Upload train_trl.py
Browse files- train_trl.py +7 -7
train_trl.py
CHANGED
|
@@ -68,7 +68,8 @@ def rollout(policy_name: str, task_name: str, collect_dataset: bool = False):
|
|
| 68 |
records.append(
|
| 69 |
{
|
| 70 |
"prompt": obs_to_prompt(result.observation),
|
| 71 |
-
|
|
|
|
| 72 |
}
|
| 73 |
)
|
| 74 |
|
|
@@ -114,26 +115,25 @@ def run_trl_sft(dataset: Dataset) -> None:
|
|
| 114 |
|
| 115 |
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL)
|
| 116 |
|
| 117 |
-
|
| 118 |
-
return f"<|user|>\n{example['prompt']}\n<|assistant|>\n{example['response']}"
|
| 119 |
-
|
| 120 |
config = SFTConfig(
|
| 121 |
output_dir="outputs/sft_run",
|
| 122 |
per_device_train_batch_size=1,
|
| 123 |
gradient_accumulation_steps=2,
|
| 124 |
learning_rate=2e-5,
|
| 125 |
num_train_epochs=1,
|
| 126 |
-
|
| 127 |
logging_steps=5,
|
| 128 |
save_strategy="no",
|
| 129 |
-
report_to=
|
| 130 |
)
|
| 131 |
|
|
|
|
| 132 |
trainer = SFTTrainer(
|
| 133 |
model=model,
|
| 134 |
args=config,
|
| 135 |
train_dataset=dataset,
|
| 136 |
-
|
| 137 |
)
|
| 138 |
trainer.train()
|
| 139 |
|
|
|
|
| 68 |
records.append(
|
| 69 |
{
|
| 70 |
"prompt": obs_to_prompt(result.observation),
|
| 71 |
+
# TRL 0.20+ expects `completion` (not `response`) for prompt/completion SFT.
|
| 72 |
+
"completion": action_to_json(action),
|
| 73 |
}
|
| 74 |
)
|
| 75 |
|
|
|
|
| 115 |
|
| 116 |
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL)
|
| 117 |
|
| 118 |
+
# TRL >= 0.20 uses `max_length`; older versions used `max_seq_length`.
|
|
|
|
|
|
|
| 119 |
config = SFTConfig(
|
| 120 |
output_dir="outputs/sft_run",
|
| 121 |
per_device_train_batch_size=1,
|
| 122 |
gradient_accumulation_steps=2,
|
| 123 |
learning_rate=2e-5,
|
| 124 |
num_train_epochs=1,
|
| 125 |
+
max_length=768,
|
| 126 |
logging_steps=5,
|
| 127 |
save_strategy="no",
|
| 128 |
+
report_to="none",
|
| 129 |
)
|
| 130 |
|
| 131 |
+
# Use prompt + completion columns; pass tokenizer as processing_class (TRL 0.20+).
|
| 132 |
trainer = SFTTrainer(
|
| 133 |
model=model,
|
| 134 |
args=config,
|
| 135 |
train_dataset=dataset,
|
| 136 |
+
processing_class=tokenizer,
|
| 137 |
)
|
| 138 |
trainer.train()
|
| 139 |
|