File size: 3,978 Bytes
ddc5c21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# =============================================================
# File: backend/api/services/query_expander.py
# =============================================================
"""
Query expansion and disambiguation service.
Uses LLM to expand ambiguous queries and improve search results.
"""

import re
from typing import List, Dict, Any, Optional
from .llm_client import LLMClient


class QueryExpander:
    """Expands and disambiguates queries for better search results."""
    
    def __init__(self, llm_client: LLMClient):
        self.llm = llm_client
    
    async def expand_ambiguous_query(self, query: str, context: Optional[str] = None) -> List[str]:
        """
        Generate multiple query variations for ambiguous terms.
        
        Args:
            query: Original query
            context: Optional context to help disambiguation
            
        Returns:
            List of expanded query variations
        """
        # Check if query is ambiguous (short terms, common abbreviations)
        ambiguous_patterns = [
            r'\b(al|ai|ml|dl|nlp|api|ui|ux|db|sql|js|ts|py|go|rs)\b',
            r'\b[a-z]{1,2}\b'  # Very short words
        ]
        
        is_ambiguous = any(re.search(p, query.lower()) for p in ambiguous_patterns)
        
        if not is_ambiguous:
            return [query]  # Return original if not ambiguous
        
        # Use LLM to generate query variations
        prompt = f"""Given the user query: "{query}"

Generate 3-5 alternative search queries that could help find relevant information. 
Consider different interpretations, synonyms, and related terms.

{f"Context: {context}" if context else ""}

Return only the queries, one per line, without numbering or bullets:"""
        
        try:
            response = await self.llm.simple_call(prompt, temperature=0.3)
            # Parse response into list of queries
            queries = [
                line.strip() 
                for line in response.split('\n') 
                if line.strip() and not line.strip().startswith(('#', '-', '*', '1.', '2.', '3.'))
            ]
            # Include original query
            queries.insert(0, query)
            return queries[:5]  # Limit to 5 variations
        except Exception:
            # Fallback: return original query
            return [query]
    
    def expand_news_query(self, query: str) -> List[str]:
        """
        Generate multiple variations for news queries.
        
        Args:
            query: News query
            
        Returns:
            List of query variations
        """
        variations = [query]
        
        # Add time-based variations
        if "latest" not in query.lower():
            variations.append(f"latest {query}")
        if "news" not in query.lower():
            variations.append(f"{query} news")
        if "breaking" not in query.lower() and "latest" in query.lower():
            variations.append(query.replace("latest", "breaking"))
        
        # Add date-specific variations
        variations.append(f"{query} 2024")
        variations.append(f"{query} 2025")
        
        return variations[:5]  # Limit to 5
    
    def expand_short_query(self, query: str) -> str:
        """
        Expand very short queries with common expansions.
        
        Args:
            query: Short query
            
        Returns:
            Expanded query
        """
        query_lower = query.lower()
        
        # Common abbreviations
        expansions = {
            "al": "artificial intelligence AI",
            "ai": "artificial intelligence",
            "ml": "machine learning",
            "dl": "deep learning",
            "nlp": "natural language processing"
        }
        
        for abbrev, expansion in expansions.items():
            if abbrev in query_lower and len(query.split()) <= 3:
                return query.replace(abbrev, expansion, 1)
        
        return query