Spaces:
Sleeping
Sleeping
| import sys | |
| sys.path.append("/Users/puyuan/code/LightZero/") | |
| from functools import partial | |
| import torch | |
| from ding.config import compile_config | |
| from ding.envs import create_env_manager | |
| from ding.envs import get_vec_env_setting | |
| from ding.policy import create_policy | |
| from ding.utils import set_pkg_seed | |
| from zoo.board_games.gomoku.config.gomoku_alphazero_bot_mode_config import main_config, create_config | |
| import numpy as np | |
| class Agent: | |
| def __init__(self, seed=0): | |
| # model_path = './ckpt/ckpt_best.pth.tar' | |
| model_path = None | |
| # If True, you can play with the agent. | |
| # main_config.env.agent_vs_human = True | |
| main_config.env.agent_vs_human = False | |
| # main_config.env.render_mode = 'image_realtime_mode' | |
| main_config.env.render_mode = 'image_savefile_mode' | |
| main_config.env.replay_path = './video' | |
| create_config.env_manager.type = 'base' | |
| main_config.env.alphazero_mcts_ctree = False | |
| main_config.policy.mcts_ctree = False | |
| main_config.env.evaluator_env_num = 1 | |
| main_config.env.n_evaluator_episode = 1 | |
| cfg, create_cfg = [main_config, create_config] | |
| create_cfg.policy.type = create_cfg.policy.type | |
| if cfg.policy.cuda and torch.cuda.is_available(): | |
| cfg.policy.device = 'cuda' | |
| else: | |
| cfg.policy.device = 'cpu' | |
| cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) | |
| # Create main components: env, policy | |
| env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) | |
| collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) | |
| evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) | |
| collector_env.seed(cfg.seed) | |
| evaluator_env.seed(cfg.seed, dynamic_seed=False) | |
| set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) | |
| self.policy = create_policy(cfg.policy, model=None, enable_field=['learn', 'collect', 'eval']) | |
| # load pretrained model | |
| if model_path is not None: | |
| self.policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) | |
| def compute_action(self, obs): | |
| # print(obs) | |
| policy_output = self.policy.eval_mode.forward({'0': obs}) | |
| actions = {env_id: output['action'] for env_id, output in policy_output.items()} | |
| return actions['0'] | |
| if __name__ == '__main__': | |
| from easydict import EasyDict | |
| from zoo.board_games.gomoku.envs.gomoku_env import GomokuEnv | |
| cfg = EasyDict( | |
| prob_random_agent=0, | |
| board_size=15, | |
| battle_mode='self_play_mode', # NOTE | |
| channel_last=False, | |
| scale=False, | |
| agent_vs_human=False, | |
| bot_action_type='v1', # {'v0', 'v1', 'alpha_beta_pruning'} | |
| prob_random_action_in_bot=0., | |
| check_action_to_connect4_in_bot_v0=False, | |
| render_mode='state_realtime_mode', | |
| replay_path=None, | |
| screen_scaling=9, | |
| alphazero_mcts_ctree=False, | |
| ) | |
| env = GomokuEnv(cfg) | |
| obs = env.reset() | |
| agent = Agent() | |
| while True: | |
| # 更新游戏环境 | |
| observation, reward, done, info = env.step(env.random_action()) | |
| # 如果游戏没有结束,获取 bot 的动作 | |
| if not done: | |
| # agent_action = env.random_action() | |
| agent_action = agent.compute_action(observation) | |
| # 更新环境状态 | |
| _, _, done, _ = env.step(agent_action) | |
| # 准备响应数据 | |
| print('orig bot action: {}'.format(agent_action)) | |
| agent_action = {'i': int(agent_action // 15), 'j': int(agent_action % 15)} | |
| print('bot action: {}'.format(agent_action)) | |
| else: | |
| break |