Spaces:
Running on Zero
Running on Zero
| """ | |
| Runtime patches for third-party libraries used in GRPO training. | |
| Python automatically imports `sitecustomize` at interpreter startup (via `site`), | |
| including in multiprocessing "spawn" workers. Keep patches minimal and gated. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| from typing import Any | |
| def _patch_art_trainconfig_lr_alias() -> None: | |
| """ | |
| ART's UnslothService warmup does: | |
| config.model_copy(update={"lr": 1e-9, "beta": 0.0, "kl_coef": 0.0}) | |
| but `art.types.TrainConfig` uses the field name `learning_rate`, not `lr`. | |
| Without this patch, the warmup step can accidentally run at the *full* | |
| learning rate with `beta=0`, causing large KL/grad spikes on the first | |
| trainable batch after a service restart. | |
| """ | |
| try: | |
| from art.types import TrainConfig | |
| except Exception: | |
| return | |
| original_model_copy = TrainConfig.model_copy | |
| # Avoid double-patching (important for interactive sessions / reloads). | |
| if getattr(original_model_copy, "__linalgzero_patched__", False): | |
| return | |
| def model_copy(self: Any, *, update: dict[str, Any] | None = None, deep: bool = False): | |
| if isinstance(update, dict): | |
| patched = dict(update) | |
| if "lr" in patched and "learning_rate" not in patched: | |
| # Keep both keys: some call sites may use `lr`, but ART TrainConfig uses | |
| # `learning_rate`. Unknown keys are ignored by pydantic, so retaining | |
| # `lr` is harmless and keeps the original update dict semantics. | |
| patched["learning_rate"] = patched["lr"] | |
| update = patched | |
| return original_model_copy(self, update=update, deep=deep) | |
| model_copy.__linalgzero_patched__ = True # type: ignore[attr-defined] | |
| TrainConfig.model_copy = model_copy # type: ignore[assignment] | |
| def _install_art_unsloth_kl_guard_patch() -> None: | |
| """ | |
| Patch ART's `art.unsloth.train` on import to support optional KL safety guards. | |
| This stays in-repo (no site-packages edits) and is applied via an import hook so | |
| it also affects subprocesses started by ART/Unsloth. | |
| """ | |
| try: | |
| from linalg_zero.grpo.art_unsloth_kl_guard import install | |
| except Exception: | |
| return | |
| install() | |
| if os.getenv("LINALGZERO_DISABLE_ART_PATCHES") != "1": | |
| _patch_art_trainconfig_lr_alias() | |