| """Generate a stable train/val/test split manifest for the fota_unlabeled parquets. |
| |
| Each sample is identified by `(parquet_filename, row_index)`. We hash that |
| identity to assign a stable bucket independent of file order, so adding / |
| reordering shards never reshuffles existing samples (only new files get |
| fresh assignments). |
| |
| Run: |
| python tactile_vae/dataset/make_splits.py |
| |
| Default writes to `tactile_vae/dataset/splits.json` and looks like: |
| |
| { |
| "seed": 42, |
| "ratios": {"train": 0.9, "val": 0.05, "test": 0.05}, |
| "counts": {...}, "total": ..., |
| "files": ["train-00000-of-00008.parquet", ...], |
| "splits": { |
| "train": {"train-00000-of-00008.parquet": [0, 3, 4, ...], ...}, |
| "val": {...}, |
| "test": {...} |
| } |
| } |
| |
| Per-file row lists are sorted ascending so they're easy to diff across runs |
| and cheap to bisect / slice from a parquet reader. |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import hashlib |
| import json |
| import struct |
| import sys |
| import time |
| from pathlib import Path |
|
|
| import numpy as np |
| import pyarrow.parquet as pq |
|
|
| _REPO_ROOT = Path(__file__).resolve().parents[2] |
| if str(_REPO_ROOT) not in sys.path: |
| sys.path.insert(0, str(_REPO_ROOT)) |
|
|
| from tactile_vae.dataset.dataset import DEFAULT_DATA_ROOT, DEFAULT_FILE_GLOB |
|
|
| DEFAULT_OUTPUT = Path(__file__).with_name("splits.json") |
| DEFAULT_RATIOS = (0.9, 0.05, 0.05) |
| DEFAULT_SEED = 42 |
| SPLIT_NAMES = ("train", "val", "test") |
|
|
|
|
| def _bucket_for(filename: str, row_idx: int, seed: int) -> float: |
| """Map a (filename, row) pair to a deterministic float in [0, 1). |
| |
| Uses BLAKE2b for a stable cross-run hash that doesn't depend on Python's |
| PYTHONHASHSEED. Adding new files leaves existing samples in their bucket. |
| """ |
| h = hashlib.blake2b(digest_size=8) |
| h.update(struct.pack("<I", seed)) |
| h.update(filename.encode("utf-8")) |
| h.update(b"\x00") |
| h.update(struct.pack("<q", row_idx)) |
| return int.from_bytes(h.digest(), "little") / 2**64 |
|
|
|
|
| def _assign(filename: str, n_rows: int, seed: int, cutoffs: tuple[float, float]) -> np.ndarray: |
| """Return a uint8 array of length `n_rows` with bucket {0,1,2} per row.""" |
| |
| |
| out = np.empty(n_rows, dtype=np.uint8) |
| train_cut, val_cut = cutoffs |
| for i in range(n_rows): |
| u = _bucket_for(filename, i, seed) |
| if u < train_cut: |
| out[i] = 0 |
| elif u < val_cut: |
| out[i] = 1 |
| else: |
| out[i] = 2 |
| return out |
|
|
|
|
| def make_splits( |
| root: Path = DEFAULT_DATA_ROOT, |
| file_glob: str = DEFAULT_FILE_GLOB, |
| ratios: tuple[float, float, float] = DEFAULT_RATIOS, |
| seed: int = DEFAULT_SEED, |
| ) -> dict: |
| if not (len(ratios) == 3 and abs(sum(ratios) - 1.0) < 1e-6): |
| raise ValueError(f"ratios must sum to 1.0; got {ratios} (sum={sum(ratios)})") |
| train_r, val_r, _ = ratios |
| cutoffs = (train_r, train_r + val_r) |
|
|
| files = sorted(p for p in root.glob(file_glob) if ".raw." not in p.name) |
| if not files: |
| raise FileNotFoundError(f"no parquet shards under {root} matching {file_glob!r}") |
|
|
| per_split: dict[str, dict[str, list[int]]] = {n: {} for n in SPLIT_NAMES} |
| counts = {n: 0 for n in SPLIT_NAMES} |
| total = 0 |
| t0 = time.time() |
| for p in files: |
| n = pq.ParquetFile(p).metadata.num_rows |
| total += n |
| assignments = _assign(p.name, n, seed=seed, cutoffs=cutoffs) |
| for split_idx, split_name in enumerate(SPLIT_NAMES): |
| rows = np.flatnonzero(assignments == split_idx).tolist() |
| per_split[split_name][p.name] = rows |
| counts[split_name] += len(rows) |
| print(f" {p.name}: rows={n:,} " |
| f"train={int((assignments==0).sum()):,} " |
| f"val={int((assignments==1).sum()):,} " |
| f"test={int((assignments==2).sum()):,}", |
| flush=True) |
| dt = time.time() - t0 |
|
|
| print(f"\nAssignment done in {dt:.1f}s") |
| for n in SPLIT_NAMES: |
| print(f" {n:5s}: {counts[n]:>8,} ({100*counts[n]/total:.2f}%)") |
| print(f" total: {total:>8,}") |
|
|
| return { |
| "seed": seed, |
| "ratios": {"train": ratios[0], "val": ratios[1], "test": ratios[2]}, |
| "counts": counts, |
| "total": total, |
| "source_root": str(root), |
| "source_glob": file_glob, |
| "files": [p.name for p in files], |
| "split_names": list(SPLIT_NAMES), |
| "splits": per_split, |
| } |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| p = argparse.ArgumentParser(description=__doc__.split("\n", 1)[0]) |
| p.add_argument("--root", type=Path, default=DEFAULT_DATA_ROOT) |
| p.add_argument("--glob", default=DEFAULT_FILE_GLOB) |
| p.add_argument("--output", type=Path, default=DEFAULT_OUTPUT) |
| p.add_argument("--seed", type=int, default=DEFAULT_SEED) |
| p.add_argument("--train", type=float, default=DEFAULT_RATIOS[0]) |
| p.add_argument("--val", type=float, default=DEFAULT_RATIOS[1]) |
| p.add_argument("--test", type=float, default=DEFAULT_RATIOS[2]) |
| return p.parse_args() |
|
|
|
|
| def main() -> int: |
| args = parse_args() |
| manifest = make_splits( |
| root=args.root, |
| file_glob=args.glob, |
| ratios=(args.train, args.val, args.test), |
| seed=args.seed, |
| ) |
| args.output.parent.mkdir(parents=True, exist_ok=True) |
| with args.output.open("w") as f: |
| json.dump(manifest, f, separators=(",", ":")) |
| size_mb = args.output.stat().st_size / 1e6 |
| print(f"\nwrote {args.output} ({size_mb:.2f} MB)") |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|