Comparative-Analysis-of-Speech-Synthesis-Models
/
TensorFlowTTS
/tensorflow_tts
/models
/fastspeech.py
| # -*- coding: utf-8 -*- | |
| # Copyright 2020 The FastSpeech Authors, The HuggingFace Inc. team and Minh Nguyen (@dathudeptrai) | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Tensorflow Model modules for FastSpeech.""" | |
| import numpy as np | |
| import tensorflow as tf | |
| from tensorflow_tts.models import BaseModel | |
| def get_initializer(initializer_range=0.02): | |
| """Creates a `tf.initializers.truncated_normal` with the given range. | |
| Args: | |
| initializer_range: float, initializer range for stddev. | |
| Returns: | |
| TruncatedNormal initializer with stddev = `initializer_range`. | |
| """ | |
| return tf.keras.initializers.TruncatedNormal(stddev=initializer_range) | |
| def gelu(x): | |
| """Gaussian Error Linear unit.""" | |
| cdf = 0.5 * (1.0 + tf.math.erf(x / tf.math.sqrt(2.0))) | |
| return x * cdf | |
| def gelu_new(x): | |
| """Smoother gaussian Error Linear Unit.""" | |
| cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))) | |
| return x * cdf | |
| def swish(x): | |
| """Swish activation function.""" | |
| return tf.nn.swish(x) | |
| def mish(x): | |
| return x * tf.math.tanh(tf.math.softplus(x)) | |
| ACT2FN = { | |
| "identity": tf.keras.layers.Activation("linear"), | |
| "tanh": tf.keras.layers.Activation("tanh"), | |
| "gelu": tf.keras.layers.Activation(gelu), | |
| "relu": tf.keras.activations.relu, | |
| "swish": tf.keras.layers.Activation(swish), | |
| "gelu_new": tf.keras.layers.Activation(gelu_new), | |
| "mish": tf.keras.layers.Activation(mish), | |
| } | |
| class TFEmbedding(tf.keras.layers.Embedding): | |
| """Faster version of embedding.""" | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| def call(self, inputs): | |
| inputs = tf.cast(inputs, tf.int32) | |
| outputs = tf.gather(self.embeddings, inputs) | |
| return outputs | |
| class TFFastSpeechEmbeddings(tf.keras.layers.Layer): | |
| """Construct charactor/phoneme/positional/speaker embeddings.""" | |
| def __init__(self, config, **kwargs): | |
| """Init variables.""" | |
| super().__init__(**kwargs) | |
| self.vocab_size = config.vocab_size | |
| self.hidden_size = config.encoder_self_attention_params.hidden_size | |
| self.initializer_range = config.initializer_range | |
| self.config = config | |
| self.position_embeddings = TFEmbedding( | |
| config.max_position_embeddings + 1, | |
| self.hidden_size, | |
| weights=[ | |
| self._sincos_embedding( | |
| self.hidden_size, self.config.max_position_embeddings | |
| ) | |
| ], | |
| name="position_embeddings", | |
| trainable=False, | |
| ) | |
| if config.n_speakers > 1: | |
| self.encoder_speaker_embeddings = TFEmbedding( | |
| config.n_speakers, | |
| self.hidden_size, | |
| embeddings_initializer=get_initializer(self.initializer_range), | |
| name="speaker_embeddings", | |
| ) | |
| self.speaker_fc = tf.keras.layers.Dense( | |
| units=self.hidden_size, name="speaker_fc" | |
| ) | |
| def build(self, input_shape): | |
| """Build shared charactor/phoneme embedding layers.""" | |
| with tf.name_scope("charactor_embeddings"): | |
| self.charactor_embeddings = self.add_weight( | |
| "weight", | |
| shape=[self.vocab_size, self.hidden_size], | |
| initializer=get_initializer(self.initializer_range), | |
| ) | |
| super().build(input_shape) | |
| def call(self, inputs, training=False): | |
| """Get charactor embeddings of inputs. | |
| Args: | |
| 1. charactor, Tensor (int32) shape [batch_size, length]. | |
| 2. speaker_id, Tensor (int32) shape [batch_size] | |
| Returns: | |
| Tensor (float32) shape [batch_size, length, embedding_size]. | |
| """ | |
| return self._embedding(inputs, training=training) | |
| def _embedding(self, inputs, training=False): | |
| """Applies embedding based on inputs tensor.""" | |
| input_ids, speaker_ids = inputs | |
| input_shape = tf.shape(input_ids) | |
| seq_length = input_shape[1] | |
| position_ids = tf.range(1, seq_length + 1, dtype=tf.int32)[tf.newaxis, :] | |
| # create embeddings | |
| inputs_embeds = tf.gather(self.charactor_embeddings, input_ids) | |
| position_embeddings = self.position_embeddings(position_ids) | |
| # sum embedding | |
| embeddings = inputs_embeds + tf.cast(position_embeddings, inputs_embeds.dtype) | |
| if self.config.n_speakers > 1: | |
| speaker_embeddings = self.encoder_speaker_embeddings(speaker_ids) | |
| speaker_features = tf.math.softplus(self.speaker_fc(speaker_embeddings)) | |
| # extended speaker embeddings | |
| extended_speaker_features = speaker_features[:, tf.newaxis, :] | |
| embeddings += extended_speaker_features | |
| return embeddings | |
| def _sincos_embedding( | |
| self, hidden_size, max_positional_embedding, | |
| ): | |
| position_enc = np.array( | |
| [ | |
| [ | |
| pos / np.power(10000, 2.0 * (i // 2) / hidden_size) | |
| for i in range(hidden_size) | |
| ] | |
| for pos in range(max_positional_embedding + 1) | |
| ] | |
| ) | |
| position_enc[:, 0::2] = np.sin(position_enc[:, 0::2]) | |
| position_enc[:, 1::2] = np.cos(position_enc[:, 1::2]) | |
| # pad embedding. | |
| position_enc[0] = 0.0 | |
| return position_enc | |
| def resize_positional_embeddings(self, new_size): | |
| self.position_embeddings = TFEmbedding( | |
| new_size + 1, | |
| self.hidden_size, | |
| weights=[self._sincos_embedding(self.hidden_size, new_size)], | |
| name="position_embeddings", | |
| trainable=False, | |
| ) | |
| class TFFastSpeechSelfAttention(tf.keras.layers.Layer): | |
| """Self attention module for fastspeech.""" | |
| def __init__(self, config, **kwargs): | |
| """Init variables.""" | |
| super().__init__(**kwargs) | |
| if config.hidden_size % config.num_attention_heads != 0: | |
| raise ValueError( | |
| "The hidden size (%d) is not a multiple of the number of attention " | |
| "heads (%d)" % (config.hidden_size, config.num_attention_heads) | |
| ) | |
| self.output_attentions = config.output_attentions | |
| self.num_attention_heads = config.num_attention_heads | |
| self.all_head_size = self.num_attention_heads * config.attention_head_size | |
| self.query = tf.keras.layers.Dense( | |
| self.all_head_size, | |
| kernel_initializer=get_initializer(config.initializer_range), | |
| name="query", | |
| ) | |
| self.key = tf.keras.layers.Dense( | |
| self.all_head_size, | |
| kernel_initializer=get_initializer(config.initializer_range), | |
| name="key", | |
| ) | |
| self.value = tf.keras.layers.Dense( | |
| self.all_head_size, | |
| kernel_initializer=get_initializer(config.initializer_range), | |
| name="value", | |
| ) | |
| self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob) | |
| self.config = config | |
| def transpose_for_scores(self, x, batch_size): | |
| """Transpose to calculate attention scores.""" | |
| x = tf.reshape( | |
| x, | |
| (batch_size, -1, self.num_attention_heads, self.config.attention_head_size), | |
| ) | |
| return tf.transpose(x, perm=[0, 2, 1, 3]) | |
| def call(self, inputs, training=False): | |
| """Call logic.""" | |
| hidden_states, attention_mask = inputs | |
| batch_size = tf.shape(hidden_states)[0] | |
| mixed_query_layer = self.query(hidden_states) | |
| mixed_key_layer = self.key(hidden_states) | |
| mixed_value_layer = self.value(hidden_states) | |
| query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) | |
| key_layer = self.transpose_for_scores(mixed_key_layer, batch_size) | |
| value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) | |
| attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) | |
| dk = tf.cast( | |
| tf.shape(key_layer)[-1], attention_scores.dtype | |
| ) # scale attention_scores | |
| attention_scores = attention_scores / tf.math.sqrt(dk) | |
| if attention_mask is not None: | |
| # extended_attention_masks for self attention encoder. | |
| extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :] | |
| extended_attention_mask = tf.cast( | |
| extended_attention_mask, attention_scores.dtype | |
| ) | |
| extended_attention_mask = (1.0 - extended_attention_mask) * -1e9 | |
| attention_scores = attention_scores + extended_attention_mask | |
| # Normalize the attention scores to probabilities. | |
| attention_probs = tf.nn.softmax(attention_scores, axis=-1) | |
| attention_probs = self.dropout(attention_probs, training=training) | |
| context_layer = tf.matmul(attention_probs, value_layer) | |
| context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3]) | |
| context_layer = tf.reshape(context_layer, (batch_size, -1, self.all_head_size)) | |
| outputs = ( | |
| (context_layer, attention_probs) | |
| if self.output_attentions | |
| else (context_layer,) | |
| ) | |
| return outputs | |
| class TFFastSpeechSelfOutput(tf.keras.layers.Layer): | |
| """Fastspeech output of self attention module.""" | |
| def __init__(self, config, **kwargs): | |
| """Init variables.""" | |
| super().__init__(**kwargs) | |
| self.dense = tf.keras.layers.Dense( | |
| config.hidden_size, | |
| kernel_initializer=get_initializer(config.initializer_range), | |
| name="dense", | |
| ) | |
| self.LayerNorm = tf.keras.layers.LayerNormalization( | |
| epsilon=config.layer_norm_eps, name="LayerNorm" | |
| ) | |
| self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) | |
| def call(self, inputs, training=False): | |
| """Call logic.""" | |
| hidden_states, input_tensor = inputs | |
| hidden_states = self.dense(hidden_states) | |
| hidden_states = self.dropout(hidden_states, training=training) | |
| hidden_states = self.LayerNorm(hidden_states + input_tensor) | |
| return hidden_states | |
| class TFFastSpeechAttention(tf.keras.layers.Layer): | |
| """Fastspeech attention module.""" | |
| def __init__(self, config, **kwargs): | |
| """Init variables.""" | |
| super().__init__(**kwargs) | |
| self.self_attention = TFFastSpeechSelfAttention(config, name="self") | |
| self.dense_output = TFFastSpeechSelfOutput(config, name="output") | |
| def call(self, inputs, training=False): | |
| input_tensor, attention_mask = inputs | |
| self_outputs = self.self_attention( | |
| [input_tensor, attention_mask], training=training | |
| ) | |
| attention_output = self.dense_output( | |
| [self_outputs[0], input_tensor], training=training | |
| ) | |
| masked_attention_output = attention_output * tf.cast( | |
| tf.expand_dims(attention_mask, 2), dtype=attention_output.dtype | |
| ) | |
| outputs = (masked_attention_output,) + self_outputs[ | |
| 1: | |
| ] # add attentions if we output them | |
| return outputs | |
| class TFFastSpeechIntermediate(tf.keras.layers.Layer): | |
| """Intermediate representation module.""" | |
| def __init__(self, config, **kwargs): | |
| """Init variables.""" | |
| super().__init__(**kwargs) | |
| self.conv1d_1 = tf.keras.layers.Conv1D( | |
| config.intermediate_size, | |
| kernel_size=config.intermediate_kernel_size, | |
| kernel_initializer=get_initializer(config.initializer_range), | |
| padding="same", | |
| name="conv1d_1", | |
| ) | |
| self.conv1d_2 = tf.keras.layers.Conv1D( | |
| config.hidden_size, | |
| kernel_size=config.intermediate_kernel_size, | |
| kernel_initializer=get_initializer(config.initializer_range), | |
| padding="same", | |
| name="conv1d_2", | |
| ) | |
| if isinstance(config.hidden_act, str): | |
| self.intermediate_act_fn = ACT2FN[config.hidden_act] | |
| else: | |
| self.intermediate_act_fn = config.hidden_act | |
| def call(self, inputs): | |
| """Call logic.""" | |
| hidden_states, attention_mask = inputs | |
| hidden_states = self.conv1d_1(hidden_states) | |
| hidden_states = self.intermediate_act_fn(hidden_states) | |
| hidden_states = self.conv1d_2(hidden_states) | |
| masked_hidden_states = hidden_states * tf.cast( | |
| tf.expand_dims(attention_mask, 2), dtype=hidden_states.dtype | |
| ) | |
| return masked_hidden_states | |
| class TFFastSpeechOutput(tf.keras.layers.Layer): | |
| """Output module.""" | |
| def __init__(self, config, **kwargs): | |
| """Init variables.""" | |
| super().__init__(**kwargs) | |
| self.LayerNorm = tf.keras.layers.LayerNormalization( | |
| epsilon=config.layer_norm_eps, name="LayerNorm" | |
| ) | |
| self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) | |
| def call(self, inputs, training=False): | |
| """Call logic.""" | |
| hidden_states, input_tensor = inputs | |
| hidden_states = self.dropout(hidden_states, training=training) | |
| hidden_states = self.LayerNorm(hidden_states + input_tensor) | |
| return hidden_states | |
| class TFFastSpeechLayer(tf.keras.layers.Layer): | |
| """Fastspeech module (FFT module on the paper).""" | |
| def __init__(self, config, **kwargs): | |
| """Init variables.""" | |
| super().__init__(**kwargs) | |
| self.attention = TFFastSpeechAttention(config, name="attention") | |
| self.intermediate = TFFastSpeechIntermediate(config, name="intermediate") | |
| self.bert_output = TFFastSpeechOutput(config, name="output") | |
| def call(self, inputs, training=False): | |
| """Call logic.""" | |
| hidden_states, attention_mask = inputs | |
| attention_outputs = self.attention( | |
| [hidden_states, attention_mask], training=training | |
| ) | |
| attention_output = attention_outputs[0] | |
| intermediate_output = self.intermediate( | |
| [attention_output, attention_mask], training=training | |
| ) | |
| layer_output = self.bert_output( | |
| [intermediate_output, attention_output], training=training | |
| ) | |
| masked_layer_output = layer_output * tf.cast( | |
| tf.expand_dims(attention_mask, 2), dtype=layer_output.dtype | |
| ) | |
| outputs = (masked_layer_output,) + attention_outputs[ | |
| 1: | |
| ] # add attentions if we output them | |
| return outputs | |
| class TFFastSpeechEncoder(tf.keras.layers.Layer): | |
| """Fast Speech encoder module.""" | |
| def __init__(self, config, **kwargs): | |
| """Init variables.""" | |
| super().__init__(**kwargs) | |
| self.output_attentions = config.output_attentions | |
| self.output_hidden_states = config.output_hidden_states | |
| self.layer = [ | |
| TFFastSpeechLayer(config, name="layer_._{}".format(i)) | |
| for i in range(config.num_hidden_layers) | |
| ] | |
| def call(self, inputs, training=False): | |
| """Call logic.""" | |
| hidden_states, attention_mask = inputs | |
| all_hidden_states = () | |
| all_attentions = () | |
| for _, layer_module in enumerate(self.layer): | |
| if self.output_hidden_states: | |
| all_hidden_states = all_hidden_states + (hidden_states,) | |
| layer_outputs = layer_module( | |
| [hidden_states, attention_mask], training=training | |
| ) | |
| hidden_states = layer_outputs[0] | |
| if self.output_attentions: | |
| all_attentions = all_attentions + (layer_outputs[1],) | |
| # Add last layer | |
| if self.output_hidden_states: | |
| all_hidden_states = all_hidden_states + (hidden_states,) | |
| outputs = (hidden_states,) | |
| if self.output_hidden_states: | |
| outputs = outputs + (all_hidden_states,) | |
| if self.output_attentions: | |
| outputs = outputs + (all_attentions,) | |
| return outputs # outputs, (hidden states), (attentions) | |
| class TFFastSpeechDecoder(TFFastSpeechEncoder): | |
| """Fast Speech decoder module.""" | |
| def __init__(self, config, **kwargs): | |
| self.is_compatible_encoder = kwargs.pop("is_compatible_encoder", True) | |
| super().__init__(config, **kwargs) | |
| self.config = config | |
| # create decoder positional embedding | |
| self.decoder_positional_embeddings = TFEmbedding( | |
| config.max_position_embeddings + 1, | |
| config.hidden_size, | |
| weights=[self._sincos_embedding()], | |
| name="position_embeddings", | |
| trainable=False, | |
| ) | |
| if self.is_compatible_encoder is False: | |
| self.project_compatible_decoder = tf.keras.layers.Dense( | |
| units=config.hidden_size, name="project_compatible_decoder" | |
| ) | |
| if config.n_speakers > 1: | |
| self.decoder_speaker_embeddings = TFEmbedding( | |
| config.n_speakers, | |
| config.hidden_size, | |
| embeddings_initializer=get_initializer(config.initializer_range), | |
| name="speaker_embeddings", | |
| ) | |
| self.speaker_fc = tf.keras.layers.Dense( | |
| units=config.hidden_size, name="speaker_fc" | |
| ) | |
| def call(self, inputs, training=False): | |
| hidden_states, speaker_ids, encoder_mask, decoder_pos = inputs | |
| if self.is_compatible_encoder is False: | |
| hidden_states = self.project_compatible_decoder(hidden_states) | |
| # calculate new hidden states. | |
| hidden_states += tf.cast( | |
| self.decoder_positional_embeddings(decoder_pos), hidden_states.dtype | |
| ) | |
| if self.config.n_speakers > 1: | |
| speaker_embeddings = self.decoder_speaker_embeddings(speaker_ids) | |
| speaker_features = tf.math.softplus(self.speaker_fc(speaker_embeddings)) | |
| # extended speaker embeddings | |
| extended_speaker_features = speaker_features[:, tf.newaxis, :] | |
| hidden_states += extended_speaker_features | |
| return super().call([hidden_states, encoder_mask], training=training) | |
| def _sincos_embedding(self): | |
| position_enc = np.array( | |
| [ | |
| [ | |
| pos / np.power(10000, 2.0 * (i // 2) / self.config.hidden_size) | |
| for i in range(self.config.hidden_size) | |
| ] | |
| for pos in range(self.config.max_position_embeddings + 1) | |
| ] | |
| ) | |
| position_enc[:, 0::2] = np.sin(position_enc[:, 0::2]) | |
| position_enc[:, 1::2] = np.cos(position_enc[:, 1::2]) | |
| # pad embedding. | |
| position_enc[0] = 0.0 | |
| return position_enc | |
| class TFTacotronPostnet(tf.keras.layers.Layer): | |
| """Tacotron-2 postnet.""" | |
| def __init__(self, config, **kwargs): | |
| """Init variables.""" | |
| super().__init__(**kwargs) | |
| self.conv_batch_norm = [] | |
| for i in range(config.n_conv_postnet): | |
| conv = tf.keras.layers.Conv1D( | |
| filters=config.postnet_conv_filters | |
| if i < config.n_conv_postnet - 1 | |
| else config.num_mels, | |
| kernel_size=config.postnet_conv_kernel_sizes, | |
| padding="same", | |
| name="conv_._{}".format(i), | |
| ) | |
| batch_norm = tf.keras.layers.BatchNormalization( | |
| axis=-1, name="batch_norm_._{}".format(i) | |
| ) | |
| self.conv_batch_norm.append((conv, batch_norm)) | |
| self.dropout = tf.keras.layers.Dropout( | |
| rate=config.postnet_dropout_rate, name="dropout" | |
| ) | |
| self.activation = [tf.nn.tanh] * (config.n_conv_postnet - 1) + [tf.identity] | |
| def call(self, inputs, training=False): | |
| """Call logic.""" | |
| outputs, mask = inputs | |
| extended_mask = tf.cast(tf.expand_dims(mask, axis=2), outputs.dtype) | |
| for i, (conv, bn) in enumerate(self.conv_batch_norm): | |
| outputs = conv(outputs) | |
| outputs = bn(outputs) | |
| outputs = self.activation[i](outputs) | |
| outputs = self.dropout(outputs, training=training) | |
| return outputs * extended_mask | |
| class TFFastSpeechDurationPredictor(tf.keras.layers.Layer): | |
| """FastSpeech duration predictor module.""" | |
| def __init__(self, config, **kwargs): | |
| """Init variables.""" | |
| super().__init__(**kwargs) | |
| self.conv_layers = [] | |
| for i in range(config.num_duration_conv_layers): | |
| self.conv_layers.append( | |
| tf.keras.layers.Conv1D( | |
| config.duration_predictor_filters, | |
| config.duration_predictor_kernel_sizes, | |
| padding="same", | |
| name="conv_._{}".format(i), | |
| ) | |
| ) | |
| self.conv_layers.append( | |
| tf.keras.layers.LayerNormalization( | |
| epsilon=config.layer_norm_eps, name="LayerNorm_._{}".format(i) | |
| ) | |
| ) | |
| self.conv_layers.append(tf.keras.layers.Activation(tf.nn.relu6)) | |
| self.conv_layers.append( | |
| tf.keras.layers.Dropout(config.duration_predictor_dropout_probs) | |
| ) | |
| self.conv_layers_sequence = tf.keras.Sequential(self.conv_layers) | |
| self.output_layer = tf.keras.layers.Dense(1) | |
| def call(self, inputs, training=False): | |
| """Call logic.""" | |
| encoder_hidden_states, attention_mask = inputs | |
| attention_mask = tf.cast( | |
| tf.expand_dims(attention_mask, 2), encoder_hidden_states.dtype | |
| ) | |
| # mask encoder hidden states | |
| masked_encoder_hidden_states = encoder_hidden_states * attention_mask | |
| # pass though first layer | |
| outputs = self.conv_layers_sequence(masked_encoder_hidden_states) | |
| outputs = self.output_layer(outputs) | |
| masked_outputs = outputs * attention_mask | |
| return tf.squeeze(tf.nn.relu6(masked_outputs), -1) # make sure positive value. | |
| class TFFastSpeechLengthRegulator(tf.keras.layers.Layer): | |
| """FastSpeech lengthregulator module.""" | |
| def __init__(self, config, **kwargs): | |
| """Init variables.""" | |
| self.enable_tflite_convertible = kwargs.pop("enable_tflite_convertible", False) | |
| super().__init__(**kwargs) | |
| self.config = config | |
| def call(self, inputs, training=False): | |
| """Call logic. | |
| Args: | |
| 1. encoder_hidden_states, Tensor (float32) shape [batch_size, length, hidden_size] | |
| 2. durations_gt, Tensor (float32/int32) shape [batch_size, length] | |
| """ | |
| encoder_hidden_states, durations_gt = inputs | |
| outputs, encoder_masks = self._length_regulator( | |
| encoder_hidden_states, durations_gt | |
| ) | |
| return outputs, encoder_masks | |
| def _length_regulator(self, encoder_hidden_states, durations_gt): | |
| """Length regulator logic.""" | |
| sum_durations = tf.reduce_sum(durations_gt, axis=-1) # [batch_size] | |
| max_durations = tf.reduce_max(sum_durations) | |
| input_shape = tf.shape(encoder_hidden_states) | |
| batch_size = input_shape[0] | |
| hidden_size = input_shape[-1] | |
| # initialize output hidden states and encoder masking. | |
| if self.enable_tflite_convertible: | |
| # There is only 1 batch in inference, so we don't have to use | |
| # `tf.While` op with 3-D output tensor. | |
| repeats = durations_gt[0] | |
| real_length = tf.reduce_sum(repeats) | |
| pad_size = max_durations - real_length | |
| # masks : [max_durations] | |
| masks = tf.sequence_mask([real_length], max_durations, dtype=tf.int32) | |
| repeat_encoder_hidden_states = tf.repeat( | |
| encoder_hidden_states[0], repeats=repeats, axis=0 | |
| ) | |
| repeat_encoder_hidden_states = tf.expand_dims( | |
| tf.pad(repeat_encoder_hidden_states, [[0, pad_size], [0, 0]]), 0 | |
| ) # [1, max_durations, hidden_size] | |
| outputs = repeat_encoder_hidden_states | |
| encoder_masks = masks | |
| else: | |
| outputs = tf.zeros( | |
| shape=[0, max_durations, hidden_size], dtype=encoder_hidden_states.dtype | |
| ) | |
| encoder_masks = tf.zeros(shape=[0, max_durations], dtype=tf.int32) | |
| def condition( | |
| i, | |
| batch_size, | |
| outputs, | |
| encoder_masks, | |
| encoder_hidden_states, | |
| durations_gt, | |
| max_durations, | |
| ): | |
| return tf.less(i, batch_size) | |
| def body( | |
| i, | |
| batch_size, | |
| outputs, | |
| encoder_masks, | |
| encoder_hidden_states, | |
| durations_gt, | |
| max_durations, | |
| ): | |
| repeats = durations_gt[i] | |
| real_length = tf.reduce_sum(repeats) | |
| pad_size = max_durations - real_length | |
| masks = tf.sequence_mask([real_length], max_durations, dtype=tf.int32) | |
| repeat_encoder_hidden_states = tf.repeat( | |
| encoder_hidden_states[i], repeats=repeats, axis=0 | |
| ) | |
| repeat_encoder_hidden_states = tf.expand_dims( | |
| tf.pad(repeat_encoder_hidden_states, [[0, pad_size], [0, 0]]), 0 | |
| ) # [1, max_durations, hidden_size] | |
| outputs = tf.concat([outputs, repeat_encoder_hidden_states], axis=0) | |
| encoder_masks = tf.concat([encoder_masks, masks], axis=0) | |
| return [ | |
| i + 1, | |
| batch_size, | |
| outputs, | |
| encoder_masks, | |
| encoder_hidden_states, | |
| durations_gt, | |
| max_durations, | |
| ] | |
| # initialize iteration i. | |
| i = tf.constant(0, dtype=tf.int32) | |
| _, _, outputs, encoder_masks, _, _, _, = tf.while_loop( | |
| condition, | |
| body, | |
| [ | |
| i, | |
| batch_size, | |
| outputs, | |
| encoder_masks, | |
| encoder_hidden_states, | |
| durations_gt, | |
| max_durations, | |
| ], | |
| shape_invariants=[ | |
| i.get_shape(), | |
| batch_size.get_shape(), | |
| tf.TensorShape( | |
| [ | |
| None, | |
| None, | |
| self.config.encoder_self_attention_params.hidden_size, | |
| ] | |
| ), | |
| tf.TensorShape([None, None]), | |
| encoder_hidden_states.get_shape(), | |
| durations_gt.get_shape(), | |
| max_durations.get_shape(), | |
| ], | |
| ) | |
| return outputs, encoder_masks | |
| class TFFastSpeech(BaseModel): | |
| """TF Fastspeech module.""" | |
| def __init__(self, config, **kwargs): | |
| """Init layers for fastspeech.""" | |
| self.enable_tflite_convertible = kwargs.pop("enable_tflite_convertible", False) | |
| super().__init__(**kwargs) | |
| self.embeddings = TFFastSpeechEmbeddings(config, name="embeddings") | |
| self.encoder = TFFastSpeechEncoder( | |
| config.encoder_self_attention_params, name="encoder" | |
| ) | |
| self.duration_predictor = TFFastSpeechDurationPredictor( | |
| config, dtype=tf.float32, name="duration_predictor" | |
| ) | |
| self.length_regulator = TFFastSpeechLengthRegulator( | |
| config, | |
| enable_tflite_convertible=self.enable_tflite_convertible, | |
| name="length_regulator", | |
| ) | |
| self.decoder = TFFastSpeechDecoder( | |
| config.decoder_self_attention_params, | |
| is_compatible_encoder=config.encoder_self_attention_params.hidden_size | |
| == config.decoder_self_attention_params.hidden_size, | |
| name="decoder", | |
| ) | |
| self.mel_dense = tf.keras.layers.Dense( | |
| units=config.num_mels, dtype=tf.float32, name="mel_before" | |
| ) | |
| self.postnet = TFTacotronPostnet( | |
| config=config, dtype=tf.float32, name="postnet" | |
| ) | |
| self.setup_inference_fn() | |
| def _build(self): | |
| """Dummy input for building model.""" | |
| # fake inputs | |
| input_ids = tf.convert_to_tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], tf.int32) | |
| speaker_ids = tf.convert_to_tensor([0], tf.int32) | |
| duration_gts = tf.convert_to_tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], tf.int32) | |
| self(input_ids, speaker_ids, duration_gts) | |
| def resize_positional_embeddings(self, new_size): | |
| self.embeddings.resize_positional_embeddings(new_size) | |
| self._build() | |
| def call( | |
| self, input_ids, speaker_ids, duration_gts, training=False, **kwargs, | |
| ): | |
| """Call logic.""" | |
| attention_mask = tf.math.not_equal(input_ids, 0) | |
| embedding_output = self.embeddings([input_ids, speaker_ids], training=training) | |
| encoder_output = self.encoder( | |
| [embedding_output, attention_mask], training=training | |
| ) | |
| last_encoder_hidden_states = encoder_output[0] | |
| # duration predictor, here use last_encoder_hidden_states, u can use more hidden_states layers | |
| # rather than just use last_hidden_states of encoder for duration_predictor. | |
| duration_outputs = self.duration_predictor( | |
| [last_encoder_hidden_states, attention_mask] | |
| ) # [batch_size, length] | |
| length_regulator_outputs, encoder_masks = self.length_regulator( | |
| [last_encoder_hidden_states, duration_gts], training=training | |
| ) | |
| # create decoder positional embedding | |
| decoder_pos = tf.range( | |
| 1, tf.shape(length_regulator_outputs)[1] + 1, dtype=tf.int32 | |
| ) | |
| masked_decoder_pos = tf.expand_dims(decoder_pos, 0) * encoder_masks | |
| decoder_output = self.decoder( | |
| [length_regulator_outputs, speaker_ids, encoder_masks, masked_decoder_pos], | |
| training=training, | |
| ) | |
| last_decoder_hidden_states = decoder_output[0] | |
| # here u can use sum or concat more than 1 hidden states layers from decoder. | |
| mel_before = self.mel_dense(last_decoder_hidden_states) | |
| mel_after = ( | |
| self.postnet([mel_before, encoder_masks], training=training) + mel_before | |
| ) | |
| outputs = (mel_before, mel_after, duration_outputs) | |
| return outputs | |
| def _inference(self, input_ids, speaker_ids, speed_ratios, **kwargs): | |
| """Call logic.""" | |
| attention_mask = tf.math.not_equal(input_ids, 0) | |
| embedding_output = self.embeddings([input_ids, speaker_ids], training=False) | |
| encoder_output = self.encoder( | |
| [embedding_output, attention_mask], training=False | |
| ) | |
| last_encoder_hidden_states = encoder_output[0] | |
| # duration predictor, here use last_encoder_hidden_states, u can use more hidden_states layers | |
| # rather than just use last_hidden_states of encoder for duration_predictor. | |
| duration_outputs = self.duration_predictor( | |
| [last_encoder_hidden_states, attention_mask] | |
| ) # [batch_size, length] | |
| duration_outputs = tf.math.exp(duration_outputs) - 1.0 | |
| if speed_ratios is None: | |
| speed_ratios = tf.convert_to_tensor(np.array([1.0]), dtype=tf.float32) | |
| speed_ratios = tf.expand_dims(speed_ratios, 1) | |
| duration_outputs = tf.cast( | |
| tf.math.round(duration_outputs * speed_ratios), tf.int32 | |
| ) | |
| length_regulator_outputs, encoder_masks = self.length_regulator( | |
| [last_encoder_hidden_states, duration_outputs], training=False | |
| ) | |
| # create decoder positional embedding | |
| decoder_pos = tf.range( | |
| 1, tf.shape(length_regulator_outputs)[1] + 1, dtype=tf.int32 | |
| ) | |
| masked_decoder_pos = tf.expand_dims(decoder_pos, 0) * encoder_masks | |
| decoder_output = self.decoder( | |
| [length_regulator_outputs, speaker_ids, encoder_masks, masked_decoder_pos], | |
| training=False, | |
| ) | |
| last_decoder_hidden_states = decoder_output[0] | |
| # here u can use sum or concat more than 1 hidden states layers from decoder. | |
| mel_before = self.mel_dense(last_decoder_hidden_states) | |
| mel_after = ( | |
| self.postnet([mel_before, encoder_masks], training=False) + mel_before | |
| ) | |
| outputs = (mel_before, mel_after, duration_outputs) | |
| return outputs | |
| def setup_inference_fn(self): | |
| self.inference = tf.function( | |
| self._inference, | |
| experimental_relax_shapes=True, | |
| input_signature=[ | |
| tf.TensorSpec(shape=[None, None], dtype=tf.int32, name="input_ids"), | |
| tf.TensorSpec(shape=[None,], dtype=tf.int32, name="speaker_ids"), | |
| tf.TensorSpec(shape=[None,], dtype=tf.float32, name="speed_ratios"), | |
| ], | |
| ) | |
| self.inference_tflite = tf.function( | |
| self._inference, | |
| experimental_relax_shapes=True, | |
| input_signature=[ | |
| tf.TensorSpec(shape=[1, None], dtype=tf.int32, name="input_ids"), | |
| tf.TensorSpec(shape=[1,], dtype=tf.int32, name="speaker_ids"), | |
| tf.TensorSpec(shape=[1,], dtype=tf.float32, name="speed_ratios"), | |
| ], | |
| ) | |