File size: 7,595 Bytes
db06ffa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
"""Tests for retrieval metrics, lexical retriever, synthetic QA, and benchmark wiring."""

from __future__ import annotations

import tempfile
import unittest
from pathlib import Path

from zsgdp.benchmarks.parser_quality import run_parser_benchmark
from zsgdp.benchmarks.retrieval import (
    LexicalRetriever,
    RetrievalQuery,
    run_retrieval_for_document,
    synthesize_qa_set,
)
from zsgdp.schema import Chunk, ParsedDocument, QualityReport
from zsgdp.verify.retrieval import compute_retrieval_metrics


def _chunk(chunk_id: str, text: str, *, page: int = 1) -> Chunk:
    return Chunk(
        chunk_id=chunk_id,
        doc_id="d1",
        page_start=page,
        page_end=page,
        section_path=[],
        content_type="prose",
        text=text,
        token_count=len(text.split()),
    )


class TestComputeRetrievalMetrics(unittest.TestCase):
    def test_perfect_retrieval(self):
        result = compute_retrieval_metrics([
            (["c1"], ["c1"]),
            (["c2"], ["c2"]),
        ])
        self.assertEqual(result["recall_at_k"][1], 1.0)
        self.assertEqual(result["mean_reciprocal_rank"], 1.0)

    def test_truth_at_rank_three_yields_partial(self):
        result = compute_retrieval_metrics([(["x", "y", "c1", "z"], ["c1"])])
        self.assertEqual(result["recall_at_k"][1], 0.0)
        self.assertEqual(result["recall_at_k"][3], 1.0)
        self.assertAlmostEqual(result["mean_reciprocal_rank"], 1 / 3)

    def test_no_hit_yields_zero_mrr(self):
        result = compute_retrieval_metrics([(["x", "y"], ["c1"])])
        self.assertEqual(result["mean_reciprocal_rank"], 0.0)
        self.assertEqual(result["recall_at_k"][5], 0.0)

    def test_citation_accuracy_mirrors_recall(self):
        result = compute_retrieval_metrics([(["c1"], ["c1"])])
        self.assertEqual(result["citation_accuracy_at_k"][1], result["recall_at_k"][1])

    def test_empty_queries_are_vacuous(self):
        result = compute_retrieval_metrics([])
        self.assertEqual(result["query_count"], 0)
        self.assertEqual(result["mean_reciprocal_rank"], 1.0)

    def test_empty_truth_sets_skipped(self):
        result = compute_retrieval_metrics([(["c1"], [])])
        self.assertEqual(result["query_count"], 0)


class TestLexicalRetriever(unittest.TestCase):
    def test_finds_distinctive_chunk(self):
        chunks = [
            _chunk("c1", "Apples grow on trees in the orchard."),
            _chunk("c2", "Cars drive on highways across the country."),
            _chunk("c3", "Boats sail on rivers and oceans."),
        ]
        retriever = LexicalRetriever()
        retriever.index(chunks)

        ranking = retriever.query("apples orchard", top_k=3)
        self.assertEqual(ranking[0], "c1")

    def test_query_text_with_no_indexed_terms_returns_empty(self):
        retriever = LexicalRetriever()
        retriever.index([_chunk("c1", "Apples grow on trees.")])
        self.assertEqual(retriever.query("zzz qqq", top_k=3), [])

    def test_empty_index_returns_empty(self):
        self.assertEqual(LexicalRetriever().query("anything", top_k=3), [])


class TestSynthesizeQASet(unittest.TestCase):
    def test_picks_distinctive_sentence_per_chunk(self):
        parsed = ParsedDocument(
            doc_id="d1",
            source_path="/tmp/d1.md",
            file_type="markdown",
            chunks=[
                _chunk("c1", "Apples grow on trees in the orchard. Common shared sentence."),
                _chunk("c2", "Cars drive on highways across the country. Common shared sentence."),
            ],
            quality_report=QualityReport(),
        )
        queries = synthesize_qa_set(parsed, min_sentence_tokens=3)
        truths = sorted(query.truths[0] for query in queries)
        self.assertEqual(truths, ["c1", "c2"])
        for query in queries:
            self.assertNotIn("Common shared", query.text)

    def test_skips_chunks_with_no_distinctive_sentences(self):
        parsed = ParsedDocument(
            doc_id="d1",
            source_path="/tmp/d1.md",
            file_type="markdown",
            chunks=[
                _chunk("c1", "Same sentence here."),
                _chunk("c2", "Same sentence here."),
            ],
            quality_report=QualityReport(),
        )
        queries = synthesize_qa_set(parsed, min_sentence_tokens=2)
        self.assertEqual(queries, [])


class TestRunRetrievalForDocument(unittest.TestCase):
    def test_end_to_end_retrieval_on_synthetic_doc(self):
        parsed = ParsedDocument(
            doc_id="d1",
            source_path="/tmp/d1.md",
            file_type="markdown",
            chunks=[
                _chunk("c1", "Apples grow on trees in the orchard during autumn season."),
                _chunk("c2", "Submarines navigate beneath the ocean using sonar pulses."),
                _chunk("c3", "Mountains rise above the clouds in the distant horizon."),
            ],
            quality_report=QualityReport(),
        )
        run = run_retrieval_for_document(parsed, top_k=3)
        self.assertTrue(run["evaluated"])
        self.assertEqual(run["query_count"], 3)
        for result in run["results"]:
            truth = result["truths"][0]
            # Verbatim retrieval should put the source chunk at rank 1.
            self.assertEqual(result["retrieved"][0], truth)

    def test_no_chunks_returns_unevaluated(self):
        parsed = ParsedDocument(
            doc_id="d1",
            source_path="/tmp/d1.md",
            file_type="markdown",
            chunks=[],
            quality_report=QualityReport(),
        )
        run = run_retrieval_for_document(parsed)
        self.assertFalse(run["evaluated"])
        self.assertEqual(run["reason"], "no_chunks")

    def test_explicit_queries_override_synthesis(self):
        parsed = ParsedDocument(
            doc_id="d1",
            source_path="/tmp/d1.md",
            file_type="markdown",
            chunks=[
                _chunk("c1", "Apples grow on trees in the orchard."),
                _chunk("c2", "Cars drive on highways."),
            ],
            quality_report=QualityReport(),
        )
        queries = [RetrievalQuery(query_id="q1", text="apples orchard", truths=["c1"])]
        run = run_retrieval_for_document(parsed, queries=queries)
        self.assertEqual(run["query_count"], 1)
        self.assertEqual(run["results"][0]["retrieved"][0], "c1")


class TestBenchmarkIntegration(unittest.TestCase):
    def test_retrieval_metrics_appear_in_summary(self):
        with tempfile.TemporaryDirectory() as tmp:
            tmp = Path(tmp)
            src = tmp / "in"
            src.mkdir()
            (src / "doc.md").write_text(
                "# Doc\n\n"
                "Apples grow on trees in the orchard during autumn.\n\n"
                "Submarines navigate beneath the ocean using sonar.\n\n"
                "Mountains rise above the clouds in the horizon.\n",
                encoding="utf-8",
            )

            summary = run_parser_benchmark(src, tmp / "out", dataset_name="custom_folder")

            doc = summary["documents"][0]
            self.assertTrue(doc["retrieval_evaluated"])
            self.assertGreater(doc["retrieval_query_count"], 0)
            self.assertEqual(doc["retrieval_recall_at_1"], 1.0)
            self.assertGreaterEqual(summary["mean_retrieval_recall_at_1"], 0.0)
            self.assertEqual(summary["retrieval_evaluated_count"], 1)
            self.assertTrue((tmp / "out" / "retrieval_runs.csv").exists())


if __name__ == "__main__":
    unittest.main()