RomanGPT / app.py
SunX45's picture
Update app.py
e8d85ea
import torch
from transformers import GPTNeoForCausalLM, GPT2Tokenizer, Trainer, TrainingArguments
from torch.utils.data import Dataset
class TextDataset(Dataset):
def __init__(self, text, tokenizer):
self.tokenizer = tokenizer
self.input_ids = []
self.attn_masks = []
for i in range(0, len(text) - 1024 + 1, 1024): # GPT-Neo has a max length of 1024
inputs = tokenizer.encode_plus(text[i:i + 1024], truncation=True, max_length=1024, padding="max_length", return_tensors='pt')
self.input_ids.append(inputs['input_ids'])
self.attn_masks.append(inputs['attention_mask'])
def __len__(self):
return len(self.input_ids)
def __getitem__(self, idx):
return self.input_ids[idx], self.attn_masks[idx]
class GPTNeoTrainer:
def __init__(self, model_name, dataset_path):
self.model = GPTNeoForCausalLM.from_pretrained(model_name)
self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
with open(dataset_path, "r") as f:
data = f.read()
self.dataset = TextDataset(data, self.tokenizer)
self.training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=10,
per_device_train_batch_size=16,
per_device_eval_batch_size=64,
warmup_steps=500,
weight_decay=0.01,
logging_dir='./logs',
)
def train(self):
trainer = Trainer(
model=self.model,
args=self.training_args,
train_dataset=self.dataset,
)
trainer.train()
def save_model(self, output_dir):
self.model.save_pretrained(output_dir)
# ИспользованиС класса
trainer = GPTNeoTrainer("EleutherAI/gpt-neo-1.3B", "dataset.txt")
trainer.train()
trainer.save_model("model_directory")