Ars135 commited on
Commit
d940c82
·
verified ·
1 Parent(s): bfb6726

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -32
app.py CHANGED
@@ -1,36 +1,90 @@
1
  import gradio as gr
2
- import torch
3
- from torchvision import transforms
4
  from PIL import Image
5
- from transformers import AutoModelForImageClassification, AutoImageProcessor
6
-
7
- model_name = "nateraw/fer-2013"
8
-
9
- processor = AutoImageProcessor.from_pretrained(model_name)
10
- model = AutoModelForImageClassification.from_pretrained(model_name)
11
-
12
- transform = transforms.Compose([
13
- transforms.Resize((224, 224)),
14
- transforms.ToTensor()
15
- ])
16
-
17
- def predict(img):
18
- img = Image.fromarray(img).convert("RGB")
19
- inputs = processor(images=img, return_tensors="pt")
20
- with torch.no_grad():
21
- outputs = model(**inputs)
22
- logits = outputs.logits
23
- prob = logits.softmax(dim=1)
24
- score, label_id = torch.max(prob, dim=1)
25
- label = model.config.id2label[label_id.item()]
26
- return f"Emotion: {label} ({score.item():.2f})"
27
-
28
- ui = gr.Interface(
29
- fn=predict,
30
- inputs=gr.Image(type="numpy", label="Upload Image"),
31
- outputs="text",
32
- title="Emotion Detection (PyTorch)",
33
- description="Detect emotions using a lightweight PyTorch model (FER-2013)."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  )
35
 
36
- ui.launch()
 
 
 
1
  import gradio as gr
2
+ import numpy as np
 
3
  from PIL import Image
4
+ import torch
5
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
6
+
7
+ # --- Configuration ---
8
+ MODEL_NAME = "nateraw/fer-2013"
9
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+ # --- Model and Processor Loading ---
12
+ try:
13
+ processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
14
+ model = AutoModelForImageClassification.from_pretrained(MODEL_NAME).to(DEVICE)
15
+ model.eval()
16
+ LABELS = model.config.id2label
17
+
18
+ print(f"Model loaded successfully on device: {DEVICE}")
19
+ print(f"Available labels: {LABELS}")
20
+
21
+ except Exception as e:
22
+ print(f"Error loading model or processor: {e}")
23
+ # Use a dummy function if loading fails
24
+ processor = None
25
+ model = None
26
+ LABELS = {0: "dummy_emotion"}
27
+
28
+ # --- Inference Function ---
29
+ def classify_emotion(image_np: np.ndarray) -> str:
30
+ """
31
+ Performs emotion classification on an input image (numpy array).
32
+ """
33
+ if model is None or processor is None:
34
+ return f"Error: Model or processor failed to load. Check logs."
35
+
36
+ try:
37
+ # Convert numpy array (from Gradio) to PIL Image
38
+ image = Image.fromarray(image_np).convert("RGB")
39
+
40
+ # Preprocess the image
41
+ # The processor handles necessary resizing and normalization
42
+ inputs = processor(images=image, return_tensors="pt").to(DEVICE)
43
+
44
+ # Run inference
45
+ with torch.no_grad():
46
+ outputs = model(**inputs)
47
+
48
+ # Get predictions
49
+ logits = outputs.logits
50
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)
51
+
52
+ # Find the dominant emotion
53
+ confidence, predicted_class_idx = torch.max(probabilities, 1)
54
+
55
+ dominant_emotion = LABELS[predicted_class_idx.item()]
56
+ confidence_score = confidence.item()
57
+
58
+ # Format the result string
59
+ result_str = f"Emotion: **{dominant_emotion}** ({confidence_score:.2f})"
60
+ return result_str
61
+
62
+ except Exception as e:
63
+ return f"Prediction Error: {type(e).__name__} - {str(e)}"
64
+
65
+ # --- Gradio Interface ---
66
+ iface = gr.Interface(
67
+ fn=classify_emotion,
68
+ inputs=gr.Image(
69
+ type="numpy",
70
+ label="Upload an image of a face"
71
+ ),
72
+ outputs=gr.Markdown(label="Predicted Emotion"),
73
+ title="Emotion Detection (PyTorch/Transformers)",
74
+ description=(
75
+ "Upload an image containing a face to classify the dominant emotion. "
76
+ "Uses the **nateraw/fer-2013** PyTorch model from Hugging Face Transformers. "
77
+ "No TensorFlow or Keras dependencies."
78
+ ),
79
+ examples=[
80
+ # Providing simple examples is good practice
81
+ # Note: Gradio will handle downloading and using these if they exist in the repo
82
+ # Since this is for a new Space, use placeholder or common sense (omit paths if files aren't included)
83
+ ],
84
+ allow_flagging="never",
85
+ theme=gr.themes.Soft()
86
  )
87
 
88
+ # Launch the app (required for Hugging Face Spaces)
89
+ if __name__ == "__main__":
90
+ iface.launch()