MogensR commited on
Commit
7280fe8
·
1 Parent(s): b20702e

Update models/loaders/matanyone_loader.py

Browse files
Files changed (1) hide show
  1. models/loaders/matanyone_loader.py +284 -412
models/loaders/matanyone_loader.py CHANGED
@@ -1,48 +1,258 @@
1
  #!/usr/bin/env python3
2
  """
3
- MatAnyone Model Loader
4
- Handles MatAnyone loading with proper device initialization
 
 
 
 
 
5
  """
6
 
7
  import os
8
  import time
9
  import logging
10
  import traceback
11
- from pathlib import Path
12
- from typing import Optional, Dict, Any
13
 
14
- import torch
15
  import numpy as np
 
16
 
17
  logger = logging.getLogger(__name__)
18
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  class MatAnyoneLoader:
21
- """Dedicated loader for MatAnyone models"""
22
-
23
  def __init__(self, device: str = "cuda", cache_dir: str = "./checkpoints/matanyone_cache"):
24
- self.device = device
25
  self.cache_dir = cache_dir
26
  os.makedirs(self.cache_dir, exist_ok=True)
27
-
28
- self.model = None
29
  self.model_id = "PeiqingYang/MatAnyone"
30
  self.load_time = 0.0
31
-
32
  def load(self) -> Optional[Any]:
33
  """
34
- Load MatAnyone model
35
  Returns:
36
- Loaded model or None
37
  """
38
- logger.info(f"Loading MatAnyone model: {self.model_id}")
39
-
40
- # Try loading strategies in order
41
  strategies = [
42
  ("official", self._load_official),
43
- ("fallback", self._load_fallback)
44
  ]
45
-
46
  for strategy_name, strategy_func in strategies:
47
  try:
48
  logger.info(f"Trying MatAnyone loading strategy: {strategy_name}")
@@ -51,423 +261,85 @@ def load(self) -> Optional[Any]:
51
  if model:
52
  self.load_time = time.time() - start_time
53
  self.model = model
54
- logger.info(f"MatAnyone loaded successfully via {strategy_name} in {self.load_time:.2f}s")
55
  return model
56
  except Exception as e:
57
  logger.error(f"MatAnyone {strategy_name} strategy failed: {e}")
58
  logger.debug(traceback.format_exc())
59
  continue
60
-
61
  logger.error("All MatAnyone loading strategies failed")
62
  return None
63
-
64
  def _load_official(self) -> Optional[Any]:
65
- """Load using official MatAnyone API with comprehensive shape guard"""
66
- from matanyone import InferenceCore
67
-
68
- # Create processor - pass model ID as positional argument
69
- processor = InferenceCore(self.model_id)
70
-
71
- # Install the critical shape guard patch from original loader
72
- self._install_shape_guard(processor)
73
-
74
- return processor
75
-
76
- def _install_shape_guard(self, processor):
77
  """
78
- Install the comprehensive shape guard from the original loader.
79
- This is CRITICAL for preventing 5D tensor issues and ensuring compatibility.
80
  """
81
- import torch
82
- import numpy as np
83
-
84
- device = self.device
85
-
86
- # Helper functions for tensor manipulation
87
- def ensure_image_nchw(img: torch.Tensor, want_batched: bool = True) -> torch.Tensor:
88
- """Ensure image is in NCHW format"""
89
- if isinstance(img, np.ndarray):
90
- img = torch.from_numpy(img)
91
-
92
- img = img.to(device)
93
-
94
- # Handle 5D tensors (B,T,C,H,W) by squeezing time dimension
95
- while img.ndim == 5:
96
- if img.shape[0] == 1:
97
- img = img.squeeze(0)
98
- elif img.shape[1] == 1:
99
- img = img.squeeze(1)
100
- else:
101
- # Can't auto-squeeze, take first time frame
102
- img = img[:, 0]
103
-
104
- # Handle various input formats
105
- if img.ndim == 3:
106
- # CHW or HWC
107
- if img.shape[0] in (1, 3, 4): # Likely CHW
108
- chw = img
109
- elif img.shape[-1] in (1, 3, 4): # Likely HWC
110
- chw = img.permute(2, 0, 1)
111
- else:
112
- # Assume CHW
113
- chw = img
114
-
115
- # Ensure float and normalized
116
- if chw.dtype != torch.float32:
117
- chw = chw.float()
118
- if chw.max() > 1.0:
119
- chw = chw / 255.0
120
-
121
- return chw.unsqueeze(0) if want_batched else chw
122
-
123
- elif img.ndim == 4:
124
- # NCHW or NHWC
125
- N, A, B, C = img.shape
126
- if A in (1, 3, 4): # NCHW
127
- nchw = img
128
- elif C in (1, 3, 4): # NHWC
129
- nchw = img.permute(0, 3, 1, 2)
130
- else:
131
- # Assume NCHW
132
- nchw = img
133
-
134
- # Ensure float and normalized
135
- if nchw.dtype != torch.float32:
136
- nchw = nchw.float()
137
- if nchw.max() > 1.0:
138
- nchw = nchw / 255.0
139
-
140
- return nchw if want_batched else nchw.squeeze(0) if not want_batched and nchw.shape[0] == 1 else nchw[0]
141
-
142
- else:
143
- logger.error(f"Unexpected image dimensions: {img.shape}")
144
- # Return something safe
145
- return torch.zeros((3, 512, 512), device=device, dtype=torch.float32).unsqueeze(0) if want_batched else torch.zeros((3, 512, 512), device=device, dtype=torch.float32)
146
-
147
- def ensure_mask_for_matanyone(mask: torch.Tensor, idx_mask: bool = False,
148
- threshold: float = 0.5, keep_soft: bool = False) -> torch.Tensor:
149
- """Ensure mask is in correct format for MatAnyone"""
150
- if isinstance(mask, np.ndarray):
151
- mask = torch.from_numpy(mask)
152
-
153
- mask = mask.to(device)
154
-
155
- # Handle 5D tensors
156
- if mask.ndim == 5:
157
- if mask.shape[1] == 1:
158
- mask = mask.squeeze(1)
159
- if mask.shape[0] == 1 and mask.ndim == 5:
160
- mask = mask.squeeze(0)
161
-
162
- # Handle index masks
163
- if idx_mask:
164
- if mask.ndim == 3:
165
- if mask.shape[0] == 1:
166
- idx = (mask[0] >= threshold).to(torch.long)
167
- else:
168
- idx = torch.argmax(mask, dim=0).to(torch.long)
169
- idx = (idx > 0).to(torch.long)
170
- elif mask.ndim == 2:
171
- idx = (mask >= threshold).to(torch.long)
172
- else:
173
- logger.warning(f"Unexpected idx mask shape: {mask.shape}")
174
- idx = torch.zeros((512, 512), device=device, dtype=torch.long)
175
- return idx
176
-
177
- # Handle channel masks
178
- if mask.ndim == 2:
179
- out = mask.unsqueeze(0) # Add channel dimension
180
- elif mask.ndim == 3:
181
- if mask.shape[0] == 1:
182
- out = mask
183
  else:
184
- # Choose channel with largest area
185
- areas = mask.sum(dim=(-2, -1))
186
- best_idx = areas.argmax()
187
- out = mask[best_idx:best_idx+1]
188
- else:
189
- logger.warning(f"Unexpected mask shape: {mask.shape}")
190
- out = torch.ones((1, 512, 512), device=device, dtype=torch.float32)
191
-
192
- # Convert to float and normalize
193
- out = out.to(torch.float32)
194
- if not keep_soft:
195
- out = (out >= threshold).to(torch.float32)
196
-
197
- return out.clamp_(0.0, 1.0).contiguous()
198
-
199
- # Create the guarded wrapper
200
- def create_guarded_method(original_method):
201
- """Create a guarded version of a MatAnyone method"""
202
- def guarded_method(*args, **kwargs):
203
- # Extract image and mask
204
- image = kwargs.get("image", None)
205
- mask = kwargs.get("mask", None)
206
- idx_mask = kwargs.get("idx_mask", kwargs.get("index_mask", False))
207
-
208
- # Handle positional arguments
209
- if image is None and len(args) >= 1:
210
- image = args[0]
211
- if mask is None and len(args) >= 2:
212
- mask = args[1]
213
-
214
- if image is None or mask is None:
215
- logger.error(f"MatAnyone called without image/mask: args={len(args)}, kwargs={list(kwargs.keys())}")
216
- # Return something safe
217
- return torch.ones((1, 512, 512), dtype=torch.float32) * 0.5
218
-
219
  try:
220
- # Coerce shapes - ensure we REALLY squeeze out extra dimensions
221
- img_nchw = ensure_image_nchw(image, want_batched=True)
222
-
223
- # CRITICAL FIX: Force squeeze all unnecessary dimensions
224
- while img_nchw.ndim > 4:
225
- if img_nchw.shape[0] == 1:
226
- img_nchw = img_nchw.squeeze(0)
227
- elif img_nchw.shape[1] == 1:
228
- img_nchw = img_nchw.squeeze(1)
229
- else:
230
- break
231
-
232
- if idx_mask:
233
- m_fixed = ensure_mask_for_matanyone(mask, idx_mask=True)
234
- else:
235
- m_fixed = ensure_mask_for_matanyone(mask, idx_mask=False, threshold=0.5)
236
-
237
- # Log actual shapes being passed
238
- logger.debug(f"MatAnyone input - image: {img_nchw.shape}, mask: {m_fixed.shape}, idx: {idx_mask}")
239
-
240
- # For MatAnyone, we need CHW not NCHW for unbatched
241
- if img_nchw.ndim == 4 and img_nchw.shape[0] == 1:
242
- img_chw = img_nchw[0] # Remove batch dimension
243
- else:
244
- img_chw = img_nchw
245
-
246
- # Try unbatched first (most common)
247
- try:
248
- new_kwargs = dict(kwargs)
249
- new_kwargs["image"] = img_chw # CHW
250
- new_kwargs["mask"] = m_fixed.squeeze(0) if m_fixed.ndim > 2 and m_fixed.shape[0] == 1 else m_fixed
251
- new_kwargs["idx_mask"] = bool(idx_mask)
252
-
253
- result = original_method(**new_kwargs)
254
- return result
255
-
256
- except Exception as e1:
257
- logger.debug(f"Unbatched call failed, trying batched: {e1}")
258
- # Try with batch dimension
259
- new_kwargs = dict(kwargs)
260
- new_kwargs["image"] = img_nchw # NCHW
261
- new_kwargs["mask"] = m_fixed
262
- new_kwargs["idx_mask"] = bool(idx_mask)
263
-
264
- result = original_method(**new_kwargs)
265
- return result
266
-
267
- except Exception as e:
268
- logger.error(f"MatAnyone guarded call failed: {e}")
269
- import traceback
270
- logger.debug(traceback.format_exc())
271
- # Return input mask as fallback
272
- if isinstance(mask, torch.Tensor):
273
- return mask.cpu().numpy()
274
- elif isinstance(mask, np.ndarray):
275
- return mask
276
- else:
277
- return np.ones((512, 512), dtype=np.float32) * 0.5
278
-
279
- return guarded_method
280
-
281
- # Apply guard to both step and process methods
282
- if hasattr(processor, 'step'):
283
- original_step = processor.step
284
- processor.step = create_guarded_method(original_step)
285
- logger.info("Installed shape guard on MatAnyone.step")
286
-
287
- if hasattr(processor, 'process'):
288
- original_process = processor.process
289
- processor.process = create_guarded_method(original_process)
290
- logger.info("Installed shape guard on MatAnyone.process")
291
-
292
- def _patch_processor(self, processor):
293
- """
294
- Patch the MatAnyone processor to handle device placement and tensor formats correctly
295
- """
296
- original_step = getattr(processor, 'step', None)
297
- original_process = getattr(processor, 'process', None)
298
-
299
- device = self.device
300
-
301
- def safe_wrapper(*args, **kwargs):
302
- """Universal wrapper that handles both step and process calls"""
303
- try:
304
- # Handle different calling patterns
305
- # Pattern 1: step(image, mask, idx_mask=False)
306
- # Pattern 2: process(image, mask)
307
- # Pattern 3: Called with just args
308
- # Pattern 4: Called with kwargs
309
-
310
- image = None
311
- mask = None
312
- idx_mask = kwargs.get('idx_mask', False)
313
-
314
- # Extract image and mask
315
- if 'image' in kwargs and 'mask' in kwargs:
316
- image = kwargs['image']
317
- mask = kwargs['mask']
318
- elif len(args) >= 2:
319
- image = args[0]
320
- mask = args[1]
321
- if len(args) > 2:
322
- idx_mask = args[2]
323
- elif len(args) == 1:
324
- # Might be called with just mask for refinement
325
- mask = args[0]
326
- # Create dummy image if needed
327
- if isinstance(mask, np.ndarray):
328
- h, w = mask.shape[:2] if mask.ndim >= 2 else (512, 512)
329
- image = np.zeros((h, w, 3), dtype=np.uint8)
330
- elif isinstance(mask, torch.Tensor):
331
- h, w = mask.shape[-2:] if mask.dim() >= 2 else (512, 512)
332
- image = torch.zeros((h, w, 3), dtype=torch.uint8)
333
-
334
- if image is None or mask is None:
335
- logger.error(f"MatAnyone called with invalid args: {len(args)} args, kwargs: {kwargs.keys()}")
336
- # Return something safe
337
- if mask is not None:
338
- return mask
339
- return np.ones((512, 512), dtype=np.float32) * 0.5
340
-
341
- # Convert to tensors on correct device
342
- if isinstance(image, np.ndarray):
343
- image = torch.from_numpy(image).to(device)
344
- elif isinstance(image, torch.Tensor):
345
- image = image.to(device)
346
-
347
- if isinstance(mask, np.ndarray):
348
- mask = torch.from_numpy(mask).to(device)
349
- elif isinstance(mask, torch.Tensor):
350
- mask = mask.to(device)
351
-
352
- # Fix image format (ensure CHW or NCHW)
353
- if image.dim() == 2: # Grayscale HW
354
- image = image.unsqueeze(0) # CHW
355
- elif image.dim() == 3:
356
- # Check if HWC or CHW
357
- if image.shape[-1] in [1, 3, 4]: # HWC
358
- image = image.permute(2, 0, 1) # CHW
359
- # Add batch if needed
360
- if image.shape[0] in [1, 3, 4]: # CHW
361
- image = image.unsqueeze(0) # NCHW
362
- elif image.dim() == 4:
363
- # Already NCHW, ensure correct channel position
364
- if image.shape[-1] in [1, 3, 4]: # NHWC
365
- image = image.permute(0, 3, 1, 2) # NCHW
366
-
367
- # Fix mask format
368
- if mask.dim() == 2:
369
- mask = mask.unsqueeze(0) # Add channel: CHW
370
- elif mask.dim() == 3:
371
- if mask.shape[0] > 4: # Likely HWC
372
- mask = mask.permute(2, 0, 1) # CHW
373
-
374
- # Ensure float and normalized
375
- if image.dtype != torch.float32:
376
- image = image.float()
377
- if not idx_mask and mask.dtype != torch.float32:
378
- mask = mask.float()
379
-
380
- if image.max() > 1.0:
381
- image = image / 255.0
382
- if not idx_mask and mask.max() > 1.0:
383
- mask = mask / 255.0
384
-
385
- # Call original method if it exists
386
- if original_step:
387
- try:
388
- result = original_step(image, mask, idx_mask=idx_mask)
389
- # Convert result back to numpy if needed
390
- if isinstance(result, torch.Tensor):
391
- result = result.cpu().numpy()
392
- return result
393
- except Exception as e:
394
- logger.error(f"MatAnyone original step failed: {e}")
395
-
396
- # Fallback: return slightly processed mask
397
- if isinstance(mask, torch.Tensor):
398
- # Apply slight smoothing
399
- import torch.nn.functional as F
400
- mask = F.avg_pool2d(mask.unsqueeze(0), 3, stride=1, padding=1)
401
- mask = mask.squeeze(0).cpu().numpy()
402
-
403
- return mask
404
-
405
- except Exception as e:
406
- logger.error(f"MatAnyone safe_wrapper failed: {e}")
407
- import traceback
408
- logger.debug(traceback.format_exc())
409
- # Return safe fallback
410
- if 'mask' in locals() and mask is not None:
411
- if isinstance(mask, torch.Tensor):
412
- return mask.cpu().numpy()
413
- return mask
414
- return np.ones((512, 512), dtype=np.float32) * 0.5
415
-
416
- # Apply patches to both methods
417
- if hasattr(processor, 'step'):
418
- processor.step = safe_wrapper
419
- logger.info("Patched MatAnyone step method")
420
-
421
- if hasattr(processor, 'process'):
422
- processor.process = safe_wrapper
423
- logger.info("Patched MatAnyone process method")
424
-
425
- # Also add a direct call method
426
- processor.__call__ = safe_wrapper
427
-
428
- def _load_fallback(self) -> Optional[Any]:
429
- """Create fallback processor for testing"""
430
-
431
- class FallbackMatAnyone:
432
- def __init__(self, device):
433
- self.device = device
434
-
435
- def step(self, image, mask, idx_mask=False, **kwargs):
436
- """Pass through mask with minor smoothing"""
437
- if isinstance(mask, np.ndarray):
438
- # Apply slight Gaussian blur for edge smoothing
439
  import cv2
440
- if mask.ndim == 2:
441
- smoothed = cv2.GaussianBlur(mask, (5, 5), 1.0)
442
- return smoothed
443
- elif mask.ndim == 3:
444
- smoothed = np.zeros_like(mask)
445
- for i in range(mask.shape[0]):
446
- smoothed[i] = cv2.GaussianBlur(mask[i], (5, 5), 1.0)
447
- return smoothed
448
- return mask
449
-
 
 
 
450
  def process(self, image, mask, **kwargs):
451
- """Alias for step"""
452
  return self.step(image, mask, **kwargs)
453
-
454
- logger.warning("Using fallback MatAnyone (limited refinement)")
455
- return FallbackMatAnyone(self.device)
456
-
 
 
 
457
  def cleanup(self):
458
- """Clean up resources"""
459
  if self.model:
460
- del self.model
 
 
 
461
  self.model = None
462
  if torch.cuda.is_available():
463
  torch.cuda.empty_cache()
464
-
465
  def get_info(self) -> Dict[str, Any]:
466
- """Get loader information"""
467
  return {
468
  "loaded": self.model is not None,
469
  "model_id": self.model_id,
470
  "device": self.device,
471
  "load_time": self.load_time,
472
- "model_type": type(self.model).__name__ if self.model else None
473
- }
 
 
 
 
 
1
  #!/usr/bin/env python3
2
  """
3
+ MatAnyone Model Loader (Hardened)
4
+ - Prevents 5D (B,T,C,H,W) tensors from ever reaching conv2d.
5
+ - Normalizes images to BCHW [B,C,H,W] and masks to B1HW [B,1,H,W].
6
+ - If idx_mask=True, converts masks to integer labels (long) safely.
7
+ - Tries unbatched then batched calls for maximum compatibility.
8
+ - Resizes masks with 'nearest' to preserve label integrity.
9
+ - Includes a debug_shapes() helper for quick diagnostics.
10
  """
11
 
12
  import os
13
  import time
14
  import logging
15
  import traceback
16
+ from typing import Optional, Dict, Any, Tuple
 
17
 
 
18
  import numpy as np
19
+ import torch
20
 
21
  logger = logging.getLogger(__name__)
22
 
23
 
24
+ # ------------------------------- Utilities -------------------------------- #
25
+
26
+ def _select_device(pref: str) -> str:
27
+ """
28
+ Resolve a safe device string. If CUDA not available, fall back to CPU.
29
+ """
30
+ pref = (pref or "").lower()
31
+ if pref.startswith("cuda"):
32
+ return "cuda" if torch.cuda.is_available() else "cpu"
33
+ if pref == "cpu":
34
+ return "cpu"
35
+ return "cuda" if torch.cuda.is_available() else "cpu"
36
+
37
+
38
+ def _as_tensor_on_device(x, device: str) -> torch.Tensor:
39
+ """Convert ndarray or Tensor to torch.Tensor on device."""
40
+ if isinstance(x, torch.Tensor):
41
+ return x.to(device)
42
+ return torch.from_numpy(np.asarray(x)).to(device)
43
+
44
+
45
+ def _to_bchw(x, device: str, is_mask: bool = False) -> torch.Tensor:
46
+ """
47
+ Normalize input to BCHW (image) or B1HW (mask).
48
+
49
+ Accepts: HWC, CHW, BCHW, BHWC, BTCHW, BTHWC, TCHW, THWC, HW.
50
+ - Collapses any time/clip dimension T if present (takes t=0 if T>1).
51
+ - Images returned as float32 in [0,1], shape [B,C,H,W] (C=3 or 4; C=1 expanded to 3).
52
+ - Masks returned as float32 in [0,1], shape [B,1,H,W].
53
+ """
54
+ x = _as_tensor_on_device(x, device)
55
+
56
+ # Promote to float and normalize if needed
57
+ if x.dtype == torch.uint8:
58
+ x = x.float().div_(255.0)
59
+ elif x.dtype in (torch.int16, torch.int32, torch.int64):
60
+ x = x.float()
61
+
62
+ # 5D: [B,T,C,H,W] or [B,T,H,W,C] -> take first frame
63
+ if x.ndim == 5:
64
+ B, T = x.shape[0], x.shape[1]
65
+ x = x[:, 0] if T > 0 else x.squeeze(1) # -> [B,C,H,W] or [B,H,W,C]
66
+
67
+ # 4D
68
+ if x.ndim == 4:
69
+ # If BHWC, permute to BCHW
70
+ if x.shape[-1] in (1, 3, 4) and x.shape[1] not in (1, 3, 4):
71
+ x = x.permute(0, 3, 1, 2).contiguous()
72
+
73
+ # 3D
74
+ elif x.ndim == 3:
75
+ # HWC -> CHW
76
+ if x.shape[-1] in (1, 3, 4):
77
+ x = x.permute(2, 0, 1).contiguous()
78
+ x = x.unsqueeze(0) # -> BCHW
79
+
80
+ # 2D
81
+ elif x.ndim == 2:
82
+ if is_mask:
83
+ x = x.unsqueeze(0).unsqueeze(0) # -> B1HW
84
+ else:
85
+ x = x.unsqueeze(0).unsqueeze(0) # 1,1,H,W
86
+ x = x.repeat(1, 3, 1, 1) # 1,3,H,W
87
+
88
+ else:
89
+ raise ValueError(f"Unsupported tensor ndim={x.ndim} for normalization")
90
+
91
+ # Now x should be BCHW
92
+ if is_mask:
93
+ # Ensure single-channel
94
+ if x.shape[1] > 1:
95
+ x = x[:, :1]
96
+ x = x.clamp_(0.0, 1.0).to(torch.float32)
97
+ else:
98
+ # Ensure reasonable channels
99
+ C = x.shape[1]
100
+ if C == 1:
101
+ x = x.repeat(1, 3, 1, 1)
102
+ if x.min() < 0.0 or x.max() > 1.0:
103
+ x = x.clamp_(0.0, 1.0)
104
+ x = x.to(torch.float32)
105
+
106
+ return x
107
+
108
+
109
+ def _resize_mask_to(img_bchw: torch.Tensor, mask_b1hw: torch.Tensor) -> torch.Tensor:
110
+ """
111
+ Ensure mask spatial dims match image. Use NEAREST to keep labels crisp.
112
+ """
113
+ if img_bchw.shape[-2:] == mask_b1hw.shape[-2:]:
114
+ return mask_b1hw
115
+ import torch.nn.functional as F
116
+ return F.interpolate(mask_b1hw, size=img_bchw.shape[-2:], mode="nearest")
117
+
118
+
119
+ def debug_shapes(tag: str, image, mask) -> None:
120
+ """
121
+ Quick diagnostics: logs shape/dtype/min/max for image/mask.
122
+ """
123
+ def _info(name, t):
124
+ try:
125
+ tt = torch.as_tensor(t)
126
+ mn = float(tt.min()) if tt.numel() else float("nan")
127
+ mx = float(tt.max()) if tt.numel() else float("nan")
128
+ logger.info(f"[{tag}:{name}] shape={tuple(tt.shape)} dtype={tt.dtype} "
129
+ f"min={mn:.4f} max={mx:.4f}")
130
+ except Exception as e:
131
+ logger.info(f"[{tag}:{name}] type={type(t)} err={e}")
132
+
133
+ _info("image", image)
134
+ _info("mask", mask)
135
+
136
+
137
+ # --------------------------- Boundary Wrapper ------------------------------ #
138
+
139
+ class _MatAnyoneWrapper:
140
+ """
141
+ Thin, defensive wrapper around the MatAnyone InferenceCore.
142
+ Normalizes inputs at the boundary so the core never sees >4D tensors.
143
+ """
144
+
145
+ def __init__(self, core: Any, device: str):
146
+ self.core = core
147
+ self.device = device
148
+
149
+ # Try to move the core to device, if supported.
150
+ try:
151
+ if hasattr(self.core, "to"):
152
+ self.core.to(self.device)
153
+ except Exception as e:
154
+ logger.debug(f"MatAnyone core .to({self.device}) not applied: {e}")
155
+
156
+ @staticmethod
157
+ def _to_numpy(x):
158
+ if isinstance(x, torch.Tensor):
159
+ return x.detach().cpu().numpy()
160
+ return np.asarray(x)
161
+
162
+ def _normalize_pair(
163
+ self, image, mask, idx_mask: bool
164
+ ) -> Tuple[torch.Tensor, torch.Tensor, bool]:
165
+ img_bchw = _to_bchw(image, self.device, is_mask=False) # [B,C,H,W]
166
+ msk_b1hw = _to_bchw(mask, self.device, is_mask=True) # [B,1,H,W]
167
+ msk_b1hw = _resize_mask_to(img_bchw, msk_b1hw)
168
+ return img_bchw, msk_b1hw, bool(idx_mask)
169
+
170
+ def __call__(self, image, mask, idx_mask: bool = False, **kwargs):
171
+ """
172
+ Preferred entry: handles normalization and robust call patterns.
173
+ """
174
+ img_bchw, msk_b1hw, idx_mask = self._normalize_pair(image, mask, idx_mask)
175
+
176
+ # Special handling for idx_mask: convert to integer label map.
177
+ if idx_mask:
178
+ # Threshold -> {0,1} long; squeeze channel
179
+ m_bhw = (msk_b1hw > 0.5).long()[:, 0] # [B,H,W]
180
+ # Try unbatched first if B==1
181
+ if img_bchw.shape[0] == 1:
182
+ img_chw = img_bchw[0] # [C,H,W]
183
+ m_hw = m_bhw[0] # [H,W]
184
+ # Prefer step(image, mask, idx_mask=True)
185
+ try:
186
+ if hasattr(self.core, "step"):
187
+ out = self.core.step(image=img_chw, mask=m_hw, idx_mask=True, **kwargs)
188
+ return self._to_numpy(out)
189
+ except Exception as e_unbatched_idx:
190
+ logger.debug(f"MatAnyone unbatched idx_mask step() failed: {e_unbatched_idx}")
191
+ # Batched fallback
192
+ for method_name in ("step", "process"):
193
+ try:
194
+ if hasattr(self.core, method_name):
195
+ method = getattr(self.core, method_name)
196
+ out = method(image=img_bchw, mask=m_bhw, idx_mask=True, **kwargs)
197
+ return self._to_numpy(out)
198
+ except Exception as e_batched_idx:
199
+ logger.debug(f"MatAnyone {method_name} idx_mask batched call failed: {e_batched_idx}")
200
+
201
+ logger.warning("MatAnyone idx_mask calls failed; returning integer mask as fallback.")
202
+ return self._to_numpy(m_bhw if m_bhw.shape[0] > 1 else m_bhw[0])
203
+
204
+ # Non-index soft/binary mask path
205
+ try:
206
+ # Try unbatched first (common CHW / 1HW)
207
+ if hasattr(self.core, "step") and img_bchw.shape[0] == 1:
208
+ img_chw = img_bchw[0] # [C,H,W]
209
+ m_1hw = msk_b1hw[0] # [1,H,W]
210
+ out = self.core.step(image=img_chw, mask=m_1hw, idx_mask=False, **kwargs)
211
+ return self._to_numpy(out)
212
+ except Exception as e_unbatched:
213
+ logger.debug(f"MatAnyone unbatched step() failed: {e_unbatched}")
214
+
215
+ # Batched fallback
216
+ for method_name in ("step", "process"):
217
+ try:
218
+ if hasattr(self.core, method_name):
219
+ method = getattr(self.core, method_name)
220
+ out = method(image=img_bchw, mask=msk_b1hw, idx_mask=False, **kwargs)
221
+ return self._to_numpy(out)
222
+ except Exception as e_batched:
223
+ logger.debug(f"MatAnyone {method_name} batched call failed: {e_batched}")
224
+
225
+ logger.warning("MatAnyone calls failed; returning input mask as fallback.")
226
+ return self._to_numpy(msk_b1hw.squeeze(1)) # [B,H,W] or [H,W] if squeezed
227
+
228
+
229
+ # ------------------------------- Loader ----------------------------------- #
230
+
231
  class MatAnyoneLoader:
232
+ """Dedicated loader for MatAnyone models (with boundary normalization)."""
233
+
234
  def __init__(self, device: str = "cuda", cache_dir: str = "./checkpoints/matanyone_cache"):
235
+ self.device = _select_device(device)
236
  self.cache_dir = cache_dir
237
  os.makedirs(self.cache_dir, exist_ok=True)
238
+
239
+ self.model: Optional[Any] = None
240
  self.model_id = "PeiqingYang/MatAnyone"
241
  self.load_time = 0.0
242
+
243
  def load(self) -> Optional[Any]:
244
  """
245
+ Load MatAnyone model and return a callable wrapper.
246
  Returns:
247
+ _MatAnyoneWrapper or None
248
  """
249
+ logger.info(f"Loading MatAnyone model: {self.model_id} (device={self.device})")
250
+
 
251
  strategies = [
252
  ("official", self._load_official),
253
+ ("fallback", self._load_fallback),
254
  ]
255
+
256
  for strategy_name, strategy_func in strategies:
257
  try:
258
  logger.info(f"Trying MatAnyone loading strategy: {strategy_name}")
 
261
  if model:
262
  self.load_time = time.time() - start_time
263
  self.model = model
264
+ logger.info(f"MatAnyone loaded via {strategy_name} in {self.load_time:.2f}s")
265
  return model
266
  except Exception as e:
267
  logger.error(f"MatAnyone {strategy_name} strategy failed: {e}")
268
  logger.debug(traceback.format_exc())
269
  continue
270
+
271
  logger.error("All MatAnyone loading strategies failed")
272
  return None
273
+
274
  def _load_official(self) -> Optional[Any]:
 
 
 
 
 
 
 
 
 
 
 
 
275
  """
276
+ Load using the official MatAnyone API and wrap with boundary normalizer.
 
277
  """
278
+ try:
279
+ from matanyone import InferenceCore # type: ignore
280
+ except Exception as e:
281
+ logger.error(f"Failed to import official MatAnyone: {e}")
282
+ return None
283
+
284
+ core = InferenceCore(self.model_id)
285
+ wrapped = _MatAnyoneWrapper(core, device=self.device)
286
+ return wrapped
287
+
288
+ def _load_fallback(self) -> Optional[Any]:
289
+ """Create a minimal fallback that smooths/returns the mask."""
290
+
291
+ class _FallbackCore:
292
+ def step(self, image, mask, idx_mask: bool = False, **kwargs):
293
+ # Convert mask to numpy
294
+ if isinstance(mask, torch.Tensor):
295
+ mask_np = mask.detach().cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  else:
297
+ mask_np = np.asarray(mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  import cv2
300
+ if mask_np.ndim == 2:
301
+ return cv2.GaussianBlur(mask_np, (5, 5), 1.0)
302
+ if mask_np.ndim == 3:
303
+ # Handle CHW-style smoothing (per-channel)
304
+ if mask_np.shape[0] in (1, 3, 4):
305
+ sm = np.empty_like(mask_np)
306
+ for i in range(mask_np.shape[0]):
307
+ sm[i] = cv2.GaussianBlur(mask_np[i], (5, 5), 1.0)
308
+ return sm
309
+ return mask_np
310
+ except Exception:
311
+ return mask_np
312
+
313
  def process(self, image, mask, **kwargs):
 
314
  return self.step(image, mask, **kwargs)
315
+
316
+ logger.warning("Using fallback MatAnyone (limited refinement).")
317
+ core = _FallbackCore()
318
+ return _MatAnyoneWrapper(core, device=self.device)
319
+
320
+ # --------------------------- Housekeeping --------------------------- #
321
+
322
  def cleanup(self):
323
+ """Clean up resources."""
324
  if self.model:
325
+ try:
326
+ del self.model
327
+ except Exception:
328
+ pass
329
  self.model = None
330
  if torch.cuda.is_available():
331
  torch.cuda.empty_cache()
332
+
333
  def get_info(self) -> Dict[str, Any]:
334
+ """Get loader information."""
335
  return {
336
  "loaded": self.model is not None,
337
  "model_id": self.model_id,
338
  "device": self.device,
339
  "load_time": self.load_time,
340
+ "model_type": type(self.model).__name__ if self.model else None,
341
+ }
342
+
343
+ # Optional: instance-level shape debugging hook
344
+ def debug_shapes(self, image, mask, tag: str = ""):
345
+ debug_shapes(tag, image, mask)