VLAwithVariousSpeed / scripts /build_ablation_datasets.py
Alan0928's picture
Upload folder using huggingface_hub
08ff31f verified
Raw
History Blame Contribute Delete
3.64 kB
#!/usr/bin/env python3
"""Build all variable-speed ablation datasets without running norm stats or training.
Use this on the data-prep machine. The ABLATIONS table is shared with
``run_ablations.py`` -- edit it there.
Usage:
uv run python scripts/build_ablation_datasets.py \\
--src $SRC --out-root $ROOT \\
--num-workers 16
# only build a subset
uv run python scripts/build_ablation_datasets.py \\
--src $SRC --out-root $ROOT \\
--only g2_coarse,g3a_step025
# show commands without running
uv run python scripts/build_ablation_datasets.py \\
--src $SRC --out-root $ROOT --dry-run
"""
from __future__ import annotations
import argparse
import shlex
import subprocess
import sys
from pathlib import Path
_SCRIPTS_DIR = Path(__file__).resolve().parent
if str(_SCRIPTS_DIR) not in sys.path:
sys.path.insert(0, str(_SCRIPTS_DIR))
import run_ablations as _ar # noqa: E402
def _run(cmd: list[str], dry_run: bool) -> int:
print(f"$ {' '.join(shlex.quote(c) for c in cmd)}")
if dry_run:
return 0
return subprocess.call(cmd)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description=__doc__)
p.add_argument("--src", required=True, help="Source LeRobot dataset root")
p.add_argument("--out-root", required=True, help="Root directory under which built datasets land")
p.add_argument("--run-tag", default="ablation")
p.add_argument(
"--only",
default=None,
help="Comma-separated ablation names to build (default = all in ABLATIONS).",
)
p.add_argument("--num-workers", type=int, default=16)
# Default 0.0 disables near-zero action cleaning (LIBERO is clean enough).
p.add_argument("--clean-eps", type=float, default=0.0)
p.add_argument("--dry-run", action="store_true")
return p.parse_args()
def main() -> None:
args = parse_args()
src = Path(args.src).resolve()
out_root = Path(args.out_root).resolve()
out_root.mkdir(parents=True, exist_ok=True)
only = {s.strip() for s in args.only.split(",")} if args.only else None
selected = [a for a in _ar.ABLATIONS if not only or a.name in only]
if not selected:
sys.exit(
f"No ablations selected. Names in --only must be one of: "
f"{[a.name for a in _ar.ABLATIONS]}"
)
if only:
unknown = only - {a.name for a in _ar.ABLATIONS}
if unknown:
sys.exit(f"Unknown ablation names: {sorted(unknown)}")
print(f"selected = {[a.name for a in selected]}")
print(f"src = {src}")
print(f"out_root = {out_root}")
print()
built: list[tuple[str, Path]] = []
built_speeds: set[str] = set()
for ab in selected:
print(f"\n========== ablation: {ab.name} speeds={ab.speeds} ==========")
ds_dir, cmd = _ar.build_cmd(ab, src, out_root, args.run_tag, args.num_workers, args.clean_eps)
speed_tok = _ar._speed_token(ab.speeds)
if speed_tok in built_speeds:
print(f"[skip, dataset already built in this run] -> {ds_dir}")
else:
rc = _run(cmd, args.dry_run)
if rc != 0:
sys.exit(f"build failed for {ab.name} (rc={rc})")
built_speeds.add(speed_tok)
built.append((ab.name, ds_dir))
print("\n=== Build summary ===")
for name, path in built:
print(f" {name:20s} -> {path}")
print(
"\nNext steps: run norm-stats and training via "
"scripts/run_ablations.py --skip-build (or invoke compute_norm_stats / train_pytorch directly)."
)
if __name__ == "__main__":
main()