# app.py import gradio as gr import torch import torchvision.transforms as transforms from torchvision import models import torch.nn as nn from PIL import Image import json import os # Korrekten Pfad zur class_names.json Datei verwenden with open('class_names.json', 'r') as f: class_names = json.load(f) print(f"Klassen geladen: {len(class_names)} Klassen gefunden") # Define the model def load_model(): model = models.resnet50(pretrained=False) model.fc = nn.Linear(model.fc.in_features, len(class_names)) # Korrekten Pfad zur Modelldatei verwenden model_path = 'reptile_classifier.pth' print(f"Lade Modell von: {model_path}") # Load the trained model weights checkpoint = torch.load(model_path, map_location=torch.device('cpu')) model.load_state_dict(checkpoint['model_state_dict']) model.eval() return model # Load the model model = load_model() print("Modell erfolgreich geladen") # Define image transformation transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Prediction function def predict_image(image): if image is None: return None # Preprocess the image image = transform(image).unsqueeze(0) # Make prediction with torch.no_grad(): outputs = model(image) probabilities = torch.nn.functional.softmax(outputs[0], dim=0) # Get top 3 predictions top3_prob, top3_indices = torch.topk(probabilities, 3) # Format results results = [(class_names[idx], float(prob)) for idx, prob in zip(top3_indices, top3_prob)] return {class_name: float(prob) for class_name, prob in results} # Create Gradio interface def main(): title = "Reptilien- und Amphibien-Klassifikation" description = "Lade ein Bild eines Reptils oder Amphibiums hoch, um es zu klassifizieren. Dieses Modell kann verschiedene Arten basierend auf dem Reptiles and Amphibians Dataset von Kaggle identifizieren." # Define the interface interface = gr.Interface( fn=predict_image, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=3), title=title, description=description ) # Launch the app interface.launch() if __name__ == "__main__": main()