Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from sklearn.metrics import confusion_matrix | |
| import torch | |
| def to_display_image(img_tensor, mean, std): | |
| img = img_tensor.cpu().numpy() | |
| for c in range(3): | |
| img[c] = img[c]*std[c]+mean[c] | |
| img = np.clip(img, 0.0, 1.0) | |
| img = np.transpose(img, (1,2,0)) | |
| return img | |
| def visualize_preds(images, labels, preds, logger, class_names, mean, std, num_images): | |
| num_images = min(num_images, len(images)) | |
| rows = int(np.ceil(num_images/4)) | |
| fig, axs = plt.subplots(rows, 4, figsize=(24, 6*rows)) | |
| axs = axs.flatten() | |
| for i, ax in enumerate(axs): | |
| ax.axis("off") | |
| if i >= len(images): | |
| continue | |
| img = to_display_image(images[i], mean, std) | |
| lbl = labels[i] | |
| pr = preds[i] | |
| ax.imshow(img) | |
| title = f"Label: {class_names[lbl]}\nPrediction: {class_names[pr]}" | |
| colour = "green" if lbl == pr else "red" | |
| ax.set_title(title, fontsize=16, color=colour) | |
| fig.tight_layout() | |
| logger.report_matplotlib_figure("sample_predictions", "test", fig, iteration=0) | |
| plt.close(fig) | |
| def plot_cfm(labels, preds, logger, class_names, num_classes): | |
| cfm = confusion_matrix(labels, preds, labels=list(range(num_classes))) | |
| cfm_norm = cfm/cfm.sum(axis=1, keepdims=True) | |
| cfm_norm = np.nan_to_num(cfm_norm) | |
| fig, ax = plt.subplots(figsize=(16, 16)) | |
| im = ax.imshow(cfm_norm, interpolation="nearest", cmap="Blues") | |
| cbar = fig.colorbar(im, ax) | |
| cbar.ax.set_ylabel("Fraction of sample", rotation=90) | |
| fig.colorbar(im, ax) | |
| ax.set_xticks(range(num_classes)) | |
| ax.set_yticks(range(num_classes)) | |
| ax.set_xticklabels(class_names, rotation=90, fontsize=8) | |
| ax.set_yticklabels(class_names, fontsize=8) | |
| ax.set_xlabel("Predicted") | |
| ax.set_ylabel("Ground Truth") | |
| ax.set_title("Confusion matrix (Normalized)") | |
| threshold = cfm_norm.max() / 2.0 | |
| for i in range(num_classes): | |
| for j in range(num_classes): | |
| value = cfm_norm[i, j] | |
| if value == 0: | |
| continue | |
| ax.text(j, i, f"{value:.2f}", ha="center", va="center", | |
| fontsize=5, color="white" if value > threshold else "black") | |
| fig.tight_layout() | |
| logger.report_matplotlib_figure(title="normalized_confusion_matrix", series="test", figure=fig, iteration=0) | |
| plt.close(fig) | |
| cfm_errors = cfm.copy() | |
| np.fill_diagonal(cfm_errors, 0) | |
| if cfm_errors.max() > 0: | |
| fig_err, ax_err = plt.subplots(figsize=(18, 18)) | |
| im_err = ax_err.imshow(cfm_errors, interpolation="nearest", cmap=plt.cm.Blues) | |
| cbar_err = fig_err.colorbar(im_err, ax=ax_err) | |
| cbar_err.ax.set_ylabel("Number of misclassified samples", rotation=90) | |
| ax_err.set_title("Confusion matrix (errors only)") | |
| ax_err.set_xlabel("Predicted") | |
| ax_err.set_ylabel("Ground Truth") | |
| ax_err.set_xticks(np.arange(len(class_names))) | |
| ax_err.set_yticks(np.arange(len(class_names))) | |
| ax_err.set_xticklabels(class_names, rotation=90, fontsize=8) | |
| ax_err.set_yticklabels(class_names, fontsize=8) | |
| threshold = cfm_errors.max() / 2.0 | |
| for i in range(num_classes): | |
| for j in range(num_classes): | |
| value = cfm_errors[i, j] | |
| if value == 0: | |
| continue | |
| ax_err.text(j, i, str(value), ha="center", va="center", | |
| fontsize=5, color="white" if value > threshold else "black") | |
| fig_err.tight_layout() | |
| logger.report_matplotlib_figure(title="errors_only_confusion_matrix", series="test", figure=fig_err, iteration=0) | |
| plt.close(fig_err) | |
| else: | |
| print("No misclassifications") | |