FOXES / forecasting /training /train.py
griffingoodwin04's picture
Refactor pipeline configuration and update data processing scripts
c9320d3
"""
Training script for the AIA-GOES multimodal solar flare forecasting model using PyTorch Lightning.
This script:
1. Loads configuration from a YAML file with variable substitution (e.g., ${base_dir} references).
2. Initializes the AIA-GOES DataModule.
3. Configures logging with Weights & Biases.
4. Builds and trains a Vision Transformer (ViTLocal) model.
5. Computes dynamic base class weights for flare categories (Quiet, C, M, X).
6. Saves model checkpoints (.ckpt and .pth formats).
"""
import argparse
import os
from datetime import datetime
import re
from multiprocessing import Pool, cpu_count
from functools import partial
import yaml
import wandb
import torch
import numpy as np
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pathlib import Path
import sys
# Add project root to Python path
PROJECT_ROOT = Path(__file__).parent.parent.parent.absolute()
sys.path.insert(0, str(PROJECT_ROOT))
from forecasting.data_loaders.SDOAIA_dataloader import AIA_GOESDataModule
from forecasting.models.vit_patch_model_local import ViTLocal
from callback import ImagePredictionLogger_SXR, AttentionMapCallback
from pytorch_lightning.callbacks import Callback
def resolve_config_variables(config_dict):
"""
Recursively resolve variable references within a YAML config dictionary.
This function substitutes placeholders like `${variable}` with the
corresponding values defined at the root level of the configuration.
Parameters
----------
config_dict : dict
Configuration dictionary loaded from a YAML file.
Returns
-------
dict
Configuration dictionary with all ${variable} references resolved.
"""
# Extract variables defined at root level (like base_data_dir, base_checkpoint_dir)
variables = {}
for key, value in config_dict.items():
if isinstance(value, str) and not value.startswith('${'):
variables[key] = value
def substitute_value(value, variables):
"""Helper function to replace ${var_name} with actual values."""
if isinstance(value, str):
pattern = r'\$\{([^}]+)\}'
for match in re.finditer(pattern, value):
var_name = match.group(1)
if var_name in variables:
value = value.replace(f'${{{var_name}}}', variables[var_name])
return value
def recursive_substitute(obj, variables):
"""Recursively substitute variables in nested structures."""
if isinstance(obj, dict):
return {k: recursive_substitute(v, variables) for k, v in obj.items()}
elif isinstance(obj, list):
return [recursive_substitute(item, variables) for item in obj]
else:
return substitute_value(obj, variables)
return recursive_substitute(config_dict, variables)
class PTHCheckpointCallback(Callback):
"""
Custom PyTorch Lightning callback to save model checkpoints in `.pth` format.
This is in addition to Lightning's `.ckpt` files, allowing for
standalone PyTorch model loading without Lightning dependencies.
Parameters
----------
dirpath : str
Directory to save checkpoints.
monitor : str, optional
Metric name to monitor for best model saving (default: 'val_total_loss').
mode : str, optional
Optimization direction: 'min' or 'max' (default: 'min').
save_top_k : int, optional
Number of best checkpoints to keep (default: 1).
filename_prefix : str, optional
Prefix for checkpoint filenames.
"""
def __init__(self, dirpath, monitor='val_total_loss', mode='min', save_top_k=1, filename_prefix="model"):
self.dirpath = dirpath
self.monitor = monitor
self.mode = mode
self.save_top_k = save_top_k
self.filename_prefix = filename_prefix
self.best_score = float('inf') if mode == 'min' else float('-inf')
def on_validation_end(self, trainer, pl_module):
"""
Save the model checkpoint as a `.pth` file if validation metric improves.
Parameters
----------
trainer : pytorch_lightning.Trainer
Lightning trainer instance.
pl_module : pytorch_lightning.LightningModule
The model being trained.
"""
current_score = trainer.callback_metrics.get(self.monitor)
if current_score is None:
return
is_better = (self.mode == 'min' and current_score < self.best_score) or \
(self.mode == 'max' and current_score > self.best_score)
if is_better:
self.best_score = current_score
# Save as .pth file
filename = f"{self.filename_prefix}-epoch={trainer.current_epoch:02d}-{self.monitor}={current_score:.4f}.pth"
filepath = os.path.join(self.dirpath, filename)
torch.save({
'model': pl_module,
'epoch': trainer.current_epoch,
'optimizer_state_dict': trainer.optimizers[0].state_dict(),
'loss': current_score,
}, filepath)
def get_base_weights(data_loader, sxr_norm):
"""
Compute inverse-frequency weights for flare classes based on training data.
The weights help balance loss contributions from imbalanced flare categories.
Parameters
----------
data_loader : AIA_GOESDataModule
Initialized DataModule providing the train_dataloader.
sxr_norm : np.ndarray
Normalization parameters for SXR.
Returns
-------
dict
Dictionary containing class weights for quiet, C, M, and X classes.
"""
print("Calculating base weights from DataModule...")
# Thresholds for SXR classes
c_threshold = 1e-6
m_threshold = 1e-5
x_threshold = 1e-4
from forecasting.models.vit_patch_model_local import unnormalize_sxr
quiet_count = 0
c_count = 0
m_count = 0
x_count = 0
total = 0
# Use the train_dataloader which already exists
train_loader = data_loader.train_dataloader()
print(f"Processing {len(train_loader)} batches...")
for batch_idx, (aia_batch, sxr_batch) in enumerate(train_loader):
if batch_idx % 50 == 0:
print(f"Processed {batch_idx}/{len(train_loader)} batches...")
# Unnormalize the SXR batch
sxr_un = unnormalize_sxr(sxr_batch, sxr_norm)
sxr_un_flat = sxr_un.view(-1).cpu().numpy()
batch_total = len(sxr_un_flat)
batch_quiet = ((sxr_un_flat < c_threshold)).sum()
batch_c = ((sxr_un_flat >= c_threshold) & (sxr_un_flat < m_threshold)).sum()
batch_m = ((sxr_un_flat >= m_threshold) & (sxr_un_flat < x_threshold)).sum()
batch_x = ((sxr_un_flat >= x_threshold)).sum()
total += batch_total
quiet_count += batch_quiet
c_count += batch_c
m_count += batch_m
x_count += batch_x
# Avoid division by zero
quiet_count = max(quiet_count, 1)
c_count = max(c_count, 1)
m_count = max(m_count, 1)
x_count = max(x_count, 1)
# Inverse frequency weighting
quiet_weight = total / (quiet_count)
c_weight = total / (c_count)
m_weight = total / m_count
x_weight = total / x_count
print("Base weights calculated")
print(f"Total samples: {total}")
print(f"Quiet samples: {quiet_count}, weight: {quiet_weight:.4f}")
print(f"C samples: {c_count}, weight: {c_weight:.4f}")
print(f"M samples: {m_count}, weight: {m_weight:.4f}")
print(f"X samples: {x_count}, weight: {x_weight:.4f}")
return {
'quiet': quiet_weight,
'c_class': c_weight,
'm_class': m_weight,
'x_class': x_weight
}
if __name__ == '__main__':
# Parser
parser = argparse.ArgumentParser()
parser.add_argument('-config', type=str, default='config.yaml', required=True, help='Path to config YAML.')
args = parser.parse_args()
# Load config with variable substitution
with open(args.config, 'r') as stream:
config_data = yaml.load(stream, Loader=yaml.SafeLoader)
config_data = resolve_config_variables(config_data)
print("Resolved paths:")
print(f"AIA dir: {config_data['data']['aia_dir']}")
print(f"SXR dir: {config_data['data']['sxr_dir']}")
print(f"Checkpoints dir: {config_data['data']['checkpoints_dir']}")
sxr_norm = np.load(config_data['data']['sxr_norm_path'])
training_wavelengths = config_data['wavelengths']
# DataModule
data_loader = AIA_GOESDataModule(
aia_train_dir=config_data['data']['aia_dir'] + "/train",
aia_val_dir=config_data['data']['aia_dir'] + "/val",
aia_test_dir=config_data['data']['aia_dir'] + "/test",
sxr_train_dir=config_data['data']['sxr_dir'] + "/train",
sxr_val_dir=config_data['data']['sxr_dir'] + "/val",
sxr_test_dir=config_data['data']['sxr_dir'] + "/test",
batch_size=config_data['batch_size'],
num_workers=min(8, os.cpu_count()),
sxr_norm=sxr_norm,
wavelengths=training_wavelengths,
oversample=config_data['oversample'],
balance_strategy=config_data['balance_strategy'],
)
data_loader.setup()
# Logger
wandb_logger = WandbLogger(
entity=config_data['wandb']['entity'],
project=config_data['wandb']['project'],
job_type=config_data['wandb']['job_type'],
tags=config_data['wandb']['tags'],
name=config_data['wandb']['run_name'],
notes=config_data['wandb']['notes'],
config=config_data
)
# Callbacks
total_n_valid = len(data_loader.val_ds)
plot_samples = [data_loader.val_ds[i] for i in range(0, total_n_valid, max(1, total_n_valid // 4))]
sxr_plot_callback = ImagePredictionLogger_SXR(plot_samples, sxr_norm)
patch_size = config_data.get('vit_architecture', {}).get('patch_size', 16)
attention = AttentionMapCallback(patch_size=patch_size, use_local_attention=True)
base_weights = get_base_weights(data_loader, sxr_norm) if config_data.get('calculate_base_weights', True) else None
model = ViTLocal(model_kwargs=config_data['vit_architecture'], sxr_norm=sxr_norm, base_weights=base_weights)
# Checkpoint callbacks
checkpoint_callback = ModelCheckpoint(
dirpath=config_data['data']['checkpoints_dir'],
monitor='val_total_loss',
mode='min',
save_top_k=10,
filename=f"{config_data['wandb']['run_name']}-{{epoch:02d}}-{{val_total_loss:.4f}}"
)
pth_callback = PTHCheckpointCallback(
dirpath=config_data['data']['checkpoints_dir'],
monitor='val_total_loss',
mode='min',
save_top_k=1,
filename_prefix=config_data['wandb']['run_name']
)
# Set device based on config
gpu_config = config_data.get('gpu_ids', config_data.get('gpu_id', 0))
if gpu_config == -1:
accelerator, devices, strategy = "cpu", 1, "auto"
print("Using CPU for training")
elif gpu_config == "all":
if torch.cuda.is_available():
accelerator, devices, strategy = "gpu", -1, "auto"
num_gpus = torch.cuda.device_count()
print(f"Using all available GPUs ({num_gpus} GPUs)")
else:
accelerator, devices, strategy = "cpu", 1, "auto"
print("No GPUs available, falling back to CPU")
elif isinstance(gpu_config, list):
if torch.cuda.is_available():
accelerator, devices, strategy = "gpu", gpu_config, "auto"
print(f"Using GPUs: {gpu_config}")
else:
accelerator, devices, strategy = "cpu", 1, "auto"
print("No GPUs available, falling back to CPU")
else:
if torch.cuda.is_available():
accelerator, devices, strategy = "gpu", [gpu_config], "auto"
print(f"Using GPU {gpu_config}")
else:
accelerator, devices, strategy = "cpu", 1, "auto"
print(f"GPU {gpu_config} not available, falling back to CPU")
# Trainer
trainer = Trainer(
default_root_dir=config_data['data']['checkpoints_dir'],
accelerator=accelerator,
devices=devices,
strategy=strategy,
max_epochs=config_data['epochs'],
callbacks=[attention, checkpoint_callback],
logger=wandb_logger,
log_every_n_steps=10,
)
trainer.fit(model, data_loader)
# Save final checkpoint
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
final_checkpoint_path = os.path.join(
config_data['data']['checkpoints_dir'],
f"{config_data['wandb']['run_name']}-final-{timestamp}.pth"
)
torch.save({'model': model, 'state_dict': model.state_dict()}, final_checkpoint_path)
print(f"Saved final PyTorch checkpoint: {final_checkpoint_path}")
wandb.finish()