Spaces:
Running
Running
| from tabnanny import check | |
| from typing import Any, Callable, List, Tuple | |
| import numpy as np | |
| from collections.abc import Sequence | |
| from easydict import EasyDict | |
| from ding.envs.env import BaseEnv, BaseEnvTimestep | |
| from ding.envs.env.tests import DemoEnv | |
| # from dizoo.atari.envs import AtariEnv | |
| def check_space_dtype(env: BaseEnv) -> None: | |
| print("== 0. Test obs/act/rew space's dtype") | |
| env.reset() | |
| for name, space in zip(['obs', 'act', 'rew'], [env.observation_space, env.action_space, env.reward_space]): | |
| if 'float' in repr(space.dtype): | |
| assert space.dtype == np.float32, "If float, then must be np.float32, but get {} for {} space".format( | |
| space.dtype, name | |
| ) | |
| if 'int' in repr(space.dtype): | |
| assert space.dtype == np.int64, "If int, then must be np.int64, but get {} for {} space".format( | |
| space.dtype, name | |
| ) | |
| # Util function | |
| def check_array_space(ndarray, space, name) -> bool: | |
| if isinstance(ndarray, np.ndarray): | |
| # print("{}'s type should be np.ndarray".format(name)) | |
| assert ndarray.dtype == space.dtype, "{}'s dtype is {}, but requires {}".format( | |
| name, ndarray.dtype, space.dtype | |
| ) | |
| assert ndarray.shape == space.shape, "{}'s shape is {}, but requires {}".format( | |
| name, ndarray.shape, space.shape | |
| ) | |
| assert (space.low <= ndarray).all() and (ndarray <= space.high).all( | |
| ), "{}'s value is {}, but requires in range ({},{})".format(name, ndarray, space.low, space.high) | |
| elif isinstance(ndarray, Sequence): | |
| for i in range(len(ndarray)): | |
| try: | |
| check_array_space(ndarray[i], space[i], name) | |
| except AssertionError as e: | |
| print("The following error happens at {}-th index".format(i)) | |
| raise e | |
| elif isinstance(ndarray, dict): | |
| for k in ndarray.keys(): | |
| try: | |
| check_array_space(ndarray[k], space[k], name) | |
| except AssertionError as e: | |
| print("The following error happens at key {}".format(k)) | |
| raise e | |
| else: | |
| raise TypeError( | |
| "Input array should be np.ndarray or sequence/dict of np.ndarray, but found {}".format(type(ndarray)) | |
| ) | |
| def check_reset(env: BaseEnv) -> None: | |
| print('== 1. Test reset method') | |
| obs = env.reset() | |
| check_array_space(obs, env.observation_space, 'obs') | |
| def check_step(env: BaseEnv) -> None: | |
| done_times = 0 | |
| print('== 2. Test step method') | |
| _ = env.reset() | |
| if hasattr(env, "random_action"): | |
| random_action = env.random_action() | |
| else: | |
| random_action = env.action_space.sample() | |
| while True: | |
| obs, rew, done, info = env.step(random_action) | |
| for ndarray, space, name in zip([obs, rew], [env.observation_space, env.reward_space], ['obs', 'rew']): | |
| check_array_space(ndarray, space, name) | |
| if done: | |
| assert 'eval_episode_return' in info, "info dict should have 'eval_episode_return' key." | |
| done_times += 1 | |
| _ = env.reset() | |
| if done_times == 3: | |
| break | |
| # Util function | |
| def check_different_memory(array1, array2, step_times) -> None: | |
| assert type(array1) == type( | |
| array2 | |
| ), "In step times {}, obs_last_frame({}) and obs_this_frame({}) are not of the same type".format( | |
| step_times, type(array1), type(array2) | |
| ) | |
| if isinstance(array1, np.ndarray): | |
| assert id(array1) != id( | |
| array2 | |
| ), "In step times {}, obs_last_frame and obs_this_frame are the same np.ndarray".format(step_times) | |
| elif isinstance(array1, Sequence): | |
| assert len(array1) == len( | |
| array2 | |
| ), "In step times {}, obs_last_frame({}) and obs_this_frame({}) have different sequence lengths".format( | |
| step_times, len(array1), len(array2) | |
| ) | |
| for i in range(len(array1)): | |
| try: | |
| check_different_memory(array1[i], array2[i], step_times) | |
| except AssertionError as e: | |
| print("The following error happens at {}-th index".format(i)) | |
| raise e | |
| elif isinstance(array1, dict): | |
| assert array1.keys() == array2.keys(), "In step times {}, obs_last_frame({}) and obs_this_frame({}) have \ | |
| different dict keys".format(step_times, array1.keys(), array2.keys()) | |
| for k in array1.keys(): | |
| try: | |
| check_different_memory(array1[k], array2[k], step_times) | |
| except AssertionError as e: | |
| print("The following error happens at key {}".format(k)) | |
| raise e | |
| else: | |
| raise TypeError( | |
| "Input array should be np.ndarray or list/dict of np.ndarray, but found {} and {}".format( | |
| type(array1), type(array2) | |
| ) | |
| ) | |
| def check_obs_deepcopy(env: BaseEnv) -> None: | |
| step_times = 0 | |
| print('== 3. Test observation deepcopy') | |
| obs_1 = env.reset() | |
| if hasattr(env, "random_action"): | |
| random_action = env.random_action() | |
| else: | |
| random_action = env.action_space.sample() | |
| while True: | |
| step_times += 1 | |
| obs_2, _, done, _ = env.step(random_action) | |
| check_different_memory(obs_1, obs_2, step_times) | |
| obs_1 = obs_2 | |
| if done: | |
| break | |
| def check_all(env: BaseEnv) -> None: | |
| check_space_dtype(env) | |
| check_reset(env) | |
| check_step(env) | |
| check_obs_deepcopy(env) | |
| def demonstrate_correct_procedure(env_fn: Callable) -> None: | |
| print('== 4. Demonstrate the correct procudures') | |
| done_times = 0 | |
| # Init the env. | |
| env = env_fn({}) | |
| # Lazy init. The real env is not initialized until `reset` method is called | |
| assert not hasattr(env, "_env") | |
| # Must set seed before `reset` method is called. | |
| env.seed(4) | |
| assert env._seed == 4 | |
| # Reset the env. The real env is initialized here. | |
| obs = env.reset() | |
| while True: | |
| # Using the policy to get the action from obs. But here we use `random_action` instead. | |
| action = env.random_action() | |
| obs, rew, done, info = env.step(action) | |
| if done: | |
| assert 'eval_episode_return' in info | |
| done_times += 1 | |
| obs = env.reset() | |
| # Seed will not change unless `seed` method is called again. | |
| assert env._seed == 4 | |
| if done_times == 3: | |
| break | |
| if __name__ == "__main__": | |
| ''' | |
| # Moethods `check_*` are for user to check whether their implemented env obeys DI-engine's rules. | |
| # You can replace `AtariEnv` with your own env. | |
| atari_env = AtariEnv(EasyDict(env_id='PongNoFrameskip-v4', frame_stack=4, is_train=False)) | |
| check_reset(atari_env) | |
| check_step(atari_env) | |
| check_obs_deepcopy(atari_env) | |
| ''' | |
| # Method `demonstrate_correct_procudure` is to demonstrate the correct procedure to | |
| # use an env to generate trajectories. | |
| # You can check whether your env's design is similar to `DemoEnv` | |
| demonstrate_correct_procedure(DemoEnv) | |