RAG_Eval / tests /test_reranker.py
Rom89823974978's picture
Updated metrics and tests
fc20fed
raw
history blame contribute delete
548 Bytes
from evaluation.rerankers.cross_encoder import CrossEncoderReranker
from evaluation.retrievers.base import Context
def test_rerank():
rer = CrossEncoderReranker("cross-encoder/ms-marco-MiniLM-L-6-v2", device="cpu")
dummy = [Context(id=str(i), text=f"text {i}", score=1.0) for i in range(5)]
out = rer.rerank("dummy query", dummy, k=3)
# If the model loads, out is a list of up to 3 contexts; otherwise same as input[:3]
assert isinstance(out, list)
assert all(isinstance(r, Context) for r in out)
assert len(out) <= 3