| |
| """Build all variable-speed ablation datasets without running norm stats or training. |
| |
| Use this on the data-prep machine. The ABLATIONS table is shared with |
| ``run_ablations.py`` -- edit it there. |
| |
| Usage: |
| |
| uv run python scripts/build_ablation_datasets.py \\ |
| --src $SRC --out-root $ROOT \\ |
| --num-workers 16 |
| |
| # only build a subset |
| uv run python scripts/build_ablation_datasets.py \\ |
| --src $SRC --out-root $ROOT \\ |
| --only g2_coarse,g3a_step025 |
| |
| # show commands without running |
| uv run python scripts/build_ablation_datasets.py \\ |
| --src $SRC --out-root $ROOT --dry-run |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import shlex |
| import subprocess |
| import sys |
| from pathlib import Path |
|
|
| _SCRIPTS_DIR = Path(__file__).resolve().parent |
| if str(_SCRIPTS_DIR) not in sys.path: |
| sys.path.insert(0, str(_SCRIPTS_DIR)) |
|
|
| import run_ablations as _ar |
|
|
|
|
| def _run(cmd: list[str], dry_run: bool) -> int: |
| print(f"$ {' '.join(shlex.quote(c) for c in cmd)}") |
| if dry_run: |
| return 0 |
| return subprocess.call(cmd) |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| p = argparse.ArgumentParser(description=__doc__) |
| p.add_argument("--src", required=True, help="Source LeRobot dataset root") |
| p.add_argument("--out-root", required=True, help="Root directory under which built datasets land") |
| p.add_argument("--run-tag", default="ablation") |
| p.add_argument( |
| "--only", |
| default=None, |
| help="Comma-separated ablation names to build (default = all in ABLATIONS).", |
| ) |
| p.add_argument("--num-workers", type=int, default=16) |
| |
| p.add_argument("--clean-eps", type=float, default=0.0) |
| p.add_argument("--dry-run", action="store_true") |
| return p.parse_args() |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| src = Path(args.src).resolve() |
| out_root = Path(args.out_root).resolve() |
| out_root.mkdir(parents=True, exist_ok=True) |
|
|
| only = {s.strip() for s in args.only.split(",")} if args.only else None |
| selected = [a for a in _ar.ABLATIONS if not only or a.name in only] |
| if not selected: |
| sys.exit( |
| f"No ablations selected. Names in --only must be one of: " |
| f"{[a.name for a in _ar.ABLATIONS]}" |
| ) |
| if only: |
| unknown = only - {a.name for a in _ar.ABLATIONS} |
| if unknown: |
| sys.exit(f"Unknown ablation names: {sorted(unknown)}") |
|
|
| print(f"selected = {[a.name for a in selected]}") |
| print(f"src = {src}") |
| print(f"out_root = {out_root}") |
| print() |
|
|
| built: list[tuple[str, Path]] = [] |
| built_speeds: set[str] = set() |
| for ab in selected: |
| print(f"\n========== ablation: {ab.name} speeds={ab.speeds} ==========") |
| ds_dir, cmd = _ar.build_cmd(ab, src, out_root, args.run_tag, args.num_workers, args.clean_eps) |
| speed_tok = _ar._speed_token(ab.speeds) |
| if speed_tok in built_speeds: |
| print(f"[skip, dataset already built in this run] -> {ds_dir}") |
| else: |
| rc = _run(cmd, args.dry_run) |
| if rc != 0: |
| sys.exit(f"build failed for {ab.name} (rc={rc})") |
| built_speeds.add(speed_tok) |
| built.append((ab.name, ds_dir)) |
|
|
| print("\n=== Build summary ===") |
| for name, path in built: |
| print(f" {name:20s} -> {path}") |
| print( |
| "\nNext steps: run norm-stats and training via " |
| "scripts/run_ablations.py --skip-build (or invoke compute_norm_stats / train_pytorch directly)." |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|