Marwa-Khan
changes
8fcfa6d
import gradio as gr
import torch
import torch.nn.functional as F
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import numpy as np
MODEL_NAME = "mo-thecreator/vit-Facial-Expression-Recognition"
EMOTIONS = ["Angry", "Disgust", "Fear", "Happy", "Sad", "Surprise", "Neutral"]
# Load model and processor
processor = ViTImageProcessor.from_pretrained(MODEL_NAME)
model = ViTForImageClassification.from_pretrained(MODEL_NAME)
model.eval()
def analyze_emotion(image):
if image is None:
return "Upload an image", None
# Ensure RGB
image = image.convert("RGB")
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
probs = F.softmax(outputs.logits, dim=-1)[0].numpy()
top_idx = np.argmax(probs)
top_emotion = EMOTIONS[top_idx]
chart_data = {"emotion": EMOTIONS, "confidence": probs.tolist()}
result_text = f"Predicted Emotion: {top_emotion} ({probs[top_idx]*100:.1f}%)"
return result_text, chart_data
# Build Gradio interface
demo = gr.Interface(
fn=analyze_emotion,
inputs=gr.Image(type="pil", label="Upload Facial Image"),
outputs=[
gr.Textbox(label="Prediction"),
gr.BarPlot(x="emotion", y="confidence", y_lim=[0,1], label="Confidence")
],
title="Facial Expression Recognition (ViT)",
description="Upload a facial image and detect emotions (Angry, Disgust, Fear, Happy, Sad, Surprise, Neutral) using a Vision Transformer."
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)