Spaces:
Sleeping
Sleeping
File size: 6,459 Bytes
43a2563 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | """FastAPI application: routes, lifespan, and wiring.
Endpoints:
GET / -> redirect to the demo UI
GET /healthz -> readiness/liveness (200 once the model is loaded)
GET /metrics -> Prometheus exposition
POST /predict -> single or batch classification (pydantic-validated)
GET /demo -> static HTML demo that calls /predict
The model is loaded once in the lifespan startup and shared via ``app.state``;
the micro-batcher runs as a background task for the lifetime of the process.
"""
from __future__ import annotations
import logging
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any, Dict, List
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import FileResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from pydantic import ValidationError
from app import __version__
from app.batching import MicroBatcher
from app.classifier import load_classifier
from app.config import get_settings
from app.metrics import metrics_middleware, render_latest
from app.schemas import (BatchPredictRequest, BatchPredictResponse,
HealthResponse, PredictRequest, PredictResponse,
Prediction)
logger = logging.getLogger(__name__)
DEMO_DIR = Path(__file__).resolve().parent.parent / "demo"
STATIC_DIR = DEMO_DIR / "vendor"
def _build_prediction(text: str, dist: Dict[str, float]) -> Prediction:
top = max(dist, key=dist.get)
return Prediction(label=top, score=dist[top], probabilities=dist)
@asynccontextmanager
async def lifespan(app: FastAPI):
settings = get_settings()
logging.basicConfig(
level=getattr(logging, settings.log_level.upper(), logging.INFO),
format="%(asctime)s %(levelname)s %(name)s %(message)s",
)
logger.info(
"starting emotion-api v%s (offline=%s, model_id=%s)",
__version__, settings.offline, settings.model_id,
)
classifier = load_classifier(settings)
batcher = MicroBatcher(
classifier,
max_batch_size=settings.max_batch_size,
max_delay_ms=settings.batch_max_delay_ms,
)
await batcher.start()
app.state.settings = settings
app.state.classifier = classifier
app.state.batcher = batcher
app.state.ready = True
try:
yield
finally:
app.state.ready = False
await batcher.stop()
logger.info("emotion-api stopped")
def create_app() -> FastAPI:
"""Application factory (one per process; tests build their own)."""
app = FastAPI(
title="DistilBERT Emotion API",
version=__version__,
summary="Production inference service for the LaelaZ/distilbert-emotion model.",
description=(
"Classify a sentence into one of six emotions "
"(sadness, joy, love, anger, fear, surprise) with full per-class "
"probabilities. Batched, observable, and runnable fully offline."
),
lifespan=lifespan,
)
app.middleware("http")(metrics_middleware)
@app.get("/", include_in_schema=False)
async def root() -> RedirectResponse:
return RedirectResponse(url="/demo")
@app.get("/healthz", response_model=HealthResponse, tags=["ops"])
async def healthz(request: Request) -> HealthResponse:
"""Readiness + liveness. 503 until the model is loaded and the batcher runs."""
state = request.app.state
if not getattr(state, "ready", False):
raise HTTPException(status_code=503, detail="not ready")
settings = state.settings
return HealthResponse(
status="ok",
backend=state.classifier.backend,
model_id=settings.model_id,
offline=settings.offline,
version=__version__,
)
@app.get("/metrics", tags=["ops"], include_in_schema=False)
async def metrics():
return render_latest()
@app.post(
"/predict",
tags=["inference"],
summary="Classify one sentence or a batch.",
responses={
200: {
"description": "Single prediction (PredictResponse) or batch "
"(BatchPredictResponse), matching the request shape."
}
},
)
async def predict(request: Request, body: Dict[str, Any]) -> Any:
"""Accept ``{text: str}`` or ``{texts: [str, ...]}`` and classify.
The body is parsed manually so one endpoint can serve both shapes and
return the matching response shape. Validation errors surface as 422.
"""
batcher: MicroBatcher = request.app.state.batcher
is_single = "text" in body
is_batch = "texts" in body
if is_single == is_batch: # neither, or both
raise HTTPException(
status_code=422,
detail="provide exactly one of 'text' (string) or 'texts' (list of strings)",
)
try:
if is_single:
req = PredictRequest(**body)
dist = await batcher.submit(req.text)
pred = _build_prediction(req.text, dist)
return PredictResponse(**pred.model_dump())
req_b = BatchPredictRequest(**body)
dists = await batcher.submit_many(req_b.texts)
preds: List[Prediction] = [
_build_prediction(t, d) for t, d in zip(req_b.texts, dists)
]
return BatchPredictResponse(predictions=preds)
except ValidationError as exc:
# Translate pydantic errors into FastAPI's standard 422 envelope.
# Strip the non-JSON-serializable bits pydantic attaches (the raw
# exception object under `ctx`, and the docs URL).
detail = [
{k: v for k, v in err.items() if k not in ("ctx", "url")}
for err in exc.errors()
]
raise HTTPException(status_code=422, detail=detail)
@app.get("/demo", tags=["ui"], include_in_schema=False)
async def demo() -> FileResponse:
return FileResponse(DEMO_DIR / "index.html")
# Vendored, offline front-end assets (Tailwind Play CDN bundle). Mounting a
# directory keeps the demo page fully network-free — no external CDN at runtime.
if STATIC_DIR.is_dir():
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
return app
app = create_app()
|