GitHub Actions
Deploy f8b1b4c
85c969e
# infra/hf_spaces/reranker/app.py
# Serves cross-encoder/ms-marco-MiniLM-L-6-v2 reranking over HTTP.
# Model is loaded from /app/model_cache (baked into the Docker image at build time).
from contextlib import asynccontextmanager
from typing import Annotated
from fastapi import FastAPI
from pydantic import BaseModel, Field
from sentence_transformers import CrossEncoder
# Pipeline retrieves top-20, reranks to top-5. Cap at 64 to guard against abuse.
_MAX_TEXTS = 64
_MAX_QUERY_LEN = 512
class RerankRequest(BaseModel):
query: str = Field(..., max_length=_MAX_QUERY_LEN)
# Remove max_length on individual strings — token truncation handles long text safely
texts: list[str] = Field(..., max_length=_MAX_TEXTS)
top_k: int = Field(5, ge=1, le=20)
class RerankResponse(BaseModel):
# Indices into the input texts list, sorted by descending relevance.
indices: list[int]
scores: list[float]
@asynccontextmanager
async def lifespan(app: FastAPI):
app.state.model = CrossEncoder(
"cross-encoder/ms-marco-MiniLM-L-6-v2",
cache_folder="/app/model_cache",
)
yield
app.state.model = None
app = FastAPI(
title="PersonaBot Reranker",
lifespan=lifespan,
docs_url=None,
redoc_url=None,
openapi_url=None,
)
@app.get("/health")
async def health() -> dict[str, str]:
if app.state.model is None:
return {"status": "loading"}
return {"status": "ok"}
@app.post("/rerank", response_model=RerankResponse)
async def rerank(request: RerankRequest) -> RerankResponse:
if not request.texts:
return RerankResponse(indices=[], scores=[])
pairs = [(request.query, text) for text in request.texts]
raw_scores: list[float] = [float(s) for s in app.state.model.predict(pairs)]
# Sort by score descending, return top_k
ranked = sorted(enumerate(raw_scores), key=lambda x: x[1], reverse=True)
ranked = ranked[: request.top_k]
return RerankResponse(
indices=[i for i, _ in ranked],
scores=[s for _, s in ranked],
)