| | |
| |
|
| | import contextlib |
| | import dataclasses |
| | import datetime |
| | import faulthandler |
| | import os |
| | import signal |
| | import time |
| | from moviepy.editor import ImageSequenceClip |
| | import numpy as np |
| | from openpi_client import image_tools |
| | from openpi_client import websocket_client_policy |
| | import pandas as pd |
| | from PIL import Image |
| | from droid.robot_env import RobotEnv |
| | import tqdm |
| | import tyro |
| |
|
| | faulthandler.enable() |
| |
|
| | |
| | DROID_CONTROL_FREQUENCY = 15 |
| |
|
| |
|
| | @dataclasses.dataclass |
| | class Args: |
| | |
| | left_camera_id: str = "<your_camera_id>" |
| | right_camera_id: str = "<your_camera_id>" |
| | wrist_camera_id: str = "<your_camera_id>" |
| |
|
| | |
| | external_camera: str | None = ( |
| | None |
| | ) |
| |
|
| | |
| | max_timesteps: int = 600 |
| | |
| | |
| | open_loop_horizon: int = 8 |
| |
|
| | |
| | remote_host: str = "0.0.0.0" |
| | remote_port: int = ( |
| | 8000 |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| | @contextlib.contextmanager |
| | def prevent_keyboard_interrupt(): |
| | """Temporarily prevent keyboard interrupts by delaying them until after the protected code.""" |
| | interrupted = False |
| | original_handler = signal.getsignal(signal.SIGINT) |
| |
|
| | def handler(signum, frame): |
| | nonlocal interrupted |
| | interrupted = True |
| |
|
| | signal.signal(signal.SIGINT, handler) |
| | try: |
| | yield |
| | finally: |
| | signal.signal(signal.SIGINT, original_handler) |
| | if interrupted: |
| | raise KeyboardInterrupt |
| |
|
| |
|
| | def main(args: Args): |
| | |
| | assert ( |
| | args.external_camera is not None and args.external_camera in ["left", "right"] |
| | ), f"Please specify an external camera to use for the policy, choose from ['left', 'right'], but got {args.external_camera}" |
| |
|
| | |
| | env = RobotEnv(action_space="joint_velocity", gripper_action_space="position") |
| | print("Created the droid env!") |
| |
|
| | |
| | policy_client = websocket_client_policy.WebsocketClientPolicy(args.remote_host, args.remote_port) |
| |
|
| | df = pd.DataFrame(columns=["success", "duration", "video_filename"]) |
| |
|
| | while True: |
| | instruction = input("Enter instruction: ") |
| |
|
| | |
| | actions_from_chunk_completed = 0 |
| | pred_action_chunk = None |
| |
|
| | |
| | timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H:%M:%S") |
| | video = [] |
| | bar = tqdm.tqdm(range(args.max_timesteps)) |
| | print("Running rollout... press Ctrl+C to stop early.") |
| | for t_step in bar: |
| | start_time = time.time() |
| | try: |
| | |
| | curr_obs = _extract_observation( |
| | args, |
| | env.get_observation(), |
| | |
| | save_to_disk=t_step == 0, |
| | ) |
| |
|
| | video.append(curr_obs[f"{args.external_camera}_image"]) |
| |
|
| | |
| | if actions_from_chunk_completed == 0 or actions_from_chunk_completed >= args.open_loop_horizon: |
| | actions_from_chunk_completed = 0 |
| |
|
| | |
| | |
| | request_data = { |
| | "observation/exterior_image_1_left": image_tools.resize_with_pad( |
| | curr_obs[f"{args.external_camera}_image"], 224, 224 |
| | ), |
| | "observation/wrist_image_left": image_tools.resize_with_pad(curr_obs["wrist_image"], 224, 224), |
| | "observation/joint_position": curr_obs["joint_position"], |
| | "observation/gripper_position": curr_obs["gripper_position"], |
| | "prompt": instruction, |
| | } |
| |
|
| | |
| | |
| | with prevent_keyboard_interrupt(): |
| | |
| | pred_action_chunk = policy_client.infer(request_data)["actions"] |
| | assert pred_action_chunk.shape == (10, 8) |
| |
|
| | |
| | action = pred_action_chunk[actions_from_chunk_completed] |
| | actions_from_chunk_completed += 1 |
| |
|
| | |
| | if action[-1].item() > 0.5: |
| | |
| | action = np.concatenate([action[:-1], np.ones((1,))]) |
| | else: |
| | |
| | action = np.concatenate([action[:-1], np.zeros((1,))]) |
| |
|
| | |
| | action = np.clip(action, -1, 1) |
| |
|
| | env.step(action) |
| |
|
| | |
| | elapsed_time = time.time() - start_time |
| | if elapsed_time < 1 / DROID_CONTROL_FREQUENCY: |
| | time.sleep(1 / DROID_CONTROL_FREQUENCY - elapsed_time) |
| | except KeyboardInterrupt: |
| | break |
| |
|
| | video = np.stack(video) |
| | save_filename = "video_" + timestamp |
| | ImageSequenceClip(list(video), fps=10).write_videofile(save_filename + ".mp4", codec="libx264") |
| |
|
| | success: str | float | None = None |
| | while not isinstance(success, float): |
| | success = input( |
| | "Did the rollout succeed? (enter y for 100%, n for 0%), or a numeric value 0-100 based on the evaluation spec" |
| | ) |
| | if success == "y": |
| | success = 1.0 |
| | elif success == "n": |
| | success = 0.0 |
| |
|
| | success = float(success) / 100 |
| | if not (0 <= success <= 1): |
| | print(f"Success must be a number in [0, 100] but got: {success * 100}") |
| |
|
| | df = df.append( |
| | { |
| | "success": success, |
| | "duration": t_step, |
| | "video_filename": save_filename, |
| | }, |
| | ignore_index=True, |
| | ) |
| |
|
| | if input("Do one more eval? (enter y or n) ").lower() != "y": |
| | break |
| | env.reset() |
| |
|
| | os.makedirs("results", exist_ok=True) |
| | timestamp = datetime.datetime.now().strftime("%I:%M%p_%B_%d_%Y") |
| | csv_filename = os.path.join("results", f"eval_{timestamp}.csv") |
| | df.to_csv(csv_filename) |
| | print(f"Results saved to {csv_filename}") |
| |
|
| |
|
| | def _extract_observation(args: Args, obs_dict, *, save_to_disk=False): |
| | image_observations = obs_dict["image"] |
| | left_image, right_image, wrist_image = None, None, None |
| | for key in image_observations: |
| | |
| | |
| | if args.left_camera_id in key and "left" in key: |
| | left_image = image_observations[key] |
| | elif args.right_camera_id in key and "left" in key: |
| | right_image = image_observations[key] |
| | elif args.wrist_camera_id in key and "left" in key: |
| | wrist_image = image_observations[key] |
| |
|
| | |
| | left_image = left_image[..., :3] |
| | right_image = right_image[..., :3] |
| | wrist_image = wrist_image[..., :3] |
| |
|
| | |
| | left_image = left_image[..., ::-1] |
| | right_image = right_image[..., ::-1] |
| | wrist_image = wrist_image[..., ::-1] |
| |
|
| | |
| | robot_state = obs_dict["robot_state"] |
| | cartesian_position = np.array(robot_state["cartesian_position"]) |
| | joint_position = np.array(robot_state["joint_positions"]) |
| | gripper_position = np.array([robot_state["gripper_position"]]) |
| |
|
| | |
| | |
| | if save_to_disk: |
| | combined_image = np.concatenate([left_image, wrist_image, right_image], axis=1) |
| | combined_image = Image.fromarray(combined_image) |
| | combined_image.save("robot_camera_views.png") |
| |
|
| | return { |
| | "left_image": left_image, |
| | "right_image": right_image, |
| | "wrist_image": wrist_image, |
| | "cartesian_position": cartesian_position, |
| | "joint_position": joint_position, |
| | "gripper_position": gripper_position, |
| | } |
| |
|
| |
|
| | if __name__ == "__main__": |
| | args: Args = tyro.cli(Args) |
| | main(args) |
| |
|