|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Contains definitions of 3D Residual Networks."""
|
| from typing import Callable, List, Tuple, Optional
|
|
|
|
|
| import tensorflow as tf, tf_keras
|
|
|
| from official.modeling import hyperparams
|
| from official.modeling import tf_utils
|
| from official.vision.modeling.backbones import factory
|
| from official.vision.modeling.layers import nn_blocks_3d
|
| from official.vision.modeling.layers import nn_layers
|
|
|
| layers = tf_keras.layers
|
|
|
| RESNET_SPECS = {
|
| 50: [
|
| ('bottleneck3d', 64, 3),
|
| ('bottleneck3d', 128, 4),
|
| ('bottleneck3d', 256, 6),
|
| ('bottleneck3d', 512, 3),
|
| ],
|
| 101: [
|
| ('bottleneck3d', 64, 3),
|
| ('bottleneck3d', 128, 4),
|
| ('bottleneck3d', 256, 23),
|
| ('bottleneck3d', 512, 3),
|
| ],
|
| 152: [
|
| ('bottleneck3d', 64, 3),
|
| ('bottleneck3d', 128, 8),
|
| ('bottleneck3d', 256, 36),
|
| ('bottleneck3d', 512, 3),
|
| ],
|
| 200: [
|
| ('bottleneck3d', 64, 3),
|
| ('bottleneck3d', 128, 24),
|
| ('bottleneck3d', 256, 36),
|
| ('bottleneck3d', 512, 3),
|
| ],
|
| 270: [
|
| ('bottleneck3d', 64, 4),
|
| ('bottleneck3d', 128, 29),
|
| ('bottleneck3d', 256, 53),
|
| ('bottleneck3d', 512, 4),
|
| ],
|
| 300: [
|
| ('bottleneck3d', 64, 4),
|
| ('bottleneck3d', 128, 36),
|
| ('bottleneck3d', 256, 54),
|
| ('bottleneck3d', 512, 4),
|
| ],
|
| 350: [
|
| ('bottleneck3d', 64, 4),
|
| ('bottleneck3d', 128, 36),
|
| ('bottleneck3d', 256, 72),
|
| ('bottleneck3d', 512, 4),
|
| ],
|
| }
|
|
|
|
|
| @tf_keras.utils.register_keras_serializable(package='Vision')
|
| class ResNet3D(tf_keras.Model):
|
| """Creates a 3D ResNet family model."""
|
|
|
| def __init__(
|
| self,
|
| model_id: int,
|
| temporal_strides: List[int],
|
| temporal_kernel_sizes: List[Tuple[int]],
|
| use_self_gating: Optional[List[int]] = None,
|
| input_specs: tf_keras.layers.InputSpec = layers.InputSpec(
|
| shape=[None, None, None, None, 3]),
|
| stem_type: str = 'v0',
|
| stem_conv_temporal_kernel_size: int = 5,
|
| stem_conv_temporal_stride: int = 2,
|
| stem_pool_temporal_stride: int = 2,
|
| init_stochastic_depth_rate: float = 0.0,
|
| activation: str = 'relu',
|
| se_ratio: Optional[float] = None,
|
| use_sync_bn: bool = False,
|
| norm_momentum: float = 0.99,
|
| norm_epsilon: float = 0.001,
|
| kernel_initializer: str = 'VarianceScaling',
|
| kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
|
| bias_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
|
| **kwargs):
|
| """Initializes a 3D ResNet model.
|
|
|
| Args:
|
| model_id: An `int` of depth of ResNet backbone model.
|
| temporal_strides: A list of integers that specifies the temporal strides
|
| for all 3d blocks.
|
| temporal_kernel_sizes: A list of tuples that specifies the temporal kernel
|
| sizes for all 3d blocks in different block groups.
|
| use_self_gating: A list of booleans to specify applying self-gating module
|
| or not in each block group. If None, self-gating is not applied.
|
| input_specs: A `tf_keras.layers.InputSpec` of the input tensor.
|
| stem_type: A `str` of stem type of ResNet. Default to `v0`. If set to
|
| `v1`, use ResNet-D type stem (https://arxiv.org/abs/1812.01187).
|
| stem_conv_temporal_kernel_size: An `int` of temporal kernel size for the
|
| first conv layer.
|
| stem_conv_temporal_stride: An `int` of temporal stride for the first conv
|
| layer.
|
| stem_pool_temporal_stride: An `int` of temporal stride for the first pool
|
| layer.
|
| init_stochastic_depth_rate: A `float` of initial stochastic depth rate.
|
| activation: A `str` of name of the activation function.
|
| se_ratio: A `float` or None. Ratio of the Squeeze-and-Excitation layer.
|
| use_sync_bn: If True, use synchronized batch normalization.
|
| norm_momentum: A `float` of normalization momentum for the moving average.
|
| norm_epsilon: A `float` added to variance to avoid dividing by zero.
|
| kernel_initializer: A str for kernel initializer of convolutional layers.
|
| kernel_regularizer: A `tf_keras.regularizers.Regularizer` object for
|
| Conv2D. Default to None.
|
| bias_regularizer: A `tf_keras.regularizers.Regularizer` object for Conv2D.
|
| Default to None.
|
| **kwargs: Additional keyword arguments to be passed.
|
| """
|
| self._model_id = model_id
|
| self._temporal_strides = temporal_strides
|
| self._temporal_kernel_sizes = temporal_kernel_sizes
|
| self._input_specs = input_specs
|
| self._stem_type = stem_type
|
| self._stem_conv_temporal_kernel_size = stem_conv_temporal_kernel_size
|
| self._stem_conv_temporal_stride = stem_conv_temporal_stride
|
| self._stem_pool_temporal_stride = stem_pool_temporal_stride
|
| self._use_self_gating = use_self_gating
|
| self._se_ratio = se_ratio
|
| self._init_stochastic_depth_rate = init_stochastic_depth_rate
|
| self._use_sync_bn = use_sync_bn
|
| self._activation = activation
|
| self._norm_momentum = norm_momentum
|
| self._norm_epsilon = norm_epsilon
|
| self._norm = layers.BatchNormalization
|
| self._kernel_initializer = kernel_initializer
|
| self._kernel_regularizer = kernel_regularizer
|
| self._bias_regularizer = bias_regularizer
|
| if tf_keras.backend.image_data_format() == 'channels_last':
|
| self._bn_axis = -1
|
| else:
|
| self._bn_axis = 1
|
|
|
|
|
| inputs = tf_keras.Input(shape=input_specs.shape[1:])
|
| endpoints = self._build_model(inputs)
|
| self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}
|
|
|
| super(ResNet3D, self).__init__(inputs=inputs, outputs=endpoints, **kwargs)
|
|
|
| def _build_model(self, inputs):
|
| """Builds model architecture.
|
|
|
| Args:
|
| inputs: the keras input spec.
|
|
|
| Returns:
|
| endpoints: A dictionary of backbone endpoint features.
|
| """
|
|
|
| x = self._build_stem(inputs, stem_type=self._stem_type)
|
|
|
| temporal_kernel_size = 1 if self._stem_pool_temporal_stride == 1 else 3
|
| x = layers.MaxPool3D(
|
| pool_size=[temporal_kernel_size, 3, 3],
|
| strides=[self._stem_pool_temporal_stride, 2, 2],
|
| padding='same')(x)
|
|
|
|
|
| resnet_specs = RESNET_SPECS[self._model_id]
|
| if len(self._temporal_strides) != len(resnet_specs) or len(
|
| self._temporal_kernel_sizes) != len(resnet_specs):
|
| raise ValueError(
|
| 'Number of blocks in temporal specs should equal to resnet_specs.')
|
|
|
| endpoints = {}
|
| for i, resnet_spec in enumerate(resnet_specs):
|
| if resnet_spec[0] == 'bottleneck3d':
|
| block_fn = nn_blocks_3d.BottleneckBlock3D
|
| else:
|
| raise ValueError('Block fn `{}` is not supported.'.format(
|
| resnet_spec[0]))
|
|
|
| use_self_gating = (
|
| self._use_self_gating[i] if self._use_self_gating else False)
|
| x = self._block_group(
|
| inputs=x,
|
| filters=resnet_spec[1],
|
| temporal_kernel_sizes=self._temporal_kernel_sizes[i],
|
| temporal_strides=self._temporal_strides[i],
|
| spatial_strides=(1 if i == 0 else 2),
|
| block_fn=block_fn,
|
| block_repeats=resnet_spec[2],
|
| stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
|
| self._init_stochastic_depth_rate, i + 2, 5),
|
| use_self_gating=use_self_gating,
|
| name='block_group_l{}'.format(i + 2))
|
| endpoints[str(i + 2)] = x
|
|
|
| return endpoints
|
|
|
| def _build_stem(self, inputs, stem_type):
|
| """Builds stem layer."""
|
|
|
| if stem_type == 'v0':
|
| x = layers.Conv3D(
|
| filters=64,
|
| kernel_size=[self._stem_conv_temporal_kernel_size, 7, 7],
|
| strides=[self._stem_conv_temporal_stride, 2, 2],
|
| use_bias=False,
|
| padding='same',
|
| kernel_initializer=self._kernel_initializer,
|
| kernel_regularizer=self._kernel_regularizer,
|
| bias_regularizer=self._bias_regularizer)(
|
| inputs)
|
| x = self._norm(
|
| axis=self._bn_axis,
|
| momentum=self._norm_momentum,
|
| epsilon=self._norm_epsilon,
|
| synchronized=self._use_sync_bn)(x)
|
| x = tf_utils.get_activation(self._activation)(x)
|
| elif stem_type == 'v1':
|
| x = layers.Conv3D(
|
| filters=32,
|
| kernel_size=[self._stem_conv_temporal_kernel_size, 3, 3],
|
| strides=[self._stem_conv_temporal_stride, 2, 2],
|
| use_bias=False,
|
| padding='same',
|
| kernel_initializer=self._kernel_initializer,
|
| kernel_regularizer=self._kernel_regularizer,
|
| bias_regularizer=self._bias_regularizer)(
|
| inputs)
|
| x = self._norm(
|
| axis=self._bn_axis,
|
| momentum=self._norm_momentum,
|
| epsilon=self._norm_epsilon,
|
| synchronized=self._use_sync_bn)(x)
|
| x = tf_utils.get_activation(self._activation)(x)
|
| x = layers.Conv3D(
|
| filters=32,
|
| kernel_size=[1, 3, 3],
|
| strides=[1, 1, 1],
|
| use_bias=False,
|
| padding='same',
|
| kernel_initializer=self._kernel_initializer,
|
| kernel_regularizer=self._kernel_regularizer,
|
| bias_regularizer=self._bias_regularizer)(
|
| x)
|
| x = self._norm(
|
| axis=self._bn_axis,
|
| momentum=self._norm_momentum,
|
| epsilon=self._norm_epsilon,
|
| synchronized=self._use_sync_bn)(x)
|
| x = tf_utils.get_activation(self._activation)(x)
|
| x = layers.Conv3D(
|
| filters=64,
|
| kernel_size=[1, 3, 3],
|
| strides=[1, 1, 1],
|
| use_bias=False,
|
| padding='same',
|
| kernel_initializer=self._kernel_initializer,
|
| kernel_regularizer=self._kernel_regularizer,
|
| bias_regularizer=self._bias_regularizer)(
|
| x)
|
| x = self._norm(
|
| axis=self._bn_axis,
|
| momentum=self._norm_momentum,
|
| epsilon=self._norm_epsilon,
|
| synchronized=self._use_sync_bn)(x)
|
| x = tf_utils.get_activation(self._activation)(x)
|
| else:
|
| raise ValueError(f'Stem type {stem_type} not supported.')
|
|
|
| return x
|
|
|
| def _block_group(self,
|
| inputs: tf.Tensor,
|
| filters: int,
|
| temporal_kernel_sizes: Tuple[int],
|
| temporal_strides: int,
|
| spatial_strides: int,
|
| block_fn: Callable[
|
| ...,
|
| tf_keras.layers.Layer] = nn_blocks_3d.BottleneckBlock3D,
|
| block_repeats: int = 1,
|
| stochastic_depth_drop_rate: float = 0.0,
|
| use_self_gating: bool = False,
|
| name: str = 'block_group'):
|
| """Creates one group of blocks for the ResNet3D model.
|
|
|
| Args:
|
| inputs: A `tf.Tensor` of size `[batch, channels, height, width]`.
|
| filters: An `int` of number of filters for the first convolution of the
|
| layer.
|
| temporal_kernel_sizes: A tuple that specifies the temporal kernel sizes
|
| for each block in the current group.
|
| temporal_strides: An `int` of temporal strides for the first convolution
|
| in this group.
|
| spatial_strides: An `int` stride to use for the first convolution of the
|
| layer. If greater than 1, this layer will downsample the input.
|
| block_fn: Either `nn_blocks.ResidualBlock` or `nn_blocks.BottleneckBlock`.
|
| block_repeats: An `int` of number of blocks contained in the layer.
|
| stochastic_depth_drop_rate: A `float` of drop rate of the current block
|
| group.
|
| use_self_gating: A `bool` that specifies whether to apply self-gating
|
| module or not.
|
| name: A `str` name for the block.
|
|
|
| Returns:
|
| The output `tf.Tensor` of the block layer.
|
| """
|
| if len(temporal_kernel_sizes) != block_repeats:
|
| raise ValueError(
|
| 'Number of elements in `temporal_kernel_sizes` must equal to `block_repeats`.'
|
| )
|
|
|
|
|
| use_self_gating_list = [False] * (block_repeats - 1) + [use_self_gating]
|
|
|
| x = block_fn(
|
| filters=filters,
|
| temporal_kernel_size=temporal_kernel_sizes[0],
|
| temporal_strides=temporal_strides,
|
| spatial_strides=spatial_strides,
|
| stochastic_depth_drop_rate=stochastic_depth_drop_rate,
|
| use_self_gating=use_self_gating_list[0],
|
| se_ratio=self._se_ratio,
|
| kernel_initializer=self._kernel_initializer,
|
| kernel_regularizer=self._kernel_regularizer,
|
| bias_regularizer=self._bias_regularizer,
|
| activation=self._activation,
|
| use_sync_bn=self._use_sync_bn,
|
| norm_momentum=self._norm_momentum,
|
| norm_epsilon=self._norm_epsilon)(
|
| inputs)
|
|
|
| for i in range(1, block_repeats):
|
| x = block_fn(
|
| filters=filters,
|
| temporal_kernel_size=temporal_kernel_sizes[i],
|
| temporal_strides=1,
|
| spatial_strides=1,
|
| stochastic_depth_drop_rate=stochastic_depth_drop_rate,
|
| use_self_gating=use_self_gating_list[i],
|
| se_ratio=self._se_ratio,
|
| kernel_initializer=self._kernel_initializer,
|
| kernel_regularizer=self._kernel_regularizer,
|
| bias_regularizer=self._bias_regularizer,
|
| activation=self._activation,
|
| use_sync_bn=self._use_sync_bn,
|
| norm_momentum=self._norm_momentum,
|
| norm_epsilon=self._norm_epsilon)(
|
| x)
|
|
|
| return tf.identity(x, name=name)
|
|
|
| def get_config(self):
|
| config_dict = {
|
| 'model_id': self._model_id,
|
| 'temporal_strides': self._temporal_strides,
|
| 'temporal_kernel_sizes': self._temporal_kernel_sizes,
|
| 'stem_type': self._stem_type,
|
| 'stem_conv_temporal_kernel_size': self._stem_conv_temporal_kernel_size,
|
| 'stem_conv_temporal_stride': self._stem_conv_temporal_stride,
|
| 'stem_pool_temporal_stride': self._stem_pool_temporal_stride,
|
| 'use_self_gating': self._use_self_gating,
|
| 'se_ratio': self._se_ratio,
|
| 'init_stochastic_depth_rate': self._init_stochastic_depth_rate,
|
| 'activation': self._activation,
|
| 'use_sync_bn': self._use_sync_bn,
|
| 'norm_momentum': self._norm_momentum,
|
| 'norm_epsilon': self._norm_epsilon,
|
| 'kernel_initializer': self._kernel_initializer,
|
| 'kernel_regularizer': self._kernel_regularizer,
|
| 'bias_regularizer': self._bias_regularizer,
|
| }
|
| return config_dict
|
|
|
| @classmethod
|
| def from_config(cls, config, custom_objects=None):
|
| return cls(**config)
|
|
|
| @property
|
| def output_specs(self):
|
| """A dict of {level: TensorShape} pairs for the model output."""
|
| return self._output_specs
|
|
|
|
|
| @factory.register_backbone_builder('resnet_3d')
|
| def build_resnet3d(
|
| input_specs: tf_keras.layers.InputSpec,
|
| backbone_config: hyperparams.Config,
|
| norm_activation_config: hyperparams.Config,
|
| l2_regularizer: Optional[tf_keras.regularizers.Regularizer] = None
|
| ) -> tf_keras.Model:
|
| """Builds ResNet 3d backbone from a config."""
|
| backbone_cfg = backbone_config.get()
|
|
|
|
|
| temporal_strides = []
|
| temporal_kernel_sizes = []
|
| use_self_gating = []
|
| for block_spec in backbone_cfg.block_specs:
|
| temporal_strides.append(block_spec.temporal_strides)
|
| temporal_kernel_sizes.append(block_spec.temporal_kernel_sizes)
|
| use_self_gating.append(block_spec.use_self_gating)
|
|
|
| return ResNet3D(
|
| model_id=backbone_cfg.model_id,
|
| temporal_strides=temporal_strides,
|
| temporal_kernel_sizes=temporal_kernel_sizes,
|
| use_self_gating=use_self_gating,
|
| input_specs=input_specs,
|
| stem_type=backbone_cfg.stem_type,
|
| stem_conv_temporal_kernel_size=backbone_cfg
|
| .stem_conv_temporal_kernel_size,
|
| stem_conv_temporal_stride=backbone_cfg.stem_conv_temporal_stride,
|
| stem_pool_temporal_stride=backbone_cfg.stem_pool_temporal_stride,
|
| init_stochastic_depth_rate=backbone_cfg.stochastic_depth_drop_rate,
|
| se_ratio=backbone_cfg.se_ratio,
|
| activation=norm_activation_config.activation,
|
| use_sync_bn=norm_activation_config.use_sync_bn,
|
| norm_momentum=norm_activation_config.norm_momentum,
|
| norm_epsilon=norm_activation_config.norm_epsilon,
|
| kernel_regularizer=l2_regularizer)
|
|
|
|
|
| @factory.register_backbone_builder('resnet_3d_rs')
|
| def build_resnet3d_rs(
|
| input_specs: tf_keras.layers.InputSpec,
|
| backbone_config: hyperparams.Config,
|
| norm_activation_config: hyperparams.Config,
|
| l2_regularizer: Optional[tf_keras.regularizers.Regularizer] = None
|
| ) -> tf_keras.Model:
|
| """Builds ResNet-3D-RS backbone from a config."""
|
| backbone_cfg = backbone_config.get()
|
|
|
|
|
| temporal_strides = []
|
| temporal_kernel_sizes = []
|
| use_self_gating = []
|
| for i, block_spec in enumerate(backbone_cfg.block_specs):
|
| temporal_strides.append(block_spec.temporal_strides)
|
| use_self_gating.append(block_spec.use_self_gating)
|
| block_repeats_i = RESNET_SPECS[backbone_cfg.model_id][i][-1]
|
| temporal_kernel_sizes.append(list(block_spec.temporal_kernel_sizes) *
|
| block_repeats_i)
|
| return ResNet3D(
|
| model_id=backbone_cfg.model_id,
|
| temporal_strides=temporal_strides,
|
| temporal_kernel_sizes=temporal_kernel_sizes,
|
| use_self_gating=use_self_gating,
|
| input_specs=input_specs,
|
| stem_type=backbone_cfg.stem_type,
|
| stem_conv_temporal_kernel_size=backbone_cfg
|
| .stem_conv_temporal_kernel_size,
|
| stem_conv_temporal_stride=backbone_cfg.stem_conv_temporal_stride,
|
| stem_pool_temporal_stride=backbone_cfg.stem_pool_temporal_stride,
|
| init_stochastic_depth_rate=backbone_cfg.stochastic_depth_drop_rate,
|
| se_ratio=backbone_cfg.se_ratio,
|
| activation=norm_activation_config.activation,
|
| use_sync_bn=norm_activation_config.use_sync_bn,
|
| norm_momentum=norm_activation_config.norm_momentum,
|
| norm_epsilon=norm_activation_config.norm_epsilon,
|
| kernel_regularizer=l2_regularizer)
|
|
|