"""Generate a stable train/val/test split manifest for the fota_unlabeled parquets. Each sample is identified by `(parquet_filename, row_index)`. We hash that identity to assign a stable bucket independent of file order, so adding / reordering shards never reshuffles existing samples (only new files get fresh assignments). Run: python tactile_vae/dataset/make_splits.py Default writes to `tactile_vae/dataset/splits.json` and looks like: { "seed": 42, "ratios": {"train": 0.9, "val": 0.05, "test": 0.05}, "counts": {...}, "total": ..., "files": ["train-00000-of-00008.parquet", ...], "splits": { "train": {"train-00000-of-00008.parquet": [0, 3, 4, ...], ...}, "val": {...}, "test": {...} } } Per-file row lists are sorted ascending so they're easy to diff across runs and cheap to bisect / slice from a parquet reader. """ from __future__ import annotations import argparse import hashlib import json import struct import sys import time from pathlib import Path import numpy as np import pyarrow.parquet as pq _REPO_ROOT = Path(__file__).resolve().parents[2] if str(_REPO_ROOT) not in sys.path: sys.path.insert(0, str(_REPO_ROOT)) from tactile_vae.dataset.dataset import DEFAULT_DATA_ROOT, DEFAULT_FILE_GLOB DEFAULT_OUTPUT = Path(__file__).with_name("splits.json") DEFAULT_RATIOS = (0.9, 0.05, 0.05) DEFAULT_SEED = 42 SPLIT_NAMES = ("train", "val", "test") def _bucket_for(filename: str, row_idx: int, seed: int) -> float: """Map a (filename, row) pair to a deterministic float in [0, 1). Uses BLAKE2b for a stable cross-run hash that doesn't depend on Python's PYTHONHASHSEED. Adding new files leaves existing samples in their bucket. """ h = hashlib.blake2b(digest_size=8) h.update(struct.pack(" np.ndarray: """Return a uint8 array of length `n_rows` with bucket {0,1,2} per row.""" # Vectorized blake2b is impractical, but for n ≤ ~100k per file the # python loop is still fast (<1s/file) and keeps the bucketing portable. out = np.empty(n_rows, dtype=np.uint8) train_cut, val_cut = cutoffs for i in range(n_rows): u = _bucket_for(filename, i, seed) if u < train_cut: out[i] = 0 # train elif u < val_cut: out[i] = 1 # val else: out[i] = 2 # test return out def make_splits( root: Path = DEFAULT_DATA_ROOT, file_glob: str = DEFAULT_FILE_GLOB, ratios: tuple[float, float, float] = DEFAULT_RATIOS, seed: int = DEFAULT_SEED, ) -> dict: if not (len(ratios) == 3 and abs(sum(ratios) - 1.0) < 1e-6): raise ValueError(f"ratios must sum to 1.0; got {ratios} (sum={sum(ratios)})") train_r, val_r, _ = ratios cutoffs = (train_r, train_r + val_r) files = sorted(p for p in root.glob(file_glob) if ".raw." not in p.name) if not files: raise FileNotFoundError(f"no parquet shards under {root} matching {file_glob!r}") per_split: dict[str, dict[str, list[int]]] = {n: {} for n in SPLIT_NAMES} counts = {n: 0 for n in SPLIT_NAMES} total = 0 t0 = time.time() for p in files: n = pq.ParquetFile(p).metadata.num_rows total += n assignments = _assign(p.name, n, seed=seed, cutoffs=cutoffs) for split_idx, split_name in enumerate(SPLIT_NAMES): rows = np.flatnonzero(assignments == split_idx).tolist() per_split[split_name][p.name] = rows counts[split_name] += len(rows) print(f" {p.name}: rows={n:,} " f"train={int((assignments==0).sum()):,} " f"val={int((assignments==1).sum()):,} " f"test={int((assignments==2).sum()):,}", flush=True) dt = time.time() - t0 print(f"\nAssignment done in {dt:.1f}s") for n in SPLIT_NAMES: print(f" {n:5s}: {counts[n]:>8,} ({100*counts[n]/total:.2f}%)") print(f" total: {total:>8,}") return { "seed": seed, "ratios": {"train": ratios[0], "val": ratios[1], "test": ratios[2]}, "counts": counts, "total": total, "source_root": str(root), "source_glob": file_glob, "files": [p.name for p in files], "split_names": list(SPLIT_NAMES), "splits": per_split, } def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser(description=__doc__.split("\n", 1)[0]) p.add_argument("--root", type=Path, default=DEFAULT_DATA_ROOT) p.add_argument("--glob", default=DEFAULT_FILE_GLOB) p.add_argument("--output", type=Path, default=DEFAULT_OUTPUT) p.add_argument("--seed", type=int, default=DEFAULT_SEED) p.add_argument("--train", type=float, default=DEFAULT_RATIOS[0]) p.add_argument("--val", type=float, default=DEFAULT_RATIOS[1]) p.add_argument("--test", type=float, default=DEFAULT_RATIOS[2]) return p.parse_args() def main() -> int: args = parse_args() manifest = make_splits( root=args.root, file_glob=args.glob, ratios=(args.train, args.val, args.test), seed=args.seed, ) args.output.parent.mkdir(parents=True, exist_ok=True) with args.output.open("w") as f: json.dump(manifest, f, separators=(",", ":")) size_mb = args.output.stat().st_size / 1e6 print(f"\nwrote {args.output} ({size_mb:.2f} MB)") return 0 if __name__ == "__main__": sys.exit(main())