Update models/loaders/matanyone_loader.py
Browse files
models/loaders/matanyone_loader.py
CHANGED
|
@@ -92,11 +92,14 @@ def ensure_image_nchw(img: torch.Tensor, want_batched: bool = True) -> torch.Ten
|
|
| 92 |
img = img.to(device)
|
| 93 |
|
| 94 |
# Handle 5D tensors (B,T,C,H,W) by squeezing time dimension
|
| 95 |
-
|
| 96 |
-
if img.shape[
|
| 97 |
-
img = img.squeeze(1)
|
| 98 |
-
elif img.shape[0] == 1: # Single batch
|
| 99 |
img = img.squeeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
# Handle various input formats
|
| 102 |
if img.ndim == 3:
|
|
@@ -134,12 +137,12 @@ def ensure_image_nchw(img: torch.Tensor, want_batched: bool = True) -> torch.Ten
|
|
| 134 |
if nchw.max() > 1.0:
|
| 135 |
nchw = nchw / 255.0
|
| 136 |
|
| 137 |
-
return nchw if want_batched else nchw[0]
|
| 138 |
|
| 139 |
else:
|
| 140 |
logger.error(f"Unexpected image dimensions: {img.shape}")
|
| 141 |
# Return something safe
|
| 142 |
-
return torch.zeros((
|
| 143 |
|
| 144 |
def ensure_mask_for_matanyone(mask: torch.Tensor, idx_mask: bool = False,
|
| 145 |
threshold: float = 0.5, keep_soft: bool = False) -> torch.Tensor:
|
|
@@ -228,8 +231,9 @@ def guarded_method(*args, **kwargs):
|
|
| 228 |
# Try unbatched first (most common)
|
| 229 |
try:
|
| 230 |
new_kwargs = dict(kwargs)
|
| 231 |
-
|
| 232 |
-
new_kwargs["
|
|
|
|
| 233 |
new_kwargs["idx_mask"] = bool(idx_mask)
|
| 234 |
|
| 235 |
result = original_method(**new_kwargs)
|
|
|
|
| 92 |
img = img.to(device)
|
| 93 |
|
| 94 |
# Handle 5D tensors (B,T,C,H,W) by squeezing time dimension
|
| 95 |
+
while img.ndim == 5:
|
| 96 |
+
if img.shape[0] == 1:
|
|
|
|
|
|
|
| 97 |
img = img.squeeze(0)
|
| 98 |
+
elif img.shape[1] == 1:
|
| 99 |
+
img = img.squeeze(1)
|
| 100 |
+
else:
|
| 101 |
+
# Can't auto-squeeze, take first time frame
|
| 102 |
+
img = img[:, 0]
|
| 103 |
|
| 104 |
# Handle various input formats
|
| 105 |
if img.ndim == 3:
|
|
|
|
| 137 |
if nchw.max() > 1.0:
|
| 138 |
nchw = nchw / 255.0
|
| 139 |
|
| 140 |
+
return nchw if want_batched else nchw.squeeze(0) if not want_batched and nchw.shape[0] == 1 else nchw[0]
|
| 141 |
|
| 142 |
else:
|
| 143 |
logger.error(f"Unexpected image dimensions: {img.shape}")
|
| 144 |
# Return something safe
|
| 145 |
+
return torch.zeros((3, 512, 512), device=device, dtype=torch.float32).unsqueeze(0) if want_batched else torch.zeros((3, 512, 512), device=device, dtype=torch.float32)
|
| 146 |
|
| 147 |
def ensure_mask_for_matanyone(mask: torch.Tensor, idx_mask: bool = False,
|
| 148 |
threshold: float = 0.5, keep_soft: bool = False) -> torch.Tensor:
|
|
|
|
| 231 |
# Try unbatched first (most common)
|
| 232 |
try:
|
| 233 |
new_kwargs = dict(kwargs)
|
| 234 |
+
# CRITICAL: Use unbatched (CHW) not batched for first attempt
|
| 235 |
+
new_kwargs["image"] = img_nchw.squeeze(0) if img_nchw.shape[0] == 1 else img_nchw[0] # CHW
|
| 236 |
+
new_kwargs["mask"] = m_fixed.squeeze(0) if m_fixed.shape[0] == 1 else m_fixed # HW or CHW
|
| 237 |
new_kwargs["idx_mask"] = bool(idx_mask)
|
| 238 |
|
| 239 |
result = original_method(**new_kwargs)
|