Spaces:
Running
Running
| """ | |
| tests/test_phase4.py | |
| ==================== | |
| Phase 4 — LLM Generation Chain Tests | |
| Tests: | |
| - CitationInjector: marker parsing, resolution strategies, deduplication | |
| - FaithfulnessGuard: refusal detection, confidence scoring, system prompt | |
| - AnswerChain: message building, fallback logic (mocked LLMs), streaming, | |
| token extraction, max_tokens per query_type | |
| All LLM calls are mocked — no real API keys required. | |
| Run with: pytest tests/test_phase4.py -v | |
| """ | |
| from __future__ import annotations | |
| from unittest.mock import MagicMock, patch | |
| import pytest | |
| from voicevault.generation.citation_injector import CitationInjector | |
| from voicevault.generation.faithfulness_guard import ( | |
| REFUSAL_PHRASE, | |
| FaithfulnessGuard, | |
| ) | |
| from voicevault.models import Citation, RetrievalResult | |
| # ------------------------------------------------------------------ # | |
| # Helpers # | |
| # ------------------------------------------------------------------ # | |
| def _make_citation( | |
| source_file: str = "report.pdf", | |
| page_number: int = 1, | |
| section: str = "Introduction", | |
| excerpt: str = "Some relevant excerpt.", | |
| relevance_score: float = 0.8, | |
| ) -> Citation: | |
| return Citation( | |
| source_file=source_file, | |
| page_number=page_number, | |
| section=section, | |
| excerpt=excerpt, | |
| relevance_score=relevance_score, | |
| ) | |
| def _make_retrieval_result(rerank_score: float = 0.0, rrf_score: float = 0.0) -> RetrievalResult: | |
| return RetrievalResult( | |
| chunk_id="test-chunk", | |
| text="test text", | |
| source_file="test.pdf", | |
| page_number=1, | |
| rrf_score=rrf_score, | |
| rerank_score=rerank_score, | |
| ) | |
| # ------------------------------------------------------------------ # | |
| # CitationInjector Tests # | |
| # ------------------------------------------------------------------ # | |
| class TestCitationInjectorBasic: | |
| """Core parsing and injection behavior.""" | |
| def setup_method(self) -> None: | |
| self.injector = CitationInjector() | |
| self.citation_map = [ | |
| _make_citation("report.pdf", 3), | |
| _make_citation("paper.pdf", 7), | |
| ] | |
| def test_empty_answer_returns_empty(self) -> None: | |
| answer, citations = self.injector.inject("", self.citation_map) | |
| assert answer == "" | |
| assert citations == [] | |
| def test_answer_without_markers_returned_unchanged(self) -> None: | |
| text = "Machine learning is a field of AI." | |
| answer, citations = self.injector.inject(text, self.citation_map) | |
| assert answer == text | |
| assert citations == [] | |
| def test_exact_filename_and_page_resolved(self) -> None: | |
| text = "The accuracy was 94% [Source: report.pdf, p.3]." | |
| _, citations = self.injector.inject(text, self.citation_map) | |
| assert len(citations) == 1 | |
| assert citations[0].source_file == "report.pdf" | |
| assert citations[0].page_number == 3 | |
| def test_multiple_markers_resolved(self) -> None: | |
| text = ( | |
| "First fact [Source: report.pdf, p.3]. " | |
| "Second fact [Source: paper.pdf, p.7]." | |
| ) | |
| _, citations = self.injector.inject(text, self.citation_map) | |
| assert len(citations) == 2 | |
| def test_duplicate_markers_deduplicated(self) -> None: | |
| text = ( | |
| "Claim one [Source: report.pdf, p.3]. " | |
| "Same source again [Source: report.pdf, p.3]." | |
| ) | |
| _, citations = self.injector.inject(text, self.citation_map) | |
| assert len(citations) == 1 | |
| def test_answer_text_preserved_with_markers(self) -> None: | |
| """Markers are preserved in the answer text (not stripped).""" | |
| text = "The result was 94% [Source: report.pdf, p.3]." | |
| answer, _ = self.injector.inject(text, self.citation_map) | |
| assert "[Source: report.pdf, p.3]" in answer | |
| def test_citation_order_matches_first_appearance(self) -> None: | |
| text = ( | |
| "Paper result [Source: paper.pdf, p.7]. " | |
| "Report result [Source: report.pdf, p.3]." | |
| ) | |
| _, citations = self.injector.inject(text, self.citation_map) | |
| assert citations[0].source_file == "paper.pdf" | |
| assert citations[1].source_file == "report.pdf" | |
| def test_empty_citation_map_returns_no_citations(self) -> None: | |
| text = "Result [Source: anything.pdf, p.1]." | |
| _, citations = self.injector.inject(text, []) | |
| assert citations == [] | |
| class TestCitationInjectorMatchingStrategies: | |
| """Test the four resolution strategies.""" | |
| def setup_method(self) -> None: | |
| self.injector = CitationInjector() | |
| def test_strategy1_exact_match(self) -> None: | |
| """Strategy 1: exact filename + exact page.""" | |
| cmap = [_make_citation("report.pdf", 5), _make_citation("other.pdf", 5)] | |
| _, citations = self.injector.inject("[Source: report.pdf, p.5]", cmap) | |
| assert citations[0].source_file == "report.pdf" | |
| def test_strategy2_substring_match(self) -> None: | |
| """Strategy 2: filename substring + page.""" | |
| cmap = [_make_citation("annual_report_2024.pdf", 3)] | |
| _, citations = self.injector.inject("[Source: report, p.3]", cmap) | |
| assert len(citations) == 1 | |
| assert citations[0].source_file == "annual_report_2024.pdf" | |
| def test_strategy3_page_only_match(self) -> None: | |
| """Strategy 3: page number match as fallback.""" | |
| cmap = [_make_citation("unique_name.pdf", 9)] | |
| _, citations = self.injector.inject("[Source: unknownfile, p.9]", cmap) | |
| assert len(citations) == 1 | |
| assert citations[0].page_number == 9 | |
| def test_strategy4_filename_no_page(self) -> None: | |
| """Strategy 4: filename substring with no page number.""" | |
| cmap = [_make_citation("research.pdf", 1)] | |
| _, citations = self.injector.inject("[Source: research]", cmap) | |
| assert len(citations) == 1 | |
| def test_last_resort_first_citation(self) -> None: | |
| """Last resort: return first citation when nothing else matches.""" | |
| cmap = [ | |
| _make_citation("alpha.pdf", 1), | |
| _make_citation("beta.pdf", 2), | |
| ] | |
| _, citations = self.injector.inject("[Source: zzz_no_match.pdf, p.99]", cmap) | |
| assert len(citations) == 1 | |
| assert citations[0].source_file == "alpha.pdf" | |
| # ------------------------------------------------------------------ # | |
| # FaithfulnessGuard Tests # | |
| # ------------------------------------------------------------------ # | |
| class TestFaithfulnessGuardRefusal: | |
| """Refusal detection edge cases.""" | |
| def setup_method(self) -> None: | |
| self.guard = FaithfulnessGuard() | |
| def test_exact_refusal_phrase_detected(self) -> None: | |
| assert self.guard.is_refusal(REFUSAL_PHRASE) is True | |
| def test_refusal_case_insensitive(self) -> None: | |
| assert self.guard.is_refusal(REFUSAL_PHRASE.upper()) is True | |
| def test_refusal_embedded_in_text(self) -> None: | |
| text = f"Sorry, {REFUSAL_PHRASE} Please try another query." | |
| assert self.guard.is_refusal(text) is True | |
| def test_normal_answer_not_refusal(self) -> None: | |
| assert self.guard.is_refusal("Machine learning is a subset of AI.") is False | |
| def test_empty_string_is_refusal(self) -> None: | |
| assert self.guard.is_refusal("") is True | |
| def test_partial_phrase_not_refusal(self) -> None: | |
| assert self.guard.is_refusal("I could not find this") is False | |
| def test_refusal_without_trailing_period(self) -> None: | |
| phrase_no_period = REFUSAL_PHRASE.rstrip(".") | |
| assert self.guard.is_refusal(phrase_no_period) is True | |
| class TestFaithfulnessGuardConfidence: | |
| """Confidence level scoring.""" | |
| def setup_method(self) -> None: | |
| self.guard = FaithfulnessGuard() | |
| def test_empty_results_returns_low(self) -> None: | |
| assert self.guard.confidence_level([]) == "low" | |
| def test_high_rerank_score_returns_high(self) -> None: | |
| results = [_make_retrieval_result(rerank_score=0.9)] | |
| assert self.guard.confidence_level(results) == "high" | |
| def test_medium_rerank_score_returns_medium(self) -> None: | |
| results = [_make_retrieval_result(rerank_score=0.35)] | |
| assert self.guard.confidence_level(results) == "medium" | |
| def test_low_rerank_score_returns_low(self) -> None: | |
| results = [_make_retrieval_result(rerank_score=0.1)] | |
| assert self.guard.confidence_level(results) == "low" | |
| def test_uses_max_score_across_results(self) -> None: | |
| results = [ | |
| _make_retrieval_result(rerank_score=0.1), | |
| _make_retrieval_result(rerank_score=0.8), | |
| _make_retrieval_result(rerank_score=0.3), | |
| ] | |
| assert self.guard.confidence_level(results) == "high" | |
| def test_zero_rerank_falls_back_to_rrf_score(self) -> None: | |
| """When rerank_score is 0, rrf_score should be used.""" | |
| results = [_make_retrieval_result(rerank_score=0.0, rrf_score=0.6)] | |
| assert self.guard.confidence_level(results) == "high" | |
| def test_boundary_above_0_5_is_high(self) -> None: | |
| results = [_make_retrieval_result(rerank_score=0.51)] | |
| assert self.guard.confidence_level(results) == "high" | |
| def test_boundary_exactly_0_5_is_medium(self) -> None: | |
| results = [_make_retrieval_result(rerank_score=0.5)] | |
| assert self.guard.confidence_level(results) == "medium" | |
| def test_boundary_exactly_0_2_is_low(self) -> None: | |
| results = [_make_retrieval_result(rerank_score=0.2)] | |
| assert self.guard.confidence_level(results) == "low" | |
| def test_boundary_above_0_2_is_medium(self) -> None: | |
| results = [_make_retrieval_result(rerank_score=0.21)] | |
| assert self.guard.confidence_level(results) == "medium" | |
| class TestFaithfulnessGuardSystemPrompt: | |
| """System prompt construction.""" | |
| def test_system_prompt_instruction_contains_refusal_phrase(self) -> None: | |
| instruction = FaithfulnessGuard.system_prompt_instruction() | |
| assert REFUSAL_PHRASE in instruction | |
| def test_system_prompt_instruction_non_empty(self) -> None: | |
| assert len(FaithfulnessGuard.system_prompt_instruction()) > 50 | |
| def test_build_system_prompt_contains_citation_rules(self) -> None: | |
| prompt = FaithfulnessGuard.build_system_prompt() | |
| assert "CITATION RULES" in prompt | |
| def test_build_system_prompt_contains_faithfulness_rules(self) -> None: | |
| prompt = FaithfulnessGuard.build_system_prompt() | |
| assert "FAITHFULNESS RULES" in prompt | |
| def test_build_system_prompt_contains_refusal_phrase(self) -> None: | |
| prompt = FaithfulnessGuard.build_system_prompt() | |
| assert REFUSAL_PHRASE in prompt | |
| def test_build_system_prompt_non_empty(self) -> None: | |
| assert len(FaithfulnessGuard.build_system_prompt()) > 200 | |
| # ------------------------------------------------------------------ # | |
| # AnswerChain Tests # | |
| # ------------------------------------------------------------------ # | |
| class TestAnswerChainMessageBuilding: | |
| """Verify the LangChain message list is constructed correctly.""" | |
| def setup_method(self) -> None: | |
| from voicevault.generation.answer_chain import AnswerChain | |
| self.chain = AnswerChain() | |
| def test_messages_start_with_system(self) -> None: | |
| from langchain_core.messages import SystemMessage | |
| messages = self.chain._build_messages("what is AI?", "ctx", []) | |
| assert isinstance(messages[0], SystemMessage) | |
| def test_messages_end_with_human(self) -> None: | |
| from langchain_core.messages import HumanMessage | |
| messages = self.chain._build_messages("what is AI?", "ctx", []) | |
| assert isinstance(messages[-1], HumanMessage) | |
| def test_context_in_last_human_message(self) -> None: | |
| messages = self.chain._build_messages("what is AI?", "CONTEXT_TEXT", []) | |
| assert "CONTEXT_TEXT" in messages[-1].content | |
| def test_query_in_last_human_message(self) -> None: | |
| messages = self.chain._build_messages("what is AI?", "ctx", []) | |
| assert "what is AI?" in messages[-1].content | |
| def test_history_injected_as_human_ai_pairs(self) -> None: | |
| from langchain_core.messages import AIMessage, HumanMessage | |
| history = [("q1", "a1"), ("q2", "a2")] | |
| messages = self.chain._build_messages("q3", "ctx", history) | |
| # system + (human + AI) × 2 + human = 6 | |
| assert len(messages) == 6 | |
| assert isinstance(messages[1], HumanMessage) | |
| assert isinstance(messages[2], AIMessage) | |
| assert messages[1].content == "q1" | |
| assert messages[2].content == "a1" | |
| def test_history_capped_at_conversation_window(self) -> None: | |
| from config import cfg | |
| long_history = [(f"q{i}", f"a{i}") for i in range(20)] | |
| messages = self.chain._build_messages("current", "ctx", long_history) | |
| # system + (human + AI) × window + human | |
| expected_len = 1 + cfg.conversation_window * 2 + 1 | |
| assert len(messages) == expected_len | |
| def test_no_history_three_messages_only(self) -> None: | |
| messages = self.chain._build_messages("q", "ctx", []) | |
| assert len(messages) == 2 # system + human | |
| class TestAnswerChainMaxTokens: | |
| """Max tokens budget per query type.""" | |
| def setup_method(self) -> None: | |
| from voicevault.generation.answer_chain import AnswerChain | |
| self.chain = AnswerChain() | |
| def test_factual_uses_base_max_tokens(self) -> None: | |
| from config import cfg | |
| assert self.chain._max_tokens_for("factual") == cfg.max_answer_tokens | |
| def test_summary_uses_double_max_tokens(self) -> None: | |
| from config import cfg | |
| assert self.chain._max_tokens_for("summary") == cfg.max_answer_tokens * 2 | |
| def test_compare_uses_base_max_tokens(self) -> None: | |
| from config import cfg | |
| assert self.chain._max_tokens_for("compare") == cfg.max_answer_tokens | |
| class TestAnswerChainTokenExtraction: | |
| """Token extraction from AIMessage responses.""" | |
| def test_extracts_tokens_from_usage_metadata(self) -> None: | |
| from voicevault.generation.answer_chain import _extract_tokens | |
| response = MagicMock() | |
| response.usage_metadata = {"total_tokens": 123} | |
| assert _extract_tokens(response) == 123 | |
| def test_returns_zero_when_no_metadata(self) -> None: | |
| from voicevault.generation.answer_chain import _extract_tokens | |
| response = MagicMock() | |
| response.usage_metadata = None | |
| assert _extract_tokens(response) == 0 | |
| def test_returns_zero_when_attribute_missing(self) -> None: | |
| from voicevault.generation.answer_chain import _extract_tokens | |
| response = MagicMock(spec=[]) # No attributes | |
| assert _extract_tokens(response) == 0 | |
| def test_returns_zero_on_type_error(self) -> None: | |
| from voicevault.generation.answer_chain import _extract_tokens | |
| response = MagicMock() | |
| response.usage_metadata = "not_a_dict" | |
| # .get() on a string raises AttributeError | |
| assert _extract_tokens(response) == 0 | |
| class TestAnswerChainConfidenceFromCitations: | |
| """Citation-based confidence scoring.""" | |
| def test_empty_citation_map_returns_low(self) -> None: | |
| from voicevault.generation.answer_chain import _confidence_from_citations | |
| assert _confidence_from_citations([]) == "low" | |
| def test_high_relevance_returns_high(self) -> None: | |
| from voicevault.generation.answer_chain import _confidence_from_citations | |
| cmap = [_make_citation(relevance_score=0.9)] | |
| assert _confidence_from_citations(cmap) == "high" | |
| def test_medium_relevance_returns_medium(self) -> None: | |
| from voicevault.generation.answer_chain import _confidence_from_citations | |
| cmap = [_make_citation(relevance_score=0.35)] | |
| assert _confidence_from_citations(cmap) == "medium" | |
| def test_low_relevance_returns_low(self) -> None: | |
| from voicevault.generation.answer_chain import _confidence_from_citations | |
| cmap = [_make_citation(relevance_score=0.05)] | |
| assert _confidence_from_citations(cmap) == "low" | |
| def test_uses_max_across_multiple_citations(self) -> None: | |
| from voicevault.generation.answer_chain import _confidence_from_citations | |
| cmap = [ | |
| _make_citation(relevance_score=0.1), | |
| _make_citation(relevance_score=0.9), | |
| ] | |
| assert _confidence_from_citations(cmap) == "high" | |
| class TestAnswerChainGenerateMocked: | |
| """Test generate() with mocked LLM responses.""" | |
| def _make_mock_response(self, content: str, total_tokens: int = 150) -> MagicMock: | |
| response = MagicMock() | |
| response.content = content | |
| response.usage_metadata = {"total_tokens": total_tokens} | |
| return response | |
| def test_generate_returns_generation_result(self) -> None: | |
| from voicevault.generation.answer_chain import AnswerChain, GenerationResult | |
| chain = AnswerChain() | |
| citation = _make_citation("doc.pdf", 1, relevance_score=0.8) | |
| mock_response = self._make_mock_response("ML is a subset of AI [Source: doc.pdf, p.1].") | |
| with patch.object(chain, "_build_groq", return_value=MagicMock(invoke=lambda m: mock_response)): | |
| result = chain.generate( | |
| query="what is ML", | |
| context="[Source: doc.pdf, p.1]\nML is...", | |
| citation_map=[citation], | |
| query_type="factual", | |
| ) | |
| assert isinstance(result, GenerationResult) | |
| def test_generate_extracts_answer(self) -> None: | |
| from voicevault.generation.answer_chain import AnswerChain | |
| chain = AnswerChain() | |
| mock_response = self._make_mock_response("ML stands for Machine Learning.") | |
| with patch.object(chain, "_build_groq", return_value=MagicMock(invoke=lambda m: mock_response)): | |
| result = chain.generate( | |
| query="what is ML", | |
| context="context text", | |
| citation_map=[], | |
| ) | |
| assert "Machine Learning" in result.answer | |
| def test_generate_records_latency(self) -> None: | |
| from voicevault.generation.answer_chain import AnswerChain | |
| chain = AnswerChain() | |
| mock_response = self._make_mock_response("Some answer.") | |
| with patch.object(chain, "_build_groq", return_value=MagicMock(invoke=lambda m: mock_response)): | |
| result = chain.generate("q", "ctx", []) | |
| assert result.latency_ms >= 0 | |
| def test_generate_records_tokens(self) -> None: | |
| from voicevault.generation.answer_chain import AnswerChain | |
| chain = AnswerChain() | |
| mock_response = self._make_mock_response("Answer.", total_tokens=200) | |
| with patch.object(chain, "_build_groq", return_value=MagicMock(invoke=lambda m: mock_response)): | |
| result = chain.generate("q", "ctx", []) | |
| assert result.tokens_used == 200 | |
| def test_generate_detects_refusal(self) -> None: | |
| from voicevault.generation.answer_chain import AnswerChain | |
| chain = AnswerChain() | |
| mock_response = self._make_mock_response(REFUSAL_PHRASE) | |
| with patch.object(chain, "_build_groq", return_value=MagicMock(invoke=lambda m: mock_response)): | |
| result = chain.generate("q", "ctx", []) | |
| assert result.is_refusal is True | |
| def test_generate_non_refusal_answer_not_flagged(self) -> None: | |
| from voicevault.generation.answer_chain import AnswerChain | |
| chain = AnswerChain() | |
| mock_response = self._make_mock_response("This is a real answer.") | |
| with patch.object(chain, "_build_groq", return_value=MagicMock(invoke=lambda m: mock_response)): | |
| result = chain.generate("q", "ctx", []) | |
| assert result.is_refusal is False | |
| def test_generate_resolves_citations(self) -> None: | |
| from voicevault.generation.answer_chain import AnswerChain | |
| chain = AnswerChain() | |
| citation = _make_citation("paper.pdf", 4, relevance_score=0.8) | |
| mock_response = self._make_mock_response( | |
| "The accuracy was 94% [Source: paper.pdf, p.4]." | |
| ) | |
| with patch.object(chain, "_build_groq", return_value=MagicMock(invoke=lambda m: mock_response)): | |
| result = chain.generate("q", "ctx", citation_map=[citation]) | |
| assert len(result.citations) == 1 | |
| assert result.citations[0].source_file == "paper.pdf" | |
| class TestAnswerChainFallback: | |
| """Test Groq → Gemini fallback behavior.""" | |
| def _make_mock_response(self, content: str) -> MagicMock: | |
| response = MagicMock() | |
| response.content = content | |
| response.usage_metadata = {"total_tokens": 50} | |
| return response | |
| def test_falls_back_to_gemini_when_groq_raises(self) -> None: | |
| from config import cfg | |
| from voicevault.generation.answer_chain import AnswerChain | |
| chain = AnswerChain() | |
| groq_llm = MagicMock() | |
| groq_llm.invoke.side_effect = RuntimeError("Groq API error") | |
| gemini_llm = MagicMock() | |
| gemini_llm.invoke.return_value = self._make_mock_response("Gemini answered.") | |
| with ( | |
| patch.object(chain, "_build_groq", return_value=groq_llm), | |
| patch.object(chain, "_build_gemini", return_value=gemini_llm), | |
| ): | |
| result = chain.generate("q", "ctx", []) | |
| assert result.model_used == cfg.gemini_llm_model | |
| assert "Gemini answered" in result.answer | |
| def test_returns_refusal_when_both_fail(self) -> None: | |
| from voicevault.generation.answer_chain import AnswerChain | |
| chain = AnswerChain() | |
| failing_llm = MagicMock() | |
| failing_llm.invoke.side_effect = RuntimeError("API error") | |
| with ( | |
| patch.object(chain, "_build_groq", return_value=failing_llm), | |
| patch.object(chain, "_build_gemini", return_value=failing_llm), | |
| ): | |
| result = chain.generate("q", "ctx", []) | |
| assert result.model_used == "none" | |
| assert result.is_refusal is True | |
| assert REFUSAL_PHRASE in result.answer | |
| def test_returns_refusal_when_no_keys_configured(self) -> None: | |
| from voicevault.generation.answer_chain import AnswerChain | |
| chain = AnswerChain() | |
| with ( | |
| patch.object(chain, "_build_groq", return_value=None), | |
| patch.object(chain, "_build_gemini", return_value=None), | |
| ): | |
| result = chain.generate("q", "ctx", []) | |
| assert result.model_used == "none" | |
| assert REFUSAL_PHRASE in result.answer | |
| def test_groq_used_when_available(self) -> None: | |
| from config import cfg | |
| from voicevault.generation.answer_chain import AnswerChain | |
| chain = AnswerChain() | |
| groq_llm = MagicMock() | |
| groq_llm.invoke.return_value = self._make_mock_response("Groq answered.") | |
| with patch.object(chain, "_build_groq", return_value=groq_llm): | |
| result = chain.generate("q", "ctx", []) | |
| assert result.model_used == cfg.groq_llm_model | |
| assert "Groq answered" in result.answer | |
| class TestAnswerChainStreaming: | |
| """Test stream_generate() token streaming.""" | |
| def test_streaming_yields_chunks(self) -> None: | |
| from voicevault.generation.answer_chain import AnswerChain | |
| chain = AnswerChain() | |
| mock_chunks = [ | |
| MagicMock(content="Hello "), | |
| MagicMock(content="world"), | |
| MagicMock(content="!"), | |
| ] | |
| mock_llm = MagicMock() | |
| mock_llm.stream.return_value = iter(mock_chunks) | |
| with patch.object(chain, "_build_groq", return_value=mock_llm): | |
| chunks = list(chain.stream_generate("q", "ctx", [])) | |
| assert chunks == ["Hello ", "world", "!"] | |
| def test_streaming_skips_empty_chunks(self) -> None: | |
| from voicevault.generation.answer_chain import AnswerChain | |
| chain = AnswerChain() | |
| mock_chunks = [ | |
| MagicMock(content="real"), | |
| MagicMock(content=""), # empty — should be skipped | |
| MagicMock(content=" content"), | |
| ] | |
| mock_llm = MagicMock() | |
| mock_llm.stream.return_value = iter(mock_chunks) | |
| with patch.object(chain, "_build_groq", return_value=mock_llm): | |
| chunks = list(chain.stream_generate("q", "ctx", [])) | |
| assert "" not in chunks | |
| assert chunks == ["real", " content"] | |
| def test_streaming_returns_refusal_when_no_llm(self) -> None: | |
| from voicevault.generation.answer_chain import AnswerChain | |
| chain = AnswerChain() | |
| with ( | |
| patch.object(chain, "_build_groq", return_value=None), | |
| patch.object(chain, "_build_gemini", return_value=None), | |
| ): | |
| chunks = list(chain.stream_generate("q", "ctx", [])) | |
| assert REFUSAL_PHRASE in "".join(chunks) | |
| def test_streaming_yields_error_on_exception(self) -> None: | |
| from voicevault.generation.answer_chain import AnswerChain | |
| chain = AnswerChain() | |
| mock_llm = MagicMock() | |
| mock_llm.stream.side_effect = RuntimeError("connection refused") | |
| with patch.object(chain, "_build_groq", return_value=mock_llm): | |
| chunks = list(chain.stream_generate("q", "ctx", [])) | |
| combined = "".join(chunks) | |
| assert "Error" in combined or "error" in combined | |
| # ------------------------------------------------------------------ # | |
| # GenerationResult Model Tests # | |
| # ------------------------------------------------------------------ # | |
| class TestGenerationResult: | |
| """Verify GenerationResult dataclass.""" | |
| def test_can_instantiate(self) -> None: | |
| from voicevault.generation.answer_chain import GenerationResult | |
| result = GenerationResult( | |
| answer="test answer", | |
| citations=[], | |
| confidence_level="high", | |
| is_refusal=False, | |
| model_used="llama-3.1-70b-versatile", | |
| tokens_used=100, | |
| latency_ms=250, | |
| ) | |
| assert result.answer == "test answer" | |
| assert result.confidence_level == "high" | |
| assert result.is_refusal is False | |
| assert result.tokens_used == 100 | |
| assert result.latency_ms == 250 | |
| def test_citations_list_is_mutable(self) -> None: | |
| from voicevault.generation.answer_chain import GenerationResult | |
| result = GenerationResult( | |
| answer="", | |
| citations=[], | |
| confidence_level="low", | |
| is_refusal=True, | |
| model_used="none", | |
| tokens_used=0, | |
| latency_ms=0, | |
| ) | |
| result.citations.append(_make_citation()) | |
| assert len(result.citations) == 1 | |