MogensR commited on
Commit
183c1c8
·
1 Parent(s): a41fc30

Update models/loaders/matanyone_loader.py

Browse files
Files changed (1) hide show
  1. models/loaders/matanyone_loader.py +57 -79
models/loaders/matanyone_loader.py CHANGED
@@ -3,33 +3,30 @@
3
  """
4
  MatAnyone Loader + Stateful Adapter (OOM-resilient, spatially robust)
5
  - Canonical HF load (MatAnyone.from_pretrained → InferenceCore(model, cfg))
6
- - Mixed precision (bf16/fp16) with safe fallback to fp32
7
  - torch.autocast(device_type="cuda", dtype=...) + torch.inference_mode()
8
  - Progressive downscale ladder with graceful fallback
9
  - Strict image↔mask alignment on every path/scale
10
  - Returns 2-D float32 [H,W] alpha (OpenCV-friendly)
 
 
 
11
  """
12
-
13
  from __future__ import annotations
14
-
15
  import os
16
  import time
17
  import logging
18
  import traceback
19
  from typing import Optional, Dict, Any, Tuple, List
20
-
21
  import numpy as np
22
  import torch
23
  import torch.nn.functional as F
24
  import inspect
25
  import threading
26
-
27
  logger = logging.getLogger(__name__)
28
-
29
  # ---------------------------------------------------------------------------
30
  # Utilities (shapes, dtype, scaling)
31
  # ---------------------------------------------------------------------------
32
-
33
  def _select_device(pref: str) -> str:
34
  pref = (pref or "").lower()
35
  if pref.startswith("cuda"):
@@ -37,12 +34,10 @@ def _select_device(pref: str) -> str:
37
  if pref == "cpu":
38
  return "cpu"
39
  return "cuda" if torch.cuda.is_available() else "cpu"
40
-
41
  def _as_tensor_on_device(x, device: str) -> torch.Tensor:
42
  if isinstance(x, torch.Tensor):
43
  return x.to(device, non_blocking=True)
44
  return torch.from_numpy(np.asarray(x)).to(device, non_blocking=True)
45
-
46
  def _to_bchw(x, device: str, is_mask: bool = False) -> torch.Tensor:
47
  """
48
  Normalize input to BCHW (image) or B1HW (mask).
@@ -54,7 +49,7 @@ def _to_bchw(x, device: str, is_mask: bool = False) -> torch.Tensor:
54
  elif x.dtype in (torch.int16, torch.int32, torch.int64):
55
  x = x.float()
56
  if x.ndim == 5:
57
- x = x[:, 0] # -> 4D
58
  if x.ndim == 4:
59
  if x.shape[-1] in (1, 3, 4) and x.shape[1] not in (1, 3, 4):
60
  x = x.permute(0, 3, 1, 2).contiguous()
@@ -77,21 +72,18 @@ def _to_bchw(x, device: str, is_mask: bool = False) -> torch.Tensor:
77
  x = x.repeat(1, 3, 1, 1)
78
  x = x.clamp_(0.0, 1.0)
79
  return x
80
-
81
  def _to_chw_image(img_bchw: torch.Tensor) -> torch.Tensor:
82
  if img_bchw.ndim == 4 and img_bchw.shape[0] == 1:
83
  return img_bchw[0]
84
  return img_bchw
85
-
86
  def _to_1hw_mask(msk_b1hw: torch.Tensor) -> Optional[torch.Tensor]:
87
  if msk_b1hw is None:
88
  return None
89
  if msk_b1hw.ndim == 4 and msk_b1hw.shape[1] == 1:
90
- return msk_b1hw[0] # -> [1,H,W]
91
  if msk_b1hw.ndim == 3 and msk_b1hw.shape[0] == 1:
92
  return msk_b1hw
93
  raise ValueError(f"Expected B1HW or 1HW, got {tuple(msk_b1hw.shape)}")
94
-
95
  def _resize_bchw(x: Optional[torch.Tensor], size_hw: Tuple[int, int], is_mask: bool = False) -> Optional[torch.Tensor]:
96
  if x is None:
97
  return None
@@ -99,11 +91,10 @@ def _resize_bchw(x: Optional[torch.Tensor], size_hw: Tuple[int, int], is_mask: b
99
  return x
100
  mode = "nearest" if is_mask else "bilinear"
101
  return F.interpolate(x, size_hw, mode=mode, align_corners=False if mode == "bilinear" else None)
102
-
103
  def _to_b1hw_alpha(alpha, device: str) -> torch.Tensor:
104
  t = torch.as_tensor(alpha, device=device).float()
105
  if t.ndim == 2:
106
- t = t.unsqueeze(0).unsqueeze(0) # -> [1,1,H,W]
107
  elif t.ndim == 3:
108
  if t.shape[0] in (1, 3, 4):
109
  if t.shape[0] != 1:
@@ -126,7 +117,6 @@ def _to_b1hw_alpha(alpha, device: str) -> torch.Tensor:
126
  if t.shape[1] != 1:
127
  t = t[:, :1]
128
  return t.clamp_(0.0, 1.0).contiguous()
129
-
130
  def _to_2d_alpha_numpy(x) -> np.ndarray:
131
  t = torch.as_tensor(x).float()
132
  while t.ndim > 2:
@@ -139,17 +129,32 @@ def _to_2d_alpha_numpy(x) -> np.ndarray:
139
  t = t.clamp_(0.0, 1.0)
140
  out = t.detach().cpu().numpy().astype(np.float32)
141
  return np.ascontiguousarray(out)
142
-
143
  def _compute_scaled_size(h: int, w: int, max_edge: int, target_pixels: int) -> Tuple[int, int, float]:
144
  if h <= 0 or w <= 0:
145
  return h, w, 1.0
146
  s1 = min(1.0, float(max_edge) / float(max(h, w))) if max_edge > 0 else 1.0
147
  s2 = min(1.0, (float(target_pixels) / float(h * w)) ** 0.5) if target_pixels > 0 else 1.0
148
  s = min(s1, s2)
149
- nh = max(1, int(round(h * s)))
150
- nw = max(1, int(round(w * s)))
151
  return nh, nw, s
152
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  def debug_shapes(tag: str, image, mask) -> None:
154
  def _info(name, v):
155
  try:
@@ -161,28 +166,24 @@ def _info(name, v):
161
  logger.info(f"[{tag}:{name}] type={type(v)} err={e}")
162
  _info("image", image)
163
  _info("mask", mask)
164
-
165
  # ---------------------------------------------------------------------------
166
  # Precision selection
167
  # ---------------------------------------------------------------------------
168
-
169
  def _choose_precision(device: str) -> Tuple[torch.dtype, bool, Optional[torch.dtype]]:
170
- """Pick model weight dtype + autocast dtype (bf16>fp16>fp32)."""
171
  if device != "cuda":
172
  return torch.float32, False, None
173
- bf16_ok = hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported()
174
  cc = torch.cuda.get_device_capability() if torch.cuda.is_available() else (0, 0)
175
  fp16_ok = cc[0] >= 7 # Volta+
 
 
 
176
  if bf16_ok:
177
  return torch.bfloat16, True, torch.bfloat16
178
- if fp16_ok:
179
- return torch.float16, True, torch.float16
180
  return torch.float32, False, None
181
-
182
  # ---------------------------------------------------------------------------
183
  # Stateful Adapter around InferenceCore
184
  # ---------------------------------------------------------------------------
185
-
186
  class _MatAnyoneSession:
187
  """
188
  Stateful controller around InferenceCore with OOM-resilient inference.
@@ -196,7 +197,7 @@ def __init__(
196
  use_autocast: bool,
197
  autocast_dtype: Optional[torch.dtype],
198
  max_edge: int = 768,
199
- target_pixels: int = 600_000, # ~775x775 by area
200
  ):
201
  self.core = core
202
  self.device = device
@@ -207,7 +208,6 @@ def __init__(
207
  self.target_pixels = int(target_pixels)
208
  self.started = False
209
  self._lock = threading.Lock()
210
-
211
  # Introspect optional args
212
  try:
213
  sig = inspect.signature(self.core.step)
@@ -215,7 +215,6 @@ def __init__(
215
  except Exception:
216
  self._has_first_frame_pred = True
217
  self._has_prob_to_mask = hasattr(self.core, "output_prob_to_mask")
218
-
219
  def reset(self):
220
  with self._lock:
221
  try:
@@ -224,7 +223,6 @@ def reset(self):
224
  except Exception:
225
  pass
226
  self.started = False
227
-
228
  def _scaled_ladder(self, H: int, W: int) -> List[Tuple[int, int]]:
229
  nh, nw, s = _compute_scaled_size(H, W, self.max_edge, self.target_pixels)
230
  sizes = [(nh, nw)]
@@ -237,7 +235,6 @@ def _scaled_ladder(self, H: int, W: int) -> List[Tuple[int, int]]:
237
  if sizes[-1] != (cur_h, cur_w):
238
  sizes.append((cur_h, cur_w))
239
  return sizes
240
-
241
  def _to_alpha(self, out_prob):
242
  if self._has_prob_to_mask:
243
  try:
@@ -250,7 +247,6 @@ def _to_alpha(self, out_prob):
250
  if t.ndim == 3:
251
  return t[0] if t.shape[0] >= 1 else t.mean(0)
252
  return t
253
-
254
  def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
255
  """
256
  Returns a 2-D float32 alpha [H,W].
@@ -258,19 +254,16 @@ def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
258
  - frames 1..N: pass mask=None (propagation)
259
  """
260
  with self._lock:
261
- img_bchw = _to_bchw(image, self.device, is_mask=False) # [1,C,H,W]
262
  H, W = img_bchw.shape[-2], img_bchw.shape[-1]
263
  img_bchw = img_bchw.to(self.model_dtype, non_blocking=True)
264
-
265
  # Normalize + align provided mask (if any) to **B1HW** at full res
266
  msk_b1hw = _to_bchw(mask, self.device, is_mask=True) if mask is not None else None
267
  if msk_b1hw is not None and msk_b1hw.shape[-2:] != (H, W):
268
  msk_b1hw = _resize_bchw(msk_b1hw, (H, W), is_mask=True)
269
- mask_1hw = _to_1hw_mask(msk_b1hw) if msk_b1hw is not None else None # ← 1HW!
270
-
271
  sizes = self._scaled_ladder(H, W)
272
  last_exc = None
273
-
274
  for (th, tw) in sizes:
275
  try:
276
  img_in = img_bchw if (th, tw) == (H, W) else F.interpolate(
@@ -283,9 +276,12 @@ def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
283
  else:
284
  # nearest to keep binary-like edges
285
  msk_in = F.interpolate(mask_1hw.unsqueeze(0), size=(th, tw), mode="nearest")[0]
286
-
287
- img_chw = _to_chw_image(img_in).contiguous() # [C,H,W]
288
-
 
 
 
289
  with torch.inference_mode():
290
  if self.use_autocast:
291
  amp_ctx = torch.autocast(device_type="cuda", dtype=self.autocast_dtype)
@@ -294,7 +290,6 @@ class _NoOp:
294
  def __enter__(self): return None
295
  def __exit__(self, *a): return False
296
  amp_ctx = _NoOp()
297
-
298
  with amp_ctx:
299
  if not self.started:
300
  if msk_in is None:
@@ -310,17 +305,15 @@ def __exit__(self, *a): return False
310
  self.started = True
311
  else:
312
  out_prob = self.core.step(image=img_chw)
313
-
314
  alpha = self._to_alpha(out_prob)
315
-
 
316
  # Upsample alpha back if we ran at a smaller scale
317
  if (th, tw) != (H, W):
318
  a_b1hw = _to_b1hw_alpha(alpha, device=img_bchw.device)
319
  a_b1hw = F.interpolate(a_b1hw, size=(H, W), mode="bilinear", align_corners=False)
320
  alpha = a_b1hw[0, 0]
321
-
322
  return _to_2d_alpha_numpy(alpha)
323
-
324
  except torch.cuda.OutOfMemoryError as e:
325
  last_exc = e
326
  torch.cuda.empty_cache()
@@ -332,16 +325,13 @@ def __exit__(self, *a): return False
332
  logger.debug(traceback.format_exc())
333
  logger.warning(f"MatAnyone call failed at {th}x{tw}; retrying smaller. {e}")
334
  continue
335
-
336
  logger.warning(f"MatAnyone calls failed; returning input mask or neutral alpha. {last_exc}")
337
  if mask_1hw is not None:
338
  return _to_2d_alpha_numpy(mask_1hw)
339
  return np.full((H, W), 0.5, dtype=np.float32)
340
-
341
  # ---------------------------------------------------------------------------
342
  # Loader
343
  # ---------------------------------------------------------------------------
344
-
345
  class MatAnyoneLoader:
346
  """
347
  Official MatAnyone loader with stateful, OOM-resilient session adapter.
@@ -355,7 +345,6 @@ def __init__(self, device: str = "cuda", cache_dir: str = "./checkpoints/matanyo
355
  self.adapter = None
356
  self.model_id = "PeiqingYang/MatAnyone"
357
  self.load_time = 0.0
358
-
359
  # --- Robust imports (works with different packaging layouts) ---
360
  def _import_model_and_core(self):
361
  model_cls = core_cls = None
@@ -379,11 +368,10 @@ def _import_model_and_core(self):
379
  core_cls = getattr(m, cls)
380
  break
381
  except Exception as e:
382
- err_msgs.append(f"core {mod}.{cls}: {e}")
383
  if model_cls is None or core_cls is None:
384
  raise ImportError("Could not import MatAnyone / InferenceCore: " + " | ".join(err_msgs))
385
  return model_cls, core_cls
386
-
387
  def load(self) -> Optional[Any]:
388
  logger.info(f"Loading MatAnyone from HF: {self.model_id} (device={self.device})")
389
  t0 = time.time()
@@ -391,7 +379,6 @@ def load(self) -> Optional[Any]:
391
  model_cls, core_cls = self._import_model_and_core()
392
  model_dtype, use_autocast, autocast_dtype = _choose_precision(self.device)
393
  logger.info(f"MatAnyone precision: weights={model_dtype}, autocast={use_autocast and autocast_dtype}")
394
-
395
  # HF weights (safetensors)
396
  self.model = model_cls.from_pretrained(self.model_id)
397
  try:
@@ -399,21 +386,27 @@ def load(self) -> Optional[Any]:
399
  except Exception:
400
  self.model = self.model.to(self.device)
401
  self.model.eval()
402
-
403
- # Inference core (cfg may or may not exist on the model)
 
 
 
 
 
 
 
 
 
404
  try:
405
- cfg = getattr(self.model, "cfg", None)
406
- self.core = core_cls(self.model, cfg=cfg) if cfg is not None else core_cls(self.model)
407
  except TypeError:
408
  self.core = core_cls(self.model)
409
-
410
  # Some versions expose .to(), some don’t — best effort
411
  try:
412
  if hasattr(self.core, "to"):
413
  self.core.to(self.device)
414
  except Exception:
415
  pass
416
-
417
  # Build stateful adapter
418
  max_edge = int(os.environ.get("MATANYONE_MAX_EDGE", "768"))
419
  target_pixels = int(os.environ.get("MATANYONE_TARGET_PIXELS", "600000"))
@@ -429,12 +422,10 @@ def load(self) -> Optional[Any]:
429
  self.load_time = time.time() - t0
430
  logger.info(f"MatAnyone loaded in {self.load_time:.2f}s")
431
  return self.adapter
432
-
433
  except Exception as e:
434
  logger.error(f"Failed to load MatAnyone: {e}")
435
  logger.debug(traceback.format_exc())
436
  return None
437
-
438
  def cleanup(self):
439
  self.adapter = None
440
  self.core = None
@@ -446,7 +437,6 @@ def cleanup(self):
446
  self.model = None
447
  if torch.cuda.is_available():
448
  torch.cuda.empty_cache()
449
-
450
  def get_info(self) -> Dict[str, Any]:
451
  return {
452
  "loaded": self.adapter is not None,
@@ -455,7 +445,6 @@ def get_info(self) -> Dict[str, Any]:
455
  "load_time": self.load_time,
456
  "model_type": type(self.model).__name__ if self.model else None,
457
  }
458
-
459
  def debug_shapes(self, image, mask, tag: str = ""):
460
  try:
461
  tv_img = torch.as_tensor(image)
@@ -465,11 +454,9 @@ def debug_shapes(self, image, mask, tag: str = ""):
465
  logger.info(f"[{tag}:mask ] shape={tuple(tv_msk.shape)} dtype={tv_msk.dtype}")
466
  except Exception as e:
467
  logger.info(f"[{tag}] debug error: {e}")
468
-
469
  # ---------------------------------------------------------------------------
470
  # Public symbols
471
  # ---------------------------------------------------------------------------
472
-
473
  __all__ = [
474
  "MatAnyoneLoader",
475
  "_MatAnyoneSession",
@@ -482,43 +469,34 @@ def debug_shapes(self, image, mask, tag: str = ""):
482
  "_compute_scaled_size",
483
  "debug_shapes",
484
  ]
485
-
486
  # ---------------------------------------------------------------------------
487
  # Optional CLI for quick testing (no circular imports)
488
  # ---------------------------------------------------------------------------
489
-
490
  if __name__ == "__main__":
491
  import sys
492
- import cv2 # only needed for this demo CLI
493
-
494
  logging.basicConfig(level=logging.INFO)
495
  device = "cuda" if torch.cuda.is_available() else "cpu"
496
-
497
  if len(sys.argv) < 2:
498
  print(f"Usage: {sys.argv[0]} image.jpg [mask.png]")
499
  raise SystemExit(1)
500
-
501
  image_path = sys.argv[1]
502
- mask_path = sys.argv[2] if len(sys.argv) > 2 else None
503
-
504
  img_bgr = cv2.imread(image_path, cv2.IMREAD_COLOR)
505
  if img_bgr is None:
506
  print(f"Could not load image {image_path}")
507
  raise SystemExit(2)
508
  img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
509
-
510
  mask = None
511
  if mask_path:
512
  mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
513
  if mask is not None and mask.max() > 1:
514
  mask = (mask.astype(np.float32) / 255.0)
515
-
516
  loader = MatAnyoneLoader(device=device)
517
  session = loader.load()
518
  if not session:
519
  print("Failed to load MatAnyone")
520
  raise SystemExit(3)
521
-
522
  alpha = session(img_rgb, mask if mask is not None else np.ones(img_rgb.shape[:2], np.float32))
523
  cv2.imwrite("alpha_out.png", (np.clip(alpha, 0, 1) * 255).astype(np.uint8))
524
- print("Alpha matte written to alpha_out.png")
 
3
  """
4
  MatAnyone Loader + Stateful Adapter (OOM-resilient, spatially robust)
5
  - Canonical HF load (MatAnyone.from_pretrained → InferenceCore(model, cfg))
6
+ - Mixed precision (fp16 preferred over bf16) with safe fallback to fp32
7
  - torch.autocast(device_type="cuda", dtype=...) + torch.inference_mode()
8
  - Progressive downscale ladder with graceful fallback
9
  - Strict image↔mask alignment on every path/scale
10
  - Returns 2-D float32 [H,W] alpha (OpenCV-friendly)
11
+ - Added: Force chunk_size=1, flip_aug=False in cfg to avoid dim mismatches
12
+ - Added: Pad to multiple of 16 to avoid transformer patch issues
13
+ - Added: Prefer fp16 over bf16 for Tesla T4 compatibility
14
  """
 
15
  from __future__ import annotations
 
16
  import os
17
  import time
18
  import logging
19
  import traceback
20
  from typing import Optional, Dict, Any, Tuple, List
 
21
  import numpy as np
22
  import torch
23
  import torch.nn.functional as F
24
  import inspect
25
  import threading
 
26
  logger = logging.getLogger(__name__)
 
27
  # ---------------------------------------------------------------------------
28
  # Utilities (shapes, dtype, scaling)
29
  # ---------------------------------------------------------------------------
 
30
  def _select_device(pref: str) -> str:
31
  pref = (pref or "").lower()
32
  if pref.startswith("cuda"):
 
34
  if pref == "cpu":
35
  return "cpu"
36
  return "cuda" if torch.cuda.is_available() else "cpu"
 
37
  def _as_tensor_on_device(x, device: str) -> torch.Tensor:
38
  if isinstance(x, torch.Tensor):
39
  return x.to(device, non_blocking=True)
40
  return torch.from_numpy(np.asarray(x)).to(device, non_blocking=True)
 
41
  def _to_bchw(x, device: str, is_mask: bool = False) -> torch.Tensor:
42
  """
43
  Normalize input to BCHW (image) or B1HW (mask).
 
49
  elif x.dtype in (torch.int16, torch.int32, torch.int64):
50
  x = x.float()
51
  if x.ndim == 5:
52
+ x = x[:, 0] # -> 4D
53
  if x.ndim == 4:
54
  if x.shape[-1] in (1, 3, 4) and x.shape[1] not in (1, 3, 4):
55
  x = x.permute(0, 3, 1, 2).contiguous()
 
72
  x = x.repeat(1, 3, 1, 1)
73
  x = x.clamp_(0.0, 1.0)
74
  return x
 
75
  def _to_chw_image(img_bchw: torch.Tensor) -> torch.Tensor:
76
  if img_bchw.ndim == 4 and img_bchw.shape[0] == 1:
77
  return img_bchw[0]
78
  return img_bchw
 
79
  def _to_1hw_mask(msk_b1hw: torch.Tensor) -> Optional[torch.Tensor]:
80
  if msk_b1hw is None:
81
  return None
82
  if msk_b1hw.ndim == 4 and msk_b1hw.shape[1] == 1:
83
+ return msk_b1hw[0] # -> [1,H,W]
84
  if msk_b1hw.ndim == 3 and msk_b1hw.shape[0] == 1:
85
  return msk_b1hw
86
  raise ValueError(f"Expected B1HW or 1HW, got {tuple(msk_b1hw.shape)}")
 
87
  def _resize_bchw(x: Optional[torch.Tensor], size_hw: Tuple[int, int], is_mask: bool = False) -> Optional[torch.Tensor]:
88
  if x is None:
89
  return None
 
91
  return x
92
  mode = "nearest" if is_mask else "bilinear"
93
  return F.interpolate(x, size_hw, mode=mode, align_corners=False if mode == "bilinear" else None)
 
94
  def _to_b1hw_alpha(alpha, device: str) -> torch.Tensor:
95
  t = torch.as_tensor(alpha, device=device).float()
96
  if t.ndim == 2:
97
+ t = t.unsqueeze(0).unsqueeze(0) # -> [1,1,H,W]
98
  elif t.ndim == 3:
99
  if t.shape[0] in (1, 3, 4):
100
  if t.shape[0] != 1:
 
117
  if t.shape[1] != 1:
118
  t = t[:, :1]
119
  return t.clamp_(0.0, 1.0).contiguous()
 
120
  def _to_2d_alpha_numpy(x) -> np.ndarray:
121
  t = torch.as_tensor(x).float()
122
  while t.ndim > 2:
 
129
  t = t.clamp_(0.0, 1.0)
130
  out = t.detach().cpu().numpy().astype(np.float32)
131
  return np.ascontiguousarray(out)
 
132
  def _compute_scaled_size(h: int, w: int, max_edge: int, target_pixels: int) -> Tuple[int, int, float]:
133
  if h <= 0 or w <= 0:
134
  return h, w, 1.0
135
  s1 = min(1.0, float(max_edge) / float(max(h, w))) if max_edge > 0 else 1.0
136
  s2 = min(1.0, (float(target_pixels) / float(h * w)) ** 0.5) if target_pixels > 0 else 1.0
137
  s = min(s1, s2)
138
+ nh = max(128, int(round(h * s))) # Force min 128 to avoid small-res bugs
139
+ nw = max(128, int(round(w * s)))
140
  return nh, nw, s
141
+ def _pad_to_multiple(t: Optional[torch.Tensor], multiple: int = 16) -> Optional[torch.Tensor]:
142
+ if t is None:
143
+ return None
144
+ if t.ndim == 3:
145
+ c, h, w = t.shape
146
+ elif t.ndim == 2:
147
+ h, w = t.shape
148
+ t = t.unsqueeze(0) # Temp to 3D for padding
149
+ else:
150
+ raise ValueError(f"Unsupported ndim for padding: {t.ndim}")
151
+ pad_h = (multiple - h % multiple) % multiple
152
+ pad_w = (multiple - w % multiple) % multiple
153
+ if pad_h or pad_w:
154
+ t = F.pad(t, (0, pad_w, 0, pad_h))
155
+ if t.ndim == 2: # Shouldn't happen
156
+ t = t.squeeze(0)
157
+ return t
158
  def debug_shapes(tag: str, image, mask) -> None:
159
  def _info(name, v):
160
  try:
 
166
  logger.info(f"[{tag}:{name}] type={type(v)} err={e}")
167
  _info("image", image)
168
  _info("mask", mask)
 
169
  # ---------------------------------------------------------------------------
170
  # Precision selection
171
  # ---------------------------------------------------------------------------
 
172
  def _choose_precision(device: str) -> Tuple[torch.dtype, bool, Optional[torch.dtype]]:
173
+ """Pick model weight dtype + autocast dtype (fp16>bf16>fp32) for T4 compatibility."""
174
  if device != "cuda":
175
  return torch.float32, False, None
 
176
  cc = torch.cuda.get_device_capability() if torch.cuda.is_available() else (0, 0)
177
  fp16_ok = cc[0] >= 7 # Volta+
178
+ bf16_ok = cc[0] >= 8 and hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported() # Ampere+ strict
179
+ if fp16_ok:
180
+ return torch.float16, True, torch.float16 # Prefer fp16 for T4
181
  if bf16_ok:
182
  return torch.bfloat16, True, torch.bfloat16
 
 
183
  return torch.float32, False, None
 
184
  # ---------------------------------------------------------------------------
185
  # Stateful Adapter around InferenceCore
186
  # ---------------------------------------------------------------------------
 
187
  class _MatAnyoneSession:
188
  """
189
  Stateful controller around InferenceCore with OOM-resilient inference.
 
197
  use_autocast: bool,
198
  autocast_dtype: Optional[torch.dtype],
199
  max_edge: int = 768,
200
+ target_pixels: int = 600_000, # ~775x775 by area
201
  ):
202
  self.core = core
203
  self.device = device
 
208
  self.target_pixels = int(target_pixels)
209
  self.started = False
210
  self._lock = threading.Lock()
 
211
  # Introspect optional args
212
  try:
213
  sig = inspect.signature(self.core.step)
 
215
  except Exception:
216
  self._has_first_frame_pred = True
217
  self._has_prob_to_mask = hasattr(self.core, "output_prob_to_mask")
 
218
  def reset(self):
219
  with self._lock:
220
  try:
 
223
  except Exception:
224
  pass
225
  self.started = False
 
226
  def _scaled_ladder(self, H: int, W: int) -> List[Tuple[int, int]]:
227
  nh, nw, s = _compute_scaled_size(H, W, self.max_edge, self.target_pixels)
228
  sizes = [(nh, nw)]
 
235
  if sizes[-1] != (cur_h, cur_w):
236
  sizes.append((cur_h, cur_w))
237
  return sizes
 
238
  def _to_alpha(self, out_prob):
239
  if self._has_prob_to_mask:
240
  try:
 
247
  if t.ndim == 3:
248
  return t[0] if t.shape[0] >= 1 else t.mean(0)
249
  return t
 
250
  def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
251
  """
252
  Returns a 2-D float32 alpha [H,W].
 
254
  - frames 1..N: pass mask=None (propagation)
255
  """
256
  with self._lock:
257
+ img_bchw = _to_bchw(image, self.device, is_mask=False) # [1,C,H,W]
258
  H, W = img_bchw.shape[-2], img_bchw.shape[-1]
259
  img_bchw = img_bchw.to(self.model_dtype, non_blocking=True)
 
260
  # Normalize + align provided mask (if any) to **B1HW** at full res
261
  msk_b1hw = _to_bchw(mask, self.device, is_mask=True) if mask is not None else None
262
  if msk_b1hw is not None and msk_b1hw.shape[-2:] != (H, W):
263
  msk_b1hw = _resize_bchw(msk_b1hw, (H, W), is_mask=True)
264
+ mask_1hw = _to_1hw_mask(msk_b1hw) if msk_b1hw is not None else None # ← 1HW!
 
265
  sizes = self._scaled_ladder(H, W)
266
  last_exc = None
 
267
  for (th, tw) in sizes:
268
  try:
269
  img_in = img_bchw if (th, tw) == (H, W) else F.interpolate(
 
276
  else:
277
  # nearest to keep binary-like edges
278
  msk_in = F.interpolate(mask_1hw.unsqueeze(0), size=(th, tw), mode="nearest")[0]
279
+ img_chw = _to_chw_image(img_in).contiguous() # [C,H,W]
280
+ # Pad to multiple of 16
281
+ img_chw = _pad_to_multiple(img_chw)
282
+ if msk_in is not None:
283
+ msk_in = _pad_to_multiple(msk_in)
284
+ ph, pw = img_chw.shape[-2:]
285
  with torch.inference_mode():
286
  if self.use_autocast:
287
  amp_ctx = torch.autocast(device_type="cuda", dtype=self.autocast_dtype)
 
290
  def __enter__(self): return None
291
  def __exit__(self, *a): return False
292
  amp_ctx = _NoOp()
 
293
  with amp_ctx:
294
  if not self.started:
295
  if msk_in is None:
 
305
  self.started = True
306
  else:
307
  out_prob = self.core.step(image=img_chw)
 
308
  alpha = self._to_alpha(out_prob)
309
+ # Unpad to scaled size, then upsample if needed
310
+ alpha = alpha[:th, :tw]
311
  # Upsample alpha back if we ran at a smaller scale
312
  if (th, tw) != (H, W):
313
  a_b1hw = _to_b1hw_alpha(alpha, device=img_bchw.device)
314
  a_b1hw = F.interpolate(a_b1hw, size=(H, W), mode="bilinear", align_corners=False)
315
  alpha = a_b1hw[0, 0]
 
316
  return _to_2d_alpha_numpy(alpha)
 
317
  except torch.cuda.OutOfMemoryError as e:
318
  last_exc = e
319
  torch.cuda.empty_cache()
 
325
  logger.debug(traceback.format_exc())
326
  logger.warning(f"MatAnyone call failed at {th}x{tw}; retrying smaller. {e}")
327
  continue
 
328
  logger.warning(f"MatAnyone calls failed; returning input mask or neutral alpha. {last_exc}")
329
  if mask_1hw is not None:
330
  return _to_2d_alpha_numpy(mask_1hw)
331
  return np.full((H, W), 0.5, dtype=np.float32)
 
332
  # ---------------------------------------------------------------------------
333
  # Loader
334
  # ---------------------------------------------------------------------------
 
335
  class MatAnyoneLoader:
336
  """
337
  Official MatAnyone loader with stateful, OOM-resilient session adapter.
 
345
  self.adapter = None
346
  self.model_id = "PeiqingYang/MatAnyone"
347
  self.load_time = 0.0
 
348
  # --- Robust imports (works with different packaging layouts) ---
349
  def _import_model_and_core(self):
350
  model_cls = core_cls = None
 
368
  core_cls = getattr(m, cls)
369
  break
370
  except Exception as e:
371
+ err_msgs.append(f"core {mod}.{cls}: {e}")
372
  if model_cls is None or core_cls is None:
373
  raise ImportError("Could not import MatAnyone / InferenceCore: " + " | ".join(err_msgs))
374
  return model_cls, core_cls
 
375
  def load(self) -> Optional[Any]:
376
  logger.info(f"Loading MatAnyone from HF: {self.model_id} (device={self.device})")
377
  t0 = time.time()
 
379
  model_cls, core_cls = self._import_model_and_core()
380
  model_dtype, use_autocast, autocast_dtype = _choose_precision(self.device)
381
  logger.info(f"MatAnyone precision: weights={model_dtype}, autocast={use_autocast and autocast_dtype}")
 
382
  # HF weights (safetensors)
383
  self.model = model_cls.from_pretrained(self.model_id)
384
  try:
 
386
  except Exception:
387
  self.model = self.model.to(self.device)
388
  self.model.eval()
389
+ # Override cfg to disable features causing dim mismatches
390
+ default_cfg = {
391
+ 'chunk_size': 1,
392
+ 'flip_aug': False,
393
+ }
394
+ cfg = getattr(self.model, "cfg", default_cfg) or default_cfg
395
+ if isinstance(cfg, dict):
396
+ cfg.update(default_cfg) # Override
397
+ else:
398
+ cfg = default_cfg
399
+ # Inference core
400
  try:
401
+ self.core = core_cls(self.model, cfg=cfg)
 
402
  except TypeError:
403
  self.core = core_cls(self.model)
 
404
  # Some versions expose .to(), some don’t — best effort
405
  try:
406
  if hasattr(self.core, "to"):
407
  self.core.to(self.device)
408
  except Exception:
409
  pass
 
410
  # Build stateful adapter
411
  max_edge = int(os.environ.get("MATANYONE_MAX_EDGE", "768"))
412
  target_pixels = int(os.environ.get("MATANYONE_TARGET_PIXELS", "600000"))
 
422
  self.load_time = time.time() - t0
423
  logger.info(f"MatAnyone loaded in {self.load_time:.2f}s")
424
  return self.adapter
 
425
  except Exception as e:
426
  logger.error(f"Failed to load MatAnyone: {e}")
427
  logger.debug(traceback.format_exc())
428
  return None
 
429
  def cleanup(self):
430
  self.adapter = None
431
  self.core = None
 
437
  self.model = None
438
  if torch.cuda.is_available():
439
  torch.cuda.empty_cache()
 
440
  def get_info(self) -> Dict[str, Any]:
441
  return {
442
  "loaded": self.adapter is not None,
 
445
  "load_time": self.load_time,
446
  "model_type": type(self.model).__name__ if self.model else None,
447
  }
 
448
  def debug_shapes(self, image, mask, tag: str = ""):
449
  try:
450
  tv_img = torch.as_tensor(image)
 
454
  logger.info(f"[{tag}:mask ] shape={tuple(tv_msk.shape)} dtype={tv_msk.dtype}")
455
  except Exception as e:
456
  logger.info(f"[{tag}] debug error: {e}")
 
457
  # ---------------------------------------------------------------------------
458
  # Public symbols
459
  # ---------------------------------------------------------------------------
 
460
  __all__ = [
461
  "MatAnyoneLoader",
462
  "_MatAnyoneSession",
 
469
  "_compute_scaled_size",
470
  "debug_shapes",
471
  ]
 
472
  # ---------------------------------------------------------------------------
473
  # Optional CLI for quick testing (no circular imports)
474
  # ---------------------------------------------------------------------------
 
475
  if __name__ == "__main__":
476
  import sys
477
+ import cv2 # only needed for this demo CLI
 
478
  logging.basicConfig(level=logging.INFO)
479
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
480
  if len(sys.argv) < 2:
481
  print(f"Usage: {sys.argv[0]} image.jpg [mask.png]")
482
  raise SystemExit(1)
 
483
  image_path = sys.argv[1]
484
+ mask_path = sys.argv[2] if len(sys.argv) > 2 else None
 
485
  img_bgr = cv2.imread(image_path, cv2.IMREAD_COLOR)
486
  if img_bgr is None:
487
  print(f"Could not load image {image_path}")
488
  raise SystemExit(2)
489
  img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
 
490
  mask = None
491
  if mask_path:
492
  mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
493
  if mask is not None and mask.max() > 1:
494
  mask = (mask.astype(np.float32) / 255.0)
 
495
  loader = MatAnyoneLoader(device=device)
496
  session = loader.load()
497
  if not session:
498
  print("Failed to load MatAnyone")
499
  raise SystemExit(3)
 
500
  alpha = session(img_rgb, mask if mask is not None else np.ones(img_rgb.shape[:2], np.float32))
501
  cv2.imwrite("alpha_out.png", (np.clip(alpha, 0, 1) * 255).astype(np.uint8))
502
+ print("Alpha matte written to alpha_out.png")