| | |
| | import gradio as gr |
| | import torch |
| | import torchvision.transforms as transforms |
| | from torchvision import models |
| | import torch.nn as nn |
| | from PIL import Image |
| | import json |
| | import os |
| |
|
| | |
| | with open('class_names.json', 'r') as f: |
| | class_names = json.load(f) |
| | print(f"Klassen geladen: {len(class_names)} Klassen gefunden") |
| |
|
| | |
| | def load_model(): |
| | model = models.resnet50(pretrained=False) |
| | model.fc = nn.Linear(model.fc.in_features, len(class_names)) |
| | |
| | |
| | model_path = 'reptile_classifier.pth' |
| | print(f"Lade Modell von: {model_path}") |
| | |
| | |
| | checkpoint = torch.load(model_path, map_location=torch.device('cpu')) |
| | model.load_state_dict(checkpoint['model_state_dict']) |
| | model.eval() |
| | return model |
| |
|
| | |
| | model = load_model() |
| | print("Modell erfolgreich geladen") |
| |
|
| | |
| | transform = transforms.Compose([ |
| | transforms.Resize((224, 224)), |
| | transforms.ToTensor(), |
| | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
| | ]) |
| |
|
| | |
| | def predict_image(image): |
| | if image is None: |
| | return None |
| | |
| | |
| | image = transform(image).unsqueeze(0) |
| | |
| | |
| | with torch.no_grad(): |
| | outputs = model(image) |
| | probabilities = torch.nn.functional.softmax(outputs[0], dim=0) |
| | |
| | |
| | top3_prob, top3_indices = torch.topk(probabilities, 3) |
| | |
| | |
| | results = [(class_names[idx], float(prob)) for idx, prob in zip(top3_indices, top3_prob)] |
| | |
| | return {class_name: float(prob) for class_name, prob in results} |
| |
|
| | |
| | def main(): |
| | title = "Reptilien- und Amphibien-Klassifikation" |
| | description = "Lade ein Bild eines Reptils oder Amphibiums hoch, um es zu klassifizieren. Dieses Modell kann verschiedene Arten basierend auf dem Reptiles and Amphibians Dataset von Kaggle identifizieren." |
| | |
| | |
| | interface = gr.Interface( |
| | fn=predict_image, |
| | inputs=gr.Image(type="pil"), |
| | outputs=gr.Label(num_top_classes=3), |
| | title=title, |
| | description=description |
| | ) |
| | |
| | |
| | interface.launch() |
| |
|
| | if __name__ == "__main__": |
| | main() |