md896 commited on
Commit
d21de11
·
1 Parent(s): 1fdba13

Fix Unsloth startup: avoid pre-importing trl/transformers; mock vllm as real package modules.

Browse files
Files changed (1) hide show
  1. ultimate_sota_training.py +26 -21
ultimate_sota_training.py CHANGED
@@ -119,16 +119,7 @@ def bootstrap_deps() -> None:
119
  _pip(["uninstall", "-y", "torchao"], check=False)
120
  _pip(["uninstall", "--break-system-packages", "-y", "torchvision", "torchaudio"], check=False)
121
 
122
- try:
123
- import accelerate # noqa: F401
124
- import transformers # noqa: F401
125
- from trl import GRPOConfig as _BootstrapGRPOConfig # noqa: F401
126
-
127
- _ = _BootstrapGRPOConfig
128
- except Exception as e:
129
- raise RuntimeError(
130
- "Post-bootstrap import check failed. Adjust BOOTSTRAP_*_VERSION or SKIP_BOOTSTRAP=1."
131
- ) from e
132
 
133
 
134
  bootstrap_deps()
@@ -157,21 +148,35 @@ importlib_metadata.version = _safe_pkg_version
157
  import sys
158
  import types
159
  import importlib.machinery
160
- from unittest.mock import MagicMock
161
 
162
  def mock_vllm_hierarchy():
163
- for m_name in [
164
- "vllm",
165
- "vllm.distributed",
166
- "vllm.distributed.device_communicators",
167
- "vllm.distributed.device_communicators.pynccl",
168
  "vllm.model_executor",
169
  "vllm.model_executor.parallel_utils",
170
- ]:
171
- mock_m = MagicMock(spec=types.ModuleType)
172
- mock_m.__name__ = m_name
173
- mock_m.__spec__ = importlib.machinery.ModuleSpec(m_name, None)
174
- sys.modules[m_name] = mock_m
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
  mock_vllm_hierarchy()
177
 
 
119
  _pip(["uninstall", "-y", "torchao"], check=False)
120
  _pip(["uninstall", "--break-system-packages", "-y", "torchvision", "torchaudio"], check=False)
121
 
122
+ # Do not import transformers/trl here. Unsloth must be imported first later.
 
 
 
 
 
 
 
 
 
123
 
124
 
125
  bootstrap_deps()
 
148
  import sys
149
  import types
150
  import importlib.machinery
 
151
 
152
  def mock_vllm_hierarchy():
153
+ pkg_names = [
154
+ "vllm",
155
+ "vllm.distributed",
156
+ "vllm.distributed.device_communicators",
 
157
  "vllm.model_executor",
158
  "vllm.model_executor.parallel_utils",
159
+ ]
160
+ leaf_names = [
161
+ "vllm.distributed.device_communicators.pynccl",
162
+ ]
163
+
164
+ # Create proper package-like modules with submodule_search_locations so
165
+ # unsloth's import fixes that inspect package paths don't crash.
166
+ for m_name in pkg_names:
167
+ mod = types.ModuleType(m_name)
168
+ mod.__package__ = m_name
169
+ mod.__path__ = [f"/tmp/mock_{m_name.replace('.', '_')}"]
170
+ spec = importlib.machinery.ModuleSpec(m_name, loader=None, is_package=True)
171
+ spec.submodule_search_locations = mod.__path__
172
+ mod.__spec__ = spec
173
+ sys.modules[m_name] = mod
174
+
175
+ for m_name in leaf_names:
176
+ mod = types.ModuleType(m_name)
177
+ mod.__package__ = m_name.rsplit(".", 1)[0]
178
+ mod.__spec__ = importlib.machinery.ModuleSpec(m_name, loader=None, is_package=False)
179
+ sys.modules[m_name] = mod
180
 
181
  mock_vllm_hierarchy()
182