"""RoboSuite simulator — real MuJoCo physics with OSC arm control. Uses RoboSuite's default OSC_POSE controller for precise end-effector control. Provides goto_pose, gripper, IK, and blocking motion. """ from __future__ import annotations import os from typing import Any import numpy as np from anima_naka.gym.base_env import BaseEnv os.environ.setdefault("MUJOCO_GL", "egl") _ROBOSUITE_ENVS = { "robosuite_cube_lift": ("Lift", {}), "robosuite_cube_stack": ("Stack", {}), "robosuite_spill_wipe": ("Wipe", {}), "robosuite_peg_insertion": ("NutAssemblySquare", {}), "robosuite_cube_restack": ("TwoArmLift", {"env_configuration": "parallel"}), "robosuite_two_arm_lift": ("TwoArmLift", {"env_configuration": "parallel"}), "robosuite_two_arm_handover": ("TwoArmHandover", {"env_configuration": "parallel"}), } # Panda joint limits (radians) _JOINT_LOWER = np.array([-2.8973, -1.7628, -2.8973, -3.0718, -2.8973, -0.0175, -2.8973]) _JOINT_UPPER = np.array([2.8973, 1.7628, 2.8973, -0.0698, 2.8973, 3.7525, 2.8973]) class RoboSuiteSim(BaseEnv): """RoboSuite MuJoCo env with OSC pose control. Default controller: OSC_POSE (delta input, 6 pose dims + 1 gripper). Actions: [dx, dy, dz, dax, day, daz, gripper] where gripper: -1=open, 1=close. """ def __init__( self, sim_name: str = "robosuite_cube_lift", bimanual: bool = False, render_size: int = 640, ): import robosuite as suite env_name, kwargs = _ROBOSUITE_ENVS.get(sim_name, ("Lift", {})) robots = ["Panda", "Panda"] if bimanual else "Panda" self._camera_name = "frontview" self._env = suite.make( env_name, robots=robots, has_renderer=False, has_offscreen_renderer=True, use_camera_obs=True, camera_names=[self._camera_name], camera_heights=render_size, camera_widths=render_size, **kwargs, ) self._render_size = render_size self._bimanual = bimanual self.step_count = 0 self.max_steps = 1500 self._sim_step_count = 0 self._gripper_cmd = -1.0 # -1=open, 1=close # ---- Reset ---- def reset(self, seed: int | None = None) -> tuple[dict[str, Any], dict[str, Any]]: self.step_count = 0 self._sim_step_count = 0 self._gripper_cmd = -1.0 if seed is not None: np.random.seed(seed) self._env.reset() return self.get_observation(), {} # ---- Observation ---- def get_observation(self) -> dict[str, Any]: obs = self._env._get_observations() rgb = obs.get( "frontview_image", np.zeros((self._render_size, self._render_size, 3), dtype=np.uint8), ) if rgb.ndim == 3: rgb = rgb[::-1] depth_raw = obs.get("frontview_depth", None) if depth_raw is not None: if depth_raw.ndim == 3: depth_raw = depth_raw[:, :, 0] depth = self._depth_to_meters(depth_raw[::-1]) else: depth = np.ones((self._render_size, self._render_size), dtype=np.float32) return { "robot0_robotview": { "images": {"rgb": rgb, "depth": depth}, "intrinsics": self._get_camera_intrinsics(self._camera_name), "pose_mat": self._get_camera_pose_matrix(self._camera_name), } } def _depth_to_meters(self, depth_buffer: np.ndarray) -> np.ndarray: try: from robosuite.utils.camera_utils import get_real_depth_map return get_real_depth_map(self._env.sim, depth_buffer).astype(np.float32) except ImportError: extent = self._env.sim.model.stat.extent near = self._env.sim.model.vis.map.znear * extent far = self._env.sim.model.vis.map.zfar * extent return (near / (1.0 - depth_buffer * (1.0 - near / far))).astype(np.float32) def _get_camera_intrinsics(self, camera_name: str) -> np.ndarray: try: cam_id = self._env.sim.model.camera_name2id(camera_name) fovy = self._env.sim.model.cam_fovy[cam_id] except Exception: fovy = 45.0 f = 0.5 * self._render_size / np.tan(np.radians(fovy / 2.0)) cx = cy = self._render_size / 2.0 return np.array([[f, 0, cx], [0, f, cy], [0, 0, 1]], dtype=np.float64) def _get_camera_pose_matrix(self, camera_name: str) -> np.ndarray: try: sim = self._env.sim mat = np.eye(4, dtype=np.float64) mat[:3, :3] = sim.data.get_camera_xmat(camera_name).reshape(3, 3) mat[:3, 3] = sim.data.get_camera_xpos(camera_name) return mat except Exception: return np.eye(4, dtype=np.float64) # ---- Reward ---- def compute_reward(self) -> float: """Compute reward from RoboSuite env.""" try: return float(self._env.reward()) except TypeError: # Some envs require action arg action = np.zeros(self._env.action_dim) return float(self._env.reward(action)) # ---- Object queries ---- def get_object_position(self, name: str) -> np.ndarray: """Get 3D position of named object. Searches MuJoCo bodies with fuzzy matching: 'cube' matches 'cube_main', 'red cube' matches 'cube_main', etc. Excludes robot/mount/table infrastructure bodies. """ sim = self._env.sim tokens = name.lower().replace("_", " ").split() _EXCLUDE = {"world", "table", "robot0", "fixed_mount", "gripper", "link"} best_match = None best_score = 0 for body_name in sim.model.body_names: bn_lower = body_name.lower() # Skip infrastructure bodies if any(ex in bn_lower for ex in _EXCLUDE): continue score = sum(1 for t in tokens if t in bn_lower) if score > best_score: best_score = score best_match = body_name if best_match and best_score > 0: body_id = sim.model.body_name2id(best_match) return sim.data.body_xpos[body_id].copy() return np.zeros(3) def get_table_height(self) -> float: try: body_id = self._env.sim.model.body_name2id("table_main") return float(self._env.sim.data.body_xpos[body_id][2]) except Exception: return 0.82 # ---- EE state ---- def _get_ee_pos(self) -> np.ndarray: """Get end-effector position from sim.""" obs = self._env._get_observations() return np.array(obs["robot0_eef_pos"], dtype=np.float64) def _get_ee_quat_xyzw(self) -> np.ndarray: """Get end-effector quaternion (xyzw) from sim.""" obs = self._env._get_observations() return np.array(obs["robot0_eef_quat"], dtype=np.float64) def get_joint_positions(self) -> np.ndarray: obs = self._env._get_observations() return np.array(obs["robot0_joint_pos"][:7], dtype=np.float64) # ---- Low-level stepping ---- def _step_action(self, action: np.ndarray) -> None: """Step the simulation with the given action vector.""" action_dim = self._env.action_dim if len(action) > action_dim: action = action[:action_dim] elif len(action) < action_dim: action = np.pad(action, (0, action_dim - len(action))) self._env.step(action) self._sim_step_count += 1 def _step_zero(self) -> None: """Step with zero delta (hold position, maintain gripper).""" action = np.zeros(self._env.action_dim) action[-1] = self._gripper_cmd self._step_action(action) # ---- Gripper ---- def open_gripper(self) -> None: """Open gripper with 30 physics steps.""" self._gripper_cmd = -1.0 for _ in range(30): self._step_zero() def close_gripper(self) -> None: """Close gripper with 30 physics steps.""" self._gripper_cmd = 1.0 for _ in range(30): self._step_zero() # ---- Motion control ---- def goto_pose( self, position: np.ndarray, quaternion_wxyz: np.ndarray, z_approach: float = 0.0, ) -> None: """Move EE to target pose using OSC delta control loop. Args: position: (3,) target xyz in world frame. quaternion_wxyz: (4,) target orientation [w,x,y,z]. z_approach: If > 0, approach from above first. """ pos = np.asarray(position, dtype=np.float64) quat = np.asarray(quaternion_wxyz, dtype=np.float64) if z_approach > 0: approach_pos = pos.copy() approach_pos[2] += z_approach self._move_to_position(approach_pos, quat, max_steps=150) self._move_to_position(pos, quat, max_steps=200) def _move_to_position( self, target_pos: np.ndarray, target_quat_wxyz: np.ndarray, max_steps: int = 200, pos_tol: float = 0.01, ) -> None: """OSC delta control loop: send deltas until EE reaches target.""" gain = 10.0 # Scale factor for delta commands (OSC expects [-1, 1]) for _ in range(max_steps): current_pos = self._get_ee_pos() pos_error = target_pos - current_pos if np.linalg.norm(pos_error) < pos_tol: break # Clip deltas to [-1, 1] range delta_pos = np.clip(pos_error * gain, -1.0, 1.0) # Orientation: zero delta (maintain current orientation) delta_ori = np.zeros(3) action = np.concatenate([delta_pos, delta_ori, [self._gripper_cmd]]) self._step_action(action) def move_to_joints(self, joints: np.ndarray) -> None: """Move toward target joint config using OSC (approximate). Since we use OSC_POSE controller, we can't directly set joints. Instead, we compute the forward kinematics target and use goto_pose. """ # Use MuJoCo to compute FK: what pose do these joints correspond to? target = np.asarray(joints, dtype=np.float64).reshape(7) try: import mujoco model = self._env.sim.model._model data = self._env.sim.data._data qpos_save = data.qpos.copy() # Temporarily set joints to compute FK data.qpos[:7] = target mujoco.mj_forward(model, data) # Get EE position at target joints site_name = self._find_ee_site() if site_name: site_id = mujoco.mj_name2id( model, mujoco.mjtObj.mjOBJ_SITE, site_name ) target_pos = data.site_xpos[site_id].copy() else: target_pos = self._get_ee_pos() # Restore data.qpos[:] = qpos_save mujoco.mj_forward(model, data) # Move to computed position using OSC ee_quat = self._get_ee_quat_xyzw() wxyz = np.array([ee_quat[3], ee_quat[0], ee_quat[1], ee_quat[2]]) self._move_to_position(target_pos, wxyz, max_steps=150) except Exception: # Fallback: just step for _ in range(50): self._step_zero() def solve_ik( self, position: np.ndarray, quaternion_wxyz: np.ndarray ) -> np.ndarray: """Solve IK using MuJoCo Jacobian-based iterative solver.""" target_pos = np.asarray(position, dtype=np.float64) try: import mujoco model = self._env.sim.model._model data = self._env.sim.data._data site_name = self._find_ee_site() if not site_name: return self.get_joint_positions() site_id = mujoco.mj_name2id( model, mujoco.mjtObj.mjOBJ_SITE, site_name ) qpos_save = data.qpos.copy() qvel_save = data.qvel.copy() nq = 7 jacp = np.zeros((3, model.nv)) damping = 0.05 for _ in range(50): mujoco.mj_forward(model, data) current_pos = data.site_xpos[site_id].copy() pos_err = target_pos - current_pos if np.linalg.norm(pos_err) < 1e-3: break mujoco.mj_jacSite(model, data, jacp, None, site_id) J = jacp[:, :nq] JJT = J @ J.T + damping**2 * np.eye(3) dq = J.T @ np.linalg.solve(JJT, pos_err) data.qpos[:nq] += dq result = data.qpos[:nq].copy() data.qpos[:] = qpos_save data.qvel[:] = qvel_save mujoco.mj_forward(model, data) return result except Exception: return self.get_joint_positions() def _find_ee_site(self) -> str | None: """Find end-effector site name in MuJoCo model.""" import mujoco model = self._env.sim.model._model candidates = ["grip_site", "gripper0_grip_site", "robot0_grip_site", "ee_site"] for name in candidates: try: mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_SITE, name) return name except Exception: continue for i in range(model.nsite): site_name = mujoco.mj_id2name(model, mujoco.mjtObj.mjOBJ_SITE, i) if site_name and "grip" in site_name.lower(): return site_name return None def sample_grasp_pose(self, object_name: str) -> tuple[np.ndarray, np.ndarray]: """Get a top-down grasp pose for the named object. Uses the Panda's home EE orientation (natural grasp direction). """ obj_pos = self.get_object_position(object_name) if np.allclose(obj_pos, 0.0): return np.zeros(3), np.array([1.0, 0.0, 0.0, 0.0]) grasp_pos = obj_pos.copy() grasp_pos[2] += 0.003 # Slight offset above object center # Use current EE orientation (Panda home = natural top-down grasp) ee_quat_xyzw = self._get_ee_quat_xyzw() grasp_quat_wxyz = np.array([ ee_quat_xyzw[3], ee_quat_xyzw[0], ee_quat_xyzw[1], ee_quat_xyzw[2], ]) return grasp_pos, grasp_quat_wxyz def get_ground_truth_masks(self, text_prompt: str) -> list[dict]: """S1 tier: segmentation from MuJoCo.""" obs = self._env._get_observations() seg_key = "robot0_robotview_segmentation_instance" if seg_key in obs: seg = obs[seg_key][::-1] mask = seg > 0 if mask.ndim == 3: mask = mask[:, :, 0] ys, xs = np.where(mask) if len(xs) > 0: box = [int(xs.min()), int(ys.min()), int(xs.max()), int(ys.max())] else: box = [0, 0, 0, 0] return [{"mask": mask, "box": box, "score": 1.0}] h = w = self._render_size mask = np.zeros((h, w), dtype=bool) mask[h // 4: 3 * h // 4, w // 4: 3 * w // 4] = True return [{"mask": mask, "box": [w // 4, h // 4, 3 * w // 4, 3 * h // 4], "score": 0.5}] # ---- Cleanup ---- def close(self) -> None: if hasattr(self, "_env") and self._env is not None: self._env.close()