| | import atexit |
| | from distutils.version import StrictVersion |
| |
|
| | import numpy as np |
| | import os |
| | import subprocess |
| | from typing import Dict, List, Optional, Tuple, Mapping as MappingType |
| |
|
| | import mlagents_envs |
| |
|
| | from mlagents_envs.logging_util import get_logger |
| | from mlagents_envs.side_channel.side_channel import SideChannel |
| | from mlagents_envs.side_channel import DefaultTrainingAnalyticsSideChannel |
| | from mlagents_envs.side_channel.side_channel_manager import SideChannelManager |
| | from mlagents_envs import env_utils |
| |
|
| | from mlagents_envs.base_env import ( |
| | BaseEnv, |
| | DecisionSteps, |
| | TerminalSteps, |
| | BehaviorSpec, |
| | ActionTuple, |
| | BehaviorName, |
| | AgentId, |
| | BehaviorMapping, |
| | ) |
| | from mlagents_envs.timers import timed, hierarchical_timer |
| | from mlagents_envs.exception import ( |
| | UnityEnvironmentException, |
| | UnityActionException, |
| | UnityTimeOutException, |
| | UnityCommunicatorStoppedException, |
| | ) |
| |
|
| | from mlagents_envs.communicator_objects.command_pb2 import STEP, RESET |
| | from mlagents_envs.rpc_utils import behavior_spec_from_proto, steps_from_proto |
| |
|
| | from mlagents_envs.communicator_objects.unity_rl_input_pb2 import UnityRLInputProto |
| | from mlagents_envs.communicator_objects.unity_rl_output_pb2 import UnityRLOutputProto |
| | from mlagents_envs.communicator_objects.agent_action_pb2 import AgentActionProto |
| | from mlagents_envs.communicator_objects.unity_output_pb2 import UnityOutputProto |
| | from mlagents_envs.communicator_objects.capabilities_pb2 import UnityRLCapabilitiesProto |
| | from mlagents_envs.communicator_objects.unity_rl_initialization_input_pb2 import ( |
| | UnityRLInitializationInputProto, |
| | ) |
| |
|
| | from mlagents_envs.communicator_objects.unity_input_pb2 import UnityInputProto |
| |
|
| | from .rpc_communicator import RpcCommunicator |
| | import signal |
| |
|
| | logger = get_logger(__name__) |
| |
|
| |
|
| | class UnityEnvironment(BaseEnv): |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | API_VERSION = "1.5.0" |
| |
|
| | |
| | |
| | DEFAULT_EDITOR_PORT = 5004 |
| |
|
| | |
| | |
| | BASE_ENVIRONMENT_PORT = 5005 |
| |
|
| | |
| | _PORT_COMMAND_LINE_ARG = "--mlagents-port" |
| |
|
| | @staticmethod |
| | def _raise_version_exception(unity_com_ver: str) -> None: |
| | raise UnityEnvironmentException( |
| | f"The communication API version is not compatible between Unity and python. " |
| | f"Python API: {UnityEnvironment.API_VERSION}, Unity API: {unity_com_ver}.\n " |
| | f"Please find the versions that work best together from our release page.\n" |
| | "https://github.com/Unity-Technologies/ml-agents/releases" |
| | ) |
| |
|
| | @staticmethod |
| | def _check_communication_compatibility( |
| | unity_com_ver: str, python_api_version: str, unity_package_version: str |
| | ) -> bool: |
| | unity_communicator_version = StrictVersion(unity_com_ver) |
| | api_version = StrictVersion(python_api_version) |
| | if unity_communicator_version.version[0] == 0: |
| | if ( |
| | unity_communicator_version.version[0] != api_version.version[0] |
| | or unity_communicator_version.version[1] != api_version.version[1] |
| | ): |
| | |
| | return False |
| | elif unity_communicator_version.version[0] != api_version.version[0]: |
| | |
| | return False |
| | else: |
| | |
| | |
| | |
| | |
| | |
| | |
| | logger.info( |
| | f"Connected to Unity environment with package version {unity_package_version} " |
| | f"and communication version {unity_com_ver}" |
| | ) |
| | return True |
| |
|
| | @staticmethod |
| | def _get_capabilities_proto() -> UnityRLCapabilitiesProto: |
| | capabilities = UnityRLCapabilitiesProto() |
| | capabilities.baseRLCapabilities = True |
| | capabilities.concatenatedPngObservations = True |
| | capabilities.compressedChannelMapping = True |
| | capabilities.hybridActions = True |
| | capabilities.trainingAnalytics = True |
| | capabilities.variableLengthObservation = True |
| | capabilities.multiAgentGroups = True |
| | return capabilities |
| |
|
| | @staticmethod |
| | def _warn_csharp_base_capabilities( |
| | caps: UnityRLCapabilitiesProto, unity_package_ver: str, python_package_ver: str |
| | ) -> None: |
| | if not caps.baseRLCapabilities: |
| | logger.warning( |
| | "WARNING: The Unity process is not running with the expected base Reinforcement Learning" |
| | " capabilities. Please be sure upgrade the Unity Package to a version that is compatible with this " |
| | "python package.\n" |
| | f"Python package version: {python_package_ver}, C# package version: {unity_package_ver}" |
| | f"Please find the versions that work best together from our release page.\n" |
| | "https://github.com/Unity-Technologies/ml-agents/releases" |
| | ) |
| |
|
| | def __init__( |
| | self, |
| | file_name: Optional[str] = None, |
| | worker_id: int = 0, |
| | base_port: Optional[int] = None, |
| | seed: int = 0, |
| | no_graphics: bool = False, |
| | timeout_wait: int = 60, |
| | additional_args: Optional[List[str]] = None, |
| | side_channels: Optional[List[SideChannel]] = None, |
| | log_folder: Optional[str] = None, |
| | num_areas: int = 1, |
| | ): |
| | """ |
| | Starts a new unity environment and establishes a connection with the environment. |
| | Notice: Currently communication between Unity and Python takes place over an open socket without authentication. |
| | Ensure that the network where training takes place is secure. |
| | |
| | :string file_name: Name of Unity environment binary. |
| | :int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this. |
| | If no environment is specified (i.e. file_name is None), the DEFAULT_EDITOR_PORT will be used. |
| | :int worker_id: Offset from base_port. Used for training multiple environments simultaneously. |
| | :bool no_graphics: Whether to run the Unity simulator in no-graphics mode |
| | :int timeout_wait: Time (in seconds) to wait for connection from environment. |
| | :list args: Addition Unity command line arguments |
| | :list side_channels: Additional side channel for no-rl communication with Unity |
| | :str log_folder: Optional folder to write the Unity Player log file into. Requires absolute path. |
| | """ |
| | atexit.register(self._close) |
| | self._additional_args = additional_args or [] |
| | self._no_graphics = no_graphics |
| | |
| | |
| | if base_port is None: |
| | base_port = ( |
| | self.BASE_ENVIRONMENT_PORT if file_name else self.DEFAULT_EDITOR_PORT |
| | ) |
| | self._port = base_port + worker_id |
| | self._buffer_size = 12000 |
| | |
| | self._loaded = False |
| | |
| | self._process: Optional[subprocess.Popen] = None |
| | self._timeout_wait: int = timeout_wait |
| | self._communicator = self._get_communicator(worker_id, base_port, timeout_wait) |
| | self._worker_id = worker_id |
| | if side_channels is None: |
| | side_channels = [] |
| | default_training_side_channel: Optional[ |
| | DefaultTrainingAnalyticsSideChannel |
| | ] = None |
| | if DefaultTrainingAnalyticsSideChannel.CHANNEL_ID not in [ |
| | _.channel_id for _ in side_channels |
| | ]: |
| | default_training_side_channel = DefaultTrainingAnalyticsSideChannel() |
| | side_channels.append(default_training_side_channel) |
| | self._side_channel_manager = SideChannelManager(side_channels) |
| | self._log_folder = log_folder |
| | self.academy_capabilities: UnityRLCapabilitiesProto = None |
| |
|
| | |
| | |
| | |
| | if file_name is None and worker_id != 0: |
| | raise UnityEnvironmentException( |
| | "If the environment name is None, " |
| | "the worker-id must be 0 in order to connect with the Editor." |
| | ) |
| | if file_name is not None: |
| | try: |
| | self._process = env_utils.launch_executable( |
| | file_name, self._executable_args() |
| | ) |
| | except UnityEnvironmentException: |
| | self._close(0) |
| | raise |
| | else: |
| | logger.info( |
| | f"Listening on port {self._port}. " |
| | f"Start training by pressing the Play button in the Unity Editor." |
| | ) |
| | self._loaded = True |
| |
|
| | rl_init_parameters_in = UnityRLInitializationInputProto( |
| | seed=seed, |
| | communication_version=self.API_VERSION, |
| | package_version=mlagents_envs.__version__, |
| | capabilities=UnityEnvironment._get_capabilities_proto(), |
| | num_areas=num_areas, |
| | ) |
| | try: |
| | aca_output = self._send_academy_parameters(rl_init_parameters_in) |
| | aca_params = aca_output.rl_initialization_output |
| | except UnityTimeOutException: |
| | self._close(0) |
| | raise |
| |
|
| | if not UnityEnvironment._check_communication_compatibility( |
| | aca_params.communication_version, |
| | UnityEnvironment.API_VERSION, |
| | aca_params.package_version, |
| | ): |
| | self._close(0) |
| | UnityEnvironment._raise_version_exception(aca_params.communication_version) |
| |
|
| | UnityEnvironment._warn_csharp_base_capabilities( |
| | aca_params.capabilities, |
| | aca_params.package_version, |
| | UnityEnvironment.API_VERSION, |
| | ) |
| |
|
| | self._env_state: Dict[str, Tuple[DecisionSteps, TerminalSteps]] = {} |
| | self._env_specs: Dict[str, BehaviorSpec] = {} |
| | self._env_actions: Dict[str, ActionTuple] = {} |
| | self._is_first_message = True |
| | self._update_behavior_specs(aca_output) |
| | self.academy_capabilities = aca_params.capabilities |
| | if default_training_side_channel is not None: |
| | default_training_side_channel.environment_initialized() |
| |
|
| | @staticmethod |
| | def _get_communicator(worker_id, base_port, timeout_wait): |
| | return RpcCommunicator(worker_id, base_port, timeout_wait) |
| |
|
| | def _executable_args(self) -> List[str]: |
| | args: List[str] = [] |
| | if self._no_graphics: |
| | args += ["-nographics", "-batchmode"] |
| | args += [UnityEnvironment._PORT_COMMAND_LINE_ARG, str(self._port)] |
| |
|
| | |
| | |
| | logfile_set = "-logfile" in (arg.lower() for arg in self._additional_args) |
| | if self._log_folder and not logfile_set: |
| | log_file_path = os.path.join( |
| | self._log_folder, f"Player-{self._worker_id}.log" |
| | ) |
| | args += ["-logFile", log_file_path] |
| | |
| | args += self._additional_args |
| | return args |
| |
|
| | def _update_behavior_specs(self, output: UnityOutputProto) -> None: |
| | init_output = output.rl_initialization_output |
| | for brain_param in init_output.brain_parameters: |
| | |
| | |
| | agent_infos = output.rl_output.agentInfos[brain_param.brain_name] |
| | if agent_infos.value: |
| | agent = agent_infos.value[0] |
| | new_spec = behavior_spec_from_proto(brain_param, agent) |
| | self._env_specs[brain_param.brain_name] = new_spec |
| | logger.info(f"Connected new brain: {brain_param.brain_name}") |
| |
|
| | def _update_state(self, output: UnityRLOutputProto) -> None: |
| | """ |
| | Collects experience information from all external brains in environment at current step. |
| | """ |
| | for brain_name in self._env_specs.keys(): |
| | if brain_name in output.agentInfos: |
| | agent_info_list = output.agentInfos[brain_name].value |
| | self._env_state[brain_name] = steps_from_proto( |
| | agent_info_list, self._env_specs[brain_name] |
| | ) |
| | else: |
| | self._env_state[brain_name] = ( |
| | DecisionSteps.empty(self._env_specs[brain_name]), |
| | TerminalSteps.empty(self._env_specs[brain_name]), |
| | ) |
| | self._side_channel_manager.process_side_channel_message(output.side_channel) |
| |
|
| | def reset(self) -> None: |
| | if self._loaded: |
| | outputs = self._communicator.exchange( |
| | self._generate_reset_input(), self._poll_process |
| | ) |
| | if outputs is None: |
| | raise UnityCommunicatorStoppedException("Communicator has exited.") |
| | self._update_behavior_specs(outputs) |
| | rl_output = outputs.rl_output |
| | self._update_state(rl_output) |
| | self._is_first_message = False |
| | self._env_actions.clear() |
| | else: |
| | raise UnityEnvironmentException("No Unity environment is loaded.") |
| |
|
| | @timed |
| | def step(self) -> None: |
| | if self._is_first_message: |
| | return self.reset() |
| | if not self._loaded: |
| | raise UnityEnvironmentException("No Unity environment is loaded.") |
| | |
| | for group_name in self._env_specs: |
| | if group_name not in self._env_actions: |
| | n_agents = 0 |
| | if group_name in self._env_state: |
| | n_agents = len(self._env_state[group_name][0]) |
| | self._env_actions[group_name] = self._env_specs[ |
| | group_name |
| | ].action_spec.empty_action(n_agents) |
| | step_input = self._generate_step_input(self._env_actions) |
| | with hierarchical_timer("communicator.exchange"): |
| | outputs = self._communicator.exchange(step_input, self._poll_process) |
| | if outputs is None: |
| | raise UnityCommunicatorStoppedException("Communicator has exited.") |
| | self._update_behavior_specs(outputs) |
| | rl_output = outputs.rl_output |
| | self._update_state(rl_output) |
| | self._env_actions.clear() |
| |
|
| | @property |
| | def behavior_specs(self) -> MappingType[str, BehaviorSpec]: |
| | return BehaviorMapping(self._env_specs) |
| |
|
| | def _assert_behavior_exists(self, behavior_name: str) -> None: |
| | if behavior_name not in self._env_specs: |
| | raise UnityActionException( |
| | f"The group {behavior_name} does not correspond to an existing " |
| | f"agent group in the environment" |
| | ) |
| |
|
| | def set_actions(self, behavior_name: BehaviorName, action: ActionTuple) -> None: |
| | self._assert_behavior_exists(behavior_name) |
| | if behavior_name not in self._env_state: |
| | return |
| | action_spec = self._env_specs[behavior_name].action_spec |
| | num_agents = len(self._env_state[behavior_name][0]) |
| | action = action_spec._validate_action(action, num_agents, behavior_name) |
| | self._env_actions[behavior_name] = action |
| |
|
| | def set_action_for_agent( |
| | self, behavior_name: BehaviorName, agent_id: AgentId, action: ActionTuple |
| | ) -> None: |
| | self._assert_behavior_exists(behavior_name) |
| | if behavior_name not in self._env_state: |
| | return |
| | action_spec = self._env_specs[behavior_name].action_spec |
| | action = action_spec._validate_action(action, 1, behavior_name) |
| | if behavior_name not in self._env_actions: |
| | num_agents = len(self._env_state[behavior_name][0]) |
| | self._env_actions[behavior_name] = action_spec.empty_action(num_agents) |
| | try: |
| | index = np.where(self._env_state[behavior_name][0].agent_id == agent_id)[0][ |
| | 0 |
| | ] |
| | except IndexError as ie: |
| | raise IndexError( |
| | "agent_id {} is did not request a decision at the previous step".format( |
| | agent_id |
| | ) |
| | ) from ie |
| | if action_spec.continuous_size > 0: |
| | self._env_actions[behavior_name].continuous[index] = action.continuous[0, :] |
| | if action_spec.discrete_size > 0: |
| | self._env_actions[behavior_name].discrete[index] = action.discrete[0, :] |
| |
|
| | def get_steps( |
| | self, behavior_name: BehaviorName |
| | ) -> Tuple[DecisionSteps, TerminalSteps]: |
| | self._assert_behavior_exists(behavior_name) |
| | return self._env_state[behavior_name] |
| |
|
| | def _poll_process(self) -> None: |
| | """ |
| | Check the status of the subprocess. If it has exited, raise a UnityEnvironmentException |
| | :return: None |
| | """ |
| | if not self._process: |
| | return |
| | poll_res = self._process.poll() |
| | if poll_res is not None: |
| | exc_msg = self._returncode_to_env_message(self._process.returncode) |
| | raise UnityEnvironmentException(exc_msg) |
| |
|
| | def close(self): |
| | """ |
| | Sends a shutdown signal to the unity environment, and closes the socket connection. |
| | """ |
| | if self._loaded: |
| | self._close() |
| | else: |
| | raise UnityEnvironmentException("No Unity environment is loaded.") |
| |
|
| | def _close(self, timeout: Optional[int] = None) -> None: |
| | """ |
| | Close the communicator and environment subprocess (if necessary). |
| | |
| | :int timeout: [Optional] Number of seconds to wait for the environment to shut down before |
| | force-killing it. Defaults to `self.timeout_wait`. |
| | """ |
| | if timeout is None: |
| | timeout = self._timeout_wait |
| | self._loaded = False |
| | self._communicator.close() |
| | if self._process is not None: |
| | |
| | try: |
| | self._process.wait(timeout=timeout) |
| | logger.debug(self._returncode_to_env_message(self._process.returncode)) |
| | except subprocess.TimeoutExpired: |
| | logger.warning("Environment timed out shutting down. Killing...") |
| | self._process.kill() |
| | |
| | self._process = None |
| |
|
| | @timed |
| | def _generate_step_input( |
| | self, vector_action: Dict[str, ActionTuple] |
| | ) -> UnityInputProto: |
| | rl_in = UnityRLInputProto() |
| | for b in vector_action: |
| | n_agents = len(self._env_state[b][0]) |
| | if n_agents == 0: |
| | continue |
| | for i in range(n_agents): |
| | action = AgentActionProto() |
| | if vector_action[b].continuous is not None: |
| | action.vector_actions_deprecated.extend( |
| | vector_action[b].continuous[i] |
| | ) |
| | action.continuous_actions.extend(vector_action[b].continuous[i]) |
| | if vector_action[b].discrete is not None: |
| | action.vector_actions_deprecated.extend( |
| | vector_action[b].discrete[i] |
| | ) |
| | action.discrete_actions.extend(vector_action[b].discrete[i]) |
| | rl_in.agent_actions[b].value.extend([action]) |
| | rl_in.command = STEP |
| | rl_in.side_channel = bytes( |
| | self._side_channel_manager.generate_side_channel_messages() |
| | ) |
| | return self._wrap_unity_input(rl_in) |
| |
|
| | def _generate_reset_input(self) -> UnityInputProto: |
| | rl_in = UnityRLInputProto() |
| | rl_in.command = RESET |
| | rl_in.side_channel = bytes( |
| | self._side_channel_manager.generate_side_channel_messages() |
| | ) |
| | return self._wrap_unity_input(rl_in) |
| |
|
| | def _send_academy_parameters( |
| | self, init_parameters: UnityRLInitializationInputProto |
| | ) -> UnityOutputProto: |
| | inputs = UnityInputProto() |
| | inputs.rl_initialization_input.CopyFrom(init_parameters) |
| | return self._communicator.initialize(inputs, self._poll_process) |
| |
|
| | @staticmethod |
| | def _wrap_unity_input(rl_input: UnityRLInputProto) -> UnityInputProto: |
| | result = UnityInputProto() |
| | result.rl_input.CopyFrom(rl_input) |
| | return result |
| |
|
| | @staticmethod |
| | def _returncode_to_signal_name(returncode: int) -> Optional[str]: |
| | """ |
| | Try to convert return codes into their corresponding signal name. |
| | E.g. returncode_to_signal_name(-2) -> "SIGINT" |
| | """ |
| | try: |
| | |
| | s = signal.Signals(-returncode) |
| | return s.name |
| | except Exception: |
| | |
| | return None |
| |
|
| | @staticmethod |
| | def _returncode_to_env_message(returncode: int) -> str: |
| | signal_name = UnityEnvironment._returncode_to_signal_name(returncode) |
| | signal_name = f" ({signal_name})" if signal_name else "" |
| | return f"Environment shut down with return code {returncode}{signal_name}." |
| |
|