File size: 3,051 Bytes
1be5b40 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import asyncio
import http
import logging
import time
import traceback
from openpi_client import base_policy as _base_policy
from openpi_client import msgpack_numpy
import websockets.asyncio.server as _server
import websockets.frames
logger = logging.getLogger(__name__)
class WebsocketPolicyServer:
"""Serves a policy using the websocket protocol. See websocket_client_policy.py for a client implementation.
Currently only implements the `load` and `infer` methods.
"""
def __init__(
self,
policy: _base_policy.BasePolicy,
host: str = "0.0.0.0",
port: int | None = None,
metadata: dict | None = None,
) -> None:
self._policy = policy
self._host = host
self._port = port
self._metadata = metadata or {}
logging.getLogger("websockets.server").setLevel(logging.INFO)
def serve_forever(self) -> None:
asyncio.run(self.run())
async def run(self):
async with _server.serve(
self._handler,
self._host,
self._port,
compression=None,
max_size=None,
process_request=_health_check,
) as server:
await server.serve_forever()
async def _handler(self, websocket: _server.ServerConnection):
logger.info(f"Connection from {websocket.remote_address} opened")
packer = msgpack_numpy.Packer()
await websocket.send(packer.pack(self._metadata))
prev_total_time = None
while True:
try:
start_time = time.monotonic()
obs = msgpack_numpy.unpackb(await websocket.recv())
infer_time = time.monotonic()
action = self._policy.infer(obs)
infer_time = time.monotonic() - infer_time
action["server_timing"] = {
"infer_ms": infer_time * 1000,
}
if prev_total_time is not None:
# We can only record the last total time since we also want to include the send time.
action["server_timing"]["prev_total_ms"] = prev_total_time * 1000
await websocket.send(packer.pack(action))
prev_total_time = time.monotonic() - start_time
except websockets.ConnectionClosed:
logger.info(f"Connection from {websocket.remote_address} closed")
break
except Exception:
await websocket.send(traceback.format_exc())
await websocket.close(
code=websockets.frames.CloseCode.INTERNAL_ERROR,
reason="Internal server error. Traceback included in previous frame.",
)
raise
def _health_check(connection: _server.ServerConnection, request: _server.Request) -> _server.Response | None:
if request.path == "/healthz":
return connection.respond(http.HTTPStatus.OK, "OK\n")
# Continue with the normal request handling.
return None
|