File size: 4,343 Bytes
634ff98 39f3734 634ff98 a18220e 634ff98 a18220e 634ff98 9de98a3 a18220e 9de98a3 a18220e 9de98a3 7dbc984 9de98a3 634ff98 a18220e 634ff98 a18220e 634ff98 9de98a3 634ff98 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
#!/usr/bin/env python3
# /// script
# dependencies = [
# "trl>=0.12.0",
# "transformers>=4.46.0",
# "accelerate>=0.24.0",
# "peft>=0.7.0",
# "trackio",
# "bitsandbytes",
# "sentencepiece",
# "protobuf",
# ]
# ///
"""
ORPO training for n8n workflows with chain-of-thought reasoning.
Fine-tunes stmasson/mistral-7b-n8n-workflows on the n8n-workflows-thinking dataset
to generate structured reasoning (<thinking>) before producing n8n workflow JSON.
ORPO (Odds Ratio Preference Optimization) combines SFT and preference learning
in a single training objective, making it more efficient than DPO for this use case.
"""
import trackio
import torch
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from trl import ORPOTrainer, ORPOConfig
# Load ORPO dataset
print("Loading n8n-workflows-thinking dataset (ORPO split)...")
train_dataset = load_dataset(
"stmasson/n8n-workflows-thinking",
data_files="data/orpo/train.jsonl",
split="train"
)
eval_dataset = load_dataset(
"stmasson/n8n-workflows-thinking",
data_files="data/orpo/validation.jsonl",
split="train"
)
print(f"Train: {len(train_dataset)} examples")
print(f"Eval: {len(eval_dataset)} examples")
# Remove metadata column (not needed for training)
train_dataset = train_dataset.remove_columns(["metadata"])
eval_dataset = eval_dataset.remove_columns(["metadata"])
# Load model and tokenizer
MODEL_NAME = "stmasson/mistral-7b-n8n-workflows"
print(f"Loading tokenizer from {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 4-bit quantization config to reduce memory
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
print(f"Loading model from {MODEL_NAME} with 4-bit quantization...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
device_map="auto",
attn_implementation="sdpa", # Use scaled dot-product attention
)
# LoRA configuration for efficient training on 7B model
lora_config = LoraConfig(
r=32,
lora_alpha=64,
lora_dropout=0.05,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
task_type="CAUSAL_LM",
)
# ORPO training configuration
config = ORPOConfig(
# Hub settings - CRITICAL for saving
output_dir="mistral-7b-n8n-thinking-orpo",
push_to_hub=True,
hub_model_id="stmasson/mistral-7b-n8n-thinking-orpo",
hub_strategy="every_save",
hub_private_repo=False,
# ORPO-specific parameter
beta=0.1, # Weight for the odds ratio loss
# Training parameters
num_train_epochs=2,
per_device_train_batch_size=1,
gradient_accumulation_steps=32, # Effective batch size = 32
learning_rate=5e-5,
max_length=2048, # Reduced for memory
max_prompt_length=256,
# Memory optimization
gradient_checkpointing=True,
bf16=True,
# Logging & checkpointing
logging_steps=10,
save_strategy="steps",
save_steps=200,
save_total_limit=3,
# Evaluation
eval_strategy="steps",
eval_steps=200,
# Optimization
warmup_ratio=0.1,
lr_scheduler_type="cosine",
optim="adamw_8bit", # Memory-efficient optimizer
# Monitoring with Trackio
report_to="trackio",
project="n8n-thinking-training",
run_name="mistral-7b-orpo-reasoning",
)
# Initialize trainer
print("Initializing ORPO trainer...")
trainer = ORPOTrainer(
model=model,
processing_class=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=lora_config,
args=config,
)
print("Starting ORPO training...")
print(f" Model: stmasson/mistral-7b-n8n-workflows")
print(f" Dataset: stmasson/n8n-workflows-thinking (ORPO)")
print(f" Output: stmasson/mistral-7b-n8n-thinking-orpo")
trainer.train()
print("Pushing final model to Hub...")
trainer.push_to_hub()
# Finish Trackio tracking
trackio.finish()
print("Training complete!")
print("Model: https://huggingface.co/stmasson/mistral-7b-n8n-thinking-orpo")
print("Metrics: https://huggingface.co/spaces/stmasson/trackio")
|