Marwa-Khan commited on
Commit
8fcfa6d
·
1 Parent(s): 1efb530
Files changed (1) hide show
  1. app.py +5 -9
app.py CHANGED
@@ -17,28 +17,23 @@ def analyze_emotion(image):
17
  if image is None:
18
  return "Upload an image", None
19
 
20
- # Preprocess
21
- if isinstance(image, np.ndarray):
22
- image = Image.fromarray(image)
23
- if image.mode != "RGB":
24
- image = image.convert("RGB")
25
 
26
  inputs = processor(images=image, return_tensors="pt")
27
  with torch.no_grad():
28
  outputs = model(**inputs)
29
  probs = F.softmax(outputs.logits, dim=-1)[0].numpy()
30
 
31
- # Get top emotion
32
  top_idx = np.argmax(probs)
33
  top_emotion = EMOTIONS[top_idx]
34
 
35
- # Prepare bar chart
36
  chart_data = {"emotion": EMOTIONS, "confidence": probs.tolist()}
37
-
38
  result_text = f"Predicted Emotion: {top_emotion} ({probs[top_idx]*100:.1f}%)"
39
 
40
  return result_text, chart_data
41
 
 
42
  # Build Gradio interface
43
  demo = gr.Interface(
44
  fn=analyze_emotion,
@@ -51,5 +46,6 @@ demo = gr.Interface(
51
  description="Upload a facial image and detect emotions (Angry, Disgust, Fear, Happy, Sad, Surprise, Neutral) using a Vision Transformer."
52
  )
53
 
 
54
  if __name__ == "__main__":
55
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
17
  if image is None:
18
  return "Upload an image", None
19
 
20
+ # Ensure RGB
21
+ image = image.convert("RGB")
 
 
 
22
 
23
  inputs = processor(images=image, return_tensors="pt")
24
  with torch.no_grad():
25
  outputs = model(**inputs)
26
  probs = F.softmax(outputs.logits, dim=-1)[0].numpy()
27
 
 
28
  top_idx = np.argmax(probs)
29
  top_emotion = EMOTIONS[top_idx]
30
 
 
31
  chart_data = {"emotion": EMOTIONS, "confidence": probs.tolist()}
 
32
  result_text = f"Predicted Emotion: {top_emotion} ({probs[top_idx]*100:.1f}%)"
33
 
34
  return result_text, chart_data
35
 
36
+
37
  # Build Gradio interface
38
  demo = gr.Interface(
39
  fn=analyze_emotion,
 
46
  description="Upload a facial image and detect emotions (Angry, Disgust, Fear, Happy, Sad, Surprise, Neutral) using a Vision Transformer."
47
  )
48
 
49
+
50
  if __name__ == "__main__":
51
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)