| | |
| | """ |
| | BackgroundFX Pro - Model Loading & Utilities (Hardened) |
| | ====================================================== |
| | - Avoids heavy CUDA/Hydra work at import time |
| | - Adds timeouts to subprocess probes |
| | - Safer sys.path wiring for third_party repos |
| | - MatAnyone loader is probe-only here; actual run happens in matanyone_loader.MatAnyoneSession |
| | |
| | Changes (2025-09-16): |
| | - Aligned with torch==2.3.1+cu121 and MatAnyone v1.0.0 |
| | - Updated load_matany to apply T=1 squeeze patch before InferenceCore import |
| | - Added patch status logging and MatAnyone version |
| | - Added InferenceCore attributes logging for debugging |
| | - Fixed InferenceCore import path to matanyone.inference.inference_core |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import os |
| | import sys |
| | import cv2 |
| | import subprocess |
| | import inspect |
| | import logging |
| | import importlib.metadata |
| | from pathlib import Path |
| | from typing import Optional, Tuple, Dict, Any, Union, Callable |
| |
|
| | import numpy as np |
| | import yaml |
| |
|
| | |
| | try: |
| | import torch |
| | except ImportError: |
| | torch = None |
| |
|
| | |
| | |
| | |
| | logger = logging.getLogger("backgroundfx_pro") |
| | if not logger.handlers: |
| | _h = logging.StreamHandler() |
| | _h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s")) |
| | logger.addHandler(_h) |
| | logger.setLevel(logging.INFO) |
| |
|
| | |
| | try: |
| | cv_threads = int(os.environ.get("CV_THREADS", "1")) |
| | if hasattr(cv2, "setNumThreads"): |
| | cv2.setNumThreads(cv_threads) |
| | except Exception: |
| | pass |
| |
|
| | |
| | |
| | |
| | try: |
| | import mediapipe as mp |
| | _HAS_MEDIAPIPE = True |
| | except Exception: |
| | _HAS_MEDIAPIPE = False |
| |
|
| | |
| | |
| | |
| | ROOT = Path(__file__).resolve().parent.parent |
| | TP_SAM2 = Path(os.environ.get("THIRD_PARTY_SAM2_DIR", ROOT / "third_party" / "sam2")).resolve() |
| | TP_MATANY = Path(os.environ.get("THIRD_PARTY_MATANY_DIR", ROOT / "third_party" / "matanyone")).resolve() |
| |
|
| | def _add_sys_path(p: Path) -> None: |
| | if p.exists(): |
| | p_str = str(p) |
| | if p_str not in sys.path: |
| | sys.path.insert(0, p_str) |
| | else: |
| | logger.warning(f"third_party path not found: {p}") |
| |
|
| | _add_sys_path(TP_SAM2) |
| | _add_sys_path(TP_MATANY) |
| |
|
| | |
| | |
| | |
| | def _torch(): |
| | try: |
| | import torch |
| | return torch |
| | except Exception as e: |
| | logger.warning(f"[models.safe-torch] import failed: {e}") |
| | return None |
| |
|
| | def _has_cuda() -> bool: |
| | t = _torch() |
| | if t is None: |
| | return False |
| | try: |
| | return bool(t.cuda.is_available()) |
| | except Exception as e: |
| | logger.warning(f"[models.safe-torch] cuda.is_available() failed: {e}") |
| | return False |
| |
|
| | def _pick_device(env_key: str) -> str: |
| | requested = os.environ.get(env_key, "").strip().lower() |
| | has_cuda = _has_cuda() |
| | |
| | |
| | cuda_env_vars = { |
| | 'FORCE_CUDA_DEVICE': os.environ.get('FORCE_CUDA_DEVICE', ''), |
| | 'CUDA_MEMORY_FRACTION': os.environ.get('CUDA_MEMORY_FRACTION', ''), |
| | 'PYTORCH_CUDA_ALLOC_CONF': os.environ.get('PYTORCH_CUDA_ALLOC_CONF', ''), |
| | 'REQUIRE_CUDA': os.environ.get('REQUIRE_CUDA', ''), |
| | 'SAM2_DEVICE': os.environ.get('SAM2_DEVICE', ''), |
| | 'MATANY_DEVICE': os.environ.get('MATANY_DEVICE', ''), |
| | } |
| | logger.info(f"CUDA environment variables: {cuda_env_vars}") |
| | |
| | logger.info(f"_pick_device({env_key}): requested='{requested}', has_cuda={has_cuda}") |
| | |
| | |
| | if has_cuda and requested not in {"cpu"}: |
| | logger.info(f"FORCING CUDA device (GPU available, requested='{requested}')") |
| | return "cuda" |
| | elif requested in {"cuda", "cpu"}: |
| | logger.info(f"Using explicitly requested device: {requested}") |
| | return requested |
| | |
| | result = "cuda" if has_cuda else "cpu" |
| | logger.info(f"Auto-selected device: {result}") |
| | return result |
| |
|
| | |
| | |
| | |
| | def _ffmpeg_bin() -> str: |
| | return os.environ.get("FFMPEG_BIN", "ffmpeg") |
| |
|
| | def _probe_ffmpeg(timeout: int = 2) -> bool: |
| | try: |
| | subprocess.run([_ffmpeg_bin(), "-version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True, timeout=timeout) |
| | return True |
| | except Exception: |
| | return False |
| |
|
| | def _ensure_dir(p: Path) -> None: |
| | p.mkdir(parents=True, exist_ok=True) |
| |
|
| | def _cv_read_first_frame(video_path: Union[str, Path]) -> Tuple[Optional[np.ndarray], int, Tuple[int, int]]: |
| | cap = cv2.VideoCapture(str(video_path)) |
| | if not cap.isOpened(): |
| | return None, 0, (0, 0) |
| | fps = int(round(cap.get(cv2.CAP_PROP_FPS) or 25)) |
| | ok, frame = cap.read() |
| | w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0) |
| | h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0) |
| | cap.release() |
| | if not ok: |
| | return None, fps, (w, h) |
| | return frame, fps, (w, h) |
| |
|
| | def _save_mask_png(mask: np.ndarray, path: Union[str, Path]) -> str: |
| | if mask.dtype == bool: |
| | mask = (mask.astype(np.uint8) * 255) |
| | elif mask.dtype != np.uint8: |
| | mask = np.clip(mask, 0, 255).astype(np.uint8) |
| | cv2.imwrite(str(path), mask) |
| | return str(path) |
| |
|
| | def _resize_keep_ar(image: np.ndarray, target_wh: Tuple[int, int]) -> np.ndarray: |
| | tw, th = target_wh |
| | h, w = image.shape[:2] |
| | if h == 0 or w == 0 or tw == 0 or th == 0: |
| | return image |
| | scale = min(tw / w, th / h) |
| | nw, nh = max(1, int(round(w * scale))), max(1, int(round(h * scale))) |
| | resized = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_CUBIC) |
| | canvas = np.zeros((th, tw, 3), dtype=resized.dtype) |
| | x0 = (tw - nw) // 2 |
| | y0 = (th - nh) // 2 |
| | canvas[y0:y0+nh, x0:x0+nw] = resized |
| | return canvas |
| |
|
| | def _video_writer(out_path: Path, fps: int, size: Tuple[int, int]) -> cv2.VideoWriter: |
| | fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
| | return cv2.VideoWriter(str(out_path), fourcc, max(1, fps), size) |
| |
|
| | def _mux_audio(src_video: Union[str, Path], silent_video: Union[str, Path], out_path: Union[str, Path]) -> bool: |
| | """Copy video from silent_video + audio from src_video into out_path (AAC).""" |
| | try: |
| | cmd = [ |
| | _ffmpeg_bin(), "-y", |
| | "-i", str(silent_video), |
| | "-i", str(src_video), |
| | "-map", "0:v:0", |
| | "-map", "1:a:0?", |
| | "-c:v", "copy", |
| | "-c:a", "aac", "-b:a", "192k", |
| | "-shortest", |
| | str(out_path) |
| | ] |
| | subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
| | return True |
| | except Exception as e: |
| | logger.warning(f"Audio mux failed; returning silent video. Reason: {e}") |
| | return False |
| |
|
| | |
| | |
| | |
| | def _refine_alpha(alpha: np.ndarray, erode_px: int = 1, dilate_px: int = 2, blur_px: float = 1.5) -> np.ndarray: |
| | if alpha.dtype != np.float32: |
| | a = alpha.astype(np.float32) |
| | if a.max() > 1.0: |
| | a = a / 255.0 |
| | else: |
| | a = alpha.copy() |
| |
|
| | a_u8 = np.clip(np.round(a * 255.0), 0, 255).astype(np.uint8) |
| | if erode_px > 0: |
| | k = max(1, int(erode_px)) |
| | a_u8 = cv2.erode(a_u8, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k)), iterations=1) |
| | if dilate_px > 0: |
| | k = max(1, int(dilate_px)) |
| | a_u8 = cv2.dilate(a_u8, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k)), iterations=1) |
| | a = a_u8.astype(np.float32) / 255.0 |
| |
|
| | if blur_px and blur_px > 0: |
| | rad = max(1, int(round(blur_px))) |
| | a = cv2.GaussianBlur(a, (rad | 1, rad | 1), 0) |
| |
|
| | return np.clip(a, 0.0, 1.0) |
| |
|
| | def _to_linear(rgb: np.ndarray, gamma: float = 2.2) -> np.ndarray: |
| | x = np.clip(rgb.astype(np.float32) / 255.0, 0.0, 1.0) |
| | return np.power(x, gamma) |
| |
|
| | def _to_srgb(lin: np.ndarray, gamma: float = 2.2) -> np.ndarray: |
| | x = np.clip(lin, 0.0, 1.0) |
| | return np.clip(np.power(x, 1.0 / gamma) * 255.0, 0, 255).astype(np.uint8) |
| |
|
| | def _light_wrap(bg_rgb: np.ndarray, alpha01: np.ndarray, radius: int = 5, amount: float = 0.18) -> np.ndarray: |
| | r = max(1, int(radius)) |
| | inv = 1.0 - alpha01 |
| | inv_blur = cv2.GaussianBlur(inv, (r | 1, r | 1), 0) |
| | lw = (bg_rgb.astype(np.float32) * inv_blur[..., None] * float(amount)) |
| | return lw |
| |
|
| | def _despill_edges(fg_rgb: np.ndarray, alpha01: np.ndarray, amount: float = 0.35) -> np.ndarray: |
| | w = 1.0 - 2.0 * np.abs(alpha01 - 0.5) |
| | w = np.clip(w, 0.0, 1.0) |
| | hsv = cv2.cvtColor(fg_rgb.astype(np.uint8), cv2.COLOR_RGB2HSV).astype(np.float32) |
| | H, S, V = cv2.split(hsv) |
| | S = S * (1.0 - amount * w) |
| | hsv2 = cv2.merge([H, np.clip(S, 0, 255), V]) |
| | out = cv2.cvtColor(hsv2.astype(np.uint8), cv2.COLOR_HSV2RGB) |
| | return out |
| |
|
| | def _composite_frame_pro( |
| | fg_rgb: np.ndarray, alpha: np.ndarray, bg_rgb: np.ndarray, |
| | erode_px: int = None, dilate_px: int = None, blur_px: float = None, |
| | lw_radius: int = None, lw_amount: float = None, despill_amount: float = None |
| | ) -> np.ndarray: |
| | erode_px = erode_px if erode_px is not None else int(os.environ.get("EDGE_ERODE", "1")) |
| | dilate_px = dilate_px if dilate_px is not None else int(os.environ.get("EDGE_DILATE", "2")) |
| | blur_px = blur_px if blur_px is not None else float(os.environ.get("EDGE_BLUR", "1.5")) |
| | lw_radius = lw_radius if lw_radius is not None else int(os.environ.get("LIGHTWRAP_RADIUS", "5")) |
| | lw_amount = lw_amount if lw_amount is not None else float(os.environ.get("LIGHTWRAP_AMOUNT", "0.18")) |
| | despill_amount = despill_amount if despill_amount is not None else float(os.environ.get("DESPILL_AMOUNT", "0.35")) |
| |
|
| | a = _refine_alpha(alpha, erode_px=erode_px, dilate_px=dilate_px, blur_px=blur_px) |
| | fg_rgb = _despill_edges(fg_rgb, a, amount=despill_amount) |
| |
|
| | fg_lin = _to_linear(fg_rgb) |
| | bg_lin = _to_linear(bg_rgb) |
| | lw = _light_wrap(bg_rgb, a, radius=lw_radius, amount=lw_amount) |
| | lw_lin = _to_linear(np.clip(lw, 0, 255).astype(np.uint8)) |
| |
|
| | comp_lin = fg_lin * a[..., None] + bg_lin * (1.0 - a[..., None]) + lw_lin |
| | comp = _to_srgb(comp_lin) |
| | return comp |
| |
|
| | |
| | |
| | |
| | def _resolve_sam2_cfg(cfg_str: str) -> str: |
| | """Resolve SAM2 config path - return relative path for Hydra compatibility.""" |
| | logger.info(f"_resolve_sam2_cfg called with cfg_str={cfg_str}") |
| | |
| | |
| | tp_sam2 = os.environ.get("THIRD_PARTY_SAM2_DIR", "/home/user/app/third_party/sam2") |
| | logger.info(f"TP_SAM2 = {tp_sam2}") |
| | |
| | |
| | candidate = os.path.join(tp_sam2, cfg_str) |
| | logger.info(f"Candidate path: {candidate}") |
| | logger.info(f"Candidate exists: {os.path.exists(candidate)}") |
| | |
| | if os.path.exists(candidate): |
| | |
| | if cfg_str.startswith("sam2/configs/"): |
| | relative_path = cfg_str.replace("sam2/configs/", "configs/") |
| | else: |
| | relative_path = cfg_str |
| | logger.info(f"Returning Hydra-compatible relative path: {relative_path}") |
| | return relative_path |
| | |
| | |
| | fallbacks = [ |
| | os.path.join(tp_sam2, "sam2", cfg_str), |
| | os.path.join(tp_sam2, "configs", cfg_str), |
| | ] |
| | |
| | for fallback in fallbacks: |
| | logger.info(f"Trying fallback: {fallback}") |
| | if os.path.exists(fallback): |
| | |
| | if "configs/" in fallback: |
| | relative_path = "configs/" + fallback.split("configs/")[-1] |
| | logger.info(f"Returning fallback relative path: {relative_path}") |
| | return relative_path |
| | |
| | logger.warning(f"Config not found, returning original: {cfg_str}") |
| | return cfg_str |
| |
|
| | def _find_hiera_config_if_hieradet(cfg_path: str) -> Optional[str]: |
| | """If config references 'hieradet', try to find a 'hiera' config.""" |
| | try: |
| | with open(cfg_path, "r") as f: |
| | data = yaml.safe_load(f) |
| | model = data.get("model", {}) or {} |
| | enc = model.get("image_encoder") or {} |
| | trunk = enc.get("trunk") or {} |
| | target = trunk.get("_target_") or trunk.get("target") |
| | if isinstance(target, str) and "hieradet" in target: |
| | for y in TP_SAM2.rglob("*.yaml"): |
| | try: |
| | with open(y, "r") as f2: |
| | d2 = yaml.safe_load(f2) or {} |
| | e2 = (d2.get("model", {}) or {}).get("image_encoder") or {} |
| | t2 = (e2.get("trunk") or {}) |
| | tgt2 = t2.get("_target_") or t2.get("target") |
| | if isinstance(tgt2, str) and ".hiera." in tgt2: |
| | logger.info(f"SAM2: switching config from 'hieradet' → 'hiera': {y}") |
| | return str(y) |
| | except Exception: |
| | continue |
| | except Exception: |
| | pass |
| | return None |
| |
|
| | def load_sam2() -> Tuple[Optional[object], bool, Dict[str, Any]]: |
| | """Robust SAM2 loader with config resolution and error handling.""" |
| | meta = {"sam2_import_ok": False, "sam2_init_ok": False} |
| | try: |
| | from sam2.build_sam import build_sam2 |
| | from sam2.sam2_image_predictor import SAM2ImagePredictor |
| | meta["sam2_import_ok"] = True |
| | except Exception as e: |
| | logger.warning(f"SAM2 import failed: {e}") |
| | return None, False, meta |
| |
|
| | |
| | if torch and torch.cuda.is_available(): |
| | mem_before = torch.cuda.memory_allocated() / 1024**3 |
| | logger.info(f"🔍 GPU memory before SAM2 load: {mem_before:.2f}GB") |
| |
|
| | device = _pick_device("SAM2_DEVICE") |
| | cfg_env = os.environ.get("SAM2_MODEL_CFG", "sam2/configs/sam2/sam2_hiera_l.yaml") |
| | cfg = _resolve_sam2_cfg(cfg_env) |
| | ckpt = os.environ.get("SAM2_CHECKPOINT", "") |
| |
|
| | def _try_build(cfg_path: str): |
| | logger.info(f"_try_build called with cfg_path: {cfg_path}") |
| | params = set(inspect.signature(build_sam2).parameters.keys()) |
| | logger.info(f"build_sam2 parameters: {list(params)}") |
| | kwargs = {} |
| | if "config_file" in params: |
| | kwargs["config_file"] = cfg_path |
| | logger.info(f"Using config_file parameter: {cfg_path}") |
| | elif "model_cfg" in params: |
| | kwargs["model_cfg"] = cfg_path |
| | logger.info(f"Using model_cfg parameter: {cfg_path}") |
| | if ckpt: |
| | if "checkpoint" in params: |
| | kwargs["checkpoint"] = ckpt |
| | elif "ckpt_path" in params: |
| | kwargs["ckpt_path"] = ckpt |
| | elif "weights" in params: |
| | kwargs["weights"] = ckpt |
| | if "device" in params: |
| | kwargs["device"] = device |
| | try: |
| | logger.info(f"Calling build_sam2 with kwargs: {kwargs}") |
| | result = build_sam2(**kwargs) |
| | logger.info(f"build_sam2 succeeded with kwargs") |
| | |
| | if hasattr(result, 'device'): |
| | logger.info(f"SAM2 model device: {result.device}") |
| | elif hasattr(result, 'image_encoder') and hasattr(result.image_encoder, 'device'): |
| | logger.info(f"SAM2 model device: {result.image_encoder.device}") |
| | return result |
| | except TypeError as e: |
| | logger.info(f"build_sam2 kwargs failed: {e}, trying positional args") |
| | pos = [cfg_path] |
| | if ckpt: |
| | pos.append(ckpt) |
| | if "device" not in kwargs: |
| | pos.append(device) |
| | logger.info(f"Calling build_sam2 with positional args: {pos}") |
| | result = build_sam2(*pos) |
| | logger.info(f"build_sam2 succeeded with positional args") |
| | return result |
| |
|
| | try: |
| | try: |
| | sam = _try_build(cfg) |
| | except Exception: |
| | alt_cfg = _find_hiera_config_if_hieradet(cfg) |
| | if alt_cfg: |
| | sam = _try_build(alt_cfg) |
| | else: |
| | raise |
| | |
| | if sam is not None: |
| | predictor = SAM2ImagePredictor(sam) |
| | meta["sam2_init_ok"] = True |
| | meta["sam2_device"] = device |
| | return predictor, True, meta |
| | else: |
| | return None, False, meta |
| | |
| | except Exception as e: |
| | logger.error(f"SAM2 loading failed: {e}") |
| | return None, False, meta |
| |
|
| | def run_sam2_mask(predictor: object, |
| | first_frame_bgr: np.ndarray, |
| | point: Optional[Tuple[int, int]] = None, |
| | auto: bool = False) -> Tuple[Optional[np.ndarray], bool]: |
| | """Return (mask_uint8_0_255, ok).""" |
| | if predictor is None: |
| | return None, False |
| | try: |
| | rgb = cv2.cvtColor(first_frame_bgr, cv2.COLOR_BGR2RGB) |
| | predictor.set_image(rgb) |
| |
|
| | if auto: |
| | h, w = rgb.shape[:2] |
| | box = np.array([int(0.05*w), int(0.05*h), int(0.95*w), int(0.95*h)]) |
| | masks, _, _ = predictor.predict(box=box) |
| | elif point is not None: |
| | x, y = int(point[0]), int(point[1]) |
| | pts = np.array([[x, y]], dtype=np.int32) |
| | labels = np.array([1], dtype=np.int32) |
| | masks, _, _ = predictor.predict(point_coords=pts, point_labels=labels) |
| | else: |
| | h, w = rgb.shape[:2] |
| | box = np.array([int(0.1*w), int(0.1*h), int(0.9*w), int(0.9*h)]) |
| | masks, _, _ = predictor.predict(box=box) |
| |
|
| | if masks is None or len(masks) == 0: |
| | return None, False |
| |
|
| | m = masks[0].astype(np.uint8) * 255 |
| | return m, True |
| | except Exception as e: |
| | logger.warning(f"SAM2 mask failed: {e}") |
| | return None, False |
| |
|
| | def _refine_mask_grabcut(image_bgr: np.ndarray, |
| | mask_u8: np.ndarray, |
| | iters: int = None, |
| | trimap_erode: int = None, |
| | trimap_dilate: int = None) -> np.ndarray: |
| | """Use SAM2 seed as initialization for GrabCut refinement.""" |
| | iters = int(os.environ.get("REFINE_GRABCUT_ITERS", "2")) if iters is None else int(iters) |
| | e = int(os.environ.get("REFINE_TRIMAP_ERODE", "3")) if trimap_erode is None else int(trimap_erode) |
| | d = int(os.environ.get("REFINE_TRIMAP_DILATE", "6")) if trimap_dilate is None else int(trimap_dilate) |
| |
|
| | h, w = mask_u8.shape[:2] |
| | m = (mask_u8 > 127).astype(np.uint8) * 255 |
| |
|
| | sure_fg = cv2.erode(m, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (max(1, e), max(1, e))), iterations=1) |
| | sure_bg = cv2.erode(255 - m, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (max(1, d), max(1, d))), iterations=1) |
| |
|
| | gc_mask = np.full((h, w), cv2.GC_PR_BGD, dtype=np.uint8) |
| | gc_mask[sure_bg > 0] = cv2.GC_BGD |
| | gc_mask[sure_fg > 0] = cv2.GC_FGD |
| |
|
| | bgdModel = np.zeros((1, 65), np.float64) |
| | fgdModel = np.zeros((1, 65), np.float64) |
| | try: |
| | cv2.grabCut(image_bgr, gc_mask, None, bgdModel, fgdModel, iters, cv2.GC_INIT_WITH_MASK) |
| | out = np.where((gc_mask == cv2.GC_FGD) | (gc_mask == cv2.GC_PR_FGD), 255, 0).astype(np.uint8) |
| | out = cv2.medianBlur(out, 5) |
| | return out |
| | except Exception as e: |
| | logger.warning(f"GrabCut refinement failed; using original mask. Reason: {e}") |
| | return m |
| |
|
| | |
| | |
| | |
| | def load_matany() -> Tuple[Optional[object], bool, Dict[str, Any]]: |
| | """ |
| | Probe MatAnyone availability with T=1 squeeze patch for conv2d compatibility. |
| | Returns (None, available, meta); actual instantiation happens in MatAnyoneSession. |
| | """ |
| | meta = {"matany_import_ok": False, "matany_init_ok": False} |
| | enable_env = os.environ.get("ENABLE_MATANY", "1").strip().lower() |
| | if enable_env in {"0", "false", "off", "no"}: |
| | logger.info("MatAnyone disabled by ENABLE_MATANY=0.") |
| | meta["disabled"] = True |
| | return None, False, meta |
| |
|
| | |
| | try: |
| | from .matany_compat_patch import apply_matany_t1_squeeze_guard |
| | if apply_matany_t1_squeeze_guard(): |
| | logger.info("[MatAnyCompat] T=1 squeeze guard applied") |
| | meta["patch_applied"] = True |
| | else: |
| | logger.warning("[MatAnyCompat] T=1 squeeze patch failed; conv2d errors may occur") |
| | meta["patch_applied"] = False |
| | except Exception as e: |
| | logger.warning(f"[MatAnyCompat] Patch import failed: {e}") |
| | meta["patch_applied"] = False |
| |
|
| | try: |
| | from matanyone.inference.inference_core import InferenceCore |
| | meta["matany_import_ok"] = True |
| | |
| | try: |
| | version = importlib.metadata.version("matanyone") |
| | logger.info(f"[MATANY] MatAnyone version: {version}") |
| | except Exception: |
| | logger.info("[MATANY] MatAnyone version unknown") |
| | logger.debug(f"[MATANY] InferenceCore attributes: {dir(InferenceCore)}") |
| | device = _pick_device("MATANY_DEVICE") |
| | repo_id = os.environ.get("MATANY_REPO_ID", "PeiqingYang/MatAnyone") |
| | meta["matany_repo_id"] = repo_id |
| | meta["matany_device"] = device |
| | return None, True, meta |
| | except Exception as e: |
| | logger.warning(f"MatAnyone import failed: {e}") |
| | return None, False, meta |
| |
|
| | |
| | |
| | |
| | def fallback_mask(first_frame_bgr: np.ndarray) -> np.ndarray: |
| | """Prefer MediaPipe; fallback to GrabCut. Returns uint8 mask 0/255.""" |
| | h, w = first_frame_bgr.shape[:2] |
| | if _HAS_MEDIAPIPE: |
| | try: |
| | mp_selfie = mp.solutions.selfie_segmentation |
| | with mp_selfie.SelfieSegmentation(model_selection=1) as segmenter: |
| | rgb = cv2.cvtColor(first_frame_bgr, cv2.COLOR_BGR2RGB) |
| | res = segmenter.process(rgb) |
| | m = (np.clip(res.segmentation_mask, 0, 1) > 0.5).astype(np.uint8) * 255 |
| | m = cv2.medianBlur(m, 5) |
| | return m |
| | except Exception as e: |
| | logger.warning(f"MediaPipe fallback failed: {e}") |
| |
|
| | |
| | mask = np.zeros((h, w), np.uint8) |
| | rect = (int(0.1*w), int(0.1*h), int(0.8*w), int(0.8*h)) |
| | bgdModel = np.zeros((1, 65), np.float64) |
| | fgdModel = np.zeros((1, 65), np.float64) |
| | try: |
| | cv2.grabCut(first_frame_bgr, mask, rect, bgdModel, fgdModel, 5, cv2.GC_INIT_WITH_RECT) |
| | mask_bin = np.where((mask == cv2.GC_FGD) | (mask == cv2.GC_PR_FGD), 255, 0).astype(np.uint8) |
| | return mask_bin |
| | except Exception as e: |
| | logger.warning(f"GrabCut failed: {e}") |
| | return np.zeros((h, w), dtype=np.uint8) |
| |
|
| | def composite_video(fg_path: Union[str, Path], |
| | alpha_path: Union[str, Path], |
| | bg_image_path: Union[str, Path], |
| | out_path: Union[str, Path], |
| | fps: int, |
| | size: Tuple[int, int]) -> bool: |
| | """Blend MatAnyone FG+ALPHA over background using pro compositor.""" |
| | fg_cap = cv2.VideoCapture(str(fg_path)) |
| | al_cap = cv2.VideoCapture(str(alpha_path)) |
| | if not fg_cap.isOpened() or not al_cap.isOpened(): |
| | return False |
| |
|
| | w, h = size |
| | bg = cv2.imread(str(bg_image_path), cv2.IMREAD_COLOR) |
| | if bg is None: |
| | bg = np.full((h, w, 3), 127, dtype=np.uint8) |
| | bg_f = _resize_keep_ar(bg, (w, h)) |
| |
|
| | if _probe_ffmpeg(): |
| | tmp_out = Path(str(out_path) + ".tmp.mp4") |
| | writer = _video_writer(tmp_out, fps, (w, h)) |
| | post_h264 = True |
| | else: |
| | writer = _video_writer(Path(out_path), fps, (w, h)) |
| | post_h264 = False |
| |
|
| | ok_any = False |
| | try: |
| | while True: |
| | ok_fg, fg = fg_cap.read() |
| | ok_al, al = al_cap.read() |
| | if not ok_fg or not ok_al: |
| | break |
| | fg = cv2.resize(fg, (w, h), interpolation=cv2.INTER_CUBIC) |
| | al_gray = cv2.cvtColor(cv2.resize(al, (w, h)), cv2.COLOR_BGR2GRAY) |
| |
|
| | comp = _composite_frame_pro( |
| | cv2.cvtColor(fg, cv2.COLOR_BGR2RGB), |
| | al_gray, |
| | cv2.cvtColor(bg_f, cv2.COLOR_BGR2RGB) |
| | ) |
| | writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR)) |
| | ok_any = True |
| | finally: |
| | fg_cap.release() |
| | al_cap.release() |
| | writer.release() |
| |
|
| | if post_h264 and ok_any: |
| | try: |
| | cmd = [ |
| | _ffmpeg_bin(), "-y", |
| | "-i", str(tmp_out), |
| | "-c:v", "libx264", "-pix_fmt", "yuv420p", "-movflags", "+faststart", |
| | str(out_path) |
| | ] |
| | subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
| | tmp_out.unlink(missing_ok=True) |
| | except Exception as e: |
| | logger.warning(f"ffmpeg finalize failed: {e}") |
| | Path(out_path).unlink(missing_ok=True) |
| | tmp_out.replace(out_path) |
| |
|
| | return ok_any |
| |
|
| | def fallback_composite(video_path: Union[str, Path], |
| | mask_path: Union[str, Path], |
| | bg_image_path: Union[str, Path], |
| | out_path: Union[str, Path]) -> bool: |
| | """Static-mask compositing using pro compositor.""" |
| | mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE) |
| | cap = cv2.VideoCapture(str(video_path)) |
| | if mask is None or not cap.isOpened(): |
| | return False |
| |
|
| | w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0) |
| | h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0) |
| | fps = int(round(cap.get(cv2.CAP_PROP_FPS) or 25)) |
| |
|
| | bg = cv2.imread(str(bg_image_path), cv2.IMREAD_COLOR) |
| | if bg is None: |
| | bg = np.full((h, w, 3), 127, dtype=np.uint8) |
| |
|
| | mask_resized = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST) |
| | bg_f = _resize_keep_ar(bg, (w, h)) |
| |
|
| | if _probe_ffmpeg(): |
| | tmp_out = Path(str(out_path) + ".tmp.mp4") |
| | writer = _video_writer(tmp_out, fps, (w, h)) |
| | use_post_ffmpeg = True |
| | else: |
| | writer = _video_writer(Path(out_path), fps, (w, h)) |
| | use_post_ffmpeg = False |
| |
|
| | ok_any = False |
| | try: |
| | while True: |
| | ok, frame = cap.read() |
| | if not ok: |
| | break |
| | comp = _composite_frame_pro( |
| | cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), |
| | mask_resized, |
| | cv2.cvtColor(bg_f, cv2.COLOR_BGR2RGB) |
| | ) |
| | writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR)) |
| | ok_any = True |
| | finally: |
| | cap.release() |
| | writer.release() |
| |
|
| | if use_post_ffmpeg and ok_any: |
| | try: |
| | cmd = [ |
| | _ffmpeg_bin(), "-y", |
| | "-i", str(tmp_out), |
| | "-c:v", "libx264", "-pix_fmt", "yuv420p", "-movflags", "+faststart", |
| | str(out_path) |
| | ] |
| | subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
| | tmp_out.unlink(missing_ok=True) |
| | except Exception as e: |
| | logger.warning(f"ffmpeg H.264 finalize failed: {e}") |
| | Path(out_path).unlink(missing_ok=True) |
| | tmp_out.replace(out_path) |
| |
|
| | return ok_any |
| |
|
| | |
| | |
| | |
| | def _checkerboard_bg(w: int, h: int, tile: int = 32) -> np.ndarray: |
| | y, x = np.mgrid[0:h, 0:w] |
| | c = ((x // tile) + (y // tile)) % 2 |
| | a = np.where(c == 0, 200, 150).astype(np.uint8) |
| | return np.stack([a, a, a], axis=-1) |
| |
|
| | def _build_stage_a_rgba_vp9_from_fg_alpha( |
| | fg_path: Union[str, Path], |
| | alpha_path: Union[str, Path], |
| | out_webm: Union[str, Path], |
| | fps: int, |
| | size: Tuple[int, int], |
| | src_audio: Optional[Union[str, Path]] = None, |
| | ) -> bool: |
| | if not _probe_ffmpeg(): |
| | return False |
| | w, h = size |
| | try: |
| | cmd = [_ffmpeg_bin(), "-y", "-i", str(fg_path), "-i", str(alpha_path)] |
| | if src_audio: |
| | cmd += ["-i", str(src_audio)] |
| | fcx = f"[1:v]format=gray,scale={w}:{h},fps={fps}[al];" \ |
| | f"[0:v]scale={w}:{h},fps={fps}[fg];" \ |
| | f"[fg][al]alphamerge[outv]" |
| | cmd += ["-filter_complex", fcx, "-map", "[outv]"] |
| | if src_audio: |
| | cmd += ["-map", "2:a:0?", "-c:a", "libopus", "-b:a", "128k"] |
| | cmd += [ |
| | "-c:v", "libvpx-vp9", "-pix_fmt", "yuva420p", |
| | "-crf", os.environ.get("STAGEA_VP9_CRF", "28"), |
| | "-b:v", "0", "-row-mt", "1", "-shortest", str(out_webm), |
| | ] |
| | subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
| | return True |
| | except Exception as e: |
| | logger.warning(f"Stage-A VP9(alpha) build failed: {e}") |
| | return False |
| |
|
| | def _build_stage_a_rgba_vp9_from_mask( |
| | video_path: Union[str, Path], |
| | mask_png: Union[str, Path], |
| | out_webm: Union[str, Path], |
| | fps: int, |
| | size: Tuple[int, int], |
| | ) -> bool: |
| | if not _probe_ffmpeg(): |
| | return False |
| | w, h = size |
| | try: |
| | cmd = [ |
| | _ffmpeg_bin(), "-y", |
| | "-i", str(video_path), |
| | "-loop", "1", "-i", str(mask_png), |
| | "-filter_complex", |
| | f"[1:v]format=gray,scale={w}:{h},fps={fps}[al];" |
| | f"[0:v]scale={w}:{h},fps={fps}[fg];" |
| | f"[fg][al]alphamerge[outv]", |
| | "-map", "[outv]", |
| | "-c:v", "libvpx-vp9", "-pix_fmt", "yuva420p", |
| | "-crf", os.environ.get("STAGEA_VP9_CRF", "28"), |
| | "-b:v", "0", "-row-mt", "1", "-shortest", str(out_webm), |
| | ] |
| | subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
| | return True |
| | except Exception as e: |
| | logger.warning(f"Stage-A VP9(alpha) (mask) build failed: {e}") |
| | return False |
| |
|
| | def _build_stage_a_checkerboard_from_fg_alpha( |
| | fg_path: Union[str, Path], |
| | alpha_path: Union[str, Path], |
| | out_mp4: Union[str, Path], |
| | fps: int, |
| | size: Tuple[int, int], |
| | ) -> bool: |
| | fg_cap = cv2.VideoCapture(str(fg_path)) |
| | al_cap = cv2.VideoCapture(str(alpha_path)) |
| | if not fg_cap.isOpened() or not al_cap.isOpened(): |
| | return False |
| | w, h = size |
| | writer = _video_writer(Path(out_mp4), fps, (w, h)) |
| | bg = _checkerboard_bg(w, h) |
| | ok_any = False |
| | try: |
| | while True: |
| | okf, fg = fg_cap.read() |
| | oka, al = al_cap.read() |
| | if not okf or not oka: |
| | break |
| | fg = cv2.resize(fg, (w, h)) |
| | al = cv2.cvtColor(cv2.resize(al, (w, h)), cv2.COLOR_BGR2GRAY) |
| | comp = _composite_frame_pro(cv2.cvtColor(fg, cv2.COLOR_BGR2RGB), al, bg) |
| | writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR)) |
| | ok_any = True |
| | finally: |
| | fg_cap.release() |
| | al_cap.release() |
| | writer.release() |
| | return ok_any |
| |
|
| | def _build_stage_a_checkerboard_from_mask( |
| | video_path: Union[str, Path], |
| | mask_png: Union[str, Path], |
| | out_mp4: Union[str, Path], |
| | fps: int, |
| | size: Tuple[int, int], |
| | ) -> bool: |
| | cap = cv2.VideoCapture(str(video_path)) |
| | if not cap.isOpened(): |
| | return False |
| | w, h = size |
| | mask = cv2.imread(str(mask_png), cv2.IMREAD_GRAYSCALE) |
| | if mask is None: |
| | return False |
| | mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST) |
| | writer = _video_writer(Path(out_mp4), fps, (w, h)) |
| | bg = _checkerboard_bg(w, h) |
| | ok_any = False |
| | try: |
| | while True: |
| | ok, frame = cap.read() |
| | if not ok: |
| | break |
| | frame = cv2.resize(frame, (w, h)) |
| | comp = _composite_frame_pro(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), mask, bg) |
| | writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR)) |
| | ok_any = True |
| | finally: |
| | cap.release() |
| | writer.release() |
| | return ok_any |
| |
|
| | |
| | |
| | |
| | def run_matany( |
| | video_path: Union[str, Path], |
| | mask_path: Optional[Union[str, Path]], |
| | out_dir: Union[str, Path], |
| | device: Optional[str] = None, |
| | progress_callback: Optional[Callable[[float, str], None]] = None, |
| | ) -> Tuple[Path, Path]: |
| | """ |
| | Run MatAnyone streaming matting via our shape-guarded adapter. |
| | Returns (alpha_mp4_path, fg_mp4_path). |
| | Raises MatAnyError on failure. |
| | """ |
| | from .matanyone_loader import MatAnyoneSession, MatAnyError |
| |
|
| | session = MatAnyoneSession(device=device, precision="auto") |
| | alpha_p, fg_p = session.process_stream( |
| | video_path=Path(video_path), |
| | seed_mask_path=Path(mask_path) if mask_path else None, |
| | out_dir=Path(out_dir), |
| | progress_cb=progress_callback, |
| | ) |
| | return alpha_p, fg_p |