""" Helpers for scripts like run_atari.py. """ import os try: from mpi4py import MPI except ImportError: MPI = None import gym from gym.wrappers import FlattenObservation, FilterObservation from baselines import logger from baselines.bench import Monitor from baselines.common import set_global_seeds from baselines.common.atari_wrappers import make_atari, wrap_deepmind from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv from baselines.common.vec_env.dummy_vec_env import DummyVecEnv from baselines.common import retro_wrappers from baselines.common.wrappers import ClipActionsWrapper def make_vec_env(env_id, env_type, num_env, seed, wrapper_kwargs=None, env_kwargs=None, start_index=0, reward_scale=1.0, flatten_dict_observations=True, gamestate=None, initializer=None, force_dummy=False): """ Create a wrapped, monitored SubprocVecEnv for Atari and MuJoCo. """ wrapper_kwargs = wrapper_kwargs or {} env_kwargs = env_kwargs or {} mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0 seed = seed + 10000 * mpi_rank if seed is not None else None logger_dir = logger.get_dir() def make_thunk(rank, initializer=None): return lambda: make_env( env_id=env_id, env_type=env_type, mpi_rank=mpi_rank, subrank=rank, seed=seed, reward_scale=reward_scale, gamestate=gamestate, flatten_dict_observations=flatten_dict_observations, wrapper_kwargs=wrapper_kwargs, env_kwargs=env_kwargs, logger_dir=logger_dir, initializer=initializer ) set_global_seeds(seed) if not force_dummy and num_env > 1: return SubprocVecEnv([make_thunk(i + start_index, initializer=initializer) for i in range(num_env)]) else: return DummyVecEnv([make_thunk(i + start_index, initializer=None) for i in range(num_env)]) def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.0, gamestate=None, flatten_dict_observations=True, wrapper_kwargs=None, env_kwargs=None, logger_dir=None, initializer=None): if initializer is not None: initializer(mpi_rank=mpi_rank, subrank=subrank) wrapper_kwargs = wrapper_kwargs or {} env_kwargs = env_kwargs or {} if ':' in env_id: import re import importlib module_name = re.sub(':.*','',env_id) env_id = re.sub('.*:', '', env_id) importlib.import_module(module_name) if env_type == 'atari': env = make_atari(env_id) elif env_type == 'retro': import retro gamestate = gamestate or retro.State.DEFAULT env = retro_wrappers.make_retro(game=env_id, max_episode_steps=10000, use_restricted_actions=retro.Actions.DISCRETE, state=gamestate) else: env = gym.make(env_id, **env_kwargs) if flatten_dict_observations and isinstance(env.observation_space, gym.spaces.Dict): env = FlattenObservation(env) env.seed(seed + subrank if seed is not None else None) env = Monitor(env, logger_dir and os.path.join(logger_dir, str(mpi_rank) + '.' + str(subrank)), allow_early_resets=True) if env_type == 'atari': env = wrap_deepmind(env, **wrapper_kwargs) elif env_type == 'retro': if 'frame_stack' not in wrapper_kwargs: wrapper_kwargs['frame_stack'] = 1 env = retro_wrappers.wrap_deepmind_retro(env, **wrapper_kwargs) if isinstance(env.action_space, gym.spaces.Box): env = ClipActionsWrapper(env) if reward_scale != 1: env = retro_wrappers.RewardScaler(env, reward_scale) return env def make_mujoco_env(env_id, seed, reward_scale=1.0): """ Create a wrapped, monitored gym.Env for MuJoCo. """ rank = MPI.COMM_WORLD.Get_rank() myseed = seed + 1000 * rank if seed is not None else None set_global_seeds(myseed) env = gym.make(env_id) logger_path = None if logger.get_dir() is None else os.path.join(logger.get_dir(), str(rank)) env = Monitor(env, logger_path, allow_early_resets=True) env.seed(seed) if reward_scale != 1.0: from baselines.common.retro_wrappers import RewardScaler env = RewardScaler(env, reward_scale) return env def make_robotics_env(env_id, seed, rank=0): """ Create a wrapped, monitored gym.Env for MuJoCo. """ set_global_seeds(seed) env = gym.make(env_id) env = FlattenObservation(FilterObservation(env, ['observation', 'desired_goal'])) env = Monitor( env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)), info_keywords=('is_success',)) env.seed(seed) return env def arg_parser(): """ Create an empty argparse.ArgumentParser. """ import argparse return argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) def atari_arg_parser(): """ Create an argparse.ArgumentParser for run_atari.py. """ print('Obsolete - use common_arg_parser instead') return common_arg_parser() def mujoco_arg_parser(): print('Obsolete - use common_arg_parser instead') return common_arg_parser() def common_arg_parser(): """ Create an argparse.ArgumentParser for run_mujoco.py. """ parser = arg_parser() parser.add_argument('--env', help='environment ID', type=str, default='Reacher-v2') parser.add_argument('--env_type', help='type of environment, used when the environment type cannot be automatically determined', type=str) parser.add_argument('--seed', help='RNG seed', type=int, default=None) parser.add_argument('--alg', help='Algorithm', type=str, default='ppo2') parser.add_argument('--num_timesteps', type=float, default=1e6), parser.add_argument('--network', help='network type (mlp, cnn, lstm, cnn_lstm, conv_only)', default=None) parser.add_argument('--gamestate', help='game state to load (so far only used in retro games)', default=None) parser.add_argument('--num_env', help='Number of environment copies being run in parallel. When not specified, set to number of cpus for Atari, and to 1 for Mujoco', default=None, type=int) parser.add_argument('--reward_scale', help='Reward scale factor. Default: 1.0', default=1.0, type=float) parser.add_argument('--save_path', help='Path to save trained model to', default=None, type=str) parser.add_argument('--save_video_interval', help='Save video every x steps (0 = disabled)', default=0, type=int) parser.add_argument('--save_video_length', help='Length of recorded video. Default: 200', default=200, type=int) parser.add_argument('--log_path', help='Directory to save learning curve data.', default=None, type=str) parser.add_argument('--play', default=False, action='store_true') return parser def robotics_arg_parser(): """ Create an argparse.ArgumentParser for run_mujoco.py. """ parser = arg_parser() parser.add_argument('--env', help='environment ID', type=str, default='FetchReach-v0') parser.add_argument('--seed', help='RNG seed', type=int, default=None) parser.add_argument('--num-timesteps', type=int, default=int(1e6)) return parser def parse_unknown_args(args): """ Parse arguments not consumed by arg parser into a dictionary """ retval = {} preceded_by_key = False for arg in args: if arg.startswith('--'): if '=' in arg: key = arg.split('=')[0][2:] value = arg.split('=')[1] retval[key] = value else: key = arg[2:] preceded_by_key = True elif preceded_by_key: retval[key] = arg preceded_by_key = False return retval