| """ |
| Training script for the AIA-GOES multimodal solar flare forecasting model using PyTorch Lightning. |
| |
| This script: |
| 1. Loads configuration from a YAML file with variable substitution (e.g., ${base_dir} references). |
| 2. Initializes the AIA-GOES DataModule. |
| 3. Configures logging with Weights & Biases. |
| 4. Builds and trains a Vision Transformer (ViTLocal) model. |
| 5. Computes dynamic base class weights for flare categories (Quiet, C, M, X). |
| 6. Saves model checkpoints (.ckpt and .pth formats). |
| """ |
|
|
| import argparse |
| import os |
| from datetime import datetime |
| import re |
| from multiprocessing import Pool, cpu_count |
| from functools import partial |
|
|
| import yaml |
| import wandb |
| import torch |
| import numpy as np |
| from pytorch_lightning import Trainer |
| from pytorch_lightning.loggers import WandbLogger |
| from pytorch_lightning.callbacks import ModelCheckpoint |
| from pathlib import Path |
| import sys |
| |
| PROJECT_ROOT = Path(__file__).parent.parent.parent.absolute() |
| sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
| from forecasting.data_loaders.SDOAIA_dataloader import AIA_GOESDataModule |
|
|
| from forecasting.models.vit_patch_model_local import ViTLocal |
| from callback import ImagePredictionLogger_SXR, AttentionMapCallback |
|
|
| from pytorch_lightning.callbacks import Callback |
|
|
|
|
|
|
| def resolve_config_variables(config_dict): |
| """ |
| Recursively resolve variable references within a YAML config dictionary. |
| |
| This function substitutes placeholders like `${variable}` with the |
| corresponding values defined at the root level of the configuration. |
| |
| Parameters |
| ---------- |
| config_dict : dict |
| Configuration dictionary loaded from a YAML file. |
| |
| Returns |
| ------- |
| dict |
| Configuration dictionary with all ${variable} references resolved. |
| """ |
| |
| variables = {} |
| for key, value in config_dict.items(): |
| if isinstance(value, str) and not value.startswith('${'): |
| variables[key] = value |
|
|
| def substitute_value(value, variables): |
| """Helper function to replace ${var_name} with actual values.""" |
| if isinstance(value, str): |
| pattern = r'\$\{([^}]+)\}' |
| for match in re.finditer(pattern, value): |
| var_name = match.group(1) |
| if var_name in variables: |
| value = value.replace(f'${{{var_name}}}', variables[var_name]) |
| return value |
|
|
| def recursive_substitute(obj, variables): |
| """Recursively substitute variables in nested structures.""" |
| if isinstance(obj, dict): |
| return {k: recursive_substitute(v, variables) for k, v in obj.items()} |
| elif isinstance(obj, list): |
| return [recursive_substitute(item, variables) for item in obj] |
| else: |
| return substitute_value(obj, variables) |
|
|
| return recursive_substitute(config_dict, variables) |
|
|
|
|
| class PTHCheckpointCallback(Callback): |
| """ |
| Custom PyTorch Lightning callback to save model checkpoints in `.pth` format. |
| |
| This is in addition to Lightning's `.ckpt` files, allowing for |
| standalone PyTorch model loading without Lightning dependencies. |
| |
| Parameters |
| ---------- |
| dirpath : str |
| Directory to save checkpoints. |
| monitor : str, optional |
| Metric name to monitor for best model saving (default: 'val_total_loss'). |
| mode : str, optional |
| Optimization direction: 'min' or 'max' (default: 'min'). |
| save_top_k : int, optional |
| Number of best checkpoints to keep (default: 1). |
| filename_prefix : str, optional |
| Prefix for checkpoint filenames. |
| """ |
| def __init__(self, dirpath, monitor='val_total_loss', mode='min', save_top_k=1, filename_prefix="model"): |
| self.dirpath = dirpath |
| self.monitor = monitor |
| self.mode = mode |
| self.save_top_k = save_top_k |
| self.filename_prefix = filename_prefix |
| self.best_score = float('inf') if mode == 'min' else float('-inf') |
|
|
| def on_validation_end(self, trainer, pl_module): |
| """ |
| Save the model checkpoint as a `.pth` file if validation metric improves. |
| |
| Parameters |
| ---------- |
| trainer : pytorch_lightning.Trainer |
| Lightning trainer instance. |
| pl_module : pytorch_lightning.LightningModule |
| The model being trained. |
| """ |
| current_score = trainer.callback_metrics.get(self.monitor) |
| if current_score is None: |
| return |
|
|
| is_better = (self.mode == 'min' and current_score < self.best_score) or \ |
| (self.mode == 'max' and current_score > self.best_score) |
|
|
| if is_better: |
| self.best_score = current_score |
| |
| filename = f"{self.filename_prefix}-epoch={trainer.current_epoch:02d}-{self.monitor}={current_score:.4f}.pth" |
| filepath = os.path.join(self.dirpath, filename) |
|
|
| torch.save({ |
| 'model': pl_module, |
| 'epoch': trainer.current_epoch, |
| 'optimizer_state_dict': trainer.optimizers[0].state_dict(), |
| 'loss': current_score, |
| }, filepath) |
|
|
|
|
|
|
|
|
| def get_base_weights(data_loader, sxr_norm): |
| """ |
| Compute inverse-frequency weights for flare classes based on training data. |
| |
| The weights help balance loss contributions from imbalanced flare categories. |
| |
| Parameters |
| ---------- |
| data_loader : AIA_GOESDataModule |
| Initialized DataModule providing the train_dataloader. |
| sxr_norm : np.ndarray |
| Normalization parameters for SXR. |
| |
| Returns |
| ------- |
| dict |
| Dictionary containing class weights for quiet, C, M, and X classes. |
| """ |
| print("Calculating base weights from DataModule...") |
| |
| |
| c_threshold = 1e-6 |
| m_threshold = 1e-5 |
| x_threshold = 1e-4 |
|
|
| from forecasting.models.vit_patch_model_local import unnormalize_sxr |
| |
| quiet_count = 0 |
| c_count = 0 |
| m_count = 0 |
| x_count = 0 |
| total = 0 |
| |
| |
| train_loader = data_loader.train_dataloader() |
| print(f"Processing {len(train_loader)} batches...") |
| |
| for batch_idx, (aia_batch, sxr_batch) in enumerate(train_loader): |
| if batch_idx % 50 == 0: |
| print(f"Processed {batch_idx}/{len(train_loader)} batches...") |
| |
| |
| sxr_un = unnormalize_sxr(sxr_batch, sxr_norm) |
| sxr_un_flat = sxr_un.view(-1).cpu().numpy() |
| |
| batch_total = len(sxr_un_flat) |
| batch_quiet = ((sxr_un_flat < c_threshold)).sum() |
| batch_c = ((sxr_un_flat >= c_threshold) & (sxr_un_flat < m_threshold)).sum() |
| batch_m = ((sxr_un_flat >= m_threshold) & (sxr_un_flat < x_threshold)).sum() |
| batch_x = ((sxr_un_flat >= x_threshold)).sum() |
| |
| total += batch_total |
| quiet_count += batch_quiet |
| c_count += batch_c |
| m_count += batch_m |
| x_count += batch_x |
|
|
| |
| quiet_count = max(quiet_count, 1) |
| c_count = max(c_count, 1) |
| m_count = max(m_count, 1) |
| x_count = max(x_count, 1) |
|
|
| |
| quiet_weight = total / (quiet_count) |
| c_weight = total / (c_count) |
| m_weight = total / m_count |
| x_weight = total / x_count |
|
|
| print("Base weights calculated") |
| print(f"Total samples: {total}") |
| print(f"Quiet samples: {quiet_count}, weight: {quiet_weight:.4f}") |
| print(f"C samples: {c_count}, weight: {c_weight:.4f}") |
| print(f"M samples: {m_count}, weight: {m_weight:.4f}") |
| print(f"X samples: {x_count}, weight: {x_weight:.4f}") |
| |
| return { |
| 'quiet': quiet_weight, |
| 'c_class': c_weight, |
| 'm_class': m_weight, |
| 'x_class': x_weight |
| } |
|
|
|
|
|
|
| if __name__ == '__main__': |
| |
| parser = argparse.ArgumentParser() |
| parser.add_argument('-config', type=str, default='config.yaml', required=True, help='Path to config YAML.') |
| args = parser.parse_args() |
|
|
| |
| with open(args.config, 'r') as stream: |
| config_data = yaml.load(stream, Loader=yaml.SafeLoader) |
| config_data = resolve_config_variables(config_data) |
|
|
| print("Resolved paths:") |
| print(f"AIA dir: {config_data['data']['aia_dir']}") |
| print(f"SXR dir: {config_data['data']['sxr_dir']}") |
| print(f"Checkpoints dir: {config_data['data']['checkpoints_dir']}") |
|
|
| sxr_norm = np.load(config_data['data']['sxr_norm_path']) |
| training_wavelengths = config_data['wavelengths'] |
|
|
| |
| data_loader = AIA_GOESDataModule( |
| aia_train_dir=config_data['data']['aia_dir'] + "/train", |
| aia_val_dir=config_data['data']['aia_dir'] + "/val", |
| aia_test_dir=config_data['data']['aia_dir'] + "/test", |
| sxr_train_dir=config_data['data']['sxr_dir'] + "/train", |
| sxr_val_dir=config_data['data']['sxr_dir'] + "/val", |
| sxr_test_dir=config_data['data']['sxr_dir'] + "/test", |
| batch_size=config_data['batch_size'], |
| num_workers=min(8, os.cpu_count()), |
| sxr_norm=sxr_norm, |
| wavelengths=training_wavelengths, |
| oversample=config_data['oversample'], |
| balance_strategy=config_data['balance_strategy'], |
| ) |
| data_loader.setup() |
|
|
| |
| wandb_logger = WandbLogger( |
| entity=config_data['wandb']['entity'], |
| project=config_data['wandb']['project'], |
| job_type=config_data['wandb']['job_type'], |
| tags=config_data['wandb']['tags'], |
| name=config_data['wandb']['run_name'], |
| notes=config_data['wandb']['notes'], |
| config=config_data |
| ) |
|
|
| |
| total_n_valid = len(data_loader.val_ds) |
| plot_samples = [data_loader.val_ds[i] for i in range(0, total_n_valid, max(1, total_n_valid // 4))] |
| sxr_plot_callback = ImagePredictionLogger_SXR(plot_samples, sxr_norm) |
| patch_size = config_data.get('vit_architecture', {}).get('patch_size', 16) |
| attention = AttentionMapCallback(patch_size=patch_size, use_local_attention=True) |
|
|
| base_weights = get_base_weights(data_loader, sxr_norm) if config_data.get('calculate_base_weights', True) else None |
| model = ViTLocal(model_kwargs=config_data['vit_architecture'], sxr_norm=sxr_norm, base_weights=base_weights) |
|
|
| |
| checkpoint_callback = ModelCheckpoint( |
| dirpath=config_data['data']['checkpoints_dir'], |
| monitor='val_total_loss', |
| mode='min', |
| save_top_k=10, |
| filename=f"{config_data['wandb']['run_name']}-{{epoch:02d}}-{{val_total_loss:.4f}}" |
| ) |
| pth_callback = PTHCheckpointCallback( |
| dirpath=config_data['data']['checkpoints_dir'], |
| monitor='val_total_loss', |
| mode='min', |
| save_top_k=1, |
| filename_prefix=config_data['wandb']['run_name'] |
| ) |
|
|
| |
| gpu_config = config_data.get('gpu_ids', config_data.get('gpu_id', 0)) |
| if gpu_config == -1: |
| accelerator, devices, strategy = "cpu", 1, "auto" |
| print("Using CPU for training") |
| elif gpu_config == "all": |
| if torch.cuda.is_available(): |
| accelerator, devices, strategy = "gpu", -1, "auto" |
| num_gpus = torch.cuda.device_count() |
| print(f"Using all available GPUs ({num_gpus} GPUs)") |
| else: |
| accelerator, devices, strategy = "cpu", 1, "auto" |
| print("No GPUs available, falling back to CPU") |
| elif isinstance(gpu_config, list): |
| if torch.cuda.is_available(): |
| accelerator, devices, strategy = "gpu", gpu_config, "auto" |
| print(f"Using GPUs: {gpu_config}") |
| else: |
| accelerator, devices, strategy = "cpu", 1, "auto" |
| print("No GPUs available, falling back to CPU") |
| else: |
| if torch.cuda.is_available(): |
| accelerator, devices, strategy = "gpu", [gpu_config], "auto" |
| print(f"Using GPU {gpu_config}") |
| else: |
| accelerator, devices, strategy = "cpu", 1, "auto" |
| print(f"GPU {gpu_config} not available, falling back to CPU") |
|
|
| |
| trainer = Trainer( |
| default_root_dir=config_data['data']['checkpoints_dir'], |
| accelerator=accelerator, |
| devices=devices, |
| strategy=strategy, |
| max_epochs=config_data['epochs'], |
| callbacks=[attention, checkpoint_callback], |
| logger=wandb_logger, |
| log_every_n_steps=10, |
| ) |
| trainer.fit(model, data_loader) |
|
|
| |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| final_checkpoint_path = os.path.join( |
| config_data['data']['checkpoints_dir'], |
| f"{config_data['wandb']['run_name']}-final-{timestamp}.pth" |
| ) |
| torch.save({'model': model, 'state_dict': model.state_dict()}, final_checkpoint_path) |
| print(f"Saved final PyTorch checkpoint: {final_checkpoint_path}") |
| wandb.finish() |
|
|