Spaces:
Running
Running
| from typing import Callable, Optional, List | |
| from collections import namedtuple | |
| import numpy as np | |
| from easydict import EasyDict | |
| from ding.utils import import_module, PLAYER_REGISTRY | |
| from .algorithm import pfsp | |
| class Player: | |
| """ | |
| Overview: | |
| Base player class, player is the basic member of a league | |
| Interfaces: | |
| __init__ | |
| Property: | |
| race, payoff, checkpoint_path, player_id, total_agent_step | |
| """ | |
| _name = "BasePlayer" # override this variable for sub-class player | |
| def __init__( | |
| self, | |
| cfg: EasyDict, | |
| category: str, | |
| init_payoff: 'BattleSharedPayoff', # noqa | |
| checkpoint_path: str, | |
| player_id: str, | |
| total_agent_step: int, | |
| rating: 'PlayerRating', # noqa | |
| ) -> None: | |
| """ | |
| Overview: | |
| Initialize base player metadata | |
| Arguments: | |
| - cfg (:obj:`EasyDict`): Player config dict. | |
| - category (:obj:`str`): Player category, depending on the game, \ | |
| e.g. StarCraft has 3 races ['terran', 'protoss', 'zerg']. | |
| - init_payoff (:obj:`Union[BattleSharedPayoff, SoloSharedPayoff]`): Payoff shared by all players. | |
| - checkpoint_path (:obj:`str`): The path to load player checkpoint. | |
| - player_id (:obj:`str`): Player id in string format. | |
| - total_agent_step (:obj:`int`): For active player, it should be 0; \ | |
| For historical player, it should be parent player's ``_total_agent_step`` when ``snapshot``. | |
| - rating (:obj:`PlayerRating`): player rating information in total league | |
| """ | |
| self._cfg = cfg | |
| self._category = category | |
| self._payoff = init_payoff | |
| self._checkpoint_path = checkpoint_path | |
| assert isinstance(player_id, str) | |
| self._player_id = player_id | |
| assert isinstance(total_agent_step, int), (total_agent_step, type(total_agent_step)) | |
| self._total_agent_step = total_agent_step | |
| self._rating = rating | |
| def category(self) -> str: | |
| return self._category | |
| def payoff(self) -> 'BattleSharedPayoff': # noqa | |
| return self._payoff | |
| def checkpoint_path(self) -> str: | |
| return self._checkpoint_path | |
| def player_id(self) -> str: | |
| return self._player_id | |
| def total_agent_step(self) -> int: | |
| return self._total_agent_step | |
| def total_agent_step(self, step: int) -> None: | |
| self._total_agent_step = step | |
| def rating(self) -> 'PlayerRating': # noqa | |
| return self._rating | |
| def rating(self, _rating: 'PlayerRating') -> None: # noqa | |
| self._rating = _rating | |
| class HistoricalPlayer(Player): | |
| """ | |
| Overview: | |
| Historical player which is snapshotted from an active player, and is fixed with the checkpoint. | |
| Have a unique attribute ``parent_id``. | |
| Property: | |
| race, payoff, checkpoint_path, player_id, total_agent_step, parent_id | |
| """ | |
| _name = "HistoricalPlayer" | |
| def __init__(self, *args, parent_id: str) -> None: | |
| """ | |
| Overview: | |
| Initialize ``_parent_id`` additionally | |
| Arguments: | |
| - parent_id (:obj:`str`): id of historical player's parent, should be an active player | |
| """ | |
| super().__init__(*args) | |
| self._parent_id = parent_id | |
| def parent_id(self) -> str: | |
| return self._parent_id | |
| class ActivePlayer(Player): | |
| """ | |
| Overview: | |
| Active player can be updated, or snapshotted to a historical player in the league training. | |
| Interface: | |
| __init__, is_trained_enough, snapshot, mutate, get_job | |
| Property: | |
| race, payoff, checkpoint_path, player_id, total_agent_step | |
| """ | |
| _name = "ActivePlayer" | |
| BRANCH = namedtuple("BRANCH", ['name', 'prob']) | |
| def __init__(self, *args, **kwargs) -> None: | |
| """ | |
| Overview: | |
| Initialize player metadata, depending on the game | |
| Note: | |
| - one_phase_step (:obj:`int`): An active player will be considered trained enough for snapshot \ | |
| after two phase steps. | |
| - last_enough_step (:obj:`int`): Player's last step number that satisfies ``_is_trained_enough``. | |
| - strong_win_rate (:obj:`float`): If win rates between this player and all the opponents are greater than | |
| this value, this player can be regarded as strong enough to these opponents. \ | |
| If also already trained for one phase step, this player can be regarded as trained enough for snapshot. | |
| - branch_probs (:obj:`namedtuple`): A namedtuple of probabilities of selecting different opponent branch. | |
| """ | |
| super().__init__(*args) | |
| self._one_phase_step = int(float(self._cfg.one_phase_step)) # ``one_phase_step`` is like 1e9 | |
| self._last_enough_step = 0 | |
| self._strong_win_rate = self._cfg.strong_win_rate | |
| assert isinstance(self._cfg.branch_probs, dict) | |
| self._branch_probs = [self.BRANCH(k, v) for k, v in self._cfg.branch_probs.items()] | |
| # self._eval_opponent_difficulty = ["WEAK", "MEDIUM", "STRONG"] | |
| self._eval_opponent_difficulty = ["RULE_BASED"] | |
| self._eval_opponent_index = 0 | |
| def is_trained_enough(self, select_fn: Optional[Callable] = None) -> bool: | |
| """ | |
| Overview: | |
| Judge whether this player is trained enough for further operations(e.g. snapshot, mutate...) | |
| according to past step count and overall win rates against opponents. | |
| If yes, set ``self._last_agent_step`` to ``self._total_agent_step`` and return True; otherwise return False. | |
| Arguments: | |
| - select_fn (:obj:`function`): The function to select opponent players. | |
| Returns: | |
| - flag (:obj:`bool`): Whether this player is trained enough | |
| """ | |
| if select_fn is None: | |
| select_fn = lambda x: isinstance(x, HistoricalPlayer) # noqa | |
| step_passed = self._total_agent_step - self._last_enough_step | |
| if step_passed < self._one_phase_step: | |
| return False | |
| elif step_passed >= 2 * self._one_phase_step: | |
| # ``step_passed`` is 2 times of ``self._one_phase_step``, regarded as trained enough | |
| self._last_enough_step = self._total_agent_step | |
| return True | |
| else: | |
| # Get payoff against specific opponents (Different players have different type of opponent players) | |
| # If min win rate is larger than ``self._strong_win_rate``, then is judged trained enough | |
| selected_players = self._get_players(select_fn) | |
| if len(selected_players) == 0: # No such player, therefore no past game | |
| return False | |
| win_rates = self._payoff[self, selected_players] | |
| if win_rates.min() > self._strong_win_rate: | |
| self._last_enough_step = self._total_agent_step | |
| return True | |
| else: | |
| return False | |
| def snapshot(self, metric_env: 'LeagueMetricEnv') -> HistoricalPlayer: # noqa | |
| """ | |
| Overview: | |
| Generate a snapshot historical player from the current player, called in league's ``_snapshot``. | |
| Argument: | |
| - metric_env (:obj:`LeagueMetricEnv`): player rating environment, one league one env | |
| Returns: | |
| - snapshot_player (:obj:`HistoricalPlayer`): new instantiated historical player | |
| .. note:: | |
| This method only generates a historical player object, but without saving the checkpoint, which should be | |
| done by league. | |
| """ | |
| path = self.checkpoint_path.split('.pth')[0] + '_{}'.format(self._total_agent_step) + '.pth' | |
| return HistoricalPlayer( | |
| self._cfg, | |
| self.category, | |
| self.payoff, | |
| path, | |
| self.player_id + '_{}_historical'.format(int(self._total_agent_step)), | |
| self._total_agent_step, | |
| metric_env.create_rating(mu=self.rating.mu), | |
| parent_id=self.player_id | |
| ) | |
| def mutate(self, info: dict) -> Optional[str]: | |
| """ | |
| Overview: | |
| Mutate the current player, called in league's ``_mutate_player``. | |
| Arguments: | |
| - info (:obj:`dict`): related information for the mutation | |
| Returns: | |
| - mutation_result (:obj:`str`): if the player does the mutation operation then returns the | |
| corresponding model path, otherwise returns None | |
| """ | |
| pass | |
| def get_job(self, eval_flag: bool = False) -> dict: | |
| """ | |
| Overview: | |
| Get a dict containing some info about the job to be launched, e.g. the selected opponent. | |
| Arguments: | |
| - eval_flag (:obj:`bool`): Whether to select an opponent for evaluator task. | |
| Returns: | |
| - ret (:obj:`dict`): The returned dict. Should contain key ['opponent']. | |
| """ | |
| if eval_flag: | |
| # eval opponent is a str. | |
| opponent = self._eval_opponent_difficulty[self._eval_opponent_index] | |
| else: | |
| # collect opponent is a Player. | |
| opponent = self._get_collect_opponent() | |
| return { | |
| 'opponent': opponent, | |
| } | |
| def _get_collect_opponent(self) -> Player: | |
| """ | |
| Overview: | |
| Select an opponent according to the player's ``branch_probs``. | |
| Returns: | |
| - opponent (:obj:`Player`): Selected opponent. | |
| """ | |
| p = np.random.uniform() | |
| L = len(self._branch_probs) | |
| cum_p = [0.] + [sum([j.prob for j in self._branch_probs[:i + 1]]) for i in range(L)] | |
| idx = [cum_p[i] <= p < cum_p[i + 1] for i in range(L)].index(True) | |
| branch_name = '_{}_branch'.format(self._branch_probs[idx].name) | |
| opponent = getattr(self, branch_name)() | |
| return opponent | |
| def _get_players(self, select_fn: Callable) -> List[Player]: | |
| """ | |
| Overview: | |
| Get a list of players in the league (shared_payoff), selected by ``select_fn`` . | |
| Arguments: | |
| - select_fn (:obj:`function`): players in the returned list must satisfy this function | |
| Returns: | |
| - players (:obj:`list`): a list of players that satisfies ``select_fn`` | |
| """ | |
| return [player for player in self._payoff.players if select_fn(player)] | |
| def _get_opponent(self, players: list, p: Optional[np.ndarray] = None) -> Player: | |
| """ | |
| Overview: | |
| Get one opponent player from list ``players`` according to probability ``p``. | |
| Arguments: | |
| - players (:obj:`list`): a list of players that can select opponent from | |
| - p (:obj:`np.ndarray`): the selection probability of each player, should have the same size as \ | |
| ``players``. If you don't need it and set None, it would select uniformly by default. | |
| Returns: | |
| - opponent_player (:obj:`Player`): a random chosen opponent player according to probability | |
| """ | |
| idx = np.random.choice(len(players), p=p) | |
| return players[idx] | |
| def increment_eval_difficulty(self) -> bool: | |
| """ | |
| Overview: | |
| When evaluating, active player will choose a specific builtin opponent difficulty. | |
| This method is used to increment the difficulty. | |
| It is usually called after the easier builtin bot is already been beaten by this player. | |
| Returns: | |
| - increment_or_not (:obj:`bool`): True means difficulty is incremented; \ | |
| False means difficulty is already the hardest. | |
| """ | |
| if self._eval_opponent_index < len(self._eval_opponent_difficulty) - 1: | |
| self._eval_opponent_index += 1 | |
| return True | |
| else: | |
| return False | |
| def checkpoint_path(self) -> str: | |
| return self._checkpoint_path | |
| def checkpoint_path(self, path: str) -> None: | |
| self._checkpoint_path = path | |
| class NaiveSpPlayer(ActivePlayer): | |
| def _pfsp_branch(self) -> HistoricalPlayer: | |
| """ | |
| Overview: | |
| Select prioritized fictitious self-play opponent, should be a historical player. | |
| Returns: | |
| - player (:obj:`HistoricalPlayer`): The selected historical player. | |
| """ | |
| historical = self._get_players(lambda p: isinstance(p, HistoricalPlayer)) | |
| win_rates = self._payoff[self, historical] | |
| # Normal self-play if no historical players | |
| if win_rates.shape == (0, ): | |
| return self | |
| p = pfsp(win_rates, weighting='squared') | |
| return self._get_opponent(historical, p) | |
| def _sp_branch(self) -> ActivePlayer: | |
| """ | |
| Overview: | |
| Select normal self-play opponent | |
| """ | |
| return self | |
| def create_player(cfg: EasyDict, player_type: str, *args, **kwargs) -> Player: | |
| """ | |
| Overview: | |
| Given the key (player_type), create a new player instance if in player_mapping's values, | |
| or raise an KeyError. In other words, a derived player must first register then call ``create_player`` | |
| to get the instance object. | |
| Arguments: | |
| - cfg (:obj:`EasyDict`): player config, necessary keys: [import_names] | |
| - player_type (:obj:`str`): the type of player to be created | |
| Returns: | |
| - player (:obj:`Player`): the created new player, should be an instance of one of \ | |
| player_mapping's values | |
| """ | |
| import_module(cfg.get('import_names', [])) | |
| return PLAYER_REGISTRY.build(player_type, *args, **kwargs) | |