File size: 1,957 Bytes
6c58cf4
dfa6a46
6c58cf4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from typing import List
from langchain.schema import Document
from flashrank.Ranker import Ranker, RerankRequest
from project.utils.config_loader import load_config
from project.logger.logging import get_logger

logger = get_logger(__name__)


class DocumentReranker:
    
    def __init__(self, config_path: str = None):
        self.config = load_config(config_path)
        reranker_config = self.config.get('reranker', {})
        model_name = reranker_config.get('model_name', 'rank-T5-flan')
        cache_dir = reranker_config.get('cache_dir')
        self.top_k = reranker_config.get('top_k', 3)
        
        if cache_dir:
            self.ranker = Ranker(model_name=model_name, cache_dir=cache_dir)
        else:
            self.ranker = Ranker(model_name=model_name)
        
        logger.info(f"FlashRank reranker initialized with model: {model_name}")
    
    def rerank(
        self, 
        query: str, 
        documents: List[Document], 
        top_k: int = None
    ) -> List[Document]:
        
        if top_k is None:
            top_k = self.top_k
        
        if not documents:
            logger.warning("No documents to rerank")
            return []
        
        passages = [
            {
                "id": i,
                "text": doc.page_content,
                "meta": doc.metadata
            }
            for i, doc in enumerate(documents)
        ]
        
        rerank_request = RerankRequest(query=query, passages=passages)
        
        results = self.ranker.rerank(rerank_request)
        
        reranked_docs = []
        for result in results[:top_k]:
            doc_idx = result["id"]
            original_doc = documents[doc_idx]
            original_doc.metadata["rerank_score"] = result["score"]
            reranked_docs.append(original_doc)
        
        logger.info(f"Reranked {len(documents)} documents, returning top {len(reranked_docs)}")
        return reranked_docs