File size: 6,459 Bytes
43a2563
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
"""FastAPI application: routes, lifespan, and wiring.

Endpoints:
    GET  /            -> redirect to the demo UI
    GET  /healthz     -> readiness/liveness (200 once the model is loaded)
    GET  /metrics     -> Prometheus exposition
    POST /predict     -> single or batch classification (pydantic-validated)
    GET  /demo        -> static HTML demo that calls /predict

The model is loaded once in the lifespan startup and shared via ``app.state``;
the micro-batcher runs as a background task for the lifetime of the process.
"""
from __future__ import annotations

import logging
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any, Dict, List

from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import FileResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from pydantic import ValidationError

from app import __version__
from app.batching import MicroBatcher
from app.classifier import load_classifier
from app.config import get_settings
from app.metrics import metrics_middleware, render_latest
from app.schemas import (BatchPredictRequest, BatchPredictResponse,
                         HealthResponse, PredictRequest, PredictResponse,
                         Prediction)

logger = logging.getLogger(__name__)

DEMO_DIR = Path(__file__).resolve().parent.parent / "demo"
STATIC_DIR = DEMO_DIR / "vendor"


def _build_prediction(text: str, dist: Dict[str, float]) -> Prediction:
    top = max(dist, key=dist.get)
    return Prediction(label=top, score=dist[top], probabilities=dist)


@asynccontextmanager
async def lifespan(app: FastAPI):
    settings = get_settings()
    logging.basicConfig(
        level=getattr(logging, settings.log_level.upper(), logging.INFO),
        format="%(asctime)s %(levelname)s %(name)s %(message)s",
    )
    logger.info(
        "starting emotion-api v%s (offline=%s, model_id=%s)",
        __version__, settings.offline, settings.model_id,
    )
    classifier = load_classifier(settings)
    batcher = MicroBatcher(
        classifier,
        max_batch_size=settings.max_batch_size,
        max_delay_ms=settings.batch_max_delay_ms,
    )
    await batcher.start()

    app.state.settings = settings
    app.state.classifier = classifier
    app.state.batcher = batcher
    app.state.ready = True
    try:
        yield
    finally:
        app.state.ready = False
        await batcher.stop()
        logger.info("emotion-api stopped")


def create_app() -> FastAPI:
    """Application factory (one per process; tests build their own)."""
    app = FastAPI(
        title="DistilBERT Emotion API",
        version=__version__,
        summary="Production inference service for the LaelaZ/distilbert-emotion model.",
        description=(
            "Classify a sentence into one of six emotions "
            "(sadness, joy, love, anger, fear, surprise) with full per-class "
            "probabilities. Batched, observable, and runnable fully offline."
        ),
        lifespan=lifespan,
    )
    app.middleware("http")(metrics_middleware)

    @app.get("/", include_in_schema=False)
    async def root() -> RedirectResponse:
        return RedirectResponse(url="/demo")

    @app.get("/healthz", response_model=HealthResponse, tags=["ops"])
    async def healthz(request: Request) -> HealthResponse:
        """Readiness + liveness. 503 until the model is loaded and the batcher runs."""
        state = request.app.state
        if not getattr(state, "ready", False):
            raise HTTPException(status_code=503, detail="not ready")
        settings = state.settings
        return HealthResponse(
            status="ok",
            backend=state.classifier.backend,
            model_id=settings.model_id,
            offline=settings.offline,
            version=__version__,
        )

    @app.get("/metrics", tags=["ops"], include_in_schema=False)
    async def metrics():
        return render_latest()

    @app.post(
        "/predict",
        tags=["inference"],
        summary="Classify one sentence or a batch.",
        responses={
            200: {
                "description": "Single prediction (PredictResponse) or batch "
                "(BatchPredictResponse), matching the request shape."
            }
        },
    )
    async def predict(request: Request, body: Dict[str, Any]) -> Any:
        """Accept ``{text: str}`` or ``{texts: [str, ...]}`` and classify.

        The body is parsed manually so one endpoint can serve both shapes and
        return the matching response shape. Validation errors surface as 422.
        """
        batcher: MicroBatcher = request.app.state.batcher

        is_single = "text" in body
        is_batch = "texts" in body
        if is_single == is_batch:  # neither, or both
            raise HTTPException(
                status_code=422,
                detail="provide exactly one of 'text' (string) or 'texts' (list of strings)",
            )

        try:
            if is_single:
                req = PredictRequest(**body)
                dist = await batcher.submit(req.text)
                pred = _build_prediction(req.text, dist)
                return PredictResponse(**pred.model_dump())
            req_b = BatchPredictRequest(**body)
            dists = await batcher.submit_many(req_b.texts)
            preds: List[Prediction] = [
                _build_prediction(t, d) for t, d in zip(req_b.texts, dists)
            ]
            return BatchPredictResponse(predictions=preds)
        except ValidationError as exc:
            # Translate pydantic errors into FastAPI's standard 422 envelope.
            # Strip the non-JSON-serializable bits pydantic attaches (the raw
            # exception object under `ctx`, and the docs URL).
            detail = [
                {k: v for k, v in err.items() if k not in ("ctx", "url")}
                for err in exc.errors()
            ]
            raise HTTPException(status_code=422, detail=detail)

    @app.get("/demo", tags=["ui"], include_in_schema=False)
    async def demo() -> FileResponse:
        return FileResponse(DEMO_DIR / "index.html")

    # Vendored, offline front-end assets (Tailwind Play CDN bundle). Mounting a
    # directory keeps the demo page fully network-free — no external CDN at runtime.
    if STATIC_DIR.is_dir():
        app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")

    return app


app = create_app()