Spaces:
Running
Running
| """Tests for the source-result -> chunk -> top-k retrieval pipeline.""" | |
| import sys | |
| from pathlib import Path | |
| sys.path.insert(0, str(Path(__file__).resolve().parents[1])) | |
| import research.retrieval as retrieval | |
| from research.retrieval import chunk_markdown, rank_chunks_for_query | |
| from research.types import ResearchChunk | |
| class FakeTextEmbedding: | |
| def __init__(self): | |
| self.seen_texts = [] | |
| def embed(self, texts): | |
| self.seen_texts.extend(texts) | |
| for text in texts: | |
| lower = text.lower() | |
| if "backpropagation" in lower or "chain rule" in lower or "target" in lower: | |
| yield [1.0, 0.0] | |
| else: | |
| yield [0.0, 1.0] | |
| def _chunk(title: str, text: str) -> ResearchChunk: | |
| return ResearchChunk( | |
| source="test", | |
| tool="fetch_docs", | |
| title=title, | |
| url="https://example.test", | |
| text=text, | |
| ) | |
| def test_chunking_then_top5_ranking(): | |
| docs = chunk_markdown( | |
| """ | |
| # Intro | |
| general overview | |
| # Chain Rule | |
| backpropagation chain rule gradients neural network | |
| # History | |
| unrelated history | |
| """, | |
| "Fallback", | |
| ) | |
| chunks = [_chunk(title, text) for title, text in docs] | |
| chunks.extend(_chunk(f"Filler {i}", f"unrelated filler {i}") for i in range(8)) | |
| ranked = rank_chunks_for_query( | |
| "backpropagation", | |
| "chain rule gradients", | |
| chunks, | |
| embedding_model=FakeTextEmbedding(), | |
| ) | |
| assert len(ranked) == 5 | |
| assert [chunk.rank for chunk in ranked] == [1, 2, 3, 4, 5] | |
| assert ranked[0].title == "Chain Rule" | |
| def test_embedding_ranking_is_not_bm25(): | |
| chunks = [ | |
| _chunk("Lexical Match", "query repeated query repeated lexical only"), | |
| _chunk("Embedding Match", "target concept with less lexical overlap"), | |
| _chunk("Other", "unrelated content"), | |
| ] | |
| ranked = rank_chunks_for_query( | |
| "query", | |
| "intent target", | |
| chunks, | |
| top_k=2, | |
| embedding_model=FakeTextEmbedding(), | |
| ) | |
| assert len(ranked) == 2 | |
| assert ranked[0].title == "Embedding Match" | |
| def test_preload_embedding_model_warms_runtime(): | |
| previous_model = retrieval._EMBEDDING_MODEL | |
| fake_model = FakeTextEmbedding() | |
| retrieval._EMBEDDING_MODEL = fake_model | |
| try: | |
| retrieval.preload_embedding_model() | |
| assert fake_model.seen_texts == ["startup warmup"] | |
| finally: | |
| retrieval._EMBEDDING_MODEL = previous_model | |
| if __name__ == "__main__": | |
| tests = [ | |
| test_chunking_then_top5_ranking, | |
| test_embedding_ranking_is_not_bm25, | |
| test_preload_embedding_model_warms_runtime, | |
| ] | |
| passed = 0 | |
| for test in tests: | |
| try: | |
| test() | |
| passed += 1 | |
| except Exception as exc: | |
| print(f"FAIL: {test.__name__}: {exc}") | |
| print(f"PASS: test_retrieval ({passed}/{len(tests)})") | |