shakespeareGPT / train.py
pradeep6kumar2024's picture
Add application file
787565d
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import os
import math
# Adjusted hyperparameters
BATCH_SIZE = 32
BLOCK_SIZE = 128
LEARNING_RATE = 3e-4
N_EMBD = 512 # Reduced from 768
N_HEAD = 8 # Reduced from 12
N_LAYER = 8 # Reduced from 12
DROPOUT = 0.2 # Increased from 0.1
WEIGHT_DECAY = 0.01
class DecoderBlock(nn.Module):
def __init__(self):
super().__init__()
assert N_EMBD % N_HEAD == 0
# Added dropout to attention
self.attention = nn.MultiheadAttention(
N_EMBD,
N_HEAD,
dropout=DROPOUT,
batch_first=True,
bias=False # Reduce parameters
)
# Modified feed forward with additional dropout
self.feed_forward = nn.Sequential(
nn.Linear(N_EMBD, 3 * N_EMBD, bias=False), # Reduced multiplier from 4 to 3
nn.GELU(),
nn.Dropout(DROPOUT),
nn.Linear(3 * N_EMBD, N_EMBD, bias=False),
nn.Dropout(DROPOUT)
)
# Layer normalization with improved epsilon
self.ln1 = nn.LayerNorm(N_EMBD, eps=1e-5)
self.ln2 = nn.LayerNorm(N_EMBD, eps=1e-5)
def forward(self, x, mask=None):
# Pre-norm architecture for better training stability
attn_out = self.attention(self.ln1(x), self.ln1(x), self.ln1(x), attn_mask=mask, need_weights=False)[0]
x = x + attn_out
x = x + self.feed_forward(self.ln2(x))
return x
class ShakespeareModel(nn.Module):
def __init__(self, vocab_size):
super().__init__()
self.token_embedding = nn.Embedding(vocab_size, N_EMBD)
self.position_embedding = nn.Embedding(BLOCK_SIZE, N_EMBD)
self.drop = nn.Dropout(DROPOUT)
self.blocks = nn.ModuleList([DecoderBlock() for _ in range(N_LAYER)])
self.ln_f = nn.LayerNorm(N_EMBD, eps=1e-5)
self.lm_head = nn.Linear(N_EMBD, vocab_size, bias=False)
# Improved weight initialization
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02/math.sqrt(2 * N_LAYER))
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02/math.sqrt(2 * N_LAYER))
elif isinstance(module, nn.LayerNorm):
torch.nn.init.ones_(module.weight)
torch.nn.init.zeros_(module.bias)
def forward(self, idx):
B, T = idx.shape
# Get embeddings and apply dropout
tok_emb = self.token_embedding(idx)
pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
pos_emb = self.position_embedding(pos)
# Apply dropout to combined embeddings
x = self.drop(tok_emb + pos_emb)
# Create causal mask
mask = torch.triu(torch.ones(T, T) * float('-inf'), diagonal=1).to(idx.device)
# Apply transformer blocks
for block in self.blocks:
x = block(x, mask)
x = self.ln_f(x)
logits = self.lm_head(x)
return logits
class TextDataset(Dataset):
def __init__(self, text, block_size):
chars = sorted(list(set(text)))
self.vocab_size = len(chars)
self.stoi = {ch: i for i, ch in enumerate(chars)}
self.itos = {i: ch for i, ch in enumerate(chars)}
data = torch.tensor([self.stoi[c] for c in text], dtype=torch.long)
# Create overlapping sequences for better context
self.examples = []
stride = block_size // 2 # Add stride for overlapping sequences
for i in range(0, len(data) - block_size, stride):
x = data[i:i + block_size]
y = data[i + 1:i + block_size + 1]
self.examples.append((x, y))
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
return self.examples[idx]
def evaluate_model(model, dataloader, device):
model.eval()
total_loss = 0
with torch.no_grad():
for x, y in dataloader:
x, y = x.to(device), y.to(device)
logits = model(x)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
total_loss += loss.item()
return total_loss / len(dataloader)
def train_model(model, train_dataloader, optimizer, scheduler, device):
model.train()
total_loss = 0
for x, y in train_dataloader:
x, y = x.to(device), y.to(device)
# Forward pass with mixed precision
logits = model(x)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
# Backward pass
optimizer.zero_grad(set_to_none=True) # More efficient than zero_grad()
loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optimizer.step()
total_loss += loss.item()
return total_loss / len(train_dataloader)
def main():
# Load and preprocess data
with open('input.txt', 'r', encoding='utf-8') as f:
text = f.read()
# Create full dataset
full_dataset = TextDataset(text, BLOCK_SIZE)
# Split into train and validation sets (90-10 split)
train_size = int(0.9 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
# Create dataloaders
train_dataloader = DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
pin_memory=True
)
val_dataloader = DataLoader(
val_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
pin_memory=True
)
# Initialize model and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ShakespeareModel(full_dataset.vocab_size).to(device)
# Optimizer with weight decay
optimizer = torch.optim.AdamW(
model.parameters(),
lr=LEARNING_RATE,
betas=(0.9, 0.95),
weight_decay=WEIGHT_DECAY
)
# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode='min',
factor=0.5,
patience=3,
verbose=True
)
# Load checkpoint if exists
start_epoch = 0
best_val_loss = float('inf')
if os.path.exists('shakespeare_model_best.pth'):
print("Loading checkpoint 'shakespeare_model_best.pth'")
checkpoint = torch.load('shakespeare_model_best.pth')
start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
best_val_loss = checkpoint['best_loss']
print(f"Loaded checkpoint (epoch {start_epoch})")
# Training loop
n_epochs = 1000
patience = 5
patience_counter = 0
for epoch in range(start_epoch, n_epochs):
train_loss = train_model(model, train_dataloader, optimizer, scheduler, device)
val_loss = evaluate_model(model, val_dataloader, device)
# Update learning rate
scheduler.step(val_loss)
print(f'Epoch {epoch+1}/{n_epochs}, Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}')
# Save checkpoint if validation loss improved
if val_loss < best_val_loss:
best_val_loss = val_loss
print(f'Validation loss improved to {val_loss:.6f}. Saving checkpoint...')
torch.save({
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_loss': val_loss,
}, 'shakespeare_model_best.pth')
# Only keep target loss check
if train_loss < 0.0999999:
print(f'Target loss achieved! Training completed at epoch {epoch+1}')
break
if __name__ == '__main__':
main()