Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| from torchvision import transforms, models | |
| from PIL import Image | |
| import requests | |
| # Load ImageNet labels | |
| LABELS_URL = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt" | |
| labels = requests.get(LABELS_URL).text.strip().split("\n") | |
| # Load model (change this to your model path) | |
| model = models.resnet50(weights=None) | |
| # If using your converted FP16 model: | |
| # state = torch.load("model_cpu.pt", map_location="cpu") | |
| # def to_fp32(obj): | |
| # if isinstance(obj, torch.Tensor) and obj.dtype == torch.float16: | |
| # return obj.float() | |
| # if isinstance(obj, dict): | |
| # return {k: to_fp32(v) for k, v in obj.items()} | |
| # return obj | |
| # model.load_state_dict(to_fp32(state)) | |
| # For demo, using pretrained weights | |
| model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1) | |
| model.eval() | |
| # Preprocessing | |
| transform = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| def predict(image): | |
| if image is None: | |
| return {} | |
| img = transform(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| outputs = model(img) | |
| probs = F.softmax(outputs, dim=1)[0] | |
| top5_probs, top5_indices = torch.topk(probs, 5) | |
| return {labels[idx]: float(prob) for prob, idx in zip(top5_probs, top5_indices)} | |
| # Gradio interface | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil", label="Upload Image"), | |
| outputs=gr.Label(num_top_classes=5, label="Predictions"), | |
| title="🖼️ ImageNet 1K Classifier", | |
| description="Upload an image to classify it into one of 1000 ImageNet categories.", | |
| examples=[ | |
| ["https://upload.wikimedia.org/wikipedia/commons/thumb/4/4d/Cat_November_2010-1a.jpg/1200px-Cat_November_2010-1a.jpg"], | |
| ], | |
| theme=gr.themes.Soft(), | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |