Spaces:
Sleeping
Sleeping
| from collections import namedtuple | |
| from typing import Optional, Callable, Tuple | |
| import torch | |
| import numpy as np | |
| from ding.envs import BaseEnv | |
| from ding.envs import BaseEnvManager | |
| from ding.torch_utils import to_tensor, to_item | |
| from ding.utils import build_logger, EasyTimer, SERIAL_EVALUATOR_REGISTRY | |
| from ding.utils import get_world_size, get_rank, broadcast_object_list | |
| from ding.worker.collector.base_serial_evaluator import ISerialEvaluator, VectorEvalMonitor | |
| class AlphaZeroEvaluator(ISerialEvaluator): | |
| """ | |
| Overview: | |
| AlphaZero Evaluator. | |
| Interfaces: | |
| __init__, reset, reset_policy, reset_env, close, should_eval, eval | |
| Property: | |
| env, policy | |
| """ | |
| def __init__( | |
| self, | |
| eval_freq: int = 1000, | |
| n_evaluator_episode: int = 3, | |
| stop_value: int = 1e6, | |
| env: BaseEnv = None, | |
| policy: namedtuple = None, | |
| tb_logger: 'SummaryWriter' = None, # noqa | |
| exp_name: Optional[str] = 'default_experiment', | |
| instance_name: Optional[str] = 'evaluator', | |
| env_config=None, | |
| ) -> None: | |
| """ | |
| Overview: | |
| Init the AlphaZero evaluator according to input arguments. | |
| Arguments: | |
| - eval_freq (:obj:`int`): evaluation frequency in terms of training steps. | |
| - n_evaluator_episode (:obj:`int`): the number of episodes to eval in total. | |
| - env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \ | |
| its derivatives are supported. | |
| - policy (:obj:`Policy`): The policy to be collected. | |
| - tb_logger (:obj:`SummaryWriter`): Logger, defaultly set as 'SummaryWriter' for model summary. | |
| - exp_name (:obj:`str`): Experiment name, which is used to indicate output directory. | |
| - instance_name (:obj:`Optional[str]`): Name of this instance. | |
| - env_config: Config of environment | |
| """ | |
| self._eval_freq = eval_freq | |
| self._exp_name = exp_name | |
| self._instance_name = instance_name | |
| self._end_flag = False | |
| self._env_config = env_config | |
| # Logger (Monitor will be initialized in policy setter) | |
| # Only rank == 0 learner needs monitor and tb_logger, others only need text_logger to display terminal output. | |
| if get_rank() == 0: | |
| if tb_logger is not None: | |
| self._logger, _ = build_logger( | |
| './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False | |
| ) | |
| self._tb_logger = tb_logger | |
| else: | |
| self._logger, self._tb_logger = build_logger( | |
| './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name | |
| ) | |
| else: | |
| self._logger, self._tb_logger = None, None # for close elegantly | |
| self.reset(policy, env) | |
| self._timer = EasyTimer() | |
| self._default_n_episode = n_evaluator_episode | |
| self._stop_value = stop_value | |
| def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: | |
| """ | |
| Overview: | |
| Reset evaluator's environment. In some case, we need evaluator use the same policy in different \ | |
| environments. We can use reset_env to reset the environment. | |
| If _env is None, reset the old environment. | |
| If _env is not None, replace the old environment in the evaluator with the \ | |
| new passed in environment and launch. | |
| Arguments: | |
| - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ | |
| env_manager(BaseEnvManager) | |
| """ | |
| if _env is not None: | |
| self._env = _env | |
| self._env.launch() | |
| self._env_num = self._env.env_num | |
| else: | |
| self._env.reset() | |
| def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: | |
| """ | |
| Overview: | |
| Reset evaluator's policy. In some case, we need evaluator work in this same environment but use\ | |
| different policy. We can use reset_policy to reset the policy. | |
| If _policy is None, reset the old policy. | |
| If _policy is not None, replace the old policy in the evaluator with the new passed in policy. | |
| Arguments: | |
| - policy (:obj:`Optional[namedtuple]`): the api namedtuple of eval_mode policy | |
| """ | |
| assert hasattr(self, '_env'), "please set env first" | |
| if _policy is not None: | |
| self._policy = _policy | |
| self._policy.reset() | |
| def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: | |
| """ | |
| Overview: | |
| Reset evaluator's policy and environment. Use new policy and environment to collect data. | |
| If _env is None, reset the old environment. | |
| If _env is not None, replace the old environment in the evaluator with the new passed in \ | |
| environment and launch. | |
| If _policy is None, reset the old policy. | |
| If _policy is not None, replace the old policy in the evaluator with the new passed in policy. | |
| Arguments: | |
| - policy (:obj:`Optional[namedtuple]`): the api namedtuple of eval_mode policy | |
| - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ | |
| env_manager(BaseEnvManager) | |
| """ | |
| if _env is not None: | |
| self.reset_env(_env) | |
| if _policy is not None: | |
| self.reset_policy(_policy) | |
| self._max_eval_reward = float("-inf") | |
| self._last_eval_iter = -1 | |
| self._end_flag = False | |
| def close(self) -> None: | |
| """ | |
| Overview: | |
| Close the evaluator. If end_flag is False, close the environment, flush the tb_logger\ | |
| and close the tb_logger. | |
| """ | |
| if self._end_flag: | |
| return | |
| self._end_flag = True | |
| self._env.close() | |
| if self._tb_logger: | |
| self._tb_logger.flush() | |
| self._tb_logger.close() | |
| def __del__(self) -> None: | |
| """ | |
| Overview: | |
| Execute the close command and close the evaluator. __del__ is automatically called \ | |
| to destroy the evaluator instance when the evaluator finishes its work | |
| """ | |
| self.close() | |
| def should_eval(self, train_iter: int) -> bool: | |
| """ | |
| Overview: | |
| Determine whether you need to start the evaluation mode, if the number of training has reached\ | |
| the maximum number of times to start the evaluator, return True | |
| Arguments: | |
| - train_iter (:obj:`int`): Current training iteration. | |
| """ | |
| if train_iter == self._last_eval_iter: | |
| return False | |
| if (train_iter - self._last_eval_iter) < self._eval_freq and train_iter != 0: | |
| return False | |
| self._last_eval_iter = train_iter | |
| return True | |
| def eval( | |
| self, | |
| save_ckpt_fn: Callable = None, | |
| train_iter: int = -1, | |
| envstep: int = -1, | |
| n_episode: Optional[int] = None, | |
| force_render: bool = False, | |
| ) -> Tuple[bool, dict]: | |
| """ | |
| Overview: | |
| Evaluate policy and store the best policy based on whether it reaches the highest historical reward. | |
| Arguments: | |
| - save_ckpt_fn (:obj:`Callable`): Saving ckpt function, which will be triggered by getting the best reward. | |
| - train_iter (:obj:`int`): Current training iteration. | |
| - envstep (:obj:`int`): Current env interaction step. | |
| - n_episode (:obj:`int`): Number of evaluation episodes. | |
| Returns: | |
| - stop_flag (:obj:`bool`): Whether this training program can be ended. | |
| - return_info (:obj:`dict`): Current evaluation return information. | |
| """ | |
| # evaluator only work on rank0 | |
| stop_flag, return_info = False, [] | |
| if get_rank() == 0: | |
| if n_episode is None: | |
| n_episode = self._default_n_episode | |
| assert n_episode is not None, "please indicate eval n_episode" | |
| envstep_count = 0 | |
| eval_monitor = VectorEvalMonitor(self._env.env_num, n_episode) | |
| self._env.reset() | |
| self._policy.reset() | |
| with self._timer: | |
| while not eval_monitor.is_finished(): | |
| obs = self._env.ready_obs | |
| # ============================================================== | |
| # policy forward | |
| # ============================================================== | |
| policy_output = self._policy.forward(obs) | |
| actions = {env_id: output['action'] for env_id, output in policy_output.items()} | |
| # ============================================================== | |
| # Interact with env. | |
| # ============================================================== | |
| timesteps = self._env.step(actions) | |
| timesteps = to_tensor(timesteps, dtype=torch.float32) | |
| for env_id, t in timesteps.items(): | |
| if t.info.get('abnormal', False): | |
| # If there is an abnormal timestep, reset all the related variables(including this env). | |
| self._policy.reset([env_id]) | |
| continue | |
| if t.done: | |
| # Env reset is done by env_manager automatically. | |
| self._policy.reset([env_id]) | |
| reward = t.info['eval_episode_return'] | |
| saved_info = {'eval_episode_return': t.info['eval_episode_return']} | |
| if 'episode_info' in t.info: | |
| saved_info.update(t.info['episode_info']) | |
| eval_monitor.update_info(env_id, saved_info) | |
| eval_monitor.update_reward(env_id, reward) | |
| return_info.append(t.info) | |
| self._logger.info( | |
| "[EVALUATOR]env {} finish episode, final reward: {}, current episode: {}".format( | |
| env_id, eval_monitor.get_latest_reward(env_id), eval_monitor.get_current_episode() | |
| ) | |
| ) | |
| envstep_count += 1 | |
| duration = self._timer.value | |
| episode_return = eval_monitor.get_episode_return() | |
| info = { | |
| 'train_iter': train_iter, | |
| 'ckpt_name': 'iteration_{}.pth.tar'.format(train_iter), | |
| 'episode_count': n_episode, | |
| 'envstep_count': envstep_count, | |
| 'avg_envstep_per_episode': envstep_count / n_episode, | |
| 'evaluate_time': duration, | |
| 'avg_envstep_per_sec': envstep_count / duration, | |
| 'avg_time_per_episode': n_episode / duration, | |
| 'reward_mean': np.mean(episode_return), | |
| 'reward_std': np.std(episode_return), | |
| 'reward_max': np.max(episode_return), | |
| 'reward_min': np.min(episode_return), | |
| # 'each_reward': episode_return, | |
| } | |
| episode_info = eval_monitor.get_episode_info() | |
| if episode_info is not None: | |
| info.update(episode_info) | |
| self._logger.info(self._logger.get_tabulate_vars_hor(info)) | |
| # self._logger.info(self._logger.get_tabulate_vars(info)) | |
| for k, v in info.items(): | |
| if k in ['train_iter', 'ckpt_name', 'each_reward']: | |
| continue | |
| if not np.isscalar(v): | |
| continue | |
| self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) | |
| self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, envstep) | |
| eval_reward = np.mean(episode_return) | |
| if eval_reward > self._max_eval_reward: | |
| if save_ckpt_fn: | |
| save_ckpt_fn('ckpt_best.pth.tar') | |
| self._max_eval_reward = eval_reward | |
| stop_flag = eval_reward >= self._stop_value and train_iter > 0 | |
| if stop_flag: | |
| self._logger.info( | |
| "[LightZero serial pipeline] " + | |
| "Current eval_reward: {} is greater than stop_value: {}".format(eval_reward, self._stop_value) + | |
| ", so your AlphaZero agent is converged, you can refer to " + | |
| "'log/evaluator/evaluator_logger.txt' for details." | |
| ) | |
| if get_world_size() > 1: | |
| objects = [stop_flag, episode_info] | |
| broadcast_object_list(objects, src=0) | |
| stop_flag, episode_info = objects | |
| episode_info = to_item(episode_info) | |
| return stop_flag, episode_info | |