Ars135's picture
Update app.py
6c7d72c verified
import gradio as gr
import numpy as np
from PIL import Image
import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
# --- Configuration ---
# NEW, more stable ViT-based model for emotion detection
MODEL_NAME = "abhilash88/face-emotion-detection"
DEVICE = "cpu" # Explicitly set to CPU
# --- Model and Processor Loading ---
try:
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
# Load model with map_location='cpu' for memory-safe loading.
model = AutoModelForImageClassification.from_pretrained(
MODEL_NAME,
map_location=DEVICE
).to(DEVICE)
model.eval()
LABELS = model.config.id2label
print(f"Model loaded successfully on device: {DEVICE}")
except Exception as e:
print(f"CRITICAL ERROR during model loading: {e}")
processor = None
model = None
# If this ViT model fails, the only remaining cause is a lack of RAM.
LABELS = {0: "HARDWARE FAILURE: Free tier lacks sufficient RAM (OOM). Upgrade required."}
# --- Inference Function ---
def classify_emotion(image_np: np.ndarray) -> str:
"""Performs emotion classification on an input image (numpy array)."""
if model is None or processor is None:
return LABELS[0]
try:
image = Image.fromarray(image_np).convert("RGB")
inputs = processor(images=image, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
confidence, predicted_class_idx = torch.max(probabilities, 1)
dominant_emotion = LABELS[predicted_class_idx.item()]
confidence_score = confidence.item()
result_str = (
f"<h2 class='text-xl font-bold'>Predicted Emotion:</h2>"
f"<p class='text-3xl mt-2'>**{dominant_emotion.upper()}**</p>"
f"<p class='text-lg text-gray-600 mt-1'>Confidence: {confidence_score:.2f}</p>"
)
return result_str
except Exception as e:
return f"Prediction Runtime Error: {type(e).__name__} - {str(e)}"
# --- Gradio Interface ---
iface = gr.Interface(
fn=classify_emotion,
inputs=gr.Image(
type="numpy",
label="Upload an image of a face"
),
outputs=gr.Markdown(label="Predicted Emotion"),
title="๐Ÿ˜Š PyTorch Facial Emotion Detection (ViT Model)",
description=(
"Uses a stable ViT (Vision Transformer) model fine-tuned on the FER-2013 dataset."
),
allow_flagging="never",
theme=gr.themes.Soft()
)
if __name__ == "__main__":
iface.launch()