| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Implementation of AxialResNet with group norm and weight standardization. |
| |
| Ported from: |
| https://arxiv.org/abs/2003.07853 |
| based on: |
| https://github.com/csrhddlam/axial-deeplab/tree/optimize |
| """ |
|
|
| from typing import Dict, Tuple, Union |
|
|
| import flax.linen as nn |
| import jax |
| import jax.numpy as jnp |
| import ml_collections |
| from scenic.model_lib.base_models.multilabel_classification_model import MultiLabelClassificationModel |
| from scenic.model_lib.layers import nn_layers |
| from scenic.model_lib.layers import nn_ops |
| from scenic.projects.baselines import bit_resnet |
|
|
|
|
| class SelfAttentionWith1DRelativePos(nn.Module): |
| """Multi-head dot-product self-attention with reltive positional encodeing. |
| |
| Attributes: |
| num_heads: Number of attention heads. |
| """ |
|
|
| num_heads: int |
|
|
| @nn.compact |
| def __call__(self, x: jnp.ndarray) -> jnp.ndarray: |
| """Applies multi-head dot product self-attention on the input data. |
| |
| Args: |
| x: input of shape `[bs, len, features]`. |
| |
| Returns: |
| output of shape `[bs, len, features]`. |
| """ |
| bs, input_len, features = x.shape |
| if features % (2 * self.num_heads) != 0: |
| raise ValueError( |
| f'Inputs feature dimension {features} must be divisible by 2 * number ' |
| f'of heads {2*self.num_heads}.') |
| head_features = features // self.num_heads |
|
|
| |
| qkv = nn.Dense( |
| features=features * 2, |
| kernel_init=nn.initializers.normal(stddev=1.0 / features), |
| use_bias=False, |
| dtype=jnp.float32)( |
| x) |
|
|
| |
| qkv = nn.GroupNorm(epsilon=1e-4, name='gn_qkv')(qkv) |
| |
| qkv = qkv.reshape(bs, input_len, self.num_heads, |
| head_features * 2).transpose((0, 2, 3, 1)) |
| |
| |
| query, key, value = jnp.split( |
| qkv, [head_features // 2, head_features], axis=2) |
|
|
| |
| length = query.shape[-1] |
| relative_emb = self.param( |
| 'relative_pos_emb', nn.initializers.normal(stddev=1.0 / head_features), |
| (head_features * 2, 2 * length - 1), jnp.float32) |
| relative_pos_emb = jnp.take( |
| relative_emb, |
| nn_ops.compute_1d_relative_distance(length, length), |
| axis=-1) |
| relative_pos_emb_q, relative_pos_emb_k, relative_pos_emb_v = jnp.split( |
| relative_pos_emb, [head_features // 2, head_features], axis=0) |
|
|
| |
| |
| |
| |
| |
| |
| qr_attn_logits = jnp.einsum('bhdi,dij->bhij', query, relative_pos_emb_k) |
| |
| |
| kr_attn_logits = jnp.einsum('bhdi,dij->bhij', key, relative_pos_emb_q) |
| |
| qk_attn_logits = jnp.einsum('bhdi,bhdj->bhij', query, key) |
| |
| attn_weights = qk_attn_logits + qr_attn_logits + kr_attn_logits |
|
|
| |
| attn_weights = jax.nn.softmax(attn_weights, axis=3).astype(jnp.float32) |
| |
| wv = jnp.einsum('bhij,bhdj->bhdi', attn_weights, value) |
| wve = jnp.einsum('bhij,dij->bhdi', attn_weights, relative_pos_emb_v) |
| |
|
|
| x = (wv + wve).transpose((0, 3, 1, 2)) |
| |
| out = x.reshape(bs, input_len, features) |
| return nn.GroupNorm(epsilon=1e-4, name='gn_out')(out) |
|
|
|
|
| class AxialSelfAttention(nn.Module): |
| """Axial Self Attention module. |
| |
| Attributes: |
| attention_axis: Axis of the attention |
| axial_attention_configs: Configurations of the axial attention. |
| """ |
|
|
| attention_axis: int |
| axial_attention_configs: ml_collections.ConfigDict |
|
|
| @nn.compact |
| def __call__(self, x: jnp.ndarray) -> jnp.ndarray: |
| """Applies Axial Self Attention module. |
| |
| Args: |
| x: Input data. |
| |
| Returns: |
| Output after axial attention applied on it. |
| """ |
| bs, height, width, channel = x.shape |
| if self.attention_axis == 1: |
| x = x.transpose((0, 2, 1, 3)).reshape(bs * width, height, channel) |
| elif self.attention_axis == 2: |
| x = x.reshape(bs * height, width, channel) |
| else: |
| raise ValueError('Only attention over rows or columns is supported.') |
|
|
| x = SelfAttentionWith1DRelativePos(self.axial_attention_configs.num_heads)( |
| x) |
|
|
| if self.attention_axis == 1: |
| return x.reshape((bs, width, height, channel)).transpose((0, 2, 1, 3)) |
| else: |
| return x.reshape((bs, height, width, channel)) |
|
|
|
|
| class AxialResidualUnit(nn.Module): |
| """Bottleneck AxialResNet block. |
| |
| Attributes: |
| nout: Number of output features. |
| axial_attention_configs: Configurations of the axial attention. |
| strides: Down-sampling stride. |
| bottleneck: If True, the block is a bottleneck block. |
| """ |
| nout: int |
| axial_attention_configs: ml_collections.ConfigDict |
| strides: Tuple[int, ...] = (1, 1) |
| bottleneck: bool = True |
|
|
| @nn.compact |
| def __call__(self, x: jnp.ndarray) -> jnp.ndarray: |
| features = self.nout |
| nout = self.nout * 4 if self.bottleneck else self.nout |
| needs_projection = x.shape[-1] != nout or self.strides != (1, 1) |
| residual = x |
| if needs_projection: |
| residual = bit_resnet.StdConv( |
| nout, (1, 1), self.strides, use_bias=False, name='conv_proj')( |
| residual) |
| residual = nn.GroupNorm(epsilon=1e-4, name='gn_proj')(residual) |
|
|
| if self.bottleneck: |
| x = bit_resnet.StdConv(features, (1, 1), use_bias=False, name='conv1')(x) |
| x = nn.GroupNorm(epsilon=1e-4, name='gn1')(x) |
| x = nn.relu(x) |
|
|
| |
| |
| x = AxialSelfAttention( |
| attention_axis=1, axial_attention_configs=self.axial_attention_configs)( |
| x) |
| |
| x = AxialSelfAttention( |
| attention_axis=2, axial_attention_configs=self.axial_attention_configs)( |
| x) |
| if self.strides == (2, 2): |
| x = nn.avg_pool(x, (2, 2), strides=(2, 2), padding='SAME') |
| x = nn.relu(x) |
|
|
| last_kernel = (1, 1) if self.bottleneck else (3, 3) |
| x = bit_resnet.StdConv(nout, last_kernel, use_bias=False, name='conv3')(x) |
| x = nn.GroupNorm( |
| epsilon=1e-4, name='gn3', scale_init=nn.initializers.zeros)( |
| x) |
| x = nn.relu(residual + x) |
|
|
| return x |
|
|
|
|
| class AxialResNetStage(nn.Module): |
| """ResNet Stage: one or more stacked ResNet blocks. |
| |
| Attributes: |
| block_size: Number of ResNet blocks to stack. |
| nout: Number of features. |
| first_stride: Downsampling stride. |
| bottleneck: If True, the bottleneck block is used. |
| """ |
|
|
| block_size: int |
| nout: int |
| axial_attention_configs: ml_collections.ConfigDict |
| first_stride: Tuple[int, ...] |
| bottleneck: bool = True |
|
|
| @nn.compact |
| def __call__(self, x: jnp.ndarray) -> jnp.ndarray: |
| x = AxialResidualUnit( |
| self.nout, |
| axial_attention_configs=self.axial_attention_configs, |
| strides=self.first_stride, |
| bottleneck=self.bottleneck, |
| name='unit1')( |
| x) |
| for i in range(1, self.block_size): |
| x = AxialResidualUnit( |
| self.nout, |
| axial_attention_configs=self.axial_attention_configs, |
| strides=(1, 1), |
| bottleneck=self.bottleneck, |
| name=f'unit{i + 1}')( |
| x) |
| return x |
|
|
|
|
| class AxialResNet(nn.Module): |
| """Axial ResNet. |
| |
| Attributes: |
| num_outputs: Num output classes. If None, a dict of intermediate feature |
| maps is returned. |
| width_factor: Width multiplier for each of the ResNet stages. |
| num_layers: Number of layers (see `BLOCK_SIZE_OPTIONS` for stage |
| configurations). |
| """ |
| num_outputs: int |
| axial_attention_configs: ml_collections.ConfigDict |
| width_factor: int = 1 |
| num_layers: int = 50 |
|
|
| @nn.compact |
| def __call__( |
| self, |
| x: jnp.ndarray, |
| *, |
| train: bool = True, |
| debug: bool = False) -> Union[jnp.ndarray, Dict[str, jnp.ndarray]]: |
| """Applies the AxialResNet model to the inputs. |
| |
| Args: |
| x: Inputs to the model. |
| train: Unused. |
| debug: Unused. |
| |
| Returns: |
| Un-normalized logits if `num_outputs` is provided, a dictionary with |
| representations otherwise. |
| """ |
| del train |
| del debug |
| blocks, bottleneck = bit_resnet.BLOCK_SIZE_OPTIONS[self.num_layers] |
| width = int(64 * self.width_factor) |
|
|
| |
| x = bit_resnet.StdConv( |
| width, (7, 7), (2, 2), use_bias=False, name='conv_root')( |
| x) |
| x = nn.GroupNorm(epsilon=1e-4, name='gn_root')(x) |
| x = nn.relu(x) |
| x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') |
|
|
| |
| x = AxialResNetStage( |
| blocks[0], |
| width, |
| axial_attention_configs=self.axial_attention_configs, |
| first_stride=(1, 1), |
| bottleneck=bottleneck, |
| name='block1')( |
| x) |
| for i, block_size in enumerate(blocks[1:], 1): |
| x = AxialResNetStage( |
| block_size, |
| width * 2**i, |
| axial_attention_configs=self.axial_attention_configs, |
| first_stride=(2, 2), |
| bottleneck=bottleneck, |
| name=f'block{i + 1}')( |
| x) |
|
|
| |
| x = jnp.mean(x, axis=(1, 2)) |
| x = nn_layers.IdentityLayer(name='pre_logits')(x) |
| return nn.Dense( |
| self.num_outputs, |
| kernel_init=nn.initializers.zeros, |
| name='output_projection')( |
| x) |
|
|
|
|
| class AxialResNetMultiLabelClassificationModel(MultiLabelClassificationModel): |
| """Implements the AxialResNet model for multi-label classification.""" |
|
|
| def build_flax_model(self) -> nn.Module: |
| return AxialResNet( |
| num_outputs=self.dataset_meta_data['num_classes'], |
| axial_attention_configs=self.config.axial_attention_configs, |
| width_factor=self.config.get('width_factor', 1), |
| num_layers=self.config.num_layers, |
| ) |
|
|
| def default_flax_model_config(self) -> ml_collections.ConfigDict: |
| return ml_collections.ConfigDict( |
| dict( |
| width_factor=1, |
| num_layers=5, |
| axial_attention_configs=ml_collections.ConfigDict({'num_heads': 2}), |
| )) |
|
|