|
|
import math |
|
|
import sys |
|
|
|
|
|
if not sys.warnoptions: |
|
|
import warnings |
|
|
|
|
|
warnings.simplefilter("ignore") |
|
|
|
|
|
with warnings.catch_warnings(): |
|
|
warnings.filterwarnings("ignore", category=FutureWarning) |
|
|
warnings.filterwarnings("ignore", category=RuntimeWarning) |
|
|
warnings.simplefilter("ignore") |
|
|
import tensorflow as tf |
|
|
|
|
|
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) |
|
|
tf.get_logger().setLevel('INFO') |
|
|
tf.autograph.set_verbosity(0) |
|
|
import warnings |
|
|
|
|
|
warnings.filterwarnings("ignore", message=r"Passing", category=FutureWarning) |
|
|
warnings.simplefilter(action='ignore', category=FutureWarning) |
|
|
warnings.simplefilter(action='ignore', category=Warning) |
|
|
|
|
|
from baselines.ppo2 import sppo2 |
|
|
from baselines.common.models import build_skill_impala_cnn |
|
|
from baselines.common.mpi_util import setup_mpi_gpus |
|
|
from procgen import ProcgenEnv |
|
|
from baselines.common.vec_env import ( |
|
|
VecExtractDictObs, |
|
|
VecMonitor, |
|
|
VecNormalize |
|
|
) |
|
|
from baselines import logger |
|
|
from mpi4py import MPI |
|
|
import argparse |
|
|
import matplotlib.colors as c |
|
|
import gym |
|
|
from PIL import Image |
|
|
import matplotlib |
|
|
import matplotlib.pyplot as plt |
|
|
from baselines.common import set_global_seeds |
|
|
from train_procgen.utils import * |
|
|
from copy import deepcopy |
|
|
from matplotlib.offsetbox import OffsetImage, AnnotationBbox |
|
|
|
|
|
parser = argparse.ArgumentParser(description='Process procgen enjoying arguments.') |
|
|
parser.add_argument('--env_name', type=str, default='coinrun') |
|
|
parser.add_argument('--num_envs', type=int, default=64) |
|
|
parser.add_argument('--distribution_mode', type=str, default='easy', |
|
|
choices=["easy", "hard", "exploration", "memory", "extreme"]) |
|
|
parser.add_argument('--num_levels', type=int, default=0) |
|
|
parser.add_argument('--start_level', type=int, default=0) |
|
|
parser.add_argument('--test_worker_interval', type=int, default=0) |
|
|
parser.add_argument('--timesteps_per_proc', type=int, default=25_000_000) |
|
|
parser.add_argument('--rand_seed', type=int, default=2021) |
|
|
parser.add_argument('--num_embeddings', type=int, default=8) |
|
|
parser.add_argument('--beta', type=float, default=0.0) |
|
|
parser.add_argument('--alpha1', type=float, default=20) |
|
|
parser.add_argument('--alpha2', type=float, default=20) |
|
|
parser.add_argument('--total_states', type=int, default=5000, help='choose how many states you want to have') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
def isEmpty(path): |
|
|
if os.path.exists(path) and not os.path.isfile(path): |
|
|
|
|
|
|
|
|
if not os.listdir(path): |
|
|
return True |
|
|
else: |
|
|
return False |
|
|
else: |
|
|
return True |
|
|
|
|
|
|
|
|
def load_model(env_name, num_envs, distribution_mode, num_levels, start_level, |
|
|
num_embeddings, beta, alpha1, alpha2, |
|
|
log_dir='./checkpoints', comm=None, rand_seed=None): |
|
|
ent_coef = .01 |
|
|
nsteps = 256 |
|
|
nminibatches = 8 |
|
|
|
|
|
mpi_rank_weight = 1 |
|
|
num_levels = num_levels |
|
|
|
|
|
logger.info("creating environment") |
|
|
venv = ProcgenEnv(num_envs=num_envs, env_name=env_name, num_levels=num_levels, start_level=start_level + 200000, |
|
|
distribution_mode=distribution_mode, rand_seed=rand_seed) |
|
|
venv = VecExtractDictObs(venv, "rgb") |
|
|
|
|
|
venv = VecMonitor( |
|
|
venv=venv, filename=None, keep_buf=100, |
|
|
) |
|
|
|
|
|
venv = VecNormalize(venv=venv, ob=False) |
|
|
|
|
|
logger.info("creating tf session") |
|
|
setup_mpi_gpus() |
|
|
config = tf.compat.v1.ConfigProto() |
|
|
config.gpu_options.allow_growth = True |
|
|
sess = tf.compat.v1.Session(config=config) |
|
|
sess.__enter__() |
|
|
|
|
|
conv_fn = lambda x: build_skill_impala_cnn(x, depths=[16, 32, 32], emb_dim=256, num_embeddings=num_embeddings, |
|
|
beta=beta, seed=rand_seed) |
|
|
|
|
|
logger.info("loading the model") |
|
|
|
|
|
model = sppo2.load( |
|
|
env=venv, |
|
|
network=conv_fn, |
|
|
nsteps=nsteps, |
|
|
nminibatches=nminibatches, |
|
|
ent_coef=ent_coef, |
|
|
mpi_rank_weight=mpi_rank_weight, |
|
|
comm=comm, |
|
|
vf_coef=0.5, |
|
|
max_grad_norm=0.5, |
|
|
seed=rand_seed, |
|
|
load_path=log_dir, |
|
|
alpha1=alpha1, |
|
|
alpha2=alpha2, |
|
|
) |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def main(): |
|
|
cluster_path = './train_procgen/figures/' + args.env_name + '_cluster_images' |
|
|
|
|
|
if isEmpty(cluster_path): |
|
|
plt.rcParams.update({'font.size': 15}) |
|
|
comm = MPI.COMM_WORLD |
|
|
|
|
|
saved_dir = './train_procgen/checkpoints/sppo-' + args.env_name + '_' + args.distribution_mode + '_' + str( |
|
|
args.num_levels) + '_' + str(args.start_level) + '_' + str( |
|
|
args.rand_seed) |
|
|
|
|
|
saved_dir += '/checkpoints/' + str(args.timesteps_per_proc) |
|
|
|
|
|
model = load_model(args.env_name, |
|
|
1, |
|
|
args.distribution_mode, |
|
|
args.num_levels, |
|
|
args.start_level, |
|
|
args.num_embeddings, |
|
|
comm=comm, |
|
|
log_dir=saved_dir, |
|
|
rand_seed=args.rand_seed, |
|
|
beta=args.beta, |
|
|
alpha1=args.alpha1, |
|
|
alpha2=args.alpha2, |
|
|
) |
|
|
|
|
|
if not os.path.isdir(cluster_path): |
|
|
if not os.path.isdir('./train_procgen/figures'): |
|
|
os.mkdir('./train_procgen/figures') |
|
|
os.mkdir(cluster_path) |
|
|
for i in range(args.num_embeddings): |
|
|
if not os.path.isdir(cluster_path + '/cluster_' + str(i)): |
|
|
os.mkdir(cluster_path + '/cluster_' + str(i)) |
|
|
|
|
|
|
|
|
set_global_seeds(args.rand_seed) |
|
|
env_name = "procgen:procgen-" + args.env_name + "-v0" |
|
|
env = gym.make(env_name, num_levels=args.num_levels, start_level=args.start_level, |
|
|
distribution_mode=args.distribution_mode, rand_seed=args.rand_seed, render_mode="human") |
|
|
obs = env.reset() |
|
|
counters = np.zeros(args.num_embeddings) |
|
|
episode_i = 0 |
|
|
total_states = args.total_states |
|
|
upper_frames = args.total_states * 0.2 |
|
|
episode_files = [] |
|
|
sequence_files = [[] for _ in range(args.num_embeddings)] |
|
|
sequence_indices = [[] for _ in range(args.num_embeddings)] |
|
|
episode_rewards = 0 |
|
|
episode_counters = np.zeros(args.num_embeddings) |
|
|
|
|
|
if args.env_name not in ['starpilot', 'fruitbot']: |
|
|
episode_max_steps = 200 |
|
|
else: |
|
|
episode_max_steps = math.inf |
|
|
|
|
|
while True: |
|
|
rgb_img = env.render(mode="rgb_array") |
|
|
im = Image.fromarray(rgb_img) |
|
|
a, v, pure_latent, vq_latent, pure_vq_latent, vq_embeddings, encoding_indices, sl, lat = model.skill_step( |
|
|
obs) |
|
|
action = np.squeeze(a) |
|
|
if episode_counters[encoding_indices[0][0]] < upper_frames: |
|
|
saved_dir = cluster_path + '/cluster_' + str(encoding_indices[0][0]) |
|
|
file_name = str(episode_i + 1) + '_' + str(sl[0][0]) + '_' + str(sl[0][1]) + '.png' |
|
|
im.save(saved_dir + '/' + file_name, dpi=(5, 5)) |
|
|
episode_files.append(saved_dir + '/' + file_name) |
|
|
if len(sequence_files[encoding_indices[0][0]]) == 0: |
|
|
sequence_files[encoding_indices[0][0]].append(saved_dir + '/' + file_name) |
|
|
sequence_indices[encoding_indices[0][0]].append(episode_i + 1) |
|
|
else: |
|
|
if (episode_i + 1 - int(sequence_indices[encoding_indices[0][0]][-1])) == 1: |
|
|
sequence_files[encoding_indices[0][0]].append(saved_dir + '/' + file_name) |
|
|
sequence_indices[encoding_indices[0][0]].append(episode_i + 1) |
|
|
else: |
|
|
if len(sequence_files[encoding_indices[0][0]]) > 20: |
|
|
for f in sequence_files[encoding_indices[0][0]]: |
|
|
if os.path.isfile(f): |
|
|
os.remove(f) |
|
|
episode_counters[encoding_indices[0][0]] -= 1 |
|
|
else: |
|
|
continue |
|
|
|
|
|
sequence_files[encoding_indices[0][0]] = [] |
|
|
sequence_indices[encoding_indices[0][0]] = [] |
|
|
sequence_files[encoding_indices[0][0]].append(saved_dir + '/' + file_name) |
|
|
sequence_indices[encoding_indices[0][0]].append(episode_i + 1) |
|
|
|
|
|
obs, reward, done, info = env.step(action) |
|
|
episode_rewards += reward |
|
|
episode_counters[encoding_indices[0][0]] += 1 |
|
|
episode_i += 1 |
|
|
|
|
|
if done: |
|
|
if len(episode_files) > episode_max_steps: |
|
|
for f in episode_files: |
|
|
if os.path.isfile(f): |
|
|
os.remove(f) |
|
|
for i in range(args.num_embeddings): |
|
|
if f in sequence_files[i]: |
|
|
sequence_indices[i].pop(sequence_files[i].index(f)) |
|
|
sequence_files[i].remove(f) |
|
|
else: |
|
|
continue |
|
|
|
|
|
episode_counters = deepcopy(counters) |
|
|
counters = deepcopy(episode_counters) |
|
|
|
|
|
episode_files = [] |
|
|
episode_rewards = 0 |
|
|
|
|
|
if np.sum(counters) >= total_states: |
|
|
env.close() |
|
|
break |
|
|
|
|
|
skill_latents = [] |
|
|
cluster_indices = [] |
|
|
images = [] |
|
|
|
|
|
for i in range(args.num_embeddings): |
|
|
files = os.listdir(cluster_path + '/cluster_' + str(i)) |
|
|
for file in files: |
|
|
if file.endswith('png'): |
|
|
_, sl_x, sl_y = file[:-4].split('_') |
|
|
arr_img = plt.imread(cluster_path + '/cluster_' + str(i) + '/' + file) |
|
|
images.append(arr_img) |
|
|
skill_latents.append([float(sl_x), float(sl_y)]) |
|
|
cluster_indices.append(i) |
|
|
|
|
|
images = np.array(images) |
|
|
skill_latents = np.array(skill_latents) |
|
|
cluster_indices = np.array(cluster_indices) |
|
|
|
|
|
fig, ax = plt.subplots() |
|
|
fig.set_figheight(6) |
|
|
fig.set_figwidth(8) |
|
|
np.random.seed(0) |
|
|
color_set = [(np.random.rand(), np.random.rand(), np.random.rand()) for _ in range(args.num_embeddings)] |
|
|
|
|
|
c_cet = [] |
|
|
for cs in color_set: |
|
|
c_cet.append(c.to_hex(cs)) |
|
|
c_cet = np.array(c_cet) |
|
|
|
|
|
figure = plt.scatter(skill_latents[:, 0], skill_latents[:, 1], s=10, c=c_cet[cluster_indices], alpha=0.5) |
|
|
line, = plt.plot(skill_latents[:, 0], skill_latents[:, 1], ls="") |
|
|
|
|
|
plt.title('The FDR space') |
|
|
plt.xlabel('FDR x') |
|
|
plt.ylabel('FDR y') |
|
|
|
|
|
|
|
|
im = OffsetImage(images[0, :, :, :], zoom=0.2) |
|
|
xybox = (100., 100.) |
|
|
ab = AnnotationBbox(im, (0, 0), xybox=xybox, xycoords='data', |
|
|
boxcoords="offset points", pad=0.3, arrowprops=dict(arrowstyle="->")) |
|
|
|
|
|
ax.add_artist(ab) |
|
|
ab.set_visible(False) |
|
|
|
|
|
def hover(event): |
|
|
|
|
|
if line.contains(event)[0]: |
|
|
|
|
|
try: |
|
|
ind, = line.contains(event)[1]["ind"] |
|
|
except: |
|
|
indices = line.contains(event)[1]["ind"] |
|
|
ind = indices[0] |
|
|
|
|
|
w, h = fig.get_size_inches() * fig.dpi |
|
|
ws = (event.x > w / 2.) * -1 + (event.x <= w / 2.) |
|
|
hs = (event.y > h / 2.) * -1 + (event.y <= h / 2.) |
|
|
|
|
|
|
|
|
ab.xybox = (xybox[0] * ws, xybox[1] * hs) |
|
|
|
|
|
ab.set_visible(True) |
|
|
|
|
|
ab.xy = (skill_latents[:, 0][ind], skill_latents[:, 1][ind]) |
|
|
|
|
|
im.set_data(images[ind, :, :, :]) |
|
|
else: |
|
|
|
|
|
ab.set_visible(False) |
|
|
fig.canvas.draw_idle() |
|
|
|
|
|
|
|
|
fig.canvas.mpl_connect('motion_notify_event', hover) |
|
|
plt.show() |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|