File size: 2,352 Bytes
0a7036f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
import torch
import torch.distributed as dist

from .logging import logger
from .Simple_Remote_Infer.deploy.websocket_policy_server import WebsocketPolicyServer


class DistributedModelWrapper:
    """
    TODO
    """

    def __init__(self, model, local_rank):
        self.model = model
        self.local_rank = local_rank

    def infer(self, obs):
        return distributed_infer(self.model, obs, self.local_rank)


def distributed_infer(model, obs, local_rank):
    """
    TODO
    """
    if not dist.is_initialized():
        return model.infer(obs)

    rank = dist.get_rank()
    assert rank == local_rank, "distributed_infer can only run at(rank 0)"

    cmd = torch.tensor(1,
                       dtype=torch.int64,
                       device='cuda' if torch.cuda.is_available() else 'cpu')
    dist.broadcast(cmd, src=0)

    obj_list = [obs]
    dist.broadcast_object_list(obj_list, src=0)

    result = model.infer(obs)

    return result


def worker_loop(model, local_rank):
    """
    TODO
    """
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    rank = dist.get_rank()

    while True:
        cmd = torch.zeros(1, dtype=torch.int64, device=device)
        dist.broadcast(cmd, src=0)
        cmd_val = cmd.item()

        if cmd_val == -1:
            break
        elif cmd_val == 1:
            obj_list = [None]
            dist.broadcast_object_list(obj_list, src=0)
            obs = obj_list[0]
            _ = model.infer(obs)
        else:
            pass

    logger.info(f"[worker_loop] Rank {rank} exiting.")


def run_async_server_mode(model, local_rank, host, port):
    logger.info("Running in ASYNC SERVER mode")
    if local_rank == 0:
        dist_model = DistributedModelWrapper(model, local_rank=local_rank)
        model_server = WebsocketPolicyServer(dist_model, host=host, port=port)
        model_server.serve_forever()

        if dist.is_initialized():
            cmd = torch.tensor(
                -1,
                dtype=torch.int64,
                device='cuda' if torch.cuda.is_available() else 'cpu')
            dist.broadcast(cmd, src=0)
    else:
        try:
            worker_loop(model, local_rank)
        except KeyboardInterrupt:
            logger.info(f"Rank {local_rank}: Shutting down")