nothingworry's picture
feat: update the encoding model
0e8c152
raw
history blame
4.01 kB
"""
Cross-encoder re-ranking for RAG search results.
Uses cross-encoder/ms-marco-MiniLM-L-6-v2 for fast, accurate re-ranking
of vector search results to improve retrieval accuracy.
"""
from __future__ import annotations
from functools import lru_cache
from typing import List, Dict, Any, Optional
try:
from sentence_transformers import CrossEncoder
except ImportError:
CrossEncoder = None # type: ignore
@lru_cache(maxsize=1)
def _get_reranker() -> Optional[Any]:
"""
Lazily load the cross-encoder model once per process.
Uses cross-encoder/ms-marco-MiniLM-L-6-v2 which is optimized for
MS MARCO dataset and provides fast, accurate re-ranking.
"""
if CrossEncoder is None:
return None
try:
# Load the cross-encoder model
# This model is specifically trained for re-ranking search results
model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
return model
except Exception as e:
print(f"Warning: Failed to load cross-encoder model: {e}")
print("RAG search will continue without re-ranking.")
return None
def rerank_results(
query: str,
candidates: List[Dict[str, Any]],
top_k: Optional[int] = None,
) -> List[Dict[str, Any]]:
"""
Re-rank search results using cross-encoder for improved accuracy.
Args:
query: The search query
candidates: List of candidate results, each with at least a "text" field
top_k: Optional limit on number of results to return after re-ranking
Returns:
Re-ranked list of candidates with updated "score" and "relevance" fields
"""
if not candidates:
return []
reranker = _get_reranker()
# If cross-encoder is not available, return original results
if reranker is None:
return candidates
try:
# Prepare pairs: (query, candidate_text) for each candidate
pairs = [(query, candidate.get("text", "")) for candidate in candidates]
# Get re-ranking scores (higher = more relevant)
# Cross-encoder outputs raw scores (can be negative or positive)
scores = reranker.predict(pairs)
# Update candidates with new scores
reranked = []
for candidate, score in zip(candidates, scores):
# Cross-encoder scores are logits, normalize to 0-1 using sigmoid
# This ensures scores are in [0, 1] range for consistency with vector similarity scores
try:
import numpy as np
# Apply sigmoid to normalize logit scores to [0, 1]
normalized_score = float(1.0 / (1.0 + np.exp(-float(score))))
except (ImportError, ValueError, TypeError):
# Fallback: if numpy not available, use simple normalization
# Cross-encoder scores for ms-marco-MiniLM-L-6-v2 are typically in [-10, 10] range
# Simple linear scaling to [0, 1] as fallback
score_float = float(score) if isinstance(score, (int, float)) else 0.0
normalized_score = max(0.0, min(1.0, (score_float + 10.0) / 20.0))
# Update the candidate with re-ranked score
updated = {
**candidate,
"score": normalized_score,
"relevance": normalized_score, # Keep both for compatibility
"reranked": True, # Flag to indicate this was re-ranked
}
reranked.append(updated)
# Sort by re-ranked score (descending)
reranked.sort(key=lambda x: x.get("score", 0.0), reverse=True)
# Return top_k if specified
if top_k is not None and top_k > 0:
reranked = reranked[:top_k]
return reranked
except Exception as e:
print(f"Warning: Cross-encoder re-ranking failed: {e}")
print("Returning original results without re-ranking.")
return candidates