| import os |
| import time |
| from contextlib import asynccontextmanager |
|
|
| from dotenv import load_dotenv |
| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import RedirectResponse, Response |
| from prometheus_client import Counter, Histogram, generate_latest, CONTENT_TYPE_LATEST |
| from pydantic import BaseModel |
| from transformers import pipeline |
|
|
| load_dotenv(dotenv_path=".env") |
|
|
| |
| |
| |
|
|
| REQUESTS = Counter( |
| "classify_requests_total", |
| "Total classification requests", |
| ["label"], |
| ) |
| LATENCY = Histogram( |
| "classify_latency_seconds", |
| "End-to-end classification latency", |
| buckets=[0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0], |
| ) |
| ERRORS = Counter( |
| "classify_errors_total", |
| "Total classification errors", |
| ) |
|
|
| |
| |
| |
| |
| model_store: dict = {} |
|
|
|
|
| |
| |
| |
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| model_id = os.environ.get("HF_MODEL_ID", "pranavsagar10/content-classifier-distilbert") |
| print(f"Loading model: {model_id}") |
|
|
| model_store["classifier"] = pipeline( |
| "text-classification", |
| model=model_id, |
| device=-1, |
| ) |
| print("Model loaded and ready.") |
| yield |
| model_store.clear() |
| print("Model unloaded.") |
|
|
|
|
| |
| app = FastAPI( |
| title="Content Intelligence API", |
| description="Classifies news text into World / Sports / Business / Sci/Tech", |
| version="1.0.0", |
| lifespan=lifespan, |
| ) |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["GET", "POST", "OPTIONS"], |
| allow_headers=["*"], |
| ) |
|
|
|
|
| |
| |
| |
| class ClassifyRequest(BaseModel): |
| text: str |
|
|
| class ClassifyResponse(BaseModel): |
| label: str |
| confidence: float |
| latency_ms: float |
| scores: dict[str, float] |
|
|
|
|
| |
| @app.get("/", include_in_schema=False) |
| def root(): |
| return RedirectResponse(url="/docs") |
|
|
|
|
| @app.post("/classify", response_model=ClassifyResponse) |
| def classify(req: ClassifyRequest): |
| if not req.text.strip(): |
| raise HTTPException(status_code=422, detail="text cannot be empty") |
|
|
| start = time.perf_counter() |
|
|
| try: |
| results = model_store["classifier"](req.text, truncation=True, max_length=128, top_k=None) |
| except Exception as e: |
| ERRORS.inc() |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| latency_s = time.perf_counter() - start |
| top = results[0] |
|
|
| REQUESTS.labels(label=top["label"]).inc() |
| LATENCY.observe(latency_s) |
|
|
| return ClassifyResponse( |
| label=top["label"], |
| confidence=round(top["score"], 4), |
| latency_ms=round(latency_s * 1000, 2), |
| scores={r["label"]: round(r["score"], 4) for r in results}, |
| ) |
|
|
|
|
| @app.get("/health") |
| def health(): |
| |
| |
| |
| model_ready = "classifier" in model_store |
| if not model_ready: |
| raise HTTPException(status_code=503, detail="model not loaded") |
| return {"status": "ok", "model": os.environ.get("HF_MODEL_ID", "unknown")} |
|
|
|
|
| @app.get("/metrics") |
| def metrics(): |
| |
| |
| return Response(generate_latest(), media_type=CONTENT_TYPE_LATEST) |
|
|