from datasets import load_dataset import matplotlib.pyplot as plt def main(): ds = load_dataset("DrBimmer/comprehensive-car-damage") train_ds = ds["train"] label_names = train_ds.features["label"].names shown = set() images = [] titles = [] for sample in train_ds: label = sample["label"] if label not in shown: images.append(sample["image"]) titles.append(label_names[label]) shown.add(label) if len(shown) == len(label_names): break plt.figure(figsize=(12, 8)) for i, (img, title) in enumerate(zip(images, titles)): plt.subplot(2, 3, i + 1) plt.imshow(img) plt.title(title) plt.axis("off") plt.tight_layout() plt.show() if __name__ == "__main__": main() # This script loads the "comprehensive-car-damage" dataset, # iterates through the training set to find and display one image for each damage class # using matplotlib in a grid layout.