Spaces:
Runtime error
Runtime error
File size: 3,932 Bytes
83be575 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 | import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import torch
def to_display_image(img_tensor, mean, std):
img = img_tensor.cpu().numpy()
for c in range(3):
img[c] = img[c]*std[c]+mean[c]
img = np.clip(img, 0.0, 1.0)
img = np.transpose(img, (1,2,0))
return img
def visualize_preds(images, labels, preds, logger, class_names, mean, std, num_images):
num_images = min(num_images, len(images))
rows = int(np.ceil(num_images/4))
fig, axs = plt.subplots(rows, 4, figsize=(24, 6*rows))
axs = axs.flatten()
for i, ax in enumerate(axs):
ax.axis("off")
if i >= len(images):
continue
img = to_display_image(images[i], mean, std)
lbl = labels[i]
pr = preds[i]
ax.imshow(img)
title = f"Label: {class_names[lbl]}\nPrediction: {class_names[pr]}"
colour = "green" if lbl == pr else "red"
ax.set_title(title, fontsize=16, color=colour)
fig.tight_layout()
logger.report_matplotlib_figure("sample_predictions", "test", fig, iteration=0)
plt.close(fig)
def plot_cfm(labels, preds, logger, class_names, num_classes):
cfm = confusion_matrix(labels, preds, labels=list(range(num_classes)))
cfm_norm = cfm/cfm.sum(axis=1, keepdims=True)
cfm_norm = np.nan_to_num(cfm_norm)
fig, ax = plt.subplots(figsize=(16, 16))
im = ax.imshow(cfm_norm, interpolation="nearest", cmap="Blues")
cbar = fig.colorbar(im, ax)
cbar.ax.set_ylabel("Fraction of sample", rotation=90)
fig.colorbar(im, ax)
ax.set_xticks(range(num_classes))
ax.set_yticks(range(num_classes))
ax.set_xticklabels(class_names, rotation=90, fontsize=8)
ax.set_yticklabels(class_names, fontsize=8)
ax.set_xlabel("Predicted")
ax.set_ylabel("Ground Truth")
ax.set_title("Confusion matrix (Normalized)")
threshold = cfm_norm.max() / 2.0
for i in range(num_classes):
for j in range(num_classes):
value = cfm_norm[i, j]
if value == 0:
continue
ax.text(j, i, f"{value:.2f}", ha="center", va="center",
fontsize=5, color="white" if value > threshold else "black")
fig.tight_layout()
logger.report_matplotlib_figure(title="normalized_confusion_matrix", series="test", figure=fig, iteration=0)
plt.close(fig)
cfm_errors = cfm.copy()
np.fill_diagonal(cfm_errors, 0)
if cfm_errors.max() > 0:
fig_err, ax_err = plt.subplots(figsize=(18, 18))
im_err = ax_err.imshow(cfm_errors, interpolation="nearest", cmap=plt.cm.Blues)
cbar_err = fig_err.colorbar(im_err, ax=ax_err)
cbar_err.ax.set_ylabel("Number of misclassified samples", rotation=90)
ax_err.set_title("Confusion matrix (errors only)")
ax_err.set_xlabel("Predicted")
ax_err.set_ylabel("Ground Truth")
ax_err.set_xticks(np.arange(len(class_names)))
ax_err.set_yticks(np.arange(len(class_names)))
ax_err.set_xticklabels(class_names, rotation=90, fontsize=8)
ax_err.set_yticklabels(class_names, fontsize=8)
threshold = cfm_errors.max() / 2.0
for i in range(num_classes):
for j in range(num_classes):
value = cfm_errors[i, j]
if value == 0:
continue
ax_err.text(j, i, str(value), ha="center", va="center",
fontsize=5, color="white" if value > threshold else "black")
fig_err.tight_layout()
logger.report_matplotlib_figure(title="errors_only_confusion_matrix", series="test", figure=fig_err, iteration=0)
plt.close(fig_err)
else:
print("No misclassifications")
|