Spaces:
Sleeping
Sleeping
| import torch, torchvision | |
| from torchvision import transforms | |
| import numpy as np | |
| import gradio as gr | |
| from PIL import Image | |
| from pytorch_grad_cam import GradCAM | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| from torch.utils.data import DataLoader | |
| import itertools | |
| import matplotlib.pyplot as plt | |
| import utils as utils | |
| from model import Net | |
| model = Net() | |
| model.load_state_dict(torch.load("model.pth", map_location=torch.device('cpu')), strict=False) | |
| model.eval() | |
| classes = ('plane', 'car', 'bird', 'cat', 'deer', | |
| 'dog', 'frog', 'horse', 'ship', 'truck') | |
| cifar_valid = utils.Cifar10SearchDataset('.', train=False, download=True, transform=utils.augmentation_custom_resnet()) | |
| inv_normalize = transforms.Normalize( | |
| mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23], | |
| std=[1/0.23, 1/0.23, 1/0.23] | |
| ) | |
| def inference(wants_gradcam, n_gradcam, target_layer_number, transparency, wants_misclassified, n_misclassified, input_img = None, n_top_classes=10): | |
| if wants_gradcam: | |
| outputs_inference_gc = [] | |
| cifar_valid_loader = DataLoader(cifar_valid, batch_size=1, shuffle = True) | |
| count_gradcam = 1 | |
| for data, target in cifar_valid_loader: | |
| data, target = data.to('cpu'), target.to('cpu') | |
| if target_layer_number == '-2': | |
| target_layers = [model.convblock31[0]] | |
| elif target_layer_number == '-1': | |
| target_layers = [model.convblock21[0]] | |
| cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False) | |
| grayscale_cam = cam(input_tensor=data, targets=None) | |
| grayscale_cam = grayscale_cam[0, :] | |
| org_img = inv_normalize(data).squeeze(0).numpy() | |
| org_img = np.transpose(org_img, (1, 2, 0)) | |
| visualization = np.array(show_cam_on_image(org_img, grayscale_cam, use_rgb=True, image_weight=transparency)) | |
| outputs_inference_gc.append(visualization) | |
| count_gradcam += 1 | |
| if count_gradcam > n_gradcam: | |
| break | |
| else: | |
| outputs_inference_gc = None | |
| if wants_misclassified: | |
| outputs_inference_mis = [] | |
| cifar_valid_loader = DataLoader(cifar_valid, batch_size=1, shuffle = True) | |
| count_mis = 1 | |
| for data, target in cifar_valid_loader: | |
| data, target = data.to('cpu'), target.to('cpu') | |
| outputs = model(data) | |
| softmax = torch.nn.Softmax(dim=0) | |
| o = softmax(outputs.flatten()) | |
| confidences = {classes[i]: float(o[i]) for i in range(10)} | |
| _, prediction = torch.max(outputs, 1) | |
| if target.numpy()[0] != prediction.numpy()[0]: | |
| count_mis += 1 | |
| org_img = inv_normalize(data).squeeze(0).numpy() | |
| org_img = np.transpose(org_img, (1, 2, 0)) | |
| fig = plt.figure() | |
| fig.add_subplot(111) | |
| plt.imshow(org_img) | |
| plt.title(f'Target: {classes[target.numpy()[0]]}\nPred: {classes[prediction.numpy()[0]]}') | |
| plt.axis('off') | |
| fig.canvas.draw() | |
| fig_img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) | |
| fig_img = fig_img.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
| plt.close(fig) | |
| outputs_inference_mis.append(fig_img) | |
| if count_mis > n_misclassified: | |
| break | |
| else: | |
| outputs_inference_mis = None | |
| if input_img is not None: | |
| transform=utils.augmentation_custom_resnet('Valid') | |
| org_img = input_img | |
| input_img = transform(image=input_img) | |
| input_img = input_img['image'].unsqueeze(0) | |
| outputs = model(input_img) | |
| softmax = torch.nn.Softmax(dim=0) | |
| o = softmax(outputs.flatten()) | |
| confidences = {classes[i]: float(o[i]) for i in range(10)} | |
| _, prediction = torch.max(outputs, 1) | |
| confidences = {k: v for k, v in sorted(confidences.items(), key=lambda item: item[1], reverse=True)} | |
| confidences = dict(itertools.islice(confidences.items(), n_top_classes)) | |
| else: | |
| confidences = None | |
| return outputs_inference_gc, outputs_inference_mis, confidences | |
| title = "CIFAR10 trained on Custom ResNet Model with GradCAM" | |
| description = "A Gradio interface to infer on Custom ResNet model, and to get GradCAM results" | |
| examples = [[None, None, None, None, None, None, 'examples/gr_'+str(i)+'.jpg', None] for i in range(10)] | |
| demo = gr.Interface(inference, | |
| inputs = [gr.Checkbox(False, label='Do you want to see GradCAM outputs?'), | |
| gr.Slider(0, 10, value = 0, step=1, label="How many?"), | |
| gr.inputs.Dropdown([-2, -1], label="Which target layer?"), | |
| gr.Slider(0, 1, value = 0, label="Opacity of GradCAM"), | |
| gr.Checkbox(False, label='Do you want to see misclassified images?'), | |
| gr.Slider(0, 10, value = 0, step=1, label="How many?"), | |
| gr.Image(shape=(32, 32), label="Input image"), | |
| gr.Slider(0, 10, value = 0, step=1, label="How many top classes you want to see?") | |
| ], | |
| outputs = [ | |
| gr.Gallery(label="GradCAM Outputs", show_label=True, elem_id="gallery").style(columns=[2], rows=[2], object_fit="contain", height="auto"), | |
| gr.Gallery(label="Misclassified Images", show_label=True, elem_id="gallery").style(columns=[2], rows=[2], object_fit="contain", height="auto"), | |
| gr.Label(num_top_classes=10, label = "Top classes") | |
| ], | |
| title = title, | |
| description = description, | |
| examples = examples | |
| ) | |
| demo.launch() |