zeroshotGPU / tests /test_retrieval.py
Arjunvir Singh
Initial commit: zeroshotGPU MVP with full eval surface
db06ffa
"""Tests for retrieval metrics, lexical retriever, synthetic QA, and benchmark wiring."""
from __future__ import annotations
import tempfile
import unittest
from pathlib import Path
from zsgdp.benchmarks.parser_quality import run_parser_benchmark
from zsgdp.benchmarks.retrieval import (
LexicalRetriever,
RetrievalQuery,
run_retrieval_for_document,
synthesize_qa_set,
)
from zsgdp.schema import Chunk, ParsedDocument, QualityReport
from zsgdp.verify.retrieval import compute_retrieval_metrics
def _chunk(chunk_id: str, text: str, *, page: int = 1) -> Chunk:
return Chunk(
chunk_id=chunk_id,
doc_id="d1",
page_start=page,
page_end=page,
section_path=[],
content_type="prose",
text=text,
token_count=len(text.split()),
)
class TestComputeRetrievalMetrics(unittest.TestCase):
def test_perfect_retrieval(self):
result = compute_retrieval_metrics([
(["c1"], ["c1"]),
(["c2"], ["c2"]),
])
self.assertEqual(result["recall_at_k"][1], 1.0)
self.assertEqual(result["mean_reciprocal_rank"], 1.0)
def test_truth_at_rank_three_yields_partial(self):
result = compute_retrieval_metrics([(["x", "y", "c1", "z"], ["c1"])])
self.assertEqual(result["recall_at_k"][1], 0.0)
self.assertEqual(result["recall_at_k"][3], 1.0)
self.assertAlmostEqual(result["mean_reciprocal_rank"], 1 / 3)
def test_no_hit_yields_zero_mrr(self):
result = compute_retrieval_metrics([(["x", "y"], ["c1"])])
self.assertEqual(result["mean_reciprocal_rank"], 0.0)
self.assertEqual(result["recall_at_k"][5], 0.0)
def test_citation_accuracy_mirrors_recall(self):
result = compute_retrieval_metrics([(["c1"], ["c1"])])
self.assertEqual(result["citation_accuracy_at_k"][1], result["recall_at_k"][1])
def test_empty_queries_are_vacuous(self):
result = compute_retrieval_metrics([])
self.assertEqual(result["query_count"], 0)
self.assertEqual(result["mean_reciprocal_rank"], 1.0)
def test_empty_truth_sets_skipped(self):
result = compute_retrieval_metrics([(["c1"], [])])
self.assertEqual(result["query_count"], 0)
class TestLexicalRetriever(unittest.TestCase):
def test_finds_distinctive_chunk(self):
chunks = [
_chunk("c1", "Apples grow on trees in the orchard."),
_chunk("c2", "Cars drive on highways across the country."),
_chunk("c3", "Boats sail on rivers and oceans."),
]
retriever = LexicalRetriever()
retriever.index(chunks)
ranking = retriever.query("apples orchard", top_k=3)
self.assertEqual(ranking[0], "c1")
def test_query_text_with_no_indexed_terms_returns_empty(self):
retriever = LexicalRetriever()
retriever.index([_chunk("c1", "Apples grow on trees.")])
self.assertEqual(retriever.query("zzz qqq", top_k=3), [])
def test_empty_index_returns_empty(self):
self.assertEqual(LexicalRetriever().query("anything", top_k=3), [])
class TestSynthesizeQASet(unittest.TestCase):
def test_picks_distinctive_sentence_per_chunk(self):
parsed = ParsedDocument(
doc_id="d1",
source_path="/tmp/d1.md",
file_type="markdown",
chunks=[
_chunk("c1", "Apples grow on trees in the orchard. Common shared sentence."),
_chunk("c2", "Cars drive on highways across the country. Common shared sentence."),
],
quality_report=QualityReport(),
)
queries = synthesize_qa_set(parsed, min_sentence_tokens=3)
truths = sorted(query.truths[0] for query in queries)
self.assertEqual(truths, ["c1", "c2"])
for query in queries:
self.assertNotIn("Common shared", query.text)
def test_skips_chunks_with_no_distinctive_sentences(self):
parsed = ParsedDocument(
doc_id="d1",
source_path="/tmp/d1.md",
file_type="markdown",
chunks=[
_chunk("c1", "Same sentence here."),
_chunk("c2", "Same sentence here."),
],
quality_report=QualityReport(),
)
queries = synthesize_qa_set(parsed, min_sentence_tokens=2)
self.assertEqual(queries, [])
class TestRunRetrievalForDocument(unittest.TestCase):
def test_end_to_end_retrieval_on_synthetic_doc(self):
parsed = ParsedDocument(
doc_id="d1",
source_path="/tmp/d1.md",
file_type="markdown",
chunks=[
_chunk("c1", "Apples grow on trees in the orchard during autumn season."),
_chunk("c2", "Submarines navigate beneath the ocean using sonar pulses."),
_chunk("c3", "Mountains rise above the clouds in the distant horizon."),
],
quality_report=QualityReport(),
)
run = run_retrieval_for_document(parsed, top_k=3)
self.assertTrue(run["evaluated"])
self.assertEqual(run["query_count"], 3)
for result in run["results"]:
truth = result["truths"][0]
# Verbatim retrieval should put the source chunk at rank 1.
self.assertEqual(result["retrieved"][0], truth)
def test_no_chunks_returns_unevaluated(self):
parsed = ParsedDocument(
doc_id="d1",
source_path="/tmp/d1.md",
file_type="markdown",
chunks=[],
quality_report=QualityReport(),
)
run = run_retrieval_for_document(parsed)
self.assertFalse(run["evaluated"])
self.assertEqual(run["reason"], "no_chunks")
def test_explicit_queries_override_synthesis(self):
parsed = ParsedDocument(
doc_id="d1",
source_path="/tmp/d1.md",
file_type="markdown",
chunks=[
_chunk("c1", "Apples grow on trees in the orchard."),
_chunk("c2", "Cars drive on highways."),
],
quality_report=QualityReport(),
)
queries = [RetrievalQuery(query_id="q1", text="apples orchard", truths=["c1"])]
run = run_retrieval_for_document(parsed, queries=queries)
self.assertEqual(run["query_count"], 1)
self.assertEqual(run["results"][0]["retrieved"][0], "c1")
class TestBenchmarkIntegration(unittest.TestCase):
def test_retrieval_metrics_appear_in_summary(self):
with tempfile.TemporaryDirectory() as tmp:
tmp = Path(tmp)
src = tmp / "in"
src.mkdir()
(src / "doc.md").write_text(
"# Doc\n\n"
"Apples grow on trees in the orchard during autumn.\n\n"
"Submarines navigate beneath the ocean using sonar.\n\n"
"Mountains rise above the clouds in the horizon.\n",
encoding="utf-8",
)
summary = run_parser_benchmark(src, tmp / "out", dataset_name="custom_folder")
doc = summary["documents"][0]
self.assertTrue(doc["retrieval_evaluated"])
self.assertGreater(doc["retrieval_query_count"], 0)
self.assertEqual(doc["retrieval_recall_at_1"], 1.0)
self.assertGreaterEqual(summary["mean_retrieval_recall_at_1"], 0.0)
self.assertEqual(summary["retrieval_evaluated_count"], 1)
self.assertTrue((tmp / "out" / "retrieval_runs.csv").exists())
if __name__ == "__main__":
unittest.main()