import torch import matplotlib.pyplot as plt def plot_attention_map(attention_map: torch.Tensor, title, x_label="X", y_label="Y", save_path=None, plot_type="default"): """ Plots an attention map using matplotlib.pyplot Arguments: attention_map: Tensor - The attention map to plot. Shape: (H, W) title: str - The title of the plot x_label: str (optional) - The x-axis label y_label: str (optional) - The y-axis label save_path: str (optional) - The path to save the plot plot_type: str (optional) - The type of plot to create. Options: 'default', 'num', or any matplotlib colormap name. https://matplotlib.org/stable/gallery/color/colormap_reference.html Returns: None """ # Convert attention map to numpy array attention_map = attention_map.detach().cpu().numpy() # Create figure and axis fig, ax = plt.subplots() cmap_name = 'viridis' match plot_type: case 'default': cmap_name = 'viridis' case 'num': cmap_name = 'tab20c' case _: cmap_name = plot_type # Plot the attention map ax.imshow(attention_map, cmap=cmap_name, interpolation='nearest') if plot_type == 'num': elements = list(set(attention_map.flatten())) labels = [f"{x}" for x in elements] fig.legend(elements, labels, loc='lower left') # Set title and labels ax.set_title(title) ax.set_xlabel(x_label) ax.set_ylabel(y_label) # Save the plot if save_path is provided if save_path: plt.savefig(save_path) plt.close(fig)