AI Agent commited on
Commit
a9f1afc
·
1 Parent(s): cb7cf0f

HOTFIX: Remove global bfloat16 alias to unbreak torch.load checkpoint parsing

Browse files
Files changed (1) hide show
  1. app.py +3 -7
app.py CHANGED
@@ -11,17 +11,13 @@ import io
11
  # can *emulate* bfloat16 in software, but the actual kernels crash on mixed
12
  # dtype operations (linear, conv2d). We MUST patch unconditionally.
13
  if torch.cuda.is_available():
14
- # 1. Globally alias bfloat16 float16 so all future lookups resolve to fp16
15
- _original_bf16 = torch.bfloat16
16
- torch.bfloat16 = torch.float16
17
-
18
- # 2. Intercept ALL autocast entry points to force float16
19
  import torch.amp.autocast_mode
20
  _OriginalAutocast = torch.amp.autocast_mode.autocast
21
  class _Fp16Autocast(_OriginalAutocast):
22
  def __init__(self, device_type, dtype=None, *args, **kwargs):
23
- # Intercept any bfloat16 request (original C enum or aliased)
24
- if dtype is not None and dtype in (_original_bf16, torch.float16):
25
  dtype = torch.float16
26
  super().__init__(device_type, dtype=dtype, *args, **kwargs)
27
 
 
11
  # can *emulate* bfloat16 in software, but the actual kernels crash on mixed
12
  # dtype operations (linear, conv2d). We MUST patch unconditionally.
13
  if torch.cuda.is_available():
14
+ # Intercept ALL autocast entry points to force float16
 
 
 
 
15
  import torch.amp.autocast_mode
16
  _OriginalAutocast = torch.amp.autocast_mode.autocast
17
  class _Fp16Autocast(_OriginalAutocast):
18
  def __init__(self, device_type, dtype=None, *args, **kwargs):
19
+ # Intercept Meta's bfloat16 request and force float16 for Turing support
20
+ if dtype == torch.bfloat16:
21
  dtype = torch.float16
22
  super().__init__(device_type, dtype=dtype, *args, **kwargs)
23