Spaces:
Sleeping
Sleeping
Zhen Ye Claude Opus 4.6 (1M context) commited on
Commit ·
8d09cca
1
Parent(s): ae40f9a
fix: add mask-level NMS in GSAM2/YSAM2 to deduplicate overlapping masks
Browse filesWithin a single keyframe, YOLO can detect the same object with
slightly different bounding boxes (e.g., cab vs full truck body)
that survive box-level NMS but produce overlapping SAM2 masks.
These all get unique IDs and render as stacked labels.
Added _mask_nms() to MaskDictionary.add_new_frame_annotation():
- Computes pairwise mask IoU for same-label detections
- Suppresses smaller masks when IoU > 0.5 with a larger one
- Runs before masks enter the SAM2 video predictor pipeline
Fixes duplicate "truck truck truck" labels on single objects.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
models/segmenters/grounded_sam2.py
CHANGED
|
@@ -50,7 +50,16 @@ class MaskDictionary:
|
|
| 50 |
mask_list: torch.Tensor,
|
| 51 |
box_list: torch.Tensor,
|
| 52 |
label_list: list,
|
|
|
|
| 53 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
mask_img = torch.zeros(mask_list.shape[-2:])
|
| 55 |
anno = {}
|
| 56 |
for idx, (mask, box, label) in enumerate(zip(mask_list, box_list, label_list)):
|
|
@@ -69,6 +78,49 @@ class MaskDictionary:
|
|
| 69 |
self.mask_width = mask_img.shape[1]
|
| 70 |
self.labels = anno
|
| 71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
def update_masks(
|
| 73 |
self,
|
| 74 |
tracking_dict: "MaskDictionary",
|
|
|
|
| 50 |
mask_list: torch.Tensor,
|
| 51 |
box_list: torch.Tensor,
|
| 52 |
label_list: list,
|
| 53 |
+
mask_iou_threshold: float = 0.5,
|
| 54 |
):
|
| 55 |
+
# Deduplicate overlapping masks within the same keyframe.
|
| 56 |
+
# YOLO can detect the same object with slightly different boxes
|
| 57 |
+
# (e.g., cab vs full truck), producing multiple masks for one object.
|
| 58 |
+
keep = self._mask_nms(mask_list, box_list, label_list, mask_iou_threshold)
|
| 59 |
+
mask_list = mask_list[keep]
|
| 60 |
+
box_list = box_list[keep]
|
| 61 |
+
label_list = [label_list[i] for i in keep]
|
| 62 |
+
|
| 63 |
mask_img = torch.zeros(mask_list.shape[-2:])
|
| 64 |
anno = {}
|
| 65 |
for idx, (mask, box, label) in enumerate(zip(mask_list, box_list, label_list)):
|
|
|
|
| 78 |
self.mask_width = mask_img.shape[1]
|
| 79 |
self.labels = anno
|
| 80 |
|
| 81 |
+
@staticmethod
|
| 82 |
+
def _mask_nms(
|
| 83 |
+
masks: torch.Tensor,
|
| 84 |
+
boxes: torch.Tensor,
|
| 85 |
+
labels: list,
|
| 86 |
+
iou_threshold: float = 0.5,
|
| 87 |
+
) -> list:
|
| 88 |
+
"""Remove duplicate masks within a keyframe using mask IoU.
|
| 89 |
+
|
| 90 |
+
For each pair of masks with the same label, if their mask IoU
|
| 91 |
+
exceeds the threshold, keep the one with the larger area.
|
| 92 |
+
Returns indices to keep.
|
| 93 |
+
"""
|
| 94 |
+
n = len(masks)
|
| 95 |
+
if n <= 1:
|
| 96 |
+
return list(range(n))
|
| 97 |
+
|
| 98 |
+
# Compute mask areas
|
| 99 |
+
areas = [int(masks[i].sum()) for i in range(n)]
|
| 100 |
+
suppressed = [False] * n
|
| 101 |
+
|
| 102 |
+
# Sort by area descending (keep larger masks)
|
| 103 |
+
order = sorted(range(n), key=lambda i: areas[i], reverse=True)
|
| 104 |
+
|
| 105 |
+
keep = []
|
| 106 |
+
for i in order:
|
| 107 |
+
if suppressed[i]:
|
| 108 |
+
continue
|
| 109 |
+
keep.append(i)
|
| 110 |
+
for j in order:
|
| 111 |
+
if j <= i or suppressed[j]:
|
| 112 |
+
continue
|
| 113 |
+
# Only suppress same-label masks
|
| 114 |
+
if labels[i] != labels[j]:
|
| 115 |
+
continue
|
| 116 |
+
# Compute mask IoU
|
| 117 |
+
inter = int((masks[i] & masks[j]).sum())
|
| 118 |
+
union = areas[i] + areas[j] - inter
|
| 119 |
+
if union > 0 and inter / union > iou_threshold:
|
| 120 |
+
suppressed[j] = True
|
| 121 |
+
|
| 122 |
+
return sorted(keep)
|
| 123 |
+
|
| 124 |
def update_masks(
|
| 125 |
self,
|
| 126 |
tracking_dict: "MaskDictionary",
|