MogensR commited on
Commit
b91eb11
·
1 Parent(s): a6e89c1

Update models/loaders/matanyone_loader.py

Browse files
Files changed (1) hide show
  1. models/loaders/matanyone_loader.py +11 -4
models/loaders/matanyone_loader.py CHANGED
@@ -15,7 +15,7 @@
15
  - New: Full default cfg from official config.json to fix 'mem_every' issues
16
  - Update: Disable memory propagation by setting mem_every=-1, max_mem_frames=0 to fix dim mismatch in fusion
17
  - Fix: Merge long_term overrides to preserve keys like count_usage
18
- - New: Force fp32 for T4 to avoid fp16 fallback bugs in tensor ops
19
  """
20
  from __future__ import annotations
21
  import os
@@ -85,7 +85,7 @@ def _to_bchw(x, device: str, is_mask: bool = False) -> torch.Tensor:
85
  x = x.float()
86
  if x.ndim == 5:
87
  x = x[:, 0] # -> 4D
88
- if x.ndim = 4:
89
  if x.shape[-1] in (1, 3, 4) and x.shape[1] not in (1, 3, 4):
90
  x = x.permute(0, 3, 1, 2).contiguous()
91
  elif x.ndim == 3:
@@ -214,10 +214,17 @@ def _info(name, v):
214
  # Precision selection
215
  # ---------------------------------------------------------------------------
216
  def _choose_precision(device: str) -> Tuple[torch.dtype, bool, Optional[torch.dtype]]:
217
- """Pick model weight dtype + autocast dtype (fp32 for stability on T4)."""
218
  if device != "cuda":
219
  return torch.float32, False, None
220
- return torch.float32, False, None # Force fp32 to avoid fp16 bugs
 
 
 
 
 
 
 
221
 
222
  # ---------------------------------------------------------------------------
223
  # Stateful Adapter around InferenceCore
 
15
  - New: Full default cfg from official config.json to fix 'mem_every' issues
16
  - Update: Disable memory propagation by setting mem_every=-1, max_mem_frames=0 to fix dim mismatch in fusion
17
  - Fix: Merge long_term overrides to preserve keys like count_usage
18
+ - Fix: Syntax error in _to_bchw (== instead of =)
19
  """
20
  from __future__ import annotations
21
  import os
 
85
  x = x.float()
86
  if x.ndim == 5:
87
  x = x[:, 0] # -> 4D
88
+ if x.ndim == 4:
89
  if x.shape[-1] in (1, 3, 4) and x.shape[1] not in (1, 3, 4):
90
  x = x.permute(0, 3, 1, 2).contiguous()
91
  elif x.ndim == 3:
 
214
  # Precision selection
215
  # ---------------------------------------------------------------------------
216
  def _choose_precision(device: str) -> Tuple[torch.dtype, bool, Optional[torch.dtype]]:
217
+ """Pick model weight dtype + autocast dtype (fp16>bf16>fp32) for T4 compatibility."""
218
  if device != "cuda":
219
  return torch.float32, False, None
220
+ cc = torch.cuda.get_device_capability() if torch.cuda.is_available() else (0, 0)
221
+ fp16_ok = cc[0] >= 7 # Volta+
222
+ bf16_ok = cc[0] >= 8 and hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported() # Ampere+ strict
223
+ if fp16_ok:
224
+ return torch.float16, True, torch.float16 # Prefer fp16 for T4
225
+ if bf16_ok:
226
+ return torch.bfloat16, True, torch.bfloat16
227
+ return torch.float32, False, None
228
 
229
  # ---------------------------------------------------------------------------
230
  # Stateful Adapter around InferenceCore