Spaces:
Running
Running
| from typing import Union, Dict | |
| import uuid | |
| import copy | |
| import os | |
| import os.path as osp | |
| from abc import abstractmethod | |
| from easydict import EasyDict | |
| from tabulate import tabulate | |
| from ding.league.player import ActivePlayer, HistoricalPlayer, create_player | |
| from ding.league.shared_payoff import create_payoff | |
| from ding.utils import import_module, read_file, save_file, LockContext, LockContextType, LEAGUE_REGISTRY, \ | |
| deep_merge_dicts | |
| from .metric import LeagueMetricEnv | |
| class BaseLeague: | |
| """ | |
| Overview: | |
| League, proposed by Google Deepmind AlphaStar. Can manage multiple players in one league. | |
| Interface: | |
| get_job_info, judge_snapshot, update_active_player, finish_job, save_checkpoint | |
| .. note:: | |
| In ``__init__`` method, league would also initialized players as well(in ``_init_players`` method). | |
| """ | |
| def default_config(cls: type) -> EasyDict: | |
| cfg = EasyDict(copy.deepcopy(cls.config)) | |
| cfg.cfg_type = cls.__name__ + 'Dict' | |
| return cfg | |
| config = dict( | |
| league_type='base', | |
| import_names=["ding.league.base_league"], | |
| # ---player---- | |
| # "player_category" is just a name. Depends on the env. | |
| # For example, in StarCraft, this can be ['zerg', 'terran', 'protoss']. | |
| player_category=['default'], | |
| # Support different types of active players for solo and battle league. | |
| # For solo league, supports ['solo_active_player']. | |
| # For battle league, supports ['battle_active_player', 'main_player', 'main_exploiter', 'league_exploiter']. | |
| # active_players=dict(), | |
| # "use_pretrain" means whether to use pretrain model to initialize active player. | |
| use_pretrain=False, | |
| # "use_pretrain_init_historical" means whether to use pretrain model to initialize historical player. | |
| # "pretrain_checkpoint_path" is the pretrain checkpoint path used in "use_pretrain" and | |
| # "use_pretrain_init_historical". If both are False, "pretrain_checkpoint_path" can be omitted as well. | |
| # Otherwise, "pretrain_checkpoint_path" should list paths of all player categories. | |
| use_pretrain_init_historical=False, | |
| pretrain_checkpoint_path=dict(default='default_cate_pretrain.pth', ), | |
| # ---payoff--- | |
| payoff=dict( | |
| # Supports ['battle'] | |
| type='battle', | |
| decay=0.99, | |
| min_win_rate_games=8, | |
| ), | |
| metric=dict( | |
| mu=0, | |
| sigma=25 / 3, | |
| beta=25 / 3 / 2, | |
| tau=0.0, | |
| draw_probability=0.02, | |
| ), | |
| ) | |
| def __init__(self, cfg: EasyDict) -> None: | |
| """ | |
| Overview: | |
| Initialization method. | |
| Arguments: | |
| - cfg (:obj:`EasyDict`): League config. | |
| """ | |
| self.cfg = deep_merge_dicts(self.default_config(), cfg) | |
| self.path_policy = cfg.path_policy | |
| if not osp.exists(self.path_policy): | |
| os.mkdir(self.path_policy) | |
| self.league_uid = str(uuid.uuid1()) | |
| # TODO dict players | |
| self.active_players = [] | |
| self.historical_players = [] | |
| self.player_path = "./league" | |
| self.payoff = create_payoff(self.cfg.payoff) | |
| metric_cfg = self.cfg.metric | |
| self.metric_env = LeagueMetricEnv(metric_cfg.mu, metric_cfg.sigma, metric_cfg.tau, metric_cfg.draw_probability) | |
| self._active_players_lock = LockContext(type_=LockContextType.THREAD_LOCK) | |
| self._init_players() | |
| def _init_players(self) -> None: | |
| """ | |
| Overview: | |
| Initialize players (active & historical) in the league. | |
| """ | |
| # Add different types of active players for each player category, according to ``cfg.active_players``. | |
| for cate in self.cfg.player_category: # Player's category (Depends on the env) | |
| for k, n in self.cfg.active_players.items(): # Active player's type | |
| for i in range(n): # This type's active player number | |
| name = '{}_{}_{}'.format(k, cate, i) | |
| ckpt_path = osp.join(self.path_policy, '{}_ckpt.pth'.format(name)) | |
| player = create_player( | |
| self.cfg, k, self.cfg[k], cate, self.payoff, ckpt_path, name, 0, self.metric_env.create_rating() | |
| ) | |
| if self.cfg.use_pretrain: | |
| self.save_checkpoint(self.cfg.pretrain_checkpoint_path[cate], ckpt_path) | |
| self.active_players.append(player) | |
| self.payoff.add_player(player) | |
| # Add pretrain player as the initial HistoricalPlayer for each player category. | |
| if self.cfg.use_pretrain_init_historical: | |
| for cate in self.cfg.player_category: | |
| main_player_name = [k for k in self.cfg.keys() if 'main_player' in k] | |
| assert len(main_player_name) == 1, main_player_name | |
| main_player_name = main_player_name[0] | |
| name = '{}_{}_0_pretrain_historical'.format(main_player_name, cate) | |
| parent_name = '{}_{}_0'.format(main_player_name, cate) | |
| hp = HistoricalPlayer( | |
| self.cfg.get(main_player_name), | |
| cate, | |
| self.payoff, | |
| self.cfg.pretrain_checkpoint_path[cate], | |
| name, | |
| 0, | |
| self.metric_env.create_rating(), | |
| parent_id=parent_name | |
| ) | |
| self.historical_players.append(hp) | |
| self.payoff.add_player(hp) | |
| # Save active players' ``player_id``` & ``player_ckpt```. | |
| self.active_players_ids = [p.player_id for p in self.active_players] | |
| self.active_players_ckpts = [p.checkpoint_path for p in self.active_players] | |
| # Validate active players are unique by ``player_id``. | |
| assert len(self.active_players_ids) == len(set(self.active_players_ids)) | |
| def get_job_info(self, player_id: str = None, eval_flag: bool = False) -> dict: | |
| """ | |
| Overview: | |
| Get info dict of the job which is to be launched to an active player. | |
| Arguments: | |
| - player_id (:obj:`str`): The active player's id. | |
| - eval_flag (:obj:`bool`): Whether this is an evaluation job. | |
| Returns: | |
| - job_info (:obj:`dict`): Job info. | |
| ReturnsKeys: | |
| - necessary: ``launch_player`` (the active player) | |
| """ | |
| if player_id is None: | |
| player_id = self.active_players_ids[0] | |
| with self._active_players_lock: | |
| idx = self.active_players_ids.index(player_id) | |
| player = self.active_players[idx] | |
| job_info = self._get_job_info(player, eval_flag) | |
| assert 'launch_player' in job_info.keys() and job_info['launch_player'] == player.player_id | |
| return job_info | |
| def _get_job_info(self, player: ActivePlayer, eval_flag: bool = False) -> dict: | |
| """ | |
| Overview: | |
| Real `get_job` method. Called by ``_launch_job``. | |
| Arguments: | |
| - player (:obj:`ActivePlayer`): The active player to be launched a job. | |
| - eval_flag (:obj:`bool`): Whether this is an evaluation job. | |
| Returns: | |
| - job_info (:obj:`dict`): Job info. Should include keys ['lauch_player']. | |
| """ | |
| raise NotImplementedError | |
| def judge_snapshot(self, player_id: str, force: bool = False) -> bool: | |
| """ | |
| Overview: | |
| Judge whether a player is trained enough for snapshot. If yes, call player's ``snapshot``, create a | |
| historical player(prepare the checkpoint and add it to the shared payoff), then mutate it, and return True. | |
| Otherwise, return False. | |
| Arguments: | |
| - player_id (:obj:`ActivePlayer`): The active player's id. | |
| Returns: | |
| - snapshot_or_not (:obj:`dict`): Whether the active player is snapshotted. | |
| """ | |
| with self._active_players_lock: | |
| idx = self.active_players_ids.index(player_id) | |
| player = self.active_players[idx] | |
| if force or player.is_trained_enough(): | |
| # Snapshot | |
| hp = player.snapshot(self.metric_env) | |
| self.save_checkpoint(player.checkpoint_path, hp.checkpoint_path) | |
| self.historical_players.append(hp) | |
| self.payoff.add_player(hp) | |
| # Mutate | |
| self._mutate_player(player) | |
| return True | |
| else: | |
| return False | |
| def _mutate_player(self, player: ActivePlayer) -> None: | |
| """ | |
| Overview: | |
| Players have the probability to mutate, e.g. Reset network parameters. | |
| Called by ``self.judge_snapshot``. | |
| Arguments: | |
| - player (:obj:`ActivePlayer`): The active player that may mutate. | |
| """ | |
| raise NotImplementedError | |
| def update_active_player(self, player_info: dict) -> None: | |
| """ | |
| Overview: | |
| Update an active player's info. | |
| Arguments: | |
| - player_info (:obj:`dict`): Info dict of the player which is to be updated. | |
| ArgumentsKeys: | |
| - necessary: `player_id`, `train_iteration` | |
| """ | |
| try: | |
| idx = self.active_players_ids.index(player_info['player_id']) | |
| player = self.active_players[idx] | |
| return self._update_player(player, player_info) | |
| except ValueError as e: | |
| print(e) | |
| def _update_player(self, player: ActivePlayer, player_info: dict) -> None: | |
| """ | |
| Overview: | |
| Update an active player. Called by ``self.update_active_player``. | |
| Arguments: | |
| - player (:obj:`ActivePlayer`): The active player that will be updated. | |
| - player_info (:obj:`dict`): Info dict of the active player which is to be updated. | |
| """ | |
| raise NotImplementedError | |
| def finish_job(self, job_info: dict) -> None: | |
| """ | |
| Overview: | |
| Finish current job. Update shared payoff to record the game results. | |
| Arguments: | |
| - job_info (:obj:`dict`): A dict containing job result information. | |
| """ | |
| # TODO(nyz) more fine-grained job info | |
| self.payoff.update(job_info) | |
| if 'eval_flag' in job_info and job_info['eval_flag']: | |
| home_id, away_id = job_info['player_id'] | |
| home_player, away_player = self.get_player_by_id(home_id), self.get_player_by_id(away_id) | |
| job_info_result = job_info['result'] | |
| if isinstance(job_info_result[0], list): | |
| job_info_result = sum(job_info_result, []) | |
| home_player.rating, away_player.rating = self.metric_env.rate_1vs1( | |
| home_player.rating, away_player.rating, result=job_info_result | |
| ) | |
| def get_player_by_id(self, player_id: str) -> 'Player': # noqa | |
| if 'historical' in player_id: | |
| return [p for p in self.historical_players if p.player_id == player_id][0] | |
| else: | |
| return [p for p in self.active_players if p.player_id == player_id][0] | |
| def save_checkpoint(src_checkpoint, dst_checkpoint) -> None: | |
| ''' | |
| Overview: | |
| Copy a checkpoint from path ``src_checkpoint`` to path ``dst_checkpoint``. | |
| Arguments: | |
| - src_checkpoint (:obj:`str`): Source checkpoint's path, e.g. s3://alphastar_fake_data/ckpt.pth | |
| - dst_checkpoint (:obj:`str`): Destination checkpoint's path, e.g. s3://alphastar_fake_data/ckpt.pth | |
| ''' | |
| checkpoint = read_file(src_checkpoint) | |
| save_file(dst_checkpoint, checkpoint) | |
| def player_rank(self, string: bool = False) -> Union[str, Dict[str, float]]: | |
| rank = {} | |
| for p in self.active_players + self.historical_players: | |
| name = p.player_id | |
| rank[name] = p.rating.exposure | |
| if string: | |
| headers = ["Player ID", "Rank (TrueSkill)"] | |
| data = [] | |
| for k, v in rank.items(): | |
| data.append([k, "{:.2f}".format(v)]) | |
| s = "\n" + tabulate(data, headers=headers, tablefmt='pipe') | |
| return s | |
| else: | |
| return rank | |
| def create_league(cfg: EasyDict, *args) -> BaseLeague: | |
| """ | |
| Overview: | |
| Given the key (league_type), create a new league instance if in league_mapping's values, | |
| or raise an KeyError. In other words, a derived league must first register then call ``create_league`` | |
| to get the instance object. | |
| Arguments: | |
| - cfg (:obj:`EasyDict`): league config, necessary keys: [league.import_module, league.learner_type] | |
| Returns: | |
| - league (:obj:`BaseLeague`): the created new league, should be an instance of one of \ | |
| league_mapping's values | |
| """ | |
| import_module(cfg.get('import_names', [])) | |
| return LEAGUE_REGISTRY.build(cfg.league_type, cfg=cfg, *args) | |