File size: 4,100 Bytes
24addac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eaf3508
24addac
 
 
 
 
eaf3508
24addac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch, torchvision
from torchvision import transforms
import numpy as np
import gradio as gr
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from custom_resnet import Net
import gradio as gr
from io import BytesIO
import os, re
import matplotlib.pyplot as plt


model = Net()
model.load_state_dict(torch.load("custom_resnet_model.pt", map_location=torch.device('cpu')), strict=False)

inv_normalize = transforms.Normalize(
    mean=[0.4914, 0.4822, 0.4471],
    std=[0.2469, 0.2433, 0.2615]
)
classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

def inference(input_img, transparency = 0.5, target_layer_number = -1, top_predictions=3, miss_classified_images_count=3):
    transform = transforms.ToTensor()
    org_img = input_img
    input_img = transform(input_img)
    input_img = input_img
    input_img = input_img.unsqueeze(0)
    outputs = model(input_img)
    softmax = torch.nn.Softmax(dim=0)
    o = softmax(outputs.flatten())
    confidences = {classes[i]: float(o[i]) for i in range(10)}
    _, prediction = torch.max(outputs, 1)
    target_layers = [[model.X3],[model.R3]]
    cam = GradCAM(model=model, target_layers=target_layers[target_layer_number], use_cuda=False)
    grayscale_cam = cam(input_tensor=input_img, targets=None)
    grayscale_cam = grayscale_cam[0, :]
    img = input_img.squeeze(0)
    img = inv_normalize(img)
    rgb_img = np.transpose(img, (1, 2, 0))
    rgb_img = rgb_img.numpy()
    visualization = show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency)

    # Sort the confidences dictionary based on confidence values
    sorted_confidences = dict(sorted(confidences.items(), key=lambda item: item[1], reverse=True))
    
    # Pick the top n predictions
    top_n_confidences = dict(list(sorted_confidences.items())[:top_predictions])


    files = os.listdir('./misclassified_images/')

    # Plot the misclassified images
    fig = plt.figure(figsize=(12, 5))
    for i in range(miss_classified_images_count):
        sub = fig.add_subplot(2, 5, i+1)
        npimg = Image.open('./misclassified_images/' + files[i])

        # Use regex to extract target and predicted classes
        match = re.search(r'Target_(\w+)_Pred_(\w+)_\d+.jpeg', files[i])
        target_class = match.group(1)
        predicted_class = match.group(2)
        
        plt.imshow(npimg, cmap='gray', interpolation='none')
        sub.set_title("Actual: {}, Pred: {}".format(target_class, predicted_class),color='red')
    plt.tight_layout()
    buffer = BytesIO()
    plt.savefig(buffer,format='png')
    visualization_missclassified = Image.open(buffer)
    
    return top_n_confidences, visualization, visualization_missclassified

title = "Jaiyesh's ResNet18 Model with GradCAM"
description = "A simple Gradio interface to infer on ResNet model, and get GradCAM results"
examples = [["cat.jpg", 0.8, -1,3,3], 
            ["dog.jpeg", 0.8, -1,3,3], 
            ["plane.jpeg", 0.8, -1,3,3], 
            ["deer.jpeg", 0.8, -1,3,3], 
            ["horse.jpeg", 0.8, -1,3,3], 
            ["bird.jpeg", 0.8, -1,3,3],
            ["frog.jpeg", 0.8, -1,3,3], 
            ["ship.jpeg", 0.8, -1,3,3], 
            ["truck.jpeg", 0.8, -1,3,3], 
            ["car.jpeg", 0.8, -1,3,3]]
demo = gr.Interface(
    inference, 
    inputs = [gr.Image(shape=(32, 32), label="Input Image"), 
              gr.Slider(0, 1, value = 0.8, label="Opacity of GradCAM"), 
              gr.Slider(-2, -1, value = -1, step=1, label="Which Layer?"),
              gr.Slider(2, 10, value = 3, step=1, label="Top Predictions"),
              gr.Slider(1, 10, value = 3, step=1, label="Misclassified Images"),], 
    outputs = [gr.Label(label="Top Predictions"), 
               gr.Image(shape=(32, 32), label="Output").style(width=128, height=128),
               gr.Image(shape=(640, 360), label="Misclassified Images").style(width=640, height=360)],
    title = title,
    description = description,
    examples = examples,
) 
demo.launch()