| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from __future__ import annotations |
|
|
| from abc import ABC, abstractmethod |
| from dataclasses import asdict, dataclass, is_dataclass |
| from enum import Enum |
| import io |
| from typing import Any |
|
|
| import msgpack |
| import numpy as np |
| import zmq |
|
|
|
|
| def to_json_serializable(obj: Any) -> Any: |
| """ |
| Recursively convert dataclasses and numpy arrays to JSON-serializable format. |
| |
| Args: |
| obj: Object to convert (can be dataclass, numpy array, dict, list, etc.) |
| |
| Returns: |
| JSON-serializable representation of the object |
| """ |
| if is_dataclass(obj) and not isinstance(obj, type): |
| |
| return to_json_serializable(asdict(obj)) |
| elif isinstance(obj, np.ndarray): |
| |
| return obj.tolist() |
| elif isinstance(obj, np.integer): |
| |
| return int(obj) |
| elif isinstance(obj, np.floating): |
| |
| return float(obj) |
| elif isinstance(obj, np.bool_): |
| |
| return bool(obj) |
| elif isinstance(obj, dict): |
| |
| return {key: to_json_serializable(value) for key, value in obj.items()} |
| elif isinstance(obj, (list, tuple)): |
| |
| return [to_json_serializable(item) for item in obj] |
| elif isinstance(obj, set): |
| |
| return [to_json_serializable(item) for item in obj] |
| elif isinstance(obj, (str, int, float, bool, type(None))): |
| |
| return obj |
| elif isinstance(obj, Enum): |
| return obj.name |
| else: |
| |
| |
| return str(obj) |
|
|
|
|
| class MessageType(Enum): |
| START_OF_EPISODE = "start_of_episode" |
| END_OF_EPISODE = "end_of_episode" |
| EPISODE_STEP = "episode_step" |
| IMAGE = "image" |
| TEXT = "text" |
|
|
|
|
| class ActionRepresentation(Enum): |
| RELATIVE = "relative" |
| DELTA = "delta" |
| ABSOLUTE = "absolute" |
|
|
|
|
| class ActionType(Enum): |
| EEF = "eef" |
| NON_EEF = "non_eef" |
|
|
|
|
| class ActionFormat(Enum): |
| DEFAULT = "default" |
| XYZ_ROT6D = "xyz+rot6d" |
| XYZ_ROTVEC = "xyz+rotvec" |
|
|
|
|
| @dataclass |
| class ActionConfig: |
| rep: ActionRepresentation |
| type: ActionType |
| format: ActionFormat |
| state_key: str | None = None |
|
|
|
|
| @dataclass |
| class ModalityConfig: |
| """Configuration for a modality defining how data should be sampled and loaded. |
| |
| This class specifies which indices to sample relative to a base index and which |
| keys to load for a particular modality (e.g., video, state, action). |
| """ |
|
|
| delta_indices: list[int] |
| """Delta indices to sample relative to the current index. The returned data will correspond to the original data at a sampled base index + delta indices.""" |
| modality_keys: list[str] |
| """The keys to load for the modality in the dataset.""" |
| sin_cos_embedding_keys: list[str] | None = None |
| """Optional list of keys to apply sin/cos encoding. If None or empty, use min/max normalization for all keys.""" |
| mean_std_embedding_keys: list[str] | None = None |
| """Optional list of keys to apply mean/std normalization. If None or empty, use min/max normalization for all keys.""" |
| action_configs: list[ActionConfig] | None = None |
|
|
| def __post_init__(self): |
| """Set default values for action-related fields if not specified.""" |
| if self.action_configs is not None: |
| assert len(self.action_configs) == len(self.modality_keys), ( |
| f"Number of action configs ({len(self.action_configs)}) must match number of modality keys ({len(self.modality_keys)})" |
| ) |
| parsed_action_configs = [] |
| for action_config in self.action_configs: |
| if isinstance(action_config, dict): |
| action_config = ActionConfig( |
| rep=ActionRepresentation[action_config["rep"]], |
| type=ActionType[action_config["type"]], |
| format=ActionFormat[action_config["format"]], |
| state_key=action_config.get("state_key", None), |
| ) |
| parsed_action_configs.append(action_config) |
| self.action_configs = parsed_action_configs |
|
|
|
|
| class MsgSerializer: |
| @staticmethod |
| def to_bytes(data: Any) -> bytes: |
| return msgpack.packb(data, default=MsgSerializer.encode_custom_classes) |
|
|
| @staticmethod |
| def from_bytes(data: bytes) -> Any: |
| return msgpack.unpackb(data, object_hook=MsgSerializer.decode_custom_classes) |
|
|
| @staticmethod |
| def decode_custom_classes(obj): |
| if not isinstance(obj, dict): |
| return obj |
| if "__ModalityConfig_class__" in obj: |
| return ModalityConfig(**obj["as_json"]) |
| if "__ndarray_class__" in obj: |
| return np.load(io.BytesIO(obj["as_npy"]), allow_pickle=False) |
| return obj |
|
|
| @staticmethod |
| def encode_custom_classes(obj): |
| if isinstance(obj, ModalityConfig): |
| |
| return {"__ModalityConfig_class__": True, "as_json": to_json_serializable(obj)} |
| if isinstance(obj, np.ndarray): |
| output = io.BytesIO() |
| np.save(output, obj, allow_pickle=False) |
| return {"__ndarray_class__": True, "as_npy": output.getvalue()} |
| return obj |
|
|
|
|
| class BasePolicy(ABC): |
| """Abstract base class for robotic control policies. |
| |
| This class defines the interface that all policies must implement, including |
| methods for action computation, input/output validation, and state management. |
| |
| Subclasses must implement: |
| - check_observation(): Validate observation format |
| - check_action(): Validate action format |
| - _get_action(): Core action computation logic |
| - reset(): Reset policy to initial state |
| """ |
|
|
| def __init__(self, *, strict: bool = True): |
| self.strict = strict |
|
|
| @abstractmethod |
| def check_observation(self, observation: dict[str, Any]) -> None: |
| """Check if the observation is valid. |
| |
| Args: |
| observation: Dictionary containing the current state/observation of the environment |
| |
| Raises: |
| AssertionError: If the observation is invalid. |
| """ |
| pass |
|
|
| @abstractmethod |
| def check_action(self, action: dict[str, Any]) -> None: |
| """Check if the action is valid. |
| |
| Args: |
| action: Dictionary containing the action to be executed |
| |
| Raises: |
| AssertionError: If the action is invalid. |
| """ |
| pass |
|
|
| @abstractmethod |
| def _get_action( |
| self, observation: dict[str, Any], options: dict[str, Any] | None = None |
| ) -> tuple[dict[str, Any], dict[str, Any]]: |
| """Compute and return the next action based on current observation. |
| |
| This method should be overridden by subclasses to implement policy-specific |
| action computation. Input validation is handled by the public get_action() method. |
| |
| Args: |
| observation: Dictionary containing the current state/observation |
| options: Optional configuration dict for action computation |
| |
| Returns: |
| Tuple of (action, info): |
| - action: Dictionary containing the action to be executed |
| - info: Dictionary containing additional metadata (e.g., confidence scores) |
| """ |
| pass |
|
|
| def get_action( |
| self, observation: dict[str, Any], options: dict[str, Any] | None = None |
| ) -> tuple[dict[str, Any], dict[str, Any]]: |
| """Compute and return the next action based on current observation with validation. |
| |
| This is the main public interface. It validates the observation, calls |
| the internal _get_action(), and validates the resulting action. |
| |
| Args: |
| observation: Dictionary containing the current state/observation |
| options: Optional configuration dict for action computation |
| |
| Returns: |
| Tuple of (action, info): |
| - action: Dictionary containing the validated action |
| - info: Dictionary containing additional metadata |
| |
| Raises: |
| AssertionError/ValueError: If observation or action validation fails |
| """ |
| if self.strict: |
| self.check_observation(observation) |
| action, info = self._get_action(observation, options) |
| if self.strict: |
| self.check_action(action) |
| return action, info |
|
|
| @abstractmethod |
| def reset(self, options: dict[str, Any] | None = None) -> dict[str, Any]: |
| """Reset the policy to its initial state. |
| |
| Args: |
| options: Dictionary containing the options for the reset |
| |
| Returns: |
| Dictionary containing the info after resetting the policy |
| """ |
| pass |
|
|
|
|
| class PolicyClient(BasePolicy): |
| def __init__( |
| self, |
| host: str = "localhost", |
| port: int = 5555, |
| timeout_ms: int = 15000, |
| api_token: str = None, |
| strict: bool = False, |
| ): |
| super().__init__(strict=strict) |
| self.context = zmq.Context() |
| self.host = host |
| self.port = port |
| self.timeout_ms = timeout_ms |
| self.api_token = api_token |
| self._init_socket() |
|
|
| def _init_socket(self): |
| """Initialize or reinitialize the socket with current settings""" |
| self.socket = self.context.socket(zmq.REQ) |
| self.socket.connect(f"tcp://{self.host}:{self.port}") |
|
|
| def ping(self) -> bool: |
| try: |
| self.call_endpoint("ping", requires_input=False) |
| return True |
| except zmq.error.ZMQError: |
| self._init_socket() |
| return False |
|
|
| def kill_server(self): |
| """ |
| Kill the server. |
| """ |
| self.call_endpoint("kill", requires_input=False) |
|
|
| def call_endpoint( |
| self, endpoint: str, data: dict | None = None, requires_input: bool = True |
| ) -> Any: |
| """ |
| Call an endpoint on the server. |
| |
| Args: |
| endpoint: The name of the endpoint. |
| data: The input data for the endpoint. |
| requires_input: Whether the endpoint requires input data. |
| """ |
| request: dict = {"endpoint": endpoint} |
| if requires_input: |
| request["data"] = data |
| if self.api_token: |
| request["api_token"] = self.api_token |
|
|
| self.socket.send(MsgSerializer.to_bytes(request)) |
| message = self.socket.recv() |
| if message == b"ERROR": |
| raise RuntimeError("Server error. Make sure we are running the correct policy server.") |
| response = MsgSerializer.from_bytes(message) |
|
|
| if isinstance(response, dict) and "error" in response: |
| raise RuntimeError(f"Server error: {response['error']}") |
| return response |
|
|
| def __del__(self): |
| """Cleanup resources on destruction""" |
| self.socket.close() |
| self.context.term() |
|
|
| def _get_action( |
| self, observation: dict[str, Any], options: dict[str, Any] | None = None |
| ) -> tuple[dict[str, Any], dict[str, Any]]: |
| response = self.call_endpoint( |
| "get_action", {"observation": observation, "options": options} |
| ) |
| return tuple(response) |
|
|
| def reset(self, options: dict[str, Any] | None = None) -> dict[str, Any]: |
| return self.call_endpoint("reset", {"options": options}) |
|
|
| def get_modality_config(self) -> dict[str, ModalityConfig]: |
| return self.call_endpoint("get_modality_config", requires_input=False) |
|
|
| def check_observation(self, observation: dict[str, Any]) -> None: |
| raise NotImplementedError( |
| "check_observation is not implemented. Please use `strict=False` to disable strict mode or implement this method in the subclass." |
| ) |
|
|
| def check_action(self, action: dict[str, Any]) -> None: |
| raise NotImplementedError( |
| "check_action is not implemented. Please use `strict=False` to disable strict mode or implement this method in the subclass." |
| ) |
|
|