MogensR commited on
Commit
3cd16f5
Β·
1 Parent(s): 975ab1f
Files changed (1) hide show
  1. models/matanyone_loader.py +210 -79
models/matanyone_loader.py CHANGED
@@ -1,6 +1,6 @@
1
  #!/usr/bin/env python3
2
  # =============================================================================
3
- # MatAnyone Adapter (streaming, API-agnostic) β€” with chapter markers
4
  # =============================================================================
5
  """
6
  - Supports multiple MatAnyone variants:
@@ -40,90 +40,152 @@
40
  # =============================================================================
41
  # CHAPTER 1 β€” Small utilities
42
  # =============================================================================
43
- def _emit_progress(cb, pct: float, msg: str):
44
- """Route progress to callback (supports new 2-arg and legacy 1-arg styles)."""
45
- if not cb:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  return
 
 
 
 
47
  try:
48
- cb(pct, msg) # preferred 2-arg
49
- except TypeError:
50
  try:
51
- cb(msg) # legacy 1-arg
52
  except TypeError:
 
 
 
 
 
 
 
 
 
53
  pass
54
 
55
 
 
56
  class MatAnyError(RuntimeError):
57
  """Custom exception for MatAnyone processing errors."""
58
  pass
59
 
60
 
 
61
  def _cuda_snapshot(device: Optional[torch.device] = None) -> str:
62
- """Human-friendly GPU memory snapshot."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  if not torch.cuda.is_available():
64
- return "CUDA: N/A"
65
- idx = 0
66
- if device is not None and isinstance(device, torch.device) and device.index is not None:
67
- idx = device.index
68
- name = torch.cuda.get_device_name(idx)
69
- alloc = torch.cuda.memory_allocated(idx) / 1e9
70
- resv = torch.cuda.memory_reserved(idx) / 1e9
71
- return f"device={idx}, name={name}, alloc={alloc:.2f}GB, reserved={resv:.2f}GB"
72
-
73
-
74
- def _safe_empty_cache():
75
- """Synchronize and empty CUDA cache if present (best-effort)."""
76
- if torch.cuda.is_available():
77
- try:
78
- torch.cuda.synchronize()
79
- except Exception:
80
- pass
81
  torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
 
84
  def _read_mask_hw(mask_path: Path, target_hw: Tuple[int, int]) -> np.ndarray:
85
- """Read mask, convert to float32 [0,1], resize to target (H,W)."""
86
- if not Path(mask_path).exists():
 
 
 
 
 
 
87
  raise MatAnyError(f"Seed mask not found: {mask_path}")
 
88
  mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
89
- if mask is None:
90
- raise MatAnyError(f"Failed to read seed mask: {mask_path}")
 
91
  H, W = target_hw
 
 
92
  if mask.shape[:2] != (H, W):
93
  mask = cv2.resize(mask, (W, H), interpolation=cv2.INTER_LINEAR)
 
94
  maskf = (mask.astype(np.float32) / 255.0).clip(0.0, 1.0)
95
  return maskf
96
 
97
 
98
  def _to_chw01(img_bgr: np.ndarray) -> np.ndarray:
99
- """BGR [H,W,3] uint8 -> CHW float32 [0,1] RGB."""
 
 
 
 
100
  rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
101
  rgbf = rgb.astype(np.float32) / 255.0
102
- chw = np.transpose(rgbf, (2, 0, 1)) # C,H,W
103
  return chw
104
 
105
 
106
- def _validate_nonempty(file_path: Path) -> None:
107
- """Ensure output file exists and is non-empty."""
108
- if not file_path.exists() or file_path.stat().st_size == 0:
109
- raise MatAnyError(f"Output file missing/empty: {file_path}")
110
-
111
-
112
- def _select_matany_mode(core) -> str:
113
- """
114
- Inspect available APIs.
115
- Priority: process_video > process_frame > step
116
- (Note: we still force frame mode in _lazy_init; this helper is used by chunk helper.)
117
- """
118
- if hasattr(core, "process_video") and callable(getattr(core, "process_video")):
119
- return "process_video"
120
- if hasattr(core, "process_frame") and callable(getattr(core, "process_frame")):
121
- return "process_frame"
122
- if hasattr(core, "step") and callable(getattr(core, "step")):
123
- return "step"
124
- raise MatAnyError("No supported MatAnyone API on core (process_video/process_frame/step).")
125
-
126
-
127
  # =============================================================================
128
  # CHAPTER 2 β€” Main session
129
  # =============================================================================
@@ -153,6 +215,7 @@ def __init__(self, device: Optional[str] = None, precision: str = "auto"):
153
  self._core = None
154
  self._api_mode = None
155
  self._initialized = False
 
156
  self._lazy_init()
157
 
158
  log.info(f"Initialized MatAnyoneSession on {self.device} | precision={self.precision}, use_fp16={self.use_fp16}")
@@ -239,7 +302,7 @@ def _maybe_amp(self):
239
  return torch.amp.autocast(device_type="cuda", enabled=enabled and self.use_fp16)
240
 
241
  # -------------------------------------------------------------------------
242
- # 2.4 β€” Frame validation & core call
243
  # -------------------------------------------------------------------------
244
  def _validate_input_frame(self, frame: np.ndarray) -> None:
245
  if not isinstance(frame, np.ndarray):
@@ -249,43 +312,105 @@ def _validate_input_frame(self, frame: np.ndarray) -> None:
249
  if frame.ndim != 3 or frame.shape[2] != 3:
250
  raise MatAnyError(f"Frame must be HWC with 3 channels, got {frame.shape}")
251
 
252
- def _run_frame(self, frame_bgr: np.ndarray, seed_1hw: Optional[np.ndarray], is_first: bool) -> np.ndarray:
 
 
 
253
  """
254
- Run a single frame through MatAnyone.
255
- Returns: alpha matte as 2D np.float32 in [0,1].
 
256
  """
257
- self._validate_input_frame(frame_bgr)
258
 
259
- # Image -> CHW float32 [0,1], then torch on device
260
- img_chw = _to_chw01(frame_bgr) # (3,H,W) float32
261
- img_t = torch.from_numpy(img_chw).to(self.device)
262
 
263
- # Optional seed mask on first frame: expect HW float32 [0,1]
264
- mask_t = None
265
  if is_first and seed_1hw is not None:
 
266
  if seed_1hw.ndim == 3 and seed_1hw.shape[0] == 1:
267
  seed_hw = seed_1hw[0]
268
  elif seed_1hw.ndim == 2:
269
  seed_hw = seed_1hw
270
  else:
271
  raise MatAnyError(f"seed mask must be 1HW or HW; got {seed_1hw.shape}")
272
- mask_t = torch.from_numpy(seed_hw).to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
- # Dispatch into the selected frame API
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  try:
276
  with torch.no_grad(), self._maybe_amp():
277
- if self._api_mode == "step":
278
- out = self._core.step(img_t, mask_t) if mask_t is not None else self._core.step(img_t)
279
- elif self._api_mode == "process_frame":
280
- out = self._core.process_frame(img_t, mask_t)
281
- else:
282
- raise MatAnyError("Internal error: _run_frame used in non-frame mode")
283
  except torch.cuda.OutOfMemoryError as e:
284
  snap = _cuda_snapshot(self.device)
285
  self._log_gpu_memory()
286
  raise MatAnyError(f"CUDA OOM while processing frame | {snap}") from e
287
  except RuntimeError as e:
288
- # If it’s a CUDA-side runtime issue, annotate with snapshot
289
  if "CUDA" in str(e):
290
  snap = _cuda_snapshot(self.device)
291
  self._log_gpu_memory()
@@ -296,20 +421,26 @@ def _run_frame(self, frame_bgr: np.ndarray, seed_1hw: Optional[np.ndarray], is_f
296
 
297
  # Normalize to pure 2D numpy [0,1]
298
  if isinstance(out, torch.Tensor):
299
- alpha_np = out.detach().float().clamp(0, 1).squeeze().cpu().numpy()
300
  else:
301
- alpha_np = np.asarray(out, dtype=np.float32)
302
- if alpha_np.max() > 1.0:
303
- alpha_np = alpha_np / 255.0
 
 
 
304
 
 
305
  alpha_np = np.squeeze(alpha_np)
306
  if alpha_np.ndim != 2:
307
  raise MatAnyError(f"Expected 2D alpha matte; got shape {alpha_np.shape}")
308
 
309
- return alpha_np.astype(np.float32)
 
 
310
 
311
  # -------------------------------------------------------------------------
312
- # 2.5 β€” process_video harvesting (kept for completeness; not used in forced frame mode)
313
  # -------------------------------------------------------------------------
314
  def _harvest_process_video_output(self, res, out_dir: Path, base: str) -> Tuple[Path, Path]:
315
  """
@@ -320,7 +451,7 @@ def _harvest_process_video_output(self, res, out_dir: Path, base: str) -> Tuple[
320
  alpha_mp4 = out_dir / "alpha.mp4"
321
  fg_mp4 = out_dir / "fg.mp4"
322
 
323
- # Dict style: look for common keys
324
  if isinstance(res, dict):
325
  cand_alpha = res.get("alpha") or res.get("alpha_path") or res.get("matte") or res.get("matte_path")
326
  cand_fg = res.get("fg") or res.get("fg_path") or res.get("foreground") or res.get("foreground_path")
@@ -359,7 +490,7 @@ def _harvest_process_video_output(self, res, out_dir: Path, base: str) -> Tuple[
359
  raise MatAnyError("MatAnyone.process_video did not yield discoverable output paths.")
360
 
361
  # -------------------------------------------------------------------------
362
- # 2.6 β€” Public API: process_stream
363
  # -------------------------------------------------------------------------
364
  def process_stream(
365
  self,
 
1
  #!/usr/bin/env python3
2
  # =============================================================================
3
+ # MatAnyone Adapter (streaming, API-agnostic) β€” with chapter markers + layout probe
4
  # =============================================================================
5
  """
6
  - Supports multiple MatAnyone variants:
 
40
  # =============================================================================
41
  # CHAPTER 1 β€” Small utilities
42
  # =============================================================================
43
+
44
+ # --- Progress callback controls ---
45
+ def _env_flag(name: str, default: str = "0") -> bool:
46
+ return os.getenv(name, default).strip() in {"1", "true", "TRUE", "yes", "YES", "on", "ON"}
47
+
48
+ _PROGRESS_CB_ENABLED = _env_flag("MATANY_PROGRESS", "1")
49
+ _PROGRESS_MIN_INTERVAL = float(os.getenv("MATANY_PROGRESS_MIN_SEC", "0.25"))
50
+ _progress_state = {"t": 0.0, "last": None, "disabled": False}
51
+
52
+ def _emit_progress(cb, pct: float, msg: str, *, force: bool = False) -> None:
53
+ """
54
+ Safe progress emitter:
55
+ - Respects MATANY_PROGRESS and rate-limits updates.
56
+ - Never raises upstream; disables itself if the callback misbehaves.
57
+ - Accepts either 2-arg (pct, msg) or legacy 1-arg (msg) callbacks.
58
+ """
59
+ if not cb or not _PROGRESS_CB_ENABLED or _progress_state["disabled"]:
60
  return
61
+ now = time.time()
62
+ if not force and (now - _progress_state["t"] < _PROGRESS_MIN_INTERVAL) and msg == _progress_state["last"]:
63
+ return
64
+
65
  try:
 
 
66
  try:
67
+ cb(pct, msg) # preferred signature
68
  except TypeError:
69
+ cb(msg) # legacy signature
70
+ _progress_state["t"] = now
71
+ _progress_state["last"] = msg
72
+ except Exception as e:
73
+ # Permanently disable to avoid log spam and user-facing crashes
74
+ _progress_state["disabled"] = True
75
+ try:
76
+ log.warning(f"[progress-cb] disabled due to exception: {e}")
77
+ except Exception:
78
  pass
79
 
80
 
81
+ # --- Errors ---
82
  class MatAnyError(RuntimeError):
83
  """Custom exception for MatAnyone processing errors."""
84
  pass
85
 
86
 
87
+ # --- CUDA helpers ---
88
  def _cuda_snapshot(device: Optional[torch.device] = None) -> str:
89
+ """
90
+ Return a short, exception-safe string describing CUDA memory on a device.
91
+ """
92
+ try:
93
+ if not torch.cuda.is_available():
94
+ return "CUDA: N/A"
95
+ idx = 0
96
+ if isinstance(device, torch.device) and device.type == "cuda" and device.index is not None:
97
+ idx = device.index
98
+ name = torch.cuda.get_device_name(idx)
99
+ alloc = torch.cuda.memory_allocated(idx) / (1024 ** 3)
100
+ resv = torch.cuda.memory_reserved(idx) / (1024 ** 3)
101
+ return f"device={idx}, name={name}, alloc={alloc:.2f}GB, reserved={resv:.2f}GB"
102
+ except Exception as e:
103
+ return f"CUDA: snapshot-error: {e!r}"
104
+
105
+
106
+ def _safe_empty_cache() -> None:
107
+ """Try hard to release CUDA cache; never raises."""
108
  if not torch.cuda.is_available():
109
+ return
110
+ try:
111
+ torch.cuda.synchronize()
112
+ except Exception:
113
+ pass
114
+ try:
 
 
 
 
 
 
 
 
 
 
 
115
  torch.cuda.empty_cache()
116
+ except Exception:
117
+ pass
118
+
119
+
120
+ def _supports_fp16(device: Optional[torch.device]) -> bool:
121
+ """
122
+ Best-effort check whether the device can benefit from fp16.
123
+ Returns False for CPU; True for most modern NVIDIA GPUs.
124
+ """
125
+ if not isinstance(device, torch.device) or device.type != "cuda" or not torch.cuda.is_available():
126
+ return False
127
+ try:
128
+ major, minor = torch.cuda.get_device_capability(device.index or 0)
129
+ # Volta (7.0)+ generally supports fast fp16 paths; T4 is 7.5.
130
+ return (major, minor) >= (7, 0)
131
+ except Exception:
132
+ return True # be optimistic if capability query fails
133
+
134
+
135
+ def _ensure_device_usable(device: torch.device) -> None:
136
+ """
137
+ Validate that the chosen device is actually usable.
138
+ Raise MatAnyError early if CUDA is requested but unavailable.
139
+ """
140
+ if device.type == "cuda" and not torch.cuda.is_available():
141
+ raise MatAnyError("CUDA device requested but torch.cuda.is_available() == False")
142
+ if device.type not in {"cuda", "cpu"}:
143
+ raise MatAnyError(f"Unsupported device type: {device.type!r}")
144
+
145
+
146
+ # --- File & image helpers ---
147
+ def _validate_nonempty(file_path: Path) -> None:
148
+ if (not isinstance(file_path, Path)) or (not file_path.exists()) or file_path.stat().st_size <= 0:
149
+ raise MatAnyError(f"Output file missing/empty: {file_path}")
150
 
151
 
152
  def _read_mask_hw(mask_path: Path, target_hw: Tuple[int, int]) -> np.ndarray:
153
+ """
154
+ Read a mask image, return float32 in [0,1] with shape (H, W).
155
+ Validates content and resizes to target (H, W).
156
+ """
157
+ if not isinstance(mask_path, (str, Path)):
158
+ raise MatAnyError(f"Seed mask path must be str/Path, got {type(mask_path)}")
159
+ mask_path = Path(mask_path)
160
+ if not mask_path.exists():
161
  raise MatAnyError(f"Seed mask not found: {mask_path}")
162
+
163
  mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
164
+ if mask is None or mask.size == 0:
165
+ raise MatAnyError(f"Failed to read seed mask or empty file: {mask_path}")
166
+
167
  H, W = target_hw
168
+ if mask.ndim != 2:
169
+ raise MatAnyError(f"Seed mask must be single-channel; got shape {mask.shape}")
170
  if mask.shape[:2] != (H, W):
171
  mask = cv2.resize(mask, (W, H), interpolation=cv2.INTER_LINEAR)
172
+
173
  maskf = (mask.astype(np.float32) / 255.0).clip(0.0, 1.0)
174
  return maskf
175
 
176
 
177
  def _to_chw01(img_bgr: np.ndarray) -> np.ndarray:
178
+ """
179
+ BGR uint8 (H, W, 3) -> RGB float32 (C, H, W) in [0,1].
180
+ """
181
+ if not isinstance(img_bgr, np.ndarray) or img_bgr.dtype != np.uint8 or img_bgr.ndim != 3 or img_bgr.shape[2] != 3:
182
+ raise MatAnyError(f"Frame must be uint8 HWC BGR; got {type(img_bgr)}, shape={getattr(img_bgr, 'shape', None)}")
183
  rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
184
  rgbf = rgb.astype(np.float32) / 255.0
185
+ chw = np.transpose(rgbf, (2, 0, 1)) # C, H, W
186
  return chw
187
 
188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  # =============================================================================
190
  # CHAPTER 2 β€” Main session
191
  # =============================================================================
 
215
  self._core = None
216
  self._api_mode = None
217
  self._initialized = False
218
+ self._layout_locked: Optional[str] = None # 'BCHW+B1HW', 'CHW+HW', etc.
219
  self._lazy_init()
220
 
221
  log.info(f"Initialized MatAnyoneSession on {self.device} | precision={self.precision}, use_fp16={self.use_fp16}")
 
302
  return torch.amp.autocast(device_type="cuda", enabled=enabled and self.use_fp16)
303
 
304
  # -------------------------------------------------------------------------
305
+ # 2.4 β€” Frame validation
306
  # -------------------------------------------------------------------------
307
  def _validate_input_frame(self, frame: np.ndarray) -> None:
308
  if not isinstance(frame, np.ndarray):
 
312
  if frame.ndim != 3 or frame.shape[2] != 3:
313
  raise MatAnyError(f"Frame must be HWC with 3 channels, got {frame.shape}")
314
 
315
+ # -------------------------------------------------------------------------
316
+ # 2.5 β€” Core call helper with first-frame layout probe (locks after success)
317
+ # -------------------------------------------------------------------------
318
+ def _call_core_frame(self, img_chw: np.ndarray, seed_1hw: Optional[np.ndarray], is_first: bool):
319
  """
320
+ Calls MatAnyone frame API trying a small set of plausible layouts.
321
+ Locks layout after the first successful call to avoid repeated probing.
322
+ Returns the raw output (torch.Tensor or numpy).
323
  """
324
+ core = self._core
325
 
326
+ # Build base tensors
327
+ img_t_chw = torch.from_numpy(img_chw).to(self.device) # (3,H,W)
328
+ H, W = img_chw.shape[1], img_chw.shape[2]
329
 
330
+ mask_t_hw = None
 
331
  if is_first and seed_1hw is not None:
332
+ # Ensure pure HW float32 in [0,1]
333
  if seed_1hw.ndim == 3 and seed_1hw.shape[0] == 1:
334
  seed_hw = seed_1hw[0]
335
  elif seed_1hw.ndim == 2:
336
  seed_hw = seed_1hw
337
  else:
338
  raise MatAnyError(f"seed mask must be 1HW or HW; got {seed_1hw.shape}")
339
+ mask_t_hw = torch.from_numpy(seed_hw.astype(np.float32)).to(self.device)
340
+
341
+ def _do_call(layout: str):
342
+ """Dispatch according to a named layout."""
343
+ if layout == "BCHW+B1HW": # Preferred for many PyTorch models
344
+ img_in = img_t_chw.unsqueeze(0).contiguous() # (1,3,H,W)
345
+ mask_in = mask_t_hw.unsqueeze(0).unsqueeze(0).contiguous() if mask_t_hw is not None else None # (1,1,H,W)
346
+ elif layout == "CHW+HW": # Some APIs accept unbatched tensors
347
+ img_in = img_t_chw # (3,H,W)
348
+ mask_in = mask_t_hw if mask_t_hw is not None else None # (H,W)
349
+ elif layout == "BCHW+HW":
350
+ img_in = img_t_chw.unsqueeze(0).contiguous() # (1,3,H,W)
351
+ mask_in = mask_t_hw if mask_t_hw is not None else None # (H,W)
352
+ elif layout == "CHW+1HW":
353
+ img_in = img_t_chw # (3,H,W)
354
+ mask_in = mask_t_hw.unsqueeze(0).contiguous() if mask_t_hw is not None else None # (1,H,W)
355
+ else:
356
+ raise MatAnyError(f"Unknown layout spec: {layout}")
357
+
358
+ if self._api_mode == "step":
359
+ return core.step(img_in, mask_in) if mask_in is not None else core.step(img_in)
360
+ elif self._api_mode == "process_frame":
361
+ return core.process_frame(img_in, mask_in)
362
+ else:
363
+ raise MatAnyError("Internal error: frame dispatch used in non-frame mode")
364
 
365
+ # If layout was already found, use it directly
366
+ if self._layout_locked is not None:
367
+ try:
368
+ return _do_call(self._layout_locked)
369
+ except Exception as e:
370
+ # If a previously-working layout starts failing, surface clear error
371
+ raise MatAnyError(f"MatAnyone call failed with locked layout {self._layout_locked}: {e}")
372
+
373
+ # First-frame probe: try a few reasonable layouts in this order
374
+ probe_order = ["BCHW+B1HW", "CHW+HW", "BCHW+HW", "CHW+1HW"]
375
+ last_err: Optional[str] = None
376
+
377
+ for layout in probe_order:
378
+ try:
379
+ out = _do_call(layout)
380
+ # Success β€” lock layout for subsequent frames
381
+ self._layout_locked = layout
382
+ log.info(f"[MATANY] First-frame layout locked: {layout} (H={H}, W={W})")
383
+ return out
384
+ except Exception as e:
385
+ last_err = str(e)
386
+ log.warning(f"[MATANY] Layout attempt failed ({layout}): {last_err}")
387
+
388
+ # If we reach here, all attempts failed
389
+ snap = _cuda_snapshot(self.device)
390
+ raise MatAnyError(f"MatAnyone first-frame probe failed for all layouts. Last error: {last_err} | {snap}")
391
+
392
+ # -------------------------------------------------------------------------
393
+ # 2.6 β€” Frame runner (normalizes output to 2D [0,1])
394
+ # -------------------------------------------------------------------------
395
+ def _run_frame(self, frame_bgr: np.ndarray, seed_1hw: Optional[np.ndarray], is_first: bool) -> np.ndarray:
396
+ """
397
+ Run a single frame through MatAnyone.
398
+ Returns: alpha matte as 2D np.float32 in [0,1].
399
+ """
400
+ self._validate_input_frame(frame_bgr)
401
+
402
+ # Image -> CHW float32 [0,1]
403
+ img_chw = _to_chw01(frame_bgr) # (3,H,W)
404
+
405
+ # Dispatch (with autocast + no_grad)
406
  try:
407
  with torch.no_grad(), self._maybe_amp():
408
+ out = self._call_core_frame(img_chw, seed_1hw, is_first=is_first)
 
 
 
 
 
409
  except torch.cuda.OutOfMemoryError as e:
410
  snap = _cuda_snapshot(self.device)
411
  self._log_gpu_memory()
412
  raise MatAnyError(f"CUDA OOM while processing frame | {snap}") from e
413
  except RuntimeError as e:
 
414
  if "CUDA" in str(e):
415
  snap = _cuda_snapshot(self.device)
416
  self._log_gpu_memory()
 
421
 
422
  # Normalize to pure 2D numpy [0,1]
423
  if isinstance(out, torch.Tensor):
424
+ alpha_np = out.detach().float().squeeze().cpu().numpy()
425
  else:
426
+ alpha_np = np.asarray(out)
427
+
428
+ # Scale if it looks like 0..255
429
+ alpha_np = alpha_np.astype(np.float32)
430
+ if alpha_np.max() > 1.0:
431
+ alpha_np = alpha_np / 255.0
432
 
433
+ # In case model returns shape like (1,H,W) or (1,1,H,W), squeeze to (H,W)
434
  alpha_np = np.squeeze(alpha_np)
435
  if alpha_np.ndim != 2:
436
  raise MatAnyError(f"Expected 2D alpha matte; got shape {alpha_np.shape}")
437
 
438
+ # Clamp to [0,1]
439
+ alpha_np = np.clip(alpha_np, 0.0, 1.0).astype(np.float32)
440
+ return alpha_np
441
 
442
  # -------------------------------------------------------------------------
443
+ # 2.7 β€” process_video harvesting (kept for completeness; not used in forced frame mode)
444
  # -------------------------------------------------------------------------
445
  def _harvest_process_video_output(self, res, out_dir: Path, base: str) -> Tuple[Path, Path]:
446
  """
 
451
  alpha_mp4 = out_dir / "alpha.mp4"
452
  fg_mp4 = out_dir / "fg.mp4"
453
 
454
+ # Dict style
455
  if isinstance(res, dict):
456
  cand_alpha = res.get("alpha") or res.get("alpha_path") or res.get("matte") or res.get("matte_path")
457
  cand_fg = res.get("fg") or res.get("fg_path") or res.get("foreground") or res.get("foreground_path")
 
490
  raise MatAnyError("MatAnyone.process_video did not yield discoverable output paths.")
491
 
492
  # -------------------------------------------------------------------------
493
+ # 2.8 β€” Public API: process_stream
494
  # -------------------------------------------------------------------------
495
  def process_stream(
496
  self,