Spaces:
Paused
Paused
Zhen Ye
commited on
Commit
·
c90fe44
1
Parent(s):
032b60f
Harden GSAM2 parallel pipeline and tracking reconciliation
Browse files- inference.py +418 -79
- models/segmenters/grounded_sam2.py +165 -2
inference.py
CHANGED
|
@@ -1586,6 +1586,63 @@ def run_segmentation(
|
|
| 1586 |
|
| 1587 |
|
| 1588 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1589 |
def run_grounded_sam2_tracking(
|
| 1590 |
input_video_path: str,
|
| 1591 |
output_video_path: str,
|
|
@@ -1598,14 +1655,16 @@ def run_grounded_sam2_tracking(
|
|
| 1598 |
) -> str:
|
| 1599 |
"""Run Grounded-SAM-2 video tracking pipeline.
|
| 1600 |
|
| 1601 |
-
|
| 1602 |
-
|
| 1603 |
-
renders the results back into a video.
|
| 1604 |
"""
|
|
|
|
| 1605 |
import shutil
|
|
|
|
|
|
|
| 1606 |
|
| 1607 |
from utils.video import extract_frames_to_jpeg_dir
|
| 1608 |
-
from models.segmenters.
|
| 1609 |
|
| 1610 |
active_segmenter = segmenter_name or "gsam2_large"
|
| 1611 |
logging.info(
|
|
@@ -1622,92 +1681,372 @@ def run_grounded_sam2_tracking(
|
|
| 1622 |
total_frames = len(frame_names)
|
| 1623 |
logging.info("Extracted %d frames to %s", total_frames, frame_dir)
|
| 1624 |
|
| 1625 |
-
|
| 1626 |
-
segmenter = _load_seg(active_segmenter)
|
| 1627 |
|
| 1628 |
-
#
|
| 1629 |
-
|
| 1630 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1631 |
|
| 1632 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1633 |
_check_cancellation(job_id)
|
| 1634 |
-
|
| 1635 |
-
|
| 1636 |
-
|
| 1637 |
-
|
| 1638 |
-
|
| 1639 |
-
|
| 1640 |
-
|
| 1641 |
-
|
| 1642 |
-
|
| 1643 |
-
if
|
| 1644 |
-
|
| 1645 |
-
|
| 1646 |
-
|
| 1647 |
-
|
| 1648 |
-
|
| 1649 |
-
|
| 1650 |
-
|
| 1651 |
-
|
| 1652 |
-
|
| 1653 |
-
|
| 1654 |
-
|
| 1655 |
-
|
| 1656 |
-
|
| 1657 |
-
|
| 1658 |
-
|
| 1659 |
-
|
| 1660 |
-
|
| 1661 |
-
mask_np = np.asarray(mask).astype(bool)
|
| 1662 |
-
# Resize mask if needed
|
| 1663 |
-
if mask_np.shape[:2] != (height, width):
|
| 1664 |
-
mask_np = cv2.resize(
|
| 1665 |
-
mask_np.astype(np.uint8),
|
| 1666 |
-
(width, height),
|
| 1667 |
-
interpolation=cv2.INTER_NEAREST,
|
| 1668 |
-
).astype(bool)
|
| 1669 |
-
masks_list.append(mask_np)
|
| 1670 |
-
|
| 1671 |
-
label = f"{obj_info.instance_id} {obj_info.class_name}"
|
| 1672 |
-
label_list.append(label)
|
| 1673 |
-
|
| 1674 |
-
has_box = not (obj_info.x1 == 0 and obj_info.y1 == 0 and obj_info.x2 == 0 and obj_info.y2 == 0)
|
| 1675 |
-
if has_box:
|
| 1676 |
-
boxes_list.append([obj_info.x1, obj_info.y1, obj_info.x2, obj_info.y2])
|
| 1677 |
-
|
| 1678 |
-
# Draw masks
|
| 1679 |
-
if masks_list:
|
| 1680 |
-
masks_array = np.stack(masks_list)
|
| 1681 |
-
frame = draw_masks(frame, masks_array, labels=label_list)
|
| 1682 |
-
|
| 1683 |
-
# Draw boxes
|
| 1684 |
-
if boxes_list:
|
| 1685 |
-
boxes_array = np.array(boxes_list)
|
| 1686 |
-
frame = draw_boxes(frame, boxes_array, label_names=label_list)
|
| 1687 |
-
|
| 1688 |
-
writer.write(frame)
|
| 1689 |
-
|
| 1690 |
-
# Stream frame if requested
|
| 1691 |
-
if stream_queue:
|
| 1692 |
try:
|
| 1693 |
-
|
| 1694 |
-
|
| 1695 |
-
_pub(job_id, frame)
|
| 1696 |
-
else:
|
| 1697 |
-
stream_queue.put(frame, timeout=0.01)
|
| 1698 |
-
except Exception:
|
| 1699 |
pass
|
| 1700 |
|
| 1701 |
-
|
| 1702 |
-
|
| 1703 |
-
|
| 1704 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1705 |
|
| 1706 |
logging.info("Grounded-SAM-2 output written to: %s", output_video_path)
|
| 1707 |
return output_video_path
|
| 1708 |
|
| 1709 |
finally:
|
| 1710 |
-
# Cleanup temp frame directory
|
| 1711 |
try:
|
| 1712 |
shutil.rmtree(frame_dir)
|
| 1713 |
logging.info("Cleaned up temp frame dir: %s", frame_dir)
|
|
|
|
| 1586 |
|
| 1587 |
|
| 1588 |
|
| 1589 |
+
def _gsam2_render_frame(
|
| 1590 |
+
frame_dir: str,
|
| 1591 |
+
frame_names: List[str],
|
| 1592 |
+
frame_idx: int,
|
| 1593 |
+
frame_objects: Dict,
|
| 1594 |
+
height: int,
|
| 1595 |
+
width: int,
|
| 1596 |
+
) -> np.ndarray:
|
| 1597 |
+
"""Render a single GSAM2 tracking frame (masks + boxes). CPU-only."""
|
| 1598 |
+
from models.segmenters.grounded_sam2 import ObjectInfo
|
| 1599 |
+
|
| 1600 |
+
frame_path = os.path.join(frame_dir, frame_names[frame_idx])
|
| 1601 |
+
frame = cv2.imread(frame_path)
|
| 1602 |
+
if frame is None:
|
| 1603 |
+
return np.zeros((height, width, 3), dtype=np.uint8)
|
| 1604 |
+
|
| 1605 |
+
if not frame_objects:
|
| 1606 |
+
return frame
|
| 1607 |
+
|
| 1608 |
+
masks_list: List[np.ndarray] = []
|
| 1609 |
+
mask_labels: List[str] = []
|
| 1610 |
+
boxes_list: List[List[int]] = []
|
| 1611 |
+
box_labels: List[str] = []
|
| 1612 |
+
|
| 1613 |
+
for _obj_id, obj_info in frame_objects.items():
|
| 1614 |
+
mask = obj_info.mask
|
| 1615 |
+
label = f"{obj_info.instance_id} {obj_info.class_name}"
|
| 1616 |
+
if mask is not None:
|
| 1617 |
+
if isinstance(mask, torch.Tensor):
|
| 1618 |
+
mask_np = mask.cpu().numpy().astype(bool)
|
| 1619 |
+
else:
|
| 1620 |
+
mask_np = np.asarray(mask).astype(bool)
|
| 1621 |
+
if mask_np.shape[:2] != (height, width):
|
| 1622 |
+
mask_np = cv2.resize(
|
| 1623 |
+
mask_np.astype(np.uint8),
|
| 1624 |
+
(width, height),
|
| 1625 |
+
interpolation=cv2.INTER_NEAREST,
|
| 1626 |
+
).astype(bool)
|
| 1627 |
+
masks_list.append(mask_np)
|
| 1628 |
+
mask_labels.append(label)
|
| 1629 |
+
|
| 1630 |
+
has_box = not (
|
| 1631 |
+
obj_info.x1 == 0 and obj_info.y1 == 0
|
| 1632 |
+
and obj_info.x2 == 0 and obj_info.y2 == 0
|
| 1633 |
+
)
|
| 1634 |
+
if has_box:
|
| 1635 |
+
boxes_list.append([obj_info.x1, obj_info.y1, obj_info.x2, obj_info.y2])
|
| 1636 |
+
box_labels.append(label)
|
| 1637 |
+
|
| 1638 |
+
if masks_list:
|
| 1639 |
+
frame = draw_masks(frame, np.stack(masks_list), labels=mask_labels)
|
| 1640 |
+
if boxes_list:
|
| 1641 |
+
frame = draw_boxes(frame, np.array(boxes_list), label_names=box_labels)
|
| 1642 |
+
|
| 1643 |
+
return frame
|
| 1644 |
+
|
| 1645 |
+
|
| 1646 |
def run_grounded_sam2_tracking(
|
| 1647 |
input_video_path: str,
|
| 1648 |
output_video_path: str,
|
|
|
|
| 1655 |
) -> str:
|
| 1656 |
"""Run Grounded-SAM-2 video tracking pipeline.
|
| 1657 |
|
| 1658 |
+
Uses multi-GPU data parallelism when multiple GPUs are available.
|
| 1659 |
+
Falls back to single-GPU ``process_video`` otherwise.
|
|
|
|
| 1660 |
"""
|
| 1661 |
+
import copy
|
| 1662 |
import shutil
|
| 1663 |
+
from contextlib import nullcontext
|
| 1664 |
+
from PIL import Image as PILImage
|
| 1665 |
|
| 1666 |
from utils.video import extract_frames_to_jpeg_dir
|
| 1667 |
+
from models.segmenters.grounded_sam2 import MaskDictionary, ObjectInfo
|
| 1668 |
|
| 1669 |
active_segmenter = segmenter_name or "gsam2_large"
|
| 1670 |
logging.info(
|
|
|
|
| 1681 |
total_frames = len(frame_names)
|
| 1682 |
logging.info("Extracted %d frames to %s", total_frames, frame_dir)
|
| 1683 |
|
| 1684 |
+
num_gpus = torch.cuda.device_count()
|
|
|
|
| 1685 |
|
| 1686 |
+
# ==================================================================
|
| 1687 |
+
# Phase 1-4: Tracking (single-GPU fallback vs multi-GPU pipeline)
|
| 1688 |
+
# ==================================================================
|
| 1689 |
+
if num_gpus <= 1:
|
| 1690 |
+
# ---------- Single-GPU fallback ----------
|
| 1691 |
+
device_str = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 1692 |
+
segmenter = load_segmenter_on_device(active_segmenter, device_str)
|
| 1693 |
+
_check_cancellation(job_id)
|
| 1694 |
+
tracking_results = segmenter.process_video(
|
| 1695 |
+
frame_dir, frame_names, queries,
|
| 1696 |
+
)
|
| 1697 |
+
logging.info(
|
| 1698 |
+
"Single-GPU tracking complete: %d frames",
|
| 1699 |
+
len(tracking_results),
|
| 1700 |
+
)
|
| 1701 |
+
else:
|
| 1702 |
+
# ---------- Multi-GPU pipeline ----------
|
| 1703 |
+
logging.info(
|
| 1704 |
+
"Multi-GPU GSAM2 tracking: %d GPUs, %d frames, step=%d",
|
| 1705 |
+
num_gpus, total_frames, step,
|
| 1706 |
+
)
|
| 1707 |
+
|
| 1708 |
+
# Phase 1: Load one segmenter per GPU (parallel)
|
| 1709 |
+
segmenters = []
|
| 1710 |
+
with ThreadPoolExecutor(max_workers=num_gpus) as pool:
|
| 1711 |
+
futs = [
|
| 1712 |
+
pool.submit(
|
| 1713 |
+
load_segmenter_on_device,
|
| 1714 |
+
active_segmenter,
|
| 1715 |
+
f"cuda:{i}",
|
| 1716 |
+
)
|
| 1717 |
+
for i in range(num_gpus)
|
| 1718 |
+
]
|
| 1719 |
+
segmenters = [f.result() for f in futs]
|
| 1720 |
+
logging.info("Loaded %d segmenters", len(segmenters))
|
| 1721 |
+
|
| 1722 |
+
# Phase 2: Init SAM2 models/state per GPU (parallel)
|
| 1723 |
+
def _init_seg_state(seg):
|
| 1724 |
+
seg._ensure_models_loaded()
|
| 1725 |
+
return seg._video_predictor.init_state(
|
| 1726 |
+
video_path=frame_dir,
|
| 1727 |
+
offload_video_to_cpu=True,
|
| 1728 |
+
async_loading_frames=True,
|
| 1729 |
+
)
|
| 1730 |
+
|
| 1731 |
+
with ThreadPoolExecutor(max_workers=len(segmenters)) as pool:
|
| 1732 |
+
futs = [pool.submit(_init_seg_state, seg) for seg in segmenters]
|
| 1733 |
+
inference_states = [f.result() for f in futs]
|
| 1734 |
+
|
| 1735 |
+
# Phase 3: Parallel segment processing (queue-based workers)
|
| 1736 |
+
segments = list(range(0, total_frames, step))
|
| 1737 |
+
seg_queue_in: Queue = Queue()
|
| 1738 |
+
seg_queue_out: Queue = Queue()
|
| 1739 |
+
|
| 1740 |
+
for i, start_idx in enumerate(segments):
|
| 1741 |
+
seg_queue_in.put((i, start_idx))
|
| 1742 |
+
for _ in segmenters:
|
| 1743 |
+
seg_queue_in.put(None) # sentinel
|
| 1744 |
+
|
| 1745 |
+
iou_thresh = segmenters[0].iou_threshold
|
| 1746 |
+
|
| 1747 |
+
def _segment_worker(gpu_idx: int):
|
| 1748 |
+
seg = segmenters[gpu_idx]
|
| 1749 |
+
state = inference_states[gpu_idx]
|
| 1750 |
+
device_type = seg.device.split(":")[0]
|
| 1751 |
+
ac = (
|
| 1752 |
+
torch.autocast(device_type=device_type, dtype=torch.bfloat16)
|
| 1753 |
+
if device_type == "cuda"
|
| 1754 |
+
else nullcontext()
|
| 1755 |
+
)
|
| 1756 |
+
with ac:
|
| 1757 |
+
while True:
|
| 1758 |
+
if job_id:
|
| 1759 |
+
try:
|
| 1760 |
+
_check_cancellation(job_id)
|
| 1761 |
+
except RuntimeError as e:
|
| 1762 |
+
if "cancelled" in str(e).lower():
|
| 1763 |
+
logging.info(
|
| 1764 |
+
"Segment worker %d cancelled.",
|
| 1765 |
+
gpu_idx,
|
| 1766 |
+
)
|
| 1767 |
+
break
|
| 1768 |
+
raise
|
| 1769 |
+
item = seg_queue_in.get()
|
| 1770 |
+
if item is None:
|
| 1771 |
+
break
|
| 1772 |
+
seg_idx, start_idx = item
|
| 1773 |
+
try:
|
| 1774 |
+
logging.info(
|
| 1775 |
+
"GPU %d processing segment %d (frame %d)",
|
| 1776 |
+
gpu_idx, seg_idx, start_idx,
|
| 1777 |
+
)
|
| 1778 |
+
img_path = os.path.join(
|
| 1779 |
+
frame_dir, frame_names[start_idx]
|
| 1780 |
+
)
|
| 1781 |
+
with PILImage.open(img_path) as pil_img:
|
| 1782 |
+
image = pil_img.convert("RGB")
|
| 1783 |
+
|
| 1784 |
+
if job_id:
|
| 1785 |
+
_check_cancellation(job_id)
|
| 1786 |
+
masks, boxes, labels = seg.detect_keyframe(
|
| 1787 |
+
image, queries,
|
| 1788 |
+
)
|
| 1789 |
+
|
| 1790 |
+
if masks is None:
|
| 1791 |
+
seg_queue_out.put(
|
| 1792 |
+
(seg_idx, start_idx, None, {})
|
| 1793 |
+
)
|
| 1794 |
+
continue
|
| 1795 |
+
|
| 1796 |
+
mask_dict = MaskDictionary()
|
| 1797 |
+
mask_dict.add_new_frame_annotation(
|
| 1798 |
+
mask_list=torch.tensor(masks).to(seg.device),
|
| 1799 |
+
box_list=(
|
| 1800 |
+
boxes.clone()
|
| 1801 |
+
if torch.is_tensor(boxes)
|
| 1802 |
+
else torch.tensor(boxes)
|
| 1803 |
+
),
|
| 1804 |
+
label_list=labels,
|
| 1805 |
+
)
|
| 1806 |
+
|
| 1807 |
+
segment_results = seg.propagate_segment(
|
| 1808 |
+
state, start_idx, mask_dict, step,
|
| 1809 |
+
)
|
| 1810 |
+
seg_queue_out.put(
|
| 1811 |
+
(seg_idx, start_idx, mask_dict, segment_results)
|
| 1812 |
+
)
|
| 1813 |
+
except RuntimeError as e:
|
| 1814 |
+
if "cancelled" in str(e).lower():
|
| 1815 |
+
logging.info(
|
| 1816 |
+
"Segment worker %d cancelled.",
|
| 1817 |
+
gpu_idx,
|
| 1818 |
+
)
|
| 1819 |
+
break
|
| 1820 |
+
raise
|
| 1821 |
+
except Exception:
|
| 1822 |
+
logging.exception(
|
| 1823 |
+
"Segment %d failed on GPU %d",
|
| 1824 |
+
seg_idx, gpu_idx,
|
| 1825 |
+
)
|
| 1826 |
+
seg_queue_out.put(
|
| 1827 |
+
(seg_idx, start_idx, None, {})
|
| 1828 |
+
)
|
| 1829 |
|
| 1830 |
+
seg_workers = []
|
| 1831 |
+
for i in range(num_gpus):
|
| 1832 |
+
t = Thread(
|
| 1833 |
+
target=_segment_worker, args=(i,), daemon=True,
|
| 1834 |
+
)
|
| 1835 |
+
t.start()
|
| 1836 |
+
seg_workers.append(t)
|
| 1837 |
+
|
| 1838 |
+
for t in seg_workers:
|
| 1839 |
+
t.join()
|
| 1840 |
+
|
| 1841 |
+
# Collect all segment outputs
|
| 1842 |
+
segment_data: Dict[int, Tuple] = {}
|
| 1843 |
+
while not seg_queue_out.empty():
|
| 1844 |
+
seg_idx, start_idx, mask_dict, results = seg_queue_out.get()
|
| 1845 |
+
segment_data[seg_idx] = (start_idx, mask_dict, results)
|
| 1846 |
+
|
| 1847 |
+
# Phase 4: Sequential ID reconciliation
|
| 1848 |
+
global_id_counter = 0
|
| 1849 |
+
sam2_masks = MaskDictionary()
|
| 1850 |
+
tracking_results: Dict[int, Dict[int, ObjectInfo]] = {}
|
| 1851 |
+
|
| 1852 |
+
for seg_idx in sorted(segment_data.keys()):
|
| 1853 |
+
start_idx, mask_dict, segment_results = segment_data[seg_idx]
|
| 1854 |
+
|
| 1855 |
+
if mask_dict is None or not mask_dict.labels:
|
| 1856 |
+
# No detections — carry forward previous masks
|
| 1857 |
+
for fi in range(
|
| 1858 |
+
start_idx, min(start_idx + step, total_frames)
|
| 1859 |
+
):
|
| 1860 |
+
if fi not in tracking_results:
|
| 1861 |
+
tracking_results[fi] = (
|
| 1862 |
+
{
|
| 1863 |
+
k: ObjectInfo(
|
| 1864 |
+
instance_id=v.instance_id,
|
| 1865 |
+
mask=v.mask,
|
| 1866 |
+
class_name=v.class_name,
|
| 1867 |
+
x1=v.x1, y1=v.y1,
|
| 1868 |
+
x2=v.x2, y2=v.y2,
|
| 1869 |
+
)
|
| 1870 |
+
for k, v in sam2_masks.labels.items()
|
| 1871 |
+
}
|
| 1872 |
+
if sam2_masks.labels
|
| 1873 |
+
else {}
|
| 1874 |
+
)
|
| 1875 |
+
continue
|
| 1876 |
+
|
| 1877 |
+
# IoU match + get local→global remapping
|
| 1878 |
+
global_id_counter, remapping = (
|
| 1879 |
+
mask_dict.update_masks_with_remapping(
|
| 1880 |
+
tracking_dict=sam2_masks,
|
| 1881 |
+
iou_threshold=iou_thresh,
|
| 1882 |
+
objects_count=global_id_counter,
|
| 1883 |
+
)
|
| 1884 |
+
)
|
| 1885 |
+
|
| 1886 |
+
if not mask_dict.labels:
|
| 1887 |
+
for fi in range(
|
| 1888 |
+
start_idx, min(start_idx + step, total_frames)
|
| 1889 |
+
):
|
| 1890 |
+
tracking_results[fi] = {}
|
| 1891 |
+
continue
|
| 1892 |
+
|
| 1893 |
+
# Apply remapping to every frame in this segment
|
| 1894 |
+
for frame_idx, frame_objects in segment_results.items():
|
| 1895 |
+
remapped: Dict[int, ObjectInfo] = {}
|
| 1896 |
+
for local_id, obj_info in frame_objects.items():
|
| 1897 |
+
global_id = remapping.get(local_id)
|
| 1898 |
+
if global_id is None:
|
| 1899 |
+
continue
|
| 1900 |
+
remapped[global_id] = ObjectInfo(
|
| 1901 |
+
instance_id=global_id,
|
| 1902 |
+
mask=obj_info.mask,
|
| 1903 |
+
class_name=obj_info.class_name,
|
| 1904 |
+
x1=obj_info.x1, y1=obj_info.y1,
|
| 1905 |
+
x2=obj_info.x2, y2=obj_info.y2,
|
| 1906 |
+
)
|
| 1907 |
+
tracking_results[frame_idx] = remapped
|
| 1908 |
+
|
| 1909 |
+
# Update running tracker with last frame of this segment
|
| 1910 |
+
if segment_results:
|
| 1911 |
+
last_fi = max(segment_results.keys())
|
| 1912 |
+
last_objs = tracking_results.get(last_fi, {})
|
| 1913 |
+
sam2_masks = MaskDictionary()
|
| 1914 |
+
sam2_masks.labels = copy.deepcopy(last_objs)
|
| 1915 |
+
if last_objs:
|
| 1916 |
+
first_info = next(iter(last_objs.values()))
|
| 1917 |
+
if first_info.mask is not None:
|
| 1918 |
+
m = first_info.mask
|
| 1919 |
+
sam2_masks.mask_height = (
|
| 1920 |
+
m.shape[-2] if m.ndim >= 2 else 0
|
| 1921 |
+
)
|
| 1922 |
+
sam2_masks.mask_width = (
|
| 1923 |
+
m.shape[-1] if m.ndim >= 2 else 0
|
| 1924 |
+
)
|
| 1925 |
+
|
| 1926 |
+
logging.info(
|
| 1927 |
+
"Multi-GPU reconciliation complete: %d frames, %d objects",
|
| 1928 |
+
len(tracking_results), global_id_counter,
|
| 1929 |
+
)
|
| 1930 |
+
|
| 1931 |
+
# ==================================================================
|
| 1932 |
+
# Phase 5: Parallel rendering + sequential video writing
|
| 1933 |
+
# ==================================================================
|
| 1934 |
_check_cancellation(job_id)
|
| 1935 |
+
|
| 1936 |
+
render_in: Queue = Queue(maxsize=32)
|
| 1937 |
+
render_out: Queue = Queue(maxsize=64)
|
| 1938 |
+
render_done = False
|
| 1939 |
+
num_render_workers = min(4, os.cpu_count() or 1)
|
| 1940 |
+
|
| 1941 |
+
def _render_worker():
|
| 1942 |
+
while True:
|
| 1943 |
+
item = render_in.get()
|
| 1944 |
+
if item is None:
|
| 1945 |
+
break
|
| 1946 |
+
fidx, fobjs = item
|
| 1947 |
+
try:
|
| 1948 |
+
frm = _gsam2_render_frame(
|
| 1949 |
+
frame_dir, frame_names, fidx, fobjs,
|
| 1950 |
+
height, width,
|
| 1951 |
+
)
|
| 1952 |
+
while True:
|
| 1953 |
+
try:
|
| 1954 |
+
render_out.put((fidx, frm), timeout=1.0)
|
| 1955 |
+
break
|
| 1956 |
+
except Full:
|
| 1957 |
+
if render_done:
|
| 1958 |
+
return
|
| 1959 |
+
except Exception:
|
| 1960 |
+
logging.exception("Render failed for frame %d", fidx)
|
| 1961 |
+
blank = np.zeros((height, width, 3), dtype=np.uint8)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1962 |
try:
|
| 1963 |
+
render_out.put((fidx, blank), timeout=5.0)
|
| 1964 |
+
except Full:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1965 |
pass
|
| 1966 |
|
| 1967 |
+
r_workers = [
|
| 1968 |
+
Thread(target=_render_worker, daemon=True)
|
| 1969 |
+
for _ in range(num_render_workers)
|
| 1970 |
+
]
|
| 1971 |
+
for t in r_workers:
|
| 1972 |
+
t.start()
|
| 1973 |
+
|
| 1974 |
+
def _writer_loop():
|
| 1975 |
+
nonlocal render_done
|
| 1976 |
+
next_idx = 0
|
| 1977 |
+
buf: Dict[int, np.ndarray] = {}
|
| 1978 |
+
try:
|
| 1979 |
+
with StreamingVideoWriter(
|
| 1980 |
+
output_video_path, fps, width, height
|
| 1981 |
+
) as writer:
|
| 1982 |
+
while next_idx < total_frames:
|
| 1983 |
+
try:
|
| 1984 |
+
while next_idx not in buf:
|
| 1985 |
+
if len(buf) > 128:
|
| 1986 |
+
logging.warning(
|
| 1987 |
+
"Render reorder buffer large (%d), "
|
| 1988 |
+
"waiting for frame %d",
|
| 1989 |
+
len(buf), next_idx,
|
| 1990 |
+
)
|
| 1991 |
+
time.sleep(0.05)
|
| 1992 |
+
idx, frm = render_out.get(timeout=1.0)
|
| 1993 |
+
buf[idx] = frm
|
| 1994 |
+
|
| 1995 |
+
frm = buf.pop(next_idx)
|
| 1996 |
+
writer.write(frm)
|
| 1997 |
+
|
| 1998 |
+
if stream_queue:
|
| 1999 |
+
try:
|
| 2000 |
+
from jobs.streaming import (
|
| 2001 |
+
publish_frame as _pub,
|
| 2002 |
+
)
|
| 2003 |
+
if job_id:
|
| 2004 |
+
_pub(job_id, frm)
|
| 2005 |
+
else:
|
| 2006 |
+
stream_queue.put(frm, timeout=0.01)
|
| 2007 |
+
except Exception:
|
| 2008 |
+
pass
|
| 2009 |
+
|
| 2010 |
+
next_idx += 1
|
| 2011 |
+
if next_idx % 30 == 0:
|
| 2012 |
+
logging.info(
|
| 2013 |
+
"Rendered frame %d / %d",
|
| 2014 |
+
next_idx, total_frames,
|
| 2015 |
+
)
|
| 2016 |
+
except Empty:
|
| 2017 |
+
if job_id:
|
| 2018 |
+
_check_cancellation(job_id)
|
| 2019 |
+
if not any(t.is_alive() for t in r_workers) and render_out.empty():
|
| 2020 |
+
logging.error(
|
| 2021 |
+
"Render workers stopped while waiting "
|
| 2022 |
+
"for frame %d", next_idx,
|
| 2023 |
+
)
|
| 2024 |
+
break
|
| 2025 |
+
continue
|
| 2026 |
+
finally:
|
| 2027 |
+
render_done = True
|
| 2028 |
+
|
| 2029 |
+
writer_thread = Thread(target=_writer_loop, daemon=True)
|
| 2030 |
+
writer_thread.start()
|
| 2031 |
+
|
| 2032 |
+
# Feed render queue
|
| 2033 |
+
for fidx in range(total_frames):
|
| 2034 |
+
_check_cancellation(job_id)
|
| 2035 |
+
fobjs = tracking_results.get(fidx, {})
|
| 2036 |
+
render_in.put((fidx, fobjs))
|
| 2037 |
+
|
| 2038 |
+
# Sentinels for render workers
|
| 2039 |
+
for _ in r_workers:
|
| 2040 |
+
render_in.put(None)
|
| 2041 |
+
|
| 2042 |
+
for t in r_workers:
|
| 2043 |
+
t.join()
|
| 2044 |
+
writer_thread.join()
|
| 2045 |
|
| 2046 |
logging.info("Grounded-SAM-2 output written to: %s", output_video_path)
|
| 2047 |
return output_video_path
|
| 2048 |
|
| 2049 |
finally:
|
|
|
|
| 2050 |
try:
|
| 2051 |
shutil.rmtree(frame_dir)
|
| 2052 |
logging.info("Cleaned up temp frame dir: %s", frame_dir)
|
models/segmenters/grounded_sam2.py
CHANGED
|
@@ -90,18 +90,24 @@ class MaskDictionary:
|
|
| 90 |
) -> int:
|
| 91 |
"""Match current detections against tracked objects via IoU."""
|
| 92 |
updated = {}
|
|
|
|
| 93 |
for _seg_id, seg_info in self.labels.items():
|
| 94 |
if seg_info.mask is None or seg_info.mask.sum() == 0:
|
| 95 |
continue
|
| 96 |
matched_id = 0
|
|
|
|
| 97 |
for _obj_id, obj_info in tracking_dict.labels.items():
|
|
|
|
|
|
|
| 98 |
iou = self._iou(seg_info.mask, obj_info.mask)
|
| 99 |
-
if iou >
|
|
|
|
| 100 |
matched_id = obj_info.instance_id
|
| 101 |
-
break
|
| 102 |
if not matched_id:
|
| 103 |
objects_count += 1
|
| 104 |
matched_id = objects_count
|
|
|
|
|
|
|
| 105 |
new_info = ObjectInfo(
|
| 106 |
instance_id=matched_id,
|
| 107 |
mask=seg_info.mask,
|
|
@@ -111,6 +117,47 @@ class MaskDictionary:
|
|
| 111 |
self.labels = updated
|
| 112 |
return objects_count
|
| 113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
def get_target_class_name(self, instance_id: int) -> str:
|
| 115 |
info = self.labels.get(instance_id)
|
| 116 |
return info.class_name if info else ""
|
|
@@ -277,6 +324,122 @@ class GroundedSAM2Segmenter(Segmenter):
|
|
| 277 |
boxes=det.boxes,
|
| 278 |
)
|
| 279 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
# -- Video-level tracking interface -------------------------------------
|
| 281 |
|
| 282 |
def process_video(
|
|
|
|
| 90 |
) -> int:
|
| 91 |
"""Match current detections against tracked objects via IoU."""
|
| 92 |
updated = {}
|
| 93 |
+
used_tracked_ids = set()
|
| 94 |
for _seg_id, seg_info in self.labels.items():
|
| 95 |
if seg_info.mask is None or seg_info.mask.sum() == 0:
|
| 96 |
continue
|
| 97 |
matched_id = 0
|
| 98 |
+
best_iou = iou_threshold
|
| 99 |
for _obj_id, obj_info in tracking_dict.labels.items():
|
| 100 |
+
if obj_info.instance_id in used_tracked_ids:
|
| 101 |
+
continue
|
| 102 |
iou = self._iou(seg_info.mask, obj_info.mask)
|
| 103 |
+
if iou > best_iou:
|
| 104 |
+
best_iou = iou
|
| 105 |
matched_id = obj_info.instance_id
|
|
|
|
| 106 |
if not matched_id:
|
| 107 |
objects_count += 1
|
| 108 |
matched_id = objects_count
|
| 109 |
+
else:
|
| 110 |
+
used_tracked_ids.add(matched_id)
|
| 111 |
new_info = ObjectInfo(
|
| 112 |
instance_id=matched_id,
|
| 113 |
mask=seg_info.mask,
|
|
|
|
| 117 |
self.labels = updated
|
| 118 |
return objects_count
|
| 119 |
|
| 120 |
+
def update_masks_with_remapping(
|
| 121 |
+
self,
|
| 122 |
+
tracking_dict: "MaskDictionary",
|
| 123 |
+
iou_threshold: float = 0.5,
|
| 124 |
+
objects_count: int = 0,
|
| 125 |
+
) -> Tuple[int, Dict[int, int]]:
|
| 126 |
+
"""Match detections against tracked objects, returning ID remapping.
|
| 127 |
+
|
| 128 |
+
Same logic as ``update_masks`` but additionally returns a dict
|
| 129 |
+
mapping original (local) IDs to the assigned (global) IDs.
|
| 130 |
+
"""
|
| 131 |
+
updated = {}
|
| 132 |
+
remapping: Dict[int, int] = {}
|
| 133 |
+
used_tracked_ids = set()
|
| 134 |
+
for seg_id, seg_info in self.labels.items():
|
| 135 |
+
if seg_info.mask is None or seg_info.mask.sum() == 0:
|
| 136 |
+
continue
|
| 137 |
+
matched_id = 0
|
| 138 |
+
best_iou = iou_threshold
|
| 139 |
+
for _obj_id, obj_info in tracking_dict.labels.items():
|
| 140 |
+
if obj_info.instance_id in used_tracked_ids:
|
| 141 |
+
continue
|
| 142 |
+
iou = self._iou(seg_info.mask, obj_info.mask)
|
| 143 |
+
if iou > best_iou:
|
| 144 |
+
best_iou = iou
|
| 145 |
+
matched_id = obj_info.instance_id
|
| 146 |
+
if not matched_id:
|
| 147 |
+
objects_count += 1
|
| 148 |
+
matched_id = objects_count
|
| 149 |
+
else:
|
| 150 |
+
used_tracked_ids.add(matched_id)
|
| 151 |
+
new_info = ObjectInfo(
|
| 152 |
+
instance_id=matched_id,
|
| 153 |
+
mask=seg_info.mask,
|
| 154 |
+
class_name=seg_info.class_name,
|
| 155 |
+
)
|
| 156 |
+
updated[matched_id] = new_info
|
| 157 |
+
remapping[seg_id] = matched_id
|
| 158 |
+
self.labels = updated
|
| 159 |
+
return objects_count, remapping
|
| 160 |
+
|
| 161 |
def get_target_class_name(self, instance_id: int) -> str:
|
| 162 |
info = self.labels.get(instance_id)
|
| 163 |
return info.class_name if info else ""
|
|
|
|
| 324 |
boxes=det.boxes,
|
| 325 |
)
|
| 326 |
|
| 327 |
+
# -- Multi-GPU helper methods -------------------------------------------
|
| 328 |
+
|
| 329 |
+
def detect_keyframe(
|
| 330 |
+
self,
|
| 331 |
+
image: "Image",
|
| 332 |
+
text_prompts: List[str],
|
| 333 |
+
) -> Tuple[Optional[np.ndarray], Optional[torch.Tensor], List[str]]:
|
| 334 |
+
"""Run GDINO + SAM2 image predictor on a single keyframe.
|
| 335 |
+
|
| 336 |
+
Args:
|
| 337 |
+
image: PIL Image in RGB mode.
|
| 338 |
+
text_prompts: Text queries for Grounding DINO.
|
| 339 |
+
|
| 340 |
+
Returns:
|
| 341 |
+
``(masks, boxes, labels)`` where *masks* is an ``(N, H, W)``
|
| 342 |
+
numpy array, *boxes* is an ``(N, 4)`` tensor on device, and
|
| 343 |
+
*labels* is a list of strings. Returns ``(None, None, [])``
|
| 344 |
+
when no objects are detected.
|
| 345 |
+
"""
|
| 346 |
+
self._ensure_models_loaded()
|
| 347 |
+
|
| 348 |
+
prompt = self._gdino_detector._build_prompt(text_prompts)
|
| 349 |
+
gdino_processor = self._gdino_detector.processor
|
| 350 |
+
gdino_model = self._gdino_detector.model
|
| 351 |
+
|
| 352 |
+
inputs = gdino_processor(
|
| 353 |
+
images=image, text=prompt, return_tensors="pt"
|
| 354 |
+
)
|
| 355 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 356 |
+
|
| 357 |
+
with torch.no_grad():
|
| 358 |
+
outputs = gdino_model(**inputs)
|
| 359 |
+
|
| 360 |
+
results = self._gdino_detector._post_process(
|
| 361 |
+
outputs,
|
| 362 |
+
inputs["input_ids"],
|
| 363 |
+
target_sizes=[image.size[::-1]],
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
input_boxes = results[0]["boxes"]
|
| 367 |
+
det_labels = results[0].get("text_labels") or results[0].get("labels", [])
|
| 368 |
+
if torch.is_tensor(det_labels):
|
| 369 |
+
det_labels = det_labels.detach().cpu().tolist()
|
| 370 |
+
det_labels = [str(l) for l in det_labels]
|
| 371 |
+
|
| 372 |
+
if input_boxes.shape[0] == 0:
|
| 373 |
+
return None, None, []
|
| 374 |
+
|
| 375 |
+
# SAM2 image predictor
|
| 376 |
+
self._image_predictor.set_image(np.array(image))
|
| 377 |
+
masks, scores, logits = self._image_predictor.predict(
|
| 378 |
+
point_coords=None,
|
| 379 |
+
point_labels=None,
|
| 380 |
+
box=input_boxes,
|
| 381 |
+
multimask_output=False,
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
if masks.ndim == 2:
|
| 385 |
+
masks = masks[None]
|
| 386 |
+
elif masks.ndim == 4:
|
| 387 |
+
masks = masks.squeeze(1)
|
| 388 |
+
|
| 389 |
+
return masks, input_boxes, det_labels
|
| 390 |
+
|
| 391 |
+
def propagate_segment(
|
| 392 |
+
self,
|
| 393 |
+
inference_state: Any,
|
| 394 |
+
start_idx: int,
|
| 395 |
+
mask_dict: "MaskDictionary",
|
| 396 |
+
step: int,
|
| 397 |
+
) -> Dict[int, Dict[int, "ObjectInfo"]]:
|
| 398 |
+
"""Propagate masks for a single segment via SAM2 video predictor.
|
| 399 |
+
|
| 400 |
+
Calls ``reset_state`` first, making this safe to call independently
|
| 401 |
+
(and therefore parallelisable across GPUs).
|
| 402 |
+
|
| 403 |
+
Args:
|
| 404 |
+
inference_state: SAM2 video predictor state (from ``init_state``).
|
| 405 |
+
start_idx: Starting frame index for this segment.
|
| 406 |
+
mask_dict: MaskDictionary with object masks for the keyframe.
|
| 407 |
+
step: Maximum number of frames to propagate.
|
| 408 |
+
|
| 409 |
+
Returns:
|
| 410 |
+
Dict mapping ``frame_idx`` → ``{obj_id: ObjectInfo}`` using the
|
| 411 |
+
IDs from *mask_dict* (local, not yet reconciled).
|
| 412 |
+
"""
|
| 413 |
+
self._video_predictor.reset_state(inference_state)
|
| 414 |
+
|
| 415 |
+
for obj_id, obj_info in mask_dict.labels.items():
|
| 416 |
+
self._video_predictor.add_new_mask(
|
| 417 |
+
inference_state,
|
| 418 |
+
start_idx,
|
| 419 |
+
obj_id,
|
| 420 |
+
obj_info.mask,
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
segment_results: Dict[int, Dict[int, ObjectInfo]] = {}
|
| 424 |
+
for out_frame_idx, out_obj_ids, out_mask_logits in self._video_predictor.propagate_in_video(
|
| 425 |
+
inference_state,
|
| 426 |
+
max_frame_num_to_track=step,
|
| 427 |
+
start_frame_idx=start_idx,
|
| 428 |
+
):
|
| 429 |
+
frame_objects: Dict[int, ObjectInfo] = {}
|
| 430 |
+
for i, out_obj_id in enumerate(out_obj_ids):
|
| 431 |
+
out_mask = (out_mask_logits[i] > 0.0)
|
| 432 |
+
info = ObjectInfo(
|
| 433 |
+
instance_id=out_obj_id,
|
| 434 |
+
mask=out_mask[0],
|
| 435 |
+
class_name=mask_dict.get_target_class_name(out_obj_id),
|
| 436 |
+
)
|
| 437 |
+
info.update_box()
|
| 438 |
+
frame_objects[out_obj_id] = info
|
| 439 |
+
segment_results[out_frame_idx] = frame_objects
|
| 440 |
+
|
| 441 |
+
return segment_results
|
| 442 |
+
|
| 443 |
# -- Video-level tracking interface -------------------------------------
|
| 444 |
|
| 445 |
def process_video(
|