"""Tests for the embedding-based retriever and the build_retriever factory.""" from __future__ import annotations import math import tempfile import unittest from pathlib import Path from unittest.mock import patch from zsgdp.benchmarks.embedding_retriever import ( EmbeddingRetriever, build_retriever, ) from zsgdp.benchmarks.parser_quality import run_parser_benchmark from zsgdp.benchmarks.retrieval import LexicalRetriever, run_retrieval_for_document from zsgdp.schema import Chunk, ParsedDocument, QualityReport def _chunk(chunk_id: str, text: str) -> Chunk: return Chunk( chunk_id=chunk_id, doc_id="d1", page_start=1, page_end=1, section_path=[], content_type="prose", text=text, token_count=len(text.split()), ) def _hashing_embedder(dim: int = 32): """Deterministic toy embedder: tokens hashed into a fixed-dim vector. Uses a process-stable hash (hashlib.md5) instead of builtins.hash(), which is randomized per Python process and would make ranking non-deterministic across test runs. """ import hashlib def stable_hash(token: str) -> int: return int.from_bytes(hashlib.md5(token.encode("utf-8")).digest()[:8], "big") def encode(texts): out = [] for text in texts: vector = [0.0] * dim for token in text.lower().split(): vector[stable_hash(token) % dim] += 1.0 out.append(vector) return out return encode class TestEmbeddingRetriever(unittest.TestCase): def test_finds_distinctive_chunk_with_injected_embedder(self): chunks = [ _chunk("c1", "Apples grow on trees in the orchard."), _chunk("c2", "Cars drive on highways across the country."), _chunk("c3", "Boats sail on rivers and oceans."), ] retriever = EmbeddingRetriever(embedder=_hashing_embedder()) retriever.index(chunks) ranking = retriever.query("apples orchard", top_k=3) self.assertEqual(ranking[0], "c1") def test_empty_index_returns_empty(self): retriever = EmbeddingRetriever(embedder=_hashing_embedder()) self.assertEqual(retriever.query("anything", top_k=3), []) def test_zero_norm_vector_skipped(self): retriever = EmbeddingRetriever(embedder=lambda texts: [[0.0, 0.0, 0.0]] * len(texts)) retriever.index([_chunk("c1", "anything")]) # Query embedder also returns zero vector, normalization fails -> empty. self.assertEqual(retriever.query("anything", top_k=3), []) def test_embedder_returning_wrong_count_raises(self): bad = lambda texts: [[1.0]] # always returns one vector retriever = EmbeddingRetriever(embedder=bad) with self.assertRaises(RuntimeError): retriever.index([_chunk("c1", "a"), _chunk("c2", "b")]) def test_lazy_load_path_raises_if_sentence_transformers_missing(self): retriever = EmbeddingRetriever(model_id="fake/model") # Force the import to fail by patching builtins.__import__. import builtins real_import = builtins.__import__ def fake_import(name, *args, **kwargs): if name == "sentence_transformers": raise ImportError("not installed") return real_import(name, *args, **kwargs) with patch("builtins.__import__", side_effect=fake_import): with self.assertRaises(RuntimeError) as ctx: retriever.index([_chunk("c1", "anything")]) self.assertIn("sentence-transformers", str(ctx.exception)) class TestBuildRetriever(unittest.TestCase): def test_default_returns_lexical(self): retriever = build_retriever({}) self.assertIsInstance(retriever, LexicalRetriever) def test_explicit_lexical_backend(self): retriever = build_retriever({"benchmarks": {"retriever": {"backend": "lexical"}}}) self.assertIsInstance(retriever, LexicalRetriever) def test_embedding_backend_uses_gpu_models_embedding_default(self): config = { "benchmarks": {"retriever": {"backend": "embedding"}}, "gpu": {"models": {"embedding": {"model_id": "custom/model", "task": "retrieval.query"}}}, } retriever = build_retriever(config) self.assertIsInstance(retriever, EmbeddingRetriever) self.assertEqual(retriever._model_id, "custom/model") self.assertEqual(retriever._task, "retrieval.query") def test_explicit_model_id_overrides_gpu_default(self): config = { "benchmarks": {"retriever": {"backend": "embedding", "model_id": "explicit/model"}}, "gpu": {"models": {"embedding": {"model_id": "ignored/model"}}}, } retriever = build_retriever(config) self.assertEqual(retriever._model_id, "explicit/model") def test_unknown_backend_raises(self): with self.assertRaises(ValueError): build_retriever({"benchmarks": {"retriever": {"backend": "magic"}}}) class TestRunRetrievalWithEmbedding(unittest.TestCase): def test_run_retrieval_for_document_accepts_embedding_retriever(self): parsed = ParsedDocument( doc_id="d1", source_path="/tmp/d1.md", file_type="markdown", chunks=[ _chunk("c1", "Apples grow on trees in the orchard during autumn."), _chunk("c2", "Submarines navigate beneath the ocean using sonar."), ], quality_report=QualityReport(), ) retriever = EmbeddingRetriever(embedder=_hashing_embedder()) run = run_retrieval_for_document(parsed, retriever=retriever) self.assertTrue(run["evaluated"]) self.assertGreater(run["query_count"], 0) for result in run["results"]: truth = result["truths"][0] self.assertEqual(result["retrieved"][0], truth) class TestBenchmarkOptInToEmbeddingBackend(unittest.TestCase): def test_benchmark_uses_embedding_when_config_says_so(self): # Patch build_retriever to return an EmbeddingRetriever with our toy embedder # so the benchmark exercises the opt-in code path without loading a real model. toy = EmbeddingRetriever(embedder=_hashing_embedder()) with tempfile.TemporaryDirectory() as tmp: tmp = Path(tmp) src = tmp / "in" src.mkdir() (src / "doc.md").write_text( "# Doc\n\n" "Apples grow on trees in the orchard during autumn season.\n\n" "Submarines navigate beneath the ocean using sonar pulses across waters.\n", encoding="utf-8", ) with patch("zsgdp.benchmarks.parser_quality.load_config") as load_config: load_config.return_value = { "benchmarks": {"retriever": {"backend": "embedding"}}, } with patch( "zsgdp.benchmarks.embedding_retriever.build_retriever", return_value=toy, ) as build_call: summary = run_parser_benchmark(src, tmp / "out", dataset_name="custom_folder") self.assertGreaterEqual(build_call.call_count, 1) self.assertTrue(summary["documents"][0]["retrieval_evaluated"]) if __name__ == "__main__": unittest.main()