MogensR commited on
Commit
ee1b711
·
1 Parent(s): e295279
Files changed (1) hide show
  1. models/matanyone_loader.py +295 -832
models/matanyone_loader.py CHANGED
@@ -2,11 +2,11 @@
2
  """
3
  MatAnyone Adapter (streaming, API-agnostic)
4
  -------------------------------------------
5
- - Works with multiple MatAnyone variants:
6
- - frame API: core.step(image[, mask]) or session.process_frame(image, mask)
7
- - video API: process_video(frames, mask) (falls back to chunking)
8
  - Streams frames: no full-video-in-RAM.
9
- - Emits alpha.mp4 (grayscale) and fg.mp4 (RGB) as it goes.
10
  - Validates outputs and raises MatAnyError on failure (so pipeline can fallback).
11
 
12
  I/O conventions:
@@ -21,18 +21,21 @@
21
  import os
22
  import cv2
23
  import sys
24
- import json
25
- import math
26
  import time
 
 
27
  import torch
28
  import logging
29
- import tempfile
30
  import numpy as np
31
  from pathlib import Path
32
- from typing import Optional, Callable, Tuple, Union
33
 
34
  log = logging.getLogger(__name__)
35
 
 
 
 
 
36
  def _emit_progress(cb, pct: float, msg: str):
37
  if not cb:
38
  return
@@ -42,85 +45,24 @@ def _emit_progress(cb, pct: float, msg: str):
42
  try:
43
  cb(msg) # legacy 1-arg
44
  except TypeError:
45
- pass # ignore if cb is incompatible
 
46
 
47
  class MatAnyError(RuntimeError):
48
  """Custom exception for MatAnyone processing errors."""
49
  pass
50
 
51
 
52
- def _to_device_batch(frames_bgr_np, device, dtype=torch.float16):
53
- """
54
- frames_bgr_np: list or np.ndarray of shape [N,H,W,3], dtype=uint8, BGR
55
- Returns torch tensor [N,3,H,W] on device, normalized to 0..1
56
- """
57
- if isinstance(frames_bgr_np, list):
58
- frames_bgr_np = np.stack(frames_bgr_np, axis=0)
59
- frames_rgb = frames_bgr_np[..., ::-1].copy(order="C") # BGR->RGB
60
- pin = torch.from_numpy(frames_rgb).pin_memory() # [N,H,W,3]
61
- t = pin.permute(0, 3, 1, 2).contiguous().to(device, non_blocking=True)
62
- t = t.to(dtype=dtype) / 255.0
63
- return t # [N,3,H,W]
64
-
65
-
66
- def _select_matany_mode(core):
67
- """Pick best available API."""
68
- if hasattr(core, "process_frame"):
69
- return "process_frame"
70
- if hasattr(core, "_process_tensor_video"):
71
- return "_process_tensor_video"
72
- if hasattr(core, "step"):
73
- return "step"
74
- raise MatAnyError("MatAnyone core has no supported API (process_frame/_process_tensor_video/step).")
75
-
76
-
77
- def _matany_run(core, mode, frames_04chw, seed_1hw=None, use_fp16=False):
78
- """
79
- Returns (alpha [N,1,H,W], fg [N,3,H,W]) on current device.
80
- """
81
- with torch.no_grad():
82
- if mode == "process_frame":
83
- alphas, fgs = [], []
84
- for i in range(frames_04chw.shape[0]):
85
- f = frames_04chw[i:i+1] # [1,3,H,W]
86
- if seed_1hw is not None and seed_1hw.ndim == 3:
87
- a, fg = core.process_frame(f, seed_1hw.unsqueeze(0))
88
- else:
89
- a, fg = core.process_frame(f)
90
- alphas.append(a) # [1,1,H,W]
91
- fgs.append(fg) # [1,3,H,W]
92
- alpha = torch.cat(alphas, dim=0)
93
- fg = torch.cat(fgs, dim=0)
94
- return alpha, fg
95
-
96
- elif mode == "_process_tensor_video":
97
- # Many repos expect float32 for this path
98
- return core._process_tensor_video(frames_04chw.float(), seed_1hw)
99
-
100
- elif mode == "step":
101
- alphas, fgs = [], []
102
- for i in range(frames_04chw.shape[0]):
103
- f = frames_04chw[i:i+1]
104
- if i == 0 and seed_1hw is not None:
105
- a, fg = core.step(f, seed_1hw)
106
- else:
107
- a, fg = core.step(f)
108
- alphas.append(a)
109
- fgs.append(fg)
110
- alpha = torch.cat(alphas, dim=0)
111
- fg = torch.cat(fgs, dim=0)
112
- return alpha, fg
113
-
114
- raise MatAnyError(f"Unsupported MatAnyone mode: {mode}")
115
-
116
-
117
- def _cuda_snapshot():
118
  if not torch.cuda.is_available():
119
  return "CUDA: N/A"
120
- i = torch.cuda.current_device()
121
- return (f"device={i}, name={torch.cuda.get_device_name(i)}, "
122
- f"alloc={torch.cuda.memory_allocated(i)/1e9:.2f}GB, "
123
- f"reserved={torch.cuda.memory_reserved(i)/1e9:.2f}GB")
 
 
 
124
 
125
 
126
  def _safe_empty_cache():
@@ -132,99 +74,6 @@ def _safe_empty_cache():
132
  torch.cuda.empty_cache()
133
 
134
 
135
- def _to_uint8_cpu(alpha_n1hw, fg_n3hw):
136
- alpha_cpu = (alpha_n1hw.clamp(0, 1) * 255.0).byte().squeeze(1).contiguous().cpu().numpy() # [N,H,W]
137
- fg_cpu = (fg_n3hw.clamp(0, 1) * 255.0).byte().permute(0, 2, 3, 1).contiguous().cpu().numpy() # [N,H,W,3] RGB
138
- return alpha_cpu, fg_cpu
139
-
140
-
141
- def _to_device_batch(frames_bgr_np, device, dtype=torch.float16):
142
- """
143
- Convert a list/array of BGR uint8 frames [N,H,W,3] to a normalized
144
- CHW tensor on device using pinned memory + non_blocking copies.
145
- """
146
- if isinstance(frames_bgr_np, list):
147
- frames_bgr_np = np.stack(frames_bgr_np, axis=0) # [N,H,W,3]
148
- # BGR -> RGB
149
- frames_rgb = frames_bgr_np[..., ::-1].copy(order="C")
150
- # to torch
151
- pin = torch.from_numpy(frames_rgb).pin_memory() # uint8 [N,H,W,3]
152
- # NCHW and normalize
153
- t = pin.permute(0, 3, 1, 2).contiguous().to(device, non_blocking=True)
154
- t = t.to(dtype=dtype) / 255.0
155
- return t # [N,3,H,W]
156
-
157
-
158
- def _select_matany_mode(core):
159
- """
160
- Pick the best-available MatAnyone API at runtime.
161
- Priority: process_frame > _process_tensor_video > step
162
- """
163
- if hasattr(core, "process_frame"):
164
- return "process_frame"
165
- if hasattr(core, "_process_tensor_video"):
166
- return "_process_tensor_video"
167
- if hasattr(core, "step"):
168
- return "step"
169
- raise MatAnyError("No supported MatAnyone API on core (process_frame/_process_tensor_video/step).")
170
-
171
-
172
- def _matany_run(core, mode, frames_04chw, seed_1hw=None):
173
- """
174
- Dispatch into the selected API. All tensors are on device.
175
- Returns (alpha_1nhw, fg_n3hw) where alpha is [N,1,H,W], fg [N,3,H,W].
176
- """
177
- with torch.no_grad():
178
- if mode == "process_frame":
179
- alphas, fgs = [], []
180
- # process_frame usually wants per-frame tensors in [1,3,H,W]
181
- for i in range(frames_04chw.shape[0]):
182
- f = frames_04chw[i:i+1] # [1,3,H,W]
183
- if seed_1hw is not None and seed_1hw.ndim == 3:
184
- a, fg = core.process_frame(f, seed_1hw.unsqueeze(0))
185
- else:
186
- a, fg = core.process_frame(f)
187
- alphas.append(a) # [1,1,H,W]
188
- fgs.append(fg) # [1,3,H,W]
189
- alpha = torch.cat(alphas, dim=0)
190
- fg = torch.cat(fgs, dim=0)
191
- return alpha, fg
192
-
193
- elif mode == "_process_tensor_video":
194
- return core._process_tensor_video(frames_04chw.float(), seed_1hw)
195
-
196
- elif mode == "step":
197
- alphas, fgs = [], []
198
- for i in range(frames_04chw.shape[0]):
199
- f = frames_04chw[i:i+1]
200
- if i == 0 and seed_1hw is not None:
201
- a, fg = core.step(f, seed_1hw)
202
- else:
203
- a, fg = core.step(f)
204
- alphas.append(a)
205
- fgs.append(fg)
206
- alpha = torch.cat(alphas, dim=0)
207
- fg = torch.cat(fgs, dim=0)
208
- return alpha, fg
209
-
210
- raise MatAnyError(f"Unsupported mode: {mode}")
211
-
212
-
213
- def _safe_empty_cache():
214
- if torch.cuda.is_available():
215
- torch.cuda.synchronize()
216
- torch.cuda.empty_cache()
217
-
218
-
219
- def _cuda_snapshot():
220
- if not torch.cuda.is_available():
221
- return "CUDA: N/A"
222
- i = torch.cuda.current_device()
223
- return (f"device={i}, name={torch.cuda.get_device_name(i)}, "
224
- f"alloc={torch.cuda.memory_allocated(i)/1e9:.2f}GB, "
225
- f"reserved={torch.cuda.memory_reserved(i)/1e9:.2f}GB")
226
-
227
-
228
  def _read_mask_hw(mask_path: Path, target_hw: Tuple[int, int]) -> np.ndarray:
229
  """Read mask image, convert to float32 [0,1], resize to target (H,W)."""
230
  if not Path(mask_path).exists():
@@ -241,267 +90,197 @@ def _read_mask_hw(mask_path: Path, target_hw: Tuple[int, int]) -> np.ndarray:
241
 
242
  def _to_chw01(img_bgr: np.ndarray) -> np.ndarray:
243
  """BGR [H,W,3] uint8 -> CHW float32 [0,1] RGB."""
244
- # OpenCV gives BGR; convert to RGB
245
  rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
246
  rgbf = rgb.astype(np.float32) / 255.0
247
  chw = np.transpose(rgbf, (2, 0, 1)) # C,H,W
248
  return chw
249
 
250
 
251
- def _mask_to_1hw(mask_hw01: np.ndarray) -> np.ndarray:
252
- """HW float32 [0,1] -> 1HW float32 [0,1]."""
253
- return np.expand_dims(mask_hw01, axis=0)
254
-
255
-
256
- def _ensure_dir(p: Path) -> None:
257
- p.mkdir(parents=True, exist_ok=True)
258
-
259
-
260
- def _open_video_writers(out_dir: Path, fps: float, size: Tuple[int, int]) -> Tuple[cv2.VideoWriter, cv2.VideoWriter]:
261
- """Return (alpha_writer, fg_writer). size=(W,H)."""
262
- fourcc = cv2.VideoWriter_fourcc(*"mp4v")
263
- W, H = size
264
- alpha_path = str(out_dir / "alpha.mp4")
265
- fg_path = str(out_dir / "fg.mp4")
266
- # alpha: single channel => write as 3-channel grayscale for broad compatibility
267
- alpha_writer = cv2.VideoWriter(alpha_path, fourcc, fps, (W, H), True)
268
- fg_writer = cv2.VideoWriter(fg_path, fourcc, fps, (W, H), True)
269
- if not alpha_writer.isOpened() or not fg_writer.isOpened():
270
- raise MatAnyError("Failed to open VideoWriter for alpha/fg outputs.")
271
- return alpha_writer, fg_writer
272
-
273
-
274
  def _validate_nonempty(file_path: Path) -> None:
275
  if not file_path.exists() or file_path.stat().st_size == 0:
276
  raise MatAnyError(f"Output file missing/empty: {file_path}")
277
 
278
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  class MatAnyoneSession:
280
  """
281
  Unified, streaming wrapper over MatAnyone variants.
282
 
283
  Public:
284
  - process_stream(video_path, seed_mask_path, out_dir, progress_cb)
 
285
 
286
- Detects API once at init:
287
- - prefers frame-wise: core.step(img[, mask]) OR session.process_frame(img, mask)
288
- - else uses video-wise: process_video(frames, mask) with chunk fallback
289
  """
290
 
291
  def __init__(self, device: Optional[str] = None, precision: str = "auto"):
292
- """Initialize MatAnyoneSession with optional device and precision settings.
293
-
294
  Args:
295
- device: Device to run on (e.g., 'cuda', 'cpu', 'cuda:0'). If None, auto-detects CUDA.
296
- precision: One of 'auto', 'fp32', or 'fp16'. 'auto' uses fp16 if CUDA is available.
297
  """
298
  self.device = torch.device(device) if device else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
299
  self.precision = precision.lower()
 
300
  self._core = None
301
- self._api_mode = None # "step", "process_frame", or "process_video"
302
- self._frame_times = []
303
- self._start_time = 0.0
304
- self._gpu_mem_allocated = 0.0
305
- self._gpu_mem_cached = 0.0
306
  self._lazy_init()
307
-
308
- # Log initialization
309
- log.info(f"Initialized MatAnyoneSession on {self.device} with precision {self.precision}")
310
  if torch.cuda.is_available():
311
- log.info(f"CUDA device: {torch.cuda.get_device_name(self.device)}")
 
312
  self._log_gpu_memory()
313
 
314
- def _log_gpu_memory(self) -> None:
315
- """Log current GPU memory usage."""
316
  if torch.cuda.is_available():
 
317
  try:
318
- allocated = torch.cuda.memory_allocated(self.device) / 1024**2
319
- cached = torch.cuda.memory_reserved(self.device) / 1024**2
320
- log.info(f"GPU Memory - Allocated: {allocated:.1f}MB, Cached: {cached:.1f}MB")
321
- return allocated, cached
322
  except Exception as e:
323
- log.warning(f"Failed to get GPU memory info: {e}")
324
  return 0.0, 0.0
325
-
326
  def _lazy_init(self) -> None:
327
- """Lazy initialization of the MatAnyone inference core."""
328
  try:
329
  from matanyone.inference.inference_core import InferenceCore # type: ignore
330
  except ImportError as e:
331
- raise MatAnyError(f"Failed to import MatAnyone: {e}. Please ensure it's installed correctly.")
332
  except Exception as e:
333
  raise MatAnyError(f"Unexpected error during MatAnyone import: {e}")
334
 
335
- # Log GPU info
336
- if torch.cuda.is_available():
337
- log.info(f"[GPU] CUDA is available. Device: {torch.cuda.get_device_name(0)}")
338
- log.info(f"[GPU] Memory allocated: {torch.cuda.memory_allocated()/1024**2:.1f}MB")
339
- log.info(f"[GPU] Memory cached: {torch.cuda.memory_reserved()/1024**2:.1f}MB")
340
- else:
341
- log.warning("[GPU] CUDA is not available. Using CPU (this will be slow!)")
342
-
343
- # Try zero-arg first, then repo-id variant
344
  try:
345
  self._core = InferenceCore()
346
  except TypeError:
347
- try:
348
- self._core = InferenceCore("PeiqingYang/MatAnyone")
349
- except Exception as e:
350
- raise MatAnyError(f"MatAnyone InferenceCore init failed: {e}")
351
-
352
- core = self._core
353
 
354
- # MODE SELECTION (prefer video) can be forced by env flags
355
  force_video = os.getenv("MATANY_FORCE_VIDEO", "1") == "1"
356
  force_step = os.getenv("MATANY_FORCE_STEP", "0") == "1"
357
 
358
- if force_step and hasattr(core, "step") and callable(getattr(core, "step")):
359
- self._api_mode = "step"
360
- elif force_video and hasattr(core, "process_video") and callable(getattr(core, "process_video")):
361
- self._api_mode = "process_video"
362
- elif hasattr(core, "process_video") and callable(getattr(core, "process_video")):
363
- self._api_mode = "process_video"
364
- elif hasattr(core, "process_frame") and callable(getattr(core, "process_frame")):
365
- self._api_mode = "process_frame"
366
- elif hasattr(core, "step") and callable(getattr(core, "step")):
367
  self._api_mode = "step"
368
  else:
369
- raise MatAnyError("No supported MatAnyone API found (process_video/process_frame/step).")
 
 
 
 
370
 
371
- log.info(f"[MATANY] Initialized on {self.device} | API mode = {self._api_mode}")
372
  self._initialized = True
373
 
374
  def _maybe_amp(self):
375
- # Use new API to silence deprecation warning
376
  if self.precision == "fp32":
377
  return torch.amp.autocast(device_type="cuda", enabled=False)
378
  if self.precision == "fp16":
379
- return torch.amp.autocast(device_type="cuda", enabled=True, dtype=torch.float16)
380
- return torch.amp.autocast(device_type="cuda", enabled=torch.cuda.is_available())
 
381
 
382
  def _validate_input_frame(self, frame: np.ndarray) -> None:
383
- """Validate input frame dimensions and type."""
384
  if not isinstance(frame, np.ndarray):
385
- raise MatAnyError(f"Frame must be a numpy array, got {type(frame)}")
386
  if frame.dtype != np.uint8:
387
  raise MatAnyError(f"Frame must be uint8, got {frame.dtype}")
388
  if frame.ndim != 3 or frame.shape[2] != 3:
389
- raise MatAnyError(f"Frame must be HWC with 3 channels, got shape {frame.shape}")
390
 
391
- def _run_frame(self, frame_bgr: np.ndarray, seed_1hw: Optional[np.ndarray], is_first: bool = False) -> np.ndarray:
392
  """
393
- Process a single frame through MatAnyone to generate an alpha matte.
394
- Uses strict 3D image (CHW) and 2D mask (HW) formats to avoid dimension issues.
395
-
396
- Args:
397
- frame_bgr: Input frame in BGR format (H,W,3) uint8
398
- seed_1hw: Optional mask in 1HW or HW format (float32 [0,1])
399
- is_first: Whether this is the first frame in the sequence
400
-
401
- Returns:
402
- Alpha matte in HW format (float32 [0,1])
403
-
404
- Raises:
405
- MatAnyError: If processing fails or invalid input is provided
406
  """
407
- # --- Prepare image tensor (CHW float32 [0,1]) ---
408
- img_chw = _to_chw01(frame_bgr) # (3,H,W) float32
 
409
  img_t = torch.from_numpy(img_chw).to(self.device)
410
-
411
- # --- Prepare mask tensor (HW float32 [0,1]) ---
412
  mask_t = None
413
  if is_first and seed_1hw is not None:
414
  if seed_1hw.ndim == 3 and seed_1hw.shape[0] == 1:
415
- seed_hw = seed_1hw[0] # (H,W)
416
  elif seed_1hw.ndim == 2:
417
  seed_hw = seed_1hw
418
  else:
419
  raise MatAnyError(f"seed mask must be 1HW or HW; got {seed_1hw.shape}")
420
- mask_t = torch.from_numpy(seed_hw).to(self.device) # (H,W)
421
-
422
- # --- Validate shapes ---
423
- if img_t.ndim != 3 or img_t.shape[0] != 3:
424
- raise MatAnyError(f"img_t must be CHW; got {tuple(img_t.shape)}")
425
- if mask_t is not None and mask_t.ndim != 2:
426
- raise MatAnyError(f"mask_t must be HW; got {tuple(mask_t.shape)}")
427
 
428
- # --- Process with MatAnyone ---
429
- frame_start_time = time.time()
430
  try:
431
  with torch.no_grad(), self._maybe_amp():
432
  if self._api_mode == "step":
433
- alpha = self._core.step(img_t, mask_t) if mask_t is not None else self._core.step(img_t)
434
  elif self._api_mode == "process_frame":
435
- alpha = self._core.process_frame(img_t, mask_t)
436
  else:
437
- raise MatAnyError("Internal error: Invalid API mode")
438
-
439
- # Log performance metrics
440
- frame_time = time.time() - frame_start_time
441
- self._frame_times.append(frame_time)
442
- if len(self._frame_times) > 10: # Keep last 10 frame times
443
- self._frame_times.pop(0)
444
-
445
- # Log GPU memory every 10 frames
446
- if len(self._frame_times) % 10 == 0:
447
- self._log_gpu_memory()
448
-
449
- return alpha
450
-
451
- except torch.cuda.OutOfMemoryError:
452
  self._log_gpu_memory()
453
- raise MatAnyError("CUDA out of memory. Try reducing the input resolution or batch size.")
454
  except RuntimeError as e:
455
  if "CUDA" in str(e):
 
456
  self._log_gpu_memory()
457
- raise MatAnyError(f"CUDA error: {e}")
458
- raise MatAnyError(f"Runtime error: {e}")
459
  except Exception as e:
460
- raise MatAnyError(f"Processing failed: {e}")
 
 
 
461
 
462
- # --- Process output ---
463
- # Convert to numpy and ensure correct shape/range
464
- if isinstance(alpha, torch.Tensor):
465
- alpha_np = alpha.detach().float().clamp(0, 1).squeeze().cpu().numpy()
466
  else:
467
- alpha_np = np.asarray(alpha, dtype=np.float32)
468
  if alpha_np.max() > 1.0:
469
- alpha_np = (alpha_np / 255.0).clip(0, 1)
470
-
471
- # Ensure 2D output (H,W)
472
  alpha_np = np.squeeze(alpha_np)
473
  if alpha_np.ndim != 2:
474
  raise MatAnyError(f"Expected 2D alpha matte; got shape {alpha_np.shape}")
475
-
476
- return alpha_np
477
 
478
  def _harvest_process_video_output(self, res, out_dir: Path, base: str) -> Tuple[Path, Path]:
479
  """
480
  Accepts varied return types from MatAnyone.process_video and produces
481
- (alpha.mp4, fg.mp4) inside out_dir. Strategies:
482
- - If res is a sequence of alpha arrays/tensors write our own videos.
483
- - If res is dict/tuple of paths copy/rename.
484
- - Else: glob typical output dirs for files matching base.
485
  """
486
- # Case A: sequence of masks
487
- import torch, numpy as np, cv2, glob, shutil
488
-
489
- def _as_np(a):
490
- if isinstance(a, torch.Tensor):
491
- a = a.detach().float().cpu().numpy()
492
- a = np.asarray(a)
493
- if a.ndim == 3 and a.shape[0] in (1,3): # (C,H,W) → prefer HW
494
- a = np.squeeze(a) if a.shape[0] == 1 else np.mean(a, axis=0)
495
- if a.max() > 1.0:
496
- a = a / 255.0
497
- return a.clip(0,1).astype(np.float32)
498
-
499
  alpha_mp4 = out_dir / "alpha.mp4"
500
  fg_mp4 = out_dir / "fg.mp4"
501
 
502
- # If we got arrays/tensors: we can't reconstruct FG without original frames here,
503
- # so prefer path-returning flows. If needed, you can extend this to re-read frames
504
- # and blend. For now, try to detect paths first.
505
  if isinstance(res, dict):
506
  cand_alpha = res.get("alpha") or res.get("alpha_path") or res.get("matte") or res.get("matte_path")
507
  cand_fg = res.get("fg") or res.get("fg_path") or res.get("foreground") or res.get("foreground_path")
@@ -510,13 +289,13 @@ def _as_np(a):
510
  shutil.copy2(cand_alpha, alpha_mp4); moved += 1
511
  if cand_fg and Path(cand_fg).exists():
512
  shutil.copy2(cand_fg, fg_mp4); moved += 1
513
- if moved == 2: return alpha_mp4, fg_mp4
 
514
 
 
515
  if isinstance(res, (list, tuple)) and len(res) >= 1:
516
- # Heuristic: assume list/tuple of file paths
517
  paths = [Path(x) for x in res if isinstance(x, (str, Path))]
518
  if paths:
519
- # Pick best matches by name
520
  alpha_candidates = [p for p in paths if p.exists() and ("alpha" in p.name or "matte" in p.name)]
521
  fg_candidates = [p for p in paths if p.exists() and ("fg" in p.name or "fore" in p.name)]
522
  if alpha_candidates and fg_candidates:
@@ -524,23 +303,25 @@ def _as_np(a):
524
  shutil.copy2(fg_candidates[0], fg_mp4)
525
  return alpha_mp4, fg_mp4
526
 
527
- # As last resort, glob common dirs created by the lib
528
  search_dirs = [Path.cwd(), out_dir, Path("results"), Path("result"), Path("output"), Path("outputs")]
529
- hits = []
530
  for d in search_dirs:
531
  if d.exists():
532
  hits.extend(list(d.rglob(f"*{base}*.*")))
533
- # choose best alpha/fg
534
  alpha_candidates = [p for p in hits if p.suffix.lower() in (".mp4",".mov",".mkv",".avi") and ("alpha" in p.name or "matte" in p.name)]
535
  fg_candidates = [p for p in hits if p.suffix.lower() in (".mp4",".mov",".mkv",".avi") and ("fg" in p.name or "fore" in p.name)]
536
  if alpha_candidates and fg_candidates:
537
- import shutil
538
  shutil.copy2(alpha_candidates[0], alpha_mp4)
539
  shutil.copy2(fg_candidates[0], fg_mp4)
540
  return alpha_mp4, fg_mp4
541
 
542
- raise MatAnyError("MatAnyone.process_video did not yield discoverable outputs.")
 
543
 
 
 
 
544
  def process_stream(
545
  self,
546
  video_path: Path,
@@ -548,520 +329,209 @@ def process_stream(
548
  out_dir: Optional[Path] = None,
549
  progress_cb: Optional[Callable] = None,
550
  ) -> Tuple[Path, Path]:
551
- """Process video stream with MatAnyone.
552
-
553
- Args:
554
- video_path: Input video file path (must exist and be readable)
555
- seed_mask_path: Optional seed mask image (grayscale, same size as video)
556
- out_dir: Output directory (default: video_path.parent)
557
- progress_cb: Callback for progress updates (signature: (float, str) or (str,))
558
 
559
  Returns:
560
- Tuple of (alpha_path, fg_path) output video paths
561
 
562
  Raises:
563
- MatAnyError: If processing fails for any reason
564
- FileNotFoundError: If input files are not found
565
- ValueError: If input parameters are invalid
566
  """
567
- # Input validation
568
  if not video_path.exists():
569
  raise FileNotFoundError(f"Input video not found: {video_path}")
570
-
571
- if seed_mask_path is not None and not seed_mask_path.exists():
572
- raise FileNotFoundError(f"Seed mask not found: {seed_mask_path}")
573
-
574
  if out_dir is None:
575
  out_dir = video_path.parent
576
-
577
  out_dir = Path(out_dir)
578
  out_dir.mkdir(parents=True, exist_ok=True)
579
-
580
- # Initialize progress tracking
581
- self._frame_times = []
582
- self._start_time = time.time()
583
- _emit_progress(progress_cb, 0.0, "Initializing video processing...")
584
 
585
- # Log GPU status
586
- if torch.cuda.is_available():
587
- _emit_progress(progress_cb, 0.01, "GPU detected, initializing CUDA...")
588
- else:
589
- _emit_progress(progress_cb, 0.01, "No GPU detected, using CPU (slower)...")
590
 
591
- cap = cv2.VideoCapture(str(video_path))
592
- if not cap.isOpened():
 
593
  raise MatAnyError(f"Failed to open video: {video_path}")
 
 
 
 
 
594
 
595
- N = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
596
- fps = cap.get(cv2.CAP_PROP_FPS)
597
- W = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
598
- H = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
599
- cap.release()
600
 
601
- log.info(f"[MATANY] Processing {N} frames ({W}x{H} @ {fps:.1f}fps) from {video_path}")
602
- _emit_progress(progress_cb, 0.05, f"Processing {N} frames ({W}x{H} @ {fps:.1f}fps)")
 
 
603
 
604
  try:
605
  if self._api_mode == "process_video":
606
- # --- PATH-BASED CALL (this wheel expects a video path, not tensors) ---
607
- _emit_progress(progress_cb, 0.1, "Using MatAnyone video mode (GPU-accelerated)")
608
-
609
- # Log before starting video processing
610
  if torch.cuda.is_available():
611
- mem_alloc, _ = self._log_gpu_memory()
612
- _emit_progress(progress_cb, 0.12, f"GPU memory before processing: {mem_alloc:.1f}MB")
613
-
614
- # Some builds accept (video_path, seed_mask_path), others just (video_path)
615
- try:
616
- _emit_progress(progress_cb, 0.15, "Starting video processing with mask...")
617
- res = self._core.process_video(
618
- str(video_path),
619
- str(seed_mask_path) if seed_mask_path is not None else None
620
- )
621
- except TypeError as e:
622
- if "takes 2 positional arguments but 3 were given" in str(e):
623
- _emit_progress(progress_cb, 0.15, "Starting video processing without mask...")
624
- res = self._core.process_video(str(video_path))
625
- else:
626
- raise
627
-
628
- # Log after processing
629
- if torch.cuda.is_available():
630
- _emit_progress(progress_cb, 0.9, f"Processing complete. GPU memory used: {torch.cuda.memory_allocated()/1024**2:.1f}MB")
631
- else:
632
- _emit_progress(progress_cb, 0.9, "Processing complete.")
633
-
634
- # Normalize output files
635
- _emit_progress(progress_cb, 0.95, "Finalizing output files...")
636
- alpha_path, fg_path = self._harvest_process_video_output(res, out_dir, base=video_path.stem)
637
- _validate_nonempty(alpha_path)
638
- _validate_nonempty(fg_path)
639
-
640
- _emit_progress(progress_cb, 1.0, "Processing complete!")
641
- return alpha_path, fg_path
642
-
643
- else:
644
- # Frame-by-frame (preferred)
645
- log.info(f"[MATANY] Using frame-by-frame mode: {self._api_mode}")
646
- _emit_progress(progress_cb, 0.1, f"Using {self._api_mode} mode (frame-by-frame)")
647
-
648
- cap = cv2.VideoCapture(str(video_path))
649
- alpha_path = out_dir / "alpha.mp4"
650
- fg_path = out_dir / "fg.mp4"
651
-
652
- # Initialize video writers
653
- _emit_progress(progress_cb, 0.12, "Initializing video writers...")
654
- alpha_writer = cv2.VideoWriter(
655
- str(alpha_path),
656
- cv2.VideoWriter_fourcc(*'mp4v'),
657
- fps,
658
- (W, H),
659
- isColor=False
660
- )
661
- fg_writer = cv2.VideoWriter(
662
- str(fg_path),
663
- cv2.VideoWriter_fourcc(*'mp4v'),
664
- fps,
665
- (W, H),
666
- isColor=True
667
- )
668
-
669
- if not alpha_writer.isOpened() or not fg_writer.isOpened():
670
- raise MatAnyError("Failed to initialize video writers")
671
 
 
672
  try:
673
- # Load seed mask if provided
674
- seed_1hw = None
675
- if seed_mask_path is not None:
676
- seed_1hw = _read_mask_hw(seed_mask_path, (H, W))
677
-
678
- idx = 0
679
- last_progress_update = 0
680
- frame_times = []
681
- start_time = time.time()
682
-
683
- while True:
684
- ret, frame = cap.read()
685
- if not ret:
686
- break
687
-
688
- frame_start_time = time.time()
689
-
690
- # Update progress more frequently (every 1% or 5 frames, whichever is more frequent)
691
- current_progress = (idx / N) if N > 0 else 0.0
692
- if idx % max(5, N//100) == 0 or time.time() - last_progress_update > 2.0:
693
- # Calculate progress metrics
694
- elapsed = time.time() - start_time
695
- if idx > 0 and current_progress > 0:
696
- # Calculate ETA
697
- eta_seconds = (elapsed / current_progress) * (1 - current_progress)
698
- if eta_seconds > 3600:
699
- eta_str = f"{eta_seconds/3600:.1f} hours"
700
- elif eta_seconds > 60:
701
- eta_str = f"{eta_seconds/60:.1f} minutes"
702
- else:
703
- eta_str = f"{eta_seconds:.0f} seconds"
704
-
705
- # Calculate processing speed
706
- fps = idx / elapsed if elapsed > 0 else 0
707
-
708
- # Add GPU memory info if available
709
- gpu_info = ""
710
- if torch.cuda.is_available():
711
- mem_alloc = torch.cuda.memory_allocated() / 1024**2
712
- mem_cached = torch.cuda.memory_reserved() / 1024**2
713
- gpu_info = f" | GPU: {mem_alloc:.1f}/{mem_cached:.1f}MB"
714
-
715
- status = (f"Processing frame {idx+1}/{N} (ETA: {eta_str}, "
716
- f"{fps:.1f} FPS{gpu_info}")
717
- _emit_progress(progress_cb, min(0.99, current_progress), status)
718
- last_progress_update = time.time()
719
-
720
- # Process frame
721
- log.debug(f"[MATANY] Processing frame {idx+1}/{N}")
722
- # Only pass seed mask on first frame
723
- current_mask = seed_1hw if idx == 0 else None
724
- alpha_hw = self._run_frame(frame, current_mask, is_first=(idx == 0))
725
-
726
- # Calculate frame processing time
727
- frame_time = time.time() - frame_start_time
728
- frame_times.append(frame_time)
729
- if len(frame_times) > 10: # Keep last 10 frame times for average
730
- frame_times.pop(0)
731
-
732
- # Log GPU memory usage occasionally
733
- if idx % 50 == 0 and torch.cuda.is_available():
734
- log.info(f"[GPU] Memory allocated: {torch.cuda.memory_allocated()/1024**2:.1f}MB, "
735
- f"Cached: {torch.cuda.memory_reserved()/1024**2:.1f}MB, "
736
- f"Avg frame time: {sum(frame_times)/len(frame_times)*1000:.1f}ms")
737
-
738
- # Compose output frames
739
- alpha_u8 = (alpha_hw * 255.0 + 0.5).astype(np.uint8)
740
- alpha_rgb = cv2.cvtColor(alpha_u8, cv2.COLOR_GRAY2BGR)
741
- fg_bgr = (frame.astype(np.float32) * (alpha_hw[..., None] / 255.0)).astype(np.uint8)
742
-
743
- # Write outputs
744
- alpha_writer.write(alpha_rgb)
745
- fg_writer.write(fg_bgr)
746
- idx += 1
747
-
748
- except Exception as e:
749
- # Log detailed error information
750
- error_msg = f"Error processing frame {idx+1}/{N}: {str(e)}"
751
- log.error(error_msg, exc_info=True)
752
-
753
- # Add GPU memory info if available
754
- if torch.cuda.is_available():
755
- mem_alloc = torch.cuda.memory_allocated() / 1024**2
756
- mem_cached = torch.cuda.memory_reserved() / 1024**2
757
- error_msg += (f"\nGPU Memory - Allocated: {mem_alloc:.1f}MB, "
758
- f"Cached: {mem_cached:.1f}MB")
759
-
760
- # Add frame processing stats
761
- if frame_times:
762
- avg_time = sum(frame_times) / len(frame_times)
763
- error_msg += f"\nAvg frame time: {avg_time*1000:.1f}ms"
764
-
765
- _emit_progress(progress_cb, -1, f"ERROR: {error_msg}")
766
- raise MatAnyError(error_msg) from e
767
-
768
- finally:
769
- # Cleanup resources
770
- try:
771
- if 'cap' in locals() and hasattr(cap, 'isOpened') and cap.isOpened():
772
- cap.release()
773
- if 'alpha_writer' in locals() and alpha_writer is not None:
774
- if hasattr(alpha_writer, 'isOpened') and alpha_writer.isOpened():
775
- alpha_writer.release()
776
- if 'fg_writer' in locals() and fg_writer is not None:
777
- if hasattr(fg_writer, 'isOpened') and fg_writer.isOpened():
778
- fg_writer.release()
779
-
780
- # Log final stats
781
- total_time = time.time() - start_time
782
- fps = idx / total_time if total_time > 0 else 0
783
-
784
- # Log GPU memory info if available
785
- gpu_info = ""
786
- if torch.cuda.is_available():
787
- mem_alloc = torch.cuda.memory_allocated() / 1024**2
788
- mem_cached = torch.cuda.memory_reserved() / 1024**2
789
- gpu_info = f"\nGPU Memory - Allocated: {mem_alloc:.1f}MB, Cached: {mem_cached:.1f}MB"
790
-
791
- log.info(
792
- f"[MATANY] Processed {idx} frames in {total_time:.1f}s ({fps:.1f} FPS){gpu_info}"
793
- )
794
-
795
- # Validate outputs
796
- _validate_nonempty(alpha_path)
797
- _validate_nonempty(fg_path)
798
-
799
- # Final progress update
800
- _emit_progress(
801
- progress_cb,
802
- 1.0,
803
- f"Complete! Processed {idx} frames at {fps:.1f} FPS{gpu_info}"
804
- )
805
-
806
- return alpha_path, fg_path
807
-
808
- except Exception as e:
809
- error_msg = f"Error during cleanup: {str(e)}"
810
- log.error(error_msg, exc_info=True)
811
- _emit_progress(progress_cb, -1, f"CLEANUP ERROR: {error_msg}")
812
- raise MatAnyError(error_msg) from e
813
-
814
- except Exception as e:
815
- error_msg = f"Error during video processing: {str(e)}"
816
- log.error(error_msg, exc_info=True)
817
- if torch.cuda.is_available():
818
- error_msg += f"\nGPU Memory: {torch.cuda.memory_allocated()/1024**2:.1f}MB allocated"
819
- _emit_progress(progress_cb, -1, error_msg)
820
- raise MatAnyError(error_msg) from e
821
- else:
822
- # Frame-by-frame (preferred)
823
- log.info(f"[MATANY] Using frame-by-frame mode: {self._api_mode}")
824
- _emit_progress(progress_cb, 0.1, f"Using {self._api_mode} mode (frame-by-frame)")
825
-
826
  cap = cv2.VideoCapture(str(video_path))
827
- alpha_path = out_dir / "alpha.mp4"
828
- fg_path = out_dir / "fg.mp4"
829
-
830
- # Initialize video writers
831
- _emit_progress(progress_cb, 0.12, "Initializing video writers...")
832
- alpha_writer = cv2.VideoWriter(
833
- str(alpha_path),
834
- cv2.VideoWriter_fourcc(*'mp4v'),
835
- fps,
836
- (W, H),
837
- isColor=False
838
- )
839
- fg_writer = cv2.VideoWriter(
840
- str(fg_path),
841
- cv2.VideoWriter_fourcc(*'mp4v'),
842
- fps,
843
- (W, H),
844
- isColor=True
845
- )
846
-
847
  if not alpha_writer.isOpened() or not fg_writer.isOpened():
848
- raise MatAnyError("Failed to initialize video writers")
 
 
 
 
 
 
 
 
 
849
 
850
  try:
851
- # Load seed mask if provided
852
- seed_1hw = None
853
- if seed_mask_path is not None:
854
- seed_1hw = _read_mask_hw(seed_mask_path, (H, W))
855
-
856
- idx = 0
857
- last_progress_update = 0
858
- frame_times = []
859
- start_time = time.time()
860
-
861
- try:
862
- while True:
863
- ret, frame = cap.read()
864
- if not ret:
865
- break
866
-
867
- frame_start_time = time.time()
868
-
869
- # Update progress more frequently (every 1% or 5 frames, whichever is more frequent)
870
- current_progress = (idx / N) if N > 0 else 0.0
871
- if idx % max(5, N//100) == 0 or time.time() - last_progress_update > 2.0:
872
- # Calculate progress metrics
873
- elapsed = time.time() - start_time
874
- if idx > 0 and current_progress > 0:
875
- # Calculate ETA
876
- eta_seconds = (elapsed / current_progress) * (1 - current_progress)
877
- if eta_seconds > 3600:
878
- eta_str = f"{eta_seconds/3600:.1f} hours"
879
- elif eta_seconds > 60:
880
- eta_str = f"{eta_seconds/60:.1f} minutes"
881
- else:
882
- eta_str = f"{eta_seconds:.0f} seconds"
883
-
884
- # Calculate processing speed
885
- fps = idx / elapsed if elapsed > 0 else 0
886
-
887
- # Add GPU memory info if available
888
- gpu_info = ""
889
- if torch.cuda.is_available():
890
- mem_alloc = torch.cuda.memory_allocated() / 1024**2
891
- mem_cached = torch.cuda.memory_reserved() / 1024**2
892
- gpu_info = f" | GPU: {mem_alloc:.1f}/{mem_cached:.1f}MB"
893
-
894
- status = (f"Processing frame {idx+1}/{N} (ETA: {eta_str}, "
895
- f"{fps:.1f} FPS{gpu_info}")
896
- _emit_progress(progress_cb, min(0.99, current_progress), status)
897
- last_progress_update = time.time()
898
-
899
- # Process frame
900
- log.debug(f"[MATANY] Processing frame {idx+1}/{N}")
901
- # Only pass seed mask on first frame
902
- current_mask = seed_1hw if idx == 0 else None
903
- alpha_hw = self._run_frame(frame, current_mask, is_first=(idx == 0))
904
-
905
- # Calculate frame processing time
906
- frame_time = time.time() - frame_start_time
907
- frame_times.append(frame_time)
908
- if len(frame_times) > 10: # Keep last 10 frame times for average
909
- frame_times.pop(0)
910
-
911
- # Log GPU memory usage occasionally
912
- if idx % 50 == 0 and torch.cuda.is_available():
913
- log.info(f"[GPU] Memory allocated: {torch.cuda.memory_allocated()/1024**2:.1f}MB, "
914
- f"Cached: {torch.cuda.memory_reserved()/1024**2:.1f}MB, "
915
- f"Avg frame time: {sum(frame_times)/len(frame_times)*1000:.1f}ms")
916
-
917
- # Compose output frames
918
- alpha_u8 = (alpha_hw * 255.0 + 0.5).astype(np.uint8)
919
- alpha_rgb = cv2.cvtColor(alpha_u8, cv2.COLOR_GRAY2BGR)
920
- fg_bgr = (frame.astype(np.float32) * (alpha_hw[..., None] / 255.0)).astype(np.uint8)
921
-
922
- # Write outputs
923
- alpha_writer.write(alpha_rgb)
924
- fg_writer.write(fg_bgr)
925
- idx += 1
926
-
927
- except Exception as e:
928
- # Log detailed error information
929
- error_msg = f"Error processing frame {idx+1}/{N}: {str(e)}"
930
- log.error(error_msg, exc_info=True)
931
-
932
- # Add GPU memory info if available
933
- if torch.cuda.is_available():
934
- mem_alloc = torch.cuda.memory_allocated() / 1024**2
935
- mem_cached = torch.cuda.memory_reserved() / 1024**2
936
- error_msg += (f"\nGPU Memory - Allocated: {mem_alloc:.1f}MB, "
937
- f"Cached: {mem_cached:.1f}MB")
938
-
939
- # Add frame processing stats
940
- if self._frame_times:
941
- avg_time = sum(self._frame_times) / len(self._frame_times)
942
- error_msg += f"\nAvg frame time: {avg_time*1000:.1f}ms"
943
-
944
- _emit_progress(progress_cb, -1, f"ERROR: {error_msg}")
945
- raise MatAnyError(error_msg) from e
946
-
947
- finally:
948
- # Cleanup resources
949
- # Cleanup resources in a single finally block
950
- try:
951
- if 'cap' in locals() and cap is not None:
952
- if hasattr(cap, 'isOpened') and cap.isOpened():
953
- cap.release()
954
- if 'alpha_writer' in locals() and alpha_writer is not None:
955
- if hasattr(alpha_writer, 'isOpened') and alpha_writer.isOpened():
956
- alpha_writer.release()
957
- if 'fg_writer' in locals() and fg_writer is not None:
958
- if hasattr(fg_writer, 'isOpened') and fg_writer.isOpened():
959
- fg_writer.release()
960
-
961
- # Log final stats
962
- total_time = time.time() - start_time
963
- fps = idx / total_time if total_time > 0 else 0
964
-
965
- # Log GPU memory info if available
966
- gpu_info = ""
967
  if torch.cuda.is_available():
968
- mem_alloc = torch.cuda.memory_allocated() / 1024**2
969
- mem_cached = torch.cuda.memory_reserved() / 1024**2
970
- gpu_info = f"\nGPU Memory - Allocated: {mem_alloc:.1f}MB, Cached: {mem_cached:.1f}MB"
971
-
972
- log.info(
973
- f"[MATANY] Processed {idx} frames in {total_time:.1f}s ({fps:.1f} FPS){gpu_info}"
974
- )
975
-
976
- # Validate outputs
977
- _validate_nonempty(alpha_path)
978
- _validate_nonempty(fg_path)
979
-
980
- # Final progress update
981
- _emit_progress(
982
- progress_cb,
983
- 1.0,
984
- f"Complete! Processed {idx} frames at {fps:.1f} FPS{gpu_info}"
985
- )
986
-
987
- return alpha_path, fg_path
988
-
989
- except Exception as e:
990
- error_msg = f"Error during cleanup: {str(e)}"
991
- log.error(error_msg, exc_info=True)
992
- _emit_progress(progress_cb, -1, f"CLEANUP ERROR: {error_msg}")
993
- raise MatAnyError(error_msg) from e
994
- finally:
995
- # Ensure all resources are cleaned up
996
- if 'cap' in locals() and cap is not None:
997
- if hasattr(cap, 'release'):
998
- cap.release()
999
- if 'alpha_writer' in locals() and alpha_writer is not None:
1000
- if hasattr(alpha_writer, 'release'):
1001
- alpha_writer.release()
1002
- if 'fg_writer' in locals() and fg_writer is not None:
1003
- if hasattr(fg_writer, 'release'):
1004
- fg_writer.release()
1005
- _safe_empty_cache()
1006
 
1007
- def _flush_chunk(self, frames_bgr, seed_1hw, alpha_writer, fg_writer):
 
 
 
 
 
 
 
 
 
 
 
 
 
1008
  """
1009
- Process an in-memory batch (list of uint8 BGR frames), write results via writers.
1010
- Strong CUDA guards + cleanup.
1011
  """
1012
- device = self.device
1013
- use_fp16 = (device.type == "cuda") and getattr(self, "use_fp16", True)
1014
  mode = _select_matany_mode(self._core)
1015
-
1016
- frames_04chw = None
1017
- alpha_n1hw = None
1018
- fg_n3hw = None
1019
-
1020
- try:
1021
- frames_04chw = _to_device_batch(frames_bgr, device, dtype=torch.float16 if use_fp16 else torch.float32)
1022
-
1023
- if device.type == "cuda":
1024
- stream = torch.cuda.Stream()
1025
- with torch.cuda.stream(stream):
1026
- with torch.autocast(device_type="cuda", enabled=use_fp16):
1027
- alpha_n1hw, fg_n3hw = _matany_run(self._core, mode, frames_04chw, seed_1hw, use_fp16)
1028
- stream.synchronize()
1029
- else:
1030
- alpha_n1hw, fg_n3hw = _matany_run(self._core, mode, frames_04chw, seed_1hw, use_fp16)
1031
-
1032
- alpha_cpu, fg_cpu = _to_uint8_cpu(alpha_n1hw, fg_n3hw)
1033
-
1034
- for i in range(alpha_cpu.shape[0]):
1035
- alpha_writer.write(alpha_cpu[i]) # [H,W] uint8
1036
- fg_writer.write(fg_cpu[i][..., ::-1].copy()) # RGB->BGR
1037
-
1038
- if hasattr(self._core, "last_mask"):
1039
- self._last_alpha_1hw = self._core.last_mask
1040
-
1041
- except torch.cuda.OutOfMemoryError as e:
1042
- snap = _cuda_snapshot()
1043
- _safe_empty_cache()
1044
- # Re-raise with context for pipeline to catch
1045
- raise MatAnyError(f"CUDA OOM in _flush_chunk | {snap}") from e
1046
-
1047
- except Exception as e:
1048
- snap = _cuda_snapshot()
1049
- raise MatAnyError(f"MatAnyone failure in _flush_chunk: {e} | {snap}") from e
1050
-
1051
- finally:
1052
- # ensure we release heavy tensors
1053
- try:
1054
- del alpha_n1hw, fg_n3hw, frames_04chw
1055
- except Exception:
1056
- pass
1057
- _safe_empty_cache()
1058
-
1059
- def process_stream(self, frames_iterable, seed_1hw, alpha_writer, fg_writer, chunk_size=32):
1060
  """
1061
  Buffer frames from iterable and process in chunks.
1062
  On OOM, retry once with half chunk size; otherwise bubble up MatAnyError.
1063
  """
1064
- frames_buf = []
1065
  try:
1066
  for f in frames_iterable:
1067
  frames_buf.append(f)
@@ -1069,10 +539,8 @@ def process_stream(self, frames_iterable, seed_1hw, alpha_writer, fg_writer, chu
1069
  try:
1070
  self._flush_chunk(frames_buf, seed_1hw, alpha_writer, fg_writer)
1071
  frames_buf.clear()
1072
- except torch.cuda.OutOfMemoryError:
1073
- # should be wrapped above, but double-guard
1074
- raise
1075
- except MatAnyError as inner:
1076
  # one-time downshift
1077
  if chunk_size > 4:
1078
  half = max(4, chunk_size // 2)
@@ -1081,19 +549,14 @@ def process_stream(self, frames_iterable, seed_1hw, alpha_writer, fg_writer, chu
1081
  self._flush_chunk(sub, seed_1hw, alpha_writer, fg_writer)
1082
  frames_buf.clear()
1083
  else:
1084
- raise inner
1085
 
1086
  if frames_buf:
1087
  self._flush_chunk(frames_buf, seed_1hw, alpha_writer, fg_writer)
1088
  frames_buf.clear()
1089
 
1090
- except torch.cuda.OutOfMemoryError as e:
1091
- snap = _cuda_snapshot()
1092
- _safe_empty_cache()
1093
- raise MatAnyError(f"CUDA OOM in process_stream outer | {snap}") from e
1094
-
1095
  except Exception as e:
1096
- raise MatAnyError(f"Unexpected error in process_stream: {e}") from e
1097
 
1098
  finally:
1099
  frames_buf.clear()
 
2
  """
3
  MatAnyone Adapter (streaming, API-agnostic)
4
  -------------------------------------------
5
+ - Supports multiple MatAnyone variants:
6
+ * frame API: core.step(image[, mask]) or core.process_frame(image, mask)
7
+ * video API: core.process_video(video_path[, mask_path])
8
  - Streams frames: no full-video-in-RAM.
9
+ - Emits alpha.mp4 (grayscale-as-BGR for compatibility) and fg.mp4 (RGB-on-black) as it goes.
10
  - Validates outputs and raises MatAnyError on failure (so pipeline can fallback).
11
 
12
  I/O conventions:
 
21
  import os
22
  import cv2
23
  import sys
 
 
24
  import time
25
+ import glob
26
+ import shutil
27
  import torch
28
  import logging
 
29
  import numpy as np
30
  from pathlib import Path
31
+ from typing import Optional, Callable, Tuple, List, Union
32
 
33
  log = logging.getLogger(__name__)
34
 
35
+
36
+ # -----------------------------
37
+ # Small utilities
38
+ # -----------------------------
39
  def _emit_progress(cb, pct: float, msg: str):
40
  if not cb:
41
  return
 
45
  try:
46
  cb(msg) # legacy 1-arg
47
  except TypeError:
48
+ pass
49
+
50
 
51
  class MatAnyError(RuntimeError):
52
  """Custom exception for MatAnyone processing errors."""
53
  pass
54
 
55
 
56
+ def _cuda_snapshot(device: Optional[torch.device] = None) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  if not torch.cuda.is_available():
58
  return "CUDA: N/A"
59
+ idx = 0
60
+ if device is not None and isinstance(device, torch.device) and device.index is not None:
61
+ idx = device.index
62
+ name = torch.cuda.get_device_name(idx)
63
+ alloc = torch.cuda.memory_allocated(idx) / 1e9
64
+ resv = torch.cuda.memory_reserved(idx) / 1e9
65
+ return f"device={idx}, name={name}, alloc={alloc:.2f}GB, reserved={resv:.2f}GB"
66
 
67
 
68
  def _safe_empty_cache():
 
74
  torch.cuda.empty_cache()
75
 
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  def _read_mask_hw(mask_path: Path, target_hw: Tuple[int, int]) -> np.ndarray:
78
  """Read mask image, convert to float32 [0,1], resize to target (H,W)."""
79
  if not Path(mask_path).exists():
 
90
 
91
  def _to_chw01(img_bgr: np.ndarray) -> np.ndarray:
92
  """BGR [H,W,3] uint8 -> CHW float32 [0,1] RGB."""
 
93
  rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
94
  rgbf = rgb.astype(np.float32) / 255.0
95
  chw = np.transpose(rgbf, (2, 0, 1)) # C,H,W
96
  return chw
97
 
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  def _validate_nonempty(file_path: Path) -> None:
100
  if not file_path.exists() or file_path.stat().st_size == 0:
101
  raise MatAnyError(f"Output file missing/empty: {file_path}")
102
 
103
 
104
+ def _select_matany_mode(core) -> str:
105
+ """
106
+ Pick the best-available MatAnyone API at runtime.
107
+ Priority: process_video > process_frame > step
108
+ """
109
+ if hasattr(core, "process_video") and callable(getattr(core, "process_video")):
110
+ return "process_video"
111
+ if hasattr(core, "process_frame") and callable(getattr(core, "process_frame")):
112
+ return "process_frame"
113
+ if hasattr(core, "step") and callable(getattr(core, "step")):
114
+ return "step"
115
+ raise MatAnyError("No supported MatAnyone API on core (process_video/process_frame/step).")
116
+
117
+
118
+ # -----------------------------
119
+ # Main session
120
+ # -----------------------------
121
  class MatAnyoneSession:
122
  """
123
  Unified, streaming wrapper over MatAnyone variants.
124
 
125
  Public:
126
  - process_stream(video_path, seed_mask_path, out_dir, progress_cb)
127
+ -> returns (alpha_path, fg_path)
128
 
129
+ Private helper:
130
+ - _process_stream_chunks(frames_iterable, seed_1hw, alpha_writer, fg_writer, chunk_size)
 
131
  """
132
 
133
  def __init__(self, device: Optional[str] = None, precision: str = "auto"):
134
+ """
 
135
  Args:
136
+ device: 'cuda', 'cpu', 'cuda:0', etc. If None, auto-detects CUDA.
137
+ precision: 'auto' | 'fp32' | 'fp16'
138
  """
139
  self.device = torch.device(device) if device else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
140
  self.precision = precision.lower()
141
+ self.use_fp16 = (self.precision == "fp16") or (self.precision == "auto" and self.device.type == "cuda")
142
  self._core = None
143
+ self._api_mode = None
144
+ self._initialized = False
 
 
 
145
  self._lazy_init()
146
+
147
+ log.info(f"Initialized MatAnyoneSession on {self.device} | precision={self.precision}, use_fp16={self.use_fp16}")
 
148
  if torch.cuda.is_available():
149
+ idx = self.device.index if isinstance(self.device, torch.device) and self.device.index is not None else 0
150
+ log.info(f"CUDA device: {torch.cuda.get_device_name(idx)}")
151
  self._log_gpu_memory()
152
 
153
+ # ---- internals ----
154
+ def _log_gpu_memory(self) -> Tuple[float, float]:
155
  if torch.cuda.is_available():
156
+ idx = self.device.index if isinstance(self.device, torch.device) and self.device.index is not None else 0
157
  try:
158
+ allocated = torch.cuda.memory_allocated(idx) / 1024**2
159
+ reserved = torch.cuda.memory_reserved(idx) / 1024**2
160
+ log.info(f"GPU Memory - Allocated: {allocated:.1f}MB, Reserved: {reserved:.1f}MB")
161
+ return allocated, reserved
162
  except Exception as e:
163
+ log.warning(f"Failed to read GPU memory: {e}")
164
  return 0.0, 0.0
165
+
166
  def _lazy_init(self) -> None:
167
+ """Import and initialize the MatAnyone InferenceCore and choose API mode."""
168
  try:
169
  from matanyone.inference.inference_core import InferenceCore # type: ignore
170
  except ImportError as e:
171
+ raise MatAnyError(f"Failed to import MatAnyone: {e}. Ensure it's installed and on PYTHONPATH.")
172
  except Exception as e:
173
  raise MatAnyError(f"Unexpected error during MatAnyone import: {e}")
174
 
175
+ # Some wheels accept zero-arg, some require a repo-id; try both
 
 
 
 
 
 
 
 
176
  try:
177
  self._core = InferenceCore()
178
  except TypeError:
179
+ self._core = InferenceCore("PeiqingYang/MatAnyone")
 
 
 
 
 
180
 
181
+ # Mode selection (env flags can influence)
182
  force_video = os.getenv("MATANY_FORCE_VIDEO", "1") == "1"
183
  force_step = os.getenv("MATANY_FORCE_STEP", "0") == "1"
184
 
185
+ if force_step and hasattr(self._core, "step"):
 
 
 
 
 
 
 
 
186
  self._api_mode = "step"
187
  else:
188
+ mode = _select_matany_mode(self._core)
189
+ if force_video and mode != "process_video" and hasattr(self._core, "process_video"):
190
+ self._api_mode = "process_video"
191
+ else:
192
+ self._api_mode = mode
193
 
194
+ log.info(f"[MATANY] API mode selected: {self._api_mode}")
195
  self._initialized = True
196
 
197
  def _maybe_amp(self):
198
+ enabled = (self.device.type == "cuda")
199
  if self.precision == "fp32":
200
  return torch.amp.autocast(device_type="cuda", enabled=False)
201
  if self.precision == "fp16":
202
+ return torch.amp.autocast(device_type="cuda", enabled=enabled, dtype=torch.float16)
203
+ # auto
204
+ return torch.amp.autocast(device_type="cuda", enabled=enabled and self.use_fp16)
205
 
206
  def _validate_input_frame(self, frame: np.ndarray) -> None:
 
207
  if not isinstance(frame, np.ndarray):
208
+ raise MatAnyError(f"Frame must be numpy.ndarray, got {type(frame)}")
209
  if frame.dtype != np.uint8:
210
  raise MatAnyError(f"Frame must be uint8, got {frame.dtype}")
211
  if frame.ndim != 3 or frame.shape[2] != 3:
212
+ raise MatAnyError(f"Frame must be HWC with 3 channels, got {frame.shape}")
213
 
214
+ def _run_frame(self, frame_bgr: np.ndarray, seed_1hw: Optional[np.ndarray], is_first: bool) -> np.ndarray:
215
  """
216
+ Returns alpha matte as 2D np.float32 in [0,1].
 
 
 
 
 
 
 
 
 
 
 
 
217
  """
218
+ self._validate_input_frame(frame_bgr)
219
+
220
+ img_chw = _to_chw01(frame_bgr) # (3,H,W) float32 [0,1]
221
  img_t = torch.from_numpy(img_chw).to(self.device)
222
+
 
223
  mask_t = None
224
  if is_first and seed_1hw is not None:
225
  if seed_1hw.ndim == 3 and seed_1hw.shape[0] == 1:
226
+ seed_hw = seed_1hw[0]
227
  elif seed_1hw.ndim == 2:
228
  seed_hw = seed_1hw
229
  else:
230
  raise MatAnyError(f"seed mask must be 1HW or HW; got {seed_1hw.shape}")
231
+ mask_t = torch.from_numpy(seed_hw).to(self.device)
 
 
 
 
 
 
232
 
233
+ # dispatch
234
+ frame_start = time.time()
235
  try:
236
  with torch.no_grad(), self._maybe_amp():
237
  if self._api_mode == "step":
238
+ out = self._core.step(img_t, mask_t) if mask_t is not None else self._core.step(img_t)
239
  elif self._api_mode == "process_frame":
240
+ out = self._core.process_frame(img_t, mask_t)
241
  else:
242
+ raise MatAnyError("Internal error: _run_frame used in non-frame mode")
243
+
244
+ except torch.cuda.OutOfMemoryError as e:
245
+ snap = _cuda_snapshot(self.device)
 
 
 
 
 
 
 
 
 
 
 
246
  self._log_gpu_memory()
247
+ raise MatAnyError(f"CUDA OOM while processing frame | {snap}") from e
248
  except RuntimeError as e:
249
  if "CUDA" in str(e):
250
+ snap = _cuda_snapshot(self.device)
251
  self._log_gpu_memory()
252
+ raise MatAnyError(f"CUDA runtime error: {e} | {snap}") from e
253
+ raise MatAnyError(f"Runtime error: {e}") from e
254
  except Exception as e:
255
+ raise MatAnyError(f"Processing failed: {e}") from e
256
+ finally:
257
+ # optional: track times / stats (omitted to keep adapter slim)
258
+ pass
259
 
260
+ # Normalize to 2D numpy [0,1]
261
+ if isinstance(out, torch.Tensor):
262
+ alpha_np = out.detach().float().clamp(0, 1).squeeze().cpu().numpy()
 
263
  else:
264
+ alpha_np = np.asarray(out, dtype=np.float32)
265
  if alpha_np.max() > 1.0:
266
+ alpha_np = alpha_np / 255.0
 
 
267
  alpha_np = np.squeeze(alpha_np)
268
  if alpha_np.ndim != 2:
269
  raise MatAnyError(f"Expected 2D alpha matte; got shape {alpha_np.shape}")
270
+
271
+ return alpha_np.astype(np.float32)
272
 
273
  def _harvest_process_video_output(self, res, out_dir: Path, base: str) -> Tuple[Path, Path]:
274
  """
275
  Accepts varied return types from MatAnyone.process_video and produces
276
+ (alpha.mp4, fg.mp4) inside out_dir.
277
+ Strategy: prefer path returns; as a last resort, glob common output dirs.
278
+ NOTE: If backend returns arrays only, we raise (cannot reconstruct FG here).
 
279
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  alpha_mp4 = out_dir / "alpha.mp4"
281
  fg_mp4 = out_dir / "fg.mp4"
282
 
283
+ # Dict style: look for common keys
 
 
284
  if isinstance(res, dict):
285
  cand_alpha = res.get("alpha") or res.get("alpha_path") or res.get("matte") or res.get("matte_path")
286
  cand_fg = res.get("fg") or res.get("fg_path") or res.get("foreground") or res.get("foreground_path")
 
289
  shutil.copy2(cand_alpha, alpha_mp4); moved += 1
290
  if cand_fg and Path(cand_fg).exists():
291
  shutil.copy2(cand_fg, fg_mp4); moved += 1
292
+ if moved == 2:
293
+ return alpha_mp4, fg_mp4
294
 
295
+ # Tuple/list of paths
296
  if isinstance(res, (list, tuple)) and len(res) >= 1:
 
297
  paths = [Path(x) for x in res if isinstance(x, (str, Path))]
298
  if paths:
 
299
  alpha_candidates = [p for p in paths if p.exists() and ("alpha" in p.name or "matte" in p.name)]
300
  fg_candidates = [p for p in paths if p.exists() and ("fg" in p.name or "fore" in p.name)]
301
  if alpha_candidates and fg_candidates:
 
303
  shutil.copy2(fg_candidates[0], fg_mp4)
304
  return alpha_mp4, fg_mp4
305
 
306
+ # Fallback: glob common dirs
307
  search_dirs = [Path.cwd(), out_dir, Path("results"), Path("result"), Path("output"), Path("outputs")]
308
+ hits: List[Path] = []
309
  for d in search_dirs:
310
  if d.exists():
311
  hits.extend(list(d.rglob(f"*{base}*.*")))
 
312
  alpha_candidates = [p for p in hits if p.suffix.lower() in (".mp4",".mov",".mkv",".avi") and ("alpha" in p.name or "matte" in p.name)]
313
  fg_candidates = [p for p in hits if p.suffix.lower() in (".mp4",".mov",".mkv",".avi") and ("fg" in p.name or "fore" in p.name)]
314
  if alpha_candidates and fg_candidates:
 
315
  shutil.copy2(alpha_candidates[0], alpha_mp4)
316
  shutil.copy2(fg_candidates[0], fg_mp4)
317
  return alpha_mp4, fg_mp4
318
 
319
+ # If we got arrays only, we cannot reconstruct FG here (we'd need to replay frames)
320
+ raise MatAnyError("MatAnyone.process_video did not yield discoverable output paths.")
321
 
322
+ # -----------------------------
323
+ # Public API
324
+ # -----------------------------
325
  def process_stream(
326
  self,
327
  video_path: Path,
 
329
  out_dir: Optional[Path] = None,
330
  progress_cb: Optional[Callable] = None,
331
  ) -> Tuple[Path, Path]:
332
+ """
333
+ Process a video with MatAnyone.
 
 
 
 
 
334
 
335
  Returns:
336
+ (alpha_path, fg_path)
337
 
338
  Raises:
339
+ MatAnyError / FileNotFoundError / ValueError
 
 
340
  """
341
+ video_path = Path(video_path)
342
  if not video_path.exists():
343
  raise FileNotFoundError(f"Input video not found: {video_path}")
 
 
 
 
344
  if out_dir is None:
345
  out_dir = video_path.parent
 
346
  out_dir = Path(out_dir)
347
  out_dir.mkdir(parents=True, exist_ok=True)
 
 
 
 
 
348
 
349
+ _emit_progress(progress_cb, 0.0, "Initializing video processing...")
 
 
 
 
350
 
351
+ # Inspect video
352
+ cap_probe = cv2.VideoCapture(str(video_path))
353
+ if not cap_probe.isOpened():
354
  raise MatAnyError(f"Failed to open video: {video_path}")
355
+ N = int(cap_probe.get(cv2.CAP_PROP_FRAME_COUNT))
356
+ fps = cap_probe.get(cv2.CAP_PROP_FPS) or 25.0
357
+ W = int(cap_probe.get(cv2.CAP_PROP_FRAME_WIDTH))
358
+ H = int(cap_probe.get(cv2.CAP_PROP_FRAME_HEIGHT))
359
+ cap_probe.release()
360
 
361
+ log.info(f"[MATANY] {video_path.name}: {N} frames {W}x{H} @ {fps:.2f} fps")
362
+ _emit_progress(progress_cb, 0.05, f"Video: {N} frames {W}x{H} @ {fps:.2f} fps")
 
 
 
363
 
364
+ # If full-video API exists, prefer it
365
+ alpha_path = out_dir / "alpha.mp4"
366
+ fg_path = out_dir / "fg.mp4"
367
+ t0 = time.time()
368
 
369
  try:
370
  if self._api_mode == "process_video":
371
+ _emit_progress(progress_cb, 0.10, "Using MatAnyone video mode")
 
 
 
372
  if torch.cuda.is_available():
373
+ self._log_gpu_memory()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
 
375
+ # Some builds accept (video, mask), some only (video)
376
  try:
377
+ res = self._core.process_video(
378
+ str(video_path),
379
+ str(seed_mask_path) if seed_mask_path is not None else None
380
+ )
381
+ except TypeError as e:
382
+ if "takes 2 positional arguments but 3 were given" in str(e):
383
+ res = self._core.process_video(str(video_path))
384
+ else:
385
+ raise
386
+
387
+ _emit_progress(progress_cb, 0.90, "Processing complete, collecting outputs…")
388
+ alpha_path, fg_path = self._harvest_process_video_output(res, out_dir, base=video_path.stem)
389
+ _validate_nonempty(alpha_path)
390
+ _validate_nonempty(fg_path)
391
+ _emit_progress(progress_cb, 1.0, "Done!")
392
+ return alpha_path, fg_path
393
+
394
+ # -----------------------------
395
+ # Frame-by-frame streaming path
396
+ # -----------------------------
397
+ _emit_progress(progress_cb, 0.10, f"Using {self._api_mode} (frame-by-frame)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
  cap = cv2.VideoCapture(str(video_path))
399
+ if not cap.isOpened():
400
+ raise MatAnyError(f"Failed to open video for reading: {video_path}")
401
+
402
+ # Writers (alpha as BGR grayscale for broad mp4v compatibility)
403
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
404
+ alpha_writer = cv2.VideoWriter(str(alpha_path), fourcc, fps, (W, H), True) # isColor=True
405
+ fg_writer = cv2.VideoWriter(str(fg_path), fourcc, fps, (W, H), True)
 
 
 
 
 
 
 
 
 
 
 
 
 
406
  if not alpha_writer.isOpened() or not fg_writer.isOpened():
407
+ raise MatAnyError("Failed to initialize VideoWriter(s)")
408
+
409
+ # Optional seed mask
410
+ seed_1hw = None
411
+ if seed_mask_path is not None:
412
+ seed_1hw = _read_mask_hw(Path(seed_mask_path), (H, W))
413
+
414
+ idx = 0
415
+ last_tick = time.time()
416
+ start = time.time()
417
 
418
  try:
419
+ while True:
420
+ ret, frame = cap.read()
421
+ if not ret:
422
+ break
423
+
424
+ current_mask = seed_1hw if idx == 0 else None
425
+ alpha_hw = self._run_frame(frame, current_mask, is_first=(idx == 0))
426
+
427
+ # Compose outputs
428
+ alpha_u8 = (alpha_hw * 255.0 + 0.5).astype(np.uint8)
429
+ alpha_bgr = cv2.cvtColor(alpha_u8, cv2.COLOR_GRAY2BGR)
430
+ # IMPORTANT: alpha_hw already [0,1]
431
+ fg_bgr = (frame.astype(np.float32) * alpha_hw[..., None]).clip(0, 255).astype(np.uint8)
432
+
433
+ alpha_writer.write(alpha_bgr)
434
+ fg_writer.write(fg_bgr)
435
+
436
+ idx += 1
437
+ # progress & ETA
438
+ if N > 0 and (idx % max(5, N // 100) == 0 or (time.time() - last_tick) > 2.0):
439
+ elapsed = time.time() - start
440
+ prog = idx / max(1, N)
441
+ eta_s = (elapsed / prog) * (1.0 - prog) if prog > 0 else 0.0
442
+ if eta_s > 3600:
443
+ eta = f"{eta_s/3600:.1f} h"
444
+ elif eta_s > 60:
445
+ eta = f"{eta_s/60:.1f} m"
446
+ else:
447
+ eta = f"{eta_s:.0f} s"
448
+ fps_run = idx / elapsed if elapsed > 0 else 0.0
449
+ gpu_tail = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450
  if torch.cuda.is_available():
451
+ idx_dev = self.device.index if self.device.index is not None else 0
452
+ mem_a = torch.cuda.memory_allocated(idx_dev) / 1024**2
453
+ mem_r = torch.cuda.memory_reserved(idx_dev) / 1024**2
454
+ gpu_tail = f" | GPU {mem_a:.0f}/{mem_r:.0f}MB"
455
+ _emit_progress(progress_cb, min(0.99, prog), f"Frame {idx}/{N} • {fps_run:.1f} FPS • ETA {eta}{gpu_tail}")
456
+ last_tick = time.time()
457
+
458
+ # finalize
459
+ _validate_nonempty(alpha_path)
460
+ _validate_nonempty(fg_path)
461
+ total = time.time() - start
462
+ fps_run = idx / total if total > 0 else 0.0
463
+ _emit_progress(progress_cb, 1.0, f"Complete! {idx} frames at {fps_run:.1f} FPS")
464
+ return alpha_path, fg_path
465
+
466
+ finally:
467
+ try:
468
+ if cap and hasattr(cap, "isOpened") and cap.isOpened():
469
+ cap.release()
470
+ except Exception:
471
+ pass
472
+ try:
473
+ if alpha_writer:
474
+ alpha_writer.release()
475
+ except Exception:
476
+ pass
477
+ try:
478
+ if fg_writer:
479
+ fg_writer.release()
480
+ except Exception:
481
+ pass
482
+ _safe_empty_cache()
 
 
 
 
 
 
483
 
484
+ except Exception as e:
485
+ msg = f"Error during video processing: {e}"
486
+ log.error(msg, exc_info=True)
487
+ if torch.cuda.is_available():
488
+ msg += f" | {_cuda_snapshot(self.device)}"
489
+ _emit_progress(progress_cb, -1, msg)
490
+ raise MatAnyError(msg) from e
491
+
492
+ # -----------------------------
493
+ # Private chunk helper (not used by public API in this file,
494
+ # but available if your pipeline wants to feed frames itself)
495
+ # -----------------------------
496
+ def _flush_chunk(self, frames_bgr: List[np.ndarray], seed_1hw: Optional[np.ndarray],
497
+ alpha_writer: cv2.VideoWriter, fg_writer: cv2.VideoWriter):
498
  """
499
+ Process an in-memory batch (list of uint8 BGR frames) and write results.
500
+ This path assumes a core that can process batches; if not, it falls back per-frame.
501
  """
 
 
502
  mode = _select_matany_mode(self._core)
503
+ # If the core doesn't support tensor-batch processing, go per-frame
504
+ if mode in ("process_frame", "step"):
505
+ for i, frame in enumerate(frames_bgr):
506
+ alpha_hw = self._run_frame(frame, seed_1hw if i == 0 else None, is_first=(i == 0))
507
+ alpha_u8 = (alpha_hw * 255.0 + 0.5).astype(np.uint8)
508
+ alpha_bgr = cv2.cvtColor(alpha_u8, cv2.COLOR_GRAY2BGR)
509
+ fg_bgr = (frame.astype(np.float32) * alpha_hw[..., None]).clip(0, 255).astype(np.uint8)
510
+ alpha_writer.write(alpha_bgr)
511
+ fg_writer.write(fg_bgr)
512
+ return
513
+
514
+ # If we reach here, assume a tensor-video code path exists (rare in released wheels).
515
+ # For safety we still fallback per-frame because API signatures vary wildly.
516
+ for i, frame in enumerate(frames_bgr):
517
+ alpha_hw = self._run_frame(frame, seed_1hw if i == 0 else None, is_first=(i == 0))
518
+ alpha_u8 = (alpha_hw * 255.0 + 0.5).astype(np.uint8)
519
+ alpha_bgr = cv2.cvtColor(alpha_u8, cv2.COLOR_GRAY2BGR)
520
+ fg_bgr = (frame.astype(np.float32) * alpha_hw[..., None]).clip(0, 255).astype(np.uint8)
521
+ alpha_writer.write(alpha_bgr)
522
+ fg_writer.write(fg_bgr)
523
+
524
+ def _process_stream_chunks(self,
525
+ frames_iterable,
526
+ seed_1hw: Optional[np.ndarray],
527
+ alpha_writer: cv2.VideoWriter,
528
+ fg_writer: cv2.VideoWriter,
529
+ chunk_size: int = 32):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
530
  """
531
  Buffer frames from iterable and process in chunks.
532
  On OOM, retry once with half chunk size; otherwise bubble up MatAnyError.
533
  """
534
+ frames_buf: List[np.ndarray] = []
535
  try:
536
  for f in frames_iterable:
537
  frames_buf.append(f)
 
539
  try:
540
  self._flush_chunk(frames_buf, seed_1hw, alpha_writer, fg_writer)
541
  frames_buf.clear()
542
+ except torch.cuda.OutOfMemoryError as e:
543
+ _safe_empty_cache()
 
 
544
  # one-time downshift
545
  if chunk_size > 4:
546
  half = max(4, chunk_size // 2)
 
549
  self._flush_chunk(sub, seed_1hw, alpha_writer, fg_writer)
550
  frames_buf.clear()
551
  else:
552
+ raise MatAnyError(f"CUDA OOM in _process_stream_chunks | {_cuda_snapshot(self.device)}") from e
553
 
554
  if frames_buf:
555
  self._flush_chunk(frames_buf, seed_1hw, alpha_writer, fg_writer)
556
  frames_buf.clear()
557
 
 
 
 
 
 
558
  except Exception as e:
559
+ raise MatAnyError(f"Unexpected error in _process_stream_chunks: {e}") from e
560
 
561
  finally:
562
  frames_buf.clear()