Spaces:
Paused
Paused
Zhen Ye
commited on
Commit
·
968c327
1
Parent(s):
c90fe44
Fix cross-GPU mask device mismatch in GSAM2 reconciliation
Browse files- inference.py +11 -2
- models/segmenters/grounded_sam2.py +12 -1
inference.py
CHANGED
|
@@ -1849,6 +1849,11 @@ def run_grounded_sam2_tracking(
|
|
| 1849 |
sam2_masks = MaskDictionary()
|
| 1850 |
tracking_results: Dict[int, Dict[int, ObjectInfo]] = {}
|
| 1851 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1852 |
for seg_idx in sorted(segment_data.keys()):
|
| 1853 |
start_idx, mask_dict, segment_results = segment_data[seg_idx]
|
| 1854 |
|
|
@@ -1862,7 +1867,7 @@ def run_grounded_sam2_tracking(
|
|
| 1862 |
{
|
| 1863 |
k: ObjectInfo(
|
| 1864 |
instance_id=v.instance_id,
|
| 1865 |
-
mask=v.mask,
|
| 1866 |
class_name=v.class_name,
|
| 1867 |
x1=v.x1, y1=v.y1,
|
| 1868 |
x2=v.x2, y2=v.y2,
|
|
@@ -1874,6 +1879,10 @@ def run_grounded_sam2_tracking(
|
|
| 1874 |
)
|
| 1875 |
continue
|
| 1876 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1877 |
# IoU match + get local→global remapping
|
| 1878 |
global_id_counter, remapping = (
|
| 1879 |
mask_dict.update_masks_with_remapping(
|
|
@@ -1899,7 +1908,7 @@ def run_grounded_sam2_tracking(
|
|
| 1899 |
continue
|
| 1900 |
remapped[global_id] = ObjectInfo(
|
| 1901 |
instance_id=global_id,
|
| 1902 |
-
mask=obj_info.mask,
|
| 1903 |
class_name=obj_info.class_name,
|
| 1904 |
x1=obj_info.x1, y1=obj_info.y1,
|
| 1905 |
x2=obj_info.x2, y2=obj_info.y2,
|
|
|
|
| 1849 |
sam2_masks = MaskDictionary()
|
| 1850 |
tracking_results: Dict[int, Dict[int, ObjectInfo]] = {}
|
| 1851 |
|
| 1852 |
+
def _mask_to_cpu(mask):
|
| 1853 |
+
if torch.is_tensor(mask):
|
| 1854 |
+
return mask.detach().cpu()
|
| 1855 |
+
return mask
|
| 1856 |
+
|
| 1857 |
for seg_idx in sorted(segment_data.keys()):
|
| 1858 |
start_idx, mask_dict, segment_results = segment_data[seg_idx]
|
| 1859 |
|
|
|
|
| 1867 |
{
|
| 1868 |
k: ObjectInfo(
|
| 1869 |
instance_id=v.instance_id,
|
| 1870 |
+
mask=_mask_to_cpu(v.mask),
|
| 1871 |
class_name=v.class_name,
|
| 1872 |
x1=v.x1, y1=v.y1,
|
| 1873 |
x2=v.x2, y2=v.y2,
|
|
|
|
| 1879 |
)
|
| 1880 |
continue
|
| 1881 |
|
| 1882 |
+
# Normalize keyframe masks to CPU before cross-GPU IoU matching.
|
| 1883 |
+
for info in mask_dict.labels.values():
|
| 1884 |
+
info.mask = _mask_to_cpu(info.mask)
|
| 1885 |
+
|
| 1886 |
# IoU match + get local→global remapping
|
| 1887 |
global_id_counter, remapping = (
|
| 1888 |
mask_dict.update_masks_with_remapping(
|
|
|
|
| 1908 |
continue
|
| 1909 |
remapped[global_id] = ObjectInfo(
|
| 1910 |
instance_id=global_id,
|
| 1911 |
+
mask=_mask_to_cpu(obj_info.mask),
|
| 1912 |
class_name=obj_info.class_name,
|
| 1913 |
x1=obj_info.x1, y1=obj_info.y1,
|
| 1914 |
x2=obj_info.x2, y2=obj_info.y2,
|
models/segmenters/grounded_sam2.py
CHANGED
|
@@ -164,13 +164,24 @@ class MaskDictionary:
|
|
| 164 |
|
| 165 |
@staticmethod
|
| 166 |
def _iou(m1: torch.Tensor, m2: torch.Tensor) -> float:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
m1f = m1.to(torch.float32)
|
| 168 |
m2f = m2.to(torch.float32)
|
| 169 |
inter = (m1f * m2f).sum()
|
| 170 |
union = m1f.sum() + m2f.sum() - inter
|
| 171 |
if union == 0:
|
| 172 |
return 0.0
|
| 173 |
-
return float(inter / union)
|
| 174 |
|
| 175 |
|
| 176 |
# ---------------------------------------------------------------------------
|
|
|
|
| 164 |
|
| 165 |
@staticmethod
|
| 166 |
def _iou(m1: torch.Tensor, m2: torch.Tensor) -> float:
|
| 167 |
+
if not torch.is_tensor(m1):
|
| 168 |
+
m1 = torch.as_tensor(m1)
|
| 169 |
+
if not torch.is_tensor(m2):
|
| 170 |
+
m2 = torch.as_tensor(m2)
|
| 171 |
+
|
| 172 |
+
# Multi-GPU reconciliation can compare masks produced on different
|
| 173 |
+
# devices; normalize both masks onto CPU before arithmetic.
|
| 174 |
+
if m1.device != m2.device:
|
| 175 |
+
m1 = m1.detach().to(device="cpu")
|
| 176 |
+
m2 = m2.detach().to(device="cpu")
|
| 177 |
+
|
| 178 |
m1f = m1.to(torch.float32)
|
| 179 |
m2f = m2.to(torch.float32)
|
| 180 |
inter = (m1f * m2f).sum()
|
| 181 |
union = m1f.sum() + m2f.sum() - inter
|
| 182 |
if union == 0:
|
| 183 |
return 0.0
|
| 184 |
+
return float((inter / union).item())
|
| 185 |
|
| 186 |
|
| 187 |
# ---------------------------------------------------------------------------
|