peeyushsinghal commited on
Commit
2d2d4be
·
1 Parent(s): 4ff0321

Updated Files

Browse files
Files changed (14) hide show
  1. app.py +316 -0
  2. bird1.jpg +0 -0
  3. car1.jpg +0 -0
  4. cat1.jpg +0 -0
  5. deer1.jpg +0 -0
  6. dog1.jpg +0 -0
  7. frog1.jpg +0 -0
  8. horse1.jpg +0 -0
  9. misclassified_images_list.pt +3 -0
  10. model.pth +3 -0
  11. plane1.jpg +0 -0
  12. requirements.txt +8 -0
  13. ship1.jpg +0 -0
  14. truck1.jpg +0 -0
app.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """s12.ipynb
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1dtu0lhq50jTGmRRKyFaXfDCE4ZVW-w4H
8
+ """
9
+
10
+ !pip install gradio --quiet
11
+
12
+ !pip install grad_cam --quiet
13
+
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+
17
+ dropout_value = 0.1
18
+ class ResBlock(nn.Module):
19
+ def __init__(self, in_channels, out_channels):
20
+ super(ResBlock,self).__init__()
21
+ self.res_block = nn.Sequential(
22
+ nn.Conv2d(in_channels=in_channels, out_channels = out_channels, kernel_size=3, stride =1 , padding =1),
23
+ nn.BatchNorm2d(out_channels),
24
+ nn.ReLU(),
25
+ nn.Conv2d(in_channels=out_channels, out_channels = out_channels, kernel_size=3, stride =1 , padding =1),
26
+ nn.BatchNorm2d(out_channels),
27
+ nn.ReLU(),
28
+ )
29
+
30
+ def forward (self, x):
31
+ x = self.res_block(x)
32
+ return x
33
+
34
+
35
+ class LayerBlock(nn.Module):
36
+ def __init__(self, in_channels, out_channels):
37
+ super(LayerBlock,self).__init__()
38
+ self.layer_block = nn.Sequential(
39
+ nn.Conv2d(in_channels=in_channels, out_channels = out_channels, kernel_size=3, stride =1 , padding =1),
40
+ nn.MaxPool2d(kernel_size=2,stride=2),
41
+ nn.BatchNorm2d(out_channels),
42
+ nn.ReLU(),
43
+ )
44
+
45
+ def forward (self, x):
46
+ x = self.layer_block(x)
47
+ return x
48
+
49
+ class custom_resnet_s10(nn.Module):
50
+ def __init__(self, num_classes=10):
51
+ super(custom_resnet_s10,self).__init__()
52
+
53
+ self.PrepLayer = nn.Sequential(
54
+ nn.Conv2d(in_channels = 3, out_channels=64, kernel_size = 3, stride = 1, padding =1),
55
+ nn.BatchNorm2d(64),
56
+ nn.ReLU(),
57
+ )
58
+ self.Layer1 = LayerBlock(in_channels = 64, out_channels=128)
59
+ self.resblock1 = ResBlock(in_channels =128, out_channels=128)
60
+ self.Layer2 = LayerBlock(in_channels = 128, out_channels=256)
61
+ self.resblock2 = ResBlock(in_channels =256, out_channels=256)
62
+ self.Layer3 = LayerBlock(in_channels = 256, out_channels=512)
63
+ self.resblock3 = ResBlock(in_channels =512, out_channels=512)
64
+ self.max_pool4 = nn.MaxPool2d(kernel_size=4, stride=4) # 512,512, 4/4 = 512,512,1
65
+ self.fc = nn.Linear(512,num_classes)
66
+
67
+ def forward(self,x):
68
+ x = self.PrepLayer(x)
69
+ #################
70
+ x = self.Layer1(x)
71
+ # print("x..l1",x.shape)
72
+ resl1 = self.resblock1(x)
73
+ # print("resl1",resl1.shape)
74
+ x = x+resl1
75
+ # print("x..l1+resl1",x.shape)
76
+ #################
77
+ x = self.Layer2(x)
78
+ # print("x..l2",x.shape)
79
+ resl2 = self.resblock2(x)
80
+ # print("resl2",resl2.shape)
81
+ x = x+resl2
82
+ # print("x..l2+resl2",x.shape)
83
+ #################
84
+ x = self.Layer3(x)
85
+ # print("x..l3",x.shape)
86
+ resl3 = self.resblock3(x)
87
+ # print("resl3",resl3.shape)
88
+ x = x+resl3
89
+ # print("x..l3+resl3",x.shape)
90
+ #################
91
+ x = self.max_pool4(x)
92
+ # print("x..max_pool4",x.shape)
93
+ x = x.view(x.size(0),-1)
94
+ # print("x..flat",x.shape)
95
+ x = self.fc(x)
96
+ return x
97
+
98
+ # With Tabs
99
+ import gradio as gr
100
+ import torch
101
+ import random
102
+ from collections import OrderedDict
103
+ from pytorch_grad_cam import GradCAM
104
+ from pytorch_grad_cam.utils.image import show_cam_on_image
105
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
106
+ import numpy as np
107
+ from PIL import Image
108
+ from torchvision import transforms
109
+
110
+
111
+ def get_device():
112
+ if torch.cuda.is_available():
113
+ device = "cuda"
114
+ elif torch.backends.mps.is_available():
115
+ device = "mps"
116
+ else:
117
+ device = "cpu"
118
+ print("Device Selected:", device)
119
+ return device
120
+
121
+ DEVICE = get_device()
122
+ print(DEVICE)
123
+
124
+ # Load the list of tensors from the file
125
+ loaded_misclassified_image_list = torch.load('misclassified_images_list.pt')
126
+
127
+ # Instantiate the model (make sure it has the same architecture)
128
+ loaded_model = custom_resnet_s10()
129
+ loaded_model = loaded_model.to(DEVICE)
130
+
131
+ # Load the saved state dictionary
132
+ loaded_model.load_state_dict(torch.load('model.pth', map_location=DEVICE), strict=False)
133
+
134
+ # Put the loaded model in evaluation mode
135
+ loaded_model.eval()
136
+
137
+ classes = ['plane', 'car', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck']
138
+ mean = (0.49139968, 0.48215827, 0.44653124)
139
+ std = (0.24703233, 0.24348505, 0.26158768)
140
+ transform = transforms.Compose([
141
+ transforms.ToTensor(),
142
+ transforms.Normalize(mean=mean, std=std)
143
+ ])
144
+
145
+ dict_layer = {'layer3': loaded_model.resblock2.res_block[-1],
146
+ 'layer4': loaded_model.resblock3.res_block[-1]}
147
+
148
+ # def show_misclassified_images(num_images):
149
+ # return num_images
150
+
151
+ def view_gradcam_images(choice_gradcam):
152
+ if choice_gradcam == "Yes (View Existing Images)":
153
+ return gr.update(label ="Number of GradCAM Images to view", visible=True, interactive = True), \
154
+ gr.update(visible=True), \
155
+ gr.update(visible=True), gr.update(visible=True), \
156
+ gr.update(visible=False) # Gallery not shown as yet
157
+ else:
158
+ #TODO: to be completed
159
+ return gr.update(visible=False), gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False)
160
+
161
+ def process_gradcam_images(num_images,layer,opacity,image_list=None):
162
+ # images =["/content/cat1.jpg","/content/dog1.jpg","/content/horse1.jpg"]
163
+ # images_trimmed = [images[i] for i in range(min(num_images,len(images)))]
164
+ if not image_list:
165
+ selected_data = random.sample(loaded_misclassified_image_list, min(num_images,len(loaded_misclassified_image_list)))
166
+ else:
167
+ selected_data = [image_list]
168
+
169
+ layer_model = dict_layer.get(layer)
170
+ cam = GradCAM(model=loaded_model, target_layers = [layer_model], use_cuda = True)
171
+ grad_images = []
172
+ inv_normalize = transforms.Normalize(
173
+ mean=[-0.50/0.2197, -0.50/0.1858, -0.50/0.1569], # mean_ds = [0.2197, 0.1858, 0.1569]
174
+ std=[1/0.1810, 1/0.1635, 1/0.1511] # std_dev_ds =[0.1810, 0.1635, 0.1511]
175
+ )
176
+ for i, (img, pred, correct) in enumerate(selected_data):
177
+ input_tensor = img.unsqueeze(0)
178
+ targets = [ClassifierOutputTarget(pred)]
179
+ grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
180
+
181
+ grayscale_cam = grayscale_cam[0, :]
182
+
183
+ # Get back the original image
184
+ img = input_tensor.squeeze(0).to('cpu')
185
+ img = inv_normalize(img)
186
+ rgb_img = np.transpose(img, (1, 2, 0))
187
+ rgb_img = torch.clamp(rgb_img, max = 1)
188
+ rgb_img = rgb_img.numpy()
189
+
190
+ visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=opacity)
191
+
192
+ grad_images.append(((visualization),f'Pred: {classes[pred.cpu()]} | Truth :{classes[correct.cpu()]}'))
193
+
194
+ print(str(num_images) + "**" + str(layer) + "**" + str(opacity))
195
+ return grad_images, gr.update(visible=True)
196
+
197
+
198
+
199
+ def process_misclassified_images(num_images):
200
+ selected_data = random.sample(loaded_misclassified_image_list, min(num_images,len(loaded_misclassified_image_list)))
201
+ misclassified_images = []
202
+ for i, (img, pred, correct) in enumerate(selected_data):
203
+ img, pred, target = img.cpu().numpy().astype(dtype=np.float32), pred.cpu(), correct.cpu()
204
+ for j in range(img.shape[0]):
205
+ img[j] = (img[j] * std[j]) + mean[j]
206
+ img = np.transpose(img, (1, 2, 0))
207
+ img = Image.fromarray((img * 255).astype(np.uint8))
208
+ misclassified_images.append(((img),f'Pred: {classes[pred]} | Truth :{classes[correct]}'))
209
+ return misclassified_images, gr.update(visible=True)
210
+
211
+ def view_misclassified_images(choice_misclassified):
212
+ if choice_misclassified == "Yes":
213
+ return gr.update(label ="Number of Misclassified Images to view", visible=True, interactive = True),gr.update(visible=True),gr.update(visible=False)
214
+ else:
215
+ return gr.update(visible=False),gr.update(visible=False),gr.update(visible=False)
216
+
217
+ def classify_image(image, num_classes=3, grad_cam_choice = False, layer = None, opacity = 0.8 ):
218
+ # transforming image and getting prediction from model
219
+ transformed_image = transform(image)
220
+ image_tensor = transformed_image.to(DEVICE).unsqueeze(0)#transform(torch.tensor(image).to(DEVICE)).unsqueeze(0) # making it a batch
221
+
222
+ # sending it to model to get prediction
223
+ logits = loaded_model(image_tensor) # logits
224
+ output = F.softmax(logits.view(-1)) #F.softmax(output.flatten(), dim=-1) #
225
+
226
+ confidences = [(classes[i], float(output[i])) for i in range(len(classes))]
227
+ confidences.sort(key=lambda x: x[1], reverse=True)
228
+ confidences = OrderedDict(confidences[:num_classes])
229
+ label = torch.argmax(output).item()
230
+
231
+ # overall_dict = {"a":0.3,"b":0.2,"c":0.1,"d":0.05}
232
+ # key_list = itertools.islice(list(overall_dict.keys()),min(num_classes,4))
233
+ # trimmed_dict = {key:float(overall_dict[key]) for key in key_list}
234
+
235
+ if grad_cam_choice:
236
+ print("** Before Calling **",transformed_image.shape)
237
+ image_list = [transformed_image.to(DEVICE),torch.tensor(label).to(DEVICE),torch.tensor(label).to(DEVICE)]
238
+
239
+ grad_cam_output,_ = process_gradcam_images(num_images = 1,layer = layer,opacity= opacity,image_list=image_list)
240
+
241
+ return confidences, grad_cam_output
242
+ else:
243
+ return confidences, gr.update(visible=False)
244
+
245
+
246
+ with gr.Blocks() as demo:
247
+ with gr.Tab("GradCam"):
248
+ gr.Markdown(
249
+ """
250
+ Visualize Class Activations Maps (helps to see what the model is actually looking at in the image) generated by the model's layer for the predicted class
251
+ - For existing images
252
+ - For new images (choose an example image or upload your own)
253
+ """
254
+ )
255
+ with gr.Column():
256
+ with gr.Box():
257
+ radio_gradcam = gr.Radio(["Yes (View Existing Images)", "No (New or Example Images)"], label="Do you want to view existing GradCAM images?")
258
+ with gr.Column():
259
+ with gr.Row():
260
+ slider_gradcam_num_images = gr.Slider(minimum=1, maximum =10, value = 1, step =1, visible= False, interactive = False)
261
+ dropdown_gradcam_layer = gr.Dropdown(choices=['layer4', 'layer3'], value = "layer4", label="Please select the layer from which the GradCAM would be taken", interactive = True, visible= False)
262
+ slider_gradcam_opacity = gr.Slider(label ="Opacity of Images", minimum=0.05, maximum =1.00, value = 0.70, step =0.05, visible= False, interactive = True)
263
+ button_gradcam = gr.Button("View GradCAM Output", visible = False)
264
+ # txt_gradcam = gr.Textbox ("GradCAM output here" , visible = True)
265
+ output_gallery_gradcam=gr.Gallery(label="GradCAM Output", min_width=512,columns=4, visible = False)
266
+ with gr.Box():
267
+ with gr.Row():
268
+ with gr.Column():
269
+ input_image_classify = gr.Image(label="Classification",type="pil", shape=(32, 32))
270
+ slider_classify_num_classes = gr.Slider(label="Select the number of top classes to be shown",minimum=1, maximum =10, value = 3, step = 1, visible= True, interactive = True)
271
+ checkbox_gradcam_classify = gr.Checkbox(label="Enable GradCAM", value=True, info="Do you want to see Class Activation Maps?", visible=True)
272
+ # txt_classify= gr.Textbox ("Classification output here" , visible = True)
273
+ dropdown_gradcam_classify_layer = gr.Dropdown(choices=['layer4', 'layer3'], value = "layer4", label="Please select the layer from which the GradCAM would be taken", interactive = True, visible= True)
274
+ slider_gradcam_classify_opacity = gr.Slider(label ="Opacity of Images", minimum=0.05, maximum =1.00, value = 0.80, step =0.05, visible= True, interactive = True)
275
+ button_classify = gr.Button("Submit to Classify Image", visible = True)
276
+
277
+ with gr.Column():
278
+ label_classify = gr.Label(num_top_classes=10, visible = True)
279
+ gallery_gradcam_classify = gr.Gallery(label="GradCAM Output", min_width=256,columns=1, visible = True)
280
+ with gr.Row():
281
+ gr.Examples(['/content/bird1.jpg','/content/car1.jpg','/content/deer1.jpg','/content/frog1.jpg','/content/plane1.jpg',
282
+ '/content/ship1.jpg','/content/truck1.jpg',"/content/cat1.jpg","/content/dog1.jpg","/content/horse1.jpg"],
283
+ inputs=[input_image_classify])
284
+ with gr.Tab("Misclassified Examples"):
285
+ gr.Markdown(
286
+ """
287
+ The AI model is not able to predict correct image labels all the time.
288
+
289
+ Select "Yes" to visualize the misclassified images with their model predicted label and ground truth label.
290
+ """
291
+ )
292
+ with gr.Column():
293
+ with gr.Box():
294
+ radio_misclassified = gr.Radio(["Yes", "No"], label="Do you want to view Misclassified images?")
295
+ slider_misclassified_num_images = gr.Slider(minimum=1, maximum =10, value = 1, step =1, visible= False, interactive = False)
296
+ button_misclassified = gr.Button("View Misclassified Output", visible = False)
297
+ # txt_misclassified = gr.Textbox ("Misclassified output here" , visible = True)
298
+ output_gallery_misclassification=gr.Gallery(label="Misclassification Output (Predicted/Truth)", min_width=512,columns=5, visible = False)
299
+ # with gr.Row():
300
+ # with gr.Column():
301
+ # input_image_classify = gr.Image(label="Classification",type="pil", shape=(32, 32))
302
+ # slider_classify_num_classes = gr.Slider(label="Select the number of top classes to be shown",minimum=1, maximum =10, value = 3, step = 1, visible= True, interactive = True)
303
+ # button_classify = gr.Button("Submit to Classify Image", visible = True)
304
+ # gr.Examples(["/content/dog1.jpg"],inputs=[input_image_classify])
305
+ # with gr.Column():
306
+ # label_classify = gr.Label(num_top_classes=10, visible = True)
307
+
308
+
309
+ radio_gradcam.change(fn=view_gradcam_images, inputs=radio_gradcam, outputs=[slider_gradcam_num_images, dropdown_gradcam_layer,slider_gradcam_opacity,button_gradcam, output_gallery_gradcam])
310
+ button_gradcam.click(fn = process_gradcam_images, inputs = [slider_gradcam_num_images,dropdown_gradcam_layer,slider_gradcam_opacity], outputs = [output_gallery_gradcam,output_gallery_gradcam])
311
+
312
+ radio_misclassified.change(fn=view_misclassified_images, inputs=radio_misclassified, outputs=[slider_misclassified_num_images,button_misclassified,output_gallery_misclassification])
313
+ button_misclassified.click(fn = process_misclassified_images, inputs = [slider_misclassified_num_images], outputs = [output_gallery_misclassification,output_gallery_misclassification])
314
+ button_classify.click(fn=classify_image, inputs =[input_image_classify,slider_classify_num_classes,checkbox_gradcam_classify,dropdown_gradcam_classify_layer,slider_gradcam_classify_opacity], outputs = [label_classify,gallery_gradcam_classify])
315
+ demo.launch ()
316
+
bird1.jpg ADDED
car1.jpg ADDED
cat1.jpg ADDED
deer1.jpg ADDED
dog1.jpg ADDED
frog1.jpg ADDED
horse1.jpg ADDED
misclassified_images_list.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:70ee4dd1e925aa6ae87833561aa350412785acd57d354411273f6e36c022dc9c
3
+ size 15330301
model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b8135516f42a44764a75724b03bda96b10308f103ad33283c88962f57a3c018
3
+ size 31068665
plane1.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ numpy
4
+ grad-cam
5
+ pandas
6
+ gradio
7
+ Pillow
8
+
ship1.jpg ADDED
truck1.jpg ADDED