File size: 2,352 Bytes
0a7036f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 | # Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
import torch
import torch.distributed as dist
from .logging import logger
from .Simple_Remote_Infer.deploy.websocket_policy_server import WebsocketPolicyServer
class DistributedModelWrapper:
"""
TODO
"""
def __init__(self, model, local_rank):
self.model = model
self.local_rank = local_rank
def infer(self, obs):
return distributed_infer(self.model, obs, self.local_rank)
def distributed_infer(model, obs, local_rank):
"""
TODO
"""
if not dist.is_initialized():
return model.infer(obs)
rank = dist.get_rank()
assert rank == local_rank, "distributed_infer can only run at(rank 0)"
cmd = torch.tensor(1,
dtype=torch.int64,
device='cuda' if torch.cuda.is_available() else 'cpu')
dist.broadcast(cmd, src=0)
obj_list = [obs]
dist.broadcast_object_list(obj_list, src=0)
result = model.infer(obs)
return result
def worker_loop(model, local_rank):
"""
TODO
"""
device = 'cuda' if torch.cuda.is_available() else 'cpu'
rank = dist.get_rank()
while True:
cmd = torch.zeros(1, dtype=torch.int64, device=device)
dist.broadcast(cmd, src=0)
cmd_val = cmd.item()
if cmd_val == -1:
break
elif cmd_val == 1:
obj_list = [None]
dist.broadcast_object_list(obj_list, src=0)
obs = obj_list[0]
_ = model.infer(obs)
else:
pass
logger.info(f"[worker_loop] Rank {rank} exiting.")
def run_async_server_mode(model, local_rank, host, port):
logger.info("Running in ASYNC SERVER mode")
if local_rank == 0:
dist_model = DistributedModelWrapper(model, local_rank=local_rank)
model_server = WebsocketPolicyServer(dist_model, host=host, port=port)
model_server.serve_forever()
if dist.is_initialized():
cmd = torch.tensor(
-1,
dtype=torch.int64,
device='cuda' if torch.cuda.is_available() else 'cpu')
dist.broadcast(cmd, src=0)
else:
try:
worker_loop(model, local_rank)
except KeyboardInterrupt:
logger.info(f"Rank {local_rank}: Shutting down")
|