"""Tests for the source-result -> chunk -> top-k retrieval pipeline.""" import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parents[1])) import research.retrieval as retrieval from research.retrieval import chunk_markdown, rank_chunks_for_query from research.types import ResearchChunk class FakeTextEmbedding: def __init__(self): self.seen_texts = [] def embed(self, texts): self.seen_texts.extend(texts) for text in texts: lower = text.lower() if "backpropagation" in lower or "chain rule" in lower or "target" in lower: yield [1.0, 0.0] else: yield [0.0, 1.0] def _chunk(title: str, text: str) -> ResearchChunk: return ResearchChunk( source="test", tool="fetch_docs", title=title, url="https://example.test", text=text, ) def test_chunking_then_top5_ranking(): docs = chunk_markdown( """ # Intro general overview # Chain Rule backpropagation chain rule gradients neural network # History unrelated history """, "Fallback", ) chunks = [_chunk(title, text) for title, text in docs] chunks.extend(_chunk(f"Filler {i}", f"unrelated filler {i}") for i in range(8)) ranked = rank_chunks_for_query( "backpropagation", "chain rule gradients", chunks, embedding_model=FakeTextEmbedding(), ) assert len(ranked) == 5 assert [chunk.rank for chunk in ranked] == [1, 2, 3, 4, 5] assert ranked[0].title == "Chain Rule" def test_embedding_ranking_is_not_bm25(): chunks = [ _chunk("Lexical Match", "query repeated query repeated lexical only"), _chunk("Embedding Match", "target concept with less lexical overlap"), _chunk("Other", "unrelated content"), ] ranked = rank_chunks_for_query( "query", "intent target", chunks, top_k=2, embedding_model=FakeTextEmbedding(), ) assert len(ranked) == 2 assert ranked[0].title == "Embedding Match" def test_preload_embedding_model_warms_runtime(): previous_model = retrieval._EMBEDDING_MODEL fake_model = FakeTextEmbedding() retrieval._EMBEDDING_MODEL = fake_model try: retrieval.preload_embedding_model() assert fake_model.seen_texts == ["startup warmup"] finally: retrieval._EMBEDDING_MODEL = previous_model if __name__ == "__main__": tests = [ test_chunking_then_top5_ranking, test_embedding_ranking_is_not_bm25, test_preload_embedding_model_warms_runtime, ] passed = 0 for test in tests: try: test() passed += 1 except Exception as exc: print(f"FAIL: {test.__name__}: {exc}") print(f"PASS: test_retrieval ({passed}/{len(tests)})")