MogensR commited on
Commit
f41f37e
·
verified ·
1 Parent(s): 6e683d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -55
app.py CHANGED
@@ -409,16 +409,22 @@ def create_mask(self, image_rgb: np.ndarray) -> Optional[np.ndarray]:
409
  # CHAPTER 6: MATANYONE HANDLER (First-frame PROB mask)
410
  # ==============================================================================
411
 
 
 
 
 
412
  class MatAnyoneHandler:
413
  """
414
- MatAnyone loader + inference adapter (unbatched I/O).
415
-
416
- This build of MatAnyone appears to add its own batch dimension internally.
417
- To avoid 5D tensors, we feed:
418
- - image as CHW (3,H,W)
419
- - first-frame seed as HW (H,W) (soft probabilities in [0,1])
420
- We try a few safe call signatures to handle minor API differences
421
- (with/without `matting=True`, with prob as HW, then 1xHxW).
 
 
422
  """
423
  def __init__(self):
424
  self.core = None
@@ -426,7 +432,7 @@ def __init__(self):
426
 
427
  # ----- tensor helpers -----
428
  def _to_chw_float(self, img01: np.ndarray) -> "torch.Tensor":
429
- """img01: HxWx3 in [0,1] -> torch float (3,H,W) on DEVICE (no batch!)."""
430
  assert img01.ndim == 3 and img01.shape[2] == 3, f"Expected HxWx3, got {img01.shape}"
431
  t = torch.from_numpy(img01.transpose(2, 0, 1)).contiguous().float() # (3,H,W)
432
  return t.to(DEVICE, non_blocking=CUDA_AVAILABLE)
@@ -440,7 +446,7 @@ def _prob_hw_from_mask_u8(self, mask_u8: np.ndarray, w: int, h: int) -> "torch.T
440
  return t.to(DEVICE, non_blocking=CUDA_AVAILABLE)
441
 
442
  def _prob_1hw_from_mask_u8(self, mask_u8: np.ndarray, w: int, h: int) -> "torch.Tensor":
443
- """backup: 1xHxW tensor if a variant expects a leading channel (still unbatched)."""
444
  if mask_u8.shape[0] != h or mask_u8.shape[1] != w:
445
  mask_u8 = cv2.resize(mask_u8, (w, h), interpolation=cv2.INTER_NEAREST)
446
  prob = (mask_u8.astype(np.float32) / 255.0)[None, ...] # (1,H,W)
@@ -450,8 +456,7 @@ def _prob_1hw_from_mask_u8(self, mask_u8: np.ndarray, w: int, h: int) -> "torch.
450
  def _alpha_to_u8_hw(self, alpha_like) -> np.ndarray:
451
  """
452
  Accepts torch / numpy / tuple(list) outputs.
453
- Returns uint8 HxW (0..255).
454
- Squeezes (1,H,W), (B,1,H,W) etc. down to (H,W) when possible.
455
  """
456
  if isinstance(alpha_like, (list, tuple)) and len(alpha_like) > 1:
457
  alpha_like = alpha_like[1] # (indices, probs) -> take probs
@@ -466,13 +471,11 @@ def _alpha_to_u8_hw(self, alpha_like) -> np.ndarray:
466
  a = np.clip(a, 0, 1)
467
 
468
  a = np.squeeze(a)
469
- # If still 3D like (1,H,W) or (H,W,1) after np.squeeze it should be (H,W)
 
470
  if a.ndim != 2:
471
- # Try common forms: (1,H,W) or (B,H,W) -> pick first
472
- if a.ndim == 3 and a.shape[0] >= 1:
473
- a = a[0]
474
- if a.ndim != 2:
475
- raise ValueError(f"Alpha must be HxW; got {a.shape}")
476
  return np.clip(a * 255.0, 0, 255).astype(np.uint8)
477
 
478
  def initialize(self) -> bool:
@@ -514,38 +517,73 @@ def initialize(self) -> bool:
514
  state.matanyone_error = f"MatAnyone init error: {e}"
515
  return False
516
 
517
- # ----- robust call helpers (UNBATCHED) -----------------------------------
518
- def _call_step_seed(self,
519
- img_chw: "torch.Tensor",
520
- prob_hw: "torch.Tensor",
521
- prob_1hw: "torch.Tensor"):
 
 
 
522
  """
523
- Try signatures that keep inputs UNBATCHED:
524
- 1) step(img[3,H,W], prob[H,W])
525
- 2) step(img[3,H,W], prob[H,W], matting=True)
526
- 3) step(img[3,H,W], prob[1,H,W])
527
- 4) step(img[3,H,W], prob[1,H,W], matting=True)
 
 
 
 
 
 
 
528
  """
529
- try:
530
- return self.core.step(img_chw, prob_hw)
531
- except (TypeError, RuntimeError):
532
- pass
533
- try:
534
- return self.core.step(img_chw, prob_hw, matting=True)
535
- except (TypeError, RuntimeError):
536
- pass
537
- try:
538
- return self.core.step(img_chw, prob_1hw)
539
- except (TypeError, RuntimeError):
540
- pass
541
- return self.core.step(img_chw, prob_1hw, matting=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
542
 
543
  # ----- video matting using first-frame PROB mask --------------------------
544
  def process_video(self, input_path: str, mask_path: str, output_path: str) -> str:
545
  """
546
  Produce a single-channel alpha mp4 matching input fps & size.
547
 
548
- First frame: pass a soft seed prob (HxW) alongside CHW image.
549
  Remaining frames: call step(image) only.
550
  """
551
  if not self.initialized or self.core is None:
@@ -563,13 +601,16 @@ def process_video(self, input_path: str, mask_path: str, output_path: str) -> st
563
  w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
564
  h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
565
 
566
- # soft seed prob (unbatched)
567
  seed_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
568
  if seed_mask is None:
569
  cap.release()
570
  raise RuntimeError("Seed mask read failed")
571
- prob_hw = self._prob_hw_from_mask_u8(seed_mask, w, h) # (H,W)
572
- prob_1hw = self._prob_1hw_from_mask_u8(seed_mask, w, h) # (1,H,W) fallback
 
 
 
573
 
574
  # temp frames
575
  tmp_dir = TEMP_DIR / f"ma_{int(time.time())}_{random.randint(1000,9999)}"
@@ -583,12 +624,17 @@ def process_video(self, input_path: str, mask_path: str, output_path: str) -> st
583
  if not ok or frame_bgr is None:
584
  cap.release()
585
  raise RuntimeError("Empty first frame")
586
-
587
  frame_rgb01 = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
588
- img_chw = self._to_chw_float(frame_rgb01) # (3,H,W)
 
 
589
 
590
  with torch.no_grad():
591
- out_prob = self._call_step_seed(img_chw, prob_hw, prob_1hw)
 
 
 
 
592
 
593
  alpha_u8 = self._alpha_to_u8_hw(out_prob)
594
  cv2.imwrite(str(tmp_dir / f"{frame_idx:06d}.png"), alpha_u8)
@@ -601,13 +647,11 @@ def process_video(self, input_path: str, mask_path: str, output_path: str) -> st
601
  break
602
 
603
  frame_rgb01 = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
604
- img_chw = self._to_chw_float(frame_rgb01)
 
605
 
606
  with torch.no_grad():
607
- try:
608
- out_prob = self.core.step(img_chw)
609
- except TypeError:
610
- out_prob = self.core.step(img_chw, matting=True)
611
 
612
  alpha_u8 = self._alpha_to_u8_hw(out_prob)
613
  cv2.imwrite(str(tmp_dir / f"{frame_idx:06d}.png"), alpha_u8)
@@ -635,8 +679,6 @@ def process_video(self, input_path: str, mask_path: str, output_path: str) -> st
635
  return str(alpha_path)
636
 
637
 
638
-
639
-
640
  # =============================================================================
641
  # CHAPTER 7: AI BACKGROUNDS
642
  # =============================================================================
 
409
  # CHAPTER 6: MATANYONE HANDLER (First-frame PROB mask)
410
  # ==============================================================================
411
 
412
+ # =============================================================================
413
+ # CHAPTER 6: MATANYONE HANDLER (Robust unbatched calls + fallbacks)
414
+ # =============================================================================
415
+
416
  class MatAnyoneHandler:
417
  """
418
+ Robust MatAnyone loader + inference adapter.
419
+
420
+ What this does:
421
+ - Prefers UNBATCHED inputs:
422
+ image: (3, H, W) float32 in [0,1]
423
+ prob : (H, W) float32 in [0,1] (soft seed from first frame)
424
+ - Falls back to other safe permutations that some MatAnyone builds expect:
425
+ prob 1xHxW, numpy HxWx3 + HxW, etc.
426
+ - Never uses idx_mask/objects (your build asserts on idx mask path).
427
+ - Squeezes model outputs back to HxW uint8.
428
  """
429
  def __init__(self):
430
  self.core = None
 
432
 
433
  # ----- tensor helpers -----
434
  def _to_chw_float(self, img01: np.ndarray) -> "torch.Tensor":
435
+ """img01: HxWx3 in [0,1] -> torch float (3,H,W) on DEVICE (no batch)."""
436
  assert img01.ndim == 3 and img01.shape[2] == 3, f"Expected HxWx3, got {img01.shape}"
437
  t = torch.from_numpy(img01.transpose(2, 0, 1)).contiguous().float() # (3,H,W)
438
  return t.to(DEVICE, non_blocking=CUDA_AVAILABLE)
 
446
  return t.to(DEVICE, non_blocking=CUDA_AVAILABLE)
447
 
448
  def _prob_1hw_from_mask_u8(self, mask_u8: np.ndarray, w: int, h: int) -> "torch.Tensor":
449
+ """Optional: 1xHxW (channel-first, still unbatched)."""
450
  if mask_u8.shape[0] != h or mask_u8.shape[1] != w:
451
  mask_u8 = cv2.resize(mask_u8, (w, h), interpolation=cv2.INTER_NEAREST)
452
  prob = (mask_u8.astype(np.float32) / 255.0)[None, ...] # (1,H,W)
 
456
  def _alpha_to_u8_hw(self, alpha_like) -> np.ndarray:
457
  """
458
  Accepts torch / numpy / tuple(list) outputs.
459
+ Returns uint8 HxW (0..255). Squeezes common shapes down to HxW.
 
460
  """
461
  if isinstance(alpha_like, (list, tuple)) and len(alpha_like) > 1:
462
  alpha_like = alpha_like[1] # (indices, probs) -> take probs
 
471
  a = np.clip(a, 0, 1)
472
 
473
  a = np.squeeze(a)
474
+ if a.ndim == 3 and a.shape[0] >= 1: # (1,H,W) -> (H,W)
475
+ a = a[0]
476
  if a.ndim != 2:
477
+ raise ValueError(f"Alpha must be HxW; got {a.shape}")
478
+
 
 
 
479
  return np.clip(a * 255.0, 0, 255).astype(np.uint8)
480
 
481
  def initialize(self) -> bool:
 
517
  state.matanyone_error = f"MatAnyone init error: {e}"
518
  return False
519
 
520
+ # ----- robust call helpers ------------------------------------------------
521
+ def _try_step_variants_seed(self,
522
+ img_chw_t: "torch.Tensor",
523
+ img_hwc_np: np.ndarray,
524
+ prob_hw_t: "torch.Tensor",
525
+ prob_1hw_t: "torch.Tensor",
526
+ prob_hw_np: np.ndarray,
527
+ prob_hwc1_np: np.ndarray):
528
  """
529
+ Try multiple MatAnyone.step() signatures in a safe order.
530
+ We avoid idx_mask/objects because this build asserts on idx path.
531
+
532
+ Order (from most to least strict about tensors):
533
+ 1) step(CHW, HW)
534
+ 2) step(CHW, HW, matting=True)
535
+ 3) step(CHW, 1HW)
536
+ 4) step(CHW, 1HW, matting=True)
537
+ 5) step(HWC, HW) # numpy fallbacks
538
+ 6) step(HWC, HW, matting=True)
539
+ 7) step(HWC, HWC1)
540
+ 8) step(HWC, HWC1, matting=True)
541
  """
542
+ trials = [
543
+ ( (img_chw_t, prob_hw_t), {} ),
544
+ ( (img_chw_t, prob_hw_t), {"matting": True} ),
545
+ ( (img_chw_t, prob_1hw_t), {} ),
546
+ ( (img_chw_t, prob_1hw_t), {"matting": True} ),
547
+ ( (img_hwc_np, prob_hw_np), {} ),
548
+ ( (img_hwc_np, prob_hw_np), {"matting": True} ),
549
+ ( (img_hwc_np, prob_hwc1_np), {} ),
550
+ ( (img_hwc_np, prob_hwc1_np), {"matting": True} ),
551
+ ]
552
+ last_err = None
553
+ for (args, kwargs) in trials:
554
+ try:
555
+ return self.core.step(*args, **kwargs)
556
+ except Exception as e:
557
+ last_err = e
558
+ # Keep trying next variant
559
+ raise last_err # bubble up the most informative final error
560
+
561
+ def _try_step_variants_noseed(self,
562
+ img_chw_t: "torch.Tensor",
563
+ img_hwc_np: np.ndarray):
564
+ """
565
+ Variants when no seed is provided on subsequent frames.
566
+ """
567
+ trials = [
568
+ ( (img_chw_t,), {} ),
569
+ ( (img_chw_t,), {"matting": True} ),
570
+ ( (img_hwc_np,), {} ),
571
+ ( (img_hwc_np,), {"matting": True} ),
572
+ ]
573
+ last_err = None
574
+ for (args, kwargs) in trials:
575
+ try:
576
+ return self.core.step(*args, **kwargs)
577
+ except Exception as e:
578
+ last_err = e
579
+ raise last_err
580
 
581
  # ----- video matting using first-frame PROB mask --------------------------
582
  def process_video(self, input_path: str, mask_path: str, output_path: str) -> str:
583
  """
584
  Produce a single-channel alpha mp4 matching input fps & size.
585
 
586
+ First frame: pass a soft seed prob (~HW) alongside the image.
587
  Remaining frames: call step(image) only.
588
  """
589
  if not self.initialized or self.core is None:
 
601
  w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
602
  h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
603
 
604
+ # soft seed prob
605
  seed_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
606
  if seed_mask is None:
607
  cap.release()
608
  raise RuntimeError("Seed mask read failed")
609
+
610
+ prob_hw_t = self._prob_hw_from_mask_u8(seed_mask, w, h) # (H,W) torch
611
+ prob_1hw_t = self._prob_1hw_from_mask_u8(seed_mask, w, h) # (1,H,W) torch
612
+ prob_hw_np = (seed_mask.astype(np.float32) / 255.0) # (H,W) numpy
613
+ prob_hwc1_np = prob_hw_np[:, :, None] # (H,W,1) numpy
614
 
615
  # temp frames
616
  tmp_dir = TEMP_DIR / f"ma_{int(time.time())}_{random.randint(1000,9999)}"
 
624
  if not ok or frame_bgr is None:
625
  cap.release()
626
  raise RuntimeError("Empty first frame")
 
627
  frame_rgb01 = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
628
+
629
+ img_chw_t = self._to_chw_float(frame_rgb01) # (3,H,W) torch
630
+ img_hwc_np = frame_rgb01 # (H,W,3) numpy
631
 
632
  with torch.no_grad():
633
+ out_prob = self._try_step_variants_seed(
634
+ img_chw_t, img_hwc_np,
635
+ prob_hw_t, prob_1hw_t,
636
+ prob_hw_np, prob_hwc1_np
637
+ )
638
 
639
  alpha_u8 = self._alpha_to_u8_hw(out_prob)
640
  cv2.imwrite(str(tmp_dir / f"{frame_idx:06d}.png"), alpha_u8)
 
647
  break
648
 
649
  frame_rgb01 = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
650
+ img_chw_t = self._to_chw_float(frame_rgb01)
651
+ img_hwc_np = frame_rgb01
652
 
653
  with torch.no_grad():
654
+ out_prob = self._try_step_variants_noseed(img_chw_t, img_hwc_np)
 
 
 
655
 
656
  alpha_u8 = self._alpha_to_u8_hw(out_prob)
657
  cv2.imwrite(str(tmp_dir / f"{frame_idx:06d}.png"), alpha_u8)
 
679
  return str(alpha_path)
680
 
681
 
 
 
682
  # =============================================================================
683
  # CHAPTER 7: AI BACKGROUNDS
684
  # =============================================================================