from __future__ import annotations import json import os import shlex import shutil import subprocess import sys import time from dataclasses import dataclass from pathlib import Path import mujoco import numpy as np import torch from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import FeatureType from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.policies.factory import get_policy_class, make_pre_post_processors try: import mujoco.viewer except Exception: pass JOINT_NAMES = [ "shoulder_pan.pos", "shoulder_lift.pos", "elbow_flex.pos", "wrist_flex.pos", "wrist_roll.pos", "gripper.pos", ] HOME_STATE = np.array([0.0, -35.0, 75.0, -40.0, 0.0, 70.0], dtype=np.float64) TABLE_Z = 0.025 ATTACH_OFFSET = np.array([0.0, 0.0, -0.055], dtype=np.float64) GRIP_CLOSE_THRESHOLD = 25.0 GRIP_OPEN_THRESHOLD = 40.0 GRASP_DISTANCE_THRESHOLD = 0.045 SUCCESS_XY_THRESHOLD = 0.05 @dataclass class DatasetConfig: repo_id: str root: Path episodes: int fps: int width: int height: int image_key: str task: str robot_type: str base_height: float seed: int overwrite: bool show_viewer: bool @dataclass class VizConfig: episode_index: int @dataclass class TrainConfig: output_dir: Path job_name: str device: str batch_size: int num_workers: int steps: int save_freq: int log_freq: int eval_freq: int seed: int wandb_enable: bool resume: bool @dataclass class EvalConfig: policy_root: Path | None episodes: int max_steps: int device: str base_height: float seed: int show_viewer: bool step_sleep: float task: str robot_type: str @dataclass class AppConfig: config_path: Path dataset: DatasetConfig viz: VizConfig train: TrainConfig eval: EvalConfig def resolve_path(base_dir: Path, value: str) -> Path: path = Path(value).expanduser() if path.is_absolute(): return path return (base_dir / path).resolve() def load_config(config_path: str | Path | None = None) -> AppConfig: path = Path(config_path or Path(__file__).with_name("config.json")).expanduser().resolve() raw = json.loads(path.read_text()) base_dir = path.parent dataset = DatasetConfig( repo_id=raw["dataset"]["repo_id"], root=resolve_path(base_dir, raw["dataset"]["root"]), episodes=int(raw["dataset"]["episodes"]), fps=int(raw["dataset"]["fps"]), width=int(raw["dataset"]["width"]), height=int(raw["dataset"]["height"]), image_key=raw["dataset"]["image_key"], task=raw["dataset"]["task"], robot_type=raw["dataset"]["robot_type"], base_height=float(raw["dataset"]["base_height"]), seed=int(raw["dataset"]["seed"]), overwrite=bool(raw["dataset"]["overwrite"]), show_viewer=bool(raw["dataset"]["show_viewer"]), ) viz = VizConfig(episode_index=int(raw["viz"]["episode_index"])) train = TrainConfig( output_dir=resolve_path(base_dir, raw["train"]["output_dir"]), job_name=raw["train"]["job_name"], device=raw["train"]["device"], batch_size=int(raw["train"]["batch_size"]), num_workers=int(raw["train"]["num_workers"]), steps=int(raw["train"]["steps"]), save_freq=int(raw["train"]["save_freq"]), log_freq=int(raw["train"]["log_freq"]), eval_freq=int(raw["train"]["eval_freq"]), seed=int(raw["train"]["seed"]), wandb_enable=bool(raw["train"]["wandb_enable"]), resume=bool(raw["train"]["resume"]), ) policy_root_value = raw["eval"]["policy_root"] eval_cfg = EvalConfig( policy_root=resolve_path(base_dir, policy_root_value) if policy_root_value else None, episodes=int(raw["eval"]["episodes"]), max_steps=int(raw["eval"]["max_steps"]), device=raw["eval"]["device"], base_height=float(raw["eval"]["base_height"]), seed=int(raw["eval"]["seed"]), show_viewer=bool(raw["eval"]["show_viewer"]), step_sleep=float(raw["eval"]["step_sleep"]), task=raw["eval"]["task"], robot_type=raw["eval"]["robot_type"], ) return AppConfig(config_path=path, dataset=dataset, viz=viz, train=train, eval=eval_cfg) def lerobot_env() -> dict[str, str]: return os.environ.copy() def maybe_reexec_with_mjpython(show_viewer: bool, env_key: str) -> None: if not show_viewer or sys.platform != "darwin": return if os.environ.get(env_key) == "1": return if Path(sys.executable).name == "mjpython": return mjpython_path = Path(sys.executable).with_name("mjpython") if not mjpython_path.exists(): raise RuntimeError(f"`mjpython` not found next to {sys.executable}") env = os.environ.copy() env[env_key] = "1" os.execve(str(mjpython_path), [str(mjpython_path), *sys.argv], env) def bool_cli(value: bool) -> str: return "true" if value else "false" def parse_device(device: str) -> torch.device: if device == "auto": if torch.cuda.is_available(): return torch.device("cuda") if torch.backends.mps.is_available(): return torch.device("mps") return torch.device("cpu") return torch.device(device) def build_xml(base_height: float) -> str: return f""" """ def prepare_model(base_height: float) -> tuple[mujoco.MjModel, mujoco.MjData, dict[str, int]]: model = mujoco.MjModel.from_xml_string(build_xml(base_height)) data = mujoco.MjData(model) ids = { "tcp": mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_SITE, "tcp"), "goal_site": mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_SITE, "goal_site"), "cube_joint": mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_JOINT, "cube_freejoint"), "shoulder_pan": mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_JOINT, "shoulder_pan"), "shoulder_lift": mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_JOINT, "shoulder_lift"), "elbow_flex": mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_JOINT, "elbow_flex"), "wrist_flex": mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_JOINT, "wrist_flex"), "wrist_roll": mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_JOINT, "wrist_roll"), "left_finger": mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_JOINT, "left_finger_slide"), "right_finger": mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_JOINT, "right_finger_slide"), } return model, data, ids def set_arm_state(model: mujoco.MjModel, data: mujoco.MjData, ids: dict[str, int], joint_state: np.ndarray) -> None: qpos = data.qpos qpos[model.jnt_qposadr[ids["shoulder_pan"]]] = np.deg2rad(joint_state[0]) qpos[model.jnt_qposadr[ids["shoulder_lift"]]] = np.deg2rad(joint_state[1]) qpos[model.jnt_qposadr[ids["elbow_flex"]]] = np.deg2rad(joint_state[2]) qpos[model.jnt_qposadr[ids["wrist_flex"]]] = np.deg2rad(joint_state[3]) qpos[model.jnt_qposadr[ids["wrist_roll"]]] = np.deg2rad(joint_state[4]) finger_slide = 0.003 + 0.022 * float(np.clip(joint_state[5] / 100.0, 0.0, 1.0)) qpos[model.jnt_qposadr[ids["left_finger"]]] = finger_slide qpos[model.jnt_qposadr[ids["right_finger"]]] = finger_slide def set_cube_pose(model: mujoco.MjModel, data: mujoco.MjData, ids: dict[str, int], pos: np.ndarray) -> None: adr = model.jnt_qposadr[ids["cube_joint"]] data.qpos[adr : adr + 7] = np.array([pos[0], pos[1], pos[2], 1.0, 0.0, 0.0, 0.0], dtype=np.float64) def render_frame(renderer: mujoco.Renderer, data: mujoco.MjData, base_height: float) -> np.ndarray: camera = mujoco.MjvCamera() camera.type = mujoco.mjtCamera.mjCAMERA_FREE camera.lookat[:] = np.array([0.18, 0.0, base_height + 0.08], dtype=np.float64) camera.distance = 1.0 camera.azimuth = 140.0 camera.elevation = -45.0 renderer.update_scene(data, camera=camera) return renderer.render() def configure_viewer_camera(viewer, base_height: float) -> None: viewer.cam.lookat[:] = np.array([0.18, 0.0, base_height + 0.10], dtype=np.float64) viewer.cam.distance = 1.15 viewer.cam.azimuth = 135.0 viewer.cam.elevation = -32.0 def sample_positions(rng: np.random.Generator) -> tuple[np.ndarray, np.ndarray]: while True: obj = np.array([rng.uniform(0.18, 0.28), rng.uniform(-0.12, 0.12), TABLE_Z], dtype=np.float64) goal = np.array([rng.uniform(0.10, 0.25), rng.uniform(-0.18, 0.18), TABLE_Z], dtype=np.float64) if np.linalg.norm(obj[:2] - goal[:2]) >= 0.14: return obj, goal def solve_ik(target_xyz: np.ndarray) -> np.ndarray: shoulder_origin = np.array([0.0, 0.0, 0.50], dtype=np.float64) link_1 = 0.18 link_2 = 0.16 tool = 0.195 desired_pitch = -np.pi / 2.0 rel = target_xyz - shoulder_origin yaw = np.arctan2(rel[1], rel[0]) radial = np.hypot(rel[0], rel[1]) wrist_x = radial - tool * np.cos(desired_pitch) wrist_z = rel[2] - tool * np.sin(desired_pitch) dist_sq = wrist_x**2 + wrist_z**2 cos_elbow = np.clip((dist_sq - link_1**2 - link_2**2) / (2.0 * link_1 * link_2), -1.0, 1.0) elbow = np.arccos(cos_elbow) shoulder = np.arctan2(wrist_z, wrist_x) - np.arctan2(link_2 * np.sin(elbow), link_1 + link_2 * np.cos(elbow)) wrist = desired_pitch - shoulder - elbow return np.rad2deg(np.array([yaw, -shoulder, -elbow, -wrist, 0.0], dtype=np.float64)) def interpolate(start: np.ndarray, end: np.ndarray, steps: int) -> list[np.ndarray]: if steps <= 1: return [end.astype(np.float64)] return [(start + (end - start) * (i / (steps - 1))).astype(np.float64) for i in range(steps)] def make_episode_trajectory(obj: np.ndarray, goal: np.ndarray) -> list[tuple[np.ndarray, bool, np.ndarray]]: open_grip = 70.0 closed_grip = 8.0 above_obj = np.concatenate([solve_ik(obj + np.array([0.0, 0.0, 0.12])), [open_grip]]) grasp = np.concatenate([solve_ik(obj + np.array([0.0, 0.0, 0.045])), [open_grip]]) close = grasp.copy() close[-1] = closed_grip lift = np.concatenate([solve_ik(obj + np.array([0.0, 0.0, 0.18])), [closed_grip]]) above_goal = np.concatenate([solve_ik(goal + np.array([0.0, 0.0, 0.16])), [closed_grip]]) place = np.concatenate([solve_ik(goal + np.array([0.0, 0.0, 0.05])), [closed_grip]]) open_at_goal = place.copy() open_at_goal[-1] = open_grip retreat = np.concatenate([solve_ik(goal + np.array([0.0, 0.0, 0.18])), [open_grip]]) segments = [ (HOME_STATE, above_obj, 18, False, obj), (above_obj, grasp, 12, False, obj), (grasp, close, 8, False, obj), (close, lift, 14, True, obj), (lift, above_goal, 20, True, obj), (above_goal, place, 12, True, obj), (place, open_at_goal, 8, True, goal), (open_at_goal, retreat, 12, False, goal), ] frames: list[tuple[np.ndarray, bool, np.ndarray]] = [] for start, end, steps, attached, cube_anchor in segments: for state in interpolate(start, end, steps): frames.append((state, attached, cube_anchor.copy())) return frames def dataset_features(image_key: str, width: int, height: int) -> dict: return { "action": {"dtype": "float32", "shape": (6,), "names": JOINT_NAMES}, "observation.state": {"dtype": "float32", "shape": (6,), "names": JOINT_NAMES}, f"observation.images.{image_key}": { "dtype": "image", "shape": (height, width, 3), "names": ["height", "width", "channels"], }, } def collect_dataset(cfg: AppConfig) -> None: if cfg.dataset.root.exists(): if not cfg.dataset.overwrite: raise FileExistsError(f"Dataset root already exists: {cfg.dataset.root}") shutil.rmtree(cfg.dataset.root) dataset = LeRobotDataset.create( repo_id=cfg.dataset.repo_id, root=cfg.dataset.root, fps=cfg.dataset.fps, features=dataset_features(cfg.dataset.image_key, cfg.dataset.width, cfg.dataset.height), robot_type=cfg.dataset.robot_type, use_videos=False, ) model, data, ids = prepare_model(cfg.dataset.base_height) renderer = mujoco.Renderer(model, height=cfg.dataset.height, width=cfg.dataset.width) rng = np.random.default_rng(cfg.dataset.seed) dt = 1.0 / cfg.dataset.fps if cfg.dataset.show_viewer: viewer_cm = mujoco.viewer.launch_passive(model, data) else: viewer_cm = None try: if viewer_cm is not None: viewer = viewer_cm.__enter__() configure_viewer_camera(viewer, cfg.dataset.base_height) else: viewer = None for _ in range(cfg.dataset.episodes): obj_pos, goal_pos = sample_positions(rng) model.site_pos[ids["goal_site"]] = goal_pos frames = make_episode_trajectory(obj_pos, goal_pos) prev_state = frames[0][0].copy() attached_once = False for state, attached, cube_anchor in frames: start = time.perf_counter() if attached and not attached_once: attached_once = True set_arm_state(model, data, ids, state) mujoco.mj_forward(model, data) tcp_pos = data.site_xpos[ids["tcp"]].copy() if attached: cube_pos = tcp_pos + ATTACH_OFFSET elif attached_once: cube_pos = goal_pos.copy() else: cube_pos = cube_anchor.copy() set_cube_pose(model, data, ids, cube_pos) mujoco.mj_forward(model, data) image = render_frame(renderer, data, cfg.dataset.base_height) dataset.add_frame( { "observation.state": prev_state.astype(np.float32), "action": state.astype(np.float32), f"observation.images.{cfg.dataset.image_key}": image, "task": cfg.dataset.task, } ) prev_state = state.copy() if viewer is not None: viewer.sync() if not viewer.is_running(): viewer = None else: elapsed = time.perf_counter() - start if elapsed < dt: time.sleep(dt - elapsed) dataset.save_episode() finally: renderer.close() dataset.finalize() if viewer_cm is not None: viewer_cm.__exit__(None, None, None) def run_subprocess(cmd: list[str]) -> None: subprocess.run(cmd, check=True, env=lerobot_env()) def viz_dataset(cfg: AppConfig) -> None: cmd = [ "lerobot-dataset-viz", "--repo-id", cfg.dataset.repo_id, "--root", str(cfg.dataset.root), "--episode-index", str(cfg.viz.episode_index), ] run_subprocess(cmd) def train_policy(cfg: AppConfig) -> None: cmd = [ "lerobot-train", f"--dataset.repo_id={cfg.dataset.repo_id}", f"--dataset.root={cfg.dataset.root}", "--policy.type=act", "--policy.push_to_hub=false", f"--output_dir={cfg.train.output_dir}", f"--job_name={cfg.train.job_name}", f"--policy.device={cfg.train.device}", f"--batch_size={cfg.train.batch_size}", f"--num_workers={cfg.train.num_workers}", f"--steps={cfg.train.steps}", f"--save_freq={cfg.train.save_freq}", f"--log_freq={cfg.train.log_freq}", f"--eval_freq={cfg.train.eval_freq}", f"--seed={cfg.train.seed}", f"--wandb.enable={bool_cli(cfg.train.wandb_enable)}", f"--resume={bool_cli(cfg.train.resume)}", ] print("command=" + " ".join(shlex.quote(part) for part in cmd)) sys.stdout.flush() sys.stderr.flush() os.execvpe(cmd[0], cmd, lerobot_env()) def resolve_policy_root(cfg: AppConfig) -> Path: if cfg.eval.policy_root is not None: return cfg.eval.policy_root candidates = sorted(cfg.train.output_dir.glob("checkpoints/*/pretrained_model")) if not candidates: raise FileNotFoundError(f"No checkpoint found under {cfg.train.output_dir / 'checkpoints'}") return candidates[-1] def load_policy(policy_root: Path, device: torch.device): policy_cfg = PreTrainedConfig.from_pretrained(policy_root, cli_overrides=[f"--device={device.type}"]) policy_cls = get_policy_class(policy_cfg.type) policy = policy_cls.from_pretrained(policy_root, config=policy_cfg, local_files_only=True) processor_overrides = {"device_processor": {"device": device.type}} preprocessor, postprocessor = make_pre_post_processors( policy_cfg, pretrained_path=str(policy_root), preprocessor_overrides=processor_overrides, postprocessor_overrides=processor_overrides, ) return policy, preprocessor, postprocessor, policy_cfg def get_policy_observation_keys(policy_cfg: PreTrainedConfig) -> tuple[str, str]: state_key = None image_key = None for key, feature in policy_cfg.input_features.items(): if feature.type is FeatureType.STATE and state_key is None: state_key = key if feature.type is FeatureType.VISUAL and image_key is None: image_key = key if state_key is None or image_key is None: raise ValueError("Policy must contain one state input and one image input.") return state_key, image_key def prepare_policy_input( state: np.ndarray, image: np.ndarray, state_key: str, image_key: str, task: str, robot_type: str, device: torch.device, ) -> dict: return { state_key: torch.from_numpy(state.astype(np.float32)).unsqueeze(0).to(device), image_key: ( torch.from_numpy(image.astype(np.float32) / 255.0) .permute(2, 0, 1) .contiguous() .unsqueeze(0) .to(device) ), "task": task, "robot_type": robot_type, } def action_limits(model: mujoco.MjModel, ids: dict[str, int]) -> tuple[np.ndarray, np.ndarray]: lower = [] upper = [] for joint_name in ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll"]: low, high = model.jnt_range[ids[joint_name]] lower.append(np.rad2deg(low)) upper.append(np.rad2deg(high)) lower.append(0.0) upper.append(100.0) return np.asarray(lower, dtype=np.float64), np.asarray(upper, dtype=np.float64) def update_cube_attachment( attached: bool, cube_pos: np.ndarray, tcp_pos: np.ndarray, action: np.ndarray, goal_pos: np.ndarray, ) -> tuple[bool, np.ndarray]: grip = float(action[5]) tcp_to_cube = np.linalg.norm(tcp_pos - (cube_pos - ATTACH_OFFSET)) if attached: if grip > GRIP_OPEN_THRESHOLD: released = np.array([tcp_pos[0], tcp_pos[1], TABLE_Z], dtype=np.float64) if np.linalg.norm(released[:2] - goal_pos[:2]) < SUCCESS_XY_THRESHOLD: released[:2] = goal_pos[:2] return False, released return True, tcp_pos + ATTACH_OFFSET if grip < GRIP_CLOSE_THRESHOLD and tcp_to_cube < GRASP_DISTANCE_THRESHOLD: return True, tcp_pos + ATTACH_OFFSET return False, cube_pos def check_success(attached: bool, cube_pos: np.ndarray, goal_pos: np.ndarray) -> bool: if attached: return False return np.linalg.norm(cube_pos[:2] - goal_pos[:2]) < SUCCESS_XY_THRESHOLD and abs(cube_pos[2] - TABLE_Z) < 1e-3 def eval_policy(cfg: AppConfig) -> None: policy_root = resolve_policy_root(cfg) device = parse_device(cfg.eval.device) policy, preprocessor, postprocessor, policy_cfg = load_policy(policy_root, device) state_key, image_key = get_policy_observation_keys(policy_cfg) image_shape = policy_cfg.input_features[image_key].shape height = int(image_shape[1]) width = int(image_shape[2]) model, data, ids = prepare_model(cfg.eval.base_height) renderer = mujoco.Renderer(model, height=height, width=width) lower, upper = action_limits(model, ids) rng = np.random.default_rng(cfg.eval.seed) successes = 0 if cfg.eval.show_viewer: viewer_cm = mujoco.viewer.launch_passive(model, data) else: viewer_cm = None try: if viewer_cm is not None: viewer = viewer_cm.__enter__() configure_viewer_camera(viewer, cfg.eval.base_height) else: viewer = None for episode_index in range(cfg.eval.episodes): policy.reset() preprocessor.reset() postprocessor.reset() object_pos, goal_pos = sample_positions(rng) model.site_pos[ids["goal_site"]] = goal_pos current_state = HOME_STATE.copy() attached = False cube_pos = object_pos.copy() success = False for _ in range(cfg.eval.max_steps): set_arm_state(model, data, ids, current_state) set_cube_pose(model, data, ids, cube_pos) mujoco.mj_forward(model, data) image = render_frame(renderer, data, cfg.eval.base_height) obs = prepare_policy_input( state=current_state, image=image, state_key=state_key, image_key=image_key, task=cfg.eval.task, robot_type=cfg.eval.robot_type, device=device, ) with torch.inference_mode(): processed = preprocessor(obs) action = policy.select_action(processed) action = postprocessor(action) current_state = action.squeeze(0).detach().to("cpu").numpy().astype(np.float64) current_state = np.clip(current_state, lower, upper) set_arm_state(model, data, ids, current_state) mujoco.mj_forward(model, data) tcp_pos = data.site_xpos[ids["tcp"]].copy() attached, cube_pos = update_cube_attachment(attached, cube_pos, tcp_pos, current_state, goal_pos) set_cube_pose(model, data, ids, cube_pos) mujoco.mj_forward(model, data) if viewer is not None: viewer.sync() if not viewer.is_running(): viewer = None elif cfg.eval.step_sleep > 0: time.sleep(cfg.eval.step_sleep) if check_success(attached, cube_pos, goal_pos): success = True break successes += int(success) print(f"episode={episode_index} success={int(success)} cube={cube_pos.round(4).tolist()} goal={goal_pos.round(4).tolist()}") print(f"policy_root={policy_root}") print(f"episodes={cfg.eval.episodes}") print(f"successes={successes}") print(f"success_rate={successes / cfg.eval.episodes:.3f}") finally: renderer.close() if viewer_cm is not None: viewer_cm.__exit__(None, None, None)