Spaces:
Running on Zero
Running on Zero
File size: 7,595 Bytes
db06ffa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 | """Tests for retrieval metrics, lexical retriever, synthetic QA, and benchmark wiring."""
from __future__ import annotations
import tempfile
import unittest
from pathlib import Path
from zsgdp.benchmarks.parser_quality import run_parser_benchmark
from zsgdp.benchmarks.retrieval import (
LexicalRetriever,
RetrievalQuery,
run_retrieval_for_document,
synthesize_qa_set,
)
from zsgdp.schema import Chunk, ParsedDocument, QualityReport
from zsgdp.verify.retrieval import compute_retrieval_metrics
def _chunk(chunk_id: str, text: str, *, page: int = 1) -> Chunk:
return Chunk(
chunk_id=chunk_id,
doc_id="d1",
page_start=page,
page_end=page,
section_path=[],
content_type="prose",
text=text,
token_count=len(text.split()),
)
class TestComputeRetrievalMetrics(unittest.TestCase):
def test_perfect_retrieval(self):
result = compute_retrieval_metrics([
(["c1"], ["c1"]),
(["c2"], ["c2"]),
])
self.assertEqual(result["recall_at_k"][1], 1.0)
self.assertEqual(result["mean_reciprocal_rank"], 1.0)
def test_truth_at_rank_three_yields_partial(self):
result = compute_retrieval_metrics([(["x", "y", "c1", "z"], ["c1"])])
self.assertEqual(result["recall_at_k"][1], 0.0)
self.assertEqual(result["recall_at_k"][3], 1.0)
self.assertAlmostEqual(result["mean_reciprocal_rank"], 1 / 3)
def test_no_hit_yields_zero_mrr(self):
result = compute_retrieval_metrics([(["x", "y"], ["c1"])])
self.assertEqual(result["mean_reciprocal_rank"], 0.0)
self.assertEqual(result["recall_at_k"][5], 0.0)
def test_citation_accuracy_mirrors_recall(self):
result = compute_retrieval_metrics([(["c1"], ["c1"])])
self.assertEqual(result["citation_accuracy_at_k"][1], result["recall_at_k"][1])
def test_empty_queries_are_vacuous(self):
result = compute_retrieval_metrics([])
self.assertEqual(result["query_count"], 0)
self.assertEqual(result["mean_reciprocal_rank"], 1.0)
def test_empty_truth_sets_skipped(self):
result = compute_retrieval_metrics([(["c1"], [])])
self.assertEqual(result["query_count"], 0)
class TestLexicalRetriever(unittest.TestCase):
def test_finds_distinctive_chunk(self):
chunks = [
_chunk("c1", "Apples grow on trees in the orchard."),
_chunk("c2", "Cars drive on highways across the country."),
_chunk("c3", "Boats sail on rivers and oceans."),
]
retriever = LexicalRetriever()
retriever.index(chunks)
ranking = retriever.query("apples orchard", top_k=3)
self.assertEqual(ranking[0], "c1")
def test_query_text_with_no_indexed_terms_returns_empty(self):
retriever = LexicalRetriever()
retriever.index([_chunk("c1", "Apples grow on trees.")])
self.assertEqual(retriever.query("zzz qqq", top_k=3), [])
def test_empty_index_returns_empty(self):
self.assertEqual(LexicalRetriever().query("anything", top_k=3), [])
class TestSynthesizeQASet(unittest.TestCase):
def test_picks_distinctive_sentence_per_chunk(self):
parsed = ParsedDocument(
doc_id="d1",
source_path="/tmp/d1.md",
file_type="markdown",
chunks=[
_chunk("c1", "Apples grow on trees in the orchard. Common shared sentence."),
_chunk("c2", "Cars drive on highways across the country. Common shared sentence."),
],
quality_report=QualityReport(),
)
queries = synthesize_qa_set(parsed, min_sentence_tokens=3)
truths = sorted(query.truths[0] for query in queries)
self.assertEqual(truths, ["c1", "c2"])
for query in queries:
self.assertNotIn("Common shared", query.text)
def test_skips_chunks_with_no_distinctive_sentences(self):
parsed = ParsedDocument(
doc_id="d1",
source_path="/tmp/d1.md",
file_type="markdown",
chunks=[
_chunk("c1", "Same sentence here."),
_chunk("c2", "Same sentence here."),
],
quality_report=QualityReport(),
)
queries = synthesize_qa_set(parsed, min_sentence_tokens=2)
self.assertEqual(queries, [])
class TestRunRetrievalForDocument(unittest.TestCase):
def test_end_to_end_retrieval_on_synthetic_doc(self):
parsed = ParsedDocument(
doc_id="d1",
source_path="/tmp/d1.md",
file_type="markdown",
chunks=[
_chunk("c1", "Apples grow on trees in the orchard during autumn season."),
_chunk("c2", "Submarines navigate beneath the ocean using sonar pulses."),
_chunk("c3", "Mountains rise above the clouds in the distant horizon."),
],
quality_report=QualityReport(),
)
run = run_retrieval_for_document(parsed, top_k=3)
self.assertTrue(run["evaluated"])
self.assertEqual(run["query_count"], 3)
for result in run["results"]:
truth = result["truths"][0]
# Verbatim retrieval should put the source chunk at rank 1.
self.assertEqual(result["retrieved"][0], truth)
def test_no_chunks_returns_unevaluated(self):
parsed = ParsedDocument(
doc_id="d1",
source_path="/tmp/d1.md",
file_type="markdown",
chunks=[],
quality_report=QualityReport(),
)
run = run_retrieval_for_document(parsed)
self.assertFalse(run["evaluated"])
self.assertEqual(run["reason"], "no_chunks")
def test_explicit_queries_override_synthesis(self):
parsed = ParsedDocument(
doc_id="d1",
source_path="/tmp/d1.md",
file_type="markdown",
chunks=[
_chunk("c1", "Apples grow on trees in the orchard."),
_chunk("c2", "Cars drive on highways."),
],
quality_report=QualityReport(),
)
queries = [RetrievalQuery(query_id="q1", text="apples orchard", truths=["c1"])]
run = run_retrieval_for_document(parsed, queries=queries)
self.assertEqual(run["query_count"], 1)
self.assertEqual(run["results"][0]["retrieved"][0], "c1")
class TestBenchmarkIntegration(unittest.TestCase):
def test_retrieval_metrics_appear_in_summary(self):
with tempfile.TemporaryDirectory() as tmp:
tmp = Path(tmp)
src = tmp / "in"
src.mkdir()
(src / "doc.md").write_text(
"# Doc\n\n"
"Apples grow on trees in the orchard during autumn.\n\n"
"Submarines navigate beneath the ocean using sonar.\n\n"
"Mountains rise above the clouds in the horizon.\n",
encoding="utf-8",
)
summary = run_parser_benchmark(src, tmp / "out", dataset_name="custom_folder")
doc = summary["documents"][0]
self.assertTrue(doc["retrieval_evaluated"])
self.assertGreater(doc["retrieval_query_count"], 0)
self.assertEqual(doc["retrieval_recall_at_1"], 1.0)
self.assertGreaterEqual(summary["mean_retrieval_recall_at_1"], 0.0)
self.assertEqual(summary["retrieval_evaluated_count"], 1)
self.assertTrue((tmp / "out" / "retrieval_runs.csv").exists())
if __name__ == "__main__":
unittest.main()
|