|
|
import gradio as gr |
|
|
from transformers import AutoModelForImageClassification, AutoImageProcessor |
|
|
from PIL import Image |
|
|
import torch |
|
|
|
|
|
model_id = "SEAR01/FER_model" |
|
|
try: |
|
|
processor = AutoImageProcessor.from_pretrained(model_id) |
|
|
model = AutoModelForImageClassification.from_pretrained(model_id, trust_remote_code=True) |
|
|
except Exception as e: |
|
|
raise ValueError(f"Model load failed: {e}. Check repo files.") |
|
|
|
|
|
emotion_labels = ['angry', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprise'] |
|
|
|
|
|
def predict_emotion(image): |
|
|
if image is None: |
|
|
return "Upload an image." |
|
|
inputs = processor(images=image, return_tensors="pt") |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
predicted = outputs.logits.argmax(-1).item() |
|
|
emotion = emotion_labels[predicted] |
|
|
confidence = torch.softmax(outputs.logits, dim=-1)[0][predicted].item() |
|
|
return f"Emotion: {emotion} (Confidence: {confidence:.2f})" |
|
|
|
|
|
iface = gr.Interface(fn=predict_emotion, inputs=gr.Image(type="pil"), outputs="text", title="FER Demo") |
|
|
if __name__ == "__main__": |
|
|
iface.launch() |