stmasson commited on
Commit
3b4be13
·
verified ·
1 Parent(s): 4a23090

Upload train_qwen3_sft_multitask.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_qwen3_sft_multitask.py +279 -0
train_qwen3_sft_multitask.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "transformers>=4.45.0",
5
+ # "trl>=0.12.0",
6
+ # "peft>=0.13.0",
7
+ # "datasets>=3.0.0",
8
+ # "accelerate>=1.0.0",
9
+ # "huggingface_hub>=0.26.0",
10
+ # "torch>=2.4.0",
11
+ # ]
12
+ # [tool.uv]
13
+ # index-strategy = "unsafe-best-match"
14
+ # extra-index-url = ["https://download.pytorch.org/whl/cu124"]
15
+ # ///
16
+ """
17
+ SFT Multi-Task Training Script for n8n Agent
18
+
19
+ This script fine-tunes the DPO-trained Qwen3-0.6B model on multi-task n8n workflows.
20
+ It builds on the reasoning capabilities from DPO training and adds task-specific skills.
21
+
22
+ Tasks covered:
23
+ - generate: Create workflows from descriptions
24
+ - edit: Modify existing workflows
25
+ - fix: Correct errors in workflows
26
+ - explain: Explain what workflows do
27
+ - debug: Diagnose execution issues
28
+ - improve: Optimize and enhance workflows
29
+
30
+ Usage:
31
+ hf jobs uv run \
32
+ --script train_qwen3_sft_multitask.py \
33
+ --flavor l4x1 \
34
+ --timeout 24h
35
+ """
36
+
37
+ import os
38
+ import json
39
+ import torch
40
+ from datasets import Dataset
41
+ from transformers import AutoModelForCausalLM, AutoTokenizer
42
+ from peft import LoraConfig, PeftModel, get_peft_model
43
+ from trl import SFTTrainer, SFTConfig
44
+ from huggingface_hub import login, hf_hub_download
45
+
46
+ # ============================================================================
47
+ # CONFIGURATION
48
+ # ============================================================================
49
+
50
+ # Base model (the DPO-trained model with reasoning capabilities)
51
+ BASE_MODEL = os.environ.get("BASE_MODEL", "Qwen/Qwen3-0.6B")
52
+ DPO_ADAPTER = os.environ.get("DPO_ADAPTER", "stmasson/qwen3-0.6b-n8n-reasoning")
53
+
54
+ # Dataset
55
+ DATASET_REPO = "stmasson/n8n-agentic-multitask"
56
+ TRAIN_FILE = "data/multitask_large/train.jsonl"
57
+ VAL_FILE = "data/multitask_large/val.jsonl"
58
+
59
+ # Output
60
+ OUTPUT_DIR = "./qwen3-sft-multitask"
61
+ HF_REPO = os.environ.get("HF_REPO", "stmasson/qwen3-0.6b-n8n-agent")
62
+
63
+ # Hyperparameters
64
+ NUM_EPOCHS = int(os.environ.get("NUM_EPOCHS", "1"))
65
+ BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "1"))
66
+ GRAD_ACCUM = int(os.environ.get("GRAD_ACCUM", "8"))
67
+ LEARNING_RATE = float(os.environ.get("LEARNING_RATE", "1e-5"))
68
+ MAX_SEQ_LENGTH = int(os.environ.get("MAX_SEQ_LENGTH", "4096"))
69
+
70
+ # LoRA (continuing from DPO adapter)
71
+ LORA_R = int(os.environ.get("LORA_R", "32"))
72
+ LORA_ALPHA = int(os.environ.get("LORA_ALPHA", "64"))
73
+ LORA_DROPOUT = float(os.environ.get("LORA_DROPOUT", "0.05"))
74
+
75
+ # ============================================================================
76
+ # AUTHENTICATION
77
+ # ============================================================================
78
+
79
+ print("=" * 60)
80
+ print("SFT MULTI-TASK TRAINING - N8N AGENT")
81
+ print("=" * 60)
82
+
83
+ hf_token = os.environ.get("HF_TOKEN")
84
+ if hf_token:
85
+ login(token=hf_token)
86
+ print("Authenticated with HuggingFace")
87
+ else:
88
+ print("Warning: HF_TOKEN not set, push disabled")
89
+
90
+ # ============================================================================
91
+ # LOAD MODEL WITH DPO ADAPTER
92
+ # ============================================================================
93
+
94
+ print(f"\nLoading base model: {BASE_MODEL}")
95
+ print(f"Loading DPO adapter: {DPO_ADAPTER}")
96
+
97
+ # Load base model
98
+ model = AutoModelForCausalLM.from_pretrained(
99
+ BASE_MODEL,
100
+ torch_dtype=torch.bfloat16,
101
+ attn_implementation="sdpa",
102
+ device_map="auto",
103
+ trust_remote_code=True,
104
+ )
105
+
106
+ # Load tokenizer
107
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
108
+ if tokenizer.pad_token is None:
109
+ tokenizer.pad_token = tokenizer.eos_token
110
+ tokenizer.padding_side = "right"
111
+
112
+ # Load DPO adapter and merge it into the base model
113
+ print("Loading and merging DPO adapter...")
114
+ model = PeftModel.from_pretrained(model, DPO_ADAPTER)
115
+ model = model.merge_and_unload()
116
+ print("DPO adapter merged successfully!")
117
+
118
+ print(f"Model loaded: {model.config.num_hidden_layers} layers, {model.config.hidden_size} hidden size")
119
+
120
+ # ============================================================================
121
+ # NEW LORA CONFIG FOR SFT
122
+ # ============================================================================
123
+
124
+ print(f"\nNew LoRA config for SFT: r={LORA_R}, alpha={LORA_ALPHA}")
125
+
126
+ lora_config = LoraConfig(
127
+ r=LORA_R,
128
+ lora_alpha=LORA_ALPHA,
129
+ target_modules=[
130
+ "q_proj", "k_proj", "v_proj", "o_proj",
131
+ "gate_proj", "up_proj", "down_proj"
132
+ ],
133
+ lora_dropout=LORA_DROPOUT,
134
+ bias="none",
135
+ task_type="CAUSAL_LM"
136
+ )
137
+
138
+ # ============================================================================
139
+ # LOAD DATASET
140
+ # ============================================================================
141
+
142
+ print(f"\nLoading dataset: {DATASET_REPO}")
143
+
144
+ def load_jsonl_dataset(repo_id: str, filename: str) -> Dataset:
145
+ """Load JSONL dataset and extract only messages column."""
146
+ local_path = hf_hub_download(
147
+ repo_id=repo_id,
148
+ filename=filename,
149
+ repo_type="dataset"
150
+ )
151
+
152
+ messages_list = []
153
+ with open(local_path, 'r', encoding='utf-8') as f:
154
+ for line in f:
155
+ data = json.loads(line)
156
+ messages_list.append({"messages": data["messages"]})
157
+
158
+ return Dataset.from_list(messages_list)
159
+
160
+ # Load train and validation
161
+ train_dataset = load_jsonl_dataset(DATASET_REPO, TRAIN_FILE)
162
+ val_dataset = load_jsonl_dataset(DATASET_REPO, VAL_FILE)
163
+
164
+ print(f"Train: {len(train_dataset)} examples")
165
+ print(f"Validation: {len(val_dataset)} examples")
166
+
167
+ # Filter out very long examples to avoid OOM
168
+ def filter_by_length(example):
169
+ """Filter examples that would be too long."""
170
+ total_len = sum(len(m.get('content', '')) for m in example['messages'])
171
+ return total_len < 30000 # ~7500 tokens max
172
+
173
+ print("Filtering long examples...")
174
+ train_dataset = train_dataset.filter(filter_by_length)
175
+ val_dataset = val_dataset.filter(filter_by_length)
176
+ print(f"After filtering - Train: {len(train_dataset)}, Val: {len(val_dataset)}")
177
+
178
+ # Format examples
179
+ def format_example(example):
180
+ """Format messages to text for training."""
181
+ messages = example["messages"]
182
+ text = tokenizer.apply_chat_template(
183
+ messages,
184
+ tokenize=False,
185
+ add_generation_prompt=False
186
+ )
187
+ return {"text": text}
188
+
189
+ print("Formatting data...")
190
+ train_dataset = train_dataset.map(format_example, remove_columns=train_dataset.column_names)
191
+ val_dataset = val_dataset.map(format_example, remove_columns=val_dataset.column_names)
192
+
193
+ # Show example
194
+ print("\nExample formatted data:")
195
+ print(train_dataset[0]["text"][:500] + "...")
196
+
197
+ # ============================================================================
198
+ # TRAINING CONFIGURATION
199
+ # ============================================================================
200
+
201
+ print(f"\nTraining configuration:")
202
+ print(f" - Epochs: {NUM_EPOCHS}")
203
+ print(f" - Batch size: {BATCH_SIZE}")
204
+ print(f" - Gradient accumulation: {GRAD_ACCUM}")
205
+ print(f" - Effective batch size: {BATCH_SIZE * GRAD_ACCUM}")
206
+ print(f" - Learning rate: {LEARNING_RATE}")
207
+ print(f" - Max sequence length: {MAX_SEQ_LENGTH}")
208
+
209
+ training_args = SFTConfig(
210
+ output_dir=OUTPUT_DIR,
211
+ num_train_epochs=NUM_EPOCHS,
212
+ per_device_train_batch_size=BATCH_SIZE,
213
+ per_device_eval_batch_size=BATCH_SIZE,
214
+ gradient_accumulation_steps=GRAD_ACCUM,
215
+ learning_rate=LEARNING_RATE,
216
+ lr_scheduler_type="cosine",
217
+ warmup_ratio=0.05,
218
+ weight_decay=0.01,
219
+ bf16=True,
220
+ tf32=True,
221
+ logging_steps=50,
222
+ save_strategy="steps",
223
+ save_steps=1000,
224
+ save_total_limit=3,
225
+ eval_strategy="steps",
226
+ eval_steps=1000,
227
+ max_length=MAX_SEQ_LENGTH,
228
+ packing=False,
229
+ gradient_checkpointing=True,
230
+ gradient_checkpointing_kwargs={"use_reentrant": False},
231
+ dataset_text_field="text",
232
+ report_to="none",
233
+ run_name="qwen3-sft-multitask",
234
+ hub_model_id=HF_REPO if hf_token else None,
235
+ push_to_hub=bool(hf_token),
236
+ hub_strategy="checkpoint",
237
+ )
238
+
239
+ # ============================================================================
240
+ # TRAINING
241
+ # ============================================================================
242
+
243
+ print("\nInitializing SFT trainer...")
244
+
245
+ trainer = SFTTrainer(
246
+ model=model,
247
+ args=training_args,
248
+ train_dataset=train_dataset,
249
+ eval_dataset=val_dataset,
250
+ peft_config=lora_config,
251
+ processing_class=tokenizer,
252
+ )
253
+
254
+ # Show trainable parameters
255
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
256
+ total_params = sum(p.numel() for p in model.parameters())
257
+ print(f"\nTrainable parameters: {trainable_params:,} / {total_params:,} ({100 * trainable_params / total_params:.2f}%)")
258
+
259
+ print("\n" + "=" * 60)
260
+ print("STARTING SFT MULTI-TASK TRAINING")
261
+ print("=" * 60)
262
+
263
+ trainer.train()
264
+
265
+ # ============================================================================
266
+ # SAVE MODEL
267
+ # ============================================================================
268
+
269
+ print("\nSaving model...")
270
+ trainer.save_model(f"{OUTPUT_DIR}/final")
271
+
272
+ if hf_token:
273
+ print(f"Pushing to {HF_REPO}...")
274
+ trainer.push_to_hub()
275
+ print(f"Model available at: https://huggingface.co/{HF_REPO}")
276
+
277
+ print("\n" + "=" * 60)
278
+ print("SFT MULTI-TASK TRAINING COMPLETE")
279
+ print("=" * 60)