stmasson commited on
Commit
159c050
·
verified ·
1 Parent(s): 348bed7

Upload scripts/train_qwen3_sft_multitask.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/train_qwen3_sft_multitask.py +268 -0
scripts/train_qwen3_sft_multitask.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "transformers>=4.45.0",
5
+ # "trl>=0.12.0",
6
+ # "peft>=0.13.0",
7
+ # "datasets>=3.0.0",
8
+ # "accelerate>=1.0.0",
9
+ # "huggingface_hub>=0.26.0",
10
+ # "torch>=2.4.0",
11
+ # ]
12
+ # [tool.uv]
13
+ # index-strategy = "unsafe-best-match"
14
+ # extra-index-url = ["https://download.pytorch.org/whl/cu124"]
15
+ # ///
16
+ """
17
+ SFT Multi-Task Training Script for n8n Agent
18
+
19
+ This script fine-tunes the DPO-trained Qwen3-0.6B model on multi-task n8n workflows.
20
+ It builds on the reasoning capabilities from DPO training and adds task-specific skills.
21
+
22
+ Tasks covered:
23
+ - generate: Create workflows from descriptions
24
+ - edit: Modify existing workflows
25
+ - fix: Correct errors in workflows
26
+ - explain: Explain what workflows do
27
+ - debug: Diagnose execution issues
28
+ - improve: Optimize and enhance workflows
29
+
30
+ Usage:
31
+ hf jobs uv run \
32
+ --script train_qwen3_sft_multitask.py \
33
+ --flavor l4x1 \
34
+ --timeout 24h
35
+ """
36
+
37
+ import os
38
+ import json
39
+ import torch
40
+ from datasets import Dataset
41
+ from transformers import AutoModelForCausalLM, AutoTokenizer
42
+ from peft import LoraConfig, PeftModel, get_peft_model
43
+ from trl import SFTTrainer, SFTConfig
44
+ from huggingface_hub import login, hf_hub_download
45
+
46
+ # ============================================================================
47
+ # CONFIGURATION
48
+ # ============================================================================
49
+
50
+ # Base model (the DPO-trained model with reasoning capabilities)
51
+ BASE_MODEL = os.environ.get("BASE_MODEL", "Qwen/Qwen3-0.6B")
52
+ DPO_ADAPTER = os.environ.get("DPO_ADAPTER", "stmasson/qwen3-0.6b-n8n-reasoning")
53
+
54
+ # Dataset
55
+ DATASET_REPO = "stmasson/n8n-agentic-multitask"
56
+ TRAIN_FILE = "data/multitask_large/train.jsonl"
57
+ VAL_FILE = "data/multitask_large/val.jsonl"
58
+
59
+ # Output
60
+ OUTPUT_DIR = "./qwen3-sft-multitask"
61
+ HF_REPO = os.environ.get("HF_REPO", "stmasson/qwen3-0.6b-n8n-agent")
62
+
63
+ # Hyperparameters
64
+ NUM_EPOCHS = int(os.environ.get("NUM_EPOCHS", "1"))
65
+ BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "1"))
66
+ GRAD_ACCUM = int(os.environ.get("GRAD_ACCUM", "8"))
67
+ LEARNING_RATE = float(os.environ.get("LEARNING_RATE", "1e-5"))
68
+ MAX_SEQ_LENGTH = int(os.environ.get("MAX_SEQ_LENGTH", "8192"))
69
+
70
+ # LoRA (continuing from DPO adapter)
71
+ LORA_R = int(os.environ.get("LORA_R", "32"))
72
+ LORA_ALPHA = int(os.environ.get("LORA_ALPHA", "64"))
73
+ LORA_DROPOUT = float(os.environ.get("LORA_DROPOUT", "0.05"))
74
+
75
+ # ============================================================================
76
+ # AUTHENTICATION
77
+ # ============================================================================
78
+
79
+ print("=" * 60)
80
+ print("SFT MULTI-TASK TRAINING - N8N AGENT")
81
+ print("=" * 60)
82
+
83
+ hf_token = os.environ.get("HF_TOKEN")
84
+ if hf_token:
85
+ login(token=hf_token)
86
+ print("Authenticated with HuggingFace")
87
+ else:
88
+ print("Warning: HF_TOKEN not set, push disabled")
89
+
90
+ # ============================================================================
91
+ # LOAD MODEL WITH DPO ADAPTER
92
+ # ============================================================================
93
+
94
+ print(f"\nLoading base model: {BASE_MODEL}")
95
+ print(f"Loading DPO adapter: {DPO_ADAPTER}")
96
+
97
+ # Load base model
98
+ model = AutoModelForCausalLM.from_pretrained(
99
+ BASE_MODEL,
100
+ torch_dtype=torch.bfloat16,
101
+ attn_implementation="sdpa",
102
+ device_map="auto",
103
+ trust_remote_code=True,
104
+ )
105
+
106
+ # Load tokenizer
107
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
108
+ if tokenizer.pad_token is None:
109
+ tokenizer.pad_token = tokenizer.eos_token
110
+ tokenizer.padding_side = "right"
111
+
112
+ # Load DPO adapter and merge it into the base model
113
+ print("Loading and merging DPO adapter...")
114
+ model = PeftModel.from_pretrained(model, DPO_ADAPTER)
115
+ model = model.merge_and_unload()
116
+ print("DPO adapter merged successfully!")
117
+
118
+ print(f"Model loaded: {model.config.num_hidden_layers} layers, {model.config.hidden_size} hidden size")
119
+
120
+ # ============================================================================
121
+ # NEW LORA CONFIG FOR SFT
122
+ # ============================================================================
123
+
124
+ print(f"\nNew LoRA config for SFT: r={LORA_R}, alpha={LORA_ALPHA}")
125
+
126
+ lora_config = LoraConfig(
127
+ r=LORA_R,
128
+ lora_alpha=LORA_ALPHA,
129
+ target_modules=[
130
+ "q_proj", "k_proj", "v_proj", "o_proj",
131
+ "gate_proj", "up_proj", "down_proj"
132
+ ],
133
+ lora_dropout=LORA_DROPOUT,
134
+ bias="none",
135
+ task_type="CAUSAL_LM"
136
+ )
137
+
138
+ # ============================================================================
139
+ # LOAD DATASET
140
+ # ============================================================================
141
+
142
+ print(f"\nLoading dataset: {DATASET_REPO}")
143
+
144
+ def load_jsonl_dataset(repo_id: str, filename: str) -> Dataset:
145
+ """Load JSONL dataset and extract only messages column."""
146
+ local_path = hf_hub_download(
147
+ repo_id=repo_id,
148
+ filename=filename,
149
+ repo_type="dataset"
150
+ )
151
+
152
+ messages_list = []
153
+ with open(local_path, 'r', encoding='utf-8') as f:
154
+ for line in f:
155
+ data = json.loads(line)
156
+ messages_list.append({"messages": data["messages"]})
157
+
158
+ return Dataset.from_list(messages_list)
159
+
160
+ # Load train and validation
161
+ train_dataset = load_jsonl_dataset(DATASET_REPO, TRAIN_FILE)
162
+ val_dataset = load_jsonl_dataset(DATASET_REPO, VAL_FILE)
163
+
164
+ print(f"Train: {len(train_dataset)} examples")
165
+ print(f"Validation: {len(val_dataset)} examples")
166
+
167
+ # Format examples
168
+ def format_example(example):
169
+ """Format messages to text for training."""
170
+ messages = example["messages"]
171
+ text = tokenizer.apply_chat_template(
172
+ messages,
173
+ tokenize=False,
174
+ add_generation_prompt=False
175
+ )
176
+ return {"text": text}
177
+
178
+ print("Formatting data...")
179
+ train_dataset = train_dataset.map(format_example, remove_columns=train_dataset.column_names)
180
+ val_dataset = val_dataset.map(format_example, remove_columns=val_dataset.column_names)
181
+
182
+ # Show example
183
+ print("\nExample formatted data:")
184
+ print(train_dataset[0]["text"][:500] + "...")
185
+
186
+ # ============================================================================
187
+ # TRAINING CONFIGURATION
188
+ # ============================================================================
189
+
190
+ print(f"\nTraining configuration:")
191
+ print(f" - Epochs: {NUM_EPOCHS}")
192
+ print(f" - Batch size: {BATCH_SIZE}")
193
+ print(f" - Gradient accumulation: {GRAD_ACCUM}")
194
+ print(f" - Effective batch size: {BATCH_SIZE * GRAD_ACCUM}")
195
+ print(f" - Learning rate: {LEARNING_RATE}")
196
+ print(f" - Max sequence length: {MAX_SEQ_LENGTH}")
197
+
198
+ training_args = SFTConfig(
199
+ output_dir=OUTPUT_DIR,
200
+ num_train_epochs=NUM_EPOCHS,
201
+ per_device_train_batch_size=BATCH_SIZE,
202
+ per_device_eval_batch_size=BATCH_SIZE,
203
+ gradient_accumulation_steps=GRAD_ACCUM,
204
+ learning_rate=LEARNING_RATE,
205
+ lr_scheduler_type="cosine",
206
+ warmup_ratio=0.05,
207
+ weight_decay=0.01,
208
+ bf16=True,
209
+ tf32=True,
210
+ logging_steps=50,
211
+ save_strategy="steps",
212
+ save_steps=1000,
213
+ save_total_limit=3,
214
+ eval_strategy="steps",
215
+ eval_steps=1000,
216
+ max_seq_length=MAX_SEQ_LENGTH,
217
+ packing=False,
218
+ gradient_checkpointing=True,
219
+ gradient_checkpointing_kwargs={"use_reentrant": False},
220
+ dataset_text_field="text",
221
+ report_to="none",
222
+ run_name="qwen3-sft-multitask",
223
+ hub_model_id=HF_REPO if hf_token else None,
224
+ push_to_hub=bool(hf_token),
225
+ hub_strategy="checkpoint",
226
+ )
227
+
228
+ # ============================================================================
229
+ # TRAINING
230
+ # ============================================================================
231
+
232
+ print("\nInitializing SFT trainer...")
233
+
234
+ trainer = SFTTrainer(
235
+ model=model,
236
+ args=training_args,
237
+ train_dataset=train_dataset,
238
+ eval_dataset=val_dataset,
239
+ peft_config=lora_config,
240
+ processing_class=tokenizer,
241
+ )
242
+
243
+ # Show trainable parameters
244
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
245
+ total_params = sum(p.numel() for p in model.parameters())
246
+ print(f"\nTrainable parameters: {trainable_params:,} / {total_params:,} ({100 * trainable_params / total_params:.2f}%)")
247
+
248
+ print("\n" + "=" * 60)
249
+ print("STARTING SFT MULTI-TASK TRAINING")
250
+ print("=" * 60)
251
+
252
+ trainer.train()
253
+
254
+ # ============================================================================
255
+ # SAVE MODEL
256
+ # ============================================================================
257
+
258
+ print("\nSaving model...")
259
+ trainer.save_model(f"{OUTPUT_DIR}/final")
260
+
261
+ if hf_token:
262
+ print(f"Pushing to {HF_REPO}...")
263
+ trainer.push_to_hub()
264
+ print(f"Model available at: https://huggingface.co/{HF_REPO}")
265
+
266
+ print("\n" + "=" * 60)
267
+ print("SFT MULTI-TASK TRAINING COMPLETE")
268
+ print("=" * 60)