File size: 3,996 Bytes
8c57b4a
 
 
d742a8e
8c57b4a
 
 
 
 
d742a8e
 
8c57b4a
 
 
 
 
 
 
 
 
d742a8e
 
 
 
 
 
8c57b4a
 
 
 
 
 
 
 
 
 
d742a8e
8c57b4a
d742a8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c57b4a
d742a8e
8c57b4a
d742a8e
8c57b4a
d742a8e
 
 
 
 
8c57b4a
d742a8e
 
8c57b4a
 
 
 
 
d742a8e
 
8c57b4a
 
d742a8e
 
 
 
8c57b4a
d742a8e
8c57b4a
 
 
 
d742a8e
 
 
 
 
8c57b4a
d742a8e
 
8c57b4a
 
d742a8e
8c57b4a
 
 
 
d742a8e
 
8c57b4a
 
 
 
d742a8e
8c57b4a
 
 
 
 
 
 
 
 
 
d742a8e
8c57b4a
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#!/usr/bin/env python3
"""
Cap'n Proto RPC Server for Agent Interface
Receives observation as Agent.Tensor (no pickle).
"""

import asyncio
import logging
import os

import capnp
import numpy as np
import torch

# Load the schema
schema_file = os.path.join(os.path.dirname(__file__), "agent.capnp")
agent_capnp = capnp.load(schema_file)

logger = logging.getLogger(__name__)

# Default network configuration
DEFAULT_RPC_ADDRESS = "127.0.0.1"
DEFAULT_RPC_PORT = 8000

_TRAVERSAL_WORDS = 100 * 1024 * 1024  # match client; tune appropriately


class AgentServer(agent_capnp.Agent.Server):
    """Cap'n Proto server implementation for AgentInterface"""

    def __init__(self, agent):
        self.agent = agent
        self.logger = logging.getLogger(__name__)
        self.logger.info("AgentServer initialized with agent: %s", type(agent).__name__)

    async def act(self, obs, **kwargs):
        """Handle act RPC call. 'obs' is expected to be an Agent.Tensor struct."""
        try:
            # obs is a struct with .data, .shape, .dtype
            byte_len = len(obs.data) if obs and obs.data is not None else 0
            self.logger.debug(
                "Server.act invoked; incoming obs bytes=%d shape=%s dtype=%s",
                byte_len,
                list(obs.shape) if obs else None,
                obs.dtype if obs else None,
            )

            # reconstruct numpy observation
            obs_np = np.frombuffer(obs.data, dtype=np.dtype(obs.dtype)).reshape(
                tuple(obs.shape)
            )

            # call the underlying agent synchronously (user's agent.act should accept ndarray)
            action_tensor = self.agent.act(obs_np)

            # convert to numpy
            if isinstance(action_tensor, torch.Tensor):
                action_np = action_tensor.detach().cpu().numpy()
            else:
                action_np = np.array(action_tensor)

            # Build response Tensor
            response = agent_capnp.Tensor.new_message()
            response.data = action_np.tobytes()
            response.shape = [int(s) for s in action_np.shape]
            response.dtype = str(action_np.dtype)
            return response
        except Exception:
            self.logger.exception("Exception in AgentServer.act")
            raise

    async def reset(self, **kwargs):
        try:
            self.agent.reset()
        except Exception:
            self.logger.exception("Error in reset")
            raise

    async def ping(self, message, **kwargs):
        self.logger.info(f"Ping received: {message}")
        return "pong"


async def serve(agent, address=DEFAULT_RPC_ADDRESS, port=DEFAULT_RPC_PORT):
    """Serve the agent using asyncio approach"""

    async def new_connection(stream):
        try:
            server = capnp.TwoPartyServer(
                stream,
                bootstrap=AgentServer(agent),
                traversal_limit_in_words=_TRAVERSAL_WORDS,
            )
            await server.on_disconnect()
        except Exception:
            logger.exception("Error handling connection")

    server = await capnp.AsyncIoStream.create_server(new_connection, address, port)
    logger.info("Agent RPC server listening on %s:%d", address, port)

    try:
        async with server:
            await server.serve_forever()
    except Exception:
        logger.exception("Server error")
    finally:
        logger.info("Server shutting down")


def start_server(agent, address=DEFAULT_RPC_ADDRESS, port=DEFAULT_RPC_PORT):
    async def run_server_with_kj():
        async with capnp.kj_loop():
            await serve(agent, address, port)

    try:
        asyncio.run(run_server_with_kj())
    except KeyboardInterrupt:
        logger.info("Server stopped by user")


def run_server_in_process(agent, address=DEFAULT_RPC_ADDRESS, port=DEFAULT_RPC_PORT):
    async def run_with_kj():
        async with capnp.kj_loop():
            await serve(agent, address, port)

    asyncio.run(run_with_kj())