File size: 4,817 Bytes
6192a31
fe72ddb
462c3fa
 
 
 
 
 
 
f9afe8f
 
 
 
462c3fa
9c84d1a
 
 
17771de
f9afe8f
462c3fa
 
db5858d
9c84d1a
051a4e1
462c3fa
f9afe8f
 
fa845ee
fe72ddb
462c3fa
fe72ddb
 
 
 
6192a31
839df3e
7724b1c
839df3e
 
 
 
 
f9afe8f
90aceef
121d531
7724b1c
121d531
a6456c8
90aceef
a6456c8
 
 
 
 
f9afe8f
 
 
 
 
2fd201f
f9afe8f
74ec669
6192a31
f9afe8f
 
 
 
7d7bd68
fe72ddb
f9afe8f
8addcc6
f9afe8f
 
fe72ddb
f9afe8f
 
 
 
 
 
 
 
2fd201f
816472f
462c3fa
09b5b07
 
f9afe8f
 
e1322ae
 
 
 
 
 
 
 
462c3fa
6192a31
561de13
7724b1c
f9afe8f
 
 
 
 
 
7724b1c
561de13
6192a31
 
 
462c3fa
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch
import torchvision
from torchvision import transforms
import numpy as np
import gradio as gr
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from resnetS11 import LITResNet
import os
import re
import matplotlib.pyplot as plt
from io import BytesIO

transform = transforms.Compose([
      transforms.ToTensor(),
      transforms.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768))])

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

model = LITResNet(classes)
model.load_state_dict(torch.load("model.pth",map_location=torch.device('cpu'))["state_dict"])
model.eval()
modellayers = list(dict(model.named_modules()))

def inference(input_img, num_gradcam_images=1, target_layer_number=-1, transparency=0.5, show_misclassified=False, num_top_classes=3, num_misclassified_images=3):
    input_img = np.array(Image.fromarray(np.array(input_img)).resize((32, 32)))
    org_img = input_img
    input_img = transform(input_img).unsqueeze(0)
    outputs = model(input_img)
    softmax = torch.nn.Softmax(dim=0)
    o = softmax(outputs.flatten())
    confidences = {classes[i]: float(o[i]) for i in range(10)}
    _, prediction = torch.max(outputs, 1)

    visualization =[]
    for item in range(1, num_gradcam_images+1):
        cam = GradCAM(model=model, target_layers = [model.layer2[-item]])
        grayscale_cam = cam(input_tensor=input_img, targets=None)
        grayscale_cam = grayscale_cam[0, :]
        rgb_img = np.transpose(org_img, (1, 2, 0))
        visualization.append(show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency))

    fig = plt.figure(figsize=(12, 5))
    for i in range(len(visualization)):
        ax = fig.add_subplot(2, 5, i + 1)
        ax.imshow(visualization[i])
        ax.axis('off')

    plt.tight_layout()
    buffer = BytesIO()
    plt.savefig(buffer, format='png')
    visualization = Image.open(buffer)
        
    # Sort the confidences dictionary based on confidence values
    sorted_confidences = dict(sorted(confidences.items(), key=lambda item: item[1], reverse=True))

    # Pick the top n predictions
    top_n_confidences = dict(list(sorted_confidences.items())[:num_top_classes])

    if show_misclassified:
        files = os.listdir('./misclassified/')

        # Plot the misclassified images
        fig = plt.figure(figsize=(12, 5))
        for i in range(num_misclassified_images):
            sub = fig.add_subplot(2, 5, i+1)
            npimg = Image.open('./misclassified/' + files[i])

            # Use regex to extract target and predicted classes
            match = re.search(r'(\w+)_(\w+).png', files[i])
            target_class = match.group(1)
            predicted_class = match.group(2)

            plt.imshow(npimg, cmap='gray', interpolation='none')
            sub.set_title("Actual: {}, Pred: {}".format(target_class, predicted_class), color='red')
        plt.tight_layout()
        buffer = BytesIO()
        plt.savefig(buffer, format='png')
        visualization_misclassified = Image.open(buffer)

        return top_n_confidences, visualization, visualization_misclassified
    else:
        return top_n_confidences, visualization, None

title = "CIFAR10 trained on ResNet18 Model using Pytorch Lightning with GradCAM"
description = "A simple Gradio interface to infer on ResNet18 model using Pytorch Lightning, and get GradCAM results"
examples = [["cat.jpg", 1, -1, 0.8, True, 3, 3],
            ["dog.jpg", 1, -1, 0.8, True, 3, 3],
            ["plane.jpg", 1, -1, 0.8, True, 3, 3],
            ["deer.jpg", 1, -1, 0.8, True, 3, 3],
            ["horse.jpg", 1, -1, 0.8, True, 3, 3],
            ["bird.jpg", 1, -1, 0.8, True, 3, 3],
            ["frog.jpg", 1, -1, 0.8, True, 3, 3],
            ["ship.jpg", 1, -1, 0.8, True, 3, 3],
            ["truck.jpg", 1, -1, 0.8, True, 3, 3],
            ["car.jpg", 1, -1, 0.8, True, 3, 3]]
demo = gr.Interface(
    inference,
    inputs=[gr.Image(width=256, height=256, label="Input Image"),
            gr.Slider(1, 2, value=1, step=1, label="Number of GradCAM Images"),
            gr.Slider(-2, -1, value=-1, step=1, label="Which Layer?"),
            gr.Slider(0, 1, value=0.8, label="Opacity of GradCAM"),
            gr.Checkbox(value=True, label="Show Misclassified Images"),
            gr.Slider(2, 10, value=3, step=1, label="Top Predictions"),
            gr.Slider(1, 10, value=3, step=1, label="Misclassified Images")],
    outputs=[gr.Label(label="Top Predictions"),
             gr.Image(label="Output",width=640, height=360),
             gr.Image(label="Misclassified Images",width=640, height=360)],
    title=title,
    description=description,
    examples=examples,
)
demo.launch()