Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import pytorch_lightning as pl | |
| from torchvision import transforms | |
| from PIL import Image | |
| from torchvision import models | |
| import torch.nn as nn | |
| # Define the LightningModule class (should match the training code) | |
| class ResNet50Image2k(pl.LightningModule): | |
| def __init__(self, num_classes=1000): | |
| super().__init__() | |
| self.model = models.resnet50(pretrained=False) | |
| self.model.fc = nn.Linear(self.model.fc.in_features, num_classes) | |
| def forward(self, x): | |
| return self.model(x) | |
| # Load the model from PyTorch Lightning checkpoint | |
| checkpoint_path = "./resnet50_exp.ckpt" # Replace with your checkpoint file path | |
| model = ResNet50Image2k.load_from_checkpoint(checkpoint_path) | |
| model.eval() # Set the model to evaluation mode | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = model.to(device) | |
| # Load ImageNet class labels | |
| with open("classes.txt") as f: | |
| class_labels = [line.strip() for line in f.readlines()] | |
| # Define the preprocessing pipeline | |
| preprocess = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| # Define the prediction function | |
| def predict_top5(image): | |
| # Preprocess the image | |
| image = preprocess(image).unsqueeze(0).to(device) | |
| # Get predictions | |
| with torch.no_grad(): | |
| outputs = model(image) | |
| probabilities = torch.nn.functional.softmax(outputs[0], dim=0) | |
| # Get top-5 predictions | |
| top5_prob, top5_catid = torch.topk(probabilities, 5) | |
| top5_results = {class_labels[catid]: f"{prob.item():.4f}" for prob, catid in zip(top5_prob, top5_catid)} | |
| return top5_results | |
| examples = [ | |
| ["Images/Bird.JPEG"], # Example 1 | |
| ["Images/Chamelion.JPEG"], # Example 2 | |
| ["Images/Lizard.JPEG"], # Example 3 | |
| ["Images/Shark.JPEG"], # Example 4 | |
| ["Images/Turtle.JPEG"], # Example 5 | |
| ] | |
| # Create the Gradio interface | |
| interface = gr.Interface( | |
| fn=predict_top5, | |
| inputs=gr.Image(type="pil"), # Updated syntax for image input | |
| outputs=gr.Label(num_top_classes=5), # Updated syntax for label output | |
| title="ResNet50 Image Classification", | |
| description="Upload an image for top-5 class predictions from the ResNet50 ImageNet 1k Model.", | |
| examples=examples | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| interface.launch() | |