MogensR commited on
Commit
4f1de42
·
1 Parent(s): f7afe05

Update models/loaders/matanyone_loader.py

Browse files
Files changed (1) hide show
  1. models/loaders/matanyone_loader.py +9 -6
models/loaders/matanyone_loader.py CHANGED
@@ -232,8 +232,11 @@ def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
232
  img_in = _resize_bchw(img_bchw, (th, tw), is_mask=False)
233
  msk_in = _resize_bchw(msk_b1hw, (th, tw), is_mask=True) if msk_b1hw is not None else None
234
 
235
- img_chw = _to_chw_image(img_in)
236
- m_1hw = _to_1hw_mask(msk_in) if msk_in is not None else None
 
 
 
237
 
238
  # inference with autocast + inference_mode
239
  with torch.inference_mode():
@@ -247,12 +250,13 @@ def __exit__(self, *args): return False
247
 
248
  with amp_ctx:
249
  if not self.started:
250
- if m_1hw is None:
251
  logger.warning("First frame arrived without a mask; returning neutral alpha.")
252
  return np.full((H, W), 0.5, dtype=np.float32)
253
 
254
- # encode/memorize
255
- _ = self.core.step(image=img_chw, mask=m_1hw)
 
256
  # warm-up predict
257
  if self._has_first_frame_pred:
258
  out_prob = self.core.step(image=img_chw, first_frame_pred=True)
@@ -455,4 +459,3 @@ def get_info(self) -> Dict[str, Any]:
455
  # Optional: instance-level shape debugging
456
  def debug_shapes(self, image, mask, tag: str = ""):
457
  debug_shapes(tag, image, mask)
458
-
 
232
  img_in = _resize_bchw(img_bchw, (th, tw), is_mask=False)
233
  msk_in = _resize_bchw(msk_b1hw, (th, tw), is_mask=True) if msk_b1hw is not None else None
234
 
235
+ # ---- IMPORTANT SHAPE CHANGES (only edit) ----
236
+ img_chw = _to_chw_image(img_in).contiguous() # [C,H,W]
237
+ m_1hw = _to_1hw_mask(msk_in) if msk_in is not None else None # [1,H,W] or None
238
+ mask_2d = m_1hw[0].contiguous() if m_1hw is not None else None # [H,W] or None
239
+ # ------------------------------------------------
240
 
241
  # inference with autocast + inference_mode
242
  with torch.inference_mode():
 
250
 
251
  with amp_ctx:
252
  if not self.started:
253
+ if mask_2d is None:
254
  logger.warning("First frame arrived without a mask; returning neutral alpha.")
255
  return np.full((H, W), 0.5, dtype=np.float32)
256
 
257
+ # encode/memorize — pass 2-D mask (H,W)
258
+ _ = self.core.step(image=img_chw, mask=mask_2d)
259
+
260
  # warm-up predict
261
  if self._has_first_frame_pred:
262
  out_prob = self.core.step(image=img_chw, first_frame_pred=True)
 
459
  # Optional: instance-level shape debugging
460
  def debug_shapes(self, image, mask, tag: str = ""):
461
  debug_shapes(tag, image, mask)