| def plot_sample_images(image_data, labels, categories): | |
| plt.figure(figsize=(12, 8)) | |
| for i in range(6): | |
| plt.subplot(2, 3, i + 1) | |
| plt.imshow(image_data[i].astype('uint8')) | |
| plt.title(categories[labels[i]]) | |
| plt.axis('off') | |
| plt.tight_layout() | |
| plt.show() | |
| plot_sample_images(image_data, labels, categories) |