| |
| import torch |
| import torch.distributed as dist |
|
|
| from .logging import logger |
| from .Simple_Remote_Infer.deploy.websocket_policy_server import WebsocketPolicyServer |
|
|
|
|
| class DistributedModelWrapper: |
| """ |
| TODO |
| """ |
|
|
| def __init__(self, model, local_rank): |
| self.model = model |
| self.local_rank = local_rank |
|
|
| def infer(self, obs): |
| return distributed_infer(self.model, obs, self.local_rank) |
|
|
|
|
| def distributed_infer(model, obs, local_rank): |
| """ |
| TODO |
| """ |
| if not dist.is_initialized(): |
| return model.infer(obs) |
|
|
| rank = dist.get_rank() |
| assert rank == local_rank, "distributed_infer can only run at(rank 0)" |
|
|
| cmd = torch.tensor(1, |
| dtype=torch.int64, |
| device='cuda' if torch.cuda.is_available() else 'cpu') |
| dist.broadcast(cmd, src=0) |
|
|
| obj_list = [obs] |
| dist.broadcast_object_list(obj_list, src=0) |
|
|
| result = model.infer(obs) |
|
|
| return result |
|
|
|
|
| def worker_loop(model, local_rank): |
| """ |
| TODO |
| """ |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| rank = dist.get_rank() |
|
|
| while True: |
| cmd = torch.zeros(1, dtype=torch.int64, device=device) |
| dist.broadcast(cmd, src=0) |
| cmd_val = cmd.item() |
|
|
| if cmd_val == -1: |
| break |
| elif cmd_val == 1: |
| obj_list = [None] |
| dist.broadcast_object_list(obj_list, src=0) |
| obs = obj_list[0] |
| _ = model.infer(obs) |
| else: |
| pass |
|
|
| logger.info(f"[worker_loop] Rank {rank} exiting.") |
|
|
|
|
| def run_async_server_mode(model, local_rank, host, port): |
| logger.info("Running in ASYNC SERVER mode") |
| if local_rank == 0: |
| dist_model = DistributedModelWrapper(model, local_rank=local_rank) |
| model_server = WebsocketPolicyServer(dist_model, host=host, port=port) |
| model_server.serve_forever() |
|
|
| if dist.is_initialized(): |
| cmd = torch.tensor( |
| -1, |
| dtype=torch.int64, |
| device='cuda' if torch.cuda.is_available() else 'cpu') |
| dist.broadcast(cmd, src=0) |
| else: |
| try: |
| worker_loop(model, local_rank) |
| except KeyboardInterrupt: |
| logger.info(f"Rank {local_rank}: Shutting down") |
|
|