| | import sys |
| | import numpy as np |
| | from typing import List, Dict, TypeVar, Generic, Tuple, Any, Union |
| | from collections import defaultdict, Counter |
| | import queue |
| | from mlagents.torch_utils import torch |
| |
|
| | from mlagents_envs.base_env import ( |
| | ActionTuple, |
| | DecisionSteps, |
| | DecisionStep, |
| | TerminalSteps, |
| | TerminalStep, |
| | ) |
| | from mlagents_envs.side_channel.stats_side_channel import ( |
| | StatsAggregationMethod, |
| | EnvironmentStats, |
| | ) |
| | from mlagents.trainers.exception import UnityTrainerException |
| | from mlagents.trainers.trajectory import AgentStatus, Trajectory, AgentExperience |
| | from mlagents.trainers.policy import Policy |
| | from mlagents.trainers.action_info import ActionInfo, ActionInfoOutputs |
| | from mlagents.trainers.stats import StatsReporter |
| | from mlagents.trainers.behavior_id_utils import ( |
| | get_global_agent_id, |
| | get_global_group_id, |
| | GlobalAgentId, |
| | GlobalGroupId, |
| | ) |
| | from mlagents.trainers.torch_entities.action_log_probs import LogProbsTuple |
| | from mlagents.trainers.torch_entities.utils import ModelUtils |
| |
|
| | T = TypeVar("T") |
| |
|
| |
|
| | class AgentProcessor: |
| | """ |
| | AgentProcessor contains a dictionary per-agent trajectory buffers. The buffers are indexed by agent_id. |
| | Buffer also contains an update_buffer that corresponds to the buffer used when updating the model. |
| | One AgentProcessor should be created per agent group. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | policy: Policy, |
| | behavior_id: str, |
| | stats_reporter: StatsReporter, |
| | max_trajectory_length: int = sys.maxsize, |
| | ): |
| | """ |
| | Create an AgentProcessor. |
| | |
| | :param trainer: Trainer instance connected to this AgentProcessor. Trainer is given trajectory |
| | when it is finished. |
| | :param policy: Policy instance associated with this AgentProcessor. |
| | :param max_trajectory_length: Maximum length of a trajectory before it is added to the trainer. |
| | :param stats_category: The category under which to write the stats. Usually, this comes from the Trainer. |
| | """ |
| | self._experience_buffers: Dict[ |
| | GlobalAgentId, List[AgentExperience] |
| | ] = defaultdict(list) |
| | self._last_step_result: Dict[GlobalAgentId, Tuple[DecisionStep, int]] = {} |
| | |
| | |
| | |
| | self._current_group_obs: Dict[ |
| | GlobalGroupId, Dict[GlobalAgentId, List[np.ndarray]] |
| | ] = defaultdict(lambda: defaultdict(list)) |
| | |
| | |
| | |
| | self._group_status: Dict[ |
| | GlobalGroupId, Dict[GlobalAgentId, AgentStatus] |
| | ] = defaultdict(lambda: defaultdict(None)) |
| | |
| | |
| | self._last_take_action_outputs: Dict[GlobalAgentId, ActionInfoOutputs] = {} |
| |
|
| | self._episode_steps: Counter = Counter() |
| | self._episode_rewards: Dict[GlobalAgentId, float] = defaultdict(float) |
| | self._stats_reporter = stats_reporter |
| | self._max_trajectory_length = max_trajectory_length |
| | self._trajectory_queues: List[AgentManagerQueue[Trajectory]] = [] |
| | self._behavior_id = behavior_id |
| |
|
| | |
| | |
| | self.policy = policy |
| |
|
| | def add_experiences( |
| | self, |
| | decision_steps: DecisionSteps, |
| | terminal_steps: TerminalSteps, |
| | worker_id: int, |
| | previous_action: ActionInfo, |
| | ) -> None: |
| | """ |
| | Adds experiences to each agent's experience history. |
| | :param decision_steps: current DecisionSteps. |
| | :param terminal_steps: current TerminalSteps. |
| | :param previous_action: The outputs of the Policy's get_action method. |
| | """ |
| | take_action_outputs = previous_action.outputs |
| | if take_action_outputs: |
| | try: |
| | for _entropy in take_action_outputs["entropy"]: |
| | if isinstance(_entropy, torch.Tensor): |
| | _entropy = ModelUtils.to_numpy(_entropy) |
| | self._stats_reporter.add_stat("Policy/Entropy", _entropy) |
| | except KeyError: |
| | pass |
| |
|
| | |
| | action_global_agent_ids = [ |
| | get_global_agent_id(worker_id, ag_id) for ag_id in previous_action.agent_ids |
| | ] |
| | for global_id in action_global_agent_ids: |
| | if global_id in self._last_step_result: |
| | self._last_take_action_outputs[global_id] = take_action_outputs |
| |
|
| | |
| | |
| | |
| | for terminal_step in terminal_steps.values(): |
| | self._add_group_status_and_obs(terminal_step, worker_id) |
| | for terminal_step in terminal_steps.values(): |
| | local_id = terminal_step.agent_id |
| | global_id = get_global_agent_id(worker_id, local_id) |
| | self._process_step( |
| | terminal_step, worker_id, terminal_steps.agent_id_to_index[local_id] |
| | ) |
| |
|
| | |
| | |
| | |
| | for ongoing_step in decision_steps.values(): |
| | self._add_group_status_and_obs(ongoing_step, worker_id) |
| | for ongoing_step in decision_steps.values(): |
| | local_id = ongoing_step.agent_id |
| | self._process_step( |
| | ongoing_step, worker_id, decision_steps.agent_id_to_index[local_id] |
| | ) |
| | |
| | |
| | for terminal_step in terminal_steps.values(): |
| | local_id = terminal_step.agent_id |
| | global_id = get_global_agent_id(worker_id, local_id) |
| | self._clear_group_status_and_obs(global_id) |
| |
|
| | for _gid in action_global_agent_ids: |
| | |
| | |
| | if _gid in self._last_step_result: |
| | if "action" in take_action_outputs: |
| | self.policy.save_previous_action( |
| | [_gid], take_action_outputs["action"] |
| | ) |
| |
|
| | def _add_group_status_and_obs( |
| | self, step: Union[TerminalStep, DecisionStep], worker_id: int |
| | ) -> None: |
| | """ |
| | Takes a TerminalStep or DecisionStep and adds the information in it |
| | to self.group_status. This information can then be retrieved |
| | when constructing trajectories to get the status of group mates. Also stores the current |
| | observation into current_group_obs, to be used to get the next group observations |
| | for bootstrapping. |
| | :param step: TerminalStep or DecisionStep |
| | :param worker_id: Worker ID of this particular environment. Used to generate a |
| | global group id. |
| | """ |
| | global_agent_id = get_global_agent_id(worker_id, step.agent_id) |
| | stored_decision_step, idx = self._last_step_result.get( |
| | global_agent_id, (None, None) |
| | ) |
| | stored_take_action_outputs = self._last_take_action_outputs.get( |
| | global_agent_id, None |
| | ) |
| | if stored_decision_step is not None and stored_take_action_outputs is not None: |
| | |
| | |
| | if step.group_id > 0: |
| | global_group_id = get_global_group_id(worker_id, step.group_id) |
| | stored_actions = stored_take_action_outputs["action"] |
| | action_tuple = ActionTuple( |
| | continuous=stored_actions.continuous[idx], |
| | discrete=stored_actions.discrete[idx], |
| | ) |
| | group_status = AgentStatus( |
| | obs=stored_decision_step.obs, |
| | reward=step.reward, |
| | action=action_tuple, |
| | done=isinstance(step, TerminalStep), |
| | ) |
| | self._group_status[global_group_id][global_agent_id] = group_status |
| | self._current_group_obs[global_group_id][global_agent_id] = step.obs |
| |
|
| | def _clear_group_status_and_obs(self, global_id: GlobalAgentId) -> None: |
| | """ |
| | Clears an agent from self._group_status and self._current_group_obs. |
| | """ |
| | self._delete_in_nested_dict(self._current_group_obs, global_id) |
| | self._delete_in_nested_dict(self._group_status, global_id) |
| |
|
| | def _delete_in_nested_dict(self, nested_dict: Dict[str, Any], key: str) -> None: |
| | for _manager_id in list(nested_dict.keys()): |
| | _team_group = nested_dict[_manager_id] |
| | self._safe_delete(_team_group, key) |
| | if not _team_group: |
| | self._safe_delete(nested_dict, _manager_id) |
| |
|
| | def _process_step( |
| | self, step: Union[TerminalStep, DecisionStep], worker_id: int, index: int |
| | ) -> None: |
| | terminated = isinstance(step, TerminalStep) |
| | global_agent_id = get_global_agent_id(worker_id, step.agent_id) |
| | global_group_id = get_global_group_id(worker_id, step.group_id) |
| | stored_decision_step, idx = self._last_step_result.get( |
| | global_agent_id, (None, None) |
| | ) |
| | stored_take_action_outputs = self._last_take_action_outputs.get( |
| | global_agent_id, None |
| | ) |
| | if not terminated: |
| | |
| | self._last_step_result[global_agent_id] = (step, index) |
| |
|
| | |
| | if stored_decision_step is not None and stored_take_action_outputs is not None: |
| | obs = stored_decision_step.obs |
| | if self.policy.use_recurrent: |
| | memory = self.policy.retrieve_previous_memories([global_agent_id])[0, :] |
| | else: |
| | memory = None |
| | done = terminated |
| | interrupted = step.interrupted if terminated else False |
| | |
| | stored_actions = stored_take_action_outputs["action"] |
| | action_tuple = ActionTuple( |
| | continuous=stored_actions.continuous[idx], |
| | discrete=stored_actions.discrete[idx], |
| | ) |
| | try: |
| | stored_action_probs = stored_take_action_outputs["log_probs"] |
| | if not isinstance(stored_action_probs, LogProbsTuple): |
| | stored_action_probs = stored_action_probs.to_log_probs_tuple() |
| | log_probs_tuple = LogProbsTuple( |
| | continuous=stored_action_probs.continuous[idx], |
| | discrete=stored_action_probs.discrete[idx], |
| | ) |
| | except KeyError: |
| | log_probs_tuple = LogProbsTuple.empty_log_probs() |
| |
|
| | action_mask = stored_decision_step.action_mask |
| | prev_action = self.policy.retrieve_previous_action([global_agent_id])[0, :] |
| |
|
| | |
| | group_statuses = [] |
| | for _id, _mate_status in self._group_status[global_group_id].items(): |
| | if _id != global_agent_id: |
| | group_statuses.append(_mate_status) |
| |
|
| | experience = AgentExperience( |
| | obs=obs, |
| | reward=step.reward, |
| | done=done, |
| | action=action_tuple, |
| | action_probs=log_probs_tuple, |
| | action_mask=action_mask, |
| | prev_action=prev_action, |
| | interrupted=interrupted, |
| | memory=memory, |
| | group_status=group_statuses, |
| | group_reward=step.group_reward, |
| | ) |
| | |
| | self._experience_buffers[global_agent_id].append(experience) |
| | self._episode_rewards[global_agent_id] += step.reward |
| | if not terminated: |
| | self._episode_steps[global_agent_id] += 1 |
| |
|
| | |
| | if ( |
| | len(self._experience_buffers[global_agent_id]) |
| | >= self._max_trajectory_length |
| | or terminated |
| | ): |
| | next_obs = step.obs |
| | next_group_obs = [] |
| | for _id, _obs in self._current_group_obs[global_group_id].items(): |
| | if _id != global_agent_id: |
| | next_group_obs.append(_obs) |
| |
|
| | trajectory = Trajectory( |
| | steps=self._experience_buffers[global_agent_id], |
| | agent_id=global_agent_id, |
| | next_obs=next_obs, |
| | next_group_obs=next_group_obs, |
| | behavior_id=self._behavior_id, |
| | ) |
| | for traj_queue in self._trajectory_queues: |
| | traj_queue.put(trajectory) |
| | self._experience_buffers[global_agent_id] = [] |
| | if terminated: |
| | |
| | self._stats_reporter.add_stat( |
| | "Environment/Episode Length", |
| | self._episode_steps.get(global_agent_id, 0), |
| | ) |
| | self._clean_agent_data(global_agent_id) |
| |
|
| | def _clean_agent_data(self, global_id: GlobalAgentId) -> None: |
| | """ |
| | Removes the data for an Agent. |
| | """ |
| | self._safe_delete(self._experience_buffers, global_id) |
| | self._safe_delete(self._last_take_action_outputs, global_id) |
| | self._safe_delete(self._last_step_result, global_id) |
| | self._safe_delete(self._episode_steps, global_id) |
| | self._safe_delete(self._episode_rewards, global_id) |
| | self.policy.remove_previous_action([global_id]) |
| | self.policy.remove_memories([global_id]) |
| |
|
| | def _safe_delete(self, my_dictionary: Dict[Any, Any], key: Any) -> None: |
| | """ |
| | Safe removes data from a dictionary. If not found, |
| | don't delete. |
| | """ |
| | if key in my_dictionary: |
| | del my_dictionary[key] |
| |
|
| | def publish_trajectory_queue( |
| | self, trajectory_queue: "AgentManagerQueue[Trajectory]" |
| | ) -> None: |
| | """ |
| | Adds a trajectory queue to the list of queues to publish to when this AgentProcessor |
| | assembles a Trajectory |
| | :param trajectory_queue: Trajectory queue to publish to. |
| | """ |
| | self._trajectory_queues.append(trajectory_queue) |
| |
|
| | def end_episode(self) -> None: |
| | """ |
| | Ends the episode, terminating the current trajectory and stopping stats collection for that |
| | episode. Used for forceful reset (e.g. in curriculum or generalization training.) |
| | """ |
| | all_gids = list(self._experience_buffers.keys()) |
| | for _gid in all_gids: |
| | self._clean_agent_data(_gid) |
| |
|
| |
|
| | class AgentManagerQueue(Generic[T]): |
| | """ |
| | Queue used by the AgentManager. Note that we make our own class here because in most implementations |
| | deque is sufficient and faster. However, if we want to switch to multiprocessing, we'll need to change |
| | out this implementation. |
| | """ |
| |
|
| | class Empty(Exception): |
| | """ |
| | Exception for when the queue is empty. |
| | """ |
| |
|
| | pass |
| |
|
| | def __init__(self, behavior_id: str, maxlen: int = 0): |
| | """ |
| | Initializes an AgentManagerQueue. Note that we can give it a behavior_id so that it can be identified |
| | separately from an AgentManager. |
| | """ |
| | self._maxlen: int = maxlen |
| | self._queue: queue.Queue = queue.Queue(maxsize=maxlen) |
| | self._behavior_id = behavior_id |
| |
|
| | @property |
| | def maxlen(self): |
| | """ |
| | The maximum length of the queue. |
| | :return: Maximum length of the queue. |
| | """ |
| | return self._maxlen |
| |
|
| | @property |
| | def behavior_id(self): |
| | """ |
| | The Behavior ID of this queue. |
| | :return: Behavior ID associated with the queue. |
| | """ |
| | return self._behavior_id |
| |
|
| | def qsize(self) -> int: |
| | """ |
| | Returns the approximate size of the queue. Note that values may differ |
| | depending on the underlying queue implementation. |
| | """ |
| | return self._queue.qsize() |
| |
|
| | def empty(self) -> bool: |
| | return self._queue.empty() |
| |
|
| | def get_nowait(self) -> T: |
| | """ |
| | Gets the next item from the queue, throwing an AgentManagerQueue.Empty exception |
| | if the queue is empty. |
| | """ |
| | try: |
| | return self._queue.get_nowait() |
| | except queue.Empty: |
| | raise self.Empty("The AgentManagerQueue is empty.") |
| |
|
| | def put(self, item: T) -> None: |
| | self._queue.put(item) |
| |
|
| |
|
| | class AgentManager(AgentProcessor): |
| | """ |
| | An AgentManager is an AgentProcessor that also holds a single trajectory and policy queue. |
| | Note: this leaves room for adding AgentProcessors that publish multiple trajectory queues. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | policy: Policy, |
| | behavior_id: str, |
| | stats_reporter: StatsReporter, |
| | max_trajectory_length: int = sys.maxsize, |
| | threaded: bool = True, |
| | ): |
| | super().__init__(policy, behavior_id, stats_reporter, max_trajectory_length) |
| | trajectory_queue_len = 20 if threaded else 0 |
| | self.trajectory_queue: AgentManagerQueue[Trajectory] = AgentManagerQueue( |
| | self._behavior_id, maxlen=trajectory_queue_len |
| | ) |
| | |
| | |
| | self.policy_queue: AgentManagerQueue[Policy] = AgentManagerQueue( |
| | self._behavior_id, maxlen=0 |
| | ) |
| | self.publish_trajectory_queue(self.trajectory_queue) |
| |
|
| | def record_environment_stats( |
| | self, env_stats: EnvironmentStats, worker_id: int |
| | ) -> None: |
| | """ |
| | Pass stats from the environment to the StatsReporter. |
| | Depending on the StatsAggregationMethod, either StatsReporter.add_stat or StatsReporter.set_stat is used. |
| | The worker_id is used to determine whether StatsReporter.set_stat should be used. |
| | |
| | :param env_stats: |
| | :param worker_id: |
| | :return: |
| | """ |
| | for stat_name, value_list in env_stats.items(): |
| | for val, agg_type in value_list: |
| | if agg_type == StatsAggregationMethod.AVERAGE: |
| | self._stats_reporter.add_stat(stat_name, val, agg_type) |
| | elif agg_type == StatsAggregationMethod.SUM: |
| | self._stats_reporter.add_stat(stat_name, val, agg_type) |
| | elif agg_type == StatsAggregationMethod.HISTOGRAM: |
| | self._stats_reporter.add_stat(stat_name, val, agg_type) |
| | elif agg_type == StatsAggregationMethod.MOST_RECENT: |
| | |
| | |
| | if worker_id == 0: |
| | self._stats_reporter.set_stat(stat_name, val) |
| | else: |
| | raise UnityTrainerException( |
| | f"Unknown StatsAggregationMethod encountered. {agg_type}" |
| | ) |
| |
|