MogensR commited on
Commit
386575c
·
1 Parent(s): 87688ee
Files changed (1) hide show
  1. models/matanyone_loader.py +105 -135
models/matanyone_loader.py CHANGED
@@ -3,17 +3,10 @@
3
  """
4
  MatAnyone adapter — SAM2-seeded, streaming, build-agnostic.
5
 
6
- #1 Overview
7
- - SAM2 provides a seed mask on frame 0.
8
  - MatAnyone does frame-by-frame alpha matting.
9
- - Supports wheels that expect either 4D [B,C,H,W] or 5D [B,T,C,H,W].
10
- - Accepts HWC or CHW frames; converts to HWC RGB.
11
- - Writes alpha.mp4 (grayscale-as-BGR) and fg.mp4 (RGB on black).
12
-
13
- Public API used by pipeline:
14
- MatAnyError (exception)
15
- class MatAnyoneSession:
16
- process_stream(video_path, seed_mask_path=None, out_dir=None, progress_cb=None) -> (alpha_path, fg_path)
17
  """
18
 
19
  from __future__ import annotations
@@ -28,9 +21,9 @@ class MatAnyoneSession:
28
 
29
  log = logging.getLogger(__name__)
30
 
31
- # ---------- Progress helper (safe & rate-limited) ----------
32
  def _env_flag(name: str, default: str = "0") -> bool:
33
- return os.getenv(name, default).strip().lower() in {"1", "true", "yes", "on"}
34
 
35
  _PROGRESS_CB_ENABLED = _env_flag("MATANY_PROGRESS", "1")
36
  _PROGRESS_MIN_INTERVAL = float(os.getenv("MATANY_PROGRESS_MIN_SEC", "0.25"))
@@ -39,7 +32,6 @@ def _env_flag(name: str, default: str = "0") -> bool:
39
  _progress_disabled = False
40
 
41
  def _emit_progress(cb, pct: float, msg: str):
42
- """#2 UI progress callback wrapper (tolerant of legacy 1-arg signatures)"""
43
  global _progress_last, _progress_last_msg, _progress_disabled
44
  if not cb or not _PROGRESS_CB_ENABLED or _progress_disabled:
45
  return
@@ -50,7 +42,7 @@ def _emit_progress(cb, pct: float, msg: str):
50
  try:
51
  cb(pct, msg) # preferred (pct, msg)
52
  except TypeError:
53
- cb(msg) # legacy (msg-only)
54
  _progress_last = now
55
  _progress_last_msg = msg
56
  except Exception as e:
@@ -59,12 +51,10 @@ def _emit_progress(cb, pct: float, msg: str):
59
 
60
  # ---------- Errors ----------
61
  class MatAnyError(RuntimeError):
62
- """#3 Adapter-level error (keeps upstream logs readable)"""
63
  pass
64
 
65
- # ---------- CUDA snapshots ----------
66
  def _cuda_snapshot(device: Optional[torch.device]) -> str:
67
- """#4 Best-effort CUDA memory + device info (for error context)"""
68
  try:
69
  if not torch.cuda.is_available():
70
  return "CUDA: N/A"
@@ -79,7 +69,6 @@ def _cuda_snapshot(device: Optional[torch.device]) -> str:
79
  return f"CUDA snapshot error: {e!r}"
80
 
81
  def _safe_empty_cache():
82
- """#5 Non-blocking VRAM cleanup (avoid synchronize() in Spaces)"""
83
  if not torch.cuda.is_available():
84
  return
85
  try:
@@ -90,9 +79,8 @@ def _safe_empty_cache():
90
  # ---------- SAM2 → seed mask prep ----------
91
  def _prepare_seed_mask(sam2_mask: np.ndarray, H: int, W: int) -> np.ndarray:
92
  """
93
- #6 Normalize SAM2 mask to float32 [H,W] in {0,1}, white = foreground.
94
- - Accepts 2D or 3-channel images; resizes with NEAREST to keep edges crisp.
95
- - Auto-inverts if >60% of the image is ON (likely polarity swap).
96
  """
97
  if not isinstance(sam2_mask, np.ndarray):
98
  raise MatAnyError(f"SAM2 mask must be numpy array, got {type(sam2_mask)}")
@@ -109,47 +97,41 @@ def _prepare_seed_mask(sam2_mask: np.ndarray, H: int, W: int) -> np.ndarray:
109
  m /= 255.0
110
  m = np.clip(m, 0.0, 1.0)
111
 
112
- cov = float((m > 0.5).mean())
113
- if cov > 0.60:
114
  m = 1.0 - m
115
 
116
- # hard binarize for a clean seed
117
- m = (m > 0.5).astype(np.float32)
118
- return m
119
 
120
  # ---------- Frame conversion ----------
121
  def _frame_bgr_to_hwc_rgb_numpy(frame) -> np.ndarray:
122
- """
123
- #7 Accepts OpenCV BGR uint8 HWC, or uint8 CHW; returns HWC RGB uint8.
124
- """
125
  if not isinstance(frame, np.ndarray) or frame.ndim != 3:
126
  raise MatAnyError(f"Frame must be HWC/CHW numpy array, got {type(frame)}, shape={getattr(frame, 'shape', None)}")
127
  arr = frame
128
- # Accept CHW and convert to HWC
129
- if arr.shape[0] == 3 and arr.shape[2] != 3:
130
- arr = np.transpose(arr, (1, 2, 0)) # CHW -> HWC
131
  if arr.dtype != np.uint8:
132
  raise MatAnyError(f"Frame must be uint8, got {arr.dtype}")
133
- rgb = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)
134
- return rgb
135
 
136
  # ============================================================================
137
 
138
  class MatAnyoneSession:
139
  """
140
- #8 Streaming wrapper that seeds MatAnyone with a SAM2 mask on frame 0.
141
- - Tries 4D first; if the wheel truly wants 5D, promotes both image AND mask.
142
- - Has an override env: MATANY_FORCE_FORMAT=4D|5D (for debugging).
143
  """
144
  def __init__(self, device: Optional[str] = None, precision: str = "auto"):
145
  self.device = torch.device(device) if device else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
146
  self.precision = precision.lower()
147
 
148
- # Optional override: MATANY_FORCE_FORMAT=4D|5D
149
- fmt = os.getenv("MATANY_FORCE_FORMAT", "").strip().lower()
150
- self._force_4d = (fmt == "4d")
151
- self._force_5d = (fmt == "5d")
152
- self._use_5d = self._force_5d # start in 5D only if forced
 
 
153
 
154
  try:
155
  from matanyone.inference.inference_core import InferenceCore
@@ -158,37 +140,36 @@ def __init__(self, device: Optional[str] = None, precision: str = "auto"):
158
  try:
159
  self.core = InferenceCore()
160
  except TypeError:
161
- # HF wheel constructor that needs a repo string
162
  self.core = InferenceCore("PeiqingYang/MatAnyone")
163
 
164
- self.api = "step" if hasattr(self.core, "step") else ("process_frame" if hasattr(self.core, "process_frame") else None)
165
- if not self.api:
166
- raise MatAnyError("MatAnyone core exposes neither 'step' nor 'process_frame'")
 
 
 
 
 
 
 
167
 
168
- log.info(f"[MATANY] API: {self.api} | device={self.device} | force4d={self._force_4d} | force5d={self._force_5d}")
 
169
 
170
- # ----- AMP policy -----
 
 
171
  def _amp(self):
172
- """#9 Simple AMP gate (auto/fp16/fp32)"""
173
  if self.device.type != "cuda":
174
  return torch.amp.autocast(device_type="cuda", enabled=False)
175
  if self.precision == "fp32":
176
  return torch.amp.autocast(device_type="cuda", enabled=False)
177
  if self.precision == "fp16":
178
  return torch.amp.autocast(device_type="cuda", enabled=True, dtype=torch.float16)
179
- # auto
180
  return torch.amp.autocast(device_type="cuda", enabled=True)
181
 
182
- # ----- Tensor builders -----
183
  def _to_tensors(self, img_hwc_rgb: np.ndarray, mask_hw: Optional[np.ndarray]):
184
- """
185
- #10 Build both 4D and 5D tensors.
186
- Returns: (img_4d, img_5d, mask_4d, mask_5d)
187
- - img_4d: [1, 3, H, W]
188
- - img_5d: [1, 1, 3, H, W]
189
- - mask_4d: [1, 1, H, W] or None
190
- - mask_5d: [1, 1, 1, H, W] or None
191
- """
192
  img = torch.from_numpy(img_hwc_rgb).to(self.device)
193
  if img.dtype != torch.float32:
194
  img = img.float()
@@ -204,107 +185,104 @@ def _to_tensors(self, img_hwc_rgb: np.ndarray, mask_hw: Optional[np.ndarray]):
204
  m = torch.from_numpy(mask_hw).to(self.device)
205
  if m.dtype != torch.float32:
206
  m = m.float()
207
- # robust binarize
208
  m = (m >= 0.5).float() if float(m.max().item()) <= 1.0 else (m >= 128).float()
209
- mask_4d = m.unsqueeze(0).unsqueeze(0).contiguous() # [1,1,H,W]
210
- mask_5d = mask_4d.unsqueeze(1).contiguous() # [1,1,1,H,W]
211
  return img_4d, img_5d, mask_4d, mask_5d
212
 
213
- # ----- Core call (4D first, 5D only if demanded) -----
214
- def _core_call(self, img_4d, img_5d, mask_4d, mask_5d, is_first: bool):
215
- """
216
- #11 Dispatch into the wheel, trying 4D, then 5D if the error suggests it.
217
- Also backs off from 5D → 4D when conv2d complains about 3D/4D.
218
- """
219
- def run(use_5d: bool):
220
- img = img_5d if use_5d else img_4d
221
- msk = mask_5d if use_5d else mask_4d # <<< IMPORTANT: match ranks
222
- if self.api == "step":
223
- if is_first and msk is not None:
 
 
 
224
  try:
225
- return self.core.step(img, msk, is_first=True)
226
- except TypeError:
227
- return self.core.step(img, msk) # older signature
228
- else:
229
- return self.core.step(img)
 
 
 
 
 
 
 
 
 
 
 
 
230
  else:
231
- return self.core.process_frame(img, msk if is_first else None)
232
 
233
  with torch.no_grad(), self._amp():
234
- # Forced modes for debugging
235
  if self._force_4d:
236
  return run(False)
237
  if self._force_5d:
238
  return run(True)
239
 
240
- # If a previous frame decided on 5D, try 5D first but back off if needed
241
  if self._use_5d:
242
  try:
243
  return run(True)
244
  except RuntimeError as e5:
245
- msg5 = str(e5)
246
- # If the wheel says conv2d needs 3D/4D, revert to 4D permanently
247
- if "Expected 3D" in msg5 and "4D" in msg5 and "conv2d" in msg5:
248
  log.info("[MATANY] 5D rejected by wheel (conv2d wants 3D/4D). Falling back to 4D.")
249
  self._use_5d = False
250
  return run(False)
251
- raise MatAnyError(f"Runtime error (5D path): {msg5}") from e5
252
 
253
- # Default: try 4D first
254
  try:
255
- return run(False)
256
  except RuntimeError as e4:
257
- msg4 = str(e4)
258
- # Hints that the wheel actually expects 5D
259
- wants_5d = any(kw in msg4 for kw in [
260
- "expected 5D",
261
- "expects 5D",
262
- "input.dim() == 5",
263
- "but got 4D",
264
- "got input of size: [1, 3," # some wheels report this pattern
265
- ])
266
- if wants_5d:
267
  log.info("[MATANY] Wheel appears to expect 5D — retrying with [1,1,3,H,W] and [1,1,1,H,W].")
268
  self._use_5d = True
269
  try:
270
  return run(True)
271
  except RuntimeError as e5b:
272
- msg5b = str(e5b)
273
- # If retry says conv2d wants 3D/4D, undo and raise original
274
- if "Expected 3D" in msg5b and "4D" in msg5b and "conv2d" in msg5b:
275
  self._use_5d = False
276
- raise MatAnyError(f"Wheel ultimately expects 4D (conv2d). Original 4D error: {msg4}") from e4
277
- raise MatAnyError(f"5D attempt failed: {msg5b}") from e5b
278
-
279
- # Add CUDA context for GPU errors
280
- if "CUDA" in msg4 or "cublas" in msg4.lower() or "cudnn" in msg4.lower():
281
  snap = _cuda_snapshot(self.device)
282
- raise MatAnyError(f"CUDA runtime error: {msg4} | {snap}") from e4
283
-
284
- # Generic wrap
285
- raise MatAnyError(f"Runtime error (4D path): {msg4}") from e4
286
 
287
  # ----- Per-frame runner -----
288
  def _run_frame(self, frame_bgr: np.ndarray, sam2_mask_hw: Optional[np.ndarray], is_first: bool) -> np.ndarray:
289
- """#12 Convert inputs, seed frame 0, call core, and normalize to [H,W] alpha."""
290
  rgb_hwc = _frame_bgr_to_hwc_rgb_numpy(frame_bgr)
291
  H, W = rgb_hwc.shape[:2]
 
292
 
293
- seed_for_this_frame = None
294
- if is_first and sam2_mask_hw is not None:
295
- seed_for_this_frame = _prepare_seed_mask(sam2_mask_hw, H, W)
296
-
297
- img_4d, img_5d, mask_4d, mask_5d = self._to_tensors(rgb_hwc, seed_for_this_frame)
298
-
299
- try:
300
- out = self._core_call(img_4d, img_5d, mask_4d, mask_5d, is_first)
301
- except torch.cuda.OutOfMemoryError as e:
302
- snap = _cuda_snapshot(self.device)
303
- raise MatAnyError(f"CUDA OOM while processing frame | {snap}") from e
304
- except Exception as e:
305
- raise MatAnyError(f"Runtime error: {e}") from e
306
 
307
- # Normalize output to [H,W] float32 in [0,1]
308
  if isinstance(out, torch.Tensor):
309
  alpha = out.detach().float().squeeze().cpu().numpy()
310
  else:
@@ -325,9 +303,6 @@ def process_stream(
325
  out_dir: Optional[Path] = None,
326
  progress_cb: Optional[Callable] = None,
327
  ) -> Tuple[Path, Path]:
328
- """
329
- #13 Stream the video one frame at a time (T=1), write alpha.mp4 & fg.mp4.
330
- """
331
  video_path = Path(video_path)
332
  if not video_path.exists():
333
  raise MatAnyError(f"Video file not found: {video_path}")
@@ -335,7 +310,6 @@ def process_stream(
335
  out_dir = Path(out_dir) if out_dir else video_path.parent
336
  out_dir.mkdir(parents=True, exist_ok=True)
337
 
338
- # Probe video
339
  cap_probe = cv2.VideoCapture(str(video_path))
340
  if not cap_probe.isOpened():
341
  raise MatAnyError(f"Failed to open video: {video_path}")
@@ -349,9 +323,8 @@ def process_stream(
349
 
350
  log.info(f"MatAnyone: {video_path.name} | {N} frames {W}x{H} @ {fps:.2f} fps")
351
  _emit_progress(progress_cb, 0.05, f"Video: {N} frames {W}x{H} @ {fps:.2f} fps")
352
- _emit_progress(progress_cb, 0.08, "Using step (frame-by-frame)")
353
 
354
- # Prepare writers
355
  alpha_path = out_dir / "alpha.mp4"
356
  fg_path = out_dir / "fg.mp4"
357
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
@@ -360,7 +333,6 @@ def process_stream(
360
  if not alpha_writer.isOpened() or not fg_writer.isOpened():
361
  raise MatAnyError("Failed to initialize VideoWriter(s)")
362
 
363
- # Load seed mask if provided (file path on disk)
364
  seed_mask_np = None
365
  if seed_mask_path is not None:
366
  p = Path(seed_mask_path)
@@ -369,7 +341,7 @@ def process_stream(
369
  m = cv2.imread(str(p), cv2.IMREAD_GRAYSCALE)
370
  if m is None:
371
  raise MatAnyError(f"Failed to read seed mask: {p}")
372
- seed_mask_np = m # we resize/polarize/binarize inside _run_frame
373
 
374
  cap = cv2.VideoCapture(str(video_path))
375
  if not cap.isOpened():
@@ -384,9 +356,8 @@ def process_stream(
384
  if not ret:
385
  break
386
  is_first = (idx == 0)
387
- alpha = self._run_frame(frame, seed_mask_np if is_first else None, is_first) # [H,W] in [0,1]
388
 
389
- # Compose outputs (no double divide)
390
  alpha_u8 = (alpha * 255.0 + 0.5).astype(np.uint8)
391
  alpha_bgr = cv2.cvtColor(alpha_u8, cv2.COLOR_GRAY2BGR)
392
  fg_bgr = (frame.astype(np.float32) * alpha[..., None]).clip(0, 255).astype(np.uint8)
@@ -411,7 +382,6 @@ def process_stream(
411
  except: pass
412
  _safe_empty_cache()
413
 
414
- # Verify outputs are non-empty
415
  if not alpha_path.exists() or alpha_path.stat().st_size == 0:
416
  raise MatAnyError(f"Output file missing/empty: {alpha_path}")
417
  if not fg_path.exists() or fg_path.stat().st_size == 0:
 
3
  """
4
  MatAnyone adapter — SAM2-seeded, streaming, build-agnostic.
5
 
6
+ - SAM2 defines the subject (seed mask) on frame 0.
 
7
  - MatAnyone does frame-by-frame alpha matting.
8
+ - Prefers process_frame (HWC numpy) and falls back to step.
9
+ - For step(): supports 4D [B,C,H,W] and 5D [B,T,C,H,W] with matching mask rank.
 
 
 
 
 
 
10
  """
11
 
12
  from __future__ import annotations
 
21
 
22
  log = logging.getLogger(__name__)
23
 
24
+ # ---------- Progress helper ----------
25
  def _env_flag(name: str, default: str = "0") -> bool:
26
+ return os.getenv(name, default).strip().lower() in {"1","true","yes","on"}
27
 
28
  _PROGRESS_CB_ENABLED = _env_flag("MATANY_PROGRESS", "1")
29
  _PROGRESS_MIN_INTERVAL = float(os.getenv("MATANY_PROGRESS_MIN_SEC", "0.25"))
 
32
  _progress_disabled = False
33
 
34
  def _emit_progress(cb, pct: float, msg: str):
 
35
  global _progress_last, _progress_last_msg, _progress_disabled
36
  if not cb or not _PROGRESS_CB_ENABLED or _progress_disabled:
37
  return
 
42
  try:
43
  cb(pct, msg) # preferred (pct, msg)
44
  except TypeError:
45
+ cb(msg) # legacy (msg)
46
  _progress_last = now
47
  _progress_last_msg = msg
48
  except Exception as e:
 
51
 
52
  # ---------- Errors ----------
53
  class MatAnyError(RuntimeError):
 
54
  pass
55
 
56
+ # ---------- CUDA helpers ----------
57
  def _cuda_snapshot(device: Optional[torch.device]) -> str:
 
58
  try:
59
  if not torch.cuda.is_available():
60
  return "CUDA: N/A"
 
69
  return f"CUDA snapshot error: {e!r}"
70
 
71
  def _safe_empty_cache():
 
72
  if not torch.cuda.is_available():
73
  return
74
  try:
 
79
  # ---------- SAM2 → seed mask prep ----------
80
  def _prepare_seed_mask(sam2_mask: np.ndarray, H: int, W: int) -> np.ndarray:
81
  """
82
+ Normalize to float32 [H,W] in {0,1}, white=FG.
83
+ Auto-invert if >60% ON (likely wrong polarity).
 
84
  """
85
  if not isinstance(sam2_mask, np.ndarray):
86
  raise MatAnyError(f"SAM2 mask must be numpy array, got {type(sam2_mask)}")
 
97
  m /= 255.0
98
  m = np.clip(m, 0.0, 1.0)
99
 
100
+ if (m > 0.5).mean() > 0.60:
 
101
  m = 1.0 - m
102
 
103
+ return (m > 0.5).astype(np.float32)
 
 
104
 
105
  # ---------- Frame conversion ----------
106
  def _frame_bgr_to_hwc_rgb_numpy(frame) -> np.ndarray:
107
+ """Accept HWC/CHW BGR uint8 → return HWC RGB uint8."""
 
 
108
  if not isinstance(frame, np.ndarray) or frame.ndim != 3:
109
  raise MatAnyError(f"Frame must be HWC/CHW numpy array, got {type(frame)}, shape={getattr(frame, 'shape', None)}")
110
  arr = frame
111
+ if arr.shape[0] == 3 and arr.shape[2] != 3: # CHW → HWC
112
+ arr = np.transpose(arr, (1, 2, 0))
 
113
  if arr.dtype != np.uint8:
114
  raise MatAnyError(f"Frame must be uint8, got {arr.dtype}")
115
+ return cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)
 
116
 
117
  # ============================================================================
118
 
119
  class MatAnyoneSession:
120
  """
121
+ Streaming wrapper that seeds MatAnyone on frame 0.
122
+ Prefers core.process_frame (HWC numpy), falls back to core.step with 4D/5D.
 
123
  """
124
  def __init__(self, device: Optional[str] = None, precision: str = "auto"):
125
  self.device = torch.device(device) if device else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
126
  self.precision = precision.lower()
127
 
128
+ # API/format overrides for debugging
129
+ api_force = os.getenv("MATANY_FORCE_API", "").strip().lower() # "process" or "step"
130
+ fmt_force = os.getenv("MATANY_FORCE_FORMAT", "").strip().lower() # "4d" or "5d"
131
+ self._force_api_process = (api_force == "process")
132
+ self._force_api_step = (api_force == "step")
133
+ self._force_4d = (fmt_force == "4d")
134
+ self._force_5d = (fmt_force == "5d")
135
 
136
  try:
137
  from matanyone.inference.inference_core import InferenceCore
 
140
  try:
141
  self.core = InferenceCore()
142
  except TypeError:
 
143
  self.core = InferenceCore("PeiqingYang/MatAnyone")
144
 
145
+ self._has_process = hasattr(self.core, "process_frame")
146
+ self._has_step = hasattr(self.core, "step")
147
+ if not (self._has_process or self._has_step):
148
+ raise MatAnyError("MatAnyone core exposes neither 'process_frame' nor 'step'")
149
+
150
+ # Prefer process_frame unless forced to step
151
+ if self._force_api_step and not self._has_step:
152
+ raise MatAnyError("MATANY_FORCE_API=step but core.step is missing")
153
+ if self._force_api_process and not self._has_process:
154
+ raise MatAnyError("MATANY_FORCE_API=process but core.process_frame is missing")
155
 
156
+ self._api = "process_frame" if (self._has_process and not self._force_api_step) or self._force_api_process else "step"
157
+ self._use_5d = bool(self._force_5d) # only used in step mode
158
 
159
+ log.info(f"[MATANY] APIs: process_frame={self._has_process}, step={self._has_step} | active={self._api} | force4d={self._force_4d} force5d={self._force_5d}")
160
+
161
+ # AMP only affects step() path where we may use torch tensors
162
  def _amp(self):
 
163
  if self.device.type != "cuda":
164
  return torch.amp.autocast(device_type="cuda", enabled=False)
165
  if self.precision == "fp32":
166
  return torch.amp.autocast(device_type="cuda", enabled=False)
167
  if self.precision == "fp16":
168
  return torch.amp.autocast(device_type="cuda", enabled=True, dtype=torch.float16)
 
169
  return torch.amp.autocast(device_type="cuda", enabled=True)
170
 
171
+ # ----- Tensor builders for step() mode -----
172
  def _to_tensors(self, img_hwc_rgb: np.ndarray, mask_hw: Optional[np.ndarray]):
 
 
 
 
 
 
 
 
173
  img = torch.from_numpy(img_hwc_rgb).to(self.device)
174
  if img.dtype != torch.float32:
175
  img = img.float()
 
185
  m = torch.from_numpy(mask_hw).to(self.device)
186
  if m.dtype != torch.float32:
187
  m = m.float()
 
188
  m = (m >= 0.5).float() if float(m.max().item()) <= 1.0 else (m >= 128).float()
189
+ mask_4d = m.unsqueeze(0).unsqueeze(0).contiguous() # [1,1,H,W]
190
+ mask_5d = mask_4d.unsqueeze(1).contiguous() # [1,1,1,H,W]
191
  return img_4d, img_5d, mask_4d, mask_5d
192
 
193
+ # ----- Core call: process_frame preferred, fallback to step -----
194
+ def _call_process_frame(self, rgb_hwc: np.ndarray, seed_mask_hw: Optional[np.ndarray], is_first: bool):
195
+ """Try numpy path first; fallback to torch path if the wheel requests tensors."""
196
+ seed = seed_mask_hw if is_first else None
197
+
198
+ # 1) Most wheels want numpy HWC + 2D mask (float 0..1 or uint8)
199
+ try:
200
+ return self.core.process_frame(rgb_hwc, seed)
201
+ except TypeError as e_np:
202
+ msg = str(e_np).lower()
203
+ # 2) Some wheels want torch [B,C,H,W] tensors even in process_frame
204
+ if "tensor" in msg or "expected" in msg or "conv2d" in msg:
205
+ img_4d, _, mask_4d, _ = self._to_tensors(rgb_hwc, seed)
206
+ with torch.no_grad(), self._amp():
207
  try:
208
+ return self.core.process_frame(img_4d, mask_4d)
209
+ except Exception as e_t:
210
+ raise MatAnyError(f"process_frame tensor path failed: {e_t}") from e_t
211
+ raise
212
+
213
+ def _call_step(self, rgb_hwc: np.ndarray, seed_mask_hw: Optional[np.ndarray], is_first: bool):
214
+ """4D first; if the wheel wants 5D, promote both image AND mask."""
215
+ img_4d, img_5d, mask_4d, mask_5d = self._to_tensors(rgb_hwc, seed_mask_hw if is_first else None)
216
+
217
+ def run(use_5d: bool):
218
+ img = img_5d if use_5d else img_4d
219
+ msk = mask_5d if use_5d else mask_4d
220
+ if is_first and msk is not None:
221
+ try:
222
+ return self.core.step(img, msk, is_first=True)
223
+ except TypeError:
224
+ return self.core.step(img, msk)
225
  else:
226
+ return self.core.step(img)
227
 
228
  with torch.no_grad(), self._amp():
 
229
  if self._force_4d:
230
  return run(False)
231
  if self._force_5d:
232
  return run(True)
233
 
 
234
  if self._use_5d:
235
  try:
236
  return run(True)
237
  except RuntimeError as e5:
238
+ m5 = str(e5)
239
+ if "expected 3d" in m5.lower() and "4d" in m5 and "conv2d" in m5.lower():
 
240
  log.info("[MATANY] 5D rejected by wheel (conv2d wants 3D/4D). Falling back to 4D.")
241
  self._use_5d = False
242
  return run(False)
243
+ raise MatAnyError(f"Runtime error (step/5D): {m5}") from e5
244
 
 
245
  try:
246
+ return run(False) # 4D
247
  except RuntimeError as e4:
248
+ m4 = str(e4)
249
+ needs_5d = any(kw in m4 for kw in ["expected 5D", "expects 5D", "input.dim() == 5", "but got 4D", "got input of size: [1, 3,"])
250
+ if needs_5d:
 
 
 
 
 
 
 
251
  log.info("[MATANY] Wheel appears to expect 5D — retrying with [1,1,3,H,W] and [1,1,1,H,W].")
252
  self._use_5d = True
253
  try:
254
  return run(True)
255
  except RuntimeError as e5b:
256
+ m5b = str(e5b)
257
+ if "expected 3d" in m5b.lower() and "4d" in m5b and "conv2d" in m5b.lower():
 
258
  self._use_5d = False
259
+ raise MatAnyError(f"Wheel ultimately expects 4D (conv2d). Original 4D error: {m4}") from e4
260
+ raise MatAnyError(f"step/5D attempt failed: {m5b}") from e5b
261
+ if "cuda" in m4.lower():
 
 
262
  snap = _cuda_snapshot(self.device)
263
+ raise MatAnyError(f"CUDA runtime error: {m4} | {snap}") from e4
264
+ raise MatAnyError(f"Runtime error (step/4D): {m4}") from e4
 
 
265
 
266
  # ----- Per-frame runner -----
267
  def _run_frame(self, frame_bgr: np.ndarray, sam2_mask_hw: Optional[np.ndarray], is_first: bool) -> np.ndarray:
 
268
  rgb_hwc = _frame_bgr_to_hwc_rgb_numpy(frame_bgr)
269
  H, W = rgb_hwc.shape[:2]
270
+ seed_for_this_frame = _prepare_seed_mask(sam2_mask_hw, H, W) if (is_first and sam2_mask_hw is not None) else None
271
 
272
+ # Primary: process_frame
273
+ if self._api == "process_frame":
274
+ try:
275
+ out = self._call_process_frame(rgb_hwc, seed_for_this_frame, is_first)
276
+ except Exception as e_proc:
277
+ log.warning(f"[MATANY] process_frame failed ({e_proc}); falling back to step().")
278
+ if not self._has_step:
279
+ raise MatAnyError(f"process_frame failed and step() is unavailable: {e_proc}")
280
+ self._api = "step"
281
+ out = self._call_step(rgb_hwc, seed_for_this_frame, is_first)
282
+ else:
283
+ out = self._call_step(rgb_hwc, seed_for_this_frame, is_first)
 
284
 
285
+ # Normalize to 2D alpha [H,W] in [0,1]
286
  if isinstance(out, torch.Tensor):
287
  alpha = out.detach().float().squeeze().cpu().numpy()
288
  else:
 
303
  out_dir: Optional[Path] = None,
304
  progress_cb: Optional[Callable] = None,
305
  ) -> Tuple[Path, Path]:
 
 
 
306
  video_path = Path(video_path)
307
  if not video_path.exists():
308
  raise MatAnyError(f"Video file not found: {video_path}")
 
310
  out_dir = Path(out_dir) if out_dir else video_path.parent
311
  out_dir.mkdir(parents=True, exist_ok=True)
312
 
 
313
  cap_probe = cv2.VideoCapture(str(video_path))
314
  if not cap_probe.isOpened():
315
  raise MatAnyError(f"Failed to open video: {video_path}")
 
323
 
324
  log.info(f"MatAnyone: {video_path.name} | {N} frames {W}x{H} @ {fps:.2f} fps")
325
  _emit_progress(progress_cb, 0.05, f"Video: {N} frames {W}x{H} @ {fps:.2f} fps")
326
+ _emit_progress(progress_cb, 0.08, "Using per-frame processing")
327
 
 
328
  alpha_path = out_dir / "alpha.mp4"
329
  fg_path = out_dir / "fg.mp4"
330
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
 
333
  if not alpha_writer.isOpened() or not fg_writer.isOpened():
334
  raise MatAnyError("Failed to initialize VideoWriter(s)")
335
 
 
336
  seed_mask_np = None
337
  if seed_mask_path is not None:
338
  p = Path(seed_mask_path)
 
341
  m = cv2.imread(str(p), cv2.IMREAD_GRAYSCALE)
342
  if m is None:
343
  raise MatAnyError(f"Failed to read seed mask: {p}")
344
+ seed_mask_np = m
345
 
346
  cap = cv2.VideoCapture(str(video_path))
347
  if not cap.isOpened():
 
356
  if not ret:
357
  break
358
  is_first = (idx == 0)
359
+ alpha = self._run_frame(frame, seed_mask_np if is_first else None, is_first)
360
 
 
361
  alpha_u8 = (alpha * 255.0 + 0.5).astype(np.uint8)
362
  alpha_bgr = cv2.cvtColor(alpha_u8, cv2.COLOR_GRAY2BGR)
363
  fg_bgr = (frame.astype(np.float32) * alpha[..., None]).clip(0, 255).astype(np.uint8)
 
382
  except: pass
383
  _safe_empty_cache()
384
 
 
385
  if not alpha_path.exists() or alpha_path.stat().st_size == 0:
386
  raise MatAnyError(f"Output file missing/empty: {alpha_path}")
387
  if not fg_path.exists() or fg_path.stat().st_size == 0: