Siddh12334 commited on
Commit
a262689
Β·
1 Parent(s): 98317c2

fix: patch torchao dtype imports for unsloth

Browse files

Allow the training Space to import Unsloth with stable CUDA PyTorch by shimming optional TorchAO dtype probes before Transformers loads.

Made-with: Cursor

Files changed (2) hide show
  1. Dockerfile.train +1 -1
  2. training/space_runner.py +34 -6
Dockerfile.train CHANGED
@@ -4,7 +4,7 @@ WORKDIR /app
4
 
5
  RUN apt-get update && apt-get install -y git build-essential && rm -rf /var/lib/apt/lists/*
6
 
7
- # Stable cu121 torch β€” register_constant shim applied at runtime for torchao compat
8
  RUN pip install --no-cache-dir \
9
  torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 && \
10
  pip install --no-cache-dir \
 
4
 
5
  RUN apt-get update && apt-get install -y git build-essential && rm -rf /var/lib/apt/lists/*
6
 
7
+ # Stable cu121 torch β€” TorchAO import shims are applied at runtime before Unsloth loads.
8
  RUN pip install --no-cache-dir \
9
  torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 && \
10
  pip install --no-cache-dir \
training/space_runner.py CHANGED
@@ -57,13 +57,14 @@ def _attach_log_capture():
57
  lg.setLevel(logging.INFO)
58
 
59
 
60
- def _run_training():
61
- global _training_status
62
- _training_status = "running"
63
- _append_log("Thread started β€” patching torch/_pytree then importing unsloth...")
 
 
 
64
 
65
- # torchao uses register_constant which only exists in torch nightly.
66
- # We don't use torchao quantization, so a no-op shim is safe.
67
  try:
68
  import torch.utils._pytree as _pytree
69
  if not hasattr(_pytree, "register_constant"):
@@ -72,6 +73,33 @@ def _run_training():
72
  except Exception as _e:
73
  _append_log(f"Warning: could not patch _pytree: {_e}")
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  try:
76
  import unsloth # noqa: F401
77
  _append_log(f"βœ… unsloth ready (v{getattr(unsloth, '__version__', 'unknown')})")
 
57
  lg.setLevel(logging.INFO)
58
 
59
 
60
+ def _patch_torchao_import_compat():
61
+ """Patch TorchAO import probes that expect newer/nightly torch symbols.
62
+
63
+ Training uses bitsandbytes 4-bit loading through Unsloth, not TorchAO
64
+ quantization. These aliases are only to let optional TorchAO modules import.
65
+ """
66
+ import torch
67
 
 
 
68
  try:
69
  import torch.utils._pytree as _pytree
70
  if not hasattr(_pytree, "register_constant"):
 
73
  except Exception as _e:
74
  _append_log(f"Warning: could not patch _pytree: {_e}")
75
 
76
+ patched_dtypes = []
77
+ dtype_aliases = {
78
+ **{f"int{i}": torch.int8 for i in range(1, 8)},
79
+ **{f"uint{i}": torch.uint8 for i in range(1, 8)},
80
+ }
81
+ for name, fallback in dtype_aliases.items():
82
+ if not hasattr(torch, name):
83
+ setattr(torch, name, fallback)
84
+ patched_dtypes.append(name)
85
+
86
+ if patched_dtypes:
87
+ _append_log(
88
+ "Applied torch dtype shims for torchao import: "
89
+ + ", ".join(patched_dtypes)
90
+ )
91
+
92
+
93
+ def _run_training():
94
+ global _training_status
95
+ _training_status = "running"
96
+ _append_log("Thread started β€” patching torch/torchao compat then importing unsloth...")
97
+
98
+ try:
99
+ _patch_torchao_import_compat()
100
+ except Exception as _e:
101
+ _append_log(f"Warning: could not patch torchao import compat: {_e}")
102
+
103
  try:
104
  import unsloth # noqa: F401
105
  _append_log(f"βœ… unsloth ready (v{getattr(unsloth, '__version__', 'unknown')})")