Spaces:
Running on Zero
Running on Zero
File size: 7,418 Bytes
db06ffa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 | """Tests for the embedding-based retriever and the build_retriever factory."""
from __future__ import annotations
import math
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch
from zsgdp.benchmarks.embedding_retriever import (
EmbeddingRetriever,
build_retriever,
)
from zsgdp.benchmarks.parser_quality import run_parser_benchmark
from zsgdp.benchmarks.retrieval import LexicalRetriever, run_retrieval_for_document
from zsgdp.schema import Chunk, ParsedDocument, QualityReport
def _chunk(chunk_id: str, text: str) -> Chunk:
return Chunk(
chunk_id=chunk_id,
doc_id="d1",
page_start=1,
page_end=1,
section_path=[],
content_type="prose",
text=text,
token_count=len(text.split()),
)
def _hashing_embedder(dim: int = 32):
"""Deterministic toy embedder: tokens hashed into a fixed-dim vector.
Uses a process-stable hash (hashlib.md5) instead of builtins.hash(), which
is randomized per Python process and would make ranking non-deterministic
across test runs.
"""
import hashlib
def stable_hash(token: str) -> int:
return int.from_bytes(hashlib.md5(token.encode("utf-8")).digest()[:8], "big")
def encode(texts):
out = []
for text in texts:
vector = [0.0] * dim
for token in text.lower().split():
vector[stable_hash(token) % dim] += 1.0
out.append(vector)
return out
return encode
class TestEmbeddingRetriever(unittest.TestCase):
def test_finds_distinctive_chunk_with_injected_embedder(self):
chunks = [
_chunk("c1", "Apples grow on trees in the orchard."),
_chunk("c2", "Cars drive on highways across the country."),
_chunk("c3", "Boats sail on rivers and oceans."),
]
retriever = EmbeddingRetriever(embedder=_hashing_embedder())
retriever.index(chunks)
ranking = retriever.query("apples orchard", top_k=3)
self.assertEqual(ranking[0], "c1")
def test_empty_index_returns_empty(self):
retriever = EmbeddingRetriever(embedder=_hashing_embedder())
self.assertEqual(retriever.query("anything", top_k=3), [])
def test_zero_norm_vector_skipped(self):
retriever = EmbeddingRetriever(embedder=lambda texts: [[0.0, 0.0, 0.0]] * len(texts))
retriever.index([_chunk("c1", "anything")])
# Query embedder also returns zero vector, normalization fails -> empty.
self.assertEqual(retriever.query("anything", top_k=3), [])
def test_embedder_returning_wrong_count_raises(self):
bad = lambda texts: [[1.0]] # always returns one vector
retriever = EmbeddingRetriever(embedder=bad)
with self.assertRaises(RuntimeError):
retriever.index([_chunk("c1", "a"), _chunk("c2", "b")])
def test_lazy_load_path_raises_if_sentence_transformers_missing(self):
retriever = EmbeddingRetriever(model_id="fake/model")
# Force the import to fail by patching builtins.__import__.
import builtins
real_import = builtins.__import__
def fake_import(name, *args, **kwargs):
if name == "sentence_transformers":
raise ImportError("not installed")
return real_import(name, *args, **kwargs)
with patch("builtins.__import__", side_effect=fake_import):
with self.assertRaises(RuntimeError) as ctx:
retriever.index([_chunk("c1", "anything")])
self.assertIn("sentence-transformers", str(ctx.exception))
class TestBuildRetriever(unittest.TestCase):
def test_default_returns_lexical(self):
retriever = build_retriever({})
self.assertIsInstance(retriever, LexicalRetriever)
def test_explicit_lexical_backend(self):
retriever = build_retriever({"benchmarks": {"retriever": {"backend": "lexical"}}})
self.assertIsInstance(retriever, LexicalRetriever)
def test_embedding_backend_uses_gpu_models_embedding_default(self):
config = {
"benchmarks": {"retriever": {"backend": "embedding"}},
"gpu": {"models": {"embedding": {"model_id": "custom/model", "task": "retrieval.query"}}},
}
retriever = build_retriever(config)
self.assertIsInstance(retriever, EmbeddingRetriever)
self.assertEqual(retriever._model_id, "custom/model")
self.assertEqual(retriever._task, "retrieval.query")
def test_explicit_model_id_overrides_gpu_default(self):
config = {
"benchmarks": {"retriever": {"backend": "embedding", "model_id": "explicit/model"}},
"gpu": {"models": {"embedding": {"model_id": "ignored/model"}}},
}
retriever = build_retriever(config)
self.assertEqual(retriever._model_id, "explicit/model")
def test_unknown_backend_raises(self):
with self.assertRaises(ValueError):
build_retriever({"benchmarks": {"retriever": {"backend": "magic"}}})
class TestRunRetrievalWithEmbedding(unittest.TestCase):
def test_run_retrieval_for_document_accepts_embedding_retriever(self):
parsed = ParsedDocument(
doc_id="d1",
source_path="/tmp/d1.md",
file_type="markdown",
chunks=[
_chunk("c1", "Apples grow on trees in the orchard during autumn."),
_chunk("c2", "Submarines navigate beneath the ocean using sonar."),
],
quality_report=QualityReport(),
)
retriever = EmbeddingRetriever(embedder=_hashing_embedder())
run = run_retrieval_for_document(parsed, retriever=retriever)
self.assertTrue(run["evaluated"])
self.assertGreater(run["query_count"], 0)
for result in run["results"]:
truth = result["truths"][0]
self.assertEqual(result["retrieved"][0], truth)
class TestBenchmarkOptInToEmbeddingBackend(unittest.TestCase):
def test_benchmark_uses_embedding_when_config_says_so(self):
# Patch build_retriever to return an EmbeddingRetriever with our toy embedder
# so the benchmark exercises the opt-in code path without loading a real model.
toy = EmbeddingRetriever(embedder=_hashing_embedder())
with tempfile.TemporaryDirectory() as tmp:
tmp = Path(tmp)
src = tmp / "in"
src.mkdir()
(src / "doc.md").write_text(
"# Doc\n\n"
"Apples grow on trees in the orchard during autumn season.\n\n"
"Submarines navigate beneath the ocean using sonar pulses across waters.\n",
encoding="utf-8",
)
with patch("zsgdp.benchmarks.parser_quality.load_config") as load_config:
load_config.return_value = {
"benchmarks": {"retriever": {"backend": "embedding"}},
}
with patch(
"zsgdp.benchmarks.embedding_retriever.build_retriever",
return_value=toy,
) as build_call:
summary = run_parser_benchmark(src, tmp / "out", dataset_name="custom_folder")
self.assertGreaterEqual(build_call.call_count, 1)
self.assertTrue(summary["documents"][0]["retrieval_evaluated"])
if __name__ == "__main__":
unittest.main()
|