File size: 4,563 Bytes
5960497 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import sys
if not sys.warnoptions:
import warnings
warnings.simplefilter("ignore")
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.simplefilter("ignore")
import tensorflow as tf
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
tf.get_logger().setLevel('INFO')
tf.autograph.set_verbosity(0)
import warnings
warnings.filterwarnings("ignore", message=r"Passing", category=FutureWarning)
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=Warning)
from baselines.ppo2 import ppo2
from baselines.common.models import build_impala_cnn
from baselines.common.mpi_util import setup_mpi_gpus
from procgen import ProcgenEnv
from baselines.common.vec_env import (
VecExtractDictObs,
VecMonitor,
VecNormalize
)
from baselines import logger
from mpi4py import MPI
import argparse
parser = argparse.ArgumentParser(description='Process procgen training arguments.')
parser.add_argument('--env_name', type=str, default='coinrun')
parser.add_argument('--num_envs', type=int, default=64)
parser.add_argument('--distribution_mode', type=str, default='hard',
choices=["easy", "hard", "exploration", "memory", "extreme"])
parser.add_argument('--num_levels', type=int, default=0)
parser.add_argument('--start_level', type=int, default=0)
parser.add_argument('--test_worker_interval', type=int, default=0)
parser.add_argument('--timesteps_per_proc', type=int, default=25_000_000)
parser.add_argument('--rand_seed', type=int, default=2022)
args = parser.parse_args()
def train_fn(env_name, num_envs, distribution_mode, num_levels, start_level, timesteps_per_proc, is_test_worker=False,
log_dir='./checkpoints', comm=None, rand_seed=None):
learning_rate = 5e-4
ent_coef = .01
gamma = .999
lam = .95
nsteps = 256
nminibatches = 8
ppo_epochs = 3
clip_range = .2
use_vf_clipping = True
mpi_rank_weight = 0 if is_test_worker else 1
num_levels = 0 if is_test_worker else num_levels
if log_dir is not None:
log_comm = comm.Split(1 if is_test_worker else 0, 0)
format_strs = ['csv', 'stdout'] if log_comm.Get_rank() == 0 else []
logger.configure(comm=log_comm, dir=log_dir, format_strs=format_strs)
logger.info("creating environment")
venv = ProcgenEnv(num_envs=num_envs, env_name=env_name, num_levels=num_levels, start_level=start_level,
distribution_mode=distribution_mode, rand_seed=rand_seed)
venv = VecExtractDictObs(venv, "rgb")
venv = VecMonitor(
venv=venv, filename=None, keep_buf=100,
)
venv = VecNormalize(venv=venv, ob=False)
logger.info("creating tf session")
setup_mpi_gpus()
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True # pylint: disable=E1101
sess = tf.compat.v1.Session(config=config)
sess.__enter__()
conv_fn = lambda x: build_impala_cnn(x, depths=[16, 32, 32], emb_size=256)
logger.info("training")
ppo2.learn(
env=venv,
network=conv_fn,
total_timesteps=timesteps_per_proc,
save_interval=0,
nsteps=nsteps,
nminibatches=nminibatches,
lam=lam,
gamma=gamma,
noptepochs=ppo_epochs,
log_interval=1,
ent_coef=ent_coef,
mpi_rank_weight=mpi_rank_weight,
clip_vf=use_vf_clipping,
comm=comm,
lr=learning_rate,
cliprange=clip_range,
update_fn=None,
init_fn=None,
vf_coef=0.5,
max_grad_norm=0.5,
seed=rand_seed
)
def main():
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
is_test_worker = False
test_worker_interval = args.test_worker_interval
if test_worker_interval > 0:
is_test_worker = rank % test_worker_interval == (test_worker_interval - 1)
saved_dir = './train_procgen/checkpoints/ppo-' + args.env_name + '_' + args.distribution_mode + '_' + str(
args.num_levels) + '_' + str(args.start_level) + '_' + str(
args.rand_seed)
train_fn(args.env_name,
args.num_envs,
args.distribution_mode,
args.num_levels,
args.start_level,
args.timesteps_per_proc,
is_test_worker=is_test_worker,
comm=comm,
log_dir=saved_dir,
rand_seed=args.rand_seed)
if __name__ == '__main__':
main()
|