MogensR commited on
Commit
7aa0137
·
1 Parent(s): 90db7e5
Files changed (2) hide show
  1. models.py +0 -795
  2. models/__init__.py +795 -0
models.py DELETED
@@ -1,795 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- BackgroundFX Pro - Model Loading & Utilities
4
- ===========================================
5
- Contains all model loading, inference functions, and utility functions
6
- moved from the main pipeline for better organization.
7
- """
8
-
9
- from __future__ import annotations
10
-
11
- import os
12
- import sys
13
- import cv2
14
- import subprocess
15
- import inspect
16
- import logging
17
- from pathlib import Path
18
- from typing import Optional, Tuple, Dict, Any, Union
19
-
20
- import numpy as np
21
- import yaml
22
- import torch # For memory management and CUDA operations
23
- import torch # For memory management and CUDA operations
24
-
25
- # --------------------------------------------------------------------------------------
26
- # Logging
27
- # --------------------------------------------------------------------------------------
28
- logger = logging.getLogger("backgroundfx_pro")
29
-
30
- # --------------------------------------------------------------------------------------
31
- # Optional dependencies
32
- # --------------------------------------------------------------------------------------
33
- try:
34
- import mediapipe as mp # type: ignore
35
- _HAS_MEDIAPIPE = True
36
- except Exception:
37
- _HAS_MEDIAPIPE = False
38
-
39
- # --------------------------------------------------------------------------------------
40
- # Path setup for third_party repos
41
- # --------------------------------------------------------------------------------------
42
- ROOT = Path(__file__).resolve().parent
43
- TP_SAM2 = Path(os.environ.get("THIRD_PARTY_SAM2_DIR", ROOT / "third_party" / "sam2")).resolve()
44
- TP_MATANY = Path(os.environ.get("THIRD_PARTY_MATANY_DIR", ROOT / "third_party" / "matanyone")).resolve()
45
-
46
- def _add_sys_path(p: Path) -> None:
47
- p_str = str(p)
48
- if p_str not in sys.path:
49
- sys.path.insert(0, p_str)
50
-
51
- _add_sys_path(TP_SAM2)
52
- _add_sys_path(TP_MATANY)
53
-
54
- # --------------------------------------------------------------------------------------
55
- # Basic Utilities
56
- # --------------------------------------------------------------------------------------
57
- def _ffmpeg_bin() -> str:
58
- return os.environ.get("FFMPEG_BIN", "ffmpeg")
59
-
60
- def _probe_ffmpeg() -> bool:
61
- try:
62
- subprocess.run([_ffmpeg_bin(), "-version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
63
- return True
64
- except Exception:
65
- return False
66
-
67
- def _has_cuda() -> bool:
68
- try:
69
- import torch # type: ignore
70
- return torch.cuda.is_available()
71
- except Exception:
72
- return False
73
-
74
- def _pick_device(env_key: str) -> str:
75
- requested = os.environ.get(env_key, "").strip().lower()
76
- if requested in {"cuda", "cpu"}:
77
- return requested
78
- return "cuda" if _has_cuda() else "cpu"
79
-
80
- def _ensure_dir(p: Path) -> None:
81
- p.mkdir(parents=True, exist_ok=True)
82
-
83
- def _cv_read_first_frame(video_path: Union[str, Path]) -> Tuple[Optional[np.ndarray], int, Tuple[int, int]]:
84
- cap = cv2.VideoCapture(str(video_path))
85
- if not cap.isOpened():
86
- return None, 0, (0, 0)
87
- fps = int(round(cap.get(cv2.CAP_PROP_FPS) or 25))
88
- ok, frame = cap.read()
89
- w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0)
90
- h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0)
91
- cap.release()
92
- if not ok:
93
- return None, fps, (w, h)
94
- return frame, fps, (w, h)
95
-
96
- def _save_mask_png(mask: np.ndarray, path: Union[str, Path]) -> str:
97
- if mask.dtype == bool:
98
- mask = (mask.astype(np.uint8) * 255)
99
- elif mask.dtype != np.uint8:
100
- mask = np.clip(mask, 0, 255).astype(np.uint8)
101
- cv2.imwrite(str(path), mask)
102
- return str(path)
103
-
104
- def _resize_keep_ar(image: np.ndarray, target_wh: Tuple[int, int]) -> np.ndarray:
105
- tw, th = target_wh
106
- h, w = image.shape[:2]
107
- if h == 0 or w == 0 or tw == 0 or th == 0:
108
- return image
109
- scale = min(tw / w, th / h)
110
- nw, nh = max(1, int(round(w * scale))), max(1, int(round(h * scale)))
111
- resized = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_CUBIC)
112
- canvas = np.zeros((th, tw, 3), dtype=resized.dtype)
113
- x0 = (tw - nw) // 2
114
- y0 = (th - nh) // 2
115
- canvas[y0:y0+nh, x0:x0+nw] = resized
116
- return canvas
117
-
118
- def _video_writer(out_path: Path, fps: int, size: Tuple[int, int]) -> cv2.VideoWriter:
119
- fourcc = cv2.VideoWriter_fourcc(*"mp4v")
120
- return cv2.VideoWriter(str(out_path), fourcc, max(1, fps), size)
121
-
122
- def _mux_audio(src_video: Union[str, Path], silent_video: Union[str, Path], out_path: Union[str, Path]) -> bool:
123
- """Copy video from silent_video + audio from src_video into out_path (AAC)."""
124
- try:
125
- cmd = [
126
- _ffmpeg_bin(), "-y",
127
- "-i", str(silent_video),
128
- "-i", str(src_video),
129
- "-map", "0:v:0",
130
- "-map", "1:a:0?",
131
- "-c:v", "copy",
132
- "-c:a", "aac", "-b:a", "192k",
133
- "-shortest",
134
- str(out_path)
135
- ]
136
- subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
137
- return True
138
- except Exception as e:
139
- logger.warning(f"Audio mux failed; returning silent video. Reason: {e}")
140
- return False
141
-
142
- # --------------------------------------------------------------------------------------
143
- # Compositing & Image Processing
144
- # --------------------------------------------------------------------------------------
145
- def _refine_alpha(alpha: np.ndarray, erode_px: int = 1, dilate_px: int = 2, blur_px: float = 1.5) -> np.ndarray:
146
- """Erode→dilate + gentle blur → float alpha in [0,1]."""
147
- if alpha.dtype != np.float32:
148
- a = alpha.astype(np.float32)
149
- if a.max() > 1.0:
150
- a = a / 255.0
151
- else:
152
- a = alpha.copy()
153
-
154
- a_u8 = np.clip(np.round(a * 255.0), 0, 255).astype(np.uint8)
155
- if erode_px > 0:
156
- k = max(1, int(erode_px))
157
- a_u8 = cv2.erode(a_u8, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k)), iterations=1)
158
- if dilate_px > 0:
159
- k = max(1, int(dilate_px))
160
- a_u8 = cv2.dilate(a_u8, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k)), iterations=1)
161
- a = a_u8.astype(np.float32) / 255.0
162
-
163
- if blur_px and blur_px > 0:
164
- rad = max(1, int(round(blur_px)))
165
- a = cv2.GaussianBlur(a, (rad | 1, rad | 1), 0)
166
-
167
- return np.clip(a, 0.0, 1.0)
168
-
169
- def _to_linear(rgb: np.ndarray, gamma: float = 2.2) -> np.ndarray:
170
- x = np.clip(rgb.astype(np.float32) / 255.0, 0.0, 1.0)
171
- return np.power(x, gamma)
172
-
173
- def _to_srgb(lin: np.ndarray, gamma: float = 2.2) -> np.ndarray:
174
- x = np.clip(lin, 0.0, 1.0)
175
- return np.clip(np.power(x, 1.0 / gamma) * 255.0, 0, 255).astype(np.uint8)
176
-
177
- def _light_wrap(bg_rgb: np.ndarray, alpha01: np.ndarray, radius: int = 5, amount: float = 0.18) -> np.ndarray:
178
- """Simple light wrap from background into subject edges."""
179
- r = max(1, int(radius))
180
- inv = 1.0 - alpha01
181
- inv_blur = cv2.GaussianBlur(inv, (r | 1, r | 1), 0)
182
- lw = (bg_rgb.astype(np.float32) * inv_blur[..., None] * float(amount))
183
- return lw
184
-
185
- def _despill_edges(fg_rgb: np.ndarray, alpha01: np.ndarray, amount: float = 0.35) -> np.ndarray:
186
- """Reduce saturation in boundary band (alpha≈0.5) to remove old-background tint."""
187
- w = 1.0 - 2.0 * np.abs(alpha01 - 0.5) # bell-shaped weight
188
- w = np.clip(w, 0.0, 1.0)
189
- hsv = cv2.cvtColor(fg_rgb.astype(np.uint8), cv2.COLOR_RGB2HSV).astype(np.float32)
190
- H, S, V = cv2.split(hsv)
191
- S = S * (1.0 - amount * w)
192
- hsv2 = cv2.merge([H, np.clip(S, 0, 255), V])
193
- out = cv2.cvtColor(hsv2.astype(np.uint8), cv2.COLOR_HSV2RGB)
194
- return out
195
-
196
- def _composite_frame_pro(fg_rgb: np.ndarray, alpha: np.ndarray, bg_rgb: np.ndarray,
197
- erode_px: int = None, dilate_px: int = None, blur_px: float = None,
198
- lw_radius: int = None, lw_amount: float = None,
199
- despill_amount: float = None) -> np.ndarray:
200
- """Gamma-aware composite + edge refinement + light wrap + boundary de-spill."""
201
- erode_px = erode_px if erode_px is not None else int(os.environ.get("EDGE_ERODE", "1"))
202
- dilate_px = dilate_px if dilate_px is not None else int(os.environ.get("EDGE_DILATE", "2"))
203
- blur_px = blur_px if blur_px is not None else float(os.environ.get("EDGE_BLUR", "1.5"))
204
- lw_radius = lw_radius if lw_radius is not None else int(os.environ.get("LIGHTWRAP_RADIUS", "5"))
205
- lw_amount = lw_amount if lw_amount is not None else float(os.environ.get("LIGHTWRAP_AMOUNT", "0.18"))
206
- despill_amount = despill_amount if despill_amount is not None else float(os.environ.get("DESPILL_AMOUNT", "0.35"))
207
-
208
- # refine alpha [0,1]
209
- a = _refine_alpha(alpha, erode_px=erode_px, dilate_px=dilate_px, blur_px=blur_px)
210
-
211
- # edge de-spill: temper saturation where a≈0.5
212
- fg_rgb = _despill_edges(fg_rgb, a, amount=despill_amount)
213
-
214
- # linearize for better blending
215
- fg_lin = _to_linear(fg_rgb)
216
- bg_lin = _to_linear(bg_rgb)
217
-
218
- # light wrap
219
- lw = _light_wrap(bg_rgb, a, radius=lw_radius, amount=lw_amount)
220
- lw_lin = _to_linear(np.clip(lw, 0, 255).astype(np.uint8))
221
-
222
- comp_lin = fg_lin * a[..., None] + bg_lin * (1.0 - a[..., None]) + lw_lin
223
- comp = _to_srgb(comp_lin)
224
- return comp
225
-
226
- # --------------------------------------------------------------------------------------
227
- # SAM2 Integration
228
- # --------------------------------------------------------------------------------------
229
- def _resolve_sam2_cfg(cfg_str: str) -> str:
230
- """Make the SAM2 config path absolute (prefer inside TP_SAM2)."""
231
- cfg_path = Path(cfg_str)
232
- if not cfg_path.is_absolute():
233
- candidate = TP_SAM2 / cfg_path
234
- if candidate.exists():
235
- return str(candidate)
236
- if cfg_path.exists():
237
- return str(cfg_path)
238
- # Last resort: common defaults inside the repo
239
- for name in ["configs/sam2/sam2_hiera_l.yaml", "configs/sam2/sam2_hiera_b.yaml", "configs/sam2/sam2_hiera_s.yaml"]:
240
- p = TP_SAM2 / name
241
- if p.exists():
242
- return str(p)
243
- return str(cfg_str) # let build_sam2 raise a clear error
244
-
245
- def _find_hiera_config_if_hieradet(cfg_path: str) -> Optional[str]:
246
- """If config references 'hieradet', try to find a 'hiera' config."""
247
- try:
248
- with open(cfg_path, "r") as f:
249
- data = yaml.safe_load(f)
250
- target = None
251
- model = data.get("model", {})
252
- enc = (model.get("image_encoder") or {})
253
- trunk = (enc.get("trunk") or {})
254
- target = trunk.get("_target_") or trunk.get("target")
255
- if isinstance(target, str) and "hieradet" in target:
256
- for y in TP_SAM2.rglob("*.yaml"):
257
- try:
258
- with open(y, "r") as f2:
259
- d2 = yaml.safe_load(f2)
260
- m2 = (d2 or {}).get("model", {})
261
- e2 = (m2.get("image_encoder") or {})
262
- t2 = (e2.get("trunk") or {})
263
- tgt2 = t2.get("_target_") or t2.get("target")
264
- if isinstance(tgt2, str) and ".hiera." in tgt2:
265
- logger.info(f"SAM2: switching config from 'hieradet' → 'hiera': {y}")
266
- return str(y)
267
- except Exception:
268
- continue
269
- except Exception:
270
- pass
271
- return None
272
-
273
- def load_sam2() -> Tuple[Optional[object], bool, Dict[str, Any]]:
274
- """Robust SAM2 loader with config resolution and error handling."""
275
- meta = {"sam2_import_ok": False, "sam2_init_ok": False}
276
- try:
277
- from sam2.build_sam import build_sam2 # type: ignore
278
- from sam2.sam2_image_predictor import SAM2ImagePredictor # type: ignore
279
- meta["sam2_import_ok"] = True
280
- except Exception as e:
281
- logger.warning(f"SAM2 import failed: {e}")
282
- return None, False, meta
283
-
284
- device = _pick_device("SAM2_DEVICE")
285
- cfg_env = os.environ.get("SAM2_MODEL_CFG", "configs/sam2/sam2_hiera_l.yaml")
286
- cfg = _resolve_sam2_cfg(cfg_env)
287
- ckpt = os.environ.get("SAM2_CHECKPOINT", "")
288
-
289
- def _try_build(cfg_path: str):
290
- params = set(inspect.signature(build_sam2).parameters.keys())
291
- kwargs = {}
292
- if "config_file" in params:
293
- kwargs["config_file"] = cfg_path
294
- elif "model_cfg" in params:
295
- kwargs["model_cfg"] = cfg_path
296
- if ckpt:
297
- if "checkpoint" in params:
298
- kwargs["checkpoint"] = ckpt
299
- elif "ckpt_path" in params:
300
- kwargs["ckpt_path"] = ckpt
301
- elif "weights" in params:
302
- kwargs["weights"] = ckpt
303
- if "device" in params:
304
- kwargs["device"] = device
305
- try:
306
- return build_sam2(**kwargs)
307
- except TypeError:
308
- pos = [cfg_path]
309
- if ckpt:
310
- pos.append(ckpt)
311
- if "device" not in kwargs:
312
- pos.append(device)
313
- return build_sam2(*pos)
314
-
315
- try:
316
- try:
317
- sam = _try_build(cfg)
318
- except Exception as e1:
319
- alt_cfg = _find_hiera_config_if_hieradet(cfg)
320
- if alt_cfg:
321
- logger.info(f"SAM2: retrying with alt config: {alt_cfg}")
322
- sam = _try_build(alt_cfg)
323
- cfg = alt_cfg
324
- else:
325
- raise
326
-
327
- predictor = SAM2ImagePredictor(sam)
328
- meta.update({
329
- "sam2_init_ok": True,
330
- "sam2_device": device,
331
- "sam2_cfg": cfg,
332
- "sam2_ckpt": ckpt or "(repo default)"
333
- })
334
- return predictor, True, meta
335
- except Exception as e:
336
- logger.error(f"SAM2 init failed: {e}")
337
- return None, False, meta
338
-
339
- def run_sam2_mask(predictor: object,
340
- first_frame_bgr: np.ndarray,
341
- point: Optional[Tuple[int, int]] = None,
342
- auto: bool = False) -> Tuple[Optional[np.ndarray], bool]:
343
- """Return (mask_uint8_0_255, ok)."""
344
- if predictor is None:
345
- return None, False
346
- try:
347
- rgb = cv2.cvtColor(first_frame_bgr, cv2.COLOR_BGR2RGB)
348
- predictor.set_image(rgb)
349
-
350
- if auto:
351
- h, w = rgb.shape[:2]
352
- box = np.array([int(0.05*w), int(0.05*h), int(0.95*w), int(0.95*h)])
353
- masks, _, _ = predictor.predict(box=box)
354
- elif point is not None:
355
- x, y = int(point[0]), int(point[1])
356
- pts = np.array([[x, y]], dtype=np.int32)
357
- labels = np.array([1], dtype=np.int32)
358
- masks, _, _ = predictor.predict(point_coords=pts, point_labels=labels)
359
- else:
360
- h, w = rgb.shape[:2]
361
- box = np.array([int(0.1*w), int(0.1*h), int(0.9*w), int(0.9*h)])
362
- masks, _, _ = predictor.predict(box=box)
363
-
364
- if masks is None or len(masks) == 0:
365
- return None, False
366
-
367
- m = masks[0].astype(np.uint8) * 255
368
- return m, True
369
- except Exception as e:
370
- logger.warning(f"SAM2 mask failed: {e}")
371
- return None, False
372
-
373
- def _refine_mask_grabcut(image_bgr: np.ndarray,
374
- mask_u8: np.ndarray,
375
- iters: int = None,
376
- trimap_erode: int = None,
377
- trimap_dilate: int = None) -> np.ndarray:
378
- """Use SAM2 seed as initialization for GrabCut refinement."""
379
- iters = int(os.environ.get("REFINE_GRABCUT_ITERS", "2")) if iters is None else int(iters)
380
- e = int(os.environ.get("REFINE_TRIMAP_ERODE", "3")) if trimap_erode is None else int(trimap_erode)
381
- d = int(os.environ.get("REFINE_TRIMAP_DILATE", "6")) if trimap_dilate is None else int(trimap_dilate)
382
-
383
- h, w = mask_u8.shape[:2]
384
- m = (mask_u8 > 127).astype(np.uint8) * 255
385
-
386
- sure_fg = cv2.erode(m, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (max(1, e), max(1, e))), iterations=1)
387
- sure_bg = cv2.erode(255 - m, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (max(1, d), max(1, d))), iterations=1)
388
-
389
- gc_mask = np.full((h, w), cv2.GC_PR_BGD, dtype=np.uint8)
390
- gc_mask[sure_bg > 0] = cv2.GC_BGD
391
- gc_mask[sure_fg > 0] = cv2.GC_FGD
392
-
393
- bgdModel = np.zeros((1, 65), np.float64)
394
- fgdModel = np.zeros((1, 65), np.float64)
395
- try:
396
- cv2.grabCut(image_bgr, gc_mask, None, bgdModel, fgdModel, iters, cv2.GC_INIT_WITH_MASK)
397
- out = np.where((gc_mask == cv2.GC_FGD) | (gc_mask == cv2.GC_PR_FGD), 255, 0).astype(np.uint8)
398
- out = cv2.medianBlur(out, 5)
399
- return out
400
- except Exception as e:
401
- logger.warning(f"GrabCut refinement failed; using original mask. Reason: {e}")
402
- return m
403
-
404
- # --------------------------------------------------------------------------------------
405
- # MatAnyone Integration
406
- # --------------------------------------------------------------------------------------
407
- def load_matany() -> Tuple[Optional[object], bool, Dict[str, Any]]:
408
- """MatAnyone loader with disable switch and error handling."""
409
- meta = {"matany_import_ok": False, "matany_init_ok": False}
410
-
411
- enable_env = os.environ.get("ENABLE_MATANY", "1").strip().lower()
412
- if enable_env in {"0", "false", "off", "no"}:
413
- logger.info("MatAnyone disabled by ENABLE_MATANY=0.")
414
- meta["disabled"] = True
415
- return None, False, meta
416
-
417
- try:
418
- try:
419
- from inference_core import InferenceCore # type: ignore
420
- except Exception:
421
- from matanyone.inference.inference_core import InferenceCore # type: ignore
422
- meta["matany_import_ok"] = True
423
- except Exception as e:
424
- logger.warning(f"MatAnyone import failed: {e}")
425
- return None, False, meta
426
-
427
- device = _pick_device("MATANY_DEVICE")
428
- repo_id = os.environ.get("MATANY_REPO_ID", "")
429
- ckpt = os.environ.get("MATANY_CHECKPOINT", "")
430
-
431
- # Check if this fork needs a prebuilt network
432
- try:
433
- sig = inspect.signature(InferenceCore)
434
- if "network" in sig.parameters and sig.parameters["network"].default is inspect._empty:
435
- logger.error(
436
- "This MatAnyone fork expects `InferenceCore(network=...)`. "
437
- "Pin a fork/commit that supplies a checkpoint-based constructor, "
438
- "or set ENABLE_MATANY=0 to skip."
439
- )
440
- meta["needs_network_arg"] = True
441
- return None, False, meta
442
- except Exception:
443
- pass
444
-
445
- candidates = [
446
- {"kwargs": {"repo_id": repo_id or None, "checkpoint": ckpt or None, "device": device}},
447
- {"kwargs": {"checkpoint": ckpt or None, "device": device}},
448
- {"args": (), "kwargs": {"device": device}},
449
- ]
450
- last_err = None
451
- for cand in candidates:
452
- try:
453
- matany = InferenceCore(*cand.get("args", ()), **cand.get("kwargs", {}))
454
- meta["matany_init_ok"] = True
455
- meta["matany_device"] = device
456
- meta["matany_repo_id"] = repo_id or "(unset)"
457
- meta["matany_checkpoint"] = ckpt or "(unset)"
458
- return matany, True, meta
459
- except Exception as e:
460
- last_err = e
461
- continue
462
-
463
- logger.error(f"MatAnyone init failed with all fallbacks: {last_err}")
464
- return None, False, meta
465
-
466
- def run_matany(matany: object,
467
- video_path: Union[str, Path],
468
- first_mask_path: Union[str, Path],
469
- work_dir: Union[str, Path]) -> Tuple[Optional[str], Optional[str], bool]:
470
- """Return (foreground_video_path, alpha_video_path, ok)."""
471
- if matany is None:
472
- return None, None, False
473
- try:
474
- if hasattr(matany, "process_video"):
475
- out = matany.process_video(input_path=str(video_path), mask_path=str(first_mask_path), output_dir=str(work_dir))
476
- if isinstance(out, (list, tuple)) and len(out) >= 2:
477
- return str(out[0]), str(out[1]), True
478
- if isinstance(out, dict):
479
- fg = out.get("foreground") or out.get("fg") or out.get("foreground_path")
480
- al = out.get("alpha") or out.get("alpha_path")
481
- if fg and al:
482
- return str(fg), str(al), True
483
-
484
- if hasattr(matany, "run"):
485
- out = matany.run(video_path=str(video_path), seed_mask=str(first_mask_path), out_dir=str(work_dir))
486
- if isinstance(out, dict):
487
- fg = out.get("foreground") or out.get("fg") or out.get("foreground_path")
488
- al = out.get("alpha") or out.get("alpha_path")
489
- if fg and al:
490
- return str(fg), str(al), True
491
-
492
- logger.error("MatAnyone returned no usable paths.")
493
- return None, None, False
494
- except Exception as e:
495
- logger.warning(f"MatAnyone processing failed: {e}")
496
- return None, None, False
497
-
498
- # --------------------------------------------------------------------------------------
499
- # Fallback Functions
500
- # --------------------------------------------------------------------------------------
501
- def fallback_mask(first_frame_bgr: np.ndarray) -> np.ndarray:
502
- """Prefer MediaPipe; fallback to GrabCut. Returns uint8 mask 0/255."""
503
- h, w = first_frame_bgr.shape[:2]
504
- if _HAS_MEDIAPIPE:
505
- try:
506
- mp_selfie = mp.solutions.selfie_segmentation
507
- with mp_selfie.SelfieSegmentation(model_selection=1) as segmenter:
508
- rgb = cv2.cvtColor(first_frame_bgr, cv2.COLOR_BGR2RGB)
509
- res = segmenter.process(rgb)
510
- m = (np.clip(res.segmentation_mask, 0, 1) > 0.5).astype(np.uint8) * 255
511
- m = cv2.medianBlur(m, 5)
512
- return m
513
- except Exception as e:
514
- logger.warning(f"MediaPipe fallback failed: {e}")
515
-
516
- # Ultimate fallback: GrabCut
517
- mask = np.zeros((h, w), np.uint8)
518
- rect = (int(0.1*w), int(0.1*h), int(0.8*w), int(0.8*h))
519
- bgdModel = np.zeros((1, 65), np.float64)
520
- fgdModel = np.zeros((1, 65), np.float64)
521
- try:
522
- cv2.grabCut(first_frame_bgr, mask, rect, bgdModel, fgdModel, 5, cv2.GC_INIT_WITH_RECT)
523
- mask_bin = np.where((mask == cv2.GC_FGD) | (mask == cv2.GC_PR_FGD), 255, 0).astype(np.uint8)
524
- return mask_bin
525
- except Exception as e:
526
- logger.warning(f"GrabCut failed: {e}")
527
- return np.zeros((h, w), dtype=np.uint8)
528
-
529
- def composite_video(fg_path: Union[str, Path],
530
- alpha_path: Union[str, Path],
531
- bg_image_path: Union[str, Path],
532
- out_path: Union[str, Path],
533
- fps: int,
534
- size: Tuple[int, int]) -> bool:
535
- """Blend MatAnyone FG+ALPHA over background using pro compositor."""
536
- fg_cap = cv2.VideoCapture(str(fg_path))
537
- al_cap = cv2.VideoCapture(str(alpha_path))
538
- if not fg_cap.isOpened() or not al_cap.isOpened():
539
- return False
540
-
541
- w, h = size
542
- bg = cv2.imread(str(bg_image_path), cv2.IMREAD_COLOR)
543
- if bg is None:
544
- bg = np.full((h, w, 3), 127, dtype=np.uint8)
545
- bg_f = _resize_keep_ar(bg, (w, h))
546
-
547
- if _probe_ffmpeg():
548
- tmp_out = Path(str(out_path) + ".tmp.mp4")
549
- writer = _video_writer(tmp_out, fps, (w, h))
550
- post_h264 = True
551
- else:
552
- writer = _video_writer(Path(out_path), fps, (w, h))
553
- post_h264 = False
554
-
555
- ok_any = False
556
- try:
557
- while True:
558
- ok_fg, fg = fg_cap.read()
559
- ok_al, al = al_cap.read()
560
- if not ok_fg or not ok_al:
561
- break
562
- fg = cv2.resize(fg, (w, h), interpolation=cv2.INTER_CUBIC)
563
- al_gray = cv2.cvtColor(cv2.resize(al, (w, h)), cv2.COLOR_BGR2GRAY)
564
-
565
- comp = _composite_frame_pro(
566
- cv2.cvtColor(fg, cv2.COLOR_BGR2RGB),
567
- al_gray,
568
- cv2.cvtColor(bg_f, cv2.COLOR_BGR2RGB)
569
- )
570
- writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR))
571
- ok_any = True
572
- finally:
573
- fg_cap.release()
574
- al_cap.release()
575
- writer.release()
576
-
577
- if post_h264 and ok_any:
578
- try:
579
- cmd = [
580
- _ffmpeg_bin(), "-y",
581
- "-i", str(tmp_out),
582
- "-c:v", "libx264", "-pix_fmt", "yuv420p", "-movflags", "+faststart",
583
- str(out_path)
584
- ]
585
- subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
586
- tmp_out.unlink(missing_ok=True)
587
- except Exception as e:
588
- logger.warning(f"ffmpeg finalize failed: {e}")
589
- Path(out_path).unlink(missing_ok=True)
590
- tmp_out.replace(out_path)
591
-
592
- return ok_any
593
-
594
- def fallback_composite(video_path: Union[str, Path],
595
- mask_path: Union[str, Path],
596
- bg_image_path: Union[str, Path],
597
- out_path: Union[str, Path]) -> bool:
598
- """Static-mask compositing using pro compositor."""
599
- mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
600
- cap = cv2.VideoCapture(str(video_path))
601
- if mask is None or not cap.isOpened():
602
- return False
603
-
604
- w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0)
605
- h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0)
606
- fps = int(round(cap.get(cv2.CAP_PROP_FPS) or 25))
607
-
608
- bg = cv2.imread(str(bg_image_path), cv2.IMREAD_COLOR)
609
- if bg is None:
610
- bg = np.full((h, w, 3), 127, dtype=np.uint8)
611
-
612
- mask_resized = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
613
- bg_f = _resize_keep_ar(bg, (w, h))
614
-
615
- if _probe_ffmpeg():
616
- tmp_out = Path(str(out_path) + ".tmp.mp4")
617
- writer = _video_writer(tmp_out, fps, (w, h))
618
- use_post_ffmpeg = True
619
- else:
620
- writer = _video_writer(Path(out_path), fps, (w, h))
621
- use_post_ffmpeg = False
622
-
623
- ok_any = False
624
- try:
625
- while True:
626
- ok, frame = cap.read()
627
- if not ok:
628
- break
629
- comp = _composite_frame_pro(
630
- cv2.cvtColor(frame, cv2.COLOR_BGR2RGB),
631
- mask_resized,
632
- cv2.cvtColor(bg_f, cv2.COLOR_BGR2RGB)
633
- )
634
- writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR))
635
- ok_any = True
636
- finally:
637
- cap.release()
638
- writer.release()
639
-
640
- if use_post_ffmpeg and ok_any:
641
- try:
642
- cmd = [
643
- _ffmpeg_bin(), "-y",
644
- "-i", str(tmp_out),
645
- "-c:v", "libx264", "-pix_fmt", "yuv420p", "-movflags", "+faststart",
646
- str(out_path)
647
- ]
648
- subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
649
- tmp_out.unlink(missing_ok=True)
650
- except Exception as e:
651
- logger.warning(f"ffmpeg H.264 finalize failed: {e}")
652
- Path(out_path).unlink(missing_ok=True)
653
- tmp_out.replace(out_path)
654
-
655
- return ok_any
656
-
657
- # --------------------------------------------------------------------------------------
658
- # Stage-A (Transparent Export) Functions
659
- # --------------------------------------------------------------------------------------
660
- def _checkerboard_bg(w: int, h: int, tile: int = 32) -> np.ndarray:
661
- """RGB checkerboard for preview when no real alpha is possible."""
662
- y, x = np.mgrid[0:h, 0:w]
663
- c = ((x // tile) + (y // tile)) % 2
664
- a = np.where(c == 0, 200, 150).astype(np.uint8)
665
- return np.stack([a, a, a], axis=-1)
666
-
667
- def _build_stage_a_rgba_vp9_from_fg_alpha(
668
- fg_path: Union[str, Path],
669
- alpha_path: Union[str, Path],
670
- out_webm: Union[str, Path],
671
- fps: int,
672
- size: Tuple[int, int],
673
- src_audio: Optional[Union[str, Path]] = None,
674
- ) -> bool:
675
- """Merge FG+ALPHA → RGBA WebM (VP9 with alpha)."""
676
- if not _probe_ffmpeg():
677
- return False
678
- w, h = size
679
- try:
680
- cmd = [_ffmpeg_bin(), "-y", "-i", str(fg_path), "-i", str(alpha_path)]
681
- if src_audio:
682
- cmd += ["-i", str(src_audio)]
683
- fcx = f"[1:v]format=gray,scale={w}:{h},fps={fps}[al];" \
684
- f"[0:v]scale={w}:{h},fps={fps}[fg];" \
685
- f"[fg][al]alphamerge[outv]"
686
- cmd += ["-filter_complex", fcx, "-map", "[outv]"]
687
- if src_audio:
688
- cmd += ["-map", "2:a:0?", "-c:a", "libopus", "-b:a", "128k"]
689
- cmd += [
690
- "-c:v", "libvpx-vp9", "-pix_fmt", "yuva420p",
691
- "-crf", os.environ.get("STAGEA_VP9_CRF", "28"),
692
- "-b:v", "0", "-row-mt", "1", "-shortest", str(out_webm),
693
- ]
694
- subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
695
- return True
696
- except Exception as e:
697
- logger.warning(f"Stage-A VP9(alpha) build failed: {e}")
698
- return False
699
-
700
- def _build_stage_a_rgba_vp9_from_mask(
701
- video_path: Union[str, Path],
702
- mask_png: Union[str, Path],
703
- out_webm: Union[str, Path],
704
- fps: int,
705
- size: Tuple[int, int],
706
- ) -> bool:
707
- """Merge original video + static mask → RGBA WebM (VP9 with alpha)."""
708
- if not _probe_ffmpeg():
709
- return False
710
- w, h = size
711
- try:
712
- cmd = [
713
- _ffmpeg_bin(), "-y",
714
- "-i", str(video_path),
715
- "-loop", "1", "-i", str(mask_png),
716
- "-filter_complex",
717
- f"[1:v]format=gray,scale={w}:{h},fps={fps}[al];"
718
- f"[0:v]scale={w}:{h},fps={fps}[fg];"
719
- f"[fg][al]alphamerge[outv]",
720
- "-map", "[outv]",
721
- "-c:v", "libvpx-vp9", "-pix_fmt", "yuva420p",
722
- "-crf", os.environ.get("STAGEA_VP9_CRF", "28"),
723
- "-b:v", "0", "-row-mt", "1", "-shortest", str(out_webm),
724
- ]
725
- subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
726
- return True
727
- except Exception as e:
728
- logger.warning(f"Stage-A VP9(alpha) (mask) build failed: {e}")
729
- return False
730
-
731
- def _build_stage_a_checkerboard_from_fg_alpha(
732
- fg_path: Union[str, Path],
733
- alpha_path: Union[str, Path],
734
- out_mp4: Union[str, Path],
735
- fps: int,
736
- size: Tuple[int, int],
737
- ) -> bool:
738
- """Preview: FG+ALPHA over checkerboard → MP4 (no real alpha)."""
739
- fg_cap = cv2.VideoCapture(str(fg_path))
740
- al_cap = cv2.VideoCapture(str(alpha_path))
741
- if not fg_cap.isOpened() or not al_cap.isOpened():
742
- return False
743
- w, h = size
744
- writer = _video_writer(Path(out_mp4), fps, (w, h))
745
- bg = _checkerboard_bg(w, h)
746
- ok_any = False
747
- try:
748
- while True:
749
- okf, fg = fg_cap.read()
750
- oka, al = al_cap.read()
751
- if not okf or not oka:
752
- break
753
- fg = cv2.resize(fg, (w, h))
754
- al = cv2.cvtColor(cv2.resize(al, (w, h)), cv2.COLOR_BGR2GRAY)
755
- comp = _composite_frame_pro(cv2.cvtColor(fg, cv2.COLOR_BGR2RGB), al, bg)
756
- writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR))
757
- ok_any = True
758
- finally:
759
- fg_cap.release()
760
- al_cap.release()
761
- writer.release()
762
- return ok_any
763
-
764
- def _build_stage_a_checkerboard_from_mask(
765
- video_path: Union[str, Path],
766
- mask_png: Union[str, Path],
767
- out_mp4: Union[str, Path],
768
- fps: int,
769
- size: Tuple[int, int],
770
- ) -> bool:
771
- """Preview: original video + static mask over checkerboard → MP4."""
772
- cap = cv2.VideoCapture(str(video_path))
773
- if not cap.isOpened():
774
- return False
775
- w, h = size
776
- mask = cv2.imread(str(mask_png), cv2.IMREAD_GRAYSCALE)
777
- if mask is None:
778
- return False
779
- mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
780
- writer = _video_writer(Path(out_mp4), fps, (w, h))
781
- bg = _checkerboard_bg(w, h)
782
- ok_any = False
783
- try:
784
- while True:
785
- ok, frame = cap.read()
786
- if not ok:
787
- break
788
- frame = cv2.resize(frame, (w, h))
789
- comp = _composite_frame_pro(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), mask, bg)
790
- writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR))
791
- ok_any = True
792
- finally:
793
- cap.release()
794
- writer.release()
795
- return ok_any
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/__init__.py CHANGED
@@ -0,0 +1,795 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ BackgroundFX Pro - Model Loading & Utilities
4
+ ===========================================
5
+ Contains all model loading, inference functions, and utility functions
6
+ moved from the main pipeline for better organization.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import os
12
+ import sys
13
+ import cv2
14
+ import subprocess
15
+ import inspect
16
+ import logging
17
+ from pathlib import Path
18
+ from typing import Optional, Tuple, Dict, Any, Union
19
+
20
+ import numpy as np
21
+ import yaml
22
+ import torch # For memory management and CUDA operations
23
+ import torch # For memory management and CUDA operations
24
+
25
+ # --------------------------------------------------------------------------------------
26
+ # Logging
27
+ # --------------------------------------------------------------------------------------
28
+ logger = logging.getLogger("backgroundfx_pro")
29
+
30
+ # --------------------------------------------------------------------------------------
31
+ # Optional dependencies
32
+ # --------------------------------------------------------------------------------------
33
+ try:
34
+ import mediapipe as mp # type: ignore
35
+ _HAS_MEDIAPIPE = True
36
+ except Exception:
37
+ _HAS_MEDIAPIPE = False
38
+
39
+ # --------------------------------------------------------------------------------------
40
+ # Path setup for third_party repos
41
+ # --------------------------------------------------------------------------------------
42
+ ROOT = Path(__file__).resolve().parent
43
+ TP_SAM2 = Path(os.environ.get("THIRD_PARTY_SAM2_DIR", ROOT / "third_party" / "sam2")).resolve()
44
+ TP_MATANY = Path(os.environ.get("THIRD_PARTY_MATANY_DIR", ROOT / "third_party" / "matanyone")).resolve()
45
+
46
+ def _add_sys_path(p: Path) -> None:
47
+ p_str = str(p)
48
+ if p_str not in sys.path:
49
+ sys.path.insert(0, p_str)
50
+
51
+ _add_sys_path(TP_SAM2)
52
+ _add_sys_path(TP_MATANY)
53
+
54
+ # --------------------------------------------------------------------------------------
55
+ # Basic Utilities
56
+ # --------------------------------------------------------------------------------------
57
+ def _ffmpeg_bin() -> str:
58
+ return os.environ.get("FFMPEG_BIN", "ffmpeg")
59
+
60
+ def _probe_ffmpeg() -> bool:
61
+ try:
62
+ subprocess.run([_ffmpeg_bin(), "-version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
63
+ return True
64
+ except Exception:
65
+ return False
66
+
67
+ def _has_cuda() -> bool:
68
+ try:
69
+ import torch # type: ignore
70
+ return torch.cuda.is_available()
71
+ except Exception:
72
+ return False
73
+
74
+ def _pick_device(env_key: str) -> str:
75
+ requested = os.environ.get(env_key, "").strip().lower()
76
+ if requested in {"cuda", "cpu"}:
77
+ return requested
78
+ return "cuda" if _has_cuda() else "cpu"
79
+
80
+ def _ensure_dir(p: Path) -> None:
81
+ p.mkdir(parents=True, exist_ok=True)
82
+
83
+ def _cv_read_first_frame(video_path: Union[str, Path]) -> Tuple[Optional[np.ndarray], int, Tuple[int, int]]:
84
+ cap = cv2.VideoCapture(str(video_path))
85
+ if not cap.isOpened():
86
+ return None, 0, (0, 0)
87
+ fps = int(round(cap.get(cv2.CAP_PROP_FPS) or 25))
88
+ ok, frame = cap.read()
89
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0)
90
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0)
91
+ cap.release()
92
+ if not ok:
93
+ return None, fps, (w, h)
94
+ return frame, fps, (w, h)
95
+
96
+ def _save_mask_png(mask: np.ndarray, path: Union[str, Path]) -> str:
97
+ if mask.dtype == bool:
98
+ mask = (mask.astype(np.uint8) * 255)
99
+ elif mask.dtype != np.uint8:
100
+ mask = np.clip(mask, 0, 255).astype(np.uint8)
101
+ cv2.imwrite(str(path), mask)
102
+ return str(path)
103
+
104
+ def _resize_keep_ar(image: np.ndarray, target_wh: Tuple[int, int]) -> np.ndarray:
105
+ tw, th = target_wh
106
+ h, w = image.shape[:2]
107
+ if h == 0 or w == 0 or tw == 0 or th == 0:
108
+ return image
109
+ scale = min(tw / w, th / h)
110
+ nw, nh = max(1, int(round(w * scale))), max(1, int(round(h * scale)))
111
+ resized = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_CUBIC)
112
+ canvas = np.zeros((th, tw, 3), dtype=resized.dtype)
113
+ x0 = (tw - nw) // 2
114
+ y0 = (th - nh) // 2
115
+ canvas[y0:y0+nh, x0:x0+nw] = resized
116
+ return canvas
117
+
118
+ def _video_writer(out_path: Path, fps: int, size: Tuple[int, int]) -> cv2.VideoWriter:
119
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
120
+ return cv2.VideoWriter(str(out_path), fourcc, max(1, fps), size)
121
+
122
+ def _mux_audio(src_video: Union[str, Path], silent_video: Union[str, Path], out_path: Union[str, Path]) -> bool:
123
+ """Copy video from silent_video + audio from src_video into out_path (AAC)."""
124
+ try:
125
+ cmd = [
126
+ _ffmpeg_bin(), "-y",
127
+ "-i", str(silent_video),
128
+ "-i", str(src_video),
129
+ "-map", "0:v:0",
130
+ "-map", "1:a:0?",
131
+ "-c:v", "copy",
132
+ "-c:a", "aac", "-b:a", "192k",
133
+ "-shortest",
134
+ str(out_path)
135
+ ]
136
+ subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
137
+ return True
138
+ except Exception as e:
139
+ logger.warning(f"Audio mux failed; returning silent video. Reason: {e}")
140
+ return False
141
+
142
+ # --------------------------------------------------------------------------------------
143
+ # Compositing & Image Processing
144
+ # --------------------------------------------------------------------------------------
145
+ def _refine_alpha(alpha: np.ndarray, erode_px: int = 1, dilate_px: int = 2, blur_px: float = 1.5) -> np.ndarray:
146
+ """Erode→dilate + gentle blur → float alpha in [0,1]."""
147
+ if alpha.dtype != np.float32:
148
+ a = alpha.astype(np.float32)
149
+ if a.max() > 1.0:
150
+ a = a / 255.0
151
+ else:
152
+ a = alpha.copy()
153
+
154
+ a_u8 = np.clip(np.round(a * 255.0), 0, 255).astype(np.uint8)
155
+ if erode_px > 0:
156
+ k = max(1, int(erode_px))
157
+ a_u8 = cv2.erode(a_u8, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k)), iterations=1)
158
+ if dilate_px > 0:
159
+ k = max(1, int(dilate_px))
160
+ a_u8 = cv2.dilate(a_u8, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k)), iterations=1)
161
+ a = a_u8.astype(np.float32) / 255.0
162
+
163
+ if blur_px and blur_px > 0:
164
+ rad = max(1, int(round(blur_px)))
165
+ a = cv2.GaussianBlur(a, (rad | 1, rad | 1), 0)
166
+
167
+ return np.clip(a, 0.0, 1.0)
168
+
169
+ def _to_linear(rgb: np.ndarray, gamma: float = 2.2) -> np.ndarray:
170
+ x = np.clip(rgb.astype(np.float32) / 255.0, 0.0, 1.0)
171
+ return np.power(x, gamma)
172
+
173
+ def _to_srgb(lin: np.ndarray, gamma: float = 2.2) -> np.ndarray:
174
+ x = np.clip(lin, 0.0, 1.0)
175
+ return np.clip(np.power(x, 1.0 / gamma) * 255.0, 0, 255).astype(np.uint8)
176
+
177
+ def _light_wrap(bg_rgb: np.ndarray, alpha01: np.ndarray, radius: int = 5, amount: float = 0.18) -> np.ndarray:
178
+ """Simple light wrap from background into subject edges."""
179
+ r = max(1, int(radius))
180
+ inv = 1.0 - alpha01
181
+ inv_blur = cv2.GaussianBlur(inv, (r | 1, r | 1), 0)
182
+ lw = (bg_rgb.astype(np.float32) * inv_blur[..., None] * float(amount))
183
+ return lw
184
+
185
+ def _despill_edges(fg_rgb: np.ndarray, alpha01: np.ndarray, amount: float = 0.35) -> np.ndarray:
186
+ """Reduce saturation in boundary band (alpha≈0.5) to remove old-background tint."""
187
+ w = 1.0 - 2.0 * np.abs(alpha01 - 0.5) # bell-shaped weight
188
+ w = np.clip(w, 0.0, 1.0)
189
+ hsv = cv2.cvtColor(fg_rgb.astype(np.uint8), cv2.COLOR_RGB2HSV).astype(np.float32)
190
+ H, S, V = cv2.split(hsv)
191
+ S = S * (1.0 - amount * w)
192
+ hsv2 = cv2.merge([H, np.clip(S, 0, 255), V])
193
+ out = cv2.cvtColor(hsv2.astype(np.uint8), cv2.COLOR_HSV2RGB)
194
+ return out
195
+
196
+ def _composite_frame_pro(fg_rgb: np.ndarray, alpha: np.ndarray, bg_rgb: np.ndarray,
197
+ erode_px: int = None, dilate_px: int = None, blur_px: float = None,
198
+ lw_radius: int = None, lw_amount: float = None,
199
+ despill_amount: float = None) -> np.ndarray:
200
+ """Gamma-aware composite + edge refinement + light wrap + boundary de-spill."""
201
+ erode_px = erode_px if erode_px is not None else int(os.environ.get("EDGE_ERODE", "1"))
202
+ dilate_px = dilate_px if dilate_px is not None else int(os.environ.get("EDGE_DILATE", "2"))
203
+ blur_px = blur_px if blur_px is not None else float(os.environ.get("EDGE_BLUR", "1.5"))
204
+ lw_radius = lw_radius if lw_radius is not None else int(os.environ.get("LIGHTWRAP_RADIUS", "5"))
205
+ lw_amount = lw_amount if lw_amount is not None else float(os.environ.get("LIGHTWRAP_AMOUNT", "0.18"))
206
+ despill_amount = despill_amount if despill_amount is not None else float(os.environ.get("DESPILL_AMOUNT", "0.35"))
207
+
208
+ # refine alpha [0,1]
209
+ a = _refine_alpha(alpha, erode_px=erode_px, dilate_px=dilate_px, blur_px=blur_px)
210
+
211
+ # edge de-spill: temper saturation where a≈0.5
212
+ fg_rgb = _despill_edges(fg_rgb, a, amount=despill_amount)
213
+
214
+ # linearize for better blending
215
+ fg_lin = _to_linear(fg_rgb)
216
+ bg_lin = _to_linear(bg_rgb)
217
+
218
+ # light wrap
219
+ lw = _light_wrap(bg_rgb, a, radius=lw_radius, amount=lw_amount)
220
+ lw_lin = _to_linear(np.clip(lw, 0, 255).astype(np.uint8))
221
+
222
+ comp_lin = fg_lin * a[..., None] + bg_lin * (1.0 - a[..., None]) + lw_lin
223
+ comp = _to_srgb(comp_lin)
224
+ return comp
225
+
226
+ # --------------------------------------------------------------------------------------
227
+ # SAM2 Integration
228
+ # --------------------------------------------------------------------------------------
229
+ def _resolve_sam2_cfg(cfg_str: str) -> str:
230
+ """Make the SAM2 config path absolute (prefer inside TP_SAM2)."""
231
+ cfg_path = Path(cfg_str)
232
+ if not cfg_path.is_absolute():
233
+ candidate = TP_SAM2 / cfg_path
234
+ if candidate.exists():
235
+ return str(candidate)
236
+ if cfg_path.exists():
237
+ return str(cfg_path)
238
+ # Last resort: common defaults inside the repo
239
+ for name in ["configs/sam2/sam2_hiera_l.yaml", "configs/sam2/sam2_hiera_b.yaml", "configs/sam2/sam2_hiera_s.yaml"]:
240
+ p = TP_SAM2 / name
241
+ if p.exists():
242
+ return str(p)
243
+ return str(cfg_str) # let build_sam2 raise a clear error
244
+
245
+ def _find_hiera_config_if_hieradet(cfg_path: str) -> Optional[str]:
246
+ """If config references 'hieradet', try to find a 'hiera' config."""
247
+ try:
248
+ with open(cfg_path, "r") as f:
249
+ data = yaml.safe_load(f)
250
+ target = None
251
+ model = data.get("model", {})
252
+ enc = (model.get("image_encoder") or {})
253
+ trunk = (enc.get("trunk") or {})
254
+ target = trunk.get("_target_") or trunk.get("target")
255
+ if isinstance(target, str) and "hieradet" in target:
256
+ for y in TP_SAM2.rglob("*.yaml"):
257
+ try:
258
+ with open(y, "r") as f2:
259
+ d2 = yaml.safe_load(f2)
260
+ m2 = (d2 or {}).get("model", {})
261
+ e2 = (m2.get("image_encoder") or {})
262
+ t2 = (e2.get("trunk") or {})
263
+ tgt2 = t2.get("_target_") or t2.get("target")
264
+ if isinstance(tgt2, str) and ".hiera." in tgt2:
265
+ logger.info(f"SAM2: switching config from 'hieradet' → 'hiera': {y}")
266
+ return str(y)
267
+ except Exception:
268
+ continue
269
+ except Exception:
270
+ pass
271
+ return None
272
+
273
+ def load_sam2() -> Tuple[Optional[object], bool, Dict[str, Any]]:
274
+ """Robust SAM2 loader with config resolution and error handling."""
275
+ meta = {"sam2_import_ok": False, "sam2_init_ok": False}
276
+ try:
277
+ from sam2.build_sam import build_sam2 # type: ignore
278
+ from sam2.sam2_image_predictor import SAM2ImagePredictor # type: ignore
279
+ meta["sam2_import_ok"] = True
280
+ except Exception as e:
281
+ logger.warning(f"SAM2 import failed: {e}")
282
+ return None, False, meta
283
+
284
+ device = _pick_device("SAM2_DEVICE")
285
+ cfg_env = os.environ.get("SAM2_MODEL_CFG", "configs/sam2/sam2_hiera_l.yaml")
286
+ cfg = _resolve_sam2_cfg(cfg_env)
287
+ ckpt = os.environ.get("SAM2_CHECKPOINT", "")
288
+
289
+ def _try_build(cfg_path: str):
290
+ params = set(inspect.signature(build_sam2).parameters.keys())
291
+ kwargs = {}
292
+ if "config_file" in params:
293
+ kwargs["config_file"] = cfg_path
294
+ elif "model_cfg" in params:
295
+ kwargs["model_cfg"] = cfg_path
296
+ if ckpt:
297
+ if "checkpoint" in params:
298
+ kwargs["checkpoint"] = ckpt
299
+ elif "ckpt_path" in params:
300
+ kwargs["ckpt_path"] = ckpt
301
+ elif "weights" in params:
302
+ kwargs["weights"] = ckpt
303
+ if "device" in params:
304
+ kwargs["device"] = device
305
+ try:
306
+ return build_sam2(**kwargs)
307
+ except TypeError:
308
+ pos = [cfg_path]
309
+ if ckpt:
310
+ pos.append(ckpt)
311
+ if "device" not in kwargs:
312
+ pos.append(device)
313
+ return build_sam2(*pos)
314
+
315
+ try:
316
+ try:
317
+ sam = _try_build(cfg)
318
+ except Exception as e1:
319
+ alt_cfg = _find_hiera_config_if_hieradet(cfg)
320
+ if alt_cfg:
321
+ logger.info(f"SAM2: retrying with alt config: {alt_cfg}")
322
+ sam = _try_build(alt_cfg)
323
+ cfg = alt_cfg
324
+ else:
325
+ raise
326
+
327
+ predictor = SAM2ImagePredictor(sam)
328
+ meta.update({
329
+ "sam2_init_ok": True,
330
+ "sam2_device": device,
331
+ "sam2_cfg": cfg,
332
+ "sam2_ckpt": ckpt or "(repo default)"
333
+ })
334
+ return predictor, True, meta
335
+ except Exception as e:
336
+ logger.error(f"SAM2 init failed: {e}")
337
+ return None, False, meta
338
+
339
+ def run_sam2_mask(predictor: object,
340
+ first_frame_bgr: np.ndarray,
341
+ point: Optional[Tuple[int, int]] = None,
342
+ auto: bool = False) -> Tuple[Optional[np.ndarray], bool]:
343
+ """Return (mask_uint8_0_255, ok)."""
344
+ if predictor is None:
345
+ return None, False
346
+ try:
347
+ rgb = cv2.cvtColor(first_frame_bgr, cv2.COLOR_BGR2RGB)
348
+ predictor.set_image(rgb)
349
+
350
+ if auto:
351
+ h, w = rgb.shape[:2]
352
+ box = np.array([int(0.05*w), int(0.05*h), int(0.95*w), int(0.95*h)])
353
+ masks, _, _ = predictor.predict(box=box)
354
+ elif point is not None:
355
+ x, y = int(point[0]), int(point[1])
356
+ pts = np.array([[x, y]], dtype=np.int32)
357
+ labels = np.array([1], dtype=np.int32)
358
+ masks, _, _ = predictor.predict(point_coords=pts, point_labels=labels)
359
+ else:
360
+ h, w = rgb.shape[:2]
361
+ box = np.array([int(0.1*w), int(0.1*h), int(0.9*w), int(0.9*h)])
362
+ masks, _, _ = predictor.predict(box=box)
363
+
364
+ if masks is None or len(masks) == 0:
365
+ return None, False
366
+
367
+ m = masks[0].astype(np.uint8) * 255
368
+ return m, True
369
+ except Exception as e:
370
+ logger.warning(f"SAM2 mask failed: {e}")
371
+ return None, False
372
+
373
+ def _refine_mask_grabcut(image_bgr: np.ndarray,
374
+ mask_u8: np.ndarray,
375
+ iters: int = None,
376
+ trimap_erode: int = None,
377
+ trimap_dilate: int = None) -> np.ndarray:
378
+ """Use SAM2 seed as initialization for GrabCut refinement."""
379
+ iters = int(os.environ.get("REFINE_GRABCUT_ITERS", "2")) if iters is None else int(iters)
380
+ e = int(os.environ.get("REFINE_TRIMAP_ERODE", "3")) if trimap_erode is None else int(trimap_erode)
381
+ d = int(os.environ.get("REFINE_TRIMAP_DILATE", "6")) if trimap_dilate is None else int(trimap_dilate)
382
+
383
+ h, w = mask_u8.shape[:2]
384
+ m = (mask_u8 > 127).astype(np.uint8) * 255
385
+
386
+ sure_fg = cv2.erode(m, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (max(1, e), max(1, e))), iterations=1)
387
+ sure_bg = cv2.erode(255 - m, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (max(1, d), max(1, d))), iterations=1)
388
+
389
+ gc_mask = np.full((h, w), cv2.GC_PR_BGD, dtype=np.uint8)
390
+ gc_mask[sure_bg > 0] = cv2.GC_BGD
391
+ gc_mask[sure_fg > 0] = cv2.GC_FGD
392
+
393
+ bgdModel = np.zeros((1, 65), np.float64)
394
+ fgdModel = np.zeros((1, 65), np.float64)
395
+ try:
396
+ cv2.grabCut(image_bgr, gc_mask, None, bgdModel, fgdModel, iters, cv2.GC_INIT_WITH_MASK)
397
+ out = np.where((gc_mask == cv2.GC_FGD) | (gc_mask == cv2.GC_PR_FGD), 255, 0).astype(np.uint8)
398
+ out = cv2.medianBlur(out, 5)
399
+ return out
400
+ except Exception as e:
401
+ logger.warning(f"GrabCut refinement failed; using original mask. Reason: {e}")
402
+ return m
403
+
404
+ # --------------------------------------------------------------------------------------
405
+ # MatAnyone Integration
406
+ # --------------------------------------------------------------------------------------
407
+ def load_matany() -> Tuple[Optional[object], bool, Dict[str, Any]]:
408
+ """MatAnyone loader with disable switch and error handling."""
409
+ meta = {"matany_import_ok": False, "matany_init_ok": False}
410
+
411
+ enable_env = os.environ.get("ENABLE_MATANY", "1").strip().lower()
412
+ if enable_env in {"0", "false", "off", "no"}:
413
+ logger.info("MatAnyone disabled by ENABLE_MATANY=0.")
414
+ meta["disabled"] = True
415
+ return None, False, meta
416
+
417
+ try:
418
+ try:
419
+ from inference_core import InferenceCore # type: ignore
420
+ except Exception:
421
+ from matanyone.inference.inference_core import InferenceCore # type: ignore
422
+ meta["matany_import_ok"] = True
423
+ except Exception as e:
424
+ logger.warning(f"MatAnyone import failed: {e}")
425
+ return None, False, meta
426
+
427
+ device = _pick_device("MATANY_DEVICE")
428
+ repo_id = os.environ.get("MATANY_REPO_ID", "")
429
+ ckpt = os.environ.get("MATANY_CHECKPOINT", "")
430
+
431
+ # Check if this fork needs a prebuilt network
432
+ try:
433
+ sig = inspect.signature(InferenceCore)
434
+ if "network" in sig.parameters and sig.parameters["network"].default is inspect._empty:
435
+ logger.error(
436
+ "This MatAnyone fork expects `InferenceCore(network=...)`. "
437
+ "Pin a fork/commit that supplies a checkpoint-based constructor, "
438
+ "or set ENABLE_MATANY=0 to skip."
439
+ )
440
+ meta["needs_network_arg"] = True
441
+ return None, False, meta
442
+ except Exception:
443
+ pass
444
+
445
+ candidates = [
446
+ {"kwargs": {"repo_id": repo_id or None, "checkpoint": ckpt or None, "device": device}},
447
+ {"kwargs": {"checkpoint": ckpt or None, "device": device}},
448
+ {"args": (), "kwargs": {"device": device}},
449
+ ]
450
+ last_err = None
451
+ for cand in candidates:
452
+ try:
453
+ matany = InferenceCore(*cand.get("args", ()), **cand.get("kwargs", {}))
454
+ meta["matany_init_ok"] = True
455
+ meta["matany_device"] = device
456
+ meta["matany_repo_id"] = repo_id or "(unset)"
457
+ meta["matany_checkpoint"] = ckpt or "(unset)"
458
+ return matany, True, meta
459
+ except Exception as e:
460
+ last_err = e
461
+ continue
462
+
463
+ logger.error(f"MatAnyone init failed with all fallbacks: {last_err}")
464
+ return None, False, meta
465
+
466
+ def run_matany(matany: object,
467
+ video_path: Union[str, Path],
468
+ first_mask_path: Union[str, Path],
469
+ work_dir: Union[str, Path]) -> Tuple[Optional[str], Optional[str], bool]:
470
+ """Return (foreground_video_path, alpha_video_path, ok)."""
471
+ if matany is None:
472
+ return None, None, False
473
+ try:
474
+ if hasattr(matany, "process_video"):
475
+ out = matany.process_video(input_path=str(video_path), mask_path=str(first_mask_path), output_dir=str(work_dir))
476
+ if isinstance(out, (list, tuple)) and len(out) >= 2:
477
+ return str(out[0]), str(out[1]), True
478
+ if isinstance(out, dict):
479
+ fg = out.get("foreground") or out.get("fg") or out.get("foreground_path")
480
+ al = out.get("alpha") or out.get("alpha_path")
481
+ if fg and al:
482
+ return str(fg), str(al), True
483
+
484
+ if hasattr(matany, "run"):
485
+ out = matany.run(video_path=str(video_path), seed_mask=str(first_mask_path), out_dir=str(work_dir))
486
+ if isinstance(out, dict):
487
+ fg = out.get("foreground") or out.get("fg") or out.get("foreground_path")
488
+ al = out.get("alpha") or out.get("alpha_path")
489
+ if fg and al:
490
+ return str(fg), str(al), True
491
+
492
+ logger.error("MatAnyone returned no usable paths.")
493
+ return None, None, False
494
+ except Exception as e:
495
+ logger.warning(f"MatAnyone processing failed: {e}")
496
+ return None, None, False
497
+
498
+ # --------------------------------------------------------------------------------------
499
+ # Fallback Functions
500
+ # --------------------------------------------------------------------------------------
501
+ def fallback_mask(first_frame_bgr: np.ndarray) -> np.ndarray:
502
+ """Prefer MediaPipe; fallback to GrabCut. Returns uint8 mask 0/255."""
503
+ h, w = first_frame_bgr.shape[:2]
504
+ if _HAS_MEDIAPIPE:
505
+ try:
506
+ mp_selfie = mp.solutions.selfie_segmentation
507
+ with mp_selfie.SelfieSegmentation(model_selection=1) as segmenter:
508
+ rgb = cv2.cvtColor(first_frame_bgr, cv2.COLOR_BGR2RGB)
509
+ res = segmenter.process(rgb)
510
+ m = (np.clip(res.segmentation_mask, 0, 1) > 0.5).astype(np.uint8) * 255
511
+ m = cv2.medianBlur(m, 5)
512
+ return m
513
+ except Exception as e:
514
+ logger.warning(f"MediaPipe fallback failed: {e}")
515
+
516
+ # Ultimate fallback: GrabCut
517
+ mask = np.zeros((h, w), np.uint8)
518
+ rect = (int(0.1*w), int(0.1*h), int(0.8*w), int(0.8*h))
519
+ bgdModel = np.zeros((1, 65), np.float64)
520
+ fgdModel = np.zeros((1, 65), np.float64)
521
+ try:
522
+ cv2.grabCut(first_frame_bgr, mask, rect, bgdModel, fgdModel, 5, cv2.GC_INIT_WITH_RECT)
523
+ mask_bin = np.where((mask == cv2.GC_FGD) | (mask == cv2.GC_PR_FGD), 255, 0).astype(np.uint8)
524
+ return mask_bin
525
+ except Exception as e:
526
+ logger.warning(f"GrabCut failed: {e}")
527
+ return np.zeros((h, w), dtype=np.uint8)
528
+
529
+ def composite_video(fg_path: Union[str, Path],
530
+ alpha_path: Union[str, Path],
531
+ bg_image_path: Union[str, Path],
532
+ out_path: Union[str, Path],
533
+ fps: int,
534
+ size: Tuple[int, int]) -> bool:
535
+ """Blend MatAnyone FG+ALPHA over background using pro compositor."""
536
+ fg_cap = cv2.VideoCapture(str(fg_path))
537
+ al_cap = cv2.VideoCapture(str(alpha_path))
538
+ if not fg_cap.isOpened() or not al_cap.isOpened():
539
+ return False
540
+
541
+ w, h = size
542
+ bg = cv2.imread(str(bg_image_path), cv2.IMREAD_COLOR)
543
+ if bg is None:
544
+ bg = np.full((h, w, 3), 127, dtype=np.uint8)
545
+ bg_f = _resize_keep_ar(bg, (w, h))
546
+
547
+ if _probe_ffmpeg():
548
+ tmp_out = Path(str(out_path) + ".tmp.mp4")
549
+ writer = _video_writer(tmp_out, fps, (w, h))
550
+ post_h264 = True
551
+ else:
552
+ writer = _video_writer(Path(out_path), fps, (w, h))
553
+ post_h264 = False
554
+
555
+ ok_any = False
556
+ try:
557
+ while True:
558
+ ok_fg, fg = fg_cap.read()
559
+ ok_al, al = al_cap.read()
560
+ if not ok_fg or not ok_al:
561
+ break
562
+ fg = cv2.resize(fg, (w, h), interpolation=cv2.INTER_CUBIC)
563
+ al_gray = cv2.cvtColor(cv2.resize(al, (w, h)), cv2.COLOR_BGR2GRAY)
564
+
565
+ comp = _composite_frame_pro(
566
+ cv2.cvtColor(fg, cv2.COLOR_BGR2RGB),
567
+ al_gray,
568
+ cv2.cvtColor(bg_f, cv2.COLOR_BGR2RGB)
569
+ )
570
+ writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR))
571
+ ok_any = True
572
+ finally:
573
+ fg_cap.release()
574
+ al_cap.release()
575
+ writer.release()
576
+
577
+ if post_h264 and ok_any:
578
+ try:
579
+ cmd = [
580
+ _ffmpeg_bin(), "-y",
581
+ "-i", str(tmp_out),
582
+ "-c:v", "libx264", "-pix_fmt", "yuv420p", "-movflags", "+faststart",
583
+ str(out_path)
584
+ ]
585
+ subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
586
+ tmp_out.unlink(missing_ok=True)
587
+ except Exception as e:
588
+ logger.warning(f"ffmpeg finalize failed: {e}")
589
+ Path(out_path).unlink(missing_ok=True)
590
+ tmp_out.replace(out_path)
591
+
592
+ return ok_any
593
+
594
+ def fallback_composite(video_path: Union[str, Path],
595
+ mask_path: Union[str, Path],
596
+ bg_image_path: Union[str, Path],
597
+ out_path: Union[str, Path]) -> bool:
598
+ """Static-mask compositing using pro compositor."""
599
+ mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
600
+ cap = cv2.VideoCapture(str(video_path))
601
+ if mask is None or not cap.isOpened():
602
+ return False
603
+
604
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0)
605
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0)
606
+ fps = int(round(cap.get(cv2.CAP_PROP_FPS) or 25))
607
+
608
+ bg = cv2.imread(str(bg_image_path), cv2.IMREAD_COLOR)
609
+ if bg is None:
610
+ bg = np.full((h, w, 3), 127, dtype=np.uint8)
611
+
612
+ mask_resized = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
613
+ bg_f = _resize_keep_ar(bg, (w, h))
614
+
615
+ if _probe_ffmpeg():
616
+ tmp_out = Path(str(out_path) + ".tmp.mp4")
617
+ writer = _video_writer(tmp_out, fps, (w, h))
618
+ use_post_ffmpeg = True
619
+ else:
620
+ writer = _video_writer(Path(out_path), fps, (w, h))
621
+ use_post_ffmpeg = False
622
+
623
+ ok_any = False
624
+ try:
625
+ while True:
626
+ ok, frame = cap.read()
627
+ if not ok:
628
+ break
629
+ comp = _composite_frame_pro(
630
+ cv2.cvtColor(frame, cv2.COLOR_BGR2RGB),
631
+ mask_resized,
632
+ cv2.cvtColor(bg_f, cv2.COLOR_BGR2RGB)
633
+ )
634
+ writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR))
635
+ ok_any = True
636
+ finally:
637
+ cap.release()
638
+ writer.release()
639
+
640
+ if use_post_ffmpeg and ok_any:
641
+ try:
642
+ cmd = [
643
+ _ffmpeg_bin(), "-y",
644
+ "-i", str(tmp_out),
645
+ "-c:v", "libx264", "-pix_fmt", "yuv420p", "-movflags", "+faststart",
646
+ str(out_path)
647
+ ]
648
+ subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
649
+ tmp_out.unlink(missing_ok=True)
650
+ except Exception as e:
651
+ logger.warning(f"ffmpeg H.264 finalize failed: {e}")
652
+ Path(out_path).unlink(missing_ok=True)
653
+ tmp_out.replace(out_path)
654
+
655
+ return ok_any
656
+
657
+ # --------------------------------------------------------------------------------------
658
+ # Stage-A (Transparent Export) Functions
659
+ # --------------------------------------------------------------------------------------
660
+ def _checkerboard_bg(w: int, h: int, tile: int = 32) -> np.ndarray:
661
+ """RGB checkerboard for preview when no real alpha is possible."""
662
+ y, x = np.mgrid[0:h, 0:w]
663
+ c = ((x // tile) + (y // tile)) % 2
664
+ a = np.where(c == 0, 200, 150).astype(np.uint8)
665
+ return np.stack([a, a, a], axis=-1)
666
+
667
+ def _build_stage_a_rgba_vp9_from_fg_alpha(
668
+ fg_path: Union[str, Path],
669
+ alpha_path: Union[str, Path],
670
+ out_webm: Union[str, Path],
671
+ fps: int,
672
+ size: Tuple[int, int],
673
+ src_audio: Optional[Union[str, Path]] = None,
674
+ ) -> bool:
675
+ """Merge FG+ALPHA → RGBA WebM (VP9 with alpha)."""
676
+ if not _probe_ffmpeg():
677
+ return False
678
+ w, h = size
679
+ try:
680
+ cmd = [_ffmpeg_bin(), "-y", "-i", str(fg_path), "-i", str(alpha_path)]
681
+ if src_audio:
682
+ cmd += ["-i", str(src_audio)]
683
+ fcx = f"[1:v]format=gray,scale={w}:{h},fps={fps}[al];" \
684
+ f"[0:v]scale={w}:{h},fps={fps}[fg];" \
685
+ f"[fg][al]alphamerge[outv]"
686
+ cmd += ["-filter_complex", fcx, "-map", "[outv]"]
687
+ if src_audio:
688
+ cmd += ["-map", "2:a:0?", "-c:a", "libopus", "-b:a", "128k"]
689
+ cmd += [
690
+ "-c:v", "libvpx-vp9", "-pix_fmt", "yuva420p",
691
+ "-crf", os.environ.get("STAGEA_VP9_CRF", "28"),
692
+ "-b:v", "0", "-row-mt", "1", "-shortest", str(out_webm),
693
+ ]
694
+ subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
695
+ return True
696
+ except Exception as e:
697
+ logger.warning(f"Stage-A VP9(alpha) build failed: {e}")
698
+ return False
699
+
700
+ def _build_stage_a_rgba_vp9_from_mask(
701
+ video_path: Union[str, Path],
702
+ mask_png: Union[str, Path],
703
+ out_webm: Union[str, Path],
704
+ fps: int,
705
+ size: Tuple[int, int],
706
+ ) -> bool:
707
+ """Merge original video + static mask → RGBA WebM (VP9 with alpha)."""
708
+ if not _probe_ffmpeg():
709
+ return False
710
+ w, h = size
711
+ try:
712
+ cmd = [
713
+ _ffmpeg_bin(), "-y",
714
+ "-i", str(video_path),
715
+ "-loop", "1", "-i", str(mask_png),
716
+ "-filter_complex",
717
+ f"[1:v]format=gray,scale={w}:{h},fps={fps}[al];"
718
+ f"[0:v]scale={w}:{h},fps={fps}[fg];"
719
+ f"[fg][al]alphamerge[outv]",
720
+ "-map", "[outv]",
721
+ "-c:v", "libvpx-vp9", "-pix_fmt", "yuva420p",
722
+ "-crf", os.environ.get("STAGEA_VP9_CRF", "28"),
723
+ "-b:v", "0", "-row-mt", "1", "-shortest", str(out_webm),
724
+ ]
725
+ subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
726
+ return True
727
+ except Exception as e:
728
+ logger.warning(f"Stage-A VP9(alpha) (mask) build failed: {e}")
729
+ return False
730
+
731
+ def _build_stage_a_checkerboard_from_fg_alpha(
732
+ fg_path: Union[str, Path],
733
+ alpha_path: Union[str, Path],
734
+ out_mp4: Union[str, Path],
735
+ fps: int,
736
+ size: Tuple[int, int],
737
+ ) -> bool:
738
+ """Preview: FG+ALPHA over checkerboard → MP4 (no real alpha)."""
739
+ fg_cap = cv2.VideoCapture(str(fg_path))
740
+ al_cap = cv2.VideoCapture(str(alpha_path))
741
+ if not fg_cap.isOpened() or not al_cap.isOpened():
742
+ return False
743
+ w, h = size
744
+ writer = _video_writer(Path(out_mp4), fps, (w, h))
745
+ bg = _checkerboard_bg(w, h)
746
+ ok_any = False
747
+ try:
748
+ while True:
749
+ okf, fg = fg_cap.read()
750
+ oka, al = al_cap.read()
751
+ if not okf or not oka:
752
+ break
753
+ fg = cv2.resize(fg, (w, h))
754
+ al = cv2.cvtColor(cv2.resize(al, (w, h)), cv2.COLOR_BGR2GRAY)
755
+ comp = _composite_frame_pro(cv2.cvtColor(fg, cv2.COLOR_BGR2RGB), al, bg)
756
+ writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR))
757
+ ok_any = True
758
+ finally:
759
+ fg_cap.release()
760
+ al_cap.release()
761
+ writer.release()
762
+ return ok_any
763
+
764
+ def _build_stage_a_checkerboard_from_mask(
765
+ video_path: Union[str, Path],
766
+ mask_png: Union[str, Path],
767
+ out_mp4: Union[str, Path],
768
+ fps: int,
769
+ size: Tuple[int, int],
770
+ ) -> bool:
771
+ """Preview: original video + static mask over checkerboard → MP4."""
772
+ cap = cv2.VideoCapture(str(video_path))
773
+ if not cap.isOpened():
774
+ return False
775
+ w, h = size
776
+ mask = cv2.imread(str(mask_png), cv2.IMREAD_GRAYSCALE)
777
+ if mask is None:
778
+ return False
779
+ mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
780
+ writer = _video_writer(Path(out_mp4), fps, (w, h))
781
+ bg = _checkerboard_bg(w, h)
782
+ ok_any = False
783
+ try:
784
+ while True:
785
+ ok, frame = cap.read()
786
+ if not ok:
787
+ break
788
+ frame = cv2.resize(frame, (w, h))
789
+ comp = _composite_frame_pro(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), mask, bg)
790
+ writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR))
791
+ ok_any = True
792
+ finally:
793
+ cap.release()
794
+ writer.release()
795
+ return ok_any