MogensR commited on
Commit
8b8d050
·
1 Parent(s): 5711ea9

Update models/loaders/matanyone_loader.py

Browse files
Files changed (1) hide show
  1. models/loaders/matanyone_loader.py +243 -197
models/loaders/matanyone_loader.py CHANGED
@@ -1,35 +1,54 @@
1
  #!/usr/bin/env python3
2
  # -*- coding: utf-8 -*-
3
  """
4
- MatAnyone Loader + Stateful Adapter (OOM-resilient, spatially robust)
5
- - Canonical HF load (MatAnyone.from_pretrained → InferenceCore(model, cfg))
6
- - Mixed precision (fp16 preferred over bf16) with safe fallback to fp32
7
- - torch.autocast(device_type="cuda", dtype=...) + torch.inference_mode()
8
- - Progressive downscale ladder with graceful fallback
9
- - Strict image↔mask alignment on every path/scale
10
- - Returns 2-D float32 [H,W] alpha (OpenCV-friendly)
11
- - Added: Force chunk_size=1, flip_aug=False in cfg to avoid dim mismatches
12
- - Added: Pad to multiple of 16 to avoid transformer patch issues
13
- - Added: Prefer fp16 over bf16 for Tesla T4 compatibility
14
- - New: EasyDict polyfill and conversion for cfg to fix 'dict no attribute' errors
15
- - New: Full default cfg from official config.json to fix 'mem_every' issues
16
- - FIXED: Re-enabled memory features and added temporal dimension support
 
 
 
 
 
 
 
 
17
  """
 
 
 
 
18
  from __future__ import annotations
19
  import os
20
  import time
21
  import logging
22
  import traceback
23
  from typing import Optional, Dict, Any, Tuple, List
 
24
  import numpy as np
25
  import torch
26
  import torch.nn.functional as F
27
  import inspect
28
  import threading
 
 
29
  logger = logging.getLogger(__name__)
30
 
31
- # EasyDict polyfill (recursive dict with dot access)
 
 
 
32
  class EasyDict(dict):
 
33
  def __init__(self, d=None, **kwargs):
34
  if d is None:
35
  d = {}
@@ -43,21 +62,22 @@ def __init__(self, d=None, **kwargs):
43
  else:
44
  self[k] = v
45
 
46
- def __getattr__(self, name):
47
  try:
48
  return self[name]
49
  except KeyError:
50
  raise AttributeError(name)
51
 
52
- def __setattr__(self, name, value):
53
  self[name] = value
54
 
55
- def __delattr__(self, name):
56
  del self[name]
57
 
58
- # ---------------------------------------------------------------------------
59
- # Utilities (shapes, dtype, scaling)
60
- # ---------------------------------------------------------------------------
 
61
  def _select_device(pref: str) -> str:
62
  pref = (pref or "").lower()
63
  if pref.startswith("cuda"):
@@ -66,61 +86,89 @@ def _select_device(pref: str) -> str:
66
  return "cpu"
67
  return "cuda" if torch.cuda.is_available() else "cpu"
68
 
 
69
  def _as_tensor_on_device(x, device: str) -> torch.Tensor:
70
  if isinstance(x, torch.Tensor):
71
  return x.to(device, non_blocking=True)
72
  return torch.from_numpy(np.asarray(x)).to(device, non_blocking=True)
73
 
 
74
  def _to_bchw(x, device: str, is_mask: bool = False) -> torch.Tensor:
75
  """
76
  Normalize input to BCHW (image) or B1HW (mask).
77
- Accepts: HWC, CHW, BCHW, BHWC, BTCHW/BTHWC, TCHW/THWC, HW.
 
78
  """
79
  x = _as_tensor_on_device(x, device)
80
  if x.dtype == torch.uint8:
81
  x = x.float().div_(255.0)
82
  elif x.dtype in (torch.int16, torch.int32, torch.int64):
83
  x = x.float()
 
 
84
  if x.ndim == 5:
85
- x = x[:, 0] # -> 4D
 
 
 
 
 
86
  if x.ndim == 4:
 
87
  if x.shape[-1] in (1, 3, 4) and x.shape[1] not in (1, 3, 4):
88
  x = x.permute(0, 3, 1, 2).contiguous()
89
  elif x.ndim == 3:
 
90
  if x.shape[-1] in (1, 3, 4):
91
  x = x.permute(2, 0, 1).contiguous()
 
92
  x = x.unsqueeze(0)
93
  elif x.ndim == 2:
 
94
  x = x.unsqueeze(0).unsqueeze(0)
95
  if not is_mask:
96
  x = x.repeat(1, 3, 1, 1)
97
  else:
98
- raise ValueError(f"Unsupported ndim={x.ndim}")
 
99
  if is_mask:
 
100
  if x.shape[1] > 1:
101
  x = x[:, :1]
102
  x = x.clamp_(0.0, 1.0).to(torch.float32)
103
  else:
104
- if x.shape[1] == 1:
 
 
 
105
  x = x.repeat(1, 3, 1, 1)
106
  x = x.clamp_(0.0, 1.0)
107
- return x
 
 
108
 
109
  def _to_chw_image(img_bchw: torch.Tensor) -> torch.Tensor:
 
110
  if img_bchw.ndim == 4 and img_bchw.shape[0] == 1:
111
  return img_bchw[0]
112
- return img_bchw
 
 
113
 
114
- def _to_1hw_mask(msk_b1hw: torch.Tensor) -> Optional[torch.Tensor]:
 
 
115
  if msk_b1hw is None:
116
- return None
117
  if msk_b1hw.ndim == 4 and msk_b1hw.shape[1] == 1:
118
- return msk_b1hw[0] # -> [1,H,W]
119
  if msk_b1hw.ndim == 3 and msk_b1hw.shape[0] == 1:
120
  return msk_b1hw
121
- raise ValueError(f"Expected B1HW or 1HW, got {tuple(msk_b1hw.shape)}")
 
122
 
123
  def _resize_bchw(x: Optional[torch.Tensor], size_hw: Tuple[int, int], is_mask: bool = False) -> Optional[torch.Tensor]:
 
124
  if x is None:
125
  return None
126
  if x.shape[-2:] == size_hw:
@@ -128,35 +176,40 @@ def _resize_bchw(x: Optional[torch.Tensor], size_hw: Tuple[int, int], is_mask: b
128
  mode = "nearest" if is_mask else "bilinear"
129
  return F.interpolate(x, size_hw, mode=mode, align_corners=False if mode == "bilinear" else None)
130
 
 
131
  def _to_b1hw_alpha(alpha, device: str) -> torch.Tensor:
 
132
  t = torch.as_tensor(alpha, device=device).float()
133
- if t.ndim == 2:
134
- t = t.unsqueeze(0).unsqueeze(0) # -> [1,1,H,W]
135
- elif t.ndim == 3:
136
- if t.shape[0] in (1, 3, 4):
137
- if t.shape[0] != 1:
138
- t = t[:1]
139
- t = t.unsqueeze(0)
140
- elif t.shape[-1] in (1, 3, 4):
141
- t = t[..., :1].permute(2, 0, 1).unsqueeze(0)
142
- else:
143
- t = t[:1].unsqueeze(0)
144
- elif t.ndim == 4:
145
- if t.shape[1] != 1:
146
- t = t[:, :1]
147
  if t.shape[0] != 1:
148
  t = t[:1]
149
- else:
150
- while t.ndim > 4:
151
- t = t.squeeze(0)
152
- while t.ndim < 4:
153
- t = t.unsqueeze(0)
154
  if t.shape[1] != 1:
155
  t = t[:, :1]
156
- return t.clamp_(0.0, 1.0).contiguous()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  def _to_2d_alpha_numpy(x) -> np.ndarray:
 
159
  t = torch.as_tensor(x).float()
 
160
  while t.ndim > 2:
161
  if t.ndim == 4 and t.shape[0] == 1 and t.shape[1] == 1:
162
  t = t[0, 0]
@@ -168,37 +221,36 @@ def _to_2d_alpha_numpy(x) -> np.ndarray:
168
  out = t.detach().cpu().numpy().astype(np.float32)
169
  return np.ascontiguousarray(out)
170
 
 
171
  def _compute_scaled_size(h: int, w: int, max_edge: int, target_pixels: int) -> Tuple[int, int, float]:
 
172
  if h <= 0 or w <= 0:
173
  return h, w, 1.0
174
  s1 = min(1.0, float(max_edge) / float(max(h, w))) if max_edge > 0 else 1.0
175
  s2 = min(1.0, (float(target_pixels) / float(h * w)) ** 0.5) if target_pixels > 0 else 1.0
176
  s = min(s1, s2)
177
- nh = max(128, int(round(h * s))) # Force min 128 to avoid small-res bugs
178
  nw = max(128, int(round(w * s)))
179
  return nh, nw, s
180
 
181
- def _pad_to_multiple(t: Optional[torch.Tensor], multiple: int = 16) -> Optional[torch.Tensor]:
182
- if t is None:
183
- return None
184
- if t.ndim == 3:
185
- c, h, w = t.shape
186
- elif t.ndim == 2:
187
- h, w = t.shape
188
- t = t.unsqueeze(0) # Temp to 3D for padding
189
- elif t.ndim == 4: # Handle [T, C, H, W] or similar
190
- return t # Skip padding for temporal tensors
191
- else:
192
- raise ValueError(f"Unsupported ndim for padding: {t.ndim}")
193
  pad_h = (multiple - h % multiple) % multiple
194
  pad_w = (multiple - w % multiple) % multiple
195
  if pad_h or pad_w:
196
- t = F.pad(t, (0, pad_w, 0, pad_h))
197
- if t.ndim == 2: # Shouldn't happen
198
- t = t.squeeze(0)
199
  return t
200
 
 
201
  def debug_shapes(tag: str, image, mask) -> None:
 
202
  def _info(name, v):
203
  try:
204
  tv = torch.as_tensor(v)
@@ -210,29 +262,35 @@ def _info(name, v):
210
  _info("image", image)
211
  _info("mask", mask)
212
 
213
- # ---------------------------------------------------------------------------
214
- # Precision selection
215
- # ---------------------------------------------------------------------------
 
216
  def _choose_precision(device: str) -> Tuple[torch.dtype, bool, Optional[torch.dtype]]:
217
- """Pick model weight dtype + autocast dtype (fp16>bf16>fp32) for T4 compatibility."""
 
 
 
218
  if device != "cuda":
219
  return torch.float32, False, None
220
  cc = torch.cuda.get_device_capability() if torch.cuda.is_available() else (0, 0)
221
  fp16_ok = cc[0] >= 7 # Volta+
222
- bf16_ok = cc[0] >= 8 and hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported() # Ampere+ strict
223
  if fp16_ok:
224
- return torch.float16, True, torch.float16 # Prefer fp16 for T4
225
  if bf16_ok:
226
  return torch.bfloat16, True, torch.bfloat16
227
  return torch.float32, False, None
228
 
229
- # ---------------------------------------------------------------------------
230
- # Stateful Adapter around InferenceCore
231
- # ---------------------------------------------------------------------------
 
232
  class _MatAnyoneSession:
233
  """
234
  Stateful controller around InferenceCore with OOM-resilient inference.
235
  First call MUST supply a coarse mask (we enforce 1HW internally).
 
236
  """
237
  def __init__(
238
  self,
@@ -242,7 +300,7 @@ def __init__(
242
  use_autocast: bool,
243
  autocast_dtype: Optional[torch.dtype],
244
  max_edge: int = 768,
245
- target_pixels: int = 600_000, # ~775x775 by area
246
  ):
247
  self.core = core
248
  self.device = device
@@ -253,7 +311,8 @@ def __init__(
253
  self.target_pixels = int(target_pixels)
254
  self.started = False
255
  self._lock = threading.Lock()
256
- # Introspect optional args
 
257
  try:
258
  sig = inspect.signature(self.core.step)
259
  self._has_first_frame_pred = "first_frame_pred" in sig.parameters
@@ -271,6 +330,9 @@ def reset(self):
271
  self.started = False
272
 
273
  def _scaled_ladder(self, H: int, W: int) -> List[Tuple[int, int]]:
 
 
 
274
  nh, nw, s = _compute_scaled_size(H, W, self.max_edge, self.target_pixels)
275
  sizes = [(nh, nw)]
276
  if s < 1.0:
@@ -284,15 +346,16 @@ def _scaled_ladder(self, H: int, W: int) -> List[Tuple[int, int]]:
284
  return sizes
285
 
286
  def _to_alpha(self, out_prob):
 
287
  if self._has_prob_to_mask:
288
  try:
289
  return self.core.output_prob_to_mask(out_prob, matting=True)
290
  except Exception:
291
  pass
292
  t = torch.as_tensor(out_prob).float()
293
- if t.ndim == 4:
294
  return t[0, 0] if t.shape[1] >= 1 else t[0].mean(0)
295
- if t.ndim == 3:
296
  return t[0] if t.shape[0] >= 1 else t.mean(0)
297
  return t
298
 
@@ -303,76 +366,76 @@ def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
303
  - frames 1..N: pass mask=None (propagation)
304
  """
305
  with self._lock:
306
- img_bchw = _to_bchw(image, self.device, is_mask=False) # [1,C,H,W]
 
307
  H, W = img_bchw.shape[-2], img_bchw.shape[-1]
308
  img_bchw = img_bchw.to(self.model_dtype, non_blocking=True)
309
- # Normalize + align provided mask (if any) to **B1HW** at full res
310
  msk_b1hw = _to_bchw(mask, self.device, is_mask=True) if mask is not None else None
311
  if msk_b1hw is not None and msk_b1hw.shape[-2:] != (H, W):
312
  msk_b1hw = _resize_bchw(msk_b1hw, (H, W), is_mask=True)
313
- mask_1hw = _to_1hw_mask(msk_b1hw) if msk_b1hw is not None else None # ← 1HW!
 
 
 
 
314
  sizes = self._scaled_ladder(H, W)
315
  last_exc = None
 
316
  for (th, tw) in sizes:
317
  try:
318
- img_in = img_bchw if (th, tw) == (H, W) else F.interpolate(
319
- img_bchw, size=(th, tw), mode="bilinear", align_corners=False
320
- )
321
- msk_in = None
322
- if mask_1hw is not None:
323
- if (th, tw) == (H, W):
324
- msk_in = mask_1hw
325
- else:
326
- # nearest to keep binary-like edges
327
- msk_in = F.interpolate(mask_1hw.unsqueeze(0), size=(th, tw), mode="nearest")[0]
328
- img_chw = _to_chw_image(img_in).contiguous() # [C,H,W]
329
-
330
- # ADD TEMPORAL DIMENSION for video processing mode
331
- img_tchw = img_chw.unsqueeze(0) # [C,H,W] -> [T=1,C,H,W]
332
- if msk_in is not None:
333
- msk_t1hw = msk_in.unsqueeze(0) # [1,H,W] -> [T=1,1,H,W]
334
  else:
335
- msk_t1hw = None
336
-
337
- # Pad to multiple of 16 (skip for temporal tensors)
338
- img_tchw = _pad_to_multiple(img_tchw)
339
- if msk_t1hw is not None:
340
- msk_t1hw = _pad_to_multiple(msk_t1hw)
341
-
342
- ph, pw = img_tchw.shape[-2:]
 
 
 
 
 
343
  with torch.inference_mode():
344
- if self.use_autocast:
345
- amp_ctx = torch.autocast(device_type="cuda", dtype=self.autocast_dtype)
346
- else:
347
- class _NoOp:
348
- def __enter__(self): return None
349
- def __exit__(self, *a): return False
350
- amp_ctx = _NoOp()
351
  with amp_ctx:
352
  if not self.started:
353
- if msk_t1hw is None:
354
- # Should not happen when used correctly — still be defensive
355
  logger.warning("First frame arrived without a mask; returning neutral alpha.")
356
  return np.full((H, W), 0.5, dtype=np.float32)
357
- # Pass temporal tensors to core
358
- _ = self.core.step(image=img_tchw, mask=msk_t1hw)
 
359
  if self._has_first_frame_pred:
360
- out_prob = self.core.step(image=img_tchw, first_frame_pred=True)
361
  else:
362
- out_prob = self.core.step(image=img_tchw)
363
  self.started = True
364
  else:
365
- out_prob = self.core.step(image=img_tchw)
 
 
 
366
  alpha = self._to_alpha(out_prob)
367
- # Unpad to scaled size, then upsample if needed
368
  if alpha.ndim >= 2:
369
- alpha = alpha[..., :th, :tw]
370
- # Upsample alpha back if we ran at a smaller scale
371
  if (th, tw) != (H, W):
372
  a_b1hw = _to_b1hw_alpha(alpha, device=img_bchw.device)
373
  a_b1hw = F.interpolate(a_b1hw, size=(H, W), mode="bilinear", align_corners=False)
374
  alpha = a_b1hw[0, 0]
 
375
  return _to_2d_alpha_numpy(alpha)
 
376
  except torch.cuda.OutOfMemoryError as e:
377
  last_exc = e
378
  torch.cuda.empty_cache()
@@ -384,14 +447,17 @@ def __exit__(self, *a): return False
384
  logger.debug(traceback.format_exc())
385
  logger.warning(f"MatAnyone call failed at {th}x{tw}; retrying smaller. {e}")
386
  continue
 
 
387
  logger.warning(f"MatAnyone calls failed; returning input mask or neutral alpha. {last_exc}")
388
  if mask_1hw is not None:
389
  return _to_2d_alpha_numpy(mask_1hw)
390
  return np.full((H, W), 0.5, dtype=np.float32)
391
 
392
- # ---------------------------------------------------------------------------
393
- # Loader
394
- # ---------------------------------------------------------------------------
 
395
  class MatAnyoneLoader:
396
  """
397
  Official MatAnyone loader with stateful, OOM-resilient session adapter.
@@ -441,17 +507,21 @@ def load(self) -> Optional[Any]:
441
  model_cls, core_cls = self._import_model_and_core()
442
  model_dtype, use_autocast, autocast_dtype = _choose_precision(self.device)
443
  logger.info(f"MatAnyone precision: weights={model_dtype}, autocast={use_autocast and autocast_dtype}")
 
444
  # HF weights (safetensors)
445
  self.model = model_cls.from_pretrained(self.model_id)
 
 
446
  try:
447
  self.model = self.model.to(self.device).to(model_dtype)
448
  except Exception:
449
  self.model = self.model.to(self.device)
450
  self.model.eval()
451
- # Full default cfg from official config.json
 
452
  default_cfg = {
453
  "amp": False,
454
- "chunk_size": 1, # Keep at 1 for single frame processing
455
  "flip_aug": False,
456
  "long_term": {
457
  "buffer_tokens": 2000,
@@ -465,63 +535,25 @@ def load(self) -> Optional[Any]:
465
  "max_mem_frames": 5,
466
  "mem_every": 5,
467
  "model": {
468
- "aux_loss": {
469
- "query": {
470
- "enabled": True,
471
- "weight": 0.01
472
- },
473
- "sensory": {
474
- "enabled": True,
475
- "weight": 0.01
476
- }
477
- },
478
  "embed_dim": 256,
479
  "key_dim": 64,
480
- "mask_decoder": {
481
- "up_dims": [256, 128, 128, 64, 16]
482
- },
483
- "mask_encoder": {
484
- "final_dim": 256,
485
- "type": "resnet18"
486
- },
487
- "object_summarizer": {
488
- "add_pe": True,
489
- "embed_dim": 256,
490
- "num_summaries": 16
491
- },
492
  "object_transformer": {
493
- "embed_dim": 256,
494
- "ff_dim": 2048,
495
- "num_blocks": 3,
496
- "num_heads": 8,
497
  "num_queries": 16,
498
- "pixel_self_attention": {
499
- "add_pe_to_qkv": [True, True, False]
500
- },
501
- "query_self_attention": {
502
- "add_pe_to_qkv": [True, True, False]
503
- },
504
- "read_from_memory": {
505
- "add_pe_to_qkv": [True, True, False]
506
- },
507
- "read_from_past": {
508
- "add_pe_to_qkv": [True, True, False]
509
- },
510
- "read_from_pixel": {
511
- "add_pe_to_qkv": [True, True, False],
512
- "input_add_pe": False,
513
- "input_norm": False
514
- },
515
- "read_from_query": {
516
- "add_pe_to_qkv": [True, True, False],
517
- "output_norm": False
518
- }
519
  },
520
  "pixel_dim": 256,
521
- "pixel_encoder": {
522
- "ms_dims": [1024, 512, 256, 64, 3],
523
- "type": "resnet50"
524
- },
525
  "pixel_mean": [0.485, 0.456, 0.406],
526
  "pixel_pe_scale": 32,
527
  "pixel_pe_temperature": 128,
@@ -537,34 +569,35 @@ def load(self) -> Optional[Any]:
537
  "stagger_updates": 5,
538
  "top_k": 30,
539
  "use_all_masks": False,
540
- "use_long_term": True, # Enable long-term memory
541
  "visualize": False,
542
  "weights": "pretrained_models/matanyone.pth"
543
  }
544
- # Get cfg from model if available, else default
 
545
  cfg = getattr(self.model, "cfg", default_cfg) or default_cfg
546
  if isinstance(cfg, dict):
547
- cfg = dict(cfg) # Copy to avoid modifying model.cfg
548
- # Only override minimal settings for compatibility
549
  overrides = {
550
- 'chunk_size': 1, # Process one frame at a time
551
- 'flip_aug': False, # Disable augmentation
552
- # Keep memory features enabled!
553
  }
554
  cfg.update(overrides)
555
- # Convert to EasyDict for dot access
556
  cfg = EasyDict(cfg)
557
- # Inference core
 
558
  try:
559
  self.core = core_cls(self.model, cfg=cfg)
560
  except TypeError:
561
  self.core = core_cls(self.model)
562
- # Some versions expose .to(), some don't — best effort
 
563
  try:
564
  if hasattr(self.core, "to"):
565
  self.core.to(self.device)
566
  except Exception:
567
  pass
 
568
  # Build stateful adapter
569
  max_edge = int(os.environ.get("MATANYONE_MAX_EDGE", "768"))
570
  target_pixels = int(os.environ.get("MATANYONE_TARGET_PIXELS", "600000"))
@@ -580,12 +613,14 @@ def load(self) -> Optional[Any]:
580
  self.load_time = time.time() - t0
581
  logger.info(f"MatAnyone loaded in {self.load_time:.2f}s")
582
  return self.adapter
 
583
  except Exception as e:
584
  logger.error(f"Failed to load MatAnyone: {e}")
585
  logger.debug(traceback.format_exc())
586
  return None
587
 
588
  def cleanup(self):
 
589
  self.adapter = None
590
  self.core = None
591
  if self.model:
@@ -598,6 +633,7 @@ def cleanup(self):
598
  torch.cuda.empty_cache()
599
 
600
  def get_info(self) -> Dict[str, Any]:
 
601
  return {
602
  "loaded": self.adapter is not None,
603
  "model_id": self.model_id,
@@ -607,6 +643,7 @@ def get_info(self) -> Dict[str, Any]:
607
  }
608
 
609
  def debug_shapes(self, image, mask, tag: str = ""):
 
610
  try:
611
  tv_img = torch.as_tensor(image)
612
  tv_msk = torch.as_tensor(mask) if mask is not None else None
@@ -616,9 +653,10 @@ def debug_shapes(self, image, mask, tag: str = ""):
616
  except Exception as e:
617
  logger.info(f"[{tag}] debug error: {e}")
618
 
619
- # ---------------------------------------------------------------------------
620
- # Public symbols
621
- # ---------------------------------------------------------------------------
 
622
  __all__ = [
623
  "MatAnyoneLoader",
624
  "_MatAnyoneSession",
@@ -632,34 +670,42 @@ def debug_shapes(self, image, mask, tag: str = ""):
632
  "debug_shapes",
633
  ]
634
 
635
- # ---------------------------------------------------------------------------
636
- # Optional CLI for quick testing (no circular imports)
637
- # ---------------------------------------------------------------------------
 
638
  if __name__ == "__main__":
639
  import sys
640
- import cv2 # only needed for this demo CLI
641
  logging.basicConfig(level=logging.INFO)
642
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
643
  if len(sys.argv) < 2:
644
  print(f"Usage: {sys.argv[0]} image.jpg [mask.png]")
645
  raise SystemExit(1)
 
646
  image_path = sys.argv[1]
647
  mask_path = sys.argv[2] if len(sys.argv) > 2 else None
 
648
  img_bgr = cv2.imread(image_path, cv2.IMREAD_COLOR)
649
  if img_bgr is None:
650
  print(f"Could not load image {image_path}")
651
  raise SystemExit(2)
 
652
  img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
 
653
  mask = None
654
  if mask_path:
655
  mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
656
  if mask is not None and mask.max() > 1:
657
  mask = (mask.astype(np.float32) / 255.0)
 
658
  loader = MatAnyoneLoader(device=device)
659
  session = loader.load()
660
  if not session:
661
  print("Failed to load MatAnyone")
662
  raise SystemExit(3)
 
663
  alpha = session(img_rgb, mask if mask is not None else np.ones(img_rgb.shape[:2], np.float32))
664
  cv2.imwrite("alpha_out.png", (np.clip(alpha, 0, 1) * 255).astype(np.uint8))
665
- print("Alpha matte written to alpha_out.png")
 
1
  #!/usr/bin/env python3
2
  # -*- coding: utf-8 -*-
3
  """
4
+ MatAnyone Loader + Stateful Adapter (Fixed Tensor Shapes, OOM-resilient)
5
+ =======================================================================
6
+
7
+ CHAPTERS
8
+ 1) Overview & Rationale
9
+ 2) Imports & Logger
10
+ 3) EasyDict Polyfill
11
+ 4) Tensor Utilities (device, shape, resize, padding)
12
+ 5) Precision Selection (fp16/bf16/fp32)
13
+ 6) Stateful Session (_MatAnyoneSession) ← FIX: CHW / 1HW only (no temporal axis)
14
+ 7) Loader (MatAnyoneLoader)
15
+ 8) Public Symbols
16
+ 9) CLI Demo (optional quick test)
17
+
18
+ Key Fix vs. previous version
19
+ ----------------------------
20
+ - Removed the extra “temporal” axis that produced 5D tensors like [1,1,3,H,W].
21
+ - MatAnyone now receives:
22
+ • Image: CHW (float, in [0,1]) — or internally BCHW collapsed to CHW.
23
+ • Mask : 1HW (float, in [0,1]) on the first frame only; later frames mask=None.
24
+ - Kept: downscale ladder, padding to multiple of 16, mixed precision, long-term memory config.
25
  """
26
+
27
+ # ============================================================================
28
+ # 2) IMPORTS & LOGGER
29
+ # ============================================================================
30
  from __future__ import annotations
31
  import os
32
  import time
33
  import logging
34
  import traceback
35
  from typing import Optional, Dict, Any, Tuple, List
36
+
37
  import numpy as np
38
  import torch
39
  import torch.nn.functional as F
40
  import inspect
41
  import threading
42
+ import contextlib
43
+
44
  logger = logging.getLogger(__name__)
45
 
46
+
47
+ # ============================================================================
48
+ # 3) EASYDICT POLYFILL
49
+ # ============================================================================
50
  class EasyDict(dict):
51
+ """Recursive dict with dot access."""
52
  def __init__(self, d=None, **kwargs):
53
  if d is None:
54
  d = {}
 
62
  else:
63
  self[k] = v
64
 
65
+ def __getattr__(self, name): # dot-get
66
  try:
67
  return self[name]
68
  except KeyError:
69
  raise AttributeError(name)
70
 
71
+ def __setattr__(self, name, value): # dot-set
72
  self[name] = value
73
 
74
+ def __delattr__(self, name): # dot-del
75
  del self[name]
76
 
77
+
78
+ # ============================================================================
79
+ # 4) TENSOR UTILITIES (DEVICE, SHAPE, RESIZE, PADDING)
80
+ # ============================================================================
81
  def _select_device(pref: str) -> str:
82
  pref = (pref or "").lower()
83
  if pref.startswith("cuda"):
 
86
  return "cpu"
87
  return "cuda" if torch.cuda.is_available() else "cpu"
88
 
89
+
90
  def _as_tensor_on_device(x, device: str) -> torch.Tensor:
91
  if isinstance(x, torch.Tensor):
92
  return x.to(device, non_blocking=True)
93
  return torch.from_numpy(np.asarray(x)).to(device, non_blocking=True)
94
 
95
+
96
  def _to_bchw(x, device: str, is_mask: bool = False) -> torch.Tensor:
97
  """
98
  Normalize input to BCHW (image) or B1HW (mask).
99
+ Accepts: HWC, CHW, BCHW, BHWC, (accidental) 5D, and HW.
100
+ Defensive against dtype/range; output is clamped to [0,1].
101
  """
102
  x = _as_tensor_on_device(x, device)
103
  if x.dtype == torch.uint8:
104
  x = x.float().div_(255.0)
105
  elif x.dtype in (torch.int16, torch.int32, torch.int64):
106
  x = x.float()
107
+
108
+ # If upstream passed a 5D tensor (e.g., (B,1,C,H,W) or (B,T,C,H,W)), squeeze a singleton middle axis.
109
  if x.ndim == 5:
110
+ # Prefer to squeeze the 2nd dim if it's 1; otherwise take the first slice.
111
+ if x.shape[1] == 1:
112
+ x = x.squeeze(1) # -> BCHW
113
+ else:
114
+ x = x[:, 0, ...] # -> BCHW
115
+
116
  if x.ndim == 4:
117
+ # Handle BHWC → BCHW
118
  if x.shape[-1] in (1, 3, 4) and x.shape[1] not in (1, 3, 4):
119
  x = x.permute(0, 3, 1, 2).contiguous()
120
  elif x.ndim == 3:
121
+ # HWC → CHW
122
  if x.shape[-1] in (1, 3, 4):
123
  x = x.permute(2, 0, 1).contiguous()
124
+ # CHW → BCHW
125
  x = x.unsqueeze(0)
126
  elif x.ndim == 2:
127
+ # HW → B1HW (mask) or B3HW (image)
128
  x = x.unsqueeze(0).unsqueeze(0)
129
  if not is_mask:
130
  x = x.repeat(1, 3, 1, 1)
131
  else:
132
+ raise ValueError(f"_to_bchw: unsupported ndim={x.ndim}")
133
+
134
  if is_mask:
135
+ # Ensure single-channel B1HW, clamped and float32
136
  if x.shape[1] > 1:
137
  x = x[:, :1]
138
  x = x.clamp_(0.0, 1.0).to(torch.float32)
139
  else:
140
+ # Ensure RGB
141
+ if x.shape[1] == 4:
142
+ x = x[:, :3, ...]
143
+ elif x.shape[1] == 1:
144
  x = x.repeat(1, 3, 1, 1)
145
  x = x.clamp_(0.0, 1.0)
146
+
147
+ return x.contiguous()
148
+
149
 
150
  def _to_chw_image(img_bchw: torch.Tensor) -> torch.Tensor:
151
+ """BCHW → CHW (take batch 0 if present)."""
152
  if img_bchw.ndim == 4 and img_bchw.shape[0] == 1:
153
  return img_bchw[0]
154
+ if img_bchw.ndim == 3:
155
+ return img_bchw
156
+ raise ValueError(f"_to_chw_image: expected BCHW or CHW, got {tuple(img_bchw.shape)}")
157
 
158
+
159
+ def _to_1hw_mask(msk_b1hw: torch.Tensor) -> torch.Tensor:
160
+ """B1HW → 1HW (drop batch)."""
161
  if msk_b1hw is None:
162
+ raise ValueError("_to_1hw_mask: mask is None")
163
  if msk_b1hw.ndim == 4 and msk_b1hw.shape[1] == 1:
164
+ return msk_b1hw[0] # 1HW
165
  if msk_b1hw.ndim == 3 and msk_b1hw.shape[0] == 1:
166
  return msk_b1hw
167
+ raise ValueError(f"_to_1hw_mask: expected B1HW or 1HW, got {tuple(msk_b1hw.shape)}")
168
+
169
 
170
  def _resize_bchw(x: Optional[torch.Tensor], size_hw: Tuple[int, int], is_mask: bool = False) -> Optional[torch.Tensor]:
171
+ """Resize BCHW or B1HW to (H, W) using bilinear (image) or nearest (mask)."""
172
  if x is None:
173
  return None
174
  if x.shape[-2:] == size_hw:
 
176
  mode = "nearest" if is_mask else "bilinear"
177
  return F.interpolate(x, size_hw, mode=mode, align_corners=False if mode == "bilinear" else None)
178
 
179
+
180
  def _to_b1hw_alpha(alpha, device: str) -> torch.Tensor:
181
+ """Convert arbitrary mask-like input to B1HW float32 [0,1]."""
182
  t = torch.as_tensor(alpha, device=device).float()
183
+ # Squeeze extra dims down to HW/1HW first
184
+ while t.ndim > 4:
185
+ t = t.squeeze(0)
186
+ if t.ndim == 4:
187
+ # Expecting BxCxHxW; force B=1, C=1
 
 
 
 
 
 
 
 
 
188
  if t.shape[0] != 1:
189
  t = t[:1]
 
 
 
 
 
190
  if t.shape[1] != 1:
191
  t = t[:, :1]
192
+ elif t.ndim == 3:
193
+ # Could be CxHxW or HxWx1
194
+ if t.shape[0] == 1:
195
+ t = t.unsqueeze(0) # 1x1xHxW
196
+ elif t.shape[-1] == 1:
197
+ t = t.permute(2, 0, 1).unsqueeze(0) # 1x1xHxW
198
+ else:
199
+ # If C>1, take first channel
200
+ t = t[:1, ...].unsqueeze(0)
201
+ elif t.ndim == 2:
202
+ t = t.unsqueeze(0).unsqueeze(0)
203
+ else:
204
+ raise ValueError(f"_to_b1hw_alpha: unsupported ndim={t.ndim}")
205
+ t = t.clamp_(0.0, 1.0).contiguous()
206
+ return t
207
+
208
 
209
  def _to_2d_alpha_numpy(x) -> np.ndarray:
210
+ """Convert any mask-like tensor to 2D float32 numpy [H,W] in [0,1]."""
211
  t = torch.as_tensor(x).float()
212
+ # Squeeze down to 2D
213
  while t.ndim > 2:
214
  if t.ndim == 4 and t.shape[0] == 1 and t.shape[1] == 1:
215
  t = t[0, 0]
 
221
  out = t.detach().cpu().numpy().astype(np.float32)
222
  return np.ascontiguousarray(out)
223
 
224
+
225
  def _compute_scaled_size(h: int, w: int, max_edge: int, target_pixels: int) -> Tuple[int, int, float]:
226
+ """Compute a safe scaled size that respects a max edge and total pixels."""
227
  if h <= 0 or w <= 0:
228
  return h, w, 1.0
229
  s1 = min(1.0, float(max_edge) / float(max(h, w))) if max_edge > 0 else 1.0
230
  s2 = min(1.0, (float(target_pixels) / float(h * w)) ** 0.5) if target_pixels > 0 else 1.0
231
  s = min(s1, s2)
232
+ nh = max(128, int(round(h * s))) # minimum of 128 to avoid very small feature maps
233
  nw = max(128, int(round(w * s)))
234
  return nh, nw, s
235
 
236
+
237
+ def _pad_to_multiple_3d(t: torch.Tensor, multiple: int = 16) -> torch.Tensor:
238
+ """
239
+ Pad a 3D tensor (C,H,W) to multiples of `multiple`. Works for CHW and 1HW.
240
+ Returns a tensor with same ndim.
241
+ """
242
+ if t.ndim != 3:
243
+ raise ValueError(f"_pad_to_multiple_3d: expected 3D, got {t.ndim}D")
244
+ c, h, w = t.shape
 
 
 
245
  pad_h = (multiple - h % multiple) % multiple
246
  pad_w = (multiple - w % multiple) % multiple
247
  if pad_h or pad_w:
248
+ t = F.pad(t, (0, pad_w, 0, pad_h)) # (left,right,top,bottom)
 
 
249
  return t
250
 
251
+
252
  def debug_shapes(tag: str, image, mask) -> None:
253
+ """Log shapes/dtypes/min/max for quick inspection."""
254
  def _info(name, v):
255
  try:
256
  tv = torch.as_tensor(v)
 
262
  _info("image", image)
263
  _info("mask", mask)
264
 
265
+
266
+ # ============================================================================
267
+ # 5) PRECISION SELECTION (fp16/bf16/fp32)
268
+ # ============================================================================
269
  def _choose_precision(device: str) -> Tuple[torch.dtype, bool, Optional[torch.dtype]]:
270
+ """
271
+ Pick model weights dtype and autocast dtype (fp16>bf16>fp32), preferring fp16 for T4.
272
+ Returns: (model_dtype, use_autocast, autocast_dtype)
273
+ """
274
  if device != "cuda":
275
  return torch.float32, False, None
276
  cc = torch.cuda.get_device_capability() if torch.cuda.is_available() else (0, 0)
277
  fp16_ok = cc[0] >= 7 # Volta+
278
+ bf16_ok = (cc[0] >= 8) and hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported()
279
  if fp16_ok:
280
+ return torch.float16, True, torch.float16 # T4 prefers fp16
281
  if bf16_ok:
282
  return torch.bfloat16, True, torch.bfloat16
283
  return torch.float32, False, None
284
 
285
+
286
+ # ============================================================================
287
+ # 6) STATEFUL SESSION (NO TEMPORAL AXIS; STRICT CHW/1HW)
288
+ # ============================================================================
289
  class _MatAnyoneSession:
290
  """
291
  Stateful controller around InferenceCore with OOM-resilient inference.
292
  First call MUST supply a coarse mask (we enforce 1HW internally).
293
+ Subsequent calls should pass mask=None (temporal propagation handled by core).
294
  """
295
  def __init__(
296
  self,
 
300
  use_autocast: bool,
301
  autocast_dtype: Optional[torch.dtype],
302
  max_edge: int = 768,
303
+ target_pixels: int = 600_000, # ~775x775 by area
304
  ):
305
  self.core = core
306
  self.device = device
 
311
  self.target_pixels = int(target_pixels)
312
  self.started = False
313
  self._lock = threading.Lock()
314
+
315
+ # Introspect optional API surfaces
316
  try:
317
  sig = inspect.signature(self.core.step)
318
  self._has_first_frame_pred = "first_frame_pred" in sig.parameters
 
330
  self.started = False
331
 
332
  def _scaled_ladder(self, H: int, W: int) -> List[Tuple[int, int]]:
333
+ """
334
+ Build a list of decreasing (H,W) resolutions to attempt to avoid OOM.
335
+ """
336
  nh, nw, s = _compute_scaled_size(H, W, self.max_edge, self.target_pixels)
337
  sizes = [(nh, nw)]
338
  if s < 1.0:
 
346
  return sizes
347
 
348
  def _to_alpha(self, out_prob):
349
+ """Convert model output probabilities to a matte."""
350
  if self._has_prob_to_mask:
351
  try:
352
  return self.core.output_prob_to_mask(out_prob, matting=True)
353
  except Exception:
354
  pass
355
  t = torch.as_tensor(out_prob).float()
356
+ if t.ndim == 4: # BxCxHxW
357
  return t[0, 0] if t.shape[1] >= 1 else t[0].mean(0)
358
+ if t.ndim == 3: # CxHxW
359
  return t[0] if t.shape[0] >= 1 else t.mean(0)
360
  return t
361
 
 
366
  - frames 1..N: pass mask=None (propagation)
367
  """
368
  with self._lock:
369
+ # ---- 1) Normalize inputs to BCHW (image) and B1HW (mask), then collapse to CHW / 1HW
370
+ img_bchw = _to_bchw(image, self.device, is_mask=False) # BCHW
371
  H, W = img_bchw.shape[-2], img_bchw.shape[-1]
372
  img_bchw = img_bchw.to(self.model_dtype, non_blocking=True)
 
373
  msk_b1hw = _to_bchw(mask, self.device, is_mask=True) if mask is not None else None
374
  if msk_b1hw is not None and msk_b1hw.shape[-2:] != (H, W):
375
  msk_b1hw = _resize_bchw(msk_b1hw, (H, W), is_mask=True)
376
+
377
+ img_chw = _to_chw_image(img_bchw) # CHW
378
+ mask_1hw = _to_1hw_mask(msk_b1hw) if msk_b1hw is not None else None # 1HW or None
379
+
380
+ # ---- 2) Downscale ladder to avoid OOM
381
  sizes = self._scaled_ladder(H, W)
382
  last_exc = None
383
+
384
  for (th, tw) in sizes:
385
  try:
386
+ # 2a) Resize image (bilinear) and mask (nearest) to ladder size
387
+ if (th, tw) == (H, W):
388
+ img_in = img_chw
389
+ msk_in = mask_1hw
 
 
 
 
 
 
 
 
 
 
 
 
390
  else:
391
+ img_in = F.interpolate(img_chw.unsqueeze(0), size=(th, tw),
392
+ mode="bilinear", align_corners=False)[0] # CHW
393
+ msk_in = None
394
+ if mask_1hw is not None:
395
+ msk_in = F.interpolate(mask_1hw.unsqueeze(0), size=(th, tw),
396
+ mode="nearest")[0] # 1HW
397
+
398
+ # 2b) Pad to multiple of 16 (per-model stability)
399
+ img_in = _pad_to_multiple_3d(img_in) # CHW
400
+ if msk_in is not None:
401
+ msk_in = _pad_to_multiple_3d(msk_in) # 1HW
402
+
403
+ # ---- 3) Forward pass (STRICT CHW / 1HW; NO TEMPORAL AXIS)
404
  with torch.inference_mode():
405
+ amp_ctx = (
406
+ torch.autocast(device_type="cuda", dtype=self.autocast_dtype)
407
+ if self.use_autocast else
408
+ contextlib.nullcontext()
409
+ )
 
 
410
  with amp_ctx:
411
  if not self.started:
412
+ if msk_in is None:
 
413
  logger.warning("First frame arrived without a mask; returning neutral alpha.")
414
  return np.full((H, W), 0.5, dtype=np.float32)
415
+
416
+ # Initialize with first frame (explicit mask)
417
+ _ = self.core.step(image=img_in, mask=msk_in) # ← CHW + 1HW
418
  if self._has_first_frame_pred:
419
+ out_prob = self.core.step(image=img_in, first_frame_pred=True)
420
  else:
421
+ out_prob = self.core.step(image=img_in)
422
  self.started = True
423
  else:
424
+ # Subsequent frames; core uses memory internally
425
+ out_prob = self.core.step(image=img_in) # ← CHW
426
+
427
+ # ---- 4) Convert to alpha + unpad/upsample back to full res if needed
428
  alpha = self._to_alpha(out_prob)
 
429
  if alpha.ndim >= 2:
430
+ alpha = alpha[..., :th, :tw] # remove pad
431
+
432
  if (th, tw) != (H, W):
433
  a_b1hw = _to_b1hw_alpha(alpha, device=img_bchw.device)
434
  a_b1hw = F.interpolate(a_b1hw, size=(H, W), mode="bilinear", align_corners=False)
435
  alpha = a_b1hw[0, 0]
436
+
437
  return _to_2d_alpha_numpy(alpha)
438
+
439
  except torch.cuda.OutOfMemoryError as e:
440
  last_exc = e
441
  torch.cuda.empty_cache()
 
447
  logger.debug(traceback.format_exc())
448
  logger.warning(f"MatAnyone call failed at {th}x{tw}; retrying smaller. {e}")
449
  continue
450
+
451
+ # ---- 5) All attempts failed – return input mask or neutral alpha
452
  logger.warning(f"MatAnyone calls failed; returning input mask or neutral alpha. {last_exc}")
453
  if mask_1hw is not None:
454
  return _to_2d_alpha_numpy(mask_1hw)
455
  return np.full((H, W), 0.5, dtype=np.float32)
456
 
457
+
458
+ # ============================================================================
459
+ # 7) LOADER (MatAnyoneLoader)
460
+ # ============================================================================
461
  class MatAnyoneLoader:
462
  """
463
  Official MatAnyone loader with stateful, OOM-resilient session adapter.
 
507
  model_cls, core_cls = self._import_model_and_core()
508
  model_dtype, use_autocast, autocast_dtype = _choose_precision(self.device)
509
  logger.info(f"MatAnyone precision: weights={model_dtype}, autocast={use_autocast and autocast_dtype}")
510
+
511
  # HF weights (safetensors)
512
  self.model = model_cls.from_pretrained(self.model_id)
513
+
514
+ # Move to device + dtype when possible
515
  try:
516
  self.model = self.model.to(self.device).to(model_dtype)
517
  except Exception:
518
  self.model = self.model.to(self.device)
519
  self.model.eval()
520
+
521
+ # Full default cfg from official config.json (kept; enables memory features)
522
  default_cfg = {
523
  "amp": False,
524
+ "chunk_size": 1, # single-frame stepping
525
  "flip_aug": False,
526
  "long_term": {
527
  "buffer_tokens": 2000,
 
535
  "max_mem_frames": 5,
536
  "mem_every": 5,
537
  "model": {
538
+ "aux_loss": {"query": {"enabled": True, "weight": 0.01},
539
+ "sensory": {"enabled": True, "weight": 0.01}},
 
 
 
 
 
 
 
 
540
  "embed_dim": 256,
541
  "key_dim": 64,
542
+ "mask_decoder": {"up_dims": [256, 128, 128, 64, 16]},
543
+ "mask_encoder": {"final_dim": 256, "type": "resnet18"},
544
+ "object_summarizer": {"add_pe": True, "embed_dim": 256, "num_summaries": 16},
 
 
 
 
 
 
 
 
 
545
  "object_transformer": {
546
+ "embed_dim": 256, "ff_dim": 2048, "num_blocks": 3, "num_heads": 8,
 
 
 
547
  "num_queries": 16,
548
+ "pixel_self_attention": {"add_pe_to_qkv": [True, True, False]},
549
+ "query_self_attention": {"add_pe_to_qkv": [True, True, False]},
550
+ "read_from_memory": {"add_pe_to_qkv": [True, True, False]},
551
+ "read_from_past": {"add_pe_to_qkv": [True, True, False]},
552
+ "read_from_pixel": {"add_pe_to_qkv": [True, True, False], "input_add_pe": False, "input_norm": False},
553
+ "read_from_query": {"add_pe_to_qkv": [True, True, False], "output_norm": False}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
554
  },
555
  "pixel_dim": 256,
556
+ "pixel_encoder": {"ms_dims": [1024, 512, 256, 64, 3], "type": "resnet50"},
 
 
 
557
  "pixel_mean": [0.485, 0.456, 0.406],
558
  "pixel_pe_scale": 32,
559
  "pixel_pe_temperature": 128,
 
569
  "stagger_updates": 5,
570
  "top_k": 30,
571
  "use_all_masks": False,
572
+ "use_long_term": True,
573
  "visualize": False,
574
  "weights": "pretrained_models/matanyone.pth"
575
  }
576
+
577
+ # Merge with model.cfg if present; apply minimal overrides
578
  cfg = getattr(self.model, "cfg", default_cfg) or default_cfg
579
  if isinstance(cfg, dict):
580
+ cfg = dict(cfg)
 
581
  overrides = {
582
+ "chunk_size": 1,
583
+ "flip_aug": False,
 
584
  }
585
  cfg.update(overrides)
 
586
  cfg = EasyDict(cfg)
587
+
588
+ # Build inference core
589
  try:
590
  self.core = core_cls(self.model, cfg=cfg)
591
  except TypeError:
592
  self.core = core_cls(self.model)
593
+
594
+ # Some versions expose .to()
595
  try:
596
  if hasattr(self.core, "to"):
597
  self.core.to(self.device)
598
  except Exception:
599
  pass
600
+
601
  # Build stateful adapter
602
  max_edge = int(os.environ.get("MATANYONE_MAX_EDGE", "768"))
603
  target_pixels = int(os.environ.get("MATANYONE_TARGET_PIXELS", "600000"))
 
613
  self.load_time = time.time() - t0
614
  logger.info(f"MatAnyone loaded in {self.load_time:.2f}s")
615
  return self.adapter
616
+
617
  except Exception as e:
618
  logger.error(f"Failed to load MatAnyone: {e}")
619
  logger.debug(traceback.format_exc())
620
  return None
621
 
622
  def cleanup(self):
623
+ """Release model/core and clear CUDA cache."""
624
  self.adapter = None
625
  self.core = None
626
  if self.model:
 
633
  torch.cuda.empty_cache()
634
 
635
  def get_info(self) -> Dict[str, Any]:
636
+ """Lightweight status for UI/self-check."""
637
  return {
638
  "loaded": self.adapter is not None,
639
  "model_id": self.model_id,
 
643
  }
644
 
645
  def debug_shapes(self, image, mask, tag: str = ""):
646
+ """Quick shape/dtype logger."""
647
  try:
648
  tv_img = torch.as_tensor(image)
649
  tv_msk = torch.as_tensor(mask) if mask is not None else None
 
653
  except Exception as e:
654
  logger.info(f"[{tag}] debug error: {e}")
655
 
656
+
657
+ # ============================================================================
658
+ # 8) PUBLIC SYMBOLS
659
+ # ============================================================================
660
  __all__ = [
661
  "MatAnyoneLoader",
662
  "_MatAnyoneSession",
 
670
  "debug_shapes",
671
  ]
672
 
673
+
674
+ # ============================================================================
675
+ # 9) CLI DEMO (OPTIONAL QUICK TEST)
676
+ # ============================================================================
677
  if __name__ == "__main__":
678
  import sys
679
+ import cv2 # only for demo
680
  logging.basicConfig(level=logging.INFO)
681
  device = "cuda" if torch.cuda.is_available() else "cpu"
682
+
683
  if len(sys.argv) < 2:
684
  print(f"Usage: {sys.argv[0]} image.jpg [mask.png]")
685
  raise SystemExit(1)
686
+
687
  image_path = sys.argv[1]
688
  mask_path = sys.argv[2] if len(sys.argv) > 2 else None
689
+
690
  img_bgr = cv2.imread(image_path, cv2.IMREAD_COLOR)
691
  if img_bgr is None:
692
  print(f"Could not load image {image_path}")
693
  raise SystemExit(2)
694
+ # OpenCV → RGB
695
  img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
696
+
697
  mask = None
698
  if mask_path:
699
  mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
700
  if mask is not None and mask.max() > 1:
701
  mask = (mask.astype(np.float32) / 255.0)
702
+
703
  loader = MatAnyoneLoader(device=device)
704
  session = loader.load()
705
  if not session:
706
  print("Failed to load MatAnyone")
707
  raise SystemExit(3)
708
+
709
  alpha = session(img_rgb, mask if mask is not None else np.ones(img_rgb.shape[:2], np.float32))
710
  cv2.imwrite("alpha_out.png", (np.clip(alpha, 0, 1) * 255).astype(np.uint8))
711
+ print("Alpha matte written to alpha_out.png")