| |
| """Run the 0.75/1.0/1.25/1.5 speed-embedding ablation. |
| |
| This one-click runner fixes the data/speed setup and compares only how speed is |
| fed to the model: |
| |
| 1. text prompt |
| 2. scalar modulation |
| 3. soft prompt, P=8 |
| |
| It uses online sliding chunks and source 1.0x norm stats. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import dataclasses |
| import json |
| import os |
| import shlex |
| import subprocess |
| import sys |
| import time |
| from pathlib import Path |
|
|
|
|
| SPEEDS: tuple[float, ...] = (0.75, 1.0, 1.25, 1.5) |
| DEFAULT_ASSET_ID = "online_sliding_speed_embed_0p75_1p0_1p25_1p5_pi05" |
|
|
|
|
| @dataclasses.dataclass(frozen=True) |
| class Experiment: |
| name: str |
| train_config: str |
| exp_name: str |
| train_args: tuple[str, ...] |
|
|
|
|
| EXPERIMENTS: tuple[Experiment, ...] = ( |
| Experiment( |
| "text", |
| "pi05_libero_speed_embed_text", |
| "pi05_online_sliding_speed_embed_text_bs512_lr1e4", |
| ("--data.speed-integration", "text"), |
| ), |
| Experiment( |
| "modulation", |
| "pi05_libero_speed_embed_modulation", |
| "pi05_online_sliding_speed_embed_modulation_bs512_lr1e4", |
| ("--data.speed-integration", "modulation", "--model.speed-modulation"), |
| ), |
| Experiment( |
| "soft_prompt", |
| "pi05_libero_speed_embed_softprompt_p8", |
| "pi05_online_sliding_speed_embed_softprompt_p8_bs512_lr1e4", |
| ( |
| "--data.speed-integration", |
| "soft_prompt", |
| "--model.soft-prompt-p", |
| "8", |
| "--model.soft-prompt-speeds", |
| *[f"{speed:g}" for speed in SPEEDS], |
| ), |
| ), |
| ) |
|
|
|
|
| def _speed_args() -> list[str]: |
| return [f"{speed:g}" for speed in SPEEDS] |
|
|
|
|
| def _shell(cmd: list[str]) -> str: |
| return " ".join(shlex.quote(part) for part in cmd) |
|
|
|
|
| def _base_env(args: argparse.Namespace) -> dict[str, str]: |
| env = os.environ.copy() |
| env.setdefault("WANDB__SERVICE_WAIT", str(args.wandb_service_wait)) |
| if not args.keep_wandb_env: |
| env.pop("WANDB_API_KEY", None) |
| env.pop("WANDB_API_KEY_FILE", None) |
| return env |
|
|
|
|
| def _run(cmd: list[str], *, cwd: Path, env: dict[str, str], dry_run: bool) -> None: |
| print(_shell(cmd), flush=True) |
| if dry_run: |
| return |
| subprocess.run(cmd, cwd=cwd, env=env, check=True) |
|
|
|
|
| def _run_with_env(prefix_env: dict[str, str], cmd: list[str], *, cwd: Path, env: dict[str, str], dry_run: bool) -> None: |
| display = " ".join(f"{key}={shlex.quote(value)}" for key, value in prefix_env.items()) |
| print(f"{display} {_shell(cmd)}", flush=True) |
| if dry_run: |
| return |
| run_env = {**env, **prefix_env} |
| subprocess.run(cmd, cwd=cwd, env=run_env, check=True) |
|
|
|
|
| def _norm_stats_path(project_root: Path, train_config: str, asset_id: str) -> Path: |
| return project_root / "assets" / train_config / asset_id / "norm_stats.json" |
|
|
|
|
| def _checkpoint_dir(project_root: Path, train_config: str, exp_name: str) -> Path: |
| return project_root / "checkpoints" / train_config / exp_name |
|
|
|
|
| def _latest_checkpoint_step_dir( |
| project_root: Path, |
| train_config: str, |
| exp_name: str, |
| ckpt_step: int | None, |
| fallback_step: int, |
| ) -> Path: |
| root = _checkpoint_dir(project_root, train_config, exp_name) |
| if ckpt_step is not None: |
| return root / str(ckpt_step) |
| if not root.exists(): |
| return root / str(fallback_step) |
| numeric = sorted((int(path.name), path) for path in root.iterdir() if path.is_dir() and path.name.isdigit()) |
| if numeric: |
| return numeric[-1][1] |
| return root / str(fallback_step) |
|
|
|
|
| def _speed_tag(speed: float) -> str: |
| return f"{speed:g}".replace(".", "p") + "x" |
|
|
|
|
| def _cuda_devices(args: argparse.Namespace) -> list[str]: |
| if args.cuda_devices: |
| devices = [item.strip() for item in args.cuda_devices.split(",") if item.strip()] |
| else: |
| devices = [str(i) for i in range(args.num_gpus)] |
| if len(devices) != args.num_gpus: |
| raise SystemExit(f"expected {args.num_gpus} CUDA devices, got {devices}") |
| return devices |
|
|
|
|
| def _norm_cmd(args: argparse.Namespace, exp: Experiment) -> list[str]: |
| return [ |
| sys.executable, |
| "scripts/compute_norm_stats.py", |
| "--config-name", |
| exp.train_config, |
| "--repo-id", |
| str(args.data_root), |
| "--asset-id", |
| args.asset_id, |
| "--online-sliding-chunks", |
| "--online-sliding-speeds", |
| *_speed_args(), |
| ] |
|
|
|
|
| def _train_cmd(args: argparse.Namespace, exp: Experiment, log_dir: Path) -> list[str]: |
| cmd = [ |
| "uv", |
| "run", |
| "torchrun", |
| "--standalone", |
| "--nnodes=1", |
| f"--nproc_per_node={args.num_gpus}", |
| "--log-dir", |
| str(log_dir / exp.name), |
| "--redirects", |
| "3", |
| "--tee", |
| "3", |
| "scripts/train_pytorch.py", |
| exp.train_config, |
| "--exp-name", |
| exp.exp_name, |
| "--pytorch-weight-path", |
| str(args.pi05_base), |
| "--batch-size", |
| str(args.batch_size), |
| "--num-workers", |
| str(args.num_workers), |
| "--num-train-steps", |
| str(args.num_train_steps), |
| "--log-interval", |
| str(args.log_interval), |
| "--save-interval", |
| str(args.save_interval), |
| "--lr-schedule.peak-lr", |
| str(args.lr), |
| "--lr-schedule.decay-lr", |
| str(args.lr), |
| "--eval-speed-set", |
| *_speed_args(), |
| "--data.repo-id", |
| str(args.data_root), |
| "--data.assets.asset-id", |
| args.asset_id, |
| "--data.online-sliding-chunks", |
| "--data.online-sliding-speeds", |
| *_speed_args(), |
| "--model.pytorch-compile-mode", |
| args.compile_mode, |
| *exp.train_args, |
| ] |
|
|
| if args.no_wandb: |
| cmd.append("--no-wandb-enabled") |
| if args.train_mode == "overwrite": |
| cmd.append("--overwrite") |
| elif args.train_mode == "resume": |
| cmd.append("--resume") |
|
|
| for extra in args.extra_train_arg: |
| cmd.extend(shlex.split(extra)) |
| return cmd |
|
|
|
|
| def _serve_cmd(args: argparse.Namespace, exp: Experiment, ckpt_dir: Path, port: int) -> list[str]: |
| return [ |
| "uv", |
| "run", |
| "python", |
| "scripts/serve_policy.py", |
| "policy:checkpoint", |
| "--policy.config", |
| exp.train_config, |
| "--policy.dir", |
| str(ckpt_dir), |
| "--port", |
| str(port), |
| ] |
|
|
|
|
| def _terminate_servers(servers: list[subprocess.Popen]) -> None: |
| for proc in servers: |
| if proc.poll() is None: |
| proc.terminate() |
| deadline = time.time() + 30 |
| for proc in servers: |
| remaining = max(0.0, deadline - time.time()) |
| try: |
| proc.wait(timeout=remaining) |
| except subprocess.TimeoutExpired: |
| proc.kill() |
| for proc in servers: |
| try: |
| proc.wait(timeout=5) |
| except subprocess.TimeoutExpired: |
| pass |
|
|
|
|
| def _start_servers(args: argparse.Namespace, exp: Experiment, ckpt_dir: Path, env: dict[str, str]) -> list[subprocess.Popen]: |
| devices = _cuda_devices(args) |
| server_log_dir = args.log_dir / "servers" / exp.name |
| server_log_dir.mkdir(parents=True, exist_ok=True) |
| servers: list[subprocess.Popen] = [] |
| for rank, device in enumerate(devices): |
| port = args.base_port + rank |
| cmd = _serve_cmd(args, exp, ckpt_dir, port) |
| print(f"CUDA_VISIBLE_DEVICES={device} {_shell(cmd)} > {server_log_dir / f'gpu{rank}.log'} 2>&1 &") |
| if args.dry_run: |
| continue |
| log_file = (server_log_dir / f"gpu{rank}.log").open("w") |
| server_env = {**env, "CUDA_VISIBLE_DEVICES": device} |
| servers.append(subprocess.Popen(cmd, cwd=args.project_root, env=server_env, stdout=log_file, stderr=subprocess.STDOUT)) |
| if not args.dry_run: |
| print(f"Waiting {args.server_wait_seconds}s for policy servers to load...", flush=True) |
| time.sleep(args.server_wait_seconds) |
| return servers |
|
|
|
|
| def _eval_experiment(args: argparse.Namespace, exp: Experiment, env: dict[str, str]) -> None: |
| ckpt_dir = _latest_checkpoint_step_dir( |
| args.project_root, |
| exp.train_config, |
| exp.exp_name, |
| args.ckpt_step, |
| args.num_train_steps - 1, |
| ) |
| if not args.dry_run and not ckpt_dir.exists(): |
| raise SystemExit(f"checkpoint for eval does not exist: {ckpt_dir}") |
|
|
| print(f"\n========== eval: {exp.name} ckpt={ckpt_dir} ==========") |
| servers = _start_servers(args, exp, ckpt_dir, env) |
| try: |
| for speed in args.eval_speeds: |
| prefix_env = { |
| "SPEED": f"{speed:g}", |
| "BASE_PORT": str(args.base_port), |
| "HOST": args.host, |
| "NUM_TRIALS": str(args.num_trials), |
| "SAVE_VIDEOS": "1" if args.save_videos else "0", |
| "PYTHON_CMD": "uv run python", |
| "RESULTS_DIR": str(args.results_dir / exp.exp_name / f"speed_{_speed_tag(speed)}"), |
| } |
| _run_with_env( |
| prefix_env, |
| ["./scripts/eval_libero_8gpu.sh"], |
| cwd=args.project_root, |
| env=env, |
| dry_run=args.dry_run, |
| ) |
| finally: |
| if servers: |
| print(f"Stopping policy servers for {exp.name}...", flush=True) |
| _terminate_servers(servers) |
|
|
|
|
| def _write_manifest(project_root: Path, log_dir: Path, args: argparse.Namespace, experiments: tuple[Experiment, ...]) -> None: |
| if args.dry_run: |
| return |
| log_dir.mkdir(parents=True, exist_ok=True) |
| manifest = { |
| "speeds": SPEEDS, |
| "asset_id": args.asset_id, |
| "batch_size": args.batch_size, |
| "lr": args.lr, |
| "eval_speeds": args.eval_speeds, |
| "data_root": str(args.data_root), |
| "pi05_base": str(args.pi05_base), |
| "experiments": [dataclasses.asdict(exp) for exp in experiments], |
| } |
| (log_dir / "speed_embedding_ablation_manifest.json").write_text(json.dumps(manifest, indent=2) + "\n") |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description=__doc__) |
| parser.add_argument("--project-root", type=Path, default=Path.cwd(), help="VLAwithVariousSpeed repo root.") |
| parser.add_argument("--data-root", type=Path, required=True, help="Source LeRobot/LIBERO dataset root.") |
| parser.add_argument("--pi05-base", type=Path, required=True, help="Path to PI05 base weights directory.") |
| parser.add_argument("--asset-id", default=DEFAULT_ASSET_ID) |
| parser.add_argument("--only", default=None, help="Comma-separated subset: text,modulation,soft_prompt.") |
| parser.add_argument("--stage", choices=("all", "norm", "train", "eval"), default="all") |
| parser.add_argument("--force-norm", action="store_true") |
| parser.add_argument( |
| "--train-mode", |
| choices=("overwrite", "resume", "skip-existing", "fail-if-exists"), |
| default="overwrite", |
| ) |
| parser.add_argument("--dry-run", action="store_true") |
|
|
| parser.add_argument("--num-gpus", type=int, default=8) |
| parser.add_argument("--cuda-devices", default=None, help="CUDA_VISIBLE_DEVICES value. Default: 0..num_gpus-1.") |
| parser.add_argument("--batch-size", type=int, default=512) |
| parser.add_argument("--lr", type=float, default=1e-4) |
| parser.add_argument("--num-workers", type=int, default=2) |
| parser.add_argument("--num-train-steps", type=int, default=30_000) |
| parser.add_argument("--log-interval", type=int, default=100) |
| parser.add_argument("--save-interval", type=int, default=1000) |
| parser.add_argument("--compile-mode", default="None") |
| parser.add_argument("--log-dir", type=Path, default=Path("logs/speed_embedding_ablation")) |
|
|
| parser.add_argument("--eval-speeds", type=float, nargs="+", default=list(SPEEDS)) |
| parser.add_argument("--results-dir", type=Path, default=Path("results/speed_embedding_ablation")) |
| parser.add_argument("--ckpt-step", type=int, default=None, help="Checkpoint step to evaluate. Default: latest step.") |
| parser.add_argument("--base-port", type=int, default=8000) |
| parser.add_argument("--host", default="0.0.0.0") |
| parser.add_argument("--num-trials", type=int, default=50) |
| parser.add_argument("--save-videos", action="store_true") |
| parser.add_argument("--server-wait-seconds", type=int, default=120) |
|
|
| parser.add_argument("--no-wandb", action="store_true") |
| parser.add_argument("--keep-wandb-env", action="store_true") |
| parser.add_argument("--wandb-service-wait", type=int, default=300) |
| parser.add_argument( |
| "--extra-train-arg", |
| action="append", |
| default=[], |
| help="Extra argument appended to train_pytorch.py. Repeat for multiple args.", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def _select_experiments(args: argparse.Namespace) -> tuple[Experiment, ...]: |
| if args.only is None: |
| return EXPERIMENTS |
| wanted = {name.strip() for name in args.only.split(",") if name.strip()} |
| known = {exp.name for exp in EXPERIMENTS} |
| unknown = wanted - known |
| if unknown: |
| raise SystemExit(f"unknown experiments: {sorted(unknown)}; known={sorted(known)}") |
| return tuple(exp for exp in EXPERIMENTS if exp.name in wanted) |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| project_root = args.project_root.resolve() |
| args.project_root = project_root |
| args.data_root = args.data_root.resolve() |
| args.pi05_base = args.pi05_base.resolve() |
| args.log_dir = (project_root / args.log_dir).resolve() if not args.log_dir.is_absolute() else args.log_dir.resolve() |
| args.results_dir = ( |
| (project_root / args.results_dir).resolve() if not args.results_dir.is_absolute() else args.results_dir.resolve() |
| ) |
|
|
| if not (project_root / "scripts" / "train_pytorch.py").exists(): |
| raise SystemExit(f"project root does not look valid: {project_root}") |
| if not args.data_root.exists(): |
| raise SystemExit(f"data root does not exist: {args.data_root}") |
| if not args.pi05_base.exists(): |
| raise SystemExit(f"pi05 base path does not exist: {args.pi05_base}") |
| if args.batch_size % args.num_gpus != 0: |
| raise SystemExit(f"--batch-size ({args.batch_size}) must be divisible by --num-gpus ({args.num_gpus}).") |
|
|
| experiments = _select_experiments(args) |
| args.log_dir.mkdir(parents=True, exist_ok=True) |
| _write_manifest(project_root, args.log_dir, args, experiments) |
|
|
| env = _base_env(args) |
| env["CUDA_VISIBLE_DEVICES"] = args.cuda_devices or ",".join(str(i) for i in range(args.num_gpus)) |
|
|
| print("Speed embedding ablation runner") |
| print(f" project_root = {project_root}") |
| print(f" data_root = {args.data_root}") |
| print(f" pi05_base = {args.pi05_base}") |
| print(f" speeds = {SPEEDS}") |
| print(f" asset_id = {args.asset_id}") |
| print(f" batch_size = {args.batch_size}") |
| print(f" lr = {args.lr}") |
| print(f" stage = {args.stage}") |
| print(f" train_mode = {args.train_mode}") |
| print(f" eval_speeds = {args.eval_speeds}") |
| print(f" results_dir = {args.results_dir}") |
| print(f" experiments = {[exp.name for exp in experiments]}") |
| print() |
|
|
| if args.stage in ("all", "norm"): |
| for exp in experiments: |
| stats_path = _norm_stats_path(project_root, exp.train_config, args.asset_id) |
| if stats_path.exists() and not args.force_norm: |
| print(f"[skip norm] {exp.name}: {stats_path}") |
| else: |
| print(f"\n========== norm: {exp.name} source 1.0x stats for online sliding ==========") |
| _run(_norm_cmd(args, exp), cwd=project_root, env=env, dry_run=args.dry_run) |
|
|
| if args.stage in ("all", "train"): |
| for exp in experiments: |
| stats_path = _norm_stats_path(project_root, exp.train_config, args.asset_id) |
| if not args.dry_run and not stats_path.exists(): |
| raise SystemExit(f"missing norm stats for {exp.name}: {stats_path}") |
| ckpt_dir = _checkpoint_dir(project_root, exp.train_config, exp.exp_name) |
| if args.train_mode == "skip-existing" and ckpt_dir.exists(): |
| print(f"[skip train] {exp.name}: {ckpt_dir}") |
| continue |
| if args.train_mode == "fail-if-exists" and ckpt_dir.exists(): |
| raise SystemExit(f"checkpoint exists for {exp.name}: {ckpt_dir}") |
|
|
| print(f"\n========== train: {exp.name} ==========") |
| _run(_train_cmd(args, exp, args.log_dir / "torchrun"), cwd=project_root, env=env, dry_run=args.dry_run) |
|
|
| if args.stage in ("all", "eval"): |
| for exp in experiments: |
| _eval_experiment(args, exp, env) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|