Spaces:
Runtime error
Runtime error
| """ | |
| Generate data based on the learned policy and physics simulator - mujoco | |
| """ | |
| from sim.simulator import RobomimicSimulator | |
| from sim.policy import DiffusionPolicy | |
| from diffusion_policy.util.pytorch_util import dict_apply | |
| import h5py | |
| import tqdm | |
| import numpy as np | |
| import torch | |
| import cv2 | |
| import imageio | |
| MAX_STEPS = 100 | |
| RES = 84 # for DP input | |
| if __name__ == '__main__': | |
| env = RobomimicSimulator(env_name='lift') | |
| policy = DiffusionPolicy('data/dp_ckpt/dp_lift_sr0.70.ckpt') | |
| n_obs_steps = policy.n_obs_steps | |
| demos = dict() | |
| for trial in range(200): | |
| # reset all | |
| image = env.reset() | |
| this_demo = { | |
| "images": [], | |
| "actions": [] | |
| } | |
| latest_obs_dict = {"agentview_image": cv2.resize(image, (RES, RES)).transpose(2, 0, 1)} | |
| obs_dict_buf = dict_apply(latest_obs_dict, lambda x: x[np.newaxis].repeat(n_obs_steps, axis=0)) | |
| done = False | |
| pbar = tqdm.tqdm(total=MAX_STEPS) | |
| while not done and pbar.n < MAX_STEPS: | |
| # get latest obs | |
| latest_obs_dict = {"agentview_image": cv2.resize(image, (RES, RES)).transpose(2, 0, 1)} | |
| obs_dict_buf = dict_apply(obs_dict_buf, lambda x: np.roll(x, -1, axis=0)) | |
| for k, v in latest_obs_dict.items(): | |
| obs_dict_buf[k][-1] = v | |
| # rollout | |
| traj = policy.generate_action(dict_apply( | |
| obs_dict_buf, | |
| lambda x: torch.from_numpy(x).to( | |
| device=policy.device, dtype=policy.dtype | |
| ).unsqueeze(0) | |
| ))['action'].squeeze(0).detach().cpu().numpy() | |
| # step the simulator | |
| for action in traj: | |
| this_demo["images"].append(image) | |
| this_demo["actions"].append(action) | |
| result = env.step(action) | |
| done = done or result['done'] | |
| image = result['pred_next_frame'] | |
| pbar.update(1) | |
| this_demo = dict_apply(this_demo, lambda x: np.array(x)) | |
| demos[f"demo_{trial}"] = this_demo | |
| demos = {"data": demos} | |
| with h5py.File('data/my_robomimic_dataset.hdf5', 'w') as f: | |
| # save demos | |
| """ | |
| demos = { | |
| "data": { | |
| "demo_0": { | |
| "images": np.array([...]), | |
| "actions": np.array([...]) | |
| }, | |
| "demo_1": { | |
| "images": np.array([...]), | |
| "actions": np.array([...]) | |
| }, | |
| ... | |
| } | |
| } | |
| """ | |
| data = f.create_group("data") | |
| for demo_name, demo_data in demos["data"].items(): | |
| demo = data.create_group(demo_name) | |
| for key, value in demo_data.items(): | |
| demo.create_dataset(key, data=value) | |