Dokumentassistent / tests /test_router.py
XQ
Update language and prompt
05c89bc
raw
history blame
15.5 kB
"""Tests for query router with mock LLM and components."""
from unittest.mock import MagicMock
import pytest
from src.agent.router import QueryRouter
from src.models import (
DocumentChunk,
GenerationResponse,
IntentType,
QueryResult,
)
def _make_query_result(text: str, score: float) -> QueryResult:
"""Create a QueryResult for testing."""
chunk = DocumentChunk(
chunk_id="c1",
document_id="d1",
text=text,
metadata={"page": 1},
)
return QueryResult(chunk=chunk, score=score, source="hybrid")
def _make_hybrid_result(results: list[QueryResult]) -> MagicMock:
"""Create a mock HybridSearchResult."""
hybrid = MagicMock()
hybrid.dense_results = results
hybrid.sparse_results = results
hybrid.fused_results = results
return hybrid
@pytest.fixture
def mock_components():
"""Create mock intent classifier, retriever, reranker, and llm_chain."""
classifier = MagicMock()
retriever = MagicMock()
reranker = MagicMock()
llm_chain = MagicMock()
return classifier, retriever, reranker, llm_chain
def _setup_llm_chain_danish(
llm_chain: MagicMock, final_answer: str, intent: str = "factual"
) -> None:
"""Configure llm_chain mock for Danish queries (no translation needed).
The first invoke returns the combined language+intent response,
the second invoke returns the final answer.
"""
combined = f"language: Danish\nintent: {intent}"
llm_chain.invoke.side_effect = [combined, final_answer]
def _setup_llm_chain_english(
llm_chain: MagicMock, translated_query: str, final_answer: str, intent: str = "rag"
) -> None:
"""Configure llm_chain mock for English queries (combined detection + translation + answer).
The first invoke returns combined language+intent, the second returns the
translated query, and the third returns the final answer.
"""
combined = f"language: English\nintent: {intent}"
llm_chain.invoke.side_effect = [combined, translated_query, final_answer]
class TestQueryRouterRAG:
"""Tests for queries routed as RAG (factual/summary/comparison/procedural)."""
@pytest.mark.parametrize("intent_str,expected_intent", [
("factual", IntentType.RAG), # FACTUAL overridden to RAG when sources exist
("summary", IntentType.SUMMARY),
("comparison", IntentType.COMPARISON),
("procedural", IntentType.PROCEDURAL),
])
def test_rag_intent_returns_answer_with_sources(
self, mock_components, intent_str: str, expected_intent: IntentType
) -> None:
"""RAG intents should retrieve, rerank, and generate an answer."""
classifier, retriever, reranker, llm_chain = mock_components
results = [_make_query_result("policy text", 0.85)]
retriever.search_detailed.return_value = _make_hybrid_result(results)
reranker.rerank.return_value = results
_setup_llm_chain_danish(llm_chain, "Generated answer", intent=intent_str)
router = QueryRouter(classifier, retriever, reranker, llm_chain)
response = router.route("Hvad er KU's feriepolitik?", top_k=3)
assert isinstance(response, GenerationResponse)
assert response.answer == "Generated answer"
assert response.intent == expected_intent
assert response.confidence == pytest.approx(0.85, abs=1e-6)
assert len(response.sources) == 1
retriever.search_detailed.assert_called_once_with(
"Hvad er KU's feriepolitik?", top_k=3
)
reranker.rerank.assert_called_once_with(
"Hvad er KU's feriepolitik?", results, top_k=3
)
def test_prompt_contains_context_and_query(self, mock_components) -> None:
"""The prompt sent to the LLM chain should include context and query."""
classifier, retriever, reranker, llm_chain = mock_components
results = [_make_query_result("Relevant context text", 0.9)]
retriever.search_detailed.return_value = _make_hybrid_result(results)
reranker.rerank.return_value = results
_setup_llm_chain_danish(llm_chain, "answer", intent="factual")
router = QueryRouter(classifier, retriever, reranker, llm_chain)
router.route("test query", top_k=3)
# The final invoke call is the generation call
prompt = llm_chain.invoke.call_args_list[-1][0][0]
assert "Relevant context text" in prompt
assert "test query" in prompt
def test_prompt_contains_language_rule(self, mock_components) -> None:
"""The prompt should contain a language instruction matching user language."""
classifier, retriever, reranker, llm_chain = mock_components
results = [_make_query_result("ctx", 0.5)]
retriever.search_detailed.return_value = _make_hybrid_result(results)
reranker.rerank.return_value = results
_setup_llm_chain_english(llm_chain, "oversæt forespørgsel", "answer", intent="rag")
router = QueryRouter(classifier, retriever, reranker, llm_chain)
router.route("What is KU's vacation policy?", top_k=3)
prompt = llm_chain.invoke.call_args_list[-1][0][0]
assert "MUST answer in English" in prompt
class TestQueryRouterDirect:
"""Tests for queries that get a direct answer (UNKNOWN intent, no retrieval hits)."""
def test_unknown_intent_still_generates_answer(self, mock_components) -> None:
"""UNKNOWN intent skips retrieval and returns zero confidence."""
classifier, retriever, reranker, llm_chain = mock_components
_setup_llm_chain_danish(llm_chain, "Fallback answer", intent="unknown")
router = QueryRouter(classifier, retriever, reranker, llm_chain)
response = router.route("Hej, hvad kan du hjælpe med?", top_k=3)
assert response.answer == "Fallback answer"
assert response.intent == IntentType.UNKNOWN
assert response.confidence == 0.0
retriever.search_detailed.assert_not_called()
reranker.rerank.assert_not_called()
def test_unknown_intent_prompt_uses_generic_instruction(
self, mock_components
) -> None:
"""UNKNOWN intent should use the generic helpful instruction."""
classifier, retriever, reranker, llm_chain = mock_components
_setup_llm_chain_danish(llm_chain, "answer", intent="unknown")
router = QueryRouter(classifier, retriever, reranker, llm_chain)
router.route("random input", top_k=3)
prompt = llm_chain.invoke.call_args_list[-1][0][0]
assert "as helpfully as possible" in prompt
class TestQueryRouterFallback:
"""Tests for ambiguous input and fallback/degradation behaviour."""
def test_empty_reranked_results_gives_zero_confidence(
self, mock_components
) -> None:
"""When reranker returns no results, confidence should be 0.0."""
classifier, retriever, reranker, llm_chain = mock_components
retriever.search_detailed.return_value = _make_hybrid_result([])
reranker.rerank.return_value = []
_setup_llm_chain_danish(llm_chain, "No information found", intent="factual")
router = QueryRouter(classifier, retriever, reranker, llm_chain)
response = router.route("asdfghjkl", top_k=3)
assert response.confidence == 0.0
assert response.sources == []
assert response.answer == "No information found"
def test_empty_context_passed_to_llm_chain(self, mock_components) -> None:
"""When no chunks are retrieved, the prompt context should be empty."""
classifier, retriever, reranker, llm_chain = mock_components
retriever.search_detailed.return_value = _make_hybrid_result([])
reranker.rerank.return_value = []
_setup_llm_chain_danish(llm_chain, "answer", intent="factual")
router = QueryRouter(classifier, retriever, reranker, llm_chain)
router.route("gibberish", top_k=3)
prompt = llm_chain.invoke.call_args_list[-1][0][0]
assert "Context:\n\n" in prompt
def test_multiple_results_confidence_uses_max_score(
self, mock_components
) -> None:
"""Confidence should be the maximum score among reranked results."""
classifier, retriever, reranker, llm_chain = mock_components
results = [
_make_query_result("low", 0.3),
_make_query_result("high", 0.95),
_make_query_result("mid", 0.6),
]
retriever.search_detailed.return_value = _make_hybrid_result(results)
reranker.rerank.return_value = results
_setup_llm_chain_danish(llm_chain, "summary", intent="summary")
router = QueryRouter(classifier, retriever, reranker, llm_chain)
response = router.route("opsummer politikken", top_k=5)
assert response.confidence == pytest.approx(0.95, abs=1e-6)
class TestQueryTranslation:
"""Tests for query language detection and translation."""
def test_danish_query_not_translated(self, mock_components) -> None:
"""Danish queries should be passed directly to retrieval without translation."""
classifier, retriever, reranker, llm_chain = mock_components
results = [_make_query_result("ctx", 0.5)]
retriever.search_detailed.return_value = _make_hybrid_result(results)
reranker.rerank.return_value = results
_setup_llm_chain_danish(llm_chain, "svar", intent="rag")
router = QueryRouter(classifier, retriever, reranker, llm_chain)
router.route("Hvad er reglerne?", top_k=3)
# Only 2 invoke calls: combined detection + generation (no translation)
assert llm_chain.invoke.call_count == 2
retriever.search_detailed.assert_called_once_with("Hvad er reglerne?", top_k=3)
def test_english_query_translated_for_retrieval(self, mock_components) -> None:
"""English queries should be translated into the corpus language for retrieval."""
classifier, retriever, reranker, llm_chain = mock_components
results = [_make_query_result("ctx", 0.5)]
retriever.search_detailed.return_value = _make_hybrid_result(results)
reranker.rerank.return_value = results
_setup_llm_chain_english(llm_chain, "Hvad er reglerne?", "The rules are...", intent="rag")
router = QueryRouter(
classifier, retriever, reranker, llm_chain,
translate_query=True, document_languages=["Danish"],
)
response = router.route("What are the rules?", top_k=3)
# 3 invoke calls: combined detection + translation + generation
assert llm_chain.invoke.call_count == 3
retriever.search_detailed.assert_called_once_with("Hvad er reglerne?", top_k=3)
reranker.rerank.assert_called_once_with("Hvad er reglerne?", results, top_k=3)
assert response.answer == "The rules are..."
def test_translation_disabled_skips_translate(self, mock_components) -> None:
"""When translate_query=False, English queries go straight to retrieval untranslated."""
classifier, retriever, reranker, llm_chain = mock_components
results = [_make_query_result("ctx", 0.5)]
retriever.search_detailed.return_value = _make_hybrid_result(results)
reranker.rerank.return_value = results
# Only 2 calls: combined detection + generation (no translation)
combined = "language: English\nintent: rag"
llm_chain.invoke.side_effect = [combined, "The answer"]
router = QueryRouter(classifier, retriever, reranker, llm_chain, translate_query=False)
response = router.route("What are the rules?", top_k=3)
assert llm_chain.invoke.call_count == 2
retriever.search_detailed.assert_called_once_with("What are the rules?", top_k=3)
assert response.answer == "The answer"
class TestSigmoidInReranker:
"""Tests that sigmoid normalization is in the reranker, not the router."""
def test_confidence_equals_max_reranked_score(self, mock_components) -> None:
"""Confidence should equal the max reranked score (already sigmoid-normalized)."""
classifier, retriever, reranker, llm_chain = mock_components
results = [
_make_query_result("a", 0.7),
_make_query_result("b", 0.9),
]
retriever.search_detailed.return_value = _make_hybrid_result(results)
reranker.rerank.return_value = results
_setup_llm_chain_danish(llm_chain, "answer", intent="rag")
router = QueryRouter(classifier, retriever, reranker, llm_chain)
response = router.route("test", top_k=3)
assert response.confidence == pytest.approx(0.9, abs=1e-6)
class TestLowConfidenceRetry:
"""Tests for the query-broadening retry loop on low confidence."""
def test_low_confidence_triggers_retry(self, mock_components) -> None:
"""When reranker returns low-confidence results, the query should be
broadened and retrieval retried once."""
classifier, retriever, reranker, llm_chain = mock_components
low_results = [_make_query_result("weak match", 0.15)]
good_results = [_make_query_result("strong match", 0.85)]
retriever.search_detailed.return_value = _make_hybrid_result(low_results)
# First rerank: low confidence → triggers retry
# Second rerank: high confidence → proceeds to generate
reranker.rerank.side_effect = [low_results, good_results]
# LLM calls: detect, broaden_query, generate
combined = "language: Danish\nintent: factual"
llm_chain.invoke.side_effect = [combined, "bredere søgning", "Final answer"]
router = QueryRouter(classifier, retriever, reranker, llm_chain)
response = router.route("snævert spørgsmål", top_k=3)
assert response.answer == "Final answer"
assert response.confidence == pytest.approx(0.85, abs=1e-6)
assert retriever.search_detailed.call_count == 2
assert reranker.rerank.call_count == 2
def test_empty_results_do_not_trigger_retry(self, mock_components) -> None:
"""When reranker returns no results at all, retrying is skipped."""
classifier, retriever, reranker, llm_chain = mock_components
retriever.search_detailed.return_value = _make_hybrid_result([])
reranker.rerank.return_value = []
_setup_llm_chain_danish(llm_chain, "No information found", intent="factual")
router = QueryRouter(classifier, retriever, reranker, llm_chain)
response = router.route("asdfghjkl", top_k=3)
assert response.confidence == 0.0
assert retriever.search_detailed.call_count == 1
# Reranker still called once (with empty input, returns [])
assert reranker.rerank.call_count <= 1
def test_high_confidence_skips_retry(self, mock_components) -> None:
"""When confidence is above threshold, no retry is attempted."""
classifier, retriever, reranker, llm_chain = mock_components
results = [_make_query_result("good match", 0.9)]
retriever.search_detailed.return_value = _make_hybrid_result(results)
reranker.rerank.return_value = results
_setup_llm_chain_danish(llm_chain, "answer", intent="factual")
router = QueryRouter(classifier, retriever, reranker, llm_chain)
router.route("test", top_k=3)
assert retriever.search_detailed.call_count == 1
assert reranker.rerank.call_count == 1