# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from abc import ABC, abstractmethod from dataclasses import asdict, dataclass, is_dataclass from enum import Enum import io from typing import Any import msgpack import numpy as np import zmq def to_json_serializable(obj: Any) -> Any: """ Recursively convert dataclasses and numpy arrays to JSON-serializable format. Args: obj: Object to convert (can be dataclass, numpy array, dict, list, etc.) Returns: JSON-serializable representation of the object """ if is_dataclass(obj) and not isinstance(obj, type): # Convert dataclass to dict, then recursively process the dict return to_json_serializable(asdict(obj)) elif isinstance(obj, np.ndarray): # Convert numpy array to list return obj.tolist() elif isinstance(obj, np.integer): # Convert numpy integers to Python int return int(obj) elif isinstance(obj, np.floating): # Convert numpy floats to Python float return float(obj) elif isinstance(obj, np.bool_): # Convert numpy bool to Python bool return bool(obj) elif isinstance(obj, dict): # Recursively process dictionary values return {key: to_json_serializable(value) for key, value in obj.items()} elif isinstance(obj, (list, tuple)): # Recursively process list/tuple elements return [to_json_serializable(item) for item in obj] elif isinstance(obj, set): # Convert set to list return [to_json_serializable(item) for item in obj] elif isinstance(obj, (str, int, float, bool, type(None))): # Already JSON-serializable return obj elif isinstance(obj, Enum): return obj.name else: # For other types, try to convert to string as fallback # You might want to handle specific types differently return str(obj) class MessageType(Enum): START_OF_EPISODE = "start_of_episode" END_OF_EPISODE = "end_of_episode" EPISODE_STEP = "episode_step" IMAGE = "image" TEXT = "text" class ActionRepresentation(Enum): RELATIVE = "relative" DELTA = "delta" ABSOLUTE = "absolute" class ActionType(Enum): EEF = "eef" NON_EEF = "non_eef" class ActionFormat(Enum): DEFAULT = "default" XYZ_ROT6D = "xyz+rot6d" XYZ_ROTVEC = "xyz+rotvec" @dataclass class ActionConfig: rep: ActionRepresentation type: ActionType format: ActionFormat state_key: str | None = None @dataclass class ModalityConfig: """Configuration for a modality defining how data should be sampled and loaded. This class specifies which indices to sample relative to a base index and which keys to load for a particular modality (e.g., video, state, action). """ delta_indices: list[int] """Delta indices to sample relative to the current index. The returned data will correspond to the original data at a sampled base index + delta indices.""" modality_keys: list[str] """The keys to load for the modality in the dataset.""" sin_cos_embedding_keys: list[str] | None = None """Optional list of keys to apply sin/cos encoding. If None or empty, use min/max normalization for all keys.""" mean_std_embedding_keys: list[str] | None = None """Optional list of keys to apply mean/std normalization. If None or empty, use min/max normalization for all keys.""" action_configs: list[ActionConfig] | None = None def __post_init__(self): """Set default values for action-related fields if not specified.""" if self.action_configs is not None: assert len(self.action_configs) == len(self.modality_keys), ( f"Number of action configs ({len(self.action_configs)}) must match number of modality keys ({len(self.modality_keys)})" ) parsed_action_configs = [] for action_config in self.action_configs: if isinstance(action_config, dict): action_config = ActionConfig( rep=ActionRepresentation[action_config["rep"]], type=ActionType[action_config["type"]], format=ActionFormat[action_config["format"]], state_key=action_config.get("state_key", None), ) parsed_action_configs.append(action_config) self.action_configs = parsed_action_configs class MsgSerializer: @staticmethod def to_bytes(data: Any) -> bytes: return msgpack.packb(data, default=MsgSerializer.encode_custom_classes) @staticmethod def from_bytes(data: bytes) -> Any: return msgpack.unpackb(data, object_hook=MsgSerializer.decode_custom_classes) @staticmethod def decode_custom_classes(obj): if not isinstance(obj, dict): return obj if "__ModalityConfig_class__" in obj: return ModalityConfig(**obj["as_json"]) if "__ndarray_class__" in obj: return np.load(io.BytesIO(obj["as_npy"]), allow_pickle=False) return obj @staticmethod def encode_custom_classes(obj): if isinstance(obj, ModalityConfig): # Convert to dict and let msgpack recursively handle nested objects return {"__ModalityConfig_class__": True, "as_json": to_json_serializable(obj)} if isinstance(obj, np.ndarray): output = io.BytesIO() np.save(output, obj, allow_pickle=False) return {"__ndarray_class__": True, "as_npy": output.getvalue()} return obj class BasePolicy(ABC): """Abstract base class for robotic control policies. This class defines the interface that all policies must implement, including methods for action computation, input/output validation, and state management. Subclasses must implement: - check_observation(): Validate observation format - check_action(): Validate action format - _get_action(): Core action computation logic - reset(): Reset policy to initial state """ def __init__(self, *, strict: bool = True): self.strict = strict @abstractmethod def check_observation(self, observation: dict[str, Any]) -> None: """Check if the observation is valid. Args: observation: Dictionary containing the current state/observation of the environment Raises: AssertionError: If the observation is invalid. """ pass @abstractmethod def check_action(self, action: dict[str, Any]) -> None: """Check if the action is valid. Args: action: Dictionary containing the action to be executed Raises: AssertionError: If the action is invalid. """ pass @abstractmethod def _get_action( self, observation: dict[str, Any], options: dict[str, Any] | None = None ) -> tuple[dict[str, Any], dict[str, Any]]: """Compute and return the next action based on current observation. This method should be overridden by subclasses to implement policy-specific action computation. Input validation is handled by the public get_action() method. Args: observation: Dictionary containing the current state/observation options: Optional configuration dict for action computation Returns: Tuple of (action, info): - action: Dictionary containing the action to be executed - info: Dictionary containing additional metadata (e.g., confidence scores) """ pass def get_action( self, observation: dict[str, Any], options: dict[str, Any] | None = None ) -> tuple[dict[str, Any], dict[str, Any]]: """Compute and return the next action based on current observation with validation. This is the main public interface. It validates the observation, calls the internal _get_action(), and validates the resulting action. Args: observation: Dictionary containing the current state/observation options: Optional configuration dict for action computation Returns: Tuple of (action, info): - action: Dictionary containing the validated action - info: Dictionary containing additional metadata Raises: AssertionError/ValueError: If observation or action validation fails """ if self.strict: self.check_observation(observation) action, info = self._get_action(observation, options) if self.strict: self.check_action(action) return action, info @abstractmethod def reset(self, options: dict[str, Any] | None = None) -> dict[str, Any]: """Reset the policy to its initial state. Args: options: Dictionary containing the options for the reset Returns: Dictionary containing the info after resetting the policy """ pass class PolicyClient(BasePolicy): def __init__( self, host: str = "localhost", port: int = 5555, timeout_ms: int = 15000, api_token: str = None, strict: bool = False, ): super().__init__(strict=strict) self.context = zmq.Context() self.host = host self.port = port self.timeout_ms = timeout_ms self.api_token = api_token self._init_socket() def _init_socket(self): """Initialize or reinitialize the socket with current settings""" self.socket = self.context.socket(zmq.REQ) self.socket.connect(f"tcp://{self.host}:{self.port}") def ping(self) -> bool: try: self.call_endpoint("ping", requires_input=False) return True except zmq.error.ZMQError: self._init_socket() # Recreate socket for next attempt return False def kill_server(self): """ Kill the server. """ self.call_endpoint("kill", requires_input=False) def call_endpoint( self, endpoint: str, data: dict | None = None, requires_input: bool = True ) -> Any: """ Call an endpoint on the server. Args: endpoint: The name of the endpoint. data: The input data for the endpoint. requires_input: Whether the endpoint requires input data. """ request: dict = {"endpoint": endpoint} if requires_input: request["data"] = data if self.api_token: request["api_token"] = self.api_token self.socket.send(MsgSerializer.to_bytes(request)) message = self.socket.recv() if message == b"ERROR": raise RuntimeError("Server error. Make sure we are running the correct policy server.") response = MsgSerializer.from_bytes(message) if isinstance(response, dict) and "error" in response: raise RuntimeError(f"Server error: {response['error']}") return response def __del__(self): """Cleanup resources on destruction""" self.socket.close() self.context.term() def _get_action( self, observation: dict[str, Any], options: dict[str, Any] | None = None ) -> tuple[dict[str, Any], dict[str, Any]]: response = self.call_endpoint( "get_action", {"observation": observation, "options": options} ) return tuple(response) # Convert list (from msgpack) to tuple of (action, info) def reset(self, options: dict[str, Any] | None = None) -> dict[str, Any]: return self.call_endpoint("reset", {"options": options}) def get_modality_config(self) -> dict[str, ModalityConfig]: return self.call_endpoint("get_modality_config", requires_input=False) def check_observation(self, observation: dict[str, Any]) -> None: raise NotImplementedError( "check_observation is not implemented. Please use `strict=False` to disable strict mode or implement this method in the subclass." ) def check_action(self, action: dict[str, Any]) -> None: raise NotImplementedError( "check_action is not implemented. Please use `strict=False` to disable strict mode or implement this method in the subclass." )