Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Unity ML-Agents Environment Client. | |
| This module provides the client for connecting to a Unity ML-Agents | |
| Environment server via WebSocket for persistent sessions. | |
| """ | |
| from typing import Any, Dict, List, Optional | |
| # Support multiple import scenarios | |
| try: | |
| # In-repo imports (when running from OpenEnv repository root) | |
| from openenv.core.client_types import StepResult | |
| from openenv.core.env_client import EnvClient | |
| from .models import UnityAction, UnityObservation, UnityState | |
| except ImportError: | |
| # openenv from pip | |
| from openenv.core.client_types import StepResult | |
| from openenv.core.env_client import EnvClient | |
| try: | |
| # Direct execution from envs/unity_env/ directory | |
| from models import UnityAction, UnityObservation, UnityState | |
| except ImportError: | |
| try: | |
| # Package installed as unity_env | |
| from unity_env.models import UnityAction, UnityObservation, UnityState | |
| except ImportError: | |
| # Running from OpenEnv root with envs prefix | |
| from envs.unity_env.models import UnityAction, UnityObservation, UnityState | |
| class UnityEnv(EnvClient[UnityAction, UnityObservation, UnityState]): | |
| """ | |
| Client for Unity ML-Agents environments. | |
| This client maintains a persistent WebSocket connection to the environment | |
| server, enabling efficient multi-step interactions with lower latency. | |
| Each client instance has its own dedicated environment session on the server. | |
| Note: Unity environments can take 30-60+ seconds to initialize on first reset | |
| (downloading binaries, starting Unity process). The client is configured with | |
| longer ping timeouts to handle this. | |
| Supported Unity Environments: | |
| - PushBlock: Push a block to a goal (discrete actions: 7) | |
| - 3DBall: Balance a ball on a platform (continuous actions: 2) | |
| - 3DBallHard: Harder version of 3DBall | |
| - GridWorld: Navigate a grid to find goals | |
| - Basic: Simple movement task | |
| - And more from the ML-Agents registry | |
| Example: | |
| >>> # Connect to a running server | |
| >>> with UnityEnv(base_url="http://localhost:8000") as client: | |
| ... result = client.reset() | |
| ... print(f"Vector obs: {len(result.observation.vector_observations)} dims") | |
| ... | |
| ... # Take action (PushBlock: 1=forward) | |
| ... result = client.step(UnityAction(discrete_actions=[1])) | |
| ... print(f"Reward: {result.reward}") | |
| Example with Docker: | |
| >>> # Automatically start container and connect | |
| >>> client = UnityEnv.from_docker_image("unity-env:latest") | |
| >>> try: | |
| ... result = client.reset(env_id="3DBall") | |
| ... result = client.step(UnityAction(continuous_actions=[0.5, -0.3])) | |
| ... finally: | |
| ... client.close() | |
| Example switching environments: | |
| >>> client = UnityEnv(base_url="http://localhost:8000") | |
| >>> # Start with PushBlock | |
| >>> result = client.reset(env_id="PushBlock") | |
| >>> # ... train on PushBlock ... | |
| >>> # Switch to 3DBall | |
| >>> result = client.reset(env_id="3DBall") | |
| >>> # ... train on 3DBall ... | |
| """ | |
| def __init__( | |
| self, | |
| base_url: str, | |
| connect_timeout_s: float = 10.0, | |
| message_timeout_s: float = 180.0, # 3 minutes for slow Unity initialization | |
| provider: Optional[Any] = None, | |
| ): | |
| """ | |
| Initialize Unity environment client. | |
| Uses longer default timeouts than the base EnvClient because Unity | |
| environments can take 30-60+ seconds to initialize on first reset. | |
| Args: | |
| base_url: Base URL of the environment server (http:// or ws://). | |
| connect_timeout_s: Timeout for establishing WebSocket connection | |
| message_timeout_s: Timeout for receiving responses (default 3 min for Unity) | |
| provider: Optional container/runtime provider for lifecycle management. | |
| """ | |
| super().__init__( | |
| base_url=base_url, | |
| connect_timeout_s=connect_timeout_s, | |
| message_timeout_s=message_timeout_s, | |
| provider=provider, | |
| ) | |
| def connect(self) -> "UnityEnv": | |
| """ | |
| Establish WebSocket connection to the server. | |
| Overrides the default connection to use longer ping timeouts, | |
| since Unity environments can take 30-60+ seconds to initialize. | |
| Returns: | |
| self for method chaining | |
| Raises: | |
| ConnectionError: If connection cannot be established | |
| """ | |
| from websockets.sync.client import connect as ws_connect | |
| if self._ws is not None: | |
| return self | |
| try: | |
| # Use longer ping_timeout for Unity (60s) since environment | |
| # initialization can block the server for a while | |
| self._ws = ws_connect( | |
| self._ws_url, | |
| open_timeout=self._connect_timeout, | |
| ping_timeout=120, # 2 minutes for slow Unity initialization | |
| ping_interval=30, # Send pings every 30 seconds | |
| close_timeout=30, | |
| ) | |
| except Exception as e: | |
| raise ConnectionError(f"Failed to connect to {self._ws_url}: {e}") from e | |
| return self | |
| def _step_payload(self, action: UnityAction) -> Dict: | |
| """ | |
| Convert UnityAction to JSON payload for step request. | |
| Args: | |
| action: UnityAction instance | |
| Returns: | |
| Dictionary representation suitable for JSON encoding | |
| """ | |
| payload: Dict[str, Any] = {} | |
| if action.discrete_actions is not None: | |
| payload["discrete_actions"] = action.discrete_actions | |
| if action.continuous_actions is not None: | |
| payload["continuous_actions"] = action.continuous_actions | |
| if action.metadata: | |
| payload["metadata"] = action.metadata | |
| return payload | |
| def _parse_result(self, payload: Dict) -> StepResult[UnityObservation]: | |
| """ | |
| Parse server response into StepResult[UnityObservation]. | |
| Args: | |
| payload: JSON response from server | |
| Returns: | |
| StepResult with UnityObservation | |
| """ | |
| obs_data = payload.get("observation", {}) | |
| observation = UnityObservation( | |
| vector_observations=obs_data.get("vector_observations", []), | |
| visual_observations=obs_data.get("visual_observations"), | |
| behavior_name=obs_data.get("behavior_name", ""), | |
| action_spec_info=obs_data.get("action_spec_info", {}), | |
| observation_spec_info=obs_data.get("observation_spec_info", {}), | |
| done=payload.get("done", False), | |
| reward=payload.get("reward"), | |
| metadata=obs_data.get("metadata", {}), | |
| ) | |
| return StepResult( | |
| observation=observation, | |
| reward=payload.get("reward"), | |
| done=payload.get("done", False), | |
| ) | |
| def _parse_state(self, payload: Dict) -> UnityState: | |
| """ | |
| Parse server response into UnityState object. | |
| Args: | |
| payload: JSON response from /state endpoint | |
| Returns: | |
| UnityState object with environment information | |
| """ | |
| return UnityState( | |
| episode_id=payload.get("episode_id"), | |
| step_count=payload.get("step_count", 0), | |
| env_id=payload.get("env_id", ""), | |
| behavior_name=payload.get("behavior_name", ""), | |
| action_spec=payload.get("action_spec", {}), | |
| observation_spec=payload.get("observation_spec", {}), | |
| available_envs=payload.get("available_envs", []), | |
| ) | |
| def reset( | |
| self, | |
| env_id: Optional[str] = None, | |
| include_visual: bool = False, | |
| **kwargs, | |
| ) -> StepResult[UnityObservation]: | |
| """ | |
| Reset the environment. | |
| Args: | |
| env_id: Optionally switch to a different Unity environment. | |
| Available: PushBlock, 3DBall, 3DBallHard, GridWorld, Basic | |
| include_visual: If True, include visual observations in response. | |
| **kwargs: Additional arguments passed to server. | |
| Returns: | |
| StepResult with initial observation. | |
| """ | |
| reset_kwargs = dict(kwargs) | |
| if env_id is not None: | |
| reset_kwargs["env_id"] = env_id | |
| reset_kwargs["include_visual"] = include_visual | |
| return super().reset(**reset_kwargs) | |
| def available_environments() -> List[str]: | |
| """ | |
| List commonly available Unity environments. | |
| Note: The actual list may vary based on the ML-Agents registry version. | |
| Use state.available_envs after connecting for the authoritative list. | |
| Returns: | |
| List of environment identifiers. | |
| """ | |
| return [ | |
| "PushBlock", | |
| "3DBall", | |
| "3DBallHard", | |
| "GridWorld", | |
| "Basic", | |
| "VisualPushBlock", | |
| ] | |