griffingoodwin04 commited on
Commit
c29556a
·
1 Parent(s): 0b74556

refactor attention visualization; update callback to log attention maps and adjust model configuration

Browse files
flaring/MEGS_AI_baseline/callback.py CHANGED
@@ -11,6 +11,9 @@ import numpy as np
11
  from pytorch_lightning.callbacks import Callback
12
  from PIL import Image
13
  import matplotlib.patches as patches
 
 
 
14
 
15
  # Custom Callback
16
  sdoaia94 = matplotlib.colormaps['sdoaia94']
@@ -148,29 +151,28 @@ class AttentionMapCallback(Callback):
148
  patch_size: Size of patches
149
  """
150
  # Convert image to numpy for plotting
 
151
  img_np = image.cpu().numpy()
152
- # Transpose from [C, H, W] to [H, W, C]
153
-
154
- # Normalize image for display
155
- #img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())
156
 
157
 
158
- # Get attention from the last layer (or you can average across layers)
159
  last_layer_attention = attention_weights[-1] # [B, num_heads, seq_len, seq_len]
160
 
161
  # Extract attention for this sample
162
  sample_attention = last_layer_attention[sample_idx] # [num_heads, seq_len, seq_len]
163
 
164
- # Average across heads (or you can visualize individual heads)
165
  avg_attention = sample_attention.mean(dim=0) # [seq_len, seq_len]
166
 
167
  # Get attention from CLS token to patches (exclude CLS->CLS)
168
  cls_attention = avg_attention[0, 1:].cpu() # [num_patches]
169
 
170
- # Calculate grid size
171
- H, W = img_np.shape[:2]
172
  grid_h, grid_w = H // patch_size, W // patch_size
173
- #print(grid_h, grid_w)
174
  # Reshape attention to spatial grid
175
  attention_map = cls_attention.reshape(grid_h, grid_w)
176
 
@@ -178,46 +180,52 @@ class AttentionMapCallback(Callback):
178
  fig, axes = plt.subplots(1, 3, figsize=(15, 5))
179
 
180
  # Plot 1: Original image
181
- axes[0].imshow((img_np[:, :,0]+1)/2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  axes[0].set_title(f'Original Image (Epoch {epoch})')
183
  axes[0].axis('off')
184
 
185
  # Plot 2: Attention heatmap
186
- im = axes[1].imshow(attention_map.numpy(), cmap='hot', interpolation='nearest')
 
 
 
 
 
187
  axes[1].set_title(f'Attention Map (Sample {sample_idx})')
188
  axes[1].axis('off')
 
189
  plt.colorbar(im, ax=axes[1])
190
 
191
  # Plot 3: Overlay attention on image
192
- axes[2].imshow((img_np[:, :,0]+1)/2)
193
-
194
- # Overlay attention as colored patches
195
- max_attention = attention_map.max().numpy()
196
- for i in range(grid_h):
197
- for j in range(grid_w):
198
- attention_val = attention_map[i, j].item()
199
- # Create a colored rectangle with alpha based on attention
200
- rect = patches.Rectangle(
201
- (j * patch_size, i * patch_size),
202
- patch_size, patch_size,
203
- linewidth=0,
204
- facecolor='red',
205
- alpha=(attention_val/max_attention) * .9
206
- )
207
- axes[2].add_patch(rect)
208
 
209
- axes[2].set_title(f'Attention Overlay (Sample {sample_idx})')
 
 
210
  axes[2].axis('off')
211
 
212
  plt.tight_layout()
213
- return fig
214
 
215
- # Save the plot
216
- # import os
217
- # os.makedirs(self.save_dir, exist_ok=True)
218
- # plt.savefig(f'{self.save_dir}/attention_epoch_{epoch}_sample_{sample_idx}.png',
219
- # dpi=150, bbox_inches='tight')
220
- # plt.close()
221
 
222
 
223
  class MultiHeadAttentionCallback(AttentionMapCallback):
 
11
  from pytorch_lightning.callbacks import Callback
12
  from PIL import Image
13
  import matplotlib.patches as patches
14
+ import matplotlib.cm as cm
15
+ import matplotlib.colors as mcolors
16
+ from scipy.ndimage import zoom
17
 
18
  # Custom Callback
19
  sdoaia94 = matplotlib.colormaps['sdoaia94']
 
151
  patch_size: Size of patches
152
  """
153
  # Convert image to numpy for plotting
154
+ # Convert image to numpy and transpose
155
  img_np = image.cpu().numpy()
156
+ if len(img_np.shape) == 3 and img_np.shape[0] in [1, 3]: # Check if channels first
157
+ img_np = np.transpose(img_np, (1, 2, 0))
 
 
158
 
159
 
160
+ # Get attention from the last layer
161
  last_layer_attention = attention_weights[-1] # [B, num_heads, seq_len, seq_len]
162
 
163
  # Extract attention for this sample
164
  sample_attention = last_layer_attention[sample_idx] # [num_heads, seq_len, seq_len]
165
 
166
+ # Average across heads
167
  avg_attention = sample_attention.mean(dim=0) # [seq_len, seq_len]
168
 
169
  # Get attention from CLS token to patches (exclude CLS->CLS)
170
  cls_attention = avg_attention[0, 1:].cpu() # [num_patches]
171
 
172
+ # Calculate grid size - NOW USING CORRECT DIMENSIONS
173
+ H, W = img_np.shape[:2] # Now this is correct after transpose
174
  grid_h, grid_w = H // patch_size, W // patch_size
175
+
176
  # Reshape attention to spatial grid
177
  attention_map = cls_attention.reshape(grid_h, grid_w)
178
 
 
180
  fig, axes = plt.subplots(1, 3, figsize=(15, 5))
181
 
182
  # Plot 1: Original image
183
+ # if img_np.shape[2] == 1: # Grayscale
184
+ # img_display = (img_np[:, :, 0] + 1) / 2
185
+ # axes[0].imshow(img_display, cmap='gray')
186
+ # elif img_np.shape[2] == 3: # RGB
187
+ # # Normalize RGB image properly
188
+ # img_display = (img_np + 1) / 2 # Assuming images are in [-1, 1] range
189
+ # img_display = np.clip(img_display, 0, 1) # Ensure valid range
190
+ # axes[0].imshow(img_display)
191
+ # else: # Multi-channel (6 channels in your case)
192
+ # # Option 1: Display first channel as grayscale
193
+ # img_display = (img_np[:, :, 0] + 1) / 2
194
+ # axes[0].imshow(img_display, cmap='gray')
195
+
196
+ # Option 2: Create RGB composite from 3 channels (uncomment if preferred)
197
+ rgb_channels = [0, 2, 4] # Select which channels to use for R, G, B
198
+ img_display = np.stack([(img_np[:, :, i] + 1) / 2 for i in rgb_channels], axis=2)
199
+ img_display = np.clip(img_display, 0, 1)
200
+ axes[0].imshow(img_display)
201
  axes[0].set_title(f'Original Image (Epoch {epoch})')
202
  axes[0].axis('off')
203
 
204
  # Plot 2: Attention heatmap
205
+ attention_np = np.log1p(attention_map.numpy())
206
+ # Resize attention map to match image size
207
+ attention_resized = zoom(attention_np, (H / grid_h, W / grid_w), order=1)
208
+
209
+ # Create colormap for attention - FIX: Use the scalar values, not RGB
210
+ im = axes[1].imshow(attention_resized, cmap='hot')
211
  axes[1].set_title(f'Attention Map (Sample {sample_idx})')
212
  axes[1].axis('off')
213
+ # FIXED: Create colorbar from the scalar image, not RGB
214
  plt.colorbar(im, ax=axes[1])
215
 
216
  # Plot 3: Overlay attention on image
217
+ #img_display_overlay = (img_np[:, :, 0] + 1) / 2
218
+ axes[2].imshow(img_display)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
+ # Overlay attention with proper alpha blending
221
+ axes[2].imshow(attention_resized, cmap='hot', alpha=0.5)
222
+ axes[2].set_title(f'Log-Scaled Attention Overlay (Sample {sample_idx})')
223
  axes[2].axis('off')
224
 
225
  plt.tight_layout()
 
226
 
227
+ plt.tight_layout()
228
+ return fig
 
 
 
 
229
 
230
 
231
  class MultiHeadAttentionCallback(AttentionMapCallback):
flaring/MEGS_AI_baseline/config.yaml CHANGED
@@ -5,8 +5,9 @@ base_checkpoint_dir: "/mnt/data/ML-Ready/mixed_data" # Change this line for d
5
 
6
  # Model configuration
7
  selected_model: "ViT" # Options: "cnn", "vit",
8
-
9
- model:
 
10
  architecture:
11
  "cnn"
12
  seed:
@@ -17,20 +18,16 @@ model:
17
  "resnet"
18
  cnn_dp:
19
  0.5
20
- epochs:
21
- 100
22
- batch_size:
23
- 16
24
 
25
  vit:
26
- embed_dim: 512
27
  num_channels: 6 # AIA has 6 channels
28
  num_classes: 1 # Regression task, predicting SXR flux
29
- patch_size: 16
30
- num_patches: 1024
31
- hidden_dim: 512
32
- num_heads: 4
33
- num_layers: 4
34
  dropout: 0.25
35
  lr: .0001
36
 
@@ -68,5 +65,5 @@ wandb:
68
  - aia
69
  - sxr
70
  - regression
71
- wb_name: flaring-vit-lr-scheduler
72
  notes: Regression from AIA images (6 channels) to GOES SXR flux
 
5
 
6
  # Model configuration
7
  selected_model: "ViT" # Options: "cnn", "vit",
8
+ batch_size: 64
9
+ epochs: 100
10
+ megsai:
11
  architecture:
12
  "cnn"
13
  seed:
 
18
  "resnet"
19
  cnn_dp:
20
  0.5
 
 
 
 
21
 
22
  vit:
23
+ embed_dim: 256
24
  num_channels: 6 # AIA has 6 channels
25
  num_classes: 1 # Regression task, predicting SXR flux
26
+ patch_size: 32
27
+ num_patches: 256
28
+ hidden_dim: 256
29
+ num_heads: 1
30
+ num_layers: 1
31
  dropout: 0.25
32
  lr: .0001
33
 
 
65
  - aia
66
  - sxr
67
  - regression
68
+ wb_name: mixed-vit-lr-scheduler
69
  notes: Regression from AIA images (6 channels) to GOES SXR flux
flaring/MEGS_AI_baseline/train.py CHANGED
@@ -79,8 +79,8 @@ sxr_norm = np.load(config_data['data']['sxr_norm_path'])
79
 
80
  n = 0
81
 
82
- torch.manual_seed(config_data['model']['seed'])
83
- np.random.seed(config_data['model']['seed'])
84
 
85
  # DataModule
86
  data_loader = AIA_GOESDataModule(
@@ -90,7 +90,7 @@ data_loader = AIA_GOESDataModule(
90
  sxr_train_dir=config_data['data']['sxr_dir']+"/train",
91
  sxr_val_dir=config_data['data']['sxr_dir']+"/val",
92
  sxr_test_dir=config_data['data']['sxr_dir']+"/test",
93
- batch_size=config_data['model']['batch_size'],
94
  num_workers=os.cpu_count(),
95
  sxr_norm=sxr_norm,
96
  )
@@ -105,7 +105,7 @@ wandb_logger = WandbLogger(
105
  tags=config_data['wandb']['tags'],
106
  name=config_data['wandb']['wb_name'],
107
  notes=config_data['wandb']['notes'],
108
- config=config_data['model']
109
  )
110
 
111
  # Logging callback
@@ -155,7 +155,7 @@ class PTHCheckpointCallback(Callback):
155
  # Checkpoint callback
156
  checkpoint_callback = ModelCheckpoint(
157
  dirpath=config_data['data']['checkpoints_dir'],
158
- monitor='valid_loss',
159
  mode='min',
160
  save_top_k=1,
161
  filename=f"{config_data['wandb']['wb_name']}-{{epoch:02d}}-{{valid_loss:.4f}}.pth"
@@ -187,12 +187,6 @@ elif config_data['selected_model'] == 'hybrid':
187
  lr=config_data['model']['lr'],
188
  )
189
  elif config_data['selected_model'] == 'ViT':
190
- print("Using ViT")
191
- # model = ViT(embed_dim=config_data['vit']['embed_dim'], hidden_dim=config_data['vit']['hidden_dim'],
192
- # num_channels=config_data['vit']['num_channels'],num_heads=config_data['vit']['num_heads'],
193
- # num_layers=config_data['vit']['num_layers'], num_classes=config_data['vit']['num_classes'],
194
- # patch_size=config_data['vit']['patch_size'], num_patches=config_data['vit']['num_patches'],
195
- # dropout=config_data['vit']['dropout'], lr=config_data['vit']['lr'])
196
  model = ViT(model_kwargs=config_data['vit'])
197
  else:
198
  raise NotImplementedError(f"Architecture {config_data['selected_model']} not supported.")
@@ -202,8 +196,8 @@ trainer = Trainer(
202
  default_root_dir=config_data['data']['checkpoints_dir'],
203
  accelerator="gpu" if torch.cuda.is_available() else "cpu",
204
  devices=1,
205
- max_epochs=config_data['model']['epochs'],
206
- callbacks=[attention, pth_callback],
207
  logger=wandb_logger,
208
  log_every_n_steps=10
209
  )
 
79
 
80
  n = 0
81
 
82
+ torch.manual_seed(config_data['megsai']['seed'])
83
+ np.random.seed(config_data['megsai']['seed'])
84
 
85
  # DataModule
86
  data_loader = AIA_GOESDataModule(
 
90
  sxr_train_dir=config_data['data']['sxr_dir']+"/train",
91
  sxr_val_dir=config_data['data']['sxr_dir']+"/val",
92
  sxr_test_dir=config_data['data']['sxr_dir']+"/test",
93
+ batch_size=config_data['batch_size'],
94
  num_workers=os.cpu_count(),
95
  sxr_norm=sxr_norm,
96
  )
 
105
  tags=config_data['wandb']['tags'],
106
  name=config_data['wandb']['wb_name'],
107
  notes=config_data['wandb']['notes'],
108
+ config=config_data['megsai']
109
  )
110
 
111
  # Logging callback
 
155
  # Checkpoint callback
156
  checkpoint_callback = ModelCheckpoint(
157
  dirpath=config_data['data']['checkpoints_dir'],
158
+ monitor='val_loss',
159
  mode='min',
160
  save_top_k=1,
161
  filename=f"{config_data['wandb']['wb_name']}-{{epoch:02d}}-{{valid_loss:.4f}}.pth"
 
187
  lr=config_data['model']['lr'],
188
  )
189
  elif config_data['selected_model'] == 'ViT':
 
 
 
 
 
 
190
  model = ViT(model_kwargs=config_data['vit'])
191
  else:
192
  raise NotImplementedError(f"Architecture {config_data['selected_model']} not supported.")
 
196
  default_root_dir=config_data['data']['checkpoints_dir'],
197
  accelerator="gpu" if torch.cuda.is_available() else "cpu",
198
  devices=1,
199
+ max_epochs=config_data['epochs'],
200
+ callbacks=[attention, pth_callback,checkpoint_callback],
201
  logger=wandb_logger,
202
  log_every_n_steps=10
203
  )