File size: 2,062 Bytes
0d253c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize

def plot_heatmap(cm, saveToFile=None, annot=True, fmt="d", cmap="Blues", xticklabels=None, yticklabels=None):
    """
    Plots a heatmap of the confusion matrix.

    Parameters:
        cm (list of lists): The confusion matrix.
        annot (bool): Whether to annotate the heatmap with the cell values. Default is True.
        fmt (str): The format specifier for cell value annotations. Default is "d" (integer).
        cmap (str): The colormap for the heatmap. Default is "Blues".
        xticklabels (list): Labels for the x-axis ticks. Default is None.
        yticklabels (list): Labels for the y-axis ticks. Default is None.

    Returns:
        None
    """
    
    # Convert the confusion matrix to a NumPy array
    cm = np.array(cm)

    # Create a figure and axis for the heatmap
    fig, ax = plt.subplots()

    # Plot the heatmap
    im = ax.imshow(cm, cmap=cmap)
    
    # Display cell values as annotations
    if annot:
        # Normalize the colormap to get values between 0 and 1
        norm = Normalize(vmin=cm.min(), vmax=cm.max())
        for i in range(len(cm)):
            for j in range(len(cm[i])):
                value = cm[i, j]
                # Determine text color based on cell value
                text_color = 'white' if norm(value) > 0.5 else 'black'  
                text = ax.text(j, i, format(value, fmt), ha="center", va="center", color=text_color)

    # Set x-axis and y-axis ticks and labels
    if xticklabels:
        ax.set_xticks(np.arange(len(xticklabels)))
        ax.set_xticklabels(xticklabels)
    if yticklabels:
        ax.set_yticks(np.arange(len(yticklabels)))
        ax.set_yticklabels(yticklabels)

    # Set labels and title
    ax.set_xlabel("Predicted")
    ax.set_ylabel("True")
    ax.set_title("Confusion Matrix Heatmap")

    # Add a colorbar
    cbar = ax.figure.colorbar(im, ax=ax)

    # Show the plot
    if(saveToFile is not None):
        plt.savefig(saveToFile)
        
    plt.show()