reputation commited on
Commit
d39299a
·
verified ·
1 Parent(s): 18d079b

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +2 -108
utils.py CHANGED
@@ -12,13 +12,6 @@ from tqdm import tqdm
12
 
13
 
14
  def iou_width_height(boxes1, boxes2):
15
- """
16
- Parameters:
17
- boxes1 (tensor): width and height of the first bounding boxes
18
- boxes2 (tensor): width and height of the second bounding boxes
19
- Returns:
20
- tensor: Intersection over union of the corresponding boxes
21
- """
22
  intersection = torch.min(boxes1[..., 0], boxes2[..., 0]) * torch.min(
23
  boxes1[..., 1], boxes2[..., 1]
24
  )
@@ -29,21 +22,6 @@ def iou_width_height(boxes1, boxes2):
29
 
30
 
31
  def intersection_over_union(boxes_preds, boxes_labels, box_format="midpoint"):
32
- """
33
- Video explanation of this function:
34
- https://youtu.be/XXYG5ZWtjj0
35
-
36
- This function calculates intersection over union (iou) given pred boxes
37
- and target boxes.
38
-
39
- Parameters:
40
- boxes_preds (tensor): Predictions of Bounding Boxes (BATCH_SIZE, 4)
41
- boxes_labels (tensor): Correct labels of Bounding Boxes (BATCH_SIZE, 4)
42
- box_format (str): midpoint/corners, if boxes (x,y,w,h) or (x1,y1,x2,y2)
43
-
44
- Returns:
45
- tensor: Intersection over union for all examples
46
- """
47
 
48
  if box_format == "midpoint":
49
  box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
@@ -78,22 +56,6 @@ def intersection_over_union(boxes_preds, boxes_labels, box_format="midpoint"):
78
 
79
 
80
  def non_max_suppression(bboxes, iou_threshold, threshold, box_format="corners"):
81
- """
82
- Video explanation of this function:
83
- https://youtu.be/YDkjWEN8jNA
84
-
85
- Does Non Max Suppression given bboxes
86
-
87
- Parameters:
88
- bboxes (list): list of lists containing all bboxes with each bboxes
89
- specified as [class_pred, prob_score, x1, y1, x2, y2]
90
- iou_threshold (float): threshold where predicted bboxes is correct
91
- threshold (float): threshold to remove predicted bboxes (independent of IoU)
92
- box_format (str): "midpoint" or "corners" used to specify bboxes
93
-
94
- Returns:
95
- list: bboxes after performing NMS given a specific IoU threshold
96
- """
97
 
98
  assert type(bboxes) == list
99
 
@@ -124,37 +86,15 @@ def non_max_suppression(bboxes, iou_threshold, threshold, box_format="corners"):
124
  def mean_average_precision(
125
  pred_boxes, true_boxes, iou_threshold=0.5, box_format="midpoint", num_classes=20
126
  ):
127
- """
128
- Video explanation of this function:
129
- https://youtu.be/FppOzcDvaDI
130
-
131
- This function calculates mean average precision (mAP)
132
-
133
- Parameters:
134
- pred_boxes (list): list of lists containing all bboxes with each bboxes
135
- specified as [train_idx, class_prediction, prob_score, x1, y1, x2, y2]
136
- true_boxes (list): Similar as pred_boxes except all the correct ones
137
- iou_threshold (float): threshold where predicted bboxes is correct
138
- box_format (str): "midpoint" or "corners" used to specify bboxes
139
- num_classes (int): number of classes
140
 
141
- Returns:
142
- float: mAP value across all classes given a specific IoU threshold
143
- """
144
-
145
- # list storing all AP for respective classes
146
  average_precisions = []
147
 
148
- # used for numerical stability later on
149
  epsilon = 1e-6
150
 
151
  for c in range(num_classes):
152
  detections = []
153
  ground_truths = []
154
 
155
- # Go through all predictions and targets,
156
- # and only add the ones that belong to the
157
- # current class c
158
  for detection in pred_boxes:
159
  if detection[1] == c:
160
  detections.append(detection)
@@ -162,33 +102,19 @@ def mean_average_precision(
162
  for true_box in true_boxes:
163
  if true_box[1] == c:
164
  ground_truths.append(true_box)
165
-
166
- # find the amount of bboxes for each training example
167
- # Counter here finds how many ground truth bboxes we get
168
- # for each training example, so let's say img 0 has 3,
169
- # img 1 has 5 then we will obtain a dictionary with:
170
- # amount_bboxes = {0:3, 1:5}
171
  amount_bboxes = Counter([gt[0] for gt in ground_truths])
172
-
173
- # We then go through each key, val in this dictionary
174
- # and convert to the following (w.r.t same example):
175
- # ammount_bboxes = {0:torch.tensor[0,0,0], 1:torch.tensor[0,0,0,0,0]}
176
  for key, val in amount_bboxes.items():
177
  amount_bboxes[key] = torch.zeros(val)
178
 
179
- # sort by box probabilities which is index 2
180
  detections.sort(key=lambda x: x[2], reverse=True)
181
  TP = torch.zeros((len(detections)))
182
  FP = torch.zeros((len(detections)))
183
  total_true_bboxes = len(ground_truths)
184
 
185
- # If none exists for this class then we can safely skip
186
  if total_true_bboxes == 0:
187
  continue
188
 
189
  for detection_idx, detection in enumerate(detections):
190
- # Only take out the ground_truths that have the same
191
- # training idx as detection
192
  ground_truth_img = [
193
  bbox for bbox in ground_truths if bbox[0] == detection[0]
194
  ]
@@ -208,15 +134,12 @@ def mean_average_precision(
208
  best_gt_idx = idx
209
 
210
  if best_iou > iou_threshold:
211
- # only detect ground truth detection once
212
  if amount_bboxes[detection[0]][best_gt_idx] == 0:
213
- # true positive and add this bounding box to seen
214
  TP[detection_idx] = 1
215
  amount_bboxes[detection[0]][best_gt_idx] = 1
216
  else:
217
  FP[detection_idx] = 1
218
 
219
- # if IOU is lower then the detection is a false positive
220
  else:
221
  FP[detection_idx] = 1
222
 
@@ -233,22 +156,13 @@ def mean_average_precision(
233
 
234
 
235
  def plot_image(image, boxes):
236
- """Plots predicted bounding boxes on the image"""
237
  cmap = plt.get_cmap("tab20b")
238
  class_labels = config.COCO_LABELS if config.DATASET == 'COCO' else config.PASCAL_CLASSES
239
  colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))]
240
  im = np.array(image)
241
  height, width, _ = im.shape
242
-
243
- # Create figure and axes
244
  fig, ax = plt.subplots(1)
245
- # Display the image
246
  ax.imshow(im)
247
-
248
- # box[0] is x midpoint, box[2] is width
249
- # box[1] is y midpoint, box[3] is height
250
-
251
- # Create a Rectangle patch
252
  for box in boxes:
253
  assert len(box) == 6, "box should contain class pred, confidence, x, y, width, height"
254
  class_pred = box[0]
@@ -263,7 +177,6 @@ def plot_image(image, boxes):
263
  edgecolor=colors[int(class_pred)],
264
  facecolor="none",
265
  )
266
- # Add the patch to the Axes
267
  ax.add_patch(rect)
268
  plt.text(
269
  upper_left_x * width,
@@ -286,7 +199,6 @@ def get_evaluation_bboxes(
286
  box_format="midpoint",
287
  device="cuda",
288
  ):
289
- # make sure model is in eval before get bboxes
290
  model.eval()
291
  train_idx = 0
292
  all_pred_boxes = []
@@ -308,7 +220,6 @@ def get_evaluation_bboxes(
308
  for idx, (box) in enumerate(boxes_scale_i):
309
  bboxes[idx] += box
310
 
311
- # we just want one bbox for each label, not one for each scale
312
  true_bboxes = cells_to_bboxes(
313
  labels[2], anchor, S=S, is_preds=False
314
  )
@@ -335,19 +246,6 @@ def get_evaluation_bboxes(
335
 
336
 
337
  def cells_to_bboxes(predictions, anchors, S, is_preds=True):
338
- """
339
- Scales the predictions coming from the model to
340
- be relative to the entire image such that they for example later
341
- can be plotted or.
342
- INPUT:
343
- predictions: tensor of size (N, 3, S, S, num_classes+5)
344
- anchors: the anchors used for the predictions
345
- S: the number of cells the image is divided in on the width (and height)
346
- is_preds: whether the input is predictions or the true bounding boxes
347
- OUTPUT:
348
- converted_bboxes: the converted boxes of sizes (N, num_anchors, S, S, 1+5) with class index,
349
- object score, bounding box coordinates
350
- """
351
  BATCH_SIZE = predictions.shape[0]
352
  num_anchors = len(anchors)
353
  box_predictions = predictions[..., 1:5]
@@ -387,8 +285,8 @@ def check_class_accuracy(model, loader, threshold):
387
 
388
  for i in range(3):
389
  y[i] = y[i].to(config.DEVICE)
390
- obj = y[i][..., 0] == 1 # in paper this is Iobj_i
391
- noobj = y[i][..., 0] == 0 # in paper this is Iobj_i
392
 
393
  correct_class += torch.sum(
394
  torch.argmax(out[i][..., 5:][obj], dim=-1) == y[i][..., 5][obj]
@@ -408,7 +306,6 @@ def check_class_accuracy(model, loader, threshold):
408
 
409
 
410
  def get_mean_std(loader):
411
- # var[X] = E[X**2] - E[X]**2
412
  channels_sum, channels_sqrd_sum, num_batches = 0, 0, 0
413
 
414
  for data, _ in tqdm(loader):
@@ -436,9 +333,6 @@ def load_checkpoint(checkpoint_file, model, optimizer, lr):
436
  checkpoint = torch.load(checkpoint_file, map_location=config.DEVICE)
437
  model.load_state_dict(checkpoint["state_dict"])
438
  optimizer.load_state_dict(checkpoint["optimizer"])
439
-
440
- # If we don't do this then it will just have learning rate of old checkpoint
441
- # and it will lead to many hours of debugging \:
442
  for param_group in optimizer.param_groups:
443
  param_group["lr"] = lr
444
 
 
12
 
13
 
14
  def iou_width_height(boxes1, boxes2):
 
 
 
 
 
 
 
15
  intersection = torch.min(boxes1[..., 0], boxes2[..., 0]) * torch.min(
16
  boxes1[..., 1], boxes2[..., 1]
17
  )
 
22
 
23
 
24
  def intersection_over_union(boxes_preds, boxes_labels, box_format="midpoint"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  if box_format == "midpoint":
27
  box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
 
56
 
57
 
58
  def non_max_suppression(bboxes, iou_threshold, threshold, box_format="corners"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  assert type(bboxes) == list
61
 
 
86
  def mean_average_precision(
87
  pred_boxes, true_boxes, iou_threshold=0.5, box_format="midpoint", num_classes=20
88
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
 
 
 
 
 
90
  average_precisions = []
91
 
 
92
  epsilon = 1e-6
93
 
94
  for c in range(num_classes):
95
  detections = []
96
  ground_truths = []
97
 
 
 
 
98
  for detection in pred_boxes:
99
  if detection[1] == c:
100
  detections.append(detection)
 
102
  for true_box in true_boxes:
103
  if true_box[1] == c:
104
  ground_truths.append(true_box)
 
 
 
 
 
 
105
  amount_bboxes = Counter([gt[0] for gt in ground_truths])
 
 
 
 
106
  for key, val in amount_bboxes.items():
107
  amount_bboxes[key] = torch.zeros(val)
108
 
 
109
  detections.sort(key=lambda x: x[2], reverse=True)
110
  TP = torch.zeros((len(detections)))
111
  FP = torch.zeros((len(detections)))
112
  total_true_bboxes = len(ground_truths)
113
 
 
114
  if total_true_bboxes == 0:
115
  continue
116
 
117
  for detection_idx, detection in enumerate(detections):
 
 
118
  ground_truth_img = [
119
  bbox for bbox in ground_truths if bbox[0] == detection[0]
120
  ]
 
134
  best_gt_idx = idx
135
 
136
  if best_iou > iou_threshold:
 
137
  if amount_bboxes[detection[0]][best_gt_idx] == 0:
 
138
  TP[detection_idx] = 1
139
  amount_bboxes[detection[0]][best_gt_idx] = 1
140
  else:
141
  FP[detection_idx] = 1
142
 
 
143
  else:
144
  FP[detection_idx] = 1
145
 
 
156
 
157
 
158
  def plot_image(image, boxes):
 
159
  cmap = plt.get_cmap("tab20b")
160
  class_labels = config.COCO_LABELS if config.DATASET == 'COCO' else config.PASCAL_CLASSES
161
  colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))]
162
  im = np.array(image)
163
  height, width, _ = im.shape
 
 
164
  fig, ax = plt.subplots(1)
 
165
  ax.imshow(im)
 
 
 
 
 
166
  for box in boxes:
167
  assert len(box) == 6, "box should contain class pred, confidence, x, y, width, height"
168
  class_pred = box[0]
 
177
  edgecolor=colors[int(class_pred)],
178
  facecolor="none",
179
  )
 
180
  ax.add_patch(rect)
181
  plt.text(
182
  upper_left_x * width,
 
199
  box_format="midpoint",
200
  device="cuda",
201
  ):
 
202
  model.eval()
203
  train_idx = 0
204
  all_pred_boxes = []
 
220
  for idx, (box) in enumerate(boxes_scale_i):
221
  bboxes[idx] += box
222
 
 
223
  true_bboxes = cells_to_bboxes(
224
  labels[2], anchor, S=S, is_preds=False
225
  )
 
246
 
247
 
248
  def cells_to_bboxes(predictions, anchors, S, is_preds=True):
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  BATCH_SIZE = predictions.shape[0]
250
  num_anchors = len(anchors)
251
  box_predictions = predictions[..., 1:5]
 
285
 
286
  for i in range(3):
287
  y[i] = y[i].to(config.DEVICE)
288
+ obj = y[i][..., 0] == 1
289
+ noobj = y[i][..., 0] == 0
290
 
291
  correct_class += torch.sum(
292
  torch.argmax(out[i][..., 5:][obj], dim=-1) == y[i][..., 5][obj]
 
306
 
307
 
308
  def get_mean_std(loader):
 
309
  channels_sum, channels_sqrd_sum, num_batches = 0, 0, 0
310
 
311
  for data, _ in tqdm(loader):
 
333
  checkpoint = torch.load(checkpoint_file, map_location=config.DEVICE)
334
  model.load_state_dict(checkpoint["state_dict"])
335
  optimizer.load_state_dict(checkpoint["optimizer"])
 
 
 
336
  for param_group in optimizer.param_groups:
337
  param_group["lr"] = lr
338