| import multiprocessing as mp |
| import signal |
| from pathlib import Path |
| from queue import Empty, Full |
|
|
| import torch |
| import torch.optim as optim |
|
|
| from lerobot.datasets.feature_utils import hw_to_dataset_features |
| from lerobot.datasets.lerobot_dataset import LeRobotDataset |
| from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig |
| from lerobot.policies.sac.configuration_sac import SACConfig |
| from lerobot.policies.sac.modeling_sac import SACPolicy |
| from lerobot.policies.sac.reward_model.modeling_classifier import Classifier |
| from lerobot.rl.buffer import ReplayBuffer |
| from lerobot.rl.gym_manipulator import make_robot_env |
| from lerobot.robots.so_follower import SO100FollowerConfig |
| from lerobot.teleoperators.so_leader import SO100LeaderConfig |
| from lerobot.teleoperators.utils import TeleopEvents |
|
|
| LOG_EVERY = 10 |
| SEND_EVERY = 10 |
| MAX_EPISODES = 5 |
| MAX_STEPS_PER_EPISODE = 20 |
|
|
|
|
| def run_learner( |
| transitions_queue: mp.Queue, |
| parameters_queue: mp.Queue, |
| shutdown_event: mp.Event, |
| policy_learner: SACPolicy, |
| online_buffer: ReplayBuffer, |
| offline_buffer: ReplayBuffer, |
| lr: float = 3e-4, |
| batch_size: int = 32, |
| device: torch.device = "mps", |
| ): |
| """The learner process - trains SAC policy on transitions streamed from the actor, updating parameters |
| for the actor to adopt.""" |
| policy_learner.train() |
| policy_learner.to(device) |
|
|
| |
| optimizer = optim.Adam(policy_learner.parameters(), lr=lr) |
|
|
| print(f"[LEARNER] Online buffer capacity: {online_buffer.capacity}") |
| print(f"[LEARNER] Offline buffer capacity: {offline_buffer.capacity}") |
|
|
| training_step = 0 |
|
|
| while not shutdown_event.is_set(): |
| |
| try: |
| transitions = transitions_queue.get(timeout=0.1) |
| for transition in transitions: |
| |
| online_buffer.add(**transition) |
|
|
| |
| is_intervention = transition.get("complementary_info", {}).get("is_intervention", False) |
| if is_intervention: |
| offline_buffer.add(**transition) |
| print( |
| f"[LEARNER] Human intervention detected! Added to offline buffer (now {len(offline_buffer)} transitions)" |
| ) |
|
|
| except Empty: |
| pass |
|
|
| |
| if len(online_buffer) >= policy_learner.config.online_step_before_learning: |
| |
| online_batch = online_buffer.sample(batch_size // 2) |
|
|
| |
| offline_batch = offline_buffer.sample(batch_size // 2) |
|
|
| |
| batch = {} |
| for key in online_batch: |
| if key in offline_batch: |
| batch[key] = torch.cat([online_batch[key], offline_batch[key]], dim=0) |
| else: |
| batch[key] = online_batch[key] |
|
|
| loss, _ = policy_learner.forward(batch) |
|
|
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
| training_step += 1 |
|
|
| if training_step % LOG_EVERY == 0: |
| print( |
| f"[LEARNER] Training step {training_step}, Loss: {loss.item():.4f}, " |
| f"Buffers: Online={len(online_buffer)}, Offline={len(offline_buffer)}" |
| ) |
|
|
| |
| if training_step % SEND_EVERY == 0: |
| try: |
| state_dict = {k: v.cpu() for k, v in policy_learner.state_dict().items()} |
| parameters_queue.put_nowait(state_dict) |
| print("[LEARNER] Sent updated parameters to actor") |
| except Full: |
| |
| pass |
|
|
| print("[LEARNER] Learner process finished") |
|
|
|
|
| def run_actor( |
| transitions_queue: mp.Queue, |
| parameters_queue: mp.Queue, |
| shutdown_event: mp.Event, |
| policy_actor: SACPolicy, |
| reward_classifier: Classifier, |
| env_cfg: HILSerlRobotEnvConfig, |
| device: torch.device = "mps", |
| output_directory: Path | None = None, |
| ): |
| """The actor process - interacts with environment and collects data. |
| The policy is frozen and only the parameters are updated, popping the most recent ones from a queue.""" |
| policy_actor.eval() |
| policy_actor.to(device) |
|
|
| reward_classifier.eval() |
| reward_classifier.to(device) |
|
|
| |
| env, teleop_device = make_robot_env(env_cfg) |
|
|
| try: |
| for episode in range(MAX_EPISODES): |
| if shutdown_event.is_set(): |
| break |
|
|
| obs, _info = env.reset() |
| episode_reward = 0.0 |
| step = 0 |
| episode_transitions = [] |
|
|
| print(f"[ACTOR] Starting episode {episode + 1}") |
|
|
| while step < MAX_STEPS_PER_EPISODE and not shutdown_event.is_set(): |
| try: |
| new_params = parameters_queue.get_nowait() |
| policy_actor.load_state_dict(new_params) |
| print("[ACTOR] Updated policy parameters from learner") |
| except Empty: |
| pass |
|
|
| |
| policy_obs = make_policy_obs(obs, device=device) |
| action_tensor = policy_actor.select_action(policy_obs) |
| action = action_tensor.squeeze(0).cpu().numpy() |
|
|
| |
| next_obs, _env_reward, terminated, truncated, _info = env.step(action) |
| done = terminated or truncated |
|
|
| |
| policy_next_obs = make_policy_obs(next_obs, device=device) |
| reward = reward_classifier.predict_reward(policy_next_obs) |
|
|
| if reward >= 1.0 and not done: |
| terminated = True |
| done = True |
|
|
| |
| is_intervention = False |
| if hasattr(teleop_device, "get_teleop_events"): |
| |
| teleop_events = teleop_device.get_teleop_events() |
| is_intervention = teleop_events.get(TeleopEvents.IS_INTERVENTION, False) |
|
|
| |
| transition = { |
| "state": policy_obs, |
| "action": action, |
| "reward": float(reward) if hasattr(reward, "item") else reward, |
| "next_state": policy_next_obs, |
| "done": done, |
| "truncated": truncated, |
| "complementary_info": { |
| "is_intervention": is_intervention, |
| }, |
| } |
|
|
| episode_transitions.append(transition) |
|
|
| episode_reward += reward |
| step += 1 |
|
|
| obs = next_obs |
|
|
| if done: |
| break |
|
|
| |
| transitions_queue.put_nowait(episode_transitions) |
|
|
| except KeyboardInterrupt: |
| print("[ACTOR] Interrupted by user") |
| finally: |
| |
| if hasattr(env, "robot") and env.robot.is_connected: |
| env.robot.disconnect() |
| if teleop_device and hasattr(teleop_device, "disconnect"): |
| teleop_device.disconnect() |
| if output_directory is not None: |
| policy_actor.save_pretrained(output_directory) |
| print(f"[ACTOR] Latest actor policy saved at: {output_directory}") |
|
|
| print("[ACTOR] Actor process finished") |
|
|
|
|
| def make_policy_obs(obs, device: torch.device = "cpu"): |
| return { |
| "observation.state": torch.from_numpy(obs["agent_pos"]).float().unsqueeze(0).to(device), |
| **{ |
| f"observation.image.{k}": torch.from_numpy(obs["pixels"][k]).float().unsqueeze(0).to(device) |
| for k in obs["pixels"] |
| }, |
| } |
|
|
|
|
| def main(): |
| """Main function - coordinates actor and learner processes.""" |
|
|
| device = "mps" |
| output_directory = Path("outputs/robot_learning_tutorial/hil_serl") |
| output_directory.mkdir(parents=True, exist_ok=True) |
|
|
| |
| follower_port = ... |
| leader_port = ... |
|
|
| |
| follower_id = ... |
| leader_id = ... |
|
|
| |
| reward_classifier_id = "<user>/reward_classifier_hil_serl_example" |
| reward_classifier = Classifier.from_pretrained(reward_classifier_id) |
|
|
| reward_classifier.to(device) |
| reward_classifier.eval() |
|
|
| |
| robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id) |
| teleop_cfg = SO100LeaderConfig(port=leader_port, id=leader_id) |
| processor_cfg = HILSerlProcessorConfig(control_mode="leader") |
|
|
| env_cfg = HILSerlRobotEnvConfig(robot=robot_cfg, teleop=teleop_cfg, processor=processor_cfg) |
|
|
| |
| env, teleop_device = make_robot_env(env_cfg) |
|
|
| obs_features = hw_to_dataset_features(env.robot.observation_features, "observation") |
| action_features = hw_to_dataset_features(env.robot.action_features, "action") |
|
|
| |
| policy_cfg = SACConfig( |
| device=device, |
| input_features=obs_features, |
| output_features=action_features, |
| ) |
|
|
| policy_actor = SACPolicy(policy_cfg) |
| policy_learner = SACPolicy(policy_cfg) |
|
|
| demonstrations_repo_id = "lerobot/example_hil_serl_dataset" |
| offline_dataset = LeRobotDataset(repo_id=demonstrations_repo_id) |
|
|
| |
| online_replay_buffer = ReplayBuffer(device=device, state_keys=list(obs_features.keys())) |
| |
| offline_replay_buffer = ReplayBuffer.from_lerobot_dataset( |
| lerobot_dataset=offline_dataset, device=device, state_keys=list(obs_features.keys()) |
| ) |
|
|
| |
| transitions_queue = mp.Queue(maxsize=10) |
| parameters_queue = mp.Queue(maxsize=2) |
| shutdown_event = mp.Event() |
|
|
| |
| def signal_handler(sig): |
| print(f"\nSignal {sig} received, shutting down...") |
| shutdown_event.set() |
|
|
| signal.signal(signal.SIGINT, signal_handler) |
| signal.signal(signal.SIGTERM, signal_handler) |
|
|
| |
| learner_process = mp.Process( |
| target=run_learner, |
| args=( |
| transitions_queue, |
| parameters_queue, |
| shutdown_event, |
| policy_learner, |
| online_replay_buffer, |
| offline_replay_buffer, |
| ), |
| kwargs={"device": device}, |
| ) |
|
|
| actor_process = mp.Process( |
| target=run_actor, |
| args=( |
| transitions_queue, |
| parameters_queue, |
| shutdown_event, |
| policy_actor, |
| reward_classifier, |
| env_cfg, |
| output_directory, |
| ), |
| kwargs={"device": "cpu"}, |
| ) |
|
|
| learner_process.start() |
| actor_process.start() |
|
|
| try: |
| |
| actor_process.join() |
| shutdown_event.set() |
| learner_process.join(timeout=10) |
|
|
| except KeyboardInterrupt: |
| print("Main process interrupted") |
| shutdown_event.set() |
| actor_process.join(timeout=5) |
| learner_process.join(timeout=10) |
|
|
| finally: |
| if learner_process.is_alive(): |
| learner_process.terminate() |
| if actor_process.is_alive(): |
| actor_process.terminate() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|