Comparative-Analysis-of-Speech-Synthesis-Models
/
TensorFlowTTS
/tensorflow_tts
/models
/tacotron2.py
| # -*- coding: utf-8 -*- | |
| # Copyright 2020 The Tacotron-2 Authors, Minh Nguyen (@dathudeptrai), Eren Gölge (@erogol) and Jae Yoo (@jaeyoo) | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Tacotron-2 Modules.""" | |
| import collections | |
| import numpy as np | |
| import tensorflow as tf | |
| # TODO: once https://github.com/tensorflow/addons/pull/1964 is fixed, | |
| # uncomment this line. | |
| # from tensorflow_addons.seq2seq import dynamic_decode | |
| from tensorflow_addons.seq2seq import BahdanauAttention, Decoder, Sampler | |
| from tensorflow_tts.utils import dynamic_decode | |
| from tensorflow_tts.models import BaseModel | |
| def get_initializer(initializer_range=0.02): | |
| """Creates a `tf.initializers.truncated_normal` with the given range. | |
| Args: | |
| initializer_range: float, initializer range for stddev. | |
| Returns: | |
| TruncatedNormal initializer with stddev = `initializer_range`. | |
| """ | |
| return tf.keras.initializers.TruncatedNormal(stddev=initializer_range) | |
| def gelu(x): | |
| """Gaussian Error Linear unit.""" | |
| cdf = 0.5 * (1.0 + tf.math.erf(x / tf.math.sqrt(2.0))) | |
| return x * cdf | |
| def gelu_new(x): | |
| """Smoother gaussian Error Linear Unit.""" | |
| cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))) | |
| return x * cdf | |
| def swish(x): | |
| """Swish activation function.""" | |
| return tf.nn.swish(x) | |
| def mish(x): | |
| return x * tf.math.tanh(tf.math.softplus(x)) | |
| ACT2FN = { | |
| "identity": tf.keras.layers.Activation("linear"), | |
| "tanh": tf.keras.layers.Activation("tanh"), | |
| "gelu": tf.keras.layers.Activation(gelu), | |
| "relu": tf.keras.activations.relu, | |
| "swish": tf.keras.layers.Activation(swish), | |
| "gelu_new": tf.keras.layers.Activation(gelu_new), | |
| "mish": tf.keras.layers.Activation(mish), | |
| } | |
| class TFEmbedding(tf.keras.layers.Embedding): | |
| """Faster version of embedding.""" | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| def call(self, inputs): | |
| inputs = tf.cast(tf.expand_dims(inputs, -1), tf.int32) | |
| outputs = tf.gather_nd(self.embeddings, inputs) | |
| return outputs | |
| class TFTacotronConvBatchNorm(tf.keras.layers.Layer): | |
| """Tacotron-2 Convolutional Batchnorm module.""" | |
| def __init__( | |
| self, filters, kernel_size, dropout_rate, activation=None, name_idx=None | |
| ): | |
| super().__init__() | |
| self.conv1d = tf.keras.layers.Conv1D( | |
| filters, | |
| kernel_size, | |
| kernel_initializer=get_initializer(0.02), | |
| padding="same", | |
| name="conv_._{}".format(name_idx), | |
| ) | |
| self.norm = tf.keras.layers.experimental.SyncBatchNormalization( | |
| axis=-1, name="batch_norm_._{}".format(name_idx) | |
| ) | |
| self.dropout = tf.keras.layers.Dropout( | |
| rate=dropout_rate, name="dropout_._{}".format(name_idx) | |
| ) | |
| self.act = ACT2FN[activation] | |
| def call(self, inputs, training=False): | |
| outputs = self.conv1d(inputs) | |
| outputs = self.norm(outputs, training=training) | |
| outputs = self.act(outputs) | |
| outputs = self.dropout(outputs, training=training) | |
| return outputs | |
| class TFTacotronEmbeddings(tf.keras.layers.Layer): | |
| """Construct character/phoneme/positional/speaker embeddings.""" | |
| def __init__(self, config, **kwargs): | |
| """Init variables.""" | |
| super().__init__(**kwargs) | |
| self.vocab_size = config.vocab_size | |
| self.embedding_hidden_size = config.embedding_hidden_size | |
| self.initializer_range = config.initializer_range | |
| self.config = config | |
| if config.n_speakers > 1: | |
| self.speaker_embeddings = TFEmbedding( | |
| config.n_speakers, | |
| config.embedding_hidden_size, | |
| embeddings_initializer=get_initializer(self.initializer_range), | |
| name="speaker_embeddings", | |
| ) | |
| self.speaker_fc = tf.keras.layers.Dense( | |
| units=config.embedding_hidden_size, name="speaker_fc" | |
| ) | |
| self.LayerNorm = tf.keras.layers.LayerNormalization( | |
| epsilon=config.layer_norm_eps, name="LayerNorm" | |
| ) | |
| self.dropout = tf.keras.layers.Dropout(config.embedding_dropout_prob) | |
| def build(self, input_shape): | |
| """Build shared character/phoneme embedding layers.""" | |
| with tf.name_scope("character_embeddings"): | |
| self.character_embeddings = self.add_weight( | |
| "weight", | |
| shape=[self.vocab_size, self.embedding_hidden_size], | |
| initializer=get_initializer(self.initializer_range), | |
| ) | |
| super().build(input_shape) | |
| def call(self, inputs, training=False): | |
| """Get character embeddings of inputs. | |
| Args: | |
| 1. character, Tensor (int32) shape [batch_size, length]. | |
| 2. speaker_id, Tensor (int32) shape [batch_size] | |
| Returns: | |
| Tensor (float32) shape [batch_size, length, embedding_size]. | |
| """ | |
| return self._embedding(inputs, training=training) | |
| def _embedding(self, inputs, training=False): | |
| """Applies embedding based on inputs tensor.""" | |
| input_ids, speaker_ids = inputs | |
| # create embeddings | |
| inputs_embeds = tf.gather(self.character_embeddings, input_ids) | |
| embeddings = inputs_embeds | |
| if self.config.n_speakers > 1: | |
| speaker_embeddings = self.speaker_embeddings(speaker_ids) | |
| speaker_features = tf.math.softplus(self.speaker_fc(speaker_embeddings)) | |
| # extended speaker embeddings | |
| extended_speaker_features = speaker_features[:, tf.newaxis, :] | |
| # sum all embedding | |
| embeddings += extended_speaker_features | |
| # apply layer-norm and dropout for embeddings. | |
| embeddings = self.LayerNorm(embeddings) | |
| embeddings = self.dropout(embeddings, training=training) | |
| return embeddings | |
| class TFTacotronEncoderConvs(tf.keras.layers.Layer): | |
| """Tacotron-2 Encoder Convolutional Batchnorm module.""" | |
| def __init__(self, config, **kwargs): | |
| """Init variables.""" | |
| super().__init__(**kwargs) | |
| self.conv_batch_norm = [] | |
| for i in range(config.n_conv_encoder): | |
| conv = TFTacotronConvBatchNorm( | |
| filters=config.encoder_conv_filters, | |
| kernel_size=config.encoder_conv_kernel_sizes, | |
| activation=config.encoder_conv_activation, | |
| dropout_rate=config.encoder_conv_dropout_rate, | |
| name_idx=i, | |
| ) | |
| self.conv_batch_norm.append(conv) | |
| def call(self, inputs, training=False): | |
| """Call logic.""" | |
| outputs = inputs | |
| for conv in self.conv_batch_norm: | |
| outputs = conv(outputs, training=training) | |
| return outputs | |
| class TFTacotronEncoder(tf.keras.layers.Layer): | |
| """Tacotron-2 Encoder.""" | |
| def __init__(self, config, **kwargs): | |
| """Init variables.""" | |
| super().__init__(**kwargs) | |
| self.embeddings = TFTacotronEmbeddings(config, name="embeddings") | |
| self.convbn = TFTacotronEncoderConvs(config, name="conv_batch_norm") | |
| self.bilstm = tf.keras.layers.Bidirectional( | |
| tf.keras.layers.LSTM( | |
| units=config.encoder_lstm_units, return_sequences=True | |
| ), | |
| name="bilstm", | |
| ) | |
| if config.n_speakers > 1: | |
| self.encoder_speaker_embeddings = TFEmbedding( | |
| config.n_speakers, | |
| config.embedding_hidden_size, | |
| embeddings_initializer=get_initializer(config.initializer_range), | |
| name="encoder_speaker_embeddings", | |
| ) | |
| self.encoder_speaker_fc = tf.keras.layers.Dense( | |
| units=config.encoder_lstm_units * 2, name="encoder_speaker_fc" | |
| ) | |
| self.config = config | |
| def call(self, inputs, training=False): | |
| """Call logic.""" | |
| input_ids, speaker_ids, input_mask = inputs | |
| # create embedding and mask them since we sum | |
| # speaker embedding to all character embedding. | |
| input_embeddings = self.embeddings([input_ids, speaker_ids], training=training) | |
| # pass embeddings to convolution batch norm | |
| conv_outputs = self.convbn(input_embeddings, training=training) | |
| # bi-lstm. | |
| outputs = self.bilstm(conv_outputs, mask=input_mask) | |
| if self.config.n_speakers > 1: | |
| encoder_speaker_embeddings = self.encoder_speaker_embeddings(speaker_ids) | |
| encoder_speaker_features = tf.math.softplus( | |
| self.encoder_speaker_fc(encoder_speaker_embeddings) | |
| ) | |
| # extended encoderspeaker embeddings | |
| extended_encoder_speaker_features = encoder_speaker_features[ | |
| :, tf.newaxis, : | |
| ] | |
| # sum to encoder outputs | |
| outputs += extended_encoder_speaker_features | |
| return outputs | |
| class Tacotron2Sampler(Sampler): | |
| """Tacotron2 sampler for Seq2Seq training.""" | |
| def __init__( | |
| self, config, | |
| ): | |
| super().__init__() | |
| self.config = config | |
| # create schedule factor. | |
| # the input of a next decoder cell is calculated by formular: | |
| # next_inputs = ratio * prev_groundtruth_outputs + (1.0 - ratio) * prev_predicted_outputs. | |
| self._ratio = tf.constant(1.0, dtype=tf.float32) | |
| self._reduction_factor = self.config.reduction_factor | |
| def setup_target(self, targets, mel_lengths): | |
| """Setup ground-truth mel outputs for decoder.""" | |
| self.mel_lengths = mel_lengths | |
| self.set_batch_size(tf.shape(targets)[0]) | |
| self.targets = targets[ | |
| :, self._reduction_factor - 1 :: self._reduction_factor, : | |
| ] | |
| self.max_lengths = tf.tile([tf.shape(self.targets)[1]], [self._batch_size]) | |
| def batch_size(self): | |
| return self._batch_size | |
| def sample_ids_shape(self): | |
| return tf.TensorShape([]) | |
| def sample_ids_dtype(self): | |
| return tf.int32 | |
| def reduction_factor(self): | |
| return self._reduction_factor | |
| def initialize(self): | |
| """Return (Finished, next_inputs).""" | |
| return ( | |
| tf.tile([False], [self._batch_size]), | |
| tf.tile([[0.0]], [self._batch_size, self.config.n_mels]), | |
| ) | |
| def sample(self, time, outputs, state): | |
| return tf.tile([0], [self._batch_size]) | |
| def next_inputs( | |
| self, | |
| time, | |
| outputs, | |
| state, | |
| sample_ids, | |
| stop_token_prediction, | |
| training=False, | |
| **kwargs, | |
| ): | |
| if training: | |
| finished = time + 1 >= self.max_lengths | |
| next_inputs = ( | |
| self._ratio * self.targets[:, time, :] | |
| + (1.0 - self._ratio) * outputs[:, -self.config.n_mels :] | |
| ) | |
| next_state = state | |
| return (finished, next_inputs, next_state) | |
| else: | |
| stop_token_prediction = tf.nn.sigmoid(stop_token_prediction) | |
| finished = tf.cast(tf.round(stop_token_prediction), tf.bool) | |
| finished = tf.reduce_all(finished) | |
| next_inputs = outputs[:, -self.config.n_mels :] | |
| next_state = state | |
| return (finished, next_inputs, next_state) | |
| def set_batch_size(self, batch_size): | |
| self._batch_size = batch_size | |
| class TFTacotronLocationSensitiveAttention(BahdanauAttention): | |
| """Tacotron-2 Location Sensitive Attention module.""" | |
| def __init__( | |
| self, | |
| config, | |
| memory, | |
| mask_encoder=True, | |
| memory_sequence_length=None, | |
| is_cumulate=True, | |
| ): | |
| """Init variables.""" | |
| memory_length = memory_sequence_length if (mask_encoder is True) else None | |
| super().__init__( | |
| units=config.attention_dim, | |
| memory=memory, | |
| memory_sequence_length=memory_length, | |
| probability_fn="softmax", | |
| name="LocationSensitiveAttention", | |
| ) | |
| self.location_convolution = tf.keras.layers.Conv1D( | |
| filters=config.attention_filters, | |
| kernel_size=config.attention_kernel, | |
| padding="same", | |
| use_bias=False, | |
| name="location_conv", | |
| ) | |
| self.location_layer = tf.keras.layers.Dense( | |
| units=config.attention_dim, use_bias=False, name="location_layer" | |
| ) | |
| self.v = tf.keras.layers.Dense(1, use_bias=True, name="scores_attention") | |
| self.config = config | |
| self.is_cumulate = is_cumulate | |
| self.use_window = False | |
| def setup_window(self, win_front=2, win_back=4): | |
| self.win_front = tf.constant(win_front, tf.int32) | |
| self.win_back = tf.constant(win_back, tf.int32) | |
| self._indices = tf.expand_dims(tf.range(tf.shape(self.keys)[1]), 0) | |
| self._indices = tf.tile( | |
| self._indices, [tf.shape(self.keys)[0], 1] | |
| ) # [batch_size, max_time] | |
| self.use_window = True | |
| def _compute_window_mask(self, max_alignments): | |
| """Compute window mask for inference. | |
| Args: | |
| max_alignments (int): [batch_size] | |
| """ | |
| expanded_max_alignments = tf.expand_dims(max_alignments, 1) # [batch_size, 1] | |
| low = expanded_max_alignments - self.win_front | |
| high = expanded_max_alignments + self.win_back | |
| mlow = tf.cast((self._indices < low), tf.float32) | |
| mhigh = tf.cast((self._indices > high), tf.float32) | |
| mask = mlow + mhigh | |
| return mask # [batch_size, max_length] | |
| def __call__(self, inputs, training=False): | |
| query, state, prev_max_alignments = inputs | |
| processed_query = self.query_layer(query) if self.query_layer else query | |
| processed_query = tf.expand_dims(processed_query, 1) | |
| expanded_alignments = tf.expand_dims(state, axis=2) | |
| f = self.location_convolution(expanded_alignments) | |
| processed_location_features = self.location_layer(f) | |
| energy = self._location_sensitive_score( | |
| processed_query, processed_location_features, self.keys | |
| ) | |
| # mask energy on inference steps. | |
| if self.use_window is True: | |
| window_mask = self._compute_window_mask(prev_max_alignments) | |
| energy = energy + window_mask * -1e20 | |
| alignments = self.probability_fn(energy, state) | |
| if self.is_cumulate: | |
| state = alignments + state | |
| else: | |
| state = alignments | |
| expanded_alignments = tf.expand_dims(alignments, 2) | |
| context = tf.reduce_sum(expanded_alignments * self.values, 1) | |
| return context, alignments, state | |
| def _location_sensitive_score(self, W_query, W_fil, W_keys): | |
| """Calculate location sensitive energy.""" | |
| return tf.squeeze(self.v(tf.nn.tanh(W_keys + W_query + W_fil)), -1) | |
| def get_initial_state(self, batch_size, size): | |
| """Get initial alignments.""" | |
| return tf.zeros(shape=[batch_size, size], dtype=tf.float32) | |
| def get_initial_context(self, batch_size): | |
| """Get initial attention.""" | |
| return tf.zeros( | |
| shape=[batch_size, self.config.encoder_lstm_units * 2], dtype=tf.float32 | |
| ) | |
| class TFTacotronPrenet(tf.keras.layers.Layer): | |
| """Tacotron-2 prenet.""" | |
| def __init__(self, config, **kwargs): | |
| """Init variables.""" | |
| super().__init__(**kwargs) | |
| self.prenet_dense = [ | |
| tf.keras.layers.Dense( | |
| units=config.prenet_units, | |
| activation=ACT2FN[config.prenet_activation], | |
| name="dense_._{}".format(i), | |
| ) | |
| for i in range(config.n_prenet_layers) | |
| ] | |
| self.dropout = tf.keras.layers.Dropout( | |
| rate=config.prenet_dropout_rate, name="dropout" | |
| ) | |
| def call(self, inputs, training=False): | |
| """Call logic.""" | |
| outputs = inputs | |
| for layer in self.prenet_dense: | |
| outputs = layer(outputs) | |
| outputs = self.dropout(outputs, training=True) | |
| return outputs | |
| class TFTacotronPostnet(tf.keras.layers.Layer): | |
| """Tacotron-2 postnet.""" | |
| def __init__(self, config, **kwargs): | |
| """Init variables.""" | |
| super().__init__(**kwargs) | |
| self.conv_batch_norm = [] | |
| for i in range(config.n_conv_postnet): | |
| conv = TFTacotronConvBatchNorm( | |
| filters=config.postnet_conv_filters, | |
| kernel_size=config.postnet_conv_kernel_sizes, | |
| dropout_rate=config.postnet_dropout_rate, | |
| activation="identity" if i + 1 == config.n_conv_postnet else "tanh", | |
| name_idx=i, | |
| ) | |
| self.conv_batch_norm.append(conv) | |
| def call(self, inputs, training=False): | |
| """Call logic.""" | |
| outputs = inputs | |
| for _, conv in enumerate(self.conv_batch_norm): | |
| outputs = conv(outputs, training=training) | |
| return outputs | |
| TFTacotronDecoderCellState = collections.namedtuple( | |
| "TFTacotronDecoderCellState", | |
| [ | |
| "attention_lstm_state", | |
| "decoder_lstms_state", | |
| "context", | |
| "time", | |
| "state", | |
| "alignment_history", | |
| "max_alignments", | |
| ], | |
| ) | |
| TFDecoderOutput = collections.namedtuple( | |
| "TFDecoderOutput", ("mel_output", "token_output", "sample_id") | |
| ) | |
| class TFTacotronDecoderCell(tf.keras.layers.AbstractRNNCell): | |
| """Tacotron-2 custom decoder cell.""" | |
| def __init__(self, config, enable_tflite_convertible=False, **kwargs): | |
| """Init variables.""" | |
| super().__init__(**kwargs) | |
| self.enable_tflite_convertible = enable_tflite_convertible | |
| self.prenet = TFTacotronPrenet(config, name="prenet") | |
| # define lstm cell on decoder. | |
| # TODO(@dathudeptrai) switch to zone-out lstm. | |
| self.attention_lstm = tf.keras.layers.LSTMCell( | |
| units=config.decoder_lstm_units, name="attention_lstm_cell" | |
| ) | |
| lstm_cells = [] | |
| for i in range(config.n_lstm_decoder): | |
| lstm_cell = tf.keras.layers.LSTMCell( | |
| units=config.decoder_lstm_units, name="lstm_cell_._{}".format(i) | |
| ) | |
| lstm_cells.append(lstm_cell) | |
| self.decoder_lstms = tf.keras.layers.StackedRNNCells( | |
| lstm_cells, name="decoder_lstms" | |
| ) | |
| # define attention layer. | |
| if config.attention_type == "lsa": | |
| # create location-sensitive attention. | |
| self.attention_layer = TFTacotronLocationSensitiveAttention( | |
| config, | |
| memory=None, | |
| mask_encoder=True, | |
| memory_sequence_length=None, | |
| is_cumulate=True, | |
| ) | |
| else: | |
| raise ValueError("Only lsa (location-sensitive attention) is supported") | |
| # frame, stop projection layer. | |
| self.frame_projection = tf.keras.layers.Dense( | |
| units=config.n_mels * config.reduction_factor, name="frame_projection" | |
| ) | |
| self.stop_projection = tf.keras.layers.Dense( | |
| units=config.reduction_factor, name="stop_projection" | |
| ) | |
| self.config = config | |
| def set_alignment_size(self, alignment_size): | |
| self.alignment_size = alignment_size | |
| def output_size(self): | |
| """Return output (mel) size.""" | |
| return self.frame_projection.units | |
| def state_size(self): | |
| """Return hidden state size.""" | |
| return TFTacotronDecoderCellState( | |
| attention_lstm_state=self.attention_lstm.state_size, | |
| decoder_lstms_state=self.decoder_lstms.state_size, | |
| time=tf.TensorShape([]), | |
| attention=self.config.attention_dim, | |
| state=self.alignment_size, | |
| alignment_history=(), | |
| max_alignments=tf.TensorShape([1]), | |
| ) | |
| def get_initial_state(self, batch_size): | |
| """Get initial states.""" | |
| initial_attention_lstm_cell_states = self.attention_lstm.get_initial_state( | |
| None, batch_size, dtype=tf.float32 | |
| ) | |
| initial_decoder_lstms_cell_states = self.decoder_lstms.get_initial_state( | |
| None, batch_size, dtype=tf.float32 | |
| ) | |
| initial_context = tf.zeros( | |
| shape=[batch_size, self.config.encoder_lstm_units * 2], dtype=tf.float32 | |
| ) | |
| initial_state = self.attention_layer.get_initial_state( | |
| batch_size, size=self.alignment_size | |
| ) | |
| if self.enable_tflite_convertible: | |
| initial_alignment_history = () | |
| else: | |
| initial_alignment_history = tf.TensorArray( | |
| dtype=tf.float32, size=0, dynamic_size=True | |
| ) | |
| return TFTacotronDecoderCellState( | |
| attention_lstm_state=initial_attention_lstm_cell_states, | |
| decoder_lstms_state=initial_decoder_lstms_cell_states, | |
| time=tf.zeros([], dtype=tf.int32), | |
| context=initial_context, | |
| state=initial_state, | |
| alignment_history=initial_alignment_history, | |
| max_alignments=tf.zeros([batch_size], dtype=tf.int32), | |
| ) | |
| def call(self, inputs, states, training=False): | |
| """Call logic.""" | |
| decoder_input = inputs | |
| # 1. apply prenet for decoder_input. | |
| prenet_out = self.prenet(decoder_input, training=training) # [batch_size, dim] | |
| # 2. concat prenet_out and prev context vector | |
| # then use it as input of attention lstm layer. | |
| attention_lstm_input = tf.concat([prenet_out, states.context], axis=-1) | |
| attention_lstm_output, next_attention_lstm_state = self.attention_lstm( | |
| attention_lstm_input, states.attention_lstm_state | |
| ) | |
| # 3. compute context, alignment and cumulative alignment. | |
| prev_state = states.state | |
| if not self.enable_tflite_convertible: | |
| prev_alignment_history = states.alignment_history | |
| prev_max_alignments = states.max_alignments | |
| context, alignments, state = self.attention_layer( | |
| [attention_lstm_output, prev_state, prev_max_alignments], training=training, | |
| ) | |
| # 4. run decoder lstm(s) | |
| decoder_lstms_input = tf.concat([attention_lstm_output, context], axis=-1) | |
| decoder_lstms_output, next_decoder_lstms_state = self.decoder_lstms( | |
| decoder_lstms_input, states.decoder_lstms_state | |
| ) | |
| # 5. compute frame feature and stop token. | |
| projection_inputs = tf.concat([decoder_lstms_output, context], axis=-1) | |
| decoder_outputs = self.frame_projection(projection_inputs) | |
| stop_inputs = tf.concat([decoder_lstms_output, decoder_outputs], axis=-1) | |
| stop_tokens = self.stop_projection(stop_inputs) | |
| # 6. save alignment history to visualize. | |
| if self.enable_tflite_convertible: | |
| alignment_history = () | |
| else: | |
| alignment_history = prev_alignment_history.write(states.time, alignments) | |
| # 7. return new states. | |
| new_states = TFTacotronDecoderCellState( | |
| attention_lstm_state=next_attention_lstm_state, | |
| decoder_lstms_state=next_decoder_lstms_state, | |
| time=states.time + 1, | |
| context=context, | |
| state=state, | |
| alignment_history=alignment_history, | |
| max_alignments=tf.argmax(alignments, -1, output_type=tf.int32), | |
| ) | |
| return (decoder_outputs, stop_tokens), new_states | |
| class TFTacotronDecoder(Decoder): | |
| """Tacotron-2 Decoder.""" | |
| def __init__( | |
| self, | |
| decoder_cell, | |
| decoder_sampler, | |
| output_layer=None, | |
| enable_tflite_convertible=False, | |
| ): | |
| """Initial variables.""" | |
| self.cell = decoder_cell | |
| self.sampler = decoder_sampler | |
| self.output_layer = output_layer | |
| self.enable_tflite_convertible = enable_tflite_convertible | |
| def setup_decoder_init_state(self, decoder_init_state): | |
| self.initial_state = decoder_init_state | |
| def initialize(self, **kwargs): | |
| return self.sampler.initialize() + (self.initial_state,) | |
| def output_size(self): | |
| return TFDecoderOutput( | |
| mel_output=tf.nest.map_structure( | |
| lambda shape: tf.TensorShape(shape), self.cell.output_size | |
| ), | |
| token_output=tf.TensorShape(self.sampler.reduction_factor), | |
| sample_id=tf.TensorShape([1]) | |
| if self.enable_tflite_convertible | |
| else self.sampler.sample_ids_shape, # tf.TensorShape([]) | |
| ) | |
| def output_dtype(self): | |
| return TFDecoderOutput(tf.float32, tf.float32, self.sampler.sample_ids_dtype) | |
| def batch_size(self): | |
| return self.sampler._batch_size | |
| def step(self, time, inputs, state, training=False): | |
| (mel_outputs, stop_tokens), cell_state = self.cell( | |
| inputs, state, training=training | |
| ) | |
| if self.output_layer is not None: | |
| mel_outputs = self.output_layer(mel_outputs) | |
| sample_ids = self.sampler.sample( | |
| time=time, outputs=mel_outputs, state=cell_state | |
| ) | |
| (finished, next_inputs, next_state) = self.sampler.next_inputs( | |
| time=time, | |
| outputs=mel_outputs, | |
| state=cell_state, | |
| sample_ids=sample_ids, | |
| stop_token_prediction=stop_tokens, | |
| training=training, | |
| ) | |
| outputs = TFDecoderOutput(mel_outputs, stop_tokens, sample_ids) | |
| return (outputs, next_state, next_inputs, finished) | |
| class TFTacotron2(BaseModel): | |
| """Tensorflow tacotron-2 model.""" | |
| def __init__(self, config, **kwargs): | |
| """Initalize tacotron-2 layers.""" | |
| enable_tflite_convertible = kwargs.pop("enable_tflite_convertible", False) | |
| super().__init__(self, **kwargs) | |
| self.encoder = TFTacotronEncoder(config, name="encoder") | |
| self.decoder_cell = TFTacotronDecoderCell( | |
| config, | |
| name="decoder_cell", | |
| enable_tflite_convertible=enable_tflite_convertible, | |
| ) | |
| self.decoder = TFTacotronDecoder( | |
| self.decoder_cell, | |
| Tacotron2Sampler(config), | |
| enable_tflite_convertible=enable_tflite_convertible, | |
| ) | |
| self.postnet = TFTacotronPostnet(config, name="post_net") | |
| self.post_projection = tf.keras.layers.Dense( | |
| units=config.n_mels, name="residual_projection" | |
| ) | |
| self.use_window_mask = False | |
| self.maximum_iterations = 4000 | |
| self.enable_tflite_convertible = enable_tflite_convertible | |
| self.config = config | |
| def setup_window(self, win_front, win_back): | |
| """Call only for inference.""" | |
| self.use_window_mask = True | |
| self.win_front = win_front | |
| self.win_back = win_back | |
| def setup_maximum_iterations(self, maximum_iterations): | |
| """Call only for inference.""" | |
| self.maximum_iterations = maximum_iterations | |
| def _build(self): | |
| input_ids = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9]]) | |
| input_lengths = np.array([9]) | |
| speaker_ids = np.array([0]) | |
| mel_outputs = np.random.normal(size=(1, 50, 80)).astype(np.float32) | |
| mel_lengths = np.array([50]) | |
| self( | |
| input_ids, | |
| input_lengths, | |
| speaker_ids, | |
| mel_outputs, | |
| mel_lengths, | |
| 10, | |
| training=True, | |
| ) | |
| def call( | |
| self, | |
| input_ids, | |
| input_lengths, | |
| speaker_ids, | |
| mel_gts, | |
| mel_lengths, | |
| maximum_iterations=None, | |
| use_window_mask=False, | |
| win_front=2, | |
| win_back=3, | |
| training=False, | |
| **kwargs, | |
| ): | |
| """Call logic.""" | |
| # create input-mask based on input_lengths | |
| input_mask = tf.sequence_mask( | |
| input_lengths, | |
| maxlen=tf.reduce_max(input_lengths), | |
| name="input_sequence_masks", | |
| ) | |
| # Encoder Step. | |
| encoder_hidden_states = self.encoder( | |
| [input_ids, speaker_ids, input_mask], training=training | |
| ) | |
| batch_size = tf.shape(encoder_hidden_states)[0] | |
| alignment_size = tf.shape(encoder_hidden_states)[1] | |
| # Setup some initial placeholders for decoder step. Include: | |
| # 1. mel_gts, mel_lengths for teacher forcing mode. | |
| # 2. alignment_size for attention size. | |
| # 3. initial state for decoder cell. | |
| # 4. memory (encoder hidden state) for attention mechanism. | |
| self.decoder.sampler.setup_target(targets=mel_gts, mel_lengths=mel_lengths) | |
| self.decoder.cell.set_alignment_size(alignment_size) | |
| self.decoder.setup_decoder_init_state( | |
| self.decoder.cell.get_initial_state(batch_size) | |
| ) | |
| self.decoder.cell.attention_layer.setup_memory( | |
| memory=encoder_hidden_states, | |
| memory_sequence_length=input_lengths, # use for mask attention. | |
| ) | |
| if use_window_mask: | |
| self.decoder.cell.attention_layer.setup_window( | |
| win_front=win_front, win_back=win_back | |
| ) | |
| # run decode step. | |
| ( | |
| (frames_prediction, stop_token_prediction, _), | |
| final_decoder_state, | |
| _, | |
| ) = dynamic_decode( | |
| self.decoder, | |
| maximum_iterations=maximum_iterations, | |
| enable_tflite_convertible=self.enable_tflite_convertible, | |
| training=training, | |
| ) | |
| decoder_outputs = tf.reshape( | |
| frames_prediction, [batch_size, -1, self.config.n_mels] | |
| ) | |
| stop_token_prediction = tf.reshape(stop_token_prediction, [batch_size, -1]) | |
| residual = self.postnet(decoder_outputs, training=training) | |
| residual_projection = self.post_projection(residual) | |
| mel_outputs = decoder_outputs + residual_projection | |
| if self.enable_tflite_convertible: | |
| mask = tf.math.not_equal( | |
| tf.cast( | |
| tf.reduce_sum(tf.abs(decoder_outputs), axis=-1), dtype=tf.int32 | |
| ), | |
| 0, | |
| ) | |
| decoder_outputs = tf.expand_dims( | |
| tf.boolean_mask(decoder_outputs, mask), axis=0 | |
| ) | |
| mel_outputs = tf.expand_dims(tf.boolean_mask(mel_outputs, mask), axis=0) | |
| alignment_history = () | |
| else: | |
| alignment_history = tf.transpose( | |
| final_decoder_state.alignment_history.stack(), [1, 2, 0] | |
| ) | |
| return decoder_outputs, mel_outputs, stop_token_prediction, alignment_history | |
| def inference(self, input_ids, input_lengths, speaker_ids, **kwargs): | |
| """Call logic.""" | |
| # create input-mask based on input_lengths | |
| input_mask = tf.sequence_mask( | |
| input_lengths, | |
| maxlen=tf.reduce_max(input_lengths), | |
| name="input_sequence_masks", | |
| ) | |
| # Encoder Step. | |
| encoder_hidden_states = self.encoder( | |
| [input_ids, speaker_ids, input_mask], training=False | |
| ) | |
| batch_size = tf.shape(encoder_hidden_states)[0] | |
| alignment_size = tf.shape(encoder_hidden_states)[1] | |
| # Setup some initial placeholders for decoder step. Include: | |
| # 1. batch_size for inference. | |
| # 2. alignment_size for attention size. | |
| # 3. initial state for decoder cell. | |
| # 4. memory (encoder hidden state) for attention mechanism. | |
| # 5. window front/back to solve long sentence synthesize problems. (call after setup memory.) | |
| self.decoder.sampler.set_batch_size(batch_size) | |
| self.decoder.cell.set_alignment_size(alignment_size) | |
| self.decoder.setup_decoder_init_state( | |
| self.decoder.cell.get_initial_state(batch_size) | |
| ) | |
| self.decoder.cell.attention_layer.setup_memory( | |
| memory=encoder_hidden_states, | |
| memory_sequence_length=input_lengths, # use for mask attention. | |
| ) | |
| if self.use_window_mask: | |
| self.decoder.cell.attention_layer.setup_window( | |
| win_front=self.win_front, win_back=self.win_back | |
| ) | |
| # run decode step. | |
| ( | |
| (frames_prediction, stop_token_prediction, _), | |
| final_decoder_state, | |
| _, | |
| ) = dynamic_decode( | |
| self.decoder, maximum_iterations=self.maximum_iterations, training=False | |
| ) | |
| decoder_outputs = tf.reshape( | |
| frames_prediction, [batch_size, -1, self.config.n_mels] | |
| ) | |
| stop_token_predictions = tf.reshape(stop_token_prediction, [batch_size, -1]) | |
| residual = self.postnet(decoder_outputs, training=False) | |
| residual_projection = self.post_projection(residual) | |
| mel_outputs = decoder_outputs + residual_projection | |
| alignment_historys = tf.transpose( | |
| final_decoder_state.alignment_history.stack(), [1, 2, 0] | |
| ) | |
| return decoder_outputs, mel_outputs, stop_token_predictions, alignment_historys | |
| def inference_tflite(self, input_ids, input_lengths, speaker_ids, **kwargs): | |
| """Call logic.""" | |
| # create input-mask based on input_lengths | |
| input_mask = tf.sequence_mask( | |
| input_lengths, | |
| maxlen=tf.reduce_max(input_lengths), | |
| name="input_sequence_masks", | |
| ) | |
| # Encoder Step. | |
| encoder_hidden_states = self.encoder( | |
| [input_ids, speaker_ids, input_mask], training=False | |
| ) | |
| batch_size = tf.shape(encoder_hidden_states)[0] | |
| alignment_size = tf.shape(encoder_hidden_states)[1] | |
| # Setup some initial placeholders for decoder step. Include: | |
| # 1. batch_size for inference. | |
| # 2. alignment_size for attention size. | |
| # 3. initial state for decoder cell. | |
| # 4. memory (encoder hidden state) for attention mechanism. | |
| # 5. window front/back to solve long sentence synthesize problems. (call after setup memory.) | |
| self.decoder.sampler.set_batch_size(batch_size) | |
| self.decoder.cell.set_alignment_size(alignment_size) | |
| self.decoder.setup_decoder_init_state( | |
| self.decoder.cell.get_initial_state(batch_size) | |
| ) | |
| self.decoder.cell.attention_layer.setup_memory( | |
| memory=encoder_hidden_states, | |
| memory_sequence_length=input_lengths, # use for mask attention. | |
| ) | |
| if self.use_window_mask: | |
| self.decoder.cell.attention_layer.setup_window( | |
| win_front=self.win_front, win_back=self.win_back | |
| ) | |
| # run decode step. | |
| ( | |
| (frames_prediction, stop_token_prediction, _), | |
| final_decoder_state, | |
| _, | |
| ) = dynamic_decode( | |
| self.decoder, | |
| maximum_iterations=self.maximum_iterations, | |
| enable_tflite_convertible=self.enable_tflite_convertible, | |
| training=False, | |
| ) | |
| decoder_outputs = tf.reshape( | |
| frames_prediction, [batch_size, -1, self.config.n_mels] | |
| ) | |
| stop_token_predictions = tf.reshape(stop_token_prediction, [batch_size, -1]) | |
| residual = self.postnet(decoder_outputs, training=False) | |
| residual_projection = self.post_projection(residual) | |
| mel_outputs = decoder_outputs + residual_projection | |
| if self.enable_tflite_convertible: | |
| mask = tf.math.not_equal( | |
| tf.cast( | |
| tf.reduce_sum(tf.abs(decoder_outputs), axis=-1), dtype=tf.int32 | |
| ), | |
| 0, | |
| ) | |
| decoder_outputs = tf.expand_dims( | |
| tf.boolean_mask(decoder_outputs, mask), axis=0 | |
| ) | |
| mel_outputs = tf.expand_dims(tf.boolean_mask(mel_outputs, mask), axis=0) | |
| alignment_historys = () | |
| else: | |
| alignment_historys = tf.transpose( | |
| final_decoder_state.alignment_history.stack(), [1, 2, 0] | |
| ) | |
| return decoder_outputs, mel_outputs, stop_token_predictions, alignment_historys | |