| |
| """ |
| pipeline.py β Production SAM2 + MatAnyone (T4-optimized, single-pass streaming) |
| |
| Key features |
| ------------ |
| - One SAM2 inference state for the entire video (no per-chunk reinit). |
| - In-stream pipeline: Read β SAM2 β MatAnyone β Compose β Write (no big RAM dicts). |
| - Bounded memory everywhere (deque/window); optional CPU spill. |
| - fp16 + channels_last on SAM2; mixed precision blocks. |
| - VRAM-aware controller adjusts memory window/scale. |
| - Heartbeat logger to prevent HF watchdog restarts. |
| - Safer FFmpeg audio re-mux. |
| |
| Compatible with Tesla T4 (β15β16 GB) and PyTorch 2.5.x + CUDA 12.4 wheels. |
| """ |
|
|
| import os |
| import gc |
| import cv2 |
| import time |
| import uuid |
| import torch |
| import queue |
| import shutil |
| import logging |
| import tempfile |
| import subprocess |
| import threading |
| import numpy as np |
| from PIL import Image |
| from pathlib import Path |
| from typing import Optional, Tuple, Dict, Any, Callable |
| from collections import deque |
|
|
| |
| |
| |
| logger = logging.getLogger("backgroundfx_pro") |
| if not logger.handlers: |
| h = logging.StreamHandler() |
| h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s:%(name)s: %(message)s")) |
| logger.addHandler(h) |
| logger.setLevel(logging.INFO) |
|
|
| |
| |
| |
| def setup_t4_environment(): |
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", |
| "expandable_segments:True,max_split_size_mb:256,garbage_collection_threshold:0.7") |
| os.environ.setdefault("OMP_NUM_THREADS", "1") |
| os.environ.setdefault("OPENBLAS_NUM_THREADS", "1") |
| os.environ.setdefault("MKL_NUM_THREADS", "1") |
| os.environ.setdefault("OPENCV_OPENCL_RUNTIME", "disabled") |
| os.environ.setdefault("OPENCV_IO_ENABLE_OPENEXR", "0") |
|
|
| torch.set_grad_enabled(False) |
| try: |
| torch.backends.cudnn.benchmark = True |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| torch.set_float32_matmul_precision("high") |
| except Exception: |
| pass |
|
|
| if torch.cuda.is_available(): |
| try: |
| frac = float(os.getenv("CUDA_MEMORY_FRACTION", "0.88")) |
| torch.cuda.set_per_process_memory_fraction(frac) |
| logger.info(f"CUDA per-process memory fraction = {frac:.2f}") |
| except Exception as e: |
| logger.warning(f"Could not set CUDA memory fraction: {e}") |
|
|
| def vram_gb() -> Tuple[float, float]: |
| if not torch.cuda.is_available(): |
| return 0.0, 0.0 |
| free, total = torch.cuda.mem_get_info() |
| return free / (1024 ** 3), total / (1024 ** 3) |
|
|
| |
| |
| |
| def heartbeat_monitor(running_flag: Dict[str, bool], interval: float = 8.0): |
| while running_flag.get("running", False): |
| print(f"[HB] t={int(time.time())}", flush=True) |
| time.sleep(interval) |
|
|
| |
| |
| |
| class StreamingVideoIO: |
| def __init__(self, video_path: str, out_path: str, fps: float): |
| self.video_path = video_path |
| self.out_path = out_path |
| self.fps = fps |
| self.cap = None |
| self.writer = None |
| self.size = None |
|
|
| def __enter__(self): |
| self.cap = cv2.VideoCapture(self.video_path) |
| if not self.cap.isOpened(): |
| raise RuntimeError(f"Cannot open video: {self.video_path}") |
| w = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
| h = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
| self.size = (w, h) |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
| self.writer = cv2.VideoWriter(self.out_path, fourcc, self.fps, (w, h)) |
| return self |
|
|
| def __exit__(self, exc_type, exc_val, exc_tb): |
| if self.cap: |
| self.cap.release() |
| if self.writer: |
| self.writer.release() |
|
|
| def read_frame(self): |
| if not self.cap: |
| return False, None |
| return self.cap.read() |
|
|
| def write_frame(self, frame_bgr: np.ndarray): |
| if not self.writer: |
| return |
| self.writer.write(frame_bgr) |
|
|
| |
| |
| |
| def load_sam2_predictor(device: torch.device): |
| """ |
| Prefer your local wrapper to keep interfaces stable. |
| """ |
| try: |
| from models.sam2_loader import SAM2Predictor |
| predictor = SAM2Predictor(device=device) |
| |
| try: |
| if hasattr(predictor, "model") and predictor.model is not None: |
| predictor.model = predictor.model.half().to(device) |
| predictor.model = predictor.model.to(memory_format=torch.channels_last) |
| logger.info("SAM2: fp16 + channels_last applied (wrapper model).") |
| except Exception as e: |
| logger.warning(f"SAM2 fp16 optimization warning: {e}") |
| return predictor |
| except Exception as e: |
| logger.error(f"Failed to import SAM2Predictor: {e}") |
| raise |
|
|
| def load_matany_session(device: torch.device): |
| """ |
| Supports either MatAnyoneSession or MatAnyoneLoader (your code has varied). |
| """ |
| try: |
| try: |
| from models.matanyone_loader import MatAnyoneSession as _MatAny |
| except Exception: |
| from models.matanyone_loader import MatAnyoneLoader as _MatAny |
| session = _MatAny(device=device) |
| |
| if hasattr(session, "model") and session.model is not None: |
| session.model.eval() |
| try: |
| session.model = session.model.half().to(device) |
| logger.info("MatAnyone: fp16 + eval applied.") |
| except Exception: |
| logger.info("MatAnyone: using fp32 (fp16 not supported for some layers).") |
| return session |
| except Exception as e: |
| logger.warning(f"MatAnyone not available ({e}). Proceeding without refinement.") |
| return None |
|
|
| |
| |
| |
| def prune_sam2_state(predictor, state: Any, keep: int): |
| """ |
| Try to prune SAM2 temporal caches to a fixed window length. |
| Your SAM2Predictor should implement prune_state(state, keep=N). If not, we do nothing. |
| """ |
| try: |
| if hasattr(predictor, "prune_state"): |
| predictor.prune_state(state, keep=keep) |
| elif hasattr(state, "prune") and callable(getattr(state, "prune")): |
| state.prune(keep=keep) |
| else: |
| |
| pass |
| except Exception as e: |
| logger.debug(f"SAM2 prune_state warning: {e}") |
|
|
| |
| |
| |
| class VRAMAdaptiveController: |
| def __init__(self): |
| self.memory_window = int(os.getenv("SAM2_WINDOW", "96")) |
| self.propagation_scale = float(os.getenv("SAM2_PROP_SCALE", "0.90")) |
| self.cleanup_every = 20 |
|
|
| def adapt(self): |
| free, total = vram_gb() |
| if free == 0.0: |
| return |
| |
| if free < 1.6: |
| self.memory_window = max(48, self.memory_window - 8) |
| self.propagation_scale = max(0.75, self.propagation_scale - 0.03) |
| self.cleanup_every = max(12, self.cleanup_every - 2) |
| logger.warning(f"Low VRAM ({free:.2f} GB free) β window={self.memory_window}, scale={self.propagation_scale:.2f}") |
| |
| elif free > 3.0: |
| self.memory_window = min(128, self.memory_window + 4) |
| self.propagation_scale = min(1.0, self.propagation_scale + 0.01) |
| self.cleanup_every = min(40, self.cleanup_every + 2) |
|
|
| |
| |
| |
| def mux_audio(video_path_no_audio: str, source_with_audio: str, out_path: str) -> bool: |
| cmd = [ |
| "ffmpeg", "-y", "-hide_banner", "-loglevel", "error", |
| "-i", video_path_no_audio, |
| "-i", source_with_audio, |
| "-map", "0:v:0", "-map", "1:a:0", |
| "-c:v", "copy", "-c:a", "aac", "-shortest", |
| out_path |
| ] |
| try: |
| r = subprocess.run(cmd, capture_output=True, text=True, timeout=180) |
| if r.returncode != 0: |
| logger.warning(f"FFmpeg mux failed: {r.stderr.strip()}") |
| return False |
| return True |
| except Exception as e: |
| logger.warning(f"FFmpeg mux error: {e}") |
| return False |
|
|
| |
| |
| |
| def process( |
| video_path: str, |
| background_image: Optional[Image.Image] = None, |
| background_type: str = "custom", |
| background_prompt: str = "", |
| job_directory: Optional[Path] = None, |
| progress_callback: Optional[Callable[[str, float], None]] = None |
| ) -> str: |
| """ |
| Production SAM2 + MatAnyone pipeline for T4. |
| - Single-pass streaming (no large mask dicts) |
| - Bounded memory windows |
| """ |
| setup_t4_environment() |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| hb_flag = {"running": True} |
| hb_thread = threading.Thread(target=heartbeat_monitor, args=(hb_flag, 8.0), daemon=True) |
| hb_thread.start() |
|
|
| def report(step: str, p: Optional[float] = None): |
| if p is None: |
| logger.info(step) |
| else: |
| logger.info(f"{step} [{p:.1%}]") |
| if progress_callback: |
| try: |
| progress_callback(step, p) |
| except Exception as e: |
| logger.debug(f"progress_callback error: {e}") |
|
|
| |
| src = Path(video_path) |
| if not src.exists(): |
| hb_flag["running"] = False |
| raise FileNotFoundError(f"Video not found: {video_path}") |
|
|
| if job_directory is None: |
| job_directory = Path.cwd() / "tmp" / f"job_{uuid.uuid4().hex[:8]}" |
| job_directory.mkdir(parents=True, exist_ok=True) |
|
|
| |
| cap_probe = cv2.VideoCapture(str(src)) |
| if not cap_probe.isOpened(): |
| hb_flag["running"] = False |
| raise RuntimeError(f"Cannot open video: {video_path}") |
| fps = cap_probe.get(cv2.CAP_PROP_FPS) or 25.0 |
| width = int(cap_probe.get(cv2.CAP_PROP_FRAME_WIDTH)) |
| height = int(cap_probe.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
| frame_count = int(cap_probe.get(cv2.CAP_PROP_FRAME_COUNT)) |
| duration = frame_count / fps if fps > 0 else 0.0 |
| cap_probe.release() |
| logger.info(f"Video: {width}x{height} @ {fps:.2f} fps | {frame_count} frames ({duration:.1f}s)") |
|
|
| |
| if background_image is None: |
| hb_flag["running"] = False |
| raise ValueError("background_image is required") |
| bg = background_image.resize((width, height), Image.LANCZOS) |
| bg_np = np.array(bg).astype(np.float32) |
|
|
| |
| report("Loading SAM2 + MatAnyone", 0.05) |
| predictor = load_sam2_predictor(device) |
| matany = load_matany_session(device) |
|
|
| |
| report("Initializing SAM2 video state", 0.08) |
| state = predictor.init_state(video_path=str(src)) |
|
|
| |
| center_pt = np.array([[width // 2, height // 2]], dtype=np.float32) |
| labels = np.array([1], dtype=np.int32) |
| ann_obj_id = 1 |
| with torch.inference_mode(): |
| _ = predictor.add_new_points( |
| inference_state=state, |
| frame_idx=0, |
| obj_id=ann_obj_id, |
| points=center_pt, |
| labels=labels, |
| ) |
|
|
| |
| ctrl = VRAMAdaptiveController() |
|
|
| |
| out_raw = str(job_directory / f"composite_{int(time.time())}.mp4") |
| out_final = str(job_directory / f"final_{int(time.time())}.mp4") |
|
|
| |
| |
| aux_window = deque(maxlen=max(32, min(96, ctrl.memory_window // 2))) |
|
|
| |
| start = time.time() |
| frames_done = 0 |
| next_cleanup_at = ctrl.cleanup_every |
|
|
| report("Streaming: SAM2 β MatAnyone β Compose β Write", 0.12) |
| with StreamingVideoIO(str(src), out_raw, fps) as vio: |
| |
| with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16 if device.type == "cuda" else None): |
| for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(state, scale=ctrl.propagation_scale): |
| |
| ret, frame_bgr = vio.read_frame() |
| if not ret: |
| break |
|
|
| |
| mask_t = None |
| try: |
| if isinstance(out_obj_ids, torch.Tensor): |
| |
| idxs = (out_obj_ids == ann_obj_id).nonzero(as_tuple=False) |
| if idxs.numel() > 0: |
| i = idxs[0].item() |
| logits = out_mask_logits[i] |
| else: |
| logits = None |
| else: |
| |
| ids_list = list(out_obj_ids) |
| i = ids_list.index(ann_obj_id) if ann_obj_id in ids_list else -1 |
| logits = out_mask_logits[i] if i >= 0 else None |
|
|
| if logits is not None: |
| |
| mask_t = (logits > 0).float() |
| except Exception as e: |
| logger.debug(f"Mask extraction warning @frame {out_frame_idx}: {e}") |
| mask_t = None |
|
|
| |
| if mask_t is not None and matany is not None: |
| try: |
| |
| |
| frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) |
| |
| refined = None |
| if hasattr(matany, "refine_mask"): |
| refined = matany.refine_mask(frame_rgb, mask_t) |
| elif hasattr(matany, "process_frame"): |
| refined = matany.process_frame(frame_rgb, mask_t) |
| if refined is not None: |
| |
| if isinstance(refined, torch.Tensor): |
| mask_t = refined.float() |
| else: |
| |
| mask_t = torch.from_numpy(refined.astype(np.float32)) |
| if device.type == "cuda": |
| mask_t = mask_t.to(device) |
| except Exception as e: |
| logger.debug(f"MatAnyone refinement failed (frame {out_frame_idx}): {e}") |
|
|
| |
| if mask_t is not None: |
| |
| mask_np = mask_t.detach().clamp(0, 1).to("cpu", non_blocking=True).float().numpy() |
| m3 = mask_np[..., None] |
| frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) |
| comp = frame_rgb * m3 + bg_np * (1.0 - m3) |
| comp_bgr = cv2.cvtColor(comp.astype(np.uint8), cv2.COLOR_RGB2BGR) |
| vio.write_frame(comp_bgr) |
| else: |
| |
| vio.write_frame(frame_bgr) |
|
|
| |
| frames_done += 1 |
| if frames_done >= next_cleanup_at: |
| ctrl.adapt() |
| prune_sam2_state(predictor, state, keep=ctrl.memory_window) |
| |
| aux_window.clear() |
| if device.type == "cuda": |
| torch.cuda.ipc_collect() |
| torch.cuda.empty_cache() |
| next_cleanup_at = frames_done + ctrl.cleanup_every |
|
|
| |
| if frames_done % 25 == 0 and frame_count > 0: |
| p = 0.12 + 0.75 * (frames_done / frame_count) |
| report(f"Processing frame {frames_done}/{frame_count} | win={ctrl.memory_window} scale={ctrl.propagation_scale:.2f}", p) |
|
|
| |
| report("Restoring audio", 0.93) |
| ok = mux_audio(out_raw, str(src), out_final) |
| final_path = out_final if ok else out_raw |
|
|
| |
| try: |
| del predictor |
| del state |
| if matany is not None: |
| del matany |
| except Exception: |
| pass |
|
|
| if device.type == "cuda": |
| torch.cuda.ipc_collect() |
| torch.cuda.empty_cache() |
| gc.collect() |
|
|
| hb_flag["running"] = False |
| elapsed = time.time() - start |
| try: |
| peak = torch.cuda.max_memory_allocated() / (1024 ** 3) if device.type == "cuda" else 0.0 |
| logger.info(f"Peak GPU memory: {peak:.2f} GB") |
| except Exception: |
| pass |
| report(f"Done in {elapsed:.1f}s", 1.0) |
| logger.info(f"Output: {final_path}") |
| logger.info(f"Artifacts: {job_directory}") |
| return final_path |
|
|
|
|
| |
| |
| |
| if __name__ == "__main__": |
| import argparse |
| parser = argparse.ArgumentParser(description="BackgroundFX Pro pipeline") |
| parser.add_argument("--video", required=True, help="Path to input video") |
| parser.add_argument("--background", required=True, help="Path to background image") |
| parser.add_argument("--outdir", default=None, help="Job directory (optional)") |
| args = parser.parse_args() |
|
|
| bg_img = Image.open(args.background).convert("RGB") |
| outdir = Path(args.outdir) if args.outdir else None |
| out_path = process(args.video, background_image=bg_img, job_directory=outdir) |
| print(out_path) |
|
|