Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import imageio | |
| from sim.main import InteractiveDigitalWorld | |
| from sim.simulator import GenieSimulator | |
| from sim.policy import RandomPlanarQuadDirectionalPolicy | |
| if __name__ == '__main__': | |
| def draw_action_arrow_to_image(image: np.ndarray, action: np.ndarray) -> np.ndarray: | |
| action = action[0] # remove `stride` dimension | |
| assert action[0] * action[1] == 0 | |
| arrow_image = imageio.imread('sim/assets/arrow.jpg') | |
| if action[0] > 0: # `s` | |
| arrow_image = np.flipud(arrow_image) | |
| elif action[1] < 0: # `a` | |
| arrow_image = np.rot90(arrow_image) | |
| elif action[1] > 0: # `d` | |
| arrow_image = np.rot90(arrow_image, -1) | |
| else: | |
| pass # `w` | |
| image[0:arrow_image.shape[0], 0:arrow_image.shape[1]] = arrow_image | |
| return image | |
| genie_simulator = GenieSimulator( | |
| # image_encoder_type="magvit", | |
| # image_encoder_ckpt="data/magvit2.ckpt", | |
| # quantize=True, | |
| # backbone_type="stmaskgit", | |
| # backbone_ckpt="data/genie_lang/step_5", | |
| # prompt_horizon=8, | |
| image_encoder_type='temporalvae', | |
| image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid', | |
| quantize=False, | |
| backbone_type="stmar", | |
| backbone_ckpt="data/language_table_scratch_mar_dynamics_gpu_8_nodes_4_16g/step_40000", | |
| # backbone_ckpt="data/genie_lang/step_5", | |
| prompt_horizon=11, | |
| action_stride=1, | |
| domain='language_table', | |
| post_processor=draw_action_arrow_to_image | |
| ) | |
| # use whatever current state is as the initial state | |
| current_image = imageio.imread('sim/assets/langtable_prompt.png') | |
| image_prompt = np.tile( | |
| current_image, (genie_simulator.prompt_horizon, 1, 1, 1) | |
| ).astype(np.uint8) | |
| action_prompt = np.zeros( | |
| (genie_simulator.prompt_horizon, genie_simulator.action_stride, 2) | |
| ).astype(np.float32) | |
| genie_simulator.set_initial_state((image_prompt, action_prompt)) | |
| random_policy = RandomPlanarQuadDirectionalPolicy(increment=0.05) # as IRASIM | |
| playground = InteractiveDigitalWorld( | |
| simulator=genie_simulator, | |
| policy=random_policy, | |
| offscreen=True, | |
| window_size=(512, 512) | |
| ) | |
| for _ in range(50): | |
| playground.step() | |
| playground.save_video(save_path='test.mp4', as_gif=False) | |
| playground.close() | |