Spaces:
Sleeping
Sleeping
| """ | |
| Cross-encoder re-ranking for RAG search results. | |
| Uses cross-encoder/ms-marco-MiniLM-L-6-v2 for fast, accurate re-ranking | |
| of vector search results to improve retrieval accuracy. | |
| """ | |
| from __future__ import annotations | |
| from functools import lru_cache | |
| from typing import List, Dict, Any, Optional | |
| try: | |
| from sentence_transformers import CrossEncoder | |
| except ImportError: | |
| CrossEncoder = None # type: ignore | |
| def _get_reranker() -> Optional[Any]: | |
| """ | |
| Lazily load the cross-encoder model once per process. | |
| Uses cross-encoder/ms-marco-MiniLM-L-6-v2 which is optimized for | |
| MS MARCO dataset and provides fast, accurate re-ranking. | |
| """ | |
| if CrossEncoder is None: | |
| return None | |
| try: | |
| # Load the cross-encoder model | |
| # This model is specifically trained for re-ranking search results | |
| model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") | |
| return model | |
| except Exception as e: | |
| print(f"Warning: Failed to load cross-encoder model: {e}") | |
| print("RAG search will continue without re-ranking.") | |
| return None | |
| def rerank_results( | |
| query: str, | |
| candidates: List[Dict[str, Any]], | |
| top_k: Optional[int] = None, | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Re-rank search results using cross-encoder for improved accuracy. | |
| Args: | |
| query: The search query | |
| candidates: List of candidate results, each with at least a "text" field | |
| top_k: Optional limit on number of results to return after re-ranking | |
| Returns: | |
| Re-ranked list of candidates with updated "score" and "relevance" fields | |
| """ | |
| if not candidates: | |
| return [] | |
| reranker = _get_reranker() | |
| # If cross-encoder is not available, return original results | |
| if reranker is None: | |
| return candidates | |
| try: | |
| # Prepare pairs: (query, candidate_text) for each candidate | |
| pairs = [(query, candidate.get("text", "")) for candidate in candidates] | |
| # Get re-ranking scores (higher = more relevant) | |
| # Cross-encoder outputs raw scores (can be negative or positive) | |
| scores = reranker.predict(pairs) | |
| # Update candidates with new scores | |
| reranked = [] | |
| for candidate, score in zip(candidates, scores): | |
| # Cross-encoder scores are logits, normalize to 0-1 using sigmoid | |
| # This ensures scores are in [0, 1] range for consistency with vector similarity scores | |
| try: | |
| import numpy as np | |
| # Apply sigmoid to normalize logit scores to [0, 1] | |
| normalized_score = float(1.0 / (1.0 + np.exp(-float(score)))) | |
| except (ImportError, ValueError, TypeError): | |
| # Fallback: if numpy not available, use simple normalization | |
| # Cross-encoder scores for ms-marco-MiniLM-L-6-v2 are typically in [-10, 10] range | |
| # Simple linear scaling to [0, 1] as fallback | |
| score_float = float(score) if isinstance(score, (int, float)) else 0.0 | |
| normalized_score = max(0.0, min(1.0, (score_float + 10.0) / 20.0)) | |
| # Update the candidate with re-ranked score | |
| updated = { | |
| **candidate, | |
| "score": normalized_score, | |
| "relevance": normalized_score, # Keep both for compatibility | |
| "reranked": True, # Flag to indicate this was re-ranked | |
| } | |
| reranked.append(updated) | |
| # Sort by re-ranked score (descending) | |
| reranked.sort(key=lambda x: x.get("score", 0.0), reverse=True) | |
| # Return top_k if specified | |
| if top_k is not None and top_k > 0: | |
| reranked = reranked[:top_k] | |
| return reranked | |
| except Exception as e: | |
| print(f"Warning: Cross-encoder re-ranking failed: {e}") | |
| print("Returning original results without re-ranking.") | |
| return candidates | |