| |
| """Run a sweep of variable-speed ablation experiments end-to-end. |
| |
| For each ablation: build dataset (multi-process) -> compute norm stats -> train. |
| All three stages share one TrainConfig name; per-ablation differences are |
| applied via CLI overrides on ``--data.repo_id`` and ``--data.assets.asset_id``, |
| and via ``--repo-id`` / ``--asset-id`` for compute_norm_stats. |
| |
| Edit the ABLATIONS table below to define your groups. Use: |
| - ``--only NAME[,NAME]`` to run a subset |
| - ``--skip-build / --skip-norm-stats / --skip-train`` to bypass stages |
| - ``--dry-run`` to print commands without executing them |
| |
| Example: |
| |
| uv run python scripts/run_ablations.py \\ |
| --src $SRC --out-root $ROOT \\ |
| --train-config pi05_libero_various_speed_all \\ |
| --base-asset-id libero_various_speed_all_pi05 \\ |
| --exp-prefix pi05_ablation \\ |
| --num-train-steps 30000 \\ |
| --build-num-workers 16 --train-num-workers 8 \\ |
| --only g2_coarse,g3a_step025 --dry-run |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import dataclasses |
| import shlex |
| import subprocess |
| import sys |
| from pathlib import Path |
|
|
|
|
| @dataclasses.dataclass(frozen=True) |
| class Ablation: |
| name: str |
| speeds: tuple[float, ...] |
| |
| |
| speed_integration: str = "auto" |
| |
| |
| extra_train_args: tuple[str, ...] = () |
| |
| |
| |
| |
| shared_norm_key: str | None = None |
|
|
|
|
| |
| |
| _SPEED_INT_SPEEDS: tuple[float, ...] = (0.75, 1.0, 1.25, 1.5) |
| _SOFT_PROMPT_P_VALUES: tuple[int, ...] = (1, 4, 8, 16, 32) |
|
|
|
|
| def _soft_prompt_ablation(p: int) -> Ablation: |
| |
| |
| |
| return Ablation( |
| f"softprompt_p{p}", |
| _SPEED_INT_SPEEDS, |
| speed_integration="soft_prompt", |
| extra_train_args=( |
| f"--model.soft-prompt-p={p}", |
| "--model.soft-prompt-speeds", |
| *(f"{s:g}" for s in _SPEED_INT_SPEEDS), |
| ), |
| |
| shared_norm_key="softprompt_shared", |
| ) |
|
|
|
|
| |
| |
| |
| |
| ABLATIONS: tuple[Ablation, ...] = ( |
| |
| Ablation("g1_baseline", (1.0,), speed_integration="text"), |
| Ablation("g2_coarse", (0.5, 1.0, 1.5, 2.0), speed_integration="text"), |
| Ablation("g3a_step025", (0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0), speed_integration="text"), |
| Ablation("g4_narrow", (0.75, 1.0, 1.25, 1.5), speed_integration="text"), |
| Ablation("g5_extreme", (0.25, 0.5, 1.0, 2.0, 4.0), speed_integration="text"), |
| |
| |
| Ablation("speedint_text", _SPEED_INT_SPEEDS, speed_integration="text"), |
| Ablation( |
| "speedint_modulation", |
| _SPEED_INT_SPEEDS, |
| speed_integration="modulation", |
| extra_train_args=("--model.speed-modulation=True",), |
| ), |
| |
| |
| |
| *(_soft_prompt_ablation(p) for p in _SOFT_PROMPT_P_VALUES), |
| ) |
|
|
|
|
| def _speed_token(speeds: tuple[float, ...]) -> str: |
| return "_".join(f"{s:g}".replace(".", "p") for s in speeds) |
|
|
|
|
| def _dataset_dir(out_root: Path, ablation: Ablation, run_tag: str) -> Path: |
| return out_root / f"libero_speed_{_speed_token(ablation.speeds)}_{run_tag}" |
|
|
|
|
| def _asset_id(base_asset_id: str, ablation: Ablation) -> str: |
| suffix = ablation.shared_norm_key or ablation.name |
| return f"{base_asset_id}_{suffix}" |
|
|
|
|
| def _exp_name(prefix: str, ablation: Ablation) -> str: |
| return f"{prefix}_{ablation.name}" |
|
|
|
|
| def _run(cmd: list[str], dry_run: bool) -> int: |
| print(f"$ {' '.join(shlex.quote(c) for c in cmd)}") |
| if dry_run: |
| return 0 |
| return subprocess.call(cmd) |
|
|
|
|
| def build_cmd( |
| ablation: Ablation, |
| src: Path, |
| out_root: Path, |
| run_tag: str, |
| num_workers: int, |
| clean_eps: float, |
| ) -> tuple[Path, list[str]]: |
| out_dir = _dataset_dir(out_root, ablation, run_tag) |
| cmd = [ |
| "uv", "run", "python", "scripts/build_libero_speed_dataset_mp.py", |
| "--src", str(src), |
| "--dst", str(out_dir), |
| "--speeds", *(f"{s:g}" for s in ablation.speeds), |
| "--clean-transl-eps", str(clean_eps), |
| "--clean-rot-eps", str(clean_eps), |
| "--min-segment-len", "1", |
| "--num-workers", str(num_workers), |
| "--overwrite", |
| ] |
| return out_dir, cmd |
|
|
|
|
| def norm_stats_cmd( |
| ablation: Ablation, |
| dataset_dir: Path, |
| train_config_name: str, |
| base_asset_id: str, |
| ) -> list[str]: |
| return [ |
| "uv", "run", "python", "scripts/compute_norm_stats.py", |
| train_config_name, |
| "--repo-id", str(dataset_dir), |
| "--asset-id", _asset_id(base_asset_id, ablation), |
| ] |
|
|
|
|
| def train_cmd( |
| ablation: Ablation, |
| dataset_dir: Path, |
| train_config_name: str, |
| base_asset_id: str, |
| exp_prefix: str, |
| num_workers: int, |
| num_train_steps: int, |
| ) -> list[str]: |
| cmd = [ |
| "uv", "run", "python", "scripts/train_pytorch.py", train_config_name, |
| f"--exp-name={_exp_name(exp_prefix, ablation)}", |
| f"--data.repo-id={dataset_dir}", |
| f"--data.assets.asset-id={_asset_id(base_asset_id, ablation)}", |
| f"--data.speed-integration={ablation.speed_integration}", |
| f"--num-workers={num_workers}", |
| f"--num-train-steps={num_train_steps}", |
| "--overwrite", |
| "--eval-speed-set", |
| *(f"{s:g}" for s in ablation.speeds), |
| ] |
| cmd.extend(ablation.extra_train_args) |
| return cmd |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| p = argparse.ArgumentParser(description=__doc__) |
| p.add_argument("--src", required=True, help="Source LeRobot dataset root") |
| p.add_argument("--out-root", required=True, help="Root directory under which built datasets land") |
| p.add_argument("--train-config", default="pi05_libero_various_speed_all") |
| p.add_argument("--base-asset-id", default="libero_various_speed_all_pi05") |
| p.add_argument("--exp-prefix", default="ablation") |
| p.add_argument("--run-tag", default="ablation") |
| p.add_argument( |
| "--only", |
| default=None, |
| help="Comma-separated list of ablation names. Default = all in ABLATIONS.", |
| ) |
| p.add_argument("--skip-build", action="store_true") |
| p.add_argument("--skip-norm-stats", action="store_true") |
| p.add_argument("--skip-train", action="store_true") |
| p.add_argument("--build-num-workers", type=int, default=16) |
| p.add_argument("--train-num-workers", type=int, default=8) |
| p.add_argument("--num-train-steps", type=int, default=30_000) |
| |
| p.add_argument("--clean-eps", type=float, default=0.0) |
| p.add_argument("--dry-run", action="store_true") |
| return p.parse_args() |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| src = Path(args.src).resolve() |
| out_root = Path(args.out_root).resolve() |
|
|
| only = {s.strip() for s in args.only.split(",")} if args.only else None |
| selected = [a for a in ABLATIONS if not only or a.name in only] |
| if not selected: |
| sys.exit( |
| f"No ablations selected. Names in --only must be one of: {[a.name for a in ABLATIONS]}" |
| ) |
| if only: |
| unknown = only - {a.name for a in ABLATIONS} |
| if unknown: |
| sys.exit(f"Unknown ablation names: {sorted(unknown)}") |
|
|
| print(f"selected = {[a.name for a in selected]}") |
| print(f"src = {src}") |
| print(f"out_root = {out_root}") |
| print(f"train_cfg = {args.train_config}") |
| print(f"asset_base = {args.base_asset_id}") |
| print(f"exp_prefix = {args.exp_prefix}") |
| print() |
|
|
| built_speeds: set[str] = set() |
| norm_done_asset_ids: set[str] = set() |
| for ab in selected: |
| print(f"\n========== ablation: {ab.name} speeds={ab.speeds} speed_integration={ab.speed_integration} ==========") |
| ds_dir = _dataset_dir(out_root, ab, args.run_tag) |
| speed_tok = _speed_token(ab.speeds) |
|
|
| if not args.skip_build: |
| if speed_tok in built_speeds: |
| print(f"[skip build, dataset already built in this run] dataset_dir = {ds_dir}") |
| else: |
| _, cmd = build_cmd(ab, src, out_root, args.run_tag, args.build_num_workers, args.clean_eps) |
| rc = _run(cmd, args.dry_run) |
| if rc != 0: |
| sys.exit(f"build failed for {ab.name} (rc={rc})") |
| built_speeds.add(speed_tok) |
| else: |
| print(f"[skip build] dataset_dir = {ds_dir}") |
|
|
| if not args.skip_norm_stats: |
| asset_id = _asset_id(args.base_asset_id, ab) |
| if asset_id in norm_done_asset_ids: |
| print(f"[skip norm-stats, already computed in this run for asset_id={asset_id}]") |
| else: |
| cmd = norm_stats_cmd(ab, ds_dir, args.train_config, args.base_asset_id) |
| rc = _run(cmd, args.dry_run) |
| if rc != 0: |
| sys.exit(f"norm-stats failed for {ab.name} (rc={rc})") |
| norm_done_asset_ids.add(asset_id) |
|
|
| if not args.skip_train: |
| cmd = train_cmd( |
| ab, |
| ds_dir, |
| args.train_config, |
| args.base_asset_id, |
| args.exp_prefix, |
| args.train_num_workers, |
| args.num_train_steps, |
| ) |
| rc = _run(cmd, args.dry_run) |
| if rc != 0: |
| sys.exit(f"train failed for {ab.name} (rc={rc})") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|