Commit ·
c29556a
1
Parent(s): 0b74556
refactor attention visualization; update callback to log attention maps and adjust model configuration
Browse files
flaring/MEGS_AI_baseline/callback.py
CHANGED
|
@@ -11,6 +11,9 @@ import numpy as np
|
|
| 11 |
from pytorch_lightning.callbacks import Callback
|
| 12 |
from PIL import Image
|
| 13 |
import matplotlib.patches as patches
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
# Custom Callback
|
| 16 |
sdoaia94 = matplotlib.colormaps['sdoaia94']
|
|
@@ -148,29 +151,28 @@ class AttentionMapCallback(Callback):
|
|
| 148 |
patch_size: Size of patches
|
| 149 |
"""
|
| 150 |
# Convert image to numpy for plotting
|
|
|
|
| 151 |
img_np = image.cpu().numpy()
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
# Normalize image for display
|
| 155 |
-
#img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())
|
| 156 |
|
| 157 |
|
| 158 |
-
# Get attention from the last layer
|
| 159 |
last_layer_attention = attention_weights[-1] # [B, num_heads, seq_len, seq_len]
|
| 160 |
|
| 161 |
# Extract attention for this sample
|
| 162 |
sample_attention = last_layer_attention[sample_idx] # [num_heads, seq_len, seq_len]
|
| 163 |
|
| 164 |
-
# Average across heads
|
| 165 |
avg_attention = sample_attention.mean(dim=0) # [seq_len, seq_len]
|
| 166 |
|
| 167 |
# Get attention from CLS token to patches (exclude CLS->CLS)
|
| 168 |
cls_attention = avg_attention[0, 1:].cpu() # [num_patches]
|
| 169 |
|
| 170 |
-
# Calculate grid size
|
| 171 |
-
H, W = img_np.shape[:2]
|
| 172 |
grid_h, grid_w = H // patch_size, W // patch_size
|
| 173 |
-
|
| 174 |
# Reshape attention to spatial grid
|
| 175 |
attention_map = cls_attention.reshape(grid_h, grid_w)
|
| 176 |
|
|
@@ -178,46 +180,52 @@ class AttentionMapCallback(Callback):
|
|
| 178 |
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
| 179 |
|
| 180 |
# Plot 1: Original image
|
| 181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
axes[0].set_title(f'Original Image (Epoch {epoch})')
|
| 183 |
axes[0].axis('off')
|
| 184 |
|
| 185 |
# Plot 2: Attention heatmap
|
| 186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
axes[1].set_title(f'Attention Map (Sample {sample_idx})')
|
| 188 |
axes[1].axis('off')
|
|
|
|
| 189 |
plt.colorbar(im, ax=axes[1])
|
| 190 |
|
| 191 |
# Plot 3: Overlay attention on image
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
# Overlay attention as colored patches
|
| 195 |
-
max_attention = attention_map.max().numpy()
|
| 196 |
-
for i in range(grid_h):
|
| 197 |
-
for j in range(grid_w):
|
| 198 |
-
attention_val = attention_map[i, j].item()
|
| 199 |
-
# Create a colored rectangle with alpha based on attention
|
| 200 |
-
rect = patches.Rectangle(
|
| 201 |
-
(j * patch_size, i * patch_size),
|
| 202 |
-
patch_size, patch_size,
|
| 203 |
-
linewidth=0,
|
| 204 |
-
facecolor='red',
|
| 205 |
-
alpha=(attention_val/max_attention) * .9
|
| 206 |
-
)
|
| 207 |
-
axes[2].add_patch(rect)
|
| 208 |
|
| 209 |
-
|
|
|
|
|
|
|
| 210 |
axes[2].axis('off')
|
| 211 |
|
| 212 |
plt.tight_layout()
|
| 213 |
-
return fig
|
| 214 |
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
# os.makedirs(self.save_dir, exist_ok=True)
|
| 218 |
-
# plt.savefig(f'{self.save_dir}/attention_epoch_{epoch}_sample_{sample_idx}.png',
|
| 219 |
-
# dpi=150, bbox_inches='tight')
|
| 220 |
-
# plt.close()
|
| 221 |
|
| 222 |
|
| 223 |
class MultiHeadAttentionCallback(AttentionMapCallback):
|
|
|
|
| 11 |
from pytorch_lightning.callbacks import Callback
|
| 12 |
from PIL import Image
|
| 13 |
import matplotlib.patches as patches
|
| 14 |
+
import matplotlib.cm as cm
|
| 15 |
+
import matplotlib.colors as mcolors
|
| 16 |
+
from scipy.ndimage import zoom
|
| 17 |
|
| 18 |
# Custom Callback
|
| 19 |
sdoaia94 = matplotlib.colormaps['sdoaia94']
|
|
|
|
| 151 |
patch_size: Size of patches
|
| 152 |
"""
|
| 153 |
# Convert image to numpy for plotting
|
| 154 |
+
# Convert image to numpy and transpose
|
| 155 |
img_np = image.cpu().numpy()
|
| 156 |
+
if len(img_np.shape) == 3 and img_np.shape[0] in [1, 3]: # Check if channels first
|
| 157 |
+
img_np = np.transpose(img_np, (1, 2, 0))
|
|
|
|
|
|
|
| 158 |
|
| 159 |
|
| 160 |
+
# Get attention from the last layer
|
| 161 |
last_layer_attention = attention_weights[-1] # [B, num_heads, seq_len, seq_len]
|
| 162 |
|
| 163 |
# Extract attention for this sample
|
| 164 |
sample_attention = last_layer_attention[sample_idx] # [num_heads, seq_len, seq_len]
|
| 165 |
|
| 166 |
+
# Average across heads
|
| 167 |
avg_attention = sample_attention.mean(dim=0) # [seq_len, seq_len]
|
| 168 |
|
| 169 |
# Get attention from CLS token to patches (exclude CLS->CLS)
|
| 170 |
cls_attention = avg_attention[0, 1:].cpu() # [num_patches]
|
| 171 |
|
| 172 |
+
# Calculate grid size - NOW USING CORRECT DIMENSIONS
|
| 173 |
+
H, W = img_np.shape[:2] # Now this is correct after transpose
|
| 174 |
grid_h, grid_w = H // patch_size, W // patch_size
|
| 175 |
+
|
| 176 |
# Reshape attention to spatial grid
|
| 177 |
attention_map = cls_attention.reshape(grid_h, grid_w)
|
| 178 |
|
|
|
|
| 180 |
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
| 181 |
|
| 182 |
# Plot 1: Original image
|
| 183 |
+
# if img_np.shape[2] == 1: # Grayscale
|
| 184 |
+
# img_display = (img_np[:, :, 0] + 1) / 2
|
| 185 |
+
# axes[0].imshow(img_display, cmap='gray')
|
| 186 |
+
# elif img_np.shape[2] == 3: # RGB
|
| 187 |
+
# # Normalize RGB image properly
|
| 188 |
+
# img_display = (img_np + 1) / 2 # Assuming images are in [-1, 1] range
|
| 189 |
+
# img_display = np.clip(img_display, 0, 1) # Ensure valid range
|
| 190 |
+
# axes[0].imshow(img_display)
|
| 191 |
+
# else: # Multi-channel (6 channels in your case)
|
| 192 |
+
# # Option 1: Display first channel as grayscale
|
| 193 |
+
# img_display = (img_np[:, :, 0] + 1) / 2
|
| 194 |
+
# axes[0].imshow(img_display, cmap='gray')
|
| 195 |
+
|
| 196 |
+
# Option 2: Create RGB composite from 3 channels (uncomment if preferred)
|
| 197 |
+
rgb_channels = [0, 2, 4] # Select which channels to use for R, G, B
|
| 198 |
+
img_display = np.stack([(img_np[:, :, i] + 1) / 2 for i in rgb_channels], axis=2)
|
| 199 |
+
img_display = np.clip(img_display, 0, 1)
|
| 200 |
+
axes[0].imshow(img_display)
|
| 201 |
axes[0].set_title(f'Original Image (Epoch {epoch})')
|
| 202 |
axes[0].axis('off')
|
| 203 |
|
| 204 |
# Plot 2: Attention heatmap
|
| 205 |
+
attention_np = np.log1p(attention_map.numpy())
|
| 206 |
+
# Resize attention map to match image size
|
| 207 |
+
attention_resized = zoom(attention_np, (H / grid_h, W / grid_w), order=1)
|
| 208 |
+
|
| 209 |
+
# Create colormap for attention - FIX: Use the scalar values, not RGB
|
| 210 |
+
im = axes[1].imshow(attention_resized, cmap='hot')
|
| 211 |
axes[1].set_title(f'Attention Map (Sample {sample_idx})')
|
| 212 |
axes[1].axis('off')
|
| 213 |
+
# FIXED: Create colorbar from the scalar image, not RGB
|
| 214 |
plt.colorbar(im, ax=axes[1])
|
| 215 |
|
| 216 |
# Plot 3: Overlay attention on image
|
| 217 |
+
#img_display_overlay = (img_np[:, :, 0] + 1) / 2
|
| 218 |
+
axes[2].imshow(img_display)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
+
# Overlay attention with proper alpha blending
|
| 221 |
+
axes[2].imshow(attention_resized, cmap='hot', alpha=0.5)
|
| 222 |
+
axes[2].set_title(f'Log-Scaled Attention Overlay (Sample {sample_idx})')
|
| 223 |
axes[2].axis('off')
|
| 224 |
|
| 225 |
plt.tight_layout()
|
|
|
|
| 226 |
|
| 227 |
+
plt.tight_layout()
|
| 228 |
+
return fig
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
|
| 230 |
|
| 231 |
class MultiHeadAttentionCallback(AttentionMapCallback):
|
flaring/MEGS_AI_baseline/config.yaml
CHANGED
|
@@ -5,8 +5,9 @@ base_checkpoint_dir: "/mnt/data/ML-Ready/mixed_data" # Change this line for d
|
|
| 5 |
|
| 6 |
# Model configuration
|
| 7 |
selected_model: "ViT" # Options: "cnn", "vit",
|
| 8 |
-
|
| 9 |
-
|
|
|
|
| 10 |
architecture:
|
| 11 |
"cnn"
|
| 12 |
seed:
|
|
@@ -17,20 +18,16 @@ model:
|
|
| 17 |
"resnet"
|
| 18 |
cnn_dp:
|
| 19 |
0.5
|
| 20 |
-
epochs:
|
| 21 |
-
100
|
| 22 |
-
batch_size:
|
| 23 |
-
16
|
| 24 |
|
| 25 |
vit:
|
| 26 |
-
embed_dim:
|
| 27 |
num_channels: 6 # AIA has 6 channels
|
| 28 |
num_classes: 1 # Regression task, predicting SXR flux
|
| 29 |
-
patch_size:
|
| 30 |
-
num_patches:
|
| 31 |
-
hidden_dim:
|
| 32 |
-
num_heads:
|
| 33 |
-
num_layers:
|
| 34 |
dropout: 0.25
|
| 35 |
lr: .0001
|
| 36 |
|
|
@@ -68,5 +65,5 @@ wandb:
|
|
| 68 |
- aia
|
| 69 |
- sxr
|
| 70 |
- regression
|
| 71 |
-
wb_name:
|
| 72 |
notes: Regression from AIA images (6 channels) to GOES SXR flux
|
|
|
|
| 5 |
|
| 6 |
# Model configuration
|
| 7 |
selected_model: "ViT" # Options: "cnn", "vit",
|
| 8 |
+
batch_size: 64
|
| 9 |
+
epochs: 100
|
| 10 |
+
megsai:
|
| 11 |
architecture:
|
| 12 |
"cnn"
|
| 13 |
seed:
|
|
|
|
| 18 |
"resnet"
|
| 19 |
cnn_dp:
|
| 20 |
0.5
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
vit:
|
| 23 |
+
embed_dim: 256
|
| 24 |
num_channels: 6 # AIA has 6 channels
|
| 25 |
num_classes: 1 # Regression task, predicting SXR flux
|
| 26 |
+
patch_size: 32
|
| 27 |
+
num_patches: 256
|
| 28 |
+
hidden_dim: 256
|
| 29 |
+
num_heads: 1
|
| 30 |
+
num_layers: 1
|
| 31 |
dropout: 0.25
|
| 32 |
lr: .0001
|
| 33 |
|
|
|
|
| 65 |
- aia
|
| 66 |
- sxr
|
| 67 |
- regression
|
| 68 |
+
wb_name: mixed-vit-lr-scheduler
|
| 69 |
notes: Regression from AIA images (6 channels) to GOES SXR flux
|
flaring/MEGS_AI_baseline/train.py
CHANGED
|
@@ -79,8 +79,8 @@ sxr_norm = np.load(config_data['data']['sxr_norm_path'])
|
|
| 79 |
|
| 80 |
n = 0
|
| 81 |
|
| 82 |
-
torch.manual_seed(config_data['
|
| 83 |
-
np.random.seed(config_data['
|
| 84 |
|
| 85 |
# DataModule
|
| 86 |
data_loader = AIA_GOESDataModule(
|
|
@@ -90,7 +90,7 @@ data_loader = AIA_GOESDataModule(
|
|
| 90 |
sxr_train_dir=config_data['data']['sxr_dir']+"/train",
|
| 91 |
sxr_val_dir=config_data['data']['sxr_dir']+"/val",
|
| 92 |
sxr_test_dir=config_data['data']['sxr_dir']+"/test",
|
| 93 |
-
batch_size=config_data['
|
| 94 |
num_workers=os.cpu_count(),
|
| 95 |
sxr_norm=sxr_norm,
|
| 96 |
)
|
|
@@ -105,7 +105,7 @@ wandb_logger = WandbLogger(
|
|
| 105 |
tags=config_data['wandb']['tags'],
|
| 106 |
name=config_data['wandb']['wb_name'],
|
| 107 |
notes=config_data['wandb']['notes'],
|
| 108 |
-
config=config_data['
|
| 109 |
)
|
| 110 |
|
| 111 |
# Logging callback
|
|
@@ -155,7 +155,7 @@ class PTHCheckpointCallback(Callback):
|
|
| 155 |
# Checkpoint callback
|
| 156 |
checkpoint_callback = ModelCheckpoint(
|
| 157 |
dirpath=config_data['data']['checkpoints_dir'],
|
| 158 |
-
monitor='
|
| 159 |
mode='min',
|
| 160 |
save_top_k=1,
|
| 161 |
filename=f"{config_data['wandb']['wb_name']}-{{epoch:02d}}-{{valid_loss:.4f}}.pth"
|
|
@@ -187,12 +187,6 @@ elif config_data['selected_model'] == 'hybrid':
|
|
| 187 |
lr=config_data['model']['lr'],
|
| 188 |
)
|
| 189 |
elif config_data['selected_model'] == 'ViT':
|
| 190 |
-
print("Using ViT")
|
| 191 |
-
# model = ViT(embed_dim=config_data['vit']['embed_dim'], hidden_dim=config_data['vit']['hidden_dim'],
|
| 192 |
-
# num_channels=config_data['vit']['num_channels'],num_heads=config_data['vit']['num_heads'],
|
| 193 |
-
# num_layers=config_data['vit']['num_layers'], num_classes=config_data['vit']['num_classes'],
|
| 194 |
-
# patch_size=config_data['vit']['patch_size'], num_patches=config_data['vit']['num_patches'],
|
| 195 |
-
# dropout=config_data['vit']['dropout'], lr=config_data['vit']['lr'])
|
| 196 |
model = ViT(model_kwargs=config_data['vit'])
|
| 197 |
else:
|
| 198 |
raise NotImplementedError(f"Architecture {config_data['selected_model']} not supported.")
|
|
@@ -202,8 +196,8 @@ trainer = Trainer(
|
|
| 202 |
default_root_dir=config_data['data']['checkpoints_dir'],
|
| 203 |
accelerator="gpu" if torch.cuda.is_available() else "cpu",
|
| 204 |
devices=1,
|
| 205 |
-
max_epochs=config_data['
|
| 206 |
-
callbacks=[attention, pth_callback],
|
| 207 |
logger=wandb_logger,
|
| 208 |
log_every_n_steps=10
|
| 209 |
)
|
|
|
|
| 79 |
|
| 80 |
n = 0
|
| 81 |
|
| 82 |
+
torch.manual_seed(config_data['megsai']['seed'])
|
| 83 |
+
np.random.seed(config_data['megsai']['seed'])
|
| 84 |
|
| 85 |
# DataModule
|
| 86 |
data_loader = AIA_GOESDataModule(
|
|
|
|
| 90 |
sxr_train_dir=config_data['data']['sxr_dir']+"/train",
|
| 91 |
sxr_val_dir=config_data['data']['sxr_dir']+"/val",
|
| 92 |
sxr_test_dir=config_data['data']['sxr_dir']+"/test",
|
| 93 |
+
batch_size=config_data['batch_size'],
|
| 94 |
num_workers=os.cpu_count(),
|
| 95 |
sxr_norm=sxr_norm,
|
| 96 |
)
|
|
|
|
| 105 |
tags=config_data['wandb']['tags'],
|
| 106 |
name=config_data['wandb']['wb_name'],
|
| 107 |
notes=config_data['wandb']['notes'],
|
| 108 |
+
config=config_data['megsai']
|
| 109 |
)
|
| 110 |
|
| 111 |
# Logging callback
|
|
|
|
| 155 |
# Checkpoint callback
|
| 156 |
checkpoint_callback = ModelCheckpoint(
|
| 157 |
dirpath=config_data['data']['checkpoints_dir'],
|
| 158 |
+
monitor='val_loss',
|
| 159 |
mode='min',
|
| 160 |
save_top_k=1,
|
| 161 |
filename=f"{config_data['wandb']['wb_name']}-{{epoch:02d}}-{{valid_loss:.4f}}.pth"
|
|
|
|
| 187 |
lr=config_data['model']['lr'],
|
| 188 |
)
|
| 189 |
elif config_data['selected_model'] == 'ViT':
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
model = ViT(model_kwargs=config_data['vit'])
|
| 191 |
else:
|
| 192 |
raise NotImplementedError(f"Architecture {config_data['selected_model']} not supported.")
|
|
|
|
| 196 |
default_root_dir=config_data['data']['checkpoints_dir'],
|
| 197 |
accelerator="gpu" if torch.cuda.is_available() else "cpu",
|
| 198 |
devices=1,
|
| 199 |
+
max_epochs=config_data['epochs'],
|
| 200 |
+
callbacks=[attention, pth_callback,checkpoint_callback],
|
| 201 |
logger=wandb_logger,
|
| 202 |
log_every_n_steps=10
|
| 203 |
)
|