SwapnilPatil28 commited on
Commit
1d6a71e
·
verified ·
1 Parent(s): c575679

Upload train_trl.py

Browse files
Files changed (1) hide show
  1. train_trl.py +54 -2
train_trl.py CHANGED
@@ -95,6 +95,55 @@ def build_training_dataset(episodes_per_task: int = 4) -> Dataset:
95
  return Dataset.from_list(all_rows)
96
 
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  def run_trl_sft(dataset: Dataset) -> None:
99
  """
100
  Minimal TRL script.
@@ -115,6 +164,9 @@ def run_trl_sft(dataset: Dataset) -> None:
115
 
116
  model = AutoModelForCausalLM.from_pretrained(BASE_MODEL)
117
 
 
 
 
118
  # TRL >= 0.20 uses `max_length`; older versions used `max_seq_length`.
119
  config = SFTConfig(
120
  output_dir="outputs/sft_run",
@@ -123,16 +175,16 @@ def run_trl_sft(dataset: Dataset) -> None:
123
  learning_rate=2e-5,
124
  num_train_epochs=1,
125
  max_length=768,
 
126
  logging_steps=5,
127
  save_strategy="no",
128
  report_to="none",
129
  )
130
 
131
- # Use prompt + completion columns; pass tokenizer as processing_class (TRL 0.20+).
132
  trainer = SFTTrainer(
133
  model=model,
134
  args=config,
135
- train_dataset=dataset,
136
  processing_class=tokenizer,
137
  )
138
  trainer.train()
 
95
  return Dataset.from_list(all_rows)
96
 
97
 
98
+ def _dataset_to_sft_text_column(dataset: Dataset, tokenizer) -> Dataset:
99
+ """
100
+ TRL 0.20+ tokenization can fail or mis-detect `prompt`/`completion` (e.g. old `response` key, or
101
+ `formatting_func` that drops columns). A single `text` column + `dataset_text_field` uses the
102
+ standard LM code path in SFT and is the most reliable across TRL versions.
103
+ """
104
+ from transformers import PreTrainedTokenizerBase
105
+
106
+ if not isinstance(tokenizer, PreTrainedTokenizerBase):
107
+ return dataset
108
+
109
+ # Accept either column name (old notebooks / stale clones)
110
+ cols = set(dataset.column_names)
111
+ if "completion" not in cols and "response" in cols:
112
+ dataset = dataset.rename_column("response", "completion")
113
+ if "prompt" not in dataset.column_names or "completion" not in dataset.column_names:
114
+ raise ValueError(
115
+ f"Expected columns 'prompt' and 'completion' (or 'response'). Got: {dataset.column_names}"
116
+ )
117
+
118
+ has_template = bool(getattr(tokenizer, "chat_template", None))
119
+
120
+ def to_text_batched(examples: Dict[str, List[str]]) -> Dict[str, List[str]]:
121
+ out: List[str] = []
122
+ for prompt, completion in zip(examples["prompt"], examples["completion"]):
123
+ if has_template:
124
+ messages = [
125
+ {"role": "user", "content": prompt},
126
+ {"role": "assistant", "content": completion},
127
+ ]
128
+ out.append(
129
+ tokenizer.apply_chat_template(
130
+ messages,
131
+ tokenize=False,
132
+ add_generation_prompt=False,
133
+ )
134
+ )
135
+ else:
136
+ out.append(f"User: {prompt}\n\nAssistant: {completion}")
137
+ return {"text": out}
138
+
139
+ to_drop = [c for c in dataset.column_names if c != "text"]
140
+ return dataset.map(
141
+ to_text_batched,
142
+ batched=True,
143
+ remove_columns=to_drop,
144
+ )
145
+
146
+
147
  def run_trl_sft(dataset: Dataset) -> None:
148
  """
149
  Minimal TRL script.
 
164
 
165
  model = AutoModelForCausalLM.from_pretrained(BASE_MODEL)
166
 
167
+ # Single `text` column — avoids TRL's prompt+completion tokenize path KeyErrors across versions.
168
+ train_ds = _dataset_to_sft_text_column(dataset, tokenizer)
169
+
170
  # TRL >= 0.20 uses `max_length`; older versions used `max_seq_length`.
171
  config = SFTConfig(
172
  output_dir="outputs/sft_run",
 
175
  learning_rate=2e-5,
176
  num_train_epochs=1,
177
  max_length=768,
178
+ dataset_text_field="text",
179
  logging_steps=5,
180
  save_strategy="no",
181
  report_to="none",
182
  )
183
 
 
184
  trainer = SFTTrainer(
185
  model=model,
186
  args=config,
187
+ train_dataset=train_ds,
188
  processing_class=tokenizer,
189
  )
190
  trainer.train()