pluto90 commited on
Commit
082f3f8
Β·
verified Β·
1 Parent(s): 59a7be2

Update app/graph/nodes/router.py

Browse files
Files changed (1) hide show
  1. app/graph/nodes/router.py +257 -235
app/graph/nodes/router.py CHANGED
@@ -1,235 +1,257 @@
1
- # app/graph/nodes/router.py
2
-
3
- from app.core.rag_service import get_rag_context
4
- from app.core.llm_engine import eval_llm # βœ… use eval_llm β€” faster, no thinking
5
- from langchain_core.output_parsers import StrOutputParser
6
- from langchain_core.prompts import PromptTemplate
7
- import ast
8
-
9
- # Query expansion prompt
10
- expansion_prompt = PromptTemplate(
11
- input_variables=["query"],
12
- template=(
13
- "Generate 3 short alternative phrasings of this question for document search.\n"
14
- "Cover singular/plural, synonyms, and sub-concepts.\n"
15
- "Return ONLY a Python list of strings, nothing else.\n"
16
- "Example: ['What is an array?', 'array data structure', 'arrays in programming']\n\n"
17
- "Question: {query}\n\n"
18
- "List:"
19
- )
20
- )
21
-
22
- # βœ… Sanity check β€” asks LLM if the retrieved context actually answers the query
23
- relevance_check_prompt = PromptTemplate(
24
- input_variables=["query", "context"],
25
- template=(
26
- "Does the following context contain enough information to answer the query?\n"
27
- "Reply with ONLY one word: yes or no\n\n"
28
- "Query: {query}\n\n"
29
- "Context: {context}\n\n"
30
- "Answer:"
31
- )
32
- )
33
-
34
-
35
-
36
- expansion_chain = expansion_prompt | eval_llm | StrOutputParser()
37
- relevance_chain = relevance_check_prompt | eval_llm | StrOutputParser()
38
-
39
-
40
- def expand_query(query: str) -> list:
41
- try:
42
- raw = expansion_chain.invoke({"query": query}).strip()
43
- expansions = ast.literal_eval(raw)
44
- if isinstance(expansions, list):
45
- return [query] + [q for q in expansions if q != query][:3]
46
- except Exception as e:
47
- print(f"QUERY EXPANSION FAILED β†’ {e}")
48
- return [query]
49
-
50
-
51
- def is_context_relevant(query: str, context: str) -> bool:
52
- """LLM confirms whether retrieved context actually answers the query."""
53
- try:
54
- answer = relevance_chain.invoke({
55
- "query": query,
56
- "context": context[:800] # βœ… cap context sent to LLM β€” saves tokens
57
- }).strip().lower()
58
- print(f"RELEVANCE CHECK β†’ '{answer}'")
59
- return answer.startswith("yes")
60
- except Exception as e:
61
- print(f"RELEVANCE CHECK FAILED β†’ {e}")
62
- return False # βœ… fail safe: if check fails, go general
63
-
64
-
65
- # def router_node(state):
66
- # query = state.get("query")
67
- # doc_id = state.get("doc_id")
68
-
69
- # # ── Stage 1: score original query ──────────────────────────────────────
70
- # original_context, original_sources, original_scores = get_rag_context(
71
- # query, doc_id, top_k=3
72
- # )
73
- # original_max_score = max(original_scores) if original_scores else 0.0
74
- # print(f"ORIGINAL QUERY SCORE β†’ {original_max_score:.3f}")
75
-
76
- # HIGH_THRESHOLD = 0.70 # βœ… auto-RAG β€” very confident
77
- # LOW_THRESHOLD = 0.50 # βœ… below this β†’ always general, no LLM check needed
78
-
79
- # # ── Stage 2: definitive general (score too low) ─────────────────────────
80
- # if original_max_score < LOW_THRESHOLD:
81
- # print(f"ROUTER DECISION β†’ general | score too low: {original_max_score:.3f}")
82
- # return {
83
- # **state,
84
- # "route": "general",
85
- # "context": "",
86
- # "sources": [],
87
- # "score": original_max_score
88
- # }
89
-
90
- # # ── Stage 3: definitive RAG (score very high) ──────────────────────────
91
- # if original_max_score >= HIGH_THRESHOLD:
92
- # # Still expand to get more chunks, but don't need LLM sanity check
93
- # expanded_queries = expand_query(query)
94
- # print(f"EXPANDED QUERIES β†’ {expanded_queries}")
95
- # all_contexts, all_scores, seen = _collect_chunks(
96
- # expanded_queries, original_context, original_scores, doc_id
97
- # )
98
- # merged = "\n\n---\n\n".join(all_contexts)
99
- # print(f"ROUTER DECISION β†’ rag (high confidence) | score: {original_max_score:.3f} | chunks: {len(all_contexts)}")
100
- # return {
101
- # **state,
102
- # "route": "rag",
103
- # "context": merged,
104
- # "sources": all_contexts,
105
- # "score": original_max_score
106
- # }
107
-
108
- # # ── Stage 4: ambiguous zone (0.50–0.70) β†’ LLM sanity check ────────────
109
- # print(f"AMBIGUOUS SCORE β†’ {original_max_score:.3f} | running relevance check...")
110
- # context_is_relevant = is_context_relevant(query, original_context)
111
-
112
- # if not context_is_relevant:
113
- # print(f"ROUTER DECISION β†’ general | LLM says context doesn't answer query")
114
- # return {
115
- # **state,
116
- # "route": "general",
117
- # "context": "",
118
- # "sources": [],
119
- # "score": original_max_score
120
- # }
121
-
122
- # # Context confirmed relevant β€” expand and collect chunks
123
- # expanded_queries = expand_query(query)
124
- # print(f"EXPANDED QUERIES β†’ {expanded_queries}")
125
- # all_contexts, all_scores, seen = _collect_chunks(
126
- # expanded_queries, original_context, original_scores, doc_id
127
- # )
128
- # merged = "\n\n---\n\n".join(all_contexts)
129
- # print(f"ROUTER DECISION β†’ rag (llm confirmed) | score: {original_max_score:.3f} | chunks: {len(all_contexts)}")
130
-
131
- # return {
132
- # **state,
133
- # "route": "rag",
134
- # "context": merged,
135
- # "sources": all_contexts,
136
- # "score": original_max_score
137
- # }
138
-
139
-
140
-
141
-
142
-
143
- # app/graph/nodes/router.py
144
- # Add a third threshold zone between general and ambiguous
145
-
146
- def router_node(state):
147
- query = state.get("query")
148
- doc_id = state.get("doc_id")
149
-
150
- original_context, original_sources, original_scores = get_rag_context(
151
- query, doc_id, top_k=3
152
- )
153
- original_max_score = max(original_scores) if original_scores else 0.0
154
- print(f"ORIGINAL QUERY SCORE β†’ {original_max_score:.3f}")
155
-
156
- HIGH_THRESHOLD = 0.70 # strong match β†’ RAG only
157
- HYBRID_THRESHOLD = 0.40 # weak match β†’ hybrid (doc snippet + general knowledge)
158
- LOW_THRESHOLD = 0.40 # below this β†’ pure general
159
-
160
- # Pure general β€” no document relevance at all
161
- if original_max_score < LOW_THRESHOLD:
162
- print(f"ROUTER DECISION β†’ general | score: {original_max_score:.3f}")
163
- return {**state, "route": "general", "context": "", "sources": [], "score": original_max_score}
164
-
165
- # Strong match β€” full RAG
166
- if original_max_score >= HIGH_THRESHOLD:
167
- expanded_queries = expand_query(query)
168
- print(f"EXPANDED QUERIES β†’ {expanded_queries}")
169
- all_contexts, all_scores, _ = _collect_chunks(
170
- expanded_queries, original_context, original_scores, doc_id
171
- )
172
- merged = "\n\n---\n\n".join(all_contexts)
173
- print(f"ROUTER DECISION β†’ rag | score: {original_max_score:.3f} | chunks: {len(all_contexts)}")
174
- return {**state, "route": "rag", "context": merged, "sources": all_contexts, "score": original_max_score}
175
-
176
- # Ambiguous zone (0.40–0.70) β€” LLM sanity check first
177
- print(f"AMBIGUOUS SCORE β†’ {original_max_score:.3f} | running relevance check...")
178
- context_is_relevant = is_context_relevant(query, original_context)
179
-
180
- if not context_is_relevant:
181
- # Doc has weak overlap but context doesn't actually answer it β†’ hybrid
182
- print(f"ROUTER DECISION β†’ hybrid | LLM says context partial")
183
- return {
184
- **state,
185
- "route": "hybrid",
186
- "context": original_context, # pass what we have β€” synthesizer will supplement
187
- "sources": [original_context],
188
- "score": original_max_score
189
- }
190
-
191
- # LLM confirmed context is relevant β€” full RAG with expansion
192
- expanded_queries = expand_query(query)
193
- print(f"EXPANDED QUERIES β†’ {expanded_queries}")
194
- all_contexts, all_scores, _ = _collect_chunks(
195
- expanded_queries, original_context, original_scores, doc_id
196
- )
197
- merged = "\n\n---\n\n".join(all_contexts)
198
- print(f"ROUTER DECISION β†’ rag (confirmed) | score: {original_max_score:.3f} | chunks: {len(all_contexts)}")
199
- return {**state, "route": "rag", "context": merged, "sources": all_contexts, "score": original_max_score}
200
-
201
-
202
-
203
-
204
-
205
- def _collect_chunks(expanded_queries, original_context, original_scores, doc_id):
206
- """Merge chunks from original + expanded queries, deduplicating by text."""
207
- seen = set()
208
- all_contexts = []
209
- all_scores = []
210
-
211
- # Seed with original results
212
- for chunk, score in zip(original_context.split("\n\n---\n\n"), original_scores):
213
- chunk = chunk.strip()
214
- if chunk and chunk not in seen:
215
- seen.add(chunk)
216
- all_contexts.append(chunk)
217
- all_scores.append(score)
218
-
219
- # Add expanded query results
220
- for q in expanded_queries[1:]:
221
- ctx, _, scores = get_rag_context(q, doc_id, top_k=2)
222
- if ctx:
223
- for chunk, score in zip(ctx.split("\n\n---\n\n"), scores):
224
- chunk = chunk.strip()
225
- if chunk and chunk not in seen:
226
- seen.add(chunk)
227
- all_contexts.append(chunk)
228
- all_scores.append(score)
229
-
230
- return all_contexts, all_scores, seen
231
-
232
-
233
-
234
-
235
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/graph/nodes/router.py
2
+
3
+ from app.core.rag_service import get_rag_context
4
+ from app.core.llm_engine import eval_llm # βœ… use eval_llm β€” faster, no thinking
5
+ from langchain_core.output_parsers import StrOutputParser
6
+ from langchain_core.prompts import PromptTemplate
7
+ import ast, re
8
+
9
+ # Query expansion prompt
10
+ expansion_prompt = PromptTemplate(
11
+ input_variables=["query"],
12
+ template=(
13
+ "Generate 3 short alternative phrasings of this question for document search.\n"
14
+ "Cover singular/plural, synonyms, and sub-concepts.\n"
15
+ "Return ONLY a Python list of strings, nothing else.\n"
16
+ "Example: ['What is an array?', 'array data structure', 'arrays in programming']\n\n"
17
+ "Question: {query}\n\n"
18
+ "List:"
19
+ )
20
+ )
21
+
22
+ # βœ… Sanity check β€” asks LLM if the retrieved context actually answers the query
23
+ # relevance_check_prompt = PromptTemplate(
24
+ # input_variables=["query", "context"],
25
+ # template=(
26
+ # "Does the following context contain enough information to answer the query?\n"
27
+ # "Reply with ONLY one word: yes or no\n\n"
28
+ # "Query: {query}\n\n"
29
+ # "Context: {context}\n\n"
30
+ # "Answer:"
31
+ # )
32
+ # )
33
+
34
+
35
+
36
+ expansion_chain = expansion_prompt | llm | StrOutputParser()
37
+ # expansion_chain = expansion_prompt | eval_llm | StrOutputParser()
38
+ # relevance_chain = relevance_check_prompt | eval_llm | StrOutputParser()
39
+
40
+
41
+ def expand_query(query: str) -> list:
42
+ try:
43
+ raw = expansion_chain.invoke({"query": query}).strip()
44
+ # strip thinking tags if present
45
+ raw = re.sub(r"<thinking>.*?</thinking>", "", raw, flags=re.DOTALL).strip()
46
+
47
+ expansions = ast.literal_eval(raw)
48
+ if isinstance(expansions, list):
49
+ return [query] + [q for q in expansions if q != query][:3]
50
+ except Exception as e:
51
+ print(f"QUERY EXPANSION FAILED β†’ {e}")
52
+ return [query]
53
+
54
+ # invokes relevance chain
55
+ # def is_context_relevant(query: str, context: str) -> bool:
56
+ """LLM confirms whether retrieved context actually answers the query."""
57
+ try:
58
+ answer = relevance_chain.invoke({
59
+ "query": query,
60
+ "context": context[:800] # βœ… cap context sent to LLM β€” saves tokens
61
+ }).strip().lower()
62
+ print(f"RELEVANCE CHECK β†’ '{answer}'")
63
+ return answer.startswith("yes")
64
+ except Exception as e:
65
+ print(f"RELEVANCE CHECK FAILED β†’ {e}")
66
+ return False # βœ… fail safe: if check fails, go general
67
+
68
+
69
+ def _collect_chunks(expanded_queries, original_context, original_scores, doc_id):
70
+ seen = set()
71
+ all_contexts = []
72
+ all_scores = []
73
+
74
+ for chunk, score in zip(original_context.split("\n\n---\n\n"), original_scores):
75
+ chunk = chunk.strip()
76
+ if chunk and chunk not in seen:
77
+ seen.add(chunk)
78
+ all_contexts.append(chunk)
79
+ all_scores.append(score)
80
+
81
+ for q in expanded_queries[1:]:
82
+ ctx, _, scores = get_rag_context(q, doc_id, top_k=2)
83
+ if ctx:
84
+ for chunk, score in zip(ctx.split("\n\n---\n\n"), scores):
85
+ chunk = chunk.strip()
86
+ if chunk and chunk not in seen:
87
+ seen.add(chunk)
88
+ all_contexts.append(chunk)
89
+ all_scores.append(score)
90
+
91
+ return all_contexts, all_scores, seen
92
+
93
+
94
+
95
+
96
+
97
+
98
+
99
+
100
+ # def router_node(state):
101
+ # query = state.get("query")
102
+ # doc_id = state.get("doc_id")
103
+
104
+ # # ── Stage 1: score original query ──────────────────────────────────────
105
+ # original_context, original_sources, original_scores = get_rag_context(
106
+ # query, doc_id, top_k=3
107
+ # )
108
+ # original_max_score = max(original_scores) if original_scores else 0.0
109
+ # print(f"ORIGINAL QUERY SCORE β†’ {original_max_score:.3f}")
110
+
111
+ # HIGH_THRESHOLD = 0.70 # βœ… auto-RAG β€” very confident
112
+ # LOW_THRESHOLD = 0.50 # βœ… below this β†’ always general, no LLM check needed
113
+
114
+ # # ── Stage 2: definitive general (score too low) ─────────────────────────
115
+ # if original_max_score < LOW_THRESHOLD:
116
+ # print(f"ROUTER DECISION β†’ general | score too low: {original_max_score:.3f}")
117
+ # return {
118
+ # **state,
119
+ # "route": "general",
120
+ # "context": "",
121
+ # "sources": [],
122
+ # "score": original_max_score
123
+ # }
124
+
125
+ # # ── Stage 3: definitive RAG (score very high) ──────────────────────────
126
+ # if original_max_score >= HIGH_THRESHOLD:
127
+ # # Still expand to get more chunks, but don't need LLM sanity check
128
+ # expanded_queries = expand_query(query)
129
+ # print(f"EXPANDED QUERIES β†’ {expanded_queries}")
130
+ # all_contexts, all_scores, seen = _collect_chunks(
131
+ # expanded_queries, original_context, original_scores, doc_id
132
+ # )
133
+ # merged = "\n\n---\n\n".join(all_contexts)
134
+ # print(f"ROUTER DECISION β†’ rag (high confidence) | score: {original_max_score:.3f} | chunks: {len(all_contexts)}")
135
+ # return {
136
+ # **state,
137
+ # "route": "rag",
138
+ # "context": merged,
139
+ # "sources": all_contexts,
140
+ # "score": original_max_score
141
+ # }
142
+
143
+ # # ── Stage 4: ambiguous zone (0.50–0.70) β†’ LLM sanity check ────────────
144
+ # print(f"AMBIGUOUS SCORE β†’ {original_max_score:.3f} | running relevance check...")
145
+ # context_is_relevant = is_context_relevant(query, original_context)
146
+
147
+ # if not context_is_relevant:
148
+ # print(f"ROUTER DECISION β†’ general | LLM says context doesn't answer query")
149
+ # return {
150
+ # **state,
151
+ # "route": "general",
152
+ # "context": "",
153
+ # "sources": [],
154
+ # "score": original_max_score
155
+ # }
156
+
157
+ # # Context confirmed relevant β€” expand and collect chunks
158
+ # expanded_queries = expand_query(query)
159
+ # print(f"EXPANDED QUERIES β†’ {expanded_queries}")
160
+ # all_contexts, all_scores, seen = _collect_chunks(
161
+ # expanded_queries, original_context, original_scores, doc_id
162
+ # )
163
+ # merged = "\n\n---\n\n".join(all_contexts)
164
+ # print(f"ROUTER DECISION β†’ rag (llm confirmed) | score: {original_max_score:.3f} | chunks: {len(all_contexts)}")
165
+
166
+ # return {
167
+ # **state,
168
+ # "route": "rag",
169
+ # "context": merged,
170
+ # "sources": all_contexts,
171
+ # "score": original_max_score
172
+ # }
173
+
174
+
175
+
176
+
177
+
178
+ # app/graph/nodes/router.py
179
+ # Add a third threshold zone between general and ambiguous
180
+
181
+ def router_node(state):
182
+ query = state.get("query")
183
+ doc_id = state.get("doc_id")
184
+
185
+ original_context, original_sources, original_scores = get_rag_context(
186
+ query, doc_id, top_k=3
187
+ )
188
+ original_max_score = max(original_scores) if original_scores else 0.0
189
+ print(f"ORIGINAL QUERY SCORE β†’ {original_max_score:.3f}")
190
+
191
+ # HIGH_THRESHOLD = 0.70 # strong match β†’ RAG only
192
+ # HYBRID_THRESHOLD = 0.40 # weak match β†’ hybrid (doc snippet + general knowledge)
193
+ # LOW_THRESHOLD = 0.40 # below this β†’ pure general
194
+
195
+
196
+ # ── Thresholds ──────────────────────────────────────────────────────────
197
+ # Tuned for BGE-base-en cosine similarity.
198
+ # Scores below 0.55 on a DSA doc for "currency of romania" = noise, not match.
199
+ RAG_THRESHOLD = 0.60 # βœ… confident match β†’ RAG
200
+ HYBRID_THRESHOLD = 0.50 # βœ… partial match β†’ hybrid
201
+ # below 0.50 β†’ general
202
+
203
+
204
+
205
+ # Pure general β€” no document relevance at all
206
+ if original_max_score < HYBRID_THRESHOLD:
207
+ print(f"ROUTER DECISION β†’ general | score: {original_max_score:.3f}")
208
+ return {**state, "route": "general", "context": "", "sources": [], "score": original_max_score}
209
+
210
+ # ── RAG with expansion ──────────────────────────────────────────────────
211
+ if original_max_score >= RAG_THRESHOLD:
212
+ expanded_queries = expand_query(query)
213
+ print(f"EXPANDED QUERIES β†’ {expanded_queries}")
214
+ all_contexts, all_scores, _ = _collect_chunks(
215
+ expanded_queries, original_context, original_scores, doc_id
216
+ )
217
+ merged = "\n\n---\n\n".join(all_contexts)
218
+ print(f"ROUTER DECISION β†’ rag | score: {original_max_score:.3f} | chunks: {len(all_contexts)}")
219
+ return {
220
+ **state,
221
+ "route": "rag",
222
+ "context": merged,
223
+ "sources": all_contexts,
224
+ "score": original_max_score
225
+ }
226
+
227
+ if not context_is_relevant:
228
+ # Doc has weak overlap but context doesn't actually answer it β†’ hybrid
229
+ print(f"ROUTER DECISION β†’ hybrid | LLM says context partial")
230
+ return {
231
+ **state,
232
+ "route": "hybrid",
233
+ "context": original_context, # pass what we have β€” synthesizer will supplement
234
+ "sources": [original_context],
235
+ "score": original_max_score
236
+ }
237
+
238
+ # ── RAG with expansion ──────────────────────────────────────────────────
239
+ if original_max_score >= RAG_THRESHOLD:
240
+ expanded_queries = expand_query(query)
241
+ print(f"EXPANDED QUERIES β†’ {expanded_queries}")
242
+ all_contexts, all_scores, _ = _collect_chunks(
243
+ expanded_queries, original_context, original_scores, doc_id
244
+ )
245
+ merged = "\n\n---\n\n".join(all_contexts)
246
+ print(f"ROUTER DECISION β†’ rag | score: {original_max_score:.3f} | chunks: {len(all_contexts)}")
247
+ return {
248
+ **state,
249
+ "route": "rag",
250
+ "context": merged,
251
+ "sources": all_contexts,
252
+ "score": original_max_score
253
+ }
254
+
255
+
256
+
257
+