File size: 5,596 Bytes
0083b07
 
 
 
 
 
565cb89
7dfa257
0083b07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
565cb89
 
 
 
 
 
 
0083b07
 
 
 
 
 
 
 
 
 
 
565cb89
0083b07
 
 
7dfa257
 
 
 
 
0083b07
 
 
 
 
 
 
 
565cb89
0083b07
 
 
 
 
565cb89
0083b07
565cb89
0083b07
 
 
565cb89
 
0083b07
565cb89
0083b07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os
import time
from contextlib import asynccontextmanager

from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse, Response
from prometheus_client import Counter, Histogram, generate_latest, CONTENT_TYPE_LATEST
from pydantic import BaseModel
from transformers import pipeline

load_dotenv(dotenv_path=".env")

# ── Prometheus metrics ─────────────────────────────────────────────────────────
# Counter: monotonically increasing. Good for "how many times did X happen?"
# Histogram: tracks distribution of values. Good for latency (gives you p50/p95/p99).

REQUESTS = Counter(
    "classify_requests_total",
    "Total classification requests",
    ["label"],          # one counter per predicted label β€” lets you see label distribution
)
LATENCY = Histogram(
    "classify_latency_seconds",
    "End-to-end classification latency",
    buckets=[0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0],   # bucket boundaries in seconds
)
ERRORS = Counter(
    "classify_errors_total",
    "Total classification errors",
)

# ── Model store ────────────────────────────────────────────────────────────────
# A plain dict used as a module-level container for the loaded model.
# This is the "kept in memory" part β€” model is loaded once and lives here
# for the entire lifetime of the process.
model_store: dict = {}


# ── Lifespan ───────────────────────────────────────────────────────────────────
# FastAPI's lifespan runs the code before `yield` at startup and after `yield`
# at shutdown. Equivalent to @PostConstruct / @PreDestroy in Spring Boot.
@asynccontextmanager
async def lifespan(app: FastAPI):
    model_id = os.environ.get("HF_MODEL_ID", "pranavsagar10/content-classifier-distilbert")
    print(f"Loading model: {model_id}")

    model_store["classifier"] = pipeline(
        "text-classification",
        model=model_id,
        device=-1,      # -1 = CPU. For serving we use CPU β€” MPS/CUDA is for training.
    )
    print("Model loaded and ready.")
    yield                               # app runs here, handling requests
    model_store.clear()                 # cleanup on shutdown
    print("Model unloaded.")


# ── App ────────────────────────────────────────────────────────────────────────
app = FastAPI(
    title="Content Intelligence API",
    description="Classifies news text into World / Sports / Business / Sci/Tech",
    version="1.0.0",
    lifespan=lifespan,
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["GET", "POST", "OPTIONS"],
    allow_headers=["*"],
)


# ── Request / Response schemas ─────────────────────────────────────────────────
# Pydantic models validate the request body automatically.
# If `text` is missing or not a string, FastAPI returns a 422 before your code runs.
class ClassifyRequest(BaseModel):
    text: str

class ClassifyResponse(BaseModel):
    label: str
    confidence: float
    latency_ms: float
    scores: dict[str, float]  # all 4 class probabilities, sorted by score desc


# ── Endpoints ──────────────────────────────────────────────────────────────────
@app.get("/", include_in_schema=False)
def root():
    return RedirectResponse(url="/docs")


@app.post("/classify", response_model=ClassifyResponse)
def classify(req: ClassifyRequest):
    if not req.text.strip():
        raise HTTPException(status_code=422, detail="text cannot be empty")

    start = time.perf_counter()

    try:
        results = model_store["classifier"](req.text, truncation=True, max_length=128, top_k=None)
    except Exception as e:
        ERRORS.inc()
        raise HTTPException(status_code=500, detail=str(e))

    latency_s = time.perf_counter() - start
    top = results[0]

    REQUESTS.labels(label=top["label"]).inc()
    LATENCY.observe(latency_s)

    return ClassifyResponse(
        label=top["label"],
        confidence=round(top["score"], 4),
        latency_ms=round(latency_s * 1000, 2),
        scores={r["label"]: round(r["score"], 4) for r in results},
    )


@app.get("/health")
def health():
    # Load balancers call this to decide whether to send traffic here.
    # Returns 200 only when the model is actually loaded β€” not just when the
    # process is alive. That distinction matters during startup.
    model_ready = "classifier" in model_store
    if not model_ready:
        raise HTTPException(status_code=503, detail="model not loaded")
    return {"status": "ok", "model": os.environ.get("HF_MODEL_ID", "unknown")}


@app.get("/metrics")
def metrics():
    # Prometheus scrapes this endpoint on a schedule (e.g. every 15s).
    # Grafana reads from Prometheus. This is the starting point of that chain.
    return Response(generate_latest(), media_type=CONTENT_TYPE_LATEST)