MogensR commited on
Commit
0319b0d
·
verified ·
1 Parent(s): df850a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -16
app.py CHANGED
@@ -407,19 +407,31 @@ def create_mask(self, image_rgb: np.ndarray) -> Optional[np.ndarray]:
407
 
408
  # =============================================================================
409
  # CHAPTER 6: MATANYONE HANDLER (First-frame PROB mask)
410
- # =============================================================================
 
411
  class MatAnyoneHandler:
 
 
 
 
 
 
 
 
 
412
  def __init__(self):
413
  self.core = None
414
  self.initialized = False
415
 
416
  # ----- tensor helpers -----
417
  def _to_chw_float(self, img01: np.ndarray) -> "torch.Tensor":
 
418
  assert img01.ndim == 3 and img01.shape[2] == 3, f"Expected HxWx3, got {img01.shape}"
419
  t = torch.from_numpy(img01.transpose(2, 0, 1)).contiguous().float() # 3xHxW
420
  return t.to(DEVICE, non_blocking=CUDA_AVAILABLE)
421
 
422
  def _prob_from_mask_u8(self, mask_u8: np.ndarray, w: int, h: int) -> "torch.Tensor":
 
423
  if mask_u8.shape[0] != h or mask_u8.shape[1] != w:
424
  mask_u8 = cv2.resize(mask_u8, (w, h), interpolation=cv2.INTER_NEAREST)
425
  prob = (mask_u8.astype(np.float32) / 255.0)[None, ...] # 1xHxW
@@ -427,8 +439,14 @@ def _prob_from_mask_u8(self, mask_u8: np.ndarray, w: int, h: int) -> "torch.Tens
427
  return t.to(DEVICE, non_blocking=CUDA_AVAILABLE)
428
 
429
  def _alpha_to_u8_hw(self, alpha_like) -> np.ndarray:
 
 
 
 
 
430
  if isinstance(alpha_like, (list, tuple)) and len(alpha_like) > 1:
431
  alpha_like = alpha_like[1] # handle (indices, probs)
 
432
  if isinstance(alpha_like, torch.Tensor):
433
  t = alpha_like.detach()
434
  if t.is_cuda:
@@ -437,14 +455,14 @@ def _alpha_to_u8_hw(self, alpha_like) -> np.ndarray:
437
  else:
438
  a = np.asarray(alpha_like, dtype=np.float32)
439
  a = np.clip(a, 0, 1)
 
440
  a = np.squeeze(a)
 
 
441
  if a.ndim != 2:
442
- # handle shapes (1,H,W) or (K,H,W) pick first
443
- if a.ndim == 3 and a.shape[0] >= 1:
444
- a = a[0]
445
- else:
446
- raise ValueError(f"Alpha must be HxW; got {a.shape}")
447
- return (np.clip(a * 255.0, 0, 255).astype(np.uint8))
448
 
449
  def initialize(self) -> bool:
450
  if not TORCH_AVAILABLE:
@@ -485,8 +503,20 @@ def initialize(self) -> bool:
485
  state.matanyone_error = f"MatAnyone init error: {e}"
486
  return False
487
 
488
- # ----- video matting using first-frame PROB mask -----
489
  def process_video(self, input_path: str, mask_path: str, output_path: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
490
  if not self.initialized or self.core is None:
491
  raise RuntimeError("MatAnyone not initialized")
492
 
@@ -502,34 +532,43 @@ def process_video(self, input_path: str, mask_path: str, output_path: str) -> st
502
  w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
503
  h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
504
 
 
505
  seed_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
506
  if seed_mask is None:
507
  cap.release()
508
  raise RuntimeError("Seed mask read failed")
 
509
 
 
510
  tmp_dir = TEMP_DIR / f"ma_{int(time.time())}_{random.randint(1000,9999)}"
511
  tmp_dir.mkdir(parents=True, exist_ok=True)
512
  memory_manager.register_temp_file(str(tmp_dir))
513
 
 
 
 
 
 
 
 
514
  frame_idx = 0
515
 
516
- # First frame (with PROB mask)
517
  ok, frame_bgr = cap.read()
518
  if not ok or frame_bgr is None:
519
  cap.release()
520
  raise RuntimeError("Empty first frame")
521
  frame_rgb01 = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
522
- img_chw = self._to_chw_float(frame_rgb01) # 3xHxW
523
- prob_chw = self._prob_from_mask_u8(seed_mask, w, h) # 1xHxW
524
 
525
  with torch.no_grad():
526
- out_prob = self.core.step(img_chw, prob=prob_chw, matting=True)
527
 
528
  alpha_u8 = self._alpha_to_u8_hw(out_prob)
529
  cv2.imwrite(str(tmp_dir / f"{frame_idx:06d}.png"), alpha_u8)
530
  frame_idx += 1
531
 
532
- # Remaining frames (no mask)
533
  while True:
534
  ok, frame_bgr = cap.read()
535
  if not ok or frame_bgr is None:
@@ -538,7 +577,10 @@ def process_video(self, input_path: str, mask_path: str, output_path: str) -> st
538
  img_chw = self._to_chw_float(frame_rgb01)
539
 
540
  with torch.no_grad():
541
- out_prob = self.core.step(img_chw)
 
 
 
542
 
543
  alpha_u8 = self._alpha_to_u8_hw(out_prob)
544
  cv2.imwrite(str(tmp_dir / f"{frame_idx:06d}.png"), alpha_u8)
@@ -546,7 +588,7 @@ def process_video(self, input_path: str, mask_path: str, output_path: str) -> st
546
 
547
  cap.release()
548
 
549
- # Build MP4 from alpha pngs
550
  list_file = tmp_dir / "list.txt"
551
  with open(list_file, "w") as f:
552
  for i in range(frame_idx):
@@ -554,7 +596,8 @@ def process_video(self, input_path: str, mask_path: str, output_path: str) -> st
554
 
555
  cmd = [
556
  "ffmpeg", "-y", "-hide_banner", "-loglevel", "error",
557
- "-f", "concat", "-safe", "0", "-r", f"{fps:.6f}",
 
558
  "-i", str(list_file),
559
  "-vf", f"format=gray,scale={w}:{h}:flags=area",
560
  "-pix_fmt", "yuv420p",
@@ -564,6 +607,7 @@ def process_video(self, input_path: str, mask_path: str, output_path: str) -> st
564
  subprocess.run(cmd, check=True)
565
  return str(alpha_path)
566
 
 
567
  # =============================================================================
568
  # CHAPTER 7: AI BACKGROUNDS
569
  # =============================================================================
 
407
 
408
  # =============================================================================
409
  # CHAPTER 6: MATANYONE HANDLER (First-frame PROB mask)
410
+ # ============================================================================
411
+
412
  class MatAnyoneHandler:
413
+ """
414
+ MatAnyone loader + inference adapter.
415
+
416
+ Key points:
417
+ - Uses first-frame *soft probability* seed (1xHxW float in [0,1]), not an index mask.
418
+ - Calls InferenceCore.step with the prob map as a **positional** arg (some builds reject `prob=`).
419
+ - Tries `matting=True` when supported; falls back if the kwarg is not available.
420
+ - Always feeds CHW tensors for images (3,H,W) and 1xHxW for probs — no extra batch dims.
421
+ """
422
  def __init__(self):
423
  self.core = None
424
  self.initialized = False
425
 
426
  # ----- tensor helpers -----
427
  def _to_chw_float(self, img01: np.ndarray) -> "torch.Tensor":
428
+ """img01: HxWx3 in [0,1] -> torch float 3xHxW on DEVICE"""
429
  assert img01.ndim == 3 and img01.shape[2] == 3, f"Expected HxWx3, got {img01.shape}"
430
  t = torch.from_numpy(img01.transpose(2, 0, 1)).contiguous().float() # 3xHxW
431
  return t.to(DEVICE, non_blocking=CUDA_AVAILABLE)
432
 
433
  def _prob_from_mask_u8(self, mask_u8: np.ndarray, w: int, h: int) -> "torch.Tensor":
434
+ """mask_u8: HxW uint8 -> torch float 1xHxW on DEVICE, resized to (w,h) if needed"""
435
  if mask_u8.shape[0] != h or mask_u8.shape[1] != w:
436
  mask_u8 = cv2.resize(mask_u8, (w, h), interpolation=cv2.INTER_NEAREST)
437
  prob = (mask_u8.astype(np.float32) / 255.0)[None, ...] # 1xHxW
 
439
  return t.to(DEVICE, non_blocking=CUDA_AVAILABLE)
440
 
441
  def _alpha_to_u8_hw(self, alpha_like) -> np.ndarray:
442
+ """
443
+ Accepts torch Tensor or numpy-like. Returns uint8 HxW (0..255).
444
+ Handles shapes (H,W), (1,H,W), or (K,H,W) -> picks first channel.
445
+ Also handles MatAnyone tuples/lists like (indices, probs) by taking the 2nd item.
446
+ """
447
  if isinstance(alpha_like, (list, tuple)) and len(alpha_like) > 1:
448
  alpha_like = alpha_like[1] # handle (indices, probs)
449
+
450
  if isinstance(alpha_like, torch.Tensor):
451
  t = alpha_like.detach()
452
  if t.is_cuda:
 
455
  else:
456
  a = np.asarray(alpha_like, dtype=np.float32)
457
  a = np.clip(a, 0, 1)
458
+
459
  a = np.squeeze(a)
460
+ if a.ndim == 3 and a.shape[0] >= 1:
461
+ a = a[0]
462
  if a.ndim != 2:
463
+ raise ValueError(f"Alpha must be HxW; got {a.shape}")
464
+
465
+ return np.clip(a * 255.0, 0, 255).astype(np.uint8)
 
 
 
466
 
467
  def initialize(self) -> bool:
468
  if not TORCH_AVAILABLE:
 
503
  state.matanyone_error = f"MatAnyone init error: {e}"
504
  return False
505
 
506
+ # ----- video matting using first-frame PROB mask (PATCHED) -----
507
  def process_video(self, input_path: str, mask_path: str, output_path: str) -> str:
508
+ """
509
+ Produce a single-channel alpha mp4 matching input fps & size.
510
+
511
+ First frame:
512
+ - Generate soft prob (1,H,W) from SAM2 mask and pass as positional arg to step().
513
+ - Try step(image, prob, matting=True); if TypeError, call step(image, prob).
514
+
515
+ Remaining frames:
516
+ - Try step(image, matting=True); fallback to step(image).
517
+
518
+ Returns: path to alpha.mp4
519
+ """
520
  if not self.initialized or self.core is None:
521
  raise RuntimeError("MatAnyone not initialized")
522
 
 
532
  w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
533
  h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
534
 
535
+ # soft seed prob (1,H,W) in [0,1]
536
  seed_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
537
  if seed_mask is None:
538
  cap.release()
539
  raise RuntimeError("Seed mask read failed")
540
+ prob_1hw = self._prob_from_mask_u8(seed_mask, w, h) # (1,H,W) float
541
 
542
+ # temp frames
543
  tmp_dir = TEMP_DIR / f"ma_{int(time.time())}_{random.randint(1000,9999)}"
544
  tmp_dir.mkdir(parents=True, exist_ok=True)
545
  memory_manager.register_temp_file(str(tmp_dir))
546
 
547
+ def _step_with_prob(image_chw: "torch.Tensor", prob_1hw_t: "torch.Tensor"):
548
+ """Call step with positional prob; fall back if 'matting' kwarg unsupported."""
549
+ try:
550
+ return self.core.step(image_chw, prob_1hw_t, matting=True)
551
+ except TypeError:
552
+ return self.core.step(image_chw, prob_1hw_t)
553
+
554
  frame_idx = 0
555
 
556
+ # --- first frame (with soft prob) ---
557
  ok, frame_bgr = cap.read()
558
  if not ok or frame_bgr is None:
559
  cap.release()
560
  raise RuntimeError("Empty first frame")
561
  frame_rgb01 = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
562
+ img_chw = self._to_chw_float(frame_rgb01) # (3,H,W)
 
563
 
564
  with torch.no_grad():
565
+ out_prob = _step_with_prob(img_chw, prob_1hw)
566
 
567
  alpha_u8 = self._alpha_to_u8_hw(out_prob)
568
  cv2.imwrite(str(tmp_dir / f"{frame_idx:06d}.png"), alpha_u8)
569
  frame_idx += 1
570
 
571
+ # --- remaining frames (no seed) ---
572
  while True:
573
  ok, frame_bgr = cap.read()
574
  if not ok or frame_bgr is None:
 
577
  img_chw = self._to_chw_float(frame_rgb01)
578
 
579
  with torch.no_grad():
580
+ try:
581
+ out_prob = self.core.step(img_chw, matting=True)
582
+ except TypeError:
583
+ out_prob = self.core.step(img_chw)
584
 
585
  alpha_u8 = self._alpha_to_u8_hw(out_prob)
586
  cv2.imwrite(str(tmp_dir / f"{frame_idx:06d}.png"), alpha_u8)
 
588
 
589
  cap.release()
590
 
591
+ # --- encode PNGs alpha mp4 ---
592
  list_file = tmp_dir / "list.txt"
593
  with open(list_file, "w") as f:
594
  for i in range(frame_idx):
 
596
 
597
  cmd = [
598
  "ffmpeg", "-y", "-hide_banner", "-loglevel", "error",
599
+ "-f", "concat", "-safe", "0",
600
+ "-r", f"{fps:.6f}",
601
  "-i", str(list_file),
602
  "-vf", f"format=gray,scale={w}:{h}:flags=area",
603
  "-pix_fmt", "yuv420p",
 
607
  subprocess.run(cmd, check=True)
608
  return str(alpha_path)
609
 
610
+
611
  # =============================================================================
612
  # CHAPTER 7: AI BACKGROUNDS
613
  # =============================================================================