Update models/loaders/matanyone_loader.py
Browse files
models/loaders/matanyone_loader.py
CHANGED
|
@@ -15,6 +15,7 @@
|
|
| 15 |
- New: Full default cfg from official config.json to fix 'mem_every' issues
|
| 16 |
- Update: Disable memory propagation by setting mem_every=-1, max_mem_frames=0 to fix dim mismatch in fusion
|
| 17 |
- Fix: Merge long_term overrides to preserve keys like count_usage
|
|
|
|
| 18 |
"""
|
| 19 |
from __future__ import annotations
|
| 20 |
import os
|
|
@@ -84,7 +85,7 @@ def _to_bchw(x, device: str, is_mask: bool = False) -> torch.Tensor:
|
|
| 84 |
x = x.float()
|
| 85 |
if x.ndim == 5:
|
| 86 |
x = x[:, 0] # -> 4D
|
| 87 |
-
if x.ndim
|
| 88 |
if x.shape[-1] in (1, 3, 4) and x.shape[1] not in (1, 3, 4):
|
| 89 |
x = x.permute(0, 3, 1, 2).contiguous()
|
| 90 |
elif x.ndim == 3:
|
|
@@ -213,17 +214,10 @@ def _info(name, v):
|
|
| 213 |
# Precision selection
|
| 214 |
# ---------------------------------------------------------------------------
|
| 215 |
def _choose_precision(device: str) -> Tuple[torch.dtype, bool, Optional[torch.dtype]]:
|
| 216 |
-
"""Pick model weight dtype + autocast dtype (
|
| 217 |
if device != "cuda":
|
| 218 |
return torch.float32, False, None
|
| 219 |
-
|
| 220 |
-
fp16_ok = cc[0] >= 7 # Volta+
|
| 221 |
-
bf16_ok = cc[0] >= 8 and hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported() # Ampere+ strict
|
| 222 |
-
if fp16_ok:
|
| 223 |
-
return torch.float16, True, torch.float16 # Prefer fp16 for T4
|
| 224 |
-
if bf16_ok:
|
| 225 |
-
return torch.bfloat16, True, torch.bfloat16
|
| 226 |
-
return torch.float32, False, None
|
| 227 |
|
| 228 |
# ---------------------------------------------------------------------------
|
| 229 |
# Stateful Adapter around InferenceCore
|
|
@@ -543,22 +537,19 @@ def load(self) -> Optional[Any]:
|
|
| 543 |
'use_long_term': False,
|
| 544 |
}
|
| 545 |
cfg.update(overrides)
|
| 546 |
-
# Merge long_term overrides without
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 547 |
if 'long_term' in cfg:
|
| 548 |
-
cfg['long_term'].update({
|
| 549 |
-
'max_mem_frames': 0,
|
| 550 |
-
'min_mem_frames': 0,
|
| 551 |
-
'count_usage': False, # Disable since memory off
|
| 552 |
-
})
|
| 553 |
else:
|
| 554 |
-
cfg['long_term'] =
|
| 555 |
-
'buffer_tokens': 2000,
|
| 556 |
-
'count_usage': False,
|
| 557 |
-
'max_mem_frames': 0,
|
| 558 |
-
'max_num_tokens': 10000,
|
| 559 |
-
'min_mem_frames': 0,
|
| 560 |
-
'num_prototypes': 128
|
| 561 |
-
}
|
| 562 |
# Convert to EasyDict for dot access
|
| 563 |
cfg = EasyDict(cfg)
|
| 564 |
# Inference core
|
|
|
|
| 15 |
- New: Full default cfg from official config.json to fix 'mem_every' issues
|
| 16 |
- Update: Disable memory propagation by setting mem_every=-1, max_mem_frames=0 to fix dim mismatch in fusion
|
| 17 |
- Fix: Merge long_term overrides to preserve keys like count_usage
|
| 18 |
+
- New: Force fp32 for T4 to avoid fp16 fallback bugs in tensor ops
|
| 19 |
"""
|
| 20 |
from __future__ import annotations
|
| 21 |
import os
|
|
|
|
| 85 |
x = x.float()
|
| 86 |
if x.ndim == 5:
|
| 87 |
x = x[:, 0] # -> 4D
|
| 88 |
+
if x.ndim = 4:
|
| 89 |
if x.shape[-1] in (1, 3, 4) and x.shape[1] not in (1, 3, 4):
|
| 90 |
x = x.permute(0, 3, 1, 2).contiguous()
|
| 91 |
elif x.ndim == 3:
|
|
|
|
| 214 |
# Precision selection
|
| 215 |
# ---------------------------------------------------------------------------
|
| 216 |
def _choose_precision(device: str) -> Tuple[torch.dtype, bool, Optional[torch.dtype]]:
|
| 217 |
+
"""Pick model weight dtype + autocast dtype (fp32 for stability on T4)."""
|
| 218 |
if device != "cuda":
|
| 219 |
return torch.float32, False, None
|
| 220 |
+
return torch.float32, False, None # Force fp32 to avoid fp16 bugs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
|
| 222 |
# ---------------------------------------------------------------------------
|
| 223 |
# Stateful Adapter around InferenceCore
|
|
|
|
| 537 |
'use_long_term': False,
|
| 538 |
}
|
| 539 |
cfg.update(overrides)
|
| 540 |
+
# Merge long_term overrides without removing keys
|
| 541 |
+
long_term_defaults = {
|
| 542 |
+
"buffer_tokens": 2000,
|
| 543 |
+
"count_usage": True,
|
| 544 |
+
"max_mem_frames": 0,
|
| 545 |
+
"max_num_tokens": 10000,
|
| 546 |
+
"min_mem_frames": 0,
|
| 547 |
+
"num_prototypes": 128
|
| 548 |
+
}
|
| 549 |
if 'long_term' in cfg:
|
| 550 |
+
cfg['long_term'].update({k: v for k, v in long_term_defaults.items() if k not in cfg['long_term'] or k in ['max_mem_frames', 'min_mem_frames']})
|
|
|
|
|
|
|
|
|
|
|
|
|
| 551 |
else:
|
| 552 |
+
cfg['long_term'] = long_term_defaults
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 553 |
# Convert to EasyDict for dot access
|
| 554 |
cfg = EasyDict(cfg)
|
| 555 |
# Inference core
|