| import logging |
| import os |
| from typing import Any |
|
|
| import torch |
| from coqpit import Coqpit |
| from torch import nn |
| from trainer.logging.base_dash_logger import BaseDashboardLogger |
| from trainer.logging.tensorboard_logger import TensorboardLogger |
|
|
| from TTS.tts.layers.losses import NLLLoss |
| from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils |
| from TTS.tts.layers.overflow.decoder import Decoder |
| from TTS.tts.layers.overflow.neural_hmm import NeuralHMM |
| from TTS.tts.layers.overflow.plotting_utils import ( |
| get_spec_from_most_probable_state, |
| plot_transition_probabilities_to_numpy, |
| ) |
| from TTS.tts.models.base_tts import BaseTTS |
| from TTS.tts.utils.speakers import SpeakerManager |
| from TTS.tts.utils.text.tokenizer import TTSTokenizer |
| from TTS.tts.utils.visual import plot_alignment, plot_spectrogram |
| from TTS.utils.generic_utils import format_aux_input, is_pytorch_at_least_2_4 |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class Overflow(BaseTTS): |
| """OverFlow TTS model. |
| |
| Paper:: |
| https://arxiv.org/abs/2211.06892 |
| |
| Paper abstract:: |
| Neural HMMs are a type of neural transducer recently proposed for |
| sequence-to-sequence modelling in text-to-speech. They combine the best features |
| of classic statistical speech synthesis and modern neural TTS, requiring less |
| data and fewer training updates, and are less prone to gibberish output caused |
| by neural attention failures. In this paper, we combine neural HMM TTS with |
| normalising flows for describing the highly non-Gaussian distribution of speech |
| acoustics. The result is a powerful, fully probabilistic model of durations and |
| acoustics that can be trained using exact maximum likelihood. Compared to |
| dominant flow-based acoustic models, our approach integrates autoregression for |
| improved modelling of long-range dependences such as utterance-level prosody. |
| Experiments show that a system based on our proposal gives more accurate |
| pronunciations and better subjective speech quality than comparable methods, |
| whilst retaining the original advantages of neural HMMs. Audio examples and code |
| are available at https://shivammehta25.github.io/OverFlow/. |
| |
| Note: |
| - Neural HMMs uses flat start initialization i.e it computes the means |
| and std and transition probabilities of the dataset and uses them to initialize |
| the model. This benefits the model and helps with faster learning If you change |
| the dataset or want to regenerate the parameters change the |
| `force_generate_statistics` and `mel_statistics_parameter_path` accordingly. |
| |
| - To enable multi-GPU training, set the `use_grad_checkpointing=False` in config. |
| This will significantly increase the memory usage. This is because to compute |
| the actual data likelihood (not an approximation using MAS/Viterbi) we must use |
| all the states at the previous time step during the forward pass to decide the |
| probability distribution at the current step i.e the difference between the forward |
| algorithm and viterbi approximation. |
| |
| Check :class:`TTS.tts.configs.overflow.OverFlowConfig` for class arguments. |
| """ |
|
|
| def __init__( |
| self, |
| config: "OverFlowConfig", |
| ap: "AudioProcessor" = None, |
| tokenizer: "TTSTokenizer" = None, |
| speaker_manager: SpeakerManager = None, |
| ): |
| super().__init__(config, ap, tokenizer, speaker_manager) |
|
|
| |
| |
| self.config = config |
| for key in config: |
| setattr(self, key, config[key]) |
|
|
| self.decoder_output_dim = config.out_channels |
|
|
| self.encoder = Encoder(config.num_chars, config.state_per_phone, config.encoder_in_out_features) |
| self.neural_hmm = NeuralHMM( |
| frame_channels=self.out_channels, |
| ar_order=self.ar_order, |
| deterministic_transition=self.deterministic_transition, |
| encoder_dim=self.encoder_in_out_features, |
| prenet_type=self.prenet_type, |
| prenet_dim=self.prenet_dim, |
| prenet_n_layers=self.prenet_n_layers, |
| prenet_dropout=self.prenet_dropout, |
| prenet_dropout_at_inference=self.prenet_dropout_at_inference, |
| memory_rnn_dim=self.memory_rnn_dim, |
| outputnet_size=self.outputnet_size, |
| flat_start_params=self.flat_start_params, |
| std_floor=self.std_floor, |
| use_grad_checkpointing=self.use_grad_checkpointing, |
| ) |
|
|
| self.decoder = Decoder( |
| self.out_channels, |
| self.hidden_channels_dec, |
| self.kernel_size_dec, |
| self.dilation_rate, |
| self.num_flow_blocks_dec, |
| self.num_block_layers, |
| dropout_p=self.dropout_p_dec, |
| num_splits=self.num_splits, |
| num_squeeze=self.num_squeeze, |
| sigmoid_scale=self.sigmoid_scale, |
| c_in_channels=self.c_in_channels, |
| ) |
|
|
| self.register_buffer("mean", torch.tensor(0)) |
| self.register_buffer("std", torch.tensor(1)) |
|
|
| def update_mean_std(self, statistics_dict: dict): |
| self.mean.data = torch.tensor(statistics_dict["mean"]) |
| self.std.data = torch.tensor(statistics_dict["std"]) |
|
|
| def preprocess_batch(self, text, text_len, mels, mel_len): |
| if self.mean.item() == 0 or self.std.item() == 1: |
| statistics_dict = torch.load(self.mel_statistics_parameter_path, weights_only=is_pytorch_at_least_2_4()) |
| self.update_mean_std(statistics_dict) |
|
|
| mels = self.normalize(mels) |
| return text, text_len, mels, mel_len |
|
|
| def normalize(self, x): |
| return x.sub(self.mean).div(self.std) |
|
|
| def inverse_normalize(self, x): |
| return x.mul(self.std).add(self.mean) |
|
|
| def forward(self, text, text_len, mels, mel_len): |
| """ |
| Forward pass for training and computing the log likelihood of a given batch. |
| |
| Shapes: |
| Shapes: |
| text: :math:`[B, T_in]` |
| text_len: :math:`[B]` |
| mels: :math:`[B, T_out, C]` |
| mel_len: :math:`[B]` |
| """ |
| text, text_len, mels, mel_len = self.preprocess_batch(text, text_len, mels, mel_len) |
| encoder_outputs, encoder_output_len = self.encoder(text, text_len) |
| z, z_lengths, logdet = self.decoder(mels.transpose(1, 2), mel_len) |
| log_probs, fwd_alignments, transition_vectors, means = self.neural_hmm( |
| encoder_outputs, encoder_output_len, z, z_lengths |
| ) |
|
|
| outputs = { |
| "log_probs": log_probs + logdet, |
| "alignments": fwd_alignments, |
| "transition_vectors": transition_vectors, |
| "means": means, |
| } |
|
|
| return outputs |
|
|
| @staticmethod |
| def _training_stats(batch): |
| stats = {} |
| stats["avg_text_length"] = batch["text_lengths"].float().mean() |
| stats["avg_spec_length"] = batch["mel_lengths"].float().mean() |
| stats["avg_text_batch_occupancy"] = (batch["text_lengths"].float() / batch["text_lengths"].float().max()).mean() |
| stats["avg_spec_batch_occupancy"] = (batch["mel_lengths"].float() / batch["mel_lengths"].float().max()).mean() |
| return stats |
|
|
| def train_step(self, batch: dict, criterion: nn.Module): |
| text_input = batch["text_input"] |
| text_lengths = batch["text_lengths"] |
| mel_input = batch["mel_input"] |
| mel_lengths = batch["mel_lengths"] |
|
|
| outputs = self.forward( |
| text=text_input, |
| text_len=text_lengths, |
| mels=mel_input, |
| mel_len=mel_lengths, |
| ) |
| loss_dict = criterion(outputs["log_probs"] / (mel_lengths.sum() + text_lengths.sum())) |
|
|
| |
| loss_dict.update(self._training_stats(batch)) |
| return outputs, loss_dict |
|
|
| def _format_aux_input(self, aux_input: dict, default_input_dict): |
| """Set missing fields to their default value. |
| |
| Args: |
| aux_inputs (Dict): Dictionary containing the auxiliary inputs. |
| """ |
| default_input_dict = default_input_dict.copy() |
| default_input_dict.update( |
| { |
| "sampling_temp": self.sampling_temp, |
| "max_sampling_time": self.max_sampling_time, |
| "duration_threshold": self.duration_threshold, |
| } |
| ) |
| if aux_input: |
| return format_aux_input(default_input_dict, aux_input) |
| return default_input_dict |
|
|
| @torch.inference_mode() |
| def inference( |
| self, |
| text: torch.Tensor, |
| aux_input={"x_lengths": None, "sampling_temp": None, "max_sampling_time": None, "duration_threshold": None}, |
| ): |
| """Sampling from the model |
| |
| Args: |
| text (torch.Tensor): :math:`[B, T_in]` |
| aux_inputs (_type_, optional): _description_. Defaults to None. |
| |
| Returns: |
| outputs: Dictionary containing the following |
| - mel (torch.Tensor): :math:`[B, T_out, C]` |
| - hmm_outputs_len (torch.Tensor): :math:`[B]` |
| - state_travelled (List[List[int]]): List of lists containing the state travelled for each sample in the batch. |
| - input_parameters (list[torch.FloatTensor]): Input parameters to the neural HMM. |
| - output_parameters (list[torch.FloatTensor]): Output parameters to the neural HMM. |
| """ |
| default_input_dict = { |
| "x_lengths": torch.sum(text != 0, dim=1), |
| } |
| aux_input = self._format_aux_input(aux_input, default_input_dict) |
| encoder_outputs, encoder_output_len = self.encoder.inference(text, aux_input["x_lengths"]) |
| outputs = self.neural_hmm.inference( |
| encoder_outputs, |
| encoder_output_len, |
| sampling_temp=aux_input["sampling_temp"], |
| max_sampling_time=aux_input["max_sampling_time"], |
| duration_threshold=aux_input["duration_threshold"], |
| ) |
|
|
| mels, mel_outputs_len, _ = self.decoder( |
| outputs["hmm_outputs"].transpose(1, 2), outputs["hmm_outputs_len"], reverse=True |
| ) |
| mels = self.inverse_normalize(mels.transpose(1, 2)) |
| outputs.update({"model_outputs": mels, "model_outputs_len": mel_outputs_len}) |
| outputs["alignments"] = OverflowUtils.double_pad(outputs["alignments"]) |
| return outputs |
|
|
| @staticmethod |
| def get_criterion(): |
| return NLLLoss() |
|
|
| @staticmethod |
| def init_from_config(config: "OverFlowConfig", samples: list[list] | list[dict] = None): |
| """Initiate model from config |
| |
| Args: |
| config (VitsConfig): Model config. |
| samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. |
| Defaults to None. |
| """ |
| from TTS.utils.audio import AudioProcessor |
|
|
| ap = AudioProcessor.init_from_config(config) |
| tokenizer, new_config = TTSTokenizer.init_from_config(config) |
| speaker_manager = SpeakerManager.init_from_config(config, samples) |
| return Overflow(new_config, ap, tokenizer, speaker_manager) |
|
|
| def load_checkpoint( |
| self, |
| config: Coqpit, |
| checkpoint_path: str | os.PathLike[Any], |
| *, |
| eval: bool = False, |
| strict: bool = True, |
| cache: bool = False, |
| ) -> None: |
| super().load_checkpoint(config, checkpoint_path, eval=eval, strict=strict, cache=cache) |
| if eval: |
| self.decoder.store_inverse() |
|
|
| def on_init_start(self, trainer): |
| """If the current dataset does not have normalisation statistics and initialisation transition_probability it computes them otherwise loads.""" |
| if not os.path.isfile(trainer.config.mel_statistics_parameter_path) or trainer.config.force_generate_statistics: |
| dataloader = trainer.get_train_dataloader( |
| training_assets=None, samples=trainer.train_samples, verbose=False |
| ) |
| logger.info( |
| "Data parameters not found for: %s. Computing mel normalization parameters...", |
| trainer.config.mel_statistics_parameter_path, |
| ) |
| data_mean, data_std, init_transition_prob = OverflowUtils.get_data_parameters_for_flat_start( |
| dataloader, trainer.config.out_channels, trainer.config.state_per_phone |
| ) |
| logger.info( |
| "Saving data parameters to: %s: value: %s", |
| trainer.config.mel_statistics_parameter_path, |
| (data_mean, data_std, init_transition_prob), |
| ) |
| statistics = { |
| "mean": data_mean.item(), |
| "std": data_std.item(), |
| "init_transition_prob": init_transition_prob.item(), |
| } |
| torch.save(statistics, trainer.config.mel_statistics_parameter_path) |
|
|
| else: |
| logger.info( |
| "Data parameters found for: %s. Loading mel normalization parameters...", |
| trainer.config.mel_statistics_parameter_path, |
| ) |
| statistics = torch.load( |
| trainer.config.mel_statistics_parameter_path, weights_only=is_pytorch_at_least_2_4() |
| ) |
| data_mean, data_std, init_transition_prob = ( |
| statistics["mean"], |
| statistics["std"], |
| statistics["init_transition_prob"], |
| ) |
| logger.info("Data parameters loaded with value: %s", (data_mean, data_std, init_transition_prob)) |
|
|
| trainer.config.flat_start_params["transition_p"] = ( |
| init_transition_prob.item() if isinstance(init_transition_prob, torch.Tensor) else init_transition_prob |
| ) |
| OverflowUtils.update_flat_start_transition(trainer.model, init_transition_prob) |
| trainer.model.update_mean_std(statistics) |
|
|
| def _create_logs(self, batch, outputs): |
| alignments, transition_vectors = outputs["alignments"], outputs["transition_vectors"] |
| means = torch.stack(outputs["means"], dim=1) |
|
|
| figures = { |
| "alignment": plot_alignment(alignments[0].exp(), title="Forward alignment", fig_size=(20, 20)), |
| "log_alignment": plot_alignment( |
| alignments[0].exp(), title="Forward log alignment", plot_log=True, fig_size=(20, 20) |
| ), |
| "transition_vectors": plot_alignment(transition_vectors[0], title="Transition vectors", fig_size=(20, 20)), |
| "mel_from_most_probable_state": plot_spectrogram( |
| get_spec_from_most_probable_state(alignments[0], means[0], self.decoder), fig_size=(12, 3) |
| ), |
| "mel_target": plot_spectrogram(batch["mel_input"][0], fig_size=(12, 3)), |
| } |
|
|
| |
| logger.info("Synthesising audio from the model...") |
| inference_output = self.inference( |
| batch["text_input"][-1].unsqueeze(0), aux_input={"x_lengths": batch["text_lengths"][-1].unsqueeze(0)} |
| ) |
| figures["synthesised"] = plot_spectrogram(inference_output["model_outputs"][0], fig_size=(12, 3)) |
|
|
| states = [p[1] for p in inference_output["input_parameters"][0]] |
| transition_probability_synthesising = [p[2].cpu().numpy() for p in inference_output["output_parameters"][0]] |
|
|
| for i in range((len(transition_probability_synthesising) // 200) + 1): |
| start = i * 200 |
| end = (i + 1) * 200 |
| figures[f"synthesised_transition_probabilities/{i}"] = plot_transition_probabilities_to_numpy( |
| states[start:end], transition_probability_synthesising[start:end] |
| ) |
|
|
| audio = self.ap.inv_melspectrogram(inference_output["model_outputs"][0].T.cpu().numpy()) |
| return figures, {"audios": audio} |
|
|
| def eval_log( |
| self, |
| batch: dict[str, Any], |
| outputs: dict[str, Any] | list[dict[str, Any]], |
| logger: BaseDashboardLogger, |
| assets: dict[str, Any], |
| steps: int, |
| ) -> None: |
| """Compute and log evaluation metrics.""" |
| |
| if isinstance(logger, TensorboardLogger): |
| |
| for tag, value in self.named_parameters(): |
| tag = tag.replace(".", "/") |
| logger.writer.add_histogram(tag, value.data.cpu().numpy(), steps) |
| super().eval_log(batch, outputs, logger, assets, steps) |
|
|