image_model / app.py
A7md47's picture
Upload app.py with huggingface_hub
d581963 verified
Raw
History Blame Contribute Delete
2.76 kB
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
import tensorflow as tf
import numpy as np
from PIL import Image
import io
import os
import time
app = FastAPI(title="AffectNet Facial Emotion Classifier")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
EMOTIONS = ["Anger", "Contempt", "Disgust", "Fear", "Happy", "Neutral", "Sad", "Surprise"]
model = None
_model_load_time = 0.0
@app.on_event("startup")
def _load_model():
global model, _model_load_time
model_path = os.environ.get("MODEL_PATH", "EfficientNetV2S_AffectNet_v2.keras")
if not os.path.exists(model_path):
raise RuntimeError(f"Model not found at {model_path}")
t0 = time.time()
model = tf.keras.models.load_model(model_path)
_model_load_time = time.time() - t0
print(f"[STARTUP] Model loaded in {_model_load_time:.2f}s from {model_path}")
@app.get("/health")
def health():
return {
"status": "ok",
"model_loaded": model is not None,
"load_time_s": round(_model_load_time, 2),
"input_shape": [300, 300, 3],
"emotions": EMOTIONS,
}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
if model is None:
raise HTTPException(503, "Model not loaded")
contents = await file.read()
if not contents:
raise HTTPException(400, "Empty file")
img = Image.open(io.BytesIO(contents)).convert("RGB")
img = img.resize((300, 300))
arr = np.expand_dims(np.array(img).astype(np.float32), axis=0)
preds = model.predict(arr, verbose=0)[0]
idx = int(preds.argmax())
return {
"emotion": EMOTIONS[idx],
"confidence": float(preds[idx]),
"probabilities": {e: round(float(p), 4) for e, p in zip(EMOTIONS, preds)},
}
@app.post("/predict_b64")
async def predict_b64(data: dict):
if model is None:
raise HTTPException(503, "Model not loaded")
b64 = data.get("image")
if not b64:
raise HTTPException(400, "Missing 'image' field with base64 JPEG data")
import base64
try:
raw = base64.b64decode(b64)
except Exception:
raise HTTPException(400, "Invalid base64")
img = Image.open(io.BytesIO(raw)).convert("RGB")
img = img.resize((300, 300))
arr = np.expand_dims(np.array(img).astype(np.float32), axis=0)
preds = model.predict(arr, verbose=0)[0]
idx = int(preds.argmax())
return {
"emotion": EMOTIONS[idx],
"confidence": float(preds[idx]),
"probabilities": {e: round(float(p), 4) for e, p in zip(EMOTIONS, preds)},
}