from collections import deque import math import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import pytorch_lightning as pl from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts def normalize_sxr(unnormalized_values, sxr_norm): """Convert from unnormalized to normalized space""" log_values = torch.log10(unnormalized_values + 1e-8) normalized = (log_values - float(sxr_norm[0].item())) / float(sxr_norm[1].item()) return normalized def unnormalize_sxr(normalized_values, sxr_norm): return 10 ** (normalized_values * float(sxr_norm[1].item()) + float(sxr_norm[0].item())) - 1e-8 class ViTLocal(pl.LightningModule): def __init__(self, model_kwargs, sxr_norm, base_weights=None): super().__init__() self.model_kwargs = model_kwargs self.lr = model_kwargs['lr'] self.save_hyperparameters() filtered_kwargs = dict(model_kwargs) filtered_kwargs.pop('lr', None) filtered_kwargs.pop('num_classes', None) self.model = VisionTransformerLocal(**filtered_kwargs) # Set the base weights based on the number of samples in each class within training data self.base_weights = base_weights self.adaptive_loss = SXRRegressionDynamicLoss(window_size=15000, base_weights=self.base_weights) self.sxr_norm = sxr_norm def forward(self, x, return_attention=True): return self.model(x, self.sxr_norm, return_attention=return_attention) def forward_for_callback(self, x, return_attention=True): return self.model(x, self.sxr_norm, return_attention=return_attention) def configure_optimizers(self): # Use AdamW with weight decay for better regularization optimizer = torch.optim.AdamW( self.parameters(), lr=self.lr, weight_decay=0.00001, ) scheduler = CosineAnnealingWarmRestarts( optimizer, T_0=250, T_mult=2, eta_min=1e-7 ) return { 'optimizer': optimizer, 'lr_scheduler': { 'scheduler': scheduler, 'interval': 'epoch', 'frequency': 1, 'name': 'learning_rate' } } def _calculate_loss(self, batch, mode="train"): imgs, sxr = batch raw_preds, raw_patch_contributions = self.model(imgs, self.sxr_norm) raw_preds_squeezed = torch.squeeze(raw_preds) sxr_un = unnormalize_sxr(sxr, self.sxr_norm) norm_preds_squeezed = normalize_sxr(raw_preds_squeezed, self.sxr_norm) # Use adaptive rare event loss loss, weights = self.adaptive_loss.calculate_loss( norm_preds_squeezed, sxr, sxr_un ) # Also calculate huber loss for logging huber_loss = F.huber_loss(norm_preds_squeezed, sxr, delta=.3) mse_loss = F.mse_loss(norm_preds_squeezed, sxr) mae_loss = F.l1_loss(norm_preds_squeezed, sxr) rmse_loss = torch.sqrt(mse_loss) # Log adaptation info if mode == "train": # Always log learning rate (every step) current_lr = self.trainer.optimizers[0].param_groups[0]['lr'] self.log('train/learning_rate', current_lr, on_step=True, on_epoch=False, prog_bar=True, logger=True, sync_dist=True) self.log("train/total_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) self.log("train/huber_loss", huber_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) self.log("train/mse_loss", mse_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) self.log("train/mae_loss", mae_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) self.log("train/rmse_loss", rmse_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) # Detailed diagnostics only every 200 steps if self.global_step % 200 == 0: multipliers = self.adaptive_loss.get_current_multipliers() for key, value in multipliers.items(): self.log(f"adaptive/{key}", value, on_step=True, on_epoch=False, sync_dist=True) if mode == "val": # Validation: typically only log epoch aggregates multipliers = self.adaptive_loss.get_current_multipliers() for key, value in multipliers.items(): self.log(f"val/adaptive/{key}", value, on_step=False, on_epoch=True) self.log("val_total_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) self.log("val_huber_loss", huber_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) self.log("val_mse_loss", mse_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) self.log("val_mae_loss", mae_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) self.log("val_rmse_loss", rmse_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) return loss def training_step(self, batch, batch_idx): return self._calculate_loss(batch, mode="train") def validation_step(self, batch, batch_idx): self._calculate_loss(batch, mode="val") def test_step(self, batch, batch_idx): self._calculate_loss(batch, mode="test") class VisionTransformerLocal(nn.Module): def __init__( self, embed_dim, hidden_dim, num_channels, num_heads, num_layers, patch_size, num_patches, dropout, ): """Vision Transformer that outputs flux contributions per patch. Args: embed_dim: Dimensionality of the input feature vectors to the Transformer hidden_dim: Dimensionality of the hidden layer in the feed-forward networks within the Transformer num_channels: Number of channels of the input (3 for RGB) num_heads: Number of heads to use in the Multi-Head Attention block num_layers: Number of layers to use in the Transformer patch_size: Number of pixels that the patches have per dimension num_patches: Maximum number of patches an image can have dropout: Amount of dropout to apply in the feed-forward network and on the input encoding """ super().__init__() self.patch_size = patch_size # Layers/Networks self.input_layer = nn.Linear(num_channels * (patch_size ** 2), embed_dim) self.transformer_blocks = nn.ModuleList([ InvertedAttentionBlock(embed_dim, hidden_dim, num_heads, num_patches, dropout=dropout) for _ in range(num_layers) ]) self.mlp_head = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1)) self.dropout = nn.Dropout(dropout) # Parameters/Embeddings - using 2D positional encoding for local attention # No CLS token needed for local attention architecture self.grid_h = int(math.sqrt(num_patches)) self.grid_w = int(math.sqrt(num_patches)) self.pos_embedding_2d = nn.Parameter(torch.randn(1, self.grid_h, self.grid_w, embed_dim)) def forward(self, x, sxr_norm, return_attention=False): # Preprocess input x = img_to_patch(x, self.patch_size) B, T, _ = x.shape x = self.input_layer(x) # Add positional encoding (no CLS token for local attention) x = self._add_2d_positional_encoding(x) # Apply Transformer blocks x = self.dropout(x) x = x.transpose(0, 1) # [T, B, embed_dim] attention_weights = [] for block in self.transformer_blocks: if return_attention: x, attn_weights = block(x, return_attention=True) attention_weights.append(attn_weights) else: x = block(x) patch_embeddings = x.transpose(0, 1) # [B, num_patches, embed_dim] patch_logits = self.mlp_head(patch_embeddings).squeeze(-1) # normalized log predictions [B, num_patches] # --- Convert to raw SXR --- mean, std = sxr_norm # in log10 space patch_flux_raw = torch.clamp(10 ** (patch_logits * std + mean) - 1e-8, min=1e-15, max=1) # Sum over patches for raw global flux global_flux_raw = patch_flux_raw.sum(dim=1, keepdim=True) # Ensure global flux is never zero (add small epsilon if needed) global_flux_raw = torch.clamp(global_flux_raw, min=1e-15) if return_attention: return global_flux_raw, attention_weights, patch_flux_raw else: return global_flux_raw, patch_flux_raw def _add_2d_positional_encoding(self, x): """Add learned 2D positional encoding to patch embeddings""" B, T, embed_dim = x.shape num_patches = T # All tokens are patches (no CLS token) # Reshape patches to 2D grid: [B, grid_h, grid_w, embed_dim] patch_embeddings = x.reshape(B, self.grid_h, self.grid_w, embed_dim) # Add learned 2D positional encoding # Broadcasting: [B, grid_h, grid_w, embed_dim] + [1, grid_h, grid_w, embed_dim] patch_embeddings = patch_embeddings + self.pos_embedding_2d # Reshape back to sequence format: [B, num_patches, embed_dim] x = patch_embeddings.reshape(B, num_patches, embed_dim) return x def forward_for_callback(self, x, return_attention=True): """Forward method compatible with AttentionMapCallback""" global_flux_raw, attention_weights, patch_flux_raw = self.forward(x, return_attention=return_attention) # Callback expects (outputs, attention_weights, _) return global_flux_raw, attention_weights class AttentionBlock(nn.Module): def __init__(self, embed_dim, hidden_dim, num_heads, dropout=0.0): """Attention Block. Args: embed_dim: Dimensionality of input and attention feature vectors hidden_dim: Dimensionality of hidden layer in feed-forward network (usually 2-4x larger than embed_dim) num_heads: Number of heads to use in the Multi-Head Attention block dropout: Amount of dropout to apply in the feed-forward network """ super().__init__() self.layer_norm_1 = nn.LayerNorm(embed_dim) self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=False) self.layer_norm_2 = nn.LayerNorm(embed_dim) self.linear = nn.Sequential( nn.Linear(embed_dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, embed_dim), nn.Dropout(dropout), ) def forward(self, x, return_attention=False): inp_x = self.layer_norm_1(x) if return_attention: attn_output, attn_weights = self.attn(inp_x, inp_x, inp_x, average_attn_weights=False) x = x + attn_output x = x + self.linear(self.layer_norm_2(x)) return x, attn_weights else: attn_output = self.attn(inp_x, inp_x, inp_x)[0] x = x + attn_output x = x + self.linear(self.layer_norm_2(x)) return x class InvertedAttentionBlock(nn.Module): def __init__(self, embed_dim, hidden_dim, num_heads, num_patches, dropout=0.0, local_window=9): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.local_window = local_window self.num_patches = num_patches self.layer_norm_1 = nn.LayerNorm(embed_dim) self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=False) self.layer_norm_2 = nn.LayerNorm(embed_dim) self.linear = nn.Sequential( nn.Linear(embed_dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, embed_dim), nn.Dropout(dropout), ) # Pre-compute attention mask for local interactions self.register_buffer('attention_mask', self._create_inverted_attention_mask()) def _create_inverted_attention_mask(self): """Create attention mask for local interactions only""" # This creates a mask where only distant patches can attend to each other num_patches = self.num_patches grid_size = int(math.sqrt(num_patches)) # Create mask for patches only: [num_patches, num_patches] mask = torch.zeros(num_patches, num_patches) # Patches can only attend to nearby patches for i in range(num_patches): row_i, col_i = i // grid_size, i % grid_size for j in range(num_patches): row_j, col_j = j // grid_size, j % grid_size # Only allow attention if patches are within local_window distance if abs(row_i - row_j) <= self.local_window // 2 and abs(col_i - col_j) <= self.local_window // 2: mask[i, j] = 1 return mask.bool() def forward(self, x, return_attention=False): inp_x = self.layer_norm_1(x) if return_attention: # Apply local attention mask attn_output, attn_weights = self.attn( inp_x, inp_x, inp_x, attn_mask=self.attention_mask, average_attn_weights=False ) x = x + attn_output x = x + self.linear(self.layer_norm_2(x)) return x, attn_weights else: attn_output = self.attn( inp_x, inp_x, inp_x, attn_mask=self.attention_mask )[0] x = x + attn_output x = x + self.linear(self.layer_norm_2(x)) return x def img_to_patch(x, patch_size, flatten_channels=True): """ Args: x: Tensor representing the image of shape [B, H, W, C] patch_size: Number of pixels per dimension of the patches (integer) flatten_channels: If True, the patches will be returned in a flattened format as a feature vector instead of a image grid. """ x = x.permute(0, 3, 1, 2) B, C, H, W = x.shape x = x.reshape(B, C, H // patch_size, patch_size, W // patch_size, patch_size) x = x.permute(0, 2, 4, 1, 3, 5) # [B, H', W', C, p_H, p_W] x = x.flatten(1, 2) # [B, H'*W', C, p_H, p_W] if flatten_channels: x = x.flatten(2, 4) # [B, H'*W', C*p_H*p_W] return x class SXRRegressionDynamicLoss: def __init__(self, window_size=15000, base_weights=None): self.c_threshold = 1e-6 self.m_threshold = 1e-5 self.x_threshold = 1e-4 self.window_size = window_size self.quiet_errors = deque(maxlen=window_size) self.c_errors = deque(maxlen=window_size) self.m_errors = deque(maxlen=window_size) self.x_errors = deque(maxlen=window_size) # Calculate the base weights based on the number of samples in each class within training data if base_weights is None: self.base_weights = self._get_base_weights() else: self.base_weights = base_weights def _get_base_weights(self): # Base weights based on the number of samples in each class within training data return { 'quiet': 6.643528005464481, # Increase from current value 'c_class': 1.626986450317832, # Keep as baseline 'm_class': 4.724573441010383, # Maintain M-class focus 'x_class': 43.13137472283814 # Maintain X-class focus } def calculate_loss(self, preds_norm, sxr_norm, sxr_un): base_loss = F.huber_loss(preds_norm, sxr_norm, delta=.3, reduction='none') weights = self._get_adaptive_weights(sxr_un) self._update_tracking(sxr_un, sxr_norm, preds_norm) weighted_loss = base_loss * weights loss = weighted_loss.mean() return loss, weights def _get_adaptive_weights(self, sxr_un): device = sxr_un.device # Get continuous multipliers per class with custom params quiet_mult = self._get_performance_multiplier( self.quiet_errors, max_multiplier=1.5, min_multiplier=0.6, sensitivity=0.05, sxrclass='quiet' ) c_mult = self._get_performance_multiplier( self.c_errors, max_multiplier=2, min_multiplier=0.7, sensitivity=0.08, sxrclass='c_class' ) m_mult = self._get_performance_multiplier( self.m_errors, max_multiplier=5.0, min_multiplier=0.8, sensitivity=0.1, sxrclass='m_class' ) x_mult = self._get_performance_multiplier( self.x_errors, max_multiplier=8.0, min_multiplier=0.8, sensitivity=0.12, sxrclass='x_class' ) quiet_weight = self.base_weights['quiet'] * quiet_mult c_weight = self.base_weights['c_class'] * c_mult m_weight = self.base_weights['m_class'] * m_mult x_weight = self.base_weights['x_class'] * x_mult weights = torch.ones_like(sxr_un, device=device) weights = torch.where(sxr_un < self.c_threshold, quiet_weight, weights) weights = torch.where((sxr_un >= self.c_threshold) & (sxr_un < self.m_threshold), c_weight, weights) weights = torch.where((sxr_un >= self.m_threshold) & (sxr_un < self.x_threshold), m_weight, weights) weights = torch.where(sxr_un >= self.x_threshold, x_weight, weights) # Normalize so mean weight ~1.0 (optional, helps stability) mean_weight = torch.mean(weights) weights = weights / (mean_weight) # Save for logging self.current_multipliers = { 'quiet_mult': quiet_mult, 'c_mult': c_mult, 'm_mult': m_mult, 'x_mult': x_mult, 'quiet_weight': quiet_weight, 'c_weight': c_weight, 'm_weight': m_weight, 'x_weight': x_weight } return weights def _get_performance_multiplier(self, error_history, max_multiplier=10.0, min_multiplier=0.5, sensitivity=3.0, sxrclass='quiet'): """Class-dependent performance multiplier""" class_params = { 'quiet': {'min_samples': 2500, 'recent_window': 800}, 'c_class': {'min_samples': 2500, 'recent_window': 800}, 'm_class': {'min_samples': 1500, 'recent_window': 500}, 'x_class': {'min_samples': 1000, 'recent_window': 300} } if len(error_history) < class_params[sxrclass]['min_samples']: return 1.0 recent_window = class_params[sxrclass]['recent_window'] recent = np.mean(list(error_history)[-recent_window:]) overall = np.mean(list(error_history)) ratio = recent / overall multiplier = np.exp(sensitivity * (ratio - 1)) return np.clip(multiplier, min_multiplier, max_multiplier) def _update_tracking(self, sxr_un, sxr_norm, preds_norm): sxr_un_np = sxr_un.detach().cpu().numpy() # Huber loss error = F.huber_loss(preds_norm, sxr_norm, delta=.3, reduction='none') error = error.detach().cpu().numpy() quiet_mask = sxr_un_np < self.c_threshold if quiet_mask.sum() > 0: self.quiet_errors.append(float(np.mean(error[quiet_mask]))) c_mask = (sxr_un_np >= self.c_threshold) & (sxr_un_np < self.m_threshold) if c_mask.sum() > 0: self.c_errors.append(float(np.mean(error[c_mask]))) m_mask = (sxr_un_np >= self.m_threshold) & (sxr_un_np < self.x_threshold) if m_mask.sum() > 0: self.m_errors.append(float(np.mean(error[m_mask]))) x_mask = sxr_un_np >= self.x_threshold if x_mask.sum() > 0: self.x_errors.append(float(np.mean(error[x_mask]))) def get_current_multipliers(self): """Get current performance multipliers for logging""" return { 'quiet_mult': self._get_performance_multiplier( self.quiet_errors, max_multiplier=1.5, min_multiplier=0.6, sensitivity=0.2, sxrclass='quiet' ), 'c_mult': self._get_performance_multiplier( self.c_errors, max_multiplier=2, min_multiplier=0.7, sensitivity=0.3, sxrclass='c_class' ), 'm_mult': self._get_performance_multiplier( self.m_errors, max_multiplier=5.0, min_multiplier=0.8, sensitivity=0.8, sxrclass='m_class' ), 'x_mult': self._get_performance_multiplier( self.x_errors, max_multiplier=8.0, min_multiplier=0.8, sensitivity=1.0, sxrclass='x_class' ), 'quiet_count': len(self.quiet_errors), 'c_count': len(self.c_errors), 'm_count': len(self.m_errors), 'x_count': len(self.x_errors), 'quiet_error': np.mean(self.quiet_errors) if self.quiet_errors else 0.0, 'c_error': np.mean(self.c_errors) if self.c_errors else 0.0, 'm_error': np.mean(self.m_errors) if self.m_errors else 0.0, 'x_error': np.mean(self.x_errors) if self.x_errors else 0.0, 'quiet_weight': getattr(self, 'current_multipliers', {}).get('quiet_weight', 0.0), 'c_weight': getattr(self, 'current_multipliers', {}).get('c_weight', 0.0), 'm_weight': getattr(self, 'current_multipliers', {}).get('m_weight', 0.0), 'x_weight': getattr(self, 'current_multipliers', {}).get('x_weight', 0.0) }