MogensR commited on
Commit
87688ee
·
1 Parent(s): 2a63856
Files changed (1) hide show
  1. models/matanyone_loader.py +170 -137
models/matanyone_loader.py CHANGED
@@ -1,23 +1,19 @@
1
  #!/usr/bin/env python3
2
  # -*- coding: utf-8 -*-
3
  """
4
- MatAnyone adapter — SAM2-seeded, streaming, build-agnostic (HF Spaces ready).
5
-
6
- GOAL (pipeline contract):
7
- - Use SAM2 only to define the person on frame 0 (seed mask).
8
- - Feed MatAnyone frames one-by-one to generate the alpha matte.
9
- - Always pass tensors in the shapes MatAnyone’s conv2d expects:
10
- * image : [3, H, W] (float32, 0..1, RGB, CHW)
11
- * mask : [H, W] (float32, 0..1, binary)
12
- (No 5D tensors. No time dimension.)
13
-
14
- Outputs:
15
- - alpha.mp4 grayscale stored as BGR for broad mp4v compatibility
16
- - fg.mp4 original RGB multiplied by alpha (for later compositing)
17
-
18
- Works on HF Spaces:
19
- - Reads from /tmp/gradio/...
20
- - Writes to the same folder (or a provided out_dir, e.g. /data/outputs).
21
  """
22
 
23
  from __future__ import annotations
@@ -32,46 +28,43 @@
32
 
33
  log = logging.getLogger(__name__)
34
 
35
- # =============================================================================
36
- # [0] Progress helper (safe & rate-limited)
37
- # =============================================================================
38
  def _env_flag(name: str, default: str = "0") -> bool:
39
  return os.getenv(name, default).strip().lower() in {"1", "true", "yes", "on"}
40
 
41
  _PROGRESS_CB_ENABLED = _env_flag("MATANY_PROGRESS", "1")
42
  _PROGRESS_MIN_INTERVAL = float(os.getenv("MATANY_PROGRESS_MIN_SEC", "0.25"))
43
- _progress_last_t = 0.0
44
- _progress_last_msg: Optional[str] = None
45
  _progress_disabled = False
46
 
47
  def _emit_progress(cb, pct: float, msg: str):
48
- """[0.1] Emit progress without ever crashing the caller."""
49
- global _progress_last_t, _progress_last_msg, _progress_disabled
50
  if not cb or not _PROGRESS_CB_ENABLED or _progress_disabled:
51
  return
52
  now = time.time()
53
- if (now - _progress_last_t) < _PROGRESS_MIN_INTERVAL and msg == _progress_last_msg:
54
  return
55
  try:
56
  try:
57
- cb(pct, msg) # preferred signature (pct, msg)
58
  except TypeError:
59
- cb(msg) # legacy signature (msg)
60
- _progress_last_t = now
61
  _progress_last_msg = msg
62
  except Exception as e:
63
  _progress_disabled = True
64
  log.warning("[progress-cb] disabled due to exception: %s", e)
65
 
66
- # =============================================================================
67
- # [1] Errors & CUDA helpers
68
- # =============================================================================
69
  class MatAnyError(RuntimeError):
70
- """Single error type the pipeline can catch & decide to fallback."""
71
  pass
72
 
 
73
  def _cuda_snapshot(device: Optional[torch.device]) -> str:
74
- """[1.1] Short, safe description of CUDA memory state."""
75
  try:
76
  if not torch.cuda.is_available():
77
  return "CUDA: N/A"
@@ -86,30 +79,23 @@ def _cuda_snapshot(device: Optional[torch.device]) -> str:
86
  return f"CUDA snapshot error: {e!r}"
87
 
88
  def _safe_empty_cache():
89
- """[1.2] Try to free CUDA cache; never raise."""
90
  if not torch.cuda.is_available():
91
  return
92
- try:
93
- torch.cuda.synchronize()
94
- except Exception:
95
- pass
96
  try:
97
  torch.cuda.empty_cache()
98
  except Exception:
99
  pass
100
 
101
- # =============================================================================
102
- # [2] Mask & frame preparation
103
- # =============================================================================
104
  def _prepare_seed_mask(sam2_mask: np.ndarray, H: int, W: int) -> np.ndarray:
105
  """
106
- [2.1] Convert SAM2 mask (0/255 or 0..1) into a clean binary [H,W] float32 in {0,1}.
107
- Auto-invert if coverage is > 60% (typical “background is white” case).
 
108
  """
109
  if not isinstance(sam2_mask, np.ndarray):
110
  raise MatAnyError(f"SAM2 mask must be numpy array, got {type(sam2_mask)}")
111
-
112
- # Accept accidental 3-channel masks
113
  if sam2_mask.ndim == 3 and sam2_mask.shape[2] == 3:
114
  sam2_mask = cv2.cvtColor(sam2_mask, cv2.COLOR_BGR2GRAY)
115
  if sam2_mask.ndim != 2:
@@ -120,25 +106,26 @@ def _prepare_seed_mask(sam2_mask: np.ndarray, H: int, W: int) -> np.ndarray:
120
 
121
  m = sam2_mask.astype(np.float32)
122
  if m.max() > 1.0:
123
- m *= (1.0 / 255.0)
124
  m = np.clip(m, 0.0, 1.0)
125
 
126
  cov = float((m > 0.5).mean())
127
  if cov > 0.60:
128
- m = 1.0 - m # Auto-polarity for “mask covers most of the frame”
129
- cov = 1.0 - cov
130
- # Binarize (MatAnyone seed likes a crisp mask)
131
  m = (m > 0.5).astype(np.float32)
132
  return m
133
 
134
- def _frame_bgr_to_rgb_hwc(frame: np.ndarray) -> np.ndarray:
 
135
  """
136
- [2.2] Accept OpenCV BGR uint8 HWC (or CHW uint8), return RGB uint8 HWC.
137
  """
138
  if not isinstance(frame, np.ndarray) or frame.ndim != 3:
139
  raise MatAnyError(f"Frame must be HWC/CHW numpy array, got {type(frame)}, shape={getattr(frame, 'shape', None)}")
140
  arr = frame
141
- # Allow CHW input (rare, but we support it)
142
  if arr.shape[0] == 3 and arr.shape[2] != 3:
143
  arr = np.transpose(arr, (1, 2, 0)) # CHW -> HWC
144
  if arr.dtype != np.uint8:
@@ -146,24 +133,24 @@ def _frame_bgr_to_rgb_hwc(frame: np.ndarray) -> np.ndarray:
146
  rgb = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)
147
  return rgb
148
 
149
- # =============================================================================
150
- # [3] Main session
151
- # =============================================================================
152
  class MatAnyoneSession:
153
  """
154
- Streaming wrapper that seeds MatAnyone with a SAM2 mask on frame 0.
155
-
156
- KEY DECISION: We always pass CHW (3,H,W) to core.step(), and HW (H,W) for the mask.
157
- Absolutely no [B,T,C,H,W] tensors.
158
  """
159
  def __init__(self, device: Optional[str] = None, precision: str = "auto"):
160
- # [3.1] Device & AMP
161
- self.device = torch.device(device) if device else (
162
- torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
163
- )
164
  self.precision = precision.lower()
165
 
166
- # [3.2] Import & instantiate the MatAnyone core
 
 
 
 
 
167
  try:
168
  from matanyone.inference.inference_core import InferenceCore
169
  except ImportError as e:
@@ -171,122 +158,166 @@ def __init__(self, device: Optional[str] = None, precision: str = "auto"):
171
  try:
172
  self.core = InferenceCore()
173
  except TypeError:
174
- # Some builds require the repo-id
175
  self.core = InferenceCore("PeiqingYang/MatAnyone")
176
 
177
- # [3.3] Choose API (prefer step)
178
- if hasattr(self.core, "step") and callable(getattr(self.core, "step")):
179
- self.api = "step"
180
- elif hasattr(self.core, "process_frame") and callable(getattr(self.core, "process_frame")):
181
- self.api = "process_frame"
182
- else:
183
  raise MatAnyError("MatAnyone core exposes neither 'step' nor 'process_frame'")
184
 
185
- log.info(f"[MATANY] Using API: {self.api} | device={self.device}")
186
 
187
- # [3.4] AMP context (enabled on CUDA unless precision=='fp32')
188
  def _amp(self):
 
189
  if self.device.type != "cuda":
190
  return torch.amp.autocast(device_type="cuda", enabled=False)
191
  if self.precision == "fp32":
192
  return torch.amp.autocast(device_type="cuda", enabled=False)
193
  if self.precision == "fp16":
194
  return torch.amp.autocast(device_type="cuda", enabled=True, dtype=torch.float16)
 
195
  return torch.amp.autocast(device_type="cuda", enabled=True)
196
 
197
- # [3.5] Tensor builders — STRICT shapes
198
- def _to_tensors_strict(self, rgb_hwc: np.ndarray, mask_hw: Optional[np.ndarray]):
199
  """
200
- image_out: torch float32 [3,H,W] in 0..1 (RGB, CHW)
201
- mask_out : torch float32 [H,W] in {0,1}
 
 
 
 
202
  """
203
- # image -> CHW
204
- img = torch.from_numpy(rgb_hwc).to(self.device)
205
  if img.dtype != torch.float32:
206
  img = img.float()
207
  if float(img.max().item()) > 1.0:
208
  img = img / 255.0
 
209
  img_chw = img.permute(2, 0, 1).contiguous() # [3,H,W]
 
 
210
 
211
- # mask -> HW
212
- mask_t = None
213
  if mask_hw is not None:
214
  m = torch.from_numpy(mask_hw).to(self.device)
215
  if m.dtype != torch.float32:
216
  m = m.float()
217
- # Robust binarization (accepts 0/1 or 0..1 or 0/255 upstream)
218
- if float(m.max().item()) > 1.0:
219
- m = (m >= 128).float()
220
- else:
221
- m = (m >= 0.5).float()
222
- mask_t = m.contiguous() # [H,W]
223
- return img_chw, mask_t
224
-
225
- # [3.6] Core call (NO 5D, ever)
226
- def _core_call(self, img_chw: torch.Tensor, mask_hw: Optional[torch.Tensor], is_first: bool):
227
  """
228
- Route strictly:
229
- - step(image_chw, mask_hw) on frame 0 (if mask exists)
230
- - step(image_chw) on subsequent frames
231
- Fallbacks only switch between step/process_frame, NOT shapes.
232
  """
233
- with torch.no_grad(), self._amp():
 
 
234
  if self.api == "step":
235
- try:
236
- if is_first and mask_hw is not None:
237
- return self.core.step(img_chw, mask_hw) # <-- strict CHW/HW
238
- else:
239
- return self.core.step(img_chw)
240
- except TypeError:
241
- # Some wheels might gate arguments differently; try process_frame
242
- if is_first and mask_hw is not None and hasattr(self.core, "process_frame"):
243
- return self.core.process_frame(img_chw, mask_hw)
244
- elif hasattr(self.core, "process_frame"):
245
- return self.core.process_frame(img_chw, None)
246
- raise
247
  else:
248
- # process_frame fallback API
249
- return self.core.process_frame(img_chw, mask_hw if (is_first and mask_hw is not None) else None)
250
 
251
- # [3.7] Per-frame runner
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  def _run_frame(self, frame_bgr: np.ndarray, sam2_mask_hw: Optional[np.ndarray], is_first: bool) -> np.ndarray:
253
- rgb_hwc = _frame_bgr_to_rgb_hwc(frame_bgr)
 
254
  H, W = rgb_hwc.shape[:2]
255
 
256
- seed = None
257
  if is_first and sam2_mask_hw is not None:
258
- seed = _prepare_seed_mask(sam2_mask_hw, H, W) # [H,W] float32 {0,1}
259
 
260
- img_chw, mask_hw = self._to_tensors_strict(rgb_hwc, seed)
261
 
262
  try:
263
- out = self._core_call(img_chw, mask_hw, is_first)
264
  except torch.cuda.OutOfMemoryError as e:
265
  snap = _cuda_snapshot(self.device)
266
  raise MatAnyError(f"CUDA OOM while processing frame | {snap}") from e
267
- except RuntimeError as e:
268
- # Add CUDA snapshot if relevant
269
- if "CUDA" in str(e):
270
- snap = _cuda_snapshot(self.device)
271
- raise MatAnyError(f"CUDA runtime error: {e} | {snap}") from e
272
  raise MatAnyError(f"Runtime error: {e}") from e
273
 
274
- # Normalize output -> [H,W] float32 0..1
275
  if isinstance(out, torch.Tensor):
276
  alpha = out.detach().float().squeeze().cpu().numpy()
277
  else:
278
  alpha = np.asarray(out)
279
  alpha = alpha.astype(np.float32)
280
  if float(alpha.max()) > 1.0:
281
- alpha *= (1.0 / 255.0)
282
  alpha = np.squeeze(alpha)
283
  if alpha.ndim != 2:
284
  raise MatAnyError(f"Expected 2D alpha matte; got shape {alpha.shape}")
285
  return np.clip(alpha, 0.0, 1.0)
286
 
287
- # =============================================================================
288
- # [4] Public: stream the whole video
289
- # =============================================================================
290
  def process_stream(
291
  self,
292
  video_path: Path,
@@ -294,7 +325,9 @@ def process_stream(
294
  out_dir: Optional[Path] = None,
295
  progress_cb: Optional[Callable] = None,
296
  ) -> Tuple[Path, Path]:
297
- # [4.1] IO setup
 
 
298
  video_path = Path(video_path)
299
  if not video_path.exists():
300
  raise MatAnyError(f"Video file not found: {video_path}")
@@ -302,21 +335,23 @@ def process_stream(
302
  out_dir = Path(out_dir) if out_dir else video_path.parent
303
  out_dir.mkdir(parents=True, exist_ok=True)
304
 
305
- # [4.2] Probe video
306
  cap_probe = cv2.VideoCapture(str(video_path))
307
  if not cap_probe.isOpened():
308
  raise MatAnyError(f"Failed to open video: {video_path}")
309
  N = int(cap_probe.get(cv2.CAP_PROP_FRAME_COUNT))
310
- fps = cap_probe.get(cv2.CAP_PROP_FPS) or 25.0
311
  W = int(cap_probe.get(cv2.CAP_PROP_FRAME_WIDTH))
312
  H = int(cap_probe.get(cv2.CAP_PROP_FRAME_HEIGHT))
313
  cap_probe.release()
 
 
314
 
315
  log.info(f"MatAnyone: {video_path.name} | {N} frames {W}x{H} @ {fps:.2f} fps")
316
  _emit_progress(progress_cb, 0.05, f"Video: {N} frames {W}x{H} @ {fps:.2f} fps")
317
  _emit_progress(progress_cb, 0.08, "Using step (frame-by-frame)")
318
 
319
- # [4.3] Writers (alpha as BGR so mp4v is happy)
320
  alpha_path = out_dir / "alpha.mp4"
321
  fg_path = out_dir / "fg.mp4"
322
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
@@ -325,7 +360,7 @@ def process_stream(
325
  if not alpha_writer.isOpened() or not fg_writer.isOpened():
326
  raise MatAnyError("Failed to initialize VideoWriter(s)")
327
 
328
- # [4.4] Load seed mask file if provided
329
  seed_mask_np = None
330
  if seed_mask_path is not None:
331
  p = Path(seed_mask_path)
@@ -334,9 +369,8 @@ def process_stream(
334
  m = cv2.imread(str(p), cv2.IMREAD_GRAYSCALE)
335
  if m is None:
336
  raise MatAnyError(f"Failed to read seed mask: {p}")
337
- seed_mask_np = m # Prepare per-frame to ensure correct (H,W)
338
 
339
- # [4.5] Stream frames
340
  cap = cv2.VideoCapture(str(video_path))
341
  if not cap.isOpened():
342
  raise MatAnyError(f"Failed to open video for reading: {video_path}")
@@ -349,11 +383,10 @@ def process_stream(
349
  ret, frame = cap.read()
350
  if not ret:
351
  break
352
-
353
  is_first = (idx == 0)
354
- alpha = self._run_frame(frame, seed_mask_np if is_first else None, is_first) # [H,W] 0..1
355
 
356
- # Compose outputs (note: alpha already 0..1 — no double scaling)
357
  alpha_u8 = (alpha * 255.0 + 0.5).astype(np.uint8)
358
  alpha_bgr = cv2.cvtColor(alpha_u8, cv2.COLOR_GRAY2BGR)
359
  fg_bgr = (frame.astype(np.float32) * alpha[..., None]).clip(0, 255).astype(np.uint8)
@@ -378,7 +411,7 @@ def process_stream(
378
  except: pass
379
  _safe_empty_cache()
380
 
381
- # [4.6] Verify outputs
382
  if not alpha_path.exists() or alpha_path.stat().st_size == 0:
383
  raise MatAnyError(f"Output file missing/empty: {alpha_path}")
384
  if not fg_path.exists() or fg_path.stat().st_size == 0:
 
1
  #!/usr/bin/env python3
2
  # -*- coding: utf-8 -*-
3
  """
4
+ MatAnyone adapter — SAM2-seeded, streaming, build-agnostic.
5
+
6
+ #1 Overview
7
+ - SAM2 provides a seed mask on frame 0.
8
+ - MatAnyone does frame-by-frame alpha matting.
9
+ - Supports wheels that expect either 4D [B,C,H,W] or 5D [B,T,C,H,W].
10
+ - Accepts HWC or CHW frames; converts to HWC RGB.
11
+ - Writes alpha.mp4 (grayscale-as-BGR) and fg.mp4 (RGB on black).
12
+
13
+ Public API used by pipeline:
14
+ MatAnyError (exception)
15
+ class MatAnyoneSession:
16
+ process_stream(video_path, seed_mask_path=None, out_dir=None, progress_cb=None) -> (alpha_path, fg_path)
 
 
 
 
17
  """
18
 
19
  from __future__ import annotations
 
28
 
29
  log = logging.getLogger(__name__)
30
 
31
+ # ---------- Progress helper (safe & rate-limited) ----------
 
 
32
  def _env_flag(name: str, default: str = "0") -> bool:
33
  return os.getenv(name, default).strip().lower() in {"1", "true", "yes", "on"}
34
 
35
  _PROGRESS_CB_ENABLED = _env_flag("MATANY_PROGRESS", "1")
36
  _PROGRESS_MIN_INTERVAL = float(os.getenv("MATANY_PROGRESS_MIN_SEC", "0.25"))
37
+ _progress_last = 0.0
38
+ _progress_last_msg = None
39
  _progress_disabled = False
40
 
41
  def _emit_progress(cb, pct: float, msg: str):
42
+ """#2 UI progress callback wrapper (tolerant of legacy 1-arg signatures)"""
43
+ global _progress_last, _progress_last_msg, _progress_disabled
44
  if not cb or not _PROGRESS_CB_ENABLED or _progress_disabled:
45
  return
46
  now = time.time()
47
+ if (now - _progress_last) < _PROGRESS_MIN_INTERVAL and msg == _progress_last_msg:
48
  return
49
  try:
50
  try:
51
+ cb(pct, msg) # preferred (pct, msg)
52
  except TypeError:
53
+ cb(msg) # legacy (msg-only)
54
+ _progress_last = now
55
  _progress_last_msg = msg
56
  except Exception as e:
57
  _progress_disabled = True
58
  log.warning("[progress-cb] disabled due to exception: %s", e)
59
 
60
+ # ---------- Errors ----------
 
 
61
  class MatAnyError(RuntimeError):
62
+ """#3 Adapter-level error (keeps upstream logs readable)"""
63
  pass
64
 
65
+ # ---------- CUDA snapshots ----------
66
  def _cuda_snapshot(device: Optional[torch.device]) -> str:
67
+ """#4 Best-effort CUDA memory + device info (for error context)"""
68
  try:
69
  if not torch.cuda.is_available():
70
  return "CUDA: N/A"
 
79
  return f"CUDA snapshot error: {e!r}"
80
 
81
  def _safe_empty_cache():
82
+ """#5 Non-blocking VRAM cleanup (avoid synchronize() in Spaces)"""
83
  if not torch.cuda.is_available():
84
  return
 
 
 
 
85
  try:
86
  torch.cuda.empty_cache()
87
  except Exception:
88
  pass
89
 
90
+ # ---------- SAM2 → seed mask prep ----------
 
 
91
  def _prepare_seed_mask(sam2_mask: np.ndarray, H: int, W: int) -> np.ndarray:
92
  """
93
+ #6 Normalize SAM2 mask to float32 [H,W] in {0,1}, white = foreground.
94
+ - Accepts 2D or 3-channel images; resizes with NEAREST to keep edges crisp.
95
+ - Auto-inverts if >60% of the image is ON (likely polarity swap).
96
  """
97
  if not isinstance(sam2_mask, np.ndarray):
98
  raise MatAnyError(f"SAM2 mask must be numpy array, got {type(sam2_mask)}")
 
 
99
  if sam2_mask.ndim == 3 and sam2_mask.shape[2] == 3:
100
  sam2_mask = cv2.cvtColor(sam2_mask, cv2.COLOR_BGR2GRAY)
101
  if sam2_mask.ndim != 2:
 
106
 
107
  m = sam2_mask.astype(np.float32)
108
  if m.max() > 1.0:
109
+ m /= 255.0
110
  m = np.clip(m, 0.0, 1.0)
111
 
112
  cov = float((m > 0.5).mean())
113
  if cov > 0.60:
114
+ m = 1.0 - m
115
+
116
+ # hard binarize for a clean seed
117
  m = (m > 0.5).astype(np.float32)
118
  return m
119
 
120
+ # ---------- Frame conversion ----------
121
+ def _frame_bgr_to_hwc_rgb_numpy(frame) -> np.ndarray:
122
  """
123
+ #7 Accepts OpenCV BGR uint8 HWC, or uint8 CHW; returns HWC RGB uint8.
124
  """
125
  if not isinstance(frame, np.ndarray) or frame.ndim != 3:
126
  raise MatAnyError(f"Frame must be HWC/CHW numpy array, got {type(frame)}, shape={getattr(frame, 'shape', None)}")
127
  arr = frame
128
+ # Accept CHW and convert to HWC
129
  if arr.shape[0] == 3 and arr.shape[2] != 3:
130
  arr = np.transpose(arr, (1, 2, 0)) # CHW -> HWC
131
  if arr.dtype != np.uint8:
 
133
  rgb = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)
134
  return rgb
135
 
136
+ # ============================================================================
137
+
 
138
  class MatAnyoneSession:
139
  """
140
+ #8 Streaming wrapper that seeds MatAnyone with a SAM2 mask on frame 0.
141
+ - Tries 4D first; if the wheel truly wants 5D, promotes both image AND mask.
142
+ - Has an override env: MATANY_FORCE_FORMAT=4D|5D (for debugging).
 
143
  """
144
  def __init__(self, device: Optional[str] = None, precision: str = "auto"):
145
+ self.device = torch.device(device) if device else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
 
 
 
146
  self.precision = precision.lower()
147
 
148
+ # Optional override: MATANY_FORCE_FORMAT=4D|5D
149
+ fmt = os.getenv("MATANY_FORCE_FORMAT", "").strip().lower()
150
+ self._force_4d = (fmt == "4d")
151
+ self._force_5d = (fmt == "5d")
152
+ self._use_5d = self._force_5d # start in 5D only if forced
153
+
154
  try:
155
  from matanyone.inference.inference_core import InferenceCore
156
  except ImportError as e:
 
158
  try:
159
  self.core = InferenceCore()
160
  except TypeError:
161
+ # HF wheel constructor that needs a repo string
162
  self.core = InferenceCore("PeiqingYang/MatAnyone")
163
 
164
+ self.api = "step" if hasattr(self.core, "step") else ("process_frame" if hasattr(self.core, "process_frame") else None)
165
+ if not self.api:
 
 
 
 
166
  raise MatAnyError("MatAnyone core exposes neither 'step' nor 'process_frame'")
167
 
168
+ log.info(f"[MATANY] API: {self.api} | device={self.device} | force4d={self._force_4d} | force5d={self._force_5d}")
169
 
170
+ # ----- AMP policy -----
171
  def _amp(self):
172
+ """#9 Simple AMP gate (auto/fp16/fp32)"""
173
  if self.device.type != "cuda":
174
  return torch.amp.autocast(device_type="cuda", enabled=False)
175
  if self.precision == "fp32":
176
  return torch.amp.autocast(device_type="cuda", enabled=False)
177
  if self.precision == "fp16":
178
  return torch.amp.autocast(device_type="cuda", enabled=True, dtype=torch.float16)
179
+ # auto
180
  return torch.amp.autocast(device_type="cuda", enabled=True)
181
 
182
+ # ----- Tensor builders -----
183
+ def _to_tensors(self, img_hwc_rgb: np.ndarray, mask_hw: Optional[np.ndarray]):
184
  """
185
+ #10 Build both 4D and 5D tensors.
186
+ Returns: (img_4d, img_5d, mask_4d, mask_5d)
187
+ - img_4d: [1, 3, H, W]
188
+ - img_5d: [1, 1, 3, H, W]
189
+ - mask_4d: [1, 1, H, W] or None
190
+ - mask_5d: [1, 1, 1, H, W] or None
191
  """
192
+ img = torch.from_numpy(img_hwc_rgb).to(self.device)
 
193
  if img.dtype != torch.float32:
194
  img = img.float()
195
  if float(img.max().item()) > 1.0:
196
  img = img / 255.0
197
+
198
  img_chw = img.permute(2, 0, 1).contiguous() # [3,H,W]
199
+ img_4d = img_chw.unsqueeze(0) # [1,3,H,W]
200
+ img_5d = img_chw.unsqueeze(0).unsqueeze(0) # [1,1,3,H,W]
201
 
202
+ mask_4d = mask_5d = None
 
203
  if mask_hw is not None:
204
  m = torch.from_numpy(mask_hw).to(self.device)
205
  if m.dtype != torch.float32:
206
  m = m.float()
207
+ # robust binarize
208
+ m = (m >= 0.5).float() if float(m.max().item()) <= 1.0 else (m >= 128).float()
209
+ mask_4d = m.unsqueeze(0).unsqueeze(0).contiguous() # [1,1,H,W]
210
+ mask_5d = mask_4d.unsqueeze(1).contiguous() # [1,1,1,H,W]
211
+ return img_4d, img_5d, mask_4d, mask_5d
212
+
213
+ # ----- Core call (4D first, 5D only if demanded) -----
214
+ def _core_call(self, img_4d, img_5d, mask_4d, mask_5d, is_first: bool):
 
 
215
  """
216
+ #11 Dispatch into the wheel, trying 4D, then 5D if the error suggests it.
217
+ Also backs off from 5D 4D when conv2d complains about 3D/4D.
 
 
218
  """
219
+ def run(use_5d: bool):
220
+ img = img_5d if use_5d else img_4d
221
+ msk = mask_5d if use_5d else mask_4d # <<< IMPORTANT: match ranks
222
  if self.api == "step":
223
+ if is_first and msk is not None:
224
+ try:
225
+ return self.core.step(img, msk, is_first=True)
226
+ except TypeError:
227
+ return self.core.step(img, msk) # older signature
228
+ else:
229
+ return self.core.step(img)
 
 
 
 
 
230
  else:
231
+ return self.core.process_frame(img, msk if is_first else None)
 
232
 
233
+ with torch.no_grad(), self._amp():
234
+ # Forced modes for debugging
235
+ if self._force_4d:
236
+ return run(False)
237
+ if self._force_5d:
238
+ return run(True)
239
+
240
+ # If a previous frame decided on 5D, try 5D first but back off if needed
241
+ if self._use_5d:
242
+ try:
243
+ return run(True)
244
+ except RuntimeError as e5:
245
+ msg5 = str(e5)
246
+ # If the wheel says conv2d needs 3D/4D, revert to 4D permanently
247
+ if "Expected 3D" in msg5 and "4D" in msg5 and "conv2d" in msg5:
248
+ log.info("[MATANY] 5D rejected by wheel (conv2d wants 3D/4D). Falling back to 4D.")
249
+ self._use_5d = False
250
+ return run(False)
251
+ raise MatAnyError(f"Runtime error (5D path): {msg5}") from e5
252
+
253
+ # Default: try 4D first
254
+ try:
255
+ return run(False)
256
+ except RuntimeError as e4:
257
+ msg4 = str(e4)
258
+ # Hints that the wheel actually expects 5D
259
+ wants_5d = any(kw in msg4 for kw in [
260
+ "expected 5D",
261
+ "expects 5D",
262
+ "input.dim() == 5",
263
+ "but got 4D",
264
+ "got input of size: [1, 3," # some wheels report this pattern
265
+ ])
266
+ if wants_5d:
267
+ log.info("[MATANY] Wheel appears to expect 5D — retrying with [1,1,3,H,W] and [1,1,1,H,W].")
268
+ self._use_5d = True
269
+ try:
270
+ return run(True)
271
+ except RuntimeError as e5b:
272
+ msg5b = str(e5b)
273
+ # If retry says conv2d wants 3D/4D, undo and raise original
274
+ if "Expected 3D" in msg5b and "4D" in msg5b and "conv2d" in msg5b:
275
+ self._use_5d = False
276
+ raise MatAnyError(f"Wheel ultimately expects 4D (conv2d). Original 4D error: {msg4}") from e4
277
+ raise MatAnyError(f"5D attempt failed: {msg5b}") from e5b
278
+
279
+ # Add CUDA context for GPU errors
280
+ if "CUDA" in msg4 or "cublas" in msg4.lower() or "cudnn" in msg4.lower():
281
+ snap = _cuda_snapshot(self.device)
282
+ raise MatAnyError(f"CUDA runtime error: {msg4} | {snap}") from e4
283
+
284
+ # Generic wrap
285
+ raise MatAnyError(f"Runtime error (4D path): {msg4}") from e4
286
+
287
+ # ----- Per-frame runner -----
288
  def _run_frame(self, frame_bgr: np.ndarray, sam2_mask_hw: Optional[np.ndarray], is_first: bool) -> np.ndarray:
289
+ """#12 Convert inputs, seed frame 0, call core, and normalize to [H,W] alpha."""
290
+ rgb_hwc = _frame_bgr_to_hwc_rgb_numpy(frame_bgr)
291
  H, W = rgb_hwc.shape[:2]
292
 
293
+ seed_for_this_frame = None
294
  if is_first and sam2_mask_hw is not None:
295
+ seed_for_this_frame = _prepare_seed_mask(sam2_mask_hw, H, W)
296
 
297
+ img_4d, img_5d, mask_4d, mask_5d = self._to_tensors(rgb_hwc, seed_for_this_frame)
298
 
299
  try:
300
+ out = self._core_call(img_4d, img_5d, mask_4d, mask_5d, is_first)
301
  except torch.cuda.OutOfMemoryError as e:
302
  snap = _cuda_snapshot(self.device)
303
  raise MatAnyError(f"CUDA OOM while processing frame | {snap}") from e
304
+ except Exception as e:
 
 
 
 
305
  raise MatAnyError(f"Runtime error: {e}") from e
306
 
307
+ # Normalize output to [H,W] float32 in [0,1]
308
  if isinstance(out, torch.Tensor):
309
  alpha = out.detach().float().squeeze().cpu().numpy()
310
  else:
311
  alpha = np.asarray(out)
312
  alpha = alpha.astype(np.float32)
313
  if float(alpha.max()) > 1.0:
314
+ alpha /= 255.0
315
  alpha = np.squeeze(alpha)
316
  if alpha.ndim != 2:
317
  raise MatAnyError(f"Expected 2D alpha matte; got shape {alpha.shape}")
318
  return np.clip(alpha, 0.0, 1.0)
319
 
320
+ # ----- Public: streaming processor -----
 
 
321
  def process_stream(
322
  self,
323
  video_path: Path,
 
325
  out_dir: Optional[Path] = None,
326
  progress_cb: Optional[Callable] = None,
327
  ) -> Tuple[Path, Path]:
328
+ """
329
+ #13 Stream the video one frame at a time (T=1), write alpha.mp4 & fg.mp4.
330
+ """
331
  video_path = Path(video_path)
332
  if not video_path.exists():
333
  raise MatAnyError(f"Video file not found: {video_path}")
 
335
  out_dir = Path(out_dir) if out_dir else video_path.parent
336
  out_dir.mkdir(parents=True, exist_ok=True)
337
 
338
+ # Probe video
339
  cap_probe = cv2.VideoCapture(str(video_path))
340
  if not cap_probe.isOpened():
341
  raise MatAnyError(f"Failed to open video: {video_path}")
342
  N = int(cap_probe.get(cv2.CAP_PROP_FRAME_COUNT))
343
+ fps = cap_probe.get(cv2.CAP_PROP_FPS)
344
  W = int(cap_probe.get(cv2.CAP_PROP_FRAME_WIDTH))
345
  H = int(cap_probe.get(cv2.CAP_PROP_FRAME_HEIGHT))
346
  cap_probe.release()
347
+ if not fps or fps <= 0 or np.isnan(fps):
348
+ fps = 25.0
349
 
350
  log.info(f"MatAnyone: {video_path.name} | {N} frames {W}x{H} @ {fps:.2f} fps")
351
  _emit_progress(progress_cb, 0.05, f"Video: {N} frames {W}x{H} @ {fps:.2f} fps")
352
  _emit_progress(progress_cb, 0.08, "Using step (frame-by-frame)")
353
 
354
+ # Prepare writers
355
  alpha_path = out_dir / "alpha.mp4"
356
  fg_path = out_dir / "fg.mp4"
357
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
 
360
  if not alpha_writer.isOpened() or not fg_writer.isOpened():
361
  raise MatAnyError("Failed to initialize VideoWriter(s)")
362
 
363
+ # Load seed mask if provided (file path on disk)
364
  seed_mask_np = None
365
  if seed_mask_path is not None:
366
  p = Path(seed_mask_path)
 
369
  m = cv2.imread(str(p), cv2.IMREAD_GRAYSCALE)
370
  if m is None:
371
  raise MatAnyError(f"Failed to read seed mask: {p}")
372
+ seed_mask_np = m # we resize/polarize/binarize inside _run_frame
373
 
 
374
  cap = cv2.VideoCapture(str(video_path))
375
  if not cap.isOpened():
376
  raise MatAnyError(f"Failed to open video for reading: {video_path}")
 
383
  ret, frame = cap.read()
384
  if not ret:
385
  break
 
386
  is_first = (idx == 0)
387
+ alpha = self._run_frame(frame, seed_mask_np if is_first else None, is_first) # [H,W] in [0,1]
388
 
389
+ # Compose outputs (no double divide)
390
  alpha_u8 = (alpha * 255.0 + 0.5).astype(np.uint8)
391
  alpha_bgr = cv2.cvtColor(alpha_u8, cv2.COLOR_GRAY2BGR)
392
  fg_bgr = (frame.astype(np.float32) * alpha[..., None]).clip(0, 255).astype(np.uint8)
 
411
  except: pass
412
  _safe_empty_cache()
413
 
414
+ # Verify outputs are non-empty
415
  if not alpha_path.exists() or alpha_path.stat().st_size == 0:
416
  raise MatAnyError(f"Output file missing/empty: {alpha_path}")
417
  if not fg_path.exists() or fg_path.stat().st_size == 0: