File size: 3,654 Bytes
b0b150b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""
MEXAR - Cross-Encoder Reranking Module
Improves retrieval precision by reranking candidates with a cross-encoder model.
"""
import logging
from typing import List, Tuple, Any

logger = logging.getLogger(__name__)

# Lazy load to avoid slow import on startup
_reranker_model = None


def _get_reranker():
    """Lazy load the cross-encoder model."""
    global _reranker_model
    if _reranker_model is None:
        try:
            from sentence_transformers import CrossEncoder
            _reranker_model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
            logger.info("Cross-encoder reranker loaded successfully")
        except ImportError:
            logger.warning("sentence-transformers not installed. Install with: pip install sentence-transformers")
            _reranker_model = False
        except Exception as e:
            logger.warning(f"Failed to load cross-encoder: {e}")
            _reranker_model = False
    return _reranker_model


class Reranker:
    """
    Cross-encoder reranking for improved retrieval precision.
    
    Cross-encoders are more accurate than bi-encoders because they
    process query and document together, capturing fine-grained interactions.
    """
    
    def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
        """
        Initialize reranker.
        
        Args:
            model_name: HuggingFace model name for cross-encoder
        """
        self.model_name = model_name
        self._model = None
    
    @property
    def model(self):
        """Lazy load model on first use."""
        if self._model is None:
            self._model = _get_reranker()
        return self._model
    
    def rerank(
        self, 
        query: str, 
        chunks: List[Any], 
        top_k: int = 5
    ) -> List[Tuple[Any, float]]:
        """
        Rerank chunks using cross-encoder.
        
        Args:
            query: User's query
            chunks: List of DocumentChunk objects
            top_k: Number of top results to return
            
        Returns:
            List of (chunk, score) tuples, sorted by relevance
        """
        if not chunks:
            return []
        
        if not self.model:
            # Fallback: return chunks with placeholder scores
            logger.warning("Reranker not available, using original order")
            return [(chunk, 0.5) for chunk in chunks[:top_k]]
        
        try:
            # Create query-document pairs
            # Truncate content to avoid memory issues
            pairs = [[query, self._get_content(chunk)[:512]] for chunk in chunks]
            
            # Get cross-encoder scores
            scores = self.model.predict(pairs)
            
            # Combine chunks with scores
            chunk_scores = list(zip(chunks, scores))
            
            # Sort by score descending
            ranked = sorted(chunk_scores, key=lambda x: x[1], reverse=True)
            
            logger.info(f"Reranked {len(chunks)} chunks, returning top {top_k}")
            return ranked[:top_k]
            
        except Exception as e:
            logger.error(f"Reranking failed: {e}")
            return [(chunk, 0.5) for chunk in chunks[:top_k]]
    
    def _get_content(self, chunk) -> str:
        """Extract content from chunk object."""
        if hasattr(chunk, 'content'):
            return chunk.content
        elif isinstance(chunk, dict):
            return chunk.get('content', '')
        else:
            return str(chunk)


def create_reranker() -> Reranker:
    """Factory function to create Reranker."""
    return Reranker()