Spaces:
Build error
Build error
| import torch | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| import numpy as np | |
| import os | |
| import hashlib | |
| import cv2 | |
| from huggingface_hub import hf_hub_download | |
| from transformers import CLIPProcessor, CLIPModel | |
| # === Dein trainiertes Modell laden === | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model_path = hf_hub_download( | |
| repo_id="thoeppner/emotion_model", | |
| filename="emotion_model.pt" | |
| ) | |
| model = models.resnet18() | |
| model.fc = torch.nn.Linear(model.fc.in_features, 9) | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| model = model.to(device) | |
| model.eval() | |
| # === Zero-Shot Modell (CLIP) laden === | |
| clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) | |
| clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| clip_model.eval() | |
| # === Labels === | |
| labels = ["happy", "sad", "angry", "surprised", "fear", "disgust", "neutral", "contempt", "unknown"] | |
| # Zero-Shot Text Prompts | |
| zero_shot_prompts = [ | |
| "a happy person", | |
| "a sad person", | |
| "an angry person", | |
| "a surprised person", | |
| "a fearful person", | |
| "a disgusted person", | |
| "a neutral person", | |
| "a contemptuous person", | |
| "an unknown emotion" | |
| ] | |
| # === Transformation für Bilder === | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor() | |
| ]) | |
| # === Feedback-File === | |
| FEEDBACK_FILE = "user_feedback.csv" | |
| # === Hilfsfunktionen === | |
| def get_image_hash(image): | |
| img_bytes = image.tobytes() | |
| return hashlib.md5(img_bytes).hexdigest() | |
| def plot_probabilities(probabilities, labels): | |
| probs = probabilities.cpu().numpy().flatten() | |
| fig, ax = plt.subplots(figsize=(8, 4)) | |
| ax.barh(labels, probs) | |
| ax.set_xlim(0, 1) | |
| ax.invert_yaxis() | |
| ax.set_xlabel('Confidence') | |
| ax.set_title('Emotion Probabilities') | |
| plt.tight_layout() | |
| return fig | |
| def generate_gradcam(image, model, class_idx): | |
| model.eval() | |
| gradients = [] | |
| activations = [] | |
| def save_gradient(grad): | |
| gradients.append(grad) | |
| def forward_hook(module, input, output): | |
| activations.append(output) | |
| output.register_hook(save_gradient) | |
| target_layer = model.layer4[1].conv2 | |
| handle = target_layer.register_forward_hook(forward_hook) | |
| image_tensor = transform(image).unsqueeze(0).to(device) | |
| output = model(image_tensor) | |
| model.zero_grad() | |
| class_score = output[0, class_idx] | |
| class_score.backward() | |
| gradients = gradients[0].cpu().data.numpy()[0] | |
| activations = activations[0].cpu().data.numpy()[0] | |
| weights = np.mean(gradients, axis=(1, 2)) | |
| gradcam = np.zeros(activations.shape[1:], dtype=np.float32) | |
| for i, w in enumerate(weights): | |
| gradcam += w * activations[i, :, :] | |
| gradcam = np.maximum(gradcam, 0) | |
| gradcam = cv2.resize(gradcam, (224, 224)) | |
| gradcam = gradcam - np.min(gradcam) | |
| if np.max(gradcam) != 0: | |
| gradcam = gradcam / np.max(gradcam) | |
| heatmap = cv2.applyColorMap(np.uint8(255 * gradcam), cv2.COLORMAP_JET) | |
| image_np = np.array(image.resize((224, 224)).convert("RGB")) | |
| if heatmap.shape != image_np.shape: | |
| heatmap = cv2.resize(heatmap, (image_np.shape[1], image_np.shape[0])) | |
| overlay = cv2.addWeighted(image_np, 0.6, heatmap, 0.4, 0) | |
| handle.remove() | |
| return Image.fromarray(overlay) | |
| # === Dein Modell: Prediction === | |
| def predict_emotion(image): | |
| image = image.convert("RGB") | |
| transformed_image = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| outputs = model(transformed_image) | |
| probs = torch.softmax(outputs, dim=1) | |
| top3_prob, top3_idx = torch.topk(probs, 3) | |
| top3 = [(labels[i], f"{p.item()*100:.2f}%") for i, p in zip(top3_idx[0], top3_prob[0])] | |
| confidence, predicted = torch.max(probs, 1) | |
| prediction = labels[predicted.item()] | |
| if confidence.item() < 0.7: | |
| prediction_status = "⚠️ Unsichere Vorhersage" | |
| else: | |
| prediction_status = "✅ Sichere Vorhersage" | |
| fig = plot_probabilities(probs, labels) | |
| img_hash = get_image_hash(image) | |
| gradcam_img = generate_gradcam(image, model, predicted.item()) | |
| return prediction, f"Confidence: {confidence.item()*100:.2f}%\n{prediction_status}", top3, fig, gradcam_img, img_hash | |
| # === Zero-Shot Modell: Prediction === | |
| def zero_shot_predict(image): | |
| image = image.convert("RGB") | |
| inputs = clip_processor( | |
| text=zero_shot_prompts, | |
| images=image, | |
| return_tensors="pt", | |
| padding=True | |
| ).to(device) | |
| with torch.no_grad(): | |
| outputs = clip_model(**inputs) | |
| logits_per_image = outputs.logits_per_image | |
| probs = logits_per_image.softmax(dim=1) | |
| top3_prob, top3_idx = torch.topk(probs, 3) | |
| top3 = [(zero_shot_prompts[i], f"{p.item()*100:.2f}%") for i, p in zip(top3_idx[0], top3_prob[0])] | |
| best_emotion = zero_shot_prompts[top3_idx[0][0]] | |
| return best_emotion, top3 | |
| # === Feedback speichern === | |
| def save_feedback(img_hash, model_prediction, user_feedback, confidence): | |
| data = { | |
| "image_hash": [img_hash], | |
| "model_prediction": [model_prediction], | |
| "user_feedback": [user_feedback], | |
| "confidence": [confidence] | |
| } | |
| df_new = pd.DataFrame(data) | |
| if os.path.exists(FEEDBACK_FILE): | |
| df_existing = pd.read_csv(FEEDBACK_FILE) | |
| df_existing = pd.concat([df_existing, df_new], ignore_index=True) | |
| df_existing.to_csv(FEEDBACK_FILE, index=False) | |
| else: | |
| df_new.to_csv(FEEDBACK_FILE, index=False) | |
| return "✅ Vielen Dank für dein Feedback!" | |
| # Download Feedback | |
| def download_feedback(): | |
| if os.path.exists(FEEDBACK_FILE): | |
| return FEEDBACK_FILE | |
| else: | |
| return None | |
| # Kombinierte Funktion: Training + Zero-Shot | |
| def full_pipeline(image, user_feedback): | |
| prediction, confidence_text, top3, fig, gradcam_img, img_hash = predict_emotion(image) | |
| zero_shot_prediction, zero_shot_top3 = zero_shot_predict(image) | |
| feedback_message = save_feedback(img_hash, prediction, user_feedback, confidence_text.split("\n")[0]) | |
| return prediction, confidence_text, top3, fig, gradcam_img, zero_shot_prediction, zero_shot_top3, feedback_message | |
| # === Gradio Interface === | |
| with gr.Blocks() as interface: | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="Lade dein Bild hoch") | |
| feedback_input = gr.Dropdown(choices=labels, label="Dein Feedback: Was ist die richtige Emotion?") | |
| submit_btn = gr.Button("Absenden") | |
| download_btn = gr.Button("Feedback-Daten herunterladen") | |
| with gr.Column(): | |
| prediction_output = gr.Textbox(label="Dein Modell: Vorhergesagte Emotion") | |
| confidence_output = gr.Textbox(label="Confidence + Einschätzung") | |
| top3_output = gr.Dataframe(headers=["Emotion", "Wahrscheinlichkeit (%)"], label="Top 3 Emotionen") | |
| plot_output = gr.Plot(label="Verteilung der Emotionen") | |
| gradcam_output = gr.Image(label="Grad-CAM Visualisierung") | |
| zero_shot_prediction_output = gr.Textbox(label="Zero-Shot Modell: Vorhergesagte Emotion") | |
| zero_shot_top3_output = gr.Dataframe(headers=["Emotion", "Confidence (%)"], label="Zero-Shot Top 3 Emotionen") | |
| feedback_message_output = gr.Textbox(label="Feedback-Status") | |
| submit_btn.click( | |
| fn=full_pipeline, | |
| inputs=[image_input, feedback_input], | |
| outputs=[ | |
| prediction_output, confidence_output, top3_output, | |
| plot_output, gradcam_output, | |
| zero_shot_prediction_output, zero_shot_top3_output, | |
| feedback_message_output | |
| ] | |
| ) | |
| download_btn.click( | |
| fn=download_feedback, | |
| inputs=[], | |
| outputs=[gr.File()] | |
| ) | |
| interface.launch() | |