MogensR commited on
Commit
6aec771
·
1 Parent(s): 6d44f52

Update models/loaders/matanyone_loader.py

Browse files
Files changed (1) hide show
  1. models/loaders/matanyone_loader.py +170 -24
models/loaders/matanyone_loader.py CHANGED
@@ -1,26 +1,5 @@
1
- from matanyone_loader import MatAnyoneLoader
2
- import cv2, numpy as np, torch
3
-
4
- # Load session (stateful per video)
5
- loader = MatAnyoneLoader(device="cuda")
6
- session = loader.load()
7
- assert session, "MatAnyone failed to load"
8
-
9
- # Frame 0 (must supply a coarse mask, even a fallback like 0.5 or ones)
10
- bgr0 = cv2.imread("frame0001.jpg")
11
- rgb0 = cv2.cvtColor(bgr0, cv2.COLOR_BGR2RGB)
12
- coarse0 = np.ones((rgb0.shape[0], rgb0.shape[1]), dtype=np.float32) # example fallback
13
-
14
- alpha0 = session(rgb0, coarse0) # -> 2-D float32 [H,W]
15
-
16
- # Frames 1..N (mask=None, stateful propagation)
17
- for i in range(2, 6):
18
- bgr = cv2.imread(f"frame000{i}.jpg")
19
- rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
20
- alpha = session(rgb, mask=None) # -> 2-D float32 [H,W]
21
  #!/usr/bin/env python3
22
  # -*- coding: utf-8 -*-
23
-
24
  """
25
  MatAnyone Loader + Stateful Adapter (OOM-resilient, spatially robust)
26
  - Canonical HF load (MatAnyone.from_pretrained → InferenceCore(model, cfg))
@@ -44,6 +23,13 @@
44
  import torch.nn.functional as F
45
  import inspect
46
  import threading
 
 
 
 
 
 
 
47
  def _select_device(pref: str) -> str:
48
  pref = (pref or "").lower()
49
  if pref.startswith("cuda"):
@@ -52,6 +38,133 @@ def _select_device(pref: str) -> str:
52
  return "cpu"
53
  return "cuda" if torch.cuda.is_available() else "cpu"
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  def _choose_precision(device: str) -> Tuple[torch.dtype, bool, Optional[torch.dtype]]:
57
  """Pick model weight dtype + autocast dtype (bf16>fp16>fp32)."""
@@ -65,6 +178,11 @@ def _choose_precision(device: str) -> Tuple[torch.dtype, bool, Optional[torch.dt
65
  if fp16_ok:
66
  return torch.float16, True, torch.float16
67
  return torch.float32, False, None
 
 
 
 
 
68
  class _MatAnyoneSession:
69
  """
70
  Stateful controller around InferenceCore with OOM-resilient inference.
@@ -163,6 +281,7 @@ def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
163
  if (th, tw) == (H, W):
164
  msk_in = mask_1hw
165
  else:
 
166
  msk_in = F.interpolate(mask_1hw.unsqueeze(0), size=(th, tw), mode="nearest")[0]
167
 
168
  img_chw = _to_chw_image(img_in).contiguous() # [C,H,W]
@@ -218,6 +337,11 @@ def __exit__(self, *a): return False
218
  if mask_1hw is not None:
219
  return _to_2d_alpha_numpy(mask_1hw)
220
  return np.full((H, W), 0.5, dtype=np.float32)
 
 
 
 
 
221
  class MatAnyoneLoader:
222
  """
223
  Official MatAnyone loader with stateful, OOM-resilient session adapter.
@@ -268,7 +392,7 @@ def load(self) -> Optional[Any]:
268
  model_dtype, use_autocast, autocast_dtype = _choose_precision(self.device)
269
  logger.info(f"MatAnyone precision: weights={model_dtype}, autocast={use_autocast and autocast_dtype}")
270
 
271
- # HF weights (safetensors); keep trust defaults of library itself
272
  self.model = model_cls.from_pretrained(self.model_id)
273
  try:
274
  self.model = self.model.to(self.device).to(model_dtype)
@@ -341,9 +465,31 @@ def debug_shapes(self, image, mask, tag: str = ""):
341
  logger.info(f"[{tag}:mask ] shape={tuple(tv_msk.shape)} dtype={tv_msk.dtype}")
342
  except Exception as e:
343
  logger.info(f"[{tag}] debug error: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
  if __name__ == "__main__":
345
  import sys
346
- import cv2
347
 
348
  logging.basicConfig(level=logging.INFO)
349
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -373,6 +519,6 @@ def debug_shapes(self, image, mask, tag: str = ""):
373
  print("Failed to load MatAnyone")
374
  raise SystemExit(3)
375
 
376
- alpha = session(img_rgb, mask)
377
  cv2.imwrite("alpha_out.png", (np.clip(alpha, 0, 1) * 255).astype(np.uint8))
378
  print("Alpha matte written to alpha_out.png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #!/usr/bin/env python3
2
  # -*- coding: utf-8 -*-
 
3
  """
4
  MatAnyone Loader + Stateful Adapter (OOM-resilient, spatially robust)
5
  - Canonical HF load (MatAnyone.from_pretrained → InferenceCore(model, cfg))
 
23
  import torch.nn.functional as F
24
  import inspect
25
  import threading
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ # ---------------------------------------------------------------------------
30
+ # Utilities (shapes, dtype, scaling)
31
+ # ---------------------------------------------------------------------------
32
+
33
  def _select_device(pref: str) -> str:
34
  pref = (pref or "").lower()
35
  if pref.startswith("cuda"):
 
38
  return "cpu"
39
  return "cuda" if torch.cuda.is_available() else "cpu"
40
 
41
+ def _as_tensor_on_device(x, device: str) -> torch.Tensor:
42
+ if isinstance(x, torch.Tensor):
43
+ return x.to(device, non_blocking=True)
44
+ return torch.from_numpy(np.asarray(x)).to(device, non_blocking=True)
45
+
46
+ def _to_bchw(x, device: str, is_mask: bool = False) -> torch.Tensor:
47
+ """
48
+ Normalize input to BCHW (image) or B1HW (mask).
49
+ Accepts: HWC, CHW, BCHW, BHWC, BTCHW/BTHWC, TCHW/THWC, HW.
50
+ """
51
+ x = _as_tensor_on_device(x, device)
52
+ if x.dtype == torch.uint8:
53
+ x = x.float().div_(255.0)
54
+ elif x.dtype in (torch.int16, torch.int32, torch.int64):
55
+ x = x.float()
56
+ if x.ndim == 5:
57
+ x = x[:, 0] # -> 4D
58
+ if x.ndim == 4:
59
+ if x.shape[-1] in (1, 3, 4) and x.shape[1] not in (1, 3, 4):
60
+ x = x.permute(0, 3, 1, 2).contiguous()
61
+ elif x.ndim == 3:
62
+ if x.shape[-1] in (1, 3, 4):
63
+ x = x.permute(2, 0, 1).contiguous()
64
+ x = x.unsqueeze(0)
65
+ elif x.ndim == 2:
66
+ x = x.unsqueeze(0).unsqueeze(0)
67
+ if not is_mask:
68
+ x = x.repeat(1, 3, 1, 1)
69
+ else:
70
+ raise ValueError(f"Unsupported ndim={x.ndim}")
71
+ if is_mask:
72
+ if x.shape[1] > 1:
73
+ x = x[:, :1]
74
+ x = x.clamp_(0.0, 1.0).to(torch.float32)
75
+ else:
76
+ if x.shape[1] == 1:
77
+ x = x.repeat(1, 3, 1, 1)
78
+ x = x.clamp_(0.0, 1.0)
79
+ return x
80
+
81
+ def _to_chw_image(img_bchw: torch.Tensor) -> torch.Tensor:
82
+ if img_bchw.ndim == 4 and img_bchw.shape[0] == 1:
83
+ return img_bchw[0]
84
+ return img_bchw
85
+
86
+ def _to_1hw_mask(msk_b1hw: torch.Tensor) -> Optional[torch.Tensor]:
87
+ if msk_b1hw is None:
88
+ return None
89
+ if msk_b1hw.ndim == 4 and msk_b1hw.shape[1] == 1:
90
+ return msk_b1hw[0] # -> [1,H,W]
91
+ if msk_b1hw.ndim == 3 and msk_b1hw.shape[0] == 1:
92
+ return msk_b1hw
93
+ raise ValueError(f"Expected B1HW or 1HW, got {tuple(msk_b1hw.shape)}")
94
+
95
+ def _resize_bchw(x: Optional[torch.Tensor], size_hw: Tuple[int, int], is_mask: bool = False) -> Optional[torch.Tensor]:
96
+ if x is None:
97
+ return None
98
+ if x.shape[-2:] == size_hw:
99
+ return x
100
+ mode = "nearest" if is_mask else "bilinear"
101
+ return F.interpolate(x, size_hw, mode=mode, align_corners=False if mode == "bilinear" else None)
102
+
103
+ def _to_b1hw_alpha(alpha, device: str) -> torch.Tensor:
104
+ t = torch.as_tensor(alpha, device=device).float()
105
+ if t.ndim == 2:
106
+ t = t.unsqueeze(0).unsqueeze(0) # -> [1,1,H,W]
107
+ elif t.ndim == 3:
108
+ if t.shape[0] in (1, 3, 4):
109
+ if t.shape[0] != 1:
110
+ t = t[:1]
111
+ t = t.unsqueeze(0)
112
+ elif t.shape[-1] in (1, 3, 4):
113
+ t = t[..., :1].permute(2, 0, 1).unsqueeze(0)
114
+ else:
115
+ t = t[:1].unsqueeze(0)
116
+ elif t.ndim == 4:
117
+ if t.shape[1] != 1:
118
+ t = t[:, :1]
119
+ if t.shape[0] != 1:
120
+ t = t[:1]
121
+ else:
122
+ while t.ndim > 4:
123
+ t = t.squeeze(0)
124
+ while t.ndim < 4:
125
+ t = t.unsqueeze(0)
126
+ if t.shape[1] != 1:
127
+ t = t[:, :1]
128
+ return t.clamp_(0.0, 1.0).contiguous()
129
+
130
+ def _to_2d_alpha_numpy(x) -> np.ndarray:
131
+ t = torch.as_tensor(x).float()
132
+ while t.ndim > 2:
133
+ if t.ndim == 4 and t.shape[0] == 1 and t.shape[1] == 1:
134
+ t = t[0, 0]
135
+ elif t.ndim == 3 and t.shape[0] == 1:
136
+ t = t[0]
137
+ else:
138
+ t = t.squeeze(0)
139
+ t = t.clamp_(0.0, 1.0)
140
+ out = t.detach().cpu().numpy().astype(np.float32)
141
+ return np.ascontiguousarray(out)
142
+
143
+ def _compute_scaled_size(h: int, w: int, max_edge: int, target_pixels: int) -> Tuple[int, int, float]:
144
+ if h <= 0 or w <= 0:
145
+ return h, w, 1.0
146
+ s1 = min(1.0, float(max_edge) / float(max(h, w))) if max_edge > 0 else 1.0
147
+ s2 = min(1.0, (float(target_pixels) / float(h * w)) ** 0.5) if target_pixels > 0 else 1.0
148
+ s = min(s1, s2)
149
+ nh = max(1, int(round(h * s)))
150
+ nw = max(1, int(round(w * s)))
151
+ return nh, nw, s
152
+
153
+ def debug_shapes(tag: str, image, mask) -> None:
154
+ def _info(name, v):
155
+ try:
156
+ tv = torch.as_tensor(v)
157
+ mn = float(tv.min()) if tv.numel() else float("nan")
158
+ mx = float(tv.max()) if tv.numel() else float("nan")
159
+ logger.info(f"[{tag}:{name}] shape={tuple(tv.shape)} dtype={tv.dtype} min={mn:.4f} max={mx:.4f}")
160
+ except Exception as e:
161
+ logger.info(f"[{tag}:{name}] type={type(v)} err={e}")
162
+ _info("image", image)
163
+ _info("mask", mask)
164
+
165
+ # ---------------------------------------------------------------------------
166
+ # Precision selection
167
+ # ---------------------------------------------------------------------------
168
 
169
  def _choose_precision(device: str) -> Tuple[torch.dtype, bool, Optional[torch.dtype]]:
170
  """Pick model weight dtype + autocast dtype (bf16>fp16>fp32)."""
 
178
  if fp16_ok:
179
  return torch.float16, True, torch.float16
180
  return torch.float32, False, None
181
+
182
+ # ---------------------------------------------------------------------------
183
+ # Stateful Adapter around InferenceCore
184
+ # ---------------------------------------------------------------------------
185
+
186
  class _MatAnyoneSession:
187
  """
188
  Stateful controller around InferenceCore with OOM-resilient inference.
 
281
  if (th, tw) == (H, W):
282
  msk_in = mask_1hw
283
  else:
284
+ # nearest to keep binary-like edges
285
  msk_in = F.interpolate(mask_1hw.unsqueeze(0), size=(th, tw), mode="nearest")[0]
286
 
287
  img_chw = _to_chw_image(img_in).contiguous() # [C,H,W]
 
337
  if mask_1hw is not None:
338
  return _to_2d_alpha_numpy(mask_1hw)
339
  return np.full((H, W), 0.5, dtype=np.float32)
340
+
341
+ # ---------------------------------------------------------------------------
342
+ # Loader
343
+ # ---------------------------------------------------------------------------
344
+
345
  class MatAnyoneLoader:
346
  """
347
  Official MatAnyone loader with stateful, OOM-resilient session adapter.
 
392
  model_dtype, use_autocast, autocast_dtype = _choose_precision(self.device)
393
  logger.info(f"MatAnyone precision: weights={model_dtype}, autocast={use_autocast and autocast_dtype}")
394
 
395
+ # HF weights (safetensors)
396
  self.model = model_cls.from_pretrained(self.model_id)
397
  try:
398
  self.model = self.model.to(self.device).to(model_dtype)
 
465
  logger.info(f"[{tag}:mask ] shape={tuple(tv_msk.shape)} dtype={tv_msk.dtype}")
466
  except Exception as e:
467
  logger.info(f"[{tag}] debug error: {e}")
468
+
469
+ # ---------------------------------------------------------------------------
470
+ # Public symbols
471
+ # ---------------------------------------------------------------------------
472
+
473
+ __all__ = [
474
+ "MatAnyoneLoader",
475
+ "_MatAnyoneSession",
476
+ "_to_bchw",
477
+ "_resize_bchw",
478
+ "_to_chw_image",
479
+ "_to_1hw_mask",
480
+ "_to_b1hw_alpha",
481
+ "_to_2d_alpha_numpy",
482
+ "_compute_scaled_size",
483
+ "debug_shapes",
484
+ ]
485
+
486
+ # ---------------------------------------------------------------------------
487
+ # Optional CLI for quick testing (no circular imports)
488
+ # ---------------------------------------------------------------------------
489
+
490
  if __name__ == "__main__":
491
  import sys
492
+ import cv2 # only needed for this demo CLI
493
 
494
  logging.basicConfig(level=logging.INFO)
495
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
519
  print("Failed to load MatAnyone")
520
  raise SystemExit(3)
521
 
522
+ alpha = session(img_rgb, mask if mask is not None else np.ones(img_rgb.shape[:2], np.float32))
523
  cv2.imwrite("alpha_out.png", (np.clip(alpha, 0, 1) * 255).astype(np.uint8))
524
  print("Alpha matte written to alpha_out.png")