File size: 1,145 Bytes
26fe9a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from typing import List, Dict
from sentence_transformers import CrossEncoder
import os
from ..observability.langfuse_client import observe

class Reranker:
    def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
        # We can make this optional/lazy load to speed startup if not used
        print(f"Loading reranker model: {model_name}")
        self.model = CrossEncoder(model_name)
        
    @observe(name="rerank")
    def rerank(self, query: str, chunks: List[Dict], top_k: int = 5) -> List[Dict]:
        if not chunks:
            return []
            
        pairs = [[query, c['content']] for c in chunks]
        scores = self.model.predict(pairs)
        
        for i, score in enumerate(scores):
            chunks[i]['rerank_score'] = float(score)
            
        # Resort
        chunks.sort(key=lambda x: x['rerank_score'], reverse=True)
        return chunks[:top_k]

_shared_reranker = None

def get_reranker():
    global _shared_reranker
    if _shared_reranker is None:
        # Default to a small fast cross encoder
        _shared_reranker = Reranker()
    return _shared_reranker