| | import datetime |
| | from typing import Dict, NamedTuple, List, Any, Optional, Callable, Set |
| | import cloudpickle |
| | import enum |
| | import time |
| |
|
| | from mlagents_envs.environment import UnityEnvironment |
| | from mlagents_envs.exception import ( |
| | UnityCommunicationException, |
| | UnityTimeOutException, |
| | UnityEnvironmentException, |
| | UnityCommunicatorStoppedException, |
| | ) |
| | from multiprocessing import Process, Pipe, Queue |
| | from multiprocessing.connection import Connection |
| | from queue import Empty as EmptyQueueException |
| | from mlagents_envs.base_env import BaseEnv, BehaviorName, BehaviorSpec |
| | from mlagents_envs import logging_util |
| | from mlagents.trainers.env_manager import EnvManager, EnvironmentStep, AllStepResult |
| | from mlagents.trainers.settings import TrainerSettings |
| | from mlagents_envs.timers import ( |
| | TimerNode, |
| | timed, |
| | hierarchical_timer, |
| | reset_timers, |
| | get_timer_root, |
| | ) |
| | from mlagents.trainers.settings import ParameterRandomizationSettings, RunOptions |
| | from mlagents.trainers.action_info import ActionInfo |
| | from mlagents_envs.side_channel.environment_parameters_channel import ( |
| | EnvironmentParametersChannel, |
| | ) |
| | from mlagents_envs.side_channel.engine_configuration_channel import ( |
| | EngineConfigurationChannel, |
| | EngineConfig, |
| | ) |
| | from mlagents_envs.side_channel.stats_side_channel import ( |
| | EnvironmentStats, |
| | StatsSideChannel, |
| | ) |
| | from mlagents.trainers.training_analytics_side_channel import ( |
| | TrainingAnalyticsSideChannel, |
| | ) |
| | from mlagents_envs.side_channel.side_channel import SideChannel |
| |
|
| |
|
| | logger = logging_util.get_logger(__name__) |
| | WORKER_SHUTDOWN_TIMEOUT_S = 10 |
| |
|
| |
|
| | class EnvironmentCommand(enum.Enum): |
| | STEP = 1 |
| | BEHAVIOR_SPECS = 2 |
| | ENVIRONMENT_PARAMETERS = 3 |
| | RESET = 4 |
| | CLOSE = 5 |
| | ENV_EXITED = 6 |
| | CLOSED = 7 |
| | TRAINING_STARTED = 8 |
| |
|
| |
|
| | class EnvironmentRequest(NamedTuple): |
| | cmd: EnvironmentCommand |
| | payload: Any = None |
| |
|
| |
|
| | class EnvironmentResponse(NamedTuple): |
| | cmd: EnvironmentCommand |
| | worker_id: int |
| | payload: Any |
| |
|
| |
|
| | class StepResponse(NamedTuple): |
| | all_step_result: AllStepResult |
| | timer_root: Optional[TimerNode] |
| | environment_stats: EnvironmentStats |
| |
|
| |
|
| | class UnityEnvWorker: |
| | def __init__(self, process: Process, worker_id: int, conn: Connection): |
| | self.process = process |
| | self.worker_id = worker_id |
| | self.conn = conn |
| | self.previous_step: EnvironmentStep = EnvironmentStep.empty(worker_id) |
| | self.previous_all_action_info: Dict[str, ActionInfo] = {} |
| | self.waiting = False |
| | self.closed = False |
| |
|
| | def send(self, cmd: EnvironmentCommand, payload: Any = None) -> None: |
| | try: |
| | req = EnvironmentRequest(cmd, payload) |
| | self.conn.send(req) |
| | except (BrokenPipeError, EOFError): |
| | raise UnityCommunicationException("UnityEnvironment worker: send failed.") |
| |
|
| | def recv(self) -> EnvironmentResponse: |
| | try: |
| | response: EnvironmentResponse = self.conn.recv() |
| | if response.cmd == EnvironmentCommand.ENV_EXITED: |
| | env_exception: Exception = response.payload |
| | raise env_exception |
| | return response |
| | except (BrokenPipeError, EOFError): |
| | raise UnityCommunicationException("UnityEnvironment worker: recv failed.") |
| |
|
| | def request_close(self): |
| | try: |
| | self.conn.send(EnvironmentRequest(EnvironmentCommand.CLOSE)) |
| | except (BrokenPipeError, EOFError): |
| | logger.debug( |
| | f"UnityEnvWorker {self.worker_id} got exception trying to close." |
| | ) |
| | pass |
| |
|
| |
|
| | def worker( |
| | parent_conn: Connection, |
| | step_queue: Queue, |
| | pickled_env_factory: str, |
| | worker_id: int, |
| | run_options: RunOptions, |
| | log_level: int = logging_util.INFO, |
| | ) -> None: |
| | env_factory: Callable[ |
| | [int, List[SideChannel]], UnityEnvironment |
| | ] = cloudpickle.loads(pickled_env_factory) |
| | env_parameters = EnvironmentParametersChannel() |
| |
|
| | engine_config = EngineConfig( |
| | width=run_options.engine_settings.width, |
| | height=run_options.engine_settings.height, |
| | quality_level=run_options.engine_settings.quality_level, |
| | time_scale=run_options.engine_settings.time_scale, |
| | target_frame_rate=run_options.engine_settings.target_frame_rate, |
| | capture_frame_rate=run_options.engine_settings.capture_frame_rate, |
| | ) |
| | engine_configuration_channel = EngineConfigurationChannel() |
| | engine_configuration_channel.set_configuration(engine_config) |
| |
|
| | stats_channel = StatsSideChannel() |
| | training_analytics_channel: Optional[TrainingAnalyticsSideChannel] = None |
| | if worker_id == 0: |
| | training_analytics_channel = TrainingAnalyticsSideChannel() |
| | env: UnityEnvironment = None |
| | |
| | |
| | logging_util.set_log_level(log_level) |
| |
|
| | def _send_response(cmd_name: EnvironmentCommand, payload: Any) -> None: |
| | parent_conn.send(EnvironmentResponse(cmd_name, worker_id, payload)) |
| |
|
| | def _generate_all_results() -> AllStepResult: |
| | all_step_result: AllStepResult = {} |
| | for brain_name in env.behavior_specs: |
| | all_step_result[brain_name] = env.get_steps(brain_name) |
| | return all_step_result |
| |
|
| | try: |
| | side_channels = [env_parameters, engine_configuration_channel, stats_channel] |
| | if training_analytics_channel is not None: |
| | side_channels.append(training_analytics_channel) |
| |
|
| | env = env_factory(worker_id, side_channels) |
| | if ( |
| | not env.academy_capabilities |
| | or not env.academy_capabilities.trainingAnalytics |
| | ): |
| | |
| | |
| | training_analytics_channel = None |
| | if training_analytics_channel: |
| | training_analytics_channel.environment_initialized(run_options) |
| |
|
| | while True: |
| | req: EnvironmentRequest = parent_conn.recv() |
| | if req.cmd == EnvironmentCommand.STEP: |
| | all_action_info = req.payload |
| | for brain_name, action_info in all_action_info.items(): |
| | if len(action_info.agent_ids) > 0: |
| | env.set_actions(brain_name, action_info.env_action) |
| | env.step() |
| | all_step_result = _generate_all_results() |
| | |
| | |
| | |
| | |
| | |
| | env_stats = stats_channel.get_and_reset_stats() |
| | step_response = StepResponse( |
| | all_step_result, get_timer_root(), env_stats |
| | ) |
| | step_queue.put( |
| | EnvironmentResponse( |
| | EnvironmentCommand.STEP, worker_id, step_response |
| | ) |
| | ) |
| | reset_timers() |
| | elif req.cmd == EnvironmentCommand.BEHAVIOR_SPECS: |
| | _send_response(EnvironmentCommand.BEHAVIOR_SPECS, env.behavior_specs) |
| | elif req.cmd == EnvironmentCommand.ENVIRONMENT_PARAMETERS: |
| | for k, v in req.payload.items(): |
| | if isinstance(v, ParameterRandomizationSettings): |
| | v.apply(k, env_parameters) |
| | elif req.cmd == EnvironmentCommand.TRAINING_STARTED: |
| | behavior_name, trainer_config = req.payload |
| | if training_analytics_channel: |
| | training_analytics_channel.training_started( |
| | behavior_name, trainer_config |
| | ) |
| | elif req.cmd == EnvironmentCommand.RESET: |
| | env.reset() |
| | all_step_result = _generate_all_results() |
| | _send_response(EnvironmentCommand.RESET, all_step_result) |
| | elif req.cmd == EnvironmentCommand.CLOSE: |
| | break |
| | except ( |
| | KeyboardInterrupt, |
| | UnityCommunicationException, |
| | UnityTimeOutException, |
| | UnityEnvironmentException, |
| | UnityCommunicatorStoppedException, |
| | ) as ex: |
| | logger.debug(f"UnityEnvironment worker {worker_id}: environment stopping.") |
| | step_queue.put( |
| | EnvironmentResponse(EnvironmentCommand.ENV_EXITED, worker_id, ex) |
| | ) |
| | _send_response(EnvironmentCommand.ENV_EXITED, ex) |
| | except Exception as ex: |
| | logger.exception( |
| | f"UnityEnvironment worker {worker_id}: environment raised an unexpected exception." |
| | ) |
| | step_queue.put( |
| | EnvironmentResponse(EnvironmentCommand.ENV_EXITED, worker_id, ex) |
| | ) |
| | _send_response(EnvironmentCommand.ENV_EXITED, ex) |
| | finally: |
| | logger.debug(f"UnityEnvironment worker {worker_id} closing.") |
| | if env is not None: |
| | env.close() |
| | logger.debug(f"UnityEnvironment worker {worker_id} done.") |
| | parent_conn.close() |
| | step_queue.put(EnvironmentResponse(EnvironmentCommand.CLOSED, worker_id, None)) |
| | step_queue.close() |
| |
|
| |
|
| | class SubprocessEnvManager(EnvManager): |
| | def __init__( |
| | self, |
| | env_factory: Callable[[int, List[SideChannel]], BaseEnv], |
| | run_options: RunOptions, |
| | n_env: int = 1, |
| | ): |
| | super().__init__() |
| | self.env_workers: List[UnityEnvWorker] = [] |
| | self.step_queue: Queue = Queue() |
| | self.workers_alive = 0 |
| | self.env_factory = env_factory |
| | self.run_options = run_options |
| | self.env_parameters: Optional[Dict] = None |
| | |
| | self.recent_restart_timestamps: List[List[datetime.datetime]] = [ |
| | [] for _ in range(n_env) |
| | ] |
| | self.restart_counts: List[int] = [0] * n_env |
| | for worker_idx in range(n_env): |
| | self.env_workers.append( |
| | self.create_worker( |
| | worker_idx, self.step_queue, env_factory, run_options |
| | ) |
| | ) |
| | self.workers_alive += 1 |
| |
|
| | @staticmethod |
| | def create_worker( |
| | worker_id: int, |
| | step_queue: Queue, |
| | env_factory: Callable[[int, List[SideChannel]], BaseEnv], |
| | run_options: RunOptions, |
| | ) -> UnityEnvWorker: |
| | parent_conn, child_conn = Pipe() |
| |
|
| | |
| | |
| | pickled_env_factory = cloudpickle.dumps(env_factory) |
| | child_process = Process( |
| | target=worker, |
| | args=( |
| | child_conn, |
| | step_queue, |
| | pickled_env_factory, |
| | worker_id, |
| | run_options, |
| | logger.level, |
| | ), |
| | ) |
| | child_process.start() |
| | return UnityEnvWorker(child_process, worker_id, parent_conn) |
| |
|
| | def _queue_steps(self) -> None: |
| | for env_worker in self.env_workers: |
| | if not env_worker.waiting: |
| | env_action_info = self._take_step(env_worker.previous_step) |
| | env_worker.previous_all_action_info = env_action_info |
| | env_worker.send(EnvironmentCommand.STEP, env_action_info) |
| | env_worker.waiting = True |
| |
|
| | def _restart_failed_workers(self, first_failure: EnvironmentResponse) -> None: |
| | if first_failure.cmd != EnvironmentCommand.ENV_EXITED: |
| | return |
| | |
| | |
| | other_failures: Dict[int, Exception] = self._drain_step_queue() |
| | |
| | failures: Dict[int, Exception] = { |
| | **{first_failure.worker_id: first_failure.payload}, |
| | **other_failures, |
| | } |
| | for worker_id, ex in failures.items(): |
| | self._assert_worker_can_restart(worker_id, ex) |
| | logger.warning(f"Restarting worker[{worker_id}] after '{ex}'") |
| | self.recent_restart_timestamps[worker_id].append(datetime.datetime.now()) |
| | self.restart_counts[worker_id] += 1 |
| | self.env_workers[worker_id] = self.create_worker( |
| | worker_id, self.step_queue, self.env_factory, self.run_options |
| | ) |
| | |
| | |
| | self.reset(self.env_parameters) |
| |
|
| | def _drain_step_queue(self) -> Dict[int, Exception]: |
| | """ |
| | Drains all steps out of the step queue and returns all exceptions from crashed workers. |
| | This will effectively pause all workers so that they won't do anything until _queue_steps is called. |
| | """ |
| | all_failures = {} |
| | workers_still_pending = {w.worker_id for w in self.env_workers if w.waiting} |
| | deadline = datetime.datetime.now() + datetime.timedelta(minutes=1) |
| | while workers_still_pending and deadline > datetime.datetime.now(): |
| | try: |
| | while True: |
| | step: EnvironmentResponse = self.step_queue.get_nowait() |
| | if step.cmd == EnvironmentCommand.ENV_EXITED: |
| | workers_still_pending.add(step.worker_id) |
| | all_failures[step.worker_id] = step.payload |
| | else: |
| | workers_still_pending.remove(step.worker_id) |
| | self.env_workers[step.worker_id].waiting = False |
| | except EmptyQueueException: |
| | pass |
| | if deadline < datetime.datetime.now(): |
| | still_waiting = {w.worker_id for w in self.env_workers if w.waiting} |
| | raise TimeoutError(f"Workers {still_waiting} stuck in waiting state") |
| | return all_failures |
| |
|
| | def _assert_worker_can_restart(self, worker_id: int, exception: Exception) -> None: |
| | """ |
| | Checks if we can recover from an exception from a worker. |
| | If the restart limit is exceeded it will raise a UnityCommunicationException. |
| | If the exception is not recoverable it re-raises the exception. |
| | """ |
| | if ( |
| | isinstance(exception, UnityCommunicationException) |
| | or isinstance(exception, UnityTimeOutException) |
| | or isinstance(exception, UnityEnvironmentException) |
| | or isinstance(exception, UnityCommunicatorStoppedException) |
| | ): |
| | if self._worker_has_restart_quota(worker_id): |
| | return |
| | else: |
| | logger.error( |
| | f"Worker {worker_id} exceeded the allowed number of restarts." |
| | ) |
| | raise exception |
| | raise exception |
| |
|
| | def _worker_has_restart_quota(self, worker_id: int) -> bool: |
| | self._drop_old_restart_timestamps(worker_id) |
| | max_lifetime_restarts = self.run_options.env_settings.max_lifetime_restarts |
| | max_limit_check = ( |
| | max_lifetime_restarts == -1 |
| | or self.restart_counts[worker_id] < max_lifetime_restarts |
| | ) |
| |
|
| | rate_limit_n = self.run_options.env_settings.restarts_rate_limit_n |
| | rate_limit_check = ( |
| | rate_limit_n == -1 |
| | or len(self.recent_restart_timestamps[worker_id]) < rate_limit_n |
| | ) |
| |
|
| | return rate_limit_check and max_limit_check |
| |
|
| | def _drop_old_restart_timestamps(self, worker_id: int) -> None: |
| | """ |
| | Drops environment restart timestamps that are outside of the current window. |
| | """ |
| |
|
| | def _filter(t: datetime.datetime) -> bool: |
| | return t > datetime.datetime.now() - datetime.timedelta( |
| | seconds=self.run_options.env_settings.restarts_rate_limit_period_s |
| | ) |
| |
|
| | self.recent_restart_timestamps[worker_id] = list( |
| | filter(_filter, self.recent_restart_timestamps[worker_id]) |
| | ) |
| |
|
| | def _step(self) -> List[EnvironmentStep]: |
| | |
| | self._queue_steps() |
| |
|
| | worker_steps: List[EnvironmentResponse] = [] |
| | step_workers: Set[int] = set() |
| | |
| | |
| | while len(worker_steps) < 1: |
| | try: |
| | while True: |
| | step: EnvironmentResponse = self.step_queue.get_nowait() |
| | if step.cmd == EnvironmentCommand.ENV_EXITED: |
| | |
| | self._restart_failed_workers(step) |
| | |
| | worker_steps.clear() |
| | step_workers.clear() |
| | self._queue_steps() |
| | elif step.worker_id not in step_workers: |
| | self.env_workers[step.worker_id].waiting = False |
| | worker_steps.append(step) |
| | step_workers.add(step.worker_id) |
| | except EmptyQueueException: |
| | pass |
| | step_infos = self._postprocess_steps(worker_steps) |
| | return step_infos |
| |
|
| | def _reset_env(self, config: Optional[Dict] = None) -> List[EnvironmentStep]: |
| | while any(ew.waiting for ew in self.env_workers): |
| | if not self.step_queue.empty(): |
| | step = self.step_queue.get_nowait() |
| | self.env_workers[step.worker_id].waiting = False |
| | |
| | self.set_env_parameters(config) |
| | |
| | for ew in self.env_workers: |
| | ew.send(EnvironmentCommand.RESET, config) |
| | |
| | for ew in self.env_workers: |
| | ew.previous_step = EnvironmentStep(ew.recv().payload, ew.worker_id, {}, {}) |
| | return list(map(lambda ew: ew.previous_step, self.env_workers)) |
| |
|
| | def set_env_parameters(self, config: Dict = None) -> None: |
| | """ |
| | Sends environment parameter settings to C# via the |
| | EnvironmentParametersSidehannel for each worker. |
| | :param config: Dict of environment parameter keys and values |
| | """ |
| | self.env_parameters = config |
| | for ew in self.env_workers: |
| | ew.send(EnvironmentCommand.ENVIRONMENT_PARAMETERS, config) |
| |
|
| | def on_training_started( |
| | self, behavior_name: str, trainer_settings: TrainerSettings |
| | ) -> None: |
| | """ |
| | Handle traing starting for a new behavior type. Generally nothing is necessary here. |
| | :param behavior_name: |
| | :param trainer_settings: |
| | :return: |
| | """ |
| | for ew in self.env_workers: |
| | ew.send( |
| | EnvironmentCommand.TRAINING_STARTED, (behavior_name, trainer_settings) |
| | ) |
| |
|
| | @property |
| | def training_behaviors(self) -> Dict[BehaviorName, BehaviorSpec]: |
| | result: Dict[BehaviorName, BehaviorSpec] = {} |
| | for worker in self.env_workers: |
| | worker.send(EnvironmentCommand.BEHAVIOR_SPECS) |
| | result.update(worker.recv().payload) |
| | return result |
| |
|
| | def close(self) -> None: |
| | logger.debug("SubprocessEnvManager closing.") |
| | for env_worker in self.env_workers: |
| | env_worker.request_close() |
| | |
| | deadline = time.time() + WORKER_SHUTDOWN_TIMEOUT_S |
| | while self.workers_alive > 0 and time.time() < deadline: |
| | try: |
| | step: EnvironmentResponse = self.step_queue.get_nowait() |
| | env_worker = self.env_workers[step.worker_id] |
| | if step.cmd == EnvironmentCommand.CLOSED and not env_worker.closed: |
| | env_worker.closed = True |
| | self.workers_alive -= 1 |
| | |
| | except EmptyQueueException: |
| | pass |
| | self.step_queue.close() |
| | |
| | if self.workers_alive > 0: |
| | logger.error("SubprocessEnvManager had workers that didn't signal shutdown") |
| | for env_worker in self.env_workers: |
| | if not env_worker.closed and env_worker.process.is_alive(): |
| | env_worker.process.terminate() |
| | logger.error( |
| | "A SubprocessEnvManager worker did not shut down correctly so it was forcefully terminated." |
| | ) |
| | self.step_queue.join_thread() |
| |
|
| | def _postprocess_steps( |
| | self, env_steps: List[EnvironmentResponse] |
| | ) -> List[EnvironmentStep]: |
| | step_infos = [] |
| | timer_nodes = [] |
| | for step in env_steps: |
| | payload: StepResponse = step.payload |
| | env_worker = self.env_workers[step.worker_id] |
| | new_step = EnvironmentStep( |
| | payload.all_step_result, |
| | step.worker_id, |
| | env_worker.previous_all_action_info, |
| | payload.environment_stats, |
| | ) |
| | step_infos.append(new_step) |
| | env_worker.previous_step = new_step |
| |
|
| | if payload.timer_root: |
| | timer_nodes.append(payload.timer_root) |
| |
|
| | if timer_nodes: |
| | with hierarchical_timer("workers") as main_timer_node: |
| | for worker_timer_node in timer_nodes: |
| | main_timer_node.merge( |
| | worker_timer_node, root_name="worker_root", is_parallel=True |
| | ) |
| |
|
| | return step_infos |
| |
|
| | @timed |
| | def _take_step(self, last_step: EnvironmentStep) -> Dict[BehaviorName, ActionInfo]: |
| | all_action_info: Dict[str, ActionInfo] = {} |
| | for brain_name, step_tuple in last_step.current_all_step_result.items(): |
| | if brain_name in self.policies: |
| | all_action_info[brain_name] = self.policies[brain_name].get_action( |
| | step_tuple[0], last_step.worker_id |
| | ) |
| | return all_action_info |
| |
|