RashiAgarwal commited on
Commit
f739176
·
1 Parent(s): 84b79fd

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -0
app.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pandas as pd
3
+ import numpy as np
4
+ import gradio as gr
5
+ from PIL import Image
6
+ from torch.nn import functional as F
7
+ from collections import OrderedDict
8
+ from torchvision import transforms
9
+ from pytorch_grad_cam import GradCAM
10
+ from pytorch_grad_cam.utils.image import show_cam_on_image
11
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
12
+ from pytorch_lightning import LightningModule, Trainer, seed_everything
13
+ import albumentations as A
14
+ from albumentations.pytorch import ToTensorV2
15
+ import torchvision.transforms as T
16
+ from custom_resnet import LitResnet
17
+
18
+ classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
19
+
20
+ wrong_img = pd.read_csv('misclassified_data.csv')
21
+ wrong_img_no = len(wrong_img)
22
+
23
+ model = LitResnet()
24
+ model.load_state_dict(torch.load("model.pth", map_location=torch.device('cpu')), strict=False)
25
+ model.eval()
26
+
27
+ transform = transforms.Compose([
28
+ transforms.ToTensor(),
29
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
30
+ ])
31
+
32
+ inv_normalize = T.Normalize(
33
+ mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23],
34
+ std=[1/0.23, 1/0.23, 1/0.23])
35
+
36
+ grad_cams = [GradCAM(model=model, target_layers=[model.convblock3[i]], use_cuda=False) for i in range(5)]
37
+
38
+ def get_gradcam_image(input_tensor, label, target_layer):
39
+ grad_cam = grad_cams[target_layer]
40
+ targets = [ClassifierOutputTarget(label)]
41
+ grayscale_cam = grad_cam(input_tensor=input_tensor, targets=targets)
42
+ grayscale_cam = grayscale_cam[0, :]
43
+ return grayscale_cam
44
+
45
+
46
+ def image_classifier(input_image, top_classes=3, show_cam=True, target_layers=[2, 3], transparency=0.5):
47
+ orig_image = input_image
48
+ input_image = transform(input_image)
49
+
50
+ input_image = input_image.unsqueeze(0)
51
+ output = model(input_image)
52
+
53
+ softmax = torch.nn.Softmax(dim=0)
54
+ o = softmax(output.flatten())
55
+
56
+ confidences = {classes[i]: float(o[i]) for i in range(10)}
57
+ confidences_sorted = dict(sorted(confidences.items(), key=lambda x:x[1],reverse=True))
58
+ confidences = {k: confidences_sorted[k] for k in list(confidences_sorted)[:top_classes]}
59
+ _, label = torch.max(output, 1)
60
+
61
+ outputs = list()
62
+ if show_cam:
63
+ for layer in target_layers:
64
+ grayscale_cam = get_gradcam_image(input_image, label, layer)
65
+ output_image = show_cam_on_image(orig_image / 255, grayscale_cam, use_rgb=True, image_weight=transparency)
66
+ outputs.append((output_image, f"Layer {layer - 5}"))
67
+
68
+ return outputs, confidences
69
+
70
+ #examples = [["examples/cat.jpg", 3, True,["-2","-1"],0.5], ["examples/dog.jpg", 3, True,["-2","-1"],0.5]]
71
+ examples = []
72
+ for i in range(10):
73
+ examples.append([f'examples/{classes[i]}.jpg', 3, True,["-2","-1"],0.5])
74
+
75
+ demo_1 = gr.Interface(
76
+ fn=image_classifier,
77
+ inputs=[
78
+ gr.Image(shape=(32, 32), label="Input Image").style(width=128, height=128),
79
+ gr.Slider(1, 10, value=3, step=1, label="Top Classes",
80
+ info="How many top classes do you want to see?"),
81
+ gr.Checkbox(label="Enable GradCAM", value=True, info="Do you want to see GradCAM Images?"),
82
+ gr.CheckboxGroup(["-5","-4", "-3", "-2", "-1"], value=["-2", "-1"], label="Network Layers", type='index',
83
+ info="Which layer(s) GradCAM do you want to visualize?",),
84
+ gr.Slider(0, 1, value=0.5, label="Transparency", step=0.1,
85
+ info="Set Transparency of CAMs")
86
+ ],
87
+ outputs=[gr.Gallery(label="Output Images", columns=2, rows=2), gr.Label(label='Top Classes')],
88
+ examples=examples
89
+ )
90
+
91
+
92
+ def show_incorrect(num_examples=10):
93
+ result = list()
94
+ for i in range(num_examples):
95
+ j = np.random.randint(1,wrong_img_no)
96
+ image = np.asarray(Image.open(f'Misclassified_images/{j}.jpg'))
97
+ actual = classes[wrong_img.loc[j-1].at["actual"]]
98
+ predicted = classes[wrong_img.loc[j-1].at["predicted"]]
99
+
100
+ result.append((image, f"Actual:{actual} / Predicted:{predicted}"))
101
+
102
+ return result
103
+
104
+
105
+ demo_2 = gr.Interface(
106
+ fn=show_incorrect,
107
+ inputs=[
108
+ gr.Number(value=10, minimum=1, maximum=50, label="Input number(s) of images", precision=0,
109
+ info="How many misclassified examples do you want to view? (max 50)")
110
+ ],
111
+ outputs=[gr.Gallery(label="Misclassified Images (Actual / Predicted)", columns=5)]
112
+ )
113
+
114
+ demo = gr.TabbedInterface([demo_1, demo_2], ["Image Classifier", "Misclassified Images"])
115
+ demo.launch()