File size: 1,150 Bytes
58787ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from loguru import logger
from sentence_transformers import CrossEncoder

from scientific_rag.domain.documents import PaperChunk
from scientific_rag.settings import settings


class CrossEncoderReranker:
    def __init__(self):
        self.model_name = settings.reranker_model_name
        self.device = settings.embedding_device

        logger.info(f"Loading Reranker model: {self.model_name} on {self.device}")
        self.model = CrossEncoder(self.model_name, device=self.device)
        logger.info("Reranker model loaded")

    def rerank(self, query: str, chunks: list[PaperChunk], top_k: int = settings.rerank_top_k) -> list[PaperChunk]:
        if not chunks:
            return []

        pairs = [[query, chunk.text] for chunk in chunks]
        scores = self.model.predict(pairs)

        reranked_chunks = []
        for i, chunk in enumerate(chunks):
            chunk_copy = chunk.model_copy()
            chunk_copy.score = float(scores[i])
            reranked_chunks.append(chunk_copy)

        reranked_chunks.sort(key=lambda x: x.score, reverse=True)

        return reranked_chunks[:top_k]


reranker = CrossEncoderReranker()