| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import pytorch_lightning as pl |
| from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback |
| from pytorch_lightning.loggers import TensorBoardLogger |
| import numpy as np |
| import matplotlib.pyplot as plt |
| import matplotlib |
| matplotlib.use('Agg') |
| from typing import Dict, Any, Optional |
| import argparse |
| import os |
| import io |
|
|
| from .model import DispersionTransformer |
| from .dataset import create_dataloaders, ParameterDistributions |
| from .utils import compute_rmse, compute_mae |
|
|
|
|
| class PredictionPlotCallback(Callback): |
| """Callback to plot truth vs prediction scatter plots in TensorBoard.""" |
| |
| def __init__(self, plot_every_n_epochs=5, max_samples=500): |
| """ |
| Initialize plotting callback. |
| |
| Args: |
| plot_every_n_epochs: How often to generate plots |
| max_samples: Maximum number of samples to plot (for performance) |
| """ |
| self.plot_every_n_epochs = plot_every_n_epochs |
| self.max_samples = max_samples |
| |
| def on_validation_epoch_end(self, trainer, pl_module): |
| """Generate truth vs prediction plots at end of validation epoch.""" |
| if trainer.current_epoch % self.plot_every_n_epochs != 0: |
| return |
| |
| |
| pl_module.eval() |
| |
| |
| predictions_list = [] |
| targets_list = [] |
| |
| with torch.no_grad(): |
| |
| val_loader = trainer.val_dataloaders |
| for batch_idx, batch in enumerate(val_loader): |
| if batch_idx >= 10: |
| break |
| |
| set_1, set_2, set_1_mask, set_2_mask, targets = batch |
| |
| |
| set_1 = set_1.to(pl_module.device) |
| set_2 = set_2.to(pl_module.device) |
| set_1_mask = set_1_mask.to(pl_module.device) |
| set_2_mask = set_2_mask.to(pl_module.device) |
| targets = targets.to(pl_module.device) |
| |
| |
| predictions = pl_module(set_1, set_2, set_1_mask, set_2_mask) |
| |
| predictions_list.append(predictions.cpu()) |
| targets_list.append(targets.cpu()) |
| |
| if not predictions_list: |
| return |
| |
| |
| all_predictions = torch.cat(predictions_list, dim=0) |
| all_targets = torch.cat(targets_list, dim=0) |
| |
| |
| if len(all_predictions) > self.max_samples: |
| indices = torch.randperm(len(all_predictions))[:self.max_samples] |
| all_predictions = all_predictions[indices] |
| all_targets = all_targets[indices] |
| |
| |
| self._create_plots(trainer, all_predictions, all_targets, trainer.current_epoch) |
| |
| def _create_plots(self, trainer, predictions, targets, epoch): |
| """Create scatter plots for each parameter.""" |
| param_names = ['μ', 'β', 'α'] |
| |
| |
| fig, axes = plt.subplots(1, 3, figsize=(15, 5)) |
| |
| for i, (param_name, ax) in enumerate(zip(param_names, axes)): |
| pred_vals = predictions[:, i].numpy() |
| true_vals = targets[:, i].numpy() |
| |
| |
| ax.scatter(true_vals, pred_vals, alpha=0.6, s=20) |
| |
| |
| min_val = min(true_vals.min(), pred_vals.min()) |
| max_val = max(true_vals.max(), pred_vals.max()) |
| ax.plot([min_val, max_val], [min_val, max_val], 'r--', alpha=0.8, linewidth=1) |
| |
| |
| correlation_matrix = np.corrcoef(true_vals, pred_vals) |
| r_squared = correlation_matrix[0, 1] ** 2 |
| |
| |
| rmse = np.sqrt(np.mean((pred_vals - true_vals) ** 2)) |
| |
| ax.set_xlabel(f'True {param_name} (normalized)') |
| ax.set_ylabel(f'Predicted {param_name} (normalized)') |
| ax.set_title(f'{param_name}: R²={r_squared:.3f}, RMSE={rmse:.3f}') |
| ax.grid(True, alpha=0.3) |
| |
| |
| ax.set_aspect('equal', adjustable='box') |
| |
| plt.tight_layout() |
| |
| |
| buf = io.BytesIO() |
| plt.savefig(buf, format='png', dpi=100, bbox_inches='tight') |
| buf.seek(0) |
| |
| |
| if hasattr(trainer.logger, 'experiment'): |
| from PIL import Image |
| import torchvision.transforms as transforms |
| |
| |
| image = Image.open(buf) |
| transform = transforms.ToTensor() |
| image_tensor = transform(image) |
| |
| trainer.logger.experiment.add_image( |
| 'Truth_vs_Prediction', |
| image_tensor, |
| global_step=epoch |
| ) |
| |
| plt.close(fig) |
| buf.close() |
|
|
|
|
| class DispersionLightningModule(pl.LightningModule): |
| """ |
| PyTorch Lightning module for training Dispersion transformer. |
| |
| Handles multi-output regression for NB GLM parameters (mu, beta, alpha) |
| with separate loss tracking and metrics for each parameter. |
| """ |
| |
| def __init__(self, |
| model_config: Dict[str, Any], |
| learning_rate: float = 1e-4, |
| weight_decay: float = 1e-5, |
| scheduler_patience: int = 5, |
| scheduler_factor: float = 0.5, |
| loss_weights: Optional[Dict[str, float]] = None): |
| """ |
| Initialize Dispersion Lightning module. |
| |
| Args: |
| model_config: Configuration for DispersionTransformer model |
| learning_rate: Learning rate for optimizer |
| weight_decay: Weight decay for optimizer |
| scheduler_patience: Patience for ReduceLROnPlateau scheduler |
| scheduler_factor: Factor for ReduceLROnPlateau reduction |
| loss_weights: Optional weights for different parameters in loss calculation |
| """ |
| super().__init__() |
| self.save_hyperparameters() |
| |
| |
| self.model = DispersionTransformer(**model_config) |
| |
| |
| self.learning_rate = learning_rate |
| self.weight_decay = weight_decay |
| self.scheduler_patience = scheduler_patience |
| self.scheduler_factor = scheduler_factor |
| |
| |
| if loss_weights is None: |
| |
| self.loss_weights = { |
| 'mu': 1.0, |
| 'beta': 1.0, |
| 'alpha': 1.0 |
| } |
| else: |
| self.loss_weights = loss_weights |
| |
| |
| self.loss_weight_tensor = torch.tensor([ |
| self.loss_weights[col] for col in self.model.TARGET_COLUMNS |
| ], dtype=torch.float32) |
| |
| def forward(self, set_1, set_2, set_1_mask, set_2_mask): |
| """Forward pass through the model.""" |
| return self.model(set_1, set_2, set_1_mask, set_2_mask) |
| |
| def compute_loss(self, predictions, targets): |
| """ |
| Compute weighted multi-output MSE loss. |
| |
| Args: |
| predictions: Model predictions (B, 3) |
| targets: Target values (B, 3) |
| |
| Returns: |
| Dictionary with total loss and per-parameter losses |
| """ |
| |
| if self.loss_weight_tensor.device != predictions.device: |
| self.loss_weight_tensor = self.loss_weight_tensor.to(predictions.device) |
| |
| |
| mse_per_output = F.mse_loss(predictions, targets, reduction='none').mean(dim=0) |
| |
| |
| weighted_losses = mse_per_output * self.loss_weight_tensor |
| |
| |
| total_loss = weighted_losses.sum() |
| |
| |
| loss_dict = {'total_loss': total_loss} |
| for i, col in enumerate(self.model.TARGET_COLUMNS): |
| loss_dict[f'loss_{col}'] = mse_per_output[i] |
| loss_dict[f'weighted_loss_{col}'] = weighted_losses[i] |
| |
| return loss_dict |
| |
| def compute_metrics(self, predictions, targets, prefix=''): |
| """ |
| Compute RMSE and MAE metrics for each parameter. |
| |
| Args: |
| predictions: Model predictions (B, 3) |
| targets: Target values (B, 3) |
| prefix: Prefix for metric names (e.g., 'train_', 'val_') |
| |
| Returns: |
| Dictionary with metrics |
| """ |
| metrics = {} |
| |
| for i, col in enumerate(self.model.TARGET_COLUMNS): |
| pred_col = predictions[:, i] |
| target_col = targets[:, i] |
| |
| rmse = compute_rmse(pred_col, target_col) |
| mae = compute_mae(pred_col, target_col) |
| |
| metrics[f'{prefix}rmse_{col}'] = rmse |
| metrics[f'{prefix}mae_{col}'] = mae |
| |
| |
| all_rmse = [metrics[f'{prefix}rmse_{col}'] for col in self.model.TARGET_COLUMNS] |
| all_mae = [metrics[f'{prefix}mae_{col}'] for col in self.model.TARGET_COLUMNS] |
| |
| metrics[f'{prefix}rmse_overall'] = sum(all_rmse) / len(all_rmse) |
| metrics[f'{prefix}mae_overall'] = sum(all_mae) / len(all_mae) |
| |
| return metrics |
| |
| def training_step(self, batch, batch_idx): |
| """Training step.""" |
| set_1, set_2, set_1_mask, set_2_mask, targets = batch |
| |
| |
| predictions = self(set_1, set_2, set_1_mask, set_2_mask) |
| |
| |
| loss_dict = self.compute_loss(predictions, targets) |
| |
| |
| for key, value in loss_dict.items(): |
| self.log(f'train_{key}', value, on_step=True, on_epoch=True, prog_bar=(key == 'total_loss')) |
| |
| |
| if batch_idx % 100 == 0: |
| metrics = self.compute_metrics(predictions, targets, prefix='train_') |
| for key, value in metrics.items(): |
| self.log(key, value, on_step=False, on_epoch=True) |
| |
| |
| batch_size = targets.shape[0] |
| |
| for i, param_name in enumerate(['mu', 'beta', 'alpha']): |
| param_targets = targets[:, i] |
| batch_mean = param_targets.mean().item() |
| batch_std = param_targets.std().item() |
| self.log(f'train_batch_{param_name}_mean', batch_mean, on_step=True, on_epoch=False) |
| self.log(f'train_batch_{param_name}_std', batch_std, on_step=True, on_epoch=False) |
| |
| return loss_dict['total_loss'] |
| |
| def on_before_optimizer_step(self, optimizer): |
| """Log gradient norms for training stability monitoring.""" |
| |
| grad_norm = 0.0 |
| param_count = 0 |
| for param in self.parameters(): |
| if param.grad is not None: |
| grad_norm += param.grad.data.norm(2).item() ** 2 |
| param_count += 1 |
| |
| if param_count > 0: |
| grad_norm = grad_norm ** 0.5 |
| self.log('train_grad_norm', grad_norm, on_step=True, on_epoch=False) |
| |
| def validation_step(self, batch, batch_idx): |
| """Validation step.""" |
| set_1, set_2, set_1_mask, set_2_mask, targets = batch |
| |
| |
| predictions = self(set_1, set_2, set_1_mask, set_2_mask) |
| |
| |
| loss_dict = self.compute_loss(predictions, targets) |
| |
| |
| for key, value in loss_dict.items(): |
| self.log(f'val_{key}', value, on_step=False, on_epoch=True, prog_bar=(key == 'total_loss')) |
| |
| |
| metrics = self.compute_metrics(predictions, targets, prefix='val_') |
| for key, value in metrics.items(): |
| self.log(key, value, on_step=False, on_epoch=True) |
| |
| |
| if batch_idx == 0: |
| self.train() |
| with torch.no_grad(): |
| train_mode_predictions = self(set_1, set_2, set_1_mask, set_2_mask) |
| train_mode_loss_dict = self.compute_loss(train_mode_predictions, targets) |
| |
| self.log('val_total_loss_with_dropout', train_mode_loss_dict['total_loss'], on_step=False, on_epoch=True) |
| self.eval() |
| |
| |
| if batch_idx == 0: |
| batch_size = targets.shape[0] |
| |
| for i, param_name in enumerate(['mu', 'beta', 'alpha']): |
| param_targets = targets[:, i] |
| batch_mean = param_targets.mean().item() |
| batch_std = param_targets.std().item() |
| self.log(f'val_batch_{param_name}_mean', batch_mean, on_step=False, on_epoch=True) |
| self.log(f'val_batch_{param_name}_std', batch_std, on_step=False, on_epoch=True) |
| |
| |
| pred_vs_target_correlation = torch.corrcoef(torch.stack([ |
| predictions.flatten(), targets.flatten() |
| ]))[0, 1].item() |
| self.log('val_batch_pred_target_corr', pred_vs_target_correlation, on_step=False, on_epoch=True) |
| |
| return loss_dict['total_loss'] |
| |
| def configure_optimizers(self): |
| """Configure optimizer and scheduler.""" |
| optimizer = torch.optim.AdamW( |
| self.parameters(), |
| lr=self.learning_rate, |
| weight_decay=self.weight_decay |
| ) |
| |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( |
| optimizer, |
| mode='min', |
| factor=self.scheduler_factor, |
| patience=self.scheduler_patience, |
| verbose=True |
| ) |
| |
| return { |
| 'optimizer': optimizer, |
| 'lr_scheduler': { |
| 'scheduler': scheduler, |
| 'monitor': 'val_total_loss', |
| 'interval': 'epoch', |
| 'frequency': 1 |
| } |
| } |
|
|
|
|
| def train_dispersion_transformer(config: Dict[str, Any]) -> Dict[str, Any]: |
| """ |
| Train a Dispersion transformer model. |
| |
| Args: |
| config: Configuration dictionary containing: |
| - model_config: Model configuration |
| - batch_size: Batch size |
| - num_workers: Number of data loading workers |
| - max_epochs: Maximum training epochs |
| - examples_per_epoch: Number of examples per epoch |
| - learning_rate: Learning rate |
| - weight_decay: Weight decay |
| - loss_weights: Optional loss weights |
| - checkpoint_dir: Directory for checkpoints |
| - seed: Random seed |
| |
| Returns: |
| Dictionary with training results |
| """ |
| |
| if 'seed' in config: |
| pl.seed_everything(config['seed']) |
| |
| |
| |
| train_loader = create_dataloaders( |
| batch_size=config.get('batch_size', 32), |
| num_workers=config.get('num_workers', 4), |
| num_examples_per_epoch=config.get('examples_per_epoch', 100000), |
| parameter_distributions=config.get('parameter_distributions'), |
| seed=None, |
| persistent_workers=True |
| ) |
| |
| |
| val_loader = create_dataloaders( |
| batch_size=config.get('batch_size', 32), |
| num_workers=1, |
| num_examples_per_epoch=10000, |
| parameter_distributions=config.get('parameter_distributions'), |
| seed=42, |
| persistent_workers=True |
| ) |
| |
| |
| if config.get('parameter_distributions') is None: |
| from .dataset import ParameterDistributions |
| param_dist = ParameterDistributions() |
| else: |
| param_dist = config.get('parameter_distributions') |
| |
| |
| model_config = config['model_config'].copy() |
| model_config['target_stats'] = param_dist.target_stats |
| |
| |
| model = DispersionLightningModule( |
| model_config=model_config, |
| learning_rate=config.get('learning_rate', 1e-4), |
| weight_decay=config.get('weight_decay', 1e-5), |
| loss_weights=config.get('loss_weights') |
| ) |
| |
| |
| logger = TensorBoardLogger( |
| save_dir=config.get('log_dir', './logs'), |
| name='dispersion_transformer' |
| ) |
| |
| |
| checkpoint_callback = ModelCheckpoint( |
| monitor='val_total_loss', |
| dirpath=config.get('checkpoint_dir', './checkpoints'), |
| filename='dispersion_transformer-epoch={epoch:02d}-val_total_loss={val_total_loss:.4f}', |
| save_top_k=3, |
| mode='min', |
| save_last=True, |
| every_n_epochs=1, |
| verbose=True |
| ) |
| |
| early_stopping = EarlyStopping( |
| monitor='val_total_loss', |
| patience=config.get('early_stopping_patience', 15), |
| mode='min' |
| ) |
| |
| |
| plot_callback = PredictionPlotCallback( |
| plot_every_n_epochs=config.get('plot_every_n_epochs', 5), |
| max_samples=config.get('plot_max_samples', 500) |
| ) |
| |
| |
| trainer = pl.Trainer( |
| max_epochs=config.get('max_epochs', 100), |
| logger=logger, |
| callbacks=[checkpoint_callback, early_stopping, plot_callback], |
| accelerator='mps' if torch.backends.mps.is_available() else ('gpu' if torch.cuda.is_available() else 'cpu'), |
| devices=1, |
| gradient_clip_val=config.get('gradient_clip', 1.0), |
| log_every_n_steps=config.get('log_every_n_steps', 100), |
| val_check_interval=config.get('val_check_interval', 0.5), |
| enable_progress_bar=True |
| ) |
| |
| |
| trainer.fit(model, train_loader, val_loader) |
| |
| |
| return { |
| 'best_model_path': checkpoint_callback.best_model_path, |
| 'trainer': trainer, |
| 'model': model |
| } |
|
|
|
|
| def main(): |
| """Main training script.""" |
| parser = argparse.ArgumentParser(description='Train Dispersion Transformer') |
| |
| |
| parser.add_argument('--d_model', type=int, default=128, help='Model dimension') |
| parser.add_argument('--n_heads', type=int, default=8, help='Number of attention heads') |
| parser.add_argument('--num_self_layers', type=int, default=3, help='Number of self-attention layers') |
| parser.add_argument('--num_cross_layers', type=int, default=3, help='Number of cross-attention layers') |
| parser.add_argument('--dropout', type=float, default=0.1, help='Dropout rate') |
| |
| |
| parser.add_argument('--batch_size', type=int, default=32, help='Batch size') |
| parser.add_argument('--num_workers', type=int, default=4, help='Number of data loading workers') |
| parser.add_argument('--max_epochs', type=int, default=100, help='Maximum epochs') |
| parser.add_argument('--examples_per_epoch', type=int, default=100000, help='Examples per epoch') |
| parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate') |
| parser.add_argument('--weight_decay', type=float, default=1e-5, help='Weight decay') |
| |
| |
| parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints', help='Checkpoint directory') |
| parser.add_argument('--log_dir', type=str, default='./logs', help='Log directory') |
| parser.add_argument('--seed', type=int, default=42, help='Random seed') |
| parser.add_argument('--early_stopping_patience', type=int, default=15, help='Early stopping patience') |
| parser.add_argument('--plot_every_n_epochs', type=int, default=5, help='Generate plots every N epochs') |
| parser.add_argument('--plot_max_samples', type=int, default=500, help='Max samples to use in plots') |
| |
| args = parser.parse_args() |
| |
| |
| config = { |
| 'model_config': { |
| 'dim_input': 1, |
| 'd_model': args.d_model, |
| 'n_heads': args.n_heads, |
| 'num_self_layers': args.num_self_layers, |
| 'num_cross_layers': args.num_cross_layers, |
| 'dropout': args.dropout |
| }, |
| 'batch_size': args.batch_size, |
| 'num_workers': args.num_workers, |
| 'max_epochs': args.max_epochs, |
| 'examples_per_epoch': args.examples_per_epoch, |
| 'learning_rate': args.learning_rate, |
| 'weight_decay': args.weight_decay, |
| 'checkpoint_dir': args.checkpoint_dir, |
| 'log_dir': args.log_dir, |
| 'seed': args.seed, |
| 'early_stopping_patience': args.early_stopping_patience, |
| 'plot_every_n_epochs': args.plot_every_n_epochs, |
| 'plot_max_samples': args.plot_max_samples |
| } |
| |
| |
| results = train_dispersion_transformer(config) |
| print(f"Best model saved at: {results['best_model_path']}") |
|
|
|
|
| if __name__ == '__main__': |
| main() |