from collections import Counter from datasets import load_dataset, DatasetDict from pathlib import Path SEED = 42 TEST_SIZE = 0.2 OUT_DIR = Path("data/splits/comprehensive-car-damage_seed42_test0p2") def dist(ds_split): c = Counter(ds_split["label"]) names = ds_split.features["label"].names total = len(ds_split) rows = [] for k in range(len(names)): v = c.get(k, 0) rows.append((names[k], v, v/total if total else 0)) return rows def print_dist(title, ds_split): print(f"\n{title} (n={len(ds_split)})") for name, v, p in dist(ds_split): print(f"- {name:<10}: {v:>4} ({p*100:>5.1f}%)") def main(): ds = load_dataset("DrBimmer/comprehensive-car-damage") train = ds["train"] split = train.train_test_split( test_size=TEST_SIZE, seed=SEED, stratify_by_column="label" ) # Rename for clarity: test -> val splits = DatasetDict({"train": split["train"], "val": split["test"]}) print_dist("TRAIN", splits["train"]) print_dist("VAL", splits["val"]) OUT_DIR.mkdir(parents=True, exist_ok=True) splits.save_to_disk(str(OUT_DIR)) print(f"\nSaved splits to: {OUT_DIR}") if __name__ == "__main__": main()