Spaces:
Sleeping
Sleeping
| from loguru import logger | |
| from sentence_transformers import CrossEncoder | |
| from scientific_rag.domain.documents import PaperChunk | |
| from scientific_rag.settings import settings | |
| class CrossEncoderReranker: | |
| def __init__(self): | |
| self.model_name = settings.reranker_model_name | |
| self.device = settings.embedding_device | |
| logger.info(f"Loading Reranker model: {self.model_name} on {self.device}") | |
| self.model = CrossEncoder(self.model_name, device=self.device) | |
| logger.info("Reranker model loaded") | |
| def rerank(self, query: str, chunks: list[PaperChunk], top_k: int = settings.rerank_top_k) -> list[PaperChunk]: | |
| if not chunks: | |
| return [] | |
| pairs = [[query, chunk.text] for chunk in chunks] | |
| scores = self.model.predict(pairs) | |
| reranked_chunks = [] | |
| for i, chunk in enumerate(chunks): | |
| chunk_copy = chunk.model_copy() | |
| chunk_copy.score = float(scores[i]) | |
| reranked_chunks.append(chunk_copy) | |
| reranked_chunks.sort(key=lambda x: x.score, reverse=True) | |
| return reranked_chunks[:top_k] | |
| reranker = CrossEncoderReranker() | |