File size: 1,442 Bytes
5b89d45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import logging
from typing import List
from langchain_core.documents import Document
from sentence_transformers import CrossEncoder

logger = logging.getLogger(__name__)

class Reranker:
    """
    Uses a Cross-Encoder to re-rank documents retrieved by the vector store.
    This significantly improves precision by scoring the query against each document directly.
    """
    def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
        logger.info(f"Loading Reranker model: {model_name}")
        self.model = CrossEncoder(model_name)
    
    def rerank(self, query: str, documents: List[Document], top_k: int = 5) -> List[Document]:
        if not documents:
            return []
            
        # Prepare pairs for scoring: [[query, doc_text], ...]
        pairs = [[query, doc.page_content] for doc in documents]
        
        # Predict scores
        scores = self.model.predict(pairs)
        
        # Attach scores to docs and sort
        scored_docs = []
        for i, doc in enumerate(documents):
            # We can store the score in metadata if needed
            doc.metadata["rerank_score"] = float(scores[i])
            scored_docs.append((doc, scores[i]))
            
        # Sort by score descending
        scored_docs.sort(key=lambda x: x[1], reverse=True)
        
        # Return top_k
        top_docs = [doc for doc, score in scored_docs[:top_k]]
        return top_docs