Atah Alam commited on
Commit
d3df1cb
·
1 Parent(s): 7f7a72e

Add Kaggle root trainer + fix Unsloth import order

Browse files
Files changed (4) hide show
  1. .gitignore +1 -0
  2. config.json +11 -0
  3. kaggle_train.py +195 -0
  4. scripts/train_unsloth_kaggle.py +3 -2
.gitignore CHANGED
@@ -3,6 +3,7 @@ venv/
3
  __pycache__/
4
  *.pyc
5
  .DS_Store
 
6
 
7
  # python tooling
8
  .pytest_cache/
 
3
  __pycache__/
4
  *.pyc
5
  .DS_Store
6
+ .worktrees/
7
 
8
  # python tooling
9
  .pytest_cache/
config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "manthan_t1",
3
+ "architectures": [
4
+ "ManthanForCausalLM"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "manthan_t1.configuration_manthan.ManthanConfig",
8
+ "AutoModelForCausalLM": "manthan_t1.modeling_manthan.ManthanForCausalLM"
9
+ },
10
+ "torch_dtype": "float16"
11
+ }
kaggle_train.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Kaggle single-file entrypoint for training Manthan-T1.
2
+
3
+ Copy-paste this file into Kaggle (repo root) and run:
4
+ - It optionally installs a compatible stack (avoids common torch/torchaudio mismatches).
5
+ - It guarantees `trust_remote_code=True` model loading by ensuring a `config.json` exists.
6
+ - Then runs Stage 1 (projector pretrain) and Stage 2 (instruction finetune) via
7
+ `scripts/train_unsloth_kaggle.py`.
8
+
9
+ Design goals:
10
+ - Minimal: no notebook-specific APIs.
11
+ - Robust: patches HF repo config if missing; sets cache dirs.
12
+
13
+ Environment variables (optional):
14
+ - MANTHAN_MODEL_ID (default: "zyxcisss/Manthan-T1")
15
+ - HF_HOME (default: /kaggle/working/hf_home)
16
+ - HF_TOKEN (if private repo)
17
+ - INSTALL_DEPS=1 to run pip installs (default: 0)
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import json
23
+ import os
24
+ import subprocess
25
+ import sys
26
+ from pathlib import Path
27
+
28
+
29
+ REPO_ROOT = Path(__file__).resolve().parent
30
+
31
+
32
+ def _run(cmd: list[str], *, env: dict[str, str] | None = None) -> None:
33
+ print("\n$", " ".join(cmd), flush=True)
34
+ subprocess.check_call(cmd, env=env)
35
+
36
+
37
+ def _maybe_install_deps() -> None:
38
+ """Optional dependency installation.
39
+
40
+ Kaggle images often come with a preinstalled CUDA stack; mixing torch + torchaudio
41
+ versions is the main source of hard errors.
42
+
43
+ This function is intentionally conservative: it only runs if INSTALL_DEPS=1.
44
+ """
45
+
46
+ if os.environ.get("INSTALL_DEPS", "0") != "1":
47
+ print("INSTALL_DEPS != 1; skipping pip installs.")
48
+ return
49
+
50
+ # Pin to a coherent torch/torchaudio/torchvision trio.
51
+ # Note: Kaggle frequently uses CUDA 12.x. The +cu121 wheel set is broadly available.
52
+ # If your Kaggle runtime has a different CUDA, adjust these pins.
53
+ pins = [
54
+ "torch==2.8.0",
55
+ "torchvision==0.23.0",
56
+ "torchaudio==2.8.0",
57
+ "transformers>=4.46.0",
58
+ "accelerate>=0.34.0",
59
+ "datasets>=2.20.0",
60
+ "safetensors>=0.4.3",
61
+ "pillow>=10.0.0",
62
+ "tyro>=0.8.0",
63
+ "trl>=0.12.0",
64
+ # Optional:
65
+ "sentencepiece",
66
+ "protobuf",
67
+ ]
68
+
69
+ # Prefer pip upgrade first.
70
+ _run([sys.executable, "-m", "pip", "install", "-U", "pip"])
71
+
72
+ # Install. We avoid extra-index URLs here; Kaggle generally resolves CUDA wheels.
73
+ _run([sys.executable, "-m", "pip", "install", "-U"] + pins)
74
+
75
+ # Try installing unsloth last; it may pin/reinstall torch deps.
76
+ _run([sys.executable, "-m", "pip", "install", "-U", "unsloth", "unsloth_zoo", "xformers"])
77
+
78
+
79
+ def _setup_hf_env() -> dict[str, str]:
80
+ env = os.environ.copy()
81
+
82
+ hf_home = env.get("HF_HOME") or "/kaggle/working/hf_home"
83
+ env["HF_HOME"] = hf_home
84
+ env.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
85
+ env.setdefault("TOKENIZERS_PARALLELISM", "false")
86
+
87
+ # Keep common caches inside /kaggle/working
88
+ env.setdefault("TRANSFORMERS_CACHE", str(Path(hf_home) / "transformers"))
89
+ env.setdefault("HF_DATASETS_CACHE", str(Path(hf_home) / "datasets"))
90
+
91
+ Path(env["TRANSFORMERS_CACHE"]).mkdir(parents=True, exist_ok=True)
92
+ Path(env["HF_DATASETS_CACHE"]).mkdir(parents=True, exist_ok=True)
93
+
94
+ return env
95
+
96
+
97
+ def _ensure_local_config(model_id: str) -> None:
98
+ """Fix the exact failure you hit: missing/invalid config.json on HF repo.
99
+
100
+ If the Kaggle environment cloned this repo locally, Transformers will load from
101
+ local path if you pass that path, which is safer than relying on remote hub
102
+ metadata.
103
+
104
+ We make sure `./config.json` exists and has:
105
+ - model_type: "manthan_t1"
106
+ - auto_map: points to the remote code modules
107
+
108
+ This makes `AutoConfig.from_pretrained(local_path, trust_remote_code=True)` work.
109
+ """
110
+
111
+ cfg_path = REPO_ROOT / "config.json"
112
+ if cfg_path.exists():
113
+ try:
114
+ cfg = json.loads(cfg_path.read_text())
115
+ except Exception:
116
+ cfg = {}
117
+ else:
118
+ cfg = {}
119
+
120
+ # If the repo already has a config, only patch missing fields.
121
+ cfg.setdefault("model_type", "manthan_t1")
122
+ cfg.setdefault("architectures", ["ManthanForCausalLM"])
123
+ cfg.setdefault(
124
+ "auto_map",
125
+ {
126
+ "AutoConfig": "manthan_t1/configuration_manthan.py:ManthanConfig",
127
+ "AutoModelForCausalLM": "manthan_t1/modeling_manthan.py:ManthanForCausalLM",
128
+ },
129
+ )
130
+
131
+ # Helpful defaults for stubs.
132
+ cfg.setdefault("torch_dtype", "float16")
133
+
134
+ cfg_path.write_text(json.dumps(cfg, indent=2) + "\n")
135
+ print(f"Ensured local config at: {cfg_path}")
136
+
137
+
138
+ def _sanity_load_config(env: dict[str, str]) -> None:
139
+ # Lazy import; avoids transformers import before unsloth in downstream script.
140
+ from transformers import AutoConfig
141
+
142
+ cfg = AutoConfig.from_pretrained(str(REPO_ROOT), trust_remote_code=True)
143
+ mt = getattr(cfg, "model_type", None)
144
+ print("Loaded config model_type:", mt)
145
+ if mt != "manthan_t1":
146
+ raise RuntimeError(f"Unexpected model_type={mt!r}; expected 'manthan_t1'.")
147
+
148
+
149
+ def _run_stage(env: dict[str, str], stage: int, extra: list[str] | None = None) -> None:
150
+ extra = extra or []
151
+ script = REPO_ROOT / "scripts" / "train_unsloth_kaggle.py"
152
+ if not script.exists():
153
+ raise FileNotFoundError(f"Missing {script}. Did you clone the repo correctly?")
154
+
155
+ _run(
156
+ [
157
+ sys.executable,
158
+ str(script),
159
+ "--stage",
160
+ str(stage),
161
+ "--model_id",
162
+ str(REPO_ROOT), # load from local to avoid HF config issues
163
+ ]
164
+ + extra,
165
+ env=env,
166
+ )
167
+
168
+
169
+ def main() -> int:
170
+ model_id = os.environ.get("MANTHAN_MODEL_ID", "zyxcisss/Manthan-T1")
171
+ print("Manthan Kaggle trainer")
172
+ print("Repo root:", REPO_ROOT)
173
+ print("Model ID (for reference):", model_id)
174
+
175
+ _maybe_install_deps()
176
+ env = _setup_hf_env()
177
+
178
+ # Patch local config so Transformers can recognize our custom model.
179
+ _ensure_local_config(model_id)
180
+
181
+ # Quick fail-fast: config should load via trust_remote_code.
182
+ _sanity_load_config(env)
183
+
184
+ print("\n==== Stage 1: projector alignment/pretrain ====")
185
+ _run_stage(env, 1)
186
+
187
+ print("\n==== Stage 2: instruction finetune ====")
188
+ _run_stage(env, 2)
189
+
190
+ print("\nDone.")
191
+ return 0
192
+
193
+
194
+ if __name__ == "__main__":
195
+ raise SystemExit(main())
scripts/train_unsloth_kaggle.py CHANGED
@@ -29,8 +29,6 @@ import torch
29
  from torch import nn
30
  from torch.utils.data import Dataset
31
 
32
- from transformers import AutoModelForCausalLM, AutoTokenizer, get_cosine_schedule_with_warmup
33
-
34
  try:
35
  # Fallback for non-Unsloth environments
36
  from peft import LoraConfig, get_peft_model
@@ -40,10 +38,13 @@ except Exception: # pragma: no cover
40
 
41
  try:
42
  # Kaggle + Unsloth
 
43
  from unsloth import FastLanguageModel
44
  except Exception: # pragma: no cover
45
  FastLanguageModel = None
46
 
 
 
47
  try:
48
  from datasets import load_dataset
49
  except Exception as e: # pragma: no cover
 
29
  from torch import nn
30
  from torch.utils.data import Dataset
31
 
 
 
32
  try:
33
  # Fallback for non-Unsloth environments
34
  from peft import LoraConfig, get_peft_model
 
38
 
39
  try:
40
  # Kaggle + Unsloth
41
+ import unsloth # noqa: F401
42
  from unsloth import FastLanguageModel
43
  except Exception: # pragma: no cover
44
  FastLanguageModel = None
45
 
46
+ from transformers import AutoModelForCausalLM, AutoTokenizer, get_cosine_schedule_with_warmup
47
+
48
  try:
49
  from datasets import load_dataset
50
  except Exception as e: # pragma: no cover