Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoImageProcessor, ViTForImageClassification | |
| from transformers import pipeline | |
| # CIFAR-10 Klassenlabels | |
| labels_cifar10 = [ | |
| 'airplane', 'automobile', 'bird', 'cat', 'deer', | |
| 'dog', 'frog', 'horse', 'ship', 'truck' | |
| ] | |
| # Lade Modell und Processor separat | |
| processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") | |
| model = ViTForImageClassification.from_pretrained("Fadri/results") | |
| # CLIP für Zero-Shot bleibt wie vorher | |
| clip_detector = pipeline(model="openai/clip-vit-large-patch14", task="zero-shot-image-classification") | |
| def predict_cifar10(image_path): | |
| # Bild laden und vorverarbeiten | |
| image = Image.open(image_path).convert("RGB") | |
| inputs = processor(images=image, return_tensors="pt") | |
| # Modellvorhersage | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| predicted_class_idx = logits.argmax(-1).item() | |
| # Top-3 Ergebnisse mit Wahrscheinlichkeiten | |
| probabilities = torch.nn.functional.softmax(logits, dim=-1)[0] | |
| top3_probs, top3_indices = torch.topk(probabilities, 3) | |
| results = {} | |
| for idx, prob in zip(top3_indices, top3_probs): | |
| label = model.config.id2label[idx.item()] | |
| results[label] = round(prob.item(), 4) | |
| return results | |
| def classify_image(image): | |
| # Klassifikation mit deinem Modell | |
| cifar10_output = predict_cifar10(image) | |
| # Zero-Shot-Klassifikation mit CLIP | |
| clip_results = clip_detector(image, candidate_labels=labels_cifar10) | |
| clip_output = {result['label']: result['score'] for result in clip_results} | |
| return { | |
| "CIFAR-10 ViT Klassifikation": cifar10_output, | |
| "CLIP Zero-Shot Klassifikation": clip_output | |
| } | |
| # Beispielbilder (Pfade anpassen) | |
| example_images = [ | |
| ["examples/airplane.jpg"], | |
| ["examples/car.jpg"], | |
| ["examples/dog.jpg"], | |
| ["examples/cat.jpg"] | |
| ] | |
| # Gradio Interface | |
| iface = gr.Interface( | |
| fn=classify_image, | |
| inputs=gr.Image(type="filepath"), | |
| outputs=gr.JSON(), | |
| title="CIFAR-10 Klassifikation", | |
| description="Lade ein Bild hoch und vergleiche die Ergebnisse zwischen deinem trainierten ViT Modell und CLIP.", | |
| examples=example_images | |
| ) | |
| iface.launch() |