MogensR commited on
Commit
6e683d6
·
verified ·
1 Parent(s): cc21ef4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -56
app.py CHANGED
@@ -407,56 +407,54 @@ def create_mask(self, image_rgb: np.ndarray) -> Optional[np.ndarray]:
407
 
408
  # =============================================================================
409
  # CHAPTER 6: MATANYONE HANDLER (First-frame PROB mask)
410
- # ============================================================================
411
-
412
 
413
  class MatAnyoneHandler:
414
  """
415
- MatAnyone loader + inference adapter.
416
-
417
- Key decisions:
418
- - Image is fed as Bx3xHxW (B=1).
419
- - First-frame seed is *soft probability* Bx1xHxW in [0,1].
420
- - We try multiple safe call patterns (with/without `matting=True`, with prob
421
- as Bx1xHxW first, then HW) to satisfy different MatAnyone builds.
422
- - Subsequent frames: step(image) with no seed.
423
  """
424
  def __init__(self):
425
  self.core = None
426
  self.initialized = False
427
 
428
  # ----- tensor helpers -----
429
- def _to_bchw_float(self, img01: np.ndarray) -> "torch.Tensor":
430
- """img01: HxWx3 in [0,1] -> torch float (1,3,H,W) on DEVICE"""
431
  assert img01.ndim == 3 and img01.shape[2] == 3, f"Expected HxWx3, got {img01.shape}"
432
  t = torch.from_numpy(img01.transpose(2, 0, 1)).contiguous().float() # (3,H,W)
433
- t = t.unsqueeze(0) # (1,3,H,W)
434
  return t.to(DEVICE, non_blocking=CUDA_AVAILABLE)
435
 
436
- def _prob_b1hw_from_mask_u8(self, mask_u8: np.ndarray, w: int, h: int) -> "torch.Tensor":
437
- """mask_u8: HxW uint8 -> torch float (1,1,H,W) on DEVICE, resized if needed"""
438
  if mask_u8.shape[0] != h or mask_u8.shape[1] != w:
439
  mask_u8 = cv2.resize(mask_u8, (w, h), interpolation=cv2.INTER_NEAREST)
440
- prob = (mask_u8.astype(np.float32) / 255.0)[None, None, ...] # (1,1,H,W)
441
  t = torch.from_numpy(prob).contiguous().float()
442
  return t.to(DEVICE, non_blocking=CUDA_AVAILABLE)
443
 
444
- def _prob_hw_from_mask_u8(self, mask_u8: np.ndarray, w: int, h: int) -> "torch.Tensor":
445
- """Optional fallback: (H,W) prob (kept on DEVICE)"""
446
  if mask_u8.shape[0] != h or mask_u8.shape[1] != w:
447
  mask_u8 = cv2.resize(mask_u8, (w, h), interpolation=cv2.INTER_NEAREST)
448
- prob = (mask_u8.astype(np.float32) / 255.0)
449
  t = torch.from_numpy(prob).contiguous().float()
450
  return t.to(DEVICE, non_blocking=CUDA_AVAILABLE)
451
 
452
  def _alpha_to_u8_hw(self, alpha_like) -> np.ndarray:
453
  """
454
- Accepts torch Tensor or numpy-like. Returns uint8 HxW (0..255).
455
- Handles (H,W), (1,H,W), or (B,1,H,W)/(B,H,W) by squeezing → (H,W).
456
- Also handles tuples/lists (indices, probs) by taking the 2nd item.
457
  """
458
  if isinstance(alpha_like, (list, tuple)) and len(alpha_like) > 1:
459
- alpha_like = alpha_like[1]
460
 
461
  if isinstance(alpha_like, torch.Tensor):
462
  t = alpha_like.detach()
@@ -468,11 +466,13 @@ def _alpha_to_u8_hw(self, alpha_like) -> np.ndarray:
468
  a = np.clip(a, 0, 1)
469
 
470
  a = np.squeeze(a)
471
- # Accept (H,W) only at this point
472
- if a.ndim == 3 and a.shape[0] >= 1:
473
- a = a[0]
474
  if a.ndim != 2:
475
- raise ValueError(f"Alpha must be HxW; got {a.shape}")
 
 
 
 
476
  return np.clip(a * 255.0, 0, 255).astype(np.uint8)
477
 
478
  def initialize(self) -> bool:
@@ -514,42 +514,39 @@ def initialize(self) -> bool:
514
  state.matanyone_error = f"MatAnyone init error: {e}"
515
  return False
516
 
517
- # ----- robust call helpers ------------------------------------------------
518
  def _call_step_seed(self,
519
- img_bchw: "torch.Tensor",
520
- prob_b1hw: "torch.Tensor",
521
- prob_hw: "torch.Tensor"):
522
  """
523
- Try permutations that keep ranks aligned and avoid 5-vs-6D mismatches.
524
- Order:
525
- 1) step(img[B,3,H,W], prob[B,1,H,W])
526
- 2) step(img[B,3,H,W], prob[B,1,H,W], matting=True)
527
- 3) step(img[B,3,H,W], prob[H,W])
528
- 4) step(img[B,3,H,W], prob[H,W], matting=True)
529
  """
530
  try:
531
- return self.core.step(img_bchw, prob_b1hw)
532
  except (TypeError, RuntimeError):
533
  pass
534
  try:
535
- return self.core.step(img_bchw, prob_b1hw, matting=True)
536
  except (TypeError, RuntimeError):
537
  pass
538
  try:
539
- return self.core.step(img_bchw, prob_hw)
540
  except (TypeError, RuntimeError):
541
  pass
542
- return self.core.step(img_bchw, prob_hw, matting=True)
543
 
544
- # ----- video matting using first-frame PROB mask (PATCHED) ----------------
545
  def process_video(self, input_path: str, mask_path: str, output_path: str) -> str:
546
  """
547
  Produce a single-channel alpha mp4 matching input fps & size.
548
 
549
- First frame:
550
- - Build (1,1,H,W) seed prob from SAM2 and pass to step via _call_step_seed.
551
- Remaining frames:
552
- - Call step(image) with no seed.
553
  """
554
  if not self.initialized or self.core is None:
555
  raise RuntimeError("MatAnyone not initialized")
@@ -566,14 +563,13 @@ def process_video(self, input_path: str, mask_path: str, output_path: str) -> st
566
  w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
567
  h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
568
 
569
- # soft seed prob
570
  seed_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
571
  if seed_mask is None:
572
  cap.release()
573
  raise RuntimeError("Seed mask read failed")
574
-
575
- prob_b1hw = self._prob_b1hw_from_mask_u8(seed_mask, w, h) # (1,1,H,W)
576
- prob_hw = self._prob_hw_from_mask_u8(seed_mask, w, h) # (H,W)
577
 
578
  # temp frames
579
  tmp_dir = TEMP_DIR / f"ma_{int(time.time())}_{random.randint(1000,9999)}"
@@ -589,10 +585,10 @@ def process_video(self, input_path: str, mask_path: str, output_path: str) -> st
589
  raise RuntimeError("Empty first frame")
590
 
591
  frame_rgb01 = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
592
- img_bchw = self._to_bchw_float(frame_rgb01) # (1,3,H,W)
593
 
594
  with torch.no_grad():
595
- out_prob = self._call_step_seed(img_bchw, prob_b1hw, prob_hw)
596
 
597
  alpha_u8 = self._alpha_to_u8_hw(out_prob)
598
  cv2.imwrite(str(tmp_dir / f"{frame_idx:06d}.png"), alpha_u8)
@@ -605,14 +601,13 @@ def process_video(self, input_path: str, mask_path: str, output_path: str) -> st
605
  break
606
 
607
  frame_rgb01 = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
608
- img_bchw = self._to_bchw_float(frame_rgb01)
609
 
610
  with torch.no_grad():
611
  try:
612
- out_prob = self.core.step(img_bchw)
613
  except TypeError:
614
- # very old/new variants: try permissive kw
615
- out_prob = self.core.step(img_bchw, matting=True)
616
 
617
  alpha_u8 = self._alpha_to_u8_hw(out_prob)
618
  cv2.imwrite(str(tmp_dir / f"{frame_idx:06d}.png"), alpha_u8)
@@ -640,6 +635,8 @@ def process_video(self, input_path: str, mask_path: str, output_path: str) -> st
640
  return str(alpha_path)
641
 
642
 
 
 
643
  # =============================================================================
644
  # CHAPTER 7: AI BACKGROUNDS
645
  # =============================================================================
 
407
 
408
  # =============================================================================
409
  # CHAPTER 6: MATANYONE HANDLER (First-frame PROB mask)
410
+ # ==============================================================================
 
411
 
412
  class MatAnyoneHandler:
413
  """
414
+ MatAnyone loader + inference adapter (unbatched I/O).
415
+
416
+ This build of MatAnyone appears to add its own batch dimension internally.
417
+ To avoid 5D tensors, we feed:
418
+ - image as CHW (3,H,W)
419
+ - first-frame seed as HW (H,W) (soft probabilities in [0,1])
420
+ We try a few safe call signatures to handle minor API differences
421
+ (with/without `matting=True`, with prob as HW, then 1xHxW).
422
  """
423
  def __init__(self):
424
  self.core = None
425
  self.initialized = False
426
 
427
  # ----- tensor helpers -----
428
+ def _to_chw_float(self, img01: np.ndarray) -> "torch.Tensor":
429
+ """img01: HxWx3 in [0,1] -> torch float (3,H,W) on DEVICE (no batch!)."""
430
  assert img01.ndim == 3 and img01.shape[2] == 3, f"Expected HxWx3, got {img01.shape}"
431
  t = torch.from_numpy(img01.transpose(2, 0, 1)).contiguous().float() # (3,H,W)
 
432
  return t.to(DEVICE, non_blocking=CUDA_AVAILABLE)
433
 
434
+ def _prob_hw_from_mask_u8(self, mask_u8: np.ndarray, w: int, h: int) -> "torch.Tensor":
435
+ """mask_u8: HxW -> torch float (H,W) in [0,1] on DEVICE (no batch, no channel)."""
436
  if mask_u8.shape[0] != h or mask_u8.shape[1] != w:
437
  mask_u8 = cv2.resize(mask_u8, (w, h), interpolation=cv2.INTER_NEAREST)
438
+ prob = (mask_u8.astype(np.float32) / 255.0) # (H,W)
439
  t = torch.from_numpy(prob).contiguous().float()
440
  return t.to(DEVICE, non_blocking=CUDA_AVAILABLE)
441
 
442
+ def _prob_1hw_from_mask_u8(self, mask_u8: np.ndarray, w: int, h: int) -> "torch.Tensor":
443
+ """backup: 1xHxW tensor if a variant expects a leading channel (still unbatched)."""
444
  if mask_u8.shape[0] != h or mask_u8.shape[1] != w:
445
  mask_u8 = cv2.resize(mask_u8, (w, h), interpolation=cv2.INTER_NEAREST)
446
+ prob = (mask_u8.astype(np.float32) / 255.0)[None, ...] # (1,H,W)
447
  t = torch.from_numpy(prob).contiguous().float()
448
  return t.to(DEVICE, non_blocking=CUDA_AVAILABLE)
449
 
450
  def _alpha_to_u8_hw(self, alpha_like) -> np.ndarray:
451
  """
452
+ Accepts torch / numpy / tuple(list) outputs.
453
+ Returns uint8 HxW (0..255).
454
+ Squeezes (1,H,W), (B,1,H,W) etc. down to (H,W) when possible.
455
  """
456
  if isinstance(alpha_like, (list, tuple)) and len(alpha_like) > 1:
457
+ alpha_like = alpha_like[1] # (indices, probs) -> take probs
458
 
459
  if isinstance(alpha_like, torch.Tensor):
460
  t = alpha_like.detach()
 
466
  a = np.clip(a, 0, 1)
467
 
468
  a = np.squeeze(a)
469
+ # If still 3D like (1,H,W) or (H,W,1) after np.squeeze it should be (H,W)
 
 
470
  if a.ndim != 2:
471
+ # Try common forms: (1,H,W) or (B,H,W) -> pick first
472
+ if a.ndim == 3 and a.shape[0] >= 1:
473
+ a = a[0]
474
+ if a.ndim != 2:
475
+ raise ValueError(f"Alpha must be HxW; got {a.shape}")
476
  return np.clip(a * 255.0, 0, 255).astype(np.uint8)
477
 
478
  def initialize(self) -> bool:
 
514
  state.matanyone_error = f"MatAnyone init error: {e}"
515
  return False
516
 
517
+ # ----- robust call helpers (UNBATCHED) -----------------------------------
518
  def _call_step_seed(self,
519
+ img_chw: "torch.Tensor",
520
+ prob_hw: "torch.Tensor",
521
+ prob_1hw: "torch.Tensor"):
522
  """
523
+ Try signatures that keep inputs UNBATCHED:
524
+ 1) step(img[3,H,W], prob[H,W])
525
+ 2) step(img[3,H,W], prob[H,W], matting=True)
526
+ 3) step(img[3,H,W], prob[1,H,W])
527
+ 4) step(img[3,H,W], prob[1,H,W], matting=True)
 
528
  """
529
  try:
530
+ return self.core.step(img_chw, prob_hw)
531
  except (TypeError, RuntimeError):
532
  pass
533
  try:
534
+ return self.core.step(img_chw, prob_hw, matting=True)
535
  except (TypeError, RuntimeError):
536
  pass
537
  try:
538
+ return self.core.step(img_chw, prob_1hw)
539
  except (TypeError, RuntimeError):
540
  pass
541
+ return self.core.step(img_chw, prob_1hw, matting=True)
542
 
543
+ # ----- video matting using first-frame PROB mask --------------------------
544
  def process_video(self, input_path: str, mask_path: str, output_path: str) -> str:
545
  """
546
  Produce a single-channel alpha mp4 matching input fps & size.
547
 
548
+ First frame: pass a soft seed prob (HxW) alongside CHW image.
549
+ Remaining frames: call step(image) only.
 
 
550
  """
551
  if not self.initialized or self.core is None:
552
  raise RuntimeError("MatAnyone not initialized")
 
563
  w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
564
  h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
565
 
566
+ # soft seed prob (unbatched)
567
  seed_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
568
  if seed_mask is None:
569
  cap.release()
570
  raise RuntimeError("Seed mask read failed")
571
+ prob_hw = self._prob_hw_from_mask_u8(seed_mask, w, h) # (H,W)
572
+ prob_1hw = self._prob_1hw_from_mask_u8(seed_mask, w, h) # (1,H,W) fallback
 
573
 
574
  # temp frames
575
  tmp_dir = TEMP_DIR / f"ma_{int(time.time())}_{random.randint(1000,9999)}"
 
585
  raise RuntimeError("Empty first frame")
586
 
587
  frame_rgb01 = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
588
+ img_chw = self._to_chw_float(frame_rgb01) # (3,H,W)
589
 
590
  with torch.no_grad():
591
+ out_prob = self._call_step_seed(img_chw, prob_hw, prob_1hw)
592
 
593
  alpha_u8 = self._alpha_to_u8_hw(out_prob)
594
  cv2.imwrite(str(tmp_dir / f"{frame_idx:06d}.png"), alpha_u8)
 
601
  break
602
 
603
  frame_rgb01 = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
604
+ img_chw = self._to_chw_float(frame_rgb01)
605
 
606
  with torch.no_grad():
607
  try:
608
+ out_prob = self.core.step(img_chw)
609
  except TypeError:
610
+ out_prob = self.core.step(img_chw, matting=True)
 
611
 
612
  alpha_u8 = self._alpha_to_u8_hw(out_prob)
613
  cv2.imwrite(str(tmp_dir / f"{frame_idx:06d}.png"), alpha_u8)
 
635
  return str(alpha_path)
636
 
637
 
638
+
639
+
640
  # =============================================================================
641
  # CHAPTER 7: AI BACKGROUNDS
642
  # =============================================================================