project_naka / code /robosuite_sim.py
ilessio-aiflowlab's picture
Upload folder using huggingface_hub
665e529 verified
"""RoboSuite simulator — real MuJoCo physics with OSC arm control.
Uses RoboSuite's default OSC_POSE controller for precise end-effector
control. Provides goto_pose, gripper, IK, and blocking motion.
"""
from __future__ import annotations
import os
from typing import Any
import numpy as np
from anima_naka.gym.base_env import BaseEnv
os.environ.setdefault("MUJOCO_GL", "egl")
_ROBOSUITE_ENVS = {
"robosuite_cube_lift": ("Lift", {}),
"robosuite_cube_stack": ("Stack", {}),
"robosuite_spill_wipe": ("Wipe", {}),
"robosuite_peg_insertion": ("NutAssemblySquare", {}),
"robosuite_cube_restack": ("TwoArmLift", {"env_configuration": "parallel"}),
"robosuite_two_arm_lift": ("TwoArmLift", {"env_configuration": "parallel"}),
"robosuite_two_arm_handover": ("TwoArmHandover", {"env_configuration": "parallel"}),
}
# Panda joint limits (radians)
_JOINT_LOWER = np.array([-2.8973, -1.7628, -2.8973, -3.0718, -2.8973, -0.0175, -2.8973])
_JOINT_UPPER = np.array([2.8973, 1.7628, 2.8973, -0.0698, 2.8973, 3.7525, 2.8973])
class RoboSuiteSim(BaseEnv):
"""RoboSuite MuJoCo env with OSC pose control.
Default controller: OSC_POSE (delta input, 6 pose dims + 1 gripper).
Actions: [dx, dy, dz, dax, day, daz, gripper] where gripper: -1=open, 1=close.
"""
def __init__(
self,
sim_name: str = "robosuite_cube_lift",
bimanual: bool = False,
render_size: int = 640,
):
import robosuite as suite
env_name, kwargs = _ROBOSUITE_ENVS.get(sim_name, ("Lift", {}))
robots = ["Panda", "Panda"] if bimanual else "Panda"
self._camera_name = "frontview"
self._env = suite.make(
env_name,
robots=robots,
has_renderer=False,
has_offscreen_renderer=True,
use_camera_obs=True,
camera_names=[self._camera_name],
camera_heights=render_size,
camera_widths=render_size,
**kwargs,
)
self._render_size = render_size
self._bimanual = bimanual
self.step_count = 0
self.max_steps = 1500
self._sim_step_count = 0
self._gripper_cmd = -1.0 # -1=open, 1=close
# ---- Reset ----
def reset(self, seed: int | None = None) -> tuple[dict[str, Any], dict[str, Any]]:
self.step_count = 0
self._sim_step_count = 0
self._gripper_cmd = -1.0
if seed is not None:
np.random.seed(seed)
self._env.reset()
return self.get_observation(), {}
# ---- Observation ----
def get_observation(self) -> dict[str, Any]:
obs = self._env._get_observations()
rgb = obs.get(
"frontview_image",
np.zeros((self._render_size, self._render_size, 3), dtype=np.uint8),
)
if rgb.ndim == 3:
rgb = rgb[::-1]
depth_raw = obs.get("frontview_depth", None)
if depth_raw is not None:
if depth_raw.ndim == 3:
depth_raw = depth_raw[:, :, 0]
depth = self._depth_to_meters(depth_raw[::-1])
else:
depth = np.ones((self._render_size, self._render_size), dtype=np.float32)
return {
"robot0_robotview": {
"images": {"rgb": rgb, "depth": depth},
"intrinsics": self._get_camera_intrinsics(self._camera_name),
"pose_mat": self._get_camera_pose_matrix(self._camera_name),
}
}
def _depth_to_meters(self, depth_buffer: np.ndarray) -> np.ndarray:
try:
from robosuite.utils.camera_utils import get_real_depth_map
return get_real_depth_map(self._env.sim, depth_buffer).astype(np.float32)
except ImportError:
extent = self._env.sim.model.stat.extent
near = self._env.sim.model.vis.map.znear * extent
far = self._env.sim.model.vis.map.zfar * extent
return (near / (1.0 - depth_buffer * (1.0 - near / far))).astype(np.float32)
def _get_camera_intrinsics(self, camera_name: str) -> np.ndarray:
try:
cam_id = self._env.sim.model.camera_name2id(camera_name)
fovy = self._env.sim.model.cam_fovy[cam_id]
except Exception:
fovy = 45.0
f = 0.5 * self._render_size / np.tan(np.radians(fovy / 2.0))
cx = cy = self._render_size / 2.0
return np.array([[f, 0, cx], [0, f, cy], [0, 0, 1]], dtype=np.float64)
def _get_camera_pose_matrix(self, camera_name: str) -> np.ndarray:
try:
sim = self._env.sim
mat = np.eye(4, dtype=np.float64)
mat[:3, :3] = sim.data.get_camera_xmat(camera_name).reshape(3, 3)
mat[:3, 3] = sim.data.get_camera_xpos(camera_name)
return mat
except Exception:
return np.eye(4, dtype=np.float64)
# ---- Reward ----
def compute_reward(self) -> float:
"""Compute reward from RoboSuite env."""
try:
return float(self._env.reward())
except TypeError:
# Some envs require action arg
action = np.zeros(self._env.action_dim)
return float(self._env.reward(action))
# ---- Object queries ----
def get_object_position(self, name: str) -> np.ndarray:
"""Get 3D position of named object.
Searches MuJoCo bodies with fuzzy matching:
'cube' matches 'cube_main', 'red cube' matches 'cube_main', etc.
Excludes robot/mount/table infrastructure bodies.
"""
sim = self._env.sim
tokens = name.lower().replace("_", " ").split()
_EXCLUDE = {"world", "table", "robot0", "fixed_mount", "gripper", "link"}
best_match = None
best_score = 0
for body_name in sim.model.body_names:
bn_lower = body_name.lower()
# Skip infrastructure bodies
if any(ex in bn_lower for ex in _EXCLUDE):
continue
score = sum(1 for t in tokens if t in bn_lower)
if score > best_score:
best_score = score
best_match = body_name
if best_match and best_score > 0:
body_id = sim.model.body_name2id(best_match)
return sim.data.body_xpos[body_id].copy()
return np.zeros(3)
def get_table_height(self) -> float:
try:
body_id = self._env.sim.model.body_name2id("table_main")
return float(self._env.sim.data.body_xpos[body_id][2])
except Exception:
return 0.82
# ---- EE state ----
def _get_ee_pos(self) -> np.ndarray:
"""Get end-effector position from sim."""
obs = self._env._get_observations()
return np.array(obs["robot0_eef_pos"], dtype=np.float64)
def _get_ee_quat_xyzw(self) -> np.ndarray:
"""Get end-effector quaternion (xyzw) from sim."""
obs = self._env._get_observations()
return np.array(obs["robot0_eef_quat"], dtype=np.float64)
def get_joint_positions(self) -> np.ndarray:
obs = self._env._get_observations()
return np.array(obs["robot0_joint_pos"][:7], dtype=np.float64)
# ---- Low-level stepping ----
def _step_action(self, action: np.ndarray) -> None:
"""Step the simulation with the given action vector."""
action_dim = self._env.action_dim
if len(action) > action_dim:
action = action[:action_dim]
elif len(action) < action_dim:
action = np.pad(action, (0, action_dim - len(action)))
self._env.step(action)
self._sim_step_count += 1
def _step_zero(self) -> None:
"""Step with zero delta (hold position, maintain gripper)."""
action = np.zeros(self._env.action_dim)
action[-1] = self._gripper_cmd
self._step_action(action)
# ---- Gripper ----
def open_gripper(self) -> None:
"""Open gripper with 30 physics steps."""
self._gripper_cmd = -1.0
for _ in range(30):
self._step_zero()
def close_gripper(self) -> None:
"""Close gripper with 30 physics steps."""
self._gripper_cmd = 1.0
for _ in range(30):
self._step_zero()
# ---- Motion control ----
def goto_pose(
self,
position: np.ndarray,
quaternion_wxyz: np.ndarray,
z_approach: float = 0.0,
) -> None:
"""Move EE to target pose using OSC delta control loop.
Args:
position: (3,) target xyz in world frame.
quaternion_wxyz: (4,) target orientation [w,x,y,z].
z_approach: If > 0, approach from above first.
"""
pos = np.asarray(position, dtype=np.float64)
quat = np.asarray(quaternion_wxyz, dtype=np.float64)
if z_approach > 0:
approach_pos = pos.copy()
approach_pos[2] += z_approach
self._move_to_position(approach_pos, quat, max_steps=150)
self._move_to_position(pos, quat, max_steps=200)
def _move_to_position(
self,
target_pos: np.ndarray,
target_quat_wxyz: np.ndarray,
max_steps: int = 200,
pos_tol: float = 0.01,
) -> None:
"""OSC delta control loop: send deltas until EE reaches target."""
gain = 10.0 # Scale factor for delta commands (OSC expects [-1, 1])
for _ in range(max_steps):
current_pos = self._get_ee_pos()
pos_error = target_pos - current_pos
if np.linalg.norm(pos_error) < pos_tol:
break
# Clip deltas to [-1, 1] range
delta_pos = np.clip(pos_error * gain, -1.0, 1.0)
# Orientation: zero delta (maintain current orientation)
delta_ori = np.zeros(3)
action = np.concatenate([delta_pos, delta_ori, [self._gripper_cmd]])
self._step_action(action)
def move_to_joints(self, joints: np.ndarray) -> None:
"""Move toward target joint config using OSC (approximate).
Since we use OSC_POSE controller, we can't directly set joints.
Instead, we compute the forward kinematics target and use goto_pose.
"""
# Use MuJoCo to compute FK: what pose do these joints correspond to?
target = np.asarray(joints, dtype=np.float64).reshape(7)
try:
import mujoco
model = self._env.sim.model._model
data = self._env.sim.data._data
qpos_save = data.qpos.copy()
# Temporarily set joints to compute FK
data.qpos[:7] = target
mujoco.mj_forward(model, data)
# Get EE position at target joints
site_name = self._find_ee_site()
if site_name:
site_id = mujoco.mj_name2id(
model, mujoco.mjtObj.mjOBJ_SITE, site_name
)
target_pos = data.site_xpos[site_id].copy()
else:
target_pos = self._get_ee_pos()
# Restore
data.qpos[:] = qpos_save
mujoco.mj_forward(model, data)
# Move to computed position using OSC
ee_quat = self._get_ee_quat_xyzw()
wxyz = np.array([ee_quat[3], ee_quat[0], ee_quat[1], ee_quat[2]])
self._move_to_position(target_pos, wxyz, max_steps=150)
except Exception:
# Fallback: just step
for _ in range(50):
self._step_zero()
def solve_ik(
self, position: np.ndarray, quaternion_wxyz: np.ndarray
) -> np.ndarray:
"""Solve IK using MuJoCo Jacobian-based iterative solver."""
target_pos = np.asarray(position, dtype=np.float64)
try:
import mujoco
model = self._env.sim.model._model
data = self._env.sim.data._data
site_name = self._find_ee_site()
if not site_name:
return self.get_joint_positions()
site_id = mujoco.mj_name2id(
model, mujoco.mjtObj.mjOBJ_SITE, site_name
)
qpos_save = data.qpos.copy()
qvel_save = data.qvel.copy()
nq = 7
jacp = np.zeros((3, model.nv))
damping = 0.05
for _ in range(50):
mujoco.mj_forward(model, data)
current_pos = data.site_xpos[site_id].copy()
pos_err = target_pos - current_pos
if np.linalg.norm(pos_err) < 1e-3:
break
mujoco.mj_jacSite(model, data, jacp, None, site_id)
J = jacp[:, :nq]
JJT = J @ J.T + damping**2 * np.eye(3)
dq = J.T @ np.linalg.solve(JJT, pos_err)
data.qpos[:nq] += dq
result = data.qpos[:nq].copy()
data.qpos[:] = qpos_save
data.qvel[:] = qvel_save
mujoco.mj_forward(model, data)
return result
except Exception:
return self.get_joint_positions()
def _find_ee_site(self) -> str | None:
"""Find end-effector site name in MuJoCo model."""
import mujoco
model = self._env.sim.model._model
candidates = ["grip_site", "gripper0_grip_site", "robot0_grip_site", "ee_site"]
for name in candidates:
try:
mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_SITE, name)
return name
except Exception:
continue
for i in range(model.nsite):
site_name = mujoco.mj_id2name(model, mujoco.mjtObj.mjOBJ_SITE, i)
if site_name and "grip" in site_name.lower():
return site_name
return None
def sample_grasp_pose(self, object_name: str) -> tuple[np.ndarray, np.ndarray]:
"""Get a top-down grasp pose for the named object.
Uses the Panda's home EE orientation (natural grasp direction).
"""
obj_pos = self.get_object_position(object_name)
if np.allclose(obj_pos, 0.0):
return np.zeros(3), np.array([1.0, 0.0, 0.0, 0.0])
grasp_pos = obj_pos.copy()
grasp_pos[2] += 0.003 # Slight offset above object center
# Use current EE orientation (Panda home = natural top-down grasp)
ee_quat_xyzw = self._get_ee_quat_xyzw()
grasp_quat_wxyz = np.array([
ee_quat_xyzw[3], ee_quat_xyzw[0], ee_quat_xyzw[1], ee_quat_xyzw[2],
])
return grasp_pos, grasp_quat_wxyz
def get_ground_truth_masks(self, text_prompt: str) -> list[dict]:
"""S1 tier: segmentation from MuJoCo."""
obs = self._env._get_observations()
seg_key = "robot0_robotview_segmentation_instance"
if seg_key in obs:
seg = obs[seg_key][::-1]
mask = seg > 0
if mask.ndim == 3:
mask = mask[:, :, 0]
ys, xs = np.where(mask)
if len(xs) > 0:
box = [int(xs.min()), int(ys.min()), int(xs.max()), int(ys.max())]
else:
box = [0, 0, 0, 0]
return [{"mask": mask, "box": box, "score": 1.0}]
h = w = self._render_size
mask = np.zeros((h, w), dtype=bool)
mask[h // 4: 3 * h // 4, w // 4: 3 * w // 4] = True
return [{"mask": mask, "box": [w // 4, h // 4, 3 * w // 4, 3 * h // 4], "score": 0.5}]
# ---- Cleanup ----
def close(self) -> None:
if hasattr(self, "_env") and self._env is not None:
self._env.close()