MogensR commited on
Commit
58a43ef
·
1 Parent(s): bb9d73e

Update models/loaders/matanyone_loader.py

Browse files
Files changed (1) hide show
  1. models/loaders/matanyone_loader.py +239 -218
models/loaders/matanyone_loader.py CHANGED
@@ -1,12 +1,10 @@
1
  #!/usr/bin/env python3
2
  """
3
- MatAnyone Model Loader (Hardened v2)
4
- - Prevents 5D (B,T,C,H,W) tensors from reaching conv2d.
5
- - Normalizes images to BCHW [B,C,H,W] and masks to B1HW [B,1,H,W].
6
- - idx_mask=True -> integer label map, but final output still a 2-D [H,W] mask for OpenCV.
7
- - ALWAYS returns a 2-D, contiguous, float32 mask [H,W] to downstream code.
8
- - Tries unbatched then batched calls; resizes masks with NEAREST to preserve labels.
9
- - Includes debug_shapes() for quick diagnostics.
10
  """
11
 
12
  import os
@@ -17,14 +15,16 @@
17
 
18
  import numpy as np
19
  import torch
 
 
20
 
21
  logger = logging.getLogger(__name__)
22
 
23
 
24
- # ------------------------------- Utilities -------------------------------- #
25
 
26
  def _select_device(pref: str) -> str:
27
- pref = (pref or "").lower()
28
  if pref.startswith("cuda"):
29
  return "cuda" if torch.cuda.is_available() else "cpu"
30
  if pref == "cpu":
@@ -41,296 +41,317 @@ def _as_tensor_on_device(x, device: str) -> torch.Tensor:
41
  def _to_bchw(x, device: str, is_mask: bool = False) -> torch.Tensor:
42
  """
43
  Normalize input to BCHW (image) or B1HW (mask).
44
- Accepts: HWC, CHW, BCHW, BHWC, BTCHW, BTHWC, TCHW, THWC, HW.
45
  """
46
  x = _as_tensor_on_device(x, device)
47
 
48
- # Promote to float and normalize if needed
49
  if x.dtype == torch.uint8:
50
  x = x.float().div_(255.0)
51
  elif x.dtype in (torch.int16, torch.int32, torch.int64):
52
  x = x.float()
53
 
54
- # 5D: [B,T,C,H,W] or [B,T,H,W,C] -> take first frame
55
  if x.ndim == 5:
56
- B, T = x.shape[0], x.shape[1]
57
- x = x[:, 0] if T > 0 else x.squeeze(1)
58
 
59
- # 4D
60
  if x.ndim == 4:
61
  if x.shape[-1] in (1, 3, 4) and x.shape[1] not in (1, 3, 4):
62
- x = x.permute(0, 3, 1, 2).contiguous() # BHWC -> BCHW
63
 
64
- # 3D
65
  elif x.ndim == 3:
66
- if x.shape[-1] in (1, 3, 4): # HWC -> CHW
67
  x = x.permute(2, 0, 1).contiguous()
68
- x = x.unsqueeze(0) # -> BCHW
69
 
70
- # 2D
71
  elif x.ndim == 2:
72
- if is_mask:
73
- x = x.unsqueeze(0).unsqueeze(0) # -> B1HW
74
- else:
75
- x = x.unsqueeze(0).unsqueeze(0) # 1,1,H,W
76
- x = x.repeat(1, 3, 1, 1) # 1,3,H,W
77
 
78
  else:
79
- raise ValueError(f"Unsupported tensor ndim={x.ndim} for normalization")
80
 
81
- # Finalize channels / clamp
82
  if is_mask:
83
  if x.shape[1] > 1:
84
  x = x[:, :1]
85
  x = x.clamp_(0.0, 1.0).to(torch.float32)
86
  else:
87
- C = x.shape[1]
88
- if C == 1:
89
  x = x.repeat(1, 3, 1, 1)
90
- if x.min() < 0.0 or x.max() > 1.0:
91
- x = x.clamp_(0.0, 1.0)
92
- x = x.to(torch.float32)
93
 
94
  return x
95
 
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  def _resize_mask_to(img_bchw: torch.Tensor, mask_b1hw: torch.Tensor) -> torch.Tensor:
98
- """Ensure mask spatial dims match image. Use NEAREST to keep labels crisp."""
 
99
  if img_bchw.shape[-2:] == mask_b1hw.shape[-2:]:
100
  return mask_b1hw
101
- import torch.nn.functional as F
102
  return F.interpolate(mask_b1hw, size=img_bchw.shape[-2:], mode="nearest")
103
 
104
 
105
- def debug_shapes(tag: str, image, mask) -> None:
106
- """Quick diagnostics: logs shape/dtype/min/max for image/mask."""
107
- def _info(name, t):
108
- try:
109
- tt = torch.as_tensor(t)
110
- mn = float(tt.min()) if tt.numel() else float("nan")
111
- mx = float(tt.max()) if tt.numel() else float("nan")
112
- logger.info(f"[{tag}:{name}] shape={tuple(tt.shape)} dtype={tt.dtype} "
113
- f"min={mn:.4f} max={mx:.4f}")
114
- except Exception as e:
115
- logger.info(f"[{tag}:{name}] type={type(t)} err={e}")
116
- _info("image", image)
117
- _info("mask", mask)
118
-
119
-
120
- def _to_2d_numpy_mask(x) -> np.ndarray:
121
  """
122
- Convert any tensor/ndarray mask to a 2-D contiguous float32 array [H,W] in [0,1].
123
- Handles inputs like: B1HW, BCHW, 1HW, CHW, HWC, HW, etc.
124
  """
125
- if isinstance(x, torch.Tensor):
126
- t = x.detach()
127
- else:
128
- t = torch.as_tensor(x)
129
-
130
- # Bring to float in [0,1] if likely 0..255
131
- if t.dtype == torch.uint8:
132
- t = t.float().div_(255.0)
133
- elif t.dtype in (torch.int16, torch.int32, torch.int64):
134
- t = t.float()
135
- else:
136
- t = t.float()
137
-
138
- # Reduce dimensions to [H,W]
139
- if t.ndim == 4: # e.g., [B, C, H, W]
140
- if t.shape[0] > 1:
141
- t = t[0]
142
- # now [C,H,W]
143
- if t.shape[0] > 1: # multiple channels -> take first (or could mean)
144
- t = t[0]
145
- else:
146
- t = t[0] # squeeze channel -> [H,W]
147
- elif t.ndim == 3:
148
- # Could be [1,H,W], [C,H,W], or [H,W,1]
149
- if t.shape[0] in (1, 3, 4): # CHW/1HW
150
- t = t[0] # -> [H,W] (first channel)
151
- elif t.shape[-1] == 1: # HWC with single channel
152
- t = t[..., 0] # -> [H,W]
153
  else:
154
- # Unknown 3D -> take first slice
155
- t = t[0]
156
- elif t.ndim == 2:
157
- pass # already [H,W]
158
- else:
159
- # Any other: try to squeeze to 2-D
160
- t = t.squeeze()
161
- if t.ndim != 2:
162
- # fallback to a tiny neutral mask
163
- h = int(t.shape[-2]) if t.ndim >= 2 else 512
164
- w = int(t.shape[-1]) if t.ndim >= 2 else 512
165
- t = torch.full((h, w), 0.5, dtype=torch.float32)
166
-
167
- # Clamp and convert to contiguous numpy
168
  t = t.clamp_(0.0, 1.0)
169
- m = t.cpu().numpy().astype(np.float32)
170
- return np.ascontiguousarray(m)
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
 
173
- # --------------------------- Boundary Wrapper ------------------------------ #
174
 
175
- class _MatAnyoneWrapper:
176
- """
177
- Thin, defensive wrapper around the MatAnyone InferenceCore.
178
- Normalizes inputs at the boundary and always outputs a 2-D mask for OpenCV.
179
  """
 
180
 
181
- def __init__(self, core: Any, device: str):
 
 
 
 
 
 
182
  self.core = core
183
  self.device = device
 
184
 
185
- # Try to move the core to device, if supported.
186
  try:
187
- if hasattr(self.core, "to"):
188
- self.core.to(self.device)
189
- except Exception as e:
190
- logger.debug(f"MatAnyone core .to({self.device}) not applied: {e}")
191
-
192
- def _normalize_pair(
193
- self, image, mask, idx_mask: bool
194
- ) -> Tuple[torch.Tensor, torch.Tensor, bool]:
195
- img_bchw = _to_bchw(image, self.device, is_mask=False) # [B,C,H,W]
196
- msk_b1hw = _to_bchw(mask, self.device, is_mask=True) # [B,1,H,W]
197
- msk_b1hw = _resize_mask_to(img_bchw, msk_b1hw)
198
- return img_bchw, msk_b1hw, bool(idx_mask)
 
 
 
 
 
 
199
 
200
- def __call__(self, image, mask, idx_mask: bool = False, **kwargs):
201
  """
202
- Entry point: returns a 2-D float32 mask [H,W] for downstream OpenCV.
 
203
  """
204
- img_bchw, msk_b1hw, idx_mask = self._normalize_pair(image, mask, idx_mask)
205
-
206
- # idx_mask path -> integer labels; still output 2-D for downstream
207
- if idx_mask:
208
- m_bhw = (msk_b1hw > 0.5).long()[:, 0] # [B,H,W]
209
- # Try unbatched if B==1
210
- if img_bchw.shape[0] == 1:
211
- img_chw = img_bchw[0] # [C,H,W]
212
- m_hw = m_bhw[0] # [H,W]
213
- try:
214
- if hasattr(self.core, "step"):
215
- out = self.core.step(image=img_chw, mask=m_hw, idx_mask=True, **kwargs)
216
- return _to_2d_numpy_mask(out)
217
- except Exception as e_unbatched_idx:
218
- logger.debug(f"MatAnyone unbatched idx_mask step() failed: {e_unbatched_idx}")
219
- # Batched fallback
220
- for method_name in ("step", "process"):
221
- try:
222
- if hasattr(self.core, method_name):
223
- method = getattr(self.core, method_name)
224
- out = method(image=img_bchw, mask=m_bhw, idx_mask=True, **kwargs)
225
- return _to_2d_numpy_mask(out)
226
- except Exception as e_batched_idx:
227
- logger.debug(f"MatAnyone {method_name} idx_mask batched call failed: {e_batched_idx}")
228
-
229
- logger.warning("MatAnyone idx_mask calls failed; returning integer mask as fallback.")
230
- return _to_2d_numpy_mask(m_bhw)
231
-
232
- # Non-index mask path (soft/binary)
233
  try:
234
- if hasattr(self.core, "step") and img_bchw.shape[0] == 1:
235
- img_chw = img_bchw[0] # [C,H,W]
236
- m_1hw = msk_b1hw[0] # [1,H,W]
237
- out = self.core.step(image=img_chw, mask=m_1hw, idx_mask=False, **kwargs)
238
- return _to_2d_numpy_mask(out)
239
- except Exception as e_unbatched:
240
- logger.debug(f"MatAnyone unbatched step() failed: {e_unbatched}")
241
-
242
- # Batched fallback
243
- for method_name in ("step", "process"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  try:
245
- if hasattr(self.core, method_name):
246
- method = getattr(self.core, method_name)
247
- out = method(image=img_bchw, mask=msk_b1hw, idx_mask=False, **kwargs)
248
- return _to_2d_numpy_mask(out)
249
- except Exception as e_batched:
250
- logger.debug(f"MatAnyone {method_name} batched call failed: {e_batched}")
 
 
 
 
251
 
252
- logger.warning("MatAnyone calls failed; returning input mask as fallback.")
253
- # Return a valid 2-D mask even on total failure
254
- return _to_2d_numpy_mask(msk_b1hw)
255
 
256
- # ------------------------------- Loader ----------------------------------- #
257
 
258
  class MatAnyoneLoader:
259
- """Dedicated loader for MatAnyone models (with boundary normalization)."""
 
 
260
 
261
  def __init__(self, device: str = "cuda", cache_dir: str = "./checkpoints/matanyone_cache"):
262
  self.device = _select_device(device)
263
  self.cache_dir = cache_dir
264
  os.makedirs(self.cache_dir, exist_ok=True)
265
 
266
- self.model: Optional[Any] = None
 
 
267
  self.model_id = "PeiqingYang/MatAnyone"
268
  self.load_time = 0.0
269
 
270
- def load(self) -> Optional[Any]:
271
  """
272
- Load MatAnyone model and return a callable wrapper.
273
- Returns: _MatAnyoneWrapper or None
274
  """
275
- logger.info(f"Loading MatAnyone model: {self.model_id} (device={self.device})")
276
-
277
- strategies = [
278
- ("official", self._load_official),
279
- ("fallback", self._load_fallback),
 
 
 
280
  ]
 
 
 
 
 
 
 
281
 
282
- for strategy_name, strategy_func in strategies:
 
 
 
 
 
283
  try:
284
- logger.info(f"Trying MatAnyone loading strategy: {strategy_name}")
285
- start_time = time.time()
286
- model = strategy_func()
287
- if model:
288
- self.load_time = time.time() - start_time
289
- self.model = model
290
- logger.info(f"MatAnyone loaded via {strategy_name} in {self.load_time:.2f}s")
291
- return model
292
  except Exception as e:
293
- logger.error(f"MatAnyone {strategy_name} strategy failed: {e}")
294
- logger.debug(traceback.format_exc())
295
- continue
296
 
297
- logger.error("All MatAnyone loading strategies failed")
298
- return None
 
299
 
300
- def _load_official(self) -> Optional[Any]:
301
- """Load using the official MatAnyone API and wrap with boundary normalizer."""
302
- try:
303
- from matanyone import InferenceCore # type: ignore
304
- except Exception as e:
305
- logger.error(f"Failed to import official MatAnyone: {e}")
306
- return None
307
 
308
- core = InferenceCore(self.model_id)
309
- wrapped = _MatAnyoneWrapper(core, device=self.device)
310
- return wrapped
 
 
 
 
 
311
 
312
- def _load_fallback(self) -> Optional[Any]:
313
- """Create a minimal fallback that smooths/returns the mask."""
314
- class _FallbackCore:
315
- def step(self, image, mask, idx_mask: bool = False, **kwargs):
316
- # Convert to 2-D numpy mask as final step
317
- m2d = _to_2d_numpy_mask(mask)
318
- try:
319
- import cv2
320
- return cv2.GaussianBlur(m2d, (5, 5), 1.0)
321
- except Exception:
322
- return m2d
323
 
324
- def process(self, image, mask, **kwargs):
325
- return self.step(image, mask, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
 
327
- logger.warning("Using fallback MatAnyone (limited refinement).")
328
- core = _FallbackCore()
329
- return _MatAnyoneWrapper(core, device=self.device)
 
330
 
331
- # --------------------------- Housekeeping --------------------------- #
 
 
 
332
 
333
  def cleanup(self):
 
 
 
 
 
 
 
334
  if self.model:
335
  try:
336
  del self.model
@@ -342,7 +363,7 @@ def cleanup(self):
342
 
343
  def get_info(self) -> Dict[str, Any]:
344
  return {
345
- "loaded": self.model is not None,
346
  "model_id": self.model_id,
347
  "device": self.device,
348
  "load_time": self.load_time,
 
1
  #!/usr/bin/env python3
2
  """
3
+ MatAnyone Loader + Stateful Adapter
4
+ - Loads the official model from Hugging Face.
5
+ - Drives InferenceCore as intended: first-frame encode + warm-up, then propagation.
6
+ - Normalizes inputs so conv2d never sees 5-D tensors.
7
+ - Always outputs a 2-D, contiguous float32 mask [H,W] for OpenCV.
 
 
8
  """
9
 
10
  import os
 
15
 
16
  import numpy as np
17
  import torch
18
+ import torch.nn.functional as F
19
+ import inspect
20
 
21
  logger = logging.getLogger(__name__)
22
 
23
 
24
+ # ------------------------- Shape & dtype utilities ------------------------- #
25
 
26
  def _select_device(pref: str) -> str:
27
+ pref = (pref or "").lower() if pref else ""
28
  if pref.startswith("cuda"):
29
  return "cuda" if torch.cuda.is_available() else "cpu"
30
  if pref == "cpu":
 
41
  def _to_bchw(x, device: str, is_mask: bool = False) -> torch.Tensor:
42
  """
43
  Normalize input to BCHW (image) or B1HW (mask).
44
+ Accepts: HWC, CHW, BCHW, BHWC, BTCHW/BTHWC, TCHW/THWC, HW.
45
  """
46
  x = _as_tensor_on_device(x, device)
47
 
48
+ # dtype / range
49
  if x.dtype == torch.uint8:
50
  x = x.float().div_(255.0)
51
  elif x.dtype in (torch.int16, torch.int32, torch.int64):
52
  x = x.float()
53
 
54
+ # 5D [B,T,*,H,W] or [B,T,H,W,*] -> take first frame
55
  if x.ndim == 5:
56
+ x = x[:, 0] # -> 4D
 
57
 
58
+ # 4D: BHWC -> BCHW
59
  if x.ndim == 4:
60
  if x.shape[-1] in (1, 3, 4) and x.shape[1] not in (1, 3, 4):
61
+ x = x.permute(0, 3, 1, 2).contiguous()
62
 
63
+ # 3D: HWC -> CHW; add batch
64
  elif x.ndim == 3:
65
+ if x.shape[-1] in (1, 3, 4):
66
  x = x.permute(2, 0, 1).contiguous()
67
+ x = x.unsqueeze(0)
68
 
69
+ # 2D: add channel & batch
70
  elif x.ndim == 2:
71
+ x = x.unsqueeze(0).unsqueeze(0)
72
+ if not is_mask:
73
+ x = x.repeat(1, 3, 1, 1)
 
 
74
 
75
  else:
76
+ raise ValueError(f"Unsupported ndim={x.ndim}")
77
 
78
+ # finalize channels
79
  if is_mask:
80
  if x.shape[1] > 1:
81
  x = x[:, :1]
82
  x = x.clamp_(0.0, 1.0).to(torch.float32)
83
  else:
84
+ if x.shape[1] == 1:
 
85
  x = x.repeat(1, 3, 1, 1)
86
+ x = x.clamp_(0.0, 1.0).to(torch.float32)
 
 
87
 
88
  return x
89
 
90
 
91
+ def _to_chw_image(img_bchw: torch.Tensor) -> torch.Tensor:
92
+ """Prefer CHW for InferenceCore.step."""
93
+ if img_bchw.ndim == 4 and img_bchw.shape[0] == 1:
94
+ return img_bchw[0]
95
+ return img_bchw # some builds may accept batched; we try CHW first
96
+
97
+
98
+ def _to_1hw_mask(msk_b1hw: torch.Tensor) -> torch.Tensor:
99
+ """Non-idx path expects [1,H,W] for single target."""
100
+ if msk_b1hw is None:
101
+ return None
102
+ if msk_b1hw.ndim == 4 and msk_b1hw.shape[1] == 1:
103
+ return msk_b1hw[0] # -> [1,H,W]
104
+ if msk_b1hw.ndim == 3 and msk_b1hw.shape[0] == 1:
105
+ return msk_b1hw
106
+ raise ValueError(f"Expected B1HW or 1HW, got {tuple(msk_b1hw.shape)}")
107
+
108
+
109
  def _resize_mask_to(img_bchw: torch.Tensor, mask_b1hw: torch.Tensor) -> torch.Tensor:
110
+ if mask_b1hw is None:
111
+ return None
112
  if img_bchw.shape[-2:] == mask_b1hw.shape[-2:]:
113
  return mask_b1hw
 
114
  return F.interpolate(mask_b1hw, size=img_bchw.shape[-2:], mode="nearest")
115
 
116
 
117
+ def _to_2d_alpha_numpy(x) -> np.ndarray:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  """
119
+ Convert probabilities/mattes to 2-D float32 [H,W] contiguous.
 
120
  """
121
+ t = torch.as_tensor(x).float()
122
+ while t.ndim > 2:
123
+ if t.ndim == 3:
124
+ t = t[0] if t.shape[0] >= 1 else t.squeeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  else:
126
+ t = t.squeeze()
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  t = t.clamp_(0.0, 1.0)
128
+ out = t.detach().cpu().numpy().astype(np.float32)
129
+ return np.ascontiguousarray(out)
130
+
131
+
132
+ def debug_shapes(tag: str, image, mask) -> None:
133
+ def _info(name, v):
134
+ try:
135
+ tv = torch.as_tensor(v)
136
+ mn = float(tv.min()) if tv.numel() else float("nan")
137
+ mx = float(tv.max()) if tv.numel() else float("nan")
138
+ logger.info(f"[{tag}:{name}] shape={tuple(tv.shape)} dtype={tv.dtype} min={mn:.4f} max={mx:.4f}")
139
+ except Exception as e:
140
+ logger.info(f"[{tag}:{name}] type={type(v)} err={e}")
141
+ _info("image", image)
142
+ _info("mask", mask)
143
 
144
 
145
+ # ------------------------------ Stateful Adapter --------------------------- #
146
 
147
+ class _MatAnyoneSession:
 
 
 
148
  """
149
+ Minimal stateful controller around InferenceCore.
150
 
151
+ Usage:
152
+ # frame 0 (has initial coarse mask):
153
+ alpha0 = session(frame0_rgb, mask0) # encode + warm-up predict
154
+ # frames 1..N (no mask):
155
+ alpha = session(frame_rgb) # propagate/refine
156
+ """
157
+ def __init__(self, core, device: str):
158
  self.core = core
159
  self.device = device
160
+ self.started = False
161
 
162
+ # discover supported step() kwargs
163
  try:
164
+ self._step_sig = inspect.signature(self.core.step)
165
+ self._has_first_frame_pred = "first_frame_pred" in self._step_sig.parameters
166
+ self._has_idx_mask = "idx_mask" in self._step_sig.parameters
167
+ except Exception:
168
+ self._step_sig = None
169
+ self._has_first_frame_pred = True
170
+ self._has_idx_mask = True
171
+
172
+ # discover output conversion helper
173
+ self._has_prob_to_mask = hasattr(self.core, "output_prob_to_mask")
174
+
175
+ def reset(self):
176
+ try:
177
+ if hasattr(self.core, "clear_memory"):
178
+ self.core.clear_memory()
179
+ except Exception:
180
+ pass
181
+ self.started = False
182
 
183
+ def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
184
  """
185
+ Returns a 2-D float32 alpha [H,W] suitable for OpenCV.
186
+ Expects RGB image in HWC or similar; mask as [H,W] or broadcastable.
187
  """
188
+ # Normalize inputs
189
+ img_bchw = _to_bchw(image, self.device, is_mask=False) # [B,C,H,W]
190
+ msk_b1hw = _to_bchw(mask, self.device, is_mask=True) if mask is not None else None
191
+ if msk_b1hw is not None:
192
+ msk_b1hw = _resize_mask_to(img_bchw, msk_b1hw)
193
+ img_chw = _to_chw_image(img_bchw)
194
+ m_1hw = _to_1hw_mask(msk_b1hw) if msk_b1hw is not None else None
195
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  try:
197
+ if not self.started:
198
+ if m_1hw is None:
199
+ logger.warning("First frame arrived without a mask; returning neutral alpha.")
200
+ return np.full(img_chw.shape[-2:], 0.5, dtype=np.float32)
201
+
202
+ # 1) Encode target on first frame
203
+ kwargs1 = {}
204
+ if self._has_idx_mask:
205
+ kwargs1["idx_mask"] = False
206
+ _ = self.core.step(image=img_chw, mask=m_1hw, **kwargs1)
207
+
208
+ # 2) First-frame warm-up prediction + memorize
209
+ kwargs2 = {}
210
+ if self._has_first_frame_pred:
211
+ kwargs2["first_frame_pred"] = True
212
+ out_prob = self.core.step(image=img_chw, **kwargs2)
213
+
214
+ alpha = self._to_alpha(out_prob)
215
+ self.started = True
216
+ return _to_2d_alpha_numpy(alpha)
217
+
218
+ # Subsequent frames: propagate without mask
219
+ out_prob = self.core.step(image=img_chw)
220
+ alpha = self._to_alpha(out_prob)
221
+ return _to_2d_alpha_numpy(alpha)
222
+
223
+ except Exception as e:
224
+ logger.debug(traceback.format_exc())
225
+ logger.warning(f"MatAnyone call failed; returning input mask as fallback: {e}")
226
+ if m_1hw is not None:
227
+ return _to_2d_alpha_numpy(m_1hw)
228
+ return np.full(img_chw.shape[-2:], 0.5, dtype=np.float32)
229
+
230
+ def _to_alpha(self, out_prob):
231
+ """
232
+ Convert core output to alpha. Prefer core.output_prob_to_mask(matting=True) if available.
233
+ """
234
+ if self._has_prob_to_mask:
235
  try:
236
+ return self.core.output_prob_to_mask(out_prob, matting=True)
237
+ except Exception:
238
+ pass
239
+ # Fallback heuristics
240
+ t = torch.as_tensor(out_prob).float()
241
+ if t.ndim == 3 and t.shape[0] >= 1:
242
+ return t[0]
243
+ if t.ndim >= 2:
244
+ return t
245
+ return torch.full((1, 1), 0.5, dtype=torch.float32, device=t.device if t.is_cuda else "cpu")
246
 
 
 
 
247
 
248
+ # -------------------------------- Loader ---------------------------------- #
249
 
250
  class MatAnyoneLoader:
251
+ """
252
+ Official MatAnyone loader with stateful adapter.
253
+ """
254
 
255
  def __init__(self, device: str = "cuda", cache_dir: str = "./checkpoints/matanyone_cache"):
256
  self.device = _select_device(device)
257
  self.cache_dir = cache_dir
258
  os.makedirs(self.cache_dir, exist_ok=True)
259
 
260
+ self.model = None # torch.nn.Module (MatAnyone)
261
+ self.core = None # InferenceCore
262
+ self.adapter = None # _MatAnyoneSession
263
  self.model_id = "PeiqingYang/MatAnyone"
264
  self.load_time = 0.0
265
 
266
+ def _import_model_and_core(self):
267
  """
268
+ Import MatAnyone + InferenceCore with resilient fallbacks (different dist layouts).
 
269
  """
270
+ # Try several possible import paths to be robust
271
+ model_cls = core_cls = None
272
+ err_msgs = []
273
+
274
+ # Candidates for model class
275
+ model_paths = [
276
+ ("matanyone.model.matanyone", "MatAnyone"),
277
+ ("matanyone", "MatAnyone"),
278
  ]
279
+ for mod, cls in model_paths:
280
+ try:
281
+ m = __import__(mod, fromlist=[cls])
282
+ model_cls = getattr(m, cls)
283
+ break
284
+ except Exception as e:
285
+ err_msgs.append(f"model {mod}.{cls}: {e}")
286
 
287
+ # Candidates for InferenceCore
288
+ core_paths = [
289
+ ("matanyone.inference.inference_core", "InferenceCore"),
290
+ ("matanyone", "InferenceCore"),
291
+ ]
292
+ for mod, cls in core_paths:
293
  try:
294
+ m = __import__(mod, fromlist=[cls])
295
+ core_cls = getattr(m, cls)
296
+ break
 
 
 
 
 
297
  except Exception as e:
298
+ err_msgs.append(f"core {mod}.{cls}: {e}")
 
 
299
 
300
+ if model_cls is None or core_cls is None:
301
+ msg = " | ".join(err_msgs)
302
+ raise ImportError(f"Could not import MatAnyone/InferenceCore: {msg}")
303
 
304
+ return model_cls, core_cls
 
 
 
 
 
 
305
 
306
+ def load(self) -> Optional[Any]:
307
+ """
308
+ Load MatAnyone and return the stateful callable adapter.
309
+ """
310
+ logger.info(f"Loading MatAnyone from HF: {self.model_id} (device={self.device})")
311
+ start = time.time()
312
+ try:
313
+ model_cls, core_cls = self._import_model_and_core()
314
 
315
+ # Official pattern: model -> eval -> core(model, cfg=model.cfg)
316
+ self.model = model_cls.from_pretrained(self.model_id)
317
+ self.model = self.model.to(self.device).eval()
 
 
 
 
 
 
 
 
318
 
319
+ # Some builds require cfg; fall back if not present
320
+ try:
321
+ cfg = getattr(self.model, "cfg", None)
322
+ if cfg is not None:
323
+ self.core = core_cls(self.model, cfg=cfg)
324
+ else:
325
+ self.core = core_cls(self.model)
326
+ except TypeError:
327
+ # signature without cfg
328
+ self.core = core_cls(self.model)
329
+
330
+ # Move core to device if it supports .to
331
+ try:
332
+ if hasattr(self.core, "to"):
333
+ self.core.to(self.device)
334
+ except Exception:
335
+ pass
336
 
337
+ self.adapter = _MatAnyoneSession(self.core, self.device)
338
+ self.load_time = time.time() - start
339
+ logger.info(f"MatAnyone loaded in {self.load_time:.2f}s")
340
+ return self.adapter
341
 
342
+ except Exception as e:
343
+ logger.error(f"Failed to load MatAnyone: {e}")
344
+ logger.debug(traceback.format_exc())
345
+ return None
346
 
347
  def cleanup(self):
348
+ if self.adapter:
349
+ try:
350
+ self.adapter.reset()
351
+ except Exception:
352
+ pass
353
+ self.adapter = None
354
+ self.core = None
355
  if self.model:
356
  try:
357
  del self.model
 
363
 
364
  def get_info(self) -> Dict[str, Any]:
365
  return {
366
+ "loaded": self.adapter is not None,
367
  "model_id": self.model_id,
368
  "device": self.device,
369
  "load_time": self.load_time,