MogensR's picture
Rename pipeline.py to pipeline/pipeline.py
f1efc9b verified
#!/usr/bin/env python3
"""
pipeline.py β€” Production SAM2 + MatAnyone (T4-optimized, single-pass streaming)
Key features
------------
- One SAM2 inference state for the entire video (no per-chunk reinit).
- In-stream pipeline: Read β†’ SAM2 β†’ MatAnyone β†’ Compose β†’ Write (no big RAM dicts).
- Bounded memory everywhere (deque/window); optional CPU spill.
- fp16 + channels_last on SAM2; mixed precision blocks.
- VRAM-aware controller adjusts memory window/scale.
- Heartbeat logger to prevent HF watchdog restarts.
- Safer FFmpeg audio re-mux.
Compatible with Tesla T4 (β‰ˆ15–16 GB) and PyTorch 2.5.x + CUDA 12.4 wheels.
"""
import os
import gc
import cv2
import time
import uuid
import torch
import queue
import shutil
import logging
import tempfile
import subprocess
import threading
import numpy as np
from PIL import Image
from pathlib import Path
from typing import Optional, Tuple, Dict, Any, Callable
from collections import deque
# ----------------------------------------------------------------------------------------------------------------------
# Logging
# ----------------------------------------------------------------------------------------------------------------------
logger = logging.getLogger("backgroundfx_pro")
if not logger.handlers:
h = logging.StreamHandler()
h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s:%(name)s: %(message)s"))
logger.addHandler(h)
logger.setLevel(logging.INFO)
# ----------------------------------------------------------------------------------------------------------------------
# Environment & Torch tuning for T4
# ----------------------------------------------------------------------------------------------------------------------
def setup_t4_environment():
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF",
"expandable_segments:True,max_split_size_mb:256,garbage_collection_threshold:0.7")
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("OPENCV_OPENCL_RUNTIME", "disabled")
os.environ.setdefault("OPENCV_IO_ENABLE_OPENEXR", "0")
torch.set_grad_enabled(False)
try:
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("high")
except Exception:
pass
if torch.cuda.is_available():
try:
frac = float(os.getenv("CUDA_MEMORY_FRACTION", "0.88"))
torch.cuda.set_per_process_memory_fraction(frac)
logger.info(f"CUDA per-process memory fraction = {frac:.2f}")
except Exception as e:
logger.warning(f"Could not set CUDA memory fraction: {e}")
def vram_gb() -> Tuple[float, float]:
if not torch.cuda.is_available():
return 0.0, 0.0
free, total = torch.cuda.mem_get_info()
return free / (1024 ** 3), total / (1024 ** 3)
# ----------------------------------------------------------------------------------------------------------------------
# Heartbeat (prevents Spaces watchdog killing the job)
# ----------------------------------------------------------------------------------------------------------------------
def heartbeat_monitor(running_flag: Dict[str, bool], interval: float = 8.0):
while running_flag.get("running", False):
print(f"[HB] t={int(time.time())}", flush=True)
time.sleep(interval)
# ----------------------------------------------------------------------------------------------------------------------
# Streaming video I/O
# ----------------------------------------------------------------------------------------------------------------------
class StreamingVideoIO:
def __init__(self, video_path: str, out_path: str, fps: float):
self.video_path = video_path
self.out_path = out_path
self.fps = fps
self.cap = None
self.writer = None
self.size = None
def __enter__(self):
self.cap = cv2.VideoCapture(self.video_path)
if not self.cap.isOpened():
raise RuntimeError(f"Cannot open video: {self.video_path}")
w = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
self.size = (w, h)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
self.writer = cv2.VideoWriter(self.out_path, fourcc, self.fps, (w, h))
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.cap:
self.cap.release()
if self.writer:
self.writer.release()
def read_frame(self):
if not self.cap:
return False, None
return self.cap.read()
def write_frame(self, frame_bgr: np.ndarray):
if not self.writer:
return
self.writer.write(frame_bgr)
# ----------------------------------------------------------------------------------------------------------------------
# Models: loaders and safe optimizations
# ----------------------------------------------------------------------------------------------------------------------
def load_sam2_predictor(device: torch.device):
"""
Prefer your local wrapper to keep interfaces stable.
"""
try:
from models.sam2_loader import SAM2Predictor # your wrapper
predictor = SAM2Predictor(device=device)
# Optional: try to access underlying model to set fp16 + channels_last
try:
if hasattr(predictor, "model") and predictor.model is not None:
predictor.model = predictor.model.half().to(device)
predictor.model = predictor.model.to(memory_format=torch.channels_last)
logger.info("SAM2: fp16 + channels_last applied (wrapper model).")
except Exception as e:
logger.warning(f"SAM2 fp16 optimization warning: {e}")
return predictor
except Exception as e:
logger.error(f"Failed to import SAM2Predictor: {e}")
raise
def load_matany_session(device: torch.device):
"""
Supports either MatAnyoneSession or MatAnyoneLoader (your code has varied).
"""
try:
try:
from models.matanyone_loader import MatAnyoneSession as _MatAny
except Exception:
from models.matanyone_loader import MatAnyoneLoader as _MatAny
session = _MatAny(device=device)
# Try fp16 eval where safe
if hasattr(session, "model") and session.model is not None:
session.model.eval()
try:
session.model = session.model.half().to(device)
logger.info("MatAnyone: fp16 + eval applied.")
except Exception:
logger.info("MatAnyone: using fp32 (fp16 not supported for some layers).")
return session
except Exception as e:
logger.warning(f"MatAnyone not available ({e}). Proceeding without refinement.")
return None
# ----------------------------------------------------------------------------------------------------------------------
# SAM2 state pruning (adapter): we call predictor.prune_state if present, else best-effort
# ----------------------------------------------------------------------------------------------------------------------
def prune_sam2_state(predictor, state: Any, keep: int):
"""
Try to prune SAM2 temporal caches to a fixed window length.
Your SAM2Predictor should implement prune_state(state, keep=N). If not, we do nothing.
"""
try:
if hasattr(predictor, "prune_state"):
predictor.prune_state(state, keep=keep)
elif hasattr(state, "prune") and callable(getattr(state, "prune")):
state.prune(keep=keep)
else:
# No-op; rely on model internals and GC
pass
except Exception as e:
logger.debug(f"SAM2 prune_state warning: {e}")
# ----------------------------------------------------------------------------------------------------------------------
# VRAM-aware controller
# ----------------------------------------------------------------------------------------------------------------------
class VRAMAdaptiveController:
def __init__(self):
self.memory_window = int(os.getenv("SAM2_WINDOW", "96")) # frames to keep in model state
self.propagation_scale = float(os.getenv("SAM2_PROP_SCALE", "0.90")) # e.g., downscale factor for propagation
self.cleanup_every = 20 # frames
def adapt(self):
free, total = vram_gb()
if free == 0.0:
return
# Tighten if we dip under ~1.6 GB
if free < 1.6:
self.memory_window = max(48, self.memory_window - 8)
self.propagation_scale = max(0.75, self.propagation_scale - 0.03)
self.cleanup_every = max(12, self.cleanup_every - 2)
logger.warning(f"Low VRAM ({free:.2f} GB free) β†’ window={self.memory_window}, scale={self.propagation_scale:.2f}")
# Relax if plenty free
elif free > 3.0:
self.memory_window = min(128, self.memory_window + 4)
self.propagation_scale = min(1.0, self.propagation_scale + 0.01)
self.cleanup_every = min(40, self.cleanup_every + 2)
# ----------------------------------------------------------------------------------------------------------------------
# Audio mux helper (safer stream mapping)
# ----------------------------------------------------------------------------------------------------------------------
def mux_audio(video_path_no_audio: str, source_with_audio: str, out_path: str) -> bool:
cmd = [
"ffmpeg", "-y", "-hide_banner", "-loglevel", "error",
"-i", video_path_no_audio,
"-i", source_with_audio,
"-map", "0:v:0", "-map", "1:a:0",
"-c:v", "copy", "-c:a", "aac", "-shortest",
out_path
]
try:
r = subprocess.run(cmd, capture_output=True, text=True, timeout=180)
if r.returncode != 0:
logger.warning(f"FFmpeg mux failed: {r.stderr.strip()}")
return False
return True
except Exception as e:
logger.warning(f"FFmpeg mux error: {e}")
return False
# ----------------------------------------------------------------------------------------------------------------------
# Main processing
# ----------------------------------------------------------------------------------------------------------------------
def process(
video_path: str,
background_image: Optional[Image.Image] = None,
background_type: str = "custom",
background_prompt: str = "",
job_directory: Optional[Path] = None,
progress_callback: Optional[Callable[[str, float], None]] = None
) -> str:
"""
Production SAM2 + MatAnyone pipeline for T4.
- Single-pass streaming (no large mask dicts)
- Bounded memory windows
"""
setup_t4_environment()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Heartbeat
hb_flag = {"running": True}
hb_thread = threading.Thread(target=heartbeat_monitor, args=(hb_flag, 8.0), daemon=True)
hb_thread.start()
def report(step: str, p: Optional[float] = None):
if p is None:
logger.info(step)
else:
logger.info(f"{step} [{p:.1%}]")
if progress_callback:
try:
progress_callback(step, p)
except Exception as e:
logger.debug(f"progress_callback error: {e}")
# Validate I/O
src = Path(video_path)
if not src.exists():
hb_flag["running"] = False
raise FileNotFoundError(f"Video not found: {video_path}")
if job_directory is None:
job_directory = Path.cwd() / "tmp" / f"job_{uuid.uuid4().hex[:8]}"
job_directory.mkdir(parents=True, exist_ok=True)
# Probe video
cap_probe = cv2.VideoCapture(str(src))
if not cap_probe.isOpened():
hb_flag["running"] = False
raise RuntimeError(f"Cannot open video: {video_path}")
fps = cap_probe.get(cv2.CAP_PROP_FPS) or 25.0
width = int(cap_probe.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap_probe.get(cv2.CAP_PROP_FRAME_HEIGHT))
frame_count = int(cap_probe.get(cv2.CAP_PROP_FRAME_COUNT))
duration = frame_count / fps if fps > 0 else 0.0
cap_probe.release()
logger.info(f"Video: {width}x{height} @ {fps:.2f} fps | {frame_count} frames ({duration:.1f}s)")
# Prepare background
if background_image is None:
hb_flag["running"] = False
raise ValueError("background_image is required")
bg = background_image.resize((width, height), Image.LANCZOS)
bg_np = np.array(bg).astype(np.float32)
# Load models
report("Loading SAM2 + MatAnyone", 0.05)
predictor = load_sam2_predictor(device)
matany = load_matany_session(device)
# Init SAM2 state (single)
report("Initializing SAM2 video state", 0.08)
state = predictor.init_state(video_path=str(src))
# Minimal prompt: single positive point at center (replace with your prompt UI if needed)
center_pt = np.array([[width // 2, height // 2]], dtype=np.float32)
labels = np.array([1], dtype=np.int32)
ann_obj_id = 1
with torch.inference_mode():
_ = predictor.add_new_points(
inference_state=state,
frame_idx=0,
obj_id=ann_obj_id,
points=center_pt,
labels=labels,
)
# Controller
ctrl = VRAMAdaptiveController()
# Output paths
out_raw = str(job_directory / f"composite_{int(time.time())}.mp4")
out_final = str(job_directory / f"final_{int(time.time())}.mp4")
# Windows/buffers (bounded)
# For completeness we keep a tiny deque for any auxiliary temporal ops (e.g., matting history)
aux_window = deque(maxlen=max(32, min(96, ctrl.memory_window // 2)))
# Stream processing
start = time.time()
frames_done = 0
next_cleanup_at = ctrl.cleanup_every
report("Streaming: SAM2 β†’ MatAnyone β†’ Compose β†’ Write", 0.12)
with StreamingVideoIO(str(src), out_raw, fps) as vio:
# iterate SAM2 propagation alongside reading frames
with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16 if device.type == "cuda" else None):
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(state, scale=ctrl.propagation_scale):
# Read the matching frame
ret, frame_bgr = vio.read_frame()
if not ret:
break
# Get mask for ann_obj_id; keep on GPU as long as possible
mask_t = None
try:
if isinstance(out_obj_ids, torch.Tensor):
# find index where id == ann_obj_id
idxs = (out_obj_ids == ann_obj_id).nonzero(as_tuple=False)
if idxs.numel() > 0:
i = idxs[0].item()
logits = out_mask_logits[i]
else:
logits = None
else:
# list/array fallback
ids_list = list(out_obj_ids)
i = ids_list.index(ann_obj_id) if ann_obj_id in ids_list else -1
logits = out_mask_logits[i] if i >= 0 else None
if logits is not None:
# logits β†’ prob β†’ binary mask (threshold 0)
mask_t = (logits > 0).float() # HxW on CUDA fp16 β†’ fp32 float
except Exception as e:
logger.debug(f"Mask extraction warning @frame {out_frame_idx}: {e}")
mask_t = None
# Optional: MatAnyone refinement
if mask_t is not None and matany is not None:
try:
# MatAnyone APIs vary β€” try common forms
# Convert RGB because many mattors expect RGB
frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
# Move frame to GPU only if your matting backend supports it
refined = None
if hasattr(matany, "refine_mask"):
refined = matany.refine_mask(frame_rgb, mask_t) # allow handler to decide device
elif hasattr(matany, "process_frame"):
refined = matany.process_frame(frame_rgb, mask_t)
if refined is not None:
# ensure float mask 0..1 on CUDA or CPU
if isinstance(refined, torch.Tensor):
mask_t = refined.float()
else:
# numpy β†’ torch
mask_t = torch.from_numpy(refined.astype(np.float32))
if device.type == "cuda":
mask_t = mask_t.to(device)
except Exception as e:
logger.debug(f"MatAnyone refinement failed (frame {out_frame_idx}): {e}")
# Compose and write (convert once, keep math sane)
if mask_t is not None:
# bring mask to CPU for np composition; keep as float [0,1]
mask_np = mask_t.detach().clamp(0, 1).to("cpu", non_blocking=True).float().numpy()
m3 = mask_np[..., None] # HxWx1
frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB).astype(np.float32)
comp = frame_rgb * m3 + bg_np * (1.0 - m3)
comp_bgr = cv2.cvtColor(comp.astype(np.uint8), cv2.COLOR_RGB2BGR)
vio.write_frame(comp_bgr)
else:
# No mask β€” write original frame
vio.write_frame(frame_bgr)
# Periodic maintenance
frames_done += 1
if frames_done >= next_cleanup_at:
ctrl.adapt()
prune_sam2_state(predictor, state, keep=ctrl.memory_window)
# Clear small aux buffers
aux_window.clear()
if device.type == "cuda":
torch.cuda.ipc_collect()
torch.cuda.empty_cache()
next_cleanup_at = frames_done + ctrl.cleanup_every
# Progress
if frames_done % 25 == 0 and frame_count > 0:
p = 0.12 + 0.75 * (frames_done / frame_count)
report(f"Processing frame {frames_done}/{frame_count} | win={ctrl.memory_window} scale={ctrl.propagation_scale:.2f}", p)
# Audio mux
report("Restoring audio", 0.93)
ok = mux_audio(out_raw, str(src), out_final)
final_path = out_final if ok else out_raw
# Cleanup models/state promptly
try:
del predictor
del state
if matany is not None:
del matany
except Exception:
pass
if device.type == "cuda":
torch.cuda.ipc_collect()
torch.cuda.empty_cache()
gc.collect()
hb_flag["running"] = False
elapsed = time.time() - start
try:
peak = torch.cuda.max_memory_allocated() / (1024 ** 3) if device.type == "cuda" else 0.0
logger.info(f"Peak GPU memory: {peak:.2f} GB")
except Exception:
pass
report(f"Done in {elapsed:.1f}s", 1.0)
logger.info(f"Output: {final_path}")
logger.info(f"Artifacts: {job_directory}")
return final_path
# -------------------------------------------------------------------------------------------------
# CLI entry (optional)
# -------------------------------------------------------------------------------------------------
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="BackgroundFX Pro pipeline")
parser.add_argument("--video", required=True, help="Path to input video")
parser.add_argument("--background", required=True, help="Path to background image")
parser.add_argument("--outdir", default=None, help="Job directory (optional)")
args = parser.parse_args()
bg_img = Image.open(args.background).convert("RGB")
outdir = Path(args.outdir) if args.outdir else None
out_path = process(args.video, background_image=bg_img, job_directory=outdir)
print(out_path)