FOXES / forecasting /training /callback.py
griffingoodwin04's picture
updates to evaluation script and callbacks
3fb991b
import wandb
from pytorch_lightning import Callback
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
import sunpy.visualization.colormaps as cm
import astropy.units as u
import matplotlib.pyplot as plt
import numpy as np
from pytorch_lightning.callbacks import Callback
from PIL import Image
import matplotlib.patches as patches
import matplotlib.cm as cm
import matplotlib.colors as mcolors
from scipy.ndimage import zoom
# Custom Callback
sdoaia94 = matplotlib.colormaps['sdoaia94']
def unnormalize_sxr(normalized_values, sxr_norm):
"""
Convert normalized SXR (soft X-ray) values back to their physical scale.
Parameters
----------
normalized_values : torch.Tensor or np.ndarray
Normalized SXR flux values.
sxr_norm : np.ndarray or torch.Tensor
Normalization parameters (mean and std used during preprocessing).
Returns
-------
np.ndarray
Unnormalized SXR flux values on the original logarithmic scale.
"""
if isinstance(normalized_values, torch.Tensor):
normalized_values = normalized_values.cpu().numpy()
normalized_values = np.array(normalized_values, dtype=np.float32)
return 10 ** (normalized_values * float(sxr_norm[1].item()) + float(sxr_norm[0].item())) - 1e-8
class ImagePredictionLogger_SXR(Callback):
"""
PyTorch Lightning callback for logging AIA input images and corresponding
true vs predicted Soft X-Ray (SXR) flux values to Weights & Biases (wandb).
This helps monitor model performance across validation epochs by
comparing predicted vs. ground-truth flare intensities.
"""
def __init__(self, data_samples, sxr_norm):
"""
Initialize callback with validation samples and normalization parameters.
Parameters
----------
data_samples : list
List of validation samples (AIA image, SXR target pairs).
sxr_norm : np.ndarray
Normalization statistics used to unnormalize predicted flux values.
"""
super().__init__()
self.data_samples = data_samples
self.val_aia = data_samples[0]
self.val_sxr = data_samples[1]
self.sxr_norm = sxr_norm
def on_validation_epoch_end(self, trainer, pl_module):
"""
Log scatter plots comparing predicted and true SXR flux values
at the end of each validation epoch.
Parameters
----------
trainer : pytorch_lightning.Trainer
The PyTorch Lightning trainer instance.
pl_module : pytorch_lightning.LightningModule
The model being trained/validated.
"""
aia_images = []
true_sxr = []
pred_sxr = []
for aia, target in self.data_samples:
aia = aia.to(pl_module.device).unsqueeze(0)
pred = pl_module(aia)
pred_sxr.append(pred.item())
aia_images.append(aia.squeeze(0).cpu().numpy())
true_sxr.append(target.item())
true_unorm = unnormalize_sxr(true_sxr, self.sxr_norm)
pred_unnorm = unnormalize_sxr(pred_sxr, self.sxr_norm)
fig1 = self.plot_aia_sxr(aia_images, true_unorm, pred_unnorm)
trainer.logger.experiment.log({"Soft X-ray flux plots": wandb.Image(fig1)})
plt.close(fig1)
fig2 = self.plot_aia_sxr_difference(aia_images, true_unorm, pred_unnorm)
trainer.logger.experiment.log({"Soft X-ray flux difference plots": wandb.Image(fig2)})
plt.close(fig2)
def plot_aia_sxr(self, val_aia, val_sxr, pred_sxr):
"""
Plot scatter of predicted vs true SXR flux values.
Returns
-------
matplotlib.figure.Figure
Scatter plot comparing true and predicted flux values.
"""
num_samples = len(val_aia)
fig, axes = plt.subplots(1, 1, figsize=(5, 2))
for i in range(num_samples):
axes.scatter(i, val_sxr[i], label='Ground truth' if i == 0 else "", color='blue')
axes.scatter(i, pred_sxr[i], label='Prediction' if i == 0 else "", color='orange')
axes.set_xlabel("Index")
axes.set_ylabel("Soft x-ray flux [W/m2]")
axes.set_yscale('log')
axes.legend()
fig.tight_layout()
return fig
def plot_aia_sxr_difference(self, val_aia, val_sxr, pred_sxr):
"""
Plot difference between true and predicted SXR flux values.
Returns
-------
matplotlib.figure.Figure
Scatter plot of flux differences (true - predicted).
"""
num_samples = len(val_aia)
fig, axes = plt.subplots(1, 1, figsize=(5, 2))
for i in range(num_samples):
axes.scatter(i, val_sxr[i] - pred_sxr[i], label='Soft X-ray Flux Difference', color='blue')
axes.set_xlabel("Index")
axes.set_ylabel("Soft X-ray Flux Difference (True - Pred.) [W/m2]")
fig.tight_layout()
return fig
class AttentionMapCallback(Callback):
"""
PyTorch Lightning callback for visualizing transformer attention maps
during validation epochs.
Supports CLS-token-based and local patch attention visualization.
"""
def __init__(self, log_every_n_epochs=1, num_samples=4, save_dir="attention_maps",
patch_size=8, use_local_attention=False):
"""
Initialize callback.
Parameters
----------
log_every_n_epochs : int
Frequency of logging attention maps.
num_samples : int
Number of samples to visualize per epoch.
save_dir : str
Directory to save attention visualizations.
patch_size : int
Patch size used in the Vision Transformer.
use_local_attention : bool
If True, visualize local attention patterns instead of CLS attention.
"""
super().__init__()
self.patch_size = patch_size
self.log_every_n_epochs = log_every_n_epochs
self.num_samples = num_samples
self.save_dir = save_dir
self.use_local_attention = use_local_attention
def on_validation_epoch_end(self, trainer, pl_module):
"""
Trigger visualization of attention maps at the end of validation epochs.
"""
if trainer.current_epoch % self.log_every_n_epochs == 0:
self._visualize_attention(trainer, pl_module)
def _visualize_attention(self, trainer, pl_module):
"""
Generate and log attention maps from the model's attention weights.
"""
val_dataloader = trainer.val_dataloaders
if val_dataloader is None:
return
pl_module.eval()
with torch.no_grad():
batch = next(iter(val_dataloader))
imgs, labels = batch
imgs = imgs[:self.num_samples].to(pl_module.device)
patch_flux_raw = None
try:
outputs, attention_weights = pl_module(imgs, return_attention=True)
except:
if hasattr(pl_module, 'model') and hasattr(pl_module.model, 'forward'):
try:
print("Using model's forward method")
outputs, attention_weights, patch_flux_raw = pl_module.model(
imgs, pl_module.sxr_norm, return_attention=True)
except:
print("Using model's forward method failed")
outputs, attention_weights = pl_module.forward_for_callback(imgs, return_attention=True)
else:
outputs, attention_weights = pl_module.forward_for_callback(imgs, return_attention=True)
for sample_idx in range(min(self.num_samples, imgs.size(0))):
map = self._plot_attention_map(
imgs[sample_idx],
attention_weights,
sample_idx,
trainer.current_epoch,
patch_size=self.patch_size,
patch_flux=patch_flux_raw[sample_idx] if patch_flux_raw is not None else None
)
trainer.logger.experiment.log({"Attention plots": wandb.Image(map)})
plt.close(map)
def _plot_attention_map(self, image, attention_weights, sample_idx, epoch, patch_size, patch_flux=None):
"""
Plot and return a visualization of the attention heatmaps for a single image.
Parameters
----------
image : torch.Tensor
Input image tensor.
attention_weights : list[torch.Tensor]
List of attention weight tensors from transformer layers.
sample_idx : int
Index of the sample in the batch.
epoch : int
Current training epoch.
patch_size : int
Patch size used in ViT.
patch_flux : torch.Tensor, optional
Optional tensor containing patch flux contributions.
"""
img_np = image.cpu().numpy()
if len(img_np.shape) == 3 and img_np.shape[0] in [1, 3]:
img_np = np.transpose(img_np, (1, 2, 0))
H, W = img_np.shape[:2]
grid_h, grid_w = H // patch_size, W // patch_size
last_layer_attention = attention_weights[-1]
sample_attention = last_layer_attention[sample_idx]
avg_attention = sample_attention.mean(dim=0)
if self.use_local_attention:
center_patch_idx = (grid_h * grid_w) // 2
center_attention = avg_attention[center_patch_idx, :].cpu()
avg_attention_map = avg_attention.mean(dim=0).cpu()
attention_map = avg_attention_map.reshape(grid_h, grid_w)
center_map = center_attention.reshape(grid_h, grid_w)
else:
cls_attention = avg_attention[0, 1:].cpu()
attention_map = cls_attention.reshape(grid_h, grid_w)
center_map = None
if len(img_np[0, 0, :]) >= 6:
rgb_channels = [0, 2, 4]
img_display = np.stack([(img_np[:, :, i] + 1) / 2 for i in rgb_channels], axis=2)
img_display = np.clip(img_display, 0, 1)
else:
img_display = (img_np[:, :, 0] + 1) / 2
img_display = np.stack([img_display] * 3, axis=2)
# Create the figure and subplots
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
fig.suptitle(f'Attention Visualization - Epoch {epoch}, Sample {sample_idx}', fontsize=16)
# Plot 1: Original image
axes[0, 0].imshow(img_display)
axes[0, 0].set_title('Original Image')
axes[0, 0].axis('off')
# Plot 2: Attention map
im1 = axes[0, 1].imshow(attention_map, cmap='hot', interpolation='nearest')
axes[0, 1].set_title('Attention Map')
axes[0, 1].axis('off')
plt.colorbar(im1, ax=axes[0, 1])
# Plot 3: Overlay
axes[0, 2].imshow(img_display)
axes[0, 2].imshow(attention_map, cmap='hot', alpha=0.6, interpolation='nearest')
axes[0, 2].set_title('Attention Overlay')
axes[0, 2].axis('off')
# Plot 4: Center attention (if available)
if center_map is not None:
im2 = axes[1, 0].imshow(center_map, cmap='hot', interpolation='nearest')
axes[1, 0].set_title('Center Patch Attention')
axes[1, 0].axis('off')
plt.colorbar(im2, ax=axes[1, 0])
else:
axes[1, 0].text(0.5, 0.5, 'Center attention\nnot available',
ha='center', va='center', transform=axes[1, 0].transAxes)
axes[1, 0].set_title('Center Patch Attention')
axes[1, 0].axis('off')
# Plot 5: Patch flux (if available)
if patch_flux is not None:
patch_flux_np = patch_flux.cpu().numpy().reshape(grid_h, grid_w)
im3 = axes[1, 1].imshow(patch_flux_np, cmap='viridis', interpolation='nearest')
axes[1, 1].set_title('Patch Flux')
axes[1, 1].axis('off')
plt.colorbar(im3, ax=axes[1, 1])
else:
axes[1, 1].text(0.5, 0.5, 'Patch flux\nnot available',
ha='center', va='center', transform=axes[1, 1].transAxes)
axes[1, 1].set_title('Patch Flux')
axes[1, 1].axis('off')
# Plot 6: Attention statistics
axes[1, 2].hist(attention_map.flatten(), bins=50, alpha=0.7)
axes[1, 2].set_title('Attention Distribution')
axes[1, 2].set_xlabel('Attention Weight')
axes[1, 2].set_ylabel('Frequency')
plt.tight_layout()
return fig
class MultiHeadAttentionCallback(AttentionMapCallback):
"""
Extended callback that visualizes not only averaged attention maps
but also the attention distributions of individual transformer heads.
"""
def _plot_attention_map(self, image, attention_weights, sample_idx, epoch, patch_size):
"""
Override: Plot both average and per-head attention maps.
"""
super()._plot_attention_map(image, attention_weights, sample_idx, epoch, patch_size)
self._plot_individual_heads(image, attention_weights, sample_idx, epoch, patch_size)
def _plot_individual_heads(self, image, attention_weights, sample_idx, epoch, patch_size):
"""
Visualize attention for each individual head separately.
Parameters
----------
image : torch.Tensor
Input image tensor.
attention_weights : list[torch.Tensor]
List of attention tensors from model layers.
sample_idx : int
Sample index within the batch.
epoch : int
Current training epoch number.
patch_size : int
Patch size used in ViT.
"""
img_np = image.cpu().numpy()
last_layer_attention = attention_weights[-1][sample_idx]
num_heads = last_layer_attention.size(0)
H, W = img_np.shape[:2]
grid_h, grid_w = H // patch_size, W // patch_size
cols = min(4, num_heads)
rows = (num_heads + cols - 1) // cols
fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 4 * rows))
if num_heads == 1:
axes = [axes]
elif rows == 1:
axes = axes.reshape(1, -1)
for head_idx in range(num_heads):
row = head_idx // cols
col = head_idx % cols
head_attention = last_layer_attention[head_idx, 0, 1:].cpu()
attention_map = head_attention.reshape(grid_h, grid_w)
ax = axes[row, col] if rows > 1 else axes[col]
im = ax.imshow(attention_map.numpy(), cmap='hot', interpolation='nearest')
ax.set_title(f'Head {head_idx}')
ax.axis('off')
plt.colorbar(im, ax=ax)
for idx in range(num_heads, rows * cols):
row = idx // cols
col = idx % cols
ax = axes[row, col] if rows > 1 else axes[col]
ax.axis('off')
plt.tight_layout()
plt.savefig(f'{self.save_dir}/heads_epoch_{epoch}_sample_{sample_idx}.png',
dpi=150, bbox_inches='tight')
plt.close()