|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Multi-task image multi-taskSimCLR model definition."""
|
| from typing import Dict, Text
|
|
|
| from absl import logging
|
| import tensorflow as tf, tf_keras
|
|
|
| from official.modeling.multitask import base_model
|
| from official.projects.simclr.configs import multitask_config as simclr_multitask_config
|
| from official.projects.simclr.heads import simclr_head
|
| from official.projects.simclr.modeling import simclr_model
|
| from official.vision.modeling import backbones
|
|
|
| PROJECTION_OUTPUT_KEY = 'projection_outputs'
|
| SUPERVISED_OUTPUT_KEY = 'supervised_outputs'
|
|
|
|
|
| class SimCLRMTModel(base_model.MultiTaskBaseModel):
|
| """A multi-task SimCLR model that does both pretrain and finetune."""
|
|
|
| def __init__(self, config: simclr_multitask_config.SimCLRMTModelConfig,
|
| **kwargs):
|
| self._config = config
|
|
|
|
|
| self._input_specs = tf_keras.layers.InputSpec(shape=[None] +
|
| config.input_size)
|
|
|
| l2_weight_decay = config.l2_weight_decay
|
|
|
|
|
|
|
| self._l2_regularizer = (
|
| tf_keras.regularizers.l2(l2_weight_decay /
|
| 2.0) if l2_weight_decay else None)
|
|
|
| self._backbone = backbones.factory.build_backbone(
|
| input_specs=self._input_specs,
|
| backbone_config=config.backbone,
|
| norm_activation_config=config.norm_activation,
|
| l2_regularizer=self._l2_regularizer)
|
|
|
|
|
| norm_activation_config = self._config.norm_activation
|
| projection_head_config = self._config.projection_head
|
| self._projection_head = simclr_head.ProjectionHead(
|
| proj_output_dim=projection_head_config.proj_output_dim,
|
| num_proj_layers=projection_head_config.num_proj_layers,
|
| ft_proj_idx=projection_head_config.ft_proj_idx,
|
| kernel_regularizer=self._l2_regularizer,
|
| use_sync_bn=norm_activation_config.use_sync_bn,
|
| norm_momentum=norm_activation_config.norm_momentum,
|
| norm_epsilon=norm_activation_config.norm_epsilon)
|
|
|
| super().__init__(**kwargs)
|
|
|
| def _instantiate_sub_tasks(self) -> Dict[Text, tf_keras.Model]:
|
| tasks = {}
|
|
|
| for model_config in self._config.heads:
|
|
|
| supervised_head_config = model_config.supervised_head
|
| if supervised_head_config:
|
| if supervised_head_config.zero_init:
|
| s_kernel_initializer = 'zeros'
|
| else:
|
| s_kernel_initializer = 'random_uniform'
|
| supervised_head = simclr_head.ClassificationHead(
|
| num_classes=supervised_head_config.num_classes,
|
| kernel_initializer=s_kernel_initializer,
|
| kernel_regularizer=self._l2_regularizer)
|
| else:
|
| supervised_head = None
|
|
|
| tasks[model_config.task_name] = simclr_model.SimCLRModel(
|
| input_specs=self._input_specs,
|
| backbone=self._backbone,
|
| projection_head=self._projection_head,
|
| supervised_head=supervised_head,
|
| mode=model_config.mode,
|
| backbone_trainable=self._config.backbone_trainable)
|
|
|
| return tasks
|
|
|
| def initialize(self):
|
| """Loads the multi-task SimCLR model with a pretrained checkpoint."""
|
| ckpt_dir_or_file = self._config.init_checkpoint
|
| if tf.io.gfile.isdir(ckpt_dir_or_file):
|
| ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
|
| if not ckpt_dir_or_file:
|
| return
|
|
|
| logging.info('Loading pretrained %s', self._config.init_checkpoint_modules)
|
| if self._config.init_checkpoint_modules == 'backbone':
|
| pretrained_items = dict(backbone=self._backbone)
|
| elif self._config.init_checkpoint_modules == 'backbone_projection':
|
| pretrained_items = dict(
|
| backbone=self._backbone, projection_head=self._projection_head)
|
| else:
|
| raise ValueError(
|
| "Only 'backbone_projection' or 'backbone' can be used to "
|
| 'initialize the model.')
|
|
|
| ckpt = tf.train.Checkpoint(**pretrained_items)
|
| status = ckpt.read(ckpt_dir_or_file)
|
| status.expect_partial().assert_existing_objects_matched()
|
| logging.info('Finished loading pretrained checkpoint from %s',
|
| ckpt_dir_or_file)
|
|
|
| @property
|
| def checkpoint_items(self):
|
| """Returns a dictionary of items to be additionally checkpointed."""
|
| return dict(backbone=self._backbone, projection_head=self._projection_head)
|
|
|