personabot-api / app /services /reranker.py
GitHub Actions
Deploy 2e8cff3
c44df3b
# backend/app/services/reranker.py
# Dual-mode reranker.
# - local (ENVIRONMENT != prod): lazy-loads CrossEncoder in-process on first call.
# - prod: calls the HuggingFace personabot-reranker Space via async HTTP.
from typing import Any, Optional
import httpx
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
from app.models.pipeline import Chunk
_local_model: Optional[Any] = None
def _get_local_model() -> Any:
global _local_model # noqa: PLW0603
if _local_model is None:
from sentence_transformers import CrossEncoder
_local_model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2", device="cpu")
return _local_model
class Reranker:
def __init__(self, remote_url: str = "", environment: str = "local") -> None:
self._remote = environment == "prod" and bool(remote_url)
self._url = remote_url.rstrip("/") if self._remote else ""
self._min_score: float = 0.0
async def rerank(self, query: str, chunks: list[Chunk], top_k: int = 5) -> list[Chunk]:
"""
Builds (query, chunk.text) pairs, scores all, returns top_k sorted descending.
Attaches score to chunk metadata as rerank_score.
Top-20 → reranker → top-5 is the validated sweet spot for latency/quality.
"""
if not chunks:
self._min_score = 0.0
return []
# RC-12: prefer contextualised_text (doc title + section prefix) so the
# cross-encoder sees the same enriched text as the dense retriever.
# Falls back to raw chunk text for old points that pre-date contextualisation.
texts = [chunk.get("contextualised_text") or chunk["text"] for chunk in chunks]
if self._remote:
@retry(
stop=stop_after_attempt(2),
wait=wait_exponential(multiplier=0.4, min=0.4, max=1.2),
retry=retry_if_exception_type((httpx.TimeoutException, httpx.HTTPError)),
reraise=True,
)
async def _remote_call() -> tuple[list[int], list[float]]:
async with httpx.AsyncClient(timeout=60.0) as client:
truncated = [t[:1500] for t in texts]
resp = await client.post(
f"{self._url}/rerank",
json={"query": query[:512], "texts": truncated, "top_k": top_k},
)
resp.raise_for_status()
data = resp.json()
indices = data.get("indices")
scores = data.get("scores")
if not isinstance(indices, list) or not isinstance(scores, list):
raise httpx.HTTPError("Invalid reranker response schema")
return [int(i) for i in indices], [float(s) for s in scores]
indices, scores = await _remote_call()
result = []
for idx, score in zip(indices, scores):
chunk_copy = dict(chunks[idx])
chunk_copy["metadata"]["rerank_score"] = score
result.append(chunk_copy)
self._min_score = scores[-1] if scores else 0.0
return result # type: ignore[return-value]
model = _get_local_model()
pairs = [(query, text) for text in texts]
raw_scores = [float(s) for s in model.predict(pairs)]
scored = list(zip(chunks, raw_scores))
scored.sort(key=lambda x: x[1], reverse=True)
top_scored = scored[:top_k]
result = []
for chunk, score in top_scored:
chunk_copy = dict(chunk)
chunk_copy["metadata"]["rerank_score"] = score
result.append(chunk_copy)
self._min_score = top_scored[-1][1] if top_scored else 0.0
return result # type: ignore[return-value]
@property
def min_score(self) -> float:
return self._min_score