Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import imageio | |
| import click | |
| from sim.main import InteractiveDigitalWorld | |
| from sim.simulator import GenieSimulator | |
| from sim.policy import TeleopPlanarQuadDirectionalPolicy | |
| import cv2 | |
| """ | |
| for maskgit: | |
| python -m sim.example.genie_langtable_teleop --image_encoder_type magvit --image_encoder_ckpt data/magvit2.ckpt \ | |
| --quantize True --prompt_horizon 8 --backbone_type stmaskgit --backbone_ckpt data/maskgit_ckpt/langtable | |
| """ | |
| def main( | |
| image_encoder_type, | |
| image_encoder_ckpt, | |
| quantize, | |
| backbone_type, | |
| backbone_ckpt, | |
| prompt_horizon, | |
| action_stride, | |
| video_save_path, | |
| scene_id, | |
| live | |
| ): | |
| def draw_action_arrow_to_image(image: np.ndarray, action: np.ndarray) -> np.ndarray: | |
| action = action[0] # remove `stride` dimension | |
| assert action[0] * action[1] == 0 | |
| arrow_image = imageio.imread('sim/assets/arrow.jpg') | |
| if action[0] > 0: # `s` | |
| arrow_image = np.flipud(arrow_image) | |
| elif action[1] < 0: # `a` | |
| arrow_image = np.rot90(arrow_image) | |
| elif action[1] > 0: # `d` | |
| arrow_image = np.rot90(arrow_image, -1) | |
| else: | |
| pass # `w` | |
| image[0:arrow_image.shape[0], 0:arrow_image.shape[1]] = arrow_image | |
| return image | |
| genie_simulator = GenieSimulator( | |
| image_encoder_type=image_encoder_type, | |
| image_encoder_ckpt=image_encoder_ckpt, | |
| quantize=quantize, | |
| backbone_type=backbone_type, | |
| backbone_ckpt=backbone_ckpt, | |
| prompt_horizon=prompt_horizon, | |
| action_stride=action_stride, | |
| domain='language_table', | |
| post_processor=draw_action_arrow_to_image | |
| ) | |
| # use whatever current state is as the initial state | |
| current_image = imageio.imread(f'sim/assets/langtable_prompt/frame_{scene_id:02d}.png') | |
| image_prompt = np.tile( | |
| current_image, (genie_simulator.prompt_horizon, 1, 1, 1) | |
| ).astype(np.uint8) | |
| action_prompt = np.zeros( | |
| (genie_simulator.prompt_horizon, genie_simulator.action_stride, 2) | |
| ).astype(np.float32) | |
| genie_simulator.set_initial_state((image_prompt, action_prompt)) | |
| teleop_policy = TeleopPlanarQuadDirectionalPolicy(increment=0.05) | |
| playground = InteractiveDigitalWorld( | |
| simulator=genie_simulator, | |
| policy=teleop_policy, | |
| offscreen=not live, | |
| window_size=(512, 512) | |
| ) | |
| for _ in range(20): | |
| playground.step() | |
| playground.save_video(save_path=video_save_path, as_gif=False) | |
| playground.close() | |
| if __name__ == '__main__': | |
| main() |