Zhen Ye commited on
Commit
c90fe44
·
1 Parent(s): 032b60f

Harden GSAM2 parallel pipeline and tracking reconciliation

Browse files
Files changed (2) hide show
  1. inference.py +418 -79
  2. models/segmenters/grounded_sam2.py +165 -2
inference.py CHANGED
@@ -1586,6 +1586,63 @@ def run_segmentation(
1586
 
1587
 
1588
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1589
  def run_grounded_sam2_tracking(
1590
  input_video_path: str,
1591
  output_video_path: str,
@@ -1598,14 +1655,16 @@ def run_grounded_sam2_tracking(
1598
  ) -> str:
1599
  """Run Grounded-SAM-2 video tracking pipeline.
1600
 
1601
- Unlike per-frame segmentation, this extracts all frames to JPEG,
1602
- runs SAM2 video predictor for temporal mask propagation, then
1603
- renders the results back into a video.
1604
  """
 
1605
  import shutil
 
 
1606
 
1607
  from utils.video import extract_frames_to_jpeg_dir
1608
- from models.segmenters.model_loader import load_segmenter as _load_seg
1609
 
1610
  active_segmenter = segmenter_name or "gsam2_large"
1611
  logging.info(
@@ -1622,92 +1681,372 @@ def run_grounded_sam2_tracking(
1622
  total_frames = len(frame_names)
1623
  logging.info("Extracted %d frames to %s", total_frames, frame_dir)
1624
 
1625
- # 2. Load segmenter
1626
- segmenter = _load_seg(active_segmenter)
1627
 
1628
- # 3. Run tracking pipeline
1629
- _check_cancellation(job_id)
1630
- tracking_results = segmenter.process_video(frame_dir, frame_names, queries)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1631
 
1632
- # 4. Render results into output video
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1633
  _check_cancellation(job_id)
1634
- import os as _os
1635
-
1636
- with StreamingVideoWriter(output_video_path, fps, width, height) as writer:
1637
- for frame_idx in range(total_frames):
1638
- _check_cancellation(job_id)
1639
-
1640
- # Read original frame
1641
- frame_path = _os.path.join(frame_dir, frame_names[frame_idx])
1642
- frame = cv2.imread(frame_path)
1643
- if frame is None:
1644
- logging.warning("Failed to read frame %d, writing blank", frame_idx)
1645
- frame = np.zeros((height, width, 3), dtype=np.uint8)
1646
-
1647
- frame_objects = tracking_results.get(frame_idx, {})
1648
-
1649
- if frame_objects:
1650
- # Collect masks, boxes, and labels for rendering
1651
- masks_list = []
1652
- boxes_list = []
1653
- label_list = []
1654
-
1655
- for obj_id, obj_info in frame_objects.items():
1656
- mask = obj_info.mask
1657
- if mask is not None:
1658
- if isinstance(mask, torch.Tensor):
1659
- mask_np = mask.cpu().numpy().astype(bool)
1660
- else:
1661
- mask_np = np.asarray(mask).astype(bool)
1662
- # Resize mask if needed
1663
- if mask_np.shape[:2] != (height, width):
1664
- mask_np = cv2.resize(
1665
- mask_np.astype(np.uint8),
1666
- (width, height),
1667
- interpolation=cv2.INTER_NEAREST,
1668
- ).astype(bool)
1669
- masks_list.append(mask_np)
1670
-
1671
- label = f"{obj_info.instance_id} {obj_info.class_name}"
1672
- label_list.append(label)
1673
-
1674
- has_box = not (obj_info.x1 == 0 and obj_info.y1 == 0 and obj_info.x2 == 0 and obj_info.y2 == 0)
1675
- if has_box:
1676
- boxes_list.append([obj_info.x1, obj_info.y1, obj_info.x2, obj_info.y2])
1677
-
1678
- # Draw masks
1679
- if masks_list:
1680
- masks_array = np.stack(masks_list)
1681
- frame = draw_masks(frame, masks_array, labels=label_list)
1682
-
1683
- # Draw boxes
1684
- if boxes_list:
1685
- boxes_array = np.array(boxes_list)
1686
- frame = draw_boxes(frame, boxes_array, label_names=label_list)
1687
-
1688
- writer.write(frame)
1689
-
1690
- # Stream frame if requested
1691
- if stream_queue:
1692
  try:
1693
- from jobs.streaming import publish_frame as _pub
1694
- if job_id:
1695
- _pub(job_id, frame)
1696
- else:
1697
- stream_queue.put(frame, timeout=0.01)
1698
- except Exception:
1699
  pass
1700
 
1701
- if frame_idx % 30 == 0:
1702
- logging.info(
1703
- "Rendered frame %d / %d", frame_idx, total_frames
1704
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1705
 
1706
  logging.info("Grounded-SAM-2 output written to: %s", output_video_path)
1707
  return output_video_path
1708
 
1709
  finally:
1710
- # Cleanup temp frame directory
1711
  try:
1712
  shutil.rmtree(frame_dir)
1713
  logging.info("Cleaned up temp frame dir: %s", frame_dir)
 
1586
 
1587
 
1588
 
1589
+ def _gsam2_render_frame(
1590
+ frame_dir: str,
1591
+ frame_names: List[str],
1592
+ frame_idx: int,
1593
+ frame_objects: Dict,
1594
+ height: int,
1595
+ width: int,
1596
+ ) -> np.ndarray:
1597
+ """Render a single GSAM2 tracking frame (masks + boxes). CPU-only."""
1598
+ from models.segmenters.grounded_sam2 import ObjectInfo
1599
+
1600
+ frame_path = os.path.join(frame_dir, frame_names[frame_idx])
1601
+ frame = cv2.imread(frame_path)
1602
+ if frame is None:
1603
+ return np.zeros((height, width, 3), dtype=np.uint8)
1604
+
1605
+ if not frame_objects:
1606
+ return frame
1607
+
1608
+ masks_list: List[np.ndarray] = []
1609
+ mask_labels: List[str] = []
1610
+ boxes_list: List[List[int]] = []
1611
+ box_labels: List[str] = []
1612
+
1613
+ for _obj_id, obj_info in frame_objects.items():
1614
+ mask = obj_info.mask
1615
+ label = f"{obj_info.instance_id} {obj_info.class_name}"
1616
+ if mask is not None:
1617
+ if isinstance(mask, torch.Tensor):
1618
+ mask_np = mask.cpu().numpy().astype(bool)
1619
+ else:
1620
+ mask_np = np.asarray(mask).astype(bool)
1621
+ if mask_np.shape[:2] != (height, width):
1622
+ mask_np = cv2.resize(
1623
+ mask_np.astype(np.uint8),
1624
+ (width, height),
1625
+ interpolation=cv2.INTER_NEAREST,
1626
+ ).astype(bool)
1627
+ masks_list.append(mask_np)
1628
+ mask_labels.append(label)
1629
+
1630
+ has_box = not (
1631
+ obj_info.x1 == 0 and obj_info.y1 == 0
1632
+ and obj_info.x2 == 0 and obj_info.y2 == 0
1633
+ )
1634
+ if has_box:
1635
+ boxes_list.append([obj_info.x1, obj_info.y1, obj_info.x2, obj_info.y2])
1636
+ box_labels.append(label)
1637
+
1638
+ if masks_list:
1639
+ frame = draw_masks(frame, np.stack(masks_list), labels=mask_labels)
1640
+ if boxes_list:
1641
+ frame = draw_boxes(frame, np.array(boxes_list), label_names=box_labels)
1642
+
1643
+ return frame
1644
+
1645
+
1646
  def run_grounded_sam2_tracking(
1647
  input_video_path: str,
1648
  output_video_path: str,
 
1655
  ) -> str:
1656
  """Run Grounded-SAM-2 video tracking pipeline.
1657
 
1658
+ Uses multi-GPU data parallelism when multiple GPUs are available.
1659
+ Falls back to single-GPU ``process_video`` otherwise.
 
1660
  """
1661
+ import copy
1662
  import shutil
1663
+ from contextlib import nullcontext
1664
+ from PIL import Image as PILImage
1665
 
1666
  from utils.video import extract_frames_to_jpeg_dir
1667
+ from models.segmenters.grounded_sam2 import MaskDictionary, ObjectInfo
1668
 
1669
  active_segmenter = segmenter_name or "gsam2_large"
1670
  logging.info(
 
1681
  total_frames = len(frame_names)
1682
  logging.info("Extracted %d frames to %s", total_frames, frame_dir)
1683
 
1684
+ num_gpus = torch.cuda.device_count()
 
1685
 
1686
+ # ==================================================================
1687
+ # Phase 1-4: Tracking (single-GPU fallback vs multi-GPU pipeline)
1688
+ # ==================================================================
1689
+ if num_gpus <= 1:
1690
+ # ---------- Single-GPU fallback ----------
1691
+ device_str = "cuda:0" if torch.cuda.is_available() else "cpu"
1692
+ segmenter = load_segmenter_on_device(active_segmenter, device_str)
1693
+ _check_cancellation(job_id)
1694
+ tracking_results = segmenter.process_video(
1695
+ frame_dir, frame_names, queries,
1696
+ )
1697
+ logging.info(
1698
+ "Single-GPU tracking complete: %d frames",
1699
+ len(tracking_results),
1700
+ )
1701
+ else:
1702
+ # ---------- Multi-GPU pipeline ----------
1703
+ logging.info(
1704
+ "Multi-GPU GSAM2 tracking: %d GPUs, %d frames, step=%d",
1705
+ num_gpus, total_frames, step,
1706
+ )
1707
+
1708
+ # Phase 1: Load one segmenter per GPU (parallel)
1709
+ segmenters = []
1710
+ with ThreadPoolExecutor(max_workers=num_gpus) as pool:
1711
+ futs = [
1712
+ pool.submit(
1713
+ load_segmenter_on_device,
1714
+ active_segmenter,
1715
+ f"cuda:{i}",
1716
+ )
1717
+ for i in range(num_gpus)
1718
+ ]
1719
+ segmenters = [f.result() for f in futs]
1720
+ logging.info("Loaded %d segmenters", len(segmenters))
1721
+
1722
+ # Phase 2: Init SAM2 models/state per GPU (parallel)
1723
+ def _init_seg_state(seg):
1724
+ seg._ensure_models_loaded()
1725
+ return seg._video_predictor.init_state(
1726
+ video_path=frame_dir,
1727
+ offload_video_to_cpu=True,
1728
+ async_loading_frames=True,
1729
+ )
1730
+
1731
+ with ThreadPoolExecutor(max_workers=len(segmenters)) as pool:
1732
+ futs = [pool.submit(_init_seg_state, seg) for seg in segmenters]
1733
+ inference_states = [f.result() for f in futs]
1734
+
1735
+ # Phase 3: Parallel segment processing (queue-based workers)
1736
+ segments = list(range(0, total_frames, step))
1737
+ seg_queue_in: Queue = Queue()
1738
+ seg_queue_out: Queue = Queue()
1739
+
1740
+ for i, start_idx in enumerate(segments):
1741
+ seg_queue_in.put((i, start_idx))
1742
+ for _ in segmenters:
1743
+ seg_queue_in.put(None) # sentinel
1744
+
1745
+ iou_thresh = segmenters[0].iou_threshold
1746
+
1747
+ def _segment_worker(gpu_idx: int):
1748
+ seg = segmenters[gpu_idx]
1749
+ state = inference_states[gpu_idx]
1750
+ device_type = seg.device.split(":")[0]
1751
+ ac = (
1752
+ torch.autocast(device_type=device_type, dtype=torch.bfloat16)
1753
+ if device_type == "cuda"
1754
+ else nullcontext()
1755
+ )
1756
+ with ac:
1757
+ while True:
1758
+ if job_id:
1759
+ try:
1760
+ _check_cancellation(job_id)
1761
+ except RuntimeError as e:
1762
+ if "cancelled" in str(e).lower():
1763
+ logging.info(
1764
+ "Segment worker %d cancelled.",
1765
+ gpu_idx,
1766
+ )
1767
+ break
1768
+ raise
1769
+ item = seg_queue_in.get()
1770
+ if item is None:
1771
+ break
1772
+ seg_idx, start_idx = item
1773
+ try:
1774
+ logging.info(
1775
+ "GPU %d processing segment %d (frame %d)",
1776
+ gpu_idx, seg_idx, start_idx,
1777
+ )
1778
+ img_path = os.path.join(
1779
+ frame_dir, frame_names[start_idx]
1780
+ )
1781
+ with PILImage.open(img_path) as pil_img:
1782
+ image = pil_img.convert("RGB")
1783
+
1784
+ if job_id:
1785
+ _check_cancellation(job_id)
1786
+ masks, boxes, labels = seg.detect_keyframe(
1787
+ image, queries,
1788
+ )
1789
+
1790
+ if masks is None:
1791
+ seg_queue_out.put(
1792
+ (seg_idx, start_idx, None, {})
1793
+ )
1794
+ continue
1795
+
1796
+ mask_dict = MaskDictionary()
1797
+ mask_dict.add_new_frame_annotation(
1798
+ mask_list=torch.tensor(masks).to(seg.device),
1799
+ box_list=(
1800
+ boxes.clone()
1801
+ if torch.is_tensor(boxes)
1802
+ else torch.tensor(boxes)
1803
+ ),
1804
+ label_list=labels,
1805
+ )
1806
+
1807
+ segment_results = seg.propagate_segment(
1808
+ state, start_idx, mask_dict, step,
1809
+ )
1810
+ seg_queue_out.put(
1811
+ (seg_idx, start_idx, mask_dict, segment_results)
1812
+ )
1813
+ except RuntimeError as e:
1814
+ if "cancelled" in str(e).lower():
1815
+ logging.info(
1816
+ "Segment worker %d cancelled.",
1817
+ gpu_idx,
1818
+ )
1819
+ break
1820
+ raise
1821
+ except Exception:
1822
+ logging.exception(
1823
+ "Segment %d failed on GPU %d",
1824
+ seg_idx, gpu_idx,
1825
+ )
1826
+ seg_queue_out.put(
1827
+ (seg_idx, start_idx, None, {})
1828
+ )
1829
 
1830
+ seg_workers = []
1831
+ for i in range(num_gpus):
1832
+ t = Thread(
1833
+ target=_segment_worker, args=(i,), daemon=True,
1834
+ )
1835
+ t.start()
1836
+ seg_workers.append(t)
1837
+
1838
+ for t in seg_workers:
1839
+ t.join()
1840
+
1841
+ # Collect all segment outputs
1842
+ segment_data: Dict[int, Tuple] = {}
1843
+ while not seg_queue_out.empty():
1844
+ seg_idx, start_idx, mask_dict, results = seg_queue_out.get()
1845
+ segment_data[seg_idx] = (start_idx, mask_dict, results)
1846
+
1847
+ # Phase 4: Sequential ID reconciliation
1848
+ global_id_counter = 0
1849
+ sam2_masks = MaskDictionary()
1850
+ tracking_results: Dict[int, Dict[int, ObjectInfo]] = {}
1851
+
1852
+ for seg_idx in sorted(segment_data.keys()):
1853
+ start_idx, mask_dict, segment_results = segment_data[seg_idx]
1854
+
1855
+ if mask_dict is None or not mask_dict.labels:
1856
+ # No detections — carry forward previous masks
1857
+ for fi in range(
1858
+ start_idx, min(start_idx + step, total_frames)
1859
+ ):
1860
+ if fi not in tracking_results:
1861
+ tracking_results[fi] = (
1862
+ {
1863
+ k: ObjectInfo(
1864
+ instance_id=v.instance_id,
1865
+ mask=v.mask,
1866
+ class_name=v.class_name,
1867
+ x1=v.x1, y1=v.y1,
1868
+ x2=v.x2, y2=v.y2,
1869
+ )
1870
+ for k, v in sam2_masks.labels.items()
1871
+ }
1872
+ if sam2_masks.labels
1873
+ else {}
1874
+ )
1875
+ continue
1876
+
1877
+ # IoU match + get local→global remapping
1878
+ global_id_counter, remapping = (
1879
+ mask_dict.update_masks_with_remapping(
1880
+ tracking_dict=sam2_masks,
1881
+ iou_threshold=iou_thresh,
1882
+ objects_count=global_id_counter,
1883
+ )
1884
+ )
1885
+
1886
+ if not mask_dict.labels:
1887
+ for fi in range(
1888
+ start_idx, min(start_idx + step, total_frames)
1889
+ ):
1890
+ tracking_results[fi] = {}
1891
+ continue
1892
+
1893
+ # Apply remapping to every frame in this segment
1894
+ for frame_idx, frame_objects in segment_results.items():
1895
+ remapped: Dict[int, ObjectInfo] = {}
1896
+ for local_id, obj_info in frame_objects.items():
1897
+ global_id = remapping.get(local_id)
1898
+ if global_id is None:
1899
+ continue
1900
+ remapped[global_id] = ObjectInfo(
1901
+ instance_id=global_id,
1902
+ mask=obj_info.mask,
1903
+ class_name=obj_info.class_name,
1904
+ x1=obj_info.x1, y1=obj_info.y1,
1905
+ x2=obj_info.x2, y2=obj_info.y2,
1906
+ )
1907
+ tracking_results[frame_idx] = remapped
1908
+
1909
+ # Update running tracker with last frame of this segment
1910
+ if segment_results:
1911
+ last_fi = max(segment_results.keys())
1912
+ last_objs = tracking_results.get(last_fi, {})
1913
+ sam2_masks = MaskDictionary()
1914
+ sam2_masks.labels = copy.deepcopy(last_objs)
1915
+ if last_objs:
1916
+ first_info = next(iter(last_objs.values()))
1917
+ if first_info.mask is not None:
1918
+ m = first_info.mask
1919
+ sam2_masks.mask_height = (
1920
+ m.shape[-2] if m.ndim >= 2 else 0
1921
+ )
1922
+ sam2_masks.mask_width = (
1923
+ m.shape[-1] if m.ndim >= 2 else 0
1924
+ )
1925
+
1926
+ logging.info(
1927
+ "Multi-GPU reconciliation complete: %d frames, %d objects",
1928
+ len(tracking_results), global_id_counter,
1929
+ )
1930
+
1931
+ # ==================================================================
1932
+ # Phase 5: Parallel rendering + sequential video writing
1933
+ # ==================================================================
1934
  _check_cancellation(job_id)
1935
+
1936
+ render_in: Queue = Queue(maxsize=32)
1937
+ render_out: Queue = Queue(maxsize=64)
1938
+ render_done = False
1939
+ num_render_workers = min(4, os.cpu_count() or 1)
1940
+
1941
+ def _render_worker():
1942
+ while True:
1943
+ item = render_in.get()
1944
+ if item is None:
1945
+ break
1946
+ fidx, fobjs = item
1947
+ try:
1948
+ frm = _gsam2_render_frame(
1949
+ frame_dir, frame_names, fidx, fobjs,
1950
+ height, width,
1951
+ )
1952
+ while True:
1953
+ try:
1954
+ render_out.put((fidx, frm), timeout=1.0)
1955
+ break
1956
+ except Full:
1957
+ if render_done:
1958
+ return
1959
+ except Exception:
1960
+ logging.exception("Render failed for frame %d", fidx)
1961
+ blank = np.zeros((height, width, 3), dtype=np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1962
  try:
1963
+ render_out.put((fidx, blank), timeout=5.0)
1964
+ except Full:
 
 
 
 
1965
  pass
1966
 
1967
+ r_workers = [
1968
+ Thread(target=_render_worker, daemon=True)
1969
+ for _ in range(num_render_workers)
1970
+ ]
1971
+ for t in r_workers:
1972
+ t.start()
1973
+
1974
+ def _writer_loop():
1975
+ nonlocal render_done
1976
+ next_idx = 0
1977
+ buf: Dict[int, np.ndarray] = {}
1978
+ try:
1979
+ with StreamingVideoWriter(
1980
+ output_video_path, fps, width, height
1981
+ ) as writer:
1982
+ while next_idx < total_frames:
1983
+ try:
1984
+ while next_idx not in buf:
1985
+ if len(buf) > 128:
1986
+ logging.warning(
1987
+ "Render reorder buffer large (%d), "
1988
+ "waiting for frame %d",
1989
+ len(buf), next_idx,
1990
+ )
1991
+ time.sleep(0.05)
1992
+ idx, frm = render_out.get(timeout=1.0)
1993
+ buf[idx] = frm
1994
+
1995
+ frm = buf.pop(next_idx)
1996
+ writer.write(frm)
1997
+
1998
+ if stream_queue:
1999
+ try:
2000
+ from jobs.streaming import (
2001
+ publish_frame as _pub,
2002
+ )
2003
+ if job_id:
2004
+ _pub(job_id, frm)
2005
+ else:
2006
+ stream_queue.put(frm, timeout=0.01)
2007
+ except Exception:
2008
+ pass
2009
+
2010
+ next_idx += 1
2011
+ if next_idx % 30 == 0:
2012
+ logging.info(
2013
+ "Rendered frame %d / %d",
2014
+ next_idx, total_frames,
2015
+ )
2016
+ except Empty:
2017
+ if job_id:
2018
+ _check_cancellation(job_id)
2019
+ if not any(t.is_alive() for t in r_workers) and render_out.empty():
2020
+ logging.error(
2021
+ "Render workers stopped while waiting "
2022
+ "for frame %d", next_idx,
2023
+ )
2024
+ break
2025
+ continue
2026
+ finally:
2027
+ render_done = True
2028
+
2029
+ writer_thread = Thread(target=_writer_loop, daemon=True)
2030
+ writer_thread.start()
2031
+
2032
+ # Feed render queue
2033
+ for fidx in range(total_frames):
2034
+ _check_cancellation(job_id)
2035
+ fobjs = tracking_results.get(fidx, {})
2036
+ render_in.put((fidx, fobjs))
2037
+
2038
+ # Sentinels for render workers
2039
+ for _ in r_workers:
2040
+ render_in.put(None)
2041
+
2042
+ for t in r_workers:
2043
+ t.join()
2044
+ writer_thread.join()
2045
 
2046
  logging.info("Grounded-SAM-2 output written to: %s", output_video_path)
2047
  return output_video_path
2048
 
2049
  finally:
 
2050
  try:
2051
  shutil.rmtree(frame_dir)
2052
  logging.info("Cleaned up temp frame dir: %s", frame_dir)
models/segmenters/grounded_sam2.py CHANGED
@@ -90,18 +90,24 @@ class MaskDictionary:
90
  ) -> int:
91
  """Match current detections against tracked objects via IoU."""
92
  updated = {}
 
93
  for _seg_id, seg_info in self.labels.items():
94
  if seg_info.mask is None or seg_info.mask.sum() == 0:
95
  continue
96
  matched_id = 0
 
97
  for _obj_id, obj_info in tracking_dict.labels.items():
 
 
98
  iou = self._iou(seg_info.mask, obj_info.mask)
99
- if iou > iou_threshold:
 
100
  matched_id = obj_info.instance_id
101
- break
102
  if not matched_id:
103
  objects_count += 1
104
  matched_id = objects_count
 
 
105
  new_info = ObjectInfo(
106
  instance_id=matched_id,
107
  mask=seg_info.mask,
@@ -111,6 +117,47 @@ class MaskDictionary:
111
  self.labels = updated
112
  return objects_count
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  def get_target_class_name(self, instance_id: int) -> str:
115
  info = self.labels.get(instance_id)
116
  return info.class_name if info else ""
@@ -277,6 +324,122 @@ class GroundedSAM2Segmenter(Segmenter):
277
  boxes=det.boxes,
278
  )
279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  # -- Video-level tracking interface -------------------------------------
281
 
282
  def process_video(
 
90
  ) -> int:
91
  """Match current detections against tracked objects via IoU."""
92
  updated = {}
93
+ used_tracked_ids = set()
94
  for _seg_id, seg_info in self.labels.items():
95
  if seg_info.mask is None or seg_info.mask.sum() == 0:
96
  continue
97
  matched_id = 0
98
+ best_iou = iou_threshold
99
  for _obj_id, obj_info in tracking_dict.labels.items():
100
+ if obj_info.instance_id in used_tracked_ids:
101
+ continue
102
  iou = self._iou(seg_info.mask, obj_info.mask)
103
+ if iou > best_iou:
104
+ best_iou = iou
105
  matched_id = obj_info.instance_id
 
106
  if not matched_id:
107
  objects_count += 1
108
  matched_id = objects_count
109
+ else:
110
+ used_tracked_ids.add(matched_id)
111
  new_info = ObjectInfo(
112
  instance_id=matched_id,
113
  mask=seg_info.mask,
 
117
  self.labels = updated
118
  return objects_count
119
 
120
+ def update_masks_with_remapping(
121
+ self,
122
+ tracking_dict: "MaskDictionary",
123
+ iou_threshold: float = 0.5,
124
+ objects_count: int = 0,
125
+ ) -> Tuple[int, Dict[int, int]]:
126
+ """Match detections against tracked objects, returning ID remapping.
127
+
128
+ Same logic as ``update_masks`` but additionally returns a dict
129
+ mapping original (local) IDs to the assigned (global) IDs.
130
+ """
131
+ updated = {}
132
+ remapping: Dict[int, int] = {}
133
+ used_tracked_ids = set()
134
+ for seg_id, seg_info in self.labels.items():
135
+ if seg_info.mask is None or seg_info.mask.sum() == 0:
136
+ continue
137
+ matched_id = 0
138
+ best_iou = iou_threshold
139
+ for _obj_id, obj_info in tracking_dict.labels.items():
140
+ if obj_info.instance_id in used_tracked_ids:
141
+ continue
142
+ iou = self._iou(seg_info.mask, obj_info.mask)
143
+ if iou > best_iou:
144
+ best_iou = iou
145
+ matched_id = obj_info.instance_id
146
+ if not matched_id:
147
+ objects_count += 1
148
+ matched_id = objects_count
149
+ else:
150
+ used_tracked_ids.add(matched_id)
151
+ new_info = ObjectInfo(
152
+ instance_id=matched_id,
153
+ mask=seg_info.mask,
154
+ class_name=seg_info.class_name,
155
+ )
156
+ updated[matched_id] = new_info
157
+ remapping[seg_id] = matched_id
158
+ self.labels = updated
159
+ return objects_count, remapping
160
+
161
  def get_target_class_name(self, instance_id: int) -> str:
162
  info = self.labels.get(instance_id)
163
  return info.class_name if info else ""
 
324
  boxes=det.boxes,
325
  )
326
 
327
+ # -- Multi-GPU helper methods -------------------------------------------
328
+
329
+ def detect_keyframe(
330
+ self,
331
+ image: "Image",
332
+ text_prompts: List[str],
333
+ ) -> Tuple[Optional[np.ndarray], Optional[torch.Tensor], List[str]]:
334
+ """Run GDINO + SAM2 image predictor on a single keyframe.
335
+
336
+ Args:
337
+ image: PIL Image in RGB mode.
338
+ text_prompts: Text queries for Grounding DINO.
339
+
340
+ Returns:
341
+ ``(masks, boxes, labels)`` where *masks* is an ``(N, H, W)``
342
+ numpy array, *boxes* is an ``(N, 4)`` tensor on device, and
343
+ *labels* is a list of strings. Returns ``(None, None, [])``
344
+ when no objects are detected.
345
+ """
346
+ self._ensure_models_loaded()
347
+
348
+ prompt = self._gdino_detector._build_prompt(text_prompts)
349
+ gdino_processor = self._gdino_detector.processor
350
+ gdino_model = self._gdino_detector.model
351
+
352
+ inputs = gdino_processor(
353
+ images=image, text=prompt, return_tensors="pt"
354
+ )
355
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
356
+
357
+ with torch.no_grad():
358
+ outputs = gdino_model(**inputs)
359
+
360
+ results = self._gdino_detector._post_process(
361
+ outputs,
362
+ inputs["input_ids"],
363
+ target_sizes=[image.size[::-1]],
364
+ )
365
+
366
+ input_boxes = results[0]["boxes"]
367
+ det_labels = results[0].get("text_labels") or results[0].get("labels", [])
368
+ if torch.is_tensor(det_labels):
369
+ det_labels = det_labels.detach().cpu().tolist()
370
+ det_labels = [str(l) for l in det_labels]
371
+
372
+ if input_boxes.shape[0] == 0:
373
+ return None, None, []
374
+
375
+ # SAM2 image predictor
376
+ self._image_predictor.set_image(np.array(image))
377
+ masks, scores, logits = self._image_predictor.predict(
378
+ point_coords=None,
379
+ point_labels=None,
380
+ box=input_boxes,
381
+ multimask_output=False,
382
+ )
383
+
384
+ if masks.ndim == 2:
385
+ masks = masks[None]
386
+ elif masks.ndim == 4:
387
+ masks = masks.squeeze(1)
388
+
389
+ return masks, input_boxes, det_labels
390
+
391
+ def propagate_segment(
392
+ self,
393
+ inference_state: Any,
394
+ start_idx: int,
395
+ mask_dict: "MaskDictionary",
396
+ step: int,
397
+ ) -> Dict[int, Dict[int, "ObjectInfo"]]:
398
+ """Propagate masks for a single segment via SAM2 video predictor.
399
+
400
+ Calls ``reset_state`` first, making this safe to call independently
401
+ (and therefore parallelisable across GPUs).
402
+
403
+ Args:
404
+ inference_state: SAM2 video predictor state (from ``init_state``).
405
+ start_idx: Starting frame index for this segment.
406
+ mask_dict: MaskDictionary with object masks for the keyframe.
407
+ step: Maximum number of frames to propagate.
408
+
409
+ Returns:
410
+ Dict mapping ``frame_idx`` → ``{obj_id: ObjectInfo}`` using the
411
+ IDs from *mask_dict* (local, not yet reconciled).
412
+ """
413
+ self._video_predictor.reset_state(inference_state)
414
+
415
+ for obj_id, obj_info in mask_dict.labels.items():
416
+ self._video_predictor.add_new_mask(
417
+ inference_state,
418
+ start_idx,
419
+ obj_id,
420
+ obj_info.mask,
421
+ )
422
+
423
+ segment_results: Dict[int, Dict[int, ObjectInfo]] = {}
424
+ for out_frame_idx, out_obj_ids, out_mask_logits in self._video_predictor.propagate_in_video(
425
+ inference_state,
426
+ max_frame_num_to_track=step,
427
+ start_frame_idx=start_idx,
428
+ ):
429
+ frame_objects: Dict[int, ObjectInfo] = {}
430
+ for i, out_obj_id in enumerate(out_obj_ids):
431
+ out_mask = (out_mask_logits[i] > 0.0)
432
+ info = ObjectInfo(
433
+ instance_id=out_obj_id,
434
+ mask=out_mask[0],
435
+ class_name=mask_dict.get_target_class_name(out_obj_id),
436
+ )
437
+ info.update_box()
438
+ frame_objects[out_obj_id] = info
439
+ segment_results[out_frame_idx] = frame_objects
440
+
441
+ return segment_results
442
+
443
  # -- Video-level tracking interface -------------------------------------
444
 
445
  def process_video(