Spaces:
Sleeping
Sleeping
AI Agent commited on
Commit ·
a9f1afc
1
Parent(s): cb7cf0f
HOTFIX: Remove global bfloat16 alias to unbreak torch.load checkpoint parsing
Browse files
app.py
CHANGED
|
@@ -11,17 +11,13 @@ import io
|
|
| 11 |
# can *emulate* bfloat16 in software, but the actual kernels crash on mixed
|
| 12 |
# dtype operations (linear, conv2d). We MUST patch unconditionally.
|
| 13 |
if torch.cuda.is_available():
|
| 14 |
-
#
|
| 15 |
-
_original_bf16 = torch.bfloat16
|
| 16 |
-
torch.bfloat16 = torch.float16
|
| 17 |
-
|
| 18 |
-
# 2. Intercept ALL autocast entry points to force float16
|
| 19 |
import torch.amp.autocast_mode
|
| 20 |
_OriginalAutocast = torch.amp.autocast_mode.autocast
|
| 21 |
class _Fp16Autocast(_OriginalAutocast):
|
| 22 |
def __init__(self, device_type, dtype=None, *args, **kwargs):
|
| 23 |
-
# Intercept
|
| 24 |
-
if dtype
|
| 25 |
dtype = torch.float16
|
| 26 |
super().__init__(device_type, dtype=dtype, *args, **kwargs)
|
| 27 |
|
|
|
|
| 11 |
# can *emulate* bfloat16 in software, but the actual kernels crash on mixed
|
| 12 |
# dtype operations (linear, conv2d). We MUST patch unconditionally.
|
| 13 |
if torch.cuda.is_available():
|
| 14 |
+
# Intercept ALL autocast entry points to force float16
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
import torch.amp.autocast_mode
|
| 16 |
_OriginalAutocast = torch.amp.autocast_mode.autocast
|
| 17 |
class _Fp16Autocast(_OriginalAutocast):
|
| 18 |
def __init__(self, device_type, dtype=None, *args, **kwargs):
|
| 19 |
+
# Intercept Meta's bfloat16 request and force float16 for Turing support
|
| 20 |
+
if dtype == torch.bfloat16:
|
| 21 |
dtype = torch.float16
|
| 22 |
super().__init__(device_type, dtype=dtype, *args, **kwargs)
|
| 23 |
|