Prototype6239's picture
Upload folder using huggingface_hub
a229747 verified
Raw
History Blame Contribute Delete
16.5 kB
"""
api.py
──────
Production-ready REST API for document classification.
Supports any combination of saved models via lazy loading.
Usage
─────
pip install fastapi uvicorn[standard] pydantic (already in requirements.txt)
# Start the server (from project root, venv active)
uvicorn api:app --host 0.0.0.0 --port 8000 --reload
# Health check
curl http://localhost:8000/health
# Single prediction
curl -X POST http://localhost:8000/predict \
-H "Content-Type: application/json" \
-d '{"text": "Fed raises interest rates by 50 bps", "model_name": "roberta_base"}'
# Batch prediction
curl -X POST http://localhost:8000/batch_predict \
-H "Content-Type: application/json" \
-d '{"texts": ["Apple unveils M5 chip", "Ronaldo scores again"], "model_name": "roberta_base"}'
# Explore interactive docs at: http://localhost:8000/docs
"""
import logging
import os
import time
from contextlib import asynccontextmanager
from typing import Dict, List, Optional
from uuid import uuid4
import numpy as np
import torch
from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel, Field
from config import CFG
import database
logger = logging.getLogger("api")
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
# ── Pydantic schemas ──────────────────────────────────────────────────────────
class PredictRequest(BaseModel):
text: str = Field(..., min_length=1, max_length=10_000,
example="Apple launches a groundbreaking AI chip.")
model_name: str = Field(default="roberta_base",
example="roberta_base",
description="Directory name in saved_models/. "
"Examples: 'roberta_base', 'lr', 'svm'.")
class BatchPredictRequest(BaseModel):
texts: List[str] = Field(..., min_length=1, max_length=256)
model_name: str = Field(default="roberta_base")
class Prediction(BaseModel):
text: str
request_id: str
label_id: int
label: str
probabilities: Optional[Dict[str, float]] = None
is_low_confidence: bool
latency_ms: float
class BatchResponse(BaseModel):
predictions: List[Prediction]
count: int
total_latency_ms: float
class HealthResponse(BaseModel):
status: str
loaded_models: List[str]
quantized: Dict[str, bool]
device: str
version: str = "1.0.0"
# ── Model registry ────────────────────────────────────────────────────────────
_registry: Dict[str, Dict] = {} # model_name → {"obj": ..., "kind": str, "quantized": bool}
def _load_model(model_name: str):
"""Lazy-load a model on first request, then cache it in _registry."""
# Normalise model name mapping (case-insensitive & support aliases)
name_lower = model_name.lower()
if name_lower in ("lr", "svm"):
model_name = name_lower
elif "distilbert" in name_lower:
model_name = "distilbert_base_uncased"
elif "roberta" in name_lower:
model_name = "roberta_base"
elif "bert" in name_lower:
model_name = "bert_base_uncased"
if model_name in _registry:
entry = _registry[model_name]
return entry["obj"], entry["kind"], entry["quantized"]
if model_name in ("lr", "svm"):
import joblib
path = os.path.join(CFG.models_dir, f"traditional_{model_name}.joblib")
if not os.path.exists(path):
raise FileNotFoundError(f"No model file: {path}")
obj = joblib.load(path)
kind = "sklearn"
quantized = False
else:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformer_model import _checkpoint_to_dir
path = os.path.join(CFG.models_dir, model_name)
if not os.path.isdir(path):
alt = os.path.join(CFG.models_dir, _checkpoint_to_dir(model_name))
if os.path.isdir(alt):
path = alt
else:
raise FileNotFoundError(
f"No model directory: {path}\n"
f"Hint: check saved_models/ for available directories."
)
int8_path = f"{path}_int8"
int8_file = os.path.join(int8_path, "model_int8.pt")
if os.path.exists(int8_file):
try:
torch.backends.quantized.engine = "qnnpack"
except Exception:
pass
try:
model = torch.load(int8_file, map_location="cpu", weights_only=False)
except TypeError:
model = torch.load(int8_file, map_location="cpu")
tokenizer = AutoTokenizer.from_pretrained(int8_path)
quantized = True
else:
model = AutoModelForSequenceClassification.from_pretrained(path)
tokenizer = AutoTokenizer.from_pretrained(path)
quantized = False
model.eval()
obj = (model, tokenizer)
kind = "transformer"
_registry[model_name] = {"obj": obj, "kind": kind, "quantized": quantized}
q = "int8" if quantized else "fp32"
logger.info(f"Model cached: {model_name} [{kind}:{q}]")
return obj, kind, quantized
def _infer_single(text: str, obj, kind: str) -> Dict:
if kind == "transformer":
model, tokenizer = obj
enc = tokenizer(text, truncation=True,
max_length=CFG.max_length, return_tensors="pt")
with torch.no_grad():
probs = torch.softmax(model(**enc).logits[0], dim=-1).numpy()
pred_id = int(np.argmax(probs))
conf = float(np.max(probs))
return {
"label_id": pred_id,
"label": CFG.label_names[pred_id],
"probabilities": {
CFG.label_names[i]: round(float(p), 4)
for i, p in enumerate(probs)
},
"confidence": conf,
}
# sklearn
pred_id = int(obj.predict([text])[0])
result = {"label_id": pred_id, "label": CFG.label_names[pred_id],
"probabilities": None, "confidence": 1.0}
clf = list(obj.named_steps.values())[-1]
if hasattr(clf, "predict_proba"):
probs = obj.predict_proba([text])[0]
result["probabilities"] = {
CFG.label_names[i]: round(float(p), 4) for i, p in enumerate(probs)
}
result["confidence"] = float(np.max(probs))
elif hasattr(clf, "decision_function"):
scores = obj.decision_function([text])
scores = np.asarray(scores, dtype=np.float64).reshape(1, -1)
scores = scores - np.max(scores, axis=1, keepdims=True)
exps = np.exp(scores)
probs = exps / np.sum(exps, axis=1, keepdims=True)
result["confidence"] = float(np.max(probs))
return result
def _infer_batch(texts: List[str], obj, kind: str) -> List[Dict]:
if kind == "transformer":
model, tokenizer = obj
results = []
batch_size = 16
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
enc = tokenizer(batch, truncation=True, max_length=CFG.max_length,
padding=True, return_tensors="pt")
with torch.no_grad():
logits = model(**enc).logits
probs_batch = torch.softmax(logits, dim=-1).numpy()
for text, probs in zip(batch, probs_batch):
pred_id = int(np.argmax(probs))
conf = float(np.max(probs))
results.append({
"label_id": pred_id,
"label": CFG.label_names[pred_id],
"probabilities": {
CFG.label_names[i]: round(float(p), 4)
for i, p in enumerate(probs)
},
"confidence": conf,
"text": text,
})
return results
# sklearn batch
preds = obj.predict(texts)
clf = list(obj.named_steps.values())[-1]
confidences = np.ones(len(texts), dtype=np.float64)
if hasattr(clf, "predict_proba"):
probs = obj.predict_proba(texts)
confidences = np.max(probs, axis=1)
elif hasattr(clf, "decision_function"):
scores = obj.decision_function(texts)
scores = np.asarray(scores, dtype=np.float64)
if scores.ndim == 1:
scores = np.stack([-scores, scores], axis=1)
scores = scores - np.max(scores, axis=1, keepdims=True)
exps = np.exp(scores)
probs = exps / np.sum(exps, axis=1, keepdims=True)
confidences = np.max(probs, axis=1)
results = []
for p, t, c in zip(preds, texts, confidences):
results.append(
{
"label_id": int(p),
"label": CFG.label_names[int(p)],
"probabilities": None,
"confidence": float(c),
"text": t,
}
)
return results
# ── FastAPI app ───────────────────────────────────────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Pre-warm the default model on server startup."""
try:
database.init_db()
_load_model("roberta_base")
logger.info("Default model (roberta_base) pre-loaded.")
except FileNotFoundError:
logger.warning("Default model not found; will load on first request.")
yield
_registry.clear()
logger.info("Model registry cleared.")
app = FastAPI(
title="Document Classifier API",
description=(
"Multi-class news text classification over four categories: "
"World · Sports · Business · Sci/Tech. "
"Supports traditional ML and transformer models."
),
version="1.0.0",
lifespan=lifespan,
)
from fastapi.middleware.cors import CORSMiddleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.middleware("http")
async def add_private_network_header(request: Request, call_next):
response = await call_next(request)
if "access-control-request-private-network" in request.headers:
response.headers["Access-Control-Allow-Private-Network"] = "true"
origin = request.headers.get("origin")
if origin:
response.headers["Access-Control-Allow-Origin"] = origin
return response
@app.get("/health", response_model=HealthResponse, tags=["Status"])
async def health():
"""Confirm the API is running, list loaded models, and report device."""
return HealthResponse(
status="ok",
loaded_models=list(_registry.keys()),
quantized={k: bool(v.get("quantized")) for k, v in _registry.items()},
device=CFG.device,
)
@app.get("/labels", tags=["Status"])
async def get_labels():
"""Return the four classification labels."""
return {
"labels": [
{"id": i, "name": n} for i, n in enumerate(CFG.label_names)
]
}
@app.get("/models", tags=["Status"])
async def list_available_models():
"""List all models that exist in saved_models/ and are ready to load."""
available = []
if os.path.isdir(CFG.models_dir):
for name in os.listdir(CFG.models_dir):
path = os.path.join(CFG.models_dir, name)
if name.endswith("_int8"):
continue
if os.path.isdir(path) and os.path.exists(
os.path.join(path, "config.json")
):
int8_file = os.path.join(f"{path}_int8", "model_int8.pt")
available.append(
{
"name": name,
"type": "transformer",
"quantized": bool(os.path.exists(int8_file)),
}
)
for fname in os.listdir(CFG.models_dir):
if fname.startswith("traditional_") and fname.endswith(".joblib"):
short = fname.replace("traditional_", "").replace(".joblib", "")
available.append({"name": short, "type": "sklearn", "quantized": False})
return {"models": available, "count": len(available)}
@app.post("/predict", response_model=Prediction, tags=["Inference"])
async def predict(req: PredictRequest):
"""Classify a single text document and return label + probabilities."""
t0 = time.perf_counter()
request_id = str(uuid4())
try:
obj, kind, _ = _load_model(req.model_name)
except FileNotFoundError as exc:
raise HTTPException(status_code=404, detail=str(exc))
result = _infer_single(req.text, obj, kind)
latency = (time.perf_counter() - t0) * 1000
confidence = float(result.get("confidence", 1.0))
is_low = bool(confidence < float(CFG.low_confidence_threshold))
database.log_request(
request_id=request_id,
model_name=req.model_name,
input_text=req.text,
predicted_label=str(result["label"]),
predicted_label_id=int(result["label_id"]),
confidence=confidence,
latency_ms=float(latency),
is_batch=False,
)
return Prediction(
text=req.text[:200],
request_id=request_id,
is_low_confidence=is_low,
latency_ms=round(latency, 2),
label_id=result["label_id"],
label=result["label"],
probabilities=result.get("probabilities"),
)
@app.post("/batch_predict", response_model=BatchResponse, tags=["Inference"])
async def batch_predict(req: BatchPredictRequest):
"""Classify a list of documents in one call (up to 256 texts)."""
t0 = time.perf_counter()
try:
obj, kind, _ = _load_model(req.model_name)
except FileNotFoundError as exc:
raise HTTPException(status_code=404, detail=str(exc))
raw_results = _infer_batch(req.texts, obj, kind)
total_ms = (time.perf_counter() - t0) * 1000
per_item_ms = (total_ms / len(req.texts)) if req.texts else 0.0
predictions = [
Prediction(
text=r["text"][:200],
request_id=str(uuid4()),
label_id=r["label_id"],
label=r["label"],
probabilities=r.get("probabilities"),
is_low_confidence=bool(float(r.get("confidence", 1.0)) < float(CFG.low_confidence_threshold)),
latency_ms=round(per_item_ms, 2),
)
for r in raw_results
]
for r, pred in zip(raw_results, predictions):
database.log_request(
request_id=pred.request_id,
model_name=req.model_name,
input_text=r["text"],
predicted_label=str(r["label"]),
predicted_label_id=int(r["label_id"]),
confidence=float(r.get("confidence", 1.0)),
latency_ms=float(per_item_ms),
is_batch=True,
)
return BatchResponse(
predictions=predictions,
count=len(predictions),
total_latency_ms=round(total_ms, 2),
)
@app.get("/analytics/summary", tags=["Analytics"])
async def analytics_summary(model_name: Optional[str] = None, days: int = 7):
return database.get_summary(model_name=model_name, days=days)
@app.get("/analytics/history", tags=["Analytics"])
async def analytics_history(limit: int = 50, offset: int = 0):
return database.get_request_history(limit=limit, offset=offset)
@app.get("/analytics/low_confidence", tags=["Analytics"])
async def analytics_low_confidence(reviewed: bool = False, limit: int = 50):
return database.get_low_confidence_flags(reviewed=reviewed, limit=limit)
class ReviewBody(BaseModel):
note: Optional[str] = None
@app.patch("/analytics/review/{request_id}", tags=["Analytics"])
async def analytics_mark_reviewed(request_id: str, body: ReviewBody):
database.mark_reviewed(request_id=request_id, note=body.note)
return {"request_id": request_id, "reviewed": True}
@app.post("/analytics/export_flags", tags=["Analytics"])
async def analytics_export_flags():
return database.export_low_confidence_to_folder()