| """RoboSuite simulator — real MuJoCo physics with OSC arm control. |
| |
| Uses RoboSuite's default OSC_POSE controller for precise end-effector |
| control. Provides goto_pose, gripper, IK, and blocking motion. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import os |
| from typing import Any |
|
|
| import numpy as np |
|
|
| from anima_naka.gym.base_env import BaseEnv |
|
|
| os.environ.setdefault("MUJOCO_GL", "egl") |
|
|
| _ROBOSUITE_ENVS = { |
| "robosuite_cube_lift": ("Lift", {}), |
| "robosuite_cube_stack": ("Stack", {}), |
| "robosuite_spill_wipe": ("Wipe", {}), |
| "robosuite_peg_insertion": ("NutAssemblySquare", {}), |
| "robosuite_cube_restack": ("TwoArmLift", {"env_configuration": "parallel"}), |
| "robosuite_two_arm_lift": ("TwoArmLift", {"env_configuration": "parallel"}), |
| "robosuite_two_arm_handover": ("TwoArmHandover", {"env_configuration": "parallel"}), |
| } |
| |
| _JOINT_LOWER = np.array([-2.8973, -1.7628, -2.8973, -3.0718, -2.8973, -0.0175, -2.8973]) |
| _JOINT_UPPER = np.array([2.8973, 1.7628, 2.8973, -0.0698, 2.8973, 3.7525, 2.8973]) |
|
|
|
|
| class RoboSuiteSim(BaseEnv): |
| """RoboSuite MuJoCo env with OSC pose control. |
| |
| Default controller: OSC_POSE (delta input, 6 pose dims + 1 gripper). |
| Actions: [dx, dy, dz, dax, day, daz, gripper] where gripper: -1=open, 1=close. |
| """ |
|
|
| def __init__( |
| self, |
| sim_name: str = "robosuite_cube_lift", |
| bimanual: bool = False, |
| render_size: int = 640, |
| ): |
| import robosuite as suite |
|
|
| env_name, kwargs = _ROBOSUITE_ENVS.get(sim_name, ("Lift", {})) |
| robots = ["Panda", "Panda"] if bimanual else "Panda" |
|
|
| self._camera_name = "frontview" |
| self._env = suite.make( |
| env_name, |
| robots=robots, |
| has_renderer=False, |
| has_offscreen_renderer=True, |
| use_camera_obs=True, |
| camera_names=[self._camera_name], |
| camera_heights=render_size, |
| camera_widths=render_size, |
| **kwargs, |
| ) |
|
|
| self._render_size = render_size |
| self._bimanual = bimanual |
| self.step_count = 0 |
| self.max_steps = 1500 |
| self._sim_step_count = 0 |
| self._gripper_cmd = -1.0 |
|
|
| |
|
|
| def reset(self, seed: int | None = None) -> tuple[dict[str, Any], dict[str, Any]]: |
| self.step_count = 0 |
| self._sim_step_count = 0 |
| self._gripper_cmd = -1.0 |
| if seed is not None: |
| np.random.seed(seed) |
| self._env.reset() |
| return self.get_observation(), {} |
|
|
| |
|
|
| def get_observation(self) -> dict[str, Any]: |
| obs = self._env._get_observations() |
| rgb = obs.get( |
| "frontview_image", |
| np.zeros((self._render_size, self._render_size, 3), dtype=np.uint8), |
| ) |
| if rgb.ndim == 3: |
| rgb = rgb[::-1] |
|
|
| depth_raw = obs.get("frontview_depth", None) |
| if depth_raw is not None: |
| if depth_raw.ndim == 3: |
| depth_raw = depth_raw[:, :, 0] |
| depth = self._depth_to_meters(depth_raw[::-1]) |
| else: |
| depth = np.ones((self._render_size, self._render_size), dtype=np.float32) |
|
|
| return { |
| "robot0_robotview": { |
| "images": {"rgb": rgb, "depth": depth}, |
| "intrinsics": self._get_camera_intrinsics(self._camera_name), |
| "pose_mat": self._get_camera_pose_matrix(self._camera_name), |
| } |
| } |
|
|
| def _depth_to_meters(self, depth_buffer: np.ndarray) -> np.ndarray: |
| try: |
| from robosuite.utils.camera_utils import get_real_depth_map |
| return get_real_depth_map(self._env.sim, depth_buffer).astype(np.float32) |
| except ImportError: |
| extent = self._env.sim.model.stat.extent |
| near = self._env.sim.model.vis.map.znear * extent |
| far = self._env.sim.model.vis.map.zfar * extent |
| return (near / (1.0 - depth_buffer * (1.0 - near / far))).astype(np.float32) |
|
|
| def _get_camera_intrinsics(self, camera_name: str) -> np.ndarray: |
| try: |
| cam_id = self._env.sim.model.camera_name2id(camera_name) |
| fovy = self._env.sim.model.cam_fovy[cam_id] |
| except Exception: |
| fovy = 45.0 |
| f = 0.5 * self._render_size / np.tan(np.radians(fovy / 2.0)) |
| cx = cy = self._render_size / 2.0 |
| return np.array([[f, 0, cx], [0, f, cy], [0, 0, 1]], dtype=np.float64) |
|
|
| def _get_camera_pose_matrix(self, camera_name: str) -> np.ndarray: |
| try: |
| sim = self._env.sim |
| mat = np.eye(4, dtype=np.float64) |
| mat[:3, :3] = sim.data.get_camera_xmat(camera_name).reshape(3, 3) |
| mat[:3, 3] = sim.data.get_camera_xpos(camera_name) |
| return mat |
| except Exception: |
| return np.eye(4, dtype=np.float64) |
|
|
| |
|
|
| def compute_reward(self) -> float: |
| """Compute reward from RoboSuite env.""" |
| try: |
| return float(self._env.reward()) |
| except TypeError: |
| |
| action = np.zeros(self._env.action_dim) |
| return float(self._env.reward(action)) |
|
|
| |
|
|
| def get_object_position(self, name: str) -> np.ndarray: |
| """Get 3D position of named object. |
| |
| Searches MuJoCo bodies with fuzzy matching: |
| 'cube' matches 'cube_main', 'red cube' matches 'cube_main', etc. |
| Excludes robot/mount/table infrastructure bodies. |
| """ |
| sim = self._env.sim |
| tokens = name.lower().replace("_", " ").split() |
| _EXCLUDE = {"world", "table", "robot0", "fixed_mount", "gripper", "link"} |
|
|
| best_match = None |
| best_score = 0 |
| for body_name in sim.model.body_names: |
| bn_lower = body_name.lower() |
| |
| if any(ex in bn_lower for ex in _EXCLUDE): |
| continue |
| score = sum(1 for t in tokens if t in bn_lower) |
| if score > best_score: |
| best_score = score |
| best_match = body_name |
|
|
| if best_match and best_score > 0: |
| body_id = sim.model.body_name2id(best_match) |
| return sim.data.body_xpos[body_id].copy() |
| return np.zeros(3) |
|
|
| def get_table_height(self) -> float: |
| try: |
| body_id = self._env.sim.model.body_name2id("table_main") |
| return float(self._env.sim.data.body_xpos[body_id][2]) |
| except Exception: |
| return 0.82 |
|
|
| |
|
|
| def _get_ee_pos(self) -> np.ndarray: |
| """Get end-effector position from sim.""" |
| obs = self._env._get_observations() |
| return np.array(obs["robot0_eef_pos"], dtype=np.float64) |
|
|
| def _get_ee_quat_xyzw(self) -> np.ndarray: |
| """Get end-effector quaternion (xyzw) from sim.""" |
| obs = self._env._get_observations() |
| return np.array(obs["robot0_eef_quat"], dtype=np.float64) |
|
|
| def get_joint_positions(self) -> np.ndarray: |
| obs = self._env._get_observations() |
| return np.array(obs["robot0_joint_pos"][:7], dtype=np.float64) |
|
|
| |
|
|
| def _step_action(self, action: np.ndarray) -> None: |
| """Step the simulation with the given action vector.""" |
| action_dim = self._env.action_dim |
| if len(action) > action_dim: |
| action = action[:action_dim] |
| elif len(action) < action_dim: |
| action = np.pad(action, (0, action_dim - len(action))) |
| self._env.step(action) |
| self._sim_step_count += 1 |
|
|
| def _step_zero(self) -> None: |
| """Step with zero delta (hold position, maintain gripper).""" |
| action = np.zeros(self._env.action_dim) |
| action[-1] = self._gripper_cmd |
| self._step_action(action) |
|
|
| |
|
|
| def open_gripper(self) -> None: |
| """Open gripper with 30 physics steps.""" |
| self._gripper_cmd = -1.0 |
| for _ in range(30): |
| self._step_zero() |
|
|
| def close_gripper(self) -> None: |
| """Close gripper with 30 physics steps.""" |
| self._gripper_cmd = 1.0 |
| for _ in range(30): |
| self._step_zero() |
|
|
| |
|
|
| def goto_pose( |
| self, |
| position: np.ndarray, |
| quaternion_wxyz: np.ndarray, |
| z_approach: float = 0.0, |
| ) -> None: |
| """Move EE to target pose using OSC delta control loop. |
| |
| Args: |
| position: (3,) target xyz in world frame. |
| quaternion_wxyz: (4,) target orientation [w,x,y,z]. |
| z_approach: If > 0, approach from above first. |
| """ |
| pos = np.asarray(position, dtype=np.float64) |
| quat = np.asarray(quaternion_wxyz, dtype=np.float64) |
|
|
| if z_approach > 0: |
| approach_pos = pos.copy() |
| approach_pos[2] += z_approach |
| self._move_to_position(approach_pos, quat, max_steps=150) |
|
|
| self._move_to_position(pos, quat, max_steps=200) |
|
|
| def _move_to_position( |
| self, |
| target_pos: np.ndarray, |
| target_quat_wxyz: np.ndarray, |
| max_steps: int = 200, |
| pos_tol: float = 0.01, |
| ) -> None: |
| """OSC delta control loop: send deltas until EE reaches target.""" |
| gain = 10.0 |
|
|
| for _ in range(max_steps): |
| current_pos = self._get_ee_pos() |
| pos_error = target_pos - current_pos |
|
|
| if np.linalg.norm(pos_error) < pos_tol: |
| break |
|
|
| |
| delta_pos = np.clip(pos_error * gain, -1.0, 1.0) |
|
|
| |
| delta_ori = np.zeros(3) |
|
|
| action = np.concatenate([delta_pos, delta_ori, [self._gripper_cmd]]) |
| self._step_action(action) |
|
|
| def move_to_joints(self, joints: np.ndarray) -> None: |
| """Move toward target joint config using OSC (approximate). |
| |
| Since we use OSC_POSE controller, we can't directly set joints. |
| Instead, we compute the forward kinematics target and use goto_pose. |
| """ |
| |
| target = np.asarray(joints, dtype=np.float64).reshape(7) |
| try: |
| import mujoco |
|
|
| model = self._env.sim.model._model |
| data = self._env.sim.data._data |
| qpos_save = data.qpos.copy() |
|
|
| |
| data.qpos[:7] = target |
| mujoco.mj_forward(model, data) |
|
|
| |
| site_name = self._find_ee_site() |
| if site_name: |
| site_id = mujoco.mj_name2id( |
| model, mujoco.mjtObj.mjOBJ_SITE, site_name |
| ) |
| target_pos = data.site_xpos[site_id].copy() |
| else: |
| target_pos = self._get_ee_pos() |
|
|
| |
| data.qpos[:] = qpos_save |
| mujoco.mj_forward(model, data) |
|
|
| |
| ee_quat = self._get_ee_quat_xyzw() |
| wxyz = np.array([ee_quat[3], ee_quat[0], ee_quat[1], ee_quat[2]]) |
| self._move_to_position(target_pos, wxyz, max_steps=150) |
| except Exception: |
| |
| for _ in range(50): |
| self._step_zero() |
|
|
| def solve_ik( |
| self, position: np.ndarray, quaternion_wxyz: np.ndarray |
| ) -> np.ndarray: |
| """Solve IK using MuJoCo Jacobian-based iterative solver.""" |
| target_pos = np.asarray(position, dtype=np.float64) |
| try: |
| import mujoco |
|
|
| model = self._env.sim.model._model |
| data = self._env.sim.data._data |
| site_name = self._find_ee_site() |
| if not site_name: |
| return self.get_joint_positions() |
|
|
| site_id = mujoco.mj_name2id( |
| model, mujoco.mjtObj.mjOBJ_SITE, site_name |
| ) |
| qpos_save = data.qpos.copy() |
| qvel_save = data.qvel.copy() |
| nq = 7 |
| jacp = np.zeros((3, model.nv)) |
| damping = 0.05 |
|
|
| for _ in range(50): |
| mujoco.mj_forward(model, data) |
| current_pos = data.site_xpos[site_id].copy() |
| pos_err = target_pos - current_pos |
| if np.linalg.norm(pos_err) < 1e-3: |
| break |
| mujoco.mj_jacSite(model, data, jacp, None, site_id) |
| J = jacp[:, :nq] |
| JJT = J @ J.T + damping**2 * np.eye(3) |
| dq = J.T @ np.linalg.solve(JJT, pos_err) |
| data.qpos[:nq] += dq |
|
|
| result = data.qpos[:nq].copy() |
| data.qpos[:] = qpos_save |
| data.qvel[:] = qvel_save |
| mujoco.mj_forward(model, data) |
| return result |
| except Exception: |
| return self.get_joint_positions() |
|
|
| def _find_ee_site(self) -> str | None: |
| """Find end-effector site name in MuJoCo model.""" |
| import mujoco |
|
|
| model = self._env.sim.model._model |
| candidates = ["grip_site", "gripper0_grip_site", "robot0_grip_site", "ee_site"] |
| for name in candidates: |
| try: |
| mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_SITE, name) |
| return name |
| except Exception: |
| continue |
| for i in range(model.nsite): |
| site_name = mujoco.mj_id2name(model, mujoco.mjtObj.mjOBJ_SITE, i) |
| if site_name and "grip" in site_name.lower(): |
| return site_name |
| return None |
|
|
| def sample_grasp_pose(self, object_name: str) -> tuple[np.ndarray, np.ndarray]: |
| """Get a top-down grasp pose for the named object. |
| |
| Uses the Panda's home EE orientation (natural grasp direction). |
| """ |
| obj_pos = self.get_object_position(object_name) |
| if np.allclose(obj_pos, 0.0): |
| return np.zeros(3), np.array([1.0, 0.0, 0.0, 0.0]) |
|
|
| grasp_pos = obj_pos.copy() |
| grasp_pos[2] += 0.003 |
|
|
| |
| ee_quat_xyzw = self._get_ee_quat_xyzw() |
| grasp_quat_wxyz = np.array([ |
| ee_quat_xyzw[3], ee_quat_xyzw[0], ee_quat_xyzw[1], ee_quat_xyzw[2], |
| ]) |
| return grasp_pos, grasp_quat_wxyz |
|
|
| def get_ground_truth_masks(self, text_prompt: str) -> list[dict]: |
| """S1 tier: segmentation from MuJoCo.""" |
| obs = self._env._get_observations() |
| seg_key = "robot0_robotview_segmentation_instance" |
| if seg_key in obs: |
| seg = obs[seg_key][::-1] |
| mask = seg > 0 |
| if mask.ndim == 3: |
| mask = mask[:, :, 0] |
| ys, xs = np.where(mask) |
| if len(xs) > 0: |
| box = [int(xs.min()), int(ys.min()), int(xs.max()), int(ys.max())] |
| else: |
| box = [0, 0, 0, 0] |
| return [{"mask": mask, "box": box, "score": 1.0}] |
|
|
| h = w = self._render_size |
| mask = np.zeros((h, w), dtype=bool) |
| mask[h // 4: 3 * h // 4, w // 4: 3 * w // 4] = True |
| return [{"mask": mask, "box": [w // 4, h // 4, 3 * w // 4, 3 * h // 4], "score": 0.5}] |
|
|
| |
|
|
| def close(self) -> None: |
| if hasattr(self, "_env") and self._env is not None: |
| self._env.close() |
|
|