Spaces:
Running
Running
| """ | |
| api.py | |
| ────── | |
| Production-ready REST API for document classification. | |
| Supports any combination of saved models via lazy loading. | |
| Usage | |
| ───── | |
| pip install fastapi uvicorn[standard] pydantic (already in requirements.txt) | |
| # Start the server (from project root, venv active) | |
| uvicorn api:app --host 0.0.0.0 --port 8000 --reload | |
| # Health check | |
| curl http://localhost:8000/health | |
| # Single prediction | |
| curl -X POST http://localhost:8000/predict \ | |
| -H "Content-Type: application/json" \ | |
| -d '{"text": "Fed raises interest rates by 50 bps", "model_name": "roberta_base"}' | |
| # Batch prediction | |
| curl -X POST http://localhost:8000/batch_predict \ | |
| -H "Content-Type: application/json" \ | |
| -d '{"texts": ["Apple unveils M5 chip", "Ronaldo scores again"], "model_name": "roberta_base"}' | |
| # Explore interactive docs at: http://localhost:8000/docs | |
| """ | |
| import logging | |
| import os | |
| import time | |
| from contextlib import asynccontextmanager | |
| from typing import Dict, List, Optional | |
| from uuid import uuid4 | |
| import numpy as np | |
| import torch | |
| from fastapi import FastAPI, HTTPException, Request | |
| from pydantic import BaseModel, Field | |
| from config import CFG | |
| import database | |
| logger = logging.getLogger("api") | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") | |
| # ── Pydantic schemas ────────────────────────────────────────────────────────── | |
| class PredictRequest(BaseModel): | |
| text: str = Field(..., min_length=1, max_length=10_000, | |
| example="Apple launches a groundbreaking AI chip.") | |
| model_name: str = Field(default="roberta_base", | |
| example="roberta_base", | |
| description="Directory name in saved_models/. " | |
| "Examples: 'roberta_base', 'lr', 'svm'.") | |
| class BatchPredictRequest(BaseModel): | |
| texts: List[str] = Field(..., min_length=1, max_length=256) | |
| model_name: str = Field(default="roberta_base") | |
| class Prediction(BaseModel): | |
| text: str | |
| request_id: str | |
| label_id: int | |
| label: str | |
| probabilities: Optional[Dict[str, float]] = None | |
| is_low_confidence: bool | |
| latency_ms: float | |
| class BatchResponse(BaseModel): | |
| predictions: List[Prediction] | |
| count: int | |
| total_latency_ms: float | |
| class HealthResponse(BaseModel): | |
| status: str | |
| loaded_models: List[str] | |
| quantized: Dict[str, bool] | |
| device: str | |
| version: str = "1.0.0" | |
| # ── Model registry ──────────────────────────────────────────────────────────── | |
| _registry: Dict[str, Dict] = {} # model_name → {"obj": ..., "kind": str, "quantized": bool} | |
| def _load_model(model_name: str): | |
| """Lazy-load a model on first request, then cache it in _registry.""" | |
| # Normalise model name mapping (case-insensitive & support aliases) | |
| name_lower = model_name.lower() | |
| if name_lower in ("lr", "svm"): | |
| model_name = name_lower | |
| elif "distilbert" in name_lower: | |
| model_name = "distilbert_base_uncased" | |
| elif "roberta" in name_lower: | |
| model_name = "roberta_base" | |
| elif "bert" in name_lower: | |
| model_name = "bert_base_uncased" | |
| if model_name in _registry: | |
| entry = _registry[model_name] | |
| return entry["obj"], entry["kind"], entry["quantized"] | |
| if model_name in ("lr", "svm"): | |
| import joblib | |
| path = os.path.join(CFG.models_dir, f"traditional_{model_name}.joblib") | |
| if not os.path.exists(path): | |
| raise FileNotFoundError(f"No model file: {path}") | |
| obj = joblib.load(path) | |
| kind = "sklearn" | |
| quantized = False | |
| else: | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| from transformer_model import _checkpoint_to_dir | |
| path = os.path.join(CFG.models_dir, model_name) | |
| if not os.path.isdir(path): | |
| alt = os.path.join(CFG.models_dir, _checkpoint_to_dir(model_name)) | |
| if os.path.isdir(alt): | |
| path = alt | |
| else: | |
| raise FileNotFoundError( | |
| f"No model directory: {path}\n" | |
| f"Hint: check saved_models/ for available directories." | |
| ) | |
| int8_path = f"{path}_int8" | |
| int8_file = os.path.join(int8_path, "model_int8.pt") | |
| if os.path.exists(int8_file): | |
| try: | |
| torch.backends.quantized.engine = "qnnpack" | |
| except Exception: | |
| pass | |
| try: | |
| model = torch.load(int8_file, map_location="cpu", weights_only=False) | |
| except TypeError: | |
| model = torch.load(int8_file, map_location="cpu") | |
| tokenizer = AutoTokenizer.from_pretrained(int8_path) | |
| quantized = True | |
| else: | |
| model = AutoModelForSequenceClassification.from_pretrained(path) | |
| tokenizer = AutoTokenizer.from_pretrained(path) | |
| quantized = False | |
| model.eval() | |
| obj = (model, tokenizer) | |
| kind = "transformer" | |
| _registry[model_name] = {"obj": obj, "kind": kind, "quantized": quantized} | |
| q = "int8" if quantized else "fp32" | |
| logger.info(f"Model cached: {model_name} [{kind}:{q}]") | |
| return obj, kind, quantized | |
| def _infer_single(text: str, obj, kind: str) -> Dict: | |
| if kind == "transformer": | |
| model, tokenizer = obj | |
| enc = tokenizer(text, truncation=True, | |
| max_length=CFG.max_length, return_tensors="pt") | |
| with torch.no_grad(): | |
| probs = torch.softmax(model(**enc).logits[0], dim=-1).numpy() | |
| pred_id = int(np.argmax(probs)) | |
| conf = float(np.max(probs)) | |
| return { | |
| "label_id": pred_id, | |
| "label": CFG.label_names[pred_id], | |
| "probabilities": { | |
| CFG.label_names[i]: round(float(p), 4) | |
| for i, p in enumerate(probs) | |
| }, | |
| "confidence": conf, | |
| } | |
| # sklearn | |
| pred_id = int(obj.predict([text])[0]) | |
| result = {"label_id": pred_id, "label": CFG.label_names[pred_id], | |
| "probabilities": None, "confidence": 1.0} | |
| clf = list(obj.named_steps.values())[-1] | |
| if hasattr(clf, "predict_proba"): | |
| probs = obj.predict_proba([text])[0] | |
| result["probabilities"] = { | |
| CFG.label_names[i]: round(float(p), 4) for i, p in enumerate(probs) | |
| } | |
| result["confidence"] = float(np.max(probs)) | |
| elif hasattr(clf, "decision_function"): | |
| scores = obj.decision_function([text]) | |
| scores = np.asarray(scores, dtype=np.float64).reshape(1, -1) | |
| scores = scores - np.max(scores, axis=1, keepdims=True) | |
| exps = np.exp(scores) | |
| probs = exps / np.sum(exps, axis=1, keepdims=True) | |
| result["confidence"] = float(np.max(probs)) | |
| return result | |
| def _infer_batch(texts: List[str], obj, kind: str) -> List[Dict]: | |
| if kind == "transformer": | |
| model, tokenizer = obj | |
| results = [] | |
| batch_size = 16 | |
| for i in range(0, len(texts), batch_size): | |
| batch = texts[i : i + batch_size] | |
| enc = tokenizer(batch, truncation=True, max_length=CFG.max_length, | |
| padding=True, return_tensors="pt") | |
| with torch.no_grad(): | |
| logits = model(**enc).logits | |
| probs_batch = torch.softmax(logits, dim=-1).numpy() | |
| for text, probs in zip(batch, probs_batch): | |
| pred_id = int(np.argmax(probs)) | |
| conf = float(np.max(probs)) | |
| results.append({ | |
| "label_id": pred_id, | |
| "label": CFG.label_names[pred_id], | |
| "probabilities": { | |
| CFG.label_names[i]: round(float(p), 4) | |
| for i, p in enumerate(probs) | |
| }, | |
| "confidence": conf, | |
| "text": text, | |
| }) | |
| return results | |
| # sklearn batch | |
| preds = obj.predict(texts) | |
| clf = list(obj.named_steps.values())[-1] | |
| confidences = np.ones(len(texts), dtype=np.float64) | |
| if hasattr(clf, "predict_proba"): | |
| probs = obj.predict_proba(texts) | |
| confidences = np.max(probs, axis=1) | |
| elif hasattr(clf, "decision_function"): | |
| scores = obj.decision_function(texts) | |
| scores = np.asarray(scores, dtype=np.float64) | |
| if scores.ndim == 1: | |
| scores = np.stack([-scores, scores], axis=1) | |
| scores = scores - np.max(scores, axis=1, keepdims=True) | |
| exps = np.exp(scores) | |
| probs = exps / np.sum(exps, axis=1, keepdims=True) | |
| confidences = np.max(probs, axis=1) | |
| results = [] | |
| for p, t, c in zip(preds, texts, confidences): | |
| results.append( | |
| { | |
| "label_id": int(p), | |
| "label": CFG.label_names[int(p)], | |
| "probabilities": None, | |
| "confidence": float(c), | |
| "text": t, | |
| } | |
| ) | |
| return results | |
| # ── FastAPI app ─────────────────────────────────────────────────────────────── | |
| async def lifespan(app: FastAPI): | |
| """Pre-warm the default model on server startup.""" | |
| try: | |
| database.init_db() | |
| _load_model("roberta_base") | |
| logger.info("Default model (roberta_base) pre-loaded.") | |
| except FileNotFoundError: | |
| logger.warning("Default model not found; will load on first request.") | |
| yield | |
| _registry.clear() | |
| logger.info("Model registry cleared.") | |
| app = FastAPI( | |
| title="Document Classifier API", | |
| description=( | |
| "Multi-class news text classification over four categories: " | |
| "World · Sports · Business · Sci/Tech. " | |
| "Supports traditional ML and transformer models." | |
| ), | |
| version="1.0.0", | |
| lifespan=lifespan, | |
| ) | |
| from fastapi.middleware.cors import CORSMiddleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def add_private_network_header(request: Request, call_next): | |
| response = await call_next(request) | |
| if "access-control-request-private-network" in request.headers: | |
| response.headers["Access-Control-Allow-Private-Network"] = "true" | |
| origin = request.headers.get("origin") | |
| if origin: | |
| response.headers["Access-Control-Allow-Origin"] = origin | |
| return response | |
| async def health(): | |
| """Confirm the API is running, list loaded models, and report device.""" | |
| return HealthResponse( | |
| status="ok", | |
| loaded_models=list(_registry.keys()), | |
| quantized={k: bool(v.get("quantized")) for k, v in _registry.items()}, | |
| device=CFG.device, | |
| ) | |
| async def get_labels(): | |
| """Return the four classification labels.""" | |
| return { | |
| "labels": [ | |
| {"id": i, "name": n} for i, n in enumerate(CFG.label_names) | |
| ] | |
| } | |
| async def list_available_models(): | |
| """List all models that exist in saved_models/ and are ready to load.""" | |
| available = [] | |
| if os.path.isdir(CFG.models_dir): | |
| for name in os.listdir(CFG.models_dir): | |
| path = os.path.join(CFG.models_dir, name) | |
| if name.endswith("_int8"): | |
| continue | |
| if os.path.isdir(path) and os.path.exists( | |
| os.path.join(path, "config.json") | |
| ): | |
| int8_file = os.path.join(f"{path}_int8", "model_int8.pt") | |
| available.append( | |
| { | |
| "name": name, | |
| "type": "transformer", | |
| "quantized": bool(os.path.exists(int8_file)), | |
| } | |
| ) | |
| for fname in os.listdir(CFG.models_dir): | |
| if fname.startswith("traditional_") and fname.endswith(".joblib"): | |
| short = fname.replace("traditional_", "").replace(".joblib", "") | |
| available.append({"name": short, "type": "sklearn", "quantized": False}) | |
| return {"models": available, "count": len(available)} | |
| async def predict(req: PredictRequest): | |
| """Classify a single text document and return label + probabilities.""" | |
| t0 = time.perf_counter() | |
| request_id = str(uuid4()) | |
| try: | |
| obj, kind, _ = _load_model(req.model_name) | |
| except FileNotFoundError as exc: | |
| raise HTTPException(status_code=404, detail=str(exc)) | |
| result = _infer_single(req.text, obj, kind) | |
| latency = (time.perf_counter() - t0) * 1000 | |
| confidence = float(result.get("confidence", 1.0)) | |
| is_low = bool(confidence < float(CFG.low_confidence_threshold)) | |
| database.log_request( | |
| request_id=request_id, | |
| model_name=req.model_name, | |
| input_text=req.text, | |
| predicted_label=str(result["label"]), | |
| predicted_label_id=int(result["label_id"]), | |
| confidence=confidence, | |
| latency_ms=float(latency), | |
| is_batch=False, | |
| ) | |
| return Prediction( | |
| text=req.text[:200], | |
| request_id=request_id, | |
| is_low_confidence=is_low, | |
| latency_ms=round(latency, 2), | |
| label_id=result["label_id"], | |
| label=result["label"], | |
| probabilities=result.get("probabilities"), | |
| ) | |
| async def batch_predict(req: BatchPredictRequest): | |
| """Classify a list of documents in one call (up to 256 texts).""" | |
| t0 = time.perf_counter() | |
| try: | |
| obj, kind, _ = _load_model(req.model_name) | |
| except FileNotFoundError as exc: | |
| raise HTTPException(status_code=404, detail=str(exc)) | |
| raw_results = _infer_batch(req.texts, obj, kind) | |
| total_ms = (time.perf_counter() - t0) * 1000 | |
| per_item_ms = (total_ms / len(req.texts)) if req.texts else 0.0 | |
| predictions = [ | |
| Prediction( | |
| text=r["text"][:200], | |
| request_id=str(uuid4()), | |
| label_id=r["label_id"], | |
| label=r["label"], | |
| probabilities=r.get("probabilities"), | |
| is_low_confidence=bool(float(r.get("confidence", 1.0)) < float(CFG.low_confidence_threshold)), | |
| latency_ms=round(per_item_ms, 2), | |
| ) | |
| for r in raw_results | |
| ] | |
| for r, pred in zip(raw_results, predictions): | |
| database.log_request( | |
| request_id=pred.request_id, | |
| model_name=req.model_name, | |
| input_text=r["text"], | |
| predicted_label=str(r["label"]), | |
| predicted_label_id=int(r["label_id"]), | |
| confidence=float(r.get("confidence", 1.0)), | |
| latency_ms=float(per_item_ms), | |
| is_batch=True, | |
| ) | |
| return BatchResponse( | |
| predictions=predictions, | |
| count=len(predictions), | |
| total_latency_ms=round(total_ms, 2), | |
| ) | |
| async def analytics_summary(model_name: Optional[str] = None, days: int = 7): | |
| return database.get_summary(model_name=model_name, days=days) | |
| async def analytics_history(limit: int = 50, offset: int = 0): | |
| return database.get_request_history(limit=limit, offset=offset) | |
| async def analytics_low_confidence(reviewed: bool = False, limit: int = 50): | |
| return database.get_low_confidence_flags(reviewed=reviewed, limit=limit) | |
| class ReviewBody(BaseModel): | |
| note: Optional[str] = None | |
| async def analytics_mark_reviewed(request_id: str, body: ReviewBody): | |
| database.mark_reviewed(request_id=request_id, note=body.note) | |
| return {"request_id": request_id, "reviewed": True} | |
| async def analytics_export_flags(): | |
| return database.export_low_confidence_to_folder() | |