| | from abc import ABC, abstractmethod |
| |
|
| | from typing import List, Dict, NamedTuple, Iterable, Tuple |
| | from mlagents_envs.base_env import ( |
| | DecisionSteps, |
| | TerminalSteps, |
| | BehaviorSpec, |
| | BehaviorName, |
| | ) |
| | from mlagents_envs.side_channel.stats_side_channel import EnvironmentStats |
| |
|
| | from mlagents.trainers.policy import Policy |
| | from mlagents.trainers.agent_processor import AgentManager, AgentManagerQueue |
| | from mlagents.trainers.action_info import ActionInfo |
| | from mlagents.trainers.settings import TrainerSettings |
| | from mlagents_envs.logging_util import get_logger |
| |
|
| | AllStepResult = Dict[BehaviorName, Tuple[DecisionSteps, TerminalSteps]] |
| | AllGroupSpec = Dict[BehaviorName, BehaviorSpec] |
| |
|
| | logger = get_logger(__name__) |
| |
|
| |
|
| | class EnvironmentStep(NamedTuple): |
| | current_all_step_result: AllStepResult |
| | worker_id: int |
| | brain_name_to_action_info: Dict[BehaviorName, ActionInfo] |
| | environment_stats: EnvironmentStats |
| |
|
| | @property |
| | def name_behavior_ids(self) -> Iterable[BehaviorName]: |
| | return self.current_all_step_result.keys() |
| |
|
| | @staticmethod |
| | def empty(worker_id: int) -> "EnvironmentStep": |
| | return EnvironmentStep({}, worker_id, {}, {}) |
| |
|
| |
|
| | class EnvManager(ABC): |
| | def __init__(self): |
| | self.policies: Dict[BehaviorName, Policy] = {} |
| | self.agent_managers: Dict[BehaviorName, AgentManager] = {} |
| | self.first_step_infos: List[EnvironmentStep] = [] |
| |
|
| | def set_policy(self, brain_name: BehaviorName, policy: Policy) -> None: |
| | self.policies[brain_name] = policy |
| | if brain_name in self.agent_managers: |
| | self.agent_managers[brain_name].policy = policy |
| |
|
| | def set_agent_manager( |
| | self, brain_name: BehaviorName, manager: AgentManager |
| | ) -> None: |
| | self.agent_managers[brain_name] = manager |
| |
|
| | @abstractmethod |
| | def _step(self) -> List[EnvironmentStep]: |
| | pass |
| |
|
| | @abstractmethod |
| | def _reset_env(self, config: Dict = None) -> List[EnvironmentStep]: |
| | pass |
| |
|
| | def reset(self, config: Dict = None) -> int: |
| | for manager in self.agent_managers.values(): |
| | manager.end_episode() |
| | |
| | |
| | self.first_step_infos = self._reset_env(config) |
| | return len(self.first_step_infos) |
| |
|
| | @abstractmethod |
| | def set_env_parameters(self, config: Dict = None) -> None: |
| | """ |
| | Sends environment parameter settings to C# via the |
| | EnvironmentParametersSideChannel. |
| | :param config: Dict of environment parameter keys and values |
| | """ |
| | pass |
| |
|
| | def on_training_started( |
| | self, behavior_name: str, trainer_settings: TrainerSettings |
| | ) -> None: |
| | """ |
| | Handle traing starting for a new behavior type. Generally nothing is necessary here. |
| | :param behavior_name: |
| | :param trainer_settings: |
| | :return: |
| | """ |
| | pass |
| |
|
| | @property |
| | @abstractmethod |
| | def training_behaviors(self) -> Dict[BehaviorName, BehaviorSpec]: |
| | pass |
| |
|
| | @abstractmethod |
| | def close(self): |
| | pass |
| |
|
| | def get_steps(self) -> List[EnvironmentStep]: |
| | """ |
| | Updates the policies, steps the environments, and returns the step information from the environments. |
| | Calling code should pass the returned EnvironmentSteps to process_steps() after calling this. |
| | :return: The list of EnvironmentSteps |
| | """ |
| | |
| | |
| | |
| | if self.first_step_infos: |
| | self._process_step_infos(self.first_step_infos) |
| | self.first_step_infos = [] |
| | |
| | for brain_name in self.agent_managers.keys(): |
| | _policy = None |
| | try: |
| | |
| | |
| | while True: |
| | _policy = self.agent_managers[brain_name].policy_queue.get_nowait() |
| | except AgentManagerQueue.Empty: |
| | if _policy is not None: |
| | self.set_policy(brain_name, _policy) |
| | |
| | new_step_infos = self._step() |
| | return new_step_infos |
| |
|
| | def process_steps(self, new_step_infos: List[EnvironmentStep]) -> int: |
| | |
| | num_step_infos = self._process_step_infos(new_step_infos) |
| | return num_step_infos |
| |
|
| | def _process_step_infos(self, step_infos: List[EnvironmentStep]) -> int: |
| | for step_info in step_infos: |
| | for name_behavior_id in step_info.name_behavior_ids: |
| | if name_behavior_id not in self.agent_managers: |
| | logger.warning( |
| | "Agent manager was not created for behavior id {}.".format( |
| | name_behavior_id |
| | ) |
| | ) |
| | continue |
| | decision_steps, terminal_steps = step_info.current_all_step_result[ |
| | name_behavior_id |
| | ] |
| | self.agent_managers[name_behavior_id].add_experiences( |
| | decision_steps, |
| | terminal_steps, |
| | step_info.worker_id, |
| | step_info.brain_name_to_action_info.get( |
| | name_behavior_id, ActionInfo.empty() |
| | ), |
| | ) |
| |
|
| | self.agent_managers[name_behavior_id].record_environment_stats( |
| | step_info.environment_stats, step_info.worker_id |
| | ) |
| | return len(step_infos) |
| |
|