|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Utils for MAE."""
|
|
|
| import math
|
| import tensorflow as tf, tf_keras
|
| from official.modeling import tf_utils
|
|
|
|
|
|
|
| def position_embedding_sine(attention_mask,
|
| num_pos_features=256,
|
| temperature=10000.,
|
| normalize=True,
|
| scale=2 * math.pi):
|
| """Sine-based positional embeddings for 2D images.
|
|
|
| Args:
|
| attention_mask: a `bool` Tensor specifying the size of the input image to
|
| the Transformer and which elements are padded, of size [batch_size,
|
| height, width]
|
| num_pos_features: a `int` specifying the number of positional features,
|
| should be equal to the hidden size of the Transformer network
|
| temperature: a `float` specifying the temperature of the positional
|
| embedding. Any type that is converted to a `float` can also be accepted.
|
| normalize: a `bool` determining whether the positional embeddings should be
|
| normalized between [0, scale] before application of the sine and cos
|
| functions.
|
| scale: a `float` if normalize is True specifying the scale embeddings before
|
| application of the embedding function.
|
|
|
| Returns:
|
| embeddings: a `float` tensor of the same shape as input_tensor specifying
|
| the positional embeddings based on sine features.
|
| """
|
| if num_pos_features % 2 != 0:
|
| raise ValueError(
|
| "Number of embedding features (num_pos_features) must be even when "
|
| "column and row embeddings are concatenated.")
|
| num_pos_features = num_pos_features // 2
|
|
|
|
|
|
|
| attention_mask = tf.cast(attention_mask, tf.float32)
|
| row_embedding = tf.cumsum(attention_mask, 1)
|
| col_embedding = tf.cumsum(attention_mask, 2)
|
|
|
| if normalize:
|
| eps = 1e-6
|
| row_embedding = row_embedding / (row_embedding[:, -1:, :] + eps) * scale
|
| col_embedding = col_embedding / (col_embedding[:, :, -1:] + eps) * scale
|
|
|
| dim_t = tf.range(num_pos_features, dtype=row_embedding.dtype)
|
| dim_t = tf.pow(temperature, 2 * (dim_t // 2) / num_pos_features)
|
|
|
|
|
|
|
| pos_row = tf.expand_dims(row_embedding, -1) / dim_t
|
| pos_col = tf.expand_dims(col_embedding, -1) / dim_t
|
| pos_row = tf.stack(
|
| [tf.sin(pos_row[:, :, :, 0::2]),
|
| tf.cos(pos_row[:, :, :, 1::2])], axis=4)
|
| pos_col = tf.stack(
|
| [tf.sin(pos_col[:, :, :, 0::2]),
|
| tf.cos(pos_col[:, :, :, 1::2])], axis=4)
|
|
|
|
|
| final_shape = tf_utils.get_shape_list(pos_row)[:3] + [-1]
|
| pos_row = tf.reshape(pos_row, final_shape)
|
| pos_col = tf.reshape(pos_col, final_shape)
|
| output = tf.concat([pos_row, pos_col], -1)
|
|
|
| embeddings = tf.cast(output, tf.float32)
|
| return embeddings
|
|
|