import numpy as np import matplotlib.pyplot as plt from sklearn.metrics import confusion_matrix import torch def to_display_image(img_tensor, mean, std): img = img_tensor.cpu().numpy() for c in range(3): img[c] = img[c]*std[c]+mean[c] img = np.clip(img, 0.0, 1.0) img = np.transpose(img, (1,2,0)) return img def visualize_preds(images, labels, preds, logger, class_names, mean, std, num_images): num_images = min(num_images, len(images)) rows = int(np.ceil(num_images/4)) fig, axs = plt.subplots(rows, 4, figsize=(24, 6*rows)) axs = axs.flatten() for i, ax in enumerate(axs): ax.axis("off") if i >= len(images): continue img = to_display_image(images[i], mean, std) lbl = labels[i] pr = preds[i] ax.imshow(img) title = f"Label: {class_names[lbl]}\nPrediction: {class_names[pr]}" colour = "green" if lbl == pr else "red" ax.set_title(title, fontsize=16, color=colour) fig.tight_layout() logger.report_matplotlib_figure("sample_predictions", "test", fig, iteration=0) plt.close(fig) def plot_cfm(labels, preds, logger, class_names, num_classes): cfm = confusion_matrix(labels, preds, labels=list(range(num_classes))) cfm_norm = cfm/cfm.sum(axis=1, keepdims=True) cfm_norm = np.nan_to_num(cfm_norm) fig, ax = plt.subplots(figsize=(16, 16)) im = ax.imshow(cfm_norm, interpolation="nearest", cmap="Blues") cbar = fig.colorbar(im, ax) cbar.ax.set_ylabel("Fraction of sample", rotation=90) fig.colorbar(im, ax) ax.set_xticks(range(num_classes)) ax.set_yticks(range(num_classes)) ax.set_xticklabels(class_names, rotation=90, fontsize=8) ax.set_yticklabels(class_names, fontsize=8) ax.set_xlabel("Predicted") ax.set_ylabel("Ground Truth") ax.set_title("Confusion matrix (Normalized)") threshold = cfm_norm.max() / 2.0 for i in range(num_classes): for j in range(num_classes): value = cfm_norm[i, j] if value == 0: continue ax.text(j, i, f"{value:.2f}", ha="center", va="center", fontsize=5, color="white" if value > threshold else "black") fig.tight_layout() logger.report_matplotlib_figure(title="normalized_confusion_matrix", series="test", figure=fig, iteration=0) plt.close(fig) cfm_errors = cfm.copy() np.fill_diagonal(cfm_errors, 0) if cfm_errors.max() > 0: fig_err, ax_err = plt.subplots(figsize=(18, 18)) im_err = ax_err.imshow(cfm_errors, interpolation="nearest", cmap=plt.cm.Blues) cbar_err = fig_err.colorbar(im_err, ax=ax_err) cbar_err.ax.set_ylabel("Number of misclassified samples", rotation=90) ax_err.set_title("Confusion matrix (errors only)") ax_err.set_xlabel("Predicted") ax_err.set_ylabel("Ground Truth") ax_err.set_xticks(np.arange(len(class_names))) ax_err.set_yticks(np.arange(len(class_names))) ax_err.set_xticklabels(class_names, rotation=90, fontsize=8) ax_err.set_yticklabels(class_names, fontsize=8) threshold = cfm_errors.max() / 2.0 for i in range(num_classes): for j in range(num_classes): value = cfm_errors[i, j] if value == 0: continue ax_err.text(j, i, str(value), ha="center", va="center", fontsize=5, color="white" if value > threshold else "black") fig_err.tight_layout() logger.report_matplotlib_figure(title="errors_only_confusion_matrix", series="test", figure=fig_err, iteration=0) plt.close(fig_err) else: print("No misclassifications")