Spaces:
Running
Running
| # backend/app/services/reranker.py | |
| # Dual-mode reranker. | |
| # - local (ENVIRONMENT != prod): lazy-loads CrossEncoder in-process on first call. | |
| # - prod: calls the HuggingFace personabot-reranker Space via async HTTP. | |
| from typing import Any, Optional | |
| import httpx | |
| from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential | |
| from app.models.pipeline import Chunk | |
| _local_model: Optional[Any] = None | |
| def _get_local_model() -> Any: | |
| global _local_model # noqa: PLW0603 | |
| if _local_model is None: | |
| from sentence_transformers import CrossEncoder | |
| _local_model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2", device="cpu") | |
| return _local_model | |
| class Reranker: | |
| def __init__(self, remote_url: str = "", environment: str = "local") -> None: | |
| self._remote = environment == "prod" and bool(remote_url) | |
| self._url = remote_url.rstrip("/") if self._remote else "" | |
| self._min_score: float = 0.0 | |
| async def rerank(self, query: str, chunks: list[Chunk], top_k: int = 5) -> list[Chunk]: | |
| """ | |
| Builds (query, chunk.text) pairs, scores all, returns top_k sorted descending. | |
| Attaches score to chunk metadata as rerank_score. | |
| Top-20 → reranker → top-5 is the validated sweet spot for latency/quality. | |
| """ | |
| if not chunks: | |
| self._min_score = 0.0 | |
| return [] | |
| # RC-12: prefer contextualised_text (doc title + section prefix) so the | |
| # cross-encoder sees the same enriched text as the dense retriever. | |
| # Falls back to raw chunk text for old points that pre-date contextualisation. | |
| texts = [chunk.get("contextualised_text") or chunk["text"] for chunk in chunks] | |
| if self._remote: | |
| async def _remote_call() -> tuple[list[int], list[float]]: | |
| async with httpx.AsyncClient(timeout=60.0) as client: | |
| truncated = [t[:1500] for t in texts] | |
| resp = await client.post( | |
| f"{self._url}/rerank", | |
| json={"query": query[:512], "texts": truncated, "top_k": top_k}, | |
| ) | |
| resp.raise_for_status() | |
| data = resp.json() | |
| indices = data.get("indices") | |
| scores = data.get("scores") | |
| if not isinstance(indices, list) or not isinstance(scores, list): | |
| raise httpx.HTTPError("Invalid reranker response schema") | |
| return [int(i) for i in indices], [float(s) for s in scores] | |
| indices, scores = await _remote_call() | |
| result = [] | |
| for idx, score in zip(indices, scores): | |
| chunk_copy = dict(chunks[idx]) | |
| chunk_copy["metadata"]["rerank_score"] = score | |
| result.append(chunk_copy) | |
| self._min_score = scores[-1] if scores else 0.0 | |
| return result # type: ignore[return-value] | |
| model = _get_local_model() | |
| pairs = [(query, text) for text in texts] | |
| raw_scores = [float(s) for s in model.predict(pairs)] | |
| scored = list(zip(chunks, raw_scores)) | |
| scored.sort(key=lambda x: x[1], reverse=True) | |
| top_scored = scored[:top_k] | |
| result = [] | |
| for chunk, score in top_scored: | |
| chunk_copy = dict(chunk) | |
| chunk_copy["metadata"]["rerank_score"] = score | |
| result.append(chunk_copy) | |
| self._min_score = top_scored[-1][1] if top_scored else 0.0 | |
| return result # type: ignore[return-value] | |
| def min_score(self) -> float: | |
| return self._min_score | |