MogensR commited on
Commit
aa52ec9
·
1 Parent(s): 9e03b6b

Update models/loaders/matanyone_loader.py

Browse files
Files changed (1) hide show
  1. models/loaders/matanyone_loader.py +200 -12
models/loaders/matanyone_loader.py CHANGED
@@ -62,26 +62,214 @@ def load(self) -> Optional[Any]:
62
  return None
63
 
64
  def _load_official(self) -> Optional[Any]:
65
- """Load using official MatAnyone API"""
66
  from matanyone import InferenceCore
67
 
68
  # Create processor - pass model ID as positional argument
69
  processor = InferenceCore(self.model_id)
70
 
71
- # Ensure processor is properly initialized for the device
72
- if hasattr(processor, 'device'):
73
- processor.device = self.device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- # Move model components to device if they exist
76
- if hasattr(processor, 'model'):
77
- if hasattr(processor.model, 'to'):
78
- processor.model = processor.model.to(self.device)
79
- processor.model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
- # Patch the processor to handle inputs properly
82
- self._patch_processor(processor)
 
 
 
 
83
 
84
- return processor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  def _patch_processor(self, processor):
87
  """
 
62
  return None
63
 
64
  def _load_official(self) -> Optional[Any]:
65
+ """Load using official MatAnyone API with comprehensive shape guard"""
66
  from matanyone import InferenceCore
67
 
68
  # Create processor - pass model ID as positional argument
69
  processor = InferenceCore(self.model_id)
70
 
71
+ # Install the critical shape guard patch from original loader
72
+ self._install_shape_guard(processor)
73
+
74
+ return processor
75
+
76
+ def _install_shape_guard(self, processor):
77
+ """
78
+ Install the comprehensive shape guard from the original loader.
79
+ This is CRITICAL for preventing 5D tensor issues and ensuring compatibility.
80
+ """
81
+ import torch
82
+ import numpy as np
83
+
84
+ device = self.device
85
+
86
+ # Helper functions for tensor manipulation
87
+ def ensure_image_nchw(img: torch.Tensor, want_batched: bool = True) -> torch.Tensor:
88
+ """Ensure image is in NCHW format"""
89
+ if isinstance(img, np.ndarray):
90
+ img = torch.from_numpy(img)
91
+
92
+ img = img.to(device)
93
+
94
+ # Handle 5D tensors (B,T,C,H,W) by squeezing time dimension
95
+ if img.ndim == 5:
96
+ if img.shape[1] == 1: # Single time frame
97
+ img = img.squeeze(1)
98
+ elif img.shape[0] == 1: # Single batch
99
+ img = img.squeeze(0)
100
+
101
+ # Handle various input formats
102
+ if img.ndim == 3:
103
+ # CHW or HWC
104
+ if img.shape[0] in (1, 3, 4): # Likely CHW
105
+ chw = img
106
+ elif img.shape[-1] in (1, 3, 4): # Likely HWC
107
+ chw = img.permute(2, 0, 1)
108
+ else:
109
+ # Assume CHW
110
+ chw = img
111
+
112
+ # Ensure float and normalized
113
+ if chw.dtype != torch.float32:
114
+ chw = chw.float()
115
+ if chw.max() > 1.0:
116
+ chw = chw / 255.0
117
+
118
+ return chw.unsqueeze(0) if want_batched else chw
119
+
120
+ elif img.ndim == 4:
121
+ # NCHW or NHWC
122
+ N, A, B, C = img.shape
123
+ if A in (1, 3, 4): # NCHW
124
+ nchw = img
125
+ elif C in (1, 3, 4): # NHWC
126
+ nchw = img.permute(0, 3, 1, 2)
127
+ else:
128
+ # Assume NCHW
129
+ nchw = img
130
+
131
+ # Ensure float and normalized
132
+ if nchw.dtype != torch.float32:
133
+ nchw = nchw.float()
134
+ if nchw.max() > 1.0:
135
+ nchw = nchw / 255.0
136
+
137
+ return nchw if want_batched else nchw[0]
138
+
139
+ else:
140
+ logger.error(f"Unexpected image dimensions: {img.shape}")
141
+ # Return something safe
142
+ return torch.zeros((1, 3, 512, 512), device=device, dtype=torch.float32)
143
+
144
+ def ensure_mask_for_matanyone(mask: torch.Tensor, idx_mask: bool = False,
145
+ threshold: float = 0.5, keep_soft: bool = False) -> torch.Tensor:
146
+ """Ensure mask is in correct format for MatAnyone"""
147
+ if isinstance(mask, np.ndarray):
148
+ mask = torch.from_numpy(mask)
149
+
150
+ mask = mask.to(device)
151
+
152
+ # Handle 5D tensors
153
+ if mask.ndim == 5:
154
+ if mask.shape[1] == 1:
155
+ mask = mask.squeeze(1)
156
+ if mask.shape[0] == 1 and mask.ndim == 5:
157
+ mask = mask.squeeze(0)
158
 
159
+ # Handle index masks
160
+ if idx_mask:
161
+ if mask.ndim == 3:
162
+ if mask.shape[0] == 1:
163
+ idx = (mask[0] >= threshold).to(torch.long)
164
+ else:
165
+ idx = torch.argmax(mask, dim=0).to(torch.long)
166
+ idx = (idx > 0).to(torch.long)
167
+ elif mask.ndim == 2:
168
+ idx = (mask >= threshold).to(torch.long)
169
+ else:
170
+ logger.warning(f"Unexpected idx mask shape: {mask.shape}")
171
+ idx = torch.zeros((512, 512), device=device, dtype=torch.long)
172
+ return idx
173
+
174
+ # Handle channel masks
175
+ if mask.ndim == 2:
176
+ out = mask.unsqueeze(0) # Add channel dimension
177
+ elif mask.ndim == 3:
178
+ if mask.shape[0] == 1:
179
+ out = mask
180
+ else:
181
+ # Choose channel with largest area
182
+ areas = mask.sum(dim=(-2, -1))
183
+ best_idx = areas.argmax()
184
+ out = mask[best_idx:best_idx+1]
185
+ else:
186
+ logger.warning(f"Unexpected mask shape: {mask.shape}")
187
+ out = torch.ones((1, 512, 512), device=device, dtype=torch.float32)
188
 
189
+ # Convert to float and normalize
190
+ out = out.to(torch.float32)
191
+ if not keep_soft:
192
+ out = (out >= threshold).to(torch.float32)
193
+
194
+ return out.clamp_(0.0, 1.0).contiguous()
195
 
196
+ # Create the guarded wrapper
197
+ def create_guarded_method(original_method):
198
+ """Create a guarded version of a MatAnyone method"""
199
+ def guarded_method(*args, **kwargs):
200
+ # Extract image and mask
201
+ image = kwargs.get("image", None)
202
+ mask = kwargs.get("mask", None)
203
+ idx_mask = kwargs.get("idx_mask", kwargs.get("index_mask", False))
204
+
205
+ # Handle positional arguments
206
+ if image is None and len(args) >= 1:
207
+ image = args[0]
208
+ if mask is None and len(args) >= 2:
209
+ mask = args[1]
210
+
211
+ if image is None or mask is None:
212
+ logger.error(f"MatAnyone called without image/mask: args={len(args)}, kwargs={list(kwargs.keys())}")
213
+ # Return something safe
214
+ return torch.ones((1, 512, 512), dtype=torch.float32) * 0.5
215
+
216
+ try:
217
+ # Coerce shapes
218
+ img_nchw = ensure_image_nchw(image, want_batched=True)
219
+
220
+ if idx_mask:
221
+ m_fixed = ensure_mask_for_matanyone(mask, idx_mask=True)
222
+ else:
223
+ m_fixed = ensure_mask_for_matanyone(mask, idx_mask=False, threshold=0.5)
224
+
225
+ # Log shapes for debugging
226
+ logger.debug(f"MatAnyone input - image: {img_nchw.shape}, mask: {m_fixed.shape}, idx: {idx_mask}")
227
+
228
+ # Try unbatched first (most common)
229
+ try:
230
+ new_kwargs = dict(kwargs)
231
+ new_kwargs["image"] = img_nchw[0] # CHW
232
+ new_kwargs["mask"] = m_fixed if idx_mask else m_fixed # Already correct shape
233
+ new_kwargs["idx_mask"] = bool(idx_mask)
234
+
235
+ result = original_method(**new_kwargs)
236
+ return result
237
+
238
+ except Exception as e1:
239
+ logger.debug(f"Unbatched call failed, trying batched: {e1}")
240
+ # Try with batch dimension
241
+ new_kwargs = dict(kwargs)
242
+ new_kwargs["image"] = img_nchw # NCHW
243
+ new_kwargs["mask"] = m_fixed
244
+ new_kwargs["idx_mask"] = bool(idx_mask)
245
+
246
+ result = original_method(**new_kwargs)
247
+ return result
248
+
249
+ except Exception as e:
250
+ logger.error(f"MatAnyone guarded call failed: {e}")
251
+ import traceback
252
+ logger.debug(traceback.format_exc())
253
+ # Return input mask as fallback
254
+ if isinstance(mask, torch.Tensor):
255
+ return mask.cpu().numpy()
256
+ elif isinstance(mask, np.ndarray):
257
+ return mask
258
+ else:
259
+ return np.ones((512, 512), dtype=np.float32) * 0.5
260
+
261
+ return guarded_method
262
+
263
+ # Apply guard to both step and process methods
264
+ if hasattr(processor, 'step'):
265
+ original_step = processor.step
266
+ processor.step = create_guarded_method(original_step)
267
+ logger.info("Installed shape guard on MatAnyone.step")
268
+
269
+ if hasattr(processor, 'process'):
270
+ original_process = processor.process
271
+ processor.process = create_guarded_method(original_process)
272
+ logger.info("Installed shape guard on MatAnyone.process")
273
 
274
  def _patch_processor(self, processor):
275
  """