jatamura commited on
Commit
ab90404
·
verified ·
1 Parent(s): 692fad8

Added bbox and mask class agnostic nms code

Browse files
Files changed (1) hide show
  1. python_utils/get_model.py +51 -0
python_utils/get_model.py CHANGED
@@ -33,6 +33,57 @@ def load_model():
33
 
34
  return predictor
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  if __name__ == '__main__':
37
  # get_set_up()
38
  load_model()
 
33
 
34
  return predictor
35
 
36
+ def mask_nms(masks, scores, nms_threshold=0.5):
37
+ import supervision as sv
38
+
39
+ polygons = []
40
+ for mask in masks:
41
+ contour = sv.mask_to_polygons(mask)
42
+ if len(contour) > 0:
43
+ polygons.append(Polygon(contour[0]))
44
+ else:
45
+ polygons.append(Polygon([]))
46
+ order = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)
47
+ keep = []
48
+ while order:
49
+ i = order.pop(0)
50
+ keep.append(i)
51
+ for j in order:
52
+ # Calculate the IoU between the two polygons
53
+ intersection = polygons[i].intersection(polygons[j]).area
54
+ union = polygons[i].union(polygons[j]).area
55
+ iou = intersection / union
56
+
57
+ # intersection = masks[i] * masks[j]
58
+ # union = masks[i] + masks[j]
59
+ # iou = intersection.sum() / union.sum()
60
+
61
+ # Remove masks with IoU greater than the threshold
62
+ if iou > nms_threshold:
63
+ order.remove(j)
64
+ return keep
65
+
66
+ def apply_nms(prediction, cls_agnostic_nms=0.5, mask=False):
67
+ from torchvision.ops import nms
68
+ from detectron2.structures import Instances
69
+
70
+ if mask:
71
+ # print("Applying mask NMS")
72
+ nms_indices = mask_nms(prediction["instances"].pred_masks.numpy(),
73
+ prediction["instances"]._fields["scores"], cls_agnostic_nms)
74
+ else:
75
+ # print("Applying box NMS")
76
+ nms_indices = nms(prediction["instances"].pred_boxes.tensor,
77
+ prediction["instances"].scores, cls_agnostic_nms)
78
+
79
+ pred = {"instances": Instances(image_size=prediction["instances"].image_size,
80
+ pred_boxes=prediction["instances"].pred_boxes[nms_indices],
81
+ scores=prediction["instances"].scores[nms_indices],
82
+ pred_classes=prediction["instances"].pred_classes[nms_indices]+1,
83
+ pred_masks=prediction["instances"].pred_masks[nms_indices])}
84
+
85
+ return pred
86
+
87
  if __name__ == '__main__':
88
  # get_set_up()
89
  load_model()