Zhen Ye Claude Opus 4.6 commited on
Commit
7124ca1
·
1 Parent(s): 882ee33

perf: pipeline GSAM2 tracking + rendering with startup buffer

Browse files

Pipeline tracking and rendering so segments stream as they're tracked
instead of waiting for all tracking to complete before rendering begins.

- Add on_segment callback to process_video() for incremental feeding
- Hoist render/writer infrastructure above tracking branch
- Single-GPU: callback-based incremental feeding via on_segment
- Multi-GPU: streaming reconciliation with segment_buffer pattern
- Writer startup buffer (60 frames) before streaming begins
- 3x frame duplication for 18 FPS stream (6 FPS processing throughput)
- Safety threshold: pause streaming when buffer < 20 frames
- try/finally sentinel safety for render worker shutdown
- Increase render_out queue from 64 to 128 for buffer headroom

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (2) hide show
  1. inference.py +385 -311
  2. models/segmenters/grounded_sam2.py +14 -1
inference.py CHANGED
@@ -1195,9 +1195,6 @@ def run_inference(
1195
  display_labels.append("")
1196
  continue
1197
  lbl = d.get('label', 'obj')
1198
- # Append Track ID
1199
- if 'track_id' in d:
1200
- lbl = f"{d['track_id']} {lbl}"
1201
  display_labels.append(lbl)
1202
 
1203
  p_frame = draw_boxes(p_frame, display_boxes, label_names=display_labels)
@@ -1675,303 +1672,15 @@ def run_grounded_sam2_tracking(
1675
 
1676
  num_gpus = torch.cuda.device_count()
1677
 
1678
- # ==================================================================
1679
- # Phase 1-4: Tracking (single-GPU fallback vs multi-GPU pipeline)
1680
- # ==================================================================
1681
- if num_gpus <= 1:
1682
- # ---------- Single-GPU fallback ----------
1683
- device_str = "cuda:0" if torch.cuda.is_available() else "cpu"
1684
- _seg_kw = {"num_maskmem": num_maskmem} if num_maskmem is not None else {}
1685
- segmenter = load_segmenter_on_device(active_segmenter, device_str, **_seg_kw)
1686
- _check_cancellation(job_id)
1687
-
1688
- if _perf_metrics is not None:
1689
- segmenter._perf_metrics = _perf_metrics
1690
- segmenter._perf_lock = None
1691
-
1692
- if _perf_metrics is not None:
1693
- _t_track = time.perf_counter()
1694
-
1695
- tracking_results = segmenter.process_video(
1696
- frame_dir, frame_names, queries,
1697
- )
1698
-
1699
- if _perf_metrics is not None:
1700
- _perf_metrics["tracking_total_ms"] = (time.perf_counter() - _t_track) * 1000.0
1701
-
1702
- logging.info(
1703
- "Single-GPU tracking complete: %d frames",
1704
- len(tracking_results),
1705
- )
1706
- else:
1707
- # ---------- Multi-GPU pipeline ----------
1708
- logging.info(
1709
- "Multi-GPU GSAM2 tracking: %d GPUs, %d frames, step=%d",
1710
- num_gpus, total_frames, step,
1711
- )
1712
-
1713
- # Phase 1: Load one segmenter per GPU (parallel)
1714
- segmenters = []
1715
- with ThreadPoolExecutor(max_workers=num_gpus) as pool:
1716
- _seg_kw_multi = {"num_maskmem": num_maskmem} if num_maskmem is not None else {}
1717
- futs = [
1718
- pool.submit(
1719
- load_segmenter_on_device,
1720
- active_segmenter,
1721
- f"cuda:{i}",
1722
- **_seg_kw_multi,
1723
- )
1724
- for i in range(num_gpus)
1725
- ]
1726
- segmenters = [f.result() for f in futs]
1727
- logging.info("Loaded %d segmenters", len(segmenters))
1728
-
1729
- if _perf_metrics is not None:
1730
- import threading as _th
1731
- _actual_lock = _perf_lock or _th.Lock()
1732
- for seg in segmenters:
1733
- seg._perf_metrics = _perf_metrics
1734
- seg._perf_lock = _actual_lock
1735
-
1736
- # Phase 2: Init SAM2 models/state per GPU (parallel)
1737
- def _init_seg_state(seg):
1738
- seg._ensure_models_loaded()
1739
- return seg._video_predictor.init_state(
1740
- video_path=frame_dir,
1741
- offload_video_to_cpu=True,
1742
- async_loading_frames=True,
1743
- )
1744
-
1745
- with ThreadPoolExecutor(max_workers=len(segmenters)) as pool:
1746
- futs = [pool.submit(_init_seg_state, seg) for seg in segmenters]
1747
- inference_states = [f.result() for f in futs]
1748
-
1749
- if _perf_metrics is not None:
1750
- _t_track = time.perf_counter()
1751
-
1752
- # Phase 3: Parallel segment processing (queue-based workers)
1753
- segments = list(range(0, total_frames, step))
1754
- seg_queue_in: Queue = Queue()
1755
- seg_queue_out: Queue = Queue()
1756
-
1757
- for i, start_idx in enumerate(segments):
1758
- seg_queue_in.put((i, start_idx))
1759
- for _ in segmenters:
1760
- seg_queue_in.put(None) # sentinel
1761
-
1762
- iou_thresh = segmenters[0].iou_threshold
1763
-
1764
- def _segment_worker(gpu_idx: int):
1765
- seg = segmenters[gpu_idx]
1766
- state = inference_states[gpu_idx]
1767
- device_type = seg.device.split(":")[0]
1768
- ac = (
1769
- torch.autocast(device_type=device_type, dtype=torch.bfloat16)
1770
- if device_type == "cuda"
1771
- else nullcontext()
1772
- )
1773
- with ac:
1774
- while True:
1775
- if job_id:
1776
- try:
1777
- _check_cancellation(job_id)
1778
- except RuntimeError as e:
1779
- if "cancelled" in str(e).lower():
1780
- logging.info(
1781
- "Segment worker %d cancelled.",
1782
- gpu_idx,
1783
- )
1784
- break
1785
- raise
1786
- item = seg_queue_in.get()
1787
- if item is None:
1788
- break
1789
- seg_idx, start_idx = item
1790
- try:
1791
- logging.info(
1792
- "GPU %d processing segment %d (frame %d)",
1793
- gpu_idx, seg_idx, start_idx,
1794
- )
1795
- img_path = os.path.join(
1796
- frame_dir, frame_names[start_idx]
1797
- )
1798
- with PILImage.open(img_path) as pil_img:
1799
- image = pil_img.convert("RGB")
1800
-
1801
- if job_id:
1802
- _check_cancellation(job_id)
1803
- masks, boxes, labels = seg.detect_keyframe(
1804
- image, queries,
1805
- )
1806
-
1807
- if masks is None:
1808
- seg_queue_out.put(
1809
- (seg_idx, start_idx, None, {})
1810
- )
1811
- continue
1812
-
1813
- mask_dict = MaskDictionary()
1814
- mask_dict.add_new_frame_annotation(
1815
- mask_list=torch.tensor(masks).to(seg.device),
1816
- box_list=(
1817
- boxes.clone()
1818
- if torch.is_tensor(boxes)
1819
- else torch.tensor(boxes)
1820
- ),
1821
- label_list=labels,
1822
- )
1823
-
1824
- segment_output = seg.propagate_segment(
1825
- state, start_idx, mask_dict, step,
1826
- )
1827
- seg_queue_out.put(
1828
- (seg_idx, start_idx, mask_dict, segment_output)
1829
- )
1830
- except RuntimeError as e:
1831
- if "cancelled" in str(e).lower():
1832
- logging.info(
1833
- "Segment worker %d cancelled.",
1834
- gpu_idx,
1835
- )
1836
- break
1837
- raise
1838
- except Exception:
1839
- logging.exception(
1840
- "Segment %d failed on GPU %d",
1841
- seg_idx, gpu_idx,
1842
- )
1843
- seg_queue_out.put(
1844
- (seg_idx, start_idx, None, {})
1845
- )
1846
-
1847
- seg_workers = []
1848
- for i in range(num_gpus):
1849
- t = Thread(
1850
- target=_segment_worker, args=(i,), daemon=True,
1851
- )
1852
- t.start()
1853
- seg_workers.append(t)
1854
-
1855
- for t in seg_workers:
1856
- t.join()
1857
-
1858
- # Collect all segment outputs
1859
- segment_data: Dict[int, Tuple] = {}
1860
- while not seg_queue_out.empty():
1861
- seg_idx, start_idx, mask_dict, segment_output = seg_queue_out.get()
1862
- segment_data[seg_idx] = (start_idx, mask_dict, segment_output)
1863
-
1864
- # Phase 4: Sequential ID reconciliation
1865
- if _perf_metrics is not None:
1866
- _t_recon = time.perf_counter()
1867
-
1868
- global_id_counter = 0
1869
- sam2_masks = MaskDictionary()
1870
- tracking_results: Dict[int, Dict[int, ObjectInfo]] = {}
1871
-
1872
- def _mask_to_cpu(mask):
1873
- """Normalize mask to CPU tensor (still used for keyframe mask_dict)."""
1874
- if torch.is_tensor(mask):
1875
- return mask.detach().cpu()
1876
- return mask
1877
-
1878
- for seg_idx in sorted(segment_data.keys()):
1879
- start_idx, mask_dict, segment_output = segment_data[seg_idx]
1880
-
1881
- if mask_dict is None or not mask_dict.labels:
1882
- # No detections — carry forward previous masks
1883
- for fi in range(
1884
- start_idx, min(start_idx + step, total_frames)
1885
- ):
1886
- if fi not in tracking_results:
1887
- tracking_results[fi] = (
1888
- {
1889
- k: ObjectInfo(
1890
- instance_id=v.instance_id,
1891
- mask=v.mask,
1892
- class_name=v.class_name,
1893
- x1=v.x1, y1=v.y1,
1894
- x2=v.x2, y2=v.y2,
1895
- )
1896
- for k, v in sam2_masks.labels.items()
1897
- }
1898
- if sam2_masks.labels
1899
- else {}
1900
- )
1901
- continue
1902
-
1903
- # Normalize keyframe masks to CPU before cross-GPU IoU matching.
1904
- for info in mask_dict.labels.values():
1905
- info.mask = _mask_to_cpu(info.mask)
1906
-
1907
- # IoU match + get local→global remapping
1908
- global_id_counter, remapping = (
1909
- mask_dict.update_masks_with_remapping(
1910
- tracking_dict=sam2_masks,
1911
- iou_threshold=iou_thresh,
1912
- objects_count=global_id_counter,
1913
- )
1914
- )
1915
-
1916
- if not mask_dict.labels:
1917
- for fi in range(
1918
- start_idx, min(start_idx + step, total_frames)
1919
- ):
1920
- tracking_results[fi] = {}
1921
- continue
1922
-
1923
- # Bulk CPU transfer: 3 CUDA syncs total (was 100+ per-mask syncs)
1924
- segment_results = segment_output.to_object_dicts()
1925
-
1926
- # Apply remapping to every frame in this segment
1927
- for frame_idx, frame_objects in segment_results.items():
1928
- remapped: Dict[int, ObjectInfo] = {}
1929
- for local_id, obj_info in frame_objects.items():
1930
- global_id = remapping.get(local_id)
1931
- if global_id is None:
1932
- continue
1933
- remapped[global_id] = ObjectInfo(
1934
- instance_id=global_id,
1935
- mask=obj_info.mask,
1936
- class_name=obj_info.class_name,
1937
- x1=obj_info.x1, y1=obj_info.y1,
1938
- x2=obj_info.x2, y2=obj_info.y2,
1939
- )
1940
- tracking_results[frame_idx] = remapped
1941
-
1942
- # Update running tracker with last frame of this segment
1943
- if segment_results:
1944
- last_fi = max(segment_results.keys())
1945
- last_objs = tracking_results.get(last_fi, {})
1946
- sam2_masks = MaskDictionary()
1947
- sam2_masks.labels = copy.deepcopy(last_objs)
1948
- if last_objs:
1949
- first_info = next(iter(last_objs.values()))
1950
- if first_info.mask is not None:
1951
- m = first_info.mask
1952
- sam2_masks.mask_height = (
1953
- m.shape[-2] if m.ndim >= 2 else 0
1954
- )
1955
- sam2_masks.mask_width = (
1956
- m.shape[-1] if m.ndim >= 2 else 0
1957
- )
1958
-
1959
- if _perf_metrics is not None:
1960
- _perf_metrics["id_reconciliation_ms"] = (time.perf_counter() - _t_recon) * 1000.0
1961
- _perf_metrics["tracking_total_ms"] = (time.perf_counter() - _t_track) * 1000.0
1962
-
1963
- logging.info(
1964
- "Multi-GPU reconciliation complete: %d frames, %d objects",
1965
- len(tracking_results), global_id_counter,
1966
- )
1967
-
1968
  # ==================================================================
1969
  # Phase 5: Parallel rendering + sequential video writing
 
 
1970
  # ==================================================================
1971
  _check_cancellation(job_id)
1972
 
1973
  render_in: Queue = Queue(maxsize=32)
1974
- render_out: Queue = Queue(maxsize=64)
1975
  render_done = False
1976
  num_render_workers = min(4, os.cpu_count() or 1)
1977
 
@@ -2112,6 +1821,11 @@ def run_grounded_sam2_tracking(
2112
  next_idx = 0
2113
  buf: Dict[int, Tuple] = {}
2114
 
 
 
 
 
 
2115
  # Per-track bbox history (replaces ByteTracker for GSAM2)
2116
  track_history: Dict[int, List] = {}
2117
  speed_est = SpeedEstimator(fps=fps) if enable_gpt else None
@@ -2127,6 +1841,28 @@ def run_grounded_sam2_tracking(
2127
  with StreamingVideoWriter(
2128
  output_video_path, fps, width, height
2129
  ) as writer:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2130
  while next_idx < total_frames:
2131
  try:
2132
  while next_idx not in buf:
@@ -2239,22 +1975,40 @@ def run_grounded_sam2_tracking(
2239
  if _perf_metrics is not None:
2240
  _t_w = time.perf_counter()
2241
 
 
2242
  writer.write(frm)
2243
 
2244
  if _perf_metrics is not None:
2245
  _perf_metrics["writer_total_ms"] += (time.perf_counter() - _t_w) * 1000.0
2246
 
2247
- if stream_queue:
2248
- try:
2249
- from jobs.streaming import (
2250
- publish_frame as _pub,
2251
- )
 
 
 
 
 
 
 
 
 
2252
  if job_id:
2253
- _pub(job_id, frm)
 
2254
  else:
2255
- stream_queue.put(frm, timeout=0.01)
2256
- except Exception:
2257
- pass
 
 
 
 
 
 
 
2258
 
2259
  next_idx += 1
2260
  if next_idx % 30 == 0:
@@ -2285,15 +2039,335 @@ def run_grounded_sam2_tracking(
2285
  writer_thread = Thread(target=_writer_loop, daemon=True)
2286
  writer_thread.start()
2287
 
2288
- # Feed render queue
2289
- for fidx in range(total_frames):
2290
- _check_cancellation(job_id)
2291
- fobjs = tracking_results.get(fidx, {})
2292
- render_in.put((fidx, fobjs))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2293
 
2294
- # Sentinels for render workers
2295
- for _ in r_workers:
2296
- render_in.put(None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2297
 
2298
  for t in r_workers:
2299
  t.join()
 
1195
  display_labels.append("")
1196
  continue
1197
  lbl = d.get('label', 'obj')
 
 
 
1198
  display_labels.append(lbl)
1199
 
1200
  p_frame = draw_boxes(p_frame, display_boxes, label_names=display_labels)
 
1672
 
1673
  num_gpus = torch.cuda.device_count()
1674
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1675
  # ==================================================================
1676
  # Phase 5: Parallel rendering + sequential video writing
1677
+ # (Hoisted above tracking so render pipeline starts before tracking
1678
+ # completes — segments are fed incrementally via callback / queue.)
1679
  # ==================================================================
1680
  _check_cancellation(job_id)
1681
 
1682
  render_in: Queue = Queue(maxsize=32)
1683
+ render_out: Queue = Queue(maxsize=128)
1684
  render_done = False
1685
  num_render_workers = min(4, os.cpu_count() or 1)
1686
 
 
1821
  next_idx = 0
1822
  buf: Dict[int, Tuple] = {}
1823
 
1824
+ # Streaming constants
1825
+ STARTUP_BUFFER = 60
1826
+ SAFETY_THRESHOLD = 20
1827
+ FRAME_DUP = 3
1828
+
1829
  # Per-track bbox history (replaces ByteTracker for GSAM2)
1830
  track_history: Dict[int, List] = {}
1831
  speed_est = SpeedEstimator(fps=fps) if enable_gpt else None
 
1841
  with StreamingVideoWriter(
1842
  output_video_path, fps, width, height
1843
  ) as writer:
1844
+ # --- Phase 1: Startup buffering ---
1845
+ playback_started = False
1846
+ while not playback_started:
1847
+ try:
1848
+ idx, frm, fobjs = render_out.get(timeout=1.0)
1849
+ buf[idx] = (frm, fobjs)
1850
+ except Empty:
1851
+ if not any(t.is_alive() for t in r_workers) and render_out.empty():
1852
+ playback_started = True
1853
+ break
1854
+ continue
1855
+
1856
+ ahead = sum(1 for k in buf if k >= next_idx)
1857
+ if ahead >= STARTUP_BUFFER or ahead >= total_frames:
1858
+ playback_started = True
1859
+
1860
+ logging.info(
1861
+ "Startup buffer filled (%d frames), beginning playback",
1862
+ len(buf),
1863
+ )
1864
+
1865
+ # --- Phase 2: Write + stream with safety gating ---
1866
  while next_idx < total_frames:
1867
  try:
1868
  while next_idx not in buf:
 
1975
  if _perf_metrics is not None:
1976
  _t_w = time.perf_counter()
1977
 
1978
+ # Write to video file (always, single copy)
1979
  writer.write(frm)
1980
 
1981
  if _perf_metrics is not None:
1982
  _perf_metrics["writer_total_ms"] += (time.perf_counter() - _t_w) * 1000.0
1983
 
1984
+ # --- Streaming with buffer gating + frame duplication ---
1985
+ if stream_queue or job_id:
1986
+ # Drain any immediately available frames for accurate buffer level
1987
+ while True:
1988
+ try:
1989
+ idx2, frm2, fobjs2 = render_out.get_nowait()
1990
+ buf[idx2] = (frm2, fobjs2)
1991
+ except Empty:
1992
+ break
1993
+
1994
+ buffer_ahead = sum(1 for k in buf if k > next_idx)
1995
+
1996
+ if buffer_ahead >= SAFETY_THRESHOLD or next_idx >= total_frames - 1:
1997
+ from jobs.streaming import publish_frame as _pub
1998
  if job_id:
1999
+ for _ in range(FRAME_DUP):
2000
+ _pub(job_id, frm)
2001
  else:
2002
+ for _ in range(FRAME_DUP):
2003
+ try:
2004
+ stream_queue.put(frm, timeout=0.01)
2005
+ except Exception:
2006
+ pass
2007
+ else:
2008
+ logging.debug(
2009
+ "Stream paused: buffer=%d < threshold=%d at frame %d",
2010
+ buffer_ahead, SAFETY_THRESHOLD, next_idx,
2011
+ )
2012
 
2013
  next_idx += 1
2014
  if next_idx % 30 == 0:
 
2039
  writer_thread = Thread(target=_writer_loop, daemon=True)
2040
  writer_thread.start()
2041
 
2042
+ # ==================================================================
2043
+ # Phase 1-4: Tracking (single-GPU fallback vs multi-GPU pipeline)
2044
+ # Segments are fed incrementally to render_in as they complete.
2045
+ # ==================================================================
2046
+ try:
2047
+ if num_gpus <= 1:
2048
+ # ---------- Single-GPU fallback ----------
2049
+ device_str = "cuda:0" if torch.cuda.is_available() else "cpu"
2050
+ _seg_kw = {"num_maskmem": num_maskmem} if num_maskmem is not None else {}
2051
+ segmenter = load_segmenter_on_device(active_segmenter, device_str, **_seg_kw)
2052
+ _check_cancellation(job_id)
2053
+
2054
+ if _perf_metrics is not None:
2055
+ segmenter._perf_metrics = _perf_metrics
2056
+ segmenter._perf_lock = None
2057
+
2058
+ if _perf_metrics is not None:
2059
+ _t_track = time.perf_counter()
2060
+
2061
+ def _feed_segment(seg_frames):
2062
+ for fidx in sorted(seg_frames.keys()):
2063
+ render_in.put((fidx, seg_frames[fidx]))
2064
+
2065
+ tracking_results = segmenter.process_video(
2066
+ frame_dir, frame_names, queries,
2067
+ on_segment=_feed_segment,
2068
+ )
2069
+
2070
+ if _perf_metrics is not None:
2071
+ _perf_metrics["tracking_total_ms"] = (time.perf_counter() - _t_track) * 1000.0
2072
+
2073
+ logging.info(
2074
+ "Single-GPU tracking complete: %d frames",
2075
+ len(tracking_results),
2076
+ )
2077
+ else:
2078
+ # ---------- Multi-GPU pipeline ----------
2079
+ logging.info(
2080
+ "Multi-GPU GSAM2 tracking: %d GPUs, %d frames, step=%d",
2081
+ num_gpus, total_frames, step,
2082
+ )
2083
+
2084
+ # Phase 1: Load one segmenter per GPU (parallel)
2085
+ segmenters = []
2086
+ with ThreadPoolExecutor(max_workers=num_gpus) as pool:
2087
+ _seg_kw_multi = {"num_maskmem": num_maskmem} if num_maskmem is not None else {}
2088
+ futs = [
2089
+ pool.submit(
2090
+ load_segmenter_on_device,
2091
+ active_segmenter,
2092
+ f"cuda:{i}",
2093
+ **_seg_kw_multi,
2094
+ )
2095
+ for i in range(num_gpus)
2096
+ ]
2097
+ segmenters = [f.result() for f in futs]
2098
+ logging.info("Loaded %d segmenters", len(segmenters))
2099
+
2100
+ if _perf_metrics is not None:
2101
+ import threading as _th
2102
+ _actual_lock = _perf_lock or _th.Lock()
2103
+ for seg in segmenters:
2104
+ seg._perf_metrics = _perf_metrics
2105
+ seg._perf_lock = _actual_lock
2106
+
2107
+ # Phase 2: Init SAM2 models/state per GPU (parallel)
2108
+ def _init_seg_state(seg):
2109
+ seg._ensure_models_loaded()
2110
+ return seg._video_predictor.init_state(
2111
+ video_path=frame_dir,
2112
+ offload_video_to_cpu=True,
2113
+ async_loading_frames=True,
2114
+ )
2115
 
2116
+ with ThreadPoolExecutor(max_workers=len(segmenters)) as pool:
2117
+ futs = [pool.submit(_init_seg_state, seg) for seg in segmenters]
2118
+ inference_states = [f.result() for f in futs]
2119
+
2120
+ if _perf_metrics is not None:
2121
+ _t_track = time.perf_counter()
2122
+
2123
+ # Phase 3: Parallel segment processing (queue-based workers)
2124
+ segments = list(range(0, total_frames, step))
2125
+ num_total_segments = len(segments)
2126
+ seg_queue_in: Queue = Queue()
2127
+ seg_queue_out: Queue = Queue()
2128
+
2129
+ for i, start_idx in enumerate(segments):
2130
+ seg_queue_in.put((i, start_idx))
2131
+ for _ in segmenters:
2132
+ seg_queue_in.put(None) # sentinel
2133
+
2134
+ iou_thresh = segmenters[0].iou_threshold
2135
+
2136
+ def _segment_worker(gpu_idx: int):
2137
+ seg = segmenters[gpu_idx]
2138
+ state = inference_states[gpu_idx]
2139
+ device_type = seg.device.split(":")[0]
2140
+ ac = (
2141
+ torch.autocast(device_type=device_type, dtype=torch.bfloat16)
2142
+ if device_type == "cuda"
2143
+ else nullcontext()
2144
+ )
2145
+ with ac:
2146
+ while True:
2147
+ if job_id:
2148
+ try:
2149
+ _check_cancellation(job_id)
2150
+ except RuntimeError as e:
2151
+ if "cancelled" in str(e).lower():
2152
+ logging.info(
2153
+ "Segment worker %d cancelled.",
2154
+ gpu_idx,
2155
+ )
2156
+ break
2157
+ raise
2158
+ item = seg_queue_in.get()
2159
+ if item is None:
2160
+ break
2161
+ seg_idx, start_idx = item
2162
+ try:
2163
+ logging.info(
2164
+ "GPU %d processing segment %d (frame %d)",
2165
+ gpu_idx, seg_idx, start_idx,
2166
+ )
2167
+ img_path = os.path.join(
2168
+ frame_dir, frame_names[start_idx]
2169
+ )
2170
+ with PILImage.open(img_path) as pil_img:
2171
+ image = pil_img.convert("RGB")
2172
+
2173
+ if job_id:
2174
+ _check_cancellation(job_id)
2175
+ masks, boxes, labels = seg.detect_keyframe(
2176
+ image, queries,
2177
+ )
2178
+
2179
+ if masks is None:
2180
+ seg_queue_out.put(
2181
+ (seg_idx, start_idx, None, {})
2182
+ )
2183
+ continue
2184
+
2185
+ mask_dict = MaskDictionary()
2186
+ mask_dict.add_new_frame_annotation(
2187
+ mask_list=torch.tensor(masks).to(seg.device),
2188
+ box_list=(
2189
+ boxes.clone()
2190
+ if torch.is_tensor(boxes)
2191
+ else torch.tensor(boxes)
2192
+ ),
2193
+ label_list=labels,
2194
+ )
2195
+
2196
+ segment_output = seg.propagate_segment(
2197
+ state, start_idx, mask_dict, step,
2198
+ )
2199
+ seg_queue_out.put(
2200
+ (seg_idx, start_idx, mask_dict, segment_output)
2201
+ )
2202
+ except RuntimeError as e:
2203
+ if "cancelled" in str(e).lower():
2204
+ logging.info(
2205
+ "Segment worker %d cancelled.",
2206
+ gpu_idx,
2207
+ )
2208
+ break
2209
+ raise
2210
+ except Exception:
2211
+ logging.exception(
2212
+ "Segment %d failed on GPU %d",
2213
+ seg_idx, gpu_idx,
2214
+ )
2215
+ seg_queue_out.put(
2216
+ (seg_idx, start_idx, None, {})
2217
+ )
2218
+
2219
+ seg_workers = []
2220
+ for i in range(num_gpus):
2221
+ t = Thread(
2222
+ target=_segment_worker, args=(i,), daemon=True,
2223
+ )
2224
+ t.start()
2225
+ seg_workers.append(t)
2226
+
2227
+ # Phase 4: Streaming reconciliation — process segments in order
2228
+ # as they arrive, feeding render_in incrementally.
2229
+ if _perf_metrics is not None:
2230
+ _t_recon = time.perf_counter()
2231
+
2232
+ global_id_counter = 0
2233
+ sam2_masks = MaskDictionary()
2234
+ tracking_results: Dict[int, Dict[int, ObjectInfo]] = {}
2235
+
2236
+ def _mask_to_cpu(mask):
2237
+ """Normalize mask to CPU tensor (still used for keyframe mask_dict)."""
2238
+ if torch.is_tensor(mask):
2239
+ return mask.detach().cpu()
2240
+ return mask
2241
+
2242
+ next_seg_idx = 0
2243
+ segment_buffer: Dict[int, Tuple] = {}
2244
+
2245
+ while next_seg_idx < num_total_segments:
2246
+ try:
2247
+ seg_idx, start_idx, mask_dict, segment_output = seg_queue_out.get(timeout=1.0)
2248
+ except Empty:
2249
+ if job_id:
2250
+ _check_cancellation(job_id)
2251
+ # Check if all segment workers are still alive
2252
+ if not any(t.is_alive() for t in seg_workers) and seg_queue_out.empty():
2253
+ logging.error(
2254
+ "All segment workers stopped while waiting for segment %d",
2255
+ next_seg_idx,
2256
+ )
2257
+ break
2258
+ continue
2259
+ segment_buffer[seg_idx] = (start_idx, mask_dict, segment_output)
2260
+
2261
+ # Process contiguous ready segments in order
2262
+ while next_seg_idx in segment_buffer:
2263
+ start_idx, mask_dict, segment_output = segment_buffer.pop(next_seg_idx)
2264
+
2265
+ if mask_dict is None or not mask_dict.labels:
2266
+ # No detections — carry forward previous masks
2267
+ for fi in range(
2268
+ start_idx, min(start_idx + step, total_frames)
2269
+ ):
2270
+ if fi not in tracking_results:
2271
+ tracking_results[fi] = (
2272
+ {
2273
+ k: ObjectInfo(
2274
+ instance_id=v.instance_id,
2275
+ mask=v.mask,
2276
+ class_name=v.class_name,
2277
+ x1=v.x1, y1=v.y1,
2278
+ x2=v.x2, y2=v.y2,
2279
+ )
2280
+ for k, v in sam2_masks.labels.items()
2281
+ }
2282
+ if sam2_masks.labels
2283
+ else {}
2284
+ )
2285
+ render_in.put((fi, tracking_results.get(fi, {})))
2286
+ next_seg_idx += 1
2287
+ continue
2288
+
2289
+ # Normalize keyframe masks to CPU before cross-GPU IoU matching.
2290
+ for info in mask_dict.labels.values():
2291
+ info.mask = _mask_to_cpu(info.mask)
2292
+
2293
+ # IoU match + get local→global remapping
2294
+ global_id_counter, remapping = (
2295
+ mask_dict.update_masks_with_remapping(
2296
+ tracking_dict=sam2_masks,
2297
+ iou_threshold=iou_thresh,
2298
+ objects_count=global_id_counter,
2299
+ )
2300
+ )
2301
+
2302
+ if not mask_dict.labels:
2303
+ for fi in range(
2304
+ start_idx, min(start_idx + step, total_frames)
2305
+ ):
2306
+ tracking_results[fi] = {}
2307
+ render_in.put((fi, {}))
2308
+ next_seg_idx += 1
2309
+ continue
2310
+
2311
+ # Bulk CPU transfer: 3 CUDA syncs total (was 100+ per-mask syncs)
2312
+ segment_results = segment_output.to_object_dicts()
2313
+
2314
+ # Apply remapping to every frame in this segment
2315
+ for frame_idx, frame_objects in segment_results.items():
2316
+ remapped: Dict[int, ObjectInfo] = {}
2317
+ for local_id, obj_info in frame_objects.items():
2318
+ global_id = remapping.get(local_id)
2319
+ if global_id is None:
2320
+ continue
2321
+ remapped[global_id] = ObjectInfo(
2322
+ instance_id=global_id,
2323
+ mask=obj_info.mask,
2324
+ class_name=obj_info.class_name,
2325
+ x1=obj_info.x1, y1=obj_info.y1,
2326
+ x2=obj_info.x2, y2=obj_info.y2,
2327
+ )
2328
+ tracking_results[frame_idx] = remapped
2329
+
2330
+ # Update running tracker with last frame of this segment
2331
+ if segment_results:
2332
+ last_fi = max(segment_results.keys())
2333
+ last_objs = tracking_results.get(last_fi, {})
2334
+ sam2_masks = MaskDictionary()
2335
+ sam2_masks.labels = copy.deepcopy(last_objs)
2336
+ if last_objs:
2337
+ first_info = next(iter(last_objs.values()))
2338
+ if first_info.mask is not None:
2339
+ m = first_info.mask
2340
+ sam2_masks.mask_height = (
2341
+ m.shape[-2] if m.ndim >= 2 else 0
2342
+ )
2343
+ sam2_masks.mask_width = (
2344
+ m.shape[-1] if m.ndim >= 2 else 0
2345
+ )
2346
+
2347
+ # Feed reconciled frames to render immediately
2348
+ for fi in range(start_idx, min(start_idx + step, total_frames)):
2349
+ render_in.put((fi, tracking_results.get(fi, {})))
2350
+
2351
+ next_seg_idx += 1
2352
+
2353
+ for t in seg_workers:
2354
+ t.join()
2355
+
2356
+ if _perf_metrics is not None:
2357
+ _perf_metrics["id_reconciliation_ms"] = (time.perf_counter() - _t_recon) * 1000.0
2358
+ _perf_metrics["tracking_total_ms"] = (time.perf_counter() - _t_track) * 1000.0
2359
+
2360
+ logging.info(
2361
+ "Multi-GPU reconciliation complete: %d frames, %d objects",
2362
+ len(tracking_results), global_id_counter,
2363
+ )
2364
+ finally:
2365
+ # Sentinels for render workers — always sent even on error/cancellation
2366
+ for _ in r_workers:
2367
+ try:
2368
+ render_in.put(None, timeout=5.0)
2369
+ except Full:
2370
+ pass
2371
 
2372
  for t in r_workers:
2373
  t.join()
models/segmenters/grounded_sam2.py CHANGED
@@ -13,7 +13,7 @@ import logging
13
  import time
14
  from contextlib import nullcontext
15
  from dataclasses import dataclass, field
16
- from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING
17
 
18
  import numpy as np
19
  import torch
@@ -673,6 +673,7 @@ class GroundedSAM2Segmenter(Segmenter):
673
  frame_dir: str,
674
  frame_names: List[str],
675
  text_prompts: List[str],
 
676
  ) -> Dict[int, Dict[int, ObjectInfo]]:
677
  """Run full Grounded-SAM-2 tracking pipeline on extracted JPEG frames.
678
 
@@ -680,6 +681,8 @@ class GroundedSAM2Segmenter(Segmenter):
680
  frame_dir: Directory containing JPEG frames.
681
  frame_names: Sorted list of frame filenames.
682
  text_prompts: Text queries for Grounding DINO.
 
 
683
 
684
  Returns:
685
  Dict mapping frame_idx -> {obj_id: ObjectInfo} with masks,
@@ -764,6 +767,7 @@ class GroundedSAM2Segmenter(Segmenter):
764
  if input_boxes.shape[0] == 0:
765
  logging.info("No detections on keyframe %d, propagating previous masks", start_idx)
766
  # Fill empty results for this segment
 
767
  for fi in range(start_idx, min(start_idx + step, total_frames)):
768
  if fi not in all_results:
769
  # Carry forward last known masks
@@ -776,6 +780,9 @@ class GroundedSAM2Segmenter(Segmenter):
776
  )
777
  for k, v in sam2_masks.labels.items()
778
  } if sam2_masks.labels else {}
 
 
 
779
  continue
780
 
781
  # -- SAM2 image predictor on keyframe --
@@ -831,8 +838,12 @@ class GroundedSAM2Segmenter(Segmenter):
831
  _pm["id_reconciliation_ms"] += _d
832
 
833
  if len(mask_dict.labels) == 0:
 
834
  for fi in range(start_idx, min(start_idx + step, total_frames)):
835
  all_results[fi] = {}
 
 
 
836
  continue
837
 
838
  # -- SAM2 video predictor: propagate masks --
@@ -846,6 +857,8 @@ class GroundedSAM2Segmenter(Segmenter):
846
 
847
  if segment_results:
848
  all_results.update(segment_results)
 
 
849
  last_fi = segment_output.last_frame_idx()
850
  if last_fi is not None:
851
  last_frame_objects = all_results.get(last_fi, {})
 
13
  import time
14
  from contextlib import nullcontext
15
  from dataclasses import dataclass, field
16
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING
17
 
18
  import numpy as np
19
  import torch
 
673
  frame_dir: str,
674
  frame_names: List[str],
675
  text_prompts: List[str],
676
+ on_segment: Optional[Callable[[Dict[int, Dict[int, "ObjectInfo"]]], None]] = None,
677
  ) -> Dict[int, Dict[int, ObjectInfo]]:
678
  """Run full Grounded-SAM-2 tracking pipeline on extracted JPEG frames.
679
 
 
681
  frame_dir: Directory containing JPEG frames.
682
  frame_names: Sorted list of frame filenames.
683
  text_prompts: Text queries for Grounding DINO.
684
+ on_segment: Optional callback invoked after each segment completes.
685
+ Receives ``{frame_idx: {obj_id: ObjectInfo}}`` for the segment.
686
 
687
  Returns:
688
  Dict mapping frame_idx -> {obj_id: ObjectInfo} with masks,
 
767
  if input_boxes.shape[0] == 0:
768
  logging.info("No detections on keyframe %d, propagating previous masks", start_idx)
769
  # Fill empty results for this segment
770
+ seg_results: Dict[int, Dict[int, ObjectInfo]] = {}
771
  for fi in range(start_idx, min(start_idx + step, total_frames)):
772
  if fi not in all_results:
773
  # Carry forward last known masks
 
780
  )
781
  for k, v in sam2_masks.labels.items()
782
  } if sam2_masks.labels else {}
783
+ seg_results[fi] = all_results[fi]
784
+ if on_segment and seg_results:
785
+ on_segment(seg_results)
786
  continue
787
 
788
  # -- SAM2 image predictor on keyframe --
 
838
  _pm["id_reconciliation_ms"] += _d
839
 
840
  if len(mask_dict.labels) == 0:
841
+ seg_results_empty: Dict[int, Dict[int, ObjectInfo]] = {}
842
  for fi in range(start_idx, min(start_idx + step, total_frames)):
843
  all_results[fi] = {}
844
+ seg_results_empty[fi] = {}
845
+ if on_segment:
846
+ on_segment(seg_results_empty)
847
  continue
848
 
849
  # -- SAM2 video predictor: propagate masks --
 
857
 
858
  if segment_results:
859
  all_results.update(segment_results)
860
+ if on_segment:
861
+ on_segment(segment_results)
862
  last_fi = segment_output.last_frame_idx()
863
  if last_fi is not None:
864
  last_frame_objects = all_results.get(last_fi, {})