Spaces:
Runtime error
Runtime error
| # gradioMisClassGradCAMimageInputter | |
| import os | |
| import math | |
| import numpy as np | |
| import pandas as pd | |
| import seaborn as sn | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision | |
| import matplotlib.pyplot as plt | |
| from pl_bolts.datamodules import CIFAR10DataModule | |
| from pl_bolts.transforms.dataset_normalizations import cifar10_normalization | |
| from pytorch_lightning import LightningModule, Trainer, seed_everything | |
| from pytorch_lightning.callbacks import LearningRateMonitor | |
| from pytorch_lightning.callbacks.progress import TQDMProgressBar | |
| from pytorch_lightning.loggers import CSVLogger | |
| from torch.optim.lr_scheduler import OneCycleLR | |
| from torch.optim.swa_utils import AveragedModel, update_bn | |
| from torchmetrics.functional import accuracy | |
| from pytorch_lightning.callbacks import ModelCheckpoint | |
| from torchvision import datasets, transforms, utils | |
| from PIL import Image | |
| from pytorch_grad_cam import GradCAM | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| import gradio as gr | |
| import misclas_helper | |
| import gradcam_helper | |
| import lightningmodel | |
| from misclas_helper import display_cifar_misclassified_data | |
| from gradcam_helper import display_gradcam_output | |
| from misclas_helper import get_misclassified_data2 | |
| from lightningmodel import LitResnet | |
| fileName = None | |
| targets = None | |
| device = torch.device("cpu") | |
| classes = ('plane', 'car', 'bird', 'cat', 'deer', | |
| 'dog', 'frog', 'horse', 'ship', 'truck') | |
| model = LitResnet(lr=0.05).load_from_checkpoint("weights_92.ckpt") | |
| device = torch.device("cpu") | |
| # Denormalize the data using test mean and std deviation | |
| inv_normalize = transforms.Normalize( | |
| mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23], | |
| std=[1/0.23, 1/0.23, 1/0.23] | |
| ) | |
| # Get the misclassified data from test dataset | |
| misclassified_data = get_misclassified_data2(model, device, 20) | |
| def hello(DoYouWantToShowMisClassifiedImages, HowManyImages): | |
| if(DoYouWantToShowMisClassifiedImages.lower() == "yes"): | |
| fileName = misclas_helper.display_cifar_misclassified_data(misclassified_data, classes, inv_normalize, number_of_samples=HowManyImages) | |
| return Image.open(fileName) | |
| else: | |
| return None | |
| misClass_demo = gr.Interface( | |
| fn = hello, | |
| inputs=['text', gr.Slider(0, 20, step=5)], | |
| outputs=['image'], | |
| title="Misclasseified Images", | |
| description="If your answer to the question DoYouWantToShowMisClassifiedImages is yes, then only it works.", | |
| ) | |
| ############ | |
| def inference(DoYouWantToShowGradCAMMedImages, HowManyImages, WhichLayer, transparency): | |
| if(DoYouWantToShowGradCAMMedImages.lower() == "yes"): | |
| if(WhichLayer == -1): | |
| target_layers = [model.model.resNetLayer2Part2[-1]] | |
| elif(WhichLayer == -2): | |
| target_layers = [model.model.resNetLayer2Part1[-1]] | |
| elif(WhichLayer == -3): | |
| target_layers = [model.model.Layer3[-1]] | |
| fileName = gradcam_helper.display_gradcam_output(misclassified_data, classes, inv_normalize, model.model, target_layers, targets, number_of_samples=HowManyImages, transparency=0.70) | |
| return Image.open(fileName) | |
| gradCAM_demo = gr.Interface( | |
| fn=inference, | |
| #DoYouWantToShowGradCAMMedImages, HowManyImages, WhichLayer, transparency | |
| inputs=['text', gr.Slider(0, 20, step=5), gr.Slider(-3, -1, value = -1, step=1), gr.Slider(0, 1, value = 0.7, label = "Overall Opacity of the Overlay")], | |
| outputs=['image'], | |
| title="GradCammd Images", | |
| description="If your answer to the question DoYouWantToShowGradCAMMedImages is yes, then only it works.", | |
| ) | |
| ############ | |
| def ImageInputter(img1, img2, img3, img4, img5, img6, img7, img8, img9, img10): | |
| return img1, img2, img3, img4, img5, img6, img7, img8, img9, img10 | |
| imageInputter_demo = gr.Interface( | |
| ImageInputter, | |
| [ | |
| "image","image","image","image","image","image","image","image","image","image" | |
| ], | |
| [ | |
| "image","image","image","image","image","image","image","image","image","image" | |
| ], | |
| examples=[ | |
| ["bird.jpg", "car.jpg", "cat.jpg"], | |
| ["deer.jpg", "dog.jpg", "frog.jpg"], | |
| ["horse.jpg", "plane.jpg", "ship.jpg"], | |
| [None, "truck.jpg", None], | |
| ], | |
| title="Max 10 images input", | |
| description="Here's a sample image inputter. Allows you to feed in 10 images and display them. You may drag and drop images from bottom examples to the input feeders", | |
| ) | |
| ############ | |
| demo = gr.TabbedInterface( | |
| interface_list = [misClass_demo, gradCAM_demo, imageInputter_demo], | |
| tab_names = ["MisClassified Images", "GradCAMMed Images", "10 images inputter"] | |
| ) | |
| demo.launch(debug=True) |