stmasson commited on
Commit
1ba34a4
·
verified ·
1 Parent(s): 202ab61

Upload train_ministral_n8n.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_ministral_n8n.py +166 -0
train_ministral_n8n.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "torch>=2.0.0",
5
+ # "transformers>=4.45.0",
6
+ # "datasets>=2.14.0",
7
+ # "accelerate>=0.24.0",
8
+ # "peft>=0.7.0",
9
+ # "trl>=0.12.0",
10
+ # "bitsandbytes>=0.41.0",
11
+ # "huggingface_hub>=0.20.0",
12
+ # "trackio",
13
+ # ]
14
+ # ///
15
+
16
+ """
17
+ Fine-tune Ministral-3B on n8n-workflows-thinking dataset for SFT.
18
+ This script trains the model to generate n8n workflows with chain-of-thought reasoning.
19
+ """
20
+
21
+ import os
22
+ import torch
23
+ from datasets import load_dataset
24
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
25
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
26
+ from trl import SFTTrainer, SFTConfig
27
+ import trackio
28
+
29
+ # Configuration
30
+ MODEL_NAME = "mistralai/Ministral-3b-instruct"
31
+ DATASET_NAME = "stmasson/n8n-workflows-thinking"
32
+ OUTPUT_MODEL = "stmasson/ministral-3b-n8n-workflows"
33
+ MAX_SEQ_LENGTH = 4096 # n8n workflows can be long
34
+
35
+ # Initialize Trackio for monitoring
36
+ trackio.init(project_name="ministral-3b-n8n-sft")
37
+
38
+ print(f"Loading tokenizer from {MODEL_NAME}...")
39
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
40
+ if tokenizer.pad_token is None:
41
+ tokenizer.pad_token = tokenizer.eos_token
42
+
43
+ # Load dataset - directly from the SFT data files
44
+ print(f"Loading dataset {DATASET_NAME}...")
45
+ # Load the SFT split directly via data_files
46
+ dataset = load_dataset(
47
+ "json",
48
+ data_files={
49
+ "train": f"hf://datasets/{DATASET_NAME}/data/sft/train.jsonl",
50
+ "validation": f"hf://datasets/{DATASET_NAME}/data/sft/validation.jsonl"
51
+ }
52
+ )
53
+ train_dataset = dataset["train"]
54
+ eval_dataset = dataset["validation"]
55
+ print(f"Dataset loaded: {len(train_dataset)} train, {len(eval_dataset)} eval examples")
56
+
57
+ # Preprocess: apply chat template to create 'text' column
58
+ print("Preprocessing dataset with chat template...")
59
+ def preprocess_function(example):
60
+ """Apply chat template to messages."""
61
+ text = tokenizer.apply_chat_template(
62
+ example["messages"],
63
+ tokenize=False,
64
+ add_generation_prompt=False
65
+ )
66
+ return {"text": text}
67
+
68
+ train_dataset = train_dataset.map(
69
+ preprocess_function,
70
+ remove_columns=train_dataset.column_names,
71
+ desc="Applying chat template to train"
72
+ )
73
+ eval_dataset = eval_dataset.map(
74
+ preprocess_function,
75
+ remove_columns=eval_dataset.column_names,
76
+ desc="Applying chat template to eval"
77
+ )
78
+ print(f"Preprocessed: {len(train_dataset)} train, {len(eval_dataset)} eval")
79
+
80
+ # Quantization config for 4-bit training (saves VRAM)
81
+ bnb_config = BitsAndBytesConfig(
82
+ load_in_4bit=True,
83
+ bnb_4bit_quant_type="nf4",
84
+ bnb_4bit_compute_dtype=torch.bfloat16,
85
+ bnb_4bit_use_double_quant=True,
86
+ )
87
+
88
+ print(f"Loading model {MODEL_NAME} with 4-bit quantization...")
89
+ model = AutoModelForCausalLM.from_pretrained(
90
+ MODEL_NAME,
91
+ quantization_config=bnb_config,
92
+ device_map="auto",
93
+ torch_dtype=torch.bfloat16,
94
+ trust_remote_code=True,
95
+ attn_implementation="flash_attention_2",
96
+ )
97
+ model = prepare_model_for_kbit_training(model)
98
+
99
+ # LoRA configuration
100
+ lora_config = LoraConfig(
101
+ r=64, # Higher rank for complex task
102
+ lora_alpha=128,
103
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
104
+ lora_dropout=0.05,
105
+ bias="none",
106
+ task_type="CAUSAL_LM",
107
+ )
108
+
109
+ model = get_peft_model(model, lora_config)
110
+ model.print_trainable_parameters()
111
+
112
+ # Training configuration
113
+ training_args = SFTConfig(
114
+ output_dir="./ministral-3b-n8n-sft",
115
+ num_train_epochs=2,
116
+ per_device_train_batch_size=1,
117
+ per_device_eval_batch_size=1,
118
+ gradient_accumulation_steps=16,
119
+ learning_rate=1e-4,
120
+ lr_scheduler_type="cosine",
121
+ warmup_ratio=0.05,
122
+ weight_decay=0.01,
123
+ logging_steps=10,
124
+ save_strategy="steps",
125
+ save_steps=200,
126
+ eval_strategy="steps",
127
+ eval_steps=200,
128
+ save_total_limit=3,
129
+ bf16=True,
130
+ gradient_checkpointing=True,
131
+ gradient_checkpointing_kwargs={"use_reentrant": False},
132
+ max_seq_length=MAX_SEQ_LENGTH,
133
+ packing=False, # Don't pack - workflows need full context
134
+ dataset_text_field="text",
135
+ # Hub configuration
136
+ push_to_hub=True,
137
+ hub_model_id=OUTPUT_MODEL,
138
+ hub_strategy="checkpoint",
139
+ hub_private_repo=False,
140
+ # Reporting
141
+ report_to="trackio",
142
+ run_name="ministral-3b-n8n-sft",
143
+ )
144
+
145
+ # Initialize trainer
146
+ print("Initializing SFTTrainer...")
147
+ trainer = SFTTrainer(
148
+ model=model,
149
+ args=training_args,
150
+ train_dataset=train_dataset,
151
+ eval_dataset=eval_dataset,
152
+ processing_class=tokenizer,
153
+ )
154
+
155
+ # Train
156
+ print("Starting training...")
157
+ trainer.train()
158
+
159
+ # Save final model
160
+ print("Saving final model...")
161
+ trainer.save_model()
162
+ trainer.push_to_hub()
163
+
164
+ print(f"\nTraining complete!")
165
+ print(f"Model saved to: https://huggingface.co/{OUTPUT_MODEL}")
166
+ print(f"Training metrics: https://huggingface.co/spaces/stmasson/trackio")