Spaces:
Running
Running
| from typing import Dict, Any | |
| import torch | |
| from ding.torch_utils import to_device | |
| from ding.rl_utils import dist_nstep_td_data, dist_nstep_td_error, dist_1step_td_data, dist_1step_td_error | |
| from ding.policy import RainbowDQNPolicy | |
| from ding.utils import POLICY_REGISTRY | |
| from ding.policy.common_utils import default_preprocess_learn | |
| class MultiDiscreteRainbowDQNPolicy(RainbowDQNPolicy): | |
| r""" | |
| Overview: | |
| Multi-discrete action space Rainbow DQN algorithms. | |
| """ | |
| def _forward_learn(self, data: dict) -> Dict[str, Any]: | |
| """ | |
| Overview: | |
| Forward and backward function of learn mode, acquire the data and calculate the loss and \ | |
| optimize learner model | |
| Arguments: | |
| - data (:obj:`dict`): Dict type data, including at least ['obs', 'next_obs', 'reward', 'action'] | |
| Returns: | |
| - info_dict (:obj:`Dict[str, Any]`): Including cur_lr, total_loss and priority | |
| - cur_lr (:obj:`float`): current learning rate | |
| - total_loss (:obj:`float`): the calculated loss | |
| - priority (:obj:`list`): the priority of samples | |
| """ | |
| data = default_preprocess_learn( | |
| data, | |
| use_priority=self._priority, | |
| use_priority_IS_weight=self._cfg.priority_IS_weight, | |
| ignore_done=self._cfg.learn.ignore_done, | |
| use_nstep=True | |
| ) | |
| if self._cuda: | |
| data = to_device(data, self._device) | |
| # ==================== | |
| # Rainbow forward | |
| # ==================== | |
| self._learn_model.train() | |
| self._target_model.train() | |
| # reset noise of noisenet for both main model and target model | |
| self._reset_noise(self._learn_model) | |
| self._reset_noise(self._target_model) | |
| q_dist = self._learn_model.forward(data['obs'])['distribution'] | |
| with torch.no_grad(): | |
| target_q_dist = self._target_model.forward(data['next_obs'])['distribution'] | |
| self._reset_noise(self._learn_model) | |
| target_q_action = self._learn_model.forward(data['next_obs'])['action'] | |
| value_gamma = data.get('value_gamma', None) | |
| if isinstance(q_dist, torch.Tensor): | |
| td_data = dist_nstep_td_data( | |
| q_dist, target_q_dist, data['action'], target_q_action, data['reward'], data['done'], data['weight'] | |
| ) | |
| loss, td_error_per_sample = dist_nstep_td_error( | |
| td_data, | |
| self._gamma, | |
| self._v_min, | |
| self._v_max, | |
| self._n_atom, | |
| nstep=self._nstep, | |
| value_gamma=value_gamma | |
| ) | |
| else: | |
| act_num = len(q_dist) | |
| losses = [] | |
| td_error_per_samples = [] | |
| for i in range(act_num): | |
| td_data = dist_nstep_td_data( | |
| q_dist[i], target_q_dist[i], data['action'][i], target_q_action[i], data['reward'], data['done'], | |
| data['weight'] | |
| ) | |
| td_loss, td_error_per_sample = dist_nstep_td_error( | |
| td_data, | |
| self._gamma, | |
| self._v_min, | |
| self._v_max, | |
| self._n_atom, | |
| nstep=self._nstep, | |
| value_gamma=value_gamma | |
| ) | |
| losses.append(td_loss) | |
| td_error_per_samples.append(td_error_per_sample) | |
| loss = sum(losses) / (len(losses) + 1e-8) | |
| td_error_per_sample_mean = sum(td_error_per_samples) / (len(td_error_per_samples) + 1e-8) | |
| # ==================== | |
| # Rainbow update | |
| # ==================== | |
| self._optimizer.zero_grad() | |
| loss.backward() | |
| self._optimizer.step() | |
| # ============= | |
| # after update | |
| # ============= | |
| self._target_model.update(self._learn_model.state_dict()) | |
| return { | |
| 'cur_lr': self._optimizer.defaults['lr'], | |
| 'total_loss': loss.item(), | |
| 'priority': td_error_per_sample_mean.abs().tolist(), | |
| } | |