Update models/loaders/matanyone_loader.py
Browse files
models/loaders/matanyone_loader.py
CHANGED
|
@@ -13,9 +13,7 @@
|
|
| 13 |
- Added: Prefer fp16 over bf16 for Tesla T4 compatibility
|
| 14 |
- New: EasyDict polyfill and conversion for cfg to fix 'dict no attribute' errors
|
| 15 |
- New: Full default cfg from official config.json to fix 'mem_every' issues
|
| 16 |
-
-
|
| 17 |
-
- Fix: Merge long_term overrides to preserve keys like count_usage
|
| 18 |
-
- Fix: Syntax error in _to_bchw (== instead of =)
|
| 19 |
"""
|
| 20 |
from __future__ import annotations
|
| 21 |
import os
|
|
@@ -188,6 +186,8 @@ def _pad_to_multiple(t: Optional[torch.Tensor], multiple: int = 16) -> Optional[
|
|
| 188 |
elif t.ndim == 2:
|
| 189 |
h, w = t.shape
|
| 190 |
t = t.unsqueeze(0) # Temp to 3D for padding
|
|
|
|
|
|
|
| 191 |
else:
|
| 192 |
raise ValueError(f"Unsupported ndim for padding: {t.ndim}")
|
| 193 |
pad_h = (multiple - h % multiple) % multiple
|
|
@@ -326,11 +326,20 @@ def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
|
|
| 326 |
# nearest to keep binary-like edges
|
| 327 |
msk_in = F.interpolate(mask_1hw.unsqueeze(0), size=(th, tw), mode="nearest")[0]
|
| 328 |
img_chw = _to_chw_image(img_in).contiguous() # [C,H,W]
|
| 329 |
-
|
| 330 |
-
|
|
|
|
| 331 |
if msk_in is not None:
|
| 332 |
-
|
| 333 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
with torch.inference_mode():
|
| 335 |
if self.use_autocast:
|
| 336 |
amp_ctx = torch.autocast(device_type="cuda", dtype=self.autocast_dtype)
|
|
@@ -341,22 +350,23 @@ def __exit__(self, *a): return False
|
|
| 341 |
amp_ctx = _NoOp()
|
| 342 |
with amp_ctx:
|
| 343 |
if not self.started:
|
| 344 |
-
if
|
| 345 |
# Should not happen when used correctly — still be defensive
|
| 346 |
logger.warning("First frame arrived without a mask; returning neutral alpha.")
|
| 347 |
return np.full((H, W), 0.5, dtype=np.float32)
|
| 348 |
-
#
|
| 349 |
-
_ = self.core.step(image=
|
| 350 |
if self._has_first_frame_pred:
|
| 351 |
-
out_prob = self.core.step(image=
|
| 352 |
else:
|
| 353 |
-
out_prob = self.core.step(image=
|
| 354 |
self.started = True
|
| 355 |
else:
|
| 356 |
-
out_prob = self.core.step(image=
|
| 357 |
alpha = self._to_alpha(out_prob)
|
| 358 |
# Unpad to scaled size, then upsample if needed
|
| 359 |
-
alpha =
|
|
|
|
| 360 |
# Upsample alpha back if we ran at a smaller scale
|
| 361 |
if (th, tw) != (H, W):
|
| 362 |
a_b1hw = _to_b1hw_alpha(alpha, device=img_bchw.device)
|
|
@@ -441,7 +451,7 @@ def load(self) -> Optional[Any]:
|
|
| 441 |
# Full default cfg from official config.json
|
| 442 |
default_cfg = {
|
| 443 |
"amp": False,
|
| 444 |
-
"chunk_size":
|
| 445 |
"flip_aug": False,
|
| 446 |
"long_term": {
|
| 447 |
"buffer_tokens": 2000,
|
|
@@ -527,7 +537,7 @@ def load(self) -> Optional[Any]:
|
|
| 527 |
"stagger_updates": 5,
|
| 528 |
"top_k": 30,
|
| 529 |
"use_all_masks": False,
|
| 530 |
-
"use_long_term":
|
| 531 |
"visualize": False,
|
| 532 |
"weights": "pretrained_models/matanyone.pth"
|
| 533 |
}
|
|
@@ -535,28 +545,13 @@ def load(self) -> Optional[Any]:
|
|
| 535 |
cfg = getattr(self.model, "cfg", default_cfg) or default_cfg
|
| 536 |
if isinstance(cfg, dict):
|
| 537 |
cfg = dict(cfg) # Copy to avoid modifying model.cfg
|
| 538 |
-
#
|
| 539 |
overrides = {
|
| 540 |
-
'chunk_size': 1,
|
| 541 |
-
'flip_aug': False,
|
| 542 |
-
|
| 543 |
-
'max_mem_frames': 0,
|
| 544 |
-
'use_long_term': False,
|
| 545 |
}
|
| 546 |
cfg.update(overrides)
|
| 547 |
-
# Merge long_term overrides without removing keys
|
| 548 |
-
long_term_defaults = {
|
| 549 |
-
"buffer_tokens": 2000,
|
| 550 |
-
"count_usage": True,
|
| 551 |
-
"max_mem_frames": 0,
|
| 552 |
-
"max_num_tokens": 10000,
|
| 553 |
-
"min_mem_frames": 0,
|
| 554 |
-
"num_prototypes": 128
|
| 555 |
-
}
|
| 556 |
-
if 'long_term' in cfg:
|
| 557 |
-
cfg['long_term'].update({k: v for k, v in long_term_defaults.items() if k not in cfg['long_term'] or k in ['max_mem_frames', 'min_mem_frames']})
|
| 558 |
-
else:
|
| 559 |
-
cfg['long_term'] = long_term_defaults
|
| 560 |
# Convert to EasyDict for dot access
|
| 561 |
cfg = EasyDict(cfg)
|
| 562 |
# Inference core
|
|
@@ -564,7 +559,7 @@ def load(self) -> Optional[Any]:
|
|
| 564 |
self.core = core_cls(self.model, cfg=cfg)
|
| 565 |
except TypeError:
|
| 566 |
self.core = core_cls(self.model)
|
| 567 |
-
# Some versions expose .to(), some don
|
| 568 |
try:
|
| 569 |
if hasattr(self.core, "to"):
|
| 570 |
self.core.to(self.device)
|
|
|
|
| 13 |
- Added: Prefer fp16 over bf16 for Tesla T4 compatibility
|
| 14 |
- New: EasyDict polyfill and conversion for cfg to fix 'dict no attribute' errors
|
| 15 |
- New: Full default cfg from official config.json to fix 'mem_every' issues
|
| 16 |
+
- FIXED: Re-enabled memory features and added temporal dimension support
|
|
|
|
|
|
|
| 17 |
"""
|
| 18 |
from __future__ import annotations
|
| 19 |
import os
|
|
|
|
| 186 |
elif t.ndim == 2:
|
| 187 |
h, w = t.shape
|
| 188 |
t = t.unsqueeze(0) # Temp to 3D for padding
|
| 189 |
+
elif t.ndim == 4: # Handle [T, C, H, W] or similar
|
| 190 |
+
return t # Skip padding for temporal tensors
|
| 191 |
else:
|
| 192 |
raise ValueError(f"Unsupported ndim for padding: {t.ndim}")
|
| 193 |
pad_h = (multiple - h % multiple) % multiple
|
|
|
|
| 326 |
# nearest to keep binary-like edges
|
| 327 |
msk_in = F.interpolate(mask_1hw.unsqueeze(0), size=(th, tw), mode="nearest")[0]
|
| 328 |
img_chw = _to_chw_image(img_in).contiguous() # [C,H,W]
|
| 329 |
+
|
| 330 |
+
# ADD TEMPORAL DIMENSION for video processing mode
|
| 331 |
+
img_tchw = img_chw.unsqueeze(0) # [C,H,W] -> [T=1,C,H,W]
|
| 332 |
if msk_in is not None:
|
| 333 |
+
msk_t1hw = msk_in.unsqueeze(0) # [1,H,W] -> [T=1,1,H,W]
|
| 334 |
+
else:
|
| 335 |
+
msk_t1hw = None
|
| 336 |
+
|
| 337 |
+
# Pad to multiple of 16 (skip for temporal tensors)
|
| 338 |
+
img_tchw = _pad_to_multiple(img_tchw)
|
| 339 |
+
if msk_t1hw is not None:
|
| 340 |
+
msk_t1hw = _pad_to_multiple(msk_t1hw)
|
| 341 |
+
|
| 342 |
+
ph, pw = img_tchw.shape[-2:]
|
| 343 |
with torch.inference_mode():
|
| 344 |
if self.use_autocast:
|
| 345 |
amp_ctx = torch.autocast(device_type="cuda", dtype=self.autocast_dtype)
|
|
|
|
| 350 |
amp_ctx = _NoOp()
|
| 351 |
with amp_ctx:
|
| 352 |
if not self.started:
|
| 353 |
+
if msk_t1hw is None:
|
| 354 |
# Should not happen when used correctly — still be defensive
|
| 355 |
logger.warning("First frame arrived without a mask; returning neutral alpha.")
|
| 356 |
return np.full((H, W), 0.5, dtype=np.float32)
|
| 357 |
+
# Pass temporal tensors to core
|
| 358 |
+
_ = self.core.step(image=img_tchw, mask=msk_t1hw)
|
| 359 |
if self._has_first_frame_pred:
|
| 360 |
+
out_prob = self.core.step(image=img_tchw, first_frame_pred=True)
|
| 361 |
else:
|
| 362 |
+
out_prob = self.core.step(image=img_tchw)
|
| 363 |
self.started = True
|
| 364 |
else:
|
| 365 |
+
out_prob = self.core.step(image=img_tchw)
|
| 366 |
alpha = self._to_alpha(out_prob)
|
| 367 |
# Unpad to scaled size, then upsample if needed
|
| 368 |
+
if alpha.ndim >= 2:
|
| 369 |
+
alpha = alpha[..., :th, :tw]
|
| 370 |
# Upsample alpha back if we ran at a smaller scale
|
| 371 |
if (th, tw) != (H, W):
|
| 372 |
a_b1hw = _to_b1hw_alpha(alpha, device=img_bchw.device)
|
|
|
|
| 451 |
# Full default cfg from official config.json
|
| 452 |
default_cfg = {
|
| 453 |
"amp": False,
|
| 454 |
+
"chunk_size": 1, # Keep at 1 for single frame processing
|
| 455 |
"flip_aug": False,
|
| 456 |
"long_term": {
|
| 457 |
"buffer_tokens": 2000,
|
|
|
|
| 537 |
"stagger_updates": 5,
|
| 538 |
"top_k": 30,
|
| 539 |
"use_all_masks": False,
|
| 540 |
+
"use_long_term": True, # Enable long-term memory
|
| 541 |
"visualize": False,
|
| 542 |
"weights": "pretrained_models/matanyone.pth"
|
| 543 |
}
|
|
|
|
| 545 |
cfg = getattr(self.model, "cfg", default_cfg) or default_cfg
|
| 546 |
if isinstance(cfg, dict):
|
| 547 |
cfg = dict(cfg) # Copy to avoid modifying model.cfg
|
| 548 |
+
# Only override minimal settings for compatibility
|
| 549 |
overrides = {
|
| 550 |
+
'chunk_size': 1, # Process one frame at a time
|
| 551 |
+
'flip_aug': False, # Disable augmentation
|
| 552 |
+
# Keep memory features enabled!
|
|
|
|
|
|
|
| 553 |
}
|
| 554 |
cfg.update(overrides)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 555 |
# Convert to EasyDict for dot access
|
| 556 |
cfg = EasyDict(cfg)
|
| 557 |
# Inference core
|
|
|
|
| 559 |
self.core = core_cls(self.model, cfg=cfg)
|
| 560 |
except TypeError:
|
| 561 |
self.core = core_cls(self.model)
|
| 562 |
+
# Some versions expose .to(), some don't — best effort
|
| 563 |
try:
|
| 564 |
if hasattr(self.core, "to"):
|
| 565 |
self.core.to(self.device)
|