File size: 8,373 Bytes
f9ad313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2473a9
f9ad313
 
 
 
c2473a9
 
f9ad313
 
 
 
 
 
 
 
 
 
 
 
 
 
d4ecac6
f9ad313
 
 
7562827
8823302
7562827
f9ad313
 
 
 
 
 
 
 
 
 
 
 
8823302
f9ad313
 
 
 
 
8823302
 
 
 
 
 
 
f9ad313
 
 
 
 
 
c2473a9
 
 
 
 
 
 
 
f9ad313
 
 
 
c2473a9
f9ad313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2473a9
f9ad313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7562827
 
 
 
 
f9ad313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4ecac6
 
f9ad313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
"""

Query Router - Decides between RAG, SQL, or hybrid approach.



Analyzes user intent and routes to the appropriate handler.

"""

import logging
from enum import Enum
from typing import Dict, Any, Optional, Tuple, List
from dataclasses import dataclass

logger = logging.getLogger(__name__)


class QueryType(Enum):
    RAG = "rag"           # Semantic search in text
    SQL = "sql"           # Structured query
    HYBRID = "hybrid"     # Both RAG and SQL
    GENERAL = "general"   # General conversation


@dataclass
class RoutingDecision:
    query_type: QueryType
    confidence: float
    reasoning: str
    suggested_tables: List[str] = None
    token_usage: Dict[str, int] = None
    
    def __post_init__(self):
        if self.suggested_tables is None:
            self.suggested_tables = []
        if self.token_usage is None:
            self.token_usage = {"input": 0, "output": 0, "total": 0}


class QueryRouter:
    """Routes queries to appropriate handlers based on intent analysis."""
    
    ROUTING_PROMPT = """Analyze this user query and determine the best approach to answer it.



DATABASE SCHEMA:

{schema}



USER QUERY: {query}



Determine if this query needs:

1. RAG - Semantic search through text content (searching for meanings, concepts, descriptions)

2. SQL - Structured database query (counting, filtering, aggregating, specific lookups, OR pagination requests like "show more", "show other", "next results", "remaining items", OR subjective filtering like "for kids", "summer shoes", "rainy season" which map to columns)

3. HYBRID - Both semantic search and structured query

4. GENERAL - General conversation not requiring database access



IMPORTANT: If the user asks to "show more", "show other", "see remaining", "next results", or similar - this is a PAGINATION request and should be routed to SQL, NOT GENERAL.

5. REFERENTIAL/AFFIRMATIVE: If the query is simply "yes", "sure", "ok", "please", or "do it", check if it's likely a confirmation to a previous offer (like "would you like to see 10 more?"). If so, this is likely SQL (pagination or new query). If ambiguous, default to GENERAL.



Respond in this exact format:

TYPE: [RAG|SQL|HYBRID|GENERAL]

CONFIDENCE: [0.0-1.0]

TABLES: [comma-separated list of relevant tables, or NONE]

REASONING: [brief explanation]"""

    def __init__(self, llm_client=None):
        self.llm_client = llm_client
    
    def set_llm_client(self, llm_client):
        self.llm_client = llm_client
    
    def route(self, query: str, schema_context: str, chat_history: Optional[List[Dict]] = None) -> RoutingDecision:
        """Analyze query and determine routing."""
        if not self.llm_client:
            # Fallback to simple heuristics
            return self._heuristic_route(query)
        
        prev_context = ""
        if chat_history and len(chat_history) > 0:
            last_msg = chat_history[-1]
            if last_msg.get("role") == "assistant":
                prev_context = f"\nPREVIOUS ASSISTANT MSG: {last_msg.get('content', '')[:200]}..."
        
        prompt = self.ROUTING_PROMPT.format(schema=schema_context, query=query + prev_context)
        
        try:
            response = self.llm_client.chat([
                {"role": "system", "content": "You are a query routing assistant."},
                {"role": "user", "content": prompt}
            ])
            
            usage = {
                "input": response.input_tokens,
                "output": response.output_tokens,
                "total": response.total_tokens
            }
            
            return self._parse_routing_response(response.content, usage)
        except Exception as e:
            logger.warning(f"LLM routing failed: {e}, using heuristics")
            return self._heuristic_route(query)
    
    def _parse_routing_response(self, response: str, usage: Dict[str, int] = None) -> RoutingDecision:
        """Parse LLM routing response."""
        lines = response.strip().split('\n')
        
        query_type = QueryType.GENERAL
        confidence = 0.5
        tables = []
        reasoning = ""
        
        for line in lines:
            line = line.strip()
            if line.startswith("TYPE:"):
                type_str = line.replace("TYPE:", "").strip().upper()
                query_type = QueryType[type_str] if type_str in QueryType.__members__ else QueryType.GENERAL
            elif line.startswith("CONFIDENCE:"):
                try:
                    confidence = float(line.replace("CONFIDENCE:", "").strip())
                except ValueError:
                    confidence = 0.5
            elif line.startswith("TABLES:"):
                tables_str = line.replace("TABLES:", "").strip()
                if tables_str.upper() != "NONE":
                    tables = [t.strip() for t in tables_str.split(",")]
            elif line.startswith("REASONING:"):
                reasoning = line.replace("REASONING:", "").strip()
        
        return RoutingDecision(query_type, confidence, reasoning, tables, token_usage=usage)
    
    def _heuristic_route(self, query: str) -> RoutingDecision:
        """Simple heuristic-based routing when LLM is unavailable."""
        query_lower = query.lower()
        
        # SQL keywords - for structured data retrieval
        sql_keywords = [
            'how many', 'count', 'total', 'average', 'sum', 'max', 'min',
            'list all', 'show all', 'find all', 'get all', 'between',
            'greater than', 'less than', 'equal to', 'top', 'bottom',
            # Data listing patterns
            'what products', 'what customers', 'what orders', 'what items',
            'show me', 'list', 'display', 'give me', 'get me',
            'all products', 'all customers', 'all orders',
            'products do you have', 'customers do you have',
            'from new york', 'from chicago', 'from los angeles',
            # Specific lookups
            'price of', 'cost of', 'stock of', 'quantity',
            'where', 'which', 'who',
            # Pagination / follow-up requests
            'show more', 'show other', 'show rest', 'show remaining',
            'more results', 'next', 'remaining', 'rest of', 'other also',
            'continue', 'keep going', 'see more', 'view more'
        ]
        
        # RAG keywords - for semantic/conceptual questions
        rag_keywords = [
            'what is the policy', 'explain', 'describe', 'tell me about',
            'meaning of', 'definition', 'why', 'how does', 'what does',
            'similar to', 'return policy', 'shipping policy', 'warranty',
            'support', 'help with', 'information about', 'details about'
        ]
        
        sql_score = sum(1 for kw in sql_keywords if kw in query_lower)
        rag_score = sum(1 for kw in rag_keywords if kw in query_lower)
        
        # Boost SQL score for common listing patterns
        if any(word in query_lower for word in ['products', 'customers', 'orders', 'items']):
            if any(word in query_lower for word in ['what', 'show', 'list', 'all', 'have']):
                sql_score += 2
        
        if sql_score > rag_score:
            return RoutingDecision(QueryType.SQL, 0.8, "SQL query for data retrieval")
        elif rag_score > sql_score:
            return RoutingDecision(QueryType.RAG, 0.8, "Semantic search for concepts")
        elif "is it good" in query_lower or "consider other" in query_lower:
            return RoutingDecision(QueryType.GENERAL, 0.7, "Consultative question about metrics")
        elif sql_score > 0 and rag_score > 0:
            return RoutingDecision(QueryType.HYBRID, 0.6, "Mixed query type")
        else:
            # Default to SQL for simple questions about data
            if any(word in query_lower for word in ['products', 'customers', 'orders']):
                return RoutingDecision(QueryType.SQL, 0.6, "Default to SQL for data tables")
            return RoutingDecision(QueryType.RAG, 0.5, "Default to semantic search")


_router: Optional[QueryRouter] = None


def get_query_router() -> QueryRouter:
    global _router
    if _router is None:
        _router = QueryRouter()
    return _router