Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| """ERAV2-S13-Himank-Gradio.ipynb | |
| Automatically generated by Colab. | |
| Original file is located at | |
| https://colab.research.google.com/drive/1HJ6wO2_czxZrJwnyUkJ_XaS5HYUvooMS | |
| """ | |
| import torch, torchvision | |
| from torchvision import transforms | |
| import numpy as np | |
| import gradio as gr | |
| from PIL import Image | |
| from pytorch_grad_cam import GradCAM | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| from model import ResNet18 | |
| model = ResNet18() | |
| model.load_state_dict(torch.load("model.pth", map_location=torch.device('cpu')), strict=False) | |
| inv_normalize = transforms.Normalize( | |
| mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23], | |
| std=[1/0.23, 1/0.23, 1/0.23] | |
| ) | |
| classes = ('plane', 'car', 'bird', 'cat', 'deer', | |
| 'dog', 'frog', 'horse', 'ship', 'truck') | |
| def resize_image_pil(image, new_width, new_height): | |
| img = Image.fromarray(np.array(image)) | |
| width, height = img.size | |
| width_scale = new_width / width | |
| height_scale = new_height / height | |
| scale = min(width_scale, height_scale) | |
| resized = img.resize((int(width*scale), int(height*scale)), Image.NEAREST) | |
| resized = resized.crop((0, 0, new_width, new_height)) | |
| return resized | |
| def inference(input_img,enable_grad_cam,transparency=0.5,target_layer_number=-1,num_top_classes=2): | |
| input_img = resize_image_pil(input_img, 32, 32) | |
| input_img = np.array(input_img) | |
| org_img = input_img | |
| input_img = input_img.reshape((32, 32, 3)) | |
| transform = transforms.ToTensor() | |
| input_img = transform(input_img) | |
| input_img = input_img | |
| input_img = input_img.unsqueeze(0) | |
| outputs = model(input_img) | |
| softmax = torch.nn.Softmax(dim=0) | |
| o = softmax(outputs.flatten()) | |
| confidences = {classes[i]: float(o[i]) for i in range(10)} | |
| _, prediction = torch.max(outputs, 1) | |
| target_layers = [model.layer2[target_layer_number]] | |
| cam = GradCAM(model=model, target_layers=target_layers) | |
| grayscale_cam = cam(input_tensor=input_img, targets=None) | |
| grayscale_cam = grayscale_cam[0, :] | |
| img = input_img.squeeze(0) | |
| img = inv_normalize(img) | |
| if enable_grad_cam: | |
| visualization = show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency) | |
| else: | |
| visualization = None | |
| confidences = sorted(confidences.items(), key=lambda x: x[1], reverse=True) | |
| return classes[prediction[0].item()], visualization, dict(confidences[:num_top_classes]) | |
| title = "CIFAR10 trained on ResNet18 Model with GradCAM" | |
| description = "A simple Gradio interface to infer on ResNet model, and get GradCAM results" | |
| examples = [ | |
| ["cat.jpg", True, 0.5, -1, 2], ["dog.jpg", True, 0.5, -1, 3], ["bird.jpg", True, 0.5, -1, 4], ["car.jpg", False, 0.5, -1, 5], ["deer.jpg", True, 0.5, -1, 6], | |
| ["frog.jpg", False, 0.5, -1, 7], ["horse.jpg", False, 0.45, -1, 8], ["plane.jpg", True, 0.30, -2, 9], ["ship.jpg", False, 0.25, -2, 10], ["truck.jpg", True ,0.75, -2, 1] | |
| ] | |
| demo = gr.Interface( | |
| inference, | |
| inputs = [ | |
| gr.Image(width=256, height=256, label="Input Image"), | |
| gr.Checkbox(value=False, label="Enable grad-cam image"), | |
| gr.Slider(0, 1, value = 0.5, label="Overall Opacity of Image"), | |
| gr.Slider(-2, -1, value = -2, step=1, label="Select Layer"), | |
| gr.Number(value=2, label="Number of Top Classes to Show", minimum=1, maximum=10), | |
| ], | |
| outputs = [ | |
| gr.Textbox(label="Predicted Category"), | |
| gr.Image(width=256, height=256, label="Output"), | |
| gr.Label() | |
| ], | |
| title = title, | |
| description = description, | |
| examples = examples, | |
| ) | |
| demo.launch() |