Spaces:
Running
Running
| import pytest | |
| from ding.envs import DingEnvWrapper | |
| from lzero.envs.wrappers import ActionDiscretizationEnvWrapper, LightZeroEnvWrapper | |
| from easydict import EasyDict | |
| import gym | |
| import numpy as np | |
| class TestLightZeroEnvWrapper: | |
| def test_continuous_pendulum(self): | |
| env_cfg = EasyDict( | |
| dict( | |
| env_name='Pendulum-v1', | |
| manually_discretization=False, | |
| continuous=True, | |
| each_dim_disc_size=None, | |
| is_train=True, | |
| ) | |
| ) | |
| lightzero_env = DingEnvWrapper( | |
| gym.make(env_cfg.env_name), cfg={'env_wrapper': [ | |
| lambda env: LightZeroEnvWrapper(env, env_cfg), | |
| ]} | |
| ) | |
| obs = lightzero_env.reset() | |
| print("obs: ", obs) | |
| print(lightzero_env.observation_space, lightzero_env.action_space, lightzero_env.reward_space) | |
| assert isinstance(obs, dict) | |
| assert isinstance(obs['observation'], np.ndarray) and obs['observation'].shape == (3, ) | |
| assert obs['action_mask'] is None and obs['to_play'] == -1 | |
| action = lightzero_env.random_action() | |
| print('random_action: {}, action_space: {}'.format(action.shape, lightzero_env.action_space)) | |
| def test_discretization_pendulum(self): | |
| env_cfg = EasyDict( | |
| dict( | |
| env_name='Pendulum-v1', | |
| manually_discretization=True, | |
| continuous=False, | |
| each_dim_disc_size=11, | |
| is_train=True, | |
| ) | |
| ) | |
| lightzero_env = DingEnvWrapper( | |
| gym.make(env_cfg.env_name), | |
| cfg={ | |
| 'env_wrapper': [ | |
| lambda env: ActionDiscretizationEnvWrapper(env, env_cfg), | |
| lambda env: LightZeroEnvWrapper(env, env_cfg), | |
| ] | |
| } | |
| ) | |
| obs = lightzero_env.reset() | |
| print("obs: ", obs) | |
| print(lightzero_env.observation_space, lightzero_env.action_space, lightzero_env.reward_space) | |
| assert isinstance(obs, dict) | |
| assert isinstance(obs['observation'], np.ndarray) and obs['observation'].shape == (3, ) | |
| assert obs['action_mask'].sum() == 11 and obs['to_play'] == -1 | |
| action = lightzero_env.random_action() | |
| print('random_action: {}, action_space: {}'.format(action.shape, lightzero_env.action_space)) | |
| def test_continuous_bipedalwalker(self): | |
| env_cfg = EasyDict( | |
| dict( | |
| env_name='BipedalWalker-v3', | |
| manually_discretization=False, | |
| continuous=True, | |
| each_dim_disc_size=4, | |
| is_train=True, | |
| ) | |
| ) | |
| lightzero_env = DingEnvWrapper( | |
| gym.make(env_cfg.env_name), cfg={'env_wrapper': [ | |
| lambda env: LightZeroEnvWrapper(env, env_cfg), | |
| ]} | |
| ) | |
| obs = lightzero_env.reset() | |
| print("obs: ", obs) | |
| print(lightzero_env.observation_space, lightzero_env.action_space, lightzero_env.reward_space) | |
| assert isinstance(obs, dict) | |
| assert isinstance(obs['observation'], np.ndarray) and obs['observation'].shape == (24, ) | |
| assert obs['action_mask'] is None and obs['to_play'] == -1 | |
| action = lightzero_env.random_action() | |
| print('random_action: {}, action_space: {}'.format(action.shape, lightzero_env.action_space)) | |
| def test_discretization_bipedalwalker(self): | |
| env_cfg = EasyDict( | |
| dict( | |
| env_name='BipedalWalker-v3', | |
| manually_discretization=True, | |
| continuous=False, | |
| each_dim_disc_size=4, | |
| is_train=True, | |
| ) | |
| ) | |
| lightzero_env = DingEnvWrapper( | |
| gym.make(env_cfg.env_name), | |
| cfg={ | |
| 'env_wrapper': [ | |
| lambda env: ActionDiscretizationEnvWrapper(env, env_cfg), | |
| lambda env: LightZeroEnvWrapper(env, env_cfg), | |
| ] | |
| } | |
| ) | |
| obs = lightzero_env.reset() | |
| print("obs: ", obs) | |
| print(lightzero_env.observation_space, lightzero_env.action_space, lightzero_env.reward_space) | |
| assert isinstance(obs, dict) | |
| assert isinstance(obs['observation'], np.ndarray) and obs['observation'].shape == (24, ) | |
| assert obs['action_mask'].sum() == 256 and obs['to_play'] == -1 | |
| action = lightzero_env.random_action() | |
| print('random_action: {}, action_space: {}'.format(action.shape, lightzero_env.action_space)) | |