kinitro / agent_server.py
tarantula11's picture
Upload submission from kinitro-agent-template
d742a8e verified
#!/usr/bin/env python3
"""
Cap'n Proto RPC Server for Agent Interface
Receives observation as Agent.Tensor (no pickle).
"""
import asyncio
import logging
import os
import capnp
import numpy as np
import torch
# Load the schema
schema_file = os.path.join(os.path.dirname(__file__), "agent.capnp")
agent_capnp = capnp.load(schema_file)
logger = logging.getLogger(__name__)
# Default network configuration
DEFAULT_RPC_ADDRESS = "127.0.0.1"
DEFAULT_RPC_PORT = 8000
_TRAVERSAL_WORDS = 100 * 1024 * 1024 # match client; tune appropriately
class AgentServer(agent_capnp.Agent.Server):
"""Cap'n Proto server implementation for AgentInterface"""
def __init__(self, agent):
self.agent = agent
self.logger = logging.getLogger(__name__)
self.logger.info("AgentServer initialized with agent: %s", type(agent).__name__)
async def act(self, obs, **kwargs):
"""Handle act RPC call. 'obs' is expected to be an Agent.Tensor struct."""
try:
# obs is a struct with .data, .shape, .dtype
byte_len = len(obs.data) if obs and obs.data is not None else 0
self.logger.debug(
"Server.act invoked; incoming obs bytes=%d shape=%s dtype=%s",
byte_len,
list(obs.shape) if obs else None,
obs.dtype if obs else None,
)
# reconstruct numpy observation
obs_np = np.frombuffer(obs.data, dtype=np.dtype(obs.dtype)).reshape(
tuple(obs.shape)
)
# call the underlying agent synchronously (user's agent.act should accept ndarray)
action_tensor = self.agent.act(obs_np)
# convert to numpy
if isinstance(action_tensor, torch.Tensor):
action_np = action_tensor.detach().cpu().numpy()
else:
action_np = np.array(action_tensor)
# Build response Tensor
response = agent_capnp.Tensor.new_message()
response.data = action_np.tobytes()
response.shape = [int(s) for s in action_np.shape]
response.dtype = str(action_np.dtype)
return response
except Exception:
self.logger.exception("Exception in AgentServer.act")
raise
async def reset(self, **kwargs):
try:
self.agent.reset()
except Exception:
self.logger.exception("Error in reset")
raise
async def ping(self, message, **kwargs):
self.logger.info(f"Ping received: {message}")
return "pong"
async def serve(agent, address=DEFAULT_RPC_ADDRESS, port=DEFAULT_RPC_PORT):
"""Serve the agent using asyncio approach"""
async def new_connection(stream):
try:
server = capnp.TwoPartyServer(
stream,
bootstrap=AgentServer(agent),
traversal_limit_in_words=_TRAVERSAL_WORDS,
)
await server.on_disconnect()
except Exception:
logger.exception("Error handling connection")
server = await capnp.AsyncIoStream.create_server(new_connection, address, port)
logger.info("Agent RPC server listening on %s:%d", address, port)
try:
async with server:
await server.serve_forever()
except Exception:
logger.exception("Server error")
finally:
logger.info("Server shutting down")
def start_server(agent, address=DEFAULT_RPC_ADDRESS, port=DEFAULT_RPC_PORT):
async def run_server_with_kj():
async with capnp.kj_loop():
await serve(agent, address, port)
try:
asyncio.run(run_server_with_kj())
except KeyboardInterrupt:
logger.info("Server stopped by user")
def run_server_in_process(agent, address=DEFAULT_RPC_ADDRESS, port=DEFAULT_RPC_PORT):
async def run_with_kj():
async with capnp.kj_loop():
await serve(agent, address, port)
asyncio.run(run_with_kj())