|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Layers for QRNN."""
|
| import tensorflow as tf
|
|
|
| from layers import base_layers
|
| from layers import conv_layers
|
| from layers import dense_layers
|
| from layers import quantization_layers
|
| from tf_ops import tf_custom_ops_py
|
|
|
| QUASI_RNN_POOLING_F = "f"
|
| QUASI_RNN_POOLING_FO = "fo"
|
| QUASI_RNN_POOLING_IFO = "ifo"
|
| _QUASI_RNN_POOLING_TO_NUMBER_OF_GATES_MAP = {
|
| QUASI_RNN_POOLING_F: 2,
|
| QUASI_RNN_POOLING_FO: 3,
|
| QUASI_RNN_POOLING_IFO: 4,
|
| }
|
|
|
|
|
| class QRNNUnidirectionalPoolingCore(base_layers.BaseLayer):
|
| """Create a unidirectional QRNN pooling inner loop."""
|
|
|
| def __init__(self, forward=True, **kwargs):
|
| self.forward = forward
|
| super(QRNNUnidirectionalPoolingCore, self).__init__(**kwargs)
|
|
|
| def call(self, multiplier, constant):
|
| if self.parameters.mode != base_layers.TFLITE:
|
| return self._qrnn_pooling(multiplier, constant)
|
| else:
|
| return tf_custom_ops_py.pooling_op(multiplier, constant,
|
| [1.0 if self.forward else 0.0])
|
|
|
| def _qrnn_pooling(self, multipler, constant):
|
| """Pooling step computes the internal states for all timesteps."""
|
| assert multipler.get_shape().as_list() == constant.get_shape().as_list()
|
|
|
| gate_static_shape = multipler.get_shape().as_list()
|
| gate_shape = tf.shape(multipler)
|
|
|
| feature_size = gate_static_shape[2]
|
| assert feature_size is not None
|
| batch_size = gate_static_shape[0] or gate_shape[0]
|
| max_timestep = gate_static_shape[1] or gate_shape[1]
|
|
|
| dynamic_loop = gate_static_shape[1] is None
|
|
|
|
|
| multiplier_transposed = tf.transpose(multipler, [1, 0, 2])
|
| constant_transposed = tf.transpose(constant, [1, 0, 2])
|
|
|
|
|
| state = tf.zeros((batch_size, feature_size), tf.float32)
|
| if dynamic_loop:
|
|
|
|
|
| def _step(index, state, states):
|
| m = multiplier_transposed[index, :, :]
|
| c = constant_transposed[index, :, :]
|
| new_state = state * m + c
|
| next_index = index + 1 if self.forward else index - 1
|
| return next_index, new_state, states.write(index, new_state)
|
|
|
|
|
| def _termination(index, state, states):
|
| del state, states
|
| return (index < max_timestep) if self.forward else (index >= 0)
|
|
|
| states = tf.TensorArray(tf.float32, size=max_timestep)
|
| index = 0 if self.forward else max_timestep - 1
|
|
|
|
|
| _, state, states = tf.while_loop(_termination, _step,
|
| [index, state, states])
|
| states = states.stack()
|
| else:
|
|
|
| multiplier_list = tf.unstack(multiplier_transposed)
|
| constant_list = tf.unstack(constant_transposed)
|
| states = []
|
|
|
|
|
| timesteps = list(range(max_timestep)) if self.forward else reversed(
|
| list(range(max_timestep)))
|
|
|
|
|
| for time in timesteps:
|
| state = state * multiplier_list[time] + constant_list[time]
|
| states.append(state)
|
|
|
|
|
| states = tf.stack(states if self.forward else list(reversed(states)))
|
|
|
|
|
| return tf.transpose(states, [1, 0, 2])
|
|
|
|
|
| class QRNNUnidirectionalPooling(base_layers.BaseLayer):
|
| """Create a unidirectional QRNN pooling."""
|
|
|
| def __init__(self,
|
| zoneout_probability=0.0,
|
| forward=True,
|
| pooling=QUASI_RNN_POOLING_FO,
|
| output_quantized=True,
|
| **kwargs):
|
| self.zoneout_probability = zoneout_probability
|
| self.pooling = pooling
|
| self.forward = forward
|
| self.output_quantized = output_quantized
|
| if output_quantized and self.pooling == QUASI_RNN_POOLING_IFO:
|
| self.qoutputs = quantization_layers.ActivationQuantization()
|
| self.num_gates = _QUASI_RNN_POOLING_TO_NUMBER_OF_GATES_MAP[pooling]
|
| assert pooling in _QUASI_RNN_POOLING_TO_NUMBER_OF_GATES_MAP.keys()
|
| self.pooling_core = QRNNUnidirectionalPoolingCore(forward=forward, **kwargs)
|
| super(QRNNUnidirectionalPooling, self).__init__(**kwargs)
|
|
|
| def call(self, gates, mask):
|
| return self._create_qrnn_pooling_unidirectional(gates, mask)
|
|
|
| def _qrnn_preprocess(self, gates):
|
| """Preprocess the gate inputs to the pooling layer."""
|
| assert self.num_gates == len(gates)
|
| dim = lambda tensor, index: tensor.get_shape().as_list()[index]
|
|
|
| for tensor in gates:
|
| assert len(tensor.get_shape().as_list()) == 3
|
| for idx in range(3):
|
| assert dim(gates[0], idx) == dim(tensor, idx)
|
|
|
| if self.pooling == QUASI_RNN_POOLING_F:
|
| z = self.quantized_tanh(gates[0], tf_only=True)
|
| f = self.quantized_sigmoid(gates[1], tf_only=True)
|
| return f, self.qrange_tanh(self.qrange_sigmoid(1 - f) * z), 1
|
| elif self.pooling == QUASI_RNN_POOLING_FO:
|
| z = self.quantized_tanh(gates[0], tf_only=True)
|
| f = self.quantized_sigmoid(gates[1], tf_only=True)
|
| o = self.quantized_sigmoid(gates[2], tf_only=True)
|
| return f, self.qrange_tanh(self.qrange_sigmoid(1 - f) * z), o
|
| else:
|
| z = self.quantized_tanh(gates[0], tf_only=True)
|
| i = self.quantized_sigmoid(gates[1], tf_only=True)
|
| f = self.quantized_sigmoid(gates[2], tf_only=True)
|
| o = self.quantized_sigmoid(gates[3], tf_only=True)
|
| return f, self.qrange_tanh(i * z), o
|
|
|
| def _qrnn_postprocess(self, states, multiplier):
|
| """Postprocess the states and return the output tensors."""
|
| if self.pooling == QUASI_RNN_POOLING_F:
|
| return states
|
| elif self.pooling == QUASI_RNN_POOLING_FO:
|
| return self.qrange_tanh(states) * multiplier
|
| else:
|
| return self.qoutputs(states) * multiplier
|
|
|
| def _qrnn_zoneout(self, multipler, constant):
|
| """Zoneout regularization for Quasi RNN."""
|
| enable_zoneout = self.zoneout_probability > 0.0
|
| if enable_zoneout and self.parameters.mode == base_layers.TRAIN:
|
|
|
|
|
| zoneout_mask = tf.random.uniform(tf.shape(multipler), maxval=1.0)
|
| zoneout_mask = tf.floor(zoneout_mask + self.zoneout_probability)
|
|
|
|
|
|
|
|
|
|
|
|
|
| multipler = tf.maximum(zoneout_mask, multipler)
|
| constant *= (1 - zoneout_mask)
|
| return multipler, constant
|
|
|
| def _create_qrnn_pooling_unidirectional(self, gates, mask):
|
| """Create QRNN Pooling in either forward or backward direction."""
|
| m1, c1, outgate = self._qrnn_preprocess(gates)
|
|
|
|
|
|
|
| if self.parameters.mode not in [base_layers.PREDICT, base_layers.TFLITE]:
|
| m1 = m1 * mask + (1 - mask) * tf.ones_like(m1)
|
| c1 *= mask
|
|
|
| m1, c1 = self._qrnn_zoneout(m1, c1)
|
|
|
| states = self.pooling_core(m1, c1)
|
|
|
| outputs = self._qrnn_postprocess(states, outgate)
|
|
|
|
|
|
|
| if self.parameters.mode not in [base_layers.PREDICT, base_layers.TFLITE]:
|
| outputs *= mask
|
|
|
| if self.output_quantized:
|
| if self.pooling in [QUASI_RNN_POOLING_FO, QUASI_RNN_POOLING_F]:
|
| outputs = self.qrange_tanh(outputs)
|
| else:
|
| outputs = self.qoutputs.quantize_using_range(outputs)
|
|
|
| return outputs
|
|
|
|
|
| class QRNNUnidirectional(base_layers.BaseLayer):
|
| """Create a unidirectional QRNN encoder."""
|
|
|
| def __init__(self,
|
| kwidth,
|
| state_size,
|
| zoneout_probability=0.0,
|
| forward=True,
|
| pooling=QUASI_RNN_POOLING_FO,
|
| output_quantized=True,
|
| normalization_fn=None,
|
| **kwargs):
|
| self.forward = forward
|
| self.kwidth = kwidth
|
| self.pooling = pooling
|
| self.state_size = state_size
|
| assert pooling in _QUASI_RNN_POOLING_TO_NUMBER_OF_GATES_MAP.keys()
|
| self.num_gates = _QUASI_RNN_POOLING_TO_NUMBER_OF_GATES_MAP[pooling]
|
| self.gate_layers = []
|
| for _ in range(self.num_gates):
|
| self.gate_layers.append(
|
| conv_layers.EncoderQConvolutionVarLen(
|
| filters=state_size,
|
| ksize=kwidth,
|
| rank=3,
|
| padding="VALID",
|
| activation=None,
|
| normalization_fn=normalization_fn,
|
| **kwargs))
|
| padding = [kwidth - 1, 0] if forward else [0, kwidth - 1]
|
| self.zero_pad = tf.keras.layers.ZeroPadding1D(padding=padding)
|
| self.qrnn_pooling = QRNNUnidirectionalPooling(
|
| forward=forward,
|
| zoneout_probability=zoneout_probability,
|
| output_quantized=output_quantized,
|
| pooling=pooling,
|
| **kwargs)
|
| super(QRNNUnidirectional, self).__init__(**kwargs)
|
|
|
| def call(self, inputs, mask, inverse_normalizer=None):
|
| if inverse_normalizer is None:
|
| inverse_normalizer = tf.math.reciprocal(tf.reduce_sum(mask))
|
| self._assert_rank_and_type(inputs, 3)
|
| self._assert_rank_and_type(mask, 3)
|
| maskr4 = tf.expand_dims(mask, axis=1)
|
| padded_inputs = self.zero_pad(inputs)
|
| gates = [
|
| layer(padded_inputs, maskr4, inverse_normalizer)
|
| for layer in self.gate_layers
|
| ]
|
| return self.qrnn_pooling(gates, mask)
|
|
|
|
|
| class QRNNUnidirectionalWithBottleneck(base_layers.BaseLayer):
|
| """Create a unidirectional QRNN encoder with bottlenecks."""
|
|
|
| def __init__(self,
|
| kwidth,
|
| state_size,
|
| bottleneck_size,
|
| zoneout_probability=0.0,
|
| forward=True,
|
| pooling=QUASI_RNN_POOLING_FO,
|
| output_quantized=True,
|
| **kwargs):
|
| self.bottleneck_size = bottleneck_size
|
| self.state_size = state_size
|
| self.forward = forward
|
| self.kwidth = kwidth
|
| self.pooling = pooling
|
| self.state_size = state_size
|
| assert pooling in _QUASI_RNN_POOLING_TO_NUMBER_OF_GATES_MAP.keys()
|
| self.num_gates = _QUASI_RNN_POOLING_TO_NUMBER_OF_GATES_MAP[pooling]
|
| self.qrnn_pooling = QRNNUnidirectionalPooling(
|
| forward=forward,
|
| zoneout_probability=zoneout_probability,
|
| output_quantized=output_quantized,
|
| pooling=pooling,
|
| **kwargs)
|
| self.pre_conv_layers = []
|
| self.gate_layers = []
|
| self.post_conv_layers = []
|
| for _ in range(self.num_gates):
|
| self.pre_conv_layers.append(
|
| dense_layers.BaseQDense(bottleneck_size, rank=3, **kwargs))
|
| self.gate_layers.append(
|
| conv_layers.EncoderQConvolution(
|
| filters=bottleneck_size,
|
| ksize=kwidth,
|
| rank=3,
|
| padding="SAME",
|
| normalization_fn=None,
|
| **kwargs))
|
| self.post_conv_layers.append(
|
| dense_layers.BaseQDense(
|
| state_size, rank=3, activation=None, **kwargs))
|
| super(QRNNUnidirectionalWithBottleneck, self).__init__(**kwargs)
|
|
|
| def call(self, inputs, mask, inverse_normalizer=None):
|
| if inverse_normalizer is None:
|
| inverse_normalizer = tf.math.reciprocal(tf.reduce_sum(mask))
|
| self._assert_rank_and_type(inputs, 3)
|
| self._assert_rank_and_type(mask, 3)
|
| pre_conv_out = [layer(inputs) for layer in self.pre_conv_layers]
|
| gates = [layer(pre_conv_out[i]) for i, layer in enumerate(self.gate_layers)]
|
| post_conv_out = [
|
| layer(gates[i]) for i, layer in enumerate(self.post_conv_layers)
|
| ]
|
| return self.qrnn_pooling(post_conv_out, mask)
|
|
|
|
|
| class QRNNBidirectional(base_layers.BaseLayer):
|
| """Create a bidirectional QRNN encoder."""
|
|
|
| def __init__(self,
|
| kwidth,
|
| state_size,
|
| zoneout_probability=0.0,
|
| pooling=QUASI_RNN_POOLING_FO,
|
| bottleneck_size=None,
|
| normalization_fn=None,
|
| **kwargs):
|
| self.pooling = pooling
|
| if bottleneck_size is None:
|
| self.forward = QRNNUnidirectional(
|
| kwidth=kwidth,
|
| state_size=state_size,
|
| forward=True,
|
| output_quantized=False,
|
| zoneout_probability=zoneout_probability,
|
| pooling=pooling,
|
| normalization_fn=normalization_fn,
|
| **kwargs)
|
| self.backward = QRNNUnidirectional(
|
| kwidth=kwidth,
|
| state_size=state_size,
|
| forward=False,
|
| output_quantized=False,
|
| zoneout_probability=zoneout_probability,
|
| pooling=pooling,
|
| normalization_fn=normalization_fn,
|
| **kwargs)
|
| else:
|
| assert normalization_fn is None, (
|
| "normalization_fn will not take an effect")
|
| self.forward = QRNNUnidirectionalWithBottleneck(
|
| kwidth=kwidth,
|
| state_size=state_size,
|
| bottleneck_size=bottleneck_size,
|
| forward=True,
|
| output_quantized=False,
|
| zoneout_probability=zoneout_probability,
|
| pooling=pooling,
|
| **kwargs)
|
| self.backward = QRNNUnidirectionalWithBottleneck(
|
| kwidth=kwidth,
|
| state_size=state_size,
|
| bottleneck_size=bottleneck_size,
|
| forward=False,
|
| output_quantized=False,
|
| zoneout_probability=zoneout_probability,
|
| pooling=pooling,
|
| **kwargs)
|
|
|
| self.qconcat = quantization_layers.ConcatQuantization(axis=2, **kwargs)
|
| super(QRNNBidirectional, self).__init__(**kwargs)
|
|
|
| def call(self, inputs, mask, inverse_normalizer=None):
|
| if inverse_normalizer is None:
|
| inverse_normalizer = tf.math.reciprocal(tf.reduce_sum(mask))
|
| fwd_outputs = self.forward(inputs, mask, inverse_normalizer)
|
| bwd_outputs = self.backward(inputs, mask, inverse_normalizer)
|
|
|
| if self.pooling in [QUASI_RNN_POOLING_FO, QUASI_RNN_POOLING_F]:
|
| outputs = [self.qrange_tanh(fwd_outputs), self.qrange_tanh(bwd_outputs)]
|
| outputs = self.qrange_tanh(tf.concat(outputs, axis=2))
|
| else:
|
| outputs = self.qconcat([fwd_outputs, bwd_outputs])
|
|
|
| return outputs
|
|
|
|
|
| class QRNNBidirectionalStack(base_layers.BaseLayer):
|
| """Create a stack of bidirectional QRNN encoder."""
|
|
|
| def __init__(self,
|
| num_layers,
|
| kwidth,
|
| state_size,
|
| zoneout_probability=0.0,
|
| layerwise_decaying_zoneout=True,
|
| pooling=QUASI_RNN_POOLING_FO,
|
| bottleneck_size=None,
|
| normalization_fn=None,
|
| **kwargs):
|
| self.layers = []
|
| zp = zoneout_probability
|
| for idx in range(num_layers):
|
| if layerwise_decaying_zoneout:
|
| zp = (zoneout_probability**(idx + 1))
|
| self.layers.append(
|
| QRNNBidirectional(
|
| kwidth=kwidth,
|
| state_size=state_size,
|
| zoneout_probability=zp,
|
| pooling=pooling,
|
| bottleneck_size=bottleneck_size,
|
| normalization_fn=normalization_fn,
|
| **kwargs))
|
| super(QRNNBidirectionalStack, self).__init__(**kwargs)
|
|
|
| def call(self, inputs, maskr3, inverse_normalizer):
|
| return self._apply_qrnn_stack(inputs, maskr3, inverse_normalizer)
|
|
|
| def _apply_qrnn_stack(self, inputs, mask3, inverse_normalizer):
|
| if self.parameters.mode not in [base_layers.PREDICT, base_layers.TFLITE]:
|
| inputs = inputs * mask3
|
| for layer in self.layers:
|
| outputs = layer(inputs, mask3, inverse_normalizer)
|
| inputs = outputs
|
| return outputs
|
|
|
|
|
| class QRNNBidirectionalStackWithSeqLength(QRNNBidirectionalStack):
|
|
|
| def call(self, inputs, sequence_length):
|
| mask = tf.sequence_mask(
|
| sequence_length, tf.shape(inputs)[1], dtype=tf.float32)
|
| inverse_normalizer = tf.math.reciprocal(tf.reduce_sum(mask))
|
| maskr3 = tf.expand_dims(mask, 2)
|
| return self._apply_qrnn_stack(inputs, maskr3, inverse_normalizer)
|
|
|