vehicle-damage-classifier / src /step4_make_splits.py
efnanaladagg's picture
Clean push
6f6eb85
from collections import Counter
from datasets import load_dataset, DatasetDict
from pathlib import Path
SEED = 42
TEST_SIZE = 0.2
OUT_DIR = Path("data/splits/comprehensive-car-damage_seed42_test0p2")
def dist(ds_split):
c = Counter(ds_split["label"])
names = ds_split.features["label"].names
total = len(ds_split)
rows = []
for k in range(len(names)):
v = c.get(k, 0)
rows.append((names[k], v, v/total if total else 0))
return rows
def print_dist(title, ds_split):
print(f"\n{title} (n={len(ds_split)})")
for name, v, p in dist(ds_split):
print(f"- {name:<10}: {v:>4} ({p*100:>5.1f}%)")
def main():
ds = load_dataset("DrBimmer/comprehensive-car-damage")
train = ds["train"]
split = train.train_test_split(
test_size=TEST_SIZE,
seed=SEED,
stratify_by_column="label"
)
# Rename for clarity: test -> val
splits = DatasetDict({"train": split["train"], "val": split["test"]})
print_dist("TRAIN", splits["train"])
print_dist("VAL", splits["val"])
OUT_DIR.mkdir(parents=True, exist_ok=True)
splits.save_to_disk(str(OUT_DIR))
print(f"\nSaved splits to: {OUT_DIR}")
if __name__ == "__main__":
main()