File size: 2,727 Bytes
233f6d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
""" Unused for the moment, I believe. """

import tensorflow as tf

def get_shape(tensor):
    """Returns the tensor's shape.
    Each shape element is either:
    - an `int`, when static shape values are available, or
    - a `tf.Tensor`, when the shape is dynamic.
    Args:
    tensor: A `tf.Tensor` to get the shape of.
    Returns:
    The `list` which contains the tensor's shape.
    """

    shape_list = tensor.shape.as_list()
    if all(s is not None for s in shape_list):
        return shape_list
    
    shape_tensor = tf.shape(tensor)
    return [shape_tensor[i] if s is None else s for i, s in
            enumerate(shape_list)]


def repeat(tensor, repeats, axis=0):
    """Repeats a `tf.Tensor`'s elements along an axis by custom amounts.
    Equivalent to Numpy's `np.repeat`.
    `tensor and `repeats` must have the same numbers of elements along `axis`.
    Args:
    tensor: A `tf.Tensor` to repeat.
    repeats: A 1D sequence of the number of repeats per element.
    axis: An axis to repeat along. Defaults to 0.
    name: (string, optional) A name for the operation.
    Returns:
    The `tf.Tensor` with repeated values.
    """

    cumsum = tf.cumsum(repeats)
    range_ = tf.range(cumsum[-1])

    indicator_matrix = tf.cast(tf.expand_dims(range_, 1) >= cumsum, tf.int32)
    indices = tf.reduce_sum(indicator_matrix, reduction_indices=1)

    shifted_tensor = _axis_to_inside(tensor, axis)
    repeated_shifted_tensor = tf.gather(shifted_tensor, indices)
    repeated_tensor = _inside_to_axis(repeated_shifted_tensor, axis)

    shape = tensor.shape.as_list()
    shape[axis] = None
    repeated_tensor.set_shape(shape)

    return repeated_tensor


def _axis_to_inside(tensor, axis):
    """Shifts a given axis of a tensor to be the innermost axis.
    Args:
        tensor: A `tf.Tensor` to shift.
        axis: An `int` or `tf.Tensor` that indicates which axis to shift.
    Returns:
        The shifted tensor.
    """

    axis = tf.convert_to_tensor(axis)
    rank = tf.rank(tensor)

    range0 = tf.range(0, limit=axis)
    range1 = tf.range(tf.add(axis, 1), limit=rank)
    perm = tf.concat([[axis], range0, range1], 0)

    return tf.transpose(tensor, perm=perm)


def _inside_to_axis(tensor, axis):
    """Shifts the innermost axis of a tensor to some other axis.
    Args:
        tensor: A `tf.Tensor` to shift.
        axis: An `int` or `tf.Tensor` that indicates which axis to shift.
    Returns:
        The shifted tensor.
    """

    axis = tf.convert_to_tensor(axis)
    rank = tf.rank(tensor)

    range0 = tf.range(1, limit=axis + 1)
    range1 = tf.range(tf.add(axis, 1), limit=rank)
    perm = tf.concat([range0, [0], range1], 0)

    return tf.transpose(tensor, perm=perm)