|
|
from collections import Counter |
|
|
from datasets import load_dataset, DatasetDict |
|
|
from pathlib import Path |
|
|
|
|
|
SEED = 42 |
|
|
TEST_SIZE = 0.2 |
|
|
OUT_DIR = Path("data/splits/comprehensive-car-damage_seed42_test0p2") |
|
|
|
|
|
def dist(ds_split): |
|
|
c = Counter(ds_split["label"]) |
|
|
names = ds_split.features["label"].names |
|
|
total = len(ds_split) |
|
|
rows = [] |
|
|
for k in range(len(names)): |
|
|
v = c.get(k, 0) |
|
|
rows.append((names[k], v, v/total if total else 0)) |
|
|
return rows |
|
|
|
|
|
def print_dist(title, ds_split): |
|
|
print(f"\n{title} (n={len(ds_split)})") |
|
|
for name, v, p in dist(ds_split): |
|
|
print(f"- {name:<10}: {v:>4} ({p*100:>5.1f}%)") |
|
|
|
|
|
def main(): |
|
|
ds = load_dataset("DrBimmer/comprehensive-car-damage") |
|
|
train = ds["train"] |
|
|
|
|
|
split = train.train_test_split( |
|
|
test_size=TEST_SIZE, |
|
|
seed=SEED, |
|
|
stratify_by_column="label" |
|
|
) |
|
|
|
|
|
|
|
|
splits = DatasetDict({"train": split["train"], "val": split["test"]}) |
|
|
|
|
|
print_dist("TRAIN", splits["train"]) |
|
|
print_dist("VAL", splits["val"]) |
|
|
|
|
|
OUT_DIR.mkdir(parents=True, exist_ok=True) |
|
|
splits.save_to_disk(str(OUT_DIR)) |
|
|
print(f"\nSaved splits to: {OUT_DIR}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|