File size: 892 Bytes
0e9a6da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
import logging
from sentence_transformers import CrossEncoder
class ArxivReranker:
"""Reranker for Arxiv documents using a cross-encoder model."""
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
self.model_name = model_name
self.reranker = CrossEncoder(self.model_name)
def rerank_documents(self, query: str, documents, top_k: int = 3):
"""Rerank documents using a cross-encoder model."""
pairs = [[query, doc.page_content] for doc in documents]
scores = self.reranker.predict(pairs)
scored_docs = list(zip(documents, scores))
scored_docs.sort(key=lambda x: x[1], reverse=True)
reranked = [doc for doc, score in scored_docs[:top_k]]
logging.info(
f"Reranking scores: {[f'{score:.2f}' for _, score in scored_docs[:top_k]]}"
)
return reranked
|