Vvaann's picture
Update app.py
099df44 verified
raw
history blame
2.65 kB
import torch
import torchvision
from torchvision import transforms
import numpy as np
import gradio as gr
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from resnet_lightning import ResNet18Model
import gradio as gr
model = ResNet18Model.load_from_checkpoint("epoch=19-step=3920.ckpt")
inv_normalize = transforms.Normalize(
mean = [-0.50/0.23, -0.50/0.23, -0.50/0.23],
std= [1/0.23, 1/0.23,1/0.23]
)
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
def resize_image_pil(image, new_width, new_height):
img = Image.fromarray(np.array(image))
width, height = img.size
width_scale = new_width/width
height_scale = new_height/height
scale = min(width_scale, height_scale)
resized = img.resize((int(width*scale), int(height*scale)), Image.NEAREST)
resized = resized.crop((0,0,new_width, new_height))
return resized
def inference(input_img, transparancy = 0.5, target_layer_number = -1):
input_img = resize_image_pil(input_img,32,32)
input_img = np.array(input_img)
org_img = input_img
input_img= input_img.reshape((32,32,3))
transform = transforms.ToTensor()
input_img = transform(input_img)
input_img = input_img.unsqueeze(0)
outputs = model(input_img)
print(outputs)
softmax = torch.nn.Softmax(dim=0)
o = softmax(outputs.flatten())
confidences = {classes[i]:float(o[i]) for i in range(10)}
prediction= torch.max(outputs, 1)
target_layers = [model.layer2[target_layer_number]]
cam = GradCAM(model=model, target_layers=target_layers)
grayscale_cam = cam(input_tensor= input_img)
grayscale_cam = grayscale_cam[0, :]
visualization = show_cam_on_image(org_img/255,grayscale_cam,use_rgb=True,
image_weight=transparancy)
return classes[prediction[0].item(),visualization,confidences]
demo = gr.Interface(
inference,
inputs = [
gr.Image(width=256,height=256,label="input image"),
gr.Slider(0,1,value=0.5,label="Overall opacity of the overelay"),
gr.Slider(-2,-1, value =-2, step=1, label= "Which layer for Gradcam")
],
outputs = [
"text",
gr.Image(width= 256, height=256,label="Output"),
gr.Label(num_top_classes=3)
],
title = "CIFAR 10 trained on ResNet model in pytorch lightning with Gradcam",
description = " A simple gradio inference to infer on resnet18 model",
examples = [["cat.jpg", 0.5, -1],["dog.jpg",0.7,-2]]
)
if __name__ == "__main__":
demo.launch()