from fastapi import FastAPI, UploadFile, File, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse import tensorflow as tf import numpy as np from PIL import Image import io import os import time app = FastAPI(title="AffectNet Facial Emotion Classifier") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) EMOTIONS = ["Anger", "Contempt", "Disgust", "Fear", "Happy", "Neutral", "Sad", "Surprise"] model = None _model_load_time = 0.0 @app.on_event("startup") def _load_model(): global model, _model_load_time model_path = os.environ.get("MODEL_PATH", "EfficientNetV2S_AffectNet_v2.keras") if not os.path.exists(model_path): raise RuntimeError(f"Model not found at {model_path}") t0 = time.time() model = tf.keras.models.load_model(model_path) _model_load_time = time.time() - t0 print(f"[STARTUP] Model loaded in {_model_load_time:.2f}s from {model_path}") @app.get("/health") def health(): return { "status": "ok", "model_loaded": model is not None, "load_time_s": round(_model_load_time, 2), "input_shape": [300, 300, 3], "emotions": EMOTIONS, } @app.post("/predict") async def predict(file: UploadFile = File(...)): if model is None: raise HTTPException(503, "Model not loaded") contents = await file.read() if not contents: raise HTTPException(400, "Empty file") img = Image.open(io.BytesIO(contents)).convert("RGB") img = img.resize((300, 300)) arr = np.expand_dims(np.array(img).astype(np.float32), axis=0) preds = model.predict(arr, verbose=0)[0] idx = int(preds.argmax()) return { "emotion": EMOTIONS[idx], "confidence": float(preds[idx]), "probabilities": {e: round(float(p), 4) for e, p in zip(EMOTIONS, preds)}, } @app.post("/predict_b64") async def predict_b64(data: dict): if model is None: raise HTTPException(503, "Model not loaded") b64 = data.get("image") if not b64: raise HTTPException(400, "Missing 'image' field with base64 JPEG data") import base64 try: raw = base64.b64decode(b64) except Exception: raise HTTPException(400, "Invalid base64") img = Image.open(io.BytesIO(raw)).convert("RGB") img = img.resize((300, 300)) arr = np.expand_dims(np.array(img).astype(np.float32), axis=0) preds = model.predict(arr, verbose=0)[0] idx = int(preds.argmax()) return { "emotion": EMOTIONS[idx], "confidence": float(preds[idx]), "probabilities": {e: round(float(p), 4) for e, p in zip(EMOTIONS, preds)}, }