Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import Dataset, DataLoader, random_split | |
| import os | |
| import math | |
| # Adjusted hyperparameters | |
| BATCH_SIZE = 32 | |
| BLOCK_SIZE = 128 | |
| LEARNING_RATE = 3e-4 | |
| N_EMBD = 512 # Reduced from 768 | |
| N_HEAD = 8 # Reduced from 12 | |
| N_LAYER = 8 # Reduced from 12 | |
| DROPOUT = 0.2 # Increased from 0.1 | |
| WEIGHT_DECAY = 0.01 | |
| class DecoderBlock(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| assert N_EMBD % N_HEAD == 0 | |
| # Added dropout to attention | |
| self.attention = nn.MultiheadAttention( | |
| N_EMBD, | |
| N_HEAD, | |
| dropout=DROPOUT, | |
| batch_first=True, | |
| bias=False # Reduce parameters | |
| ) | |
| # Modified feed forward with additional dropout | |
| self.feed_forward = nn.Sequential( | |
| nn.Linear(N_EMBD, 3 * N_EMBD, bias=False), # Reduced multiplier from 4 to 3 | |
| nn.GELU(), | |
| nn.Dropout(DROPOUT), | |
| nn.Linear(3 * N_EMBD, N_EMBD, bias=False), | |
| nn.Dropout(DROPOUT) | |
| ) | |
| # Layer normalization with improved epsilon | |
| self.ln1 = nn.LayerNorm(N_EMBD, eps=1e-5) | |
| self.ln2 = nn.LayerNorm(N_EMBD, eps=1e-5) | |
| def forward(self, x, mask=None): | |
| # Pre-norm architecture for better training stability | |
| attn_out = self.attention(self.ln1(x), self.ln1(x), self.ln1(x), attn_mask=mask, need_weights=False)[0] | |
| x = x + attn_out | |
| x = x + self.feed_forward(self.ln2(x)) | |
| return x | |
| class ShakespeareModel(nn.Module): | |
| def __init__(self, vocab_size): | |
| super().__init__() | |
| self.token_embedding = nn.Embedding(vocab_size, N_EMBD) | |
| self.position_embedding = nn.Embedding(BLOCK_SIZE, N_EMBD) | |
| self.drop = nn.Dropout(DROPOUT) | |
| self.blocks = nn.ModuleList([DecoderBlock() for _ in range(N_LAYER)]) | |
| self.ln_f = nn.LayerNorm(N_EMBD, eps=1e-5) | |
| self.lm_head = nn.Linear(N_EMBD, vocab_size, bias=False) | |
| # Improved weight initialization | |
| self.apply(self._init_weights) | |
| def _init_weights(self, module): | |
| if isinstance(module, nn.Linear): | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02/math.sqrt(2 * N_LAYER)) | |
| elif isinstance(module, nn.Embedding): | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02/math.sqrt(2 * N_LAYER)) | |
| elif isinstance(module, nn.LayerNorm): | |
| torch.nn.init.ones_(module.weight) | |
| torch.nn.init.zeros_(module.bias) | |
| def forward(self, idx): | |
| B, T = idx.shape | |
| # Get embeddings and apply dropout | |
| tok_emb = self.token_embedding(idx) | |
| pos = torch.arange(0, T, dtype=torch.long, device=idx.device) | |
| pos_emb = self.position_embedding(pos) | |
| # Apply dropout to combined embeddings | |
| x = self.drop(tok_emb + pos_emb) | |
| # Create causal mask | |
| mask = torch.triu(torch.ones(T, T) * float('-inf'), diagonal=1).to(idx.device) | |
| # Apply transformer blocks | |
| for block in self.blocks: | |
| x = block(x, mask) | |
| x = self.ln_f(x) | |
| logits = self.lm_head(x) | |
| return logits | |
| class TextDataset(Dataset): | |
| def __init__(self, text, block_size): | |
| chars = sorted(list(set(text))) | |
| self.vocab_size = len(chars) | |
| self.stoi = {ch: i for i, ch in enumerate(chars)} | |
| self.itos = {i: ch for i, ch in enumerate(chars)} | |
| data = torch.tensor([self.stoi[c] for c in text], dtype=torch.long) | |
| # Create overlapping sequences for better context | |
| self.examples = [] | |
| stride = block_size // 2 # Add stride for overlapping sequences | |
| for i in range(0, len(data) - block_size, stride): | |
| x = data[i:i + block_size] | |
| y = data[i + 1:i + block_size + 1] | |
| self.examples.append((x, y)) | |
| def __len__(self): | |
| return len(self.examples) | |
| def __getitem__(self, idx): | |
| return self.examples[idx] | |
| def evaluate_model(model, dataloader, device): | |
| model.eval() | |
| total_loss = 0 | |
| with torch.no_grad(): | |
| for x, y in dataloader: | |
| x, y = x.to(device), y.to(device) | |
| logits = model(x) | |
| loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1)) | |
| total_loss += loss.item() | |
| return total_loss / len(dataloader) | |
| def train_model(model, train_dataloader, optimizer, scheduler, device): | |
| model.train() | |
| total_loss = 0 | |
| for x, y in train_dataloader: | |
| x, y = x.to(device), y.to(device) | |
| # Forward pass with mixed precision | |
| logits = model(x) | |
| loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1)) | |
| # Backward pass | |
| optimizer.zero_grad(set_to_none=True) # More efficient than zero_grad() | |
| loss.backward() | |
| # Gradient clipping | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) | |
| optimizer.step() | |
| total_loss += loss.item() | |
| return total_loss / len(train_dataloader) | |
| def main(): | |
| # Load and preprocess data | |
| with open('input.txt', 'r', encoding='utf-8') as f: | |
| text = f.read() | |
| # Create full dataset | |
| full_dataset = TextDataset(text, BLOCK_SIZE) | |
| # Split into train and validation sets (90-10 split) | |
| train_size = int(0.9 * len(full_dataset)) | |
| val_size = len(full_dataset) - train_size | |
| train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size]) | |
| # Create dataloaders | |
| train_dataloader = DataLoader( | |
| train_dataset, | |
| batch_size=BATCH_SIZE, | |
| shuffle=True, | |
| pin_memory=True | |
| ) | |
| val_dataloader = DataLoader( | |
| val_dataset, | |
| batch_size=BATCH_SIZE, | |
| shuffle=False, | |
| pin_memory=True | |
| ) | |
| # Initialize model and optimizer | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = ShakespeareModel(full_dataset.vocab_size).to(device) | |
| # Optimizer with weight decay | |
| optimizer = torch.optim.AdamW( | |
| model.parameters(), | |
| lr=LEARNING_RATE, | |
| betas=(0.9, 0.95), | |
| weight_decay=WEIGHT_DECAY | |
| ) | |
| # Learning rate scheduler | |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( | |
| optimizer, | |
| mode='min', | |
| factor=0.5, | |
| patience=3, | |
| verbose=True | |
| ) | |
| # Load checkpoint if exists | |
| start_epoch = 0 | |
| best_val_loss = float('inf') | |
| if os.path.exists('shakespeare_model_best.pth'): | |
| print("Loading checkpoint 'shakespeare_model_best.pth'") | |
| checkpoint = torch.load('shakespeare_model_best.pth') | |
| start_epoch = checkpoint['epoch'] | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
| best_val_loss = checkpoint['best_loss'] | |
| print(f"Loaded checkpoint (epoch {start_epoch})") | |
| # Training loop | |
| n_epochs = 1000 | |
| patience = 5 | |
| patience_counter = 0 | |
| for epoch in range(start_epoch, n_epochs): | |
| train_loss = train_model(model, train_dataloader, optimizer, scheduler, device) | |
| val_loss = evaluate_model(model, val_dataloader, device) | |
| # Update learning rate | |
| scheduler.step(val_loss) | |
| print(f'Epoch {epoch+1}/{n_epochs}, Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}') | |
| # Save checkpoint if validation loss improved | |
| if val_loss < best_val_loss: | |
| best_val_loss = val_loss | |
| print(f'Validation loss improved to {val_loss:.6f}. Saving checkpoint...') | |
| torch.save({ | |
| 'epoch': epoch + 1, | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'best_loss': val_loss, | |
| }, 'shakespeare_model_best.pth') | |
| # Only keep target loss check | |
| if train_loss < 0.0999999: | |
| print(f'Target loss achieved! Training completed at epoch {epoch+1}') | |
| break | |
| if __name__ == '__main__': | |
| main() |