nova-sim / robots /spot /spot_env.py
gpue's picture
Enhance Spot robot control by integrating PyMPC and MPC Gait controllers
6487de0
"""Gymnasium environment for Boston Dynamics Spot in MuJoCo."""
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import mujoco
import os
import importlib.util
# Import controllers using importlib to avoid module namespace collision
_spot_dir = os.path.dirname(os.path.abspath(__file__))
# Import MPC Gait Controller (default)
_mpc_gait_path = os.path.join(_spot_dir, "controllers", "mpc_gait.py")
_spec = importlib.util.spec_from_file_location("spot_mpc_gait", _mpc_gait_path)
_mpc_gait_module = importlib.util.module_from_spec(_spec)
_spec.loader.exec_module(_mpc_gait_module)
MPCGaitController = _mpc_gait_module.MPCGaitController
# Import Trot Gait Controller (fallback)
_trot_gait_path = os.path.join(_spot_dir, "controllers", "trot_gait.py")
_spec = importlib.util.spec_from_file_location("spot_trot_gait", _trot_gait_path)
_trot_gait_module = importlib.util.module_from_spec(_spec)
_spec.loader.exec_module(_trot_gait_module)
TrotGaitController = _trot_gait_module.TrotGaitController
# Import Quadruped-PyMPC Controller (full MPC)
_pympc_path = os.path.join(_spot_dir, "controllers", "quadruped_pympc_controller.py")
_spec = importlib.util.spec_from_file_location("spot_pympc", _pympc_path)
_pympc_module = importlib.util.module_from_spec(_spec)
_spec.loader.exec_module(_pympc_module)
QuadrupedPyMPCController = _pympc_module.QuadrupedPyMPCController
PYMPC_AVAILABLE = _pympc_module.PYMPC_AVAILABLE
class SpotEnv(gym.Env):
"""
Gymnasium environment for the Boston Dynamics Spot quadruped robot in MuJoCo.
Spot has 12 actuated joints (3 per leg × 4 legs).
"""
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 30}
# Joint names for Spot (12 actuated joints)
# Each leg has: hx (hip roll), hy (hip pitch), kn (knee)
JOINT_NAMES = [
# Front Left
"fl_hx", "fl_hy", "fl_kn",
# Front Right
"fr_hx", "fr_hy", "fr_kn",
# Hind Left
"hl_hx", "hl_hy", "hl_kn",
# Hind Right
"hr_hx", "hr_hy", "hr_kn",
]
# Standing pose - from model's "home" keyframe
DEFAULT_STANDING_POSE = np.array([
0.0, 1.04, -1.8, # FL: hx, hy, kn
0.0, 1.04, -1.8, # FR
0.0, 1.04, -1.8, # HL
0.0, 1.04, -1.8, # HR
], dtype=np.float32)
def __init__(self, render_mode=None, width=1280, height=720, controller_type='mpc_gait'):
"""Initialize Spot environment.
Args:
render_mode: 'human' or 'rgb_array'
width: Render width
height: Render height
controller_type: 'mpc_gait' (default), 'pympc' (full MPC), or 'trot'
"""
super().__init__()
# Load model from local assets
spot_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(spot_dir, "model", "scene.xml")
self.model = mujoco.MjModel.from_xml_path(model_path)
# Override framebuffer size for rendering at higher resolution
self.model.vis.global_.offwidth = width
self.model.vis.global_.offheight = height
self.data = mujoco.MjData(self.model)
# Number of actuators (12 position actuators)
self.num_actuators = self.model.nu
# Action space: target positions for all 12 actuators
# The Spot model uses position actuators, so actions are target joint angles
self.action_space = spaces.Box(
low=-3.14, high=3.14,
shape=(self.num_actuators,),
dtype=np.float32
)
# Observation space:
# - Base position (3) and orientation quaternion (4)
# - Base linear velocity (3) and angular velocity (3)
# - Joint positions (12) and velocities (12)
# Total: 3 + 4 + 3 + 3 + 12 + 12 = 37
obs_dim = 37
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf,
shape=(obs_dim,),
dtype=np.float32
)
self.render_mode = render_mode
self.width = width
self.height = height
self.renderer = None
self.steps = 0
self.max_steps = 1000
# Select controller based on type
self.controller_type = controller_type
if controller_type == 'pympc' and PYMPC_AVAILABLE:
self.controller = QuadrupedPyMPCController(
num_joints=self.num_actuators,
model=self.model,
data=self.data
)
elif controller_type == 'trot':
self.controller = TrotGaitController(self.num_actuators)
else:
# Default: MPC Gait (feedback-based, no external deps)
self.controller = MPCGaitController(self.num_actuators)
def set_command(self, vx: float = 0.0, vy: float = 0.0, vyaw: float = 0.0):
"""Set velocity command for controller."""
if self.controller:
self.controller.set_command(vx, vy, vyaw)
def get_command(self):
"""Get current velocity command."""
if self.controller:
return self.controller.get_command()
return np.array([0.0, 0.0, 0.0], dtype=np.float32)
def _get_obs(self):
"""Get observation from simulation state."""
# Base position and orientation (from floating base joint)
base_pos = self.data.qpos[:3].copy()
base_quat = self.data.qpos[3:7].copy()
# Base velocities
base_lin_vel = self.data.qvel[:3].copy()
base_ang_vel = self.data.qvel[3:6].copy()
# Joint positions and velocities (skip floating base: 7 qpos, 6 qvel)
joint_pos = self.data.qpos[7:].copy()
joint_vel = self.data.qvel[6:].copy()
return np.concatenate([
base_pos, base_quat,
base_lin_vel, base_ang_vel,
joint_pos, joint_vel
]).astype(np.float32)
def reset(self, seed=None, options=None):
super().reset(seed=seed)
mujoco.mj_resetData(self.model, self.data)
# Set to standing pose using the keyframe
# Base position at standing height
self.data.qpos[2] = 0.46 # Standing height from keyframe
# Initialize joints to standing pose
self.data.qpos[7:] = self.DEFAULT_STANDING_POSE.copy()
# Set control to standing pose (position actuators)
self.data.ctrl[:] = self.DEFAULT_STANDING_POSE.copy()
mujoco.mj_forward(self.model, self.data)
self.steps = 0
# Reset controller
if self.controller:
self.controller.reset()
observation = self._get_obs()
return observation, {}
def step(self, action):
"""Step with explicit action (for RL training)."""
# Apply action directly to position actuators
self.data.ctrl[:] = action
# Step simulation
mujoco.mj_step(self.model, self.data)
self.steps += 1
observation = self._get_obs()
# Simple reward: stay upright and alive
base_height = self.data.qpos[2]
base_quat = self.data.qpos[3:7]
upright_reward = base_quat[0] ** 2
height_reward = min(base_height, 0.46) / 0.46
reward = height_reward + upright_reward
# Termination: robot fell
terminated = base_height < 0.2
truncated = self.steps >= self.max_steps
info = {
"base_height": base_height,
"upright": upright_reward,
}
return observation, reward, terminated, truncated, info
def step_with_controller(self, dt: float = 0.002):
"""Step using the active controller (for visualization/demo)."""
observation = self._get_obs()
# Get action from controller
if self.controller:
action = self.controller.compute_action(observation, self.data)
self.controller.step(dt)
else:
# Default to standing pose
action = self.DEFAULT_STANDING_POSE.copy()
# Apply action to position actuators
self.data.ctrl[:] = action
# Step simulation
mujoco.mj_step(self.model, self.data)
self.steps += 1
return self._get_obs()
def render(self):
if self.render_mode == "rgb_array":
if self.renderer is None:
self.renderer = mujoco.Renderer(self.model, height=self.height, width=self.width)
self.renderer.update_scene(self.data)
return self.renderer.render()
return None
def close(self):
if self.renderer:
self.renderer.close()
self.renderer = None