File size: 1,923 Bytes
3dabe4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import matplotlib.pyplot as plt

def plot_attention_map(attention_map: torch.Tensor, title, x_label="X", y_label="Y", save_path=None, plot_type="default"):
        """ Plots an attention map using matplotlib.pyplot 
                Arguments:
                        attention_map: Tensor - The attention map to plot. Shape: (H, W)
                        title: str - The title of the plot
                        x_label: str (optional) - The x-axis label
                        y_label: str (optional) - The y-axis label
                        save_path: str (optional) - The path to save the plot
                        plot_type: str (optional) - The type of plot to create. Options: 'default', 'num', or any matplotlib colormap name. https://matplotlib.org/stable/gallery/color/colormap_reference.html
                Returns:
                        None
        """

        # Convert attention map to numpy array
        attention_map = attention_map.detach().cpu().numpy()

        # Create figure and axis
        fig, ax = plt.subplots()

        cmap_name = 'viridis'
        match plot_type:
                case 'default':
                        cmap_name = 'viridis'
                case 'num':
                        cmap_name = 'tab20c'
                case _:
                        cmap_name = plot_type 
        # Plot the attention map
        ax.imshow(attention_map, cmap=cmap_name, interpolation='nearest')
        if plot_type == 'num':
                elements = list(set(attention_map.flatten()))
                labels = [f"{x}" for x in elements]
                fig.legend(elements, labels, loc='lower left')

        # Set title and labels
        ax.set_title(title)
        ax.set_xlabel(x_label)
        ax.set_ylabel(y_label)

        # Save the plot if save_path is provided
        if save_path:
                plt.savefig(save_path)
        
        plt.close(fig)