VLAwithVariousSpeed / scripts /run_speed_embedding_ablation.py
Alan0928's picture
Upload folder using huggingface_hub
08ff31f verified
Raw
History Blame Contribute Delete
16.7 kB
#!/usr/bin/env python3
"""Run the 0.75/1.0/1.25/1.5 speed-embedding ablation.
This one-click runner fixes the data/speed setup and compares only how speed is
fed to the model:
1. text prompt
2. scalar modulation
3. soft prompt, P=8
It uses online sliding chunks and source 1.0x norm stats.
"""
from __future__ import annotations
import argparse
import dataclasses
import json
import os
import shlex
import subprocess
import sys
import time
from pathlib import Path
SPEEDS: tuple[float, ...] = (0.75, 1.0, 1.25, 1.5)
DEFAULT_ASSET_ID = "online_sliding_speed_embed_0p75_1p0_1p25_1p5_pi05"
@dataclasses.dataclass(frozen=True)
class Experiment:
name: str
train_config: str
exp_name: str
train_args: tuple[str, ...]
EXPERIMENTS: tuple[Experiment, ...] = (
Experiment(
"text",
"pi05_libero_speed_embed_text",
"pi05_online_sliding_speed_embed_text_bs512_lr1e4",
("--data.speed-integration", "text"),
),
Experiment(
"modulation",
"pi05_libero_speed_embed_modulation",
"pi05_online_sliding_speed_embed_modulation_bs512_lr1e4",
("--data.speed-integration", "modulation", "--model.speed-modulation"),
),
Experiment(
"soft_prompt",
"pi05_libero_speed_embed_softprompt_p8",
"pi05_online_sliding_speed_embed_softprompt_p8_bs512_lr1e4",
(
"--data.speed-integration",
"soft_prompt",
"--model.soft-prompt-p",
"8",
"--model.soft-prompt-speeds",
*[f"{speed:g}" for speed in SPEEDS],
),
),
)
def _speed_args() -> list[str]:
return [f"{speed:g}" for speed in SPEEDS]
def _shell(cmd: list[str]) -> str:
return " ".join(shlex.quote(part) for part in cmd)
def _base_env(args: argparse.Namespace) -> dict[str, str]:
env = os.environ.copy()
env.setdefault("WANDB__SERVICE_WAIT", str(args.wandb_service_wait))
if not args.keep_wandb_env:
env.pop("WANDB_API_KEY", None)
env.pop("WANDB_API_KEY_FILE", None)
return env
def _run(cmd: list[str], *, cwd: Path, env: dict[str, str], dry_run: bool) -> None:
print(_shell(cmd), flush=True)
if dry_run:
return
subprocess.run(cmd, cwd=cwd, env=env, check=True)
def _run_with_env(prefix_env: dict[str, str], cmd: list[str], *, cwd: Path, env: dict[str, str], dry_run: bool) -> None:
display = " ".join(f"{key}={shlex.quote(value)}" for key, value in prefix_env.items())
print(f"{display} {_shell(cmd)}", flush=True)
if dry_run:
return
run_env = {**env, **prefix_env}
subprocess.run(cmd, cwd=cwd, env=run_env, check=True)
def _norm_stats_path(project_root: Path, train_config: str, asset_id: str) -> Path:
return project_root / "assets" / train_config / asset_id / "norm_stats.json"
def _checkpoint_dir(project_root: Path, train_config: str, exp_name: str) -> Path:
return project_root / "checkpoints" / train_config / exp_name
def _latest_checkpoint_step_dir(
project_root: Path,
train_config: str,
exp_name: str,
ckpt_step: int | None,
fallback_step: int,
) -> Path:
root = _checkpoint_dir(project_root, train_config, exp_name)
if ckpt_step is not None:
return root / str(ckpt_step)
if not root.exists():
return root / str(fallback_step)
numeric = sorted((int(path.name), path) for path in root.iterdir() if path.is_dir() and path.name.isdigit())
if numeric:
return numeric[-1][1]
return root / str(fallback_step)
def _speed_tag(speed: float) -> str:
return f"{speed:g}".replace(".", "p") + "x"
def _cuda_devices(args: argparse.Namespace) -> list[str]:
if args.cuda_devices:
devices = [item.strip() for item in args.cuda_devices.split(",") if item.strip()]
else:
devices = [str(i) for i in range(args.num_gpus)]
if len(devices) != args.num_gpus:
raise SystemExit(f"expected {args.num_gpus} CUDA devices, got {devices}")
return devices
def _norm_cmd(args: argparse.Namespace, exp: Experiment) -> list[str]:
return [
sys.executable,
"scripts/compute_norm_stats.py",
"--config-name",
exp.train_config,
"--repo-id",
str(args.data_root),
"--asset-id",
args.asset_id,
"--online-sliding-chunks",
"--online-sliding-speeds",
*_speed_args(),
]
def _train_cmd(args: argparse.Namespace, exp: Experiment, log_dir: Path) -> list[str]:
cmd = [
"uv",
"run",
"torchrun",
"--standalone",
"--nnodes=1",
f"--nproc_per_node={args.num_gpus}",
"--log-dir",
str(log_dir / exp.name),
"--redirects",
"3",
"--tee",
"3",
"scripts/train_pytorch.py",
exp.train_config,
"--exp-name",
exp.exp_name,
"--pytorch-weight-path",
str(args.pi05_base),
"--batch-size",
str(args.batch_size),
"--num-workers",
str(args.num_workers),
"--num-train-steps",
str(args.num_train_steps),
"--log-interval",
str(args.log_interval),
"--save-interval",
str(args.save_interval),
"--lr-schedule.peak-lr",
str(args.lr),
"--lr-schedule.decay-lr",
str(args.lr),
"--eval-speed-set",
*_speed_args(),
"--data.repo-id",
str(args.data_root),
"--data.assets.asset-id",
args.asset_id,
"--data.online-sliding-chunks",
"--data.online-sliding-speeds",
*_speed_args(),
"--model.pytorch-compile-mode",
args.compile_mode,
*exp.train_args,
]
if args.no_wandb:
cmd.append("--no-wandb-enabled")
if args.train_mode == "overwrite":
cmd.append("--overwrite")
elif args.train_mode == "resume":
cmd.append("--resume")
for extra in args.extra_train_arg:
cmd.extend(shlex.split(extra))
return cmd
def _serve_cmd(args: argparse.Namespace, exp: Experiment, ckpt_dir: Path, port: int) -> list[str]:
return [
"uv",
"run",
"python",
"scripts/serve_policy.py",
"policy:checkpoint",
"--policy.config",
exp.train_config,
"--policy.dir",
str(ckpt_dir),
"--port",
str(port),
]
def _terminate_servers(servers: list[subprocess.Popen]) -> None:
for proc in servers:
if proc.poll() is None:
proc.terminate()
deadline = time.time() + 30
for proc in servers:
remaining = max(0.0, deadline - time.time())
try:
proc.wait(timeout=remaining)
except subprocess.TimeoutExpired:
proc.kill()
for proc in servers:
try:
proc.wait(timeout=5)
except subprocess.TimeoutExpired:
pass
def _start_servers(args: argparse.Namespace, exp: Experiment, ckpt_dir: Path, env: dict[str, str]) -> list[subprocess.Popen]:
devices = _cuda_devices(args)
server_log_dir = args.log_dir / "servers" / exp.name
server_log_dir.mkdir(parents=True, exist_ok=True)
servers: list[subprocess.Popen] = []
for rank, device in enumerate(devices):
port = args.base_port + rank
cmd = _serve_cmd(args, exp, ckpt_dir, port)
print(f"CUDA_VISIBLE_DEVICES={device} {_shell(cmd)} > {server_log_dir / f'gpu{rank}.log'} 2>&1 &")
if args.dry_run:
continue
log_file = (server_log_dir / f"gpu{rank}.log").open("w")
server_env = {**env, "CUDA_VISIBLE_DEVICES": device}
servers.append(subprocess.Popen(cmd, cwd=args.project_root, env=server_env, stdout=log_file, stderr=subprocess.STDOUT))
if not args.dry_run:
print(f"Waiting {args.server_wait_seconds}s for policy servers to load...", flush=True)
time.sleep(args.server_wait_seconds)
return servers
def _eval_experiment(args: argparse.Namespace, exp: Experiment, env: dict[str, str]) -> None:
ckpt_dir = _latest_checkpoint_step_dir(
args.project_root,
exp.train_config,
exp.exp_name,
args.ckpt_step,
args.num_train_steps - 1,
)
if not args.dry_run and not ckpt_dir.exists():
raise SystemExit(f"checkpoint for eval does not exist: {ckpt_dir}")
print(f"\n========== eval: {exp.name} ckpt={ckpt_dir} ==========")
servers = _start_servers(args, exp, ckpt_dir, env)
try:
for speed in args.eval_speeds:
prefix_env = {
"SPEED": f"{speed:g}",
"BASE_PORT": str(args.base_port),
"HOST": args.host,
"NUM_TRIALS": str(args.num_trials),
"SAVE_VIDEOS": "1" if args.save_videos else "0",
"PYTHON_CMD": "uv run python",
"RESULTS_DIR": str(args.results_dir / exp.exp_name / f"speed_{_speed_tag(speed)}"),
}
_run_with_env(
prefix_env,
["./scripts/eval_libero_8gpu.sh"],
cwd=args.project_root,
env=env,
dry_run=args.dry_run,
)
finally:
if servers:
print(f"Stopping policy servers for {exp.name}...", flush=True)
_terminate_servers(servers)
def _write_manifest(project_root: Path, log_dir: Path, args: argparse.Namespace, experiments: tuple[Experiment, ...]) -> None:
if args.dry_run:
return
log_dir.mkdir(parents=True, exist_ok=True)
manifest = {
"speeds": SPEEDS,
"asset_id": args.asset_id,
"batch_size": args.batch_size,
"lr": args.lr,
"eval_speeds": args.eval_speeds,
"data_root": str(args.data_root),
"pi05_base": str(args.pi05_base),
"experiments": [dataclasses.asdict(exp) for exp in experiments],
}
(log_dir / "speed_embedding_ablation_manifest.json").write_text(json.dumps(manifest, indent=2) + "\n")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--project-root", type=Path, default=Path.cwd(), help="VLAwithVariousSpeed repo root.")
parser.add_argument("--data-root", type=Path, required=True, help="Source LeRobot/LIBERO dataset root.")
parser.add_argument("--pi05-base", type=Path, required=True, help="Path to PI05 base weights directory.")
parser.add_argument("--asset-id", default=DEFAULT_ASSET_ID)
parser.add_argument("--only", default=None, help="Comma-separated subset: text,modulation,soft_prompt.")
parser.add_argument("--stage", choices=("all", "norm", "train", "eval"), default="all")
parser.add_argument("--force-norm", action="store_true")
parser.add_argument(
"--train-mode",
choices=("overwrite", "resume", "skip-existing", "fail-if-exists"),
default="overwrite",
)
parser.add_argument("--dry-run", action="store_true")
parser.add_argument("--num-gpus", type=int, default=8)
parser.add_argument("--cuda-devices", default=None, help="CUDA_VISIBLE_DEVICES value. Default: 0..num_gpus-1.")
parser.add_argument("--batch-size", type=int, default=512)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--num-workers", type=int, default=2)
parser.add_argument("--num-train-steps", type=int, default=30_000)
parser.add_argument("--log-interval", type=int, default=100)
parser.add_argument("--save-interval", type=int, default=1000)
parser.add_argument("--compile-mode", default="None")
parser.add_argument("--log-dir", type=Path, default=Path("logs/speed_embedding_ablation"))
parser.add_argument("--eval-speeds", type=float, nargs="+", default=list(SPEEDS))
parser.add_argument("--results-dir", type=Path, default=Path("results/speed_embedding_ablation"))
parser.add_argument("--ckpt-step", type=int, default=None, help="Checkpoint step to evaluate. Default: latest step.")
parser.add_argument("--base-port", type=int, default=8000)
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--num-trials", type=int, default=50)
parser.add_argument("--save-videos", action="store_true")
parser.add_argument("--server-wait-seconds", type=int, default=120)
parser.add_argument("--no-wandb", action="store_true")
parser.add_argument("--keep-wandb-env", action="store_true")
parser.add_argument("--wandb-service-wait", type=int, default=300)
parser.add_argument(
"--extra-train-arg",
action="append",
default=[],
help="Extra argument appended to train_pytorch.py. Repeat for multiple args.",
)
return parser.parse_args()
def _select_experiments(args: argparse.Namespace) -> tuple[Experiment, ...]:
if args.only is None:
return EXPERIMENTS
wanted = {name.strip() for name in args.only.split(",") if name.strip()}
known = {exp.name for exp in EXPERIMENTS}
unknown = wanted - known
if unknown:
raise SystemExit(f"unknown experiments: {sorted(unknown)}; known={sorted(known)}")
return tuple(exp for exp in EXPERIMENTS if exp.name in wanted)
def main() -> None:
args = parse_args()
project_root = args.project_root.resolve()
args.project_root = project_root
args.data_root = args.data_root.resolve()
args.pi05_base = args.pi05_base.resolve()
args.log_dir = (project_root / args.log_dir).resolve() if not args.log_dir.is_absolute() else args.log_dir.resolve()
args.results_dir = (
(project_root / args.results_dir).resolve() if not args.results_dir.is_absolute() else args.results_dir.resolve()
)
if not (project_root / "scripts" / "train_pytorch.py").exists():
raise SystemExit(f"project root does not look valid: {project_root}")
if not args.data_root.exists():
raise SystemExit(f"data root does not exist: {args.data_root}")
if not args.pi05_base.exists():
raise SystemExit(f"pi05 base path does not exist: {args.pi05_base}")
if args.batch_size % args.num_gpus != 0:
raise SystemExit(f"--batch-size ({args.batch_size}) must be divisible by --num-gpus ({args.num_gpus}).")
experiments = _select_experiments(args)
args.log_dir.mkdir(parents=True, exist_ok=True)
_write_manifest(project_root, args.log_dir, args, experiments)
env = _base_env(args)
env["CUDA_VISIBLE_DEVICES"] = args.cuda_devices or ",".join(str(i) for i in range(args.num_gpus))
print("Speed embedding ablation runner")
print(f" project_root = {project_root}")
print(f" data_root = {args.data_root}")
print(f" pi05_base = {args.pi05_base}")
print(f" speeds = {SPEEDS}")
print(f" asset_id = {args.asset_id}")
print(f" batch_size = {args.batch_size}")
print(f" lr = {args.lr}")
print(f" stage = {args.stage}")
print(f" train_mode = {args.train_mode}")
print(f" eval_speeds = {args.eval_speeds}")
print(f" results_dir = {args.results_dir}")
print(f" experiments = {[exp.name for exp in experiments]}")
print()
if args.stage in ("all", "norm"):
for exp in experiments:
stats_path = _norm_stats_path(project_root, exp.train_config, args.asset_id)
if stats_path.exists() and not args.force_norm:
print(f"[skip norm] {exp.name}: {stats_path}")
else:
print(f"\n========== norm: {exp.name} source 1.0x stats for online sliding ==========")
_run(_norm_cmd(args, exp), cwd=project_root, env=env, dry_run=args.dry_run)
if args.stage in ("all", "train"):
for exp in experiments:
stats_path = _norm_stats_path(project_root, exp.train_config, args.asset_id)
if not args.dry_run and not stats_path.exists():
raise SystemExit(f"missing norm stats for {exp.name}: {stats_path}")
ckpt_dir = _checkpoint_dir(project_root, exp.train_config, exp.exp_name)
if args.train_mode == "skip-existing" and ckpt_dir.exists():
print(f"[skip train] {exp.name}: {ckpt_dir}")
continue
if args.train_mode == "fail-if-exists" and ckpt_dir.exists():
raise SystemExit(f"checkpoint exists for {exp.name}: {ckpt_dir}")
print(f"\n========== train: {exp.name} ==========")
_run(_train_cmd(args, exp, args.log_dir / "torchrun"), cwd=project_root, env=env, dry_run=args.dry_run)
if args.stage in ("all", "eval"):
for exp in experiments:
_eval_experiment(args, exp, env)
if __name__ == "__main__":
main()