File size: 3,282 Bytes
bb04c5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# searcher/query_understanding.py

import yaml
from nltk.corpus import wordnet
import nltk

nltk.download("wordnet", quiet=True)
nltk.download("omw-1.4", quiet=True)


class QueryUnderstanding:
    """

    Expands and rewrites the raw user query before retrieval.



    Two strategies:

    1. WordNet synonym expansion β€” "apples" β†’ ["apple", "malus", "orchard", "fruit"]

    2. (Optional) T5 query rewriting β€” rephrase for better embedding alignment



    Why expand?

    - A search for "apples" should also find "fruit", "orchard", "nutrition"

    - Embedding models are good but can miss synonyms in different domains

    - Expansion bridges the vocabulary gap before we even hit FAISS

    """

    def __init__(self, config_path="config.yaml"):
        with open(config_path) as f:
            config = yaml.safe_load(f)
        self.max_synonyms = config.get("max_synonyms", 5)
        self.expansion_enabled = config.get("query_expansion", True)

    def expand(self, query: str) -> str:
        """

        Expand the query using WordNet synonyms.



        Args:

            query (str) β€” raw user query



        Returns:

            str β€” expanded query string (original + synonyms appended)



        Example:

            "apple nutrition" β†’ "apple nutrition fruit malus orchard dietary"

        """
        if not self.expansion_enabled:
            return query

        words = query.lower().split()
        synonyms = set()

        for word in words:
            for syn in wordnet.synsets(word):
                for lemma in syn.lemmas():
                    cleaned = lemma.name().replace("_", " ")
                    if cleaned.lower() != word:
                        synonyms.add(cleaned)
                if len(synonyms) >= self.max_synonyms:
                    break

        expanded = query
        if synonyms:
            expanded = query + " " + " ".join(list(synonyms)[:self.max_synonyms])

        return expanded

    def rewrite(self, query: str) -> str:
        """

        Lightweight query rewriting β€” normalise and clean the query.

        (Plug in a T5/FLAN-T5 model here for more powerful rewriting.)



        Args:

            query (str) β€” raw user query



        Returns:

            str β€” cleaned query

        """
        # Basic normalisation: strip extra spaces, lowercase
        query = " ".join(query.strip().split())
        return query

    def process(self, query: str) -> dict:
        """

        Full query understanding pipeline.



        Returns:

            dict with:

                original   β€” the raw input

                rewritten  β€” cleaned version

                expanded   β€” synonym-expanded version for dense retrieval

        """
        rewritten = self.rewrite(query)
        expanded = self.expand(rewritten)
        return {
            "original": query,
            "rewritten": rewritten,
            "expanded": expanded,
        }


if __name__ == "__main__":
    qu = QueryUnderstanding()
    result = qu.process("quarterly budget report")
    print(f"Original : {result['original']}")
    print(f"Rewritten: {result['rewritten']}")
    print(f"Expanded : {result['expanded']}")