# infra/hf_spaces/reranker/app.py # Serves cross-encoder/ms-marco-MiniLM-L-6-v2 reranking over HTTP. # Model is loaded from /app/model_cache (baked into the Docker image at build time). from contextlib import asynccontextmanager from typing import Annotated from fastapi import FastAPI from pydantic import BaseModel, Field from sentence_transformers import CrossEncoder # Pipeline retrieves top-20, reranks to top-5. Cap at 64 to guard against abuse. _MAX_TEXTS = 64 _MAX_QUERY_LEN = 512 class RerankRequest(BaseModel): query: str = Field(..., max_length=_MAX_QUERY_LEN) # Remove max_length on individual strings — token truncation handles long text safely texts: list[str] = Field(..., max_length=_MAX_TEXTS) top_k: int = Field(5, ge=1, le=20) class RerankResponse(BaseModel): # Indices into the input texts list, sorted by descending relevance. indices: list[int] scores: list[float] @asynccontextmanager async def lifespan(app: FastAPI): app.state.model = CrossEncoder( "cross-encoder/ms-marco-MiniLM-L-6-v2", cache_folder="/app/model_cache", ) yield app.state.model = None app = FastAPI( title="PersonaBot Reranker", lifespan=lifespan, docs_url=None, redoc_url=None, openapi_url=None, ) @app.get("/health") async def health() -> dict[str, str]: if app.state.model is None: return {"status": "loading"} return {"status": "ok"} @app.post("/rerank", response_model=RerankResponse) async def rerank(request: RerankRequest) -> RerankResponse: if not request.texts: return RerankResponse(indices=[], scores=[]) pairs = [(request.query, text) for text in request.texts] raw_scores: list[float] = [float(s) for s in app.state.model.predict(pairs)] # Sort by score descending, return top_k ranked = sorted(enumerate(raw_scores), key=lambda x: x[1], reverse=True) ranked = ranked[: request.top_k] return RerankResponse( indices=[i for i, _ in ranked], scores=[s for _, s in ranked], )