zeroshotGPU / tests /test_embedding_retriever.py
Arjunvir Singh
Initial commit: zeroshotGPU MVP with full eval surface
db06ffa
"""Tests for the embedding-based retriever and the build_retriever factory."""
from __future__ import annotations
import math
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch
from zsgdp.benchmarks.embedding_retriever import (
EmbeddingRetriever,
build_retriever,
)
from zsgdp.benchmarks.parser_quality import run_parser_benchmark
from zsgdp.benchmarks.retrieval import LexicalRetriever, run_retrieval_for_document
from zsgdp.schema import Chunk, ParsedDocument, QualityReport
def _chunk(chunk_id: str, text: str) -> Chunk:
return Chunk(
chunk_id=chunk_id,
doc_id="d1",
page_start=1,
page_end=1,
section_path=[],
content_type="prose",
text=text,
token_count=len(text.split()),
)
def _hashing_embedder(dim: int = 32):
"""Deterministic toy embedder: tokens hashed into a fixed-dim vector.
Uses a process-stable hash (hashlib.md5) instead of builtins.hash(), which
is randomized per Python process and would make ranking non-deterministic
across test runs.
"""
import hashlib
def stable_hash(token: str) -> int:
return int.from_bytes(hashlib.md5(token.encode("utf-8")).digest()[:8], "big")
def encode(texts):
out = []
for text in texts:
vector = [0.0] * dim
for token in text.lower().split():
vector[stable_hash(token) % dim] += 1.0
out.append(vector)
return out
return encode
class TestEmbeddingRetriever(unittest.TestCase):
def test_finds_distinctive_chunk_with_injected_embedder(self):
chunks = [
_chunk("c1", "Apples grow on trees in the orchard."),
_chunk("c2", "Cars drive on highways across the country."),
_chunk("c3", "Boats sail on rivers and oceans."),
]
retriever = EmbeddingRetriever(embedder=_hashing_embedder())
retriever.index(chunks)
ranking = retriever.query("apples orchard", top_k=3)
self.assertEqual(ranking[0], "c1")
def test_empty_index_returns_empty(self):
retriever = EmbeddingRetriever(embedder=_hashing_embedder())
self.assertEqual(retriever.query("anything", top_k=3), [])
def test_zero_norm_vector_skipped(self):
retriever = EmbeddingRetriever(embedder=lambda texts: [[0.0, 0.0, 0.0]] * len(texts))
retriever.index([_chunk("c1", "anything")])
# Query embedder also returns zero vector, normalization fails -> empty.
self.assertEqual(retriever.query("anything", top_k=3), [])
def test_embedder_returning_wrong_count_raises(self):
bad = lambda texts: [[1.0]] # always returns one vector
retriever = EmbeddingRetriever(embedder=bad)
with self.assertRaises(RuntimeError):
retriever.index([_chunk("c1", "a"), _chunk("c2", "b")])
def test_lazy_load_path_raises_if_sentence_transformers_missing(self):
retriever = EmbeddingRetriever(model_id="fake/model")
# Force the import to fail by patching builtins.__import__.
import builtins
real_import = builtins.__import__
def fake_import(name, *args, **kwargs):
if name == "sentence_transformers":
raise ImportError("not installed")
return real_import(name, *args, **kwargs)
with patch("builtins.__import__", side_effect=fake_import):
with self.assertRaises(RuntimeError) as ctx:
retriever.index([_chunk("c1", "anything")])
self.assertIn("sentence-transformers", str(ctx.exception))
class TestBuildRetriever(unittest.TestCase):
def test_default_returns_lexical(self):
retriever = build_retriever({})
self.assertIsInstance(retriever, LexicalRetriever)
def test_explicit_lexical_backend(self):
retriever = build_retriever({"benchmarks": {"retriever": {"backend": "lexical"}}})
self.assertIsInstance(retriever, LexicalRetriever)
def test_embedding_backend_uses_gpu_models_embedding_default(self):
config = {
"benchmarks": {"retriever": {"backend": "embedding"}},
"gpu": {"models": {"embedding": {"model_id": "custom/model", "task": "retrieval.query"}}},
}
retriever = build_retriever(config)
self.assertIsInstance(retriever, EmbeddingRetriever)
self.assertEqual(retriever._model_id, "custom/model")
self.assertEqual(retriever._task, "retrieval.query")
def test_explicit_model_id_overrides_gpu_default(self):
config = {
"benchmarks": {"retriever": {"backend": "embedding", "model_id": "explicit/model"}},
"gpu": {"models": {"embedding": {"model_id": "ignored/model"}}},
}
retriever = build_retriever(config)
self.assertEqual(retriever._model_id, "explicit/model")
def test_unknown_backend_raises(self):
with self.assertRaises(ValueError):
build_retriever({"benchmarks": {"retriever": {"backend": "magic"}}})
class TestRunRetrievalWithEmbedding(unittest.TestCase):
def test_run_retrieval_for_document_accepts_embedding_retriever(self):
parsed = ParsedDocument(
doc_id="d1",
source_path="/tmp/d1.md",
file_type="markdown",
chunks=[
_chunk("c1", "Apples grow on trees in the orchard during autumn."),
_chunk("c2", "Submarines navigate beneath the ocean using sonar."),
],
quality_report=QualityReport(),
)
retriever = EmbeddingRetriever(embedder=_hashing_embedder())
run = run_retrieval_for_document(parsed, retriever=retriever)
self.assertTrue(run["evaluated"])
self.assertGreater(run["query_count"], 0)
for result in run["results"]:
truth = result["truths"][0]
self.assertEqual(result["retrieved"][0], truth)
class TestBenchmarkOptInToEmbeddingBackend(unittest.TestCase):
def test_benchmark_uses_embedding_when_config_says_so(self):
# Patch build_retriever to return an EmbeddingRetriever with our toy embedder
# so the benchmark exercises the opt-in code path without loading a real model.
toy = EmbeddingRetriever(embedder=_hashing_embedder())
with tempfile.TemporaryDirectory() as tmp:
tmp = Path(tmp)
src = tmp / "in"
src.mkdir()
(src / "doc.md").write_text(
"# Doc\n\n"
"Apples grow on trees in the orchard during autumn season.\n\n"
"Submarines navigate beneath the ocean using sonar pulses across waters.\n",
encoding="utf-8",
)
with patch("zsgdp.benchmarks.parser_quality.load_config") as load_config:
load_config.return_value = {
"benchmarks": {"retriever": {"backend": "embedding"}},
}
with patch(
"zsgdp.benchmarks.embedding_retriever.build_retriever",
return_value=toy,
) as build_call:
summary = run_parser_benchmark(src, tmp / "out", dataset_name="custom_folder")
self.assertGreaterEqual(build_call.call_count, 1)
self.assertTrue(summary["documents"][0]["retrieval_evaluated"])
if __name__ == "__main__":
unittest.main()