Spaces:
Sleeping
Sleeping
Commit Β·
a262689
1
Parent(s): 98317c2
fix: patch torchao dtype imports for unsloth
Browse filesAllow the training Space to import Unsloth with stable CUDA PyTorch by shimming optional TorchAO dtype probes before Transformers loads.
Made-with: Cursor
- Dockerfile.train +1 -1
- training/space_runner.py +34 -6
Dockerfile.train
CHANGED
|
@@ -4,7 +4,7 @@ WORKDIR /app
|
|
| 4 |
|
| 5 |
RUN apt-get update && apt-get install -y git build-essential && rm -rf /var/lib/apt/lists/*
|
| 6 |
|
| 7 |
-
# Stable cu121 torch β
|
| 8 |
RUN pip install --no-cache-dir \
|
| 9 |
torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 && \
|
| 10 |
pip install --no-cache-dir \
|
|
|
|
| 4 |
|
| 5 |
RUN apt-get update && apt-get install -y git build-essential && rm -rf /var/lib/apt/lists/*
|
| 6 |
|
| 7 |
+
# Stable cu121 torch β TorchAO import shims are applied at runtime before Unsloth loads.
|
| 8 |
RUN pip install --no-cache-dir \
|
| 9 |
torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 && \
|
| 10 |
pip install --no-cache-dir \
|
training/space_runner.py
CHANGED
|
@@ -57,13 +57,14 @@ def _attach_log_capture():
|
|
| 57 |
lg.setLevel(logging.INFO)
|
| 58 |
|
| 59 |
|
| 60 |
-
def
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
-
# torchao uses register_constant which only exists in torch nightly.
|
| 66 |
-
# We don't use torchao quantization, so a no-op shim is safe.
|
| 67 |
try:
|
| 68 |
import torch.utils._pytree as _pytree
|
| 69 |
if not hasattr(_pytree, "register_constant"):
|
|
@@ -72,6 +73,33 @@ def _run_training():
|
|
| 72 |
except Exception as _e:
|
| 73 |
_append_log(f"Warning: could not patch _pytree: {_e}")
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
try:
|
| 76 |
import unsloth # noqa: F401
|
| 77 |
_append_log(f"β
unsloth ready (v{getattr(unsloth, '__version__', 'unknown')})")
|
|
|
|
| 57 |
lg.setLevel(logging.INFO)
|
| 58 |
|
| 59 |
|
| 60 |
+
def _patch_torchao_import_compat():
|
| 61 |
+
"""Patch TorchAO import probes that expect newer/nightly torch symbols.
|
| 62 |
+
|
| 63 |
+
Training uses bitsandbytes 4-bit loading through Unsloth, not TorchAO
|
| 64 |
+
quantization. These aliases are only to let optional TorchAO modules import.
|
| 65 |
+
"""
|
| 66 |
+
import torch
|
| 67 |
|
|
|
|
|
|
|
| 68 |
try:
|
| 69 |
import torch.utils._pytree as _pytree
|
| 70 |
if not hasattr(_pytree, "register_constant"):
|
|
|
|
| 73 |
except Exception as _e:
|
| 74 |
_append_log(f"Warning: could not patch _pytree: {_e}")
|
| 75 |
|
| 76 |
+
patched_dtypes = []
|
| 77 |
+
dtype_aliases = {
|
| 78 |
+
**{f"int{i}": torch.int8 for i in range(1, 8)},
|
| 79 |
+
**{f"uint{i}": torch.uint8 for i in range(1, 8)},
|
| 80 |
+
}
|
| 81 |
+
for name, fallback in dtype_aliases.items():
|
| 82 |
+
if not hasattr(torch, name):
|
| 83 |
+
setattr(torch, name, fallback)
|
| 84 |
+
patched_dtypes.append(name)
|
| 85 |
+
|
| 86 |
+
if patched_dtypes:
|
| 87 |
+
_append_log(
|
| 88 |
+
"Applied torch dtype shims for torchao import: "
|
| 89 |
+
+ ", ".join(patched_dtypes)
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _run_training():
|
| 94 |
+
global _training_status
|
| 95 |
+
_training_status = "running"
|
| 96 |
+
_append_log("Thread started β patching torch/torchao compat then importing unsloth...")
|
| 97 |
+
|
| 98 |
+
try:
|
| 99 |
+
_patch_torchao_import_compat()
|
| 100 |
+
except Exception as _e:
|
| 101 |
+
_append_log(f"Warning: could not patch torchao import compat: {_e}")
|
| 102 |
+
|
| 103 |
try:
|
| 104 |
import unsloth # noqa: F401
|
| 105 |
_append_log(f"β
unsloth ready (v{getattr(unsloth, '__version__', 'unknown')})")
|