Spaces:
Running
Running
| import os | |
| import random | |
| import pytest | |
| import copy | |
| from easydict import EasyDict | |
| import torch | |
| from ding.league import create_league | |
| one_vs_one_league_default_config = dict( | |
| league=dict( | |
| league_type='one_vs_one', | |
| import_names=["ding.league"], | |
| # ---player---- | |
| # "player_category" is just a name. Depends on the env. | |
| # For example, in StarCraft, this can be ['zerg', 'terran', 'protoss']. | |
| player_category=['default'], | |
| # Support different types of active players for solo and battle league. | |
| # For solo league, supports ['solo_active_player']. | |
| # For battle league, supports ['battle_active_player', 'main_player', 'main_exploiter', 'league_exploiter']. | |
| active_players=dict( | |
| naive_sp_player=1, # {player_type: player_num} | |
| ), | |
| naive_sp_player=dict( | |
| # There should be keys ['one_phase_step', 'branch_probs', 'strong_win_rate']. | |
| # Specifically for 'main_exploiter' of StarCraft, there should be an additional key ['min_valid_win_rate']. | |
| one_phase_step=10, | |
| branch_probs=dict( | |
| pfsp=0.5, | |
| sp=0.5, | |
| ), | |
| strong_win_rate=0.7, | |
| ), | |
| # "use_pretrain" means whether to use pretrain model to initialize active player. | |
| use_pretrain=False, | |
| # "use_pretrain_init_historical" means whether to use pretrain model to initialize historical player. | |
| # "pretrain_checkpoint_path" is the pretrain checkpoint path used in "use_pretrain" and | |
| # "use_pretrain_init_historical". If both are False, "pretrain_checkpoint_path" can be omitted as well. | |
| # Otherwise, "pretrain_checkpoint_path" should list paths of all player categories. | |
| use_pretrain_init_historical=False, | |
| pretrain_checkpoint_path=dict(default='default_cate_pretrain.pth', ), | |
| # ---payoff--- | |
| payoff=dict( | |
| # Supports ['battle'] | |
| type='battle', | |
| decay=0.99, | |
| min_win_rate_games=8, | |
| ), | |
| path_policy='./league', | |
| ), | |
| ) | |
| one_vs_one_league_default_config = EasyDict(one_vs_one_league_default_config) | |
| def get_random_result(): | |
| ran = random.random() | |
| if ran < 1. / 3: | |
| return "wins" | |
| elif ran < 1. / 2: | |
| return "losses" | |
| else: | |
| return "draws" | |
| class TestOneVsOneLeague: | |
| def test_naive(self): | |
| league = create_league(one_vs_one_league_default_config.league) | |
| assert (len(league.active_players) == 1) | |
| assert (len(league.historical_players) == 0) | |
| active_player_ids = [p.player_id for p in league.active_players] | |
| assert set(active_player_ids) == set(league.active_players_ids) | |
| active_player_id = active_player_ids[0] | |
| active_player_ckpt = league.active_players[0].checkpoint_path | |
| tmp = torch.tensor([1, 2, 3]) | |
| path_policy = one_vs_one_league_default_config.league.path_policy | |
| torch.save(tmp, active_player_ckpt) | |
| # judge_snapshot & update_active_player | |
| assert not league.judge_snapshot(active_player_id) | |
| player_update_dict = { | |
| 'player_id': active_player_id, | |
| 'train_iteration': one_vs_one_league_default_config.league.naive_sp_player.one_phase_step * 2, | |
| } | |
| league.update_active_player(player_update_dict) | |
| assert league.judge_snapshot(active_player_id) | |
| historical_player_ids = [p.player_id for p in league.historical_players] | |
| assert len(historical_player_ids) == 1 | |
| historical_player_id = historical_player_ids[0] | |
| # get_job_info, eval_flag=False | |
| vs_active = False | |
| vs_historical = False | |
| while True: | |
| collect_job_info = league.get_job_info(active_player_id, eval_flag=False) | |
| assert collect_job_info['agent_num'] == 2 | |
| assert len(collect_job_info['checkpoint_path']) == 2 | |
| assert collect_job_info['launch_player'] == active_player_id | |
| assert collect_job_info['player_id'][0] == active_player_id | |
| if collect_job_info['player_active_flag'][1]: | |
| assert collect_job_info['player_id'][1] == collect_job_info['player_id'][0] | |
| vs_active = True | |
| else: | |
| assert collect_job_info['player_id'][1] == historical_player_id | |
| vs_historical = True | |
| if vs_active and vs_historical: | |
| break | |
| # get_job_info, eval_flag=False | |
| eval_job_info = league.get_job_info(active_player_id, eval_flag=True) | |
| assert eval_job_info['agent_num'] == 1 | |
| assert len(eval_job_info['checkpoint_path']) == 1 | |
| assert eval_job_info['launch_player'] == active_player_id | |
| assert eval_job_info['player_id'][0] == active_player_id | |
| assert len(eval_job_info['player_id']) == 1 | |
| assert len(eval_job_info['player_active_flag']) == 1 | |
| assert eval_job_info['eval_opponent'] in league.active_players[0]._eval_opponent_difficulty | |
| # finish_job | |
| episode_num = 5 | |
| env_num = 8 | |
| player_id = [active_player_id, historical_player_id] | |
| result = [[get_random_result() for __ in range(8)] for _ in range(5)] | |
| payoff_update_info = { | |
| 'launch_player': active_player_id, | |
| 'player_id': player_id, | |
| 'episode_num': episode_num, | |
| 'env_num': env_num, | |
| 'result': result, | |
| } | |
| league.finish_job(payoff_update_info) | |
| wins = 0 | |
| games = episode_num * env_num | |
| for i in result: | |
| for j in i: | |
| if j == 'wins': | |
| wins += 1 | |
| league.payoff[league.active_players[0], league.historical_players[0]] == wins / games | |
| os.popen("rm -rf {}".format(path_policy)) | |
| print("Finish!") | |
| def test_league_info(self): | |
| cfg = copy.deepcopy(one_vs_one_league_default_config.league) | |
| cfg.path_policy = 'test_league_info' | |
| league = create_league(cfg) | |
| active_player_id = [p.player_id for p in league.active_players][0] | |
| active_player_ckpt = [p.checkpoint_path for p in league.active_players][0] | |
| tmp = torch.tensor([1, 2, 3]) | |
| torch.save(tmp, active_player_ckpt) | |
| assert (len(league.active_players) == 1) | |
| assert (len(league.historical_players) == 0) | |
| print('\n') | |
| print(repr(league.payoff)) | |
| print(league.player_rank(string=True)) | |
| league.judge_snapshot(active_player_id, force=True) | |
| for i in range(10): | |
| job = league.get_job_info(active_player_id, eval_flag=False) | |
| payoff_update_info = { | |
| 'launch_player': active_player_id, | |
| 'player_id': job['player_id'], | |
| 'episode_num': 2, | |
| 'env_num': 4, | |
| 'result': [[get_random_result() for __ in range(4)] for _ in range(2)] | |
| } | |
| league.finish_job(payoff_update_info) | |
| # if not self-play | |
| if job['player_id'][0] != job['player_id'][1]: | |
| win_loss_result = sum(payoff_update_info['result'], []) | |
| home = league.get_player_by_id(job['player_id'][0]) | |
| away = league.get_player_by_id(job['player_id'][1]) | |
| home.rating, away.rating = league.metric_env.rate_1vs1(home.rating, away.rating, win_loss_result) | |
| print(repr(league.payoff)) | |
| print(league.player_rank(string=True)) | |
| os.popen("rm -rf {}".format(cfg.path_policy)) | |
| if __name__ == '__main__': | |
| pytest.main(["-sv", os.path.basename(__file__)]) | |