Vvaann's picture
Update app.py
e5ef4e7 verified
raw
history blame
3.69 kB
import torch
import torchvision
from torchvision import transforms
import numpy as np
import gradio as gr
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from resnet_lightning import ResNet18Model
import gradio as gr
model = ResNet18Model.load_from_checkpoint("epoch=19-step=3920.ckpt")
inv_normalize = transforms.Normalize(
mean = [-0.50/0.23, -0.50/0.23, -0.50/0.23],
std= [1/0.23, 1/0.23,1/0.23]
)
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
model_layer_names = ["0", "1", "2", "3"]
def get_layer(layer_name):
print(layer_name)
if layer_name == 0:
return [model.prep[-1]]
elif layer_name == 1:
return [model.layer1[-1]]
elif layer_name == 2:
return [model.layer2[-1]]
elif layer_name == 3:
return [model.layer3[-1]]
else:
return None
def resize_image_pil(image, new_width, new_height):
img = Image.fromarray(np.array(image))
width, height = img.size
width_scale = new_width/width
height_scale = new_height/height
scale = min(width_scale, height_scale)
resized = img.resize((int(width*scale), int(height*scale)), Image.NEAREST)
resized = resized.crop((0,0,new_width, new_height))
return resized
def inference(input_img, show_gradcam, layer_name, num_classes, transparancy = 0.5):
print(show_gradcam, layer_name, num_classes, transparancy)
input_img = resize_image_pil(input_img,32,32)
input_img = np.array(input_img)
org_img = input_img
input_img= input_img.reshape((32,32,3))
transform = transforms.ToTensor()
input_img = transform(input_img)
input_img = input_img.unsqueeze(0)
outputs = model(input_img)
# print(outputs)
softmax = torch.nn.Softmax(dim=0)
o = softmax(outputs.flatten())
output_numpy = np.squeeze(np.asarray(outputs.detach().numpy()))
index_sort = np.argsort(output_numpy)[::-1]
confidences = {}
for i in range(int(num_classes)):
confidences[classes[index_sort[i]]] = float(o[index_sort[i]])
prediction= torch.max(outputs, 1)
if show_gradcam:
target_layers = get_layer(layer_name)
print(target_layers)
cam = GradCAM(model=model, target_layers=target_layers)
grayscale_cam = cam(input_tensor= input_img)
grayscale_cam = grayscale_cam[0, :]
visualization = show_cam_on_image(org_img/255,grayscale_cam,use_rgb=True,
image_weight=transparancy)
else:
visualization = org_img
return classes[int(prediction[0].item())], visualization, confidences
demo = gr.Interface(
inference,
inputs = [
gr.Image(width=256,height=256,label="Input image"),
gr.Number(value=3, maximum=10, minimum=1,step=1.0, precision=0,label="Number of classes to display"),
gr.Checkbox(True, label="Show GradCAM Image"),
gr.Dropdown(model_layer_names, value=3, label="Which layer for Gradcam"),
gr.Slider(0, 1, value=0.5,label="Overall opacity of the overlay"),
],
outputs = [
gr.Label(label="Class", container=True, show_label= True),
gr.Image(width= 256, height=256,label="Output Image"),
gr.Label(label="Confidences", container=True, show_label= True),
],
title = "CIFAR 10 trained on ResNet model in pytorch lightning with Gradcam",
description = " A simple gradio inference to infer on resnet18 model",
examples = [["cat.jpg", 1, True, 10, -1],
["dog.jpg", 1, False, 4, -1]]
)
if __name__ == "__main__":
demo.launch()