""" tests/test_phase4.py ==================== Phase 4 — LLM Generation Chain Tests Tests: - CitationInjector: marker parsing, resolution strategies, deduplication - FaithfulnessGuard: refusal detection, confidence scoring, system prompt - AnswerChain: message building, fallback logic (mocked LLMs), streaming, token extraction, max_tokens per query_type All LLM calls are mocked — no real API keys required. Run with: pytest tests/test_phase4.py -v """ from __future__ import annotations from unittest.mock import MagicMock, patch import pytest from voicevault.generation.citation_injector import CitationInjector from voicevault.generation.faithfulness_guard import ( REFUSAL_PHRASE, FaithfulnessGuard, ) from voicevault.models import Citation, RetrievalResult # ------------------------------------------------------------------ # # Helpers # # ------------------------------------------------------------------ # def _make_citation( source_file: str = "report.pdf", page_number: int = 1, section: str = "Introduction", excerpt: str = "Some relevant excerpt.", relevance_score: float = 0.8, ) -> Citation: return Citation( source_file=source_file, page_number=page_number, section=section, excerpt=excerpt, relevance_score=relevance_score, ) def _make_retrieval_result(rerank_score: float = 0.0, rrf_score: float = 0.0) -> RetrievalResult: return RetrievalResult( chunk_id="test-chunk", text="test text", source_file="test.pdf", page_number=1, rrf_score=rrf_score, rerank_score=rerank_score, ) # ------------------------------------------------------------------ # # CitationInjector Tests # # ------------------------------------------------------------------ # class TestCitationInjectorBasic: """Core parsing and injection behavior.""" def setup_method(self) -> None: self.injector = CitationInjector() self.citation_map = [ _make_citation("report.pdf", 3), _make_citation("paper.pdf", 7), ] def test_empty_answer_returns_empty(self) -> None: answer, citations = self.injector.inject("", self.citation_map) assert answer == "" assert citations == [] def test_answer_without_markers_returned_unchanged(self) -> None: text = "Machine learning is a field of AI." answer, citations = self.injector.inject(text, self.citation_map) assert answer == text assert citations == [] def test_exact_filename_and_page_resolved(self) -> None: text = "The accuracy was 94% [Source: report.pdf, p.3]." _, citations = self.injector.inject(text, self.citation_map) assert len(citations) == 1 assert citations[0].source_file == "report.pdf" assert citations[0].page_number == 3 def test_multiple_markers_resolved(self) -> None: text = ( "First fact [Source: report.pdf, p.3]. " "Second fact [Source: paper.pdf, p.7]." ) _, citations = self.injector.inject(text, self.citation_map) assert len(citations) == 2 def test_duplicate_markers_deduplicated(self) -> None: text = ( "Claim one [Source: report.pdf, p.3]. " "Same source again [Source: report.pdf, p.3]." ) _, citations = self.injector.inject(text, self.citation_map) assert len(citations) == 1 def test_answer_text_preserved_with_markers(self) -> None: """Markers are preserved in the answer text (not stripped).""" text = "The result was 94% [Source: report.pdf, p.3]." answer, _ = self.injector.inject(text, self.citation_map) assert "[Source: report.pdf, p.3]" in answer def test_citation_order_matches_first_appearance(self) -> None: text = ( "Paper result [Source: paper.pdf, p.7]. " "Report result [Source: report.pdf, p.3]." ) _, citations = self.injector.inject(text, self.citation_map) assert citations[0].source_file == "paper.pdf" assert citations[1].source_file == "report.pdf" def test_empty_citation_map_returns_no_citations(self) -> None: text = "Result [Source: anything.pdf, p.1]." _, citations = self.injector.inject(text, []) assert citations == [] class TestCitationInjectorMatchingStrategies: """Test the four resolution strategies.""" def setup_method(self) -> None: self.injector = CitationInjector() def test_strategy1_exact_match(self) -> None: """Strategy 1: exact filename + exact page.""" cmap = [_make_citation("report.pdf", 5), _make_citation("other.pdf", 5)] _, citations = self.injector.inject("[Source: report.pdf, p.5]", cmap) assert citations[0].source_file == "report.pdf" def test_strategy2_substring_match(self) -> None: """Strategy 2: filename substring + page.""" cmap = [_make_citation("annual_report_2024.pdf", 3)] _, citations = self.injector.inject("[Source: report, p.3]", cmap) assert len(citations) == 1 assert citations[0].source_file == "annual_report_2024.pdf" def test_strategy3_page_only_match(self) -> None: """Strategy 3: page number match as fallback.""" cmap = [_make_citation("unique_name.pdf", 9)] _, citations = self.injector.inject("[Source: unknownfile, p.9]", cmap) assert len(citations) == 1 assert citations[0].page_number == 9 def test_strategy4_filename_no_page(self) -> None: """Strategy 4: filename substring with no page number.""" cmap = [_make_citation("research.pdf", 1)] _, citations = self.injector.inject("[Source: research]", cmap) assert len(citations) == 1 def test_last_resort_first_citation(self) -> None: """Last resort: return first citation when nothing else matches.""" cmap = [ _make_citation("alpha.pdf", 1), _make_citation("beta.pdf", 2), ] _, citations = self.injector.inject("[Source: zzz_no_match.pdf, p.99]", cmap) assert len(citations) == 1 assert citations[0].source_file == "alpha.pdf" # ------------------------------------------------------------------ # # FaithfulnessGuard Tests # # ------------------------------------------------------------------ # class TestFaithfulnessGuardRefusal: """Refusal detection edge cases.""" def setup_method(self) -> None: self.guard = FaithfulnessGuard() def test_exact_refusal_phrase_detected(self) -> None: assert self.guard.is_refusal(REFUSAL_PHRASE) is True def test_refusal_case_insensitive(self) -> None: assert self.guard.is_refusal(REFUSAL_PHRASE.upper()) is True def test_refusal_embedded_in_text(self) -> None: text = f"Sorry, {REFUSAL_PHRASE} Please try another query." assert self.guard.is_refusal(text) is True def test_normal_answer_not_refusal(self) -> None: assert self.guard.is_refusal("Machine learning is a subset of AI.") is False def test_empty_string_is_refusal(self) -> None: assert self.guard.is_refusal("") is True def test_partial_phrase_not_refusal(self) -> None: assert self.guard.is_refusal("I could not find this") is False def test_refusal_without_trailing_period(self) -> None: phrase_no_period = REFUSAL_PHRASE.rstrip(".") assert self.guard.is_refusal(phrase_no_period) is True class TestFaithfulnessGuardConfidence: """Confidence level scoring.""" def setup_method(self) -> None: self.guard = FaithfulnessGuard() def test_empty_results_returns_low(self) -> None: assert self.guard.confidence_level([]) == "low" def test_high_rerank_score_returns_high(self) -> None: results = [_make_retrieval_result(rerank_score=0.9)] assert self.guard.confidence_level(results) == "high" def test_medium_rerank_score_returns_medium(self) -> None: results = [_make_retrieval_result(rerank_score=0.35)] assert self.guard.confidence_level(results) == "medium" def test_low_rerank_score_returns_low(self) -> None: results = [_make_retrieval_result(rerank_score=0.1)] assert self.guard.confidence_level(results) == "low" def test_uses_max_score_across_results(self) -> None: results = [ _make_retrieval_result(rerank_score=0.1), _make_retrieval_result(rerank_score=0.8), _make_retrieval_result(rerank_score=0.3), ] assert self.guard.confidence_level(results) == "high" def test_zero_rerank_falls_back_to_rrf_score(self) -> None: """When rerank_score is 0, rrf_score should be used.""" results = [_make_retrieval_result(rerank_score=0.0, rrf_score=0.6)] assert self.guard.confidence_level(results) == "high" def test_boundary_above_0_5_is_high(self) -> None: results = [_make_retrieval_result(rerank_score=0.51)] assert self.guard.confidence_level(results) == "high" def test_boundary_exactly_0_5_is_medium(self) -> None: results = [_make_retrieval_result(rerank_score=0.5)] assert self.guard.confidence_level(results) == "medium" def test_boundary_exactly_0_2_is_low(self) -> None: results = [_make_retrieval_result(rerank_score=0.2)] assert self.guard.confidence_level(results) == "low" def test_boundary_above_0_2_is_medium(self) -> None: results = [_make_retrieval_result(rerank_score=0.21)] assert self.guard.confidence_level(results) == "medium" class TestFaithfulnessGuardSystemPrompt: """System prompt construction.""" def test_system_prompt_instruction_contains_refusal_phrase(self) -> None: instruction = FaithfulnessGuard.system_prompt_instruction() assert REFUSAL_PHRASE in instruction def test_system_prompt_instruction_non_empty(self) -> None: assert len(FaithfulnessGuard.system_prompt_instruction()) > 50 def test_build_system_prompt_contains_citation_rules(self) -> None: prompt = FaithfulnessGuard.build_system_prompt() assert "CITATION RULES" in prompt def test_build_system_prompt_contains_faithfulness_rules(self) -> None: prompt = FaithfulnessGuard.build_system_prompt() assert "FAITHFULNESS RULES" in prompt def test_build_system_prompt_contains_refusal_phrase(self) -> None: prompt = FaithfulnessGuard.build_system_prompt() assert REFUSAL_PHRASE in prompt def test_build_system_prompt_non_empty(self) -> None: assert len(FaithfulnessGuard.build_system_prompt()) > 200 # ------------------------------------------------------------------ # # AnswerChain Tests # # ------------------------------------------------------------------ # class TestAnswerChainMessageBuilding: """Verify the LangChain message list is constructed correctly.""" def setup_method(self) -> None: from voicevault.generation.answer_chain import AnswerChain self.chain = AnswerChain() def test_messages_start_with_system(self) -> None: from langchain_core.messages import SystemMessage messages = self.chain._build_messages("what is AI?", "ctx", []) assert isinstance(messages[0], SystemMessage) def test_messages_end_with_human(self) -> None: from langchain_core.messages import HumanMessage messages = self.chain._build_messages("what is AI?", "ctx", []) assert isinstance(messages[-1], HumanMessage) def test_context_in_last_human_message(self) -> None: messages = self.chain._build_messages("what is AI?", "CONTEXT_TEXT", []) assert "CONTEXT_TEXT" in messages[-1].content def test_query_in_last_human_message(self) -> None: messages = self.chain._build_messages("what is AI?", "ctx", []) assert "what is AI?" in messages[-1].content def test_history_injected_as_human_ai_pairs(self) -> None: from langchain_core.messages import AIMessage, HumanMessage history = [("q1", "a1"), ("q2", "a2")] messages = self.chain._build_messages("q3", "ctx", history) # system + (human + AI) × 2 + human = 6 assert len(messages) == 6 assert isinstance(messages[1], HumanMessage) assert isinstance(messages[2], AIMessage) assert messages[1].content == "q1" assert messages[2].content == "a1" def test_history_capped_at_conversation_window(self) -> None: from config import cfg long_history = [(f"q{i}", f"a{i}") for i in range(20)] messages = self.chain._build_messages("current", "ctx", long_history) # system + (human + AI) × window + human expected_len = 1 + cfg.conversation_window * 2 + 1 assert len(messages) == expected_len def test_no_history_three_messages_only(self) -> None: messages = self.chain._build_messages("q", "ctx", []) assert len(messages) == 2 # system + human class TestAnswerChainMaxTokens: """Max tokens budget per query type.""" def setup_method(self) -> None: from voicevault.generation.answer_chain import AnswerChain self.chain = AnswerChain() def test_factual_uses_base_max_tokens(self) -> None: from config import cfg assert self.chain._max_tokens_for("factual") == cfg.max_answer_tokens def test_summary_uses_double_max_tokens(self) -> None: from config import cfg assert self.chain._max_tokens_for("summary") == cfg.max_answer_tokens * 2 def test_compare_uses_base_max_tokens(self) -> None: from config import cfg assert self.chain._max_tokens_for("compare") == cfg.max_answer_tokens class TestAnswerChainTokenExtraction: """Token extraction from AIMessage responses.""" def test_extracts_tokens_from_usage_metadata(self) -> None: from voicevault.generation.answer_chain import _extract_tokens response = MagicMock() response.usage_metadata = {"total_tokens": 123} assert _extract_tokens(response) == 123 def test_returns_zero_when_no_metadata(self) -> None: from voicevault.generation.answer_chain import _extract_tokens response = MagicMock() response.usage_metadata = None assert _extract_tokens(response) == 0 def test_returns_zero_when_attribute_missing(self) -> None: from voicevault.generation.answer_chain import _extract_tokens response = MagicMock(spec=[]) # No attributes assert _extract_tokens(response) == 0 def test_returns_zero_on_type_error(self) -> None: from voicevault.generation.answer_chain import _extract_tokens response = MagicMock() response.usage_metadata = "not_a_dict" # .get() on a string raises AttributeError assert _extract_tokens(response) == 0 class TestAnswerChainConfidenceFromCitations: """Citation-based confidence scoring.""" def test_empty_citation_map_returns_low(self) -> None: from voicevault.generation.answer_chain import _confidence_from_citations assert _confidence_from_citations([]) == "low" def test_high_relevance_returns_high(self) -> None: from voicevault.generation.answer_chain import _confidence_from_citations cmap = [_make_citation(relevance_score=0.9)] assert _confidence_from_citations(cmap) == "high" def test_medium_relevance_returns_medium(self) -> None: from voicevault.generation.answer_chain import _confidence_from_citations cmap = [_make_citation(relevance_score=0.35)] assert _confidence_from_citations(cmap) == "medium" def test_low_relevance_returns_low(self) -> None: from voicevault.generation.answer_chain import _confidence_from_citations cmap = [_make_citation(relevance_score=0.05)] assert _confidence_from_citations(cmap) == "low" def test_uses_max_across_multiple_citations(self) -> None: from voicevault.generation.answer_chain import _confidence_from_citations cmap = [ _make_citation(relevance_score=0.1), _make_citation(relevance_score=0.9), ] assert _confidence_from_citations(cmap) == "high" class TestAnswerChainGenerateMocked: """Test generate() with mocked LLM responses.""" def _make_mock_response(self, content: str, total_tokens: int = 150) -> MagicMock: response = MagicMock() response.content = content response.usage_metadata = {"total_tokens": total_tokens} return response def test_generate_returns_generation_result(self) -> None: from voicevault.generation.answer_chain import AnswerChain, GenerationResult chain = AnswerChain() citation = _make_citation("doc.pdf", 1, relevance_score=0.8) mock_response = self._make_mock_response("ML is a subset of AI [Source: doc.pdf, p.1].") with patch.object(chain, "_build_groq", return_value=MagicMock(invoke=lambda m: mock_response)): result = chain.generate( query="what is ML", context="[Source: doc.pdf, p.1]\nML is...", citation_map=[citation], query_type="factual", ) assert isinstance(result, GenerationResult) def test_generate_extracts_answer(self) -> None: from voicevault.generation.answer_chain import AnswerChain chain = AnswerChain() mock_response = self._make_mock_response("ML stands for Machine Learning.") with patch.object(chain, "_build_groq", return_value=MagicMock(invoke=lambda m: mock_response)): result = chain.generate( query="what is ML", context="context text", citation_map=[], ) assert "Machine Learning" in result.answer def test_generate_records_latency(self) -> None: from voicevault.generation.answer_chain import AnswerChain chain = AnswerChain() mock_response = self._make_mock_response("Some answer.") with patch.object(chain, "_build_groq", return_value=MagicMock(invoke=lambda m: mock_response)): result = chain.generate("q", "ctx", []) assert result.latency_ms >= 0 def test_generate_records_tokens(self) -> None: from voicevault.generation.answer_chain import AnswerChain chain = AnswerChain() mock_response = self._make_mock_response("Answer.", total_tokens=200) with patch.object(chain, "_build_groq", return_value=MagicMock(invoke=lambda m: mock_response)): result = chain.generate("q", "ctx", []) assert result.tokens_used == 200 def test_generate_detects_refusal(self) -> None: from voicevault.generation.answer_chain import AnswerChain chain = AnswerChain() mock_response = self._make_mock_response(REFUSAL_PHRASE) with patch.object(chain, "_build_groq", return_value=MagicMock(invoke=lambda m: mock_response)): result = chain.generate("q", "ctx", []) assert result.is_refusal is True def test_generate_non_refusal_answer_not_flagged(self) -> None: from voicevault.generation.answer_chain import AnswerChain chain = AnswerChain() mock_response = self._make_mock_response("This is a real answer.") with patch.object(chain, "_build_groq", return_value=MagicMock(invoke=lambda m: mock_response)): result = chain.generate("q", "ctx", []) assert result.is_refusal is False def test_generate_resolves_citations(self) -> None: from voicevault.generation.answer_chain import AnswerChain chain = AnswerChain() citation = _make_citation("paper.pdf", 4, relevance_score=0.8) mock_response = self._make_mock_response( "The accuracy was 94% [Source: paper.pdf, p.4]." ) with patch.object(chain, "_build_groq", return_value=MagicMock(invoke=lambda m: mock_response)): result = chain.generate("q", "ctx", citation_map=[citation]) assert len(result.citations) == 1 assert result.citations[0].source_file == "paper.pdf" class TestAnswerChainFallback: """Test Groq → Gemini fallback behavior.""" def _make_mock_response(self, content: str) -> MagicMock: response = MagicMock() response.content = content response.usage_metadata = {"total_tokens": 50} return response def test_falls_back_to_gemini_when_groq_raises(self) -> None: from config import cfg from voicevault.generation.answer_chain import AnswerChain chain = AnswerChain() groq_llm = MagicMock() groq_llm.invoke.side_effect = RuntimeError("Groq API error") gemini_llm = MagicMock() gemini_llm.invoke.return_value = self._make_mock_response("Gemini answered.") with ( patch.object(chain, "_build_groq", return_value=groq_llm), patch.object(chain, "_build_gemini", return_value=gemini_llm), ): result = chain.generate("q", "ctx", []) assert result.model_used == cfg.gemini_llm_model assert "Gemini answered" in result.answer def test_returns_refusal_when_both_fail(self) -> None: from voicevault.generation.answer_chain import AnswerChain chain = AnswerChain() failing_llm = MagicMock() failing_llm.invoke.side_effect = RuntimeError("API error") with ( patch.object(chain, "_build_groq", return_value=failing_llm), patch.object(chain, "_build_gemini", return_value=failing_llm), ): result = chain.generate("q", "ctx", []) assert result.model_used == "none" assert result.is_refusal is True assert REFUSAL_PHRASE in result.answer def test_returns_refusal_when_no_keys_configured(self) -> None: from voicevault.generation.answer_chain import AnswerChain chain = AnswerChain() with ( patch.object(chain, "_build_groq", return_value=None), patch.object(chain, "_build_gemini", return_value=None), ): result = chain.generate("q", "ctx", []) assert result.model_used == "none" assert REFUSAL_PHRASE in result.answer def test_groq_used_when_available(self) -> None: from config import cfg from voicevault.generation.answer_chain import AnswerChain chain = AnswerChain() groq_llm = MagicMock() groq_llm.invoke.return_value = self._make_mock_response("Groq answered.") with patch.object(chain, "_build_groq", return_value=groq_llm): result = chain.generate("q", "ctx", []) assert result.model_used == cfg.groq_llm_model assert "Groq answered" in result.answer class TestAnswerChainStreaming: """Test stream_generate() token streaming.""" def test_streaming_yields_chunks(self) -> None: from voicevault.generation.answer_chain import AnswerChain chain = AnswerChain() mock_chunks = [ MagicMock(content="Hello "), MagicMock(content="world"), MagicMock(content="!"), ] mock_llm = MagicMock() mock_llm.stream.return_value = iter(mock_chunks) with patch.object(chain, "_build_groq", return_value=mock_llm): chunks = list(chain.stream_generate("q", "ctx", [])) assert chunks == ["Hello ", "world", "!"] def test_streaming_skips_empty_chunks(self) -> None: from voicevault.generation.answer_chain import AnswerChain chain = AnswerChain() mock_chunks = [ MagicMock(content="real"), MagicMock(content=""), # empty — should be skipped MagicMock(content=" content"), ] mock_llm = MagicMock() mock_llm.stream.return_value = iter(mock_chunks) with patch.object(chain, "_build_groq", return_value=mock_llm): chunks = list(chain.stream_generate("q", "ctx", [])) assert "" not in chunks assert chunks == ["real", " content"] def test_streaming_returns_refusal_when_no_llm(self) -> None: from voicevault.generation.answer_chain import AnswerChain chain = AnswerChain() with ( patch.object(chain, "_build_groq", return_value=None), patch.object(chain, "_build_gemini", return_value=None), ): chunks = list(chain.stream_generate("q", "ctx", [])) assert REFUSAL_PHRASE in "".join(chunks) def test_streaming_yields_error_on_exception(self) -> None: from voicevault.generation.answer_chain import AnswerChain chain = AnswerChain() mock_llm = MagicMock() mock_llm.stream.side_effect = RuntimeError("connection refused") with patch.object(chain, "_build_groq", return_value=mock_llm): chunks = list(chain.stream_generate("q", "ctx", [])) combined = "".join(chunks) assert "Error" in combined or "error" in combined # ------------------------------------------------------------------ # # GenerationResult Model Tests # # ------------------------------------------------------------------ # class TestGenerationResult: """Verify GenerationResult dataclass.""" def test_can_instantiate(self) -> None: from voicevault.generation.answer_chain import GenerationResult result = GenerationResult( answer="test answer", citations=[], confidence_level="high", is_refusal=False, model_used="llama-3.1-70b-versatile", tokens_used=100, latency_ms=250, ) assert result.answer == "test answer" assert result.confidence_level == "high" assert result.is_refusal is False assert result.tokens_used == 100 assert result.latency_ms == 250 def test_citations_list_is_mutable(self) -> None: from voicevault.generation.answer_chain import GenerationResult result = GenerationResult( answer="", citations=[], confidence_level="low", is_refusal=True, model_used="none", tokens_used=0, latency_ms=0, ) result.citations.append(_make_citation()) assert len(result.citations) == 1