agent 3.0
Browse files- models/matanyone_loader.py +210 -79
models/matanyone_loader.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
# =============================================================================
|
| 3 |
-
# MatAnyone Adapter (streaming, API-agnostic) β with chapter markers
|
| 4 |
# =============================================================================
|
| 5 |
"""
|
| 6 |
- Supports multiple MatAnyone variants:
|
|
@@ -40,90 +40,152 @@
|
|
| 40 |
# =============================================================================
|
| 41 |
# CHAPTER 1 β Small utilities
|
| 42 |
# =============================================================================
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
return
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
try:
|
| 48 |
-
cb(pct, msg) # preferred 2-arg
|
| 49 |
-
except TypeError:
|
| 50 |
try:
|
| 51 |
-
cb(msg)
|
| 52 |
except TypeError:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
pass
|
| 54 |
|
| 55 |
|
|
|
|
| 56 |
class MatAnyError(RuntimeError):
|
| 57 |
"""Custom exception for MatAnyone processing errors."""
|
| 58 |
pass
|
| 59 |
|
| 60 |
|
|
|
|
| 61 |
def _cuda_snapshot(device: Optional[torch.device] = None) -> str:
|
| 62 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
if not torch.cuda.is_available():
|
| 64 |
-
return
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
resv = torch.cuda.memory_reserved(idx) / 1e9
|
| 71 |
-
return f"device={idx}, name={name}, alloc={alloc:.2f}GB, reserved={resv:.2f}GB"
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
def _safe_empty_cache():
|
| 75 |
-
"""Synchronize and empty CUDA cache if present (best-effort)."""
|
| 76 |
-
if torch.cuda.is_available():
|
| 77 |
-
try:
|
| 78 |
-
torch.cuda.synchronize()
|
| 79 |
-
except Exception:
|
| 80 |
-
pass
|
| 81 |
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
|
| 84 |
def _read_mask_hw(mask_path: Path, target_hw: Tuple[int, int]) -> np.ndarray:
|
| 85 |
-
"""
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
raise MatAnyError(f"Seed mask not found: {mask_path}")
|
|
|
|
| 88 |
mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
|
| 89 |
-
if mask is None:
|
| 90 |
-
raise MatAnyError(f"Failed to read seed mask: {mask_path}")
|
|
|
|
| 91 |
H, W = target_hw
|
|
|
|
|
|
|
| 92 |
if mask.shape[:2] != (H, W):
|
| 93 |
mask = cv2.resize(mask, (W, H), interpolation=cv2.INTER_LINEAR)
|
|
|
|
| 94 |
maskf = (mask.astype(np.float32) / 255.0).clip(0.0, 1.0)
|
| 95 |
return maskf
|
| 96 |
|
| 97 |
|
| 98 |
def _to_chw01(img_bgr: np.ndarray) -> np.ndarray:
|
| 99 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
| 101 |
rgbf = rgb.astype(np.float32) / 255.0
|
| 102 |
-
chw = np.transpose(rgbf, (2, 0, 1)) # C,H,W
|
| 103 |
return chw
|
| 104 |
|
| 105 |
|
| 106 |
-
def _validate_nonempty(file_path: Path) -> None:
|
| 107 |
-
"""Ensure output file exists and is non-empty."""
|
| 108 |
-
if not file_path.exists() or file_path.stat().st_size == 0:
|
| 109 |
-
raise MatAnyError(f"Output file missing/empty: {file_path}")
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
def _select_matany_mode(core) -> str:
|
| 113 |
-
"""
|
| 114 |
-
Inspect available APIs.
|
| 115 |
-
Priority: process_video > process_frame > step
|
| 116 |
-
(Note: we still force frame mode in _lazy_init; this helper is used by chunk helper.)
|
| 117 |
-
"""
|
| 118 |
-
if hasattr(core, "process_video") and callable(getattr(core, "process_video")):
|
| 119 |
-
return "process_video"
|
| 120 |
-
if hasattr(core, "process_frame") and callable(getattr(core, "process_frame")):
|
| 121 |
-
return "process_frame"
|
| 122 |
-
if hasattr(core, "step") and callable(getattr(core, "step")):
|
| 123 |
-
return "step"
|
| 124 |
-
raise MatAnyError("No supported MatAnyone API on core (process_video/process_frame/step).")
|
| 125 |
-
|
| 126 |
-
|
| 127 |
# =============================================================================
|
| 128 |
# CHAPTER 2 β Main session
|
| 129 |
# =============================================================================
|
|
@@ -153,6 +215,7 @@ def __init__(self, device: Optional[str] = None, precision: str = "auto"):
|
|
| 153 |
self._core = None
|
| 154 |
self._api_mode = None
|
| 155 |
self._initialized = False
|
|
|
|
| 156 |
self._lazy_init()
|
| 157 |
|
| 158 |
log.info(f"Initialized MatAnyoneSession on {self.device} | precision={self.precision}, use_fp16={self.use_fp16}")
|
|
@@ -239,7 +302,7 @@ def _maybe_amp(self):
|
|
| 239 |
return torch.amp.autocast(device_type="cuda", enabled=enabled and self.use_fp16)
|
| 240 |
|
| 241 |
# -------------------------------------------------------------------------
|
| 242 |
-
# 2.4 β Frame validation
|
| 243 |
# -------------------------------------------------------------------------
|
| 244 |
def _validate_input_frame(self, frame: np.ndarray) -> None:
|
| 245 |
if not isinstance(frame, np.ndarray):
|
|
@@ -249,43 +312,105 @@ def _validate_input_frame(self, frame: np.ndarray) -> None:
|
|
| 249 |
if frame.ndim != 3 or frame.shape[2] != 3:
|
| 250 |
raise MatAnyError(f"Frame must be HWC with 3 channels, got {frame.shape}")
|
| 251 |
|
| 252 |
-
|
|
|
|
|
|
|
|
|
|
| 253 |
"""
|
| 254 |
-
|
| 255 |
-
|
|
|
|
| 256 |
"""
|
| 257 |
-
self.
|
| 258 |
|
| 259 |
-
#
|
| 260 |
-
|
| 261 |
-
|
| 262 |
|
| 263 |
-
|
| 264 |
-
mask_t = None
|
| 265 |
if is_first and seed_1hw is not None:
|
|
|
|
| 266 |
if seed_1hw.ndim == 3 and seed_1hw.shape[0] == 1:
|
| 267 |
seed_hw = seed_1hw[0]
|
| 268 |
elif seed_1hw.ndim == 2:
|
| 269 |
seed_hw = seed_1hw
|
| 270 |
else:
|
| 271 |
raise MatAnyError(f"seed mask must be 1HW or HW; got {seed_1hw.shape}")
|
| 272 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
|
| 274 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
try:
|
| 276 |
with torch.no_grad(), self._maybe_amp():
|
| 277 |
-
|
| 278 |
-
out = self._core.step(img_t, mask_t) if mask_t is not None else self._core.step(img_t)
|
| 279 |
-
elif self._api_mode == "process_frame":
|
| 280 |
-
out = self._core.process_frame(img_t, mask_t)
|
| 281 |
-
else:
|
| 282 |
-
raise MatAnyError("Internal error: _run_frame used in non-frame mode")
|
| 283 |
except torch.cuda.OutOfMemoryError as e:
|
| 284 |
snap = _cuda_snapshot(self.device)
|
| 285 |
self._log_gpu_memory()
|
| 286 |
raise MatAnyError(f"CUDA OOM while processing frame | {snap}") from e
|
| 287 |
except RuntimeError as e:
|
| 288 |
-
# If itβs a CUDA-side runtime issue, annotate with snapshot
|
| 289 |
if "CUDA" in str(e):
|
| 290 |
snap = _cuda_snapshot(self.device)
|
| 291 |
self._log_gpu_memory()
|
|
@@ -296,20 +421,26 @@ def _run_frame(self, frame_bgr: np.ndarray, seed_1hw: Optional[np.ndarray], is_f
|
|
| 296 |
|
| 297 |
# Normalize to pure 2D numpy [0,1]
|
| 298 |
if isinstance(out, torch.Tensor):
|
| 299 |
-
alpha_np = out.detach().float().
|
| 300 |
else:
|
| 301 |
-
alpha_np = np.asarray(out
|
| 302 |
-
|
| 303 |
-
|
|
|
|
|
|
|
|
|
|
| 304 |
|
|
|
|
| 305 |
alpha_np = np.squeeze(alpha_np)
|
| 306 |
if alpha_np.ndim != 2:
|
| 307 |
raise MatAnyError(f"Expected 2D alpha matte; got shape {alpha_np.shape}")
|
| 308 |
|
| 309 |
-
|
|
|
|
|
|
|
| 310 |
|
| 311 |
# -------------------------------------------------------------------------
|
| 312 |
-
# 2.
|
| 313 |
# -------------------------------------------------------------------------
|
| 314 |
def _harvest_process_video_output(self, res, out_dir: Path, base: str) -> Tuple[Path, Path]:
|
| 315 |
"""
|
|
@@ -320,7 +451,7 @@ def _harvest_process_video_output(self, res, out_dir: Path, base: str) -> Tuple[
|
|
| 320 |
alpha_mp4 = out_dir / "alpha.mp4"
|
| 321 |
fg_mp4 = out_dir / "fg.mp4"
|
| 322 |
|
| 323 |
-
# Dict style
|
| 324 |
if isinstance(res, dict):
|
| 325 |
cand_alpha = res.get("alpha") or res.get("alpha_path") or res.get("matte") or res.get("matte_path")
|
| 326 |
cand_fg = res.get("fg") or res.get("fg_path") or res.get("foreground") or res.get("foreground_path")
|
|
@@ -359,7 +490,7 @@ def _harvest_process_video_output(self, res, out_dir: Path, base: str) -> Tuple[
|
|
| 359 |
raise MatAnyError("MatAnyone.process_video did not yield discoverable output paths.")
|
| 360 |
|
| 361 |
# -------------------------------------------------------------------------
|
| 362 |
-
# 2.
|
| 363 |
# -------------------------------------------------------------------------
|
| 364 |
def process_stream(
|
| 365 |
self,
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
# =============================================================================
|
| 3 |
+
# MatAnyone Adapter (streaming, API-agnostic) β with chapter markers + layout probe
|
| 4 |
# =============================================================================
|
| 5 |
"""
|
| 6 |
- Supports multiple MatAnyone variants:
|
|
|
|
| 40 |
# =============================================================================
|
| 41 |
# CHAPTER 1 β Small utilities
|
| 42 |
# =============================================================================
|
| 43 |
+
|
| 44 |
+
# --- Progress callback controls ---
|
| 45 |
+
def _env_flag(name: str, default: str = "0") -> bool:
|
| 46 |
+
return os.getenv(name, default).strip() in {"1", "true", "TRUE", "yes", "YES", "on", "ON"}
|
| 47 |
+
|
| 48 |
+
_PROGRESS_CB_ENABLED = _env_flag("MATANY_PROGRESS", "1")
|
| 49 |
+
_PROGRESS_MIN_INTERVAL = float(os.getenv("MATANY_PROGRESS_MIN_SEC", "0.25"))
|
| 50 |
+
_progress_state = {"t": 0.0, "last": None, "disabled": False}
|
| 51 |
+
|
| 52 |
+
def _emit_progress(cb, pct: float, msg: str, *, force: bool = False) -> None:
|
| 53 |
+
"""
|
| 54 |
+
Safe progress emitter:
|
| 55 |
+
- Respects MATANY_PROGRESS and rate-limits updates.
|
| 56 |
+
- Never raises upstream; disables itself if the callback misbehaves.
|
| 57 |
+
- Accepts either 2-arg (pct, msg) or legacy 1-arg (msg) callbacks.
|
| 58 |
+
"""
|
| 59 |
+
if not cb or not _PROGRESS_CB_ENABLED or _progress_state["disabled"]:
|
| 60 |
return
|
| 61 |
+
now = time.time()
|
| 62 |
+
if not force and (now - _progress_state["t"] < _PROGRESS_MIN_INTERVAL) and msg == _progress_state["last"]:
|
| 63 |
+
return
|
| 64 |
+
|
| 65 |
try:
|
|
|
|
|
|
|
| 66 |
try:
|
| 67 |
+
cb(pct, msg) # preferred signature
|
| 68 |
except TypeError:
|
| 69 |
+
cb(msg) # legacy signature
|
| 70 |
+
_progress_state["t"] = now
|
| 71 |
+
_progress_state["last"] = msg
|
| 72 |
+
except Exception as e:
|
| 73 |
+
# Permanently disable to avoid log spam and user-facing crashes
|
| 74 |
+
_progress_state["disabled"] = True
|
| 75 |
+
try:
|
| 76 |
+
log.warning(f"[progress-cb] disabled due to exception: {e}")
|
| 77 |
+
except Exception:
|
| 78 |
pass
|
| 79 |
|
| 80 |
|
| 81 |
+
# --- Errors ---
|
| 82 |
class MatAnyError(RuntimeError):
|
| 83 |
"""Custom exception for MatAnyone processing errors."""
|
| 84 |
pass
|
| 85 |
|
| 86 |
|
| 87 |
+
# --- CUDA helpers ---
|
| 88 |
def _cuda_snapshot(device: Optional[torch.device] = None) -> str:
|
| 89 |
+
"""
|
| 90 |
+
Return a short, exception-safe string describing CUDA memory on a device.
|
| 91 |
+
"""
|
| 92 |
+
try:
|
| 93 |
+
if not torch.cuda.is_available():
|
| 94 |
+
return "CUDA: N/A"
|
| 95 |
+
idx = 0
|
| 96 |
+
if isinstance(device, torch.device) and device.type == "cuda" and device.index is not None:
|
| 97 |
+
idx = device.index
|
| 98 |
+
name = torch.cuda.get_device_name(idx)
|
| 99 |
+
alloc = torch.cuda.memory_allocated(idx) / (1024 ** 3)
|
| 100 |
+
resv = torch.cuda.memory_reserved(idx) / (1024 ** 3)
|
| 101 |
+
return f"device={idx}, name={name}, alloc={alloc:.2f}GB, reserved={resv:.2f}GB"
|
| 102 |
+
except Exception as e:
|
| 103 |
+
return f"CUDA: snapshot-error: {e!r}"
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def _safe_empty_cache() -> None:
|
| 107 |
+
"""Try hard to release CUDA cache; never raises."""
|
| 108 |
if not torch.cuda.is_available():
|
| 109 |
+
return
|
| 110 |
+
try:
|
| 111 |
+
torch.cuda.synchronize()
|
| 112 |
+
except Exception:
|
| 113 |
+
pass
|
| 114 |
+
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
torch.cuda.empty_cache()
|
| 116 |
+
except Exception:
|
| 117 |
+
pass
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def _supports_fp16(device: Optional[torch.device]) -> bool:
|
| 121 |
+
"""
|
| 122 |
+
Best-effort check whether the device can benefit from fp16.
|
| 123 |
+
Returns False for CPU; True for most modern NVIDIA GPUs.
|
| 124 |
+
"""
|
| 125 |
+
if not isinstance(device, torch.device) or device.type != "cuda" or not torch.cuda.is_available():
|
| 126 |
+
return False
|
| 127 |
+
try:
|
| 128 |
+
major, minor = torch.cuda.get_device_capability(device.index or 0)
|
| 129 |
+
# Volta (7.0)+ generally supports fast fp16 paths; T4 is 7.5.
|
| 130 |
+
return (major, minor) >= (7, 0)
|
| 131 |
+
except Exception:
|
| 132 |
+
return True # be optimistic if capability query fails
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def _ensure_device_usable(device: torch.device) -> None:
|
| 136 |
+
"""
|
| 137 |
+
Validate that the chosen device is actually usable.
|
| 138 |
+
Raise MatAnyError early if CUDA is requested but unavailable.
|
| 139 |
+
"""
|
| 140 |
+
if device.type == "cuda" and not torch.cuda.is_available():
|
| 141 |
+
raise MatAnyError("CUDA device requested but torch.cuda.is_available() == False")
|
| 142 |
+
if device.type not in {"cuda", "cpu"}:
|
| 143 |
+
raise MatAnyError(f"Unsupported device type: {device.type!r}")
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# --- File & image helpers ---
|
| 147 |
+
def _validate_nonempty(file_path: Path) -> None:
|
| 148 |
+
if (not isinstance(file_path, Path)) or (not file_path.exists()) or file_path.stat().st_size <= 0:
|
| 149 |
+
raise MatAnyError(f"Output file missing/empty: {file_path}")
|
| 150 |
|
| 151 |
|
| 152 |
def _read_mask_hw(mask_path: Path, target_hw: Tuple[int, int]) -> np.ndarray:
|
| 153 |
+
"""
|
| 154 |
+
Read a mask image, return float32 in [0,1] with shape (H, W).
|
| 155 |
+
Validates content and resizes to target (H, W).
|
| 156 |
+
"""
|
| 157 |
+
if not isinstance(mask_path, (str, Path)):
|
| 158 |
+
raise MatAnyError(f"Seed mask path must be str/Path, got {type(mask_path)}")
|
| 159 |
+
mask_path = Path(mask_path)
|
| 160 |
+
if not mask_path.exists():
|
| 161 |
raise MatAnyError(f"Seed mask not found: {mask_path}")
|
| 162 |
+
|
| 163 |
mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
|
| 164 |
+
if mask is None or mask.size == 0:
|
| 165 |
+
raise MatAnyError(f"Failed to read seed mask or empty file: {mask_path}")
|
| 166 |
+
|
| 167 |
H, W = target_hw
|
| 168 |
+
if mask.ndim != 2:
|
| 169 |
+
raise MatAnyError(f"Seed mask must be single-channel; got shape {mask.shape}")
|
| 170 |
if mask.shape[:2] != (H, W):
|
| 171 |
mask = cv2.resize(mask, (W, H), interpolation=cv2.INTER_LINEAR)
|
| 172 |
+
|
| 173 |
maskf = (mask.astype(np.float32) / 255.0).clip(0.0, 1.0)
|
| 174 |
return maskf
|
| 175 |
|
| 176 |
|
| 177 |
def _to_chw01(img_bgr: np.ndarray) -> np.ndarray:
|
| 178 |
+
"""
|
| 179 |
+
BGR uint8 (H, W, 3) -> RGB float32 (C, H, W) in [0,1].
|
| 180 |
+
"""
|
| 181 |
+
if not isinstance(img_bgr, np.ndarray) or img_bgr.dtype != np.uint8 or img_bgr.ndim != 3 or img_bgr.shape[2] != 3:
|
| 182 |
+
raise MatAnyError(f"Frame must be uint8 HWC BGR; got {type(img_bgr)}, shape={getattr(img_bgr, 'shape', None)}")
|
| 183 |
rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
| 184 |
rgbf = rgb.astype(np.float32) / 255.0
|
| 185 |
+
chw = np.transpose(rgbf, (2, 0, 1)) # C, H, W
|
| 186 |
return chw
|
| 187 |
|
| 188 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
# =============================================================================
|
| 190 |
# CHAPTER 2 β Main session
|
| 191 |
# =============================================================================
|
|
|
|
| 215 |
self._core = None
|
| 216 |
self._api_mode = None
|
| 217 |
self._initialized = False
|
| 218 |
+
self._layout_locked: Optional[str] = None # 'BCHW+B1HW', 'CHW+HW', etc.
|
| 219 |
self._lazy_init()
|
| 220 |
|
| 221 |
log.info(f"Initialized MatAnyoneSession on {self.device} | precision={self.precision}, use_fp16={self.use_fp16}")
|
|
|
|
| 302 |
return torch.amp.autocast(device_type="cuda", enabled=enabled and self.use_fp16)
|
| 303 |
|
| 304 |
# -------------------------------------------------------------------------
|
| 305 |
+
# 2.4 β Frame validation
|
| 306 |
# -------------------------------------------------------------------------
|
| 307 |
def _validate_input_frame(self, frame: np.ndarray) -> None:
|
| 308 |
if not isinstance(frame, np.ndarray):
|
|
|
|
| 312 |
if frame.ndim != 3 or frame.shape[2] != 3:
|
| 313 |
raise MatAnyError(f"Frame must be HWC with 3 channels, got {frame.shape}")
|
| 314 |
|
| 315 |
+
# -------------------------------------------------------------------------
|
| 316 |
+
# 2.5 β Core call helper with first-frame layout probe (locks after success)
|
| 317 |
+
# -------------------------------------------------------------------------
|
| 318 |
+
def _call_core_frame(self, img_chw: np.ndarray, seed_1hw: Optional[np.ndarray], is_first: bool):
|
| 319 |
"""
|
| 320 |
+
Calls MatAnyone frame API trying a small set of plausible layouts.
|
| 321 |
+
Locks layout after the first successful call to avoid repeated probing.
|
| 322 |
+
Returns the raw output (torch.Tensor or numpy).
|
| 323 |
"""
|
| 324 |
+
core = self._core
|
| 325 |
|
| 326 |
+
# Build base tensors
|
| 327 |
+
img_t_chw = torch.from_numpy(img_chw).to(self.device) # (3,H,W)
|
| 328 |
+
H, W = img_chw.shape[1], img_chw.shape[2]
|
| 329 |
|
| 330 |
+
mask_t_hw = None
|
|
|
|
| 331 |
if is_first and seed_1hw is not None:
|
| 332 |
+
# Ensure pure HW float32 in [0,1]
|
| 333 |
if seed_1hw.ndim == 3 and seed_1hw.shape[0] == 1:
|
| 334 |
seed_hw = seed_1hw[0]
|
| 335 |
elif seed_1hw.ndim == 2:
|
| 336 |
seed_hw = seed_1hw
|
| 337 |
else:
|
| 338 |
raise MatAnyError(f"seed mask must be 1HW or HW; got {seed_1hw.shape}")
|
| 339 |
+
mask_t_hw = torch.from_numpy(seed_hw.astype(np.float32)).to(self.device)
|
| 340 |
+
|
| 341 |
+
def _do_call(layout: str):
|
| 342 |
+
"""Dispatch according to a named layout."""
|
| 343 |
+
if layout == "BCHW+B1HW": # Preferred for many PyTorch models
|
| 344 |
+
img_in = img_t_chw.unsqueeze(0).contiguous() # (1,3,H,W)
|
| 345 |
+
mask_in = mask_t_hw.unsqueeze(0).unsqueeze(0).contiguous() if mask_t_hw is not None else None # (1,1,H,W)
|
| 346 |
+
elif layout == "CHW+HW": # Some APIs accept unbatched tensors
|
| 347 |
+
img_in = img_t_chw # (3,H,W)
|
| 348 |
+
mask_in = mask_t_hw if mask_t_hw is not None else None # (H,W)
|
| 349 |
+
elif layout == "BCHW+HW":
|
| 350 |
+
img_in = img_t_chw.unsqueeze(0).contiguous() # (1,3,H,W)
|
| 351 |
+
mask_in = mask_t_hw if mask_t_hw is not None else None # (H,W)
|
| 352 |
+
elif layout == "CHW+1HW":
|
| 353 |
+
img_in = img_t_chw # (3,H,W)
|
| 354 |
+
mask_in = mask_t_hw.unsqueeze(0).contiguous() if mask_t_hw is not None else None # (1,H,W)
|
| 355 |
+
else:
|
| 356 |
+
raise MatAnyError(f"Unknown layout spec: {layout}")
|
| 357 |
+
|
| 358 |
+
if self._api_mode == "step":
|
| 359 |
+
return core.step(img_in, mask_in) if mask_in is not None else core.step(img_in)
|
| 360 |
+
elif self._api_mode == "process_frame":
|
| 361 |
+
return core.process_frame(img_in, mask_in)
|
| 362 |
+
else:
|
| 363 |
+
raise MatAnyError("Internal error: frame dispatch used in non-frame mode")
|
| 364 |
|
| 365 |
+
# If layout was already found, use it directly
|
| 366 |
+
if self._layout_locked is not None:
|
| 367 |
+
try:
|
| 368 |
+
return _do_call(self._layout_locked)
|
| 369 |
+
except Exception as e:
|
| 370 |
+
# If a previously-working layout starts failing, surface clear error
|
| 371 |
+
raise MatAnyError(f"MatAnyone call failed with locked layout {self._layout_locked}: {e}")
|
| 372 |
+
|
| 373 |
+
# First-frame probe: try a few reasonable layouts in this order
|
| 374 |
+
probe_order = ["BCHW+B1HW", "CHW+HW", "BCHW+HW", "CHW+1HW"]
|
| 375 |
+
last_err: Optional[str] = None
|
| 376 |
+
|
| 377 |
+
for layout in probe_order:
|
| 378 |
+
try:
|
| 379 |
+
out = _do_call(layout)
|
| 380 |
+
# Success β lock layout for subsequent frames
|
| 381 |
+
self._layout_locked = layout
|
| 382 |
+
log.info(f"[MATANY] First-frame layout locked: {layout} (H={H}, W={W})")
|
| 383 |
+
return out
|
| 384 |
+
except Exception as e:
|
| 385 |
+
last_err = str(e)
|
| 386 |
+
log.warning(f"[MATANY] Layout attempt failed ({layout}): {last_err}")
|
| 387 |
+
|
| 388 |
+
# If we reach here, all attempts failed
|
| 389 |
+
snap = _cuda_snapshot(self.device)
|
| 390 |
+
raise MatAnyError(f"MatAnyone first-frame probe failed for all layouts. Last error: {last_err} | {snap}")
|
| 391 |
+
|
| 392 |
+
# -------------------------------------------------------------------------
|
| 393 |
+
# 2.6 β Frame runner (normalizes output to 2D [0,1])
|
| 394 |
+
# -------------------------------------------------------------------------
|
| 395 |
+
def _run_frame(self, frame_bgr: np.ndarray, seed_1hw: Optional[np.ndarray], is_first: bool) -> np.ndarray:
|
| 396 |
+
"""
|
| 397 |
+
Run a single frame through MatAnyone.
|
| 398 |
+
Returns: alpha matte as 2D np.float32 in [0,1].
|
| 399 |
+
"""
|
| 400 |
+
self._validate_input_frame(frame_bgr)
|
| 401 |
+
|
| 402 |
+
# Image -> CHW float32 [0,1]
|
| 403 |
+
img_chw = _to_chw01(frame_bgr) # (3,H,W)
|
| 404 |
+
|
| 405 |
+
# Dispatch (with autocast + no_grad)
|
| 406 |
try:
|
| 407 |
with torch.no_grad(), self._maybe_amp():
|
| 408 |
+
out = self._call_core_frame(img_chw, seed_1hw, is_first=is_first)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
except torch.cuda.OutOfMemoryError as e:
|
| 410 |
snap = _cuda_snapshot(self.device)
|
| 411 |
self._log_gpu_memory()
|
| 412 |
raise MatAnyError(f"CUDA OOM while processing frame | {snap}") from e
|
| 413 |
except RuntimeError as e:
|
|
|
|
| 414 |
if "CUDA" in str(e):
|
| 415 |
snap = _cuda_snapshot(self.device)
|
| 416 |
self._log_gpu_memory()
|
|
|
|
| 421 |
|
| 422 |
# Normalize to pure 2D numpy [0,1]
|
| 423 |
if isinstance(out, torch.Tensor):
|
| 424 |
+
alpha_np = out.detach().float().squeeze().cpu().numpy()
|
| 425 |
else:
|
| 426 |
+
alpha_np = np.asarray(out)
|
| 427 |
+
|
| 428 |
+
# Scale if it looks like 0..255
|
| 429 |
+
alpha_np = alpha_np.astype(np.float32)
|
| 430 |
+
if alpha_np.max() > 1.0:
|
| 431 |
+
alpha_np = alpha_np / 255.0
|
| 432 |
|
| 433 |
+
# In case model returns shape like (1,H,W) or (1,1,H,W), squeeze to (H,W)
|
| 434 |
alpha_np = np.squeeze(alpha_np)
|
| 435 |
if alpha_np.ndim != 2:
|
| 436 |
raise MatAnyError(f"Expected 2D alpha matte; got shape {alpha_np.shape}")
|
| 437 |
|
| 438 |
+
# Clamp to [0,1]
|
| 439 |
+
alpha_np = np.clip(alpha_np, 0.0, 1.0).astype(np.float32)
|
| 440 |
+
return alpha_np
|
| 441 |
|
| 442 |
# -------------------------------------------------------------------------
|
| 443 |
+
# 2.7 β process_video harvesting (kept for completeness; not used in forced frame mode)
|
| 444 |
# -------------------------------------------------------------------------
|
| 445 |
def _harvest_process_video_output(self, res, out_dir: Path, base: str) -> Tuple[Path, Path]:
|
| 446 |
"""
|
|
|
|
| 451 |
alpha_mp4 = out_dir / "alpha.mp4"
|
| 452 |
fg_mp4 = out_dir / "fg.mp4"
|
| 453 |
|
| 454 |
+
# Dict style
|
| 455 |
if isinstance(res, dict):
|
| 456 |
cand_alpha = res.get("alpha") or res.get("alpha_path") or res.get("matte") or res.get("matte_path")
|
| 457 |
cand_fg = res.get("fg") or res.get("fg_path") or res.get("foreground") or res.get("foreground_path")
|
|
|
|
| 490 |
raise MatAnyError("MatAnyone.process_video did not yield discoverable output paths.")
|
| 491 |
|
| 492 |
# -------------------------------------------------------------------------
|
| 493 |
+
# 2.8 β Public API: process_stream
|
| 494 |
# -------------------------------------------------------------------------
|
| 495 |
def process_stream(
|
| 496 |
self,
|