""" api.py ────── Production-ready REST API for document classification. Supports any combination of saved models via lazy loading. Usage ───── pip install fastapi uvicorn[standard] pydantic (already in requirements.txt) # Start the server (from project root, venv active) uvicorn api:app --host 0.0.0.0 --port 8000 --reload # Health check curl http://localhost:8000/health # Single prediction curl -X POST http://localhost:8000/predict \ -H "Content-Type: application/json" \ -d '{"text": "Fed raises interest rates by 50 bps", "model_name": "roberta_base"}' # Batch prediction curl -X POST http://localhost:8000/batch_predict \ -H "Content-Type: application/json" \ -d '{"texts": ["Apple unveils M5 chip", "Ronaldo scores again"], "model_name": "roberta_base"}' # Explore interactive docs at: http://localhost:8000/docs """ import logging import os import time from contextlib import asynccontextmanager from typing import Dict, List, Optional from uuid import uuid4 import numpy as np import torch from fastapi import FastAPI, HTTPException, Request from pydantic import BaseModel, Field from config import CFG import database logger = logging.getLogger("api") logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") # ── Pydantic schemas ────────────────────────────────────────────────────────── class PredictRequest(BaseModel): text: str = Field(..., min_length=1, max_length=10_000, example="Apple launches a groundbreaking AI chip.") model_name: str = Field(default="roberta_base", example="roberta_base", description="Directory name in saved_models/. " "Examples: 'roberta_base', 'lr', 'svm'.") class BatchPredictRequest(BaseModel): texts: List[str] = Field(..., min_length=1, max_length=256) model_name: str = Field(default="roberta_base") class Prediction(BaseModel): text: str request_id: str label_id: int label: str probabilities: Optional[Dict[str, float]] = None is_low_confidence: bool latency_ms: float class BatchResponse(BaseModel): predictions: List[Prediction] count: int total_latency_ms: float class HealthResponse(BaseModel): status: str loaded_models: List[str] quantized: Dict[str, bool] device: str version: str = "1.0.0" # ── Model registry ──────────────────────────────────────────────────────────── _registry: Dict[str, Dict] = {} # model_name → {"obj": ..., "kind": str, "quantized": bool} def _load_model(model_name: str): """Lazy-load a model on first request, then cache it in _registry.""" # Normalise model name mapping (case-insensitive & support aliases) name_lower = model_name.lower() if name_lower in ("lr", "svm"): model_name = name_lower elif "distilbert" in name_lower: model_name = "distilbert_base_uncased" elif "roberta" in name_lower: model_name = "roberta_base" elif "bert" in name_lower: model_name = "bert_base_uncased" if model_name in _registry: entry = _registry[model_name] return entry["obj"], entry["kind"], entry["quantized"] if model_name in ("lr", "svm"): import joblib path = os.path.join(CFG.models_dir, f"traditional_{model_name}.joblib") if not os.path.exists(path): raise FileNotFoundError(f"No model file: {path}") obj = joblib.load(path) kind = "sklearn" quantized = False else: from transformers import AutoModelForSequenceClassification, AutoTokenizer from transformer_model import _checkpoint_to_dir path = os.path.join(CFG.models_dir, model_name) if not os.path.isdir(path): alt = os.path.join(CFG.models_dir, _checkpoint_to_dir(model_name)) if os.path.isdir(alt): path = alt else: raise FileNotFoundError( f"No model directory: {path}\n" f"Hint: check saved_models/ for available directories." ) int8_path = f"{path}_int8" int8_file = os.path.join(int8_path, "model_int8.pt") if os.path.exists(int8_file): try: torch.backends.quantized.engine = "qnnpack" except Exception: pass try: model = torch.load(int8_file, map_location="cpu", weights_only=False) except TypeError: model = torch.load(int8_file, map_location="cpu") tokenizer = AutoTokenizer.from_pretrained(int8_path) quantized = True else: model = AutoModelForSequenceClassification.from_pretrained(path) tokenizer = AutoTokenizer.from_pretrained(path) quantized = False model.eval() obj = (model, tokenizer) kind = "transformer" _registry[model_name] = {"obj": obj, "kind": kind, "quantized": quantized} q = "int8" if quantized else "fp32" logger.info(f"Model cached: {model_name} [{kind}:{q}]") return obj, kind, quantized def _infer_single(text: str, obj, kind: str) -> Dict: if kind == "transformer": model, tokenizer = obj enc = tokenizer(text, truncation=True, max_length=CFG.max_length, return_tensors="pt") with torch.no_grad(): probs = torch.softmax(model(**enc).logits[0], dim=-1).numpy() pred_id = int(np.argmax(probs)) conf = float(np.max(probs)) return { "label_id": pred_id, "label": CFG.label_names[pred_id], "probabilities": { CFG.label_names[i]: round(float(p), 4) for i, p in enumerate(probs) }, "confidence": conf, } # sklearn pred_id = int(obj.predict([text])[0]) result = {"label_id": pred_id, "label": CFG.label_names[pred_id], "probabilities": None, "confidence": 1.0} clf = list(obj.named_steps.values())[-1] if hasattr(clf, "predict_proba"): probs = obj.predict_proba([text])[0] result["probabilities"] = { CFG.label_names[i]: round(float(p), 4) for i, p in enumerate(probs) } result["confidence"] = float(np.max(probs)) elif hasattr(clf, "decision_function"): scores = obj.decision_function([text]) scores = np.asarray(scores, dtype=np.float64).reshape(1, -1) scores = scores - np.max(scores, axis=1, keepdims=True) exps = np.exp(scores) probs = exps / np.sum(exps, axis=1, keepdims=True) result["confidence"] = float(np.max(probs)) return result def _infer_batch(texts: List[str], obj, kind: str) -> List[Dict]: if kind == "transformer": model, tokenizer = obj results = [] batch_size = 16 for i in range(0, len(texts), batch_size): batch = texts[i : i + batch_size] enc = tokenizer(batch, truncation=True, max_length=CFG.max_length, padding=True, return_tensors="pt") with torch.no_grad(): logits = model(**enc).logits probs_batch = torch.softmax(logits, dim=-1).numpy() for text, probs in zip(batch, probs_batch): pred_id = int(np.argmax(probs)) conf = float(np.max(probs)) results.append({ "label_id": pred_id, "label": CFG.label_names[pred_id], "probabilities": { CFG.label_names[i]: round(float(p), 4) for i, p in enumerate(probs) }, "confidence": conf, "text": text, }) return results # sklearn batch preds = obj.predict(texts) clf = list(obj.named_steps.values())[-1] confidences = np.ones(len(texts), dtype=np.float64) if hasattr(clf, "predict_proba"): probs = obj.predict_proba(texts) confidences = np.max(probs, axis=1) elif hasattr(clf, "decision_function"): scores = obj.decision_function(texts) scores = np.asarray(scores, dtype=np.float64) if scores.ndim == 1: scores = np.stack([-scores, scores], axis=1) scores = scores - np.max(scores, axis=1, keepdims=True) exps = np.exp(scores) probs = exps / np.sum(exps, axis=1, keepdims=True) confidences = np.max(probs, axis=1) results = [] for p, t, c in zip(preds, texts, confidences): results.append( { "label_id": int(p), "label": CFG.label_names[int(p)], "probabilities": None, "confidence": float(c), "text": t, } ) return results # ── FastAPI app ─────────────────────────────────────────────────────────────── @asynccontextmanager async def lifespan(app: FastAPI): """Pre-warm the default model on server startup.""" try: database.init_db() _load_model("roberta_base") logger.info("Default model (roberta_base) pre-loaded.") except FileNotFoundError: logger.warning("Default model not found; will load on first request.") yield _registry.clear() logger.info("Model registry cleared.") app = FastAPI( title="Document Classifier API", description=( "Multi-class news text classification over four categories: " "World · Sports · Business · Sci/Tech. " "Supports traditional ML and transformer models." ), version="1.0.0", lifespan=lifespan, ) from fastapi.middleware.cors import CORSMiddleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.middleware("http") async def add_private_network_header(request: Request, call_next): response = await call_next(request) if "access-control-request-private-network" in request.headers: response.headers["Access-Control-Allow-Private-Network"] = "true" origin = request.headers.get("origin") if origin: response.headers["Access-Control-Allow-Origin"] = origin return response @app.get("/health", response_model=HealthResponse, tags=["Status"]) async def health(): """Confirm the API is running, list loaded models, and report device.""" return HealthResponse( status="ok", loaded_models=list(_registry.keys()), quantized={k: bool(v.get("quantized")) for k, v in _registry.items()}, device=CFG.device, ) @app.get("/labels", tags=["Status"]) async def get_labels(): """Return the four classification labels.""" return { "labels": [ {"id": i, "name": n} for i, n in enumerate(CFG.label_names) ] } @app.get("/models", tags=["Status"]) async def list_available_models(): """List all models that exist in saved_models/ and are ready to load.""" available = [] if os.path.isdir(CFG.models_dir): for name in os.listdir(CFG.models_dir): path = os.path.join(CFG.models_dir, name) if name.endswith("_int8"): continue if os.path.isdir(path) and os.path.exists( os.path.join(path, "config.json") ): int8_file = os.path.join(f"{path}_int8", "model_int8.pt") available.append( { "name": name, "type": "transformer", "quantized": bool(os.path.exists(int8_file)), } ) for fname in os.listdir(CFG.models_dir): if fname.startswith("traditional_") and fname.endswith(".joblib"): short = fname.replace("traditional_", "").replace(".joblib", "") available.append({"name": short, "type": "sklearn", "quantized": False}) return {"models": available, "count": len(available)} @app.post("/predict", response_model=Prediction, tags=["Inference"]) async def predict(req: PredictRequest): """Classify a single text document and return label + probabilities.""" t0 = time.perf_counter() request_id = str(uuid4()) try: obj, kind, _ = _load_model(req.model_name) except FileNotFoundError as exc: raise HTTPException(status_code=404, detail=str(exc)) result = _infer_single(req.text, obj, kind) latency = (time.perf_counter() - t0) * 1000 confidence = float(result.get("confidence", 1.0)) is_low = bool(confidence < float(CFG.low_confidence_threshold)) database.log_request( request_id=request_id, model_name=req.model_name, input_text=req.text, predicted_label=str(result["label"]), predicted_label_id=int(result["label_id"]), confidence=confidence, latency_ms=float(latency), is_batch=False, ) return Prediction( text=req.text[:200], request_id=request_id, is_low_confidence=is_low, latency_ms=round(latency, 2), label_id=result["label_id"], label=result["label"], probabilities=result.get("probabilities"), ) @app.post("/batch_predict", response_model=BatchResponse, tags=["Inference"]) async def batch_predict(req: BatchPredictRequest): """Classify a list of documents in one call (up to 256 texts).""" t0 = time.perf_counter() try: obj, kind, _ = _load_model(req.model_name) except FileNotFoundError as exc: raise HTTPException(status_code=404, detail=str(exc)) raw_results = _infer_batch(req.texts, obj, kind) total_ms = (time.perf_counter() - t0) * 1000 per_item_ms = (total_ms / len(req.texts)) if req.texts else 0.0 predictions = [ Prediction( text=r["text"][:200], request_id=str(uuid4()), label_id=r["label_id"], label=r["label"], probabilities=r.get("probabilities"), is_low_confidence=bool(float(r.get("confidence", 1.0)) < float(CFG.low_confidence_threshold)), latency_ms=round(per_item_ms, 2), ) for r in raw_results ] for r, pred in zip(raw_results, predictions): database.log_request( request_id=pred.request_id, model_name=req.model_name, input_text=r["text"], predicted_label=str(r["label"]), predicted_label_id=int(r["label_id"]), confidence=float(r.get("confidence", 1.0)), latency_ms=float(per_item_ms), is_batch=True, ) return BatchResponse( predictions=predictions, count=len(predictions), total_latency_ms=round(total_ms, 2), ) @app.get("/analytics/summary", tags=["Analytics"]) async def analytics_summary(model_name: Optional[str] = None, days: int = 7): return database.get_summary(model_name=model_name, days=days) @app.get("/analytics/history", tags=["Analytics"]) async def analytics_history(limit: int = 50, offset: int = 0): return database.get_request_history(limit=limit, offset=offset) @app.get("/analytics/low_confidence", tags=["Analytics"]) async def analytics_low_confidence(reviewed: bool = False, limit: int = 50): return database.get_low_confidence_flags(reviewed=reviewed, limit=limit) class ReviewBody(BaseModel): note: Optional[str] = None @app.patch("/analytics/review/{request_id}", tags=["Analytics"]) async def analytics_mark_reviewed(request_id: str, body: ReviewBody): database.mark_reviewed(request_id=request_id, note=body.note) return {"request_id": request_id, "reviewed": True} @app.post("/analytics/export_flags", tags=["Analytics"]) async def analytics_export_flags(): return database.export_low_confidence_to_folder()