Spaces:
Paused
Paused
| # Copyright 2022 Google LLC | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # https://www.apache.org/licenses/LICENSE-2.0 | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """TF2 layer for extracting image features for the film_net interpolator. | |
| The feature extractor implemented here converts an image pyramid into a pyramid | |
| of deep features. The feature pyramid serves a similar purpose as U-Net | |
| architecture's encoder, but we use a special cascaded architecture described in | |
| Multi-view Image Fusion [1]. | |
| For comprehensiveness, below is a short description of the idea. While the | |
| description is a bit involved, the cascaded feature pyramid can be used just | |
| like any image feature pyramid. | |
| Why cascaded architeture? | |
| ========================= | |
| To understand the concept it is worth reviewing a traditional feature pyramid | |
| first: *A traditional feature pyramid* as in U-net or in many optical flow | |
| networks is built by alternating between convolutions and pooling, starting | |
| from the input image. | |
| It is well known that early features of such architecture correspond to low | |
| level concepts such as edges in the image whereas later layers extract | |
| semantically higher level concepts such as object classes etc. In other words, | |
| the meaning of the filters in each resolution level is different. For problems | |
| such as semantic segmentation and many others this is a desirable property. | |
| However, the asymmetric features preclude sharing weights across resolution | |
| levels in the feature extractor itself and in any subsequent neural networks | |
| that follow. This can be a downside, since optical flow prediction, for | |
| instance is symmetric across resolution levels. The cascaded feature | |
| architecture addresses this shortcoming. | |
| How is it built? | |
| ================ | |
| The *cascaded* feature pyramid contains feature vectors that have constant | |
| length and meaning on each resolution level, except few of the finest ones. The | |
| advantage of this is that the subsequent optical flow layer can learn | |
| synergically from many resolutions. This means that coarse level prediction can | |
| benefit from finer resolution training examples, which can be useful with | |
| moderately sized datasets to avoid overfitting. | |
| The cascaded feature pyramid is built by extracting shallower subtree pyramids, | |
| each one of them similar to the traditional architecture. Each subtree | |
| pyramid S_i is extracted starting from each resolution level: | |
| image resolution 0 -> S_0 | |
| image resolution 1 -> S_1 | |
| image resolution 2 -> S_2 | |
| ... | |
| If we denote the features at level j of subtree i as S_i_j, the cascaded pyramid | |
| is constructed by concatenating features as follows (assuming subtree depth=3): | |
| lvl | |
| feat_0 = concat( S_0_0 ) | |
| feat_1 = concat( S_1_0 S_0_1 ) | |
| feat_2 = concat( S_2_0 S_1_1 S_0_2 ) | |
| feat_3 = concat( S_3_0 S_2_1 S_1_2 ) | |
| feat_4 = concat( S_4_0 S_3_1 S_2_2 ) | |
| feat_5 = concat( S_5_0 S_4_1 S_3_2 ) | |
| .... | |
| In above, all levels except feat_0 and feat_1 have the same number of features | |
| with similar semantic meaning. This enables training a single optical flow | |
| predictor module shared by levels 2,3,4,5... . For more details and evaluation | |
| see [1]. | |
| [1] Multi-view Image Fusion, Trinidad et al. 2019 | |
| """ | |
| from typing import List | |
| from . import options | |
| import tensorflow as tf | |
| def _relu(x: tf.Tensor) -> tf.Tensor: | |
| return tf.nn.leaky_relu(x, alpha=0.2) | |
| def _conv(filters: int, name: str): | |
| return tf.keras.layers.Conv2D( | |
| name=name, | |
| filters=filters, | |
| kernel_size=3, | |
| padding='same', | |
| activation=_relu) | |
| class SubTreeExtractor(tf.keras.layers.Layer): | |
| """Extracts a hierarchical set of features from an image. | |
| This is a conventional, hierarchical image feature extractor, that extracts | |
| [k, k*2, k*4... ] filters for the image pyramid where k=options.sub_levels. | |
| Each level is followed by average pooling. | |
| Attributes: | |
| name: Name for the layer | |
| config: Options for the fusion_net frame interpolator | |
| """ | |
| def __init__(self, name: str, config: options.Options): | |
| super().__init__(name=name) | |
| k = config.filters | |
| n = config.sub_levels | |
| self.convs = [] | |
| for i in range(n): | |
| self.convs.append( | |
| _conv(filters=(k << i), name='cfeat_conv_{}'.format(2 * i))) | |
| self.convs.append( | |
| _conv(filters=(k << i), name='cfeat_conv_{}'.format(2 * i + 1))) | |
| def call(self, image: tf.Tensor, n: int) -> List[tf.Tensor]: | |
| """Extracts a pyramid of features from the image. | |
| Args: | |
| image: tf.Tensor with shape BATCH_SIZE x HEIGHT x WIDTH x CHANNELS. | |
| n: number of pyramid levels to extract. This can be less or equal to | |
| options.sub_levels given in the __init__. | |
| Returns: | |
| The pyramid of features, starting from the finest level. Each element | |
| contains the output after the last convolution on the corresponding | |
| pyramid level. | |
| """ | |
| head = image | |
| pool = tf.keras.layers.AveragePooling2D( | |
| pool_size=2, strides=2, padding='valid') | |
| pyramid = [] | |
| for i in range(n): | |
| head = self.convs[2*i](head) | |
| head = self.convs[2*i+1](head) | |
| pyramid.append(head) | |
| if i < n-1: | |
| head = pool(head) | |
| return pyramid | |
| class FeatureExtractor(tf.keras.layers.Layer): | |
| """Extracts features from an image pyramid using a cascaded architecture. | |
| Attributes: | |
| name: Name of the layer | |
| config: Options for the fusion_net frame interpolator | |
| """ | |
| def __init__(self, name: str, config: options.Options): | |
| super().__init__(name=name) | |
| self.extract_sublevels = SubTreeExtractor('sub_extractor', config) | |
| self.options = config | |
| def call(self, image_pyramid: List[tf.Tensor]) -> List[tf.Tensor]: | |
| """Extracts a cascaded feature pyramid. | |
| Args: | |
| image_pyramid: Image pyramid as a list, starting from the finest level. | |
| Returns: | |
| A pyramid of cascaded features. | |
| """ | |
| sub_pyramids = [] | |
| for i in range(len(image_pyramid)): | |
| # At each level of the image pyramid, creates a sub_pyramid of features | |
| # with 'sub_levels' pyramid levels, re-using the same SubTreeExtractor. | |
| # We use the same instance since we want to share the weights. | |
| # | |
| # However, we cap the depth of the sub_pyramid so we don't create features | |
| # that are beyond the coarsest level of the cascaded feature pyramid we | |
| # want to generate. | |
| capped_sub_levels = min(len(image_pyramid) - i, self.options.sub_levels) | |
| sub_pyramids.append( | |
| self.extract_sublevels(image_pyramid[i], capped_sub_levels)) | |
| # Below we generate the cascades of features on each level of the feature | |
| # pyramid. Assuming sub_levels=3, The layout of the features will be | |
| # as shown in the example on file documentation above. | |
| feature_pyramid = [] | |
| for i in range(len(image_pyramid)): | |
| features = sub_pyramids[i][0] | |
| for j in range(1, self.options.sub_levels): | |
| if j <= i: | |
| features = tf.concat([features, sub_pyramids[i - j][j]], axis=-1) | |
| feature_pyramid.append(features) | |
| return feature_pyramid | |