def plot_sample_images(image_data, labels, categories): plt.figure(figsize=(12, 8)) for i in range(6): plt.subplot(2, 3, i + 1) plt.imshow(image_data[i].astype('uint8')) plt.title(categories[labels[i]]) plt.axis('off') plt.tight_layout() plt.show() plot_sample_images(image_data, labels, categories)