Spaces:
Sleeping
Sleeping
| import json | |
| import numpy as np | |
| import pytest | |
| from pathlib import Path | |
| from evaluation.retrievers.dense import DenseRetriever | |
| from evaluation.retrievers.base import Context | |
| class DummyIndex: | |
| def __init__(self): | |
| self.ntotal = 3 | |
| import faiss | |
| # Use IP if available, else fallback to L2 | |
| self.metric_type = getattr(faiss, "METRIC_INNER_PRODUCT", faiss.METRIC_L2) | |
| def search(self, vec, top_k): | |
| # Always return three dummy distances/indices | |
| dists = np.array([[0.2, 0.15, 0.05]]) | |
| idxs = np.array([[0, 1, 2]]) | |
| return dists, idxs | |
| class DummyEmbedder: | |
| def encode(self, texts, normalize_embeddings): | |
| # Return a fixed-size vector (the actual values don't matter) | |
| return np.array([0.1, 0.2, 0.3, 0.4], dtype="float32") | |
| def patch_faiss_and_transformer(monkeypatch): | |
| # Stub out faiss.read_index → DummyIndex() | |
| import faiss | |
| monkeypatch.setattr(faiss, "read_index", lambda _: DummyIndex()) | |
| # Stub out SentenceTransformer → DummyEmbedder() | |
| import sentence_transformers | |
| monkeypatch.setattr( | |
| sentence_transformers, | |
| "SentenceTransformer", | |
| lambda *args, **kwargs: DummyEmbedder(), | |
| ) | |
| yield | |
| def test_dense_index_build_and_search(tmp_path): | |
| docs = [ | |
| {"id": 0, "text": "Doc zero"}, | |
| {"id": 1, "text": "Doc one"}, | |
| {"id": 2, "text": "Doc two"}, | |
| ] | |
| doc_store_path = tmp_path / "docs.jsonl" | |
| with doc_store_path.open("w") as f: | |
| for obj in docs: | |
| f.write(json.dumps(obj) + "\n") | |
| faiss_idx = tmp_path / "index.faiss" | |
| if faiss_idx.exists(): | |
| faiss_idx.unlink() | |
| # Instantiate DenseRetriever; should write a real FAISS file to disk | |
| retriever = DenseRetriever( | |
| faiss_index=faiss_idx, | |
| doc_store=doc_store_path, | |
| model_name="dummy-model-name", | |
| device="cpu", | |
| ) | |
| # Now the FAISS file should exist on disk | |
| assert faiss_idx.exists() | |
| results = retriever.retrieve("any query", top_k=3) | |
| assert isinstance(results, list) | |
| assert len(results) == 3 | |
| for i, ctx in enumerate(results): | |
| assert isinstance(ctx, Context) | |
| assert ctx.id == str(i) | |
| # DummyIndex returned dists [0.2, 0.15, 0.05] | |
| assert ctx.score == pytest.approx([0.2, 0.15, 0.05][i], rel=1e-6) | |
| # The text must come from doc_store | |
| assert ctx.text in {"Doc zero", "Doc one", "Doc two"} | |
| def test_dense_retrieve_when_faiss_or_transformer_fails(monkeypatch, tmp_path): | |
| import faiss | |
| # Force faiss.read_index to raise | |
| monkeypatch.setattr(faiss, "read_index", lambda _: (_ for _ in ()).throw(Exception("fail"))) | |
| doc_store_path = tmp_path / "docs.jsonl" | |
| doc_store_path.write_text('{"id":0,"text":"hello"}\n') | |
| faiss_idx = tmp_path / "index2.faiss" | |
| if faiss_idx.exists(): | |
| faiss_idx.unlink() | |
| retriever = DenseRetriever( | |
| faiss_index=faiss_idx, | |
| doc_store=doc_store_path, | |
| model_name="dummy-model-name", | |
| device="cpu", | |
| ) | |
| # Since index load failed, retrieve() must return [] | |
| assert retriever.retrieve("whatever", top_k=5) == [] | |