import gradio as gr from PIL import Image import torch from torchvision import transforms, models # Load trained model checkpoint = torch.load("animal_model.pth", map_location="cpu") class_names = checkpoint["class_names"] # Define model architecture model = models.resnet50(weights=None) # same as trained model.fc = torch.nn.Linear(model.fc.in_features, len(class_names)) model.load_state_dict(checkpoint["model_state_dict"]) model.eval() # Image preprocessing transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Prediction function def predict(image): img = Image.fromarray(image).convert("RGB") img = transform(img).unsqueeze(0) # add batch dimension with torch.no_grad(): outputs = model(img) _, pred = torch.max(outputs, 1) return class_names[pred.item()] # Gradio Interface app = gr.Interface( fn=predict, inputs=gr.Image(type="numpy"), outputs="text", title="Animal Image Classifier", description="Upload an image of an animal and the model will classify it." ) if __name__ == "__main__": app.launch()