|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """A UVF agent.
|
| """
|
|
|
| import tensorflow as tf
|
| import gin.tf
|
| from agents import ddpg_agent
|
|
|
| import cond_fn
|
| from utils import utils as uvf_utils
|
| from context import gin_imports
|
|
|
| slim = tf.contrib.slim
|
|
|
|
|
| @gin.configurable
|
| class UvfAgentCore(object):
|
| """Defines basic functions for UVF agent. Must be inherited with an RL agent.
|
|
|
| Used as lower-level agent.
|
| """
|
|
|
| def __init__(self,
|
| observation_spec,
|
| action_spec,
|
| tf_env,
|
| tf_context,
|
| step_cond_fn=cond_fn.env_transition,
|
| reset_episode_cond_fn=cond_fn.env_restart,
|
| reset_env_cond_fn=cond_fn.false_fn,
|
| metrics=None,
|
| **base_agent_kwargs):
|
| """Constructs a UVF agent.
|
|
|
| Args:
|
| observation_spec: A TensorSpec defining the observations.
|
| action_spec: A BoundedTensorSpec defining the actions.
|
| tf_env: A Tensorflow environment object.
|
| tf_context: A Context class.
|
| step_cond_fn: A function indicating whether to increment the num of steps.
|
| reset_episode_cond_fn: A function indicating whether to restart the
|
| episode, resampling the context.
|
| reset_env_cond_fn: A function indicating whether to perform a manual reset
|
| of the environment.
|
| metrics: A list of functions that evaluate metrics of the agent.
|
| **base_agent_kwargs: A dictionary of parameters for base RL Agent.
|
| Raises:
|
| ValueError: If 'dqda_clipping' is < 0.
|
| """
|
| self._step_cond_fn = step_cond_fn
|
| self._reset_episode_cond_fn = reset_episode_cond_fn
|
| self._reset_env_cond_fn = reset_env_cond_fn
|
| self.metrics = metrics
|
|
|
|
|
| self.tf_context = tf_context(tf_env=tf_env)
|
| self.set_replay = self.tf_context.set_replay
|
| self.sample_contexts = self.tf_context.sample_contexts
|
| self.compute_rewards = self.tf_context.compute_rewards
|
| self.gamma_index = self.tf_context.gamma_index
|
| self.context_specs = self.tf_context.context_specs
|
| self.context_as_action_specs = self.tf_context.context_as_action_specs
|
| self.init_context_vars = self.tf_context.create_vars
|
|
|
| self.env_observation_spec = observation_spec[0]
|
| merged_observation_spec = (uvf_utils.merge_specs(
|
| (self.env_observation_spec,) + self.context_specs),)
|
| self._context_vars = dict()
|
| self._action_vars = dict()
|
|
|
| self.BASE_AGENT_CLASS.__init__(
|
| self,
|
| observation_spec=merged_observation_spec,
|
| action_spec=action_spec,
|
| **base_agent_kwargs
|
| )
|
|
|
| def set_meta_agent(self, agent=None):
|
| self._meta_agent = agent
|
|
|
| @property
|
| def meta_agent(self):
|
| return self._meta_agent
|
|
|
| def actor_loss(self, states, actions, rewards, discounts,
|
| next_states):
|
| """Returns the next action for the state.
|
|
|
| Args:
|
| state: A [num_state_dims] tensor representing a state.
|
| context: A list of [num_context_dims] tensor representing a context.
|
| Returns:
|
| A [num_action_dims] tensor representing the action.
|
| """
|
| return self.BASE_AGENT_CLASS.actor_loss(self, states)
|
|
|
| def action(self, state, context=None):
|
| """Returns the next action for the state.
|
|
|
| Args:
|
| state: A [num_state_dims] tensor representing a state.
|
| context: A list of [num_context_dims] tensor representing a context.
|
| Returns:
|
| A [num_action_dims] tensor representing the action.
|
| """
|
| merged_state = self.merged_state(state, context)
|
| return self.BASE_AGENT_CLASS.action(self, merged_state)
|
|
|
| def actions(self, state, context=None):
|
| """Returns the next action for the state.
|
|
|
| Args:
|
| state: A [-1, num_state_dims] tensor representing a state.
|
| context: A list of [-1, num_context_dims] tensor representing a context.
|
| Returns:
|
| A [-1, num_action_dims] tensor representing the action.
|
| """
|
| merged_states = self.merged_states(state, context)
|
| return self.BASE_AGENT_CLASS.actor_net(self, merged_states)
|
|
|
| def log_probs(self, states, actions, state_reprs, contexts=None):
|
| assert contexts is not None
|
| batch_dims = [tf.shape(states)[0], tf.shape(states)[1]]
|
| contexts = self.tf_context.context_multi_transition_fn(
|
| contexts, states=tf.to_float(state_reprs))
|
|
|
| flat_states = tf.reshape(states,
|
| [batch_dims[0] * batch_dims[1], states.shape[-1]])
|
| flat_contexts = [tf.reshape(tf.cast(context, states.dtype),
|
| [batch_dims[0] * batch_dims[1], context.shape[-1]])
|
| for context in contexts]
|
| flat_pred_actions = self.actions(flat_states, flat_contexts)
|
| pred_actions = tf.reshape(flat_pred_actions,
|
| batch_dims + [flat_pred_actions.shape[-1]])
|
|
|
| error = tf.square(actions - pred_actions)
|
| spec_range = (self._action_spec.maximum - self._action_spec.minimum) / 2
|
| normalized_error = tf.cast(error, tf.float64) / tf.constant(spec_range) ** 2
|
| return -normalized_error
|
|
|
| @gin.configurable('uvf_add_noise_fn')
|
| def add_noise_fn(self, action_fn, stddev=1.0, debug=False,
|
| clip=True, global_step=None):
|
| """Returns the action_fn with additive Gaussian noise.
|
|
|
| Args:
|
| action_fn: A callable(`state`, `context`) which returns a
|
| [num_action_dims] tensor representing a action.
|
| stddev: stddev for the Ornstein-Uhlenbeck noise.
|
| debug: Print debug messages.
|
| Returns:
|
| A [num_action_dims] action tensor.
|
| """
|
| if global_step is not None:
|
| stddev *= tf.maximum(
|
| tf.train.exponential_decay(1.0, global_step, 1e6, 0.8), 0.5)
|
| def noisy_action_fn(state, context=None):
|
| """Noisy action fn."""
|
| action = action_fn(state, context)
|
| if debug:
|
| action = uvf_utils.tf_print(
|
| action, [action],
|
| message='[add_noise_fn] pre-noise action',
|
| first_n=100)
|
| noise_dist = tf.distributions.Normal(tf.zeros_like(action),
|
| tf.ones_like(action) * stddev)
|
| noise = noise_dist.sample()
|
| action += noise
|
| if debug:
|
| action = uvf_utils.tf_print(
|
| action, [action],
|
| message='[add_noise_fn] post-noise action',
|
| first_n=100)
|
| if clip:
|
| action = uvf_utils.clip_to_spec(action, self._action_spec)
|
| return action
|
| return noisy_action_fn
|
|
|
| def merged_state(self, state, context=None):
|
| """Returns the merged state from the environment state and contexts.
|
|
|
| Args:
|
| state: A [num_state_dims] tensor representing a state.
|
| context: A list of [num_context_dims] tensor representing a context.
|
| If None, use the internal context.
|
| Returns:
|
| A [num_merged_state_dims] tensor representing the merged state.
|
| """
|
| if context is None:
|
| context = list(self.context_vars)
|
| state = tf.concat([state,] + context, axis=-1)
|
| self._validate_states(self._batch_state(state))
|
| return state
|
|
|
| def merged_states(self, states, contexts=None):
|
| """Returns the batch merged state from the batch env state and contexts.
|
|
|
| Args:
|
| states: A [batch_size, num_state_dims] tensor representing a batch
|
| of states.
|
| contexts: A list of [batch_size, num_context_dims] tensor
|
| representing a batch of contexts. If None,
|
| use the internal context.
|
| Returns:
|
| A [batch_size, num_merged_state_dims] tensor representing the batch
|
| of merged states.
|
| """
|
| if contexts is None:
|
| contexts = [tf.tile(tf.expand_dims(context, axis=0),
|
| (tf.shape(states)[0], 1)) for
|
| context in self.context_vars]
|
| states = tf.concat([states,] + contexts, axis=-1)
|
| self._validate_states(states)
|
| return states
|
|
|
| def unmerged_states(self, merged_states):
|
| """Returns the batch state and contexts from the batch merged state.
|
|
|
| Args:
|
| merged_states: A [batch_size, num_merged_state_dims] tensor
|
| representing a batch of merged states.
|
| Returns:
|
| A [batch_size, num_state_dims] tensor and a list of
|
| [batch_size, num_context_dims] tensors representing the batch state
|
| and contexts respectively.
|
| """
|
| self._validate_states(merged_states)
|
| num_state_dims = self.env_observation_spec.shape.as_list()[0]
|
| num_context_dims_list = [c.shape.as_list()[0] for c in self.context_specs]
|
| states = merged_states[:, :num_state_dims]
|
| contexts = []
|
| i = num_state_dims
|
| for num_context_dims in num_context_dims_list:
|
| contexts.append(merged_states[:, i: i+num_context_dims])
|
| i += num_context_dims
|
| return states, contexts
|
|
|
| def sample_random_actions(self, batch_size=1):
|
| """Return random actions.
|
|
|
| Args:
|
| batch_size: Batch size.
|
| Returns:
|
| A [batch_size, num_action_dims] tensor representing the batch of actions.
|
| """
|
| actions = tf.concat(
|
| [
|
| tf.random_uniform(
|
| shape=(batch_size, 1),
|
| minval=self._action_spec.minimum[i],
|
| maxval=self._action_spec.maximum[i])
|
| for i in range(self._action_spec.shape[0].value)
|
| ],
|
| axis=1)
|
| return actions
|
|
|
| def clip_actions(self, actions):
|
| """Clip actions to spec.
|
|
|
| Args:
|
| actions: A [batch_size, num_action_dims] tensor representing
|
| the batch of actions.
|
| Returns:
|
| A [batch_size, num_action_dims] tensor representing the batch
|
| of clipped actions.
|
| """
|
| actions = tf.concat(
|
| [
|
| tf.clip_by_value(
|
| actions[:, i:i+1],
|
| self._action_spec.minimum[i],
|
| self._action_spec.maximum[i])
|
| for i in range(self._action_spec.shape[0].value)
|
| ],
|
| axis=1)
|
| return actions
|
|
|
| def mix_contexts(self, contexts, insert_contexts, indices):
|
| """Mix two contexts based on indices.
|
|
|
| Args:
|
| contexts: A list of [batch_size, num_context_dims] tensor representing
|
| the batch of contexts.
|
| insert_contexts: A list of [batch_size, num_context_dims] tensor
|
| representing the batch of contexts to be inserted.
|
| indices: A list of a list of integers denoting indices to replace.
|
| Returns:
|
| A list of resulting contexts.
|
| """
|
| if indices is None: indices = [[]] * len(contexts)
|
| assert len(contexts) == len(indices)
|
| assert all([spec.shape.ndims == 1 for spec in self.context_specs])
|
| mix_contexts = []
|
| for contexts_, insert_contexts_, indices_, spec in zip(
|
| contexts, insert_contexts, indices, self.context_specs):
|
| mix_contexts.append(
|
| tf.concat(
|
| [
|
| insert_contexts_[:, i:i + 1] if i in indices_ else
|
| contexts_[:, i:i + 1] for i in range(spec.shape.as_list()[0])
|
| ],
|
| axis=1))
|
| return mix_contexts
|
|
|
| def begin_episode_ops(self, mode, action_fn=None, state=None):
|
| """Returns ops that reset agent at beginning of episodes.
|
|
|
| Args:
|
| mode: a string representing the mode=[train, explore, eval].
|
| Returns:
|
| A list of ops.
|
| """
|
| all_ops = []
|
| for _, action_var in sorted(self._action_vars.items()):
|
| sample_action = self.sample_random_actions(1)[0]
|
| all_ops.append(tf.assign(action_var, sample_action))
|
| all_ops += self.tf_context.reset(mode=mode, agent=self._meta_agent,
|
| action_fn=action_fn, state=state)
|
| return all_ops
|
|
|
| def cond_begin_episode_op(self, cond, input_vars, mode, meta_action_fn):
|
| """Returns op that resets agent at beginning of episodes.
|
|
|
| A new episode is begun if the cond op evalues to `False`.
|
|
|
| Args:
|
| cond: a Boolean tensor variable.
|
| input_vars: A list of tensor variables.
|
| mode: a string representing the mode=[train, explore, eval].
|
| Returns:
|
| Conditional begin op.
|
| """
|
| (state, action, reward, next_state,
|
| state_repr, next_state_repr) = input_vars
|
| def continue_fn():
|
| """Continue op fn."""
|
| items = [state, action, reward, next_state,
|
| state_repr, next_state_repr] + list(self.context_vars)
|
| batch_items = [tf.expand_dims(item, 0) for item in items]
|
| (states, actions, rewards, next_states,
|
| state_reprs, next_state_reprs) = batch_items[:6]
|
| context_reward = self.compute_rewards(
|
| mode, state_reprs, actions, rewards, next_state_reprs,
|
| batch_items[6:])[0][0]
|
| context_reward = tf.cast(context_reward, dtype=reward.dtype)
|
| if self.meta_agent is not None:
|
| meta_action = tf.concat(self.context_vars, -1)
|
| items = [state, meta_action, reward, next_state,
|
| state_repr, next_state_repr] + list(self.meta_agent.context_vars)
|
| batch_items = [tf.expand_dims(item, 0) for item in items]
|
| (states, meta_actions, rewards, next_states,
|
| state_reprs, next_state_reprs) = batch_items[:6]
|
| meta_reward = self.meta_agent.compute_rewards(
|
| mode, states, meta_actions, rewards,
|
| next_states, batch_items[6:])[0][0]
|
| meta_reward = tf.cast(meta_reward, dtype=reward.dtype)
|
| else:
|
| meta_reward = tf.constant(0, dtype=reward.dtype)
|
|
|
| with tf.control_dependencies([context_reward, meta_reward]):
|
| step_ops = self.tf_context.step(mode=mode, agent=self._meta_agent,
|
| state=state,
|
| next_state=next_state,
|
| state_repr=state_repr,
|
| next_state_repr=next_state_repr,
|
| action_fn=meta_action_fn)
|
| with tf.control_dependencies(step_ops):
|
| context_reward, meta_reward = map(tf.identity, [context_reward, meta_reward])
|
| return context_reward, meta_reward
|
| def begin_episode_fn():
|
| """Begin op fn."""
|
| begin_ops = self.begin_episode_ops(mode=mode, action_fn=meta_action_fn, state=state)
|
| with tf.control_dependencies(begin_ops):
|
| return tf.zeros_like(reward), tf.zeros_like(reward)
|
| with tf.control_dependencies(input_vars):
|
| cond_begin_episode_op = tf.cond(cond, continue_fn, begin_episode_fn)
|
| return cond_begin_episode_op
|
|
|
| def get_env_base_wrapper(self, env_base, **begin_kwargs):
|
| """Create a wrapper around env_base, with agent-specific begin/end_episode.
|
|
|
| Args:
|
| env_base: A python environment base.
|
| **begin_kwargs: Keyword args for begin_episode_ops.
|
| Returns:
|
| An object with begin_episode() and end_episode().
|
| """
|
| begin_ops = self.begin_episode_ops(**begin_kwargs)
|
| return uvf_utils.get_contextual_env_base(env_base, begin_ops)
|
|
|
| def init_action_vars(self, name, i=None):
|
| """Create and return a tensorflow Variable holding an action.
|
|
|
| Args:
|
| name: Name of the variables.
|
| i: Integer id.
|
| Returns:
|
| A [num_action_dims] tensor.
|
| """
|
| if i is not None:
|
| name += '_%d' % i
|
| assert name not in self._action_vars, ('Conflict! %s is already '
|
| 'initialized.') % name
|
| self._action_vars[name] = tf.Variable(
|
| self.sample_random_actions(1)[0], name='%s_action' % (name))
|
| self._validate_actions(tf.expand_dims(self._action_vars[name], 0))
|
| return self._action_vars[name]
|
|
|
| @gin.configurable('uvf_critic_function')
|
| def critic_function(self, critic_vals, states, critic_fn=None):
|
| """Computes q values based on outputs from the critic net.
|
|
|
| Args:
|
| critic_vals: A tf.float32 [batch_size, ...] tensor representing outputs
|
| from the critic net.
|
| states: A [batch_size, num_state_dims] tensor representing a batch
|
| of states.
|
| critic_fn: A callable that process outputs from critic_net and
|
| outputs a [batch_size] tensor representing q values.
|
| Returns:
|
| A tf.float32 [batch_size] tensor representing q values.
|
| """
|
| if critic_fn is not None:
|
| env_states, contexts = self.unmerged_states(states)
|
| critic_vals = critic_fn(critic_vals, env_states, contexts)
|
| critic_vals.shape.assert_has_rank(1)
|
| return critic_vals
|
|
|
| def get_action_vars(self, key):
|
| return self._action_vars[key]
|
|
|
| def get_context_vars(self, key):
|
| return self.tf_context.context_vars[key]
|
|
|
| def step_cond_fn(self, *args):
|
| return self._step_cond_fn(self, *args)
|
|
|
| def reset_episode_cond_fn(self, *args):
|
| return self._reset_episode_cond_fn(self, *args)
|
|
|
| def reset_env_cond_fn(self, *args):
|
| return self._reset_env_cond_fn(self, *args)
|
|
|
| @property
|
| def context_vars(self):
|
| return self.tf_context.vars
|
|
|
|
|
| @gin.configurable
|
| class MetaAgentCore(UvfAgentCore):
|
| """Defines basic functions for UVF Meta-agent. Must be inherited with an RL agent.
|
|
|
| Used as higher-level agent.
|
| """
|
|
|
| def __init__(self,
|
| observation_spec,
|
| action_spec,
|
| tf_env,
|
| tf_context,
|
| sub_context,
|
| step_cond_fn=cond_fn.env_transition,
|
| reset_episode_cond_fn=cond_fn.env_restart,
|
| reset_env_cond_fn=cond_fn.false_fn,
|
| metrics=None,
|
| actions_reg=0.,
|
| k=2,
|
| **base_agent_kwargs):
|
| """Constructs a Meta agent.
|
|
|
| Args:
|
| observation_spec: A TensorSpec defining the observations.
|
| action_spec: A BoundedTensorSpec defining the actions.
|
| tf_env: A Tensorflow environment object.
|
| tf_context: A Context class.
|
| step_cond_fn: A function indicating whether to increment the num of steps.
|
| reset_episode_cond_fn: A function indicating whether to restart the
|
| episode, resampling the context.
|
| reset_env_cond_fn: A function indicating whether to perform a manual reset
|
| of the environment.
|
| metrics: A list of functions that evaluate metrics of the agent.
|
| **base_agent_kwargs: A dictionary of parameters for base RL Agent.
|
| Raises:
|
| ValueError: If 'dqda_clipping' is < 0.
|
| """
|
| self._step_cond_fn = step_cond_fn
|
| self._reset_episode_cond_fn = reset_episode_cond_fn
|
| self._reset_env_cond_fn = reset_env_cond_fn
|
| self.metrics = metrics
|
| self._actions_reg = actions_reg
|
| self._k = k
|
|
|
|
|
| self.tf_context = tf_context(tf_env=tf_env)
|
| self.sub_context = sub_context(tf_env=tf_env)
|
| self.set_replay = self.tf_context.set_replay
|
| self.sample_contexts = self.tf_context.sample_contexts
|
| self.compute_rewards = self.tf_context.compute_rewards
|
| self.gamma_index = self.tf_context.gamma_index
|
| self.context_specs = self.tf_context.context_specs
|
| self.context_as_action_specs = self.tf_context.context_as_action_specs
|
| self.sub_context_as_action_specs = self.sub_context.context_as_action_specs
|
| self.init_context_vars = self.tf_context.create_vars
|
|
|
| self.env_observation_spec = observation_spec[0]
|
| merged_observation_spec = (uvf_utils.merge_specs(
|
| (self.env_observation_spec,) + self.context_specs),)
|
| self._context_vars = dict()
|
| self._action_vars = dict()
|
|
|
| assert len(self.context_as_action_specs) == 1
|
| self.BASE_AGENT_CLASS.__init__(
|
| self,
|
| observation_spec=merged_observation_spec,
|
| action_spec=self.sub_context_as_action_specs,
|
| **base_agent_kwargs
|
| )
|
|
|
| @gin.configurable('meta_add_noise_fn')
|
| def add_noise_fn(self, action_fn, stddev=1.0, debug=False,
|
| global_step=None):
|
| noisy_action_fn = super(MetaAgentCore, self).add_noise_fn(
|
| action_fn, stddev,
|
| clip=True, global_step=global_step)
|
| return noisy_action_fn
|
|
|
| def actor_loss(self, states, actions, rewards, discounts,
|
| next_states):
|
| """Returns the next action for the state.
|
|
|
| Args:
|
| state: A [num_state_dims] tensor representing a state.
|
| context: A list of [num_context_dims] tensor representing a context.
|
| Returns:
|
| A [num_action_dims] tensor representing the action.
|
| """
|
| actions = self.actor_net(states, stop_gradients=False)
|
| regularizer = self._actions_reg * tf.reduce_mean(
|
| tf.reduce_sum(tf.abs(actions[:, self._k:]), -1), 0)
|
| loss = self.BASE_AGENT_CLASS.actor_loss(self, states)
|
| return regularizer + loss
|
|
|
|
|
| @gin.configurable
|
| class UvfAgent(UvfAgentCore, ddpg_agent.TD3Agent):
|
| """A DDPG agent with UVF.
|
| """
|
| BASE_AGENT_CLASS = ddpg_agent.TD3Agent
|
| ACTION_TYPE = 'continuous'
|
|
|
| def __init__(self, *args, **kwargs):
|
| UvfAgentCore.__init__(self, *args, **kwargs)
|
|
|
|
|
| @gin.configurable
|
| class MetaAgent(MetaAgentCore, ddpg_agent.TD3Agent):
|
| """A DDPG meta-agent.
|
| """
|
| BASE_AGENT_CLASS = ddpg_agent.TD3Agent
|
| ACTION_TYPE = 'continuous'
|
|
|
| def __init__(self, *args, **kwargs):
|
| MetaAgentCore.__init__(self, *args, **kwargs)
|
|
|
|
|
| @gin.configurable()
|
| def state_preprocess_net(
|
| states,
|
| num_output_dims=2,
|
| states_hidden_layers=(100,),
|
| normalizer_fn=None,
|
| activation_fn=tf.nn.relu,
|
| zero_time=True,
|
| images=False):
|
| """Creates a simple feed forward net for embedding states.
|
| """
|
| with slim.arg_scope(
|
| [slim.fully_connected],
|
| activation_fn=activation_fn,
|
| normalizer_fn=normalizer_fn,
|
| weights_initializer=slim.variance_scaling_initializer(
|
| factor=1.0/3.0, mode='FAN_IN', uniform=True)):
|
|
|
| states_shape = tf.shape(states)
|
| states_dtype = states.dtype
|
| states = tf.to_float(states)
|
| if images:
|
| states *= tf.constant([0.] * 2 + [1.] * (states.shape[-1] - 2), dtype=states.dtype)
|
| if zero_time:
|
| states *= tf.constant([1.] * (states.shape[-1] - 1) + [0.], dtype=states.dtype)
|
| orig_states = states
|
| embed = states
|
| if states_hidden_layers:
|
| embed = slim.stack(embed, slim.fully_connected, states_hidden_layers,
|
| scope='states')
|
|
|
| with slim.arg_scope([slim.fully_connected],
|
| weights_regularizer=None,
|
| weights_initializer=tf.random_uniform_initializer(
|
| minval=-0.003, maxval=0.003)):
|
| embed = slim.fully_connected(embed, num_output_dims,
|
| activation_fn=None,
|
| normalizer_fn=None,
|
| scope='value')
|
|
|
| output = embed
|
| output = tf.cast(output, states_dtype)
|
| return output
|
|
|
|
|
| @gin.configurable()
|
| def action_embed_net(
|
| actions,
|
| states=None,
|
| num_output_dims=2,
|
| hidden_layers=(400, 300),
|
| normalizer_fn=None,
|
| activation_fn=tf.nn.relu,
|
| zero_time=True,
|
| images=False):
|
| """Creates a simple feed forward net for embedding actions.
|
| """
|
| with slim.arg_scope(
|
| [slim.fully_connected],
|
| activation_fn=activation_fn,
|
| normalizer_fn=normalizer_fn,
|
| weights_initializer=slim.variance_scaling_initializer(
|
| factor=1.0/3.0, mode='FAN_IN', uniform=True)):
|
|
|
| actions = tf.to_float(actions)
|
| if states is not None:
|
| if images:
|
| states *= tf.constant([0.] * 2 + [1.] * (states.shape[-1] - 2), dtype=states.dtype)
|
| if zero_time:
|
| states *= tf.constant([1.] * (states.shape[-1] - 1) + [0.], dtype=states.dtype)
|
| actions = tf.concat([actions, tf.to_float(states)], -1)
|
|
|
| embed = actions
|
| if hidden_layers:
|
| embed = slim.stack(embed, slim.fully_connected, hidden_layers,
|
| scope='hidden')
|
|
|
| with slim.arg_scope([slim.fully_connected],
|
| weights_regularizer=None,
|
| weights_initializer=tf.random_uniform_initializer(
|
| minval=-0.003, maxval=0.003)):
|
| embed = slim.fully_connected(embed, num_output_dims,
|
| activation_fn=None,
|
| normalizer_fn=None,
|
| scope='value')
|
| if num_output_dims == 1:
|
| return embed[:, 0, ...]
|
| else:
|
| return embed
|
|
|
|
|
| def huber(x, kappa=0.1):
|
| return (0.5 * tf.square(x) * tf.to_float(tf.abs(x) <= kappa) +
|
| kappa * (tf.abs(x) - 0.5 * kappa) * tf.to_float(tf.abs(x) > kappa)
|
| ) / kappa
|
|
|
|
|
| @gin.configurable()
|
| class StatePreprocess(object):
|
| STATE_PREPROCESS_NET_SCOPE = 'state_process_net'
|
| ACTION_EMBED_NET_SCOPE = 'action_embed_net'
|
|
|
| def __init__(self, trainable=False,
|
| state_preprocess_net=lambda states: states,
|
| action_embed_net=lambda actions, *args, **kwargs: actions,
|
| ndims=None):
|
| self.trainable = trainable
|
| self._scope = tf.get_variable_scope().name
|
| self._ndims = ndims
|
| self._state_preprocess_net = tf.make_template(
|
| self.STATE_PREPROCESS_NET_SCOPE, state_preprocess_net,
|
| create_scope_now_=True)
|
| self._action_embed_net = tf.make_template(
|
| self.ACTION_EMBED_NET_SCOPE, action_embed_net,
|
| create_scope_now_=True)
|
|
|
| def __call__(self, states):
|
| batched = states.get_shape().ndims != 1
|
| if not batched:
|
| states = tf.expand_dims(states, 0)
|
| embedded = self._state_preprocess_net(states)
|
| if self._ndims is not None:
|
| embedded = embedded[..., :self._ndims]
|
| if not batched:
|
| return embedded[0]
|
| return embedded
|
|
|
| def loss(self, states, next_states, low_actions, low_states):
|
| batch_size = tf.shape(states)[0]
|
| d = int(low_states.shape[1])
|
|
|
| probs = 0.99 ** tf.range(d, dtype=tf.float32)
|
| probs *= tf.constant([1.0] * (d - 1) + [1.0 / (1 - 0.99)],
|
| dtype=tf.float32)
|
| probs /= tf.reduce_sum(probs)
|
| index_dist = tf.distributions.Categorical(probs=probs, dtype=tf.int64)
|
| indices = index_dist.sample(batch_size)
|
| batch_size = tf.cast(batch_size, tf.int64)
|
| next_indices = tf.concat(
|
| [tf.range(batch_size, dtype=tf.int64)[:, None],
|
| (1 + indices[:, None]) % d], -1)
|
| new_next_states = tf.where(indices < d - 1,
|
| tf.gather_nd(low_states, next_indices),
|
| next_states)
|
| next_states = new_next_states
|
|
|
| embed1 = tf.to_float(self._state_preprocess_net(states))
|
| embed2 = tf.to_float(self._state_preprocess_net(next_states))
|
| action_embed = self._action_embed_net(
|
| tf.layers.flatten(low_actions), states=states)
|
|
|
| tau = 2.0
|
| fn = lambda z: tau * tf.reduce_sum(huber(z), -1)
|
| all_embed = tf.get_variable('all_embed', [1024, int(embed1.shape[-1])],
|
| initializer=tf.zeros_initializer())
|
| upd = all_embed.assign(tf.concat([all_embed[batch_size:], embed2], 0))
|
| with tf.control_dependencies([upd]):
|
| close = 1 * tf.reduce_mean(fn(embed1 + action_embed - embed2))
|
| prior_log_probs = tf.reduce_logsumexp(
|
| -fn((embed1 + action_embed)[:, None, :] - all_embed[None, :, :]),
|
| axis=-1) - tf.log(tf.to_float(all_embed.shape[0]))
|
| far = tf.reduce_mean(tf.exp(-fn((embed1 + action_embed)[1:] - embed2[:-1])
|
| - tf.stop_gradient(prior_log_probs[1:])))
|
| repr_log_probs = tf.stop_gradient(
|
| -fn(embed1 + action_embed - embed2) - prior_log_probs) / tau
|
| return close + far, repr_log_probs, indices
|
|
|
| def get_trainable_vars(self):
|
| return (
|
| slim.get_trainable_variables(
|
| uvf_utils.join_scope(self._scope, self.STATE_PREPROCESS_NET_SCOPE)) +
|
| slim.get_trainable_variables(
|
| uvf_utils.join_scope(self._scope, self.ACTION_EMBED_NET_SCOPE)))
|
|
|
|
|
| @gin.configurable()
|
| class InverseDynamics(object):
|
| INVERSE_DYNAMICS_NET_SCOPE = 'inverse_dynamics'
|
|
|
| def __init__(self, spec):
|
| self._spec = spec
|
|
|
| def sample(self, states, next_states, num_samples, orig_goals, sc=0.5):
|
| goal_dim = orig_goals.shape[-1]
|
| spec_range = (self._spec.maximum - self._spec.minimum) / 2 * tf.ones([goal_dim])
|
| loc = tf.cast(next_states - states, tf.float32)[:, :goal_dim]
|
| scale = sc * tf.tile(tf.reshape(spec_range, [1, goal_dim]),
|
| [tf.shape(states)[0], 1])
|
| dist = tf.distributions.Normal(loc, scale)
|
| if num_samples == 1:
|
| return dist.sample()
|
| samples = tf.concat([dist.sample(num_samples - 2),
|
| tf.expand_dims(loc, 0),
|
| tf.expand_dims(orig_goals, 0)], 0)
|
| return uvf_utils.clip_to_spec(samples, self._spec)
|
|
|