Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| Optimized training script for morphological reinflection using TagTransformer | |
| """ | |
| import argparse | |
| import json | |
| import logging | |
| import os | |
| import time | |
| from pathlib import Path | |
| from typing import Dict, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader | |
| from torch.utils.tensorboard import SummaryWriter | |
| from torch.cuda.amp import GradScaler, autocast | |
| from transformer import TagTransformer, PAD_IDX, DEVICE | |
| from morphological_dataset import MorphologicalDataset, build_vocabulary, collate_fn, analyze_vocabulary | |
| # Set up logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def create_model(config: Dict, src_vocab: Dict[str, int], tgt_vocab: Dict[str, int]) -> TagTransformer: | |
| """Create and initialize the TagTransformer model""" | |
| # Count feature tokens (those starting with < and ending with >) | |
| feature_tokens = [token for token in src_vocab.keys() | |
| if token.startswith('<') and token.endswith('>')] | |
| nb_attr = len(feature_tokens) | |
| logger.info(f"Found {nb_attr} feature tokens") | |
| model = TagTransformer( | |
| src_vocab_size=len(src_vocab), | |
| trg_vocab_size=len(tgt_vocab), | |
| embed_dim=config['embed_dim'], | |
| nb_heads=config['nb_heads'], | |
| src_hid_size=config['src_hid_size'], | |
| src_nb_layers=config['src_nb_layers'], | |
| trg_hid_size=config['trg_hid_size'], | |
| trg_nb_layers=config['trg_nb_layers'], | |
| dropout_p=config['dropout_p'], | |
| tie_trg_embed=config['tie_trg_embed'], | |
| label_smooth=config['label_smooth'], | |
| nb_attr=nb_attr, | |
| src_c2i=src_vocab, | |
| trg_c2i=tgt_vocab, | |
| attr_c2i={}, # Not used in this implementation | |
| ) | |
| # Initialize weights with better initialization | |
| for p in model.parameters(): | |
| if p.dim() > 1: | |
| nn.init.xavier_uniform_(p) | |
| elif p.dim() == 1: | |
| nn.init.uniform_(p, -0.1, 0.1) | |
| return model | |
| def train_epoch(model: TagTransformer, | |
| dataloader: DataLoader, | |
| optimizer: optim.Optimizer, | |
| criterion: nn.Module, | |
| device: torch.device, | |
| epoch: int, | |
| config: Dict, | |
| scaler: GradScaler) -> Tuple[float, float]: | |
| """Train for one epoch with optimizations""" | |
| model.train() | |
| total_loss = 0.0 | |
| num_batches = 0 | |
| # Gradient accumulation | |
| accumulation_steps = config.get('gradient_accumulation_steps', 1) | |
| optimizer.zero_grad() | |
| for batch_idx, (src, src_mask, tgt, tgt_mask) in enumerate(dataloader): | |
| src, src_mask, tgt, tgt_mask = ( | |
| src.to(device, non_blocking=True), | |
| src_mask.to(device, non_blocking=True), | |
| tgt.to(device, non_blocking=True), | |
| tgt_mask.to(device, non_blocking=True) | |
| ) | |
| # Mixed precision forward pass | |
| with autocast(enabled=config.get('use_amp', True)): | |
| # Forward pass | |
| output = model(src, src_mask, tgt, tgt_mask) | |
| # Compute loss (shift sequences for next-token prediction) | |
| loss = model.loss(output[:-1], tgt[1:]) | |
| # Scale loss for gradient accumulation | |
| loss = loss / accumulation_steps | |
| # Mixed precision backward pass | |
| scaler.scale(loss).backward() | |
| # Gradient accumulation | |
| if (batch_idx + 1) % accumulation_steps == 0: | |
| # Gradient clipping | |
| scaler.unscale_(optimizer) | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config['gradient_clip']) | |
| # Optimizer step | |
| scaler.step(optimizer) | |
| scaler.update() | |
| optimizer.zero_grad() | |
| total_loss += loss.item() * accumulation_steps | |
| num_batches += 1 | |
| if batch_idx % 100 == 0: | |
| logger.info(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item() * accumulation_steps:.4f}') | |
| # Handle remaining gradients if not evenly divisible | |
| if num_batches % accumulation_steps != 0: | |
| scaler.unscale_(optimizer) | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config['gradient_clip']) | |
| scaler.step(optimizer) | |
| scaler.update() | |
| optimizer.zero_grad() | |
| avg_loss = total_loss / num_batches | |
| return avg_loss, total_loss | |
| def validate(model: TagTransformer, | |
| dataloader: DataLoader, | |
| criterion: nn.Module, | |
| device: torch.device, | |
| config: Dict) -> float: | |
| """Validate the model with optimizations""" | |
| model.eval() | |
| total_loss = 0.0 | |
| num_batches = 0 | |
| with torch.no_grad(): | |
| for src, src_mask, tgt, tgt_mask in dataloader: | |
| src, src_mask, tgt, tgt_mask = ( | |
| src.to(device, non_blocking=True), | |
| src_mask.to(device, non_blocking=True), | |
| tgt.to(device, non_blocking=True), | |
| tgt_mask.to(device, non_blocking=True) | |
| ) | |
| # Mixed precision forward pass | |
| with autocast(enabled=config.get('use_amp', True)): | |
| # Forward pass | |
| output = model(src, src_mask, tgt, tgt_mask) | |
| # Compute loss | |
| loss = model.loss(output[:-1], tgt[1:]) | |
| total_loss += loss.item() | |
| num_batches += 1 | |
| avg_loss = total_loss / num_batches | |
| return avg_loss | |
| def save_checkpoint(model: TagTransformer, | |
| optimizer: optim.Optimizer, | |
| epoch: int, | |
| loss: float, | |
| save_path: str, | |
| scaler: GradScaler = None): | |
| """Save model checkpoint""" | |
| checkpoint = { | |
| 'epoch': epoch, | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'loss': loss, | |
| } | |
| if scaler is not None: | |
| checkpoint['scaler_state_dict'] = scaler.state_dict() | |
| torch.save(checkpoint, save_path) | |
| logger.info(f'Checkpoint saved to {save_path}') | |
| def load_checkpoint(model: TagTransformer, | |
| optimizer: optim.Optimizer, | |
| checkpoint_path: str, | |
| scaler: GradScaler = None) -> int: | |
| """Load model checkpoint""" | |
| checkpoint = torch.load(checkpoint_path, map_location=DEVICE) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
| if scaler is not None and 'scaler_state_dict' in checkpoint: | |
| scaler.load_state_dict(checkpoint['scaler_state_dict']) | |
| epoch = checkpoint['epoch'] | |
| loss = checkpoint['loss'] | |
| logger.info(f'Checkpoint loaded from {checkpoint_path}, Epoch: {epoch}, Loss: {loss:.4f}') | |
| return epoch | |
| def main(): | |
| parser = argparse.ArgumentParser(description='Train TagTransformer for morphological reinflection') | |
| parser.add_argument('--resume', type=str, help='Path to checkpoint to resume from') | |
| parser.add_argument('--output_dir', type=str, default='./models', help='Output directory') | |
| parser.add_argument('--no_amp', action='store_true', help='Disable mixed precision training') | |
| args = parser.parse_args() | |
| # Enhanced configuration with optimizations | |
| config = { | |
| 'embed_dim': 256, | |
| 'nb_heads': 4, | |
| 'src_hid_size': 1024, | |
| 'src_nb_layers': 4, | |
| 'trg_hid_size': 1024, | |
| 'trg_nb_layers': 4, | |
| 'dropout_p': 0.1, | |
| 'tie_trg_embed': True, | |
| 'label_smooth': 0.1, | |
| 'batch_size': 400, # Increased batch size | |
| 'learning_rate': 0.001, | |
| 'max_epochs': 1000, | |
| 'max_updates': 10000, | |
| 'warmup_steps': 4000, | |
| 'weight_decay': 0.01, # Added weight decay | |
| 'gradient_clip': 1.0, | |
| 'save_every': 10, | |
| 'eval_every': 5, | |
| 'max_length': 100, | |
| 'use_amp': not args.no_amp, # Mixed precision training | |
| 'gradient_accumulation_steps': 2, # Gradient accumulation | |
| 'pin_memory': True, # Better memory management | |
| 'persistent_workers': True, # Keep workers alive | |
| 'prefetch_factor': 2, # Prefetch data | |
| } | |
| # Create output directory | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| os.makedirs(os.path.join(args.output_dir, 'checkpoints'), exist_ok=True) | |
| os.makedirs(os.path.join(args.output_dir, 'logs'), exist_ok=True) | |
| # Save config | |
| with open(os.path.join(args.output_dir, 'config.json'), 'w') as f: | |
| json.dump(config, f, indent=2) | |
| # Set device | |
| device = DEVICE | |
| logger.info(f'Using device: {device}') | |
| # Enable optimizations if available | |
| if torch.cuda.is_available(): | |
| torch.backends.cudnn.benchmark = True | |
| torch.backends.cudnn.deterministic = False | |
| logger.info("CUDA optimizations enabled") | |
| # Data file paths | |
| train_src = '../10L_90NL/train/run1/train.10L_90NL_1_1.src' | |
| train_tgt = '../10L_90NL/train/run1/train.10L_90NL_1_1.tgt' | |
| dev_src = '../10L_90NL/dev/run1/dev.10L_90NL_1_1.src' | |
| dev_tgt = '../10L_90NL/dev/run1/dev.10L_90NL_1_1.tgt' | |
| test_src = '../10L_90NL/test/run1/test.10L_90NL_1_1.src' | |
| test_tgt = '../10L_90NL/test/run1/test.10L_90NL_1_1.tgt' | |
| # Analyze vocabulary | |
| logger.info("Building vocabulary...") | |
| all_data_files = [train_src, train_tgt, dev_src, dev_tgt, test_src, test_tgt] | |
| vocab_stats = analyze_vocabulary(all_data_files) | |
| logger.info(f"Vocabulary statistics: {vocab_stats}") | |
| # Build source and target vocabularies | |
| src_vocab = build_vocabulary([train_src, dev_src, test_src]) | |
| tgt_vocab = build_vocabulary([train_tgt, dev_tgt, test_tgt]) | |
| logger.info(f"Source vocabulary size: {len(src_vocab)}") | |
| logger.info(f"Target vocabulary size: {len(tgt_vocab)}") | |
| # Create datasets | |
| train_dataset = MorphologicalDataset(train_src, train_tgt, src_vocab, tgt_vocab, config['max_length']) | |
| dev_dataset = MorphologicalDataset(dev_src, dev_tgt, src_vocab, tgt_vocab, config['max_length']) | |
| # Calculate optimal number of workers | |
| num_workers = min(8, os.cpu_count() or 1) | |
| # Create optimized dataloaders | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=config['batch_size'], | |
| shuffle=True, | |
| collate_fn=lambda batch: collate_fn(batch, src_vocab, tgt_vocab, config['max_length']), | |
| num_workers=num_workers, | |
| pin_memory=config['pin_memory'], | |
| persistent_workers=config['persistent_workers'], | |
| prefetch_factor=config['prefetch_factor'], | |
| drop_last=True # Drop incomplete batches for consistent training | |
| ) | |
| dev_loader = DataLoader( | |
| dev_dataset, | |
| batch_size=config['batch_size'], | |
| shuffle=False, | |
| collate_fn=lambda batch: collate_fn(batch, src_vocab, tgt_vocab, config['max_length']), | |
| num_workers=num_workers, | |
| pin_memory=config['pin_memory'], | |
| persistent_workers=config['persistent_workers'], | |
| prefetch_factor=config['prefetch_factor'] | |
| ) | |
| # Create model | |
| model = create_model(config, src_vocab, tgt_vocab) | |
| model = model.to(device) | |
| # Count parameters | |
| total_params = model.count_nb_params() | |
| logger.info(f'Total parameters: {total_params:,}') | |
| # Create optimizer with better settings | |
| optimizer = optim.AdamW( # Changed to AdamW for better performance | |
| model.parameters(), | |
| lr=config['learning_rate'], | |
| weight_decay=config['weight_decay'], | |
| betas=(0.9, 0.999), | |
| eps=1e-8 | |
| ) | |
| # Learning rate scheduler with better scheduling | |
| def lr_lambda(step): | |
| if step < config['warmup_steps']: | |
| return float(step) / float(max(1, config['warmup_steps'])) | |
| # Cosine annealing with restarts | |
| progress = (step - config['warmup_steps']) / (config['max_updates'] - config['warmup_steps']) | |
| return max(0.0, 0.5 * (1.0 + torch.cos(torch.pi * progress))) | |
| scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) | |
| # Mixed precision training | |
| scaler = GradScaler(enabled=config['use_amp']) | |
| if config['use_amp']: | |
| logger.info("Mixed precision training enabled") | |
| # Resume from checkpoint if specified | |
| start_epoch = 0 | |
| if args.resume: | |
| start_epoch = load_checkpoint(model, optimizer, args.resume, scaler) | |
| # TensorBoard writer | |
| writer = SummaryWriter(log_dir=os.path.join(args.output_dir, 'logs')) | |
| # Training loop | |
| best_val_loss = float('inf') | |
| global_step = 0 | |
| updates = 0 | |
| for epoch in range(start_epoch, config['max_epochs']): | |
| start_time = time.time() | |
| # Train | |
| train_loss, _ = train_epoch( | |
| model, train_loader, optimizer, None, device, epoch, config, scaler | |
| ) | |
| # Update learning rate | |
| scheduler.step() | |
| current_lr = scheduler.get_last_lr()[0] | |
| # Validate | |
| if epoch % config['eval_every'] == 0: | |
| val_loss = validate(model, dev_loader, None, device, config) | |
| # Log metrics | |
| writer.add_scalar('Loss/Train', train_loss, global_step) | |
| writer.add_scalar('Loss/Val', val_loss, global_step) | |
| writer.add_scalar('Learning_Rate', current_lr, global_step) | |
| logger.info(f'Epoch {epoch}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, LR: {current_lr:.6f}') | |
| # Save best model | |
| if val_loss < best_val_loss: | |
| best_val_loss = val_loss | |
| save_checkpoint( | |
| model, optimizer, epoch, val_loss, | |
| os.path.join(args.output_dir, 'checkpoints', 'best_model.pth'), | |
| scaler | |
| ) | |
| else: | |
| logger.info(f'Epoch {epoch}: Train Loss: {train_loss:.4f}, LR: {current_lr:.6f}') | |
| # Save checkpoint periodically | |
| if epoch % config['save_every'] == 0: | |
| save_checkpoint( | |
| model, optimizer, epoch, train_loss, | |
| os.path.join(args.output_dir, 'checkpoints', f'checkpoint_epoch_{epoch}.pth'), | |
| scaler | |
| ) | |
| epoch_time = time.time() - start_time | |
| logger.info(f'Epoch {epoch} completed in {epoch_time:.2f}s') | |
| # Count updates | |
| updates += len(train_loader) | |
| global_step += len(train_loader) | |
| # Check if we've reached max updates | |
| if updates >= config['max_updates']: | |
| logger.info(f'Reached maximum updates ({config["max_updates"]}), stopping training') | |
| break | |
| # Save final model | |
| save_checkpoint( | |
| model, optimizer, epoch, train_loss, | |
| os.path.join(args.output_dir, 'checkpoints', 'final_model.pth'), | |
| scaler | |
| ) | |
| writer.close() | |
| logger.info('Training completed!') | |
| if __name__ == '__main__': | |
| main() | |