| """Gymnasium environment for Boston Dynamics Spot in MuJoCo.""" |
| import gymnasium as gym |
| from gymnasium import spaces |
| import numpy as np |
| import mujoco |
| import os |
| import importlib.util |
|
|
| |
| _spot_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
| |
| _mpc_gait_path = os.path.join(_spot_dir, "controllers", "mpc_gait.py") |
| _spec = importlib.util.spec_from_file_location("spot_mpc_gait", _mpc_gait_path) |
| _mpc_gait_module = importlib.util.module_from_spec(_spec) |
| _spec.loader.exec_module(_mpc_gait_module) |
| MPCGaitController = _mpc_gait_module.MPCGaitController |
|
|
| |
| _trot_gait_path = os.path.join(_spot_dir, "controllers", "trot_gait.py") |
| _spec = importlib.util.spec_from_file_location("spot_trot_gait", _trot_gait_path) |
| _trot_gait_module = importlib.util.module_from_spec(_spec) |
| _spec.loader.exec_module(_trot_gait_module) |
| TrotGaitController = _trot_gait_module.TrotGaitController |
|
|
| |
| _pympc_path = os.path.join(_spot_dir, "controllers", "quadruped_pympc_controller.py") |
| _spec = importlib.util.spec_from_file_location("spot_pympc", _pympc_path) |
| _pympc_module = importlib.util.module_from_spec(_spec) |
| _spec.loader.exec_module(_pympc_module) |
| QuadrupedPyMPCController = _pympc_module.QuadrupedPyMPCController |
| PYMPC_AVAILABLE = _pympc_module.PYMPC_AVAILABLE |
|
|
|
|
| class SpotEnv(gym.Env): |
| """ |
| Gymnasium environment for the Boston Dynamics Spot quadruped robot in MuJoCo. |
| Spot has 12 actuated joints (3 per leg × 4 legs). |
| """ |
| metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 30} |
|
|
| |
| |
| JOINT_NAMES = [ |
| |
| "fl_hx", "fl_hy", "fl_kn", |
| |
| "fr_hx", "fr_hy", "fr_kn", |
| |
| "hl_hx", "hl_hy", "hl_kn", |
| |
| "hr_hx", "hr_hy", "hr_kn", |
| ] |
|
|
| |
| DEFAULT_STANDING_POSE = np.array([ |
| 0.0, 1.04, -1.8, |
| 0.0, 1.04, -1.8, |
| 0.0, 1.04, -1.8, |
| 0.0, 1.04, -1.8, |
| ], dtype=np.float32) |
|
|
| def __init__(self, render_mode=None, width=1280, height=720, controller_type='mpc_gait'): |
| """Initialize Spot environment. |
| |
| Args: |
| render_mode: 'human' or 'rgb_array' |
| width: Render width |
| height: Render height |
| controller_type: 'mpc_gait' (default), 'pympc' (full MPC), or 'trot' |
| """ |
| super().__init__() |
|
|
| |
| spot_dir = os.path.dirname(os.path.abspath(__file__)) |
| model_path = os.path.join(spot_dir, "model", "scene.xml") |
| self.model = mujoco.MjModel.from_xml_path(model_path) |
|
|
| |
| self.model.vis.global_.offwidth = width |
| self.model.vis.global_.offheight = height |
|
|
| self.data = mujoco.MjData(self.model) |
|
|
| |
| self.num_actuators = self.model.nu |
|
|
| |
| |
| self.action_space = spaces.Box( |
| low=-3.14, high=3.14, |
| shape=(self.num_actuators,), |
| dtype=np.float32 |
| ) |
|
|
| |
| |
| |
| |
| |
| obs_dim = 37 |
| self.observation_space = spaces.Box( |
| low=-np.inf, high=np.inf, |
| shape=(obs_dim,), |
| dtype=np.float32 |
| ) |
|
|
| self.render_mode = render_mode |
| self.width = width |
| self.height = height |
| self.renderer = None |
|
|
| self.steps = 0 |
| self.max_steps = 1000 |
|
|
| |
| self.controller_type = controller_type |
| if controller_type == 'pympc' and PYMPC_AVAILABLE: |
| self.controller = QuadrupedPyMPCController( |
| num_joints=self.num_actuators, |
| model=self.model, |
| data=self.data |
| ) |
| elif controller_type == 'trot': |
| self.controller = TrotGaitController(self.num_actuators) |
| else: |
| |
| self.controller = MPCGaitController(self.num_actuators) |
|
|
| def set_command(self, vx: float = 0.0, vy: float = 0.0, vyaw: float = 0.0): |
| """Set velocity command for controller.""" |
| if self.controller: |
| self.controller.set_command(vx, vy, vyaw) |
|
|
| def get_command(self): |
| """Get current velocity command.""" |
| if self.controller: |
| return self.controller.get_command() |
| return np.array([0.0, 0.0, 0.0], dtype=np.float32) |
|
|
| def _get_obs(self): |
| """Get observation from simulation state.""" |
| |
| base_pos = self.data.qpos[:3].copy() |
| base_quat = self.data.qpos[3:7].copy() |
|
|
| |
| base_lin_vel = self.data.qvel[:3].copy() |
| base_ang_vel = self.data.qvel[3:6].copy() |
|
|
| |
| joint_pos = self.data.qpos[7:].copy() |
| joint_vel = self.data.qvel[6:].copy() |
|
|
| return np.concatenate([ |
| base_pos, base_quat, |
| base_lin_vel, base_ang_vel, |
| joint_pos, joint_vel |
| ]).astype(np.float32) |
|
|
| def reset(self, seed=None, options=None): |
| super().reset(seed=seed) |
|
|
| mujoco.mj_resetData(self.model, self.data) |
|
|
| |
| |
| self.data.qpos[2] = 0.46 |
|
|
| |
| self.data.qpos[7:] = self.DEFAULT_STANDING_POSE.copy() |
|
|
| |
| self.data.ctrl[:] = self.DEFAULT_STANDING_POSE.copy() |
|
|
| mujoco.mj_forward(self.model, self.data) |
|
|
| self.steps = 0 |
|
|
| |
| if self.controller: |
| self.controller.reset() |
|
|
| observation = self._get_obs() |
| return observation, {} |
|
|
| def step(self, action): |
| """Step with explicit action (for RL training).""" |
| |
| self.data.ctrl[:] = action |
|
|
| |
| mujoco.mj_step(self.model, self.data) |
| self.steps += 1 |
|
|
| observation = self._get_obs() |
|
|
| |
| base_height = self.data.qpos[2] |
| base_quat = self.data.qpos[3:7] |
| upright_reward = base_quat[0] ** 2 |
| height_reward = min(base_height, 0.46) / 0.46 |
|
|
| reward = height_reward + upright_reward |
|
|
| |
| terminated = base_height < 0.2 |
| truncated = self.steps >= self.max_steps |
|
|
| info = { |
| "base_height": base_height, |
| "upright": upright_reward, |
| } |
|
|
| return observation, reward, terminated, truncated, info |
|
|
| def step_with_controller(self, dt: float = 0.002): |
| """Step using the active controller (for visualization/demo).""" |
| observation = self._get_obs() |
|
|
| |
| if self.controller: |
| action = self.controller.compute_action(observation, self.data) |
| self.controller.step(dt) |
| else: |
| |
| action = self.DEFAULT_STANDING_POSE.copy() |
|
|
| |
| self.data.ctrl[:] = action |
|
|
| |
| mujoco.mj_step(self.model, self.data) |
| self.steps += 1 |
|
|
| return self._get_obs() |
|
|
| def render(self): |
| if self.render_mode == "rgb_array": |
| if self.renderer is None: |
| self.renderer = mujoco.Renderer(self.model, height=self.height, width=self.width) |
| self.renderer.update_scene(self.data) |
| return self.renderer.render() |
| return None |
|
|
| def close(self): |
| if self.renderer: |
| self.renderer.close() |
| self.renderer = None |
|
|