krasserm commited on
Commit
e3cb970
·
verified ·
1 Parent(s): 1930fb7

Upload sft_qwen2_capybara.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. sft_qwen2_capybara.py +203 -0
sft_qwen2_capybara.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "trl==1.5.1",
5
+ # "peft==0.19.1",
6
+ # "transformers==5.9.0",
7
+ # "datasets==4.8.5",
8
+ # "accelerate==1.13.0",
9
+ # "trackio==0.26.0",
10
+ # "torch",
11
+ # ]
12
+ # ///
13
+ """
14
+ LoRA SFT of Qwen/Qwen2-0.5B (BASE) on trl-lib/Capybara (conversational 'messages').
15
+
16
+ Grounded in research (Principle 1), cross-checked against the canonical templates:
17
+ - Skill template: huggingface-llm-trainer/scripts/train_sft_example.py
18
+ (SFTTrainer(model=str, train_dataset, peft_config=LoraConfig(...), args=SFTConfig(...)))
19
+ - Canonical TRL SFT: https://github.com/huggingface/trl/blob/main/trl/scripts/sft.py
20
+ - SFTTrainer consumes the 'messages' column directly and applies the chat template;
21
+ NO formatting_func needed (TRL "conversational" dataset support).
22
+
23
+ KEY VERIFIED FACTS (by inspection, Principle 2):
24
+ - Qwen/Qwen2-0.5B (BASE) ALREADY ships a ChatML chat template in its
25
+ tokenizer_config.json (327 chars, uses <|im_start|>/<|im_end|>). So we do NOT
26
+ inject a template; SFTTrainer applies the model's own template to 'messages'.
27
+ Subtlety: the base tokenizer's eos_token is <|endoftext|> while the template
28
+ ends turns with <|im_end|>. For SFT we set eos_token="<|im_end|>" via SFTConfig
29
+ so the model learns to stop at end-of-turn (and assistant_only_loss aligns).
30
+ - trl-lib/Capybara: config 'default', splits {train:15806, test}, columns
31
+ ['source','messages','num_turns']; 'messages' is multi-turn alternating
32
+ user/assistant. Already in required schema (no mapping).
33
+ - SFTConfig fields eos_token / assistant_only_loss / max_length / packing all
34
+ present in current trl (verified in trl/trainer/sft_config.py).
35
+
36
+ LoRA recipe for a small instruct model (literature):
37
+ - target ALL linear layers (q,k,v,o,gate,up,down) > attention-only
38
+ (QLoRA, Dettmers et al. 2023, arXiv:2305.14314; PEFT/TRL all-linear guidance).
39
+ - r=16, alpha=32 (alpha=2r), dropout=0.05 (TRL ModelConfig LoRA defaults).
40
+ - lr=2e-4, cosine schedule, warmup_ratio=0.03 (standard LoRA SFT LR for small models).
41
+ - assistant_only_loss=True: loss only on assistant turns for multi-turn SFT.
42
+
43
+ Monitoring: Trackio (report_to) + structured alerts at decision points (§5.6/§5.7).
44
+
45
+ Resources (R14):
46
+ - Model: https://huggingface.co/Qwen/Qwen2-0.5B
47
+ - Dataset: https://huggingface.co/datasets/trl-lib/Capybara
48
+ """
49
+
50
+ import os
51
+
52
+ from datasets import load_dataset
53
+ from peft import LoraConfig
54
+ from trl import SFTConfig, SFTTrainer
55
+
56
+ import trackio
57
+
58
+
59
+ # --------------------------------------------------------------------------- #
60
+ # Config (overridable via env so the SAME script serves smoke + full run)
61
+ # --------------------------------------------------------------------------- #
62
+ MODEL_ID = os.environ.get("MODEL_ID", "Qwen/Qwen2-0.5B")
63
+ DATASET_ID = os.environ.get("DATASET_ID", "trl-lib/Capybara")
64
+ EOS_TOKEN = os.environ.get("EOS_TOKEN", "<|im_end|>") # base eos is <|endoftext|>
65
+ HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "krasserm/Qwen2-0.5B-Capybara-LoRA")
66
+
67
+ SMOKE = os.environ.get("SMOKE", "0") == "1"
68
+ LIMIT = int(os.environ.get("LIMIT", "0")) # 0 = full split
69
+ MAX_LENGTH = int(os.environ.get("MAX_LENGTH", "2048"))
70
+
71
+ TRACKIO_PROJECT = os.environ.get("TRACKIO_PROJECT", "qwen2-0.5b-capybara-lora")
72
+ TRACKIO_SPACE = os.environ.get("TRACKIO_SPACE_ID", "") # e.g. krasserm/trackio
73
+
74
+
75
+ ALERT_WEBHOOK = os.environ.get("ALERT_WEBHOOK_URL") or None
76
+
77
+ # Verified API (trackio 0.26.x, gradio-app/trackio __init__.py + skill alerts.md):
78
+ # trackio.alert(title, text, level=trackio.AlertLevel.{INFO,WARN,ERROR}, webhook_url)
79
+ _LEVELS = {
80
+ "info": "INFO",
81
+ "success": "INFO",
82
+ "warn": "WARN",
83
+ "warning": "WARN",
84
+ "error": "ERROR",
85
+ }
86
+
87
+
88
+ def alert(level: str, title: str, message: str):
89
+ """Structured Trackio alert at a decision point (§5.7); never crash training."""
90
+ try:
91
+ lvl = getattr(trackio.AlertLevel, _LEVELS.get(level, "INFO"))
92
+ trackio.alert(title=title, text=message, level=lvl, webhook_url=ALERT_WEBHOOK)
93
+ except Exception as e:
94
+ print(f"[alert:{level}] {title} :: {message} (trackio.alert failed: {e})")
95
+
96
+
97
+ def main():
98
+ # ------------------------------------------------------------------ #
99
+ # Monitoring init (live dashboard)
100
+ # ------------------------------------------------------------------ #
101
+ init_kwargs = {"project": TRACKIO_PROJECT}
102
+ if TRACKIO_SPACE:
103
+ init_kwargs["space_id"] = TRACKIO_SPACE
104
+ trackio.init(
105
+ config={
106
+ "model": MODEL_ID,
107
+ "dataset": DATASET_ID,
108
+ "method": "SFT+LoRA",
109
+ "eos_token": EOS_TOKEN,
110
+ "smoke": SMOKE,
111
+ "limit": LIMIT,
112
+ "max_length": MAX_LENGTH,
113
+ },
114
+ **init_kwargs,
115
+ )
116
+ alert("info", "Run started", f"SFT+LoRA {MODEL_ID} on {DATASET_ID} (smoke={SMOKE})")
117
+
118
+ # ------------------------------------------------------------------ #
119
+ # Data — consumed directly by SFTTrainer (no formatting_func)
120
+ # ------------------------------------------------------------------ #
121
+ ds = load_dataset(DATASET_ID, split="train")
122
+ if LIMIT > 0:
123
+ ds = ds.select(range(min(LIMIT, len(ds))))
124
+ if "messages" not in ds.column_names:
125
+ alert("error", "Schema mismatch", f"'messages' not in {ds.column_names}")
126
+ raise ValueError(f"Expected 'messages' column, got {ds.column_names}")
127
+ print(f"Loaded {len(ds)} examples; columns={ds.column_names}")
128
+
129
+ # ------------------------------------------------------------------ #
130
+ # LoRA config (all-linear, r16/alpha32) — research recipe
131
+ # ------------------------------------------------------------------ #
132
+ peft_config = LoraConfig(
133
+ r=16,
134
+ lora_alpha=32,
135
+ lora_dropout=0.05,
136
+ bias="none",
137
+ task_type="CAUSAL_LM",
138
+ target_modules=[
139
+ "q_proj", "k_proj", "v_proj", "o_proj",
140
+ "gate_proj", "up_proj", "down_proj",
141
+ ],
142
+ )
143
+
144
+ # ------------------------------------------------------------------ #
145
+ # SFTConfig
146
+ # ------------------------------------------------------------------ #
147
+ sft_config = SFTConfig(
148
+ output_dir="/tmp/sft-out",
149
+ # BASE already ships a ChatML template; just set EOS to end-of-turn token
150
+ # so the model learns to stop (base eos is <|endoftext|>).
151
+ eos_token=EOS_TOKEN,
152
+ max_length=MAX_LENGTH,
153
+ packing=not SMOKE, # packing on for the real run; off for tiny smoke
154
+ assistant_only_loss=True, # loss only on assistant turns (multi-turn SFT)
155
+ # Optimization (LoRA SFT recipe)
156
+ learning_rate=2e-4,
157
+ lr_scheduler_type="cosine",
158
+ warmup_ratio=0.03,
159
+ weight_decay=0.0,
160
+ num_train_epochs=1 if SMOKE else 2,
161
+ max_steps=8 if SMOKE else -1,
162
+ per_device_train_batch_size=2 if SMOKE else 8,
163
+ gradient_accumulation_steps=1 if SMOKE else 4,
164
+ gradient_checkpointing=True,
165
+ bf16=True,
166
+ # Logging / saving / monitoring
167
+ logging_steps=1 if SMOKE else 10,
168
+ save_strategy="no" if SMOKE else "epoch",
169
+ report_to=["trackio"],
170
+ run_name=TRACKIO_PROJECT,
171
+ # Persistence (Principle 4) — disabled in smoke
172
+ push_to_hub=not SMOKE,
173
+ hub_model_id=HUB_MODEL_ID,
174
+ seed=42,
175
+ )
176
+
177
+ trainer = SFTTrainer(
178
+ model=MODEL_ID,
179
+ args=sft_config,
180
+ train_dataset=ds,
181
+ peft_config=peft_config,
182
+ )
183
+
184
+ alert("info", "Training start",
185
+ f"{len(ds)} ex | bs={sft_config.per_device_train_batch_size}"
186
+ f" x ga={sft_config.gradient_accumulation_steps} |"
187
+ f" epochs={sft_config.num_train_epochs} max_steps={sft_config.max_steps}")
188
+
189
+ result = trainer.train()
190
+ tr_loss = result.metrics.get("train_loss")
191
+ alert("success", "Training done", f"train_loss={tr_loss}")
192
+
193
+ if sft_config.push_to_hub:
194
+ trainer.push_to_hub(dataset_name=DATASET_ID)
195
+ alert("success", "Pushed to Hub", f"https://huggingface.co/{HUB_MODEL_ID}")
196
+ else:
197
+ trainer.save_model(sft_config.output_dir)
198
+
199
+ trackio.finish()
200
+
201
+
202
+ if __name__ == "__main__":
203
+ main()