unitree-g1-mujoco / sim /base_sim.py
nepyope's picture
Update sim/base_sim.py
638fd14 verified
raw
history blame
33.4 kB
import argparse
import pathlib
from pathlib import Path
import threading
from threading import Lock, Thread
from typing import Dict
import mujoco
import mujoco.viewer
import numpy as np
try:
import rclpy
HAS_RCLPY = True
except ImportError:
HAS_RCLPY = False
print("Warning: rclpy not found. Camera image publishing will be disabled.")
from unitree_sdk2py.core.channel import ChannelFactoryInitialize
import yaml
import os
from .image_publish_utils import ImagePublishProcess
from .metric_utils import check_contact, check_height
from .sim_utilts import get_subtree_body_names
from .unitree_sdk2py_bridge import ElasticBand, UnitreeSdk2Bridge
GR00T_WBC_ROOT = Path(__file__).resolve().parent.parent # Points to mujoco_sim_g1/
class DefaultEnv:
"""Base environment class that handles simulation environment setup and step"""
def __init__(
self,
config: Dict[str, any],
env_name: str = "default",
camera_configs: Dict[str, any] = None,
onscreen: bool = False,
offscreen: bool = False,
):
# Avoid mutable default argument gotcha
if camera_configs is None:
camera_configs = {}
# global_view is only set up for this specifc scene for now.
if config["ROBOT_SCENE"] == "gr00t_wbc/control/robot_model/model_data/g1/scene_29dof.xml":
camera_configs["global_view"] = {
"height": 400,
"width": 400,
}
self.config = config
self.env_name = env_name
self.num_body_dof = self.config["NUM_JOINTS"]
self.num_hand_dof = self.config["NUM_HAND_JOINTS"]
self.sim_dt = self.config["SIMULATE_DT"]
self.obs = None
self.torques = np.zeros(self.num_body_dof + self.num_hand_dof * 2)
self.torque_limit = np.array(self.config["motor_effort_limit_list"])
self.camera_configs = camera_configs
# Debug: print camera config
if len(camera_configs) > 0:
print(f"✓ DefaultEnv initialized with {len(camera_configs)} camera(s): {list(camera_configs.keys())}")
# Thread safety lock
self.reward_lock = Lock()
# Unitree bridge will be initialized by the simulator
self.unitree_bridge = None
# Store display mode
self.onscreen = onscreen
# Initialize scene (defined in subclasses)
self.init_scene()
self.last_reward = 0
# Setup offscreen rendering if needed
self.offscreen = offscreen
if self.offscreen:
self.init_renderers()
self.image_dt = self.config.get("IMAGE_DT", 0.033333)
# Image publishing subprocess (initialized separately)
self.image_publish_process = None
def init_scene(self):
"""Initialize the default robot scene"""
assets_root = Path(__file__).parent.parent
self.mj_model = mujoco.MjModel.from_xml_path(
str(assets_root / self.config["ROBOT_SCENE"])
)
self.mj_data = mujoco.MjData(self.mj_model)
self.mj_model.opt.timestep = self.sim_dt
self.torso_index = mujoco.mj_name2id(self.mj_model, mujoco.mjtObj.mjOBJ_BODY, "torso_link")
self.root_body = "pelvis"
# Enable the elastic band
if self.config["ENABLE_ELASTIC_BAND"]:
self.elastic_band = ElasticBand()
if "g1" in self.config["ROBOT_TYPE"]:
if self.config["enable_waist"]:
self.band_attached_link = self.mj_model.body("pelvis").id
else:
self.band_attached_link = self.mj_model.body("torso_link").id
elif "h1" in self.config["ROBOT_TYPE"]:
self.band_attached_link = self.mj_model.body("torso_link").id
else:
self.band_attached_link = self.mj_model.body("base_link").id
if self.onscreen:
self.viewer = mujoco.viewer.launch_passive(
self.mj_model,
self.mj_data,
key_callback=self.elastic_band.MujuocoKeyCallback,
show_left_ui=False,
show_right_ui=False,
)
else:
mujoco.mj_forward(self.mj_model, self.mj_data)
self.viewer = None
else:
if self.onscreen:
self.viewer = mujoco.viewer.launch_passive(
self.mj_model, self.mj_data, show_left_ui=False, show_right_ui=False
)
else:
mujoco.mj_forward(self.mj_model, self.mj_data)
self.viewer = None
if self.viewer:
# viewer camera
self.viewer.cam.azimuth = 120 # Horizontal rotation in degrees
self.viewer.cam.elevation = -30 # Vertical tilt in degrees
self.viewer.cam.distance = 2.0 # Distance from camera to target
self.viewer.cam.lookat = np.array([0, 0, 0.5]) # Point the camera is looking at
# Note that the actuator order is the same as the joint order in the mujoco model.
self.body_joint_index = []
self.left_hand_index = []
self.right_hand_index = []
for i in range(self.mj_model.njnt):
name = self.mj_model.joint(i).name
if any(
[
part_name in name
for part_name in ["hip", "knee", "ankle", "waist", "shoulder", "elbow", "wrist"]
]
):
self.body_joint_index.append(i)
elif "left_hand" in name:
self.left_hand_index.append(i)
elif "right_hand" in name:
self.right_hand_index.append(i)
assert len(self.body_joint_index) == self.config["NUM_JOINTS"], \
f"Expected {self.config['NUM_JOINTS']} body joints, got {len(self.body_joint_index)}"
# Hand joints are optional (some models don't have hands)
if self.config.get("NUM_HAND_JOINTS", 0) > 0:
expected_hands = self.config["NUM_HAND_JOINTS"]
if len(self.left_hand_index) != expected_hands or len(self.right_hand_index) != expected_hands:
print(f"Warning: Expected {expected_hands} hand joints, got left={len(self.left_hand_index)}, right={len(self.right_hand_index)}")
print("Continuing without hands...")
self.body_joint_index = np.array(self.body_joint_index)
self.left_hand_index = np.array(self.left_hand_index)
self.right_hand_index = np.array(self.right_hand_index)
def init_renderers(self):
# Initialize camera renderers
self.renderers = {}
for camera_name, camera_config in self.camera_configs.items():
renderer = mujoco.Renderer(
self.mj_model, height=camera_config["height"], width=camera_config["width"]
)
self.renderers[camera_name] = renderer
def start_image_publish_subprocess(self, start_method: str = "spawn", camera_port: int = 5555):
"""Start image publishing subprocess using ZMQ"""
# Use spawn method for better GIL isolation, or configured method
if len(self.camera_configs) == 0:
print(
"Warning: No camera configs provided, image publishing subprocess will not be started"
)
return
start_method = self.config.get("MP_START_METHOD", "spawn")
self.image_publish_process = ImagePublishProcess(
camera_configs=self.camera_configs,
image_dt=self.image_dt,
zmq_port=camera_port,
start_method=start_method,
verbose=self.config.get("verbose", False),
)
self.image_publish_process.start_process()
print(f"✓ Started image publishing subprocess on ZMQ port {camera_port}")
def compute_body_torques(self) -> np.ndarray:
"""Compute body torques based on the current robot state"""
body_torques = np.zeros(self.num_body_dof)
if self.unitree_bridge is not None and self.unitree_bridge.low_cmd:
for i in range(self.unitree_bridge.num_body_motor):
if self.unitree_bridge.use_sensor:
body_torques[i] = (
self.unitree_bridge.low_cmd.motor_cmd[i].tau
+ self.unitree_bridge.low_cmd.motor_cmd[i].kp
* (self.unitree_bridge.low_cmd.motor_cmd[i].q - self.mj_data.sensordata[i])
+ self.unitree_bridge.low_cmd.motor_cmd[i].kd
* (
self.unitree_bridge.low_cmd.motor_cmd[i].dq
- self.mj_data.sensordata[i + self.unitree_bridge.num_body_motor]
)
)
else:
body_torques[i] = (
self.unitree_bridge.low_cmd.motor_cmd[i].tau
+ self.unitree_bridge.low_cmd.motor_cmd[i].kp
* (
self.unitree_bridge.low_cmd.motor_cmd[i].q
- self.mj_data.qpos[self.body_joint_index[i] + 7 - 1]
)
+ self.unitree_bridge.low_cmd.motor_cmd[i].kd
* (
self.unitree_bridge.low_cmd.motor_cmd[i].dq
- self.mj_data.qvel[self.body_joint_index[i] + 6 - 1]
)
)
return body_torques
def compute_hand_torques(self) -> np.ndarray:
"""Compute hand torques based on the current robot state"""
left_hand_torques = np.zeros(self.num_hand_dof)
right_hand_torques = np.zeros(self.num_hand_dof)
if self.unitree_bridge is not None and self.unitree_bridge.low_cmd:
for i in range(self.unitree_bridge.num_hand_motor):
left_hand_torques[i] = (
self.unitree_bridge.left_hand_cmd.motor_cmd[i].tau
+ self.unitree_bridge.left_hand_cmd.motor_cmd[i].kp
* (
self.unitree_bridge.left_hand_cmd.motor_cmd[i].q
- self.mj_data.qpos[self.left_hand_index[i] + 7 - 1]
)
+ self.unitree_bridge.left_hand_cmd.motor_cmd[i].kd
* (
self.unitree_bridge.left_hand_cmd.motor_cmd[i].dq
- self.mj_data.qvel[self.left_hand_index[i] + 6 - 1]
)
)
right_hand_torques[i] = (
self.unitree_bridge.right_hand_cmd.motor_cmd[i].tau
+ self.unitree_bridge.right_hand_cmd.motor_cmd[i].kp
* (
self.unitree_bridge.right_hand_cmd.motor_cmd[i].q
- self.mj_data.qpos[self.right_hand_index[i] + 7 - 1]
)
+ self.unitree_bridge.right_hand_cmd.motor_cmd[i].kd
* (
self.unitree_bridge.right_hand_cmd.motor_cmd[i].dq
- self.mj_data.qvel[self.right_hand_index[i] + 6 - 1]
)
)
return np.concatenate((left_hand_torques, right_hand_torques))
def compute_body_qpos(self) -> np.ndarray:
"""Compute body joint positions based on the current command"""
body_qpos = np.zeros(self.num_body_dof)
if self.unitree_bridge is not None and self.unitree_bridge.low_cmd:
for i in range(self.unitree_bridge.num_body_motor):
body_qpos[i] = self.unitree_bridge.low_cmd.motor_cmd[i].q
return body_qpos
def compute_hand_qpos(self) -> np.ndarray:
"""Compute hand joint positions based on the current command"""
hand_qpos = np.zeros(self.num_hand_dof * 2)
if self.unitree_bridge is not None and self.unitree_bridge.low_cmd:
for i in range(self.unitree_bridge.num_hand_motor):
hand_qpos[i] = self.unitree_bridge.left_hand_cmd.motor_cmd[i].q
hand_qpos[i + self.num_hand_dof] = self.unitree_bridge.right_hand_cmd.motor_cmd[i].q
return hand_qpos
def prepare_obs(self) -> Dict[str, any]:
"""Prepare observation dictionary from the current robot state"""
obs = {}
obs["floating_base_pose"] = self.mj_data.qpos[:7]
obs["floating_base_vel"] = self.mj_data.qvel[:6]
obs["floating_base_acc"] = self.mj_data.qacc[:6]
obs["secondary_imu_quat"] = self.mj_data.xquat[self.torso_index]
obs["secondary_imu_vel"] = self.mj_data.cvel[self.torso_index]
obs["body_q"] = self.mj_data.qpos[self.body_joint_index + 7 - 1]
obs["body_dq"] = self.mj_data.qvel[self.body_joint_index + 6 - 1]
obs["body_ddq"] = self.mj_data.qacc[self.body_joint_index + 6 - 1]
obs["body_tau_est"] = self.mj_data.actuator_force[self.body_joint_index - 1]
if self.num_hand_dof > 0:
obs["left_hand_q"] = self.mj_data.qpos[self.left_hand_index + 7 - 1]
obs["left_hand_dq"] = self.mj_data.qvel[self.left_hand_index + 6 - 1]
obs["left_hand_ddq"] = self.mj_data.qacc[self.left_hand_index + 6 - 1]
obs["left_hand_tau_est"] = self.mj_data.actuator_force[self.left_hand_index - 1]
obs["right_hand_q"] = self.mj_data.qpos[self.right_hand_index + 7 - 1]
obs["right_hand_dq"] = self.mj_data.qvel[self.right_hand_index + 6 - 1]
obs["right_hand_ddq"] = self.mj_data.qacc[self.right_hand_index + 6 - 1]
obs["right_hand_tau_est"] = self.mj_data.actuator_force[self.right_hand_index - 1]
obs["time"] = self.mj_data.time
return obs
def sim_step(self):
self.obs = self.prepare_obs()
self.unitree_bridge.PublishLowState(self.obs)
if self.unitree_bridge.joystick:
self.unitree_bridge.PublishWirelessController()
if self.config["ENABLE_ELASTIC_BAND"]:
if self.elastic_band.enable:
# Get Cartesian pose and velocity of the band_attached_link
pose = np.concatenate(
[
self.mj_data.xpos[self.band_attached_link], # link position in world
self.mj_data.xquat[
self.band_attached_link
], # link quaternion in world [w,x,y,z]
np.zeros(6), # placeholder for velocity
]
)
# Get velocity in world frame
mujoco.mj_objectVelocity(
self.mj_model,
self.mj_data,
mujoco.mjtObj.mjOBJ_BODY,
self.band_attached_link,
pose[7:13],
0, # 0 for world frame
)
# Reorder velocity from [ang, lin] to [lin, ang]
pose[7:10], pose[10:13] = pose[10:13], pose[7:10].copy()
self.mj_data.xfrc_applied[self.band_attached_link] = self.elastic_band.Advance(pose)
else:
# explicitly resetting the force when the band is not enabled
self.mj_data.xfrc_applied[self.band_attached_link] = np.zeros(6)
body_torques = self.compute_body_torques()
hand_torques = self.compute_hand_torques()
self.torques[self.body_joint_index - 1] = body_torques
if self.num_hand_dof > 0:
self.torques[self.left_hand_index - 1] = hand_torques[: self.num_hand_dof]
self.torques[self.right_hand_index - 1] = hand_torques[self.num_hand_dof :]
self.torques = np.clip(self.torques, -self.torque_limit, self.torque_limit)
if self.config["FREE_BASE"]:
self.mj_data.ctrl = np.concatenate((np.zeros(6), self.torques))
else:
self.mj_data.ctrl = self.torques
mujoco.mj_step(self.mj_model, self.mj_data)
# self.check_self_collision()
def kinematics_step(self):
"""
Run kinematics only: compute the qpos of the robot and directly set the qpos.
For debugging purposes.
"""
if self.unitree_bridge is not None:
self.unitree_bridge.PublishLowState(self.prepare_obs())
if self.unitree_bridge.joystick:
self.unitree_bridge.PublishWirelessController()
if self.config["ENABLE_ELASTIC_BAND"]:
if self.elastic_band.enable:
# Get Cartesian pose and velocity of the band_attached_link
pose = np.concatenate(
[
self.mj_data.xpos[self.band_attached_link], # link position in world
self.mj_data.xquat[
self.band_attached_link
], # link quaternion in world [w,x,y,z]
np.zeros(6), # placeholder for velocity
]
)
# Get velocity in world frame
mujoco.mj_objectVelocity(
self.mj_model,
self.mj_data,
mujoco.mjtObj.mjOBJ_BODY,
self.band_attached_link,
pose[7:13],
0, # 0 for world frame
)
# Reorder velocity from [ang, lin] to [lin, ang]
pose[7:10], pose[10:13] = pose[10:13], pose[7:10].copy()
self.mj_data.xfrc_applied[self.band_attached_link] = self.elastic_band.Advance(pose)
else:
# explicitly resetting the force when the band is not enabled
self.mj_data.xfrc_applied[self.band_attached_link] = np.zeros(6)
body_qpos = self.compute_body_qpos() # (num_body_dof,)
hand_qpos = self.compute_hand_qpos() # (num_hand_dof * 2,)
self.mj_data.qpos[self.body_joint_index + 7 - 1] = body_qpos
self.mj_data.qpos[self.left_hand_index + 7 - 1] = hand_qpos[: self.num_hand_dof]
self.mj_data.qpos[self.right_hand_index + 7 - 1] = hand_qpos[self.num_hand_dof :]
mujoco.mj_kinematics(self.mj_model, self.mj_data)
mujoco.mj_comPos(self.mj_model, self.mj_data)
def apply_perturbation(self, key):
"""Apply perturbation to the robot"""
# Add velocity perturbations in body frame
perturbation_x_body = 0.0 # forward/backward in body frame
perturbation_y_body = 0.0 # left/right in body frame
if key == "up":
perturbation_x_body = 1.0 # forward
elif key == "down":
perturbation_x_body = -1.0 # backward
elif key == "left":
perturbation_y_body = 1.0 # left
elif key == "right":
perturbation_y_body = -1.0 # right
# Transform body frame velocity to world frame using MuJoCo's rotation
vel_body = np.array([perturbation_x_body, perturbation_y_body, 0.0])
vel_world = np.zeros(3)
base_quat = self.mj_data.qpos[3:7] # [w, x, y, z] quaternion
# Use MuJoCo's robust quaternion rotation (handles invalid quaternions automatically)
mujoco.mju_rotVecQuat(vel_world, vel_body, base_quat)
# Apply to base linear velocity in world frame
self.mj_data.qvel[0] += vel_world[0] # world X velocity
self.mj_data.qvel[1] += vel_world[1] # world Y velocity
# Update dynamics after velocity change
mujoco.mj_forward(self.mj_model, self.mj_data)
def update_viewer(self):
if self.viewer is not None:
self.viewer.sync()
def update_viewer_camera(self):
if self.viewer is not None:
if self.viewer.cam.type == mujoco.mjtCamera.mjCAMERA_TRACKING:
self.viewer.cam.type = mujoco.mjtCamera.mjCAMERA_FREE
else:
self.viewer.cam.type = mujoco.mjtCamera.mjCAMERA_TRACKING
def update_reward(self):
"""Calculate reward. Should be implemented by subclasses."""
with self.reward_lock:
self.last_reward = 0
def get_reward(self):
"""Thread-safe way to get the last calculated reward."""
with self.reward_lock:
return self.last_reward
def set_unitree_bridge(self, unitree_bridge):
"""Set the unitree bridge from the simulator"""
self.unitree_bridge = unitree_bridge
def get_privileged_obs(self):
"""Get privileged observation. Should be implemented by subclasses."""
return {}
def update_render_caches(self):
"""Update render cache and shared memory for subprocess."""
render_caches = {}
for camera_name, camera_config in self.camera_configs.items():
renderer = self.renderers[camera_name]
if "params" in camera_config:
renderer.update_scene(self.mj_data, camera=camera_config["params"])
else:
renderer.update_scene(self.mj_data, camera=camera_name)
render_caches[camera_name + "_image"] = renderer.render()
# Update shared memory if image publishing process is available
if self.image_publish_process is not None:
self.image_publish_process.update_shared_memory(render_caches)
return render_caches
def handle_keyboard_button(self, key):
if self.elastic_band is not None:
self.elastic_band.handle_keyboard_button(key)
if key == "backspace":
self.reset()
if key == "v":
self.update_viewer_camera()
if key in ["up", "down", "left", "right"]:
self.apply_perturbation(key)
def check_fall(self):
"""Check if the robot has fallen"""
self.fall = False
if self.mj_data.qpos[2] < 0.2:
self.fall = True
print(f"Warning: Robot has fallen, height: {self.mj_data.qpos[2]:.3f} m")
if self.fall:
self.reset()
def check_self_collision(self):
"""Check for self-collision of the robot"""
robot_bodies = get_subtree_body_names(self.mj_model, self.mj_model.body(self.root_body).id)
self_collision, contact_bodies = check_contact(
self.mj_model, self.mj_data, robot_bodies, robot_bodies, return_all_contact_bodies=True
)
if self_collision:
print(f"Warning: Self-collision detected: {contact_bodies}")
return self_collision
def reset(self):
mujoco.mj_resetData(self.mj_model, self.mj_data)
class CubeEnv(DefaultEnv):
"""Environment with a cube object for pick and place tasks"""
def __init__(
self,
config: Dict[str, any],
onscreen: bool = False,
offscreen: bool = False,
):
# Override the robot scene
config = config.copy() # Create a copy to avoid modifying the original
config["ROBOT_SCENE"] = "gr00t_wbc/control/robot_model/model_data/g1/pnp_cube_43dof.xml"
super().__init__(config, "cube", {}, onscreen, offscreen)
def update_reward(self):
"""Calculate reward based on gripper contact with cube and cube height"""
right_hand_body = [
"right_hand_thumb_2_link",
"right_hand_middle_1_link",
"right_hand_index_1_link",
]
gripper_cube_contact = check_contact(
self.mj_model, self.mj_data, right_hand_body, "cube_body"
)
cube_lifted = check_height(self.mj_model, self.mj_data, "cube", 0.85, 2.0)
with self.reward_lock:
self.last_reward = gripper_cube_contact & cube_lifted
class BoxEnv(DefaultEnv):
"""Environment with a box object for manipulation tasks"""
def __init__(
self,
config: Dict[str, any],
onscreen: bool = False,
offscreen: bool = False,
):
# Override the robot scene
config = config.copy() # Create a copy to avoid modifying the original
config["ROBOT_SCENE"] = "gr00t_wbc/control/robot_model/model_data/g1/lift_box_43dof.xml"
super().__init__(config, "box", {}, onscreen, offscreen)
def reward(self):
"""Calculate reward based on gripper contact with cube and cube height"""
left_hand_body = [
"left_hand_thumb_2_link",
"left_hand_middle_1_link",
"left_hand_index_1_link",
]
right_hand_body = [
"right_hand_thumb_2_link",
"right_hand_middle_1_link",
"right_hand_index_1_link",
]
gripper_box_contact = check_contact(self.mj_model, self.mj_data, left_hand_body, "box_body")
gripper_box_contact &= check_contact(
self.mj_model, self.mj_data, right_hand_body, "box_body"
)
box_lifted = check_height(self.mj_model, self.mj_data, "box", 0.92, 2.0)
print("gripper_box_contact: ", gripper_box_contact, "box_lifted: ", box_lifted)
with self.reward_lock:
self.last_reward = gripper_box_contact & box_lifted
return self.last_reward
class BottleEnv(DefaultEnv):
"""Environment with a cylinder object for manipulation tasks"""
def __init__(
self,
config: Dict[str, any],
onscreen: bool = False,
offscreen: bool = False,
):
# Override the robot scene
config = config.copy() # Create a copy to avoid modifying the original
config["ROBOT_SCENE"] = "gr00t_wbc/control/robot_model/model_data/g1/pnp_bottle_43dof.xml"
camera_configs = {
"egoview": {
"height": 400,
"width": 400,
},
}
super().__init__(
config, "cylinder", camera_configs, onscreen, offscreen
)
self.bottle_body = self.mj_model.body("bottle_body")
self.bottle_geom = self.mj_model.geom("bottle")
if self.viewer is not None:
self.viewer.cam.type = mujoco.mjtCamera.mjCAMERA_FIXED
self.viewer.cam.fixedcamid = self.mj_model.camera("egoview").id
def update_reward(self):
"""Calculate reward based on gripper contact with cylinder and cylinder height"""
pass
def get_privileged_obs(self):
obs_pos = self.mj_data.xpos[self.bottle_body.id]
obs_quat = self.mj_data.xquat[self.bottle_body.id]
return {"bottle_pos": obs_pos, "bottle_quat": obs_quat}
class BaseSimulator:
"""Base simulator class that handles initialization and running of simulations"""
def __init__(self, config: Dict[str, any], env_name: str = "default", **kwargs):
self.config = config
self.env_name = env_name
# Initialize ROS 2 node (optional, only if rclpy is available)
if HAS_RCLPY:
if not rclpy.ok():
rclpy.init()
self.node = rclpy.create_node("sim_mujoco")
self.thread = threading.Thread(target=rclpy.spin, args=(self.node,), daemon=True)
self.thread.start()
else:
self.thread = None
executor = rclpy.get_global_executor()
self.node = executor.get_nodes()[0] # will only take the first node
else:
self.node = None
self.thread = None
# Set update frequencies
self.sim_dt = self.config["SIMULATE_DT"]
self.reward_dt = self.config.get("REWARD_DT", 0.02)
self.image_dt = self.config.get("IMAGE_DT", 0.033333)
self.viewer_dt = self.config.get("VIEWER_DT", 0.02)
# Create the appropriate environment based on name
if env_name == "default":
self.sim_env = DefaultEnv(config, env_name, **kwargs)
elif env_name == "pnp_cube":
self.sim_env = CubeEnv(config, **kwargs)
elif env_name == "lift_box":
self.sim_env = BoxEnv(config, **kwargs)
elif env_name == "pnp_bottle":
self.sim_env = BottleEnv(config, **kwargs)
else:
raise ValueError(f"Invalid environment name: {env_name}")
# Initialize the DDS communication layer - should be safe to call multiple times
try:
if self.config.get("INTERFACE", None):
ChannelFactoryInitialize(self.config["DOMAIN_ID"], self.config["INTERFACE"])
else:
ChannelFactoryInitialize(self.config["DOMAIN_ID"])
except Exception as e:
# If it fails because it's already initialized, that's okay
print(f"Note: Channel factory initialization attempt: {e}")
# Initialize the unitree bridge and pass it to the environment
self.init_unitree_bridge()
self.sim_env.set_unitree_bridge(self.unitree_bridge)
# Initialize additional components
self.init_subscriber()
self.init_publisher()
self.sim_thread = None
def start_as_thread(self):
# Create simulation thread
self.sim_thread = Thread(target=self.start)
self.sim_thread.start()
def start_image_publish_subprocess(self, start_method: str = "spawn", camera_port: int = 5555):
"""Start the image publish subprocess"""
self.sim_env.start_image_publish_subprocess(start_method, camera_port)
def init_subscriber(self):
"""Initialize subscribers. Can be overridden by subclasses."""
pass
def init_publisher(self):
"""Initialize publishers. Can be overridden by subclasses."""
pass
def init_unitree_bridge(self):
"""Initialize the unitree SDK bridge"""
self.unitree_bridge = UnitreeSdk2Bridge(self.config)
if self.config["USE_JOYSTICK"]:
self.unitree_bridge.SetupJoystick(
device_id=self.config["JOYSTICK_DEVICE"], js_type=self.config["JOYSTICK_TYPE"]
)
def start(self):
"""Main simulation loop"""
import time
sim_cnt = 0
last_time = time.time()
print(f"Starting simulation loop. Viewer: {self.sim_env.viewer is not None}")
try:
while (
self.sim_env.viewer and self.sim_env.viewer.is_running()
) or self.sim_env.viewer is None:
# Run simulation step
self.sim_env.sim_step()
# Update viewer at viewer rate
if sim_cnt % int(self.viewer_dt / self.sim_dt) == 0:
self.sim_env.update_viewer()
# Calculate reward at reward rate
if sim_cnt % int(self.reward_dt / self.sim_dt) == 0:
self.sim_env.update_reward()
# Update render caches at image rate
if sim_cnt % int(self.image_dt / self.sim_dt) == 0:
self.sim_env.update_render_caches()
# Sleep to maintain correct rate (simple timing without ROS)
elapsed = time.time() - last_time
sleep_time = max(0, self.sim_dt - elapsed)
if sleep_time > 0:
time.sleep(sleep_time)
last_time = time.time()
sim_cnt += 1
print(f"Loop exited. Viewer running: {self.sim_env.viewer.is_running() if self.sim_env.viewer else 'No viewer'}")
except KeyboardInterrupt:
# User pressed Ctrl+C - exit cleanly
print("Keyboard interrupt received")
pass
except Exception as e:
print(f"Exception in simulation loop: {e}")
import traceback
traceback.print_exc()
self.close()
def __del__(self):
"""Clean up resources when simulator is deleted"""
self.close()
def reset(self):
"""Reset the simulation. Can be overridden by subclasses."""
self.sim_env.reset()
def close(self):
"""Close the simulation. Can be overridden by subclasses."""
try:
# Close viewer
if hasattr(self.sim_env, "viewer") and self.sim_env.viewer is not None:
self.sim_env.viewer.close()
# Shutdown ROS (if available)
if HAS_RCLPY and rclpy.ok():
rclpy.shutdown()
except Exception as e:
print(f"Warning during close: {e}")
def get_privileged_obs(self):
obs = self.sim_env.get_privileged_obs()
# TODO: add ros2 topic to get privileged obs
return obs
def handle_keyboard_button(self, key):
# Only handles keyboard buttons for default env.
if self.env_name == "default":
self.sim_env.handle_keyboard_button(key)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Robot")
parser.add_argument(
"--config",
type=str,
default="./gr00t_wbc/control/main/teleop/configs/g1_29dof_gear_wbc.yaml",
help="config file",
)
args = parser.parse_args()
with open(args.config, "r") as file:
config = yaml.load(file, Loader=yaml.FullLoader)
if config.get("INTERFACE", None):
ChannelFactoryInitialize(config["DOMAIN_ID"], config["INTERFACE"])
else:
ChannelFactoryInitialize(config["DOMAIN_ID"])
simulation = BaseSimulator(config)
simulation.start_as_thread()