thoeppner commited on
Commit
d1773cc
·
verified ·
1 Parent(s): 7a9c57f

added features

Browse files

Top 3 Emotionen
Confidence Einschätzung
Balkendiagramm (Bar Chart)
Schöner Gradio Output

Files changed (1) hide show
  1. app.py +36 -6
app.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
  from torchvision import models, transforms
3
  from PIL import Image
4
  import gradio as gr
 
5
  from huggingface_hub import hf_hub_download
6
 
7
  # Modell laden vom Hugging Face Model Hub
@@ -12,7 +13,6 @@ model_path = hf_hub_download(
12
  filename="emotion_model.pt"
13
  )
14
 
15
-
16
  model = models.resnet18()
17
  model.fc = torch.nn.Linear(model.fc.in_features, 9)
18
  model.load_state_dict(torch.load(model_path, map_location=device))
@@ -28,6 +28,17 @@ transform = transforms.Compose([
28
  transforms.ToTensor()
29
  ])
30
 
 
 
 
 
 
 
 
 
 
 
 
31
  def predict_emotion(image):
32
  image = image.convert("RGB")
33
  image = transform(image).unsqueeze(0).to(device)
@@ -35,20 +46,39 @@ def predict_emotion(image):
35
  with torch.no_grad():
36
  outputs = model(image)
37
  probs = torch.softmax(outputs, dim=1)
38
- confidence, predicted = torch.max(probs, 1)
39
 
 
 
 
 
 
 
 
 
 
40
  if confidence.item() < 0.7:
41
- return "Unbekannte Emotion", f"{confidence.item()*100:.2f}%"
42
  else:
43
- return labels[predicted.item()], f"{confidence.item()*100:.2f}%"
 
 
 
 
 
 
44
 
45
  # Gradio Interface
46
  interface = gr.Interface(
47
  fn=predict_emotion,
48
  inputs=gr.Image(type="pil"),
49
- outputs=["text", "text"],
 
 
 
 
 
50
  title="Emotion Recognition App",
51
- description="Lade ein Bild hoch und erkenne die Emotion."
52
  )
53
 
54
  interface.launch()
 
2
  from torchvision import models, transforms
3
  from PIL import Image
4
  import gradio as gr
5
+ import matplotlib.pyplot as plt
6
  from huggingface_hub import hf_hub_download
7
 
8
  # Modell laden vom Hugging Face Model Hub
 
13
  filename="emotion_model.pt"
14
  )
15
 
 
16
  model = models.resnet18()
17
  model.fc = torch.nn.Linear(model.fc.in_features, 9)
18
  model.load_state_dict(torch.load(model_path, map_location=device))
 
28
  transforms.ToTensor()
29
  ])
30
 
31
+ def plot_probabilities(probabilities, labels):
32
+ probs = probabilities.cpu().numpy().flatten()
33
+ fig, ax = plt.subplots(figsize=(8, 4))
34
+ ax.barh(labels, probs)
35
+ ax.set_xlim(0, 1)
36
+ ax.invert_yaxis() # Highest probability on top
37
+ ax.set_xlabel('Confidence')
38
+ ax.set_title('Emotion Probabilities')
39
+ plt.tight_layout()
40
+ return fig
41
+
42
  def predict_emotion(image):
43
  image = image.convert("RGB")
44
  image = transform(image).unsqueeze(0).to(device)
 
46
  with torch.no_grad():
47
  outputs = model(image)
48
  probs = torch.softmax(outputs, dim=1)
 
49
 
50
+ # Top 3 Predictions
51
+ top3_prob, top3_idx = torch.topk(probs, 3)
52
+ top3 = [(labels[i], f"{p.item()*100:.2f}%") for i, p in zip(top3_idx[0], top3_prob[0])]
53
+
54
+ # Overall Prediction
55
+ confidence, predicted = torch.max(probs, 1)
56
+ prediction = labels[predicted.item()]
57
+
58
+ # Unsicherheitswarnung
59
  if confidence.item() < 0.7:
60
+ prediction_status = "⚠️ Unsichere Vorhersage"
61
  else:
62
+ prediction_status = "✅ Sichere Vorhersage"
63
+
64
+ # Bar Chart
65
+ fig = plot_probabilities(probs, labels)
66
+
67
+ # Ausgabe
68
+ return prediction, f"Confidence: {confidence.item()*100:.2f}%\n{prediction_status}", top3, fig
69
 
70
  # Gradio Interface
71
  interface = gr.Interface(
72
  fn=predict_emotion,
73
  inputs=gr.Image(type="pil"),
74
+ outputs=[
75
+ gr.Textbox(label="Vorhergesagte Emotion"),
76
+ gr.Textbox(label="Confidence + Einschätzung"),
77
+ gr.Dataframe(headers=["Emotion", "Wahrscheinlichkeit (%)"], label="Top 3 Emotionen"),
78
+ gr.Plot(label="Verteilung der Emotionen")
79
+ ],
80
  title="Emotion Recognition App",
81
+ description="Lade ein Bild hoch und erkenne die Emotion. Zeigt auch die Top 3 Emotionen und alle Wahrscheinlichkeiten als Balkendiagramm."
82
  )
83
 
84
  interface.launch()