explainer-env / tests /test_retrieval.py
kgdrathan's picture
Upload folder using huggingface_hub
8fa7af1 verified
"""Tests for the source-result -> chunk -> top-k retrieval pipeline."""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
import research.retrieval as retrieval
from research.retrieval import chunk_markdown, rank_chunks_for_query
from research.types import ResearchChunk
class FakeTextEmbedding:
def __init__(self):
self.seen_texts = []
def embed(self, texts):
self.seen_texts.extend(texts)
for text in texts:
lower = text.lower()
if "backpropagation" in lower or "chain rule" in lower or "target" in lower:
yield [1.0, 0.0]
else:
yield [0.0, 1.0]
def _chunk(title: str, text: str) -> ResearchChunk:
return ResearchChunk(
source="test",
tool="fetch_docs",
title=title,
url="https://example.test",
text=text,
)
def test_chunking_then_top5_ranking():
docs = chunk_markdown(
"""
# Intro
general overview
# Chain Rule
backpropagation chain rule gradients neural network
# History
unrelated history
""",
"Fallback",
)
chunks = [_chunk(title, text) for title, text in docs]
chunks.extend(_chunk(f"Filler {i}", f"unrelated filler {i}") for i in range(8))
ranked = rank_chunks_for_query(
"backpropagation",
"chain rule gradients",
chunks,
embedding_model=FakeTextEmbedding(),
)
assert len(ranked) == 5
assert [chunk.rank for chunk in ranked] == [1, 2, 3, 4, 5]
assert ranked[0].title == "Chain Rule"
def test_embedding_ranking_is_not_bm25():
chunks = [
_chunk("Lexical Match", "query repeated query repeated lexical only"),
_chunk("Embedding Match", "target concept with less lexical overlap"),
_chunk("Other", "unrelated content"),
]
ranked = rank_chunks_for_query(
"query",
"intent target",
chunks,
top_k=2,
embedding_model=FakeTextEmbedding(),
)
assert len(ranked) == 2
assert ranked[0].title == "Embedding Match"
def test_preload_embedding_model_warms_runtime():
previous_model = retrieval._EMBEDDING_MODEL
fake_model = FakeTextEmbedding()
retrieval._EMBEDDING_MODEL = fake_model
try:
retrieval.preload_embedding_model()
assert fake_model.seen_texts == ["startup warmup"]
finally:
retrieval._EMBEDDING_MODEL = previous_model
if __name__ == "__main__":
tests = [
test_chunking_then_top5_ranking,
test_embedding_ranking_is_not_bm25,
test_preload_embedding_model_warms_runtime,
]
passed = 0
for test in tests:
try:
test()
passed += 1
except Exception as exc:
print(f"FAIL: {test.__name__}: {exc}")
print(f"PASS: test_retrieval ({passed}/{len(tests)})")