GR00T / examples /DROID /server_client.py
yqi19's picture
add: source files (batch 1)
b60a439 verified
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass, is_dataclass
from enum import Enum
import io
from typing import Any
import msgpack
import numpy as np
import zmq
def to_json_serializable(obj: Any) -> Any:
"""
Recursively convert dataclasses and numpy arrays to JSON-serializable format.
Args:
obj: Object to convert (can be dataclass, numpy array, dict, list, etc.)
Returns:
JSON-serializable representation of the object
"""
if is_dataclass(obj) and not isinstance(obj, type):
# Convert dataclass to dict, then recursively process the dict
return to_json_serializable(asdict(obj))
elif isinstance(obj, np.ndarray):
# Convert numpy array to list
return obj.tolist()
elif isinstance(obj, np.integer):
# Convert numpy integers to Python int
return int(obj)
elif isinstance(obj, np.floating):
# Convert numpy floats to Python float
return float(obj)
elif isinstance(obj, np.bool_):
# Convert numpy bool to Python bool
return bool(obj)
elif isinstance(obj, dict):
# Recursively process dictionary values
return {key: to_json_serializable(value) for key, value in obj.items()}
elif isinstance(obj, (list, tuple)):
# Recursively process list/tuple elements
return [to_json_serializable(item) for item in obj]
elif isinstance(obj, set):
# Convert set to list
return [to_json_serializable(item) for item in obj]
elif isinstance(obj, (str, int, float, bool, type(None))):
# Already JSON-serializable
return obj
elif isinstance(obj, Enum):
return obj.name
else:
# For other types, try to convert to string as fallback
# You might want to handle specific types differently
return str(obj)
class MessageType(Enum):
START_OF_EPISODE = "start_of_episode"
END_OF_EPISODE = "end_of_episode"
EPISODE_STEP = "episode_step"
IMAGE = "image"
TEXT = "text"
class ActionRepresentation(Enum):
RELATIVE = "relative"
DELTA = "delta"
ABSOLUTE = "absolute"
class ActionType(Enum):
EEF = "eef"
NON_EEF = "non_eef"
class ActionFormat(Enum):
DEFAULT = "default"
XYZ_ROT6D = "xyz+rot6d"
XYZ_ROTVEC = "xyz+rotvec"
@dataclass
class ActionConfig:
rep: ActionRepresentation
type: ActionType
format: ActionFormat
state_key: str | None = None
@dataclass
class ModalityConfig:
"""Configuration for a modality defining how data should be sampled and loaded.
This class specifies which indices to sample relative to a base index and which
keys to load for a particular modality (e.g., video, state, action).
"""
delta_indices: list[int]
"""Delta indices to sample relative to the current index. The returned data will correspond to the original data at a sampled base index + delta indices."""
modality_keys: list[str]
"""The keys to load for the modality in the dataset."""
sin_cos_embedding_keys: list[str] | None = None
"""Optional list of keys to apply sin/cos encoding. If None or empty, use min/max normalization for all keys."""
mean_std_embedding_keys: list[str] | None = None
"""Optional list of keys to apply mean/std normalization. If None or empty, use min/max normalization for all keys."""
action_configs: list[ActionConfig] | None = None
def __post_init__(self):
"""Set default values for action-related fields if not specified."""
if self.action_configs is not None:
assert len(self.action_configs) == len(self.modality_keys), (
f"Number of action configs ({len(self.action_configs)}) must match number of modality keys ({len(self.modality_keys)})"
)
parsed_action_configs = []
for action_config in self.action_configs:
if isinstance(action_config, dict):
action_config = ActionConfig(
rep=ActionRepresentation[action_config["rep"]],
type=ActionType[action_config["type"]],
format=ActionFormat[action_config["format"]],
state_key=action_config.get("state_key", None),
)
parsed_action_configs.append(action_config)
self.action_configs = parsed_action_configs
class MsgSerializer:
@staticmethod
def to_bytes(data: Any) -> bytes:
return msgpack.packb(data, default=MsgSerializer.encode_custom_classes)
@staticmethod
def from_bytes(data: bytes) -> Any:
return msgpack.unpackb(data, object_hook=MsgSerializer.decode_custom_classes)
@staticmethod
def decode_custom_classes(obj):
if not isinstance(obj, dict):
return obj
if "__ModalityConfig_class__" in obj:
return ModalityConfig(**obj["as_json"])
if "__ndarray_class__" in obj:
return np.load(io.BytesIO(obj["as_npy"]), allow_pickle=False)
return obj
@staticmethod
def encode_custom_classes(obj):
if isinstance(obj, ModalityConfig):
# Convert to dict and let msgpack recursively handle nested objects
return {"__ModalityConfig_class__": True, "as_json": to_json_serializable(obj)}
if isinstance(obj, np.ndarray):
output = io.BytesIO()
np.save(output, obj, allow_pickle=False)
return {"__ndarray_class__": True, "as_npy": output.getvalue()}
return obj
class BasePolicy(ABC):
"""Abstract base class for robotic control policies.
This class defines the interface that all policies must implement, including
methods for action computation, input/output validation, and state management.
Subclasses must implement:
- check_observation(): Validate observation format
- check_action(): Validate action format
- _get_action(): Core action computation logic
- reset(): Reset policy to initial state
"""
def __init__(self, *, strict: bool = True):
self.strict = strict
@abstractmethod
def check_observation(self, observation: dict[str, Any]) -> None:
"""Check if the observation is valid.
Args:
observation: Dictionary containing the current state/observation of the environment
Raises:
AssertionError: If the observation is invalid.
"""
pass
@abstractmethod
def check_action(self, action: dict[str, Any]) -> None:
"""Check if the action is valid.
Args:
action: Dictionary containing the action to be executed
Raises:
AssertionError: If the action is invalid.
"""
pass
@abstractmethod
def _get_action(
self, observation: dict[str, Any], options: dict[str, Any] | None = None
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Compute and return the next action based on current observation.
This method should be overridden by subclasses to implement policy-specific
action computation. Input validation is handled by the public get_action() method.
Args:
observation: Dictionary containing the current state/observation
options: Optional configuration dict for action computation
Returns:
Tuple of (action, info):
- action: Dictionary containing the action to be executed
- info: Dictionary containing additional metadata (e.g., confidence scores)
"""
pass
def get_action(
self, observation: dict[str, Any], options: dict[str, Any] | None = None
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Compute and return the next action based on current observation with validation.
This is the main public interface. It validates the observation, calls
the internal _get_action(), and validates the resulting action.
Args:
observation: Dictionary containing the current state/observation
options: Optional configuration dict for action computation
Returns:
Tuple of (action, info):
- action: Dictionary containing the validated action
- info: Dictionary containing additional metadata
Raises:
AssertionError/ValueError: If observation or action validation fails
"""
if self.strict:
self.check_observation(observation)
action, info = self._get_action(observation, options)
if self.strict:
self.check_action(action)
return action, info
@abstractmethod
def reset(self, options: dict[str, Any] | None = None) -> dict[str, Any]:
"""Reset the policy to its initial state.
Args:
options: Dictionary containing the options for the reset
Returns:
Dictionary containing the info after resetting the policy
"""
pass
class PolicyClient(BasePolicy):
def __init__(
self,
host: str = "localhost",
port: int = 5555,
timeout_ms: int = 15000,
api_token: str = None,
strict: bool = False,
):
super().__init__(strict=strict)
self.context = zmq.Context()
self.host = host
self.port = port
self.timeout_ms = timeout_ms
self.api_token = api_token
self._init_socket()
def _init_socket(self):
"""Initialize or reinitialize the socket with current settings"""
self.socket = self.context.socket(zmq.REQ)
self.socket.connect(f"tcp://{self.host}:{self.port}")
def ping(self) -> bool:
try:
self.call_endpoint("ping", requires_input=False)
return True
except zmq.error.ZMQError:
self._init_socket() # Recreate socket for next attempt
return False
def kill_server(self):
"""
Kill the server.
"""
self.call_endpoint("kill", requires_input=False)
def call_endpoint(
self, endpoint: str, data: dict | None = None, requires_input: bool = True
) -> Any:
"""
Call an endpoint on the server.
Args:
endpoint: The name of the endpoint.
data: The input data for the endpoint.
requires_input: Whether the endpoint requires input data.
"""
request: dict = {"endpoint": endpoint}
if requires_input:
request["data"] = data
if self.api_token:
request["api_token"] = self.api_token
self.socket.send(MsgSerializer.to_bytes(request))
message = self.socket.recv()
if message == b"ERROR":
raise RuntimeError("Server error. Make sure we are running the correct policy server.")
response = MsgSerializer.from_bytes(message)
if isinstance(response, dict) and "error" in response:
raise RuntimeError(f"Server error: {response['error']}")
return response
def __del__(self):
"""Cleanup resources on destruction"""
self.socket.close()
self.context.term()
def _get_action(
self, observation: dict[str, Any], options: dict[str, Any] | None = None
) -> tuple[dict[str, Any], dict[str, Any]]:
response = self.call_endpoint(
"get_action", {"observation": observation, "options": options}
)
return tuple(response) # Convert list (from msgpack) to tuple of (action, info)
def reset(self, options: dict[str, Any] | None = None) -> dict[str, Any]:
return self.call_endpoint("reset", {"options": options})
def get_modality_config(self) -> dict[str, ModalityConfig]:
return self.call_endpoint("get_modality_config", requires_input=False)
def check_observation(self, observation: dict[str, Any]) -> None:
raise NotImplementedError(
"check_observation is not implemented. Please use `strict=False` to disable strict mode or implement this method in the subclass."
)
def check_action(self, action: dict[str, Any]) -> None:
raise NotImplementedError(
"check_action is not implemented. Please use `strict=False` to disable strict mode or implement this method in the subclass."
)