RAG_Eval / tests /test_hybrid_retriever.py
Rom89823974978's picture
forgot the c
549e0c8
import pytest
from pathlib import Path
from evaluation.retrievers.base import Context
from evaluation.retrievers.hybrid import HybridRetriever
class DummyBM25:
def __init__(self, bm25_idx: str, doc_store: str):
pass
def retrieve(self, query: str, top_k: int):
return [
Context(id="a", text="bm25_doc_a", score=1.0),
Context(id="b", text="bm25_doc_b", score=0.5),
]
class DummyDense:
def __init__(
self, faiss_index: str, doc_store: str, model_name: str, embedder_cache: str, device: str
):
pass
def retrieve(self, query: str, top_k: int):
return [
Context(id="b", text="dense_doc_b", score=0.8),
Context(id="c", text="dense_doc_c", score=0.3),
]
@pytest.fixture(autouse=True)
def patch_internal_retrievers(monkeypatch):
import evaluation.retrievers.hybrid as hybrid_mod
monkeypatch.setattr(hybrid_mod, "BM25Retriever", DummyBM25)
monkeypatch.setattr(hybrid_mod, "DenseRetriever", DummyDense)
yield
def test_hybrid_retriever_combines_scores(tmp_path):
bm25_idx = tmp_path / "bm25_index"
faiss_index = tmp_path / "dense_index"
doc_store = tmp_path / "docs.jsonl"
doc_store.write_text('{"id":0,"text":"hello"}\n')
hybrid = HybridRetriever(
bm25_idx=str(bm25_idx),
faiss_index=str(faiss_index),
doc_store=str(doc_store),
alpha=0.5,
model_name="ignored",
embedder_cache=None,
device="cpu",
)
results = hybrid.retrieve("dummy query", top_k=2)
assert isinstance(results, list)
assert all(isinstance(r, Context) for r in results)
ids_in_order = [r.id for r in results]
scores = {r.id: r.score for r in results}
# “b” should have (0.5*0.5 + 0.5*0.8) = 0.65
# “a” should have (0.5*1.0 + 0.5*0.0) = 0.50
assert ids_in_order == ["b", "a"]
assert scores["b"] == pytest.approx(0.65, rel=1e-6)
assert scores["a"] == pytest.approx(0.50, rel=1e-6)