File size: 7,418 Bytes
db06ffa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
"""Tests for the embedding-based retriever and the build_retriever factory."""

from __future__ import annotations

import math
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch

from zsgdp.benchmarks.embedding_retriever import (
    EmbeddingRetriever,
    build_retriever,
)
from zsgdp.benchmarks.parser_quality import run_parser_benchmark
from zsgdp.benchmarks.retrieval import LexicalRetriever, run_retrieval_for_document
from zsgdp.schema import Chunk, ParsedDocument, QualityReport


def _chunk(chunk_id: str, text: str) -> Chunk:
    return Chunk(
        chunk_id=chunk_id,
        doc_id="d1",
        page_start=1,
        page_end=1,
        section_path=[],
        content_type="prose",
        text=text,
        token_count=len(text.split()),
    )


def _hashing_embedder(dim: int = 32):
    """Deterministic toy embedder: tokens hashed into a fixed-dim vector.

    Uses a process-stable hash (hashlib.md5) instead of builtins.hash(), which
    is randomized per Python process and would make ranking non-deterministic
    across test runs.
    """

    import hashlib

    def stable_hash(token: str) -> int:
        return int.from_bytes(hashlib.md5(token.encode("utf-8")).digest()[:8], "big")

    def encode(texts):
        out = []
        for text in texts:
            vector = [0.0] * dim
            for token in text.lower().split():
                vector[stable_hash(token) % dim] += 1.0
            out.append(vector)
        return out

    return encode


class TestEmbeddingRetriever(unittest.TestCase):
    def test_finds_distinctive_chunk_with_injected_embedder(self):
        chunks = [
            _chunk("c1", "Apples grow on trees in the orchard."),
            _chunk("c2", "Cars drive on highways across the country."),
            _chunk("c3", "Boats sail on rivers and oceans."),
        ]
        retriever = EmbeddingRetriever(embedder=_hashing_embedder())
        retriever.index(chunks)

        ranking = retriever.query("apples orchard", top_k=3)
        self.assertEqual(ranking[0], "c1")

    def test_empty_index_returns_empty(self):
        retriever = EmbeddingRetriever(embedder=_hashing_embedder())
        self.assertEqual(retriever.query("anything", top_k=3), [])

    def test_zero_norm_vector_skipped(self):
        retriever = EmbeddingRetriever(embedder=lambda texts: [[0.0, 0.0, 0.0]] * len(texts))
        retriever.index([_chunk("c1", "anything")])
        # Query embedder also returns zero vector, normalization fails -> empty.
        self.assertEqual(retriever.query("anything", top_k=3), [])

    def test_embedder_returning_wrong_count_raises(self):
        bad = lambda texts: [[1.0]]  # always returns one vector
        retriever = EmbeddingRetriever(embedder=bad)
        with self.assertRaises(RuntimeError):
            retriever.index([_chunk("c1", "a"), _chunk("c2", "b")])

    def test_lazy_load_path_raises_if_sentence_transformers_missing(self):
        retriever = EmbeddingRetriever(model_id="fake/model")
        # Force the import to fail by patching builtins.__import__.
        import builtins

        real_import = builtins.__import__

        def fake_import(name, *args, **kwargs):
            if name == "sentence_transformers":
                raise ImportError("not installed")
            return real_import(name, *args, **kwargs)

        with patch("builtins.__import__", side_effect=fake_import):
            with self.assertRaises(RuntimeError) as ctx:
                retriever.index([_chunk("c1", "anything")])
            self.assertIn("sentence-transformers", str(ctx.exception))


class TestBuildRetriever(unittest.TestCase):
    def test_default_returns_lexical(self):
        retriever = build_retriever({})
        self.assertIsInstance(retriever, LexicalRetriever)

    def test_explicit_lexical_backend(self):
        retriever = build_retriever({"benchmarks": {"retriever": {"backend": "lexical"}}})
        self.assertIsInstance(retriever, LexicalRetriever)

    def test_embedding_backend_uses_gpu_models_embedding_default(self):
        config = {
            "benchmarks": {"retriever": {"backend": "embedding"}},
            "gpu": {"models": {"embedding": {"model_id": "custom/model", "task": "retrieval.query"}}},
        }
        retriever = build_retriever(config)
        self.assertIsInstance(retriever, EmbeddingRetriever)
        self.assertEqual(retriever._model_id, "custom/model")
        self.assertEqual(retriever._task, "retrieval.query")

    def test_explicit_model_id_overrides_gpu_default(self):
        config = {
            "benchmarks": {"retriever": {"backend": "embedding", "model_id": "explicit/model"}},
            "gpu": {"models": {"embedding": {"model_id": "ignored/model"}}},
        }
        retriever = build_retriever(config)
        self.assertEqual(retriever._model_id, "explicit/model")

    def test_unknown_backend_raises(self):
        with self.assertRaises(ValueError):
            build_retriever({"benchmarks": {"retriever": {"backend": "magic"}}})


class TestRunRetrievalWithEmbedding(unittest.TestCase):
    def test_run_retrieval_for_document_accepts_embedding_retriever(self):
        parsed = ParsedDocument(
            doc_id="d1",
            source_path="/tmp/d1.md",
            file_type="markdown",
            chunks=[
                _chunk("c1", "Apples grow on trees in the orchard during autumn."),
                _chunk("c2", "Submarines navigate beneath the ocean using sonar."),
            ],
            quality_report=QualityReport(),
        )
        retriever = EmbeddingRetriever(embedder=_hashing_embedder())
        run = run_retrieval_for_document(parsed, retriever=retriever)
        self.assertTrue(run["evaluated"])
        self.assertGreater(run["query_count"], 0)
        for result in run["results"]:
            truth = result["truths"][0]
            self.assertEqual(result["retrieved"][0], truth)


class TestBenchmarkOptInToEmbeddingBackend(unittest.TestCase):
    def test_benchmark_uses_embedding_when_config_says_so(self):
        # Patch build_retriever to return an EmbeddingRetriever with our toy embedder
        # so the benchmark exercises the opt-in code path without loading a real model.
        toy = EmbeddingRetriever(embedder=_hashing_embedder())

        with tempfile.TemporaryDirectory() as tmp:
            tmp = Path(tmp)
            src = tmp / "in"
            src.mkdir()
            (src / "doc.md").write_text(
                "# Doc\n\n"
                "Apples grow on trees in the orchard during autumn season.\n\n"
                "Submarines navigate beneath the ocean using sonar pulses across waters.\n",
                encoding="utf-8",
            )

            with patch("zsgdp.benchmarks.parser_quality.load_config") as load_config:
                load_config.return_value = {
                    "benchmarks": {"retriever": {"backend": "embedding"}},
                }
                with patch(
                    "zsgdp.benchmarks.embedding_retriever.build_retriever",
                    return_value=toy,
                ) as build_call:
                    summary = run_parser_benchmark(src, tmp / "out", dataset_name="custom_folder")

            self.assertGreaterEqual(build_call.call_count, 1)
            self.assertTrue(summary["documents"][0]["retrieval_evaluated"])


if __name__ == "__main__":
    unittest.main()