"""Tests for query router with mock LLM and components.""" from unittest.mock import MagicMock import pytest from src.agent.router import QueryRouter from src.models import ( DocumentChunk, GenerationResponse, IntentType, QueryResult, ) def _make_query_result(text: str, score: float) -> QueryResult: """Create a QueryResult for testing.""" chunk = DocumentChunk( chunk_id="c1", document_id="d1", text=text, metadata={"page": 1}, ) return QueryResult(chunk=chunk, score=score, source="hybrid") def _make_hybrid_result(results: list[QueryResult]) -> MagicMock: """Create a mock HybridSearchResult.""" hybrid = MagicMock() hybrid.dense_results = results hybrid.sparse_results = results hybrid.fused_results = results return hybrid @pytest.fixture def mock_components(): """Create mock intent classifier, retriever, reranker, and llm_chain.""" classifier = MagicMock() retriever = MagicMock() reranker = MagicMock() llm_chain = MagicMock() return classifier, retriever, reranker, llm_chain def _setup_llm_chain_danish( llm_chain: MagicMock, final_answer: str, intent: str = "factual" ) -> None: """Configure llm_chain mock for Danish queries (no translation needed). The first invoke returns the combined language+intent response, the second invoke returns the final answer. """ combined = f"language: Danish\nintent: {intent}" llm_chain.invoke.side_effect = [combined, final_answer] def _setup_llm_chain_english( llm_chain: MagicMock, translated_query: str, final_answer: str, intent: str = "rag" ) -> None: """Configure llm_chain mock for English queries (combined detection + translation + answer). The first invoke returns combined language+intent, the second returns the translated query, and the third returns the final answer. """ combined = f"language: English\nintent: {intent}" llm_chain.invoke.side_effect = [combined, translated_query, final_answer] class TestQueryRouterRAG: """Tests for queries routed as RAG (factual/summary/comparison/procedural).""" @pytest.mark.parametrize("intent_str,expected_intent", [ ("factual", IntentType.RAG), # FACTUAL overridden to RAG when sources exist ("summary", IntentType.SUMMARY), ("comparison", IntentType.COMPARISON), ("procedural", IntentType.PROCEDURAL), ]) def test_rag_intent_returns_answer_with_sources( self, mock_components, intent_str: str, expected_intent: IntentType ) -> None: """RAG intents should retrieve, rerank, and generate an answer.""" classifier, retriever, reranker, llm_chain = mock_components results = [_make_query_result("policy text", 0.85)] retriever.search_detailed.return_value = _make_hybrid_result(results) reranker.rerank.return_value = results _setup_llm_chain_danish(llm_chain, "Generated answer", intent=intent_str) router = QueryRouter(classifier, retriever, reranker, llm_chain) response = router.route("Hvad er KU's feriepolitik?", top_k=3) assert isinstance(response, GenerationResponse) assert response.answer == "Generated answer" assert response.intent == expected_intent assert response.confidence == pytest.approx(0.85, abs=1e-6) assert len(response.sources) == 1 retriever.search_detailed.assert_called_once_with( "Hvad er KU's feriepolitik?", top_k=3 ) reranker.rerank.assert_called_once_with( "Hvad er KU's feriepolitik?", results, top_k=3 ) def test_prompt_contains_context_and_query(self, mock_components) -> None: """The prompt sent to the LLM chain should include context and query.""" classifier, retriever, reranker, llm_chain = mock_components results = [_make_query_result("Relevant context text", 0.9)] retriever.search_detailed.return_value = _make_hybrid_result(results) reranker.rerank.return_value = results _setup_llm_chain_danish(llm_chain, "answer", intent="factual") router = QueryRouter(classifier, retriever, reranker, llm_chain) router.route("test query", top_k=3) # The final invoke call is the generation call prompt = llm_chain.invoke.call_args_list[-1][0][0] assert "Relevant context text" in prompt assert "test query" in prompt def test_prompt_contains_language_rule(self, mock_components) -> None: """The prompt should contain a language instruction matching user language.""" classifier, retriever, reranker, llm_chain = mock_components results = [_make_query_result("ctx", 0.5)] retriever.search_detailed.return_value = _make_hybrid_result(results) reranker.rerank.return_value = results _setup_llm_chain_english(llm_chain, "oversæt forespørgsel", "answer", intent="rag") router = QueryRouter(classifier, retriever, reranker, llm_chain) router.route("What is KU's vacation policy?", top_k=3) prompt = llm_chain.invoke.call_args_list[-1][0][0] assert "MUST answer in English" in prompt class TestQueryRouterDirect: """Tests for queries that get a direct answer (UNKNOWN intent, no retrieval hits).""" def test_unknown_intent_still_generates_answer(self, mock_components) -> None: """UNKNOWN intent skips retrieval and returns zero confidence.""" classifier, retriever, reranker, llm_chain = mock_components _setup_llm_chain_danish(llm_chain, "Fallback answer", intent="unknown") router = QueryRouter(classifier, retriever, reranker, llm_chain) response = router.route("Hej, hvad kan du hjælpe med?", top_k=3) assert response.answer == "Fallback answer" assert response.intent == IntentType.UNKNOWN assert response.confidence == 0.0 retriever.search_detailed.assert_not_called() reranker.rerank.assert_not_called() def test_unknown_intent_prompt_uses_generic_instruction( self, mock_components ) -> None: """UNKNOWN intent should use the generic helpful instruction.""" classifier, retriever, reranker, llm_chain = mock_components _setup_llm_chain_danish(llm_chain, "answer", intent="unknown") router = QueryRouter(classifier, retriever, reranker, llm_chain) router.route("random input", top_k=3) prompt = llm_chain.invoke.call_args_list[-1][0][0] assert "as helpfully as possible" in prompt class TestQueryRouterFallback: """Tests for ambiguous input and fallback/degradation behaviour.""" def test_empty_reranked_results_gives_zero_confidence( self, mock_components ) -> None: """When reranker returns no results, confidence should be 0.0.""" classifier, retriever, reranker, llm_chain = mock_components retriever.search_detailed.return_value = _make_hybrid_result([]) reranker.rerank.return_value = [] _setup_llm_chain_danish(llm_chain, "No information found", intent="factual") router = QueryRouter(classifier, retriever, reranker, llm_chain) response = router.route("asdfghjkl", top_k=3) assert response.confidence == 0.0 assert response.sources == [] assert response.answer == "No information found" def test_empty_context_passed_to_llm_chain(self, mock_components) -> None: """When no chunks are retrieved, the prompt context should be empty.""" classifier, retriever, reranker, llm_chain = mock_components retriever.search_detailed.return_value = _make_hybrid_result([]) reranker.rerank.return_value = [] _setup_llm_chain_danish(llm_chain, "answer", intent="factual") router = QueryRouter(classifier, retriever, reranker, llm_chain) router.route("gibberish", top_k=3) prompt = llm_chain.invoke.call_args_list[-1][0][0] assert "Context:\n\n" in prompt def test_multiple_results_confidence_uses_max_score( self, mock_components ) -> None: """Confidence should be the maximum score among reranked results.""" classifier, retriever, reranker, llm_chain = mock_components results = [ _make_query_result("low", 0.3), _make_query_result("high", 0.95), _make_query_result("mid", 0.6), ] retriever.search_detailed.return_value = _make_hybrid_result(results) reranker.rerank.return_value = results _setup_llm_chain_danish(llm_chain, "summary", intent="summary") router = QueryRouter(classifier, retriever, reranker, llm_chain) response = router.route("opsummer politikken", top_k=5) assert response.confidence == pytest.approx(0.95, abs=1e-6) class TestQueryTranslation: """Tests for query language detection and translation.""" def test_danish_query_not_translated(self, mock_components) -> None: """Danish queries should be passed directly to retrieval without translation.""" classifier, retriever, reranker, llm_chain = mock_components results = [_make_query_result("ctx", 0.5)] retriever.search_detailed.return_value = _make_hybrid_result(results) reranker.rerank.return_value = results _setup_llm_chain_danish(llm_chain, "svar", intent="rag") router = QueryRouter(classifier, retriever, reranker, llm_chain) router.route("Hvad er reglerne?", top_k=3) # Only 2 invoke calls: combined detection + generation (no translation) assert llm_chain.invoke.call_count == 2 retriever.search_detailed.assert_called_once_with("Hvad er reglerne?", top_k=3) def test_english_query_translated_for_retrieval(self, mock_components) -> None: """English queries should be translated into the corpus language for retrieval.""" classifier, retriever, reranker, llm_chain = mock_components results = [_make_query_result("ctx", 0.5)] retriever.search_detailed.return_value = _make_hybrid_result(results) reranker.rerank.return_value = results _setup_llm_chain_english(llm_chain, "Hvad er reglerne?", "The rules are...", intent="rag") router = QueryRouter( classifier, retriever, reranker, llm_chain, translate_query=True, document_languages=["Danish"], ) response = router.route("What are the rules?", top_k=3) # 3 invoke calls: combined detection + translation + generation assert llm_chain.invoke.call_count == 3 retriever.search_detailed.assert_called_once_with("Hvad er reglerne?", top_k=3) reranker.rerank.assert_called_once_with("Hvad er reglerne?", results, top_k=3) assert response.answer == "The rules are..." def test_translation_disabled_skips_translate(self, mock_components) -> None: """When translate_query=False, English queries go straight to retrieval untranslated.""" classifier, retriever, reranker, llm_chain = mock_components results = [_make_query_result("ctx", 0.5)] retriever.search_detailed.return_value = _make_hybrid_result(results) reranker.rerank.return_value = results # Only 2 calls: combined detection + generation (no translation) combined = "language: English\nintent: rag" llm_chain.invoke.side_effect = [combined, "The answer"] router = QueryRouter(classifier, retriever, reranker, llm_chain, translate_query=False) response = router.route("What are the rules?", top_k=3) assert llm_chain.invoke.call_count == 2 retriever.search_detailed.assert_called_once_with("What are the rules?", top_k=3) assert response.answer == "The answer" class TestSigmoidInReranker: """Tests that sigmoid normalization is in the reranker, not the router.""" def test_confidence_equals_max_reranked_score(self, mock_components) -> None: """Confidence should equal the max reranked score (already sigmoid-normalized).""" classifier, retriever, reranker, llm_chain = mock_components results = [ _make_query_result("a", 0.7), _make_query_result("b", 0.9), ] retriever.search_detailed.return_value = _make_hybrid_result(results) reranker.rerank.return_value = results _setup_llm_chain_danish(llm_chain, "answer", intent="rag") router = QueryRouter(classifier, retriever, reranker, llm_chain) response = router.route("test", top_k=3) assert response.confidence == pytest.approx(0.9, abs=1e-6) class TestLowConfidenceRetry: """Tests for the query-broadening retry loop on low confidence.""" def test_low_confidence_triggers_retry(self, mock_components) -> None: """When reranker returns low-confidence results, the query should be broadened and retrieval retried once.""" classifier, retriever, reranker, llm_chain = mock_components low_results = [_make_query_result("weak match", 0.15)] good_results = [_make_query_result("strong match", 0.85)] retriever.search_detailed.return_value = _make_hybrid_result(low_results) # First rerank: low confidence → triggers retry # Second rerank: high confidence → proceeds to generate reranker.rerank.side_effect = [low_results, good_results] # LLM calls: detect, broaden_query, generate combined = "language: Danish\nintent: factual" llm_chain.invoke.side_effect = [combined, "bredere søgning", "Final answer"] router = QueryRouter(classifier, retriever, reranker, llm_chain) response = router.route("snævert spørgsmål", top_k=3) assert response.answer == "Final answer" assert response.confidence == pytest.approx(0.85, abs=1e-6) assert retriever.search_detailed.call_count == 2 assert reranker.rerank.call_count == 2 def test_empty_results_do_not_trigger_retry(self, mock_components) -> None: """When reranker returns no results at all, retrying is skipped.""" classifier, retriever, reranker, llm_chain = mock_components retriever.search_detailed.return_value = _make_hybrid_result([]) reranker.rerank.return_value = [] _setup_llm_chain_danish(llm_chain, "No information found", intent="factual") router = QueryRouter(classifier, retriever, reranker, llm_chain) response = router.route("asdfghjkl", top_k=3) assert response.confidence == 0.0 assert retriever.search_detailed.call_count == 1 # Reranker still called once (with empty input, returns []) assert reranker.rerank.call_count <= 1 def test_high_confidence_skips_retry(self, mock_components) -> None: """When confidence is above threshold, no retry is attempted.""" classifier, retriever, reranker, llm_chain = mock_components results = [_make_query_result("good match", 0.9)] retriever.search_detailed.return_value = _make_hybrid_result(results) reranker.rerank.return_value = results _setup_llm_chain_danish(llm_chain, "answer", intent="factual") router = QueryRouter(classifier, retriever, reranker, llm_chain) router.route("test", top_k=3) assert retriever.search_detailed.call_count == 1 assert reranker.rerank.call_count == 1