thoeppner commited on
Commit
97adcd6
·
verified ·
1 Parent(s): 39fced5

Update app.py

Browse files

added zero shot

Files changed (1) hide show
  1. app.py +70 -33
app.py CHANGED
@@ -7,10 +7,11 @@ import pandas as pd
7
  import numpy as np
8
  import os
9
  import hashlib
10
- from huggingface_hub import hf_hub_download
11
  import cv2
 
 
12
 
13
- # Modell laden vom Hugging Face Model Hub
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
  model_path = hf_hub_download(
@@ -24,24 +25,41 @@ model.load_state_dict(torch.load(model_path, map_location=device))
24
  model = model.to(device)
25
  model.eval()
26
 
27
- # Labels
 
 
 
 
 
28
  labels = ["happy", "sad", "angry", "surprised", "fear", "disgust", "neutral", "contempt", "unknown"]
29
 
30
- # Transformation
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  transform = transforms.Compose([
32
  transforms.Resize((224, 224)),
33
  transforms.ToTensor()
34
  ])
35
 
36
- # Feedback-File
37
  FEEDBACK_FILE = "user_feedback.csv"
38
 
39
- # Hilfsfunktion für Hash des Bildes
40
  def get_image_hash(image):
41
  img_bytes = image.tobytes()
42
  return hashlib.md5(img_bytes).hexdigest()
43
 
44
- # Plot-Funktion für Wahrscheinlichkeiten
45
  def plot_probabilities(probabilities, labels):
46
  probs = probabilities.cpu().numpy().flatten()
47
  fig, ax = plt.subplots(figsize=(8, 4))
@@ -53,11 +71,9 @@ def plot_probabilities(probabilities, labels):
53
  plt.tight_layout()
54
  return fig
55
 
56
- # Grad-CAM Hilfsfunktion
57
  def generate_gradcam(image, model, class_idx):
58
  model.eval()
59
 
60
- # Hook für Features und Gradients
61
  gradients = []
62
  activations = []
63
 
@@ -68,21 +84,16 @@ def generate_gradcam(image, model, class_idx):
68
  activations.append(output)
69
  output.register_hook(save_gradient)
70
 
71
- # Letztes Convolutional Layer
72
  target_layer = model.layer4[1].conv2
73
  handle = target_layer.register_forward_hook(forward_hook)
74
 
75
  image_tensor = transform(image).unsqueeze(0).to(device)
76
  output = model(image_tensor)
77
 
78
- # Softmax -> Klasse auswählen
79
- pred_class = output.argmax(dim=1).item()
80
-
81
  model.zero_grad()
82
  class_score = output[0, class_idx]
83
  class_score.backward()
84
 
85
- # Gradients und Activations holen
86
  gradients = gradients[0].cpu().data.numpy()[0]
87
  activations = activations[0].cpu().data.numpy()[0]
88
 
@@ -95,18 +106,22 @@ def generate_gradcam(image, model, class_idx):
95
  gradcam = np.maximum(gradcam, 0)
96
  gradcam = cv2.resize(gradcam, (224, 224))
97
  gradcam = gradcam - np.min(gradcam)
98
- gradcam = gradcam / np.max(gradcam)
 
99
 
100
- # Bild zurückkonvertieren
101
  heatmap = cv2.applyColorMap(np.uint8(255 * gradcam), cv2.COLORMAP_JET)
102
  image_np = np.array(image.resize((224, 224)).convert("RGB"))
 
 
 
 
103
  overlay = cv2.addWeighted(image_np, 0.6, heatmap, 0.4, 0)
104
 
105
- handle.remove() # Hook entfernen
106
 
107
  return Image.fromarray(overlay)
108
 
109
- # Prediction-Funktion
110
  def predict_emotion(image):
111
  image = image.convert("RGB")
112
  transformed_image = transform(image).unsqueeze(0).to(device)
@@ -115,32 +130,46 @@ def predict_emotion(image):
115
  outputs = model(transformed_image)
116
  probs = torch.softmax(outputs, dim=1)
117
 
118
- # Top 3 Predictions
119
  top3_prob, top3_idx = torch.topk(probs, 3)
120
  top3 = [(labels[i], f"{p.item()*100:.2f}%") for i, p in zip(top3_idx[0], top3_prob[0])]
121
 
122
- # Overall Prediction
123
  confidence, predicted = torch.max(probs, 1)
124
  prediction = labels[predicted.item()]
125
 
126
- # Unsicherheitswarnung
127
  if confidence.item() < 0.7:
128
  prediction_status = "⚠️ Unsichere Vorhersage"
129
  else:
130
  prediction_status = "✅ Sichere Vorhersage"
131
 
132
- # Bar Chart
133
  fig = plot_probabilities(probs, labels)
134
-
135
- # Bild-Hash für spätere Zuordnung
136
  img_hash = get_image_hash(image)
137
-
138
- # Grad-CAM Overlay
139
  gradcam_img = generate_gradcam(image, model, predicted.item())
140
 
141
  return prediction, f"Confidence: {confidence.item()*100:.2f}%\n{prediction_status}", top3, fig, gradcam_img, img_hash
142
 
143
- # Funktion um Feedback zu speichern
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  def save_feedback(img_hash, model_prediction, user_feedback, confidence):
145
  data = {
146
  "image_hash": [img_hash],
@@ -157,20 +186,21 @@ def save_feedback(img_hash, model_prediction, user_feedback, confidence):
157
  df_new.to_csv(FEEDBACK_FILE, index=False)
158
  return "✅ Vielen Dank für dein Feedback!"
159
 
160
- # Download Funktion
161
  def download_feedback():
162
  if os.path.exists(FEEDBACK_FILE):
163
  return FEEDBACK_FILE
164
  else:
165
  return None
166
 
167
- # Kombinierte Funktion
168
  def full_pipeline(image, user_feedback):
169
  prediction, confidence_text, top3, fig, gradcam_img, img_hash = predict_emotion(image)
 
170
  feedback_message = save_feedback(img_hash, prediction, user_feedback, confidence_text.split("\n")[0])
171
- return prediction, confidence_text, top3, fig, gradcam_img, feedback_message
172
 
173
- # Gradio Interface
174
  with gr.Blocks() as interface:
175
  with gr.Row():
176
  with gr.Column():
@@ -179,17 +209,24 @@ with gr.Blocks() as interface:
179
  submit_btn = gr.Button("Absenden")
180
  download_btn = gr.Button("Feedback-Daten herunterladen")
181
  with gr.Column():
182
- prediction_output = gr.Textbox(label="Vorhergesagte Emotion")
183
  confidence_output = gr.Textbox(label="Confidence + Einschätzung")
184
  top3_output = gr.Dataframe(headers=["Emotion", "Wahrscheinlichkeit (%)"], label="Top 3 Emotionen")
185
  plot_output = gr.Plot(label="Verteilung der Emotionen")
186
  gradcam_output = gr.Image(label="Grad-CAM Visualisierung")
 
 
187
  feedback_message_output = gr.Textbox(label="Feedback-Status")
188
 
189
  submit_btn.click(
190
  fn=full_pipeline,
191
  inputs=[image_input, feedback_input],
192
- outputs=[prediction_output, confidence_output, top3_output, plot_output, gradcam_output, feedback_message_output]
 
 
 
 
 
193
  )
194
 
195
  download_btn.click(
 
7
  import numpy as np
8
  import os
9
  import hashlib
 
10
  import cv2
11
+ from huggingface_hub import hf_hub_download
12
+ from transformers import CLIPProcessor, CLIPModel
13
 
14
+ # === Dein trainiertes Modell laden ===
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
  model_path = hf_hub_download(
 
25
  model = model.to(device)
26
  model.eval()
27
 
28
+ # === Zero-Shot Modell (CLIP) laden ===
29
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
30
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
31
+ clip_model.eval()
32
+
33
+ # === Labels ===
34
  labels = ["happy", "sad", "angry", "surprised", "fear", "disgust", "neutral", "contempt", "unknown"]
35
 
36
+ # Zero-Shot Text Prompts
37
+ zero_shot_prompts = [
38
+ "a happy person",
39
+ "a sad person",
40
+ "an angry person",
41
+ "a surprised person",
42
+ "a fearful person",
43
+ "a disgusted person",
44
+ "a neutral person",
45
+ "a contemptuous person",
46
+ "an unknown emotion"
47
+ ]
48
+
49
+ # === Transformation für Bilder ===
50
  transform = transforms.Compose([
51
  transforms.Resize((224, 224)),
52
  transforms.ToTensor()
53
  ])
54
 
55
+ # === Feedback-File ===
56
  FEEDBACK_FILE = "user_feedback.csv"
57
 
58
+ # === Hilfsfunktionen ===
59
  def get_image_hash(image):
60
  img_bytes = image.tobytes()
61
  return hashlib.md5(img_bytes).hexdigest()
62
 
 
63
  def plot_probabilities(probabilities, labels):
64
  probs = probabilities.cpu().numpy().flatten()
65
  fig, ax = plt.subplots(figsize=(8, 4))
 
71
  plt.tight_layout()
72
  return fig
73
 
 
74
  def generate_gradcam(image, model, class_idx):
75
  model.eval()
76
 
 
77
  gradients = []
78
  activations = []
79
 
 
84
  activations.append(output)
85
  output.register_hook(save_gradient)
86
 
 
87
  target_layer = model.layer4[1].conv2
88
  handle = target_layer.register_forward_hook(forward_hook)
89
 
90
  image_tensor = transform(image).unsqueeze(0).to(device)
91
  output = model(image_tensor)
92
 
 
 
 
93
  model.zero_grad()
94
  class_score = output[0, class_idx]
95
  class_score.backward()
96
 
 
97
  gradients = gradients[0].cpu().data.numpy()[0]
98
  activations = activations[0].cpu().data.numpy()[0]
99
 
 
106
  gradcam = np.maximum(gradcam, 0)
107
  gradcam = cv2.resize(gradcam, (224, 224))
108
  gradcam = gradcam - np.min(gradcam)
109
+ if np.max(gradcam) != 0:
110
+ gradcam = gradcam / np.max(gradcam)
111
 
 
112
  heatmap = cv2.applyColorMap(np.uint8(255 * gradcam), cv2.COLORMAP_JET)
113
  image_np = np.array(image.resize((224, 224)).convert("RGB"))
114
+
115
+ if heatmap.shape != image_np.shape:
116
+ heatmap = cv2.resize(heatmap, (image_np.shape[1], image_np.shape[0]))
117
+
118
  overlay = cv2.addWeighted(image_np, 0.6, heatmap, 0.4, 0)
119
 
120
+ handle.remove()
121
 
122
  return Image.fromarray(overlay)
123
 
124
+ # === Dein Modell: Prediction ===
125
  def predict_emotion(image):
126
  image = image.convert("RGB")
127
  transformed_image = transform(image).unsqueeze(0).to(device)
 
130
  outputs = model(transformed_image)
131
  probs = torch.softmax(outputs, dim=1)
132
 
 
133
  top3_prob, top3_idx = torch.topk(probs, 3)
134
  top3 = [(labels[i], f"{p.item()*100:.2f}%") for i, p in zip(top3_idx[0], top3_prob[0])]
135
 
 
136
  confidence, predicted = torch.max(probs, 1)
137
  prediction = labels[predicted.item()]
138
 
 
139
  if confidence.item() < 0.7:
140
  prediction_status = "⚠️ Unsichere Vorhersage"
141
  else:
142
  prediction_status = "✅ Sichere Vorhersage"
143
 
 
144
  fig = plot_probabilities(probs, labels)
 
 
145
  img_hash = get_image_hash(image)
 
 
146
  gradcam_img = generate_gradcam(image, model, predicted.item())
147
 
148
  return prediction, f"Confidence: {confidence.item()*100:.2f}%\n{prediction_status}", top3, fig, gradcam_img, img_hash
149
 
150
+ # === Zero-Shot Modell: Prediction ===
151
+ def zero_shot_predict(image):
152
+ image = image.convert("RGB")
153
+ inputs = clip_processor(
154
+ text=zero_shot_prompts,
155
+ images=image,
156
+ return_tensors="pt",
157
+ padding=True
158
+ ).to(device)
159
+
160
+ with torch.no_grad():
161
+ outputs = clip_model(**inputs)
162
+
163
+ logits_per_image = outputs.logits_per_image
164
+ probs = logits_per_image.softmax(dim=1)
165
+ top3_prob, top3_idx = torch.topk(probs, 3)
166
+
167
+ top3 = [(zero_shot_prompts[i], f"{p.item()*100:.2f}%") for i, p in zip(top3_idx[0], top3_prob[0])]
168
+ best_emotion = zero_shot_prompts[top3_idx[0][0]]
169
+
170
+ return best_emotion, top3
171
+
172
+ # === Feedback speichern ===
173
  def save_feedback(img_hash, model_prediction, user_feedback, confidence):
174
  data = {
175
  "image_hash": [img_hash],
 
186
  df_new.to_csv(FEEDBACK_FILE, index=False)
187
  return "✅ Vielen Dank für dein Feedback!"
188
 
189
+ # Download Feedback
190
  def download_feedback():
191
  if os.path.exists(FEEDBACK_FILE):
192
  return FEEDBACK_FILE
193
  else:
194
  return None
195
 
196
+ # Kombinierte Funktion: Training + Zero-Shot
197
  def full_pipeline(image, user_feedback):
198
  prediction, confidence_text, top3, fig, gradcam_img, img_hash = predict_emotion(image)
199
+ zero_shot_prediction, zero_shot_top3 = zero_shot_predict(image)
200
  feedback_message = save_feedback(img_hash, prediction, user_feedback, confidence_text.split("\n")[0])
201
+ return prediction, confidence_text, top3, fig, gradcam_img, zero_shot_prediction, zero_shot_top3, feedback_message
202
 
203
+ # === Gradio Interface ===
204
  with gr.Blocks() as interface:
205
  with gr.Row():
206
  with gr.Column():
 
209
  submit_btn = gr.Button("Absenden")
210
  download_btn = gr.Button("Feedback-Daten herunterladen")
211
  with gr.Column():
212
+ prediction_output = gr.Textbox(label="Dein Modell: Vorhergesagte Emotion")
213
  confidence_output = gr.Textbox(label="Confidence + Einschätzung")
214
  top3_output = gr.Dataframe(headers=["Emotion", "Wahrscheinlichkeit (%)"], label="Top 3 Emotionen")
215
  plot_output = gr.Plot(label="Verteilung der Emotionen")
216
  gradcam_output = gr.Image(label="Grad-CAM Visualisierung")
217
+ zero_shot_prediction_output = gr.Textbox(label="Zero-Shot Modell: Vorhergesagte Emotion")
218
+ zero_shot_top3_output = gr.Dataframe(headers=["Emotion", "Confidence (%)"], label="Zero-Shot Top 3 Emotionen")
219
  feedback_message_output = gr.Textbox(label="Feedback-Status")
220
 
221
  submit_btn.click(
222
  fn=full_pipeline,
223
  inputs=[image_input, feedback_input],
224
+ outputs=[
225
+ prediction_output, confidence_output, top3_output,
226
+ plot_output, gradcam_output,
227
+ zero_shot_prediction_output, zero_shot_top3_output,
228
+ feedback_message_output
229
+ ]
230
  )
231
 
232
  download_btn.click(