import numpy as np import time from termcolor import colored from .base import BasePolicy class ArmControlPolicy(BasePolicy): def __init__(self, config): super().__init__(config) self.is_standing = False self.custom_arm_positions = None self.pending_arm_positions = None self.target_arm_positions = None self.current_arm_positions = None self.arm_interp_start_positions = None self.arm_interp_duration = 5.0 # seconds self.arm_interp_start_time = None def get_current_obs_buffer_dict(self, robot_state_data): current_obs_buffer_dict = super().get_current_obs_buffer_dict(robot_state_data) current_obs_buffer_dict["actions"] = self.last_policy_action current_obs_buffer_dict["command_lin_vel"] = self.lin_vel_command current_obs_buffer_dict["command_ang_vel"] = self.ang_vel_command current_obs_buffer_dict["command_stand"] = self.stand_command # Add phase observations only if they are configured if "sin_phase" in self.obs_dict.get("actor_obs", []): current_obs_buffer_dict["sin_phase"] = self._get_obs_sin_phase() if "cos_phase" in self.obs_dict.get("actor_obs", []): current_obs_buffer_dict["cos_phase"] = self._get_obs_cos_phase() return current_obs_buffer_dict def _get_obs_sin_phase(self): """Calculate sin phase for gait.""" return np.array([np.sin(self.phase[0, :])]) def _get_obs_cos_phase(self): """Calculate cos phase for gait.""" return np.array([np.cos(self.phase[0, :])]) def update_phase_time(self): """Update phase time.""" phase_tp1 = self.phase + self.phase_dt self.phase = np.fmod(phase_tp1 + np.pi, 2 * np.pi) - np.pi if np.linalg.norm(self.lin_vel_command[0]) < 0.01 and np.linalg.norm(self.ang_vel_command[0]) < 0.01: # Robot should stand still - set both feet to same phase self.phase[0, :] = np.pi * np.ones(2) self.is_standing = True elif self.is_standing: # When the robot starts to move, reset the phase to initial state self.phase = np.array([[0.0, np.pi]]) self.is_standing = False def handle_keyboard_button(self, keycode): """Handle keyboard button presses for locomotion.""" # Call parent handler for common commands super().handle_keyboard_button(keycode) # Locomotion-specific commands if keycode in ["w", "s", "a", "d"]: self._handle_velocity_control(keycode) elif keycode in ["q", "e"]: self._handle_angular_velocity_control(keycode) elif keycode == "=": self._handle_stand_command() elif keycode == "z": self._handle_zero_velocity() self._print_control_status() def handle_joystick_button(self, cur_key): """Handle joystick button presses for locomotion.""" # Call parent handler for common commands super().handle_joystick_button(cur_key) # Locomotion-specific commands if cur_key == "start": self._handle_stand_command() elif cur_key == "L2": self._handle_zero_velocity() def _handle_velocity_control(self, keycode): """Handle linear velocity control.""" if not self.stand_command[0, 0]: return if keycode == "w": self.lin_vel_command[0, 0] += 0.1 elif keycode == "s": self.lin_vel_command[0, 0] -= 0.1 elif keycode == "a": self.lin_vel_command[0, 1] += 0.1 elif keycode == "d": self.lin_vel_command[0, 1] -= 0.1 def _handle_angular_velocity_control(self, keycode): """Handle angular velocity control.""" if keycode == "q": self.ang_vel_command[0, 0] -= 0.1 elif keycode == "e": self.ang_vel_command[0, 0] += 0.1 def _handle_stand_command(self): """Handle stand command toggle.""" self.stand_command[0, 0] = 1 - self.stand_command[0, 0] if self.stand_command[0, 0] == 0: self.ang_vel_command[0, 0] = 0.0 self.lin_vel_command[0, 0] = 0.0 self.lin_vel_command[0, 1] = 0.0 self.logger.info(colored("Stance command", "blue")) else: self.base_height_command[0, 0] = self.desired_base_height self.logger.info(colored("Walk command", "blue")) def _handle_zero_velocity(self): """Handle zero velocity command.""" self.ang_vel_command[0, 0] = 0.0 self.lin_vel_command[0, 0] = 0.0 self.lin_vel_command[0, 1] = 0.0 self.logger.info(colored("Velocities set to zero", "blue")) def _print_control_status(self): """Print current control status.""" super()._print_control_status() # Extract values for better formatting lin_vel_x = self.lin_vel_command[0, 0] lin_vel_y = self.lin_vel_command[0, 1] ang_vel_z = self.ang_vel_command[0, 0] is_walking = self.stand_command[0, 0] == 1 # Print with clear labels and units mode = "Walking" if is_walking else "Standing" status = "✓ applied" if is_walking else "✗ not applied" print(f"Linear velocity: x={lin_vel_x:+.2f} m/s, y={lin_vel_y:+.2f} m/s") print(f"Angular velocity: {ang_vel_z:+.2f} rad/s") print(f"Mode: {mode} ({status})") print("💡 Terminal keys: W/A/S/D (lin) | Q/E (ang) | = (toggle mode)") print("🎬 MuJoCo keys (in simulator only): 7/8 (band) | 9 (toggle) | BACKSPACE (reset)") def _update_arm_interpolation(self, robot_state_data): """Linearly interpolate arm positions to target over fixed duration.""" if self.target_arm_positions is None: return if self.current_arm_positions is None: current_q = robot_state_data[0, 7 : 7 + self.num_dofs] self.current_arm_positions = current_q[self.upper_dof_indices].astype(np.float32) self.arm_interp_start_positions = self.current_arm_positions.copy() self.arm_interp_start_time = time.monotonic() elapsed = time.monotonic() - self.arm_interp_start_time progress = np.clip(elapsed / self.arm_interp_duration, 0.0, 1.0) self.current_arm_positions = ( 1.0 - progress ) * self.arm_interp_start_positions + progress * self.target_arm_positions self.custom_arm_positions = self.current_arm_positions def _apply_custom_arm_positions(self, q_target): """Apply custom arm positions to the target joint positions.""" if hasattr(self, "custom_arm_positions") and self.custom_arm_positions is not None: for i, idx in enumerate(self.upper_dof_indices): q_target[0, idx] = self.custom_arm_positions[i] return q_target def policy_action(self): """Execute policy action and send commands to robot, with custom arm positions.""" kp_override = None kd_override = None # Stage 1: Read State with self.latency_tracker.measure("read_state"): robot_state_data = self.interface.get_low_state() # Stage 2: Pre-processing with self.latency_tracker.measure("preprocessing"): # Determine target joint positions if self.get_ready_state: q_target = self.get_init_target(robot_state_data) self.init_count = min(self.init_count, 500) elif not self.use_policy_action: manual_cmd = self._get_manual_command(robot_state_data) if manual_cmd is not None: q_target = manual_cmd["q"] kp_override = manual_cmd.get("kp") kd_override = manual_cmd.get("kd") else: q_target = robot_state_data[:, 7 : 7 + self.num_dofs] else: # Prepare for inference - any preprocessing before RL inference pass # Stage 3: Inference if self.use_policy_action and not self.get_ready_state: with self.latency_tracker.measure("inference"): scaled_policy_action = self.rl_inference(robot_state_data) # Stage 4: Post-processing with self.latency_tracker.measure("postprocessing"): if self.use_policy_action and not self.get_ready_state: if scaled_policy_action.shape[1] != self.num_dofs: if not self.upper_body_controller: scaled_policy_action = np.concatenate( [np.zeros((1, self.num_dofs - scaled_policy_action.shape[1])), scaled_policy_action], axis=1 ) else: raise NotImplementedError("Upper body controller not implemented") q_target = scaled_policy_action + self.default_dof_angles # Apply custom arm positions (overrides policy output for arm joints) self._update_arm_interpolation(robot_state_data) q_target = self._apply_custom_arm_positions(q_target) # Prepare command (reuse pre-allocated arrays) self.cmd_q[:] = q_target # Stage 5: Action Pub with self.latency_tracker.measure("action_pub"): self.interface.send_low_command( self.cmd_q, self.cmd_dq, self.cmd_tau, robot_state_data[0, 7 : 7 + self.num_dofs], kp_override=kp_override, kd_override=kd_override, )