Spaces:
Running
Running
XQ commited on
Commit ·
ec64993
1
Parent(s): c263a7d
Code cleaning
Browse files- scripts/e2e_test.py +2 -2
- scripts/evaluate.py +2 -2
- src/agent/react_router.py +5 -19
- src/agent/router.py +115 -128
- src/agent/tools.py +1 -1
- src/api/main.py +13 -4
- src/api/routes.py +5 -25
- src/config.py +2 -0
- src/models.py +20 -0
- src/retrieval/bm25_search.py +0 -43
- src/retrieval/hybrid.py +5 -36
- src/retrieval/vector_store.py +24 -84
- tests/test_hybrid.py +5 -26
- tests/test_router.py +56 -56
scripts/e2e_test.py
CHANGED
|
@@ -98,12 +98,12 @@ def main() -> None:
|
|
| 98 |
)
|
| 99 |
reranker = Reranker(model=create_reranker(settings.reranker_model))
|
| 100 |
classifier = IntentClassifier(llm=llm)
|
| 101 |
-
|
| 102 |
router = QueryRouter(
|
| 103 |
intent_classifier=classifier,
|
| 104 |
hybrid_retriever=hybrid,
|
| 105 |
reranker=reranker,
|
| 106 |
-
|
| 107 |
)
|
| 108 |
|
| 109 |
# --- 5) Run query ---
|
|
|
|
| 98 |
)
|
| 99 |
reranker = Reranker(model=create_reranker(settings.reranker_model))
|
| 100 |
classifier = IntentClassifier(llm=llm)
|
| 101 |
+
llm_chain = llm | StrOutputParser()
|
| 102 |
router = QueryRouter(
|
| 103 |
intent_classifier=classifier,
|
| 104 |
hybrid_retriever=hybrid,
|
| 105 |
reranker=reranker,
|
| 106 |
+
llm_chain=llm_chain,
|
| 107 |
)
|
| 108 |
|
| 109 |
# --- 5) Run query ---
|
scripts/evaluate.py
CHANGED
|
@@ -156,12 +156,12 @@ def main() -> None:
|
|
| 156 |
)
|
| 157 |
reranker = Reranker(model=create_reranker(settings.reranker_model))
|
| 158 |
classifier = IntentClassifier(llm=llm, model_name=settings.generation_model)
|
| 159 |
-
|
| 160 |
router = QueryRouter(
|
| 161 |
intent_classifier=classifier,
|
| 162 |
hybrid_retriever=hybrid,
|
| 163 |
reranker=reranker,
|
| 164 |
-
|
| 165 |
)
|
| 166 |
|
| 167 |
# --- 5) Run test set ---
|
|
|
|
| 156 |
)
|
| 157 |
reranker = Reranker(model=create_reranker(settings.reranker_model))
|
| 158 |
classifier = IntentClassifier(llm=llm, model_name=settings.generation_model)
|
| 159 |
+
llm_chain = llm | StrOutputParser()
|
| 160 |
router = QueryRouter(
|
| 161 |
intent_classifier=classifier,
|
| 162 |
hybrid_retriever=hybrid,
|
| 163 |
reranker=reranker,
|
| 164 |
+
llm_chain=llm_chain,
|
| 165 |
)
|
| 166 |
|
| 167 |
# --- 5) Run test set ---
|
src/agent/react_router.py
CHANGED
|
@@ -39,20 +39,6 @@ _SYSTEM_PROMPT = (
|
|
| 39 |
)
|
| 40 |
|
| 41 |
|
| 42 |
-
def _ser_sources(sources: list[QueryResult]) -> list[dict]:
|
| 43 |
-
"""Serialise QueryResult list to a JSON-safe list of dicts."""
|
| 44 |
-
return [
|
| 45 |
-
{
|
| 46 |
-
"chunk_id": r.chunk.chunk_id,
|
| 47 |
-
"document_id": r.chunk.document_id,
|
| 48 |
-
"text": r.chunk.text,
|
| 49 |
-
"score": r.score,
|
| 50 |
-
"source": r.source,
|
| 51 |
-
}
|
| 52 |
-
for r in sources
|
| 53 |
-
]
|
| 54 |
-
|
| 55 |
-
|
| 56 |
class ReActRouter:
|
| 57 |
"""Routes queries through a multi-step ReAct agent with tool-calling LLM.
|
| 58 |
|
|
@@ -231,7 +217,7 @@ class ReActRouter:
|
|
| 231 |
"step": "done",
|
| 232 |
"result": {
|
| 233 |
"answer": answer,
|
| 234 |
-
"sources":
|
| 235 |
"intent": (IntentType.RAG if sources else IntentType.FACTUAL).value,
|
| 236 |
"confidence": confidence,
|
| 237 |
"pipeline_details": {
|
|
@@ -239,10 +225,10 @@ class ReActRouter:
|
|
| 239 |
"retrieval_query": ", ".join(q for _, q in store.tool_calls) or query,
|
| 240 |
"detected_language": "unknown",
|
| 241 |
"translated": False,
|
| 242 |
-
"dense_results":
|
| 243 |
-
"sparse_results":
|
| 244 |
-
"fused_results":
|
| 245 |
-
"reranked_results":
|
| 246 |
},
|
| 247 |
},
|
| 248 |
}
|
|
|
|
| 39 |
)
|
| 40 |
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
class ReActRouter:
|
| 43 |
"""Routes queries through a multi-step ReAct agent with tool-calling LLM.
|
| 44 |
|
|
|
|
| 217 |
"step": "done",
|
| 218 |
"result": {
|
| 219 |
"answer": answer,
|
| 220 |
+
"sources": [r.to_dict() for r in sources],
|
| 221 |
"intent": (IntentType.RAG if sources else IntentType.FACTUAL).value,
|
| 222 |
"confidence": confidence,
|
| 223 |
"pipeline_details": {
|
|
|
|
| 225 |
"retrieval_query": ", ".join(q for _, q in store.tool_calls) or query,
|
| 226 |
"detected_language": "unknown",
|
| 227 |
"translated": False,
|
| 228 |
+
"dense_results": [r.to_dict(include_text=False) for r in store.dense_results],
|
| 229 |
+
"sparse_results": [r.to_dict(include_text=False) for r in store.sparse_results],
|
| 230 |
+
"fused_results": [r.to_dict(include_text=False) for r in store.fused_results],
|
| 231 |
+
"reranked_results": [r.to_dict(include_text=False) for r in sources],
|
| 232 |
},
|
| 233 |
},
|
| 234 |
}
|
src/agent/router.py
CHANGED
|
@@ -48,6 +48,32 @@ class RouterState(TypedDict):
|
|
| 48 |
answer: str
|
| 49 |
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
class QueryRouter:
|
| 52 |
"""Routes queries to appropriate retrieval and generation pipelines."""
|
| 53 |
|
|
@@ -56,7 +82,7 @@ class QueryRouter:
|
|
| 56 |
intent_classifier: IntentClassifier,
|
| 57 |
hybrid_retriever: HybridRetriever,
|
| 58 |
reranker: Reranker,
|
| 59 |
-
|
| 60 |
*,
|
| 61 |
translate_query: bool = True,
|
| 62 |
) -> None:
|
|
@@ -66,7 +92,8 @@ class QueryRouter:
|
|
| 66 |
intent_classifier: IntentClassifier instance.
|
| 67 |
hybrid_retriever: HybridRetriever instance.
|
| 68 |
reranker: Reranker instance.
|
| 69 |
-
|
|
|
|
| 70 |
translate_query: Whether to translate non-Danish queries to Danish
|
| 71 |
before retrieval. When False, language detection still runs for
|
| 72 |
the answer-language rule but no translation is performed.
|
|
@@ -74,7 +101,7 @@ class QueryRouter:
|
|
| 74 |
self._intent_classifier = intent_classifier
|
| 75 |
self._hybrid_retriever = hybrid_retriever
|
| 76 |
self._reranker = reranker
|
| 77 |
-
self.
|
| 78 |
self._translate_query_enabled = translate_query
|
| 79 |
self._graph = self._build_graph()
|
| 80 |
|
|
@@ -155,7 +182,7 @@ class QueryRouter:
|
|
| 155 |
"intent: <intent>\n\n"
|
| 156 |
f"Query: {query}"
|
| 157 |
)
|
| 158 |
-
raw = str(self.
|
| 159 |
logger.debug("Combined detection raw response: %s", raw)
|
| 160 |
|
| 161 |
# Parse response
|
|
@@ -200,10 +227,77 @@ class QueryRouter:
|
|
| 200 |
"Reply with ONLY the translated text, nothing else.\n\n"
|
| 201 |
f"Text: {query}"
|
| 202 |
)
|
| 203 |
-
translated = str(self.
|
| 204 |
logger.info("Translated query to Danish: %s", translated)
|
| 205 |
return translated
|
| 206 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
def _build_graph(self) -> object:
|
| 208 |
"""Build the LangGraph routing graph.
|
| 209 |
|
|
@@ -218,75 +312,19 @@ class QueryRouter:
|
|
| 218 |
Returns:
|
| 219 |
Compiled LangGraph graph.
|
| 220 |
"""
|
| 221 |
-
|
| 222 |
-
def detect_node(state: RouterState) -> dict:
|
| 223 |
-
user_language, intent = self._detect_language_and_intent(state["query"])
|
| 224 |
-
return {"user_language": user_language, "intent": intent}
|
| 225 |
-
|
| 226 |
-
def translate_node(state: RouterState) -> dict:
|
| 227 |
-
retrieval_query = self._translate_query(state["query"], state["user_language"])
|
| 228 |
-
return {
|
| 229 |
-
"retrieval_query": retrieval_query,
|
| 230 |
-
"translated": retrieval_query != state["query"],
|
| 231 |
-
}
|
| 232 |
-
|
| 233 |
-
def retrieve_node(state: RouterState) -> dict:
|
| 234 |
-
hybrid_result = self._hybrid_retriever.search_detailed(
|
| 235 |
-
state["retrieval_query"], top_k=state["top_k"]
|
| 236 |
-
)
|
| 237 |
-
logger.info("Retrieved %d results from hybrid search", len(hybrid_result.fused_results))
|
| 238 |
-
return {
|
| 239 |
-
"dense_results": hybrid_result.dense_results,
|
| 240 |
-
"sparse_results": hybrid_result.sparse_results,
|
| 241 |
-
"fused_results": hybrid_result.fused_results,
|
| 242 |
-
}
|
| 243 |
-
|
| 244 |
-
def rerank_node(state: RouterState) -> dict:
|
| 245 |
-
results = state.get("fused_results", [])
|
| 246 |
-
reranked = (
|
| 247 |
-
self._reranker.rerank(state["retrieval_query"], results, top_k=state["top_k"])
|
| 248 |
-
if results
|
| 249 |
-
else []
|
| 250 |
-
)
|
| 251 |
-
confidence = max(r.score for r in reranked) if reranked else 0.0
|
| 252 |
-
logger.info("Reranked to %d results", len(reranked))
|
| 253 |
-
if reranked:
|
| 254 |
-
logger.info("Confidence: %.4f (sigmoid-normalized by reranker)", confidence)
|
| 255 |
-
return {"reranked": reranked, "confidence": confidence}
|
| 256 |
-
|
| 257 |
-
def update_intent_node(state: RouterState) -> dict:
|
| 258 |
-
if state.get("reranked") and state["intent"] == IntentType.FACTUAL:
|
| 259 |
-
logger.info("Overriding intent to RAG (sources retrieved)")
|
| 260 |
-
return {"intent": IntentType.RAG}
|
| 261 |
-
return {}
|
| 262 |
-
|
| 263 |
-
def generate_node(state: RouterState) -> dict:
|
| 264 |
-
reranked = state.get("reranked", [])
|
| 265 |
-
context = "\n\n".join(r.chunk.text for r in reranked)
|
| 266 |
-
prompt = self._build_prompt(
|
| 267 |
-
state["query"], state["intent"], context, state["user_language"]
|
| 268 |
-
)
|
| 269 |
-
answer = self._generator.invoke(prompt)
|
| 270 |
-
logger.info("Generated answer for intent=%s", state["intent"].value)
|
| 271 |
-
return {"answer": str(answer)}
|
| 272 |
-
|
| 273 |
-
def should_retrieve(state: RouterState) -> str:
|
| 274 |
-
"""Skip retrieval when intent is UNKNOWN."""
|
| 275 |
-
return "retrieve" if state["intent"] != IntentType.UNKNOWN else "rerank"
|
| 276 |
-
|
| 277 |
graph: StateGraph = StateGraph(RouterState)
|
| 278 |
-
graph.add_node("detect",
|
| 279 |
-
graph.add_node("translate",
|
| 280 |
-
graph.add_node("retrieve",
|
| 281 |
-
graph.add_node("rerank",
|
| 282 |
-
graph.add_node("update_intent",
|
| 283 |
-
graph.add_node("generate",
|
| 284 |
|
| 285 |
graph.set_entry_point("detect")
|
| 286 |
graph.add_edge("detect", "translate")
|
| 287 |
graph.add_conditional_edges(
|
| 288 |
"translate",
|
| 289 |
-
|
| 290 |
{"retrieve": "retrieve", "rerank": "rerank"},
|
| 291 |
)
|
| 292 |
graph.add_edge("retrieve", "rerank")
|
|
@@ -308,22 +346,7 @@ class QueryRouter:
|
|
| 308 |
"""
|
| 309 |
logger.info("Routing query: %s", query)
|
| 310 |
|
| 311 |
-
|
| 312 |
-
"query": query,
|
| 313 |
-
"top_k": top_k,
|
| 314 |
-
"user_language": "Danish",
|
| 315 |
-
"intent": IntentType.UNKNOWN,
|
| 316 |
-
"retrieval_query": query,
|
| 317 |
-
"translated": False,
|
| 318 |
-
"dense_results": [],
|
| 319 |
-
"sparse_results": [],
|
| 320 |
-
"fused_results": [],
|
| 321 |
-
"reranked": [],
|
| 322 |
-
"confidence": 0.0,
|
| 323 |
-
"answer": "",
|
| 324 |
-
}
|
| 325 |
-
|
| 326 |
-
final_state: RouterState = self._graph.invoke(initial_state)
|
| 327 |
|
| 328 |
pipeline = PipelineDetails(
|
| 329 |
original_query=query,
|
|
@@ -386,7 +409,7 @@ class QueryRouter:
|
|
| 386 |
# context = "\n\n".join(r.chunk.text for r in reranked)
|
| 387 |
# prompt = self._build_prompt(query, intent, context, user_language)
|
| 388 |
#
|
| 389 |
-
# answer = self.
|
| 390 |
# logger.info("Generated answer for intent=%s", intent.value)
|
| 391 |
#
|
| 392 |
# if reranked:
|
|
@@ -417,24 +440,9 @@ class QueryRouter:
|
|
| 417 |
Yields:
|
| 418 |
Step event dicts, then a final ``done`` event with the result.
|
| 419 |
"""
|
| 420 |
-
|
| 421 |
-
"query": query,
|
| 422 |
-
"top_k": top_k,
|
| 423 |
-
"user_language": "Danish",
|
| 424 |
-
"intent": IntentType.UNKNOWN,
|
| 425 |
-
"retrieval_query": query,
|
| 426 |
-
"translated": False,
|
| 427 |
-
"dense_results": [],
|
| 428 |
-
"sparse_results": [],
|
| 429 |
-
"fused_results": [],
|
| 430 |
-
"reranked": [],
|
| 431 |
-
"confidence": 0.0,
|
| 432 |
-
"answer": "",
|
| 433 |
-
}
|
| 434 |
-
|
| 435 |
-
accumulated: dict = dict(initial_state)
|
| 436 |
|
| 437 |
-
for chunk in self._graph.stream(
|
| 438 |
for node_name, update in chunk.items():
|
| 439 |
if update is None:
|
| 440 |
continue
|
|
@@ -453,24 +461,12 @@ class QueryRouter:
|
|
| 453 |
elif node_name == "rerank":
|
| 454 |
event["reranked_count"] = len(update.get("reranked", []))
|
| 455 |
event["confidence"] = round(update.get("confidence", 0.0), 4)
|
| 456 |
-
# update_intent and generate: no extra fields needed
|
| 457 |
|
| 458 |
yield event
|
| 459 |
|
| 460 |
# Build the final response from accumulated state and emit as "done"
|
| 461 |
reranked: list = accumulated.get("reranked", [])
|
| 462 |
|
| 463 |
-
def _ser(results: list) -> list[dict]:
|
| 464 |
-
return [
|
| 465 |
-
{
|
| 466 |
-
"document_id": r.chunk.document_id,
|
| 467 |
-
"chunk_id": r.chunk.chunk_id,
|
| 468 |
-
"score": r.score,
|
| 469 |
-
"source": r.source,
|
| 470 |
-
}
|
| 471 |
-
for r in results
|
| 472 |
-
]
|
| 473 |
-
|
| 474 |
pd_acc = PipelineDetails(
|
| 475 |
original_query=query,
|
| 476 |
retrieval_query=accumulated.get("retrieval_query", query),
|
|
@@ -486,16 +482,7 @@ class QueryRouter:
|
|
| 486 |
"step": "done",
|
| 487 |
"result": {
|
| 488 |
"answer": accumulated.get("answer", ""),
|
| 489 |
-
"sources": [
|
| 490 |
-
{
|
| 491 |
-
"chunk_id": r.chunk.chunk_id,
|
| 492 |
-
"document_id": r.chunk.document_id,
|
| 493 |
-
"text": r.chunk.text,
|
| 494 |
-
"score": r.score,
|
| 495 |
-
"source": r.source,
|
| 496 |
-
}
|
| 497 |
-
for r in reranked
|
| 498 |
-
],
|
| 499 |
"intent": accumulated.get("intent", IntentType.UNKNOWN).value,
|
| 500 |
"confidence": accumulated.get("confidence", 0.0),
|
| 501 |
"pipeline_details": {
|
|
@@ -503,10 +490,10 @@ class QueryRouter:
|
|
| 503 |
"retrieval_query": pd_acc.retrieval_query,
|
| 504 |
"detected_language": pd_acc.detected_language,
|
| 505 |
"translated": pd_acc.translated,
|
| 506 |
-
"dense_results":
|
| 507 |
-
"sparse_results":
|
| 508 |
-
"fused_results":
|
| 509 |
-
"reranked_results":
|
| 510 |
},
|
| 511 |
},
|
| 512 |
}
|
|
|
|
| 48 |
answer: str
|
| 49 |
|
| 50 |
|
| 51 |
+
def _make_initial_state(query: str, top_k: int) -> RouterState:
|
| 52 |
+
"""Create a fresh RouterState with sensible defaults.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
query: The user's original query.
|
| 56 |
+
top_k: Number of results to retrieve.
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
RouterState ready to be passed into the graph.
|
| 60 |
+
"""
|
| 61 |
+
return RouterState(
|
| 62 |
+
query=query,
|
| 63 |
+
top_k=top_k,
|
| 64 |
+
user_language="Danish",
|
| 65 |
+
intent=IntentType.UNKNOWN,
|
| 66 |
+
retrieval_query=query,
|
| 67 |
+
translated=False,
|
| 68 |
+
dense_results=[],
|
| 69 |
+
sparse_results=[],
|
| 70 |
+
fused_results=[],
|
| 71 |
+
reranked=[],
|
| 72 |
+
confidence=0.0,
|
| 73 |
+
answer="",
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
class QueryRouter:
|
| 78 |
"""Routes queries to appropriate retrieval and generation pipelines."""
|
| 79 |
|
|
|
|
| 82 |
intent_classifier: IntentClassifier,
|
| 83 |
hybrid_retriever: HybridRetriever,
|
| 84 |
reranker: Reranker,
|
| 85 |
+
llm_chain: Runnable,
|
| 86 |
*,
|
| 87 |
translate_query: bool = True,
|
| 88 |
) -> None:
|
|
|
|
| 92 |
intent_classifier: IntentClassifier instance.
|
| 93 |
hybrid_retriever: HybridRetriever instance.
|
| 94 |
reranker: Reranker instance.
|
| 95 |
+
llm_chain: LLM chain (llm | StrOutputParser) for generation,
|
| 96 |
+
translation, and language detection.
|
| 97 |
translate_query: Whether to translate non-Danish queries to Danish
|
| 98 |
before retrieval. When False, language detection still runs for
|
| 99 |
the answer-language rule but no translation is performed.
|
|
|
|
| 101 |
self._intent_classifier = intent_classifier
|
| 102 |
self._hybrid_retriever = hybrid_retriever
|
| 103 |
self._reranker = reranker
|
| 104 |
+
self._llm_chain = llm_chain
|
| 105 |
self._translate_query_enabled = translate_query
|
| 106 |
self._graph = self._build_graph()
|
| 107 |
|
|
|
|
| 182 |
"intent: <intent>\n\n"
|
| 183 |
f"Query: {query}"
|
| 184 |
)
|
| 185 |
+
raw = str(self._llm_chain.invoke(prompt)).strip()
|
| 186 |
logger.debug("Combined detection raw response: %s", raw)
|
| 187 |
|
| 188 |
# Parse response
|
|
|
|
| 227 |
"Reply with ONLY the translated text, nothing else.\n\n"
|
| 228 |
f"Text: {query}"
|
| 229 |
)
|
| 230 |
+
translated = str(self._llm_chain.invoke(translate_prompt)).strip()
|
| 231 |
logger.info("Translated query to Danish: %s", translated)
|
| 232 |
return translated
|
| 233 |
|
| 234 |
+
# ------------------------------------------------------------------
|
| 235 |
+
# LangGraph node functions
|
| 236 |
+
# ------------------------------------------------------------------
|
| 237 |
+
|
| 238 |
+
def _detect_node(self, state: RouterState) -> dict:
|
| 239 |
+
"""Detect language and classify intent."""
|
| 240 |
+
user_language, intent = self._detect_language_and_intent(state["query"])
|
| 241 |
+
return {"user_language": user_language, "intent": intent}
|
| 242 |
+
|
| 243 |
+
def _translate_node(self, state: RouterState) -> dict:
|
| 244 |
+
"""Translate query to Danish if needed."""
|
| 245 |
+
retrieval_query = self._translate_query(state["query"], state["user_language"])
|
| 246 |
+
return {
|
| 247 |
+
"retrieval_query": retrieval_query,
|
| 248 |
+
"translated": retrieval_query != state["query"],
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
def _retrieve_node(self, state: RouterState) -> dict:
|
| 252 |
+
"""Run hybrid search."""
|
| 253 |
+
hybrid_result = self._hybrid_retriever.search_detailed(
|
| 254 |
+
state["retrieval_query"], top_k=state["top_k"]
|
| 255 |
+
)
|
| 256 |
+
logger.info("Retrieved %d results from hybrid search", len(hybrid_result.fused_results))
|
| 257 |
+
return {
|
| 258 |
+
"dense_results": hybrid_result.dense_results,
|
| 259 |
+
"sparse_results": hybrid_result.sparse_results,
|
| 260 |
+
"fused_results": hybrid_result.fused_results,
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
def _rerank_node(self, state: RouterState) -> dict:
|
| 264 |
+
"""Rerank fused results with cross-encoder."""
|
| 265 |
+
results = state.get("fused_results", [])
|
| 266 |
+
reranked = (
|
| 267 |
+
self._reranker.rerank(state["retrieval_query"], results, top_k=state["top_k"])
|
| 268 |
+
if results
|
| 269 |
+
else []
|
| 270 |
+
)
|
| 271 |
+
confidence = max(r.score for r in reranked) if reranked else 0.0
|
| 272 |
+
logger.info("Reranked to %d results", len(reranked))
|
| 273 |
+
if reranked:
|
| 274 |
+
logger.info("Confidence: %.4f (sigmoid-normalized by reranker)", confidence)
|
| 275 |
+
return {"reranked": reranked, "confidence": confidence}
|
| 276 |
+
|
| 277 |
+
@staticmethod
|
| 278 |
+
def _update_intent_node(state: RouterState) -> dict:
|
| 279 |
+
"""Promote FACTUAL to RAG when sources are found."""
|
| 280 |
+
if state.get("reranked") and state["intent"] == IntentType.FACTUAL:
|
| 281 |
+
logger.info("Overriding intent to RAG (sources retrieved)")
|
| 282 |
+
return {"intent": IntentType.RAG}
|
| 283 |
+
return {}
|
| 284 |
+
|
| 285 |
+
def _generate_node(self, state: RouterState) -> dict:
|
| 286 |
+
"""Build prompt and call LLM."""
|
| 287 |
+
reranked = state.get("reranked", [])
|
| 288 |
+
context = "\n\n".join(r.chunk.text for r in reranked)
|
| 289 |
+
prompt = self._build_prompt(
|
| 290 |
+
state["query"], state["intent"], context, state["user_language"]
|
| 291 |
+
)
|
| 292 |
+
answer = self._llm_chain.invoke(prompt)
|
| 293 |
+
logger.info("Generated answer for intent=%s", state["intent"].value)
|
| 294 |
+
return {"answer": str(answer)}
|
| 295 |
+
|
| 296 |
+
@staticmethod
|
| 297 |
+
def _should_retrieve(state: RouterState) -> str:
|
| 298 |
+
"""Skip retrieval when intent is UNKNOWN."""
|
| 299 |
+
return "retrieve" if state["intent"] != IntentType.UNKNOWN else "rerank"
|
| 300 |
+
|
| 301 |
def _build_graph(self) -> object:
|
| 302 |
"""Build the LangGraph routing graph.
|
| 303 |
|
|
|
|
| 312 |
Returns:
|
| 313 |
Compiled LangGraph graph.
|
| 314 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
graph: StateGraph = StateGraph(RouterState)
|
| 316 |
+
graph.add_node("detect", self._detect_node)
|
| 317 |
+
graph.add_node("translate", self._translate_node)
|
| 318 |
+
graph.add_node("retrieve", self._retrieve_node)
|
| 319 |
+
graph.add_node("rerank", self._rerank_node)
|
| 320 |
+
graph.add_node("update_intent", self._update_intent_node)
|
| 321 |
+
graph.add_node("generate", self._generate_node)
|
| 322 |
|
| 323 |
graph.set_entry_point("detect")
|
| 324 |
graph.add_edge("detect", "translate")
|
| 325 |
graph.add_conditional_edges(
|
| 326 |
"translate",
|
| 327 |
+
self._should_retrieve,
|
| 328 |
{"retrieve": "retrieve", "rerank": "rerank"},
|
| 329 |
)
|
| 330 |
graph.add_edge("retrieve", "rerank")
|
|
|
|
| 346 |
"""
|
| 347 |
logger.info("Routing query: %s", query)
|
| 348 |
|
| 349 |
+
final_state: RouterState = self._graph.invoke(_make_initial_state(query, top_k))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
|
| 351 |
pipeline = PipelineDetails(
|
| 352 |
original_query=query,
|
|
|
|
| 409 |
# context = "\n\n".join(r.chunk.text for r in reranked)
|
| 410 |
# prompt = self._build_prompt(query, intent, context, user_language)
|
| 411 |
#
|
| 412 |
+
# answer = self._llm_chain.invoke(prompt)
|
| 413 |
# logger.info("Generated answer for intent=%s", intent.value)
|
| 414 |
#
|
| 415 |
# if reranked:
|
|
|
|
| 440 |
Yields:
|
| 441 |
Step event dicts, then a final ``done`` event with the result.
|
| 442 |
"""
|
| 443 |
+
accumulated: dict = dict(_make_initial_state(query, top_k))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 444 |
|
| 445 |
+
for chunk in self._graph.stream(_make_initial_state(query, top_k), stream_mode="updates"):
|
| 446 |
for node_name, update in chunk.items():
|
| 447 |
if update is None:
|
| 448 |
continue
|
|
|
|
| 461 |
elif node_name == "rerank":
|
| 462 |
event["reranked_count"] = len(update.get("reranked", []))
|
| 463 |
event["confidence"] = round(update.get("confidence", 0.0), 4)
|
|
|
|
| 464 |
|
| 465 |
yield event
|
| 466 |
|
| 467 |
# Build the final response from accumulated state and emit as "done"
|
| 468 |
reranked: list = accumulated.get("reranked", [])
|
| 469 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 470 |
pd_acc = PipelineDetails(
|
| 471 |
original_query=query,
|
| 472 |
retrieval_query=accumulated.get("retrieval_query", query),
|
|
|
|
| 482 |
"step": "done",
|
| 483 |
"result": {
|
| 484 |
"answer": accumulated.get("answer", ""),
|
| 485 |
+
"sources": [r.to_dict() for r in reranked],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 486 |
"intent": accumulated.get("intent", IntentType.UNKNOWN).value,
|
| 487 |
"confidence": accumulated.get("confidence", 0.0),
|
| 488 |
"pipeline_details": {
|
|
|
|
| 490 |
"retrieval_query": pd_acc.retrieval_query,
|
| 491 |
"detected_language": pd_acc.detected_language,
|
| 492 |
"translated": pd_acc.translated,
|
| 493 |
+
"dense_results": [r.to_dict(include_text=False) for r in pd_acc.dense_results],
|
| 494 |
+
"sparse_results": [r.to_dict(include_text=False) for r in pd_acc.sparse_results],
|
| 495 |
+
"fused_results": [r.to_dict(include_text=False) for r in pd_acc.fused_results],
|
| 496 |
+
"reranked_results": [r.to_dict(include_text=False) for r in pd_acc.reranked_results],
|
| 497 |
},
|
| 498 |
},
|
| 499 |
}
|
src/agent/tools.py
CHANGED
|
@@ -5,7 +5,7 @@ from dataclasses import dataclass, field
|
|
| 5 |
|
| 6 |
from langchain_core.tools import tool
|
| 7 |
|
| 8 |
-
from src.models import
|
| 9 |
from src.retrieval.hybrid import HybridRetriever
|
| 10 |
from src.retrieval.reranker import Reranker
|
| 11 |
from src.retrieval.vector_store import VectorStore
|
|
|
|
| 5 |
|
| 6 |
from langchain_core.tools import tool
|
| 7 |
|
| 8 |
+
from src.models import QueryResult
|
| 9 |
from src.retrieval.hybrid import HybridRetriever
|
| 10 |
from src.retrieval.reranker import Reranker
|
| 11 |
from src.retrieval.vector_store import VectorStore
|
src/api/main.py
CHANGED
|
@@ -83,12 +83,12 @@ def create_app() -> FastAPI:
|
|
| 83 |
else:
|
| 84 |
logger.info("Agent mode: pipeline (fixed DAG)")
|
| 85 |
intent_classifier = IntentClassifier(llm=llm, model_name=settings.generation_model)
|
| 86 |
-
|
| 87 |
query_router = QueryRouter(
|
| 88 |
intent_classifier=intent_classifier,
|
| 89 |
hybrid_retriever=hybrid_retriever,
|
| 90 |
reranker=reranker,
|
| 91 |
-
|
| 92 |
translate_query=settings.translate_query,
|
| 93 |
)
|
| 94 |
|
|
@@ -113,9 +113,18 @@ def create_app() -> FastAPI:
|
|
| 113 |
|
| 114 |
|
| 115 |
def _parse_strategy(settings: "Settings") -> "ChunkStrategy": # noqa: F821
|
| 116 |
-
"""Return the
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
from src.models import ChunkStrategy
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
|
| 121 |
app: FastAPI = create_app()
|
|
|
|
| 83 |
else:
|
| 84 |
logger.info("Agent mode: pipeline (fixed DAG)")
|
| 85 |
intent_classifier = IntentClassifier(llm=llm, model_name=settings.generation_model)
|
| 86 |
+
llm_chain = llm | StrOutputParser()
|
| 87 |
query_router = QueryRouter(
|
| 88 |
intent_classifier=intent_classifier,
|
| 89 |
hybrid_retriever=hybrid_retriever,
|
| 90 |
reranker=reranker,
|
| 91 |
+
llm_chain=llm_chain,
|
| 92 |
translate_query=settings.translate_query,
|
| 93 |
)
|
| 94 |
|
|
|
|
| 113 |
|
| 114 |
|
| 115 |
def _parse_strategy(settings: "Settings") -> "ChunkStrategy": # noqa: F821
|
| 116 |
+
"""Return the chunking strategy from config, defaulting to SEMANTIC.
|
| 117 |
+
|
| 118 |
+
Reads the CHUNK_STRATEGY environment variable via settings. Falls back
|
| 119 |
+
to SEMANTIC when the variable is unset or empty.
|
| 120 |
+
"""
|
| 121 |
from src.models import ChunkStrategy
|
| 122 |
+
|
| 123 |
+
raw = getattr(settings, "chunk_strategy", "semantic")
|
| 124 |
+
try:
|
| 125 |
+
return ChunkStrategy(raw)
|
| 126 |
+
except ValueError:
|
| 127 |
+
return ChunkStrategy.SEMANTIC
|
| 128 |
|
| 129 |
|
| 130 |
app: FastAPI = create_app()
|
src/api/routes.py
CHANGED
|
@@ -183,27 +183,7 @@ async def query_documents(request: QueryRequest) -> QueryResponse:
|
|
| 183 |
) from exc
|
| 184 |
raise
|
| 185 |
|
| 186 |
-
sources = [
|
| 187 |
-
{
|
| 188 |
-
"chunk_id": result.chunk.chunk_id,
|
| 189 |
-
"document_id": result.chunk.document_id,
|
| 190 |
-
"text": result.chunk.text,
|
| 191 |
-
"score": result.score,
|
| 192 |
-
"source": result.source,
|
| 193 |
-
}
|
| 194 |
-
for result in response.sources
|
| 195 |
-
]
|
| 196 |
-
|
| 197 |
-
def _to_pipeline_items(results: list) -> list[PipelineResultItem]:
|
| 198 |
-
return [
|
| 199 |
-
PipelineResultItem(
|
| 200 |
-
document_id=r.chunk.document_id,
|
| 201 |
-
chunk_id=r.chunk.chunk_id,
|
| 202 |
-
score=r.score,
|
| 203 |
-
source=r.source,
|
| 204 |
-
)
|
| 205 |
-
for r in results
|
| 206 |
-
]
|
| 207 |
|
| 208 |
pd = response.pipeline_details
|
| 209 |
pipeline_details = PipelineDetailsResponse(
|
|
@@ -211,10 +191,10 @@ async def query_documents(request: QueryRequest) -> QueryResponse:
|
|
| 211 |
retrieval_query=pd.retrieval_query,
|
| 212 |
detected_language=pd.detected_language,
|
| 213 |
translated=pd.translated,
|
| 214 |
-
dense_results=
|
| 215 |
-
sparse_results=
|
| 216 |
-
fused_results=
|
| 217 |
-
reranked_results=
|
| 218 |
)
|
| 219 |
|
| 220 |
return QueryResponse(
|
|
|
|
| 183 |
) from exc
|
| 184 |
raise
|
| 185 |
|
| 186 |
+
sources = [result.to_dict() for result in response.sources]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
pd = response.pipeline_details
|
| 189 |
pipeline_details = PipelineDetailsResponse(
|
|
|
|
| 191 |
retrieval_query=pd.retrieval_query,
|
| 192 |
detected_language=pd.detected_language,
|
| 193 |
translated=pd.translated,
|
| 194 |
+
dense_results=[PipelineResultItem(**r.to_dict(include_text=False)) for r in pd.dense_results],
|
| 195 |
+
sparse_results=[PipelineResultItem(**r.to_dict(include_text=False)) for r in pd.sparse_results],
|
| 196 |
+
fused_results=[PipelineResultItem(**r.to_dict(include_text=False)) for r in pd.fused_results],
|
| 197 |
+
reranked_results=[PipelineResultItem(**r.to_dict(include_text=False)) for r in pd.reranked_results],
|
| 198 |
)
|
| 199 |
|
| 200 |
return QueryResponse(
|
src/config.py
CHANGED
|
@@ -26,6 +26,7 @@ class Settings:
|
|
| 26 |
embedding_dimension: int
|
| 27 |
generation_model: str
|
| 28 |
reranker_model: str
|
|
|
|
| 29 |
chunk_size: int
|
| 30 |
chunk_overlap: int
|
| 31 |
top_k: int
|
|
@@ -106,6 +107,7 @@ def load_settings() -> Settings:
|
|
| 106 |
embedding_dimension=int(os.environ.get("EMBEDDING_DIMENSION", "384")),
|
| 107 |
generation_model=os.environ.get("GENERATION_MODEL", "gemma4:e4b"),
|
| 108 |
reranker_model=os.environ.get("RERANKER_MODEL", "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1"),
|
|
|
|
| 109 |
chunk_size=int(os.environ.get("CHUNK_SIZE", "512")),
|
| 110 |
chunk_overlap=int(os.environ.get("CHUNK_OVERLAP", "64")),
|
| 111 |
top_k=int(os.environ.get("TOP_K", "5")),
|
|
|
|
| 26 |
embedding_dimension: int
|
| 27 |
generation_model: str
|
| 28 |
reranker_model: str
|
| 29 |
+
chunk_strategy: str
|
| 30 |
chunk_size: int
|
| 31 |
chunk_overlap: int
|
| 32 |
top_k: int
|
|
|
|
| 107 |
embedding_dimension=int(os.environ.get("EMBEDDING_DIMENSION", "384")),
|
| 108 |
generation_model=os.environ.get("GENERATION_MODEL", "gemma4:e4b"),
|
| 109 |
reranker_model=os.environ.get("RERANKER_MODEL", "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1"),
|
| 110 |
+
chunk_strategy=os.environ.get("CHUNK_STRATEGY", "semantic"),
|
| 111 |
chunk_size=int(os.environ.get("CHUNK_SIZE", "512")),
|
| 112 |
chunk_overlap=int(os.environ.get("CHUNK_OVERLAP", "64")),
|
| 113 |
top_k=int(os.environ.get("TOP_K", "5")),
|
src/models.py
CHANGED
|
@@ -56,6 +56,26 @@ class QueryResult:
|
|
| 56 |
score: float
|
| 57 |
source: str
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
@dataclass
|
| 61 |
class PipelineDetails:
|
|
|
|
| 56 |
score: float
|
| 57 |
source: str
|
| 58 |
|
| 59 |
+
def to_dict(self, *, include_text: bool = True) -> dict[str, str | float]:
|
| 60 |
+
"""Serialise to a JSON-safe dictionary.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
include_text: Whether to include the chunk text (default True).
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
Dictionary with chunk_id, document_id, score, source, and
|
| 67 |
+
optionally text.
|
| 68 |
+
"""
|
| 69 |
+
d: dict[str, str | float] = {
|
| 70 |
+
"chunk_id": self.chunk.chunk_id,
|
| 71 |
+
"document_id": self.chunk.document_id,
|
| 72 |
+
"score": self.score,
|
| 73 |
+
"source": self.source,
|
| 74 |
+
}
|
| 75 |
+
if include_text:
|
| 76 |
+
d["text"] = self.chunk.text
|
| 77 |
+
return d
|
| 78 |
+
|
| 79 |
|
| 80 |
@dataclass
|
| 81 |
class PipelineDetails:
|
src/retrieval/bm25_search.py
CHANGED
|
@@ -1,12 +1,7 @@
|
|
| 1 |
"""BM25 sparse retrieval using rank_bm25."""
|
| 2 |
|
| 3 |
import logging
|
| 4 |
-
from typing import Any
|
| 5 |
|
| 6 |
-
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
| 7 |
-
from langchain_core.documents import Document
|
| 8 |
-
from langchain_core.retrievers import BaseRetriever
|
| 9 |
-
from pydantic import ConfigDict
|
| 10 |
from rank_bm25 import BM25Okapi
|
| 11 |
|
| 12 |
from src.models import DocumentChunk, QueryResult
|
|
@@ -64,17 +59,6 @@ class BM25Search:
|
|
| 64 |
logger.debug("BM25 search returned %d results for query: %s", len(results), query)
|
| 65 |
return results
|
| 66 |
|
| 67 |
-
def as_retriever(self, top_k: int) -> BaseRetriever:
|
| 68 |
-
"""Return a LangChain BaseRetriever wrapping this BM25 index.
|
| 69 |
-
|
| 70 |
-
Args:
|
| 71 |
-
top_k: Number of results to return per query.
|
| 72 |
-
|
| 73 |
-
Returns:
|
| 74 |
-
A BaseRetriever that calls search() and returns Documents.
|
| 75 |
-
"""
|
| 76 |
-
return _BM25RetrieverAdapter(bm25_search=self, top_k=top_k)
|
| 77 |
-
|
| 78 |
@staticmethod
|
| 79 |
def _tokenize(text: str) -> list[str]:
|
| 80 |
"""Tokenize text by lowercasing and splitting on whitespace.
|
|
@@ -86,30 +70,3 @@ class BM25Search:
|
|
| 86 |
List of lowercase tokens.
|
| 87 |
"""
|
| 88 |
return text.lower().split()
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
class _BM25RetrieverAdapter(BaseRetriever):
|
| 92 |
-
"""LangChain BaseRetriever adapter over BM25Search."""
|
| 93 |
-
|
| 94 |
-
model_config = ConfigDict(arbitrary_types_allowed=True)
|
| 95 |
-
|
| 96 |
-
bm25_search: Any
|
| 97 |
-
top_k: int
|
| 98 |
-
|
| 99 |
-
def _get_relevant_documents(
|
| 100 |
-
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
| 101 |
-
) -> list[Document]:
|
| 102 |
-
results = self.bm25_search.search(query, self.top_k)
|
| 103 |
-
return [
|
| 104 |
-
Document(
|
| 105 |
-
page_content=r.chunk.text,
|
| 106 |
-
metadata={
|
| 107 |
-
"chunk_id": r.chunk.chunk_id,
|
| 108 |
-
"document_id": r.chunk.document_id,
|
| 109 |
-
"chunk_metadata": r.chunk.metadata,
|
| 110 |
-
"strategy": r.chunk.strategy.value,
|
| 111 |
-
"score": r.score,
|
| 112 |
-
},
|
| 113 |
-
)
|
| 114 |
-
for r in results
|
| 115 |
-
]
|
|
|
|
| 1 |
"""BM25 sparse retrieval using rank_bm25."""
|
| 2 |
|
| 3 |
import logging
|
|
|
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
from rank_bm25 import BM25Okapi
|
| 6 |
|
| 7 |
from src.models import DocumentChunk, QueryResult
|
|
|
|
| 59 |
logger.debug("BM25 search returned %d results for query: %s", len(results), query)
|
| 60 |
return results
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
@staticmethod
|
| 63 |
def _tokenize(text: str) -> list[str]:
|
| 64 |
"""Tokenize text by lowercasing and splitting on whitespace.
|
|
|
|
| 70 |
List of lowercase tokens.
|
| 71 |
"""
|
| 72 |
return text.lower().split()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/retrieval/hybrid.py
CHANGED
|
@@ -3,9 +3,7 @@
|
|
| 3 |
import logging
|
| 4 |
from dataclasses import dataclass
|
| 5 |
|
| 6 |
-
from
|
| 7 |
-
|
| 8 |
-
from src.models import ChunkStrategy, DocumentChunk, QueryResult
|
| 9 |
from src.retrieval.bm25_search import BM25Search
|
| 10 |
from src.retrieval.embedder import Embedder
|
| 11 |
from src.retrieval.vector_store import VectorStore
|
|
@@ -70,8 +68,6 @@ class HybridRetriever:
|
|
| 70 |
def search_detailed(self, query: str, top_k: int) -> HybridSearchResult:
|
| 71 |
"""Execute hybrid search and return all intermediate results.
|
| 72 |
|
| 73 |
-
Uses LangChain BaseRetriever.invoke() for both dense and sparse retrieval.
|
| 74 |
-
|
| 75 |
Args:
|
| 76 |
query: The search query string.
|
| 77 |
top_k: Number of top results to return after fusion.
|
|
@@ -79,14 +75,9 @@ class HybridRetriever:
|
|
| 79 |
Returns:
|
| 80 |
HybridSearchResult containing dense, sparse, and fused results.
|
| 81 |
"""
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
dense_docs: list[Document] = dense_retriever.invoke(query)
|
| 86 |
-
sparse_docs: list[Document] = sparse_retriever.invoke(query)
|
| 87 |
-
|
| 88 |
-
dense_results = [self._doc_to_query_result(doc, "dense") for doc in dense_docs]
|
| 89 |
-
sparse_results = [self._doc_to_query_result(doc, "bm25") for doc in sparse_docs]
|
| 90 |
|
| 91 |
logger.debug(
|
| 92 |
"Hybrid search: %d dense, %d sparse results",
|
|
@@ -101,27 +92,6 @@ class HybridRetriever:
|
|
| 101 |
fused_results=fused[:top_k],
|
| 102 |
)
|
| 103 |
|
| 104 |
-
@staticmethod
|
| 105 |
-
def _doc_to_query_result(doc: Document, source: str) -> QueryResult:
|
| 106 |
-
"""Convert a LangChain Document to a QueryResult.
|
| 107 |
-
|
| 108 |
-
Args:
|
| 109 |
-
doc: Document returned by a BaseRetriever.
|
| 110 |
-
source: Retrieval source label (e.g. 'dense' or 'bm25').
|
| 111 |
-
|
| 112 |
-
Returns:
|
| 113 |
-
QueryResult with chunk and score populated from document metadata.
|
| 114 |
-
"""
|
| 115 |
-
meta = doc.metadata
|
| 116 |
-
chunk = DocumentChunk(
|
| 117 |
-
chunk_id=meta.get("chunk_id", ""),
|
| 118 |
-
document_id=meta.get("document_id", ""),
|
| 119 |
-
text=doc.page_content,
|
| 120 |
-
metadata=meta.get("chunk_metadata", {}),
|
| 121 |
-
strategy=ChunkStrategy(meta.get("strategy", ChunkStrategy.RECURSIVE.value)),
|
| 122 |
-
)
|
| 123 |
-
return QueryResult(chunk=chunk, score=float(meta.get("score", 0.0)), source=source)
|
| 124 |
-
|
| 125 |
def reciprocal_rank_fusion(
|
| 126 |
self,
|
| 127 |
dense_results: list[QueryResult],
|
|
@@ -138,9 +108,8 @@ class HybridRetriever:
|
|
| 138 |
Returns:
|
| 139 |
Merged and re-ranked list of QueryResult objects.
|
| 140 |
"""
|
| 141 |
-
# Map chunk_id -> (rrf_score, best QueryResult)
|
| 142 |
scores: dict[str, float] = {}
|
| 143 |
-
best_chunk: dict[str,
|
| 144 |
|
| 145 |
for rank, result in enumerate(dense_results):
|
| 146 |
cid = result.chunk.chunk_id
|
|
|
|
| 3 |
import logging
|
| 4 |
from dataclasses import dataclass
|
| 5 |
|
| 6 |
+
from src.models import DocumentChunk, QueryResult
|
|
|
|
|
|
|
| 7 |
from src.retrieval.bm25_search import BM25Search
|
| 8 |
from src.retrieval.embedder import Embedder
|
| 9 |
from src.retrieval.vector_store import VectorStore
|
|
|
|
| 68 |
def search_detailed(self, query: str, top_k: int) -> HybridSearchResult:
|
| 69 |
"""Execute hybrid search and return all intermediate results.
|
| 70 |
|
|
|
|
|
|
|
| 71 |
Args:
|
| 72 |
query: The search query string.
|
| 73 |
top_k: Number of top results to return after fusion.
|
|
|
|
| 75 |
Returns:
|
| 76 |
HybridSearchResult containing dense, sparse, and fused results.
|
| 77 |
"""
|
| 78 |
+
query_embedding = self._embedder.embed_text(query)
|
| 79 |
+
dense_results = self._vector_store.search(query_embedding, top_k)
|
| 80 |
+
sparse_results = self._bm25_search.search(query, top_k)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
logger.debug(
|
| 83 |
"Hybrid search: %d dense, %d sparse results",
|
|
|
|
| 92 |
fused_results=fused[:top_k],
|
| 93 |
)
|
| 94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
def reciprocal_rank_fusion(
|
| 96 |
self,
|
| 97 |
dense_results: list[QueryResult],
|
|
|
|
| 108 |
Returns:
|
| 109 |
Merged and re-ranked list of QueryResult objects.
|
| 110 |
"""
|
|
|
|
| 111 |
scores: dict[str, float] = {}
|
| 112 |
+
best_chunk: dict[str, DocumentChunk] = {}
|
| 113 |
|
| 114 |
for rank, result in enumerate(dense_results):
|
| 115 |
cid = result.chunk.chunk_id
|
src/retrieval/vector_store.py
CHANGED
|
@@ -2,12 +2,7 @@
|
|
| 2 |
|
| 3 |
import json
|
| 4 |
import logging
|
| 5 |
-
from typing import Any
|
| 6 |
|
| 7 |
-
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
| 8 |
-
from langchain_core.documents import Document
|
| 9 |
-
from langchain_core.retrievers import BaseRetriever
|
| 10 |
-
from pydantic import ConfigDict
|
| 11 |
from qdrant_client import QdrantClient
|
| 12 |
from qdrant_client.models import Distance, FieldCondition, Filter, MatchValue, PointStruct, VectorParams
|
| 13 |
|
|
@@ -16,6 +11,24 @@ from src.models import ChunkStrategy, DocumentChunk, QueryResult
|
|
| 16 |
logger = logging.getLogger(__name__)
|
| 17 |
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
class VectorStore:
|
| 20 |
"""Manages document storage and dense retrieval via Qdrant."""
|
| 21 |
|
|
@@ -96,18 +109,10 @@ class VectorStore:
|
|
| 96 |
limit=top_k,
|
| 97 |
).points
|
| 98 |
|
| 99 |
-
results: list[QueryResult] = [
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
chunk_id=payload["chunk_id"],
|
| 104 |
-
document_id=payload["document_id"],
|
| 105 |
-
text=payload["text"],
|
| 106 |
-
metadata=json.loads(payload["metadata"]),
|
| 107 |
-
strategy=ChunkStrategy(payload["strategy"]),
|
| 108 |
-
)
|
| 109 |
-
results.append(QueryResult(chunk=chunk, score=hit.score, source="dense"))
|
| 110 |
-
|
| 111 |
logger.debug("Dense search returned %d results", len(results))
|
| 112 |
return results
|
| 113 |
|
|
@@ -129,18 +134,7 @@ class VectorStore:
|
|
| 129 |
with_vectors=False,
|
| 130 |
)
|
| 131 |
|
| 132 |
-
chunks
|
| 133 |
-
for record in records:
|
| 134 |
-
payload = record.payload
|
| 135 |
-
chunks.append(
|
| 136 |
-
DocumentChunk(
|
| 137 |
-
chunk_id=payload["chunk_id"],
|
| 138 |
-
document_id=payload["document_id"],
|
| 139 |
-
text=payload["text"],
|
| 140 |
-
metadata=json.loads(payload["metadata"]),
|
| 141 |
-
strategy=ChunkStrategy(payload["strategy"]),
|
| 142 |
-
)
|
| 143 |
-
)
|
| 144 |
logger.info("Loaded %d chunks from collection '%s'", len(chunks), self._collection_name)
|
| 145 |
return chunks
|
| 146 |
|
|
@@ -176,67 +170,13 @@ class VectorStore:
|
|
| 176 |
with_vectors=False,
|
| 177 |
)
|
| 178 |
|
| 179 |
-
chunks
|
| 180 |
-
for record in records:
|
| 181 |
-
payload = record.payload
|
| 182 |
-
chunks.append(
|
| 183 |
-
DocumentChunk(
|
| 184 |
-
chunk_id=payload["chunk_id"],
|
| 185 |
-
document_id=payload["document_id"],
|
| 186 |
-
text=payload["text"],
|
| 187 |
-
metadata=json.loads(payload["metadata"]),
|
| 188 |
-
strategy=ChunkStrategy(payload["strategy"]),
|
| 189 |
-
)
|
| 190 |
-
)
|
| 191 |
logger.debug(
|
| 192 |
"Fetched %d chunks for document '%s'", len(chunks), document_id
|
| 193 |
)
|
| 194 |
return chunks
|
| 195 |
|
| 196 |
-
def as_retriever(self, embedder: Any, top_k: int) -> BaseRetriever:
|
| 197 |
-
"""Return a LangChain BaseRetriever wrapping this vector store.
|
| 198 |
-
|
| 199 |
-
Args:
|
| 200 |
-
embedder: Embedder instance used to encode queries.
|
| 201 |
-
top_k: Number of results to return per query.
|
| 202 |
-
|
| 203 |
-
Returns:
|
| 204 |
-
A BaseRetriever that calls search() and returns Documents.
|
| 205 |
-
"""
|
| 206 |
-
return _VectorStoreRetrieverAdapter(
|
| 207 |
-
vector_store=self, embedder=embedder, top_k=top_k
|
| 208 |
-
)
|
| 209 |
-
|
| 210 |
def delete_collection(self) -> None:
|
| 211 |
"""Delete the entire collection from the store."""
|
| 212 |
self._client.delete_collection(collection_name=self._collection_name)
|
| 213 |
logger.info("Deleted Qdrant collection '%s'", self._collection_name)
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
class _VectorStoreRetrieverAdapter(BaseRetriever):
|
| 217 |
-
"""LangChain BaseRetriever adapter over VectorStore."""
|
| 218 |
-
|
| 219 |
-
model_config = ConfigDict(arbitrary_types_allowed=True)
|
| 220 |
-
|
| 221 |
-
vector_store: Any
|
| 222 |
-
embedder: Any
|
| 223 |
-
top_k: int
|
| 224 |
-
|
| 225 |
-
def _get_relevant_documents(
|
| 226 |
-
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
| 227 |
-
) -> list[Document]:
|
| 228 |
-
query_embedding = self.embedder.embed_text(query)
|
| 229 |
-
results = self.vector_store.search(query_embedding, self.top_k)
|
| 230 |
-
return [
|
| 231 |
-
Document(
|
| 232 |
-
page_content=r.chunk.text,
|
| 233 |
-
metadata={
|
| 234 |
-
"chunk_id": r.chunk.chunk_id,
|
| 235 |
-
"document_id": r.chunk.document_id,
|
| 236 |
-
"chunk_metadata": r.chunk.metadata,
|
| 237 |
-
"strategy": r.chunk.strategy.value,
|
| 238 |
-
"score": r.score,
|
| 239 |
-
},
|
| 240 |
-
)
|
| 241 |
-
for r in results
|
| 242 |
-
]
|
|
|
|
| 2 |
|
| 3 |
import json
|
| 4 |
import logging
|
|
|
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
from qdrant_client import QdrantClient
|
| 7 |
from qdrant_client.models import Distance, FieldCondition, Filter, MatchValue, PointStruct, VectorParams
|
| 8 |
|
|
|
|
| 11 |
logger = logging.getLogger(__name__)
|
| 12 |
|
| 13 |
|
| 14 |
+
def _payload_to_chunk(payload: dict) -> DocumentChunk:
|
| 15 |
+
"""Convert a Qdrant payload dict to a DocumentChunk.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
payload: Qdrant point payload.
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
DocumentChunk reconstructed from the payload.
|
| 22 |
+
"""
|
| 23 |
+
return DocumentChunk(
|
| 24 |
+
chunk_id=payload["chunk_id"],
|
| 25 |
+
document_id=payload["document_id"],
|
| 26 |
+
text=payload["text"],
|
| 27 |
+
metadata=json.loads(payload["metadata"]),
|
| 28 |
+
strategy=ChunkStrategy(payload["strategy"]),
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
class VectorStore:
|
| 33 |
"""Manages document storage and dense retrieval via Qdrant."""
|
| 34 |
|
|
|
|
| 109 |
limit=top_k,
|
| 110 |
).points
|
| 111 |
|
| 112 |
+
results: list[QueryResult] = [
|
| 113 |
+
QueryResult(chunk=_payload_to_chunk(hit.payload), score=hit.score, source="dense")
|
| 114 |
+
for hit in hits
|
| 115 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
logger.debug("Dense search returned %d results", len(results))
|
| 117 |
return results
|
| 118 |
|
|
|
|
| 134 |
with_vectors=False,
|
| 135 |
)
|
| 136 |
|
| 137 |
+
chunks = [_payload_to_chunk(record.payload) for record in records]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
logger.info("Loaded %d chunks from collection '%s'", len(chunks), self._collection_name)
|
| 139 |
return chunks
|
| 140 |
|
|
|
|
| 170 |
with_vectors=False,
|
| 171 |
)
|
| 172 |
|
| 173 |
+
chunks = [_payload_to_chunk(record.payload) for record in records]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
logger.debug(
|
| 175 |
"Fetched %d chunks for document '%s'", len(chunks), document_id
|
| 176 |
)
|
| 177 |
return chunks
|
| 178 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
def delete_collection(self) -> None:
|
| 180 |
"""Delete the entire collection from the store."""
|
| 181 |
self._client.delete_collection(collection_name=self._collection_name)
|
| 182 |
logger.info("Deleted Qdrant collection '%s'", self._collection_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/test_hybrid.py
CHANGED
|
@@ -4,8 +4,6 @@ from unittest.mock import MagicMock
|
|
| 4 |
|
| 5 |
import pytest
|
| 6 |
|
| 7 |
-
from langchain_core.documents import Document
|
| 8 |
-
|
| 9 |
from src.models import ChunkStrategy, DocumentChunk, QueryResult
|
| 10 |
from src.retrieval.hybrid import HybridRetriever
|
| 11 |
|
|
@@ -16,20 +14,6 @@ def _make_result(chunk_id: str, score: float = 0.0, source: str = "test") -> Que
|
|
| 16 |
return QueryResult(chunk=chunk, score=score, source=source)
|
| 17 |
|
| 18 |
|
| 19 |
-
def _result_to_doc(result: QueryResult) -> Document:
|
| 20 |
-
"""Convert a QueryResult to a LangChain Document (mirrors the adapter output)."""
|
| 21 |
-
return Document(
|
| 22 |
-
page_content=result.chunk.text,
|
| 23 |
-
metadata={
|
| 24 |
-
"chunk_id": result.chunk.chunk_id,
|
| 25 |
-
"document_id": result.chunk.document_id,
|
| 26 |
-
"chunk_metadata": result.chunk.metadata,
|
| 27 |
-
"strategy": result.chunk.strategy.value,
|
| 28 |
-
"score": result.score,
|
| 29 |
-
},
|
| 30 |
-
)
|
| 31 |
-
|
| 32 |
-
|
| 33 |
def _build_retriever(
|
| 34 |
dense_results: list[QueryResult],
|
| 35 |
sparse_results: list[QueryResult],
|
|
@@ -38,22 +22,17 @@ def _build_retriever(
|
|
| 38 |
) -> HybridRetriever:
|
| 39 |
"""Build a HybridRetriever with mocked dependencies.
|
| 40 |
|
| 41 |
-
Mocks
|
| 42 |
-
|
| 43 |
"""
|
| 44 |
-
dense_retriever_mock = MagicMock()
|
| 45 |
-
dense_retriever_mock.invoke.return_value = [_result_to_doc(r) for r in dense_results]
|
| 46 |
-
|
| 47 |
vector_store = MagicMock()
|
| 48 |
-
vector_store.
|
| 49 |
-
|
| 50 |
-
sparse_retriever_mock = MagicMock()
|
| 51 |
-
sparse_retriever_mock.invoke.return_value = [_result_to_doc(r) for r in sparse_results]
|
| 52 |
|
| 53 |
bm25_search = MagicMock()
|
| 54 |
-
bm25_search.
|
| 55 |
|
| 56 |
embedder = MagicMock()
|
|
|
|
| 57 |
|
| 58 |
return HybridRetriever(
|
| 59 |
vector_store=vector_store,
|
|
|
|
| 4 |
|
| 5 |
import pytest
|
| 6 |
|
|
|
|
|
|
|
| 7 |
from src.models import ChunkStrategy, DocumentChunk, QueryResult
|
| 8 |
from src.retrieval.hybrid import HybridRetriever
|
| 9 |
|
|
|
|
| 14 |
return QueryResult(chunk=chunk, score=score, source=source)
|
| 15 |
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
def _build_retriever(
|
| 18 |
dense_results: list[QueryResult],
|
| 19 |
sparse_results: list[QueryResult],
|
|
|
|
| 22 |
) -> HybridRetriever:
|
| 23 |
"""Build a HybridRetriever with mocked dependencies.
|
| 24 |
|
| 25 |
+
Mocks vector_store.search() and bm25_search.search() since
|
| 26 |
+
HybridRetriever calls them directly.
|
| 27 |
"""
|
|
|
|
|
|
|
|
|
|
| 28 |
vector_store = MagicMock()
|
| 29 |
+
vector_store.search.return_value = dense_results
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
bm25_search = MagicMock()
|
| 32 |
+
bm25_search.search.return_value = sparse_results
|
| 33 |
|
| 34 |
embedder = MagicMock()
|
| 35 |
+
embedder.embed_text.return_value = [0.0] * 384
|
| 36 |
|
| 37 |
return HybridRetriever(
|
| 38 |
vector_store=vector_store,
|
tests/test_router.py
CHANGED
|
@@ -35,36 +35,36 @@ def _make_hybrid_result(results: list[QueryResult]) -> MagicMock:
|
|
| 35 |
|
| 36 |
@pytest.fixture
|
| 37 |
def mock_components():
|
| 38 |
-
"""Create mock intent classifier, retriever, reranker, and
|
| 39 |
classifier = MagicMock()
|
| 40 |
retriever = MagicMock()
|
| 41 |
reranker = MagicMock()
|
| 42 |
-
|
| 43 |
-
return classifier, retriever, reranker,
|
| 44 |
|
| 45 |
|
| 46 |
-
def
|
| 47 |
-
|
| 48 |
) -> None:
|
| 49 |
-
"""Configure
|
| 50 |
|
| 51 |
The first invoke returns the combined language+intent response,
|
| 52 |
the second invoke returns the final answer.
|
| 53 |
"""
|
| 54 |
combined = f"language: Danish\nintent: {intent}"
|
| 55 |
-
|
| 56 |
|
| 57 |
|
| 58 |
-
def
|
| 59 |
-
|
| 60 |
) -> None:
|
| 61 |
-
"""Configure
|
| 62 |
|
| 63 |
The first invoke returns combined language+intent, the second returns the
|
| 64 |
translated query, and the third returns the final answer.
|
| 65 |
"""
|
| 66 |
combined = f"language: English\nintent: {intent}"
|
| 67 |
-
|
| 68 |
|
| 69 |
|
| 70 |
class TestQueryRouterRAG:
|
|
@@ -80,14 +80,14 @@ class TestQueryRouterRAG:
|
|
| 80 |
self, mock_components, intent_str: str, expected_intent: IntentType
|
| 81 |
) -> None:
|
| 82 |
"""RAG intents should retrieve, rerank, and generate an answer."""
|
| 83 |
-
classifier, retriever, reranker,
|
| 84 |
|
| 85 |
results = [_make_query_result("policy text", 0.85)]
|
| 86 |
retriever.search_detailed.return_value = _make_hybrid_result(results)
|
| 87 |
reranker.rerank.return_value = results
|
| 88 |
-
|
| 89 |
|
| 90 |
-
router = QueryRouter(classifier, retriever, reranker,
|
| 91 |
response = router.route("Hvad er KU's feriepolitik?", top_k=3)
|
| 92 |
|
| 93 |
assert isinstance(response, GenerationResponse)
|
|
@@ -104,35 +104,35 @@ class TestQueryRouterRAG:
|
|
| 104 |
)
|
| 105 |
|
| 106 |
def test_prompt_contains_context_and_query(self, mock_components) -> None:
|
| 107 |
-
"""The prompt sent to the
|
| 108 |
-
classifier, retriever, reranker,
|
| 109 |
|
| 110 |
results = [_make_query_result("Relevant context text", 0.9)]
|
| 111 |
retriever.search_detailed.return_value = _make_hybrid_result(results)
|
| 112 |
reranker.rerank.return_value = results
|
| 113 |
-
|
| 114 |
|
| 115 |
-
router = QueryRouter(classifier, retriever, reranker,
|
| 116 |
router.route("test query", top_k=3)
|
| 117 |
|
| 118 |
# The final invoke call is the generation call
|
| 119 |
-
prompt =
|
| 120 |
assert "Relevant context text" in prompt
|
| 121 |
assert "test query" in prompt
|
| 122 |
|
| 123 |
def test_prompt_contains_language_rule(self, mock_components) -> None:
|
| 124 |
"""The prompt should contain a language instruction matching user language."""
|
| 125 |
-
classifier, retriever, reranker,
|
| 126 |
|
| 127 |
results = [_make_query_result("ctx", 0.5)]
|
| 128 |
retriever.search_detailed.return_value = _make_hybrid_result(results)
|
| 129 |
reranker.rerank.return_value = results
|
| 130 |
-
|
| 131 |
|
| 132 |
-
router = QueryRouter(classifier, retriever, reranker,
|
| 133 |
router.route("What is KU's vacation policy?", top_k=3)
|
| 134 |
|
| 135 |
-
prompt =
|
| 136 |
assert "MUST answer in English" in prompt
|
| 137 |
|
| 138 |
|
|
@@ -141,12 +141,12 @@ class TestQueryRouterDirect:
|
|
| 141 |
|
| 142 |
def test_unknown_intent_still_generates_answer(self, mock_components) -> None:
|
| 143 |
"""UNKNOWN intent skips retrieval and returns zero confidence."""
|
| 144 |
-
classifier, retriever, reranker,
|
| 145 |
|
| 146 |
reranker.rerank.return_value = []
|
| 147 |
-
|
| 148 |
|
| 149 |
-
router = QueryRouter(classifier, retriever, reranker,
|
| 150 |
response = router.route("Hej, hvad kan du hjælpe med?", top_k=3)
|
| 151 |
|
| 152 |
assert response.answer == "Fallback answer"
|
|
@@ -158,15 +158,15 @@ class TestQueryRouterDirect:
|
|
| 158 |
self, mock_components
|
| 159 |
) -> None:
|
| 160 |
"""UNKNOWN intent should use the generic helpful instruction."""
|
| 161 |
-
classifier, retriever, reranker,
|
| 162 |
|
| 163 |
reranker.rerank.return_value = []
|
| 164 |
-
|
| 165 |
|
| 166 |
-
router = QueryRouter(classifier, retriever, reranker,
|
| 167 |
router.route("random input", top_k=3)
|
| 168 |
|
| 169 |
-
prompt =
|
| 170 |
assert "as helpfully as possible" in prompt
|
| 171 |
|
| 172 |
|
|
@@ -177,38 +177,38 @@ class TestQueryRouterFallback:
|
|
| 177 |
self, mock_components
|
| 178 |
) -> None:
|
| 179 |
"""When reranker returns no results, confidence should be 0.0."""
|
| 180 |
-
classifier, retriever, reranker,
|
| 181 |
|
| 182 |
retriever.search_detailed.return_value = _make_hybrid_result([])
|
| 183 |
reranker.rerank.return_value = []
|
| 184 |
-
|
| 185 |
|
| 186 |
-
router = QueryRouter(classifier, retriever, reranker,
|
| 187 |
response = router.route("asdfghjkl", top_k=3)
|
| 188 |
|
| 189 |
assert response.confidence == 0.0
|
| 190 |
assert response.sources == []
|
| 191 |
assert response.answer == "No information found"
|
| 192 |
|
| 193 |
-
def
|
| 194 |
"""When no chunks are retrieved, the prompt context should be empty."""
|
| 195 |
-
classifier, retriever, reranker,
|
| 196 |
|
| 197 |
retriever.search_detailed.return_value = _make_hybrid_result([])
|
| 198 |
reranker.rerank.return_value = []
|
| 199 |
-
|
| 200 |
|
| 201 |
-
router = QueryRouter(classifier, retriever, reranker,
|
| 202 |
router.route("gibberish", top_k=3)
|
| 203 |
|
| 204 |
-
prompt =
|
| 205 |
assert "Context:\n\n" in prompt
|
| 206 |
|
| 207 |
def test_multiple_results_confidence_uses_max_score(
|
| 208 |
self, mock_components
|
| 209 |
) -> None:
|
| 210 |
"""Confidence should be the maximum score among reranked results."""
|
| 211 |
-
classifier, retriever, reranker,
|
| 212 |
|
| 213 |
results = [
|
| 214 |
_make_query_result("low", 0.3),
|
|
@@ -217,9 +217,9 @@ class TestQueryRouterFallback:
|
|
| 217 |
]
|
| 218 |
retriever.search_detailed.return_value = _make_hybrid_result(results)
|
| 219 |
reranker.rerank.return_value = results
|
| 220 |
-
|
| 221 |
|
| 222 |
-
router = QueryRouter(classifier, retriever, reranker,
|
| 223 |
response = router.route("opsummer politikken", top_k=5)
|
| 224 |
|
| 225 |
assert response.confidence == pytest.approx(0.95, abs=1e-6)
|
|
@@ -230,53 +230,53 @@ class TestQueryTranslation:
|
|
| 230 |
|
| 231 |
def test_danish_query_not_translated(self, mock_components) -> None:
|
| 232 |
"""Danish queries should be passed directly to retrieval without translation."""
|
| 233 |
-
classifier, retriever, reranker,
|
| 234 |
|
| 235 |
results = [_make_query_result("ctx", 0.5)]
|
| 236 |
retriever.search_detailed.return_value = _make_hybrid_result(results)
|
| 237 |
reranker.rerank.return_value = results
|
| 238 |
-
|
| 239 |
|
| 240 |
-
router = QueryRouter(classifier, retriever, reranker,
|
| 241 |
router.route("Hvad er reglerne?", top_k=3)
|
| 242 |
|
| 243 |
# Only 2 invoke calls: combined detection + generation (no translation)
|
| 244 |
-
assert
|
| 245 |
retriever.search_detailed.assert_called_once_with("Hvad er reglerne?", top_k=3)
|
| 246 |
|
| 247 |
def test_english_query_translated_for_retrieval(self, mock_components) -> None:
|
| 248 |
"""English queries should be translated to Danish for retrieval."""
|
| 249 |
-
classifier, retriever, reranker,
|
| 250 |
|
| 251 |
results = [_make_query_result("ctx", 0.5)]
|
| 252 |
retriever.search_detailed.return_value = _make_hybrid_result(results)
|
| 253 |
reranker.rerank.return_value = results
|
| 254 |
-
|
| 255 |
|
| 256 |
-
router = QueryRouter(classifier, retriever, reranker,
|
| 257 |
response = router.route("What are the rules?", top_k=3)
|
| 258 |
|
| 259 |
# 3 invoke calls: combined detection + translation + generation
|
| 260 |
-
assert
|
| 261 |
retriever.search_detailed.assert_called_once_with("Hvad er reglerne?", top_k=3)
|
| 262 |
reranker.rerank.assert_called_once_with("Hvad er reglerne?", results, top_k=3)
|
| 263 |
assert response.answer == "The rules are..."
|
| 264 |
|
| 265 |
def test_translation_disabled_skips_translate(self, mock_components) -> None:
|
| 266 |
"""When translate_query=False, English queries go straight to retrieval untranslated."""
|
| 267 |
-
classifier, retriever, reranker,
|
| 268 |
|
| 269 |
results = [_make_query_result("ctx", 0.5)]
|
| 270 |
retriever.search_detailed.return_value = _make_hybrid_result(results)
|
| 271 |
reranker.rerank.return_value = results
|
| 272 |
# Only 2 calls: combined detection + generation (no translation)
|
| 273 |
combined = "language: English\nintent: rag"
|
| 274 |
-
|
| 275 |
|
| 276 |
-
router = QueryRouter(classifier, retriever, reranker,
|
| 277 |
response = router.route("What are the rules?", top_k=3)
|
| 278 |
|
| 279 |
-
assert
|
| 280 |
retriever.search_detailed.assert_called_once_with("What are the rules?", top_k=3)
|
| 281 |
assert response.answer == "The answer"
|
| 282 |
|
|
@@ -286,7 +286,7 @@ class TestSigmoidInReranker:
|
|
| 286 |
|
| 287 |
def test_confidence_equals_max_reranked_score(self, mock_components) -> None:
|
| 288 |
"""Confidence should equal the max reranked score (already sigmoid-normalized)."""
|
| 289 |
-
classifier, retriever, reranker,
|
| 290 |
|
| 291 |
results = [
|
| 292 |
_make_query_result("a", 0.7),
|
|
@@ -294,9 +294,9 @@ class TestSigmoidInReranker:
|
|
| 294 |
]
|
| 295 |
retriever.search_detailed.return_value = _make_hybrid_result(results)
|
| 296 |
reranker.rerank.return_value = results
|
| 297 |
-
|
| 298 |
|
| 299 |
-
router = QueryRouter(classifier, retriever, reranker,
|
| 300 |
response = router.route("test", top_k=3)
|
| 301 |
|
| 302 |
assert response.confidence == pytest.approx(0.9, abs=1e-6)
|
|
|
|
| 35 |
|
| 36 |
@pytest.fixture
|
| 37 |
def mock_components():
|
| 38 |
+
"""Create mock intent classifier, retriever, reranker, and llm_chain."""
|
| 39 |
classifier = MagicMock()
|
| 40 |
retriever = MagicMock()
|
| 41 |
reranker = MagicMock()
|
| 42 |
+
llm_chain = MagicMock()
|
| 43 |
+
return classifier, retriever, reranker, llm_chain
|
| 44 |
|
| 45 |
|
| 46 |
+
def _setup_llm_chain_danish(
|
| 47 |
+
llm_chain: MagicMock, final_answer: str, intent: str = "factual"
|
| 48 |
) -> None:
|
| 49 |
+
"""Configure llm_chain mock for Danish queries (no translation needed).
|
| 50 |
|
| 51 |
The first invoke returns the combined language+intent response,
|
| 52 |
the second invoke returns the final answer.
|
| 53 |
"""
|
| 54 |
combined = f"language: Danish\nintent: {intent}"
|
| 55 |
+
llm_chain.invoke.side_effect = [combined, final_answer]
|
| 56 |
|
| 57 |
|
| 58 |
+
def _setup_llm_chain_english(
|
| 59 |
+
llm_chain: MagicMock, translated_query: str, final_answer: str, intent: str = "rag"
|
| 60 |
) -> None:
|
| 61 |
+
"""Configure llm_chain mock for English queries (combined detection + translation + answer).
|
| 62 |
|
| 63 |
The first invoke returns combined language+intent, the second returns the
|
| 64 |
translated query, and the third returns the final answer.
|
| 65 |
"""
|
| 66 |
combined = f"language: English\nintent: {intent}"
|
| 67 |
+
llm_chain.invoke.side_effect = [combined, translated_query, final_answer]
|
| 68 |
|
| 69 |
|
| 70 |
class TestQueryRouterRAG:
|
|
|
|
| 80 |
self, mock_components, intent_str: str, expected_intent: IntentType
|
| 81 |
) -> None:
|
| 82 |
"""RAG intents should retrieve, rerank, and generate an answer."""
|
| 83 |
+
classifier, retriever, reranker, llm_chain = mock_components
|
| 84 |
|
| 85 |
results = [_make_query_result("policy text", 0.85)]
|
| 86 |
retriever.search_detailed.return_value = _make_hybrid_result(results)
|
| 87 |
reranker.rerank.return_value = results
|
| 88 |
+
_setup_llm_chain_danish(llm_chain, "Generated answer", intent=intent_str)
|
| 89 |
|
| 90 |
+
router = QueryRouter(classifier, retriever, reranker, llm_chain)
|
| 91 |
response = router.route("Hvad er KU's feriepolitik?", top_k=3)
|
| 92 |
|
| 93 |
assert isinstance(response, GenerationResponse)
|
|
|
|
| 104 |
)
|
| 105 |
|
| 106 |
def test_prompt_contains_context_and_query(self, mock_components) -> None:
|
| 107 |
+
"""The prompt sent to the LLM chain should include context and query."""
|
| 108 |
+
classifier, retriever, reranker, llm_chain = mock_components
|
| 109 |
|
| 110 |
results = [_make_query_result("Relevant context text", 0.9)]
|
| 111 |
retriever.search_detailed.return_value = _make_hybrid_result(results)
|
| 112 |
reranker.rerank.return_value = results
|
| 113 |
+
_setup_llm_chain_danish(llm_chain, "answer", intent="factual")
|
| 114 |
|
| 115 |
+
router = QueryRouter(classifier, retriever, reranker, llm_chain)
|
| 116 |
router.route("test query", top_k=3)
|
| 117 |
|
| 118 |
# The final invoke call is the generation call
|
| 119 |
+
prompt = llm_chain.invoke.call_args_list[-1][0][0]
|
| 120 |
assert "Relevant context text" in prompt
|
| 121 |
assert "test query" in prompt
|
| 122 |
|
| 123 |
def test_prompt_contains_language_rule(self, mock_components) -> None:
|
| 124 |
"""The prompt should contain a language instruction matching user language."""
|
| 125 |
+
classifier, retriever, reranker, llm_chain = mock_components
|
| 126 |
|
| 127 |
results = [_make_query_result("ctx", 0.5)]
|
| 128 |
retriever.search_detailed.return_value = _make_hybrid_result(results)
|
| 129 |
reranker.rerank.return_value = results
|
| 130 |
+
_setup_llm_chain_english(llm_chain, "oversæt forespørgsel", "answer", intent="rag")
|
| 131 |
|
| 132 |
+
router = QueryRouter(classifier, retriever, reranker, llm_chain)
|
| 133 |
router.route("What is KU's vacation policy?", top_k=3)
|
| 134 |
|
| 135 |
+
prompt = llm_chain.invoke.call_args_list[-1][0][0]
|
| 136 |
assert "MUST answer in English" in prompt
|
| 137 |
|
| 138 |
|
|
|
|
| 141 |
|
| 142 |
def test_unknown_intent_still_generates_answer(self, mock_components) -> None:
|
| 143 |
"""UNKNOWN intent skips retrieval and returns zero confidence."""
|
| 144 |
+
classifier, retriever, reranker, llm_chain = mock_components
|
| 145 |
|
| 146 |
reranker.rerank.return_value = []
|
| 147 |
+
_setup_llm_chain_danish(llm_chain, "Fallback answer", intent="unknown")
|
| 148 |
|
| 149 |
+
router = QueryRouter(classifier, retriever, reranker, llm_chain)
|
| 150 |
response = router.route("Hej, hvad kan du hjælpe med?", top_k=3)
|
| 151 |
|
| 152 |
assert response.answer == "Fallback answer"
|
|
|
|
| 158 |
self, mock_components
|
| 159 |
) -> None:
|
| 160 |
"""UNKNOWN intent should use the generic helpful instruction."""
|
| 161 |
+
classifier, retriever, reranker, llm_chain = mock_components
|
| 162 |
|
| 163 |
reranker.rerank.return_value = []
|
| 164 |
+
_setup_llm_chain_danish(llm_chain, "answer", intent="unknown")
|
| 165 |
|
| 166 |
+
router = QueryRouter(classifier, retriever, reranker, llm_chain)
|
| 167 |
router.route("random input", top_k=3)
|
| 168 |
|
| 169 |
+
prompt = llm_chain.invoke.call_args_list[-1][0][0]
|
| 170 |
assert "as helpfully as possible" in prompt
|
| 171 |
|
| 172 |
|
|
|
|
| 177 |
self, mock_components
|
| 178 |
) -> None:
|
| 179 |
"""When reranker returns no results, confidence should be 0.0."""
|
| 180 |
+
classifier, retriever, reranker, llm_chain = mock_components
|
| 181 |
|
| 182 |
retriever.search_detailed.return_value = _make_hybrid_result([])
|
| 183 |
reranker.rerank.return_value = []
|
| 184 |
+
_setup_llm_chain_danish(llm_chain, "No information found", intent="factual")
|
| 185 |
|
| 186 |
+
router = QueryRouter(classifier, retriever, reranker, llm_chain)
|
| 187 |
response = router.route("asdfghjkl", top_k=3)
|
| 188 |
|
| 189 |
assert response.confidence == 0.0
|
| 190 |
assert response.sources == []
|
| 191 |
assert response.answer == "No information found"
|
| 192 |
|
| 193 |
+
def test_empty_context_passed_to_llm_chain(self, mock_components) -> None:
|
| 194 |
"""When no chunks are retrieved, the prompt context should be empty."""
|
| 195 |
+
classifier, retriever, reranker, llm_chain = mock_components
|
| 196 |
|
| 197 |
retriever.search_detailed.return_value = _make_hybrid_result([])
|
| 198 |
reranker.rerank.return_value = []
|
| 199 |
+
_setup_llm_chain_danish(llm_chain, "answer", intent="factual")
|
| 200 |
|
| 201 |
+
router = QueryRouter(classifier, retriever, reranker, llm_chain)
|
| 202 |
router.route("gibberish", top_k=3)
|
| 203 |
|
| 204 |
+
prompt = llm_chain.invoke.call_args_list[-1][0][0]
|
| 205 |
assert "Context:\n\n" in prompt
|
| 206 |
|
| 207 |
def test_multiple_results_confidence_uses_max_score(
|
| 208 |
self, mock_components
|
| 209 |
) -> None:
|
| 210 |
"""Confidence should be the maximum score among reranked results."""
|
| 211 |
+
classifier, retriever, reranker, llm_chain = mock_components
|
| 212 |
|
| 213 |
results = [
|
| 214 |
_make_query_result("low", 0.3),
|
|
|
|
| 217 |
]
|
| 218 |
retriever.search_detailed.return_value = _make_hybrid_result(results)
|
| 219 |
reranker.rerank.return_value = results
|
| 220 |
+
_setup_llm_chain_danish(llm_chain, "summary", intent="summary")
|
| 221 |
|
| 222 |
+
router = QueryRouter(classifier, retriever, reranker, llm_chain)
|
| 223 |
response = router.route("opsummer politikken", top_k=5)
|
| 224 |
|
| 225 |
assert response.confidence == pytest.approx(0.95, abs=1e-6)
|
|
|
|
| 230 |
|
| 231 |
def test_danish_query_not_translated(self, mock_components) -> None:
|
| 232 |
"""Danish queries should be passed directly to retrieval without translation."""
|
| 233 |
+
classifier, retriever, reranker, llm_chain = mock_components
|
| 234 |
|
| 235 |
results = [_make_query_result("ctx", 0.5)]
|
| 236 |
retriever.search_detailed.return_value = _make_hybrid_result(results)
|
| 237 |
reranker.rerank.return_value = results
|
| 238 |
+
_setup_llm_chain_danish(llm_chain, "svar", intent="rag")
|
| 239 |
|
| 240 |
+
router = QueryRouter(classifier, retriever, reranker, llm_chain)
|
| 241 |
router.route("Hvad er reglerne?", top_k=3)
|
| 242 |
|
| 243 |
# Only 2 invoke calls: combined detection + generation (no translation)
|
| 244 |
+
assert llm_chain.invoke.call_count == 2
|
| 245 |
retriever.search_detailed.assert_called_once_with("Hvad er reglerne?", top_k=3)
|
| 246 |
|
| 247 |
def test_english_query_translated_for_retrieval(self, mock_components) -> None:
|
| 248 |
"""English queries should be translated to Danish for retrieval."""
|
| 249 |
+
classifier, retriever, reranker, llm_chain = mock_components
|
| 250 |
|
| 251 |
results = [_make_query_result("ctx", 0.5)]
|
| 252 |
retriever.search_detailed.return_value = _make_hybrid_result(results)
|
| 253 |
reranker.rerank.return_value = results
|
| 254 |
+
_setup_llm_chain_english(llm_chain, "Hvad er reglerne?", "The rules are...", intent="rag")
|
| 255 |
|
| 256 |
+
router = QueryRouter(classifier, retriever, reranker, llm_chain, translate_query=True)
|
| 257 |
response = router.route("What are the rules?", top_k=3)
|
| 258 |
|
| 259 |
# 3 invoke calls: combined detection + translation + generation
|
| 260 |
+
assert llm_chain.invoke.call_count == 3
|
| 261 |
retriever.search_detailed.assert_called_once_with("Hvad er reglerne?", top_k=3)
|
| 262 |
reranker.rerank.assert_called_once_with("Hvad er reglerne?", results, top_k=3)
|
| 263 |
assert response.answer == "The rules are..."
|
| 264 |
|
| 265 |
def test_translation_disabled_skips_translate(self, mock_components) -> None:
|
| 266 |
"""When translate_query=False, English queries go straight to retrieval untranslated."""
|
| 267 |
+
classifier, retriever, reranker, llm_chain = mock_components
|
| 268 |
|
| 269 |
results = [_make_query_result("ctx", 0.5)]
|
| 270 |
retriever.search_detailed.return_value = _make_hybrid_result(results)
|
| 271 |
reranker.rerank.return_value = results
|
| 272 |
# Only 2 calls: combined detection + generation (no translation)
|
| 273 |
combined = "language: English\nintent: rag"
|
| 274 |
+
llm_chain.invoke.side_effect = [combined, "The answer"]
|
| 275 |
|
| 276 |
+
router = QueryRouter(classifier, retriever, reranker, llm_chain, translate_query=False)
|
| 277 |
response = router.route("What are the rules?", top_k=3)
|
| 278 |
|
| 279 |
+
assert llm_chain.invoke.call_count == 2
|
| 280 |
retriever.search_detailed.assert_called_once_with("What are the rules?", top_k=3)
|
| 281 |
assert response.answer == "The answer"
|
| 282 |
|
|
|
|
| 286 |
|
| 287 |
def test_confidence_equals_max_reranked_score(self, mock_components) -> None:
|
| 288 |
"""Confidence should equal the max reranked score (already sigmoid-normalized)."""
|
| 289 |
+
classifier, retriever, reranker, llm_chain = mock_components
|
| 290 |
|
| 291 |
results = [
|
| 292 |
_make_query_result("a", 0.7),
|
|
|
|
| 294 |
]
|
| 295 |
retriever.search_detailed.return_value = _make_hybrid_result(results)
|
| 296 |
reranker.rerank.return_value = results
|
| 297 |
+
_setup_llm_chain_danish(llm_chain, "answer", intent="rag")
|
| 298 |
|
| 299 |
+
router = QueryRouter(classifier, retriever, reranker, llm_chain)
|
| 300 |
response = router.route("test", top_k=3)
|
| 301 |
|
| 302 |
assert response.confidence == pytest.approx(0.9, abs=1e-6)
|