SwapnilPatil28 commited on
Commit
c575679
·
verified ·
1 Parent(s): eb2d131

Upload train_trl.py

Browse files
Files changed (1) hide show
  1. train_trl.py +7 -7
train_trl.py CHANGED
@@ -68,7 +68,8 @@ def rollout(policy_name: str, task_name: str, collect_dataset: bool = False):
68
  records.append(
69
  {
70
  "prompt": obs_to_prompt(result.observation),
71
- "response": action_to_json(action),
 
72
  }
73
  )
74
 
@@ -114,26 +115,25 @@ def run_trl_sft(dataset: Dataset) -> None:
114
 
115
  model = AutoModelForCausalLM.from_pretrained(BASE_MODEL)
116
 
117
- def formatting_func(example):
118
- return f"<|user|>\n{example['prompt']}\n<|assistant|>\n{example['response']}"
119
-
120
  config = SFTConfig(
121
  output_dir="outputs/sft_run",
122
  per_device_train_batch_size=1,
123
  gradient_accumulation_steps=2,
124
  learning_rate=2e-5,
125
  num_train_epochs=1,
126
- max_seq_length=768,
127
  logging_steps=5,
128
  save_strategy="no",
129
- report_to=[],
130
  )
131
 
 
132
  trainer = SFTTrainer(
133
  model=model,
134
  args=config,
135
  train_dataset=dataset,
136
- formatting_func=formatting_func,
137
  )
138
  trainer.train()
139
 
 
68
  records.append(
69
  {
70
  "prompt": obs_to_prompt(result.observation),
71
+ # TRL 0.20+ expects `completion` (not `response`) for prompt/completion SFT.
72
+ "completion": action_to_json(action),
73
  }
74
  )
75
 
 
115
 
116
  model = AutoModelForCausalLM.from_pretrained(BASE_MODEL)
117
 
118
+ # TRL >= 0.20 uses `max_length`; older versions used `max_seq_length`.
 
 
119
  config = SFTConfig(
120
  output_dir="outputs/sft_run",
121
  per_device_train_batch_size=1,
122
  gradient_accumulation_steps=2,
123
  learning_rate=2e-5,
124
  num_train_epochs=1,
125
+ max_length=768,
126
  logging_steps=5,
127
  save_strategy="no",
128
+ report_to="none",
129
  )
130
 
131
+ # Use prompt + completion columns; pass tokenizer as processing_class (TRL 0.20+).
132
  trainer = SFTTrainer(
133
  model=model,
134
  args=config,
135
  train_dataset=dataset,
136
+ processing_class=tokenizer,
137
  )
138
  trainer.train()
139