| import matplotlib.pyplot as plt | |
| def plot_img_and_mask(img, mask): | |
| classes = mask.shape[0] if len(mask.shape) > 2 else 1 | |
| fig, ax = plt.subplots(1, classes + 1) | |
| ax[0].set_title('Input image') | |
| ax[0].imshow(img) | |
| if classes > 1: | |
| for i in range(classes): | |
| ax[i + 1].set_title(f'Output mask (class {i + 1})') | |
| ax[i + 1].imshow(mask[:, :, i]) | |
| else: | |
| ax[1].set_title(f'Output mask') | |
| ax[1].imshow(mask) | |
| plt.xticks([]), plt.yticks([]) | |
| plt.show() | |