MogensR commited on
Commit
8e6cc12
·
1 Parent(s): 6095d82

Update models/loaders/matanyone_loader.py

Browse files
Files changed (1) hide show
  1. models/loaders/matanyone_loader.py +351 -148
models/loaders/matanyone_loader.py CHANGED
@@ -1,9 +1,14 @@
1
  #!/usr/bin/env python3
2
  # -*- coding: utf-8 -*-
3
  """
4
- MatAnyone Loader - Wrapper for Official InferenceCore API
5
- =========================================================
6
- Creates a callable wrapper around InferenceCore to maintain compatibility.
 
 
 
 
 
7
  """
8
 
9
  import os
@@ -12,133 +17,295 @@
12
  import tempfile
13
  import traceback
14
  from pathlib import Path
15
- from typing import Optional, Dict, Any, Tuple
16
 
17
  import numpy as np
18
  import torch
19
- import cv2
20
 
21
  logger = logging.getLogger(__name__)
22
 
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  class MatAnyoneCallableWrapper:
25
  """
26
- Callable wrapper around InferenceCore to maintain API compatibility.
27
- Makes the processor work like a callable session.
 
 
 
 
28
  """
29
-
30
- def __init__(self, inference_core):
31
  self.core = inference_core
32
  self.initialized = False
33
-
34
- def __call__(self, image, mask=None, **kwargs):
35
- """
36
- Make this wrapper callable like the old session interface.
37
-
38
- Args:
39
- image: Input image as numpy array
40
- mask: Optional mask for first frame
41
-
42
- Returns:
43
- Alpha mask as 2D numpy array
44
- """
 
 
45
  try:
46
- # For MatAnyone, the first frame needs initialization with a mask
 
 
 
47
  if not self.initialized:
48
  if mask is None:
49
- # Return a default mask if no mask provided for first frame
50
- logger.warning("First frame called without mask, returning default")
51
- if isinstance(image, np.ndarray):
52
- h, w = image.shape[:2]
53
- else:
54
- h, w = 512, 512
55
- return np.ones((h, w), dtype=np.float32) * 0.5
56
-
57
- # Initialize with first frame and mask
58
- # The exact API call depends on the InferenceCore implementation
59
- # This is a placeholder - adjust based on actual API
60
- if hasattr(self.core, 'step'):
61
- result = self.core.step(image=image, mask=mask)
62
- elif hasattr(self.core, 'process_frame'):
63
- result = self.core.process_frame(image, mask)
64
- else:
65
- # Fallback
66
- logger.warning("InferenceCore API unclear, returning input mask")
67
- return mask if isinstance(mask, np.ndarray) else np.array(mask)
68
-
69
  self.initialized = True
70
- return self._extract_alpha(result)
71
- else:
72
- # Subsequent frames - no mask needed
73
- if hasattr(self.core, 'step'):
74
- result = self.core.step(image=image)
75
- elif hasattr(self.core, 'process_frame'):
76
- result = self.core.process_frame(image)
77
- else:
78
- # Fallback - return neutral mask
79
- if isinstance(image, np.ndarray):
80
- h, w = image.shape[:2]
81
  else:
82
- h, w = 512, 512
83
- return np.ones((h, w), dtype=np.float32) * 0.5
84
-
85
- return self._extract_alpha(result)
86
-
 
87
  except Exception as e:
88
  logger.error(f"MatAnyone wrapper call failed: {e}")
89
- # Return a fallback mask
 
90
  if mask is not None:
91
- return mask if isinstance(mask, np.ndarray) else np.array(mask)
92
- if isinstance(image, np.ndarray):
93
- h, w = image.shape[:2]
94
- else:
95
- h, w = 512, 512
96
- return np.ones((h, w), dtype=np.float32) * 0.5
97
-
98
- def _extract_alpha(self, result):
99
- """Extract alpha channel from result."""
100
- if result is None:
101
- return np.ones((512, 512), dtype=np.float32) * 0.5
102
-
103
- if isinstance(result, np.ndarray):
104
- if result.ndim == 2:
105
- return result.astype(np.float32)
106
- elif result.ndim == 3:
107
- # Take first channel or average
108
- return result[..., 0].astype(np.float32)
109
- elif result.ndim == 4:
110
- # Batch dimension - take first
111
- return result[0, 0].astype(np.float32)
112
-
113
- # Try to convert to numpy
114
- try:
115
- arr = np.array(result)
116
- if arr.ndim >= 2:
117
- return arr[..., 0] if arr.ndim > 2 else arr
118
- except:
119
- pass
120
-
121
- return np.ones((512, 512), dtype=np.float32) * 0.5
122
-
123
  def reset(self):
124
- """Reset the session state."""
125
  self.initialized = False
126
- if hasattr(self.core, 'reset'):
127
- self.core.reset()
128
- elif hasattr(self.core, 'clear_memory'):
129
- self.core.clear_memory()
 
 
 
 
 
 
 
130
 
 
131
 
132
  class MatAnyoneLoader:
133
  """
134
- Official MatAnyone loader using InferenceCore API with callable wrapper.
 
 
 
 
 
135
  """
136
-
137
- def __init__(self, device: str = "cuda", cache_dir: str = "./checkpoints/matanyone_cache"):
 
138
  self.device = self._select_device(device)
139
  self.cache_dir = cache_dir
140
  os.makedirs(self.cache_dir, exist_ok=True)
141
-
142
  self.processor = None
143
  self.wrapper = None
144
  self.model_id = "PeiqingYang/MatAnyone"
@@ -146,103 +313,139 @@ def __init__(self, device: str = "cuda", cache_dir: str = "./checkpoints/matanyo
146
  self.loaded = False
147
  self.load_error = None
148
  self.temp_dir = Path(tempfile.mkdtemp())
149
-
 
150
  def _select_device(self, pref: str) -> str:
151
- """Select best available device."""
152
  pref = (pref or "").lower()
153
  if pref.startswith("cuda"):
154
  return "cuda" if torch.cuda.is_available() else "cpu"
155
  if pref == "cpu":
156
  return "cpu"
157
  return "cuda" if torch.cuda.is_available() else "cpu"
158
-
159
- def load(self):
160
- """Load MatAnyone using official InferenceCore API and wrap it."""
161
- if self.loaded and self.wrapper:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  return self.wrapper
163
-
164
- logger.info(f"Loading MatAnyone from HF: {self.model_id} (device={self.device})")
165
  t0 = time.time()
166
-
167
  try:
168
- # Import the official API
169
- from matanyone.inference.inference_core import InferenceCore
170
-
171
- # Create the InferenceCore processor
172
- self.processor = InferenceCore(self.model_id)
173
-
174
- # Wrap it to make it callable
175
- self.wrapper = MatAnyoneCallableWrapper(self.processor)
176
-
 
 
 
 
177
  self.loaded = True
178
  self.load_time = time.time() - t0
179
- logger.info(f"MatAnyone loaded and wrapped successfully in {self.load_time:.2f}s")
180
  return self.wrapper
181
-
182
  except ImportError as e:
183
  self.load_error = f"MatAnyone not installed: {e}"
184
- logger.error(f"Failed to import MatAnyone. Install with: pip install git+https://github.com/pq-yang/MatAnyone.git@main")
 
185
  return None
186
-
187
  except Exception as e:
188
  self.load_error = str(e)
189
  logger.error(f"Failed to load MatAnyone: {e}")
190
  logger.debug(traceback.format_exc())
191
  return None
192
-
193
  def cleanup(self):
194
  """Cleanup temporary files and release resources."""
195
  self.processor = None
196
  self.wrapper = None
197
-
198
  # Clean temp directory
199
  if self.temp_dir.exists():
200
  import shutil
201
  shutil.rmtree(self.temp_dir, ignore_errors=True)
202
-
203
  # Clear CUDA cache if available
204
  if torch.cuda.is_available():
205
  torch.cuda.empty_cache()
206
-
207
  def get_info(self) -> Dict[str, Any]:
208
- """Get model information."""
209
  info = {
210
  "loaded": self.loaded,
211
  "model_id": self.model_id,
212
  "device": str(self.device),
213
- "load_time": self.load_time,
214
  "error": self.load_error,
215
- "api": "InferenceCore (wrapped)"
 
216
  }
217
-
218
- # Add interface info
219
- if self.processor:
220
- info["has_step"] = hasattr(self.processor, 'step')
221
- info["has_process_frame"] = hasattr(self.processor, 'process_frame')
222
- info["has_process_video"] = hasattr(self.processor, 'process_video')
223
-
224
  return info
225
-
226
  def reset(self):
227
  """Reset the processor for a new video."""
228
  if self.wrapper:
229
  self.wrapper.reset()
230
  logger.info("MatAnyone session reset")
231
-
232
- # Compatibility - make the loader itself callable
233
- def __call__(self, image, mask=None, **kwargs):
234
- """Direct call compatibility."""
235
- if not self.wrapper:
236
- if not self.load():
237
  # Fallback if loading fails
238
  if mask is not None:
239
- return mask if isinstance(mask, np.ndarray) else np.array(mask)
240
- return np.zeros(image.shape[:2], dtype=np.float32)
241
-
 
 
 
242
  return self.wrapper(image, mask, **kwargs)
243
 
244
 
245
- # For backwards compatibility
246
  _MatAnyoneSession = MatAnyoneCallableWrapper
247
 
248
- __all__ = ["MatAnyoneLoader", "_MatAnyoneSession", "MatAnyoneCallableWrapper"]
 
1
  #!/usr/bin/env python3
2
  # -*- coding: utf-8 -*-
3
  """
4
+ MatAnyone Loader - Stable Callable Wrapper for InferenceCore
5
+ ===========================================================
6
+
7
+ - Enforces image CHW float32 [0,1] and mask 1HW float32 [0,1]
8
+ - Adds internal batch dim (B=1) and removes it on output
9
+ - Works with multiple possible InferenceCore loading signatures
10
+ - Uses torch.inference_mode() + optional autocast for speed
11
+ - Returns a 2-D alpha mask (H,W) float32 in [0,1]
12
  """
13
 
14
  import os
 
17
  import tempfile
18
  import traceback
19
  from pathlib import Path
20
+ from typing import Optional, Dict, Any, Tuple, Union
21
 
22
  import numpy as np
23
  import torch
 
24
 
25
  logger = logging.getLogger(__name__)
26
 
27
 
28
+ # ------------------------------ Helpers ------------------------------
29
+
30
+ def _to_float01_np(arr: np.ndarray) -> np.ndarray:
31
+ """Ensure numpy array is float32 in [0,1]."""
32
+ if arr.dtype == np.uint8:
33
+ arr = arr.astype(np.float32) / 255.0
34
+ else:
35
+ arr = arr.astype(np.float32)
36
+ # Clamp for safety
37
+ np.clip(arr, 0.0, 1.0, out=arr)
38
+ return arr
39
+
40
+
41
+ def _ensure_chw_float01(image: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
42
+ """
43
+ Convert image to torch.FloatTensor CHW in [0,1].
44
+ Accepts HxWxC or CHW (numpy or tensor). Adds batch dim later.
45
+ """
46
+ if torch.is_tensor(image):
47
+ t = image
48
+ if t.ndim == 3 and t.shape[0] in (1, 3, 4): # already CHW
49
+ pass
50
+ elif t.ndim == 3 and t.shape[-1] in (1, 3, 4): # HWC -> CHW
51
+ t = t.permute(2, 0, 1)
52
+ elif t.ndim == 2: # HW (grayscale)
53
+ t = t.unsqueeze(0)
54
+ else:
55
+ raise ValueError(f"Unsupported image tensor shape: {tuple(t.shape)}")
56
+ t = t.to(dtype=torch.float32)
57
+ # If likely 0-255, scale; otherwise clamp to [0,1]
58
+ if torch.max(t) > 1.5:
59
+ t = t / 255.0
60
+ t = torch.clamp(t, 0.0, 1.0)
61
+ return t
62
+ else:
63
+ arr = np.asarray(image)
64
+ if arr.ndim == 3 and arr.shape[2] in (1, 3, 4): # HWC
65
+ arr = arr.transpose(2, 0, 1) # -> CHW
66
+ elif arr.ndim == 2: # HW
67
+ arr = arr[None, ...] # -> 1HW
68
+ elif arr.ndim == 3 and arr.shape[0] in (1, 3, 4): # already CHW
69
+ pass
70
+ else:
71
+ raise ValueError(f"Unsupported image numpy shape: {arr.shape}")
72
+ arr = _to_float01_np(arr)
73
+ return torch.from_numpy(arr)
74
+
75
+
76
+ def _ensure_1hw_float01(mask: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
77
+ """
78
+ Convert mask to torch.FloatTensor 1HW in [0,1].
79
+ Accepts HW, 1HW, CHW (C=1), HxWx1.
80
+ """
81
+ if torch.is_tensor(mask):
82
+ m = mask
83
+ if m.ndim == 2: # HW
84
+ m = m.unsqueeze(0) # 1HW
85
+ elif m.ndim == 3:
86
+ if m.shape[0] == 1: # 1HW
87
+ pass
88
+ elif m.shape[-1] == 1: # HW1 -> 1HW
89
+ m = m.permute(2, 0, 1)
90
+ else:
91
+ raise ValueError(f"Mask has too many channels: {tuple(m.shape)}")
92
+ else:
93
+ raise ValueError(f"Unsupported mask tensor shape: {tuple(m.shape)}")
94
+ m = m.to(dtype=torch.float32)
95
+ if torch.max(m) > 1.5:
96
+ m = m / 255.0
97
+ m = torch.clamp(m, 0.0, 1.0)
98
+ return m
99
+ else:
100
+ arr = np.asarray(mask)
101
+ if arr.ndim == 2: # HW
102
+ arr = arr[None, ...] # 1HW
103
+ elif arr.ndim == 3:
104
+ if arr.shape[0] == 1: # 1HW
105
+ pass
106
+ elif arr.shape[-1] == 1: # HW1 -> 1HW
107
+ arr = arr.transpose(2, 0, 1)
108
+ else:
109
+ raise ValueError(f"Mask has too many channels: {arr.shape}")
110
+ else:
111
+ raise ValueError(f"Unsupported mask numpy shape: {arr.shape}")
112
+ arr = _to_float01_np(arr)
113
+ return torch.from_numpy(arr)
114
+
115
+
116
+ def _alpha_from_result(result: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
117
+ """
118
+ Extract a 2D alpha (H,W) float32 [0,1] from a variety of possible outputs.
119
+ Accepts numpy/tensor with shapes: HW, 1HW, CHW(C>=1), BHWC, BCHW, etc.
120
+ """
121
+ if result is None:
122
+ return np.full((512, 512), 0.5, dtype=np.float32)
123
+
124
+ if torch.is_tensor(result):
125
+ result = result.detach().float().cpu()
126
+
127
+ arr = np.asarray(result)
128
+ if arr.ndim == 2:
129
+ alpha = arr
130
+ elif arr.ndim == 3:
131
+ # Prefer first channel for CHW/HWC
132
+ if arr.shape[0] in (1, 3, 4): # CHW
133
+ alpha = arr[0]
134
+ elif arr.shape[-1] in (1, 3, 4): # HWC
135
+ alpha = arr[..., 0]
136
+ else:
137
+ # Unknown 3D shape – take first slice robustly
138
+ alpha = arr[0]
139
+ elif arr.ndim == 4:
140
+ # Batch first: BxCxHxW or BxHxWxC
141
+ if arr.shape[1] in (1, 3, 4): # BCHW
142
+ alpha = arr[0, 0]
143
+ elif arr.shape[-1] in (1, 3, 4): # BHWC
144
+ alpha = arr[0, ..., 0]
145
+ else:
146
+ alpha = arr[0, 0]
147
+ else:
148
+ # Fallback
149
+ alpha = np.full((512, 512), 0.5, dtype=np.float32)
150
+
151
+ alpha = alpha.astype(np.float32, copy=False)
152
+ np.clip(alpha, 0.0, 1.0, out=alpha)
153
+ return alpha
154
+
155
+
156
+ def _hw_from_image_like(x: Union[np.ndarray, torch.Tensor]) -> Tuple[int, int]:
157
+ """Best-effort get (H, W) from an image/mask input for neutral fallbacks."""
158
+ if torch.is_tensor(x):
159
+ shape = tuple(x.shape)
160
+ # Handle CHW / HWC / BCHW / BHWC / HW
161
+ if len(shape) == 2: # HW
162
+ return shape[0], shape[1]
163
+ if len(shape) == 3:
164
+ if shape[0] in (1, 3, 4): # CHW
165
+ return shape[1], shape[2]
166
+ if shape[-1] in (1, 3, 4): # HWC
167
+ return shape[0], shape[1]
168
+ if len(shape) == 4:
169
+ # Assume batch first
170
+ b, c_or_h, h_or_w, maybe_w = shape
171
+ # Try BCHW
172
+ if shape[1] in (1, 3, 4):
173
+ return shape[2], shape[3]
174
+ # Try BHWC
175
+ return shape[1], shape[2]
176
+ return 512, 512
177
+ else:
178
+ arr = np.asarray(x)
179
+ if arr.ndim == 2: # HW
180
+ return arr.shape[0], arr.shape[1]
181
+ if arr.ndim == 3:
182
+ if arr.shape[0] in (1, 3, 4): # CHW
183
+ return arr.shape[1], arr.shape[2]
184
+ if arr.shape[-1] in (1, 3, 4): # HWC
185
+ return arr.shape[0], arr.shape[1]
186
+ if arr.ndim == 4:
187
+ # Assume batch first
188
+ if arr.shape[1] in (1, 3, 4): # BCHW
189
+ return arr.shape[2], arr.shape[3]
190
+ return arr.shape[1], arr.shape[2]
191
+ return 512, 512
192
+
193
+
194
+ # --------------------------- Callable Wrapper ---------------------------
195
+
196
  class MatAnyoneCallableWrapper:
197
  """
198
+ Callable session-like wrapper around an InferenceCore instance.
199
+
200
+ Contract:
201
+ - First call SHOULD include a mask (1HW). If not, returns neutral 0.5 alpha.
202
+ - Subsequent calls do not require mask.
203
+ - Returns 2D alpha (H,W) float32 in [0,1].
204
  """
205
+
206
+ def __init__(self, inference_core, device: str = "cuda", mixed_precision: Optional[str] = "fp16"):
207
  self.core = inference_core
208
  self.initialized = False
209
+ self.device = device if (device in ("cuda", "cpu")) else ("cuda" if torch.cuda.is_available() else "cpu")
210
+ self.mixed_precision = mixed_precision if self.device == "cuda" else None # "fp16"|"bf16"|None
211
+
212
+ def _maybe_autocast(self):
213
+ if self.device == "cuda" and self.mixed_precision in ("fp16", "bf16"):
214
+ dtype = torch.float16 if self.mixed_precision == "fp16" else torch.bfloat16
215
+ return torch.autocast(device_type="cuda", dtype=dtype)
216
+ # no-op context manager
217
+ class _NullCtx:
218
+ def __enter__(self): return None
219
+ def __exit__(self, *exc): return False
220
+ return _NullCtx()
221
+
222
+ def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
223
  try:
224
+ # Preprocess CHW/1HW tensors, then add batch
225
+ img_chw = _ensure_chw_float01(image).to(self.device, non_blocking=True)
226
+ img_bchw = img_chw.unsqueeze(0) # B=1
227
+
228
  if not self.initialized:
229
  if mask is None:
230
+ h, w = _hw_from_image_like(image)
231
+ logger.warning("MatAnyone first frame called without mask; returning neutral alpha.")
232
+ return np.full((h, w), 0.5, dtype=np.float32)
233
+
234
+ m_1hw = _ensure_1hw_float01(mask).to(self.device, non_blocking=True)
235
+ m_b1hw = m_1hw.unsqueeze(0) # B=1
236
+
237
+ with torch.inference_mode():
238
+ with self._maybe_autocast():
239
+ if hasattr(self.core, "step"):
240
+ result = self.core.step(image=img_bchw, mask=m_b1hw, **kwargs)
241
+ elif hasattr(self.core, "process_frame"):
242
+ result = self.core.process_frame(img_bchw, m_b1hw, **kwargs)
243
+ else:
244
+ logger.warning("InferenceCore has no recognized frame API; echoing input mask.")
245
+ return _alpha_from_result(mask)
246
+
 
 
 
247
  self.initialized = True
248
+ return _alpha_from_result(result)
249
+
250
+ # Subsequent frames (no mask)
251
+ with torch.inference_mode():
252
+ with self._maybe_autocast():
253
+ if hasattr(self.core, "step"):
254
+ result = self.core.step(image=img_bchw, **kwargs)
255
+ elif hasattr(self.core, "process_frame"):
256
+ result = self.core.process_frame(img_bchw, **kwargs)
 
 
257
  else:
258
+ h, w = _hw_from_image_like(image)
259
+ logger.warning("InferenceCore has no recognized frame API on subsequent call; returning neutral alpha.")
260
+ return np.full((h, w), 0.5, dtype=np.float32)
261
+
262
+ return _alpha_from_result(result)
263
+
264
  except Exception as e:
265
  logger.error(f"MatAnyone wrapper call failed: {e}")
266
+ logger.debug(traceback.format_exc())
267
+ # Fallbacks
268
  if mask is not None:
269
+ try:
270
+ return _alpha_from_result(mask)
271
+ except Exception:
272
+ pass
273
+ h, w = _hw_from_image_like(image)
274
+ return np.full((h, w), 0.5, dtype=np.float32)
275
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  def reset(self):
277
+ """Reset state between videos."""
278
  self.initialized = False
279
+ if hasattr(self.core, "reset"):
280
+ try:
281
+ self.core.reset()
282
+ except Exception as e:
283
+ logger.debug(f"Core reset() failed: {e}")
284
+ elif hasattr(self.core, "clear_memory"):
285
+ try:
286
+ self.core.clear_memory()
287
+ except Exception as e:
288
+ logger.debug(f"Core clear_memory() failed: {e}")
289
+
290
 
291
+ # ------------------------------- Loader -------------------------------
292
 
293
  class MatAnyoneLoader:
294
  """
295
+ Loads MatAnyone's InferenceCore and returns a callable wrapper.
296
+
297
+ Usage:
298
+ loader = MatAnyoneLoader(device="cuda")
299
+ session = loader.load() # callable
300
+ alpha = session(frame, first_frame_mask) # 2-D float32 [0,1]
301
  """
302
+
303
+ def __init__(self, device: str = "cuda", cache_dir: str = "./checkpoints/matanyone_cache",
304
+ mixed_precision: Optional[str] = "fp16"):
305
  self.device = self._select_device(device)
306
  self.cache_dir = cache_dir
307
  os.makedirs(self.cache_dir, exist_ok=True)
308
+
309
  self.processor = None
310
  self.wrapper = None
311
  self.model_id = "PeiqingYang/MatAnyone"
 
313
  self.loaded = False
314
  self.load_error = None
315
  self.temp_dir = Path(tempfile.mkdtemp())
316
+ self.mixed_precision = mixed_precision if self.device == "cuda" else None
317
+
318
  def _select_device(self, pref: str) -> str:
 
319
  pref = (pref or "").lower()
320
  if pref.startswith("cuda"):
321
  return "cuda" if torch.cuda.is_available() else "cpu"
322
  if pref == "cpu":
323
  return "cpu"
324
  return "cuda" if torch.cuda.is_available() else "cpu"
325
+
326
+ def _try_build_core(self):
327
+ """
328
+ Try multiple constructor patterns to survive API changes.
329
+ """
330
+ from matanyone.inference.inference_core import InferenceCore
331
+
332
+ # 1) Preferred: from_pretrained(...)
333
+ try:
334
+ core = InferenceCore.from_pretrained(self.model_id, device=self.device, cache_dir=self.cache_dir)
335
+ logger.info("Loaded MatAnyone via InferenceCore.from_pretrained(...)")
336
+ return core
337
+ except Exception as e:
338
+ logger.debug(f"from_pretrained failed: {e}")
339
+
340
+ # 2) Direct ctor with device/cache_dir
341
+ try:
342
+ core = InferenceCore(self.model_id, device=self.device, cache_dir=self.cache_dir)
343
+ logger.info("Loaded MatAnyone via InferenceCore(model_id, device, cache_dir)")
344
+ return core
345
+ except Exception as e:
346
+ logger.debug(f"ctor(model_id, device, cache_dir) failed: {e}")
347
+
348
+ # 3) Minimal ctor
349
+ try:
350
+ core = InferenceCore(self.model_id)
351
+ logger.info("Loaded MatAnyone via InferenceCore(model_id) [minimal]")
352
+ return core
353
+ except Exception as e:
354
+ logger.debug(f"ctor(model_id) failed: {e}")
355
+ raise # Propagate last error
356
+
357
+ def load(self) -> Optional[MatAnyoneCallableWrapper]:
358
+ """Load MatAnyone and return the callable wrapper."""
359
+ if self.loaded and self.wrapper is not None:
360
  return self.wrapper
361
+
362
+ logger.info(f"Loading MatAnyone: {self.model_id} (device={self.device})")
363
  t0 = time.time()
364
+
365
  try:
366
+ self.processor = self._try_build_core()
367
+ # If the core has an explicit to(device) or set_device, try to use it
368
+ try:
369
+ if hasattr(self.processor, "to"):
370
+ self.processor.to(self.device)
371
+ elif hasattr(self.processor, "set_device"):
372
+ self.processor.set_device(self.device)
373
+ except Exception as e:
374
+ logger.debug(f"Optional device move failed: {e}")
375
+
376
+ self.wrapper = MatAnyoneCallableWrapper(
377
+ self.processor, device=self.device, mixed_precision=self.mixed_precision
378
+ )
379
  self.loaded = True
380
  self.load_time = time.time() - t0
381
+ logger.info(f"MatAnyone loaded and wrapped in {self.load_time:.2f}s")
382
  return self.wrapper
383
+
384
  except ImportError as e:
385
  self.load_error = f"MatAnyone not installed: {e}"
386
+ logger.error("Failed to import MatAnyone. Install with: "
387
+ "pip install git+https://github.com/pq-yang/MatAnyone.git@main")
388
  return None
 
389
  except Exception as e:
390
  self.load_error = str(e)
391
  logger.error(f"Failed to load MatAnyone: {e}")
392
  logger.debug(traceback.format_exc())
393
  return None
394
+
395
  def cleanup(self):
396
  """Cleanup temporary files and release resources."""
397
  self.processor = None
398
  self.wrapper = None
399
+
400
  # Clean temp directory
401
  if self.temp_dir.exists():
402
  import shutil
403
  shutil.rmtree(self.temp_dir, ignore_errors=True)
404
+
405
  # Clear CUDA cache if available
406
  if torch.cuda.is_available():
407
  torch.cuda.empty_cache()
408
+
409
  def get_info(self) -> Dict[str, Any]:
410
+ """Get model information and interface flags."""
411
  info = {
412
  "loaded": self.loaded,
413
  "model_id": self.model_id,
414
  "device": str(self.device),
415
+ "load_time": float(self.load_time),
416
  "error": self.load_error,
417
+ "api": "InferenceCore (wrapped)",
418
+ "mixed_precision": self.mixed_precision,
419
  }
420
+ proc = self.processor
421
+ if proc is not None:
422
+ info["has_step"] = hasattr(proc, "step")
423
+ info["has_process_frame"] = hasattr(proc, "process_frame")
424
+ info["has_process_video"] = hasattr(proc, "process_video")
 
 
425
  return info
426
+
427
  def reset(self):
428
  """Reset the processor for a new video."""
429
  if self.wrapper:
430
  self.wrapper.reset()
431
  logger.info("MatAnyone session reset")
432
+
433
+ # Make the loader itself callable (direct compatibility)
434
+ def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
435
+ if self.wrapper is None:
436
+ if self.load() is None:
 
437
  # Fallback if loading fails
438
  if mask is not None:
439
+ try:
440
+ return _alpha_from_result(mask)
441
+ except Exception:
442
+ pass
443
+ h, w = _hw_from_image_like(image)
444
+ return np.zeros((h, w), dtype=np.float32)
445
  return self.wrapper(image, mask, **kwargs)
446
 
447
 
448
+ # Backwards compatibility alias (legacy session naming)
449
  _MatAnyoneSession = MatAnyoneCallableWrapper
450
 
451
+ __all__ = ["MatAnyoneLoader", "_MatAnyoneSession", "MatAnyoneCallableWrapper"]