|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """TensorFlow utility functions.
|
| """
|
|
|
| from __future__ import absolute_import
|
| from __future__ import division
|
| from __future__ import print_function
|
|
|
| from copy import deepcopy
|
| import tensorflow as tf
|
| from tf_agents import specs
|
| from tf_agents.utils import common
|
|
|
| _tf_print_counts = dict()
|
| _tf_print_running_sums = dict()
|
| _tf_print_running_counts = dict()
|
| _tf_print_ids = 0
|
|
|
|
|
| def get_contextual_env_base(env_base, begin_ops=None, end_ops=None):
|
| """Wrap env_base with additional tf ops."""
|
|
|
| def init(self_, env_base):
|
| self_._env_base = env_base
|
| attribute_list = ["_render_mode", "_gym_env"]
|
| for attribute in attribute_list:
|
| if hasattr(env_base, attribute):
|
| setattr(self_, attribute, getattr(env_base, attribute))
|
| if hasattr(env_base, "physics"):
|
| self_._physics = env_base.physics
|
| elif hasattr(env_base, "gym"):
|
| class Physics(object):
|
| def render(self, *args, **kwargs):
|
| return env_base.gym.render("rgb_array")
|
| physics = Physics()
|
| self_._physics = physics
|
| self_.physics = physics
|
| def set_sess(self_, sess):
|
| self_._sess = sess
|
| if hasattr(self_._env_base, "set_sess"):
|
| self_._env_base.set_sess(sess)
|
| def begin_episode(self_):
|
| self_._env_base.reset()
|
| if begin_ops is not None:
|
| self_._sess.run(begin_ops)
|
| def end_episode(self_):
|
| self_._env_base.reset()
|
| if end_ops is not None:
|
| self_._sess.run(end_ops)
|
| return type("ContextualEnvBase", (env_base.__class__,), dict(
|
| __init__=init,
|
| set_sess=set_sess,
|
| begin_episode=begin_episode,
|
| end_episode=end_episode,
|
| ))(env_base)
|
|
|
|
|
|
|
| def merge_specs(specs_):
|
| """Merge TensorSpecs.
|
|
|
| Args:
|
| specs_: List of TensorSpecs to be merged.
|
| Returns:
|
| a TensorSpec: a merged TensorSpec.
|
| """
|
| shape = specs_[0].shape
|
| dtype = specs_[0].dtype
|
| name = specs_[0].name
|
| for spec in specs_[1:]:
|
| assert shape[1:] == spec.shape[1:], "incompatible shapes: %s, %s" % (
|
| shape, spec.shape)
|
| assert dtype == spec.dtype, "incompatible dtypes: %s, %s" % (
|
| dtype, spec.dtype)
|
| shape = merge_shapes((shape, spec.shape), axis=0)
|
| return specs.TensorSpec(
|
| shape=shape,
|
| dtype=dtype,
|
| name=name,
|
| )
|
|
|
|
|
| def merge_shapes(shapes, axis=0):
|
| """Merge TensorShapes.
|
|
|
| Args:
|
| shapes: List of TensorShapes to be merged.
|
| axis: optional, the axis to merge shaped.
|
| Returns:
|
| a TensorShape: a merged TensorShape.
|
| """
|
| assert len(shapes) > 1
|
| dims = deepcopy(shapes[0].dims)
|
| for shape in shapes[1:]:
|
| assert shapes[0].ndims == shape.ndims
|
| dims[axis] += shape.dims[axis]
|
| return tf.TensorShape(dims=dims)
|
|
|
|
|
| def get_all_vars(ignore_scopes=None):
|
| """Get all tf variables in scope.
|
|
|
| Args:
|
| ignore_scopes: A list of scope names to ignore.
|
| Returns:
|
| A list of all tf variables in scope.
|
| """
|
| all_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
|
| all_vars = [var for var in all_vars if ignore_scopes is None or not
|
| any(var.name.startswith(scope) for scope in ignore_scopes)]
|
| return all_vars
|
|
|
|
|
| def clip(tensor, range_=None):
|
| """Return a tf op which clips tensor according to range_.
|
|
|
| Args:
|
| tensor: A Tensor to be clipped.
|
| range_: None, or a tuple representing (minval, maxval)
|
| Returns:
|
| A clipped Tensor.
|
| """
|
| if range_ is None:
|
| return tf.identity(tensor)
|
| elif isinstance(range_, (tuple, list)):
|
| assert len(range_) == 2
|
| return tf.clip_by_value(tensor, range_[0], range_[1])
|
| else: raise NotImplementedError("Unacceptable range input: %r" % range_)
|
|
|
|
|
| def clip_to_bounds(value, minimum, maximum):
|
| """Clips value to be between minimum and maximum.
|
|
|
| Args:
|
| value: (tensor) value to be clipped.
|
| minimum: (numpy float array) minimum value to clip to.
|
| maximum: (numpy float array) maximum value to clip to.
|
| Returns:
|
| clipped_value: (tensor) `value` clipped to between `minimum` and `maximum`.
|
| """
|
| value = tf.minimum(value, maximum)
|
| return tf.maximum(value, minimum)
|
|
|
|
|
| clip_to_spec = common.clip_to_spec
|
| def _clip_to_spec(value, spec):
|
| """Clips value to a given bounded tensor spec.
|
|
|
| Args:
|
| value: (tensor) value to be clipped.
|
| spec: (BoundedTensorSpec) spec containing min. and max. values for clipping.
|
| Returns:
|
| clipped_value: (tensor) `value` clipped to be compatible with `spec`.
|
| """
|
| return clip_to_bounds(value, spec.minimum, spec.maximum)
|
|
|
|
|
| join_scope = common.join_scope
|
| def _join_scope(parent_scope, child_scope):
|
| """Joins a parent and child scope using `/`, checking for empty/none.
|
|
|
| Args:
|
| parent_scope: (string) parent/prefix scope.
|
| child_scope: (string) child/suffix scope.
|
| Returns:
|
| joined scope: (string) parent and child scopes joined by /.
|
| """
|
| if not parent_scope:
|
| return child_scope
|
| if not child_scope:
|
| return parent_scope
|
| return '/'.join([parent_scope, child_scope])
|
|
|
|
|
| def assign_vars(vars_, values):
|
| """Returns the update ops for assigning a list of vars.
|
|
|
| Args:
|
| vars_: A list of variables.
|
| values: A list of tensors representing new values.
|
| Returns:
|
| A list of update ops for the variables.
|
| """
|
| return [var.assign(value) for var, value in zip(vars_, values)]
|
|
|
|
|
| def identity_vars(vars_):
|
| """Return the identity ops for a list of tensors.
|
|
|
| Args:
|
| vars_: A list of tensors.
|
| Returns:
|
| A list of identity ops.
|
| """
|
| return [tf.identity(var) for var in vars_]
|
|
|
|
|
| def tile(var, batch_size=1):
|
| """Return tiled tensor.
|
|
|
| Args:
|
| var: A tensor representing the state.
|
| batch_size: Batch size.
|
| Returns:
|
| A tensor with shape [batch_size,] + var.shape.
|
| """
|
| batch_var = tf.tile(
|
| tf.expand_dims(var, 0),
|
| (batch_size,) + (1,) * var.get_shape().ndims)
|
| return batch_var
|
|
|
|
|
| def batch_list(vars_list):
|
| """Batch a list of variables.
|
|
|
| Args:
|
| vars_list: A list of tensor variables.
|
| Returns:
|
| A list of tensor variables with additional first dimension.
|
| """
|
| return [tf.expand_dims(var, 0) for var in vars_list]
|
|
|
|
|
| def tf_print(op,
|
| tensors,
|
| message="",
|
| first_n=-1,
|
| name=None,
|
| sub_messages=None,
|
| print_freq=-1,
|
| include_count=True):
|
| """tf.Print, but to stdout."""
|
|
|
| global _tf_print_ids
|
| _tf_print_ids += 1
|
| name = _tf_print_ids
|
| _tf_print_counts[name] = 0
|
| if print_freq > 0:
|
| _tf_print_running_sums[name] = [0 for _ in tensors]
|
| _tf_print_running_counts[name] = 0
|
| def print_message(*xs):
|
| """print message fn."""
|
| _tf_print_counts[name] += 1
|
| if print_freq > 0:
|
| for i, x in enumerate(xs):
|
| _tf_print_running_sums[name][i] += x
|
| _tf_print_running_counts[name] += 1
|
| if (print_freq <= 0 or _tf_print_running_counts[name] >= print_freq) and (
|
| first_n < 0 or _tf_print_counts[name] <= first_n):
|
| for i, x in enumerate(xs):
|
| if print_freq > 0:
|
| del x
|
| x = _tf_print_running_sums[name][i]/_tf_print_running_counts[name]
|
| if sub_messages is None:
|
| sub_message = str(i)
|
| else:
|
| sub_message = sub_messages[i]
|
| log_message = "%s, %s" % (message, sub_message)
|
| if include_count:
|
| log_message += ", count=%d" % _tf_print_counts[name]
|
| tf.logging.info("[%s]: %s" % (log_message, x))
|
| if print_freq > 0:
|
| for i, x in enumerate(xs):
|
| _tf_print_running_sums[name][i] = 0
|
| _tf_print_running_counts[name] = 0
|
| return xs[0]
|
|
|
| print_op = tf.py_func(print_message, tensors, tensors[0].dtype)
|
| with tf.control_dependencies([print_op]):
|
| op = tf.identity(op)
|
| return op
|
|
|
|
|
| periodically = common.periodically
|
| def _periodically(body, period, name='periodically'):
|
| """Periodically performs a tensorflow op."""
|
| if period is None or period == 0:
|
| return tf.no_op()
|
|
|
| if period < 0:
|
| raise ValueError("period cannot be less than 0.")
|
|
|
| if period == 1:
|
| return body()
|
|
|
| with tf.variable_scope(None, default_name=name):
|
| counter = tf.get_variable(
|
| "counter",
|
| shape=[],
|
| dtype=tf.int64,
|
| trainable=False,
|
| initializer=tf.constant_initializer(period, dtype=tf.int64))
|
|
|
| def _wrapped_body():
|
| with tf.control_dependencies([body()]):
|
| return counter.assign(1)
|
|
|
| update = tf.cond(
|
| tf.equal(counter, period), _wrapped_body,
|
| lambda: counter.assign_add(1))
|
|
|
| return update
|
|
|
| soft_variables_update = common.soft_variables_update
|
|
|