# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """ Unity ML-Agents Environment Client. This module provides the client for connecting to a Unity ML-Agents Environment server via WebSocket for persistent sessions. """ from typing import Any, Dict, List, Optional # Support multiple import scenarios try: # In-repo imports (when running from OpenEnv repository root) from openenv.core.client_types import StepResult from openenv.core.env_client import EnvClient from .models import UnityAction, UnityObservation, UnityState except ImportError: # openenv from pip from openenv.core.client_types import StepResult from openenv.core.env_client import EnvClient try: # Direct execution from envs/unity_env/ directory from models import UnityAction, UnityObservation, UnityState except ImportError: try: # Package installed as unity_env from unity_env.models import UnityAction, UnityObservation, UnityState except ImportError: # Running from OpenEnv root with envs prefix from envs.unity_env.models import UnityAction, UnityObservation, UnityState class UnityEnv(EnvClient[UnityAction, UnityObservation, UnityState]): """ Client for Unity ML-Agents environments. This client maintains a persistent WebSocket connection to the environment server, enabling efficient multi-step interactions with lower latency. Each client instance has its own dedicated environment session on the server. Note: Unity environments can take 30-60+ seconds to initialize on first reset (downloading binaries, starting Unity process). The client is configured with longer ping timeouts to handle this. Supported Unity Environments: - PushBlock: Push a block to a goal (discrete actions: 7) - 3DBall: Balance a ball on a platform (continuous actions: 2) - 3DBallHard: Harder version of 3DBall - GridWorld: Navigate a grid to find goals - Basic: Simple movement task - And more from the ML-Agents registry Example: >>> # Connect to a running server >>> with UnityEnv(base_url="http://localhost:8000") as client: ... result = client.reset() ... print(f"Vector obs: {len(result.observation.vector_observations)} dims") ... ... # Take action (PushBlock: 1=forward) ... result = client.step(UnityAction(discrete_actions=[1])) ... print(f"Reward: {result.reward}") Example with Docker: >>> # Automatically start container and connect >>> client = UnityEnv.from_docker_image("unity-env:latest") >>> try: ... result = client.reset(env_id="3DBall") ... result = client.step(UnityAction(continuous_actions=[0.5, -0.3])) ... finally: ... client.close() Example switching environments: >>> client = UnityEnv(base_url="http://localhost:8000") >>> # Start with PushBlock >>> result = client.reset(env_id="PushBlock") >>> # ... train on PushBlock ... >>> # Switch to 3DBall >>> result = client.reset(env_id="3DBall") >>> # ... train on 3DBall ... """ def __init__( self, base_url: str, connect_timeout_s: float = 10.0, message_timeout_s: float = 180.0, # 3 minutes for slow Unity initialization provider: Optional[Any] = None, ): """ Initialize Unity environment client. Uses longer default timeouts than the base EnvClient because Unity environments can take 30-60+ seconds to initialize on first reset. Args: base_url: Base URL of the environment server (http:// or ws://). connect_timeout_s: Timeout for establishing WebSocket connection message_timeout_s: Timeout for receiving responses (default 3 min for Unity) provider: Optional container/runtime provider for lifecycle management. """ super().__init__( base_url=base_url, connect_timeout_s=connect_timeout_s, message_timeout_s=message_timeout_s, provider=provider, ) def connect(self) -> "UnityEnv": """ Establish WebSocket connection to the server. Overrides the default connection to use longer ping timeouts, since Unity environments can take 30-60+ seconds to initialize. Returns: self for method chaining Raises: ConnectionError: If connection cannot be established """ from websockets.sync.client import connect as ws_connect if self._ws is not None: return self try: # Use longer ping_timeout for Unity (60s) since environment # initialization can block the server for a while self._ws = ws_connect( self._ws_url, open_timeout=self._connect_timeout, ping_timeout=120, # 2 minutes for slow Unity initialization ping_interval=30, # Send pings every 30 seconds close_timeout=30, ) except Exception as e: raise ConnectionError(f"Failed to connect to {self._ws_url}: {e}") from e return self def _step_payload(self, action: UnityAction) -> Dict: """ Convert UnityAction to JSON payload for step request. Args: action: UnityAction instance Returns: Dictionary representation suitable for JSON encoding """ payload: Dict[str, Any] = {} if action.discrete_actions is not None: payload["discrete_actions"] = action.discrete_actions if action.continuous_actions is not None: payload["continuous_actions"] = action.continuous_actions if action.metadata: payload["metadata"] = action.metadata return payload def _parse_result(self, payload: Dict) -> StepResult[UnityObservation]: """ Parse server response into StepResult[UnityObservation]. Args: payload: JSON response from server Returns: StepResult with UnityObservation """ obs_data = payload.get("observation", {}) observation = UnityObservation( vector_observations=obs_data.get("vector_observations", []), visual_observations=obs_data.get("visual_observations"), behavior_name=obs_data.get("behavior_name", ""), action_spec_info=obs_data.get("action_spec_info", {}), observation_spec_info=obs_data.get("observation_spec_info", {}), done=payload.get("done", False), reward=payload.get("reward"), metadata=obs_data.get("metadata", {}), ) return StepResult( observation=observation, reward=payload.get("reward"), done=payload.get("done", False), ) def _parse_state(self, payload: Dict) -> UnityState: """ Parse server response into UnityState object. Args: payload: JSON response from /state endpoint Returns: UnityState object with environment information """ return UnityState( episode_id=payload.get("episode_id"), step_count=payload.get("step_count", 0), env_id=payload.get("env_id", ""), behavior_name=payload.get("behavior_name", ""), action_spec=payload.get("action_spec", {}), observation_spec=payload.get("observation_spec", {}), available_envs=payload.get("available_envs", []), ) def reset( self, env_id: Optional[str] = None, include_visual: bool = False, **kwargs, ) -> StepResult[UnityObservation]: """ Reset the environment. Args: env_id: Optionally switch to a different Unity environment. Available: PushBlock, 3DBall, 3DBallHard, GridWorld, Basic include_visual: If True, include visual observations in response. **kwargs: Additional arguments passed to server. Returns: StepResult with initial observation. """ reset_kwargs = dict(kwargs) if env_id is not None: reset_kwargs["env_id"] = env_id reset_kwargs["include_visual"] = include_visual return super().reset(**reset_kwargs) @staticmethod def available_environments() -> List[str]: """ List commonly available Unity environments. Note: The actual list may vary based on the ML-Agents registry version. Use state.available_envs after connecting for the authoritative list. Returns: List of environment identifiers. """ return [ "PushBlock", "3DBall", "3DBallHard", "GridWorld", "Basic", "VisualPushBlock", ] @classmethod def from_direct( cls, env_id: str = "PushBlock", no_graphics: bool = False, width: int = 1280, height: int = 720, time_scale: float = 1.0, quality_level: int = 5, port: int = 8765, ) -> "UnityEnv": """ Create a Unity environment client with an embedded local server. This method starts a local uvicorn server in a subprocess and returns a client connected to it. This provides the convenience of direct mode while maintaining the client-server separation. Note: The first call will download Unity binaries (~500MB) which may take several minutes. Binaries are cached for subsequent runs. Args: env_id: Default Unity environment to use (PushBlock, 3DBall, etc.) no_graphics: If True, run Unity in headless mode (faster for training) width: Window width in pixels (default: 1280) height: Window height in pixels (default: 720) time_scale: Simulation speed multiplier (default: 1.0, use 20.0 for fast training) quality_level: Graphics quality 0-5 (default: 5) port: Port for the local server (default: 8765) Returns: UnityEnv client connected to the local server Example: >>> # Quick start with direct mode >>> client = UnityEnv.from_direct(no_graphics=True, time_scale=20) >>> try: ... result = client.reset(env_id="PushBlock") ... for _ in range(100): ... result = client.step(UnityAction(discrete_actions=[1])) ... finally: ... client.close() >>> # With custom settings >>> client = UnityEnv.from_direct( ... env_id="3DBall", ... no_graphics=True, ... time_scale=20, ... port=9000 ... ) """ import os import subprocess import sys import time import requests # Find the project root and server module # Try to locate the server module try: from pathlib import Path # Get the directory containing this file client_dir = Path(__file__).parent server_app = "envs.unity_env.server.app:app" cwd = client_dir.parent.parent # OpenEnv root # Check if we're in the envs/unity_env directory structure if not (cwd / "envs" / "unity_env" / "server" / "app.py").exists(): # Try alternative paths if (client_dir / "server" / "app.py").exists(): server_app = "server.app:app" cwd = client_dir except Exception: server_app = "envs.unity_env.server.app:app" cwd = None # Set up environment variables for Unity configuration env = { **os.environ, "UNITY_ENV_ID": env_id, "UNITY_NO_GRAPHICS": "1" if no_graphics else "0", "UNITY_WIDTH": str(width), "UNITY_HEIGHT": str(height), "UNITY_TIME_SCALE": str(time_scale), "UNITY_QUALITY_LEVEL": str(quality_level), # Bypass proxy for localhost "NO_PROXY": "localhost,127.0.0.1", "no_proxy": "localhost,127.0.0.1", } # Add src to PYTHONPATH if needed if cwd: src_path = str(cwd / "src") existing_path = env.get("PYTHONPATH", "") env["PYTHONPATH"] = f"{src_path}:{cwd}:{existing_path}" if existing_path else f"{src_path}:{cwd}" # Start the server cmd = [ sys.executable, "-m", "uvicorn", server_app, "--host", "127.0.0.1", "--port", str(port), ] server_process = subprocess.Popen( cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=str(cwd) if cwd else None, ) # Wait for server to become healthy base_url = f"http://127.0.0.1:{port}" healthy = False for _ in range(30): # Wait up to 30 seconds try: response = requests.get( f"{base_url}/health", timeout=2, proxies={"http": None, "https": None}, ) if response.status_code == 200: healthy = True break except requests.exceptions.RequestException: pass time.sleep(1) if not healthy: server_process.kill() raise RuntimeError( f"Failed to start local Unity server on port {port}. " "Check that the port is available and dependencies are installed." ) # Create a provider to manage the subprocess lifecycle class DirectModeProvider: """Provider that manages the embedded server subprocess.""" def __init__(self, process: subprocess.Popen): self._process = process def stop(self): """Stop the embedded server.""" if self._process: self._process.terminate() try: self._process.wait(timeout=10) except subprocess.TimeoutExpired: self._process.kill() self._process = None provider = DirectModeProvider(server_process) # Create and return the client client = cls(base_url=base_url, provider=provider) return client