griffingoodwin04 commited on
Commit
d8be1d3
·
1 Parent(s): a6eddcd

Refactor attention map callback to log attention plots and add multi-head visualization

Browse files
flaring/MEGS_AI_baseline/callback.py CHANGED
@@ -87,7 +87,7 @@ class ImagePredictionLogger_SXR(Callback):
87
 
88
 
89
  class AttentionMapCallback(Callback):
90
- def __init__(self, log_every_n_epochs=10, num_samples=4, save_dir="attention_maps"):
91
  """
92
  Callback to visualize attention maps during training.
93
 
@@ -125,13 +125,16 @@ class AttentionMapCallback(Callback):
125
 
126
  # Visualize attention for each sample
127
  for sample_idx in range(min(self.num_samples, imgs.size(0))):
128
- self._plot_attention_map(
 
129
  imgs[sample_idx],
130
  attention_weights,
131
  sample_idx,
132
  trainer.current_epoch,
133
  pl_module.model.patch_size
134
  )
 
 
135
 
136
  def _plot_attention_map(self, image, attention_weights, sample_idx, epoch, patch_size):
137
  """
@@ -145,11 +148,12 @@ class AttentionMapCallback(Callback):
145
  patch_size: Size of patches
146
  """
147
  # Convert image to numpy for plotting
148
- print(image.shape)
149
  img_np = image.cpu().numpy()
150
- print(img_np.shape)
 
151
  # Normalize image for display
152
- img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())
 
153
 
154
  # Get attention from the last layer (or you can average across layers)
155
  last_layer_attention = attention_weights[-1] # [B, num_heads, seq_len, seq_len]
@@ -166,7 +170,7 @@ class AttentionMapCallback(Callback):
166
  # Calculate grid size
167
  H, W = img_np.shape[:2]
168
  grid_h, grid_w = H // patch_size, W // patch_size
169
- print(grid_h, grid_w)
170
  # Reshape attention to spatial grid
171
  attention_map = cls_attention.reshape(grid_h, grid_w)
172
 
@@ -174,7 +178,7 @@ class AttentionMapCallback(Callback):
174
  fig, axes = plt.subplots(1, 3, figsize=(15, 5))
175
 
176
  # Plot 1: Original image
177
- axes[0].imshow(img_np[::0])
178
  axes[0].set_title(f'Original Image (Epoch {epoch})')
179
  axes[0].axis('off')
180
 
@@ -185,9 +189,10 @@ class AttentionMapCallback(Callback):
185
  plt.colorbar(im, ax=axes[1])
186
 
187
  # Plot 3: Overlay attention on image
188
- axes[2].imshow(img_np[::0])
189
 
190
  # Overlay attention as colored patches
 
191
  for i in range(grid_h):
192
  for j in range(grid_w):
193
  attention_val = attention_map[i, j].item()
@@ -197,7 +202,7 @@ class AttentionMapCallback(Callback):
197
  patch_size, patch_size,
198
  linewidth=0,
199
  facecolor='red',
200
- alpha=attention_val * 0.7 # Scale alpha by attention
201
  )
202
  axes[2].add_patch(rect)
203
 
@@ -213,3 +218,62 @@ class AttentionMapCallback(Callback):
213
  # plt.savefig(f'{self.save_dir}/attention_epoch_{epoch}_sample_{sample_idx}.png',
214
  # dpi=150, bbox_inches='tight')
215
  # plt.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
 
89
  class AttentionMapCallback(Callback):
90
+ def __init__(self, log_every_n_epochs=1, num_samples=4, save_dir="attention_maps"):
91
  """
92
  Callback to visualize attention maps during training.
93
 
 
125
 
126
  # Visualize attention for each sample
127
  for sample_idx in range(min(self.num_samples, imgs.size(0))):
128
+
129
+ map = self._plot_attention_map(
130
  imgs[sample_idx],
131
  attention_weights,
132
  sample_idx,
133
  trainer.current_epoch,
134
  pl_module.model.patch_size
135
  )
136
+ trainer.logger.experiment.log({"Attention plots": wandb.Image(map)})
137
+ plt.close(map)
138
 
139
  def _plot_attention_map(self, image, attention_weights, sample_idx, epoch, patch_size):
140
  """
 
148
  patch_size: Size of patches
149
  """
150
  # Convert image to numpy for plotting
 
151
  img_np = image.cpu().numpy()
152
+ # Transpose from [C, H, W] to [H, W, C]
153
+
154
  # Normalize image for display
155
+ #img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())
156
+
157
 
158
  # Get attention from the last layer (or you can average across layers)
159
  last_layer_attention = attention_weights[-1] # [B, num_heads, seq_len, seq_len]
 
170
  # Calculate grid size
171
  H, W = img_np.shape[:2]
172
  grid_h, grid_w = H // patch_size, W // patch_size
173
+ #print(grid_h, grid_w)
174
  # Reshape attention to spatial grid
175
  attention_map = cls_attention.reshape(grid_h, grid_w)
176
 
 
178
  fig, axes = plt.subplots(1, 3, figsize=(15, 5))
179
 
180
  # Plot 1: Original image
181
+ axes[0].imshow(img_np[:, :, :3]) # only first 3 channels if more than 3
182
  axes[0].set_title(f'Original Image (Epoch {epoch})')
183
  axes[0].axis('off')
184
 
 
189
  plt.colorbar(im, ax=axes[1])
190
 
191
  # Plot 3: Overlay attention on image
192
+ axes[2].imshow(img_np[:, :, :3])
193
 
194
  # Overlay attention as colored patches
195
+ max_attention = attention_map.max().numpy()
196
  for i in range(grid_h):
197
  for j in range(grid_w):
198
  attention_val = attention_map[i, j].item()
 
202
  patch_size, patch_size,
203
  linewidth=0,
204
  facecolor='red',
205
+ alpha=(attention_val/max_attention) * .9
206
  )
207
  axes[2].add_patch(rect)
208
 
 
218
  # plt.savefig(f'{self.save_dir}/attention_epoch_{epoch}_sample_{sample_idx}.png',
219
  # dpi=150, bbox_inches='tight')
220
  # plt.close()
221
+
222
+
223
+ class MultiHeadAttentionCallback(AttentionMapCallback):
224
+ """Extended callback to visualize individual attention heads."""
225
+
226
+ def _plot_attention_map(self, image, attention_weights, sample_idx, epoch, patch_size):
227
+ # Call parent method for average attention
228
+ super()._plot_attention_map(image, attention_weights, sample_idx, epoch, patch_size)
229
+
230
+ # Also plot individual heads
231
+ self._plot_individual_heads(image, attention_weights, sample_idx, epoch, patch_size)
232
+
233
+ def _plot_individual_heads(self, image, attention_weights, sample_idx, epoch, patch_size):
234
+ """Plot attention maps for individual heads."""
235
+ img_np = image.cpu().numpy()
236
+ #img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())
237
+
238
+ last_layer_attention = attention_weights[-1][sample_idx] # [num_heads, seq_len, seq_len]
239
+ num_heads = last_layer_attention.size(0)
240
+
241
+ # Calculate grid size
242
+ H, W = img_np.shape[:2]
243
+ grid_h, grid_w = H // patch_size, W // patch_size
244
+
245
+ # Create subplot grid
246
+ cols = min(4, num_heads)
247
+ rows = (num_heads + cols - 1) // cols
248
+
249
+ fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 4 * rows))
250
+ if num_heads == 1:
251
+ axes = [axes]
252
+ elif rows == 1:
253
+ axes = axes.reshape(1, -1)
254
+
255
+ for head_idx in range(num_heads):
256
+ row = head_idx // cols
257
+ col = head_idx % cols
258
+
259
+ # Get attention for this head
260
+ head_attention = last_layer_attention[head_idx, 0, 1:].cpu() # CLS to patches
261
+ attention_map = head_attention.reshape(grid_h, grid_w)
262
+
263
+ ax = axes[row, col] if rows > 1 else axes[col]
264
+ im = ax.imshow(attention_map.numpy(), cmap='hot', interpolation='nearest')
265
+ ax.set_title(f'Head {head_idx}')
266
+ ax.axis('off')
267
+ plt.colorbar(im, ax=ax)
268
+
269
+ # Hide unused subplots
270
+ for idx in range(num_heads, rows * cols):
271
+ row = idx // cols
272
+ col = idx % cols
273
+ ax = axes[row, col] if rows > 1 else axes[col]
274
+ ax.axis('off')
275
+
276
+ plt.tight_layout()
277
+ plt.savefig(f'{self.save_dir}/heads_epoch_{epoch}_sample_{sample_idx}.png',
278
+ dpi=150, bbox_inches='tight')
279
+ plt.close()
flaring/MEGS_AI_baseline/config.yaml CHANGED
@@ -29,10 +29,10 @@ vit:
29
  patch_size: 16
30
  num_patches: 262144
31
  hidden_dim: 512
32
- num_heads: 8
33
  num_layers: 6
34
  dropout: 0.1
35
- lr: .00001
36
 
37
  # Data paths (automatically constructed from base directories)
38
  data:
 
29
  patch_size: 16
30
  num_patches: 262144
31
  hidden_dim: 512
32
+ num_heads: 4
33
  num_layers: 6
34
  dropout: 0.1
35
+ lr: .0001
36
 
37
  # Data paths (automatically constructed from base directories)
38
  data: