# backend/app/services/reranker.py # Dual-mode reranker. # - local (ENVIRONMENT != prod): lazy-loads CrossEncoder in-process on first call. # - prod: calls the HuggingFace personabot-reranker Space via async HTTP. from typing import Any, Optional import httpx from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential from app.models.pipeline import Chunk _local_model: Optional[Any] = None def _get_local_model() -> Any: global _local_model # noqa: PLW0603 if _local_model is None: from sentence_transformers import CrossEncoder _local_model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2", device="cpu") return _local_model class Reranker: def __init__(self, remote_url: str = "", environment: str = "local") -> None: self._remote = environment == "prod" and bool(remote_url) self._url = remote_url.rstrip("/") if self._remote else "" self._min_score: float = 0.0 async def rerank(self, query: str, chunks: list[Chunk], top_k: int = 5) -> list[Chunk]: """ Builds (query, chunk.text) pairs, scores all, returns top_k sorted descending. Attaches score to chunk metadata as rerank_score. Top-20 → reranker → top-5 is the validated sweet spot for latency/quality. """ if not chunks: self._min_score = 0.0 return [] # RC-12: prefer contextualised_text (doc title + section prefix) so the # cross-encoder sees the same enriched text as the dense retriever. # Falls back to raw chunk text for old points that pre-date contextualisation. texts = [chunk.get("contextualised_text") or chunk["text"] for chunk in chunks] if self._remote: @retry( stop=stop_after_attempt(2), wait=wait_exponential(multiplier=0.4, min=0.4, max=1.2), retry=retry_if_exception_type((httpx.TimeoutException, httpx.HTTPError)), reraise=True, ) async def _remote_call() -> tuple[list[int], list[float]]: async with httpx.AsyncClient(timeout=60.0) as client: truncated = [t[:1500] for t in texts] resp = await client.post( f"{self._url}/rerank", json={"query": query[:512], "texts": truncated, "top_k": top_k}, ) resp.raise_for_status() data = resp.json() indices = data.get("indices") scores = data.get("scores") if not isinstance(indices, list) or not isinstance(scores, list): raise httpx.HTTPError("Invalid reranker response schema") return [int(i) for i in indices], [float(s) for s in scores] indices, scores = await _remote_call() result = [] for idx, score in zip(indices, scores): chunk_copy = dict(chunks[idx]) chunk_copy["metadata"]["rerank_score"] = score result.append(chunk_copy) self._min_score = scores[-1] if scores else 0.0 return result # type: ignore[return-value] model = _get_local_model() pairs = [(query, text) for text in texts] raw_scores = [float(s) for s in model.predict(pairs)] scored = list(zip(chunks, raw_scores)) scored.sort(key=lambda x: x[1], reverse=True) top_scored = scored[:top_k] result = [] for chunk, score in top_scored: chunk_copy = dict(chunk) chunk_copy["metadata"]["rerank_score"] = score result.append(chunk_copy) self._min_score = top_scored[-1][1] if top_scored else 0.0 return result # type: ignore[return-value] @property def min_score(self) -> float: return self._min_score