Daryna Vasylashko commited on
Commit
58787ee
·
1 Parent(s): 305c138

feat: reranker

Browse files
src/scientific_rag/application/reranking/cross_encoder.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from loguru import logger
2
+ from sentence_transformers import CrossEncoder
3
+
4
+ from scientific_rag.domain.documents import PaperChunk
5
+ from scientific_rag.settings import settings
6
+
7
+
8
+ class CrossEncoderReranker:
9
+ def __init__(self):
10
+ self.model_name = settings.reranker_model_name
11
+ self.device = settings.embedding_device
12
+
13
+ logger.info(f"Loading Reranker model: {self.model_name} on {self.device}")
14
+ self.model = CrossEncoder(self.model_name, device=self.device)
15
+ logger.info("Reranker model loaded")
16
+
17
+ def rerank(self, query: str, chunks: list[PaperChunk], top_k: int = settings.rerank_top_k) -> list[PaperChunk]:
18
+ if not chunks:
19
+ return []
20
+
21
+ pairs = [[query, chunk.text] for chunk in chunks]
22
+ scores = self.model.predict(pairs)
23
+
24
+ reranked_chunks = []
25
+ for i, chunk in enumerate(chunks):
26
+ chunk_copy = chunk.model_copy()
27
+ chunk_copy.score = float(scores[i])
28
+ reranked_chunks.append(chunk_copy)
29
+
30
+ reranked_chunks.sort(key=lambda x: x.score, reverse=True)
31
+
32
+ return reranked_chunks[:top_k]
33
+
34
+
35
+ reranker = CrossEncoderReranker()