|
|
from datasets import load_dataset |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
def main(): |
|
|
ds = load_dataset("DrBimmer/comprehensive-car-damage") |
|
|
train_ds = ds["train"] |
|
|
label_names = train_ds.features["label"].names |
|
|
|
|
|
shown = set() |
|
|
images = [] |
|
|
titles = [] |
|
|
|
|
|
for sample in train_ds: |
|
|
label = sample["label"] |
|
|
if label not in shown: |
|
|
images.append(sample["image"]) |
|
|
titles.append(label_names[label]) |
|
|
shown.add(label) |
|
|
if len(shown) == len(label_names): |
|
|
break |
|
|
|
|
|
plt.figure(figsize=(12, 8)) |
|
|
for i, (img, title) in enumerate(zip(images, titles)): |
|
|
plt.subplot(2, 3, i + 1) |
|
|
plt.imshow(img) |
|
|
plt.title(title) |
|
|
plt.axis("off") |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.show() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|
|
|
|
|
|
|