|
|
|
|
|
|
|
|
|
|
| """Tensorflow Layer modules complatible with pytorch."""
|
|
|
| import tensorflow as tf
|
|
|
|
|
| class TFReflectionPad1d(tf.keras.layers.Layer):
|
| """Tensorflow ReflectionPad1d module."""
|
|
|
| def __init__(self, padding_size):
|
| """Initialize TFReflectionPad1d module.
|
|
|
| Args:
|
| padding_size (int): Padding size.
|
|
|
| """
|
| super(TFReflectionPad1d, self).__init__()
|
| self.padding_size = padding_size
|
|
|
| @tf.function
|
| def call(self, x):
|
| """Calculate forward propagation.
|
|
|
| Args:
|
| x (Tensor): Input tensor (B, T, 1, C).
|
|
|
| Returns:
|
| Tensor: Padded tensor (B, T + 2 * padding_size, 1, C).
|
|
|
| """
|
| return tf.pad(x, [[0, 0], [self.padding_size, self.padding_size], [0, 0], [0, 0]], "REFLECT")
|
|
|
|
|
| class TFConvTranspose1d(tf.keras.layers.Layer):
|
| """Tensorflow ConvTranspose1d module."""
|
|
|
| def __init__(self, channels, kernel_size, stride, padding):
|
| """Initialize TFConvTranspose1d( module.
|
|
|
| Args:
|
| channels (int): Number of channels.
|
| kernel_size (int): kernel size.
|
| strides (int): Stride width.
|
| padding (str): Padding type ("same" or "valid").
|
|
|
| """
|
| super(TFConvTranspose1d, self).__init__()
|
| self.conv1d_transpose = tf.keras.layers.Conv2DTranspose(
|
| filters=channels,
|
| kernel_size=(kernel_size, 1),
|
| strides=(stride, 1),
|
| padding=padding,
|
| )
|
|
|
| @tf.function
|
| def call(self, x):
|
| """Calculate forward propagation.
|
|
|
| Args:
|
| x (Tensor): Input tensor (B, T, 1, C).
|
|
|
| Returns:
|
| Tensors: Output tensor (B, T', 1, C').
|
|
|
| """
|
| x = self.conv1d_transpose(x)
|
| return x
|
|
|
|
|
| class TFResidualStack(tf.keras.layers.Layer):
|
| """Tensorflow ResidualStack module."""
|
|
|
| def __init__(self,
|
| kernel_size,
|
| channels,
|
| dilation,
|
| bias,
|
| nonlinear_activation,
|
| nonlinear_activation_params,
|
| padding,
|
| ):
|
| """Initialize TFResidualStack module.
|
|
|
| Args:
|
| kernel_size (int): Kernel size.
|
| channles (int): Number of channels.
|
| dilation (int): Dilation ine.
|
| bias (bool): Whether to add bias parameter in convolution layers.
|
| nonlinear_activation (str): Activation function module name.
|
| nonlinear_activation_params (dict): Hyperparameters for activation function.
|
| padding (str): Padding type ("same" or "valid").
|
|
|
| """
|
| super(TFResidualStack, self).__init__()
|
| self.block = [
|
| getattr(tf.keras.layers, nonlinear_activation)(**nonlinear_activation_params),
|
| TFReflectionPad1d(dilation),
|
| tf.keras.layers.Conv2D(
|
| filters=channels,
|
| kernel_size=(kernel_size, 1),
|
| dilation_rate=(dilation, 1),
|
| use_bias=bias,
|
| padding="valid",
|
| ),
|
| getattr(tf.keras.layers, nonlinear_activation)(**nonlinear_activation_params),
|
| tf.keras.layers.Conv2D(filters=channels, kernel_size=1, use_bias=bias)
|
| ]
|
| self.shortcut = tf.keras.layers.Conv2D(filters=channels, kernel_size=1, use_bias=bias)
|
|
|
| @tf.function
|
| def call(self, x):
|
| """Calculate forward propagation.
|
|
|
| Args:
|
| x (Tensor): Input tensor (B, T, 1, C).
|
|
|
| Returns:
|
| Tensor: Output tensor (B, T, 1, C).
|
|
|
| """
|
| _x = tf.identity(x)
|
| for i, layer in enumerate(self.block):
|
| _x = layer(_x)
|
| shortcut = self.shortcut(x)
|
| return shortcut + _x
|
|
|