thoeppner commited on
Commit
39fced5
·
verified ·
1 Parent(s): 24e7cfa

Create zero_shot_emotion.py

Browse files
Files changed (1) hide show
  1. zero_shot_emotion.py +60 -0
zero_shot_emotion.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from transformers import CLIPProcessor, CLIPModel
3
+ import torch
4
+ import gradio as gr
5
+
6
+ # Modell laden
7
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
8
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
9
+
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ model = model.to(device)
12
+ model.eval()
13
+
14
+ # Deine Emotion Labels
15
+ emotion_labels = [
16
+ "a happy person",
17
+ "a sad person",
18
+ "an angry person",
19
+ "a surprised person",
20
+ "a fearful person",
21
+ "a disgusted person",
22
+ "a neutral person",
23
+ "a contemptuous person",
24
+ "an unknown emotion"
25
+ ]
26
+
27
+ # Funktion
28
+ def zero_shot_predict(image):
29
+ image = image.convert("RGB")
30
+
31
+ inputs = processor(
32
+ text=emotion_labels,
33
+ images=image,
34
+ return_tensors="pt",
35
+ padding=True
36
+ ).to(device)
37
+
38
+ with torch.no_grad():
39
+ outputs = model(**inputs)
40
+
41
+ logits_per_image = outputs.logits_per_image # Bild-Text Ähnlichkeiten
42
+ probs = logits_per_image.softmax(dim=1) # Wahrscheinlichkeiten
43
+ top3_prob, top3_idx = torch.topk(probs, 3)
44
+
45
+ # Ergebnisse
46
+ top3 = [(emotion_labels[i], f"{p.item() * 100:.2f}%") for i, p in zip(top3_idx[0], top3_prob[0])]
47
+ best_emotion = emotion_labels[top3_idx[0][0]]
48
+
49
+ return best_emotion, top3
50
+
51
+ # Gradio Interface
52
+ interface = gr.Interface(
53
+ fn=zero_shot_predict,
54
+ inputs=gr.Image(type="pil"),
55
+ outputs=["text", gr.Dataframe(headers=["Emotion", "Confidence (%)"])],
56
+ title="Zero-Shot Emotion Recognition",
57
+ description="Erkenne Emotionen ohne Training — einfach mit CLIP!"
58
+ )
59
+
60
+ interface.launch()