training-scripts / scripts /train_orpo_n8n_thinking.py
stmasson's picture
Upload scripts/train_orpo_n8n_thinking.py with huggingface_hub
7dbc984 verified
#!/usr/bin/env python3
# /// script
# dependencies = [
# "trl>=0.12.0",
# "transformers>=4.46.0",
# "accelerate>=0.24.0",
# "peft>=0.7.0",
# "trackio",
# "bitsandbytes",
# "sentencepiece",
# "protobuf",
# ]
# ///
"""
ORPO training for n8n workflows with chain-of-thought reasoning.
Fine-tunes stmasson/mistral-7b-n8n-workflows on the n8n-workflows-thinking dataset
to generate structured reasoning (<thinking>) before producing n8n workflow JSON.
ORPO (Odds Ratio Preference Optimization) combines SFT and preference learning
in a single training objective, making it more efficient than DPO for this use case.
"""
import trackio
import torch
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from trl import ORPOTrainer, ORPOConfig
# Load ORPO dataset
print("Loading n8n-workflows-thinking dataset (ORPO split)...")
train_dataset = load_dataset(
"stmasson/n8n-workflows-thinking",
data_files="data/orpo/train.jsonl",
split="train"
)
eval_dataset = load_dataset(
"stmasson/n8n-workflows-thinking",
data_files="data/orpo/validation.jsonl",
split="train"
)
print(f"Train: {len(train_dataset)} examples")
print(f"Eval: {len(eval_dataset)} examples")
# Remove metadata column (not needed for training)
train_dataset = train_dataset.remove_columns(["metadata"])
eval_dataset = eval_dataset.remove_columns(["metadata"])
# Load model and tokenizer
MODEL_NAME = "stmasson/mistral-7b-n8n-workflows"
print(f"Loading tokenizer from {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 4-bit quantization config to reduce memory
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
print(f"Loading model from {MODEL_NAME} with 4-bit quantization...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
device_map="auto",
attn_implementation="sdpa", # Use scaled dot-product attention
)
# LoRA configuration for efficient training on 7B model
lora_config = LoraConfig(
r=32,
lora_alpha=64,
lora_dropout=0.05,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
task_type="CAUSAL_LM",
)
# ORPO training configuration
config = ORPOConfig(
# Hub settings - CRITICAL for saving
output_dir="mistral-7b-n8n-thinking-orpo",
push_to_hub=True,
hub_model_id="stmasson/mistral-7b-n8n-thinking-orpo",
hub_strategy="every_save",
hub_private_repo=False,
# ORPO-specific parameter
beta=0.1, # Weight for the odds ratio loss
# Training parameters
num_train_epochs=2,
per_device_train_batch_size=1,
gradient_accumulation_steps=32, # Effective batch size = 32
learning_rate=5e-5,
max_length=2048, # Reduced for memory
max_prompt_length=256,
# Memory optimization
gradient_checkpointing=True,
bf16=True,
# Logging & checkpointing
logging_steps=10,
save_strategy="steps",
save_steps=200,
save_total_limit=3,
# Evaluation
eval_strategy="steps",
eval_steps=200,
# Optimization
warmup_ratio=0.1,
lr_scheduler_type="cosine",
optim="adamw_8bit", # Memory-efficient optimizer
# Monitoring with Trackio
report_to="trackio",
project="n8n-thinking-training",
run_name="mistral-7b-orpo-reasoning",
)
# Initialize trainer
print("Initializing ORPO trainer...")
trainer = ORPOTrainer(
model=model,
processing_class=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=lora_config,
args=config,
)
print("Starting ORPO training...")
print(f" Model: stmasson/mistral-7b-n8n-workflows")
print(f" Dataset: stmasson/n8n-workflows-thinking (ORPO)")
print(f" Output: stmasson/mistral-7b-n8n-thinking-orpo")
trainer.train()
print("Pushing final model to Hub...")
trainer.push_to_hub()
# Finish Trackio tracking
trackio.finish()
print("Training complete!")
print("Model: https://huggingface.co/stmasson/mistral-7b-n8n-thinking-orpo")
print("Metrics: https://huggingface.co/spaces/stmasson/trackio")