luck
Browse files- app.py +25 -9
- models/__init__.py +12 -2
- models/__pycache__/__init__.cpython-313.pyc +0 -0
- models/matany_compat_patch.py +32 -32
- models/matanyone_loader.py +37 -45
- models/sam2_loader.py +263 -179
- pipeline.py +61 -41
- requirements.txt +5 -8
- ui.py +75 -131
- utils/mask_validation.py +63 -6
app.py
CHANGED
|
@@ -1,7 +1,17 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
VideoBackgroundReplacer2 - SAM2 + MatAnyone Integration
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
"""
|
|
|
|
| 5 |
print("=== APP STARTUP DEBUG: app.py starting ===")
|
| 6 |
import sys
|
| 7 |
print(f"=== APP STARTUP DEBUG: Python {sys.version} ===")
|
|
@@ -12,7 +22,6 @@
|
|
| 12 |
import os
|
| 13 |
os.environ.pop("OMP_NUM_THREADS", None)
|
| 14 |
|
| 15 |
-
import sys
|
| 16 |
import logging
|
| 17 |
import threading
|
| 18 |
import time
|
|
@@ -23,7 +32,7 @@
|
|
| 23 |
# Suppress torchvision video deprecation warnings from MatAnyone
|
| 24 |
warnings.filterwarnings("ignore", category=UserWarning, module="torchvision.io._video_deprecation_warning")
|
| 25 |
|
| 26 |
-
# --- import path fix for third_party packages (
|
| 27 |
third_party_root = Path(__file__).parent / "third_party"
|
| 28 |
sam2_path = third_party_root / "sam2"
|
| 29 |
|
|
@@ -43,7 +52,7 @@
|
|
| 43 |
|
| 44 |
# DEBUG: try importing MatAnyone and show its location
|
| 45 |
try:
|
| 46 |
-
import matanyone
|
| 47 |
import inspect
|
| 48 |
import os as _os
|
| 49 |
print("[MATANY] import OK from:", _os.path.dirname(inspect.getfile(matanyone)), flush=True)
|
|
@@ -78,9 +87,9 @@ def _heartbeat():
|
|
| 78 |
# Safe, minimal startup diagnostics (no long CUDA probes)
|
| 79 |
# -----------------------------------------------------------------------------
|
| 80 |
def _safe_startup_diag():
|
| 81 |
-
# Torch version
|
| 82 |
try:
|
| 83 |
-
import torch
|
| 84 |
import importlib
|
| 85 |
t = importlib.import_module("torch")
|
| 86 |
logger.info(
|
|
@@ -91,6 +100,14 @@ def _safe_startup_diag():
|
|
| 91 |
except Exception as e:
|
| 92 |
logger.warning("Torch not available at startup: %s", e)
|
| 93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
# nvidia-smi with short timeout (avoid indefinite block)
|
| 95 |
try:
|
| 96 |
out = subprocess.run(
|
|
@@ -107,7 +124,7 @@ def _safe_startup_diag():
|
|
| 107 |
|
| 108 |
# Optional perf tuning; never block startup
|
| 109 |
try:
|
| 110 |
-
import perf_tuning
|
| 111 |
logger.info("perf_tuning imported successfully.")
|
| 112 |
except Exception as e:
|
| 113 |
logger.info("perf_tuning not available: %s", e)
|
|
@@ -126,7 +143,6 @@ def _safe_startup_diag():
|
|
| 126 |
logger.info(f"[MATANY] probe skipped: {e}")
|
| 127 |
|
| 128 |
# Continue with app startup
|
| 129 |
-
|
| 130 |
_safe_startup_diag()
|
| 131 |
|
| 132 |
# -----------------------------------------------------------------------------
|
|
@@ -166,7 +182,7 @@ def build_ui() -> gr.Blocks:
|
|
| 166 |
logger.info("Launching Gradio on %s:%s …", host, port)
|
| 167 |
|
| 168 |
demo = build_ui()
|
| 169 |
-
demo.queue(max_size=16)
|
| 170 |
|
| 171 |
threading.Thread(target=_post_launch_diag, daemon=True).start()
|
| 172 |
-
demo.launch(server_name=host, server_port=port, show_error=True)
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
VideoBackgroundReplacer2 - SAM2 + MatAnyone Integration
|
| 4 |
+
================================================
|
| 5 |
+
- Sets up Gradio UI and launches pipeline
|
| 6 |
+
- Aligned with torch==2.3.1+cu121, MatAnyone v1.0.0, SAM2 commit 3c76f73c1a7e7b4a2e8a0a9a3e5b92f7e6e3f2f5
|
| 7 |
+
|
| 8 |
+
Changes (2025-09-16):
|
| 9 |
+
- Aligned with updated pipeline.py and models/
|
| 10 |
+
- Added MatAnyone version logging in startup diagnostics
|
| 11 |
+
- Updated Gradio launch for compatibility with gradio==5.42.0
|
| 12 |
+
- Ensured sys.path and environment variables match Dockerfile
|
| 13 |
"""
|
| 14 |
+
|
| 15 |
print("=== APP STARTUP DEBUG: app.py starting ===")
|
| 16 |
import sys
|
| 17 |
print(f"=== APP STARTUP DEBUG: Python {sys.version} ===")
|
|
|
|
| 22 |
import os
|
| 23 |
os.environ.pop("OMP_NUM_THREADS", None)
|
| 24 |
|
|
|
|
| 25 |
import logging
|
| 26 |
import threading
|
| 27 |
import time
|
|
|
|
| 32 |
# Suppress torchvision video deprecation warnings from MatAnyone
|
| 33 |
warnings.filterwarnings("ignore", category=UserWarning, module="torchvision.io._video_deprecation_warning")
|
| 34 |
|
| 35 |
+
# --- import path fix for third_party packages (SAM2) ---
|
| 36 |
third_party_root = Path(__file__).parent / "third_party"
|
| 37 |
sam2_path = third_party_root / "sam2"
|
| 38 |
|
|
|
|
| 52 |
|
| 53 |
# DEBUG: try importing MatAnyone and show its location
|
| 54 |
try:
|
| 55 |
+
import matanyone
|
| 56 |
import inspect
|
| 57 |
import os as _os
|
| 58 |
print("[MATANY] import OK from:", _os.path.dirname(inspect.getfile(matanyone)), flush=True)
|
|
|
|
| 87 |
# Safe, minimal startup diagnostics (no long CUDA probes)
|
| 88 |
# -----------------------------------------------------------------------------
|
| 89 |
def _safe_startup_diag():
|
| 90 |
+
# Torch version
|
| 91 |
try:
|
| 92 |
+
import torch
|
| 93 |
import importlib
|
| 94 |
t = importlib.import_module("torch")
|
| 95 |
logger.info(
|
|
|
|
| 100 |
except Exception as e:
|
| 101 |
logger.warning("Torch not available at startup: %s", e)
|
| 102 |
|
| 103 |
+
# MatAnyone version
|
| 104 |
+
try:
|
| 105 |
+
import importlib.metadata
|
| 106 |
+
version = importlib.metadata.version("matanyone")
|
| 107 |
+
logger.info(f"[MATANY] MatAnyone version: {version}")
|
| 108 |
+
except Exception:
|
| 109 |
+
logger.info("[MATANY] MatAnyone version unknown")
|
| 110 |
+
|
| 111 |
# nvidia-smi with short timeout (avoid indefinite block)
|
| 112 |
try:
|
| 113 |
out = subprocess.run(
|
|
|
|
| 124 |
|
| 125 |
# Optional perf tuning; never block startup
|
| 126 |
try:
|
| 127 |
+
import perf_tuning
|
| 128 |
logger.info("perf_tuning imported successfully.")
|
| 129 |
except Exception as e:
|
| 130 |
logger.info("perf_tuning not available: %s", e)
|
|
|
|
| 143 |
logger.info(f"[MATANY] probe skipped: {e}")
|
| 144 |
|
| 145 |
# Continue with app startup
|
|
|
|
| 146 |
_safe_startup_diag()
|
| 147 |
|
| 148 |
# -----------------------------------------------------------------------------
|
|
|
|
| 182 |
logger.info("Launching Gradio on %s:%s …", host, port)
|
| 183 |
|
| 184 |
demo = build_ui()
|
| 185 |
+
demo.queue(max_size=16, api_open=False) # Disable public API for security
|
| 186 |
|
| 187 |
threading.Thread(target=_post_launch_diag, daemon=True).start()
|
| 188 |
+
demo.launch(server_name=host, server_port=port, show_error=True)
|
models/__init__.py
CHANGED
|
@@ -8,8 +8,10 @@
|
|
| 8 |
- MatAnyone loader is probe-only here; actual run happens in matanyone_loader.MatAnyoneSession
|
| 9 |
|
| 10 |
Changes (2025-09-16):
|
|
|
|
| 11 |
- Updated load_matany to apply T=1 squeeze patch before InferenceCore import
|
| 12 |
-
- Added patch status logging
|
|
|
|
| 13 |
- Fixed InferenceCore import path to matanyone.inference.inference_core
|
| 14 |
"""
|
| 15 |
|
|
@@ -21,6 +23,7 @@
|
|
| 21 |
import subprocess
|
| 22 |
import inspect
|
| 23 |
import logging
|
|
|
|
| 24 |
from pathlib import Path
|
| 25 |
from typing import Optional, Tuple, Dict, Any, Union, Callable
|
| 26 |
|
|
@@ -261,7 +264,7 @@ def _composite_frame_pro(
|
|
| 261 |
) -> np.ndarray:
|
| 262 |
erode_px = erode_px if erode_px is not None else int(os.environ.get("EDGE_ERODE", "1"))
|
| 263 |
dilate_px = dilate_px if dilate_px is not None else int(os.environ.get("EDGE_DILATE", "2"))
|
| 264 |
-
blur_px
|
| 265 |
lw_radius = lw_radius if lw_radius is not None else int(os.environ.get("LIGHTWRAP_RADIUS", "5"))
|
| 266 |
lw_amount = lw_amount if lw_amount is not None else float(os.environ.get("LIGHTWRAP_AMOUNT", "0.18"))
|
| 267 |
despill_amount = despill_amount if despill_amount is not None else float(os.environ.get("DESPILL_AMOUNT", "0.35"))
|
|
@@ -528,6 +531,13 @@ def load_matany() -> Tuple[Optional[object], bool, Dict[str, Any]]:
|
|
| 528 |
try:
|
| 529 |
from matanyone.inference.inference_core import InferenceCore # type: ignore
|
| 530 |
meta["matany_import_ok"] = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 531 |
device = _pick_device("MATANY_DEVICE")
|
| 532 |
repo_id = os.environ.get("MATANY_REPO_ID", "PeiqingYang/MatAnyone")
|
| 533 |
meta["matany_repo_id"] = repo_id
|
|
|
|
| 8 |
- MatAnyone loader is probe-only here; actual run happens in matanyone_loader.MatAnyoneSession
|
| 9 |
|
| 10 |
Changes (2025-09-16):
|
| 11 |
+
- Aligned with torch==2.3.1+cu121 and MatAnyone v1.0.0
|
| 12 |
- Updated load_matany to apply T=1 squeeze patch before InferenceCore import
|
| 13 |
+
- Added patch status logging and MatAnyone version
|
| 14 |
+
- Added InferenceCore attributes logging for debugging
|
| 15 |
- Fixed InferenceCore import path to matanyone.inference.inference_core
|
| 16 |
"""
|
| 17 |
|
|
|
|
| 23 |
import subprocess
|
| 24 |
import inspect
|
| 25 |
import logging
|
| 26 |
+
import importlib.metadata
|
| 27 |
from pathlib import Path
|
| 28 |
from typing import Optional, Tuple, Dict, Any, Union, Callable
|
| 29 |
|
|
|
|
| 264 |
) -> np.ndarray:
|
| 265 |
erode_px = erode_px if erode_px is not None else int(os.environ.get("EDGE_ERODE", "1"))
|
| 266 |
dilate_px = dilate_px if dilate_px is not None else int(os.environ.get("EDGE_DILATE", "2"))
|
| 267 |
+
blur_px = blur_px if blur_px is not None else float(os.environ.get("EDGE_BLUR", "1.5"))
|
| 268 |
lw_radius = lw_radius if lw_radius is not None else int(os.environ.get("LIGHTWRAP_RADIUS", "5"))
|
| 269 |
lw_amount = lw_amount if lw_amount is not None else float(os.environ.get("LIGHTWRAP_AMOUNT", "0.18"))
|
| 270 |
despill_amount = despill_amount if despill_amount is not None else float(os.environ.get("DESPILL_AMOUNT", "0.35"))
|
|
|
|
| 531 |
try:
|
| 532 |
from matanyone.inference.inference_core import InferenceCore # type: ignore
|
| 533 |
meta["matany_import_ok"] = True
|
| 534 |
+
# Log MatAnyone version and InferenceCore attributes
|
| 535 |
+
try:
|
| 536 |
+
version = importlib.metadata.version("matanyone")
|
| 537 |
+
logger.info(f"[MATANY] MatAnyone version: {version}")
|
| 538 |
+
except Exception:
|
| 539 |
+
logger.info("[MATANY] MatAnyone version unknown")
|
| 540 |
+
logger.debug(f"[MATANY] InferenceCore attributes: {dir(InferenceCore)}")
|
| 541 |
device = _pick_device("MATANY_DEVICE")
|
| 542 |
repo_id = os.environ.get("MATANY_REPO_ID", "PeiqingYang/MatAnyone")
|
| 543 |
meta["matany_repo_id"] = repo_id
|
models/__pycache__/__init__.cpython-313.pyc
CHANGED
|
Binary files a/models/__pycache__/__init__.cpython-313.pyc and b/models/__pycache__/__init__.cpython-313.pyc differ
|
|
|
models/matany_compat_patch.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
# MatAnyone HF-compat patch: squeeze time dim T=1 before first Conv2d
|
| 3 |
# Changes (2025-09-16):
|
| 4 |
-
# -
|
|
|
|
| 5 |
# - Log dir(MatAnyone) and module version for debugging
|
| 6 |
# - Added isinstance(img, torch.Tensor) for non-tensor safety
|
| 7 |
-
# -
|
| 8 |
-
# - Kept monkey-patch for HF Spaces compatibility
|
| 9 |
|
| 10 |
import logging
|
| 11 |
import torch
|
|
@@ -15,9 +15,9 @@
|
|
| 15 |
|
| 16 |
def apply_matany_t1_squeeze_guard() -> bool:
|
| 17 |
"""
|
| 18 |
-
Monkey-patch MatAnyone.
|
| 19 |
Safe for multi-frame (T>1) as it only squeezes when T==1.
|
| 20 |
-
Returns True if
|
| 21 |
"""
|
| 22 |
try:
|
| 23 |
import matanyone.model.matanyone as M
|
|
@@ -29,7 +29,7 @@ def apply_matany_t1_squeeze_guard() -> bool:
|
|
| 29 |
return False
|
| 30 |
MatAnyone = M.MatAnyone
|
| 31 |
|
| 32 |
-
# Log MatAnyone version and attributes
|
| 33 |
try:
|
| 34 |
version = importlib.metadata.version("matanyone")
|
| 35 |
log.info(f"[MatAnyCompat] MatAnyone version: {version}")
|
|
@@ -37,33 +37,33 @@ def apply_matany_t1_squeeze_guard() -> bool:
|
|
| 37 |
log.info("[MatAnyCompat] MatAnyone version unknown")
|
| 38 |
log.debug(f"[MatAnyCompat] MatAnyone attributes: {dir(MatAnyone)}")
|
| 39 |
|
| 40 |
-
# Try
|
| 41 |
-
|
| 42 |
-
for
|
| 43 |
-
if hasattr(MatAnyone,
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
return True
|
| 52 |
|
| 53 |
-
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
log.info(f"[MatAnyCompat] Squeezing 5D {img.shape} to 4D {img.squeeze(1).shape} in {method_name}")
|
| 61 |
-
img = img.squeeze(1) # [B,1,C,H,W] → [B,C,H,W]
|
| 62 |
-
except Exception as e:
|
| 63 |
-
log.warning(f"[MatAnyCompat] Failed to process input shape in {method_name}: %s", e)
|
| 64 |
-
return orig_method(self, img, *args, **kwargs)
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
return True
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
# MatAnyone HF-compat patch: squeeze time dim T=1 before first Conv2d
|
| 3 |
# Changes (2025-09-16):
|
| 4 |
+
# - Aligned with torch==2.3.1+cu121 and MatAnyone v1.0.0
|
| 5 |
+
# - Patch forward, encode, encode_img to cover all code paths
|
| 6 |
# - Log dir(MatAnyone) and module version for debugging
|
| 7 |
# - Added isinstance(img, torch.Tensor) for non-tensor safety
|
| 8 |
+
# - Log input/output shapes for verification
|
|
|
|
| 9 |
|
| 10 |
import logging
|
| 11 |
import torch
|
|
|
|
| 15 |
|
| 16 |
def apply_matany_t1_squeeze_guard() -> bool:
|
| 17 |
"""
|
| 18 |
+
Monkey-patch MatAnyone.forward/encode/encode_img to squeeze [B,1,C,H,W] → [B,C,H,W].
|
| 19 |
Safe for multi-frame (T>1) as it only squeezes when T==1.
|
| 20 |
+
Returns True if at least one method patched, False otherwise.
|
| 21 |
"""
|
| 22 |
try:
|
| 23 |
import matanyone.model.matanyone as M
|
|
|
|
| 29 |
return False
|
| 30 |
MatAnyone = M.MatAnyone
|
| 31 |
|
| 32 |
+
# Log MatAnyone version and attributes
|
| 33 |
try:
|
| 34 |
version = importlib.metadata.version("matanyone")
|
| 35 |
log.info(f"[MatAnyCompat] MatAnyone version: {version}")
|
|
|
|
| 37 |
log.info("[MatAnyCompat] MatAnyone version unknown")
|
| 38 |
log.debug(f"[MatAnyCompat] MatAnyone attributes: {dir(MatAnyone)}")
|
| 39 |
|
| 40 |
+
# Try patching forward, encode, encode_img
|
| 41 |
+
patched = False
|
| 42 |
+
for method_name in ["forward", "encode", "encode_img"]:
|
| 43 |
+
if not hasattr(MatAnyone, method_name):
|
| 44 |
+
continue
|
| 45 |
+
if getattr(MatAnyone, f"_{method_name}_patched", False):
|
| 46 |
+
log.info(f"[MatAnyCompat] {method_name} already patched")
|
| 47 |
+
continue
|
| 48 |
+
|
| 49 |
+
# Store original method
|
| 50 |
+
orig_method = getattr(MatAnyone, method_name)
|
|
|
|
| 51 |
|
| 52 |
+
def method_compat(self, img, *args, **kwargs):
|
| 53 |
+
try:
|
| 54 |
+
if isinstance(img, torch.Tensor) and img.dim() == 5 and img.shape[1] == 1:
|
| 55 |
+
log.info(f"[MatAnyCompat] Squeezing 5D {img.shape} to 4D {img.squeeze(1).shape} in {method_name}")
|
| 56 |
+
img = img.squeeze(1) # [B,1,C,H,W] → [B,C,H,W]
|
| 57 |
+
except Exception as e:
|
| 58 |
+
log.warning(f"[MatAnyCompat] Failed to process input shape in {method_name}: %s", e)
|
| 59 |
+
return orig_method(self, img, *args, **kwargs)
|
| 60 |
|
| 61 |
+
setattr(MatAnyone, method_name, method_compat)
|
| 62 |
+
setattr(MatAny, f"_{method_name}_patched", True)
|
| 63 |
+
log.info(f"[MatAnyCompat] Applied T=1 squeeze guard in MatAnyone.{method_name}")
|
| 64 |
+
patched = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
+
if not patched:
|
| 67 |
+
log.warning("[MatAnyCompat] No patchable methods (forward, encode, encode_img) found on MatAnyone")
|
| 68 |
+
return False
|
| 69 |
return True
|
models/matanyone_loader.py
CHANGED
|
@@ -5,16 +5,16 @@
|
|
| 5 |
|
| 6 |
- SAM2 defines the subject (seed mask) on frame 0.
|
| 7 |
- MatAnyone does frame-by-frame alpha matting.
|
| 8 |
-
-
|
| 9 |
-
- Falls back to process_frame([H,W,3]) if
|
| 10 |
|
| 11 |
Changes (2025-09-16):
|
| 12 |
-
-
|
| 13 |
-
-
|
| 14 |
-
-
|
| 15 |
-
-
|
| 16 |
-
-
|
| 17 |
-
-
|
| 18 |
"""
|
| 19 |
|
| 20 |
from __future__ import annotations
|
|
@@ -24,6 +24,7 @@
|
|
| 24 |
import logging
|
| 25 |
import numpy as np
|
| 26 |
import torch
|
|
|
|
| 27 |
from pathlib import Path
|
| 28 |
from typing import Optional, Callable, Tuple
|
| 29 |
|
|
@@ -71,39 +72,22 @@ def _cuda_snapshot(device: Optional[torch.device]) -> str:
|
|
| 71 |
idx = device.index
|
| 72 |
name = torch.cuda.get_device_name(idx)
|
| 73 |
alloc = torch.cuda.memory_allocated(idx) / (1024**3)
|
| 74 |
-
resv
|
| 75 |
return f"device={idx}, name={name}, alloc={alloc:.2f}GB, reserved={resv:.2f}GB"
|
| 76 |
except Exception as e:
|
| 77 |
return f"CUDA snapshot error: {e!r}"
|
| 78 |
|
| 79 |
def _safe_empty_cache():
|
| 80 |
-
|
|
|
|
| 81 |
try:
|
| 82 |
-
|
| 83 |
-
return
|
| 84 |
-
|
| 85 |
-
# Log memory stats before cleanup
|
| 86 |
-
if _env_flag("MATANY_LOG_VRAM"):
|
| 87 |
-
log.info("[MATANY] VRAM before cleanup:")
|
| 88 |
-
log.info(f" Allocated: {torch.cuda.memory_allocated()/1024**2:.1f} MB")
|
| 89 |
-
log.info(f" Reserved: {torch.cuda.memory_reserved()/1024**2:.1f} MB")
|
| 90 |
-
|
| 91 |
-
# Clear cache and sync
|
| 92 |
torch.cuda.empty_cache()
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
log.info(f" Allocated: {torch.cuda.memory_allocated()/1024**2:.1f} MB")
|
| 99 |
-
log.info(f" Reserved: {torch.cuda.memory_reserved()/1024**2:.1f} MB")
|
| 100 |
-
|
| 101 |
-
except Exception as e:
|
| 102 |
-
log.warning(f"[MATANY] Error in cache cleanup: {e}", exc_info=True)
|
| 103 |
-
try:
|
| 104 |
-
torch.cuda.empty_cache()
|
| 105 |
-
except Exception as e2:
|
| 106 |
-
log.warning(f"[MATANY] Secondary cache cleanup failed: {e2}")
|
| 107 |
|
| 108 |
# ---------- SAM2 → seed mask prep ----------
|
| 109 |
def _prepare_seed_mask(sam2_mask: np.ndarray, H: int, W: int) -> np.ndarray:
|
|
@@ -149,7 +133,7 @@ class MatAnyoneSession:
|
|
| 149 |
"""
|
| 150 |
Streaming wrapper that seeds MatAnyone on frame 0.
|
| 151 |
Prefers step([B,C,H,W]) with T=1 squeeze patch for conv2d compatibility.
|
| 152 |
-
Falls back to process_frame([H,W,3]) if supported
|
| 153 |
"""
|
| 154 |
def __init__(self, device: Optional[str] = None, precision: str = "auto"):
|
| 155 |
from .matany_compat_patch import apply_matany_t1_squeeze_guard
|
|
@@ -157,18 +141,25 @@ def __init__(self, device: Optional[str] = None, precision: str = "auto"):
|
|
| 157 |
self.device = torch.device(device) if device else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
|
| 158 |
self.precision = precision.lower()
|
| 159 |
|
| 160 |
-
# Apply T=1 squeeze patch
|
| 161 |
if apply_matany_t1_squeeze_guard():
|
| 162 |
-
log.info("[MATANY] T=1 squeeze patch applied for MatAnyone
|
| 163 |
else:
|
| 164 |
log.warning("[MATANY] T=1 squeeze patch failed; conv2d errors may occur")
|
| 165 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
# API/format overrides for debugging
|
| 167 |
api_force = os.getenv("MATANY_FORCE_API", "").strip().lower() # "process" or "step"
|
| 168 |
fmt_force = os.getenv("MATANY_FORCE_FORMAT", "4d").strip().lower() # "4d" or "5d"
|
| 169 |
self._force_api_process = (api_force == "process")
|
| 170 |
self._force_api_step = (api_force == "step")
|
| 171 |
-
self._force_4d = (fmt_force == "4d") or not fmt_force # Default to 4D
|
| 172 |
self._force_5d = (fmt_force == "5d")
|
| 173 |
|
| 174 |
try:
|
|
@@ -255,6 +246,7 @@ def _call_step(self, rgb_hwc: np.ndarray, seed_mask_hw: Optional[np.ndarray], is
|
|
| 255 |
def run(use_5d: bool):
|
| 256 |
img = img_5d if use_5d else img_4d
|
| 257 |
msk = mask_5d if use_5d else mask_4d
|
|
|
|
| 258 |
if is_first and msk is not None:
|
| 259 |
try:
|
| 260 |
return self.core.step(img, msk, is_first=True)
|
|
@@ -351,10 +343,10 @@ def process_stream(
|
|
| 351 |
cap_probe = cv2.VideoCapture(str(video_path))
|
| 352 |
if not cap_probe.isOpened():
|
| 353 |
raise MatAnyError(f"Failed to open video: {video_path}")
|
| 354 |
-
N
|
| 355 |
fps = cap_probe.get(cv2.CAP_PROP_FPS)
|
| 356 |
-
W
|
| 357 |
-
H
|
| 358 |
cap_probe.release()
|
| 359 |
if not fps or fps <= 0 or np.isnan(fps):
|
| 360 |
fps = 25.0
|
|
@@ -364,10 +356,10 @@ def process_stream(
|
|
| 364 |
_emit_progress(progress_cb, 0.08, "Using per-frame processing")
|
| 365 |
|
| 366 |
alpha_path = out_dir / "alpha.mp4"
|
| 367 |
-
fg_path
|
| 368 |
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
| 369 |
alpha_writer = cv2.VideoWriter(str(alpha_path), fourcc, fps, (W, H), True)
|
| 370 |
-
fg_writer
|
| 371 |
if not alpha_writer.isOpened() or not fg_writer.isOpened():
|
| 372 |
raise MatAnyError("Failed to initialize VideoWriter(s)")
|
| 373 |
|
|
@@ -396,9 +388,9 @@ def process_stream(
|
|
| 396 |
is_first = (idx == 0)
|
| 397 |
alpha = self._run_frame(frame, seed_mask_np if is_first else None, is_first)
|
| 398 |
|
| 399 |
-
alpha_u8
|
| 400 |
alpha_bgr = cv2.cvtColor(alpha_u8, cv2.COLOR_GRAY2BGR)
|
| 401 |
-
fg_bgr
|
| 402 |
|
| 403 |
alpha_writer.write(alpha_bgr)
|
| 404 |
fg_writer.write(fg_bgr)
|
|
|
|
| 5 |
|
| 6 |
- SAM2 defines the subject (seed mask) on frame 0.
|
| 7 |
- MatAnyone does frame-by-frame alpha matting.
|
| 8 |
+
- Prefers step([B,C,H,W]) with T=1 squeeze patch for conv2d compatibility.
|
| 9 |
+
- Falls back to process_frame([H,W,3]) if supported.
|
| 10 |
|
| 11 |
Changes (2025-09-16):
|
| 12 |
+
- Aligned with torch==2.3.1+cu121 and MatAnyone v1.0.0
|
| 13 |
+
- Added shape logging in _call_step to verify 5D-to-4D squeeze
|
| 14 |
+
- Set MATANY_FORCE_FORMAT=4d as default
|
| 15 |
+
- Added VRAM logging in process_stream (MATANY_LOG_VRAM=1)
|
| 16 |
+
- Enhanced _safe_empty_cache with memory_summary
|
| 17 |
+
- Added MatAnyone version logging
|
| 18 |
"""
|
| 19 |
|
| 20 |
from __future__ import annotations
|
|
|
|
| 24 |
import logging
|
| 25 |
import numpy as np
|
| 26 |
import torch
|
| 27 |
+
import importlib.metadata
|
| 28 |
from pathlib import Path
|
| 29 |
from typing import Optional, Callable, Tuple
|
| 30 |
|
|
|
|
| 72 |
idx = device.index
|
| 73 |
name = torch.cuda.get_device_name(idx)
|
| 74 |
alloc = torch.cuda.memory_allocated(idx) / (1024**3)
|
| 75 |
+
resv = torch.cuda.memory_reserved(idx) / (1024**3)
|
| 76 |
return f"device={idx}, name={name}, alloc={alloc:.2f}GB, reserved={resv:.2f}GB"
|
| 77 |
except Exception as e:
|
| 78 |
return f"CUDA snapshot error: {e!r}"
|
| 79 |
|
| 80 |
def _safe_empty_cache():
|
| 81 |
+
if not torch.cuda.is_available():
|
| 82 |
+
return
|
| 83 |
try:
|
| 84 |
+
log.info(f"[MATANY] CUDA memory before empty_cache: {_cuda_snapshot(None)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
torch.cuda.empty_cache()
|
| 86 |
+
log.info(f"[MATANY] CUDA memory after empty_cache: {_cuda_snapshot(None)}")
|
| 87 |
+
if os.getenv("MATANY_LOG_VRAM", "0") == "1":
|
| 88 |
+
log.debug(f"[MATANY] VRAM summary:\n{torch.cuda.memory_summary()}")
|
| 89 |
+
except Exception:
|
| 90 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
# ---------- SAM2 → seed mask prep ----------
|
| 93 |
def _prepare_seed_mask(sam2_mask: np.ndarray, H: int, W: int) -> np.ndarray:
|
|
|
|
| 133 |
"""
|
| 134 |
Streaming wrapper that seeds MatAnyone on frame 0.
|
| 135 |
Prefers step([B,C,H,W]) with T=1 squeeze patch for conv2d compatibility.
|
| 136 |
+
Falls back to process_frame([H,W,3]) if supported.
|
| 137 |
"""
|
| 138 |
def __init__(self, device: Optional[str] = None, precision: str = "auto"):
|
| 139 |
from .matany_compat_patch import apply_matany_t1_squeeze_guard
|
|
|
|
| 141 |
self.device = torch.device(device) if device else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
|
| 142 |
self.precision = precision.lower()
|
| 143 |
|
| 144 |
+
# Apply T=1 squeeze patch
|
| 145 |
if apply_matany_t1_squeeze_guard():
|
| 146 |
+
log.info("[MATANY] T=1 squeeze patch applied for MatAnyone")
|
| 147 |
else:
|
| 148 |
log.warning("[MATANY] T=1 squeeze patch failed; conv2d errors may occur")
|
| 149 |
|
| 150 |
+
# Log MatAnyone version
|
| 151 |
+
try:
|
| 152 |
+
version = importlib.metadata.version("matanyone")
|
| 153 |
+
log.info(f"[MATANY] MatAnyone version: {version}")
|
| 154 |
+
except Exception:
|
| 155 |
+
log.info("[MATANY] MatAnyone version unknown")
|
| 156 |
+
|
| 157 |
# API/format overrides for debugging
|
| 158 |
api_force = os.getenv("MATANY_FORCE_API", "").strip().lower() # "process" or "step"
|
| 159 |
fmt_force = os.getenv("MATANY_FORCE_FORMAT", "4d").strip().lower() # "4d" or "5d"
|
| 160 |
self._force_api_process = (api_force == "process")
|
| 161 |
self._force_api_step = (api_force == "step")
|
| 162 |
+
self._force_4d = (fmt_force == "4d") or not fmt_force # Default to 4D
|
| 163 |
self._force_5d = (fmt_force == "5d")
|
| 164 |
|
| 165 |
try:
|
|
|
|
| 246 |
def run(use_5d: bool):
|
| 247 |
img = img_5d if use_5d else img_4d
|
| 248 |
msk = mask_5d if use_5d else mask_4d
|
| 249 |
+
log.debug(f"[MATANY] Step input: img={img.shape}, mask={msk.shape if msk is not None else None}, is_first={is_first}")
|
| 250 |
if is_first and msk is not None:
|
| 251 |
try:
|
| 252 |
return self.core.step(img, msk, is_first=True)
|
|
|
|
| 343 |
cap_probe = cv2.VideoCapture(str(video_path))
|
| 344 |
if not cap_probe.isOpened():
|
| 345 |
raise MatAnyError(f"Failed to open video: {video_path}")
|
| 346 |
+
N = int(cap_probe.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 347 |
fps = cap_probe.get(cv2.CAP_PROP_FPS)
|
| 348 |
+
W = int(cap_probe.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 349 |
+
H = int(cap_probe.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 350 |
cap_probe.release()
|
| 351 |
if not fps or fps <= 0 or np.isnan(fps):
|
| 352 |
fps = 25.0
|
|
|
|
| 356 |
_emit_progress(progress_cb, 0.08, "Using per-frame processing")
|
| 357 |
|
| 358 |
alpha_path = out_dir / "alpha.mp4"
|
| 359 |
+
fg_path = out_dir / "fg.mp4"
|
| 360 |
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
| 361 |
alpha_writer = cv2.VideoWriter(str(alpha_path), fourcc, fps, (W, H), True)
|
| 362 |
+
fg_writer = cv2.VideoWriter(str(fg_path), fourcc, fps, (W, H), True)
|
| 363 |
if not alpha_writer.isOpened() or not fg_writer.isOpened():
|
| 364 |
raise MatAnyError("Failed to initialize VideoWriter(s)")
|
| 365 |
|
|
|
|
| 388 |
is_first = (idx == 0)
|
| 389 |
alpha = self._run_frame(frame, seed_mask_np if is_first else None, is_first)
|
| 390 |
|
| 391 |
+
alpha_u8 = (alpha * 255.0 + 0.5).astype(np.uint8)
|
| 392 |
alpha_bgr = cv2.cvtColor(alpha_u8, cv2.COLOR_GRAY2BGR)
|
| 393 |
+
fg_bgr = (frame.astype(np.float32) * alpha[..., None]).clip(0, 255).astype(np.uint8)
|
| 394 |
|
| 395 |
alpha_writer.write(alpha_bgr)
|
| 396 |
fg_writer.write(fg_bgr)
|
models/sam2_loader.py
CHANGED
|
@@ -1,195 +1,279 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
from pathlib import Path
|
|
|
|
|
|
|
| 5 |
import numpy as np
|
|
|
|
|
|
|
| 6 |
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
|
|
|
| 18 |
|
| 19 |
-
|
| 20 |
-
return hf_hub_download(repo_id=model_id, filename=ckpt_name, local_dir=os.environ.get("HF_HOME"))
|
| 21 |
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
| 23 |
try:
|
| 24 |
-
|
| 25 |
-
return
|
| 26 |
except Exception as e:
|
| 27 |
-
|
| 28 |
return None
|
| 29 |
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
try:
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
margin_w, # x1
|
| 99 |
-
margin_h, # y1
|
| 100 |
-
w - margin_w, # x2
|
| 101 |
-
h - margin_h # y2
|
| 102 |
-
], dtype=np.float32)
|
| 103 |
-
|
| 104 |
-
log.info(f"Using simple center box: {box}")
|
| 105 |
-
return box
|
| 106 |
-
|
| 107 |
-
@torch.inference_mode()
|
| 108 |
-
def first_frame_mask(self, image_rgb01):
|
| 109 |
-
"""
|
| 110 |
-
Returns an initial binary mask for the foreground person from first frame.
|
| 111 |
-
Uses a robust approach with fallback strategies for better reliability.
|
| 112 |
-
"""
|
| 113 |
-
log.info("🔍 SAM2 first_frame_mask() called - starting segmentation")
|
| 114 |
-
|
| 115 |
try:
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
raise RuntimeError("SAM2 predictor doesn't support set_image")
|
| 122 |
-
|
| 123 |
-
# Convert to numpy for predictor if needed
|
| 124 |
-
if isinstance(image_rgb01, torch.Tensor):
|
| 125 |
-
image_np = (image_rgb01.cpu().numpy() * 255).astype("uint8")
|
| 126 |
-
else:
|
| 127 |
-
image_np = (image_rgb01 * 255).astype("uint8")
|
| 128 |
-
|
| 129 |
-
# Set the image for prediction
|
| 130 |
-
self.predictor.set_image(image_np)
|
| 131 |
-
|
| 132 |
-
# Strategy 1: Try with person-focused bounding box first
|
| 133 |
-
box = self._detect_person_region(image_np)
|
| 134 |
-
log.info(f"Trying person-focused box: {box}")
|
| 135 |
-
|
| 136 |
-
masks, scores, _ = self.predictor.predict(
|
| 137 |
-
box=box,
|
| 138 |
-
multimask_output=True,
|
| 139 |
-
mask_input=None,
|
| 140 |
-
return_logits=False
|
| 141 |
-
)
|
| 142 |
-
|
| 143 |
-
# Strategy 2: If no good masks found, try with a point in the center
|
| 144 |
-
if len(masks) == 0 or np.max(scores) < 0.5:
|
| 145 |
-
log.info("No good masks found with box, trying center point")
|
| 146 |
-
h, w = image_np.shape[:2]
|
| 147 |
-
point = np.array([[w//2, h//2]])
|
| 148 |
-
labels = np.array([1]) # 1=foreground point
|
| 149 |
-
masks, scores, _ = self.predictor.predict(
|
| 150 |
-
point_coords=point,
|
| 151 |
-
point_labels=labels,
|
| 152 |
-
multimask_output=True
|
| 153 |
-
)
|
| 154 |
-
|
| 155 |
-
# Choose the best mask (highest score)
|
| 156 |
-
if len(masks) > 0 and len(scores) > 0:
|
| 157 |
-
best_idx = np.argmax(scores)
|
| 158 |
-
mask = masks[best_idx]
|
| 159 |
-
score = float(scores[best_idx])
|
| 160 |
-
log.info(f"Selected mask {best_idx+1}/{len(masks)} with score {score:.3f}")
|
| 161 |
-
|
| 162 |
-
# Verify mask quality
|
| 163 |
-
mask_coverage = (np.sum(mask > 0) / mask.size) * 100
|
| 164 |
-
log.info(f"Mask coverage: {mask_coverage:.1f}% (target: 15-35%)")
|
| 165 |
-
|
| 166 |
-
# If mask is too small or too large, try to refine it
|
| 167 |
-
if mask_coverage < 5 or mask_coverage > 50:
|
| 168 |
-
log.warning(f"Suspicious mask coverage {mask_coverage:.1f}%, applying post-processing")
|
| 169 |
-
# Apply morphological operations to clean up the mask
|
| 170 |
-
kernel = np.ones((5,5), np.uint8)
|
| 171 |
-
if mask_coverage < 5: # Too small - dilate
|
| 172 |
-
mask = cv2.dilate(mask.astype(np.uint8), kernel, iterations=1)
|
| 173 |
-
else: # Too large - erode
|
| 174 |
-
mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=1)
|
| 175 |
-
|
| 176 |
-
# Ensure we still have a valid mask
|
| 177 |
-
if np.sum(mask) == 0:
|
| 178 |
-
log.warning("Post-processing removed all mask pixels, using original")
|
| 179 |
-
mask = masks[best_idx]
|
| 180 |
else:
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
mask = np.zeros((h, w), dtype=np.float32)
|
| 191 |
-
margin_h, margin_w = h//4, w//4
|
| 192 |
-
mask[margin_h:h-margin_h, margin_w:w-margin_w] = 1.0
|
| 193 |
-
return mask
|
| 194 |
|
| 195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
SAM2 Loader — Robust loading and mask generation for SAM2
|
| 4 |
+
========================================================
|
| 5 |
+
- Loads SAM2 model with Hydra config resolution
|
| 6 |
+
- Generates seed masks for MatAnyone
|
| 7 |
+
- Aligned with torch==2.3.1+cu121 and SAM2 commit 3c76f73c1a7e7b4a2e8a0a9a3e5b92f7e6e3f2f5
|
| 8 |
+
|
| 9 |
+
Changes (2025-09-16):
|
| 10 |
+
- Aligned with torch==2.3.1+cu121 and SAM2 commit
|
| 11 |
+
- Added GPU memory logging for Tesla T4
|
| 12 |
+
- Added SAM2 version logging via importlib.metadata
|
| 13 |
+
- Simplified config resolution to match __init__.py
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
import logging
|
| 20 |
+
import importlib.metadata
|
| 21 |
from pathlib import Path
|
| 22 |
+
from typing import Optional, Tuple, Dict, Any
|
| 23 |
+
|
| 24 |
import numpy as np
|
| 25 |
+
import yaml
|
| 26 |
+
import torch
|
| 27 |
|
| 28 |
+
# --------------------------------------------------------------------------------------
|
| 29 |
+
# Logging
|
| 30 |
+
# --------------------------------------------------------------------------------------
|
| 31 |
+
logger = logging.getLogger("backgroundfx_pro")
|
| 32 |
+
if not logger.handlers:
|
| 33 |
+
_h = logging.StreamHandler()
|
| 34 |
+
_h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s"))
|
| 35 |
+
logger.addHandler(_h)
|
| 36 |
+
logger.setLevel(logging.INFO)
|
| 37 |
|
| 38 |
+
# --------------------------------------------------------------------------------------
|
| 39 |
+
# Path setup for third_party repos
|
| 40 |
+
# --------------------------------------------------------------------------------------
|
| 41 |
+
ROOT = Path(__file__).resolve().parent.parent # project root
|
| 42 |
+
TP_SAM2 = Path(os.environ.get("THIRD_PARTY_SAM2_DIR", ROOT / "third_party" / "sam2")).resolve()
|
| 43 |
|
| 44 |
+
def _add_sys_path(p: Path) -> None:
|
| 45 |
+
if p.exists():
|
| 46 |
+
p_str = str(p)
|
| 47 |
+
if p_str not in sys.path:
|
| 48 |
+
sys.path.insert(0, p_str)
|
| 49 |
+
else:
|
| 50 |
+
logger.warning(f"third_party path not found: {p}")
|
| 51 |
|
| 52 |
+
_add_sys_path(TP_SAM2)
|
|
|
|
| 53 |
|
| 54 |
+
# --------------------------------------------------------------------------------------
|
| 55 |
+
# Safe Torch accessors
|
| 56 |
+
# --------------------------------------------------------------------------------------
|
| 57 |
+
def _torch():
|
| 58 |
try:
|
| 59 |
+
import torch
|
| 60 |
+
return torch
|
| 61 |
except Exception as e:
|
| 62 |
+
logger.warning(f"[sam2_loader.safe-torch] import failed: {e}")
|
| 63 |
return None
|
| 64 |
|
| 65 |
+
def _has_cuda() -> bool:
|
| 66 |
+
t = _torch()
|
| 67 |
+
if t is None:
|
| 68 |
+
return False
|
| 69 |
+
try:
|
| 70 |
+
return bool(t.cuda.is_available())
|
| 71 |
+
except Exception as e:
|
| 72 |
+
logger.warning(f"[sam2_loader.safe-torch] cuda.is_available() failed: {e}")
|
| 73 |
+
return False
|
| 74 |
+
|
| 75 |
+
def _pick_device(env_key: str) -> str:
|
| 76 |
+
requested = os.environ.get(env_key, "").strip().lower()
|
| 77 |
+
has_cuda = _has_cuda()
|
| 78 |
+
|
| 79 |
+
logger.info(f"CUDA environment variables: {{'SAM2_DEVICE': '{os.environ.get('SAM2_DEVICE', '')}'}}")
|
| 80 |
+
logger.info(f"_pick_device({env_key}): requested='{requested}', has_cuda={has_cuda}")
|
| 81 |
+
|
| 82 |
+
if has_cuda and requested not in {"cpu"}:
|
| 83 |
+
logger.info(f"FORCING CUDA device (GPU available, requested='{requested}')")
|
| 84 |
+
return "cuda"
|
| 85 |
+
elif requested in {"cuda", "cpu"}:
|
| 86 |
+
logger.info(f"Using explicitly requested device: {requested}")
|
| 87 |
+
return requested
|
| 88 |
+
|
| 89 |
+
result = "cuda" if has_cuda else "cpu"
|
| 90 |
+
logger.info(f"Auto-selected device: {result}")
|
| 91 |
+
return result
|
| 92 |
+
|
| 93 |
+
# --------------------------------------------------------------------------------------
|
| 94 |
+
# SAM2 Loading and Mask Generation
|
| 95 |
+
# --------------------------------------------------------------------------------------
|
| 96 |
+
def _resolve_sam2_cfg(cfg_str: str) -> str:
|
| 97 |
+
"""Resolve SAM2 config path - return relative path for Hydra compatibility."""
|
| 98 |
+
logger.info(f"_resolve_sam2_cfg called with cfg_str={cfg_str}")
|
| 99 |
+
|
| 100 |
+
candidate = os.path.join(TP_SAM2, cfg_str)
|
| 101 |
+
logger.info(f"Candidate path: {candidate}")
|
| 102 |
+
logger.info(f"Candidate exists: {os.path.exists(candidate)}")
|
| 103 |
+
|
| 104 |
+
if os.path.exists(candidate):
|
| 105 |
+
if cfg_str.startswith("sam2/configs/"):
|
| 106 |
+
relative_path = cfg_str.replace("sam2/configs/", "configs/")
|
| 107 |
+
else:
|
| 108 |
+
relative_path = cfg_str
|
| 109 |
+
logger.info(f"Returning Hydra-compatible relative path: {relative_path}")
|
| 110 |
+
return relative_path
|
| 111 |
+
|
| 112 |
+
fallbacks = [
|
| 113 |
+
os.path.join(TP_SAM2, "sam2", cfg_str),
|
| 114 |
+
os.path.join(TP_SAM2, "configs", cfg_str),
|
| 115 |
+
]
|
| 116 |
+
|
| 117 |
+
for fallback in fallbacks:
|
| 118 |
+
logger.info(f"Trying fallback: {fallback}")
|
| 119 |
+
if os.path.exists(fallback):
|
| 120 |
+
if "configs/" in fallback:
|
| 121 |
+
relative_path = "configs/" + fallback.split("configs/")[-1]
|
| 122 |
+
logger.info(f"Returning fallback relative path: {relative_path}")
|
| 123 |
+
return relative_path
|
| 124 |
+
|
| 125 |
+
logger.warning(f"Config not found, returning original: {cfg_str}")
|
| 126 |
+
return cfg_str
|
| 127 |
|
| 128 |
+
def _find_hiera_config_if_hieradet(cfg_path: str) -> Optional[str]:
|
| 129 |
+
"""If config references 'hieradet', try to find a 'hiera' config."""
|
| 130 |
+
try:
|
| 131 |
+
with open(cfg_path, "r") as f:
|
| 132 |
+
data = yaml.safe_load(f)
|
| 133 |
+
model = data.get("model", {}) or {}
|
| 134 |
+
enc = model.get("image_encoder") or {}
|
| 135 |
+
trunk = enc.get("trunk") or {}
|
| 136 |
+
target = trunk.get("_target_") or trunk.get("target")
|
| 137 |
+
if isinstance(target, str) and "hieradet" in target:
|
| 138 |
+
for y in TP_SAM2.rglob("*.yaml"):
|
| 139 |
+
try:
|
| 140 |
+
with open(y, "r") as f2:
|
| 141 |
+
d2 = yaml.safe_load(f2) or {}
|
| 142 |
+
e2 = (d2.get("model", {}) or {}).get("image_encoder") or {}
|
| 143 |
+
t2 = (e2.get("trunk") or {})
|
| 144 |
+
tgt2 = t2.get("_target_") or t2.get("target")
|
| 145 |
+
if isinstance(tgt2, str) and ".hiera." in tgt2:
|
| 146 |
+
logger.info(f"SAM2: switching config from 'hieradet' → 'hiera': {y}")
|
| 147 |
+
return str(y)
|
| 148 |
+
except Exception:
|
| 149 |
+
continue
|
| 150 |
+
except Exception:
|
| 151 |
+
pass
|
| 152 |
+
return None
|
| 153 |
+
|
| 154 |
+
def load_sam2() -> Tuple[Optional[object], bool, Dict[str, Any]]:
|
| 155 |
+
"""Robust SAM2 loader with config resolution and error handling."""
|
| 156 |
+
meta = {"sam2_import_ok": False, "sam2_init_ok": False}
|
| 157 |
+
try:
|
| 158 |
+
from sam2.build_sam import build_sam2
|
| 159 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
| 160 |
+
meta["sam2_import_ok"] = True
|
| 161 |
+
except Exception as e:
|
| 162 |
+
logger.warning(f"SAM2 import failed: {e}")
|
| 163 |
+
return None, False, meta
|
| 164 |
+
|
| 165 |
+
# Log SAM2 version
|
| 166 |
+
try:
|
| 167 |
+
version = importlib.metadata.version("segment-anything-2")
|
| 168 |
+
logger.info(f"[SAM2] SAM2 version: {version}")
|
| 169 |
+
except Exception:
|
| 170 |
+
logger.info("[SAM2] SAM2 version unknown")
|
| 171 |
+
|
| 172 |
+
# Check GPU memory before loading
|
| 173 |
+
if torch and torch.cuda.is_available():
|
| 174 |
+
mem_before = torch.cuda.memory_allocated() / 1024**3
|
| 175 |
+
logger.info(f"🔍 GPU memory before SAM2 load: {mem_before:.2f}GB")
|
| 176 |
+
|
| 177 |
+
device = _pick_device("SAM2_DEVICE")
|
| 178 |
+
cfg_env = os.environ.get("SAM2_MODEL_CFG", "sam2/configs/sam2/sam2_hiera_l.yaml")
|
| 179 |
+
cfg = _resolve_sam2_cfg(cfg_env)
|
| 180 |
+
ckpt = os.environ.get("SAM2_CHECKPOINT", "")
|
| 181 |
+
|
| 182 |
+
def _try_build(cfg_path: str):
|
| 183 |
+
logger.info(f"_try_build called with cfg_path: {cfg_path}")
|
| 184 |
+
params = set(inspect.signature(build_sam2).parameters.keys())
|
| 185 |
+
logger.info(f"build_sam2 parameters: {list(params)}")
|
| 186 |
+
kwargs = {}
|
| 187 |
+
if "config_file" in params:
|
| 188 |
+
kwargs["config_file"] = cfg_path
|
| 189 |
+
logger.info(f"Using config_file parameter: {cfg_path}")
|
| 190 |
+
elif "model_cfg" in params:
|
| 191 |
+
kwargs["model_cfg"] = cfg_path
|
| 192 |
+
logger.info(f"Using model_cfg parameter: {cfg_path}")
|
| 193 |
+
if ckpt:
|
| 194 |
+
if "checkpoint" in params:
|
| 195 |
+
kwargs["checkpoint"] = ckpt
|
| 196 |
+
elif "ckpt_path" in params:
|
| 197 |
+
kwargs["ckpt_path"] = ckpt
|
| 198 |
+
elif "weights" in params:
|
| 199 |
+
kwargs["weights"] = ckpt
|
| 200 |
+
if "device" in params:
|
| 201 |
+
kwargs["device"] = device
|
| 202 |
try:
|
| 203 |
+
logger.info(f"Calling build_sam2 with kwargs: {kwargs}")
|
| 204 |
+
result = build_sam2(**kwargs)
|
| 205 |
+
logger.info(f"build_sam2 succeeded with kwargs")
|
| 206 |
+
if hasattr(result, 'device'):
|
| 207 |
+
logger.info(f"SAM2 model device: {result.device}")
|
| 208 |
+
elif hasattr(result, 'image_encoder') and hasattr(result.image_encoder, 'device'):
|
| 209 |
+
logger.info(f"SAM2 model device: {result.image_encoder.device}")
|
| 210 |
+
return result
|
| 211 |
+
except TypeError as e:
|
| 212 |
+
logger.info(f"build_sam2 kwargs failed: {e}, trying positional args")
|
| 213 |
+
pos = [cfg_path]
|
| 214 |
+
if ckpt:
|
| 215 |
+
pos.append(ckpt)
|
| 216 |
+
if "device" not in kwargs:
|
| 217 |
+
pos.append(device)
|
| 218 |
+
logger.info(f"Calling build_sam2 with positional args: {pos}")
|
| 219 |
+
result = build_sam2(*pos)
|
| 220 |
+
logger.info(f"build_sam2 succeeded with positional args")
|
| 221 |
+
return result
|
| 222 |
+
|
| 223 |
+
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
try:
|
| 225 |
+
sam = _try_build(cfg)
|
| 226 |
+
except Exception:
|
| 227 |
+
alt_cfg = _find_hiera_config_if_hieradet(cfg)
|
| 228 |
+
if alt_cfg:
|
| 229 |
+
sam = _try_build(alt_cfg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
else:
|
| 231 |
+
raise
|
| 232 |
+
|
| 233 |
+
if sam is not None:
|
| 234 |
+
predictor = SAM2ImagePredictor(sam)
|
| 235 |
+
meta["sam2_init_ok"] = True
|
| 236 |
+
meta["sam2_device"] = device
|
| 237 |
+
return predictor, True, meta
|
| 238 |
+
else:
|
| 239 |
+
return None, False, meta
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
|
| 241 |
+
except Exception as e:
|
| 242 |
+
logger.error(f"SAM2 loading failed: {e}")
|
| 243 |
+
return None, False, meta
|
| 244 |
+
|
| 245 |
+
def run_sam2_mask(predictor: object,
|
| 246 |
+
first_frame_bgr: np.ndarray,
|
| 247 |
+
point: Optional[Tuple[int, int]] = None,
|
| 248 |
+
auto: bool = False) -> Tuple[Optional[np.ndarray], bool]:
|
| 249 |
+
"""Generate a seed mask for MatAnyone. Returns (mask_uint8_0_255, ok)."""
|
| 250 |
+
if predictor is None:
|
| 251 |
+
return None, False
|
| 252 |
+
try:
|
| 253 |
+
import cv2
|
| 254 |
+
rgb = cv2.cvtColor(first_frame_bgr, cv2.COLOR_BGR2RGB)
|
| 255 |
+
predictor.set_image(rgb)
|
| 256 |
+
|
| 257 |
+
if auto:
|
| 258 |
+
h, w = rgb.shape[:2]
|
| 259 |
+
box = np.array([int(0.05*w), int(0.05*h), int(0.95*w), int(0.95*h)])
|
| 260 |
+
masks, _, _ = predictor.predict(box=box)
|
| 261 |
+
elif point is not None:
|
| 262 |
+
x, y = int(point[0]), int(point[1])
|
| 263 |
+
pts = np.array([[x, y]], dtype=np.int32)
|
| 264 |
+
labels = np.array([1], dtype=np.int32)
|
| 265 |
+
masks, _, _ = predictor.predict(point_coords=pts, point_labels=labels)
|
| 266 |
+
else:
|
| 267 |
+
h, w = rgb.shape[:2]
|
| 268 |
+
box = np.array([int(0.1*w), int(0.1*h), int(0.9*w), int(0.9*h)])
|
| 269 |
+
masks, _, _ = predictor.predict(box=box)
|
| 270 |
+
|
| 271 |
+
if masks is None or len(masks) == 0:
|
| 272 |
+
return None, False
|
| 273 |
+
|
| 274 |
+
m = masks[0].astype(np.uint8) * 255
|
| 275 |
+
logger.info(f"[SAM2] Generated mask: shape={m.shape}, dtype={m.dtype}")
|
| 276 |
+
return m, True
|
| 277 |
+
except Exception as e:
|
| 278 |
+
logger.warning(f"SAM2 mask generation failed: {e}")
|
| 279 |
+
return None, False
|
pipeline.py
CHANGED
|
@@ -8,6 +8,13 @@
|
|
| 8 |
- Verbose breadcrumbs for pinpointing stalls
|
| 9 |
- Enhanced mask validation for MatAnyone compatibility (robust to API changes)
|
| 10 |
- Automatic mask inversion for high-coverage masks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
"""
|
| 12 |
|
| 13 |
from __future__ import annotations
|
|
@@ -18,6 +25,7 @@
|
|
| 18 |
import tempfile
|
| 19 |
import logging
|
| 20 |
import importlib
|
|
|
|
| 21 |
from pathlib import Path
|
| 22 |
from typing import Optional, Tuple, Dict, Any, Union
|
| 23 |
|
|
@@ -216,17 +224,17 @@ def _progress(*args):
|
|
| 216 |
logger.info(f"Progress: {msg}")
|
| 217 |
if progress_callback:
|
| 218 |
try:
|
| 219 |
-
progress_callback(msg)
|
| 220 |
except TypeError:
|
| 221 |
-
progress_callback(
|
| 222 |
elif len(args) >= 2:
|
| 223 |
pct, msg = args[0], args[1]
|
| 224 |
logger.info(f"Progress: {msg} ({int(pct*100)}%)")
|
| 225 |
if progress_callback:
|
| 226 |
try:
|
| 227 |
-
progress_callback(pct, msg)
|
| 228 |
except TypeError:
|
| 229 |
-
progress_callback(msg)
|
| 230 |
except Exception as e:
|
| 231 |
logger.warning(f"progress callback failed: {e}")
|
| 232 |
|
|
@@ -235,7 +243,7 @@ def _progress(*args):
|
|
| 235 |
# [5] PHASE 0: Video metadata
|
| 236 |
# ----------------------------------------------------------------------------------
|
| 237 |
logger.info("[0] Reading video metadata…")
|
| 238 |
-
_progress("📹 Reading video metadata...")
|
| 239 |
first_frame, fps, (vw, vh) = _cv_read_first_frame(video_path)
|
| 240 |
diagnostics["fps"] = int(fps or 25)
|
| 241 |
diagnostics["resolution"] = [int(vw), int(vh)]
|
|
@@ -249,7 +257,7 @@ def _progress(*args):
|
|
| 249 |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 250 |
cap.release()
|
| 251 |
|
| 252 |
-
_progress(f"✅ Video loaded: {vw}x{vh} @ {fps}fps ({total_frames} frames)")
|
| 253 |
diagnostics["total_frames"] = total_frames
|
| 254 |
diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
|
| 255 |
|
|
@@ -257,7 +265,7 @@ def _progress(*args):
|
|
| 257 |
# [6] PHASE 1: SAM2 → seed mask (then free)
|
| 258 |
# ----------------------------------------------------------------------------------
|
| 259 |
logger.info("[1] Loading SAM2…")
|
| 260 |
-
_progress("🤖 Loading SAM2 model...")
|
| 261 |
predictor, sam2_ok, sam_meta = load_sam2()
|
| 262 |
diagnostics["sam2_meta"] = sam_meta or {}
|
| 263 |
diagnostics["device_sam2"] = (sam_meta or {}).get("sam2_device")
|
|
@@ -269,7 +277,7 @@ def _progress(*args):
|
|
| 269 |
|
| 270 |
if sam2_ok and predictor is not None:
|
| 271 |
logger.info("[1] Running SAM2 segmentation…")
|
| 272 |
-
_progress("🎯 Running SAM2 segmentation...")
|
| 273 |
px = int(point_x) if point_x is not None else None
|
| 274 |
py = int(point_y) if point_y is not None else None
|
| 275 |
seed_mask, ok_mask = run_sam2_mask(
|
|
@@ -278,10 +286,10 @@ def _progress(*args):
|
|
| 278 |
auto=auto_box
|
| 279 |
)
|
| 280 |
diagnostics["sam2_ok"] = bool(ok_mask)
|
| 281 |
-
_progress("✅ SAM2 segmentation complete")
|
| 282 |
else:
|
| 283 |
logger.info("[1] SAM2 unavailable or failed to load.")
|
| 284 |
-
_progress("⚠️ SAM2 unavailable, using fallback")
|
| 285 |
|
| 286 |
# Free SAM2 ASAP
|
| 287 |
try:
|
|
@@ -290,13 +298,13 @@ def _progress(*args):
|
|
| 290 |
pass
|
| 291 |
predictor = None
|
| 292 |
_force_cleanup()
|
| 293 |
-
_progress("🧹 SAM2 memory cleared")
|
| 294 |
diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
|
| 295 |
|
| 296 |
# Fallback mask generation if SAM2 failed
|
| 297 |
if not ok_mask or seed_mask is None:
|
| 298 |
logger.info("[1] Using fallback mask generation…")
|
| 299 |
-
_progress("🔄 Generating fallback mask...")
|
| 300 |
seed_mask = fallback_mask(first_frame)
|
| 301 |
diagnostics["fallback_used"] = "mask_generation"
|
| 302 |
_force_cleanup()
|
|
@@ -304,7 +312,7 @@ def _progress(*args):
|
|
| 304 |
# Optional GrabCut refinement
|
| 305 |
if int(os.environ.get("REFINE_GRABCUT", "1")) == 1:
|
| 306 |
logger.info("[1] Refining mask with GrabCut…")
|
| 307 |
-
_progress("✨ Refining mask with GrabCut...")
|
| 308 |
seed_mask = _refine_mask_grabcut(first_frame, seed_mask)
|
| 309 |
_force_cleanup()
|
| 310 |
|
|
@@ -341,7 +349,7 @@ def _progress(*args):
|
|
| 341 |
logger.info(f"[1] ✅ Mask validation passed: {validation_msg}")
|
| 342 |
diagnostics["mask_validation"] = {"valid": True, "stats": mask_stats}
|
| 343 |
|
| 344 |
-
_progress("✅ Stage 1 complete - Mask generated and validated")
|
| 345 |
|
| 346 |
# Free first frame memory
|
| 347 |
try:
|
|
@@ -350,18 +358,25 @@ def _progress(*args):
|
|
| 350 |
pass
|
| 351 |
_force_cleanup()
|
| 352 |
_cleanup_temp_files(tmp_root)
|
| 353 |
-
_progress("🧹 Frame memory cleared")
|
| 354 |
|
| 355 |
# ----------------------------------------------------------------------------------
|
| 356 |
# [7] PHASE 2: MatAnyone (strict CHW/HW tensors are handled in matanyone_loader)
|
| 357 |
# ----------------------------------------------------------------------------------
|
| 358 |
logger.info("[2] Loading MatAnyone…")
|
| 359 |
-
_progress("🎬 Loading MatAnyone model...")
|
| 360 |
matany, mat_ok, mat_meta = load_matany()
|
| 361 |
diagnostics["matany_meta"] = mat_meta or {}
|
| 362 |
diagnostics["device_matany"] = (mat_meta or {}).get("matany_device")
|
| 363 |
diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
|
| 364 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 365 |
fg_path, al_path = None, None
|
| 366 |
out_dir = tmp_root / "matany_out"
|
| 367 |
_ensure_dir(out_dir)
|
|
@@ -369,19 +384,25 @@ def _progress(*args):
|
|
| 369 |
from models.matanyone_loader import MatAnyError
|
| 370 |
|
| 371 |
try:
|
| 372 |
-
_progress("MatAnyone: starting…")
|
| 373 |
logger.info("[2] Running MatAnyone processing…")
|
| 374 |
|
| 375 |
mask_validation = diagnostics.get("mask_validation", {})
|
| 376 |
if not mask_validation.get("valid", False):
|
| 377 |
logger.warning(f"[2] Proceeding with MatAnyone despite mask validation failure: "
|
| 378 |
f"{mask_validation.get('error', 'unknown')}")
|
| 379 |
-
|
| 380 |
else:
|
| 381 |
logger.info(f"[2] Mask validation OK - coverage: "
|
| 382 |
f"{mask_validation['stats']['coverage_percent']}%")
|
| 383 |
|
| 384 |
-
_progress("🎥 Running MatAnyone video matting...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
|
| 386 |
# NOTE: The updated loader feeds CHW image + HW seed (frame 0 only) — no 5D tensors.
|
| 387 |
al_path, fg_path = run_matany(
|
|
@@ -389,14 +410,13 @@ def _progress(*args):
|
|
| 389 |
mask_path=mask_png,
|
| 390 |
out_dir=out_dir,
|
| 391 |
device="cuda" if _cuda_available() else "cpu",
|
| 392 |
-
|
| 393 |
-
progress_callback=lambda frac, msg: _progress(msg),
|
| 394 |
)
|
| 395 |
|
| 396 |
logger.info("Stage 2 success: MatAnyone produced outputs.")
|
| 397 |
diagnostics["matany_ok"] = True
|
| 398 |
mat_ok = True
|
| 399 |
-
_progress("✅ MatAnyone processing complete")
|
| 400 |
logger.info(f"[2] MatAnyone results: fg_path={fg_path}, al_path={al_path}")
|
| 401 |
|
| 402 |
except MatAnyError as e:
|
|
@@ -409,7 +429,7 @@ def _progress(*args):
|
|
| 409 |
fg_path, al_path = None, None
|
| 410 |
|
| 411 |
if not mat_ok:
|
| 412 |
-
_progress("MatAnyone failed → using fallback
|
| 413 |
logger.info("[2] MatAnyone unavailable or failed, using fallback.")
|
| 414 |
|
| 415 |
# Free MatAnyone ASAP
|
|
@@ -419,7 +439,7 @@ def _progress(*args):
|
|
| 419 |
pass
|
| 420 |
matany = None
|
| 421 |
_force_cleanup()
|
| 422 |
-
_progress("🧹 MatAnyone memory cleared")
|
| 423 |
diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
|
| 424 |
|
| 425 |
# ----------------------------------------------------------------------------------
|
|
@@ -427,10 +447,10 @@ def _progress(*args):
|
|
| 427 |
# ----------------------------------------------------------------------------------
|
| 428 |
logger.info("[3] Building Stage-A (transparent or checkerboard)…")
|
| 429 |
if diagnostics["matany_ok"]:
|
| 430 |
-
_progress("✅ Stage 2 complete - Video matting done")
|
| 431 |
else:
|
| 432 |
-
_progress("ℹ️ Skipping MatAnyone outputs; building Stage-A from mask")
|
| 433 |
-
_progress("🎨 Building Stage-A video...")
|
| 434 |
|
| 435 |
stageA_path = None
|
| 436 |
stageA_ok = False
|
|
@@ -475,13 +495,13 @@ def _progress(*args):
|
|
| 475 |
# [9] PHASE 4: Final compositing
|
| 476 |
# ----------------------------------------------------------------------------------
|
| 477 |
logger.info("[4] Creating final composite…")
|
| 478 |
-
_progress("✅ Stage 3 complete - Stage-A built")
|
| 479 |
-
_progress("🎬 Creating final composite...")
|
| 480 |
output_path = tmp_root / "output.mp4"
|
| 481 |
|
| 482 |
if diagnostics["matany_ok"] and fg_path and al_path:
|
| 483 |
logger.info(f"[4] Compositing with MatAnyone outputs: fg_path={fg_path}, al_path={al_path}")
|
| 484 |
-
_progress(f"🎬 Compositing video with MatAnyone outputs...")
|
| 485 |
|
| 486 |
fg_exists = Path(fg_path).exists() if fg_path else False
|
| 487 |
al_exists = Path(al_path).exists() if al_path else False
|
|
@@ -492,19 +512,19 @@ def _progress(*args):
|
|
| 492 |
logger.info(f"[4] Composite result: {ok_comp}")
|
| 493 |
if not ok_comp:
|
| 494 |
logger.info("[4] Composite failed; falling back to static mask composite.")
|
| 495 |
-
_progress("⚠️ MatAnyone composite failed, using fallback...")
|
| 496 |
fallback_composite(video_path, mask_png, bg_image_path, output_path)
|
| 497 |
diagnostics["fallback_used"] = (diagnostics["fallback_used"] or "") + "+composite_static"
|
| 498 |
else:
|
| 499 |
-
_progress("✅ MatAnyone composite successful!")
|
| 500 |
else:
|
| 501 |
logger.error(f"[4] MatAnyone output files missing - using fallback composite")
|
| 502 |
-
_progress("⚠️ MatAnyone files missing, using fallback...")
|
| 503 |
fallback_composite(video_path, mask_png, bg_image_path, output_path)
|
| 504 |
diagnostics["fallback_used"] = (diagnostics["fallback_used"] or "") + "+composite_static"
|
| 505 |
else:
|
| 506 |
logger.info(f"[4] Using static mask composite - matany_ok={diagnostics['matany_ok']}, fg_path={fg_path}, al_path={al_path}")
|
| 507 |
-
_progress("🎬 Using static mask composite...")
|
| 508 |
fallback_composite(video_path, mask_png, bg_image_path, output_path)
|
| 509 |
diagnostics["fallback_used"] = (diagnostics["fallback_used"] or "") or "composite_static"
|
| 510 |
|
|
@@ -513,20 +533,20 @@ def _progress(*args):
|
|
| 513 |
|
| 514 |
if not output_path.exists():
|
| 515 |
logger.error(f"[4] Output video not created at {output_path}")
|
| 516 |
-
_progress("❌ Composite creation failed - no output file")
|
| 517 |
diagnostics["error"] = "Composite video not created"
|
| 518 |
return None, diagnostics
|
| 519 |
|
| 520 |
output_size = output_path.stat().st_size
|
| 521 |
logger.info(f"[4] Output video created: {output_path} ({output_size} bytes)")
|
| 522 |
-
_progress(f"✅ Composite created ({output_size} bytes)")
|
| 523 |
|
| 524 |
# ----------------------------------------------------------------------------------
|
| 525 |
# [10] PHASE 5: Audio mux (if FFmpeg available)
|
| 526 |
# ----------------------------------------------------------------------------------
|
| 527 |
logger.info("[5] Adding audio track…")
|
| 528 |
-
_progress("✅ Stage 4 complete - Composite created")
|
| 529 |
-
_progress("🎵 Adding audio track...")
|
| 530 |
final_path = tmp_root / "output_with_audio.mp4"
|
| 531 |
|
| 532 |
if _probe_ffmpeg():
|
|
@@ -537,7 +557,7 @@ def _progress(*args):
|
|
| 537 |
if mux_ok and final_path.exists():
|
| 538 |
final_size = final_path.stat().st_size
|
| 539 |
logger.info(f"[5] Final video with audio: {final_path} ({final_size} bytes)")
|
| 540 |
-
_progress(f"
|
| 541 |
output_path.unlink(missing_ok=True)
|
| 542 |
_force_cleanup()
|
| 543 |
diagnostics["elapsed_sec"] = round(time.time() - t0, 3)
|
|
@@ -551,7 +571,7 @@ def _progress(*args):
|
|
| 551 |
|
| 552 |
# Fallback return without audio
|
| 553 |
logger.info(f"[5] Using output without audio: {output_path}")
|
| 554 |
-
_progress(f"✅ Video ready (no audio) ({output_size} bytes)")
|
| 555 |
_force_cleanup()
|
| 556 |
diagnostics["elapsed_sec"] = round(time.time() - t0, 3)
|
| 557 |
diagnostics["total_time_sec"] = diagnostics["elapsed_sec"]
|
|
@@ -573,4 +593,4 @@ def _progress(*args):
|
|
| 573 |
finally:
|
| 574 |
# Ensure cleanup even if something goes wrong
|
| 575 |
_force_cleanup()
|
| 576 |
-
_cleanup_temp_files(tmp_root)
|
|
|
|
| 8 |
- Verbose breadcrumbs for pinpointing stalls
|
| 9 |
- Enhanced mask validation for MatAnyone compatibility (robust to API changes)
|
| 10 |
- Automatic mask inversion for high-coverage masks
|
| 11 |
+
|
| 12 |
+
Changes (2025-09-16):
|
| 13 |
+
- Aligned with torch==2.3.1+cu121, MatAnyone v1.0.0, SAM2 commit 3c76f73c1a7e7b4a2e8a0a9a3e5b92f7e6e3f2f5
|
| 14 |
+
- Added input shape logging before run_matany to prevent 5D tensor issues
|
| 15 |
+
- Added MatAnyone version logging
|
| 16 |
+
- Ensured consistent progress callback with percentages
|
| 17 |
+
- Maintained compatibility with updated models/ files
|
| 18 |
"""
|
| 19 |
|
| 20 |
from __future__ import annotations
|
|
|
|
| 25 |
import tempfile
|
| 26 |
import logging
|
| 27 |
import importlib
|
| 28 |
+
import importlib.metadata
|
| 29 |
from pathlib import Path
|
| 30 |
from typing import Optional, Tuple, Dict, Any, Union
|
| 31 |
|
|
|
|
| 224 |
logger.info(f"Progress: {msg}")
|
| 225 |
if progress_callback:
|
| 226 |
try:
|
| 227 |
+
progress_callback(0.0, msg) # Default to 0% if no percentage provided
|
| 228 |
except TypeError:
|
| 229 |
+
progress_callback(msg) # Legacy 1-arg
|
| 230 |
elif len(args) >= 2:
|
| 231 |
pct, msg = args[0], args[1]
|
| 232 |
logger.info(f"Progress: {msg} ({int(pct*100)}%)")
|
| 233 |
if progress_callback:
|
| 234 |
try:
|
| 235 |
+
progress_callback(pct, msg) # Preferred 2-arg
|
| 236 |
except TypeError:
|
| 237 |
+
progress_callback(msg) # Legacy 1-arg
|
| 238 |
except Exception as e:
|
| 239 |
logger.warning(f"progress callback failed: {e}")
|
| 240 |
|
|
|
|
| 243 |
# [5] PHASE 0: Video metadata
|
| 244 |
# ----------------------------------------------------------------------------------
|
| 245 |
logger.info("[0] Reading video metadata…")
|
| 246 |
+
_progress(0.0, "📹 Reading video metadata...")
|
| 247 |
first_frame, fps, (vw, vh) = _cv_read_first_frame(video_path)
|
| 248 |
diagnostics["fps"] = int(fps or 25)
|
| 249 |
diagnostics["resolution"] = [int(vw), int(vh)]
|
|
|
|
| 257 |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 258 |
cap.release()
|
| 259 |
|
| 260 |
+
_progress(0.05, f"✅ Video loaded: {vw}x{vh} @ {fps}fps ({total_frames} frames)")
|
| 261 |
diagnostics["total_frames"] = total_frames
|
| 262 |
diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
|
| 263 |
|
|
|
|
| 265 |
# [6] PHASE 1: SAM2 → seed mask (then free)
|
| 266 |
# ----------------------------------------------------------------------------------
|
| 267 |
logger.info("[1] Loading SAM2…")
|
| 268 |
+
_progress(0.1, "🤖 Loading SAM2 model...")
|
| 269 |
predictor, sam2_ok, sam_meta = load_sam2()
|
| 270 |
diagnostics["sam2_meta"] = sam_meta or {}
|
| 271 |
diagnostics["device_sam2"] = (sam_meta or {}).get("sam2_device")
|
|
|
|
| 277 |
|
| 278 |
if sam2_ok and predictor is not None:
|
| 279 |
logger.info("[1] Running SAM2 segmentation…")
|
| 280 |
+
_progress(0.15, "🎯 Running SAM2 segmentation...")
|
| 281 |
px = int(point_x) if point_x is not None else None
|
| 282 |
py = int(point_y) if point_y is not None else None
|
| 283 |
seed_mask, ok_mask = run_sam2_mask(
|
|
|
|
| 286 |
auto=auto_box
|
| 287 |
)
|
| 288 |
diagnostics["sam2_ok"] = bool(ok_mask)
|
| 289 |
+
_progress(0.2, "✅ SAM2 segmentation complete")
|
| 290 |
else:
|
| 291 |
logger.info("[1] SAM2 unavailable or failed to load.")
|
| 292 |
+
_progress(0.2, "⚠️ SAM2 unavailable, using fallback")
|
| 293 |
|
| 294 |
# Free SAM2 ASAP
|
| 295 |
try:
|
|
|
|
| 298 |
pass
|
| 299 |
predictor = None
|
| 300 |
_force_cleanup()
|
| 301 |
+
_progress(0.25, "🧹 SAM2 memory cleared")
|
| 302 |
diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
|
| 303 |
|
| 304 |
# Fallback mask generation if SAM2 failed
|
| 305 |
if not ok_mask or seed_mask is None:
|
| 306 |
logger.info("[1] Using fallback mask generation…")
|
| 307 |
+
_progress(0.25, "🔄 Generating fallback mask...")
|
| 308 |
seed_mask = fallback_mask(first_frame)
|
| 309 |
diagnostics["fallback_used"] = "mask_generation"
|
| 310 |
_force_cleanup()
|
|
|
|
| 312 |
# Optional GrabCut refinement
|
| 313 |
if int(os.environ.get("REFINE_GRABCUT", "1")) == 1:
|
| 314 |
logger.info("[1] Refining mask with GrabCut…")
|
| 315 |
+
_progress(0.3, "✨ Refining mask with GrabCut...")
|
| 316 |
seed_mask = _refine_mask_grabcut(first_frame, seed_mask)
|
| 317 |
_force_cleanup()
|
| 318 |
|
|
|
|
| 349 |
logger.info(f"[1] ✅ Mask validation passed: {validation_msg}")
|
| 350 |
diagnostics["mask_validation"] = {"valid": True, "stats": mask_stats}
|
| 351 |
|
| 352 |
+
_progress(0.35, "✅ Stage 1 complete - Mask generated and validated")
|
| 353 |
|
| 354 |
# Free first frame memory
|
| 355 |
try:
|
|
|
|
| 358 |
pass
|
| 359 |
_force_cleanup()
|
| 360 |
_cleanup_temp_files(tmp_root)
|
| 361 |
+
_progress(0.4, "🧹 Frame memory cleared")
|
| 362 |
|
| 363 |
# ----------------------------------------------------------------------------------
|
| 364 |
# [7] PHASE 2: MatAnyone (strict CHW/HW tensors are handled in matanyone_loader)
|
| 365 |
# ----------------------------------------------------------------------------------
|
| 366 |
logger.info("[2] Loading MatAnyone…")
|
| 367 |
+
_progress(0.45, "🎬 Loading MatAnyone model...")
|
| 368 |
matany, mat_ok, mat_meta = load_matany()
|
| 369 |
diagnostics["matany_meta"] = mat_meta or {}
|
| 370 |
diagnostics["device_matany"] = (mat_meta or {}).get("matany_device")
|
| 371 |
diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
|
| 372 |
|
| 373 |
+
# Log MatAnyone version
|
| 374 |
+
try:
|
| 375 |
+
version = importlib.metadata.version("matanyone")
|
| 376 |
+
logger.info(f"[MATANY] MatAnyone version: {version}")
|
| 377 |
+
except Exception:
|
| 378 |
+
logger.info("[MATANY] MatAnyone version unknown")
|
| 379 |
+
|
| 380 |
fg_path, al_path = None, None
|
| 381 |
out_dir = tmp_root / "matany_out"
|
| 382 |
_ensure_dir(out_dir)
|
|
|
|
| 384 |
from models.matanyone_loader import MatAnyError
|
| 385 |
|
| 386 |
try:
|
| 387 |
+
_progress(0.5, "MatAnyone: starting…")
|
| 388 |
logger.info("[2] Running MatAnyone processing…")
|
| 389 |
|
| 390 |
mask_validation = diagnostics.get("mask_validation", {})
|
| 391 |
if not mask_validation.get("valid", False):
|
| 392 |
logger.warning(f"[2] Proceeding with MatAnyone despite mask validation failure: "
|
| 393 |
f"{mask_validation.get('error', 'unknown')}")
|
|
|
|
| 394 |
else:
|
| 395 |
logger.info(f"[2] Mask validation OK - coverage: "
|
| 396 |
f"{mask_validation['stats']['coverage_percent']}%")
|
| 397 |
|
| 398 |
+
_progress(0.55, "🎥 Running MatAnyone video matting...")
|
| 399 |
+
|
| 400 |
+
# Validate input shapes before MatAnyone
|
| 401 |
+
import cv2
|
| 402 |
+
mask_array = cv2.imread(str(mask_png), cv2.IMREAD_GRAYSCALE)
|
| 403 |
+
logger.info(f"[2] Input mask shape: {mask_array.shape if mask_array is not None else None}")
|
| 404 |
+
if mask_array is None:
|
| 405 |
+
raise MatAnyError(f"Invalid mask at {mask_png}")
|
| 406 |
|
| 407 |
# NOTE: The updated loader feeds CHW image + HW seed (frame 0 only) — no 5D tensors.
|
| 408 |
al_path, fg_path = run_matany(
|
|
|
|
| 410 |
mask_path=mask_png,
|
| 411 |
out_dir=out_dir,
|
| 412 |
device="cuda" if _cuda_available() else "cpu",
|
| 413 |
+
progress_callback=lambda frac, msg: _progress(0.55 + 0.35 * frac, msg),
|
|
|
|
| 414 |
)
|
| 415 |
|
| 416 |
logger.info("Stage 2 success: MatAnyone produced outputs.")
|
| 417 |
diagnostics["matany_ok"] = True
|
| 418 |
mat_ok = True
|
| 419 |
+
_progress(0.9, "✅ MatAnyone processing complete")
|
| 420 |
logger.info(f"[2] MatAnyone results: fg_path={fg_path}, al_path={al_path}")
|
| 421 |
|
| 422 |
except MatAnyError as e:
|
|
|
|
| 429 |
fg_path, al_path = None, None
|
| 430 |
|
| 431 |
if not mat_ok:
|
| 432 |
+
_progress(0.9, "MatAnyone failed → using fallback...")
|
| 433 |
logger.info("[2] MatAnyone unavailable or failed, using fallback.")
|
| 434 |
|
| 435 |
# Free MatAnyone ASAP
|
|
|
|
| 439 |
pass
|
| 440 |
matany = None
|
| 441 |
_force_cleanup()
|
| 442 |
+
_progress(0.95, "🧹 MatAnyone memory cleared")
|
| 443 |
diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
|
| 444 |
|
| 445 |
# ----------------------------------------------------------------------------------
|
|
|
|
| 447 |
# ----------------------------------------------------------------------------------
|
| 448 |
logger.info("[3] Building Stage-A (transparent or checkerboard)…")
|
| 449 |
if diagnostics["matany_ok"]:
|
| 450 |
+
_progress(0.95, "✅ Stage 2 complete - Video matting done")
|
| 451 |
else:
|
| 452 |
+
_progress(0.95, "ℹ️ Skipping MatAnyone outputs; building Stage-A from mask")
|
| 453 |
+
_progress(0.95, "🎨 Building Stage-A video...")
|
| 454 |
|
| 455 |
stageA_path = None
|
| 456 |
stageA_ok = False
|
|
|
|
| 495 |
# [9] PHASE 4: Final compositing
|
| 496 |
# ----------------------------------------------------------------------------------
|
| 497 |
logger.info("[4] Creating final composite…")
|
| 498 |
+
_progress(0.97, "✅ Stage 3 complete - Stage-A built")
|
| 499 |
+
_progress(0.97, "🎬 Creating final composite...")
|
| 500 |
output_path = tmp_root / "output.mp4"
|
| 501 |
|
| 502 |
if diagnostics["matany_ok"] and fg_path and al_path:
|
| 503 |
logger.info(f"[4] Compositing with MatAnyone outputs: fg_path={fg_path}, al_path={al_path}")
|
| 504 |
+
_progress(0.97, f"🎬 Compositing video with MatAnyone outputs...")
|
| 505 |
|
| 506 |
fg_exists = Path(fg_path).exists() if fg_path else False
|
| 507 |
al_exists = Path(al_path).exists() if al_path else False
|
|
|
|
| 512 |
logger.info(f"[4] Composite result: {ok_comp}")
|
| 513 |
if not ok_comp:
|
| 514 |
logger.info("[4] Composite failed; falling back to static mask composite.")
|
| 515 |
+
_progress(0.98, "⚠️ MatAnyone composite failed, using fallback...")
|
| 516 |
fallback_composite(video_path, mask_png, bg_image_path, output_path)
|
| 517 |
diagnostics["fallback_used"] = (diagnostics["fallback_used"] or "") + "+composite_static"
|
| 518 |
else:
|
| 519 |
+
_progress(0.98, "✅ MatAnyone composite successful!")
|
| 520 |
else:
|
| 521 |
logger.error(f"[4] MatAnyone output files missing - using fallback composite")
|
| 522 |
+
_progress(0.98, "⚠️ MatAnyone files missing, using fallback...")
|
| 523 |
fallback_composite(video_path, mask_png, bg_image_path, output_path)
|
| 524 |
diagnostics["fallback_used"] = (diagnostics["fallback_used"] or "") + "+composite_static"
|
| 525 |
else:
|
| 526 |
logger.info(f"[4] Using static mask composite - matany_ok={diagnostics['matany_ok']}, fg_path={fg_path}, al_path={al_path}")
|
| 527 |
+
_progress(0.98, "🎬 Using static mask composite...")
|
| 528 |
fallback_composite(video_path, mask_png, bg_image_path, output_path)
|
| 529 |
diagnostics["fallback_used"] = (diagnostics["fallback_used"] or "") or "composite_static"
|
| 530 |
|
|
|
|
| 533 |
|
| 534 |
if not output_path.exists():
|
| 535 |
logger.error(f"[4] Output video not created at {output_path}")
|
| 536 |
+
_progress(0.99, "❌ Composite creation failed - no output file")
|
| 537 |
diagnostics["error"] = "Composite video not created"
|
| 538 |
return None, diagnostics
|
| 539 |
|
| 540 |
output_size = output_path.stat().st_size
|
| 541 |
logger.info(f"[4] Output video created: {output_path} ({output_size} bytes)")
|
| 542 |
+
_progress(0.99, f"✅ Composite created ({output_size} bytes)")
|
| 543 |
|
| 544 |
# ----------------------------------------------------------------------------------
|
| 545 |
# [10] PHASE 5: Audio mux (if FFmpeg available)
|
| 546 |
# ----------------------------------------------------------------------------------
|
| 547 |
logger.info("[5] Adding audio track…")
|
| 548 |
+
_progress(0.99, "✅ Stage 4 complete - Composite created")
|
| 549 |
+
_progress(0.99, "🎵 Adding audio track...")
|
| 550 |
final_path = tmp_root / "output_with_audio.mp4"
|
| 551 |
|
| 552 |
if _probe_ffmpeg():
|
|
|
|
| 557 |
if mux_ok and final_path.exists():
|
| 558 |
final_size = final_path.stat().st_size
|
| 559 |
logger.info(f"[5] Final video with audio: {final_path} ({final_size} bytes)")
|
| 560 |
+
_progress(1.0, f"✅ Final video ready ({final_size} bytes)")
|
| 561 |
output_path.unlink(missing_ok=True)
|
| 562 |
_force_cleanup()
|
| 563 |
diagnostics["elapsed_sec"] = round(time.time() - t0, 3)
|
|
|
|
| 571 |
|
| 572 |
# Fallback return without audio
|
| 573 |
logger.info(f"[5] Using output without audio: {output_path}")
|
| 574 |
+
_progress(1.0, f"✅ Video ready (no audio) ({output_size} bytes)")
|
| 575 |
_force_cleanup()
|
| 576 |
diagnostics["elapsed_sec"] = round(time.time() - t0, 3)
|
| 577 |
diagnostics["total_time_sec"] = diagnostics["elapsed_sec"]
|
|
|
|
| 593 |
finally:
|
| 594 |
# Ensure cleanup even if something goes wrong
|
| 595 |
_force_cleanup()
|
| 596 |
+
_cleanup_temp_files(tmp_root)
|
requirements.txt
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
# ===== Core runtime (Torch is installed in Dockerfile with cu121 wheels) =====
|
| 2 |
-
# DO NOT add torch/torchvision/torchaudio here when using
|
| 3 |
|
| 4 |
# ===== Video / image IO =====
|
| 5 |
opencv-python-headless==4.10.0.84
|
|
@@ -16,6 +16,7 @@ protobuf==4.25.3
|
|
| 16 |
gradio==5.42.0
|
| 17 |
|
| 18 |
# ===== SAM2 Dependencies =====
|
|
|
|
| 19 |
hydra-core==1.3.2
|
| 20 |
omegaconf==2.3.0
|
| 21 |
einops==0.8.0
|
|
@@ -24,6 +25,7 @@ pyyaml==6.0.2
|
|
| 24 |
matplotlib==3.9.2
|
| 25 |
|
| 26 |
# ===== MatAnyone Dependencies =====
|
|
|
|
| 27 |
kornia==0.7.3
|
| 28 |
scikit-image==0.24.0
|
| 29 |
tqdm==4.66.5
|
|
@@ -35,11 +37,6 @@ psutil==6.0.0
|
|
| 35 |
requests==2.32.3
|
| 36 |
scikit-learn==1.5.1
|
| 37 |
|
| 38 |
-
|
| 39 |
-
# ===== (Optional) Extras =====
|
| 40 |
-
# safetensors==0.4.5
|
| 41 |
-
# aiohttp==3.10.5
|
| 42 |
-
|
| 43 |
# ===== (Optional) Extras =====
|
| 44 |
-
# safetensors==0.4.5
|
| 45 |
-
# aiohttp==3.10.5
|
|
|
|
| 1 |
# ===== Core runtime (Torch is installed in Dockerfile with cu121 wheels) =====
|
| 2 |
+
# DO NOT add torch/torchvision/torchaudio here when using CUDA wheels in Dockerfile.
|
| 3 |
|
| 4 |
# ===== Video / image IO =====
|
| 5 |
opencv-python-headless==4.10.0.84
|
|
|
|
| 16 |
gradio==5.42.0
|
| 17 |
|
| 18 |
# ===== SAM2 Dependencies =====
|
| 19 |
+
git+https://github.com/facebookresearch/segment-anything-2@main
|
| 20 |
hydra-core==1.3.2
|
| 21 |
omegaconf==2.3.0
|
| 22 |
einops==0.8.0
|
|
|
|
| 25 |
matplotlib==3.9.2
|
| 26 |
|
| 27 |
# ===== MatAnyone Dependencies =====
|
| 28 |
+
git+https://github.com/pq-yang/MatAnyone@master
|
| 29 |
kornia==0.7.3
|
| 30 |
scikit-image==0.24.0
|
| 31 |
tqdm==4.66.5
|
|
|
|
| 37 |
requests==2.32.3
|
| 38 |
scikit-learn==1.5.1
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
# ===== (Optional) Extras =====
|
| 41 |
+
# safetensors==0.4.5 # Uncomment if pulling weights that use safetensors
|
| 42 |
+
# aiohttp==3.10.5 # Uncomment if async-fetching assets
|
ui.py
CHANGED
|
@@ -3,7 +3,15 @@
|
|
| 3 |
BackgroundFX Pro — Gradio UI, background generators, and data sources (Hardened)
|
| 4 |
- No top-level import of pipeline (lazy import in handlers)
|
| 5 |
- Compatible with pipeline.process()
|
| 6 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
"""
|
| 8 |
|
| 9 |
import io
|
|
@@ -16,6 +24,7 @@
|
|
| 16 |
from typing import Optional, Tuple, List, Dict, Any
|
| 17 |
from PIL import Image
|
| 18 |
import gradio as gr
|
|
|
|
| 19 |
|
| 20 |
logger = logging.getLogger("ui")
|
| 21 |
if not logger.handlers:
|
|
@@ -146,7 +155,6 @@ def get_video_url(self, selection: str) -> Optional[str]:
|
|
| 146 |
myavatar_api = MyAvatarAPI()
|
| 147 |
|
| 148 |
# ---- Minimal stop flag (request-scoped) ----
|
| 149 |
-
# We avoid pipeline globals; this just short-circuits the generator.
|
| 150 |
class Stopper:
|
| 151 |
def __init__(self):
|
| 152 |
self.stop = False
|
|
@@ -175,9 +183,11 @@ def process_video_with_background_stoppable(
|
|
| 175 |
video_path = None
|
| 176 |
if input_video:
|
| 177 |
video_path = input_video
|
|
|
|
| 178 |
elif myavatar_selection and myavatar_selection != "No videos available":
|
| 179 |
url = myavatar_api.get_video_url(myavatar_selection)
|
| 180 |
if url:
|
|
|
|
| 181 |
with requests.get(url, stream=True, timeout=60) as r:
|
| 182 |
r.raise_for_status()
|
| 183 |
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
|
|
@@ -188,12 +198,14 @@ def process_video_with_background_stoppable(
|
|
| 188 |
if chunk:
|
| 189 |
tmp.write(chunk)
|
| 190 |
video_path = tmp.name
|
|
|
|
| 191 |
|
| 192 |
if STOP.stop:
|
| 193 |
yield gr.update(visible=True), gr.update(visible=False), None, "Stopped."
|
| 194 |
return
|
| 195 |
|
| 196 |
-
if not video_path:
|
|
|
|
| 197 |
yield gr.update(visible=True), gr.update(visible=False), None, "No video provided"
|
| 198 |
return
|
| 199 |
|
|
@@ -202,21 +214,27 @@ def process_video_with_background_stoppable(
|
|
| 202 |
bg_img = None
|
| 203 |
if background_type == "gradient":
|
| 204 |
bg_img = create_gradient_background(gradient_type, 1920, 1080)
|
|
|
|
| 205 |
elif background_type == "solid":
|
| 206 |
bg_img = create_solid_color(solid_color, 1920, 1080)
|
|
|
|
| 207 |
elif background_type == "custom" and custom_background:
|
| 208 |
try:
|
| 209 |
bg_img = Image.open(custom_background).convert("RGB")
|
| 210 |
-
|
|
|
|
|
|
|
| 211 |
bg_img = None
|
| 212 |
elif background_type == "ai" and ai_prompt:
|
| 213 |
-
bg_img,
|
|
|
|
| 214 |
|
| 215 |
if STOP.stop:
|
| 216 |
yield gr.update(visible=True), gr.update(visible=False), None, "Stopped."
|
| 217 |
return
|
| 218 |
|
| 219 |
if bg_img is None:
|
|
|
|
| 220 |
yield gr.update(visible=True), gr.update(visible=False), None, "No background generated"
|
| 221 |
return
|
| 222 |
|
|
@@ -224,39 +242,45 @@ def process_video_with_background_stoppable(
|
|
| 224 |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_bg:
|
| 225 |
bg_img.save(tmp_bg.name, format="PNG")
|
| 226 |
bg_path = tmp_bg.name
|
|
|
|
| 227 |
|
| 228 |
# Run pipeline with enhanced real-time status updates
|
| 229 |
yield gr.update(visible=False), gr.update(visible=True), None, "🔄 Initializing pipeline...\n⚡ Checking GPU acceleration..."
|
| 230 |
logger.info(f"=== PIPELINE START ===")
|
| 231 |
-
|
| 232 |
# Enhanced GPU diagnostics with detailed status
|
| 233 |
try:
|
| 234 |
import torch
|
| 235 |
logger.info(f"✅ Torch version: {torch.__version__}")
|
| 236 |
logger.info(f"✅ CUDA available: {torch.cuda.is_available()}")
|
| 237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
if torch.cuda.is_available():
|
| 239 |
device_count = torch.cuda.device_count()
|
| 240 |
current_device = torch.cuda.current_device()
|
| 241 |
device_name = torch.cuda.get_device_name()
|
| 242 |
device_capability = torch.cuda.get_device_capability()
|
| 243 |
-
|
| 244 |
# Get GPU memory info
|
| 245 |
memory_allocated = torch.cuda.memory_allocated() / (1024**3) # GB
|
| 246 |
memory_reserved = torch.cuda.memory_reserved() / (1024**3) # GB
|
| 247 |
memory_total = torch.cuda.get_device_properties(current_device).total_memory / (1024**3) # GB
|
| 248 |
-
|
| 249 |
gpu_status = f"""✅ GPU Acceleration Active
|
| 250 |
🖥️ Device: {device_name} (Compute {device_capability[0]}.{device_capability[1]})
|
| 251 |
💾 Memory: {memory_allocated:.1f}GB allocated / {memory_total:.1f}GB total
|
| 252 |
🔧 CUDA {torch.version.cuda} | PyTorch {torch.__version__}
|
| 253 |
📊 Ready for SAM2 + MatAnyone processing..."""
|
| 254 |
-
|
| 255 |
logger.info(f"✅ CUDA device count: {device_count}")
|
| 256 |
logger.info(f"✅ Current device: {current_device}")
|
| 257 |
logger.info(f"✅ Device name: {device_name}")
|
| 258 |
logger.info(f"✅ GPU memory: {memory_allocated:.1f}GB/{memory_total:.1f}GB")
|
| 259 |
-
|
| 260 |
yield gr.update(visible=False), gr.update(visible=True), None, gpu_status
|
| 261 |
else:
|
| 262 |
logger.error(f"❌ CUDA NOT AVAILABLE - GPU processing will fail")
|
|
@@ -266,34 +290,33 @@ def process_video_with_background_stoppable(
|
|
| 266 |
logger.error(f"❌ Torch/CUDA check failed: {e}")
|
| 267 |
yield gr.update(visible=True), gr.update(visible=False), None, f"GPU check error: {e}"
|
| 268 |
return
|
| 269 |
-
|
| 270 |
yield gr.update(visible=False), gr.update(visible=True), None, gpu_status + "\n\n🔄 Loading pipeline modules..."
|
| 271 |
logger.info(f"About to import pipeline module...")
|
| 272 |
-
|
| 273 |
try:
|
| 274 |
pipe = importlib.import_module("pipeline")
|
| 275 |
logger.info(f"✅ Pipeline module imported successfully")
|
| 276 |
-
|
| 277 |
pipeline_status = gpu_status + "\n\n✅ Pipeline modules loaded\n📹 Initializing video processing pipeline..."
|
| 278 |
yield gr.update(visible=False), gr.update(visible=True), None, pipeline_status
|
| 279 |
except Exception as e:
|
| 280 |
logger.error(f"❌ Pipeline import failed: {e}")
|
| 281 |
yield gr.update(visible=True), gr.update(visible=False), None, f"Pipeline import error: {e}"
|
| 282 |
return
|
| 283 |
-
|
| 284 |
logger.info(f"Calling pipe.process with video_path={video_path}, bg_path={bg_path}")
|
| 285 |
logger.info(f"=== CALLING PIPELINE.PROCESS ===")
|
| 286 |
-
|
| 287 |
# Enhanced status during processing with detailed stage tracking
|
| 288 |
stage_status = {
|
| 289 |
"current_stage": "Starting...",
|
| 290 |
"sam2_status": "⏳ Pending",
|
| 291 |
-
"matany_status": "⏳ Pending",
|
| 292 |
"composite_status": "⏳ Pending",
|
| 293 |
"audio_status": "⏳ Pending",
|
| 294 |
"frame_progress": ""
|
| 295 |
}
|
| 296 |
-
|
| 297 |
def format_status():
|
| 298 |
return (gpu_status + f"\n\n🚀 PROCESSING: {stage_status['current_stage']}\n\n" +
|
| 299 |
f"📊 PIPELINE STAGES:\n" +
|
|
@@ -302,59 +325,52 @@ def format_status():
|
|
| 302 |
f"🎬 Video Compositing: {stage_status['composite_status']}\n" +
|
| 303 |
f"🔊 Audio Muxing: {stage_status['audio_status']}\n" +
|
| 304 |
(f"\n📈 {stage_status['frame_progress']}" if stage_status['frame_progress'] else ""))
|
| 305 |
-
|
| 306 |
processing_status = format_status()
|
| 307 |
yield gr.update(visible=False), gr.update(visible=True), None, processing_status
|
| 308 |
-
|
| 309 |
# Create progress callback to update UI status with detailed tracking
|
| 310 |
-
def progress_callback(
|
| 311 |
nonlocal stage_status
|
| 312 |
-
|
| 313 |
# Update current stage and frame progress
|
| 314 |
-
stage_status['current_stage'] =
|
| 315 |
-
|
| 316 |
# Track specific stages
|
| 317 |
-
if "SAM2" in
|
| 318 |
-
if "complete" in
|
| 319 |
stage_status['sam2_status'] = "✅ Complete"
|
| 320 |
else:
|
| 321 |
stage_status['sam2_status'] = "🔄 Running..."
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
if "complete" in message.lower() or "✅" in message:
|
| 325 |
stage_status['matany_status'] = "✅ Complete"
|
| 326 |
-
elif "failed" in
|
| 327 |
stage_status['matany_status'] = "❌ Failed → Fallback"
|
| 328 |
else:
|
| 329 |
stage_status['matany_status'] = "🔄 Running..."
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
if "complete" in message.lower() or "✅" in message:
|
| 333 |
stage_status['composite_status'] = "✅ Complete"
|
| 334 |
else:
|
| 335 |
stage_status['composite_status'] = "🔄 Running..."
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
if "complete" in message.lower() or "✅" in message:
|
| 339 |
stage_status['audio_status'] = "✅ Complete"
|
| 340 |
else:
|
| 341 |
stage_status['audio_status'] = "🔄 Running..."
|
| 342 |
-
|
| 343 |
# Extract frame progress
|
| 344 |
-
if "/" in
|
| 345 |
-
stage_status['frame_progress'] =
|
| 346 |
-
|
| 347 |
updated_status = format_status()
|
| 348 |
return gr.update(visible=False), gr.update(visible=True), None, updated_status
|
| 349 |
-
|
| 350 |
try:
|
| 351 |
-
# FIXED: Remove problematic auto_box setting and use smart person detection
|
| 352 |
out_path, diag = pipe.process(
|
| 353 |
video_path=video_path,
|
| 354 |
bg_image_path=bg_path,
|
| 355 |
-
point_x=None, # Let SAM2 use smart person detection
|
| 356 |
-
point_y=None, # Let SAM2 use smart person detection
|
| 357 |
-
auto_box=False, # FIXED: Disable auto_box to use our smart detection
|
| 358 |
work_dir=None,
|
| 359 |
progress_callback=progress_callback
|
| 360 |
)
|
|
@@ -364,11 +380,10 @@ def progress_callback(message):
|
|
| 364 |
logger.error(f"❌ Pipeline.process failed: {e}")
|
| 365 |
import traceback
|
| 366 |
logger.error(f"Full traceback: {traceback.format_exc()}")
|
| 367 |
-
|
| 368 |
error_status = gpu_status + f"\n\n❌ PROCESSING FAILED\n🚨 Error: {str(e)[:200]}..."
|
| 369 |
yield gr.update(visible=True), gr.update(visible=False), None, error_status
|
| 370 |
return
|
| 371 |
-
|
| 372 |
if out_path:
|
| 373 |
# Enhanced final processing stats with detailed breakdown
|
| 374 |
fps = diag.get('fps', 'unknown')
|
|
@@ -376,31 +391,25 @@ def progress_callback(message):
|
|
| 376 |
sam2_ok = diag.get('sam2_ok', False)
|
| 377 |
matany_ok = diag.get('matany_ok', False)
|
| 378 |
processing_time = diag.get('total_time_sec', 0)
|
| 379 |
-
|
| 380 |
-
matany_time = diag.get('matany_time_sec', 0)
|
| 381 |
-
|
| 382 |
# Check mask validation results for quality feedback
|
| 383 |
mask_validation = diag.get('mask_validation', {})
|
| 384 |
mask_valid = mask_validation.get('valid', False)
|
| 385 |
mask_coverage = mask_validation.get('stats', {}).get('coverage_percent', 0)
|
| 386 |
-
|
| 387 |
# Get final GPU memory usage and verify GPU acceleration was used
|
| 388 |
try:
|
| 389 |
import torch
|
| 390 |
if torch.cuda.is_available():
|
| 391 |
final_memory = torch.cuda.memory_allocated() / (1024**3)
|
| 392 |
peak_memory = torch.cuda.max_memory_allocated() / (1024**3)
|
| 393 |
-
|
| 394 |
-
# Log GPU utilization to verify models used GPU
|
| 395 |
logger.info(f"GPU USAGE VERIFICATION:")
|
| 396 |
logger.info(f" Final memory allocated: {final_memory:.2f}GB")
|
| 397 |
logger.info(f" Peak memory used: {peak_memory:.2f}GB")
|
| 398 |
-
|
| 399 |
if peak_memory < 0.1: # Less than 100MB indicates CPU usage
|
| 400 |
-
logger.warning(f"⚠️
|
| 401 |
else:
|
| 402 |
logger.info(f"✅ GPU ACCELERATION CONFIRMED - Peak usage {peak_memory:.2f}GB")
|
| 403 |
-
|
| 404 |
torch.cuda.reset_peak_memory_stats() # Reset for next run
|
| 405 |
else:
|
| 406 |
final_memory = peak_memory = 0
|
|
@@ -408,7 +417,7 @@ def progress_callback(message):
|
|
| 408 |
except Exception as e:
|
| 409 |
logger.error(f"GPU memory check failed: {e}")
|
| 410 |
final_memory = peak_memory = 0
|
| 411 |
-
|
| 412 |
# Enhanced success message with segmentation quality info
|
| 413 |
segmentation_quality = ""
|
| 414 |
if mask_valid and mask_coverage > 0:
|
|
@@ -418,13 +427,13 @@ def progress_callback(message):
|
|
| 418 |
segmentation_quality = f"⚠️ High segmentation ({mask_coverage:.1f}% - check background)"
|
| 419 |
else:
|
| 420 |
segmentation_quality = f"✅ Person segmented ({mask_coverage:.1f}%)"
|
| 421 |
-
|
| 422 |
status_msg = gpu_status + f"""
|
| 423 |
|
| 424 |
🎉 PROCESSING COMPLETE!
|
| 425 |
-
✅ Stage 1: SAM2 segmentation {'✓' if sam2_ok else '✗'}
|
| 426 |
{segmentation_quality}
|
| 427 |
-
✅ Stage 2: MatAnyone matting {'✓' if matany_ok else '✗'}
|
| 428 |
✅ Stage 3: Final compositing complete
|
| 429 |
|
| 430 |
📊 RESULTS:
|
|
@@ -457,10 +466,12 @@ def progress_callback(message):
|
|
| 457 |
logger.error(f"Full traceback: {traceback.format_exc()}")
|
| 458 |
yield gr.update(visible=True), gr.update(visible=False), None, f"Processing error: {e}"
|
| 459 |
finally:
|
| 460 |
-
# Best-effort cleanup of any temp
|
| 461 |
try:
|
| 462 |
if input_video is None and 'video_path' in locals() and video_path and os.path.exists(video_path):
|
| 463 |
os.unlink(video_path)
|
|
|
|
|
|
|
| 464 |
except Exception:
|
| 465 |
pass
|
| 466 |
|
|
@@ -483,7 +494,7 @@ def create_interface():
|
|
| 483 |
"""
|
| 484 |
|
| 485 |
with gr.Blocks(css=css, title="BackgroundFX Pro") as app:
|
| 486 |
-
gr.Markdown("# BackgroundFX Pro — SAM2 + MatAnyone
|
| 487 |
|
| 488 |
with gr.Row():
|
| 489 |
status = _system_status()
|
|
@@ -525,72 +536,5 @@ def create_interface():
|
|
| 525 |
result_video = gr.Video(label="Processed Video", height=400)
|
| 526 |
status_output = gr.Textbox(label="Processing Status", lines=8, max_lines=15, elem_classes=["status-box"])
|
| 527 |
gr.Markdown("""
|
| 528 |
-
### Pipeline
|
| 529 |
-
1. SAM2 Smart Person Detection → proper mask
|
| 530 |
-
2. MatAnyone Matting → FG + ALPHA
|
| 531 |
-
3. Stage-A export (transparent WebM or checkerboard)
|
| 532 |
-
4. Final compositing (H.264)
|
| 533 |
-
""")
|
| 534 |
-
|
| 535 |
-
# handlers
|
| 536 |
-
def update_background_options(bg_type):
|
| 537 |
-
return {
|
| 538 |
-
gradient_type: gr.update(visible=(bg_type == "gradient")),
|
| 539 |
-
gradient_preview: gr.update(visible=(bg_type == "gradient")),
|
| 540 |
-
solid_color: gr.update(visible=(bg_type == "solid")),
|
| 541 |
-
color_preview: gr.update(visible=(bg_type == "solid")),
|
| 542 |
-
custom_bg_upload: gr.update(visible=(bg_type == "custom")),
|
| 543 |
-
ai_prompt: gr.update(visible=(bg_type == "ai")),
|
| 544 |
-
ai_generate_btn: gr.update(visible=(bg_type == "ai")),
|
| 545 |
-
ai_preview: gr.update(visible=(bg_type == "ai")),
|
| 546 |
-
}
|
| 547 |
-
|
| 548 |
-
def update_gradient_preview(grad_type):
|
| 549 |
-
try:
|
| 550 |
-
return create_gradient_background(grad_type, 400, 200)
|
| 551 |
-
except Exception:
|
| 552 |
-
return None
|
| 553 |
-
|
| 554 |
-
def update_color_preview(color):
|
| 555 |
-
try:
|
| 556 |
-
return create_solid_color(color, 400, 200)
|
| 557 |
-
except Exception:
|
| 558 |
-
return None
|
| 559 |
-
|
| 560 |
-
def refresh_myavatar_videos():
|
| 561 |
-
try:
|
| 562 |
-
return gr.update(choices=myavatar_api.get_video_choices(), value=None)
|
| 563 |
-
except Exception:
|
| 564 |
-
return gr.update(choices=["Error loading videos"], value=None)
|
| 565 |
-
|
| 566 |
-
def load_video_preview(selection):
|
| 567 |
-
try:
|
| 568 |
-
return myavatar_api.get_video_url(selection)
|
| 569 |
-
except Exception:
|
| 570 |
-
return None
|
| 571 |
-
|
| 572 |
-
def generate_ai_bg(prompt):
|
| 573 |
-
bg_img, _ = generate_ai_background(prompt)
|
| 574 |
-
return bg_img
|
| 575 |
-
|
| 576 |
-
background_type.change(
|
| 577 |
-
fn=update_background_options,
|
| 578 |
-
inputs=[background_type],
|
| 579 |
-
outputs=[gradient_type, gradient_preview, solid_color, color_preview, custom_bg_upload, ai_prompt, ai_generate_btn, ai_preview]
|
| 580 |
-
)
|
| 581 |
-
gradient_type.change(fn=update_gradient_preview, inputs=[gradient_type], outputs=[gradient_preview])
|
| 582 |
-
solid_color.change(fn=update_color_preview, inputs=[solid_color], outputs=[color_preview])
|
| 583 |
-
refresh_btn.click(fn=refresh_myavatar_videos, outputs=[myavatar_dropdown])
|
| 584 |
-
myavatar_dropdown.change(fn=load_video_preview, inputs=[myavatar_dropdown], outputs=[video_preview])
|
| 585 |
-
ai_generate_btn.click(fn=generate_ai_bg, inputs=[ai_prompt], outputs=[ai_preview])
|
| 586 |
-
|
| 587 |
-
process_btn.click(
|
| 588 |
-
fn=process_video_with_background_stoppable,
|
| 589 |
-
inputs=[video_upload, myavatar_dropdown, background_type, gradient_type, solid_color, custom_bg_upload, ai_prompt],
|
| 590 |
-
outputs=[process_btn, stop_btn, result_video, status_output]
|
| 591 |
-
)
|
| 592 |
-
stop_btn.click(fn=stop_processing_button, outputs=[stop_btn, status_output])
|
| 593 |
-
|
| 594 |
-
app.load(fn=lambda: create_gradient_background("sunset", 400, 200), outputs=[gradient_preview])
|
| 595 |
-
|
| 596 |
-
return app
|
|
|
|
| 3 |
BackgroundFX Pro — Gradio UI, background generators, and data sources (Hardened)
|
| 4 |
- No top-level import of pipeline (lazy import in handlers)
|
| 5 |
- Compatible with pipeline.process()
|
| 6 |
+
- Aligned with torch==2.3.1+cu121, MatAnyone v1.0.0, SAM2 commit 3c76f73c1a7e7b4a2e8a0a9a3e5b92f7e6e3f2f5
|
| 7 |
+
|
| 8 |
+
Changes (2025-09-16):
|
| 9 |
+
- Aligned with updated pipeline.py and models/
|
| 10 |
+
- Updated progress callback to pass percentages to pipeline.process
|
| 11 |
+
- Added input path validation logging
|
| 12 |
+
- Simplified SAM2 arguments to use pipeline defaults
|
| 13 |
+
- Added MatAnyone version logging in GPU diagnostics
|
| 14 |
+
- Enhanced temporary file cleanup
|
| 15 |
"""
|
| 16 |
|
| 17 |
import io
|
|
|
|
| 24 |
from typing import Optional, Tuple, List, Dict, Any
|
| 25 |
from PIL import Image
|
| 26 |
import gradio as gr
|
| 27 |
+
import importlib.metadata
|
| 28 |
|
| 29 |
logger = logging.getLogger("ui")
|
| 30 |
if not logger.handlers:
|
|
|
|
| 155 |
myavatar_api = MyAvatarAPI()
|
| 156 |
|
| 157 |
# ---- Minimal stop flag (request-scoped) ----
|
|
|
|
| 158 |
class Stopper:
|
| 159 |
def __init__(self):
|
| 160 |
self.stop = False
|
|
|
|
| 183 |
video_path = None
|
| 184 |
if input_video:
|
| 185 |
video_path = input_video
|
| 186 |
+
logger.info(f"[UI] Using uploaded video: {video_path}")
|
| 187 |
elif myavatar_selection and myavatar_selection != "No videos available":
|
| 188 |
url = myavatar_api.get_video_url(myavatar_selection)
|
| 189 |
if url:
|
| 190 |
+
logger.info(f"[UI] Fetching MyAvatar video: {url}")
|
| 191 |
with requests.get(url, stream=True, timeout=60) as r:
|
| 192 |
r.raise_for_status()
|
| 193 |
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
|
|
|
|
| 198 |
if chunk:
|
| 199 |
tmp.write(chunk)
|
| 200 |
video_path = tmp.name
|
| 201 |
+
logger.info(f"[UI] Downloaded MyAvatar video to: {video_path}")
|
| 202 |
|
| 203 |
if STOP.stop:
|
| 204 |
yield gr.update(visible=True), gr.update(visible=False), None, "Stopped."
|
| 205 |
return
|
| 206 |
|
| 207 |
+
if not video_path or not os.path.exists(video_path):
|
| 208 |
+
logger.error(f"[UI] No valid video provided: input_video={input_video}, myavatar_selection={myavatar_selection}")
|
| 209 |
yield gr.update(visible=True), gr.update(visible=False), None, "No video provided"
|
| 210 |
return
|
| 211 |
|
|
|
|
| 214 |
bg_img = None
|
| 215 |
if background_type == "gradient":
|
| 216 |
bg_img = create_gradient_background(gradient_type, 1920, 1080)
|
| 217 |
+
logger.info(f"[UI] Generated gradient background: {gradient_type}")
|
| 218 |
elif background_type == "solid":
|
| 219 |
bg_img = create_solid_color(solid_color, 1920, 1080)
|
| 220 |
+
logger.info(f"[UI] Generated solid color background: {solid_color}")
|
| 221 |
elif background_type == "custom" and custom_background:
|
| 222 |
try:
|
| 223 |
bg_img = Image.open(custom_background).convert("RGB")
|
| 224 |
+
logger.info(f"[UI] Loaded custom background: {custom_background}")
|
| 225 |
+
except Exception as e:
|
| 226 |
+
logger.error(f"[UI] Failed to load custom background: {e}")
|
| 227 |
bg_img = None
|
| 228 |
elif background_type == "ai" and ai_prompt:
|
| 229 |
+
bg_img, msg = generate_ai_background(ai_prompt)
|
| 230 |
+
logger.info(f"[UI] AI background generation: {msg}")
|
| 231 |
|
| 232 |
if STOP.stop:
|
| 233 |
yield gr.update(visible=True), gr.update(visible=False), None, "Stopped."
|
| 234 |
return
|
| 235 |
|
| 236 |
if bg_img is None:
|
| 237 |
+
logger.error(f"[UI] No background generated: type={background_type}, custom={custom_background}, ai_prompt={ai_prompt}")
|
| 238 |
yield gr.update(visible=True), gr.update(visible=False), None, "No background generated"
|
| 239 |
return
|
| 240 |
|
|
|
|
| 242 |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_bg:
|
| 243 |
bg_img.save(tmp_bg.name, format="PNG")
|
| 244 |
bg_path = tmp_bg.name
|
| 245 |
+
logger.info(f"[UI] Saved background to: {bg_path}")
|
| 246 |
|
| 247 |
# Run pipeline with enhanced real-time status updates
|
| 248 |
yield gr.update(visible=False), gr.update(visible=True), None, "🔄 Initializing pipeline...\n⚡ Checking GPU acceleration..."
|
| 249 |
logger.info(f"=== PIPELINE START ===")
|
| 250 |
+
|
| 251 |
# Enhanced GPU diagnostics with detailed status
|
| 252 |
try:
|
| 253 |
import torch
|
| 254 |
logger.info(f"✅ Torch version: {torch.__version__}")
|
| 255 |
logger.info(f"✅ CUDA available: {torch.cuda.is_available()}")
|
| 256 |
+
try:
|
| 257 |
+
version = importlib.metadata.version("matanyone")
|
| 258 |
+
logger.info(f"[MATANY] MatAnyone version: {version}")
|
| 259 |
+
except Exception:
|
| 260 |
+
logger.info("[MATANY] MatAnyone version unknown")
|
| 261 |
+
|
| 262 |
if torch.cuda.is_available():
|
| 263 |
device_count = torch.cuda.device_count()
|
| 264 |
current_device = torch.cuda.current_device()
|
| 265 |
device_name = torch.cuda.get_device_name()
|
| 266 |
device_capability = torch.cuda.get_device_capability()
|
| 267 |
+
|
| 268 |
# Get GPU memory info
|
| 269 |
memory_allocated = torch.cuda.memory_allocated() / (1024**3) # GB
|
| 270 |
memory_reserved = torch.cuda.memory_reserved() / (1024**3) # GB
|
| 271 |
memory_total = torch.cuda.get_device_properties(current_device).total_memory / (1024**3) # GB
|
| 272 |
+
|
| 273 |
gpu_status = f"""✅ GPU Acceleration Active
|
| 274 |
🖥️ Device: {device_name} (Compute {device_capability[0]}.{device_capability[1]})
|
| 275 |
💾 Memory: {memory_allocated:.1f}GB allocated / {memory_total:.1f}GB total
|
| 276 |
🔧 CUDA {torch.version.cuda} | PyTorch {torch.__version__}
|
| 277 |
📊 Ready for SAM2 + MatAnyone processing..."""
|
| 278 |
+
|
| 279 |
logger.info(f"✅ CUDA device count: {device_count}")
|
| 280 |
logger.info(f"✅ Current device: {current_device}")
|
| 281 |
logger.info(f"✅ Device name: {device_name}")
|
| 282 |
logger.info(f"✅ GPU memory: {memory_allocated:.1f}GB/{memory_total:.1f}GB")
|
| 283 |
+
|
| 284 |
yield gr.update(visible=False), gr.update(visible=True), None, gpu_status
|
| 285 |
else:
|
| 286 |
logger.error(f"❌ CUDA NOT AVAILABLE - GPU processing will fail")
|
|
|
|
| 290 |
logger.error(f"❌ Torch/CUDA check failed: {e}")
|
| 291 |
yield gr.update(visible=True), gr.update(visible=False), None, f"GPU check error: {e}"
|
| 292 |
return
|
| 293 |
+
|
| 294 |
yield gr.update(visible=False), gr.update(visible=True), None, gpu_status + "\n\n🔄 Loading pipeline modules..."
|
| 295 |
logger.info(f"About to import pipeline module...")
|
| 296 |
+
|
| 297 |
try:
|
| 298 |
pipe = importlib.import_module("pipeline")
|
| 299 |
logger.info(f"✅ Pipeline module imported successfully")
|
|
|
|
| 300 |
pipeline_status = gpu_status + "\n\n✅ Pipeline modules loaded\n📹 Initializing video processing pipeline..."
|
| 301 |
yield gr.update(visible=False), gr.update(visible=True), None, pipeline_status
|
| 302 |
except Exception as e:
|
| 303 |
logger.error(f"❌ Pipeline import failed: {e}")
|
| 304 |
yield gr.update(visible=True), gr.update(visible=False), None, f"Pipeline import error: {e}"
|
| 305 |
return
|
| 306 |
+
|
| 307 |
logger.info(f"Calling pipe.process with video_path={video_path}, bg_path={bg_path}")
|
| 308 |
logger.info(f"=== CALLING PIPELINE.PROCESS ===")
|
| 309 |
+
|
| 310 |
# Enhanced status during processing with detailed stage tracking
|
| 311 |
stage_status = {
|
| 312 |
"current_stage": "Starting...",
|
| 313 |
"sam2_status": "⏳ Pending",
|
| 314 |
+
"matany_status": "⏳ Pending",
|
| 315 |
"composite_status": "⏳ Pending",
|
| 316 |
"audio_status": "⏳ Pending",
|
| 317 |
"frame_progress": ""
|
| 318 |
}
|
| 319 |
+
|
| 320 |
def format_status():
|
| 321 |
return (gpu_status + f"\n\n🚀 PROCESSING: {stage_status['current_stage']}\n\n" +
|
| 322 |
f"📊 PIPELINE STAGES:\n" +
|
|
|
|
| 325 |
f"🎬 Video Compositing: {stage_status['composite_status']}\n" +
|
| 326 |
f"🔊 Audio Muxing: {stage_status['audio_status']}\n" +
|
| 327 |
(f"\n📈 {stage_status['frame_progress']}" if stage_status['frame_progress'] else ""))
|
| 328 |
+
|
| 329 |
processing_status = format_status()
|
| 330 |
yield gr.update(visible=False), gr.update(visible=True), None, processing_status
|
| 331 |
+
|
| 332 |
# Create progress callback to update UI status with detailed tracking
|
| 333 |
+
def progress_callback(pct: float, msg: str):
|
| 334 |
nonlocal stage_status
|
| 335 |
+
|
| 336 |
# Update current stage and frame progress
|
| 337 |
+
stage_status['current_stage'] = msg
|
| 338 |
+
|
| 339 |
# Track specific stages
|
| 340 |
+
if "SAM2" in msg or "segmentation" in msg.lower():
|
| 341 |
+
if "complete" in msg.lower() or "✅" in msg:
|
| 342 |
stage_status['sam2_status'] = "✅ Complete"
|
| 343 |
else:
|
| 344 |
stage_status['sam2_status'] = "🔄 Running..."
|
| 345 |
+
elif "MatAnyone" in msg or "matting" in msg.lower():
|
| 346 |
+
if "complete" in msg.lower() or "✅" in msg:
|
|
|
|
| 347 |
stage_status['matany_status'] = "✅ Complete"
|
| 348 |
+
elif "failed" in msg.lower() or "fallback" in msg.lower():
|
| 349 |
stage_status['matany_status'] = "❌ Failed → Fallback"
|
| 350 |
else:
|
| 351 |
stage_status['matany_status'] = "🔄 Running..."
|
| 352 |
+
elif "composit" in msg.lower():
|
| 353 |
+
if "complete" in msg.lower() or "✅" in msg:
|
|
|
|
| 354 |
stage_status['composite_status'] = "✅ Complete"
|
| 355 |
else:
|
| 356 |
stage_status['composite_status'] = "🔄 Running..."
|
| 357 |
+
elif "audio" in msg.lower() or "mux" in msg.lower():
|
| 358 |
+
if "complete" in msg.lower() or "✅" in msg:
|
|
|
|
| 359 |
stage_status['audio_status'] = "✅ Complete"
|
| 360 |
else:
|
| 361 |
stage_status['audio_status'] = "🔄 Running..."
|
| 362 |
+
|
| 363 |
# Extract frame progress
|
| 364 |
+
if "/" in msg and any(word in msg.lower() for word in ["frame", "matting", "chunking"]):
|
| 365 |
+
stage_status['frame_progress'] = msg
|
| 366 |
+
|
| 367 |
updated_status = format_status()
|
| 368 |
return gr.update(visible=False), gr.update(visible=True), None, updated_status
|
| 369 |
+
|
| 370 |
try:
|
|
|
|
| 371 |
out_path, diag = pipe.process(
|
| 372 |
video_path=video_path,
|
| 373 |
bg_image_path=bg_path,
|
|
|
|
|
|
|
|
|
|
| 374 |
work_dir=None,
|
| 375 |
progress_callback=progress_callback
|
| 376 |
)
|
|
|
|
| 380 |
logger.error(f"❌ Pipeline.process failed: {e}")
|
| 381 |
import traceback
|
| 382 |
logger.error(f"Full traceback: {traceback.format_exc()}")
|
|
|
|
| 383 |
error_status = gpu_status + f"\n\n❌ PROCESSING FAILED\n🚨 Error: {str(e)[:200]}..."
|
| 384 |
yield gr.update(visible=True), gr.update(visible=False), None, error_status
|
| 385 |
return
|
| 386 |
+
|
| 387 |
if out_path:
|
| 388 |
# Enhanced final processing stats with detailed breakdown
|
| 389 |
fps = diag.get('fps', 'unknown')
|
|
|
|
| 391 |
sam2_ok = diag.get('sam2_ok', False)
|
| 392 |
matany_ok = diag.get('matany_ok', False)
|
| 393 |
processing_time = diag.get('total_time_sec', 0)
|
| 394 |
+
|
|
|
|
|
|
|
| 395 |
# Check mask validation results for quality feedback
|
| 396 |
mask_validation = diag.get('mask_validation', {})
|
| 397 |
mask_valid = mask_validation.get('valid', False)
|
| 398 |
mask_coverage = mask_validation.get('stats', {}).get('coverage_percent', 0)
|
| 399 |
+
|
| 400 |
# Get final GPU memory usage and verify GPU acceleration was used
|
| 401 |
try:
|
| 402 |
import torch
|
| 403 |
if torch.cuda.is_available():
|
| 404 |
final_memory = torch.cuda.memory_allocated() / (1024**3)
|
| 405 |
peak_memory = torch.cuda.max_memory_allocated() / (1024**3)
|
|
|
|
|
|
|
| 406 |
logger.info(f"GPU USAGE VERIFICATION:")
|
| 407 |
logger.info(f" Final memory allocated: {final_memory:.2f}GB")
|
| 408 |
logger.info(f" Peak memory used: {peak_memory:.2f}GB")
|
|
|
|
| 409 |
if peak_memory < 0.1: # Less than 100MB indicates CPU usage
|
| 410 |
+
logger.warning(f"⚠️ LOW GPU USAGE! Peak memory {peak_memory:.2f}GB suggests CPU fallback")
|
| 411 |
else:
|
| 412 |
logger.info(f"✅ GPU ACCELERATION CONFIRMED - Peak usage {peak_memory:.2f}GB")
|
|
|
|
| 413 |
torch.cuda.reset_peak_memory_stats() # Reset for next run
|
| 414 |
else:
|
| 415 |
final_memory = peak_memory = 0
|
|
|
|
| 417 |
except Exception as e:
|
| 418 |
logger.error(f"GPU memory check failed: {e}")
|
| 419 |
final_memory = peak_memory = 0
|
| 420 |
+
|
| 421 |
# Enhanced success message with segmentation quality info
|
| 422 |
segmentation_quality = ""
|
| 423 |
if mask_valid and mask_coverage > 0:
|
|
|
|
| 427 |
segmentation_quality = f"⚠️ High segmentation ({mask_coverage:.1f}% - check background)"
|
| 428 |
else:
|
| 429 |
segmentation_quality = f"✅ Person segmented ({mask_coverage:.1f}%)"
|
| 430 |
+
|
| 431 |
status_msg = gpu_status + f"""
|
| 432 |
|
| 433 |
🎉 PROCESSING COMPLETE!
|
| 434 |
+
✅ Stage 1: SAM2 segmentation {'✓' if sam2_ok else '✗'}
|
| 435 |
{segmentation_quality}
|
| 436 |
+
✅ Stage 2: MatAnyone matting {'✓' if matany_ok else '✗'}
|
| 437 |
✅ Stage 3: Final compositing complete
|
| 438 |
|
| 439 |
📊 RESULTS:
|
|
|
|
| 466 |
logger.error(f"Full traceback: {traceback.format_exc()}")
|
| 467 |
yield gr.update(visible=True), gr.update(visible=False), None, f"Processing error: {e}"
|
| 468 |
finally:
|
| 469 |
+
# Best-effort cleanup of any temp files
|
| 470 |
try:
|
| 471 |
if input_video is None and 'video_path' in locals() and video_path and os.path.exists(video_path):
|
| 472 |
os.unlink(video_path)
|
| 473 |
+
if 'bg_path' in locals() and bg_path and os.path.exists(bg_path):
|
| 474 |
+
os.unlink(bg_path)
|
| 475 |
except Exception:
|
| 476 |
pass
|
| 477 |
|
|
|
|
| 494 |
"""
|
| 495 |
|
| 496 |
with gr.Blocks(css=css, title="BackgroundFX Pro") as app:
|
| 497 |
+
gr.Markdown("# BackgroundFX Pro — SAM2 + MatAnyone")
|
| 498 |
|
| 499 |
with gr.Row():
|
| 500 |
status = _system_status()
|
|
|
|
| 536 |
result_video = gr.Video(label="Processed Video", height=400)
|
| 537 |
status_output = gr.Textbox(label="Processing Status", lines=8, max_lines=15, elem_classes=["status-box"])
|
| 538 |
gr.Markdown("""
|
| 539 |
+
### Pipeline
|
| 540 |
+
1. SAM2 Smart Person Detection → proper mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/mask_validation.py
CHANGED
|
@@ -1,5 +1,16 @@
|
|
|
|
|
| 1 |
"""
|
| 2 |
Mask validation utilities for BackgroundFX Pro.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
import numpy as np
|
|
@@ -7,6 +18,16 @@
|
|
| 7 |
from pathlib import Path
|
| 8 |
from typing import Dict, Union, Tuple, Optional
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
def validate_mask_for_matanyone_simple(mask_input: Union[np.ndarray, str, Path], target_hw=None, *_, **__) -> Tuple[bool, Dict, str]:
|
| 11 |
"""
|
| 12 |
Back/forward-compatible mask sanity check for MatAnyone/SAM2 pipelines.
|
|
@@ -22,6 +43,15 @@ def validate_mask_for_matanyone_simple(mask_input: Union[np.ndarray, str, Path],
|
|
| 22 |
- We only *validate*; inversion/repair is upstream's job.
|
| 23 |
- Always returns True unless the mask is unreadable or wrong rank.
|
| 24 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
# ---- load to np.uint8/float32 2D ----
|
| 26 |
try:
|
| 27 |
import torch # optional
|
|
@@ -32,14 +62,17 @@ def validate_mask_for_matanyone_simple(mask_input: Union[np.ndarray, str, Path],
|
|
| 32 |
if isinstance(mask_input, (str, Path)):
|
| 33 |
mask = cv2.imread(str(mask_input), cv2.IMREAD_GRAYSCALE)
|
| 34 |
if mask is None:
|
|
|
|
| 35 |
return False, {}, f"Mask not found or unreadable: {mask_input}"
|
| 36 |
# Load from torch tensor
|
| 37 |
elif (torch is not None) and isinstance(mask_input, torch.Tensor):
|
| 38 |
mask = mask_input.detach().cpu().numpy()
|
|
|
|
| 39 |
# Already numpy
|
| 40 |
elif isinstance(mask_input, np.ndarray):
|
| 41 |
mask = mask_input
|
| 42 |
else:
|
|
|
|
| 43 |
return False, {}, f"Unsupported mask type: {type(mask_input)}"
|
| 44 |
|
| 45 |
# If 3D, squeeze/convert to grayscale
|
|
@@ -49,12 +82,16 @@ def validate_mask_for_matanyone_simple(mask_input: Union[np.ndarray, str, Path],
|
|
| 49 |
else:
|
| 50 |
mask = np.squeeze(mask)
|
| 51 |
if mask.ndim != 2:
|
|
|
|
| 52 |
return False, {}, f"Mask must be 2D, got shape {mask.shape}"
|
| 53 |
|
|
|
|
|
|
|
| 54 |
# Optional resize to target (H, W)
|
| 55 |
if target_hw is not None and isinstance(target_hw, (tuple, list)) and len(target_hw) == 2:
|
| 56 |
H, W = int(target_hw[0]), int(target_hw[1])
|
| 57 |
if mask.shape != (H, W):
|
|
|
|
| 58 |
mask = cv2.resize(mask, (W, H), interpolation=cv2.INTER_LINEAR)
|
| 59 |
|
| 60 |
# Normalize to [0,1] float
|
|
@@ -69,6 +106,8 @@ def validate_mask_for_matanyone_simple(mask_input: Union[np.ndarray, str, Path],
|
|
| 69 |
|
| 70 |
# We keep validation permissive; upstream may invert/repair based on coverage
|
| 71 |
msg = f"Basic validation - {coverage:.1f}% coverage"
|
|
|
|
|
|
|
| 72 |
return True, stats, msg
|
| 73 |
|
| 74 |
def validate_mask_for_matanyone_advanced(mask: np.ndarray, min_foreground: float = 0.01) -> Tuple[bool, str]:
|
|
@@ -82,6 +121,8 @@ def validate_mask_for_matanyone_advanced(mask: np.ndarray, min_foreground: float
|
|
| 82 |
Returns:
|
| 83 |
Tuple of (is_valid, error_message)
|
| 84 |
"""
|
|
|
|
|
|
|
| 85 |
# Basic validation first
|
| 86 |
is_valid, msg = validate_mask_for_matanyone_simple(mask)
|
| 87 |
if not is_valid:
|
|
@@ -94,8 +135,10 @@ def validate_mask_for_matanyone_advanced(mask: np.ndarray, min_foreground: float
|
|
| 94 |
# Check foreground ratio
|
| 95 |
fg_ratio = mask.mean()
|
| 96 |
if fg_ratio < min_foreground:
|
|
|
|
| 97 |
return False, f"Foreground area too small ({fg_ratio:.1%} < {min_foreground:.0%})"
|
| 98 |
|
|
|
|
| 99 |
return True, ""
|
| 100 |
|
| 101 |
def preprocess_mask(mask: np.ndarray, target_size: Optional[Tuple[int, int]] = None) -> np.ndarray:
|
|
@@ -109,19 +152,33 @@ def preprocess_mask(mask: np.ndarray, target_size: Optional[Tuple[int, int]] = N
|
|
| 109 |
Returns:
|
| 110 |
Preprocessed mask (H,W) float32 in [0,1]
|
| 111 |
"""
|
|
|
|
|
|
|
| 112 |
# Ensure 2D
|
| 113 |
if mask.ndim == 3:
|
| 114 |
-
mask =
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
# Convert to float32 in [0,1]
|
| 117 |
if mask.dtype != np.float32:
|
| 118 |
mask = mask.astype(np.float32)
|
|
|
|
| 119 |
if mask.max() > 1.0:
|
| 120 |
mask = mask / 255.0
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
| 122 |
# Resize if needed
|
| 123 |
if target_size is not None and (mask.shape[0] != target_size[0] or mask.shape[1] != target_size[1]):
|
| 124 |
import cv2
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
Mask validation utilities for BackgroundFX Pro.
|
| 4 |
+
==============================================
|
| 5 |
+
- Validates masks for MatAnyone compatibility
|
| 6 |
+
- Ensures 2D masks [H,W] to avoid 5D tensor issues
|
| 7 |
+
- Aligned with torch==2.3.1+cu121, MatAnyone v1.0.0
|
| 8 |
+
|
| 9 |
+
Changes (2025-09-16):
|
| 10 |
+
- Aligned with updated pipeline.py and models/
|
| 11 |
+
- Added logging for mask shape and coverage
|
| 12 |
+
- Enhanced preprocess_mask for Torch tensors from SAM2
|
| 13 |
+
- Ensured 2D mask output for MatAnyone
|
| 14 |
"""
|
| 15 |
|
| 16 |
import numpy as np
|
|
|
|
| 18 |
from pathlib import Path
|
| 19 |
from typing import Dict, Union, Tuple, Optional
|
| 20 |
|
| 21 |
+
import logging
|
| 22 |
+
import importlib.metadata
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger("backgroundfx_pro")
|
| 25 |
+
if not logger.handlers:
|
| 26 |
+
h = logging.StreamHandler()
|
| 27 |
+
h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s"))
|
| 28 |
+
logger.addHandler(h)
|
| 29 |
+
logger.setLevel(logging.INFO)
|
| 30 |
+
|
| 31 |
def validate_mask_for_matanyone_simple(mask_input: Union[np.ndarray, str, Path], target_hw=None, *_, **__) -> Tuple[bool, Dict, str]:
|
| 32 |
"""
|
| 33 |
Back/forward-compatible mask sanity check for MatAnyone/SAM2 pipelines.
|
|
|
|
| 43 |
- We only *validate*; inversion/repair is upstream's job.
|
| 44 |
- Always returns True unless the mask is unreadable or wrong rank.
|
| 45 |
"""
|
| 46 |
+
logger.info(f"[MaskValidation] Validating mask: {mask_input}")
|
| 47 |
+
|
| 48 |
+
# Log MatAnyone version for compatibility check
|
| 49 |
+
try:
|
| 50 |
+
version = importlib.metadata.version("matanyone")
|
| 51 |
+
logger.info(f"[MaskValidation] MatAnyone version: {version}")
|
| 52 |
+
except Exception:
|
| 53 |
+
logger.info("[MaskValidation] MatAnyone version unknown")
|
| 54 |
+
|
| 55 |
# ---- load to np.uint8/float32 2D ----
|
| 56 |
try:
|
| 57 |
import torch # optional
|
|
|
|
| 62 |
if isinstance(mask_input, (str, Path)):
|
| 63 |
mask = cv2.imread(str(mask_input), cv2.IMREAD_GRAYSCALE)
|
| 64 |
if mask is None:
|
| 65 |
+
logger.error(f"[MaskValidation] Could not load mask: {mask_input}")
|
| 66 |
return False, {}, f"Mask not found or unreadable: {mask_input}"
|
| 67 |
# Load from torch tensor
|
| 68 |
elif (torch is not None) and isinstance(mask_input, torch.Tensor):
|
| 69 |
mask = mask_input.detach().cpu().numpy()
|
| 70 |
+
logger.info(f"[MaskValidation] Loaded Torch tensor mask: shape={mask.shape}, dtype={mask.dtype}")
|
| 71 |
# Already numpy
|
| 72 |
elif isinstance(mask_input, np.ndarray):
|
| 73 |
mask = mask_input
|
| 74 |
else:
|
| 75 |
+
logger.error(f"[MaskValidation] Unsupported mask type: {type(mask_input)}")
|
| 76 |
return False, {}, f"Unsupported mask type: {type(mask_input)}"
|
| 77 |
|
| 78 |
# If 3D, squeeze/convert to grayscale
|
|
|
|
| 82 |
else:
|
| 83 |
mask = np.squeeze(mask)
|
| 84 |
if mask.ndim != 2:
|
| 85 |
+
logger.error(f"[MaskValidation] Mask must be 2D, got shape {mask.shape}")
|
| 86 |
return False, {}, f"Mask must be 2D, got shape {mask.shape}"
|
| 87 |
|
| 88 |
+
logger.info(f"[MaskValidation] Loaded mask shape: {mask.shape}, dtype: {mask.dtype}")
|
| 89 |
+
|
| 90 |
# Optional resize to target (H, W)
|
| 91 |
if target_hw is not None and isinstance(target_hw, (tuple, list)) and len(target_hw) == 2:
|
| 92 |
H, W = int(target_hw[0]), int(target_hw[1])
|
| 93 |
if mask.shape != (H, W):
|
| 94 |
+
logger.info(f"[MaskValidation] Resizing mask from {mask.shape} to {target_hw}")
|
| 95 |
mask = cv2.resize(mask, (W, H), interpolation=cv2.INTER_LINEAR)
|
| 96 |
|
| 97 |
# Normalize to [0,1] float
|
|
|
|
| 106 |
|
| 107 |
# We keep validation permissive; upstream may invert/repair based on coverage
|
| 108 |
msg = f"Basic validation - {coverage:.1f}% coverage"
|
| 109 |
+
logger.info(f"[MaskValidation] Validation result: {msg}, valid: True, coverage: {coverage:.1f}%")
|
| 110 |
+
|
| 111 |
return True, stats, msg
|
| 112 |
|
| 113 |
def validate_mask_for_matanyone_advanced(mask: np.ndarray, min_foreground: float = 0.01) -> Tuple[bool, str]:
|
|
|
|
| 121 |
Returns:
|
| 122 |
Tuple of (is_valid, error_message)
|
| 123 |
"""
|
| 124 |
+
logger.info(f"[MaskValidation] Advanced validation on mask shape: {mask.shape}")
|
| 125 |
+
|
| 126 |
# Basic validation first
|
| 127 |
is_valid, msg = validate_mask_for_matanyone_simple(mask)
|
| 128 |
if not is_valid:
|
|
|
|
| 135 |
# Check foreground ratio
|
| 136 |
fg_ratio = mask.mean()
|
| 137 |
if fg_ratio < min_foreground:
|
| 138 |
+
logger.warning(f"[MaskValidation] Foreground area too small ({fg_ratio:.1%} < {min_foreground:.0%})")
|
| 139 |
return False, f"Foreground area too small ({fg_ratio:.1%} < {min_foreground:.0%})"
|
| 140 |
|
| 141 |
+
logger.info(f"[MaskValidation] Advanced validation passed: fg_ratio={fg_ratio:.1%}")
|
| 142 |
return True, ""
|
| 143 |
|
| 144 |
def preprocess_mask(mask: np.ndarray, target_size: Optional[Tuple[int, int]] = None) -> np.ndarray:
|
|
|
|
| 152 |
Returns:
|
| 153 |
Preprocessed mask (H,W) float32 in [0,1]
|
| 154 |
"""
|
| 155 |
+
logger.info(f"[MaskValidation] Preprocessing mask: shape={mask.shape}, dtype={mask.dtype}")
|
| 156 |
+
|
| 157 |
# Ensure 2D
|
| 158 |
if mask.ndim == 3:
|
| 159 |
+
mask = np.squeeze(mask, axis=2)
|
| 160 |
+
logger.info(f"[MaskValidation] Squeezed 3D mask to 2D: {mask.shape}")
|
| 161 |
+
|
| 162 |
+
if mask.ndim != 2:
|
| 163 |
+
logger.error(f"[MaskValidation] Preprocessing failed: mask must be 2D, got {mask.shape}")
|
| 164 |
+
raise ValueError(f"Mask must be 2D, got shape {mask.shape}")
|
| 165 |
+
|
| 166 |
# Convert to float32 in [0,1]
|
| 167 |
if mask.dtype != np.float32:
|
| 168 |
mask = mask.astype(np.float32)
|
| 169 |
+
logger.info(f"[MaskValidation] Converted dtype to float32")
|
| 170 |
if mask.max() > 1.0:
|
| 171 |
mask = mask / 255.0
|
| 172 |
+
logger.info(f"[MaskValidation] Normalized to [0,1] range")
|
| 173 |
+
|
| 174 |
+
mask = np.clip(mask, 0.0, 1.0)
|
| 175 |
+
|
| 176 |
# Resize if needed
|
| 177 |
if target_size is not None and (mask.shape[0] != target_size[0] or mask.shape[1] != target_size[1]):
|
| 178 |
import cv2
|
| 179 |
+
H, W = target_size
|
| 180 |
+
mask = cv2.resize(mask, (W, H), interpolation=cv2.INTER_LINEAR)
|
| 181 |
+
logger.info(f"[MaskValidation] Resized mask to {target_size}")
|
| 182 |
+
|
| 183 |
+
logger.info(f"[MaskValidation] Preprocessed mask: shape={mask.shape}, dtype={mask.dtype}, range=[{mask.min():.3f}, {mask.max():.3f}]")
|
| 184 |
+
return mask
|