File size: 3,937 Bytes
bbe01fe
 
 
 
 
 
 
 
c44df3b
bbe01fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8fdc5ad
 
 
 
bbe01fe
 
c44df3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bbe01fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# backend/app/services/reranker.py
# Dual-mode reranker.
# - local (ENVIRONMENT != prod): lazy-loads CrossEncoder in-process on first call.
# - prod: calls the HuggingFace personabot-reranker Space via async HTTP.

from typing import Any, Optional

import httpx
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential

from app.models.pipeline import Chunk


_local_model: Optional[Any] = None


def _get_local_model() -> Any:
    global _local_model  # noqa: PLW0603
    if _local_model is None:
        from sentence_transformers import CrossEncoder
        _local_model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2", device="cpu")
    return _local_model


class Reranker:
    def __init__(self, remote_url: str = "", environment: str = "local") -> None:
        self._remote = environment == "prod" and bool(remote_url)
        self._url = remote_url.rstrip("/") if self._remote else ""
        self._min_score: float = 0.0

    async def rerank(self, query: str, chunks: list[Chunk], top_k: int = 5) -> list[Chunk]:
        """
        Builds (query, chunk.text) pairs, scores all, returns top_k sorted descending.
        Attaches score to chunk metadata as rerank_score.
        Top-20 → reranker → top-5 is the validated sweet spot for latency/quality.
        """
        if not chunks:
            self._min_score = 0.0
            return []

        # RC-12: prefer contextualised_text (doc title + section prefix) so the
        # cross-encoder sees the same enriched text as the dense retriever.
        # Falls back to raw chunk text for old points that pre-date contextualisation.
        texts = [chunk.get("contextualised_text") or chunk["text"] for chunk in chunks]

        if self._remote:
            @retry(
                stop=stop_after_attempt(2),
                wait=wait_exponential(multiplier=0.4, min=0.4, max=1.2),
                retry=retry_if_exception_type((httpx.TimeoutException, httpx.HTTPError)),
                reraise=True,
            )
            async def _remote_call() -> tuple[list[int], list[float]]:
                async with httpx.AsyncClient(timeout=60.0) as client:
                    truncated = [t[:1500] for t in texts]
                    resp = await client.post(
                        f"{self._url}/rerank",
                        json={"query": query[:512], "texts": truncated, "top_k": top_k},
                    )
                    resp.raise_for_status()
                    data = resp.json()
                    indices = data.get("indices")
                    scores = data.get("scores")
                    if not isinstance(indices, list) or not isinstance(scores, list):
                        raise httpx.HTTPError("Invalid reranker response schema")
                    return [int(i) for i in indices], [float(s) for s in scores]

            indices, scores = await _remote_call()
            result = []
            for idx, score in zip(indices, scores):
                chunk_copy = dict(chunks[idx])
                chunk_copy["metadata"]["rerank_score"] = score
                result.append(chunk_copy)
            self._min_score = scores[-1] if scores else 0.0
            return result  # type: ignore[return-value]

        model = _get_local_model()
        pairs = [(query, text) for text in texts]
        raw_scores = [float(s) for s in model.predict(pairs)]

        scored = list(zip(chunks, raw_scores))
        scored.sort(key=lambda x: x[1], reverse=True)
        top_scored = scored[:top_k]

        result = []
        for chunk, score in top_scored:
            chunk_copy = dict(chunk)
            chunk_copy["metadata"]["rerank_score"] = score
            result.append(chunk_copy)

        self._min_score = top_scored[-1][1] if top_scored else 0.0
        return result  # type: ignore[return-value]

    @property
    def min_score(self) -> float:
        return self._min_score