Spaces:
Build error
Build error
| import torch | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| import gradio as gr | |
| from resnet_model import ResNet50 | |
| from utils import load_checkpoint | |
| import ast | |
| # Load the model | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = ResNet50() | |
| model = torch.nn.DataParallel(model) | |
| model = model.to(device) | |
| # Load the checkpoint | |
| checkpoint_path = "checkpoint.pth" | |
| model, _, _, _ = load_checkpoint(model, None, checkpoint_path) | |
| model.eval() | |
| # Define the image transformation | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| # Load class labels from the file | |
| with open("imagenet1000_clsidx_to_labels.txt") as f: | |
| class_labels = ast.literal_eval(f.read()) | |
| # Define the prediction function | |
| def predict(image): | |
| image = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| outputs = model(image) | |
| probabilities = torch.nn.functional.softmax(outputs, dim=1)[0] | |
| top5_prob, top5_catid = torch.topk(probabilities, 5) | |
| results = "<div style='font-family: Arial, sans-serif; font-size: 18px;'>" | |
| for i in range(top5_prob.size(0)): | |
| class_index = top5_catid[i].item() | |
| class_label = class_labels.get(class_index, "Unknown") | |
| prob = top5_prob[i].item() * 100 | |
| results += f"<div style='margin-bottom: 10px;'><strong>{class_label}</strong>: {prob:.2f}%</div>" | |
| results += f"<div style='background-color: #ddd; width: 100%;'><div style='width: {prob}%; background-color: #4CAF50; height: 20px;'></div></div>" | |
| results += "</div>" | |
| return results | |
| # Create the Gradio interface | |
| iface = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs="html", title="ResNet 50 Image Classifier") | |
| iface.launch() |