Spaces:
Running
Running
Fix Unsloth startup: avoid pre-importing trl/transformers; mock vllm as real package modules.
Browse files- ultimate_sota_training.py +26 -21
ultimate_sota_training.py
CHANGED
|
@@ -119,16 +119,7 @@ def bootstrap_deps() -> None:
|
|
| 119 |
_pip(["uninstall", "-y", "torchao"], check=False)
|
| 120 |
_pip(["uninstall", "--break-system-packages", "-y", "torchvision", "torchaudio"], check=False)
|
| 121 |
|
| 122 |
-
|
| 123 |
-
import accelerate # noqa: F401
|
| 124 |
-
import transformers # noqa: F401
|
| 125 |
-
from trl import GRPOConfig as _BootstrapGRPOConfig # noqa: F401
|
| 126 |
-
|
| 127 |
-
_ = _BootstrapGRPOConfig
|
| 128 |
-
except Exception as e:
|
| 129 |
-
raise RuntimeError(
|
| 130 |
-
"Post-bootstrap import check failed. Adjust BOOTSTRAP_*_VERSION or SKIP_BOOTSTRAP=1."
|
| 131 |
-
) from e
|
| 132 |
|
| 133 |
|
| 134 |
bootstrap_deps()
|
|
@@ -157,21 +148,35 @@ importlib_metadata.version = _safe_pkg_version
|
|
| 157 |
import sys
|
| 158 |
import types
|
| 159 |
import importlib.machinery
|
| 160 |
-
from unittest.mock import MagicMock
|
| 161 |
|
| 162 |
def mock_vllm_hierarchy():
|
| 163 |
-
|
| 164 |
-
"vllm",
|
| 165 |
-
"vllm.distributed",
|
| 166 |
-
"vllm.distributed.device_communicators",
|
| 167 |
-
"vllm.distributed.device_communicators.pynccl",
|
| 168 |
"vllm.model_executor",
|
| 169 |
"vllm.model_executor.parallel_utils",
|
| 170 |
-
]
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
|
| 176 |
mock_vllm_hierarchy()
|
| 177 |
|
|
|
|
| 119 |
_pip(["uninstall", "-y", "torchao"], check=False)
|
| 120 |
_pip(["uninstall", "--break-system-packages", "-y", "torchvision", "torchaudio"], check=False)
|
| 121 |
|
| 122 |
+
# Do not import transformers/trl here. Unsloth must be imported first later.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
|
| 125 |
bootstrap_deps()
|
|
|
|
| 148 |
import sys
|
| 149 |
import types
|
| 150 |
import importlib.machinery
|
|
|
|
| 151 |
|
| 152 |
def mock_vllm_hierarchy():
|
| 153 |
+
pkg_names = [
|
| 154 |
+
"vllm",
|
| 155 |
+
"vllm.distributed",
|
| 156 |
+
"vllm.distributed.device_communicators",
|
|
|
|
| 157 |
"vllm.model_executor",
|
| 158 |
"vllm.model_executor.parallel_utils",
|
| 159 |
+
]
|
| 160 |
+
leaf_names = [
|
| 161 |
+
"vllm.distributed.device_communicators.pynccl",
|
| 162 |
+
]
|
| 163 |
+
|
| 164 |
+
# Create proper package-like modules with submodule_search_locations so
|
| 165 |
+
# unsloth's import fixes that inspect package paths don't crash.
|
| 166 |
+
for m_name in pkg_names:
|
| 167 |
+
mod = types.ModuleType(m_name)
|
| 168 |
+
mod.__package__ = m_name
|
| 169 |
+
mod.__path__ = [f"/tmp/mock_{m_name.replace('.', '_')}"]
|
| 170 |
+
spec = importlib.machinery.ModuleSpec(m_name, loader=None, is_package=True)
|
| 171 |
+
spec.submodule_search_locations = mod.__path__
|
| 172 |
+
mod.__spec__ = spec
|
| 173 |
+
sys.modules[m_name] = mod
|
| 174 |
+
|
| 175 |
+
for m_name in leaf_names:
|
| 176 |
+
mod = types.ModuleType(m_name)
|
| 177 |
+
mod.__package__ = m_name.rsplit(".", 1)[0]
|
| 178 |
+
mod.__spec__ = importlib.machinery.ModuleSpec(m_name, loader=None, is_package=False)
|
| 179 |
+
sys.modules[m_name] = mod
|
| 180 |
|
| 181 |
mock_vllm_hierarchy()
|
| 182 |
|