MogensR commited on
Commit
0672ceb
·
1 Parent(s): aa52ec9

Update models/loaders/matanyone_loader.py

Browse files
Files changed (1) hide show
  1. models/loaders/matanyone_loader.py +12 -8
models/loaders/matanyone_loader.py CHANGED
@@ -92,11 +92,14 @@ def ensure_image_nchw(img: torch.Tensor, want_batched: bool = True) -> torch.Ten
92
  img = img.to(device)
93
 
94
  # Handle 5D tensors (B,T,C,H,W) by squeezing time dimension
95
- if img.ndim == 5:
96
- if img.shape[1] == 1: # Single time frame
97
- img = img.squeeze(1)
98
- elif img.shape[0] == 1: # Single batch
99
  img = img.squeeze(0)
 
 
 
 
 
100
 
101
  # Handle various input formats
102
  if img.ndim == 3:
@@ -134,12 +137,12 @@ def ensure_image_nchw(img: torch.Tensor, want_batched: bool = True) -> torch.Ten
134
  if nchw.max() > 1.0:
135
  nchw = nchw / 255.0
136
 
137
- return nchw if want_batched else nchw[0]
138
 
139
  else:
140
  logger.error(f"Unexpected image dimensions: {img.shape}")
141
  # Return something safe
142
- return torch.zeros((1, 3, 512, 512), device=device, dtype=torch.float32)
143
 
144
  def ensure_mask_for_matanyone(mask: torch.Tensor, idx_mask: bool = False,
145
  threshold: float = 0.5, keep_soft: bool = False) -> torch.Tensor:
@@ -228,8 +231,9 @@ def guarded_method(*args, **kwargs):
228
  # Try unbatched first (most common)
229
  try:
230
  new_kwargs = dict(kwargs)
231
- new_kwargs["image"] = img_nchw[0] # CHW
232
- new_kwargs["mask"] = m_fixed if idx_mask else m_fixed # Already correct shape
 
233
  new_kwargs["idx_mask"] = bool(idx_mask)
234
 
235
  result = original_method(**new_kwargs)
 
92
  img = img.to(device)
93
 
94
  # Handle 5D tensors (B,T,C,H,W) by squeezing time dimension
95
+ while img.ndim == 5:
96
+ if img.shape[0] == 1:
 
 
97
  img = img.squeeze(0)
98
+ elif img.shape[1] == 1:
99
+ img = img.squeeze(1)
100
+ else:
101
+ # Can't auto-squeeze, take first time frame
102
+ img = img[:, 0]
103
 
104
  # Handle various input formats
105
  if img.ndim == 3:
 
137
  if nchw.max() > 1.0:
138
  nchw = nchw / 255.0
139
 
140
+ return nchw if want_batched else nchw.squeeze(0) if not want_batched and nchw.shape[0] == 1 else nchw[0]
141
 
142
  else:
143
  logger.error(f"Unexpected image dimensions: {img.shape}")
144
  # Return something safe
145
+ return torch.zeros((3, 512, 512), device=device, dtype=torch.float32).unsqueeze(0) if want_batched else torch.zeros((3, 512, 512), device=device, dtype=torch.float32)
146
 
147
  def ensure_mask_for_matanyone(mask: torch.Tensor, idx_mask: bool = False,
148
  threshold: float = 0.5, keep_soft: bool = False) -> torch.Tensor:
 
231
  # Try unbatched first (most common)
232
  try:
233
  new_kwargs = dict(kwargs)
234
+ # CRITICAL: Use unbatched (CHW) not batched for first attempt
235
+ new_kwargs["image"] = img_nchw.squeeze(0) if img_nchw.shape[0] == 1 else img_nchw[0] # CHW
236
+ new_kwargs["mask"] = m_fixed.squeeze(0) if m_fixed.shape[0] == 1 else m_fixed # HW or CHW
237
  new_kwargs["idx_mask"] = bool(idx_mask)
238
 
239
  result = original_method(**new_kwargs)