Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File | |
| from fastapi.responses import JSONResponse | |
| import numpy as np | |
| import joblib | |
| import cv2 | |
| from PIL import Image | |
| import mediapipe as mp | |
| import os | |
| import io | |
| import math | |
| app = FastAPI() | |
| # ---------------- LOAD MODEL ---------------- | |
| MODEL_DIR = "model2" | |
| model = joblib.load(os.path.join(MODEL_DIR, "emotion_model.joblib")) | |
| label_encoder = joblib.load(os.path.join(MODEL_DIR, "label_encoder.joblib")) | |
| face_mesh = mp.solutions.face_mesh.FaceMesh( | |
| static_image_mode=True, | |
| max_num_faces=1, | |
| min_detection_confidence=0.5 | |
| ) | |
| # ---------------- FEATURE ORDER (FULL LIST) ---------------- | |
| FEATURE_ORDER = [ | |
| "mouth_width", "mouth_height", | |
| "left_eye_width", "left_eye_height", | |
| "right_eye_width", "right_eye_height", | |
| "left_eyebrow_height", "right_eyebrow_height", | |
| "lip_top_height", "lip_bottom_height", | |
| "nose_length", "face_width" | |
| ] | |
| # ---------------- FEATURE EXTRACTION FUNCTIONS ---------------- | |
| def euclidean(p1, p2): | |
| return math.sqrt( | |
| (p1[0] - p2[0]) ** 2 + | |
| (p1[1] - p2[1]) ** 2 + | |
| (p1[2] - p2[2]) ** 2 | |
| ) | |
| def compute_basic_features(landmarks, w, h): | |
| """Compute all required geometric features.""" | |
| def lx(i): return landmarks[i][0] * w | |
| def ly(i): return landmarks[i][1] * h | |
| def lz(i): return landmarks[i][2] | |
| def point(i): return (lx(i), ly(i), lz(i)) | |
| features = {} | |
| # Mouth metrics | |
| features["mouth_width"] = euclidean(point(61), point(291)) | |
| features["mouth_height"] = euclidean(point(0), point(17)) | |
| # Left eye metrics | |
| features["left_eye_width"] = euclidean(point(33), point(133)) | |
| features["left_eye_height"] = euclidean(point(159), point(145)) | |
| # Right eye metrics | |
| features["right_eye_width"] = euclidean(point(362), point(263)) | |
| features["right_eye_height"] = euclidean(point(386), point(374)) | |
| # Eyebrows | |
| features["left_eyebrow_height"] = euclidean(point(70), point(105)) | |
| features["right_eyebrow_height"] = euclidean(point(300), point(334)) | |
| # Lip heights | |
| features["lip_top_height"] = euclidean(point(13), point(14)) | |
| features["lip_bottom_height"] = euclidean(point(14), point(18)) | |
| # Nose | |
| features["nose_length"] = euclidean(point(1), point(197)) | |
| # Face width | |
| features["face_width"] = euclidean(point(127), point(356)) | |
| return features | |
| def extract_features(image): | |
| img_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| results = face_mesh.process(img_rgb) | |
| if not results.multi_face_landmarks: | |
| return None, ["No face detected"] | |
| landmarks = np.array( | |
| [(lm.x, lm.y, lm.z) for lm in results.multi_face_landmarks[0].landmark] | |
| ) | |
| h, w, _ = image.shape | |
| features = compute_basic_features(landmarks, w, h) | |
| # reorder features | |
| ordered = [features.get(f, 0) for f in FEATURE_ORDER] | |
| return np.array(ordered).reshape(1, -1), None | |
| def predict_emotion(image): | |
| X, error = extract_features(image) | |
| if error: | |
| return {"error": error} | |
| pred = model.predict(X)[0] | |
| prob = model.predict_proba(X).max() | |
| emotion = label_encoder.inverse_transform([pred])[0] | |
| return { | |
| "emotion": emotion, | |
| "confidence": float(prob) | |
| } | |
| # ---------------- API ENDPOINT ---------------- | |
| async def predict(image: UploadFile = File(...)): | |
| try: | |
| img_bytes = await image.read() | |
| img = Image.open(io.BytesIO(img_bytes)) | |
| img = np.array(img) | |
| # standardize image format | |
| if img.ndim == 2: | |
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |
| elif img.shape[2] == 4: | |
| img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR) | |
| else: | |
| img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) | |
| result = predict_emotion(img) | |
| return JSONResponse(result) | |
| except Exception as e: | |
| return JSONResponse({"error": str(e)}, status_code=500) | |
| def root(): | |
| return {"status": "API running", "endpoint": "/predict"} | |