miyuki2026 commited on
Commit
e96faee
·
1 Parent(s): 3b65b42
examples/tutorials/dpo/ultrachat-sft/step_2_train_sft_model_ddp.py CHANGED
@@ -34,7 +34,7 @@ from trl import SFTTrainer, SFTConfig, DataCollatorForCompletionOnlyLM
34
 
35
  def get_args():
36
  parser = argparse.ArgumentParser()
37
- parser.add_argument("--local_rank", type=int, default=0) # torchrun会自动传递这个参数
38
 
39
  parser.add_argument(
40
  "--model_name",
 
34
 
35
  def get_args():
36
  parser = argparse.ArgumentParser()
37
+ parser.add_argument("--local_rank", type=int, default=-1) # torchrun会自动传递这个参数
38
 
39
  parser.add_argument(
40
  "--model_name",
examples/tutorials/dpo/ultrafeedback-dpo/step_2_train_dpo_model_ddp_qlora.py CHANGED
@@ -11,6 +11,16 @@ torchrun --nproc_per_node=2 step_2_train_dpo_model_ddp_qlora.py
11
  DPO本来就是风格微调,用LoRA 训练更合理,更科学。
12
 
13
 
 
 
 
 
 
 
 
 
 
 
14
  """
15
  import argparse
16
  import os
@@ -69,7 +79,8 @@ def get_args():
69
  type=str
70
  ),
71
 
72
- parser.add_argument("--beta", default=0.5, type=float),
 
73
 
74
  parser.add_argument(
75
  "--num_workers",
@@ -166,8 +177,8 @@ def main():
166
  ref_model = prepare_model_for_kbit_training(ref_model)
167
 
168
  lora_config = LoraConfig(
169
- r=16,
170
- lora_alpha=32,
171
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
172
  lora_dropout=0.1,
173
  bias="none",
@@ -227,7 +238,7 @@ def main():
227
  report_to="none",
228
  max_length=1024 if debug_mode else 2048, # prompt + chosen 的最大长度
229
  # DPO 特定参数
230
- beta=args.beta, # DPO 的温度参数,控制对 preference 的置信度
231
  remove_unused_columns=False,
232
  dataloader_pin_memory=False,
233
 
 
11
  DPO本来就是风格微调,用LoRA 训练更合理,更科学。
12
 
13
 
14
+ ----------
15
+
16
+ nohup torchrun --nproc_per_node=2 step_2_train_dpo_model_ddp_qlora.py \
17
+ --dpo_beta 0.5 \
18
+ --lora_rank 32 \
19
+ &
20
+
21
+ kill -9 `ps -aef | grep 'step_2_train_dpo_model_ddp_qlora.py' | grep -v grep | awk '{print $2}'`
22
+
23
+
24
  """
25
  import argparse
26
  import os
 
79
  type=str
80
  ),
81
 
82
+ parser.add_argument("--dpo_beta", default=0.5, type=float),
83
+ parser.add_argument("--lora_rank", default=32, type=int),
84
 
85
  parser.add_argument(
86
  "--num_workers",
 
177
  ref_model = prepare_model_for_kbit_training(ref_model)
178
 
179
  lora_config = LoraConfig(
180
+ r=args.lora_rank,
181
+ lora_alpha=args.lora_rank * 2,
182
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
183
  lora_dropout=0.1,
184
  bias="none",
 
238
  report_to="none",
239
  max_length=1024 if debug_mode else 2048, # prompt + chosen 的最大长度
240
  # DPO 特定参数
241
+ beta=args.dpo_beta, # DPO 的温度参数,控制对 preference 的置信度
242
  remove_unused_columns=False,
243
  dataloader_pin_memory=False,
244