| import math |
| import os |
| from dataclasses import dataclass, field, replace |
| from itertools import chain |
| from typing import Dict, List, Tuple, Union |
|
|
| import numpy as np |
| import torch |
| import torch.distributed as dist |
| import torchaudio |
| from coqpit import Coqpit |
| from librosa.filters import mel as librosa_mel_fn |
| from torch import nn |
| from torch.cuda.amp.autocast_mode import autocast |
| from torch.nn import functional as F |
| from torch.utils.data import DataLoader |
| from torch.utils.data.sampler import WeightedRandomSampler |
| from trainer.torch import DistributedSampler, DistributedSamplerWrapper |
| from trainer.trainer_utils import get_optimizer, get_scheduler |
|
|
| from TTS.tts.configs.shared_configs import CharactersConfig |
| from TTS.tts.datasets.dataset import TTSDataset, _parse_sample |
| from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor |
| from TTS.tts.layers.vits.discriminator import VitsDiscriminator |
| from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder |
| from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor |
| from TTS.tts.models.base_tts import BaseTTS |
| from TTS.tts.utils.fairseq import rehash_fairseq_vits_checkpoint |
| from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask |
| from TTS.tts.utils.languages import LanguageManager |
| from TTS.tts.utils.speakers import SpeakerManager |
| from TTS.tts.utils.synthesis import synthesis |
| from TTS.tts.utils.text.characters import BaseCharacters, BaseVocabulary, _characters, _pad, _phonemes, _punctuations |
| from TTS.tts.utils.text.tokenizer import TTSTokenizer |
| from TTS.tts.utils.visual import plot_alignment |
| from TTS.utils.io import load_fsspec |
| from TTS.utils.samplers import BucketBatchSampler |
| from TTS.vocoder.models.hifigan_generator import HifiganGenerator |
| from TTS.vocoder.utils.generic_utils import plot_results |
|
|
| |
| |
| |
|
|
| |
| hann_window = {} |
| mel_basis = {} |
|
|
|
|
| @torch.no_grad() |
| def weights_reset(m: nn.Module): |
| |
| reset_parameters = getattr(m, "reset_parameters", None) |
| if callable(reset_parameters): |
| m.reset_parameters() |
|
|
|
|
| def get_module_weights_sum(mdl: nn.Module): |
| dict_sums = {} |
| for name, w in mdl.named_parameters(): |
| if "weight" in name: |
| value = w.data.sum().item() |
| dict_sums[name] = value |
| return dict_sums |
|
|
|
|
| def load_audio(file_path): |
| """Load the audio file normalized in [-1, 1] |
| |
| Return Shapes: |
| - x: :math:`[1, T]` |
| """ |
| x, sr = torchaudio.load(file_path) |
| assert (x > 1).sum() + (x < -1).sum() == 0 |
| return x, sr |
|
|
|
|
| def _amp_to_db(x, C=1, clip_val=1e-5): |
| return torch.log(torch.clamp(x, min=clip_val) * C) |
|
|
|
|
| def _db_to_amp(x, C=1): |
| return torch.exp(x) / C |
|
|
|
|
| def amp_to_db(magnitudes): |
| output = _amp_to_db(magnitudes) |
| return output |
|
|
|
|
| def db_to_amp(magnitudes): |
| output = _db_to_amp(magnitudes) |
| return output |
|
|
|
|
| def wav_to_spec(y, n_fft, hop_length, win_length, center=False): |
| """ |
| Args Shapes: |
| - y : :math:`[B, 1, T]` |
| |
| Return Shapes: |
| - spec : :math:`[B,C,T]` |
| """ |
| y = y.squeeze(1) |
|
|
| if torch.min(y) < -1.0: |
| print("min value is ", torch.min(y)) |
| if torch.max(y) > 1.0: |
| print("max value is ", torch.max(y)) |
|
|
| global hann_window |
| dtype_device = str(y.dtype) + "_" + str(y.device) |
| wnsize_dtype_device = str(win_length) + "_" + dtype_device |
| if wnsize_dtype_device not in hann_window: |
| hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) |
|
|
| y = torch.nn.functional.pad( |
| y.unsqueeze(1), |
| (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), |
| mode="reflect", |
| ) |
| y = y.squeeze(1) |
|
|
| spec = torch.stft( |
| y, |
| n_fft, |
| hop_length=hop_length, |
| win_length=win_length, |
| window=hann_window[wnsize_dtype_device], |
| center=center, |
| pad_mode="reflect", |
| normalized=False, |
| onesided=True, |
| return_complex=False, |
| ) |
|
|
| spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) |
| return spec |
|
|
|
|
| def spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax): |
| """ |
| Args Shapes: |
| - spec : :math:`[B,C,T]` |
| |
| Return Shapes: |
| - mel : :math:`[B,C,T]` |
| """ |
| global mel_basis |
| dtype_device = str(spec.dtype) + "_" + str(spec.device) |
| fmax_dtype_device = str(fmax) + "_" + dtype_device |
| if fmax_dtype_device not in mel_basis: |
| mel = librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) |
| mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) |
| mel = torch.matmul(mel_basis[fmax_dtype_device], spec) |
| mel = amp_to_db(mel) |
| return mel |
|
|
|
|
| def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fmax, center=False): |
| """ |
| Args Shapes: |
| - y : :math:`[B, 1, T]` |
| |
| Return Shapes: |
| - spec : :math:`[B,C,T]` |
| """ |
| y = y.squeeze(1) |
|
|
| if torch.min(y) < -1.0: |
| print("min value is ", torch.min(y)) |
| if torch.max(y) > 1.0: |
| print("max value is ", torch.max(y)) |
|
|
| global mel_basis, hann_window |
| dtype_device = str(y.dtype) + "_" + str(y.device) |
| fmax_dtype_device = str(fmax) + "_" + dtype_device |
| wnsize_dtype_device = str(win_length) + "_" + dtype_device |
| if fmax_dtype_device not in mel_basis: |
| mel = librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) |
| mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) |
| if wnsize_dtype_device not in hann_window: |
| hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) |
|
|
| y = torch.nn.functional.pad( |
| y.unsqueeze(1), |
| (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), |
| mode="reflect", |
| ) |
| y = y.squeeze(1) |
|
|
| spec = torch.stft( |
| y, |
| n_fft, |
| hop_length=hop_length, |
| win_length=win_length, |
| window=hann_window[wnsize_dtype_device], |
| center=center, |
| pad_mode="reflect", |
| normalized=False, |
| onesided=True, |
| return_complex=False, |
| ) |
|
|
| spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) |
| spec = torch.matmul(mel_basis[fmax_dtype_device], spec) |
| spec = amp_to_db(spec) |
| return spec |
|
|
|
|
| |
| |
| |
|
|
|
|
| @dataclass |
| class VitsAudioConfig(Coqpit): |
| fft_size: int = 1024 |
| sample_rate: int = 22050 |
| win_length: int = 1024 |
| hop_length: int = 256 |
| num_mels: int = 80 |
| mel_fmin: int = 0 |
| mel_fmax: int = None |
|
|
|
|
| |
| |
| |
|
|
|
|
| def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: dict = None): |
| """Create inverse frequency weights for balancing the dataset. |
| Use `multi_dict` to scale relative weights.""" |
| attr_names_samples = np.array([item[attr_name] for item in items]) |
| unique_attr_names = np.unique(attr_names_samples).tolist() |
| attr_idx = [unique_attr_names.index(l) for l in attr_names_samples] |
| attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names]) |
| weight_attr = 1.0 / attr_count |
| dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx]) |
| dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight) |
| if multi_dict is not None: |
| |
| for k in multi_dict: |
| assert k in unique_attr_names, f"{k} not in {unique_attr_names}" |
| |
| multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items]) |
| dataset_samples_weight *= multiplier_samples |
| return ( |
| torch.from_numpy(dataset_samples_weight).float(), |
| unique_attr_names, |
| np.unique(dataset_samples_weight).tolist(), |
| ) |
|
|
|
|
| class VitsDataset(TTSDataset): |
| def __init__(self, model_args, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.pad_id = self.tokenizer.characters.pad_id |
| self.model_args = model_args |
|
|
| def __getitem__(self, idx): |
| item = self.samples[idx] |
| raw_text = item["text"] |
|
|
| wav, _ = load_audio(item["audio_file"]) |
| if self.model_args.encoder_sample_rate is not None: |
| if wav.size(1) % self.model_args.encoder_sample_rate != 0: |
| wav = wav[:, : -int(wav.size(1) % self.model_args.encoder_sample_rate)] |
|
|
| wav_filename = os.path.basename(item["audio_file"]) |
|
|
| token_ids = self.get_token_ids(idx, item["text"]) |
|
|
| |
| |
| |
| if len(token_ids) > self.max_text_len or wav.shape[1] < self.min_audio_len: |
| self.rescue_item_idx += 1 |
| return self.__getitem__(self.rescue_item_idx) |
|
|
| return { |
| "raw_text": raw_text, |
| "token_ids": token_ids, |
| "token_len": len(token_ids), |
| "wav": wav, |
| "wav_file": wav_filename, |
| "speaker_name": item["speaker_name"], |
| "language_name": item["language"], |
| "audio_unique_name": item["audio_unique_name"], |
| } |
|
|
| @property |
| def lengths(self): |
| lens = [] |
| for item in self.samples: |
| _, wav_file, *_ = _parse_sample(item) |
| audio_len = os.path.getsize(wav_file) / 16 * 8 |
| lens.append(audio_len) |
| return lens |
|
|
| def collate_fn(self, batch): |
| """ |
| Return Shapes: |
| - tokens: :math:`[B, T]` |
| - token_lens :math:`[B]` |
| - token_rel_lens :math:`[B]` |
| - waveform: :math:`[B, 1, T]` |
| - waveform_lens: :math:`[B]` |
| - waveform_rel_lens: :math:`[B]` |
| - speaker_names: :math:`[B]` |
| - language_names: :math:`[B]` |
| - audiofile_paths: :math:`[B]` |
| - raw_texts: :math:`[B]` |
| - audio_unique_names: :math:`[B]` |
| """ |
| |
| B = len(batch) |
| batch = {k: [dic[k] for dic in batch] for k in batch[0]} |
|
|
| _, ids_sorted_decreasing = torch.sort( |
| torch.LongTensor([x.size(1) for x in batch["wav"]]), dim=0, descending=True |
| ) |
|
|
| max_text_len = max([len(x) for x in batch["token_ids"]]) |
| token_lens = torch.LongTensor(batch["token_len"]) |
| token_rel_lens = token_lens / token_lens.max() |
|
|
| wav_lens = [w.shape[1] for w in batch["wav"]] |
| wav_lens = torch.LongTensor(wav_lens) |
| wav_lens_max = torch.max(wav_lens) |
| wav_rel_lens = wav_lens / wav_lens_max |
|
|
| token_padded = torch.LongTensor(B, max_text_len) |
| wav_padded = torch.FloatTensor(B, 1, wav_lens_max) |
| token_padded = token_padded.zero_() + self.pad_id |
| wav_padded = wav_padded.zero_() + self.pad_id |
| for i in range(len(ids_sorted_decreasing)): |
| token_ids = batch["token_ids"][i] |
| token_padded[i, : batch["token_len"][i]] = torch.LongTensor(token_ids) |
|
|
| wav = batch["wav"][i] |
| wav_padded[i, :, : wav.size(1)] = torch.FloatTensor(wav) |
|
|
| return { |
| "tokens": token_padded, |
| "token_lens": token_lens, |
| "token_rel_lens": token_rel_lens, |
| "waveform": wav_padded, |
| "waveform_lens": wav_lens, |
| "waveform_rel_lens": wav_rel_lens, |
| "speaker_names": batch["speaker_name"], |
| "language_names": batch["language_name"], |
| "audio_files": batch["wav_file"], |
| "raw_text": batch["raw_text"], |
| "audio_unique_names": batch["audio_unique_name"], |
| } |
|
|
|
|
| |
| |
| |
|
|
|
|
| @dataclass |
| class VitsArgs(Coqpit): |
| """VITS model arguments. |
| |
| Args: |
| |
| num_chars (int): |
| Number of characters in the vocabulary. Defaults to 100. |
| |
| out_channels (int): |
| Number of output channels of the decoder. Defaults to 513. |
| |
| spec_segment_size (int): |
| Decoder input segment size. Defaults to 32 `(32 * hoplength = waveform length)`. |
| |
| hidden_channels (int): |
| Number of hidden channels of the model. Defaults to 192. |
| |
| hidden_channels_ffn_text_encoder (int): |
| Number of hidden channels of the feed-forward layers of the text encoder transformer. Defaults to 256. |
| |
| num_heads_text_encoder (int): |
| Number of attention heads of the text encoder transformer. Defaults to 2. |
| |
| num_layers_text_encoder (int): |
| Number of transformer layers in the text encoder. Defaults to 6. |
| |
| kernel_size_text_encoder (int): |
| Kernel size of the text encoder transformer FFN layers. Defaults to 3. |
| |
| dropout_p_text_encoder (float): |
| Dropout rate of the text encoder. Defaults to 0.1. |
| |
| dropout_p_duration_predictor (float): |
| Dropout rate of the duration predictor. Defaults to 0.1. |
| |
| kernel_size_posterior_encoder (int): |
| Kernel size of the posterior encoder's WaveNet layers. Defaults to 5. |
| |
| dilatation_posterior_encoder (int): |
| Dilation rate of the posterior encoder's WaveNet layers. Defaults to 1. |
| |
| num_layers_posterior_encoder (int): |
| Number of posterior encoder's WaveNet layers. Defaults to 16. |
| |
| kernel_size_flow (int): |
| Kernel size of the Residual Coupling layers of the flow network. Defaults to 5. |
| |
| dilatation_flow (int): |
| Dilation rate of the Residual Coupling WaveNet layers of the flow network. Defaults to 1. |
| |
| num_layers_flow (int): |
| Number of Residual Coupling WaveNet layers of the flow network. Defaults to 6. |
| |
| resblock_type_decoder (str): |
| Type of the residual block in the decoder network. Defaults to "1". |
| |
| resblock_kernel_sizes_decoder (List[int]): |
| Kernel sizes of the residual blocks in the decoder network. Defaults to `[3, 7, 11]`. |
| |
| resblock_dilation_sizes_decoder (List[List[int]]): |
| Dilation sizes of the residual blocks in the decoder network. Defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`. |
| |
| upsample_rates_decoder (List[int]): |
| Upsampling rates for each concecutive upsampling layer in the decoder network. The multiply of these |
| values must be equal to the kop length used for computing spectrograms. Defaults to `[8, 8, 2, 2]`. |
| |
| upsample_initial_channel_decoder (int): |
| Number of hidden channels of the first upsampling convolution layer of the decoder network. Defaults to 512. |
| |
| upsample_kernel_sizes_decoder (List[int]): |
| Kernel sizes for each upsampling layer of the decoder network. Defaults to `[16, 16, 4, 4]`. |
| |
| periods_multi_period_discriminator (List[int]): |
| Periods values for Vits Multi-Period Discriminator. Defaults to `[2, 3, 5, 7, 11]`. |
| |
| use_sdp (bool): |
| Use Stochastic Duration Predictor. Defaults to True. |
| |
| noise_scale (float): |
| Noise scale used for the sample noise tensor in training. Defaults to 1.0. |
| |
| inference_noise_scale (float): |
| Noise scale used for the sample noise tensor in inference. Defaults to 0.667. |
| |
| length_scale (float): |
| Scale factor for the predicted duration values. Smaller values result faster speech. Defaults to 1. |
| |
| noise_scale_dp (float): |
| Noise scale used by the Stochastic Duration Predictor sample noise in training. Defaults to 1.0. |
| |
| inference_noise_scale_dp (float): |
| Noise scale for the Stochastic Duration Predictor in inference. Defaults to 0.8. |
| |
| max_inference_len (int): |
| Maximum inference length to limit the memory use. Defaults to None. |
| |
| init_discriminator (bool): |
| Initialize the disciminator network if set True. Set False for inference. Defaults to True. |
| |
| use_spectral_norm_disriminator (bool): |
| Use spectral normalization over weight norm in the discriminator. Defaults to False. |
| |
| use_speaker_embedding (bool): |
| Enable/Disable speaker embedding for multi-speaker models. Defaults to False. |
| |
| num_speakers (int): |
| Number of speakers for the speaker embedding layer. Defaults to 0. |
| |
| speakers_file (str): |
| Path to the speaker mapping file for the Speaker Manager. Defaults to None. |
| |
| speaker_embedding_channels (int): |
| Number of speaker embedding channels. Defaults to 256. |
| |
| use_d_vector_file (bool): |
| Enable/Disable the use of d-vectors for multi-speaker training. Defaults to False. |
| |
| d_vector_file (List[str]): |
| List of paths to the files including pre-computed speaker embeddings. Defaults to None. |
| |
| d_vector_dim (int): |
| Number of d-vector channels. Defaults to 0. |
| |
| detach_dp_input (bool): |
| Detach duration predictor's input from the network for stopping the gradients. Defaults to True. |
| |
| use_language_embedding (bool): |
| Enable/Disable language embedding for multilingual models. Defaults to False. |
| |
| embedded_language_dim (int): |
| Number of language embedding channels. Defaults to 4. |
| |
| num_languages (int): |
| Number of languages for the language embedding layer. Defaults to 0. |
| |
| language_ids_file (str): |
| Path to the language mapping file for the Language Manager. Defaults to None. |
| |
| use_speaker_encoder_as_loss (bool): |
| Enable/Disable Speaker Consistency Loss (SCL). Defaults to False. |
| |
| speaker_encoder_config_path (str): |
| Path to the file speaker encoder config file, to use for SCL. Defaults to "". |
| |
| speaker_encoder_model_path (str): |
| Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "". |
| |
| condition_dp_on_speaker (bool): |
| Condition the duration predictor on the speaker embedding. Defaults to True. |
| |
| freeze_encoder (bool): |
| Freeze the encoder weigths during training. Defaults to False. |
| |
| freeze_DP (bool): |
| Freeze the duration predictor weigths during training. Defaults to False. |
| |
| freeze_PE (bool): |
| Freeze the posterior encoder weigths during training. Defaults to False. |
| |
| freeze_flow_encoder (bool): |
| Freeze the flow encoder weigths during training. Defaults to False. |
| |
| freeze_waveform_decoder (bool): |
| Freeze the waveform decoder weigths during training. Defaults to False. |
| |
| encoder_sample_rate (int): |
| If not None this sample rate will be used for training the Posterior Encoder, |
| flow, text_encoder and duration predictor. The decoder part (vocoder) will be |
| trained with the `config.audio.sample_rate`. Defaults to None. |
| |
| interpolate_z (bool): |
| If `encoder_sample_rate` not None and this parameter True the nearest interpolation |
| will be used to upsampling the latent variable z with the sampling rate `encoder_sample_rate` |
| to the `config.audio.sample_rate`. If it is False you will need to add extra |
| `upsample_rates_decoder` to match the shape. Defaults to True. |
| |
| """ |
|
|
| num_chars: int = 100 |
| out_channels: int = 513 |
| spec_segment_size: int = 32 |
| hidden_channels: int = 192 |
| hidden_channels_ffn_text_encoder: int = 768 |
| num_heads_text_encoder: int = 2 |
| num_layers_text_encoder: int = 6 |
| kernel_size_text_encoder: int = 3 |
| dropout_p_text_encoder: float = 0.1 |
| dropout_p_duration_predictor: float = 0.5 |
| kernel_size_posterior_encoder: int = 5 |
| dilation_rate_posterior_encoder: int = 1 |
| num_layers_posterior_encoder: int = 16 |
| kernel_size_flow: int = 5 |
| dilation_rate_flow: int = 1 |
| num_layers_flow: int = 4 |
| resblock_type_decoder: str = "1" |
| resblock_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [3, 7, 11]) |
| resblock_dilation_sizes_decoder: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]) |
| upsample_rates_decoder: List[int] = field(default_factory=lambda: [8, 8, 2, 2]) |
| upsample_initial_channel_decoder: int = 512 |
| upsample_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [16, 16, 4, 4]) |
| periods_multi_period_discriminator: List[int] = field(default_factory=lambda: [2, 3, 5, 7, 11]) |
| use_sdp: bool = True |
| noise_scale: float = 1.0 |
| inference_noise_scale: float = 0.667 |
| length_scale: float = 1 |
| noise_scale_dp: float = 1.0 |
| inference_noise_scale_dp: float = 1.0 |
| max_inference_len: int = None |
| init_discriminator: bool = True |
| use_spectral_norm_disriminator: bool = False |
| use_speaker_embedding: bool = False |
| num_speakers: int = 0 |
| speakers_file: str = None |
| d_vector_file: List[str] = None |
| speaker_embedding_channels: int = 256 |
| use_d_vector_file: bool = False |
| d_vector_dim: int = 0 |
| detach_dp_input: bool = True |
| use_language_embedding: bool = False |
| embedded_language_dim: int = 4 |
| num_languages: int = 0 |
| language_ids_file: str = None |
| use_speaker_encoder_as_loss: bool = False |
| speaker_encoder_config_path: str = "" |
| speaker_encoder_model_path: str = "" |
| condition_dp_on_speaker: bool = True |
| freeze_encoder: bool = False |
| freeze_DP: bool = False |
| freeze_PE: bool = False |
| freeze_flow_decoder: bool = False |
| freeze_waveform_decoder: bool = False |
| encoder_sample_rate: int = None |
| interpolate_z: bool = True |
| reinit_DP: bool = False |
| reinit_text_encoder: bool = False |
|
|
|
|
| class Vits(BaseTTS): |
| """VITS TTS model |
| |
| Paper:: |
| https://arxiv.org/pdf/2106.06103.pdf |
| |
| Paper Abstract:: |
| Several recent end-to-end text-to-speech (TTS) models enabling single-stage training and parallel |
| sampling have been proposed, but their sample quality does not match that of two-stage TTS systems. |
| In this work, we present a parallel endto-end TTS method that generates more natural sounding audio than |
| current two-stage models. Our method adopts variational inference augmented with normalizing flows and |
| an adversarial training process, which improves the expressive power of generative modeling. We also propose a |
| stochastic duration predictor to synthesize speech with diverse rhythms from input text. With the |
| uncertainty modeling over latent variables and the stochastic duration predictor, our method expresses the |
| natural one-to-many relationship in which a text input can be spoken in multiple ways |
| with different pitches and rhythms. A subjective human evaluation (mean opinion score, or MOS) |
| on the LJ Speech, a single speaker dataset, shows that our method outperforms the best publicly |
| available TTS systems and achieves a MOS comparable to ground truth. |
| |
| Check :class:`TTS.tts.configs.vits_config.VitsConfig` for class arguments. |
| |
| Examples: |
| >>> from TTS.tts.configs.vits_config import VitsConfig |
| >>> from TTS.tts.models.vits import Vits |
| >>> config = VitsConfig() |
| >>> model = Vits(config) |
| """ |
|
|
| def __init__( |
| self, |
| config: Coqpit, |
| ap: "AudioProcessor" = None, |
| tokenizer: "TTSTokenizer" = None, |
| speaker_manager: SpeakerManager = None, |
| language_manager: LanguageManager = None, |
| ): |
| super().__init__(config, ap, tokenizer, speaker_manager, language_manager) |
|
|
| self.init_multispeaker(config) |
| self.init_multilingual(config) |
| self.init_upsampling() |
|
|
| self.length_scale = self.args.length_scale |
| self.noise_scale = self.args.noise_scale |
| self.inference_noise_scale = self.args.inference_noise_scale |
| self.inference_noise_scale_dp = self.args.inference_noise_scale_dp |
| self.noise_scale_dp = self.args.noise_scale_dp |
| self.max_inference_len = self.args.max_inference_len |
| self.spec_segment_size = self.args.spec_segment_size |
|
|
| self.text_encoder = TextEncoder( |
| self.args.num_chars, |
| self.args.hidden_channels, |
| self.args.hidden_channels, |
| self.args.hidden_channels_ffn_text_encoder, |
| self.args.num_heads_text_encoder, |
| self.args.num_layers_text_encoder, |
| self.args.kernel_size_text_encoder, |
| self.args.dropout_p_text_encoder, |
| language_emb_dim=self.embedded_language_dim, |
| ) |
|
|
| self.posterior_encoder = PosteriorEncoder( |
| self.args.out_channels, |
| self.args.hidden_channels, |
| self.args.hidden_channels, |
| kernel_size=self.args.kernel_size_posterior_encoder, |
| dilation_rate=self.args.dilation_rate_posterior_encoder, |
| num_layers=self.args.num_layers_posterior_encoder, |
| cond_channels=self.embedded_speaker_dim, |
| ) |
|
|
| self.flow = ResidualCouplingBlocks( |
| self.args.hidden_channels, |
| self.args.hidden_channels, |
| kernel_size=self.args.kernel_size_flow, |
| dilation_rate=self.args.dilation_rate_flow, |
| num_layers=self.args.num_layers_flow, |
| cond_channels=self.embedded_speaker_dim, |
| ) |
|
|
| if self.args.use_sdp: |
| self.duration_predictor = StochasticDurationPredictor( |
| self.args.hidden_channels, |
| 192, |
| 3, |
| self.args.dropout_p_duration_predictor, |
| 4, |
| cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0, |
| language_emb_dim=self.embedded_language_dim, |
| ) |
| else: |
| self.duration_predictor = DurationPredictor( |
| self.args.hidden_channels, |
| 256, |
| 3, |
| self.args.dropout_p_duration_predictor, |
| cond_channels=self.embedded_speaker_dim, |
| language_emb_dim=self.embedded_language_dim, |
| ) |
|
|
| self.waveform_decoder = HifiganGenerator( |
| self.args.hidden_channels, |
| 1, |
| self.args.resblock_type_decoder, |
| self.args.resblock_dilation_sizes_decoder, |
| self.args.resblock_kernel_sizes_decoder, |
| self.args.upsample_kernel_sizes_decoder, |
| self.args.upsample_initial_channel_decoder, |
| self.args.upsample_rates_decoder, |
| inference_padding=0, |
| cond_channels=self.embedded_speaker_dim, |
| conv_pre_weight_norm=False, |
| conv_post_weight_norm=False, |
| conv_post_bias=False, |
| ) |
|
|
| if self.args.init_discriminator: |
| self.disc = VitsDiscriminator( |
| periods=self.args.periods_multi_period_discriminator, |
| use_spectral_norm=self.args.use_spectral_norm_disriminator, |
| ) |
|
|
| @property |
| def device(self): |
| return next(self.parameters()).device |
|
|
| def init_multispeaker(self, config: Coqpit): |
| """Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer |
| or with external `d_vectors` computed from a speaker encoder model. |
| |
| You must provide a `speaker_manager` at initialization to set up the multi-speaker modules. |
| |
| Args: |
| config (Coqpit): Model configuration. |
| data (List, optional): Dataset items to infer number of speakers. Defaults to None. |
| """ |
| self.embedded_speaker_dim = 0 |
| self.num_speakers = self.args.num_speakers |
| self.audio_transform = None |
|
|
| if self.speaker_manager: |
| self.num_speakers = self.speaker_manager.num_speakers |
|
|
| if self.args.use_speaker_embedding: |
| self._init_speaker_embedding() |
|
|
| if self.args.use_d_vector_file: |
| self._init_d_vector() |
|
|
| |
| if self.args.use_speaker_encoder_as_loss: |
| if self.speaker_manager.encoder is None and ( |
| not self.args.speaker_encoder_model_path or not self.args.speaker_encoder_config_path |
| ): |
| raise RuntimeError( |
| " [!] To use the speaker consistency loss (SCL) you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!" |
| ) |
|
|
| self.speaker_manager.encoder.eval() |
| print(" > External Speaker Encoder Loaded !!") |
|
|
| if ( |
| hasattr(self.speaker_manager.encoder, "audio_config") |
| and self.config.audio.sample_rate != self.speaker_manager.encoder.audio_config["sample_rate"] |
| ): |
| self.audio_transform = torchaudio.transforms.Resample( |
| orig_freq=self.config.audio.sample_rate, |
| new_freq=self.speaker_manager.encoder.audio_config["sample_rate"], |
| ) |
|
|
| def _init_speaker_embedding(self): |
| |
| if self.num_speakers > 0: |
| print(" > initialization of speaker-embedding layers.") |
| self.embedded_speaker_dim = self.args.speaker_embedding_channels |
| self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) |
|
|
| def _init_d_vector(self): |
| |
| if hasattr(self, "emb_g"): |
| raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.") |
| self.embedded_speaker_dim = self.args.d_vector_dim |
|
|
| def init_multilingual(self, config: Coqpit): |
| """Initialize multilingual modules of a model. |
| |
| Args: |
| config (Coqpit): Model configuration. |
| """ |
| if self.args.language_ids_file is not None: |
| self.language_manager = LanguageManager(language_ids_file_path=config.language_ids_file) |
|
|
| if self.args.use_language_embedding and self.language_manager: |
| print(" > initialization of language-embedding layers.") |
| self.num_languages = self.language_manager.num_languages |
| self.embedded_language_dim = self.args.embedded_language_dim |
| self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim) |
| torch.nn.init.xavier_uniform_(self.emb_l.weight) |
| else: |
| self.embedded_language_dim = 0 |
|
|
| def init_upsampling(self): |
| """ |
| Initialize upsampling modules of a model. |
| """ |
| if self.args.encoder_sample_rate: |
| self.interpolate_factor = self.config.audio["sample_rate"] / self.args.encoder_sample_rate |
| self.audio_resampler = torchaudio.transforms.Resample( |
| orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate |
| ) |
|
|
| def on_epoch_start(self, trainer): |
| """Freeze layers at the beginning of an epoch""" |
| self._freeze_layers() |
| |
| if self.args.use_speaker_encoder_as_loss: |
| self.speaker_manager.encoder = self.speaker_manager.encoder.to(self.device) |
|
|
| def on_init_end(self, trainer): |
| """Reinit layes if needed""" |
| if self.args.reinit_DP: |
| before_dict = get_module_weights_sum(self.duration_predictor) |
| |
| self.duration_predictor.apply(fn=weights_reset) |
| after_dict = get_module_weights_sum(self.duration_predictor) |
| for key, value in after_dict.items(): |
| if value == before_dict[key]: |
| raise RuntimeError(" [!] The weights of Duration Predictor was not reinit check it !") |
| print(" > Duration Predictor was reinit.") |
|
|
| if self.args.reinit_text_encoder: |
| before_dict = get_module_weights_sum(self.text_encoder) |
| |
| self.text_encoder.apply(fn=weights_reset) |
| after_dict = get_module_weights_sum(self.text_encoder) |
| for key, value in after_dict.items(): |
| if value == before_dict[key]: |
| raise RuntimeError(" [!] The weights of Text Encoder was not reinit check it !") |
| print(" > Text Encoder was reinit.") |
|
|
| def get_aux_input(self, aux_input: Dict): |
| sid, g, lid, _ = self._set_cond_input(aux_input) |
| return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid} |
|
|
| def _freeze_layers(self): |
| if self.args.freeze_encoder: |
| for param in self.text_encoder.parameters(): |
| param.requires_grad = False |
|
|
| if hasattr(self, "emb_l"): |
| for param in self.emb_l.parameters(): |
| param.requires_grad = False |
|
|
| if self.args.freeze_PE: |
| for param in self.posterior_encoder.parameters(): |
| param.requires_grad = False |
|
|
| if self.args.freeze_DP: |
| for param in self.duration_predictor.parameters(): |
| param.requires_grad = False |
|
|
| if self.args.freeze_flow_decoder: |
| for param in self.flow.parameters(): |
| param.requires_grad = False |
|
|
| if self.args.freeze_waveform_decoder: |
| for param in self.waveform_decoder.parameters(): |
| param.requires_grad = False |
|
|
| @staticmethod |
| def _set_cond_input(aux_input: Dict): |
| """Set the speaker conditioning input based on the multi-speaker mode.""" |
| sid, g, lid, durations = None, None, None, None |
| if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None: |
| sid = aux_input["speaker_ids"] |
| if sid.ndim == 0: |
| sid = sid.unsqueeze_(0) |
| if "d_vectors" in aux_input and aux_input["d_vectors"] is not None: |
| g = F.normalize(aux_input["d_vectors"]).unsqueeze(-1) |
| if g.ndim == 2: |
| g = g.unsqueeze_(0) |
|
|
| if "language_ids" in aux_input and aux_input["language_ids"] is not None: |
| lid = aux_input["language_ids"] |
| if lid.ndim == 0: |
| lid = lid.unsqueeze_(0) |
|
|
| if "durations" in aux_input and aux_input["durations"] is not None: |
| durations = aux_input["durations"] |
|
|
| return sid, g, lid, durations |
|
|
| def _set_speaker_input(self, aux_input: Dict): |
| d_vectors = aux_input.get("d_vectors", None) |
| speaker_ids = aux_input.get("speaker_ids", None) |
|
|
| if d_vectors is not None and speaker_ids is not None: |
| raise ValueError("[!] Cannot use d-vectors and speaker-ids together.") |
|
|
| if speaker_ids is not None and not hasattr(self, "emb_g"): |
| raise ValueError("[!] Cannot use speaker-ids without enabling speaker embedding.") |
|
|
| g = speaker_ids if speaker_ids is not None else d_vectors |
| return g |
|
|
| def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb): |
| |
| attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) |
| with torch.no_grad(): |
| o_scale = torch.exp(-2 * logs_p) |
| logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) |
| logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p**2)]) |
| logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p]) |
| logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1]).unsqueeze(-1) |
| logp = logp2 + logp3 + logp1 + logp4 |
| attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() |
|
|
| |
| attn_durations = attn.sum(3) |
| if self.args.use_sdp: |
| loss_duration = self.duration_predictor( |
| x.detach() if self.args.detach_dp_input else x, |
| x_mask, |
| attn_durations, |
| g=g.detach() if self.args.detach_dp_input and g is not None else g, |
| lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, |
| ) |
| loss_duration = loss_duration / torch.sum(x_mask) |
| else: |
| attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask |
| log_durations = self.duration_predictor( |
| x.detach() if self.args.detach_dp_input else x, |
| x_mask, |
| g=g.detach() if self.args.detach_dp_input and g is not None else g, |
| lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, |
| ) |
| loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask) |
| outputs["loss_duration"] = loss_duration |
| return outputs, attn |
|
|
| def upsampling_z(self, z, slice_ids=None, y_lengths=None, y_mask=None): |
| spec_segment_size = self.spec_segment_size |
| if self.args.encoder_sample_rate: |
| |
| slice_ids = slice_ids * int(self.interpolate_factor) if slice_ids is not None else slice_ids |
| spec_segment_size = spec_segment_size * int(self.interpolate_factor) |
| |
| if self.args.interpolate_z: |
| z = torch.nn.functional.interpolate(z, scale_factor=[self.interpolate_factor], mode="linear").squeeze(0) |
| |
| if y_lengths is not None and y_mask is not None: |
| y_mask = ( |
| sequence_mask(y_lengths * self.interpolate_factor, None).to(y_mask.dtype).unsqueeze(1) |
| ) |
|
|
| return z, spec_segment_size, slice_ids, y_mask |
|
|
| def forward( |
| self, |
| x: torch.tensor, |
| x_lengths: torch.tensor, |
| y: torch.tensor, |
| y_lengths: torch.tensor, |
| waveform: torch.tensor, |
| aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}, |
| ) -> Dict: |
| """Forward pass of the model. |
| |
| Args: |
| x (torch.tensor): Batch of input character sequence IDs. |
| x_lengths (torch.tensor): Batch of input character sequence lengths. |
| y (torch.tensor): Batch of input spectrograms. |
| y_lengths (torch.tensor): Batch of input spectrogram lengths. |
| waveform (torch.tensor): Batch of ground truth waveforms per sample. |
| aux_input (dict, optional): Auxiliary inputs for multi-speaker and multi-lingual training. |
| Defaults to {"d_vectors": None, "speaker_ids": None, "language_ids": None}. |
| |
| Returns: |
| Dict: model outputs keyed by the output name. |
| |
| Shapes: |
| - x: :math:`[B, T_seq]` |
| - x_lengths: :math:`[B]` |
| - y: :math:`[B, C, T_spec]` |
| - y_lengths: :math:`[B]` |
| - waveform: :math:`[B, 1, T_wav]` |
| - d_vectors: :math:`[B, C, 1]` |
| - speaker_ids: :math:`[B]` |
| - language_ids: :math:`[B]` |
| |
| Return Shapes: |
| - model_outputs: :math:`[B, 1, T_wav]` |
| - alignments: :math:`[B, T_seq, T_dec]` |
| - z: :math:`[B, C, T_dec]` |
| - z_p: :math:`[B, C, T_dec]` |
| - m_p: :math:`[B, C, T_dec]` |
| - logs_p: :math:`[B, C, T_dec]` |
| - m_q: :math:`[B, C, T_dec]` |
| - logs_q: :math:`[B, C, T_dec]` |
| - waveform_seg: :math:`[B, 1, spec_seg_size * hop_length]` |
| - gt_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]` |
| - syn_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]` |
| """ |
| outputs = {} |
| sid, g, lid, _ = self._set_cond_input(aux_input) |
| |
| if self.args.use_speaker_embedding and sid is not None: |
| g = self.emb_g(sid).unsqueeze(-1) |
|
|
| |
| lang_emb = None |
| if self.args.use_language_embedding and lid is not None: |
| lang_emb = self.emb_l(lid).unsqueeze(-1) |
|
|
| x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb) |
|
|
| |
| z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g) |
|
|
| |
| z_p = self.flow(z, y_mask, g=g) |
|
|
| |
| outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb) |
|
|
| |
| m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p]) |
| logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p]) |
|
|
| |
| z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True) |
|
|
| |
| z_slice, spec_segment_size, slice_ids, _ = self.upsampling_z(z_slice, slice_ids=slice_ids) |
|
|
| o = self.waveform_decoder(z_slice, g=g) |
|
|
| wav_seg = segment( |
| waveform, |
| slice_ids * self.config.audio.hop_length, |
| spec_segment_size * self.config.audio.hop_length, |
| pad_short=True, |
| ) |
|
|
| if self.args.use_speaker_encoder_as_loss and self.speaker_manager.encoder is not None: |
| |
| wavs_batch = torch.cat((wav_seg, o), dim=0) |
|
|
| |
| |
| if self.audio_transform is not None: |
| wavs_batch = self.audio_transform(wavs_batch) |
|
|
| pred_embs = self.speaker_manager.encoder.forward(wavs_batch, l2_norm=True) |
|
|
| |
| gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0) |
| else: |
| gt_spk_emb, syn_spk_emb = None, None |
|
|
| outputs.update( |
| { |
| "model_outputs": o, |
| "alignments": attn.squeeze(1), |
| "m_p": m_p, |
| "logs_p": logs_p, |
| "z": z, |
| "z_p": z_p, |
| "m_q": m_q, |
| "logs_q": logs_q, |
| "waveform_seg": wav_seg, |
| "gt_spk_emb": gt_spk_emb, |
| "syn_spk_emb": syn_spk_emb, |
| "slice_ids": slice_ids, |
| } |
| ) |
| return outputs |
|
|
| @staticmethod |
| def _set_x_lengths(x, aux_input): |
| if "x_lengths" in aux_input and aux_input["x_lengths"] is not None: |
| return aux_input["x_lengths"] |
| return torch.tensor(x.shape[1:2]).to(x.device) |
|
|
| @torch.no_grad() |
| def inference( |
| self, |
| x, |
| aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None, "durations": None}, |
| ): |
| """ |
| Note: |
| To run in batch mode, provide `x_lengths` else model assumes that the batch size is 1. |
| |
| Shapes: |
| - x: :math:`[B, T_seq]` |
| - x_lengths: :math:`[B]` |
| - d_vectors: :math:`[B, C]` |
| - speaker_ids: :math:`[B]` |
| |
| Return Shapes: |
| - model_outputs: :math:`[B, 1, T_wav]` |
| - alignments: :math:`[B, T_seq, T_dec]` |
| - z: :math:`[B, C, T_dec]` |
| - z_p: :math:`[B, C, T_dec]` |
| - m_p: :math:`[B, C, T_dec]` |
| - logs_p: :math:`[B, C, T_dec]` |
| """ |
| sid, g, lid, durations = self._set_cond_input(aux_input) |
| x_lengths = self._set_x_lengths(x, aux_input) |
|
|
| |
| if self.args.use_speaker_embedding and sid is not None: |
| g = self.emb_g(sid).unsqueeze(-1) |
|
|
| |
| lang_emb = None |
| if self.args.use_language_embedding and lid is not None: |
| lang_emb = self.emb_l(lid).unsqueeze(-1) |
|
|
| x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb) |
|
|
| if durations is None: |
| if self.args.use_sdp: |
| logw = self.duration_predictor( |
| x, |
| x_mask, |
| g=g if self.args.condition_dp_on_speaker else None, |
| reverse=True, |
| noise_scale=self.inference_noise_scale_dp, |
| lang_emb=lang_emb, |
| ) |
| else: |
| logw = self.duration_predictor( |
| x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb |
| ) |
| w = torch.exp(logw) * x_mask * self.length_scale |
| else: |
| assert durations.shape[-1] == x.shape[-1] |
| w = durations.unsqueeze(0) |
|
|
| w_ceil = torch.ceil(w) |
| y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() |
| y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype).unsqueeze(1) |
|
|
| attn_mask = x_mask * y_mask.transpose(1, 2) |
| attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2)) |
|
|
| m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2) |
| logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2) |
|
|
| z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale |
| z = self.flow(z_p, y_mask, g=g, reverse=True) |
|
|
| |
| z, _, _, y_mask = self.upsampling_z(z, y_lengths=y_lengths, y_mask=y_mask) |
|
|
| o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g) |
|
|
| outputs = { |
| "model_outputs": o, |
| "alignments": attn.squeeze(1), |
| "durations": w_ceil, |
| "z": z, |
| "z_p": z_p, |
| "m_p": m_p, |
| "logs_p": logs_p, |
| "y_mask": y_mask, |
| } |
| return outputs |
|
|
| @torch.no_grad() |
| def inference_voice_conversion( |
| self, reference_wav, speaker_id=None, d_vector=None, reference_speaker_id=None, reference_d_vector=None |
| ): |
| """Inference for voice conversion |
| |
| Args: |
| reference_wav (Tensor): Reference wavform. Tensor of shape [B, T] |
| speaker_id (Tensor): speaker_id of the target speaker. Tensor of shape [B] |
| d_vector (Tensor): d_vector embedding of target speaker. Tensor of shape `[B, C]` |
| reference_speaker_id (Tensor): speaker_id of the reference_wav speaker. Tensor of shape [B] |
| reference_d_vector (Tensor): d_vector embedding of the reference_wav speaker. Tensor of shape `[B, C]` |
| """ |
| |
| y = wav_to_spec( |
| reference_wav, |
| self.config.audio.fft_size, |
| self.config.audio.hop_length, |
| self.config.audio.win_length, |
| center=False, |
| ) |
| y_lengths = torch.tensor([y.size(-1)]).to(y.device) |
| speaker_cond_src = reference_speaker_id if reference_speaker_id is not None else reference_d_vector |
| speaker_cond_tgt = speaker_id if speaker_id is not None else d_vector |
| wav, _, _ = self.voice_conversion(y, y_lengths, speaker_cond_src, speaker_cond_tgt) |
| return wav |
|
|
| def voice_conversion(self, y, y_lengths, speaker_cond_src, speaker_cond_tgt): |
| """Forward pass for voice conversion |
| |
| TODO: create an end-point for voice conversion |
| |
| Args: |
| y (Tensor): Reference spectrograms. Tensor of shape [B, T, C] |
| y_lengths (Tensor): Length of each reference spectrogram. Tensor of shape [B] |
| speaker_cond_src (Tensor): Reference speaker ID. Tensor of shape [B,] |
| speaker_cond_tgt (Tensor): Target speaker ID. Tensor of shape [B,] |
| """ |
| assert self.num_speakers > 0, "num_speakers have to be larger than 0." |
| |
| if self.args.use_speaker_embedding and not self.args.use_d_vector_file: |
| g_src = self.emb_g(torch.from_numpy((np.array(speaker_cond_src))).unsqueeze(0)).unsqueeze(-1) |
| g_tgt = self.emb_g(torch.from_numpy((np.array(speaker_cond_tgt))).unsqueeze(0)).unsqueeze(-1) |
| elif not self.args.use_speaker_embedding and self.args.use_d_vector_file: |
| g_src = F.normalize(speaker_cond_src).unsqueeze(-1) |
| g_tgt = F.normalize(speaker_cond_tgt).unsqueeze(-1) |
| else: |
| raise RuntimeError(" [!] Voice conversion is only supported on multi-speaker models.") |
|
|
| z, _, _, y_mask = self.posterior_encoder(y, y_lengths, g=g_src) |
| z_p = self.flow(z, y_mask, g=g_src) |
| z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) |
| o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt) |
| return o_hat, y_mask, (z, z_p, z_hat) |
|
|
| def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]: |
| """Perform a single training step. Run the model forward pass and compute losses. |
| |
| Args: |
| batch (Dict): Input tensors. |
| criterion (nn.Module): Loss layer designed for the model. |
| optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks. |
| |
| Returns: |
| Tuple[Dict, Dict]: Model ouputs and computed losses. |
| """ |
|
|
| spec_lens = batch["spec_lens"] |
|
|
| if optimizer_idx == 0: |
| tokens = batch["tokens"] |
| token_lenghts = batch["token_lens"] |
| spec = batch["spec"] |
|
|
| d_vectors = batch["d_vectors"] |
| speaker_ids = batch["speaker_ids"] |
| language_ids = batch["language_ids"] |
| waveform = batch["waveform"] |
|
|
| |
| outputs = self.forward( |
| tokens, |
| token_lenghts, |
| spec, |
| spec_lens, |
| waveform, |
| aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids}, |
| ) |
|
|
| |
| self.model_outputs_cache = outputs |
|
|
| |
| scores_disc_fake, _, scores_disc_real, _ = self.disc( |
| outputs["model_outputs"].detach(), outputs["waveform_seg"] |
| ) |
|
|
| |
| with autocast(enabled=False): |
| loss_dict = criterion[optimizer_idx]( |
| scores_disc_real, |
| scores_disc_fake, |
| ) |
| return outputs, loss_dict |
|
|
| if optimizer_idx == 1: |
| mel = batch["mel"] |
|
|
| |
| with autocast(enabled=False): |
| if self.args.encoder_sample_rate: |
| spec_segment_size = self.spec_segment_size * int(self.interpolate_factor) |
| else: |
| spec_segment_size = self.spec_segment_size |
|
|
| mel_slice = segment( |
| mel.float(), self.model_outputs_cache["slice_ids"], spec_segment_size, pad_short=True |
| ) |
| mel_slice_hat = wav_to_mel( |
| y=self.model_outputs_cache["model_outputs"].float(), |
| n_fft=self.config.audio.fft_size, |
| sample_rate=self.config.audio.sample_rate, |
| num_mels=self.config.audio.num_mels, |
| hop_length=self.config.audio.hop_length, |
| win_length=self.config.audio.win_length, |
| fmin=self.config.audio.mel_fmin, |
| fmax=self.config.audio.mel_fmax, |
| center=False, |
| ) |
|
|
| |
| scores_disc_fake, feats_disc_fake, _, feats_disc_real = self.disc( |
| self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"] |
| ) |
|
|
| |
| with autocast(enabled=False): |
| loss_dict = criterion[optimizer_idx]( |
| mel_slice_hat=mel_slice.float(), |
| mel_slice=mel_slice_hat.float(), |
| z_p=self.model_outputs_cache["z_p"].float(), |
| logs_q=self.model_outputs_cache["logs_q"].float(), |
| m_p=self.model_outputs_cache["m_p"].float(), |
| logs_p=self.model_outputs_cache["logs_p"].float(), |
| z_len=spec_lens, |
| scores_disc_fake=scores_disc_fake, |
| feats_disc_fake=feats_disc_fake, |
| feats_disc_real=feats_disc_real, |
| loss_duration=self.model_outputs_cache["loss_duration"], |
| use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss, |
| gt_spk_emb=self.model_outputs_cache["gt_spk_emb"], |
| syn_spk_emb=self.model_outputs_cache["syn_spk_emb"], |
| ) |
|
|
| return self.model_outputs_cache, loss_dict |
|
|
| raise ValueError(" [!] Unexpected `optimizer_idx`.") |
|
|
| def _log(self, ap, batch, outputs, name_prefix="train"): |
| y_hat = outputs[1]["model_outputs"] |
| y = outputs[1]["waveform_seg"] |
| figures = plot_results(y_hat, y, ap, name_prefix) |
| sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() |
| audios = {f"{name_prefix}/audio": sample_voice} |
|
|
| alignments = outputs[1]["alignments"] |
| align_img = alignments[0].data.cpu().numpy().T |
|
|
| figures.update( |
| { |
| "alignment": plot_alignment(align_img, output_fig=False), |
| } |
| ) |
| return figures, audios |
|
|
| def train_log( |
| self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int |
| ): |
| """Create visualizations and waveform examples. |
| |
| For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to |
| be projected onto Tensorboard. |
| |
| Args: |
| ap (AudioProcessor): audio processor used at training. |
| batch (Dict): Model inputs used at the previous training step. |
| outputs (Dict): Model outputs generated at the previoud training step. |
| |
| Returns: |
| Tuple[Dict, np.ndarray]: training plots and output waveform. |
| """ |
| figures, audios = self._log(self.ap, batch, outputs, "train") |
| logger.train_figures(steps, figures) |
| logger.train_audios(steps, audios, self.ap.sample_rate) |
|
|
| @torch.no_grad() |
| def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int): |
| return self.train_step(batch, criterion, optimizer_idx) |
|
|
| def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: |
| figures, audios = self._log(self.ap, batch, outputs, "eval") |
| logger.eval_figures(steps, figures) |
| logger.eval_audios(steps, audios, self.ap.sample_rate) |
|
|
| def get_aux_input_from_test_sentences(self, sentence_info): |
| if hasattr(self.config, "model_args"): |
| config = self.config.model_args |
| else: |
| config = self.config |
|
|
| |
| text, speaker_name, style_wav, language_name = None, None, None, None |
|
|
| if isinstance(sentence_info, list): |
| if len(sentence_info) == 1: |
| text = sentence_info[0] |
| elif len(sentence_info) == 2: |
| text, speaker_name = sentence_info |
| elif len(sentence_info) == 3: |
| text, speaker_name, style_wav = sentence_info |
| elif len(sentence_info) == 4: |
| text, speaker_name, style_wav, language_name = sentence_info |
| else: |
| text = sentence_info |
|
|
| |
| speaker_id, d_vector, language_id = None, None, None |
| if hasattr(self, "speaker_manager"): |
| if config.use_d_vector_file: |
| if speaker_name is None: |
| d_vector = self.speaker_manager.get_random_embedding() |
| else: |
| d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False) |
| elif config.use_speaker_embedding: |
| if speaker_name is None: |
| speaker_id = self.speaker_manager.get_random_id() |
| else: |
| speaker_id = self.speaker_manager.name_to_id[speaker_name] |
|
|
| |
| if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None: |
| language_id = self.language_manager.name_to_id[language_name] |
|
|
| return { |
| "text": text, |
| "speaker_id": speaker_id, |
| "style_wav": style_wav, |
| "d_vector": d_vector, |
| "language_id": language_id, |
| "language_name": language_name, |
| } |
|
|
| @torch.no_grad() |
| def test_run(self, assets) -> Tuple[Dict, Dict]: |
| """Generic test run for `tts` models used by `Trainer`. |
| |
| You can override this for a different behaviour. |
| |
| Returns: |
| Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. |
| """ |
| print(" | > Synthesizing test sentences.") |
| test_audios = {} |
| test_figures = {} |
| test_sentences = self.config.test_sentences |
| for idx, s_info in enumerate(test_sentences): |
| aux_inputs = self.get_aux_input_from_test_sentences(s_info) |
| wav, alignment, _, _ = synthesis( |
| self, |
| aux_inputs["text"], |
| self.config, |
| "cuda" in str(next(self.parameters()).device), |
| speaker_id=aux_inputs["speaker_id"], |
| d_vector=aux_inputs["d_vector"], |
| style_wav=aux_inputs["style_wav"], |
| language_id=aux_inputs["language_id"], |
| use_griffin_lim=True, |
| do_trim_silence=False, |
| ).values() |
| test_audios["{}-audio".format(idx)] = wav |
| test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False) |
| return {"figures": test_figures, "audios": test_audios} |
|
|
| def test_log( |
| self, outputs: dict, logger: "Logger", assets: dict, steps: int |
| ) -> None: |
| logger.test_audios(steps, outputs["audios"], self.ap.sample_rate) |
| logger.test_figures(steps, outputs["figures"]) |
|
|
| def format_batch(self, batch: Dict) -> Dict: |
| """Compute speaker, langugage IDs and d_vector for the batch if necessary.""" |
| speaker_ids = None |
| language_ids = None |
| d_vectors = None |
|
|
| |
| if self.speaker_manager is not None and self.speaker_manager.name_to_id and self.args.use_speaker_embedding: |
| speaker_ids = [self.speaker_manager.name_to_id[sn] for sn in batch["speaker_names"]] |
|
|
| if speaker_ids is not None: |
| speaker_ids = torch.LongTensor(speaker_ids) |
|
|
| |
| if self.speaker_manager is not None and self.speaker_manager.embeddings and self.args.use_d_vector_file: |
| d_vector_mapping = self.speaker_manager.embeddings |
| d_vectors = [d_vector_mapping[w]["embedding"] for w in batch["audio_unique_names"]] |
| d_vectors = torch.FloatTensor(d_vectors) |
|
|
| |
| if self.language_manager is not None and self.language_manager.name_to_id and self.args.use_language_embedding: |
| language_ids = [self.language_manager.name_to_id[ln] for ln in batch["language_names"]] |
|
|
| if language_ids is not None: |
| language_ids = torch.LongTensor(language_ids) |
|
|
| batch["language_ids"] = language_ids |
| batch["d_vectors"] = d_vectors |
| batch["speaker_ids"] = speaker_ids |
| return batch |
|
|
| def format_batch_on_device(self, batch): |
| """Compute spectrograms on the device.""" |
| ac = self.config.audio |
|
|
| if self.args.encoder_sample_rate: |
| wav = self.audio_resampler(batch["waveform"]) |
| else: |
| wav = batch["waveform"] |
|
|
| |
| batch["spec"] = wav_to_spec(wav, ac.fft_size, ac.hop_length, ac.win_length, center=False) |
|
|
| if self.args.encoder_sample_rate: |
| |
| spec_mel = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False) |
| |
| if spec_mel.size(2) > int(batch["spec"].size(2) * self.interpolate_factor): |
| spec_mel = spec_mel[:, :, : int(batch["spec"].size(2) * self.interpolate_factor)] |
| else: |
| batch["spec"] = batch["spec"][:, :, : int(spec_mel.size(2) / self.interpolate_factor)] |
| else: |
| spec_mel = batch["spec"] |
|
|
| batch["mel"] = spec_to_mel( |
| spec=spec_mel, |
| n_fft=ac.fft_size, |
| num_mels=ac.num_mels, |
| sample_rate=ac.sample_rate, |
| fmin=ac.mel_fmin, |
| fmax=ac.mel_fmax, |
| ) |
|
|
| if self.args.encoder_sample_rate: |
| assert batch["spec"].shape[2] == int( |
| batch["mel"].shape[2] / self.interpolate_factor |
| ), f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}" |
| else: |
| assert batch["spec"].shape[2] == batch["mel"].shape[2], f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}" |
|
|
| |
| batch["spec_lens"] = (batch["spec"].shape[2] * batch["waveform_rel_lens"]).int() |
| batch["mel_lens"] = (batch["mel"].shape[2] * batch["waveform_rel_lens"]).int() |
|
|
| if self.args.encoder_sample_rate: |
| assert (batch["spec_lens"] - (batch["mel_lens"] / self.interpolate_factor).int()).sum() == 0 |
| else: |
| assert (batch["spec_lens"] - batch["mel_lens"]).sum() == 0 |
|
|
| |
| batch["spec"] = batch["spec"] * sequence_mask(batch["spec_lens"]).unsqueeze(1) |
| batch["mel"] = batch["mel"] * sequence_mask(batch["mel_lens"]).unsqueeze(1) |
| return batch |
|
|
| def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1, is_eval=False): |
| weights = None |
| data_items = dataset.samples |
| if getattr(config, "use_weighted_sampler", False): |
| for attr_name, alpha in config.weighted_sampler_attrs.items(): |
| print(f" > Using weighted sampler for attribute '{attr_name}' with alpha '{alpha}'") |
| multi_dict = config.weighted_sampler_multipliers.get(attr_name, None) |
| print(multi_dict) |
| weights, attr_names, attr_weights = get_attribute_balancer_weights( |
| attr_name=attr_name, items=data_items, multi_dict=multi_dict |
| ) |
| weights = weights * alpha |
| print(f" > Attribute weights for '{attr_names}' \n | > {attr_weights}") |
|
|
| |
|
|
| if weights is not None: |
| w_sampler = WeightedRandomSampler(weights, len(weights)) |
| batch_sampler = BucketBatchSampler( |
| w_sampler, |
| data=data_items, |
| batch_size=config.eval_batch_size if is_eval else config.batch_size, |
| sort_key=lambda x: os.path.getsize(x["audio_file"]), |
| drop_last=True, |
| ) |
| else: |
| batch_sampler = None |
| |
| if batch_sampler is None: |
| batch_sampler = DistributedSampler(dataset) if num_gpus > 1 else None |
| else: |
| batch_sampler = ( |
| DistributedSamplerWrapper(batch_sampler) if num_gpus > 1 else batch_sampler |
| ) |
| return batch_sampler |
|
|
| def get_data_loader( |
| self, |
| config: Coqpit, |
| assets: Dict, |
| is_eval: bool, |
| samples: Union[List[Dict], List[List]], |
| verbose: bool, |
| num_gpus: int, |
| rank: int = None, |
| ) -> "DataLoader": |
| if is_eval and not config.run_eval: |
| loader = None |
| else: |
| |
| dataset = VitsDataset( |
| model_args=self.args, |
| samples=samples, |
| batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, |
| min_text_len=config.min_text_len, |
| max_text_len=config.max_text_len, |
| min_audio_len=config.min_audio_len, |
| max_audio_len=config.max_audio_len, |
| phoneme_cache_path=config.phoneme_cache_path, |
| precompute_num_workers=config.precompute_num_workers, |
| verbose=verbose, |
| tokenizer=self.tokenizer, |
| start_by_longest=config.start_by_longest, |
| ) |
|
|
| |
| if num_gpus > 1: |
| dist.barrier() |
|
|
| |
| dataset.preprocess_samples() |
|
|
| |
| sampler = self.get_sampler(config, dataset, num_gpus) |
| if sampler is None: |
| loader = DataLoader( |
| dataset, |
| batch_size=config.eval_batch_size if is_eval else config.batch_size, |
| shuffle=False, |
| collate_fn=dataset.collate_fn, |
| drop_last=False, |
| num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, |
| pin_memory=False, |
| ) |
| else: |
| if num_gpus > 1: |
| loader = DataLoader( |
| dataset, |
| sampler=sampler, |
| batch_size=config.eval_batch_size if is_eval else config.batch_size, |
| collate_fn=dataset.collate_fn, |
| num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, |
| pin_memory=False, |
| ) |
| else: |
| loader = DataLoader( |
| dataset, |
| batch_sampler=sampler, |
| collate_fn=dataset.collate_fn, |
| num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, |
| pin_memory=False, |
| ) |
| return loader |
|
|
| def get_optimizer(self) -> List: |
| """Initiate and return the GAN optimizers based on the config parameters. |
| It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator. |
| Returns: |
| List: optimizers. |
| """ |
| |
| optimizer0 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc) |
|
|
| gen_parameters = chain(params for k, params in self.named_parameters() if not k.startswith("disc.")) |
| optimizer1 = get_optimizer( |
| self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters |
| ) |
| return [optimizer0, optimizer1] |
|
|
| def get_lr(self) -> List: |
| """Set the initial learning rates for each optimizer. |
| |
| Returns: |
| List: learning rates for each optimizer. |
| """ |
| return [self.config.lr_disc, self.config.lr_gen] |
|
|
| def get_scheduler(self, optimizer) -> List: |
| """Set the schedulers for each optimizer. |
| |
| Args: |
| optimizer (List[`torch.optim.Optimizer`]): List of optimizers. |
| |
| Returns: |
| List: Schedulers, one for each optimizer. |
| """ |
| scheduler_D = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[0]) |
| scheduler_G = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[1]) |
| return [scheduler_D, scheduler_G] |
|
|
| def get_criterion(self): |
| """Get criterions for each optimizer. The index in the output list matches the optimizer idx used in |
| `train_step()`""" |
| from TTS.tts.layers.losses import ( |
| VitsDiscriminatorLoss, |
| VitsGeneratorLoss, |
| ) |
|
|
| return [VitsDiscriminatorLoss(self.config), VitsGeneratorLoss(self.config)] |
|
|
| def load_checkpoint( |
| self, config, checkpoint_path, eval=False, strict=True, cache=False |
| ): |
| """Load the model checkpoint and setup for training or inference""" |
| state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) |
| |
| |
| |
| state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k} |
|
|
| if self.args.encoder_sample_rate is not None and eval: |
| |
| self.audio_resampler = None |
|
|
| |
| if hasattr(self, "emb_g") and state["model"]["emb_g.weight"].shape != self.emb_g.weight.shape: |
| num_new_speakers = self.emb_g.weight.shape[0] - state["model"]["emb_g.weight"].shape[0] |
| print(f" > Loading checkpoint with {num_new_speakers} additional speakers.") |
| emb_g = state["model"]["emb_g.weight"] |
| new_row = torch.randn(num_new_speakers, emb_g.shape[1]) |
| emb_g = torch.cat([emb_g, new_row], axis=0) |
| state["model"]["emb_g.weight"] = emb_g |
| |
| self.load_state_dict(state["model"], strict=strict) |
|
|
| if eval: |
| self.eval() |
| assert not self.training |
|
|
| def load_fairseq_checkpoint( |
| self, config, checkpoint_dir, eval=False, strict=True |
| ): |
| """Load VITS checkpoints released by fairseq here: https://github.com/facebookresearch/fairseq/tree/main/examples/mms |
| Performs some changes for compatibility. |
| |
| Args: |
| config (Coqpit): 🐸TTS model config. |
| checkpoint_dir (str): Path to the checkpoint directory. |
| eval (bool, optional): Set to True for evaluation. Defaults to False. |
| """ |
| import json |
|
|
| from TTS.tts.utils.text.cleaners import basic_cleaners |
|
|
| self.disc = None |
| |
| config_file = os.path.join(checkpoint_dir, "config.json") |
| checkpoint_file = os.path.join(checkpoint_dir, "G_100000.pth") |
| vocab_file = os.path.join(checkpoint_dir, "vocab.txt") |
| |
| with open(config_file, "r", encoding="utf-8") as file: |
| |
| config_org = json.load(file) |
| self.config.audio.sample_rate = config_org["data"]["sampling_rate"] |
| |
| |
| vocab = FairseqVocab(vocab_file) |
| self.text_encoder.emb = nn.Embedding(vocab.num_chars, config.model_args.hidden_channels) |
| self.tokenizer = TTSTokenizer( |
| use_phonemes=False, |
| text_cleaner=basic_cleaners, |
| characters=vocab, |
| phonemizer=None, |
| add_blank=config_org["data"]["add_blank"], |
| use_eos_bos=False, |
| ) |
| |
| new_chk = rehash_fairseq_vits_checkpoint(checkpoint_file) |
| self.load_state_dict(new_chk, strict=strict) |
| if eval: |
| self.eval() |
| assert not self.training |
|
|
| @staticmethod |
| def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): |
| """Initiate model from config |
| |
| Args: |
| config (VitsConfig): Model config. |
| samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. |
| Defaults to None. |
| """ |
| from TTS.utils.audio import AudioProcessor |
|
|
| upsample_rate = torch.prod(torch.as_tensor(config.model_args.upsample_rates_decoder)).item() |
|
|
| if not config.model_args.encoder_sample_rate: |
| assert ( |
| upsample_rate == config.audio.hop_length |
| ), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {config.audio.hop_length}" |
| else: |
| encoder_to_vocoder_upsampling_factor = config.audio.sample_rate / config.model_args.encoder_sample_rate |
| effective_hop_length = config.audio.hop_length * encoder_to_vocoder_upsampling_factor |
| assert ( |
| upsample_rate == effective_hop_length |
| ), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {effective_hop_length}" |
|
|
| ap = AudioProcessor.init_from_config(config, verbose=verbose) |
| tokenizer, new_config = TTSTokenizer.init_from_config(config) |
| speaker_manager = SpeakerManager.init_from_config(config, samples) |
| language_manager = LanguageManager.init_from_config(config) |
|
|
| if config.model_args.speaker_encoder_model_path: |
| speaker_manager.init_encoder( |
| config.model_args.speaker_encoder_model_path, config.model_args.speaker_encoder_config_path |
| ) |
| return Vits(new_config, ap, tokenizer, speaker_manager, language_manager) |
|
|
| def export_onnx(self, output_path: str = "coqui_vits.onnx", verbose: bool = True): |
| """Export model to ONNX format for inference |
| |
| Args: |
| output_path (str): Path to save the exported model. |
| verbose (bool): Print verbose information. Defaults to True. |
| """ |
|
|
| |
| _forward = self.forward |
| disc = None |
| if hasattr(self, "disc"): |
| disc = self.disc |
| training = self.training |
|
|
| |
| self.disc = None |
| self.eval() |
|
|
| def onnx_inference(text, text_lengths, scales, sid=None, langid=None): |
| noise_scale = scales[0] |
| length_scale = scales[1] |
| noise_scale_dp = scales[2] |
| self.noise_scale = noise_scale |
| self.length_scale = length_scale |
| self.noise_scale_dp = noise_scale_dp |
| return self.inference( |
| text, |
| aux_input={ |
| "x_lengths": text_lengths, |
| "d_vectors": None, |
| "speaker_ids": sid, |
| "language_ids": langid, |
| "durations": None, |
| }, |
| )["model_outputs"] |
|
|
| self.forward = onnx_inference |
|
|
| |
| dummy_input_length = 100 |
| sequences = torch.randint(low=0, high=2, size=(1, dummy_input_length), dtype=torch.long) |
| sequence_lengths = torch.LongTensor([sequences.size(1)]) |
| scales = torch.FloatTensor([self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp]) |
| dummy_input = (sequences, sequence_lengths, scales) |
| input_names = ["input", "input_lengths", "scales"] |
|
|
| if self.num_speakers > 0: |
| speaker_id = torch.LongTensor([0]) |
| dummy_input += (speaker_id,) |
| input_names.append("sid") |
|
|
| if hasattr(self, "num_languages") and self.num_languages > 0 and self.embedded_language_dim > 0: |
| language_id = torch.LongTensor([0]) |
| dummy_input += (language_id,) |
| input_names.append("langid") |
|
|
| |
| torch.onnx.export( |
| model=self, |
| args=dummy_input, |
| opset_version=15, |
| f=output_path, |
| verbose=verbose, |
| input_names=input_names, |
| output_names=["output"], |
| dynamic_axes={ |
| "input": {0: "batch_size", 1: "phonemes"}, |
| "input_lengths": {0: "batch_size"}, |
| "output": {0: "batch_size", 1: "time1", 2: "time2"}, |
| }, |
| ) |
|
|
| |
| self.forward = _forward |
| if training: |
| self.train() |
| if not disc is None: |
| self.disc = disc |
|
|
| def load_onnx(self, model_path: str, cuda=False): |
| import onnxruntime as ort |
|
|
| providers = [ |
| "CPUExecutionProvider" |
| if cuda is False |
| else ("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"}) |
| ] |
| sess_options = ort.SessionOptions() |
| self.onnx_sess = ort.InferenceSession( |
| model_path, |
| sess_options=sess_options, |
| providers=providers, |
| ) |
|
|
| def inference_onnx(self, x, x_lengths=None, speaker_id=None, language_id=None): |
| """ONNX inference""" |
|
|
| if isinstance(x, torch.Tensor): |
| x = x.cpu().numpy() |
|
|
| if x_lengths is None: |
| x_lengths = np.array([x.shape[1]], dtype=np.int64) |
|
|
| if isinstance(x_lengths, torch.Tensor): |
| x_lengths = x_lengths.cpu().numpy() |
| scales = np.array( |
| [self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp], |
| dtype=np.float32, |
| ) |
| input_params = {"input": x, "input_lengths": x_lengths, "scales": scales} |
| if not speaker_id is None: |
| input_params["sid"] = torch.tensor([speaker_id]).cpu().numpy() |
| if not language_id is None: |
| input_params["langid"] = torch.tensor([language_id]).cpu().numpy() |
|
|
| audio = self.onnx_sess.run( |
| ["output"], |
| input_params, |
| ) |
| return audio[0][0] |
|
|
|
|
| |
| |
| |
|
|
|
|
| class VitsCharacters(BaseCharacters): |
| """Characters class for VITs model for compatibility with pre-trained models""" |
|
|
| def __init__( |
| self, |
| graphemes: str = _characters, |
| punctuations: str = _punctuations, |
| pad: str = _pad, |
| ipa_characters: str = _phonemes, |
| ) -> None: |
| if ipa_characters is not None: |
| graphemes += ipa_characters |
| super().__init__(graphemes, punctuations, pad, None, None, "<BLNK>", is_unique=False, is_sorted=True) |
|
|
| def _create_vocab(self): |
| self._vocab = [self._pad] + list(self._punctuations) + list(self._characters) + [self._blank] |
| self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)} |
| |
| self._id_to_char = {idx: char for idx, char in enumerate(self.vocab)} |
|
|
| @staticmethod |
| def init_from_config(config: Coqpit): |
| if config.characters is not None: |
| _pad = config.characters["pad"] |
| _punctuations = config.characters["punctuations"] |
| _letters = config.characters["characters"] |
| _letters_ipa = config.characters["phonemes"] |
| return ( |
| VitsCharacters(graphemes=_letters, ipa_characters=_letters_ipa, punctuations=_punctuations, pad=_pad), |
| config, |
| ) |
| characters = VitsCharacters() |
| new_config = replace(config, characters=characters.to_config()) |
| return characters, new_config |
|
|
| def to_config(self) -> "CharactersConfig": |
| return CharactersConfig( |
| characters=self._characters, |
| punctuations=self._punctuations, |
| pad=self._pad, |
| eos=None, |
| bos=None, |
| blank=self._blank, |
| is_unique=False, |
| is_sorted=True, |
| ) |
|
|
|
|
| class FairseqVocab(BaseVocabulary): |
| def __init__(self, vocab: str): |
| super(FairseqVocab).__init__() |
| self.vocab = vocab |
|
|
| @property |
| def vocab(self): |
| """Return the vocabulary dictionary.""" |
| return self._vocab |
|
|
| @vocab.setter |
| def vocab(self, vocab_file): |
| with open(vocab_file, encoding="utf-8") as f: |
| self._vocab = [x.replace("\n", "") for x in f.readlines()] |
| self.blank = self._vocab[0] |
| self.pad = " " |
| self._char_to_id = {s: i for i, s in enumerate(self._vocab)} |
| self._id_to_char = {i: s for i, s in enumerate(self._vocab)} |
|
|