semantic_rl / train_procgen /hover_clusters.py
leonepson's picture
Upload 254 files
5960497 verified
import math
import sys
if not sys.warnoptions:
import warnings
warnings.simplefilter("ignore")
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.simplefilter("ignore")
import tensorflow as tf
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
tf.get_logger().setLevel('INFO')
tf.autograph.set_verbosity(0)
import warnings
warnings.filterwarnings("ignore", message=r"Passing", category=FutureWarning)
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=Warning)
from baselines.ppo2 import sppo2
from baselines.common.models import build_skill_impala_cnn
from baselines.common.mpi_util import setup_mpi_gpus
from procgen import ProcgenEnv
from baselines.common.vec_env import (
VecExtractDictObs,
VecMonitor,
VecNormalize
)
from baselines import logger
from mpi4py import MPI
import argparse
import matplotlib.colors as c
import gym
from PIL import Image
import matplotlib
import matplotlib.pyplot as plt
from baselines.common import set_global_seeds
from train_procgen.utils import *
from copy import deepcopy
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
parser = argparse.ArgumentParser(description='Process procgen enjoying arguments.')
parser.add_argument('--env_name', type=str, default='coinrun')
parser.add_argument('--num_envs', type=int, default=64)
parser.add_argument('--distribution_mode', type=str, default='easy',
choices=["easy", "hard", "exploration", "memory", "extreme"])
parser.add_argument('--num_levels', type=int, default=0)
parser.add_argument('--start_level', type=int, default=0)
parser.add_argument('--test_worker_interval', type=int, default=0)
parser.add_argument('--timesteps_per_proc', type=int, default=25_000_000)
parser.add_argument('--rand_seed', type=int, default=2021)
parser.add_argument('--num_embeddings', type=int, default=8)
parser.add_argument('--beta', type=float, default=0.0)
parser.add_argument('--alpha1', type=float, default=20)
parser.add_argument('--alpha2', type=float, default=20)
parser.add_argument('--total_states', type=int, default=5000, help='choose how many states you want to have')
args = parser.parse_args()
def isEmpty(path):
if os.path.exists(path) and not os.path.isfile(path):
# Checking if the directory is empty or not
if not os.listdir(path):
return True
else:
return False
else:
return True
def load_model(env_name, num_envs, distribution_mode, num_levels, start_level,
num_embeddings, beta, alpha1, alpha2,
log_dir='./checkpoints', comm=None, rand_seed=None):
ent_coef = .01
nsteps = 256
nminibatches = 8
mpi_rank_weight = 1
num_levels = num_levels
logger.info("creating environment")
venv = ProcgenEnv(num_envs=num_envs, env_name=env_name, num_levels=num_levels, start_level=start_level + 200000,
distribution_mode=distribution_mode, rand_seed=rand_seed)
venv = VecExtractDictObs(venv, "rgb")
venv = VecMonitor(
venv=venv, filename=None, keep_buf=100,
)
venv = VecNormalize(venv=venv, ob=False)
logger.info("creating tf session")
setup_mpi_gpus()
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True # pylint: disable=E1101
sess = tf.compat.v1.Session(config=config)
sess.__enter__()
conv_fn = lambda x: build_skill_impala_cnn(x, depths=[16, 32, 32], emb_dim=256, num_embeddings=num_embeddings,
beta=beta, seed=rand_seed)
logger.info("loading the model")
model = sppo2.load(
env=venv,
network=conv_fn,
nsteps=nsteps,
nminibatches=nminibatches,
ent_coef=ent_coef,
mpi_rank_weight=mpi_rank_weight,
comm=comm,
vf_coef=0.5,
max_grad_norm=0.5,
seed=rand_seed,
load_path=log_dir,
alpha1=alpha1,
alpha2=alpha2,
)
return model
def main():
cluster_path = './train_procgen/figures/' + args.env_name + '_cluster_images'
# if empty, we need to collect states first
if isEmpty(cluster_path):
plt.rcParams.update({'font.size': 15})
comm = MPI.COMM_WORLD
saved_dir = './train_procgen/checkpoints/sppo-' + args.env_name + '_' + args.distribution_mode + '_' + str(
args.num_levels) + '_' + str(args.start_level) + '_' + str(
args.rand_seed)
saved_dir += '/checkpoints/' + str(args.timesteps_per_proc)
model = load_model(args.env_name,
1,
args.distribution_mode,
args.num_levels,
args.start_level,
args.num_embeddings,
comm=comm,
log_dir=saved_dir,
rand_seed=args.rand_seed,
beta=args.beta,
alpha1=args.alpha1,
alpha2=args.alpha2,
)
if not os.path.isdir(cluster_path):
if not os.path.isdir('./train_procgen/figures'):
os.mkdir('./train_procgen/figures')
os.mkdir(cluster_path)
for i in range(args.num_embeddings):
if not os.path.isdir(cluster_path + '/cluster_' + str(i)):
os.mkdir(cluster_path + '/cluster_' + str(i))
# generate the skill clusters
set_global_seeds(args.rand_seed)
env_name = "procgen:procgen-" + args.env_name + "-v0"
env = gym.make(env_name, num_levels=args.num_levels, start_level=args.start_level,
distribution_mode=args.distribution_mode, rand_seed=args.rand_seed, render_mode="human")
obs = env.reset()
counters = np.zeros(args.num_embeddings)
episode_i = 0
total_states = args.total_states
upper_frames = args.total_states * 0.2
episode_files = []
sequence_files = [[] for _ in range(args.num_embeddings)]
sequence_indices = [[] for _ in range(args.num_embeddings)]
episode_rewards = 0
episode_counters = np.zeros(args.num_embeddings)
# avoid to collect the endless loop
if args.env_name not in ['starpilot', 'fruitbot']:
episode_max_steps = 200
else:
episode_max_steps = math.inf
while True:
rgb_img = env.render(mode="rgb_array")
im = Image.fromarray(rgb_img)
a, v, pure_latent, vq_latent, pure_vq_latent, vq_embeddings, encoding_indices, sl, lat = model.skill_step(
obs)
action = np.squeeze(a)
if episode_counters[encoding_indices[0][0]] < upper_frames:
saved_dir = cluster_path + '/cluster_' + str(encoding_indices[0][0])
file_name = str(episode_i + 1) + '_' + str(sl[0][0]) + '_' + str(sl[0][1]) + '.png'
im.save(saved_dir + '/' + file_name, dpi=(5, 5))
episode_files.append(saved_dir + '/' + file_name)
if len(sequence_files[encoding_indices[0][0]]) == 0:
sequence_files[encoding_indices[0][0]].append(saved_dir + '/' + file_name)
sequence_indices[encoding_indices[0][0]].append(episode_i + 1)
else:
if (episode_i + 1 - int(sequence_indices[encoding_indices[0][0]][-1])) == 1:
sequence_files[encoding_indices[0][0]].append(saved_dir + '/' + file_name)
sequence_indices[encoding_indices[0][0]].append(episode_i + 1)
else:
if len(sequence_files[encoding_indices[0][0]]) > 20:
for f in sequence_files[encoding_indices[0][0]]:
if os.path.isfile(f):
os.remove(f)
episode_counters[encoding_indices[0][0]] -= 1
else:
continue
sequence_files[encoding_indices[0][0]] = []
sequence_indices[encoding_indices[0][0]] = []
sequence_files[encoding_indices[0][0]].append(saved_dir + '/' + file_name)
sequence_indices[encoding_indices[0][0]].append(episode_i + 1)
obs, reward, done, info = env.step(action)
episode_rewards += reward
episode_counters[encoding_indices[0][0]] += 1
episode_i += 1
if done:
if len(episode_files) > episode_max_steps:
for f in episode_files:
if os.path.isfile(f):
os.remove(f)
for i in range(args.num_embeddings):
if f in sequence_files[i]:
sequence_indices[i].pop(sequence_files[i].index(f))
sequence_files[i].remove(f)
else:
continue
episode_counters = deepcopy(counters)
counters = deepcopy(episode_counters)
episode_files = []
episode_rewards = 0
if np.sum(counters) >= total_states:
env.close()
break
skill_latents = []
cluster_indices = []
images = []
for i in range(args.num_embeddings):
files = os.listdir(cluster_path + '/cluster_' + str(i))
for file in files:
if file.endswith('png'):
_, sl_x, sl_y = file[:-4].split('_')
arr_img = plt.imread(cluster_path + '/cluster_' + str(i) + '/' + file)
images.append(arr_img)
skill_latents.append([float(sl_x), float(sl_y)])
cluster_indices.append(i)
images = np.array(images)
skill_latents = np.array(skill_latents)
cluster_indices = np.array(cluster_indices)
fig, ax = plt.subplots()
fig.set_figheight(6)
fig.set_figwidth(8)
np.random.seed(0)
color_set = [(np.random.rand(), np.random.rand(), np.random.rand()) for _ in range(args.num_embeddings)]
c_cet = []
for cs in color_set:
c_cet.append(c.to_hex(cs))
c_cet = np.array(c_cet)
figure = plt.scatter(skill_latents[:, 0], skill_latents[:, 1], s=10, c=c_cet[cluster_indices], alpha=0.5)
line, = plt.plot(skill_latents[:, 0], skill_latents[:, 1], ls="")
plt.title('The FDR space')
plt.xlabel('FDR x')
plt.ylabel('FDR y')
# create the annotations box
im = OffsetImage(images[0, :, :, :], zoom=0.2)
xybox = (100., 100.)
ab = AnnotationBbox(im, (0, 0), xybox=xybox, xycoords='data',
boxcoords="offset points", pad=0.3, arrowprops=dict(arrowstyle="->"))
# add it to the axes and make it invisible
ax.add_artist(ab)
ab.set_visible(False)
def hover(event):
# if the mouse is over the scatter points
if line.contains(event)[0]:
# find out the index within the array from the event
try:
ind, = line.contains(event)[1]["ind"]
except:
indices = line.contains(event)[1]["ind"]
ind = indices[0]
# get the figure size
w, h = fig.get_size_inches() * fig.dpi
ws = (event.x > w / 2.) * -1 + (event.x <= w / 2.)
hs = (event.y > h / 2.) * -1 + (event.y <= h / 2.)
# if event occurs in the top or right quadrant of the figure,
# change the annotation box position relative to mouse.
ab.xybox = (xybox[0] * ws, xybox[1] * hs)
# make annotation box visible
ab.set_visible(True)
# place it at the position of the hovered scatter point
ab.xy = (skill_latents[:, 0][ind], skill_latents[:, 1][ind])
# set the image corresponding to that point
im.set_data(images[ind, :, :, :])
else:
# if the mouse is not over a scatter point
ab.set_visible(False)
fig.canvas.draw_idle()
# add callback for mouse moves
fig.canvas.mpl_connect('motion_notify_event', hover)
plt.show()
if __name__ == '__main__':
main()