File size: 3,095 Bytes
b62e029
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1487b7f
b62e029
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# models/reranker.py

from typing import Any, Dict, List

import torch
from FlagEmbedding import FlagReranker

from core.exceptions import ModelLoadError
from core.logger import setup_logger

logger = setup_logger("reranker")

class TextReranker:
    """
    Using the BGE-Reranker model, the documents retrieved in the first search are reordered (Cross-Encoding) by comparing them with the query.
    """
    def __init__(self, model_name: str = "BAAI/bge-reranker-v2-m3", use_fp16: bool = False):
        self.model_name = model_name
        self.device = self._get_device()
        
        try:
            logger.info(f"⏳ Loading Reranker Model: {self.model_name} on {self.device}")
            self.reranker = FlagReranker(
                self.model_name, 
                use_fp16=(use_fp16 and self.device.startswith("cuda"))
            )
            self._warmup()
            logger.info("✅ Reranker Model loaded successfully.")
        except Exception as e:
            logger.critical(f"❌ Failed to load Reranker Model: {e}", exc_info=True)
            raise ModelLoadError(f"Reranker initialization failed: {e}")

    def _get_device(self) -> str:
        if torch.cuda.is_available():
            return "cuda"
        elif torch.backends.mps.is_available():
            return "mps"
        return "cpu"
    
    def _warmup(self):
        logger.info("Warming up reranker model with a dummy input.")
        self.rerank(query="Hello world", documents=[{"text": "Hello world"}])

    def rerank(self, query: str, documents: List[Dict[str, Any]], text_key: str = "text") -> List[Dict[str, Any]]:
        """
        Takes a list of documents as input, recalculates their similarity to the query, and returns the results sorted by score.  

        :param query: The original search query string
        :param documents: A list of dictionaries in the form [{'chunk_id': 1, 'text': '...'}, ...]
        :param text_key: The key name in the document dictionary containing the body text
        """
        if not documents:
            return []

        # Generate pairs for Cross-Encoder input: [[query, doc1], [query, doc2], ...]
        sentence_pairs = [[query, doc[text_key]] for doc in documents]

        try:
            # 1. Batch score calculation
            scores = self.reranker.compute_score(sentence_pairs, normalize=True)
            
            # Wrap in a list because compute_score can return a float when there is only one input document
            if isinstance(scores, float):
                scores = [scores]

            # 2. Inject rerank_score into source document dictionarys
            for i, doc in enumerate(documents):
                doc["rerank_score"] = float(scores[i])

            # 3. Sort by score (descending)
            reranked_docs = sorted(documents, key=lambda x: x["rerank_score"], reverse=True)
            
            return reranked_docs
            
        except Exception as e:
            logger.error(f"Reranking failed for query '{query}': {e}")
            raise RuntimeError(f"Reranking process failed: {e}")