Update models/loaders/matanyone_loader.py
Browse files
models/loaders/matanyone_loader.py
CHANGED
|
@@ -15,7 +15,7 @@
|
|
| 15 |
- New: Full default cfg from official config.json to fix 'mem_every' issues
|
| 16 |
- Update: Disable memory propagation by setting mem_every=-1, max_mem_frames=0 to fix dim mismatch in fusion
|
| 17 |
- Fix: Merge long_term overrides to preserve keys like count_usage
|
| 18 |
-
-
|
| 19 |
"""
|
| 20 |
from __future__ import annotations
|
| 21 |
import os
|
|
@@ -85,7 +85,7 @@ def _to_bchw(x, device: str, is_mask: bool = False) -> torch.Tensor:
|
|
| 85 |
x = x.float()
|
| 86 |
if x.ndim == 5:
|
| 87 |
x = x[:, 0] # -> 4D
|
| 88 |
-
if x.ndim
|
| 89 |
if x.shape[-1] in (1, 3, 4) and x.shape[1] not in (1, 3, 4):
|
| 90 |
x = x.permute(0, 3, 1, 2).contiguous()
|
| 91 |
elif x.ndim == 3:
|
|
@@ -214,10 +214,17 @@ def _info(name, v):
|
|
| 214 |
# Precision selection
|
| 215 |
# ---------------------------------------------------------------------------
|
| 216 |
def _choose_precision(device: str) -> Tuple[torch.dtype, bool, Optional[torch.dtype]]:
|
| 217 |
-
"""Pick model weight dtype + autocast dtype (fp32 for
|
| 218 |
if device != "cuda":
|
| 219 |
return torch.float32, False, None
|
| 220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
|
| 222 |
# ---------------------------------------------------------------------------
|
| 223 |
# Stateful Adapter around InferenceCore
|
|
|
|
| 15 |
- New: Full default cfg from official config.json to fix 'mem_every' issues
|
| 16 |
- Update: Disable memory propagation by setting mem_every=-1, max_mem_frames=0 to fix dim mismatch in fusion
|
| 17 |
- Fix: Merge long_term overrides to preserve keys like count_usage
|
| 18 |
+
- Fix: Syntax error in _to_bchw (== instead of =)
|
| 19 |
"""
|
| 20 |
from __future__ import annotations
|
| 21 |
import os
|
|
|
|
| 85 |
x = x.float()
|
| 86 |
if x.ndim == 5:
|
| 87 |
x = x[:, 0] # -> 4D
|
| 88 |
+
if x.ndim == 4:
|
| 89 |
if x.shape[-1] in (1, 3, 4) and x.shape[1] not in (1, 3, 4):
|
| 90 |
x = x.permute(0, 3, 1, 2).contiguous()
|
| 91 |
elif x.ndim == 3:
|
|
|
|
| 214 |
# Precision selection
|
| 215 |
# ---------------------------------------------------------------------------
|
| 216 |
def _choose_precision(device: str) -> Tuple[torch.dtype, bool, Optional[torch.dtype]]:
|
| 217 |
+
"""Pick model weight dtype + autocast dtype (fp16>bf16>fp32) for T4 compatibility."""
|
| 218 |
if device != "cuda":
|
| 219 |
return torch.float32, False, None
|
| 220 |
+
cc = torch.cuda.get_device_capability() if torch.cuda.is_available() else (0, 0)
|
| 221 |
+
fp16_ok = cc[0] >= 7 # Volta+
|
| 222 |
+
bf16_ok = cc[0] >= 8 and hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported() # Ampere+ strict
|
| 223 |
+
if fp16_ok:
|
| 224 |
+
return torch.float16, True, torch.float16 # Prefer fp16 for T4
|
| 225 |
+
if bf16_ok:
|
| 226 |
+
return torch.bfloat16, True, torch.bfloat16
|
| 227 |
+
return torch.float32, False, None
|
| 228 |
|
| 229 |
# ---------------------------------------------------------------------------
|
| 230 |
# Stateful Adapter around InferenceCore
|