stmasson commited on
Commit
634ff98
·
verified ·
1 Parent(s): 2cd4c7b

Upload scripts/train_orpo_n8n_thinking.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/train_orpo_n8n_thinking.py +128 -0
scripts/train_orpo_n8n_thinking.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # /// script
3
+ # dependencies = [
4
+ # "trl>=0.12.0",
5
+ # "transformers>=4.46.0",
6
+ # "accelerate>=0.24.0",
7
+ # "peft>=0.7.0",
8
+ # "trackio",
9
+ # "bitsandbytes",
10
+ # ]
11
+ # ///
12
+
13
+ """
14
+ ORPO training for n8n workflows with chain-of-thought reasoning.
15
+
16
+ Fine-tunes stmasson/mistral-7b-n8n-workflows on the n8n-workflows-thinking dataset
17
+ to generate structured reasoning (<thinking>) before producing n8n workflow JSON.
18
+
19
+ ORPO (Odds Ratio Preference Optimization) combines SFT and preference learning
20
+ in a single training objective, making it more efficient than DPO for this use case.
21
+ """
22
+
23
+ import trackio
24
+ from datasets import load_dataset
25
+ from peft import LoraConfig
26
+ from trl import ORPOTrainer, ORPOConfig
27
+
28
+
29
+ # Load ORPO dataset
30
+ print("Loading n8n-workflows-thinking dataset (ORPO split)...")
31
+ train_dataset = load_dataset(
32
+ "stmasson/n8n-workflows-thinking",
33
+ data_files="data/orpo/train.jsonl",
34
+ split="train"
35
+ )
36
+ eval_dataset = load_dataset(
37
+ "stmasson/n8n-workflows-thinking",
38
+ data_files="data/orpo/validation.jsonl",
39
+ split="train"
40
+ )
41
+
42
+ print(f"Train: {len(train_dataset)} examples")
43
+ print(f"Eval: {len(eval_dataset)} examples")
44
+
45
+ # Remove metadata column (not needed for training)
46
+ train_dataset = train_dataset.remove_columns(["metadata"])
47
+ eval_dataset = eval_dataset.remove_columns(["metadata"])
48
+
49
+ # LoRA configuration for efficient training on 7B model
50
+ lora_config = LoraConfig(
51
+ r=32,
52
+ lora_alpha=64,
53
+ lora_dropout=0.05,
54
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
55
+ task_type="CAUSAL_LM",
56
+ )
57
+
58
+ # ORPO training configuration
59
+ config = ORPOConfig(
60
+ # Hub settings - CRITICAL for saving
61
+ output_dir="mistral-7b-n8n-thinking-orpo",
62
+ push_to_hub=True,
63
+ hub_model_id="stmasson/mistral-7b-n8n-thinking-orpo",
64
+ hub_strategy="every_save",
65
+ hub_private_repo=False,
66
+
67
+ # ORPO-specific parameter
68
+ beta=0.1, # Weight for the odds ratio loss
69
+
70
+ # Training parameters
71
+ num_train_epochs=2,
72
+ per_device_train_batch_size=1,
73
+ gradient_accumulation_steps=16, # Effective batch size = 16
74
+ learning_rate=5e-5,
75
+ max_length=4096, # Long context for workflows + thinking
76
+ max_prompt_length=512,
77
+
78
+ # Memory optimization
79
+ gradient_checkpointing=True,
80
+ bf16=True,
81
+
82
+ # Logging & checkpointing
83
+ logging_steps=10,
84
+ save_strategy="steps",
85
+ save_steps=200,
86
+ save_total_limit=3,
87
+
88
+ # Evaluation
89
+ eval_strategy="steps",
90
+ eval_steps=200,
91
+
92
+ # Optimization
93
+ warmup_ratio=0.1,
94
+ lr_scheduler_type="cosine",
95
+ optim="adamw_8bit", # Memory-efficient optimizer
96
+
97
+ # Monitoring with Trackio
98
+ report_to="trackio",
99
+ project="n8n-thinking-training",
100
+ run_name="mistral-7b-orpo-reasoning",
101
+ )
102
+
103
+ # Initialize trainer
104
+ print("Initializing ORPO trainer...")
105
+ trainer = ORPOTrainer(
106
+ model="stmasson/mistral-7b-n8n-workflows",
107
+ train_dataset=train_dataset,
108
+ eval_dataset=eval_dataset,
109
+ peft_config=lora_config,
110
+ args=config,
111
+ )
112
+
113
+ print("Starting ORPO training...")
114
+ print(f" Model: stmasson/mistral-7b-n8n-workflows")
115
+ print(f" Dataset: stmasson/n8n-workflows-thinking (ORPO)")
116
+ print(f" Output: stmasson/mistral-7b-n8n-thinking-orpo")
117
+
118
+ trainer.train()
119
+
120
+ print("Pushing final model to Hub...")
121
+ trainer.push_to_hub()
122
+
123
+ # Finish Trackio tracking
124
+ trackio.finish()
125
+
126
+ print("Training complete!")
127
+ print("Model: https://huggingface.co/stmasson/mistral-7b-n8n-thinking-orpo")
128
+ print("Metrics: https://huggingface.co/spaces/stmasson/trackio")