Spaces:
Running
perf: pipeline GSAM2 tracking + rendering with startup buffer
Browse filesPipeline tracking and rendering so segments stream as they're tracked
instead of waiting for all tracking to complete before rendering begins.
- Add on_segment callback to process_video() for incremental feeding
- Hoist render/writer infrastructure above tracking branch
- Single-GPU: callback-based incremental feeding via on_segment
- Multi-GPU: streaming reconciliation with segment_buffer pattern
- Writer startup buffer (60 frames) before streaming begins
- 3x frame duplication for 18 FPS stream (6 FPS processing throughput)
- Safety threshold: pause streaming when buffer < 20 frames
- try/finally sentinel safety for render worker shutdown
- Increase render_out queue from 64 to 128 for buffer headroom
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- inference.py +385 -311
- models/segmenters/grounded_sam2.py +14 -1
|
@@ -1195,9 +1195,6 @@ def run_inference(
|
|
| 1195 |
display_labels.append("")
|
| 1196 |
continue
|
| 1197 |
lbl = d.get('label', 'obj')
|
| 1198 |
-
# Append Track ID
|
| 1199 |
-
if 'track_id' in d:
|
| 1200 |
-
lbl = f"{d['track_id']} {lbl}"
|
| 1201 |
display_labels.append(lbl)
|
| 1202 |
|
| 1203 |
p_frame = draw_boxes(p_frame, display_boxes, label_names=display_labels)
|
|
@@ -1675,303 +1672,15 @@ def run_grounded_sam2_tracking(
|
|
| 1675 |
|
| 1676 |
num_gpus = torch.cuda.device_count()
|
| 1677 |
|
| 1678 |
-
# ==================================================================
|
| 1679 |
-
# Phase 1-4: Tracking (single-GPU fallback vs multi-GPU pipeline)
|
| 1680 |
-
# ==================================================================
|
| 1681 |
-
if num_gpus <= 1:
|
| 1682 |
-
# ---------- Single-GPU fallback ----------
|
| 1683 |
-
device_str = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 1684 |
-
_seg_kw = {"num_maskmem": num_maskmem} if num_maskmem is not None else {}
|
| 1685 |
-
segmenter = load_segmenter_on_device(active_segmenter, device_str, **_seg_kw)
|
| 1686 |
-
_check_cancellation(job_id)
|
| 1687 |
-
|
| 1688 |
-
if _perf_metrics is not None:
|
| 1689 |
-
segmenter._perf_metrics = _perf_metrics
|
| 1690 |
-
segmenter._perf_lock = None
|
| 1691 |
-
|
| 1692 |
-
if _perf_metrics is not None:
|
| 1693 |
-
_t_track = time.perf_counter()
|
| 1694 |
-
|
| 1695 |
-
tracking_results = segmenter.process_video(
|
| 1696 |
-
frame_dir, frame_names, queries,
|
| 1697 |
-
)
|
| 1698 |
-
|
| 1699 |
-
if _perf_metrics is not None:
|
| 1700 |
-
_perf_metrics["tracking_total_ms"] = (time.perf_counter() - _t_track) * 1000.0
|
| 1701 |
-
|
| 1702 |
-
logging.info(
|
| 1703 |
-
"Single-GPU tracking complete: %d frames",
|
| 1704 |
-
len(tracking_results),
|
| 1705 |
-
)
|
| 1706 |
-
else:
|
| 1707 |
-
# ---------- Multi-GPU pipeline ----------
|
| 1708 |
-
logging.info(
|
| 1709 |
-
"Multi-GPU GSAM2 tracking: %d GPUs, %d frames, step=%d",
|
| 1710 |
-
num_gpus, total_frames, step,
|
| 1711 |
-
)
|
| 1712 |
-
|
| 1713 |
-
# Phase 1: Load one segmenter per GPU (parallel)
|
| 1714 |
-
segmenters = []
|
| 1715 |
-
with ThreadPoolExecutor(max_workers=num_gpus) as pool:
|
| 1716 |
-
_seg_kw_multi = {"num_maskmem": num_maskmem} if num_maskmem is not None else {}
|
| 1717 |
-
futs = [
|
| 1718 |
-
pool.submit(
|
| 1719 |
-
load_segmenter_on_device,
|
| 1720 |
-
active_segmenter,
|
| 1721 |
-
f"cuda:{i}",
|
| 1722 |
-
**_seg_kw_multi,
|
| 1723 |
-
)
|
| 1724 |
-
for i in range(num_gpus)
|
| 1725 |
-
]
|
| 1726 |
-
segmenters = [f.result() for f in futs]
|
| 1727 |
-
logging.info("Loaded %d segmenters", len(segmenters))
|
| 1728 |
-
|
| 1729 |
-
if _perf_metrics is not None:
|
| 1730 |
-
import threading as _th
|
| 1731 |
-
_actual_lock = _perf_lock or _th.Lock()
|
| 1732 |
-
for seg in segmenters:
|
| 1733 |
-
seg._perf_metrics = _perf_metrics
|
| 1734 |
-
seg._perf_lock = _actual_lock
|
| 1735 |
-
|
| 1736 |
-
# Phase 2: Init SAM2 models/state per GPU (parallel)
|
| 1737 |
-
def _init_seg_state(seg):
|
| 1738 |
-
seg._ensure_models_loaded()
|
| 1739 |
-
return seg._video_predictor.init_state(
|
| 1740 |
-
video_path=frame_dir,
|
| 1741 |
-
offload_video_to_cpu=True,
|
| 1742 |
-
async_loading_frames=True,
|
| 1743 |
-
)
|
| 1744 |
-
|
| 1745 |
-
with ThreadPoolExecutor(max_workers=len(segmenters)) as pool:
|
| 1746 |
-
futs = [pool.submit(_init_seg_state, seg) for seg in segmenters]
|
| 1747 |
-
inference_states = [f.result() for f in futs]
|
| 1748 |
-
|
| 1749 |
-
if _perf_metrics is not None:
|
| 1750 |
-
_t_track = time.perf_counter()
|
| 1751 |
-
|
| 1752 |
-
# Phase 3: Parallel segment processing (queue-based workers)
|
| 1753 |
-
segments = list(range(0, total_frames, step))
|
| 1754 |
-
seg_queue_in: Queue = Queue()
|
| 1755 |
-
seg_queue_out: Queue = Queue()
|
| 1756 |
-
|
| 1757 |
-
for i, start_idx in enumerate(segments):
|
| 1758 |
-
seg_queue_in.put((i, start_idx))
|
| 1759 |
-
for _ in segmenters:
|
| 1760 |
-
seg_queue_in.put(None) # sentinel
|
| 1761 |
-
|
| 1762 |
-
iou_thresh = segmenters[0].iou_threshold
|
| 1763 |
-
|
| 1764 |
-
def _segment_worker(gpu_idx: int):
|
| 1765 |
-
seg = segmenters[gpu_idx]
|
| 1766 |
-
state = inference_states[gpu_idx]
|
| 1767 |
-
device_type = seg.device.split(":")[0]
|
| 1768 |
-
ac = (
|
| 1769 |
-
torch.autocast(device_type=device_type, dtype=torch.bfloat16)
|
| 1770 |
-
if device_type == "cuda"
|
| 1771 |
-
else nullcontext()
|
| 1772 |
-
)
|
| 1773 |
-
with ac:
|
| 1774 |
-
while True:
|
| 1775 |
-
if job_id:
|
| 1776 |
-
try:
|
| 1777 |
-
_check_cancellation(job_id)
|
| 1778 |
-
except RuntimeError as e:
|
| 1779 |
-
if "cancelled" in str(e).lower():
|
| 1780 |
-
logging.info(
|
| 1781 |
-
"Segment worker %d cancelled.",
|
| 1782 |
-
gpu_idx,
|
| 1783 |
-
)
|
| 1784 |
-
break
|
| 1785 |
-
raise
|
| 1786 |
-
item = seg_queue_in.get()
|
| 1787 |
-
if item is None:
|
| 1788 |
-
break
|
| 1789 |
-
seg_idx, start_idx = item
|
| 1790 |
-
try:
|
| 1791 |
-
logging.info(
|
| 1792 |
-
"GPU %d processing segment %d (frame %d)",
|
| 1793 |
-
gpu_idx, seg_idx, start_idx,
|
| 1794 |
-
)
|
| 1795 |
-
img_path = os.path.join(
|
| 1796 |
-
frame_dir, frame_names[start_idx]
|
| 1797 |
-
)
|
| 1798 |
-
with PILImage.open(img_path) as pil_img:
|
| 1799 |
-
image = pil_img.convert("RGB")
|
| 1800 |
-
|
| 1801 |
-
if job_id:
|
| 1802 |
-
_check_cancellation(job_id)
|
| 1803 |
-
masks, boxes, labels = seg.detect_keyframe(
|
| 1804 |
-
image, queries,
|
| 1805 |
-
)
|
| 1806 |
-
|
| 1807 |
-
if masks is None:
|
| 1808 |
-
seg_queue_out.put(
|
| 1809 |
-
(seg_idx, start_idx, None, {})
|
| 1810 |
-
)
|
| 1811 |
-
continue
|
| 1812 |
-
|
| 1813 |
-
mask_dict = MaskDictionary()
|
| 1814 |
-
mask_dict.add_new_frame_annotation(
|
| 1815 |
-
mask_list=torch.tensor(masks).to(seg.device),
|
| 1816 |
-
box_list=(
|
| 1817 |
-
boxes.clone()
|
| 1818 |
-
if torch.is_tensor(boxes)
|
| 1819 |
-
else torch.tensor(boxes)
|
| 1820 |
-
),
|
| 1821 |
-
label_list=labels,
|
| 1822 |
-
)
|
| 1823 |
-
|
| 1824 |
-
segment_output = seg.propagate_segment(
|
| 1825 |
-
state, start_idx, mask_dict, step,
|
| 1826 |
-
)
|
| 1827 |
-
seg_queue_out.put(
|
| 1828 |
-
(seg_idx, start_idx, mask_dict, segment_output)
|
| 1829 |
-
)
|
| 1830 |
-
except RuntimeError as e:
|
| 1831 |
-
if "cancelled" in str(e).lower():
|
| 1832 |
-
logging.info(
|
| 1833 |
-
"Segment worker %d cancelled.",
|
| 1834 |
-
gpu_idx,
|
| 1835 |
-
)
|
| 1836 |
-
break
|
| 1837 |
-
raise
|
| 1838 |
-
except Exception:
|
| 1839 |
-
logging.exception(
|
| 1840 |
-
"Segment %d failed on GPU %d",
|
| 1841 |
-
seg_idx, gpu_idx,
|
| 1842 |
-
)
|
| 1843 |
-
seg_queue_out.put(
|
| 1844 |
-
(seg_idx, start_idx, None, {})
|
| 1845 |
-
)
|
| 1846 |
-
|
| 1847 |
-
seg_workers = []
|
| 1848 |
-
for i in range(num_gpus):
|
| 1849 |
-
t = Thread(
|
| 1850 |
-
target=_segment_worker, args=(i,), daemon=True,
|
| 1851 |
-
)
|
| 1852 |
-
t.start()
|
| 1853 |
-
seg_workers.append(t)
|
| 1854 |
-
|
| 1855 |
-
for t in seg_workers:
|
| 1856 |
-
t.join()
|
| 1857 |
-
|
| 1858 |
-
# Collect all segment outputs
|
| 1859 |
-
segment_data: Dict[int, Tuple] = {}
|
| 1860 |
-
while not seg_queue_out.empty():
|
| 1861 |
-
seg_idx, start_idx, mask_dict, segment_output = seg_queue_out.get()
|
| 1862 |
-
segment_data[seg_idx] = (start_idx, mask_dict, segment_output)
|
| 1863 |
-
|
| 1864 |
-
# Phase 4: Sequential ID reconciliation
|
| 1865 |
-
if _perf_metrics is not None:
|
| 1866 |
-
_t_recon = time.perf_counter()
|
| 1867 |
-
|
| 1868 |
-
global_id_counter = 0
|
| 1869 |
-
sam2_masks = MaskDictionary()
|
| 1870 |
-
tracking_results: Dict[int, Dict[int, ObjectInfo]] = {}
|
| 1871 |
-
|
| 1872 |
-
def _mask_to_cpu(mask):
|
| 1873 |
-
"""Normalize mask to CPU tensor (still used for keyframe mask_dict)."""
|
| 1874 |
-
if torch.is_tensor(mask):
|
| 1875 |
-
return mask.detach().cpu()
|
| 1876 |
-
return mask
|
| 1877 |
-
|
| 1878 |
-
for seg_idx in sorted(segment_data.keys()):
|
| 1879 |
-
start_idx, mask_dict, segment_output = segment_data[seg_idx]
|
| 1880 |
-
|
| 1881 |
-
if mask_dict is None or not mask_dict.labels:
|
| 1882 |
-
# No detections — carry forward previous masks
|
| 1883 |
-
for fi in range(
|
| 1884 |
-
start_idx, min(start_idx + step, total_frames)
|
| 1885 |
-
):
|
| 1886 |
-
if fi not in tracking_results:
|
| 1887 |
-
tracking_results[fi] = (
|
| 1888 |
-
{
|
| 1889 |
-
k: ObjectInfo(
|
| 1890 |
-
instance_id=v.instance_id,
|
| 1891 |
-
mask=v.mask,
|
| 1892 |
-
class_name=v.class_name,
|
| 1893 |
-
x1=v.x1, y1=v.y1,
|
| 1894 |
-
x2=v.x2, y2=v.y2,
|
| 1895 |
-
)
|
| 1896 |
-
for k, v in sam2_masks.labels.items()
|
| 1897 |
-
}
|
| 1898 |
-
if sam2_masks.labels
|
| 1899 |
-
else {}
|
| 1900 |
-
)
|
| 1901 |
-
continue
|
| 1902 |
-
|
| 1903 |
-
# Normalize keyframe masks to CPU before cross-GPU IoU matching.
|
| 1904 |
-
for info in mask_dict.labels.values():
|
| 1905 |
-
info.mask = _mask_to_cpu(info.mask)
|
| 1906 |
-
|
| 1907 |
-
# IoU match + get local→global remapping
|
| 1908 |
-
global_id_counter, remapping = (
|
| 1909 |
-
mask_dict.update_masks_with_remapping(
|
| 1910 |
-
tracking_dict=sam2_masks,
|
| 1911 |
-
iou_threshold=iou_thresh,
|
| 1912 |
-
objects_count=global_id_counter,
|
| 1913 |
-
)
|
| 1914 |
-
)
|
| 1915 |
-
|
| 1916 |
-
if not mask_dict.labels:
|
| 1917 |
-
for fi in range(
|
| 1918 |
-
start_idx, min(start_idx + step, total_frames)
|
| 1919 |
-
):
|
| 1920 |
-
tracking_results[fi] = {}
|
| 1921 |
-
continue
|
| 1922 |
-
|
| 1923 |
-
# Bulk CPU transfer: 3 CUDA syncs total (was 100+ per-mask syncs)
|
| 1924 |
-
segment_results = segment_output.to_object_dicts()
|
| 1925 |
-
|
| 1926 |
-
# Apply remapping to every frame in this segment
|
| 1927 |
-
for frame_idx, frame_objects in segment_results.items():
|
| 1928 |
-
remapped: Dict[int, ObjectInfo] = {}
|
| 1929 |
-
for local_id, obj_info in frame_objects.items():
|
| 1930 |
-
global_id = remapping.get(local_id)
|
| 1931 |
-
if global_id is None:
|
| 1932 |
-
continue
|
| 1933 |
-
remapped[global_id] = ObjectInfo(
|
| 1934 |
-
instance_id=global_id,
|
| 1935 |
-
mask=obj_info.mask,
|
| 1936 |
-
class_name=obj_info.class_name,
|
| 1937 |
-
x1=obj_info.x1, y1=obj_info.y1,
|
| 1938 |
-
x2=obj_info.x2, y2=obj_info.y2,
|
| 1939 |
-
)
|
| 1940 |
-
tracking_results[frame_idx] = remapped
|
| 1941 |
-
|
| 1942 |
-
# Update running tracker with last frame of this segment
|
| 1943 |
-
if segment_results:
|
| 1944 |
-
last_fi = max(segment_results.keys())
|
| 1945 |
-
last_objs = tracking_results.get(last_fi, {})
|
| 1946 |
-
sam2_masks = MaskDictionary()
|
| 1947 |
-
sam2_masks.labels = copy.deepcopy(last_objs)
|
| 1948 |
-
if last_objs:
|
| 1949 |
-
first_info = next(iter(last_objs.values()))
|
| 1950 |
-
if first_info.mask is not None:
|
| 1951 |
-
m = first_info.mask
|
| 1952 |
-
sam2_masks.mask_height = (
|
| 1953 |
-
m.shape[-2] if m.ndim >= 2 else 0
|
| 1954 |
-
)
|
| 1955 |
-
sam2_masks.mask_width = (
|
| 1956 |
-
m.shape[-1] if m.ndim >= 2 else 0
|
| 1957 |
-
)
|
| 1958 |
-
|
| 1959 |
-
if _perf_metrics is not None:
|
| 1960 |
-
_perf_metrics["id_reconciliation_ms"] = (time.perf_counter() - _t_recon) * 1000.0
|
| 1961 |
-
_perf_metrics["tracking_total_ms"] = (time.perf_counter() - _t_track) * 1000.0
|
| 1962 |
-
|
| 1963 |
-
logging.info(
|
| 1964 |
-
"Multi-GPU reconciliation complete: %d frames, %d objects",
|
| 1965 |
-
len(tracking_results), global_id_counter,
|
| 1966 |
-
)
|
| 1967 |
-
|
| 1968 |
# ==================================================================
|
| 1969 |
# Phase 5: Parallel rendering + sequential video writing
|
|
|
|
|
|
|
| 1970 |
# ==================================================================
|
| 1971 |
_check_cancellation(job_id)
|
| 1972 |
|
| 1973 |
render_in: Queue = Queue(maxsize=32)
|
| 1974 |
-
render_out: Queue = Queue(maxsize=
|
| 1975 |
render_done = False
|
| 1976 |
num_render_workers = min(4, os.cpu_count() or 1)
|
| 1977 |
|
|
@@ -2112,6 +1821,11 @@ def run_grounded_sam2_tracking(
|
|
| 2112 |
next_idx = 0
|
| 2113 |
buf: Dict[int, Tuple] = {}
|
| 2114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2115 |
# Per-track bbox history (replaces ByteTracker for GSAM2)
|
| 2116 |
track_history: Dict[int, List] = {}
|
| 2117 |
speed_est = SpeedEstimator(fps=fps) if enable_gpt else None
|
|
@@ -2127,6 +1841,28 @@ def run_grounded_sam2_tracking(
|
|
| 2127 |
with StreamingVideoWriter(
|
| 2128 |
output_video_path, fps, width, height
|
| 2129 |
) as writer:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2130 |
while next_idx < total_frames:
|
| 2131 |
try:
|
| 2132 |
while next_idx not in buf:
|
|
@@ -2239,22 +1975,40 @@ def run_grounded_sam2_tracking(
|
|
| 2239 |
if _perf_metrics is not None:
|
| 2240 |
_t_w = time.perf_counter()
|
| 2241 |
|
|
|
|
| 2242 |
writer.write(frm)
|
| 2243 |
|
| 2244 |
if _perf_metrics is not None:
|
| 2245 |
_perf_metrics["writer_total_ms"] += (time.perf_counter() - _t_w) * 1000.0
|
| 2246 |
|
| 2247 |
-
|
| 2248 |
-
|
| 2249 |
-
|
| 2250 |
-
|
| 2251 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2252 |
if job_id:
|
| 2253 |
-
|
|
|
|
| 2254 |
else:
|
| 2255 |
-
|
| 2256 |
-
|
| 2257 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2258 |
|
| 2259 |
next_idx += 1
|
| 2260 |
if next_idx % 30 == 0:
|
|
@@ -2285,15 +2039,335 @@ def run_grounded_sam2_tracking(
|
|
| 2285 |
writer_thread = Thread(target=_writer_loop, daemon=True)
|
| 2286 |
writer_thread.start()
|
| 2287 |
|
| 2288 |
-
#
|
| 2289 |
-
|
| 2290 |
-
|
| 2291 |
-
|
| 2292 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2293 |
|
| 2294 |
-
|
| 2295 |
-
|
| 2296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2297 |
|
| 2298 |
for t in r_workers:
|
| 2299 |
t.join()
|
|
|
|
| 1195 |
display_labels.append("")
|
| 1196 |
continue
|
| 1197 |
lbl = d.get('label', 'obj')
|
|
|
|
|
|
|
|
|
|
| 1198 |
display_labels.append(lbl)
|
| 1199 |
|
| 1200 |
p_frame = draw_boxes(p_frame, display_boxes, label_names=display_labels)
|
|
|
|
| 1672 |
|
| 1673 |
num_gpus = torch.cuda.device_count()
|
| 1674 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1675 |
# ==================================================================
|
| 1676 |
# Phase 5: Parallel rendering + sequential video writing
|
| 1677 |
+
# (Hoisted above tracking so render pipeline starts before tracking
|
| 1678 |
+
# completes — segments are fed incrementally via callback / queue.)
|
| 1679 |
# ==================================================================
|
| 1680 |
_check_cancellation(job_id)
|
| 1681 |
|
| 1682 |
render_in: Queue = Queue(maxsize=32)
|
| 1683 |
+
render_out: Queue = Queue(maxsize=128)
|
| 1684 |
render_done = False
|
| 1685 |
num_render_workers = min(4, os.cpu_count() or 1)
|
| 1686 |
|
|
|
|
| 1821 |
next_idx = 0
|
| 1822 |
buf: Dict[int, Tuple] = {}
|
| 1823 |
|
| 1824 |
+
# Streaming constants
|
| 1825 |
+
STARTUP_BUFFER = 60
|
| 1826 |
+
SAFETY_THRESHOLD = 20
|
| 1827 |
+
FRAME_DUP = 3
|
| 1828 |
+
|
| 1829 |
# Per-track bbox history (replaces ByteTracker for GSAM2)
|
| 1830 |
track_history: Dict[int, List] = {}
|
| 1831 |
speed_est = SpeedEstimator(fps=fps) if enable_gpt else None
|
|
|
|
| 1841 |
with StreamingVideoWriter(
|
| 1842 |
output_video_path, fps, width, height
|
| 1843 |
) as writer:
|
| 1844 |
+
# --- Phase 1: Startup buffering ---
|
| 1845 |
+
playback_started = False
|
| 1846 |
+
while not playback_started:
|
| 1847 |
+
try:
|
| 1848 |
+
idx, frm, fobjs = render_out.get(timeout=1.0)
|
| 1849 |
+
buf[idx] = (frm, fobjs)
|
| 1850 |
+
except Empty:
|
| 1851 |
+
if not any(t.is_alive() for t in r_workers) and render_out.empty():
|
| 1852 |
+
playback_started = True
|
| 1853 |
+
break
|
| 1854 |
+
continue
|
| 1855 |
+
|
| 1856 |
+
ahead = sum(1 for k in buf if k >= next_idx)
|
| 1857 |
+
if ahead >= STARTUP_BUFFER or ahead >= total_frames:
|
| 1858 |
+
playback_started = True
|
| 1859 |
+
|
| 1860 |
+
logging.info(
|
| 1861 |
+
"Startup buffer filled (%d frames), beginning playback",
|
| 1862 |
+
len(buf),
|
| 1863 |
+
)
|
| 1864 |
+
|
| 1865 |
+
# --- Phase 2: Write + stream with safety gating ---
|
| 1866 |
while next_idx < total_frames:
|
| 1867 |
try:
|
| 1868 |
while next_idx not in buf:
|
|
|
|
| 1975 |
if _perf_metrics is not None:
|
| 1976 |
_t_w = time.perf_counter()
|
| 1977 |
|
| 1978 |
+
# Write to video file (always, single copy)
|
| 1979 |
writer.write(frm)
|
| 1980 |
|
| 1981 |
if _perf_metrics is not None:
|
| 1982 |
_perf_metrics["writer_total_ms"] += (time.perf_counter() - _t_w) * 1000.0
|
| 1983 |
|
| 1984 |
+
# --- Streaming with buffer gating + frame duplication ---
|
| 1985 |
+
if stream_queue or job_id:
|
| 1986 |
+
# Drain any immediately available frames for accurate buffer level
|
| 1987 |
+
while True:
|
| 1988 |
+
try:
|
| 1989 |
+
idx2, frm2, fobjs2 = render_out.get_nowait()
|
| 1990 |
+
buf[idx2] = (frm2, fobjs2)
|
| 1991 |
+
except Empty:
|
| 1992 |
+
break
|
| 1993 |
+
|
| 1994 |
+
buffer_ahead = sum(1 for k in buf if k > next_idx)
|
| 1995 |
+
|
| 1996 |
+
if buffer_ahead >= SAFETY_THRESHOLD or next_idx >= total_frames - 1:
|
| 1997 |
+
from jobs.streaming import publish_frame as _pub
|
| 1998 |
if job_id:
|
| 1999 |
+
for _ in range(FRAME_DUP):
|
| 2000 |
+
_pub(job_id, frm)
|
| 2001 |
else:
|
| 2002 |
+
for _ in range(FRAME_DUP):
|
| 2003 |
+
try:
|
| 2004 |
+
stream_queue.put(frm, timeout=0.01)
|
| 2005 |
+
except Exception:
|
| 2006 |
+
pass
|
| 2007 |
+
else:
|
| 2008 |
+
logging.debug(
|
| 2009 |
+
"Stream paused: buffer=%d < threshold=%d at frame %d",
|
| 2010 |
+
buffer_ahead, SAFETY_THRESHOLD, next_idx,
|
| 2011 |
+
)
|
| 2012 |
|
| 2013 |
next_idx += 1
|
| 2014 |
if next_idx % 30 == 0:
|
|
|
|
| 2039 |
writer_thread = Thread(target=_writer_loop, daemon=True)
|
| 2040 |
writer_thread.start()
|
| 2041 |
|
| 2042 |
+
# ==================================================================
|
| 2043 |
+
# Phase 1-4: Tracking (single-GPU fallback vs multi-GPU pipeline)
|
| 2044 |
+
# Segments are fed incrementally to render_in as they complete.
|
| 2045 |
+
# ==================================================================
|
| 2046 |
+
try:
|
| 2047 |
+
if num_gpus <= 1:
|
| 2048 |
+
# ---------- Single-GPU fallback ----------
|
| 2049 |
+
device_str = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 2050 |
+
_seg_kw = {"num_maskmem": num_maskmem} if num_maskmem is not None else {}
|
| 2051 |
+
segmenter = load_segmenter_on_device(active_segmenter, device_str, **_seg_kw)
|
| 2052 |
+
_check_cancellation(job_id)
|
| 2053 |
+
|
| 2054 |
+
if _perf_metrics is not None:
|
| 2055 |
+
segmenter._perf_metrics = _perf_metrics
|
| 2056 |
+
segmenter._perf_lock = None
|
| 2057 |
+
|
| 2058 |
+
if _perf_metrics is not None:
|
| 2059 |
+
_t_track = time.perf_counter()
|
| 2060 |
+
|
| 2061 |
+
def _feed_segment(seg_frames):
|
| 2062 |
+
for fidx in sorted(seg_frames.keys()):
|
| 2063 |
+
render_in.put((fidx, seg_frames[fidx]))
|
| 2064 |
+
|
| 2065 |
+
tracking_results = segmenter.process_video(
|
| 2066 |
+
frame_dir, frame_names, queries,
|
| 2067 |
+
on_segment=_feed_segment,
|
| 2068 |
+
)
|
| 2069 |
+
|
| 2070 |
+
if _perf_metrics is not None:
|
| 2071 |
+
_perf_metrics["tracking_total_ms"] = (time.perf_counter() - _t_track) * 1000.0
|
| 2072 |
+
|
| 2073 |
+
logging.info(
|
| 2074 |
+
"Single-GPU tracking complete: %d frames",
|
| 2075 |
+
len(tracking_results),
|
| 2076 |
+
)
|
| 2077 |
+
else:
|
| 2078 |
+
# ---------- Multi-GPU pipeline ----------
|
| 2079 |
+
logging.info(
|
| 2080 |
+
"Multi-GPU GSAM2 tracking: %d GPUs, %d frames, step=%d",
|
| 2081 |
+
num_gpus, total_frames, step,
|
| 2082 |
+
)
|
| 2083 |
+
|
| 2084 |
+
# Phase 1: Load one segmenter per GPU (parallel)
|
| 2085 |
+
segmenters = []
|
| 2086 |
+
with ThreadPoolExecutor(max_workers=num_gpus) as pool:
|
| 2087 |
+
_seg_kw_multi = {"num_maskmem": num_maskmem} if num_maskmem is not None else {}
|
| 2088 |
+
futs = [
|
| 2089 |
+
pool.submit(
|
| 2090 |
+
load_segmenter_on_device,
|
| 2091 |
+
active_segmenter,
|
| 2092 |
+
f"cuda:{i}",
|
| 2093 |
+
**_seg_kw_multi,
|
| 2094 |
+
)
|
| 2095 |
+
for i in range(num_gpus)
|
| 2096 |
+
]
|
| 2097 |
+
segmenters = [f.result() for f in futs]
|
| 2098 |
+
logging.info("Loaded %d segmenters", len(segmenters))
|
| 2099 |
+
|
| 2100 |
+
if _perf_metrics is not None:
|
| 2101 |
+
import threading as _th
|
| 2102 |
+
_actual_lock = _perf_lock or _th.Lock()
|
| 2103 |
+
for seg in segmenters:
|
| 2104 |
+
seg._perf_metrics = _perf_metrics
|
| 2105 |
+
seg._perf_lock = _actual_lock
|
| 2106 |
+
|
| 2107 |
+
# Phase 2: Init SAM2 models/state per GPU (parallel)
|
| 2108 |
+
def _init_seg_state(seg):
|
| 2109 |
+
seg._ensure_models_loaded()
|
| 2110 |
+
return seg._video_predictor.init_state(
|
| 2111 |
+
video_path=frame_dir,
|
| 2112 |
+
offload_video_to_cpu=True,
|
| 2113 |
+
async_loading_frames=True,
|
| 2114 |
+
)
|
| 2115 |
|
| 2116 |
+
with ThreadPoolExecutor(max_workers=len(segmenters)) as pool:
|
| 2117 |
+
futs = [pool.submit(_init_seg_state, seg) for seg in segmenters]
|
| 2118 |
+
inference_states = [f.result() for f in futs]
|
| 2119 |
+
|
| 2120 |
+
if _perf_metrics is not None:
|
| 2121 |
+
_t_track = time.perf_counter()
|
| 2122 |
+
|
| 2123 |
+
# Phase 3: Parallel segment processing (queue-based workers)
|
| 2124 |
+
segments = list(range(0, total_frames, step))
|
| 2125 |
+
num_total_segments = len(segments)
|
| 2126 |
+
seg_queue_in: Queue = Queue()
|
| 2127 |
+
seg_queue_out: Queue = Queue()
|
| 2128 |
+
|
| 2129 |
+
for i, start_idx in enumerate(segments):
|
| 2130 |
+
seg_queue_in.put((i, start_idx))
|
| 2131 |
+
for _ in segmenters:
|
| 2132 |
+
seg_queue_in.put(None) # sentinel
|
| 2133 |
+
|
| 2134 |
+
iou_thresh = segmenters[0].iou_threshold
|
| 2135 |
+
|
| 2136 |
+
def _segment_worker(gpu_idx: int):
|
| 2137 |
+
seg = segmenters[gpu_idx]
|
| 2138 |
+
state = inference_states[gpu_idx]
|
| 2139 |
+
device_type = seg.device.split(":")[0]
|
| 2140 |
+
ac = (
|
| 2141 |
+
torch.autocast(device_type=device_type, dtype=torch.bfloat16)
|
| 2142 |
+
if device_type == "cuda"
|
| 2143 |
+
else nullcontext()
|
| 2144 |
+
)
|
| 2145 |
+
with ac:
|
| 2146 |
+
while True:
|
| 2147 |
+
if job_id:
|
| 2148 |
+
try:
|
| 2149 |
+
_check_cancellation(job_id)
|
| 2150 |
+
except RuntimeError as e:
|
| 2151 |
+
if "cancelled" in str(e).lower():
|
| 2152 |
+
logging.info(
|
| 2153 |
+
"Segment worker %d cancelled.",
|
| 2154 |
+
gpu_idx,
|
| 2155 |
+
)
|
| 2156 |
+
break
|
| 2157 |
+
raise
|
| 2158 |
+
item = seg_queue_in.get()
|
| 2159 |
+
if item is None:
|
| 2160 |
+
break
|
| 2161 |
+
seg_idx, start_idx = item
|
| 2162 |
+
try:
|
| 2163 |
+
logging.info(
|
| 2164 |
+
"GPU %d processing segment %d (frame %d)",
|
| 2165 |
+
gpu_idx, seg_idx, start_idx,
|
| 2166 |
+
)
|
| 2167 |
+
img_path = os.path.join(
|
| 2168 |
+
frame_dir, frame_names[start_idx]
|
| 2169 |
+
)
|
| 2170 |
+
with PILImage.open(img_path) as pil_img:
|
| 2171 |
+
image = pil_img.convert("RGB")
|
| 2172 |
+
|
| 2173 |
+
if job_id:
|
| 2174 |
+
_check_cancellation(job_id)
|
| 2175 |
+
masks, boxes, labels = seg.detect_keyframe(
|
| 2176 |
+
image, queries,
|
| 2177 |
+
)
|
| 2178 |
+
|
| 2179 |
+
if masks is None:
|
| 2180 |
+
seg_queue_out.put(
|
| 2181 |
+
(seg_idx, start_idx, None, {})
|
| 2182 |
+
)
|
| 2183 |
+
continue
|
| 2184 |
+
|
| 2185 |
+
mask_dict = MaskDictionary()
|
| 2186 |
+
mask_dict.add_new_frame_annotation(
|
| 2187 |
+
mask_list=torch.tensor(masks).to(seg.device),
|
| 2188 |
+
box_list=(
|
| 2189 |
+
boxes.clone()
|
| 2190 |
+
if torch.is_tensor(boxes)
|
| 2191 |
+
else torch.tensor(boxes)
|
| 2192 |
+
),
|
| 2193 |
+
label_list=labels,
|
| 2194 |
+
)
|
| 2195 |
+
|
| 2196 |
+
segment_output = seg.propagate_segment(
|
| 2197 |
+
state, start_idx, mask_dict, step,
|
| 2198 |
+
)
|
| 2199 |
+
seg_queue_out.put(
|
| 2200 |
+
(seg_idx, start_idx, mask_dict, segment_output)
|
| 2201 |
+
)
|
| 2202 |
+
except RuntimeError as e:
|
| 2203 |
+
if "cancelled" in str(e).lower():
|
| 2204 |
+
logging.info(
|
| 2205 |
+
"Segment worker %d cancelled.",
|
| 2206 |
+
gpu_idx,
|
| 2207 |
+
)
|
| 2208 |
+
break
|
| 2209 |
+
raise
|
| 2210 |
+
except Exception:
|
| 2211 |
+
logging.exception(
|
| 2212 |
+
"Segment %d failed on GPU %d",
|
| 2213 |
+
seg_idx, gpu_idx,
|
| 2214 |
+
)
|
| 2215 |
+
seg_queue_out.put(
|
| 2216 |
+
(seg_idx, start_idx, None, {})
|
| 2217 |
+
)
|
| 2218 |
+
|
| 2219 |
+
seg_workers = []
|
| 2220 |
+
for i in range(num_gpus):
|
| 2221 |
+
t = Thread(
|
| 2222 |
+
target=_segment_worker, args=(i,), daemon=True,
|
| 2223 |
+
)
|
| 2224 |
+
t.start()
|
| 2225 |
+
seg_workers.append(t)
|
| 2226 |
+
|
| 2227 |
+
# Phase 4: Streaming reconciliation — process segments in order
|
| 2228 |
+
# as they arrive, feeding render_in incrementally.
|
| 2229 |
+
if _perf_metrics is not None:
|
| 2230 |
+
_t_recon = time.perf_counter()
|
| 2231 |
+
|
| 2232 |
+
global_id_counter = 0
|
| 2233 |
+
sam2_masks = MaskDictionary()
|
| 2234 |
+
tracking_results: Dict[int, Dict[int, ObjectInfo]] = {}
|
| 2235 |
+
|
| 2236 |
+
def _mask_to_cpu(mask):
|
| 2237 |
+
"""Normalize mask to CPU tensor (still used for keyframe mask_dict)."""
|
| 2238 |
+
if torch.is_tensor(mask):
|
| 2239 |
+
return mask.detach().cpu()
|
| 2240 |
+
return mask
|
| 2241 |
+
|
| 2242 |
+
next_seg_idx = 0
|
| 2243 |
+
segment_buffer: Dict[int, Tuple] = {}
|
| 2244 |
+
|
| 2245 |
+
while next_seg_idx < num_total_segments:
|
| 2246 |
+
try:
|
| 2247 |
+
seg_idx, start_idx, mask_dict, segment_output = seg_queue_out.get(timeout=1.0)
|
| 2248 |
+
except Empty:
|
| 2249 |
+
if job_id:
|
| 2250 |
+
_check_cancellation(job_id)
|
| 2251 |
+
# Check if all segment workers are still alive
|
| 2252 |
+
if not any(t.is_alive() for t in seg_workers) and seg_queue_out.empty():
|
| 2253 |
+
logging.error(
|
| 2254 |
+
"All segment workers stopped while waiting for segment %d",
|
| 2255 |
+
next_seg_idx,
|
| 2256 |
+
)
|
| 2257 |
+
break
|
| 2258 |
+
continue
|
| 2259 |
+
segment_buffer[seg_idx] = (start_idx, mask_dict, segment_output)
|
| 2260 |
+
|
| 2261 |
+
# Process contiguous ready segments in order
|
| 2262 |
+
while next_seg_idx in segment_buffer:
|
| 2263 |
+
start_idx, mask_dict, segment_output = segment_buffer.pop(next_seg_idx)
|
| 2264 |
+
|
| 2265 |
+
if mask_dict is None or not mask_dict.labels:
|
| 2266 |
+
# No detections — carry forward previous masks
|
| 2267 |
+
for fi in range(
|
| 2268 |
+
start_idx, min(start_idx + step, total_frames)
|
| 2269 |
+
):
|
| 2270 |
+
if fi not in tracking_results:
|
| 2271 |
+
tracking_results[fi] = (
|
| 2272 |
+
{
|
| 2273 |
+
k: ObjectInfo(
|
| 2274 |
+
instance_id=v.instance_id,
|
| 2275 |
+
mask=v.mask,
|
| 2276 |
+
class_name=v.class_name,
|
| 2277 |
+
x1=v.x1, y1=v.y1,
|
| 2278 |
+
x2=v.x2, y2=v.y2,
|
| 2279 |
+
)
|
| 2280 |
+
for k, v in sam2_masks.labels.items()
|
| 2281 |
+
}
|
| 2282 |
+
if sam2_masks.labels
|
| 2283 |
+
else {}
|
| 2284 |
+
)
|
| 2285 |
+
render_in.put((fi, tracking_results.get(fi, {})))
|
| 2286 |
+
next_seg_idx += 1
|
| 2287 |
+
continue
|
| 2288 |
+
|
| 2289 |
+
# Normalize keyframe masks to CPU before cross-GPU IoU matching.
|
| 2290 |
+
for info in mask_dict.labels.values():
|
| 2291 |
+
info.mask = _mask_to_cpu(info.mask)
|
| 2292 |
+
|
| 2293 |
+
# IoU match + get local→global remapping
|
| 2294 |
+
global_id_counter, remapping = (
|
| 2295 |
+
mask_dict.update_masks_with_remapping(
|
| 2296 |
+
tracking_dict=sam2_masks,
|
| 2297 |
+
iou_threshold=iou_thresh,
|
| 2298 |
+
objects_count=global_id_counter,
|
| 2299 |
+
)
|
| 2300 |
+
)
|
| 2301 |
+
|
| 2302 |
+
if not mask_dict.labels:
|
| 2303 |
+
for fi in range(
|
| 2304 |
+
start_idx, min(start_idx + step, total_frames)
|
| 2305 |
+
):
|
| 2306 |
+
tracking_results[fi] = {}
|
| 2307 |
+
render_in.put((fi, {}))
|
| 2308 |
+
next_seg_idx += 1
|
| 2309 |
+
continue
|
| 2310 |
+
|
| 2311 |
+
# Bulk CPU transfer: 3 CUDA syncs total (was 100+ per-mask syncs)
|
| 2312 |
+
segment_results = segment_output.to_object_dicts()
|
| 2313 |
+
|
| 2314 |
+
# Apply remapping to every frame in this segment
|
| 2315 |
+
for frame_idx, frame_objects in segment_results.items():
|
| 2316 |
+
remapped: Dict[int, ObjectInfo] = {}
|
| 2317 |
+
for local_id, obj_info in frame_objects.items():
|
| 2318 |
+
global_id = remapping.get(local_id)
|
| 2319 |
+
if global_id is None:
|
| 2320 |
+
continue
|
| 2321 |
+
remapped[global_id] = ObjectInfo(
|
| 2322 |
+
instance_id=global_id,
|
| 2323 |
+
mask=obj_info.mask,
|
| 2324 |
+
class_name=obj_info.class_name,
|
| 2325 |
+
x1=obj_info.x1, y1=obj_info.y1,
|
| 2326 |
+
x2=obj_info.x2, y2=obj_info.y2,
|
| 2327 |
+
)
|
| 2328 |
+
tracking_results[frame_idx] = remapped
|
| 2329 |
+
|
| 2330 |
+
# Update running tracker with last frame of this segment
|
| 2331 |
+
if segment_results:
|
| 2332 |
+
last_fi = max(segment_results.keys())
|
| 2333 |
+
last_objs = tracking_results.get(last_fi, {})
|
| 2334 |
+
sam2_masks = MaskDictionary()
|
| 2335 |
+
sam2_masks.labels = copy.deepcopy(last_objs)
|
| 2336 |
+
if last_objs:
|
| 2337 |
+
first_info = next(iter(last_objs.values()))
|
| 2338 |
+
if first_info.mask is not None:
|
| 2339 |
+
m = first_info.mask
|
| 2340 |
+
sam2_masks.mask_height = (
|
| 2341 |
+
m.shape[-2] if m.ndim >= 2 else 0
|
| 2342 |
+
)
|
| 2343 |
+
sam2_masks.mask_width = (
|
| 2344 |
+
m.shape[-1] if m.ndim >= 2 else 0
|
| 2345 |
+
)
|
| 2346 |
+
|
| 2347 |
+
# Feed reconciled frames to render immediately
|
| 2348 |
+
for fi in range(start_idx, min(start_idx + step, total_frames)):
|
| 2349 |
+
render_in.put((fi, tracking_results.get(fi, {})))
|
| 2350 |
+
|
| 2351 |
+
next_seg_idx += 1
|
| 2352 |
+
|
| 2353 |
+
for t in seg_workers:
|
| 2354 |
+
t.join()
|
| 2355 |
+
|
| 2356 |
+
if _perf_metrics is not None:
|
| 2357 |
+
_perf_metrics["id_reconciliation_ms"] = (time.perf_counter() - _t_recon) * 1000.0
|
| 2358 |
+
_perf_metrics["tracking_total_ms"] = (time.perf_counter() - _t_track) * 1000.0
|
| 2359 |
+
|
| 2360 |
+
logging.info(
|
| 2361 |
+
"Multi-GPU reconciliation complete: %d frames, %d objects",
|
| 2362 |
+
len(tracking_results), global_id_counter,
|
| 2363 |
+
)
|
| 2364 |
+
finally:
|
| 2365 |
+
# Sentinels for render workers — always sent even on error/cancellation
|
| 2366 |
+
for _ in r_workers:
|
| 2367 |
+
try:
|
| 2368 |
+
render_in.put(None, timeout=5.0)
|
| 2369 |
+
except Full:
|
| 2370 |
+
pass
|
| 2371 |
|
| 2372 |
for t in r_workers:
|
| 2373 |
t.join()
|
|
@@ -13,7 +13,7 @@ import logging
|
|
| 13 |
import time
|
| 14 |
from contextlib import nullcontext
|
| 15 |
from dataclasses import dataclass, field
|
| 16 |
-
from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING
|
| 17 |
|
| 18 |
import numpy as np
|
| 19 |
import torch
|
|
@@ -673,6 +673,7 @@ class GroundedSAM2Segmenter(Segmenter):
|
|
| 673 |
frame_dir: str,
|
| 674 |
frame_names: List[str],
|
| 675 |
text_prompts: List[str],
|
|
|
|
| 676 |
) -> Dict[int, Dict[int, ObjectInfo]]:
|
| 677 |
"""Run full Grounded-SAM-2 tracking pipeline on extracted JPEG frames.
|
| 678 |
|
|
@@ -680,6 +681,8 @@ class GroundedSAM2Segmenter(Segmenter):
|
|
| 680 |
frame_dir: Directory containing JPEG frames.
|
| 681 |
frame_names: Sorted list of frame filenames.
|
| 682 |
text_prompts: Text queries for Grounding DINO.
|
|
|
|
|
|
|
| 683 |
|
| 684 |
Returns:
|
| 685 |
Dict mapping frame_idx -> {obj_id: ObjectInfo} with masks,
|
|
@@ -764,6 +767,7 @@ class GroundedSAM2Segmenter(Segmenter):
|
|
| 764 |
if input_boxes.shape[0] == 0:
|
| 765 |
logging.info("No detections on keyframe %d, propagating previous masks", start_idx)
|
| 766 |
# Fill empty results for this segment
|
|
|
|
| 767 |
for fi in range(start_idx, min(start_idx + step, total_frames)):
|
| 768 |
if fi not in all_results:
|
| 769 |
# Carry forward last known masks
|
|
@@ -776,6 +780,9 @@ class GroundedSAM2Segmenter(Segmenter):
|
|
| 776 |
)
|
| 777 |
for k, v in sam2_masks.labels.items()
|
| 778 |
} if sam2_masks.labels else {}
|
|
|
|
|
|
|
|
|
|
| 779 |
continue
|
| 780 |
|
| 781 |
# -- SAM2 image predictor on keyframe --
|
|
@@ -831,8 +838,12 @@ class GroundedSAM2Segmenter(Segmenter):
|
|
| 831 |
_pm["id_reconciliation_ms"] += _d
|
| 832 |
|
| 833 |
if len(mask_dict.labels) == 0:
|
|
|
|
| 834 |
for fi in range(start_idx, min(start_idx + step, total_frames)):
|
| 835 |
all_results[fi] = {}
|
|
|
|
|
|
|
|
|
|
| 836 |
continue
|
| 837 |
|
| 838 |
# -- SAM2 video predictor: propagate masks --
|
|
@@ -846,6 +857,8 @@ class GroundedSAM2Segmenter(Segmenter):
|
|
| 846 |
|
| 847 |
if segment_results:
|
| 848 |
all_results.update(segment_results)
|
|
|
|
|
|
|
| 849 |
last_fi = segment_output.last_frame_idx()
|
| 850 |
if last_fi is not None:
|
| 851 |
last_frame_objects = all_results.get(last_fi, {})
|
|
|
|
| 13 |
import time
|
| 14 |
from contextlib import nullcontext
|
| 15 |
from dataclasses import dataclass, field
|
| 16 |
+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING
|
| 17 |
|
| 18 |
import numpy as np
|
| 19 |
import torch
|
|
|
|
| 673 |
frame_dir: str,
|
| 674 |
frame_names: List[str],
|
| 675 |
text_prompts: List[str],
|
| 676 |
+
on_segment: Optional[Callable[[Dict[int, Dict[int, "ObjectInfo"]]], None]] = None,
|
| 677 |
) -> Dict[int, Dict[int, ObjectInfo]]:
|
| 678 |
"""Run full Grounded-SAM-2 tracking pipeline on extracted JPEG frames.
|
| 679 |
|
|
|
|
| 681 |
frame_dir: Directory containing JPEG frames.
|
| 682 |
frame_names: Sorted list of frame filenames.
|
| 683 |
text_prompts: Text queries for Grounding DINO.
|
| 684 |
+
on_segment: Optional callback invoked after each segment completes.
|
| 685 |
+
Receives ``{frame_idx: {obj_id: ObjectInfo}}`` for the segment.
|
| 686 |
|
| 687 |
Returns:
|
| 688 |
Dict mapping frame_idx -> {obj_id: ObjectInfo} with masks,
|
|
|
|
| 767 |
if input_boxes.shape[0] == 0:
|
| 768 |
logging.info("No detections on keyframe %d, propagating previous masks", start_idx)
|
| 769 |
# Fill empty results for this segment
|
| 770 |
+
seg_results: Dict[int, Dict[int, ObjectInfo]] = {}
|
| 771 |
for fi in range(start_idx, min(start_idx + step, total_frames)):
|
| 772 |
if fi not in all_results:
|
| 773 |
# Carry forward last known masks
|
|
|
|
| 780 |
)
|
| 781 |
for k, v in sam2_masks.labels.items()
|
| 782 |
} if sam2_masks.labels else {}
|
| 783 |
+
seg_results[fi] = all_results[fi]
|
| 784 |
+
if on_segment and seg_results:
|
| 785 |
+
on_segment(seg_results)
|
| 786 |
continue
|
| 787 |
|
| 788 |
# -- SAM2 image predictor on keyframe --
|
|
|
|
| 838 |
_pm["id_reconciliation_ms"] += _d
|
| 839 |
|
| 840 |
if len(mask_dict.labels) == 0:
|
| 841 |
+
seg_results_empty: Dict[int, Dict[int, ObjectInfo]] = {}
|
| 842 |
for fi in range(start_idx, min(start_idx + step, total_frames)):
|
| 843 |
all_results[fi] = {}
|
| 844 |
+
seg_results_empty[fi] = {}
|
| 845 |
+
if on_segment:
|
| 846 |
+
on_segment(seg_results_empty)
|
| 847 |
continue
|
| 848 |
|
| 849 |
# -- SAM2 video predictor: propagate masks --
|
|
|
|
| 857 |
|
| 858 |
if segment_results:
|
| 859 |
all_results.update(segment_results)
|
| 860 |
+
if on_segment:
|
| 861 |
+
on_segment(segment_results)
|
| 862 |
last_fi = segment_output.last_frame_idx()
|
| 863 |
if last_fi is not None:
|
| 864 |
last_frame_objects = all_results.get(last_fi, {})
|