#!/usr/bin/env python3 """Run a sweep of variable-speed ablation experiments end-to-end. For each ablation: build dataset (multi-process) -> compute norm stats -> train. All three stages share one TrainConfig name; per-ablation differences are applied via CLI overrides on ``--data.repo_id`` and ``--data.assets.asset_id``, and via ``--repo-id`` / ``--asset-id`` for compute_norm_stats. Edit the ABLATIONS table below to define your groups. Use: - ``--only NAME[,NAME]`` to run a subset - ``--skip-build / --skip-norm-stats / --skip-train`` to bypass stages - ``--dry-run`` to print commands without executing them Example: uv run python scripts/run_ablations.py \\ --src $SRC --out-root $ROOT \\ --train-config pi05_libero_various_speed_all \\ --base-asset-id libero_various_speed_all_pi05 \\ --exp-prefix pi05_ablation \\ --num-train-steps 30000 \\ --build-num-workers 16 --train-num-workers 8 \\ --only g2_coarse,g3a_step025 --dry-run """ from __future__ import annotations import argparse import dataclasses import shlex import subprocess import sys from pathlib import Path @dataclasses.dataclass(frozen=True) class Ablation: name: str speeds: tuple[float, ...] # How speed is fed to the model: "text" | "modulation" | "soft_prompt" | "auto". # See LeRobotVariousSpeedLiberoDataConfig.speed_integration for semantics. speed_integration: str = "auto" # Extra args appended verbatim to the train_pytorch.py invocation. Use this # for per-ablation model overrides (e.g., speed_modulation for modulation). extra_train_args: tuple[str, ...] = () # When set, multiple ablations sharing the same key share one norm_stats # file (and thus one ``asset_id``). Norm stats only depend on the dataset's # state/actions, so a P-length sweep over the same speed set should reuse # one set of stats. Default ``None`` -> isolate per ablation (legacy). shared_norm_key: str | None = None # Speeds shared by the speed-integration and soft-prompt-P sweeps (all train # on the same dataset; only the model conditioning differs). _SPEED_INT_SPEEDS: tuple[float, ...] = (0.75, 1.0, 1.25, 1.5) _SOFT_PROMPT_P_VALUES: tuple[int, ...] = (1, 4, 8, 16, 32) def _soft_prompt_ablation(p: int) -> Ablation: # tyro requires tuple[float, ...] CLI args to be space-separated argv # elements, NOT comma-joined. Mirror the --eval-speed-set pattern used # by train_cmd below. return Ablation( f"softprompt_p{p}", _SPEED_INT_SPEEDS, speed_integration="soft_prompt", extra_train_args=( f"--model.soft-prompt-p={p}", "--model.soft-prompt-speeds", *(f"{s:g}" for s in _SPEED_INT_SPEEDS), ), # All P-variants share one norm_stats: norm depends only on dataset. shared_norm_key="softprompt_shared", ) # Edit this table to define your ablation groups. Names show up in dataset # directory names, asset_id suffixes, and exp_name suffixes, so keep them short # and unique. Multiple ablations with identical ``speeds`` share one built # dataset; ablations with the same ``shared_norm_key`` share norm_stats. ABLATIONS: tuple[Ablation, ...] = ( # Speed-set sweep: range/step-size effects, all use textual prompt. Ablation("g1_baseline", (1.0,), speed_integration="text"), Ablation("g2_coarse", (0.5, 1.0, 1.5, 2.0), speed_integration="text"), Ablation("g3a_step025", (0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0), speed_integration="text"), Ablation("g4_narrow", (0.75, 1.0, 1.25, 1.5), speed_integration="text"), Ablation("g5_extreme", (0.25, 0.5, 1.0, 2.0, 4.0), speed_integration="text"), # Speed-integration sweep: same data, three conditioning strategies. The # soft_prompt arm reuses ``softprompt_p8`` from the P-sweep below. Ablation("speedint_text", _SPEED_INT_SPEEDS, speed_integration="text"), Ablation( "speedint_modulation", _SPEED_INT_SPEEDS, speed_integration="modulation", extra_train_args=("--model.speed-modulation=True",), ), # Soft-prompt P-length sweep. P=8 doubles as the soft_prompt arm of the # speed-integration sweep above. All five share one dataset and one # norm_stats file (shared_norm_key="softprompt_shared"). *(_soft_prompt_ablation(p) for p in _SOFT_PROMPT_P_VALUES), ) def _speed_token(speeds: tuple[float, ...]) -> str: return "_".join(f"{s:g}".replace(".", "p") for s in speeds) def _dataset_dir(out_root: Path, ablation: Ablation, run_tag: str) -> Path: return out_root / f"libero_speed_{_speed_token(ablation.speeds)}_{run_tag}" def _asset_id(base_asset_id: str, ablation: Ablation) -> str: suffix = ablation.shared_norm_key or ablation.name return f"{base_asset_id}_{suffix}" def _exp_name(prefix: str, ablation: Ablation) -> str: return f"{prefix}_{ablation.name}" def _run(cmd: list[str], dry_run: bool) -> int: print(f"$ {' '.join(shlex.quote(c) for c in cmd)}") if dry_run: return 0 return subprocess.call(cmd) def build_cmd( ablation: Ablation, src: Path, out_root: Path, run_tag: str, num_workers: int, clean_eps: float, ) -> tuple[Path, list[str]]: out_dir = _dataset_dir(out_root, ablation, run_tag) cmd = [ "uv", "run", "python", "scripts/build_libero_speed_dataset_mp.py", "--src", str(src), "--dst", str(out_dir), "--speeds", *(f"{s:g}" for s in ablation.speeds), "--clean-transl-eps", str(clean_eps), "--clean-rot-eps", str(clean_eps), "--min-segment-len", "1", "--num-workers", str(num_workers), "--overwrite", ] return out_dir, cmd def norm_stats_cmd( ablation: Ablation, dataset_dir: Path, train_config_name: str, base_asset_id: str, ) -> list[str]: return [ "uv", "run", "python", "scripts/compute_norm_stats.py", train_config_name, "--repo-id", str(dataset_dir), "--asset-id", _asset_id(base_asset_id, ablation), ] def train_cmd( ablation: Ablation, dataset_dir: Path, train_config_name: str, base_asset_id: str, exp_prefix: str, num_workers: int, num_train_steps: int, ) -> list[str]: cmd = [ "uv", "run", "python", "scripts/train_pytorch.py", train_config_name, f"--exp-name={_exp_name(exp_prefix, ablation)}", f"--data.repo-id={dataset_dir}", f"--data.assets.asset-id={_asset_id(base_asset_id, ablation)}", f"--data.speed-integration={ablation.speed_integration}", f"--num-workers={num_workers}", f"--num-train-steps={num_train_steps}", "--overwrite", "--eval-speed-set", *(f"{s:g}" for s in ablation.speeds), ] cmd.extend(ablation.extra_train_args) return cmd def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser(description=__doc__) p.add_argument("--src", required=True, help="Source LeRobot dataset root") p.add_argument("--out-root", required=True, help="Root directory under which built datasets land") p.add_argument("--train-config", default="pi05_libero_various_speed_all") p.add_argument("--base-asset-id", default="libero_various_speed_all_pi05") p.add_argument("--exp-prefix", default="ablation") p.add_argument("--run-tag", default="ablation") p.add_argument( "--only", default=None, help="Comma-separated list of ablation names. Default = all in ABLATIONS.", ) p.add_argument("--skip-build", action="store_true") p.add_argument("--skip-norm-stats", action="store_true") p.add_argument("--skip-train", action="store_true") p.add_argument("--build-num-workers", type=int, default=16) p.add_argument("--train-num-workers", type=int, default=8) p.add_argument("--num-train-steps", type=int, default=30_000) # Default 0.0 disables near-zero action cleaning (LIBERO is clean enough). p.add_argument("--clean-eps", type=float, default=0.0) p.add_argument("--dry-run", action="store_true") return p.parse_args() def main() -> None: args = parse_args() src = Path(args.src).resolve() out_root = Path(args.out_root).resolve() only = {s.strip() for s in args.only.split(",")} if args.only else None selected = [a for a in ABLATIONS if not only or a.name in only] if not selected: sys.exit( f"No ablations selected. Names in --only must be one of: {[a.name for a in ABLATIONS]}" ) if only: unknown = only - {a.name for a in ABLATIONS} if unknown: sys.exit(f"Unknown ablation names: {sorted(unknown)}") print(f"selected = {[a.name for a in selected]}") print(f"src = {src}") print(f"out_root = {out_root}") print(f"train_cfg = {args.train_config}") print(f"asset_base = {args.base_asset_id}") print(f"exp_prefix = {args.exp_prefix}") print() built_speeds: set[str] = set() norm_done_asset_ids: set[str] = set() for ab in selected: print(f"\n========== ablation: {ab.name} speeds={ab.speeds} speed_integration={ab.speed_integration} ==========") ds_dir = _dataset_dir(out_root, ab, args.run_tag) speed_tok = _speed_token(ab.speeds) if not args.skip_build: if speed_tok in built_speeds: print(f"[skip build, dataset already built in this run] dataset_dir = {ds_dir}") else: _, cmd = build_cmd(ab, src, out_root, args.run_tag, args.build_num_workers, args.clean_eps) rc = _run(cmd, args.dry_run) if rc != 0: sys.exit(f"build failed for {ab.name} (rc={rc})") built_speeds.add(speed_tok) else: print(f"[skip build] dataset_dir = {ds_dir}") if not args.skip_norm_stats: asset_id = _asset_id(args.base_asset_id, ab) if asset_id in norm_done_asset_ids: print(f"[skip norm-stats, already computed in this run for asset_id={asset_id}]") else: cmd = norm_stats_cmd(ab, ds_dir, args.train_config, args.base_asset_id) rc = _run(cmd, args.dry_run) if rc != 0: sys.exit(f"norm-stats failed for {ab.name} (rc={rc})") norm_done_asset_ids.add(asset_id) if not args.skip_train: cmd = train_cmd( ab, ds_dir, args.train_config, args.base_asset_id, args.exp_prefix, args.train_num_workers, args.num_train_steps, ) rc = _run(cmd, args.dry_run) if rc != 0: sys.exit(f"train failed for {ab.name} (rc={rc})") if __name__ == "__main__": main()