MogensR commited on
Commit
880475f
·
1 Parent(s): a28f292

Update models/loaders/matanyone_loader.py

Browse files
Files changed (1) hide show
  1. models/loaders/matanyone_loader.py +15 -24
models/loaders/matanyone_loader.py CHANGED
@@ -15,6 +15,7 @@
15
  - New: Full default cfg from official config.json to fix 'mem_every' issues
16
  - Update: Disable memory propagation by setting mem_every=-1, max_mem_frames=0 to fix dim mismatch in fusion
17
  - Fix: Merge long_term overrides to preserve keys like count_usage
 
18
  """
19
  from __future__ import annotations
20
  import os
@@ -84,7 +85,7 @@ def _to_bchw(x, device: str, is_mask: bool = False) -> torch.Tensor:
84
  x = x.float()
85
  if x.ndim == 5:
86
  x = x[:, 0] # -> 4D
87
- if x.ndim == 4:
88
  if x.shape[-1] in (1, 3, 4) and x.shape[1] not in (1, 3, 4):
89
  x = x.permute(0, 3, 1, 2).contiguous()
90
  elif x.ndim == 3:
@@ -213,17 +214,10 @@ def _info(name, v):
213
  # Precision selection
214
  # ---------------------------------------------------------------------------
215
  def _choose_precision(device: str) -> Tuple[torch.dtype, bool, Optional[torch.dtype]]:
216
- """Pick model weight dtype + autocast dtype (fp16>bf16>fp32) for T4 compatibility."""
217
  if device != "cuda":
218
  return torch.float32, False, None
219
- cc = torch.cuda.get_device_capability() if torch.cuda.is_available() else (0, 0)
220
- fp16_ok = cc[0] >= 7 # Volta+
221
- bf16_ok = cc[0] >= 8 and hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported() # Ampere+ strict
222
- if fp16_ok:
223
- return torch.float16, True, torch.float16 # Prefer fp16 for T4
224
- if bf16_ok:
225
- return torch.bfloat16, True, torch.bfloat16
226
- return torch.float32, False, None
227
 
228
  # ---------------------------------------------------------------------------
229
  # Stateful Adapter around InferenceCore
@@ -543,22 +537,19 @@ def load(self) -> Optional[Any]:
543
  'use_long_term': False,
544
  }
545
  cfg.update(overrides)
546
- # Merge long_term overrides without overwriting whole dict
 
 
 
 
 
 
 
 
547
  if 'long_term' in cfg:
548
- cfg['long_term'].update({
549
- 'max_mem_frames': 0,
550
- 'min_mem_frames': 0,
551
- 'count_usage': False, # Disable since memory off
552
- })
553
  else:
554
- cfg['long_term'] = {
555
- 'buffer_tokens': 2000,
556
- 'count_usage': False,
557
- 'max_mem_frames': 0,
558
- 'max_num_tokens': 10000,
559
- 'min_mem_frames': 0,
560
- 'num_prototypes': 128
561
- }
562
  # Convert to EasyDict for dot access
563
  cfg = EasyDict(cfg)
564
  # Inference core
 
15
  - New: Full default cfg from official config.json to fix 'mem_every' issues
16
  - Update: Disable memory propagation by setting mem_every=-1, max_mem_frames=0 to fix dim mismatch in fusion
17
  - Fix: Merge long_term overrides to preserve keys like count_usage
18
+ - New: Force fp32 for T4 to avoid fp16 fallback bugs in tensor ops
19
  """
20
  from __future__ import annotations
21
  import os
 
85
  x = x.float()
86
  if x.ndim == 5:
87
  x = x[:, 0] # -> 4D
88
+ if x.ndim = 4:
89
  if x.shape[-1] in (1, 3, 4) and x.shape[1] not in (1, 3, 4):
90
  x = x.permute(0, 3, 1, 2).contiguous()
91
  elif x.ndim == 3:
 
214
  # Precision selection
215
  # ---------------------------------------------------------------------------
216
  def _choose_precision(device: str) -> Tuple[torch.dtype, bool, Optional[torch.dtype]]:
217
+ """Pick model weight dtype + autocast dtype (fp32 for stability on T4)."""
218
  if device != "cuda":
219
  return torch.float32, False, None
220
+ return torch.float32, False, None # Force fp32 to avoid fp16 bugs
 
 
 
 
 
 
 
221
 
222
  # ---------------------------------------------------------------------------
223
  # Stateful Adapter around InferenceCore
 
537
  'use_long_term': False,
538
  }
539
  cfg.update(overrides)
540
+ # Merge long_term overrides without removing keys
541
+ long_term_defaults = {
542
+ "buffer_tokens": 2000,
543
+ "count_usage": True,
544
+ "max_mem_frames": 0,
545
+ "max_num_tokens": 10000,
546
+ "min_mem_frames": 0,
547
+ "num_prototypes": 128
548
+ }
549
  if 'long_term' in cfg:
550
+ cfg['long_term'].update({k: v for k, v in long_term_defaults.items() if k not in cfg['long_term'] or k in ['max_mem_frames', 'min_mem_frames']})
 
 
 
 
551
  else:
552
+ cfg['long_term'] = long_term_defaults
 
 
 
 
 
 
 
553
  # Convert to EasyDict for dot access
554
  cfg = EasyDict(cfg)
555
  # Inference core