|
|
import torch |
|
|
from transformers import GPTNeoForCausalLM, GPT2Tokenizer, Trainer, TrainingArguments |
|
|
from torch.utils.data import Dataset |
|
|
|
|
|
class TextDataset(Dataset): |
|
|
def __init__(self, text, tokenizer): |
|
|
self.tokenizer = tokenizer |
|
|
self.input_ids = [] |
|
|
self.attn_masks = [] |
|
|
|
|
|
for i in range(0, len(text) - 1024 + 1, 1024): |
|
|
inputs = tokenizer.encode_plus(text[i:i + 1024], truncation=True, max_length=1024, padding="max_length", return_tensors='pt') |
|
|
self.input_ids.append(inputs['input_ids']) |
|
|
self.attn_masks.append(inputs['attention_mask']) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.input_ids) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
return self.input_ids[idx], self.attn_masks[idx] |
|
|
|
|
|
class GPTNeoTrainer: |
|
|
def __init__(self, model_name, dataset_path): |
|
|
self.model = GPTNeoForCausalLM.from_pretrained(model_name) |
|
|
self.tokenizer = GPT2Tokenizer.from_pretrained(model_name) |
|
|
|
|
|
with open(dataset_path, "r") as f: |
|
|
data = f.read() |
|
|
|
|
|
self.dataset = TextDataset(data, self.tokenizer) |
|
|
|
|
|
self.training_args = TrainingArguments( |
|
|
output_dir="./results", |
|
|
num_train_epochs=10, |
|
|
per_device_train_batch_size=16, |
|
|
per_device_eval_batch_size=64, |
|
|
warmup_steps=500, |
|
|
weight_decay=0.01, |
|
|
logging_dir='./logs', |
|
|
) |
|
|
|
|
|
def train(self): |
|
|
trainer = Trainer( |
|
|
model=self.model, |
|
|
args=self.training_args, |
|
|
train_dataset=self.dataset, |
|
|
) |
|
|
|
|
|
trainer.train() |
|
|
|
|
|
def save_model(self, output_dir): |
|
|
self.model.save_pretrained(output_dir) |
|
|
|
|
|
|
|
|
trainer = GPTNeoTrainer("EleutherAI/gpt-neo-1.3B", "dataset.txt") |
|
|
trainer.train() |
|
|
trainer.save_model("model_directory") |