Commit ·
d8be1d3
1
Parent(s): a6eddcd
Refactor attention map callback to log attention plots and add multi-head visualization
Browse files
flaring/MEGS_AI_baseline/callback.py
CHANGED
|
@@ -87,7 +87,7 @@ class ImagePredictionLogger_SXR(Callback):
|
|
| 87 |
|
| 88 |
|
| 89 |
class AttentionMapCallback(Callback):
|
| 90 |
-
def __init__(self, log_every_n_epochs=
|
| 91 |
"""
|
| 92 |
Callback to visualize attention maps during training.
|
| 93 |
|
|
@@ -125,13 +125,16 @@ class AttentionMapCallback(Callback):
|
|
| 125 |
|
| 126 |
# Visualize attention for each sample
|
| 127 |
for sample_idx in range(min(self.num_samples, imgs.size(0))):
|
| 128 |
-
|
|
|
|
| 129 |
imgs[sample_idx],
|
| 130 |
attention_weights,
|
| 131 |
sample_idx,
|
| 132 |
trainer.current_epoch,
|
| 133 |
pl_module.model.patch_size
|
| 134 |
)
|
|
|
|
|
|
|
| 135 |
|
| 136 |
def _plot_attention_map(self, image, attention_weights, sample_idx, epoch, patch_size):
|
| 137 |
"""
|
|
@@ -145,11 +148,12 @@ class AttentionMapCallback(Callback):
|
|
| 145 |
patch_size: Size of patches
|
| 146 |
"""
|
| 147 |
# Convert image to numpy for plotting
|
| 148 |
-
print(image.shape)
|
| 149 |
img_np = image.cpu().numpy()
|
| 150 |
-
|
|
|
|
| 151 |
# Normalize image for display
|
| 152 |
-
img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())
|
|
|
|
| 153 |
|
| 154 |
# Get attention from the last layer (or you can average across layers)
|
| 155 |
last_layer_attention = attention_weights[-1] # [B, num_heads, seq_len, seq_len]
|
|
@@ -166,7 +170,7 @@ class AttentionMapCallback(Callback):
|
|
| 166 |
# Calculate grid size
|
| 167 |
H, W = img_np.shape[:2]
|
| 168 |
grid_h, grid_w = H // patch_size, W // patch_size
|
| 169 |
-
print(grid_h, grid_w)
|
| 170 |
# Reshape attention to spatial grid
|
| 171 |
attention_map = cls_attention.reshape(grid_h, grid_w)
|
| 172 |
|
|
@@ -174,7 +178,7 @@ class AttentionMapCallback(Callback):
|
|
| 174 |
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
| 175 |
|
| 176 |
# Plot 1: Original image
|
| 177 |
-
axes[0].imshow(img_np[::
|
| 178 |
axes[0].set_title(f'Original Image (Epoch {epoch})')
|
| 179 |
axes[0].axis('off')
|
| 180 |
|
|
@@ -185,9 +189,10 @@ class AttentionMapCallback(Callback):
|
|
| 185 |
plt.colorbar(im, ax=axes[1])
|
| 186 |
|
| 187 |
# Plot 3: Overlay attention on image
|
| 188 |
-
axes[2].imshow(img_np[::
|
| 189 |
|
| 190 |
# Overlay attention as colored patches
|
|
|
|
| 191 |
for i in range(grid_h):
|
| 192 |
for j in range(grid_w):
|
| 193 |
attention_val = attention_map[i, j].item()
|
|
@@ -197,7 +202,7 @@ class AttentionMapCallback(Callback):
|
|
| 197 |
patch_size, patch_size,
|
| 198 |
linewidth=0,
|
| 199 |
facecolor='red',
|
| 200 |
-
alpha=attention_val *
|
| 201 |
)
|
| 202 |
axes[2].add_patch(rect)
|
| 203 |
|
|
@@ -213,3 +218,62 @@ class AttentionMapCallback(Callback):
|
|
| 213 |
# plt.savefig(f'{self.save_dir}/attention_epoch_{epoch}_sample_{sample_idx}.png',
|
| 214 |
# dpi=150, bbox_inches='tight')
|
| 215 |
# plt.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
|
| 89 |
class AttentionMapCallback(Callback):
|
| 90 |
+
def __init__(self, log_every_n_epochs=1, num_samples=4, save_dir="attention_maps"):
|
| 91 |
"""
|
| 92 |
Callback to visualize attention maps during training.
|
| 93 |
|
|
|
|
| 125 |
|
| 126 |
# Visualize attention for each sample
|
| 127 |
for sample_idx in range(min(self.num_samples, imgs.size(0))):
|
| 128 |
+
|
| 129 |
+
map = self._plot_attention_map(
|
| 130 |
imgs[sample_idx],
|
| 131 |
attention_weights,
|
| 132 |
sample_idx,
|
| 133 |
trainer.current_epoch,
|
| 134 |
pl_module.model.patch_size
|
| 135 |
)
|
| 136 |
+
trainer.logger.experiment.log({"Attention plots": wandb.Image(map)})
|
| 137 |
+
plt.close(map)
|
| 138 |
|
| 139 |
def _plot_attention_map(self, image, attention_weights, sample_idx, epoch, patch_size):
|
| 140 |
"""
|
|
|
|
| 148 |
patch_size: Size of patches
|
| 149 |
"""
|
| 150 |
# Convert image to numpy for plotting
|
|
|
|
| 151 |
img_np = image.cpu().numpy()
|
| 152 |
+
# Transpose from [C, H, W] to [H, W, C]
|
| 153 |
+
|
| 154 |
# Normalize image for display
|
| 155 |
+
#img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())
|
| 156 |
+
|
| 157 |
|
| 158 |
# Get attention from the last layer (or you can average across layers)
|
| 159 |
last_layer_attention = attention_weights[-1] # [B, num_heads, seq_len, seq_len]
|
|
|
|
| 170 |
# Calculate grid size
|
| 171 |
H, W = img_np.shape[:2]
|
| 172 |
grid_h, grid_w = H // patch_size, W // patch_size
|
| 173 |
+
#print(grid_h, grid_w)
|
| 174 |
# Reshape attention to spatial grid
|
| 175 |
attention_map = cls_attention.reshape(grid_h, grid_w)
|
| 176 |
|
|
|
|
| 178 |
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
| 179 |
|
| 180 |
# Plot 1: Original image
|
| 181 |
+
axes[0].imshow(img_np[:, :, :3]) # only first 3 channels if more than 3
|
| 182 |
axes[0].set_title(f'Original Image (Epoch {epoch})')
|
| 183 |
axes[0].axis('off')
|
| 184 |
|
|
|
|
| 189 |
plt.colorbar(im, ax=axes[1])
|
| 190 |
|
| 191 |
# Plot 3: Overlay attention on image
|
| 192 |
+
axes[2].imshow(img_np[:, :, :3])
|
| 193 |
|
| 194 |
# Overlay attention as colored patches
|
| 195 |
+
max_attention = attention_map.max().numpy()
|
| 196 |
for i in range(grid_h):
|
| 197 |
for j in range(grid_w):
|
| 198 |
attention_val = attention_map[i, j].item()
|
|
|
|
| 202 |
patch_size, patch_size,
|
| 203 |
linewidth=0,
|
| 204 |
facecolor='red',
|
| 205 |
+
alpha=(attention_val/max_attention) * .9
|
| 206 |
)
|
| 207 |
axes[2].add_patch(rect)
|
| 208 |
|
|
|
|
| 218 |
# plt.savefig(f'{self.save_dir}/attention_epoch_{epoch}_sample_{sample_idx}.png',
|
| 219 |
# dpi=150, bbox_inches='tight')
|
| 220 |
# plt.close()
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class MultiHeadAttentionCallback(AttentionMapCallback):
|
| 224 |
+
"""Extended callback to visualize individual attention heads."""
|
| 225 |
+
|
| 226 |
+
def _plot_attention_map(self, image, attention_weights, sample_idx, epoch, patch_size):
|
| 227 |
+
# Call parent method for average attention
|
| 228 |
+
super()._plot_attention_map(image, attention_weights, sample_idx, epoch, patch_size)
|
| 229 |
+
|
| 230 |
+
# Also plot individual heads
|
| 231 |
+
self._plot_individual_heads(image, attention_weights, sample_idx, epoch, patch_size)
|
| 232 |
+
|
| 233 |
+
def _plot_individual_heads(self, image, attention_weights, sample_idx, epoch, patch_size):
|
| 234 |
+
"""Plot attention maps for individual heads."""
|
| 235 |
+
img_np = image.cpu().numpy()
|
| 236 |
+
#img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())
|
| 237 |
+
|
| 238 |
+
last_layer_attention = attention_weights[-1][sample_idx] # [num_heads, seq_len, seq_len]
|
| 239 |
+
num_heads = last_layer_attention.size(0)
|
| 240 |
+
|
| 241 |
+
# Calculate grid size
|
| 242 |
+
H, W = img_np.shape[:2]
|
| 243 |
+
grid_h, grid_w = H // patch_size, W // patch_size
|
| 244 |
+
|
| 245 |
+
# Create subplot grid
|
| 246 |
+
cols = min(4, num_heads)
|
| 247 |
+
rows = (num_heads + cols - 1) // cols
|
| 248 |
+
|
| 249 |
+
fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 4 * rows))
|
| 250 |
+
if num_heads == 1:
|
| 251 |
+
axes = [axes]
|
| 252 |
+
elif rows == 1:
|
| 253 |
+
axes = axes.reshape(1, -1)
|
| 254 |
+
|
| 255 |
+
for head_idx in range(num_heads):
|
| 256 |
+
row = head_idx // cols
|
| 257 |
+
col = head_idx % cols
|
| 258 |
+
|
| 259 |
+
# Get attention for this head
|
| 260 |
+
head_attention = last_layer_attention[head_idx, 0, 1:].cpu() # CLS to patches
|
| 261 |
+
attention_map = head_attention.reshape(grid_h, grid_w)
|
| 262 |
+
|
| 263 |
+
ax = axes[row, col] if rows > 1 else axes[col]
|
| 264 |
+
im = ax.imshow(attention_map.numpy(), cmap='hot', interpolation='nearest')
|
| 265 |
+
ax.set_title(f'Head {head_idx}')
|
| 266 |
+
ax.axis('off')
|
| 267 |
+
plt.colorbar(im, ax=ax)
|
| 268 |
+
|
| 269 |
+
# Hide unused subplots
|
| 270 |
+
for idx in range(num_heads, rows * cols):
|
| 271 |
+
row = idx // cols
|
| 272 |
+
col = idx % cols
|
| 273 |
+
ax = axes[row, col] if rows > 1 else axes[col]
|
| 274 |
+
ax.axis('off')
|
| 275 |
+
|
| 276 |
+
plt.tight_layout()
|
| 277 |
+
plt.savefig(f'{self.save_dir}/heads_epoch_{epoch}_sample_{sample_idx}.png',
|
| 278 |
+
dpi=150, bbox_inches='tight')
|
| 279 |
+
plt.close()
|
flaring/MEGS_AI_baseline/config.yaml
CHANGED
|
@@ -29,10 +29,10 @@ vit:
|
|
| 29 |
patch_size: 16
|
| 30 |
num_patches: 262144
|
| 31 |
hidden_dim: 512
|
| 32 |
-
num_heads:
|
| 33 |
num_layers: 6
|
| 34 |
dropout: 0.1
|
| 35 |
-
lr: .
|
| 36 |
|
| 37 |
# Data paths (automatically constructed from base directories)
|
| 38 |
data:
|
|
|
|
| 29 |
patch_size: 16
|
| 30 |
num_patches: 262144
|
| 31 |
hidden_dim: 512
|
| 32 |
+
num_heads: 4
|
| 33 |
num_layers: 6
|
| 34 |
dropout: 0.1
|
| 35 |
+
lr: .0001
|
| 36 |
|
| 37 |
# Data paths (automatically constructed from base directories)
|
| 38 |
data:
|