S13 / app.py
Shivdutta's picture
Update app.py
09b5b07 verified
import torch
import torchvision
from torchvision import transforms
import numpy as np
import gradio as gr
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from resnetS11 import LITResNet
import os
import re
import matplotlib.pyplot as plt
from io import BytesIO
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768))])
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
model = LITResNet(classes)
model.load_state_dict(torch.load("model.pth",map_location=torch.device('cpu'))["state_dict"])
model.eval()
modellayers = list(dict(model.named_modules()))
def inference(input_img, num_gradcam_images=1, target_layer_number=-1, transparency=0.5, show_misclassified=False, num_top_classes=3, num_misclassified_images=3):
input_img = np.array(Image.fromarray(np.array(input_img)).resize((32, 32)))
org_img = input_img
input_img = transform(input_img).unsqueeze(0)
outputs = model(input_img)
softmax = torch.nn.Softmax(dim=0)
o = softmax(outputs.flatten())
confidences = {classes[i]: float(o[i]) for i in range(10)}
_, prediction = torch.max(outputs, 1)
visualization =[]
for item in range(1, num_gradcam_images+1):
cam = GradCAM(model=model, target_layers = [model.layer2[-item]])
grayscale_cam = cam(input_tensor=input_img, targets=None)
grayscale_cam = grayscale_cam[0, :]
rgb_img = np.transpose(org_img, (1, 2, 0))
visualization.append(show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency))
fig = plt.figure(figsize=(12, 5))
for i in range(len(visualization)):
ax = fig.add_subplot(2, 5, i + 1)
ax.imshow(visualization[i])
ax.axis('off')
plt.tight_layout()
buffer = BytesIO()
plt.savefig(buffer, format='png')
visualization = Image.open(buffer)
# Sort the confidences dictionary based on confidence values
sorted_confidences = dict(sorted(confidences.items(), key=lambda item: item[1], reverse=True))
# Pick the top n predictions
top_n_confidences = dict(list(sorted_confidences.items())[:num_top_classes])
if show_misclassified:
files = os.listdir('./misclassified/')
# Plot the misclassified images
fig = plt.figure(figsize=(12, 5))
for i in range(num_misclassified_images):
sub = fig.add_subplot(2, 5, i+1)
npimg = Image.open('./misclassified/' + files[i])
# Use regex to extract target and predicted classes
match = re.search(r'(\w+)_(\w+).png', files[i])
target_class = match.group(1)
predicted_class = match.group(2)
plt.imshow(npimg, cmap='gray', interpolation='none')
sub.set_title("Actual: {}, Pred: {}".format(target_class, predicted_class), color='red')
plt.tight_layout()
buffer = BytesIO()
plt.savefig(buffer, format='png')
visualization_misclassified = Image.open(buffer)
return top_n_confidences, visualization, visualization_misclassified
else:
return top_n_confidences, visualization, None
title = "CIFAR10 trained on ResNet18 Model using Pytorch Lightning with GradCAM"
description = "A simple Gradio interface to infer on ResNet18 model using Pytorch Lightning, and get GradCAM results"
examples = [["cat.jpg", 1, -1, 0.8, True, 3, 3],
["dog.jpg", 1, -1, 0.8, True, 3, 3],
["plane.jpg", 1, -1, 0.8, True, 3, 3],
["deer.jpg", 1, -1, 0.8, True, 3, 3],
["horse.jpg", 1, -1, 0.8, True, 3, 3],
["bird.jpg", 1, -1, 0.8, True, 3, 3],
["frog.jpg", 1, -1, 0.8, True, 3, 3],
["ship.jpg", 1, -1, 0.8, True, 3, 3],
["truck.jpg", 1, -1, 0.8, True, 3, 3],
["car.jpg", 1, -1, 0.8, True, 3, 3]]
demo = gr.Interface(
inference,
inputs=[gr.Image(width=256, height=256, label="Input Image"),
gr.Slider(1, 2, value=1, step=1, label="Number of GradCAM Images"),
gr.Slider(-2, -1, value=-1, step=1, label="Which Layer?"),
gr.Slider(0, 1, value=0.8, label="Opacity of GradCAM"),
gr.Checkbox(value=True, label="Show Misclassified Images"),
gr.Slider(2, 10, value=3, step=1, label="Top Predictions"),
gr.Slider(1, 10, value=3, step=1, label="Misclassified Images")],
outputs=[gr.Label(label="Top Predictions"),
gr.Image(label="Output",width=640, height=360),
gr.Image(label="Misclassified Images",width=640, height=360)],
title=title,
description=description,
examples=examples,
)
demo.launch()