|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| r"""Script for training an RL agent using the UVF algorithm.
|
|
|
| To run locally: See run_train.py
|
| """
|
|
|
| from __future__ import absolute_import
|
| from __future__ import division
|
| from __future__ import print_function
|
|
|
| import os
|
| import time
|
| import tensorflow as tf
|
| slim = tf.contrib.slim
|
|
|
| import gin.tf
|
|
|
| import train_utils
|
| import agent as agent_
|
| from agents import circular_buffer
|
| from utils import utils as uvf_utils
|
| from environments import create_maze_env
|
|
|
|
|
|
|
| flags = tf.app.flags
|
|
|
| FLAGS = flags.FLAGS
|
| flags.DEFINE_string('goal_sample_strategy', 'sample',
|
| 'None, sample, FuN')
|
|
|
| LOAD_PATH = None
|
|
|
|
|
| def collect_experience(tf_env, agent, meta_agent, state_preprocess,
|
| replay_buffer, meta_replay_buffer,
|
| action_fn, meta_action_fn,
|
| environment_steps, num_episodes, num_resets,
|
| episode_rewards, episode_meta_rewards,
|
| store_context,
|
| disable_agent_reset):
|
| """Collect experience in a tf_env into a replay_buffer using action_fn.
|
|
|
| Args:
|
| tf_env: A TFEnvironment.
|
| agent: A UVF agent.
|
| meta_agent: A Meta Agent.
|
| replay_buffer: A Replay buffer to collect experience in.
|
| meta_replay_buffer: A Replay buffer to collect meta agent experience in.
|
| action_fn: A function to produce actions given current state.
|
| meta_action_fn: A function to produce meta actions given current state.
|
| environment_steps: A variable to count the number of steps in the tf_env.
|
| num_episodes: A variable to count the number of episodes.
|
| num_resets: A variable to count the number of resets.
|
| store_context: A boolean to check if store context in replay.
|
| disable_agent_reset: A boolean that disables agent from resetting.
|
|
|
| Returns:
|
| A collect_experience_op that excute an action and store into the
|
| replay_buffers
|
| """
|
| tf_env.start_collect()
|
| state = tf_env.current_obs()
|
| state_repr = state_preprocess(state)
|
| action = action_fn(state, context=None)
|
|
|
| with tf.control_dependencies([state]):
|
| transition_type, reward, discount = tf_env.step(action)
|
|
|
| def increment_step():
|
| return environment_steps.assign_add(1)
|
|
|
| def increment_episode():
|
| return num_episodes.assign_add(1)
|
|
|
| def increment_reset():
|
| return num_resets.assign_add(1)
|
|
|
| def update_episode_rewards(context_reward, meta_reward, reset):
|
| new_episode_rewards = tf.concat(
|
| [episode_rewards[:1] + context_reward, episode_rewards[1:]], 0)
|
| new_episode_meta_rewards = tf.concat(
|
| [episode_meta_rewards[:1] + meta_reward,
|
| episode_meta_rewards[1:]], 0)
|
| return tf.group(
|
| episode_rewards.assign(
|
| tf.cond(reset,
|
| lambda: tf.concat([[0.], episode_rewards[:-1]], 0),
|
| lambda: new_episode_rewards)),
|
| episode_meta_rewards.assign(
|
| tf.cond(reset,
|
| lambda: tf.concat([[0.], episode_meta_rewards[:-1]], 0),
|
| lambda: new_episode_meta_rewards)))
|
|
|
| def no_op_int():
|
| return tf.constant(0, dtype=tf.int64)
|
|
|
| step_cond = agent.step_cond_fn(state, action,
|
| transition_type,
|
| environment_steps, num_episodes)
|
| reset_episode_cond = agent.reset_episode_cond_fn(
|
| state, action,
|
| transition_type, environment_steps, num_episodes)
|
| reset_env_cond = agent.reset_env_cond_fn(state, action,
|
| transition_type,
|
| environment_steps, num_episodes)
|
|
|
| increment_step_op = tf.cond(step_cond, increment_step, no_op_int)
|
| increment_episode_op = tf.cond(reset_episode_cond, increment_episode,
|
| no_op_int)
|
| increment_reset_op = tf.cond(reset_env_cond, increment_reset, no_op_int)
|
| increment_op = tf.group(increment_step_op, increment_episode_op,
|
| increment_reset_op)
|
|
|
| with tf.control_dependencies([increment_op, reward, discount]):
|
| next_state = tf_env.current_obs()
|
| next_state_repr = state_preprocess(next_state)
|
| next_reset_episode_cond = tf.logical_or(
|
| agent.reset_episode_cond_fn(
|
| state, action,
|
| transition_type, environment_steps, num_episodes),
|
| tf.equal(discount, 0.0))
|
|
|
| if store_context:
|
| context = [tf.identity(var) + tf.zeros_like(var) for var in agent.context_vars]
|
| meta_context = [tf.identity(var) + tf.zeros_like(var) for var in meta_agent.context_vars]
|
| else:
|
| context = []
|
| meta_context = []
|
| with tf.control_dependencies([next_state] + context + meta_context):
|
| if disable_agent_reset:
|
| collect_experience_ops = [tf.no_op()]
|
| else:
|
| collect_experience_ops = agent.cond_begin_episode_op(
|
| tf.logical_not(reset_episode_cond),
|
| [state, action, reward, next_state,
|
| state_repr, next_state_repr],
|
| mode='explore', meta_action_fn=meta_action_fn)
|
| context_reward, meta_reward = collect_experience_ops
|
| collect_experience_ops = list(collect_experience_ops)
|
| collect_experience_ops.append(
|
| update_episode_rewards(tf.reduce_sum(context_reward), meta_reward,
|
| reset_episode_cond))
|
|
|
| meta_action_every_n = agent.tf_context.meta_action_every_n
|
| with tf.control_dependencies(collect_experience_ops):
|
| transition = [state, action, reward, discount, next_state]
|
|
|
| meta_action = tf.to_float(
|
| tf.concat(context, -1))
|
|
|
| meta_end = tf.logical_and(
|
| tf.equal(agent.tf_context.t % meta_action_every_n, 1),
|
| agent.tf_context.t > 1)
|
| with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
|
| states_var = tf.get_variable('states_var',
|
| [meta_action_every_n, state.shape[-1]],
|
| state.dtype)
|
| actions_var = tf.get_variable('actions_var',
|
| [meta_action_every_n, action.shape[-1]],
|
| action.dtype)
|
| state_var = tf.get_variable('state_var', state.shape, state.dtype)
|
| reward_var = tf.get_variable('reward_var', reward.shape, reward.dtype)
|
| meta_action_var = tf.get_variable('meta_action_var',
|
| meta_action.shape, meta_action.dtype)
|
| meta_context_var = [
|
| tf.get_variable('meta_context_var%d' % idx,
|
| meta_context[idx].shape, meta_context[idx].dtype)
|
| for idx in range(len(meta_context))]
|
|
|
| actions_var_upd = tf.scatter_update(
|
| actions_var, (agent.tf_context.t - 2) % meta_action_every_n, action)
|
| with tf.control_dependencies([actions_var_upd]):
|
| actions = tf.identity(actions_var) + tf.zeros_like(actions_var)
|
| meta_reward = tf.identity(meta_reward) + tf.zeros_like(meta_reward)
|
| meta_reward = tf.reshape(meta_reward, reward.shape)
|
|
|
| reward = 0.1 * meta_reward
|
| meta_transition = [state_var, meta_action_var,
|
| reward_var + reward,
|
| discount * (1 - tf.to_float(next_reset_episode_cond)),
|
| next_state]
|
| meta_transition.extend([states_var, actions])
|
| if store_context:
|
| transition += context + list(agent.context_vars)
|
| meta_transition += meta_context_var + list(meta_agent.context_vars)
|
|
|
| meta_step_cond = tf.squeeze(tf.logical_and(step_cond, tf.logical_or(next_reset_episode_cond, meta_end)))
|
|
|
| collect_experience_op = tf.group(
|
| replay_buffer.maybe_add(transition, step_cond),
|
| meta_replay_buffer.maybe_add(meta_transition, meta_step_cond),
|
| )
|
|
|
| with tf.control_dependencies([collect_experience_op]):
|
| collect_experience_op = tf.cond(reset_env_cond,
|
| tf_env.reset,
|
| tf_env.current_time_step)
|
|
|
| meta_period = tf.equal(agent.tf_context.t % meta_action_every_n, 1)
|
| states_var_upd = tf.scatter_update(
|
| states_var, (agent.tf_context.t - 1) % meta_action_every_n,
|
| next_state)
|
| state_var_upd = tf.assign(
|
| state_var,
|
| tf.cond(meta_period, lambda: next_state, lambda: state_var))
|
| reward_var_upd = tf.assign(
|
| reward_var,
|
| tf.cond(meta_period,
|
| lambda: tf.zeros_like(reward_var),
|
| lambda: reward_var + reward))
|
| meta_action = tf.to_float(tf.concat(agent.context_vars, -1))
|
| meta_action_var_upd = tf.assign(
|
| meta_action_var,
|
| tf.cond(meta_period, lambda: meta_action, lambda: meta_action_var))
|
| meta_context_var_upd = [
|
| tf.assign(
|
| meta_context_var[idx],
|
| tf.cond(meta_period,
|
| lambda: meta_agent.context_vars[idx],
|
| lambda: meta_context_var[idx]))
|
| for idx in range(len(meta_context))]
|
|
|
| return tf.group(
|
| collect_experience_op,
|
| states_var_upd,
|
| state_var_upd,
|
| reward_var_upd,
|
| meta_action_var_upd,
|
| *meta_context_var_upd)
|
|
|
|
|
| def sample_best_meta_actions(state_reprs, next_state_reprs, prev_meta_actions,
|
| low_states, low_actions, low_state_reprs,
|
| inverse_dynamics, uvf_agent, k=10):
|
| """Return meta-actions which approximately maximize low-level log-probs."""
|
| sampled_actions = inverse_dynamics.sample(state_reprs, next_state_reprs, k, prev_meta_actions)
|
| sampled_actions = tf.stop_gradient(sampled_actions)
|
| sampled_log_probs = tf.reshape(uvf_agent.log_probs(
|
| tf.tile(low_states, [k, 1, 1]),
|
| tf.tile(low_actions, [k, 1, 1]),
|
| tf.tile(low_state_reprs, [k, 1, 1]),
|
| [tf.reshape(sampled_actions, [-1, sampled_actions.shape[-1]])]),
|
| [k, low_states.shape[0],
|
| low_states.shape[1], -1])
|
| fitness = tf.reduce_sum(sampled_log_probs, [2, 3])
|
| best_actions = tf.argmax(fitness, 0)
|
| actions = tf.gather_nd(
|
| sampled_actions,
|
| tf.stack([best_actions,
|
| tf.range(prev_meta_actions.shape[0], dtype=tf.int64)], -1))
|
| return actions
|
|
|
|
|
| @gin.configurable
|
| def train_uvf(train_dir,
|
| environment=None,
|
| num_bin_actions=3,
|
| agent_class=None,
|
| meta_agent_class=None,
|
| state_preprocess_class=None,
|
| inverse_dynamics_class=None,
|
| exp_action_wrapper=None,
|
| replay_buffer=None,
|
| meta_replay_buffer=None,
|
| replay_num_steps=1,
|
| meta_replay_num_steps=1,
|
| critic_optimizer=None,
|
| actor_optimizer=None,
|
| meta_critic_optimizer=None,
|
| meta_actor_optimizer=None,
|
| repr_optimizer=None,
|
| relabel_contexts=False,
|
| meta_relabel_contexts=False,
|
| batch_size=64,
|
| repeat_size=0,
|
| num_episodes_train=2000,
|
| initial_episodes=2,
|
| initial_steps=None,
|
| num_updates_per_observation=1,
|
| num_collect_per_update=1,
|
| num_collect_per_meta_update=1,
|
| gamma=1.0,
|
| meta_gamma=1.0,
|
| reward_scale_factor=1.0,
|
| target_update_period=1,
|
| should_stop_early=None,
|
| clip_gradient_norm=0.0,
|
| summarize_gradients=False,
|
| debug_summaries=False,
|
| log_every_n_steps=100,
|
| prefetch_queue_capacity=2,
|
| policy_save_dir='policy',
|
| save_policy_every_n_steps=1000,
|
| save_policy_interval_secs=0,
|
| replay_context_ratio=0.0,
|
| next_state_as_context_ratio=0.0,
|
| state_index=0,
|
| zero_timer_ratio=0.0,
|
| timer_index=-1,
|
| debug=False,
|
| max_policies_to_save=None,
|
| max_steps_per_episode=None,
|
| load_path=LOAD_PATH):
|
| """Train an agent."""
|
| tf_env = create_maze_env.TFPyEnvironment(environment)
|
| observation_spec = [tf_env.observation_spec()]
|
| action_spec = [tf_env.action_spec()]
|
|
|
| max_steps_per_episode = max_steps_per_episode or tf_env.pyenv.max_episode_steps
|
|
|
| assert max_steps_per_episode, 'max_steps_per_episode need to be set'
|
|
|
| if initial_steps is None:
|
| initial_steps = initial_episodes * max_steps_per_episode
|
|
|
| if agent_class.ACTION_TYPE == 'discrete':
|
| assert False
|
| else:
|
| assert agent_class.ACTION_TYPE == 'continuous'
|
|
|
| assert agent_class.ACTION_TYPE == meta_agent_class.ACTION_TYPE
|
| with tf.variable_scope('meta_agent'):
|
| meta_agent = meta_agent_class(
|
| observation_spec,
|
| action_spec,
|
| tf_env,
|
| debug_summaries=debug_summaries)
|
| meta_agent.set_replay(replay=meta_replay_buffer)
|
|
|
| with tf.variable_scope('uvf_agent'):
|
| uvf_agent = agent_class(
|
| observation_spec,
|
| action_spec,
|
| tf_env,
|
| debug_summaries=debug_summaries)
|
| uvf_agent.set_meta_agent(agent=meta_agent)
|
| uvf_agent.set_replay(replay=replay_buffer)
|
|
|
| with tf.variable_scope('state_preprocess'):
|
| state_preprocess = state_preprocess_class()
|
|
|
| with tf.variable_scope('inverse_dynamics'):
|
| inverse_dynamics = inverse_dynamics_class(
|
| meta_agent.sub_context_as_action_specs[0])
|
|
|
|
|
| global_step = tf.contrib.framework.get_or_create_global_step()
|
| num_episodes = tf.Variable(0, dtype=tf.int64, name='num_episodes')
|
| num_resets = tf.Variable(0, dtype=tf.int64, name='num_resets')
|
| num_updates = tf.Variable(0, dtype=tf.int64, name='num_updates')
|
| num_meta_updates = tf.Variable(0, dtype=tf.int64, name='num_meta_updates')
|
| episode_rewards = tf.Variable([0.] * 100, name='episode_rewards')
|
| episode_meta_rewards = tf.Variable([0.] * 100, name='episode_meta_rewards')
|
|
|
|
|
| train_utils.create_counter_summaries([
|
| ('environment_steps', global_step),
|
| ('num_episodes', num_episodes),
|
| ('num_resets', num_resets),
|
| ('num_updates', num_updates),
|
| ('num_meta_updates', num_meta_updates),
|
| ('replay_buffer_adds', replay_buffer.get_num_adds()),
|
| ('meta_replay_buffer_adds', meta_replay_buffer.get_num_adds()),
|
| ])
|
|
|
| tf.summary.scalar('avg_episode_rewards',
|
| tf.reduce_mean(episode_rewards[1:]))
|
| tf.summary.scalar('avg_episode_meta_rewards',
|
| tf.reduce_mean(episode_meta_rewards[1:]))
|
| tf.summary.histogram('episode_rewards', episode_rewards[1:])
|
| tf.summary.histogram('episode_meta_rewards', episode_meta_rewards[1:])
|
|
|
|
|
| action_fn = uvf_agent.action
|
| action_fn = uvf_agent.add_noise_fn(action_fn, global_step=None)
|
| meta_action_fn = meta_agent.action
|
| meta_action_fn = meta_agent.add_noise_fn(meta_action_fn, global_step=None)
|
| meta_actions_fn = meta_agent.actions
|
| meta_actions_fn = meta_agent.add_noise_fn(meta_actions_fn, global_step=None)
|
| init_collect_experience_op = collect_experience(
|
| tf_env,
|
| uvf_agent,
|
| meta_agent,
|
| state_preprocess,
|
| replay_buffer,
|
| meta_replay_buffer,
|
| action_fn,
|
| meta_action_fn,
|
| environment_steps=global_step,
|
| num_episodes=num_episodes,
|
| num_resets=num_resets,
|
| episode_rewards=episode_rewards,
|
| episode_meta_rewards=episode_meta_rewards,
|
| store_context=True,
|
| disable_agent_reset=False,
|
| )
|
|
|
|
|
| collect_experience_op = collect_experience(
|
| tf_env,
|
| uvf_agent,
|
| meta_agent,
|
| state_preprocess,
|
| replay_buffer,
|
| meta_replay_buffer,
|
| action_fn,
|
| meta_action_fn,
|
| environment_steps=global_step,
|
| num_episodes=num_episodes,
|
| num_resets=num_resets,
|
| episode_rewards=episode_rewards,
|
| episode_meta_rewards=episode_meta_rewards,
|
| store_context=True,
|
| disable_agent_reset=False,
|
| )
|
|
|
| train_op_list = []
|
| repr_train_op = tf.constant(0.0)
|
| for mode in ['meta', 'nometa']:
|
| if mode == 'meta':
|
| agent = meta_agent
|
| buff = meta_replay_buffer
|
| critic_opt = meta_critic_optimizer
|
| actor_opt = meta_actor_optimizer
|
| relabel = meta_relabel_contexts
|
| num_steps = meta_replay_num_steps
|
| my_gamma = meta_gamma,
|
| n_updates = num_meta_updates
|
| else:
|
| agent = uvf_agent
|
| buff = replay_buffer
|
| critic_opt = critic_optimizer
|
| actor_opt = actor_optimizer
|
| relabel = relabel_contexts
|
| num_steps = replay_num_steps
|
| my_gamma = gamma
|
| n_updates = num_updates
|
|
|
| with tf.name_scope(mode):
|
| batch = buff.get_random_batch(batch_size, num_steps=num_steps)
|
| states, actions, rewards, discounts, next_states = batch[:5]
|
| with tf.name_scope('Reward'):
|
| tf.summary.scalar('average_step_reward', tf.reduce_mean(rewards))
|
| rewards *= reward_scale_factor
|
| batch_queue = slim.prefetch_queue.prefetch_queue(
|
| [states, actions, rewards, discounts, next_states] + batch[5:],
|
| capacity=prefetch_queue_capacity,
|
| name='batch_queue')
|
|
|
| batch_dequeue = batch_queue.dequeue()
|
| if repeat_size > 0:
|
| batch_dequeue = [
|
| tf.tile(batch, (repeat_size+1,) + (1,) * (batch.shape.ndims - 1))
|
| for batch in batch_dequeue
|
| ]
|
| batch_size *= (repeat_size + 1)
|
| states, actions, rewards, discounts, next_states = batch_dequeue[:5]
|
| if mode == 'meta':
|
| low_states = batch_dequeue[5]
|
| low_actions = batch_dequeue[6]
|
| low_state_reprs = state_preprocess(low_states)
|
| state_reprs = state_preprocess(states)
|
| next_state_reprs = state_preprocess(next_states)
|
|
|
| if mode == 'meta':
|
| prev_actions = actions
|
| if FLAGS.goal_sample_strategy == 'None':
|
| pass
|
| elif FLAGS.goal_sample_strategy == 'FuN':
|
| actions = inverse_dynamics.sample(state_reprs, next_state_reprs, 1, prev_actions, sc=0.1)
|
| actions = tf.stop_gradient(actions)
|
| elif FLAGS.goal_sample_strategy == 'sample':
|
| actions = sample_best_meta_actions(state_reprs, next_state_reprs, prev_actions,
|
| low_states, low_actions, low_state_reprs,
|
| inverse_dynamics, uvf_agent, k=10)
|
| else:
|
| assert False
|
|
|
| if state_preprocess.trainable and mode == 'meta':
|
|
|
|
|
| repr_loss, _, _ = state_preprocess.loss(states, next_states, low_actions, low_states)
|
| repr_train_op = slim.learning.create_train_op(
|
| repr_loss,
|
| repr_optimizer,
|
| global_step=None,
|
| update_ops=None,
|
| summarize_gradients=summarize_gradients,
|
| clip_gradient_norm=clip_gradient_norm,
|
| variables_to_train=state_preprocess.get_trainable_vars(),)
|
|
|
|
|
| contexts, next_contexts = agent.sample_contexts(
|
| mode='train', batch_size=batch_size,
|
| state=states, next_state=next_states,
|
| )
|
| if not relabel:
|
| contexts, next_contexts = (
|
| batch_dequeue[-2*len(contexts):-1*len(contexts)],
|
| batch_dequeue[-1*len(contexts):])
|
|
|
| merged_states = agent.merged_states(states, contexts)
|
| merged_next_states = agent.merged_states(next_states, next_contexts)
|
| if mode == 'nometa':
|
| context_rewards, context_discounts = agent.compute_rewards(
|
| 'train', state_reprs, actions, rewards, next_state_reprs, contexts)
|
| elif mode == 'meta':
|
| _, context_discounts = agent.compute_rewards(
|
| 'train', states, actions, rewards, next_states, contexts)
|
| context_rewards = rewards
|
|
|
| if agent.gamma_index is not None:
|
| context_discounts *= tf.cast(
|
| tf.reshape(contexts[agent.gamma_index], (-1,)),
|
| dtype=context_discounts.dtype)
|
| else: context_discounts *= my_gamma
|
|
|
| critic_loss = agent.critic_loss(merged_states, actions,
|
| context_rewards, context_discounts,
|
| merged_next_states)
|
|
|
| critic_loss = tf.reduce_mean(critic_loss)
|
|
|
| actor_loss = agent.actor_loss(merged_states, actions,
|
| context_rewards, context_discounts,
|
| merged_next_states)
|
| actor_loss *= tf.to_float(
|
| tf.equal(n_updates % target_update_period, 0))
|
|
|
| critic_train_op = slim.learning.create_train_op(
|
| critic_loss,
|
| critic_opt,
|
| global_step=n_updates,
|
| update_ops=None,
|
| summarize_gradients=summarize_gradients,
|
| clip_gradient_norm=clip_gradient_norm,
|
| variables_to_train=agent.get_trainable_critic_vars(),)
|
| critic_train_op = uvf_utils.tf_print(
|
| critic_train_op, [critic_train_op],
|
| message='critic_loss',
|
| print_freq=1000,
|
| name='critic_loss')
|
| train_op_list.append(critic_train_op)
|
| if actor_loss is not None:
|
| actor_train_op = slim.learning.create_train_op(
|
| actor_loss,
|
| actor_opt,
|
| global_step=None,
|
| update_ops=None,
|
| summarize_gradients=summarize_gradients,
|
| clip_gradient_norm=clip_gradient_norm,
|
| variables_to_train=agent.get_trainable_actor_vars(),)
|
| actor_train_op = uvf_utils.tf_print(
|
| actor_train_op, [actor_train_op],
|
| message='actor_loss',
|
| print_freq=1000,
|
| name='actor_loss')
|
| train_op_list.append(actor_train_op)
|
|
|
| assert len(train_op_list) == 4
|
|
|
| with tf.control_dependencies(train_op_list[2:]):
|
| update_targets_op = uvf_utils.periodically(
|
| uvf_agent.update_targets, target_update_period, 'update_targets')
|
| if meta_agent is not None:
|
| with tf.control_dependencies(train_op_list[:2]):
|
| update_meta_targets_op = uvf_utils.periodically(
|
| meta_agent.update_targets, target_update_period, 'update_targets')
|
|
|
| assert_op = tf.Assert(
|
| tf.less_equal(global_step, 200 + num_episodes_train * max_steps_per_episode),
|
| [global_step])
|
| with tf.control_dependencies([update_targets_op, assert_op]):
|
| train_op = tf.add_n(train_op_list[2:], name='post_update_targets')
|
|
|
| train_op += repr_train_op
|
| with tf.control_dependencies([update_meta_targets_op, assert_op]):
|
| meta_train_op = tf.add_n(train_op_list[:2],
|
| name='post_update_meta_targets')
|
|
|
| if debug_summaries:
|
| train_.gen_debug_batch_summaries(batch)
|
| slim.summaries.add_histogram_summaries(
|
| uvf_agent.get_trainable_critic_vars(), 'critic_vars')
|
| slim.summaries.add_histogram_summaries(
|
| uvf_agent.get_trainable_actor_vars(), 'actor_vars')
|
|
|
| train_ops = train_utils.TrainOps(train_op, meta_train_op,
|
| collect_experience_op)
|
|
|
| policy_save_path = os.path.join(train_dir, policy_save_dir, 'model.ckpt')
|
| policy_vars = uvf_agent.get_actor_vars() + meta_agent.get_actor_vars() + [
|
| global_step, num_episodes, num_resets
|
| ] + list(uvf_agent.context_vars) + list(meta_agent.context_vars) + state_preprocess.get_trainable_vars()
|
|
|
| policy_vars += uvf_agent.get_trainable_critic_vars() + meta_agent.get_trainable_critic_vars()
|
| policy_saver = tf.train.Saver(
|
| policy_vars, max_to_keep=max_policies_to_save, sharded=False)
|
|
|
| lowlevel_vars = (uvf_agent.get_actor_vars() +
|
| uvf_agent.get_trainable_critic_vars() +
|
| state_preprocess.get_trainable_vars())
|
| lowlevel_saver = tf.train.Saver(lowlevel_vars)
|
|
|
| def policy_save_fn(sess):
|
| policy_saver.save(
|
| sess, policy_save_path, global_step=global_step, write_meta_graph=False)
|
| if save_policy_interval_secs > 0:
|
| tf.logging.info(
|
| 'Wait %d secs after save policy.' % save_policy_interval_secs)
|
| time.sleep(save_policy_interval_secs)
|
|
|
| train_step_fn = train_utils.TrainStep(
|
| max_number_of_steps=num_episodes_train * max_steps_per_episode + 100,
|
| num_updates_per_observation=num_updates_per_observation,
|
| num_collect_per_update=num_collect_per_update,
|
| num_collect_per_meta_update=num_collect_per_meta_update,
|
| log_every_n_steps=log_every_n_steps,
|
| policy_save_fn=policy_save_fn,
|
| save_policy_every_n_steps=save_policy_every_n_steps,
|
| should_stop_early=should_stop_early).train_step
|
|
|
| local_init_op = tf.local_variables_initializer()
|
| init_targets_op = tf.group(uvf_agent.update_targets(1.0),
|
| meta_agent.update_targets(1.0))
|
|
|
| def initialize_training_fn(sess):
|
| """Initialize training function."""
|
| sess.run(local_init_op)
|
| sess.run(init_targets_op)
|
| if load_path:
|
| tf.logging.info('Restoring low-level from %s' % load_path)
|
| lowlevel_saver.restore(sess, load_path)
|
| global_step_value = sess.run(global_step)
|
| assert global_step_value == 0, 'Global step should be zero.'
|
| collect_experience_call = sess.make_callable(
|
| init_collect_experience_op)
|
|
|
| for _ in range(initial_steps):
|
| collect_experience_call()
|
|
|
| train_saver = tf.train.Saver(max_to_keep=2, sharded=True)
|
| tf.logging.info('train dir: %s', train_dir)
|
| return slim.learning.train(
|
| train_ops,
|
| train_dir,
|
| train_step_fn=train_step_fn,
|
| save_interval_secs=FLAGS.save_interval_secs,
|
| saver=train_saver,
|
| log_every_n_steps=0,
|
| global_step=global_step,
|
| master="",
|
| is_chief=(FLAGS.task == 0),
|
| save_summaries_secs=FLAGS.save_summaries_secs,
|
| init_fn=initialize_training_fn)
|
|
|