"""Tests for retrieval metrics, lexical retriever, synthetic QA, and benchmark wiring.""" from __future__ import annotations import tempfile import unittest from pathlib import Path from zsgdp.benchmarks.parser_quality import run_parser_benchmark from zsgdp.benchmarks.retrieval import ( LexicalRetriever, RetrievalQuery, run_retrieval_for_document, synthesize_qa_set, ) from zsgdp.schema import Chunk, ParsedDocument, QualityReport from zsgdp.verify.retrieval import compute_retrieval_metrics def _chunk(chunk_id: str, text: str, *, page: int = 1) -> Chunk: return Chunk( chunk_id=chunk_id, doc_id="d1", page_start=page, page_end=page, section_path=[], content_type="prose", text=text, token_count=len(text.split()), ) class TestComputeRetrievalMetrics(unittest.TestCase): def test_perfect_retrieval(self): result = compute_retrieval_metrics([ (["c1"], ["c1"]), (["c2"], ["c2"]), ]) self.assertEqual(result["recall_at_k"][1], 1.0) self.assertEqual(result["mean_reciprocal_rank"], 1.0) def test_truth_at_rank_three_yields_partial(self): result = compute_retrieval_metrics([(["x", "y", "c1", "z"], ["c1"])]) self.assertEqual(result["recall_at_k"][1], 0.0) self.assertEqual(result["recall_at_k"][3], 1.0) self.assertAlmostEqual(result["mean_reciprocal_rank"], 1 / 3) def test_no_hit_yields_zero_mrr(self): result = compute_retrieval_metrics([(["x", "y"], ["c1"])]) self.assertEqual(result["mean_reciprocal_rank"], 0.0) self.assertEqual(result["recall_at_k"][5], 0.0) def test_citation_accuracy_mirrors_recall(self): result = compute_retrieval_metrics([(["c1"], ["c1"])]) self.assertEqual(result["citation_accuracy_at_k"][1], result["recall_at_k"][1]) def test_empty_queries_are_vacuous(self): result = compute_retrieval_metrics([]) self.assertEqual(result["query_count"], 0) self.assertEqual(result["mean_reciprocal_rank"], 1.0) def test_empty_truth_sets_skipped(self): result = compute_retrieval_metrics([(["c1"], [])]) self.assertEqual(result["query_count"], 0) class TestLexicalRetriever(unittest.TestCase): def test_finds_distinctive_chunk(self): chunks = [ _chunk("c1", "Apples grow on trees in the orchard."), _chunk("c2", "Cars drive on highways across the country."), _chunk("c3", "Boats sail on rivers and oceans."), ] retriever = LexicalRetriever() retriever.index(chunks) ranking = retriever.query("apples orchard", top_k=3) self.assertEqual(ranking[0], "c1") def test_query_text_with_no_indexed_terms_returns_empty(self): retriever = LexicalRetriever() retriever.index([_chunk("c1", "Apples grow on trees.")]) self.assertEqual(retriever.query("zzz qqq", top_k=3), []) def test_empty_index_returns_empty(self): self.assertEqual(LexicalRetriever().query("anything", top_k=3), []) class TestSynthesizeQASet(unittest.TestCase): def test_picks_distinctive_sentence_per_chunk(self): parsed = ParsedDocument( doc_id="d1", source_path="/tmp/d1.md", file_type="markdown", chunks=[ _chunk("c1", "Apples grow on trees in the orchard. Common shared sentence."), _chunk("c2", "Cars drive on highways across the country. Common shared sentence."), ], quality_report=QualityReport(), ) queries = synthesize_qa_set(parsed, min_sentence_tokens=3) truths = sorted(query.truths[0] for query in queries) self.assertEqual(truths, ["c1", "c2"]) for query in queries: self.assertNotIn("Common shared", query.text) def test_skips_chunks_with_no_distinctive_sentences(self): parsed = ParsedDocument( doc_id="d1", source_path="/tmp/d1.md", file_type="markdown", chunks=[ _chunk("c1", "Same sentence here."), _chunk("c2", "Same sentence here."), ], quality_report=QualityReport(), ) queries = synthesize_qa_set(parsed, min_sentence_tokens=2) self.assertEqual(queries, []) class TestRunRetrievalForDocument(unittest.TestCase): def test_end_to_end_retrieval_on_synthetic_doc(self): parsed = ParsedDocument( doc_id="d1", source_path="/tmp/d1.md", file_type="markdown", chunks=[ _chunk("c1", "Apples grow on trees in the orchard during autumn season."), _chunk("c2", "Submarines navigate beneath the ocean using sonar pulses."), _chunk("c3", "Mountains rise above the clouds in the distant horizon."), ], quality_report=QualityReport(), ) run = run_retrieval_for_document(parsed, top_k=3) self.assertTrue(run["evaluated"]) self.assertEqual(run["query_count"], 3) for result in run["results"]: truth = result["truths"][0] # Verbatim retrieval should put the source chunk at rank 1. self.assertEqual(result["retrieved"][0], truth) def test_no_chunks_returns_unevaluated(self): parsed = ParsedDocument( doc_id="d1", source_path="/tmp/d1.md", file_type="markdown", chunks=[], quality_report=QualityReport(), ) run = run_retrieval_for_document(parsed) self.assertFalse(run["evaluated"]) self.assertEqual(run["reason"], "no_chunks") def test_explicit_queries_override_synthesis(self): parsed = ParsedDocument( doc_id="d1", source_path="/tmp/d1.md", file_type="markdown", chunks=[ _chunk("c1", "Apples grow on trees in the orchard."), _chunk("c2", "Cars drive on highways."), ], quality_report=QualityReport(), ) queries = [RetrievalQuery(query_id="q1", text="apples orchard", truths=["c1"])] run = run_retrieval_for_document(parsed, queries=queries) self.assertEqual(run["query_count"], 1) self.assertEqual(run["results"][0]["retrieved"][0], "c1") class TestBenchmarkIntegration(unittest.TestCase): def test_retrieval_metrics_appear_in_summary(self): with tempfile.TemporaryDirectory() as tmp: tmp = Path(tmp) src = tmp / "in" src.mkdir() (src / "doc.md").write_text( "# Doc\n\n" "Apples grow on trees in the orchard during autumn.\n\n" "Submarines navigate beneath the ocean using sonar.\n\n" "Mountains rise above the clouds in the horizon.\n", encoding="utf-8", ) summary = run_parser_benchmark(src, tmp / "out", dataset_name="custom_folder") doc = summary["documents"][0] self.assertTrue(doc["retrieval_evaluated"]) self.assertGreater(doc["retrieval_query_count"], 0) self.assertEqual(doc["retrieval_recall_at_1"], 1.0) self.assertGreaterEqual(summary["mean_retrieval_recall_at_1"], 0.0) self.assertEqual(summary["retrieval_evaluated_count"], 1) self.assertTrue((tmp / "out" / "retrieval_runs.csv").exists()) if __name__ == "__main__": unittest.main()