sanjanatule's picture
Update app.py
67b18f0
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
import numpy as np
from torch_lr_finder import LRFinder
from torch.optim.lr_scheduler import OneCycleLR
import torch, torchvision
from torchvision import transforms
import numpy as np
import gradio as gr
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
import gradio as gr
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.loggers import TensorBoardLogger
from torchmetrics import Accuracy
from models import custom_resnet
from network import LitResnet
inference_model = LitResnet.load_from_checkpoint("cifar10_customresnet_20_epoch.ckpt")
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
# def inference(input_img, see_misclassified=False,num_misclassified_imgs=0,see_gradcam=False,num_gradcam_imgs=0,transparency = 0.85, target_layer_number = -1,top_classes=3):
# if see_misclassified: # show misclassified images
# org_img = np.asarray(Image.open('misclassified_images/mis_eg_0.jpg'))
# input_img = org_img
# elif num_gradcam_imgs > 0: # show gradcam on example images
# org_img = np.asarray(Image.open('examples/car.jpg'))
# input_img = org_img
# else: # nothing chosen - misclassified or gradcam
# org_img = input_img
# # model inference
# transform = transforms.ToTensor()
# input_img = transform(input_img)
# input_img = input_img.unsqueeze(0)
# outputs = inference_model.model(input_img)
# softmax = torch.nn.Softmax(dim=0)
# o = softmax(outputs.flatten())
# confidences = {classes[i]: float(o[i]) for i in range(10)}
# sorted_confidences = dict(sorted(confidences.items(), key=lambda x:x[1], reverse=True))
# _, prediction = torch.max(outputs, 1)
# # gradcam
# if see_gradcam:
# target_layers = [inference_model.model.layer2[target_layer_number]]
# cam = GradCAM(model=inference_model.model, target_layers=target_layers, use_cuda=False)
# grayscale_cam = cam(input_tensor=input_img, targets=None)
# grayscale_cam = grayscale_cam[0, :]
# visualization = show_cam_on_image(org_img/255.0, grayscale_cam, use_rgb=True, image_weight=transparency)
# else:
# visualization = org_img
# # top n classes only
# sorted_confidences = {k: sorted_confidences[k] for k in list(sorted_confidences)[:top_classes]}
# return sorted_confidences, [visualization]
# title = "CIFAR10 trained on Custom ResNet Model with GradCAM"
# description = "A Gradio interface to infer on ResNet model, and get GradCAM results"
# examples = [["examples/cat.jpg"], ["examples/plane.jpg"],["examples/dog.jpg"],["examples/truck.jpg"],["examples/bird.jpg"],["examples/ship.jpg"],["examples/horse.jpg"],["examples/frog.jpg"],["examples/deer.jpg"],["examples/car.jpg"]]
# demo = gr.Interface(
# inference,
# inputs = [gr.Image(shape=(32, 32), label="Input Image"), gr.Checkbox(label="Misclassified"),gr.Number(value=2,minimum=0,maximum=10,label="Total Misclassified Images"),gr.Checkbox(label="Gradcam"),gr.Number(value=2,minimum=0,maximum=10,label="Total GradCam Images"),gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM"), gr.Slider(-2, -1, value = -1, step=1, label="Which Layer?"), gr.Slider(1, 10, value=3, step=1, label="How many top classes?")],
# outputs = [gr.Label(), gr.Gallery(label="Output Images", show_label=False, elem_id="gallery").style(columns=[2], rows=[5], object_fit="contain", height="auto")],
# title = title,
# description = description,
# examples = examples)
# demo.launch()
def inference_up_img(input_img,see_gradcam= True,target_layer_number = -1,transparency = 0.85,top_classes=3):
org_img = input_img
# model inference
transform = transforms.ToTensor()
input_img = transform(input_img)
input_img = input_img.unsqueeze(0)
outputs = inference_model.model(input_img)
softmax = torch.nn.Softmax(dim=0)
o = softmax(outputs.flatten())
confidences = {classes[i]: float(o[i]) for i in range(10)}
sorted_confidences = dict(sorted(confidences.items(), key=lambda x:x[1], reverse=True))
_, prediction = torch.max(outputs, 1)
# gradcam
if see_gradcam:
target_layers = [inference_model.model.layer2[target_layer_number]]
cam = GradCAM(model=inference_model.model, target_layers=target_layers, use_cuda=False)
grayscale_cam = cam(input_tensor=input_img, targets=None)
grayscale_cam = grayscale_cam[0, :]
visualization = show_cam_on_image(org_img/255.0, grayscale_cam, use_rgb=True, image_weight=transparency)
else:
visualization = org_img
# top n classes only
sorted_confidences = {k: sorted_confidences[k] for k in list(sorted_confidences)[:top_classes]}
return sorted_confidences, visualization
def misclass_fn(misclassified_check,num_misclassified=1,see_gradcam=True,num_gradcam=1,gradcam_layer=-2,gradcam_opa= 0.50):
img_gallery = []
if misclassified_check:
for i in range(int(num_misclassified)):
org_img = np.asarray(Image.open('misclassified_images/mis_eg_' + str(i) + '.jpg'))
input_img = org_img
if see_gradcam:
transform = transforms.ToTensor()
input_img = transform(input_img)
input_img = input_img.unsqueeze(0)
target_layers = [inference_model.model.layer2[gradcam_layer]]
cam = GradCAM(model=inference_model.model, target_layers=target_layers, use_cuda=False)
grayscale_cam = cam(input_tensor=input_img, targets=None)
grayscale_cam = grayscale_cam[0, :]
visualization = show_cam_on_image(org_img/255.0, grayscale_cam, use_rgb=True, image_weight=gradcam_opa)
img_gallery.append(visualization)
else:
img_gallery.append(org_img)
elif see_gradcam:
for i in range(int(num_gradcam)):
org_img = np.asarray(Image.open('misclassified_images/mis_eg_' + str(i) + '.jpg'))
input_img = org_img
transform = transforms.ToTensor()
input_img = transform(input_img)
input_img = input_img.unsqueeze(0)
target_layers = [inference_model.model.layer2[gradcam_layer]]
cam = GradCAM(model=inference_model.model, target_layers=target_layers, use_cuda=False)
grayscale_cam = cam(input_tensor=input_img, targets=None)
grayscale_cam = grayscale_cam[0, :]
visualization = show_cam_on_image(org_img/255.0, grayscale_cam, use_rgb=True, image_weight=gradcam_opa)
img_gallery.append(visualization)
return img_gallery
examples = [["examples/cat.jpg"], ["examples/plane.jpg"],["examples/dog.jpg"],["examples/truck.jpg"],["examples/bird.jpg"],["examples/ship.jpg"],["examples/horse.jpg"],["examples/frog.jpg"],["examples/deer.jpg"],["examples/car.jpg"]]
with gr.Blocks() as demo:
gr.Markdown("Explore Custom ResNet model for CIFAR10.")
with gr.Tab("Upload your own image"):
with gr.Row():
image_input = gr.Image(shape=(32, 32), label="Input Image")
image_label = gr.Label()
with gr.Row():
with gr.Column():
gradcam_check = gr.Checkbox(label="Gradcam")
gradcam_layer = gr.Slider(-2, -1, value = -1, step=1, label="Which Layer?")
gradcam_opa = gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM")
top_classes = gr.Slider(1, 10, value=3, step=1, label="How many top classes?")
image_output = gr.Image(shape=(32, 32), label="Output").style(width=128, height=128)
with gr.Row():
examples = gr.Examples(examples=examples,
inputs=[image_input,gradcam_check,gradcam_layer,gradcam_opa,top_classes,image_label],
outputs=[image_output],
fn=inference_up_img, cache_examples=False)
with gr.Row():
tab_1_button = gr.Button("Submit")
tab_1_cl_button = gr.ClearButton([image_input,gradcam_check,gradcam_layer,gradcam_opa,top_classes,image_label,image_output])
with gr.Tab("Explore Misclassified/Gradcam Images"):
with gr.Row():
with gr.Column():
misclassified_check = gr.Checkbox(label="Misclassified")
num_misclassified = gr.Number(value=2,minimum=1,maximum=10,label="Total Misclassified Images")
gradcam_check1 = gr.Checkbox(label="Gradcam")
num_gradcam = gr.Number(value=2,minimum=1,maximum=10,label="Total Gradcam Images")
gradcam_layer1 = gr.Slider(-2, -1, value = -1, step=1, label="Which Layer?")
gradcam_opa1 = gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM")
image_gallery_output = gr.Gallery(label="Output Images", show_label=False, elem_id="gallery").style(columns=[2], rows=[5], object_fit="contain", height="auto")
with gr.Row():
tab_2_button = gr.Button("Submit")
tab_2_cl_button = gr.ClearButton([misclassified_check,num_misclassified,gradcam_check1,num_gradcam,gradcam_layer1,gradcam_opa1,image_gallery_output])
tab_1_button.click(inference_up_img, inputs=[image_input,gradcam_check,gradcam_layer,gradcam_opa,top_classes], outputs=[image_label,image_output])
tab_2_button.click(misclass_fn, inputs=[misclassified_check,num_misclassified,gradcam_check1,num_gradcam,gradcam_layer1,gradcam_opa1], outputs=[image_gallery_output])
demo.launch(debug=True)