MogensR commited on
Commit
80ac736
·
1 Parent(s): b8796b9
Files changed (1) hide show
  1. models/matanyone_loader.py +192 -105
models/matanyone_loader.py CHANGED
@@ -26,6 +26,7 @@
26
  import time
27
  import torch
28
  import logging
 
29
  import numpy as np
30
  from pathlib import Path
31
  from typing import Optional, Callable, Tuple, Union
@@ -229,114 +230,194 @@ def _run_frame(self, frame_bgr: np.ndarray, seed_1hw: Optional[np.ndarray], is_f
229
 
230
  return alpha_np
231
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  def process_stream(
233
  self,
234
  video_path: Path,
235
- seed_mask_path: Optional[Path],
236
- out_dir: Path,
237
- progress_cb: Optional[Callable[[float, str], None]] = None,
238
  ) -> Tuple[Path, Path]:
 
 
 
 
 
 
 
 
 
 
239
  """
240
- Stream the video, write alpha.mp4 and fg.mp4, return their paths.
241
- """
242
- log.info(f"[MATANY] Starting process_video: {video_path}")
243
- log.info(f"[MATANY] API mode: {self._api_mode}")
244
- log.info(f"[MATANY] Device: {self.device}")
245
-
246
- video_path = Path(video_path)
247
  out_dir = Path(out_dir)
248
- _ensure_dir(out_dir)
249
 
250
  cap = cv2.VideoCapture(str(video_path))
251
  if not cap.isOpened():
252
  raise MatAnyError(f"Failed to open video: {video_path}")
253
 
254
- fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
255
- W = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
256
- H = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
257
- N = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
258
-
259
- log.info(f"[MATANY] Video info: {W}x{H}, {N} frames, {fps} fps")
260
-
261
- alpha_writer, fg_writer = _open_video_writers(out_dir, fps, (W, H))
262
 
263
- seed_1hw = None
264
- if seed_mask_path is not None:
265
- seed_hw = _read_mask_hw(seed_mask_path, (H, W))
266
- seed_1hw = _mask_to_1hw(seed_hw)
267
 
268
- # If only process_video is available, we'll chunk to avoid RAM blow-ups.
269
  if self._api_mode == "process_video":
270
- log.info(f"[MATANY] Using chunked process_video mode")
271
- frames_buf = []
272
- idx = 0
273
- chunk = max(1, min(64, int(2048 * 1024 * 1024 / (H * W * 3 * 4)))) # ~2GB budget heuristic
274
- # SAFETY: never 0
275
- if chunk <= 0:
276
- chunk = 32
277
- log.info(f"[MATANY] Chunk size: {chunk} frames")
278
-
279
- while True:
280
- ret, frame = cap.read()
281
- if not ret: # flush tail
282
- if frames_buf:
283
- log.info(f"[MATANY] Flushing final chunk of {len(frames_buf)} frames")
284
- self._flush_chunk(frames_buf, seed_1hw, alpha_writer, fg_writer)
285
- break
286
- frames_buf.append(frame.copy())
287
- if len(frames_buf) >= chunk:
288
- log.info(f"[MATANY] Processing chunk {idx//chunk + 1}: {len(frames_buf)} frames")
289
- self._flush_chunk(frames_buf, seed_1hw, alpha_writer, fg_writer)
290
- frames_buf.clear()
291
-
292
- idx += 1
293
- if N > 0:
294
- _emit_progress(progress_cb, idx / N, f"MatAnyone chunking… ({idx}/{N})")
295
  else:
296
  # Frame-by-frame (preferred)
297
  log.info(f"[MATANY] Using frame-by-frame mode: {self._api_mode}")
298
- idx = 0
299
- while True:
300
- ret, frame = cap.read()
301
- if not ret:
302
- break
303
-
304
- if idx % 10 == 0:
305
- _emit_progress(progress_cb, min(0.999, (idx / N) if N > 0 else 0.0),
306
- f"MatAnyone matting… ({idx}/{N})")
307
-
308
- log.debug(f"[MATANY] Processing frame {idx+1}/{N}")
309
- # Only pass seed mask on first frame
310
- current_mask = seed_1hw if idx == 0 else None
311
- alpha_hw = self._run_frame(frame, current_mask, is_first=(idx == 0))
312
-
313
- # compose fg for immediate write
314
- # alpha 0..1 -> 0..255 3-channel grayscale
315
- alpha_u8 = (alpha_hw * 255.0 + 0.5).astype(np.uint8)
316
- alpha_rgb = cv2.cvtColor(alpha_u8, cv2.COLOR_GRAY2BGR)
317
- # Blend: fg = alpha*frame + (1-alpha)*black == alpha*frame
318
- fg_bgr = (frame.astype(np.float32) * (alpha_hw[..., None])).clip(0, 255).astype(np.uint8)
319
-
320
- alpha_writer.write(alpha_rgb)
321
- fg_writer.write(fg_bgr)
322
-
323
- idx += 1
324
- if progress_cb and N > 0 and idx % 10 == 0:
325
- progress_cb(f"MatAnyone matting… ({idx}/{N})")
326
- log.info(f"[MATANY] Progress: {idx}/{N} frames processed")
327
-
328
- cap.release()
329
- alpha_writer.release()
330
- fg_writer.release()
331
 
332
- alpha_path = out_dir / "alpha.mp4"
333
- fg_path = out_dir / "fg.mp4"
334
- _validate_nonempty(alpha_path)
335
- _validate_nonempty(fg_path)
336
- return alpha_path, fg_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
 
338
  def _flush_chunk(self, frames_bgr, seed_1hw, alpha_writer, fg_writer):
339
- """Call core.process_video(frames, mask) safely, then write results."""
 
 
 
340
  # Prepare inputs
341
  frames_chw = [_to_chw01(f) for f in frames_bgr] # list of CHW
342
  frames_t = torch.from_numpy(np.stack(frames_chw)).to(self.device) # T,C,H,W
@@ -344,24 +425,30 @@ def _flush_chunk(self, frames_bgr, seed_1hw, alpha_writer, fg_writer):
344
 
345
  with torch.no_grad(), self._maybe_amp():
346
  try:
347
- # Preferred: T,C,H,W (+ 1,H,W mask)
348
- alphas = self._core.process_video(frames_t, mask_t)
349
- except RuntimeError as e:
350
- # Some wheels require B,T,C,H,W (+ B,T,1,H,W)
351
- msg = str(e)
352
- if "number of dimensions" in msg or "Expected" in msg or "got" in msg:
353
- frames_btchw = frames_t.unsqueeze(0) # 1,T,C,H,W
354
- mask_bt1hw = mask_t.unsqueeze(0) if mask_t is not None else None # 1,1,H,W -> (maybe ok) ; some expect 1,T,1,H,W
355
- # If mask still mismatches, try broadcast across T:
356
- try:
357
- alphas = self._core.process_video(frames_btchw, mask_bt1hw)
358
- except RuntimeError:
359
- if mask_t is not None:
360
- T = frames_t.shape[0]
361
- mask_bt1hw = mask_t.unsqueeze(0).unsqueeze(0).expand(1, T, 1, *mask_t.shape[-2:]) # 1,T,1,H,W
362
- alphas = self._core.process_video(frames_btchw, mask_bt1hw)
363
  else:
364
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
 
366
  # Normalize to numpy list of HW float32 [0,1]
367
  if isinstance(alphas, torch.Tensor):
 
26
  import time
27
  import torch
28
  import logging
29
+ import tempfile
30
  import numpy as np
31
  from pathlib import Path
32
  from typing import Optional, Callable, Tuple, Union
 
230
 
231
  return alpha_np
232
 
233
+ def _harvest_process_video_output(self, res, out_dir: Path, base: str) -> Tuple[Path, Path]:
234
+ """
235
+ Accepts varied return types from MatAnyone.process_video and produces
236
+ (alpha.mp4, fg.mp4) inside out_dir. Strategies:
237
+ - If res is a sequence of alpha arrays/tensors → write our own videos.
238
+ - If res is dict/tuple of paths → copy/rename.
239
+ - Else: glob typical output dirs for files matching base.
240
+ """
241
+ # Case A: sequence of masks
242
+ import torch, numpy as np, cv2, glob, shutil
243
+
244
+ def _as_np(a):
245
+ if isinstance(a, torch.Tensor):
246
+ a = a.detach().float().cpu().numpy()
247
+ a = np.asarray(a)
248
+ if a.ndim == 3 and a.shape[0] in (1,3): # (C,H,W) → prefer HW
249
+ a = np.squeeze(a) if a.shape[0] == 1 else np.mean(a, axis=0)
250
+ if a.max() > 1.0:
251
+ a = a / 255.0
252
+ return a.clip(0,1).astype(np.float32)
253
+
254
+ alpha_mp4 = out_dir / "alpha.mp4"
255
+ fg_mp4 = out_dir / "fg.mp4"
256
+
257
+ # If we got arrays/tensors: we can't reconstruct FG without original frames here,
258
+ # so prefer path-returning flows. If needed, you can extend this to re-read frames
259
+ # and blend. For now, try to detect paths first.
260
+ if isinstance(res, dict):
261
+ cand_alpha = res.get("alpha") or res.get("alpha_path") or res.get("matte") or res.get("matte_path")
262
+ cand_fg = res.get("fg") or res.get("fg_path") or res.get("foreground") or res.get("foreground_path")
263
+ moved = 0
264
+ if cand_alpha and Path(cand_alpha).exists():
265
+ shutil.copy2(cand_alpha, alpha_mp4); moved += 1
266
+ if cand_fg and Path(cand_fg).exists():
267
+ shutil.copy2(cand_fg, fg_mp4); moved += 1
268
+ if moved == 2: return alpha_mp4, fg_mp4
269
+
270
+ if isinstance(res, (list, tuple)) and len(res) >= 1:
271
+ # Heuristic: assume list/tuple of file paths
272
+ paths = [Path(x) for x in res if isinstance(x, (str, Path))]
273
+ if paths:
274
+ # Pick best matches by name
275
+ alpha_candidates = [p for p in paths if p.exists() and ("alpha" in p.name or "matte" in p.name)]
276
+ fg_candidates = [p for p in paths if p.exists() and ("fg" in p.name or "fore" in p.name)]
277
+ if alpha_candidates and fg_candidates:
278
+ shutil.copy2(alpha_candidates[0], alpha_mp4)
279
+ shutil.copy2(fg_candidates[0], fg_mp4)
280
+ return alpha_mp4, fg_mp4
281
+
282
+ # As last resort, glob common dirs created by the lib
283
+ search_dirs = [Path.cwd(), out_dir, Path("results"), Path("result"), Path("output"), Path("outputs")]
284
+ hits = []
285
+ for d in search_dirs:
286
+ if d.exists():
287
+ hits.extend(list(d.rglob(f"*{base}*.*")))
288
+ # choose best alpha/fg
289
+ alpha_candidates = [p for p in hits if p.suffix.lower() in (".mp4",".mov",".mkv",".avi") and ("alpha" in p.name or "matte" in p.name)]
290
+ fg_candidates = [p for p in hits if p.suffix.lower() in (".mp4",".mov",".mkv",".avi") and ("fg" in p.name or "fore" in p.name)]
291
+ if alpha_candidates and fg_candidates:
292
+ import shutil
293
+ shutil.copy2(alpha_candidates[0], alpha_mp4)
294
+ shutil.copy2(fg_candidates[0], fg_mp4)
295
+ return alpha_mp4, fg_mp4
296
+
297
+ raise MatAnyError("MatAnyone.process_video did not yield discoverable outputs.")
298
+
299
  def process_stream(
300
  self,
301
  video_path: Path,
302
+ seed_mask_path: Optional[Path] = None,
303
+ out_dir: Optional[Path] = None,
304
+ progress_cb: Optional[Callable] = None,
305
  ) -> Tuple[Path, Path]:
306
+ """Process video stream with MatAnyone.
307
+
308
+ Args:
309
+ video_path: Input video file
310
+ seed_mask_path: Optional seed mask image (grayscale, same size as video)
311
+ out_dir: Output directory (default: video_path.parent)
312
+ progress_cb: Callback for progress updates (signature: (float, str) or (str,))
313
+
314
+ Returns:
315
+ Tuple of (alpha_path, fg_path) output video paths
316
  """
317
+ if out_dir is None:
318
+ out_dir = video_path.parent
 
 
 
 
 
319
  out_dir = Path(out_dir)
320
+ out_dir.mkdir(parents=True, exist_ok=True)
321
 
322
  cap = cv2.VideoCapture(str(video_path))
323
  if not cap.isOpened():
324
  raise MatAnyError(f"Failed to open video: {video_path}")
325
 
326
+ N = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
327
+ fps = cap.get(cv2.CAP_PROP_FPS)
328
+ W = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
329
+ H = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
330
+ cap.release()
 
 
 
331
 
332
+ log.info(f"[MATANY] Processing {N} frames ({W}x{H} @ {fps:.1f}fps) from {video_path}")
 
 
 
333
 
 
334
  if self._api_mode == "process_video":
335
+ # --- PATH-BASED CALL (this wheel expects a video path, not tensors) ---
336
+ _emit_progress(progress_cb, 0.05, "MatAnyone (video mode)…")
337
+
338
+ # Some builds accept (video_path, seed_mask_path), others just (video_path)
339
+ try:
340
+ res = self._core.process_video(str(video_path),
341
+ str(seed_mask_path) if seed_mask_path is not None else None)
342
+ except TypeError:
343
+ # Fallback: only video path
344
+ res = self._core.process_video(str(video_path))
345
+
346
+ # Normalize whatever we got back into alpha.mp4 + fg.mp4 in out_dir
347
+ alpha_path, fg_path = self._harvest_process_video_output(res, out_dir, base=video_path.stem)
348
+ _validate_nonempty(alpha_path)
349
+ _validate_nonempty(fg_path)
350
+ _emit_progress(progress_cb, 1.0, "MatAnyone complete")
351
+ return alpha_path, fg_path
 
 
 
 
 
 
 
 
352
  else:
353
  # Frame-by-frame (preferred)
354
  log.info(f"[MATANY] Using frame-by-frame mode: {self._api_mode}")
355
+ cap = cv2.VideoCapture(str(video_path))
356
+ alpha_path = out_dir / "alpha.mp4"
357
+ fg_path = out_dir / "fg.mp4"
358
+
359
+ alpha_writer = cv2.VideoWriter(
360
+ str(alpha_path),
361
+ cv2.VideoWriter_fourcc(*'mp4v'),
362
+ fps,
363
+ (W, H),
364
+ isColor=False
365
+ )
366
+ fg_writer = cv2.VideoWriter(
367
+ str(fg_path),
368
+ cv2.VideoWriter_fourcc(*'mp4v'),
369
+ fps,
370
+ (W, H),
371
+ isColor=True
372
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
 
374
+ try:
375
+ # Load seed mask if provided
376
+ seed_1hw = None
377
+ if seed_mask_path is not None:
378
+ seed_1hw = _read_mask_hw(seed_mask_path, (H, W))
379
+
380
+ idx = 0
381
+ while True:
382
+ ret, frame = cap.read()
383
+ if not ret:
384
+ break
385
+
386
+ if idx % 10 == 0:
387
+ _emit_progress(progress_cb, min(0.999, (idx / N) if N > 0 else 0.0),
388
+ f"MatAnyone matting… ({idx}/{N})")
389
+
390
+ log.debug(f"[MATANY] Processing frame {idx+1}/{N}")
391
+ # Only pass seed mask on first frame
392
+ current_mask = seed_1hw if idx == 0 else None
393
+ alpha_hw = self._run_frame(frame, current_mask, is_first=(idx == 0))
394
+
395
+ # compose fg for immediate write
396
+ # alpha 0..1 -> 0..255 3-channel grayscale
397
+ alpha_u8 = (alpha_hw * 255.0 + 0.5).astype(np.uint8)
398
+ alpha_rgb = cv2.cvtColor(alpha_u8, cv2.COLOR_GRAY2BGR)
399
+ # Blend: fg = alpha*frame + (1-alpha)*black == alpha*frame
400
+ fg_bgr = (frame.astype(np.float32) * (alpha_hw[..., None] / 255.0)).astype(np.uint8)
401
+
402
+ # Write outputs
403
+ alpha_writer.write(alpha_rgb)
404
+ fg_writer.write(fg_bgr)
405
+ idx += 1
406
+
407
+ finally:
408
+ cap.release()
409
+ alpha_writer.release()
410
+ fg_writer.release()
411
+ _validate_nonempty(alpha_path)
412
+ _validate_nonempty(fg_path)
413
+ _emit_progress(progress_cb, 1.0, "MatAnyone complete")
414
+ return alpha_path, fg_path
415
 
416
  def _flush_chunk(self, frames_bgr, seed_1hw, alpha_writer, fg_writer):
417
+ """Process a chunk of frames with MatAnyone."""
418
+ if not frames_bgr:
419
+ return
420
+
421
  # Prepare inputs
422
  frames_chw = [_to_chw01(f) for f in frames_bgr] # list of CHW
423
  frames_t = torch.from_numpy(np.stack(frames_chw)).to(self.device) # T,C,H,W
 
425
 
426
  with torch.no_grad(), self._maybe_amp():
427
  try:
428
+ # Try direct tensor processing first (newer versions)
429
+ if hasattr(self._core, '_process_tensor_video'):
430
+ alphas = self._core._process_tensor_video(frames_t, mask_t)
 
 
 
 
 
 
 
 
 
 
 
 
 
431
  else:
432
+ # Fall back to file-based processing if tensor API not available
433
+ with tempfile.TemporaryDirectory() as tmpdir:
434
+ # Save frames to temp directory
435
+ frame_paths = []
436
+ for i, frame in enumerate(frames_bgr):
437
+ path = os.path.join(tmpdir, f'frame_{i:06d}.png')
438
+ cv2.imwrite(path, frame)
439
+ frame_paths.append(path)
440
+
441
+ # Process video from frames
442
+ alphas = self._core.process_video(tmpdir,
443
+ mask_path=seed_1hw_path if seed_1hw is not None else None)
444
+
445
+ # Ensure alphas is a tensor
446
+ if not isinstance(alphas, torch.Tensor):
447
+ alphas = torch.from_numpy(alphas).to(self.device)
448
+
449
+ except Exception as e:
450
+ log.error(f"Error in _flush_chunk: {str(e)}")
451
+ raise
452
 
453
  # Normalize to numpy list of HW float32 [0,1]
454
  if isinstance(alphas, torch.Tensor):