ONNX
G1-robot-inspire-hands / example_fixed_arm_policy.py
daniel-wright-6's picture
Upload example_fixed_arm_policy.py with huggingface_hub
d15edfa verified
import numpy as np
import time
from termcolor import colored
from .base import BasePolicy
class ArmControlPolicy(BasePolicy):
def __init__(self, config):
super().__init__(config)
self.is_standing = False
self.custom_arm_positions = None
self.pending_arm_positions = None
self.target_arm_positions = None
self.current_arm_positions = None
self.arm_interp_start_positions = None
self.arm_interp_duration = 5.0 # seconds
self.arm_interp_start_time = None
def get_current_obs_buffer_dict(self, robot_state_data):
current_obs_buffer_dict = super().get_current_obs_buffer_dict(robot_state_data)
current_obs_buffer_dict["actions"] = self.last_policy_action
current_obs_buffer_dict["command_lin_vel"] = self.lin_vel_command
current_obs_buffer_dict["command_ang_vel"] = self.ang_vel_command
current_obs_buffer_dict["command_stand"] = self.stand_command
# Add phase observations only if they are configured
if "sin_phase" in self.obs_dict.get("actor_obs", []):
current_obs_buffer_dict["sin_phase"] = self._get_obs_sin_phase()
if "cos_phase" in self.obs_dict.get("actor_obs", []):
current_obs_buffer_dict["cos_phase"] = self._get_obs_cos_phase()
return current_obs_buffer_dict
def _get_obs_sin_phase(self):
"""Calculate sin phase for gait."""
return np.array([np.sin(self.phase[0, :])])
def _get_obs_cos_phase(self):
"""Calculate cos phase for gait."""
return np.array([np.cos(self.phase[0, :])])
def update_phase_time(self):
"""Update phase time."""
phase_tp1 = self.phase + self.phase_dt
self.phase = np.fmod(phase_tp1 + np.pi, 2 * np.pi) - np.pi
if np.linalg.norm(self.lin_vel_command[0]) < 0.01 and np.linalg.norm(self.ang_vel_command[0]) < 0.01:
# Robot should stand still - set both feet to same phase
self.phase[0, :] = np.pi * np.ones(2)
self.is_standing = True
elif self.is_standing:
# When the robot starts to move, reset the phase to initial state
self.phase = np.array([[0.0, np.pi]])
self.is_standing = False
def handle_keyboard_button(self, keycode):
"""Handle keyboard button presses for locomotion."""
# Call parent handler for common commands
super().handle_keyboard_button(keycode)
# Locomotion-specific commands
if keycode in ["w", "s", "a", "d"]:
self._handle_velocity_control(keycode)
elif keycode in ["q", "e"]:
self._handle_angular_velocity_control(keycode)
elif keycode == "=":
self._handle_stand_command()
elif keycode == "z":
self._handle_zero_velocity()
self._print_control_status()
def handle_joystick_button(self, cur_key):
"""Handle joystick button presses for locomotion."""
# Call parent handler for common commands
super().handle_joystick_button(cur_key)
# Locomotion-specific commands
if cur_key == "start":
self._handle_stand_command()
elif cur_key == "L2":
self._handle_zero_velocity()
def _handle_velocity_control(self, keycode):
"""Handle linear velocity control."""
if not self.stand_command[0, 0]:
return
if keycode == "w":
self.lin_vel_command[0, 0] += 0.1
elif keycode == "s":
self.lin_vel_command[0, 0] -= 0.1
elif keycode == "a":
self.lin_vel_command[0, 1] += 0.1
elif keycode == "d":
self.lin_vel_command[0, 1] -= 0.1
def _handle_angular_velocity_control(self, keycode):
"""Handle angular velocity control."""
if keycode == "q":
self.ang_vel_command[0, 0] -= 0.1
elif keycode == "e":
self.ang_vel_command[0, 0] += 0.1
def _handle_stand_command(self):
"""Handle stand command toggle."""
self.stand_command[0, 0] = 1 - self.stand_command[0, 0]
if self.stand_command[0, 0] == 0:
self.ang_vel_command[0, 0] = 0.0
self.lin_vel_command[0, 0] = 0.0
self.lin_vel_command[0, 1] = 0.0
self.logger.info(colored("Stance command", "blue"))
else:
self.base_height_command[0, 0] = self.desired_base_height
self.logger.info(colored("Walk command", "blue"))
def _handle_zero_velocity(self):
"""Handle zero velocity command."""
self.ang_vel_command[0, 0] = 0.0
self.lin_vel_command[0, 0] = 0.0
self.lin_vel_command[0, 1] = 0.0
self.logger.info(colored("Velocities set to zero", "blue"))
def _print_control_status(self):
"""Print current control status."""
super()._print_control_status()
# Extract values for better formatting
lin_vel_x = self.lin_vel_command[0, 0]
lin_vel_y = self.lin_vel_command[0, 1]
ang_vel_z = self.ang_vel_command[0, 0]
is_walking = self.stand_command[0, 0] == 1
# Print with clear labels and units
mode = "Walking" if is_walking else "Standing"
status = "✓ applied" if is_walking else "✗ not applied"
print(f"Linear velocity: x={lin_vel_x:+.2f} m/s, y={lin_vel_y:+.2f} m/s")
print(f"Angular velocity: {ang_vel_z:+.2f} rad/s")
print(f"Mode: {mode} ({status})")
print("💡 Terminal keys: W/A/S/D (lin) | Q/E (ang) | = (toggle mode)")
print("🎬 MuJoCo keys (in simulator only): 7/8 (band) | 9 (toggle) | BACKSPACE (reset)")
def _update_arm_interpolation(self, robot_state_data):
"""Linearly interpolate arm positions to target over fixed duration."""
if self.target_arm_positions is None:
return
if self.current_arm_positions is None:
current_q = robot_state_data[0, 7 : 7 + self.num_dofs]
self.current_arm_positions = current_q[self.upper_dof_indices].astype(np.float32)
self.arm_interp_start_positions = self.current_arm_positions.copy()
self.arm_interp_start_time = time.monotonic()
elapsed = time.monotonic() - self.arm_interp_start_time
progress = np.clip(elapsed / self.arm_interp_duration, 0.0, 1.0)
self.current_arm_positions = (
1.0 - progress
) * self.arm_interp_start_positions + progress * self.target_arm_positions
self.custom_arm_positions = self.current_arm_positions
def _apply_custom_arm_positions(self, q_target):
"""Apply custom arm positions to the target joint positions."""
if hasattr(self, "custom_arm_positions") and self.custom_arm_positions is not None:
for i, idx in enumerate(self.upper_dof_indices):
q_target[0, idx] = self.custom_arm_positions[i]
return q_target
def policy_action(self):
"""Execute policy action and send commands to robot, with custom arm positions."""
kp_override = None
kd_override = None
# Stage 1: Read State
with self.latency_tracker.measure("read_state"):
robot_state_data = self.interface.get_low_state()
# Stage 2: Pre-processing
with self.latency_tracker.measure("preprocessing"):
# Determine target joint positions
if self.get_ready_state:
q_target = self.get_init_target(robot_state_data)
self.init_count = min(self.init_count, 500)
elif not self.use_policy_action:
manual_cmd = self._get_manual_command(robot_state_data)
if manual_cmd is not None:
q_target = manual_cmd["q"]
kp_override = manual_cmd.get("kp")
kd_override = manual_cmd.get("kd")
else:
q_target = robot_state_data[:, 7 : 7 + self.num_dofs]
else:
# Prepare for inference - any preprocessing before RL inference
pass
# Stage 3: Inference
if self.use_policy_action and not self.get_ready_state:
with self.latency_tracker.measure("inference"):
scaled_policy_action = self.rl_inference(robot_state_data)
# Stage 4: Post-processing
with self.latency_tracker.measure("postprocessing"):
if self.use_policy_action and not self.get_ready_state:
if scaled_policy_action.shape[1] != self.num_dofs:
if not self.upper_body_controller:
scaled_policy_action = np.concatenate(
[np.zeros((1, self.num_dofs - scaled_policy_action.shape[1])), scaled_policy_action], axis=1
)
else:
raise NotImplementedError("Upper body controller not implemented")
q_target = scaled_policy_action + self.default_dof_angles
# Apply custom arm positions (overrides policy output for arm joints)
self._update_arm_interpolation(robot_state_data)
q_target = self._apply_custom_arm_positions(q_target)
# Prepare command (reuse pre-allocated arrays)
self.cmd_q[:] = q_target
# Stage 5: Action Pub
with self.latency_tracker.measure("action_pub"):
self.interface.send_low_command(
self.cmd_q,
self.cmd_dq,
self.cmd_tau,
robot_state_data[0, 7 : 7 + self.num_dofs],
kp_override=kp_override,
kd_override=kd_override,
)