stmasson commited on
Commit
9de98a3
·
verified ·
1 Parent(s): 634ff98

Upload scripts/train_orpo_n8n_thinking.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/train_orpo_n8n_thinking.py +17 -1
scripts/train_orpo_n8n_thinking.py CHANGED
@@ -23,6 +23,7 @@ in a single training objective, making it more efficient than DPO for this use c
23
  import trackio
24
  from datasets import load_dataset
25
  from peft import LoraConfig
 
26
  from trl import ORPOTrainer, ORPOConfig
27
 
28
 
@@ -46,6 +47,20 @@ print(f"Eval: {len(eval_dataset)} examples")
46
  train_dataset = train_dataset.remove_columns(["metadata"])
47
  eval_dataset = eval_dataset.remove_columns(["metadata"])
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  # LoRA configuration for efficient training on 7B model
50
  lora_config = LoraConfig(
51
  r=32,
@@ -103,7 +118,8 @@ config = ORPOConfig(
103
  # Initialize trainer
104
  print("Initializing ORPO trainer...")
105
  trainer = ORPOTrainer(
106
- model="stmasson/mistral-7b-n8n-workflows",
 
107
  train_dataset=train_dataset,
108
  eval_dataset=eval_dataset,
109
  peft_config=lora_config,
 
23
  import trackio
24
  from datasets import load_dataset
25
  from peft import LoraConfig
26
+ from transformers import AutoTokenizer, AutoModelForCausalLM
27
  from trl import ORPOTrainer, ORPOConfig
28
 
29
 
 
47
  train_dataset = train_dataset.remove_columns(["metadata"])
48
  eval_dataset = eval_dataset.remove_columns(["metadata"])
49
 
50
+ # Load model and tokenizer
51
+ MODEL_NAME = "stmasson/mistral-7b-n8n-workflows"
52
+ print(f"Loading tokenizer from {MODEL_NAME}...")
53
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
54
+ if tokenizer.pad_token is None:
55
+ tokenizer.pad_token = tokenizer.eos_token
56
+
57
+ print(f"Loading model from {MODEL_NAME}...")
58
+ model = AutoModelForCausalLM.from_pretrained(
59
+ MODEL_NAME,
60
+ torch_dtype="auto",
61
+ device_map="auto",
62
+ )
63
+
64
  # LoRA configuration for efficient training on 7B model
65
  lora_config = LoraConfig(
66
  r=32,
 
118
  # Initialize trainer
119
  print("Initializing ORPO trainer...")
120
  trainer = ORPOTrainer(
121
+ model=model,
122
+ processing_class=tokenizer,
123
  train_dataset=train_dataset,
124
  eval_dataset=eval_dataset,
125
  peft_config=lora_config,