File size: 5,400 Bytes
5960497 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import sys
if not sys.warnoptions:
import warnings
warnings.simplefilter("ignore")
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.simplefilter("ignore")
import tensorflow as tf
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
tf.get_logger().setLevel('INFO')
tf.autograph.set_verbosity(0)
import warnings
warnings.filterwarnings("ignore", message=r"Passing", category=FutureWarning)
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=Warning)
# import tensorflow as tf
from baselines.ppo2 import sppo2
from baselines.common.models import build_skill_impala_cnn
from baselines.common.mpi_util import setup_mpi_gpus
from procgen import ProcgenEnv
from baselines.common.vec_env import (
VecExtractDictObs,
VecMonitor,
VecNormalize
)
from baselines import logger
from mpi4py import MPI
import argparse
from train_procgen.utils import *
from train_procgen.constants import EASY_GAME_RANGES
parser = argparse.ArgumentParser(description='Process procgen training arguments.')
parser.add_argument('--env_name', type=str, default='coinrun')
parser.add_argument('--num_envs', type=int, default=64)
parser.add_argument('--distribution_mode', type=str, default='easy',
choices=["easy", "hard", "exploration", "memory", "extreme"])
parser.add_argument('--num_levels', type=int, default=0)
parser.add_argument('--start_level', type=int, default=0)
parser.add_argument('--test_worker_interval', type=int, default=0)
parser.add_argument('--timesteps_per_proc', type=int, default=25_000_000)
parser.add_argument('--rand_seed', type=int, default=2022)
parser.add_argument('--num_embeddings', type=int, default=8)
parser.add_argument('--beta', type=float, default=0.0)
parser.add_argument('--alpha1', type=float, default=20)
parser.add_argument('--alpha2', type=float, default=20)
args = parser.parse_args()
def train_fn(env_name, num_envs, distribution_mode, num_levels, start_level, timesteps_per_proc,
num_embeddings, beta, alpha1, alpha2,
is_test_worker=False,
log_dir='./train_procgen/checkpoints', comm=None, rand_seed=None):
learning_rate = 5e-4
ent_coef = .01
gamma = .999
lam = .95
nsteps = 256
nminibatches = 8
ppo_epochs = 3
clip_range = .2
use_vf_clipping = True
mpi_rank_weight = 0 if is_test_worker else 1
num_levels = 0 if is_test_worker else num_levels
if log_dir is not None:
log_comm = comm.Split(1 if is_test_worker else 0, 0)
format_strs = ['csv', 'stdout'] if log_comm.Get_rank() == 0 else []
logger.configure(comm=log_comm, dir=log_dir, format_strs=format_strs)
logger.info("creating environment")
venv = ProcgenEnv(num_envs=num_envs, env_name=env_name, num_levels=num_levels, start_level=start_level,
distribution_mode=distribution_mode, rand_seed=rand_seed)
venv = VecExtractDictObs(venv, "rgb")
venv = VecMonitor(
venv=venv, filename=None, keep_buf=100,
)
venv = VecNormalize(venv=venv, ob=False)
logger.info("creating tf session")
setup_mpi_gpus()
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True # pylint: disable=E1101
sess = tf.compat.v1.Session(config=config)
sess.__enter__()
conv_fn = lambda x: build_skill_impala_cnn(x, depths=[16, 32, 32], emb_dim=256, num_embeddings=num_embeddings,
beta=beta, seed=rand_seed)
logger.info("training")
sppo2.learn(
env=venv,
network=conv_fn,
total_timesteps=timesteps_per_proc,
save_interval=0,
nsteps=nsteps,
nminibatches=nminibatches,
lam=lam,
gamma=gamma,
noptepochs=ppo_epochs,
log_interval=1,
ent_coef=ent_coef,
mpi_rank_weight=mpi_rank_weight,
clip_vf=use_vf_clipping,
comm=comm,
lr=learning_rate,
cliprange=clip_range,
update_fn=None,
init_fn=None,
vf_coef=0.5,
max_grad_norm=0.5,
seed=rand_seed,
num_embeddings=num_embeddings,
highest_score=EASY_GAME_RANGES[env_name][1],
alpha1=alpha1,
alpha2=alpha2,
)
def main():
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
is_test_worker = False
test_worker_interval = args.test_worker_interval
if test_worker_interval > 0:
is_test_worker = rank % test_worker_interval == (test_worker_interval - 1)
saved_dir = './train_procgen/checkpoints/sppo-' + args.env_name + '_' + args.distribution_mode + '_' + str(
args.num_levels) + '_' + str(args.start_level) + '_' + str(
args.rand_seed)
train_fn(args.env_name,
args.num_envs,
args.distribution_mode,
args.num_levels,
args.start_level,
args.timesteps_per_proc,
is_test_worker=is_test_worker,
comm=comm,
log_dir=saved_dir,
rand_seed=args.rand_seed,
num_embeddings=args.num_embeddings,
beta=args.beta,
alpha1=args.alpha1,
alpha2=args.alpha2,
)
if __name__ == '__main__':
main()
|