MogensR commited on
Commit
bcb443b
·
1 Parent(s): 4b817e6
app.py CHANGED
@@ -1,7 +1,17 @@
1
  #!/usr/bin/env python3
2
  """
3
  VideoBackgroundReplacer2 - SAM2 + MatAnyone Integration
 
 
 
 
 
 
 
 
 
4
  """
 
5
  print("=== APP STARTUP DEBUG: app.py starting ===")
6
  import sys
7
  print(f"=== APP STARTUP DEBUG: Python {sys.version} ===")
@@ -12,7 +22,6 @@
12
  import os
13
  os.environ.pop("OMP_NUM_THREADS", None)
14
 
15
- import sys
16
  import logging
17
  import threading
18
  import time
@@ -23,7 +32,7 @@
23
  # Suppress torchvision video deprecation warnings from MatAnyone
24
  warnings.filterwarnings("ignore", category=UserWarning, module="torchvision.io._video_deprecation_warning")
25
 
26
- # --- import path fix for third_party packages (MatAnyone, SAM2) ---
27
  third_party_root = Path(__file__).parent / "third_party"
28
  sam2_path = third_party_root / "sam2"
29
 
@@ -43,7 +52,7 @@
43
 
44
  # DEBUG: try importing MatAnyone and show its location
45
  try:
46
- import matanyone # noqa: F401
47
  import inspect
48
  import os as _os
49
  print("[MATANY] import OK from:", _os.path.dirname(inspect.getfile(matanyone)), flush=True)
@@ -78,9 +87,9 @@ def _heartbeat():
78
  # Safe, minimal startup diagnostics (no long CUDA probes)
79
  # -----------------------------------------------------------------------------
80
  def _safe_startup_diag():
81
- # Torch version only; defer CUDA availability checks to post-launch
82
  try:
83
- import torch # noqa: F401
84
  import importlib
85
  t = importlib.import_module("torch")
86
  logger.info(
@@ -91,6 +100,14 @@ def _safe_startup_diag():
91
  except Exception as e:
92
  logger.warning("Torch not available at startup: %s", e)
93
 
 
 
 
 
 
 
 
 
94
  # nvidia-smi with short timeout (avoid indefinite block)
95
  try:
96
  out = subprocess.run(
@@ -107,7 +124,7 @@ def _safe_startup_diag():
107
 
108
  # Optional perf tuning; never block startup
109
  try:
110
- import perf_tuning # noqa: F401
111
  logger.info("perf_tuning imported successfully.")
112
  except Exception as e:
113
  logger.info("perf_tuning not available: %s", e)
@@ -126,7 +143,6 @@ def _safe_startup_diag():
126
  logger.info(f"[MATANY] probe skipped: {e}")
127
 
128
  # Continue with app startup
129
-
130
  _safe_startup_diag()
131
 
132
  # -----------------------------------------------------------------------------
@@ -166,7 +182,7 @@ def build_ui() -> gr.Blocks:
166
  logger.info("Launching Gradio on %s:%s …", host, port)
167
 
168
  demo = build_ui()
169
- demo.queue(max_size=16)
170
 
171
  threading.Thread(target=_post_launch_diag, daemon=True).start()
172
- demo.launch(server_name=host, server_port=port, show_error=True)
 
1
  #!/usr/bin/env python3
2
  """
3
  VideoBackgroundReplacer2 - SAM2 + MatAnyone Integration
4
+ ================================================
5
+ - Sets up Gradio UI and launches pipeline
6
+ - Aligned with torch==2.3.1+cu121, MatAnyone v1.0.0, SAM2 commit 3c76f73c1a7e7b4a2e8a0a9a3e5b92f7e6e3f2f5
7
+
8
+ Changes (2025-09-16):
9
+ - Aligned with updated pipeline.py and models/
10
+ - Added MatAnyone version logging in startup diagnostics
11
+ - Updated Gradio launch for compatibility with gradio==5.42.0
12
+ - Ensured sys.path and environment variables match Dockerfile
13
  """
14
+
15
  print("=== APP STARTUP DEBUG: app.py starting ===")
16
  import sys
17
  print(f"=== APP STARTUP DEBUG: Python {sys.version} ===")
 
22
  import os
23
  os.environ.pop("OMP_NUM_THREADS", None)
24
 
 
25
  import logging
26
  import threading
27
  import time
 
32
  # Suppress torchvision video deprecation warnings from MatAnyone
33
  warnings.filterwarnings("ignore", category=UserWarning, module="torchvision.io._video_deprecation_warning")
34
 
35
+ # --- import path fix for third_party packages (SAM2) ---
36
  third_party_root = Path(__file__).parent / "third_party"
37
  sam2_path = third_party_root / "sam2"
38
 
 
52
 
53
  # DEBUG: try importing MatAnyone and show its location
54
  try:
55
+ import matanyone
56
  import inspect
57
  import os as _os
58
  print("[MATANY] import OK from:", _os.path.dirname(inspect.getfile(matanyone)), flush=True)
 
87
  # Safe, minimal startup diagnostics (no long CUDA probes)
88
  # -----------------------------------------------------------------------------
89
  def _safe_startup_diag():
90
+ # Torch version
91
  try:
92
+ import torch
93
  import importlib
94
  t = importlib.import_module("torch")
95
  logger.info(
 
100
  except Exception as e:
101
  logger.warning("Torch not available at startup: %s", e)
102
 
103
+ # MatAnyone version
104
+ try:
105
+ import importlib.metadata
106
+ version = importlib.metadata.version("matanyone")
107
+ logger.info(f"[MATANY] MatAnyone version: {version}")
108
+ except Exception:
109
+ logger.info("[MATANY] MatAnyone version unknown")
110
+
111
  # nvidia-smi with short timeout (avoid indefinite block)
112
  try:
113
  out = subprocess.run(
 
124
 
125
  # Optional perf tuning; never block startup
126
  try:
127
+ import perf_tuning
128
  logger.info("perf_tuning imported successfully.")
129
  except Exception as e:
130
  logger.info("perf_tuning not available: %s", e)
 
143
  logger.info(f"[MATANY] probe skipped: {e}")
144
 
145
  # Continue with app startup
 
146
  _safe_startup_diag()
147
 
148
  # -----------------------------------------------------------------------------
 
182
  logger.info("Launching Gradio on %s:%s …", host, port)
183
 
184
  demo = build_ui()
185
+ demo.queue(max_size=16, api_open=False) # Disable public API for security
186
 
187
  threading.Thread(target=_post_launch_diag, daemon=True).start()
188
+ demo.launch(server_name=host, server_port=port, show_error=True)
models/__init__.py CHANGED
@@ -8,8 +8,10 @@
8
  - MatAnyone loader is probe-only here; actual run happens in matanyone_loader.MatAnyoneSession
9
 
10
  Changes (2025-09-16):
 
11
  - Updated load_matany to apply T=1 squeeze patch before InferenceCore import
12
- - Added patch status logging in load_matany
 
13
  - Fixed InferenceCore import path to matanyone.inference.inference_core
14
  """
15
 
@@ -21,6 +23,7 @@
21
  import subprocess
22
  import inspect
23
  import logging
 
24
  from pathlib import Path
25
  from typing import Optional, Tuple, Dict, Any, Union, Callable
26
 
@@ -261,7 +264,7 @@ def _composite_frame_pro(
261
  ) -> np.ndarray:
262
  erode_px = erode_px if erode_px is not None else int(os.environ.get("EDGE_ERODE", "1"))
263
  dilate_px = dilate_px if dilate_px is not None else int(os.environ.get("EDGE_DILATE", "2"))
264
- blur_px = blur_px if blur_px is not None else float(os.environ.get("EDGE_BLUR", "1.5"))
265
  lw_radius = lw_radius if lw_radius is not None else int(os.environ.get("LIGHTWRAP_RADIUS", "5"))
266
  lw_amount = lw_amount if lw_amount is not None else float(os.environ.get("LIGHTWRAP_AMOUNT", "0.18"))
267
  despill_amount = despill_amount if despill_amount is not None else float(os.environ.get("DESPILL_AMOUNT", "0.35"))
@@ -528,6 +531,13 @@ def load_matany() -> Tuple[Optional[object], bool, Dict[str, Any]]:
528
  try:
529
  from matanyone.inference.inference_core import InferenceCore # type: ignore
530
  meta["matany_import_ok"] = True
 
 
 
 
 
 
 
531
  device = _pick_device("MATANY_DEVICE")
532
  repo_id = os.environ.get("MATANY_REPO_ID", "PeiqingYang/MatAnyone")
533
  meta["matany_repo_id"] = repo_id
 
8
  - MatAnyone loader is probe-only here; actual run happens in matanyone_loader.MatAnyoneSession
9
 
10
  Changes (2025-09-16):
11
+ - Aligned with torch==2.3.1+cu121 and MatAnyone v1.0.0
12
  - Updated load_matany to apply T=1 squeeze patch before InferenceCore import
13
+ - Added patch status logging and MatAnyone version
14
+ - Added InferenceCore attributes logging for debugging
15
  - Fixed InferenceCore import path to matanyone.inference.inference_core
16
  """
17
 
 
23
  import subprocess
24
  import inspect
25
  import logging
26
+ import importlib.metadata
27
  from pathlib import Path
28
  from typing import Optional, Tuple, Dict, Any, Union, Callable
29
 
 
264
  ) -> np.ndarray:
265
  erode_px = erode_px if erode_px is not None else int(os.environ.get("EDGE_ERODE", "1"))
266
  dilate_px = dilate_px if dilate_px is not None else int(os.environ.get("EDGE_DILATE", "2"))
267
+ blur_px = blur_px if blur_px is not None else float(os.environ.get("EDGE_BLUR", "1.5"))
268
  lw_radius = lw_radius if lw_radius is not None else int(os.environ.get("LIGHTWRAP_RADIUS", "5"))
269
  lw_amount = lw_amount if lw_amount is not None else float(os.environ.get("LIGHTWRAP_AMOUNT", "0.18"))
270
  despill_amount = despill_amount if despill_amount is not None else float(os.environ.get("DESPILL_AMOUNT", "0.35"))
 
531
  try:
532
  from matanyone.inference.inference_core import InferenceCore # type: ignore
533
  meta["matany_import_ok"] = True
534
+ # Log MatAnyone version and InferenceCore attributes
535
+ try:
536
+ version = importlib.metadata.version("matanyone")
537
+ logger.info(f"[MATANY] MatAnyone version: {version}")
538
+ except Exception:
539
+ logger.info("[MATANY] MatAnyone version unknown")
540
+ logger.debug(f"[MATANY] InferenceCore attributes: {dir(InferenceCore)}")
541
  device = _pick_device("MATANY_DEVICE")
542
  repo_id = os.environ.get("MATANY_REPO_ID", "PeiqingYang/MatAnyone")
543
  meta["matany_repo_id"] = repo_id
models/__pycache__/__init__.cpython-313.pyc CHANGED
Binary files a/models/__pycache__/__init__.cpython-313.pyc and b/models/__pycache__/__init__.cpython-313.pyc differ
 
models/matany_compat_patch.py CHANGED
@@ -1,11 +1,11 @@
1
  #!/usr/bin/env python3
2
  # MatAnyone HF-compat patch: squeeze time dim T=1 before first Conv2d
3
  # Changes (2025-09-16):
4
- # - Added fallback patching for forward/encode if encode_img missing
 
5
  # - Log dir(MatAnyone) and module version for debugging
6
  # - Added isinstance(img, torch.Tensor) for non-tensor safety
7
- # - Enhanced logging with input/output shapes
8
- # - Kept monkey-patch for HF Spaces compatibility
9
 
10
  import logging
11
  import torch
@@ -15,9 +15,9 @@
15
 
16
  def apply_matany_t1_squeeze_guard() -> bool:
17
  """
18
- Monkey-patch MatAnyone.encode_img (or forward/encode) to squeeze [B,1,C,H,W] → [B,C,H,W].
19
  Safe for multi-frame (T>1) as it only squeezes when T==1.
20
- Returns True if patch applied successfully, False otherwise.
21
  """
22
  try:
23
  import matanyone.model.matanyone as M
@@ -29,7 +29,7 @@ def apply_matany_t1_squeeze_guard() -> bool:
29
  return False
30
  MatAnyone = M.MatAnyone
31
 
32
- # Log MatAnyone version and attributes for debugging
33
  try:
34
  version = importlib.metadata.version("matanyone")
35
  log.info(f"[MatAnyCompat] MatAnyone version: {version}")
@@ -37,33 +37,33 @@ def apply_matany_t1_squeeze_guard() -> bool:
37
  log.info("[MatAnyCompat] MatAnyone version unknown")
38
  log.debug(f"[MatAnyCompat] MatAnyone attributes: {dir(MatAnyone)}")
39
 
40
- # Try encode_img first, then fallback to forward or encode
41
- method_name = None
42
- for candidate in ["encode_img", "forward", "encode"]:
43
- if hasattr(MatAnyone, candidate):
44
- method_name = candidate
45
- break
46
- if not method_name:
47
- log.warning("[MatAnyCompat] No patchable method (encode_img, forward, encode) found on MatAnyone")
48
- return False
49
- if getattr(MatAnyone, f"_{method_name}_patched", False):
50
- log.info(f"[MatAnyCompat] {method_name} already patched")
51
- return True
52
 
53
- # Store original method
54
- orig_method = getattr(MatAnyone, method_name)
 
 
 
 
 
 
55
 
56
- def method_compat(self, img, *args, **kwargs):
57
- # Handle inputs that MatAnyone.step turned into [B,1,C,H,W]
58
- try:
59
- if isinstance(img, torch.Tensor) and img.dim() == 5 and img.shape[1] == 1:
60
- log.info(f"[MatAnyCompat] Squeezing 5D {img.shape} to 4D {img.squeeze(1).shape} in {method_name}")
61
- img = img.squeeze(1) # [B,1,C,H,W] → [B,C,H,W]
62
- except Exception as e:
63
- log.warning(f"[MatAnyCompat] Failed to process input shape in {method_name}: %s", e)
64
- return orig_method(self, img, *args, **kwargs)
65
 
66
- setattr(MatAnyone, method_name, method_compat)
67
- setattr(MatAnyone, f"_{method_name}_patched", True)
68
- log.info(f"[MatAnyCompat] Applied T=1 squeeze guard in MatAnyone.{method_name}")
69
  return True
 
1
  #!/usr/bin/env python3
2
  # MatAnyone HF-compat patch: squeeze time dim T=1 before first Conv2d
3
  # Changes (2025-09-16):
4
+ # - Aligned with torch==2.3.1+cu121 and MatAnyone v1.0.0
5
+ # - Patch forward, encode, encode_img to cover all code paths
6
  # - Log dir(MatAnyone) and module version for debugging
7
  # - Added isinstance(img, torch.Tensor) for non-tensor safety
8
+ # - Log input/output shapes for verification
 
9
 
10
  import logging
11
  import torch
 
15
 
16
  def apply_matany_t1_squeeze_guard() -> bool:
17
  """
18
+ Monkey-patch MatAnyone.forward/encode/encode_img to squeeze [B,1,C,H,W] → [B,C,H,W].
19
  Safe for multi-frame (T>1) as it only squeezes when T==1.
20
+ Returns True if at least one method patched, False otherwise.
21
  """
22
  try:
23
  import matanyone.model.matanyone as M
 
29
  return False
30
  MatAnyone = M.MatAnyone
31
 
32
+ # Log MatAnyone version and attributes
33
  try:
34
  version = importlib.metadata.version("matanyone")
35
  log.info(f"[MatAnyCompat] MatAnyone version: {version}")
 
37
  log.info("[MatAnyCompat] MatAnyone version unknown")
38
  log.debug(f"[MatAnyCompat] MatAnyone attributes: {dir(MatAnyone)}")
39
 
40
+ # Try patching forward, encode, encode_img
41
+ patched = False
42
+ for method_name in ["forward", "encode", "encode_img"]:
43
+ if not hasattr(MatAnyone, method_name):
44
+ continue
45
+ if getattr(MatAnyone, f"_{method_name}_patched", False):
46
+ log.info(f"[MatAnyCompat] {method_name} already patched")
47
+ continue
48
+
49
+ # Store original method
50
+ orig_method = getattr(MatAnyone, method_name)
 
51
 
52
+ def method_compat(self, img, *args, **kwargs):
53
+ try:
54
+ if isinstance(img, torch.Tensor) and img.dim() == 5 and img.shape[1] == 1:
55
+ log.info(f"[MatAnyCompat] Squeezing 5D {img.shape} to 4D {img.squeeze(1).shape} in {method_name}")
56
+ img = img.squeeze(1) # [B,1,C,H,W] → [B,C,H,W]
57
+ except Exception as e:
58
+ log.warning(f"[MatAnyCompat] Failed to process input shape in {method_name}: %s", e)
59
+ return orig_method(self, img, *args, **kwargs)
60
 
61
+ setattr(MatAnyone, method_name, method_compat)
62
+ setattr(MatAny, f"_{method_name}_patched", True)
63
+ log.info(f"[MatAnyCompat] Applied T=1 squeeze guard in MatAnyone.{method_name}")
64
+ patched = True
 
 
 
 
 
65
 
66
+ if not patched:
67
+ log.warning("[MatAnyCompat] No patchable methods (forward, encode, encode_img) found on MatAnyone")
68
+ return False
69
  return True
models/matanyone_loader.py CHANGED
@@ -5,16 +5,16 @@
5
 
6
  - SAM2 defines the subject (seed mask) on frame 0.
7
  - MatAnyone does frame-by-frame alpha matting.
8
- - Uses T=1 squeeze patch for conv2d compatibility.
9
- - Falls back to process_frame([H,W,3]) if step() is unavailable.
10
 
11
  Changes (2025-09-16):
12
- - Added comprehensive error handling for MatAnyone import and initialization
13
- - Enhanced VRAM management with auto-cleanup
14
- - Added support for multiple MatAnyone method patching (encode_img/forward/encode)
15
- - Improved logging with timestamps and memory usage
16
- - Added environment variable controls for debugging
17
- - Fixed potential memory leaks in tensor handling
18
  """
19
 
20
  from __future__ import annotations
@@ -24,6 +24,7 @@
24
  import logging
25
  import numpy as np
26
  import torch
 
27
  from pathlib import Path
28
  from typing import Optional, Callable, Tuple
29
 
@@ -71,39 +72,22 @@ def _cuda_snapshot(device: Optional[torch.device]) -> str:
71
  idx = device.index
72
  name = torch.cuda.get_device_name(idx)
73
  alloc = torch.cuda.memory_allocated(idx) / (1024**3)
74
- resv = torch.cuda.memory_reserved(idx) / (1024**3)
75
  return f"device={idx}, name={name}, alloc={alloc:.2f}GB, reserved={resv:.2f}GB"
76
  except Exception as e:
77
  return f"CUDA snapshot error: {e!r}"
78
 
79
  def _safe_empty_cache():
80
- """Safely clear PyTorch cache with detailed memory reporting."""
 
81
  try:
82
- if not torch.cuda.is_available():
83
- return
84
-
85
- # Log memory stats before cleanup
86
- if _env_flag("MATANY_LOG_VRAM"):
87
- log.info("[MATANY] VRAM before cleanup:")
88
- log.info(f" Allocated: {torch.cuda.memory_allocated()/1024**2:.1f} MB")
89
- log.info(f" Reserved: {torch.cuda.memory_reserved()/1024**2:.1f} MB")
90
-
91
- # Clear cache and sync
92
  torch.cuda.empty_cache()
93
- torch.cuda.synchronize()
94
-
95
- # Log memory stats after cleanup
96
- if _env_flag("MATANY_LOG_VRAM"):
97
- log.info("[MATANY] VRAM after cleanup:")
98
- log.info(f" Allocated: {torch.cuda.memory_allocated()/1024**2:.1f} MB")
99
- log.info(f" Reserved: {torch.cuda.memory_reserved()/1024**2:.1f} MB")
100
-
101
- except Exception as e:
102
- log.warning(f"[MATANY] Error in cache cleanup: {e}", exc_info=True)
103
- try:
104
- torch.cuda.empty_cache()
105
- except Exception as e2:
106
- log.warning(f"[MATANY] Secondary cache cleanup failed: {e2}")
107
 
108
  # ---------- SAM2 → seed mask prep ----------
109
  def _prepare_seed_mask(sam2_mask: np.ndarray, H: int, W: int) -> np.ndarray:
@@ -149,7 +133,7 @@ class MatAnyoneSession:
149
  """
150
  Streaming wrapper that seeds MatAnyone on frame 0.
151
  Prefers step([B,C,H,W]) with T=1 squeeze patch for conv2d compatibility.
152
- Falls back to process_frame([H,W,3]) if supported by the wheel.
153
  """
154
  def __init__(self, device: Optional[str] = None, precision: str = "auto"):
155
  from .matany_compat_patch import apply_matany_t1_squeeze_guard
@@ -157,18 +141,25 @@ def __init__(self, device: Optional[str] = None, precision: str = "auto"):
157
  self.device = torch.device(device) if device else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
158
  self.precision = precision.lower()
159
 
160
- # Apply T=1 squeeze patch for conv2d fix
161
  if apply_matany_t1_squeeze_guard():
162
- log.info("[MATANY] T=1 squeeze patch applied for MatAnyone.encode_img")
163
  else:
164
  log.warning("[MATANY] T=1 squeeze patch failed; conv2d errors may occur")
165
 
 
 
 
 
 
 
 
166
  # API/format overrides for debugging
167
  api_force = os.getenv("MATANY_FORCE_API", "").strip().lower() # "process" or "step"
168
  fmt_force = os.getenv("MATANY_FORCE_FORMAT", "4d").strip().lower() # "4d" or "5d"
169
  self._force_api_process = (api_force == "process")
170
  self._force_api_step = (api_force == "step")
171
- self._force_4d = (fmt_force == "4d") or not fmt_force # Default to 4D post-patch
172
  self._force_5d = (fmt_force == "5d")
173
 
174
  try:
@@ -255,6 +246,7 @@ def _call_step(self, rgb_hwc: np.ndarray, seed_mask_hw: Optional[np.ndarray], is
255
  def run(use_5d: bool):
256
  img = img_5d if use_5d else img_4d
257
  msk = mask_5d if use_5d else mask_4d
 
258
  if is_first and msk is not None:
259
  try:
260
  return self.core.step(img, msk, is_first=True)
@@ -351,10 +343,10 @@ def process_stream(
351
  cap_probe = cv2.VideoCapture(str(video_path))
352
  if not cap_probe.isOpened():
353
  raise MatAnyError(f"Failed to open video: {video_path}")
354
- N = int(cap_probe.get(cv2.CAP_PROP_FRAME_COUNT))
355
  fps = cap_probe.get(cv2.CAP_PROP_FPS)
356
- W = int(cap_probe.get(cv2.CAP_PROP_FRAME_WIDTH))
357
- H = int(cap_probe.get(cv2.CAP_PROP_FRAME_HEIGHT))
358
  cap_probe.release()
359
  if not fps or fps <= 0 or np.isnan(fps):
360
  fps = 25.0
@@ -364,10 +356,10 @@ def process_stream(
364
  _emit_progress(progress_cb, 0.08, "Using per-frame processing")
365
 
366
  alpha_path = out_dir / "alpha.mp4"
367
- fg_path = out_dir / "fg.mp4"
368
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
369
  alpha_writer = cv2.VideoWriter(str(alpha_path), fourcc, fps, (W, H), True)
370
- fg_writer = cv2.VideoWriter(str(fg_path), fourcc, fps, (W, H), True)
371
  if not alpha_writer.isOpened() or not fg_writer.isOpened():
372
  raise MatAnyError("Failed to initialize VideoWriter(s)")
373
 
@@ -396,9 +388,9 @@ def process_stream(
396
  is_first = (idx == 0)
397
  alpha = self._run_frame(frame, seed_mask_np if is_first else None, is_first)
398
 
399
- alpha_u8 = (alpha * 255.0 + 0.5).astype(np.uint8)
400
  alpha_bgr = cv2.cvtColor(alpha_u8, cv2.COLOR_GRAY2BGR)
401
- fg_bgr = (frame.astype(np.float32) * alpha[..., None]).clip(0, 255).astype(np.uint8)
402
 
403
  alpha_writer.write(alpha_bgr)
404
  fg_writer.write(fg_bgr)
 
5
 
6
  - SAM2 defines the subject (seed mask) on frame 0.
7
  - MatAnyone does frame-by-frame alpha matting.
8
+ - Prefers step([B,C,H,W]) with T=1 squeeze patch for conv2d compatibility.
9
+ - Falls back to process_frame([H,W,3]) if supported.
10
 
11
  Changes (2025-09-16):
12
+ - Aligned with torch==2.3.1+cu121 and MatAnyone v1.0.0
13
+ - Added shape logging in _call_step to verify 5D-to-4D squeeze
14
+ - Set MATANY_FORCE_FORMAT=4d as default
15
+ - Added VRAM logging in process_stream (MATANY_LOG_VRAM=1)
16
+ - Enhanced _safe_empty_cache with memory_summary
17
+ - Added MatAnyone version logging
18
  """
19
 
20
  from __future__ import annotations
 
24
  import logging
25
  import numpy as np
26
  import torch
27
+ import importlib.metadata
28
  from pathlib import Path
29
  from typing import Optional, Callable, Tuple
30
 
 
72
  idx = device.index
73
  name = torch.cuda.get_device_name(idx)
74
  alloc = torch.cuda.memory_allocated(idx) / (1024**3)
75
+ resv = torch.cuda.memory_reserved(idx) / (1024**3)
76
  return f"device={idx}, name={name}, alloc={alloc:.2f}GB, reserved={resv:.2f}GB"
77
  except Exception as e:
78
  return f"CUDA snapshot error: {e!r}"
79
 
80
  def _safe_empty_cache():
81
+ if not torch.cuda.is_available():
82
+ return
83
  try:
84
+ log.info(f"[MATANY] CUDA memory before empty_cache: {_cuda_snapshot(None)}")
 
 
 
 
 
 
 
 
 
85
  torch.cuda.empty_cache()
86
+ log.info(f"[MATANY] CUDA memory after empty_cache: {_cuda_snapshot(None)}")
87
+ if os.getenv("MATANY_LOG_VRAM", "0") == "1":
88
+ log.debug(f"[MATANY] VRAM summary:\n{torch.cuda.memory_summary()}")
89
+ except Exception:
90
+ pass
 
 
 
 
 
 
 
 
 
91
 
92
  # ---------- SAM2 → seed mask prep ----------
93
  def _prepare_seed_mask(sam2_mask: np.ndarray, H: int, W: int) -> np.ndarray:
 
133
  """
134
  Streaming wrapper that seeds MatAnyone on frame 0.
135
  Prefers step([B,C,H,W]) with T=1 squeeze patch for conv2d compatibility.
136
+ Falls back to process_frame([H,W,3]) if supported.
137
  """
138
  def __init__(self, device: Optional[str] = None, precision: str = "auto"):
139
  from .matany_compat_patch import apply_matany_t1_squeeze_guard
 
141
  self.device = torch.device(device) if device else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
142
  self.precision = precision.lower()
143
 
144
+ # Apply T=1 squeeze patch
145
  if apply_matany_t1_squeeze_guard():
146
+ log.info("[MATANY] T=1 squeeze patch applied for MatAnyone")
147
  else:
148
  log.warning("[MATANY] T=1 squeeze patch failed; conv2d errors may occur")
149
 
150
+ # Log MatAnyone version
151
+ try:
152
+ version = importlib.metadata.version("matanyone")
153
+ log.info(f"[MATANY] MatAnyone version: {version}")
154
+ except Exception:
155
+ log.info("[MATANY] MatAnyone version unknown")
156
+
157
  # API/format overrides for debugging
158
  api_force = os.getenv("MATANY_FORCE_API", "").strip().lower() # "process" or "step"
159
  fmt_force = os.getenv("MATANY_FORCE_FORMAT", "4d").strip().lower() # "4d" or "5d"
160
  self._force_api_process = (api_force == "process")
161
  self._force_api_step = (api_force == "step")
162
+ self._force_4d = (fmt_force == "4d") or not fmt_force # Default to 4D
163
  self._force_5d = (fmt_force == "5d")
164
 
165
  try:
 
246
  def run(use_5d: bool):
247
  img = img_5d if use_5d else img_4d
248
  msk = mask_5d if use_5d else mask_4d
249
+ log.debug(f"[MATANY] Step input: img={img.shape}, mask={msk.shape if msk is not None else None}, is_first={is_first}")
250
  if is_first and msk is not None:
251
  try:
252
  return self.core.step(img, msk, is_first=True)
 
343
  cap_probe = cv2.VideoCapture(str(video_path))
344
  if not cap_probe.isOpened():
345
  raise MatAnyError(f"Failed to open video: {video_path}")
346
+ N = int(cap_probe.get(cv2.CAP_PROP_FRAME_COUNT))
347
  fps = cap_probe.get(cv2.CAP_PROP_FPS)
348
+ W = int(cap_probe.get(cv2.CAP_PROP_FRAME_WIDTH))
349
+ H = int(cap_probe.get(cv2.CAP_PROP_FRAME_HEIGHT))
350
  cap_probe.release()
351
  if not fps or fps <= 0 or np.isnan(fps):
352
  fps = 25.0
 
356
  _emit_progress(progress_cb, 0.08, "Using per-frame processing")
357
 
358
  alpha_path = out_dir / "alpha.mp4"
359
+ fg_path = out_dir / "fg.mp4"
360
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
361
  alpha_writer = cv2.VideoWriter(str(alpha_path), fourcc, fps, (W, H), True)
362
+ fg_writer = cv2.VideoWriter(str(fg_path), fourcc, fps, (W, H), True)
363
  if not alpha_writer.isOpened() or not fg_writer.isOpened():
364
  raise MatAnyError("Failed to initialize VideoWriter(s)")
365
 
 
388
  is_first = (idx == 0)
389
  alpha = self._run_frame(frame, seed_mask_np if is_first else None, is_first)
390
 
391
+ alpha_u8 = (alpha * 255.0 + 0.5).astype(np.uint8)
392
  alpha_bgr = cv2.cvtColor(alpha_u8, cv2.COLOR_GRAY2BGR)
393
+ fg_bgr = (frame.astype(np.float32) * alpha[..., None]).clip(0, 255).astype(np.uint8)
394
 
395
  alpha_writer.write(alpha_bgr)
396
  fg_writer.write(fg_bgr)
models/sam2_loader.py CHANGED
@@ -1,195 +1,279 @@
1
- # models/sam2_loader.py
2
- import os, logging, torch
3
- from huggingface_hub import hf_hub_download
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from pathlib import Path
 
 
5
  import numpy as np
 
 
6
 
7
- log = logging.getLogger("sam2_loader")
 
 
 
 
 
 
 
 
8
 
9
- DEFAULT_MODEL_ID = os.environ.get("SAM2_MODEL_ID", "facebook/sam2")
10
- DEFAULT_VARIANT = os.environ.get("SAM2_VARIANT", "sam2_hiera_large")
 
 
 
11
 
12
- # Map variant -> filenames (SAM2 releases follow this pattern)
13
- VARIANT_FILES = {
14
- "sam2_hiera_small": ("sam2_hiera_small.pt", "configs/sam2/sam2_hiera_s.yaml"),
15
- "sam2_hiera_base": ("sam2_hiera_base.pt", "configs/sam2/sam2_hiera_b.yaml"),
16
- "sam2_hiera_large": ("sam2_hiera_large.pt", "configs/sam2/sam2_hiera_l.yaml"),
17
- }
 
18
 
19
- def _download_checkpoint(model_id: str, ckpt_name: str) -> str:
20
- return hf_hub_download(repo_id=model_id, filename=ckpt_name, local_dir=os.environ.get("HF_HOME"))
21
 
22
- def _find_sam2_build():
 
 
 
23
  try:
24
- from sam2.build_sam import build_sam2
25
- return build_sam2
26
  except Exception as e:
27
- log.error("SAM2 not importable (check Dockerfile vendoring): %s", e)
28
  return None
29
 
30
- class SAM2Predictor:
31
- def __init__(self, device: torch.device):
32
- self.device = device
33
- self.model = None
34
- self.predictor = None
35
-
36
- def load(self, variant: str = DEFAULT_VARIANT, model_id: str = DEFAULT_MODEL_ID):
37
- log.info(f"SAM2Predictor.load() called with variant={variant}")
38
- build_sam2 = _find_sam2_build()
39
- if build_sam2 is None:
40
- log.error("SAM2 build function not available - raising RuntimeError")
41
- raise RuntimeError("SAM2 build function not available")
42
-
43
- ckpt_name, cfg_path = VARIANT_FILES.get(variant, VARIANT_FILES["sam2_hiera_large"])
44
- log.info(f"Downloading checkpoint: {ckpt_name}")
45
- ckpt = _download_checkpoint(model_id, ckpt_name)
46
- log.info(f"Checkpoint downloaded to: {ckpt}")
47
-
48
- # Use the symlinked config files in the sam2 package directory
49
- # From debug output: sam2_hiera_l.yaml -> configs/sam2/sam2_hiera_l.yaml
50
- sam2_pkg_dir = os.environ.get("THIRD_PARTY_SAM2_DIR", "/home/user/app/third_party/sam2")
51
- config_name = cfg_path.split('/')[-1] # Extract just the filename (e.g., "sam2_hiera_l.yaml")
52
- full_cfg_path = os.path.join(sam2_pkg_dir, "sam2", config_name)
53
- log.info(f"SAM2 config path: {full_cfg_path}")
54
- log.info(f"Config file exists: {os.path.exists(full_cfg_path)}")
55
-
56
- log.info("Calling build_sam2()...")
57
- model = build_sam2(config_file=full_cfg_path, ckpt_path=ckpt, device=str(self.device))
58
- log.info("build_sam2() completed successfully")
59
-
60
- # Explicitly move model to device and verify
61
- model = model.to(self.device)
62
- model.eval()
63
-
64
- # Verify model is on correct device
65
- if hasattr(model, 'parameters'):
66
- first_param = next(model.parameters(), None)
67
- if first_param is not None:
68
- actual_device = first_param.device
69
- log.info(f"SAM2 model device verification: expected={self.device}, actual={actual_device}")
70
- if str(actual_device) != str(self.device):
71
- log.warning(f"SAM2 model device mismatch! Moving to {self.device}")
72
- model = model.to(self.device)
73
-
74
- self.model = model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  try:
77
- from sam2.sam2_video_predictor import SAM2VideoPredictor
78
- self.predictor = SAM2VideoPredictor(self.model)
79
- except Exception:
80
- # Fallback to image predictor if video predictor missing
81
- from sam2.sam2_image_predictor import SAM2ImagePredictor
82
- self.predictor = SAM2ImagePredictor(self.model)
83
-
84
- return self
85
-
86
- def _detect_person_region(self, image_rgb01: np.ndarray) -> np.ndarray:
87
- """
88
- Simple centered box that works for most cases.
89
- Returns [x1, y1, x2, y2] in image coordinates.
90
- """
91
- h, w = image_rgb01.shape[:2]
92
-
93
- # Use 70% of the frame size, centered
94
- margin_w = int(w * 0.15)
95
- margin_h = int(h * 0.15)
96
-
97
- box = np.array([
98
- margin_w, # x1
99
- margin_h, # y1
100
- w - margin_w, # x2
101
- h - margin_h # y2
102
- ], dtype=np.float32)
103
-
104
- log.info(f"Using simple center box: {box}")
105
- return box
106
-
107
- @torch.inference_mode()
108
- def first_frame_mask(self, image_rgb01):
109
- """
110
- Returns an initial binary mask for the foreground person from first frame.
111
- Uses a robust approach with fallback strategies for better reliability.
112
- """
113
- log.info("🔍 SAM2 first_frame_mask() called - starting segmentation")
114
-
115
  try:
116
- # Ensure input tensor is on correct device
117
- if isinstance(image_rgb01, torch.Tensor):
118
- image_rgb01 = image_rgb01.to(self.device, non_blocking=True)
119
-
120
- if not hasattr(self.predictor, "set_image"):
121
- raise RuntimeError("SAM2 predictor doesn't support set_image")
122
-
123
- # Convert to numpy for predictor if needed
124
- if isinstance(image_rgb01, torch.Tensor):
125
- image_np = (image_rgb01.cpu().numpy() * 255).astype("uint8")
126
- else:
127
- image_np = (image_rgb01 * 255).astype("uint8")
128
-
129
- # Set the image for prediction
130
- self.predictor.set_image(image_np)
131
-
132
- # Strategy 1: Try with person-focused bounding box first
133
- box = self._detect_person_region(image_np)
134
- log.info(f"Trying person-focused box: {box}")
135
-
136
- masks, scores, _ = self.predictor.predict(
137
- box=box,
138
- multimask_output=True,
139
- mask_input=None,
140
- return_logits=False
141
- )
142
-
143
- # Strategy 2: If no good masks found, try with a point in the center
144
- if len(masks) == 0 or np.max(scores) < 0.5:
145
- log.info("No good masks found with box, trying center point")
146
- h, w = image_np.shape[:2]
147
- point = np.array([[w//2, h//2]])
148
- labels = np.array([1]) # 1=foreground point
149
- masks, scores, _ = self.predictor.predict(
150
- point_coords=point,
151
- point_labels=labels,
152
- multimask_output=True
153
- )
154
-
155
- # Choose the best mask (highest score)
156
- if len(masks) > 0 and len(scores) > 0:
157
- best_idx = np.argmax(scores)
158
- mask = masks[best_idx]
159
- score = float(scores[best_idx])
160
- log.info(f"Selected mask {best_idx+1}/{len(masks)} with score {score:.3f}")
161
-
162
- # Verify mask quality
163
- mask_coverage = (np.sum(mask > 0) / mask.size) * 100
164
- log.info(f"Mask coverage: {mask_coverage:.1f}% (target: 15-35%)")
165
-
166
- # If mask is too small or too large, try to refine it
167
- if mask_coverage < 5 or mask_coverage > 50:
168
- log.warning(f"Suspicious mask coverage {mask_coverage:.1f}%, applying post-processing")
169
- # Apply morphological operations to clean up the mask
170
- kernel = np.ones((5,5), np.uint8)
171
- if mask_coverage < 5: # Too small - dilate
172
- mask = cv2.dilate(mask.astype(np.uint8), kernel, iterations=1)
173
- else: # Too large - erode
174
- mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=1)
175
-
176
- # Ensure we still have a valid mask
177
- if np.sum(mask) == 0:
178
- log.warning("Post-processing removed all mask pixels, using original")
179
- mask = masks[best_idx]
180
  else:
181
- log.warning("No valid masks found, using fallback")
182
- mask = np.ones_like(image_np[:,:,0], dtype=bool)
183
-
184
- return mask.astype(np.float32)
185
-
186
- except Exception as e:
187
- log.error(f"Error in first_frame_mask: {e}")
188
- # Fallback to a simple centered box if anything goes wrong
189
- h, w = image_np.shape[:2]
190
- mask = np.zeros((h, w), dtype=np.float32)
191
- margin_h, margin_w = h//4, w//4
192
- mask[margin_h:h-margin_h, margin_w:w-margin_w] = 1.0
193
- return mask
194
 
195
- return mask.astype("float32")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ SAM2 Loader Robust loading and mask generation for SAM2
4
+ ========================================================
5
+ - Loads SAM2 model with Hydra config resolution
6
+ - Generates seed masks for MatAnyone
7
+ - Aligned with torch==2.3.1+cu121 and SAM2 commit 3c76f73c1a7e7b4a2e8a0a9a3e5b92f7e6e3f2f5
8
+
9
+ Changes (2025-09-16):
10
+ - Aligned with torch==2.3.1+cu121 and SAM2 commit
11
+ - Added GPU memory logging for Tesla T4
12
+ - Added SAM2 version logging via importlib.metadata
13
+ - Simplified config resolution to match __init__.py
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import os
19
+ import logging
20
+ import importlib.metadata
21
  from pathlib import Path
22
+ from typing import Optional, Tuple, Dict, Any
23
+
24
  import numpy as np
25
+ import yaml
26
+ import torch
27
 
28
+ # --------------------------------------------------------------------------------------
29
+ # Logging
30
+ # --------------------------------------------------------------------------------------
31
+ logger = logging.getLogger("backgroundfx_pro")
32
+ if not logger.handlers:
33
+ _h = logging.StreamHandler()
34
+ _h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s"))
35
+ logger.addHandler(_h)
36
+ logger.setLevel(logging.INFO)
37
 
38
+ # --------------------------------------------------------------------------------------
39
+ # Path setup for third_party repos
40
+ # --------------------------------------------------------------------------------------
41
+ ROOT = Path(__file__).resolve().parent.parent # project root
42
+ TP_SAM2 = Path(os.environ.get("THIRD_PARTY_SAM2_DIR", ROOT / "third_party" / "sam2")).resolve()
43
 
44
+ def _add_sys_path(p: Path) -> None:
45
+ if p.exists():
46
+ p_str = str(p)
47
+ if p_str not in sys.path:
48
+ sys.path.insert(0, p_str)
49
+ else:
50
+ logger.warning(f"third_party path not found: {p}")
51
 
52
+ _add_sys_path(TP_SAM2)
 
53
 
54
+ # --------------------------------------------------------------------------------------
55
+ # Safe Torch accessors
56
+ # --------------------------------------------------------------------------------------
57
+ def _torch():
58
  try:
59
+ import torch
60
+ return torch
61
  except Exception as e:
62
+ logger.warning(f"[sam2_loader.safe-torch] import failed: {e}")
63
  return None
64
 
65
+ def _has_cuda() -> bool:
66
+ t = _torch()
67
+ if t is None:
68
+ return False
69
+ try:
70
+ return bool(t.cuda.is_available())
71
+ except Exception as e:
72
+ logger.warning(f"[sam2_loader.safe-torch] cuda.is_available() failed: {e}")
73
+ return False
74
+
75
+ def _pick_device(env_key: str) -> str:
76
+ requested = os.environ.get(env_key, "").strip().lower()
77
+ has_cuda = _has_cuda()
78
+
79
+ logger.info(f"CUDA environment variables: {{'SAM2_DEVICE': '{os.environ.get('SAM2_DEVICE', '')}'}}")
80
+ logger.info(f"_pick_device({env_key}): requested='{requested}', has_cuda={has_cuda}")
81
+
82
+ if has_cuda and requested not in {"cpu"}:
83
+ logger.info(f"FORCING CUDA device (GPU available, requested='{requested}')")
84
+ return "cuda"
85
+ elif requested in {"cuda", "cpu"}:
86
+ logger.info(f"Using explicitly requested device: {requested}")
87
+ return requested
88
+
89
+ result = "cuda" if has_cuda else "cpu"
90
+ logger.info(f"Auto-selected device: {result}")
91
+ return result
92
+
93
+ # --------------------------------------------------------------------------------------
94
+ # SAM2 Loading and Mask Generation
95
+ # --------------------------------------------------------------------------------------
96
+ def _resolve_sam2_cfg(cfg_str: str) -> str:
97
+ """Resolve SAM2 config path - return relative path for Hydra compatibility."""
98
+ logger.info(f"_resolve_sam2_cfg called with cfg_str={cfg_str}")
99
+
100
+ candidate = os.path.join(TP_SAM2, cfg_str)
101
+ logger.info(f"Candidate path: {candidate}")
102
+ logger.info(f"Candidate exists: {os.path.exists(candidate)}")
103
+
104
+ if os.path.exists(candidate):
105
+ if cfg_str.startswith("sam2/configs/"):
106
+ relative_path = cfg_str.replace("sam2/configs/", "configs/")
107
+ else:
108
+ relative_path = cfg_str
109
+ logger.info(f"Returning Hydra-compatible relative path: {relative_path}")
110
+ return relative_path
111
+
112
+ fallbacks = [
113
+ os.path.join(TP_SAM2, "sam2", cfg_str),
114
+ os.path.join(TP_SAM2, "configs", cfg_str),
115
+ ]
116
+
117
+ for fallback in fallbacks:
118
+ logger.info(f"Trying fallback: {fallback}")
119
+ if os.path.exists(fallback):
120
+ if "configs/" in fallback:
121
+ relative_path = "configs/" + fallback.split("configs/")[-1]
122
+ logger.info(f"Returning fallback relative path: {relative_path}")
123
+ return relative_path
124
+
125
+ logger.warning(f"Config not found, returning original: {cfg_str}")
126
+ return cfg_str
127
 
128
+ def _find_hiera_config_if_hieradet(cfg_path: str) -> Optional[str]:
129
+ """If config references 'hieradet', try to find a 'hiera' config."""
130
+ try:
131
+ with open(cfg_path, "r") as f:
132
+ data = yaml.safe_load(f)
133
+ model = data.get("model", {}) or {}
134
+ enc = model.get("image_encoder") or {}
135
+ trunk = enc.get("trunk") or {}
136
+ target = trunk.get("_target_") or trunk.get("target")
137
+ if isinstance(target, str) and "hieradet" in target:
138
+ for y in TP_SAM2.rglob("*.yaml"):
139
+ try:
140
+ with open(y, "r") as f2:
141
+ d2 = yaml.safe_load(f2) or {}
142
+ e2 = (d2.get("model", {}) or {}).get("image_encoder") or {}
143
+ t2 = (e2.get("trunk") or {})
144
+ tgt2 = t2.get("_target_") or t2.get("target")
145
+ if isinstance(tgt2, str) and ".hiera." in tgt2:
146
+ logger.info(f"SAM2: switching config from 'hieradet' → 'hiera': {y}")
147
+ return str(y)
148
+ except Exception:
149
+ continue
150
+ except Exception:
151
+ pass
152
+ return None
153
+
154
+ def load_sam2() -> Tuple[Optional[object], bool, Dict[str, Any]]:
155
+ """Robust SAM2 loader with config resolution and error handling."""
156
+ meta = {"sam2_import_ok": False, "sam2_init_ok": False}
157
+ try:
158
+ from sam2.build_sam import build_sam2
159
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
160
+ meta["sam2_import_ok"] = True
161
+ except Exception as e:
162
+ logger.warning(f"SAM2 import failed: {e}")
163
+ return None, False, meta
164
+
165
+ # Log SAM2 version
166
+ try:
167
+ version = importlib.metadata.version("segment-anything-2")
168
+ logger.info(f"[SAM2] SAM2 version: {version}")
169
+ except Exception:
170
+ logger.info("[SAM2] SAM2 version unknown")
171
+
172
+ # Check GPU memory before loading
173
+ if torch and torch.cuda.is_available():
174
+ mem_before = torch.cuda.memory_allocated() / 1024**3
175
+ logger.info(f"🔍 GPU memory before SAM2 load: {mem_before:.2f}GB")
176
+
177
+ device = _pick_device("SAM2_DEVICE")
178
+ cfg_env = os.environ.get("SAM2_MODEL_CFG", "sam2/configs/sam2/sam2_hiera_l.yaml")
179
+ cfg = _resolve_sam2_cfg(cfg_env)
180
+ ckpt = os.environ.get("SAM2_CHECKPOINT", "")
181
+
182
+ def _try_build(cfg_path: str):
183
+ logger.info(f"_try_build called with cfg_path: {cfg_path}")
184
+ params = set(inspect.signature(build_sam2).parameters.keys())
185
+ logger.info(f"build_sam2 parameters: {list(params)}")
186
+ kwargs = {}
187
+ if "config_file" in params:
188
+ kwargs["config_file"] = cfg_path
189
+ logger.info(f"Using config_file parameter: {cfg_path}")
190
+ elif "model_cfg" in params:
191
+ kwargs["model_cfg"] = cfg_path
192
+ logger.info(f"Using model_cfg parameter: {cfg_path}")
193
+ if ckpt:
194
+ if "checkpoint" in params:
195
+ kwargs["checkpoint"] = ckpt
196
+ elif "ckpt_path" in params:
197
+ kwargs["ckpt_path"] = ckpt
198
+ elif "weights" in params:
199
+ kwargs["weights"] = ckpt
200
+ if "device" in params:
201
+ kwargs["device"] = device
202
  try:
203
+ logger.info(f"Calling build_sam2 with kwargs: {kwargs}")
204
+ result = build_sam2(**kwargs)
205
+ logger.info(f"build_sam2 succeeded with kwargs")
206
+ if hasattr(result, 'device'):
207
+ logger.info(f"SAM2 model device: {result.device}")
208
+ elif hasattr(result, 'image_encoder') and hasattr(result.image_encoder, 'device'):
209
+ logger.info(f"SAM2 model device: {result.image_encoder.device}")
210
+ return result
211
+ except TypeError as e:
212
+ logger.info(f"build_sam2 kwargs failed: {e}, trying positional args")
213
+ pos = [cfg_path]
214
+ if ckpt:
215
+ pos.append(ckpt)
216
+ if "device" not in kwargs:
217
+ pos.append(device)
218
+ logger.info(f"Calling build_sam2 with positional args: {pos}")
219
+ result = build_sam2(*pos)
220
+ logger.info(f"build_sam2 succeeded with positional args")
221
+ return result
222
+
223
+ try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  try:
225
+ sam = _try_build(cfg)
226
+ except Exception:
227
+ alt_cfg = _find_hiera_config_if_hieradet(cfg)
228
+ if alt_cfg:
229
+ sam = _try_build(alt_cfg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  else:
231
+ raise
232
+
233
+ if sam is not None:
234
+ predictor = SAM2ImagePredictor(sam)
235
+ meta["sam2_init_ok"] = True
236
+ meta["sam2_device"] = device
237
+ return predictor, True, meta
238
+ else:
239
+ return None, False, meta
 
 
 
 
240
 
241
+ except Exception as e:
242
+ logger.error(f"SAM2 loading failed: {e}")
243
+ return None, False, meta
244
+
245
+ def run_sam2_mask(predictor: object,
246
+ first_frame_bgr: np.ndarray,
247
+ point: Optional[Tuple[int, int]] = None,
248
+ auto: bool = False) -> Tuple[Optional[np.ndarray], bool]:
249
+ """Generate a seed mask for MatAnyone. Returns (mask_uint8_0_255, ok)."""
250
+ if predictor is None:
251
+ return None, False
252
+ try:
253
+ import cv2
254
+ rgb = cv2.cvtColor(first_frame_bgr, cv2.COLOR_BGR2RGB)
255
+ predictor.set_image(rgb)
256
+
257
+ if auto:
258
+ h, w = rgb.shape[:2]
259
+ box = np.array([int(0.05*w), int(0.05*h), int(0.95*w), int(0.95*h)])
260
+ masks, _, _ = predictor.predict(box=box)
261
+ elif point is not None:
262
+ x, y = int(point[0]), int(point[1])
263
+ pts = np.array([[x, y]], dtype=np.int32)
264
+ labels = np.array([1], dtype=np.int32)
265
+ masks, _, _ = predictor.predict(point_coords=pts, point_labels=labels)
266
+ else:
267
+ h, w = rgb.shape[:2]
268
+ box = np.array([int(0.1*w), int(0.1*h), int(0.9*w), int(0.9*h)])
269
+ masks, _, _ = predictor.predict(box=box)
270
+
271
+ if masks is None or len(masks) == 0:
272
+ return None, False
273
+
274
+ m = masks[0].astype(np.uint8) * 255
275
+ logger.info(f"[SAM2] Generated mask: shape={m.shape}, dtype={m.dtype}")
276
+ return m, True
277
+ except Exception as e:
278
+ logger.warning(f"SAM2 mask generation failed: {e}")
279
+ return None, False
pipeline.py CHANGED
@@ -8,6 +8,13 @@
8
  - Verbose breadcrumbs for pinpointing stalls
9
  - Enhanced mask validation for MatAnyone compatibility (robust to API changes)
10
  - Automatic mask inversion for high-coverage masks
 
 
 
 
 
 
 
11
  """
12
 
13
  from __future__ import annotations
@@ -18,6 +25,7 @@
18
  import tempfile
19
  import logging
20
  import importlib
 
21
  from pathlib import Path
22
  from typing import Optional, Tuple, Dict, Any, Union
23
 
@@ -216,17 +224,17 @@ def _progress(*args):
216
  logger.info(f"Progress: {msg}")
217
  if progress_callback:
218
  try:
219
- progress_callback(msg) # legacy 1-arg
220
  except TypeError:
221
- progress_callback(0.0, msg) # fallback
222
  elif len(args) >= 2:
223
  pct, msg = args[0], args[1]
224
  logger.info(f"Progress: {msg} ({int(pct*100)}%)")
225
  if progress_callback:
226
  try:
227
- progress_callback(pct, msg) # preferred 2-arg
228
  except TypeError:
229
- progress_callback(msg) # legacy 1-arg
230
  except Exception as e:
231
  logger.warning(f"progress callback failed: {e}")
232
 
@@ -235,7 +243,7 @@ def _progress(*args):
235
  # [5] PHASE 0: Video metadata
236
  # ----------------------------------------------------------------------------------
237
  logger.info("[0] Reading video metadata…")
238
- _progress("📹 Reading video metadata...")
239
  first_frame, fps, (vw, vh) = _cv_read_first_frame(video_path)
240
  diagnostics["fps"] = int(fps or 25)
241
  diagnostics["resolution"] = [int(vw), int(vh)]
@@ -249,7 +257,7 @@ def _progress(*args):
249
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
250
  cap.release()
251
 
252
- _progress(f"✅ Video loaded: {vw}x{vh} @ {fps}fps ({total_frames} frames)")
253
  diagnostics["total_frames"] = total_frames
254
  diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
255
 
@@ -257,7 +265,7 @@ def _progress(*args):
257
  # [6] PHASE 1: SAM2 → seed mask (then free)
258
  # ----------------------------------------------------------------------------------
259
  logger.info("[1] Loading SAM2…")
260
- _progress("🤖 Loading SAM2 model...")
261
  predictor, sam2_ok, sam_meta = load_sam2()
262
  diagnostics["sam2_meta"] = sam_meta or {}
263
  diagnostics["device_sam2"] = (sam_meta or {}).get("sam2_device")
@@ -269,7 +277,7 @@ def _progress(*args):
269
 
270
  if sam2_ok and predictor is not None:
271
  logger.info("[1] Running SAM2 segmentation…")
272
- _progress("🎯 Running SAM2 segmentation...")
273
  px = int(point_x) if point_x is not None else None
274
  py = int(point_y) if point_y is not None else None
275
  seed_mask, ok_mask = run_sam2_mask(
@@ -278,10 +286,10 @@ def _progress(*args):
278
  auto=auto_box
279
  )
280
  diagnostics["sam2_ok"] = bool(ok_mask)
281
- _progress("✅ SAM2 segmentation complete")
282
  else:
283
  logger.info("[1] SAM2 unavailable or failed to load.")
284
- _progress("⚠️ SAM2 unavailable, using fallback")
285
 
286
  # Free SAM2 ASAP
287
  try:
@@ -290,13 +298,13 @@ def _progress(*args):
290
  pass
291
  predictor = None
292
  _force_cleanup()
293
- _progress("🧹 SAM2 memory cleared")
294
  diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
295
 
296
  # Fallback mask generation if SAM2 failed
297
  if not ok_mask or seed_mask is None:
298
  logger.info("[1] Using fallback mask generation…")
299
- _progress("🔄 Generating fallback mask...")
300
  seed_mask = fallback_mask(first_frame)
301
  diagnostics["fallback_used"] = "mask_generation"
302
  _force_cleanup()
@@ -304,7 +312,7 @@ def _progress(*args):
304
  # Optional GrabCut refinement
305
  if int(os.environ.get("REFINE_GRABCUT", "1")) == 1:
306
  logger.info("[1] Refining mask with GrabCut…")
307
- _progress("✨ Refining mask with GrabCut...")
308
  seed_mask = _refine_mask_grabcut(first_frame, seed_mask)
309
  _force_cleanup()
310
 
@@ -341,7 +349,7 @@ def _progress(*args):
341
  logger.info(f"[1] ✅ Mask validation passed: {validation_msg}")
342
  diagnostics["mask_validation"] = {"valid": True, "stats": mask_stats}
343
 
344
- _progress("✅ Stage 1 complete - Mask generated and validated")
345
 
346
  # Free first frame memory
347
  try:
@@ -350,18 +358,25 @@ def _progress(*args):
350
  pass
351
  _force_cleanup()
352
  _cleanup_temp_files(tmp_root)
353
- _progress("🧹 Frame memory cleared")
354
 
355
  # ----------------------------------------------------------------------------------
356
  # [7] PHASE 2: MatAnyone (strict CHW/HW tensors are handled in matanyone_loader)
357
  # ----------------------------------------------------------------------------------
358
  logger.info("[2] Loading MatAnyone…")
359
- _progress("🎬 Loading MatAnyone model...")
360
  matany, mat_ok, mat_meta = load_matany()
361
  diagnostics["matany_meta"] = mat_meta or {}
362
  diagnostics["device_matany"] = (mat_meta or {}).get("matany_device")
363
  diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
364
 
 
 
 
 
 
 
 
365
  fg_path, al_path = None, None
366
  out_dir = tmp_root / "matany_out"
367
  _ensure_dir(out_dir)
@@ -369,19 +384,25 @@ def _progress(*args):
369
  from models.matanyone_loader import MatAnyError
370
 
371
  try:
372
- _progress("MatAnyone: starting…")
373
  logger.info("[2] Running MatAnyone processing…")
374
 
375
  mask_validation = diagnostics.get("mask_validation", {})
376
  if not mask_validation.get("valid", False):
377
  logger.warning(f"[2] Proceeding with MatAnyone despite mask validation failure: "
378
  f"{mask_validation.get('error', 'unknown')}")
379
-
380
  else:
381
  logger.info(f"[2] Mask validation OK - coverage: "
382
  f"{mask_validation['stats']['coverage_percent']}%")
383
 
384
- _progress("🎥 Running MatAnyone video matting...")
 
 
 
 
 
 
 
385
 
386
  # NOTE: The updated loader feeds CHW image + HW seed (frame 0 only) — no 5D tensors.
387
  al_path, fg_path = run_matany(
@@ -389,14 +410,13 @@ def _progress(*args):
389
  mask_path=mask_png,
390
  out_dir=out_dir,
391
  device="cuda" if _cuda_available() else "cpu",
392
- # Pass a simple status bridge; the loader already rate-limits progress
393
- progress_callback=lambda frac, msg: _progress(msg),
394
  )
395
 
396
  logger.info("Stage 2 success: MatAnyone produced outputs.")
397
  diagnostics["matany_ok"] = True
398
  mat_ok = True
399
- _progress("✅ MatAnyone processing complete")
400
  logger.info(f"[2] MatAnyone results: fg_path={fg_path}, al_path={al_path}")
401
 
402
  except MatAnyError as e:
@@ -409,7 +429,7 @@ def _progress(*args):
409
  fg_path, al_path = None, None
410
 
411
  if not mat_ok:
412
- _progress("MatAnyone failed → using fallback")
413
  logger.info("[2] MatAnyone unavailable or failed, using fallback.")
414
 
415
  # Free MatAnyone ASAP
@@ -419,7 +439,7 @@ def _progress(*args):
419
  pass
420
  matany = None
421
  _force_cleanup()
422
- _progress("🧹 MatAnyone memory cleared")
423
  diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
424
 
425
  # ----------------------------------------------------------------------------------
@@ -427,10 +447,10 @@ def _progress(*args):
427
  # ----------------------------------------------------------------------------------
428
  logger.info("[3] Building Stage-A (transparent or checkerboard)…")
429
  if diagnostics["matany_ok"]:
430
- _progress("✅ Stage 2 complete - Video matting done")
431
  else:
432
- _progress("ℹ️ Skipping MatAnyone outputs; building Stage-A from mask")
433
- _progress("🎨 Building Stage-A video...")
434
 
435
  stageA_path = None
436
  stageA_ok = False
@@ -475,13 +495,13 @@ def _progress(*args):
475
  # [9] PHASE 4: Final compositing
476
  # ----------------------------------------------------------------------------------
477
  logger.info("[4] Creating final composite…")
478
- _progress("✅ Stage 3 complete - Stage-A built")
479
- _progress("🎬 Creating final composite...")
480
  output_path = tmp_root / "output.mp4"
481
 
482
  if diagnostics["matany_ok"] and fg_path and al_path:
483
  logger.info(f"[4] Compositing with MatAnyone outputs: fg_path={fg_path}, al_path={al_path}")
484
- _progress(f"🎬 Compositing video with MatAnyone outputs...")
485
 
486
  fg_exists = Path(fg_path).exists() if fg_path else False
487
  al_exists = Path(al_path).exists() if al_path else False
@@ -492,19 +512,19 @@ def _progress(*args):
492
  logger.info(f"[4] Composite result: {ok_comp}")
493
  if not ok_comp:
494
  logger.info("[4] Composite failed; falling back to static mask composite.")
495
- _progress("⚠️ MatAnyone composite failed, using fallback...")
496
  fallback_composite(video_path, mask_png, bg_image_path, output_path)
497
  diagnostics["fallback_used"] = (diagnostics["fallback_used"] or "") + "+composite_static"
498
  else:
499
- _progress("✅ MatAnyone composite successful!")
500
  else:
501
  logger.error(f"[4] MatAnyone output files missing - using fallback composite")
502
- _progress("⚠️ MatAnyone files missing, using fallback...")
503
  fallback_composite(video_path, mask_png, bg_image_path, output_path)
504
  diagnostics["fallback_used"] = (diagnostics["fallback_used"] or "") + "+composite_static"
505
  else:
506
  logger.info(f"[4] Using static mask composite - matany_ok={diagnostics['matany_ok']}, fg_path={fg_path}, al_path={al_path}")
507
- _progress("🎬 Using static mask composite...")
508
  fallback_composite(video_path, mask_png, bg_image_path, output_path)
509
  diagnostics["fallback_used"] = (diagnostics["fallback_used"] or "") or "composite_static"
510
 
@@ -513,20 +533,20 @@ def _progress(*args):
513
 
514
  if not output_path.exists():
515
  logger.error(f"[4] Output video not created at {output_path}")
516
- _progress("❌ Composite creation failed - no output file")
517
  diagnostics["error"] = "Composite video not created"
518
  return None, diagnostics
519
 
520
  output_size = output_path.stat().st_size
521
  logger.info(f"[4] Output video created: {output_path} ({output_size} bytes)")
522
- _progress(f"✅ Composite created ({output_size} bytes)")
523
 
524
  # ----------------------------------------------------------------------------------
525
  # [10] PHASE 5: Audio mux (if FFmpeg available)
526
  # ----------------------------------------------------------------------------------
527
  logger.info("[5] Adding audio track…")
528
- _progress("✅ Stage 4 complete - Composite created")
529
- _progress("🎵 Adding audio track...")
530
  final_path = tmp_root / "output_with_audio.mp4"
531
 
532
  if _probe_ffmpeg():
@@ -537,7 +557,7 @@ def _progress(*args):
537
  if mux_ok and final_path.exists():
538
  final_size = final_path.stat().st_size
539
  logger.info(f"[5] Final video with audio: {final_path} ({final_size} bytes)")
540
- _progress(f"�� Final video ready ({final_size} bytes)")
541
  output_path.unlink(missing_ok=True)
542
  _force_cleanup()
543
  diagnostics["elapsed_sec"] = round(time.time() - t0, 3)
@@ -551,7 +571,7 @@ def _progress(*args):
551
 
552
  # Fallback return without audio
553
  logger.info(f"[5] Using output without audio: {output_path}")
554
- _progress(f"✅ Video ready (no audio) ({output_size} bytes)")
555
  _force_cleanup()
556
  diagnostics["elapsed_sec"] = round(time.time() - t0, 3)
557
  diagnostics["total_time_sec"] = diagnostics["elapsed_sec"]
@@ -573,4 +593,4 @@ def _progress(*args):
573
  finally:
574
  # Ensure cleanup even if something goes wrong
575
  _force_cleanup()
576
- _cleanup_temp_files(tmp_root)
 
8
  - Verbose breadcrumbs for pinpointing stalls
9
  - Enhanced mask validation for MatAnyone compatibility (robust to API changes)
10
  - Automatic mask inversion for high-coverage masks
11
+
12
+ Changes (2025-09-16):
13
+ - Aligned with torch==2.3.1+cu121, MatAnyone v1.0.0, SAM2 commit 3c76f73c1a7e7b4a2e8a0a9a3e5b92f7e6e3f2f5
14
+ - Added input shape logging before run_matany to prevent 5D tensor issues
15
+ - Added MatAnyone version logging
16
+ - Ensured consistent progress callback with percentages
17
+ - Maintained compatibility with updated models/ files
18
  """
19
 
20
  from __future__ import annotations
 
25
  import tempfile
26
  import logging
27
  import importlib
28
+ import importlib.metadata
29
  from pathlib import Path
30
  from typing import Optional, Tuple, Dict, Any, Union
31
 
 
224
  logger.info(f"Progress: {msg}")
225
  if progress_callback:
226
  try:
227
+ progress_callback(0.0, msg) # Default to 0% if no percentage provided
228
  except TypeError:
229
+ progress_callback(msg) # Legacy 1-arg
230
  elif len(args) >= 2:
231
  pct, msg = args[0], args[1]
232
  logger.info(f"Progress: {msg} ({int(pct*100)}%)")
233
  if progress_callback:
234
  try:
235
+ progress_callback(pct, msg) # Preferred 2-arg
236
  except TypeError:
237
+ progress_callback(msg) # Legacy 1-arg
238
  except Exception as e:
239
  logger.warning(f"progress callback failed: {e}")
240
 
 
243
  # [5] PHASE 0: Video metadata
244
  # ----------------------------------------------------------------------------------
245
  logger.info("[0] Reading video metadata…")
246
+ _progress(0.0, "📹 Reading video metadata...")
247
  first_frame, fps, (vw, vh) = _cv_read_first_frame(video_path)
248
  diagnostics["fps"] = int(fps or 25)
249
  diagnostics["resolution"] = [int(vw), int(vh)]
 
257
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
258
  cap.release()
259
 
260
+ _progress(0.05, f"✅ Video loaded: {vw}x{vh} @ {fps}fps ({total_frames} frames)")
261
  diagnostics["total_frames"] = total_frames
262
  diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
263
 
 
265
  # [6] PHASE 1: SAM2 → seed mask (then free)
266
  # ----------------------------------------------------------------------------------
267
  logger.info("[1] Loading SAM2…")
268
+ _progress(0.1, "🤖 Loading SAM2 model...")
269
  predictor, sam2_ok, sam_meta = load_sam2()
270
  diagnostics["sam2_meta"] = sam_meta or {}
271
  diagnostics["device_sam2"] = (sam_meta or {}).get("sam2_device")
 
277
 
278
  if sam2_ok and predictor is not None:
279
  logger.info("[1] Running SAM2 segmentation…")
280
+ _progress(0.15, "🎯 Running SAM2 segmentation...")
281
  px = int(point_x) if point_x is not None else None
282
  py = int(point_y) if point_y is not None else None
283
  seed_mask, ok_mask = run_sam2_mask(
 
286
  auto=auto_box
287
  )
288
  diagnostics["sam2_ok"] = bool(ok_mask)
289
+ _progress(0.2, "✅ SAM2 segmentation complete")
290
  else:
291
  logger.info("[1] SAM2 unavailable or failed to load.")
292
+ _progress(0.2, "⚠️ SAM2 unavailable, using fallback")
293
 
294
  # Free SAM2 ASAP
295
  try:
 
298
  pass
299
  predictor = None
300
  _force_cleanup()
301
+ _progress(0.25, "🧹 SAM2 memory cleared")
302
  diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
303
 
304
  # Fallback mask generation if SAM2 failed
305
  if not ok_mask or seed_mask is None:
306
  logger.info("[1] Using fallback mask generation…")
307
+ _progress(0.25, "🔄 Generating fallback mask...")
308
  seed_mask = fallback_mask(first_frame)
309
  diagnostics["fallback_used"] = "mask_generation"
310
  _force_cleanup()
 
312
  # Optional GrabCut refinement
313
  if int(os.environ.get("REFINE_GRABCUT", "1")) == 1:
314
  logger.info("[1] Refining mask with GrabCut…")
315
+ _progress(0.3, "✨ Refining mask with GrabCut...")
316
  seed_mask = _refine_mask_grabcut(first_frame, seed_mask)
317
  _force_cleanup()
318
 
 
349
  logger.info(f"[1] ✅ Mask validation passed: {validation_msg}")
350
  diagnostics["mask_validation"] = {"valid": True, "stats": mask_stats}
351
 
352
+ _progress(0.35, "✅ Stage 1 complete - Mask generated and validated")
353
 
354
  # Free first frame memory
355
  try:
 
358
  pass
359
  _force_cleanup()
360
  _cleanup_temp_files(tmp_root)
361
+ _progress(0.4, "🧹 Frame memory cleared")
362
 
363
  # ----------------------------------------------------------------------------------
364
  # [7] PHASE 2: MatAnyone (strict CHW/HW tensors are handled in matanyone_loader)
365
  # ----------------------------------------------------------------------------------
366
  logger.info("[2] Loading MatAnyone…")
367
+ _progress(0.45, "🎬 Loading MatAnyone model...")
368
  matany, mat_ok, mat_meta = load_matany()
369
  diagnostics["matany_meta"] = mat_meta or {}
370
  diagnostics["device_matany"] = (mat_meta or {}).get("matany_device")
371
  diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
372
 
373
+ # Log MatAnyone version
374
+ try:
375
+ version = importlib.metadata.version("matanyone")
376
+ logger.info(f"[MATANY] MatAnyone version: {version}")
377
+ except Exception:
378
+ logger.info("[MATANY] MatAnyone version unknown")
379
+
380
  fg_path, al_path = None, None
381
  out_dir = tmp_root / "matany_out"
382
  _ensure_dir(out_dir)
 
384
  from models.matanyone_loader import MatAnyError
385
 
386
  try:
387
+ _progress(0.5, "MatAnyone: starting…")
388
  logger.info("[2] Running MatAnyone processing…")
389
 
390
  mask_validation = diagnostics.get("mask_validation", {})
391
  if not mask_validation.get("valid", False):
392
  logger.warning(f"[2] Proceeding with MatAnyone despite mask validation failure: "
393
  f"{mask_validation.get('error', 'unknown')}")
 
394
  else:
395
  logger.info(f"[2] Mask validation OK - coverage: "
396
  f"{mask_validation['stats']['coverage_percent']}%")
397
 
398
+ _progress(0.55, "🎥 Running MatAnyone video matting...")
399
+
400
+ # Validate input shapes before MatAnyone
401
+ import cv2
402
+ mask_array = cv2.imread(str(mask_png), cv2.IMREAD_GRAYSCALE)
403
+ logger.info(f"[2] Input mask shape: {mask_array.shape if mask_array is not None else None}")
404
+ if mask_array is None:
405
+ raise MatAnyError(f"Invalid mask at {mask_png}")
406
 
407
  # NOTE: The updated loader feeds CHW image + HW seed (frame 0 only) — no 5D tensors.
408
  al_path, fg_path = run_matany(
 
410
  mask_path=mask_png,
411
  out_dir=out_dir,
412
  device="cuda" if _cuda_available() else "cpu",
413
+ progress_callback=lambda frac, msg: _progress(0.55 + 0.35 * frac, msg),
 
414
  )
415
 
416
  logger.info("Stage 2 success: MatAnyone produced outputs.")
417
  diagnostics["matany_ok"] = True
418
  mat_ok = True
419
+ _progress(0.9, "✅ MatAnyone processing complete")
420
  logger.info(f"[2] MatAnyone results: fg_path={fg_path}, al_path={al_path}")
421
 
422
  except MatAnyError as e:
 
429
  fg_path, al_path = None, None
430
 
431
  if not mat_ok:
432
+ _progress(0.9, "MatAnyone failed → using fallback...")
433
  logger.info("[2] MatAnyone unavailable or failed, using fallback.")
434
 
435
  # Free MatAnyone ASAP
 
439
  pass
440
  matany = None
441
  _force_cleanup()
442
+ _progress(0.95, "🧹 MatAnyone memory cleared")
443
  diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
444
 
445
  # ----------------------------------------------------------------------------------
 
447
  # ----------------------------------------------------------------------------------
448
  logger.info("[3] Building Stage-A (transparent or checkerboard)…")
449
  if diagnostics["matany_ok"]:
450
+ _progress(0.95, "✅ Stage 2 complete - Video matting done")
451
  else:
452
+ _progress(0.95, "ℹ️ Skipping MatAnyone outputs; building Stage-A from mask")
453
+ _progress(0.95, "🎨 Building Stage-A video...")
454
 
455
  stageA_path = None
456
  stageA_ok = False
 
495
  # [9] PHASE 4: Final compositing
496
  # ----------------------------------------------------------------------------------
497
  logger.info("[4] Creating final composite…")
498
+ _progress(0.97, "✅ Stage 3 complete - Stage-A built")
499
+ _progress(0.97, "🎬 Creating final composite...")
500
  output_path = tmp_root / "output.mp4"
501
 
502
  if diagnostics["matany_ok"] and fg_path and al_path:
503
  logger.info(f"[4] Compositing with MatAnyone outputs: fg_path={fg_path}, al_path={al_path}")
504
+ _progress(0.97, f"🎬 Compositing video with MatAnyone outputs...")
505
 
506
  fg_exists = Path(fg_path).exists() if fg_path else False
507
  al_exists = Path(al_path).exists() if al_path else False
 
512
  logger.info(f"[4] Composite result: {ok_comp}")
513
  if not ok_comp:
514
  logger.info("[4] Composite failed; falling back to static mask composite.")
515
+ _progress(0.98, "⚠️ MatAnyone composite failed, using fallback...")
516
  fallback_composite(video_path, mask_png, bg_image_path, output_path)
517
  diagnostics["fallback_used"] = (diagnostics["fallback_used"] or "") + "+composite_static"
518
  else:
519
+ _progress(0.98, "✅ MatAnyone composite successful!")
520
  else:
521
  logger.error(f"[4] MatAnyone output files missing - using fallback composite")
522
+ _progress(0.98, "⚠️ MatAnyone files missing, using fallback...")
523
  fallback_composite(video_path, mask_png, bg_image_path, output_path)
524
  diagnostics["fallback_used"] = (diagnostics["fallback_used"] or "") + "+composite_static"
525
  else:
526
  logger.info(f"[4] Using static mask composite - matany_ok={diagnostics['matany_ok']}, fg_path={fg_path}, al_path={al_path}")
527
+ _progress(0.98, "🎬 Using static mask composite...")
528
  fallback_composite(video_path, mask_png, bg_image_path, output_path)
529
  diagnostics["fallback_used"] = (diagnostics["fallback_used"] or "") or "composite_static"
530
 
 
533
 
534
  if not output_path.exists():
535
  logger.error(f"[4] Output video not created at {output_path}")
536
+ _progress(0.99, "❌ Composite creation failed - no output file")
537
  diagnostics["error"] = "Composite video not created"
538
  return None, diagnostics
539
 
540
  output_size = output_path.stat().st_size
541
  logger.info(f"[4] Output video created: {output_path} ({output_size} bytes)")
542
+ _progress(0.99, f"✅ Composite created ({output_size} bytes)")
543
 
544
  # ----------------------------------------------------------------------------------
545
  # [10] PHASE 5: Audio mux (if FFmpeg available)
546
  # ----------------------------------------------------------------------------------
547
  logger.info("[5] Adding audio track…")
548
+ _progress(0.99, "✅ Stage 4 complete - Composite created")
549
+ _progress(0.99, "🎵 Adding audio track...")
550
  final_path = tmp_root / "output_with_audio.mp4"
551
 
552
  if _probe_ffmpeg():
 
557
  if mux_ok and final_path.exists():
558
  final_size = final_path.stat().st_size
559
  logger.info(f"[5] Final video with audio: {final_path} ({final_size} bytes)")
560
+ _progress(1.0, f" Final video ready ({final_size} bytes)")
561
  output_path.unlink(missing_ok=True)
562
  _force_cleanup()
563
  diagnostics["elapsed_sec"] = round(time.time() - t0, 3)
 
571
 
572
  # Fallback return without audio
573
  logger.info(f"[5] Using output without audio: {output_path}")
574
+ _progress(1.0, f"✅ Video ready (no audio) ({output_size} bytes)")
575
  _force_cleanup()
576
  diagnostics["elapsed_sec"] = round(time.time() - t0, 3)
577
  diagnostics["total_time_sec"] = diagnostics["elapsed_sec"]
 
593
  finally:
594
  # Ensure cleanup even if something goes wrong
595
  _force_cleanup()
596
+ _cleanup_temp_files(tmp_root)
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  # ===== Core runtime (Torch is installed in Dockerfile with cu121 wheels) =====
2
- # DO NOT add torch/torchvision/torchaudio here when using the CUDA wheels in Dockerfile.
3
 
4
  # ===== Video / image IO =====
5
  opencv-python-headless==4.10.0.84
@@ -16,6 +16,7 @@ protobuf==4.25.3
16
  gradio==5.42.0
17
 
18
  # ===== SAM2 Dependencies =====
 
19
  hydra-core==1.3.2
20
  omegaconf==2.3.0
21
  einops==0.8.0
@@ -24,6 +25,7 @@ pyyaml==6.0.2
24
  matplotlib==3.9.2
25
 
26
  # ===== MatAnyone Dependencies =====
 
27
  kornia==0.7.3
28
  scikit-image==0.24.0
29
  tqdm==4.66.5
@@ -35,11 +37,6 @@ psutil==6.0.0
35
  requests==2.32.3
36
  scikit-learn==1.5.1
37
 
38
-
39
- # ===== (Optional) Extras =====
40
- # safetensors==0.4.5
41
- # aiohttp==3.10.5
42
-
43
  # ===== (Optional) Extras =====
44
- # safetensors==0.4.5 # if you pull weights that use safetensors
45
- # aiohttp==3.10.5 # if you later async-fetch assets
 
1
  # ===== Core runtime (Torch is installed in Dockerfile with cu121 wheels) =====
2
+ # DO NOT add torch/torchvision/torchaudio here when using CUDA wheels in Dockerfile.
3
 
4
  # ===== Video / image IO =====
5
  opencv-python-headless==4.10.0.84
 
16
  gradio==5.42.0
17
 
18
  # ===== SAM2 Dependencies =====
19
+ git+https://github.com/facebookresearch/segment-anything-2@main
20
  hydra-core==1.3.2
21
  omegaconf==2.3.0
22
  einops==0.8.0
 
25
  matplotlib==3.9.2
26
 
27
  # ===== MatAnyone Dependencies =====
28
+ git+https://github.com/pq-yang/MatAnyone@master
29
  kornia==0.7.3
30
  scikit-image==0.24.0
31
  tqdm==4.66.5
 
37
  requests==2.32.3
38
  scikit-learn==1.5.1
39
 
 
 
 
 
 
40
  # ===== (Optional) Extras =====
41
+ # safetensors==0.4.5 # Uncomment if pulling weights that use safetensors
42
+ # aiohttp==3.10.5 # Uncomment if async-fetching assets
ui.py CHANGED
@@ -3,7 +3,15 @@
3
  BackgroundFX Pro — Gradio UI, background generators, and data sources (Hardened)
4
  - No top-level import of pipeline (lazy import in handlers)
5
  - Compatible with pipeline.process()
6
- - FIXED: Proper SAM2 configuration for person segmentation
 
 
 
 
 
 
 
 
7
  """
8
 
9
  import io
@@ -16,6 +24,7 @@
16
  from typing import Optional, Tuple, List, Dict, Any
17
  from PIL import Image
18
  import gradio as gr
 
19
 
20
  logger = logging.getLogger("ui")
21
  if not logger.handlers:
@@ -146,7 +155,6 @@ def get_video_url(self, selection: str) -> Optional[str]:
146
  myavatar_api = MyAvatarAPI()
147
 
148
  # ---- Minimal stop flag (request-scoped) ----
149
- # We avoid pipeline globals; this just short-circuits the generator.
150
  class Stopper:
151
  def __init__(self):
152
  self.stop = False
@@ -175,9 +183,11 @@ def process_video_with_background_stoppable(
175
  video_path = None
176
  if input_video:
177
  video_path = input_video
 
178
  elif myavatar_selection and myavatar_selection != "No videos available":
179
  url = myavatar_api.get_video_url(myavatar_selection)
180
  if url:
 
181
  with requests.get(url, stream=True, timeout=60) as r:
182
  r.raise_for_status()
183
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
@@ -188,12 +198,14 @@ def process_video_with_background_stoppable(
188
  if chunk:
189
  tmp.write(chunk)
190
  video_path = tmp.name
 
191
 
192
  if STOP.stop:
193
  yield gr.update(visible=True), gr.update(visible=False), None, "Stopped."
194
  return
195
 
196
- if not video_path:
 
197
  yield gr.update(visible=True), gr.update(visible=False), None, "No video provided"
198
  return
199
 
@@ -202,21 +214,27 @@ def process_video_with_background_stoppable(
202
  bg_img = None
203
  if background_type == "gradient":
204
  bg_img = create_gradient_background(gradient_type, 1920, 1080)
 
205
  elif background_type == "solid":
206
  bg_img = create_solid_color(solid_color, 1920, 1080)
 
207
  elif background_type == "custom" and custom_background:
208
  try:
209
  bg_img = Image.open(custom_background).convert("RGB")
210
- except Exception:
 
 
211
  bg_img = None
212
  elif background_type == "ai" and ai_prompt:
213
- bg_img, _ = generate_ai_background(ai_prompt)
 
214
 
215
  if STOP.stop:
216
  yield gr.update(visible=True), gr.update(visible=False), None, "Stopped."
217
  return
218
 
219
  if bg_img is None:
 
220
  yield gr.update(visible=True), gr.update(visible=False), None, "No background generated"
221
  return
222
 
@@ -224,39 +242,45 @@ def process_video_with_background_stoppable(
224
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_bg:
225
  bg_img.save(tmp_bg.name, format="PNG")
226
  bg_path = tmp_bg.name
 
227
 
228
  # Run pipeline with enhanced real-time status updates
229
  yield gr.update(visible=False), gr.update(visible=True), None, "🔄 Initializing pipeline...\n⚡ Checking GPU acceleration..."
230
  logger.info(f"=== PIPELINE START ===")
231
-
232
  # Enhanced GPU diagnostics with detailed status
233
  try:
234
  import torch
235
  logger.info(f"✅ Torch version: {torch.__version__}")
236
  logger.info(f"✅ CUDA available: {torch.cuda.is_available()}")
237
-
 
 
 
 
 
238
  if torch.cuda.is_available():
239
  device_count = torch.cuda.device_count()
240
  current_device = torch.cuda.current_device()
241
  device_name = torch.cuda.get_device_name()
242
  device_capability = torch.cuda.get_device_capability()
243
-
244
  # Get GPU memory info
245
  memory_allocated = torch.cuda.memory_allocated() / (1024**3) # GB
246
  memory_reserved = torch.cuda.memory_reserved() / (1024**3) # GB
247
  memory_total = torch.cuda.get_device_properties(current_device).total_memory / (1024**3) # GB
248
-
249
  gpu_status = f"""✅ GPU Acceleration Active
250
  🖥️ Device: {device_name} (Compute {device_capability[0]}.{device_capability[1]})
251
  💾 Memory: {memory_allocated:.1f}GB allocated / {memory_total:.1f}GB total
252
  🔧 CUDA {torch.version.cuda} | PyTorch {torch.__version__}
253
  📊 Ready for SAM2 + MatAnyone processing..."""
254
-
255
  logger.info(f"✅ CUDA device count: {device_count}")
256
  logger.info(f"✅ Current device: {current_device}")
257
  logger.info(f"✅ Device name: {device_name}")
258
  logger.info(f"✅ GPU memory: {memory_allocated:.1f}GB/{memory_total:.1f}GB")
259
-
260
  yield gr.update(visible=False), gr.update(visible=True), None, gpu_status
261
  else:
262
  logger.error(f"❌ CUDA NOT AVAILABLE - GPU processing will fail")
@@ -266,34 +290,33 @@ def process_video_with_background_stoppable(
266
  logger.error(f"❌ Torch/CUDA check failed: {e}")
267
  yield gr.update(visible=True), gr.update(visible=False), None, f"GPU check error: {e}"
268
  return
269
-
270
  yield gr.update(visible=False), gr.update(visible=True), None, gpu_status + "\n\n🔄 Loading pipeline modules..."
271
  logger.info(f"About to import pipeline module...")
272
-
273
  try:
274
  pipe = importlib.import_module("pipeline")
275
  logger.info(f"✅ Pipeline module imported successfully")
276
-
277
  pipeline_status = gpu_status + "\n\n✅ Pipeline modules loaded\n📹 Initializing video processing pipeline..."
278
  yield gr.update(visible=False), gr.update(visible=True), None, pipeline_status
279
  except Exception as e:
280
  logger.error(f"❌ Pipeline import failed: {e}")
281
  yield gr.update(visible=True), gr.update(visible=False), None, f"Pipeline import error: {e}"
282
  return
283
-
284
  logger.info(f"Calling pipe.process with video_path={video_path}, bg_path={bg_path}")
285
  logger.info(f"=== CALLING PIPELINE.PROCESS ===")
286
-
287
  # Enhanced status during processing with detailed stage tracking
288
  stage_status = {
289
  "current_stage": "Starting...",
290
  "sam2_status": "⏳ Pending",
291
- "matany_status": "⏳ Pending",
292
  "composite_status": "⏳ Pending",
293
  "audio_status": "⏳ Pending",
294
  "frame_progress": ""
295
  }
296
-
297
  def format_status():
298
  return (gpu_status + f"\n\n🚀 PROCESSING: {stage_status['current_stage']}\n\n" +
299
  f"📊 PIPELINE STAGES:\n" +
@@ -302,59 +325,52 @@ def format_status():
302
  f"🎬 Video Compositing: {stage_status['composite_status']}\n" +
303
  f"🔊 Audio Muxing: {stage_status['audio_status']}\n" +
304
  (f"\n📈 {stage_status['frame_progress']}" if stage_status['frame_progress'] else ""))
305
-
306
  processing_status = format_status()
307
  yield gr.update(visible=False), gr.update(visible=True), None, processing_status
308
-
309
  # Create progress callback to update UI status with detailed tracking
310
- def progress_callback(message):
311
  nonlocal stage_status
312
-
313
  # Update current stage and frame progress
314
- stage_status['current_stage'] = message
315
-
316
  # Track specific stages
317
- if "SAM2" in message or "segmentation" in message.lower():
318
- if "complete" in message.lower() or "✅" in message:
319
  stage_status['sam2_status'] = "✅ Complete"
320
  else:
321
  stage_status['sam2_status'] = "🔄 Running..."
322
-
323
- elif "MatAnyone" in message or "matting" in message.lower():
324
- if "complete" in message.lower() or "✅" in message:
325
  stage_status['matany_status'] = "✅ Complete"
326
- elif "failed" in message.lower() or "fallback" in message.lower():
327
  stage_status['matany_status'] = "❌ Failed → Fallback"
328
  else:
329
  stage_status['matany_status'] = "🔄 Running..."
330
-
331
- elif "composit" in message.lower():
332
- if "complete" in message.lower() or "✅" in message:
333
  stage_status['composite_status'] = "✅ Complete"
334
  else:
335
  stage_status['composite_status'] = "🔄 Running..."
336
-
337
- elif "audio" in message.lower() or "mux" in message.lower():
338
- if "complete" in message.lower() or "✅" in message:
339
  stage_status['audio_status'] = "✅ Complete"
340
  else:
341
  stage_status['audio_status'] = "🔄 Running..."
342
-
343
  # Extract frame progress
344
- if "/" in message and any(word in message.lower() for word in ["frame", "matting", "chunking"]):
345
- stage_status['frame_progress'] = message
346
-
347
  updated_status = format_status()
348
  return gr.update(visible=False), gr.update(visible=True), None, updated_status
349
-
350
  try:
351
- # FIXED: Remove problematic auto_box setting and use smart person detection
352
  out_path, diag = pipe.process(
353
  video_path=video_path,
354
  bg_image_path=bg_path,
355
- point_x=None, # Let SAM2 use smart person detection
356
- point_y=None, # Let SAM2 use smart person detection
357
- auto_box=False, # FIXED: Disable auto_box to use our smart detection
358
  work_dir=None,
359
  progress_callback=progress_callback
360
  )
@@ -364,11 +380,10 @@ def progress_callback(message):
364
  logger.error(f"❌ Pipeline.process failed: {e}")
365
  import traceback
366
  logger.error(f"Full traceback: {traceback.format_exc()}")
367
-
368
  error_status = gpu_status + f"\n\n❌ PROCESSING FAILED\n🚨 Error: {str(e)[:200]}..."
369
  yield gr.update(visible=True), gr.update(visible=False), None, error_status
370
  return
371
-
372
  if out_path:
373
  # Enhanced final processing stats with detailed breakdown
374
  fps = diag.get('fps', 'unknown')
@@ -376,31 +391,25 @@ def progress_callback(message):
376
  sam2_ok = diag.get('sam2_ok', False)
377
  matany_ok = diag.get('matany_ok', False)
378
  processing_time = diag.get('total_time_sec', 0)
379
- sam2_time = diag.get('sam2_time_sec', 0)
380
- matany_time = diag.get('matany_time_sec', 0)
381
-
382
  # Check mask validation results for quality feedback
383
  mask_validation = diag.get('mask_validation', {})
384
  mask_valid = mask_validation.get('valid', False)
385
  mask_coverage = mask_validation.get('stats', {}).get('coverage_percent', 0)
386
-
387
  # Get final GPU memory usage and verify GPU acceleration was used
388
  try:
389
  import torch
390
  if torch.cuda.is_available():
391
  final_memory = torch.cuda.memory_allocated() / (1024**3)
392
  peak_memory = torch.cuda.max_memory_allocated() / (1024**3)
393
-
394
- # Log GPU utilization to verify models used GPU
395
  logger.info(f"GPU USAGE VERIFICATION:")
396
  logger.info(f" Final memory allocated: {final_memory:.2f}GB")
397
  logger.info(f" Peak memory used: {peak_memory:.2f}GB")
398
-
399
  if peak_memory < 0.1: # Less than 100MB indicates CPU usage
400
- logger.warning(f"⚠️ LOW GPU USAGE! Peak memory {peak_memory:.2f}GB suggests CPU fallback")
401
  else:
402
  logger.info(f"✅ GPU ACCELERATION CONFIRMED - Peak usage {peak_memory:.2f}GB")
403
-
404
  torch.cuda.reset_peak_memory_stats() # Reset for next run
405
  else:
406
  final_memory = peak_memory = 0
@@ -408,7 +417,7 @@ def progress_callback(message):
408
  except Exception as e:
409
  logger.error(f"GPU memory check failed: {e}")
410
  final_memory = peak_memory = 0
411
-
412
  # Enhanced success message with segmentation quality info
413
  segmentation_quality = ""
414
  if mask_valid and mask_coverage > 0:
@@ -418,13 +427,13 @@ def progress_callback(message):
418
  segmentation_quality = f"⚠️ High segmentation ({mask_coverage:.1f}% - check background)"
419
  else:
420
  segmentation_quality = f"✅ Person segmented ({mask_coverage:.1f}%)"
421
-
422
  status_msg = gpu_status + f"""
423
 
424
  🎉 PROCESSING COMPLETE!
425
- ✅ Stage 1: SAM2 segmentation {'✓' if sam2_ok else '✗'} ({sam2_time:.1f}s)
426
  {segmentation_quality}
427
- ✅ Stage 2: MatAnyone matting {'✓' if matany_ok else '✗'} ({matany_time:.1f}s)
428
  ✅ Stage 3: Final compositing complete
429
 
430
  📊 RESULTS:
@@ -457,10 +466,12 @@ def progress_callback(message):
457
  logger.error(f"Full traceback: {traceback.format_exc()}")
458
  yield gr.update(visible=True), gr.update(visible=False), None, f"Processing error: {e}"
459
  finally:
460
- # Best-effort cleanup of any temp download
461
  try:
462
  if input_video is None and 'video_path' in locals() and video_path and os.path.exists(video_path):
463
  os.unlink(video_path)
 
 
464
  except Exception:
465
  pass
466
 
@@ -483,7 +494,7 @@ def create_interface():
483
  """
484
 
485
  with gr.Blocks(css=css, title="BackgroundFX Pro") as app:
486
- gr.Markdown("# BackgroundFX Pro — SAM2 + MatAnyone (Fixed)")
487
 
488
  with gr.Row():
489
  status = _system_status()
@@ -525,72 +536,5 @@ def create_interface():
525
  result_video = gr.Video(label="Processed Video", height=400)
526
  status_output = gr.Textbox(label="Processing Status", lines=8, max_lines=15, elem_classes=["status-box"])
527
  gr.Markdown("""
528
- ### Pipeline (Fixed)
529
- 1. SAM2 Smart Person Detection → proper mask (15-35% coverage)
530
- 2. MatAnyone Matting → FG + ALPHA
531
- 3. Stage-A export (transparent WebM or checkerboard)
532
- 4. Final compositing (H.264)
533
- """)
534
-
535
- # handlers
536
- def update_background_options(bg_type):
537
- return {
538
- gradient_type: gr.update(visible=(bg_type == "gradient")),
539
- gradient_preview: gr.update(visible=(bg_type == "gradient")),
540
- solid_color: gr.update(visible=(bg_type == "solid")),
541
- color_preview: gr.update(visible=(bg_type == "solid")),
542
- custom_bg_upload: gr.update(visible=(bg_type == "custom")),
543
- ai_prompt: gr.update(visible=(bg_type == "ai")),
544
- ai_generate_btn: gr.update(visible=(bg_type == "ai")),
545
- ai_preview: gr.update(visible=(bg_type == "ai")),
546
- }
547
-
548
- def update_gradient_preview(grad_type):
549
- try:
550
- return create_gradient_background(grad_type, 400, 200)
551
- except Exception:
552
- return None
553
-
554
- def update_color_preview(color):
555
- try:
556
- return create_solid_color(color, 400, 200)
557
- except Exception:
558
- return None
559
-
560
- def refresh_myavatar_videos():
561
- try:
562
- return gr.update(choices=myavatar_api.get_video_choices(), value=None)
563
- except Exception:
564
- return gr.update(choices=["Error loading videos"], value=None)
565
-
566
- def load_video_preview(selection):
567
- try:
568
- return myavatar_api.get_video_url(selection)
569
- except Exception:
570
- return None
571
-
572
- def generate_ai_bg(prompt):
573
- bg_img, _ = generate_ai_background(prompt)
574
- return bg_img
575
-
576
- background_type.change(
577
- fn=update_background_options,
578
- inputs=[background_type],
579
- outputs=[gradient_type, gradient_preview, solid_color, color_preview, custom_bg_upload, ai_prompt, ai_generate_btn, ai_preview]
580
- )
581
- gradient_type.change(fn=update_gradient_preview, inputs=[gradient_type], outputs=[gradient_preview])
582
- solid_color.change(fn=update_color_preview, inputs=[solid_color], outputs=[color_preview])
583
- refresh_btn.click(fn=refresh_myavatar_videos, outputs=[myavatar_dropdown])
584
- myavatar_dropdown.change(fn=load_video_preview, inputs=[myavatar_dropdown], outputs=[video_preview])
585
- ai_generate_btn.click(fn=generate_ai_bg, inputs=[ai_prompt], outputs=[ai_preview])
586
-
587
- process_btn.click(
588
- fn=process_video_with_background_stoppable,
589
- inputs=[video_upload, myavatar_dropdown, background_type, gradient_type, solid_color, custom_bg_upload, ai_prompt],
590
- outputs=[process_btn, stop_btn, result_video, status_output]
591
- )
592
- stop_btn.click(fn=stop_processing_button, outputs=[stop_btn, status_output])
593
-
594
- app.load(fn=lambda: create_gradient_background("sunset", 400, 200), outputs=[gradient_preview])
595
-
596
- return app
 
3
  BackgroundFX Pro — Gradio UI, background generators, and data sources (Hardened)
4
  - No top-level import of pipeline (lazy import in handlers)
5
  - Compatible with pipeline.process()
6
+ - Aligned with torch==2.3.1+cu121, MatAnyone v1.0.0, SAM2 commit 3c76f73c1a7e7b4a2e8a0a9a3e5b92f7e6e3f2f5
7
+
8
+ Changes (2025-09-16):
9
+ - Aligned with updated pipeline.py and models/
10
+ - Updated progress callback to pass percentages to pipeline.process
11
+ - Added input path validation logging
12
+ - Simplified SAM2 arguments to use pipeline defaults
13
+ - Added MatAnyone version logging in GPU diagnostics
14
+ - Enhanced temporary file cleanup
15
  """
16
 
17
  import io
 
24
  from typing import Optional, Tuple, List, Dict, Any
25
  from PIL import Image
26
  import gradio as gr
27
+ import importlib.metadata
28
 
29
  logger = logging.getLogger("ui")
30
  if not logger.handlers:
 
155
  myavatar_api = MyAvatarAPI()
156
 
157
  # ---- Minimal stop flag (request-scoped) ----
 
158
  class Stopper:
159
  def __init__(self):
160
  self.stop = False
 
183
  video_path = None
184
  if input_video:
185
  video_path = input_video
186
+ logger.info(f"[UI] Using uploaded video: {video_path}")
187
  elif myavatar_selection and myavatar_selection != "No videos available":
188
  url = myavatar_api.get_video_url(myavatar_selection)
189
  if url:
190
+ logger.info(f"[UI] Fetching MyAvatar video: {url}")
191
  with requests.get(url, stream=True, timeout=60) as r:
192
  r.raise_for_status()
193
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
 
198
  if chunk:
199
  tmp.write(chunk)
200
  video_path = tmp.name
201
+ logger.info(f"[UI] Downloaded MyAvatar video to: {video_path}")
202
 
203
  if STOP.stop:
204
  yield gr.update(visible=True), gr.update(visible=False), None, "Stopped."
205
  return
206
 
207
+ if not video_path or not os.path.exists(video_path):
208
+ logger.error(f"[UI] No valid video provided: input_video={input_video}, myavatar_selection={myavatar_selection}")
209
  yield gr.update(visible=True), gr.update(visible=False), None, "No video provided"
210
  return
211
 
 
214
  bg_img = None
215
  if background_type == "gradient":
216
  bg_img = create_gradient_background(gradient_type, 1920, 1080)
217
+ logger.info(f"[UI] Generated gradient background: {gradient_type}")
218
  elif background_type == "solid":
219
  bg_img = create_solid_color(solid_color, 1920, 1080)
220
+ logger.info(f"[UI] Generated solid color background: {solid_color}")
221
  elif background_type == "custom" and custom_background:
222
  try:
223
  bg_img = Image.open(custom_background).convert("RGB")
224
+ logger.info(f"[UI] Loaded custom background: {custom_background}")
225
+ except Exception as e:
226
+ logger.error(f"[UI] Failed to load custom background: {e}")
227
  bg_img = None
228
  elif background_type == "ai" and ai_prompt:
229
+ bg_img, msg = generate_ai_background(ai_prompt)
230
+ logger.info(f"[UI] AI background generation: {msg}")
231
 
232
  if STOP.stop:
233
  yield gr.update(visible=True), gr.update(visible=False), None, "Stopped."
234
  return
235
 
236
  if bg_img is None:
237
+ logger.error(f"[UI] No background generated: type={background_type}, custom={custom_background}, ai_prompt={ai_prompt}")
238
  yield gr.update(visible=True), gr.update(visible=False), None, "No background generated"
239
  return
240
 
 
242
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_bg:
243
  bg_img.save(tmp_bg.name, format="PNG")
244
  bg_path = tmp_bg.name
245
+ logger.info(f"[UI] Saved background to: {bg_path}")
246
 
247
  # Run pipeline with enhanced real-time status updates
248
  yield gr.update(visible=False), gr.update(visible=True), None, "🔄 Initializing pipeline...\n⚡ Checking GPU acceleration..."
249
  logger.info(f"=== PIPELINE START ===")
250
+
251
  # Enhanced GPU diagnostics with detailed status
252
  try:
253
  import torch
254
  logger.info(f"✅ Torch version: {torch.__version__}")
255
  logger.info(f"✅ CUDA available: {torch.cuda.is_available()}")
256
+ try:
257
+ version = importlib.metadata.version("matanyone")
258
+ logger.info(f"[MATANY] MatAnyone version: {version}")
259
+ except Exception:
260
+ logger.info("[MATANY] MatAnyone version unknown")
261
+
262
  if torch.cuda.is_available():
263
  device_count = torch.cuda.device_count()
264
  current_device = torch.cuda.current_device()
265
  device_name = torch.cuda.get_device_name()
266
  device_capability = torch.cuda.get_device_capability()
267
+
268
  # Get GPU memory info
269
  memory_allocated = torch.cuda.memory_allocated() / (1024**3) # GB
270
  memory_reserved = torch.cuda.memory_reserved() / (1024**3) # GB
271
  memory_total = torch.cuda.get_device_properties(current_device).total_memory / (1024**3) # GB
272
+
273
  gpu_status = f"""✅ GPU Acceleration Active
274
  🖥️ Device: {device_name} (Compute {device_capability[0]}.{device_capability[1]})
275
  💾 Memory: {memory_allocated:.1f}GB allocated / {memory_total:.1f}GB total
276
  🔧 CUDA {torch.version.cuda} | PyTorch {torch.__version__}
277
  📊 Ready for SAM2 + MatAnyone processing..."""
278
+
279
  logger.info(f"✅ CUDA device count: {device_count}")
280
  logger.info(f"✅ Current device: {current_device}")
281
  logger.info(f"✅ Device name: {device_name}")
282
  logger.info(f"✅ GPU memory: {memory_allocated:.1f}GB/{memory_total:.1f}GB")
283
+
284
  yield gr.update(visible=False), gr.update(visible=True), None, gpu_status
285
  else:
286
  logger.error(f"❌ CUDA NOT AVAILABLE - GPU processing will fail")
 
290
  logger.error(f"❌ Torch/CUDA check failed: {e}")
291
  yield gr.update(visible=True), gr.update(visible=False), None, f"GPU check error: {e}"
292
  return
293
+
294
  yield gr.update(visible=False), gr.update(visible=True), None, gpu_status + "\n\n🔄 Loading pipeline modules..."
295
  logger.info(f"About to import pipeline module...")
296
+
297
  try:
298
  pipe = importlib.import_module("pipeline")
299
  logger.info(f"✅ Pipeline module imported successfully")
 
300
  pipeline_status = gpu_status + "\n\n✅ Pipeline modules loaded\n📹 Initializing video processing pipeline..."
301
  yield gr.update(visible=False), gr.update(visible=True), None, pipeline_status
302
  except Exception as e:
303
  logger.error(f"❌ Pipeline import failed: {e}")
304
  yield gr.update(visible=True), gr.update(visible=False), None, f"Pipeline import error: {e}"
305
  return
306
+
307
  logger.info(f"Calling pipe.process with video_path={video_path}, bg_path={bg_path}")
308
  logger.info(f"=== CALLING PIPELINE.PROCESS ===")
309
+
310
  # Enhanced status during processing with detailed stage tracking
311
  stage_status = {
312
  "current_stage": "Starting...",
313
  "sam2_status": "⏳ Pending",
314
+ "matany_status": "⏳ Pending",
315
  "composite_status": "⏳ Pending",
316
  "audio_status": "⏳ Pending",
317
  "frame_progress": ""
318
  }
319
+
320
  def format_status():
321
  return (gpu_status + f"\n\n🚀 PROCESSING: {stage_status['current_stage']}\n\n" +
322
  f"📊 PIPELINE STAGES:\n" +
 
325
  f"🎬 Video Compositing: {stage_status['composite_status']}\n" +
326
  f"🔊 Audio Muxing: {stage_status['audio_status']}\n" +
327
  (f"\n📈 {stage_status['frame_progress']}" if stage_status['frame_progress'] else ""))
328
+
329
  processing_status = format_status()
330
  yield gr.update(visible=False), gr.update(visible=True), None, processing_status
331
+
332
  # Create progress callback to update UI status with detailed tracking
333
+ def progress_callback(pct: float, msg: str):
334
  nonlocal stage_status
335
+
336
  # Update current stage and frame progress
337
+ stage_status['current_stage'] = msg
338
+
339
  # Track specific stages
340
+ if "SAM2" in msg or "segmentation" in msg.lower():
341
+ if "complete" in msg.lower() or "✅" in msg:
342
  stage_status['sam2_status'] = "✅ Complete"
343
  else:
344
  stage_status['sam2_status'] = "🔄 Running..."
345
+ elif "MatAnyone" in msg or "matting" in msg.lower():
346
+ if "complete" in msg.lower() or "" in msg:
 
347
  stage_status['matany_status'] = "✅ Complete"
348
+ elif "failed" in msg.lower() or "fallback" in msg.lower():
349
  stage_status['matany_status'] = "❌ Failed → Fallback"
350
  else:
351
  stage_status['matany_status'] = "🔄 Running..."
352
+ elif "composit" in msg.lower():
353
+ if "complete" in msg.lower() or "✅" in msg:
 
354
  stage_status['composite_status'] = "✅ Complete"
355
  else:
356
  stage_status['composite_status'] = "🔄 Running..."
357
+ elif "audio" in msg.lower() or "mux" in msg.lower():
358
+ if "complete" in msg.lower() or "" in msg:
 
359
  stage_status['audio_status'] = "✅ Complete"
360
  else:
361
  stage_status['audio_status'] = "🔄 Running..."
362
+
363
  # Extract frame progress
364
+ if "/" in msg and any(word in msg.lower() for word in ["frame", "matting", "chunking"]):
365
+ stage_status['frame_progress'] = msg
366
+
367
  updated_status = format_status()
368
  return gr.update(visible=False), gr.update(visible=True), None, updated_status
369
+
370
  try:
 
371
  out_path, diag = pipe.process(
372
  video_path=video_path,
373
  bg_image_path=bg_path,
 
 
 
374
  work_dir=None,
375
  progress_callback=progress_callback
376
  )
 
380
  logger.error(f"❌ Pipeline.process failed: {e}")
381
  import traceback
382
  logger.error(f"Full traceback: {traceback.format_exc()}")
 
383
  error_status = gpu_status + f"\n\n❌ PROCESSING FAILED\n🚨 Error: {str(e)[:200]}..."
384
  yield gr.update(visible=True), gr.update(visible=False), None, error_status
385
  return
386
+
387
  if out_path:
388
  # Enhanced final processing stats with detailed breakdown
389
  fps = diag.get('fps', 'unknown')
 
391
  sam2_ok = diag.get('sam2_ok', False)
392
  matany_ok = diag.get('matany_ok', False)
393
  processing_time = diag.get('total_time_sec', 0)
394
+
 
 
395
  # Check mask validation results for quality feedback
396
  mask_validation = diag.get('mask_validation', {})
397
  mask_valid = mask_validation.get('valid', False)
398
  mask_coverage = mask_validation.get('stats', {}).get('coverage_percent', 0)
399
+
400
  # Get final GPU memory usage and verify GPU acceleration was used
401
  try:
402
  import torch
403
  if torch.cuda.is_available():
404
  final_memory = torch.cuda.memory_allocated() / (1024**3)
405
  peak_memory = torch.cuda.max_memory_allocated() / (1024**3)
 
 
406
  logger.info(f"GPU USAGE VERIFICATION:")
407
  logger.info(f" Final memory allocated: {final_memory:.2f}GB")
408
  logger.info(f" Peak memory used: {peak_memory:.2f}GB")
 
409
  if peak_memory < 0.1: # Less than 100MB indicates CPU usage
410
+ logger.warning(f"⚠️ LOW GPU USAGE! Peak memory {peak_memory:.2f}GB suggests CPU fallback")
411
  else:
412
  logger.info(f"✅ GPU ACCELERATION CONFIRMED - Peak usage {peak_memory:.2f}GB")
 
413
  torch.cuda.reset_peak_memory_stats() # Reset for next run
414
  else:
415
  final_memory = peak_memory = 0
 
417
  except Exception as e:
418
  logger.error(f"GPU memory check failed: {e}")
419
  final_memory = peak_memory = 0
420
+
421
  # Enhanced success message with segmentation quality info
422
  segmentation_quality = ""
423
  if mask_valid and mask_coverage > 0:
 
427
  segmentation_quality = f"⚠️ High segmentation ({mask_coverage:.1f}% - check background)"
428
  else:
429
  segmentation_quality = f"✅ Person segmented ({mask_coverage:.1f}%)"
430
+
431
  status_msg = gpu_status + f"""
432
 
433
  🎉 PROCESSING COMPLETE!
434
+ ✅ Stage 1: SAM2 segmentation {'✓' if sam2_ok else '✗'}
435
  {segmentation_quality}
436
+ ✅ Stage 2: MatAnyone matting {'✓' if matany_ok else '✗'}
437
  ✅ Stage 3: Final compositing complete
438
 
439
  📊 RESULTS:
 
466
  logger.error(f"Full traceback: {traceback.format_exc()}")
467
  yield gr.update(visible=True), gr.update(visible=False), None, f"Processing error: {e}"
468
  finally:
469
+ # Best-effort cleanup of any temp files
470
  try:
471
  if input_video is None and 'video_path' in locals() and video_path and os.path.exists(video_path):
472
  os.unlink(video_path)
473
+ if 'bg_path' in locals() and bg_path and os.path.exists(bg_path):
474
+ os.unlink(bg_path)
475
  except Exception:
476
  pass
477
 
 
494
  """
495
 
496
  with gr.Blocks(css=css, title="BackgroundFX Pro") as app:
497
+ gr.Markdown("# BackgroundFX Pro — SAM2 + MatAnyone")
498
 
499
  with gr.Row():
500
  status = _system_status()
 
536
  result_video = gr.Video(label="Processed Video", height=400)
537
  status_output = gr.Textbox(label="Processing Status", lines=8, max_lines=15, elem_classes=["status-box"])
538
  gr.Markdown("""
539
+ ### Pipeline
540
+ 1. SAM2 Smart Person Detection → proper mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/mask_validation.py CHANGED
@@ -1,5 +1,16 @@
 
1
  """
2
  Mask validation utilities for BackgroundFX Pro.
 
 
 
 
 
 
 
 
 
 
3
  """
4
 
5
  import numpy as np
@@ -7,6 +18,16 @@
7
  from pathlib import Path
8
  from typing import Dict, Union, Tuple, Optional
9
 
 
 
 
 
 
 
 
 
 
 
10
  def validate_mask_for_matanyone_simple(mask_input: Union[np.ndarray, str, Path], target_hw=None, *_, **__) -> Tuple[bool, Dict, str]:
11
  """
12
  Back/forward-compatible mask sanity check for MatAnyone/SAM2 pipelines.
@@ -22,6 +43,15 @@ def validate_mask_for_matanyone_simple(mask_input: Union[np.ndarray, str, Path],
22
  - We only *validate*; inversion/repair is upstream's job.
23
  - Always returns True unless the mask is unreadable or wrong rank.
24
  """
 
 
 
 
 
 
 
 
 
25
  # ---- load to np.uint8/float32 2D ----
26
  try:
27
  import torch # optional
@@ -32,14 +62,17 @@ def validate_mask_for_matanyone_simple(mask_input: Union[np.ndarray, str, Path],
32
  if isinstance(mask_input, (str, Path)):
33
  mask = cv2.imread(str(mask_input), cv2.IMREAD_GRAYSCALE)
34
  if mask is None:
 
35
  return False, {}, f"Mask not found or unreadable: {mask_input}"
36
  # Load from torch tensor
37
  elif (torch is not None) and isinstance(mask_input, torch.Tensor):
38
  mask = mask_input.detach().cpu().numpy()
 
39
  # Already numpy
40
  elif isinstance(mask_input, np.ndarray):
41
  mask = mask_input
42
  else:
 
43
  return False, {}, f"Unsupported mask type: {type(mask_input)}"
44
 
45
  # If 3D, squeeze/convert to grayscale
@@ -49,12 +82,16 @@ def validate_mask_for_matanyone_simple(mask_input: Union[np.ndarray, str, Path],
49
  else:
50
  mask = np.squeeze(mask)
51
  if mask.ndim != 2:
 
52
  return False, {}, f"Mask must be 2D, got shape {mask.shape}"
53
 
 
 
54
  # Optional resize to target (H, W)
55
  if target_hw is not None and isinstance(target_hw, (tuple, list)) and len(target_hw) == 2:
56
  H, W = int(target_hw[0]), int(target_hw[1])
57
  if mask.shape != (H, W):
 
58
  mask = cv2.resize(mask, (W, H), interpolation=cv2.INTER_LINEAR)
59
 
60
  # Normalize to [0,1] float
@@ -69,6 +106,8 @@ def validate_mask_for_matanyone_simple(mask_input: Union[np.ndarray, str, Path],
69
 
70
  # We keep validation permissive; upstream may invert/repair based on coverage
71
  msg = f"Basic validation - {coverage:.1f}% coverage"
 
 
72
  return True, stats, msg
73
 
74
  def validate_mask_for_matanyone_advanced(mask: np.ndarray, min_foreground: float = 0.01) -> Tuple[bool, str]:
@@ -82,6 +121,8 @@ def validate_mask_for_matanyone_advanced(mask: np.ndarray, min_foreground: float
82
  Returns:
83
  Tuple of (is_valid, error_message)
84
  """
 
 
85
  # Basic validation first
86
  is_valid, msg = validate_mask_for_matanyone_simple(mask)
87
  if not is_valid:
@@ -94,8 +135,10 @@ def validate_mask_for_matanyone_advanced(mask: np.ndarray, min_foreground: float
94
  # Check foreground ratio
95
  fg_ratio = mask.mean()
96
  if fg_ratio < min_foreground:
 
97
  return False, f"Foreground area too small ({fg_ratio:.1%} < {min_foreground:.0%})"
98
 
 
99
  return True, ""
100
 
101
  def preprocess_mask(mask: np.ndarray, target_size: Optional[Tuple[int, int]] = None) -> np.ndarray:
@@ -109,19 +152,33 @@ def preprocess_mask(mask: np.ndarray, target_size: Optional[Tuple[int, int]] = N
109
  Returns:
110
  Preprocessed mask (H,W) float32 in [0,1]
111
  """
 
 
112
  # Ensure 2D
113
  if mask.ndim == 3:
114
- mask = mask.squeeze(2)
115
-
 
 
 
 
 
116
  # Convert to float32 in [0,1]
117
  if mask.dtype != np.float32:
118
  mask = mask.astype(np.float32)
 
119
  if mask.max() > 1.0:
120
  mask = mask / 255.0
121
-
 
 
 
122
  # Resize if needed
123
  if target_size is not None and (mask.shape[0] != target_size[0] or mask.shape[1] != target_size[1]):
124
  import cv2
125
- mask = cv2.resize(mask, (target_size[1], target_size[0]), interpolation=cv2.INTER_LINEAR)
126
-
127
- return mask
 
 
 
 
1
+ #!/usr/bin/env python3
2
  """
3
  Mask validation utilities for BackgroundFX Pro.
4
+ ==============================================
5
+ - Validates masks for MatAnyone compatibility
6
+ - Ensures 2D masks [H,W] to avoid 5D tensor issues
7
+ - Aligned with torch==2.3.1+cu121, MatAnyone v1.0.0
8
+
9
+ Changes (2025-09-16):
10
+ - Aligned with updated pipeline.py and models/
11
+ - Added logging for mask shape and coverage
12
+ - Enhanced preprocess_mask for Torch tensors from SAM2
13
+ - Ensured 2D mask output for MatAnyone
14
  """
15
 
16
  import numpy as np
 
18
  from pathlib import Path
19
  from typing import Dict, Union, Tuple, Optional
20
 
21
+ import logging
22
+ import importlib.metadata
23
+
24
+ logger = logging.getLogger("backgroundfx_pro")
25
+ if not logger.handlers:
26
+ h = logging.StreamHandler()
27
+ h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s"))
28
+ logger.addHandler(h)
29
+ logger.setLevel(logging.INFO)
30
+
31
  def validate_mask_for_matanyone_simple(mask_input: Union[np.ndarray, str, Path], target_hw=None, *_, **__) -> Tuple[bool, Dict, str]:
32
  """
33
  Back/forward-compatible mask sanity check for MatAnyone/SAM2 pipelines.
 
43
  - We only *validate*; inversion/repair is upstream's job.
44
  - Always returns True unless the mask is unreadable or wrong rank.
45
  """
46
+ logger.info(f"[MaskValidation] Validating mask: {mask_input}")
47
+
48
+ # Log MatAnyone version for compatibility check
49
+ try:
50
+ version = importlib.metadata.version("matanyone")
51
+ logger.info(f"[MaskValidation] MatAnyone version: {version}")
52
+ except Exception:
53
+ logger.info("[MaskValidation] MatAnyone version unknown")
54
+
55
  # ---- load to np.uint8/float32 2D ----
56
  try:
57
  import torch # optional
 
62
  if isinstance(mask_input, (str, Path)):
63
  mask = cv2.imread(str(mask_input), cv2.IMREAD_GRAYSCALE)
64
  if mask is None:
65
+ logger.error(f"[MaskValidation] Could not load mask: {mask_input}")
66
  return False, {}, f"Mask not found or unreadable: {mask_input}"
67
  # Load from torch tensor
68
  elif (torch is not None) and isinstance(mask_input, torch.Tensor):
69
  mask = mask_input.detach().cpu().numpy()
70
+ logger.info(f"[MaskValidation] Loaded Torch tensor mask: shape={mask.shape}, dtype={mask.dtype}")
71
  # Already numpy
72
  elif isinstance(mask_input, np.ndarray):
73
  mask = mask_input
74
  else:
75
+ logger.error(f"[MaskValidation] Unsupported mask type: {type(mask_input)}")
76
  return False, {}, f"Unsupported mask type: {type(mask_input)}"
77
 
78
  # If 3D, squeeze/convert to grayscale
 
82
  else:
83
  mask = np.squeeze(mask)
84
  if mask.ndim != 2:
85
+ logger.error(f"[MaskValidation] Mask must be 2D, got shape {mask.shape}")
86
  return False, {}, f"Mask must be 2D, got shape {mask.shape}"
87
 
88
+ logger.info(f"[MaskValidation] Loaded mask shape: {mask.shape}, dtype: {mask.dtype}")
89
+
90
  # Optional resize to target (H, W)
91
  if target_hw is not None and isinstance(target_hw, (tuple, list)) and len(target_hw) == 2:
92
  H, W = int(target_hw[0]), int(target_hw[1])
93
  if mask.shape != (H, W):
94
+ logger.info(f"[MaskValidation] Resizing mask from {mask.shape} to {target_hw}")
95
  mask = cv2.resize(mask, (W, H), interpolation=cv2.INTER_LINEAR)
96
 
97
  # Normalize to [0,1] float
 
106
 
107
  # We keep validation permissive; upstream may invert/repair based on coverage
108
  msg = f"Basic validation - {coverage:.1f}% coverage"
109
+ logger.info(f"[MaskValidation] Validation result: {msg}, valid: True, coverage: {coverage:.1f}%")
110
+
111
  return True, stats, msg
112
 
113
  def validate_mask_for_matanyone_advanced(mask: np.ndarray, min_foreground: float = 0.01) -> Tuple[bool, str]:
 
121
  Returns:
122
  Tuple of (is_valid, error_message)
123
  """
124
+ logger.info(f"[MaskValidation] Advanced validation on mask shape: {mask.shape}")
125
+
126
  # Basic validation first
127
  is_valid, msg = validate_mask_for_matanyone_simple(mask)
128
  if not is_valid:
 
135
  # Check foreground ratio
136
  fg_ratio = mask.mean()
137
  if fg_ratio < min_foreground:
138
+ logger.warning(f"[MaskValidation] Foreground area too small ({fg_ratio:.1%} < {min_foreground:.0%})")
139
  return False, f"Foreground area too small ({fg_ratio:.1%} < {min_foreground:.0%})"
140
 
141
+ logger.info(f"[MaskValidation] Advanced validation passed: fg_ratio={fg_ratio:.1%}")
142
  return True, ""
143
 
144
  def preprocess_mask(mask: np.ndarray, target_size: Optional[Tuple[int, int]] = None) -> np.ndarray:
 
152
  Returns:
153
  Preprocessed mask (H,W) float32 in [0,1]
154
  """
155
+ logger.info(f"[MaskValidation] Preprocessing mask: shape={mask.shape}, dtype={mask.dtype}")
156
+
157
  # Ensure 2D
158
  if mask.ndim == 3:
159
+ mask = np.squeeze(mask, axis=2)
160
+ logger.info(f"[MaskValidation] Squeezed 3D mask to 2D: {mask.shape}")
161
+
162
+ if mask.ndim != 2:
163
+ logger.error(f"[MaskValidation] Preprocessing failed: mask must be 2D, got {mask.shape}")
164
+ raise ValueError(f"Mask must be 2D, got shape {mask.shape}")
165
+
166
  # Convert to float32 in [0,1]
167
  if mask.dtype != np.float32:
168
  mask = mask.astype(np.float32)
169
+ logger.info(f"[MaskValidation] Converted dtype to float32")
170
  if mask.max() > 1.0:
171
  mask = mask / 255.0
172
+ logger.info(f"[MaskValidation] Normalized to [0,1] range")
173
+
174
+ mask = np.clip(mask, 0.0, 1.0)
175
+
176
  # Resize if needed
177
  if target_size is not None and (mask.shape[0] != target_size[0] or mask.shape[1] != target_size[1]):
178
  import cv2
179
+ H, W = target_size
180
+ mask = cv2.resize(mask, (W, H), interpolation=cv2.INTER_LINEAR)
181
+ logger.info(f"[MaskValidation] Resized mask to {target_size}")
182
+
183
+ logger.info(f"[MaskValidation] Preprocessed mask: shape={mask.shape}, dtype={mask.dtype}, range=[{mask.min():.3f}, {mask.max():.3f}]")
184
+ return mask