|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Layers for normalization."""
|
| import tensorflow as tf
|
|
|
| from layers import base_layers
|
| from layers import quantization_layers
|
| from tf_ops import tf_custom_ops_py
|
|
|
|
|
| class BatchNormalization(base_layers.BaseLayer):
|
| """A class that applies batch normalization to the input tensor."""
|
|
|
| def __init__(self, ema_decay=0.999, **kwargs):
|
| self.ema_decay = ema_decay
|
| super(BatchNormalization, self).__init__(**kwargs)
|
|
|
| def build(self, input_shapes):
|
| self.reduce_dims = list(range(len(input_shapes) - 1))
|
| shape = [input_shapes[-1]]
|
| self.offset = self.add_weight(
|
| "offset",
|
| shape=shape,
|
| initializer=tf.keras.initializers.Zeros(),
|
| trainable=True)
|
| self.scale = self.add_weight(
|
| "scale",
|
| shape=shape,
|
| initializer=tf.keras.initializers.Ones(),
|
| trainable=True)
|
| self.mva_mean = self.add_weight(
|
| "mva_mean",
|
| shape=shape,
|
| initializer=tf.keras.initializers.Zeros(),
|
| trainable=False)
|
| self.mva_var = self.add_weight(
|
| "mva_variance",
|
| shape=shape,
|
| initializer=tf.keras.initializers.Ones(),
|
| trainable=False)
|
|
|
| def call(self, inputs):
|
| mean_mom, var_mom = None, None
|
| if self.parameters.mode == base_layers.TRAIN:
|
| mean_mom, var_mom = tf.nn.moments(inputs, self.reduce_dims)
|
| return self._batch_norm(inputs, mean_mom, var_mom)
|
|
|
| def _batch_norm(self, inputs, mean_mom, var_mom):
|
| if self.parameters.mode == base_layers.TRAIN:
|
|
|
|
|
| with tf.control_dependencies([
|
| self.assign_moving_average(self.mva_mean, mean_mom, self.ema_decay),
|
| self.assign_moving_average(self.mva_var, var_mom, self.ema_decay)
|
| ]):
|
| tensor = tf.nn.batch_normalization(inputs, mean_mom, var_mom,
|
| self.offset, self.scale, 1e-9)
|
| else:
|
|
|
|
|
|
|
| tensor = tf.nn.batch_normalization(inputs, self.mva_mean, self.mva_var,
|
| self.offset, self.scale, 1e-9)
|
| return tensor
|
|
|
|
|
| class VarLenBatchNormalization(BatchNormalization):
|
| """A class that applies batch normalization to the input tensor."""
|
|
|
| def __init__(self, rank=2, **kwargs):
|
| self.rank = rank
|
| assert rank == 2 or rank == 4
|
| super(VarLenBatchNormalization, self).__init__(**kwargs)
|
|
|
| def _reduce(self, tensor, multiplier):
|
| return tf.reduce_sum(tensor, axis=self.reduce_dims) * multiplier
|
|
|
| def call(self, inputs, mask, inverse_normalizer):
|
| if self.parameters.mode == base_layers.TRAIN:
|
| self._assert_rank_and_type(inputs, self.rank)
|
| self._assert_rank_and_type(mask, self.rank)
|
| inputs = mask * inputs
|
| mean_mom = self._reduce(inputs, inverse_normalizer)
|
| var_mom = self._reduce(inputs * inputs, inverse_normalizer)
|
| return mask * self._batch_norm(inputs, mean_mom, var_mom)
|
| elif self.parameters.mode == base_layers.EVAL:
|
| return mask * self._batch_norm(inputs, None, None)
|
| return self._batch_norm(inputs, None, None)
|
|
|
|
|
| class LayerNormalization(base_layers.BaseLayer):
|
| """A class that applies layer normalization to the input tensor."""
|
|
|
| def __init__(self, axes=None, **kwargs):
|
| self.axes = axes or [-1]
|
| self.qactivation = quantization_layers.ActivationQuantization(**kwargs)
|
| super(LayerNormalization, self).__init__(**kwargs)
|
|
|
| def build(self, input_shape):
|
| self.rank = len(input_shape)
|
| for i, axis in enumerate(self.axes):
|
| if axis < 0:
|
| self.axes[i] += self.rank
|
| assert (self.axes[i] > 0 and self.axes[i] < self.rank)
|
| self.offset = self.add_weight(
|
| "offset",
|
| shape=[1],
|
| initializer=tf.keras.initializers.Zeros(),
|
| trainable=True)
|
| self.scale = self.add_weight(
|
| "scale",
|
| shape=[1],
|
| initializer=tf.keras.initializers.Ones(),
|
| trainable=True)
|
|
|
| def call(self, tensor):
|
| tensor = self.qactivation(tensor)
|
| if self.parameters.mode != base_layers.TFLITE:
|
| mean, variance = tf.nn.moments(tensor, self.axes, keepdims=True)
|
|
|
|
|
|
|
| tensor = (tensor - mean) / tf.sqrt(variance + 1e-6)
|
| return tensor * self.scale + self.offset
|
| else:
|
| return tf_custom_ops_py.layer_norm(
|
| tensor, self.scale, self.offset, axes=self.axes)
|
|
|