| import logging | |
| from sentence_transformers import CrossEncoder | |
| class ArxivReranker: | |
| """Reranker for Arxiv documents using a cross-encoder model.""" | |
| def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"): | |
| self.model_name = model_name | |
| self.reranker = CrossEncoder(self.model_name) | |
| def rerank_documents(self, query: str, documents, top_k: int = 3): | |
| """Rerank documents using a cross-encoder model.""" | |
| pairs = [[query, doc.page_content] for doc in documents] | |
| scores = self.reranker.predict(pairs) | |
| scored_docs = list(zip(documents, scores)) | |
| scored_docs.sort(key=lambda x: x[1], reverse=True) | |
| reranked = [doc for doc, score in scored_docs[:top_k]] | |
| logging.info( | |
| f"Reranking scores: {[f'{score:.2f}' for _, score in scored_docs[:top_k]]}" | |
| ) | |
| return reranked | |