AI Agent commited on
Commit
cb7cf0f
·
1 Parent(s): 4b23049

CRITICAL FIX: Remove is_bf16_supported gate - T4 reports True but crashes on bf16 ops. Patches now unconditional + model cast to fp16

Browse files
Files changed (1) hide show
  1. app.py +25 -17
app.py CHANGED
@@ -6,27 +6,31 @@ from PIL import Image
6
  import os
7
  import io
8
 
9
- # ── Native Runtime BFloat16 Nullification for T4 Turing GPUs ────
10
- # Hugging Face containers lock `site-packages` on boot, dropping our dynamic writer.
11
- # Instead, we directly alias PyTorch's global datatypes and AMP engine modes to silently
12
- # intercept Meta's `bfloat16` requests anywhere in the active memory loop.
13
- if torch.cuda.is_available() and not torch.cuda.is_bf16_supported():
14
- # 1. Bruteforce global datatype alias mapping
 
15
  torch.bfloat16 = torch.float16
16
 
17
- # 2. Intercept inner PyTorch AMP decorators
18
  import torch.amp.autocast_mode
19
- original_amp = torch.amp.autocast_mode.autocast
20
- class PatchedAutocast(original_amp):
21
  def __init__(self, device_type, dtype=None, *args, **kwargs):
22
- if dtype == torch.bfloat16 or dtype == torch.float16:
23
- dtype = torch.float16 # Always force Turing FP16
24
- super().__init__(device_type, dtype, *args, **kwargs)
25
-
26
- torch.autocast = PatchedAutocast
27
- torch.amp.autocast_mode.autocast = PatchedAutocast
 
28
  if hasattr(torch.amp, 'autocast'):
29
- torch.amp.autocast = PatchedAutocast
 
 
30
 
31
  # ── Ensure SAM 3 Checkpoint is downloaded ────────────────────────
32
  # (HuggingFace Spaces can use the hf_hub_download mechanism)
@@ -86,7 +90,11 @@ if model_installed:
86
 
87
  model.load_state_dict(image_state_dict, strict=False)
88
  model.to(device)
89
- model.to(torch.float32) # Maintain standard Floats parameters; the patched Float16 autocast will natively handle precision math.
 
 
 
 
90
 
91
  processor = Sam3Processor(model)
92
  if not torch.cuda.is_available():
 
6
  import os
7
  import io
8
 
9
+ # ── UNCONDITIONAL BFloat16 Float16 Patch for T4 Turing GPUs ────
10
+ # CRITICAL: torch.cuda.is_bf16_supported() returns True on T4 because CUDA
11
+ # can *emulate* bfloat16 in software, but the actual kernels crash on mixed
12
+ # dtype operations (linear, conv2d). We MUST patch unconditionally.
13
+ if torch.cuda.is_available():
14
+ # 1. Globally alias bfloat16 float16 so all future lookups resolve to fp16
15
+ _original_bf16 = torch.bfloat16
16
  torch.bfloat16 = torch.float16
17
 
18
+ # 2. Intercept ALL autocast entry points to force float16
19
  import torch.amp.autocast_mode
20
+ _OriginalAutocast = torch.amp.autocast_mode.autocast
21
+ class _Fp16Autocast(_OriginalAutocast):
22
  def __init__(self, device_type, dtype=None, *args, **kwargs):
23
+ # Intercept any bfloat16 request (original C enum or aliased)
24
+ if dtype is not None and dtype in (_original_bf16, torch.float16):
25
+ dtype = torch.float16
26
+ super().__init__(device_type, dtype=dtype, *args, **kwargs)
27
+
28
+ torch.autocast = _Fp16Autocast
29
+ torch.amp.autocast_mode.autocast = _Fp16Autocast
30
  if hasattr(torch.amp, 'autocast'):
31
+ torch.amp.autocast = _Fp16Autocast
32
+ if hasattr(torch.cuda.amp, 'autocast'):
33
+ torch.cuda.amp.autocast = _Fp16Autocast
34
 
35
  # ── Ensure SAM 3 Checkpoint is downloaded ────────────────────────
36
  # (HuggingFace Spaces can use the hf_hub_download mechanism)
 
90
 
91
  model.load_state_dict(image_state_dict, strict=False)
92
  model.to(device)
93
+ # Cast to float16 on GPU (matches our patched autocast dtype) or float32 on CPU
94
+ if torch.cuda.is_available():
95
+ model.to(torch.float16)
96
+ else:
97
+ model.to(torch.float32)
98
 
99
  processor = Sam3Processor(model)
100
  if not torch.cuda.is_available():