|
|
"""FastAPI Emotion Detection API for ONNX models. |
|
|
|
|
|
Features: |
|
|
* Dynamic model input shape discovery (supports NCHW: [N, C, H, W]) |
|
|
* Automatic channel handling (grayscale C=1 or RGB C=3) |
|
|
* Configurable preprocessing via environment variables: |
|
|
* CENTER_CROP (default: true) |
|
|
* INPUT_SCALE (0_1 or 0_255) |
|
|
* CHANNEL_ORDER (RGB or BGR) |
|
|
* NORM_MEAN / NORM_STD (comma separated floats, length == C) |
|
|
* Label map can be provided as JSON string or a path via LABEL_MAP_JSON |
|
|
* Safe softmax application only if output tensor isn't already probabilities |
|
|
* Returns: emotion, confidence, and all_predictions dict |
|
|
|
|
|
Environment Variables: |
|
|
MODEL_PATH Path to ONNX file (default: model.onnx) |
|
|
LABEL_MAP_JSON JSON object OR path to JSON file of {label: index} |
|
|
ORT_PROVIDERS Comma separated ONNX Runtime providers (default CPUExecutionProvider) |
|
|
CENTER_CROP true/false (default true) |
|
|
INPUT_SCALE 0_1 or 0_255 (default 0_1) |
|
|
(choose 0_255 if training used raw 0..255 then normalized manually) |
|
|
CHANNEL_ORDER RGB (default) or BGR (applies after PIL conversion) |
|
|
NORM_MEAN Comma separated mean values (if single value & C>1 will be broadcast) |
|
|
NORM_STD Comma separated std values |
|
|
DEBUG_LOG true/false to log top-3 predictions |
|
|
|
|
|
Usage (multipart): |
|
|
curl -X POST -F "file=@face.jpg" http://localhost:7860/predict |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import io |
|
|
import json |
|
|
import logging |
|
|
import os |
|
|
from typing import Dict, Tuple |
|
|
|
|
|
import numpy as np |
|
|
import onnxruntime as ort |
|
|
from fastapi import FastAPI, File, HTTPException, UploadFile |
|
|
from fastapi.responses import JSONResponse |
|
|
from PIL import Image |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_PATH = os.getenv("MODEL_PATH", "model.onnx") |
|
|
LABEL_MAP_ENV = os.getenv("LABEL_MAP_JSON") |
|
|
PROVIDERS = [ |
|
|
p.strip() for p in os.getenv("ORT_PROVIDERS", "CPUExecutionProvider").split(",") if p.strip() |
|
|
] |
|
|
|
|
|
DEFAULT_LABEL_MAP: Dict[str, int] = { |
|
|
"angry": 0, |
|
|
"fearful": 1, |
|
|
"happy": 2, |
|
|
"neutral": 3, |
|
|
"sad": 4, |
|
|
"surprised": 5, |
|
|
} |
|
|
|
|
|
|
|
|
def _load_label_map(raw: str | None) -> Dict[str, int]: |
|
|
"""Load label map from path or inline JSON string. |
|
|
|
|
|
Accepts either: |
|
|
- Path to a JSON file containing {label: index} |
|
|
- Inline JSON string |
|
|
Falls back to DEFAULT_LABEL_MAP on any failure. |
|
|
""" |
|
|
if not raw: |
|
|
logging.info("LABEL_MAP_JSON not set; using default label map: %s", DEFAULT_LABEL_MAP) |
|
|
return DEFAULT_LABEL_MAP |
|
|
|
|
|
if os.path.isfile(raw): |
|
|
try: |
|
|
with open(raw, "r", encoding="utf-8") as f: |
|
|
data = json.load(f) |
|
|
logging.info("Loaded label map from file: %s", raw) |
|
|
return data |
|
|
except Exception as exc: |
|
|
logging.warning("Failed reading label map file '%s': %s; using default", raw, exc) |
|
|
return DEFAULT_LABEL_MAP |
|
|
|
|
|
|
|
|
try: |
|
|
data = json.loads(raw) |
|
|
logging.info("Loaded label map from inline JSON") |
|
|
return data |
|
|
except Exception as exc: |
|
|
logging.warning("Failed to parse LABEL_MAP_JSON inline: %s; using default", exc) |
|
|
return DEFAULT_LABEL_MAP |
|
|
|
|
|
|
|
|
LABEL_MAP: Dict[str, int] = _load_label_map(LABEL_MAP_ENV) |
|
|
INDEX_TO_LABEL = {v: k for k, v in LABEL_MAP.items()} |
|
|
|
|
|
|
|
|
def _extract_chw(shape) -> Tuple[int, int, int]: |
|
|
"""Extract channel, height, width from ONNX input shape. |
|
|
|
|
|
shape typical forms: [N, C, H, W] or may include symbolic dims. |
|
|
Fallbacks: C=3, H=224, W=224. |
|
|
""" |
|
|
|
|
|
def _as_int(val, default): |
|
|
if isinstance(val, int): |
|
|
return val |
|
|
try: |
|
|
return int(val) |
|
|
except Exception: |
|
|
return default |
|
|
|
|
|
c = _as_int(shape[1] if len(shape) > 1 else None, 3) |
|
|
h = _as_int(shape[2] if len(shape) > 2 else None, 224) |
|
|
w = _as_int(shape[3] if len(shape) > 3 else None, 224) |
|
|
return c, h, w |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
session = ort.InferenceSession(MODEL_PATH, providers=PROVIDERS) |
|
|
except Exception as exc: |
|
|
raise RuntimeError(f"Failed to load ONNX model '{MODEL_PATH}': {exc}") from exc |
|
|
|
|
|
input_meta = session.get_inputs()[0] |
|
|
INPUT_NAME = input_meta.name |
|
|
INPUT_SHAPE = input_meta.shape |
|
|
MODEL_C, MODEL_H, MODEL_W = _extract_chw(INPUT_SHAPE) |
|
|
logging.info( |
|
|
"Model loaded: %s | input name=%s shape=%s -> (C,H,W)=(%s,%s,%s)", |
|
|
MODEL_PATH, |
|
|
INPUT_NAME, |
|
|
INPUT_SHAPE, |
|
|
MODEL_C, |
|
|
MODEL_H, |
|
|
MODEL_W, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CENTER_CROP = os.getenv("CENTER_CROP", "true").lower() in {"1", "true", "yes", "y"} |
|
|
INPUT_SCALE_MODE = os.getenv("INPUT_SCALE", "0_1").lower() |
|
|
CHANNEL_ORDER = os.getenv("CHANNEL_ORDER", "RGB").upper() |
|
|
|
|
|
|
|
|
def _parse_norm(val: str | None, expected: int, defaults: list[float]) -> np.ndarray: |
|
|
if not val: |
|
|
return np.array(defaults, dtype=np.float32) |
|
|
try: |
|
|
parts = [float(x.strip()) for x in val.split(",") if x.strip()] |
|
|
if len(parts) == 1 and expected > 1: |
|
|
parts = parts * expected |
|
|
if len(parts) != expected: |
|
|
logging.warning( |
|
|
"Normalization value length %d != expected %d; using defaults %s", |
|
|
len(parts), expected, defaults, |
|
|
) |
|
|
return np.array(defaults, dtype=np.float32) |
|
|
return np.array(parts, dtype=np.float32) |
|
|
except Exception as exc: |
|
|
logging.warning("Failed parsing normalization values '%s': %s; using defaults", val, exc) |
|
|
return np.array(defaults, dtype=np.float32) |
|
|
|
|
|
|
|
|
if MODEL_C == 3: |
|
|
DEFAULT_MEAN = [0.485, 0.456, 0.406] |
|
|
DEFAULT_STD = [0.229, 0.224, 0.225] |
|
|
else: |
|
|
DEFAULT_MEAN = [0.5] |
|
|
DEFAULT_STD = [0.5] |
|
|
|
|
|
NORM_MEAN = _parse_norm(os.getenv("NORM_MEAN"), MODEL_C, DEFAULT_MEAN) |
|
|
NORM_STD = _parse_norm(os.getenv("NORM_STD"), MODEL_C, DEFAULT_STD) |
|
|
logging.info( |
|
|
"Preprocess config: CENTER_CROP=%s INPUT_SCALE=%s CHANNEL_ORDER=%s MEAN=%s STD=%s", |
|
|
CENTER_CROP, INPUT_SCALE_MODE, CHANNEL_ORDER, NORM_MEAN.tolist(), NORM_STD.tolist(), |
|
|
) |
|
|
|
|
|
|
|
|
def preprocess(pil_img: Image.Image) -> np.ndarray: |
|
|
"""Preprocess image -> NCHW float32 ready for inference.""" |
|
|
|
|
|
if MODEL_C == 1: |
|
|
pil_img = pil_img.convert("L") |
|
|
else: |
|
|
pil_img = pil_img.convert("RGB") |
|
|
|
|
|
|
|
|
if CENTER_CROP: |
|
|
w0, h0 = pil_img.size |
|
|
side = min(w0, h0) |
|
|
left = (w0 - side) // 2 |
|
|
top = (h0 - side) // 2 |
|
|
pil_img = pil_img.crop((left, top, left + side, top + side)) |
|
|
|
|
|
|
|
|
pil_img = pil_img.resize((MODEL_W, MODEL_H)) |
|
|
|
|
|
arr = np.asarray(pil_img, dtype=np.float32) |
|
|
|
|
|
|
|
|
if MODEL_C == 3 and CHANNEL_ORDER == "BGR": |
|
|
arr = arr[:, :, ::-1] |
|
|
|
|
|
|
|
|
if INPUT_SCALE_MODE == "0_1": |
|
|
arr /= 255.0 |
|
|
elif INPUT_SCALE_MODE == "0_255": |
|
|
pass |
|
|
else: |
|
|
logging.warning("Unknown INPUT_SCALE '%s'; defaulting to 0_1", INPUT_SCALE_MODE) |
|
|
arr /= 255.0 |
|
|
|
|
|
|
|
|
if MODEL_C == 1: |
|
|
arr = (arr - NORM_MEAN[0]) / NORM_STD[0] |
|
|
|
|
|
arr = arr[None, None, :, :] |
|
|
else: |
|
|
arr = ((arr - NORM_MEAN) / NORM_STD).transpose(2, 0, 1) |
|
|
arr = arr[None, :, :, :] |
|
|
|
|
|
return arr.astype(np.float32, copy=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI(title="Emotion ONNX API", version="1.0.0") |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
return { |
|
|
"status": "ok", |
|
|
"model": os.path.basename(MODEL_PATH), |
|
|
"input_shape": INPUT_SHAPE, |
|
|
"labels": INDEX_TO_LABEL, |
|
|
} |
|
|
|
|
|
|
|
|
FILE_PARAM = File(..., description="Image file (jpg/png)") |
|
|
|
|
|
|
|
|
@app.post("/predict") |
|
|
async def predict(file: UploadFile = FILE_PARAM): |
|
|
|
|
|
try: |
|
|
bytes_data = await file.read() |
|
|
pil_img = Image.open(io.BytesIO(bytes_data)) |
|
|
except Exception as exc: |
|
|
raise HTTPException(status_code=400, detail=f"Invalid image: {exc}") from exc |
|
|
|
|
|
|
|
|
inp = preprocess(pil_img) |
|
|
|
|
|
|
|
|
try: |
|
|
outputs = session.run(None, {INPUT_NAME: inp}) |
|
|
except Exception as exc: |
|
|
raise HTTPException(status_code=500, detail=f"ONNX inference failed: {exc}") from exc |
|
|
|
|
|
if not outputs: |
|
|
raise HTTPException(status_code=500, detail="Model returned no outputs") |
|
|
|
|
|
raw = outputs[0] |
|
|
arr = np.asarray(raw) |
|
|
|
|
|
|
|
|
if arr.ndim == 1: |
|
|
logits = arr |
|
|
elif arr.ndim == 2 and arr.shape[0] == 1: |
|
|
logits = arr[0] |
|
|
else: |
|
|
raise HTTPException(status_code=500, detail=f"Unexpected output shape: {arr.shape}") |
|
|
|
|
|
|
|
|
if ( |
|
|
np.all((logits >= 0.0) & (logits <= 1.0)) |
|
|
and 0.98 <= float(np.sum(logits)) <= 1.02 |
|
|
): |
|
|
probs = logits.astype(np.float32, copy=False) |
|
|
else: |
|
|
|
|
|
exp = np.exp(logits - np.max(logits)) |
|
|
denom = np.sum(exp) |
|
|
probs = exp / (denom if denom != 0 else 1.0) |
|
|
|
|
|
|
|
|
predictions: Dict[str, float] = {} |
|
|
for idx, p in enumerate(probs): |
|
|
label = INDEX_TO_LABEL.get(idx, f"label_{idx}") |
|
|
predictions[label] = float(p) |
|
|
|
|
|
top_idx = int(np.argmax(probs)) |
|
|
top_label = INDEX_TO_LABEL.get(top_idx, f"label_{top_idx}") |
|
|
top_conf = float(probs[top_idx]) |
|
|
|
|
|
if os.getenv("DEBUG_LOG", "false").lower() in {"1", "true", "yes", "y"}: |
|
|
order = np.argsort(-probs)[:3] |
|
|
top3 = [ |
|
|
( |
|
|
int(i), |
|
|
INDEX_TO_LABEL.get(int(i), f"label_{int(i)}"), |
|
|
float(probs[int(i)]), |
|
|
) |
|
|
for i in order |
|
|
] |
|
|
logging.info("Top-3: %s", top3) |
|
|
|
|
|
return JSONResponse( |
|
|
{ |
|
|
"emotion": top_label, |
|
|
"confidence": top_conf, |
|
|
"all_predictions": predictions, |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/healthz") |
|
|
async def healthz(): |
|
|
return {"ok": True} |
|
|
|
|
|
|