| """ Unused for the moment, I believe. """ | |
| import tensorflow as tf | |
| def get_shape(tensor): | |
| """Returns the tensor's shape. | |
| Each shape element is either: | |
| - an `int`, when static shape values are available, or | |
| - a `tf.Tensor`, when the shape is dynamic. | |
| Args: | |
| tensor: A `tf.Tensor` to get the shape of. | |
| Returns: | |
| The `list` which contains the tensor's shape. | |
| """ | |
| shape_list = tensor.shape.as_list() | |
| if all(s is not None for s in shape_list): | |
| return shape_list | |
| shape_tensor = tf.shape(tensor) | |
| return [shape_tensor[i] if s is None else s for i, s in | |
| enumerate(shape_list)] | |
| def repeat(tensor, repeats, axis=0): | |
| """Repeats a `tf.Tensor`'s elements along an axis by custom amounts. | |
| Equivalent to Numpy's `np.repeat`. | |
| `tensor and `repeats` must have the same numbers of elements along `axis`. | |
| Args: | |
| tensor: A `tf.Tensor` to repeat. | |
| repeats: A 1D sequence of the number of repeats per element. | |
| axis: An axis to repeat along. Defaults to 0. | |
| name: (string, optional) A name for the operation. | |
| Returns: | |
| The `tf.Tensor` with repeated values. | |
| """ | |
| cumsum = tf.cumsum(repeats) | |
| range_ = tf.range(cumsum[-1]) | |
| indicator_matrix = tf.cast(tf.expand_dims(range_, 1) >= cumsum, tf.int32) | |
| indices = tf.reduce_sum(indicator_matrix, reduction_indices=1) | |
| shifted_tensor = _axis_to_inside(tensor, axis) | |
| repeated_shifted_tensor = tf.gather(shifted_tensor, indices) | |
| repeated_tensor = _inside_to_axis(repeated_shifted_tensor, axis) | |
| shape = tensor.shape.as_list() | |
| shape[axis] = None | |
| repeated_tensor.set_shape(shape) | |
| return repeated_tensor | |
| def _axis_to_inside(tensor, axis): | |
| """Shifts a given axis of a tensor to be the innermost axis. | |
| Args: | |
| tensor: A `tf.Tensor` to shift. | |
| axis: An `int` or `tf.Tensor` that indicates which axis to shift. | |
| Returns: | |
| The shifted tensor. | |
| """ | |
| axis = tf.convert_to_tensor(axis) | |
| rank = tf.rank(tensor) | |
| range0 = tf.range(0, limit=axis) | |
| range1 = tf.range(tf.add(axis, 1), limit=rank) | |
| perm = tf.concat([[axis], range0, range1], 0) | |
| return tf.transpose(tensor, perm=perm) | |
| def _inside_to_axis(tensor, axis): | |
| """Shifts the innermost axis of a tensor to some other axis. | |
| Args: | |
| tensor: A `tf.Tensor` to shift. | |
| axis: An `int` or `tf.Tensor` that indicates which axis to shift. | |
| Returns: | |
| The shifted tensor. | |
| """ | |
| axis = tf.convert_to_tensor(axis) | |
| rank = tf.rank(tensor) | |
| range0 = tf.range(1, limit=axis + 1) | |
| range1 = tf.range(tf.add(axis, 1), limit=rank) | |
| perm = tf.concat([range0, [0], range1], 0) | |
| return tf.transpose(tensor, perm=perm) | |