Instructions to use lcccluck/mujoco-lerobot-train with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- LeRobot
How to use lcccluck/mujoco-lerobot-train with LeRobot:
- Notebooks
- Google Colab
- Kaggle
| from __future__ import annotations | |
| import json | |
| import os | |
| import shlex | |
| import shutil | |
| import subprocess | |
| import sys | |
| import time | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| import mujoco | |
| import numpy as np | |
| import torch | |
| from lerobot.configs.policies import PreTrainedConfig | |
| from lerobot.configs.types import FeatureType | |
| from lerobot.datasets.lerobot_dataset import LeRobotDataset | |
| from lerobot.policies.factory import get_policy_class, make_pre_post_processors | |
| try: | |
| import mujoco.viewer | |
| except Exception: | |
| pass | |
| JOINT_NAMES = [ | |
| "shoulder_pan.pos", | |
| "shoulder_lift.pos", | |
| "elbow_flex.pos", | |
| "wrist_flex.pos", | |
| "wrist_roll.pos", | |
| "gripper.pos", | |
| ] | |
| HOME_STATE = np.array([0.0, -35.0, 75.0, -40.0, 0.0, 70.0], dtype=np.float64) | |
| TABLE_Z = 0.025 | |
| ATTACH_OFFSET = np.array([0.0, 0.0, -0.055], dtype=np.float64) | |
| GRIP_CLOSE_THRESHOLD = 25.0 | |
| GRIP_OPEN_THRESHOLD = 40.0 | |
| GRASP_DISTANCE_THRESHOLD = 0.045 | |
| SUCCESS_XY_THRESHOLD = 0.05 | |
| class DatasetConfig: | |
| repo_id: str | |
| root: Path | |
| episodes: int | |
| fps: int | |
| width: int | |
| height: int | |
| image_key: str | |
| task: str | |
| robot_type: str | |
| base_height: float | |
| seed: int | |
| overwrite: bool | |
| show_viewer: bool | |
| class VizConfig: | |
| episode_index: int | |
| class TrainConfig: | |
| output_dir: Path | |
| job_name: str | |
| device: str | |
| batch_size: int | |
| num_workers: int | |
| steps: int | |
| save_freq: int | |
| log_freq: int | |
| eval_freq: int | |
| seed: int | |
| wandb_enable: bool | |
| resume: bool | |
| class EvalConfig: | |
| policy_root: Path | None | |
| episodes: int | |
| max_steps: int | |
| device: str | |
| base_height: float | |
| seed: int | |
| show_viewer: bool | |
| step_sleep: float | |
| task: str | |
| robot_type: str | |
| class AppConfig: | |
| config_path: Path | |
| dataset: DatasetConfig | |
| viz: VizConfig | |
| train: TrainConfig | |
| eval: EvalConfig | |
| def resolve_path(base_dir: Path, value: str) -> Path: | |
| path = Path(value).expanduser() | |
| if path.is_absolute(): | |
| return path | |
| return (base_dir / path).resolve() | |
| def load_config(config_path: str | Path | None = None) -> AppConfig: | |
| path = Path(config_path or Path(__file__).with_name("config.json")).expanduser().resolve() | |
| raw = json.loads(path.read_text()) | |
| base_dir = path.parent | |
| dataset = DatasetConfig( | |
| repo_id=raw["dataset"]["repo_id"], | |
| root=resolve_path(base_dir, raw["dataset"]["root"]), | |
| episodes=int(raw["dataset"]["episodes"]), | |
| fps=int(raw["dataset"]["fps"]), | |
| width=int(raw["dataset"]["width"]), | |
| height=int(raw["dataset"]["height"]), | |
| image_key=raw["dataset"]["image_key"], | |
| task=raw["dataset"]["task"], | |
| robot_type=raw["dataset"]["robot_type"], | |
| base_height=float(raw["dataset"]["base_height"]), | |
| seed=int(raw["dataset"]["seed"]), | |
| overwrite=bool(raw["dataset"]["overwrite"]), | |
| show_viewer=bool(raw["dataset"]["show_viewer"]), | |
| ) | |
| viz = VizConfig(episode_index=int(raw["viz"]["episode_index"])) | |
| train = TrainConfig( | |
| output_dir=resolve_path(base_dir, raw["train"]["output_dir"]), | |
| job_name=raw["train"]["job_name"], | |
| device=raw["train"]["device"], | |
| batch_size=int(raw["train"]["batch_size"]), | |
| num_workers=int(raw["train"]["num_workers"]), | |
| steps=int(raw["train"]["steps"]), | |
| save_freq=int(raw["train"]["save_freq"]), | |
| log_freq=int(raw["train"]["log_freq"]), | |
| eval_freq=int(raw["train"]["eval_freq"]), | |
| seed=int(raw["train"]["seed"]), | |
| wandb_enable=bool(raw["train"]["wandb_enable"]), | |
| resume=bool(raw["train"]["resume"]), | |
| ) | |
| policy_root_value = raw["eval"]["policy_root"] | |
| eval_cfg = EvalConfig( | |
| policy_root=resolve_path(base_dir, policy_root_value) if policy_root_value else None, | |
| episodes=int(raw["eval"]["episodes"]), | |
| max_steps=int(raw["eval"]["max_steps"]), | |
| device=raw["eval"]["device"], | |
| base_height=float(raw["eval"]["base_height"]), | |
| seed=int(raw["eval"]["seed"]), | |
| show_viewer=bool(raw["eval"]["show_viewer"]), | |
| step_sleep=float(raw["eval"]["step_sleep"]), | |
| task=raw["eval"]["task"], | |
| robot_type=raw["eval"]["robot_type"], | |
| ) | |
| return AppConfig(config_path=path, dataset=dataset, viz=viz, train=train, eval=eval_cfg) | |
| def lerobot_env() -> dict[str, str]: | |
| return os.environ.copy() | |
| def maybe_reexec_with_mjpython(show_viewer: bool, env_key: str) -> None: | |
| if not show_viewer or sys.platform != "darwin": | |
| return | |
| if os.environ.get(env_key) == "1": | |
| return | |
| if Path(sys.executable).name == "mjpython": | |
| return | |
| mjpython_path = Path(sys.executable).with_name("mjpython") | |
| if not mjpython_path.exists(): | |
| raise RuntimeError(f"`mjpython` not found next to {sys.executable}") | |
| env = os.environ.copy() | |
| env[env_key] = "1" | |
| os.execve(str(mjpython_path), [str(mjpython_path), *sys.argv], env) | |
| def bool_cli(value: bool) -> str: | |
| return "true" if value else "false" | |
| def parse_device(device: str) -> torch.device: | |
| if device == "auto": | |
| if torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| if torch.backends.mps.is_available(): | |
| return torch.device("mps") | |
| return torch.device("cpu") | |
| return torch.device(device) | |
| def build_xml(base_height: float) -> str: | |
| return f""" | |
| <mujoco model="mujoco_pickplace_minimal"> | |
| <compiler angle="radian"/> | |
| <option timestep="0.01" gravity="0 0 -9.81"/> | |
| <visual> | |
| <headlight ambient="0.7 0.7 0.7" diffuse="0.7 0.7 0.7" specular="0.1 0.1 0.1"/> | |
| <global offwidth="1024" offheight="1024"/> | |
| </visual> | |
| <asset> | |
| <texture name="grid" type="2d" builtin="checker" rgb1="0.28 0.31 0.35" rgb2="0.18 0.20 0.24" width="256" height="256"/> | |
| <material name="floor" texture="grid" texrepeat="6 6" reflectance="0.05"/> | |
| <material name="arm" rgba="0.88 0.66 0.22 1"/> | |
| <material name="metal" rgba="0.24 0.24 0.28 1"/> | |
| <material name="finger" rgba="0.75 0.77 0.80 1"/> | |
| <material name="cube" rgba="0.91 0.32 0.20 1"/> | |
| </asset> | |
| <worldbody> | |
| <light pos="0 0 2.4" dir="0 0 -1"/> | |
| <geom name="floor" type="plane" size="2 2 0.05" material="floor"/> | |
| <site name="goal_site" pos="0.18 0.18 0.025" size="0.03 0.03 0.003" type="box" rgba="0.2 0.8 0.3 0.55"/> | |
| <body name="cube" pos="0.22 -0.10 0.025"> | |
| <freejoint name="cube_freejoint"/> | |
| <geom type="box" size="0.025 0.025 0.025" material="cube" mass="0.05"/> | |
| </body> | |
| <body name="base" pos="0 0 {base_height:.3f}"> | |
| <geom type="cylinder" size="0.09 0.05" material="metal"/> | |
| <body name="shoulder_pan_link" pos="0 0 0.05"> | |
| <joint name="shoulder_pan" type="hinge" axis="0 0 1" range="-3.14 3.14"/> | |
| <geom type="capsule" fromto="0 0 0 0 0 0.10" size="0.035" material="arm"/> | |
| <body name="shoulder_lift_link" pos="0 0 0.10"> | |
| <joint name="shoulder_lift" type="hinge" axis="0 1 0" range="-2.7 2.7"/> | |
| <geom type="capsule" fromto="0 0 0 0.18 0 0" size="0.03" material="arm"/> | |
| <body name="elbow_flex_link" pos="0.18 0 0"> | |
| <joint name="elbow_flex" type="hinge" axis="0 1 0" range="-2.7 2.7"/> | |
| <geom type="capsule" fromto="0 0 0 0.16 0 0" size="0.026" material="arm"/> | |
| <body name="wrist_flex_link" pos="0.16 0 0"> | |
| <joint name="wrist_flex" type="hinge" axis="0 1 0" range="-2.7 2.7"/> | |
| <geom type="capsule" fromto="0 0 0 0.10 0 0" size="0.021" material="arm"/> | |
| <body name="wrist_roll_link" pos="0.10 0 0"> | |
| <joint name="wrist_roll" type="hinge" axis="1 0 0" range="-3.14 3.14"/> | |
| <geom type="capsule" fromto="0 0 0 0.06 0 0" size="0.017" material="arm"/> | |
| <body name="tcp_body" pos="0.06 0 0"> | |
| <geom type="box" size="0.02 0.014 0.014" material="metal"/> | |
| <site name="tcp" pos="0.035 0 0" size="0.008" rgba="0.2 0.85 0.35 1"/> | |
| <body name="left_finger" pos="0.012 0.015 0"> | |
| <joint name="left_finger_slide" type="slide" axis="0 1 0" range="0 0.025"/> | |
| <geom type="box" pos="0.025 0.012 0" size="0.025 0.004 0.006" material="finger"/> | |
| </body> | |
| <body name="right_finger" pos="0.012 -0.015 0"> | |
| <joint name="right_finger_slide" type="slide" axis="0 -1 0" range="0 0.025"/> | |
| <geom type="box" pos="0.025 -0.012 0" size="0.025 0.004 0.006" material="finger"/> | |
| </body> | |
| </body> | |
| </body> | |
| </body> | |
| </body> | |
| </body> | |
| </body> | |
| </body> | |
| </worldbody> | |
| </mujoco> | |
| """ | |
| def prepare_model(base_height: float) -> tuple[mujoco.MjModel, mujoco.MjData, dict[str, int]]: | |
| model = mujoco.MjModel.from_xml_string(build_xml(base_height)) | |
| data = mujoco.MjData(model) | |
| ids = { | |
| "tcp": mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_SITE, "tcp"), | |
| "goal_site": mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_SITE, "goal_site"), | |
| "cube_joint": mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_JOINT, "cube_freejoint"), | |
| "shoulder_pan": mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_JOINT, "shoulder_pan"), | |
| "shoulder_lift": mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_JOINT, "shoulder_lift"), | |
| "elbow_flex": mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_JOINT, "elbow_flex"), | |
| "wrist_flex": mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_JOINT, "wrist_flex"), | |
| "wrist_roll": mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_JOINT, "wrist_roll"), | |
| "left_finger": mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_JOINT, "left_finger_slide"), | |
| "right_finger": mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_JOINT, "right_finger_slide"), | |
| } | |
| return model, data, ids | |
| def set_arm_state(model: mujoco.MjModel, data: mujoco.MjData, ids: dict[str, int], joint_state: np.ndarray) -> None: | |
| qpos = data.qpos | |
| qpos[model.jnt_qposadr[ids["shoulder_pan"]]] = np.deg2rad(joint_state[0]) | |
| qpos[model.jnt_qposadr[ids["shoulder_lift"]]] = np.deg2rad(joint_state[1]) | |
| qpos[model.jnt_qposadr[ids["elbow_flex"]]] = np.deg2rad(joint_state[2]) | |
| qpos[model.jnt_qposadr[ids["wrist_flex"]]] = np.deg2rad(joint_state[3]) | |
| qpos[model.jnt_qposadr[ids["wrist_roll"]]] = np.deg2rad(joint_state[4]) | |
| finger_slide = 0.003 + 0.022 * float(np.clip(joint_state[5] / 100.0, 0.0, 1.0)) | |
| qpos[model.jnt_qposadr[ids["left_finger"]]] = finger_slide | |
| qpos[model.jnt_qposadr[ids["right_finger"]]] = finger_slide | |
| def set_cube_pose(model: mujoco.MjModel, data: mujoco.MjData, ids: dict[str, int], pos: np.ndarray) -> None: | |
| adr = model.jnt_qposadr[ids["cube_joint"]] | |
| data.qpos[adr : adr + 7] = np.array([pos[0], pos[1], pos[2], 1.0, 0.0, 0.0, 0.0], dtype=np.float64) | |
| def render_frame(renderer: mujoco.Renderer, data: mujoco.MjData, base_height: float) -> np.ndarray: | |
| camera = mujoco.MjvCamera() | |
| camera.type = mujoco.mjtCamera.mjCAMERA_FREE | |
| camera.lookat[:] = np.array([0.18, 0.0, base_height + 0.08], dtype=np.float64) | |
| camera.distance = 1.0 | |
| camera.azimuth = 140.0 | |
| camera.elevation = -45.0 | |
| renderer.update_scene(data, camera=camera) | |
| return renderer.render() | |
| def configure_viewer_camera(viewer, base_height: float) -> None: | |
| viewer.cam.lookat[:] = np.array([0.18, 0.0, base_height + 0.10], dtype=np.float64) | |
| viewer.cam.distance = 1.15 | |
| viewer.cam.azimuth = 135.0 | |
| viewer.cam.elevation = -32.0 | |
| def sample_positions(rng: np.random.Generator) -> tuple[np.ndarray, np.ndarray]: | |
| while True: | |
| obj = np.array([rng.uniform(0.18, 0.28), rng.uniform(-0.12, 0.12), TABLE_Z], dtype=np.float64) | |
| goal = np.array([rng.uniform(0.10, 0.25), rng.uniform(-0.18, 0.18), TABLE_Z], dtype=np.float64) | |
| if np.linalg.norm(obj[:2] - goal[:2]) >= 0.14: | |
| return obj, goal | |
| def solve_ik(target_xyz: np.ndarray) -> np.ndarray: | |
| shoulder_origin = np.array([0.0, 0.0, 0.50], dtype=np.float64) | |
| link_1 = 0.18 | |
| link_2 = 0.16 | |
| tool = 0.195 | |
| desired_pitch = -np.pi / 2.0 | |
| rel = target_xyz - shoulder_origin | |
| yaw = np.arctan2(rel[1], rel[0]) | |
| radial = np.hypot(rel[0], rel[1]) | |
| wrist_x = radial - tool * np.cos(desired_pitch) | |
| wrist_z = rel[2] - tool * np.sin(desired_pitch) | |
| dist_sq = wrist_x**2 + wrist_z**2 | |
| cos_elbow = np.clip((dist_sq - link_1**2 - link_2**2) / (2.0 * link_1 * link_2), -1.0, 1.0) | |
| elbow = np.arccos(cos_elbow) | |
| shoulder = np.arctan2(wrist_z, wrist_x) - np.arctan2(link_2 * np.sin(elbow), link_1 + link_2 * np.cos(elbow)) | |
| wrist = desired_pitch - shoulder - elbow | |
| return np.rad2deg(np.array([yaw, -shoulder, -elbow, -wrist, 0.0], dtype=np.float64)) | |
| def interpolate(start: np.ndarray, end: np.ndarray, steps: int) -> list[np.ndarray]: | |
| if steps <= 1: | |
| return [end.astype(np.float64)] | |
| return [(start + (end - start) * (i / (steps - 1))).astype(np.float64) for i in range(steps)] | |
| def make_episode_trajectory(obj: np.ndarray, goal: np.ndarray) -> list[tuple[np.ndarray, bool, np.ndarray]]: | |
| open_grip = 70.0 | |
| closed_grip = 8.0 | |
| above_obj = np.concatenate([solve_ik(obj + np.array([0.0, 0.0, 0.12])), [open_grip]]) | |
| grasp = np.concatenate([solve_ik(obj + np.array([0.0, 0.0, 0.045])), [open_grip]]) | |
| close = grasp.copy() | |
| close[-1] = closed_grip | |
| lift = np.concatenate([solve_ik(obj + np.array([0.0, 0.0, 0.18])), [closed_grip]]) | |
| above_goal = np.concatenate([solve_ik(goal + np.array([0.0, 0.0, 0.16])), [closed_grip]]) | |
| place = np.concatenate([solve_ik(goal + np.array([0.0, 0.0, 0.05])), [closed_grip]]) | |
| open_at_goal = place.copy() | |
| open_at_goal[-1] = open_grip | |
| retreat = np.concatenate([solve_ik(goal + np.array([0.0, 0.0, 0.18])), [open_grip]]) | |
| segments = [ | |
| (HOME_STATE, above_obj, 18, False, obj), | |
| (above_obj, grasp, 12, False, obj), | |
| (grasp, close, 8, False, obj), | |
| (close, lift, 14, True, obj), | |
| (lift, above_goal, 20, True, obj), | |
| (above_goal, place, 12, True, obj), | |
| (place, open_at_goal, 8, True, goal), | |
| (open_at_goal, retreat, 12, False, goal), | |
| ] | |
| frames: list[tuple[np.ndarray, bool, np.ndarray]] = [] | |
| for start, end, steps, attached, cube_anchor in segments: | |
| for state in interpolate(start, end, steps): | |
| frames.append((state, attached, cube_anchor.copy())) | |
| return frames | |
| def dataset_features(image_key: str, width: int, height: int) -> dict: | |
| return { | |
| "action": {"dtype": "float32", "shape": (6,), "names": JOINT_NAMES}, | |
| "observation.state": {"dtype": "float32", "shape": (6,), "names": JOINT_NAMES}, | |
| f"observation.images.{image_key}": { | |
| "dtype": "image", | |
| "shape": (height, width, 3), | |
| "names": ["height", "width", "channels"], | |
| }, | |
| } | |
| def collect_dataset(cfg: AppConfig) -> None: | |
| if cfg.dataset.root.exists(): | |
| if not cfg.dataset.overwrite: | |
| raise FileExistsError(f"Dataset root already exists: {cfg.dataset.root}") | |
| shutil.rmtree(cfg.dataset.root) | |
| dataset = LeRobotDataset.create( | |
| repo_id=cfg.dataset.repo_id, | |
| root=cfg.dataset.root, | |
| fps=cfg.dataset.fps, | |
| features=dataset_features(cfg.dataset.image_key, cfg.dataset.width, cfg.dataset.height), | |
| robot_type=cfg.dataset.robot_type, | |
| use_videos=False, | |
| ) | |
| model, data, ids = prepare_model(cfg.dataset.base_height) | |
| renderer = mujoco.Renderer(model, height=cfg.dataset.height, width=cfg.dataset.width) | |
| rng = np.random.default_rng(cfg.dataset.seed) | |
| dt = 1.0 / cfg.dataset.fps | |
| if cfg.dataset.show_viewer: | |
| viewer_cm = mujoco.viewer.launch_passive(model, data) | |
| else: | |
| viewer_cm = None | |
| try: | |
| if viewer_cm is not None: | |
| viewer = viewer_cm.__enter__() | |
| configure_viewer_camera(viewer, cfg.dataset.base_height) | |
| else: | |
| viewer = None | |
| for _ in range(cfg.dataset.episodes): | |
| obj_pos, goal_pos = sample_positions(rng) | |
| model.site_pos[ids["goal_site"]] = goal_pos | |
| frames = make_episode_trajectory(obj_pos, goal_pos) | |
| prev_state = frames[0][0].copy() | |
| attached_once = False | |
| for state, attached, cube_anchor in frames: | |
| start = time.perf_counter() | |
| if attached and not attached_once: | |
| attached_once = True | |
| set_arm_state(model, data, ids, state) | |
| mujoco.mj_forward(model, data) | |
| tcp_pos = data.site_xpos[ids["tcp"]].copy() | |
| if attached: | |
| cube_pos = tcp_pos + ATTACH_OFFSET | |
| elif attached_once: | |
| cube_pos = goal_pos.copy() | |
| else: | |
| cube_pos = cube_anchor.copy() | |
| set_cube_pose(model, data, ids, cube_pos) | |
| mujoco.mj_forward(model, data) | |
| image = render_frame(renderer, data, cfg.dataset.base_height) | |
| dataset.add_frame( | |
| { | |
| "observation.state": prev_state.astype(np.float32), | |
| "action": state.astype(np.float32), | |
| f"observation.images.{cfg.dataset.image_key}": image, | |
| "task": cfg.dataset.task, | |
| } | |
| ) | |
| prev_state = state.copy() | |
| if viewer is not None: | |
| viewer.sync() | |
| if not viewer.is_running(): | |
| viewer = None | |
| else: | |
| elapsed = time.perf_counter() - start | |
| if elapsed < dt: | |
| time.sleep(dt - elapsed) | |
| dataset.save_episode() | |
| finally: | |
| renderer.close() | |
| dataset.finalize() | |
| if viewer_cm is not None: | |
| viewer_cm.__exit__(None, None, None) | |
| def run_subprocess(cmd: list[str]) -> None: | |
| subprocess.run(cmd, check=True, env=lerobot_env()) | |
| def viz_dataset(cfg: AppConfig) -> None: | |
| cmd = [ | |
| "lerobot-dataset-viz", | |
| "--repo-id", | |
| cfg.dataset.repo_id, | |
| "--root", | |
| str(cfg.dataset.root), | |
| "--episode-index", | |
| str(cfg.viz.episode_index), | |
| ] | |
| run_subprocess(cmd) | |
| def train_policy(cfg: AppConfig) -> None: | |
| cmd = [ | |
| "lerobot-train", | |
| f"--dataset.repo_id={cfg.dataset.repo_id}", | |
| f"--dataset.root={cfg.dataset.root}", | |
| "--policy.type=act", | |
| "--policy.push_to_hub=false", | |
| f"--output_dir={cfg.train.output_dir}", | |
| f"--job_name={cfg.train.job_name}", | |
| f"--policy.device={cfg.train.device}", | |
| f"--batch_size={cfg.train.batch_size}", | |
| f"--num_workers={cfg.train.num_workers}", | |
| f"--steps={cfg.train.steps}", | |
| f"--save_freq={cfg.train.save_freq}", | |
| f"--log_freq={cfg.train.log_freq}", | |
| f"--eval_freq={cfg.train.eval_freq}", | |
| f"--seed={cfg.train.seed}", | |
| f"--wandb.enable={bool_cli(cfg.train.wandb_enable)}", | |
| f"--resume={bool_cli(cfg.train.resume)}", | |
| ] | |
| print("command=" + " ".join(shlex.quote(part) for part in cmd)) | |
| sys.stdout.flush() | |
| sys.stderr.flush() | |
| os.execvpe(cmd[0], cmd, lerobot_env()) | |
| def resolve_policy_root(cfg: AppConfig) -> Path: | |
| if cfg.eval.policy_root is not None: | |
| return cfg.eval.policy_root | |
| candidates = sorted(cfg.train.output_dir.glob("checkpoints/*/pretrained_model")) | |
| if not candidates: | |
| raise FileNotFoundError(f"No checkpoint found under {cfg.train.output_dir / 'checkpoints'}") | |
| return candidates[-1] | |
| def load_policy(policy_root: Path, device: torch.device): | |
| policy_cfg = PreTrainedConfig.from_pretrained(policy_root, cli_overrides=[f"--device={device.type}"]) | |
| policy_cls = get_policy_class(policy_cfg.type) | |
| policy = policy_cls.from_pretrained(policy_root, config=policy_cfg, local_files_only=True) | |
| processor_overrides = {"device_processor": {"device": device.type}} | |
| preprocessor, postprocessor = make_pre_post_processors( | |
| policy_cfg, | |
| pretrained_path=str(policy_root), | |
| preprocessor_overrides=processor_overrides, | |
| postprocessor_overrides=processor_overrides, | |
| ) | |
| return policy, preprocessor, postprocessor, policy_cfg | |
| def get_policy_observation_keys(policy_cfg: PreTrainedConfig) -> tuple[str, str]: | |
| state_key = None | |
| image_key = None | |
| for key, feature in policy_cfg.input_features.items(): | |
| if feature.type is FeatureType.STATE and state_key is None: | |
| state_key = key | |
| if feature.type is FeatureType.VISUAL and image_key is None: | |
| image_key = key | |
| if state_key is None or image_key is None: | |
| raise ValueError("Policy must contain one state input and one image input.") | |
| return state_key, image_key | |
| def prepare_policy_input( | |
| state: np.ndarray, | |
| image: np.ndarray, | |
| state_key: str, | |
| image_key: str, | |
| task: str, | |
| robot_type: str, | |
| device: torch.device, | |
| ) -> dict: | |
| return { | |
| state_key: torch.from_numpy(state.astype(np.float32)).unsqueeze(0).to(device), | |
| image_key: ( | |
| torch.from_numpy(image.astype(np.float32) / 255.0) | |
| .permute(2, 0, 1) | |
| .contiguous() | |
| .unsqueeze(0) | |
| .to(device) | |
| ), | |
| "task": task, | |
| "robot_type": robot_type, | |
| } | |
| def action_limits(model: mujoco.MjModel, ids: dict[str, int]) -> tuple[np.ndarray, np.ndarray]: | |
| lower = [] | |
| upper = [] | |
| for joint_name in ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll"]: | |
| low, high = model.jnt_range[ids[joint_name]] | |
| lower.append(np.rad2deg(low)) | |
| upper.append(np.rad2deg(high)) | |
| lower.append(0.0) | |
| upper.append(100.0) | |
| return np.asarray(lower, dtype=np.float64), np.asarray(upper, dtype=np.float64) | |
| def update_cube_attachment( | |
| attached: bool, | |
| cube_pos: np.ndarray, | |
| tcp_pos: np.ndarray, | |
| action: np.ndarray, | |
| goal_pos: np.ndarray, | |
| ) -> tuple[bool, np.ndarray]: | |
| grip = float(action[5]) | |
| tcp_to_cube = np.linalg.norm(tcp_pos - (cube_pos - ATTACH_OFFSET)) | |
| if attached: | |
| if grip > GRIP_OPEN_THRESHOLD: | |
| released = np.array([tcp_pos[0], tcp_pos[1], TABLE_Z], dtype=np.float64) | |
| if np.linalg.norm(released[:2] - goal_pos[:2]) < SUCCESS_XY_THRESHOLD: | |
| released[:2] = goal_pos[:2] | |
| return False, released | |
| return True, tcp_pos + ATTACH_OFFSET | |
| if grip < GRIP_CLOSE_THRESHOLD and tcp_to_cube < GRASP_DISTANCE_THRESHOLD: | |
| return True, tcp_pos + ATTACH_OFFSET | |
| return False, cube_pos | |
| def check_success(attached: bool, cube_pos: np.ndarray, goal_pos: np.ndarray) -> bool: | |
| if attached: | |
| return False | |
| return np.linalg.norm(cube_pos[:2] - goal_pos[:2]) < SUCCESS_XY_THRESHOLD and abs(cube_pos[2] - TABLE_Z) < 1e-3 | |
| def eval_policy(cfg: AppConfig) -> None: | |
| policy_root = resolve_policy_root(cfg) | |
| device = parse_device(cfg.eval.device) | |
| policy, preprocessor, postprocessor, policy_cfg = load_policy(policy_root, device) | |
| state_key, image_key = get_policy_observation_keys(policy_cfg) | |
| image_shape = policy_cfg.input_features[image_key].shape | |
| height = int(image_shape[1]) | |
| width = int(image_shape[2]) | |
| model, data, ids = prepare_model(cfg.eval.base_height) | |
| renderer = mujoco.Renderer(model, height=height, width=width) | |
| lower, upper = action_limits(model, ids) | |
| rng = np.random.default_rng(cfg.eval.seed) | |
| successes = 0 | |
| if cfg.eval.show_viewer: | |
| viewer_cm = mujoco.viewer.launch_passive(model, data) | |
| else: | |
| viewer_cm = None | |
| try: | |
| if viewer_cm is not None: | |
| viewer = viewer_cm.__enter__() | |
| configure_viewer_camera(viewer, cfg.eval.base_height) | |
| else: | |
| viewer = None | |
| for episode_index in range(cfg.eval.episodes): | |
| policy.reset() | |
| preprocessor.reset() | |
| postprocessor.reset() | |
| object_pos, goal_pos = sample_positions(rng) | |
| model.site_pos[ids["goal_site"]] = goal_pos | |
| current_state = HOME_STATE.copy() | |
| attached = False | |
| cube_pos = object_pos.copy() | |
| success = False | |
| for _ in range(cfg.eval.max_steps): | |
| set_arm_state(model, data, ids, current_state) | |
| set_cube_pose(model, data, ids, cube_pos) | |
| mujoco.mj_forward(model, data) | |
| image = render_frame(renderer, data, cfg.eval.base_height) | |
| obs = prepare_policy_input( | |
| state=current_state, | |
| image=image, | |
| state_key=state_key, | |
| image_key=image_key, | |
| task=cfg.eval.task, | |
| robot_type=cfg.eval.robot_type, | |
| device=device, | |
| ) | |
| with torch.inference_mode(): | |
| processed = preprocessor(obs) | |
| action = policy.select_action(processed) | |
| action = postprocessor(action) | |
| current_state = action.squeeze(0).detach().to("cpu").numpy().astype(np.float64) | |
| current_state = np.clip(current_state, lower, upper) | |
| set_arm_state(model, data, ids, current_state) | |
| mujoco.mj_forward(model, data) | |
| tcp_pos = data.site_xpos[ids["tcp"]].copy() | |
| attached, cube_pos = update_cube_attachment(attached, cube_pos, tcp_pos, current_state, goal_pos) | |
| set_cube_pose(model, data, ids, cube_pos) | |
| mujoco.mj_forward(model, data) | |
| if viewer is not None: | |
| viewer.sync() | |
| if not viewer.is_running(): | |
| viewer = None | |
| elif cfg.eval.step_sleep > 0: | |
| time.sleep(cfg.eval.step_sleep) | |
| if check_success(attached, cube_pos, goal_pos): | |
| success = True | |
| break | |
| successes += int(success) | |
| print(f"episode={episode_index} success={int(success)} cube={cube_pos.round(4).tolist()} goal={goal_pos.round(4).tolist()}") | |
| print(f"policy_root={policy_root}") | |
| print(f"episodes={cfg.eval.episodes}") | |
| print(f"successes={successes}") | |
| print(f"success_rate={successes / cfg.eval.episodes:.3f}") | |
| finally: | |
| renderer.close() | |
| if viewer_cm is not None: | |
| viewer_cm.__exit__(None, None, None) | |