shakespeare-transformer-learning / train_text_model.py
gopi87's picture
Upload 6 files
e8bf402 verified
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from tqdm import tqdm
import time
import os
import pickle
import requests
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("="*70)
print(f"πŸ€– NLP Text-to-Text Language Model Training")
print("="*70)
print(f"Device: {device}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
print("="*70)
# ============================================
# 1. DOWNLOAD AND PREPARE TEXT DATA
# ============================================
print("\nπŸ“₯ Step 1: Downloading text data...")
# Download a text corpus (Shakespeare as example - you can change this!)
def download_text_data():
"""Download text data for training"""
# Option 1: Shakespeare (small, good for testing)
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
# Option 2: Larger corpus (uncomment to use)
# url = "https://raw.githubusercontent.com/pytorch/examples/main/word_language_model/data/wikitext-2/train.txt"
data_file = "training_data.txt"
if not os.path.exists(data_file):
print(f"Downloading from {url}...")
response = requests.get(url)
with open(data_file, 'w', encoding='utf-8') as f:
f.write(response.text)
print(f"βœ“ Downloaded to {data_file}")
else:
print(f"βœ“ Using existing {data_file}")
with open(data_file, 'r', encoding='utf-8') as f:
text = f.read()
return text
text = download_text_data()
print(f"βœ“ Corpus size: {len(text):,} characters")
print(f"βœ“ Sample text:\n{text[:200]}...\n")
# ============================================
# 2. CREATE VOCABULARY AND TOKENIZER
# ============================================
print("πŸ“š Step 2: Creating vocabulary...")
class CharTokenizer:
"""Simple character-level tokenizer"""
def __init__(self, text):
self.chars = sorted(list(set(text)))
self.vocab_size = len(self.chars)
self.char_to_idx = {ch: i for i, ch in enumerate(self.chars)}
self.idx_to_char = {i: ch for i, ch in enumerate(self.chars)}
def encode(self, text):
return [self.char_to_idx[ch] for ch in text]
def decode(self, indices):
return ''.join([self.idx_to_char[i] for i in indices])
tokenizer = CharTokenizer(text)
print(f"βœ“ Vocabulary size: {tokenizer.vocab_size} characters")
print(f"βœ“ Characters: {''.join(tokenizer.chars[:50])}...")
# Encode entire text
data = torch.tensor(tokenizer.encode(text), dtype=torch.long)
print(f"βœ“ Encoded data shape: {data.shape}")
# ============================================
# 3. CREATE DATASET
# ============================================
print("\nπŸ“Š Step 3: Creating dataset...")
class TextDataset(Dataset):
def __init__(self, data, seq_length=128):
self.data = data
self.seq_length = seq_length
def __len__(self):
return len(self.data) - self.seq_length
def __getitem__(self, idx):
x = self.data[idx:idx + self.seq_length]
y = self.data[idx + 1:idx + self.seq_length + 1]
return x, y
seq_length = 128
dataset = TextDataset(data, seq_length)
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
dataset, [train_size, val_size]
)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=2)
print(f"βœ“ Training samples: {len(train_dataset):,}")
print(f"βœ“ Validation samples: {len(val_dataset):,}")
print(f"βœ“ Sequence length: {seq_length}")
print(f"βœ“ Batch size: 64")
# ============================================
# 4. DEFINE THE MODEL (~1M parameters)
# ============================================
print("\nπŸ”¨ Step 4: Building language model...")
class TransformerLanguageModel(nn.Module):
"""Transformer-based language model with ~1M parameters"""
def __init__(self, vocab_size, d_model=256, nhead=8, num_layers=6, dropout=0.2):
super().__init__()
self.d_model = d_model
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = nn.Embedding(seq_length, d_model)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=d_model * 4,
dropout=dropout,
batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.dropout = nn.Dropout(dropout)
self.fc_out = nn.Linear(d_model, vocab_size)
def forward(self, x):
batch_size, seq_len = x.shape
# Token embedding
token_emb = self.embedding(x) # (batch, seq_len, d_model)
# Position embedding
positions = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1)
pos_emb = self.pos_encoding(positions)
# Combine embeddings
x = self.dropout(token_emb + pos_emb)
# Create causal mask
mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(x.device)
# Transformer
x = self.transformer(x, mask=mask, is_causal=True)
# Output
logits = self.fc_out(x)
return logits
def generate(self, tokenizer, prompt="", max_length=200, temperature=0.8):
"""Generate text given a prompt"""
self.eval()
if prompt == "":
# Start with a random character
indices = [np.random.randint(0, tokenizer.vocab_size)]
else:
indices = tokenizer.encode(prompt)
with torch.no_grad():
for _ in range(max_length):
# Get last seq_length tokens
x = torch.tensor(indices[-seq_length:], dtype=torch.long).unsqueeze(0).to(device)
# Pad if necessary
if x.shape[1] < seq_length:
padding = torch.zeros(1, seq_length - x.shape[1], dtype=torch.long).to(device)
x = torch.cat([padding, x], dim=1)
# Forward pass
logits = self(x)
logits = logits[0, -1, :] / temperature
# Sample from distribution
probs = torch.softmax(logits, dim=-1)
next_idx = torch.multinomial(probs, num_samples=1).item()
indices.append(next_idx)
return tokenizer.decode(indices)
# Create model
model = TransformerLanguageModel(
vocab_size=tokenizer.vocab_size,
d_model=256,
nhead=8,
num_layers=6,
dropout=0.2
).to(device)
# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"βœ“ Total parameters: {total_params:,}")
print(f"βœ“ Trainable parameters: {trainable_params:,}")
print(f"βœ“ Model size: {total_params * 4 / 1024 / 1024:.2f} MB")
# ============================================
# 5. TRAINING SETUP
# ============================================
print("\nβš™οΈ Step 5: Setting up training...")
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
print("βœ“ Loss function: CrossEntropyLoss")
print("βœ“ Optimizer: AdamW")
print("βœ“ Learning rate: 0.001")
# ============================================
# 6. TRAINING FUNCTIONS
# ============================================
def train_epoch(model, train_loader, criterion, optimizer, epoch):
model.train()
total_loss = 0
total_correct = 0
total_tokens = 0
pbar = tqdm(train_loader, desc=f'Epoch {epoch:2d} [Train]')
for batch_idx, (x, y) in enumerate(pbar):
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
# Forward pass
logits = model(x) # (batch, seq_len, vocab_size)
# Calculate loss
loss = criterion(logits.view(-1, tokenizer.vocab_size), y.view(-1))
# Backward pass
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
# Statistics
total_loss += loss.item()
# Calculate accuracy
predictions = logits.argmax(dim=-1)
correct = (predictions == y).sum().item()
total_correct += correct
total_tokens += y.numel()
pbar.set_postfix({
'loss': f'{total_loss/(batch_idx+1):.3f}',
'acc': f'{100.*total_correct/total_tokens:.2f}%',
'ppl': f'{np.exp(total_loss/(batch_idx+1)):.1f}'
})
avg_loss = total_loss / len(train_loader)
accuracy = 100. * total_correct / total_tokens
perplexity = np.exp(avg_loss)
return avg_loss, accuracy, perplexity
def validate(model, val_loader, criterion):
model.eval()
total_loss = 0
total_correct = 0
total_tokens = 0
with torch.no_grad():
pbar = tqdm(val_loader, desc='Validating')
for x, y in pbar:
x, y = x.to(device), y.to(device)
logits = model(x)
loss = criterion(logits.view(-1, tokenizer.vocab_size), y.view(-1))
total_loss += loss.item()
predictions = logits.argmax(dim=-1)
correct = (predictions == y).sum().item()
total_correct += correct
total_tokens += y.numel()
pbar.set_postfix({
'loss': f'{total_loss/len(pbar):.3f}',
'acc': f'{100.*total_correct/total_tokens:.2f}%'
})
avg_loss = total_loss / len(val_loader)
accuracy = 100. * total_correct / total_tokens
perplexity = np.exp(avg_loss)
return avg_loss, accuracy, perplexity
# ============================================
# 7. TRAINING LOOP
# ============================================
print("\nπŸš€ Step 6: Starting training...")
print("="*70)
num_epochs = 25
best_val_loss = float('inf')
start_time = time.time()
for epoch in range(1, num_epochs + 1):
# Train
train_loss, train_acc, train_ppl = train_epoch(model, train_loader, criterion, optimizer, epoch)
# Validate
val_loss, val_acc, val_ppl = validate(model, val_loader, criterion)
# Update learning rate
scheduler.step()
# Print results
print(f'\nEpoch {epoch:2d}/{num_epochs} | '
f'Train Loss: {train_loss:.3f} Acc: {train_acc:.1f}% PPL: {train_ppl:.1f} | '
f'Val Loss: {val_loss:.3f} Acc: {val_acc:.1f}% PPL: {val_ppl:.1f}', end='')
# Save best model
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'val_loss': val_loss,
'tokenizer': tokenizer,
}, 'best_model.pth')
print(' βœ“ BEST', end='')
print()
# Generate sample text every 5 epochs
if epoch % 5 == 0:
print("\n" + "="*70)
print(f"πŸ“ Sample generation after epoch {epoch}:")
print("-"*70)
sample = model.generate(tokenizer, prompt="", max_length=200, temperature=0.8)
print(sample)
print("="*70 + "\n")
# ============================================
# 8. FINAL RESULTS
# ============================================
total_time = time.time() - start_time
print("\n" + "="*70)
print("πŸŽ‰ Training Complete!")
print("="*70)
print(f"Total time: {total_time/60:.1f} minutes")
print(f"Best validation loss: {best_val_loss:.3f}")
print(f"Best perplexity: {np.exp(best_val_loss):.1f}")
print("="*70)
# Save final model
torch.save({
'epoch': num_epochs,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'tokenizer': tokenizer,
}, 'final_model.pth')
print("\nβœ“ Models saved:")
print(" - best_model.pth")
print(" - final_model.pth")
# ============================================
# 9. GENERATE TEXT SAMPLES
# ============================================
print("\n" + "="*70)
print("πŸ“ Final Text Generation Samples")
print("="*70)
# Load best model
checkpoint = torch.load('best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
# Generate multiple samples
prompts = ["", "The ", "To be", "Once upon"]
for i, prompt in enumerate(prompts, 1):
print(f"\n--- Sample {i} (prompt: '{prompt}') ---")
generated = model.generate(tokenizer, prompt=prompt, max_length=300, temperature=0.8)
print(generated)
print()
print("="*70)
print("\nβœ… All done! Your text model is ready!")
print("\nTo generate text later:")
print(" python3 generate_text.py")