|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Cell structure used by NAS."""
|
|
|
| from __future__ import absolute_import
|
| from __future__ import division
|
| from __future__ import print_function
|
|
|
| import functools
|
| from six.moves import range
|
| from six.moves import zip
|
| import tensorflow as tf
|
| from tensorflow.contrib import framework as contrib_framework
|
| from tensorflow.contrib import slim as contrib_slim
|
| from deeplab.core import xception as xception_utils
|
| from deeplab.core.utils import resize_bilinear
|
| from deeplab.core.utils import scale_dimension
|
| from tensorflow.contrib.slim.nets import resnet_utils
|
|
|
| arg_scope = contrib_framework.arg_scope
|
| slim = contrib_slim
|
|
|
| separable_conv2d_same = functools.partial(xception_utils.separable_conv2d_same,
|
| regularize_depthwise=True)
|
|
|
|
|
| class NASBaseCell(object):
|
| """NASNet Cell class that is used as a 'layer' in image architectures."""
|
|
|
| def __init__(self, num_conv_filters, operations, used_hiddenstates,
|
| hiddenstate_indices, drop_path_keep_prob, total_num_cells,
|
| total_training_steps, batch_norm_fn=slim.batch_norm):
|
| """Init function.
|
|
|
| For more details about NAS cell, see
|
| https://arxiv.org/abs/1707.07012 and https://arxiv.org/abs/1712.00559.
|
|
|
| Args:
|
| num_conv_filters: The number of filters for each convolution operation.
|
| operations: List of operations that are performed in the NASNet Cell in
|
| order.
|
| used_hiddenstates: Binary array that signals if the hiddenstate was used
|
| within the cell. This is used to determine what outputs of the cell
|
| should be concatenated together.
|
| hiddenstate_indices: Determines what hiddenstates should be combined
|
| together with the specified operations to create the NASNet cell.
|
| drop_path_keep_prob: Float, drop path keep probability.
|
| total_num_cells: Integer, total number of cells.
|
| total_training_steps: Integer, total training steps.
|
| batch_norm_fn: Function, batch norm function. Defaults to
|
| slim.batch_norm.
|
| """
|
| if len(hiddenstate_indices) != len(operations):
|
| raise ValueError(
|
| 'Number of hiddenstate_indices and operations should be the same.')
|
| if len(operations) % 2:
|
| raise ValueError('Number of operations should be even.')
|
| self._num_conv_filters = num_conv_filters
|
| self._operations = operations
|
| self._used_hiddenstates = used_hiddenstates
|
| self._hiddenstate_indices = hiddenstate_indices
|
| self._drop_path_keep_prob = drop_path_keep_prob
|
| self._total_num_cells = total_num_cells
|
| self._total_training_steps = total_training_steps
|
| self._batch_norm_fn = batch_norm_fn
|
|
|
| def __call__(self, net, scope, filter_scaling, stride, prev_layer, cell_num):
|
| """Runs the conv cell."""
|
| self._cell_num = cell_num
|
| self._filter_scaling = filter_scaling
|
| self._filter_size = int(self._num_conv_filters * filter_scaling)
|
|
|
| with tf.variable_scope(scope):
|
| net = self._cell_base(net, prev_layer)
|
| for i in range(len(self._operations) // 2):
|
| with tf.variable_scope('comb_iter_{}'.format(i)):
|
| h1 = net[self._hiddenstate_indices[i * 2]]
|
| h2 = net[self._hiddenstate_indices[i * 2 + 1]]
|
| with tf.variable_scope('left'):
|
| h1 = self._apply_conv_operation(
|
| h1, self._operations[i * 2], stride,
|
| self._hiddenstate_indices[i * 2] < 2)
|
| with tf.variable_scope('right'):
|
| h2 = self._apply_conv_operation(
|
| h2, self._operations[i * 2 + 1], stride,
|
| self._hiddenstate_indices[i * 2 + 1] < 2)
|
| with tf.variable_scope('combine'):
|
| h = h1 + h2
|
| net.append(h)
|
|
|
| with tf.variable_scope('cell_output'):
|
| net = self._combine_unused_states(net)
|
|
|
| return net
|
|
|
| def _cell_base(self, net, prev_layer):
|
| """Runs the beginning of the conv cell before the chosen ops are run."""
|
| filter_size = self._filter_size
|
|
|
| if prev_layer is None:
|
| prev_layer = net
|
| else:
|
| if net.shape[2] != prev_layer.shape[2]:
|
| prev_layer = resize_bilinear(
|
| prev_layer, tf.shape(net)[1:3], prev_layer.dtype)
|
| if filter_size != prev_layer.shape[3]:
|
| prev_layer = tf.nn.relu(prev_layer)
|
| prev_layer = slim.conv2d(prev_layer, filter_size, 1, scope='prev_1x1')
|
| prev_layer = self._batch_norm_fn(prev_layer, scope='prev_bn')
|
|
|
| net = tf.nn.relu(net)
|
| net = slim.conv2d(net, filter_size, 1, scope='1x1')
|
| net = self._batch_norm_fn(net, scope='beginning_bn')
|
| net = tf.split(axis=3, num_or_size_splits=1, value=net)
|
| net.append(prev_layer)
|
| return net
|
|
|
| def _apply_conv_operation(self, net, operation, stride,
|
| is_from_original_input):
|
| """Applies the predicted conv operation to net."""
|
| if stride > 1 and not is_from_original_input:
|
| stride = 1
|
| input_filters = net.shape[3]
|
| filter_size = self._filter_size
|
| if 'separable' in operation:
|
| num_layers = int(operation.split('_')[-1])
|
| kernel_size = int(operation.split('x')[0][-1])
|
| for layer_num in range(num_layers):
|
| net = tf.nn.relu(net)
|
| net = separable_conv2d_same(
|
| net,
|
| filter_size,
|
| kernel_size,
|
| depth_multiplier=1,
|
| scope='separable_{0}x{0}_{1}'.format(kernel_size, layer_num + 1),
|
| stride=stride)
|
| net = self._batch_norm_fn(
|
| net, scope='bn_sep_{0}x{0}_{1}'.format(kernel_size, layer_num + 1))
|
| stride = 1
|
| elif 'atrous' in operation:
|
| kernel_size = int(operation.split('x')[0][-1])
|
| net = tf.nn.relu(net)
|
| if stride == 2:
|
| scaled_height = scale_dimension(tf.shape(net)[1], 0.5)
|
| scaled_width = scale_dimension(tf.shape(net)[2], 0.5)
|
| net = resize_bilinear(net, [scaled_height, scaled_width], net.dtype)
|
| net = resnet_utils.conv2d_same(
|
| net, filter_size, kernel_size, rate=1, stride=1,
|
| scope='atrous_{0}x{0}'.format(kernel_size))
|
| else:
|
| net = resnet_utils.conv2d_same(
|
| net, filter_size, kernel_size, rate=2, stride=1,
|
| scope='atrous_{0}x{0}'.format(kernel_size))
|
| net = self._batch_norm_fn(net, scope='bn_atr_{0}x{0}'.format(kernel_size))
|
| elif operation in ['none']:
|
| if stride > 1 or (input_filters != filter_size):
|
| net = tf.nn.relu(net)
|
| net = slim.conv2d(net, filter_size, 1, stride=stride, scope='1x1')
|
| net = self._batch_norm_fn(net, scope='bn_1')
|
| elif 'pool' in operation:
|
| pooling_type = operation.split('_')[0]
|
| pooling_shape = int(operation.split('_')[-1].split('x')[0])
|
| if pooling_type == 'avg':
|
| net = slim.avg_pool2d(net, pooling_shape, stride=stride, padding='SAME')
|
| elif pooling_type == 'max':
|
| net = slim.max_pool2d(net, pooling_shape, stride=stride, padding='SAME')
|
| else:
|
| raise ValueError('Unimplemented pooling type: ', pooling_type)
|
| if input_filters != filter_size:
|
| net = slim.conv2d(net, filter_size, 1, stride=1, scope='1x1')
|
| net = self._batch_norm_fn(net, scope='bn_1')
|
| else:
|
| raise ValueError('Unimplemented operation', operation)
|
|
|
| if operation != 'none':
|
| net = self._apply_drop_path(net)
|
| return net
|
|
|
| def _combine_unused_states(self, net):
|
| """Concatenates the unused hidden states of the cell."""
|
| used_hiddenstates = self._used_hiddenstates
|
| states_to_combine = ([
|
| h for h, is_used in zip(net, used_hiddenstates) if not is_used])
|
| net = tf.concat(values=states_to_combine, axis=3)
|
| return net
|
|
|
| @contrib_framework.add_arg_scope
|
| def _apply_drop_path(self, net):
|
| """Apply drop_path regularization."""
|
| drop_path_keep_prob = self._drop_path_keep_prob
|
| if drop_path_keep_prob < 1.0:
|
|
|
| assert self._cell_num != -1
|
| layer_ratio = (self._cell_num + 1) / float(self._total_num_cells)
|
| drop_path_keep_prob = 1 - layer_ratio * (1 - drop_path_keep_prob)
|
|
|
| current_step = tf.cast(tf.train.get_or_create_global_step(), tf.float32)
|
| current_ratio = tf.minimum(1.0, current_step / self._total_training_steps)
|
| drop_path_keep_prob = (1 - current_ratio * (1 - drop_path_keep_prob))
|
|
|
| noise_shape = [tf.shape(net)[0], 1, 1, 1]
|
| random_tensor = drop_path_keep_prob
|
| random_tensor += tf.random_uniform(noise_shape, dtype=tf.float32)
|
| binary_tensor = tf.cast(tf.floor(random_tensor), net.dtype)
|
| keep_prob_inv = tf.cast(1.0 / drop_path_keep_prob, net.dtype)
|
| net = net * keep_prob_inv * binary_tensor
|
| return net
|
|
|