ImageCaptionner / training /train_advanced.py
AOUNZakaria's picture
Deploy image captioner
32d4a86
"""
Advanced Training Script with Best Practices
- Learning rate scheduling
- Mixed precision training
- Experiment tracking (W&B optional)
- Comprehensive evaluation
- Model checkpointing
"""
import argparse
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau, LambdaLR
import math
from efficient_train import (
create_dataloaders, Encoder, Decoder, ImageCaptioningModel,
train_epoch, validate, generate_caption
)
from datetime import datetime
# Optional: Weights & Biases
try:
import wandb
WANDB_AVAILABLE = True
except ImportError:
WANDB_AVAILABLE = False
print("W&B not available. Install with: pip install wandb")
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
"""Create learning rate schedule with warmup and cosine annealing"""
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
return LambdaLR(optimizer, lr_lambda)
def train_advanced(args):
"""Advanced training with all best practices"""
# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)
# GPU optimizations
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = True # Optimize for consistent input sizes
torch.backends.cudnn.deterministic = False # Faster, but non-deterministic
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
# Initialize W&B
if args.use_wandb and WANDB_AVAILABLE:
wandb.init(
project=args.wandb_project,
name=f"{args.model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
config=vars(args)
)
# Create dataloaders
train_loader, val_loader, test_loader, tokenizer, train_set = create_dataloaders(args)
# Initialize model
encoder = Encoder(args.model_name, args.embed_dim)
decoder = Decoder(
vocab_size=tokenizer.vocab_size + 2,
embed_dim=args.embed_dim,
num_layers=args.num_layers,
num_heads=args.num_heads,
max_seq_length=64,
dropout=args.dropout
)
model = ImageCaptioningModel(encoder, decoder).to(device)
# Resume from checkpoint if provided
start_epoch = 0
best_val_loss = float('inf')
best_metrics = {}
if args.resume_checkpoint:
print(f"Loading checkpoint from {args.resume_checkpoint}")
# Handle PyTorch 2.6+ security: allow tokenizer classes
try:
from transformers.models.gpt2.tokenization_gpt2_fast import GPT2TokenizerFast
torch.serialization.add_safe_globals([GPT2TokenizerFast])
except ImportError:
pass
checkpoint = torch.load(args.resume_checkpoint, map_location=device, weights_only=False)
model.load_state_dict(checkpoint['model_state'])
start_epoch = checkpoint.get('epoch', 0) + 1
best_val_loss = checkpoint.get('val_loss', float('inf'))
print(f"Resumed from epoch {start_epoch}, best val loss: {best_val_loss:.4f}")
# Optimizer with different learning rates for encoder/decoder
encoder_params = [p for n, p in model.named_parameters() if 'encoder' in n]
decoder_params = [p for n, p in model.named_parameters() if 'decoder' in n]
if args.different_lr:
# Lower learning rate for encoder (fine-tuning)
optimizer = optim.AdamW([
{'params': encoder_params, 'lr': args.lr * 0.1},
{'params': decoder_params, 'lr': args.lr}
], weight_decay=args.weight_decay)
else:
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
# Learning rate scheduler
if args.scheduler == 'cosine':
scheduler = CosineAnnealingLR(
optimizer,
T_max=args.epochs * len(train_loader),
eta_min=args.min_lr
)
elif args.scheduler == 'plateau':
scheduler = ReduceLROnPlateau(
optimizer, mode='min', factor=0.5, patience=args.patience
)
elif args.scheduler == 'warmup_cosine':
num_training_steps = args.epochs * len(train_loader)
num_warmup_steps = args.warmup_epochs * len(train_loader)
scheduler = get_cosine_schedule_with_warmup(
optimizer, num_warmup_steps, num_training_steps
)
else:
scheduler = None
# Loss function
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
# Mixed precision training - Use new API for PyTorch 2.6+
if hasattr(torch.amp, 'GradScaler'):
scaler = torch.amp.GradScaler('cuda', enabled=args.use_amp)
else:
scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp)
# Create checkpoint directory
os.makedirs(args.checkpoint_dir, exist_ok=True)
# Training loop
patience_counter = 0
for epoch in range(start_epoch, args.epochs):
args.epoch = epoch # Set epoch for train_epoch function
print(f"\nEpoch {epoch+1}/{args.epochs}")
print("-" * 60)
# Train
train_loss = train_epoch(
model, train_loader, optimizer, criterion, scaler,
scheduler if args.scheduler == 'cosine' or args.scheduler == 'warmup_cosine' else None,
device, args
)
# Validate
val_loss = validate(model, val_loader, criterion, device)
# Update scheduler
if args.scheduler == 'plateau':
scheduler.step(val_loss)
elif args.scheduler in ['cosine', 'warmup_cosine']:
# Already updated in train_epoch
pass
current_lr = optimizer.param_groups[0]['lr']
print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | LR: {current_lr:.6f}")
# Log to W&B
log_dict = {
'epoch': epoch,
'train_loss': train_loss,
'val_loss': val_loss,
'learning_rate': current_lr
}
if args.use_wandb and WANDB_AVAILABLE:
wandb.log(log_dict)
# Checkpointing
is_best = val_loss < best_val_loss
if is_best:
best_val_loss = val_loss
patience_counter = 0
# Save best model
checkpoint = {
'epoch': epoch,
'model_state': model.state_dict(),
'optimizer_state': optimizer.state_dict(),
'scheduler_state': scheduler.state_dict() if scheduler else None,
'val_loss': val_loss,
'train_loss': train_loss,
'tokenizer': tokenizer,
'config': vars(args)
}
best_path = os.path.join(args.checkpoint_dir, 'best_model.pth')
torch.save(checkpoint, best_path)
print(f"✓ Saved best model (val_loss: {val_loss:.4f})")
else:
patience_counter += 1
# Save periodic checkpoints
if (epoch + 1) % args.save_every == 0:
checkpoint = {
'epoch': epoch,
'model_state': model.state_dict(),
'optimizer_state': optimizer.state_dict(),
'scheduler_state': scheduler.state_dict() if scheduler else None,
'val_loss': val_loss,
'train_loss': train_loss,
'tokenizer': tokenizer,
'config': vars(args)
}
checkpoint_path = os.path.join(
args.checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth'
)
torch.save(checkpoint, checkpoint_path)
print(f"✓ Saved periodic checkpoint (epoch {epoch+1})")
# Early stopping
if patience_counter >= args.early_stopping_patience:
print(f"\nEarly stopping triggered after {args.early_stopping_patience} epochs without improvement")
break
print("\n" + "="*60)
print("Training Complete!")
print(f"Best validation loss: {best_val_loss:.4f}")
print(f"Best model saved to: {os.path.join(args.checkpoint_dir, 'best_model.pth')}")
print("="*60)
if args.use_wandb and WANDB_AVAILABLE:
wandb.finish()
def main():
parser = argparse.ArgumentParser(description='Advanced training with best practices')
# Data arguments
parser.add_argument('--train_image_dir', type=str, required=True)
parser.add_argument('--train_ann_file', type=str, required=True)
parser.add_argument('--val_image_dir', type=str, required=True)
parser.add_argument('--val_ann_file', type=str, required=True)
parser.add_argument('--test_image_dir', type=str, required=True)
# Model arguments
parser.add_argument('--model_name', type=str, default='efficientnet_b3')
parser.add_argument('--embed_dim', type=int, default=512)
parser.add_argument('--num_layers', type=int, default=8)
parser.add_argument('--num_heads', type=int, default=8)
parser.add_argument('--dropout', type=float, default=0.1)
# Training arguments
parser.add_argument('--batch_size', type=int, default=96)
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--epochs', type=int, default=20)
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--use_amp', action='store_true', help='Use mixed precision')
parser.add_argument('--grad_accum', type=int, default=1)
parser.add_argument('--weight_decay', type=float, default=1e-4)
parser.add_argument('--different_lr', action='store_true',
help='Use different LR for encoder/decoder')
# Scheduler arguments
parser.add_argument('--scheduler', type=str, default='plateau',
choices=['cosine', 'plateau', 'warmup_cosine', 'none'])
parser.add_argument('--patience', type=int, default=3)
parser.add_argument('--min_lr', type=float, default=1e-6)
parser.add_argument('--warmup_epochs', type=int, default=2)
# Checkpointing
parser.add_argument('--checkpoint_dir', type=str, default='checkpoints')
parser.add_argument('--resume_checkpoint', type=str, default=None)
parser.add_argument('--save_every', type=int, default=5)
parser.add_argument('--early_stopping_patience', type=int, default=5)
# Experiment tracking
parser.add_argument('--use_wandb', action='store_true', help='Use Weights & Biases')
parser.add_argument('--wandb_project', type=str, default='image-captioning')
# Additional args needed by create_dataloaders and train_epoch
parser.add_argument('--distributed', action='store_true', help='Use distributed training')
parser.add_argument('--local_rank', type=int, default=0, help='Local rank for distributed training')
args = parser.parse_args()
# Set epoch attribute (will be updated during training)
args.epoch = 0
train_advanced(args)
if __name__ == '__main__':
main()