| | |
| | |
| | |
| |
|
| | from typing import cast |
| |
|
| | import numpy as np |
| |
|
| | from mlagents_envs.base_env import BehaviorSpec |
| | from mlagents_envs.logging_util import get_logger |
| | from mlagents.trainers.buffer import BufferKey, RewardSignalUtil |
| | from mlagents.trainers.trainer.on_policy_trainer import OnPolicyTrainer |
| | from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer |
| | from mlagents.trainers.trainer.trainer_utils import get_gae |
| | from mlagents.trainers.policy.torch_policy import TorchPolicy |
| | from .a2c_optimizer import A2COptimizer, A2CSettings |
| | from mlagents.trainers.trajectory import Trajectory |
| | from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers |
| | from mlagents.trainers.settings import TrainerSettings |
| |
|
| | from mlagents.trainers.torch_entities.networks import SimpleActor, SharedActorCritic |
| |
|
| | logger = get_logger(__name__) |
| |
|
| | TRAINER_NAME = "a2c" |
| |
|
| |
|
| | class A2CTrainer(OnPolicyTrainer): |
| | """The A2CTrainer is an implementation of the A2C algorithm.""" |
| |
|
| | def __init__( |
| | self, |
| | behavior_name: str, |
| | reward_buff_cap: int, |
| | trainer_settings: TrainerSettings, |
| | training: bool, |
| | load: bool, |
| | seed: int, |
| | artifact_path: str, |
| | ): |
| | """ |
| | Responsible for collecting experiences and training A2C model. |
| | :param behavior_name: The name of the behavior associated with trainer config |
| | :param reward_buff_cap: Max reward history to track in the reward buffer |
| | :param trainer_settings: The parameters for the trainer. |
| | :param training: Whether the trainer is set for training. |
| | :param load: Whether the model should be loaded. |
| | :param seed: The seed the model will be initialized with |
| | :param artifact_path: The directory within which to store artifacts from this trainer. |
| | """ |
| | super().__init__( |
| | behavior_name, |
| | reward_buff_cap, |
| | trainer_settings, |
| | training, |
| | load, |
| | seed, |
| | artifact_path, |
| | ) |
| | self.hyperparameters: A2CSettings = cast( |
| | A2CSettings, self.trainer_settings.hyperparameters |
| | ) |
| | self.shared_critic = self.hyperparameters.shared_critic |
| | self.policy: TorchPolicy = None |
| |
|
| | def _process_trajectory(self, trajectory: Trajectory) -> None: |
| | """ |
| | Takes a trajectory and processes it, putting it into the update buffer. |
| | Processing involves calculating value and advantage targets for model updating step. |
| | :param trajectory: The Trajectory tuple containing the steps to be processed. |
| | """ |
| | super()._process_trajectory(trajectory) |
| | agent_id = trajectory.agent_id |
| |
|
| | agent_buffer_trajectory = trajectory.to_agentbuffer() |
| | |
| | self._warn_if_group_reward(agent_buffer_trajectory) |
| |
|
| | |
| | if self.is_training: |
| | self.policy.actor.update_normalization(agent_buffer_trajectory) |
| | self.optimizer.critic.update_normalization(agent_buffer_trajectory) |
| |
|
| | |
| | ( |
| | value_estimates, |
| | value_next, |
| | value_memories, |
| | ) = self.optimizer.get_trajectory_value_estimates( |
| | agent_buffer_trajectory, |
| | trajectory.next_obs, |
| | trajectory.done_reached and not trajectory.interrupted, |
| | ) |
| | if value_memories is not None: |
| | agent_buffer_trajectory[BufferKey.CRITIC_MEMORY].set(value_memories) |
| |
|
| | for name, v in value_estimates.items(): |
| | agent_buffer_trajectory[RewardSignalUtil.value_estimates_key(name)].extend( |
| | v |
| | ) |
| | self._stats_reporter.add_stat( |
| | f"Policy/{self.optimizer.reward_signals[name].name.capitalize()} Value Estimate", |
| | np.mean(v), |
| | ) |
| |
|
| | |
| | self.collected_rewards["environment"][agent_id] += np.sum( |
| | agent_buffer_trajectory[BufferKey.ENVIRONMENT_REWARDS] |
| | ) |
| | for name, reward_signal in self.optimizer.reward_signals.items(): |
| | evaluate_result = ( |
| | reward_signal.evaluate(agent_buffer_trajectory) * reward_signal.strength |
| | ) |
| | agent_buffer_trajectory[RewardSignalUtil.rewards_key(name)].extend( |
| | evaluate_result |
| | ) |
| | |
| | self.collected_rewards[name][agent_id] += np.sum(evaluate_result) |
| |
|
| | |
| | tmp_advantages = [] |
| | tmp_returns = [] |
| | for name in self.optimizer.reward_signals: |
| | bootstrap_value = value_next[name] |
| |
|
| | local_rewards = agent_buffer_trajectory[ |
| | RewardSignalUtil.rewards_key(name) |
| | ].get_batch() |
| | local_value_estimates = agent_buffer_trajectory[ |
| | RewardSignalUtil.value_estimates_key(name) |
| | ].get_batch() |
| |
|
| | local_advantage = get_gae( |
| | rewards=local_rewards, |
| | value_estimates=local_value_estimates, |
| | value_next=bootstrap_value, |
| | gamma=self.optimizer.reward_signals[name].gamma, |
| | lambd=self.hyperparameters.lambd, |
| | ) |
| | local_return = local_advantage + local_value_estimates |
| | |
| | agent_buffer_trajectory[RewardSignalUtil.returns_key(name)].set( |
| | local_return |
| | ) |
| | agent_buffer_trajectory[RewardSignalUtil.advantage_key(name)].set( |
| | local_advantage |
| | ) |
| | tmp_advantages.append(local_advantage) |
| | tmp_returns.append(local_return) |
| |
|
| | |
| | global_advantages = list( |
| | np.mean(np.array(tmp_advantages, dtype=np.float32), axis=0) |
| | ) |
| | global_returns = list(np.mean(np.array(tmp_returns, dtype=np.float32), axis=0)) |
| | agent_buffer_trajectory[BufferKey.ADVANTAGES].set(global_advantages) |
| | agent_buffer_trajectory[BufferKey.DISCOUNTED_RETURNS].set(global_returns) |
| |
|
| | self._append_to_update_buffer(agent_buffer_trajectory) |
| |
|
| | |
| | if trajectory.done_reached: |
| | self._update_end_episode_stats(agent_id, self.optimizer) |
| |
|
| | def create_optimizer(self) -> TorchOptimizer: |
| | """ |
| | Creates an Optimizer object |
| | """ |
| | return A2COptimizer( |
| | cast(TorchPolicy, self.policy), self.trainer_settings |
| | ) |
| |
|
| | def create_policy( |
| | self, parsed_behavior_id: BehaviorIdentifiers, behavior_spec: BehaviorSpec |
| | ) -> TorchPolicy: |
| | """ |
| | Creates a policy with a PyTorch backend and PPO hyperparameters |
| | :param parsed_behavior_id: |
| | :param behavior_spec: specifications for policy construction |
| | :return policy |
| | """ |
| | actor_cls = SimpleActor |
| | actor_kwargs = {"conditional_sigma": False, "tanh_squash": False} |
| | if self.shared_critic: |
| | reward_signal_configs = self.trainer_settings.reward_signals |
| | reward_signal_names = [ |
| | key.value for key, _ in reward_signal_configs.items() |
| | ] |
| | actor_cls = SharedActorCritic |
| | actor_kwargs.update({"stream_names": reward_signal_names}) |
| |
|
| | policy = TorchPolicy( |
| | self.seed, |
| | behavior_spec, |
| | self.trainer_settings.network_settings, |
| | actor_cls, |
| | actor_kwargs, |
| | ) |
| | return policy |
| |
|
| | @staticmethod |
| | def get_settings_type(): |
| | return A2CSettings |
| |
|
| | @staticmethod |
| | def get_trainer_name() -> str: |
| | return TRAINER_NAME |
| |
|
| |
|
| | def get_type_and_setting(): |
| | return {A2CTrainer.get_trainer_name(): A2CTrainer}, { |
| | A2CTrainer.get_trainer_name(): A2CSettings |
| | } |
| |
|