Spaces:
Running
Running
| """ | |
| Advanced Training Script with Best Practices | |
| - Learning rate scheduling | |
| - Mixed precision training | |
| - Experiment tracking (W&B optional) | |
| - Comprehensive evaluation | |
| - Model checkpointing | |
| """ | |
| import argparse | |
| import os | |
| import random | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau, LambdaLR | |
| import math | |
| from efficient_train import ( | |
| create_dataloaders, Encoder, Decoder, ImageCaptioningModel, | |
| train_epoch, validate, generate_caption | |
| ) | |
| from datetime import datetime | |
| # Optional: Weights & Biases | |
| try: | |
| import wandb | |
| WANDB_AVAILABLE = True | |
| except ImportError: | |
| WANDB_AVAILABLE = False | |
| print("W&B not available. Install with: pip install wandb") | |
| def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps): | |
| """Create learning rate schedule with warmup and cosine annealing""" | |
| def lr_lambda(current_step): | |
| if current_step < num_warmup_steps: | |
| return float(current_step) / float(max(1, num_warmup_steps)) | |
| progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) | |
| return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) | |
| return LambdaLR(optimizer, lr_lambda) | |
| def train_advanced(args): | |
| """Advanced training with all best practices""" | |
| # Setup | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| torch.manual_seed(args.seed) | |
| random.seed(args.seed) | |
| np.random.seed(args.seed) | |
| # GPU optimizations | |
| if torch.cuda.is_available(): | |
| torch.backends.cudnn.benchmark = True # Optimize for consistent input sizes | |
| torch.backends.cudnn.deterministic = False # Faster, but non-deterministic | |
| print(f"Using GPU: {torch.cuda.get_device_name(0)}") | |
| print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB") | |
| # Initialize W&B | |
| if args.use_wandb and WANDB_AVAILABLE: | |
| wandb.init( | |
| project=args.wandb_project, | |
| name=f"{args.model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}", | |
| config=vars(args) | |
| ) | |
| # Create dataloaders | |
| train_loader, val_loader, test_loader, tokenizer, train_set = create_dataloaders(args) | |
| # Initialize model | |
| encoder = Encoder(args.model_name, args.embed_dim) | |
| decoder = Decoder( | |
| vocab_size=tokenizer.vocab_size + 2, | |
| embed_dim=args.embed_dim, | |
| num_layers=args.num_layers, | |
| num_heads=args.num_heads, | |
| max_seq_length=64, | |
| dropout=args.dropout | |
| ) | |
| model = ImageCaptioningModel(encoder, decoder).to(device) | |
| # Resume from checkpoint if provided | |
| start_epoch = 0 | |
| best_val_loss = float('inf') | |
| best_metrics = {} | |
| if args.resume_checkpoint: | |
| print(f"Loading checkpoint from {args.resume_checkpoint}") | |
| # Handle PyTorch 2.6+ security: allow tokenizer classes | |
| try: | |
| from transformers.models.gpt2.tokenization_gpt2_fast import GPT2TokenizerFast | |
| torch.serialization.add_safe_globals([GPT2TokenizerFast]) | |
| except ImportError: | |
| pass | |
| checkpoint = torch.load(args.resume_checkpoint, map_location=device, weights_only=False) | |
| model.load_state_dict(checkpoint['model_state']) | |
| start_epoch = checkpoint.get('epoch', 0) + 1 | |
| best_val_loss = checkpoint.get('val_loss', float('inf')) | |
| print(f"Resumed from epoch {start_epoch}, best val loss: {best_val_loss:.4f}") | |
| # Optimizer with different learning rates for encoder/decoder | |
| encoder_params = [p for n, p in model.named_parameters() if 'encoder' in n] | |
| decoder_params = [p for n, p in model.named_parameters() if 'decoder' in n] | |
| if args.different_lr: | |
| # Lower learning rate for encoder (fine-tuning) | |
| optimizer = optim.AdamW([ | |
| {'params': encoder_params, 'lr': args.lr * 0.1}, | |
| {'params': decoder_params, 'lr': args.lr} | |
| ], weight_decay=args.weight_decay) | |
| else: | |
| optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) | |
| # Learning rate scheduler | |
| if args.scheduler == 'cosine': | |
| scheduler = CosineAnnealingLR( | |
| optimizer, | |
| T_max=args.epochs * len(train_loader), | |
| eta_min=args.min_lr | |
| ) | |
| elif args.scheduler == 'plateau': | |
| scheduler = ReduceLROnPlateau( | |
| optimizer, mode='min', factor=0.5, patience=args.patience | |
| ) | |
| elif args.scheduler == 'warmup_cosine': | |
| num_training_steps = args.epochs * len(train_loader) | |
| num_warmup_steps = args.warmup_epochs * len(train_loader) | |
| scheduler = get_cosine_schedule_with_warmup( | |
| optimizer, num_warmup_steps, num_training_steps | |
| ) | |
| else: | |
| scheduler = None | |
| # Loss function | |
| criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id) | |
| # Mixed precision training - Use new API for PyTorch 2.6+ | |
| if hasattr(torch.amp, 'GradScaler'): | |
| scaler = torch.amp.GradScaler('cuda', enabled=args.use_amp) | |
| else: | |
| scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp) | |
| # Create checkpoint directory | |
| os.makedirs(args.checkpoint_dir, exist_ok=True) | |
| # Training loop | |
| patience_counter = 0 | |
| for epoch in range(start_epoch, args.epochs): | |
| args.epoch = epoch # Set epoch for train_epoch function | |
| print(f"\nEpoch {epoch+1}/{args.epochs}") | |
| print("-" * 60) | |
| # Train | |
| train_loss = train_epoch( | |
| model, train_loader, optimizer, criterion, scaler, | |
| scheduler if args.scheduler == 'cosine' or args.scheduler == 'warmup_cosine' else None, | |
| device, args | |
| ) | |
| # Validate | |
| val_loss = validate(model, val_loader, criterion, device) | |
| # Update scheduler | |
| if args.scheduler == 'plateau': | |
| scheduler.step(val_loss) | |
| elif args.scheduler in ['cosine', 'warmup_cosine']: | |
| # Already updated in train_epoch | |
| pass | |
| current_lr = optimizer.param_groups[0]['lr'] | |
| print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | LR: {current_lr:.6f}") | |
| # Log to W&B | |
| log_dict = { | |
| 'epoch': epoch, | |
| 'train_loss': train_loss, | |
| 'val_loss': val_loss, | |
| 'learning_rate': current_lr | |
| } | |
| if args.use_wandb and WANDB_AVAILABLE: | |
| wandb.log(log_dict) | |
| # Checkpointing | |
| is_best = val_loss < best_val_loss | |
| if is_best: | |
| best_val_loss = val_loss | |
| patience_counter = 0 | |
| # Save best model | |
| checkpoint = { | |
| 'epoch': epoch, | |
| 'model_state': model.state_dict(), | |
| 'optimizer_state': optimizer.state_dict(), | |
| 'scheduler_state': scheduler.state_dict() if scheduler else None, | |
| 'val_loss': val_loss, | |
| 'train_loss': train_loss, | |
| 'tokenizer': tokenizer, | |
| 'config': vars(args) | |
| } | |
| best_path = os.path.join(args.checkpoint_dir, 'best_model.pth') | |
| torch.save(checkpoint, best_path) | |
| print(f"✓ Saved best model (val_loss: {val_loss:.4f})") | |
| else: | |
| patience_counter += 1 | |
| # Save periodic checkpoints | |
| if (epoch + 1) % args.save_every == 0: | |
| checkpoint = { | |
| 'epoch': epoch, | |
| 'model_state': model.state_dict(), | |
| 'optimizer_state': optimizer.state_dict(), | |
| 'scheduler_state': scheduler.state_dict() if scheduler else None, | |
| 'val_loss': val_loss, | |
| 'train_loss': train_loss, | |
| 'tokenizer': tokenizer, | |
| 'config': vars(args) | |
| } | |
| checkpoint_path = os.path.join( | |
| args.checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth' | |
| ) | |
| torch.save(checkpoint, checkpoint_path) | |
| print(f"✓ Saved periodic checkpoint (epoch {epoch+1})") | |
| # Early stopping | |
| if patience_counter >= args.early_stopping_patience: | |
| print(f"\nEarly stopping triggered after {args.early_stopping_patience} epochs without improvement") | |
| break | |
| print("\n" + "="*60) | |
| print("Training Complete!") | |
| print(f"Best validation loss: {best_val_loss:.4f}") | |
| print(f"Best model saved to: {os.path.join(args.checkpoint_dir, 'best_model.pth')}") | |
| print("="*60) | |
| if args.use_wandb and WANDB_AVAILABLE: | |
| wandb.finish() | |
| def main(): | |
| parser = argparse.ArgumentParser(description='Advanced training with best practices') | |
| # Data arguments | |
| parser.add_argument('--train_image_dir', type=str, required=True) | |
| parser.add_argument('--train_ann_file', type=str, required=True) | |
| parser.add_argument('--val_image_dir', type=str, required=True) | |
| parser.add_argument('--val_ann_file', type=str, required=True) | |
| parser.add_argument('--test_image_dir', type=str, required=True) | |
| # Model arguments | |
| parser.add_argument('--model_name', type=str, default='efficientnet_b3') | |
| parser.add_argument('--embed_dim', type=int, default=512) | |
| parser.add_argument('--num_layers', type=int, default=8) | |
| parser.add_argument('--num_heads', type=int, default=8) | |
| parser.add_argument('--dropout', type=float, default=0.1) | |
| # Training arguments | |
| parser.add_argument('--batch_size', type=int, default=96) | |
| parser.add_argument('--lr', type=float, default=3e-4) | |
| parser.add_argument('--epochs', type=int, default=20) | |
| parser.add_argument('--seed', type=int, default=42) | |
| parser.add_argument('--use_amp', action='store_true', help='Use mixed precision') | |
| parser.add_argument('--grad_accum', type=int, default=1) | |
| parser.add_argument('--weight_decay', type=float, default=1e-4) | |
| parser.add_argument('--different_lr', action='store_true', | |
| help='Use different LR for encoder/decoder') | |
| # Scheduler arguments | |
| parser.add_argument('--scheduler', type=str, default='plateau', | |
| choices=['cosine', 'plateau', 'warmup_cosine', 'none']) | |
| parser.add_argument('--patience', type=int, default=3) | |
| parser.add_argument('--min_lr', type=float, default=1e-6) | |
| parser.add_argument('--warmup_epochs', type=int, default=2) | |
| # Checkpointing | |
| parser.add_argument('--checkpoint_dir', type=str, default='checkpoints') | |
| parser.add_argument('--resume_checkpoint', type=str, default=None) | |
| parser.add_argument('--save_every', type=int, default=5) | |
| parser.add_argument('--early_stopping_patience', type=int, default=5) | |
| # Experiment tracking | |
| parser.add_argument('--use_wandb', action='store_true', help='Use Weights & Biases') | |
| parser.add_argument('--wandb_project', type=str, default='image-captioning') | |
| # Additional args needed by create_dataloaders and train_epoch | |
| parser.add_argument('--distributed', action='store_true', help='Use distributed training') | |
| parser.add_argument('--local_rank', type=int, default=0, help='Local rank for distributed training') | |
| args = parser.parse_args() | |
| # Set epoch attribute (will be updated during training) | |
| args.epoch = 0 | |
| train_advanced(args) | |
| if __name__ == '__main__': | |
| main() | |