RAG_Eval / tests /test_dense_retriever.py
Rom89823974978's picture
Updated metrics and tests
fc20fed
import json
import numpy as np
import pytest
from pathlib import Path
from evaluation.retrievers.dense import DenseRetriever
from evaluation.retrievers.base import Context
class DummyIndex:
def __init__(self):
self.ntotal = 3
import faiss
# Use IP if available, else fallback to L2
self.metric_type = getattr(faiss, "METRIC_INNER_PRODUCT", faiss.METRIC_L2)
def search(self, vec, top_k):
# Always return three dummy distances/indices
dists = np.array([[0.2, 0.15, 0.05]])
idxs = np.array([[0, 1, 2]])
return dists, idxs
class DummyEmbedder:
def encode(self, texts, normalize_embeddings):
# Return a fixed-size vector (the actual values don't matter)
return np.array([0.1, 0.2, 0.3, 0.4], dtype="float32")
@pytest.fixture(autouse=True)
def patch_faiss_and_transformer(monkeypatch):
# Stub out faiss.read_index → DummyIndex()
import faiss
monkeypatch.setattr(faiss, "read_index", lambda _: DummyIndex())
# Stub out SentenceTransformer → DummyEmbedder()
import sentence_transformers
monkeypatch.setattr(
sentence_transformers,
"SentenceTransformer",
lambda *args, **kwargs: DummyEmbedder(),
)
yield
def test_dense_index_build_and_search(tmp_path):
docs = [
{"id": 0, "text": "Doc zero"},
{"id": 1, "text": "Doc one"},
{"id": 2, "text": "Doc two"},
]
doc_store_path = tmp_path / "docs.jsonl"
with doc_store_path.open("w") as f:
for obj in docs:
f.write(json.dumps(obj) + "\n")
faiss_idx = tmp_path / "index.faiss"
if faiss_idx.exists():
faiss_idx.unlink()
# Instantiate DenseRetriever; should write a real FAISS file to disk
retriever = DenseRetriever(
faiss_index=faiss_idx,
doc_store=doc_store_path,
model_name="dummy-model-name",
device="cpu",
)
# Now the FAISS file should exist on disk
assert faiss_idx.exists()
results = retriever.retrieve("any query", top_k=3)
assert isinstance(results, list)
assert len(results) == 3
for i, ctx in enumerate(results):
assert isinstance(ctx, Context)
assert ctx.id == str(i)
# DummyIndex returned dists [0.2, 0.15, 0.05]
assert ctx.score == pytest.approx([0.2, 0.15, 0.05][i], rel=1e-6)
# The text must come from doc_store
assert ctx.text in {"Doc zero", "Doc one", "Doc two"}
def test_dense_retrieve_when_faiss_or_transformer_fails(monkeypatch, tmp_path):
import faiss
# Force faiss.read_index to raise
monkeypatch.setattr(faiss, "read_index", lambda _: (_ for _ in ()).throw(Exception("fail")))
doc_store_path = tmp_path / "docs.jsonl"
doc_store_path.write_text('{"id":0,"text":"hello"}\n')
faiss_idx = tmp_path / "index2.faiss"
if faiss_idx.exists():
faiss_idx.unlink()
retriever = DenseRetriever(
faiss_index=faiss_idx,
doc_store=doc_store_path,
model_name="dummy-model-name",
device="cpu",
)
# Since index load failed, retrieve() must return []
assert retriever.retrieve("whatever", top_k=5) == []