stmasson commited on
Commit
a18220e
·
verified ·
1 Parent(s): 39f3734

Upload scripts/train_orpo_n8n_thinking.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/train_orpo_n8n_thinking.py +17 -6
scripts/train_orpo_n8n_thinking.py CHANGED
@@ -9,6 +9,7 @@
9
  # "bitsandbytes",
10
  # "sentencepiece",
11
  # "protobuf",
 
12
  # ]
13
  # ///
14
 
@@ -23,9 +24,10 @@ in a single training objective, making it more efficient than DPO for this use c
23
  """
24
 
25
  import trackio
 
26
  from datasets import load_dataset
27
  from peft import LoraConfig
28
- from transformers import AutoTokenizer, AutoModelForCausalLM
29
  from trl import ORPOTrainer, ORPOConfig
30
 
31
 
@@ -56,11 +58,20 @@ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
56
  if tokenizer.pad_token is None:
57
  tokenizer.pad_token = tokenizer.eos_token
58
 
59
- print(f"Loading model from {MODEL_NAME}...")
 
 
 
 
 
 
 
 
60
  model = AutoModelForCausalLM.from_pretrained(
61
  MODEL_NAME,
62
- torch_dtype="auto",
63
  device_map="auto",
 
64
  )
65
 
66
  # LoRA configuration for efficient training on 7B model
@@ -87,10 +98,10 @@ config = ORPOConfig(
87
  # Training parameters
88
  num_train_epochs=2,
89
  per_device_train_batch_size=1,
90
- gradient_accumulation_steps=16, # Effective batch size = 16
91
  learning_rate=5e-5,
92
- max_length=4096, # Long context for workflows + thinking
93
- max_prompt_length=512,
94
 
95
  # Memory optimization
96
  gradient_checkpointing=True,
 
9
  # "bitsandbytes",
10
  # "sentencepiece",
11
  # "protobuf",
12
+ # "flash-attn",
13
  # ]
14
  # ///
15
 
 
24
  """
25
 
26
  import trackio
27
+ import torch
28
  from datasets import load_dataset
29
  from peft import LoraConfig
30
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
31
  from trl import ORPOTrainer, ORPOConfig
32
 
33
 
 
58
  if tokenizer.pad_token is None:
59
  tokenizer.pad_token = tokenizer.eos_token
60
 
61
+ # 4-bit quantization config to reduce memory
62
+ bnb_config = BitsAndBytesConfig(
63
+ load_in_4bit=True,
64
+ bnb_4bit_quant_type="nf4",
65
+ bnb_4bit_compute_dtype=torch.bfloat16,
66
+ bnb_4bit_use_double_quant=True,
67
+ )
68
+
69
+ print(f"Loading model from {MODEL_NAME} with 4-bit quantization...")
70
  model = AutoModelForCausalLM.from_pretrained(
71
  MODEL_NAME,
72
+ quantization_config=bnb_config,
73
  device_map="auto",
74
+ attn_implementation="flash_attention_2",
75
  )
76
 
77
  # LoRA configuration for efficient training on 7B model
 
98
  # Training parameters
99
  num_train_epochs=2,
100
  per_device_train_batch_size=1,
101
+ gradient_accumulation_steps=32, # Effective batch size = 32
102
  learning_rate=5e-5,
103
+ max_length=2048, # Reduced for memory
104
+ max_prompt_length=256,
105
 
106
  # Memory optimization
107
  gradient_checkpointing=True,