Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| # Copyright 2021 The Deeplab2 Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Implements Axial-Attention layers proposed in Axial-DeepLab. | |
| Axial-Attention factorizes 2D self-attention into two 1D self-attentions, so | |
| that it can be applied on large inputs. Axial-Attention is typically used to | |
| replace 3x3 convolutions in a bottleneck residual block. | |
| [1] Axial-Deeplab: Stand-Alone Axial-Attention for Panoptic Segmentation, | |
| ECCV 2020 Spotlight. | |
| Huiyu Wang, Yukun Zhu, Bradley Green, Hartwig Adam, Alan Yuille, | |
| Liang-Chieh Chen. | |
| """ | |
| import numpy as np | |
| import tensorflow as tf | |
| from deeplab2.model import utils | |
| from deeplab2.model.layers import activations | |
| from deeplab2.model.layers import positional_encodings | |
| class AxialAttention(tf.keras.layers.Layer): | |
| """An axial-attention layer.""" | |
| def __init__(self, | |
| query_shape=129, | |
| memory_flange=32, | |
| total_key_depth=512, | |
| total_value_depth=1024, | |
| num_heads=8, | |
| name='axial_attention', | |
| use_query_rpe_similarity=True, | |
| use_key_rpe_similarity=True, | |
| use_content_similarity=True, | |
| retrieve_value_rpe=True, | |
| retrieve_value_content=True, | |
| initialization_std_for_query_key_rpe=1.0, | |
| initialization_std_for_value_rpe=1.0, | |
| self_attention_activation='softmax', | |
| bn_layer=tf.keras.layers.BatchNormalization, | |
| conv_kernel_weight_decay=0.0): | |
| """Initializes an axial-attention layer. | |
| This function is designed to support both global and local axial-attention | |
| in a unified way. If query_shape is larger than the length of input, a | |
| global attention is applied. If query_shape is smaller than the length of | |
| input, a local attention is applied. In this case, the input is divided into | |
| blocks of length query_shape, padded by memory_flange on both sides. Then, | |
| local attention is applied within each query block. The choice of | |
| query_shape does not affect the output value but affects computation | |
| efficiency and memory usage. In general, use global attention (large | |
| query_shape) if possible. Local axial-attention has not been supported yet. | |
| Args: | |
| query_shape: An integer, the block size for local axial attention. | |
| Defaults to 129 since 129 is usually the largest feature map where we do | |
| global attention (1025 with stride 8, or 2049 with stride 16). | |
| memory_flange: An integer, the memory flange padded to each query block in | |
| local attention. It has no effect in global attention. Defaults to 32, | |
| which is equivalent to a span of 65 in Aixal-DeepLab paper -- A pixel | |
| can see 32 pixels on its left and 32 pixels on its right. | |
| total_key_depth: An integer, the total depth of keys, which is also the | |
| depth of queries and the depth of key (query) positional encodings. | |
| total_value_depth: An integer, the total depth of the values, which is | |
| also the depth of value positional encodings. | |
| num_heads: An integer, the number of heads in multi-head attention. | |
| name: A string, the name of this axial attention layer. | |
| use_query_rpe_similarity: A boolean, whether to use the attention | |
| similarity between the queries and the relative positional encodings. | |
| use_key_rpe_similarity: A boolean, whether to use the attention similarity | |
| between the keys and the relative positional encodings. | |
| use_content_similarity: A boolean, whether to use the content similarity | |
| between the queries and the keys. | |
| retrieve_value_rpe: A boolean, whether to retrieve the relative positional | |
| encodings of the values. | |
| retrieve_value_content: A boolean, whether to retrieve the content of the | |
| values. | |
| initialization_std_for_query_key_rpe: A float, the initialization std for | |
| the relative positional encodings of the queries and keys. | |
| initialization_std_for_value_rpe: A float, the initialization std for the | |
| relative positional encodings of the values. | |
| self_attention_activation: A string, type of activation function for | |
| self-attention. Support 'sigmoid' and 'softmax'. | |
| bn_layer: A tf.keras.layers.Layer that computes the normalization | |
| (default: tf.keras.layers.BatchNormalization). | |
| conv_kernel_weight_decay: A float, the weight decay for convolution | |
| kernels. | |
| Returns: | |
| output: A [batch, length, total_value_depth] tensor. | |
| Raises: | |
| ValueError: If none of the three similarities (use_query_rpe_similarity, | |
| use_key_rpe_similarity, use_content_similarity) is used. | |
| ValueError: If neither of value content or value rpe is retrieved. | |
| ValueError: If self_attention_activation is not supported. | |
| ValueError: If total_key_depth is not divisible by num_heads. | |
| ValueError: If total_value_depth is not divisible by num_heads. | |
| """ | |
| # Validate the attention similarity choices. | |
| if not any([ | |
| use_content_similarity, use_key_rpe_similarity, use_query_rpe_similarity | |
| ]): | |
| raise ValueError( | |
| 'Should use at least one similarity to compute attention.') | |
| # Validate the retrieve value choices. | |
| if not retrieve_value_content and not retrieve_value_rpe: | |
| raise ValueError('Should retrieve at least one of content or rpe.') | |
| if total_key_depth % num_heads: | |
| raise ValueError('Total_key_depth should be divisible by num_heads.') | |
| if total_value_depth % num_heads: | |
| raise ValueError('Total_value_depth should be divisible by num_heads.') | |
| super(AxialAttention, self).__init__(name=name) | |
| self._query_shape = query_shape | |
| self._memory_flange = memory_flange | |
| self._total_key_depth = total_key_depth | |
| self._total_value_depth = total_value_depth | |
| self._num_heads = num_heads | |
| self._use_query_rpe_similarity = use_query_rpe_similarity | |
| self._use_key_rpe_similarity = use_key_rpe_similarity | |
| self._use_content_similarity = use_content_similarity | |
| self._retrieve_value_rpe = retrieve_value_rpe | |
| self._retrieve_value_content = retrieve_value_content | |
| self._initialization_std_for_query_key_rpe = ( | |
| initialization_std_for_query_key_rpe) | |
| self._initialization_std_for_value_rpe = initialization_std_for_value_rpe | |
| self._self_attention_activation = self_attention_activation | |
| self._conv_kernel_weight_decay = conv_kernel_weight_decay | |
| self._batch_norm_qkv = bn_layer(axis=-1, name='batch_norm_qkv') | |
| self._batch_norm_similarity = bn_layer( | |
| axis=[0, 2], name='batch_norm_similarity') | |
| self._batch_norm_retrieved_output = bn_layer( | |
| axis=[0, 2, 4], name='batch_norm_retrieved_output') | |
| self._key_depth_per_head = total_key_depth // num_heads | |
| self._attention_activate_fn = activations.get_activation( | |
| self_attention_activation) | |
| def build(self, input_shape): | |
| """Builds axial-attention layer weights. | |
| Args: | |
| input_shape: An integer list of length 3, the shape of the input tensor. | |
| Raises: | |
| NotImplementedError: Local axial-attention has not been implemented. It is | |
| triggered if query_shape is less than input_shape. | |
| """ | |
| # Apply global attention if query_shape is larger than input_shape[1]. | |
| if self._query_shape >= input_shape[1]: | |
| self._query_shape = input_shape[1] | |
| self._memory_flange = 0 | |
| else: | |
| raise NotImplementedError('Local axial attention has not been ' | |
| 'implemented yet.') | |
| self._memory_shape = self._query_shape + 2 * self._memory_flange | |
| # Compute query key value with one convolution and an optional batch norm. | |
| # The initialization std is standard transformer initialization (without | |
| # batch norm), as used in SASA and ViT. In our case, we use batch norm by | |
| # default, so it does not require careful tuning. If one wants to remove | |
| # all batch norms in axial attention, this standard initialization should | |
| # still be good, but a more careful initialization is encouraged. | |
| self.qkv_kernel = self.add_weight( | |
| name='qkv_kernel', | |
| shape=[input_shape[-1], | |
| self._total_key_depth * 2 + self._total_value_depth], | |
| initializer=tf.keras.initializers.TruncatedNormal( | |
| stddev=input_shape[-1]**-0.5), | |
| regularizer=tf.keras.regularizers.l2(self._conv_kernel_weight_decay)) | |
| if self._use_query_rpe_similarity: | |
| self._query_rpe = positional_encodings.RelativePositionalEncoding( | |
| self._query_shape, | |
| self._memory_shape, | |
| self._key_depth_per_head, | |
| self._num_heads, | |
| 'query_rpe', | |
| initialization_std=self._initialization_std_for_query_key_rpe, | |
| conv_kernel_weight_decay=self._conv_kernel_weight_decay) | |
| if self._use_key_rpe_similarity: | |
| self._key_rpe = positional_encodings.RelativePositionalEncoding( | |
| self._query_shape, | |
| self._memory_shape, | |
| self._key_depth_per_head, | |
| self._num_heads, | |
| 'key_rpe', | |
| initialization_std=self._initialization_std_for_query_key_rpe, | |
| conv_kernel_weight_decay=self._conv_kernel_weight_decay) | |
| if self._retrieve_value_rpe: | |
| self._value_rpe = positional_encodings.RelativePositionalEncoding( | |
| self._query_shape, | |
| self._memory_shape, | |
| self._total_value_depth // self._num_heads, | |
| self._num_heads, | |
| 'value_rpe', | |
| initialization_std=self._initialization_std_for_value_rpe, | |
| conv_kernel_weight_decay=self._conv_kernel_weight_decay) | |
| def call(self, input_tensor, training=False): | |
| """Performs a forward pass. | |
| Args: | |
| input_tensor: An input [batch, length, channel] tensor. | |
| training: A boolean flag indicating whether training behavior should be | |
| used (default: False). | |
| Returns: | |
| output: An output [batch, length, total_value_depth] tensor. | |
| """ | |
| # Alternatively, the einsum can be implemented as a 1x1 convolution. | |
| # However, it is not obvious which implementation is more efficient (without | |
| # careful benchmarking), so we use einsum for its flexibility and | |
| # consistency with other parts of the function. | |
| query_key_value = tf.einsum( | |
| 'nlc,cd->nld', input_tensor, self.qkv_kernel, name='compute_qkv') | |
| query_key_value = self._batch_norm_qkv(query_key_value, training=training) | |
| # Split query key value. | |
| query, key, value = tf.split( | |
| query_key_value, | |
| [self._total_key_depth, self._total_key_depth, self._total_value_depth], | |
| axis=-1) | |
| # Reshape the query, key, and value. | |
| query = tf.reshape(query, [-1, self._query_shape, self._num_heads, | |
| self._key_depth_per_head]) | |
| query = tf.transpose(a=query, perm=[0, 2, 1, 3]) | |
| key = tf.reshape(key, [-1, np.prod(self._memory_shape), self._num_heads, | |
| self._key_depth_per_head]) | |
| key = tf.transpose(a=key, perm=[0, 2, 1, 3]) | |
| value = tf.reshape(value, [-1, np.prod(self._memory_shape), self._num_heads, | |
| self._total_value_depth // self._num_heads]) | |
| # Gather all similarity logits into a list. | |
| similarity_logits = [] | |
| # Compute the content similarity term: q * k. | |
| if self._use_content_similarity: | |
| content_similarity = tf.einsum( | |
| 'bhld,bhmd->bhlm', query, key, name='content_similarity') | |
| similarity_logits.append(content_similarity) | |
| # Compute the query rpe similarity term: q * rpe. | |
| if self._use_query_rpe_similarity: | |
| query_rpe = self._query_rpe(None) | |
| query_rpe_similarity = tf.einsum( | |
| 'bhld,hlmd->bhlm', query, query_rpe, name='query_rpe_similarity') | |
| similarity_logits.append(query_rpe_similarity) | |
| # Compute the key rpe similarity term: k * rpe. | |
| if self._use_key_rpe_similarity: | |
| key_rpe = self._key_rpe(None) | |
| key_rpe_similarity = tf.einsum( | |
| 'bhmd,hlmd->bhlm', key, key_rpe, name='key_rpe_similarity') | |
| similarity_logits.append(key_rpe_similarity) | |
| # Apply an optional batch norm to the similarities and sum them. | |
| similarity_logits = tf.stack(similarity_logits) | |
| similarity_logits = self._batch_norm_similarity(similarity_logits, | |
| training=training) | |
| similarity_logits = tf.reduce_sum(input_tensor=similarity_logits, axis=0) | |
| # Apply an attention activation function, e.g. softmax. | |
| weights = self._attention_activate_fn(similarity_logits) | |
| # Gather retrieved values or rpes into a list. | |
| retrieve_list = [] | |
| # Retrieve the content of the attended value. | |
| if self._retrieve_value_content: | |
| retrieved_content = tf.einsum( | |
| 'bhlm,bmhd->bhld', weights, value, name='retrieve_value_content') | |
| retrieve_list.append(retrieved_content) | |
| # Retrieve the relative position of the attended value. | |
| if self._retrieve_value_rpe: | |
| value_rpe = self._value_rpe(None) | |
| retrieved_rpe = tf.einsum( | |
| 'bhlm,hlmd->bhld', weights, value_rpe, name='retrieve_value_rpe') | |
| retrieve_list.append(retrieved_rpe) | |
| # Apply batch norms to retrieved contents and rpes respectively. | |
| retrieved_output = tf.stack(retrieve_list) | |
| retrieved_output = self._batch_norm_retrieved_output(retrieved_output, | |
| training=training) | |
| # Additive contents and rpes. | |
| retrieved_output = tf.reduce_sum(input_tensor=retrieved_output, axis=0) | |
| # Combine the heads by transposing and reshaping the tensor. | |
| retrieved_output = utils.transpose_and_reshape_for_attention_operation( | |
| retrieved_output) | |
| return retrieved_output | |
| class AxialAttention2D(tf.keras.layers.Layer): | |
| """Sequentially applies height-axis and width-axis axial-attention.""" | |
| def __init__(self, | |
| strides=1, | |
| filters=512, | |
| name='attention', | |
| key_expansion=1, | |
| value_expansion=2, | |
| query_shape=(129, 129), | |
| memory_flange=(32, 32), | |
| **kwargs): | |
| """Initializes an AxialAttention2D layer. | |
| Args: | |
| strides: An integer, the stride for the output, usually 1 or 2. | |
| filters: An integer, the base number of channels for the layer. | |
| name: A string, the name of the attention layer. | |
| key_expansion: A float, the channel expansion ratio for keys. | |
| value_expansion: A float, the channel expansion ratio for values. | |
| query_shape: An integer, the maximum query shape for both the height axis | |
| and the width axis. | |
| memory_flange: An integer list of length 2. The memory flange for the | |
| height axis and the width axis. | |
| **kwargs: A dictionary of keyword arguments passed to height-axis, | |
| width-axis, and 2D global AxialAttention. | |
| Returns: | |
| output: A [batch, strided height, strided width, output_channels] tensor. | |
| """ | |
| super(AxialAttention2D, self).__init__(name=name) | |
| total_key_depth = int(round(filters * key_expansion)) | |
| total_value_depth = int(round(filters * value_expansion)) | |
| self._strides = strides | |
| self._total_key_depth = total_key_depth | |
| self._total_value_depth = total_value_depth | |
| self._height_axis = AxialAttention( | |
| total_key_depth=total_key_depth, | |
| total_value_depth=total_value_depth, | |
| query_shape=query_shape[0], | |
| memory_flange=memory_flange[0], | |
| name='height_axis', | |
| **kwargs) | |
| self._width_axis = AxialAttention( | |
| total_key_depth=total_key_depth, | |
| total_value_depth=total_value_depth, | |
| query_shape=query_shape[1], | |
| memory_flange=memory_flange[1], | |
| name='width_axis', | |
| **kwargs) | |
| def call(self, inputs, training=False): | |
| """Performs a forward pass. | |
| Args: | |
| inputs: An input [batch, height, width, channel] tensor. | |
| training: A boolean flag indicating whether training behavior should be | |
| used (default: False). | |
| Returns: | |
| output: An output [batch, strided_height, strided_width, | |
| filters * value_expansion] tensor. | |
| """ | |
| _, height, width, channel = inputs.get_shape().as_list() | |
| # Transpose and reshape the width axis to the batch dimension. | |
| x = tf.transpose(a=inputs, perm=[0, 2, 1, 3]) | |
| x = tf.reshape(x, [-1, height, channel]) | |
| x = self._height_axis(x, training=training) | |
| # Reshape and transpose back to a 4D tensor. | |
| x = tf.reshape(x, [-1, width, height, self._total_value_depth]) | |
| x = tf.transpose(a=x, perm=[0, 2, 1, 3]) | |
| # Height axis striding. | |
| if self._strides > 1: | |
| x = x[:, ::self._strides, :, :] | |
| # Reshape the height axis to the batch dimension. | |
| _, strided_height, _, _ = x.get_shape().as_list() | |
| x = tf.reshape(x, [-1, width, self._total_value_depth]) | |
| x = self._width_axis(x, training=training) | |
| # Reshape back to a 4D tensor. | |
| x = tf.reshape(x, [-1, strided_height, width, self._total_value_depth]) | |
| # Width axis striding. | |
| if self._strides > 1: | |
| x = x[:, :, ::self._strides, :] | |
| return x | |
| class GlobalAttention2D(tf.keras.layers.Layer): | |
| """A 2D global attention layer.""" | |
| def __init__(self, | |
| strides=1, | |
| filters=512, | |
| name='attention', | |
| key_expansion=1, | |
| value_expansion=2, | |
| query_shape=(129, 129), | |
| memory_flange=(32, 32), | |
| double_global_attention=False, | |
| **kwargs): | |
| """Initializes a GlobalAttention2D layer. | |
| Args: | |
| strides: An integer, the stride for the output, usually 1 or 2. | |
| filters: An integer, the base number of channels for the layer. | |
| name: A string, the name of the attention layer. | |
| key_expansion: A float, the channel expansion ratio for keys. | |
| value_expansion: A float, the channel expansion ratio for values. | |
| query_shape: An integer, the maximum query shape for both the height axis | |
| and the width axis. | |
| memory_flange: An integer list of length 2. The memory flange for the | |
| height axis and the width axis. | |
| double_global_attention: A boolean, whether to use two global attention | |
| layers. Two global attention layers match the parameter count to a | |
| seqentially applied height and width axial attention layer. | |
| **kwargs: A dictionary of keyword arguments passed to height-axis, | |
| width-axis, and 2D global AxialAttention. | |
| Returns: | |
| output: A [batch, strided height, strided width, output_channels] tensor. | |
| Raises: | |
| ValueError: If relative positional encoding is enforced in kwargs. | |
| """ | |
| if any([kwargs.get('use_query_rpe_similarity', False), | |
| kwargs.get('use_key_rpe_similarity', False), | |
| kwargs.get('retrieve_value_rpe', False)]): | |
| raise ValueError('GlobalAttention2D does not support relative positional ' | |
| 'encodings.') | |
| super(GlobalAttention2D, self).__init__(name=name) | |
| total_key_depth = int(round(filters * key_expansion)) | |
| total_value_depth = int(round(filters * value_expansion)) | |
| self._strides = strides | |
| self._double_global_attention = double_global_attention | |
| self._total_key_depth = total_key_depth | |
| self._total_value_depth = total_value_depth | |
| # Global attention does not support relative positional encodings. | |
| kwargs['use_query_rpe_similarity'] = False | |
| kwargs['use_key_rpe_similarity'] = False | |
| kwargs['retrieve_value_rpe'] = False | |
| self._kwargs = kwargs | |
| def build(self, input_shape): | |
| """Builds global attention layers according to the 4D input_shape.""" | |
| _, height, width, _ = input_shape | |
| # Implement 2D global attention as 1D axial-attention by flattening the 2D | |
| # inputs into 1D. We also disable the relative positional encodings in | |
| # axial attention, so that only content-based attention is used. The query | |
| # shape is set to height * width, so that the axial attention is global. | |
| self._global = AxialAttention( | |
| total_key_depth=self._total_key_depth, | |
| total_value_depth=self._total_value_depth, | |
| query_shape=height*width, | |
| memory_flange=0, | |
| name='global', | |
| **self._kwargs) | |
| # Use two global attention layers in one residual block. This option | |
| # ensures that global attention models have similar number of layers and | |
| # parameters as axial-attention models. | |
| if self._double_global_attention: | |
| self._global2 = AxialAttention( | |
| total_key_depth=self._total_key_depth, | |
| total_value_depth=self._total_value_depth, | |
| query_shape=height*width, | |
| memory_flange=0, | |
| name='global2', | |
| **self._kwargs) | |
| def call(self, inputs, training=False): | |
| """Performs a forward pass. | |
| Args: | |
| inputs: An input [batch, height, width, channel] tensor. | |
| training: A boolean flag indicating whether training behavior should be | |
| used (default: False). | |
| Returns: | |
| output: An output [batch, strided_height, strided_width, | |
| filters * value_expansion] tensor. | |
| """ | |
| _, height, width, channel = inputs.get_shape().as_list() | |
| # Reshape the inputs so that the attention is global 2D. | |
| x = tf.reshape(inputs, [-1, height * width, channel]) | |
| # Implement 2D global attention as 1D axial-attention by flattening the 2D | |
| # inputs into 1D. We also disable the relative positional encodings in | |
| # axial attention, so that only content-based attention is used. | |
| x = self._global(x, training=training) | |
| # Use two global attention layers in one residual block. This option | |
| # ensures that global attention models have the same number of layers and | |
| # parameters as axial-attention models. | |
| if self._double_global_attention: | |
| x = self._global2(x, training=training) | |
| x = tf.reshape(x, [-1, height, width, self._total_value_depth]) | |
| if self._strides > 1: | |
| x = x[:, ::self._strides, ::self._strides, :] | |
| return x | |