Spaces:
Build error
Build error
| # Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Contains definitions of 3D Residual Networks.""" | |
| from typing import Callable, List, Tuple, Optional | |
| # Import libraries | |
| import tensorflow as tf, tf_keras | |
| from official.modeling import hyperparams | |
| from official.modeling import tf_utils | |
| from official.vision.modeling.backbones import factory | |
| from official.vision.modeling.layers import nn_blocks_3d | |
| from official.vision.modeling.layers import nn_layers | |
| layers = tf_keras.layers | |
| RESNET_SPECS = { | |
| 50: [ | |
| ('bottleneck3d', 64, 3), | |
| ('bottleneck3d', 128, 4), | |
| ('bottleneck3d', 256, 6), | |
| ('bottleneck3d', 512, 3), | |
| ], | |
| 101: [ | |
| ('bottleneck3d', 64, 3), | |
| ('bottleneck3d', 128, 4), | |
| ('bottleneck3d', 256, 23), | |
| ('bottleneck3d', 512, 3), | |
| ], | |
| 152: [ | |
| ('bottleneck3d', 64, 3), | |
| ('bottleneck3d', 128, 8), | |
| ('bottleneck3d', 256, 36), | |
| ('bottleneck3d', 512, 3), | |
| ], | |
| 200: [ | |
| ('bottleneck3d', 64, 3), | |
| ('bottleneck3d', 128, 24), | |
| ('bottleneck3d', 256, 36), | |
| ('bottleneck3d', 512, 3), | |
| ], | |
| 270: [ | |
| ('bottleneck3d', 64, 4), | |
| ('bottleneck3d', 128, 29), | |
| ('bottleneck3d', 256, 53), | |
| ('bottleneck3d', 512, 4), | |
| ], | |
| 300: [ | |
| ('bottleneck3d', 64, 4), | |
| ('bottleneck3d', 128, 36), | |
| ('bottleneck3d', 256, 54), | |
| ('bottleneck3d', 512, 4), | |
| ], | |
| 350: [ | |
| ('bottleneck3d', 64, 4), | |
| ('bottleneck3d', 128, 36), | |
| ('bottleneck3d', 256, 72), | |
| ('bottleneck3d', 512, 4), | |
| ], | |
| } | |
| class ResNet3D(tf_keras.Model): | |
| """Creates a 3D ResNet family model.""" | |
| def __init__( | |
| self, | |
| model_id: int, | |
| temporal_strides: List[int], | |
| temporal_kernel_sizes: List[Tuple[int]], | |
| use_self_gating: Optional[List[int]] = None, | |
| input_specs: tf_keras.layers.InputSpec = layers.InputSpec( | |
| shape=[None, None, None, None, 3]), | |
| stem_type: str = 'v0', | |
| stem_conv_temporal_kernel_size: int = 5, | |
| stem_conv_temporal_stride: int = 2, | |
| stem_pool_temporal_stride: int = 2, | |
| init_stochastic_depth_rate: float = 0.0, | |
| activation: str = 'relu', | |
| se_ratio: Optional[float] = None, | |
| use_sync_bn: bool = False, | |
| norm_momentum: float = 0.99, | |
| norm_epsilon: float = 0.001, | |
| kernel_initializer: str = 'VarianceScaling', | |
| kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = None, | |
| bias_regularizer: Optional[tf_keras.regularizers.Regularizer] = None, | |
| **kwargs): | |
| """Initializes a 3D ResNet model. | |
| Args: | |
| model_id: An `int` of depth of ResNet backbone model. | |
| temporal_strides: A list of integers that specifies the temporal strides | |
| for all 3d blocks. | |
| temporal_kernel_sizes: A list of tuples that specifies the temporal kernel | |
| sizes for all 3d blocks in different block groups. | |
| use_self_gating: A list of booleans to specify applying self-gating module | |
| or not in each block group. If None, self-gating is not applied. | |
| input_specs: A `tf_keras.layers.InputSpec` of the input tensor. | |
| stem_type: A `str` of stem type of ResNet. Default to `v0`. If set to | |
| `v1`, use ResNet-D type stem (https://arxiv.org/abs/1812.01187). | |
| stem_conv_temporal_kernel_size: An `int` of temporal kernel size for the | |
| first conv layer. | |
| stem_conv_temporal_stride: An `int` of temporal stride for the first conv | |
| layer. | |
| stem_pool_temporal_stride: An `int` of temporal stride for the first pool | |
| layer. | |
| init_stochastic_depth_rate: A `float` of initial stochastic depth rate. | |
| activation: A `str` of name of the activation function. | |
| se_ratio: A `float` or None. Ratio of the Squeeze-and-Excitation layer. | |
| use_sync_bn: If True, use synchronized batch normalization. | |
| norm_momentum: A `float` of normalization momentum for the moving average. | |
| norm_epsilon: A `float` added to variance to avoid dividing by zero. | |
| kernel_initializer: A str for kernel initializer of convolutional layers. | |
| kernel_regularizer: A `tf_keras.regularizers.Regularizer` object for | |
| Conv2D. Default to None. | |
| bias_regularizer: A `tf_keras.regularizers.Regularizer` object for Conv2D. | |
| Default to None. | |
| **kwargs: Additional keyword arguments to be passed. | |
| """ | |
| self._model_id = model_id | |
| self._temporal_strides = temporal_strides | |
| self._temporal_kernel_sizes = temporal_kernel_sizes | |
| self._input_specs = input_specs | |
| self._stem_type = stem_type | |
| self._stem_conv_temporal_kernel_size = stem_conv_temporal_kernel_size | |
| self._stem_conv_temporal_stride = stem_conv_temporal_stride | |
| self._stem_pool_temporal_stride = stem_pool_temporal_stride | |
| self._use_self_gating = use_self_gating | |
| self._se_ratio = se_ratio | |
| self._init_stochastic_depth_rate = init_stochastic_depth_rate | |
| self._use_sync_bn = use_sync_bn | |
| self._activation = activation | |
| self._norm_momentum = norm_momentum | |
| self._norm_epsilon = norm_epsilon | |
| self._norm = layers.BatchNormalization | |
| self._kernel_initializer = kernel_initializer | |
| self._kernel_regularizer = kernel_regularizer | |
| self._bias_regularizer = bias_regularizer | |
| if tf_keras.backend.image_data_format() == 'channels_last': | |
| self._bn_axis = -1 | |
| else: | |
| self._bn_axis = 1 | |
| # Build ResNet3D backbone. | |
| inputs = tf_keras.Input(shape=input_specs.shape[1:]) | |
| endpoints = self._build_model(inputs) | |
| self._output_specs = {l: endpoints[l].get_shape() for l in endpoints} | |
| super(ResNet3D, self).__init__(inputs=inputs, outputs=endpoints, **kwargs) | |
| def _build_model(self, inputs): | |
| """Builds model architecture. | |
| Args: | |
| inputs: the keras input spec. | |
| Returns: | |
| endpoints: A dictionary of backbone endpoint features. | |
| """ | |
| # Build stem. | |
| x = self._build_stem(inputs, stem_type=self._stem_type) | |
| temporal_kernel_size = 1 if self._stem_pool_temporal_stride == 1 else 3 | |
| x = layers.MaxPool3D( | |
| pool_size=[temporal_kernel_size, 3, 3], | |
| strides=[self._stem_pool_temporal_stride, 2, 2], | |
| padding='same')(x) | |
| # Build intermediate blocks and endpoints. | |
| resnet_specs = RESNET_SPECS[self._model_id] | |
| if len(self._temporal_strides) != len(resnet_specs) or len( | |
| self._temporal_kernel_sizes) != len(resnet_specs): | |
| raise ValueError( | |
| 'Number of blocks in temporal specs should equal to resnet_specs.') | |
| endpoints = {} | |
| for i, resnet_spec in enumerate(resnet_specs): | |
| if resnet_spec[0] == 'bottleneck3d': | |
| block_fn = nn_blocks_3d.BottleneckBlock3D | |
| else: | |
| raise ValueError('Block fn `{}` is not supported.'.format( | |
| resnet_spec[0])) | |
| use_self_gating = ( | |
| self._use_self_gating[i] if self._use_self_gating else False) | |
| x = self._block_group( | |
| inputs=x, | |
| filters=resnet_spec[1], | |
| temporal_kernel_sizes=self._temporal_kernel_sizes[i], | |
| temporal_strides=self._temporal_strides[i], | |
| spatial_strides=(1 if i == 0 else 2), | |
| block_fn=block_fn, | |
| block_repeats=resnet_spec[2], | |
| stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate( | |
| self._init_stochastic_depth_rate, i + 2, 5), | |
| use_self_gating=use_self_gating, | |
| name='block_group_l{}'.format(i + 2)) | |
| endpoints[str(i + 2)] = x | |
| return endpoints | |
| def _build_stem(self, inputs, stem_type): | |
| """Builds stem layer.""" | |
| # Build stem. | |
| if stem_type == 'v0': | |
| x = layers.Conv3D( | |
| filters=64, | |
| kernel_size=[self._stem_conv_temporal_kernel_size, 7, 7], | |
| strides=[self._stem_conv_temporal_stride, 2, 2], | |
| use_bias=False, | |
| padding='same', | |
| kernel_initializer=self._kernel_initializer, | |
| kernel_regularizer=self._kernel_regularizer, | |
| bias_regularizer=self._bias_regularizer)( | |
| inputs) | |
| x = self._norm( | |
| axis=self._bn_axis, | |
| momentum=self._norm_momentum, | |
| epsilon=self._norm_epsilon, | |
| synchronized=self._use_sync_bn)(x) | |
| x = tf_utils.get_activation(self._activation)(x) | |
| elif stem_type == 'v1': | |
| x = layers.Conv3D( | |
| filters=32, | |
| kernel_size=[self._stem_conv_temporal_kernel_size, 3, 3], | |
| strides=[self._stem_conv_temporal_stride, 2, 2], | |
| use_bias=False, | |
| padding='same', | |
| kernel_initializer=self._kernel_initializer, | |
| kernel_regularizer=self._kernel_regularizer, | |
| bias_regularizer=self._bias_regularizer)( | |
| inputs) | |
| x = self._norm( | |
| axis=self._bn_axis, | |
| momentum=self._norm_momentum, | |
| epsilon=self._norm_epsilon, | |
| synchronized=self._use_sync_bn)(x) | |
| x = tf_utils.get_activation(self._activation)(x) | |
| x = layers.Conv3D( | |
| filters=32, | |
| kernel_size=[1, 3, 3], | |
| strides=[1, 1, 1], | |
| use_bias=False, | |
| padding='same', | |
| kernel_initializer=self._kernel_initializer, | |
| kernel_regularizer=self._kernel_regularizer, | |
| bias_regularizer=self._bias_regularizer)( | |
| x) | |
| x = self._norm( | |
| axis=self._bn_axis, | |
| momentum=self._norm_momentum, | |
| epsilon=self._norm_epsilon, | |
| synchronized=self._use_sync_bn)(x) | |
| x = tf_utils.get_activation(self._activation)(x) | |
| x = layers.Conv3D( | |
| filters=64, | |
| kernel_size=[1, 3, 3], | |
| strides=[1, 1, 1], | |
| use_bias=False, | |
| padding='same', | |
| kernel_initializer=self._kernel_initializer, | |
| kernel_regularizer=self._kernel_regularizer, | |
| bias_regularizer=self._bias_regularizer)( | |
| x) | |
| x = self._norm( | |
| axis=self._bn_axis, | |
| momentum=self._norm_momentum, | |
| epsilon=self._norm_epsilon, | |
| synchronized=self._use_sync_bn)(x) | |
| x = tf_utils.get_activation(self._activation)(x) | |
| else: | |
| raise ValueError(f'Stem type {stem_type} not supported.') | |
| return x | |
| def _block_group(self, | |
| inputs: tf.Tensor, | |
| filters: int, | |
| temporal_kernel_sizes: Tuple[int], | |
| temporal_strides: int, | |
| spatial_strides: int, | |
| block_fn: Callable[ | |
| ..., | |
| tf_keras.layers.Layer] = nn_blocks_3d.BottleneckBlock3D, | |
| block_repeats: int = 1, | |
| stochastic_depth_drop_rate: float = 0.0, | |
| use_self_gating: bool = False, | |
| name: str = 'block_group'): | |
| """Creates one group of blocks for the ResNet3D model. | |
| Args: | |
| inputs: A `tf.Tensor` of size `[batch, channels, height, width]`. | |
| filters: An `int` of number of filters for the first convolution of the | |
| layer. | |
| temporal_kernel_sizes: A tuple that specifies the temporal kernel sizes | |
| for each block in the current group. | |
| temporal_strides: An `int` of temporal strides for the first convolution | |
| in this group. | |
| spatial_strides: An `int` stride to use for the first convolution of the | |
| layer. If greater than 1, this layer will downsample the input. | |
| block_fn: Either `nn_blocks.ResidualBlock` or `nn_blocks.BottleneckBlock`. | |
| block_repeats: An `int` of number of blocks contained in the layer. | |
| stochastic_depth_drop_rate: A `float` of drop rate of the current block | |
| group. | |
| use_self_gating: A `bool` that specifies whether to apply self-gating | |
| module or not. | |
| name: A `str` name for the block. | |
| Returns: | |
| The output `tf.Tensor` of the block layer. | |
| """ | |
| if len(temporal_kernel_sizes) != block_repeats: | |
| raise ValueError( | |
| 'Number of elements in `temporal_kernel_sizes` must equal to `block_repeats`.' | |
| ) | |
| # Only apply self-gating module in the last block. | |
| use_self_gating_list = [False] * (block_repeats - 1) + [use_self_gating] | |
| x = block_fn( | |
| filters=filters, | |
| temporal_kernel_size=temporal_kernel_sizes[0], | |
| temporal_strides=temporal_strides, | |
| spatial_strides=spatial_strides, | |
| stochastic_depth_drop_rate=stochastic_depth_drop_rate, | |
| use_self_gating=use_self_gating_list[0], | |
| se_ratio=self._se_ratio, | |
| kernel_initializer=self._kernel_initializer, | |
| kernel_regularizer=self._kernel_regularizer, | |
| bias_regularizer=self._bias_regularizer, | |
| activation=self._activation, | |
| use_sync_bn=self._use_sync_bn, | |
| norm_momentum=self._norm_momentum, | |
| norm_epsilon=self._norm_epsilon)( | |
| inputs) | |
| for i in range(1, block_repeats): | |
| x = block_fn( | |
| filters=filters, | |
| temporal_kernel_size=temporal_kernel_sizes[i], | |
| temporal_strides=1, | |
| spatial_strides=1, | |
| stochastic_depth_drop_rate=stochastic_depth_drop_rate, | |
| use_self_gating=use_self_gating_list[i], | |
| se_ratio=self._se_ratio, | |
| kernel_initializer=self._kernel_initializer, | |
| kernel_regularizer=self._kernel_regularizer, | |
| bias_regularizer=self._bias_regularizer, | |
| activation=self._activation, | |
| use_sync_bn=self._use_sync_bn, | |
| norm_momentum=self._norm_momentum, | |
| norm_epsilon=self._norm_epsilon)( | |
| x) | |
| return tf.identity(x, name=name) | |
| def get_config(self): | |
| config_dict = { | |
| 'model_id': self._model_id, | |
| 'temporal_strides': self._temporal_strides, | |
| 'temporal_kernel_sizes': self._temporal_kernel_sizes, | |
| 'stem_type': self._stem_type, | |
| 'stem_conv_temporal_kernel_size': self._stem_conv_temporal_kernel_size, | |
| 'stem_conv_temporal_stride': self._stem_conv_temporal_stride, | |
| 'stem_pool_temporal_stride': self._stem_pool_temporal_stride, | |
| 'use_self_gating': self._use_self_gating, | |
| 'se_ratio': self._se_ratio, | |
| 'init_stochastic_depth_rate': self._init_stochastic_depth_rate, | |
| 'activation': self._activation, | |
| 'use_sync_bn': self._use_sync_bn, | |
| 'norm_momentum': self._norm_momentum, | |
| 'norm_epsilon': self._norm_epsilon, | |
| 'kernel_initializer': self._kernel_initializer, | |
| 'kernel_regularizer': self._kernel_regularizer, | |
| 'bias_regularizer': self._bias_regularizer, | |
| } | |
| return config_dict | |
| def from_config(cls, config, custom_objects=None): | |
| return cls(**config) | |
| def output_specs(self): | |
| """A dict of {level: TensorShape} pairs for the model output.""" | |
| return self._output_specs | |
| def build_resnet3d( | |
| input_specs: tf_keras.layers.InputSpec, | |
| backbone_config: hyperparams.Config, | |
| norm_activation_config: hyperparams.Config, | |
| l2_regularizer: Optional[tf_keras.regularizers.Regularizer] = None | |
| ) -> tf_keras.Model: | |
| """Builds ResNet 3d backbone from a config.""" | |
| backbone_cfg = backbone_config.get() | |
| # Flatten configs before passing to the backbone. | |
| temporal_strides = [] | |
| temporal_kernel_sizes = [] | |
| use_self_gating = [] | |
| for block_spec in backbone_cfg.block_specs: | |
| temporal_strides.append(block_spec.temporal_strides) | |
| temporal_kernel_sizes.append(block_spec.temporal_kernel_sizes) | |
| use_self_gating.append(block_spec.use_self_gating) | |
| return ResNet3D( | |
| model_id=backbone_cfg.model_id, | |
| temporal_strides=temporal_strides, | |
| temporal_kernel_sizes=temporal_kernel_sizes, | |
| use_self_gating=use_self_gating, | |
| input_specs=input_specs, | |
| stem_type=backbone_cfg.stem_type, | |
| stem_conv_temporal_kernel_size=backbone_cfg | |
| .stem_conv_temporal_kernel_size, | |
| stem_conv_temporal_stride=backbone_cfg.stem_conv_temporal_stride, | |
| stem_pool_temporal_stride=backbone_cfg.stem_pool_temporal_stride, | |
| init_stochastic_depth_rate=backbone_cfg.stochastic_depth_drop_rate, | |
| se_ratio=backbone_cfg.se_ratio, | |
| activation=norm_activation_config.activation, | |
| use_sync_bn=norm_activation_config.use_sync_bn, | |
| norm_momentum=norm_activation_config.norm_momentum, | |
| norm_epsilon=norm_activation_config.norm_epsilon, | |
| kernel_regularizer=l2_regularizer) | |
| def build_resnet3d_rs( | |
| input_specs: tf_keras.layers.InputSpec, | |
| backbone_config: hyperparams.Config, | |
| norm_activation_config: hyperparams.Config, | |
| l2_regularizer: Optional[tf_keras.regularizers.Regularizer] = None | |
| ) -> tf_keras.Model: | |
| """Builds ResNet-3D-RS backbone from a config.""" | |
| backbone_cfg = backbone_config.get() | |
| # Flatten configs before passing to the backbone. | |
| temporal_strides = [] | |
| temporal_kernel_sizes = [] | |
| use_self_gating = [] | |
| for i, block_spec in enumerate(backbone_cfg.block_specs): | |
| temporal_strides.append(block_spec.temporal_strides) | |
| use_self_gating.append(block_spec.use_self_gating) | |
| block_repeats_i = RESNET_SPECS[backbone_cfg.model_id][i][-1] | |
| temporal_kernel_sizes.append(list(block_spec.temporal_kernel_sizes) * | |
| block_repeats_i) | |
| return ResNet3D( | |
| model_id=backbone_cfg.model_id, | |
| temporal_strides=temporal_strides, | |
| temporal_kernel_sizes=temporal_kernel_sizes, | |
| use_self_gating=use_self_gating, | |
| input_specs=input_specs, | |
| stem_type=backbone_cfg.stem_type, | |
| stem_conv_temporal_kernel_size=backbone_cfg | |
| .stem_conv_temporal_kernel_size, | |
| stem_conv_temporal_stride=backbone_cfg.stem_conv_temporal_stride, | |
| stem_pool_temporal_stride=backbone_cfg.stem_pool_temporal_stride, | |
| init_stochastic_depth_rate=backbone_cfg.stochastic_depth_drop_rate, | |
| se_ratio=backbone_cfg.se_ratio, | |
| activation=norm_activation_config.activation, | |
| use_sync_bn=norm_activation_config.use_sync_bn, | |
| norm_momentum=norm_activation_config.norm_momentum, | |
| norm_epsilon=norm_activation_config.norm_epsilon, | |
| kernel_regularizer=l2_regularizer) | |