Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| import tensorflow as tf | |
| import numpy as np | |
| from PIL import Image | |
| import io | |
| import os | |
| import time | |
| app = FastAPI(title="AffectNet Facial Emotion Classifier") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| EMOTIONS = ["Anger", "Contempt", "Disgust", "Fear", "Happy", "Neutral", "Sad", "Surprise"] | |
| model = None | |
| _model_load_time = 0.0 | |
| def _load_model(): | |
| global model, _model_load_time | |
| model_path = os.environ.get("MODEL_PATH", "EfficientNetV2S_AffectNet_v2.keras") | |
| if not os.path.exists(model_path): | |
| raise RuntimeError(f"Model not found at {model_path}") | |
| t0 = time.time() | |
| model = tf.keras.models.load_model(model_path) | |
| _model_load_time = time.time() - t0 | |
| print(f"[STARTUP] Model loaded in {_model_load_time:.2f}s from {model_path}") | |
| def health(): | |
| return { | |
| "status": "ok", | |
| "model_loaded": model is not None, | |
| "load_time_s": round(_model_load_time, 2), | |
| "input_shape": [300, 300, 3], | |
| "emotions": EMOTIONS, | |
| } | |
| async def predict(file: UploadFile = File(...)): | |
| if model is None: | |
| raise HTTPException(503, "Model not loaded") | |
| contents = await file.read() | |
| if not contents: | |
| raise HTTPException(400, "Empty file") | |
| img = Image.open(io.BytesIO(contents)).convert("RGB") | |
| img = img.resize((300, 300)) | |
| arr = np.expand_dims(np.array(img).astype(np.float32), axis=0) | |
| preds = model.predict(arr, verbose=0)[0] | |
| idx = int(preds.argmax()) | |
| return { | |
| "emotion": EMOTIONS[idx], | |
| "confidence": float(preds[idx]), | |
| "probabilities": {e: round(float(p), 4) for e, p in zip(EMOTIONS, preds)}, | |
| } | |
| async def predict_b64(data: dict): | |
| if model is None: | |
| raise HTTPException(503, "Model not loaded") | |
| b64 = data.get("image") | |
| if not b64: | |
| raise HTTPException(400, "Missing 'image' field with base64 JPEG data") | |
| import base64 | |
| try: | |
| raw = base64.b64decode(b64) | |
| except Exception: | |
| raise HTTPException(400, "Invalid base64") | |
| img = Image.open(io.BytesIO(raw)).convert("RGB") | |
| img = img.resize((300, 300)) | |
| arr = np.expand_dims(np.array(img).astype(np.float32), axis=0) | |
| preds = model.predict(arr, verbose=0)[0] | |
| idx = int(preds.argmax()) | |
| return { | |
| "emotion": EMOTIONS[idx], | |
| "confidence": float(preds[idx]), | |
| "probabilities": {e: round(float(p), 4) for e, p in zip(EMOTIONS, preds)}, | |
| } | |