Ars135 commited on
Commit
0cde079
·
verified ·
1 Parent(s): d6b132d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -24
app.py CHANGED
@@ -4,27 +4,35 @@ from PIL import Image
4
  import torch
5
  from transformers import AutoImageProcessor, AutoModelForImageClassification
6
 
7
- # --- Configuration ---
8
  # --- Configuration ---
9
  MODEL_NAME = "nateraw/fer-2013"
10
- DEVICE = "cpu" # FORCING CPU to ensure model loads correctly on basic hardware
 
 
11
 
12
  # --- Model and Processor Loading ---
 
13
  try:
14
  processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
15
- model = AutoModelForImageClassification.from_pretrained(MODEL_NAME).to(DEVICE)
 
 
 
 
 
 
 
16
  model.eval()
17
  LABELS = model.config.id2label
18
 
19
  print(f"Model loaded successfully on device: {DEVICE}")
20
- print(f"Available labels: {LABELS}")
21
 
22
  except Exception as e:
23
- print(f"Error loading model or processor: {e}")
24
- # Use a dummy function if loading fails
25
  processor = None
26
  model = None
27
- LABELS = {0: "dummy_emotion"}
28
 
29
  # --- Inference Function ---
30
  def classify_emotion(image_np: np.ndarray) -> str:
@@ -32,14 +40,13 @@ def classify_emotion(image_np: np.ndarray) -> str:
32
  Performs emotion classification on an input image (numpy array).
33
  """
34
  if model is None or processor is None:
35
- return f"Error: Model or processor failed to load. Check logs."
36
 
37
  try:
38
  # Convert numpy array (from Gradio) to PIL Image
39
  image = Image.fromarray(image_np).convert("RGB")
40
 
41
  # Preprocess the image
42
- # The processor handles necessary resizing and normalization
43
  inputs = processor(images=image, return_tensors="pt").to(DEVICE)
44
 
45
  # Run inference
@@ -47,8 +54,7 @@ def classify_emotion(image_np: np.ndarray) -> str:
47
  outputs = model(**inputs)
48
 
49
  # Get predictions
50
- logits = outputs.logits
51
- probabilities = torch.nn.functional.softmax(logits, dim=-1)
52
 
53
  # Find the dominant emotion
54
  confidence, predicted_class_idx = torch.max(probabilities, 1)
@@ -56,12 +62,15 @@ def classify_emotion(image_np: np.ndarray) -> str:
56
  dominant_emotion = LABELS[predicted_class_idx.item()]
57
  confidence_score = confidence.item()
58
 
59
- # Format the result string
60
- result_str = f"Emotion: **{dominant_emotion}** ({confidence_score:.2f})"
 
 
 
61
  return result_str
62
 
63
  except Exception as e:
64
- return f"Prediction Error: {type(e).__name__} - {str(e)}"
65
 
66
  # --- Gradio Interface ---
67
  iface = gr.Interface(
@@ -71,21 +80,14 @@ iface = gr.Interface(
71
  label="Upload an image of a face"
72
  ),
73
  outputs=gr.Markdown(label="Predicted Emotion"),
74
- title="Emotion Detection (PyTorch/Transformers)",
75
  description=(
76
- "Upload an image containing a face to classify the dominant emotion. "
77
- "Uses the **nateraw/fer-2013** PyTorch model from Hugging Face Transformers. "
78
- "No TensorFlow or Keras dependencies."
79
  ),
80
- examples=[
81
- # Providing simple examples is good practice
82
- # Note: Gradio will handle downloading and using these if they exist in the repo
83
- # Since this is for a new Space, use placeholder or common sense (omit paths if files aren't included)
84
- ],
85
  allow_flagging="never",
86
  theme=gr.themes.Soft()
87
  )
88
 
89
- # Launch the app (required for Hugging Face Spaces)
90
  if __name__ == "__main__":
91
  iface.launch()
 
4
  import torch
5
  from transformers import AutoImageProcessor, AutoModelForImageClassification
6
 
 
7
  # --- Configuration ---
8
  MODEL_NAME = "nateraw/fer-2013"
9
+ # CRITICAL FIX: Explicitly set DEVICE to 'cpu' to prevent CUDA initialization errors
10
+ # and memory issues on default Hugging Face Spaces hardware.
11
+ DEVICE = "cpu"
12
 
13
  # --- Model and Processor Loading ---
14
+ # Load model outside the prediction function for efficiency
15
  try:
16
  processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
17
+
18
+ # CRITICAL FIX: Load model with map_location='cpu' to prevent Out-Of-Memory (OOM)
19
+ # errors during the loading process by mapping all tensors directly to CPU memory.
20
+ model = AutoModelForImageClassification.from_pretrained(
21
+ MODEL_NAME,
22
+ map_location=DEVICE
23
+ ).to(DEVICE)
24
+
25
  model.eval()
26
  LABELS = model.config.id2label
27
 
28
  print(f"Model loaded successfully on device: {DEVICE}")
 
29
 
30
  except Exception as e:
31
+ # If loading fails, ensure the error message is descriptive.
32
+ print(f"CRITICAL ERROR during model loading: {e}")
33
  processor = None
34
  model = None
35
+ LABELS = {0: "Load_Error"}
36
 
37
  # --- Inference Function ---
38
  def classify_emotion(image_np: np.ndarray) -> str:
 
40
  Performs emotion classification on an input image (numpy array).
41
  """
42
  if model is None or processor is None:
43
+ return "System Error: Model failed to initialize. Please perform a Factory Reboot or check Space logs."
44
 
45
  try:
46
  # Convert numpy array (from Gradio) to PIL Image
47
  image = Image.fromarray(image_np).convert("RGB")
48
 
49
  # Preprocess the image
 
50
  inputs = processor(images=image, return_tensors="pt").to(DEVICE)
51
 
52
  # Run inference
 
54
  outputs = model(**inputs)
55
 
56
  # Get predictions
57
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
 
58
 
59
  # Find the dominant emotion
60
  confidence, predicted_class_idx = torch.max(probabilities, 1)
 
62
  dominant_emotion = LABELS[predicted_class_idx.item()]
63
  confidence_score = confidence.item()
64
 
65
+ # Format the result with clear markdown
66
+ result_str = (
67
+ f"**Predicted Emotion:** **{dominant_emotion.upper()}**\n\n"
68
+ f"Confidence: {confidence_score:.2f}"
69
+ )
70
  return result_str
71
 
72
  except Exception as e:
73
+ return f"Prediction Runtime Error: {type(e).__name__} - {str(e)}"
74
 
75
  # --- Gradio Interface ---
76
  iface = gr.Interface(
 
80
  label="Upload an image of a face"
81
  ),
82
  outputs=gr.Markdown(label="Predicted Emotion"),
83
+ title="😊 PyTorch Facial Emotion Detection",
84
  description=(
85
+ "Upload an image to classify the dominant emotion. Uses the **nateraw/fer-2013** PyTorch model. "
86
+ "Built for stable deployment on Hugging Face Spaces."
 
87
  ),
 
 
 
 
 
88
  allow_flagging="never",
89
  theme=gr.themes.Soft()
90
  )
91
 
 
92
  if __name__ == "__main__":
93
  iface.launch()