| import numpy as np |
| import time |
| from termcolor import colored |
|
|
| from .base import BasePolicy |
|
|
|
|
| class ArmControlPolicy(BasePolicy): |
| def __init__(self, config): |
| super().__init__(config) |
| self.is_standing = False |
| self.custom_arm_positions = None |
| self.pending_arm_positions = None |
| self.target_arm_positions = None |
| self.current_arm_positions = None |
| self.arm_interp_start_positions = None |
| self.arm_interp_duration = 5.0 |
| self.arm_interp_start_time = None |
|
|
| def get_current_obs_buffer_dict(self, robot_state_data): |
| current_obs_buffer_dict = super().get_current_obs_buffer_dict(robot_state_data) |
| current_obs_buffer_dict["actions"] = self.last_policy_action |
| current_obs_buffer_dict["command_lin_vel"] = self.lin_vel_command |
| current_obs_buffer_dict["command_ang_vel"] = self.ang_vel_command |
| current_obs_buffer_dict["command_stand"] = self.stand_command |
|
|
| |
| if "sin_phase" in self.obs_dict.get("actor_obs", []): |
| current_obs_buffer_dict["sin_phase"] = self._get_obs_sin_phase() |
| if "cos_phase" in self.obs_dict.get("actor_obs", []): |
| current_obs_buffer_dict["cos_phase"] = self._get_obs_cos_phase() |
|
|
| return current_obs_buffer_dict |
|
|
| def _get_obs_sin_phase(self): |
| """Calculate sin phase for gait.""" |
| return np.array([np.sin(self.phase[0, :])]) |
|
|
| def _get_obs_cos_phase(self): |
| """Calculate cos phase for gait.""" |
| return np.array([np.cos(self.phase[0, :])]) |
|
|
| def update_phase_time(self): |
| """Update phase time.""" |
| phase_tp1 = self.phase + self.phase_dt |
| self.phase = np.fmod(phase_tp1 + np.pi, 2 * np.pi) - np.pi |
| if np.linalg.norm(self.lin_vel_command[0]) < 0.01 and np.linalg.norm(self.ang_vel_command[0]) < 0.01: |
| |
| self.phase[0, :] = np.pi * np.ones(2) |
| self.is_standing = True |
| elif self.is_standing: |
| |
| self.phase = np.array([[0.0, np.pi]]) |
| self.is_standing = False |
|
|
| def handle_keyboard_button(self, keycode): |
| """Handle keyboard button presses for locomotion.""" |
| |
| super().handle_keyboard_button(keycode) |
|
|
| |
| if keycode in ["w", "s", "a", "d"]: |
| self._handle_velocity_control(keycode) |
| elif keycode in ["q", "e"]: |
| self._handle_angular_velocity_control(keycode) |
| elif keycode == "=": |
| self._handle_stand_command() |
| elif keycode == "z": |
| self._handle_zero_velocity() |
|
|
| self._print_control_status() |
|
|
| def handle_joystick_button(self, cur_key): |
| """Handle joystick button presses for locomotion.""" |
| |
| super().handle_joystick_button(cur_key) |
|
|
| |
| if cur_key == "start": |
| self._handle_stand_command() |
| elif cur_key == "L2": |
| self._handle_zero_velocity() |
|
|
| def _handle_velocity_control(self, keycode): |
| """Handle linear velocity control.""" |
| if not self.stand_command[0, 0]: |
| return |
|
|
| if keycode == "w": |
| self.lin_vel_command[0, 0] += 0.1 |
| elif keycode == "s": |
| self.lin_vel_command[0, 0] -= 0.1 |
| elif keycode == "a": |
| self.lin_vel_command[0, 1] += 0.1 |
| elif keycode == "d": |
| self.lin_vel_command[0, 1] -= 0.1 |
|
|
| def _handle_angular_velocity_control(self, keycode): |
| """Handle angular velocity control.""" |
| if keycode == "q": |
| self.ang_vel_command[0, 0] -= 0.1 |
| elif keycode == "e": |
| self.ang_vel_command[0, 0] += 0.1 |
|
|
| def _handle_stand_command(self): |
| """Handle stand command toggle.""" |
| self.stand_command[0, 0] = 1 - self.stand_command[0, 0] |
| if self.stand_command[0, 0] == 0: |
| self.ang_vel_command[0, 0] = 0.0 |
| self.lin_vel_command[0, 0] = 0.0 |
| self.lin_vel_command[0, 1] = 0.0 |
| self.logger.info(colored("Stance command", "blue")) |
| else: |
| self.base_height_command[0, 0] = self.desired_base_height |
| self.logger.info(colored("Walk command", "blue")) |
|
|
| def _handle_zero_velocity(self): |
| """Handle zero velocity command.""" |
| self.ang_vel_command[0, 0] = 0.0 |
| self.lin_vel_command[0, 0] = 0.0 |
| self.lin_vel_command[0, 1] = 0.0 |
| self.logger.info(colored("Velocities set to zero", "blue")) |
|
|
| def _print_control_status(self): |
| """Print current control status.""" |
| super()._print_control_status() |
|
|
| |
| lin_vel_x = self.lin_vel_command[0, 0] |
| lin_vel_y = self.lin_vel_command[0, 1] |
| ang_vel_z = self.ang_vel_command[0, 0] |
| is_walking = self.stand_command[0, 0] == 1 |
|
|
| |
| mode = "Walking" if is_walking else "Standing" |
| status = "✓ applied" if is_walking else "✗ not applied" |
| print(f"Linear velocity: x={lin_vel_x:+.2f} m/s, y={lin_vel_y:+.2f} m/s") |
| print(f"Angular velocity: {ang_vel_z:+.2f} rad/s") |
| print(f"Mode: {mode} ({status})") |
| print("💡 Terminal keys: W/A/S/D (lin) | Q/E (ang) | = (toggle mode)") |
| print("🎬 MuJoCo keys (in simulator only): 7/8 (band) | 9 (toggle) | BACKSPACE (reset)") |
|
|
| def _update_arm_interpolation(self, robot_state_data): |
| """Linearly interpolate arm positions to target over fixed duration.""" |
| if self.target_arm_positions is None: |
| return |
|
|
| if self.current_arm_positions is None: |
| current_q = robot_state_data[0, 7 : 7 + self.num_dofs] |
| self.current_arm_positions = current_q[self.upper_dof_indices].astype(np.float32) |
| self.arm_interp_start_positions = self.current_arm_positions.copy() |
| self.arm_interp_start_time = time.monotonic() |
|
|
| elapsed = time.monotonic() - self.arm_interp_start_time |
| progress = np.clip(elapsed / self.arm_interp_duration, 0.0, 1.0) |
|
|
| self.current_arm_positions = ( |
| 1.0 - progress |
| ) * self.arm_interp_start_positions + progress * self.target_arm_positions |
|
|
| self.custom_arm_positions = self.current_arm_positions |
|
|
| def _apply_custom_arm_positions(self, q_target): |
| """Apply custom arm positions to the target joint positions.""" |
|
|
| if hasattr(self, "custom_arm_positions") and self.custom_arm_positions is not None: |
| for i, idx in enumerate(self.upper_dof_indices): |
| q_target[0, idx] = self.custom_arm_positions[i] |
| return q_target |
|
|
| def policy_action(self): |
| """Execute policy action and send commands to robot, with custom arm positions.""" |
|
|
| kp_override = None |
| kd_override = None |
|
|
| |
| with self.latency_tracker.measure("read_state"): |
| robot_state_data = self.interface.get_low_state() |
|
|
| |
| with self.latency_tracker.measure("preprocessing"): |
| |
| if self.get_ready_state: |
| q_target = self.get_init_target(robot_state_data) |
| self.init_count = min(self.init_count, 500) |
| elif not self.use_policy_action: |
| manual_cmd = self._get_manual_command(robot_state_data) |
| if manual_cmd is not None: |
| q_target = manual_cmd["q"] |
| kp_override = manual_cmd.get("kp") |
| kd_override = manual_cmd.get("kd") |
| else: |
| q_target = robot_state_data[:, 7 : 7 + self.num_dofs] |
| else: |
| |
| pass |
|
|
| |
| if self.use_policy_action and not self.get_ready_state: |
| with self.latency_tracker.measure("inference"): |
| scaled_policy_action = self.rl_inference(robot_state_data) |
|
|
| |
| with self.latency_tracker.measure("postprocessing"): |
| if self.use_policy_action and not self.get_ready_state: |
| if scaled_policy_action.shape[1] != self.num_dofs: |
| if not self.upper_body_controller: |
| scaled_policy_action = np.concatenate( |
| [np.zeros((1, self.num_dofs - scaled_policy_action.shape[1])), scaled_policy_action], axis=1 |
| ) |
| else: |
| raise NotImplementedError("Upper body controller not implemented") |
| q_target = scaled_policy_action + self.default_dof_angles |
|
|
| |
| self._update_arm_interpolation(robot_state_data) |
| q_target = self._apply_custom_arm_positions(q_target) |
|
|
| |
| self.cmd_q[:] = q_target |
|
|
| |
| with self.latency_tracker.measure("action_pub"): |
| self.interface.send_low_command( |
| self.cmd_q, |
| self.cmd_dq, |
| self.cmd_tau, |
| robot_state_data[0, 7 : 7 + self.num_dofs], |
| kp_override=kp_override, |
| kd_override=kd_override, |
| ) |
|
|