| import torch |
| from accelerate import Accelerator |
| from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments, DataCollatorForLanguageModeling |
| from datasets import load_dataset, concatenate_datasets |
|
|
| |
| model_name = 'gpt2' |
| tokenizer = GPT2Tokenizer.from_pretrained(model_name) |
| model = GPT2LMHeadModel.from_pretrained(model_name) |
|
|
| |
| tokenizer.pad_token = tokenizer.eos_token |
| tokenizer.truncation = True |
|
|
| |
| comb_ds = load_dataset("yoonholee/combined-preference-dataset", split='train[:1%]', trust_remote_code=True) |
| pref_ds = load_dataset("OpenRLHF/preference_dataset_mixture2_and_safe_pku", split='train[:1%]', trust_remote_code=True) |
| com_ds = load_dataset("community-datasets/generics_kb", "generics_kb_simplewiki", split='train[:1%]', trust_remote_code=True) |
|
|
| |
| combined_dataset = concatenate_datasets([comb_ds, pref_ds, com_ds]) |
|
|
| |
| def preprocess_function(examples): |
| |
| text_fields = ['text', 'chosen', 'rejected', 'content', 'sentence', 'concept_name'] |
| for field in text_fields: |
| if field in examples: |
| texts = examples[field] |
| break |
| else: |
| raise ValueError(f"No available text fields were found: {examples.keys()}") |
| |
| texts = [str(text) if text is not None else "" for text in texts] |
| return tokenizer(texts, truncation=True, padding='max_length', max_length=256) |
|
|
| |
| print("Dataset columns:", combined_dataset.column_names) |
| print("Sample data from datasets:") |
| print(combined_dataset[:5]) |
|
|
| |
| tokenized_datasets = combined_dataset.map(preprocess_function, batched=True, |
| remove_columns=combined_dataset.column_names) |
| tokenized_datasets.set_format('torch', columns=['input_ids', 'attention_mask']) |
|
|
| |
| dataset_size = len(tokenized_datasets) |
|
|
| |
| train_size = min(1000, dataset_size) |
| eval_size = min(200, dataset_size) |
|
|
| |
| small_train_dataset = tokenized_datasets.shuffle(seed=42).select(range(train_size)) |
| small_eval_dataset = tokenized_datasets.shuffle(seed=42).select(range(eval_size)) |
|
|
| |
| training_args = TrainingArguments( |
| output_dir='./results', |
| eval_strategy='epoch', |
| learning_rate=2e-5, |
| per_device_train_batch_size=2, |
| per_device_eval_batch_size=2, |
| num_train_epochs=3, |
| weight_decay=0.01, |
| save_total_limit=2, |
| ) |
|
|
| |
| data_collator = DataCollatorForLanguageModeling( |
| tokenizer=tokenizer, |
| mlm=False, |
| ) |
|
|
| |
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=small_train_dataset, |
| eval_dataset=small_eval_dataset, |
| data_collator=data_collator, |
| tokenizer=tokenizer, |
| ) |
|
|
| |
| trainer.train() |