Manthan-T1 / kaggle_train.py
Atah Alam
Add Kaggle root trainer + fix Unsloth import order
d3df1cb
"""Kaggle single-file entrypoint for training Manthan-T1.
Copy-paste this file into Kaggle (repo root) and run:
- It optionally installs a compatible stack (avoids common torch/torchaudio mismatches).
- It guarantees `trust_remote_code=True` model loading by ensuring a `config.json` exists.
- Then runs Stage 1 (projector pretrain) and Stage 2 (instruction finetune) via
`scripts/train_unsloth_kaggle.py`.
Design goals:
- Minimal: no notebook-specific APIs.
- Robust: patches HF repo config if missing; sets cache dirs.
Environment variables (optional):
- MANTHAN_MODEL_ID (default: "zyxcisss/Manthan-T1")
- HF_HOME (default: /kaggle/working/hf_home)
- HF_TOKEN (if private repo)
- INSTALL_DEPS=1 to run pip installs (default: 0)
"""
from __future__ import annotations
import json
import os
import subprocess
import sys
from pathlib import Path
REPO_ROOT = Path(__file__).resolve().parent
def _run(cmd: list[str], *, env: dict[str, str] | None = None) -> None:
print("\n$", " ".join(cmd), flush=True)
subprocess.check_call(cmd, env=env)
def _maybe_install_deps() -> None:
"""Optional dependency installation.
Kaggle images often come with a preinstalled CUDA stack; mixing torch + torchaudio
versions is the main source of hard errors.
This function is intentionally conservative: it only runs if INSTALL_DEPS=1.
"""
if os.environ.get("INSTALL_DEPS", "0") != "1":
print("INSTALL_DEPS != 1; skipping pip installs.")
return
# Pin to a coherent torch/torchaudio/torchvision trio.
# Note: Kaggle frequently uses CUDA 12.x. The +cu121 wheel set is broadly available.
# If your Kaggle runtime has a different CUDA, adjust these pins.
pins = [
"torch==2.8.0",
"torchvision==0.23.0",
"torchaudio==2.8.0",
"transformers>=4.46.0",
"accelerate>=0.34.0",
"datasets>=2.20.0",
"safetensors>=0.4.3",
"pillow>=10.0.0",
"tyro>=0.8.0",
"trl>=0.12.0",
# Optional:
"sentencepiece",
"protobuf",
]
# Prefer pip upgrade first.
_run([sys.executable, "-m", "pip", "install", "-U", "pip"])
# Install. We avoid extra-index URLs here; Kaggle generally resolves CUDA wheels.
_run([sys.executable, "-m", "pip", "install", "-U"] + pins)
# Try installing unsloth last; it may pin/reinstall torch deps.
_run([sys.executable, "-m", "pip", "install", "-U", "unsloth", "unsloth_zoo", "xformers"])
def _setup_hf_env() -> dict[str, str]:
env = os.environ.copy()
hf_home = env.get("HF_HOME") or "/kaggle/working/hf_home"
env["HF_HOME"] = hf_home
env.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
env.setdefault("TOKENIZERS_PARALLELISM", "false")
# Keep common caches inside /kaggle/working
env.setdefault("TRANSFORMERS_CACHE", str(Path(hf_home) / "transformers"))
env.setdefault("HF_DATASETS_CACHE", str(Path(hf_home) / "datasets"))
Path(env["TRANSFORMERS_CACHE"]).mkdir(parents=True, exist_ok=True)
Path(env["HF_DATASETS_CACHE"]).mkdir(parents=True, exist_ok=True)
return env
def _ensure_local_config(model_id: str) -> None:
"""Fix the exact failure you hit: missing/invalid config.json on HF repo.
If the Kaggle environment cloned this repo locally, Transformers will load from
local path if you pass that path, which is safer than relying on remote hub
metadata.
We make sure `./config.json` exists and has:
- model_type: "manthan_t1"
- auto_map: points to the remote code modules
This makes `AutoConfig.from_pretrained(local_path, trust_remote_code=True)` work.
"""
cfg_path = REPO_ROOT / "config.json"
if cfg_path.exists():
try:
cfg = json.loads(cfg_path.read_text())
except Exception:
cfg = {}
else:
cfg = {}
# If the repo already has a config, only patch missing fields.
cfg.setdefault("model_type", "manthan_t1")
cfg.setdefault("architectures", ["ManthanForCausalLM"])
cfg.setdefault(
"auto_map",
{
"AutoConfig": "manthan_t1/configuration_manthan.py:ManthanConfig",
"AutoModelForCausalLM": "manthan_t1/modeling_manthan.py:ManthanForCausalLM",
},
)
# Helpful defaults for stubs.
cfg.setdefault("torch_dtype", "float16")
cfg_path.write_text(json.dumps(cfg, indent=2) + "\n")
print(f"Ensured local config at: {cfg_path}")
def _sanity_load_config(env: dict[str, str]) -> None:
# Lazy import; avoids transformers import before unsloth in downstream script.
from transformers import AutoConfig
cfg = AutoConfig.from_pretrained(str(REPO_ROOT), trust_remote_code=True)
mt = getattr(cfg, "model_type", None)
print("Loaded config model_type:", mt)
if mt != "manthan_t1":
raise RuntimeError(f"Unexpected model_type={mt!r}; expected 'manthan_t1'.")
def _run_stage(env: dict[str, str], stage: int, extra: list[str] | None = None) -> None:
extra = extra or []
script = REPO_ROOT / "scripts" / "train_unsloth_kaggle.py"
if not script.exists():
raise FileNotFoundError(f"Missing {script}. Did you clone the repo correctly?")
_run(
[
sys.executable,
str(script),
"--stage",
str(stage),
"--model_id",
str(REPO_ROOT), # load from local to avoid HF config issues
]
+ extra,
env=env,
)
def main() -> int:
model_id = os.environ.get("MANTHAN_MODEL_ID", "zyxcisss/Manthan-T1")
print("Manthan Kaggle trainer")
print("Repo root:", REPO_ROOT)
print("Model ID (for reference):", model_id)
_maybe_install_deps()
env = _setup_hf_env()
# Patch local config so Transformers can recognize our custom model.
_ensure_local_config(model_id)
# Quick fail-fast: config should load via trust_remote_code.
_sanity_load_config(env)
print("\n==== Stage 1: projector alignment/pretrain ====")
_run_stage(env, 1)
print("\n==== Stage 2: instruction finetune ====")
_run_stage(env, 2)
print("\nDone.")
return 0
if __name__ == "__main__":
raise SystemExit(main())