File size: 4,343 Bytes
634ff98
 
 
 
 
 
 
 
 
39f3734
 
634ff98
 
 
 
 
 
 
 
 
 
 
 
 
 
a18220e
634ff98
 
a18220e
634ff98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9de98a3
 
 
 
 
 
 
a18220e
 
 
 
 
 
 
 
 
9de98a3
 
a18220e
9de98a3
7dbc984
9de98a3
 
634ff98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a18220e
634ff98
a18220e
 
634ff98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9de98a3
 
634ff98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
#!/usr/bin/env python3
# /// script
# dependencies = [
#     "trl>=0.12.0",
#     "transformers>=4.46.0",
#     "accelerate>=0.24.0",
#     "peft>=0.7.0",
#     "trackio",
#     "bitsandbytes",
#     "sentencepiece",
#     "protobuf",
# ]
# ///

"""
ORPO training for n8n workflows with chain-of-thought reasoning.

Fine-tunes stmasson/mistral-7b-n8n-workflows on the n8n-workflows-thinking dataset
to generate structured reasoning (<thinking>) before producing n8n workflow JSON.

ORPO (Odds Ratio Preference Optimization) combines SFT and preference learning
in a single training objective, making it more efficient than DPO for this use case.
"""

import trackio
import torch
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from trl import ORPOTrainer, ORPOConfig


# Load ORPO dataset
print("Loading n8n-workflows-thinking dataset (ORPO split)...")
train_dataset = load_dataset(
    "stmasson/n8n-workflows-thinking",
    data_files="data/orpo/train.jsonl",
    split="train"
)
eval_dataset = load_dataset(
    "stmasson/n8n-workflows-thinking",
    data_files="data/orpo/validation.jsonl",
    split="train"
)

print(f"Train: {len(train_dataset)} examples")
print(f"Eval: {len(eval_dataset)} examples")

# Remove metadata column (not needed for training)
train_dataset = train_dataset.remove_columns(["metadata"])
eval_dataset = eval_dataset.remove_columns(["metadata"])

# Load model and tokenizer
MODEL_NAME = "stmasson/mistral-7b-n8n-workflows"
print(f"Loading tokenizer from {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# 4-bit quantization config to reduce memory
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

print(f"Loading model from {MODEL_NAME} with 4-bit quantization...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
    attn_implementation="sdpa",  # Use scaled dot-product attention
)

# LoRA configuration for efficient training on 7B model
lora_config = LoraConfig(
    r=32,
    lora_alpha=64,
    lora_dropout=0.05,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)

# ORPO training configuration
config = ORPOConfig(
    # Hub settings - CRITICAL for saving
    output_dir="mistral-7b-n8n-thinking-orpo",
    push_to_hub=True,
    hub_model_id="stmasson/mistral-7b-n8n-thinking-orpo",
    hub_strategy="every_save",
    hub_private_repo=False,

    # ORPO-specific parameter
    beta=0.1,  # Weight for the odds ratio loss

    # Training parameters
    num_train_epochs=2,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=32,  # Effective batch size = 32
    learning_rate=5e-5,
    max_length=2048,  # Reduced for memory
    max_prompt_length=256,

    # Memory optimization
    gradient_checkpointing=True,
    bf16=True,

    # Logging & checkpointing
    logging_steps=10,
    save_strategy="steps",
    save_steps=200,
    save_total_limit=3,

    # Evaluation
    eval_strategy="steps",
    eval_steps=200,

    # Optimization
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    optim="adamw_8bit",  # Memory-efficient optimizer

    # Monitoring with Trackio
    report_to="trackio",
    project="n8n-thinking-training",
    run_name="mistral-7b-orpo-reasoning",
)

# Initialize trainer
print("Initializing ORPO trainer...")
trainer = ORPOTrainer(
    model=model,
    processing_class=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    peft_config=lora_config,
    args=config,
)

print("Starting ORPO training...")
print(f"  Model: stmasson/mistral-7b-n8n-workflows")
print(f"  Dataset: stmasson/n8n-workflows-thinking (ORPO)")
print(f"  Output: stmasson/mistral-7b-n8n-thinking-orpo")

trainer.train()

print("Pushing final model to Hub...")
trainer.push_to_hub()

# Finish Trackio tracking
trackio.finish()

print("Training complete!")
print("Model: https://huggingface.co/stmasson/mistral-7b-n8n-thinking-orpo")
print("Metrics: https://huggingface.co/spaces/stmasson/trackio")