Spaces:
Sleeping
Sleeping
| from collections import deque | |
| import math | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import pytorch_lightning as pl | |
| from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts | |
| def normalize_sxr(unnormalized_values, sxr_norm): | |
| """Convert from unnormalized to normalized space""" | |
| log_values = torch.log10(unnormalized_values + 1e-8) | |
| normalized = (log_values - float(sxr_norm[0].item())) / float(sxr_norm[1].item()) | |
| return normalized | |
| def unnormalize_sxr(normalized_values, sxr_norm): | |
| return 10 ** (normalized_values * float(sxr_norm[1].item()) + float(sxr_norm[0].item())) - 1e-8 | |
| class ViTLocal(pl.LightningModule): | |
| def __init__(self, model_kwargs, sxr_norm, base_weights=None): | |
| super().__init__() | |
| self.model_kwargs = model_kwargs | |
| self.lr = model_kwargs['lr'] | |
| self.save_hyperparameters() | |
| filtered_kwargs = dict(model_kwargs) | |
| filtered_kwargs.pop('lr', None) | |
| filtered_kwargs.pop('num_classes', None) | |
| self.model = VisionTransformerLocal(**filtered_kwargs) | |
| # Set the base weights based on the number of samples in each class within training data | |
| self.base_weights = base_weights | |
| self.adaptive_loss = SXRRegressionDynamicLoss(window_size=15000, base_weights=self.base_weights) | |
| self.sxr_norm = sxr_norm | |
| def forward(self, x, return_attention=True): | |
| return self.model(x, self.sxr_norm, return_attention=return_attention) | |
| def forward_for_callback(self, x, return_attention=True): | |
| return self.model(x, self.sxr_norm, return_attention=return_attention) | |
| def configure_optimizers(self): | |
| # Use AdamW with weight decay for better regularization | |
| optimizer = torch.optim.AdamW( | |
| self.parameters(), | |
| lr=self.lr, | |
| weight_decay=0.00001, | |
| ) | |
| scheduler = CosineAnnealingWarmRestarts( | |
| optimizer, | |
| T_0=250, | |
| T_mult=2, | |
| eta_min=1e-7 | |
| ) | |
| return { | |
| 'optimizer': optimizer, | |
| 'lr_scheduler': { | |
| 'scheduler': scheduler, | |
| 'interval': 'epoch', | |
| 'frequency': 1, | |
| 'name': 'learning_rate' | |
| } | |
| } | |
| def _calculate_loss(self, batch, mode="train"): | |
| imgs, sxr = batch | |
| raw_preds, raw_patch_contributions = self.model(imgs, self.sxr_norm) | |
| raw_preds_squeezed = torch.squeeze(raw_preds) | |
| sxr_un = unnormalize_sxr(sxr, self.sxr_norm) | |
| norm_preds_squeezed = normalize_sxr(raw_preds_squeezed, self.sxr_norm) | |
| # Use adaptive rare event loss | |
| loss, weights = self.adaptive_loss.calculate_loss( | |
| norm_preds_squeezed, sxr, sxr_un | |
| ) | |
| # Also calculate huber loss for logging | |
| huber_loss = F.huber_loss(norm_preds_squeezed, sxr, delta=.3) | |
| mse_loss = F.mse_loss(norm_preds_squeezed, sxr) | |
| mae_loss = F.l1_loss(norm_preds_squeezed, sxr) | |
| rmse_loss = torch.sqrt(mse_loss) | |
| # Log adaptation info | |
| if mode == "train": | |
| # Always log learning rate (every step) | |
| current_lr = self.trainer.optimizers[0].param_groups[0]['lr'] | |
| self.log('train/learning_rate', current_lr, on_step=True, on_epoch=False, | |
| prog_bar=True, logger=True, sync_dist=True) | |
| self.log("train/total_loss", loss, on_step=True, on_epoch=True, | |
| prog_bar=True, logger=True, sync_dist=True) | |
| self.log("train/huber_loss", huber_loss, on_step=True, on_epoch=True, | |
| prog_bar=True, logger=True, sync_dist=True) | |
| self.log("train/mse_loss", mse_loss, on_step=True, on_epoch=True, | |
| prog_bar=True, logger=True, sync_dist=True) | |
| self.log("train/mae_loss", mae_loss, on_step=True, on_epoch=True, | |
| prog_bar=True, logger=True, sync_dist=True) | |
| self.log("train/rmse_loss", rmse_loss, on_step=True, on_epoch=True, | |
| prog_bar=True, logger=True, sync_dist=True) | |
| # Detailed diagnostics only every 200 steps | |
| if self.global_step % 200 == 0: | |
| multipliers = self.adaptive_loss.get_current_multipliers() | |
| for key, value in multipliers.items(): | |
| self.log(f"adaptive/{key}", value, on_step=True, on_epoch=False, sync_dist=True) | |
| if mode == "val": | |
| # Validation: typically only log epoch aggregates | |
| multipliers = self.adaptive_loss.get_current_multipliers() | |
| for key, value in multipliers.items(): | |
| self.log(f"val/adaptive/{key}", value, on_step=False, on_epoch=True) | |
| self.log("val_total_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) | |
| self.log("val_huber_loss", huber_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, | |
| sync_dist=True) | |
| self.log("val_mse_loss", mse_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) | |
| self.log("val_mae_loss", mae_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) | |
| self.log("val_rmse_loss", rmse_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, | |
| sync_dist=True) | |
| return loss | |
| def training_step(self, batch, batch_idx): | |
| return self._calculate_loss(batch, mode="train") | |
| def validation_step(self, batch, batch_idx): | |
| self._calculate_loss(batch, mode="val") | |
| def test_step(self, batch, batch_idx): | |
| self._calculate_loss(batch, mode="test") | |
| class VisionTransformerLocal(nn.Module): | |
| def __init__( | |
| self, | |
| embed_dim, | |
| hidden_dim, | |
| num_channels, | |
| num_heads, | |
| num_layers, | |
| patch_size, | |
| num_patches, | |
| dropout, | |
| ): | |
| """Vision Transformer that outputs flux contributions per patch. | |
| Args: | |
| embed_dim: Dimensionality of the input feature vectors to the Transformer | |
| hidden_dim: Dimensionality of the hidden layer in the feed-forward networks | |
| within the Transformer | |
| num_channels: Number of channels of the input (3 for RGB) | |
| num_heads: Number of heads to use in the Multi-Head Attention block | |
| num_layers: Number of layers to use in the Transformer | |
| patch_size: Number of pixels that the patches have per dimension | |
| num_patches: Maximum number of patches an image can have | |
| dropout: Amount of dropout to apply in the feed-forward network and | |
| on the input encoding | |
| """ | |
| super().__init__() | |
| self.patch_size = patch_size | |
| # Layers/Networks | |
| self.input_layer = nn.Linear(num_channels * (patch_size ** 2), embed_dim) | |
| self.transformer_blocks = nn.ModuleList([ | |
| InvertedAttentionBlock(embed_dim, hidden_dim, num_heads, num_patches, dropout=dropout) | |
| for _ in range(num_layers) | |
| ]) | |
| self.mlp_head = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1)) | |
| self.dropout = nn.Dropout(dropout) | |
| # Parameters/Embeddings - using 2D positional encoding for local attention | |
| # No CLS token needed for local attention architecture | |
| self.grid_h = int(math.sqrt(num_patches)) | |
| self.grid_w = int(math.sqrt(num_patches)) | |
| self.pos_embedding_2d = nn.Parameter(torch.randn(1, self.grid_h, self.grid_w, embed_dim)) | |
| def forward(self, x, sxr_norm, return_attention=False): | |
| # Preprocess input | |
| x = img_to_patch(x, self.patch_size) | |
| B, T, _ = x.shape | |
| x = self.input_layer(x) | |
| # Add positional encoding (no CLS token for local attention) | |
| x = self._add_2d_positional_encoding(x) | |
| # Apply Transformer blocks | |
| x = self.dropout(x) | |
| x = x.transpose(0, 1) # [T, B, embed_dim] | |
| attention_weights = [] | |
| for block in self.transformer_blocks: | |
| if return_attention: | |
| x, attn_weights = block(x, return_attention=True) | |
| attention_weights.append(attn_weights) | |
| else: | |
| x = block(x) | |
| patch_embeddings = x.transpose(0, 1) # [B, num_patches, embed_dim] | |
| patch_logits = self.mlp_head(patch_embeddings).squeeze(-1) # normalized log predictions [B, num_patches] | |
| # --- Convert to raw SXR --- | |
| mean, std = sxr_norm # in log10 space | |
| patch_flux_raw = torch.clamp(10 ** (patch_logits * std + mean) - 1e-8, min=1e-15, max=1) | |
| # Sum over patches for raw global flux | |
| global_flux_raw = patch_flux_raw.sum(dim=1, keepdim=True) | |
| # Ensure global flux is never zero (add small epsilon if needed) | |
| global_flux_raw = torch.clamp(global_flux_raw, min=1e-15) | |
| if return_attention: | |
| return global_flux_raw, attention_weights, patch_flux_raw | |
| else: | |
| return global_flux_raw, patch_flux_raw | |
| def _add_2d_positional_encoding(self, x): | |
| """Add learned 2D positional encoding to patch embeddings""" | |
| B, T, embed_dim = x.shape | |
| num_patches = T # All tokens are patches (no CLS token) | |
| # Reshape patches to 2D grid: [B, grid_h, grid_w, embed_dim] | |
| patch_embeddings = x.reshape(B, self.grid_h, self.grid_w, embed_dim) | |
| # Add learned 2D positional encoding | |
| # Broadcasting: [B, grid_h, grid_w, embed_dim] + [1, grid_h, grid_w, embed_dim] | |
| patch_embeddings = patch_embeddings + self.pos_embedding_2d | |
| # Reshape back to sequence format: [B, num_patches, embed_dim] | |
| x = patch_embeddings.reshape(B, num_patches, embed_dim) | |
| return x | |
| def forward_for_callback(self, x, return_attention=True): | |
| """Forward method compatible with AttentionMapCallback""" | |
| global_flux_raw, attention_weights, patch_flux_raw = self.forward(x, return_attention=return_attention) | |
| # Callback expects (outputs, attention_weights, _) | |
| return global_flux_raw, attention_weights | |
| class AttentionBlock(nn.Module): | |
| def __init__(self, embed_dim, hidden_dim, num_heads, dropout=0.0): | |
| """Attention Block. | |
| Args: | |
| embed_dim: Dimensionality of input and attention feature vectors | |
| hidden_dim: Dimensionality of hidden layer in feed-forward network | |
| (usually 2-4x larger than embed_dim) | |
| num_heads: Number of heads to use in the Multi-Head Attention block | |
| dropout: Amount of dropout to apply in the feed-forward network | |
| """ | |
| super().__init__() | |
| self.layer_norm_1 = nn.LayerNorm(embed_dim) | |
| self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=False) | |
| self.layer_norm_2 = nn.LayerNorm(embed_dim) | |
| self.linear = nn.Sequential( | |
| nn.Linear(embed_dim, hidden_dim), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, embed_dim), | |
| nn.Dropout(dropout), | |
| ) | |
| def forward(self, x, return_attention=False): | |
| inp_x = self.layer_norm_1(x) | |
| if return_attention: | |
| attn_output, attn_weights = self.attn(inp_x, inp_x, inp_x, average_attn_weights=False) | |
| x = x + attn_output | |
| x = x + self.linear(self.layer_norm_2(x)) | |
| return x, attn_weights | |
| else: | |
| attn_output = self.attn(inp_x, inp_x, inp_x)[0] | |
| x = x + attn_output | |
| x = x + self.linear(self.layer_norm_2(x)) | |
| return x | |
| class InvertedAttentionBlock(nn.Module): | |
| def __init__(self, embed_dim, hidden_dim, num_heads, num_patches, dropout=0.0, local_window=9): | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.num_heads = num_heads | |
| self.local_window = local_window | |
| self.num_patches = num_patches | |
| self.layer_norm_1 = nn.LayerNorm(embed_dim) | |
| self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=False) | |
| self.layer_norm_2 = nn.LayerNorm(embed_dim) | |
| self.linear = nn.Sequential( | |
| nn.Linear(embed_dim, hidden_dim), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, embed_dim), | |
| nn.Dropout(dropout), | |
| ) | |
| # Pre-compute attention mask for local interactions | |
| self.register_buffer('attention_mask', self._create_inverted_attention_mask()) | |
| def _create_inverted_attention_mask(self): | |
| """Create attention mask for local interactions only""" | |
| # This creates a mask where only distant patches can attend to each other | |
| num_patches = self.num_patches | |
| grid_size = int(math.sqrt(num_patches)) | |
| # Create mask for patches only: [num_patches, num_patches] | |
| mask = torch.zeros(num_patches, num_patches) | |
| # Patches can only attend to nearby patches | |
| for i in range(num_patches): | |
| row_i, col_i = i // grid_size, i % grid_size | |
| for j in range(num_patches): | |
| row_j, col_j = j // grid_size, j % grid_size | |
| # Only allow attention if patches are within local_window distance | |
| if abs(row_i - row_j) <= self.local_window // 2 and abs(col_i - col_j) <= self.local_window // 2: | |
| mask[i, j] = 1 | |
| return mask.bool() | |
| def forward(self, x, return_attention=False): | |
| inp_x = self.layer_norm_1(x) | |
| if return_attention: | |
| # Apply local attention mask | |
| attn_output, attn_weights = self.attn( | |
| inp_x, inp_x, inp_x, | |
| attn_mask=self.attention_mask, | |
| average_attn_weights=False | |
| ) | |
| x = x + attn_output | |
| x = x + self.linear(self.layer_norm_2(x)) | |
| return x, attn_weights | |
| else: | |
| attn_output = self.attn( | |
| inp_x, inp_x, inp_x, | |
| attn_mask=self.attention_mask | |
| )[0] | |
| x = x + attn_output | |
| x = x + self.linear(self.layer_norm_2(x)) | |
| return x | |
| def img_to_patch(x, patch_size, flatten_channels=True): | |
| """ | |
| Args: | |
| x: Tensor representing the image of shape [B, H, W, C] | |
| patch_size: Number of pixels per dimension of the patches (integer) | |
| flatten_channels: If True, the patches will be returned in a flattened format | |
| as a feature vector instead of a image grid. | |
| """ | |
| x = x.permute(0, 3, 1, 2) | |
| B, C, H, W = x.shape | |
| x = x.reshape(B, C, H // patch_size, patch_size, W // patch_size, patch_size) | |
| x = x.permute(0, 2, 4, 1, 3, 5) # [B, H', W', C, p_H, p_W] | |
| x = x.flatten(1, 2) # [B, H'*W', C, p_H, p_W] | |
| if flatten_channels: | |
| x = x.flatten(2, 4) # [B, H'*W', C*p_H*p_W] | |
| return x | |
| class SXRRegressionDynamicLoss: | |
| def __init__(self, window_size=15000, base_weights=None): | |
| self.c_threshold = 1e-6 | |
| self.m_threshold = 1e-5 | |
| self.x_threshold = 1e-4 | |
| self.window_size = window_size | |
| self.quiet_errors = deque(maxlen=window_size) | |
| self.c_errors = deque(maxlen=window_size) | |
| self.m_errors = deque(maxlen=window_size) | |
| self.x_errors = deque(maxlen=window_size) | |
| # Calculate the base weights based on the number of samples in each class within training data | |
| if base_weights is None: | |
| self.base_weights = self._get_base_weights() | |
| else: | |
| self.base_weights = base_weights | |
| def _get_base_weights(self): | |
| # Base weights based on the number of samples in each class within training data | |
| return { | |
| 'quiet': 6.643528005464481, # Increase from current value | |
| 'c_class': 1.626986450317832, # Keep as baseline | |
| 'm_class': 4.724573441010383, # Maintain M-class focus | |
| 'x_class': 43.13137472283814 # Maintain X-class focus | |
| } | |
| def calculate_loss(self, preds_norm, sxr_norm, sxr_un): | |
| base_loss = F.huber_loss(preds_norm, sxr_norm, delta=.3, reduction='none') | |
| weights = self._get_adaptive_weights(sxr_un) | |
| self._update_tracking(sxr_un, sxr_norm, preds_norm) | |
| weighted_loss = base_loss * weights | |
| loss = weighted_loss.mean() | |
| return loss, weights | |
| def _get_adaptive_weights(self, sxr_un): | |
| device = sxr_un.device | |
| # Get continuous multipliers per class with custom params | |
| quiet_mult = self._get_performance_multiplier( | |
| self.quiet_errors, max_multiplier=1.5, min_multiplier=0.6, sensitivity=0.05, sxrclass='quiet' | |
| ) | |
| c_mult = self._get_performance_multiplier( | |
| self.c_errors, max_multiplier=2, min_multiplier=0.7, sensitivity=0.08, sxrclass='c_class' | |
| ) | |
| m_mult = self._get_performance_multiplier( | |
| self.m_errors, max_multiplier=5.0, min_multiplier=0.8, sensitivity=0.1, sxrclass='m_class' | |
| ) | |
| x_mult = self._get_performance_multiplier( | |
| self.x_errors, max_multiplier=8.0, min_multiplier=0.8, sensitivity=0.12, sxrclass='x_class' | |
| ) | |
| quiet_weight = self.base_weights['quiet'] * quiet_mult | |
| c_weight = self.base_weights['c_class'] * c_mult | |
| m_weight = self.base_weights['m_class'] * m_mult | |
| x_weight = self.base_weights['x_class'] * x_mult | |
| weights = torch.ones_like(sxr_un, device=device) | |
| weights = torch.where(sxr_un < self.c_threshold, quiet_weight, weights) | |
| weights = torch.where((sxr_un >= self.c_threshold) & (sxr_un < self.m_threshold), c_weight, weights) | |
| weights = torch.where((sxr_un >= self.m_threshold) & (sxr_un < self.x_threshold), m_weight, weights) | |
| weights = torch.where(sxr_un >= self.x_threshold, x_weight, weights) | |
| # Normalize so mean weight ~1.0 (optional, helps stability) | |
| mean_weight = torch.mean(weights) | |
| weights = weights / (mean_weight) | |
| # Save for logging | |
| self.current_multipliers = { | |
| 'quiet_mult': quiet_mult, | |
| 'c_mult': c_mult, | |
| 'm_mult': m_mult, | |
| 'x_mult': x_mult, | |
| 'quiet_weight': quiet_weight, | |
| 'c_weight': c_weight, | |
| 'm_weight': m_weight, | |
| 'x_weight': x_weight | |
| } | |
| return weights | |
| def _get_performance_multiplier(self, error_history, max_multiplier=10.0, min_multiplier=0.5, sensitivity=3.0, | |
| sxrclass='quiet'): | |
| """Class-dependent performance multiplier""" | |
| class_params = { | |
| 'quiet': {'min_samples': 2500, 'recent_window': 800}, | |
| 'c_class': {'min_samples': 2500, 'recent_window': 800}, | |
| 'm_class': {'min_samples': 1500, 'recent_window': 500}, | |
| 'x_class': {'min_samples': 1000, 'recent_window': 300} | |
| } | |
| if len(error_history) < class_params[sxrclass]['min_samples']: | |
| return 1.0 | |
| recent_window = class_params[sxrclass]['recent_window'] | |
| recent = np.mean(list(error_history)[-recent_window:]) | |
| overall = np.mean(list(error_history)) | |
| ratio = recent / overall | |
| multiplier = np.exp(sensitivity * (ratio - 1)) | |
| return np.clip(multiplier, min_multiplier, max_multiplier) | |
| def _update_tracking(self, sxr_un, sxr_norm, preds_norm): | |
| sxr_un_np = sxr_un.detach().cpu().numpy() | |
| # Huber loss | |
| error = F.huber_loss(preds_norm, sxr_norm, delta=.3, reduction='none') | |
| error = error.detach().cpu().numpy() | |
| quiet_mask = sxr_un_np < self.c_threshold | |
| if quiet_mask.sum() > 0: | |
| self.quiet_errors.append(float(np.mean(error[quiet_mask]))) | |
| c_mask = (sxr_un_np >= self.c_threshold) & (sxr_un_np < self.m_threshold) | |
| if c_mask.sum() > 0: | |
| self.c_errors.append(float(np.mean(error[c_mask]))) | |
| m_mask = (sxr_un_np >= self.m_threshold) & (sxr_un_np < self.x_threshold) | |
| if m_mask.sum() > 0: | |
| self.m_errors.append(float(np.mean(error[m_mask]))) | |
| x_mask = sxr_un_np >= self.x_threshold | |
| if x_mask.sum() > 0: | |
| self.x_errors.append(float(np.mean(error[x_mask]))) | |
| def get_current_multipliers(self): | |
| """Get current performance multipliers for logging""" | |
| return { | |
| 'quiet_mult': self._get_performance_multiplier( | |
| self.quiet_errors, max_multiplier=1.5, min_multiplier=0.6, sensitivity=0.2, sxrclass='quiet' | |
| ), | |
| 'c_mult': self._get_performance_multiplier( | |
| self.c_errors, max_multiplier=2, min_multiplier=0.7, sensitivity=0.3, sxrclass='c_class' | |
| ), | |
| 'm_mult': self._get_performance_multiplier( | |
| self.m_errors, max_multiplier=5.0, min_multiplier=0.8, sensitivity=0.8, sxrclass='m_class' | |
| ), | |
| 'x_mult': self._get_performance_multiplier( | |
| self.x_errors, max_multiplier=8.0, min_multiplier=0.8, sensitivity=1.0, sxrclass='x_class' | |
| ), | |
| 'quiet_count': len(self.quiet_errors), | |
| 'c_count': len(self.c_errors), | |
| 'm_count': len(self.m_errors), | |
| 'x_count': len(self.x_errors), | |
| 'quiet_error': np.mean(self.quiet_errors) if self.quiet_errors else 0.0, | |
| 'c_error': np.mean(self.c_errors) if self.c_errors else 0.0, | |
| 'm_error': np.mean(self.m_errors) if self.m_errors else 0.0, | |
| 'x_error': np.mean(self.x_errors) if self.x_errors else 0.0, | |
| 'quiet_weight': getattr(self, 'current_multipliers', {}).get('quiet_weight', 0.0), | |
| 'c_weight': getattr(self, 'current_multipliers', {}).get('c_weight', 0.0), | |
| 'm_weight': getattr(self, 'current_multipliers', {}).get('m_weight', 0.0), | |
| 'x_weight': getattr(self, 'current_multipliers', {}).get('x_weight', 0.0) | |
| } |