unity_env / client.py
Crashbandicoote2's picture
Upload folder using huggingface_hub
0f53490 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Unity ML-Agents Environment Client.
This module provides the client for connecting to a Unity ML-Agents
Environment server via WebSocket for persistent sessions.
"""
from typing import Any, Dict, List, Optional
# Support multiple import scenarios
try:
# In-repo imports (when running from OpenEnv repository root)
from openenv.core.client_types import StepResult
from openenv.core.env_client import EnvClient
from .models import UnityAction, UnityObservation, UnityState
except ImportError:
# openenv from pip
from openenv.core.client_types import StepResult
from openenv.core.env_client import EnvClient
try:
# Direct execution from envs/unity_env/ directory
from models import UnityAction, UnityObservation, UnityState
except ImportError:
try:
# Package installed as unity_env
from unity_env.models import UnityAction, UnityObservation, UnityState
except ImportError:
# Running from OpenEnv root with envs prefix
from envs.unity_env.models import UnityAction, UnityObservation, UnityState
class UnityEnv(EnvClient[UnityAction, UnityObservation, UnityState]):
"""
Client for Unity ML-Agents environments.
This client maintains a persistent WebSocket connection to the environment
server, enabling efficient multi-step interactions with lower latency.
Each client instance has its own dedicated environment session on the server.
Note: Unity environments can take 30-60+ seconds to initialize on first reset
(downloading binaries, starting Unity process). The client is configured with
longer ping timeouts to handle this.
Supported Unity Environments:
- PushBlock: Push a block to a goal (discrete actions: 7)
- 3DBall: Balance a ball on a platform (continuous actions: 2)
- 3DBallHard: Harder version of 3DBall
- GridWorld: Navigate a grid to find goals
- Basic: Simple movement task
- And more from the ML-Agents registry
Example:
>>> # Connect to a running server
>>> with UnityEnv(base_url="http://localhost:8000") as client:
... result = client.reset()
... print(f"Vector obs: {len(result.observation.vector_observations)} dims")
...
... # Take action (PushBlock: 1=forward)
... result = client.step(UnityAction(discrete_actions=[1]))
... print(f"Reward: {result.reward}")
Example with Docker:
>>> # Automatically start container and connect
>>> client = UnityEnv.from_docker_image("unity-env:latest")
>>> try:
... result = client.reset(env_id="3DBall")
... result = client.step(UnityAction(continuous_actions=[0.5, -0.3]))
... finally:
... client.close()
Example switching environments:
>>> client = UnityEnv(base_url="http://localhost:8000")
>>> # Start with PushBlock
>>> result = client.reset(env_id="PushBlock")
>>> # ... train on PushBlock ...
>>> # Switch to 3DBall
>>> result = client.reset(env_id="3DBall")
>>> # ... train on 3DBall ...
"""
def __init__(
self,
base_url: str,
connect_timeout_s: float = 10.0,
message_timeout_s: float = 180.0, # 3 minutes for slow Unity initialization
provider: Optional[Any] = None,
):
"""
Initialize Unity environment client.
Uses longer default timeouts than the base EnvClient because Unity
environments can take 30-60+ seconds to initialize on first reset.
Args:
base_url: Base URL of the environment server (http:// or ws://).
connect_timeout_s: Timeout for establishing WebSocket connection
message_timeout_s: Timeout for receiving responses (default 3 min for Unity)
provider: Optional container/runtime provider for lifecycle management.
"""
super().__init__(
base_url=base_url,
connect_timeout_s=connect_timeout_s,
message_timeout_s=message_timeout_s,
provider=provider,
)
def connect(self) -> "UnityEnv":
"""
Establish WebSocket connection to the server.
Overrides the default connection to use longer ping timeouts,
since Unity environments can take 30-60+ seconds to initialize.
Returns:
self for method chaining
Raises:
ConnectionError: If connection cannot be established
"""
from websockets.sync.client import connect as ws_connect
if self._ws is not None:
return self
try:
# Use longer ping_timeout for Unity (60s) since environment
# initialization can block the server for a while
self._ws = ws_connect(
self._ws_url,
open_timeout=self._connect_timeout,
ping_timeout=120, # 2 minutes for slow Unity initialization
ping_interval=30, # Send pings every 30 seconds
close_timeout=30,
)
except Exception as e:
raise ConnectionError(f"Failed to connect to {self._ws_url}: {e}") from e
return self
def _step_payload(self, action: UnityAction) -> Dict:
"""
Convert UnityAction to JSON payload for step request.
Args:
action: UnityAction instance
Returns:
Dictionary representation suitable for JSON encoding
"""
payload: Dict[str, Any] = {}
if action.discrete_actions is not None:
payload["discrete_actions"] = action.discrete_actions
if action.continuous_actions is not None:
payload["continuous_actions"] = action.continuous_actions
if action.metadata:
payload["metadata"] = action.metadata
return payload
def _parse_result(self, payload: Dict) -> StepResult[UnityObservation]:
"""
Parse server response into StepResult[UnityObservation].
Args:
payload: JSON response from server
Returns:
StepResult with UnityObservation
"""
obs_data = payload.get("observation", {})
observation = UnityObservation(
vector_observations=obs_data.get("vector_observations", []),
visual_observations=obs_data.get("visual_observations"),
behavior_name=obs_data.get("behavior_name", ""),
action_spec_info=obs_data.get("action_spec_info", {}),
observation_spec_info=obs_data.get("observation_spec_info", {}),
done=payload.get("done", False),
reward=payload.get("reward"),
metadata=obs_data.get("metadata", {}),
)
return StepResult(
observation=observation,
reward=payload.get("reward"),
done=payload.get("done", False),
)
def _parse_state(self, payload: Dict) -> UnityState:
"""
Parse server response into UnityState object.
Args:
payload: JSON response from /state endpoint
Returns:
UnityState object with environment information
"""
return UnityState(
episode_id=payload.get("episode_id"),
step_count=payload.get("step_count", 0),
env_id=payload.get("env_id", ""),
behavior_name=payload.get("behavior_name", ""),
action_spec=payload.get("action_spec", {}),
observation_spec=payload.get("observation_spec", {}),
available_envs=payload.get("available_envs", []),
)
def reset(
self,
env_id: Optional[str] = None,
include_visual: bool = False,
**kwargs,
) -> StepResult[UnityObservation]:
"""
Reset the environment.
Args:
env_id: Optionally switch to a different Unity environment.
Available: PushBlock, 3DBall, 3DBallHard, GridWorld, Basic
include_visual: If True, include visual observations in response.
**kwargs: Additional arguments passed to server.
Returns:
StepResult with initial observation.
"""
reset_kwargs = dict(kwargs)
if env_id is not None:
reset_kwargs["env_id"] = env_id
reset_kwargs["include_visual"] = include_visual
return super().reset(**reset_kwargs)
@staticmethod
def available_environments() -> List[str]:
"""
List commonly available Unity environments.
Note: The actual list may vary based on the ML-Agents registry version.
Use state.available_envs after connecting for the authoritative list.
Returns:
List of environment identifiers.
"""
return [
"PushBlock",
"3DBall",
"3DBallHard",
"GridWorld",
"Basic",
"VisualPushBlock",
]