File size: 3,932 Bytes
83be575
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import torch

def to_display_image(img_tensor, mean, std):
    img = img_tensor.cpu().numpy()
    for c in range(3):
        img[c] = img[c]*std[c]+mean[c]
    img = np.clip(img, 0.0, 1.0)
    img = np.transpose(img, (1,2,0))
    return img

def visualize_preds(images, labels, preds, logger, class_names, mean, std, num_images):
    num_images = min(num_images, len(images))
    rows = int(np.ceil(num_images/4))
    fig, axs = plt.subplots(rows, 4, figsize=(24, 6*rows))
    axs = axs.flatten()
    
    for i, ax in enumerate(axs):
        ax.axis("off")
        if i >= len(images):
            continue
        
        img = to_display_image(images[i], mean, std)
        lbl = labels[i]
        pr = preds[i]
        
        ax.imshow(img)
        title = f"Label: {class_names[lbl]}\nPrediction: {class_names[pr]}"
        colour = "green" if lbl == pr else "red"
        ax.set_title(title, fontsize=16, color=colour)
    
    fig.tight_layout()
    logger.report_matplotlib_figure("sample_predictions", "test", fig, iteration=0)
    plt.close(fig)
    
def plot_cfm(labels, preds, logger, class_names, num_classes):
    cfm = confusion_matrix(labels, preds, labels=list(range(num_classes)))
    cfm_norm = cfm/cfm.sum(axis=1, keepdims=True)
    cfm_norm = np.nan_to_num(cfm_norm)
    
    fig, ax = plt.subplots(figsize=(16, 16))
    im = ax.imshow(cfm_norm, interpolation="nearest", cmap="Blues")
    cbar = fig.colorbar(im, ax)
    cbar.ax.set_ylabel("Fraction of sample", rotation=90)
    fig.colorbar(im, ax)
    ax.set_xticks(range(num_classes))
    ax.set_yticks(range(num_classes))
    ax.set_xticklabels(class_names, rotation=90, fontsize=8)
    ax.set_yticklabels(class_names, fontsize=8)
    ax.set_xlabel("Predicted")
    ax.set_ylabel("Ground Truth")
    ax.set_title("Confusion matrix (Normalized)")
    
    threshold = cfm_norm.max() / 2.0  
    for i in range(num_classes):
        for j in range(num_classes):
            value = cfm_norm[i, j]
            if value == 0:
                    continue  
            ax.text(j, i, f"{value:.2f}", ha="center", va="center",
                fontsize=5, color="white" if value > threshold else "black")
    
    fig.tight_layout()
    logger.report_matplotlib_figure(title="normalized_confusion_matrix", series="test", figure=fig, iteration=0)
    plt.close(fig)
    
    cfm_errors = cfm.copy()
    np.fill_diagonal(cfm_errors, 0)
    if cfm_errors.max() > 0:
        fig_err, ax_err = plt.subplots(figsize=(18, 18))
        im_err = ax_err.imshow(cfm_errors, interpolation="nearest", cmap=plt.cm.Blues)
        cbar_err = fig_err.colorbar(im_err, ax=ax_err)
        cbar_err.ax.set_ylabel("Number of misclassified samples", rotation=90)
        ax_err.set_title("Confusion matrix (errors only)")
        ax_err.set_xlabel("Predicted")
        ax_err.set_ylabel("Ground Truth")
        ax_err.set_xticks(np.arange(len(class_names)))
        ax_err.set_yticks(np.arange(len(class_names)))
        ax_err.set_xticklabels(class_names, rotation=90, fontsize=8)
        ax_err.set_yticklabels(class_names, fontsize=8)
        
        threshold = cfm_errors.max() / 2.0
        for i in range(num_classes):
            for j in range(num_classes):
                value = cfm_errors[i, j]
                if value == 0:
                    continue  
                ax_err.text(j, i, str(value), ha="center", va="center",
                    fontsize=5, color="white" if value > threshold else "black")

        fig_err.tight_layout()
        logger.report_matplotlib_figure(title="errors_only_confusion_matrix", series="test", figure=fig_err, iteration=0)
        plt.close(fig_err)
    else:
        print("No misclassifications")