Spaces:
Running
Running
| """Tests for query router with mock LLM and components.""" | |
| from unittest.mock import MagicMock | |
| import pytest | |
| from src.agent.router import QueryRouter | |
| from src.models import ( | |
| DocumentChunk, | |
| GenerationResponse, | |
| IntentType, | |
| QueryResult, | |
| ) | |
| def _make_query_result(text: str, score: float) -> QueryResult: | |
| """Create a QueryResult for testing.""" | |
| chunk = DocumentChunk( | |
| chunk_id="c1", | |
| document_id="d1", | |
| text=text, | |
| metadata={"page": 1}, | |
| ) | |
| return QueryResult(chunk=chunk, score=score, source="hybrid") | |
| def _make_hybrid_result(results: list[QueryResult]) -> MagicMock: | |
| """Create a mock HybridSearchResult.""" | |
| hybrid = MagicMock() | |
| hybrid.dense_results = results | |
| hybrid.sparse_results = results | |
| hybrid.fused_results = results | |
| return hybrid | |
| def mock_components(): | |
| """Create mock intent classifier, retriever, reranker, and llm_chain.""" | |
| classifier = MagicMock() | |
| retriever = MagicMock() | |
| reranker = MagicMock() | |
| llm_chain = MagicMock() | |
| return classifier, retriever, reranker, llm_chain | |
| def _setup_llm_chain_danish( | |
| llm_chain: MagicMock, final_answer: str, intent: str = "factual" | |
| ) -> None: | |
| """Configure llm_chain mock for Danish queries (no translation needed). | |
| The first invoke returns the combined language+intent response, | |
| the second invoke returns the final answer. | |
| """ | |
| combined = f"language: Danish\nintent: {intent}" | |
| llm_chain.invoke.side_effect = [combined, final_answer] | |
| def _setup_llm_chain_english( | |
| llm_chain: MagicMock, translated_query: str, final_answer: str, intent: str = "rag" | |
| ) -> None: | |
| """Configure llm_chain mock for English queries (combined detection + translation + answer). | |
| The first invoke returns combined language+intent, the second returns the | |
| translated query, and the third returns the final answer. | |
| """ | |
| combined = f"language: English\nintent: {intent}" | |
| llm_chain.invoke.side_effect = [combined, translated_query, final_answer] | |
| class TestQueryRouterRAG: | |
| """Tests for queries routed as RAG (factual/summary/comparison/procedural).""" | |
| def test_rag_intent_returns_answer_with_sources( | |
| self, mock_components, intent_str: str, expected_intent: IntentType | |
| ) -> None: | |
| """RAG intents should retrieve, rerank, and generate an answer.""" | |
| classifier, retriever, reranker, llm_chain = mock_components | |
| results = [_make_query_result("policy text", 0.85)] | |
| retriever.search_detailed.return_value = _make_hybrid_result(results) | |
| reranker.rerank.return_value = results | |
| _setup_llm_chain_danish(llm_chain, "Generated answer", intent=intent_str) | |
| router = QueryRouter(classifier, retriever, reranker, llm_chain) | |
| response = router.route("Hvad er KU's feriepolitik?", top_k=3) | |
| assert isinstance(response, GenerationResponse) | |
| assert response.answer == "Generated answer" | |
| assert response.intent == expected_intent | |
| assert response.confidence == pytest.approx(0.85, abs=1e-6) | |
| assert len(response.sources) == 1 | |
| retriever.search_detailed.assert_called_once_with( | |
| "Hvad er KU's feriepolitik?", top_k=3 | |
| ) | |
| reranker.rerank.assert_called_once_with( | |
| "Hvad er KU's feriepolitik?", results, top_k=3 | |
| ) | |
| def test_prompt_contains_context_and_query(self, mock_components) -> None: | |
| """The prompt sent to the LLM chain should include context and query.""" | |
| classifier, retriever, reranker, llm_chain = mock_components | |
| results = [_make_query_result("Relevant context text", 0.9)] | |
| retriever.search_detailed.return_value = _make_hybrid_result(results) | |
| reranker.rerank.return_value = results | |
| _setup_llm_chain_danish(llm_chain, "answer", intent="factual") | |
| router = QueryRouter(classifier, retriever, reranker, llm_chain) | |
| router.route("test query", top_k=3) | |
| # The final invoke call is the generation call | |
| prompt = llm_chain.invoke.call_args_list[-1][0][0] | |
| assert "Relevant context text" in prompt | |
| assert "test query" in prompt | |
| def test_prompt_contains_language_rule(self, mock_components) -> None: | |
| """The prompt should contain a language instruction matching user language.""" | |
| classifier, retriever, reranker, llm_chain = mock_components | |
| results = [_make_query_result("ctx", 0.5)] | |
| retriever.search_detailed.return_value = _make_hybrid_result(results) | |
| reranker.rerank.return_value = results | |
| _setup_llm_chain_english(llm_chain, "oversæt forespørgsel", "answer", intent="rag") | |
| router = QueryRouter(classifier, retriever, reranker, llm_chain) | |
| router.route("What is KU's vacation policy?", top_k=3) | |
| prompt = llm_chain.invoke.call_args_list[-1][0][0] | |
| assert "MUST answer in English" in prompt | |
| class TestQueryRouterDirect: | |
| """Tests for queries that get a direct answer (UNKNOWN intent, no retrieval hits).""" | |
| def test_unknown_intent_still_generates_answer(self, mock_components) -> None: | |
| """UNKNOWN intent skips retrieval and returns zero confidence.""" | |
| classifier, retriever, reranker, llm_chain = mock_components | |
| _setup_llm_chain_danish(llm_chain, "Fallback answer", intent="unknown") | |
| router = QueryRouter(classifier, retriever, reranker, llm_chain) | |
| response = router.route("Hej, hvad kan du hjælpe med?", top_k=3) | |
| assert response.answer == "Fallback answer" | |
| assert response.intent == IntentType.UNKNOWN | |
| assert response.confidence == 0.0 | |
| retriever.search_detailed.assert_not_called() | |
| reranker.rerank.assert_not_called() | |
| def test_unknown_intent_prompt_uses_generic_instruction( | |
| self, mock_components | |
| ) -> None: | |
| """UNKNOWN intent should use the generic helpful instruction.""" | |
| classifier, retriever, reranker, llm_chain = mock_components | |
| _setup_llm_chain_danish(llm_chain, "answer", intent="unknown") | |
| router = QueryRouter(classifier, retriever, reranker, llm_chain) | |
| router.route("random input", top_k=3) | |
| prompt = llm_chain.invoke.call_args_list[-1][0][0] | |
| assert "as helpfully as possible" in prompt | |
| class TestQueryRouterFallback: | |
| """Tests for ambiguous input and fallback/degradation behaviour.""" | |
| def test_empty_reranked_results_gives_zero_confidence( | |
| self, mock_components | |
| ) -> None: | |
| """When reranker returns no results, confidence should be 0.0.""" | |
| classifier, retriever, reranker, llm_chain = mock_components | |
| retriever.search_detailed.return_value = _make_hybrid_result([]) | |
| reranker.rerank.return_value = [] | |
| _setup_llm_chain_danish(llm_chain, "No information found", intent="factual") | |
| router = QueryRouter(classifier, retriever, reranker, llm_chain) | |
| response = router.route("asdfghjkl", top_k=3) | |
| assert response.confidence == 0.0 | |
| assert response.sources == [] | |
| assert response.answer == "No information found" | |
| def test_empty_context_passed_to_llm_chain(self, mock_components) -> None: | |
| """When no chunks are retrieved, the prompt context should be empty.""" | |
| classifier, retriever, reranker, llm_chain = mock_components | |
| retriever.search_detailed.return_value = _make_hybrid_result([]) | |
| reranker.rerank.return_value = [] | |
| _setup_llm_chain_danish(llm_chain, "answer", intent="factual") | |
| router = QueryRouter(classifier, retriever, reranker, llm_chain) | |
| router.route("gibberish", top_k=3) | |
| prompt = llm_chain.invoke.call_args_list[-1][0][0] | |
| assert "Context:\n\n" in prompt | |
| def test_multiple_results_confidence_uses_max_score( | |
| self, mock_components | |
| ) -> None: | |
| """Confidence should be the maximum score among reranked results.""" | |
| classifier, retriever, reranker, llm_chain = mock_components | |
| results = [ | |
| _make_query_result("low", 0.3), | |
| _make_query_result("high", 0.95), | |
| _make_query_result("mid", 0.6), | |
| ] | |
| retriever.search_detailed.return_value = _make_hybrid_result(results) | |
| reranker.rerank.return_value = results | |
| _setup_llm_chain_danish(llm_chain, "summary", intent="summary") | |
| router = QueryRouter(classifier, retriever, reranker, llm_chain) | |
| response = router.route("opsummer politikken", top_k=5) | |
| assert response.confidence == pytest.approx(0.95, abs=1e-6) | |
| class TestQueryTranslation: | |
| """Tests for query language detection and translation.""" | |
| def test_danish_query_not_translated(self, mock_components) -> None: | |
| """Danish queries should be passed directly to retrieval without translation.""" | |
| classifier, retriever, reranker, llm_chain = mock_components | |
| results = [_make_query_result("ctx", 0.5)] | |
| retriever.search_detailed.return_value = _make_hybrid_result(results) | |
| reranker.rerank.return_value = results | |
| _setup_llm_chain_danish(llm_chain, "svar", intent="rag") | |
| router = QueryRouter(classifier, retriever, reranker, llm_chain) | |
| router.route("Hvad er reglerne?", top_k=3) | |
| # Only 2 invoke calls: combined detection + generation (no translation) | |
| assert llm_chain.invoke.call_count == 2 | |
| retriever.search_detailed.assert_called_once_with("Hvad er reglerne?", top_k=3) | |
| def test_english_query_translated_for_retrieval(self, mock_components) -> None: | |
| """English queries should be translated into the corpus language for retrieval.""" | |
| classifier, retriever, reranker, llm_chain = mock_components | |
| results = [_make_query_result("ctx", 0.5)] | |
| retriever.search_detailed.return_value = _make_hybrid_result(results) | |
| reranker.rerank.return_value = results | |
| _setup_llm_chain_english(llm_chain, "Hvad er reglerne?", "The rules are...", intent="rag") | |
| router = QueryRouter( | |
| classifier, retriever, reranker, llm_chain, | |
| translate_query=True, document_languages=["Danish"], | |
| ) | |
| response = router.route("What are the rules?", top_k=3) | |
| # 3 invoke calls: combined detection + translation + generation | |
| assert llm_chain.invoke.call_count == 3 | |
| retriever.search_detailed.assert_called_once_with("Hvad er reglerne?", top_k=3) | |
| reranker.rerank.assert_called_once_with("Hvad er reglerne?", results, top_k=3) | |
| assert response.answer == "The rules are..." | |
| def test_translation_disabled_skips_translate(self, mock_components) -> None: | |
| """When translate_query=False, English queries go straight to retrieval untranslated.""" | |
| classifier, retriever, reranker, llm_chain = mock_components | |
| results = [_make_query_result("ctx", 0.5)] | |
| retriever.search_detailed.return_value = _make_hybrid_result(results) | |
| reranker.rerank.return_value = results | |
| # Only 2 calls: combined detection + generation (no translation) | |
| combined = "language: English\nintent: rag" | |
| llm_chain.invoke.side_effect = [combined, "The answer"] | |
| router = QueryRouter(classifier, retriever, reranker, llm_chain, translate_query=False) | |
| response = router.route("What are the rules?", top_k=3) | |
| assert llm_chain.invoke.call_count == 2 | |
| retriever.search_detailed.assert_called_once_with("What are the rules?", top_k=3) | |
| assert response.answer == "The answer" | |
| class TestSigmoidInReranker: | |
| """Tests that sigmoid normalization is in the reranker, not the router.""" | |
| def test_confidence_equals_max_reranked_score(self, mock_components) -> None: | |
| """Confidence should equal the max reranked score (already sigmoid-normalized).""" | |
| classifier, retriever, reranker, llm_chain = mock_components | |
| results = [ | |
| _make_query_result("a", 0.7), | |
| _make_query_result("b", 0.9), | |
| ] | |
| retriever.search_detailed.return_value = _make_hybrid_result(results) | |
| reranker.rerank.return_value = results | |
| _setup_llm_chain_danish(llm_chain, "answer", intent="rag") | |
| router = QueryRouter(classifier, retriever, reranker, llm_chain) | |
| response = router.route("test", top_k=3) | |
| assert response.confidence == pytest.approx(0.9, abs=1e-6) | |
| class TestLowConfidenceRetry: | |
| """Tests for the query-broadening retry loop on low confidence.""" | |
| def test_low_confidence_triggers_retry(self, mock_components) -> None: | |
| """When reranker returns low-confidence results, the query should be | |
| broadened and retrieval retried once.""" | |
| classifier, retriever, reranker, llm_chain = mock_components | |
| low_results = [_make_query_result("weak match", 0.15)] | |
| good_results = [_make_query_result("strong match", 0.85)] | |
| retriever.search_detailed.return_value = _make_hybrid_result(low_results) | |
| # First rerank: low confidence → triggers retry | |
| # Second rerank: high confidence → proceeds to generate | |
| reranker.rerank.side_effect = [low_results, good_results] | |
| # LLM calls: detect, broaden_query, generate | |
| combined = "language: Danish\nintent: factual" | |
| llm_chain.invoke.side_effect = [combined, "bredere søgning", "Final answer"] | |
| router = QueryRouter(classifier, retriever, reranker, llm_chain) | |
| response = router.route("snævert spørgsmål", top_k=3) | |
| assert response.answer == "Final answer" | |
| assert response.confidence == pytest.approx(0.85, abs=1e-6) | |
| assert retriever.search_detailed.call_count == 2 | |
| assert reranker.rerank.call_count == 2 | |
| def test_empty_results_do_not_trigger_retry(self, mock_components) -> None: | |
| """When reranker returns no results at all, retrying is skipped.""" | |
| classifier, retriever, reranker, llm_chain = mock_components | |
| retriever.search_detailed.return_value = _make_hybrid_result([]) | |
| reranker.rerank.return_value = [] | |
| _setup_llm_chain_danish(llm_chain, "No information found", intent="factual") | |
| router = QueryRouter(classifier, retriever, reranker, llm_chain) | |
| response = router.route("asdfghjkl", top_k=3) | |
| assert response.confidence == 0.0 | |
| assert retriever.search_detailed.call_count == 1 | |
| # Reranker still called once (with empty input, returns []) | |
| assert reranker.rerank.call_count <= 1 | |
| def test_high_confidence_skips_retry(self, mock_components) -> None: | |
| """When confidence is above threshold, no retry is attempted.""" | |
| classifier, retriever, reranker, llm_chain = mock_components | |
| results = [_make_query_result("good match", 0.9)] | |
| retriever.search_detailed.return_value = _make_hybrid_result(results) | |
| reranker.rerank.return_value = results | |
| _setup_llm_chain_danish(llm_chain, "answer", intent="factual") | |
| router = QueryRouter(classifier, retriever, reranker, llm_chain) | |
| router.route("test", top_k=3) | |
| assert retriever.search_detailed.call_count == 1 | |
| assert reranker.rerank.call_count == 1 | |