| import numpy as np |
| import click |
| from diffusion_policy.common.replay_buffer import ReplayBuffer |
| from diffusion_policy.env.pusht.pusht_keypoints_env import PushTKeypointsEnv |
| import pygame |
|
|
| @click.command() |
| @click.option('-o', '--output', required=True) |
| @click.option('-rs', '--render_size', default=96, type=int) |
| @click.option('-hz', '--control_hz', default=10, type=int) |
| def main(output, render_size, control_hz): |
| """ |
| Collect demonstration for the Push-T task. |
| |
| Usage: python demo_pusht.py -o data/pusht_demo.zarr |
| |
| This script is compatible with both Linux and MacOS. |
| Hover mouse close to the blue circle to start. |
| Push the T block into the green area. |
| The episode will automatically terminate if the task is succeeded. |
| Press "Q" to exit. |
| Press "R" to retry. |
| Hold "Space" to pause. |
| """ |
| |
| |
| replay_buffer = ReplayBuffer.create_from_path(output, mode='a') |
|
|
| |
| kp_kwargs = PushTKeypointsEnv.genenerate_keypoint_manager_params() |
| env = PushTKeypointsEnv(render_size=render_size, render_action=False, **kp_kwargs) |
| agent = env.teleop_agent() |
| clock = pygame.time.Clock() |
| |
| |
| while True: |
| episode = list() |
| |
| seed = replay_buffer.n_episodes |
| print(f'starting seed {seed}') |
|
|
| |
| env.seed(seed) |
| |
| |
| obs = env.reset() |
| info = env._get_info() |
| img = env.render(mode='human') |
| |
| |
| retry = False |
| pause = False |
| done = False |
| plan_idx = 0 |
| pygame.display.set_caption(f'plan_idx:{plan_idx}') |
| |
| while not done: |
| |
| for event in pygame.event.get(): |
| if event.type == pygame.KEYDOWN: |
| if event.key == pygame.K_SPACE: |
| |
| plan_idx += 1 |
| pygame.display.set_caption(f'plan_idx:{plan_idx}') |
| pause = True |
| elif event.key == pygame.K_r: |
| |
| retry=True |
| elif event.key == pygame.K_q: |
| |
| exit(0) |
| if event.type == pygame.KEYUP: |
| if event.key == pygame.K_SPACE: |
| pause = False |
|
|
| |
| if retry: |
| break |
| if pause: |
| continue |
| |
| |
| |
| act = agent.act(obs) |
| if not act is None: |
| |
| |
| state = np.concatenate([info['pos_agent'], info['block_pose']]) |
| |
| |
| keypoint = obs.reshape(2,-1)[0].reshape(-1,2)[:9] |
| data = { |
| 'img': img, |
| 'state': np.float32(state), |
| 'keypoint': np.float32(keypoint), |
| 'action': np.float32(act), |
| 'n_contacts': np.float32([info['n_contacts']]) |
| } |
| episode.append(data) |
| |
| |
| obs, reward, done, info = env.step(act) |
| img = env.render(mode='human') |
| |
| |
| clock.tick(control_hz) |
| if not retry: |
| |
| data_dict = dict() |
| for key in episode[0].keys(): |
| data_dict[key] = np.stack( |
| [x[key] for x in episode]) |
| replay_buffer.add_episode(data_dict, compressors='disk') |
| print(f'saved seed {seed}') |
| else: |
| print(f'retry seed {seed}') |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|