chatbot / old /train_script_v2.py
frc 10252
add files
3905c4a
import torch
import torch.nn as nn
from torch.nn import functional as F
import math, time, os
from torch.utils.data import Dataset, DataLoader
import tiktoken
# from torch.cuda.amp import autocast, GradScaler
from torch.amp.autocast_mode import autocast
from torch.amp.grad_scaler import GradScaler
from datasets import load_dataset
from tqdm import tqdm
# Load dataset
dataset = load_dataset("Bingsu/openwebtext_20p")
# This gives you cleaned, plain text articles1
print(dataset['train'][100]['text'][:500]) # pyright: ignore[reportArgumentType] # Print the first 500 characters of the first article
print(dataset['train'][600000]) # pyright: ignore[reportArgumentType]
class TextDataset(Dataset):
def __init__(self, hf_dataset, tokenizer, block_size):
self.dataset = hf_dataset
self.tokenizer = tokenizer
self.block_size = block_size
def __len__(self):
return len(self.dataset['train'])
def __getitem__(self, idx):
# Start with a random index sample
rand_idx = torch.randint(0, len(self.dataset['train']), (1,)).item()
text = self.dataset['train'][rand_idx]['text']
tokens = self.tokenizer.encode(text)
# Keep appending more samples if too short
while len(tokens) < self.block_size + 1:
next_idx = torch.randint(0, len(self.dataset['train']), (1,)).item()
next_text = self.dataset['train'][next_idx]['text']
tokens.extend(self.tokenizer.encode(" " + next_text))
# Prevent runaway growth
if len(tokens) > self.block_size * 2:
break
# Truncate to block_size + 1
tokens = torch.tensor(tokens[: self.block_size + 1])
x = tokens[: self.block_size]
y = tokens[1 : self.block_size + 1]
return x.long(), y.long()
# hyperparameters
train_model = True
block_size = 256
n_layers = 8
n_heads = 8
dropout_p = 0.1
batch_size = 8
learning_rate = 3e-4
n_embedding = 512
max_iters = 5000
device = 'cuda' if torch.cuda.is_available() else 'cpu'
class GPTModel(nn.Module):
def __init__(self, vocab_size, n_embedding, n_layers, n_heads, dropout_p, block_size):
super(GPTModel, self).__init__()
self.token_embedding = nn.Embedding(vocab_size, n_embedding)
self.position_embedding = nn.Embedding(block_size, n_embedding)
self.layers = nn.ModuleList([
nn.TransformerEncoderLayer(d_model=n_embedding, nhead=n_heads, dropout=dropout_p)
for _ in range(n_layers)
])
self.ln_f = nn.LayerNorm(n_embedding)
self.head = nn.Linear(n_embedding, vocab_size)
self.dropout = nn.Dropout(dropout_p)
self.block_size = block_size
def forward(self, x):
bsz, seq_len = x.size()
positions = torch.arange(0, seq_len, device=x.device).unsqueeze(0).expand(bsz, seq_len)
x = self.token_embedding(x) + self.position_embedding(positions)
x = self.dropout(x)
for layer in self.layers:
x = layer(x)
x = self.ln_f(x)
logits = self.head(x)
return logits
# Initialize tokenizer and dataset
tokenizer = tiktoken.get_encoding("gpt2")
train_dataset = TextDataset(dataset, tokenizer, block_size=block_size)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=16)
# Define model objects
vocab_size = tokenizer.n_vocab
model = GPTModel(vocab_size, n_embedding, n_layers, n_heads, dropout_p, block_size).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()
# Training loop
def train():
torch.set_float32_matmul_precision('high')
scaler = GradScaler(device)
if train_model:
compiled_model = torch.compile(model)
pbar = tqdm(range(max_iters), desc="Training", ncols=100)
data_iter = iter(train_dataloader)
for count in pbar:
xb, yb = next(data_iter)
xb, yb = xb.to(device), yb.to(device)
with autocast(device, dtype=torch.float16):
logits = compiled_model(xb)
loss = loss_fn(logits.view(-1, vocab_size), yb.view(-1))
# backward pass with gradient scaling
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
# update bar text dynamically
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
@torch.no_grad()
def generate_text(model, tokenizer, prompt, max_new_tokens, block_size, device):
model.eval()
# Encode the prompt text into token IDs
tokens = torch.tensor(tokenizer.encode(prompt), dtype=torch.long).unsqueeze(0).to(device)
for _ in range(max_new_tokens):
# Only keep the last block_size tokens for context
input_tokens = tokens[:, -block_size:]
# Get logits and take the last token's distribution
logits = model(input_tokens)
logits = logits[:, -1, :] # (batch=1, vocab)
probs = F.softmax(logits, dim=-1)
# Sample from the distribution
next_token = torch.multinomial(probs, num_samples=1)
tokens = torch.cat((tokens, next_token), dim=1)
# Decode back into text
output_text = tokenizer.decode(tokens[0].tolist())
return output_text
def save_model(model, filepath):
if not os.path.exists(os.path.dirname(filepath)):
os.makedirs(os.path.dirname(filepath))
torch.save(model.state_dict(), filepath)
def load_model(model, filepath):
model.load_state_dict(torch.load(filepath))
return model
def main():
if train_model:
train()
save_model(model, "checkpoints/gpt_model-1.pth")
else:
model.load_state_dict(torch.load("checkpoints/gpt_model-1.pth"))
# Example of generating text after training or loading
prompt = "me when the "
generated_text = generate_text(model, tokenizer, prompt, max_new_tokens=50, block_size=block_size, device=device)
print(generated_text)
if __name__ == "__main__":
main()