| import dataclasses | |
| import logging | |
| import pathlib | |
| import env as _env | |
| from openpi_client import action_chunk_broker | |
| from openpi_client import websocket_client_policy as _websocket_client_policy | |
| from openpi_client.runtime import runtime as _runtime | |
| from openpi_client.runtime.agents import policy_agent as _policy_agent | |
| import saver as _saver | |
| import tyro | |
| class Args: | |
| out_dir: pathlib.Path = pathlib.Path("data/aloha_sim/videos") | |
| task: str = "gym_aloha/AlohaTransferCube-v0" | |
| seed: int = 0 | |
| action_horizon: int = 10 | |
| host: str = "0.0.0.0" | |
| port: int = 8000 | |
| display: bool = False | |
| def main(args: Args) -> None: | |
| runtime = _runtime.Runtime( | |
| environment=_env.AlohaSimEnvironment( | |
| task=args.task, | |
| seed=args.seed, | |
| ), | |
| agent=_policy_agent.PolicyAgent( | |
| policy=action_chunk_broker.ActionChunkBroker( | |
| policy=_websocket_client_policy.WebsocketClientPolicy( | |
| host=args.host, | |
| port=args.port, | |
| ), | |
| action_horizon=args.action_horizon, | |
| ) | |
| ), | |
| subscribers=[ | |
| _saver.VideoSaver(args.out_dir), | |
| ], | |
| max_hz=50, | |
| ) | |
| runtime.run() | |
| if __name__ == "__main__": | |
| logging.basicConfig(level=logging.INFO, force=True) | |
| tyro.cli(main) | |