Spaces:
Sleeping
Sleeping
XQ commited on
Commit ·
4ba88df
1
Parent(s): ec64993
Update router logic
Browse files- src/agent/router.py +90 -11
- tests/test_router.py +61 -2
src/agent/router.py
CHANGED
|
@@ -15,6 +15,11 @@ from src.retrieval.reranker import Reranker
|
|
| 15 |
|
| 16 |
logger = logging.getLogger(__name__)
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
class RouterState(TypedDict):
|
| 20 |
"""LangGraph state passed between routing nodes.
|
|
@@ -31,6 +36,7 @@ class RouterState(TypedDict):
|
|
| 31 |
fused_results: Results after RRF fusion.
|
| 32 |
reranked: Results after cross-encoder reranking.
|
| 33 |
confidence: Max reranker score (0.0-1.0).
|
|
|
|
| 34 |
answer: Final generated answer.
|
| 35 |
"""
|
| 36 |
|
|
@@ -45,6 +51,7 @@ class RouterState(TypedDict):
|
|
| 45 |
fused_results: list[QueryResult]
|
| 46 |
reranked: list[QueryResult]
|
| 47 |
confidence: float
|
|
|
|
| 48 |
answer: str
|
| 49 |
|
| 50 |
|
|
@@ -70,6 +77,7 @@ def _make_initial_state(query: str, top_k: int) -> RouterState:
|
|
| 70 |
fused_results=[],
|
| 71 |
reranked=[],
|
| 72 |
confidence=0.0,
|
|
|
|
| 73 |
answer="",
|
| 74 |
)
|
| 75 |
|
|
@@ -274,6 +282,53 @@ class QueryRouter:
|
|
| 274 |
logger.info("Confidence: %.4f (sigmoid-normalized by reranker)", confidence)
|
| 275 |
return {"reranked": reranked, "confidence": confidence}
|
| 276 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
@staticmethod
|
| 278 |
def _update_intent_node(state: RouterState) -> dict:
|
| 279 |
"""Promote FACTUAL to RAG when sources are found."""
|
|
@@ -295,19 +350,28 @@ class QueryRouter:
|
|
| 295 |
|
| 296 |
@staticmethod
|
| 297 |
def _should_retrieve(state: RouterState) -> str:
|
| 298 |
-
"""Skip retrieval when intent is UNKNOWN."""
|
| 299 |
-
return "retrieve" if state["intent"] != IntentType.UNKNOWN else "
|
| 300 |
|
| 301 |
def _build_graph(self) -> object:
|
| 302 |
"""Build the LangGraph routing graph.
|
| 303 |
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
|
| 312 |
Returns:
|
| 313 |
Compiled LangGraph graph.
|
|
@@ -317,18 +381,30 @@ class QueryRouter:
|
|
| 317 |
graph.add_node("translate", self._translate_node)
|
| 318 |
graph.add_node("retrieve", self._retrieve_node)
|
| 319 |
graph.add_node("rerank", self._rerank_node)
|
|
|
|
| 320 |
graph.add_node("update_intent", self._update_intent_node)
|
| 321 |
graph.add_node("generate", self._generate_node)
|
| 322 |
|
| 323 |
graph.set_entry_point("detect")
|
| 324 |
graph.add_edge("detect", "translate")
|
|
|
|
|
|
|
| 325 |
graph.add_conditional_edges(
|
| 326 |
"translate",
|
| 327 |
self._should_retrieve,
|
| 328 |
-
{"retrieve": "retrieve", "
|
| 329 |
)
|
|
|
|
| 330 |
graph.add_edge("retrieve", "rerank")
|
| 331 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
graph.add_edge("update_intent", "generate")
|
| 333 |
graph.add_edge("generate", END)
|
| 334 |
|
|
@@ -461,6 +537,9 @@ class QueryRouter:
|
|
| 461 |
elif node_name == "rerank":
|
| 462 |
event["reranked_count"] = len(update.get("reranked", []))
|
| 463 |
event["confidence"] = round(update.get("confidence", 0.0), 4)
|
|
|
|
|
|
|
|
|
|
| 464 |
|
| 465 |
yield event
|
| 466 |
|
|
|
|
| 15 |
|
| 16 |
logger = logging.getLogger(__name__)
|
| 17 |
|
| 18 |
+
# Reranker confidence below this triggers a query-broadening retry.
|
| 19 |
+
# Cross-encoder sigmoid scores below 0.3 generally indicate poor relevance.
|
| 20 |
+
_LOW_CONFIDENCE_THRESHOLD = 0.3
|
| 21 |
+
_MAX_RETRIES = 1
|
| 22 |
+
|
| 23 |
|
| 24 |
class RouterState(TypedDict):
|
| 25 |
"""LangGraph state passed between routing nodes.
|
|
|
|
| 36 |
fused_results: Results after RRF fusion.
|
| 37 |
reranked: Results after cross-encoder reranking.
|
| 38 |
confidence: Max reranker score (0.0-1.0).
|
| 39 |
+
retry_count: Number of query-broadening retries performed so far.
|
| 40 |
answer: Final generated answer.
|
| 41 |
"""
|
| 42 |
|
|
|
|
| 51 |
fused_results: list[QueryResult]
|
| 52 |
reranked: list[QueryResult]
|
| 53 |
confidence: float
|
| 54 |
+
retry_count: int
|
| 55 |
answer: str
|
| 56 |
|
| 57 |
|
|
|
|
| 77 |
fused_results=[],
|
| 78 |
reranked=[],
|
| 79 |
confidence=0.0,
|
| 80 |
+
retry_count=0,
|
| 81 |
answer="",
|
| 82 |
)
|
| 83 |
|
|
|
|
| 282 |
logger.info("Confidence: %.4f (sigmoid-normalized by reranker)", confidence)
|
| 283 |
return {"reranked": reranked, "confidence": confidence}
|
| 284 |
|
| 285 |
+
def _broaden_query_node(self, state: RouterState) -> dict:
|
| 286 |
+
"""Rewrite the retrieval query when reranker confidence is low.
|
| 287 |
+
|
| 288 |
+
Uses the LLM to generate alternative search terms while preserving
|
| 289 |
+
the original meaning, then increments the retry counter.
|
| 290 |
+
"""
|
| 291 |
+
prompt = (
|
| 292 |
+
"The following search query did not return good results from "
|
| 293 |
+
"the document database. Rewrite it to be broader or use "
|
| 294 |
+
"different keywords while keeping the same meaning. "
|
| 295 |
+
"Reply with ONLY the rewritten query, nothing else.\n\n"
|
| 296 |
+
f"Original question: {state['query']}\n"
|
| 297 |
+
f"Failed search query: {state['retrieval_query']}"
|
| 298 |
+
)
|
| 299 |
+
broadened = str(self._llm_chain.invoke(prompt)).strip()
|
| 300 |
+
logger.info(
|
| 301 |
+
"Broadened query for retry %d: %s",
|
| 302 |
+
state["retry_count"] + 1,
|
| 303 |
+
broadened,
|
| 304 |
+
)
|
| 305 |
+
return {
|
| 306 |
+
"retrieval_query": broadened,
|
| 307 |
+
"retry_count": state["retry_count"] + 1,
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
@staticmethod
|
| 311 |
+
def _check_confidence(state: RouterState) -> str:
|
| 312 |
+
"""Decide whether to retry retrieval or proceed to generation.
|
| 313 |
+
|
| 314 |
+
Triggers a retry when results exist but confidence is below
|
| 315 |
+
the threshold and retries remain. Empty results (no documents
|
| 316 |
+
matched at all) are not retried — broadening cannot help when
|
| 317 |
+
the knowledge base simply lacks coverage.
|
| 318 |
+
"""
|
| 319 |
+
if (
|
| 320 |
+
state.get("reranked")
|
| 321 |
+
and state["confidence"] < _LOW_CONFIDENCE_THRESHOLD
|
| 322 |
+
and state["retry_count"] < _MAX_RETRIES
|
| 323 |
+
):
|
| 324 |
+
logger.info(
|
| 325 |
+
"Low confidence (%.4f < %.2f), retrying with broadened query",
|
| 326 |
+
state["confidence"],
|
| 327 |
+
_LOW_CONFIDENCE_THRESHOLD,
|
| 328 |
+
)
|
| 329 |
+
return "retry"
|
| 330 |
+
return "accept"
|
| 331 |
+
|
| 332 |
@staticmethod
|
| 333 |
def _update_intent_node(state: RouterState) -> dict:
|
| 334 |
"""Promote FACTUAL to RAG when sources are found."""
|
|
|
|
| 350 |
|
| 351 |
@staticmethod
|
| 352 |
def _should_retrieve(state: RouterState) -> str:
|
| 353 |
+
"""Skip retrieval entirely when intent is UNKNOWN."""
|
| 354 |
+
return "retrieve" if state["intent"] != IntentType.UNKNOWN else "generate"
|
| 355 |
|
| 356 |
def _build_graph(self) -> object:
|
| 357 |
"""Build the LangGraph routing graph.
|
| 358 |
|
| 359 |
+
Graph topology::
|
| 360 |
+
|
| 361 |
+
detect → translate ─┬─ (UNKNOWN) ──────────────→ generate
|
| 362 |
+
└─ (other) → retrieve → rerank
|
| 363 |
+
↑ │
|
| 364 |
+
│ check_confidence
|
| 365 |
+
│ │ │
|
| 366 |
+
broaden ←─ retry accept
|
| 367 |
+
_query → update_intent
|
| 368 |
+
│
|
| 369 |
+
generate
|
| 370 |
+
|
| 371 |
+
Key LangGraph features demonstrated:
|
| 372 |
+
- Conditional edges: intent-based skip, confidence-based routing
|
| 373 |
+
- Cycle: low-confidence retry loop (broaden_query → retrieve)
|
| 374 |
+
- Shared state: retry_count controls loop termination
|
| 375 |
|
| 376 |
Returns:
|
| 377 |
Compiled LangGraph graph.
|
|
|
|
| 381 |
graph.add_node("translate", self._translate_node)
|
| 382 |
graph.add_node("retrieve", self._retrieve_node)
|
| 383 |
graph.add_node("rerank", self._rerank_node)
|
| 384 |
+
graph.add_node("broaden_query", self._broaden_query_node)
|
| 385 |
graph.add_node("update_intent", self._update_intent_node)
|
| 386 |
graph.add_node("generate", self._generate_node)
|
| 387 |
|
| 388 |
graph.set_entry_point("detect")
|
| 389 |
graph.add_edge("detect", "translate")
|
| 390 |
+
|
| 391 |
+
# Branch: skip retrieval entirely for off-topic queries
|
| 392 |
graph.add_conditional_edges(
|
| 393 |
"translate",
|
| 394 |
self._should_retrieve,
|
| 395 |
+
{"retrieve": "retrieve", "generate": "generate"},
|
| 396 |
)
|
| 397 |
+
|
| 398 |
graph.add_edge("retrieve", "rerank")
|
| 399 |
+
|
| 400 |
+
# Branch + cycle: retry with broadened query on low confidence
|
| 401 |
+
graph.add_conditional_edges(
|
| 402 |
+
"rerank",
|
| 403 |
+
self._check_confidence,
|
| 404 |
+
{"retry": "broaden_query", "accept": "update_intent"},
|
| 405 |
+
)
|
| 406 |
+
graph.add_edge("broaden_query", "retrieve") # ← the loop
|
| 407 |
+
|
| 408 |
graph.add_edge("update_intent", "generate")
|
| 409 |
graph.add_edge("generate", END)
|
| 410 |
|
|
|
|
| 537 |
elif node_name == "rerank":
|
| 538 |
event["reranked_count"] = len(update.get("reranked", []))
|
| 539 |
event["confidence"] = round(update.get("confidence", 0.0), 4)
|
| 540 |
+
elif node_name == "broaden_query":
|
| 541 |
+
event["retrieval_query"] = update.get("retrieval_query", "")
|
| 542 |
+
event["retry_count"] = update.get("retry_count", 0)
|
| 543 |
|
| 544 |
yield event
|
| 545 |
|
tests/test_router.py
CHANGED
|
@@ -143,7 +143,6 @@ class TestQueryRouterDirect:
|
|
| 143 |
"""UNKNOWN intent skips retrieval and returns zero confidence."""
|
| 144 |
classifier, retriever, reranker, llm_chain = mock_components
|
| 145 |
|
| 146 |
-
reranker.rerank.return_value = []
|
| 147 |
_setup_llm_chain_danish(llm_chain, "Fallback answer", intent="unknown")
|
| 148 |
|
| 149 |
router = QueryRouter(classifier, retriever, reranker, llm_chain)
|
|
@@ -153,6 +152,7 @@ class TestQueryRouterDirect:
|
|
| 153 |
assert response.intent == IntentType.UNKNOWN
|
| 154 |
assert response.confidence == 0.0
|
| 155 |
retriever.search_detailed.assert_not_called()
|
|
|
|
| 156 |
|
| 157 |
def test_unknown_intent_prompt_uses_generic_instruction(
|
| 158 |
self, mock_components
|
|
@@ -160,7 +160,6 @@ class TestQueryRouterDirect:
|
|
| 160 |
"""UNKNOWN intent should use the generic helpful instruction."""
|
| 161 |
classifier, retriever, reranker, llm_chain = mock_components
|
| 162 |
|
| 163 |
-
reranker.rerank.return_value = []
|
| 164 |
_setup_llm_chain_danish(llm_chain, "answer", intent="unknown")
|
| 165 |
|
| 166 |
router = QueryRouter(classifier, retriever, reranker, llm_chain)
|
|
@@ -300,3 +299,63 @@ class TestSigmoidInReranker:
|
|
| 300 |
response = router.route("test", top_k=3)
|
| 301 |
|
| 302 |
assert response.confidence == pytest.approx(0.9, abs=1e-6)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
"""UNKNOWN intent skips retrieval and returns zero confidence."""
|
| 144 |
classifier, retriever, reranker, llm_chain = mock_components
|
| 145 |
|
|
|
|
| 146 |
_setup_llm_chain_danish(llm_chain, "Fallback answer", intent="unknown")
|
| 147 |
|
| 148 |
router = QueryRouter(classifier, retriever, reranker, llm_chain)
|
|
|
|
| 152 |
assert response.intent == IntentType.UNKNOWN
|
| 153 |
assert response.confidence == 0.0
|
| 154 |
retriever.search_detailed.assert_not_called()
|
| 155 |
+
reranker.rerank.assert_not_called()
|
| 156 |
|
| 157 |
def test_unknown_intent_prompt_uses_generic_instruction(
|
| 158 |
self, mock_components
|
|
|
|
| 160 |
"""UNKNOWN intent should use the generic helpful instruction."""
|
| 161 |
classifier, retriever, reranker, llm_chain = mock_components
|
| 162 |
|
|
|
|
| 163 |
_setup_llm_chain_danish(llm_chain, "answer", intent="unknown")
|
| 164 |
|
| 165 |
router = QueryRouter(classifier, retriever, reranker, llm_chain)
|
|
|
|
| 299 |
response = router.route("test", top_k=3)
|
| 300 |
|
| 301 |
assert response.confidence == pytest.approx(0.9, abs=1e-6)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
class TestLowConfidenceRetry:
|
| 305 |
+
"""Tests for the query-broadening retry loop on low confidence."""
|
| 306 |
+
|
| 307 |
+
def test_low_confidence_triggers_retry(self, mock_components) -> None:
|
| 308 |
+
"""When reranker returns low-confidence results, the query should be
|
| 309 |
+
broadened and retrieval retried once."""
|
| 310 |
+
classifier, retriever, reranker, llm_chain = mock_components
|
| 311 |
+
|
| 312 |
+
low_results = [_make_query_result("weak match", 0.15)]
|
| 313 |
+
good_results = [_make_query_result("strong match", 0.85)]
|
| 314 |
+
|
| 315 |
+
retriever.search_detailed.return_value = _make_hybrid_result(low_results)
|
| 316 |
+
# First rerank: low confidence → triggers retry
|
| 317 |
+
# Second rerank: high confidence → proceeds to generate
|
| 318 |
+
reranker.rerank.side_effect = [low_results, good_results]
|
| 319 |
+
|
| 320 |
+
# LLM calls: detect, broaden_query, generate
|
| 321 |
+
combined = "language: Danish\nintent: factual"
|
| 322 |
+
llm_chain.invoke.side_effect = [combined, "bredere søgning", "Final answer"]
|
| 323 |
+
|
| 324 |
+
router = QueryRouter(classifier, retriever, reranker, llm_chain)
|
| 325 |
+
response = router.route("snævert spørgsmål", top_k=3)
|
| 326 |
+
|
| 327 |
+
assert response.answer == "Final answer"
|
| 328 |
+
assert response.confidence == pytest.approx(0.85, abs=1e-6)
|
| 329 |
+
assert retriever.search_detailed.call_count == 2
|
| 330 |
+
assert reranker.rerank.call_count == 2
|
| 331 |
+
|
| 332 |
+
def test_empty_results_do_not_trigger_retry(self, mock_components) -> None:
|
| 333 |
+
"""When reranker returns no results at all, retrying is skipped."""
|
| 334 |
+
classifier, retriever, reranker, llm_chain = mock_components
|
| 335 |
+
|
| 336 |
+
retriever.search_detailed.return_value = _make_hybrid_result([])
|
| 337 |
+
reranker.rerank.return_value = []
|
| 338 |
+
_setup_llm_chain_danish(llm_chain, "No information found", intent="factual")
|
| 339 |
+
|
| 340 |
+
router = QueryRouter(classifier, retriever, reranker, llm_chain)
|
| 341 |
+
response = router.route("asdfghjkl", top_k=3)
|
| 342 |
+
|
| 343 |
+
assert response.confidence == 0.0
|
| 344 |
+
assert retriever.search_detailed.call_count == 1
|
| 345 |
+
# Reranker still called once (with empty input, returns [])
|
| 346 |
+
assert reranker.rerank.call_count <= 1
|
| 347 |
+
|
| 348 |
+
def test_high_confidence_skips_retry(self, mock_components) -> None:
|
| 349 |
+
"""When confidence is above threshold, no retry is attempted."""
|
| 350 |
+
classifier, retriever, reranker, llm_chain = mock_components
|
| 351 |
+
|
| 352 |
+
results = [_make_query_result("good match", 0.9)]
|
| 353 |
+
retriever.search_detailed.return_value = _make_hybrid_result(results)
|
| 354 |
+
reranker.rerank.return_value = results
|
| 355 |
+
_setup_llm_chain_danish(llm_chain, "answer", intent="factual")
|
| 356 |
+
|
| 357 |
+
router = QueryRouter(classifier, retriever, reranker, llm_chain)
|
| 358 |
+
router.route("test", top_k=3)
|
| 359 |
+
|
| 360 |
+
assert retriever.search_detailed.call_count == 1
|
| 361 |
+
assert reranker.rerank.call_count == 1
|