sanjanatule commited on
Commit
67b18f0
·
1 Parent(s): 2dd8354

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -21
app.py CHANGED
@@ -31,19 +31,61 @@ inference_model = LitResnet.load_from_checkpoint("cifar10_customresnet_20_epoch.
31
  classes = ('plane', 'car', 'bird', 'cat', 'deer',
32
  'dog', 'frog', 'horse', 'ship', 'truck')
33
 
34
- def inference(input_img, see_misclassified=False,num_misclassified_imgs=0,see_gradcam=False,num_gradcam_imgs=0,transparency = 0.85, target_layer_number = -1,top_classes=3):
35
 
36
- if see_misclassified: # show misclassified images
37
- org_img = np.asarray(Image.open('misclassified_images/mis_eg_0.jpg'))
38
- input_img = org_img
39
 
40
- elif num_gradcam_imgs > 0: # show gradcam on example images
41
- org_img = np.asarray(Image.open('examples/car.jpg'))
42
- input_img = org_img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- else: # nothing chosen - misclassified or gradcam
45
- org_img = input_img
 
 
 
 
 
 
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  # model inference
48
  transform = transforms.ToTensor()
49
  input_img = transform(input_img)
@@ -53,7 +95,6 @@ def inference(input_img, see_misclassified=False,num_misclassified_imgs=0,see_gr
53
  o = softmax(outputs.flatten())
54
  confidences = {classes[i]: float(o[i]) for i in range(10)}
55
  sorted_confidences = dict(sorted(confidences.items(), key=lambda x:x[1], reverse=True))
56
-
57
  _, prediction = torch.max(outputs, 1)
58
 
59
  # gradcam
@@ -68,18 +109,84 @@ def inference(input_img, see_misclassified=False,num_misclassified_imgs=0,see_gr
68
 
69
  # top n classes only
70
  sorted_confidences = {k: sorted_confidences[k] for k in list(sorted_confidences)[:top_classes]}
71
- return sorted_confidences, [visualization]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- title = "CIFAR10 trained on Custom ResNet Model with GradCAM"
74
- description = "A Gradio interface to infer on ResNet model, and get GradCAM results"
75
  examples = [["examples/cat.jpg"], ["examples/plane.jpg"],["examples/dog.jpg"],["examples/truck.jpg"],["examples/bird.jpg"],["examples/ship.jpg"],["examples/horse.jpg"],["examples/frog.jpg"],["examples/deer.jpg"],["examples/car.jpg"]]
76
 
77
- demo = gr.Interface(
78
- inference,
79
- inputs = [gr.Image(shape=(32, 32), label="Input Image"), gr.Checkbox(label="Misclassified"),gr.Number(value=2,minimum=0,maximum=10,label="Total Misclassified Images"),gr.Checkbox(label="Gradcam"),gr.Number(value=2,minimum=0,maximum=10,label="Total GradCam Images"),gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM"), gr.Slider(-2, -1, value = -1, step=1, label="Which Layer?"), gr.Slider(1, 10, value=3, step=1, label="How many top classes?")],
80
- outputs = [gr.Label(), gr.Gallery(label="Output Images", show_label=False, elem_id="gallery").style(columns=[2], rows=[5], object_fit="contain", height="auto")],
81
- title = title,
82
- description = description,
83
- examples = examples)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- demo.launch()
 
31
  classes = ('plane', 'car', 'bird', 'cat', 'deer',
32
  'dog', 'frog', 'horse', 'ship', 'truck')
33
 
34
+ # def inference(input_img, see_misclassified=False,num_misclassified_imgs=0,see_gradcam=False,num_gradcam_imgs=0,transparency = 0.85, target_layer_number = -1,top_classes=3):
35
 
36
+ # if see_misclassified: # show misclassified images
37
+ # org_img = np.asarray(Image.open('misclassified_images/mis_eg_0.jpg'))
38
+ # input_img = org_img
39
 
40
+ # elif num_gradcam_imgs > 0: # show gradcam on example images
41
+ # org_img = np.asarray(Image.open('examples/car.jpg'))
42
+ # input_img = org_img
43
+
44
+ # else: # nothing chosen - misclassified or gradcam
45
+ # org_img = input_img
46
+
47
+ # # model inference
48
+ # transform = transforms.ToTensor()
49
+ # input_img = transform(input_img)
50
+ # input_img = input_img.unsqueeze(0)
51
+ # outputs = inference_model.model(input_img)
52
+ # softmax = torch.nn.Softmax(dim=0)
53
+ # o = softmax(outputs.flatten())
54
+ # confidences = {classes[i]: float(o[i]) for i in range(10)}
55
+ # sorted_confidences = dict(sorted(confidences.items(), key=lambda x:x[1], reverse=True))
56
+
57
+ # _, prediction = torch.max(outputs, 1)
58
 
59
+ # # gradcam
60
+ # if see_gradcam:
61
+ # target_layers = [inference_model.model.layer2[target_layer_number]]
62
+ # cam = GradCAM(model=inference_model.model, target_layers=target_layers, use_cuda=False)
63
+ # grayscale_cam = cam(input_tensor=input_img, targets=None)
64
+ # grayscale_cam = grayscale_cam[0, :]
65
+ # visualization = show_cam_on_image(org_img/255.0, grayscale_cam, use_rgb=True, image_weight=transparency)
66
+ # else:
67
+ # visualization = org_img
68
 
69
+ # # top n classes only
70
+ # sorted_confidences = {k: sorted_confidences[k] for k in list(sorted_confidences)[:top_classes]}
71
+ # return sorted_confidences, [visualization]
72
+
73
+ # title = "CIFAR10 trained on Custom ResNet Model with GradCAM"
74
+ # description = "A Gradio interface to infer on ResNet model, and get GradCAM results"
75
+ # examples = [["examples/cat.jpg"], ["examples/plane.jpg"],["examples/dog.jpg"],["examples/truck.jpg"],["examples/bird.jpg"],["examples/ship.jpg"],["examples/horse.jpg"],["examples/frog.jpg"],["examples/deer.jpg"],["examples/car.jpg"]]
76
+
77
+ # demo = gr.Interface(
78
+ # inference,
79
+ # inputs = [gr.Image(shape=(32, 32), label="Input Image"), gr.Checkbox(label="Misclassified"),gr.Number(value=2,minimum=0,maximum=10,label="Total Misclassified Images"),gr.Checkbox(label="Gradcam"),gr.Number(value=2,minimum=0,maximum=10,label="Total GradCam Images"),gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM"), gr.Slider(-2, -1, value = -1, step=1, label="Which Layer?"), gr.Slider(1, 10, value=3, step=1, label="How many top classes?")],
80
+ # outputs = [gr.Label(), gr.Gallery(label="Output Images", show_label=False, elem_id="gallery").style(columns=[2], rows=[5], object_fit="contain", height="auto")],
81
+ # title = title,
82
+ # description = description,
83
+ # examples = examples)
84
+
85
+ # demo.launch()
86
+
87
+ def inference_up_img(input_img,see_gradcam= True,target_layer_number = -1,transparency = 0.85,top_classes=3):
88
+ org_img = input_img
89
  # model inference
90
  transform = transforms.ToTensor()
91
  input_img = transform(input_img)
 
95
  o = softmax(outputs.flatten())
96
  confidences = {classes[i]: float(o[i]) for i in range(10)}
97
  sorted_confidences = dict(sorted(confidences.items(), key=lambda x:x[1], reverse=True))
 
98
  _, prediction = torch.max(outputs, 1)
99
 
100
  # gradcam
 
109
 
110
  # top n classes only
111
  sorted_confidences = {k: sorted_confidences[k] for k in list(sorted_confidences)[:top_classes]}
112
+ return sorted_confidences, visualization
113
+
114
+ def misclass_fn(misclassified_check,num_misclassified=1,see_gradcam=True,num_gradcam=1,gradcam_layer=-2,gradcam_opa= 0.50):
115
+ img_gallery = []
116
+
117
+ if misclassified_check:
118
+ for i in range(int(num_misclassified)):
119
+ org_img = np.asarray(Image.open('misclassified_images/mis_eg_' + str(i) + '.jpg'))
120
+ input_img = org_img
121
+
122
+ if see_gradcam:
123
+ transform = transforms.ToTensor()
124
+ input_img = transform(input_img)
125
+ input_img = input_img.unsqueeze(0)
126
+ target_layers = [inference_model.model.layer2[gradcam_layer]]
127
+ cam = GradCAM(model=inference_model.model, target_layers=target_layers, use_cuda=False)
128
+ grayscale_cam = cam(input_tensor=input_img, targets=None)
129
+ grayscale_cam = grayscale_cam[0, :]
130
+ visualization = show_cam_on_image(org_img/255.0, grayscale_cam, use_rgb=True, image_weight=gradcam_opa)
131
+ img_gallery.append(visualization)
132
+ else:
133
+ img_gallery.append(org_img)
134
+
135
+ elif see_gradcam:
136
+ for i in range(int(num_gradcam)):
137
+ org_img = np.asarray(Image.open('misclassified_images/mis_eg_' + str(i) + '.jpg'))
138
+ input_img = org_img
139
+ transform = transforms.ToTensor()
140
+ input_img = transform(input_img)
141
+ input_img = input_img.unsqueeze(0)
142
+ target_layers = [inference_model.model.layer2[gradcam_layer]]
143
+ cam = GradCAM(model=inference_model.model, target_layers=target_layers, use_cuda=False)
144
+ grayscale_cam = cam(input_tensor=input_img, targets=None)
145
+ grayscale_cam = grayscale_cam[0, :]
146
+ visualization = show_cam_on_image(org_img/255.0, grayscale_cam, use_rgb=True, image_weight=gradcam_opa)
147
+ img_gallery.append(visualization)
148
+
149
+ return img_gallery
150
 
 
 
151
  examples = [["examples/cat.jpg"], ["examples/plane.jpg"],["examples/dog.jpg"],["examples/truck.jpg"],["examples/bird.jpg"],["examples/ship.jpg"],["examples/horse.jpg"],["examples/frog.jpg"],["examples/deer.jpg"],["examples/car.jpg"]]
152
 
153
+ with gr.Blocks() as demo:
154
+ gr.Markdown("Explore Custom ResNet model for CIFAR10.")
155
+ with gr.Tab("Upload your own image"):
156
+ with gr.Row():
157
+ image_input = gr.Image(shape=(32, 32), label="Input Image")
158
+ image_label = gr.Label()
159
+ with gr.Row():
160
+ with gr.Column():
161
+ gradcam_check = gr.Checkbox(label="Gradcam")
162
+ gradcam_layer = gr.Slider(-2, -1, value = -1, step=1, label="Which Layer?")
163
+ gradcam_opa = gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM")
164
+ top_classes = gr.Slider(1, 10, value=3, step=1, label="How many top classes?")
165
+ image_output = gr.Image(shape=(32, 32), label="Output").style(width=128, height=128)
166
+ with gr.Row():
167
+ examples = gr.Examples(examples=examples,
168
+ inputs=[image_input,gradcam_check,gradcam_layer,gradcam_opa,top_classes,image_label],
169
+ outputs=[image_output],
170
+ fn=inference_up_img, cache_examples=False)
171
+ with gr.Row():
172
+ tab_1_button = gr.Button("Submit")
173
+ tab_1_cl_button = gr.ClearButton([image_input,gradcam_check,gradcam_layer,gradcam_opa,top_classes,image_label,image_output])
174
+
175
+ with gr.Tab("Explore Misclassified/Gradcam Images"):
176
+ with gr.Row():
177
+ with gr.Column():
178
+ misclassified_check = gr.Checkbox(label="Misclassified")
179
+ num_misclassified = gr.Number(value=2,minimum=1,maximum=10,label="Total Misclassified Images")
180
+ gradcam_check1 = gr.Checkbox(label="Gradcam")
181
+ num_gradcam = gr.Number(value=2,minimum=1,maximum=10,label="Total Gradcam Images")
182
+ gradcam_layer1 = gr.Slider(-2, -1, value = -1, step=1, label="Which Layer?")
183
+ gradcam_opa1 = gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM")
184
+ image_gallery_output = gr.Gallery(label="Output Images", show_label=False, elem_id="gallery").style(columns=[2], rows=[5], object_fit="contain", height="auto")
185
+ with gr.Row():
186
+ tab_2_button = gr.Button("Submit")
187
+ tab_2_cl_button = gr.ClearButton([misclassified_check,num_misclassified,gradcam_check1,num_gradcam,gradcam_layer1,gradcam_opa1,image_gallery_output])
188
+
189
+ tab_1_button.click(inference_up_img, inputs=[image_input,gradcam_check,gradcam_layer,gradcam_opa,top_classes], outputs=[image_label,image_output])
190
+ tab_2_button.click(misclass_fn, inputs=[misclassified_check,num_misclassified,gradcam_check1,num_gradcam,gradcam_layer1,gradcam_opa1], outputs=[image_gallery_output])
191
+ demo.launch(debug=True)
192