SAM3-video-segmentation-tracking / parallel_segment_worker.py
bellmake's picture
SAM3 Video Segmentation - Clean deployment
ae50268
"""
Worker entry for the parallel multi-GPU Auto-Mode dispatcher.
This module is intentionally lightweight at top level — it must NOT import
torch (or anything that imports torch) before `worker_main()` has a chance to
narrow ``CUDA_VISIBLE_DEVICES`` to a single device. ``mp.spawn`` re-imports the
target module in each child process; importing torch at top level here would
cause CUDA to initialize against all visible devices before we can pin the
worker to one GPU.
Flow:
1. Parent dispatcher in app.py spawns one of these per concurrent video.
2. ``worker_main(gpu_index, args, progress_queue)`` is the spawn target.
3. It sets ``CUDA_VISIBLE_DEVICES`` to ``str(gpu_index)`` BEFORE importing
torch, so the child sees only one device (always referenced as cuda:0).
4. It then imports ``app._segment_video_core`` and runs the segmentation,
streaming progress / status / result / error messages back to the parent
via ``progress_queue``. Each message carries ``gpu_index`` so the
dispatcher can route it to the right per-video UI slot.
"""
import os
import pathlib
import shutil
import sys
import tempfile
import traceback
import uuid
_DISTRIBUTED_ENV_KEYS = (
"RANK",
"WORLD_SIZE",
"LOCAL_RANK",
"LOCAL_WORLD_SIZE",
"GROUP_RANK",
"GROUP_WORLD_SIZE",
"ROLE_RANK",
"ROLE_WORLD_SIZE",
"MASTER_ADDR",
"MASTER_PORT",
"TORCHELASTIC_RUN_ID",
"TORCHELASTIC_RESTART_COUNT",
"TORCHELASTIC_MAX_RESTARTS",
)
def _truthy(value):
return str(value).strip().lower() not in {"0", "false", "no", "off"}
def _force_single_rank_env(gpu_index=None, cpu_only=False):
os.environ.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID")
if cpu_only:
os.environ["CUDA_VISIBLE_DEVICES"] = ""
elif gpu_index is not None:
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_index)
for key in _DISTRIBUTED_ENV_KEYS:
os.environ.pop(key, None)
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_RANK"] = "0"
os.environ["LOCAL_WORLD_SIZE"] = "1"
os.environ["SAM3_WORKER_MODE"] = "1"
os.environ.setdefault("SAM3_CACHE_FRAME_OUTPUTS", "0")
os.environ.setdefault("SAM3_OFFLOAD_TRACKER_STATE_TO_CPU", "1")
os.environ.setdefault("PYTHONDONTWRITEBYTECODE", "1")
sys.dont_write_bytecode = True
def _copy_runtime_item(source_dir, runtime_dir, name):
src = pathlib.Path(source_dir) / name
if not src.exists():
fallback = pathlib.Path(__file__).resolve().parent / name
src = fallback if fallback.exists() else src
if not src.exists():
return
dst = pathlib.Path(runtime_dir) / name
if src.is_dir():
shutil.copytree(
src,
dst,
symlinks=False,
ignore=shutil.ignore_patterns("__pycache__", "*.pyc", ".git"),
)
else:
shutil.copy2(src, dst)
def _prepare_runtime(worker_options, tag):
"""
Optionally copy app.py + SAM3 source into a private temp runtime.
The model weights still come from the HF cache, but Python module globals,
__pycache__, and source imports are per-worker. Result temp files stay in
the normal system temp area so the parent process can persist them after
the worker exits. Disable with
SAM3_PARALLEL_COPY_RUNTIME=0 if startup latency matters more than isolation.
"""
worker_options = worker_options or {}
source_dir = pathlib.Path(
worker_options.get("source_app_dir") or pathlib.Path(__file__).resolve().parent
).resolve()
os.environ["SAM3_OUTPUT_ROOT"] = str(source_dir)
if not _truthy(worker_options.get("isolate_runtime", "1")):
return str(source_dir), None
app_src = source_dir / "app.py"
if not app_src.exists():
return str(source_dir), None
runtime_dir = pathlib.Path(
tempfile.mkdtemp(prefix=f"sam3_{tag}_{uuid.uuid4().hex[:8]}_")
).resolve()
try:
_copy_runtime_item(source_dir, runtime_dir, "app.py")
_copy_runtime_item(source_dir, runtime_dir, "parallel_segment_worker.py")
_copy_runtime_item(source_dir, runtime_dir, "sam3")
_copy_runtime_item(source_dir, runtime_dir, "assets")
os.environ["SAM3_ISOLATED_RUNTIME_DIR"] = str(runtime_dir)
return str(runtime_dir), str(runtime_dir)
except Exception:
shutil.rmtree(runtime_dir, ignore_errors=True)
raise
def _prepend_import_paths(*paths):
for path in reversed([p for p in paths if p]):
if path in sys.path:
sys.path.remove(path)
sys.path.insert(0, path)
def _cleanup_runtime(runtime_dir):
if runtime_dir:
shutil.rmtree(runtime_dir, ignore_errors=True)
def worker_main(gpu_index, args, progress_queue, worker_options=None):
# Pin BEFORE any torch import — this is the whole point of the separate file.
_force_single_rank_env(gpu_index=gpu_index)
runtime_dir = None
try:
app_root, runtime_dir = _prepare_runtime(worker_options, f"gpu{gpu_index}")
except Exception as exc: # noqa: BLE001
progress_queue.put({
"type": "error",
"message": f"runtime isolation setup failed on GPU {gpu_index}: {exc}",
"traceback": traceback.format_exc(),
"gpu_index": gpu_index,
})
return
repo_root = os.path.dirname(os.path.abspath(__file__))
_prepend_import_paths(app_root, repo_root)
sys.modules.pop("app", None)
try:
import torch # safe now: only one device visible
if torch.cuda.is_available():
torch.cuda.set_device(0)
except Exception as exc: # noqa: BLE001
progress_queue.put({
"type": "error",
"message": f"torch init failed on GPU {gpu_index}: {exc}",
"traceback": traceback.format_exc(),
"gpu_index": gpu_index,
})
_cleanup_runtime(runtime_dir)
return
try:
progress_queue.put({
"type": "status",
"message": (
f"🔒 GPU {gpu_index}: isolated worker ready "
f"(CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}, "
f"WORLD_SIZE={os.environ.get('WORLD_SIZE')}, runtime={app_root})"
),
"gpu_index": gpu_index,
})
from app import _segment_video_core # imports torch but env is already pinned
except Exception as exc: # noqa: BLE001
progress_queue.put({
"type": "error",
"message": f"import _segment_video_core failed: {exc}",
"traceback": traceback.format_exc(),
"gpu_index": gpu_index,
})
_cleanup_runtime(runtime_dir)
return
(
video_path,
text_prompt,
duration_limit,
id_corrections_text,
id_drop_text,
id_override_start_sec,
show_trails,
view_mode,
) = args
def _progress_cb(val, desc):
progress_queue.put({
"type": "progress",
"value": val,
"desc": desc,
"gpu_index": gpu_index,
})
def _status_cb(msg):
progress_queue.put({
"type": "status",
"message": msg,
"gpu_index": gpu_index,
})
try:
progress_queue.put({
"type": "progress",
"value": 0.0,
"desc": f"GPU {gpu_index}: starting...",
"gpu_index": gpu_index,
})
out_path, status, loc_path = _segment_video_core(
video_path,
text_prompt,
duration_limit,
id_corrections_text=id_corrections_text,
id_drop_text=id_drop_text,
id_override_start_sec=id_override_start_sec,
show_trails=show_trails,
view_mode=view_mode,
progress_callback=_progress_cb,
status_callback=_status_cb,
)
progress_queue.put({
"type": "result",
"data": (out_path, status, loc_path),
"gpu_index": gpu_index,
})
except Exception as exc: # noqa: BLE001
progress_queue.put({
"type": "error",
"message": str(exc),
"traceback": traceback.format_exc(),
"gpu_index": gpu_index,
})
finally:
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
except Exception:
pass
_cleanup_runtime(runtime_dir)
def overlay_worker_main(task_id, args, event_queue, worker_options=None):
"""
Render segmentation + trails overlays in an isolated process so that
multiple videos' CPU-bound overlay work can run on different vCPUs in
parallel after their GPU segmentation completes.
Pushes a single ``overlay_done`` (or ``overlay_error``) message onto
``event_queue`` carrying ``task_id`` so the parent dispatcher can route
it back to the correct video slot.
"""
# Don't pin CUDA — overlay rendering is pure CPU. Avoid importing torch
# at all if we can help it; just route to ffmpeg/cv2.
_force_single_rank_env(cpu_only=True)
runtime_dir = None
try:
app_root, runtime_dir = _prepare_runtime(worker_options, f"overlay_{task_id}")
except Exception as exc: # noqa: BLE001
event_queue.put({
"type": "overlay_error",
"task_id": task_id,
"error": f"runtime isolation setup failed: {exc}",
"traceback": traceback.format_exc(),
})
return
repo_root = os.path.dirname(os.path.abspath(__file__))
_prepend_import_paths(app_root, repo_root)
sys.modules.pop("app", None)
try:
from app import (
_render_segmentation_overlay_video,
_render_trails_overlay_video,
_build_trail_filter_options,
_persist_for_download,
)
except Exception as exc: # noqa: BLE001
event_queue.put({
"type": "overlay_error",
"task_id": task_id,
"error": f"import failed: {exc}",
"traceback": traceback.format_exc(),
})
_cleanup_runtime(runtime_dir)
return
(output_video, location_path, text_prompt) = args
try:
seg_overlay = _render_segmentation_overlay_video(
output_video, location_path, text_prompt
)
except Exception as exc: # noqa: BLE001
event_queue.put({
"type": "overlay_error",
"task_id": task_id,
"error": f"seg overlay failed: {exc}",
"traceback": traceback.format_exc(),
})
_cleanup_runtime(runtime_dir)
return
seg_display_path = seg_overlay or output_video
try:
trails_overlay = _render_trails_overlay_video(
seg_display_path, location_path, text_prompt, force_unique=True
)
except Exception as exc: # noqa: BLE001
# Trails overlay is a "nice to have" — fall back to seg overlay
trails_overlay = None
trails_error = f"trails overlay failed: {exc}"
else:
trails_error = None
try:
choices, defaults, _legend = _build_trail_filter_options(
location_path, text_prompt
)
except Exception:
choices, defaults = [], []
download_path = trails_overlay or seg_display_path
try:
persisted = _persist_for_download(download_path, subdir="downloads")
if persisted:
download_path = persisted
except Exception:
pass
event_queue.put({
"type": "overlay_done",
"task_id": task_id,
"seg_overlay": seg_display_path,
"trails_overlay": trails_overlay,
"download_path": download_path,
"trail_choices": choices,
"trail_selected": defaults,
"warning": trails_error,
})
_cleanup_runtime(runtime_dir)