dikdimon's picture
Upload extensions using SD-Hub extension
3dabe4a verified
import torch
import matplotlib.pyplot as plt
def plot_attention_map(attention_map: torch.Tensor, title, x_label="X", y_label="Y", save_path=None, plot_type="default"):
""" Plots an attention map using matplotlib.pyplot
Arguments:
attention_map: Tensor - The attention map to plot. Shape: (H, W)
title: str - The title of the plot
x_label: str (optional) - The x-axis label
y_label: str (optional) - The y-axis label
save_path: str (optional) - The path to save the plot
plot_type: str (optional) - The type of plot to create. Options: 'default', 'num', or any matplotlib colormap name. https://matplotlib.org/stable/gallery/color/colormap_reference.html
Returns:
None
"""
# Convert attention map to numpy array
attention_map = attention_map.detach().cpu().numpy()
# Create figure and axis
fig, ax = plt.subplots()
cmap_name = 'viridis'
match plot_type:
case 'default':
cmap_name = 'viridis'
case 'num':
cmap_name = 'tab20c'
case _:
cmap_name = plot_type
# Plot the attention map
ax.imshow(attention_map, cmap=cmap_name, interpolation='nearest')
if plot_type == 'num':
elements = list(set(attention_map.flatten()))
labels = [f"{x}" for x in elements]
fig.legend(elements, labels, loc='lower left')
# Set title and labels
ax.set_title(title)
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
# Save the plot if save_path is provided
if save_path:
plt.savefig(save_path)
plt.close(fig)