MogensR commited on
Commit
5711ea9
·
1 Parent(s): 19f3f1c

Update models/loaders/matanyone_loader.py

Browse files
Files changed (1) hide show
  1. models/loaders/matanyone_loader.py +31 -36
models/loaders/matanyone_loader.py CHANGED
@@ -13,9 +13,7 @@
13
  - Added: Prefer fp16 over bf16 for Tesla T4 compatibility
14
  - New: EasyDict polyfill and conversion for cfg to fix 'dict no attribute' errors
15
  - New: Full default cfg from official config.json to fix 'mem_every' issues
16
- - Update: Disable memory propagation by setting mem_every=-1, max_mem_frames=0 to fix dim mismatch in fusion
17
- - Fix: Merge long_term overrides to preserve keys like count_usage
18
- - Fix: Syntax error in _to_bchw (== instead of =)
19
  """
20
  from __future__ import annotations
21
  import os
@@ -188,6 +186,8 @@ def _pad_to_multiple(t: Optional[torch.Tensor], multiple: int = 16) -> Optional[
188
  elif t.ndim == 2:
189
  h, w = t.shape
190
  t = t.unsqueeze(0) # Temp to 3D for padding
 
 
191
  else:
192
  raise ValueError(f"Unsupported ndim for padding: {t.ndim}")
193
  pad_h = (multiple - h % multiple) % multiple
@@ -326,11 +326,20 @@ def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
326
  # nearest to keep binary-like edges
327
  msk_in = F.interpolate(mask_1hw.unsqueeze(0), size=(th, tw), mode="nearest")[0]
328
  img_chw = _to_chw_image(img_in).contiguous() # [C,H,W]
329
- # Pad to multiple of 16
330
- img_chw = _pad_to_multiple(img_chw)
 
331
  if msk_in is not None:
332
- msk_in = _pad_to_multiple(msk_in)
333
- ph, pw = img_chw.shape[-2:]
 
 
 
 
 
 
 
 
334
  with torch.inference_mode():
335
  if self.use_autocast:
336
  amp_ctx = torch.autocast(device_type="cuda", dtype=self.autocast_dtype)
@@ -341,22 +350,23 @@ def __exit__(self, *a): return False
341
  amp_ctx = _NoOp()
342
  with amp_ctx:
343
  if not self.started:
344
- if msk_in is None:
345
  # Should not happen when used correctly — still be defensive
346
  logger.warning("First frame arrived without a mask; returning neutral alpha.")
347
  return np.full((H, W), 0.5, dtype=np.float32)
348
- # CRITICAL: pass **1HW** to .step(mask=...)
349
- _ = self.core.step(image=img_chw, mask=msk_in)
350
  if self._has_first_frame_pred:
351
- out_prob = self.core.step(image=img_chw, first_frame_pred=True)
352
  else:
353
- out_prob = self.core.step(image=img_chw)
354
  self.started = True
355
  else:
356
- out_prob = self.core.step(image=img_chw)
357
  alpha = self._to_alpha(out_prob)
358
  # Unpad to scaled size, then upsample if needed
359
- alpha = alpha[:th, :tw]
 
360
  # Upsample alpha back if we ran at a smaller scale
361
  if (th, tw) != (H, W):
362
  a_b1hw = _to_b1hw_alpha(alpha, device=img_bchw.device)
@@ -441,7 +451,7 @@ def load(self) -> Optional[Any]:
441
  # Full default cfg from official config.json
442
  default_cfg = {
443
  "amp": False,
444
- "chunk_size": -1,
445
  "flip_aug": False,
446
  "long_term": {
447
  "buffer_tokens": 2000,
@@ -527,7 +537,7 @@ def load(self) -> Optional[Any]:
527
  "stagger_updates": 5,
528
  "top_k": 30,
529
  "use_all_masks": False,
530
- "use_long_term": False,
531
  "visualize": False,
532
  "weights": "pretrained_models/matanyone.pth"
533
  }
@@ -535,28 +545,13 @@ def load(self) -> Optional[Any]:
535
  cfg = getattr(self.model, "cfg", default_cfg) or default_cfg
536
  if isinstance(cfg, dict):
537
  cfg = dict(cfg) # Copy to avoid modifying model.cfg
538
- # Override specific values to disable memory and potential dim issues
539
  overrides = {
540
- 'chunk_size': 1,
541
- 'flip_aug': False,
542
- 'mem_every': -1,
543
- 'max_mem_frames': 0,
544
- 'use_long_term': False,
545
  }
546
  cfg.update(overrides)
547
- # Merge long_term overrides without removing keys
548
- long_term_defaults = {
549
- "buffer_tokens": 2000,
550
- "count_usage": True,
551
- "max_mem_frames": 0,
552
- "max_num_tokens": 10000,
553
- "min_mem_frames": 0,
554
- "num_prototypes": 128
555
- }
556
- if 'long_term' in cfg:
557
- cfg['long_term'].update({k: v for k, v in long_term_defaults.items() if k not in cfg['long_term'] or k in ['max_mem_frames', 'min_mem_frames']})
558
- else:
559
- cfg['long_term'] = long_term_defaults
560
  # Convert to EasyDict for dot access
561
  cfg = EasyDict(cfg)
562
  # Inference core
@@ -564,7 +559,7 @@ def load(self) -> Optional[Any]:
564
  self.core = core_cls(self.model, cfg=cfg)
565
  except TypeError:
566
  self.core = core_cls(self.model)
567
- # Some versions expose .to(), some dont — best effort
568
  try:
569
  if hasattr(self.core, "to"):
570
  self.core.to(self.device)
 
13
  - Added: Prefer fp16 over bf16 for Tesla T4 compatibility
14
  - New: EasyDict polyfill and conversion for cfg to fix 'dict no attribute' errors
15
  - New: Full default cfg from official config.json to fix 'mem_every' issues
16
+ - FIXED: Re-enabled memory features and added temporal dimension support
 
 
17
  """
18
  from __future__ import annotations
19
  import os
 
186
  elif t.ndim == 2:
187
  h, w = t.shape
188
  t = t.unsqueeze(0) # Temp to 3D for padding
189
+ elif t.ndim == 4: # Handle [T, C, H, W] or similar
190
+ return t # Skip padding for temporal tensors
191
  else:
192
  raise ValueError(f"Unsupported ndim for padding: {t.ndim}")
193
  pad_h = (multiple - h % multiple) % multiple
 
326
  # nearest to keep binary-like edges
327
  msk_in = F.interpolate(mask_1hw.unsqueeze(0), size=(th, tw), mode="nearest")[0]
328
  img_chw = _to_chw_image(img_in).contiguous() # [C,H,W]
329
+
330
+ # ADD TEMPORAL DIMENSION for video processing mode
331
+ img_tchw = img_chw.unsqueeze(0) # [C,H,W] -> [T=1,C,H,W]
332
  if msk_in is not None:
333
+ msk_t1hw = msk_in.unsqueeze(0) # [1,H,W] -> [T=1,1,H,W]
334
+ else:
335
+ msk_t1hw = None
336
+
337
+ # Pad to multiple of 16 (skip for temporal tensors)
338
+ img_tchw = _pad_to_multiple(img_tchw)
339
+ if msk_t1hw is not None:
340
+ msk_t1hw = _pad_to_multiple(msk_t1hw)
341
+
342
+ ph, pw = img_tchw.shape[-2:]
343
  with torch.inference_mode():
344
  if self.use_autocast:
345
  amp_ctx = torch.autocast(device_type="cuda", dtype=self.autocast_dtype)
 
350
  amp_ctx = _NoOp()
351
  with amp_ctx:
352
  if not self.started:
353
+ if msk_t1hw is None:
354
  # Should not happen when used correctly — still be defensive
355
  logger.warning("First frame arrived without a mask; returning neutral alpha.")
356
  return np.full((H, W), 0.5, dtype=np.float32)
357
+ # Pass temporal tensors to core
358
+ _ = self.core.step(image=img_tchw, mask=msk_t1hw)
359
  if self._has_first_frame_pred:
360
+ out_prob = self.core.step(image=img_tchw, first_frame_pred=True)
361
  else:
362
+ out_prob = self.core.step(image=img_tchw)
363
  self.started = True
364
  else:
365
+ out_prob = self.core.step(image=img_tchw)
366
  alpha = self._to_alpha(out_prob)
367
  # Unpad to scaled size, then upsample if needed
368
+ if alpha.ndim >= 2:
369
+ alpha = alpha[..., :th, :tw]
370
  # Upsample alpha back if we ran at a smaller scale
371
  if (th, tw) != (H, W):
372
  a_b1hw = _to_b1hw_alpha(alpha, device=img_bchw.device)
 
451
  # Full default cfg from official config.json
452
  default_cfg = {
453
  "amp": False,
454
+ "chunk_size": 1, # Keep at 1 for single frame processing
455
  "flip_aug": False,
456
  "long_term": {
457
  "buffer_tokens": 2000,
 
537
  "stagger_updates": 5,
538
  "top_k": 30,
539
  "use_all_masks": False,
540
+ "use_long_term": True, # Enable long-term memory
541
  "visualize": False,
542
  "weights": "pretrained_models/matanyone.pth"
543
  }
 
545
  cfg = getattr(self.model, "cfg", default_cfg) or default_cfg
546
  if isinstance(cfg, dict):
547
  cfg = dict(cfg) # Copy to avoid modifying model.cfg
548
+ # Only override minimal settings for compatibility
549
  overrides = {
550
+ 'chunk_size': 1, # Process one frame at a time
551
+ 'flip_aug': False, # Disable augmentation
552
+ # Keep memory features enabled!
 
 
553
  }
554
  cfg.update(overrides)
 
 
 
 
 
 
 
 
 
 
 
 
 
555
  # Convert to EasyDict for dot access
556
  cfg = EasyDict(cfg)
557
  # Inference core
 
559
  self.core = core_cls(self.model, cfg=cfg)
560
  except TypeError:
561
  self.core = core_cls(self.model)
562
+ # Some versions expose .to(), some don't — best effort
563
  try:
564
  if hasattr(self.core, "to"):
565
  self.core.to(self.device)