XQ commited on
Commit
ec64993
·
1 Parent(s): c263a7d

Code cleaning

Browse files
scripts/e2e_test.py CHANGED
@@ -98,12 +98,12 @@ def main() -> None:
98
  )
99
  reranker = Reranker(model=create_reranker(settings.reranker_model))
100
  classifier = IntentClassifier(llm=llm)
101
- generator = llm | StrOutputParser()
102
  router = QueryRouter(
103
  intent_classifier=classifier,
104
  hybrid_retriever=hybrid,
105
  reranker=reranker,
106
- generator=generator,
107
  )
108
 
109
  # --- 5) Run query ---
 
98
  )
99
  reranker = Reranker(model=create_reranker(settings.reranker_model))
100
  classifier = IntentClassifier(llm=llm)
101
+ llm_chain = llm | StrOutputParser()
102
  router = QueryRouter(
103
  intent_classifier=classifier,
104
  hybrid_retriever=hybrid,
105
  reranker=reranker,
106
+ llm_chain=llm_chain,
107
  )
108
 
109
  # --- 5) Run query ---
scripts/evaluate.py CHANGED
@@ -156,12 +156,12 @@ def main() -> None:
156
  )
157
  reranker = Reranker(model=create_reranker(settings.reranker_model))
158
  classifier = IntentClassifier(llm=llm, model_name=settings.generation_model)
159
- generator = llm | StrOutputParser()
160
  router = QueryRouter(
161
  intent_classifier=classifier,
162
  hybrid_retriever=hybrid,
163
  reranker=reranker,
164
- generator=generator,
165
  )
166
 
167
  # --- 5) Run test set ---
 
156
  )
157
  reranker = Reranker(model=create_reranker(settings.reranker_model))
158
  classifier = IntentClassifier(llm=llm, model_name=settings.generation_model)
159
+ llm_chain = llm | StrOutputParser()
160
  router = QueryRouter(
161
  intent_classifier=classifier,
162
  hybrid_retriever=hybrid,
163
  reranker=reranker,
164
+ llm_chain=llm_chain,
165
  )
166
 
167
  # --- 5) Run test set ---
src/agent/react_router.py CHANGED
@@ -39,20 +39,6 @@ _SYSTEM_PROMPT = (
39
  )
40
 
41
 
42
- def _ser_sources(sources: list[QueryResult]) -> list[dict]:
43
- """Serialise QueryResult list to a JSON-safe list of dicts."""
44
- return [
45
- {
46
- "chunk_id": r.chunk.chunk_id,
47
- "document_id": r.chunk.document_id,
48
- "text": r.chunk.text,
49
- "score": r.score,
50
- "source": r.source,
51
- }
52
- for r in sources
53
- ]
54
-
55
-
56
  class ReActRouter:
57
  """Routes queries through a multi-step ReAct agent with tool-calling LLM.
58
 
@@ -231,7 +217,7 @@ class ReActRouter:
231
  "step": "done",
232
  "result": {
233
  "answer": answer,
234
- "sources": _ser_sources(sources),
235
  "intent": (IntentType.RAG if sources else IntentType.FACTUAL).value,
236
  "confidence": confidence,
237
  "pipeline_details": {
@@ -239,10 +225,10 @@ class ReActRouter:
239
  "retrieval_query": ", ".join(q for _, q in store.tool_calls) or query,
240
  "detected_language": "unknown",
241
  "translated": False,
242
- "dense_results": _ser_sources(store.dense_results),
243
- "sparse_results": _ser_sources(store.sparse_results),
244
- "fused_results": _ser_sources(store.fused_results),
245
- "reranked_results": _ser_sources(sources),
246
  },
247
  },
248
  }
 
39
  )
40
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  class ReActRouter:
43
  """Routes queries through a multi-step ReAct agent with tool-calling LLM.
44
 
 
217
  "step": "done",
218
  "result": {
219
  "answer": answer,
220
+ "sources": [r.to_dict() for r in sources],
221
  "intent": (IntentType.RAG if sources else IntentType.FACTUAL).value,
222
  "confidence": confidence,
223
  "pipeline_details": {
 
225
  "retrieval_query": ", ".join(q for _, q in store.tool_calls) or query,
226
  "detected_language": "unknown",
227
  "translated": False,
228
+ "dense_results": [r.to_dict(include_text=False) for r in store.dense_results],
229
+ "sparse_results": [r.to_dict(include_text=False) for r in store.sparse_results],
230
+ "fused_results": [r.to_dict(include_text=False) for r in store.fused_results],
231
+ "reranked_results": [r.to_dict(include_text=False) for r in sources],
232
  },
233
  },
234
  }
src/agent/router.py CHANGED
@@ -48,6 +48,32 @@ class RouterState(TypedDict):
48
  answer: str
49
 
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  class QueryRouter:
52
  """Routes queries to appropriate retrieval and generation pipelines."""
53
 
@@ -56,7 +82,7 @@ class QueryRouter:
56
  intent_classifier: IntentClassifier,
57
  hybrid_retriever: HybridRetriever,
58
  reranker: Reranker,
59
- generator: Runnable,
60
  *,
61
  translate_query: bool = True,
62
  ) -> None:
@@ -66,7 +92,8 @@ class QueryRouter:
66
  intent_classifier: IntentClassifier instance.
67
  hybrid_retriever: HybridRetriever instance.
68
  reranker: Reranker instance.
69
- generator: LLM generation chain.
 
70
  translate_query: Whether to translate non-Danish queries to Danish
71
  before retrieval. When False, language detection still runs for
72
  the answer-language rule but no translation is performed.
@@ -74,7 +101,7 @@ class QueryRouter:
74
  self._intent_classifier = intent_classifier
75
  self._hybrid_retriever = hybrid_retriever
76
  self._reranker = reranker
77
- self._generator = generator
78
  self._translate_query_enabled = translate_query
79
  self._graph = self._build_graph()
80
 
@@ -155,7 +182,7 @@ class QueryRouter:
155
  "intent: <intent>\n\n"
156
  f"Query: {query}"
157
  )
158
- raw = str(self._generator.invoke(prompt)).strip()
159
  logger.debug("Combined detection raw response: %s", raw)
160
 
161
  # Parse response
@@ -200,10 +227,77 @@ class QueryRouter:
200
  "Reply with ONLY the translated text, nothing else.\n\n"
201
  f"Text: {query}"
202
  )
203
- translated = str(self._generator.invoke(translate_prompt)).strip()
204
  logger.info("Translated query to Danish: %s", translated)
205
  return translated
206
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  def _build_graph(self) -> object:
208
  """Build the LangGraph routing graph.
209
 
@@ -218,75 +312,19 @@ class QueryRouter:
218
  Returns:
219
  Compiled LangGraph graph.
220
  """
221
-
222
- def detect_node(state: RouterState) -> dict:
223
- user_language, intent = self._detect_language_and_intent(state["query"])
224
- return {"user_language": user_language, "intent": intent}
225
-
226
- def translate_node(state: RouterState) -> dict:
227
- retrieval_query = self._translate_query(state["query"], state["user_language"])
228
- return {
229
- "retrieval_query": retrieval_query,
230
- "translated": retrieval_query != state["query"],
231
- }
232
-
233
- def retrieve_node(state: RouterState) -> dict:
234
- hybrid_result = self._hybrid_retriever.search_detailed(
235
- state["retrieval_query"], top_k=state["top_k"]
236
- )
237
- logger.info("Retrieved %d results from hybrid search", len(hybrid_result.fused_results))
238
- return {
239
- "dense_results": hybrid_result.dense_results,
240
- "sparse_results": hybrid_result.sparse_results,
241
- "fused_results": hybrid_result.fused_results,
242
- }
243
-
244
- def rerank_node(state: RouterState) -> dict:
245
- results = state.get("fused_results", [])
246
- reranked = (
247
- self._reranker.rerank(state["retrieval_query"], results, top_k=state["top_k"])
248
- if results
249
- else []
250
- )
251
- confidence = max(r.score for r in reranked) if reranked else 0.0
252
- logger.info("Reranked to %d results", len(reranked))
253
- if reranked:
254
- logger.info("Confidence: %.4f (sigmoid-normalized by reranker)", confidence)
255
- return {"reranked": reranked, "confidence": confidence}
256
-
257
- def update_intent_node(state: RouterState) -> dict:
258
- if state.get("reranked") and state["intent"] == IntentType.FACTUAL:
259
- logger.info("Overriding intent to RAG (sources retrieved)")
260
- return {"intent": IntentType.RAG}
261
- return {}
262
-
263
- def generate_node(state: RouterState) -> dict:
264
- reranked = state.get("reranked", [])
265
- context = "\n\n".join(r.chunk.text for r in reranked)
266
- prompt = self._build_prompt(
267
- state["query"], state["intent"], context, state["user_language"]
268
- )
269
- answer = self._generator.invoke(prompt)
270
- logger.info("Generated answer for intent=%s", state["intent"].value)
271
- return {"answer": str(answer)}
272
-
273
- def should_retrieve(state: RouterState) -> str:
274
- """Skip retrieval when intent is UNKNOWN."""
275
- return "retrieve" if state["intent"] != IntentType.UNKNOWN else "rerank"
276
-
277
  graph: StateGraph = StateGraph(RouterState)
278
- graph.add_node("detect", detect_node)
279
- graph.add_node("translate", translate_node)
280
- graph.add_node("retrieve", retrieve_node)
281
- graph.add_node("rerank", rerank_node)
282
- graph.add_node("update_intent", update_intent_node)
283
- graph.add_node("generate", generate_node)
284
 
285
  graph.set_entry_point("detect")
286
  graph.add_edge("detect", "translate")
287
  graph.add_conditional_edges(
288
  "translate",
289
- should_retrieve,
290
  {"retrieve": "retrieve", "rerank": "rerank"},
291
  )
292
  graph.add_edge("retrieve", "rerank")
@@ -308,22 +346,7 @@ class QueryRouter:
308
  """
309
  logger.info("Routing query: %s", query)
310
 
311
- initial_state: RouterState = {
312
- "query": query,
313
- "top_k": top_k,
314
- "user_language": "Danish",
315
- "intent": IntentType.UNKNOWN,
316
- "retrieval_query": query,
317
- "translated": False,
318
- "dense_results": [],
319
- "sparse_results": [],
320
- "fused_results": [],
321
- "reranked": [],
322
- "confidence": 0.0,
323
- "answer": "",
324
- }
325
-
326
- final_state: RouterState = self._graph.invoke(initial_state)
327
 
328
  pipeline = PipelineDetails(
329
  original_query=query,
@@ -386,7 +409,7 @@ class QueryRouter:
386
  # context = "\n\n".join(r.chunk.text for r in reranked)
387
  # prompt = self._build_prompt(query, intent, context, user_language)
388
  #
389
- # answer = self._generator.invoke(prompt)
390
  # logger.info("Generated answer for intent=%s", intent.value)
391
  #
392
  # if reranked:
@@ -417,24 +440,9 @@ class QueryRouter:
417
  Yields:
418
  Step event dicts, then a final ``done`` event with the result.
419
  """
420
- initial_state: RouterState = {
421
- "query": query,
422
- "top_k": top_k,
423
- "user_language": "Danish",
424
- "intent": IntentType.UNKNOWN,
425
- "retrieval_query": query,
426
- "translated": False,
427
- "dense_results": [],
428
- "sparse_results": [],
429
- "fused_results": [],
430
- "reranked": [],
431
- "confidence": 0.0,
432
- "answer": "",
433
- }
434
-
435
- accumulated: dict = dict(initial_state)
436
 
437
- for chunk in self._graph.stream(initial_state, stream_mode="updates"):
438
  for node_name, update in chunk.items():
439
  if update is None:
440
  continue
@@ -453,24 +461,12 @@ class QueryRouter:
453
  elif node_name == "rerank":
454
  event["reranked_count"] = len(update.get("reranked", []))
455
  event["confidence"] = round(update.get("confidence", 0.0), 4)
456
- # update_intent and generate: no extra fields needed
457
 
458
  yield event
459
 
460
  # Build the final response from accumulated state and emit as "done"
461
  reranked: list = accumulated.get("reranked", [])
462
 
463
- def _ser(results: list) -> list[dict]:
464
- return [
465
- {
466
- "document_id": r.chunk.document_id,
467
- "chunk_id": r.chunk.chunk_id,
468
- "score": r.score,
469
- "source": r.source,
470
- }
471
- for r in results
472
- ]
473
-
474
  pd_acc = PipelineDetails(
475
  original_query=query,
476
  retrieval_query=accumulated.get("retrieval_query", query),
@@ -486,16 +482,7 @@ class QueryRouter:
486
  "step": "done",
487
  "result": {
488
  "answer": accumulated.get("answer", ""),
489
- "sources": [
490
- {
491
- "chunk_id": r.chunk.chunk_id,
492
- "document_id": r.chunk.document_id,
493
- "text": r.chunk.text,
494
- "score": r.score,
495
- "source": r.source,
496
- }
497
- for r in reranked
498
- ],
499
  "intent": accumulated.get("intent", IntentType.UNKNOWN).value,
500
  "confidence": accumulated.get("confidence", 0.0),
501
  "pipeline_details": {
@@ -503,10 +490,10 @@ class QueryRouter:
503
  "retrieval_query": pd_acc.retrieval_query,
504
  "detected_language": pd_acc.detected_language,
505
  "translated": pd_acc.translated,
506
- "dense_results": _ser(pd_acc.dense_results),
507
- "sparse_results": _ser(pd_acc.sparse_results),
508
- "fused_results": _ser(pd_acc.fused_results),
509
- "reranked_results": _ser(pd_acc.reranked_results),
510
  },
511
  },
512
  }
 
48
  answer: str
49
 
50
 
51
+ def _make_initial_state(query: str, top_k: int) -> RouterState:
52
+ """Create a fresh RouterState with sensible defaults.
53
+
54
+ Args:
55
+ query: The user's original query.
56
+ top_k: Number of results to retrieve.
57
+
58
+ Returns:
59
+ RouterState ready to be passed into the graph.
60
+ """
61
+ return RouterState(
62
+ query=query,
63
+ top_k=top_k,
64
+ user_language="Danish",
65
+ intent=IntentType.UNKNOWN,
66
+ retrieval_query=query,
67
+ translated=False,
68
+ dense_results=[],
69
+ sparse_results=[],
70
+ fused_results=[],
71
+ reranked=[],
72
+ confidence=0.0,
73
+ answer="",
74
+ )
75
+
76
+
77
  class QueryRouter:
78
  """Routes queries to appropriate retrieval and generation pipelines."""
79
 
 
82
  intent_classifier: IntentClassifier,
83
  hybrid_retriever: HybridRetriever,
84
  reranker: Reranker,
85
+ llm_chain: Runnable,
86
  *,
87
  translate_query: bool = True,
88
  ) -> None:
 
92
  intent_classifier: IntentClassifier instance.
93
  hybrid_retriever: HybridRetriever instance.
94
  reranker: Reranker instance.
95
+ llm_chain: LLM chain (llm | StrOutputParser) for generation,
96
+ translation, and language detection.
97
  translate_query: Whether to translate non-Danish queries to Danish
98
  before retrieval. When False, language detection still runs for
99
  the answer-language rule but no translation is performed.
 
101
  self._intent_classifier = intent_classifier
102
  self._hybrid_retriever = hybrid_retriever
103
  self._reranker = reranker
104
+ self._llm_chain = llm_chain
105
  self._translate_query_enabled = translate_query
106
  self._graph = self._build_graph()
107
 
 
182
  "intent: <intent>\n\n"
183
  f"Query: {query}"
184
  )
185
+ raw = str(self._llm_chain.invoke(prompt)).strip()
186
  logger.debug("Combined detection raw response: %s", raw)
187
 
188
  # Parse response
 
227
  "Reply with ONLY the translated text, nothing else.\n\n"
228
  f"Text: {query}"
229
  )
230
+ translated = str(self._llm_chain.invoke(translate_prompt)).strip()
231
  logger.info("Translated query to Danish: %s", translated)
232
  return translated
233
 
234
+ # ------------------------------------------------------------------
235
+ # LangGraph node functions
236
+ # ------------------------------------------------------------------
237
+
238
+ def _detect_node(self, state: RouterState) -> dict:
239
+ """Detect language and classify intent."""
240
+ user_language, intent = self._detect_language_and_intent(state["query"])
241
+ return {"user_language": user_language, "intent": intent}
242
+
243
+ def _translate_node(self, state: RouterState) -> dict:
244
+ """Translate query to Danish if needed."""
245
+ retrieval_query = self._translate_query(state["query"], state["user_language"])
246
+ return {
247
+ "retrieval_query": retrieval_query,
248
+ "translated": retrieval_query != state["query"],
249
+ }
250
+
251
+ def _retrieve_node(self, state: RouterState) -> dict:
252
+ """Run hybrid search."""
253
+ hybrid_result = self._hybrid_retriever.search_detailed(
254
+ state["retrieval_query"], top_k=state["top_k"]
255
+ )
256
+ logger.info("Retrieved %d results from hybrid search", len(hybrid_result.fused_results))
257
+ return {
258
+ "dense_results": hybrid_result.dense_results,
259
+ "sparse_results": hybrid_result.sparse_results,
260
+ "fused_results": hybrid_result.fused_results,
261
+ }
262
+
263
+ def _rerank_node(self, state: RouterState) -> dict:
264
+ """Rerank fused results with cross-encoder."""
265
+ results = state.get("fused_results", [])
266
+ reranked = (
267
+ self._reranker.rerank(state["retrieval_query"], results, top_k=state["top_k"])
268
+ if results
269
+ else []
270
+ )
271
+ confidence = max(r.score for r in reranked) if reranked else 0.0
272
+ logger.info("Reranked to %d results", len(reranked))
273
+ if reranked:
274
+ logger.info("Confidence: %.4f (sigmoid-normalized by reranker)", confidence)
275
+ return {"reranked": reranked, "confidence": confidence}
276
+
277
+ @staticmethod
278
+ def _update_intent_node(state: RouterState) -> dict:
279
+ """Promote FACTUAL to RAG when sources are found."""
280
+ if state.get("reranked") and state["intent"] == IntentType.FACTUAL:
281
+ logger.info("Overriding intent to RAG (sources retrieved)")
282
+ return {"intent": IntentType.RAG}
283
+ return {}
284
+
285
+ def _generate_node(self, state: RouterState) -> dict:
286
+ """Build prompt and call LLM."""
287
+ reranked = state.get("reranked", [])
288
+ context = "\n\n".join(r.chunk.text for r in reranked)
289
+ prompt = self._build_prompt(
290
+ state["query"], state["intent"], context, state["user_language"]
291
+ )
292
+ answer = self._llm_chain.invoke(prompt)
293
+ logger.info("Generated answer for intent=%s", state["intent"].value)
294
+ return {"answer": str(answer)}
295
+
296
+ @staticmethod
297
+ def _should_retrieve(state: RouterState) -> str:
298
+ """Skip retrieval when intent is UNKNOWN."""
299
+ return "retrieve" if state["intent"] != IntentType.UNKNOWN else "rerank"
300
+
301
  def _build_graph(self) -> object:
302
  """Build the LangGraph routing graph.
303
 
 
312
  Returns:
313
  Compiled LangGraph graph.
314
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  graph: StateGraph = StateGraph(RouterState)
316
+ graph.add_node("detect", self._detect_node)
317
+ graph.add_node("translate", self._translate_node)
318
+ graph.add_node("retrieve", self._retrieve_node)
319
+ graph.add_node("rerank", self._rerank_node)
320
+ graph.add_node("update_intent", self._update_intent_node)
321
+ graph.add_node("generate", self._generate_node)
322
 
323
  graph.set_entry_point("detect")
324
  graph.add_edge("detect", "translate")
325
  graph.add_conditional_edges(
326
  "translate",
327
+ self._should_retrieve,
328
  {"retrieve": "retrieve", "rerank": "rerank"},
329
  )
330
  graph.add_edge("retrieve", "rerank")
 
346
  """
347
  logger.info("Routing query: %s", query)
348
 
349
+ final_state: RouterState = self._graph.invoke(_make_initial_state(query, top_k))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
 
351
  pipeline = PipelineDetails(
352
  original_query=query,
 
409
  # context = "\n\n".join(r.chunk.text for r in reranked)
410
  # prompt = self._build_prompt(query, intent, context, user_language)
411
  #
412
+ # answer = self._llm_chain.invoke(prompt)
413
  # logger.info("Generated answer for intent=%s", intent.value)
414
  #
415
  # if reranked:
 
440
  Yields:
441
  Step event dicts, then a final ``done`` event with the result.
442
  """
443
+ accumulated: dict = dict(_make_initial_state(query, top_k))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
444
 
445
+ for chunk in self._graph.stream(_make_initial_state(query, top_k), stream_mode="updates"):
446
  for node_name, update in chunk.items():
447
  if update is None:
448
  continue
 
461
  elif node_name == "rerank":
462
  event["reranked_count"] = len(update.get("reranked", []))
463
  event["confidence"] = round(update.get("confidence", 0.0), 4)
 
464
 
465
  yield event
466
 
467
  # Build the final response from accumulated state and emit as "done"
468
  reranked: list = accumulated.get("reranked", [])
469
 
 
 
 
 
 
 
 
 
 
 
 
470
  pd_acc = PipelineDetails(
471
  original_query=query,
472
  retrieval_query=accumulated.get("retrieval_query", query),
 
482
  "step": "done",
483
  "result": {
484
  "answer": accumulated.get("answer", ""),
485
+ "sources": [r.to_dict() for r in reranked],
 
 
 
 
 
 
 
 
 
486
  "intent": accumulated.get("intent", IntentType.UNKNOWN).value,
487
  "confidence": accumulated.get("confidence", 0.0),
488
  "pipeline_details": {
 
490
  "retrieval_query": pd_acc.retrieval_query,
491
  "detected_language": pd_acc.detected_language,
492
  "translated": pd_acc.translated,
493
+ "dense_results": [r.to_dict(include_text=False) for r in pd_acc.dense_results],
494
+ "sparse_results": [r.to_dict(include_text=False) for r in pd_acc.sparse_results],
495
+ "fused_results": [r.to_dict(include_text=False) for r in pd_acc.fused_results],
496
+ "reranked_results": [r.to_dict(include_text=False) for r in pd_acc.reranked_results],
497
  },
498
  },
499
  }
src/agent/tools.py CHANGED
@@ -5,7 +5,7 @@ from dataclasses import dataclass, field
5
 
6
  from langchain_core.tools import tool
7
 
8
- from src.models import DocumentChunk, QueryResult
9
  from src.retrieval.hybrid import HybridRetriever
10
  from src.retrieval.reranker import Reranker
11
  from src.retrieval.vector_store import VectorStore
 
5
 
6
  from langchain_core.tools import tool
7
 
8
+ from src.models import QueryResult
9
  from src.retrieval.hybrid import HybridRetriever
10
  from src.retrieval.reranker import Reranker
11
  from src.retrieval.vector_store import VectorStore
src/api/main.py CHANGED
@@ -83,12 +83,12 @@ def create_app() -> FastAPI:
83
  else:
84
  logger.info("Agent mode: pipeline (fixed DAG)")
85
  intent_classifier = IntentClassifier(llm=llm, model_name=settings.generation_model)
86
- generator = llm | StrOutputParser()
87
  query_router = QueryRouter(
88
  intent_classifier=intent_classifier,
89
  hybrid_retriever=hybrid_retriever,
90
  reranker=reranker,
91
- generator=generator,
92
  translate_query=settings.translate_query,
93
  )
94
 
@@ -113,9 +113,18 @@ def create_app() -> FastAPI:
113
 
114
 
115
  def _parse_strategy(settings: "Settings") -> "ChunkStrategy": # noqa: F821
116
- """Return the default chunking strategy from config."""
 
 
 
 
117
  from src.models import ChunkStrategy
118
- return ChunkStrategy.SEMANTIC
 
 
 
 
 
119
 
120
 
121
  app: FastAPI = create_app()
 
83
  else:
84
  logger.info("Agent mode: pipeline (fixed DAG)")
85
  intent_classifier = IntentClassifier(llm=llm, model_name=settings.generation_model)
86
+ llm_chain = llm | StrOutputParser()
87
  query_router = QueryRouter(
88
  intent_classifier=intent_classifier,
89
  hybrid_retriever=hybrid_retriever,
90
  reranker=reranker,
91
+ llm_chain=llm_chain,
92
  translate_query=settings.translate_query,
93
  )
94
 
 
113
 
114
 
115
  def _parse_strategy(settings: "Settings") -> "ChunkStrategy": # noqa: F821
116
+ """Return the chunking strategy from config, defaulting to SEMANTIC.
117
+
118
+ Reads the CHUNK_STRATEGY environment variable via settings. Falls back
119
+ to SEMANTIC when the variable is unset or empty.
120
+ """
121
  from src.models import ChunkStrategy
122
+
123
+ raw = getattr(settings, "chunk_strategy", "semantic")
124
+ try:
125
+ return ChunkStrategy(raw)
126
+ except ValueError:
127
+ return ChunkStrategy.SEMANTIC
128
 
129
 
130
  app: FastAPI = create_app()
src/api/routes.py CHANGED
@@ -183,27 +183,7 @@ async def query_documents(request: QueryRequest) -> QueryResponse:
183
  ) from exc
184
  raise
185
 
186
- sources = [
187
- {
188
- "chunk_id": result.chunk.chunk_id,
189
- "document_id": result.chunk.document_id,
190
- "text": result.chunk.text,
191
- "score": result.score,
192
- "source": result.source,
193
- }
194
- for result in response.sources
195
- ]
196
-
197
- def _to_pipeline_items(results: list) -> list[PipelineResultItem]:
198
- return [
199
- PipelineResultItem(
200
- document_id=r.chunk.document_id,
201
- chunk_id=r.chunk.chunk_id,
202
- score=r.score,
203
- source=r.source,
204
- )
205
- for r in results
206
- ]
207
 
208
  pd = response.pipeline_details
209
  pipeline_details = PipelineDetailsResponse(
@@ -211,10 +191,10 @@ async def query_documents(request: QueryRequest) -> QueryResponse:
211
  retrieval_query=pd.retrieval_query,
212
  detected_language=pd.detected_language,
213
  translated=pd.translated,
214
- dense_results=_to_pipeline_items(pd.dense_results),
215
- sparse_results=_to_pipeline_items(pd.sparse_results),
216
- fused_results=_to_pipeline_items(pd.fused_results),
217
- reranked_results=_to_pipeline_items(pd.reranked_results),
218
  )
219
 
220
  return QueryResponse(
 
183
  ) from exc
184
  raise
185
 
186
+ sources = [result.to_dict() for result in response.sources]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  pd = response.pipeline_details
189
  pipeline_details = PipelineDetailsResponse(
 
191
  retrieval_query=pd.retrieval_query,
192
  detected_language=pd.detected_language,
193
  translated=pd.translated,
194
+ dense_results=[PipelineResultItem(**r.to_dict(include_text=False)) for r in pd.dense_results],
195
+ sparse_results=[PipelineResultItem(**r.to_dict(include_text=False)) for r in pd.sparse_results],
196
+ fused_results=[PipelineResultItem(**r.to_dict(include_text=False)) for r in pd.fused_results],
197
+ reranked_results=[PipelineResultItem(**r.to_dict(include_text=False)) for r in pd.reranked_results],
198
  )
199
 
200
  return QueryResponse(
src/config.py CHANGED
@@ -26,6 +26,7 @@ class Settings:
26
  embedding_dimension: int
27
  generation_model: str
28
  reranker_model: str
 
29
  chunk_size: int
30
  chunk_overlap: int
31
  top_k: int
@@ -106,6 +107,7 @@ def load_settings() -> Settings:
106
  embedding_dimension=int(os.environ.get("EMBEDDING_DIMENSION", "384")),
107
  generation_model=os.environ.get("GENERATION_MODEL", "gemma4:e4b"),
108
  reranker_model=os.environ.get("RERANKER_MODEL", "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1"),
 
109
  chunk_size=int(os.environ.get("CHUNK_SIZE", "512")),
110
  chunk_overlap=int(os.environ.get("CHUNK_OVERLAP", "64")),
111
  top_k=int(os.environ.get("TOP_K", "5")),
 
26
  embedding_dimension: int
27
  generation_model: str
28
  reranker_model: str
29
+ chunk_strategy: str
30
  chunk_size: int
31
  chunk_overlap: int
32
  top_k: int
 
107
  embedding_dimension=int(os.environ.get("EMBEDDING_DIMENSION", "384")),
108
  generation_model=os.environ.get("GENERATION_MODEL", "gemma4:e4b"),
109
  reranker_model=os.environ.get("RERANKER_MODEL", "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1"),
110
+ chunk_strategy=os.environ.get("CHUNK_STRATEGY", "semantic"),
111
  chunk_size=int(os.environ.get("CHUNK_SIZE", "512")),
112
  chunk_overlap=int(os.environ.get("CHUNK_OVERLAP", "64")),
113
  top_k=int(os.environ.get("TOP_K", "5")),
src/models.py CHANGED
@@ -56,6 +56,26 @@ class QueryResult:
56
  score: float
57
  source: str
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  @dataclass
61
  class PipelineDetails:
 
56
  score: float
57
  source: str
58
 
59
+ def to_dict(self, *, include_text: bool = True) -> dict[str, str | float]:
60
+ """Serialise to a JSON-safe dictionary.
61
+
62
+ Args:
63
+ include_text: Whether to include the chunk text (default True).
64
+
65
+ Returns:
66
+ Dictionary with chunk_id, document_id, score, source, and
67
+ optionally text.
68
+ """
69
+ d: dict[str, str | float] = {
70
+ "chunk_id": self.chunk.chunk_id,
71
+ "document_id": self.chunk.document_id,
72
+ "score": self.score,
73
+ "source": self.source,
74
+ }
75
+ if include_text:
76
+ d["text"] = self.chunk.text
77
+ return d
78
+
79
 
80
  @dataclass
81
  class PipelineDetails:
src/retrieval/bm25_search.py CHANGED
@@ -1,12 +1,7 @@
1
  """BM25 sparse retrieval using rank_bm25."""
2
 
3
  import logging
4
- from typing import Any
5
 
6
- from langchain_core.callbacks import CallbackManagerForRetrieverRun
7
- from langchain_core.documents import Document
8
- from langchain_core.retrievers import BaseRetriever
9
- from pydantic import ConfigDict
10
  from rank_bm25 import BM25Okapi
11
 
12
  from src.models import DocumentChunk, QueryResult
@@ -64,17 +59,6 @@ class BM25Search:
64
  logger.debug("BM25 search returned %d results for query: %s", len(results), query)
65
  return results
66
 
67
- def as_retriever(self, top_k: int) -> BaseRetriever:
68
- """Return a LangChain BaseRetriever wrapping this BM25 index.
69
-
70
- Args:
71
- top_k: Number of results to return per query.
72
-
73
- Returns:
74
- A BaseRetriever that calls search() and returns Documents.
75
- """
76
- return _BM25RetrieverAdapter(bm25_search=self, top_k=top_k)
77
-
78
  @staticmethod
79
  def _tokenize(text: str) -> list[str]:
80
  """Tokenize text by lowercasing and splitting on whitespace.
@@ -86,30 +70,3 @@ class BM25Search:
86
  List of lowercase tokens.
87
  """
88
  return text.lower().split()
89
-
90
-
91
- class _BM25RetrieverAdapter(BaseRetriever):
92
- """LangChain BaseRetriever adapter over BM25Search."""
93
-
94
- model_config = ConfigDict(arbitrary_types_allowed=True)
95
-
96
- bm25_search: Any
97
- top_k: int
98
-
99
- def _get_relevant_documents(
100
- self, query: str, *, run_manager: CallbackManagerForRetrieverRun
101
- ) -> list[Document]:
102
- results = self.bm25_search.search(query, self.top_k)
103
- return [
104
- Document(
105
- page_content=r.chunk.text,
106
- metadata={
107
- "chunk_id": r.chunk.chunk_id,
108
- "document_id": r.chunk.document_id,
109
- "chunk_metadata": r.chunk.metadata,
110
- "strategy": r.chunk.strategy.value,
111
- "score": r.score,
112
- },
113
- )
114
- for r in results
115
- ]
 
1
  """BM25 sparse retrieval using rank_bm25."""
2
 
3
  import logging
 
4
 
 
 
 
 
5
  from rank_bm25 import BM25Okapi
6
 
7
  from src.models import DocumentChunk, QueryResult
 
59
  logger.debug("BM25 search returned %d results for query: %s", len(results), query)
60
  return results
61
 
 
 
 
 
 
 
 
 
 
 
 
62
  @staticmethod
63
  def _tokenize(text: str) -> list[str]:
64
  """Tokenize text by lowercasing and splitting on whitespace.
 
70
  List of lowercase tokens.
71
  """
72
  return text.lower().split()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/retrieval/hybrid.py CHANGED
@@ -3,9 +3,7 @@
3
  import logging
4
  from dataclasses import dataclass
5
 
6
- from langchain_core.documents import Document
7
-
8
- from src.models import ChunkStrategy, DocumentChunk, QueryResult
9
  from src.retrieval.bm25_search import BM25Search
10
  from src.retrieval.embedder import Embedder
11
  from src.retrieval.vector_store import VectorStore
@@ -70,8 +68,6 @@ class HybridRetriever:
70
  def search_detailed(self, query: str, top_k: int) -> HybridSearchResult:
71
  """Execute hybrid search and return all intermediate results.
72
 
73
- Uses LangChain BaseRetriever.invoke() for both dense and sparse retrieval.
74
-
75
  Args:
76
  query: The search query string.
77
  top_k: Number of top results to return after fusion.
@@ -79,14 +75,9 @@ class HybridRetriever:
79
  Returns:
80
  HybridSearchResult containing dense, sparse, and fused results.
81
  """
82
- dense_retriever = self._vector_store.as_retriever(self._embedder, top_k)
83
- sparse_retriever = self._bm25_search.as_retriever(top_k)
84
-
85
- dense_docs: list[Document] = dense_retriever.invoke(query)
86
- sparse_docs: list[Document] = sparse_retriever.invoke(query)
87
-
88
- dense_results = [self._doc_to_query_result(doc, "dense") for doc in dense_docs]
89
- sparse_results = [self._doc_to_query_result(doc, "bm25") for doc in sparse_docs]
90
 
91
  logger.debug(
92
  "Hybrid search: %d dense, %d sparse results",
@@ -101,27 +92,6 @@ class HybridRetriever:
101
  fused_results=fused[:top_k],
102
  )
103
 
104
- @staticmethod
105
- def _doc_to_query_result(doc: Document, source: str) -> QueryResult:
106
- """Convert a LangChain Document to a QueryResult.
107
-
108
- Args:
109
- doc: Document returned by a BaseRetriever.
110
- source: Retrieval source label (e.g. 'dense' or 'bm25').
111
-
112
- Returns:
113
- QueryResult with chunk and score populated from document metadata.
114
- """
115
- meta = doc.metadata
116
- chunk = DocumentChunk(
117
- chunk_id=meta.get("chunk_id", ""),
118
- document_id=meta.get("document_id", ""),
119
- text=doc.page_content,
120
- metadata=meta.get("chunk_metadata", {}),
121
- strategy=ChunkStrategy(meta.get("strategy", ChunkStrategy.RECURSIVE.value)),
122
- )
123
- return QueryResult(chunk=chunk, score=float(meta.get("score", 0.0)), source=source)
124
-
125
  def reciprocal_rank_fusion(
126
  self,
127
  dense_results: list[QueryResult],
@@ -138,9 +108,8 @@ class HybridRetriever:
138
  Returns:
139
  Merged and re-ranked list of QueryResult objects.
140
  """
141
- # Map chunk_id -> (rrf_score, best QueryResult)
142
  scores: dict[str, float] = {}
143
- best_chunk: dict[str, QueryResult] = {}
144
 
145
  for rank, result in enumerate(dense_results):
146
  cid = result.chunk.chunk_id
 
3
  import logging
4
  from dataclasses import dataclass
5
 
6
+ from src.models import DocumentChunk, QueryResult
 
 
7
  from src.retrieval.bm25_search import BM25Search
8
  from src.retrieval.embedder import Embedder
9
  from src.retrieval.vector_store import VectorStore
 
68
  def search_detailed(self, query: str, top_k: int) -> HybridSearchResult:
69
  """Execute hybrid search and return all intermediate results.
70
 
 
 
71
  Args:
72
  query: The search query string.
73
  top_k: Number of top results to return after fusion.
 
75
  Returns:
76
  HybridSearchResult containing dense, sparse, and fused results.
77
  """
78
+ query_embedding = self._embedder.embed_text(query)
79
+ dense_results = self._vector_store.search(query_embedding, top_k)
80
+ sparse_results = self._bm25_search.search(query, top_k)
 
 
 
 
 
81
 
82
  logger.debug(
83
  "Hybrid search: %d dense, %d sparse results",
 
92
  fused_results=fused[:top_k],
93
  )
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  def reciprocal_rank_fusion(
96
  self,
97
  dense_results: list[QueryResult],
 
108
  Returns:
109
  Merged and re-ranked list of QueryResult objects.
110
  """
 
111
  scores: dict[str, float] = {}
112
+ best_chunk: dict[str, DocumentChunk] = {}
113
 
114
  for rank, result in enumerate(dense_results):
115
  cid = result.chunk.chunk_id
src/retrieval/vector_store.py CHANGED
@@ -2,12 +2,7 @@
2
 
3
  import json
4
  import logging
5
- from typing import Any
6
 
7
- from langchain_core.callbacks import CallbackManagerForRetrieverRun
8
- from langchain_core.documents import Document
9
- from langchain_core.retrievers import BaseRetriever
10
- from pydantic import ConfigDict
11
  from qdrant_client import QdrantClient
12
  from qdrant_client.models import Distance, FieldCondition, Filter, MatchValue, PointStruct, VectorParams
13
 
@@ -16,6 +11,24 @@ from src.models import ChunkStrategy, DocumentChunk, QueryResult
16
  logger = logging.getLogger(__name__)
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  class VectorStore:
20
  """Manages document storage and dense retrieval via Qdrant."""
21
 
@@ -96,18 +109,10 @@ class VectorStore:
96
  limit=top_k,
97
  ).points
98
 
99
- results: list[QueryResult] = []
100
- for hit in hits:
101
- payload = hit.payload
102
- chunk = DocumentChunk(
103
- chunk_id=payload["chunk_id"],
104
- document_id=payload["document_id"],
105
- text=payload["text"],
106
- metadata=json.loads(payload["metadata"]),
107
- strategy=ChunkStrategy(payload["strategy"]),
108
- )
109
- results.append(QueryResult(chunk=chunk, score=hit.score, source="dense"))
110
-
111
  logger.debug("Dense search returned %d results", len(results))
112
  return results
113
 
@@ -129,18 +134,7 @@ class VectorStore:
129
  with_vectors=False,
130
  )
131
 
132
- chunks: list[DocumentChunk] = []
133
- for record in records:
134
- payload = record.payload
135
- chunks.append(
136
- DocumentChunk(
137
- chunk_id=payload["chunk_id"],
138
- document_id=payload["document_id"],
139
- text=payload["text"],
140
- metadata=json.loads(payload["metadata"]),
141
- strategy=ChunkStrategy(payload["strategy"]),
142
- )
143
- )
144
  logger.info("Loaded %d chunks from collection '%s'", len(chunks), self._collection_name)
145
  return chunks
146
 
@@ -176,67 +170,13 @@ class VectorStore:
176
  with_vectors=False,
177
  )
178
 
179
- chunks: list[DocumentChunk] = []
180
- for record in records:
181
- payload = record.payload
182
- chunks.append(
183
- DocumentChunk(
184
- chunk_id=payload["chunk_id"],
185
- document_id=payload["document_id"],
186
- text=payload["text"],
187
- metadata=json.loads(payload["metadata"]),
188
- strategy=ChunkStrategy(payload["strategy"]),
189
- )
190
- )
191
  logger.debug(
192
  "Fetched %d chunks for document '%s'", len(chunks), document_id
193
  )
194
  return chunks
195
 
196
- def as_retriever(self, embedder: Any, top_k: int) -> BaseRetriever:
197
- """Return a LangChain BaseRetriever wrapping this vector store.
198
-
199
- Args:
200
- embedder: Embedder instance used to encode queries.
201
- top_k: Number of results to return per query.
202
-
203
- Returns:
204
- A BaseRetriever that calls search() and returns Documents.
205
- """
206
- return _VectorStoreRetrieverAdapter(
207
- vector_store=self, embedder=embedder, top_k=top_k
208
- )
209
-
210
  def delete_collection(self) -> None:
211
  """Delete the entire collection from the store."""
212
  self._client.delete_collection(collection_name=self._collection_name)
213
  logger.info("Deleted Qdrant collection '%s'", self._collection_name)
214
-
215
-
216
- class _VectorStoreRetrieverAdapter(BaseRetriever):
217
- """LangChain BaseRetriever adapter over VectorStore."""
218
-
219
- model_config = ConfigDict(arbitrary_types_allowed=True)
220
-
221
- vector_store: Any
222
- embedder: Any
223
- top_k: int
224
-
225
- def _get_relevant_documents(
226
- self, query: str, *, run_manager: CallbackManagerForRetrieverRun
227
- ) -> list[Document]:
228
- query_embedding = self.embedder.embed_text(query)
229
- results = self.vector_store.search(query_embedding, self.top_k)
230
- return [
231
- Document(
232
- page_content=r.chunk.text,
233
- metadata={
234
- "chunk_id": r.chunk.chunk_id,
235
- "document_id": r.chunk.document_id,
236
- "chunk_metadata": r.chunk.metadata,
237
- "strategy": r.chunk.strategy.value,
238
- "score": r.score,
239
- },
240
- )
241
- for r in results
242
- ]
 
2
 
3
  import json
4
  import logging
 
5
 
 
 
 
 
6
  from qdrant_client import QdrantClient
7
  from qdrant_client.models import Distance, FieldCondition, Filter, MatchValue, PointStruct, VectorParams
8
 
 
11
  logger = logging.getLogger(__name__)
12
 
13
 
14
+ def _payload_to_chunk(payload: dict) -> DocumentChunk:
15
+ """Convert a Qdrant payload dict to a DocumentChunk.
16
+
17
+ Args:
18
+ payload: Qdrant point payload.
19
+
20
+ Returns:
21
+ DocumentChunk reconstructed from the payload.
22
+ """
23
+ return DocumentChunk(
24
+ chunk_id=payload["chunk_id"],
25
+ document_id=payload["document_id"],
26
+ text=payload["text"],
27
+ metadata=json.loads(payload["metadata"]),
28
+ strategy=ChunkStrategy(payload["strategy"]),
29
+ )
30
+
31
+
32
  class VectorStore:
33
  """Manages document storage and dense retrieval via Qdrant."""
34
 
 
109
  limit=top_k,
110
  ).points
111
 
112
+ results: list[QueryResult] = [
113
+ QueryResult(chunk=_payload_to_chunk(hit.payload), score=hit.score, source="dense")
114
+ for hit in hits
115
+ ]
 
 
 
 
 
 
 
 
116
  logger.debug("Dense search returned %d results", len(results))
117
  return results
118
 
 
134
  with_vectors=False,
135
  )
136
 
137
+ chunks = [_payload_to_chunk(record.payload) for record in records]
 
 
 
 
 
 
 
 
 
 
 
138
  logger.info("Loaded %d chunks from collection '%s'", len(chunks), self._collection_name)
139
  return chunks
140
 
 
170
  with_vectors=False,
171
  )
172
 
173
+ chunks = [_payload_to_chunk(record.payload) for record in records]
 
 
 
 
 
 
 
 
 
 
 
174
  logger.debug(
175
  "Fetched %d chunks for document '%s'", len(chunks), document_id
176
  )
177
  return chunks
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  def delete_collection(self) -> None:
180
  """Delete the entire collection from the store."""
181
  self._client.delete_collection(collection_name=self._collection_name)
182
  logger.info("Deleted Qdrant collection '%s'", self._collection_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/test_hybrid.py CHANGED
@@ -4,8 +4,6 @@ from unittest.mock import MagicMock
4
 
5
  import pytest
6
 
7
- from langchain_core.documents import Document
8
-
9
  from src.models import ChunkStrategy, DocumentChunk, QueryResult
10
  from src.retrieval.hybrid import HybridRetriever
11
 
@@ -16,20 +14,6 @@ def _make_result(chunk_id: str, score: float = 0.0, source: str = "test") -> Que
16
  return QueryResult(chunk=chunk, score=score, source=source)
17
 
18
 
19
- def _result_to_doc(result: QueryResult) -> Document:
20
- """Convert a QueryResult to a LangChain Document (mirrors the adapter output)."""
21
- return Document(
22
- page_content=result.chunk.text,
23
- metadata={
24
- "chunk_id": result.chunk.chunk_id,
25
- "document_id": result.chunk.document_id,
26
- "chunk_metadata": result.chunk.metadata,
27
- "strategy": result.chunk.strategy.value,
28
- "score": result.score,
29
- },
30
- )
31
-
32
-
33
  def _build_retriever(
34
  dense_results: list[QueryResult],
35
  sparse_results: list[QueryResult],
@@ -38,22 +22,17 @@ def _build_retriever(
38
  ) -> HybridRetriever:
39
  """Build a HybridRetriever with mocked dependencies.
40
 
41
- Mocks as_retriever().invoke() since HybridRetriever now uses the
42
- LangChain BaseRetriever interface instead of .search() directly.
43
  """
44
- dense_retriever_mock = MagicMock()
45
- dense_retriever_mock.invoke.return_value = [_result_to_doc(r) for r in dense_results]
46
-
47
  vector_store = MagicMock()
48
- vector_store.as_retriever.return_value = dense_retriever_mock
49
-
50
- sparse_retriever_mock = MagicMock()
51
- sparse_retriever_mock.invoke.return_value = [_result_to_doc(r) for r in sparse_results]
52
 
53
  bm25_search = MagicMock()
54
- bm25_search.as_retriever.return_value = sparse_retriever_mock
55
 
56
  embedder = MagicMock()
 
57
 
58
  return HybridRetriever(
59
  vector_store=vector_store,
 
4
 
5
  import pytest
6
 
 
 
7
  from src.models import ChunkStrategy, DocumentChunk, QueryResult
8
  from src.retrieval.hybrid import HybridRetriever
9
 
 
14
  return QueryResult(chunk=chunk, score=score, source=source)
15
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def _build_retriever(
18
  dense_results: list[QueryResult],
19
  sparse_results: list[QueryResult],
 
22
  ) -> HybridRetriever:
23
  """Build a HybridRetriever with mocked dependencies.
24
 
25
+ Mocks vector_store.search() and bm25_search.search() since
26
+ HybridRetriever calls them directly.
27
  """
 
 
 
28
  vector_store = MagicMock()
29
+ vector_store.search.return_value = dense_results
 
 
 
30
 
31
  bm25_search = MagicMock()
32
+ bm25_search.search.return_value = sparse_results
33
 
34
  embedder = MagicMock()
35
+ embedder.embed_text.return_value = [0.0] * 384
36
 
37
  return HybridRetriever(
38
  vector_store=vector_store,
tests/test_router.py CHANGED
@@ -35,36 +35,36 @@ def _make_hybrid_result(results: list[QueryResult]) -> MagicMock:
35
 
36
  @pytest.fixture
37
  def mock_components():
38
- """Create mock intent classifier, retriever, reranker, and generator."""
39
  classifier = MagicMock()
40
  retriever = MagicMock()
41
  reranker = MagicMock()
42
- generator = MagicMock()
43
- return classifier, retriever, reranker, generator
44
 
45
 
46
- def _setup_generator_danish(
47
- generator: MagicMock, final_answer: str, intent: str = "factual"
48
  ) -> None:
49
- """Configure generator mock for Danish queries (no translation needed).
50
 
51
  The first invoke returns the combined language+intent response,
52
  the second invoke returns the final answer.
53
  """
54
  combined = f"language: Danish\nintent: {intent}"
55
- generator.invoke.side_effect = [combined, final_answer]
56
 
57
 
58
- def _setup_generator_english(
59
- generator: MagicMock, translated_query: str, final_answer: str, intent: str = "rag"
60
  ) -> None:
61
- """Configure generator mock for English queries (combined detection + translation + answer).
62
 
63
  The first invoke returns combined language+intent, the second returns the
64
  translated query, and the third returns the final answer.
65
  """
66
  combined = f"language: English\nintent: {intent}"
67
- generator.invoke.side_effect = [combined, translated_query, final_answer]
68
 
69
 
70
  class TestQueryRouterRAG:
@@ -80,14 +80,14 @@ class TestQueryRouterRAG:
80
  self, mock_components, intent_str: str, expected_intent: IntentType
81
  ) -> None:
82
  """RAG intents should retrieve, rerank, and generate an answer."""
83
- classifier, retriever, reranker, generator = mock_components
84
 
85
  results = [_make_query_result("policy text", 0.85)]
86
  retriever.search_detailed.return_value = _make_hybrid_result(results)
87
  reranker.rerank.return_value = results
88
- _setup_generator_danish(generator, "Generated answer", intent=intent_str)
89
 
90
- router = QueryRouter(classifier, retriever, reranker, generator)
91
  response = router.route("Hvad er KU's feriepolitik?", top_k=3)
92
 
93
  assert isinstance(response, GenerationResponse)
@@ -104,35 +104,35 @@ class TestQueryRouterRAG:
104
  )
105
 
106
  def test_prompt_contains_context_and_query(self, mock_components) -> None:
107
- """The prompt sent to the generator should include context and query."""
108
- classifier, retriever, reranker, generator = mock_components
109
 
110
  results = [_make_query_result("Relevant context text", 0.9)]
111
  retriever.search_detailed.return_value = _make_hybrid_result(results)
112
  reranker.rerank.return_value = results
113
- _setup_generator_danish(generator, "answer", intent="factual")
114
 
115
- router = QueryRouter(classifier, retriever, reranker, generator)
116
  router.route("test query", top_k=3)
117
 
118
  # The final invoke call is the generation call
119
- prompt = generator.invoke.call_args_list[-1][0][0]
120
  assert "Relevant context text" in prompt
121
  assert "test query" in prompt
122
 
123
  def test_prompt_contains_language_rule(self, mock_components) -> None:
124
  """The prompt should contain a language instruction matching user language."""
125
- classifier, retriever, reranker, generator = mock_components
126
 
127
  results = [_make_query_result("ctx", 0.5)]
128
  retriever.search_detailed.return_value = _make_hybrid_result(results)
129
  reranker.rerank.return_value = results
130
- _setup_generator_english(generator, "oversæt forespørgsel", "answer", intent="rag")
131
 
132
- router = QueryRouter(classifier, retriever, reranker, generator)
133
  router.route("What is KU's vacation policy?", top_k=3)
134
 
135
- prompt = generator.invoke.call_args_list[-1][0][0]
136
  assert "MUST answer in English" in prompt
137
 
138
 
@@ -141,12 +141,12 @@ class TestQueryRouterDirect:
141
 
142
  def test_unknown_intent_still_generates_answer(self, mock_components) -> None:
143
  """UNKNOWN intent skips retrieval and returns zero confidence."""
144
- classifier, retriever, reranker, generator = mock_components
145
 
146
  reranker.rerank.return_value = []
147
- _setup_generator_danish(generator, "Fallback answer", intent="unknown")
148
 
149
- router = QueryRouter(classifier, retriever, reranker, generator)
150
  response = router.route("Hej, hvad kan du hjælpe med?", top_k=3)
151
 
152
  assert response.answer == "Fallback answer"
@@ -158,15 +158,15 @@ class TestQueryRouterDirect:
158
  self, mock_components
159
  ) -> None:
160
  """UNKNOWN intent should use the generic helpful instruction."""
161
- classifier, retriever, reranker, generator = mock_components
162
 
163
  reranker.rerank.return_value = []
164
- _setup_generator_danish(generator, "answer", intent="unknown")
165
 
166
- router = QueryRouter(classifier, retriever, reranker, generator)
167
  router.route("random input", top_k=3)
168
 
169
- prompt = generator.invoke.call_args_list[-1][0][0]
170
  assert "as helpfully as possible" in prompt
171
 
172
 
@@ -177,38 +177,38 @@ class TestQueryRouterFallback:
177
  self, mock_components
178
  ) -> None:
179
  """When reranker returns no results, confidence should be 0.0."""
180
- classifier, retriever, reranker, generator = mock_components
181
 
182
  retriever.search_detailed.return_value = _make_hybrid_result([])
183
  reranker.rerank.return_value = []
184
- _setup_generator_danish(generator, "No information found", intent="factual")
185
 
186
- router = QueryRouter(classifier, retriever, reranker, generator)
187
  response = router.route("asdfghjkl", top_k=3)
188
 
189
  assert response.confidence == 0.0
190
  assert response.sources == []
191
  assert response.answer == "No information found"
192
 
193
- def test_empty_context_passed_to_generator(self, mock_components) -> None:
194
  """When no chunks are retrieved, the prompt context should be empty."""
195
- classifier, retriever, reranker, generator = mock_components
196
 
197
  retriever.search_detailed.return_value = _make_hybrid_result([])
198
  reranker.rerank.return_value = []
199
- _setup_generator_danish(generator, "answer", intent="factual")
200
 
201
- router = QueryRouter(classifier, retriever, reranker, generator)
202
  router.route("gibberish", top_k=3)
203
 
204
- prompt = generator.invoke.call_args_list[-1][0][0]
205
  assert "Context:\n\n" in prompt
206
 
207
  def test_multiple_results_confidence_uses_max_score(
208
  self, mock_components
209
  ) -> None:
210
  """Confidence should be the maximum score among reranked results."""
211
- classifier, retriever, reranker, generator = mock_components
212
 
213
  results = [
214
  _make_query_result("low", 0.3),
@@ -217,9 +217,9 @@ class TestQueryRouterFallback:
217
  ]
218
  retriever.search_detailed.return_value = _make_hybrid_result(results)
219
  reranker.rerank.return_value = results
220
- _setup_generator_danish(generator, "summary", intent="summary")
221
 
222
- router = QueryRouter(classifier, retriever, reranker, generator)
223
  response = router.route("opsummer politikken", top_k=5)
224
 
225
  assert response.confidence == pytest.approx(0.95, abs=1e-6)
@@ -230,53 +230,53 @@ class TestQueryTranslation:
230
 
231
  def test_danish_query_not_translated(self, mock_components) -> None:
232
  """Danish queries should be passed directly to retrieval without translation."""
233
- classifier, retriever, reranker, generator = mock_components
234
 
235
  results = [_make_query_result("ctx", 0.5)]
236
  retriever.search_detailed.return_value = _make_hybrid_result(results)
237
  reranker.rerank.return_value = results
238
- _setup_generator_danish(generator, "svar", intent="rag")
239
 
240
- router = QueryRouter(classifier, retriever, reranker, generator)
241
  router.route("Hvad er reglerne?", top_k=3)
242
 
243
  # Only 2 invoke calls: combined detection + generation (no translation)
244
- assert generator.invoke.call_count == 2
245
  retriever.search_detailed.assert_called_once_with("Hvad er reglerne?", top_k=3)
246
 
247
  def test_english_query_translated_for_retrieval(self, mock_components) -> None:
248
  """English queries should be translated to Danish for retrieval."""
249
- classifier, retriever, reranker, generator = mock_components
250
 
251
  results = [_make_query_result("ctx", 0.5)]
252
  retriever.search_detailed.return_value = _make_hybrid_result(results)
253
  reranker.rerank.return_value = results
254
- _setup_generator_english(generator, "Hvad er reglerne?", "The rules are...", intent="rag")
255
 
256
- router = QueryRouter(classifier, retriever, reranker, generator, translate_query=True)
257
  response = router.route("What are the rules?", top_k=3)
258
 
259
  # 3 invoke calls: combined detection + translation + generation
260
- assert generator.invoke.call_count == 3
261
  retriever.search_detailed.assert_called_once_with("Hvad er reglerne?", top_k=3)
262
  reranker.rerank.assert_called_once_with("Hvad er reglerne?", results, top_k=3)
263
  assert response.answer == "The rules are..."
264
 
265
  def test_translation_disabled_skips_translate(self, mock_components) -> None:
266
  """When translate_query=False, English queries go straight to retrieval untranslated."""
267
- classifier, retriever, reranker, generator = mock_components
268
 
269
  results = [_make_query_result("ctx", 0.5)]
270
  retriever.search_detailed.return_value = _make_hybrid_result(results)
271
  reranker.rerank.return_value = results
272
  # Only 2 calls: combined detection + generation (no translation)
273
  combined = "language: English\nintent: rag"
274
- generator.invoke.side_effect = [combined, "The answer"]
275
 
276
- router = QueryRouter(classifier, retriever, reranker, generator, translate_query=False)
277
  response = router.route("What are the rules?", top_k=3)
278
 
279
- assert generator.invoke.call_count == 2
280
  retriever.search_detailed.assert_called_once_with("What are the rules?", top_k=3)
281
  assert response.answer == "The answer"
282
 
@@ -286,7 +286,7 @@ class TestSigmoidInReranker:
286
 
287
  def test_confidence_equals_max_reranked_score(self, mock_components) -> None:
288
  """Confidence should equal the max reranked score (already sigmoid-normalized)."""
289
- classifier, retriever, reranker, generator = mock_components
290
 
291
  results = [
292
  _make_query_result("a", 0.7),
@@ -294,9 +294,9 @@ class TestSigmoidInReranker:
294
  ]
295
  retriever.search_detailed.return_value = _make_hybrid_result(results)
296
  reranker.rerank.return_value = results
297
- _setup_generator_danish(generator, "answer", intent="rag")
298
 
299
- router = QueryRouter(classifier, retriever, reranker, generator)
300
  response = router.route("test", top_k=3)
301
 
302
  assert response.confidence == pytest.approx(0.9, abs=1e-6)
 
35
 
36
  @pytest.fixture
37
  def mock_components():
38
+ """Create mock intent classifier, retriever, reranker, and llm_chain."""
39
  classifier = MagicMock()
40
  retriever = MagicMock()
41
  reranker = MagicMock()
42
+ llm_chain = MagicMock()
43
+ return classifier, retriever, reranker, llm_chain
44
 
45
 
46
+ def _setup_llm_chain_danish(
47
+ llm_chain: MagicMock, final_answer: str, intent: str = "factual"
48
  ) -> None:
49
+ """Configure llm_chain mock for Danish queries (no translation needed).
50
 
51
  The first invoke returns the combined language+intent response,
52
  the second invoke returns the final answer.
53
  """
54
  combined = f"language: Danish\nintent: {intent}"
55
+ llm_chain.invoke.side_effect = [combined, final_answer]
56
 
57
 
58
+ def _setup_llm_chain_english(
59
+ llm_chain: MagicMock, translated_query: str, final_answer: str, intent: str = "rag"
60
  ) -> None:
61
+ """Configure llm_chain mock for English queries (combined detection + translation + answer).
62
 
63
  The first invoke returns combined language+intent, the second returns the
64
  translated query, and the third returns the final answer.
65
  """
66
  combined = f"language: English\nintent: {intent}"
67
+ llm_chain.invoke.side_effect = [combined, translated_query, final_answer]
68
 
69
 
70
  class TestQueryRouterRAG:
 
80
  self, mock_components, intent_str: str, expected_intent: IntentType
81
  ) -> None:
82
  """RAG intents should retrieve, rerank, and generate an answer."""
83
+ classifier, retriever, reranker, llm_chain = mock_components
84
 
85
  results = [_make_query_result("policy text", 0.85)]
86
  retriever.search_detailed.return_value = _make_hybrid_result(results)
87
  reranker.rerank.return_value = results
88
+ _setup_llm_chain_danish(llm_chain, "Generated answer", intent=intent_str)
89
 
90
+ router = QueryRouter(classifier, retriever, reranker, llm_chain)
91
  response = router.route("Hvad er KU's feriepolitik?", top_k=3)
92
 
93
  assert isinstance(response, GenerationResponse)
 
104
  )
105
 
106
  def test_prompt_contains_context_and_query(self, mock_components) -> None:
107
+ """The prompt sent to the LLM chain should include context and query."""
108
+ classifier, retriever, reranker, llm_chain = mock_components
109
 
110
  results = [_make_query_result("Relevant context text", 0.9)]
111
  retriever.search_detailed.return_value = _make_hybrid_result(results)
112
  reranker.rerank.return_value = results
113
+ _setup_llm_chain_danish(llm_chain, "answer", intent="factual")
114
 
115
+ router = QueryRouter(classifier, retriever, reranker, llm_chain)
116
  router.route("test query", top_k=3)
117
 
118
  # The final invoke call is the generation call
119
+ prompt = llm_chain.invoke.call_args_list[-1][0][0]
120
  assert "Relevant context text" in prompt
121
  assert "test query" in prompt
122
 
123
  def test_prompt_contains_language_rule(self, mock_components) -> None:
124
  """The prompt should contain a language instruction matching user language."""
125
+ classifier, retriever, reranker, llm_chain = mock_components
126
 
127
  results = [_make_query_result("ctx", 0.5)]
128
  retriever.search_detailed.return_value = _make_hybrid_result(results)
129
  reranker.rerank.return_value = results
130
+ _setup_llm_chain_english(llm_chain, "oversæt forespørgsel", "answer", intent="rag")
131
 
132
+ router = QueryRouter(classifier, retriever, reranker, llm_chain)
133
  router.route("What is KU's vacation policy?", top_k=3)
134
 
135
+ prompt = llm_chain.invoke.call_args_list[-1][0][0]
136
  assert "MUST answer in English" in prompt
137
 
138
 
 
141
 
142
  def test_unknown_intent_still_generates_answer(self, mock_components) -> None:
143
  """UNKNOWN intent skips retrieval and returns zero confidence."""
144
+ classifier, retriever, reranker, llm_chain = mock_components
145
 
146
  reranker.rerank.return_value = []
147
+ _setup_llm_chain_danish(llm_chain, "Fallback answer", intent="unknown")
148
 
149
+ router = QueryRouter(classifier, retriever, reranker, llm_chain)
150
  response = router.route("Hej, hvad kan du hjælpe med?", top_k=3)
151
 
152
  assert response.answer == "Fallback answer"
 
158
  self, mock_components
159
  ) -> None:
160
  """UNKNOWN intent should use the generic helpful instruction."""
161
+ classifier, retriever, reranker, llm_chain = mock_components
162
 
163
  reranker.rerank.return_value = []
164
+ _setup_llm_chain_danish(llm_chain, "answer", intent="unknown")
165
 
166
+ router = QueryRouter(classifier, retriever, reranker, llm_chain)
167
  router.route("random input", top_k=3)
168
 
169
+ prompt = llm_chain.invoke.call_args_list[-1][0][0]
170
  assert "as helpfully as possible" in prompt
171
 
172
 
 
177
  self, mock_components
178
  ) -> None:
179
  """When reranker returns no results, confidence should be 0.0."""
180
+ classifier, retriever, reranker, llm_chain = mock_components
181
 
182
  retriever.search_detailed.return_value = _make_hybrid_result([])
183
  reranker.rerank.return_value = []
184
+ _setup_llm_chain_danish(llm_chain, "No information found", intent="factual")
185
 
186
+ router = QueryRouter(classifier, retriever, reranker, llm_chain)
187
  response = router.route("asdfghjkl", top_k=3)
188
 
189
  assert response.confidence == 0.0
190
  assert response.sources == []
191
  assert response.answer == "No information found"
192
 
193
+ def test_empty_context_passed_to_llm_chain(self, mock_components) -> None:
194
  """When no chunks are retrieved, the prompt context should be empty."""
195
+ classifier, retriever, reranker, llm_chain = mock_components
196
 
197
  retriever.search_detailed.return_value = _make_hybrid_result([])
198
  reranker.rerank.return_value = []
199
+ _setup_llm_chain_danish(llm_chain, "answer", intent="factual")
200
 
201
+ router = QueryRouter(classifier, retriever, reranker, llm_chain)
202
  router.route("gibberish", top_k=3)
203
 
204
+ prompt = llm_chain.invoke.call_args_list[-1][0][0]
205
  assert "Context:\n\n" in prompt
206
 
207
  def test_multiple_results_confidence_uses_max_score(
208
  self, mock_components
209
  ) -> None:
210
  """Confidence should be the maximum score among reranked results."""
211
+ classifier, retriever, reranker, llm_chain = mock_components
212
 
213
  results = [
214
  _make_query_result("low", 0.3),
 
217
  ]
218
  retriever.search_detailed.return_value = _make_hybrid_result(results)
219
  reranker.rerank.return_value = results
220
+ _setup_llm_chain_danish(llm_chain, "summary", intent="summary")
221
 
222
+ router = QueryRouter(classifier, retriever, reranker, llm_chain)
223
  response = router.route("opsummer politikken", top_k=5)
224
 
225
  assert response.confidence == pytest.approx(0.95, abs=1e-6)
 
230
 
231
  def test_danish_query_not_translated(self, mock_components) -> None:
232
  """Danish queries should be passed directly to retrieval without translation."""
233
+ classifier, retriever, reranker, llm_chain = mock_components
234
 
235
  results = [_make_query_result("ctx", 0.5)]
236
  retriever.search_detailed.return_value = _make_hybrid_result(results)
237
  reranker.rerank.return_value = results
238
+ _setup_llm_chain_danish(llm_chain, "svar", intent="rag")
239
 
240
+ router = QueryRouter(classifier, retriever, reranker, llm_chain)
241
  router.route("Hvad er reglerne?", top_k=3)
242
 
243
  # Only 2 invoke calls: combined detection + generation (no translation)
244
+ assert llm_chain.invoke.call_count == 2
245
  retriever.search_detailed.assert_called_once_with("Hvad er reglerne?", top_k=3)
246
 
247
  def test_english_query_translated_for_retrieval(self, mock_components) -> None:
248
  """English queries should be translated to Danish for retrieval."""
249
+ classifier, retriever, reranker, llm_chain = mock_components
250
 
251
  results = [_make_query_result("ctx", 0.5)]
252
  retriever.search_detailed.return_value = _make_hybrid_result(results)
253
  reranker.rerank.return_value = results
254
+ _setup_llm_chain_english(llm_chain, "Hvad er reglerne?", "The rules are...", intent="rag")
255
 
256
+ router = QueryRouter(classifier, retriever, reranker, llm_chain, translate_query=True)
257
  response = router.route("What are the rules?", top_k=3)
258
 
259
  # 3 invoke calls: combined detection + translation + generation
260
+ assert llm_chain.invoke.call_count == 3
261
  retriever.search_detailed.assert_called_once_with("Hvad er reglerne?", top_k=3)
262
  reranker.rerank.assert_called_once_with("Hvad er reglerne?", results, top_k=3)
263
  assert response.answer == "The rules are..."
264
 
265
  def test_translation_disabled_skips_translate(self, mock_components) -> None:
266
  """When translate_query=False, English queries go straight to retrieval untranslated."""
267
+ classifier, retriever, reranker, llm_chain = mock_components
268
 
269
  results = [_make_query_result("ctx", 0.5)]
270
  retriever.search_detailed.return_value = _make_hybrid_result(results)
271
  reranker.rerank.return_value = results
272
  # Only 2 calls: combined detection + generation (no translation)
273
  combined = "language: English\nintent: rag"
274
+ llm_chain.invoke.side_effect = [combined, "The answer"]
275
 
276
+ router = QueryRouter(classifier, retriever, reranker, llm_chain, translate_query=False)
277
  response = router.route("What are the rules?", top_k=3)
278
 
279
+ assert llm_chain.invoke.call_count == 2
280
  retriever.search_detailed.assert_called_once_with("What are the rules?", top_k=3)
281
  assert response.answer == "The answer"
282
 
 
286
 
287
  def test_confidence_equals_max_reranked_score(self, mock_components) -> None:
288
  """Confidence should equal the max reranked score (already sigmoid-normalized)."""
289
+ classifier, retriever, reranker, llm_chain = mock_components
290
 
291
  results = [
292
  _make_query_result("a", 0.7),
 
294
  ]
295
  retriever.search_detailed.return_value = _make_hybrid_result(results)
296
  reranker.rerank.return_value = results
297
+ _setup_llm_chain_danish(llm_chain, "answer", intent="rag")
298
 
299
+ router = QueryRouter(classifier, retriever, reranker, llm_chain)
300
  response = router.route("test", top_k=3)
301
 
302
  assert response.confidence == pytest.approx(0.9, abs=1e-6)