MogensR commited on
Commit
a1c7ef1
·
verified ·
1 Parent(s): 0319b0d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -29
app.py CHANGED
@@ -413,11 +413,12 @@ class MatAnyoneHandler:
413
  """
414
  MatAnyone loader + inference adapter.
415
 
416
- Key points:
417
- - Uses first-frame *soft probability* seed (1xHxW float in [0,1]), not an index mask.
418
- - Calls InferenceCore.step with the prob map as a **positional** arg (some builds reject `prob=`).
419
- - Tries `matting=True` when supported; falls back if the kwarg is not available.
420
- - Always feeds CHW tensors for images (3,H,W) and 1xHxW for probs — no extra batch dims.
 
421
  """
422
  def __init__(self):
423
  self.core = None
@@ -427,21 +428,29 @@ def __init__(self):
427
  def _to_chw_float(self, img01: np.ndarray) -> "torch.Tensor":
428
  """img01: HxWx3 in [0,1] -> torch float 3xHxW on DEVICE"""
429
  assert img01.ndim == 3 and img01.shape[2] == 3, f"Expected HxWx3, got {img01.shape}"
430
- t = torch.from_numpy(img01.transpose(2, 0, 1)).contiguous().float() # 3xHxW
431
  return t.to(DEVICE, non_blocking=CUDA_AVAILABLE)
432
 
433
- def _prob_from_mask_u8(self, mask_u8: np.ndarray, w: int, h: int) -> "torch.Tensor":
434
- """mask_u8: HxW uint8 -> torch float 1xHxW on DEVICE, resized to (w,h) if needed"""
435
  if mask_u8.shape[0] != h or mask_u8.shape[1] != w:
436
  mask_u8 = cv2.resize(mask_u8, (w, h), interpolation=cv2.INTER_NEAREST)
437
- prob = (mask_u8.astype(np.float32) / 255.0)[None, ...] # 1xHxW
 
 
 
 
 
 
 
 
438
  t = torch.from_numpy(prob).contiguous().float()
439
  return t.to(DEVICE, non_blocking=CUDA_AVAILABLE)
440
 
441
  def _alpha_to_u8_hw(self, alpha_like) -> np.ndarray:
442
  """
443
  Accepts torch Tensor or numpy-like. Returns uint8 HxW (0..255).
444
- Handles shapes (H,W), (1,H,W), or (K,H,W) -> picks first channel.
445
  Also handles MatAnyone tuples/lists like (indices, probs) by taking the 2nd item.
446
  """
447
  if isinstance(alpha_like, (list, tuple)) and len(alpha_like) > 1:
@@ -503,19 +512,45 @@ def initialize(self) -> bool:
503
  state.matanyone_error = f"MatAnyone init error: {e}"
504
  return False
505
 
506
- # ----- video matting using first-frame PROB mask (PATCHED) -----
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
507
  def process_video(self, input_path: str, mask_path: str, output_path: str) -> str:
508
  """
509
  Produce a single-channel alpha mp4 matching input fps & size.
510
 
511
  First frame:
512
- - Generate soft prob (1,H,W) from SAM2 mask and pass as positional arg to step().
513
- - Try step(image, prob, matting=True); if TypeError, call step(image, prob).
514
-
515
  Remaining frames:
516
- - Try step(image, matting=True); fallback to step(image).
517
-
518
- Returns: path to alpha.mp4
519
  """
520
  if not self.initialized or self.core is None:
521
  raise RuntimeError("MatAnyone not initialized")
@@ -532,25 +567,19 @@ def process_video(self, input_path: str, mask_path: str, output_path: str) -> st
532
  w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
533
  h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
534
 
535
- # soft seed prob (1,H,W) in [0,1]
536
  seed_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
537
  if seed_mask is None:
538
  cap.release()
539
  raise RuntimeError("Seed mask read failed")
540
- prob_1hw = self._prob_from_mask_u8(seed_mask, w, h) # (1,H,W) float
 
 
541
 
542
  # temp frames
543
  tmp_dir = TEMP_DIR / f"ma_{int(time.time())}_{random.randint(1000,9999)}"
544
  tmp_dir.mkdir(parents=True, exist_ok=True)
545
  memory_manager.register_temp_file(str(tmp_dir))
546
 
547
- def _step_with_prob(image_chw: "torch.Tensor", prob_1hw_t: "torch.Tensor"):
548
- """Call step with positional prob; fall back if 'matting' kwarg unsupported."""
549
- try:
550
- return self.core.step(image_chw, prob_1hw_t, matting=True)
551
- except TypeError:
552
- return self.core.step(image_chw, prob_1hw_t)
553
-
554
  frame_idx = 0
555
 
556
  # --- first frame (with soft prob) ---
@@ -558,11 +587,12 @@ def _step_with_prob(image_chw: "torch.Tensor", prob_1hw_t: "torch.Tensor"):
558
  if not ok or frame_bgr is None:
559
  cap.release()
560
  raise RuntimeError("Empty first frame")
 
561
  frame_rgb01 = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
562
  img_chw = self._to_chw_float(frame_rgb01) # (3,H,W)
563
 
564
  with torch.no_grad():
565
- out_prob = _step_with_prob(img_chw, prob_1hw)
566
 
567
  alpha_u8 = self._alpha_to_u8_hw(out_prob)
568
  cv2.imwrite(str(tmp_dir / f"{frame_idx:06d}.png"), alpha_u8)
@@ -573,14 +603,16 @@ def _step_with_prob(image_chw: "torch.Tensor", prob_1hw_t: "torch.Tensor"):
573
  ok, frame_bgr = cap.read()
574
  if not ok or frame_bgr is None:
575
  break
 
576
  frame_rgb01 = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
577
  img_chw = self._to_chw_float(frame_rgb01)
578
 
579
  with torch.no_grad():
580
  try:
581
- out_prob = self.core.step(img_chw, matting=True)
582
  except TypeError:
583
- out_prob = self.core.step(img_chw)
 
584
 
585
  alpha_u8 = self._alpha_to_u8_hw(out_prob)
586
  cv2.imwrite(str(tmp_dir / f"{frame_idx:06d}.png"), alpha_u8)
 
413
  """
414
  MatAnyone loader + inference adapter.
415
 
416
+ Strategy:
417
+ - Image CHW (3,H,W), no batch/time dims.
418
+ - Seed mask try **2D prob** (H,W) first to avoid 5 vs 6-D concat, then fall back to 1xHxW.
419
+ - Call InferenceCore.step with prob as a **positional** argument.
420
+ - Try with/without `matting=True` (some builds don't accept it).
421
+ - Subsequent frames call step(image) with no seed.
422
  """
423
  def __init__(self):
424
  self.core = None
 
428
  def _to_chw_float(self, img01: np.ndarray) -> "torch.Tensor":
429
  """img01: HxWx3 in [0,1] -> torch float 3xHxW on DEVICE"""
430
  assert img01.ndim == 3 and img01.shape[2] == 3, f"Expected HxWx3, got {img01.shape}"
431
+ t = torch.from_numpy(img01.transpose(2, 0, 1)).contiguous().float() # (3,H,W)
432
  return t.to(DEVICE, non_blocking=CUDA_AVAILABLE)
433
 
434
+ def _prob_hw_from_mask_u8(self, mask_u8: np.ndarray, w: int, h: int) -> "torch.Tensor":
435
+ """mask_u8: HxW uint8 -> torch float (H,W) on DEVICE, resized if needed"""
436
  if mask_u8.shape[0] != h or mask_u8.shape[1] != w:
437
  mask_u8 = cv2.resize(mask_u8, (w, h), interpolation=cv2.INTER_NEAREST)
438
+ prob = (mask_u8.astype(np.float32) / 255.0) # (H,W)
439
+ t = torch.from_numpy(prob).contiguous().float()
440
+ return t.to(DEVICE, non_blocking=CUDA_AVAILABLE)
441
+
442
+ def _prob_1hw_from_mask_u8(self, mask_u8: np.ndarray, w: int, h: int) -> "torch.Tensor":
443
+ """mask_u8: HxW uint8 -> torch float (1,H,W) on DEVICE, resized if needed"""
444
+ if mask_u8.shape[0] != h or mask_u8.shape[1] != w:
445
+ mask_u8 = cv2.resize(mask_u8, (w, h), interpolation=cv2.INTER_NEAREST)
446
+ prob = (mask_u8.astype(np.float32) / 255.0)[None, ...] # (1,H,W)
447
  t = torch.from_numpy(prob).contiguous().float()
448
  return t.to(DEVICE, non_blocking=CUDA_AVAILABLE)
449
 
450
  def _alpha_to_u8_hw(self, alpha_like) -> np.ndarray:
451
  """
452
  Accepts torch Tensor or numpy-like. Returns uint8 HxW (0..255).
453
+ Handles (H,W), (1,H,W), or (K,H,W) by taking the first channel if needed.
454
  Also handles MatAnyone tuples/lists like (indices, probs) by taking the 2nd item.
455
  """
456
  if isinstance(alpha_like, (list, tuple)) and len(alpha_like) > 1:
 
512
  state.matanyone_error = f"MatAnyone init error: {e}"
513
  return False
514
 
515
+ # ----- robust call helpers ------------------------------------------------
516
+ def _call_step_seed(self, img_chw: "torch.Tensor",
517
+ prob_hw: "torch.Tensor",
518
+ prob_1hw: "torch.Tensor"):
519
+ """
520
+ Try a few safe permutations to satisfy different MatAnyone builds.
521
+ Order chosen to avoid the 5-vs-6D concat error seen in group_modules.py.
522
+ """
523
+ # 1) image (3,H,W), prob (H,W)
524
+ try:
525
+ return self.core.step(img_chw, prob_hw)
526
+ except (TypeError, RuntimeError):
527
+ pass
528
+
529
+ # 2) image (3,H,W), prob (1,H,W)
530
+ try:
531
+ return self.core.step(img_chw, prob_1hw)
532
+ except (TypeError, RuntimeError):
533
+ pass
534
+
535
+ # 3) image (3,H,W), prob (H,W), matting kw
536
+ try:
537
+ return self.core.step(img_chw, prob_hw, matting=True)
538
+ except (TypeError, RuntimeError):
539
+ pass
540
+
541
+ # 4) image (3,H,W), prob (1,H,W), matting kw
542
+ return self.core.step(img_chw, prob_1hw, matting=True)
543
+
544
+ # ----- video matting using first-frame PROB mask (PATCHED) ----------------
545
  def process_video(self, input_path: str, mask_path: str, output_path: str) -> str:
546
  """
547
  Produce a single-channel alpha mp4 matching input fps & size.
548
 
549
  First frame:
550
+ - Build both (H,W) and (1,H,W) soft prob tensors from SAM2 mask.
551
+ - Call step(...) via _call_step_seed (tries 2D first to dodge 6D concat).
 
552
  Remaining frames:
553
+ - Call step(image) with no seed.
 
 
554
  """
555
  if not self.initialized or self.core is None:
556
  raise RuntimeError("MatAnyone not initialized")
 
567
  w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
568
  h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
569
 
 
570
  seed_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
571
  if seed_mask is None:
572
  cap.release()
573
  raise RuntimeError("Seed mask read failed")
574
+
575
+ prob_hw = self._prob_hw_from_mask_u8(seed_mask, w, h) # (H,W)
576
+ prob_1hw = self._prob_1hw_from_mask_u8(seed_mask, w, h) # (1,H,W)
577
 
578
  # temp frames
579
  tmp_dir = TEMP_DIR / f"ma_{int(time.time())}_{random.randint(1000,9999)}"
580
  tmp_dir.mkdir(parents=True, exist_ok=True)
581
  memory_manager.register_temp_file(str(tmp_dir))
582
 
 
 
 
 
 
 
 
583
  frame_idx = 0
584
 
585
  # --- first frame (with soft prob) ---
 
587
  if not ok or frame_bgr is None:
588
  cap.release()
589
  raise RuntimeError("Empty first frame")
590
+
591
  frame_rgb01 = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
592
  img_chw = self._to_chw_float(frame_rgb01) # (3,H,W)
593
 
594
  with torch.no_grad():
595
+ out_prob = self._call_step_seed(img_chw, prob_hw, prob_1hw)
596
 
597
  alpha_u8 = self._alpha_to_u8_hw(out_prob)
598
  cv2.imwrite(str(tmp_dir / f"{frame_idx:06d}.png"), alpha_u8)
 
603
  ok, frame_bgr = cap.read()
604
  if not ok or frame_bgr is None:
605
  break
606
+
607
  frame_rgb01 = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
608
  img_chw = self._to_chw_float(frame_rgb01)
609
 
610
  with torch.no_grad():
611
  try:
612
+ out_prob = self.core.step(img_chw) # simplest path
613
  except TypeError:
614
+ # Extremely old/new variants: try permissive kw
615
+ out_prob = self.core.step(img_chw, matting=True)
616
 
617
  alpha_u8 = self._alpha_to_u8_hw(out_prob)
618
  cv2.imwrite(str(tmp_dir / f"{frame_idx:06d}.png"), alpha_u8)