Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import h5py | |
| from sim.main import InteractiveDigitalWorld | |
| from sim.simulator import GenieSimulator, ReplaySimulator | |
| from sim.policy import ReplayPolicy | |
| if __name__ == '__main__': | |
| demo_idx = 120 | |
| prompt_horizon = 11 | |
| action_stride = 1 | |
| with h5py.File('data/robomimic_datasets/robomimic_raw/datasets/lift/ph/image.hdf5', 'r') as f: | |
| demo = f['data'][f'demo_{demo_idx}'] | |
| actions = demo['actions'][:].astype(np.float32) | |
| frames = demo['obs']['agentview_image'][:].astype(np.uint8) # NOTE: possible re-render | |
| replay_simulator = ReplaySimulator(frames=frames, prompt_horizon=prompt_horizon) | |
| replay_policy = ReplayPolicy(actions=actions, prompt_horizon=prompt_horizon, action_stride=action_stride) | |
| assert len(replay_policy) == len(replay_simulator) | |
| genie_simulator = GenieSimulator( | |
| # image_encoder_type="magvit", | |
| # image_encoder_ckpt="data/magvit2.ckpt", | |
| # quantize=True, | |
| # backbone_type="stmaskgit", | |
| # backbone_ckpt='data/serious_robomimic_d256/step_86500', | |
| # # backbone_ckpt="data/genie_lang/step_5", | |
| image_encoder_type='temporalvae', | |
| image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid', | |
| quantize=False, | |
| backbone_type="stmar", | |
| backbone_ckpt="data/mar_ckpt/robomimic", | |
| prompt_horizon=prompt_horizon, | |
| action_stride=action_stride, | |
| domain='robomimic', | |
| physics_simulator=replay_simulator, | |
| compute_psnr=True, | |
| compute_delta_psnr=True, | |
| allow_external_prompt=True, | |
| ) | |
| # use whatever current state is as the initial state | |
| image_prompt = replay_simulator.prompt() | |
| action_prompt = replay_policy.prompt() | |
| genie_simulator.set_initial_state((image_prompt, action_prompt)) | |
| playground = InteractiveDigitalWorld( | |
| simulator=genie_simulator, | |
| policy=replay_policy, | |
| offscreen=True, | |
| window_size=(512 * 2, 512) # [genie image | GT image] side-by-side | |
| ) | |
| for _ in range(len(replay_policy)): | |
| playground.step() | |
| playground.save_video(save_path='test.mp4', as_gif=False) | |
| playground.close() |