|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Build simclr models."""
|
| from typing import Optional
|
| from absl import logging
|
|
|
| import tensorflow as tf, tf_keras
|
|
|
| layers = tf_keras.layers
|
|
|
| PRETRAIN = 'pretrain'
|
| FINETUNE = 'finetune'
|
|
|
| PROJECTION_OUTPUT_KEY = 'projection_outputs'
|
| SUPERVISED_OUTPUT_KEY = 'supervised_outputs'
|
|
|
|
|
| class SimCLRModel(tf_keras.Model):
|
| """A classification model based on SimCLR framework."""
|
|
|
| def __init__(self,
|
| backbone: tf_keras.models.Model,
|
| projection_head: tf_keras.layers.Layer,
|
| supervised_head: Optional[tf_keras.layers.Layer] = None,
|
| input_specs=layers.InputSpec(shape=[None, None, None, 3]),
|
| mode: str = PRETRAIN,
|
| backbone_trainable: bool = True,
|
| **kwargs):
|
| """A classification model based on SimCLR framework.
|
|
|
| Args:
|
| backbone: a backbone network.
|
| projection_head: a projection head network.
|
| supervised_head: a head network for supervised learning, e.g.
|
| classification head.
|
| input_specs: `tf_keras.layers.InputSpec` specs of the input tensor.
|
| mode: `str` indicates mode of training to be executed.
|
| backbone_trainable: `bool` whether the backbone is trainable or not.
|
| **kwargs: keyword arguments to be passed.
|
| """
|
| super(SimCLRModel, self).__init__(**kwargs)
|
| self._config_dict = {
|
| 'backbone': backbone,
|
| 'projection_head': projection_head,
|
| 'supervised_head': supervised_head,
|
| 'input_specs': input_specs,
|
| 'mode': mode,
|
| 'backbone_trainable': backbone_trainable,
|
| }
|
| self._input_specs = input_specs
|
| self._backbone = backbone
|
| self._projection_head = projection_head
|
| self._supervised_head = supervised_head
|
| self._mode = mode
|
| self._backbone_trainable = backbone_trainable
|
|
|
|
|
| self._backbone.trainable = backbone_trainable
|
|
|
| def call(self, inputs, training=None, **kwargs):
|
| model_outputs = {}
|
|
|
| if training and self._mode == PRETRAIN:
|
| num_transforms = 2
|
|
|
|
|
| features_list = tf.split(
|
| inputs, num_or_size_splits=num_transforms, axis=-1)
|
|
|
| features = tf.concat(features_list, 0)
|
| else:
|
| num_transforms = 1
|
| features = inputs
|
|
|
|
|
| endpoints = self._backbone(
|
| features, training=training and self._backbone_trainable)
|
| features = endpoints[max(endpoints.keys())]
|
| projection_inputs = layers.GlobalAveragePooling2D()(features)
|
|
|
|
|
| projection_outputs, supervised_inputs = self._projection_head(
|
| projection_inputs, training)
|
|
|
| if self._supervised_head is not None:
|
| if self._mode == PRETRAIN:
|
| logging.info('Ignoring gradient from supervised outputs !')
|
|
|
|
|
|
|
| supervised_outputs = self._supervised_head(
|
| tf.stop_gradient(supervised_inputs), training)
|
| else:
|
| supervised_outputs = self._supervised_head(supervised_inputs, training)
|
| else:
|
| supervised_outputs = None
|
|
|
| model_outputs.update({
|
| PROJECTION_OUTPUT_KEY: projection_outputs,
|
| SUPERVISED_OUTPUT_KEY: supervised_outputs
|
| })
|
|
|
| return model_outputs
|
|
|
| @property
|
| def checkpoint_items(self):
|
| """Returns a dictionary of items to be additionally checkpointed."""
|
| if self._supervised_head is not None:
|
| items = dict(
|
| backbone=self.backbone,
|
| projection_head=self.projection_head,
|
| supervised_head=self.supervised_head)
|
| else:
|
| items = dict(backbone=self.backbone, projection_head=self.projection_head)
|
| return items
|
|
|
| @property
|
| def backbone(self):
|
| return self._backbone
|
|
|
| @property
|
| def projection_head(self):
|
| return self._projection_head
|
|
|
| @property
|
| def supervised_head(self):
|
| return self._supervised_head
|
|
|
| @property
|
| def mode(self):
|
| return self._mode
|
|
|
| @mode.setter
|
| def mode(self, value):
|
| self._mode = value
|
|
|
| @property
|
| def backbone_trainable(self):
|
| return self._backbone_trainable
|
|
|
| @backbone_trainable.setter
|
| def backbone_trainable(self, value):
|
| self._backbone_trainable = value
|
| self._backbone.trainable = value
|
|
|
| def get_config(self):
|
| return self._config_dict
|
|
|
| @classmethod
|
| def from_config(cls, config, custom_objects=None):
|
| return cls(**config)
|
|
|