training-scripts / scripts /train_n8n_dpo.py
stmasson's picture
Upload scripts/train_n8n_dpo.py with huggingface_hub
70f98a4 verified
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "transformers>=4.45.0",
# "trl>=0.12.0",
# "peft>=0.13.0",
# "datasets>=3.0.0",
# "accelerate>=1.0.0",
# "bitsandbytes>=0.44.0",
# "wandb>=0.18.0",
# "huggingface_hub>=0.26.0",
# "torch>=2.4.0",
# "einops>=0.8.0",
# "sentencepiece>=0.2.0",
# ]
# [tool.uv]
# extra-index-url = ["https://download.pytorch.org/whl/cu124"]
# ///
"""
Script d'entraînement DPO pour le modèle n8n Expert.
À exécuter APRÈS l'entraînement SFT.
Usage sur HuggingFace Jobs:
hf jobs uv run \
--script train_n8n_dpo.py \
--flavor h100x1 \
--name n8n-expert-dpo \
--timeout 12h \
--env BASE_MODEL=stmasson/n8n-expert-14b-sft
Variables d'environnement:
- HF_TOKEN: Token HuggingFace
- BASE_MODEL: Modèle SFT à utiliser comme base
- WANDB_API_KEY: (optionnel) Pour le tracking
"""
import os
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, PeftModel
from trl import DPOTrainer, DPOConfig
from huggingface_hub import login
# ============================================================================
# CONFIGURATION
# ============================================================================
# Modèle SFT fine-tuné
BASE_MODEL = os.environ.get("BASE_MODEL", "stmasson/n8n-expert-14b-sft")
ORIGINAL_MODEL = os.environ.get("ORIGINAL_MODEL", "Qwen/Qwen2.5-14B-Instruct")
# Dataset DPO
DATASET_REPO = "stmasson/n8n-workflows-thinking"
DPO_FILE = "n8n_dpo_train.jsonl"
# Output
OUTPUT_DIR = "./n8n-expert-dpo"
HF_REPO = os.environ.get("HF_REPO", "stmasson/n8n-expert-14b-dpo")
# Hyperparamètres DPO
NUM_EPOCHS = int(os.environ.get("NUM_EPOCHS", "2"))
BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "1"))
GRAD_ACCUM = int(os.environ.get("GRAD_ACCUM", "16"))
LEARNING_RATE = float(os.environ.get("LEARNING_RATE", "5e-6"))
BETA = float(os.environ.get("DPO_BETA", "0.1"))
MAX_LENGTH = int(os.environ.get("MAX_LENGTH", "8192"))
MAX_PROMPT_LENGTH = int(os.environ.get("MAX_PROMPT_LENGTH", "2048"))
# LoRA (plus léger pour DPO)
LORA_R = int(os.environ.get("LORA_R", "32"))
LORA_ALPHA = int(os.environ.get("LORA_ALPHA", "64"))
# ============================================================================
# AUTHENTIFICATION
# ============================================================================
print("=" * 60)
print("ENTRAÎNEMENT DPO - N8N EXPERT")
print("=" * 60)
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
login(token=hf_token)
print("Authentifié sur HuggingFace")
wandb_key = os.environ.get("WANDB_API_KEY")
if wandb_key:
import wandb
wandb.login(key=wandb_key)
report_to = "wandb"
else:
report_to = "none"
# ============================================================================
# CHARGEMENT DU MODÈLE
# ============================================================================
print(f"\nChargement du modèle SFT: {BASE_MODEL}")
# Charger le modèle de base
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto",
trust_remote_code=True,
)
# Charger le modèle de référence (pour DPO)
ref_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto",
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("Modèle chargé")
# ============================================================================
# CONFIGURATION LORA
# ============================================================================
print(f"\nConfiguration LoRA: r={LORA_R}, alpha={LORA_ALPHA}")
lora_config = LoraConfig(
r=LORA_R,
lora_alpha=LORA_ALPHA,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
# ============================================================================
# CHARGEMENT DU DATASET DPO
# ============================================================================
print(f"\nChargement du dataset DPO: {DATASET_REPO}")
dataset = load_dataset(
DATASET_REPO,
data_files={"train": DPO_FILE},
split="train"
)
print(f"Exemples DPO: {len(dataset)}")
# Fonction de formatage pour DPO
def format_dpo_example(example):
"""
Format attendu par DPOTrainer:
- prompt: le prompt de l'utilisateur
- chosen: la bonne réponse
- rejected: la mauvaise réponse
"""
return {
"prompt": example["prompt"],
"chosen": example["chosen"],
"rejected": example["rejected"],
}
# Le dataset devrait déjà être au bon format
print("\nExemple de données DPO:")
print(f"Prompt: {dataset[0]['prompt'][:200]}...")
print(f"Chosen: {dataset[0]['chosen'][:200]}...")
print(f"Rejected: {dataset[0]['rejected'][:200]}...")
# ============================================================================
# CONFIGURATION D'ENTRAÎNEMENT DPO
# ============================================================================
print(f"\nConfiguration DPO:")
print(f" - Beta: {BETA}")
print(f" - Epochs: {NUM_EPOCHS}")
print(f" - Batch size: {BATCH_SIZE}")
print(f" - Gradient accumulation: {GRAD_ACCUM}")
print(f" - Learning rate: {LEARNING_RATE}")
dpo_config = DPOConfig(
output_dir=OUTPUT_DIR,
num_train_epochs=NUM_EPOCHS,
per_device_train_batch_size=BATCH_SIZE,
gradient_accumulation_steps=GRAD_ACCUM,
learning_rate=LEARNING_RATE,
beta=BETA,
lr_scheduler_type="cosine",
warmup_ratio=0.1,
bf16=True,
logging_steps=10,
save_strategy="steps",
save_steps=200,
save_total_limit=3,
max_length=MAX_LENGTH,
max_prompt_length=MAX_PROMPT_LENGTH,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
report_to=report_to,
run_name="n8n-expert-dpo",
hub_model_id=HF_REPO if hf_token else None,
push_to_hub=bool(hf_token),
)
# ============================================================================
# ENTRAÎNEMENT DPO
# ============================================================================
print("\nInitialisation du DPO trainer...")
trainer = DPOTrainer(
model=model,
ref_model=ref_model,
args=dpo_config,
train_dataset=dataset,
peft_config=lora_config,
tokenizer=tokenizer,
)
print("\n" + "=" * 60)
print("DÉMARRAGE DE L'ENTRAÎNEMENT DPO")
print("=" * 60)
trainer.train()
# ============================================================================
# SAUVEGARDE
# ============================================================================
print("\nSauvegarde du modèle...")
trainer.save_model(f"{OUTPUT_DIR}/final")
if hf_token:
print(f"Push vers {HF_REPO}...")
trainer.push_to_hub()
print(f"Modèle disponible sur: https://huggingface.co/{HF_REPO}")
print("\n" + "=" * 60)
print("ENTRAÎNEMENT DPO TERMINÉ")
print("=" * 60)