Zhen Ye commited on
Commit
968c327
·
1 Parent(s): c90fe44

Fix cross-GPU mask device mismatch in GSAM2 reconciliation

Browse files
Files changed (2) hide show
  1. inference.py +11 -2
  2. models/segmenters/grounded_sam2.py +12 -1
inference.py CHANGED
@@ -1849,6 +1849,11 @@ def run_grounded_sam2_tracking(
1849
  sam2_masks = MaskDictionary()
1850
  tracking_results: Dict[int, Dict[int, ObjectInfo]] = {}
1851
 
 
 
 
 
 
1852
  for seg_idx in sorted(segment_data.keys()):
1853
  start_idx, mask_dict, segment_results = segment_data[seg_idx]
1854
 
@@ -1862,7 +1867,7 @@ def run_grounded_sam2_tracking(
1862
  {
1863
  k: ObjectInfo(
1864
  instance_id=v.instance_id,
1865
- mask=v.mask,
1866
  class_name=v.class_name,
1867
  x1=v.x1, y1=v.y1,
1868
  x2=v.x2, y2=v.y2,
@@ -1874,6 +1879,10 @@ def run_grounded_sam2_tracking(
1874
  )
1875
  continue
1876
 
 
 
 
 
1877
  # IoU match + get local→global remapping
1878
  global_id_counter, remapping = (
1879
  mask_dict.update_masks_with_remapping(
@@ -1899,7 +1908,7 @@ def run_grounded_sam2_tracking(
1899
  continue
1900
  remapped[global_id] = ObjectInfo(
1901
  instance_id=global_id,
1902
- mask=obj_info.mask,
1903
  class_name=obj_info.class_name,
1904
  x1=obj_info.x1, y1=obj_info.y1,
1905
  x2=obj_info.x2, y2=obj_info.y2,
 
1849
  sam2_masks = MaskDictionary()
1850
  tracking_results: Dict[int, Dict[int, ObjectInfo]] = {}
1851
 
1852
+ def _mask_to_cpu(mask):
1853
+ if torch.is_tensor(mask):
1854
+ return mask.detach().cpu()
1855
+ return mask
1856
+
1857
  for seg_idx in sorted(segment_data.keys()):
1858
  start_idx, mask_dict, segment_results = segment_data[seg_idx]
1859
 
 
1867
  {
1868
  k: ObjectInfo(
1869
  instance_id=v.instance_id,
1870
+ mask=_mask_to_cpu(v.mask),
1871
  class_name=v.class_name,
1872
  x1=v.x1, y1=v.y1,
1873
  x2=v.x2, y2=v.y2,
 
1879
  )
1880
  continue
1881
 
1882
+ # Normalize keyframe masks to CPU before cross-GPU IoU matching.
1883
+ for info in mask_dict.labels.values():
1884
+ info.mask = _mask_to_cpu(info.mask)
1885
+
1886
  # IoU match + get local→global remapping
1887
  global_id_counter, remapping = (
1888
  mask_dict.update_masks_with_remapping(
 
1908
  continue
1909
  remapped[global_id] = ObjectInfo(
1910
  instance_id=global_id,
1911
+ mask=_mask_to_cpu(obj_info.mask),
1912
  class_name=obj_info.class_name,
1913
  x1=obj_info.x1, y1=obj_info.y1,
1914
  x2=obj_info.x2, y2=obj_info.y2,
models/segmenters/grounded_sam2.py CHANGED
@@ -164,13 +164,24 @@ class MaskDictionary:
164
 
165
  @staticmethod
166
  def _iou(m1: torch.Tensor, m2: torch.Tensor) -> float:
 
 
 
 
 
 
 
 
 
 
 
167
  m1f = m1.to(torch.float32)
168
  m2f = m2.to(torch.float32)
169
  inter = (m1f * m2f).sum()
170
  union = m1f.sum() + m2f.sum() - inter
171
  if union == 0:
172
  return 0.0
173
- return float(inter / union)
174
 
175
 
176
  # ---------------------------------------------------------------------------
 
164
 
165
  @staticmethod
166
  def _iou(m1: torch.Tensor, m2: torch.Tensor) -> float:
167
+ if not torch.is_tensor(m1):
168
+ m1 = torch.as_tensor(m1)
169
+ if not torch.is_tensor(m2):
170
+ m2 = torch.as_tensor(m2)
171
+
172
+ # Multi-GPU reconciliation can compare masks produced on different
173
+ # devices; normalize both masks onto CPU before arithmetic.
174
+ if m1.device != m2.device:
175
+ m1 = m1.detach().to(device="cpu")
176
+ m2 = m2.detach().to(device="cpu")
177
+
178
  m1f = m1.to(torch.float32)
179
  m2f = m2.to(torch.float32)
180
  inter = (m1f * m2f).sum()
181
  union = m1f.sum() + m2f.sum() - inter
182
  if union == 0:
183
  return 0.0
184
+ return float((inter / union).item())
185
 
186
 
187
  # ---------------------------------------------------------------------------