|
|
import asyncio |
|
|
import http |
|
|
import logging |
|
|
import time |
|
|
import traceback |
|
|
|
|
|
from openpi_client import base_policy as _base_policy |
|
|
from openpi_client import msgpack_numpy |
|
|
import websockets.asyncio.server as _server |
|
|
import websockets.frames |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class WebsocketPolicyServer: |
|
|
"""Serves a policy using the websocket protocol. See websocket_client_policy.py for a client implementation. |
|
|
|
|
|
Currently only implements the `load` and `infer` methods. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
policy: _base_policy.BasePolicy, |
|
|
host: str = "0.0.0.0", |
|
|
port: int | None = None, |
|
|
metadata: dict | None = None, |
|
|
) -> None: |
|
|
self._policy = policy |
|
|
self._host = host |
|
|
self._port = port |
|
|
self._metadata = metadata or {} |
|
|
logging.getLogger("websockets.server").setLevel(logging.INFO) |
|
|
|
|
|
def serve_forever(self) -> None: |
|
|
asyncio.run(self.run()) |
|
|
|
|
|
async def run(self): |
|
|
async with _server.serve( |
|
|
self._handler, |
|
|
self._host, |
|
|
self._port, |
|
|
compression=None, |
|
|
max_size=None, |
|
|
process_request=_health_check, |
|
|
) as server: |
|
|
await server.serve_forever() |
|
|
|
|
|
async def _handler(self, websocket: _server.ServerConnection): |
|
|
logger.info(f"Connection from {websocket.remote_address} opened") |
|
|
packer = msgpack_numpy.Packer() |
|
|
|
|
|
await websocket.send(packer.pack(self._metadata)) |
|
|
|
|
|
prev_total_time = None |
|
|
while True: |
|
|
try: |
|
|
start_time = time.monotonic() |
|
|
obs = msgpack_numpy.unpackb(await websocket.recv()) |
|
|
|
|
|
infer_time = time.monotonic() |
|
|
action = self._policy.infer(obs) |
|
|
infer_time = time.monotonic() - infer_time |
|
|
|
|
|
action["server_timing"] = { |
|
|
"infer_ms": infer_time * 1000, |
|
|
} |
|
|
if prev_total_time is not None: |
|
|
|
|
|
action["server_timing"]["prev_total_ms"] = prev_total_time * 1000 |
|
|
|
|
|
await websocket.send(packer.pack(action)) |
|
|
prev_total_time = time.monotonic() - start_time |
|
|
|
|
|
except websockets.ConnectionClosed: |
|
|
logger.info(f"Connection from {websocket.remote_address} closed") |
|
|
break |
|
|
except Exception: |
|
|
await websocket.send(traceback.format_exc()) |
|
|
await websocket.close( |
|
|
code=websockets.frames.CloseCode.INTERNAL_ERROR, |
|
|
reason="Internal server error. Traceback included in previous frame.", |
|
|
) |
|
|
raise |
|
|
|
|
|
|
|
|
def _health_check(connection: _server.ServerConnection, request: _server.Request) -> _server.Response | None: |
|
|
if request.path == "/healthz": |
|
|
return connection.respond(http.HTTPStatus.OK, "OK\n") |
|
|
|
|
|
return None |
|
|
|