File size: 1,638 Bytes
9612292
31a2688
 
9612292
31a2688
 
 
 
 
 
9612292
 
 
 
 
 
31a2688
 
 
9612292
31a2688
 
 
9612292
31a2688
9612292
 
31a2688
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9612292
31a2688
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
"""Cross-encoder reranking."""

import logging
import math

from src.models import QueryResult

logger = logging.getLogger(__name__)


def _sigmoid(score: float) -> float:
    """Normalize a raw cross-encoder score to 0-1 via sigmoid."""
    score = max(-500.0, min(500.0, score))
    return 1.0 / (1.0 + math.exp(-score))


class Reranker:
    """Reranks retrieval results using a cross-encoder model."""

    def __init__(self, model: object) -> None:
        """Initialize the reranker with a cross-encoder model.

        Args:
            model: A cross-encoder model instance (e.g. from provider.create_reranker).
        """
        self._model = model
        logger.info("Loaded cross-encoder reranker")

    def rerank(self, query: str, results: list[QueryResult], top_k: int) -> list[QueryResult]:
        """Rerank retrieval results using the cross-encoder.

        Args:
            query: The original search query.
            results: Candidate results to rerank.
            top_k: Number of top results to keep after reranking.

        Returns:
            Reranked list of QueryResult objects.
        """
        if not results:
            return []

        pairs = [[query, result.chunk.text] for result in results]
        scores = self._model.predict(pairs)

        reranked = [
            QueryResult(chunk=result.chunk, score=_sigmoid(float(score)), source="reranked")
            for result, score in zip(results, scores)
        ]
        reranked.sort(key=lambda r: r.score, reverse=True)

        logger.debug("Reranked %d results, keeping top %d", len(reranked), top_k)
        return reranked[:top_k]