|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Provides flags that are common to scripts.
|
|
|
| Common flags from train/eval/vis/export_model.py are collected in this script.
|
| """
|
| import collections
|
| import copy
|
| import json
|
| import tensorflow as tf
|
|
|
| flags = tf.app.flags
|
|
|
|
|
|
|
| flags.DEFINE_integer('min_resize_value', None,
|
| 'Desired size of the smaller image side.')
|
|
|
| flags.DEFINE_integer('max_resize_value', None,
|
| 'Maximum allowed size of the larger image side.')
|
|
|
| flags.DEFINE_integer('resize_factor', None,
|
| 'Resized dimensions are multiple of factor plus one.')
|
|
|
| flags.DEFINE_boolean('keep_aspect_ratio', True,
|
| 'Keep aspect ratio after resizing or not.')
|
|
|
|
|
|
|
| flags.DEFINE_integer('logits_kernel_size', 1,
|
| 'The kernel size for the convolutional kernel that '
|
| 'generates logits.')
|
|
|
|
|
|
|
|
|
|
|
| flags.DEFINE_string('model_variant', 'mobilenet_v2', 'DeepLab model variant.')
|
|
|
| flags.DEFINE_multi_float('image_pyramid', None,
|
| 'Input scales for multi-scale feature extraction.')
|
|
|
| flags.DEFINE_boolean('add_image_level_feature', True,
|
| 'Add image level feature.')
|
|
|
| flags.DEFINE_list(
|
| 'image_pooling_crop_size', None,
|
| 'Image pooling crop size [height, width] used in the ASPP module. When '
|
| 'value is None, the model performs image pooling with "crop_size". This'
|
| 'flag is useful when one likes to use different image pooling sizes.')
|
|
|
| flags.DEFINE_list(
|
| 'image_pooling_stride', '1,1',
|
| 'Image pooling stride [height, width] used in the ASPP image pooling. ')
|
|
|
| flags.DEFINE_boolean('aspp_with_batch_norm', True,
|
| 'Use batch norm parameters for ASPP or not.')
|
|
|
| flags.DEFINE_boolean('aspp_with_separable_conv', True,
|
| 'Use separable convolution for ASPP or not.')
|
|
|
|
|
|
|
| flags.DEFINE_multi_integer('multi_grid', None,
|
| 'Employ a hierarchy of atrous rates for ResNet.')
|
|
|
| flags.DEFINE_float('depth_multiplier', 1.0,
|
| 'Multiplier for the depth (number of channels) for all '
|
| 'convolution ops used in MobileNet.')
|
|
|
| flags.DEFINE_integer('divisible_by', None,
|
| 'An integer that ensures the layer # channels are '
|
| 'divisible by this value. Used in MobileNet.')
|
|
|
|
|
|
|
| flags.DEFINE_list('decoder_output_stride', None,
|
| 'Comma-separated list of strings with the number specifying '
|
| 'output stride of low-level features at each network level.'
|
| 'Current semantic segmentation implementation assumes at '
|
| 'most one output stride (i.e., either None or a list with '
|
| 'only one element.')
|
|
|
| flags.DEFINE_boolean('decoder_use_separable_conv', True,
|
| 'Employ separable convolution for decoder or not.')
|
|
|
| flags.DEFINE_enum('merge_method', 'max', ['max', 'avg'],
|
| 'Scheme to merge multi scale features.')
|
|
|
| flags.DEFINE_boolean(
|
| 'prediction_with_upsampled_logits', True,
|
| 'When performing prediction, there are two options: (1) bilinear '
|
| 'upsampling the logits followed by softmax, or (2) softmax followed by '
|
| 'bilinear upsampling.')
|
|
|
| flags.DEFINE_string(
|
| 'dense_prediction_cell_json',
|
| '',
|
| 'A JSON file that specifies the dense prediction cell.')
|
|
|
| flags.DEFINE_integer(
|
| 'nas_stem_output_num_conv_filters', 20,
|
| 'Number of filters of the stem output tensor in NAS models.')
|
|
|
| flags.DEFINE_bool('nas_use_classification_head', False,
|
| 'Use image classification head for NAS model variants.')
|
|
|
| flags.DEFINE_bool('nas_remove_os32_stride', False,
|
| 'Remove the stride in the output stride 32 branch.')
|
|
|
| flags.DEFINE_bool('use_bounded_activation', False,
|
| 'Whether or not to use bounded activations. Bounded '
|
| 'activations better lend themselves to quantized inference.')
|
|
|
| flags.DEFINE_boolean('aspp_with_concat_projection', True,
|
| 'ASPP with concat projection.')
|
|
|
| flags.DEFINE_boolean('aspp_with_squeeze_and_excitation', False,
|
| 'ASPP with squeeze and excitation.')
|
|
|
| flags.DEFINE_integer('aspp_convs_filters', 256, 'ASPP convolution filters.')
|
|
|
| flags.DEFINE_boolean('decoder_use_sum_merge', False,
|
| 'Decoder uses simply sum merge.')
|
|
|
| flags.DEFINE_integer('decoder_filters', 256, 'Decoder filters.')
|
|
|
| flags.DEFINE_boolean('decoder_output_is_logits', False,
|
| 'Use decoder output as logits or not.')
|
|
|
| flags.DEFINE_boolean('image_se_uses_qsigmoid', False, 'Use q-sigmoid.')
|
|
|
| flags.DEFINE_multi_float(
|
| 'label_weights', None,
|
| 'A list of label weights, each element represents the weight for the label '
|
| 'of its index, for example, label_weights = [0.1, 0.5] means the weight '
|
| 'for label 0 is 0.1 and the weight for label 1 is 0.5. If set as None, all '
|
| 'the labels have the same weight 1.0.')
|
|
|
| flags.DEFINE_float('batch_norm_decay', 0.9997, 'Batchnorm decay.')
|
|
|
| FLAGS = flags.FLAGS
|
|
|
|
|
|
|
|
|
| OUTPUT_TYPE = 'semantic'
|
|
|
|
|
| LABELS_CLASS = 'labels_class'
|
| IMAGE = 'image'
|
| HEIGHT = 'height'
|
| WIDTH = 'width'
|
| IMAGE_NAME = 'image_name'
|
| LABEL = 'label'
|
| ORIGINAL_IMAGE = 'original_image'
|
|
|
|
|
| TEST_SET = 'test'
|
|
|
|
|
| class ModelOptions(
|
| collections.namedtuple('ModelOptions', [
|
| 'outputs_to_num_classes',
|
| 'crop_size',
|
| 'atrous_rates',
|
| 'output_stride',
|
| 'preprocessed_images_dtype',
|
| 'merge_method',
|
| 'add_image_level_feature',
|
| 'image_pooling_crop_size',
|
| 'image_pooling_stride',
|
| 'aspp_with_batch_norm',
|
| 'aspp_with_separable_conv',
|
| 'multi_grid',
|
| 'decoder_output_stride',
|
| 'decoder_use_separable_conv',
|
| 'logits_kernel_size',
|
| 'model_variant',
|
| 'depth_multiplier',
|
| 'divisible_by',
|
| 'prediction_with_upsampled_logits',
|
| 'dense_prediction_cell_config',
|
| 'nas_architecture_options',
|
| 'use_bounded_activation',
|
| 'aspp_with_concat_projection',
|
| 'aspp_with_squeeze_and_excitation',
|
| 'aspp_convs_filters',
|
| 'decoder_use_sum_merge',
|
| 'decoder_filters',
|
| 'decoder_output_is_logits',
|
| 'image_se_uses_qsigmoid',
|
| 'label_weights',
|
| 'sync_batch_norm_method',
|
| 'batch_norm_decay',
|
| ])):
|
| """Immutable class to hold model options."""
|
|
|
| __slots__ = ()
|
|
|
| def __new__(cls,
|
| outputs_to_num_classes,
|
| crop_size=None,
|
| atrous_rates=None,
|
| output_stride=8,
|
| preprocessed_images_dtype=tf.float32):
|
| """Constructor to set default values.
|
|
|
| Args:
|
| outputs_to_num_classes: A dictionary from output type to the number of
|
| classes. For example, for the task of semantic segmentation with 21
|
| semantic classes, we would have outputs_to_num_classes['semantic'] = 21.
|
| crop_size: A tuple [crop_height, crop_width].
|
| atrous_rates: A list of atrous convolution rates for ASPP.
|
| output_stride: The ratio of input to output spatial resolution.
|
| preprocessed_images_dtype: The type after the preprocessing function.
|
|
|
| Returns:
|
| A new ModelOptions instance.
|
| """
|
| dense_prediction_cell_config = None
|
| if FLAGS.dense_prediction_cell_json:
|
| with tf.gfile.Open(FLAGS.dense_prediction_cell_json, 'r') as f:
|
| dense_prediction_cell_config = json.load(f)
|
| decoder_output_stride = None
|
| if FLAGS.decoder_output_stride:
|
| decoder_output_stride = [
|
| int(x) for x in FLAGS.decoder_output_stride]
|
| if sorted(decoder_output_stride, reverse=True) != decoder_output_stride:
|
| raise ValueError('Decoder output stride need to be sorted in the '
|
| 'descending order.')
|
| image_pooling_crop_size = None
|
| if FLAGS.image_pooling_crop_size:
|
| image_pooling_crop_size = [int(x) for x in FLAGS.image_pooling_crop_size]
|
| image_pooling_stride = [1, 1]
|
| if FLAGS.image_pooling_stride:
|
| image_pooling_stride = [int(x) for x in FLAGS.image_pooling_stride]
|
| label_weights = FLAGS.label_weights
|
| if label_weights is None:
|
| label_weights = 1.0
|
| nas_architecture_options = {
|
| 'nas_stem_output_num_conv_filters': (
|
| FLAGS.nas_stem_output_num_conv_filters),
|
| 'nas_use_classification_head': FLAGS.nas_use_classification_head,
|
| 'nas_remove_os32_stride': FLAGS.nas_remove_os32_stride,
|
| }
|
| return super(ModelOptions, cls).__new__(
|
| cls, outputs_to_num_classes, crop_size, atrous_rates, output_stride,
|
| preprocessed_images_dtype,
|
| FLAGS.merge_method,
|
| FLAGS.add_image_level_feature,
|
| image_pooling_crop_size,
|
| image_pooling_stride,
|
| FLAGS.aspp_with_batch_norm,
|
| FLAGS.aspp_with_separable_conv,
|
| FLAGS.multi_grid,
|
| decoder_output_stride,
|
| FLAGS.decoder_use_separable_conv,
|
| FLAGS.logits_kernel_size,
|
| FLAGS.model_variant,
|
| FLAGS.depth_multiplier,
|
| FLAGS.divisible_by,
|
| FLAGS.prediction_with_upsampled_logits,
|
| dense_prediction_cell_config,
|
| nas_architecture_options,
|
| FLAGS.use_bounded_activation,
|
| FLAGS.aspp_with_concat_projection,
|
| FLAGS.aspp_with_squeeze_and_excitation,
|
| FLAGS.aspp_convs_filters,
|
| FLAGS.decoder_use_sum_merge,
|
| FLAGS.decoder_filters,
|
| FLAGS.decoder_output_is_logits,
|
| FLAGS.image_se_uses_qsigmoid,
|
| label_weights,
|
| 'None',
|
| FLAGS.batch_norm_decay)
|
|
|
| def __deepcopy__(self, memo):
|
| return ModelOptions(copy.deepcopy(self.outputs_to_num_classes),
|
| self.crop_size,
|
| self.atrous_rates,
|
| self.output_stride,
|
| self.preprocessed_images_dtype)
|
|
|