MogensR commited on
Commit
cc21ef4
·
verified ·
1 Parent(s): a1c7ef1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -44
app.py CHANGED
@@ -409,52 +409,54 @@ def create_mask(self, image_rgb: np.ndarray) -> Optional[np.ndarray]:
409
  # CHAPTER 6: MATANYONE HANDLER (First-frame PROB mask)
410
  # ============================================================================
411
 
 
412
  class MatAnyoneHandler:
413
  """
414
  MatAnyone loader + inference adapter.
415
 
416
- Strategy:
417
- - Image CHW (3,H,W), no batch/time dims.
418
- - Seed mask try **2D prob** (H,W) first to avoid 5 vs 6-D concat, then fall back to 1xHxW.
419
- - Call InferenceCore.step with prob as a **positional** argument.
420
- - Try with/without `matting=True` (some builds don't accept it).
421
- - Subsequent frames call step(image) with no seed.
422
  """
423
  def __init__(self):
424
  self.core = None
425
  self.initialized = False
426
 
427
  # ----- tensor helpers -----
428
- def _to_chw_float(self, img01: np.ndarray) -> "torch.Tensor":
429
- """img01: HxWx3 in [0,1] -> torch float 3xHxW on DEVICE"""
430
  assert img01.ndim == 3 and img01.shape[2] == 3, f"Expected HxWx3, got {img01.shape}"
431
  t = torch.from_numpy(img01.transpose(2, 0, 1)).contiguous().float() # (3,H,W)
 
432
  return t.to(DEVICE, non_blocking=CUDA_AVAILABLE)
433
 
434
- def _prob_hw_from_mask_u8(self, mask_u8: np.ndarray, w: int, h: int) -> "torch.Tensor":
435
- """mask_u8: HxW uint8 -> torch float (H,W) on DEVICE, resized if needed"""
436
  if mask_u8.shape[0] != h or mask_u8.shape[1] != w:
437
  mask_u8 = cv2.resize(mask_u8, (w, h), interpolation=cv2.INTER_NEAREST)
438
- prob = (mask_u8.astype(np.float32) / 255.0) # (H,W)
439
  t = torch.from_numpy(prob).contiguous().float()
440
  return t.to(DEVICE, non_blocking=CUDA_AVAILABLE)
441
 
442
- def _prob_1hw_from_mask_u8(self, mask_u8: np.ndarray, w: int, h: int) -> "torch.Tensor":
443
- """mask_u8: HxW uint8 -> torch float (1,H,W) on DEVICE, resized if needed"""
444
  if mask_u8.shape[0] != h or mask_u8.shape[1] != w:
445
  mask_u8 = cv2.resize(mask_u8, (w, h), interpolation=cv2.INTER_NEAREST)
446
- prob = (mask_u8.astype(np.float32) / 255.0)[None, ...] # (1,H,W)
447
  t = torch.from_numpy(prob).contiguous().float()
448
  return t.to(DEVICE, non_blocking=CUDA_AVAILABLE)
449
 
450
  def _alpha_to_u8_hw(self, alpha_like) -> np.ndarray:
451
  """
452
  Accepts torch Tensor or numpy-like. Returns uint8 HxW (0..255).
453
- Handles (H,W), (1,H,W), or (K,H,W) by taking the first channel if needed.
454
- Also handles MatAnyone tuples/lists like (indices, probs) by taking the 2nd item.
455
  """
456
  if isinstance(alpha_like, (list, tuple)) and len(alpha_like) > 1:
457
- alpha_like = alpha_like[1] # handle (indices, probs)
458
 
459
  if isinstance(alpha_like, torch.Tensor):
460
  t = alpha_like.detach()
@@ -466,11 +468,11 @@ def _alpha_to_u8_hw(self, alpha_like) -> np.ndarray:
466
  a = np.clip(a, 0, 1)
467
 
468
  a = np.squeeze(a)
 
469
  if a.ndim == 3 and a.shape[0] >= 1:
470
  a = a[0]
471
  if a.ndim != 2:
472
  raise ValueError(f"Alpha must be HxW; got {a.shape}")
473
-
474
  return np.clip(a * 255.0, 0, 255).astype(np.uint8)
475
 
476
  def initialize(self) -> bool:
@@ -513,33 +515,31 @@ def initialize(self) -> bool:
513
  return False
514
 
515
  # ----- robust call helpers ------------------------------------------------
516
- def _call_step_seed(self, img_chw: "torch.Tensor",
517
- prob_hw: "torch.Tensor",
518
- prob_1hw: "torch.Tensor"):
 
519
  """
520
- Try a few safe permutations to satisfy different MatAnyone builds.
521
- Order chosen to avoid the 5-vs-6D concat error seen in group_modules.py.
 
 
 
 
522
  """
523
- # 1) image (3,H,W), prob (H,W)
524
  try:
525
- return self.core.step(img_chw, prob_hw)
526
  except (TypeError, RuntimeError):
527
  pass
528
-
529
- # 2) image (3,H,W), prob (1,H,W)
530
  try:
531
- return self.core.step(img_chw, prob_1hw)
532
  except (TypeError, RuntimeError):
533
  pass
534
-
535
- # 3) image (3,H,W), prob (H,W), matting kw
536
  try:
537
- return self.core.step(img_chw, prob_hw, matting=True)
538
  except (TypeError, RuntimeError):
539
  pass
540
-
541
- # 4) image (3,H,W), prob (1,H,W), matting kw
542
- return self.core.step(img_chw, prob_1hw, matting=True)
543
 
544
  # ----- video matting using first-frame PROB mask (PATCHED) ----------------
545
  def process_video(self, input_path: str, mask_path: str, output_path: str) -> str:
@@ -547,8 +547,7 @@ def process_video(self, input_path: str, mask_path: str, output_path: str) -> st
547
  Produce a single-channel alpha mp4 matching input fps & size.
548
 
549
  First frame:
550
- - Build both (H,W) and (1,H,W) soft prob tensors from SAM2 mask.
551
- - Call step(...) via _call_step_seed (tries 2D first to dodge 6D concat).
552
  Remaining frames:
553
  - Call step(image) with no seed.
554
  """
@@ -567,13 +566,14 @@ def process_video(self, input_path: str, mask_path: str, output_path: str) -> st
567
  w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
568
  h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
569
 
 
570
  seed_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
571
  if seed_mask is None:
572
  cap.release()
573
  raise RuntimeError("Seed mask read failed")
574
 
575
- prob_hw = self._prob_hw_from_mask_u8(seed_mask, w, h) # (H,W)
576
- prob_1hw = self._prob_1hw_from_mask_u8(seed_mask, w, h) # (1,H,W)
577
 
578
  # temp frames
579
  tmp_dir = TEMP_DIR / f"ma_{int(time.time())}_{random.randint(1000,9999)}"
@@ -589,10 +589,10 @@ def process_video(self, input_path: str, mask_path: str, output_path: str) -> st
589
  raise RuntimeError("Empty first frame")
590
 
591
  frame_rgb01 = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
592
- img_chw = self._to_chw_float(frame_rgb01) # (3,H,W)
593
 
594
  with torch.no_grad():
595
- out_prob = self._call_step_seed(img_chw, prob_hw, prob_1hw)
596
 
597
  alpha_u8 = self._alpha_to_u8_hw(out_prob)
598
  cv2.imwrite(str(tmp_dir / f"{frame_idx:06d}.png"), alpha_u8)
@@ -605,14 +605,14 @@ def process_video(self, input_path: str, mask_path: str, output_path: str) -> st
605
  break
606
 
607
  frame_rgb01 = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
608
- img_chw = self._to_chw_float(frame_rgb01)
609
 
610
  with torch.no_grad():
611
  try:
612
- out_prob = self.core.step(img_chw) # simplest path
613
  except TypeError:
614
- # Extremely old/new variants: try permissive kw
615
- out_prob = self.core.step(img_chw, matting=True)
616
 
617
  alpha_u8 = self._alpha_to_u8_hw(out_prob)
618
  cv2.imwrite(str(tmp_dir / f"{frame_idx:06d}.png"), alpha_u8)
 
409
  # CHAPTER 6: MATANYONE HANDLER (First-frame PROB mask)
410
  # ============================================================================
411
 
412
+
413
  class MatAnyoneHandler:
414
  """
415
  MatAnyone loader + inference adapter.
416
 
417
+ Key decisions:
418
+ - Image is fed as Bx3xHxW (B=1).
419
+ - First-frame seed is *soft probability* Bx1xHxW in [0,1].
420
+ - We try multiple safe call patterns (with/without `matting=True`, with prob
421
+ as Bx1xHxW first, then HW) to satisfy different MatAnyone builds.
422
+ - Subsequent frames: step(image) with no seed.
423
  """
424
  def __init__(self):
425
  self.core = None
426
  self.initialized = False
427
 
428
  # ----- tensor helpers -----
429
+ def _to_bchw_float(self, img01: np.ndarray) -> "torch.Tensor":
430
+ """img01: HxWx3 in [0,1] -> torch float (1,3,H,W) on DEVICE"""
431
  assert img01.ndim == 3 and img01.shape[2] == 3, f"Expected HxWx3, got {img01.shape}"
432
  t = torch.from_numpy(img01.transpose(2, 0, 1)).contiguous().float() # (3,H,W)
433
+ t = t.unsqueeze(0) # (1,3,H,W)
434
  return t.to(DEVICE, non_blocking=CUDA_AVAILABLE)
435
 
436
+ def _prob_b1hw_from_mask_u8(self, mask_u8: np.ndarray, w: int, h: int) -> "torch.Tensor":
437
+ """mask_u8: HxW uint8 -> torch float (1,1,H,W) on DEVICE, resized if needed"""
438
  if mask_u8.shape[0] != h or mask_u8.shape[1] != w:
439
  mask_u8 = cv2.resize(mask_u8, (w, h), interpolation=cv2.INTER_NEAREST)
440
+ prob = (mask_u8.astype(np.float32) / 255.0)[None, None, ...] # (1,1,H,W)
441
  t = torch.from_numpy(prob).contiguous().float()
442
  return t.to(DEVICE, non_blocking=CUDA_AVAILABLE)
443
 
444
+ def _prob_hw_from_mask_u8(self, mask_u8: np.ndarray, w: int, h: int) -> "torch.Tensor":
445
+ """Optional fallback: (H,W) prob (kept on DEVICE)"""
446
  if mask_u8.shape[0] != h or mask_u8.shape[1] != w:
447
  mask_u8 = cv2.resize(mask_u8, (w, h), interpolation=cv2.INTER_NEAREST)
448
+ prob = (mask_u8.astype(np.float32) / 255.0)
449
  t = torch.from_numpy(prob).contiguous().float()
450
  return t.to(DEVICE, non_blocking=CUDA_AVAILABLE)
451
 
452
  def _alpha_to_u8_hw(self, alpha_like) -> np.ndarray:
453
  """
454
  Accepts torch Tensor or numpy-like. Returns uint8 HxW (0..255).
455
+ Handles (H,W), (1,H,W), or (B,1,H,W)/(B,H,W) by squeezing (H,W).
456
+ Also handles tuples/lists (indices, probs) by taking the 2nd item.
457
  """
458
  if isinstance(alpha_like, (list, tuple)) and len(alpha_like) > 1:
459
+ alpha_like = alpha_like[1]
460
 
461
  if isinstance(alpha_like, torch.Tensor):
462
  t = alpha_like.detach()
 
468
  a = np.clip(a, 0, 1)
469
 
470
  a = np.squeeze(a)
471
+ # Accept (H,W) only at this point
472
  if a.ndim == 3 and a.shape[0] >= 1:
473
  a = a[0]
474
  if a.ndim != 2:
475
  raise ValueError(f"Alpha must be HxW; got {a.shape}")
 
476
  return np.clip(a * 255.0, 0, 255).astype(np.uint8)
477
 
478
  def initialize(self) -> bool:
 
515
  return False
516
 
517
  # ----- robust call helpers ------------------------------------------------
518
+ def _call_step_seed(self,
519
+ img_bchw: "torch.Tensor",
520
+ prob_b1hw: "torch.Tensor",
521
+ prob_hw: "torch.Tensor"):
522
  """
523
+ Try permutations that keep ranks aligned and avoid 5-vs-6D mismatches.
524
+ Order:
525
+ 1) step(img[B,3,H,W], prob[B,1,H,W])
526
+ 2) step(img[B,3,H,W], prob[B,1,H,W], matting=True)
527
+ 3) step(img[B,3,H,W], prob[H,W])
528
+ 4) step(img[B,3,H,W], prob[H,W], matting=True)
529
  """
 
530
  try:
531
+ return self.core.step(img_bchw, prob_b1hw)
532
  except (TypeError, RuntimeError):
533
  pass
 
 
534
  try:
535
+ return self.core.step(img_bchw, prob_b1hw, matting=True)
536
  except (TypeError, RuntimeError):
537
  pass
 
 
538
  try:
539
+ return self.core.step(img_bchw, prob_hw)
540
  except (TypeError, RuntimeError):
541
  pass
542
+ return self.core.step(img_bchw, prob_hw, matting=True)
 
 
543
 
544
  # ----- video matting using first-frame PROB mask (PATCHED) ----------------
545
  def process_video(self, input_path: str, mask_path: str, output_path: str) -> str:
 
547
  Produce a single-channel alpha mp4 matching input fps & size.
548
 
549
  First frame:
550
+ - Build (1,1,H,W) seed prob from SAM2 and pass to step via _call_step_seed.
 
551
  Remaining frames:
552
  - Call step(image) with no seed.
553
  """
 
566
  w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
567
  h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
568
 
569
+ # soft seed prob
570
  seed_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
571
  if seed_mask is None:
572
  cap.release()
573
  raise RuntimeError("Seed mask read failed")
574
 
575
+ prob_b1hw = self._prob_b1hw_from_mask_u8(seed_mask, w, h) # (1,1,H,W)
576
+ prob_hw = self._prob_hw_from_mask_u8(seed_mask, w, h) # (H,W)
577
 
578
  # temp frames
579
  tmp_dir = TEMP_DIR / f"ma_{int(time.time())}_{random.randint(1000,9999)}"
 
589
  raise RuntimeError("Empty first frame")
590
 
591
  frame_rgb01 = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
592
+ img_bchw = self._to_bchw_float(frame_rgb01) # (1,3,H,W)
593
 
594
  with torch.no_grad():
595
+ out_prob = self._call_step_seed(img_bchw, prob_b1hw, prob_hw)
596
 
597
  alpha_u8 = self._alpha_to_u8_hw(out_prob)
598
  cv2.imwrite(str(tmp_dir / f"{frame_idx:06d}.png"), alpha_u8)
 
605
  break
606
 
607
  frame_rgb01 = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
608
+ img_bchw = self._to_bchw_float(frame_rgb01)
609
 
610
  with torch.no_grad():
611
  try:
612
+ out_prob = self.core.step(img_bchw)
613
  except TypeError:
614
+ # very old/new variants: try permissive kw
615
+ out_prob = self.core.step(img_bchw, matting=True)
616
 
617
  alpha_u8 = self._alpha_to_u8_hw(out_prob)
618
  cv2.imwrite(str(tmp_dir / f"{frame_idx:06d}.png"), alpha_u8)