MogensR commited on
Commit
975ab1f
·
1 Parent(s): 9923851
Files changed (1) hide show
  1. models/matanyone_loader.py +263 -227
models/matanyone_loader.py CHANGED
@@ -1,12 +1,13 @@
1
  #!/usr/bin/env python3
 
 
 
2
  """
3
- MatAnyone Adapter (streaming, API-agnostic)
4
- -------------------------------------------
5
  - Supports multiple MatAnyone variants:
6
  * frame API: core.step(image[, mask]) or core.process_frame(image, mask)
7
  * video API: core.process_video(video_path[, mask_path]) [DISABLED BY DEFAULT]
8
  - Streams frames: no full-video-in-RAM.
9
- - Emits alpha.mp4 (grayscale-as-BGR for compatibility) and fg.mp4 (RGB-on-black) as it goes.
10
  - Validates outputs and raises MatAnyError on failure (so pipeline can fallback).
11
 
12
  I/O conventions:
@@ -17,31 +18,37 @@
17
  Requires: OpenCV, Torch, NumPy
18
  """
19
 
 
 
 
20
  from __future__ import annotations
 
21
  import os
22
  import cv2
23
  import time
24
  import shutil
25
- import torch
26
  import logging
27
  import numpy as np
 
 
28
  from pathlib import Path
29
  from typing import Optional, Callable, Tuple, List
30
 
31
  log = logging.getLogger(__name__)
32
 
33
 
34
- # -----------------------------
35
- # Small utilities
36
- # -----------------------------
37
  def _emit_progress(cb, pct: float, msg: str):
 
38
  if not cb:
39
  return
40
  try:
41
- cb(pct, msg) # preferred 2-args
42
  except TypeError:
43
  try:
44
- cb(msg) # legacy 1-arg
45
  except TypeError:
46
  pass
47
 
@@ -52,6 +59,7 @@ class MatAnyError(RuntimeError):
52
 
53
 
54
  def _cuda_snapshot(device: Optional[torch.device] = None) -> str:
 
55
  if not torch.cuda.is_available():
56
  return "CUDA: N/A"
57
  idx = 0
@@ -64,6 +72,7 @@ def _cuda_snapshot(device: Optional[torch.device] = None) -> str:
64
 
65
 
66
  def _safe_empty_cache():
 
67
  if torch.cuda.is_available():
68
  try:
69
  torch.cuda.synchronize()
@@ -73,7 +82,7 @@ def _safe_empty_cache():
73
 
74
 
75
  def _read_mask_hw(mask_path: Path, target_hw: Tuple[int, int]) -> np.ndarray:
76
- """Read mask image, convert to float32 [0,1], resize to target (H,W)."""
77
  if not Path(mask_path).exists():
78
  raise MatAnyError(f"Seed mask not found: {mask_path}")
79
  mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
@@ -83,34 +92,28 @@ def _read_mask_hw(mask_path: Path, target_hw: Tuple[int, int]) -> np.ndarray:
83
  if mask.shape[:2] != (H, W):
84
  mask = cv2.resize(mask, (W, H), interpolation=cv2.INTER_LINEAR)
85
  maskf = (mask.astype(np.float32) / 255.0).clip(0.0, 1.0)
86
- return maskf # (H, W)
87
-
88
-
89
- def _to_hwc01(img_bgr: np.ndarray) -> np.ndarray:
90
- """BGR [H,W,3] uint8 -> HWC float32 [0,1] RGB."""
91
- rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
92
- rgbf = rgb.astype(np.float32) / 255.0
93
- return rgbf # (H, W, 3)
94
 
95
 
96
  def _to_chw01(img_bgr: np.ndarray) -> np.ndarray:
97
  """BGR [H,W,3] uint8 -> CHW float32 [0,1] RGB."""
98
  rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
99
  rgbf = rgb.astype(np.float32) / 255.0
100
- chw = np.transpose(rgbf, (2, 0, 1)) # (3, H, W)
101
  return chw
102
 
103
 
104
  def _validate_nonempty(file_path: Path) -> None:
 
105
  if not file_path.exists() or file_path.stat().st_size == 0:
106
  raise MatAnyError(f"Output file missing/empty: {file_path}")
107
 
108
 
109
  def _select_matany_mode(core) -> str:
110
  """
111
- Pick the best-available MatAnyone API at runtime.
112
  Priority: process_video > process_frame > step
113
- (Note: we force frame mode in _lazy_init; this helper is used only in a chunk helper.)
114
  """
115
  if hasattr(core, "process_video") and callable(getattr(core, "process_video")):
116
  return "process_video"
@@ -121,9 +124,9 @@ def _select_matany_mode(core) -> str:
121
  raise MatAnyError("No supported MatAnyone API on core (process_video/process_frame/step).")
122
 
123
 
124
- # -----------------------------
125
- # Main session
126
- # -----------------------------
127
  class MatAnyoneSession:
128
  """
129
  Unified, streaming wrapper over MatAnyone variants.
@@ -133,24 +136,23 @@ class MatAnyoneSession:
133
  -> returns (alpha_path, fg_path)
134
  """
135
 
 
 
 
136
  def __init__(self, device: Optional[str] = None, precision: str = "auto"):
137
  """
138
  Args:
139
  device: 'cuda', 'cpu', 'cuda:0', etc. If None, auto-detects CUDA.
140
  precision: 'auto' | 'fp32' | 'fp16'
141
  """
142
- self.device = torch.device(device) if device else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
 
 
143
  self.precision = precision.lower()
144
  self.use_fp16 = (self.precision == "fp16") or (self.precision == "auto" and self.device.type == "cuda")
145
  self._core = None
146
  self._api_mode = None
147
  self._initialized = False
148
-
149
- # chosen builders after first frame succeeds
150
- self._build_img = None # Callable[[np.ndarray], torch.Tensor]
151
- self._build_msk = None # Optional[Callable[[np.ndarray], Optional[torch.Tensor]]]
152
- self._layout_name = None
153
-
154
  self._lazy_init()
155
 
156
  log.info(f"Initialized MatAnyoneSession on {self.device} | precision={self.precision}, use_fp16={self.use_fp16}")
@@ -159,8 +161,8 @@ def __init__(self, device: Optional[str] = None, precision: str = "auto"):
159
  log.info(f"CUDA device: {torch.cuda.get_device_name(idx)}")
160
  self._log_gpu_memory()
161
 
162
- # ---- internals ----
163
  def _log_gpu_memory(self) -> Tuple[float, float]:
 
164
  if torch.cuda.is_available():
165
  idx = self.device.index if isinstance(self.device, torch.device) and self.device.index is not None else 0
166
  try:
@@ -172,8 +174,11 @@ def _log_gpu_memory(self) -> Tuple[float, float]:
172
  log.warning(f"Failed to read GPU memory: {e}")
173
  return 0.0, 0.0
174
 
 
 
 
175
  def _lazy_init(self) -> None:
176
- """Import and initialize the MatAnyone InferenceCore and choose API mode."""
177
  try:
178
  from matanyone.inference.inference_core import InferenceCore # type: ignore
179
  except ImportError as e:
@@ -187,20 +192,43 @@ def _lazy_init(self) -> None:
187
  except TypeError:
188
  self._core = InferenceCore("PeiqingYang/MatAnyone")
189
 
190
- # --- Force reliable frame-by-frame mode (avoid process_video) ---
191
  if hasattr(self._core, "process_frame"):
192
  self._api_mode = "process_frame"
193
  elif hasattr(self._core, "step"):
194
  self._api_mode = "step"
195
  else:
196
  raise MatAnyError(
197
- "MatAnyone build has no frame API (process_frame/step). "
198
- "Cannot proceed safely."
199
  )
200
 
201
  log.info(f"[MATANY] API mode forced to: {self._api_mode} (video-mode disabled)")
 
 
 
 
202
  self._initialized = True
203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  def _maybe_amp(self):
205
  enabled = (self.device.type == "cuda")
206
  if self.precision == "fp32":
@@ -210,6 +238,9 @@ def _maybe_amp(self):
210
  # auto
211
  return torch.amp.autocast(device_type="cuda", enabled=enabled and self.use_fp16)
212
 
 
 
 
213
  def _validate_input_frame(self, frame: np.ndarray) -> None:
214
  if not isinstance(frame, np.ndarray):
215
  raise MatAnyError(f"Frame must be numpy.ndarray, got {type(frame)}")
@@ -218,125 +249,118 @@ def _validate_input_frame(self, frame: np.ndarray) -> None:
218
  if frame.ndim != 3 or frame.shape[2] != 3:
219
  raise MatAnyError(f"Frame must be HWC with 3 channels, got {frame.shape}")
220
 
221
- def _core_call(self, img_t: torch.Tensor, mask_t: Optional[torch.Tensor]):
222
- if self._api_mode == "step":
223
- return self._core.step(img_t, mask_t) if mask_t is not None else self._core.step(img_t)
224
- elif self._api_mode == "process_frame":
225
- return self._core.process_frame(img_t, mask_t) if mask_t is not None else self._core.process_frame(img_t)
226
- raise MatAnyError("Internal error: unknown API mode")
227
-
228
- # ---- builders for probing ----
229
- def _mk_builder_bchw(self) -> Tuple[str, Callable[[np.ndarray], torch.Tensor], Callable[[np.ndarray], Optional[torch.Tensor]]]:
230
- def b_img(frame_bgr: np.ndarray) -> torch.Tensor:
231
- chw = _to_chw01(frame_bgr)
232
- return torch.from_numpy(chw).unsqueeze(0).contiguous().to(self.device, dtype=torch.float32, non_blocking=True) # [1,3,H,W]
233
- def b_msk(seed_hw: np.ndarray) -> torch.Tensor:
234
- return torch.from_numpy(seed_hw).unsqueeze(0).unsqueeze(0).contiguous().to(self.device, dtype=torch.float32, non_blocking=True) # [1,1,H,W]
235
- return "BCHW+B1HW", b_img, b_msk
236
-
237
- def _mk_builder_bchw_nomask(self) -> Tuple[str, Callable[[np.ndarray], torch.Tensor], Callable[[np.ndarray], Optional[torch.Tensor]]]:
238
- def b_img(frame_bgr: np.ndarray) -> torch.Tensor:
239
- chw = _to_chw01(frame_bgr)
240
- return torch.from_numpy(chw).unsqueeze(0).contiguous().to(self.device, dtype=torch.float32, non_blocking=True)
241
- def b_msk(_: np.ndarray) -> Optional[torch.Tensor]:
242
- return None
243
- return "BCHW+None", b_img, b_msk
244
-
245
- def _mk_builder_btchw(self) -> Tuple[str, Callable[[np.ndarray], torch.Tensor], Callable[[np.ndarray], Optional[torch.Tensor]]]:
246
- def b_img(frame_bgr: np.ndarray) -> torch.Tensor:
247
- chw = _to_chw01(frame_bgr)
248
- return torch.from_numpy(chw).unsqueeze(0).unsqueeze(1).contiguous().to(self.device, dtype=torch.float32, non_blocking=True) # [1,1,3,H,W]
249
- def b_msk(seed_hw: np.ndarray) -> torch.Tensor:
250
- return torch.from_numpy(seed_hw).unsqueeze(0).unsqueeze(0).unsqueeze(0).contiguous().to(self.device, dtype=torch.float32, non_blocking=True) # [1,1,1,H,W]
251
- return "BTCHW+B1THW", b_img, b_msk
252
-
253
- def _mk_builder_chw(self) -> Tuple[str, Callable[[np.ndarray], torch.Tensor], Callable[[np.ndarray], Optional[torch.Tensor]]]:
254
- def b_img(frame_bgr: np.ndarray) -> torch.Tensor:
255
- chw = _to_chw01(frame_bgr)
256
- return torch.from_numpy(chw).contiguous().to(self.device, dtype=torch.float32, non_blocking=True) # [3,H,W]
257
- def b_msk(seed_hw: np.ndarray) -> torch.Tensor:
258
- return torch.from_numpy(seed_hw).unsqueeze(0).contiguous().to(self.device, dtype=torch.float32, non_blocking=True) # [1,H,W]
259
- return "CHW+1HW", b_img, b_msk
260
-
261
- def _mk_builder_hwc(self) -> Tuple[str, Callable[[np.ndarray], torch.Tensor], Callable[[np.ndarray], Optional[torch.Tensor]]]:
262
- def b_img(frame_bgr: np.ndarray) -> torch.Tensor:
263
- hwc = _to_hwc01(frame_bgr)
264
- return torch.from_numpy(hwc).contiguous().to(self.device, dtype=torch.float32, non_blocking=True) # [H,W,3]
265
- def b_msk(seed_hw: np.ndarray) -> torch.Tensor:
266
- return torch.from_numpy(seed_hw).contiguous().to(self.device, dtype=torch.float32, non_blocking=True) # [H,W]
267
- return "HWC+HW", b_img, b_msk
268
-
269
- def _run_frame(self, frame_bgr: np.ndarray, seed_hw: Optional[np.ndarray], is_first: bool) -> np.ndarray:
270
  """
271
- Returns alpha matte as 2D np.float32 in [0,1].
272
- - On first frame, try several (image,mask) layout combos and remember the winner.
273
- - On later frames, use the recorded builders (mask is None).
274
  """
275
  self._validate_input_frame(frame_bgr)
276
 
277
- # Later frames: use the memorized builders
278
- if self._build_img is not None and not is_first:
279
- img_t = self._build_img(frame_bgr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  with torch.no_grad(), self._maybe_amp():
281
- out = self._core_call(img_t, None)
282
- alpha_np = out.detach().float().clamp(0, 1).squeeze().cpu().numpy() if isinstance(out, torch.Tensor) \
283
- else np.asarray(out, dtype=np.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  if alpha_np.max() > 1.0:
285
  alpha_np = alpha_np / 255.0
286
- alpha_np = np.squeeze(alpha_np)
287
- if alpha_np.ndim != 2:
288
- raise MatAnyError(f"Expected 2D alpha matte; got shape {alpha_np.shape}")
289
- return alpha_np.astype(np.float32)
290
-
291
- # First frame: probe combos
292
- attempts = [
293
- self._mk_builder_bchw(), # [1,3,H,W] + [1,1,H,W]
294
- self._mk_builder_bchw_nomask(), # [1,3,H,W] + None
295
- self._mk_builder_btchw(), # [1,1,3,H,W] + [1,1,1,H,W]
296
- self._mk_builder_chw(), # [3,H,W] + [1,H,W]
297
- self._mk_builder_hwc(), # [H,W,3] + [H,W]
298
- ]
299
-
300
- last_err = None
301
- for name, mk_img, mk_msk in attempts:
302
- try:
303
- img_t = mk_img(frame_bgr)
304
- mask_t = None
305
- if seed_hw is not None:
306
- mask_t = mk_msk(seed_hw)
307
-
308
- log.info(f"[MATANY] Trying layout: {name} | img.shape={tuple(img_t.shape)}"
309
- f"{'' if mask_t is None else ' mask.shape=' + str(tuple(mask_t.shape))}")
310
-
311
- with torch.no_grad(), self._maybe_amp():
312
- out = self._core_call(img_t, mask_t)
313
-
314
- # success → remember builders for subsequent frames
315
- self._build_img = mk_img
316
- # after first frame, we won't pass mask anymore
317
- self._build_msk = mk_msk
318
- self._layout_name = name
319
- log.info(f"[MATANY] Selected layout: {name}")
320
-
321
- alpha_np = out.detach().float().clamp(0, 1).squeeze().cpu().numpy() if isinstance(out, torch.Tensor) \
322
- else np.asarray(out, dtype=np.float32)
323
- if alpha_np.max() > 1.0:
324
- alpha_np = alpha_np / 255.0
325
- alpha_np = np.squeeze(alpha_np)
326
- if alpha_np.ndim != 2:
327
- raise MatAnyError(f"Expected 2D alpha matte; got shape {alpha_np.shape}")
328
- return alpha_np.astype(np.float32)
329
 
330
- except Exception as e:
331
- last_err = e
332
- log.warning(f"[MATANY] Layout attempt failed ({name}): {e}")
333
 
334
- snap = _cuda_snapshot(self.device)
335
- raise MatAnyError(f"MatAnyone first-frame probe failed for all layouts. Last error: {last_err} | {snap}")
336
 
337
- # -----------------------------
338
- # Public API
339
- # -----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  def process_stream(
341
  self,
342
  video_path: Path,
@@ -345,7 +369,7 @@ def process_stream(
345
  progress_cb: Optional[Callable] = None,
346
  ) -> Tuple[Path, Path]:
347
  """
348
- Process a video with MatAnyone.
349
 
350
  Returns:
351
  (alpha_path, fg_path)
@@ -376,74 +400,99 @@ def process_stream(
376
  log.info(f"[MATANY] {video_path.name}: {N} frames {W}x{H} @ {fps:.2f} fps")
377
  _emit_progress(progress_cb, 0.05, f"Video: {N} frames {W}x{H} @ {fps:.2f} fps")
378
 
379
- # Writers (alpha as BGR grayscale for broad mp4v compatibility)
380
  alpha_path = out_dir / "alpha.mp4"
381
  fg_path = out_dir / "fg.mp4"
382
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
383
-
384
- cap = cv2.VideoCapture(str(video_path))
385
- if not cap.isOpened():
386
- raise MatAnyError(f"Failed to open video for reading: {video_path}")
387
-
388
- alpha_writer = cv2.VideoWriter(str(alpha_path), fourcc, fps, (W, H), True) # isColor=True
389
- fg_writer = cv2.VideoWriter(str(fg_path), fourcc, fps, (W, H), True)
390
- if not alpha_writer.isOpened() or not fg_writer.isOpened():
391
- raise MatAnyError("Failed to initialize VideoWriter(s)")
392
-
393
- # Optional seed mask for first frame
394
- seed_hw = None
395
- if seed_mask_path is not None:
396
- seed_hw = _read_mask_hw(Path(seed_mask_path), (H, W))
397
-
398
- idx = 0
399
- last_tick = time.time()
400
- start = time.time()
401
 
402
  try:
403
- while True:
404
- ret, frame = cap.read()
405
- if not ret:
406
- break
407
-
408
- alpha_hw = self._run_frame(frame, seed_hw if idx == 0 else None, is_first=(idx == 0))
409
-
410
- # Compose outputs
411
- alpha_u8 = (alpha_hw * 255.0 + 0.5).astype(np.uint8)
412
- alpha_bgr = cv2.cvtColor(alpha_u8, cv2.COLOR_GRAY2BGR)
413
- fg_bgr = (frame.astype(np.float32) * alpha_hw[..., None]).clip(0, 255).astype(np.uint8)
414
-
415
- alpha_writer.write(alpha_bgr)
416
- fg_writer.write(fg_bgr)
417
-
418
- idx += 1
419
- # progress & ETA
420
- if N > 0 and (idx % max(5, N // 100) == 0 or (time.time() - last_tick) > 2.0):
421
- elapsed = time.time() - start
422
- prog = idx / max(1, N)
423
- eta_s = (elapsed / prog) * (1.0 - prog) if prog > 0 else 0.0
424
- if eta_s > 3600:
425
- eta = f"{eta_s/3600:.1f} h"
426
- elif eta_s > 60:
427
- eta = f"{eta_s/60:.1f} m"
428
- else:
429
- eta = f"{eta_s:.0f} s"
430
- fps_run = idx / elapsed if elapsed > 0 else 0.0
431
- gpu_tail = ""
432
- if torch.cuda.is_available():
433
- idx_dev = self.device.index if self.device.index is not None else 0
434
- mem_a = torch.cuda.memory_allocated(idx_dev) / 1024**2
435
- mem_r = torch.cuda.memory_reserved(idx_dev) / 1024**2
436
- gpu_tail = f" | GPU {mem_a:.0f}/{mem_r:.0f}MB"
437
- _emit_progress(progress_cb, min(0.99, prog), f"Frame {idx}/{N} • {fps_run:.1f} FPS • ETA {eta}{gpu_tail}")
438
- last_tick = time.time()
439
-
440
- # finalize
441
- _validate_nonempty(alpha_path)
442
- _validate_nonempty(fg_path)
443
- total = time.time() - start
444
- fps_run = idx / total if total > 0 else 0.0
445
- _emit_progress(progress_cb, 1.0, f"Complete! {idx} frames at {fps_run:.1f} FPS")
446
- return alpha_path, fg_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447
 
448
  except Exception as e:
449
  msg = f"Error during video processing: {e}"
@@ -452,20 +501,7 @@ def process_stream(
452
  msg += f" | {_cuda_snapshot(self.device)}"
453
  _emit_progress(progress_cb, -1, msg)
454
  raise MatAnyError(msg) from e
455
- finally:
456
- try:
457
- if cap and hasattr(cap, "isOpened") and cap.isOpened():
458
- cap.release()
459
- except Exception:
460
- pass
461
- try:
462
- if alpha_writer:
463
- alpha_writer.release()
464
- except Exception:
465
- pass
466
- try:
467
- if fg_writer:
468
- fg_writer.release()
469
- except Exception:
470
- pass
471
- _safe_empty_cache()
 
1
  #!/usr/bin/env python3
2
+ # =============================================================================
3
+ # MatAnyone Adapter (streaming, API-agnostic) — with chapter markers
4
+ # =============================================================================
5
  """
 
 
6
  - Supports multiple MatAnyone variants:
7
  * frame API: core.step(image[, mask]) or core.process_frame(image, mask)
8
  * video API: core.process_video(video_path[, mask_path]) [DISABLED BY DEFAULT]
9
  - Streams frames: no full-video-in-RAM.
10
+ - Emits alpha.mp4 (grayscale-as-BGR for compatibility) and fg.mp4 (RGB-on-black).
11
  - Validates outputs and raises MatAnyError on failure (so pipeline can fallback).
12
 
13
  I/O conventions:
 
18
  Requires: OpenCV, Torch, NumPy
19
  """
20
 
21
+ # =============================================================================
22
+ # CHAPTER 0 — Imports & logging
23
+ # =============================================================================
24
  from __future__ import annotations
25
+
26
  import os
27
  import cv2
28
  import time
29
  import shutil
 
30
  import logging
31
  import numpy as np
32
+ import torch
33
+
34
  from pathlib import Path
35
  from typing import Optional, Callable, Tuple, List
36
 
37
  log = logging.getLogger(__name__)
38
 
39
 
40
+ # =============================================================================
41
+ # CHAPTER 1 — Small utilities
42
+ # =============================================================================
43
  def _emit_progress(cb, pct: float, msg: str):
44
+ """Route progress to callback (supports new 2-arg and legacy 1-arg styles)."""
45
  if not cb:
46
  return
47
  try:
48
+ cb(pct, msg) # preferred 2-arg
49
  except TypeError:
50
  try:
51
+ cb(msg) # legacy 1-arg
52
  except TypeError:
53
  pass
54
 
 
59
 
60
 
61
  def _cuda_snapshot(device: Optional[torch.device] = None) -> str:
62
+ """Human-friendly GPU memory snapshot."""
63
  if not torch.cuda.is_available():
64
  return "CUDA: N/A"
65
  idx = 0
 
72
 
73
 
74
  def _safe_empty_cache():
75
+ """Synchronize and empty CUDA cache if present (best-effort)."""
76
  if torch.cuda.is_available():
77
  try:
78
  torch.cuda.synchronize()
 
82
 
83
 
84
  def _read_mask_hw(mask_path: Path, target_hw: Tuple[int, int]) -> np.ndarray:
85
+ """Read mask, convert to float32 [0,1], resize to target (H,W)."""
86
  if not Path(mask_path).exists():
87
  raise MatAnyError(f"Seed mask not found: {mask_path}")
88
  mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
 
92
  if mask.shape[:2] != (H, W):
93
  mask = cv2.resize(mask, (W, H), interpolation=cv2.INTER_LINEAR)
94
  maskf = (mask.astype(np.float32) / 255.0).clip(0.0, 1.0)
95
+ return maskf
 
 
 
 
 
 
 
96
 
97
 
98
  def _to_chw01(img_bgr: np.ndarray) -> np.ndarray:
99
  """BGR [H,W,3] uint8 -> CHW float32 [0,1] RGB."""
100
  rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
101
  rgbf = rgb.astype(np.float32) / 255.0
102
+ chw = np.transpose(rgbf, (2, 0, 1)) # C,H,W
103
  return chw
104
 
105
 
106
  def _validate_nonempty(file_path: Path) -> None:
107
+ """Ensure output file exists and is non-empty."""
108
  if not file_path.exists() or file_path.stat().st_size == 0:
109
  raise MatAnyError(f"Output file missing/empty: {file_path}")
110
 
111
 
112
  def _select_matany_mode(core) -> str:
113
  """
114
+ Inspect available APIs.
115
  Priority: process_video > process_frame > step
116
+ (Note: we still force frame mode in _lazy_init; this helper is used by chunk helper.)
117
  """
118
  if hasattr(core, "process_video") and callable(getattr(core, "process_video")):
119
  return "process_video"
 
124
  raise MatAnyError("No supported MatAnyone API on core (process_video/process_frame/step).")
125
 
126
 
127
+ # =============================================================================
128
+ # CHAPTER 2 — Main session
129
+ # =============================================================================
130
  class MatAnyoneSession:
131
  """
132
  Unified, streaming wrapper over MatAnyone variants.
 
136
  -> returns (alpha_path, fg_path)
137
  """
138
 
139
+ # -------------------------------------------------------------------------
140
+ # 2.1 — Init & device
141
+ # -------------------------------------------------------------------------
142
  def __init__(self, device: Optional[str] = None, precision: str = "auto"):
143
  """
144
  Args:
145
  device: 'cuda', 'cpu', 'cuda:0', etc. If None, auto-detects CUDA.
146
  precision: 'auto' | 'fp32' | 'fp16'
147
  """
148
+ self.device = torch.device(device) if device else (
149
+ torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
150
+ )
151
  self.precision = precision.lower()
152
  self.use_fp16 = (self.precision == "fp16") or (self.precision == "auto" and self.device.type == "cuda")
153
  self._core = None
154
  self._api_mode = None
155
  self._initialized = False
 
 
 
 
 
 
156
  self._lazy_init()
157
 
158
  log.info(f"Initialized MatAnyoneSession on {self.device} | precision={self.precision}, use_fp16={self.use_fp16}")
 
161
  log.info(f"CUDA device: {torch.cuda.get_device_name(idx)}")
162
  self._log_gpu_memory()
163
 
 
164
  def _log_gpu_memory(self) -> Tuple[float, float]:
165
+ """Log current GPU memory usage (MB)."""
166
  if torch.cuda.is_available():
167
  idx = self.device.index if isinstance(self.device, torch.device) and self.device.index is not None else 0
168
  try:
 
174
  log.warning(f"Failed to read GPU memory: {e}")
175
  return 0.0, 0.0
176
 
177
+ # -------------------------------------------------------------------------
178
+ # 2.2 — Lazy init of MatAnyone core & API selection + API probe
179
+ # -------------------------------------------------------------------------
180
  def _lazy_init(self) -> None:
181
+ """Import and initialize the MatAnyone InferenceCore, choose API mode, and probe capabilities."""
182
  try:
183
  from matanyone.inference.inference_core import InferenceCore # type: ignore
184
  except ImportError as e:
 
192
  except TypeError:
193
  self._core = InferenceCore("PeiqingYang/MatAnyone")
194
 
195
+ # ---- Force reliable frame-by-frame mode (avoid process_video by default)
196
  if hasattr(self._core, "process_frame"):
197
  self._api_mode = "process_frame"
198
  elif hasattr(self._core, "step"):
199
  self._api_mode = "step"
200
  else:
201
  raise MatAnyError(
202
+ "MatAnyone build has no frame API (process_frame/step). Cannot proceed safely."
 
203
  )
204
 
205
  log.info(f"[MATANY] API mode forced to: {self._api_mode} (video-mode disabled)")
206
+
207
+ # Probe & log exactly what APIs exist (and process_video signature if available)
208
+ self._probe_api_support()
209
+
210
  self._initialized = True
211
 
212
+ def _probe_api_support(self) -> None:
213
+ """Log which APIs the installed MatAnyone exposes + best-effort signature for process_video."""
214
+ core = self._core
215
+ have = {
216
+ "process_video": hasattr(core, "process_video") and callable(getattr(core, "process_video", None)),
217
+ "process_frame": hasattr(core, "process_frame") and callable(getattr(core, "process_frame", None)),
218
+ "step": hasattr(core, "step") and callable(getattr(core, "step", None)),
219
+ }
220
+ log.info(f"[MATANY] API availability: {have}")
221
+ if have["process_video"]:
222
+ try:
223
+ import inspect
224
+ sig = inspect.signature(core.process_video) # type: ignore[attr-defined]
225
+ log.info(f"[MATANY] process_video signature: {sig}")
226
+ except Exception as e:
227
+ log.info(f"[MATANY] process_video signature probe failed: {e}")
228
+
229
+ # -------------------------------------------------------------------------
230
+ # 2.3 — Autocast policy
231
+ # -------------------------------------------------------------------------
232
  def _maybe_amp(self):
233
  enabled = (self.device.type == "cuda")
234
  if self.precision == "fp32":
 
238
  # auto
239
  return torch.amp.autocast(device_type="cuda", enabled=enabled and self.use_fp16)
240
 
241
+ # -------------------------------------------------------------------------
242
+ # 2.4 — Frame validation & core call
243
+ # -------------------------------------------------------------------------
244
  def _validate_input_frame(self, frame: np.ndarray) -> None:
245
  if not isinstance(frame, np.ndarray):
246
  raise MatAnyError(f"Frame must be numpy.ndarray, got {type(frame)}")
 
249
  if frame.ndim != 3 or frame.shape[2] != 3:
250
  raise MatAnyError(f"Frame must be HWC with 3 channels, got {frame.shape}")
251
 
252
+ def _run_frame(self, frame_bgr: np.ndarray, seed_1hw: Optional[np.ndarray], is_first: bool) -> np.ndarray:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  """
254
+ Run a single frame through MatAnyone.
255
+ Returns: alpha matte as 2D np.float32 in [0,1].
 
256
  """
257
  self._validate_input_frame(frame_bgr)
258
 
259
+ # Image -> CHW float32 [0,1], then torch on device
260
+ img_chw = _to_chw01(frame_bgr) # (3,H,W) float32
261
+ img_t = torch.from_numpy(img_chw).to(self.device)
262
+
263
+ # Optional seed mask on first frame: expect HW float32 [0,1]
264
+ mask_t = None
265
+ if is_first and seed_1hw is not None:
266
+ if seed_1hw.ndim == 3 and seed_1hw.shape[0] == 1:
267
+ seed_hw = seed_1hw[0]
268
+ elif seed_1hw.ndim == 2:
269
+ seed_hw = seed_1hw
270
+ else:
271
+ raise MatAnyError(f"seed mask must be 1HW or HW; got {seed_1hw.shape}")
272
+ mask_t = torch.from_numpy(seed_hw).to(self.device)
273
+
274
+ # Dispatch into the selected frame API
275
+ try:
276
  with torch.no_grad(), self._maybe_amp():
277
+ if self._api_mode == "step":
278
+ out = self._core.step(img_t, mask_t) if mask_t is not None else self._core.step(img_t)
279
+ elif self._api_mode == "process_frame":
280
+ out = self._core.process_frame(img_t, mask_t)
281
+ else:
282
+ raise MatAnyError("Internal error: _run_frame used in non-frame mode")
283
+ except torch.cuda.OutOfMemoryError as e:
284
+ snap = _cuda_snapshot(self.device)
285
+ self._log_gpu_memory()
286
+ raise MatAnyError(f"CUDA OOM while processing frame | {snap}") from e
287
+ except RuntimeError as e:
288
+ # If it’s a CUDA-side runtime issue, annotate with snapshot
289
+ if "CUDA" in str(e):
290
+ snap = _cuda_snapshot(self.device)
291
+ self._log_gpu_memory()
292
+ raise MatAnyError(f"CUDA runtime error: {e} | {snap}") from e
293
+ raise MatAnyError(f"Runtime error: {e}") from e
294
+ except Exception as e:
295
+ raise MatAnyError(f"Processing failed: {e}") from e
296
+
297
+ # Normalize to pure 2D numpy [0,1]
298
+ if isinstance(out, torch.Tensor):
299
+ alpha_np = out.detach().float().clamp(0, 1).squeeze().cpu().numpy()
300
+ else:
301
+ alpha_np = np.asarray(out, dtype=np.float32)
302
  if alpha_np.max() > 1.0:
303
  alpha_np = alpha_np / 255.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
 
305
+ alpha_np = np.squeeze(alpha_np)
306
+ if alpha_np.ndim != 2:
307
+ raise MatAnyError(f"Expected 2D alpha matte; got shape {alpha_np.shape}")
308
 
309
+ return alpha_np.astype(np.float32)
 
310
 
311
+ # -------------------------------------------------------------------------
312
+ # 2.5 — process_video harvesting (kept for completeness; not used in forced frame mode)
313
+ # -------------------------------------------------------------------------
314
+ def _harvest_process_video_output(self, res, out_dir: Path, base: str) -> Tuple[Path, Path]:
315
+ """
316
+ Accepts varied return types from MatAnyone.process_video and produces
317
+ (alpha.mp4, fg.mp4) inside out_dir. Strategy: prefer path returns; fallback glob.
318
+ If backend returns arrays only, we raise (cannot reconstruct FG here).
319
+ """
320
+ alpha_mp4 = out_dir / "alpha.mp4"
321
+ fg_mp4 = out_dir / "fg.mp4"
322
+
323
+ # Dict style: look for common keys
324
+ if isinstance(res, dict):
325
+ cand_alpha = res.get("alpha") or res.get("alpha_path") or res.get("matte") or res.get("matte_path")
326
+ cand_fg = res.get("fg") or res.get("fg_path") or res.get("foreground") or res.get("foreground_path")
327
+ moved = 0
328
+ if cand_alpha and Path(cand_alpha).exists():
329
+ shutil.copy2(cand_alpha, alpha_mp4); moved += 1
330
+ if cand_fg and Path(cand_fg).exists():
331
+ shutil.copy2(cand_fg, fg_mp4); moved += 1
332
+ if moved == 2:
333
+ return alpha_mp4, fg_mp4
334
+
335
+ # Tuple/list of paths
336
+ if isinstance(res, (list, tuple)) and len(res) >= 1:
337
+ paths = [Path(x) for x in res if isinstance(x, (str, Path))]
338
+ if paths:
339
+ alpha_candidates = [p for p in paths if p.exists() and ("alpha" in p.name or "matte" in p.name)]
340
+ fg_candidates = [p for p in paths if p.exists() and ("fg" in p.name or "fore" in p.name)]
341
+ if alpha_candidates and fg_candidates:
342
+ shutil.copy2(alpha_candidates[0], alpha_mp4)
343
+ shutil.copy2(fg_candidates[0], fg_mp4)
344
+ return alpha_mp4, fg_mp4
345
+
346
+ # Fallback: glob common dirs
347
+ search_dirs = [Path.cwd(), out_dir, Path("results"), Path("result"), Path("output"), Path("outputs")]
348
+ hits: List[Path] = []
349
+ for d in search_dirs:
350
+ if d.exists():
351
+ hits.extend(list(d.rglob(f"*{base}*.*")))
352
+ alpha_candidates = [p for p in hits if p.suffix.lower() in (".mp4",".mov",".mkv",".avi") and ("alpha" in p.name or "matte" in p.name)]
353
+ fg_candidates = [p for p in hits if p.suffix.lower() in (".mp4",".mov",".mkv",".avi") and ("fg" in p.name or "fore" in p.name)]
354
+ if alpha_candidates and fg_candidates:
355
+ shutil.copy2(alpha_candidates[0], alpha_mp4)
356
+ shutil.copy2(fg_candidates[0], fg_mp4)
357
+ return alpha_mp4, fg_mp4
358
+
359
+ raise MatAnyError("MatAnyone.process_video did not yield discoverable output paths.")
360
+
361
+ # -------------------------------------------------------------------------
362
+ # 2.6 — Public API: process_stream
363
+ # -------------------------------------------------------------------------
364
  def process_stream(
365
  self,
366
  video_path: Path,
 
369
  progress_cb: Optional[Callable] = None,
370
  ) -> Tuple[Path, Path]:
371
  """
372
+ Process a video with MatAnyone (frame-by-frame path enforced by default).
373
 
374
  Returns:
375
  (alpha_path, fg_path)
 
400
  log.info(f"[MATANY] {video_path.name}: {N} frames {W}x{H} @ {fps:.2f} fps")
401
  _emit_progress(progress_cb, 0.05, f"Video: {N} frames {W}x{H} @ {fps:.2f} fps")
402
 
 
403
  alpha_path = out_dir / "alpha.mp4"
404
  fg_path = out_dir / "fg.mp4"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
 
406
  try:
407
+ # -----------------------------
408
+ # Frame-by-frame streaming path
409
+ # -----------------------------
410
+ _emit_progress(progress_cb, 0.10, f"Using {self._api_mode} (frame-by-frame)")
411
+ cap = cv2.VideoCapture(str(video_path))
412
+ if not cap.isOpened():
413
+ raise MatAnyError(f"Failed to open video for reading: {video_path}")
414
+
415
+ # Writers (alpha as BGR grayscale for broad mp4v compatibility)
416
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
417
+ alpha_writer = cv2.VideoWriter(str(alpha_path), fourcc, fps, (W, H), True) # isColor=True
418
+ fg_writer = cv2.VideoWriter(str(fg_path), fourcc, fps, (W, H), True)
419
+ if not alpha_writer.isOpened() or not fg_writer.isOpened():
420
+ raise MatAnyError("Failed to initialize VideoWriter(s)")
421
+
422
+ # Optional seed mask (resized to video HxW, normalized to [0,1])
423
+ seed_1hw = None
424
+ if seed_mask_path is not None:
425
+ seed_1hw = _read_mask_hw(Path(seed_mask_path), (H, W))
426
+
427
+ idx = 0
428
+ last_tick = time.time()
429
+ start = time.time()
430
+
431
+ try:
432
+ while True:
433
+ ret, frame = cap.read()
434
+ if not ret:
435
+ break
436
+
437
+ current_mask = seed_1hw if idx == 0 else None
438
+ alpha_hw = self._run_frame(frame, current_mask, is_first=(idx == 0))
439
+
440
+ # Compose outputs
441
+ alpha_u8 = (alpha_hw * 255.0 + 0.5).astype(np.uint8)
442
+ alpha_bgr = cv2.cvtColor(alpha_u8, cv2.COLOR_GRAY2BGR)
443
+ # alpha_hw already [0,1]
444
+ fg_bgr = (frame.astype(np.float32) * alpha_hw[..., None]).clip(0, 255).astype(np.uint8)
445
+
446
+ alpha_writer.write(alpha_bgr)
447
+ fg_writer.write(fg_bgr)
448
+
449
+ idx += 1
450
+ # progress & ETA
451
+ if N > 0 and (idx % max(5, N // 100) == 0 or (time.time() - last_tick) > 2.0):
452
+ elapsed = time.time() - start
453
+ prog = idx / max(1, N)
454
+ eta_s = (elapsed / prog) * (1.0 - prog) if prog > 0 else 0.0
455
+ if eta_s > 3600:
456
+ eta = f"{eta_s/3600:.1f} h"
457
+ elif eta_s > 60:
458
+ eta = f"{eta_s/60:.1f} m"
459
+ else:
460
+ eta = f"{eta_s:.0f} s"
461
+ fps_run = idx / elapsed if elapsed > 0 else 0.0
462
+ gpu_tail = ""
463
+ if torch.cuda.is_available():
464
+ idx_dev = self.device.index if self.device.index is not None else 0
465
+ mem_a = torch.cuda.memory_allocated(idx_dev) / 1024**2
466
+ mem_r = torch.cuda.memory_reserved(idx_dev) / 1024**2
467
+ gpu_tail = f" | GPU {mem_a:.0f}/{mem_r:.0f}MB"
468
+ _emit_progress(progress_cb, min(0.99, prog), f"Frame {idx}/{N} • {fps_run:.1f} FPS • ETA {eta}{gpu_tail}")
469
+ last_tick = time.time()
470
+
471
+ # finalize
472
+ _validate_nonempty(alpha_path)
473
+ _validate_nonempty(fg_path)
474
+ total = time.time() - start
475
+ fps_run = idx / total if total > 0 else 0.0
476
+ _emit_progress(progress_cb, 1.0, f"Complete! {idx} frames at {fps_run:.1f} FPS")
477
+ return alpha_path, fg_path
478
+
479
+ finally:
480
+ try:
481
+ if cap and hasattr(cap, "isOpened") and cap.isOpened():
482
+ cap.release()
483
+ except Exception:
484
+ pass
485
+ try:
486
+ if alpha_writer:
487
+ alpha_writer.release()
488
+ except Exception:
489
+ pass
490
+ try:
491
+ if fg_writer:
492
+ fg_writer.release()
493
+ except Exception:
494
+ pass
495
+ _safe_empty_cache()
496
 
497
  except Exception as e:
498
  msg = f"Error during video processing: {e}"
 
501
  msg += f" | {_cuda_snapshot(self.device)}"
502
  _emit_progress(progress_cb, -1, msg)
503
  raise MatAnyError(msg) from e
504
+
505
+ # =============================================================================
506
+ # END OF FILE
507
+ # =============================================================================