File size: 5,596 Bytes
0083b07 565cb89 7dfa257 0083b07 565cb89 0083b07 565cb89 0083b07 7dfa257 0083b07 565cb89 0083b07 565cb89 0083b07 565cb89 0083b07 565cb89 0083b07 565cb89 0083b07 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | import os
import time
from contextlib import asynccontextmanager
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse, Response
from prometheus_client import Counter, Histogram, generate_latest, CONTENT_TYPE_LATEST
from pydantic import BaseModel
from transformers import pipeline
load_dotenv(dotenv_path=".env")
# ββ Prometheus metrics βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Counter: monotonically increasing. Good for "how many times did X happen?"
# Histogram: tracks distribution of values. Good for latency (gives you p50/p95/p99).
REQUESTS = Counter(
"classify_requests_total",
"Total classification requests",
["label"], # one counter per predicted label β lets you see label distribution
)
LATENCY = Histogram(
"classify_latency_seconds",
"End-to-end classification latency",
buckets=[0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0], # bucket boundaries in seconds
)
ERRORS = Counter(
"classify_errors_total",
"Total classification errors",
)
# ββ Model store ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# A plain dict used as a module-level container for the loaded model.
# This is the "kept in memory" part β model is loaded once and lives here
# for the entire lifetime of the process.
model_store: dict = {}
# ββ Lifespan βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# FastAPI's lifespan runs the code before `yield` at startup and after `yield`
# at shutdown. Equivalent to @PostConstruct / @PreDestroy in Spring Boot.
@asynccontextmanager
async def lifespan(app: FastAPI):
model_id = os.environ.get("HF_MODEL_ID", "pranavsagar10/content-classifier-distilbert")
print(f"Loading model: {model_id}")
model_store["classifier"] = pipeline(
"text-classification",
model=model_id,
device=-1, # -1 = CPU. For serving we use CPU β MPS/CUDA is for training.
)
print("Model loaded and ready.")
yield # app runs here, handling requests
model_store.clear() # cleanup on shutdown
print("Model unloaded.")
# ββ App ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
app = FastAPI(
title="Content Intelligence API",
description="Classifies news text into World / Sports / Business / Sci/Tech",
version="1.0.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["GET", "POST", "OPTIONS"],
allow_headers=["*"],
)
# ββ Request / Response schemas βββββββββββββββββββββββββββββββββββββββββββββββββ
# Pydantic models validate the request body automatically.
# If `text` is missing or not a string, FastAPI returns a 422 before your code runs.
class ClassifyRequest(BaseModel):
text: str
class ClassifyResponse(BaseModel):
label: str
confidence: float
latency_ms: float
scores: dict[str, float] # all 4 class probabilities, sorted by score desc
# ββ Endpoints ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@app.get("/", include_in_schema=False)
def root():
return RedirectResponse(url="/docs")
@app.post("/classify", response_model=ClassifyResponse)
def classify(req: ClassifyRequest):
if not req.text.strip():
raise HTTPException(status_code=422, detail="text cannot be empty")
start = time.perf_counter()
try:
results = model_store["classifier"](req.text, truncation=True, max_length=128, top_k=None)
except Exception as e:
ERRORS.inc()
raise HTTPException(status_code=500, detail=str(e))
latency_s = time.perf_counter() - start
top = results[0]
REQUESTS.labels(label=top["label"]).inc()
LATENCY.observe(latency_s)
return ClassifyResponse(
label=top["label"],
confidence=round(top["score"], 4),
latency_ms=round(latency_s * 1000, 2),
scores={r["label"]: round(r["score"], 4) for r in results},
)
@app.get("/health")
def health():
# Load balancers call this to decide whether to send traffic here.
# Returns 200 only when the model is actually loaded β not just when the
# process is alive. That distinction matters during startup.
model_ready = "classifier" in model_store
if not model_ready:
raise HTTPException(status_code=503, detail="model not loaded")
return {"status": "ok", "model": os.environ.get("HF_MODEL_ID", "unknown")}
@app.get("/metrics")
def metrics():
# Prometheus scrapes this endpoint on a schedule (e.g. every 15s).
# Grafana reads from Prometheus. This is the starting point of that chain.
return Response(generate_latest(), media_type=CONTENT_TYPE_LATEST)
|