Merge remote-tracking branch 'origin/main' into dev-patch
Browse files- README.md +5 -1
- forecasting/models/__init__.py +0 -1
- forecasting/models/vit_patch_model_local.py +4 -5
- forecasting/training/callback.py +113 -60
- forecasting/training/localpatch.yaml +13 -8
- forecasting/training/train.py +46 -63
README.md
CHANGED
|
@@ -1 +1,5 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<<<<<<< HEAD
|
| 2 |
+
# 2025-HL-flaring-intelligence
|
| 3 |
+
=======
|
| 4 |
+
# FOXES
|
| 5 |
+
>>>>>>> origin/main
|
forecasting/models/__init__.py
CHANGED
|
@@ -1,2 +1 @@
|
|
| 1 |
-
from .fusion_vit_hybrid import FusionViTHybrid
|
| 2 |
|
|
|
|
|
|
|
| 1 |
|
forecasting/models/vit_patch_model_local.py
CHANGED
|
@@ -108,7 +108,7 @@ class ViTLocal(pl.LightningModule):
|
|
| 108 |
if self.global_step % 200 == 0:
|
| 109 |
multipliers = self.adaptive_loss.get_current_multipliers()
|
| 110 |
for key, value in multipliers.items():
|
| 111 |
-
self.log(f"adaptive/{key}", value, on_step=True, on_epoch=False)
|
| 112 |
|
| 113 |
if mode == "val":
|
| 114 |
# Validation: typically only log epoch aggregates
|
|
@@ -177,9 +177,8 @@ class VisionTransformerLocal(nn.Module):
|
|
| 177 |
self.mlp_head = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1))
|
| 178 |
self.dropout = nn.Dropout(dropout)
|
| 179 |
|
| 180 |
-
# Parameters/Embeddings
|
| 181 |
-
|
| 182 |
-
self.pos_embedding = nn.Parameter(torch.randn(1, 1 + num_patches, embed_dim))
|
| 183 |
self.grid_h = int(math.sqrt(num_patches))
|
| 184 |
self.grid_w = int(math.sqrt(num_patches))
|
| 185 |
self.pos_embedding_2d = nn.Parameter(torch.randn(1, self.grid_h, self.grid_w, embed_dim))
|
|
@@ -289,7 +288,7 @@ class AttentionBlock(nn.Module):
|
|
| 289 |
|
| 290 |
|
| 291 |
class LocalAttentionBlock(nn.Module):
|
| 292 |
-
def __init__(self, embed_dim, hidden_dim, num_heads, num_patches, dropout=0.0, local_window=
|
| 293 |
super().__init__()
|
| 294 |
self.embed_dim = embed_dim
|
| 295 |
self.num_heads = num_heads
|
|
|
|
| 108 |
if self.global_step % 200 == 0:
|
| 109 |
multipliers = self.adaptive_loss.get_current_multipliers()
|
| 110 |
for key, value in multipliers.items():
|
| 111 |
+
self.log(f"adaptive/{key}", value, on_step=True, on_epoch=False, sync_dist=True)
|
| 112 |
|
| 113 |
if mode == "val":
|
| 114 |
# Validation: typically only log epoch aggregates
|
|
|
|
| 177 |
self.mlp_head = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1))
|
| 178 |
self.dropout = nn.Dropout(dropout)
|
| 179 |
|
| 180 |
+
# Parameters/Embeddings - using 2D positional encoding for local attention
|
| 181 |
+
# No CLS token needed for local attention architecture
|
|
|
|
| 182 |
self.grid_h = int(math.sqrt(num_patches))
|
| 183 |
self.grid_w = int(math.sqrt(num_patches))
|
| 184 |
self.pos_embedding_2d = nn.Parameter(torch.randn(1, self.grid_h, self.grid_w, embed_dim))
|
|
|
|
| 288 |
|
| 289 |
|
| 290 |
class LocalAttentionBlock(nn.Module):
|
| 291 |
+
def __init__(self, embed_dim, hidden_dim, num_heads, num_patches, dropout=0.0, local_window=9):
|
| 292 |
super().__init__()
|
| 293 |
self.embed_dim = embed_dim
|
| 294 |
self.num_heads = num_heads
|
forecasting/training/callback.py
CHANGED
|
@@ -90,7 +90,7 @@ class ImagePredictionLogger_SXR(Callback):
|
|
| 90 |
|
| 91 |
|
| 92 |
class AttentionMapCallback(Callback):
|
| 93 |
-
def __init__(self, log_every_n_epochs=1, num_samples=4, save_dir="attention_maps", patch_size=8):
|
| 94 |
"""
|
| 95 |
Callback to visualize attention maps during training.
|
| 96 |
|
|
@@ -99,12 +99,14 @@ class AttentionMapCallback(Callback):
|
|
| 99 |
num_samples: Number of samples to visualize
|
| 100 |
save_dir: Directory to save attention maps
|
| 101 |
patch_size: Size of patches used in the model
|
|
|
|
| 102 |
"""
|
| 103 |
super().__init__()
|
| 104 |
self.patch_size = patch_size
|
| 105 |
self.log_every_n_epochs = log_every_n_epochs
|
| 106 |
self.num_samples = num_samples
|
| 107 |
self.save_dir = save_dir
|
|
|
|
| 108 |
|
| 109 |
def on_validation_epoch_end(self, trainer, pl_module):
|
| 110 |
if trainer.current_epoch % self.log_every_n_epochs == 0:
|
|
@@ -125,8 +127,8 @@ class AttentionMapCallback(Callback):
|
|
| 125 |
# Move to device
|
| 126 |
imgs = imgs[:self.num_samples].to(pl_module.device)
|
| 127 |
|
| 128 |
-
# Get predictions with attention weights
|
| 129 |
-
|
| 130 |
try:
|
| 131 |
outputs, attention_weights = pl_module(imgs, return_attention=True)
|
| 132 |
except:
|
|
@@ -134,7 +136,7 @@ class AttentionMapCallback(Callback):
|
|
| 134 |
if hasattr(pl_module, 'model') and hasattr(pl_module.model, 'forward'):
|
| 135 |
try:
|
| 136 |
print("Using model's forward method")
|
| 137 |
-
outputs, attention_weights,
|
| 138 |
except:
|
| 139 |
print("Using model's forward method failed")
|
| 140 |
outputs, attention_weights = pl_module.forward_for_callback(imgs, return_attention=True)
|
|
@@ -149,12 +151,13 @@ class AttentionMapCallback(Callback):
|
|
| 149 |
attention_weights,
|
| 150 |
sample_idx,
|
| 151 |
trainer.current_epoch,
|
| 152 |
-
patch_size=self.patch_size
|
|
|
|
| 153 |
)
|
| 154 |
trainer.logger.experiment.log({"Attention plots": wandb.Image(map)})
|
| 155 |
plt.close(map)
|
| 156 |
|
| 157 |
-
def _plot_attention_map(self, image, attention_weights, sample_idx, epoch, patch_size):
|
| 158 |
"""
|
| 159 |
Plot attention map for a single image.
|
| 160 |
|
|
@@ -164,50 +167,44 @@ class AttentionMapCallback(Callback):
|
|
| 164 |
sample_idx: Index of the sample in the batch
|
| 165 |
epoch: Current epoch number
|
| 166 |
patch_size: Size of patches
|
|
|
|
| 167 |
"""
|
| 168 |
# Convert image to numpy and transpose
|
| 169 |
img_np = image.cpu().numpy()
|
| 170 |
if len(img_np.shape) == 3 and img_np.shape[0] in [1, 3]: # Check if channels first
|
| 171 |
img_np = np.transpose(img_np, (1, 2, 0))
|
| 172 |
|
|
|
|
|
|
|
|
|
|
| 173 |
|
| 174 |
# Get attention from the last layer
|
| 175 |
last_layer_attention = attention_weights[-1] # [B, num_heads, seq_len, seq_len]
|
| 176 |
-
|
| 177 |
# Extract attention for this sample
|
| 178 |
sample_attention = last_layer_attention[sample_idx] # [num_heads, seq_len, seq_len]
|
| 179 |
-
|
| 180 |
# Average across heads
|
| 181 |
avg_attention = sample_attention.mean(dim=0) # [seq_len, seq_len]
|
| 182 |
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
|
| 190 |
-
#
|
| 191 |
-
attention_map = cls_attention.reshape(grid_h, grid_w)
|
| 192 |
-
|
| 193 |
-
# Create figure with subplots
|
| 194 |
-
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
| 195 |
-
|
| 196 |
-
# Plot 1: Original image
|
| 197 |
-
# if img_np.shape[2] == 1: # Grayscale
|
| 198 |
-
# img_display = (img_np[:, :, 0] + 1) / 2
|
| 199 |
-
# axes[0].imshow(img_display, cmap='gray')
|
| 200 |
-
# elif img_np.shape[2] == 3: # RGB
|
| 201 |
-
# # Normalize RGB image properly
|
| 202 |
-
# img_display = (img_np + 1) / 2 # Assuming images are in [-1, 1] range
|
| 203 |
-
# img_display = np.clip(img_display, 0, 1) # Ensure valid range
|
| 204 |
-
# axes[0].imshow(img_display)
|
| 205 |
-
# else: # Multi-channel (6 channels in your case)
|
| 206 |
-
# # Option 1: Display first channel as grayscale
|
| 207 |
-
# img_display = (img_np[:, :, 0] + 1) / 2
|
| 208 |
-
# axes[0].imshow(img_display, cmap='gray')
|
| 209 |
-
|
| 210 |
-
# Option 2: Create RGB composite from 3 channels (uncomment if preferred)
|
| 211 |
if len(img_np[0,0,:]) >= 6: # Ensure we have enough channels
|
| 212 |
rgb_channels = [0, 2, 4] # Select which channels to use for R, G, B
|
| 213 |
img_display = np.stack([(img_np[:, :, i] + 1) / 2 for i in rgb_channels], axis=2)
|
|
@@ -216,32 +213,88 @@ class AttentionMapCallback(Callback):
|
|
| 216 |
# If not enough channels, use grayscale
|
| 217 |
img_display = (img_np[:, :, 0] + 1) / 2
|
| 218 |
img_display = np.stack([img_display] * 3, axis=2)
|
| 219 |
-
axes[0].imshow(img_display)
|
| 220 |
-
axes[0].set_title(f'Original Image (Epoch {epoch})')
|
| 221 |
-
axes[0].axis('off')
|
| 222 |
-
|
| 223 |
-
# Plot 2: Attention heatmap
|
| 224 |
-
attention_np = np.log1p(attention_map.numpy())
|
| 225 |
-
# Resize attention map to match image size
|
| 226 |
-
attention_resized = zoom(attention_np, (H / grid_h, W / grid_w), order=1)
|
| 227 |
-
|
| 228 |
-
# Create colormap for attention - FIX: Use the scalar values, not RGB
|
| 229 |
-
im = axes[1].imshow(attention_resized, cmap='hot')
|
| 230 |
-
axes[1].set_title(f'Attention Map (Sample {sample_idx})')
|
| 231 |
-
axes[1].axis('off')
|
| 232 |
-
# FIXED: Create colorbar from the scalar image, not RGB
|
| 233 |
-
plt.colorbar(im, ax=axes[1])
|
| 234 |
-
|
| 235 |
-
# Plot 3: Overlay attention on image
|
| 236 |
-
#img_display_overlay = (img_np[:, :, 0] + 1) / 2
|
| 237 |
-
axes[2].imshow(img_display)
|
| 238 |
-
|
| 239 |
-
# Overlay attention with proper alpha blending
|
| 240 |
-
axes[2].imshow(attention_resized, cmap='hot', alpha=0.5)
|
| 241 |
-
axes[2].set_title(f'Log-Scaled Attention Overlay (Sample {sample_idx})')
|
| 242 |
-
axes[2].axis('off')
|
| 243 |
|
| 244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
|
| 246 |
plt.tight_layout()
|
| 247 |
return fig
|
|
|
|
| 90 |
|
| 91 |
|
| 92 |
class AttentionMapCallback(Callback):
|
| 93 |
+
def __init__(self, log_every_n_epochs=1, num_samples=4, save_dir="attention_maps", patch_size=8, use_local_attention=False):
|
| 94 |
"""
|
| 95 |
Callback to visualize attention maps during training.
|
| 96 |
|
|
|
|
| 99 |
num_samples: Number of samples to visualize
|
| 100 |
save_dir: Directory to save attention maps
|
| 101 |
patch_size: Size of patches used in the model
|
| 102 |
+
use_local_attention: If True, visualize local attention patterns instead of CLS token attention
|
| 103 |
"""
|
| 104 |
super().__init__()
|
| 105 |
self.patch_size = patch_size
|
| 106 |
self.log_every_n_epochs = log_every_n_epochs
|
| 107 |
self.num_samples = num_samples
|
| 108 |
self.save_dir = save_dir
|
| 109 |
+
self.use_local_attention = use_local_attention
|
| 110 |
|
| 111 |
def on_validation_epoch_end(self, trainer, pl_module):
|
| 112 |
if trainer.current_epoch % self.log_every_n_epochs == 0:
|
|
|
|
| 127 |
# Move to device
|
| 128 |
imgs = imgs[:self.num_samples].to(pl_module.device)
|
| 129 |
|
| 130 |
+
# Get predictions with attention weights and patch contributions
|
| 131 |
+
patch_flux_raw = None
|
| 132 |
try:
|
| 133 |
outputs, attention_weights = pl_module(imgs, return_attention=True)
|
| 134 |
except:
|
|
|
|
| 136 |
if hasattr(pl_module, 'model') and hasattr(pl_module.model, 'forward'):
|
| 137 |
try:
|
| 138 |
print("Using model's forward method")
|
| 139 |
+
outputs, attention_weights, patch_flux_raw = pl_module.model(imgs, pl_module.sxr_norm, return_attention=True)
|
| 140 |
except:
|
| 141 |
print("Using model's forward method failed")
|
| 142 |
outputs, attention_weights = pl_module.forward_for_callback(imgs, return_attention=True)
|
|
|
|
| 151 |
attention_weights,
|
| 152 |
sample_idx,
|
| 153 |
trainer.current_epoch,
|
| 154 |
+
patch_size=self.patch_size,
|
| 155 |
+
patch_flux=patch_flux_raw[sample_idx] if patch_flux_raw is not None else None
|
| 156 |
)
|
| 157 |
trainer.logger.experiment.log({"Attention plots": wandb.Image(map)})
|
| 158 |
plt.close(map)
|
| 159 |
|
| 160 |
+
def _plot_attention_map(self, image, attention_weights, sample_idx, epoch, patch_size, patch_flux=None):
|
| 161 |
"""
|
| 162 |
Plot attention map for a single image.
|
| 163 |
|
|
|
|
| 167 |
sample_idx: Index of the sample in the batch
|
| 168 |
epoch: Current epoch number
|
| 169 |
patch_size: Size of patches
|
| 170 |
+
patch_flux: Optional tensor of patch flux contributions [num_patches]
|
| 171 |
"""
|
| 172 |
# Convert image to numpy and transpose
|
| 173 |
img_np = image.cpu().numpy()
|
| 174 |
if len(img_np.shape) == 3 and img_np.shape[0] in [1, 3]: # Check if channels first
|
| 175 |
img_np = np.transpose(img_np, (1, 2, 0))
|
| 176 |
|
| 177 |
+
# Calculate grid size
|
| 178 |
+
H, W = img_np.shape[:2]
|
| 179 |
+
grid_h, grid_w = H // patch_size, W // patch_size
|
| 180 |
|
| 181 |
# Get attention from the last layer
|
| 182 |
last_layer_attention = attention_weights[-1] # [B, num_heads, seq_len, seq_len]
|
| 183 |
+
|
| 184 |
# Extract attention for this sample
|
| 185 |
sample_attention = last_layer_attention[sample_idx] # [num_heads, seq_len, seq_len]
|
| 186 |
+
|
| 187 |
# Average across heads
|
| 188 |
avg_attention = sample_attention.mean(dim=0) # [seq_len, seq_len]
|
| 189 |
|
| 190 |
+
if self.use_local_attention:
|
| 191 |
+
# For local attention: visualize attention patterns from center patch
|
| 192 |
+
# and average attention across all patches
|
| 193 |
+
center_patch_idx = (grid_h * grid_w) // 2 # Center patch
|
| 194 |
+
center_attention = avg_attention[center_patch_idx, :].cpu() # [num_patches]
|
| 195 |
+
|
| 196 |
+
# Average attention pattern (how much each patch attends to others on average)
|
| 197 |
+
avg_attention_map = avg_attention.mean(dim=0).cpu() # [num_patches]
|
| 198 |
+
|
| 199 |
+
attention_map = avg_attention_map.reshape(grid_h, grid_w)
|
| 200 |
+
center_map = center_attention.reshape(grid_h, grid_w)
|
| 201 |
+
else:
|
| 202 |
+
# For CLS token attention: visualize attention from CLS to patches
|
| 203 |
+
cls_attention = avg_attention[0, 1:].cpu() # [num_patches]
|
| 204 |
+
attention_map = cls_attention.reshape(grid_h, grid_w)
|
| 205 |
+
center_map = None
|
| 206 |
|
| 207 |
+
# Prepare image display
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
if len(img_np[0,0,:]) >= 6: # Ensure we have enough channels
|
| 209 |
rgb_channels = [0, 2, 4] # Select which channels to use for R, G, B
|
| 210 |
img_display = np.stack([(img_np[:, :, i] + 1) / 2 for i in rgb_channels], axis=2)
|
|
|
|
| 213 |
# If not enough channels, use grayscale
|
| 214 |
img_display = (img_np[:, :, 0] + 1) / 2
|
| 215 |
img_display = np.stack([img_display] * 3, axis=2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
|
| 217 |
+
# Create figure with appropriate number of subplots
|
| 218 |
+
if self.use_local_attention and patch_flux is not None:
|
| 219 |
+
# Show: Original, Avg Attention, Center Attention, Patch Flux
|
| 220 |
+
fig, axes = plt.subplots(1, 4, figsize=(20, 5))
|
| 221 |
+
|
| 222 |
+
# Plot 1: Original image
|
| 223 |
+
axes[0].imshow(img_display)
|
| 224 |
+
axes[0].set_title(f'Original Image (Epoch {epoch})')
|
| 225 |
+
axes[0].axis('off')
|
| 226 |
+
|
| 227 |
+
# Plot 2: Average attention pattern
|
| 228 |
+
attention_np = np.log1p(attention_map.numpy())
|
| 229 |
+
attention_resized = zoom(attention_np, (H / grid_h, W / grid_w), order=1)
|
| 230 |
+
im1 = axes[1].imshow(attention_resized, cmap='hot')
|
| 231 |
+
axes[1].set_title('Avg Attention (All Patches)')
|
| 232 |
+
axes[1].axis('off')
|
| 233 |
+
plt.colorbar(im1, ax=axes[1])
|
| 234 |
+
|
| 235 |
+
# Plot 3: Center patch attention
|
| 236 |
+
center_np = np.log1p(center_map.numpy())
|
| 237 |
+
center_resized = zoom(center_np, (H / grid_h, W / grid_w), order=1)
|
| 238 |
+
im2 = axes[2].imshow(center_resized, cmap='viridis')
|
| 239 |
+
axes[2].set_title('Center Patch Attention')
|
| 240 |
+
axes[2].axis('off')
|
| 241 |
+
plt.colorbar(im2, ax=axes[2])
|
| 242 |
+
|
| 243 |
+
# Plot 4: Patch flux contributions
|
| 244 |
+
flux_map = patch_flux.cpu().reshape(grid_h, grid_w)
|
| 245 |
+
flux_np = np.log1p(flux_map.numpy())
|
| 246 |
+
flux_resized = zoom(flux_np, (H / grid_h, W / grid_w), order=1)
|
| 247 |
+
im3 = axes[3].imshow(flux_resized, cmap='plasma')
|
| 248 |
+
axes[3].set_title('Log Patch Flux Contributions')
|
| 249 |
+
axes[3].axis('off')
|
| 250 |
+
plt.colorbar(im3, ax=axes[3])
|
| 251 |
+
|
| 252 |
+
elif self.use_local_attention:
|
| 253 |
+
# Show: Original, Avg Attention, Center Attention
|
| 254 |
+
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
| 255 |
+
|
| 256 |
+
# Plot 1: Original image
|
| 257 |
+
axes[0].imshow(img_display)
|
| 258 |
+
axes[0].set_title(f'Original Image (Epoch {epoch})')
|
| 259 |
+
axes[0].axis('off')
|
| 260 |
+
|
| 261 |
+
# Plot 2: Average attention pattern
|
| 262 |
+
attention_np = np.log1p(attention_map.numpy())
|
| 263 |
+
attention_resized = zoom(attention_np, (H / grid_h, W / grid_w), order=1)
|
| 264 |
+
im1 = axes[1].imshow(attention_resized, cmap='hot')
|
| 265 |
+
axes[1].set_title('Avg Attention (All Patches)')
|
| 266 |
+
axes[1].axis('off')
|
| 267 |
+
plt.colorbar(im1, ax=axes[1])
|
| 268 |
+
|
| 269 |
+
# Plot 3: Center patch attention
|
| 270 |
+
center_np = np.log1p(center_map.numpy())
|
| 271 |
+
center_resized = zoom(center_np, (H / grid_h, W / grid_w), order=1)
|
| 272 |
+
im2 = axes[2].imshow(center_resized, cmap='viridis')
|
| 273 |
+
axes[2].set_title('Center Patch Attention')
|
| 274 |
+
axes[2].axis('off')
|
| 275 |
+
plt.colorbar(im2, ax=axes[2])
|
| 276 |
+
else:
|
| 277 |
+
# Original CLS token visualization
|
| 278 |
+
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
| 279 |
+
|
| 280 |
+
# Plot 1: Original image
|
| 281 |
+
axes[0].imshow(img_display)
|
| 282 |
+
axes[0].set_title(f'Original Image (Epoch {epoch})')
|
| 283 |
+
axes[0].axis('off')
|
| 284 |
+
|
| 285 |
+
# Plot 2: Attention heatmap
|
| 286 |
+
attention_np = np.log1p(attention_map.numpy())
|
| 287 |
+
attention_resized = zoom(attention_np, (H / grid_h, W / grid_w), order=1)
|
| 288 |
+
im = axes[1].imshow(attention_resized, cmap='hot')
|
| 289 |
+
axes[1].set_title(f'Attention Map (Sample {sample_idx})')
|
| 290 |
+
axes[1].axis('off')
|
| 291 |
+
plt.colorbar(im, ax=axes[1])
|
| 292 |
+
|
| 293 |
+
# Plot 3: Overlay attention on image
|
| 294 |
+
axes[2].imshow(img_display)
|
| 295 |
+
axes[2].imshow(attention_resized, cmap='hot', alpha=0.5)
|
| 296 |
+
axes[2].set_title(f'Log-Scaled Attention Overlay (Sample {sample_idx})')
|
| 297 |
+
axes[2].axis('off')
|
| 298 |
|
| 299 |
plt.tight_layout()
|
| 300 |
return fig
|
forecasting/training/localpatch.yaml
CHANGED
|
@@ -1,14 +1,19 @@
|
|
| 1 |
|
| 2 |
#Base directories - change these to switch datasets
|
| 3 |
-
base_data_dir: "/mnt/data/
|
| 4 |
-
base_checkpoint_dir: "/mnt/data/
|
| 5 |
wavelengths: [94, 131, 171, 193, 211, 304] # AIA wavelengths in Angstroms
|
| 6 |
|
| 7 |
# GPU configuration
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
# Model configuration
|
| 10 |
selected_model: "ViTLocal" # Options: "hybrid", "vit", "fusion", "vitpatch"
|
| 11 |
-
batch_size:
|
| 12 |
epochs: 250
|
| 13 |
oversample: false
|
| 14 |
balance_strategy: "upsample_minority"
|
|
@@ -20,9 +25,9 @@ vit_custom:
|
|
| 20 |
num_classes: 1
|
| 21 |
patch_size: 16
|
| 22 |
num_patches: 1024
|
| 23 |
-
hidden_dim:
|
| 24 |
num_heads: 8
|
| 25 |
-
num_layers:
|
| 26 |
dropout: 0.1
|
| 27 |
lr: 0.0001
|
| 28 |
|
|
@@ -40,11 +45,11 @@ data:
|
|
| 40 |
|
| 41 |
wandb:
|
| 42 |
entity: jayantbiradar619-university-of-arizona # Use your exact W&B username
|
| 43 |
-
project:
|
| 44 |
job_type: training
|
| 45 |
tags:
|
| 46 |
- aia
|
| 47 |
- sxr
|
| 48 |
- regression
|
| 49 |
-
wb_name:
|
| 50 |
notes: Regression from AIA images (6 channels) to GOES SXR flux
|
|
|
|
| 1 |
|
| 2 |
#Base directories - change these to switch datasets
|
| 3 |
+
base_data_dir: "/mnt/data/PAPER_DATA_B" # Change this line for different datasets
|
| 4 |
+
base_checkpoint_dir: "/mnt/data/PAPER_DATA_B" # Change this line for different datasets
|
| 5 |
wavelengths: [94, 131, 171, 193, 211, 304] # AIA wavelengths in Angstroms
|
| 6 |
|
| 7 |
# GPU configuration
|
| 8 |
+
# Options:
|
| 9 |
+
# - Single GPU: gpu_ids: 0 or gpu_ids: [0]
|
| 10 |
+
# - Multi GPU: gpu_ids: [0, 1] (uses both GPU 0 and 1)
|
| 11 |
+
# - All GPUs: gpu_ids: "all" (uses all available GPUs)
|
| 12 |
+
# - CPU only: gpu_ids: -1
|
| 13 |
+
gpu_ids: "all" # Use both GPUs
|
| 14 |
# Model configuration
|
| 15 |
selected_model: "ViTLocal" # Options: "hybrid", "vit", "fusion", "vitpatch"
|
| 16 |
+
batch_size: 48
|
| 17 |
epochs: 250
|
| 18 |
oversample: false
|
| 19 |
balance_strategy: "upsample_minority"
|
|
|
|
| 25 |
num_classes: 1
|
| 26 |
patch_size: 16
|
| 27 |
num_patches: 1024
|
| 28 |
+
hidden_dim: 2048
|
| 29 |
num_heads: 8
|
| 30 |
+
num_layers: 10
|
| 31 |
dropout: 0.1
|
| 32 |
lr: 0.0001
|
| 33 |
|
|
|
|
| 45 |
|
| 46 |
wandb:
|
| 47 |
entity: jayantbiradar619-university-of-arizona # Use your exact W&B username
|
| 48 |
+
project: Paper
|
| 49 |
job_type: training
|
| 50 |
tags:
|
| 51 |
- aia
|
| 52 |
- sxr
|
| 53 |
- regression
|
| 54 |
+
wb_name: paper-testing-16-patch-deeper-model-9x9-attention
|
| 55 |
notes: Regression from AIA images (6 channels) to GOES SXR flux
|
forecasting/training/train.py
CHANGED
|
@@ -13,7 +13,6 @@ import numpy as np
|
|
| 13 |
from pytorch_lightning import Trainer
|
| 14 |
from pytorch_lightning.loggers import WandbLogger
|
| 15 |
from pytorch_lightning.callbacks import ModelCheckpoint
|
| 16 |
-
from torch.nn import MSELoss, HuberLoss
|
| 17 |
from pathlib import Path
|
| 18 |
import sys
|
| 19 |
# Add project root to Python path
|
|
@@ -21,33 +20,13 @@ PROJECT_ROOT = Path(__file__).parent.parent.parent.absolute()
|
|
| 21 |
sys.path.insert(0, str(PROJECT_ROOT))
|
| 22 |
|
| 23 |
from forecasting.data_loaders.SDOAIA_dataloader import AIA_GOESDataModule
|
| 24 |
-
|
| 25 |
-
from forecasting.models.linear_and_hybrid import LinearIrradianceModel, HybridIrradianceModel
|
| 26 |
-
from forecasting.models.vit_patch_model import ViT as ViTPatch
|
| 27 |
-
from forecasting.models.vit_patch_model_uncertainty import ViTUncertainty
|
| 28 |
-
from forecasting.models import FusionViTHybrid
|
| 29 |
-
from forecasting.models.CNN_Patch import CNNPatch
|
| 30 |
from forecasting.models.vit_patch_model_local import ViTLocal
|
| 31 |
from callback import ImagePredictionLogger_SXR, AttentionMapCallback
|
| 32 |
|
| 33 |
from pytorch_lightning.callbacks import Callback
|
| 34 |
|
| 35 |
-
from forecasting.models.FastSpectralNet import FastViTFlaringModel
|
| 36 |
-
|
| 37 |
-
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 38 |
-
os.environ["NCCL_DEBUG"] = "WARN"
|
| 39 |
-
# Shared memory optimizations
|
| 40 |
-
os.environ["OMP_NUM_THREADS"] = "1" # Limit OpenMP threads
|
| 41 |
-
os.environ["MKL_NUM_THREADS"] = "1" # Limit MKL threads
|
| 42 |
|
| 43 |
-
def print_gpu_memory(stage=""):
|
| 44 |
-
"""Print GPU memory usage for monitoring"""
|
| 45 |
-
if torch.cuda.is_available():
|
| 46 |
-
allocated = torch.cuda.memory_allocated() / 1e9
|
| 47 |
-
reserved = torch.cuda.memory_reserved() / 1e9
|
| 48 |
-
print(f"GPU Memory {stage} - Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB")
|
| 49 |
-
else:
|
| 50 |
-
print(f"No GPU available for memory monitoring {stage}")
|
| 51 |
|
| 52 |
def resolve_config_variables(config_dict):
|
| 53 |
"""Recursively resolve ${variable} references within the config"""
|
|
@@ -91,27 +70,6 @@ with open(args.config, 'r') as stream:
|
|
| 91 |
# Resolve variables like ${base_data_dir}
|
| 92 |
config_data = resolve_config_variables(config_data)
|
| 93 |
|
| 94 |
-
# GPU Memory Isolation for Multi-GPU Systems
|
| 95 |
-
gpu_id = config_data.get('gpu_id', 0)
|
| 96 |
-
if gpu_id != -1: # Only if using GPU
|
| 97 |
-
# Set CUDA device visibility to only the specified GPU
|
| 98 |
-
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
| 99 |
-
print(f"Set CUDA_VISIBLE_DEVICES to GPU {gpu_id}")
|
| 100 |
-
|
| 101 |
-
# Clear any existing CUDA cache
|
| 102 |
-
if torch.cuda.is_available():
|
| 103 |
-
torch.cuda.empty_cache()
|
| 104 |
-
print(f"Cleared CUDA cache for GPU {gpu_id}")
|
| 105 |
-
|
| 106 |
-
# Set memory allocation strategy for better isolation
|
| 107 |
-
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,roundup_power2_divisions:16"
|
| 108 |
-
|
| 109 |
-
# Disable memory sharing between processes
|
| 110 |
-
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
| 111 |
-
|
| 112 |
-
print(f"GPU Memory Isolation configured for GPU {gpu_id}")
|
| 113 |
-
else:
|
| 114 |
-
print("Using CPU - no GPU memory isolation needed")
|
| 115 |
|
| 116 |
# Debug: Print resolved paths
|
| 117 |
print("Resolved paths:")
|
|
@@ -120,12 +78,6 @@ print(f"SXR dir: {config_data['data']['sxr_dir']}")
|
|
| 120 |
print(f"Checkpoints dir: {config_data['data']['checkpoints_dir']}")
|
| 121 |
|
| 122 |
sxr_norm = np.load(config_data['data']['sxr_norm_path'])
|
| 123 |
-
|
| 124 |
-
n = 0
|
| 125 |
-
|
| 126 |
-
torch.manual_seed(config_data['megsai']['seed'])
|
| 127 |
-
np.random.seed(config_data['megsai']['seed'])
|
| 128 |
-
|
| 129 |
training_wavelengths = config_data['wavelengths']
|
| 130 |
|
| 131 |
|
|
@@ -145,10 +97,6 @@ data_loader = AIA_GOESDataModule(
|
|
| 145 |
balance_strategy=config_data['balance_strategy'],
|
| 146 |
)
|
| 147 |
data_loader.setup()
|
| 148 |
-
|
| 149 |
-
# Monitor memory after data loading
|
| 150 |
-
print_gpu_memory("after data loading")
|
| 151 |
-
|
| 152 |
# Logger
|
| 153 |
#wb_name = f"{instrument}_{n}" if len(combined_parameters) > 1 else "aia_sxr_model"
|
| 154 |
wandb_logger = WandbLogger(
|
|
@@ -158,7 +106,7 @@ wandb_logger = WandbLogger(
|
|
| 158 |
tags=config_data['wandb']['tags'],
|
| 159 |
name=config_data['wandb']['wb_name'],
|
| 160 |
notes=config_data['wandb']['notes'],
|
| 161 |
-
config=config_data
|
| 162 |
)
|
| 163 |
|
| 164 |
# Logging callback
|
|
@@ -169,8 +117,8 @@ plot_samples = plot_data # Keep as list of ((aia, sxr), target)
|
|
| 169 |
|
| 170 |
sxr_plot_callback = ImagePredictionLogger_SXR(plot_samples, sxr_norm)
|
| 171 |
# Attention map callback - get patch size from config
|
| 172 |
-
patch_size = config_data.get('vit_custom', {}).get('patch_size',
|
| 173 |
-
attention = AttentionMapCallback(patch_size=patch_size)
|
| 174 |
|
| 175 |
|
| 176 |
class PTHCheckpointCallback(Callback):
|
|
@@ -323,27 +271,63 @@ else:
|
|
| 323 |
raise NotImplementedError(f"Architecture {config_data['selected_model']} not supported.")
|
| 324 |
|
| 325 |
# Set device based on config
|
| 326 |
-
|
| 327 |
-
|
|
|
|
|
|
|
|
|
|
| 328 |
accelerator = "cpu"
|
| 329 |
devices = 1
|
|
|
|
| 330 |
print("Using CPU for training")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
else:
|
|
|
|
| 332 |
if torch.cuda.is_available():
|
| 333 |
accelerator = "gpu"
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
print(f"Using GPU {
|
| 337 |
else:
|
| 338 |
accelerator = "cpu"
|
| 339 |
devices = 1
|
| 340 |
-
|
|
|
|
| 341 |
|
| 342 |
# Trainer
|
| 343 |
trainer = Trainer(
|
| 344 |
default_root_dir=config_data['data']['checkpoints_dir'],
|
| 345 |
accelerator=accelerator,
|
| 346 |
devices=devices,
|
|
|
|
| 347 |
max_epochs=config_data['epochs'],
|
| 348 |
callbacks=[attention, checkpoint_callback],
|
| 349 |
logger=wandb_logger,
|
|
@@ -359,6 +343,5 @@ torch.save({
|
|
| 359 |
'state_dict': model.state_dict()
|
| 360 |
}, final_checkpoint_path)
|
| 361 |
print(f"Saved final PyTorch checkpoint: {final_checkpoint_path}")
|
| 362 |
-
n += 1
|
| 363 |
# Finalize
|
| 364 |
wandb.finish()
|
|
|
|
| 13 |
from pytorch_lightning import Trainer
|
| 14 |
from pytorch_lightning.loggers import WandbLogger
|
| 15 |
from pytorch_lightning.callbacks import ModelCheckpoint
|
|
|
|
| 16 |
from pathlib import Path
|
| 17 |
import sys
|
| 18 |
# Add project root to Python path
|
|
|
|
| 20 |
sys.path.insert(0, str(PROJECT_ROOT))
|
| 21 |
|
| 22 |
from forecasting.data_loaders.SDOAIA_dataloader import AIA_GOESDataModule
|
| 23 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
from forecasting.models.vit_patch_model_local import ViTLocal
|
| 25 |
from callback import ImagePredictionLogger_SXR, AttentionMapCallback
|
| 26 |
|
| 27 |
from pytorch_lightning.callbacks import Callback
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
def resolve_config_variables(config_dict):
|
| 32 |
"""Recursively resolve ${variable} references within the config"""
|
|
|
|
| 70 |
# Resolve variables like ${base_data_dir}
|
| 71 |
config_data = resolve_config_variables(config_data)
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
# Debug: Print resolved paths
|
| 75 |
print("Resolved paths:")
|
|
|
|
| 78 |
print(f"Checkpoints dir: {config_data['data']['checkpoints_dir']}")
|
| 79 |
|
| 80 |
sxr_norm = np.load(config_data['data']['sxr_norm_path'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
training_wavelengths = config_data['wavelengths']
|
| 82 |
|
| 83 |
|
|
|
|
| 97 |
balance_strategy=config_data['balance_strategy'],
|
| 98 |
)
|
| 99 |
data_loader.setup()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
# Logger
|
| 101 |
#wb_name = f"{instrument}_{n}" if len(combined_parameters) > 1 else "aia_sxr_model"
|
| 102 |
wandb_logger = WandbLogger(
|
|
|
|
| 106 |
tags=config_data['wandb']['tags'],
|
| 107 |
name=config_data['wandb']['wb_name'],
|
| 108 |
notes=config_data['wandb']['notes'],
|
| 109 |
+
config=config_data
|
| 110 |
)
|
| 111 |
|
| 112 |
# Logging callback
|
|
|
|
| 117 |
|
| 118 |
sxr_plot_callback = ImagePredictionLogger_SXR(plot_samples, sxr_norm)
|
| 119 |
# Attention map callback - get patch size from config
|
| 120 |
+
patch_size = config_data.get('vit_custom', {}).get('patch_size', 16)
|
| 121 |
+
attention = AttentionMapCallback(patch_size=patch_size, use_local_attention=True)
|
| 122 |
|
| 123 |
|
| 124 |
class PTHCheckpointCallback(Callback):
|
|
|
|
| 271 |
raise NotImplementedError(f"Architecture {config_data['selected_model']} not supported.")
|
| 272 |
|
| 273 |
# Set device based on config
|
| 274 |
+
# Support both old 'gpu_id' and new 'gpu_ids' config keys for backward compatibility
|
| 275 |
+
gpu_config = config_data.get('gpu_ids', config_data.get('gpu_id', 0))
|
| 276 |
+
|
| 277 |
+
if gpu_config == -1:
|
| 278 |
+
# CPU only
|
| 279 |
accelerator = "cpu"
|
| 280 |
devices = 1
|
| 281 |
+
strategy = "auto"
|
| 282 |
print("Using CPU for training")
|
| 283 |
+
elif gpu_config == "all":
|
| 284 |
+
# Use all available GPUs
|
| 285 |
+
if torch.cuda.is_available():
|
| 286 |
+
accelerator = "gpu"
|
| 287 |
+
devices = -1 # -1 means use all available GPUs
|
| 288 |
+
num_gpus = torch.cuda.device_count()
|
| 289 |
+
strategy = "auto"
|
| 290 |
+
print(f"Using all available GPUs ({num_gpus} GPUs)")
|
| 291 |
+
if num_gpus > 1:
|
| 292 |
+
print(f"Multi-GPU training with DDP: Effective batch size = {config_data['batch_size']} x {num_gpus} GPUs = {config_data['batch_size'] * num_gpus}")
|
| 293 |
+
else:
|
| 294 |
+
accelerator = "cpu"
|
| 295 |
+
devices = 1
|
| 296 |
+
strategy = "auto"
|
| 297 |
+
print("No GPUs available, falling back to CPU")
|
| 298 |
+
elif isinstance(gpu_config, list):
|
| 299 |
+
# Multiple specific GPUs
|
| 300 |
+
if torch.cuda.is_available():
|
| 301 |
+
accelerator = "gpu"
|
| 302 |
+
devices = gpu_config
|
| 303 |
+
strategy = "auto"
|
| 304 |
+
print(f"Using GPUs: {gpu_config}")
|
| 305 |
+
if len(gpu_config) > 1:
|
| 306 |
+
print(f"Multi-GPU training with DDP: Effective batch size = {config_data['batch_size']} x {len(gpu_config)} GPUs = {config_data['batch_size'] * len(gpu_config)}")
|
| 307 |
+
else:
|
| 308 |
+
accelerator = "cpu"
|
| 309 |
+
devices = 1
|
| 310 |
+
strategy = "auto"
|
| 311 |
+
print("No GPUs available, falling back to CPU")
|
| 312 |
else:
|
| 313 |
+
# Single GPU (integer)
|
| 314 |
if torch.cuda.is_available():
|
| 315 |
accelerator = "gpu"
|
| 316 |
+
devices = [gpu_config]
|
| 317 |
+
strategy = "auto"
|
| 318 |
+
print(f"Using GPU {gpu_config}")
|
| 319 |
else:
|
| 320 |
accelerator = "cpu"
|
| 321 |
devices = 1
|
| 322 |
+
strategy = "auto"
|
| 323 |
+
print(f"GPU {gpu_config} not available, falling back to CPU")
|
| 324 |
|
| 325 |
# Trainer
|
| 326 |
trainer = Trainer(
|
| 327 |
default_root_dir=config_data['data']['checkpoints_dir'],
|
| 328 |
accelerator=accelerator,
|
| 329 |
devices=devices,
|
| 330 |
+
strategy=strategy,
|
| 331 |
max_epochs=config_data['epochs'],
|
| 332 |
callbacks=[attention, checkpoint_callback],
|
| 333 |
logger=wandb_logger,
|
|
|
|
| 343 |
'state_dict': model.state_dict()
|
| 344 |
}, final_checkpoint_path)
|
| 345 |
print(f"Saved final PyTorch checkpoint: {final_checkpoint_path}")
|
|
|
|
| 346 |
# Finalize
|
| 347 |
wandb.finish()
|