Spaces:
Running
Running
File size: 1,638 Bytes
9612292 31a2688 9612292 31a2688 9612292 31a2688 9612292 31a2688 9612292 31a2688 9612292 31a2688 9612292 31a2688 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 | """Cross-encoder reranking."""
import logging
import math
from src.models import QueryResult
logger = logging.getLogger(__name__)
def _sigmoid(score: float) -> float:
"""Normalize a raw cross-encoder score to 0-1 via sigmoid."""
score = max(-500.0, min(500.0, score))
return 1.0 / (1.0 + math.exp(-score))
class Reranker:
"""Reranks retrieval results using a cross-encoder model."""
def __init__(self, model: object) -> None:
"""Initialize the reranker with a cross-encoder model.
Args:
model: A cross-encoder model instance (e.g. from provider.create_reranker).
"""
self._model = model
logger.info("Loaded cross-encoder reranker")
def rerank(self, query: str, results: list[QueryResult], top_k: int) -> list[QueryResult]:
"""Rerank retrieval results using the cross-encoder.
Args:
query: The original search query.
results: Candidate results to rerank.
top_k: Number of top results to keep after reranking.
Returns:
Reranked list of QueryResult objects.
"""
if not results:
return []
pairs = [[query, result.chunk.text] for result in results]
scores = self._model.predict(pairs)
reranked = [
QueryResult(chunk=result.chunk, score=_sigmoid(float(score)), source="reranked")
for result, score in zip(results, scores)
]
reranked.sort(key=lambda r: r.score, reverse=True)
logger.debug("Reranked %d results, keeping top %d", len(reranked), top_k)
return reranked[:top_k]
|