File size: 3,459 Bytes
dd37dbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from typing import Any, Dict
import logging
import traceback
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import JSONResponse
import uvicorn
import json_numpy
import msgpack
import msgpack_numpy as m
from abc import ABC, abstractmethod
m.patch()

class ModelServer(ABC):
    def __init__(self):
        self.app: FastAPI | None = None

    @abstractmethod
    def inference_api(self, payload: Dict[str, Any], **kwargs) -> Dict[str, Any]:
        """
        Abstract method for model inference API.

        Parameters
        ----------
        payload : Dict[str, Any]
            The input payload for inference.

        Returns
        -------
        Dict[str, Any]
            The inference result.
        """
        pass


    def _build_app(self, **infer_kwargs):
        """
        Minimal FastAPI app for XVLA inference.
        kwargs are passed to inference_api.
        """
        if self.app is not None: return
        app = FastAPI()
        
        # ODL VERSION With Json Response
        @app.post("/act")
        def act(payload: Dict[str, Any]):
            try:
                for key, value in payload.items():
                    if isinstance(value, (str, bytes)):
                        try: payload[key] = json_numpy.loads(value)
                        except Exception: pass
                action = self.inference_api(payload, **infer_kwargs)
                return JSONResponse({"action": action.tolist()})
            except Exception:
                logging.error(traceback.format_exc())
                return JSONResponse({"error": "Request failed"}, status_code=400)

        @app.websocket("/act")
        async def websocket_endpoint(websocket: WebSocket):
            await websocket.accept()
            await websocket.send_bytes(msgpack.packb({"type": "welcome", "ok": True}, 
                                                     use_bin_type=True))
            try:
                while True:
                    data = await websocket.receive_bytes()
                    payload = msgpack.unpackb(data, raw=False)
                    try: action_pred = self.inference_api(payload, **infer_kwargs)
                    except Exception as e:
                        logging.error(traceback.format_exc())
                        response = {"error": f"Inference failed: {e}"}
                        await websocket.send_bytes(msgpack.packb(response, use_bin_type=True))
                        continue
                    # 4. Pack & Send Response
                    response = {"action": action_pred}
                    await websocket.send_bytes(msgpack.packb(response, use_bin_type=True))
            except WebSocketDisconnect:
                logging.info("WS disconnected")
            except Exception:
                logging.error(traceback.format_exc())
        self.app = app

    def run(self, host: str = "0.0.0.0", port: int = 8000, **kwargs):
        """
        Launch the FastAPI service.
        """
        logging.info(f"🚀 XVLAServer listening on http://{host}:{port}/act")
        logging.info(f"🚀 XVLAServer listening on ws://{host}:{port}/act")
        self._build_app(**kwargs)
        assert self.app is not None
        uvicorn.run(self.app, 
                    host=host, 
                    port=port, 
                    log_level="info",
                    ws_ping_interval=20,
                    ws_ping_timeout=20)