GPT2LMHeadModel / train.py
Ananthusajeev190's picture
Upload 10 files
2d7594e verified
from transformers import GPT2LMHeadModel, GPT2TokenizerFast, Trainer, TrainingArguments
from datasets import load_dataset
MODEL_DIR = "./68h"
tokenizer = GPT2TokenizerFast.from_pretrained(MODEL_DIR)
model = GPT2LMHeadModel.from_pretrained(MODEL_DIR)
dataset = load_dataset("text", data_files={"train": "html_text_dataset.txt"})
def tokenize(batch):
return tokenizer(batch["text"], truncation=True, padding="max_length", max_length=512)
tokenized = dataset.map(tokenize, batched=True, remove_columns=["text"])
args = TrainingArguments(
output_dir="./out",
per_device_train_batch_size=2,
num_train_epochs=1,
save_steps=500,
logging_steps=100
)
trainer = Trainer(
model=model,
args=args,
train_dataset=tokenized["train"]
)
trainer.train()