AI Agent commited on
Commit
32691c0
·
1 Parent(s): a998b70

Cast model explicitly to bfloat16 on T4 to override AMP dtype mismatch

Browse files
Files changed (1) hide show
  1. app.py +6 -12
app.py CHANGED
@@ -31,17 +31,6 @@ if not torch.cuda.is_available():
31
  return __orig_fn(*args, **kwargs)
32
  setattr(torch, name, patched_fn)
33
 
34
- # Intercept Meta's hardcoded BFloat16 autocast (T4 Turing GPUs don't support BFloat16 hardware math)
35
- original_autocast = torch.autocast
36
- class PatchedAutocast(original_autocast):
37
- def __init__(self, device_type, dtype=None, *args, **kwargs):
38
- if dtype == torch.bfloat16 and torch.cuda.is_available() and not torch.cuda.is_bf16_supported():
39
- dtype = torch.float16 # Fallback to fp16, supported perfectly by T4 Turing NVidia cards
40
- if device_type == 'cuda' and not torch.cuda.is_available():
41
- device_type = 'cpu'
42
- super().__init__(device_type, dtype, *args, **kwargs)
43
- torch.autocast = PatchedAutocast
44
-
45
  # ── SAM 3 Imports ────────────────────────────────────────────────
46
  try:
47
  from sam3.model_builder import build_sam3_image_model
@@ -75,8 +64,13 @@ if model_installed:
75
 
76
  model.load_state_dict(image_state_dict, strict=False)
77
  model.to(device)
78
- if not torch.cuda.is_available():
 
 
 
 
79
  model.to(torch.float32) # Force upcast from checkpoint's native bfloat16 to float32 for CPU inference
 
80
  processor = Sam3Processor(model)
81
  if not torch.cuda.is_available():
82
  processor.device = "cpu"
 
31
  return __orig_fn(*args, **kwargs)
32
  setattr(torch, name, patched_fn)
33
 
 
 
 
 
 
 
 
 
 
 
 
34
  # ── SAM 3 Imports ────────────────────────────────────────────────
35
  try:
36
  from sam3.model_builder import build_sam3_image_model
 
64
 
65
  model.load_state_dict(image_state_dict, strict=False)
66
  model.to(device)
67
+
68
+ # BFloat16 alignment for T4 / Turing GPUs
69
+ if torch.cuda.is_available():
70
+ model.to(torch.bfloat16)
71
+ else:
72
  model.to(torch.float32) # Force upcast from checkpoint's native bfloat16 to float32 for CPU inference
73
+
74
  processor = Sam3Processor(model)
75
  if not torch.cuda.is_available():
76
  processor.device = "cpu"