VLAwithVariousSpeed / scripts /run_ablations.py
Alan0928's picture
Upload folder using huggingface_hub
08ff31f verified
Raw
History Blame Contribute Delete
10.9 kB
#!/usr/bin/env python3
"""Run a sweep of variable-speed ablation experiments end-to-end.
For each ablation: build dataset (multi-process) -> compute norm stats -> train.
All three stages share one TrainConfig name; per-ablation differences are
applied via CLI overrides on ``--data.repo_id`` and ``--data.assets.asset_id``,
and via ``--repo-id`` / ``--asset-id`` for compute_norm_stats.
Edit the ABLATIONS table below to define your groups. Use:
- ``--only NAME[,NAME]`` to run a subset
- ``--skip-build / --skip-norm-stats / --skip-train`` to bypass stages
- ``--dry-run`` to print commands without executing them
Example:
uv run python scripts/run_ablations.py \\
--src $SRC --out-root $ROOT \\
--train-config pi05_libero_various_speed_all \\
--base-asset-id libero_various_speed_all_pi05 \\
--exp-prefix pi05_ablation \\
--num-train-steps 30000 \\
--build-num-workers 16 --train-num-workers 8 \\
--only g2_coarse,g3a_step025 --dry-run
"""
from __future__ import annotations
import argparse
import dataclasses
import shlex
import subprocess
import sys
from pathlib import Path
@dataclasses.dataclass(frozen=True)
class Ablation:
name: str
speeds: tuple[float, ...]
# How speed is fed to the model: "text" | "modulation" | "soft_prompt" | "auto".
# See LeRobotVariousSpeedLiberoDataConfig.speed_integration for semantics.
speed_integration: str = "auto"
# Extra args appended verbatim to the train_pytorch.py invocation. Use this
# for per-ablation model overrides (e.g., speed_modulation for modulation).
extra_train_args: tuple[str, ...] = ()
# When set, multiple ablations sharing the same key share one norm_stats
# file (and thus one ``asset_id``). Norm stats only depend on the dataset's
# state/actions, so a P-length sweep over the same speed set should reuse
# one set of stats. Default ``None`` -> isolate per ablation (legacy).
shared_norm_key: str | None = None
# Speeds shared by the speed-integration and soft-prompt-P sweeps (all train
# on the same dataset; only the model conditioning differs).
_SPEED_INT_SPEEDS: tuple[float, ...] = (0.75, 1.0, 1.25, 1.5)
_SOFT_PROMPT_P_VALUES: tuple[int, ...] = (1, 4, 8, 16, 32)
def _soft_prompt_ablation(p: int) -> Ablation:
# tyro requires tuple[float, ...] CLI args to be space-separated argv
# elements, NOT comma-joined. Mirror the --eval-speed-set pattern used
# by train_cmd below.
return Ablation(
f"softprompt_p{p}",
_SPEED_INT_SPEEDS,
speed_integration="soft_prompt",
extra_train_args=(
f"--model.soft-prompt-p={p}",
"--model.soft-prompt-speeds",
*(f"{s:g}" for s in _SPEED_INT_SPEEDS),
),
# All P-variants share one norm_stats: norm depends only on dataset.
shared_norm_key="softprompt_shared",
)
# Edit this table to define your ablation groups. Names show up in dataset
# directory names, asset_id suffixes, and exp_name suffixes, so keep them short
# and unique. Multiple ablations with identical ``speeds`` share one built
# dataset; ablations with the same ``shared_norm_key`` share norm_stats.
ABLATIONS: tuple[Ablation, ...] = (
# Speed-set sweep: range/step-size effects, all use textual prompt.
Ablation("g1_baseline", (1.0,), speed_integration="text"),
Ablation("g2_coarse", (0.5, 1.0, 1.5, 2.0), speed_integration="text"),
Ablation("g3a_step025", (0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0), speed_integration="text"),
Ablation("g4_narrow", (0.75, 1.0, 1.25, 1.5), speed_integration="text"),
Ablation("g5_extreme", (0.25, 0.5, 1.0, 2.0, 4.0), speed_integration="text"),
# Speed-integration sweep: same data, three conditioning strategies. The
# soft_prompt arm reuses ``softprompt_p8`` from the P-sweep below.
Ablation("speedint_text", _SPEED_INT_SPEEDS, speed_integration="text"),
Ablation(
"speedint_modulation",
_SPEED_INT_SPEEDS,
speed_integration="modulation",
extra_train_args=("--model.speed-modulation=True",),
),
# Soft-prompt P-length sweep. P=8 doubles as the soft_prompt arm of the
# speed-integration sweep above. All five share one dataset and one
# norm_stats file (shared_norm_key="softprompt_shared").
*(_soft_prompt_ablation(p) for p in _SOFT_PROMPT_P_VALUES),
)
def _speed_token(speeds: tuple[float, ...]) -> str:
return "_".join(f"{s:g}".replace(".", "p") for s in speeds)
def _dataset_dir(out_root: Path, ablation: Ablation, run_tag: str) -> Path:
return out_root / f"libero_speed_{_speed_token(ablation.speeds)}_{run_tag}"
def _asset_id(base_asset_id: str, ablation: Ablation) -> str:
suffix = ablation.shared_norm_key or ablation.name
return f"{base_asset_id}_{suffix}"
def _exp_name(prefix: str, ablation: Ablation) -> str:
return f"{prefix}_{ablation.name}"
def _run(cmd: list[str], dry_run: bool) -> int:
print(f"$ {' '.join(shlex.quote(c) for c in cmd)}")
if dry_run:
return 0
return subprocess.call(cmd)
def build_cmd(
ablation: Ablation,
src: Path,
out_root: Path,
run_tag: str,
num_workers: int,
clean_eps: float,
) -> tuple[Path, list[str]]:
out_dir = _dataset_dir(out_root, ablation, run_tag)
cmd = [
"uv", "run", "python", "scripts/build_libero_speed_dataset_mp.py",
"--src", str(src),
"--dst", str(out_dir),
"--speeds", *(f"{s:g}" for s in ablation.speeds),
"--clean-transl-eps", str(clean_eps),
"--clean-rot-eps", str(clean_eps),
"--min-segment-len", "1",
"--num-workers", str(num_workers),
"--overwrite",
]
return out_dir, cmd
def norm_stats_cmd(
ablation: Ablation,
dataset_dir: Path,
train_config_name: str,
base_asset_id: str,
) -> list[str]:
return [
"uv", "run", "python", "scripts/compute_norm_stats.py",
train_config_name,
"--repo-id", str(dataset_dir),
"--asset-id", _asset_id(base_asset_id, ablation),
]
def train_cmd(
ablation: Ablation,
dataset_dir: Path,
train_config_name: str,
base_asset_id: str,
exp_prefix: str,
num_workers: int,
num_train_steps: int,
) -> list[str]:
cmd = [
"uv", "run", "python", "scripts/train_pytorch.py", train_config_name,
f"--exp-name={_exp_name(exp_prefix, ablation)}",
f"--data.repo-id={dataset_dir}",
f"--data.assets.asset-id={_asset_id(base_asset_id, ablation)}",
f"--data.speed-integration={ablation.speed_integration}",
f"--num-workers={num_workers}",
f"--num-train-steps={num_train_steps}",
"--overwrite",
"--eval-speed-set",
*(f"{s:g}" for s in ablation.speeds),
]
cmd.extend(ablation.extra_train_args)
return cmd
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description=__doc__)
p.add_argument("--src", required=True, help="Source LeRobot dataset root")
p.add_argument("--out-root", required=True, help="Root directory under which built datasets land")
p.add_argument("--train-config", default="pi05_libero_various_speed_all")
p.add_argument("--base-asset-id", default="libero_various_speed_all_pi05")
p.add_argument("--exp-prefix", default="ablation")
p.add_argument("--run-tag", default="ablation")
p.add_argument(
"--only",
default=None,
help="Comma-separated list of ablation names. Default = all in ABLATIONS.",
)
p.add_argument("--skip-build", action="store_true")
p.add_argument("--skip-norm-stats", action="store_true")
p.add_argument("--skip-train", action="store_true")
p.add_argument("--build-num-workers", type=int, default=16)
p.add_argument("--train-num-workers", type=int, default=8)
p.add_argument("--num-train-steps", type=int, default=30_000)
# Default 0.0 disables near-zero action cleaning (LIBERO is clean enough).
p.add_argument("--clean-eps", type=float, default=0.0)
p.add_argument("--dry-run", action="store_true")
return p.parse_args()
def main() -> None:
args = parse_args()
src = Path(args.src).resolve()
out_root = Path(args.out_root).resolve()
only = {s.strip() for s in args.only.split(",")} if args.only else None
selected = [a for a in ABLATIONS if not only or a.name in only]
if not selected:
sys.exit(
f"No ablations selected. Names in --only must be one of: {[a.name for a in ABLATIONS]}"
)
if only:
unknown = only - {a.name for a in ABLATIONS}
if unknown:
sys.exit(f"Unknown ablation names: {sorted(unknown)}")
print(f"selected = {[a.name for a in selected]}")
print(f"src = {src}")
print(f"out_root = {out_root}")
print(f"train_cfg = {args.train_config}")
print(f"asset_base = {args.base_asset_id}")
print(f"exp_prefix = {args.exp_prefix}")
print()
built_speeds: set[str] = set()
norm_done_asset_ids: set[str] = set()
for ab in selected:
print(f"\n========== ablation: {ab.name} speeds={ab.speeds} speed_integration={ab.speed_integration} ==========")
ds_dir = _dataset_dir(out_root, ab, args.run_tag)
speed_tok = _speed_token(ab.speeds)
if not args.skip_build:
if speed_tok in built_speeds:
print(f"[skip build, dataset already built in this run] dataset_dir = {ds_dir}")
else:
_, cmd = build_cmd(ab, src, out_root, args.run_tag, args.build_num_workers, args.clean_eps)
rc = _run(cmd, args.dry_run)
if rc != 0:
sys.exit(f"build failed for {ab.name} (rc={rc})")
built_speeds.add(speed_tok)
else:
print(f"[skip build] dataset_dir = {ds_dir}")
if not args.skip_norm_stats:
asset_id = _asset_id(args.base_asset_id, ab)
if asset_id in norm_done_asset_ids:
print(f"[skip norm-stats, already computed in this run for asset_id={asset_id}]")
else:
cmd = norm_stats_cmd(ab, ds_dir, args.train_config, args.base_asset_id)
rc = _run(cmd, args.dry_run)
if rc != 0:
sys.exit(f"norm-stats failed for {ab.name} (rc={rc})")
norm_done_asset_ids.add(asset_id)
if not args.skip_train:
cmd = train_cmd(
ab,
ds_dir,
args.train_config,
args.base_asset_id,
args.exp_prefix,
args.train_num_workers,
args.num_train_steps,
)
rc = _run(cmd, args.dry_run)
if rc != 0:
sys.exit(f"train failed for {ab.name} (rc={rc})")
if __name__ == "__main__":
main()