| |
| """ |
| Cap'n Proto RPC Server for Agent Interface |
| Receives observation as Agent.Tensor (no pickle). |
| """ |
|
|
| import asyncio |
| import logging |
| import os |
|
|
| import capnp |
| import numpy as np |
| import torch |
|
|
| |
| schema_file = os.path.join(os.path.dirname(__file__), "agent.capnp") |
| agent_capnp = capnp.load(schema_file) |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
| DEFAULT_RPC_ADDRESS = "127.0.0.1" |
| DEFAULT_RPC_PORT = 8000 |
|
|
| _TRAVERSAL_WORDS = 100 * 1024 * 1024 |
|
|
|
|
| class AgentServer(agent_capnp.Agent.Server): |
| """Cap'n Proto server implementation for AgentInterface""" |
|
|
| def __init__(self, agent): |
| self.agent = agent |
| self.logger = logging.getLogger(__name__) |
| self.logger.info("AgentServer initialized with agent: %s", type(agent).__name__) |
|
|
| async def act(self, obs, **kwargs): |
| """Handle act RPC call. 'obs' is expected to be an Agent.Tensor struct.""" |
| try: |
| |
| byte_len = len(obs.data) if obs and obs.data is not None else 0 |
| self.logger.debug( |
| "Server.act invoked; incoming obs bytes=%d shape=%s dtype=%s", |
| byte_len, |
| list(obs.shape) if obs else None, |
| obs.dtype if obs else None, |
| ) |
|
|
| |
| obs_np = np.frombuffer(obs.data, dtype=np.dtype(obs.dtype)).reshape( |
| tuple(obs.shape) |
| ) |
|
|
| |
| action_tensor = self.agent.act(obs_np) |
|
|
| |
| if isinstance(action_tensor, torch.Tensor): |
| action_np = action_tensor.detach().cpu().numpy() |
| else: |
| action_np = np.array(action_tensor) |
|
|
| |
| response = agent_capnp.Tensor.new_message() |
| response.data = action_np.tobytes() |
| response.shape = [int(s) for s in action_np.shape] |
| response.dtype = str(action_np.dtype) |
| return response |
| except Exception: |
| self.logger.exception("Exception in AgentServer.act") |
| raise |
|
|
| async def reset(self, **kwargs): |
| try: |
| self.agent.reset() |
| except Exception: |
| self.logger.exception("Error in reset") |
| raise |
|
|
| async def ping(self, message, **kwargs): |
| self.logger.info(f"Ping received: {message}") |
| return "pong" |
|
|
|
|
| async def serve(agent, address=DEFAULT_RPC_ADDRESS, port=DEFAULT_RPC_PORT): |
| """Serve the agent using asyncio approach""" |
|
|
| async def new_connection(stream): |
| try: |
| server = capnp.TwoPartyServer( |
| stream, |
| bootstrap=AgentServer(agent), |
| traversal_limit_in_words=_TRAVERSAL_WORDS, |
| ) |
| await server.on_disconnect() |
| except Exception: |
| logger.exception("Error handling connection") |
|
|
| server = await capnp.AsyncIoStream.create_server(new_connection, address, port) |
| logger.info("Agent RPC server listening on %s:%d", address, port) |
|
|
| try: |
| async with server: |
| await server.serve_forever() |
| except Exception: |
| logger.exception("Server error") |
| finally: |
| logger.info("Server shutting down") |
|
|
|
|
| def start_server(agent, address=DEFAULT_RPC_ADDRESS, port=DEFAULT_RPC_PORT): |
| async def run_server_with_kj(): |
| async with capnp.kj_loop(): |
| await serve(agent, address, port) |
|
|
| try: |
| asyncio.run(run_server_with_kj()) |
| except KeyboardInterrupt: |
| logger.info("Server stopped by user") |
|
|
|
|
| def run_server_in_process(agent, address=DEFAULT_RPC_ADDRESS, port=DEFAULT_RPC_PORT): |
| async def run_with_kj(): |
| async with capnp.kj_loop(): |
| await serve(agent, address, port) |
|
|
| asyncio.run(run_with_kj()) |
|
|