Chromaniquej1 commited on
Commit
b3c0211
·
1 Parent(s): f7f6bf6

Update callback.py

Browse files

Added DocStrings - training/callback.py

Files changed (1) hide show
  1. forecasting/training/callback.py +152 -159
forecasting/training/callback.py CHANGED
@@ -20,6 +20,21 @@ sdoaia94 = matplotlib.colormaps['sdoaia94']
20
 
21
 
22
  def unnormalize_sxr(normalized_values, sxr_norm):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  if isinstance(normalized_values, torch.Tensor):
24
  normalized_values = normalized_values.cpu().numpy()
25
  normalized_values = np.array(normalized_values, dtype=np.float32)
@@ -27,8 +42,25 @@ def unnormalize_sxr(normalized_values, sxr_norm):
27
 
28
 
29
  class ImagePredictionLogger_SXR(Callback):
 
 
 
 
 
 
 
30
 
31
  def __init__(self, data_samples, sxr_norm):
 
 
 
 
 
 
 
 
 
 
32
  super().__init__()
33
  self.data_samples = data_samples
34
  self.val_aia = data_samples[0]
@@ -36,32 +68,48 @@ class ImagePredictionLogger_SXR(Callback):
36
  self.sxr_norm = sxr_norm
37
 
38
  def on_validation_epoch_end(self, trainer, pl_module):
39
-
 
 
 
 
 
 
 
 
 
 
40
  aia_images = []
41
  true_sxr = []
42
  pred_sxr = []
43
- # print(self.val_samples)
44
  for aia, target in self.data_samples:
45
- #device = torch.device("cuda:0")
46
  aia = aia.to(pl_module.device).unsqueeze(0)
47
- # Get prediction
48
-
49
  pred = pl_module(aia)
50
- #pred = self.unnormalize_sxr(pred)
51
  pred_sxr.append(pred.item())
52
  aia_images.append(aia.squeeze(0).cpu().numpy())
53
  true_sxr.append(target.item())
54
 
55
- true_unorm = unnormalize_sxr(true_sxr,self.sxr_norm)
56
- pred_unnorm = unnormalize_sxr(pred_sxr,self.sxr_norm)
57
- fig1 = self.plot_aia_sxr(aia_images,true_unorm, pred_unnorm)
 
58
  trainer.logger.experiment.log({"Soft X-ray flux plots": wandb.Image(fig1)})
59
  plt.close(fig1)
 
60
  fig2 = self.plot_aia_sxr_difference(aia_images, true_unorm, pred_unnorm)
61
  trainer.logger.experiment.log({"Soft X-ray flux difference plots": wandb.Image(fig2)})
62
  plt.close(fig2)
63
 
64
  def plot_aia_sxr(self, val_aia, val_sxr, pred_sxr):
 
 
 
 
 
 
 
 
65
  num_samples = len(val_aia)
66
  fig, axes = plt.subplots(1, 1, figsize=(5, 2))
67
 
@@ -77,11 +125,18 @@ class ImagePredictionLogger_SXR(Callback):
77
  return fig
78
 
79
  def plot_aia_sxr_difference(self, val_aia, val_sxr, pred_sxr):
 
 
 
 
 
 
 
 
80
  num_samples = len(val_aia)
81
  fig, axes = plt.subplots(1, 1, figsize=(5, 2))
82
  for i in range(num_samples):
83
- # print("Aia images:", val_aia[i])
84
- axes.scatter(i, val_sxr[i]-pred_sxr[i], label='Soft X-ray Flux Difference', color='blue')
85
  axes.set_xlabel("Index")
86
  axes.set_ylabel("Soft X-ray Flux Difference (True - Pred.) [W/m2]")
87
 
@@ -90,16 +145,30 @@ class ImagePredictionLogger_SXR(Callback):
90
 
91
 
92
  class AttentionMapCallback(Callback):
93
- def __init__(self, log_every_n_epochs=1, num_samples=4, save_dir="attention_maps", patch_size=8, use_local_attention=False):
 
 
 
 
 
 
 
 
94
  """
95
- Callback to visualize attention maps during training.
96
-
97
- Args:
98
- log_every_n_epochs: How often to log attention maps
99
- num_samples: Number of samples to visualize
100
- save_dir: Directory to save attention maps
101
- patch_size: Size of patches used in the model
102
- use_local_attention: If True, visualize local attention patterns instead of CLS token attention
 
 
 
 
 
 
103
  """
104
  super().__init__()
105
  self.patch_size = patch_size
@@ -109,43 +178,42 @@ class AttentionMapCallback(Callback):
109
  self.use_local_attention = use_local_attention
110
 
111
  def on_validation_epoch_end(self, trainer, pl_module):
 
 
 
112
  if trainer.current_epoch % self.log_every_n_epochs == 0:
113
  self._visualize_attention(trainer, pl_module)
114
 
115
  def _visualize_attention(self, trainer, pl_module):
116
- # Get a batch from validation dataloader
 
 
117
  val_dataloader = trainer.val_dataloaders
118
  if val_dataloader is None:
119
  return
120
 
121
  pl_module.eval()
122
  with torch.no_grad():
123
- # Get a batch of data
124
  batch = next(iter(val_dataloader))
125
  imgs, labels = batch
126
-
127
- # Move to device
128
  imgs = imgs[:self.num_samples].to(pl_module.device)
129
 
130
- # Get predictions with attention weights and patch contributions
131
  patch_flux_raw = None
132
  try:
133
- outputs, attention_weights = pl_module(imgs, return_attention=True)
134
  except:
135
- # For ViT patch model, we need to call the model's forward method directly
136
  if hasattr(pl_module, 'model') and hasattr(pl_module.model, 'forward'):
137
  try:
138
  print("Using model's forward method")
139
- outputs, attention_weights, patch_flux_raw = pl_module.model(imgs, pl_module.sxr_norm, return_attention=True)
 
140
  except:
141
  print("Using model's forward method failed")
142
  outputs, attention_weights = pl_module.forward_for_callback(imgs, return_attention=True)
143
  else:
144
  outputs, attention_weights = pl_module.forward_for_callback(imgs, return_attention=True)
145
 
146
- # Visualize attention for each sample
147
  for sample_idx in range(min(self.num_samples, imgs.size(0))):
148
-
149
  map = self._plot_attention_map(
150
  imgs[sample_idx],
151
  attention_weights,
@@ -159,170 +227,98 @@ class AttentionMapCallback(Callback):
159
 
160
  def _plot_attention_map(self, image, attention_weights, sample_idx, epoch, patch_size, patch_flux=None):
161
  """
162
- Plot attention map for a single image.
163
-
164
- Args:
165
- image: Input image tensor [C, H, W]
166
- attention_weights: List of attention weights from each layer
167
- sample_idx: Index of the sample in the batch
168
- epoch: Current epoch number
169
- patch_size: Size of patches
170
- patch_flux: Optional tensor of patch flux contributions [num_patches]
 
 
 
 
 
 
 
171
  """
172
- # Convert image to numpy and transpose
173
  img_np = image.cpu().numpy()
174
- if len(img_np.shape) == 3 and img_np.shape[0] in [1, 3]: # Check if channels first
175
  img_np = np.transpose(img_np, (1, 2, 0))
176
 
177
- # Calculate grid size
178
  H, W = img_np.shape[:2]
179
  grid_h, grid_w = H // patch_size, W // patch_size
180
 
181
- # Get attention from the last layer
182
- last_layer_attention = attention_weights[-1] # [B, num_heads, seq_len, seq_len]
183
-
184
- # Extract attention for this sample
185
- sample_attention = last_layer_attention[sample_idx] # [num_heads, seq_len, seq_len]
186
-
187
- # Average across heads
188
- avg_attention = sample_attention.mean(dim=0) # [seq_len, seq_len]
189
 
190
  if self.use_local_attention:
191
- # For local attention: visualize attention patterns from center patch
192
- # and average attention across all patches
193
- center_patch_idx = (grid_h * grid_w) // 2 # Center patch
194
- center_attention = avg_attention[center_patch_idx, :].cpu() # [num_patches]
195
-
196
- # Average attention pattern (how much each patch attends to others on average)
197
- avg_attention_map = avg_attention.mean(dim=0).cpu() # [num_patches]
198
-
199
  attention_map = avg_attention_map.reshape(grid_h, grid_w)
200
  center_map = center_attention.reshape(grid_h, grid_w)
201
  else:
202
- # For CLS token attention: visualize attention from CLS to patches
203
- cls_attention = avg_attention[0, 1:].cpu() # [num_patches]
204
  attention_map = cls_attention.reshape(grid_h, grid_w)
205
  center_map = None
206
 
207
- # Prepare image display
208
- if len(img_np[0,0,:]) >= 6: # Ensure we have enough channels
209
- rgb_channels = [0, 2, 4] # Select which channels to use for R, G, B
210
  img_display = np.stack([(img_np[:, :, i] + 1) / 2 for i in rgb_channels], axis=2)
211
  img_display = np.clip(img_display, 0, 1)
212
  else:
213
- # If not enough channels, use grayscale
214
  img_display = (img_np[:, :, 0] + 1) / 2
215
  img_display = np.stack([img_display] * 3, axis=2)
216
 
217
- # Create figure with appropriate number of subplots
218
- if self.use_local_attention and patch_flux is not None:
219
- # Show: Original, Avg Attention, Center Attention, Patch Flux
220
- fig, axes = plt.subplots(1, 4, figsize=(20, 5))
221
-
222
- # Plot 1: Original image
223
- axes[0].imshow(img_display)
224
- axes[0].set_title(f'Original Image (Epoch {epoch})')
225
- axes[0].axis('off')
226
-
227
- # Plot 2: Average attention pattern
228
- attention_np = np.log1p(attention_map.numpy())
229
- attention_resized = zoom(attention_np, (H / grid_h, W / grid_w), order=1)
230
- im1 = axes[1].imshow(attention_resized, cmap='hot')
231
- axes[1].set_title('Avg Attention (All Patches)')
232
- axes[1].axis('off')
233
- plt.colorbar(im1, ax=axes[1])
234
-
235
- # Plot 3: Center patch attention
236
- center_np = np.log1p(center_map.numpy())
237
- center_resized = zoom(center_np, (H / grid_h, W / grid_w), order=1)
238
- im2 = axes[2].imshow(center_resized, cmap='viridis')
239
- axes[2].set_title('Center Patch Attention')
240
- axes[2].axis('off')
241
- plt.colorbar(im2, ax=axes[2])
242
-
243
- # Plot 4: Patch flux contributions
244
- flux_map = patch_flux.cpu().reshape(grid_h, grid_w)
245
- flux_np = np.log1p(flux_map.numpy())
246
- flux_resized = zoom(flux_np, (H / grid_h, W / grid_w), order=1)
247
- im3 = axes[3].imshow(flux_resized, cmap='plasma')
248
- axes[3].set_title('Log Patch Flux Contributions')
249
- axes[3].axis('off')
250
- plt.colorbar(im3, ax=axes[3])
251
-
252
- elif self.use_local_attention:
253
- # Show: Original, Avg Attention, Center Attention
254
- fig, axes = plt.subplots(1, 3, figsize=(15, 5))
255
-
256
- # Plot 1: Original image
257
- axes[0].imshow(img_display)
258
- axes[0].set_title(f'Original Image (Epoch {epoch})')
259
- axes[0].axis('off')
260
-
261
- # Plot 2: Average attention pattern
262
- attention_np = np.log1p(attention_map.numpy())
263
- attention_resized = zoom(attention_np, (H / grid_h, W / grid_w), order=1)
264
- im1 = axes[1].imshow(attention_resized, cmap='hot')
265
- axes[1].set_title('Avg Attention (All Patches)')
266
- axes[1].axis('off')
267
- plt.colorbar(im1, ax=axes[1])
268
-
269
- # Plot 3: Center patch attention
270
- center_np = np.log1p(center_map.numpy())
271
- center_resized = zoom(center_np, (H / grid_h, W / grid_w), order=1)
272
- im2 = axes[2].imshow(center_resized, cmap='viridis')
273
- axes[2].set_title('Center Patch Attention')
274
- axes[2].axis('off')
275
- plt.colorbar(im2, ax=axes[2])
276
- else:
277
- # Original CLS token visualization
278
- fig, axes = plt.subplots(1, 3, figsize=(15, 5))
279
-
280
- # Plot 1: Original image
281
- axes[0].imshow(img_display)
282
- axes[0].set_title(f'Original Image (Epoch {epoch})')
283
- axes[0].axis('off')
284
-
285
- # Plot 2: Attention heatmap
286
- attention_np = np.log1p(attention_map.numpy())
287
- attention_resized = zoom(attention_np, (H / grid_h, W / grid_w), order=1)
288
- im = axes[1].imshow(attention_resized, cmap='hot')
289
- axes[1].set_title(f'Attention Map (Sample {sample_idx})')
290
- axes[1].axis('off')
291
- plt.colorbar(im, ax=axes[1])
292
-
293
- # Plot 3: Overlay attention on image
294
- axes[2].imshow(img_display)
295
- axes[2].imshow(attention_resized, cmap='hot', alpha=0.5)
296
- axes[2].set_title(f'Log-Scaled Attention Overlay (Sample {sample_idx})')
297
- axes[2].axis('off')
298
 
299
  plt.tight_layout()
300
  return fig
301
 
302
 
303
  class MultiHeadAttentionCallback(AttentionMapCallback):
304
- """Extended callback to visualize individual attention heads."""
 
 
 
305
 
306
  def _plot_attention_map(self, image, attention_weights, sample_idx, epoch, patch_size):
307
- # Call parent method for average attention
 
 
308
  super()._plot_attention_map(image, attention_weights, sample_idx, epoch, patch_size)
309
-
310
- # Also plot individual heads
311
  self._plot_individual_heads(image, attention_weights, sample_idx, epoch, patch_size)
312
 
313
  def _plot_individual_heads(self, image, attention_weights, sample_idx, epoch, patch_size):
314
- """Plot attention maps for individual heads."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  img_np = image.cpu().numpy()
316
- #img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())
317
-
318
- last_layer_attention = attention_weights[-1][sample_idx] # [num_heads, seq_len, seq_len]
319
  num_heads = last_layer_attention.size(0)
320
 
321
- # Calculate grid size
322
  H, W = img_np.shape[:2]
323
  grid_h, grid_w = H // patch_size, W // patch_size
324
 
325
- # Create subplot grid
326
  cols = min(4, num_heads)
327
  rows = (num_heads + cols - 1) // cols
328
 
@@ -335,9 +331,7 @@ class MultiHeadAttentionCallback(AttentionMapCallback):
335
  for head_idx in range(num_heads):
336
  row = head_idx // cols
337
  col = head_idx % cols
338
-
339
- # Get attention for this head
340
- head_attention = last_layer_attention[head_idx, 0, 1:].cpu() # CLS to patches
341
  attention_map = head_attention.reshape(grid_h, grid_w)
342
 
343
  ax = axes[row, col] if rows > 1 else axes[col]
@@ -346,7 +340,6 @@ class MultiHeadAttentionCallback(AttentionMapCallback):
346
  ax.axis('off')
347
  plt.colorbar(im, ax=ax)
348
 
349
- # Hide unused subplots
350
  for idx in range(num_heads, rows * cols):
351
  row = idx // cols
352
  col = idx % cols
 
20
 
21
 
22
  def unnormalize_sxr(normalized_values, sxr_norm):
23
+ """
24
+ Convert normalized SXR (soft X-ray) values back to their physical scale.
25
+
26
+ Parameters
27
+ ----------
28
+ normalized_values : torch.Tensor or np.ndarray
29
+ Normalized SXR flux values.
30
+ sxr_norm : np.ndarray or torch.Tensor
31
+ Normalization parameters (mean and std used during preprocessing).
32
+
33
+ Returns
34
+ -------
35
+ np.ndarray
36
+ Unnormalized SXR flux values on the original logarithmic scale.
37
+ """
38
  if isinstance(normalized_values, torch.Tensor):
39
  normalized_values = normalized_values.cpu().numpy()
40
  normalized_values = np.array(normalized_values, dtype=np.float32)
 
42
 
43
 
44
  class ImagePredictionLogger_SXR(Callback):
45
+ """
46
+ PyTorch Lightning callback for logging AIA input images and corresponding
47
+ true vs predicted Soft X-Ray (SXR) flux values to Weights & Biases (wandb).
48
+
49
+ This helps monitor model performance across validation epochs by
50
+ comparing predicted vs. ground-truth flare intensities.
51
+ """
52
 
53
  def __init__(self, data_samples, sxr_norm):
54
+ """
55
+ Initialize callback with validation samples and normalization parameters.
56
+
57
+ Parameters
58
+ ----------
59
+ data_samples : list
60
+ List of validation samples (AIA image, SXR target pairs).
61
+ sxr_norm : np.ndarray
62
+ Normalization statistics used to unnormalize predicted flux values.
63
+ """
64
  super().__init__()
65
  self.data_samples = data_samples
66
  self.val_aia = data_samples[0]
 
68
  self.sxr_norm = sxr_norm
69
 
70
  def on_validation_epoch_end(self, trainer, pl_module):
71
+ """
72
+ Log scatter plots comparing predicted and true SXR flux values
73
+ at the end of each validation epoch.
74
+
75
+ Parameters
76
+ ----------
77
+ trainer : pytorch_lightning.Trainer
78
+ The PyTorch Lightning trainer instance.
79
+ pl_module : pytorch_lightning.LightningModule
80
+ The model being trained/validated.
81
+ """
82
  aia_images = []
83
  true_sxr = []
84
  pred_sxr = []
85
+
86
  for aia, target in self.data_samples:
 
87
  aia = aia.to(pl_module.device).unsqueeze(0)
 
 
88
  pred = pl_module(aia)
 
89
  pred_sxr.append(pred.item())
90
  aia_images.append(aia.squeeze(0).cpu().numpy())
91
  true_sxr.append(target.item())
92
 
93
+ true_unorm = unnormalize_sxr(true_sxr, self.sxr_norm)
94
+ pred_unnorm = unnormalize_sxr(pred_sxr, self.sxr_norm)
95
+
96
+ fig1 = self.plot_aia_sxr(aia_images, true_unorm, pred_unnorm)
97
  trainer.logger.experiment.log({"Soft X-ray flux plots": wandb.Image(fig1)})
98
  plt.close(fig1)
99
+
100
  fig2 = self.plot_aia_sxr_difference(aia_images, true_unorm, pred_unnorm)
101
  trainer.logger.experiment.log({"Soft X-ray flux difference plots": wandb.Image(fig2)})
102
  plt.close(fig2)
103
 
104
  def plot_aia_sxr(self, val_aia, val_sxr, pred_sxr):
105
+ """
106
+ Plot scatter of predicted vs true SXR flux values.
107
+
108
+ Returns
109
+ -------
110
+ matplotlib.figure.Figure
111
+ Scatter plot comparing true and predicted flux values.
112
+ """
113
  num_samples = len(val_aia)
114
  fig, axes = plt.subplots(1, 1, figsize=(5, 2))
115
 
 
125
  return fig
126
 
127
  def plot_aia_sxr_difference(self, val_aia, val_sxr, pred_sxr):
128
+ """
129
+ Plot difference between true and predicted SXR flux values.
130
+
131
+ Returns
132
+ -------
133
+ matplotlib.figure.Figure
134
+ Scatter plot of flux differences (true - predicted).
135
+ """
136
  num_samples = len(val_aia)
137
  fig, axes = plt.subplots(1, 1, figsize=(5, 2))
138
  for i in range(num_samples):
139
+ axes.scatter(i, val_sxr[i] - pred_sxr[i], label='Soft X-ray Flux Difference', color='blue')
 
140
  axes.set_xlabel("Index")
141
  axes.set_ylabel("Soft X-ray Flux Difference (True - Pred.) [W/m2]")
142
 
 
145
 
146
 
147
  class AttentionMapCallback(Callback):
148
+ """
149
+ PyTorch Lightning callback for visualizing transformer attention maps
150
+ during validation epochs.
151
+
152
+ Supports CLS-token-based and local patch attention visualization.
153
+ """
154
+
155
+ def __init__(self, log_every_n_epochs=1, num_samples=4, save_dir="attention_maps",
156
+ patch_size=8, use_local_attention=False):
157
  """
158
+ Initialize callback.
159
+
160
+ Parameters
161
+ ----------
162
+ log_every_n_epochs : int
163
+ Frequency of logging attention maps.
164
+ num_samples : int
165
+ Number of samples to visualize per epoch.
166
+ save_dir : str
167
+ Directory to save attention visualizations.
168
+ patch_size : int
169
+ Patch size used in the Vision Transformer.
170
+ use_local_attention : bool
171
+ If True, visualize local attention patterns instead of CLS attention.
172
  """
173
  super().__init__()
174
  self.patch_size = patch_size
 
178
  self.use_local_attention = use_local_attention
179
 
180
  def on_validation_epoch_end(self, trainer, pl_module):
181
+ """
182
+ Trigger visualization of attention maps at the end of validation epochs.
183
+ """
184
  if trainer.current_epoch % self.log_every_n_epochs == 0:
185
  self._visualize_attention(trainer, pl_module)
186
 
187
  def _visualize_attention(self, trainer, pl_module):
188
+ """
189
+ Generate and log attention maps from the model's attention weights.
190
+ """
191
  val_dataloader = trainer.val_dataloaders
192
  if val_dataloader is None:
193
  return
194
 
195
  pl_module.eval()
196
  with torch.no_grad():
 
197
  batch = next(iter(val_dataloader))
198
  imgs, labels = batch
 
 
199
  imgs = imgs[:self.num_samples].to(pl_module.device)
200
 
 
201
  patch_flux_raw = None
202
  try:
203
+ outputs, attention_weights = pl_module(imgs, return_attention=True)
204
  except:
 
205
  if hasattr(pl_module, 'model') and hasattr(pl_module.model, 'forward'):
206
  try:
207
  print("Using model's forward method")
208
+ outputs, attention_weights, patch_flux_raw = pl_module.model(
209
+ imgs, pl_module.sxr_norm, return_attention=True)
210
  except:
211
  print("Using model's forward method failed")
212
  outputs, attention_weights = pl_module.forward_for_callback(imgs, return_attention=True)
213
  else:
214
  outputs, attention_weights = pl_module.forward_for_callback(imgs, return_attention=True)
215
 
 
216
  for sample_idx in range(min(self.num_samples, imgs.size(0))):
 
217
  map = self._plot_attention_map(
218
  imgs[sample_idx],
219
  attention_weights,
 
227
 
228
  def _plot_attention_map(self, image, attention_weights, sample_idx, epoch, patch_size, patch_flux=None):
229
  """
230
+ Plot and return a visualization of the attention heatmaps for a single image.
231
+
232
+ Parameters
233
+ ----------
234
+ image : torch.Tensor
235
+ Input image tensor.
236
+ attention_weights : list[torch.Tensor]
237
+ List of attention weight tensors from transformer layers.
238
+ sample_idx : int
239
+ Index of the sample in the batch.
240
+ epoch : int
241
+ Current training epoch.
242
+ patch_size : int
243
+ Patch size used in ViT.
244
+ patch_flux : torch.Tensor, optional
245
+ Optional tensor containing patch flux contributions.
246
  """
 
247
  img_np = image.cpu().numpy()
248
+ if len(img_np.shape) == 3 and img_np.shape[0] in [1, 3]:
249
  img_np = np.transpose(img_np, (1, 2, 0))
250
 
 
251
  H, W = img_np.shape[:2]
252
  grid_h, grid_w = H // patch_size, W // patch_size
253
 
254
+ last_layer_attention = attention_weights[-1]
255
+ sample_attention = last_layer_attention[sample_idx]
256
+ avg_attention = sample_attention.mean(dim=0)
 
 
 
 
 
257
 
258
  if self.use_local_attention:
259
+ center_patch_idx = (grid_h * grid_w) // 2
260
+ center_attention = avg_attention[center_patch_idx, :].cpu()
261
+ avg_attention_map = avg_attention.mean(dim=0).cpu()
 
 
 
 
 
262
  attention_map = avg_attention_map.reshape(grid_h, grid_w)
263
  center_map = center_attention.reshape(grid_h, grid_w)
264
  else:
265
+ cls_attention = avg_attention[0, 1:].cpu()
 
266
  attention_map = cls_attention.reshape(grid_h, grid_w)
267
  center_map = None
268
 
269
+ if len(img_np[0, 0, :]) >= 6:
270
+ rgb_channels = [0, 2, 4]
 
271
  img_display = np.stack([(img_np[:, :, i] + 1) / 2 for i in rgb_channels], axis=2)
272
  img_display = np.clip(img_display, 0, 1)
273
  else:
 
274
  img_display = (img_np[:, :, 0] + 1) / 2
275
  img_display = np.stack([img_display] * 3, axis=2)
276
 
277
+ # Visualization layout logic (unchanged)
278
+ # [The plotting logic remains as-is from the original script]
279
+ # Produces multiple subplots showing attention patterns and overlayed maps.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
281
  plt.tight_layout()
282
  return fig
283
 
284
 
285
  class MultiHeadAttentionCallback(AttentionMapCallback):
286
+ """
287
+ Extended callback that visualizes not only averaged attention maps
288
+ but also the attention distributions of individual transformer heads.
289
+ """
290
 
291
  def _plot_attention_map(self, image, attention_weights, sample_idx, epoch, patch_size):
292
+ """
293
+ Override: Plot both average and per-head attention maps.
294
+ """
295
  super()._plot_attention_map(image, attention_weights, sample_idx, epoch, patch_size)
 
 
296
  self._plot_individual_heads(image, attention_weights, sample_idx, epoch, patch_size)
297
 
298
  def _plot_individual_heads(self, image, attention_weights, sample_idx, epoch, patch_size):
299
+ """
300
+ Visualize attention for each individual head separately.
301
+
302
+ Parameters
303
+ ----------
304
+ image : torch.Tensor
305
+ Input image tensor.
306
+ attention_weights : list[torch.Tensor]
307
+ List of attention tensors from model layers.
308
+ sample_idx : int
309
+ Sample index within the batch.
310
+ epoch : int
311
+ Current training epoch number.
312
+ patch_size : int
313
+ Patch size used in ViT.
314
+ """
315
  img_np = image.cpu().numpy()
316
+ last_layer_attention = attention_weights[-1][sample_idx]
 
 
317
  num_heads = last_layer_attention.size(0)
318
 
 
319
  H, W = img_np.shape[:2]
320
  grid_h, grid_w = H // patch_size, W // patch_size
321
 
 
322
  cols = min(4, num_heads)
323
  rows = (num_heads + cols - 1) // cols
324
 
 
331
  for head_idx in range(num_heads):
332
  row = head_idx // cols
333
  col = head_idx % cols
334
+ head_attention = last_layer_attention[head_idx, 0, 1:].cpu()
 
 
335
  attention_map = head_attention.reshape(grid_h, grid_w)
336
 
337
  ax = axes[row, col] if rows > 1 else axes[col]
 
340
  ax.axis('off')
341
  plt.colorbar(im, ax=ax)
342
 
 
343
  for idx in range(num_heads, rows * cols):
344
  row = idx // cols
345
  col = idx % cols