XQ
Code cleaning
9612292
raw
history blame
1.64 kB
"""Cross-encoder reranking."""
import logging
import math
from src.models import QueryResult
logger = logging.getLogger(__name__)
def _sigmoid(score: float) -> float:
"""Normalize a raw cross-encoder score to 0-1 via sigmoid."""
score = max(-500.0, min(500.0, score))
return 1.0 / (1.0 + math.exp(-score))
class Reranker:
"""Reranks retrieval results using a cross-encoder model."""
def __init__(self, model: object) -> None:
"""Initialize the reranker with a cross-encoder model.
Args:
model: A cross-encoder model instance (e.g. from provider.create_reranker).
"""
self._model = model
logger.info("Loaded cross-encoder reranker")
def rerank(self, query: str, results: list[QueryResult], top_k: int) -> list[QueryResult]:
"""Rerank retrieval results using the cross-encoder.
Args:
query: The original search query.
results: Candidate results to rerank.
top_k: Number of top results to keep after reranking.
Returns:
Reranked list of QueryResult objects.
"""
if not results:
return []
pairs = [[query, result.chunk.text] for result in results]
scores = self._model.predict(pairs)
reranked = [
QueryResult(chunk=result.chunk, score=_sigmoid(float(score)), source="reranked")
for result, score in zip(results, scores)
]
reranked.sort(key=lambda r: r.score, reverse=True)
logger.debug("Reranked %d results, keeping top %d", len(reranked), top_k)
return reranked[:top_k]