MogensR commited on
Commit
bf47972
·
verified ·
1 Parent(s): f41f37e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +238 -0
app.py CHANGED
@@ -409,6 +409,244 @@ def create_mask(self, image_rgb: np.ndarray) -> Optional[np.ndarray]:
409
  # CHAPTER 6: MATANYONE HANDLER (First-frame PROB mask)
410
  # ==============================================================================
411
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  # =============================================================================
413
  # CHAPTER 6: MATANYONE HANDLER (Robust unbatched calls + fallbacks)
414
  # =============================================================================
 
409
  # CHAPTER 6: MATANYONE HANDLER (First-frame PROB mask)
410
  # ==============================================================================
411
 
412
+ class MatAnyoneHandler:
413
+ """
414
+ Fixed MatAnyone loader + inference adapter.
415
+
416
+ Key fix: Only pass tensor inputs to MatAnyone.core.step() since the
417
+ internal pad_divide_by function expects tensors, not numpy arrays.
418
+ """
419
+ def __init__(self):
420
+ self.core = None
421
+ self.initialized = False
422
+
423
+ # ----- tensor helpers -----
424
+ def _to_chw_float(self, img01: np.ndarray) -> "torch.Tensor":
425
+ """img01: HxWx3 in [0,1] -> torch float (3,H,W) on DEVICE (no batch)."""
426
+ assert img01.ndim == 3 and img01.shape[2] == 3, f"Expected HxWx3, got {img01.shape}"
427
+ t = torch.from_numpy(img01.transpose(2, 0, 1)).contiguous().float() # (3,H,W)
428
+ return t.to(DEVICE, non_blocking=CUDA_AVAILABLE)
429
+
430
+ def _prob_hw_from_mask_u8(self, mask_u8: np.ndarray, w: int, h: int) -> "torch.Tensor":
431
+ """mask_u8: HxW -> torch float (H,W) in [0,1] on DEVICE (no batch, no channel)."""
432
+ if mask_u8.shape[0] != h or mask_u8.shape[1] != w:
433
+ mask_u8 = cv2.resize(mask_u8, (w, h), interpolation=cv2.INTER_NEAREST)
434
+ prob = (mask_u8.astype(np.float32) / 255.0) # (H,W)
435
+ t = torch.from_numpy(prob).contiguous().float()
436
+ return t.to(DEVICE, non_blocking=CUDA_AVAILABLE)
437
+
438
+ def _prob_1hw_from_mask_u8(self, mask_u8: np.ndarray, w: int, h: int) -> "torch.Tensor":
439
+ """Optional: 1xHxW (channel-first, still unbatched)."""
440
+ if mask_u8.shape[0] != h or mask_u8.shape[1] != w:
441
+ mask_u8 = cv2.resize(mask_u8, (w, h), interpolation=cv2.INTER_NEAREST)
442
+ prob = (mask_u8.astype(np.float32) / 255.0)[None, ...] # (1,H,W)
443
+ t = torch.from_numpy(prob).contiguous().float()
444
+ return t.to(DEVICE, non_blocking=CUDA_AVAILABLE)
445
+
446
+ def _alpha_to_u8_hw(self, alpha_like) -> np.ndarray:
447
+ """
448
+ Accepts torch / numpy / tuple(list) outputs.
449
+ Returns uint8 HxW (0..255). Squeezes common shapes down to HxW.
450
+ """
451
+ if isinstance(alpha_like, (list, tuple)) and len(alpha_like) > 1:
452
+ alpha_like = alpha_like[1] # (indices, probs) -> take probs
453
+
454
+ if isinstance(alpha_like, torch.Tensor):
455
+ t = alpha_like.detach()
456
+ if t.is_cuda:
457
+ t = t.cpu()
458
+ a = t.float().clamp(0, 1).numpy()
459
+ else:
460
+ a = np.asarray(alpha_like, dtype=np.float32)
461
+ a = np.clip(a, 0, 1)
462
+
463
+ a = np.squeeze(a)
464
+ if a.ndim == 3 and a.shape[0] >= 1: # (1,H,W) -> (H,W)
465
+ a = a[0]
466
+ if a.ndim != 2:
467
+ raise ValueError(f"Alpha must be HxW; got {a.shape}")
468
+
469
+ return np.clip(a * 255.0, 0, 255).astype(np.uint8)
470
+
471
+ def initialize(self) -> bool:
472
+ if not TORCH_AVAILABLE:
473
+ state.matanyone_error = "PyTorch required"
474
+ return False
475
+ with memory_manager.mem_context("MatAnyone init"):
476
+ try:
477
+ _reset_hydra()
478
+ repo_path = ensure_repo("matanyone", "https://github.com/pq-yang/MatAnyone.git")
479
+ if not repo_path:
480
+ state.matanyone_error = "Clone failed"
481
+ return False
482
+ try:
483
+ from matanyone.inference.inference_core import InferenceCore
484
+ from matanyone.utils.get_default_model import get_matanyone_model
485
+ except Exception as e:
486
+ state.matanyone_error = f"Import error: {e}"
487
+ return False
488
+
489
+ ckpt = CHECKPOINTS / "matanyone.pth"
490
+ net = None
491
+ if ckpt.exists():
492
+ net = get_matanyone_model(str(ckpt), device=DEVICE)
493
+ else:
494
+ url = "https://github.com/pq-yang/MatAnyone/releases/download/v1.0.0/matanyone.pth"
495
+ if download_file(url, ckpt, "MatAnyone"):
496
+ net = get_matanyone_model(str(ckpt), device=DEVICE)
497
+
498
+ if net is None:
499
+ state.matanyone_error = "Model load failed"
500
+ return False
501
+
502
+ self.core = InferenceCore(net)
503
+ self.initialized = True
504
+ state.matanyone_ready = True
505
+ return True
506
+ except Exception as e:
507
+ state.matanyone_error = f"MatAnyone init error: {e}"
508
+ return False
509
+
510
+ # ----- FIXED: tensor-only call helpers --------------------------------
511
+ def _try_step_variants_seed(self,
512
+ img_chw_t: "torch.Tensor",
513
+ prob_hw_t: "torch.Tensor",
514
+ prob_1hw_t: "torch.Tensor"):
515
+ """
516
+ Try multiple MatAnyone.step() signatures with TENSOR INPUTS ONLY.
517
+ MatAnyone's internal functions expect tensors, not numpy arrays.
518
+
519
+ Order (most to least preferred):
520
+ 1) step(CHW_tensor, HW_tensor)
521
+ 2) step(CHW_tensor, HW_tensor, matting=True)
522
+ 3) step(CHW_tensor, 1HW_tensor)
523
+ 4) step(CHW_tensor, 1HW_tensor, matting=True)
524
+ """
525
+ trials = [
526
+ ( (img_chw_t, prob_hw_t), {} ),
527
+ ( (img_chw_t, prob_hw_t), {"matting": True} ),
528
+ ( (img_chw_t, prob_1hw_t), {} ),
529
+ ( (img_chw_t, prob_1hw_t), {"matting": True} ),
530
+ ]
531
+ last_err = None
532
+ for (args, kwargs) in trials:
533
+ try:
534
+ return self.core.step(*args, **kwargs)
535
+ except Exception as e:
536
+ last_err = e
537
+ # Keep trying next variant
538
+ raise last_err # bubble up the most informative final error
539
+
540
+ def _try_step_variants_noseed(self, img_chw_t: "torch.Tensor"):
541
+ """
542
+ Variants when no seed is provided on subsequent frames.
543
+ TENSOR INPUT ONLY.
544
+ """
545
+ trials = [
546
+ ( (img_chw_t,), {} ),
547
+ ( (img_chw_t,), {"matting": True} ),
548
+ ]
549
+ last_err = None
550
+ for (args, kwargs) in trials:
551
+ try:
552
+ return self.core.step(*args, **kwargs)
553
+ except Exception as e:
554
+ last_err = e
555
+ raise last_err
556
+
557
+ # ----- video matting using first-frame PROB mask --------------------------
558
+ def process_video(self, input_path: str, mask_path: str, output_path: str) -> str:
559
+ """
560
+ Produce a single-channel alpha mp4 matching input fps & size.
561
+
562
+ First frame: pass a soft seed prob (~HW) alongside the image.
563
+ Remaining frames: call step(image) only.
564
+ """
565
+ if not self.initialized or self.core is None:
566
+ raise RuntimeError("MatAnyone not initialized")
567
+
568
+ out_dir = Path(output_path)
569
+ out_dir.mkdir(parents=True, exist_ok=True)
570
+ alpha_path = out_dir / "alpha.mp4"
571
+
572
+ cap = cv2.VideoCapture(input_path)
573
+ if not cap.isOpened():
574
+ raise RuntimeError("Could not open input video")
575
+
576
+ fps = cap.get(cv2.CAP_PROP_FPS) or 24.0
577
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
578
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
579
+
580
+ # soft seed prob - prepare tensor versions only
581
+ seed_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
582
+ if seed_mask is None:
583
+ cap.release()
584
+ raise RuntimeError("Seed mask read failed")
585
+
586
+ prob_hw_t = self._prob_hw_from_mask_u8(seed_mask, w, h) # (H,W) torch
587
+ prob_1hw_t = self._prob_1hw_from_mask_u8(seed_mask, w, h) # (1,H,W) torch
588
+
589
+ # temp frames
590
+ tmp_dir = TEMP_DIR / f"ma_{int(time.time())}_{random.randint(1000,9999)}"
591
+ tmp_dir.mkdir(parents=True, exist_ok=True)
592
+ memory_manager.register_temp_file(str(tmp_dir))
593
+
594
+ frame_idx = 0
595
+
596
+ # --- first frame (with soft prob) ---
597
+ ok, frame_bgr = cap.read()
598
+ if not ok or frame_bgr is None:
599
+ cap.release()
600
+ raise RuntimeError("Empty first frame")
601
+ frame_rgb01 = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
602
+
603
+ img_chw_t = self._to_chw_float(frame_rgb01) # (3,H,W) torch
604
+
605
+ with torch.no_grad():
606
+ out_prob = self._try_step_variants_seed(
607
+ img_chw_t, prob_hw_t, prob_1hw_t
608
+ )
609
+
610
+ alpha_u8 = self._alpha_to_u8_hw(out_prob)
611
+ cv2.imwrite(str(tmp_dir / f"{frame_idx:06d}.png"), alpha_u8)
612
+ frame_idx += 1
613
+
614
+ # --- remaining frames (no seed) ---
615
+ while True:
616
+ ok, frame_bgr = cap.read()
617
+ if not ok or frame_bgr is None:
618
+ break
619
+
620
+ frame_rgb01 = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
621
+ img_chw_t = self._to_chw_float(frame_rgb01)
622
+
623
+ with torch.no_grad():
624
+ out_prob = self._try_step_variants_noseed(img_chw_t)
625
+
626
+ alpha_u8 = self._alpha_to_u8_hw(out_prob)
627
+ cv2.imwrite(str(tmp_dir / f"{frame_idx:06d}.png"), alpha_u8)
628
+ frame_idx += 1
629
+
630
+ cap.release()
631
+
632
+ # --- encode PNGs → alpha mp4 ---
633
+ list_file = tmp_dir / "list.txt"
634
+ with open(list_file, "w") as f:
635
+ for i in range(frame_idx):
636
+ f.write(f"file '{(tmp_dir / f'{i:06d}.png').as_posix()}'\n")
637
+
638
+ cmd = [
639
+ "ffmpeg", "-y", "-hide_banner", "-loglevel", "error",
640
+ "-f", "concat", "-safe", "0",
641
+ "-r", f"{fps:.6f}",
642
+ "-i", str(list_file),
643
+ "-vf", f"format=gray,scale={w}:{h}:flags=area",
644
+ "-pix_fmt", "yuv420p",
645
+ "-c:v", "libx264", "-preset", "medium", "-crf", "18",
646
+ str(alpha_path)
647
+ ]
648
+ subprocess.run(cmd, check=True)
649
+ return str(alpha_path)
650
  # =============================================================================
651
  # CHAPTER 6: MATANYONE HANDLER (Robust unbatched calls + fallbacks)
652
  # =============================================================================