hirumunasinghe's picture
Update app.py
c4a4b26 verified
"""FastAPI Emotion Detection API for ONNX models.
Features:
* Dynamic model input shape discovery (supports NCHW: [N, C, H, W])
* Automatic channel handling (grayscale C=1 or RGB C=3)
* Configurable preprocessing via environment variables:
* CENTER_CROP (default: true)
* INPUT_SCALE (0_1 or 0_255)
* CHANNEL_ORDER (RGB or BGR)
* NORM_MEAN / NORM_STD (comma separated floats, length == C)
* Label map can be provided as JSON string or a path via LABEL_MAP_JSON
* Safe softmax application only if output tensor isn't already probabilities
* Returns: emotion, confidence, and all_predictions dict
Environment Variables:
MODEL_PATH Path to ONNX file (default: model.onnx)
LABEL_MAP_JSON JSON object OR path to JSON file of {label: index}
ORT_PROVIDERS Comma separated ONNX Runtime providers (default CPUExecutionProvider)
CENTER_CROP true/false (default true)
INPUT_SCALE 0_1 or 0_255 (default 0_1)
(choose 0_255 if training used raw 0..255 then normalized manually)
CHANNEL_ORDER RGB (default) or BGR (applies after PIL conversion)
NORM_MEAN Comma separated mean values (if single value & C>1 will be broadcast)
NORM_STD Comma separated std values
DEBUG_LOG true/false to log top-3 predictions
Usage (multipart):
curl -X POST -F "file=@face.jpg" http://localhost:7860/predict
"""
from __future__ import annotations
import io
import json
import logging
import os
from typing import Dict, Tuple
import numpy as np
import onnxruntime as ort
from fastapi import FastAPI, File, HTTPException, UploadFile
from fastapi.responses import JSONResponse
from PIL import Image
logging.basicConfig(level=logging.INFO)
# ---------------------------------------------------------------------------
# Environment & Config
# ---------------------------------------------------------------------------
MODEL_PATH = os.getenv("MODEL_PATH", "model.onnx")
LABEL_MAP_ENV = os.getenv("LABEL_MAP_JSON")
PROVIDERS = [
p.strip() for p in os.getenv("ORT_PROVIDERS", "CPUExecutionProvider").split(",") if p.strip()
]
DEFAULT_LABEL_MAP: Dict[str, int] = {
"angry": 0,
"fearful": 1,
"happy": 2,
"neutral": 3,
"sad": 4,
"surprised": 5,
}
def _load_label_map(raw: str | None) -> Dict[str, int]:
"""Load label map from path or inline JSON string.
Accepts either:
- Path to a JSON file containing {label: index}
- Inline JSON string
Falls back to DEFAULT_LABEL_MAP on any failure.
"""
if not raw:
logging.info("LABEL_MAP_JSON not set; using default label map: %s", DEFAULT_LABEL_MAP)
return DEFAULT_LABEL_MAP
if os.path.isfile(raw):
try:
with open(raw, "r", encoding="utf-8") as f:
data = json.load(f)
logging.info("Loaded label map from file: %s", raw)
return data
except Exception as exc: # pragma: no cover - defensive
logging.warning("Failed reading label map file '%s': %s; using default", raw, exc)
return DEFAULT_LABEL_MAP
# Try inline JSON
try:
data = json.loads(raw)
logging.info("Loaded label map from inline JSON")
return data
except Exception as exc: # pragma: no cover - defensive
logging.warning("Failed to parse LABEL_MAP_JSON inline: %s; using default", exc)
return DEFAULT_LABEL_MAP
LABEL_MAP: Dict[str, int] = _load_label_map(LABEL_MAP_ENV)
INDEX_TO_LABEL = {v: k for k, v in LABEL_MAP.items()}
def _extract_chw(shape) -> Tuple[int, int, int]:
"""Extract channel, height, width from ONNX input shape.
shape typical forms: [N, C, H, W] or may include symbolic dims.
Fallbacks: C=3, H=224, W=224.
"""
def _as_int(val, default):
if isinstance(val, int):
return val
try:
return int(val) # symbolic numeric string
except Exception:
return default
c = _as_int(shape[1] if len(shape) > 1 else None, 3)
h = _as_int(shape[2] if len(shape) > 2 else None, 224)
w = _as_int(shape[3] if len(shape) > 3 else None, 224)
return c, h, w
# ---------------------------------------------------------------------------
# Initialize ONNX Runtime Session
# ---------------------------------------------------------------------------
try:
session = ort.InferenceSession(MODEL_PATH, providers=PROVIDERS)
except Exception as exc: # pragma: no cover - startup failure should be visible
raise RuntimeError(f"Failed to load ONNX model '{MODEL_PATH}': {exc}") from exc
input_meta = session.get_inputs()[0]
INPUT_NAME = input_meta.name
INPUT_SHAPE = input_meta.shape # e.g. [1, 3, 224, 224]
MODEL_C, MODEL_H, MODEL_W = _extract_chw(INPUT_SHAPE)
logging.info(
"Model loaded: %s | input name=%s shape=%s -> (C,H,W)=(%s,%s,%s)",
MODEL_PATH,
INPUT_NAME,
INPUT_SHAPE,
MODEL_C,
MODEL_H,
MODEL_W,
)
# ---------------------------------------------------------------------------
# Preprocessing Configuration
# ---------------------------------------------------------------------------
CENTER_CROP = os.getenv("CENTER_CROP", "true").lower() in {"1", "true", "yes", "y"}
INPUT_SCALE_MODE = os.getenv("INPUT_SCALE", "0_1").lower() # '0_1' or '0_255'
CHANNEL_ORDER = os.getenv("CHANNEL_ORDER", "RGB").upper() # 'RGB' or 'BGR'
def _parse_norm(val: str | None, expected: int, defaults: list[float]) -> np.ndarray:
if not val:
return np.array(defaults, dtype=np.float32)
try:
parts = [float(x.strip()) for x in val.split(",") if x.strip()]
if len(parts) == 1 and expected > 1: # broadcast single value
parts = parts * expected
if len(parts) != expected:
logging.warning(
"Normalization value length %d != expected %d; using defaults %s",
len(parts), expected, defaults,
)
return np.array(defaults, dtype=np.float32)
return np.array(parts, dtype=np.float32)
except Exception as exc: # pragma: no cover - defensive
logging.warning("Failed parsing normalization values '%s': %s; using defaults", val, exc)
return np.array(defaults, dtype=np.float32)
if MODEL_C == 3:
DEFAULT_MEAN = [0.485, 0.456, 0.406]
DEFAULT_STD = [0.229, 0.224, 0.225]
else: # grayscale
DEFAULT_MEAN = [0.5]
DEFAULT_STD = [0.5]
NORM_MEAN = _parse_norm(os.getenv("NORM_MEAN"), MODEL_C, DEFAULT_MEAN)
NORM_STD = _parse_norm(os.getenv("NORM_STD"), MODEL_C, DEFAULT_STD)
logging.info(
"Preprocess config: CENTER_CROP=%s INPUT_SCALE=%s CHANNEL_ORDER=%s MEAN=%s STD=%s",
CENTER_CROP, INPUT_SCALE_MODE, CHANNEL_ORDER, NORM_MEAN.tolist(), NORM_STD.tolist(),
)
def preprocess(pil_img: Image.Image) -> np.ndarray:
"""Preprocess image -> NCHW float32 ready for inference."""
# Convert to proper mode
if MODEL_C == 1:
pil_img = pil_img.convert("L")
else:
pil_img = pil_img.convert("RGB")
# Optional center crop to square
if CENTER_CROP:
w0, h0 = pil_img.size
side = min(w0, h0)
left = (w0 - side) // 2
top = (h0 - side) // 2
pil_img = pil_img.crop((left, top, left + side, top + side))
# Resize
pil_img = pil_img.resize((MODEL_W, MODEL_H))
arr = np.asarray(pil_img, dtype=np.float32)
# Channel order swap if requested (only relevant for RGB models)
if MODEL_C == 3 and CHANNEL_ORDER == "BGR":
arr = arr[:, :, ::-1]
# Scaling
if INPUT_SCALE_MODE == "0_1":
arr /= 255.0
elif INPUT_SCALE_MODE == "0_255":
pass # keep original 0..255 range
else:
logging.warning("Unknown INPUT_SCALE '%s'; defaulting to 0_1", INPUT_SCALE_MODE)
arr /= 255.0
# Normalization
if MODEL_C == 1:
arr = (arr - NORM_MEAN[0]) / NORM_STD[0]
# Add channel dim
arr = arr[None, None, :, :]
else:
arr = ((arr - NORM_MEAN) / NORM_STD).transpose(2, 0, 1) # HWC -> CHW
arr = arr[None, :, :, :]
return arr.astype(np.float32, copy=False)
# ---------------------------------------------------------------------------
# FastAPI App
# ---------------------------------------------------------------------------
app = FastAPI(title="Emotion ONNX API", version="1.0.0")
@app.get("/")
async def root(): # Simple health/info endpoint
return {
"status": "ok",
"model": os.path.basename(MODEL_PATH),
"input_shape": INPUT_SHAPE,
"labels": INDEX_TO_LABEL,
}
FILE_PARAM = File(..., description="Image file (jpg/png)")
@app.post("/predict")
async def predict(file: UploadFile = FILE_PARAM):
# Read & validate image
try:
bytes_data = await file.read()
pil_img = Image.open(io.BytesIO(bytes_data))
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Invalid image: {exc}") from exc
# Preprocess
inp = preprocess(pil_img)
# Run inference
try:
outputs = session.run(None, {INPUT_NAME: inp})
except Exception as exc:
raise HTTPException(status_code=500, detail=f"ONNX inference failed: {exc}") from exc
if not outputs:
raise HTTPException(status_code=500, detail="Model returned no outputs")
raw = outputs[0]
arr = np.asarray(raw)
# Accept (1, C) or (C,) shapes
if arr.ndim == 1:
logits = arr
elif arr.ndim == 2 and arr.shape[0] == 1:
logits = arr[0]
else:
raise HTTPException(status_code=500, detail=f"Unexpected output shape: {arr.shape}")
# Determine if already probabilities
if (
np.all((logits >= 0.0) & (logits <= 1.0))
and 0.98 <= float(np.sum(logits)) <= 1.02
):
probs = logits.astype(np.float32, copy=False)
else:
# Stable softmax
exp = np.exp(logits - np.max(logits))
denom = np.sum(exp)
probs = exp / (denom if denom != 0 else 1.0)
# Map to labels (only indices present in INDEX_TO_LABEL)
predictions: Dict[str, float] = {}
for idx, p in enumerate(probs):
label = INDEX_TO_LABEL.get(idx, f"label_{idx}")
predictions[label] = float(p)
top_idx = int(np.argmax(probs))
top_label = INDEX_TO_LABEL.get(top_idx, f"label_{top_idx}")
top_conf = float(probs[top_idx])
if os.getenv("DEBUG_LOG", "false").lower() in {"1", "true", "yes", "y"}:
order = np.argsort(-probs)[:3]
top3 = [
(
int(i),
INDEX_TO_LABEL.get(int(i), f"label_{int(i)}"),
float(probs[int(i)]),
)
for i in order
]
logging.info("Top-3: %s", top3)
return JSONResponse(
{
"emotion": top_label,
"confidence": top_conf,
"all_predictions": predictions,
}
)
# Optional: rudimentary ping for monitoring systems
@app.get("/healthz")
async def healthz():
return {"ok": True}