#!/usr/bin/env python3 """ Optimized training script for morphological reinflection using TagTransformer """ import argparse import json import logging import os import time from pathlib import Path from typing import Dict, Tuple import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from torch.cuda.amp import GradScaler, autocast from transformer import TagTransformer, PAD_IDX, DEVICE from morphological_dataset import MorphologicalDataset, build_vocabulary, collate_fn, analyze_vocabulary # Set up logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) def create_model(config: Dict, src_vocab: Dict[str, int], tgt_vocab: Dict[str, int]) -> TagTransformer: """Create and initialize the TagTransformer model""" # Count feature tokens (those starting with < and ending with >) feature_tokens = [token for token in src_vocab.keys() if token.startswith('<') and token.endswith('>')] nb_attr = len(feature_tokens) logger.info(f"Found {nb_attr} feature tokens") model = TagTransformer( src_vocab_size=len(src_vocab), trg_vocab_size=len(tgt_vocab), embed_dim=config['embed_dim'], nb_heads=config['nb_heads'], src_hid_size=config['src_hid_size'], src_nb_layers=config['src_nb_layers'], trg_hid_size=config['trg_hid_size'], trg_nb_layers=config['trg_nb_layers'], dropout_p=config['dropout_p'], tie_trg_embed=config['tie_trg_embed'], label_smooth=config['label_smooth'], nb_attr=nb_attr, src_c2i=src_vocab, trg_c2i=tgt_vocab, attr_c2i={}, # Not used in this implementation ) # Initialize weights with better initialization for p in model.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) elif p.dim() == 1: nn.init.uniform_(p, -0.1, 0.1) return model def train_epoch(model: TagTransformer, dataloader: DataLoader, optimizer: optim.Optimizer, criterion: nn.Module, device: torch.device, epoch: int, config: Dict, scaler: GradScaler) -> Tuple[float, float]: """Train for one epoch with optimizations""" model.train() total_loss = 0.0 num_batches = 0 # Gradient accumulation accumulation_steps = config.get('gradient_accumulation_steps', 1) optimizer.zero_grad() for batch_idx, (src, src_mask, tgt, tgt_mask) in enumerate(dataloader): src, src_mask, tgt, tgt_mask = ( src.to(device, non_blocking=True), src_mask.to(device, non_blocking=True), tgt.to(device, non_blocking=True), tgt_mask.to(device, non_blocking=True) ) # Mixed precision forward pass with autocast(enabled=config.get('use_amp', True)): # Forward pass output = model(src, src_mask, tgt, tgt_mask) # Compute loss (shift sequences for next-token prediction) loss = model.loss(output[:-1], tgt[1:]) # Scale loss for gradient accumulation loss = loss / accumulation_steps # Mixed precision backward pass scaler.scale(loss).backward() # Gradient accumulation if (batch_idx + 1) % accumulation_steps == 0: # Gradient clipping scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config['gradient_clip']) # Optimizer step scaler.step(optimizer) scaler.update() optimizer.zero_grad() total_loss += loss.item() * accumulation_steps num_batches += 1 if batch_idx % 100 == 0: logger.info(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item() * accumulation_steps:.4f}') # Handle remaining gradients if not evenly divisible if num_batches % accumulation_steps != 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config['gradient_clip']) scaler.step(optimizer) scaler.update() optimizer.zero_grad() avg_loss = total_loss / num_batches return avg_loss, total_loss def validate(model: TagTransformer, dataloader: DataLoader, criterion: nn.Module, device: torch.device, config: Dict) -> float: """Validate the model with optimizations""" model.eval() total_loss = 0.0 num_batches = 0 with torch.no_grad(): for src, src_mask, tgt, tgt_mask in dataloader: src, src_mask, tgt, tgt_mask = ( src.to(device, non_blocking=True), src_mask.to(device, non_blocking=True), tgt.to(device, non_blocking=True), tgt_mask.to(device, non_blocking=True) ) # Mixed precision forward pass with autocast(enabled=config.get('use_amp', True)): # Forward pass output = model(src, src_mask, tgt, tgt_mask) # Compute loss loss = model.loss(output[:-1], tgt[1:]) total_loss += loss.item() num_batches += 1 avg_loss = total_loss / num_batches return avg_loss def save_checkpoint(model: TagTransformer, optimizer: optim.Optimizer, epoch: int, loss: float, save_path: str, scaler: GradScaler = None): """Save model checkpoint""" checkpoint = { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, } if scaler is not None: checkpoint['scaler_state_dict'] = scaler.state_dict() torch.save(checkpoint, save_path) logger.info(f'Checkpoint saved to {save_path}') def load_checkpoint(model: TagTransformer, optimizer: optim.Optimizer, checkpoint_path: str, scaler: GradScaler = None) -> int: """Load model checkpoint""" checkpoint = torch.load(checkpoint_path, map_location=DEVICE) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) if scaler is not None and 'scaler_state_dict' in checkpoint: scaler.load_state_dict(checkpoint['scaler_state_dict']) epoch = checkpoint['epoch'] loss = checkpoint['loss'] logger.info(f'Checkpoint loaded from {checkpoint_path}, Epoch: {epoch}, Loss: {loss:.4f}') return epoch def main(): parser = argparse.ArgumentParser(description='Train TagTransformer for morphological reinflection') parser.add_argument('--resume', type=str, help='Path to checkpoint to resume from') parser.add_argument('--output_dir', type=str, default='./models', help='Output directory') parser.add_argument('--no_amp', action='store_true', help='Disable mixed precision training') args = parser.parse_args() # Enhanced configuration with optimizations config = { 'embed_dim': 256, 'nb_heads': 4, 'src_hid_size': 1024, 'src_nb_layers': 4, 'trg_hid_size': 1024, 'trg_nb_layers': 4, 'dropout_p': 0.1, 'tie_trg_embed': True, 'label_smooth': 0.1, 'batch_size': 400, # Increased batch size 'learning_rate': 0.001, 'max_epochs': 1000, 'max_updates': 10000, 'warmup_steps': 4000, 'weight_decay': 0.01, # Added weight decay 'gradient_clip': 1.0, 'save_every': 10, 'eval_every': 5, 'max_length': 100, 'use_amp': not args.no_amp, # Mixed precision training 'gradient_accumulation_steps': 2, # Gradient accumulation 'pin_memory': True, # Better memory management 'persistent_workers': True, # Keep workers alive 'prefetch_factor': 2, # Prefetch data } # Create output directory os.makedirs(args.output_dir, exist_ok=True) os.makedirs(os.path.join(args.output_dir, 'checkpoints'), exist_ok=True) os.makedirs(os.path.join(args.output_dir, 'logs'), exist_ok=True) # Save config with open(os.path.join(args.output_dir, 'config.json'), 'w') as f: json.dump(config, f, indent=2) # Set device device = DEVICE logger.info(f'Using device: {device}') # Enable optimizations if available if torch.cuda.is_available(): torch.backends.cudnn.benchmark = True torch.backends.cudnn.deterministic = False logger.info("CUDA optimizations enabled") # Data file paths train_src = '../10L_90NL/train/run1/train.10L_90NL_1_1.src' train_tgt = '../10L_90NL/train/run1/train.10L_90NL_1_1.tgt' dev_src = '../10L_90NL/dev/run1/dev.10L_90NL_1_1.src' dev_tgt = '../10L_90NL/dev/run1/dev.10L_90NL_1_1.tgt' test_src = '../10L_90NL/test/run1/test.10L_90NL_1_1.src' test_tgt = '../10L_90NL/test/run1/test.10L_90NL_1_1.tgt' # Analyze vocabulary logger.info("Building vocabulary...") all_data_files = [train_src, train_tgt, dev_src, dev_tgt, test_src, test_tgt] vocab_stats = analyze_vocabulary(all_data_files) logger.info(f"Vocabulary statistics: {vocab_stats}") # Build source and target vocabularies src_vocab = build_vocabulary([train_src, dev_src, test_src]) tgt_vocab = build_vocabulary([train_tgt, dev_tgt, test_tgt]) logger.info(f"Source vocabulary size: {len(src_vocab)}") logger.info(f"Target vocabulary size: {len(tgt_vocab)}") # Create datasets train_dataset = MorphologicalDataset(train_src, train_tgt, src_vocab, tgt_vocab, config['max_length']) dev_dataset = MorphologicalDataset(dev_src, dev_tgt, src_vocab, tgt_vocab, config['max_length']) # Calculate optimal number of workers num_workers = min(8, os.cpu_count() or 1) # Create optimized dataloaders train_loader = DataLoader( train_dataset, batch_size=config['batch_size'], shuffle=True, collate_fn=lambda batch: collate_fn(batch, src_vocab, tgt_vocab, config['max_length']), num_workers=num_workers, pin_memory=config['pin_memory'], persistent_workers=config['persistent_workers'], prefetch_factor=config['prefetch_factor'], drop_last=True # Drop incomplete batches for consistent training ) dev_loader = DataLoader( dev_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=lambda batch: collate_fn(batch, src_vocab, tgt_vocab, config['max_length']), num_workers=num_workers, pin_memory=config['pin_memory'], persistent_workers=config['persistent_workers'], prefetch_factor=config['prefetch_factor'] ) # Create model model = create_model(config, src_vocab, tgt_vocab) model = model.to(device) # Count parameters total_params = model.count_nb_params() logger.info(f'Total parameters: {total_params:,}') # Create optimizer with better settings optimizer = optim.AdamW( # Changed to AdamW for better performance model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'], betas=(0.9, 0.999), eps=1e-8 ) # Learning rate scheduler with better scheduling def lr_lambda(step): if step < config['warmup_steps']: return float(step) / float(max(1, config['warmup_steps'])) # Cosine annealing with restarts progress = (step - config['warmup_steps']) / (config['max_updates'] - config['warmup_steps']) return max(0.0, 0.5 * (1.0 + torch.cos(torch.pi * progress))) scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) # Mixed precision training scaler = GradScaler(enabled=config['use_amp']) if config['use_amp']: logger.info("Mixed precision training enabled") # Resume from checkpoint if specified start_epoch = 0 if args.resume: start_epoch = load_checkpoint(model, optimizer, args.resume, scaler) # TensorBoard writer writer = SummaryWriter(log_dir=os.path.join(args.output_dir, 'logs')) # Training loop best_val_loss = float('inf') global_step = 0 updates = 0 for epoch in range(start_epoch, config['max_epochs']): start_time = time.time() # Train train_loss, _ = train_epoch( model, train_loader, optimizer, None, device, epoch, config, scaler ) # Update learning rate scheduler.step() current_lr = scheduler.get_last_lr()[0] # Validate if epoch % config['eval_every'] == 0: val_loss = validate(model, dev_loader, None, device, config) # Log metrics writer.add_scalar('Loss/Train', train_loss, global_step) writer.add_scalar('Loss/Val', val_loss, global_step) writer.add_scalar('Learning_Rate', current_lr, global_step) logger.info(f'Epoch {epoch}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, LR: {current_lr:.6f}') # Save best model if val_loss < best_val_loss: best_val_loss = val_loss save_checkpoint( model, optimizer, epoch, val_loss, os.path.join(args.output_dir, 'checkpoints', 'best_model.pth'), scaler ) else: logger.info(f'Epoch {epoch}: Train Loss: {train_loss:.4f}, LR: {current_lr:.6f}') # Save checkpoint periodically if epoch % config['save_every'] == 0: save_checkpoint( model, optimizer, epoch, train_loss, os.path.join(args.output_dir, 'checkpoints', f'checkpoint_epoch_{epoch}.pth'), scaler ) epoch_time = time.time() - start_time logger.info(f'Epoch {epoch} completed in {epoch_time:.2f}s') # Count updates updates += len(train_loader) global_step += len(train_loader) # Check if we've reached max updates if updates >= config['max_updates']: logger.info(f'Reached maximum updates ({config["max_updates"]}), stopping training') break # Save final model save_checkpoint( model, optimizer, epoch, train_loss, os.path.join(args.output_dir, 'checkpoints', 'final_model.pth'), scaler ) writer.close() logger.info('Training completed!') if __name__ == '__main__': main()