MogensR commited on
Commit
7b01a2f
·
1 Parent(s): 8e6cc12

Update models/loaders/matanyone_loader.py

Browse files
Files changed (1) hide show
  1. models/loaders/matanyone_loader.py +187 -126
models/loaders/matanyone_loader.py CHANGED
@@ -1,14 +1,18 @@
1
  #!/usr/bin/env python3
2
  # -*- coding: utf-8 -*-
3
  """
4
- MatAnyone Loader - Stable Callable Wrapper for InferenceCore
5
- ===========================================================
6
-
7
- - Enforces image CHW float32 [0,1] and mask 1HW float32 [0,1]
8
- - Adds internal batch dim (B=1) and removes it on output
9
- - Works with multiple possible InferenceCore loading signatures
10
- - Uses torch.inference_mode() + optional autocast for speed
11
- - Returns a 2-D alpha mask (H,W) float32 in [0,1]
 
 
 
 
12
  """
13
 
14
  import os
@@ -32,92 +36,170 @@ def _to_float01_np(arr: np.ndarray) -> np.ndarray:
32
  if arr.dtype == np.uint8:
33
  arr = arr.astype(np.float32) / 255.0
34
  else:
35
- arr = arr.astype(np.float32)
36
- # Clamp for safety
37
  np.clip(arr, 0.0, 1.0, out=arr)
38
  return arr
39
 
40
 
41
- def _ensure_chw_float01(image: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
42
  """
43
- Convert image to torch.FloatTensor CHW in [0,1].
44
- Accepts HxWxC or CHW (numpy or tensor). Adds batch dim later.
 
 
45
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  if torch.is_tensor(image):
47
  t = image
48
- if t.ndim == 3 and t.shape[0] in (1, 3, 4): # already CHW
49
- pass
50
- elif t.ndim == 3 and t.shape[-1] in (1, 3, 4): # HWC -> CHW
51
- t = t.permute(2, 0, 1)
52
- elif t.ndim == 2: # HW (grayscale)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  t = t.unsqueeze(0)
54
  else:
55
- raise ValueError(f"Unsupported image tensor shape: {tuple(t.shape)}")
 
56
  t = t.to(dtype=torch.float32)
57
- # If likely 0-255, scale; otherwise clamp to [0,1]
58
  if torch.max(t) > 1.5:
59
  t = t / 255.0
60
  t = torch.clamp(t, 0.0, 1.0)
 
61
  return t
62
- else:
63
- arr = np.asarray(image)
64
- if arr.ndim == 3 and arr.shape[2] in (1, 3, 4): # HWC
65
- arr = arr.transpose(2, 0, 1) # -> CHW
66
- elif arr.ndim == 2: # HW
67
- arr = arr[None, ...] # -> 1HW
68
- elif arr.ndim == 3 and arr.shape[0] in (1, 3, 4): # already CHW
 
69
  pass
 
 
70
  else:
71
- raise ValueError(f"Unsupported image numpy shape: {arr.shape}")
72
- arr = _to_float01_np(arr)
73
- return torch.from_numpy(arr)
 
 
 
 
 
 
 
 
 
 
74
 
75
 
76
- def _ensure_1hw_float01(mask: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
77
  """
78
- Convert mask to torch.FloatTensor 1HW in [0,1].
79
- Accepts HW, 1HW, CHW (C=1), HxWx1.
80
  """
 
 
 
81
  if torch.is_tensor(mask):
82
  m = mask
83
- if m.ndim == 2: # HW
84
- m = m.unsqueeze(0) # 1HW
85
- elif m.ndim == 3:
86
- if m.shape[0] == 1: # 1HW
87
- pass
88
- elif m.shape[-1] == 1: # HW1 -> 1HW
89
- m = m.permute(2, 0, 1)
90
  else:
91
- raise ValueError(f"Mask has too many channels: {tuple(m.shape)}")
 
 
 
 
 
 
 
 
 
 
 
92
  else:
93
- raise ValueError(f"Unsupported mask tensor shape: {tuple(m.shape)}")
 
94
  m = m.to(dtype=torch.float32)
95
  if torch.max(m) > 1.5:
96
  m = m / 255.0
97
  m = torch.clamp(m, 0.0, 1.0)
 
98
  return m
99
- else:
100
- arr = np.asarray(mask)
101
- if arr.ndim == 2: # HW
102
- arr = arr[None, ...] # 1HW
103
- elif arr.ndim == 3:
104
- if arr.shape[0] == 1: # 1HW
105
- pass
106
- elif arr.shape[-1] == 1: # HW1 -> 1HW
107
- arr = arr.transpose(2, 0, 1)
108
- else:
109
- raise ValueError(f"Mask has too many channels: {arr.shape}")
110
  else:
111
- raise ValueError(f"Unsupported mask numpy shape: {arr.shape}")
112
- arr = _to_float01_np(arr)
113
- return torch.from_numpy(arr)
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
 
116
  def _alpha_from_result(result: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
117
- """
118
- Extract a 2D alpha (H,W) float32 [0,1] from a variety of possible outputs.
119
- Accepts numpy/tensor with shapes: HW, 1HW, CHW(C>=1), BHWC, BCHW, etc.
120
- """
121
  if result is None:
122
  return np.full((512, 512), 0.5, dtype=np.float32)
123
 
@@ -125,27 +207,23 @@ def _alpha_from_result(result: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
125
  result = result.detach().float().cpu()
126
 
127
  arr = np.asarray(result)
 
 
 
 
 
 
128
  if arr.ndim == 2:
129
  alpha = arr
130
  elif arr.ndim == 3:
131
- # Prefer first channel for CHW/HWC
132
- if arr.shape[0] in (1, 3, 4): # CHW
133
  alpha = arr[0]
134
- elif arr.shape[-1] in (1, 3, 4): # HWC
135
  alpha = arr[..., 0]
136
  else:
137
- # Unknown 3D shape – take first slice robustly
138
- alpha = arr[0]
139
- elif arr.ndim == 4:
140
- # Batch first: BxCxHxW or BxHxWxC
141
- if arr.shape[1] in (1, 3, 4): # BCHW
142
- alpha = arr[0, 0]
143
- elif arr.shape[-1] in (1, 3, 4): # BHWC
144
- alpha = arr[0, ..., 0]
145
- else:
146
- alpha = arr[0, 0]
147
  else:
148
- # Fallback
149
  alpha = np.full((512, 512), 0.5, dtype=np.float32)
150
 
151
  alpha = alpha.astype(np.float32, copy=False)
@@ -154,41 +232,29 @@ def _alpha_from_result(result: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
154
 
155
 
156
  def _hw_from_image_like(x: Union[np.ndarray, torch.Tensor]) -> Tuple[int, int]:
157
- """Best-effort get (H, W) from an image/mask input for neutral fallbacks."""
158
  if torch.is_tensor(x):
159
  shape = tuple(x.shape)
160
- # Handle CHW / HWC / BCHW / BHWC / HW
161
- if len(shape) == 2: # HW
162
- return shape[0], shape[1]
163
- if len(shape) == 3:
164
- if shape[0] in (1, 3, 4): # CHW
165
- return shape[1], shape[2]
166
- if shape[-1] in (1, 3, 4): # HWC
167
- return shape[0], shape[1]
168
- if len(shape) == 4:
169
- # Assume batch first
170
- b, c_or_h, h_or_w, maybe_w = shape
171
- # Try BCHW
172
- if shape[1] in (1, 3, 4):
173
- return shape[2], shape[3]
174
- # Try BHWC
175
- return shape[1], shape[2]
176
- return 512, 512
177
  else:
178
- arr = np.asarray(x)
179
- if arr.ndim == 2: # HW
180
- return arr.shape[0], arr.shape[1]
181
- if arr.ndim == 3:
182
- if arr.shape[0] in (1, 3, 4): # CHW
183
- return arr.shape[1], arr.shape[2]
184
- if arr.shape[-1] in (1, 3, 4): # HWC
185
- return arr.shape[0], arr.shape[1]
186
- if arr.ndim == 4:
187
- # Assume batch first
188
- if arr.shape[1] in (1, 3, 4): # BCHW
189
- return arr.shape[2], arr.shape[3]
190
- return arr.shape[1], arr.shape[2]
191
- return 512, 512
 
 
 
 
 
192
 
193
 
194
  # --------------------------- Callable Wrapper ---------------------------
@@ -201,6 +267,7 @@ class MatAnyoneCallableWrapper:
201
  - First call SHOULD include a mask (1HW). If not, returns neutral 0.5 alpha.
202
  - Subsequent calls do not require mask.
203
  - Returns 2D alpha (H,W) float32 in [0,1].
 
204
  """
205
 
206
  def __init__(self, inference_core, device: str = "cuda", mixed_precision: Optional[str] = "fp16"):
@@ -213,7 +280,7 @@ def _maybe_autocast(self):
213
  if self.device == "cuda" and self.mixed_precision in ("fp16", "bf16"):
214
  dtype = torch.float16 if self.mixed_precision == "fp16" else torch.bfloat16
215
  return torch.autocast(device_type="cuda", dtype=dtype)
216
- # no-op context manager
217
  class _NullCtx:
218
  def __enter__(self): return None
219
  def __exit__(self, *exc): return False
@@ -221,9 +288,8 @@ def __exit__(self, *exc): return False
221
 
222
  def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
223
  try:
224
- # Preprocess → CHW/1HW tensors, then add batch
225
- img_chw = _ensure_chw_float01(image).to(self.device, non_blocking=True)
226
- img_bchw = img_chw.unsqueeze(0) # B=1
227
 
228
  if not self.initialized:
229
  if mask is None:
@@ -231,15 +297,14 @@ def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
231
  logger.warning("MatAnyone first frame called without mask; returning neutral alpha.")
232
  return np.full((h, w), 0.5, dtype=np.float32)
233
 
234
- m_1hw = _ensure_1hw_float01(mask).to(self.device, non_blocking=True)
235
- m_b1hw = m_1hw.unsqueeze(0) # B=1
236
 
237
  with torch.inference_mode():
238
  with self._maybe_autocast():
239
  if hasattr(self.core, "step"):
240
- result = self.core.step(image=img_bchw, mask=m_b1hw, **kwargs)
241
  elif hasattr(self.core, "process_frame"):
242
- result = self.core.process_frame(img_bchw, m_b1hw, **kwargs)
243
  else:
244
  logger.warning("InferenceCore has no recognized frame API; echoing input mask.")
245
  return _alpha_from_result(mask)
@@ -251,9 +316,9 @@ def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
251
  with torch.inference_mode():
252
  with self._maybe_autocast():
253
  if hasattr(self.core, "step"):
254
- result = self.core.step(image=img_bchw, **kwargs)
255
  elif hasattr(self.core, "process_frame"):
256
- result = self.core.process_frame(img_bchw, **kwargs)
257
  else:
258
  h, w = _hw_from_image_like(image)
259
  logger.warning("InferenceCore has no recognized frame API on subsequent call; returning neutral alpha.")
@@ -297,7 +362,7 @@ class MatAnyoneLoader:
297
  Usage:
298
  loader = MatAnyoneLoader(device="cuda")
299
  session = loader.load() # callable
300
- alpha = session(frame, first_frame_mask) # 2-D float32 [0,1]
301
  """
302
 
303
  def __init__(self, device: str = "cuda", cache_dir: str = "./checkpoints/matanyone_cache",
@@ -346,13 +411,9 @@ def _try_build_core(self):
346
  logger.debug(f"ctor(model_id, device, cache_dir) failed: {e}")
347
 
348
  # 3) Minimal ctor
349
- try:
350
- core = InferenceCore(self.model_id)
351
- logger.info("Loaded MatAnyone via InferenceCore(model_id) [minimal]")
352
- return core
353
- except Exception as e:
354
- logger.debug(f"ctor(model_id) failed: {e}")
355
- raise # Propagate last error
356
 
357
  def load(self) -> Optional[MatAnyoneCallableWrapper]:
358
  """Load MatAnyone and return the callable wrapper."""
@@ -364,7 +425,7 @@ def load(self) -> Optional[MatAnyoneCallableWrapper]:
364
 
365
  try:
366
  self.processor = self._try_build_core()
367
- # If the core has an explicit to(device) or set_device, try to use it
368
  try:
369
  if hasattr(self.processor, "to"):
370
  self.processor.to(self.device)
@@ -445,7 +506,7 @@ def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
445
  return self.wrapper(image, mask, **kwargs)
446
 
447
 
448
- # Backwards compatibility alias (legacy session naming)
449
  _MatAnyoneSession = MatAnyoneCallableWrapper
450
 
451
  __all__ = ["MatAnyoneLoader", "_MatAnyoneSession", "MatAnyoneCallableWrapper"]
 
1
  #!/usr/bin/env python3
2
  # -*- coding: utf-8 -*-
3
  """
4
+ MatAnyone Loader - Stable Callable Wrapper for InferenceCore (extra-dim stripping)
5
+ =================================================================================
6
+
7
+ - Always call InferenceCore UNBATCHED:
8
+ image -> CHW float32 [0,1]
9
+ mask -> 1HW float32 [0,1]
10
+ - Aggressively strip extra dims:
11
+ e.g. [B,T,C,H,W] -> [C,H,W] (use first slice when B/T > 1 with a warning)
12
+ e.g. [B,C,H,W] -> [C,H,W]
13
+ e.g. [H,W,C,1] -> [H,W,C]
14
+ - Optional CUDA mixed precision (fp16/bf16)
15
+ - Robust alpha extraction -> (H,W) float32 [0,1]
16
  """
17
 
18
  import os
 
36
  if arr.dtype == np.uint8:
37
  arr = arr.astype(np.float32) / 255.0
38
  else:
39
+ arr = arr.astype(np.float32, copy=False)
 
40
  np.clip(arr, 0.0, 1.0, out=arr)
41
  return arr
42
 
43
 
44
+ def _strip_leading_extras_to_ndim(x: Union[np.ndarray, torch.Tensor], target_ndim: int) -> Union[np.ndarray, torch.Tensor]:
45
  """
46
+ Reduce x to at most target_ndim by removing leading dims.
47
+ - If a leading dim == 1, squeeze it.
48
+ - If a leading dim > 1, take the first slice and log a warning.
49
+ Repeat until ndim <= target_ndim.
50
  """
51
+ is_tensor = torch.is_tensor(x)
52
+ get_shape = (lambda t: tuple(t.shape)) if is_tensor else (lambda a: a.shape)
53
+ index_first = (lambda t: t[0]) if is_tensor else (lambda a: a[0])
54
+ squeeze_first = (lambda t: t.squeeze(0)) if is_tensor else (lambda a: np.squeeze(a, axis=0))
55
+
56
+ while len(get_shape(x)) > target_ndim:
57
+ dim0 = get_shape(x)[0]
58
+ if dim0 == 1:
59
+ x = squeeze_first(x)
60
+ else:
61
+ logger.warning(f"Input has extra leading dim >1 ({dim0}); taking the first slice.")
62
+ x = index_first(x)
63
+ return x
64
+
65
+
66
+ def _ensure_chw_float01(image: Union[np.ndarray, torch.Tensor], *, name: str = "image") -> torch.Tensor:
67
+ """
68
+ Convert image to torch.FloatTensor CHW in [0,1], stripping extras.
69
+ Accepts shapes up to 5D (e.g. B,T,C,H,W / B,C,H,W / H,W,C / CHW / HW / ...).
70
+ If ambiguous multi-channel, picks first channel with a warning.
71
+ """
72
+ orig_shape = tuple(image.shape) if not torch.is_tensor(image) else tuple(image.shape)
73
+ # Reduce to <= 3 dims
74
+ image = _strip_leading_extras_to_ndim(image, 3)
75
+
76
  if torch.is_tensor(image):
77
  t = image
78
+ # Convert 4D (rare if caller passes) once more
79
+ if t.ndim == 4:
80
+ t = _strip_leading_extras_to_ndim(t, 3)
81
+
82
+ if t.ndim == 3:
83
+ c0, c1, c2 = t.shape
84
+ if c0 in (1, 3, 4):
85
+ # CHW
86
+ pass
87
+ elif c2 in (1, 3, 4):
88
+ # HWC -> CHW
89
+ t = t.permute(2, 0, 1)
90
+ else:
91
+ # Ambiguous, assume HWC-like and take first channel after moving to CHW
92
+ logger.warning(f"{name}: ambiguous 3D shape {tuple(t.shape)}; attempting HWC->CHW then selecting first channel.")
93
+ t = t.permute(2, 0, 1)
94
+ if t.shape[0] > 1:
95
+ t = t[0]
96
+ t = t.unsqueeze(0) # back to 1HW
97
+ elif t.ndim == 2:
98
+ # HW -> 1HW
99
  t = t.unsqueeze(0)
100
  else:
101
+ raise ValueError(f"{name}: unsupported tensor dims {tuple(t.shape)} after stripping.")
102
+
103
  t = t.to(dtype=torch.float32)
 
104
  if torch.max(t) > 1.5:
105
  t = t / 255.0
106
  t = torch.clamp(t, 0.0, 1.0)
107
+ logger.debug(f"{name}: {orig_shape} -> {tuple(t.shape)} (CHW)")
108
  return t
109
+
110
+ # numpy path
111
+ arr = np.asarray(image)
112
+ if arr.ndim == 4:
113
+ arr = _strip_leading_extras_to_ndim(arr, 3)
114
+
115
+ if arr.ndim == 3:
116
+ if arr.shape[0] in (1, 3, 4): # CHW
117
  pass
118
+ elif arr.shape[-1] in (1, 3, 4): # HWC -> CHW
119
+ arr = arr.transpose(2, 0, 1)
120
  else:
121
+ logger.warning(f"{name}: ambiguous 3D shape {arr.shape}; trying HWC->CHW and selecting first channel.")
122
+ arr = arr.transpose(2, 0, 1) # HWC->CHW
123
+ if arr.shape[0] > 1:
124
+ arr = arr[0:1, ...] # 1HW
125
+ elif arr.ndim == 2:
126
+ arr = arr[None, ...] # 1HW
127
+ else:
128
+ raise ValueError(f"{name}: unsupported numpy dims {arr.shape} after stripping.")
129
+
130
+ arr = _to_float01_np(arr)
131
+ t = torch.from_numpy(arr)
132
+ logger.debug(f"{name}: {orig_shape} -> {tuple(t.shape)} (CHW)")
133
+ return t
134
 
135
 
136
+ def _ensure_1hw_float01(mask: Union[np.ndarray, torch.Tensor], *, name: str = "mask") -> torch.Tensor:
137
  """
138
+ Convert mask to torch.FloatTensor 1HW in [0,1], stripping extras.
139
+ Accepts up to 4D inputs; collapses leading dims; picks first slice/channel if needed.
140
  """
141
+ orig_shape = tuple(mask.shape) if not torch.is_tensor(mask) else tuple(mask.shape)
142
+ mask = _strip_leading_extras_to_ndim(mask, 3)
143
+
144
  if torch.is_tensor(mask):
145
  m = mask
146
+ if m.ndim == 3:
147
+ # 1HW or CHW or HWC-like
148
+ if m.shape[0] == 1:
149
+ pass # 1HW
150
+ elif m.shape[-1] == 1:
151
+ m = m.permute(2, 0, 1) # HW1 -> 1HW
 
152
  else:
153
+ # If multi-channel, take first
154
+ logger.warning(f"{name}: multi-channel {tuple(m.shape)}; using first channel.")
155
+ # Assume CHW or HWC-like already normalized earlier; prefer leading as channel
156
+ if m.shape[0] in (3, 4):
157
+ m = m[0:1, ...]
158
+ elif m.shape[-1] in (3, 4):
159
+ m = m.permute(2, 0, 1)[0:1, ...]
160
+ else:
161
+ # Ambiguous -> take first along first axis and ensure 1HW
162
+ m = m[0:1, ...]
163
+ elif m.ndim == 2:
164
+ m = m.unsqueeze(0) # 1HW
165
  else:
166
+ raise ValueError(f"{name}: unsupported tensor dims {tuple(m.shape)} after stripping.")
167
+
168
  m = m.to(dtype=torch.float32)
169
  if torch.max(m) > 1.5:
170
  m = m / 255.0
171
  m = torch.clamp(m, 0.0, 1.0)
172
+ logger.debug(f"{name}: {orig_shape} -> {tuple(m.shape)} (1HW)")
173
  return m
174
+
175
+ # numpy path
176
+ arr = np.asarray(mask)
177
+ if arr.ndim == 3:
178
+ if arr.shape[0] == 1:
179
+ pass # 1HW
180
+ elif arr.shape[-1] == 1:
181
+ arr = arr.transpose(2, 0, 1) # HW1 -> 1HW
 
 
 
182
  else:
183
+ logger.warning(f"{name}: multi-channel {arr.shape}; using first channel.")
184
+ if arr.shape[0] in (3, 4):
185
+ arr = arr[0:1, ...] # CHW -> 1HW
186
+ elif arr.shape[-1] in (3, 4):
187
+ arr = arr.transpose(2, 0, 1)[0:1, ...] # HWC -> CHW -> 1HW
188
+ else:
189
+ arr = arr[0:1, ...] # ambiguous -> 1HW by slice
190
+ elif arr.ndim == 2:
191
+ arr = arr[None, ...] # 1HW
192
+ else:
193
+ raise ValueError(f"{name}: unsupported numpy dims {arr.shape} after stripping.")
194
+
195
+ arr = _to_float01_np(arr)
196
+ t = torch.from_numpy(arr)
197
+ logger.debug(f"{name}: {orig_shape} -> {tuple(t.shape)} (1HW)")
198
+ return t
199
 
200
 
201
  def _alpha_from_result(result: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
202
+ """Extract a 2D alpha (H,W) float32 [0,1] from various outputs."""
 
 
 
203
  if result is None:
204
  return np.full((512, 512), 0.5, dtype=np.float32)
205
 
 
207
  result = result.detach().float().cpu()
208
 
209
  arr = np.asarray(result)
210
+ # Strip to <= 3 dims, then extract
211
+ while arr.ndim > 3:
212
+ if arr.shape[0] > 1:
213
+ logger.warning(f"Result has leading dim {arr.shape[0]}; taking first slice.")
214
+ arr = arr[0]
215
+
216
  if arr.ndim == 2:
217
  alpha = arr
218
  elif arr.ndim == 3:
219
+ if arr.shape[0] in (1, 3, 4): # CHW -> take channel 0
 
220
  alpha = arr[0]
221
+ elif arr.shape[-1] in (1, 3, 4): # HWC -> take channel 0
222
  alpha = arr[..., 0]
223
  else:
224
+ alpha = arr[0] # ambiguous
 
 
 
 
 
 
 
 
 
225
  else:
226
+ # 1D or 0D shouldn't happen; fallback
227
  alpha = np.full((512, 512), 0.5, dtype=np.float32)
228
 
229
  alpha = alpha.astype(np.float32, copy=False)
 
232
 
233
 
234
  def _hw_from_image_like(x: Union[np.ndarray, torch.Tensor]) -> Tuple[int, int]:
235
+ """Best-effort infer (H, W) for fallback mask sizing."""
236
  if torch.is_tensor(x):
237
  shape = tuple(x.shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  else:
239
+ shape = np.asarray(x).shape
240
+
241
+ # Try common orders
242
+ if len(shape) == 2: # HW
243
+ return shape[0], shape[1]
244
+ if len(shape) == 3:
245
+ if shape[0] in (1, 3, 4): # CHW
246
+ return shape[1], shape[2]
247
+ if shape[-1] in (1, 3, 4): # HWC
248
+ return shape[0], shape[1]
249
+ # Ambiguous -> treat as CHW
250
+ return shape[1], shape[2]
251
+ if len(shape) >= 4:
252
+ # Assume leading are batch/time; try BCHW first
253
+ if len(shape) >= 4 and (shape[1] in (1, 3, 4)):
254
+ return shape[2], shape[3]
255
+ # Else BHWC-ish
256
+ return shape[-3], shape[-2]
257
+ return 512, 512
258
 
259
 
260
  # --------------------------- Callable Wrapper ---------------------------
 
267
  - First call SHOULD include a mask (1HW). If not, returns neutral 0.5 alpha.
268
  - Subsequent calls do not require mask.
269
  - Returns 2D alpha (H,W) float32 in [0,1].
270
+ - Strips any extra dims from inputs before calling core.
271
  """
272
 
273
  def __init__(self, inference_core, device: str = "cuda", mixed_precision: Optional[str] = "fp16"):
 
280
  if self.device == "cuda" and self.mixed_precision in ("fp16", "bf16"):
281
  dtype = torch.float16 if self.mixed_precision == "fp16" else torch.bfloat16
282
  return torch.autocast(device_type="cuda", dtype=dtype)
283
+ # no-op ctx
284
  class _NullCtx:
285
  def __enter__(self): return None
286
  def __exit__(self, *exc): return False
 
288
 
289
  def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
290
  try:
291
+ # Preprocess (unbatched)
292
+ img_chw = _ensure_chw_float01(image, name="image").to(self.device, non_blocking=True)
 
293
 
294
  if not self.initialized:
295
  if mask is None:
 
297
  logger.warning("MatAnyone first frame called without mask; returning neutral alpha.")
298
  return np.full((h, w), 0.5, dtype=np.float32)
299
 
300
+ m_1hw = _ensure_1hw_float01(mask, name="mask").to(self.device, non_blocking=True)
 
301
 
302
  with torch.inference_mode():
303
  with self._maybe_autocast():
304
  if hasattr(self.core, "step"):
305
+ result = self.core.step(image=img_chw, mask=m_1hw, **kwargs)
306
  elif hasattr(self.core, "process_frame"):
307
+ result = self.core.process_frame(img_chw, m_1hw, **kwargs)
308
  else:
309
  logger.warning("InferenceCore has no recognized frame API; echoing input mask.")
310
  return _alpha_from_result(mask)
 
316
  with torch.inference_mode():
317
  with self._maybe_autocast():
318
  if hasattr(self.core, "step"):
319
+ result = self.core.step(image=img_chw, **kwargs)
320
  elif hasattr(self.core, "process_frame"):
321
+ result = self.core.process_frame(img_chw, **kwargs)
322
  else:
323
  h, w = _hw_from_image_like(image)
324
  logger.warning("InferenceCore has no recognized frame API on subsequent call; returning neutral alpha.")
 
362
  Usage:
363
  loader = MatAnyoneLoader(device="cuda")
364
  session = loader.load() # callable
365
+ alpha = session(frame, first_frame_mask) # returns (H, W) float32
366
  """
367
 
368
  def __init__(self, device: str = "cuda", cache_dir: str = "./checkpoints/matanyone_cache",
 
411
  logger.debug(f"ctor(model_id, device, cache_dir) failed: {e}")
412
 
413
  # 3) Minimal ctor
414
+ core = InferenceCore(self.model_id)
415
+ logger.info("Loaded MatAnyone via InferenceCore(model_id) [minimal]")
416
+ return core
 
 
 
 
417
 
418
  def load(self) -> Optional[MatAnyoneCallableWrapper]:
419
  """Load MatAnyone and return the callable wrapper."""
 
425
 
426
  try:
427
  self.processor = self._try_build_core()
428
+ # Optional device move
429
  try:
430
  if hasattr(self.processor, "to"):
431
  self.processor.to(self.device)
 
506
  return self.wrapper(image, mask, **kwargs)
507
 
508
 
509
+ # Backwards compatibility alias
510
  _MatAnyoneSession = MatAnyoneCallableWrapper
511
 
512
  __all__ = ["MatAnyoneLoader", "_MatAnyoneSession", "MatAnyoneCallableWrapper"]