File size: 4,010 Bytes
0e8c152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""
Cross-encoder re-ranking for RAG search results.

Uses cross-encoder/ms-marco-MiniLM-L-6-v2 for fast, accurate re-ranking
of vector search results to improve retrieval accuracy.
"""

from __future__ import annotations

from functools import lru_cache
from typing import List, Dict, Any, Optional

try:
    from sentence_transformers import CrossEncoder
except ImportError:
    CrossEncoder = None  # type: ignore


@lru_cache(maxsize=1)
def _get_reranker() -> Optional[Any]:
    """
    Lazily load the cross-encoder model once per process.
    
    Uses cross-encoder/ms-marco-MiniLM-L-6-v2 which is optimized for
    MS MARCO dataset and provides fast, accurate re-ranking.
    """
    if CrossEncoder is None:
        return None
    try:
        # Load the cross-encoder model
        # This model is specifically trained for re-ranking search results
        model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
        return model
    except Exception as e:
        print(f"Warning: Failed to load cross-encoder model: {e}")
        print("RAG search will continue without re-ranking.")
        return None


def rerank_results(
    query: str,
    candidates: List[Dict[str, Any]],
    top_k: Optional[int] = None,
) -> List[Dict[str, Any]]:
    """
    Re-rank search results using cross-encoder for improved accuracy.
    
    Args:
        query: The search query
        candidates: List of candidate results, each with at least a "text" field
        top_k: Optional limit on number of results to return after re-ranking
    
    Returns:
        Re-ranked list of candidates with updated "score" and "relevance" fields
    """
    if not candidates:
        return []
    
    reranker = _get_reranker()
    
    # If cross-encoder is not available, return original results
    if reranker is None:
        return candidates
    
    try:
        # Prepare pairs: (query, candidate_text) for each candidate
        pairs = [(query, candidate.get("text", "")) for candidate in candidates]
        
        # Get re-ranking scores (higher = more relevant)
        # Cross-encoder outputs raw scores (can be negative or positive)
        scores = reranker.predict(pairs)
        
        # Update candidates with new scores
        reranked = []
        for candidate, score in zip(candidates, scores):
            # Cross-encoder scores are logits, normalize to 0-1 using sigmoid
            # This ensures scores are in [0, 1] range for consistency with vector similarity scores
            try:
                import numpy as np
                # Apply sigmoid to normalize logit scores to [0, 1]
                normalized_score = float(1.0 / (1.0 + np.exp(-float(score))))
            except (ImportError, ValueError, TypeError):
                # Fallback: if numpy not available, use simple normalization
                # Cross-encoder scores for ms-marco-MiniLM-L-6-v2 are typically in [-10, 10] range
                # Simple linear scaling to [0, 1] as fallback
                score_float = float(score) if isinstance(score, (int, float)) else 0.0
                normalized_score = max(0.0, min(1.0, (score_float + 10.0) / 20.0))
            
            # Update the candidate with re-ranked score
            updated = {
                **candidate,
                "score": normalized_score,
                "relevance": normalized_score,  # Keep both for compatibility
                "reranked": True,  # Flag to indicate this was re-ranked
            }
            reranked.append(updated)
        
        # Sort by re-ranked score (descending)
        reranked.sort(key=lambda x: x.get("score", 0.0), reverse=True)
        
        # Return top_k if specified
        if top_k is not None and top_k > 0:
            reranked = reranked[:top_k]
        
        return reranked
    
    except Exception as e:
        print(f"Warning: Cross-encoder re-ranking failed: {e}")
        print("Returning original results without re-ranking.")
        return candidates