Spaces:
Running on Zero
Running on Zero
| """Tests for the embedding-based retriever and the build_retriever factory.""" | |
| from __future__ import annotations | |
| import math | |
| import tempfile | |
| import unittest | |
| from pathlib import Path | |
| from unittest.mock import patch | |
| from zsgdp.benchmarks.embedding_retriever import ( | |
| EmbeddingRetriever, | |
| build_retriever, | |
| ) | |
| from zsgdp.benchmarks.parser_quality import run_parser_benchmark | |
| from zsgdp.benchmarks.retrieval import LexicalRetriever, run_retrieval_for_document | |
| from zsgdp.schema import Chunk, ParsedDocument, QualityReport | |
| def _chunk(chunk_id: str, text: str) -> Chunk: | |
| return Chunk( | |
| chunk_id=chunk_id, | |
| doc_id="d1", | |
| page_start=1, | |
| page_end=1, | |
| section_path=[], | |
| content_type="prose", | |
| text=text, | |
| token_count=len(text.split()), | |
| ) | |
| def _hashing_embedder(dim: int = 32): | |
| """Deterministic toy embedder: tokens hashed into a fixed-dim vector. | |
| Uses a process-stable hash (hashlib.md5) instead of builtins.hash(), which | |
| is randomized per Python process and would make ranking non-deterministic | |
| across test runs. | |
| """ | |
| import hashlib | |
| def stable_hash(token: str) -> int: | |
| return int.from_bytes(hashlib.md5(token.encode("utf-8")).digest()[:8], "big") | |
| def encode(texts): | |
| out = [] | |
| for text in texts: | |
| vector = [0.0] * dim | |
| for token in text.lower().split(): | |
| vector[stable_hash(token) % dim] += 1.0 | |
| out.append(vector) | |
| return out | |
| return encode | |
| class TestEmbeddingRetriever(unittest.TestCase): | |
| def test_finds_distinctive_chunk_with_injected_embedder(self): | |
| chunks = [ | |
| _chunk("c1", "Apples grow on trees in the orchard."), | |
| _chunk("c2", "Cars drive on highways across the country."), | |
| _chunk("c3", "Boats sail on rivers and oceans."), | |
| ] | |
| retriever = EmbeddingRetriever(embedder=_hashing_embedder()) | |
| retriever.index(chunks) | |
| ranking = retriever.query("apples orchard", top_k=3) | |
| self.assertEqual(ranking[0], "c1") | |
| def test_empty_index_returns_empty(self): | |
| retriever = EmbeddingRetriever(embedder=_hashing_embedder()) | |
| self.assertEqual(retriever.query("anything", top_k=3), []) | |
| def test_zero_norm_vector_skipped(self): | |
| retriever = EmbeddingRetriever(embedder=lambda texts: [[0.0, 0.0, 0.0]] * len(texts)) | |
| retriever.index([_chunk("c1", "anything")]) | |
| # Query embedder also returns zero vector, normalization fails -> empty. | |
| self.assertEqual(retriever.query("anything", top_k=3), []) | |
| def test_embedder_returning_wrong_count_raises(self): | |
| bad = lambda texts: [[1.0]] # always returns one vector | |
| retriever = EmbeddingRetriever(embedder=bad) | |
| with self.assertRaises(RuntimeError): | |
| retriever.index([_chunk("c1", "a"), _chunk("c2", "b")]) | |
| def test_lazy_load_path_raises_if_sentence_transformers_missing(self): | |
| retriever = EmbeddingRetriever(model_id="fake/model") | |
| # Force the import to fail by patching builtins.__import__. | |
| import builtins | |
| real_import = builtins.__import__ | |
| def fake_import(name, *args, **kwargs): | |
| if name == "sentence_transformers": | |
| raise ImportError("not installed") | |
| return real_import(name, *args, **kwargs) | |
| with patch("builtins.__import__", side_effect=fake_import): | |
| with self.assertRaises(RuntimeError) as ctx: | |
| retriever.index([_chunk("c1", "anything")]) | |
| self.assertIn("sentence-transformers", str(ctx.exception)) | |
| class TestBuildRetriever(unittest.TestCase): | |
| def test_default_returns_lexical(self): | |
| retriever = build_retriever({}) | |
| self.assertIsInstance(retriever, LexicalRetriever) | |
| def test_explicit_lexical_backend(self): | |
| retriever = build_retriever({"benchmarks": {"retriever": {"backend": "lexical"}}}) | |
| self.assertIsInstance(retriever, LexicalRetriever) | |
| def test_embedding_backend_uses_gpu_models_embedding_default(self): | |
| config = { | |
| "benchmarks": {"retriever": {"backend": "embedding"}}, | |
| "gpu": {"models": {"embedding": {"model_id": "custom/model", "task": "retrieval.query"}}}, | |
| } | |
| retriever = build_retriever(config) | |
| self.assertIsInstance(retriever, EmbeddingRetriever) | |
| self.assertEqual(retriever._model_id, "custom/model") | |
| self.assertEqual(retriever._task, "retrieval.query") | |
| def test_explicit_model_id_overrides_gpu_default(self): | |
| config = { | |
| "benchmarks": {"retriever": {"backend": "embedding", "model_id": "explicit/model"}}, | |
| "gpu": {"models": {"embedding": {"model_id": "ignored/model"}}}, | |
| } | |
| retriever = build_retriever(config) | |
| self.assertEqual(retriever._model_id, "explicit/model") | |
| def test_unknown_backend_raises(self): | |
| with self.assertRaises(ValueError): | |
| build_retriever({"benchmarks": {"retriever": {"backend": "magic"}}}) | |
| class TestRunRetrievalWithEmbedding(unittest.TestCase): | |
| def test_run_retrieval_for_document_accepts_embedding_retriever(self): | |
| parsed = ParsedDocument( | |
| doc_id="d1", | |
| source_path="/tmp/d1.md", | |
| file_type="markdown", | |
| chunks=[ | |
| _chunk("c1", "Apples grow on trees in the orchard during autumn."), | |
| _chunk("c2", "Submarines navigate beneath the ocean using sonar."), | |
| ], | |
| quality_report=QualityReport(), | |
| ) | |
| retriever = EmbeddingRetriever(embedder=_hashing_embedder()) | |
| run = run_retrieval_for_document(parsed, retriever=retriever) | |
| self.assertTrue(run["evaluated"]) | |
| self.assertGreater(run["query_count"], 0) | |
| for result in run["results"]: | |
| truth = result["truths"][0] | |
| self.assertEqual(result["retrieved"][0], truth) | |
| class TestBenchmarkOptInToEmbeddingBackend(unittest.TestCase): | |
| def test_benchmark_uses_embedding_when_config_says_so(self): | |
| # Patch build_retriever to return an EmbeddingRetriever with our toy embedder | |
| # so the benchmark exercises the opt-in code path without loading a real model. | |
| toy = EmbeddingRetriever(embedder=_hashing_embedder()) | |
| with tempfile.TemporaryDirectory() as tmp: | |
| tmp = Path(tmp) | |
| src = tmp / "in" | |
| src.mkdir() | |
| (src / "doc.md").write_text( | |
| "# Doc\n\n" | |
| "Apples grow on trees in the orchard during autumn season.\n\n" | |
| "Submarines navigate beneath the ocean using sonar pulses across waters.\n", | |
| encoding="utf-8", | |
| ) | |
| with patch("zsgdp.benchmarks.parser_quality.load_config") as load_config: | |
| load_config.return_value = { | |
| "benchmarks": {"retriever": {"backend": "embedding"}}, | |
| } | |
| with patch( | |
| "zsgdp.benchmarks.embedding_retriever.build_retriever", | |
| return_value=toy, | |
| ) as build_call: | |
| summary = run_parser_benchmark(src, tmp / "out", dataset_name="custom_folder") | |
| self.assertGreaterEqual(build_call.call_count, 1) | |
| self.assertTrue(summary["documents"][0]["retrieval_evaluated"]) | |
| if __name__ == "__main__": | |
| unittest.main() | |