MogensR commited on
Commit
b20702e
·
1 Parent(s): 0672ceb

Update models/loaders/matanyone_loader.py

Browse files
Files changed (1) hide show
  1. models/loaders/matanyone_loader.py +19 -5
models/loaders/matanyone_loader.py CHANGED
@@ -217,23 +217,37 @@ def guarded_method(*args, **kwargs):
217
  return torch.ones((1, 512, 512), dtype=torch.float32) * 0.5
218
 
219
  try:
220
- # Coerce shapes
221
  img_nchw = ensure_image_nchw(image, want_batched=True)
222
 
 
 
 
 
 
 
 
 
 
223
  if idx_mask:
224
  m_fixed = ensure_mask_for_matanyone(mask, idx_mask=True)
225
  else:
226
  m_fixed = ensure_mask_for_matanyone(mask, idx_mask=False, threshold=0.5)
227
 
228
- # Log shapes for debugging
229
  logger.debug(f"MatAnyone input - image: {img_nchw.shape}, mask: {m_fixed.shape}, idx: {idx_mask}")
230
 
 
 
 
 
 
 
231
  # Try unbatched first (most common)
232
  try:
233
  new_kwargs = dict(kwargs)
234
- # CRITICAL: Use unbatched (CHW) not batched for first attempt
235
- new_kwargs["image"] = img_nchw.squeeze(0) if img_nchw.shape[0] == 1 else img_nchw[0] # CHW
236
- new_kwargs["mask"] = m_fixed.squeeze(0) if m_fixed.shape[0] == 1 else m_fixed # HW or CHW
237
  new_kwargs["idx_mask"] = bool(idx_mask)
238
 
239
  result = original_method(**new_kwargs)
 
217
  return torch.ones((1, 512, 512), dtype=torch.float32) * 0.5
218
 
219
  try:
220
+ # Coerce shapes - ensure we REALLY squeeze out extra dimensions
221
  img_nchw = ensure_image_nchw(image, want_batched=True)
222
 
223
+ # CRITICAL FIX: Force squeeze all unnecessary dimensions
224
+ while img_nchw.ndim > 4:
225
+ if img_nchw.shape[0] == 1:
226
+ img_nchw = img_nchw.squeeze(0)
227
+ elif img_nchw.shape[1] == 1:
228
+ img_nchw = img_nchw.squeeze(1)
229
+ else:
230
+ break
231
+
232
  if idx_mask:
233
  m_fixed = ensure_mask_for_matanyone(mask, idx_mask=True)
234
  else:
235
  m_fixed = ensure_mask_for_matanyone(mask, idx_mask=False, threshold=0.5)
236
 
237
+ # Log actual shapes being passed
238
  logger.debug(f"MatAnyone input - image: {img_nchw.shape}, mask: {m_fixed.shape}, idx: {idx_mask}")
239
 
240
+ # For MatAnyone, we need CHW not NCHW for unbatched
241
+ if img_nchw.ndim == 4 and img_nchw.shape[0] == 1:
242
+ img_chw = img_nchw[0] # Remove batch dimension
243
+ else:
244
+ img_chw = img_nchw
245
+
246
  # Try unbatched first (most common)
247
  try:
248
  new_kwargs = dict(kwargs)
249
+ new_kwargs["image"] = img_chw # CHW
250
+ new_kwargs["mask"] = m_fixed.squeeze(0) if m_fixed.ndim > 2 and m_fixed.shape[0] == 1 else m_fixed
 
251
  new_kwargs["idx_mask"] = bool(idx_mask)
252
 
253
  result = original_method(**new_kwargs)