QuerySphere / retrieval /reranker.py
satyakimitra's picture
first commit
0a4529c
# DEPENDENCIES
from typing import List
from typing import Optional
from config.models import ChunkWithScore
from config.settings import get_settings
from config.logging_config import get_logger
from utils.error_handler import handle_errors
from utils.error_handler import RerankingError
from sentence_transformers import CrossEncoder
# Setup Settings and Logging
settings = get_settings()
logger = get_logger(__name__)
class Reranker:
"""
Cross-encoder reranking for retrieval results: Provides more accurate relevance scoring using cross-encoder models
Optionally enabled for improved accuracy at cost of latency
"""
def __init__(self, model_name: Optional[str] = None, enable_reranking: Optional[bool] = None):
"""
Initialize reranker
Arguments:
----------
model_name { str } : Cross-encoder model name
enable_reranking { bool } : Whether reranking is enabled
"""
self.logger = logger
self.model_name = model_name or settings.RERANKER_MODEL
self.enable_reranking = enable_reranking if (enable_reranking is not None) else settings.ENABLE_RERANKING
self.model = None
# Statistics
self.rerank_count = 0
# Load model if reranking is enabled
if self.enable_reranking:
self._load_model()
self.logger.info(f"Initialized Reranker: enabled={self.enable_reranking}, model={self.model_name}")
def _load_model(self):
"""
Load cross-encoder model
"""
try:
self.logger.info(f"Loading cross-encoder model: {self.model_name}")
self.model = CrossEncoder(self.model_name)
self.logger.info("Cross-encoder model loaded successfully")
except ImportError:
self.logger.error("sentence-transformers not available for cross-encoder")
self.model = None
self.enable_reranking = False
except Exception as e:
self.logger.error(f"Failed to load cross-encoder model: {repr(e)}")
self.model = None
self.enable_reranking = False
@handle_errors(error_type = RerankingError, log_error = True, reraise = False)
def rerank(self, query: str, chunks_with_scores: List[ChunkWithScore], top_k: Optional[int] = None) -> List[ChunkWithScore]:
"""
Rerank retrieved chunks using cross-encoder
Arguments:
----------
query { str } : Original query
chunks_with_scores { list } : Initial retrieval results
top_k { int } : Number of top results to return (default: all)
Returns:
--------
{ list } : Reranked ChunkWithScore objects
"""
if not self.enable_reranking or self.model is None:
self.logger.debug("Reranking disabled, returning original results")
return chunks_with_scores
if not chunks_with_scores:
return []
if not query or not query.strip():
self.logger.warning("Empty query provided for reranking")
return chunks_with_scores
self.logger.debug(f"Reranking {len(chunks_with_scores)} chunks")
try:
# Prepare query-document pairs
pairs = [(query, cws.chunk.text) for cws in chunks_with_scores]
# Get cross-encoder scores
scores = self.model.predict(pairs)
# Normalize Cross-encoder scores
if (len(scores) > 0):
# Cross-encoder outputs logits that can be negative: Apply min-max normalization to [0, 1]
min_score = min(scores)
max_score = max(scores)
score_range = max_score - min_score
# Avoid division by zero
if (score_range > 1e-6):
scores = [(s - min_score) / score_range for s in scores]
else:
# All scores are the same, set to 0.5
scores = [0.5] * len(scores)
# Update scores and rerank
reranked = list()
for i, (cws, new_score) in enumerate(zip(chunks_with_scores, scores)):
# Create new ChunkWithScore with updated score
reranked_cws = ChunkWithScore(chunk = cws.chunk,
score = float(new_score),
rank = i + 1, # Will be updated after sorting
retrieval_method = 'reranked',
)
reranked.append(reranked_cws)
# Sort by new scores (descending)
reranked.sort(key = lambda x: x.score,
reverse = True,
)
# Update ranks
for rank, cws in enumerate(reranked, 1):
cws.rank = rank
# Return top_k if specified
if top_k:
reranked = reranked[:top_k]
self.rerank_count += 1
self.logger.info(f"Reranked {len(reranked)} chunks using cross-encoder")
return reranked
except Exception as e:
self.logger.error(f"Reranking failed: {repr(e)}, returning original results")
return chunks_with_scores
def rerank_with_scores(self, query: str, texts: List[str]) -> List[tuple]:
"""
Rerank texts and return with scores
Arguments:
----------
query { str } : Query string
texts { list } : List of text strings
Returns:
--------
{ list } : List of (text, score) tuples sorted by score
"""
if not self.enable_reranking or self.model is None:
self.logger.warning("Reranking not available")
return [(text, 0.0) for text in texts]
try:
# Prepare pairs
pairs = [(query, text) for text in texts]
# Get scores
scores = self.model.predict(pairs)
# Combine and sort
results = list(zip(texts, scores))
results.sort(key = lambda x: x[1],
reverse = True,
)
return results
except Exception as e:
self.logger.error(f"Reranking with scores failed: {repr(e)}")
return [(text, 0.0) for text in texts]
def get_reranker_stats(self) -> dict:
"""
Get reranker statistics
Returns:
--------
{ dict } : Reranker statistics
"""
return {"enabled" : self.enable_reranking,
"model_name" : self.model_name,
"model_loaded" : self.model is not None,
"rerank_count" : self.rerank_count,
}
def is_available(self) -> bool:
"""
Check if reranking is available
Returns:
--------
{ bool } : True if reranking is available
"""
return self.enable_reranking and (self.model is not None)
# Global reranker instance
_reranker = None
def get_reranker() -> Reranker:
"""
Get global reranker instance
Returns:
--------
{ Reranker } : Reranker instance
"""
global _reranker
if _reranker is None:
_reranker = Reranker()
return _reranker
def rerank_results(query: str, chunks_with_scores: List[ChunkWithScore], **kwargs) -> List[ChunkWithScore]:
"""
Convenience function for reranking
Arguments:
----------
query { str } : Query string
chunks_with_scores { list } : Results to rerank
**kwargs : Additional arguments
Returns:
--------
{ list } : Reranked results
"""
reranker = get_reranker()
return reranker.rerank(query, chunks_with_scores, **kwargs)