sanjanatule commited on
Commit
ee9ceac
·
1 Parent(s): 8e927fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -25
app.py CHANGED
@@ -17,6 +17,10 @@ import torch
17
  import torch.optim as optim
18
  import matplotlib
19
  import cv2
 
 
 
 
20
 
21
  # my files
22
  import utils
@@ -154,36 +158,43 @@ with gr.Blocks() as demo:
154
  colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))]
155
  colors_hex = {class_labels[i]:matplotlib.colors.rgb2hex(colors[i]) for i in range(0,len(class_labels))}
156
 
157
- # app GUI
158
- with gr.Row():
159
- img_input = gr.Image()
160
- img_output = gr.AnnotatedImage().style(color_map = colors_hex)
161
- section_btn = gr.Button("Identify Objects")
162
-
163
- def yolo3_inference(input_img): # function for yolo inference
164
-
 
 
 
 
 
 
 
 
 
 
165
  yololit = LitYolo()
166
  inference_model = yololit.load_from_checkpoint("yolo3_model.ckpt")
 
 
167
  anchors = (torch.tensor(config.ANCHORS) * torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2))
168
  bboxes = [[]]
169
  sections = [] # to return image and annotations
 
 
 
 
 
 
 
 
170
 
171
- # image transformation
172
- test_transforms = Al.Compose(
173
- [
174
- Al.LongestMaxSize(max_size=416),
175
- Al.PadIfNeeded(
176
- min_height=416, min_width=416, border_mode=cv2.BORDER_CONSTANT
177
- ),
178
- Al.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
179
- ToTensorV2(),
180
- ]
181
- )
182
- pr_input_img = test_transforms(image=input_img)
183
- pr_input_img = pr_input_img['image'].unsqueeze(0)
184
  # infer the image
185
  inference_model.eval()
186
- test_img_out = inference_model(pr_input_img)
187
 
188
  # process the outputs to create bounding boxes
189
  for i in range(3):
@@ -194,6 +205,7 @@ with gr.Blocks() as demo:
194
  bboxes[idx] += box
195
  # nms
196
  nms_boxes = utils.non_max_suppression(bboxes[0], iou_threshold=0.6, threshold=0.5, box_format="midpoint",)
 
197
 
198
  # use gradio image annotations
199
  height, width = 416, 416
@@ -205,13 +217,37 @@ with gr.Blocks() as demo:
205
  lower_right_x = int(upper_left_x + (box[2] * width))
206
  lower_right_y = int(upper_left_y + (box[3] * height))
207
  sections.append(((upper_left_x,upper_left_y,lower_right_x,lower_right_y), class_labels[int(class_pred)]))
208
- return (np.array(pr_input_img.squeeze(0).permute(1,2,0)),sections)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
- section_btn.click(yolo3_inference, inputs=[img_input], outputs=[img_output])
 
 
 
 
 
 
 
 
 
211
 
212
  gr.Markdown("## Some Examples")
213
  gr.Examples(examples=examples,
214
- inputs =img_input,
215
  outputs=img_output,
216
  fn=yolo3_inference, cache_examples=False)
217
 
 
17
  import torch.optim as optim
18
  import matplotlib
19
  import cv2
20
+ from pytorch_grad_cam import EigenCAM
21
+ from pytorch_grad_cam.utils.model_targets import FasterRCNNBoxScoreTarget
22
+ from pytorch_grad_cam.utils.image import show_cam_on_image
23
+
24
 
25
  # my files
26
  import utils
 
158
  colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))]
159
  colors_hex = {class_labels[i]:matplotlib.colors.rgb2hex(colors[i]) for i in range(0,len(class_labels))}
160
 
161
+ # consolidate the output from the model for gradcam to work
162
+ def yolov3_reshape_transform(x):
163
+ activations = []
164
+ size = x[0].size()[2:4] # 13 * 13
165
+ for x_item in x:
166
+ x_permute = x_item.permute(0, 1, 4, 2, 3 ) # 1,3,25,13,13
167
+ x_permute = x_permute.reshape((x_permute.shape[0],
168
+ x_permute.shape[1]*x_permute.shape[2],
169
+ *x_permute.shape[3:])) # 1,75,13,13
170
+ activations.append(torch.nn.functional.interpolate(torch.abs(x_permute), size, mode='bilinear'))
171
+ activations = torch.cat(activations, axis=1) # 1,255,13,13
172
+ return(activations)
173
+
174
+
175
+ # main function of the app
176
+ def yolo3_inference(input_img,gradcam=True,gradcam_opa=0.5): # function for yolo inference
177
+
178
+ # load model
179
  yololit = LitYolo()
180
  inference_model = yololit.load_from_checkpoint("yolo3_model.ckpt")
181
+
182
+ # bboxes, gradcam
183
  anchors = (torch.tensor(config.ANCHORS) * torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2))
184
  bboxes = [[]]
185
  sections = [] # to return image and annotations
186
+ nms_boxes_output = []
187
+
188
+ # process the input image for inference/gradcam
189
+ input_img = cv2.resize(input_img, (416, 416))
190
+ input_img_copy = input_img.copy()
191
+ input_img = np.float32(input_img) / 255
192
+ transform = transforms.ToTensor()
193
+ input_img = transform(input_img).unsqueeze(0)
194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  # infer the image
196
  inference_model.eval()
197
+ test_img_out = inference_model(input_img)
198
 
199
  # process the outputs to create bounding boxes
200
  for i in range(3):
 
205
  bboxes[idx] += box
206
  # nms
207
  nms_boxes = utils.non_max_suppression(bboxes[0], iou_threshold=0.6, threshold=0.5, box_format="midpoint",)
208
+ nms_boxes_output.append(nms_boxes)
209
 
210
  # use gradio image annotations
211
  height, width = 416, 416
 
217
  lower_right_x = int(upper_left_x + (box[2] * width))
218
  lower_right_y = int(upper_left_y + (box[3] * height))
219
  sections.append(((upper_left_x,upper_left_y,lower_right_x,lower_right_y), class_labels[int(class_pred)]))
220
+
221
+ # for gradcam
222
+ if gradcam:
223
+ objs = [b[1] for b in nms_boxes_output[0]]
224
+ bbox_coord = [b[2:] for b in nms_boxes_output[0]]
225
+ targets = [FasterRCNNBoxScoreTarget(objs, bbox_coord)]
226
+
227
+ target_layers = [inference_model.model]
228
+ cam = EigenCAM(inference_model, target_layers, use_cuda=False,reshape_transform=yolov3_reshape_transform)
229
+ grayscale_cam = cam(input_tensor = input_img, targets= targets)
230
+ grayscale_cam = grayscale_cam[0, :]
231
+ visualization = show_cam_on_image(input_img_copy/255, grayscale_cam, use_rgb=True, image_weight=gradcam_opa)
232
+
233
+ return (visualization,sections)
234
+ else:
235
+ return (np.array(input_img.squeeze(0).permute(1,2,0)),sections)
236
 
237
+ # app GUI
238
+ with gr.Row():
239
+ img_input = gr.Image()
240
+ img_output = gr.AnnotatedImage().style(color_map = colors_hex)
241
+ with gr.Row():
242
+ gradcam_check = gr.Checkbox(label="Gradcam")
243
+ gradcam_opa = gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM")
244
+
245
+ section_btn = gr.Button("Identify Objects")
246
+ section_btn.click(yolo3_inference, inputs=[img_input,gradcam_check,gradcam_opa], outputs=[img_output])
247
 
248
  gr.Markdown("## Some Examples")
249
  gr.Examples(examples=examples,
250
+ inputs =[img_input,gradcam_check,gradcam_opa],
251
  outputs=img_output,
252
  fn=yolo3_inference, cache_examples=False)
253