MogensR commited on
Commit
874e937
·
1 Parent(s): 28e0f6c

Update models/loaders/matanyone_loader.py

Browse files
Files changed (1) hide show
  1. models/loaders/matanyone_loader.py +75 -103
models/loaders/matanyone_loader.py CHANGED
@@ -1,10 +1,11 @@
1
  #!/usr/bin/env python3
2
  """
3
- MatAnyone Loader + Stateful Adapter (OOM-resilient)
4
  - Canonical HF load (MatAnyone.from_pretrained -> InferenceCore(model, cfg))
5
  - Mixed precision (bf16/fp16) with safe fallback to fp32
6
  - Autocast + inference_mode around every call
7
  - Auto downscale with progressive retry on OOM, then upsample alpha back
 
8
  - Returns 2-D float32 [H,W] alpha for OpenCV
9
  """
10
 
@@ -42,38 +43,25 @@ def _to_bchw(x, device: str, is_mask: bool = False) -> torch.Tensor:
42
  Accepts: HWC, CHW, BCHW, BHWC, BTCHW/BTHWC, TCHW/THWC, HW.
43
  """
44
  x = _as_tensor_on_device(x, device)
45
-
46
- # dtype / range
47
  if x.dtype == torch.uint8:
48
  x = x.float().div_(255.0)
49
  elif x.dtype in (torch.int16, torch.int32, torch.int64):
50
  x = x.float()
51
-
52
- # 5D -> take first time slice
53
  if x.ndim == 5:
54
  x = x[:, 0] # -> 4D
55
-
56
- # 4D: BHWC -> BCHW
57
  if x.ndim == 4:
58
  if x.shape[-1] in (1, 3, 4) and x.shape[1] not in (1, 3, 4):
59
  x = x.permute(0, 3, 1, 2).contiguous()
60
-
61
- # 3D: HWC -> CHW; add batch
62
  elif x.ndim == 3:
63
  if x.shape[-1] in (1, 3, 4):
64
  x = x.permute(2, 0, 1).contiguous()
65
  x = x.unsqueeze(0)
66
-
67
- # 2D: add channel & batch
68
  elif x.ndim == 2:
69
  x = x.unsqueeze(0).unsqueeze(0)
70
  if not is_mask:
71
  x = x.repeat(1, 3, 1, 1)
72
-
73
  else:
74
  raise ValueError(f"Unsupported ndim={x.ndim}")
75
-
76
- # finalize channels
77
  if is_mask:
78
  if x.shape[1] > 1:
79
  x = x[:, :1]
@@ -82,7 +70,6 @@ def _to_bchw(x, device: str, is_mask: bool = False) -> torch.Tensor:
82
  if x.shape[1] == 1:
83
  x = x.repeat(1, 3, 1, 1)
84
  x = x.clamp_(0.0, 1.0)
85
-
86
  return x
87
 
88
  def _to_chw_image(img_bchw: torch.Tensor) -> torch.Tensor:
@@ -108,32 +95,24 @@ def _resize_bchw(x: Optional[torch.Tensor], size_hw: Tuple[int, int], is_mask=Fa
108
  return F.interpolate(x, size=size_hw, mode=mode, align_corners=False if mode == "bilinear" else None)
109
 
110
  def _to_b1hw_alpha(alpha, device: str) -> torch.Tensor:
111
- """
112
- Convert any plausible alpha/prob output into [1,1,H,W] float in [0,1].
113
- Prevents 5D/6D mishaps when upsampling.
114
- """
115
  t = torch.as_tensor(alpha, device=device).float()
116
  if t.ndim == 2:
117
  t = t.unsqueeze(0).unsqueeze(0) # -> [1,1,H,W]
118
  elif t.ndim == 3:
119
- # CHW or 1HW
120
  if t.shape[0] in (1, 3, 4):
121
  if t.shape[0] != 1:
122
- t = t[:1] # keep first channel
123
- t = t.unsqueeze(0) # -> [1,1,H,W]
124
- elif t.shape[-1] in (1, 3, 4): # HWC (unexpected, but handle)
125
  t = t[..., :1].permute(2, 0, 1).unsqueeze(0)
126
  else:
127
- # assume [H,W,C?] incompatible → fallback to first dim semantics
128
  t = t[:1].unsqueeze(0)
129
  elif t.ndim == 4:
130
- # [B,C,H,W] → ensure C=1 and B=1
131
  if t.shape[1] != 1:
132
  t = t[:, :1]
133
  if t.shape[0] != 1:
134
  t = t[:1]
135
  else:
136
- # squeeze weird shapes down to [1,1,H,W] best-effort
137
  while t.ndim > 4:
138
  t = t.squeeze(0)
139
  while t.ndim < 4:
@@ -213,7 +192,6 @@ def reset(self):
213
  pass
214
  self.started = False
215
 
216
- # ---- helpers ----
217
  def _compute_scaled_size(self, h: int, w: int) -> Tuple[int, int, float]:
218
  if h <= 0 or w <= 0:
219
  return h, w, 1.0
@@ -225,48 +203,41 @@ def _compute_scaled_size(self, h: int, w: int) -> Tuple[int, int, float]:
225
  return nh, nw, s
226
 
227
  def _to_alpha(self, out_prob):
228
- # Prefer library conversion if available
229
  if self._has_prob_to_mask:
230
  try:
231
  return self.core.output_prob_to_mask(out_prob, matting=True)
232
  except Exception:
233
  pass
234
  t = torch.as_tensor(out_prob).float()
235
- # Normalize common cases to 2-D alpha
236
- if t.ndim == 4: # [B,C,H,W]
237
  c = 0 if t.shape[1] > 0 else None
238
  b = 0 if t.shape[0] > 0 else None
239
  if b is not None and c is not None:
240
  return t[b, c]
241
- if t.ndim == 3: # [C,H,W]
242
  return t[0] if t.shape[0] >= 1 else t.mean(0)
243
- return t # already 2-D or degenerate -> let caller sanitize
244
-
245
- # ---- main call ----
246
  def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
247
  """
248
  Returns a 2-D float32 alpha [H,W]. On first call, provide a coarse mask.
249
  Subsequent calls propagate without a mask.
250
  """
251
- # Boundary normalization
252
  img_bchw = _to_bchw(image, self.device, is_mask=False) # [1,C,H,W]
253
  msk_b1hw = _to_bchw(mask, self.device, is_mask=True) if mask is not None else None
254
 
255
  H, W = img_bchw.shape[-2], img_bchw.shape[-1]
256
- if msk_b1hw is not None:
257
- msk_b1hw = _resize_bchw(msk_b1hw, (H, W), is_mask=True)
 
 
258
 
259
- # dtype alignment for activations
260
  img_bchw = img_bchw.to(self.model_dtype, non_blocking=True)
261
-
262
- # build a deeper downscale ladder to survive tight VRAM
263
  nh, nw, s = self._compute_scaled_size(H, W)
264
  scales = [(nh, nw)]
265
- # add progressive reductions until fairly small, but not tiny
266
  if s < 1.0:
267
  f = 0.85
268
  cur_h, cur_w = nh, nw
269
- for _ in range(6): # up to 8 attempts total
270
  cur_h = max(128, int(cur_h * f))
271
  cur_w = max(128, int(cur_w * f))
272
  if (cur_h, cur_w) != scales[-1]:
@@ -278,15 +249,17 @@ def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
278
 
279
  for (th, tw) in scales:
280
  try:
281
- # downscale for inference if needed
282
  img_in = _resize_bchw(img_bchw, (th, tw), is_mask=False)
283
  msk_in = _resize_bchw(msk_b1hw, (th, tw), is_mask=True) if msk_b1hw is not None else None
 
 
 
 
284
 
285
- img_chw = _to_chw_image(img_in).contiguous() # [C,H,W]
286
- m_1hw = _to_1hw_mask(msk_in) if msk_in is not None else None # [1,H,W] or None
287
- mask_2d = m_1hw[0].contiguous() if m_1hw is not None else None# [H,W] or None
288
 
289
- # inference with autocast + inference_mode
290
  with torch.inference_mode():
291
  if self.use_autocast:
292
  amp_ctx = torch.cuda.amp.autocast(dtype=self.autocast_dtype)
@@ -295,17 +268,12 @@ class _NoOp:
295
  def __enter__(self): return None
296
  def __exit__(self, *args): return False
297
  amp_ctx = _NoOp()
298
-
299
  with amp_ctx:
300
  if not self.started:
301
  if mask_2d is None:
302
  logger.warning("First frame arrived without a mask; returning neutral alpha.")
303
  return np.full((H, W), 0.5, dtype=np.float32)
304
-
305
- # encode/memorize — pass 2-D mask (H,W)
306
  _ = self.core.step(image=img_chw, mask=mask_2d)
307
-
308
- # warm-up predict
309
  if self._has_first_frame_pred:
310
  out_prob = self.core.step(image=img_chw, first_frame_pred=True)
311
  else:
@@ -316,13 +284,10 @@ def __exit__(self, *args): return False
316
  out_prob = self.core.step(image=img_chw)
317
  alpha = self._to_alpha(out_prob)
318
 
319
- # ---- SAFE UPSAMPLE PATH (always 4D -> 2D) ----
320
  if (th, tw) != (H, W):
321
- a_b1hw = _to_b1hw_alpha(alpha, device=img_chw.device) # [1,1,th,tw]
322
- a_b1hw = F.interpolate(a_b1hw, size=(H, W), mode="bilinear", align_corners=False) # [1,1,H,W]
323
- alpha = a_b1hw[0, 0] # -> [H,W]
324
- # ------------------------------------------------
325
-
326
  return _to_2d_alpha_numpy(alpha)
327
 
328
  except torch.cuda.OutOfMemoryError as e:
@@ -337,7 +302,6 @@ def __exit__(self, *args): return False
337
  torch.cuda.empty_cache()
338
  continue
339
 
340
- # All attempts failed → return fallback
341
  logger.warning(f"MatAnyone calls failed; returning input mask as fallback. {last_exc}")
342
  if msk_b1hw is not None:
343
  return _to_2d_alpha_numpy(msk_b1hw)
@@ -346,51 +310,34 @@ def __exit__(self, *args): return False
346
  # -------------------------------- Loader ---------------------------------- #
347
 
348
  def _choose_precision(device: str) -> Tuple[torch.dtype, bool, Optional[torch.dtype]]:
349
- """
350
- Decide model+autocast dtypes.
351
- Strategy:
352
- - Prefer bf16 autocast if supported (Ampere+), keep weights bf16 if possible.
353
- - Else use fp16 autocast, keep weights fp16 if safe.
354
- - Else fp32 without autocast.
355
- """
356
  if device != "cuda":
357
  return torch.float32, False, None
358
-
359
  bf16_ok = hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported()
360
  cc = torch.cuda.get_device_capability() if torch.cuda.is_available() else (0, 0)
361
  fp16_ok = cc[0] >= 7 # Volta+
362
-
363
  if bf16_ok:
364
  return torch.bfloat16, True, torch.bfloat16
365
  if fp16_ok:
366
  return torch.float16, True, torch.float16
367
  return torch.float32, False, None
368
 
369
-
370
  class MatAnyoneLoader:
371
  """
372
  Official MatAnyone loader with stateful, OOM-resilient adapter.
373
  """
374
-
375
  def __init__(self, device: str = "cuda", cache_dir: str = "./checkpoints/matanyone_cache"):
376
  self.device = _select_device(device)
377
  self.cache_dir = cache_dir
378
  os.makedirs(self.cache_dir, exist_ok=True)
379
-
380
- self.model = None # torch.nn.Module (MatAnyone)
381
- self.core = None # InferenceCore
382
- self.adapter = None # _MatAnyoneSession
383
  self.model_id = "PeiqingYang/MatAnyone"
384
  self.load_time = 0.0
385
 
386
  def _import_model_and_core(self):
387
- """
388
- Import MatAnyone + InferenceCore with resilient fallbacks (different dist layouts).
389
- """
390
  model_cls = core_cls = None
391
  err_msgs = []
392
-
393
- # Candidates for model class
394
  for mod, cls in [
395
  ("matanyone.model.matanyone", "MatAnyone"),
396
  ("matanyone", "MatAnyone"),
@@ -401,8 +348,6 @@ def _import_model_and_core(self):
401
  break
402
  except Exception as e:
403
  err_msgs.append(f"model {mod}.{cls}: {e}")
404
-
405
- # Candidates for InferenceCore
406
  for mod, cls in [
407
  ("matanyone.inference.inference_core", "InferenceCore"),
408
  ("matanyone", "InferenceCore"),
@@ -413,39 +358,24 @@ def _import_model_and_core(self):
413
  break
414
  except Exception as e:
415
  err_msgs.append(f"core {mod}.{cls}: {e}")
416
-
417
  if model_cls is None or core_cls is None:
418
  msg = " | ".join(err_msgs)
419
  raise ImportError(f"Could not import MatAnyone/InferenceCore: {msg}")
420
-
421
  return model_cls, core_cls
422
 
423
  def load(self) -> Optional[Any]:
424
- """
425
- Load MatAnyone and return the stateful callable adapter.
426
- """
427
  logger.info(f"Loading MatAnyone from HF: {self.model_id} (device={self.device})")
428
  start = time.time()
429
  try:
430
  model_cls, core_cls = self._import_model_and_core()
431
-
432
- # pick precision strategy
433
  model_dtype, use_autocast, autocast_dtype = _choose_precision(self.device)
434
  logger.info(f"MatAnyone precision: weights={model_dtype}, autocast={use_autocast and autocast_dtype}")
435
-
436
- # Official pattern: model -> eval -> core(model, cfg=model.cfg)
437
  self.model = model_cls.from_pretrained(self.model_id)
438
-
439
- # Try to move weights to selected dtype (safe try)
440
  try:
441
  self.model = self.model.to(self.device).to(model_dtype)
442
  except Exception:
443
  self.model = self.model.to(self.device)
444
- # keep weights fp32; still benefit from autocast
445
-
446
  self.model.eval()
447
-
448
- # Some builds require cfg; fall back if not present
449
  try:
450
  cfg = getattr(self.model, "cfg", None)
451
  if cfg is not None:
@@ -454,17 +384,13 @@ def load(self) -> Optional[Any]:
454
  self.core = core_cls(self.model)
455
  except TypeError:
456
  self.core = core_cls(self.model)
457
-
458
  try:
459
  if hasattr(self.core, "to"):
460
  self.core.to(self.device)
461
  except Exception:
462
  pass
463
-
464
- # tune scaling from env (optional)
465
  max_edge = int(os.environ.get("MATANYONE_MAX_EDGE", "768"))
466
  target_pixels = int(os.environ.get("MATANYONE_TARGET_PIXELS", "600000"))
467
-
468
  self.adapter = _MatAnyoneSession(
469
  self.core,
470
  device=self.device,
@@ -474,11 +400,9 @@ def load(self) -> Optional[Any]:
474
  max_edge=max_edge,
475
  target_pixels=target_pixels,
476
  )
477
-
478
  self.load_time = time.time() - start
479
  logger.info(f"MatAnyone loaded in {self.load_time:.2f}s")
480
  return self.adapter
481
-
482
  except Exception as e:
483
  logger.error(f"Failed to load MatAnyone: {e}")
484
  logger.debug(traceback.format_exc())
@@ -505,6 +429,54 @@ def get_info(self) -> Dict[str, Any]:
505
  "model_type": type(self.model).__name__ if self.model else None,
506
  }
507
 
508
- # Optional: instance-level shape debugging
509
  def debug_shapes(self, image, mask, tag: str = ""):
510
  debug_shapes(tag, image, mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #!/usr/bin/env python3
2
  """
3
+ MatAnyone Loader + Stateful Adapter (OOM-resilient, spatially robust)
4
  - Canonical HF load (MatAnyone.from_pretrained -> InferenceCore(model, cfg))
5
  - Mixed precision (bf16/fp16) with safe fallback to fp32
6
  - Autocast + inference_mode around every call
7
  - Auto downscale with progressive retry on OOM, then upsample alpha back
8
+ - Always aligns mask/image dimensions before inference to avoid all size errors
9
  - Returns 2-D float32 [H,W] alpha for OpenCV
10
  """
11
 
 
43
  Accepts: HWC, CHW, BCHW, BHWC, BTCHW/BTHWC, TCHW/THWC, HW.
44
  """
45
  x = _as_tensor_on_device(x, device)
 
 
46
  if x.dtype == torch.uint8:
47
  x = x.float().div_(255.0)
48
  elif x.dtype in (torch.int16, torch.int32, torch.int64):
49
  x = x.float()
 
 
50
  if x.ndim == 5:
51
  x = x[:, 0] # -> 4D
 
 
52
  if x.ndim == 4:
53
  if x.shape[-1] in (1, 3, 4) and x.shape[1] not in (1, 3, 4):
54
  x = x.permute(0, 3, 1, 2).contiguous()
 
 
55
  elif x.ndim == 3:
56
  if x.shape[-1] in (1, 3, 4):
57
  x = x.permute(2, 0, 1).contiguous()
58
  x = x.unsqueeze(0)
 
 
59
  elif x.ndim == 2:
60
  x = x.unsqueeze(0).unsqueeze(0)
61
  if not is_mask:
62
  x = x.repeat(1, 3, 1, 1)
 
63
  else:
64
  raise ValueError(f"Unsupported ndim={x.ndim}")
 
 
65
  if is_mask:
66
  if x.shape[1] > 1:
67
  x = x[:, :1]
 
70
  if x.shape[1] == 1:
71
  x = x.repeat(1, 3, 1, 1)
72
  x = x.clamp_(0.0, 1.0)
 
73
  return x
74
 
75
  def _to_chw_image(img_bchw: torch.Tensor) -> torch.Tensor:
 
95
  return F.interpolate(x, size=size_hw, mode=mode, align_corners=False if mode == "bilinear" else None)
96
 
97
  def _to_b1hw_alpha(alpha, device: str) -> torch.Tensor:
 
 
 
 
98
  t = torch.as_tensor(alpha, device=device).float()
99
  if t.ndim == 2:
100
  t = t.unsqueeze(0).unsqueeze(0) # -> [1,1,H,W]
101
  elif t.ndim == 3:
 
102
  if t.shape[0] in (1, 3, 4):
103
  if t.shape[0] != 1:
104
+ t = t[:1]
105
+ t = t.unsqueeze(0)
106
+ elif t.shape[-1] in (1, 3, 4):
107
  t = t[..., :1].permute(2, 0, 1).unsqueeze(0)
108
  else:
 
109
  t = t[:1].unsqueeze(0)
110
  elif t.ndim == 4:
 
111
  if t.shape[1] != 1:
112
  t = t[:, :1]
113
  if t.shape[0] != 1:
114
  t = t[:1]
115
  else:
 
116
  while t.ndim > 4:
117
  t = t.squeeze(0)
118
  while t.ndim < 4:
 
192
  pass
193
  self.started = False
194
 
 
195
  def _compute_scaled_size(self, h: int, w: int) -> Tuple[int, int, float]:
196
  if h <= 0 or w <= 0:
197
  return h, w, 1.0
 
203
  return nh, nw, s
204
 
205
  def _to_alpha(self, out_prob):
 
206
  if self._has_prob_to_mask:
207
  try:
208
  return self.core.output_prob_to_mask(out_prob, matting=True)
209
  except Exception:
210
  pass
211
  t = torch.as_tensor(out_prob).float()
212
+ if t.ndim == 4:
 
213
  c = 0 if t.shape[1] > 0 else None
214
  b = 0 if t.shape[0] > 0 else None
215
  if b is not None and c is not None:
216
  return t[b, c]
217
+ if t.ndim == 3:
218
  return t[0] if t.shape[0] >= 1 else t.mean(0)
219
+ return t
 
 
220
  def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
221
  """
222
  Returns a 2-D float32 alpha [H,W]. On first call, provide a coarse mask.
223
  Subsequent calls propagate without a mask.
224
  """
 
225
  img_bchw = _to_bchw(image, self.device, is_mask=False) # [1,C,H,W]
226
  msk_b1hw = _to_bchw(mask, self.device, is_mask=True) if mask is not None else None
227
 
228
  H, W = img_bchw.shape[-2], img_bchw.shape[-1]
229
+ # --- Guarantee same shape for mask/image at input resolution ---
230
+ if msk_b1hw is not None and img_bchw.shape[-2:] != msk_b1hw.shape[-2:]:
231
+ logger.warning(f"Fixing mask shape: {msk_b1hw.shape[-2:]} → {img_bchw.shape[-2:]}")
232
+ msk_b1hw = _resize_bchw(msk_b1hw, img_bchw.shape[-2:], is_mask=True)
233
 
 
234
  img_bchw = img_bchw.to(self.model_dtype, non_blocking=True)
 
 
235
  nh, nw, s = self._compute_scaled_size(H, W)
236
  scales = [(nh, nw)]
 
237
  if s < 1.0:
238
  f = 0.85
239
  cur_h, cur_w = nh, nw
240
+ for _ in range(6):
241
  cur_h = max(128, int(cur_h * f))
242
  cur_w = max(128, int(cur_w * f))
243
  if (cur_h, cur_w) != scales[-1]:
 
249
 
250
  for (th, tw) in scales:
251
  try:
 
252
  img_in = _resize_bchw(img_bchw, (th, tw), is_mask=False)
253
  msk_in = _resize_bchw(msk_b1hw, (th, tw), is_mask=True) if msk_b1hw is not None else None
254
+ # --- Guarantee same shape for mask/image at each retry scale ---
255
+ if msk_in is not None and img_in.shape[-2:] != msk_in.shape[-2:]:
256
+ logger.warning(f"Progressive retry: resizing mask from {msk_in.shape[-2:]} to {img_in.shape[-2:]}")
257
+ msk_in = _resize_bchw(msk_in, img_in.shape[-2:], is_mask=True)
258
 
259
+ img_chw = _to_chw_image(img_in).contiguous()
260
+ m_1hw = _to_1hw_mask(msk_in) if msk_in is not None else None
261
+ mask_2d = m_1hw[0].contiguous() if m_1hw is not None else None
262
 
 
263
  with torch.inference_mode():
264
  if self.use_autocast:
265
  amp_ctx = torch.cuda.amp.autocast(dtype=self.autocast_dtype)
 
268
  def __enter__(self): return None
269
  def __exit__(self, *args): return False
270
  amp_ctx = _NoOp()
 
271
  with amp_ctx:
272
  if not self.started:
273
  if mask_2d is None:
274
  logger.warning("First frame arrived without a mask; returning neutral alpha.")
275
  return np.full((H, W), 0.5, dtype=np.float32)
 
 
276
  _ = self.core.step(image=img_chw, mask=mask_2d)
 
 
277
  if self._has_first_frame_pred:
278
  out_prob = self.core.step(image=img_chw, first_frame_pred=True)
279
  else:
 
284
  out_prob = self.core.step(image=img_chw)
285
  alpha = self._to_alpha(out_prob)
286
 
 
287
  if (th, tw) != (H, W):
288
+ a_b1hw = _to_b1hw_alpha(alpha, device=img_chw.device)
289
+ a_b1hw = torch.nn.functional.interpolate(a_b1hw, size=(H, W), mode="bilinear", align_corners=False)
290
+ alpha = a_b1hw[0, 0]
 
 
291
  return _to_2d_alpha_numpy(alpha)
292
 
293
  except torch.cuda.OutOfMemoryError as e:
 
302
  torch.cuda.empty_cache()
303
  continue
304
 
 
305
  logger.warning(f"MatAnyone calls failed; returning input mask as fallback. {last_exc}")
306
  if msk_b1hw is not None:
307
  return _to_2d_alpha_numpy(msk_b1hw)
 
310
  # -------------------------------- Loader ---------------------------------- #
311
 
312
  def _choose_precision(device: str) -> Tuple[torch.dtype, bool, Optional[torch.dtype]]:
 
 
 
 
 
 
 
313
  if device != "cuda":
314
  return torch.float32, False, None
 
315
  bf16_ok = hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported()
316
  cc = torch.cuda.get_device_capability() if torch.cuda.is_available() else (0, 0)
317
  fp16_ok = cc[0] >= 7 # Volta+
 
318
  if bf16_ok:
319
  return torch.bfloat16, True, torch.bfloat16
320
  if fp16_ok:
321
  return torch.float16, True, torch.float16
322
  return torch.float32, False, None
323
 
 
324
  class MatAnyoneLoader:
325
  """
326
  Official MatAnyone loader with stateful, OOM-resilient adapter.
327
  """
 
328
  def __init__(self, device: str = "cuda", cache_dir: str = "./checkpoints/matanyone_cache"):
329
  self.device = _select_device(device)
330
  self.cache_dir = cache_dir
331
  os.makedirs(self.cache_dir, exist_ok=True)
332
+ self.model = None
333
+ self.core = None
334
+ self.adapter = None
 
335
  self.model_id = "PeiqingYang/MatAnyone"
336
  self.load_time = 0.0
337
 
338
  def _import_model_and_core(self):
 
 
 
339
  model_cls = core_cls = None
340
  err_msgs = []
 
 
341
  for mod, cls in [
342
  ("matanyone.model.matanyone", "MatAnyone"),
343
  ("matanyone", "MatAnyone"),
 
348
  break
349
  except Exception as e:
350
  err_msgs.append(f"model {mod}.{cls}: {e}")
 
 
351
  for mod, cls in [
352
  ("matanyone.inference.inference_core", "InferenceCore"),
353
  ("matanyone", "InferenceCore"),
 
358
  break
359
  except Exception as e:
360
  err_msgs.append(f"core {mod}.{cls}: {e}")
 
361
  if model_cls is None or core_cls is None:
362
  msg = " | ".join(err_msgs)
363
  raise ImportError(f"Could not import MatAnyone/InferenceCore: {msg}")
 
364
  return model_cls, core_cls
365
 
366
  def load(self) -> Optional[Any]:
 
 
 
367
  logger.info(f"Loading MatAnyone from HF: {self.model_id} (device={self.device})")
368
  start = time.time()
369
  try:
370
  model_cls, core_cls = self._import_model_and_core()
 
 
371
  model_dtype, use_autocast, autocast_dtype = _choose_precision(self.device)
372
  logger.info(f"MatAnyone precision: weights={model_dtype}, autocast={use_autocast and autocast_dtype}")
 
 
373
  self.model = model_cls.from_pretrained(self.model_id)
 
 
374
  try:
375
  self.model = self.model.to(self.device).to(model_dtype)
376
  except Exception:
377
  self.model = self.model.to(self.device)
 
 
378
  self.model.eval()
 
 
379
  try:
380
  cfg = getattr(self.model, "cfg", None)
381
  if cfg is not None:
 
384
  self.core = core_cls(self.model)
385
  except TypeError:
386
  self.core = core_cls(self.model)
 
387
  try:
388
  if hasattr(self.core, "to"):
389
  self.core.to(self.device)
390
  except Exception:
391
  pass
 
 
392
  max_edge = int(os.environ.get("MATANYONE_MAX_EDGE", "768"))
393
  target_pixels = int(os.environ.get("MATANYONE_TARGET_PIXELS", "600000"))
 
394
  self.adapter = _MatAnyoneSession(
395
  self.core,
396
  device=self.device,
 
400
  max_edge=max_edge,
401
  target_pixels=target_pixels,
402
  )
 
403
  self.load_time = time.time() - start
404
  logger.info(f"MatAnyone loaded in {self.load_time:.2f}s")
405
  return self.adapter
 
406
  except Exception as e:
407
  logger.error(f"Failed to load MatAnyone: {e}")
408
  logger.debug(traceback.format_exc())
 
429
  "model_type": type(self.model).__name__ if self.model else None,
430
  }
431
 
 
432
  def debug_shapes(self, image, mask, tag: str = ""):
433
  debug_shapes(tag, image, mask)
434
+ # -------------------------- Optional: Module-level symbols --------------------------
435
+
436
+ __all__ = [
437
+ "MatAnyoneLoader",
438
+ "_MatAnyoneSession",
439
+ "_to_bchw",
440
+ "_resize_bchw",
441
+ "_to_chw_image",
442
+ "_to_1hw_mask",
443
+ "_to_b1hw_alpha",
444
+ "_to_2d_alpha_numpy",
445
+ "debug_shapes"
446
+ ]
447
+
448
+ # -------------------------- (Optional) Simple CLI for quick testing --------------------------
449
+
450
+ if __name__ == "__main__":
451
+ import sys
452
+
453
+ logging.basicConfig(level=logging.INFO)
454
+ device = "cuda" if torch.cuda.is_available() else "cpu"
455
+
456
+ if len(sys.argv) < 2:
457
+ print(f"Usage: {sys.argv[0]} image.jpg [mask.png]")
458
+ sys.exit(1)
459
+ image_path = sys.argv[1]
460
+ mask_path = sys.argv[2] if len(sys.argv) > 2 else None
461
+
462
+ img = cv2.imread(image_path, cv2.IMREAD_COLOR)
463
+ if img is None:
464
+ print(f"Could not load image {image_path}")
465
+ sys.exit(2)
466
+ img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
467
+
468
+ mask = None
469
+ if mask_path:
470
+ mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
471
+ if mask is not None and mask.max() > 1:
472
+ mask = (mask.astype(np.float32) / 255.0)
473
+
474
+ loader = MatAnyoneLoader(device=device)
475
+ session = loader.load()
476
+ if not session:
477
+ print("Failed to load MatAnyone")
478
+ sys.exit(3)
479
+
480
+ alpha = session(img_rgb, mask)
481
+ cv2.imwrite("alpha_out.png", (np.clip(alpha, 0, 1) * 255).astype(np.uint8))
482
+ print("Alpha matte written to alpha_out.png")