jatamura commited on
Commit
8f7b2c9
·
verified ·
1 Parent(s): 0e9299b

Update python_utils/get_model.py

Browse files
Files changed (1) hide show
  1. python_utils/get_model.py +22 -9
python_utils/get_model.py CHANGED
@@ -35,34 +35,47 @@ def load_model():
35
  return predictor
36
 
37
  def mask_nms(masks, scores, nms_threshold=0.5):
 
 
 
 
 
 
 
38
  import supervision as sv
 
39
 
 
 
 
 
 
 
 
40
  order = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)
41
- keep = []
42
  while order:
43
  i = order.pop(0)
44
- keep.append(i)
45
  for j in order:
46
-
47
- intersection = masks[i] * masks[j]
48
- union = masks[i] + masks[j]
49
- iou = intersection.sum() / union.sum()
50
 
51
  # Remove masks with IoU greater than the threshold
52
  if iou > nms_threshold:
53
  order.remove(j)
54
- return keep
55
 
56
  def apply_nms(prediction, mask=False, cls_agnostic_nms=0.8):
57
  from torchvision.ops import nms
58
  from detectron2.structures import Instances
59
 
60
  if mask:
61
- # print("Applying mask NMS")
62
  nms_indices = mask_nms(prediction["instances"].pred_masks.numpy(),
63
  prediction["instances"]._fields["scores"], cls_agnostic_nms)
64
  else:
65
- # print("Applying box NMS")
66
  nms_indices = nms(prediction["instances"].pred_boxes.tensor,
67
  prediction["instances"].scores, cls_agnostic_nms)
68
 
 
35
  return predictor
36
 
37
  def mask_nms(masks, scores, nms_threshold=0.5):
38
+ """
39
+ Runs class agnostic NMS on masks/segmentations instead of the bounding boxes.
40
+ :param masks: (list float) List of coordinates that make up the mask output from the model.
41
+ :param scores: (list float) List of corresponding confidence scores given to each mask.
42
+ :param nms_threshold: (float) Threshold to apply mask-based class agnostic NMS.
43
+ :return masks_kept (list float): List of masks kept after applying NMS.
44
+ """
45
  import supervision as sv
46
+ from shapely.geometry.polygon import Polygon
47
 
48
+ polygons = []
49
+ for mask in masks:
50
+ contour = sv.mask_to_polygons(mask)
51
+ if len(contour) > 0:
52
+ polygons.append(Polygon(contour[0]))
53
+ else:
54
+ polygons.append(Polygon([]))
55
  order = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)
56
+ masks_kept = []
57
  while order:
58
  i = order.pop(0)
59
+ masks_kept.append(i)
60
  for j in order:
61
+ # Calculate the IoU between the two polygons
62
+ intersection = polygons[i].intersection(polygons[j]).area
63
+ union = polygons[i].union(polygons[j]).area
64
+ iou = intersection / union
65
 
66
  # Remove masks with IoU greater than the threshold
67
  if iou > nms_threshold:
68
  order.remove(j)
69
+ return masks_kept
70
 
71
  def apply_nms(prediction, mask=False, cls_agnostic_nms=0.8):
72
  from torchvision.ops import nms
73
  from detectron2.structures import Instances
74
 
75
  if mask:
 
76
  nms_indices = mask_nms(prediction["instances"].pred_masks.numpy(),
77
  prediction["instances"]._fields["scores"], cls_agnostic_nms)
78
  else:
 
79
  nms_indices = nms(prediction["instances"].pred_boxes.tensor,
80
  prediction["instances"].scores, cls_agnostic_nms)
81