2toINF's picture
Upload ckpt-200000 (X-VLA generalist)
dd37dbc verified
from typing import Any, Dict
import logging
import traceback
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import JSONResponse
import uvicorn
import json_numpy
import msgpack
import msgpack_numpy as m
from abc import ABC, abstractmethod
m.patch()
class ModelServer(ABC):
def __init__(self):
self.app: FastAPI | None = None
@abstractmethod
def inference_api(self, payload: Dict[str, Any], **kwargs) -> Dict[str, Any]:
"""
Abstract method for model inference API.
Parameters
----------
payload : Dict[str, Any]
The input payload for inference.
Returns
-------
Dict[str, Any]
The inference result.
"""
pass
def _build_app(self, **infer_kwargs):
"""
Minimal FastAPI app for XVLA inference.
kwargs are passed to inference_api.
"""
if self.app is not None: return
app = FastAPI()
# ODL VERSION With Json Response
@app.post("/act")
def act(payload: Dict[str, Any]):
try:
for key, value in payload.items():
if isinstance(value, (str, bytes)):
try: payload[key] = json_numpy.loads(value)
except Exception: pass
action = self.inference_api(payload, **infer_kwargs)
return JSONResponse({"action": action.tolist()})
except Exception:
logging.error(traceback.format_exc())
return JSONResponse({"error": "Request failed"}, status_code=400)
@app.websocket("/act")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
await websocket.send_bytes(msgpack.packb({"type": "welcome", "ok": True},
use_bin_type=True))
try:
while True:
data = await websocket.receive_bytes()
payload = msgpack.unpackb(data, raw=False)
try: action_pred = self.inference_api(payload, **infer_kwargs)
except Exception as e:
logging.error(traceback.format_exc())
response = {"error": f"Inference failed: {e}"}
await websocket.send_bytes(msgpack.packb(response, use_bin_type=True))
continue
# 4. Pack & Send Response
response = {"action": action_pred}
await websocket.send_bytes(msgpack.packb(response, use_bin_type=True))
except WebSocketDisconnect:
logging.info("WS disconnected")
except Exception:
logging.error(traceback.format_exc())
self.app = app
def run(self, host: str = "0.0.0.0", port: int = 8000, **kwargs):
"""
Launch the FastAPI service.
"""
logging.info(f"๐Ÿš€ XVLAServer listening on http://{host}:{port}/act")
logging.info(f"๐Ÿš€ XVLAServer listening on ws://{host}:{port}/act")
self._build_app(**kwargs)
assert self.app is not None
uvicorn.run(self.app,
host=host,
port=port,
log_level="info",
ws_ping_interval=20,
ws_ping_timeout=20)