pluto90 commited on
Commit
edde763
Β·
verified Β·
1 Parent(s): 26f5539

Update app/graph/nodes/router.py

Browse files
Files changed (1) hide show
  1. app/graph/nodes/router.py +132 -216
app/graph/nodes/router.py CHANGED
@@ -1,257 +1,173 @@
1
- # app/graph/nodes/router.py
2
 
 
 
3
  from app.core.rag_service import get_rag_context
4
- from app.core.llm_engine import llm # βœ… use eval_llm β€” faster, no thinking
5
  from langchain_core.output_parsers import StrOutputParser
6
  from langchain_core.prompts import PromptTemplate
7
- import ast, re
8
 
9
- # Query expansion prompt
10
  expansion_prompt = PromptTemplate(
11
  input_variables=["query"],
12
  template=(
13
- "Generate 3 short alternative phrasings of this question for document search.\n"
14
- "Cover singular/plural, synonyms, and sub-concepts.\n"
15
- "Return ONLY a Python list of strings, nothing else.\n"
16
- "Example: ['What is an array?', 'array data structure', 'arrays in programming']\n\n"
17
- "Question: {query}\n\n"
 
 
18
  "List:"
19
  )
20
  )
21
 
22
- # βœ… Sanity check β€” asks LLM if the retrieved context actually answers the query
23
- # relevance_check_prompt = PromptTemplate(
24
- # input_variables=["query", "context"],
25
- # template=(
26
- # "Does the following context contain enough information to answer the query?\n"
27
- # "Reply with ONLY one word: yes or no\n\n"
28
- # "Query: {query}\n\n"
29
- # "Context: {context}\n\n"
30
- # "Answer:"
31
- # )
32
- # )
33
-
34
-
35
-
36
  expansion_chain = expansion_prompt | llm | StrOutputParser()
37
- # expansion_chain = expansion_prompt | eval_llm | StrOutputParser()
38
- # relevance_chain = relevance_check_prompt | eval_llm | StrOutputParser()
39
-
40
 
41
  def expand_query(query: str) -> list:
 
42
  try:
43
  raw = expansion_chain.invoke({"query": query}).strip()
44
- # strip thinking tags if present
45
- raw = re.sub(r"<thinking>.*?</thinking>", "", raw, flags=re.DOTALL).strip()
 
 
 
46
 
47
- expansions = ast.literal_eval(raw)
48
- if isinstance(expansions, list):
49
- return [query] + [q for q in expansions if q != query][:3]
 
 
 
 
 
 
50
  except Exception as e:
51
- print(f"QUERY EXPANSION FAILED β†’ {e}")
 
52
  return [query]
53
 
54
- # invokes relevance chain
55
- # def is_context_relevant(query: str, context: str) -> bool:
56
- """LLM confirms whether retrieved context actually answers the query."""
57
- try:
58
- answer = relevance_chain.invoke({
59
- "query": query,
60
- "context": context[:800] # βœ… cap context sent to LLM β€” saves tokens
61
- }).strip().lower()
62
- print(f"RELEVANCE CHECK β†’ '{answer}'")
63
- return answer.startswith("yes")
64
- except Exception as e:
65
- print(f"RELEVANCE CHECK FAILED β†’ {e}")
66
- return False # βœ… fail safe: if check fails, go general
67
-
68
-
69
- def _collect_chunks(expanded_queries, original_context, original_scores, doc_id):
70
- seen = set()
71
- all_contexts = []
72
- all_scores = []
73
-
74
- for chunk, score in zip(original_context.split("\n\n---\n\n"), original_scores):
75
- chunk = chunk.strip()
76
- if chunk and chunk not in seen:
77
- seen.add(chunk)
78
- all_contexts.append(chunk)
79
- all_scores.append(score)
80
-
81
- for q in expanded_queries[1:]:
82
- ctx, _, scores = get_rag_context(q, doc_id, top_k=2)
83
- if ctx:
84
- for chunk, score in zip(ctx.split("\n\n---\n\n"), scores):
85
- chunk = chunk.strip()
86
- if chunk and chunk not in seen:
87
- seen.add(chunk)
88
- all_contexts.append(chunk)
89
- all_scores.append(score)
90
-
91
- return all_contexts, all_scores, seen
92
-
93
-
94
-
95
-
96
-
97
-
98
-
99
-
100
- # def router_node(state):
101
- # query = state.get("query")
102
- # doc_id = state.get("doc_id")
103
-
104
- # # ── Stage 1: score original query ───────────────���──────────────────────
105
- # original_context, original_sources, original_scores = get_rag_context(
106
- # query, doc_id, top_k=3
107
- # )
108
- # original_max_score = max(original_scores) if original_scores else 0.0
109
- # print(f"ORIGINAL QUERY SCORE β†’ {original_max_score:.3f}")
110
-
111
- # HIGH_THRESHOLD = 0.70 # βœ… auto-RAG β€” very confident
112
- # LOW_THRESHOLD = 0.50 # βœ… below this β†’ always general, no LLM check needed
113
-
114
- # # ── Stage 2: definitive general (score too low) ─────────────────────────
115
- # if original_max_score < LOW_THRESHOLD:
116
- # print(f"ROUTER DECISION β†’ general | score too low: {original_max_score:.3f}")
117
- # return {
118
- # **state,
119
- # "route": "general",
120
- # "context": "",
121
- # "sources": [],
122
- # "score": original_max_score
123
- # }
124
-
125
- # # ── Stage 3: definitive RAG (score very high) ──────────────────────────
126
- # if original_max_score >= HIGH_THRESHOLD:
127
- # # Still expand to get more chunks, but don't need LLM sanity check
128
- # expanded_queries = expand_query(query)
129
- # print(f"EXPANDED QUERIES β†’ {expanded_queries}")
130
- # all_contexts, all_scores, seen = _collect_chunks(
131
- # expanded_queries, original_context, original_scores, doc_id
132
- # )
133
- # merged = "\n\n---\n\n".join(all_contexts)
134
- # print(f"ROUTER DECISION β†’ rag (high confidence) | score: {original_max_score:.3f} | chunks: {len(all_contexts)}")
135
- # return {
136
- # **state,
137
- # "route": "rag",
138
- # "context": merged,
139
- # "sources": all_contexts,
140
- # "score": original_max_score
141
- # }
142
-
143
- # # ── Stage 4: ambiguous zone (0.50–0.70) β†’ LLM sanity check ────────────
144
- # print(f"AMBIGUOUS SCORE β†’ {original_max_score:.3f} | running relevance check...")
145
- # context_is_relevant = is_context_relevant(query, original_context)
146
-
147
- # if not context_is_relevant:
148
- # print(f"ROUTER DECISION β†’ general | LLM says context doesn't answer query")
149
- # return {
150
- # **state,
151
- # "route": "general",
152
- # "context": "",
153
- # "sources": [],
154
- # "score": original_max_score
155
- # }
156
-
157
- # # Context confirmed relevant β€” expand and collect chunks
158
- # expanded_queries = expand_query(query)
159
- # print(f"EXPANDED QUERIES β†’ {expanded_queries}")
160
- # all_contexts, all_scores, seen = _collect_chunks(
161
- # expanded_queries, original_context, original_scores, doc_id
162
- # )
163
- # merged = "\n\n---\n\n".join(all_contexts)
164
- # print(f"ROUTER DECISION β†’ rag (llm confirmed) | score: {original_max_score:.3f} | chunks: {len(all_contexts)}")
165
-
166
- # return {
167
- # **state,
168
- # "route": "rag",
169
- # "context": merged,
170
- # "sources": all_contexts,
171
- # "score": original_max_score
172
- # }
173
-
174
-
175
-
176
-
177
-
178
- # app/graph/nodes/router.py
179
- # Add a third threshold zone between general and ambiguous
180
 
181
  def router_node(state):
 
 
 
 
182
  query = state.get("query")
183
  doc_id = state.get("doc_id")
184
 
 
185
  original_context, original_sources, original_scores = get_rag_context(
186
- query, doc_id, top_k=3
187
  )
188
  original_max_score = max(original_scores) if original_scores else 0.0
189
- print(f"ORIGINAL QUERY SCORE β†’ {original_max_score:.3f}")
190
-
191
- # HIGH_THRESHOLD = 0.70 # strong match β†’ RAG only
192
- # HYBRID_THRESHOLD = 0.40 # weak match β†’ hybrid (doc snippet + general knowledge)
193
- # LOW_THRESHOLD = 0.40 # below this β†’ pure general
194
 
 
 
 
195
 
196
- # ── Thresholds ──────────────────────────────────────────────────────────
197
- # Tuned for BGE-base-en cosine similarity.
198
- # Scores below 0.55 on a DSA doc for "currency of romania" = noise, not match.
199
- RAG_THRESHOLD = 0.60 # βœ… confident match β†’ RAG
200
- HYBRID_THRESHOLD = 0.50 # βœ… partial match β†’ hybrid
201
- # below 0.50 β†’ general
202
-
203
-
204
-
205
- # Pure general β€” no document relevance at all
206
  if original_max_score < HYBRID_THRESHOLD:
207
- print(f"ROUTER DECISION β†’ general | score: {original_max_score:.3f}")
208
- return {**state, "route": "general", "context": "", "sources": [], "score": original_max_score}
209
-
210
- # ── RAG with expansion ──────────────────────────────────────────────────
211
- if original_max_score >= RAG_THRESHOLD:
212
- expanded_queries = expand_query(query)
213
- print(f"EXPANDED QUERIES β†’ {expanded_queries}")
214
- all_contexts, all_scores, _ = _collect_chunks(
215
- expanded_queries, original_context, original_scores, doc_id
216
- )
217
- merged = "\n\n---\n\n".join(all_contexts)
218
- print(f"ROUTER DECISION β†’ rag | score: {original_max_score:.3f} | chunks: {len(all_contexts)}")
219
  return {
220
- **state,
221
- "route": "rag",
222
- "context": merged,
223
- "sources": all_contexts,
224
- "score": original_max_score
225
- }
226
-
227
- if not context_is_relevant:
228
- # Doc has weak overlap but context doesn't actually answer it β†’ hybrid
229
- print(f"ROUTER DECISION β†’ hybrid | LLM says context partial")
230
- return {
231
- **state,
232
- "route": "hybrid",
233
- "context": original_context, # pass what we have β€” synthesizer will supplement
234
- "sources": [original_context],
235
  "score": original_max_score
236
  }
237
 
238
- # ── RAG with expansion ──────────────────────────────────────────────────
239
- if original_max_score >= RAG_THRESHOLD:
240
- expanded_queries = expand_query(query)
241
- print(f"EXPANDED QUERIES β†’ {expanded_queries}")
242
- all_contexts, all_scores, _ = _collect_chunks(
243
- expanded_queries, original_context, original_scores, doc_id
244
- )
245
- merged = "\n\n---\n\n".join(all_contexts)
246
- print(f"ROUTER DECISION β†’ rag | score: {original_max_score:.3f} | chunks: {len(all_contexts)}")
247
  return {
248
  **state,
249
- "route": "rag",
250
- "context": merged,
251
- "sources": all_contexts,
252
- "score": original_max_score
253
  }
254
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
-
257
-
 
 
1
 
2
+
3
+ # app/graph/nodes/router.py
4
  from app.core.rag_service import get_rag_context
5
+ from app.core.llm_engine import llm # βœ… Use main llm, not eval_llm
6
  from langchain_core.output_parsers import StrOutputParser
7
  from langchain_core.prompts import PromptTemplate
8
+ import ast
9
 
10
+ # βœ… IMPROVED: Multi-strategy expansion
11
  expansion_prompt = PromptTemplate(
12
  input_variables=["query"],
13
  template=(
14
+ "Generate 4 diverse search queries for: '{query}'\n\n"
15
+ "1. Synonym variation (different words, same meaning)\n"
16
+ "2. Acronym/abbreviation expansion (if applicable)\n"
17
+ "3. Broader concept query\n"
18
+ "4. Technical detail query\n\n"
19
+ "Return ONLY a Python list of 4 strings.\n"
20
+ "Example: ['machine learning algorithms', 'ML techniques', 'artificial intelligence methods', 'supervised learning models']\n\n"
21
  "List:"
22
  )
23
  )
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  expansion_chain = expansion_prompt | llm | StrOutputParser()
 
 
 
26
 
27
  def expand_query(query: str) -> list:
28
+ """Generate diverse query variations for better retrieval."""
29
  try:
30
  raw = expansion_chain.invoke({"query": query}).strip()
31
+ # Handle both list format and line-by-line format
32
+ if raw.startswith('['):
33
+ expansions = ast.literal_eval(raw)
34
+ else:
35
+ expansions = [line.strip().strip('"-') for line in raw.split('\n') if line.strip()]
36
 
37
+ if isinstance(expansions, list) and expansions:
38
+ # Return original + up to 4 unique expansions
39
+ unique_expansions = [query]
40
+ for exp in expansions:
41
+ if exp and exp != query and exp not in unique_expansions:
42
+ unique_expansions.append(exp)
43
+ if len(unique_expansions) >= 5: # Original + 4 expansions
44
+ break
45
+ return unique_expansions
46
  except Exception as e:
47
+ print(f"⚠️ QUERY EXPANSION FAILED β†’ {e}")
48
+
49
  return [query]
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  def router_node(state):
53
+ """
54
+ Pure score-based routing with smart threshold zones.
55
+ NO LLM calls during routing β€” saves quota and latency.
56
+ """
57
  query = state.get("query")
58
  doc_id = state.get("doc_id")
59
 
60
+ # βœ… Initial retrieval with relaxed threshold
61
  original_context, original_sources, original_scores = get_rag_context(
62
+ query, doc_id, top_k=5 # βœ… Increased from 3
63
  )
64
  original_max_score = max(original_scores) if original_scores else 0.0
65
+ print(f"πŸ“Š ORIGINAL QUERY SCORE β†’ {original_max_score:.3f}")
 
 
 
 
66
 
67
+ # βœ… IMPROVED THRESHOLDS
68
+ HIGH_THRESHOLD = 0.50 # Lowered from 0.55 β€” strong match β†’ RAG
69
+ HYBRID_THRESHOLD = 0.28 # Lowered from 0.30 β€” weak match β†’ hybrid
70
 
71
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
72
+ # ZONE 1: Pure General (< 0.28)
73
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
 
 
 
 
 
 
 
74
  if original_max_score < HYBRID_THRESHOLD:
75
+ print(f"πŸ”€ ROUTER β†’ general | score: {original_max_score:.3f}")
 
 
 
 
 
 
 
 
 
 
 
76
  return {
77
+ **state,
78
+ "route": "general",
79
+ "context": "",
80
+ "sources": [],
 
 
 
 
 
 
 
 
 
 
 
81
  "score": original_max_score
82
  }
83
 
84
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
85
+ # ZONE 2: Hybrid (0.28 - 0.50)
86
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
87
+ if HYBRID_THRESHOLD <= original_max_score < HIGH_THRESHOLD:
88
+ print(f"πŸ”€ ROUTER β†’ hybrid | score: {original_max_score:.3f}")
89
+ # Pass the best chunks we have, synthesizer will supplement with general knowledge
 
 
 
90
  return {
91
  **state,
92
+ "route": "hybrid",
93
+ "context": original_context,
94
+ "sources": original_sources,
95
+ "score": original_max_score
96
  }
97
 
98
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
99
+ # ZONE 3: Full RAG (β‰₯ 0.50) β€” Query Expansion + Fusion
100
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
101
+ print(f"πŸš€ HIGH SCORE β†’ Expanding query for better coverage...")
102
+ expanded_queries = expand_query(query)
103
+ print(f"πŸ“ EXPANDED QUERIES β†’ {expanded_queries}")
104
+
105
+ # Collect chunks from all query variations
106
+ all_contexts, all_scores = _collect_chunks_with_rrf(
107
+ expanded_queries, doc_id
108
+ )
109
+
110
+ if not all_contexts:
111
+ # Fallback to original context if expansion failed
112
+ all_contexts = original_sources
113
+
114
+ merged = "\n\n---\n\n".join(all_contexts[:10]) # βœ… Cap at 10 chunks
115
+
116
+ print(f"🎯 ROUTER β†’ rag | score: {original_max_score:.3f} | chunks: {len(all_contexts)}")
117
+ return {
118
+ **state,
119
+ "route": "rag",
120
+ "context": merged,
121
+ "sources": all_contexts[:10],
122
+ "score": original_max_score
123
+ }
124
+
125
+
126
+ def _collect_chunks_with_rrf(queries: list, doc_id: str, k: int = 60):
127
+ """
128
+ Reciprocal Rank Fusion across multiple queries.
129
+ RRF formula: score = Ξ£(1 / (k + rank)) for each query
130
+ """
131
+ from collections import defaultdict
132
+
133
+ chunk_scores = defaultdict(lambda: {'text': '', 'rrf_score': 0.0, 'max_sim': 0.0})
134
+
135
+ for query in queries:
136
+ context, sources, scores = get_rag_context(query, doc_id, top_k=8)
137
+
138
+ if not context:
139
+ continue
140
+
141
+ chunks = context.split("\n\n---\n\n")
142
+
143
+ for rank, (chunk, score) in enumerate(zip(chunks, scores), start=1):
144
+ chunk = chunk.strip()
145
+ if not chunk:
146
+ continue
147
+
148
+ # Use chunk text as key for deduplication
149
+ chunk_key = chunk[:100] # First 100 chars as unique identifier
150
+
151
+ # RRF score accumulation
152
+ chunk_scores[chunk_key]['text'] = chunk
153
+ chunk_scores[chunk_key]['rrf_score'] += 1.0 / (k + rank)
154
+ chunk_scores[chunk_key]['max_sim'] = max(
155
+ chunk_scores[chunk_key]['max_sim'],
156
+ score
157
+ )
158
+
159
+ # Sort by RRF score (primary) and max similarity (tiebreaker)
160
+ sorted_chunks = sorted(
161
+ chunk_scores.values(),
162
+ key=lambda x: (x['rrf_score'], x['max_sim']),
163
+ reverse=True
164
+ )
165
+
166
+ texts = [item['text'] for item in sorted_chunks]
167
+ scores = [item['max_sim'] for item in sorted_chunks]
168
+
169
+ print(f"βœ… RRF FUSION β†’ {len(texts)} unique chunks from {len(queries)} queries")
170
+
171
+ return texts, scores
172
 
173
+