|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Trainer for coordinating single or multi-replica training.
|
|
|
| Main point of entry for running models. Specifies most of
|
| the parameters used by different algorithms.
|
| """
|
|
|
| import tensorflow as tf
|
| import numpy as np
|
| import random
|
| import os
|
| import pickle
|
|
|
| from six.moves import xrange
|
| import controller
|
| import model
|
| import policy
|
| import baseline
|
| import objective
|
| import full_episode_objective
|
| import trust_region
|
| import optimizers
|
| import replay_buffer
|
| import expert_paths
|
| import gym_wrapper
|
| import env_spec
|
|
|
| app = tf.app
|
| flags = tf.flags
|
| logging = tf.logging
|
| gfile = tf.gfile
|
|
|
| FLAGS = flags.FLAGS
|
|
|
| flags.DEFINE_string('env', 'Copy-v0', 'environment name')
|
| flags.DEFINE_integer('batch_size', 100, 'batch size')
|
| flags.DEFINE_integer('replay_batch_size', None, 'replay batch size; defaults to batch_size')
|
| flags.DEFINE_integer('num_samples', 1,
|
| 'number of samples from each random seed initialization')
|
| flags.DEFINE_integer('max_step', 200, 'max number of steps to train on')
|
| flags.DEFINE_integer('cutoff_agent', 0,
|
| 'number of steps at which to cut-off agent. '
|
| 'Defaults to always cutoff')
|
| flags.DEFINE_integer('num_steps', 100000, 'number of training steps')
|
| flags.DEFINE_integer('validation_frequency', 100,
|
| 'every so many steps, output some stats')
|
|
|
| flags.DEFINE_float('target_network_lag', 0.95,
|
| 'This exponential decay on online network yields target '
|
| 'network')
|
| flags.DEFINE_string('sample_from', 'online',
|
| 'Sample actions from "online" network or "target" network')
|
|
|
| flags.DEFINE_string('objective', 'pcl',
|
| 'pcl/upcl/a3c/trpo/reinforce/urex')
|
| flags.DEFINE_bool('trust_region_p', False,
|
| 'use trust region for policy optimization')
|
| flags.DEFINE_string('value_opt', None,
|
| 'leave as None to optimize it along with policy '
|
| '(using critic_weight). Otherwise set to '
|
| '"best_fit" (least squares regression), "lbfgs", or "grad"')
|
| flags.DEFINE_float('max_divergence', 0.01,
|
| 'max divergence (i.e. KL) to allow during '
|
| 'trust region optimization')
|
|
|
| flags.DEFINE_float('learning_rate', 0.01, 'learning rate')
|
| flags.DEFINE_float('clip_norm', 5.0, 'clip norm')
|
| flags.DEFINE_float('clip_adv', 0.0, 'Clip advantages at this value. '
|
| 'Leave as 0 to not clip at all.')
|
| flags.DEFINE_float('critic_weight', 0.1, 'critic weight')
|
| flags.DEFINE_float('tau', 0.1, 'entropy regularizer.'
|
| 'If using decaying tau, this is the final value.')
|
| flags.DEFINE_float('tau_decay', None,
|
| 'decay tau by this much every 100 steps')
|
| flags.DEFINE_float('tau_start', 0.1,
|
| 'start tau at this value')
|
| flags.DEFINE_float('eps_lambda', 0.0, 'relative entropy regularizer.')
|
| flags.DEFINE_bool('update_eps_lambda', False,
|
| 'Update lambda automatically based on last 100 episodes.')
|
| flags.DEFINE_float('gamma', 1.0, 'discount')
|
| flags.DEFINE_integer('rollout', 10, 'rollout')
|
| flags.DEFINE_bool('use_target_values', False,
|
| 'use target network for value estimates')
|
| flags.DEFINE_bool('fixed_std', True,
|
| 'fix the std in Gaussian distributions')
|
| flags.DEFINE_bool('input_prev_actions', True,
|
| 'input previous actions to policy network')
|
| flags.DEFINE_bool('recurrent', True,
|
| 'use recurrent connections')
|
| flags.DEFINE_bool('input_time_step', False,
|
| 'input time step into value calucations')
|
|
|
| flags.DEFINE_bool('use_online_batch', True, 'train on batches as they are sampled')
|
| flags.DEFINE_bool('batch_by_steps', False,
|
| 'ensure each training batch has batch_size * max_step steps')
|
| flags.DEFINE_bool('unify_episodes', False,
|
| 'Make sure replay buffer holds entire episodes, '
|
| 'even across distinct sampling steps')
|
| flags.DEFINE_integer('replay_buffer_size', 5000, 'replay buffer size')
|
| flags.DEFINE_float('replay_buffer_alpha', 0.5, 'replay buffer alpha param')
|
| flags.DEFINE_integer('replay_buffer_freq', 0,
|
| 'replay buffer frequency (only supports -1/0/1)')
|
| flags.DEFINE_string('eviction', 'rand',
|
| 'how to evict from replay buffer: rand/rank/fifo')
|
| flags.DEFINE_string('prioritize_by', 'rewards',
|
| 'Prioritize replay buffer by "rewards" or "step"')
|
| flags.DEFINE_integer('num_expert_paths', 0,
|
| 'number of expert paths to seed replay buffer with')
|
|
|
| flags.DEFINE_integer('internal_dim', 256, 'RNN internal dim')
|
| flags.DEFINE_integer('value_hidden_layers', 0,
|
| 'number of hidden layers in value estimate')
|
| flags.DEFINE_integer('tf_seed', 42, 'random seed for tensorflow')
|
|
|
| flags.DEFINE_string('save_trajectories_dir', None,
|
| 'directory to save trajectories to, if desired')
|
| flags.DEFINE_string('load_trajectories_file', None,
|
| 'file to load expert trajectories from')
|
|
|
|
|
| flags.DEFINE_bool('supervisor', False, 'use supervisor training')
|
| flags.DEFINE_integer('task_id', 0, 'task id')
|
| flags.DEFINE_integer('ps_tasks', 0, 'number of ps tasks')
|
| flags.DEFINE_integer('num_replicas', 1, 'number of replicas used')
|
| flags.DEFINE_string('master', 'local', 'name of master')
|
| flags.DEFINE_string('save_dir', '', 'directory to save model to')
|
| flags.DEFINE_string('load_path', '', 'path of saved model to load (if none in save_dir)')
|
|
|
|
|
| class Trainer(object):
|
| """Coordinates single or multi-replica training."""
|
|
|
| def __init__(self):
|
| self.batch_size = FLAGS.batch_size
|
| self.replay_batch_size = FLAGS.replay_batch_size
|
| if self.replay_batch_size is None:
|
| self.replay_batch_size = self.batch_size
|
| self.num_samples = FLAGS.num_samples
|
|
|
| self.env_str = FLAGS.env
|
| self.env = gym_wrapper.GymWrapper(self.env_str,
|
| distinct=FLAGS.batch_size // self.num_samples,
|
| count=self.num_samples)
|
| self.eval_env = gym_wrapper.GymWrapper(
|
| self.env_str,
|
| distinct=FLAGS.batch_size // self.num_samples,
|
| count=self.num_samples)
|
| self.env_spec = env_spec.EnvSpec(self.env.get_one())
|
|
|
| self.max_step = FLAGS.max_step
|
| self.cutoff_agent = FLAGS.cutoff_agent
|
| self.num_steps = FLAGS.num_steps
|
| self.validation_frequency = FLAGS.validation_frequency
|
|
|
| self.target_network_lag = FLAGS.target_network_lag
|
| self.sample_from = FLAGS.sample_from
|
| assert self.sample_from in ['online', 'target']
|
|
|
| self.critic_weight = FLAGS.critic_weight
|
| self.objective = FLAGS.objective
|
| self.trust_region_p = FLAGS.trust_region_p
|
| self.value_opt = FLAGS.value_opt
|
| assert not self.trust_region_p or self.objective in ['pcl', 'trpo']
|
| assert self.objective != 'trpo' or self.trust_region_p
|
| assert self.value_opt is None or self.value_opt == 'None' or \
|
| self.critic_weight == 0.0
|
| self.max_divergence = FLAGS.max_divergence
|
|
|
| self.learning_rate = FLAGS.learning_rate
|
| self.clip_norm = FLAGS.clip_norm
|
| self.clip_adv = FLAGS.clip_adv
|
| self.tau = FLAGS.tau
|
| self.tau_decay = FLAGS.tau_decay
|
| self.tau_start = FLAGS.tau_start
|
| self.eps_lambda = FLAGS.eps_lambda
|
| self.update_eps_lambda = FLAGS.update_eps_lambda
|
| self.gamma = FLAGS.gamma
|
| self.rollout = FLAGS.rollout
|
| self.use_target_values = FLAGS.use_target_values
|
| self.fixed_std = FLAGS.fixed_std
|
| self.input_prev_actions = FLAGS.input_prev_actions
|
| self.recurrent = FLAGS.recurrent
|
| assert not self.trust_region_p or not self.recurrent
|
| self.input_time_step = FLAGS.input_time_step
|
| assert not self.input_time_step or (self.cutoff_agent <= self.max_step)
|
|
|
| self.use_online_batch = FLAGS.use_online_batch
|
| self.batch_by_steps = FLAGS.batch_by_steps
|
| self.unify_episodes = FLAGS.unify_episodes
|
| if self.unify_episodes:
|
| assert self.batch_size == 1
|
|
|
| self.replay_buffer_size = FLAGS.replay_buffer_size
|
| self.replay_buffer_alpha = FLAGS.replay_buffer_alpha
|
| self.replay_buffer_freq = FLAGS.replay_buffer_freq
|
| assert self.replay_buffer_freq in [-1, 0, 1]
|
| self.eviction = FLAGS.eviction
|
| self.prioritize_by = FLAGS.prioritize_by
|
| assert self.prioritize_by in ['rewards', 'step']
|
| self.num_expert_paths = FLAGS.num_expert_paths
|
|
|
| self.internal_dim = FLAGS.internal_dim
|
| self.value_hidden_layers = FLAGS.value_hidden_layers
|
| self.tf_seed = FLAGS.tf_seed
|
|
|
| self.save_trajectories_dir = FLAGS.save_trajectories_dir
|
| self.save_trajectories_file = (
|
| os.path.join(
|
| self.save_trajectories_dir, self.env_str.replace('-', '_'))
|
| if self.save_trajectories_dir else None)
|
| self.load_trajectories_file = FLAGS.load_trajectories_file
|
|
|
| self.hparams = dict((attr, getattr(self, attr))
|
| for attr in dir(self)
|
| if not attr.startswith('__') and
|
| not callable(getattr(self, attr)))
|
|
|
| def hparams_string(self):
|
| return '\n'.join('%s: %s' % item for item in sorted(self.hparams.items()))
|
|
|
| def get_objective(self):
|
| tau = self.tau
|
| if self.tau_decay is not None:
|
| assert self.tau_start >= self.tau
|
| tau = tf.maximum(
|
| tf.train.exponential_decay(
|
| self.tau_start, self.global_step, 100, self.tau_decay),
|
| self.tau)
|
|
|
| if self.objective in ['pcl', 'a3c', 'trpo', 'upcl']:
|
| cls = (objective.PCL if self.objective in ['pcl', 'upcl'] else
|
| objective.TRPO if self.objective == 'trpo' else
|
| objective.ActorCritic)
|
| policy_weight = 1.0
|
|
|
| return cls(self.learning_rate,
|
| clip_norm=self.clip_norm,
|
| policy_weight=policy_weight,
|
| critic_weight=self.critic_weight,
|
| tau=tau, gamma=self.gamma, rollout=self.rollout,
|
| eps_lambda=self.eps_lambda, clip_adv=self.clip_adv,
|
| use_target_values=self.use_target_values)
|
| elif self.objective in ['reinforce', 'urex']:
|
| cls = (full_episode_objective.Reinforce
|
| if self.objective == 'reinforce' else
|
| full_episode_objective.UREX)
|
| return cls(self.learning_rate,
|
| clip_norm=self.clip_norm,
|
| num_samples=self.num_samples,
|
| tau=tau, bonus_weight=1.0)
|
| else:
|
| assert False, 'Unknown objective %s' % self.objective
|
|
|
| def get_policy(self):
|
| if self.recurrent:
|
| cls = policy.Policy
|
| else:
|
| cls = policy.MLPPolicy
|
| return cls(self.env_spec, self.internal_dim,
|
| fixed_std=self.fixed_std,
|
| recurrent=self.recurrent,
|
| input_prev_actions=self.input_prev_actions)
|
|
|
| def get_baseline(self):
|
| cls = (baseline.UnifiedBaseline if self.objective == 'upcl' else
|
| baseline.Baseline)
|
| return cls(self.env_spec, self.internal_dim,
|
| input_prev_actions=self.input_prev_actions,
|
| input_time_step=self.input_time_step,
|
| input_policy_state=self.recurrent,
|
| n_hidden_layers=self.value_hidden_layers,
|
| hidden_dim=self.internal_dim,
|
| tau=self.tau)
|
|
|
| def get_trust_region_p_opt(self):
|
| if self.trust_region_p:
|
| return trust_region.TrustRegionOptimization(
|
| max_divergence=self.max_divergence)
|
| else:
|
| return None
|
|
|
| def get_value_opt(self):
|
| if self.value_opt == 'grad':
|
| return optimizers.GradOptimization(
|
| learning_rate=self.learning_rate, max_iter=5, mix_frac=0.05)
|
| elif self.value_opt == 'lbfgs':
|
| return optimizers.LbfgsOptimization(max_iter=25, mix_frac=0.1)
|
| elif self.value_opt == 'best_fit':
|
| return optimizers.BestFitOptimization(mix_frac=1.0)
|
| else:
|
| return None
|
|
|
| def get_model(self):
|
| cls = model.Model
|
| return cls(self.env_spec, self.global_step,
|
| target_network_lag=self.target_network_lag,
|
| sample_from=self.sample_from,
|
| get_policy=self.get_policy,
|
| get_baseline=self.get_baseline,
|
| get_objective=self.get_objective,
|
| get_trust_region_p_opt=self.get_trust_region_p_opt,
|
| get_value_opt=self.get_value_opt)
|
|
|
| def get_replay_buffer(self):
|
| if self.replay_buffer_freq <= 0:
|
| return None
|
| else:
|
| assert self.objective in ['pcl', 'upcl'], 'Can\'t use replay buffer with %s' % (
|
| self.objective)
|
| cls = replay_buffer.PrioritizedReplayBuffer
|
| return cls(self.replay_buffer_size,
|
| alpha=self.replay_buffer_alpha,
|
| eviction_strategy=self.eviction)
|
|
|
| def get_buffer_seeds(self):
|
| return expert_paths.sample_expert_paths(
|
| self.num_expert_paths, self.env_str, self.env_spec,
|
| load_trajectories_file=self.load_trajectories_file)
|
|
|
| def get_controller(self, env):
|
| """Get controller."""
|
| cls = controller.Controller
|
| return cls(env, self.env_spec, self.internal_dim,
|
| use_online_batch=self.use_online_batch,
|
| batch_by_steps=self.batch_by_steps,
|
| unify_episodes=self.unify_episodes,
|
| replay_batch_size=self.replay_batch_size,
|
| max_step=self.max_step,
|
| cutoff_agent=self.cutoff_agent,
|
| save_trajectories_file=self.save_trajectories_file,
|
| use_trust_region=self.trust_region_p,
|
| use_value_opt=self.value_opt not in [None, 'None'],
|
| update_eps_lambda=self.update_eps_lambda,
|
| prioritize_by=self.prioritize_by,
|
| get_model=self.get_model,
|
| get_replay_buffer=self.get_replay_buffer,
|
| get_buffer_seeds=self.get_buffer_seeds)
|
|
|
| def do_before_step(self, step):
|
| pass
|
|
|
| def run(self):
|
| """Run training."""
|
| is_chief = FLAGS.task_id == 0 or not FLAGS.supervisor
|
| sv = None
|
|
|
| def init_fn(sess, saver):
|
| ckpt = None
|
| if FLAGS.save_dir and sv is None:
|
| load_dir = FLAGS.save_dir
|
| ckpt = tf.train.get_checkpoint_state(load_dir)
|
| if ckpt and ckpt.model_checkpoint_path:
|
| logging.info('restoring from %s', ckpt.model_checkpoint_path)
|
| saver.restore(sess, ckpt.model_checkpoint_path)
|
| elif FLAGS.load_path:
|
| logging.info('restoring from %s', FLAGS.load_path)
|
| saver.restore(sess, FLAGS.load_path)
|
|
|
| if FLAGS.supervisor:
|
| with tf.device(tf.ReplicaDeviceSetter(FLAGS.ps_tasks, merge_devices=True)):
|
| self.global_step = tf.contrib.framework.get_or_create_global_step()
|
| tf.set_random_seed(FLAGS.tf_seed)
|
| self.controller = self.get_controller(self.env)
|
| self.model = self.controller.model
|
| self.controller.setup()
|
| with tf.variable_scope(tf.get_variable_scope(), reuse=True):
|
| self.eval_controller = self.get_controller(self.eval_env)
|
| self.eval_controller.setup(train=False)
|
|
|
| saver = tf.train.Saver(max_to_keep=10)
|
| step = self.model.global_step
|
| sv = tf.Supervisor(logdir=FLAGS.save_dir,
|
| is_chief=is_chief,
|
| saver=saver,
|
| save_model_secs=600,
|
| summary_op=None,
|
| save_summaries_secs=60,
|
| global_step=step,
|
| init_fn=lambda sess: init_fn(sess, saver))
|
| sess = sv.PrepareSession(FLAGS.master)
|
| else:
|
| tf.set_random_seed(FLAGS.tf_seed)
|
| self.global_step = tf.contrib.framework.get_or_create_global_step()
|
| self.controller = self.get_controller(self.env)
|
| self.model = self.controller.model
|
| self.controller.setup()
|
| with tf.variable_scope(tf.get_variable_scope(), reuse=True):
|
| self.eval_controller = self.get_controller(self.eval_env)
|
| self.eval_controller.setup(train=False)
|
|
|
| saver = tf.train.Saver(max_to_keep=10)
|
| sess = tf.Session()
|
| sess.run(tf.initialize_all_variables())
|
| init_fn(sess, saver)
|
|
|
| self.sv = sv
|
| self.sess = sess
|
|
|
| logging.info('hparams:\n%s', self.hparams_string())
|
|
|
| model_step = sess.run(self.model.global_step)
|
| if model_step >= self.num_steps:
|
| logging.info('training has reached final step')
|
| return
|
|
|
| losses = []
|
| rewards = []
|
| all_ep_rewards = []
|
| for step in xrange(1 + self.num_steps):
|
|
|
| if sv is not None and sv.ShouldStop():
|
| logging.info('stopping supervisor')
|
| break
|
|
|
| self.do_before_step(step)
|
|
|
| (loss, summary,
|
| total_rewards, episode_rewards) = self.controller.train(sess)
|
| _, greedy_episode_rewards = self.eval_controller.eval(sess)
|
| self.controller.greedy_episode_rewards = greedy_episode_rewards
|
| losses.append(loss)
|
| rewards.append(total_rewards)
|
| all_ep_rewards.extend(episode_rewards)
|
|
|
| if (random.random() < 0.1 and summary and episode_rewards and
|
| is_chief and sv and sv._summary_writer):
|
| sv.summary_computed(sess, summary)
|
|
|
| model_step = sess.run(self.model.global_step)
|
| if is_chief and step % self.validation_frequency == 0:
|
| logging.info('at training step %d, model step %d: '
|
| 'avg loss %f, avg reward %f, '
|
| 'episode rewards: %f, greedy rewards: %f',
|
| step, model_step,
|
| np.mean(losses), np.mean(rewards),
|
| np.mean(all_ep_rewards),
|
| np.mean(greedy_episode_rewards))
|
|
|
| losses = []
|
| rewards = []
|
| all_ep_rewards = []
|
|
|
| if model_step >= self.num_steps:
|
| logging.info('training has reached final step')
|
| break
|
|
|
| if is_chief and sv is not None:
|
| logging.info('saving final model to %s', sv.save_path)
|
| sv.saver.save(sess, sv.save_path, global_step=sv.global_step)
|
|
|
|
|
| def main(unused_argv):
|
| logging.set_verbosity(logging.INFO)
|
| trainer = Trainer()
|
| trainer.run()
|
|
|
|
|
| if __name__ == '__main__':
|
| app.run()
|
|
|