vehicle-damage-classifier / src /step3_show_all_classes.py
efnanaladagg's picture
Clean push
6f6eb85
raw
history blame contribute delete
999 Bytes
from datasets import load_dataset
import matplotlib.pyplot as plt
def main():
ds = load_dataset("DrBimmer/comprehensive-car-damage")
train_ds = ds["train"]
label_names = train_ds.features["label"].names
shown = set()
images = []
titles = []
for sample in train_ds:
label = sample["label"]
if label not in shown:
images.append(sample["image"])
titles.append(label_names[label])
shown.add(label)
if len(shown) == len(label_names):
break
plt.figure(figsize=(12, 8))
for i, (img, title) in enumerate(zip(images, titles)):
plt.subplot(2, 3, i + 1)
plt.imshow(img)
plt.title(title)
plt.axis("off")
plt.tight_layout()
plt.show()
if __name__ == "__main__":
main()
# This script loads the "comprehensive-car-damage" dataset,
# iterates through the training set to find and display one image for each damage class
# using matplotlib in a grid layout.