VoiceVault / tests /test_phase4.py
NinjainPJs's picture
Initial release: VoiceVault v1.0.0 — Voice-First RAG Knowledge Agent
85f900d
"""
tests/test_phase4.py
====================
Phase 4 — LLM Generation Chain Tests
Tests:
- CitationInjector: marker parsing, resolution strategies, deduplication
- FaithfulnessGuard: refusal detection, confidence scoring, system prompt
- AnswerChain: message building, fallback logic (mocked LLMs), streaming,
token extraction, max_tokens per query_type
All LLM calls are mocked — no real API keys required.
Run with: pytest tests/test_phase4.py -v
"""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from voicevault.generation.citation_injector import CitationInjector
from voicevault.generation.faithfulness_guard import (
REFUSAL_PHRASE,
FaithfulnessGuard,
)
from voicevault.models import Citation, RetrievalResult
# ------------------------------------------------------------------ #
# Helpers #
# ------------------------------------------------------------------ #
def _make_citation(
source_file: str = "report.pdf",
page_number: int = 1,
section: str = "Introduction",
excerpt: str = "Some relevant excerpt.",
relevance_score: float = 0.8,
) -> Citation:
return Citation(
source_file=source_file,
page_number=page_number,
section=section,
excerpt=excerpt,
relevance_score=relevance_score,
)
def _make_retrieval_result(rerank_score: float = 0.0, rrf_score: float = 0.0) -> RetrievalResult:
return RetrievalResult(
chunk_id="test-chunk",
text="test text",
source_file="test.pdf",
page_number=1,
rrf_score=rrf_score,
rerank_score=rerank_score,
)
# ------------------------------------------------------------------ #
# CitationInjector Tests #
# ------------------------------------------------------------------ #
class TestCitationInjectorBasic:
"""Core parsing and injection behavior."""
def setup_method(self) -> None:
self.injector = CitationInjector()
self.citation_map = [
_make_citation("report.pdf", 3),
_make_citation("paper.pdf", 7),
]
def test_empty_answer_returns_empty(self) -> None:
answer, citations = self.injector.inject("", self.citation_map)
assert answer == ""
assert citations == []
def test_answer_without_markers_returned_unchanged(self) -> None:
text = "Machine learning is a field of AI."
answer, citations = self.injector.inject(text, self.citation_map)
assert answer == text
assert citations == []
def test_exact_filename_and_page_resolved(self) -> None:
text = "The accuracy was 94% [Source: report.pdf, p.3]."
_, citations = self.injector.inject(text, self.citation_map)
assert len(citations) == 1
assert citations[0].source_file == "report.pdf"
assert citations[0].page_number == 3
def test_multiple_markers_resolved(self) -> None:
text = (
"First fact [Source: report.pdf, p.3]. "
"Second fact [Source: paper.pdf, p.7]."
)
_, citations = self.injector.inject(text, self.citation_map)
assert len(citations) == 2
def test_duplicate_markers_deduplicated(self) -> None:
text = (
"Claim one [Source: report.pdf, p.3]. "
"Same source again [Source: report.pdf, p.3]."
)
_, citations = self.injector.inject(text, self.citation_map)
assert len(citations) == 1
def test_answer_text_preserved_with_markers(self) -> None:
"""Markers are preserved in the answer text (not stripped)."""
text = "The result was 94% [Source: report.pdf, p.3]."
answer, _ = self.injector.inject(text, self.citation_map)
assert "[Source: report.pdf, p.3]" in answer
def test_citation_order_matches_first_appearance(self) -> None:
text = (
"Paper result [Source: paper.pdf, p.7]. "
"Report result [Source: report.pdf, p.3]."
)
_, citations = self.injector.inject(text, self.citation_map)
assert citations[0].source_file == "paper.pdf"
assert citations[1].source_file == "report.pdf"
def test_empty_citation_map_returns_no_citations(self) -> None:
text = "Result [Source: anything.pdf, p.1]."
_, citations = self.injector.inject(text, [])
assert citations == []
class TestCitationInjectorMatchingStrategies:
"""Test the four resolution strategies."""
def setup_method(self) -> None:
self.injector = CitationInjector()
def test_strategy1_exact_match(self) -> None:
"""Strategy 1: exact filename + exact page."""
cmap = [_make_citation("report.pdf", 5), _make_citation("other.pdf", 5)]
_, citations = self.injector.inject("[Source: report.pdf, p.5]", cmap)
assert citations[0].source_file == "report.pdf"
def test_strategy2_substring_match(self) -> None:
"""Strategy 2: filename substring + page."""
cmap = [_make_citation("annual_report_2024.pdf", 3)]
_, citations = self.injector.inject("[Source: report, p.3]", cmap)
assert len(citations) == 1
assert citations[0].source_file == "annual_report_2024.pdf"
def test_strategy3_page_only_match(self) -> None:
"""Strategy 3: page number match as fallback."""
cmap = [_make_citation("unique_name.pdf", 9)]
_, citations = self.injector.inject("[Source: unknownfile, p.9]", cmap)
assert len(citations) == 1
assert citations[0].page_number == 9
def test_strategy4_filename_no_page(self) -> None:
"""Strategy 4: filename substring with no page number."""
cmap = [_make_citation("research.pdf", 1)]
_, citations = self.injector.inject("[Source: research]", cmap)
assert len(citations) == 1
def test_last_resort_first_citation(self) -> None:
"""Last resort: return first citation when nothing else matches."""
cmap = [
_make_citation("alpha.pdf", 1),
_make_citation("beta.pdf", 2),
]
_, citations = self.injector.inject("[Source: zzz_no_match.pdf, p.99]", cmap)
assert len(citations) == 1
assert citations[0].source_file == "alpha.pdf"
# ------------------------------------------------------------------ #
# FaithfulnessGuard Tests #
# ------------------------------------------------------------------ #
class TestFaithfulnessGuardRefusal:
"""Refusal detection edge cases."""
def setup_method(self) -> None:
self.guard = FaithfulnessGuard()
def test_exact_refusal_phrase_detected(self) -> None:
assert self.guard.is_refusal(REFUSAL_PHRASE) is True
def test_refusal_case_insensitive(self) -> None:
assert self.guard.is_refusal(REFUSAL_PHRASE.upper()) is True
def test_refusal_embedded_in_text(self) -> None:
text = f"Sorry, {REFUSAL_PHRASE} Please try another query."
assert self.guard.is_refusal(text) is True
def test_normal_answer_not_refusal(self) -> None:
assert self.guard.is_refusal("Machine learning is a subset of AI.") is False
def test_empty_string_is_refusal(self) -> None:
assert self.guard.is_refusal("") is True
def test_partial_phrase_not_refusal(self) -> None:
assert self.guard.is_refusal("I could not find this") is False
def test_refusal_without_trailing_period(self) -> None:
phrase_no_period = REFUSAL_PHRASE.rstrip(".")
assert self.guard.is_refusal(phrase_no_period) is True
class TestFaithfulnessGuardConfidence:
"""Confidence level scoring."""
def setup_method(self) -> None:
self.guard = FaithfulnessGuard()
def test_empty_results_returns_low(self) -> None:
assert self.guard.confidence_level([]) == "low"
def test_high_rerank_score_returns_high(self) -> None:
results = [_make_retrieval_result(rerank_score=0.9)]
assert self.guard.confidence_level(results) == "high"
def test_medium_rerank_score_returns_medium(self) -> None:
results = [_make_retrieval_result(rerank_score=0.35)]
assert self.guard.confidence_level(results) == "medium"
def test_low_rerank_score_returns_low(self) -> None:
results = [_make_retrieval_result(rerank_score=0.1)]
assert self.guard.confidence_level(results) == "low"
def test_uses_max_score_across_results(self) -> None:
results = [
_make_retrieval_result(rerank_score=0.1),
_make_retrieval_result(rerank_score=0.8),
_make_retrieval_result(rerank_score=0.3),
]
assert self.guard.confidence_level(results) == "high"
def test_zero_rerank_falls_back_to_rrf_score(self) -> None:
"""When rerank_score is 0, rrf_score should be used."""
results = [_make_retrieval_result(rerank_score=0.0, rrf_score=0.6)]
assert self.guard.confidence_level(results) == "high"
def test_boundary_above_0_5_is_high(self) -> None:
results = [_make_retrieval_result(rerank_score=0.51)]
assert self.guard.confidence_level(results) == "high"
def test_boundary_exactly_0_5_is_medium(self) -> None:
results = [_make_retrieval_result(rerank_score=0.5)]
assert self.guard.confidence_level(results) == "medium"
def test_boundary_exactly_0_2_is_low(self) -> None:
results = [_make_retrieval_result(rerank_score=0.2)]
assert self.guard.confidence_level(results) == "low"
def test_boundary_above_0_2_is_medium(self) -> None:
results = [_make_retrieval_result(rerank_score=0.21)]
assert self.guard.confidence_level(results) == "medium"
class TestFaithfulnessGuardSystemPrompt:
"""System prompt construction."""
def test_system_prompt_instruction_contains_refusal_phrase(self) -> None:
instruction = FaithfulnessGuard.system_prompt_instruction()
assert REFUSAL_PHRASE in instruction
def test_system_prompt_instruction_non_empty(self) -> None:
assert len(FaithfulnessGuard.system_prompt_instruction()) > 50
def test_build_system_prompt_contains_citation_rules(self) -> None:
prompt = FaithfulnessGuard.build_system_prompt()
assert "CITATION RULES" in prompt
def test_build_system_prompt_contains_faithfulness_rules(self) -> None:
prompt = FaithfulnessGuard.build_system_prompt()
assert "FAITHFULNESS RULES" in prompt
def test_build_system_prompt_contains_refusal_phrase(self) -> None:
prompt = FaithfulnessGuard.build_system_prompt()
assert REFUSAL_PHRASE in prompt
def test_build_system_prompt_non_empty(self) -> None:
assert len(FaithfulnessGuard.build_system_prompt()) > 200
# ------------------------------------------------------------------ #
# AnswerChain Tests #
# ------------------------------------------------------------------ #
class TestAnswerChainMessageBuilding:
"""Verify the LangChain message list is constructed correctly."""
def setup_method(self) -> None:
from voicevault.generation.answer_chain import AnswerChain
self.chain = AnswerChain()
def test_messages_start_with_system(self) -> None:
from langchain_core.messages import SystemMessage
messages = self.chain._build_messages("what is AI?", "ctx", [])
assert isinstance(messages[0], SystemMessage)
def test_messages_end_with_human(self) -> None:
from langchain_core.messages import HumanMessage
messages = self.chain._build_messages("what is AI?", "ctx", [])
assert isinstance(messages[-1], HumanMessage)
def test_context_in_last_human_message(self) -> None:
messages = self.chain._build_messages("what is AI?", "CONTEXT_TEXT", [])
assert "CONTEXT_TEXT" in messages[-1].content
def test_query_in_last_human_message(self) -> None:
messages = self.chain._build_messages("what is AI?", "ctx", [])
assert "what is AI?" in messages[-1].content
def test_history_injected_as_human_ai_pairs(self) -> None:
from langchain_core.messages import AIMessage, HumanMessage
history = [("q1", "a1"), ("q2", "a2")]
messages = self.chain._build_messages("q3", "ctx", history)
# system + (human + AI) × 2 + human = 6
assert len(messages) == 6
assert isinstance(messages[1], HumanMessage)
assert isinstance(messages[2], AIMessage)
assert messages[1].content == "q1"
assert messages[2].content == "a1"
def test_history_capped_at_conversation_window(self) -> None:
from config import cfg
long_history = [(f"q{i}", f"a{i}") for i in range(20)]
messages = self.chain._build_messages("current", "ctx", long_history)
# system + (human + AI) × window + human
expected_len = 1 + cfg.conversation_window * 2 + 1
assert len(messages) == expected_len
def test_no_history_three_messages_only(self) -> None:
messages = self.chain._build_messages("q", "ctx", [])
assert len(messages) == 2 # system + human
class TestAnswerChainMaxTokens:
"""Max tokens budget per query type."""
def setup_method(self) -> None:
from voicevault.generation.answer_chain import AnswerChain
self.chain = AnswerChain()
def test_factual_uses_base_max_tokens(self) -> None:
from config import cfg
assert self.chain._max_tokens_for("factual") == cfg.max_answer_tokens
def test_summary_uses_double_max_tokens(self) -> None:
from config import cfg
assert self.chain._max_tokens_for("summary") == cfg.max_answer_tokens * 2
def test_compare_uses_base_max_tokens(self) -> None:
from config import cfg
assert self.chain._max_tokens_for("compare") == cfg.max_answer_tokens
class TestAnswerChainTokenExtraction:
"""Token extraction from AIMessage responses."""
def test_extracts_tokens_from_usage_metadata(self) -> None:
from voicevault.generation.answer_chain import _extract_tokens
response = MagicMock()
response.usage_metadata = {"total_tokens": 123}
assert _extract_tokens(response) == 123
def test_returns_zero_when_no_metadata(self) -> None:
from voicevault.generation.answer_chain import _extract_tokens
response = MagicMock()
response.usage_metadata = None
assert _extract_tokens(response) == 0
def test_returns_zero_when_attribute_missing(self) -> None:
from voicevault.generation.answer_chain import _extract_tokens
response = MagicMock(spec=[]) # No attributes
assert _extract_tokens(response) == 0
def test_returns_zero_on_type_error(self) -> None:
from voicevault.generation.answer_chain import _extract_tokens
response = MagicMock()
response.usage_metadata = "not_a_dict"
# .get() on a string raises AttributeError
assert _extract_tokens(response) == 0
class TestAnswerChainConfidenceFromCitations:
"""Citation-based confidence scoring."""
def test_empty_citation_map_returns_low(self) -> None:
from voicevault.generation.answer_chain import _confidence_from_citations
assert _confidence_from_citations([]) == "low"
def test_high_relevance_returns_high(self) -> None:
from voicevault.generation.answer_chain import _confidence_from_citations
cmap = [_make_citation(relevance_score=0.9)]
assert _confidence_from_citations(cmap) == "high"
def test_medium_relevance_returns_medium(self) -> None:
from voicevault.generation.answer_chain import _confidence_from_citations
cmap = [_make_citation(relevance_score=0.35)]
assert _confidence_from_citations(cmap) == "medium"
def test_low_relevance_returns_low(self) -> None:
from voicevault.generation.answer_chain import _confidence_from_citations
cmap = [_make_citation(relevance_score=0.05)]
assert _confidence_from_citations(cmap) == "low"
def test_uses_max_across_multiple_citations(self) -> None:
from voicevault.generation.answer_chain import _confidence_from_citations
cmap = [
_make_citation(relevance_score=0.1),
_make_citation(relevance_score=0.9),
]
assert _confidence_from_citations(cmap) == "high"
class TestAnswerChainGenerateMocked:
"""Test generate() with mocked LLM responses."""
def _make_mock_response(self, content: str, total_tokens: int = 150) -> MagicMock:
response = MagicMock()
response.content = content
response.usage_metadata = {"total_tokens": total_tokens}
return response
def test_generate_returns_generation_result(self) -> None:
from voicevault.generation.answer_chain import AnswerChain, GenerationResult
chain = AnswerChain()
citation = _make_citation("doc.pdf", 1, relevance_score=0.8)
mock_response = self._make_mock_response("ML is a subset of AI [Source: doc.pdf, p.1].")
with patch.object(chain, "_build_groq", return_value=MagicMock(invoke=lambda m: mock_response)):
result = chain.generate(
query="what is ML",
context="[Source: doc.pdf, p.1]\nML is...",
citation_map=[citation],
query_type="factual",
)
assert isinstance(result, GenerationResult)
def test_generate_extracts_answer(self) -> None:
from voicevault.generation.answer_chain import AnswerChain
chain = AnswerChain()
mock_response = self._make_mock_response("ML stands for Machine Learning.")
with patch.object(chain, "_build_groq", return_value=MagicMock(invoke=lambda m: mock_response)):
result = chain.generate(
query="what is ML",
context="context text",
citation_map=[],
)
assert "Machine Learning" in result.answer
def test_generate_records_latency(self) -> None:
from voicevault.generation.answer_chain import AnswerChain
chain = AnswerChain()
mock_response = self._make_mock_response("Some answer.")
with patch.object(chain, "_build_groq", return_value=MagicMock(invoke=lambda m: mock_response)):
result = chain.generate("q", "ctx", [])
assert result.latency_ms >= 0
def test_generate_records_tokens(self) -> None:
from voicevault.generation.answer_chain import AnswerChain
chain = AnswerChain()
mock_response = self._make_mock_response("Answer.", total_tokens=200)
with patch.object(chain, "_build_groq", return_value=MagicMock(invoke=lambda m: mock_response)):
result = chain.generate("q", "ctx", [])
assert result.tokens_used == 200
def test_generate_detects_refusal(self) -> None:
from voicevault.generation.answer_chain import AnswerChain
chain = AnswerChain()
mock_response = self._make_mock_response(REFUSAL_PHRASE)
with patch.object(chain, "_build_groq", return_value=MagicMock(invoke=lambda m: mock_response)):
result = chain.generate("q", "ctx", [])
assert result.is_refusal is True
def test_generate_non_refusal_answer_not_flagged(self) -> None:
from voicevault.generation.answer_chain import AnswerChain
chain = AnswerChain()
mock_response = self._make_mock_response("This is a real answer.")
with patch.object(chain, "_build_groq", return_value=MagicMock(invoke=lambda m: mock_response)):
result = chain.generate("q", "ctx", [])
assert result.is_refusal is False
def test_generate_resolves_citations(self) -> None:
from voicevault.generation.answer_chain import AnswerChain
chain = AnswerChain()
citation = _make_citation("paper.pdf", 4, relevance_score=0.8)
mock_response = self._make_mock_response(
"The accuracy was 94% [Source: paper.pdf, p.4]."
)
with patch.object(chain, "_build_groq", return_value=MagicMock(invoke=lambda m: mock_response)):
result = chain.generate("q", "ctx", citation_map=[citation])
assert len(result.citations) == 1
assert result.citations[0].source_file == "paper.pdf"
class TestAnswerChainFallback:
"""Test Groq → Gemini fallback behavior."""
def _make_mock_response(self, content: str) -> MagicMock:
response = MagicMock()
response.content = content
response.usage_metadata = {"total_tokens": 50}
return response
def test_falls_back_to_gemini_when_groq_raises(self) -> None:
from config import cfg
from voicevault.generation.answer_chain import AnswerChain
chain = AnswerChain()
groq_llm = MagicMock()
groq_llm.invoke.side_effect = RuntimeError("Groq API error")
gemini_llm = MagicMock()
gemini_llm.invoke.return_value = self._make_mock_response("Gemini answered.")
with (
patch.object(chain, "_build_groq", return_value=groq_llm),
patch.object(chain, "_build_gemini", return_value=gemini_llm),
):
result = chain.generate("q", "ctx", [])
assert result.model_used == cfg.gemini_llm_model
assert "Gemini answered" in result.answer
def test_returns_refusal_when_both_fail(self) -> None:
from voicevault.generation.answer_chain import AnswerChain
chain = AnswerChain()
failing_llm = MagicMock()
failing_llm.invoke.side_effect = RuntimeError("API error")
with (
patch.object(chain, "_build_groq", return_value=failing_llm),
patch.object(chain, "_build_gemini", return_value=failing_llm),
):
result = chain.generate("q", "ctx", [])
assert result.model_used == "none"
assert result.is_refusal is True
assert REFUSAL_PHRASE in result.answer
def test_returns_refusal_when_no_keys_configured(self) -> None:
from voicevault.generation.answer_chain import AnswerChain
chain = AnswerChain()
with (
patch.object(chain, "_build_groq", return_value=None),
patch.object(chain, "_build_gemini", return_value=None),
):
result = chain.generate("q", "ctx", [])
assert result.model_used == "none"
assert REFUSAL_PHRASE in result.answer
def test_groq_used_when_available(self) -> None:
from config import cfg
from voicevault.generation.answer_chain import AnswerChain
chain = AnswerChain()
groq_llm = MagicMock()
groq_llm.invoke.return_value = self._make_mock_response("Groq answered.")
with patch.object(chain, "_build_groq", return_value=groq_llm):
result = chain.generate("q", "ctx", [])
assert result.model_used == cfg.groq_llm_model
assert "Groq answered" in result.answer
class TestAnswerChainStreaming:
"""Test stream_generate() token streaming."""
def test_streaming_yields_chunks(self) -> None:
from voicevault.generation.answer_chain import AnswerChain
chain = AnswerChain()
mock_chunks = [
MagicMock(content="Hello "),
MagicMock(content="world"),
MagicMock(content="!"),
]
mock_llm = MagicMock()
mock_llm.stream.return_value = iter(mock_chunks)
with patch.object(chain, "_build_groq", return_value=mock_llm):
chunks = list(chain.stream_generate("q", "ctx", []))
assert chunks == ["Hello ", "world", "!"]
def test_streaming_skips_empty_chunks(self) -> None:
from voicevault.generation.answer_chain import AnswerChain
chain = AnswerChain()
mock_chunks = [
MagicMock(content="real"),
MagicMock(content=""), # empty — should be skipped
MagicMock(content=" content"),
]
mock_llm = MagicMock()
mock_llm.stream.return_value = iter(mock_chunks)
with patch.object(chain, "_build_groq", return_value=mock_llm):
chunks = list(chain.stream_generate("q", "ctx", []))
assert "" not in chunks
assert chunks == ["real", " content"]
def test_streaming_returns_refusal_when_no_llm(self) -> None:
from voicevault.generation.answer_chain import AnswerChain
chain = AnswerChain()
with (
patch.object(chain, "_build_groq", return_value=None),
patch.object(chain, "_build_gemini", return_value=None),
):
chunks = list(chain.stream_generate("q", "ctx", []))
assert REFUSAL_PHRASE in "".join(chunks)
def test_streaming_yields_error_on_exception(self) -> None:
from voicevault.generation.answer_chain import AnswerChain
chain = AnswerChain()
mock_llm = MagicMock()
mock_llm.stream.side_effect = RuntimeError("connection refused")
with patch.object(chain, "_build_groq", return_value=mock_llm):
chunks = list(chain.stream_generate("q", "ctx", []))
combined = "".join(chunks)
assert "Error" in combined or "error" in combined
# ------------------------------------------------------------------ #
# GenerationResult Model Tests #
# ------------------------------------------------------------------ #
class TestGenerationResult:
"""Verify GenerationResult dataclass."""
def test_can_instantiate(self) -> None:
from voicevault.generation.answer_chain import GenerationResult
result = GenerationResult(
answer="test answer",
citations=[],
confidence_level="high",
is_refusal=False,
model_used="llama-3.1-70b-versatile",
tokens_used=100,
latency_ms=250,
)
assert result.answer == "test answer"
assert result.confidence_level == "high"
assert result.is_refusal is False
assert result.tokens_used == 100
assert result.latency_ms == 250
def test_citations_list_is_mutable(self) -> None:
from voicevault.generation.answer_chain import GenerationResult
result = GenerationResult(
answer="",
citations=[],
confidence_level="low",
is_refusal=True,
model_used="none",
tokens_used=0,
latency_ms=0,
)
result.citations.append(_make_citation())
assert len(result.citations) == 1