Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from PIL import Image | |
| import torch | |
| from torchvision import transforms, models | |
| # Load trained model | |
| checkpoint = torch.load("animal_model.pth", map_location="cpu") | |
| class_names = checkpoint["class_names"] | |
| # Define model architecture | |
| model = models.resnet50(weights=None) # same as trained | |
| model.fc = torch.nn.Linear(model.fc.in_features, len(class_names)) | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| model.eval() | |
| # Image preprocessing | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| # Prediction function | |
| def predict(image): | |
| img = Image.fromarray(image).convert("RGB") | |
| img = transform(img).unsqueeze(0) # add batch dimension | |
| with torch.no_grad(): | |
| outputs = model(img) | |
| _, pred = torch.max(outputs, 1) | |
| return class_names[pred.item()] | |
| # Gradio Interface | |
| app = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="numpy"), | |
| outputs="text", | |
| title="Animal Image Classifier", | |
| description="Upload an image of an animal and the model will classify it." | |
| ) | |
| if __name__ == "__main__": | |
| app.launch() | |