valsv's picture
Upload folder using huggingface_hub
ccd282b verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback
from pytorch_lightning.loggers import TensorBoardLogger
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg') # Use non-interactive backend
from typing import Dict, Any, Optional
import argparse
import os
import io
from .model import DispersionTransformer
from .dataset import create_dataloaders, ParameterDistributions
from .utils import compute_rmse, compute_mae
class PredictionPlotCallback(Callback):
"""Callback to plot truth vs prediction scatter plots in TensorBoard."""
def __init__(self, plot_every_n_epochs=5, max_samples=500):
"""
Initialize plotting callback.
Args:
plot_every_n_epochs: How often to generate plots
max_samples: Maximum number of samples to plot (for performance)
"""
self.plot_every_n_epochs = plot_every_n_epochs
self.max_samples = max_samples
def on_validation_epoch_end(self, trainer, pl_module):
"""Generate truth vs prediction plots at end of validation epoch."""
if trainer.current_epoch % self.plot_every_n_epochs != 0:
return
# Set model to eval mode
pl_module.eval()
# Collect predictions and targets from validation set
predictions_list = []
targets_list = []
with torch.no_grad():
# Get a batch from validation loader
val_loader = trainer.val_dataloaders
for batch_idx, batch in enumerate(val_loader):
if batch_idx >= 10: # Only use first 10 batches for plotting
break
set_1, set_2, set_1_mask, set_2_mask, targets = batch
# Move to device
set_1 = set_1.to(pl_module.device)
set_2 = set_2.to(pl_module.device)
set_1_mask = set_1_mask.to(pl_module.device)
set_2_mask = set_2_mask.to(pl_module.device)
targets = targets.to(pl_module.device)
# Forward pass
predictions = pl_module(set_1, set_2, set_1_mask, set_2_mask)
predictions_list.append(predictions.cpu())
targets_list.append(targets.cpu())
if not predictions_list:
return
# Concatenate all predictions and targets
all_predictions = torch.cat(predictions_list, dim=0)
all_targets = torch.cat(targets_list, dim=0)
# Limit number of samples for performance
if len(all_predictions) > self.max_samples:
indices = torch.randperm(len(all_predictions))[:self.max_samples]
all_predictions = all_predictions[indices]
all_targets = all_targets[indices]
# Create plots for each parameter
self._create_plots(trainer, all_predictions, all_targets, trainer.current_epoch)
def _create_plots(self, trainer, predictions, targets, epoch):
"""Create scatter plots for each parameter."""
param_names = ['μ', 'β', 'α']
# Create subplot with 1 row, 3 columns
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
for i, (param_name, ax) in enumerate(zip(param_names, axes)):
pred_vals = predictions[:, i].numpy()
true_vals = targets[:, i].numpy()
# Create scatter plot
ax.scatter(true_vals, pred_vals, alpha=0.6, s=20)
# Add perfect prediction line
min_val = min(true_vals.min(), pred_vals.min())
max_val = max(true_vals.max(), pred_vals.max())
ax.plot([min_val, max_val], [min_val, max_val], 'r--', alpha=0.8, linewidth=1)
# Calculate R²
correlation_matrix = np.corrcoef(true_vals, pred_vals)
r_squared = correlation_matrix[0, 1] ** 2
# Calculate RMSE
rmse = np.sqrt(np.mean((pred_vals - true_vals) ** 2))
ax.set_xlabel(f'True {param_name} (normalized)')
ax.set_ylabel(f'Predicted {param_name} (normalized)')
ax.set_title(f'{param_name}: R²={r_squared:.3f}, RMSE={rmse:.3f}')
ax.grid(True, alpha=0.3)
# Make axes equal for better visualization
ax.set_aspect('equal', adjustable='box')
plt.tight_layout()
# Convert plot to image and log to TensorBoard
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
buf.seek(0)
# Log to TensorBoard
if hasattr(trainer.logger, 'experiment'):
from PIL import Image
import torchvision.transforms as transforms
# Convert to tensor
image = Image.open(buf)
transform = transforms.ToTensor()
image_tensor = transform(image)
trainer.logger.experiment.add_image(
'Truth_vs_Prediction',
image_tensor,
global_step=epoch
)
plt.close(fig)
buf.close()
class DispersionLightningModule(pl.LightningModule):
"""
PyTorch Lightning module for training Dispersion transformer.
Handles multi-output regression for NB GLM parameters (mu, beta, alpha)
with separate loss tracking and metrics for each parameter.
"""
def __init__(self,
model_config: Dict[str, Any],
learning_rate: float = 1e-4,
weight_decay: float = 1e-5,
scheduler_patience: int = 5,
scheduler_factor: float = 0.5,
loss_weights: Optional[Dict[str, float]] = None):
"""
Initialize Dispersion Lightning module.
Args:
model_config: Configuration for DispersionTransformer model
learning_rate: Learning rate for optimizer
weight_decay: Weight decay for optimizer
scheduler_patience: Patience for ReduceLROnPlateau scheduler
scheduler_factor: Factor for ReduceLROnPlateau reduction
loss_weights: Optional weights for different parameters in loss calculation
"""
super().__init__()
self.save_hyperparameters()
# Create model
self.model = DispersionTransformer(**model_config)
# Training hyperparameters
self.learning_rate = learning_rate
self.weight_decay = weight_decay
self.scheduler_patience = scheduler_patience
self.scheduler_factor = scheduler_factor
# Loss weights for multi-task learning
if loss_weights is None:
# Equal weights since targets are now normalized to N(0,1)
self.loss_weights = {
'mu': 1.0,
'beta': 1.0,
'alpha': 1.0 # Equal weight now that scales are normalized
}
else:
self.loss_weights = loss_weights
# Convert to tensor for efficient computation
self.loss_weight_tensor = torch.tensor([
self.loss_weights[col] for col in self.model.TARGET_COLUMNS
], dtype=torch.float32)
def forward(self, set_1, set_2, set_1_mask, set_2_mask):
"""Forward pass through the model."""
return self.model(set_1, set_2, set_1_mask, set_2_mask)
def compute_loss(self, predictions, targets):
"""
Compute weighted multi-output MSE loss.
Args:
predictions: Model predictions (B, 3)
targets: Target values (B, 3)
Returns:
Dictionary with total loss and per-parameter losses
"""
# Ensure loss weights are on the correct device
if self.loss_weight_tensor.device != predictions.device:
self.loss_weight_tensor = self.loss_weight_tensor.to(predictions.device)
# Compute MSE loss for each output
mse_per_output = F.mse_loss(predictions, targets, reduction='none').mean(dim=0) # (3,)
# Apply weights
weighted_losses = mse_per_output * self.loss_weight_tensor
# Total loss
total_loss = weighted_losses.sum()
# Create loss dictionary
loss_dict = {'total_loss': total_loss}
for i, col in enumerate(self.model.TARGET_COLUMNS):
loss_dict[f'loss_{col}'] = mse_per_output[i]
loss_dict[f'weighted_loss_{col}'] = weighted_losses[i]
return loss_dict
def compute_metrics(self, predictions, targets, prefix=''):
"""
Compute RMSE and MAE metrics for each parameter.
Args:
predictions: Model predictions (B, 3)
targets: Target values (B, 3)
prefix: Prefix for metric names (e.g., 'train_', 'val_')
Returns:
Dictionary with metrics
"""
metrics = {}
for i, col in enumerate(self.model.TARGET_COLUMNS):
pred_col = predictions[:, i]
target_col = targets[:, i]
rmse = compute_rmse(pred_col, target_col)
mae = compute_mae(pred_col, target_col)
metrics[f'{prefix}rmse_{col}'] = rmse
metrics[f'{prefix}mae_{col}'] = mae
# Overall metrics (averaged across parameters)
all_rmse = [metrics[f'{prefix}rmse_{col}'] for col in self.model.TARGET_COLUMNS]
all_mae = [metrics[f'{prefix}mae_{col}'] for col in self.model.TARGET_COLUMNS]
metrics[f'{prefix}rmse_overall'] = sum(all_rmse) / len(all_rmse)
metrics[f'{prefix}mae_overall'] = sum(all_mae) / len(all_mae)
return metrics
def training_step(self, batch, batch_idx):
"""Training step."""
set_1, set_2, set_1_mask, set_2_mask, targets = batch
# Forward pass
predictions = self(set_1, set_2, set_1_mask, set_2_mask)
# Compute loss
loss_dict = self.compute_loss(predictions, targets)
# Log losses
for key, value in loss_dict.items():
self.log(f'train_{key}', value, on_step=True, on_epoch=True, prog_bar=(key == 'total_loss'))
# Compute and log metrics every N batches
if batch_idx % 100 == 0:
metrics = self.compute_metrics(predictions, targets, prefix='train_')
for key, value in metrics.items():
self.log(key, value, on_step=False, on_epoch=True)
# DIAGNOSTIC: Log batch statistics to detect batch-level artifacts
batch_size = targets.shape[0]
# Log target statistics within this batch
for i, param_name in enumerate(['mu', 'beta', 'alpha']):
param_targets = targets[:, i]
batch_mean = param_targets.mean().item()
batch_std = param_targets.std().item()
self.log(f'train_batch_{param_name}_mean', batch_mean, on_step=True, on_epoch=False)
self.log(f'train_batch_{param_name}_std', batch_std, on_step=True, on_epoch=False)
return loss_dict['total_loss']
def on_before_optimizer_step(self, optimizer):
"""Log gradient norms for training stability monitoring."""
# Compute gradient norm
grad_norm = 0.0
param_count = 0
for param in self.parameters():
if param.grad is not None:
grad_norm += param.grad.data.norm(2).item() ** 2
param_count += 1
if param_count > 0:
grad_norm = grad_norm ** 0.5
self.log('train_grad_norm', grad_norm, on_step=True, on_epoch=False)
def validation_step(self, batch, batch_idx):
"""Validation step."""
set_1, set_2, set_1_mask, set_2_mask, targets = batch
# Forward pass
predictions = self(set_1, set_2, set_1_mask, set_2_mask)
# Compute loss
loss_dict = self.compute_loss(predictions, targets)
# Log losses
for key, value in loss_dict.items():
self.log(f'val_{key}', value, on_step=False, on_epoch=True, prog_bar=(key == 'total_loss'))
# Compute and log metrics
metrics = self.compute_metrics(predictions, targets, prefix='val_')
for key, value in metrics.items():
self.log(key, value, on_step=False, on_epoch=True)
# DIAGNOSTIC: Also compute loss with model in training mode (dropout active)
if batch_idx == 0: # Only do this once per validation epoch for efficiency
self.train() # Temporarily switch to training mode
with torch.no_grad():
train_mode_predictions = self(set_1, set_2, set_1_mask, set_2_mask)
train_mode_loss_dict = self.compute_loss(train_mode_predictions, targets)
# Log the training-mode validation loss for comparison
self.log('val_total_loss_with_dropout', train_mode_loss_dict['total_loss'], on_step=False, on_epoch=True)
self.eval() # Switch back to eval mode
# DIAGNOSTIC: Log batch statistics to detect batch-level artifacts
if batch_idx == 0: # Only log for first batch per validation epoch
batch_size = targets.shape[0]
# Log target statistics within this batch
for i, param_name in enumerate(['mu', 'beta', 'alpha']):
param_targets = targets[:, i]
batch_mean = param_targets.mean().item()
batch_std = param_targets.std().item()
self.log(f'val_batch_{param_name}_mean', batch_mean, on_step=False, on_epoch=True)
self.log(f'val_batch_{param_name}_std', batch_std, on_step=False, on_epoch=True)
# Log how well predictions match within-batch target statistics
pred_vs_target_correlation = torch.corrcoef(torch.stack([
predictions.flatten(), targets.flatten()
]))[0, 1].item()
self.log('val_batch_pred_target_corr', pred_vs_target_correlation, on_step=False, on_epoch=True)
return loss_dict['total_loss']
def configure_optimizers(self):
"""Configure optimizer and scheduler."""
optimizer = torch.optim.AdamW(
self.parameters(),
lr=self.learning_rate,
weight_decay=self.weight_decay
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode='min',
factor=self.scheduler_factor,
patience=self.scheduler_patience,
verbose=True
)
return {
'optimizer': optimizer,
'lr_scheduler': {
'scheduler': scheduler,
'monitor': 'val_total_loss', # Use validation loss for better generalization
'interval': 'epoch',
'frequency': 1
}
}
def train_dispersion_transformer(config: Dict[str, Any]) -> Dict[str, Any]:
"""
Train a Dispersion transformer model.
Args:
config: Configuration dictionary containing:
- model_config: Model configuration
- batch_size: Batch size
- num_workers: Number of data loading workers
- max_epochs: Maximum training epochs
- examples_per_epoch: Number of examples per epoch
- learning_rate: Learning rate
- weight_decay: Weight decay
- loss_weights: Optional loss weights
- checkpoint_dir: Directory for checkpoints
- seed: Random seed
Returns:
Dictionary with training results
"""
# Set random seed for reproducibility
if 'seed' in config:
pl.seed_everything(config['seed'])
# Create data loader with persistent workers to avoid file descriptor leaks
# Use None for training seed to get random data generation each epoch
train_loader = create_dataloaders(
batch_size=config.get('batch_size', 32),
num_workers=config.get('num_workers', 4),
num_examples_per_epoch=config.get('examples_per_epoch', 100000),
parameter_distributions=config.get('parameter_distributions'),
seed=None, # Random seed for training data diversity
persistent_workers=True # Keep workers alive between epochs
)
# For validation, use fixed seed for consistent evaluation
val_loader = create_dataloaders(
batch_size=config.get('batch_size', 32),
num_workers=1, # Use single worker for validation to minimize file descriptors
num_examples_per_epoch=10000, # Smaller validation set is fine
parameter_distributions=config.get('parameter_distributions'),
seed=42, # Fixed seed for reproducible validation
persistent_workers=True # Keep workers alive between epochs
)
# Get target normalization stats from parameter distributions
if config.get('parameter_distributions') is None:
from .dataset import ParameterDistributions
param_dist = ParameterDistributions()
else:
param_dist = config.get('parameter_distributions')
# Add target stats to model config for denormalization
model_config = config['model_config'].copy()
model_config['target_stats'] = param_dist.target_stats
# Create model
model = DispersionLightningModule(
model_config=model_config,
learning_rate=config.get('learning_rate', 1e-4),
weight_decay=config.get('weight_decay', 1e-5),
loss_weights=config.get('loss_weights')
)
# Setup logging
logger = TensorBoardLogger(
save_dir=config.get('log_dir', './logs'),
name='dispersion_transformer'
)
# Setup callbacks with proper metric monitoring
checkpoint_callback = ModelCheckpoint(
monitor='val_total_loss', # Use validation loss for better model selection
dirpath=config.get('checkpoint_dir', './checkpoints'),
filename='dispersion_transformer-epoch={epoch:02d}-val_total_loss={val_total_loss:.4f}',
save_top_k=3, # Keep best 3 models by validation loss
mode='min',
save_last=True, # Always save the last checkpoint
every_n_epochs=1, # Save every epoch
verbose=True # Print when checkpoints are saved
)
early_stopping = EarlyStopping(
monitor='val_total_loss', # Use validation loss for proper generalization
patience=config.get('early_stopping_patience', 15),
mode='min'
)
# Create prediction plotting callback
plot_callback = PredictionPlotCallback(
plot_every_n_epochs=config.get('plot_every_n_epochs', 5),
max_samples=config.get('plot_max_samples', 500)
)
# Create trainer
trainer = pl.Trainer(
max_epochs=config.get('max_epochs', 100),
logger=logger,
callbacks=[checkpoint_callback, early_stopping, plot_callback],
accelerator='mps' if torch.backends.mps.is_available() else ('gpu' if torch.cuda.is_available() else 'cpu'),
devices=1,
gradient_clip_val=config.get('gradient_clip', 1.0),
log_every_n_steps=config.get('log_every_n_steps', 100),
val_check_interval=config.get('val_check_interval', 0.5), # Validate twice per epoch (every 50K examples)
enable_progress_bar=True
)
# Train model
trainer.fit(model, train_loader, val_loader)
# Return results
return {
'best_model_path': checkpoint_callback.best_model_path,
'trainer': trainer,
'model': model
}
def main():
"""Main training script."""
parser = argparse.ArgumentParser(description='Train Dispersion Transformer')
# Model configuration
parser.add_argument('--d_model', type=int, default=128, help='Model dimension')
parser.add_argument('--n_heads', type=int, default=8, help='Number of attention heads')
parser.add_argument('--num_self_layers', type=int, default=3, help='Number of self-attention layers')
parser.add_argument('--num_cross_layers', type=int, default=3, help='Number of cross-attention layers')
parser.add_argument('--dropout', type=float, default=0.1, help='Dropout rate')
# Training configuration
parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
parser.add_argument('--num_workers', type=int, default=4, help='Number of data loading workers')
parser.add_argument('--max_epochs', type=int, default=100, help='Maximum epochs')
parser.add_argument('--examples_per_epoch', type=int, default=100000, help='Examples per epoch')
parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate')
parser.add_argument('--weight_decay', type=float, default=1e-5, help='Weight decay')
# Other configuration
parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints', help='Checkpoint directory')
parser.add_argument('--log_dir', type=str, default='./logs', help='Log directory')
parser.add_argument('--seed', type=int, default=42, help='Random seed')
parser.add_argument('--early_stopping_patience', type=int, default=15, help='Early stopping patience')
parser.add_argument('--plot_every_n_epochs', type=int, default=5, help='Generate plots every N epochs')
parser.add_argument('--plot_max_samples', type=int, default=500, help='Max samples to use in plots')
args = parser.parse_args()
# Create configuration
config = {
'model_config': {
'dim_input': 1,
'd_model': args.d_model,
'n_heads': args.n_heads,
'num_self_layers': args.num_self_layers,
'num_cross_layers': args.num_cross_layers,
'dropout': args.dropout
},
'batch_size': args.batch_size,
'num_workers': args.num_workers,
'max_epochs': args.max_epochs,
'examples_per_epoch': args.examples_per_epoch,
'learning_rate': args.learning_rate,
'weight_decay': args.weight_decay,
'checkpoint_dir': args.checkpoint_dir,
'log_dir': args.log_dir,
'seed': args.seed,
'early_stopping_patience': args.early_stopping_patience,
'plot_every_n_epochs': args.plot_every_n_epochs,
'plot_max_samples': args.plot_max_samples
}
# Train model
results = train_dispersion_transformer(config)
print(f"Best model saved at: {results['best_model_path']}")
if __name__ == '__main__':
main()