Upload train_trl.py
Browse files- train_trl.py +54 -2
train_trl.py
CHANGED
|
@@ -95,6 +95,55 @@ def build_training_dataset(episodes_per_task: int = 4) -> Dataset:
|
|
| 95 |
return Dataset.from_list(all_rows)
|
| 96 |
|
| 97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
def run_trl_sft(dataset: Dataset) -> None:
|
| 99 |
"""
|
| 100 |
Minimal TRL script.
|
|
@@ -115,6 +164,9 @@ def run_trl_sft(dataset: Dataset) -> None:
|
|
| 115 |
|
| 116 |
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL)
|
| 117 |
|
|
|
|
|
|
|
|
|
|
| 118 |
# TRL >= 0.20 uses `max_length`; older versions used `max_seq_length`.
|
| 119 |
config = SFTConfig(
|
| 120 |
output_dir="outputs/sft_run",
|
|
@@ -123,16 +175,16 @@ def run_trl_sft(dataset: Dataset) -> None:
|
|
| 123 |
learning_rate=2e-5,
|
| 124 |
num_train_epochs=1,
|
| 125 |
max_length=768,
|
|
|
|
| 126 |
logging_steps=5,
|
| 127 |
save_strategy="no",
|
| 128 |
report_to="none",
|
| 129 |
)
|
| 130 |
|
| 131 |
-
# Use prompt + completion columns; pass tokenizer as processing_class (TRL 0.20+).
|
| 132 |
trainer = SFTTrainer(
|
| 133 |
model=model,
|
| 134 |
args=config,
|
| 135 |
-
train_dataset=
|
| 136 |
processing_class=tokenizer,
|
| 137 |
)
|
| 138 |
trainer.train()
|
|
|
|
| 95 |
return Dataset.from_list(all_rows)
|
| 96 |
|
| 97 |
|
| 98 |
+
def _dataset_to_sft_text_column(dataset: Dataset, tokenizer) -> Dataset:
|
| 99 |
+
"""
|
| 100 |
+
TRL 0.20+ tokenization can fail or mis-detect `prompt`/`completion` (e.g. old `response` key, or
|
| 101 |
+
`formatting_func` that drops columns). A single `text` column + `dataset_text_field` uses the
|
| 102 |
+
standard LM code path in SFT and is the most reliable across TRL versions.
|
| 103 |
+
"""
|
| 104 |
+
from transformers import PreTrainedTokenizerBase
|
| 105 |
+
|
| 106 |
+
if not isinstance(tokenizer, PreTrainedTokenizerBase):
|
| 107 |
+
return dataset
|
| 108 |
+
|
| 109 |
+
# Accept either column name (old notebooks / stale clones)
|
| 110 |
+
cols = set(dataset.column_names)
|
| 111 |
+
if "completion" not in cols and "response" in cols:
|
| 112 |
+
dataset = dataset.rename_column("response", "completion")
|
| 113 |
+
if "prompt" not in dataset.column_names or "completion" not in dataset.column_names:
|
| 114 |
+
raise ValueError(
|
| 115 |
+
f"Expected columns 'prompt' and 'completion' (or 'response'). Got: {dataset.column_names}"
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
has_template = bool(getattr(tokenizer, "chat_template", None))
|
| 119 |
+
|
| 120 |
+
def to_text_batched(examples: Dict[str, List[str]]) -> Dict[str, List[str]]:
|
| 121 |
+
out: List[str] = []
|
| 122 |
+
for prompt, completion in zip(examples["prompt"], examples["completion"]):
|
| 123 |
+
if has_template:
|
| 124 |
+
messages = [
|
| 125 |
+
{"role": "user", "content": prompt},
|
| 126 |
+
{"role": "assistant", "content": completion},
|
| 127 |
+
]
|
| 128 |
+
out.append(
|
| 129 |
+
tokenizer.apply_chat_template(
|
| 130 |
+
messages,
|
| 131 |
+
tokenize=False,
|
| 132 |
+
add_generation_prompt=False,
|
| 133 |
+
)
|
| 134 |
+
)
|
| 135 |
+
else:
|
| 136 |
+
out.append(f"User: {prompt}\n\nAssistant: {completion}")
|
| 137 |
+
return {"text": out}
|
| 138 |
+
|
| 139 |
+
to_drop = [c for c in dataset.column_names if c != "text"]
|
| 140 |
+
return dataset.map(
|
| 141 |
+
to_text_batched,
|
| 142 |
+
batched=True,
|
| 143 |
+
remove_columns=to_drop,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
def run_trl_sft(dataset: Dataset) -> None:
|
| 148 |
"""
|
| 149 |
Minimal TRL script.
|
|
|
|
| 164 |
|
| 165 |
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL)
|
| 166 |
|
| 167 |
+
# Single `text` column — avoids TRL's prompt+completion tokenize path KeyErrors across versions.
|
| 168 |
+
train_ds = _dataset_to_sft_text_column(dataset, tokenizer)
|
| 169 |
+
|
| 170 |
# TRL >= 0.20 uses `max_length`; older versions used `max_seq_length`.
|
| 171 |
config = SFTConfig(
|
| 172 |
output_dir="outputs/sft_run",
|
|
|
|
| 175 |
learning_rate=2e-5,
|
| 176 |
num_train_epochs=1,
|
| 177 |
max_length=768,
|
| 178 |
+
dataset_text_field="text",
|
| 179 |
logging_steps=5,
|
| 180 |
save_strategy="no",
|
| 181 |
report_to="none",
|
| 182 |
)
|
| 183 |
|
|
|
|
| 184 |
trainer = SFTTrainer(
|
| 185 |
model=model,
|
| 186 |
args=config,
|
| 187 |
+
train_dataset=train_ds,
|
| 188 |
processing_class=tokenizer,
|
| 189 |
)
|
| 190 |
trainer.train()
|