MogensR commited on
Commit
372831d
·
1 Parent(s): 041af61
Files changed (1) hide show
  1. models/__init__.py +760 -57
models/__init__.py CHANGED
@@ -1,75 +1,778 @@
1
- # models/sam2_loader.py
2
- import os, logging, torch
3
- from huggingface_hub import hf_hub_download
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from pathlib import Path
 
 
5
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- log = logging.getLogger("sam2_loader")
 
 
 
 
 
 
8
 
9
- DEFAULT_MODEL_ID = os.environ.get("SAM2_MODEL_ID", "facebook/sam2")
10
- DEFAULT_VARIANT = os.environ.get("SAM2_VARIANT", "sam2_hiera_large")
 
 
 
 
 
 
11
 
12
- # Map variant -> filenames (SAM2 releases follow this pattern)
13
- VARIANT_FILES = {
14
- "sam2_hiera_small": ("sam2_hiera_small.pt", "configs/sam2/sam2_hiera_s.yaml"),
15
- "sam2_hiera_base": ("sam2_hiera_base.pt", "configs/sam2/sam2_hiera_b.yaml"),
16
- "sam2_hiera_large": ("sam2_hiera_large.pt", "configs/sam2/sam2_hiera_l.yaml"),
17
- }
18
 
19
- def _download_checkpoint(model_id: str, ckpt_name: str) -> str:
20
- return hf_hub_download(repo_id=model_id, filename=ckpt_name, local_dir=os.environ.get("HF_HOME"))
 
 
 
 
 
21
 
22
- def _find_sam2_build():
 
 
 
 
 
 
23
  try:
24
- from sam2.build_sam import build_sam2
25
- return build_sam2
26
  except Exception as e:
27
- log.error("SAM2 not importable (check Dockerfile vendoring): %s", e)
28
  return None
29
 
30
- class SAM2Predictor:
31
- def __init__(self, device: torch.device):
32
- self.device = device
33
- self.model = None
34
- self.predictor = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- def load(self, variant: str = DEFAULT_VARIANT, model_id: str = DEFAULT_MODEL_ID):
37
- build_sam2 = _find_sam2_build()
38
- if build_sam2 is None:
39
- raise RuntimeError("SAM2 build function not available")
 
 
 
 
 
40
 
41
- ckpt_name, cfg_path = VARIANT_FILES.get(variant, VARIANT_FILES["sam2_hiera_large"])
42
- ckpt = _download_checkpoint(model_id, ckpt_name)
 
 
 
 
 
 
 
 
 
43
 
44
- # Compose config via hydra-free path (using explicit path args)
45
- model = build_sam2(config_file=cfg_path, ckpt_path=ckpt, device=str(self.device))
46
- model.eval()
47
- self.model = model
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  try:
50
- from sam2.sam2_video_predictor import SAM2VideoPredictor
51
- self.predictor = SAM2VideoPredictor(self.model)
52
  except Exception:
53
- # Fallback to image predictor if video predictor missing
54
- from sam2.sam2_image_predictor import SAM2ImagePredictor
55
- self.predictor = SAM2ImagePredictor(self.model)
56
-
57
- return self
58
-
59
- @torch.inference_mode()
60
- def first_frame_mask(self, image_rgb01):
61
- """
62
- Returns an initial binary-ish mask for the foreground subject from first frame.
63
- You can refine prompts here (points/boxes) if you add UI hooks later.
64
- """
65
- if hasattr(self.predictor, "set_image"):
66
- self.predictor.set_image((image_rgb01*255).astype("uint8"))
67
- # simple auto-box prompt (tight box)
68
- h, w = image_rgb01.shape[:2]
69
- box = np.array([1, 1, w-2, h-2])
70
- masks, _, _ = self.predictor.predict(box=box, multimask_output=False)
71
- mask = masks[0] # HxW bool/float
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  else:
73
- # video predictor path: run_single_frame if available
74
- mask = (image_rgb01[...,0] > -1) # dummy, should not happen
75
- return mask.astype("float32")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ BackgroundFX Pro - Model Loading & Utilities (Hardened)
4
+ ======================================================
5
+ - Avoids heavy CUDA/Hydra work at import time
6
+ - Adds timeouts to subprocess probes
7
+ - Safer sys.path wiring for third_party repos
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import os
13
+ import sys
14
+ import cv2
15
+ import subprocess
16
+ import inspect
17
+ import logging
18
  from pathlib import Path
19
+ from typing import Optional, Tuple, Dict, Any, Union
20
+
21
  import numpy as np
22
+ import yaml
23
+
24
+ # --------------------------------------------------------------------------------------
25
+ # Logging (ensure a handler exists very early)
26
+ # --------------------------------------------------------------------------------------
27
+ logger = logging.getLogger("backgroundfx_pro")
28
+ if not logger.handlers:
29
+ _h = logging.StreamHandler()
30
+ _h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s"))
31
+ logger.addHandler(_h)
32
+ logger.setLevel(logging.INFO)
33
 
34
+ # Pin OpenCV threads (helps libgomp stability in Spaces)
35
+ try:
36
+ cv_threads = int(os.environ.get("CV_THREADS", "1"))
37
+ if hasattr(cv2, "setNumThreads"):
38
+ cv2.setNumThreads(cv_threads)
39
+ except Exception:
40
+ pass
41
 
42
+ # --------------------------------------------------------------------------------------
43
+ # Optional dependencies
44
+ # --------------------------------------------------------------------------------------
45
+ try:
46
+ import mediapipe as mp # type: ignore
47
+ _HAS_MEDIAPIPE = True
48
+ except Exception:
49
+ _HAS_MEDIAPIPE = False
50
 
51
+ # --------------------------------------------------------------------------------------
52
+ # Path setup for third_party repos
53
+ # --------------------------------------------------------------------------------------
54
+ ROOT = Path(__file__).resolve().parent.parent # project root
55
+ TP_SAM2 = Path(os.environ.get("THIRD_PARTY_SAM2_DIR", ROOT / "third_party" / "sam2")).resolve()
56
+ TP_MATANY = Path(os.environ.get("THIRD_PARTY_MATANY_DIR", ROOT / "third_party" / "matanyone")).resolve()
57
 
58
+ def _add_sys_path(p: Path) -> None:
59
+ if p.exists():
60
+ p_str = str(p)
61
+ if p_str not in sys.path:
62
+ sys.path.insert(0, p_str)
63
+ else:
64
+ logger.warning(f"third_party path not found: {p}")
65
 
66
+ _add_sys_path(TP_SAM2)
67
+ _add_sys_path(TP_MATANY)
68
+
69
+ # --------------------------------------------------------------------------------------
70
+ # Safe Torch accessors (no top-level import)
71
+ # --------------------------------------------------------------------------------------
72
+ def _torch():
73
  try:
74
+ import torch # local import avoids early CUDA init during module import
75
+ return torch
76
  except Exception as e:
77
+ logger.warning(f"[models.safe-torch] import failed: {e}")
78
  return None
79
 
80
+ def _has_cuda() -> bool:
81
+ t = _torch()
82
+ if t is None:
83
+ return False
84
+ try:
85
+ return bool(t.cuda.is_available())
86
+ except Exception as e:
87
+ logger.warning(f"[models.safe-torch] cuda.is_available() failed: {e}")
88
+ return False
89
+
90
+ def _pick_device(env_key: str) -> str:
91
+ requested = os.environ.get(env_key, "").strip().lower()
92
+ if requested in {"cuda", "cpu"}:
93
+ return requested
94
+ return "cuda" if _has_cuda() else "cpu"
95
+
96
+ # --------------------------------------------------------------------------------------
97
+ # Basic Utilities
98
+ # --------------------------------------------------------------------------------------
99
+ def _ffmpeg_bin() -> str:
100
+ return os.environ.get("FFMPEG_BIN", "ffmpeg")
101
+
102
+ def _probe_ffmpeg(timeout: int = 2) -> bool:
103
+ try:
104
+ subprocess.run([_ffmpeg_bin(), "-version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True, timeout=timeout)
105
+ return True
106
+ except Exception:
107
+ return False
108
+
109
+ def _ensure_dir(p: Path) -> None:
110
+ p.mkdir(parents=True, exist_ok=True)
111
+
112
+ def _cv_read_first_frame(video_path: Union[str, Path]) -> Tuple[Optional[np.ndarray], int, Tuple[int, int]]:
113
+ cap = cv2.VideoCapture(str(video_path))
114
+ if not cap.isOpened():
115
+ return None, 0, (0, 0)
116
+ fps = int(round(cap.get(cv2.CAP_PROP_FPS) or 25))
117
+ ok, frame = cap.read()
118
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0)
119
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0)
120
+ cap.release()
121
+ if not ok:
122
+ return None, fps, (w, h)
123
+ return frame, fps, (w, h)
124
+
125
+ def _save_mask_png(mask: np.ndarray, path: Union[str, Path]) -> str:
126
+ if mask.dtype == bool:
127
+ mask = (mask.astype(np.uint8) * 255)
128
+ elif mask.dtype != np.uint8:
129
+ mask = np.clip(mask, 0, 255).astype(np.uint8)
130
+ cv2.imwrite(str(path), mask)
131
+ return str(path)
132
+
133
+ def _resize_keep_ar(image: np.ndarray, target_wh: Tuple[int, int]) -> np.ndarray:
134
+ tw, th = target_wh
135
+ h, w = image.shape[:2]
136
+ if h == 0 or w == 0 or tw == 0 or th == 0:
137
+ return image
138
+ scale = min(tw / w, th / h)
139
+ nw, nh = max(1, int(round(w * scale))), max(1, int(round(h * scale)))
140
+ resized = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_CUBIC)
141
+ canvas = np.zeros((th, tw, 3), dtype=resized.dtype)
142
+ x0 = (tw - nw) // 2
143
+ y0 = (th - nh) // 2
144
+ canvas[y0:y0+nh, x0:x0+nw] = resized
145
+ return canvas
146
+
147
+ def _video_writer(out_path: Path, fps: int, size: Tuple[int, int]) -> cv2.VideoWriter:
148
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
149
+ return cv2.VideoWriter(str(out_path), fourcc, max(1, fps), size)
150
+
151
+ def _mux_audio(src_video: Union[str, Path], silent_video: Union[str, Path], out_path: Union[str, Path]) -> bool:
152
+ """Copy video from silent_video + audio from src_video into out_path (AAC)."""
153
+ try:
154
+ cmd = [
155
+ _ffmpeg_bin(), "-y",
156
+ "-i", str(silent_video),
157
+ "-i", str(src_video),
158
+ "-map", "0:v:0",
159
+ "-map", "1:a:0?",
160
+ "-c:v", "copy",
161
+ "-c:a", "aac", "-b:a", "192k",
162
+ "-shortest",
163
+ str(out_path)
164
+ ]
165
+ subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
166
+ return True
167
+ except Exception as e:
168
+ logger.warning(f"Audio mux failed; returning silent video. Reason: {e}")
169
+ return False
170
+
171
+ # --------------------------------------------------------------------------------------
172
+ # Compositing & Image Processing
173
+ # --------------------------------------------------------------------------------------
174
+ def _refine_alpha(alpha: np.ndarray, erode_px: int = 1, dilate_px: int = 2, blur_px: float = 1.5) -> np.ndarray:
175
+ if alpha.dtype != np.float32:
176
+ a = alpha.astype(np.float32)
177
+ if a.max() > 1.0:
178
+ a = a / 255.0
179
+ else:
180
+ a = alpha.copy()
181
+
182
+ a_u8 = np.clip(np.round(a * 255.0), 0, 255).astype(np.uint8)
183
+ if erode_px > 0:
184
+ k = max(1, int(erode_px))
185
+ a_u8 = cv2.erode(a_u8, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k)), iterations=1)
186
+ if dilate_px > 0:
187
+ k = max(1, int(dilate_px))
188
+ a_u8 = cv2.dilate(a_u8, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k)), iterations=1)
189
+ a = a_u8.astype(np.float32) / 255.0
190
+
191
+ if blur_px and blur_px > 0:
192
+ rad = max(1, int(round(blur_px)))
193
+ a = cv2.GaussianBlur(a, (rad | 1, rad | 1), 0)
194
+
195
+ return np.clip(a, 0.0, 1.0)
196
+
197
+ def _to_linear(rgb: np.ndarray, gamma: float = 2.2) -> np.ndarray:
198
+ x = np.clip(rgb.astype(np.float32) / 255.0, 0.0, 1.0)
199
+ return np.power(x, gamma)
200
+
201
+ def _to_srgb(lin: np.ndarray, gamma: float = 2.2) -> np.ndarray:
202
+ x = np.clip(lin, 0.0, 1.0)
203
+ return np.clip(np.power(x, 1.0 / gamma) * 255.0, 0, 255).astype(np.uint8)
204
+
205
+ def _light_wrap(bg_rgb: np.ndarray, alpha01: np.ndarray, radius: int = 5, amount: float = 0.18) -> np.ndarray:
206
+ r = max(1, int(radius))
207
+ inv = 1.0 - alpha01
208
+ inv_blur = cv2.GaussianBlur(inv, (r | 1, r | 1), 0)
209
+ lw = (bg_rgb.astype(np.float32) * inv_blur[..., None] * float(amount))
210
+ return lw
211
 
212
+ def _despill_edges(fg_rgb: np.ndarray, alpha01: np.ndarray, amount: float = 0.35) -> np.ndarray:
213
+ w = 1.0 - 2.0 * np.abs(alpha01 - 0.5)
214
+ w = np.clip(w, 0.0, 1.0)
215
+ hsv = cv2.cvtColor(fg_rgb.astype(np.uint8), cv2.COLOR_RGB2HSV).astype(np.float32)
216
+ H, S, V = cv2.split(hsv)
217
+ S = S * (1.0 - amount * w)
218
+ hsv2 = cv2.merge([H, np.clip(S, 0, 255), V])
219
+ out = cv2.cvtColor(hsv2.astype(np.uint8), cv2.COLOR_HSV2RGB)
220
+ return out
221
 
222
+ def _composite_frame_pro(
223
+ fg_rgb: np.ndarray, alpha: np.ndarray, bg_rgb: np.ndarray,
224
+ erode_px: int = None, dilate_px: int = None, blur_px: float = None,
225
+ lw_radius: int = None, lw_amount: float = None, despill_amount: float = None
226
+ ) -> np.ndarray:
227
+ erode_px = erode_px if erode_px is not None else int(os.environ.get("EDGE_ERODE", "1"))
228
+ dilate_px = dilate_px if dilate_px is not None else int(os.environ.get("EDGE_DILATE", "2"))
229
+ blur_px = blur_px if blur_px is not None else float(os.environ.get("EDGE_BLUR", "1.5"))
230
+ lw_radius = lw_radius if lw_radius is not None else int(os.environ.get("LIGHTWRAP_RADIUS", "5"))
231
+ lw_amount = lw_amount if lw_amount is not None else float(os.environ.get("LIGHTWRAP_AMOUNT", "0.18"))
232
+ despill_amount = despill_amount if despill_amount is not None else float(os.environ.get("DESPILL_AMOUNT", "0.35"))
233
 
234
+ a = _refine_alpha(alpha, erode_px=erode_px, dilate_px=dilate_px, blur_px=blur_px)
235
+ fg_rgb = _despill_edges(fg_rgb, a, amount=despill_amount)
 
 
236
 
237
+ fg_lin = _to_linear(fg_rgb)
238
+ bg_lin = _to_linear(bg_rgb)
239
+ lw = _light_wrap(bg_rgb, a, radius=lw_radius, amount=lw_amount)
240
+ lw_lin = _to_linear(np.clip(lw, 0, 255).astype(np.uint8))
241
+
242
+ comp_lin = fg_lin * a[..., None] + bg_lin * (1.0 - a[..., None]) + lw_lin
243
+ comp = _to_srgb(comp_lin)
244
+ return comp
245
+
246
+ # --------------------------------------------------------------------------------------
247
+ # SAM2 Integration
248
+ # --------------------------------------------------------------------------------------
249
+ def _resolve_sam2_cfg(cfg_str: str) -> str:
250
+ """Make the SAM2 config path absolute (prefer inside TP_SAM2)."""
251
+ cfg_path = Path(cfg_str)
252
+ if not cfg_path.is_absolute():
253
+ candidate = TP_SAM2 / cfg_path
254
+ if candidate.exists():
255
+ return str(candidate)
256
+ if cfg_path.exists():
257
+ return str(cfg_path)
258
+ for name in ["configs/sam2/sam2_hiera_l.yaml", "configs/sam2/sam2_hiera_b.yaml", "configs/sam2/sam2_hiera_s.yaml"]:
259
+ p = TP_SAM2 / name
260
+ if p.exists():
261
+ return str(p)
262
+ return str(cfg_str)
263
+
264
+ def _find_hiera_config_if_hieradet(cfg_path: str) -> Optional[str]:
265
+ """If config references 'hieradet', try to find a 'hiera' config."""
266
+ try:
267
+ with open(cfg_path, "r") as f:
268
+ data = yaml.safe_load(f)
269
+ model = data.get("model", {}) or {}
270
+ enc = model.get("image_encoder") or {}
271
+ trunk = enc.get("trunk") or {}
272
+ target = trunk.get("_target_") or trunk.get("target")
273
+ if isinstance(target, str) and "hieradet" in target:
274
+ for y in TP_SAM2.rglob("*.yaml"):
275
+ try:
276
+ with open(y, "r") as f2:
277
+ d2 = yaml.safe_load(f2) or {}
278
+ e2 = (d2.get("model", {}) or {}).get("image_encoder") or {}
279
+ t2 = (e2.get("trunk") or {})
280
+ tgt2 = t2.get("_target_") or t2.get("target")
281
+ if isinstance(tgt2, str) and ".hiera." in tgt2:
282
+ logger.info(f"SAM2: switching config from 'hieradet' → 'hiera': {y}")
283
+ return str(y)
284
+ except Exception:
285
+ continue
286
+ except Exception:
287
+ pass
288
+ return None
289
+
290
+ def load_sam2() -> Tuple[Optional[object], bool, Dict[str, Any]]:
291
+ """Robust SAM2 loader with config resolution and error handling."""
292
+ meta = {"sam2_import_ok": False, "sam2_init_ok": False}
293
+ try:
294
+ from sam2.build_sam import build_sam2 # type: ignore
295
+ from sam2.sam2_image_predictor import SAM2ImagePredictor # type: ignore
296
+ meta["sam2_import_ok"] = True
297
+ except Exception as e:
298
+ logger.warning(f"SAM2 import failed: {e}")
299
+ return None, False, meta
300
+
301
+ device = _pick_device("SAM2_DEVICE")
302
+ cfg_env = os.environ.get("SAM2_MODEL_CFG", "configs/sam2/sam2_hiera_l.yaml")
303
+ cfg = _resolve_sam2_cfg(cfg_env)
304
+ ckpt = os.environ.get("SAM2_CHECKPOINT", "")
305
+
306
+ def _try_build(cfg_path: str):
307
+ params = set(inspect.signature(build_sam2).parameters.keys())
308
+ kwargs = {}
309
+ if "config_file" in params:
310
+ kwargs["config_file"] = cfg_path
311
+ elif "model_cfg" in params:
312
+ kwargs["model_cfg"] = cfg_path
313
+ if ckpt:
314
+ if "checkpoint" in params:
315
+ kwargs["checkpoint"] = ckpt
316
+ elif "ckpt_path" in params:
317
+ kwargs["ckpt_path"] = ckpt
318
+ elif "weights" in params:
319
+ kwargs["weights"] = ckpt
320
+ if "device" in params:
321
+ kwargs["device"] = device
322
+ try:
323
+ return build_sam2(**kwargs)
324
+ except TypeError:
325
+ pos = [cfg_path]
326
+ if ckpt:
327
+ pos.append(ckpt)
328
+ if "device" not in kwargs:
329
+ pos.append(device)
330
+ return build_sam2(*pos)
331
+
332
+ try:
333
  try:
334
+ sam = _try_build(cfg)
 
335
  except Exception:
336
+ alt_cfg = _find_hiera_config_if_hieradet(cfg)
337
+ if alt_cfg:
338
+ logger.info(f"SAM2: retrying with alt config: {alt_cfg}")
339
+ sam = _try_build(alt_cfg)
340
+ cfg = alt_cfg
341
+ else:
342
+ raise
343
+
344
+ predictor = SAM2ImagePredictor(sam)
345
+ meta.update({
346
+ "sam2_init_ok": True,
347
+ "sam2_device": device,
348
+ "sam2_cfg": cfg,
349
+ "sam2_ckpt": ckpt or "(repo default)"
350
+ })
351
+ return predictor, True, meta
352
+ except Exception as e:
353
+ logger.error(f"SAM2 init failed: {e}")
354
+ return None, False, meta
355
+
356
+ def run_sam2_mask(predictor: object,
357
+ first_frame_bgr: np.ndarray,
358
+ point: Optional[Tuple[int, int]] = None,
359
+ auto: bool = False) -> Tuple[Optional[np.ndarray], bool]:
360
+ """Return (mask_uint8_0_255, ok)."""
361
+ if predictor is None:
362
+ return None, False
363
+ try:
364
+ rgb = cv2.cvtColor(first_frame_bgr, cv2.COLOR_BGR2RGB)
365
+ predictor.set_image(rgb)
366
+
367
+ if auto:
368
+ h, w = rgb.shape[:2]
369
+ box = np.array([int(0.05*w), int(0.05*h), int(0.95*w), int(0.95*h)])
370
+ masks, _, _ = predictor.predict(box=box)
371
+ elif point is not None:
372
+ x, y = int(point[0]), int(point[1])
373
+ pts = np.array([[x, y]], dtype=np.int32)
374
+ labels = np.array([1], dtype=np.int32)
375
+ masks, _, _ = predictor.predict(point_coords=pts, point_labels=labels)
376
  else:
377
+ h, w = rgb.shape[:2]
378
+ box = np.array([int(0.1*w), int(0.1*h), int(0.9*w), int(0.9*h)])
379
+ masks, _, _ = predictor.predict(box=box)
380
+
381
+ if masks is None or len(masks) == 0:
382
+ return None, False
383
+
384
+ m = masks[0].astype(np.uint8) * 255
385
+ return m, True
386
+ except Exception as e:
387
+ logger.warning(f"SAM2 mask failed: {e}")
388
+ return None, False
389
+
390
+ def _refine_mask_grabcut(image_bgr: np.ndarray,
391
+ mask_u8: np.ndarray,
392
+ iters: int = None,
393
+ trimap_erode: int = None,
394
+ trimap_dilate: int = None) -> np.ndarray:
395
+ """Use SAM2 seed as initialization for GrabCut refinement."""
396
+ iters = int(os.environ.get("REFINE_GRABCUT_ITERS", "2")) if iters is None else int(iters)
397
+ e = int(os.environ.get("REFINE_TRIMAP_ERODE", "3")) if trimap_erode is None else int(trimap_erode)
398
+ d = int(os.environ.get("REFINE_TRIMAP_DILATE", "6")) if trimap_dilate is None else int(trimap_dilate)
399
+
400
+ h, w = mask_u8.shape[:2]
401
+ m = (mask_u8 > 127).astype(np.uint8) * 255
402
+
403
+ sure_fg = cv2.erode(m, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (max(1, e), max(1, e))), iterations=1)
404
+ sure_bg = cv2.erode(255 - m, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (max(1, d), max(1, d))), iterations=1)
405
+
406
+ gc_mask = np.full((h, w), cv2.GC_PR_BGD, dtype=np.uint8)
407
+ gc_mask[sure_bg > 0] = cv2.GC_BGD
408
+ gc_mask[sure_fg > 0] = cv2.GC_FGD
409
+
410
+ bgdModel = np.zeros((1, 65), np.float64)
411
+ fgdModel = np.zeros((1, 65), np.float64)
412
+ try:
413
+ cv2.grabCut(image_bgr, gc_mask, None, bgdModel, fgdModel, iters, cv2.GC_INIT_WITH_MASK)
414
+ out = np.where((gc_mask == cv2.GC_FGD) | (gc_mask == cv2.GC_PR_FGD), 255, 0).astype(np.uint8)
415
+ out = cv2.medianBlur(out, 5)
416
+ return out
417
+ except Exception as e:
418
+ logger.warning(f"GrabCut refinement failed; using original mask. Reason: {e}")
419
+ return m
420
+
421
+ # --------------------------------------------------------------------------------------
422
+ # MatAnyone Integration
423
+ # --------------------------------------------------------------------------------------
424
+ def load_matany() -> Tuple[Optional[object], bool, Dict[str, Any]]:
425
+ """MatAnyone loader with disable switch and error handling."""
426
+ meta = {"matany_import_ok": False, "matany_init_ok": False}
427
+
428
+ enable_env = os.environ.get("ENABLE_MATANY", "1").strip().lower()
429
+ if enable_env in {"0", "false", "off", "no"}:
430
+ logger.info("MatAnyone disabled by ENABLE_MATANY=0.")
431
+ meta["disabled"] = True
432
+ return None, False, meta
433
+
434
+ try:
435
+ from matanyone import InferenceCore # type: ignore
436
+ meta["matany_import_ok"] = True
437
+ except Exception as e:
438
+ logger.warning(f"MatAnyone import failed: {e}")
439
+ return None, False, meta
440
+
441
+ device = _pick_device("MATANY_DEVICE")
442
+ repo_id = os.environ.get("MATANY_REPO_ID", "PeiqingYang/MatAnyone")
443
+
444
+ try:
445
+ matany = InferenceCore(repo_id)
446
+ meta["matany_init_ok"] = True
447
+ meta["matany_device"] = device
448
+ meta["matany_repo_id"] = repo_id
449
+ return matany, True, meta
450
+ except Exception as e:
451
+ logger.error(f"MatAnyone init failed: {e}")
452
+ return None, False, meta
453
+
454
+ def run_matany(matany: object,
455
+ video_path: Union[str, Path],
456
+ first_mask_path: Union[str, Path],
457
+ work_dir: Union[str, Path]) -> Tuple[Optional[str], Optional[str], bool]:
458
+ """Return (foreground_video_path, alpha_video_path, ok)."""
459
+ if matany is None:
460
+ return None, None, False
461
+ try:
462
+ if hasattr(matany, "process_video"):
463
+ out = matany.process_video(input_path=str(video_path), mask_path=str(first_mask_path), output_dir=str(work_dir))
464
+ if isinstance(out, (list, tuple)) and len(out) >= 2:
465
+ return str(out[0]), str(out[1]), True
466
+ if isinstance(out, dict):
467
+ fg = out.get("foreground") or out.get("fg") or out.get("foreground_path")
468
+ al = out.get("alpha") or out.get("alpha_path")
469
+ if fg and al:
470
+ return str(fg), str(al), True
471
+
472
+ if hasattr(matany, "run"):
473
+ out = matany.run(video_path=str(video_path), seed_mask=str(first_mask_path), out_dir=str(work_dir))
474
+ if isinstance(out, dict):
475
+ fg = out.get("foreground") or out.get("fg") or out.get("foreground_path")
476
+ al = out.get("alpha") or out.get("alpha_path")
477
+ if fg and al:
478
+ return str(fg), str(al), True
479
+
480
+ logger.error("MatAnyone returned no usable paths.")
481
+ return None, None, False
482
+ except Exception as e:
483
+ logger.warning(f"MatAnyone processing failed: {e}")
484
+ return None, None, False
485
+
486
+ # --------------------------------------------------------------------------------------
487
+ # Fallback Functions
488
+ # --------------------------------------------------------------------------------------
489
+ def fallback_mask(first_frame_bgr: np.ndarray) -> np.ndarray:
490
+ """Prefer MediaPipe; fallback to GrabCut. Returns uint8 mask 0/255."""
491
+ h, w = first_frame_bgr.shape[:2]
492
+ if _HAS_MEDIAPIPE:
493
+ try:
494
+ mp_selfie = mp.solutions.selfie_segmentation
495
+ with mp_selfie.SelfieSegmentation(model_selection=1) as segmenter:
496
+ rgb = cv2.cvtColor(first_frame_bgr, cv2.COLOR_BGR2RGB)
497
+ res = segmenter.process(rgb)
498
+ m = (np.clip(res.segmentation_mask, 0, 1) > 0.5).astype(np.uint8) * 255
499
+ m = cv2.medianBlur(m, 5)
500
+ return m
501
+ except Exception as e:
502
+ logger.warning(f"MediaPipe fallback failed: {e}")
503
+
504
+ # Ultimate fallback: GrabCut
505
+ mask = np.zeros((h, w), np.uint8)
506
+ rect = (int(0.1*w), int(0.1*h), int(0.8*w), int(0.8*h))
507
+ bgdModel = np.zeros((1, 65), np.float64)
508
+ fgdModel = np.zeros((1, 65), np.float64)
509
+ try:
510
+ cv2.grabCut(first_frame_bgr, mask, rect, bgdModel, fgdModel, 5, cv2.GC_INIT_WITH_RECT)
511
+ mask_bin = np.where((mask == cv2.GC_FGD) | (mask == cv2.GC_PR_FGD), 255, 0).astype(np.uint8)
512
+ return mask_bin
513
+ except Exception as e:
514
+ logger.warning(f"GrabCut failed: {e}")
515
+ return np.zeros((h, w), dtype=np.uint8)
516
+
517
+ def composite_video(fg_path: Union[str, Path],
518
+ alpha_path: Union[str, Path],
519
+ bg_image_path: Union[str, Path],
520
+ out_path: Union[str, Path],
521
+ fps: int,
522
+ size: Tuple[int, int]) -> bool:
523
+ """Blend MatAnyone FG+ALPHA over background using pro compositor."""
524
+ fg_cap = cv2.VideoCapture(str(fg_path))
525
+ al_cap = cv2.VideoCapture(str(alpha_path))
526
+ if not fg_cap.isOpened() or not al_cap.isOpened():
527
+ return False
528
+
529
+ w, h = size
530
+ bg = cv2.imread(str(bg_image_path), cv2.IMREAD_COLOR)
531
+ if bg is None:
532
+ bg = np.full((h, w, 3), 127, dtype=np.uint8)
533
+ bg_f = _resize_keep_ar(bg, (w, h))
534
+
535
+ if _probe_ffmpeg():
536
+ tmp_out = Path(str(out_path) + ".tmp.mp4")
537
+ writer = _video_writer(tmp_out, fps, (w, h))
538
+ post_h264 = True
539
+ else:
540
+ writer = _video_writer(Path(out_path), fps, (w, h))
541
+ post_h264 = False
542
+
543
+ ok_any = False
544
+ try:
545
+ while True:
546
+ ok_fg, fg = fg_cap.read()
547
+ ok_al, al = al_cap.read()
548
+ if not ok_fg or not ok_al:
549
+ break
550
+ fg = cv2.resize(fg, (w, h), interpolation=cv2.INTER_CUBIC)
551
+ al_gray = cv2.cvtColor(cv2.resize(al, (w, h)), cv2.COLOR_BGR2GRAY)
552
+
553
+ comp = _composite_frame_pro(
554
+ cv2.cvtColor(fg, cv2.COLOR_BGR2RGB),
555
+ al_gray,
556
+ cv2.cvtColor(bg_f, cv2.COLOR_BGR2RGB)
557
+ )
558
+ writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR))
559
+ ok_any = True
560
+ finally:
561
+ fg_cap.release()
562
+ al_cap.release()
563
+ writer.release()
564
+
565
+ if post_h264 and ok_any:
566
+ try:
567
+ cmd = [
568
+ _ffmpeg_bin(), "-y",
569
+ "-i", str(tmp_out),
570
+ "-c:v", "libx264", "-pix_fmt", "yuv420p", "-movflags", "+faststart",
571
+ str(out_path)
572
+ ]
573
+ subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
574
+ tmp_out.unlink(missing_ok=True)
575
+ except Exception as e:
576
+ logger.warning(f"ffmpeg finalize failed: {e}")
577
+ Path(out_path).unlink(missing_ok=True)
578
+ tmp_out.replace(out_path)
579
+
580
+ return ok_any
581
+
582
+ def fallback_composite(video_path: Union[str, Path],
583
+ mask_path: Union[str, Path],
584
+ bg_image_path: Union[str, Path],
585
+ out_path: Union[str, Path]) -> bool:
586
+ """Static-mask compositing using pro compositor."""
587
+ mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
588
+ cap = cv2.VideoCapture(str(video_path))
589
+ if mask is None or not cap.isOpened():
590
+ return False
591
+
592
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0)
593
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0)
594
+ fps = int(round(cap.get(cv2.CAP_PROP_FPS) or 25))
595
+
596
+ bg = cv2.imread(str(bg_image_path), cv2.IMREAD_COLOR)
597
+ if bg is None:
598
+ bg = np.full((h, w, 3), 127, dtype=np.uint8)
599
+
600
+ mask_resized = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
601
+ bg_f = _resize_keep_ar(bg, (w, h))
602
+
603
+ if _probe_ffmpeg():
604
+ tmp_out = Path(str(out_path) + ".tmp.mp4")
605
+ writer = _video_writer(tmp_out, fps, (w, h))
606
+ use_post_ffmpeg = True
607
+ else:
608
+ writer = _video_writer(Path(out_path), fps, (w, h))
609
+ use_post_ffmpeg = False
610
+
611
+ ok_any = False
612
+ try:
613
+ while True:
614
+ ok, frame = cap.read()
615
+ if not ok:
616
+ break
617
+ comp = _composite_frame_pro(
618
+ cv2.cvtColor(frame, cv2.COLOR_BGR2RGB),
619
+ mask_resized,
620
+ cv2.cvtColor(bg_f, cv2.COLOR_BGR2RGB)
621
+ )
622
+ writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR))
623
+ ok_any = True
624
+ finally:
625
+ cap.release()
626
+ writer.release()
627
+
628
+ if use_post_ffmpeg and ok_any:
629
+ try:
630
+ cmd = [
631
+ _ffmpeg_bin(), "-y",
632
+ "-i", str(tmp_out),
633
+ "-c:v", "libx264", "-pix_fmt", "yuv420p", "-movflags", "+faststart",
634
+ str(out_path)
635
+ ]
636
+ subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
637
+ tmp_out.unlink(missing_ok=True)
638
+ except Exception as e:
639
+ logger.warning(f"ffmpeg H.264 finalize failed: {e}")
640
+ Path(out_path).unlink(missing_ok=True)
641
+ tmp_out.replace(out_path)
642
+
643
+ return ok_any
644
+
645
+ # --------------------------------------------------------------------------------------
646
+ # Stage-A (Transparent Export) Functions
647
+ # --------------------------------------------------------------------------------------
648
+ def _checkerboard_bg(w: int, h: int, tile: int = 32) -> np.ndarray:
649
+ y, x = np.mgrid[0:h, 0:w]
650
+ c = ((x // tile) + (y // tile)) % 2
651
+ a = np.where(c == 0, 200, 150).astype(np.uint8)
652
+ return np.stack([a, a, a], axis=-1)
653
+
654
+ def _build_stage_a_rgba_vp9_from_fg_alpha(
655
+ fg_path: Union[str, Path],
656
+ alpha_path: Union[str, Path],
657
+ out_webm: Union[str, Path],
658
+ fps: int,
659
+ size: Tuple[int, int],
660
+ src_audio: Optional[Union[str, Path]] = None,
661
+ ) -> bool:
662
+ if not _probe_ffmpeg():
663
+ return False
664
+ w, h = size
665
+ try:
666
+ cmd = [_ffmpeg_bin(), "-y", "-i", str(fg_path), "-i", str(alpha_path)]
667
+ if src_audio:
668
+ cmd += ["-i", str(src_audio)]
669
+ fcx = f"[1:v]format=gray,scale={w}:{h},fps={fps}[al];" \
670
+ f"[0:v]scale={w}:{h},fps={fps}[fg];" \
671
+ f"[fg][al]alphamerge[outv]"
672
+ cmd += ["-filter_complex", fcx, "-map", "[outv]"]
673
+ if src_audio:
674
+ cmd += ["-map", "2:a:0?", "-c:a", "libopus", "-b:a", "128k"]
675
+ cmd += [
676
+ "-c:v", "libvpx-vp9", "-pix_fmt", "yuva420p",
677
+ "-crf", os.environ.get("STAGEA_VP9_CRF", "28"),
678
+ "-b:v", "0", "-row-mt", "1", "-shortest", str(out_webm),
679
+ ]
680
+ subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
681
+ return True
682
+ except Exception as e:
683
+ logger.warning(f"Stage-A VP9(alpha) build failed: {e}")
684
+ return False
685
+
686
+ def _build_stage_a_rgba_vp9_from_mask(
687
+ video_path: Union[str, Path],
688
+ mask_png: Union[str, Path],
689
+ out_webm: Union[str, Path],
690
+ fps: int,
691
+ size: Tuple[int, int],
692
+ ) -> bool:
693
+ if not _probe_ffmpeg():
694
+ return False
695
+ w, h = size
696
+ try:
697
+ cmd = [
698
+ _ffmpeg_bin(), "-y",
699
+ "-i", str(video_path),
700
+ "-loop", "1", "-i", str(mask_png),
701
+ "-filter_complex",
702
+ f"[1:v]format=gray,scale={w}:{h},fps={fps}[al];"
703
+ f"[0:v]scale={w}:{h},fps={fps}[fg];"
704
+ f"[fg][al]alphamerge[outv]",
705
+ "-map", "[outv]",
706
+ "-c:v", "libvpx-vp9", "-pix_fmt", "yuva420p",
707
+ "-crf", os.environ.get("STAGEA_VP9_CRF", "28"),
708
+ "-b:v", "0", "-row-mt", "1", "-shortest", str(out_webm),
709
+ ]
710
+ subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
711
+ return True
712
+ except Exception as e:
713
+ logger.warning(f"Stage-A VP9(alpha) (mask) build failed: {e}")
714
+ return False
715
+
716
+ def _build_stage_a_checkerboard_from_fg_alpha(
717
+ fg_path: Union[str, Path],
718
+ alpha_path: Union[str, Path],
719
+ out_mp4: Union[str, Path],
720
+ fps: int,
721
+ size: Tuple[int, int],
722
+ ) -> bool:
723
+ fg_cap = cv2.VideoCapture(str(fg_path))
724
+ al_cap = cv2.VideoCapture(str(alpha_path))
725
+ if not fg_cap.isOpened() or not al_cap.isOpened():
726
+ return False
727
+ w, h = size
728
+ writer = _video_writer(Path(out_mp4), fps, (w, h))
729
+ bg = _checkerboard_bg(w, h)
730
+ ok_any = False
731
+ try:
732
+ while True:
733
+ okf, fg = fg_cap.read()
734
+ oka, al = al_cap.read()
735
+ if not okf or not oka:
736
+ break
737
+ fg = cv2.resize(fg, (w, h))
738
+ al = cv2.cvtColor(cv2.resize(al, (w, h)), cv2.COLOR_BGR2GRAY)
739
+ comp = _composite_frame_pro(cv2.cvtColor(fg, cv2.COLOR_BGR2RGB), al, bg)
740
+ writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR))
741
+ ok_any = True
742
+ finally:
743
+ fg_cap.release()
744
+ al_cap.release()
745
+ writer.release()
746
+ return ok_any
747
+
748
+ def _build_stage_a_checkerboard_from_mask(
749
+ video_path: Union[str, Path],
750
+ mask_png: Union[str, Path],
751
+ out_mp4: Union[str, Path],
752
+ fps: int,
753
+ size: Tuple[int, int],
754
+ ) -> bool:
755
+ cap = cv2.VideoCapture(str(video_path))
756
+ if not cap.isOpened():
757
+ return False
758
+ w, h = size
759
+ mask = cv2.imread(str(mask_png), cv2.IMREAD_GRAYSCALE)
760
+ if mask is None:
761
+ return False
762
+ mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
763
+ writer = _video_writer(Path(out_mp4), fps, (w, h))
764
+ bg = _checkerboard_bg(w, h)
765
+ ok_any = False
766
+ try:
767
+ while True:
768
+ ok, frame = cap.read()
769
+ if not ok:
770
+ break
771
+ frame = cv2.resize(frame, (w, h))
772
+ comp = _composite_frame_pro(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), mask, bg)
773
+ writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR))
774
+ ok_any = True
775
+ finally:
776
+ cap.release()
777
+ writer.release()
778
+ return ok_any