#!/usr/bin/env python3 """Run the 0.75/1.0/1.25/1.5 speed-embedding ablation. This one-click runner fixes the data/speed setup and compares only how speed is fed to the model: 1. text prompt 2. scalar modulation 3. soft prompt, P=8 It uses online sliding chunks and source 1.0x norm stats. """ from __future__ import annotations import argparse import dataclasses import json import os import shlex import subprocess import sys import time from pathlib import Path SPEEDS: tuple[float, ...] = (0.75, 1.0, 1.25, 1.5) DEFAULT_ASSET_ID = "online_sliding_speed_embed_0p75_1p0_1p25_1p5_pi05" @dataclasses.dataclass(frozen=True) class Experiment: name: str train_config: str exp_name: str train_args: tuple[str, ...] EXPERIMENTS: tuple[Experiment, ...] = ( Experiment( "text", "pi05_libero_speed_embed_text", "pi05_online_sliding_speed_embed_text_bs512_lr1e4", ("--data.speed-integration", "text"), ), Experiment( "modulation", "pi05_libero_speed_embed_modulation", "pi05_online_sliding_speed_embed_modulation_bs512_lr1e4", ("--data.speed-integration", "modulation", "--model.speed-modulation"), ), Experiment( "soft_prompt", "pi05_libero_speed_embed_softprompt_p8", "pi05_online_sliding_speed_embed_softprompt_p8_bs512_lr1e4", ( "--data.speed-integration", "soft_prompt", "--model.soft-prompt-p", "8", "--model.soft-prompt-speeds", *[f"{speed:g}" for speed in SPEEDS], ), ), ) def _speed_args() -> list[str]: return [f"{speed:g}" for speed in SPEEDS] def _shell(cmd: list[str]) -> str: return " ".join(shlex.quote(part) for part in cmd) def _base_env(args: argparse.Namespace) -> dict[str, str]: env = os.environ.copy() env.setdefault("WANDB__SERVICE_WAIT", str(args.wandb_service_wait)) if not args.keep_wandb_env: env.pop("WANDB_API_KEY", None) env.pop("WANDB_API_KEY_FILE", None) return env def _run(cmd: list[str], *, cwd: Path, env: dict[str, str], dry_run: bool) -> None: print(_shell(cmd), flush=True) if dry_run: return subprocess.run(cmd, cwd=cwd, env=env, check=True) def _run_with_env(prefix_env: dict[str, str], cmd: list[str], *, cwd: Path, env: dict[str, str], dry_run: bool) -> None: display = " ".join(f"{key}={shlex.quote(value)}" for key, value in prefix_env.items()) print(f"{display} {_shell(cmd)}", flush=True) if dry_run: return run_env = {**env, **prefix_env} subprocess.run(cmd, cwd=cwd, env=run_env, check=True) def _norm_stats_path(project_root: Path, train_config: str, asset_id: str) -> Path: return project_root / "assets" / train_config / asset_id / "norm_stats.json" def _checkpoint_dir(project_root: Path, train_config: str, exp_name: str) -> Path: return project_root / "checkpoints" / train_config / exp_name def _latest_checkpoint_step_dir( project_root: Path, train_config: str, exp_name: str, ckpt_step: int | None, fallback_step: int, ) -> Path: root = _checkpoint_dir(project_root, train_config, exp_name) if ckpt_step is not None: return root / str(ckpt_step) if not root.exists(): return root / str(fallback_step) numeric = sorted((int(path.name), path) for path in root.iterdir() if path.is_dir() and path.name.isdigit()) if numeric: return numeric[-1][1] return root / str(fallback_step) def _speed_tag(speed: float) -> str: return f"{speed:g}".replace(".", "p") + "x" def _cuda_devices(args: argparse.Namespace) -> list[str]: if args.cuda_devices: devices = [item.strip() for item in args.cuda_devices.split(",") if item.strip()] else: devices = [str(i) for i in range(args.num_gpus)] if len(devices) != args.num_gpus: raise SystemExit(f"expected {args.num_gpus} CUDA devices, got {devices}") return devices def _norm_cmd(args: argparse.Namespace, exp: Experiment) -> list[str]: return [ sys.executable, "scripts/compute_norm_stats.py", "--config-name", exp.train_config, "--repo-id", str(args.data_root), "--asset-id", args.asset_id, "--online-sliding-chunks", "--online-sliding-speeds", *_speed_args(), ] def _train_cmd(args: argparse.Namespace, exp: Experiment, log_dir: Path) -> list[str]: cmd = [ "uv", "run", "torchrun", "--standalone", "--nnodes=1", f"--nproc_per_node={args.num_gpus}", "--log-dir", str(log_dir / exp.name), "--redirects", "3", "--tee", "3", "scripts/train_pytorch.py", exp.train_config, "--exp-name", exp.exp_name, "--pytorch-weight-path", str(args.pi05_base), "--batch-size", str(args.batch_size), "--num-workers", str(args.num_workers), "--num-train-steps", str(args.num_train_steps), "--log-interval", str(args.log_interval), "--save-interval", str(args.save_interval), "--lr-schedule.peak-lr", str(args.lr), "--lr-schedule.decay-lr", str(args.lr), "--eval-speed-set", *_speed_args(), "--data.repo-id", str(args.data_root), "--data.assets.asset-id", args.asset_id, "--data.online-sliding-chunks", "--data.online-sliding-speeds", *_speed_args(), "--model.pytorch-compile-mode", args.compile_mode, *exp.train_args, ] if args.no_wandb: cmd.append("--no-wandb-enabled") if args.train_mode == "overwrite": cmd.append("--overwrite") elif args.train_mode == "resume": cmd.append("--resume") for extra in args.extra_train_arg: cmd.extend(shlex.split(extra)) return cmd def _serve_cmd(args: argparse.Namespace, exp: Experiment, ckpt_dir: Path, port: int) -> list[str]: return [ "uv", "run", "python", "scripts/serve_policy.py", "policy:checkpoint", "--policy.config", exp.train_config, "--policy.dir", str(ckpt_dir), "--port", str(port), ] def _terminate_servers(servers: list[subprocess.Popen]) -> None: for proc in servers: if proc.poll() is None: proc.terminate() deadline = time.time() + 30 for proc in servers: remaining = max(0.0, deadline - time.time()) try: proc.wait(timeout=remaining) except subprocess.TimeoutExpired: proc.kill() for proc in servers: try: proc.wait(timeout=5) except subprocess.TimeoutExpired: pass def _start_servers(args: argparse.Namespace, exp: Experiment, ckpt_dir: Path, env: dict[str, str]) -> list[subprocess.Popen]: devices = _cuda_devices(args) server_log_dir = args.log_dir / "servers" / exp.name server_log_dir.mkdir(parents=True, exist_ok=True) servers: list[subprocess.Popen] = [] for rank, device in enumerate(devices): port = args.base_port + rank cmd = _serve_cmd(args, exp, ckpt_dir, port) print(f"CUDA_VISIBLE_DEVICES={device} {_shell(cmd)} > {server_log_dir / f'gpu{rank}.log'} 2>&1 &") if args.dry_run: continue log_file = (server_log_dir / f"gpu{rank}.log").open("w") server_env = {**env, "CUDA_VISIBLE_DEVICES": device} servers.append(subprocess.Popen(cmd, cwd=args.project_root, env=server_env, stdout=log_file, stderr=subprocess.STDOUT)) if not args.dry_run: print(f"Waiting {args.server_wait_seconds}s for policy servers to load...", flush=True) time.sleep(args.server_wait_seconds) return servers def _eval_experiment(args: argparse.Namespace, exp: Experiment, env: dict[str, str]) -> None: ckpt_dir = _latest_checkpoint_step_dir( args.project_root, exp.train_config, exp.exp_name, args.ckpt_step, args.num_train_steps - 1, ) if not args.dry_run and not ckpt_dir.exists(): raise SystemExit(f"checkpoint for eval does not exist: {ckpt_dir}") print(f"\n========== eval: {exp.name} ckpt={ckpt_dir} ==========") servers = _start_servers(args, exp, ckpt_dir, env) try: for speed in args.eval_speeds: prefix_env = { "SPEED": f"{speed:g}", "BASE_PORT": str(args.base_port), "HOST": args.host, "NUM_TRIALS": str(args.num_trials), "SAVE_VIDEOS": "1" if args.save_videos else "0", "PYTHON_CMD": "uv run python", "RESULTS_DIR": str(args.results_dir / exp.exp_name / f"speed_{_speed_tag(speed)}"), } _run_with_env( prefix_env, ["./scripts/eval_libero_8gpu.sh"], cwd=args.project_root, env=env, dry_run=args.dry_run, ) finally: if servers: print(f"Stopping policy servers for {exp.name}...", flush=True) _terminate_servers(servers) def _write_manifest(project_root: Path, log_dir: Path, args: argparse.Namespace, experiments: tuple[Experiment, ...]) -> None: if args.dry_run: return log_dir.mkdir(parents=True, exist_ok=True) manifest = { "speeds": SPEEDS, "asset_id": args.asset_id, "batch_size": args.batch_size, "lr": args.lr, "eval_speeds": args.eval_speeds, "data_root": str(args.data_root), "pi05_base": str(args.pi05_base), "experiments": [dataclasses.asdict(exp) for exp in experiments], } (log_dir / "speed_embedding_ablation_manifest.json").write_text(json.dumps(manifest, indent=2) + "\n") def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--project-root", type=Path, default=Path.cwd(), help="VLAwithVariousSpeed repo root.") parser.add_argument("--data-root", type=Path, required=True, help="Source LeRobot/LIBERO dataset root.") parser.add_argument("--pi05-base", type=Path, required=True, help="Path to PI05 base weights directory.") parser.add_argument("--asset-id", default=DEFAULT_ASSET_ID) parser.add_argument("--only", default=None, help="Comma-separated subset: text,modulation,soft_prompt.") parser.add_argument("--stage", choices=("all", "norm", "train", "eval"), default="all") parser.add_argument("--force-norm", action="store_true") parser.add_argument( "--train-mode", choices=("overwrite", "resume", "skip-existing", "fail-if-exists"), default="overwrite", ) parser.add_argument("--dry-run", action="store_true") parser.add_argument("--num-gpus", type=int, default=8) parser.add_argument("--cuda-devices", default=None, help="CUDA_VISIBLE_DEVICES value. Default: 0..num_gpus-1.") parser.add_argument("--batch-size", type=int, default=512) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--num-workers", type=int, default=2) parser.add_argument("--num-train-steps", type=int, default=30_000) parser.add_argument("--log-interval", type=int, default=100) parser.add_argument("--save-interval", type=int, default=1000) parser.add_argument("--compile-mode", default="None") parser.add_argument("--log-dir", type=Path, default=Path("logs/speed_embedding_ablation")) parser.add_argument("--eval-speeds", type=float, nargs="+", default=list(SPEEDS)) parser.add_argument("--results-dir", type=Path, default=Path("results/speed_embedding_ablation")) parser.add_argument("--ckpt-step", type=int, default=None, help="Checkpoint step to evaluate. Default: latest step.") parser.add_argument("--base-port", type=int, default=8000) parser.add_argument("--host", default="0.0.0.0") parser.add_argument("--num-trials", type=int, default=50) parser.add_argument("--save-videos", action="store_true") parser.add_argument("--server-wait-seconds", type=int, default=120) parser.add_argument("--no-wandb", action="store_true") parser.add_argument("--keep-wandb-env", action="store_true") parser.add_argument("--wandb-service-wait", type=int, default=300) parser.add_argument( "--extra-train-arg", action="append", default=[], help="Extra argument appended to train_pytorch.py. Repeat for multiple args.", ) return parser.parse_args() def _select_experiments(args: argparse.Namespace) -> tuple[Experiment, ...]: if args.only is None: return EXPERIMENTS wanted = {name.strip() for name in args.only.split(",") if name.strip()} known = {exp.name for exp in EXPERIMENTS} unknown = wanted - known if unknown: raise SystemExit(f"unknown experiments: {sorted(unknown)}; known={sorted(known)}") return tuple(exp for exp in EXPERIMENTS if exp.name in wanted) def main() -> None: args = parse_args() project_root = args.project_root.resolve() args.project_root = project_root args.data_root = args.data_root.resolve() args.pi05_base = args.pi05_base.resolve() args.log_dir = (project_root / args.log_dir).resolve() if not args.log_dir.is_absolute() else args.log_dir.resolve() args.results_dir = ( (project_root / args.results_dir).resolve() if not args.results_dir.is_absolute() else args.results_dir.resolve() ) if not (project_root / "scripts" / "train_pytorch.py").exists(): raise SystemExit(f"project root does not look valid: {project_root}") if not args.data_root.exists(): raise SystemExit(f"data root does not exist: {args.data_root}") if not args.pi05_base.exists(): raise SystemExit(f"pi05 base path does not exist: {args.pi05_base}") if args.batch_size % args.num_gpus != 0: raise SystemExit(f"--batch-size ({args.batch_size}) must be divisible by --num-gpus ({args.num_gpus}).") experiments = _select_experiments(args) args.log_dir.mkdir(parents=True, exist_ok=True) _write_manifest(project_root, args.log_dir, args, experiments) env = _base_env(args) env["CUDA_VISIBLE_DEVICES"] = args.cuda_devices or ",".join(str(i) for i in range(args.num_gpus)) print("Speed embedding ablation runner") print(f" project_root = {project_root}") print(f" data_root = {args.data_root}") print(f" pi05_base = {args.pi05_base}") print(f" speeds = {SPEEDS}") print(f" asset_id = {args.asset_id}") print(f" batch_size = {args.batch_size}") print(f" lr = {args.lr}") print(f" stage = {args.stage}") print(f" train_mode = {args.train_mode}") print(f" eval_speeds = {args.eval_speeds}") print(f" results_dir = {args.results_dir}") print(f" experiments = {[exp.name for exp in experiments]}") print() if args.stage in ("all", "norm"): for exp in experiments: stats_path = _norm_stats_path(project_root, exp.train_config, args.asset_id) if stats_path.exists() and not args.force_norm: print(f"[skip norm] {exp.name}: {stats_path}") else: print(f"\n========== norm: {exp.name} source 1.0x stats for online sliding ==========") _run(_norm_cmd(args, exp), cwd=project_root, env=env, dry_run=args.dry_run) if args.stage in ("all", "train"): for exp in experiments: stats_path = _norm_stats_path(project_root, exp.train_config, args.asset_id) if not args.dry_run and not stats_path.exists(): raise SystemExit(f"missing norm stats for {exp.name}: {stats_path}") ckpt_dir = _checkpoint_dir(project_root, exp.train_config, exp.exp_name) if args.train_mode == "skip-existing" and ckpt_dir.exists(): print(f"[skip train] {exp.name}: {ckpt_dir}") continue if args.train_mode == "fail-if-exists" and ckpt_dir.exists(): raise SystemExit(f"checkpoint exists for {exp.name}: {ckpt_dir}") print(f"\n========== train: {exp.name} ==========") _run(_train_cmd(args, exp, args.log_dir / "torchrun"), cwd=project_root, env=env, dry_run=args.dry_run) if args.stage in ("all", "eval"): for exp in experiments: _eval_experiment(args, exp, env) if __name__ == "__main__": main()