MogensR commited on
Commit
b595219
·
1 Parent(s): 0deca70
Files changed (2) hide show
  1. models.py +793 -0
  2. pipeline.py +229 -945
models.py ADDED
@@ -0,0 +1,793 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ BackgroundFX Pro - Model Loading & Utilities
4
+ ===========================================
5
+ Contains all model loading, inference functions, and utility functions
6
+ moved from the main pipeline for better organization.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import os
12
+ import sys
13
+ import cv2
14
+ import subprocess
15
+ import inspect
16
+ import logging
17
+ from pathlib import Path
18
+ from typing import Optional, Tuple, Dict, Any, Union
19
+
20
+ import numpy as np
21
+ import yaml
22
+
23
+ # --------------------------------------------------------------------------------------
24
+ # Logging
25
+ # --------------------------------------------------------------------------------------
26
+ logger = logging.getLogger("backgroundfx_pro")
27
+
28
+ # --------------------------------------------------------------------------------------
29
+ # Optional dependencies
30
+ # --------------------------------------------------------------------------------------
31
+ try:
32
+ import mediapipe as mp # type: ignore
33
+ _HAS_MEDIAPIPE = True
34
+ except Exception:
35
+ _HAS_MEDIAPIPE = False
36
+
37
+ # --------------------------------------------------------------------------------------
38
+ # Path setup for third_party repos
39
+ # --------------------------------------------------------------------------------------
40
+ ROOT = Path(__file__).resolve().parent
41
+ TP_SAM2 = Path(os.environ.get("THIRD_PARTY_SAM2_DIR", ROOT / "third_party" / "sam2")).resolve()
42
+ TP_MATANY = Path(os.environ.get("THIRD_PARTY_MATANY_DIR", ROOT / "third_party" / "matanyone")).resolve()
43
+
44
+ def _add_sys_path(p: Path) -> None:
45
+ p_str = str(p)
46
+ if p_str not in sys.path:
47
+ sys.path.insert(0, p_str)
48
+
49
+ _add_sys_path(TP_SAM2)
50
+ _add_sys_path(TP_MATANY)
51
+
52
+ # --------------------------------------------------------------------------------------
53
+ # Basic Utilities
54
+ # --------------------------------------------------------------------------------------
55
+ def _ffmpeg_bin() -> str:
56
+ return os.environ.get("FFMPEG_BIN", "ffmpeg")
57
+
58
+ def _probe_ffmpeg() -> bool:
59
+ try:
60
+ subprocess.run([_ffmpeg_bin(), "-version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
61
+ return True
62
+ except Exception:
63
+ return False
64
+
65
+ def _has_cuda() -> bool:
66
+ try:
67
+ import torch # type: ignore
68
+ return torch.cuda.is_available()
69
+ except Exception:
70
+ return False
71
+
72
+ def _pick_device(env_key: str) -> str:
73
+ requested = os.environ.get(env_key, "").strip().lower()
74
+ if requested in {"cuda", "cpu"}:
75
+ return requested
76
+ return "cuda" if _has_cuda() else "cpu"
77
+
78
+ def _ensure_dir(p: Path) -> None:
79
+ p.mkdir(parents=True, exist_ok=True)
80
+
81
+ def _cv_read_first_frame(video_path: Union[str, Path]) -> Tuple[Optional[np.ndarray], int, Tuple[int, int]]:
82
+ cap = cv2.VideoCapture(str(video_path))
83
+ if not cap.isOpened():
84
+ return None, 0, (0, 0)
85
+ fps = int(round(cap.get(cv2.CAP_PROP_FPS) or 25))
86
+ ok, frame = cap.read()
87
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0)
88
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0)
89
+ cap.release()
90
+ if not ok:
91
+ return None, fps, (w, h)
92
+ return frame, fps, (w, h)
93
+
94
+ def _save_mask_png(mask: np.ndarray, path: Union[str, Path]) -> str:
95
+ if mask.dtype == bool:
96
+ mask = (mask.astype(np.uint8) * 255)
97
+ elif mask.dtype != np.uint8:
98
+ mask = np.clip(mask, 0, 255).astype(np.uint8)
99
+ cv2.imwrite(str(path), mask)
100
+ return str(path)
101
+
102
+ def _resize_keep_ar(image: np.ndarray, target_wh: Tuple[int, int]) -> np.ndarray:
103
+ tw, th = target_wh
104
+ h, w = image.shape[:2]
105
+ if h == 0 or w == 0 or tw == 0 or th == 0:
106
+ return image
107
+ scale = min(tw / w, th / h)
108
+ nw, nh = max(1, int(round(w * scale))), max(1, int(round(h * scale)))
109
+ resized = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_CUBIC)
110
+ canvas = np.zeros((th, tw, 3), dtype=resized.dtype)
111
+ x0 = (tw - nw) // 2
112
+ y0 = (th - nh) // 2
113
+ canvas[y0:y0+nh, x0:x0+nw] = resized
114
+ return canvas
115
+
116
+ def _video_writer(out_path: Path, fps: int, size: Tuple[int, int]) -> cv2.VideoWriter:
117
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
118
+ return cv2.VideoWriter(str(out_path), fourcc, max(1, fps), size)
119
+
120
+ def _mux_audio(src_video: Union[str, Path], silent_video: Union[str, Path], out_path: Union[str, Path]) -> bool:
121
+ """Copy video from silent_video + audio from src_video into out_path (AAC)."""
122
+ try:
123
+ cmd = [
124
+ _ffmpeg_bin(), "-y",
125
+ "-i", str(silent_video),
126
+ "-i", str(src_video),
127
+ "-map", "0:v:0",
128
+ "-map", "1:a:0?",
129
+ "-c:v", "copy",
130
+ "-c:a", "aac", "-b:a", "192k",
131
+ "-shortest",
132
+ str(out_path)
133
+ ]
134
+ subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
135
+ return True
136
+ except Exception as e:
137
+ logger.warning(f"Audio mux failed; returning silent video. Reason: {e}")
138
+ return False
139
+
140
+ # --------------------------------------------------------------------------------------
141
+ # Compositing & Image Processing
142
+ # --------------------------------------------------------------------------------------
143
+ def _refine_alpha(alpha: np.ndarray, erode_px: int = 1, dilate_px: int = 2, blur_px: float = 1.5) -> np.ndarray:
144
+ """Erode→dilate + gentle blur → float alpha in [0,1]."""
145
+ if alpha.dtype != np.float32:
146
+ a = alpha.astype(np.float32)
147
+ if a.max() > 1.0:
148
+ a = a / 255.0
149
+ else:
150
+ a = alpha.copy()
151
+
152
+ a_u8 = np.clip(np.round(a * 255.0), 0, 255).astype(np.uint8)
153
+ if erode_px > 0:
154
+ k = max(1, int(erode_px))
155
+ a_u8 = cv2.erode(a_u8, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k)), iterations=1)
156
+ if dilate_px > 0:
157
+ k = max(1, int(dilate_px))
158
+ a_u8 = cv2.dilate(a_u8, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k)), iterations=1)
159
+ a = a_u8.astype(np.float32) / 255.0
160
+
161
+ if blur_px and blur_px > 0:
162
+ rad = max(1, int(round(blur_px)))
163
+ a = cv2.GaussianBlur(a, (rad | 1, rad | 1), 0)
164
+
165
+ return np.clip(a, 0.0, 1.0)
166
+
167
+ def _to_linear(rgb: np.ndarray, gamma: float = 2.2) -> np.ndarray:
168
+ x = np.clip(rgb.astype(np.float32) / 255.0, 0.0, 1.0)
169
+ return np.power(x, gamma)
170
+
171
+ def _to_srgb(lin: np.ndarray, gamma: float = 2.2) -> np.ndarray:
172
+ x = np.clip(lin, 0.0, 1.0)
173
+ return np.clip(np.power(x, 1.0 / gamma) * 255.0, 0, 255).astype(np.uint8)
174
+
175
+ def _light_wrap(bg_rgb: np.ndarray, alpha01: np.ndarray, radius: int = 5, amount: float = 0.18) -> np.ndarray:
176
+ """Simple light wrap from background into subject edges."""
177
+ r = max(1, int(radius))
178
+ inv = 1.0 - alpha01
179
+ inv_blur = cv2.GaussianBlur(inv, (r | 1, r | 1), 0)
180
+ lw = (bg_rgb.astype(np.float32) * inv_blur[..., None] * float(amount))
181
+ return lw
182
+
183
+ def _despill_edges(fg_rgb: np.ndarray, alpha01: np.ndarray, amount: float = 0.35) -> np.ndarray:
184
+ """Reduce saturation in boundary band (alpha≈0.5) to remove old-background tint."""
185
+ w = 1.0 - 2.0 * np.abs(alpha01 - 0.5) # bell-shaped weight
186
+ w = np.clip(w, 0.0, 1.0)
187
+ hsv = cv2.cvtColor(fg_rgb.astype(np.uint8), cv2.COLOR_RGB2HSV).astype(np.float32)
188
+ H, S, V = cv2.split(hsv)
189
+ S = S * (1.0 - amount * w)
190
+ hsv2 = cv2.merge([H, np.clip(S, 0, 255), V])
191
+ out = cv2.cvtColor(hsv2.astype(np.uint8), cv2.COLOR_HSV2RGB)
192
+ return out
193
+
194
+ def _composite_frame_pro(fg_rgb: np.ndarray, alpha: np.ndarray, bg_rgb: np.ndarray,
195
+ erode_px: int = None, dilate_px: int = None, blur_px: float = None,
196
+ lw_radius: int = None, lw_amount: float = None,
197
+ despill_amount: float = None) -> np.ndarray:
198
+ """Gamma-aware composite + edge refinement + light wrap + boundary de-spill."""
199
+ erode_px = erode_px if erode_px is not None else int(os.environ.get("EDGE_ERODE", "1"))
200
+ dilate_px = dilate_px if dilate_px is not None else int(os.environ.get("EDGE_DILATE", "2"))
201
+ blur_px = blur_px if blur_px is not None else float(os.environ.get("EDGE_BLUR", "1.5"))
202
+ lw_radius = lw_radius if lw_radius is not None else int(os.environ.get("LIGHTWRAP_RADIUS", "5"))
203
+ lw_amount = lw_amount if lw_amount is not None else float(os.environ.get("LIGHTWRAP_AMOUNT", "0.18"))
204
+ despill_amount = despill_amount if despill_amount is not None else float(os.environ.get("DESPILL_AMOUNT", "0.35"))
205
+
206
+ # refine alpha [0,1]
207
+ a = _refine_alpha(alpha, erode_px=erode_px, dilate_px=dilate_px, blur_px=blur_px)
208
+
209
+ # edge de-spill: temper saturation where a≈0.5
210
+ fg_rgb = _despill_edges(fg_rgb, a, amount=despill_amount)
211
+
212
+ # linearize for better blending
213
+ fg_lin = _to_linear(fg_rgb)
214
+ bg_lin = _to_linear(bg_rgb)
215
+
216
+ # light wrap
217
+ lw = _light_wrap(bg_rgb, a, radius=lw_radius, amount=lw_amount)
218
+ lw_lin = _to_linear(np.clip(lw, 0, 255).astype(np.uint8))
219
+
220
+ comp_lin = fg_lin * a[..., None] + bg_lin * (1.0 - a[..., None]) + lw_lin
221
+ comp = _to_srgb(comp_lin)
222
+ return comp
223
+
224
+ # --------------------------------------------------------------------------------------
225
+ # SAM2 Integration
226
+ # --------------------------------------------------------------------------------------
227
+ def _resolve_sam2_cfg(cfg_str: str) -> str:
228
+ """Make the SAM2 config path absolute (prefer inside TP_SAM2)."""
229
+ cfg_path = Path(cfg_str)
230
+ if not cfg_path.is_absolute():
231
+ candidate = TP_SAM2 / cfg_path
232
+ if candidate.exists():
233
+ return str(candidate)
234
+ if cfg_path.exists():
235
+ return str(cfg_path)
236
+ # Last resort: common defaults inside the repo
237
+ for name in ["configs/sam2/sam2_hiera_l.yaml", "configs/sam2/sam2_hiera_b.yaml", "configs/sam2/sam2_hiera_s.yaml"]:
238
+ p = TP_SAM2 / name
239
+ if p.exists():
240
+ return str(p)
241
+ return str(cfg_str) # let build_sam2 raise a clear error
242
+
243
+ def _find_hiera_config_if_hieradet(cfg_path: str) -> Optional[str]:
244
+ """If config references 'hieradet', try to find a 'hiera' config."""
245
+ try:
246
+ with open(cfg_path, "r") as f:
247
+ data = yaml.safe_load(f)
248
+ target = None
249
+ model = data.get("model", {})
250
+ enc = (model.get("image_encoder") or {})
251
+ trunk = (enc.get("trunk") or {})
252
+ target = trunk.get("_target_") or trunk.get("target")
253
+ if isinstance(target, str) and "hieradet" in target:
254
+ for y in TP_SAM2.rglob("*.yaml"):
255
+ try:
256
+ with open(y, "r") as f2:
257
+ d2 = yaml.safe_load(f2)
258
+ m2 = (d2 or {}).get("model", {})
259
+ e2 = (m2.get("image_encoder") or {})
260
+ t2 = (e2.get("trunk") or {})
261
+ tgt2 = t2.get("_target_") or t2.get("target")
262
+ if isinstance(tgt2, str) and ".hiera." in tgt2:
263
+ logger.info(f"SAM2: switching config from 'hieradet' → 'hiera': {y}")
264
+ return str(y)
265
+ except Exception:
266
+ continue
267
+ except Exception:
268
+ pass
269
+ return None
270
+
271
+ def load_sam2() -> Tuple[Optional[object], bool, Dict[str, Any]]:
272
+ """Robust SAM2 loader with config resolution and error handling."""
273
+ meta = {"sam2_import_ok": False, "sam2_init_ok": False}
274
+ try:
275
+ from sam2.build_sam import build_sam2 # type: ignore
276
+ from sam2.sam2_image_predictor import SAM2ImagePredictor # type: ignore
277
+ meta["sam2_import_ok"] = True
278
+ except Exception as e:
279
+ logger.warning(f"SAM2 import failed: {e}")
280
+ return None, False, meta
281
+
282
+ device = _pick_device("SAM2_DEVICE")
283
+ cfg_env = os.environ.get("SAM2_MODEL_CFG", "configs/sam2/sam2_hiera_l.yaml")
284
+ cfg = _resolve_sam2_cfg(cfg_env)
285
+ ckpt = os.environ.get("SAM2_CHECKPOINT", "")
286
+
287
+ def _try_build(cfg_path: str):
288
+ params = set(inspect.signature(build_sam2).parameters.keys())
289
+ kwargs = {}
290
+ if "config_file" in params:
291
+ kwargs["config_file"] = cfg_path
292
+ elif "model_cfg" in params:
293
+ kwargs["model_cfg"] = cfg_path
294
+ if ckpt:
295
+ if "checkpoint" in params:
296
+ kwargs["checkpoint"] = ckpt
297
+ elif "ckpt_path" in params:
298
+ kwargs["ckpt_path"] = ckpt
299
+ elif "weights" in params:
300
+ kwargs["weights"] = ckpt
301
+ if "device" in params:
302
+ kwargs["device"] = device
303
+ try:
304
+ return build_sam2(**kwargs)
305
+ except TypeError:
306
+ pos = [cfg_path]
307
+ if ckpt:
308
+ pos.append(ckpt)
309
+ if "device" not in kwargs:
310
+ pos.append(device)
311
+ return build_sam2(*pos)
312
+
313
+ try:
314
+ try:
315
+ sam = _try_build(cfg)
316
+ except Exception as e1:
317
+ alt_cfg = _find_hiera_config_if_hieradet(cfg)
318
+ if alt_cfg:
319
+ logger.info(f"SAM2: retrying with alt config: {alt_cfg}")
320
+ sam = _try_build(alt_cfg)
321
+ cfg = alt_cfg
322
+ else:
323
+ raise
324
+
325
+ predictor = SAM2ImagePredictor(sam)
326
+ meta.update({
327
+ "sam2_init_ok": True,
328
+ "sam2_device": device,
329
+ "sam2_cfg": cfg,
330
+ "sam2_ckpt": ckpt or "(repo default)"
331
+ })
332
+ return predictor, True, meta
333
+ except Exception as e:
334
+ logger.error(f"SAM2 init failed: {e}")
335
+ return None, False, meta
336
+
337
+ def run_sam2_mask(predictor: object,
338
+ first_frame_bgr: np.ndarray,
339
+ point: Optional[Tuple[int, int]] = None,
340
+ auto: bool = False) -> Tuple[Optional[np.ndarray], bool]:
341
+ """Return (mask_uint8_0_255, ok)."""
342
+ if predictor is None:
343
+ return None, False
344
+ try:
345
+ rgb = cv2.cvtColor(first_frame_bgr, cv2.COLOR_BGR2RGB)
346
+ predictor.set_image(rgb)
347
+
348
+ if auto:
349
+ h, w = rgb.shape[:2]
350
+ box = np.array([int(0.05*w), int(0.05*h), int(0.95*w), int(0.95*h)])
351
+ masks, _, _ = predictor.predict(box=box)
352
+ elif point is not None:
353
+ x, y = int(point[0]), int(point[1])
354
+ pts = np.array([[x, y]], dtype=np.int32)
355
+ labels = np.array([1], dtype=np.int32)
356
+ masks, _, _ = predictor.predict(point_coords=pts, point_labels=labels)
357
+ else:
358
+ h, w = rgb.shape[:2]
359
+ box = np.array([int(0.1*w), int(0.1*h), int(0.9*w), int(0.9*h)])
360
+ masks, _, _ = predictor.predict(box=box)
361
+
362
+ if masks is None or len(masks) == 0:
363
+ return None, False
364
+
365
+ m = masks[0].astype(np.uint8) * 255
366
+ return m, True
367
+ except Exception as e:
368
+ logger.warning(f"SAM2 mask failed: {e}")
369
+ return None, False
370
+
371
+ def _refine_mask_grabcut(image_bgr: np.ndarray,
372
+ mask_u8: np.ndarray,
373
+ iters: int = None,
374
+ trimap_erode: int = None,
375
+ trimap_dilate: int = None) -> np.ndarray:
376
+ """Use SAM2 seed as initialization for GrabCut refinement."""
377
+ iters = int(os.environ.get("REFINE_GRABCUT_ITERS", "2")) if iters is None else int(iters)
378
+ e = int(os.environ.get("REFINE_TRIMAP_ERODE", "3")) if trimap_erode is None else int(trimap_erode)
379
+ d = int(os.environ.get("REFINE_TRIMAP_DILATE", "6")) if trimap_dilate is None else int(trimap_dilate)
380
+
381
+ h, w = mask_u8.shape[:2]
382
+ m = (mask_u8 > 127).astype(np.uint8) * 255
383
+
384
+ sure_fg = cv2.erode(m, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (max(1, e), max(1, e))), iterations=1)
385
+ sure_bg = cv2.erode(255 - m, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (max(1, d), max(1, d))), iterations=1)
386
+
387
+ gc_mask = np.full((h, w), cv2.GC_PR_BGD, dtype=np.uint8)
388
+ gc_mask[sure_bg > 0] = cv2.GC_BGD
389
+ gc_mask[sure_fg > 0] = cv2.GC_FGD
390
+
391
+ bgdModel = np.zeros((1, 65), np.float64)
392
+ fgdModel = np.zeros((1, 65), np.float64)
393
+ try:
394
+ cv2.grabCut(image_bgr, gc_mask, None, bgdModel, fgdModel, iters, cv2.GC_INIT_WITH_MASK)
395
+ out = np.where((gc_mask == cv2.GC_FGD) | (gc_mask == cv2.GC_PR_FGD), 255, 0).astype(np.uint8)
396
+ out = cv2.medianBlur(out, 5)
397
+ return out
398
+ except Exception as e:
399
+ logger.warning(f"GrabCut refinement failed; using original mask. Reason: {e}")
400
+ return m
401
+
402
+ # --------------------------------------------------------------------------------------
403
+ # MatAnyone Integration
404
+ # --------------------------------------------------------------------------------------
405
+ def load_matany() -> Tuple[Optional[object], bool, Dict[str, Any]]:
406
+ """MatAnyone loader with disable switch and error handling."""
407
+ meta = {"matany_import_ok": False, "matany_init_ok": False}
408
+
409
+ enable_env = os.environ.get("ENABLE_MATANY", "1").strip().lower()
410
+ if enable_env in {"0", "false", "off", "no"}:
411
+ logger.info("MatAnyone disabled by ENABLE_MATANY=0.")
412
+ meta["disabled"] = True
413
+ return None, False, meta
414
+
415
+ try:
416
+ try:
417
+ from inference_core import InferenceCore # type: ignore
418
+ except Exception:
419
+ from matanyone.inference.inference_core import InferenceCore # type: ignore
420
+ meta["matany_import_ok"] = True
421
+ except Exception as e:
422
+ logger.warning(f"MatAnyone import failed: {e}")
423
+ return None, False, meta
424
+
425
+ device = _pick_device("MATANY_DEVICE")
426
+ repo_id = os.environ.get("MATANY_REPO_ID", "")
427
+ ckpt = os.environ.get("MATANY_CHECKPOINT", "")
428
+
429
+ # Check if this fork needs a prebuilt network
430
+ try:
431
+ sig = inspect.signature(InferenceCore)
432
+ if "network" in sig.parameters and sig.parameters["network"].default is inspect._empty:
433
+ logger.error(
434
+ "This MatAnyone fork expects `InferenceCore(network=...)`. "
435
+ "Pin a fork/commit that supplies a checkpoint-based constructor, "
436
+ "or set ENABLE_MATANY=0 to skip."
437
+ )
438
+ meta["needs_network_arg"] = True
439
+ return None, False, meta
440
+ except Exception:
441
+ pass
442
+
443
+ candidates = [
444
+ {"kwargs": {"repo_id": repo_id or None, "checkpoint": ckpt or None, "device": device}},
445
+ {"kwargs": {"checkpoint": ckpt or None, "device": device}},
446
+ {"args": (), "kwargs": {"device": device}},
447
+ ]
448
+ last_err = None
449
+ for cand in candidates:
450
+ try:
451
+ matany = InferenceCore(*cand.get("args", ()), **cand.get("kwargs", {}))
452
+ meta["matany_init_ok"] = True
453
+ meta["matany_device"] = device
454
+ meta["matany_repo_id"] = repo_id or "(unset)"
455
+ meta["matany_checkpoint"] = ckpt or "(unset)"
456
+ return matany, True, meta
457
+ except Exception as e:
458
+ last_err = e
459
+ continue
460
+
461
+ logger.error(f"MatAnyone init failed with all fallbacks: {last_err}")
462
+ return None, False, meta
463
+
464
+ def run_matany(matany: object,
465
+ video_path: Union[str, Path],
466
+ first_mask_path: Union[str, Path],
467
+ work_dir: Union[str, Path]) -> Tuple[Optional[str], Optional[str], bool]:
468
+ """Return (foreground_video_path, alpha_video_path, ok)."""
469
+ if matany is None:
470
+ return None, None, False
471
+ try:
472
+ if hasattr(matany, "process_video"):
473
+ out = matany.process_video(input_path=str(video_path), mask_path=str(first_mask_path), output_dir=str(work_dir))
474
+ if isinstance(out, (list, tuple)) and len(out) >= 2:
475
+ return str(out[0]), str(out[1]), True
476
+ if isinstance(out, dict):
477
+ fg = out.get("foreground") or out.get("fg") or out.get("foreground_path")
478
+ al = out.get("alpha") or out.get("alpha_path")
479
+ if fg and al:
480
+ return str(fg), str(al), True
481
+
482
+ if hasattr(matany, "run"):
483
+ out = matany.run(video_path=str(video_path), seed_mask=str(first_mask_path), out_dir=str(work_dir))
484
+ if isinstance(out, dict):
485
+ fg = out.get("foreground") or out.get("fg") or out.get("foreground_path")
486
+ al = out.get("alpha") or out.get("alpha_path")
487
+ if fg and al:
488
+ return str(fg), str(al), True
489
+
490
+ logger.error("MatAnyone returned no usable paths.")
491
+ return None, None, False
492
+ except Exception as e:
493
+ logger.warning(f"MatAnyone processing failed: {e}")
494
+ return None, None, False
495
+
496
+ # --------------------------------------------------------------------------------------
497
+ # Fallback Functions
498
+ # --------------------------------------------------------------------------------------
499
+ def fallback_mask(first_frame_bgr: np.ndarray) -> np.ndarray:
500
+ """Prefer MediaPipe; fallback to GrabCut. Returns uint8 mask 0/255."""
501
+ h, w = first_frame_bgr.shape[:2]
502
+ if _HAS_MEDIAPIPE:
503
+ try:
504
+ mp_selfie = mp.solutions.selfie_segmentation
505
+ with mp_selfie.SelfieSegmentation(model_selection=1) as segmenter:
506
+ rgb = cv2.cvtColor(first_frame_bgr, cv2.COLOR_BGR2RGB)
507
+ res = segmenter.process(rgb)
508
+ m = (np.clip(res.segmentation_mask, 0, 1) > 0.5).astype(np.uint8) * 255
509
+ m = cv2.medianBlur(m, 5)
510
+ return m
511
+ except Exception as e:
512
+ logger.warning(f"MediaPipe fallback failed: {e}")
513
+
514
+ # Ultimate fallback: GrabCut
515
+ mask = np.zeros((h, w), np.uint8)
516
+ rect = (int(0.1*w), int(0.1*h), int(0.8*w), int(0.8*h))
517
+ bgdModel = np.zeros((1, 65), np.float64)
518
+ fgdModel = np.zeros((1, 65), np.float64)
519
+ try:
520
+ cv2.grabCut(first_frame_bgr, mask, rect, bgdModel, fgdModel, 5, cv2.GC_INIT_WITH_RECT)
521
+ mask_bin = np.where((mask == cv2.GC_FGD) | (mask == cv2.GC_PR_FGD), 255, 0).astype(np.uint8)
522
+ return mask_bin
523
+ except Exception as e:
524
+ logger.warning(f"GrabCut failed: {e}")
525
+ return np.zeros((h, w), dtype=np.uint8)
526
+
527
+ def composite_video(fg_path: Union[str, Path],
528
+ alpha_path: Union[str, Path],
529
+ bg_image_path: Union[str, Path],
530
+ out_path: Union[str, Path],
531
+ fps: int,
532
+ size: Tuple[int, int]) -> bool:
533
+ """Blend MatAnyone FG+ALPHA over background using pro compositor."""
534
+ fg_cap = cv2.VideoCapture(str(fg_path))
535
+ al_cap = cv2.VideoCapture(str(alpha_path))
536
+ if not fg_cap.isOpened() or not al_cap.isOpened():
537
+ return False
538
+
539
+ w, h = size
540
+ bg = cv2.imread(str(bg_image_path), cv2.IMREAD_COLOR)
541
+ if bg is None:
542
+ bg = np.full((h, w, 3), 127, dtype=np.uint8)
543
+ bg_f = _resize_keep_ar(bg, (w, h))
544
+
545
+ if _probe_ffmpeg():
546
+ tmp_out = Path(str(out_path) + ".tmp.mp4")
547
+ writer = _video_writer(tmp_out, fps, (w, h))
548
+ post_h264 = True
549
+ else:
550
+ writer = _video_writer(Path(out_path), fps, (w, h))
551
+ post_h264 = False
552
+
553
+ ok_any = False
554
+ try:
555
+ while True:
556
+ ok_fg, fg = fg_cap.read()
557
+ ok_al, al = al_cap.read()
558
+ if not ok_fg or not ok_al:
559
+ break
560
+ fg = cv2.resize(fg, (w, h), interpolation=cv2.INTER_CUBIC)
561
+ al_gray = cv2.cvtColor(cv2.resize(al, (w, h)), cv2.COLOR_BGR2GRAY)
562
+
563
+ comp = _composite_frame_pro(
564
+ cv2.cvtColor(fg, cv2.COLOR_BGR2RGB),
565
+ al_gray,
566
+ cv2.cvtColor(bg_f, cv2.COLOR_BGR2RGB)
567
+ )
568
+ writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR))
569
+ ok_any = True
570
+ finally:
571
+ fg_cap.release()
572
+ al_cap.release()
573
+ writer.release()
574
+
575
+ if post_h264 and ok_any:
576
+ try:
577
+ cmd = [
578
+ _ffmpeg_bin(), "-y",
579
+ "-i", str(tmp_out),
580
+ "-c:v", "libx264", "-pix_fmt", "yuv420p", "-movflags", "+faststart",
581
+ str(out_path)
582
+ ]
583
+ subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
584
+ tmp_out.unlink(missing_ok=True)
585
+ except Exception as e:
586
+ logger.warning(f"ffmpeg finalize failed: {e}")
587
+ Path(out_path).unlink(missing_ok=True)
588
+ tmp_out.replace(out_path)
589
+
590
+ return ok_any
591
+
592
+ def fallback_composite(video_path: Union[str, Path],
593
+ mask_path: Union[str, Path],
594
+ bg_image_path: Union[str, Path],
595
+ out_path: Union[str, Path]) -> bool:
596
+ """Static-mask compositing using pro compositor."""
597
+ mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
598
+ cap = cv2.VideoCapture(str(video_path))
599
+ if mask is None or not cap.isOpened():
600
+ return False
601
+
602
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0)
603
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0)
604
+ fps = int(round(cap.get(cv2.CAP_PROP_FPS) or 25))
605
+
606
+ bg = cv2.imread(str(bg_image_path), cv2.IMREAD_COLOR)
607
+ if bg is None:
608
+ bg = np.full((h, w, 3), 127, dtype=np.uint8)
609
+
610
+ mask_resized = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
611
+ bg_f = _resize_keep_ar(bg, (w, h))
612
+
613
+ if _probe_ffmpeg():
614
+ tmp_out = Path(str(out_path) + ".tmp.mp4")
615
+ writer = _video_writer(tmp_out, fps, (w, h))
616
+ use_post_ffmpeg = True
617
+ else:
618
+ writer = _video_writer(Path(out_path), fps, (w, h))
619
+ use_post_ffmpeg = False
620
+
621
+ ok_any = False
622
+ try:
623
+ while True:
624
+ ok, frame = cap.read()
625
+ if not ok:
626
+ break
627
+ comp = _composite_frame_pro(
628
+ cv2.cvtColor(frame, cv2.COLOR_BGR2RGB),
629
+ mask_resized,
630
+ cv2.cvtColor(bg_f, cv2.COLOR_BGR2RGB)
631
+ )
632
+ writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR))
633
+ ok_any = True
634
+ finally:
635
+ cap.release()
636
+ writer.release()
637
+
638
+ if use_post_ffmpeg and ok_any:
639
+ try:
640
+ cmd = [
641
+ _ffmpeg_bin(), "-y",
642
+ "-i", str(tmp_out),
643
+ "-c:v", "libx264", "-pix_fmt", "yuv420p", "-movflags", "+faststart",
644
+ str(out_path)
645
+ ]
646
+ subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
647
+ tmp_out.unlink(missing_ok=True)
648
+ except Exception as e:
649
+ logger.warning(f"ffmpeg H.264 finalize failed: {e}")
650
+ Path(out_path).unlink(missing_ok=True)
651
+ tmp_out.replace(out_path)
652
+
653
+ return ok_any
654
+
655
+ # --------------------------------------------------------------------------------------
656
+ # Stage-A (Transparent Export) Functions
657
+ # --------------------------------------------------------------------------------------
658
+ def _checkerboard_bg(w: int, h: int, tile: int = 32) -> np.ndarray:
659
+ """RGB checkerboard for preview when no real alpha is possible."""
660
+ y, x = np.mgrid[0:h, 0:w]
661
+ c = ((x // tile) + (y // tile)) % 2
662
+ a = np.where(c == 0, 200, 150).astype(np.uint8)
663
+ return np.stack([a, a, a], axis=-1)
664
+
665
+ def _build_stage_a_rgba_vp9_from_fg_alpha(
666
+ fg_path: Union[str, Path],
667
+ alpha_path: Union[str, Path],
668
+ out_webm: Union[str, Path],
669
+ fps: int,
670
+ size: Tuple[int, int],
671
+ src_audio: Optional[Union[str, Path]] = None,
672
+ ) -> bool:
673
+ """Merge FG+ALPHA → RGBA WebM (VP9 with alpha)."""
674
+ if not _probe_ffmpeg():
675
+ return False
676
+ w, h = size
677
+ try:
678
+ cmd = [_ffmpeg_bin(), "-y", "-i", str(fg_path), "-i", str(alpha_path)]
679
+ if src_audio:
680
+ cmd += ["-i", str(src_audio)]
681
+ fcx = f"[1:v]format=gray,scale={w}:{h},fps={fps}[al];" \
682
+ f"[0:v]scale={w}:{h},fps={fps}[fg];" \
683
+ f"[fg][al]alphamerge[outv]"
684
+ cmd += ["-filter_complex", fcx, "-map", "[outv]"]
685
+ if src_audio:
686
+ cmd += ["-map", "2:a:0?", "-c:a", "libopus", "-b:a", "128k"]
687
+ cmd += [
688
+ "-c:v", "libvpx-vp9", "-pix_fmt", "yuva420p",
689
+ "-crf", os.environ.get("STAGEA_VP9_CRF", "28"),
690
+ "-b:v", "0", "-row-mt", "1", "-shortest", str(out_webm),
691
+ ]
692
+ subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
693
+ return True
694
+ except Exception as e:
695
+ logger.warning(f"Stage-A VP9(alpha) build failed: {e}")
696
+ return False
697
+
698
+ def _build_stage_a_rgba_vp9_from_mask(
699
+ video_path: Union[str, Path],
700
+ mask_png: Union[str, Path],
701
+ out_webm: Union[str, Path],
702
+ fps: int,
703
+ size: Tuple[int, int],
704
+ ) -> bool:
705
+ """Merge original video + static mask → RGBA WebM (VP9 with alpha)."""
706
+ if not _probe_ffmpeg():
707
+ return False
708
+ w, h = size
709
+ try:
710
+ cmd = [
711
+ _ffmpeg_bin(), "-y",
712
+ "-i", str(video_path),
713
+ "-loop", "1", "-i", str(mask_png),
714
+ "-filter_complex",
715
+ f"[1:v]format=gray,scale={w}:{h},fps={fps}[al];"
716
+ f"[0:v]scale={w}:{h},fps={fps}[fg];"
717
+ f"[fg][al]alphamerge[outv]",
718
+ "-map", "[outv]",
719
+ "-c:v", "libvpx-vp9", "-pix_fmt", "yuva420p",
720
+ "-crf", os.environ.get("STAGEA_VP9_CRF", "28"),
721
+ "-b:v", "0", "-row-mt", "1", "-shortest", str(out_webm),
722
+ ]
723
+ subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
724
+ return True
725
+ except Exception as e:
726
+ logger.warning(f"Stage-A VP9(alpha) (mask) build failed: {e}")
727
+ return False
728
+
729
+ def _build_stage_a_checkerboard_from_fg_alpha(
730
+ fg_path: Union[str, Path],
731
+ alpha_path: Union[str, Path],
732
+ out_mp4: Union[str, Path],
733
+ fps: int,
734
+ size: Tuple[int, int],
735
+ ) -> bool:
736
+ """Preview: FG+ALPHA over checkerboard → MP4 (no real alpha)."""
737
+ fg_cap = cv2.VideoCapture(str(fg_path))
738
+ al_cap = cv2.VideoCapture(str(alpha_path))
739
+ if not fg_cap.isOpened() or not al_cap.isOpened():
740
+ return False
741
+ w, h = size
742
+ writer = _video_writer(Path(out_mp4), fps, (w, h))
743
+ bg = _checkerboard_bg(w, h)
744
+ ok_any = False
745
+ try:
746
+ while True:
747
+ okf, fg = fg_cap.read()
748
+ oka, al = al_cap.read()
749
+ if not okf or not oka:
750
+ break
751
+ fg = cv2.resize(fg, (w, h))
752
+ al = cv2.cvtColor(cv2.resize(al, (w, h)), cv2.COLOR_BGR2GRAY)
753
+ comp = _composite_frame_pro(cv2.cvtColor(fg, cv2.COLOR_BGR2RGB), al, bg)
754
+ writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR))
755
+ ok_any = True
756
+ finally:
757
+ fg_cap.release()
758
+ al_cap.release()
759
+ writer.release()
760
+ return ok_any
761
+
762
+ def _build_stage_a_checkerboard_from_mask(
763
+ video_path: Union[str, Path],
764
+ mask_png: Union[str, Path],
765
+ out_mp4: Union[str, Path],
766
+ fps: int,
767
+ size: Tuple[int, int],
768
+ ) -> bool:
769
+ """Preview: original video + static mask over checkerboard → MP4."""
770
+ cap = cv2.VideoCapture(str(video_path))
771
+ if not cap.isOpened():
772
+ return False
773
+ w, h = size
774
+ mask = cv2.imread(str(mask_png), cv2.IMREAD_GRAYSCALE)
775
+ if mask is None:
776
+ return False
777
+ mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
778
+ writer = _video_writer(Path(out_mp4), fps, (w, h))
779
+ bg = _checkerboard_bg(w, h)
780
+ ok_any = False
781
+ try:
782
+ while True:
783
+ ok, frame = cap.read()
784
+ if not ok:
785
+ break
786
+ frame = cv2.resize(frame, (w, h))
787
+ comp = _composite_frame_pro(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), mask, bg)
788
+ writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR))
789
+ ok_any = True
790
+ finally:
791
+ cap.release()
792
+ writer.release()
793
+ return ok_any
pipeline.py CHANGED
@@ -1,52 +1,32 @@
1
- # pipeline.py
2
  #!/usr/bin/env python3
3
  """
4
- BackgroundFX Pro - Dynamic SAM2 + MatAnyone Pipeline (pro masking, pro compositing, audio mux)
5
- ==============================================================================================
6
-
7
- What's inside:
8
- - SAM2 (first-frame segmentation) via third_party/sam2 (env-configurable)
9
- - MatAnyone (temporal matting) via third_party/matanyone (env-configurable)
10
- - First-frame mask refinement via GrabCut (optional, default ON)
11
- - Pro compositing: alpha refinement, gamma-aware blending, light wrap, edge de-spill
12
- - Fallbacks: MediaPipe SelfieSegmentation → else OpenCV GrabCut
13
- - H.264 MP4 output (ffmpeg when available; OpenCV fallback)
14
- - Audio mux: original audio copied into final output (AAC) if present
15
- - Stage-A transparent export (VP9 with alpha or checkerboard preview)
16
-
17
- Environment knobs (all optional):
18
- - THIRD_PARTY_SAM2_DIR, THIRD_PARTY_MATANY_DIR
19
- - SAM2_MODEL_CFG, SAM2_CHECKPOINT, SAM2_DEVICE
20
- - MATANY_REPO_ID, MATANY_CHECKPOINT, MATANY_DEVICE, ENABLE_MATANY=1|0
21
- - FFMPEG_BIN
22
- - REFINE_GRABCUT=1 | 0 (enable/disable seed mask GrabCut refinement)
23
- - REFINE_GRABCUT_ITERS=2 (GrabCut iterations)
24
- - REFINE_TRIMAP_ERODE=3 (px for sure-FG erode)
25
- - REFINE_TRIMAP_DILATE=6 (px for sure-BG erode of inverse)
26
- - EDGE_ERODE=1, EDGE_DILATE=2, EDGE_BLUR=1.5
27
- - LIGHTWRAP_RADIUS=5, LIGHTWRAP_AMOUNT=0.18
28
- - DESPILL_AMOUNT=0.35
29
- - RETURN_STAGE_A=0 | 1 (if 1, return Stage-A file instead of final composite)
30
- - STAGEA_VP9_CRF=28 (quality for VP9 alpha export)
31
  """
32
 
33
  from __future__ import annotations
34
 
35
  import os
36
- import sys
37
- import cv2
38
  import time
39
  import tempfile
40
  import logging
41
- import subprocess
42
- import inspect
43
  from pathlib import Path
44
  from typing import Optional, Tuple, Dict, Any, Union
45
 
46
- import numpy as np
47
- import yaml # for SAM2 config introspection
48
-
49
- # Try to apply GPU/perf tuning early if present
 
 
 
 
 
 
 
50
  try:
51
  import perf_tuning # noqa: F401
52
  except Exception:
@@ -63,821 +43,41 @@
63
  logger.addHandler(_h)
64
 
65
  # --------------------------------------------------------------------------------------
66
- # Optional dependency: MediaPipe SelfieSegmentation
67
- # --------------------------------------------------------------------------------------
68
- try:
69
- import mediapipe as mp # type: ignore
70
- _HAS_MEDIAPIPE = True
71
- except Exception:
72
- _HAS_MEDIAPIPE = False
73
-
74
- # --------------------------------------------------------------------------------------
75
- # Path setup for third_party repos (dynamically override-able)
76
- # --------------------------------------------------------------------------------------
77
- ROOT = Path(__file__).resolve().parent
78
- TP_SAM2 = Path(os.environ.get("THIRD_PARTY_SAM2_DIR", ROOT / "third_party" / "sam2")).resolve()
79
- TP_MATANY = Path(os.environ.get("THIRD_PARTY_MATANY_DIR", ROOT / "third_party" / "matanyone")).resolve()
80
-
81
- def _add_sys_path(p: Path) -> None:
82
- p_str = str(p)
83
- if p_str not in sys.path:
84
- sys.path.insert(0, p_str)
85
-
86
- _add_sys_path(TP_SAM2)
87
- _add_sys_path(TP_MATANY)
88
-
89
- # --------------------------------------------------------------------------------------
90
- # Utilities
91
- # --------------------------------------------------------------------------------------
92
- def _ffmpeg_bin() -> str:
93
- return os.environ.get("FFMPEG_BIN", "ffmpeg")
94
-
95
- def _probe_ffmpeg() -> bool:
96
- try:
97
- subprocess.run([_ffmpeg_bin(), "-version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
98
- return True
99
- except Exception:
100
- return False
101
-
102
- def _mux_audio(src_video: Union[str, Path], silent_video: Union[str, Path], out_path: Union[str, Path]) -> bool:
103
- """Copy video from silent_video + audio from src_video into out_path (AAC)."""
104
- try:
105
- cmd = [
106
- _ffmpeg_bin(), "-y",
107
- "-i", str(silent_video),
108
- "-i", str(src_video),
109
- "-map", "0:v:0",
110
- "-map", "1:a:0?",
111
- "-c:v", "copy",
112
- "-c:a", "aac", "-b:a", "192k",
113
- "-shortest",
114
- str(out_path)
115
- ]
116
- subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
117
- return True
118
- except Exception as e:
119
- logger.warning(f"Audio mux failed; returning silent video. Reason: {e}")
120
- return False
121
-
122
- def _has_cuda() -> bool:
123
- try:
124
- import torch # type: ignore
125
- return torch.cuda.is_available()
126
- except Exception:
127
- return False
128
-
129
- def _pick_device(env_key: str) -> str:
130
- requested = os.environ.get(env_key, "").strip().lower()
131
- if requested in {"cuda", "cpu"}:
132
- return requested
133
- return "cuda" if _has_cuda() else "cpu"
134
-
135
- def _ensure_dir(p: Path) -> None:
136
- p.mkdir(parents=True, exist_ok=True)
137
-
138
- def _cv_read_first_frame(video_path: Union[str, Path]) -> Tuple[Optional[np.ndarray], int, Tuple[int, int]]:
139
- cap = cv2.VideoCapture(str(video_path))
140
- if not cap.isOpened():
141
- return None, 0, (0, 0)
142
- fps = int(round(cap.get(cv2.CAP_PROP_FPS) or 25))
143
- ok, frame = cap.read()
144
- w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0)
145
- h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0)
146
- cap.release()
147
- if not ok:
148
- return None, fps, (w, h)
149
- return frame, fps, (w, h)
150
-
151
- def _save_mask_png(mask: np.ndarray, path: Union[str, Path]) -> str:
152
- # expects mask as uint8 0..255 or bool
153
- if mask.dtype == bool:
154
- mask = (mask.astype(np.uint8) * 255)
155
- elif mask.dtype != np.uint8:
156
- mask = np.clip(mask, 0, 255).astype(np.uint8)
157
- cv2.imwrite(str(path), mask)
158
- return str(path)
159
-
160
- def _resize_keep_ar(image: np.ndarray, target_wh: Tuple[int, int]) -> np.ndarray:
161
- tw, th = target_wh
162
- h, w = image.shape[:2]
163
- if h == 0 or w == 0 or tw == 0 or th == 0:
164
- return image
165
- scale = min(tw / w, th / h)
166
- nw, nh = max(1, int(round(w * scale))), max(1, int(round(h * scale)))
167
- resized = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_CUBIC)
168
- canvas = np.zeros((th, tw, 3), dtype=resized.dtype)
169
- x0 = (tw - nw) // 2
170
- y0 = (th - nh) // 2
171
- canvas[y0:y0+nh, x0:x0+nw] = resized
172
- return canvas
173
-
174
- # ---- Edge refinement / compositing helpers ----
175
- def _refine_alpha(alpha: np.ndarray, erode_px: int = 1, dilate_px: int = 2, blur_px: float = 1.5) -> np.ndarray:
176
- """Erode→dilate + gentle blur → float alpha in [0,1]."""
177
- if alpha.dtype != np.float32:
178
- a = alpha.astype(np.float32)
179
- if a.max() > 1.0:
180
- a = a / 255.0
181
- else:
182
- a = alpha.copy()
183
-
184
- a_u8 = np.clip(np.round(a * 255.0), 0, 255).astype(np.uint8)
185
- if erode_px > 0:
186
- k = max(1, int(erode_px))
187
- a_u8 = cv2.erode(a_u8, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k)), iterations=1)
188
- if dilate_px > 0:
189
- k = max(1, int(dilate_px))
190
- a_u8 = cv2.dilate(a_u8, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k)), iterations=1)
191
- a = a_u8.astype(np.float32) / 255.0
192
-
193
- if blur_px and blur_px > 0:
194
- rad = max(1, int(round(blur_px)))
195
- a = cv2.GaussianBlur(a, (rad | 1, rad | 1), 0)
196
-
197
- return np.clip(a, 0.0, 1.0)
198
-
199
- def _to_linear(rgb: np.ndarray, gamma: float = 2.2) -> np.ndarray:
200
- x = np.clip(rgb.astype(np.float32) / 255.0, 0.0, 1.0)
201
- return np.power(x, gamma)
202
-
203
- def _to_srgb(lin: np.ndarray, gamma: float = 2.2) -> np.ndarray:
204
- x = np.clip(lin, 0.0, 1.0)
205
- return np.clip(np.power(x, 1.0 / gamma) * 255.0, 0, 255).astype(np.uint8)
206
-
207
- def _light_wrap(bg_rgb: np.ndarray, alpha01: np.ndarray, radius: int = 5, amount: float = 0.18) -> np.ndarray:
208
- """Simple light wrap from background into subject edges."""
209
- r = max(1, int(radius))
210
- inv = 1.0 - alpha01
211
- inv_blur = cv2.GaussianBlur(inv, (r | 1, r | 1), 0)
212
- lw = (bg_rgb.astype(np.float32) * inv_blur[..., None] * float(amount))
213
- return lw
214
-
215
- def _despill_edges(fg_rgb: np.ndarray, alpha01: np.ndarray, amount: float = 0.35) -> np.ndarray:
216
- """
217
- Reduce saturation only in the boundary band (alpha≈0.5) to remove old-background tint.
218
- amount: 0..1 (how strongly to desaturate)
219
- """
220
- w = 1.0 - 2.0 * np.abs(alpha01 - 0.5) # bell-shaped weight
221
- w = np.clip(w, 0.0, 1.0)
222
- hsv = cv2.cvtColor(fg_rgb.astype(np.uint8), cv2.COLOR_RGB2HSV).astype(np.float32)
223
- H, S, V = cv2.split(hsv)
224
- S = S * (1.0 - amount * w)
225
- hsv2 = cv2.merge([H, np.clip(S, 0, 255), V])
226
- out = cv2.cvtColor(hsv2.astype(np.uint8), cv2.COLOR_HSV2RGB)
227
- return out
228
-
229
- def _composite_frame_pro(fg_rgb: np.ndarray, alpha: np.ndarray, bg_rgb: np.ndarray,
230
- erode_px: int = None, dilate_px: int = None, blur_px: float = None,
231
- lw_radius: int = None, lw_amount: float = None,
232
- despill_amount: float = None) -> np.ndarray:
233
- """Gamma-aware composite + edge refinement + light wrap + boundary de-spill."""
234
- erode_px = erode_px if erode_px is not None else int(os.environ.get("EDGE_ERODE", "1"))
235
- dilate_px = dilate_px if dilate_px is not None else int(os.environ.get("EDGE_DILATE", "2"))
236
- blur_px = blur_px if blur_px is not None else float(os.environ.get("EDGE_BLUR", "1.5"))
237
- lw_radius = lw_radius if lw_radius is not None else int(os.environ.get("LIGHTWRAP_RADIUS", "5"))
238
- lw_amount = lw_amount if lw_amount is not None else float(os.environ.get("LIGHTWRAP_AMOUNT", "0.18"))
239
- despill_amount = despill_amount if despill_amount is not None else float(os.environ.get("DESPILL_AMOUNT", "0.35"))
240
-
241
- # refine alpha [0,1]
242
- a = _refine_alpha(alpha, erode_px=erode_px, dilate_px=dilate_px, blur_px=blur_px)
243
-
244
- # edge de-spill: temper saturation where a≈0.5
245
- fg_rgb = _despill_edges(fg_rgb, a, amount=despill_amount)
246
-
247
- # linearize for better blending
248
- fg_lin = _to_linear(fg_rgb)
249
- bg_lin = _to_linear(bg_rgb)
250
-
251
- # light wrap
252
- lw = _light_wrap(bg_rgb, a, radius=lw_radius, amount=lw_amount)
253
- lw_lin = _to_linear(np.clip(lw, 0, 255).astype(np.uint8))
254
-
255
- comp_lin = fg_lin * a[..., None] + bg_lin * (1.0 - a[..., None]) + lw_lin
256
- comp = _to_srgb(comp_lin)
257
- return comp
258
-
259
- def _video_writer(out_path: Path, fps: int, size: Tuple[int, int]) -> cv2.VideoWriter:
260
- fourcc = cv2.VideoWriter_fourcc(*"mp4v")
261
- return cv2.VideoWriter(str(out_path), fourcc, max(1, fps), size)
262
-
263
- # --- Stage-A (transparent) builders ----------------------------------------------------
264
- def _build_stage_a_rgba_vp9_from_fg_alpha(
265
- fg_path: Union[str, Path],
266
- alpha_path: Union[str, Path],
267
- out_webm: Union[str, Path],
268
- fps: int,
269
- size: Tuple[int, int],
270
- src_audio: Optional[Union[str, Path]] = None,
271
- ) -> bool:
272
- """Merge FG+ALPHA → RGBA WebM (VP9 with alpha). Optionally mux original audio (Opus)."""
273
- if not _probe_ffmpeg():
274
- return False
275
- w, h = size
276
- try:
277
- cmd = [
278
- _ffmpeg_bin(), "-y",
279
- "-i", str(fg_path), # 0: FG video
280
- "-i", str(alpha_path), # 1: ALPHA video (grayscale)
281
- ]
282
- if src_audio:
283
- cmd += ["-i", str(src_audio)] # 2: original (for audio)
284
- fcx = f"[1:v]format=gray,scale={w}:{h},fps={fps}[al];" \
285
- f"[0:v]scale={w}:{h},fps={fps}[fg];" \
286
- f"[fg][al]alphamerge[outv]"
287
- cmd += ["-filter_complex", fcx, "-map", "[outv]"]
288
- if src_audio:
289
- cmd += ["-map", "2:a:0?", "-c:a", "libopus", "-b:a", "128k"]
290
- cmd += [
291
- "-c:v", "libvpx-vp9", "-pix_fmt", "yuva420p",
292
- "-crf", os.environ.get("STAGEA_VP9_CRF", "28"),
293
- "-b:v", "0", "-row-mt", "1",
294
- "-shortest",
295
- str(out_webm),
296
- ]
297
- subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
298
- return True
299
- except Exception as e:
300
- logger.warning(f"Stage-A VP9(alpha) build failed: {e}")
301
- return False
302
-
303
- def _build_stage_a_rgba_vp9_from_mask(
304
- video_path: Union[str, Path],
305
- mask_png: Union[str, Path],
306
- out_webm: Union[str, Path],
307
- fps: int,
308
- size: Tuple[int, int],
309
- ) -> bool:
310
- """Merge original video + static mask → RGBA WebM (VP9 with alpha)."""
311
- if not _probe_ffmpeg():
312
- return False
313
- w, h = size
314
- try:
315
- cmd = [
316
- _ffmpeg_bin(), "-y",
317
- "-i", str(video_path), # 0: original video
318
- "-loop", "1", "-i", str(mask_png), # 1: static PNG mask (grayscale)
319
- "-filter_complex",
320
- f"[1:v]format=gray,scale={w}:{h},fps={fps}[al];"
321
- f"[0:v]scale={w}:{h},fps={fps}[fg];"
322
- f"[fg][al]alphamerge[outv]",
323
- "-map", "[outv]",
324
- "-c:v", "libvpx-vp9", "-pix_fmt", "yuva420p",
325
- "-crf", os.environ.get("STAGEA_VP9_CRF", "28"),
326
- "-b:v", "0", "-row-mt", "1",
327
- "-shortest",
328
- str(out_webm),
329
- ]
330
- subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
331
- return True
332
- except Exception as e:
333
- logger.warning(f"Stage-A VP9(alpha) (mask) build failed: {e}")
334
- return False
335
-
336
- def _checkerboard_bg(w: int, h: int, tile: int = 32) -> np.ndarray:
337
- """RGB checkerboard (for preview when no real alpha is possible)."""
338
- y, x = np.mgrid[0:h, 0:w]
339
- c = ((x // tile) + (y // tile)) % 2
340
- a = np.where(c == 0, 200, 150).astype(np.uint8)
341
- return np.stack([a, a, a], axis=-1)
342
-
343
- def _build_stage_a_checkerboard_from_fg_alpha(
344
- fg_path: Union[str, Path],
345
- alpha_path: Union[str, Path],
346
- out_mp4: Union[str, Path],
347
- fps: int,
348
- size: Tuple[int, int],
349
- ) -> bool:
350
- """Preview: FG+ALPHA over checkerboard → MP4 (no real alpha)."""
351
- fg_cap = cv2.VideoCapture(str(fg_path))
352
- al_cap = cv2.VideoCapture(str(alpha_path))
353
- if not fg_cap.isOpened() or not al_cap.isOpened():
354
- return False
355
- w, h = size
356
- writer = _video_writer(Path(out_mp4), fps, (w, h))
357
- bg = _checkerboard_bg(w, h)
358
- ok_any = False
359
- try:
360
- while True:
361
- okf, fg = fg_cap.read()
362
- oka, al = al_cap.read()
363
- if not okf or not oka:
364
- break
365
- fg = cv2.resize(fg, (w, h))
366
- al = cv2.cvtColor(cv2.resize(al, (w, h)), cv2.COLOR_BGR2GRAY)
367
- comp = _composite_frame_pro(cv2.cvtColor(fg, cv2.COLOR_BGR2RGB), al, bg)
368
- writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR))
369
- ok_any = True
370
- finally:
371
- fg_cap.release()
372
- al_cap.release()
373
- writer.release()
374
- return ok_any
375
-
376
- def _build_stage_a_checkerboard_from_mask(
377
- video_path: Union[str, Path],
378
- mask_png: Union[str, Path],
379
- out_mp4: Union[str, Path],
380
- fps: int,
381
- size: Tuple[int, int],
382
- ) -> bool:
383
- """Preview: original video + static mask over checkerboard → MP4."""
384
- cap = cv2.VideoCapture(str(video_path))
385
- if not cap.isOpened():
386
- return False
387
- w, h = size
388
- mask = cv2.imread(str(mask_png), cv2.IMREAD_GRAYSCALE)
389
- if mask is None:
390
- return False
391
- mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
392
- writer = _video_writer(Path(out_mp4), fps, (w, h))
393
- bg = _checkerboard_bg(w, h)
394
- ok_any = False
395
- try:
396
- while True:
397
- ok, frame = cap.read()
398
- if not ok:
399
- break
400
- frame = cv2.resize(frame, (w, h))
401
- comp = _composite_frame_pro(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), mask, bg)
402
- writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR))
403
- ok_any = True
404
- finally:
405
- cap.release()
406
- writer.release()
407
- return ok_any
408
-
409
- # --------------------------------------------------------------------------------------
410
- # SAM2 helpers (config resolution & robust loader)
411
  # --------------------------------------------------------------------------------------
412
- def _resolve_sam2_cfg(cfg_str: str) -> str:
413
- """Make the SAM2 config path absolute (prefer inside TP_SAM2)."""
414
- cfg_path = Path(cfg_str)
415
- if not cfg_path.is_absolute():
416
- candidate = TP_SAM2 / cfg_path
417
- if candidate.exists():
418
- return str(candidate)
419
- if cfg_path.exists():
420
- return str(cfg_path)
421
- # Last resort: common defaults inside the repo
422
- for name in ["configs/sam2/sam2_hiera_l.yaml", "configs/sam2/sam2_hiera_b.yaml", "configs/sam2/sam2_hiera_s.yaml"]:
423
- p = TP_SAM2 / name
424
- if p.exists():
425
- return str(p)
426
- return str(cfg_str) # let build_sam2 raise a clear error
427
-
428
- def _find_hiera_config_if_hieradet(cfg_path: str) -> Optional[str]:
429
- """
430
- If the given config references 'hieradet', try to find a 'hiera' config in the repo and return it.
431
- """
432
  try:
433
- with open(cfg_path, "r") as f:
434
- data = yaml.safe_load(f)
435
- # Look for target under model.image_encoder.trunk._target_ (Hydra style)
436
- target = None
437
- model = data.get("model", {})
438
- enc = (model.get("image_encoder") or {})
439
- trunk = (enc.get("trunk") or {})
440
- target = trunk.get("_target_") or trunk.get("target")
441
- if isinstance(target, str) and "hieradet" in target:
442
- # Search all yaml files under TP_SAM2/configs for those that reference ".hiera."
443
- for y in TP_SAM2.rglob("*.yaml"):
444
- try:
445
- with open(y, "r") as f2:
446
- d2 = yaml.safe_load(f2)
447
- m2 = (d2 or {}).get("model", {})
448
- e2 = (m2.get("image_encoder") or {})
449
- t2 = (e2.get("trunk") or {})
450
- tgt2 = t2.get("_target_") or t2.get("target")
451
- if isinstance(tgt2, str) and ".hiera." in tgt2:
452
- logger.info(f"SAM2: switching config from 'hieradet' → 'hiera': {y}")
453
- return str(y)
454
- except Exception:
455
- continue
456
  except Exception:
457
  pass
458
- return None
459
-
460
- # --------------------------------------------------------------------------------------
461
- # SAM2 Integration (robust to different build_sam2 signatures)
462
- # --------------------------------------------------------------------------------------
463
- def load_sam2() -> Tuple[Optional[object], bool, Dict[str, Any]]:
464
- """
465
- Robust SAM2 loader that adapts to different build_sam2 signatures:
466
- - config_file vs model_cfg
467
- - checkpoint vs ckpt_path vs weights
468
- - optional device kwarg
469
- - absolute config resolution (inside third_party/sam2)
470
- - auto-fix if config references 'hieradet' but repo has 'hiera'
471
- """
472
- meta = {"sam2_import_ok": False, "sam2_init_ok": False}
473
- try:
474
- from sam2.build_sam import build_sam2 # type: ignore
475
- from sam2.sam2_image_predictor import SAM2ImagePredictor # type: ignore
476
- meta["sam2_import_ok"] = True
477
- except Exception as e:
478
- logger.warning(f"SAM2 import failed: {e}")
479
- return None, False, meta
480
-
481
- device = _pick_device("SAM2_DEVICE")
482
- cfg_env = os.environ.get("SAM2_MODEL_CFG", "configs/sam2/sam2_hiera_l.yaml")
483
- cfg = _resolve_sam2_cfg(cfg_env)
484
- ckpt = os.environ.get("SAM2_CHECKPOINT", "")
485
 
486
- def _try_build(cfg_path: str):
487
- params = set(inspect.signature(build_sam2).parameters.keys())
488
- kwargs = {}
489
- # Config arg
490
- if "config_file" in params:
491
- kwargs["config_file"] = cfg_path
492
- elif "model_cfg" in params:
493
- kwargs["model_cfg"] = cfg_path
494
- # Checkpoint arg
495
- if ckpt:
496
- if "checkpoint" in params:
497
- kwargs["checkpoint"] = ckpt
498
- elif "ckpt_path" in params:
499
- kwargs["ckpt_path"] = ckpt
500
- elif "weights" in params:
501
- kwargs["weights"] = ckpt
502
- # Device
503
- if "device" in params:
504
- kwargs["device"] = device
505
- # Try keywords first, then positional fallback
506
  try:
507
- return build_sam2(**kwargs)
508
- except TypeError:
509
- pos = [cfg_path]
510
- if ckpt:
511
- pos.append(ckpt)
512
- if "device" not in kwargs:
513
- pos.append(device)
514
- return build_sam2(*pos)
515
-
516
- try:
517
- try:
518
- sam = _try_build(cfg)
519
- except Exception as e1:
520
- msg = str(e1)
521
- # If the config is using 'hieradet', try to swap to a 'hiera' config
522
- alt_cfg = _find_hiera_config_if_hieradet(cfg)
523
- if alt_cfg:
524
- logger.info(f"SAM2: retrying with alt config: {alt_cfg}")
525
- sam = _try_build(alt_cfg)
526
- cfg = alt_cfg
527
- else:
528
- raise
529
-
530
- predictor = SAM2ImagePredictor(sam)
531
- meta.update({
532
- "sam2_init_ok": True,
533
- "sam2_device": device,
534
- "sam2_cfg": cfg,
535
- "sam2_ckpt": ckpt or "(repo default)"
536
- })
537
- return predictor, True, meta
538
- except Exception as e:
539
- logger.error(f"SAM2 init failed: {e}")
540
- return None, False, meta
541
-
542
- def run_sam2_mask(predictor: object,
543
- first_frame_bgr: np.ndarray,
544
- point: Optional[Tuple[int, int]] = None,
545
- auto: bool = False) -> Tuple[Optional[np.ndarray], bool]:
546
- """Return (mask_uint8_0_255, ok)."""
547
- if predictor is None:
548
- return None, False
549
- try:
550
- rgb = cv2.cvtColor(first_frame_bgr, cv2.COLOR_BGR2RGB)
551
- predictor.set_image(rgb)
552
-
553
- if auto:
554
- h, w = rgb.shape[:2]
555
- box = np.array([int(0.05*w), int(0.05*h), int(0.95*w), int(0.95*h)])
556
- masks, _, _ = predictor.predict(box=box)
557
- elif point is not None:
558
- x, y = int(point[0]), int(point[1])
559
- pts = np.array([[x, y]], dtype=np.int32)
560
- labels = np.array([1], dtype=np.int32)
561
- masks, _, _ = predictor.predict(point_coords=pts, point_labels=labels)
562
- else:
563
- h, w = rgb.shape[:2]
564
- box = np.array([int(0.1*w), int(0.1*h), int(0.9*w), int(0.9*h)])
565
- masks, _, _ = predictor.predict(box=box)
566
-
567
- if masks is None or len(masks) == 0:
568
- return None, False
569
-
570
- m = masks[0].astype(np.uint8) * 255
571
- return m, True
572
- except Exception as e:
573
- logger.warning(f"SAM2 mask failed: {e}")
574
- return None, False
575
-
576
- # --------------------------------------------------------------------------------------
577
- # First-frame mask refinement (GrabCut with mask init)
578
- # --------------------------------------------------------------------------------------
579
- def _refine_mask_grabcut(image_bgr: np.ndarray,
580
- mask_u8: np.ndarray,
581
- iters: int = None,
582
- trimap_erode: int = None,
583
- trimap_dilate: int = None) -> np.ndarray:
584
- """
585
- Use SAM2 seed as initialization for GrabCut (GC_INIT_WITH_MASK).
586
- - sure FG: eroded mask
587
- - sure BG: eroded inverse
588
- - unknown: rest
589
- Returns refined binary mask (uint8 0/255).
590
- """
591
- iters = int(os.environ.get("REFINE_GRABCUT_ITERS", "2")) if iters is None else int(iters)
592
- e = int(os.environ.get("REFINE_TRIMAP_ERODE", "3")) if trimap_erode is None else int(trimap_erode)
593
- d = int(os.environ.get("REFINE_TRIMAP_DILATE", "6")) if trimap_dilate is None else int(trimap_dilate)
594
-
595
- h, w = mask_u8.shape[:2]
596
- m = (mask_u8 > 127).astype(np.uint8) * 255
597
-
598
- sure_fg = cv2.erode(m, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (max(1, e), max(1, e))), iterations=1)
599
- sure_bg = cv2.erode(255 - m, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (max(1, d), max(1, d))), iterations=1)
600
-
601
- gc_mask = np.full((h, w), cv2.GC_PR_BGD, dtype=np.uint8)
602
- gc_mask[sure_bg > 0] = cv2.GC_BGD
603
- gc_mask[sure_fg > 0] = cv2.GC_FGD
604
-
605
- bgdModel = np.zeros((1, 65), np.float64)
606
- fgdModel = np.zeros((1, 65), np.float64)
607
- try:
608
- cv2.grabCut(image_bgr, gc_mask, None, bgdModel, fgdModel, iters, cv2.GC_INIT_WITH_MASK)
609
- out = np.where((gc_mask == cv2.GC_FGD) | (gc_mask == cv2.GC_PR_FGD), 255, 0).astype(np.uint8)
610
- out = cv2.medianBlur(out, 5)
611
- return out
612
- except Exception as e:
613
- logger.warning(f"GrabCut refinement failed; using original mask. Reason: {e}")
614
- return m
615
-
616
- # --------------------------------------------------------------------------------------
617
- # MatAnyone Integration (robust + disable switch)
618
- # --------------------------------------------------------------------------------------
619
- def load_matany() -> Tuple[Optional[object], bool, Dict[str, Any]]:
620
- """
621
- MatAnyone loader that:
622
- - Skips if ENABLE_MATANY=0
623
- - Detects forks that require a `network` arg and exits cleanly with diagnostics
624
- - Otherwise tries repo/checkpoint style constructors
625
- """
626
- meta = {"matany_import_ok": False, "matany_init_ok": False}
627
-
628
- enable_env = os.environ.get("ENABLE_MATANY", "1").strip().lower()
629
- if enable_env in {"0", "false", "off", "no"}:
630
- logger.info("MatAnyone disabled by ENABLE_MATANY=0.")
631
- meta["disabled"] = True
632
- return None, False, meta
633
-
634
- try:
635
- try:
636
- from inference_core import InferenceCore # type: ignore
637
  except Exception:
638
- from matanyone.inference.inference_core import InferenceCore # type: ignore
639
- meta["matany_import_ok"] = True
640
- except Exception as e:
641
- logger.warning(f"MatAnyone import failed: {e}")
642
- return None, False, meta
643
-
644
- device = _pick_device("MATANY_DEVICE")
645
- repo_id = os.environ.get("MATANY_REPO_ID", "")
646
- ckpt = os.environ.get("MATANY_CHECKPOINT", "")
647
-
648
- # If this fork needs a prebuilt network, tell the user and skip
649
- try:
650
- sig = inspect.signature(InferenceCore)
651
- if "network" in sig.parameters and sig.parameters["network"].default is inspect._empty:
652
- logger.error(
653
- "This MatAnyone fork expects `InferenceCore(network=...)`. "
654
- "Pin a fork/commit that supplies a checkpoint-based constructor, "
655
- "or set ENABLE_MATANY=0 to skip."
656
- )
657
- meta["needs_network_arg"] = True
658
- return None, False, meta
659
- except Exception:
660
- pass
661
 
662
- candidates = [
663
- {"kwargs": {"repo_id": repo_id or None, "checkpoint": ckpt or None, "device": device}},
664
- {"kwargs": {"checkpoint": ckpt or None, "device": device}},
665
- {"args": (), "kwargs": {"device": device}},
666
- ]
667
- last_err = None
668
- for cand in candidates:
669
- try:
670
- matany = InferenceCore(*cand.get("args", ()), **cand.get("kwargs", {}))
671
- meta["matany_init_ok"] = True
672
- meta["matany_device"] = device
673
- meta["matany_repo_id"] = repo_id or "(unset)"
674
- meta["matany_checkpoint"] = ckpt or "(unset)"
675
- return matany, True, meta
676
- except Exception as e:
677
- last_err = e
678
- continue
679
-
680
- logger.error(f"MatAnyone init failed with all fallbacks: {last_err}")
681
- return None, False, meta
682
-
683
- def run_matany(matany: object,
684
- video_path: Union[str, Path],
685
- first_mask_path: Union[str, Path],
686
- work_dir: Union[str, Path]) -> Tuple[Optional[str], Optional[str], bool]:
687
- """Return (foreground_video_path, alpha_video_path, ok)."""
688
- if matany is None:
689
- return None, None, False
690
- video_path = str(video_path)
691
- first_mask_path = str(first_mask_path)
692
- work_dir = str(work_dir)
693
  try:
694
- if hasattr(matany, "process_video"):
695
- out = matany.process_video(input_path=video_path, mask_path=first_mask_path, output_dir=work_dir)
696
- if isinstance(out, (list, tuple)) and len(out) >= 2:
697
- return str(out[0]), str(out[1]), True
698
- if isinstance(out, dict):
699
- fg = out.get("foreground") or out.get("fg") or out.get("foreground_path")
700
- al = out.get("alpha") or out.get("alpha_path")
701
- if fg and al:
702
- return str(fg), str(al), True
703
-
704
- if hasattr(matany, "run"):
705
- out = matany.run(video_path=video_path, seed_mask=first_mask_path, out_dir=work_dir)
706
- if isinstance(out, dict):
707
- fg = out.get("foreground") or out.get("fg") or out.get("foreground_path")
708
- al = out.get("alpha") or out.get("alpha_path")
709
- if fg and al:
710
- return str(fg), str(al), True
711
-
712
- logger.error("MatAnyone returned no usable paths.")
713
- return None, None, False
714
  except Exception as e:
715
- logger.warning(f"MatAnyone processing failed: {e}")
716
- return None, None, False
717
 
718
  # --------------------------------------------------------------------------------------
719
- # Fallbacks
720
- # --------------------------------------------------------------------------------------
721
- def fallback_mask(first_frame_bgr: np.ndarray) -> np.ndarray:
722
- """Prefer MediaPipe; fallback to GrabCut. Returns uint8 mask 0/255."""
723
- h, w = first_frame_bgr.shape[:2]
724
- if _HAS_MEDIAPIPE:
725
- try:
726
- mp_selfie = mp.solutions.selfie_segmentation
727
- with mp_selfie.SelfieSegmentation(model_selection=1) as segmenter:
728
- rgb = cv2.cvtColor(first_frame_bgr, cv2.COLOR_BGR2RGB)
729
- res = segmenter.process(rgb)
730
- m = (np.clip(res.segmentation_mask, 0, 1) > 0.5).astype(np.uint8) * 255
731
- m = cv2.medianBlur(m, 5)
732
- return m
733
- except Exception as e:
734
- logger.warning(f"MediaPipe fallback failed: {e}")
735
-
736
- mask = np.zeros((h, w), np.uint8)
737
- rect = (int(0.1*w), int(0.1*h), int(0.8*w), int(0.8*h))
738
- bgdModel = np.zeros((1, 65), np.float64)
739
- fgdModel = np.zeros((1, 65), np.float64)
740
- try:
741
- cv2.grabCut(first_frame_bgr, mask, rect, bgdModel, fgdModel, 5, cv2.GC_INIT_WITH_RECT)
742
- mask_bin = np.where((mask == cv2.GC_FGD) | (mask == cv2.GC_PR_FGD), 255, 0).astype(np.uint8)
743
- return mask_bin
744
- except Exception as e:
745
- logger.warning(f"GrabCut failed: {e}")
746
- return np.zeros((h, w), dtype=np.uint8)
747
-
748
- def fallback_composite(video_path: Union[str, Path],
749
- mask_path: Union[str, Path],
750
- bg_image_path: Union[str, Path],
751
- out_path: Union[str, Path]) -> bool:
752
- """Static-mask compositing (uses pro compositor to reduce halos)."""
753
- mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
754
- cap = cv2.VideoCapture(str(video_path))
755
- if mask is None or not cap.isOpened():
756
- return False
757
-
758
- w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0)
759
- h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0)
760
- fps = int(round(cap.get(cv2.CAP_PROP_FPS) or 25))
761
-
762
- bg = cv2.imread(str(bg_image_path), cv2.IMREAD_COLOR)
763
- if bg is None:
764
- bg = np.full((h, w, 3), 127, dtype=np.uint8)
765
-
766
- mask_resized = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
767
- bg_f = _resize_keep_ar(bg, (w, h))
768
-
769
- if _probe_ffmpeg():
770
- tmp_out = Path(str(out_path) + ".tmp.mp4")
771
- writer = _video_writer(tmp_out, fps, (w, h))
772
- use_post_ffmpeg = True
773
- else:
774
- writer = _video_writer(Path(out_path), fps, (w, h))
775
- use_post_ffmpeg = False
776
-
777
- ok_any = False
778
- try:
779
- while True:
780
- ok, frame = cap.read()
781
- if not ok:
782
- break
783
- comp = _composite_frame_pro(
784
- cv2.cvtColor(frame, cv2.COLOR_BGR2RGB),
785
- mask_resized,
786
- cv2.cvtColor(bg_f, cv2.COLOR_BGR2RGB)
787
- )
788
- writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR))
789
- ok_any = True
790
- finally:
791
- cap.release()
792
- writer.release()
793
-
794
- if use_post_ffmpeg and ok_any:
795
- try:
796
- cmd = [
797
- _ffmpeg_bin(), "-y",
798
- "-i", str(tmp_out),
799
- "-c:v", "libx264", "-pix_fmt", "yuv420p", "-movflags", "+faststart",
800
- str(out_path)
801
- ]
802
- subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
803
- tmp_out.unlink(missing_ok=True)
804
- except Exception as e:
805
- logger.warning(f"ffmpeg H.264 finalize failed: {e}")
806
- Path(out_path).unlink(missing_ok=True)
807
- tmp_out.replace(out_path)
808
-
809
- return ok_any
810
-
811
- # --------------------------------------------------------------------------------------
812
- # Compositing using MatAnyone outputs
813
- # --------------------------------------------------------------------------------------
814
- def composite_video(fg_path: Union[str, Path],
815
- alpha_path: Union[str, Path],
816
- bg_image_path: Union[str, Path],
817
- out_path: Union[str, Path],
818
- fps: int,
819
- size: Tuple[int, int]) -> bool:
820
- """Blend MatAnyone FG+ALPHA over background using pro compositor."""
821
- fg_cap = cv2.VideoCapture(str(fg_path))
822
- al_cap = cv2.VideoCapture(str(alpha_path))
823
- if not fg_cap.isOpened() or not al_cap.isOpened():
824
- return False
825
-
826
- w, h = size
827
- bg = cv2.imread(str(bg_image_path), cv2.IMREAD_COLOR)
828
- if bg is None:
829
- bg = np.full((h, w, 3), 127, dtype=np.uint8)
830
- bg_f = _resize_keep_ar(bg, (w, h))
831
-
832
- if _probe_ffmpeg():
833
- tmp_out = Path(str(out_path) + ".tmp.mp4")
834
- writer = _video_writer(tmp_out, fps, (w, h))
835
- post_h264 = True
836
- else:
837
- writer = _video_writer(Path(out_path), fps, (w, h))
838
- post_h264 = False
839
-
840
- ok_any = False
841
- try:
842
- while True:
843
- ok_fg, fg = fg_cap.read()
844
- ok_al, al = al_cap.read()
845
- if not ok_fg or not ok_al:
846
- break
847
- fg = cv2.resize(fg, (w, h), interpolation=cv2.INTER_CUBIC)
848
- al_gray = cv2.cvtColor(cv2.resize(al, (w, h)), cv2.COLOR_BGR2GRAY)
849
-
850
- comp = _composite_frame_pro(
851
- cv2.cvtColor(fg, cv2.COLOR_BGR2RGB),
852
- al_gray,
853
- cv2.cvtColor(bg_f, cv2.COLOR_BGR2RGB)
854
- )
855
- writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR))
856
- ok_any = True
857
- finally:
858
- fg_cap.release()
859
- al_cap.release()
860
- writer.release()
861
-
862
- if post_h264 and ok_any:
863
- try:
864
- cmd = [
865
- _ffmpeg_bin(), "-y",
866
- "-i", str(tmp_out),
867
- "-c:v", "libx264", "-pix_fmt", "yuv420p", "-movflags", "+faststart",
868
- str(out_path)
869
- ]
870
- subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
871
- tmp_out.unlink(missing_ok=True)
872
- except Exception as e:
873
- logger.warning(f"ffmpeg finalize failed: {e}")
874
- Path(out_path).unlink(missing_ok=True)
875
- tmp_out.replace(out_path)
876
-
877
- return ok_any
878
-
879
- # --------------------------------------------------------------------------------------
880
- # High-level process function (for app.py)
881
  # --------------------------------------------------------------------------------------
882
  def process(video_path: Union[str, Path],
883
  bg_image_path: Union[str, Path],
@@ -885,7 +85,14 @@ def process(video_path: Union[str, Path],
885
  point_y: Optional[float] = None,
886
  auto_box: bool = False,
887
  work_dir: Optional[Union[str, Path]] = None) -> Tuple[Optional[str], Dict[str, Any]]:
888
- """Orchestrate: SAM2 mask → (optional GrabCut refine) → MatAnyone → Stage-A → composite → mux audio."""
 
 
 
 
 
 
 
889
  t0 = time.time()
890
  diagnostics: Dict[str, Any] = {
891
  "sam2_ok": False,
@@ -897,120 +104,197 @@ def process(video_path: Union[str, Path],
897
  "matany_meta": {},
898
  "device_sam2": None,
899
  "device_matany": None,
 
900
  }
901
 
902
  tmp_root = Path(work_dir) if work_dir else Path(tempfile.mkdtemp(prefix="bfx_"))
903
  _ensure_dir(tmp_root)
904
 
905
- # 0) Basic video info
906
- first_frame, fps, (vw, vh) = _cv_read_first_frame(video_path)
907
- diagnostics["fps"] = int(fps or 25)
908
- diagnostics["resolution"] = [int(vw), int(vh)]
909
- if first_frame is None or vw == 0 or vh == 0:
910
- diagnostics["fallback_used"] = "invalid_video"
911
- return None, diagnostics
912
-
913
- # 1) First-frame mask via SAM2 (or fallback)
914
- mask_png = tmp_root / "seed_mask.png"
915
- predictor, sam2_ok, sam_meta = load_sam2()
916
- diagnostics["sam2_meta"] = sam_meta
917
- diagnostics["device_sam2"] = sam_meta.get("sam2_device") if sam_meta else None
918
-
919
- seed_mask = None
920
- if sam2_ok:
921
- px = int(point_x) if point_x is not None else None
922
- py = int(point_y) if point_y is not None else None
923
- seed_mask, ok_mask = run_sam2_mask(
924
- predictor, first_frame,
925
- point=(px, py) if (px is not None and py is not None) else None,
926
- auto=auto_box
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
927
  )
928
- diagnostics["sam2_ok"] = bool(ok_mask)
929
- else:
930
- ok_mask = False
931
-
932
- if not ok_mask or seed_mask is None:
933
- logger.info("SAM2 failed or not available. Using fallback mask.")
934
- seed_mask = fallback_mask(first_frame)
935
- diagnostics["fallback_used"] = "mask_generation"
936
-
937
- # 1b) Optional GrabCut refinement over SAM2 seed
938
- if int(os.environ.get("REFINE_GRABCUT", "1")) == 1:
939
- seed_mask = _refine_mask_grabcut(first_frame, seed_mask)
940
-
941
- _save_mask_png(seed_mask, mask_png)
942
 
943
- # 2) Try MatAnyone
944
- matany, mat_ok, mat_meta = load_matany()
945
- diagnostics["matany_meta"] = mat_meta
946
- diagnostics["device_matany"] = mat_meta.get("matany_device") if mat_meta else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
947
 
948
- out_dir = tmp_root / "matany_out"
949
- _ensure_dir(out_dir)
950
- fg_path, al_path = None, None
951
- if mat_ok:
952
- fg_path, al_path, ran = run_matany(matany, video_path, mask_png, out_dir)
953
- diagnostics["matany_ok"] = bool(ran)
954
- else:
955
- ran = False
956
 
957
- # --- Build Stage-A (transparent) file for inspection ---
958
- stageA_path = None
959
- stageA_ok = False
960
- if diagnostics["matany_ok"] and fg_path and al_path:
961
- stageA_path = tmp_root / "stageA_transparent.webm"
962
- if _probe_ffmpeg():
963
- stageA_ok = _build_stage_a_rgba_vp9_from_fg_alpha(
964
- fg_path, al_path, stageA_path, diagnostics["fps"], (vw, vh), src_audio=video_path
965
- )
966
- if not stageA_ok:
967
- stageA_path = tmp_root / "stageA_checkerboard.mp4"
968
- stageA_ok = _build_stage_a_checkerboard_from_fg_alpha(
969
- fg_path, al_path, stageA_path, diagnostics["fps"], (vw, vh)
970
- )
971
- else:
972
- stageA_path = tmp_root / "stageA_transparent.webm"
973
  if _probe_ffmpeg():
974
- stageA_ok = _build_stage_a_rgba_vp9_from_mask(
975
- video_path, mask_png, stageA_path, diagnostics["fps"], (vw, vh)
976
- )
977
- if not stageA_ok:
978
- stageA_path = tmp_root / "stageA_checkerboard.mp4"
979
- stageA_ok = _build_stage_a_checkerboard_from_mask(
980
- video_path, mask_png, stageA_path, diagnostics["fps"], (vw, vh)
981
- )
982
-
983
- diagnostics["stageA_path"] = str(stageA_path) if stageA_ok else None
984
- diagnostics["stageA_note"] = (
985
- "WebM with real alpha (VP9)" if stageA_ok and str(stageA_path).endswith(".webm")
986
- else ("MP4 checkerboard preview (no real alpha)" if stageA_ok else "Stage-A build failed")
987
- )
988
-
989
- # Optional: return Stage-A instead of final composite
990
- if os.environ.get("RETURN_STAGE_A", "0").strip() == "1" and stageA_ok:
991
- return str(stageA_path), diagnostics
992
-
993
- # 3) Composite to final background
994
- output_path = tmp_root / "output.mp4"
995
- if diagnostics["matany_ok"] and fg_path and al_path:
996
- ok_comp = composite_video(fg_path, al_path, bg_image_path, output_path, diagnostics["fps"], (vw, vh))
997
- if not ok_comp:
998
- logger.info("MatAnyone composite failed; falling back to static mask composite.")
999
- fallback_composite(video_path, mask_png, bg_image_path, output_path)
1000
- diagnostics["fallback_used"] = (diagnostics["fallback_used"] or "") + "+composite_static"
1001
- else:
1002
- logger.info("MatAnyone not used; doing static mask composite.")
1003
- fallback_composite(video_path, mask_png, bg_image_path, output_path)
1004
- diagnostics["fallback_used"] = (diagnostics["fallback_used"] or "") or "composite_static"
1005
-
1006
- diagnostics["elapsed_sec"] = round(time.time() - t0, 3)
1007
 
1008
- # 4) Add audio back from the original input (if present)
1009
- final_path = tmp_root / "output_with_audio.mp4"
1010
- if _probe_ffmpeg():
1011
- mux_ok = _mux_audio(video_path, output_path, final_path)
1012
- if mux_ok:
1013
- return str(final_path), diagnostics
 
 
1014
 
1015
- # Fallback: return the silent video if mux failed or ffmpeg not available
1016
- return str(output_path), diagnostics
 
 
 
 
1
  #!/usr/bin/env python3
2
  """
3
+ BackgroundFX Pro - Memory-Optimized Pipeline
4
+ ===========================================
5
+ Orchestrates SAM2 → MatAnyone → Compositing with aggressive memory management.
6
+ Models are loaded sequentially and freed immediately after use.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  """
8
 
9
  from __future__ import annotations
10
 
11
  import os
12
+ import gc
 
13
  import time
14
  import tempfile
15
  import logging
 
 
16
  from pathlib import Path
17
  from typing import Optional, Tuple, Dict, Any, Union
18
 
19
+ import torch
20
+ from models import (
21
+ load_sam2, run_sam2_mask, load_matany, run_matany,
22
+ fallback_mask, fallback_composite, composite_video,
23
+ _cv_read_first_frame, _save_mask_png, _ensure_dir, _mux_audio, _probe_ffmpeg,
24
+ _refine_mask_grabcut, _build_stage_a_rgba_vp9_from_fg_alpha,
25
+ _build_stage_a_rgba_vp9_from_mask, _build_stage_a_checkerboard_from_fg_alpha,
26
+ _build_stage_a_checkerboard_from_mask
27
+ )
28
+
29
+ # Try to apply GPU/perf tuning early
30
  try:
31
  import perf_tuning # noqa: F401
32
  except Exception:
 
43
  logger.addHandler(_h)
44
 
45
  # --------------------------------------------------------------------------------------
46
+ # Memory Management Utilities
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  # --------------------------------------------------------------------------------------
48
+ def _cleanup_temp_files(tmp_root: Path) -> None:
49
+ """Clean up temporary files aggressively"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  try:
51
+ for pattern in ["*.tmp", "*.temp", "*.bak"]:
52
+ for f in tmp_root.glob(pattern):
53
+ f.unlink(missing_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  except Exception:
55
  pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
+ def _log_memory() -> float:
58
+ """Log current GPU memory usage and return allocated GB"""
59
+ if torch.cuda.is_available():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  try:
61
+ allocated = torch.cuda.memory_allocated() / 1e9
62
+ reserved = torch.cuda.memory_reserved() / 1e9
63
+ logger.info(f"GPU memory: {allocated:.1f}GB allocated, {reserved:.1f}GB reserved")
64
+ return allocated
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  except Exception:
66
+ pass
67
+ return 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ def _force_cleanup() -> None:
70
+ """Aggressive memory cleanup"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  try:
72
+ gc.collect()
73
+ if torch.cuda.is_available():
74
+ torch.cuda.empty_cache()
75
+ torch.cuda.synchronize()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  except Exception as e:
77
+ logger.warning(f"Cleanup warning: {e}")
 
78
 
79
  # --------------------------------------------------------------------------------------
80
+ # Main Processing Function (Memory-Optimized)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  # --------------------------------------------------------------------------------------
82
  def process(video_path: Union[str, Path],
83
  bg_image_path: Union[str, Path],
 
85
  point_y: Optional[float] = None,
86
  auto_box: bool = False,
87
  work_dir: Optional[Union[str, Path]] = None) -> Tuple[Optional[str], Dict[str, Any]]:
88
+ """
89
+ Memory-optimized orchestration: lazy loading, sequential model usage, aggressive cleanup.
90
+
91
+ Flow:
92
+ 1. Load SAM2 → get mask → FREE SAM2 immediately
93
+ 2. Load MatAnyone → process → FREE MatAnyone immediately
94
+ 3. Composite & finalize (CPU-based operations)
95
+ """
96
  t0 = time.time()
97
  diagnostics: Dict[str, Any] = {
98
  "sam2_ok": False,
 
104
  "matany_meta": {},
105
  "device_sam2": None,
106
  "device_matany": None,
107
+ "memory_peak_gb": 0.0,
108
  }
109
 
110
  tmp_root = Path(work_dir) if work_dir else Path(tempfile.mkdtemp(prefix="bfx_"))
111
  _ensure_dir(tmp_root)
112
 
113
+ try:
114
+ # 0) Basic video info
115
+ logger.info("Reading video metadata...")
116
+ first_frame, fps, (vw, vh) = _cv_read_first_frame(video_path)
117
+ diagnostics["fps"] = int(fps or 25)
118
+ diagnostics["resolution"] = [int(vw), int(vh)]
119
+
120
+ if first_frame is None or vw == 0 or vh == 0:
121
+ diagnostics["fallback_used"] = "invalid_video"
122
+ return None, diagnostics
123
+
124
+ diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
125
+
126
+ # 1) PHASE 1: SAM2 Loading & Processing → IMMEDIATE CLEANUP
127
+ logger.info("=== PHASE 1: Loading SAM2 for segmentation ===")
128
+ predictor, sam2_ok, sam_meta = load_sam2()
129
+ diagnostics["sam2_meta"] = sam_meta
130
+ diagnostics["device_sam2"] = sam_meta.get("sam2_device") if sam_meta else None
131
+
132
+ diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
133
+
134
+ seed_mask = None
135
+ mask_png = tmp_root / "seed_mask.png"
136
+
137
+ if sam2_ok and predictor is not None:
138
+ logger.info("Running SAM2 segmentation...")
139
+ px = int(point_x) if point_x is not None else None
140
+ py = int(point_y) if point_y is not None else None
141
+
142
+ seed_mask, ok_mask = run_sam2_mask(
143
+ predictor, first_frame,
144
+ point=(px, py) if (px is not None and py is not None) else None,
145
+ auto=auto_box
146
+ )
147
+ diagnostics["sam2_ok"] = bool(ok_mask)
148
+
149
+ # CRITICAL: Free SAM2 immediately after getting the mask
150
+ logger.info("Freeing SAM2 memory...")
151
+ del predictor
152
+ predictor = None
153
+ _force_cleanup()
154
+ diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
155
+
156
+ else:
157
+ ok_mask = False
158
+ logger.info("SAM2 not available or failed to load")
159
+
160
+ # Fallback mask generation if SAM2 failed
161
+ if not ok_mask or seed_mask is None:
162
+ logger.info("Using fallback mask generation...")
163
+ seed_mask = fallback_mask(first_frame)
164
+ diagnostics["fallback_used"] = "mask_generation"
165
+ _force_cleanup()
166
+
167
+ # Optional GrabCut refinement
168
+ if int(os.environ.get("REFINE_GRABCUT", "1")) == 1:
169
+ logger.info("Refining mask with GrabCut...")
170
+ seed_mask = _refine_mask_grabcut(first_frame, seed_mask)
171
+ _force_cleanup()
172
+
173
+ _save_mask_png(seed_mask, mask_png)
174
+
175
+ # Clean up the first frame from memory
176
+ del first_frame
177
+ _force_cleanup()
178
+ _cleanup_temp_files(tmp_root)
179
+
180
+ # 2) PHASE 2: MatAnyone Loading & Processing → IMMEDIATE CLEANUP
181
+ logger.info("=== PHASE 2: Loading MatAnyone for temporal processing ===")
182
+ matany, mat_ok, mat_meta = load_matany()
183
+ diagnostics["matany_meta"] = mat_meta
184
+ diagnostics["device_matany"] = mat_meta.get("matany_device") if mat_meta else None
185
+
186
+ diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
187
+
188
+ fg_path, al_path = None, None
189
+ out_dir = tmp_root / "matany_out"
190
+ _ensure_dir(out_dir)
191
+
192
+ if mat_ok and matany is not None:
193
+ logger.info("Running MatAnyone processing...")
194
+ fg_path, al_path, ran = run_matany(matany, video_path, mask_png, out_dir)
195
+ diagnostics["matany_ok"] = bool(ran)
196
+
197
+ # CRITICAL: Free MatAnyone immediately after processing
198
+ logger.info("Freeing MatAnyone memory...")
199
+ del matany
200
+ matany = None
201
+ _force_cleanup()
202
+ diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
203
+ else:
204
+ ran = False
205
+ logger.info("MatAnyone not available, disabled, or failed to load")
206
+
207
+ # 3) PHASE 3: Stage-A Creation (lightweight, CPU-based)
208
+ logger.info("=== PHASE 3: Building Stage-A (transparent export) ===")
209
+ stageA_path = None
210
+ stageA_ok = False
211
+
212
+ if diagnostics["matany_ok"] and fg_path and al_path:
213
+ stageA_path = tmp_root / "stageA_transparent.webm"
214
+ if _probe_ffmpeg():
215
+ stageA_ok = _build_stage_a_rgba_vp9_from_fg_alpha(
216
+ fg_path, al_path, stageA_path, diagnostics["fps"], (vw, vh), src_audio=video_path
217
+ )
218
+ if not stageA_ok:
219
+ stageA_path = tmp_root / "stageA_checkerboard.mp4"
220
+ stageA_ok = _build_stage_a_checkerboard_from_fg_alpha(
221
+ fg_path, al_path, stageA_path, diagnostics["fps"], (vw, vh)
222
+ )
223
+ else:
224
+ stageA_path = tmp_root / "stageA_transparent.webm"
225
+ if _probe_ffmpeg():
226
+ stageA_ok = _build_stage_a_rgba_vp9_from_mask(
227
+ video_path, mask_png, stageA_path, diagnostics["fps"], (vw, vh)
228
+ )
229
+ if not stageA_ok:
230
+ stageA_path = tmp_root / "stageA_checkerboard.mp4"
231
+ stageA_ok = _build_stage_a_checkerboard_from_mask(
232
+ video_path, mask_png, stageA_path, diagnostics["fps"], (vw, vh)
233
+ )
234
+
235
+ diagnostics["stageA_path"] = str(stageA_path) if stageA_ok else None
236
+ diagnostics["stageA_note"] = (
237
+ "WebM with real alpha (VP9)" if stageA_ok and str(stageA_path).endswith(".webm")
238
+ else ("MP4 checkerboard preview (no real alpha)" if stageA_ok else "Stage-A build failed")
239
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
+ # Optional: return Stage-A instead of final composite
242
+ if os.environ.get("RETURN_STAGE_A", "0").strip() == "1" and stageA_ok:
243
+ _force_cleanup()
244
+ _cleanup_temp_files(tmp_root)
245
+ return str(stageA_path), diagnostics
246
+
247
+ # 4) PHASE 4: Final Compositing (CPU-based, memory-efficient)
248
+ logger.info("=== PHASE 4: Creating final composite ===")
249
+ output_path = tmp_root / "output.mp4"
250
+
251
+ if diagnostics["matany_ok"] and fg_path and al_path:
252
+ logger.info("Compositing with MatAnyone outputs...")
253
+ ok_comp = composite_video(fg_path, al_path, bg_image_path, output_path, diagnostics["fps"], (vw, vh))
254
+ if not ok_comp:
255
+ logger.info("MatAnyone composite failed; falling back to static mask composite.")
256
+ fallback_composite(video_path, mask_png, bg_image_path, output_path)
257
+ diagnostics["fallback_used"] = (diagnostics["fallback_used"] or "") + "+composite_static"
258
+ else:
259
+ logger.info("Using static mask composite...")
260
+ fallback_composite(video_path, mask_png, bg_image_path, output_path)
261
+ diagnostics["fallback_used"] = (diagnostics["fallback_used"] or "") or "composite_static"
262
 
263
+ # Clean up intermediate files
264
+ _cleanup_temp_files(tmp_root)
265
+ _force_cleanup()
 
 
 
 
 
266
 
267
+ # 5) PHASE 5: Audio Muxing (final step)
268
+ logger.info("=== PHASE 5: Adding audio track ===")
269
+ final_path = tmp_root / "output_with_audio.mp4"
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  if _probe_ffmpeg():
271
+ mux_ok = _mux_audio(video_path, output_path, final_path)
272
+ if mux_ok:
273
+ # Clean up the silent version
274
+ output_path.unlink(missing_ok=True)
275
+ _force_cleanup()
276
+ diagnostics["elapsed_sec"] = round(time.time() - t0, 3)
277
+ logger.info(f"Processing completed successfully in {diagnostics['elapsed_sec']}s")
278
+ logger.info(f"Peak GPU memory usage: {diagnostics['memory_peak_gb']:.1f}GB")
279
+ return str(final_path), diagnostics
280
+
281
+ # Final cleanup
282
+ _force_cleanup()
283
+ diagnostics["elapsed_sec"] = round(time.time() - t0, 3)
284
+ logger.info(f"Processing completed in {diagnostics['elapsed_sec']}s (no audio)")
285
+ logger.info(f"Peak GPU memory usage: {diagnostics['memory_peak_gb']:.1f}GB")
286
+ return str(output_path), diagnostics
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
 
288
+ except Exception as e:
289
+ logger.error(f"Processing failed: {e}")
290
+ import traceback
291
+ logger.error(f"Traceback: {traceback.format_exc()}")
292
+ _force_cleanup()
293
+ diagnostics["error"] = str(e)
294
+ diagnostics["elapsed_sec"] = round(time.time() - t0, 3)
295
+ return None, diagnostics
296
 
297
+ finally:
298
+ # Ensure cleanup even if something goes wrong
299
+ _force_cleanup()
300
+ _cleanup_temp_files(tmp_root)