# /// script # requires-python = ">=3.10" # dependencies = [ # "transformers>=4.45.0", # "trl>=0.12.0", # "peft>=0.13.0", # "datasets>=3.0.0", # "accelerate>=1.0.0", # "bitsandbytes>=0.44.0", # "wandb>=0.18.0", # "huggingface_hub>=0.26.0", # "torch>=2.4.0", # "einops>=0.8.0", # "sentencepiece>=0.2.0", # ] # [tool.uv] # extra-index-url = ["https://download.pytorch.org/whl/cu124"] # /// """ Script d'entraînement DPO pour le modèle n8n Expert. À exécuter APRÈS l'entraînement SFT. Usage sur HuggingFace Jobs: hf jobs uv run \ --script train_n8n_dpo.py \ --flavor h100x1 \ --name n8n-expert-dpo \ --timeout 12h \ --env BASE_MODEL=stmasson/n8n-expert-14b-sft Variables d'environnement: - HF_TOKEN: Token HuggingFace - BASE_MODEL: Modèle SFT à utiliser comme base - WANDB_API_KEY: (optionnel) Pour le tracking """ import os import torch from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer from peft import LoraConfig, PeftModel from trl import DPOTrainer, DPOConfig from huggingface_hub import login # ============================================================================ # CONFIGURATION # ============================================================================ # Modèle SFT fine-tuné BASE_MODEL = os.environ.get("BASE_MODEL", "stmasson/n8n-expert-14b-sft") ORIGINAL_MODEL = os.environ.get("ORIGINAL_MODEL", "Qwen/Qwen2.5-14B-Instruct") # Dataset DPO DATASET_REPO = "stmasson/n8n-workflows-thinking" DPO_FILE = "n8n_dpo_train.jsonl" # Output OUTPUT_DIR = "./n8n-expert-dpo" HF_REPO = os.environ.get("HF_REPO", "stmasson/n8n-expert-14b-dpo") # Hyperparamètres DPO NUM_EPOCHS = int(os.environ.get("NUM_EPOCHS", "2")) BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "1")) GRAD_ACCUM = int(os.environ.get("GRAD_ACCUM", "16")) LEARNING_RATE = float(os.environ.get("LEARNING_RATE", "5e-6")) BETA = float(os.environ.get("DPO_BETA", "0.1")) MAX_LENGTH = int(os.environ.get("MAX_LENGTH", "8192")) MAX_PROMPT_LENGTH = int(os.environ.get("MAX_PROMPT_LENGTH", "2048")) # LoRA (plus léger pour DPO) LORA_R = int(os.environ.get("LORA_R", "32")) LORA_ALPHA = int(os.environ.get("LORA_ALPHA", "64")) # ============================================================================ # AUTHENTIFICATION # ============================================================================ print("=" * 60) print("ENTRAÎNEMENT DPO - N8N EXPERT") print("=" * 60) hf_token = os.environ.get("HF_TOKEN") if hf_token: login(token=hf_token) print("Authentifié sur HuggingFace") wandb_key = os.environ.get("WANDB_API_KEY") if wandb_key: import wandb wandb.login(key=wandb_key) report_to = "wandb" else: report_to = "none" # ============================================================================ # CHARGEMENT DU MODÈLE # ============================================================================ print(f"\nChargement du modèle SFT: {BASE_MODEL}") # Charger le modèle de base model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", device_map="auto", trust_remote_code=True, ) # Charger le modèle de référence (pour DPO) ref_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", device_map="auto", trust_remote_code=True, ) tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print("Modèle chargé") # ============================================================================ # CONFIGURATION LORA # ============================================================================ print(f"\nConfiguration LoRA: r={LORA_R}, alpha={LORA_ALPHA}") lora_config = LoraConfig( r=LORA_R, lora_alpha=LORA_ALPHA, target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM" ) # ============================================================================ # CHARGEMENT DU DATASET DPO # ============================================================================ print(f"\nChargement du dataset DPO: {DATASET_REPO}") dataset = load_dataset( DATASET_REPO, data_files={"train": DPO_FILE}, split="train" ) print(f"Exemples DPO: {len(dataset)}") # Fonction de formatage pour DPO def format_dpo_example(example): """ Format attendu par DPOTrainer: - prompt: le prompt de l'utilisateur - chosen: la bonne réponse - rejected: la mauvaise réponse """ return { "prompt": example["prompt"], "chosen": example["chosen"], "rejected": example["rejected"], } # Le dataset devrait déjà être au bon format print("\nExemple de données DPO:") print(f"Prompt: {dataset[0]['prompt'][:200]}...") print(f"Chosen: {dataset[0]['chosen'][:200]}...") print(f"Rejected: {dataset[0]['rejected'][:200]}...") # ============================================================================ # CONFIGURATION D'ENTRAÎNEMENT DPO # ============================================================================ print(f"\nConfiguration DPO:") print(f" - Beta: {BETA}") print(f" - Epochs: {NUM_EPOCHS}") print(f" - Batch size: {BATCH_SIZE}") print(f" - Gradient accumulation: {GRAD_ACCUM}") print(f" - Learning rate: {LEARNING_RATE}") dpo_config = DPOConfig( output_dir=OUTPUT_DIR, num_train_epochs=NUM_EPOCHS, per_device_train_batch_size=BATCH_SIZE, gradient_accumulation_steps=GRAD_ACCUM, learning_rate=LEARNING_RATE, beta=BETA, lr_scheduler_type="cosine", warmup_ratio=0.1, bf16=True, logging_steps=10, save_strategy="steps", save_steps=200, save_total_limit=3, max_length=MAX_LENGTH, max_prompt_length=MAX_PROMPT_LENGTH, gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": False}, report_to=report_to, run_name="n8n-expert-dpo", hub_model_id=HF_REPO if hf_token else None, push_to_hub=bool(hf_token), ) # ============================================================================ # ENTRAÎNEMENT DPO # ============================================================================ print("\nInitialisation du DPO trainer...") trainer = DPOTrainer( model=model, ref_model=ref_model, args=dpo_config, train_dataset=dataset, peft_config=lora_config, tokenizer=tokenizer, ) print("\n" + "=" * 60) print("DÉMARRAGE DE L'ENTRAÎNEMENT DPO") print("=" * 60) trainer.train() # ============================================================================ # SAUVEGARDE # ============================================================================ print("\nSauvegarde du modèle...") trainer.save_model(f"{OUTPUT_DIR}/final") if hf_token: print(f"Push vers {HF_REPO}...") trainer.push_to_hub() print(f"Modèle disponible sur: https://huggingface.co/{HF_REPO}") print("\n" + "=" * 60) print("ENTRAÎNEMENT DPO TERMINÉ") print("=" * 60)