tactile-vae / dataset /make_splits.py
WitneyWW's picture
Initial upload of tactile_vae (code, model, config, inference)
3770c94 verified
Raw
History Blame Contribute Delete
5.67 kB
"""Generate a stable train/val/test split manifest for the fota_unlabeled parquets.
Each sample is identified by `(parquet_filename, row_index)`. We hash that
identity to assign a stable bucket independent of file order, so adding /
reordering shards never reshuffles existing samples (only new files get
fresh assignments).
Run:
python tactile_vae/dataset/make_splits.py
Default writes to `tactile_vae/dataset/splits.json` and looks like:
{
"seed": 42,
"ratios": {"train": 0.9, "val": 0.05, "test": 0.05},
"counts": {...}, "total": ...,
"files": ["train-00000-of-00008.parquet", ...],
"splits": {
"train": {"train-00000-of-00008.parquet": [0, 3, 4, ...], ...},
"val": {...},
"test": {...}
}
}
Per-file row lists are sorted ascending so they're easy to diff across runs
and cheap to bisect / slice from a parquet reader.
"""
from __future__ import annotations
import argparse
import hashlib
import json
import struct
import sys
import time
from pathlib import Path
import numpy as np
import pyarrow.parquet as pq
_REPO_ROOT = Path(__file__).resolve().parents[2]
if str(_REPO_ROOT) not in sys.path:
sys.path.insert(0, str(_REPO_ROOT))
from tactile_vae.dataset.dataset import DEFAULT_DATA_ROOT, DEFAULT_FILE_GLOB
DEFAULT_OUTPUT = Path(__file__).with_name("splits.json")
DEFAULT_RATIOS = (0.9, 0.05, 0.05)
DEFAULT_SEED = 42
SPLIT_NAMES = ("train", "val", "test")
def _bucket_for(filename: str, row_idx: int, seed: int) -> float:
"""Map a (filename, row) pair to a deterministic float in [0, 1).
Uses BLAKE2b for a stable cross-run hash that doesn't depend on Python's
PYTHONHASHSEED. Adding new files leaves existing samples in their bucket.
"""
h = hashlib.blake2b(digest_size=8)
h.update(struct.pack("<I", seed))
h.update(filename.encode("utf-8"))
h.update(b"\x00")
h.update(struct.pack("<q", row_idx))
return int.from_bytes(h.digest(), "little") / 2**64
def _assign(filename: str, n_rows: int, seed: int, cutoffs: tuple[float, float]) -> np.ndarray:
"""Return a uint8 array of length `n_rows` with bucket {0,1,2} per row."""
# Vectorized blake2b is impractical, but for n ≤ ~100k per file the
# python loop is still fast (<1s/file) and keeps the bucketing portable.
out = np.empty(n_rows, dtype=np.uint8)
train_cut, val_cut = cutoffs
for i in range(n_rows):
u = _bucket_for(filename, i, seed)
if u < train_cut:
out[i] = 0 # train
elif u < val_cut:
out[i] = 1 # val
else:
out[i] = 2 # test
return out
def make_splits(
root: Path = DEFAULT_DATA_ROOT,
file_glob: str = DEFAULT_FILE_GLOB,
ratios: tuple[float, float, float] = DEFAULT_RATIOS,
seed: int = DEFAULT_SEED,
) -> dict:
if not (len(ratios) == 3 and abs(sum(ratios) - 1.0) < 1e-6):
raise ValueError(f"ratios must sum to 1.0; got {ratios} (sum={sum(ratios)})")
train_r, val_r, _ = ratios
cutoffs = (train_r, train_r + val_r)
files = sorted(p for p in root.glob(file_glob) if ".raw." not in p.name)
if not files:
raise FileNotFoundError(f"no parquet shards under {root} matching {file_glob!r}")
per_split: dict[str, dict[str, list[int]]] = {n: {} for n in SPLIT_NAMES}
counts = {n: 0 for n in SPLIT_NAMES}
total = 0
t0 = time.time()
for p in files:
n = pq.ParquetFile(p).metadata.num_rows
total += n
assignments = _assign(p.name, n, seed=seed, cutoffs=cutoffs)
for split_idx, split_name in enumerate(SPLIT_NAMES):
rows = np.flatnonzero(assignments == split_idx).tolist()
per_split[split_name][p.name] = rows
counts[split_name] += len(rows)
print(f" {p.name}: rows={n:,} "
f"train={int((assignments==0).sum()):,} "
f"val={int((assignments==1).sum()):,} "
f"test={int((assignments==2).sum()):,}",
flush=True)
dt = time.time() - t0
print(f"\nAssignment done in {dt:.1f}s")
for n in SPLIT_NAMES:
print(f" {n:5s}: {counts[n]:>8,} ({100*counts[n]/total:.2f}%)")
print(f" total: {total:>8,}")
return {
"seed": seed,
"ratios": {"train": ratios[0], "val": ratios[1], "test": ratios[2]},
"counts": counts,
"total": total,
"source_root": str(root),
"source_glob": file_glob,
"files": [p.name for p in files],
"split_names": list(SPLIT_NAMES),
"splits": per_split,
}
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description=__doc__.split("\n", 1)[0])
p.add_argument("--root", type=Path, default=DEFAULT_DATA_ROOT)
p.add_argument("--glob", default=DEFAULT_FILE_GLOB)
p.add_argument("--output", type=Path, default=DEFAULT_OUTPUT)
p.add_argument("--seed", type=int, default=DEFAULT_SEED)
p.add_argument("--train", type=float, default=DEFAULT_RATIOS[0])
p.add_argument("--val", type=float, default=DEFAULT_RATIOS[1])
p.add_argument("--test", type=float, default=DEFAULT_RATIOS[2])
return p.parse_args()
def main() -> int:
args = parse_args()
manifest = make_splits(
root=args.root,
file_glob=args.glob,
ratios=(args.train, args.val, args.test),
seed=args.seed,
)
args.output.parent.mkdir(parents=True, exist_ok=True)
with args.output.open("w") as f:
json.dump(manifest, f, separators=(",", ":"))
size_mb = args.output.stat().st_size / 1e6
print(f"\nwrote {args.output} ({size_mb:.2f} MB)")
return 0
if __name__ == "__main__":
sys.exit(main())