File size: 892 Bytes
0e9a6da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import logging
from sentence_transformers import CrossEncoder


class ArxivReranker:
    """Reranker for Arxiv documents using a cross-encoder model."""

    def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
        self.model_name = model_name
        self.reranker = CrossEncoder(self.model_name)

    def rerank_documents(self, query: str, documents, top_k: int = 3):
        """Rerank documents using a cross-encoder model."""
        pairs = [[query, doc.page_content] for doc in documents]

        scores = self.reranker.predict(pairs)

        scored_docs = list(zip(documents, scores))
        scored_docs.sort(key=lambda x: x[1], reverse=True)

        reranked = [doc for doc, score in scored_docs[:top_k]]

        logging.info(
            f"Reranking scores: {[f'{score:.2f}' for _, score in scored_docs[:top_k]]}"
        )

        return reranked