|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Builds the Wide-ResNet Model."""
|
|
|
| from __future__ import absolute_import
|
| from __future__ import division
|
| from __future__ import print_function
|
|
|
| import custom_ops as ops
|
| import numpy as np
|
| import tensorflow as tf
|
|
|
|
|
|
|
| def residual_block(
|
| x, in_filter, out_filter, stride, activate_before_residual=False):
|
| """Adds residual connection to `x` in addition to applying BN->ReLU->3x3 Conv.
|
|
|
| Args:
|
| x: Tensor that is the output of the previous layer in the model.
|
| in_filter: Number of filters `x` has.
|
| out_filter: Number of filters that the output of this layer will have.
|
| stride: Integer that specified what stride should be applied to `x`.
|
| activate_before_residual: Boolean on whether a BN->ReLU should be applied
|
| to x before the convolution is applied.
|
|
|
| Returns:
|
| A Tensor that is the result of applying two sequences of BN->ReLU->3x3 Conv
|
| and then adding that Tensor to `x`.
|
| """
|
|
|
| if activate_before_residual:
|
| with tf.variable_scope('shared_activation'):
|
| x = ops.batch_norm(x, scope='init_bn')
|
| x = tf.nn.relu(x)
|
| orig_x = x
|
| else:
|
| orig_x = x
|
|
|
| block_x = x
|
| if not activate_before_residual:
|
| with tf.variable_scope('residual_only_activation'):
|
| block_x = ops.batch_norm(block_x, scope='init_bn')
|
| block_x = tf.nn.relu(block_x)
|
|
|
| with tf.variable_scope('sub1'):
|
| block_x = ops.conv2d(
|
| block_x, out_filter, 3, stride=stride, scope='conv1')
|
|
|
| with tf.variable_scope('sub2'):
|
| block_x = ops.batch_norm(block_x, scope='bn2')
|
| block_x = tf.nn.relu(block_x)
|
| block_x = ops.conv2d(
|
| block_x, out_filter, 3, stride=1, scope='conv2')
|
|
|
| with tf.variable_scope(
|
| 'sub_add'):
|
| if in_filter != out_filter:
|
| orig_x = ops.avg_pool(orig_x, stride, stride)
|
| orig_x = ops.zero_pad(orig_x, in_filter, out_filter)
|
| x = orig_x + block_x
|
| return x
|
|
|
|
|
| def _res_add(in_filter, out_filter, stride, x, orig_x):
|
| """Adds `x` with `orig_x`, both of which are layers in the model.
|
|
|
| Args:
|
| in_filter: Number of filters in `orig_x`.
|
| out_filter: Number of filters in `x`.
|
| stride: Integer specifying the stide that should be applied `orig_x`.
|
| x: Tensor that is the output of the previous layer.
|
| orig_x: Tensor that is the output of an earlier layer in the network.
|
|
|
| Returns:
|
| A Tensor that is the result of `x` and `orig_x` being added after
|
| zero padding and striding are applied to `orig_x` to get the shapes
|
| to match.
|
| """
|
| if in_filter != out_filter:
|
| orig_x = ops.avg_pool(orig_x, stride, stride)
|
| orig_x = ops.zero_pad(orig_x, in_filter, out_filter)
|
| x = x + orig_x
|
| orig_x = x
|
| return x, orig_x
|
|
|
|
|
| def build_wrn_model(images, num_classes, wrn_size):
|
| """Builds the WRN model.
|
|
|
| Build the Wide ResNet model from https://arxiv.org/abs/1605.07146.
|
|
|
| Args:
|
| images: Tensor of images that will be fed into the Wide ResNet Model.
|
| num_classes: Number of classed that the model needs to predict.
|
| wrn_size: Parameter that scales the number of filters in the Wide ResNet
|
| model.
|
|
|
| Returns:
|
| The logits of the Wide ResNet model.
|
| """
|
| kernel_size = wrn_size
|
| filter_size = 3
|
| num_blocks_per_resnet = 4
|
| filters = [
|
| min(kernel_size, 16), kernel_size, kernel_size * 2, kernel_size * 4
|
| ]
|
| strides = [1, 2, 2]
|
|
|
|
|
| with tf.variable_scope('init'):
|
| x = images
|
| output_filters = filters[0]
|
| x = ops.conv2d(x, output_filters, filter_size, scope='init_conv')
|
|
|
| first_x = x
|
| orig_x = x
|
|
|
| for block_num in range(1, 4):
|
| with tf.variable_scope('unit_{}_0'.format(block_num)):
|
| activate_before_residual = True if block_num == 1 else False
|
| x = residual_block(
|
| x,
|
| filters[block_num - 1],
|
| filters[block_num],
|
| strides[block_num - 1],
|
| activate_before_residual=activate_before_residual)
|
| for i in range(1, num_blocks_per_resnet):
|
| with tf.variable_scope('unit_{}_{}'.format(block_num, i)):
|
| x = residual_block(
|
| x,
|
| filters[block_num],
|
| filters[block_num],
|
| 1,
|
| activate_before_residual=False)
|
| x, orig_x = _res_add(filters[block_num - 1], filters[block_num],
|
| strides[block_num - 1], x, orig_x)
|
| final_stride_val = np.prod(strides)
|
| x, _ = _res_add(filters[0], filters[3], final_stride_val, x, first_x)
|
| with tf.variable_scope('unit_last'):
|
| x = ops.batch_norm(x, scope='final_bn')
|
| x = tf.nn.relu(x)
|
| x = ops.global_avg_pool(x)
|
| logits = ops.fc(x, num_classes)
|
| return logits
|
|
|