| | from typing import cast |
| | from mlagents.torch_utils import torch, nn, default_device |
| | from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer |
| | from mlagents.trainers.policy.torch_policy import TorchPolicy |
| | from mlagents.trainers.buffer import AgentBuffer, BufferKey, RewardSignalUtil |
| | from mlagents_envs.timers import timed |
| | from typing import List, Dict, Tuple, Optional, Union, Any |
| | from mlagents.trainers.torch_entities.networks import ValueNetwork, Actor |
| | from mlagents_envs.base_env import ActionSpec, ObservationSpec |
| | from mlagents.trainers.torch_entities.agent_action import AgentAction |
| | from mlagents.trainers.torch_entities.utils import ModelUtils |
| | from mlagents.trainers.trajectory import ObsUtil |
| | from mlagents.trainers.settings import TrainerSettings, OffPolicyHyperparamSettings |
| | from mlagents.trainers.settings import ScheduleType, NetworkSettings |
| |
|
| | from mlagents.trainers.torch_entities.networks import Critic |
| | import numpy as np |
| | import attr |
| |
|
| |
|
| | |
| |
|
| |
|
| | @attr.s(auto_attribs=True) |
| | class DQNSettings(OffPolicyHyperparamSettings): |
| | gamma: float = 0.99 |
| | exploration_schedule: ScheduleType = ScheduleType.LINEAR |
| | exploration_initial_eps: float = 0.1 |
| | exploration_final_eps: float = 0.05 |
| | target_update_interval: int = 10000 |
| | tau: float = 0.005 |
| | steps_per_update: float = 1 |
| | save_replay_buffer: bool = False |
| | reward_signal_steps_per_update: float = attr.ib() |
| |
|
| | @reward_signal_steps_per_update.default |
| | def _reward_signal_steps_per_update_default(self): |
| | return self.steps_per_update |
| |
|
| |
|
| | class DQNOptimizer(TorchOptimizer): |
| | def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): |
| | super().__init__(policy, trainer_settings) |
| |
|
| | |
| | params = list(self.policy.actor.parameters()) |
| | self.optimizer = torch.optim.Adam( |
| | params, lr=self.trainer_settings.hyperparameters.learning_rate |
| | ) |
| | self.stream_names = list(self.reward_signals.keys()) |
| | self.gammas = [_val.gamma for _val in trainer_settings.reward_signals.values()] |
| | self.use_dones_in_backup = { |
| | name: int(not self.reward_signals[name].ignore_done) |
| | for name in self.stream_names |
| | } |
| |
|
| | self.hyperparameters: DQNSettings = cast( |
| | DQNSettings, trainer_settings.hyperparameters |
| | ) |
| | self.tau = self.hyperparameters.tau |
| | self.decay_learning_rate = ModelUtils.DecayedValue( |
| | self.hyperparameters.learning_rate_schedule, |
| | self.hyperparameters.learning_rate, |
| | 1e-10, |
| | self.trainer_settings.max_steps, |
| | ) |
| |
|
| | self.decay_exploration_rate = ModelUtils.DecayedValue( |
| | self.hyperparameters.exploration_schedule, |
| | self.hyperparameters.exploration_initial_eps, |
| | self.hyperparameters.exploration_final_eps, |
| | 20000, |
| | ) |
| |
|
| | |
| | self.q_net_target = QNetwork( |
| | stream_names=self.reward_signals.keys(), |
| | observation_specs=policy.behavior_spec.observation_specs, |
| | network_settings=policy.network_settings, |
| | action_spec=policy.behavior_spec.action_spec, |
| | ) |
| | ModelUtils.soft_update(self.policy.actor, self.q_net_target, 1.0) |
| |
|
| | self.q_net_target.to(default_device()) |
| |
|
| | @property |
| | def critic(self): |
| | return self.q_net_target |
| |
|
| | @timed |
| | def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: |
| | """ |
| | Performs update on model. |
| | :param batch: Batch of experiences. |
| | :param num_sequences: Number of sequences to process. |
| | :return: Results of update. |
| | """ |
| | |
| | decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step()) |
| | exp_rate = self.decay_exploration_rate.get_value(self.policy.get_current_step()) |
| | self.policy.actor.exploration_rate = exp_rate |
| | rewards = {} |
| | for name in self.reward_signals: |
| | rewards[name] = ModelUtils.list_to_tensor( |
| | batch[RewardSignalUtil.rewards_key(name)] |
| | ) |
| |
|
| | n_obs = len(self.policy.behavior_spec.observation_specs) |
| | current_obs = ObsUtil.from_buffer(batch, n_obs) |
| | |
| | current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs] |
| |
|
| | next_obs = ObsUtil.from_buffer_next(batch, n_obs) |
| | |
| | next_obs = [ModelUtils.list_to_tensor(obs) for obs in next_obs] |
| |
|
| | actions = AgentAction.from_buffer(batch) |
| |
|
| | dones = ModelUtils.list_to_tensor(batch[BufferKey.DONE]) |
| |
|
| | current_q_values, _ = self.policy.actor.critic_pass( |
| | current_obs, sequence_length=self.policy.sequence_length |
| | ) |
| |
|
| | qloss = [] |
| | with torch.no_grad(): |
| | greedy_actions = self.policy.actor.get_greedy_action(current_q_values) |
| | next_q_values_list, _ = self.q_net_target.critic_pass( |
| | next_obs, sequence_length=self.policy.sequence_length |
| | ) |
| | for name_i, name in enumerate(rewards.keys()): |
| | with torch.no_grad(): |
| | next_q_values = torch.gather( |
| | next_q_values_list[name], dim=1, index=greedy_actions |
| | ).squeeze() |
| | target_q_values = rewards[name] + ( |
| | (1.0 - self.use_dones_in_backup[name] * dones) |
| | * self.gammas[name_i] |
| | * next_q_values |
| | ) |
| | target_q_values = target_q_values.reshape(-1, 1) |
| | curr_q = torch.gather( |
| | current_q_values[name], dim=1, index=actions.discrete_tensor |
| | ) |
| | qloss.append(torch.nn.functional.smooth_l1_loss(curr_q, target_q_values)) |
| |
|
| | loss = torch.mean(torch.stack(qloss)) |
| | ModelUtils.update_learning_rate(self.optimizer, decay_lr) |
| | self.optimizer.zero_grad() |
| | loss.backward() |
| | self.optimizer.step() |
| |
|
| | ModelUtils.soft_update(self.policy.actor, self.q_net_target, self.tau) |
| | update_stats = { |
| | "Losses/Value Loss": loss.item(), |
| | "Policy/Learning Rate": decay_lr, |
| | "Policy/epsilon": exp_rate, |
| | } |
| |
|
| | for reward_provider in self.reward_signals.values(): |
| | update_stats.update(reward_provider.update(batch)) |
| | return update_stats |
| |
|
| | def get_modules(self): |
| | modules = { |
| | "Optimizer:value_optimizer": self.optimizer, |
| | "Optimizer:critic": self.critic, |
| | } |
| | for reward_provider in self.reward_signals.values(): |
| | modules.update(reward_provider.get_modules()) |
| | return modules |
| |
|
| |
|
| | class QNetwork(nn.Module, Actor, Critic): |
| | MODEL_EXPORT_VERSION = 3 |
| |
|
| | def __init__( |
| | self, |
| | stream_names: List[str], |
| | observation_specs: List[ObservationSpec], |
| | network_settings: NetworkSettings, |
| | action_spec: ActionSpec, |
| | exploration_initial_eps: float = 1.0, |
| | ): |
| | self.exploration_rate = exploration_initial_eps |
| | nn.Module.__init__(self) |
| | output_act_size = max(sum(action_spec.discrete_branches), 1) |
| | self.network_body = ValueNetwork( |
| | stream_names, |
| | observation_specs, |
| | network_settings, |
| | outputs_per_stream=output_act_size, |
| | ) |
| |
|
| | |
| | self.action_spec = action_spec |
| | self.version_number = torch.nn.Parameter( |
| | torch.Tensor([self.MODEL_EXPORT_VERSION]), requires_grad=False |
| | ) |
| | self.is_continuous_int_deprecated = torch.nn.Parameter( |
| | torch.Tensor([int(self.action_spec.is_continuous())]), requires_grad=False |
| | ) |
| | self.continuous_act_size_vector = torch.nn.Parameter( |
| | torch.Tensor([int(self.action_spec.continuous_size)]), requires_grad=False |
| | ) |
| | self.discrete_act_size_vector = torch.nn.Parameter( |
| | torch.Tensor([self.action_spec.discrete_branches]), requires_grad=False |
| | ) |
| | self.act_size_vector_deprecated = torch.nn.Parameter( |
| | torch.Tensor( |
| | [ |
| | self.action_spec.continuous_size |
| | + sum(self.action_spec.discrete_branches) |
| | ] |
| | ), |
| | requires_grad=False, |
| | ) |
| | self.memory_size_vector = torch.nn.Parameter( |
| | torch.Tensor([int(self.network_body.memory_size)]), requires_grad=False |
| | ) |
| |
|
| | def update_normalization(self, buffer: AgentBuffer) -> None: |
| | self.network_body.update_normalization(buffer) |
| |
|
| | def critic_pass( |
| | self, |
| | inputs: List[torch.Tensor], |
| | memories: Optional[torch.Tensor] = None, |
| | sequence_length: int = 1, |
| | ) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: |
| | value_outputs, critic_mem_out = self.network_body( |
| | inputs, memories=memories, sequence_length=sequence_length |
| | ) |
| | return value_outputs, critic_mem_out |
| |
|
| | @property |
| | def memory_size(self) -> int: |
| | return self.network_body.memory_size |
| |
|
| | def forward( |
| | self, |
| | inputs: List[torch.Tensor], |
| | masks: Optional[torch.Tensor] = None, |
| | memories: Optional[torch.Tensor] = None, |
| | sequence_length: int = 1, |
| | ) -> Tuple[Union[int, torch.Tensor], ...]: |
| | out_vals, memories = self.critic_pass(inputs, memories, sequence_length) |
| |
|
| | |
| | export_out = [self.version_number, self.memory_size_vector] |
| |
|
| | disc_action_out = self.get_greedy_action(out_vals) |
| | deterministic_disc_action_out = self.get_random_action(out_vals) |
| | export_out += [ |
| | disc_action_out, |
| | self.discrete_act_size_vector, |
| | deterministic_disc_action_out, |
| | ] |
| | return tuple(export_out) |
| |
|
| | def get_random_action(self, inputs) -> torch.Tensor: |
| | action_out = torch.randint( |
| | 0, self.action_spec.discrete_branches[0], (len(inputs), 1) |
| | ) |
| | return action_out |
| |
|
| | @staticmethod |
| | def get_greedy_action(q_values) -> torch.Tensor: |
| | all_q = torch.cat([val.unsqueeze(0) for val in q_values.values()]) |
| | return torch.argmax(all_q.sum(dim=0), dim=1, keepdim=True) |
| |
|
| | def get_action_and_stats( |
| | self, |
| | inputs: List[torch.Tensor], |
| | masks: Optional[torch.Tensor] = None, |
| | memories: Optional[torch.Tensor] = None, |
| | sequence_length: int = 1, |
| | deterministic=False, |
| | ) -> Tuple[AgentAction, Dict[str, Any], torch.Tensor]: |
| | run_out = {} |
| | if not deterministic and np.random.rand() < self.exploration_rate: |
| | action_out = self.get_random_action(inputs) |
| | action_out = AgentAction(None, [action_out]) |
| | run_out["env_action"] = action_out.to_action_tuple() |
| | else: |
| | out_vals, _ = self.critic_pass(inputs, memories, sequence_length) |
| | action_out = self.get_greedy_action(out_vals) |
| | action_out = AgentAction(None, [action_out]) |
| | run_out["env_action"] = action_out.to_action_tuple() |
| | return action_out, run_out, torch.Tensor([]) |
| |
|