XQ commited on
Commit
4ba88df
·
1 Parent(s): ec64993

Update router logic

Browse files
Files changed (2) hide show
  1. src/agent/router.py +90 -11
  2. tests/test_router.py +61 -2
src/agent/router.py CHANGED
@@ -15,6 +15,11 @@ from src.retrieval.reranker import Reranker
15
 
16
  logger = logging.getLogger(__name__)
17
 
 
 
 
 
 
18
 
19
  class RouterState(TypedDict):
20
  """LangGraph state passed between routing nodes.
@@ -31,6 +36,7 @@ class RouterState(TypedDict):
31
  fused_results: Results after RRF fusion.
32
  reranked: Results after cross-encoder reranking.
33
  confidence: Max reranker score (0.0-1.0).
 
34
  answer: Final generated answer.
35
  """
36
 
@@ -45,6 +51,7 @@ class RouterState(TypedDict):
45
  fused_results: list[QueryResult]
46
  reranked: list[QueryResult]
47
  confidence: float
 
48
  answer: str
49
 
50
 
@@ -70,6 +77,7 @@ def _make_initial_state(query: str, top_k: int) -> RouterState:
70
  fused_results=[],
71
  reranked=[],
72
  confidence=0.0,
 
73
  answer="",
74
  )
75
 
@@ -274,6 +282,53 @@ class QueryRouter:
274
  logger.info("Confidence: %.4f (sigmoid-normalized by reranker)", confidence)
275
  return {"reranked": reranked, "confidence": confidence}
276
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  @staticmethod
278
  def _update_intent_node(state: RouterState) -> dict:
279
  """Promote FACTUAL to RAG when sources are found."""
@@ -295,19 +350,28 @@ class QueryRouter:
295
 
296
  @staticmethod
297
  def _should_retrieve(state: RouterState) -> str:
298
- """Skip retrieval when intent is UNKNOWN."""
299
- return "retrieve" if state["intent"] != IntentType.UNKNOWN else "rerank"
300
 
301
  def _build_graph(self) -> object:
302
  """Build the LangGraph routing graph.
303
 
304
- Nodes:
305
- detect → detect language and intent
306
- translate → translate query to Danish if needed
307
- retrievehybrid search (skipped when intent is UNKNOWN)
308
- rerank → cross-encoder reranking
309
- update_intent → promote FACTUAL to RAG when sources are found
310
- generate → build prompt and call LLM
 
 
 
 
 
 
 
 
 
311
 
312
  Returns:
313
  Compiled LangGraph graph.
@@ -317,18 +381,30 @@ class QueryRouter:
317
  graph.add_node("translate", self._translate_node)
318
  graph.add_node("retrieve", self._retrieve_node)
319
  graph.add_node("rerank", self._rerank_node)
 
320
  graph.add_node("update_intent", self._update_intent_node)
321
  graph.add_node("generate", self._generate_node)
322
 
323
  graph.set_entry_point("detect")
324
  graph.add_edge("detect", "translate")
 
 
325
  graph.add_conditional_edges(
326
  "translate",
327
  self._should_retrieve,
328
- {"retrieve": "retrieve", "rerank": "rerank"},
329
  )
 
330
  graph.add_edge("retrieve", "rerank")
331
- graph.add_edge("rerank", "update_intent")
 
 
 
 
 
 
 
 
332
  graph.add_edge("update_intent", "generate")
333
  graph.add_edge("generate", END)
334
 
@@ -461,6 +537,9 @@ class QueryRouter:
461
  elif node_name == "rerank":
462
  event["reranked_count"] = len(update.get("reranked", []))
463
  event["confidence"] = round(update.get("confidence", 0.0), 4)
 
 
 
464
 
465
  yield event
466
 
 
15
 
16
  logger = logging.getLogger(__name__)
17
 
18
+ # Reranker confidence below this triggers a query-broadening retry.
19
+ # Cross-encoder sigmoid scores below 0.3 generally indicate poor relevance.
20
+ _LOW_CONFIDENCE_THRESHOLD = 0.3
21
+ _MAX_RETRIES = 1
22
+
23
 
24
  class RouterState(TypedDict):
25
  """LangGraph state passed between routing nodes.
 
36
  fused_results: Results after RRF fusion.
37
  reranked: Results after cross-encoder reranking.
38
  confidence: Max reranker score (0.0-1.0).
39
+ retry_count: Number of query-broadening retries performed so far.
40
  answer: Final generated answer.
41
  """
42
 
 
51
  fused_results: list[QueryResult]
52
  reranked: list[QueryResult]
53
  confidence: float
54
+ retry_count: int
55
  answer: str
56
 
57
 
 
77
  fused_results=[],
78
  reranked=[],
79
  confidence=0.0,
80
+ retry_count=0,
81
  answer="",
82
  )
83
 
 
282
  logger.info("Confidence: %.4f (sigmoid-normalized by reranker)", confidence)
283
  return {"reranked": reranked, "confidence": confidence}
284
 
285
+ def _broaden_query_node(self, state: RouterState) -> dict:
286
+ """Rewrite the retrieval query when reranker confidence is low.
287
+
288
+ Uses the LLM to generate alternative search terms while preserving
289
+ the original meaning, then increments the retry counter.
290
+ """
291
+ prompt = (
292
+ "The following search query did not return good results from "
293
+ "the document database. Rewrite it to be broader or use "
294
+ "different keywords while keeping the same meaning. "
295
+ "Reply with ONLY the rewritten query, nothing else.\n\n"
296
+ f"Original question: {state['query']}\n"
297
+ f"Failed search query: {state['retrieval_query']}"
298
+ )
299
+ broadened = str(self._llm_chain.invoke(prompt)).strip()
300
+ logger.info(
301
+ "Broadened query for retry %d: %s",
302
+ state["retry_count"] + 1,
303
+ broadened,
304
+ )
305
+ return {
306
+ "retrieval_query": broadened,
307
+ "retry_count": state["retry_count"] + 1,
308
+ }
309
+
310
+ @staticmethod
311
+ def _check_confidence(state: RouterState) -> str:
312
+ """Decide whether to retry retrieval or proceed to generation.
313
+
314
+ Triggers a retry when results exist but confidence is below
315
+ the threshold and retries remain. Empty results (no documents
316
+ matched at all) are not retried — broadening cannot help when
317
+ the knowledge base simply lacks coverage.
318
+ """
319
+ if (
320
+ state.get("reranked")
321
+ and state["confidence"] < _LOW_CONFIDENCE_THRESHOLD
322
+ and state["retry_count"] < _MAX_RETRIES
323
+ ):
324
+ logger.info(
325
+ "Low confidence (%.4f < %.2f), retrying with broadened query",
326
+ state["confidence"],
327
+ _LOW_CONFIDENCE_THRESHOLD,
328
+ )
329
+ return "retry"
330
+ return "accept"
331
+
332
  @staticmethod
333
  def _update_intent_node(state: RouterState) -> dict:
334
  """Promote FACTUAL to RAG when sources are found."""
 
350
 
351
  @staticmethod
352
  def _should_retrieve(state: RouterState) -> str:
353
+ """Skip retrieval entirely when intent is UNKNOWN."""
354
+ return "retrieve" if state["intent"] != IntentType.UNKNOWN else "generate"
355
 
356
  def _build_graph(self) -> object:
357
  """Build the LangGraph routing graph.
358
 
359
+ Graph topology::
360
+
361
+ detect → translate ─┬─ (UNKNOWN) ──────────────→ generate
362
+ └─ (other)retrieve rerank
363
+ ↑ │
364
+ │ check_confidence
365
+ │ │ │
366
+ broaden ←─ retry accept
367
+ _query → update_intent
368
+
369
+ generate
370
+
371
+ Key LangGraph features demonstrated:
372
+ - Conditional edges: intent-based skip, confidence-based routing
373
+ - Cycle: low-confidence retry loop (broaden_query → retrieve)
374
+ - Shared state: retry_count controls loop termination
375
 
376
  Returns:
377
  Compiled LangGraph graph.
 
381
  graph.add_node("translate", self._translate_node)
382
  graph.add_node("retrieve", self._retrieve_node)
383
  graph.add_node("rerank", self._rerank_node)
384
+ graph.add_node("broaden_query", self._broaden_query_node)
385
  graph.add_node("update_intent", self._update_intent_node)
386
  graph.add_node("generate", self._generate_node)
387
 
388
  graph.set_entry_point("detect")
389
  graph.add_edge("detect", "translate")
390
+
391
+ # Branch: skip retrieval entirely for off-topic queries
392
  graph.add_conditional_edges(
393
  "translate",
394
  self._should_retrieve,
395
+ {"retrieve": "retrieve", "generate": "generate"},
396
  )
397
+
398
  graph.add_edge("retrieve", "rerank")
399
+
400
+ # Branch + cycle: retry with broadened query on low confidence
401
+ graph.add_conditional_edges(
402
+ "rerank",
403
+ self._check_confidence,
404
+ {"retry": "broaden_query", "accept": "update_intent"},
405
+ )
406
+ graph.add_edge("broaden_query", "retrieve") # ← the loop
407
+
408
  graph.add_edge("update_intent", "generate")
409
  graph.add_edge("generate", END)
410
 
 
537
  elif node_name == "rerank":
538
  event["reranked_count"] = len(update.get("reranked", []))
539
  event["confidence"] = round(update.get("confidence", 0.0), 4)
540
+ elif node_name == "broaden_query":
541
+ event["retrieval_query"] = update.get("retrieval_query", "")
542
+ event["retry_count"] = update.get("retry_count", 0)
543
 
544
  yield event
545
 
tests/test_router.py CHANGED
@@ -143,7 +143,6 @@ class TestQueryRouterDirect:
143
  """UNKNOWN intent skips retrieval and returns zero confidence."""
144
  classifier, retriever, reranker, llm_chain = mock_components
145
 
146
- reranker.rerank.return_value = []
147
  _setup_llm_chain_danish(llm_chain, "Fallback answer", intent="unknown")
148
 
149
  router = QueryRouter(classifier, retriever, reranker, llm_chain)
@@ -153,6 +152,7 @@ class TestQueryRouterDirect:
153
  assert response.intent == IntentType.UNKNOWN
154
  assert response.confidence == 0.0
155
  retriever.search_detailed.assert_not_called()
 
156
 
157
  def test_unknown_intent_prompt_uses_generic_instruction(
158
  self, mock_components
@@ -160,7 +160,6 @@ class TestQueryRouterDirect:
160
  """UNKNOWN intent should use the generic helpful instruction."""
161
  classifier, retriever, reranker, llm_chain = mock_components
162
 
163
- reranker.rerank.return_value = []
164
  _setup_llm_chain_danish(llm_chain, "answer", intent="unknown")
165
 
166
  router = QueryRouter(classifier, retriever, reranker, llm_chain)
@@ -300,3 +299,63 @@ class TestSigmoidInReranker:
300
  response = router.route("test", top_k=3)
301
 
302
  assert response.confidence == pytest.approx(0.9, abs=1e-6)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  """UNKNOWN intent skips retrieval and returns zero confidence."""
144
  classifier, retriever, reranker, llm_chain = mock_components
145
 
 
146
  _setup_llm_chain_danish(llm_chain, "Fallback answer", intent="unknown")
147
 
148
  router = QueryRouter(classifier, retriever, reranker, llm_chain)
 
152
  assert response.intent == IntentType.UNKNOWN
153
  assert response.confidence == 0.0
154
  retriever.search_detailed.assert_not_called()
155
+ reranker.rerank.assert_not_called()
156
 
157
  def test_unknown_intent_prompt_uses_generic_instruction(
158
  self, mock_components
 
160
  """UNKNOWN intent should use the generic helpful instruction."""
161
  classifier, retriever, reranker, llm_chain = mock_components
162
 
 
163
  _setup_llm_chain_danish(llm_chain, "answer", intent="unknown")
164
 
165
  router = QueryRouter(classifier, retriever, reranker, llm_chain)
 
299
  response = router.route("test", top_k=3)
300
 
301
  assert response.confidence == pytest.approx(0.9, abs=1e-6)
302
+
303
+
304
+ class TestLowConfidenceRetry:
305
+ """Tests for the query-broadening retry loop on low confidence."""
306
+
307
+ def test_low_confidence_triggers_retry(self, mock_components) -> None:
308
+ """When reranker returns low-confidence results, the query should be
309
+ broadened and retrieval retried once."""
310
+ classifier, retriever, reranker, llm_chain = mock_components
311
+
312
+ low_results = [_make_query_result("weak match", 0.15)]
313
+ good_results = [_make_query_result("strong match", 0.85)]
314
+
315
+ retriever.search_detailed.return_value = _make_hybrid_result(low_results)
316
+ # First rerank: low confidence → triggers retry
317
+ # Second rerank: high confidence → proceeds to generate
318
+ reranker.rerank.side_effect = [low_results, good_results]
319
+
320
+ # LLM calls: detect, broaden_query, generate
321
+ combined = "language: Danish\nintent: factual"
322
+ llm_chain.invoke.side_effect = [combined, "bredere søgning", "Final answer"]
323
+
324
+ router = QueryRouter(classifier, retriever, reranker, llm_chain)
325
+ response = router.route("snævert spørgsmål", top_k=3)
326
+
327
+ assert response.answer == "Final answer"
328
+ assert response.confidence == pytest.approx(0.85, abs=1e-6)
329
+ assert retriever.search_detailed.call_count == 2
330
+ assert reranker.rerank.call_count == 2
331
+
332
+ def test_empty_results_do_not_trigger_retry(self, mock_components) -> None:
333
+ """When reranker returns no results at all, retrying is skipped."""
334
+ classifier, retriever, reranker, llm_chain = mock_components
335
+
336
+ retriever.search_detailed.return_value = _make_hybrid_result([])
337
+ reranker.rerank.return_value = []
338
+ _setup_llm_chain_danish(llm_chain, "No information found", intent="factual")
339
+
340
+ router = QueryRouter(classifier, retriever, reranker, llm_chain)
341
+ response = router.route("asdfghjkl", top_k=3)
342
+
343
+ assert response.confidence == 0.0
344
+ assert retriever.search_detailed.call_count == 1
345
+ # Reranker still called once (with empty input, returns [])
346
+ assert reranker.rerank.call_count <= 1
347
+
348
+ def test_high_confidence_skips_retry(self, mock_components) -> None:
349
+ """When confidence is above threshold, no retry is attempted."""
350
+ classifier, retriever, reranker, llm_chain = mock_components
351
+
352
+ results = [_make_query_result("good match", 0.9)]
353
+ retriever.search_detailed.return_value = _make_hybrid_result(results)
354
+ reranker.rerank.return_value = results
355
+ _setup_llm_chain_danish(llm_chain, "answer", intent="factual")
356
+
357
+ router = QueryRouter(classifier, retriever, reranker, llm_chain)
358
+ router.route("test", top_k=3)
359
+
360
+ assert retriever.search_detailed.call_count == 1
361
+ assert reranker.rerank.call_count == 1