| from typing import Any, Dict |
| import logging |
| import traceback |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect |
| from fastapi.responses import JSONResponse |
| import uvicorn |
| import json_numpy |
| import msgpack |
| import msgpack_numpy as m |
| from abc import ABC, abstractmethod |
| m.patch() |
|
|
| class ModelServer(ABC): |
| def __init__(self): |
| self.app: FastAPI | None = None |
|
|
| @abstractmethod |
| def inference_api(self, payload: Dict[str, Any], **kwargs) -> Dict[str, Any]: |
| """ |
| Abstract method for model inference API. |
| |
| Parameters |
| ---------- |
| payload : Dict[str, Any] |
| The input payload for inference. |
| |
| Returns |
| ------- |
| Dict[str, Any] |
| The inference result. |
| """ |
| pass |
|
|
|
|
| def _build_app(self, **infer_kwargs): |
| """ |
| Minimal FastAPI app for XVLA inference. |
| kwargs are passed to inference_api. |
| """ |
| if self.app is not None: return |
| app = FastAPI() |
| |
| |
| @app.post("/act") |
| def act(payload: Dict[str, Any]): |
| try: |
| for key, value in payload.items(): |
| if isinstance(value, (str, bytes)): |
| try: payload[key] = json_numpy.loads(value) |
| except Exception: pass |
| action = self.inference_api(payload, **infer_kwargs) |
| return JSONResponse({"action": action.tolist()}) |
| except Exception: |
| logging.error(traceback.format_exc()) |
| return JSONResponse({"error": "Request failed"}, status_code=400) |
|
|
| @app.websocket("/act") |
| async def websocket_endpoint(websocket: WebSocket): |
| await websocket.accept() |
| await websocket.send_bytes(msgpack.packb({"type": "welcome", "ok": True}, |
| use_bin_type=True)) |
| try: |
| while True: |
| data = await websocket.receive_bytes() |
| payload = msgpack.unpackb(data, raw=False) |
| try: action_pred = self.inference_api(payload, **infer_kwargs) |
| except Exception as e: |
| logging.error(traceback.format_exc()) |
| response = {"error": f"Inference failed: {e}"} |
| await websocket.send_bytes(msgpack.packb(response, use_bin_type=True)) |
| continue |
| |
| response = {"action": action_pred} |
| await websocket.send_bytes(msgpack.packb(response, use_bin_type=True)) |
| except WebSocketDisconnect: |
| logging.info("WS disconnected") |
| except Exception: |
| logging.error(traceback.format_exc()) |
| self.app = app |
|
|
| def run(self, host: str = "0.0.0.0", port: int = 8000, **kwargs): |
| """ |
| Launch the FastAPI service. |
| """ |
| logging.info(f"๐ XVLAServer listening on http://{host}:{port}/act") |
| logging.info(f"๐ XVLAServer listening on ws://{host}:{port}/act") |
| self._build_app(**kwargs) |
| assert self.app is not None |
| uvicorn.run(self.app, |
| host=host, |
| port=port, |
| log_level="info", |
| ws_ping_interval=20, |
| ws_ping_timeout=20) |
| |