|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r"""Provides DeepLab model definition and helper functions. |
|
|
|
|
|
DeepLab is a deep learning system for semantic image segmentation with |
|
|
the following features: |
|
|
|
|
|
(1) Atrous convolution to explicitly control the resolution at which |
|
|
feature responses are computed within Deep Convolutional Neural Networks. |
|
|
|
|
|
(2) Atrous spatial pyramid pooling (ASPP) to robustly segment objects at |
|
|
multiple scales with filters at multiple sampling rates and effective |
|
|
fields-of-views. |
|
|
|
|
|
(3) ASPP module augmented with image-level feature and batch normalization. |
|
|
|
|
|
(4) A simple yet effective decoder module to recover the object boundaries. |
|
|
|
|
|
See the following papers for more details: |
|
|
|
|
|
"Encoder-Decoder with Atrous Separable Convolution for Semantic Image |
|
|
Segmentation" |
|
|
Liang-Chieh Chen, Yukun Zhu, George Papandreou, Florian Schroff, Hartwig Adam. |
|
|
(https://arxiv.org/abs/1802.02611) |
|
|
|
|
|
"Rethinking Atrous Convolution for Semantic Image Segmentation," |
|
|
Liang-Chieh Chen, George Papandreou, Florian Schroff, Hartwig Adam |
|
|
(https://arxiv.org/abs/1706.05587) |
|
|
|
|
|
"DeepLab: Semantic Image Segmentation with Deep Convolutional Nets, |
|
|
Atrous Convolution, and Fully Connected CRFs", |
|
|
Liang-Chieh Chen*, George Papandreou*, Iasonas Kokkinos, Kevin Murphy, |
|
|
Alan L Yuille (* equal contribution) |
|
|
(https://arxiv.org/abs/1606.00915) |
|
|
|
|
|
"Semantic Image Segmentation with Deep Convolutional Nets and Fully Connected |
|
|
CRFs" |
|
|
Liang-Chieh Chen*, George Papandreou*, Iasonas Kokkinos, Kevin Murphy, |
|
|
Alan L. Yuille (* equal contribution) |
|
|
(https://arxiv.org/abs/1412.7062) |
|
|
""" |
|
|
import tensorflow as tf |
|
|
from tensorflow.contrib import slim as contrib_slim |
|
|
from deeplab.core import dense_prediction_cell |
|
|
from deeplab.core import feature_extractor |
|
|
from deeplab.core import utils |
|
|
|
|
|
slim = contrib_slim |
|
|
|
|
|
LOGITS_SCOPE_NAME = 'logits' |
|
|
MERGED_LOGITS_SCOPE = 'merged_logits' |
|
|
IMAGE_POOLING_SCOPE = 'image_pooling' |
|
|
ASPP_SCOPE = 'aspp' |
|
|
CONCAT_PROJECTION_SCOPE = 'concat_projection' |
|
|
DECODER_SCOPE = 'decoder' |
|
|
META_ARCHITECTURE_SCOPE = 'meta_architecture' |
|
|
|
|
|
PROB_SUFFIX = '_prob' |
|
|
|
|
|
_resize_bilinear = utils.resize_bilinear |
|
|
scale_dimension = utils.scale_dimension |
|
|
split_separable_conv2d = utils.split_separable_conv2d |
|
|
|
|
|
|
|
|
def get_extra_layer_scopes(last_layers_contain_logits_only=False): |
|
|
"""Gets the scopes for extra layers. |
|
|
|
|
|
Args: |
|
|
last_layers_contain_logits_only: Boolean, True if only consider logits as |
|
|
the last layer (i.e., exclude ASPP module, decoder module and so on) |
|
|
|
|
|
Returns: |
|
|
A list of scopes for extra layers. |
|
|
""" |
|
|
if last_layers_contain_logits_only: |
|
|
return [LOGITS_SCOPE_NAME] |
|
|
else: |
|
|
return [ |
|
|
LOGITS_SCOPE_NAME, |
|
|
IMAGE_POOLING_SCOPE, |
|
|
ASPP_SCOPE, |
|
|
CONCAT_PROJECTION_SCOPE, |
|
|
DECODER_SCOPE, |
|
|
META_ARCHITECTURE_SCOPE, |
|
|
] |
|
|
|
|
|
|
|
|
def predict_labels_multi_scale(images, |
|
|
model_options, |
|
|
eval_scales=(1.0,), |
|
|
add_flipped_images=False): |
|
|
"""Predicts segmentation labels. |
|
|
|
|
|
Args: |
|
|
images: A tensor of size [batch, height, width, channels]. |
|
|
model_options: A ModelOptions instance to configure models. |
|
|
eval_scales: The scales to resize images for evaluation. |
|
|
add_flipped_images: Add flipped images for evaluation or not. |
|
|
|
|
|
Returns: |
|
|
A dictionary with keys specifying the output_type (e.g., semantic |
|
|
prediction) and values storing Tensors representing predictions (argmax |
|
|
over channels). Each prediction has size [batch, height, width]. |
|
|
""" |
|
|
outputs_to_predictions = { |
|
|
output: [] |
|
|
for output in model_options.outputs_to_num_classes |
|
|
} |
|
|
|
|
|
for i, image_scale in enumerate(eval_scales): |
|
|
with tf.variable_scope(tf.get_variable_scope(), reuse=True if i else None): |
|
|
outputs_to_scales_to_logits = multi_scale_logits( |
|
|
images, |
|
|
model_options=model_options, |
|
|
image_pyramid=[image_scale], |
|
|
is_training=False, |
|
|
fine_tune_batch_norm=False) |
|
|
|
|
|
if add_flipped_images: |
|
|
with tf.variable_scope(tf.get_variable_scope(), reuse=True): |
|
|
outputs_to_scales_to_logits_reversed = multi_scale_logits( |
|
|
tf.reverse_v2(images, [2]), |
|
|
model_options=model_options, |
|
|
image_pyramid=[image_scale], |
|
|
is_training=False, |
|
|
fine_tune_batch_norm=False) |
|
|
|
|
|
for output in sorted(outputs_to_scales_to_logits): |
|
|
scales_to_logits = outputs_to_scales_to_logits[output] |
|
|
logits = _resize_bilinear( |
|
|
scales_to_logits[MERGED_LOGITS_SCOPE], |
|
|
tf.shape(images)[1:3], |
|
|
scales_to_logits[MERGED_LOGITS_SCOPE].dtype) |
|
|
outputs_to_predictions[output].append( |
|
|
tf.expand_dims(tf.nn.softmax(logits), 4)) |
|
|
|
|
|
if add_flipped_images: |
|
|
scales_to_logits_reversed = ( |
|
|
outputs_to_scales_to_logits_reversed[output]) |
|
|
logits_reversed = _resize_bilinear( |
|
|
tf.reverse_v2(scales_to_logits_reversed[MERGED_LOGITS_SCOPE], [2]), |
|
|
tf.shape(images)[1:3], |
|
|
scales_to_logits_reversed[MERGED_LOGITS_SCOPE].dtype) |
|
|
outputs_to_predictions[output].append( |
|
|
tf.expand_dims(tf.nn.softmax(logits_reversed), 4)) |
|
|
|
|
|
for output in sorted(outputs_to_predictions): |
|
|
predictions = outputs_to_predictions[output] |
|
|
|
|
|
predictions = tf.reduce_mean(tf.concat(predictions, 4), axis=4) |
|
|
outputs_to_predictions[output] = tf.argmax(predictions, 3) |
|
|
outputs_to_predictions[output + PROB_SUFFIX] = tf.nn.softmax(predictions) |
|
|
|
|
|
return outputs_to_predictions |
|
|
|
|
|
|
|
|
def predict_labels(images, model_options, image_pyramid=None): |
|
|
"""Predicts segmentation labels. |
|
|
|
|
|
Args: |
|
|
images: A tensor of size [batch, height, width, channels]. |
|
|
model_options: A ModelOptions instance to configure models. |
|
|
image_pyramid: Input image scales for multi-scale feature extraction. |
|
|
|
|
|
Returns: |
|
|
A dictionary with keys specifying the output_type (e.g., semantic |
|
|
prediction) and values storing Tensors representing predictions (argmax |
|
|
over channels). Each prediction has size [batch, height, width]. |
|
|
""" |
|
|
outputs_to_scales_to_logits = multi_scale_logits( |
|
|
images, |
|
|
model_options=model_options, |
|
|
image_pyramid=image_pyramid, |
|
|
is_training=False, |
|
|
fine_tune_batch_norm=False) |
|
|
|
|
|
predictions = {} |
|
|
for output in sorted(outputs_to_scales_to_logits): |
|
|
scales_to_logits = outputs_to_scales_to_logits[output] |
|
|
logits = scales_to_logits[MERGED_LOGITS_SCOPE] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if model_options.prediction_with_upsampled_logits: |
|
|
logits = _resize_bilinear(logits, |
|
|
tf.shape(images)[1:3], |
|
|
scales_to_logits[MERGED_LOGITS_SCOPE].dtype) |
|
|
predictions[output] = tf.argmax(logits, 3) |
|
|
predictions[output + PROB_SUFFIX] = tf.nn.softmax(logits) |
|
|
else: |
|
|
argmax_results = tf.argmax(logits, 3) |
|
|
argmax_results = tf.image.resize_nearest_neighbor( |
|
|
tf.expand_dims(argmax_results, 3), |
|
|
tf.shape(images)[1:3], |
|
|
align_corners=True, |
|
|
name='resize_prediction') |
|
|
predictions[output] = tf.squeeze(argmax_results, 3) |
|
|
predictions[output + PROB_SUFFIX] = tf.image.resize_bilinear( |
|
|
tf.nn.softmax(logits), |
|
|
tf.shape(images)[1:3], |
|
|
align_corners=True, |
|
|
name='resize_prob') |
|
|
return predictions |
|
|
|
|
|
|
|
|
def multi_scale_logits(images, |
|
|
model_options, |
|
|
image_pyramid, |
|
|
weight_decay=0.0001, |
|
|
is_training=False, |
|
|
fine_tune_batch_norm=False, |
|
|
nas_training_hyper_parameters=None): |
|
|
"""Gets the logits for multi-scale inputs. |
|
|
|
|
|
The returned logits are all downsampled (due to max-pooling layers) |
|
|
for both training and evaluation. |
|
|
|
|
|
Args: |
|
|
images: A tensor of size [batch, height, width, channels]. |
|
|
model_options: A ModelOptions instance to configure models. |
|
|
image_pyramid: Input image scales for multi-scale feature extraction. |
|
|
weight_decay: The weight decay for model variables. |
|
|
is_training: Is training or not. |
|
|
fine_tune_batch_norm: Fine-tune the batch norm parameters or not. |
|
|
nas_training_hyper_parameters: A dictionary storing hyper-parameters for |
|
|
training nas models. Its keys are: |
|
|
- `drop_path_keep_prob`: Probability to keep each path in the cell when |
|
|
training. |
|
|
- `total_training_steps`: Total training steps to help drop path |
|
|
probability calculation. |
|
|
|
|
|
Returns: |
|
|
outputs_to_scales_to_logits: A map of maps from output_type (e.g., |
|
|
semantic prediction) to a dictionary of multi-scale logits names to |
|
|
logits. For each output_type, the dictionary has keys which |
|
|
correspond to the scales and values which correspond to the logits. |
|
|
For example, if `scales` equals [1.0, 1.5], then the keys would |
|
|
include 'merged_logits', 'logits_1.00' and 'logits_1.50'. |
|
|
|
|
|
Raises: |
|
|
ValueError: If model_options doesn't specify crop_size and its |
|
|
add_image_level_feature = True, since add_image_level_feature requires |
|
|
crop_size information. |
|
|
""" |
|
|
|
|
|
if not image_pyramid: |
|
|
image_pyramid = [1.0] |
|
|
crop_height = ( |
|
|
model_options.crop_size[0] |
|
|
if model_options.crop_size else tf.shape(images)[1]) |
|
|
crop_width = ( |
|
|
model_options.crop_size[1] |
|
|
if model_options.crop_size else tf.shape(images)[2]) |
|
|
if model_options.image_pooling_crop_size: |
|
|
image_pooling_crop_height = model_options.image_pooling_crop_size[0] |
|
|
image_pooling_crop_width = model_options.image_pooling_crop_size[1] |
|
|
|
|
|
|
|
|
if model_options.decoder_output_stride: |
|
|
logits_output_stride = min(model_options.decoder_output_stride) |
|
|
else: |
|
|
logits_output_stride = model_options.output_stride |
|
|
|
|
|
logits_height = scale_dimension( |
|
|
crop_height, |
|
|
max(1.0, max(image_pyramid)) / logits_output_stride) |
|
|
logits_width = scale_dimension( |
|
|
crop_width, |
|
|
max(1.0, max(image_pyramid)) / logits_output_stride) |
|
|
|
|
|
|
|
|
outputs_to_scales_to_logits = { |
|
|
k: {} |
|
|
for k in model_options.outputs_to_num_classes |
|
|
} |
|
|
|
|
|
num_channels = images.get_shape().as_list()[-1] |
|
|
|
|
|
for image_scale in image_pyramid: |
|
|
if image_scale != 1.0: |
|
|
scaled_height = scale_dimension(crop_height, image_scale) |
|
|
scaled_width = scale_dimension(crop_width, image_scale) |
|
|
scaled_crop_size = [scaled_height, scaled_width] |
|
|
scaled_images = _resize_bilinear(images, scaled_crop_size, images.dtype) |
|
|
if model_options.crop_size: |
|
|
scaled_images.set_shape( |
|
|
[None, scaled_height, scaled_width, num_channels]) |
|
|
|
|
|
scaled_image_pooling_crop_size = None |
|
|
if model_options.image_pooling_crop_size: |
|
|
scaled_image_pooling_crop_size = [ |
|
|
scale_dimension(image_pooling_crop_height, image_scale), |
|
|
scale_dimension(image_pooling_crop_width, image_scale)] |
|
|
else: |
|
|
scaled_crop_size = model_options.crop_size |
|
|
scaled_images = images |
|
|
scaled_image_pooling_crop_size = model_options.image_pooling_crop_size |
|
|
|
|
|
updated_options = model_options._replace( |
|
|
crop_size=scaled_crop_size, |
|
|
image_pooling_crop_size=scaled_image_pooling_crop_size) |
|
|
outputs_to_logits = _get_logits( |
|
|
scaled_images, |
|
|
updated_options, |
|
|
weight_decay=weight_decay, |
|
|
reuse=tf.AUTO_REUSE, |
|
|
is_training=is_training, |
|
|
fine_tune_batch_norm=fine_tune_batch_norm, |
|
|
nas_training_hyper_parameters=nas_training_hyper_parameters) |
|
|
|
|
|
|
|
|
for output in sorted(outputs_to_logits): |
|
|
outputs_to_logits[output] = _resize_bilinear( |
|
|
outputs_to_logits[output], [logits_height, logits_width], |
|
|
outputs_to_logits[output].dtype) |
|
|
|
|
|
|
|
|
if len(image_pyramid) == 1: |
|
|
for output in sorted(model_options.outputs_to_num_classes): |
|
|
outputs_to_scales_to_logits[output][ |
|
|
MERGED_LOGITS_SCOPE] = outputs_to_logits[output] |
|
|
return outputs_to_scales_to_logits |
|
|
|
|
|
|
|
|
for output in sorted(model_options.outputs_to_num_classes): |
|
|
outputs_to_scales_to_logits[output][ |
|
|
'logits_%.2f' % image_scale] = outputs_to_logits[output] |
|
|
|
|
|
|
|
|
for output in sorted(model_options.outputs_to_num_classes): |
|
|
|
|
|
all_logits = [ |
|
|
tf.expand_dims(logits, axis=4) |
|
|
for logits in outputs_to_scales_to_logits[output].values() |
|
|
] |
|
|
all_logits = tf.concat(all_logits, 4) |
|
|
merge_fn = ( |
|
|
tf.reduce_max |
|
|
if model_options.merge_method == 'max' else tf.reduce_mean) |
|
|
outputs_to_scales_to_logits[output][MERGED_LOGITS_SCOPE] = merge_fn( |
|
|
all_logits, axis=4) |
|
|
|
|
|
return outputs_to_scales_to_logits |
|
|
|
|
|
|
|
|
def extract_features(images, |
|
|
model_options, |
|
|
weight_decay=0.0001, |
|
|
reuse=None, |
|
|
is_training=False, |
|
|
fine_tune_batch_norm=False, |
|
|
nas_training_hyper_parameters=None): |
|
|
"""Extracts features by the particular model_variant. |
|
|
|
|
|
Args: |
|
|
images: A tensor of size [batch, height, width, channels]. |
|
|
model_options: A ModelOptions instance to configure models. |
|
|
weight_decay: The weight decay for model variables. |
|
|
reuse: Reuse the model variables or not. |
|
|
is_training: Is training or not. |
|
|
fine_tune_batch_norm: Fine-tune the batch norm parameters or not. |
|
|
nas_training_hyper_parameters: A dictionary storing hyper-parameters for |
|
|
training nas models. Its keys are: |
|
|
- `drop_path_keep_prob`: Probability to keep each path in the cell when |
|
|
training. |
|
|
- `total_training_steps`: Total training steps to help drop path |
|
|
probability calculation. |
|
|
|
|
|
Returns: |
|
|
concat_logits: A tensor of size [batch, feature_height, feature_width, |
|
|
feature_channels], where feature_height/feature_width are determined by |
|
|
the images height/width and output_stride. |
|
|
end_points: A dictionary from components of the network to the corresponding |
|
|
activation. |
|
|
""" |
|
|
features, end_points = feature_extractor.extract_features( |
|
|
images, |
|
|
output_stride=model_options.output_stride, |
|
|
multi_grid=model_options.multi_grid, |
|
|
model_variant=model_options.model_variant, |
|
|
depth_multiplier=model_options.depth_multiplier, |
|
|
divisible_by=model_options.divisible_by, |
|
|
weight_decay=weight_decay, |
|
|
reuse=reuse, |
|
|
is_training=is_training, |
|
|
preprocessed_images_dtype=model_options.preprocessed_images_dtype, |
|
|
fine_tune_batch_norm=fine_tune_batch_norm, |
|
|
nas_architecture_options=model_options.nas_architecture_options, |
|
|
nas_training_hyper_parameters=nas_training_hyper_parameters, |
|
|
use_bounded_activation=model_options.use_bounded_activation) |
|
|
|
|
|
if not model_options.aspp_with_batch_norm: |
|
|
return features, end_points |
|
|
else: |
|
|
if model_options.dense_prediction_cell_config is not None: |
|
|
tf.logging.info('Using dense prediction cell config.') |
|
|
dense_prediction_layer = dense_prediction_cell.DensePredictionCell( |
|
|
config=model_options.dense_prediction_cell_config, |
|
|
hparams={ |
|
|
'conv_rate_multiplier': 16 // model_options.output_stride, |
|
|
}) |
|
|
concat_logits = dense_prediction_layer.build_cell( |
|
|
features, |
|
|
output_stride=model_options.output_stride, |
|
|
crop_size=model_options.crop_size, |
|
|
image_pooling_crop_size=model_options.image_pooling_crop_size, |
|
|
weight_decay=weight_decay, |
|
|
reuse=reuse, |
|
|
is_training=is_training, |
|
|
fine_tune_batch_norm=fine_tune_batch_norm) |
|
|
return concat_logits, end_points |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
batch_norm_params = utils.get_batch_norm_params( |
|
|
decay=0.9997, |
|
|
epsilon=1e-5, |
|
|
scale=True, |
|
|
is_training=(is_training and fine_tune_batch_norm), |
|
|
sync_batch_norm_method=model_options.sync_batch_norm_method) |
|
|
batch_norm = utils.get_batch_norm_fn( |
|
|
model_options.sync_batch_norm_method) |
|
|
activation_fn = ( |
|
|
tf.nn.relu6 if model_options.use_bounded_activation else tf.nn.relu) |
|
|
with slim.arg_scope( |
|
|
[slim.conv2d, slim.separable_conv2d], |
|
|
weights_regularizer=slim.l2_regularizer(weight_decay), |
|
|
activation_fn=activation_fn, |
|
|
normalizer_fn=batch_norm, |
|
|
padding='SAME', |
|
|
stride=1, |
|
|
reuse=reuse): |
|
|
with slim.arg_scope([batch_norm], **batch_norm_params): |
|
|
depth = model_options.aspp_convs_filters |
|
|
branch_logits = [] |
|
|
|
|
|
if model_options.add_image_level_feature: |
|
|
if model_options.crop_size is not None: |
|
|
image_pooling_crop_size = model_options.image_pooling_crop_size |
|
|
|
|
|
if image_pooling_crop_size is None: |
|
|
image_pooling_crop_size = model_options.crop_size |
|
|
pool_height = scale_dimension( |
|
|
image_pooling_crop_size[0], |
|
|
1. / model_options.output_stride) |
|
|
pool_width = scale_dimension( |
|
|
image_pooling_crop_size[1], |
|
|
1. / model_options.output_stride) |
|
|
image_feature = slim.avg_pool2d( |
|
|
features, [pool_height, pool_width], |
|
|
model_options.image_pooling_stride, padding='VALID') |
|
|
resize_height = scale_dimension( |
|
|
model_options.crop_size[0], |
|
|
1. / model_options.output_stride) |
|
|
resize_width = scale_dimension( |
|
|
model_options.crop_size[1], |
|
|
1. / model_options.output_stride) |
|
|
else: |
|
|
|
|
|
pool_height = tf.shape(features)[1] |
|
|
pool_width = tf.shape(features)[2] |
|
|
image_feature = tf.reduce_mean( |
|
|
features, axis=[1, 2], keepdims=True) |
|
|
resize_height = pool_height |
|
|
resize_width = pool_width |
|
|
image_feature_activation_fn = tf.nn.relu |
|
|
image_feature_normalizer_fn = batch_norm |
|
|
if model_options.aspp_with_squeeze_and_excitation: |
|
|
image_feature_activation_fn = tf.nn.sigmoid |
|
|
if model_options.image_se_uses_qsigmoid: |
|
|
image_feature_activation_fn = utils.q_sigmoid |
|
|
image_feature_normalizer_fn = None |
|
|
image_feature = slim.conv2d( |
|
|
image_feature, depth, 1, |
|
|
activation_fn=image_feature_activation_fn, |
|
|
normalizer_fn=image_feature_normalizer_fn, |
|
|
scope=IMAGE_POOLING_SCOPE) |
|
|
image_feature = _resize_bilinear( |
|
|
image_feature, |
|
|
[resize_height, resize_width], |
|
|
image_feature.dtype) |
|
|
|
|
|
if isinstance(resize_height, tf.Tensor): |
|
|
resize_height = None |
|
|
if isinstance(resize_width, tf.Tensor): |
|
|
resize_width = None |
|
|
image_feature.set_shape([None, resize_height, resize_width, depth]) |
|
|
if not model_options.aspp_with_squeeze_and_excitation: |
|
|
branch_logits.append(image_feature) |
|
|
|
|
|
|
|
|
branch_logits.append(slim.conv2d(features, depth, 1, |
|
|
scope=ASPP_SCOPE + str(0))) |
|
|
|
|
|
if model_options.atrous_rates: |
|
|
|
|
|
for i, rate in enumerate(model_options.atrous_rates, 1): |
|
|
scope = ASPP_SCOPE + str(i) |
|
|
if model_options.aspp_with_separable_conv: |
|
|
aspp_features = split_separable_conv2d( |
|
|
features, |
|
|
filters=depth, |
|
|
rate=rate, |
|
|
weight_decay=weight_decay, |
|
|
scope=scope) |
|
|
else: |
|
|
aspp_features = slim.conv2d( |
|
|
features, depth, 3, rate=rate, scope=scope) |
|
|
branch_logits.append(aspp_features) |
|
|
|
|
|
|
|
|
concat_logits = tf.concat(branch_logits, 3) |
|
|
if model_options.aspp_with_concat_projection: |
|
|
concat_logits = slim.conv2d( |
|
|
concat_logits, depth, 1, scope=CONCAT_PROJECTION_SCOPE) |
|
|
concat_logits = slim.dropout( |
|
|
concat_logits, |
|
|
keep_prob=0.9, |
|
|
is_training=is_training, |
|
|
scope=CONCAT_PROJECTION_SCOPE + '_dropout') |
|
|
if (model_options.add_image_level_feature and |
|
|
model_options.aspp_with_squeeze_and_excitation): |
|
|
concat_logits *= image_feature |
|
|
|
|
|
return concat_logits, end_points |
|
|
|
|
|
|
|
|
def _get_logits(images, |
|
|
model_options, |
|
|
weight_decay=0.0001, |
|
|
reuse=None, |
|
|
is_training=False, |
|
|
fine_tune_batch_norm=False, |
|
|
nas_training_hyper_parameters=None): |
|
|
"""Gets the logits by atrous/image spatial pyramid pooling. |
|
|
|
|
|
Args: |
|
|
images: A tensor of size [batch, height, width, channels]. |
|
|
model_options: A ModelOptions instance to configure models. |
|
|
weight_decay: The weight decay for model variables. |
|
|
reuse: Reuse the model variables or not. |
|
|
is_training: Is training or not. |
|
|
fine_tune_batch_norm: Fine-tune the batch norm parameters or not. |
|
|
nas_training_hyper_parameters: A dictionary storing hyper-parameters for |
|
|
training nas models. Its keys are: |
|
|
- `drop_path_keep_prob`: Probability to keep each path in the cell when |
|
|
training. |
|
|
- `total_training_steps`: Total training steps to help drop path |
|
|
probability calculation. |
|
|
|
|
|
Returns: |
|
|
outputs_to_logits: A map from output_type to logits. |
|
|
""" |
|
|
features, end_points = extract_features( |
|
|
images, |
|
|
model_options, |
|
|
weight_decay=weight_decay, |
|
|
reuse=reuse, |
|
|
is_training=is_training, |
|
|
fine_tune_batch_norm=fine_tune_batch_norm, |
|
|
nas_training_hyper_parameters=nas_training_hyper_parameters) |
|
|
|
|
|
if model_options.decoder_output_stride: |
|
|
crop_size = model_options.crop_size |
|
|
if crop_size is None: |
|
|
crop_size = [tf.shape(images)[1], tf.shape(images)[2]] |
|
|
features = refine_by_decoder( |
|
|
features, |
|
|
end_points, |
|
|
crop_size=crop_size, |
|
|
decoder_output_stride=model_options.decoder_output_stride, |
|
|
decoder_use_separable_conv=model_options.decoder_use_separable_conv, |
|
|
decoder_use_sum_merge=model_options.decoder_use_sum_merge, |
|
|
decoder_filters=model_options.decoder_filters, |
|
|
decoder_output_is_logits=model_options.decoder_output_is_logits, |
|
|
model_variant=model_options.model_variant, |
|
|
weight_decay=weight_decay, |
|
|
reuse=reuse, |
|
|
is_training=is_training, |
|
|
fine_tune_batch_norm=fine_tune_batch_norm, |
|
|
use_bounded_activation=model_options.use_bounded_activation) |
|
|
|
|
|
outputs_to_logits = {} |
|
|
for output in sorted(model_options.outputs_to_num_classes): |
|
|
if model_options.decoder_output_is_logits: |
|
|
outputs_to_logits[output] = tf.identity(features, |
|
|
name=output) |
|
|
else: |
|
|
outputs_to_logits[output] = get_branch_logits( |
|
|
features, |
|
|
model_options.outputs_to_num_classes[output], |
|
|
model_options.atrous_rates, |
|
|
aspp_with_batch_norm=model_options.aspp_with_batch_norm, |
|
|
kernel_size=model_options.logits_kernel_size, |
|
|
weight_decay=weight_decay, |
|
|
reuse=reuse, |
|
|
scope_suffix=output) |
|
|
|
|
|
return outputs_to_logits |
|
|
|
|
|
|
|
|
def refine_by_decoder(features, |
|
|
end_points, |
|
|
crop_size=None, |
|
|
decoder_output_stride=None, |
|
|
decoder_use_separable_conv=False, |
|
|
decoder_use_sum_merge=False, |
|
|
decoder_filters=256, |
|
|
decoder_output_is_logits=False, |
|
|
model_variant=None, |
|
|
weight_decay=0.0001, |
|
|
reuse=None, |
|
|
is_training=False, |
|
|
fine_tune_batch_norm=False, |
|
|
use_bounded_activation=False, |
|
|
sync_batch_norm_method='None'): |
|
|
"""Adds the decoder to obtain sharper segmentation results. |
|
|
|
|
|
Args: |
|
|
features: A tensor of size [batch, features_height, features_width, |
|
|
features_channels]. |
|
|
end_points: A dictionary from components of the network to the corresponding |
|
|
activation. |
|
|
crop_size: A tuple [crop_height, crop_width] specifying whole patch crop |
|
|
size. |
|
|
decoder_output_stride: A list of integers specifying the output stride of |
|
|
low-level features used in the decoder module. |
|
|
decoder_use_separable_conv: Employ separable convolution for decoder or not. |
|
|
decoder_use_sum_merge: Boolean, decoder uses simple sum merge or not. |
|
|
decoder_filters: Integer, decoder filter size. |
|
|
decoder_output_is_logits: Boolean, using decoder output as logits or not. |
|
|
model_variant: Model variant for feature extraction. |
|
|
weight_decay: The weight decay for model variables. |
|
|
reuse: Reuse the model variables or not. |
|
|
is_training: Is training or not. |
|
|
fine_tune_batch_norm: Fine-tune the batch norm parameters or not. |
|
|
use_bounded_activation: Whether or not to use bounded activations. Bounded |
|
|
activations better lend themselves to quantized inference. |
|
|
sync_batch_norm_method: String, method used to sync batch norm. Currently |
|
|
only support `None` (no sync batch norm) and `tpu` (use tpu code to |
|
|
sync batch norm). |
|
|
|
|
|
Returns: |
|
|
Decoder output with size [batch, decoder_height, decoder_width, |
|
|
decoder_channels]. |
|
|
|
|
|
Raises: |
|
|
ValueError: If crop_size is None. |
|
|
""" |
|
|
if crop_size is None: |
|
|
raise ValueError('crop_size must be provided when using decoder.') |
|
|
batch_norm_params = utils.get_batch_norm_params( |
|
|
decay=0.9997, |
|
|
epsilon=1e-5, |
|
|
scale=True, |
|
|
is_training=(is_training and fine_tune_batch_norm), |
|
|
sync_batch_norm_method=sync_batch_norm_method) |
|
|
batch_norm = utils.get_batch_norm_fn(sync_batch_norm_method) |
|
|
decoder_depth = decoder_filters |
|
|
projected_filters = 48 |
|
|
if decoder_use_sum_merge: |
|
|
|
|
|
|
|
|
projected_filters = decoder_filters |
|
|
if decoder_output_is_logits: |
|
|
|
|
|
activation_fn = None |
|
|
normalizer_fn = None |
|
|
conv2d_kernel = 1 |
|
|
|
|
|
decoder_use_separable_conv = False |
|
|
else: |
|
|
|
|
|
activation_fn = tf.nn.relu6 if use_bounded_activation else tf.nn.relu |
|
|
normalizer_fn = batch_norm |
|
|
conv2d_kernel = 3 |
|
|
with slim.arg_scope( |
|
|
[slim.conv2d, slim.separable_conv2d], |
|
|
weights_regularizer=slim.l2_regularizer(weight_decay), |
|
|
activation_fn=activation_fn, |
|
|
normalizer_fn=normalizer_fn, |
|
|
padding='SAME', |
|
|
stride=1, |
|
|
reuse=reuse): |
|
|
with slim.arg_scope([batch_norm], **batch_norm_params): |
|
|
with tf.variable_scope(DECODER_SCOPE, DECODER_SCOPE, [features]): |
|
|
decoder_features = features |
|
|
decoder_stage = 0 |
|
|
scope_suffix = '' |
|
|
for output_stride in decoder_output_stride: |
|
|
feature_list = feature_extractor.networks_to_feature_maps[ |
|
|
model_variant][ |
|
|
feature_extractor.DECODER_END_POINTS][output_stride] |
|
|
|
|
|
|
|
|
if decoder_stage: |
|
|
scope_suffix = '_{}'.format(decoder_stage) |
|
|
for i, name in enumerate(feature_list): |
|
|
decoder_features_list = [decoder_features] |
|
|
|
|
|
if ('mobilenet' in model_variant or |
|
|
model_variant.startswith('mnas') or |
|
|
model_variant.startswith('nas')): |
|
|
feature_name = name |
|
|
else: |
|
|
feature_name = '{}/{}'.format( |
|
|
feature_extractor.name_scope[model_variant], name) |
|
|
decoder_features_list.append( |
|
|
slim.conv2d( |
|
|
end_points[feature_name], |
|
|
projected_filters, |
|
|
1, |
|
|
scope='feature_projection' + str(i) + scope_suffix)) |
|
|
|
|
|
decoder_height = scale_dimension(crop_size[0], 1.0 / output_stride) |
|
|
decoder_width = scale_dimension(crop_size[1], 1.0 / output_stride) |
|
|
|
|
|
for j, feature in enumerate(decoder_features_list): |
|
|
decoder_features_list[j] = _resize_bilinear( |
|
|
feature, [decoder_height, decoder_width], feature.dtype) |
|
|
h = (None if isinstance(decoder_height, tf.Tensor) |
|
|
else decoder_height) |
|
|
w = (None if isinstance(decoder_width, tf.Tensor) |
|
|
else decoder_width) |
|
|
decoder_features_list[j].set_shape([None, h, w, None]) |
|
|
if decoder_use_sum_merge: |
|
|
decoder_features = _decoder_with_sum_merge( |
|
|
decoder_features_list, |
|
|
decoder_depth, |
|
|
conv2d_kernel=conv2d_kernel, |
|
|
decoder_use_separable_conv=decoder_use_separable_conv, |
|
|
weight_decay=weight_decay, |
|
|
scope_suffix=scope_suffix) |
|
|
else: |
|
|
if not decoder_use_separable_conv: |
|
|
scope_suffix = str(i) + scope_suffix |
|
|
decoder_features = _decoder_with_concat_merge( |
|
|
decoder_features_list, |
|
|
decoder_depth, |
|
|
decoder_use_separable_conv=decoder_use_separable_conv, |
|
|
weight_decay=weight_decay, |
|
|
scope_suffix=scope_suffix) |
|
|
decoder_stage += 1 |
|
|
return decoder_features |
|
|
|
|
|
|
|
|
def _decoder_with_sum_merge(decoder_features_list, |
|
|
decoder_depth, |
|
|
conv2d_kernel=3, |
|
|
decoder_use_separable_conv=True, |
|
|
weight_decay=0.0001, |
|
|
scope_suffix=''): |
|
|
"""Decoder with sum to merge features. |
|
|
|
|
|
Args: |
|
|
decoder_features_list: A list of decoder features. |
|
|
decoder_depth: Integer, the filters used in the convolution. |
|
|
conv2d_kernel: Integer, the convolution kernel size. |
|
|
decoder_use_separable_conv: Boolean, use separable conv or not. |
|
|
weight_decay: Weight decay for the model variables. |
|
|
scope_suffix: String, used in the scope suffix. |
|
|
|
|
|
Returns: |
|
|
decoder features merged with sum. |
|
|
|
|
|
Raises: |
|
|
RuntimeError: If decoder_features_list have length not equal to 2. |
|
|
""" |
|
|
if len(decoder_features_list) != 2: |
|
|
raise RuntimeError('Expect decoder_features has length 2.') |
|
|
|
|
|
if decoder_use_separable_conv: |
|
|
decoder_features = split_separable_conv2d( |
|
|
decoder_features_list[0], |
|
|
filters=decoder_depth, |
|
|
rate=1, |
|
|
weight_decay=weight_decay, |
|
|
scope='decoder_split_sep_conv0'+scope_suffix) + decoder_features_list[1] |
|
|
else: |
|
|
decoder_features = slim.conv2d( |
|
|
decoder_features_list[0], |
|
|
decoder_depth, |
|
|
conv2d_kernel, |
|
|
scope='decoder_conv0'+scope_suffix) + decoder_features_list[1] |
|
|
return decoder_features |
|
|
|
|
|
|
|
|
def _decoder_with_concat_merge(decoder_features_list, |
|
|
decoder_depth, |
|
|
decoder_use_separable_conv=True, |
|
|
weight_decay=0.0001, |
|
|
scope_suffix=''): |
|
|
"""Decoder with concatenation to merge features. |
|
|
|
|
|
This decoder method applies two convolutions to smooth the features obtained |
|
|
by concatenating the input decoder_features_list. |
|
|
|
|
|
This decoder module is proposed in the DeepLabv3+ paper. |
|
|
|
|
|
Args: |
|
|
decoder_features_list: A list of decoder features. |
|
|
decoder_depth: Integer, the filters used in the convolution. |
|
|
decoder_use_separable_conv: Boolean, use separable conv or not. |
|
|
weight_decay: Weight decay for the model variables. |
|
|
scope_suffix: String, used in the scope suffix. |
|
|
|
|
|
Returns: |
|
|
decoder features merged with concatenation. |
|
|
""" |
|
|
if decoder_use_separable_conv: |
|
|
decoder_features = split_separable_conv2d( |
|
|
tf.concat(decoder_features_list, 3), |
|
|
filters=decoder_depth, |
|
|
rate=1, |
|
|
weight_decay=weight_decay, |
|
|
scope='decoder_conv0'+scope_suffix) |
|
|
decoder_features = split_separable_conv2d( |
|
|
decoder_features, |
|
|
filters=decoder_depth, |
|
|
rate=1, |
|
|
weight_decay=weight_decay, |
|
|
scope='decoder_conv1'+scope_suffix) |
|
|
else: |
|
|
num_convs = 2 |
|
|
decoder_features = slim.repeat( |
|
|
tf.concat(decoder_features_list, 3), |
|
|
num_convs, |
|
|
slim.conv2d, |
|
|
decoder_depth, |
|
|
3, |
|
|
scope='decoder_conv'+scope_suffix) |
|
|
return decoder_features |
|
|
|
|
|
|
|
|
def get_branch_logits(features, |
|
|
num_classes, |
|
|
atrous_rates=None, |
|
|
aspp_with_batch_norm=False, |
|
|
kernel_size=1, |
|
|
weight_decay=0.0001, |
|
|
reuse=None, |
|
|
scope_suffix=''): |
|
|
"""Gets the logits from each model's branch. |
|
|
|
|
|
The underlying model is branched out in the last layer when atrous |
|
|
spatial pyramid pooling is employed, and all branches are sum-merged |
|
|
to form the final logits. |
|
|
|
|
|
Args: |
|
|
features: A float tensor of shape [batch, height, width, channels]. |
|
|
num_classes: Number of classes to predict. |
|
|
atrous_rates: A list of atrous convolution rates for last layer. |
|
|
aspp_with_batch_norm: Use batch normalization layers for ASPP. |
|
|
kernel_size: Kernel size for convolution. |
|
|
weight_decay: Weight decay for the model variables. |
|
|
reuse: Reuse model variables or not. |
|
|
scope_suffix: Scope suffix for the model variables. |
|
|
|
|
|
Returns: |
|
|
Merged logits with shape [batch, height, width, num_classes]. |
|
|
|
|
|
Raises: |
|
|
ValueError: Upon invalid input kernel_size value. |
|
|
""" |
|
|
|
|
|
|
|
|
if aspp_with_batch_norm or atrous_rates is None: |
|
|
if kernel_size != 1: |
|
|
raise ValueError('Kernel size must be 1 when atrous_rates is None or ' |
|
|
'using aspp_with_batch_norm. Gets %d.' % kernel_size) |
|
|
atrous_rates = [1] |
|
|
|
|
|
with slim.arg_scope( |
|
|
[slim.conv2d], |
|
|
weights_regularizer=slim.l2_regularizer(weight_decay), |
|
|
weights_initializer=tf.truncated_normal_initializer(stddev=0.01), |
|
|
reuse=reuse): |
|
|
with tf.variable_scope(LOGITS_SCOPE_NAME, LOGITS_SCOPE_NAME, [features]): |
|
|
branch_logits = [] |
|
|
for i, rate in enumerate(atrous_rates): |
|
|
scope = scope_suffix |
|
|
if i: |
|
|
scope += '_%d' % i |
|
|
|
|
|
branch_logits.append( |
|
|
slim.conv2d( |
|
|
features, |
|
|
num_classes, |
|
|
kernel_size=kernel_size, |
|
|
rate=rate, |
|
|
activation_fn=None, |
|
|
normalizer_fn=None, |
|
|
scope=scope)) |
|
|
|
|
|
return tf.add_n(branch_logits) |
|
|
|