|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Build segmentation models."""
|
| from typing import Any, Mapping, Union, Optional, Dict
|
|
|
|
|
| import tensorflow as tf, tf_keras
|
|
|
| layers = tf_keras.layers
|
|
|
|
|
| @tf_keras.utils.register_keras_serializable(package='Vision')
|
| class SegmentationModel(tf_keras.Model):
|
| """A Segmentation class model.
|
|
|
| Input images are passed through backbone first. Decoder network is then
|
| applied, and finally, segmentation head is applied on the output of the
|
| decoder network. Layers such as ASPP should be part of decoder. Any feature
|
| fusion is done as part of the segmentation head (i.e. deeplabv3+ feature
|
| fusion is not part of the decoder, instead it is part of the segmentation
|
| head). This way, different feature fusion techniques can be combined with
|
| different backbones, and decoders.
|
| """
|
|
|
| def __init__(self, backbone: tf_keras.Model, decoder: tf_keras.Model,
|
| head: tf_keras.layers.Layer,
|
| mask_scoring_head: Optional[tf_keras.layers.Layer] = None,
|
| **kwargs):
|
| """Segmentation initialization function.
|
|
|
| Args:
|
| backbone: a backbone network.
|
| decoder: a decoder network. E.g. FPN.
|
| head: segmentation head.
|
| mask_scoring_head: mask scoring head.
|
| **kwargs: keyword arguments to be passed.
|
| """
|
| super(SegmentationModel, self).__init__(**kwargs)
|
| self._config_dict = {
|
| 'backbone': backbone,
|
| 'decoder': decoder,
|
| 'head': head,
|
| 'mask_scoring_head': mask_scoring_head,
|
| }
|
| self.backbone = backbone
|
| self.decoder = decoder
|
| self.head = head
|
| self.mask_scoring_head = mask_scoring_head
|
|
|
| def call(self, inputs: tf.Tensor, training: bool = None
|
| ) -> Dict[str, tf.Tensor]:
|
| backbone_features = self.backbone(inputs)
|
|
|
| if self.decoder:
|
| decoder_features = self.decoder(backbone_features)
|
| else:
|
| decoder_features = backbone_features
|
|
|
| logits = self.head((backbone_features, decoder_features))
|
| outputs = {'logits': logits}
|
| if self.mask_scoring_head:
|
| mask_scores = self.mask_scoring_head(logits)
|
| outputs.update({'mask_scores': mask_scores})
|
| return outputs
|
|
|
| @property
|
| def checkpoint_items(
|
| self) -> Mapping[str, Union[tf_keras.Model, tf_keras.layers.Layer]]:
|
| """Returns a dictionary of items to be additionally checkpointed."""
|
| items = dict(backbone=self.backbone, head=self.head)
|
| if self.decoder is not None:
|
| items.update(decoder=self.decoder)
|
| if self.mask_scoring_head is not None:
|
| items.update(mask_scoring_head=self.mask_scoring_head)
|
| return items
|
|
|
| def get_config(self) -> Mapping[str, Any]:
|
| return self._config_dict
|
|
|
| @classmethod
|
| def from_config(cls, config, custom_objects=None):
|
| return cls(**config)
|
|
|