VLAarchtests / code /reveal_vla_bimanual /eval /run_peract2_task_sweep.py
lsnu's picture
2026-03-25 runpod handoff update
e7d8e79 verified
from __future__ import annotations
import argparse
import json
import subprocess
import sys
from pathlib import Path
from typing import Any
from sim_rlbench.task_splits import PERACT2_BIMANUAL_TASKS
def _run_task(
project_root: Path,
checkpoint: Path,
output_dir: Path,
task_name: str,
*,
plan: bool,
episodes_per_task: int,
episode_length: int,
resolution: int,
device: str,
headless: bool,
chunk_commit_steps: int,
allow_unsupervised_planning: bool,
disable_support_mode_conditioning: bool,
disable_task_conditioning: bool,
no_geometry: bool,
compact_world_model: bool,
no_planner: bool,
) -> dict[str, Any]:
task_dir = output_dir / task_name
task_dir.mkdir(parents=True, exist_ok=True)
command = [
sys.executable,
"-m",
"eval.run_rlbench_rollout_eval",
"--checkpoint",
str(checkpoint),
"--output-dir",
str(task_dir),
"--tasks",
task_name,
"--episodes-per-task",
str(episodes_per_task),
"--episode-length",
str(episode_length),
"--resolution",
str(resolution),
"--device",
device,
"--chunk-commit-steps",
str(chunk_commit_steps),
]
if headless:
command.append("--headless")
if plan:
command.append("--plan")
if no_planner:
command.append("--no-planner")
if allow_unsupervised_planning:
command.append("--allow-unsupervised-planning")
if disable_support_mode_conditioning:
command.append("--disable-support-mode-conditioning")
if disable_task_conditioning:
command.append("--disable-task-conditioning")
if no_geometry:
command.append("--no-geometry")
if compact_world_model:
command.append("--compact-world-model")
completed = subprocess.run(
command,
cwd=project_root,
text=True,
capture_output=True,
check=False,
)
(task_dir / "command.txt").write_text(" ".join(command) + "\n", encoding="utf-8")
(task_dir / "stdout.txt").write_text(completed.stdout, encoding="utf-8")
(task_dir / "stderr.txt").write_text(completed.stderr, encoding="utf-8")
rollout_path = task_dir / "rollout_eval.json"
if rollout_path.exists():
payload = json.loads(rollout_path.read_text(encoding="utf-8"))
task_payload = payload.get("tasks", {}).get(task_name, {})
else:
task_payload = {}
if completed.returncode != 0 and "error" not in task_payload:
task_payload["error"] = f"subprocess_exit_{completed.returncode}"
if "mean_success" not in task_payload:
task_payload["mean_success"] = 0.0
if "mean_return" not in task_payload:
task_payload["mean_return"] = 0.0
task_payload["subprocess_returncode"] = completed.returncode
task_payload["rollout_path"] = str(rollout_path)
return task_payload
def _mode_name(plan: bool) -> str:
return "plan" if plan else "noplan"
def _mode_output_dir(output_root: Path, run_name_prefix: str, plan: bool) -> Path:
return output_root / f"{run_name_prefix}_{_mode_name(plan)}_split"
def _write_summary_markdown(path: Path, payload: dict[str, Any]) -> None:
lines = [
"# PerAct2 13-Task Rollout Sweep",
"",
f"- Checkpoint: `{payload['checkpoint']}`",
f"- Plan requested: `{payload['plan_requested']}`",
f"- Plan applied: `{payload['plan_applied']}`",
f"- Episodes per task: `{payload['episodes_per_task']}`",
f"- Episode length: `{payload['episode_length']}`",
f"- Resolution: `{payload['resolution']}`",
f"- No planner: `{payload['no_planner']}`",
f"- Disable task conditioning: `{payload['disable_task_conditioning']}`",
f"- No geometry: `{payload['no_geometry']}`",
f"- Compact world model: `{payload['compact_world_model']}`",
f"- Task count: `{payload['task_count']}`",
f"- Error tasks: `{payload['error_tasks']}`",
f"- Mean success: `{payload['mean_success']:.3f}`",
"",
"## Per-task",
"",
]
for task_name, task_payload in payload["tasks"].items():
if "error" in task_payload:
lines.append(
f"- `{task_name}`: mean_success={task_payload['mean_success']:.3f}, "
f"mean_return={task_payload['mean_return']:.3f}, "
f"error={task_payload['error']}, "
f"subprocess_returncode={task_payload['subprocess_returncode']}"
)
continue
lines.append(
f"- `{task_name}`: mean_success={task_payload['mean_success']:.3f}, "
f"mean_return={task_payload['mean_return']:.3f}, "
f"path_recoveries={task_payload.get('path_recoveries')}, "
f"noop_fallbacks={task_payload.get('noop_fallbacks')}"
)
path.write_text("\n".join(lines) + "\n", encoding="utf-8")
def _run_mode(args: argparse.Namespace, plan: bool) -> Path:
project_root = Path(__file__).resolve().parents[1]
checkpoint = Path(args.checkpoint).resolve()
output_dir = _mode_output_dir(Path(args.output_root).resolve(), args.run_name_prefix, plan)
output_dir.mkdir(parents=True, exist_ok=True)
summary: dict[str, Any] = {
"checkpoint": str(checkpoint),
"plan_requested": plan,
"plan_applied": plan,
"episodes_per_task": args.episodes_per_task,
"episode_length": args.episode_length,
"resolution": args.resolution,
"device": args.device,
"no_planner": args.no_planner,
"disable_task_conditioning": args.disable_task_conditioning,
"no_geometry": args.no_geometry,
"compact_world_model": args.compact_world_model,
"tasks": {},
"subprocess_mode": "isolated_per_task",
}
tasks = tuple(args.tasks) if args.tasks else PERACT2_BIMANUAL_TASKS
for task_name in tasks:
print(f"[peract2-sweep] running task={task_name} plan={plan}", flush=True)
summary["tasks"][task_name] = _run_task(
project_root,
checkpoint,
output_dir,
task_name,
plan=plan,
episodes_per_task=args.episodes_per_task,
episode_length=args.episode_length,
resolution=args.resolution,
device=args.device,
headless=args.headless,
chunk_commit_steps=args.chunk_commit_steps,
allow_unsupervised_planning=args.allow_unsupervised_planning,
disable_support_mode_conditioning=args.disable_support_mode_conditioning,
disable_task_conditioning=args.disable_task_conditioning,
no_geometry=args.no_geometry,
compact_world_model=args.compact_world_model,
no_planner=args.no_planner,
)
task_scores = [float(task_payload["mean_success"]) for task_payload in summary["tasks"].values()]
summary["task_count"] = len(summary["tasks"])
summary["error_tasks"] = sorted(
task_name for task_name, task_payload in summary["tasks"].items() if "error" in task_payload
)
summary["mean_success"] = float(sum(task_scores) / len(task_scores)) if task_scores else 0.0
summary_path = output_dir / "rollout_eval.json"
summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
_write_summary_markdown(output_dir / "rollout_eval.md", summary)
return summary_path
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", required=True)
parser.add_argument("--output-root", required=True)
parser.add_argument("--run-name-prefix", default="peract2_13_rollout")
parser.add_argument("--tasks", nargs="*", default=list(PERACT2_BIMANUAL_TASKS))
parser.add_argument("--episodes-per-task", type=int, default=1)
parser.add_argument("--episode-length", type=int, default=20)
parser.add_argument("--resolution", type=int, default=224)
parser.add_argument("--device", default="cuda")
parser.add_argument("--headless", action="store_true", default=True)
parser.add_argument("--chunk-commit-steps", type=int, default=4)
parser.add_argument("--allow-unsupervised-planning", action="store_true")
parser.add_argument("--disable-support-mode-conditioning", action="store_true")
parser.add_argument("--disable-task-conditioning", action="store_true")
parser.add_argument("--no-geometry", action="store_true")
parser.add_argument("--compact-world-model", action="store_true")
parser.add_argument("--no-planner", action="store_true")
parser.add_argument("--skip-noplan", action="store_true")
parser.add_argument("--skip-plan", action="store_true")
args = parser.parse_args()
generated = []
if not args.skip_noplan:
generated.append(_run_mode(args, plan=False))
if not args.skip_plan:
generated.append(_run_mode(args, plan=True))
print(json.dumps({"generated": [str(path) for path in generated]}, indent=2))
if __name__ == "__main__":
main()