Spaces:
Running
Running
| """ | |
| Hyperparameter Optimization using Optuna | |
| Run this to find the best hyperparameters for your model | |
| """ | |
| import optuna | |
| import torch | |
| import argparse | |
| import os | |
| import sys | |
| from efficient_train import create_dataloaders, Encoder, Decoder, ImageCaptioningModel | |
| from efficient_train import train_epoch, validate, generate_caption | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau | |
| def train_with_config(trial, args): | |
| """Train model with suggested hyperparameters from Optuna""" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Suggest hyperparameters | |
| lr = trial.suggest_loguniform('lr', 1e-5, 1e-3) | |
| batch_size = trial.suggest_categorical('batch_size', [32, 64, 96, 128]) | |
| embed_dim = trial.suggest_categorical('embed_dim', [256, 512, 768]) | |
| num_layers = trial.suggest_int('num_layers', 4, 12) | |
| num_heads = trial.suggest_categorical('num_heads', [4, 8, 12, 16]) | |
| dropout = trial.suggest_uniform('dropout', 0.1, 0.5) | |
| weight_decay = trial.suggest_loguniform('weight_decay', 1e-5, 1e-2) | |
| warmup_epochs = trial.suggest_int('warmup_epochs', 0, 3) | |
| # Update args with suggested values | |
| args.lr = lr | |
| args.batch_size = batch_size | |
| args.embed_dim = embed_dim | |
| args.num_layers = num_layers | |
| args.num_heads = num_heads | |
| args.epochs = 5 # Fewer epochs for hyperparameter search | |
| # Create dataloaders | |
| train_loader, val_loader, test_loader, tokenizer, train_set = create_dataloaders(args) | |
| # Initialize model | |
| encoder = Encoder(args.model_name, embed_dim) | |
| decoder = Decoder( | |
| vocab_size=tokenizer.vocab_size + 2, | |
| embed_dim=embed_dim, | |
| num_layers=num_layers, | |
| num_heads=num_heads, | |
| max_seq_length=64, | |
| dropout=dropout | |
| ) | |
| model = ImageCaptioningModel(encoder, decoder).to(device) | |
| # Optimizer | |
| optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) | |
| # Scheduler | |
| scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2) | |
| # Loss | |
| criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id) | |
| # Mixed precision | |
| scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp) | |
| # Training loop (fewer epochs for hyperparameter search) | |
| best_val_loss = float('inf') | |
| for epoch in range(args.epochs): | |
| # Train | |
| train_loss = train_epoch(model, train_loader, optimizer, criterion, scaler, | |
| scheduler, device, args) | |
| # Validate | |
| val_loss = validate(model, val_loader, criterion, device) | |
| # Update scheduler | |
| scheduler.step(val_loss) | |
| # Report to Optuna | |
| trial.report(val_loss, epoch) | |
| # Prune trial if not promising | |
| if trial.should_prune(): | |
| raise optuna.exceptions.TrialPruned() | |
| if val_loss < best_val_loss: | |
| best_val_loss = val_loss | |
| return best_val_loss | |
| def objective(trial): | |
| """Optuna objective function""" | |
| # Create minimal args object | |
| args = argparse.Namespace( | |
| train_image_dir='Data/train2017/train2017', | |
| train_ann_file='Data/annotations_trainval2017/annotations/captions_train2017.json', | |
| val_image_dir='Data/val2017', | |
| val_ann_file='Data/annotations_trainval2017/annotations/captions_val2017.json', | |
| test_image_dir='Data/test2017/test2017', | |
| model_name='efficientnet_b3', | |
| embed_dim=512, # Will be overridden | |
| num_layers=8, # Will be overridden | |
| num_heads=8, # Will be overridden | |
| batch_size=96, # Will be overridden | |
| lr=3e-4, # Will be overridden | |
| epochs=5, | |
| seed=42, | |
| use_amp=True, | |
| grad_accum=1, | |
| checkpoint_dir='checkpoints', | |
| early_stopping_patience=3, | |
| distributed=False, | |
| local_rank=0, | |
| resume_checkpoint=None | |
| ) | |
| try: | |
| val_loss = train_with_config(trial, args) | |
| return val_loss | |
| except Exception as e: | |
| print(f"Trial failed: {e}") | |
| return float('inf') | |
| def main(): | |
| parser = argparse.ArgumentParser(description='Hyperparameter optimization with Optuna') | |
| parser.add_argument('--n_trials', type=int, default=50, help='Number of trials') | |
| parser.add_argument('--timeout', type=int, default=3600*24, help='Timeout in seconds') | |
| parser.add_argument('--study_name', type=str, default='efficientnet_captioning', | |
| help='Study name') | |
| parser.add_argument('--storage', type=str, default='sqlite:///optuna_study.db', | |
| help='Storage URL for study') | |
| args = parser.parse_args() | |
| # Create or load study | |
| study = optuna.create_study( | |
| direction='minimize', | |
| study_name=args.study_name, | |
| storage=args.storage, | |
| load_if_exists=True, | |
| pruner=optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=3) | |
| ) | |
| print(f"Starting optimization with {args.n_trials} trials...") | |
| print(f"Study: {args.study_name}") | |
| # Optimize | |
| study.optimize(objective, n_trials=args.n_trials, timeout=args.timeout) | |
| # Print results | |
| print("\n" + "="*60) | |
| print("Optimization Complete!") | |
| print("="*60) | |
| print(f"Best trial: {study.best_trial.number}") | |
| print(f"Best validation loss: {study.best_value:.4f}") | |
| print("\nBest parameters:") | |
| for key, value in study.best_params.items(): | |
| print(f" {key}: {value}") | |
| # Save results | |
| import json | |
| with open('best_hyperparameters.json', 'w') as f: | |
| json.dump(study.best_params, f, indent=2) | |
| print("\nBest hyperparameters saved to best_hyperparameters.json") | |
| # Visualize (optional, requires plotly) | |
| try: | |
| import optuna.visualization as vis | |
| # Optimization history | |
| fig = vis.plot_optimization_history(study) | |
| fig.write_image("optimization_history.png") | |
| print("Saved optimization_history.png") | |
| # Parameter importances | |
| fig = vis.plot_param_importances(study) | |
| fig.write_image("param_importances.png") | |
| print("Saved param_importances.png") | |
| # Parallel coordinate plot | |
| fig = vis.plot_parallel_coordinate(study) | |
| fig.write_image("parallel_coordinate.png") | |
| print("Saved parallel_coordinate.png") | |
| except ImportError: | |
| print("Install plotly to generate visualizations: pip install plotly") | |
| if __name__ == '__main__': | |
| main() | |