Spaces:
Running on Zero
Running on Zero
| """Tests for retrieval metrics, lexical retriever, synthetic QA, and benchmark wiring.""" | |
| from __future__ import annotations | |
| import tempfile | |
| import unittest | |
| from pathlib import Path | |
| from zsgdp.benchmarks.parser_quality import run_parser_benchmark | |
| from zsgdp.benchmarks.retrieval import ( | |
| LexicalRetriever, | |
| RetrievalQuery, | |
| run_retrieval_for_document, | |
| synthesize_qa_set, | |
| ) | |
| from zsgdp.schema import Chunk, ParsedDocument, QualityReport | |
| from zsgdp.verify.retrieval import compute_retrieval_metrics | |
| def _chunk(chunk_id: str, text: str, *, page: int = 1) -> Chunk: | |
| return Chunk( | |
| chunk_id=chunk_id, | |
| doc_id="d1", | |
| page_start=page, | |
| page_end=page, | |
| section_path=[], | |
| content_type="prose", | |
| text=text, | |
| token_count=len(text.split()), | |
| ) | |
| class TestComputeRetrievalMetrics(unittest.TestCase): | |
| def test_perfect_retrieval(self): | |
| result = compute_retrieval_metrics([ | |
| (["c1"], ["c1"]), | |
| (["c2"], ["c2"]), | |
| ]) | |
| self.assertEqual(result["recall_at_k"][1], 1.0) | |
| self.assertEqual(result["mean_reciprocal_rank"], 1.0) | |
| def test_truth_at_rank_three_yields_partial(self): | |
| result = compute_retrieval_metrics([(["x", "y", "c1", "z"], ["c1"])]) | |
| self.assertEqual(result["recall_at_k"][1], 0.0) | |
| self.assertEqual(result["recall_at_k"][3], 1.0) | |
| self.assertAlmostEqual(result["mean_reciprocal_rank"], 1 / 3) | |
| def test_no_hit_yields_zero_mrr(self): | |
| result = compute_retrieval_metrics([(["x", "y"], ["c1"])]) | |
| self.assertEqual(result["mean_reciprocal_rank"], 0.0) | |
| self.assertEqual(result["recall_at_k"][5], 0.0) | |
| def test_citation_accuracy_mirrors_recall(self): | |
| result = compute_retrieval_metrics([(["c1"], ["c1"])]) | |
| self.assertEqual(result["citation_accuracy_at_k"][1], result["recall_at_k"][1]) | |
| def test_empty_queries_are_vacuous(self): | |
| result = compute_retrieval_metrics([]) | |
| self.assertEqual(result["query_count"], 0) | |
| self.assertEqual(result["mean_reciprocal_rank"], 1.0) | |
| def test_empty_truth_sets_skipped(self): | |
| result = compute_retrieval_metrics([(["c1"], [])]) | |
| self.assertEqual(result["query_count"], 0) | |
| class TestLexicalRetriever(unittest.TestCase): | |
| def test_finds_distinctive_chunk(self): | |
| chunks = [ | |
| _chunk("c1", "Apples grow on trees in the orchard."), | |
| _chunk("c2", "Cars drive on highways across the country."), | |
| _chunk("c3", "Boats sail on rivers and oceans."), | |
| ] | |
| retriever = LexicalRetriever() | |
| retriever.index(chunks) | |
| ranking = retriever.query("apples orchard", top_k=3) | |
| self.assertEqual(ranking[0], "c1") | |
| def test_query_text_with_no_indexed_terms_returns_empty(self): | |
| retriever = LexicalRetriever() | |
| retriever.index([_chunk("c1", "Apples grow on trees.")]) | |
| self.assertEqual(retriever.query("zzz qqq", top_k=3), []) | |
| def test_empty_index_returns_empty(self): | |
| self.assertEqual(LexicalRetriever().query("anything", top_k=3), []) | |
| class TestSynthesizeQASet(unittest.TestCase): | |
| def test_picks_distinctive_sentence_per_chunk(self): | |
| parsed = ParsedDocument( | |
| doc_id="d1", | |
| source_path="/tmp/d1.md", | |
| file_type="markdown", | |
| chunks=[ | |
| _chunk("c1", "Apples grow on trees in the orchard. Common shared sentence."), | |
| _chunk("c2", "Cars drive on highways across the country. Common shared sentence."), | |
| ], | |
| quality_report=QualityReport(), | |
| ) | |
| queries = synthesize_qa_set(parsed, min_sentence_tokens=3) | |
| truths = sorted(query.truths[0] for query in queries) | |
| self.assertEqual(truths, ["c1", "c2"]) | |
| for query in queries: | |
| self.assertNotIn("Common shared", query.text) | |
| def test_skips_chunks_with_no_distinctive_sentences(self): | |
| parsed = ParsedDocument( | |
| doc_id="d1", | |
| source_path="/tmp/d1.md", | |
| file_type="markdown", | |
| chunks=[ | |
| _chunk("c1", "Same sentence here."), | |
| _chunk("c2", "Same sentence here."), | |
| ], | |
| quality_report=QualityReport(), | |
| ) | |
| queries = synthesize_qa_set(parsed, min_sentence_tokens=2) | |
| self.assertEqual(queries, []) | |
| class TestRunRetrievalForDocument(unittest.TestCase): | |
| def test_end_to_end_retrieval_on_synthetic_doc(self): | |
| parsed = ParsedDocument( | |
| doc_id="d1", | |
| source_path="/tmp/d1.md", | |
| file_type="markdown", | |
| chunks=[ | |
| _chunk("c1", "Apples grow on trees in the orchard during autumn season."), | |
| _chunk("c2", "Submarines navigate beneath the ocean using sonar pulses."), | |
| _chunk("c3", "Mountains rise above the clouds in the distant horizon."), | |
| ], | |
| quality_report=QualityReport(), | |
| ) | |
| run = run_retrieval_for_document(parsed, top_k=3) | |
| self.assertTrue(run["evaluated"]) | |
| self.assertEqual(run["query_count"], 3) | |
| for result in run["results"]: | |
| truth = result["truths"][0] | |
| # Verbatim retrieval should put the source chunk at rank 1. | |
| self.assertEqual(result["retrieved"][0], truth) | |
| def test_no_chunks_returns_unevaluated(self): | |
| parsed = ParsedDocument( | |
| doc_id="d1", | |
| source_path="/tmp/d1.md", | |
| file_type="markdown", | |
| chunks=[], | |
| quality_report=QualityReport(), | |
| ) | |
| run = run_retrieval_for_document(parsed) | |
| self.assertFalse(run["evaluated"]) | |
| self.assertEqual(run["reason"], "no_chunks") | |
| def test_explicit_queries_override_synthesis(self): | |
| parsed = ParsedDocument( | |
| doc_id="d1", | |
| source_path="/tmp/d1.md", | |
| file_type="markdown", | |
| chunks=[ | |
| _chunk("c1", "Apples grow on trees in the orchard."), | |
| _chunk("c2", "Cars drive on highways."), | |
| ], | |
| quality_report=QualityReport(), | |
| ) | |
| queries = [RetrievalQuery(query_id="q1", text="apples orchard", truths=["c1"])] | |
| run = run_retrieval_for_document(parsed, queries=queries) | |
| self.assertEqual(run["query_count"], 1) | |
| self.assertEqual(run["results"][0]["retrieved"][0], "c1") | |
| class TestBenchmarkIntegration(unittest.TestCase): | |
| def test_retrieval_metrics_appear_in_summary(self): | |
| with tempfile.TemporaryDirectory() as tmp: | |
| tmp = Path(tmp) | |
| src = tmp / "in" | |
| src.mkdir() | |
| (src / "doc.md").write_text( | |
| "# Doc\n\n" | |
| "Apples grow on trees in the orchard during autumn.\n\n" | |
| "Submarines navigate beneath the ocean using sonar.\n\n" | |
| "Mountains rise above the clouds in the horizon.\n", | |
| encoding="utf-8", | |
| ) | |
| summary = run_parser_benchmark(src, tmp / "out", dataset_name="custom_folder") | |
| doc = summary["documents"][0] | |
| self.assertTrue(doc["retrieval_evaluated"]) | |
| self.assertGreater(doc["retrieval_query_count"], 0) | |
| self.assertEqual(doc["retrieval_recall_at_1"], 1.0) | |
| self.assertGreaterEqual(summary["mean_retrieval_recall_at_1"], 0.0) | |
| self.assertEqual(summary["retrieval_evaluated_count"], 1) | |
| self.assertTrue((tmp / "out" / "retrieval_runs.csv").exists()) | |
| if __name__ == "__main__": | |
| unittest.main() | |