Zhen Ye Claude Opus 4.6 (1M context) commited on
Commit
8d09cca
·
1 Parent(s): ae40f9a

fix: add mask-level NMS in GSAM2/YSAM2 to deduplicate overlapping masks

Browse files

Within a single keyframe, YOLO can detect the same object with
slightly different bounding boxes (e.g., cab vs full truck body)
that survive box-level NMS but produce overlapping SAM2 masks.
These all get unique IDs and render as stacked labels.

Added _mask_nms() to MaskDictionary.add_new_frame_annotation():
- Computes pairwise mask IoU for same-label detections
- Suppresses smaller masks when IoU > 0.5 with a larger one
- Runs before masks enter the SAM2 video predictor pipeline

Fixes duplicate "truck truck truck" labels on single objects.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Files changed (1) hide show
  1. models/segmenters/grounded_sam2.py +52 -0
models/segmenters/grounded_sam2.py CHANGED
@@ -50,7 +50,16 @@ class MaskDictionary:
50
  mask_list: torch.Tensor,
51
  box_list: torch.Tensor,
52
  label_list: list,
 
53
  ):
 
 
 
 
 
 
 
 
54
  mask_img = torch.zeros(mask_list.shape[-2:])
55
  anno = {}
56
  for idx, (mask, box, label) in enumerate(zip(mask_list, box_list, label_list)):
@@ -69,6 +78,49 @@ class MaskDictionary:
69
  self.mask_width = mask_img.shape[1]
70
  self.labels = anno
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  def update_masks(
73
  self,
74
  tracking_dict: "MaskDictionary",
 
50
  mask_list: torch.Tensor,
51
  box_list: torch.Tensor,
52
  label_list: list,
53
+ mask_iou_threshold: float = 0.5,
54
  ):
55
+ # Deduplicate overlapping masks within the same keyframe.
56
+ # YOLO can detect the same object with slightly different boxes
57
+ # (e.g., cab vs full truck), producing multiple masks for one object.
58
+ keep = self._mask_nms(mask_list, box_list, label_list, mask_iou_threshold)
59
+ mask_list = mask_list[keep]
60
+ box_list = box_list[keep]
61
+ label_list = [label_list[i] for i in keep]
62
+
63
  mask_img = torch.zeros(mask_list.shape[-2:])
64
  anno = {}
65
  for idx, (mask, box, label) in enumerate(zip(mask_list, box_list, label_list)):
 
78
  self.mask_width = mask_img.shape[1]
79
  self.labels = anno
80
 
81
+ @staticmethod
82
+ def _mask_nms(
83
+ masks: torch.Tensor,
84
+ boxes: torch.Tensor,
85
+ labels: list,
86
+ iou_threshold: float = 0.5,
87
+ ) -> list:
88
+ """Remove duplicate masks within a keyframe using mask IoU.
89
+
90
+ For each pair of masks with the same label, if their mask IoU
91
+ exceeds the threshold, keep the one with the larger area.
92
+ Returns indices to keep.
93
+ """
94
+ n = len(masks)
95
+ if n <= 1:
96
+ return list(range(n))
97
+
98
+ # Compute mask areas
99
+ areas = [int(masks[i].sum()) for i in range(n)]
100
+ suppressed = [False] * n
101
+
102
+ # Sort by area descending (keep larger masks)
103
+ order = sorted(range(n), key=lambda i: areas[i], reverse=True)
104
+
105
+ keep = []
106
+ for i in order:
107
+ if suppressed[i]:
108
+ continue
109
+ keep.append(i)
110
+ for j in order:
111
+ if j <= i or suppressed[j]:
112
+ continue
113
+ # Only suppress same-label masks
114
+ if labels[i] != labels[j]:
115
+ continue
116
+ # Compute mask IoU
117
+ inter = int((masks[i] & masks[j]).sum())
118
+ union = areas[i] + areas[j] - inter
119
+ if union > 0 and inter / union > iou_threshold:
120
+ suppressed[j] = True
121
+
122
+ return sorted(keep)
123
+
124
  def update_masks(
125
  self,
126
  tracking_dict: "MaskDictionary",