VLAwithVariousSpeed / scripts /eval_libero_speed.py
Alan0928's picture
Upload folder using huggingface_hub
08ff31f verified
Raw
History Blame Contribute Delete
12.2 kB
#!/usr/bin/env python3
"""LIBERO eval client with speed conditioning, partial-suite support, and step-count tracking.
Connects to a websocket policy server (see ``scripts/serve_policy.py``), runs
rollouts for a chosen task suite (or a subset of task IDs within that suite),
records per-episode success and step count, and prints a summary plus an
optional JSON file.
Designed to be invoked by ``scripts/eval_libero_8gpu.sh``, which fans out the
work across 8 GPUs (3 short suites + libero_10 split 5 ways).
Example:
uv run python scripts/eval_libero_speed.py \\
--task-suite-name libero_spatial \\
--host 0.0.0.0 --port 8000 \\
--speed 1.0 \\
--video-out-path videos/spatial_1x \\
--results-json results/spatial_1x.json
"""
from __future__ import annotations
import collections
import dataclasses
import json
import logging
import math
import pathlib
import time
import imageio
from libero.libero import benchmark
from libero.libero import get_libero_path
from libero.libero.envs import OffScreenRenderEnv
import numpy as np
from openpi_client import image_tools
from openpi_client import websocket_client_policy as _websocket_client_policy
import tqdm
import tyro
LIBERO_DUMMY_ACTION = [0.0] * 6 + [-1.0]
LIBERO_ENV_RESOLUTION = 256
# Per-suite max-steps budget. Successful rollouts terminate as soon as the env
# returns ``done=True``; failures are forced to run to ``max_steps``.
_MAX_STEPS = {
"libero_spatial": 220,
"libero_object": 280,
"libero_goal": 300,
"libero_10": 520,
"libero_90": 400,
}
@dataclasses.dataclass
class Args:
# Server
host: str = "0.0.0.0"
port: int = 8000
# Suite / partition
task_suite_name: str = "libero_spatial"
# Comma-separated task IDs to evaluate; "all" means all tasks in the suite.
task_ids: str = "all"
num_trials_per_task: int = 50
# Speed conditioning (passed through to the model). 1.0 = original speed.
speed: float = 1.0
# Rollout
replan_steps: int = 5
resize_size: int = 224
num_steps_wait: int = 10
seed: int = 7
# Outputs
video_out_path: str = "videos/libero_eval"
results_json: str | None = None
# Per-episode video can balloon disk usage; flip off if not needed.
save_videos: bool = True
# Identifier reported in the summary (e.g., the GPU id this client was
# launched against). Purely cosmetic.
rank: int = 0
def _parse_task_ids(spec: str, n_total: int) -> list[int]:
if spec.strip().lower() == "all":
return list(range(n_total))
out: list[int] = []
for part in spec.split(","):
part = part.strip()
if not part:
continue
if "-" in part:
lo, hi = part.split("-", 1)
out.extend(range(int(lo), int(hi) + 1))
else:
out.append(int(part))
bad = [i for i in out if i < 0 or i >= n_total]
if bad:
raise ValueError(f"task_ids out of range [0, {n_total}): {bad}")
return sorted(set(out))
def _get_libero_env(task, resolution: int, seed: int):
task_description = task.language
bddl = pathlib.Path(get_libero_path("bddl_files")) / task.problem_folder / task.bddl_file
env = OffScreenRenderEnv(bddl_file_name=bddl, camera_heights=resolution, camera_widths=resolution)
env.seed(seed)
return env, task_description
def _quat2axisangle(quat):
if quat[3] > 1.0:
quat[3] = 1.0
elif quat[3] < -1.0:
quat[3] = -1.0
den = np.sqrt(1.0 - quat[3] * quat[3])
if math.isclose(den, 0.0):
return np.zeros(3)
return (quat[:3] * 2.0 * math.acos(quat[3])) / den
def _speed_label(speed: float) -> str:
text = f"{speed:g}".replace(".", "p")
return f"{text}x"
def _summary_string(speed: float, suite: str, rank: int, episodes: list[dict]) -> str:
n = len(episodes)
if n == 0:
return f"[rank={rank}] {suite} speed={speed:g}x no episodes"
successes = [e for e in episodes if e["success"]]
failures = [e for e in episodes if not e["success"]]
succ_steps = [e["steps"] for e in successes]
all_steps = [e["steps"] for e in episodes]
sr = len(successes) / n
succ_mean = float(np.mean(succ_steps)) if succ_steps else float("nan")
succ_median = float(np.median(succ_steps)) if succ_steps else float("nan")
all_mean = float(np.mean(all_steps))
fail_mean = float(np.mean([e["steps"] for e in failures])) if failures else float("nan")
return (
f"[rank={rank}] {suite} speed={speed:g}x "
f"success={len(successes)}/{n} ({sr * 100:.1f}%) "
f"mean_steps_success={succ_mean:.1f} median={succ_median:.1f} "
f"mean_steps_failure={fail_mean:.1f} "
f"mean_steps_all={all_mean:.1f}"
)
def eval_libero(args: Args) -> int:
np.random.seed(args.seed)
benchmark_dict = benchmark.get_benchmark_dict()
task_suite = benchmark_dict[args.task_suite_name]()
n_total_tasks = task_suite.n_tasks
task_ids = _parse_task_ids(args.task_ids, n_total_tasks)
if not task_ids:
raise ValueError(f"No tasks selected for {args.task_suite_name} (task_ids='{args.task_ids}')")
if args.task_suite_name not in _MAX_STEPS:
raise ValueError(f"Unknown task suite: {args.task_suite_name}")
max_steps = _MAX_STEPS[args.task_suite_name]
pathlib.Path(args.video_out_path).mkdir(parents=True, exist_ok=True)
logging.info(
f"[rank={args.rank}] suite={args.task_suite_name} task_ids={task_ids} "
f"n_trials={args.num_trials_per_task} speed={args.speed:g} max_steps={max_steps}"
)
client = _websocket_client_policy.WebsocketClientPolicy(args.host, args.port)
speed_label = _speed_label(args.speed)
episodes: list[dict] = []
t_start = time.time()
for task_id in tqdm.tqdm(task_ids, desc=f"rank={args.rank}/{args.task_suite_name}"):
task = task_suite.get_task(task_id)
initial_states = task_suite.get_task_init_states(task_id)
env, task_description = _get_libero_env(task, LIBERO_ENV_RESOLUTION, args.seed)
for episode_idx in range(args.num_trials_per_task):
env.reset()
action_plan: collections.deque = collections.deque()
obs = env.set_init_state(initial_states[episode_idx])
t = 0
policy_steps_executed = 0
replay_images: list[np.ndarray] = []
done = False
while t < max_steps + args.num_steps_wait:
try:
if t < args.num_steps_wait:
obs, _, _, _ = env.step(LIBERO_DUMMY_ACTION)
t += 1
continue
img = np.ascontiguousarray(obs["agentview_image"][::-1, ::-1])
wrist_img = np.ascontiguousarray(obs["robot0_eye_in_hand_image"][::-1, ::-1])
img = image_tools.convert_to_uint8(
image_tools.resize_with_pad(img, args.resize_size, args.resize_size)
)
wrist_img = image_tools.convert_to_uint8(
image_tools.resize_with_pad(wrist_img, args.resize_size, args.resize_size)
)
replay_images.append(img)
if not action_plan:
element = {
"observation/image": img,
"observation/wrist_image": wrist_img,
"observation/state": np.concatenate(
(
obs["robot0_eef_pos"],
_quat2axisangle(obs["robot0_eef_quat"]),
obs["robot0_gripper_qpos"],
)
),
"prompt": str(task_description),
# Pass speed downstream; the data + model decide how
# to consume it (text prompt / scalar modulation /
# soft_prompt anchor lookup).
# Shape (1,) so that policy.infer's [None, ...] promotes
# it to (B=1, 1), matching Observation.speed: Float[*b 1].
"speed": np.array([args.speed], dtype=np.float32),
"speed_label": speed_label,
}
action_chunk = client.infer(element)["actions"]
if len(action_chunk) < args.replan_steps:
raise RuntimeError(
f"replan_steps={args.replan_steps} but policy returned "
f"{len(action_chunk)} actions"
)
action_plan.extend(action_chunk[: args.replan_steps])
action = action_plan.popleft()
obs, _, env_done, _ = env.step(action.tolist())
policy_steps_executed += 1
if env_done:
done = True
break
t += 1
except Exception as e: # noqa: BLE001
logging.error(f"[rank={args.rank}] task={task_id} ep={episode_idx} caught: {e}")
break
episodes.append(
{
"task_id": int(task_id),
"task_description": str(task_description),
"episode_idx": int(episode_idx),
"success": bool(done),
"steps": int(policy_steps_executed),
"max_steps": int(max_steps),
"wait_steps": int(args.num_steps_wait),
"speed": float(args.speed),
"suite": str(args.task_suite_name),
}
)
if args.save_videos:
suffix = "success" if done else "failure"
seg = task_description.replace(" ", "_")
vid_path = (
pathlib.Path(args.video_out_path)
/ f"rank{args.rank}_{args.task_suite_name}_speed{speed_label}"
f"_t{task_id:02d}_e{episode_idx:02d}_{suffix}.mp4"
)
imageio.mimwrite(vid_path, [np.asarray(x) for x in replay_images], fps=10)
del replay_images
elapsed = time.time() - t_start
summary_str = _summary_string(args.speed, args.task_suite_name, args.rank, episodes)
print(summary_str)
print(f"[rank={args.rank}] elapsed={elapsed:.1f}s")
if args.results_json:
out_path = pathlib.Path(args.results_json)
out_path.parent.mkdir(parents=True, exist_ok=True)
successes = [e for e in episodes if e["success"]]
failures = [e for e in episodes if not e["success"]]
summary = {
"rank": args.rank,
"suite": args.task_suite_name,
"task_ids": task_ids,
"num_trials_per_task": args.num_trials_per_task,
"speed": float(args.speed),
"speed_label": speed_label,
"n_episodes": len(episodes),
"n_success": len(successes),
"n_failure": len(failures),
"success_rate": len(successes) / max(len(episodes), 1),
"mean_steps_success": float(np.mean([e["steps"] for e in successes])) if successes else None,
"median_steps_success": float(np.median([e["steps"] for e in successes])) if successes else None,
"mean_steps_failure": float(np.mean([e["steps"] for e in failures])) if failures else None,
"mean_steps_all": float(np.mean([e["steps"] for e in episodes])) if episodes else None,
"elapsed_seconds": elapsed,
"summary_line": summary_str,
}
with out_path.open("w") as f:
json.dump({"summary": summary, "episodes": episodes}, f, indent=2)
f.write("\n")
print(f"[rank={args.rank}] wrote {out_path}")
return 0
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s", datefmt="%H:%M:%S")
raise SystemExit(eval_libero(tyro.cli(Args)))