Spaces:
Running
Running
File size: 8,646 Bytes
0a4529c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 |
# DEPENDENCIES
from typing import List
from typing import Optional
from config.models import ChunkWithScore
from config.settings import get_settings
from config.logging_config import get_logger
from utils.error_handler import handle_errors
from utils.error_handler import RerankingError
from sentence_transformers import CrossEncoder
# Setup Settings and Logging
settings = get_settings()
logger = get_logger(__name__)
class Reranker:
"""
Cross-encoder reranking for retrieval results: Provides more accurate relevance scoring using cross-encoder models
Optionally enabled for improved accuracy at cost of latency
"""
def __init__(self, model_name: Optional[str] = None, enable_reranking: Optional[bool] = None):
"""
Initialize reranker
Arguments:
----------
model_name { str } : Cross-encoder model name
enable_reranking { bool } : Whether reranking is enabled
"""
self.logger = logger
self.model_name = model_name or settings.RERANKER_MODEL
self.enable_reranking = enable_reranking if (enable_reranking is not None) else settings.ENABLE_RERANKING
self.model = None
# Statistics
self.rerank_count = 0
# Load model if reranking is enabled
if self.enable_reranking:
self._load_model()
self.logger.info(f"Initialized Reranker: enabled={self.enable_reranking}, model={self.model_name}")
def _load_model(self):
"""
Load cross-encoder model
"""
try:
self.logger.info(f"Loading cross-encoder model: {self.model_name}")
self.model = CrossEncoder(self.model_name)
self.logger.info("Cross-encoder model loaded successfully")
except ImportError:
self.logger.error("sentence-transformers not available for cross-encoder")
self.model = None
self.enable_reranking = False
except Exception as e:
self.logger.error(f"Failed to load cross-encoder model: {repr(e)}")
self.model = None
self.enable_reranking = False
@handle_errors(error_type = RerankingError, log_error = True, reraise = False)
def rerank(self, query: str, chunks_with_scores: List[ChunkWithScore], top_k: Optional[int] = None) -> List[ChunkWithScore]:
"""
Rerank retrieved chunks using cross-encoder
Arguments:
----------
query { str } : Original query
chunks_with_scores { list } : Initial retrieval results
top_k { int } : Number of top results to return (default: all)
Returns:
--------
{ list } : Reranked ChunkWithScore objects
"""
if not self.enable_reranking or self.model is None:
self.logger.debug("Reranking disabled, returning original results")
return chunks_with_scores
if not chunks_with_scores:
return []
if not query or not query.strip():
self.logger.warning("Empty query provided for reranking")
return chunks_with_scores
self.logger.debug(f"Reranking {len(chunks_with_scores)} chunks")
try:
# Prepare query-document pairs
pairs = [(query, cws.chunk.text) for cws in chunks_with_scores]
# Get cross-encoder scores
scores = self.model.predict(pairs)
# Normalize Cross-encoder scores
if (len(scores) > 0):
# Cross-encoder outputs logits that can be negative: Apply min-max normalization to [0, 1]
min_score = min(scores)
max_score = max(scores)
score_range = max_score - min_score
# Avoid division by zero
if (score_range > 1e-6):
scores = [(s - min_score) / score_range for s in scores]
else:
# All scores are the same, set to 0.5
scores = [0.5] * len(scores)
# Update scores and rerank
reranked = list()
for i, (cws, new_score) in enumerate(zip(chunks_with_scores, scores)):
# Create new ChunkWithScore with updated score
reranked_cws = ChunkWithScore(chunk = cws.chunk,
score = float(new_score),
rank = i + 1, # Will be updated after sorting
retrieval_method = 'reranked',
)
reranked.append(reranked_cws)
# Sort by new scores (descending)
reranked.sort(key = lambda x: x.score,
reverse = True,
)
# Update ranks
for rank, cws in enumerate(reranked, 1):
cws.rank = rank
# Return top_k if specified
if top_k:
reranked = reranked[:top_k]
self.rerank_count += 1
self.logger.info(f"Reranked {len(reranked)} chunks using cross-encoder")
return reranked
except Exception as e:
self.logger.error(f"Reranking failed: {repr(e)}, returning original results")
return chunks_with_scores
def rerank_with_scores(self, query: str, texts: List[str]) -> List[tuple]:
"""
Rerank texts and return with scores
Arguments:
----------
query { str } : Query string
texts { list } : List of text strings
Returns:
--------
{ list } : List of (text, score) tuples sorted by score
"""
if not self.enable_reranking or self.model is None:
self.logger.warning("Reranking not available")
return [(text, 0.0) for text in texts]
try:
# Prepare pairs
pairs = [(query, text) for text in texts]
# Get scores
scores = self.model.predict(pairs)
# Combine and sort
results = list(zip(texts, scores))
results.sort(key = lambda x: x[1],
reverse = True,
)
return results
except Exception as e:
self.logger.error(f"Reranking with scores failed: {repr(e)}")
return [(text, 0.0) for text in texts]
def get_reranker_stats(self) -> dict:
"""
Get reranker statistics
Returns:
--------
{ dict } : Reranker statistics
"""
return {"enabled" : self.enable_reranking,
"model_name" : self.model_name,
"model_loaded" : self.model is not None,
"rerank_count" : self.rerank_count,
}
def is_available(self) -> bool:
"""
Check if reranking is available
Returns:
--------
{ bool } : True if reranking is available
"""
return self.enable_reranking and (self.model is not None)
# Global reranker instance
_reranker = None
def get_reranker() -> Reranker:
"""
Get global reranker instance
Returns:
--------
{ Reranker } : Reranker instance
"""
global _reranker
if _reranker is None:
_reranker = Reranker()
return _reranker
def rerank_results(query: str, chunks_with_scores: List[ChunkWithScore], **kwargs) -> List[ChunkWithScore]:
"""
Convenience function for reranking
Arguments:
----------
query { str } : Query string
chunks_with_scores { list } : Results to rerank
**kwargs : Additional arguments
Returns:
--------
{ list } : Reranked results
"""
reranker = get_reranker()
return reranker.rerank(query, chunks_with_scores, **kwargs) |