File size: 4,817 Bytes
6192a31 fe72ddb 462c3fa f9afe8f 462c3fa 9c84d1a 17771de f9afe8f 462c3fa db5858d 9c84d1a 051a4e1 462c3fa f9afe8f fa845ee fe72ddb 462c3fa fe72ddb 6192a31 839df3e 7724b1c 839df3e f9afe8f 90aceef 121d531 7724b1c 121d531 a6456c8 90aceef a6456c8 f9afe8f 2fd201f f9afe8f 74ec669 6192a31 f9afe8f 7d7bd68 fe72ddb f9afe8f 8addcc6 f9afe8f fe72ddb f9afe8f 2fd201f 816472f 462c3fa 09b5b07 f9afe8f e1322ae 462c3fa 6192a31 561de13 7724b1c f9afe8f 7724b1c 561de13 6192a31 462c3fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import torch
import torchvision
from torchvision import transforms
import numpy as np
import gradio as gr
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from resnetS11 import LITResNet
import os
import re
import matplotlib.pyplot as plt
from io import BytesIO
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768))])
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
model = LITResNet(classes)
model.load_state_dict(torch.load("model.pth",map_location=torch.device('cpu'))["state_dict"])
model.eval()
modellayers = list(dict(model.named_modules()))
def inference(input_img, num_gradcam_images=1, target_layer_number=-1, transparency=0.5, show_misclassified=False, num_top_classes=3, num_misclassified_images=3):
input_img = np.array(Image.fromarray(np.array(input_img)).resize((32, 32)))
org_img = input_img
input_img = transform(input_img).unsqueeze(0)
outputs = model(input_img)
softmax = torch.nn.Softmax(dim=0)
o = softmax(outputs.flatten())
confidences = {classes[i]: float(o[i]) for i in range(10)}
_, prediction = torch.max(outputs, 1)
visualization =[]
for item in range(1, num_gradcam_images+1):
cam = GradCAM(model=model, target_layers = [model.layer2[-item]])
grayscale_cam = cam(input_tensor=input_img, targets=None)
grayscale_cam = grayscale_cam[0, :]
rgb_img = np.transpose(org_img, (1, 2, 0))
visualization.append(show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency))
fig = plt.figure(figsize=(12, 5))
for i in range(len(visualization)):
ax = fig.add_subplot(2, 5, i + 1)
ax.imshow(visualization[i])
ax.axis('off')
plt.tight_layout()
buffer = BytesIO()
plt.savefig(buffer, format='png')
visualization = Image.open(buffer)
# Sort the confidences dictionary based on confidence values
sorted_confidences = dict(sorted(confidences.items(), key=lambda item: item[1], reverse=True))
# Pick the top n predictions
top_n_confidences = dict(list(sorted_confidences.items())[:num_top_classes])
if show_misclassified:
files = os.listdir('./misclassified/')
# Plot the misclassified images
fig = plt.figure(figsize=(12, 5))
for i in range(num_misclassified_images):
sub = fig.add_subplot(2, 5, i+1)
npimg = Image.open('./misclassified/' + files[i])
# Use regex to extract target and predicted classes
match = re.search(r'(\w+)_(\w+).png', files[i])
target_class = match.group(1)
predicted_class = match.group(2)
plt.imshow(npimg, cmap='gray', interpolation='none')
sub.set_title("Actual: {}, Pred: {}".format(target_class, predicted_class), color='red')
plt.tight_layout()
buffer = BytesIO()
plt.savefig(buffer, format='png')
visualization_misclassified = Image.open(buffer)
return top_n_confidences, visualization, visualization_misclassified
else:
return top_n_confidences, visualization, None
title = "CIFAR10 trained on ResNet18 Model using Pytorch Lightning with GradCAM"
description = "A simple Gradio interface to infer on ResNet18 model using Pytorch Lightning, and get GradCAM results"
examples = [["cat.jpg", 1, -1, 0.8, True, 3, 3],
["dog.jpg", 1, -1, 0.8, True, 3, 3],
["plane.jpg", 1, -1, 0.8, True, 3, 3],
["deer.jpg", 1, -1, 0.8, True, 3, 3],
["horse.jpg", 1, -1, 0.8, True, 3, 3],
["bird.jpg", 1, -1, 0.8, True, 3, 3],
["frog.jpg", 1, -1, 0.8, True, 3, 3],
["ship.jpg", 1, -1, 0.8, True, 3, 3],
["truck.jpg", 1, -1, 0.8, True, 3, 3],
["car.jpg", 1, -1, 0.8, True, 3, 3]]
demo = gr.Interface(
inference,
inputs=[gr.Image(width=256, height=256, label="Input Image"),
gr.Slider(1, 2, value=1, step=1, label="Number of GradCAM Images"),
gr.Slider(-2, -1, value=-1, step=1, label="Which Layer?"),
gr.Slider(0, 1, value=0.8, label="Opacity of GradCAM"),
gr.Checkbox(value=True, label="Show Misclassified Images"),
gr.Slider(2, 10, value=3, step=1, label="Top Predictions"),
gr.Slider(1, 10, value=3, step=1, label="Misclassified Images")],
outputs=[gr.Label(label="Top Predictions"),
gr.Image(label="Output",width=640, height=360),
gr.Image(label="Misclassified Images",width=640, height=360)],
title=title,
description=description,
examples=examples,
)
demo.launch() |