Update models/loaders/matanyone_loader.py
Browse files
models/loaders/matanyone_loader.py
CHANGED
|
@@ -232,8 +232,11 @@ def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
|
|
| 232 |
img_in = _resize_bchw(img_bchw, (th, tw), is_mask=False)
|
| 233 |
msk_in = _resize_bchw(msk_b1hw, (th, tw), is_mask=True) if msk_b1hw is not None else None
|
| 234 |
|
| 235 |
-
|
| 236 |
-
|
|
|
|
|
|
|
|
|
|
| 237 |
|
| 238 |
# inference with autocast + inference_mode
|
| 239 |
with torch.inference_mode():
|
|
@@ -247,12 +250,13 @@ def __exit__(self, *args): return False
|
|
| 247 |
|
| 248 |
with amp_ctx:
|
| 249 |
if not self.started:
|
| 250 |
-
if
|
| 251 |
logger.warning("First frame arrived without a mask; returning neutral alpha.")
|
| 252 |
return np.full((H, W), 0.5, dtype=np.float32)
|
| 253 |
|
| 254 |
-
# encode/memorize
|
| 255 |
-
_ = self.core.step(image=img_chw, mask=
|
|
|
|
| 256 |
# warm-up predict
|
| 257 |
if self._has_first_frame_pred:
|
| 258 |
out_prob = self.core.step(image=img_chw, first_frame_pred=True)
|
|
@@ -455,4 +459,3 @@ def get_info(self) -> Dict[str, Any]:
|
|
| 455 |
# Optional: instance-level shape debugging
|
| 456 |
def debug_shapes(self, image, mask, tag: str = ""):
|
| 457 |
debug_shapes(tag, image, mask)
|
| 458 |
-
|
|
|
|
| 232 |
img_in = _resize_bchw(img_bchw, (th, tw), is_mask=False)
|
| 233 |
msk_in = _resize_bchw(msk_b1hw, (th, tw), is_mask=True) if msk_b1hw is not None else None
|
| 234 |
|
| 235 |
+
# ---- IMPORTANT SHAPE CHANGES (only edit) ----
|
| 236 |
+
img_chw = _to_chw_image(img_in).contiguous() # [C,H,W]
|
| 237 |
+
m_1hw = _to_1hw_mask(msk_in) if msk_in is not None else None # [1,H,W] or None
|
| 238 |
+
mask_2d = m_1hw[0].contiguous() if m_1hw is not None else None # [H,W] or None
|
| 239 |
+
# ------------------------------------------------
|
| 240 |
|
| 241 |
# inference with autocast + inference_mode
|
| 242 |
with torch.inference_mode():
|
|
|
|
| 250 |
|
| 251 |
with amp_ctx:
|
| 252 |
if not self.started:
|
| 253 |
+
if mask_2d is None:
|
| 254 |
logger.warning("First frame arrived without a mask; returning neutral alpha.")
|
| 255 |
return np.full((H, W), 0.5, dtype=np.float32)
|
| 256 |
|
| 257 |
+
# encode/memorize — pass 2-D mask (H,W)
|
| 258 |
+
_ = self.core.step(image=img_chw, mask=mask_2d)
|
| 259 |
+
|
| 260 |
# warm-up predict
|
| 261 |
if self._has_first_frame_pred:
|
| 262 |
out_prob = self.core.step(image=img_chw, first_frame_pred=True)
|
|
|
|
| 459 |
# Optional: instance-level shape debugging
|
| 460 |
def debug_shapes(self, image, mask, tag: str = ""):
|
| 461 |
debug_shapes(tag, image, mask)
|
|
|