| from __future__ import annotations
|
|
|
| import logging
|
| import math
|
| import time
|
| from dataclasses import replace
|
| from functools import lru_cache
|
|
|
| import requests
|
|
|
| from app.config import ( |
| RERANK_API_RETRIES, |
| RERANK_API_RETRY_BACKOFF,
|
| RERANK_API_TIMEOUT,
|
| RERANK_API_URL,
|
| RERANK_BATCH_SIZE,
|
| RERANK_ENABLED,
|
| ) |
| from app.runtime_auth import get_hf_api_key |
| from app.schemas import RetrievedChunk
|
|
|
|
|
| logger = logging.getLogger(__name__)
|
| RETRYABLE_STATUS_CODES = {408, 429, 500, 502, 503, 504}
|
|
|
|
|
| class BGEReranker:
|
| def __init__(self) -> None:
|
| self.enabled = RERANK_ENABLED
|
|
|
| def rerank(self, query: str, chunks: list[RetrievedChunk], top_k: int) -> list[RetrievedChunk]:
|
| if not chunks:
|
| return []
|
| if not self.enabled:
|
| return chunks[:top_k]
|
| if not get_hf_api_key(): |
| return self._fallback(chunks, top_k, "missing_hf_api_key") |
|
|
| try:
|
| scores: list[float] = []
|
| for start in range(0, len(chunks), RERANK_BATCH_SIZE):
|
| batch = chunks[start : start + RERANK_BATCH_SIZE]
|
| scores.extend(self._api_scores(query, [chunk.text for chunk in batch]))
|
|
|
| if len(scores) != len(chunks):
|
| raise RuntimeError(
|
| f"Reranker returned {len(scores)} scores for {len(chunks)} candidates"
|
| )
|
|
|
| ranked = sorted(
|
| zip(chunks, scores),
|
| key=lambda item: item[1],
|
| reverse=True,
|
| )
|
| return [
|
| replace(
|
| chunk,
|
| score=round(sigmoid(raw_score), 6),
|
| metadata={
|
| **chunk.metadata,
|
| "hybrid_score": chunk.score,
|
| "rerank_score": raw_score,
|
| "rerank_status": "success",
|
| },
|
| )
|
| for chunk, raw_score in ranked[:top_k]
|
| ]
|
| except (requests.RequestException, RuntimeError, TypeError, ValueError) as exc:
|
| logger.warning(
|
| "Reranker API unavailable; using hybrid ranking fallback: %s",
|
| exc,
|
| )
|
| return self._fallback(chunks, top_k, type(exc).__name__)
|
|
|
| def _fallback(
|
| self,
|
| chunks: list[RetrievedChunk],
|
| top_k: int,
|
| reason: str,
|
| ) -> list[RetrievedChunk]:
|
| return [
|
| replace(
|
| chunk,
|
| metadata={
|
| **chunk.metadata,
|
| "hybrid_score": chunk.score,
|
| "rerank_status": "fallback",
|
| "rerank_fallback_reason": reason,
|
| },
|
| )
|
| for chunk in chunks[:top_k]
|
| ]
|
|
|
| def _api_scores(self, query: str, documents: list[str]) -> list[float]: |
| api_key = get_hf_api_key() |
| if not api_key: |
| raise RuntimeError("Enter a Hugging Face token to use reranking") |
| headers = {"Authorization": f"Bearer {api_key}"} |
| payload = {
|
| "inputs": [{"text": query, "text_pair": document} for document in documents],
|
| "options": {"wait_for_model": True},
|
| }
|
|
|
| response: requests.Response | None = None
|
| attempts = max(1, RERANK_API_RETRIES + 1)
|
| for attempt in range(1, attempts + 1):
|
| try:
|
| response = requests.post(
|
| RERANK_API_URL,
|
| headers=headers,
|
| json=payload,
|
| timeout=RERANK_API_TIMEOUT,
|
| )
|
| if response.status_code not in RETRYABLE_STATUS_CODES:
|
| break
|
| if attempt == attempts:
|
| response.raise_for_status()
|
| logger.warning(
|
| "Reranker API returned HTTP %s; retrying (%s/%s)",
|
| response.status_code,
|
| attempt,
|
| attempts - 1,
|
| )
|
| except (requests.Timeout, requests.ConnectionError) as exc:
|
| if attempt == attempts:
|
| raise
|
| logger.warning(
|
| "Reranker API request failed; retrying (%s/%s): %s",
|
| attempt,
|
| attempts - 1,
|
| exc,
|
| )
|
|
|
| delay = RERANK_API_RETRY_BACKOFF * (2 ** (attempt - 1))
|
| if delay > 0:
|
| time.sleep(delay)
|
|
|
| if response is None:
|
| raise RuntimeError("Reranker API did not return a response")
|
| if response.status_code == 400 and len(documents) > 1:
|
| return [self._api_scores(query, [document])[0] for document in documents]
|
|
|
| response.raise_for_status()
|
| response_payload = response.json()
|
| if isinstance(response_payload, dict) and response_payload.get("error"):
|
| raise RuntimeError(str(response_payload["error"]))
|
| return self._coerce_scores(response_payload, expected_count=len(documents))
|
|
|
| def _coerce_scores(self, payload, expected_count: int) -> list[float]:
|
| if isinstance(payload, dict) and "scores" in payload:
|
| scores = payload["scores"]
|
| else:
|
| scores = payload
|
|
|
| if isinstance(scores, list) and len(scores) == 1 and isinstance(scores[0], list):
|
| scores = scores[0]
|
|
|
| if not isinstance(scores, list) or len(scores) != expected_count:
|
| raise RuntimeError(
|
| f"Unexpected rerank API response shape: expected {expected_count}, "
|
| f"received {type(scores).__name__}"
|
| )
|
|
|
| return [self._score_from_item(item) for item in scores]
|
|
|
| def _score_from_item(self, item) -> float:
|
| if isinstance(item, int | float):
|
| return float(item)
|
| if isinstance(item, dict):
|
| if "score" in item:
|
| return float(item["score"])
|
| if "logit" in item:
|
| return float(item["logit"])
|
| if isinstance(item, list) and item:
|
| candidate = max(
|
| item,
|
| key=lambda value: (
|
| float(value.get("score", 0.0)) if isinstance(value, dict) else 0.0
|
| ),
|
| )
|
| return self._score_from_item(candidate)
|
| raise RuntimeError("Unexpected rerank score item from API")
|
|
|
|
|
| def sigmoid(value: float) -> float:
|
| if value >= 0:
|
| z = math.exp(-value)
|
| return 1 / (1 + z)
|
| z = math.exp(value)
|
| return z / (1 + z)
|
|
|
|
|
| @lru_cache(maxsize=1)
|
| def get_reranker() -> BGEReranker:
|
| return BGEReranker()
|
|
|