| import logging |
| import os |
| from dataclasses import dataclass, field |
| from itertools import chain |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple, Union |
|
|
| import numpy as np |
| import torch |
| import torch.distributed as dist |
| import torchaudio |
| from coqpit import Coqpit |
| from librosa.filters import mel as librosa_mel_fn |
| from torch import nn |
| from torch.cuda.amp.autocast_mode import autocast |
| from torch.nn import functional as F |
| from torch.utils.data import DataLoader |
| from torch.utils.data.sampler import WeightedRandomSampler |
| from trainer.io import load_fsspec |
| from trainer.torch import DistributedSampler, DistributedSamplerWrapper |
| from trainer.trainer_utils import get_optimizer, get_scheduler |
|
|
| from src.utils.TTS.tts.datasets.dataset import F0Dataset, TTSDataset, _parse_sample |
| from src.utils.TTS.tts.layers.delightful_tts.acoustic_model import AcousticModel |
| from src.utils.TTS.tts.layers.losses import ForwardSumLoss, VitsDiscriminatorLoss |
| from src.utils.TTS.tts.layers.vits.discriminator import VitsDiscriminator |
| from src.utils.TTS.tts.models.base_tts import BaseTTSE2E |
| from src.utils.TTS.tts.utils.helpers import average_over_durations, compute_attn_prior, rand_segments, segment, sequence_mask |
| from src.utils.TTS.tts.utils.speakers import SpeakerManager |
| from src.utils.TTS.tts.utils.text.tokenizer import TTSTokenizer |
| from src.utils.TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_pitch, plot_spectrogram |
| from src.utils.TTS.utils.audio.numpy_transforms import build_mel_basis, compute_f0 |
| from src.utils.TTS.utils.audio.numpy_transforms import db_to_amp as db_to_amp_numpy |
| from src.utils.TTS.utils.audio.numpy_transforms import mel_to_wav as mel_to_wav_numpy |
| from src.utils.TTS.utils.audio.processor import AudioProcessor |
| from src.utils.TTS.vocoder.layers.losses import MultiScaleSTFTLoss |
| from src.utils.TTS.vocoder.models.hifigan_generator import HifiganGenerator |
| from src.utils.TTS.vocoder.utils.generic_utils import plot_results |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def id_to_torch(aux_id, cuda=False): |
| if aux_id is not None: |
| aux_id = np.asarray(aux_id) |
| aux_id = torch.from_numpy(aux_id) |
| if cuda: |
| return aux_id.cuda() |
| return aux_id |
|
|
|
|
| def embedding_to_torch(d_vector, cuda=False): |
| if d_vector is not None: |
| d_vector = np.asarray(d_vector) |
| d_vector = torch.from_numpy(d_vector).float() |
| d_vector = d_vector.squeeze().unsqueeze(0) |
| if cuda: |
| return d_vector.cuda() |
| return d_vector |
|
|
|
|
| def numpy_to_torch(np_array, dtype, cuda=False): |
| if np_array is None: |
| return None |
| tensor = torch.as_tensor(np_array, dtype=dtype) |
| if cuda: |
| return tensor.cuda() |
| return tensor |
|
|
|
|
| def get_mask_from_lengths(lengths: torch.Tensor) -> torch.Tensor: |
| batch_size = lengths.shape[0] |
| max_len = torch.max(lengths).item() |
| ids = torch.arange(0, max_len, device=lengths.device).unsqueeze(0).expand(batch_size, -1) |
| mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) |
| return mask |
|
|
|
|
| def pad(input_ele: List[torch.Tensor], max_len: int) -> torch.Tensor: |
| out_list = torch.jit.annotate(List[torch.Tensor], []) |
| for batch in input_ele: |
| if len(batch.shape) == 1: |
| one_batch_padded = F.pad(batch, (0, max_len - batch.size(0)), "constant", 0.0) |
| else: |
| one_batch_padded = F.pad(batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0) |
| out_list.append(one_batch_padded) |
| out_padded = torch.stack(out_list) |
| return out_padded |
|
|
|
|
| def stride_lens(lens: torch.Tensor, stride: int = 2) -> torch.Tensor: |
| return torch.ceil(lens / stride).int() |
|
|
|
|
| def initialize_embeddings(shape: Tuple[int]) -> torch.Tensor: |
| assert len(shape) == 2, "Can only initialize 2-D embedding matrices ..." |
| return torch.randn(shape) * np.sqrt(2 / shape[1]) |
|
|
|
|
| |
| def calc_same_padding(kernel_size: int) -> Tuple[int, int]: |
| pad = kernel_size // 2 |
| return (pad, pad - (kernel_size + 1) % 2) |
|
|
|
|
| hann_window = {} |
| mel_basis = {} |
|
|
|
|
| @torch.no_grad() |
| def weights_reset(m: nn.Module): |
| |
| reset_parameters = getattr(m, "reset_parameters", None) |
| if callable(reset_parameters): |
| m.reset_parameters() |
|
|
|
|
| def get_module_weights_sum(mdl: nn.Module): |
| dict_sums = {} |
| for name, w in mdl.named_parameters(): |
| if "weight" in name: |
| value = w.data.sum().item() |
| dict_sums[name] = value |
| return dict_sums |
|
|
|
|
| def load_audio(file_path: str): |
| """Load the audio file normalized in [-1, 1] |
| |
| Return Shapes: |
| - x: :math:`[1, T]` |
| """ |
| x, sr = torchaudio.load( |
| file_path, |
| ) |
| assert (x > 1).sum() + (x < -1).sum() == 0 |
| return x, sr |
|
|
|
|
| def _amp_to_db(x, C=1, clip_val=1e-5): |
| return torch.log(torch.clamp(x, min=clip_val) * C) |
|
|
|
|
| def _db_to_amp(x, C=1): |
| return torch.exp(x) / C |
|
|
|
|
| def amp_to_db(magnitudes): |
| output = _amp_to_db(magnitudes) |
| return output |
|
|
|
|
| def db_to_amp(magnitudes): |
| output = _db_to_amp(magnitudes) |
| return output |
|
|
|
|
| def _wav_to_spec(y, n_fft, hop_length, win_length, center=False): |
| y = y.squeeze(1) |
|
|
| if torch.min(y) < -1.0: |
| logger.info("min value is %.3f", torch.min(y)) |
| if torch.max(y) > 1.0: |
| logger.info("max value is %.3f", torch.max(y)) |
|
|
| global hann_window |
| dtype_device = str(y.dtype) + "_" + str(y.device) |
| wnsize_dtype_device = str(win_length) + "_" + dtype_device |
| if wnsize_dtype_device not in hann_window: |
| hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) |
|
|
| y = torch.nn.functional.pad( |
| y.unsqueeze(1), |
| (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), |
| mode="reflect", |
| ) |
| y = y.squeeze(1) |
|
|
| spec = torch.view_as_real( |
| torch.stft( |
| y, |
| n_fft, |
| hop_length=hop_length, |
| win_length=win_length, |
| window=hann_window[wnsize_dtype_device], |
| center=center, |
| pad_mode="reflect", |
| normalized=False, |
| onesided=True, |
| return_complex=True, |
| ) |
| ) |
|
|
| return spec |
|
|
|
|
| def wav_to_spec(y, n_fft, hop_length, win_length, center=False): |
| """ |
| Args Shapes: |
| - y : :math:`[B, 1, T]` |
| |
| Return Shapes: |
| - spec : :math:`[B,C,T]` |
| """ |
| spec = _wav_to_spec(y, n_fft, hop_length, win_length, center=center) |
| spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) |
| return spec |
|
|
|
|
| def wav_to_energy(y, n_fft, hop_length, win_length, center=False): |
| spec = _wav_to_spec(y, n_fft, hop_length, win_length, center=center) |
|
|
| spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) |
| return torch.norm(spec, dim=1, keepdim=True) |
|
|
|
|
| def name_mel_basis(spec, n_fft, fmax): |
| n_fft_len = f"{n_fft}_{fmax}_{spec.dtype}_{spec.device}" |
| return n_fft_len |
|
|
|
|
| def spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax): |
| """ |
| Args Shapes: |
| - spec : :math:`[B,C,T]` |
| |
| Return Shapes: |
| - mel : :math:`[B,C,T]` |
| """ |
| global mel_basis |
| mel_basis_key = name_mel_basis(spec, n_fft, fmax) |
| |
| if mel_basis_key not in mel_basis: |
| |
| mel = librosa_mel_fn(sample_rate, n_fft, num_mels, fmin, fmax) |
| mel_basis[mel_basis_key] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) |
| mel = torch.matmul(mel_basis[mel_basis_key], spec) |
| mel = amp_to_db(mel) |
| return mel |
|
|
|
|
| def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fmax, center=False): |
| """ |
| Args Shapes: |
| - y : :math:`[B, 1, T_y]` |
| |
| Return Shapes: |
| - spec : :math:`[B,C,T_spec]` |
| """ |
| y = y.squeeze(1) |
|
|
| if torch.min(y) < -1.0: |
| logger.info("min value is %.3f", torch.min(y)) |
| if torch.max(y) > 1.0: |
| logger.info("max value is %.3f", torch.max(y)) |
|
|
| global mel_basis, hann_window |
| mel_basis_key = name_mel_basis(y, n_fft, fmax) |
| wnsize_dtype_device = str(win_length) + "_" + str(y.dtype) + "_" + str(y.device) |
| if mel_basis_key not in mel_basis: |
| |
| mel = librosa_mel_fn( |
| sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax |
| ) |
| mel_basis[mel_basis_key] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) |
| if wnsize_dtype_device not in hann_window: |
| hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) |
|
|
| y = torch.nn.functional.pad( |
| y.unsqueeze(1), |
| (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), |
| mode="reflect", |
| ) |
| y = y.squeeze(1) |
|
|
| spec = torch.view_as_real( |
| torch.stft( |
| y, |
| n_fft, |
| hop_length=hop_length, |
| win_length=win_length, |
| window=hann_window[wnsize_dtype_device], |
| center=center, |
| pad_mode="reflect", |
| normalized=False, |
| onesided=True, |
| return_complex=True, |
| ) |
| ) |
|
|
| spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) |
| spec = torch.matmul(mel_basis[mel_basis_key], spec) |
| spec = amp_to_db(spec) |
| return spec |
|
|
|
|
| |
| |
| |
|
|
|
|
| def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: dict = None): |
| """Create balancer weight for torch WeightedSampler""" |
| attr_names_samples = np.array([item[attr_name] for item in items]) |
| unique_attr_names = np.unique(attr_names_samples).tolist() |
| attr_idx = [unique_attr_names.index(l) for l in attr_names_samples] |
| attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names]) |
| weight_attr = 1.0 / attr_count |
| dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx]) |
| dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight) |
| if multi_dict is not None: |
| multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items]) |
| dataset_samples_weight *= multiplier_samples |
| return ( |
| torch.from_numpy(dataset_samples_weight).float(), |
| unique_attr_names, |
| np.unique(dataset_samples_weight).tolist(), |
| ) |
|
|
|
|
| class ForwardTTSE2eF0Dataset(F0Dataset): |
| """Override F0Dataset to avoid slow computing of pitches""" |
|
|
| def __init__( |
| self, |
| ap, |
| samples: Union[List[List], List[Dict]], |
| cache_path: str = None, |
| precompute_num_workers=0, |
| normalize_f0=True, |
| ): |
| super().__init__( |
| samples=samples, |
| ap=ap, |
| cache_path=cache_path, |
| precompute_num_workers=precompute_num_workers, |
| normalize_f0=normalize_f0, |
| ) |
|
|
| def _compute_and_save_pitch(self, wav_file, pitch_file=None): |
| wav, _ = load_audio(wav_file) |
| f0 = compute_f0( |
| x=wav.numpy()[0], |
| sample_rate=self.ap.sample_rate, |
| hop_length=self.ap.hop_length, |
| pitch_fmax=self.ap.pitch_fmax, |
| pitch_fmin=self.ap.pitch_fmin, |
| win_length=self.ap.win_length, |
| ) |
| |
| if wav.shape[1] % self.ap.hop_length != 0: |
| f0 = f0[:-1] |
| if pitch_file: |
| np.save(pitch_file, f0) |
| return f0 |
|
|
| def compute_or_load(self, wav_file, audio_name): |
| """ |
| compute pitch and return a numpy array of pitch values |
| """ |
| pitch_file = self.create_pitch_file_path(audio_name, self.cache_path) |
| if not os.path.exists(pitch_file): |
| pitch = self._compute_and_save_pitch(wav_file=wav_file, pitch_file=pitch_file) |
| else: |
| pitch = np.load(pitch_file) |
| return pitch.astype(np.float32) |
|
|
|
|
| class ForwardTTSE2eDataset(TTSDataset): |
| def __init__(self, *args, **kwargs): |
| |
| compute_f0 = kwargs.pop("compute_f0", False) |
| kwargs["compute_f0"] = False |
| self.attn_prior_cache_path = kwargs.pop("attn_prior_cache_path") |
|
|
| super().__init__(*args, **kwargs) |
|
|
| self.compute_f0 = compute_f0 |
| self.pad_id = self.tokenizer.characters.pad_id |
| self.ap = kwargs["ap"] |
|
|
| if self.compute_f0: |
| self.f0_dataset = ForwardTTSE2eF0Dataset( |
| ap=self.ap, |
| samples=self.samples, |
| cache_path=kwargs["f0_cache_path"], |
| precompute_num_workers=kwargs["precompute_num_workers"], |
| ) |
|
|
| if self.attn_prior_cache_path is not None: |
| os.makedirs(self.attn_prior_cache_path, exist_ok=True) |
|
|
| def __getitem__(self, idx): |
| item = self.samples[idx] |
|
|
| rel_wav_path = Path(item["audio_file"]).relative_to(item["root_path"]).with_suffix("") |
| rel_wav_path = str(rel_wav_path).replace("/", "_") |
|
|
| raw_text = item["text"] |
| wav, _ = load_audio(item["audio_file"]) |
| wav_filename = os.path.basename(item["audio_file"]) |
|
|
| try: |
| token_ids = self.get_token_ids(idx, item["text"]) |
| except: |
| logger.exception("%s %s", idx, item) |
| |
| raise OSError |
| f0 = None |
| if self.compute_f0: |
| f0 = self.get_f0(idx)["f0"] |
|
|
| |
| |
| |
| if len(token_ids) > self.max_text_len or wav.shape[1] < self.min_audio_len: |
| self.rescue_item_idx += 1 |
| return self.__getitem__(self.rescue_item_idx) |
|
|
| attn_prior = None |
| if self.attn_prior_cache_path is not None: |
| attn_prior = self.load_or_compute_attn_prior(token_ids, wav, rel_wav_path) |
|
|
| return { |
| "raw_text": raw_text, |
| "token_ids": token_ids, |
| "token_len": len(token_ids), |
| "wav": wav, |
| "pitch": f0, |
| "wav_file": wav_filename, |
| "speaker_name": item["speaker_name"], |
| "language_name": item["language"], |
| "attn_prior": attn_prior, |
| "audio_unique_name": item["audio_unique_name"], |
| } |
|
|
| def load_or_compute_attn_prior(self, token_ids, wav, rel_wav_path): |
| """Load or compute and save the attention prior.""" |
| attn_prior_file = os.path.join(self.attn_prior_cache_path, f"{rel_wav_path}.npy") |
| |
| if os.path.exists(attn_prior_file): |
| return np.load(attn_prior_file) |
| else: |
| token_len = len(token_ids) |
| mel_len = wav.shape[1] // self.ap.hop_length |
| attn_prior = compute_attn_prior(token_len, mel_len) |
| np.save(attn_prior_file, attn_prior) |
| return attn_prior |
|
|
| @property |
| def lengths(self): |
| lens = [] |
| for item in self.samples: |
| _, wav_file, *_ = _parse_sample(item) |
| audio_len = os.path.getsize(wav_file) / 16 * 8 |
| lens.append(audio_len) |
| return lens |
|
|
| def collate_fn(self, batch): |
| """ |
| Return Shapes: |
| - tokens: :math:`[B, T]` |
| - token_lens :math:`[B]` |
| - token_rel_lens :math:`[B]` |
| - pitch :math:`[B, T]` |
| - waveform: :math:`[B, 1, T]` |
| - waveform_lens: :math:`[B]` |
| - waveform_rel_lens: :math:`[B]` |
| - speaker_names: :math:`[B]` |
| - language_names: :math:`[B]` |
| - audiofile_paths: :math:`[B]` |
| - raw_texts: :math:`[B]` |
| - attn_prior: :math:`[[T_token, T_mel]]` |
| """ |
| B = len(batch) |
| batch = {k: [dic[k] for dic in batch] for k in batch[0]} |
|
|
| max_text_len = max([len(x) for x in batch["token_ids"]]) |
| token_lens = torch.LongTensor(batch["token_len"]) |
| token_rel_lens = token_lens / token_lens.max() |
|
|
| wav_lens = [w.shape[1] for w in batch["wav"]] |
| wav_lens = torch.LongTensor(wav_lens) |
| wav_lens_max = torch.max(wav_lens) |
| wav_rel_lens = wav_lens / wav_lens_max |
|
|
| pitch_padded = None |
| if self.compute_f0: |
| pitch_lens = [p.shape[0] for p in batch["pitch"]] |
| pitch_lens = torch.LongTensor(pitch_lens) |
| pitch_lens_max = torch.max(pitch_lens) |
| pitch_padded = torch.FloatTensor(B, 1, pitch_lens_max) |
| pitch_padded = pitch_padded.zero_() + self.pad_id |
|
|
| token_padded = torch.LongTensor(B, max_text_len) |
| wav_padded = torch.FloatTensor(B, 1, wav_lens_max) |
|
|
| token_padded = token_padded.zero_() + self.pad_id |
| wav_padded = wav_padded.zero_() + self.pad_id |
|
|
| for i in range(B): |
| token_ids = batch["token_ids"][i] |
| token_padded[i, : batch["token_len"][i]] = torch.LongTensor(token_ids) |
|
|
| wav = batch["wav"][i] |
| wav_padded[i, :, : wav.size(1)] = torch.FloatTensor(wav) |
|
|
| if self.compute_f0: |
| pitch = batch["pitch"][i] |
| pitch_padded[i, 0, : len(pitch)] = torch.FloatTensor(pitch) |
|
|
| return { |
| "text_input": token_padded, |
| "text_lengths": token_lens, |
| "text_rel_lens": token_rel_lens, |
| "pitch": pitch_padded, |
| "waveform": wav_padded, |
| "waveform_lens": wav_lens, |
| "waveform_rel_lens": wav_rel_lens, |
| "speaker_names": batch["speaker_name"], |
| "language_names": batch["language_name"], |
| "audio_unique_names": batch["audio_unique_name"], |
| "audio_files": batch["wav_file"], |
| "raw_text": batch["raw_text"], |
| "attn_priors": batch["attn_prior"] if batch["attn_prior"][0] is not None else None, |
| } |
|
|
|
|
| |
| |
| |
|
|
|
|
| @dataclass |
| class VocoderConfig(Coqpit): |
| resblock_type_decoder: str = "1" |
| resblock_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [3, 7, 11]) |
| resblock_dilation_sizes_decoder: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]) |
| upsample_rates_decoder: List[int] = field(default_factory=lambda: [8, 8, 2, 2]) |
| upsample_initial_channel_decoder: int = 512 |
| upsample_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [16, 16, 4, 4]) |
| use_spectral_norm_discriminator: bool = False |
| upsampling_rates_discriminator: List[int] = field(default_factory=lambda: [4, 4, 4, 4]) |
| periods_discriminator: List[int] = field(default_factory=lambda: [2, 3, 5, 7, 11]) |
| pretrained_model_path: Optional[str] = None |
|
|
|
|
| @dataclass |
| class DelightfulTtsAudioConfig(Coqpit): |
| sample_rate: int = 22050 |
| hop_length: int = 256 |
| win_length: int = 1024 |
| fft_size: int = 1024 |
| mel_fmin: float = 0.0 |
| mel_fmax: float = 8000 |
| num_mels: int = 100 |
| pitch_fmax: float = 640.0 |
| pitch_fmin: float = 1.0 |
| resample: bool = False |
| preemphasis: float = 0.0 |
| ref_level_db: int = 20 |
| do_sound_norm: bool = False |
| log_func: str = "np.log10" |
| do_trim_silence: bool = True |
| trim_db: int = 45 |
| do_rms_norm: bool = False |
| db_level: float = None |
| power: float = 1.5 |
| griffin_lim_iters: int = 60 |
| spec_gain: int = 20 |
| do_amp_to_db_linear: bool = True |
| do_amp_to_db_mel: bool = True |
| min_level_db: int = -100 |
| max_norm: float = 4.0 |
|
|
|
|
| @dataclass |
| class DelightfulTtsArgs(Coqpit): |
| num_chars: int = 100 |
| spec_segment_size: int = 32 |
| n_hidden_conformer_encoder: int = 512 |
| n_layers_conformer_encoder: int = 6 |
| n_heads_conformer_encoder: int = 8 |
| dropout_conformer_encoder: float = 0.1 |
| kernel_size_conv_mod_conformer_encoder: int = 7 |
| kernel_size_depthwise_conformer_encoder: int = 7 |
| lrelu_slope: float = 0.3 |
| n_hidden_conformer_decoder: int = 512 |
| n_layers_conformer_decoder: int = 6 |
| n_heads_conformer_decoder: int = 8 |
| dropout_conformer_decoder: float = 0.1 |
| kernel_size_conv_mod_conformer_decoder: int = 11 |
| kernel_size_depthwise_conformer_decoder: int = 11 |
| bottleneck_size_p_reference_encoder: int = 4 |
| bottleneck_size_u_reference_encoder: int = 512 |
| ref_enc_filters_reference_encoder = [32, 32, 64, 64, 128, 128] |
| ref_enc_size_reference_encoder: int = 3 |
| ref_enc_strides_reference_encoder = [1, 2, 1, 2, 1] |
| ref_enc_pad_reference_encoder = [1, 1] |
| ref_enc_gru_size_reference_encoder: int = 32 |
| ref_attention_dropout_reference_encoder: float = 0.2 |
| token_num_reference_encoder: int = 32 |
| predictor_kernel_size_reference_encoder: int = 5 |
| n_hidden_variance_adaptor: int = 512 |
| kernel_size_variance_adaptor: int = 5 |
| dropout_variance_adaptor: float = 0.5 |
| n_bins_variance_adaptor: int = 256 |
| emb_kernel_size_variance_adaptor: int = 3 |
| use_speaker_embedding: bool = False |
| num_speakers: int = 0 |
| speakers_file: str = None |
| d_vector_file: str = None |
| speaker_embedding_channels: int = 384 |
| use_d_vector_file: bool = False |
| d_vector_dim: int = 0 |
| freeze_vocoder: bool = False |
| freeze_text_encoder: bool = False |
| freeze_duration_predictor: bool = False |
| freeze_pitch_predictor: bool = False |
| freeze_energy_predictor: bool = False |
| freeze_basis_vectors_predictor: bool = False |
| freeze_decoder: bool = False |
| length_scale: float = 1.0 |
|
|
|
|
| |
| |
| |
| class DelightfulTTS(BaseTTSE2E): |
| """ |
| Paper:: |
| https://arxiv.org/pdf/2110.12612.pdf |
| |
| Paper Abstract:: |
| This paper describes the Microsoft end-to-end neural text to speech (TTS) system: DelightfulTTS for Blizzard Challenge 2021. |
| The goal of this challenge is to synthesize natural and high-quality speech from text, and we approach this goal in two perspectives: |
| The first is to directly model and generate waveform in 48 kHz sampling rate, which brings higher perception quality than previous systems |
| with 16 kHz or 24 kHz sampling rate; The second is to model the variation information in speech through a systematic design, which improves |
| the prosody and naturalness. Specifically, for 48 kHz modeling, we predict 16 kHz mel-spectrogram in acoustic model, and |
| propose a vocoder called HiFiNet to directly generate 48 kHz waveform from predicted 16 kHz mel-spectrogram, which can better trade off training |
| efficiency, modelling stability and voice quality. We model variation information systematically from both explicit (speaker ID, language ID, pitch and duration) and |
| implicit (utterance-level and phoneme-level prosody) perspectives: 1) For speaker and language ID, we use lookup embedding in training and |
| inference; 2) For pitch and duration, we extract the values from paired text-speech data in training and use two predictors to predict the values in inference; 3) |
| For utterance-level and phoneme-level prosody, we use two reference encoders to extract the values in training, and use two separate predictors to predict the values in inference. |
| Additionally, we introduce an improved Conformer block to better model the local and global dependency in acoustic model. For task SH1, DelightfulTTS achieves 4.17 mean score in MOS test |
| and 4.35 in SMOS test, which indicates the effectiveness of our proposed system |
| |
| |
| Model training:: |
| text --> ForwardTTS() --> spec_hat --> rand_seg_select()--> GANVocoder() --> waveform_seg |
| spec --------^ |
| |
| Examples: |
| >>> from src.utils.TTS.tts.models.forward_tts_e2e import ForwardTTSE2e, ForwardTTSE2eConfig |
| >>> config = ForwardTTSE2eConfig() |
| >>> model = ForwardTTSE2e(config) |
| """ |
|
|
| |
| def __init__( |
| self, |
| config: Coqpit, |
| ap, |
| tokenizer: "TTSTokenizer" = None, |
| speaker_manager: SpeakerManager = None, |
| ): |
| super().__init__(config=config, ap=ap, tokenizer=tokenizer, speaker_manager=speaker_manager) |
| self.ap = ap |
|
|
| self._set_model_args(config) |
| self.init_multispeaker(config) |
| self.binary_loss_weight = None |
|
|
| self.args.out_channels = self.config.audio.num_mels |
| self.args.num_mels = self.config.audio.num_mels |
| self.acoustic_model = AcousticModel(args=self.args, tokenizer=tokenizer, speaker_manager=speaker_manager) |
|
|
| self.waveform_decoder = HifiganGenerator( |
| self.config.audio.num_mels, |
| 1, |
| self.config.vocoder.resblock_type_decoder, |
| self.config.vocoder.resblock_dilation_sizes_decoder, |
| self.config.vocoder.resblock_kernel_sizes_decoder, |
| self.config.vocoder.upsample_kernel_sizes_decoder, |
| self.config.vocoder.upsample_initial_channel_decoder, |
| self.config.vocoder.upsample_rates_decoder, |
| inference_padding=0, |
| |
| conv_pre_weight_norm=False, |
| conv_post_weight_norm=False, |
| conv_post_bias=False, |
| ) |
|
|
| if self.config.init_discriminator: |
| self.disc = VitsDiscriminator( |
| use_spectral_norm=self.config.vocoder.use_spectral_norm_discriminator, |
| periods=self.config.vocoder.periods_discriminator, |
| ) |
|
|
| @property |
| def device(self): |
| return next(self.parameters()).device |
|
|
| @property |
| def energy_scaler(self): |
| return self.acoustic_model.energy_scaler |
|
|
| @property |
| def length_scale(self): |
| return self.acoustic_model.length_scale |
|
|
| @length_scale.setter |
| def length_scale(self, value): |
| self.acoustic_model.length_scale = value |
|
|
| @property |
| def pitch_mean(self): |
| return self.acoustic_model.pitch_mean |
|
|
| @pitch_mean.setter |
| def pitch_mean(self, value): |
| self.acoustic_model.pitch_mean = value |
|
|
| @property |
| def pitch_std(self): |
| return self.acoustic_model.pitch_std |
|
|
| @pitch_std.setter |
| def pitch_std(self, value): |
| self.acoustic_model.pitch_std = value |
|
|
| @property |
| def mel_basis(self): |
| return build_mel_basis( |
| sample_rate=self.ap.sample_rate, |
| fft_size=self.ap.fft_size, |
| num_mels=self.ap.num_mels, |
| mel_fmax=self.ap.mel_fmax, |
| mel_fmin=self.ap.mel_fmin, |
| ) |
|
|
| def init_for_training(self) -> None: |
| self.train_disc = ( |
| self.config.steps_to_start_discriminator <= 0 |
| ) |
| self.update_energy_scaler = True |
|
|
| def init_multispeaker(self, config: Coqpit): |
| """Init for multi-speaker training. |
| |
| Args: |
| config (Coqpit): Model configuration. |
| """ |
| self.embedded_speaker_dim = 0 |
| self.num_speakers = self.args.num_speakers |
| self.audio_transform = None |
|
|
| if self.speaker_manager: |
| self.num_speakers = self.speaker_manager.num_speakers |
| self.args.num_speakers = self.speaker_manager.num_speakers |
|
|
| if self.args.use_speaker_embedding: |
| self._init_speaker_embedding() |
|
|
| if self.args.use_d_vector_file: |
| self._init_d_vector() |
|
|
| def _init_speaker_embedding(self): |
| |
| if self.num_speakers > 0: |
| logger.info("Initialization of speaker-embedding layers.") |
| self.embedded_speaker_dim = self.args.speaker_embedding_channels |
| self.args.embedded_speaker_dim = self.args.speaker_embedding_channels |
|
|
| def _init_d_vector(self): |
| |
| if hasattr(self, "emb_g"): |
| raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.") |
| self.embedded_speaker_dim = self.args.d_vector_dim |
| self.args.embedded_speaker_dim = self.args.d_vector_dim |
|
|
| def _freeze_layers(self): |
| if self.args.freeze_vocoder: |
| for param in self.vocoder.paramseters(): |
| param.requires_grad = False |
|
|
| if self.args.freeze_text_encoder: |
| for param in self.text_encoder.parameters(): |
| param.requires_grad = False |
|
|
| if self.args.freeze_duration_predictor: |
| for param in self.durarion_predictor.parameters(): |
| param.requires_grad = False |
|
|
| if self.args.freeze_pitch_predictor: |
| for param in self.pitch_predictor.parameters(): |
| param.requires_grad = False |
|
|
| if self.args.freeze_energy_predictor: |
| for param in self.energy_predictor.parameters(): |
| param.requires_grad = False |
|
|
| if self.args.freeze_decoder: |
| for param in self.decoder.parameters(): |
| param.requires_grad = False |
|
|
| def forward( |
| self, |
| x: torch.LongTensor, |
| x_lengths: torch.LongTensor, |
| spec_lengths: torch.LongTensor, |
| spec: torch.FloatTensor, |
| waveform: torch.FloatTensor, |
| pitch: torch.FloatTensor = None, |
| energy: torch.FloatTensor = None, |
| attn_priors: torch.FloatTensor = None, |
| d_vectors: torch.FloatTensor = None, |
| speaker_idx: torch.LongTensor = None, |
| ) -> Dict: |
| """Model's forward pass. |
| |
| Args: |
| x (torch.LongTensor): Input character sequences. |
| x_lengths (torch.LongTensor): Input sequence lengths. |
| spec_lengths (torch.LongTensor): Spectrogram sequnce lengths. Defaults to None. |
| spec (torch.FloatTensor): Spectrogram frames. Only used when the alignment network is on. Defaults to None. |
| waveform (torch.FloatTensor): Waveform. Defaults to None. |
| pitch (torch.FloatTensor): Pitch values for each spectrogram frame. Only used when the pitch predictor is on. Defaults to None. |
| energy (torch.FloatTensor): Spectral energy values for each spectrogram frame. Only used when the energy predictor is on. Defaults to None. |
| attn_priors (torch.FloatTentrasor): Attention priors for the aligner network. Defaults to None. |
| aux_input (Dict): Auxiliary model inputs for multi-speaker training. Defaults to `{"d_vectors": 0, "speaker_ids": None}`. |
| |
| Shapes: |
| - x: :math:`[B, T_max]` |
| - x_lengths: :math:`[B]` |
| - spec_lengths: :math:`[B]` |
| - spec: :math:`[B, T_max2, C_spec]` |
| - waveform: :math:`[B, 1, T_max2 * hop_length]` |
| - g: :math:`[B, C]` |
| - pitch: :math:`[B, 1, T_max2]` |
| - energy: :math:`[B, 1, T_max2]` |
| """ |
| encoder_outputs = self.acoustic_model( |
| tokens=x, |
| src_lens=x_lengths, |
| mel_lens=spec_lengths, |
| mels=spec, |
| pitches=pitch, |
| energies=energy, |
| attn_priors=attn_priors, |
| d_vectors=d_vectors, |
| speaker_idx=speaker_idx, |
| ) |
|
|
| |
| vocoder_input = encoder_outputs["model_outputs"] |
|
|
| vocoder_input_slices, slice_ids = rand_segments( |
| x=vocoder_input.transpose(1, 2), |
| x_lengths=spec_lengths, |
| segment_size=self.args.spec_segment_size, |
| let_short_samples=True, |
| pad_short=True, |
| ) |
| if encoder_outputs["spk_emb"] is not None: |
| g = encoder_outputs["spk_emb"].unsqueeze(-1) |
| else: |
| g = None |
|
|
| vocoder_output = self.waveform_decoder(x=vocoder_input_slices.detach(), g=g) |
| wav_seg = segment( |
| waveform, |
| slice_ids * self.ap.hop_length, |
| self.args.spec_segment_size * self.ap.hop_length, |
| pad_short=True, |
| ) |
| model_outputs = {**encoder_outputs} |
| model_outputs["acoustic_model_outputs"] = encoder_outputs["model_outputs"] |
| model_outputs["model_outputs"] = vocoder_output |
| model_outputs["waveform_seg"] = wav_seg |
| model_outputs["slice_ids"] = slice_ids |
| return model_outputs |
|
|
| @torch.no_grad() |
| def inference( |
| self, x, aux_input={"d_vectors": None, "speaker_ids": None}, pitch_transform=None, energy_transform=None |
| ): |
| encoder_outputs = self.acoustic_model.inference( |
| tokens=x, |
| d_vectors=aux_input["d_vectors"], |
| speaker_idx=aux_input["speaker_ids"], |
| pitch_transform=pitch_transform, |
| energy_transform=energy_transform, |
| p_control=None, |
| d_control=None, |
| ) |
| vocoder_input = encoder_outputs["model_outputs"].transpose(1, 2) |
| if encoder_outputs["spk_emb"] is not None: |
| g = encoder_outputs["spk_emb"].unsqueeze(-1) |
| else: |
| g = None |
|
|
| vocoder_output = self.waveform_decoder(x=vocoder_input, g=g) |
| model_outputs = {**encoder_outputs} |
| model_outputs["model_outputs"] = vocoder_output |
| return model_outputs |
|
|
| @torch.no_grad() |
| def inference_spec_decoder(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): |
| encoder_outputs = self.acoustic_model.inference( |
| tokens=x, |
| d_vectors=aux_input["d_vectors"], |
| speaker_idx=aux_input["speaker_ids"], |
| ) |
| model_outputs = {**encoder_outputs} |
| return model_outputs |
|
|
| def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int): |
| if optimizer_idx == 0: |
| tokens = batch["text_input"] |
| token_lenghts = batch["text_lengths"] |
| mel = batch["mel_input"] |
| mel_lens = batch["mel_lengths"] |
| waveform = batch["waveform"] |
| pitch = batch["pitch"] |
| d_vectors = batch["d_vectors"] |
| speaker_ids = batch["speaker_ids"] |
| attn_priors = batch["attn_priors"] |
| energy = batch["energy"] |
|
|
| |
| outputs = self.forward( |
| x=tokens, |
| x_lengths=token_lenghts, |
| spec_lengths=mel_lens, |
| spec=mel, |
| waveform=waveform, |
| pitch=pitch, |
| energy=energy, |
| attn_priors=attn_priors, |
| d_vectors=d_vectors, |
| speaker_idx=speaker_ids, |
| ) |
|
|
| |
| self.model_outputs_cache = outputs |
|
|
| if self.train_disc: |
| |
| scores_d_fake, _, scores_d_real, _ = self.disc( |
| outputs["model_outputs"].detach(), outputs["waveform_seg"] |
| ) |
|
|
| |
| with autocast(enabled=False): |
| loss_dict = criterion[optimizer_idx]( |
| scores_disc_fake=scores_d_fake, |
| scores_disc_real=scores_d_real, |
| ) |
| return outputs, loss_dict |
| return None, None |
|
|
| if optimizer_idx == 1: |
| mel = batch["mel_input"] |
| |
| with autocast(enabled=False): |
| mel_slice = segment( |
| mel.float(), self.model_outputs_cache["slice_ids"], self.args.spec_segment_size, pad_short=True |
| ) |
|
|
| mel_slice_hat = wav_to_mel( |
| y=self.model_outputs_cache["model_outputs"].float(), |
| n_fft=self.ap.fft_size, |
| sample_rate=self.ap.sample_rate, |
| num_mels=self.ap.num_mels, |
| hop_length=self.ap.hop_length, |
| win_length=self.ap.win_length, |
| fmin=self.ap.mel_fmin, |
| fmax=self.ap.mel_fmax, |
| center=False, |
| ) |
|
|
| scores_d_fake = None |
| feats_d_fake = None |
| feats_d_real = None |
|
|
| if self.train_disc: |
| |
| scores_d_fake, feats_d_fake, _, feats_d_real = self.disc( |
| self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"] |
| ) |
|
|
| |
| with autocast(enabled=True): |
| loss_dict = criterion[optimizer_idx]( |
| mel_output=self.model_outputs_cache["acoustic_model_outputs"].transpose(1, 2), |
| mel_target=batch["mel_input"], |
| mel_lens=batch["mel_lengths"], |
| dur_output=self.model_outputs_cache["dr_log_pred"], |
| dur_target=self.model_outputs_cache["dr_log_target"].detach(), |
| pitch_output=self.model_outputs_cache["pitch_pred"], |
| pitch_target=self.model_outputs_cache["pitch_target"], |
| energy_output=self.model_outputs_cache["energy_pred"], |
| energy_target=self.model_outputs_cache["energy_target"], |
| src_lens=batch["text_lengths"], |
| waveform=self.model_outputs_cache["waveform_seg"], |
| waveform_hat=self.model_outputs_cache["model_outputs"], |
| p_prosody_ref=self.model_outputs_cache["p_prosody_ref"], |
| p_prosody_pred=self.model_outputs_cache["p_prosody_pred"], |
| u_prosody_ref=self.model_outputs_cache["u_prosody_ref"], |
| u_prosody_pred=self.model_outputs_cache["u_prosody_pred"], |
| aligner_logprob=self.model_outputs_cache["aligner_logprob"], |
| aligner_hard=self.model_outputs_cache["aligner_mas"], |
| aligner_soft=self.model_outputs_cache["aligner_soft"], |
| binary_loss_weight=self.binary_loss_weight, |
| feats_fake=feats_d_fake, |
| feats_real=feats_d_real, |
| scores_fake=scores_d_fake, |
| spec_slice=mel_slice, |
| spec_slice_hat=mel_slice_hat, |
| skip_disc=not self.train_disc, |
| ) |
|
|
| loss_dict["avg_text_length"] = batch["text_lengths"].float().mean() |
| loss_dict["avg_mel_length"] = batch["mel_lengths"].float().mean() |
| loss_dict["avg_text_batch_occupancy"] = ( |
| batch["text_lengths"].float() / batch["text_lengths"].float().max() |
| ).mean() |
| loss_dict["avg_mel_batch_occupancy"] = ( |
| batch["mel_lengths"].float() / batch["mel_lengths"].float().max() |
| ).mean() |
|
|
| return self.model_outputs_cache, loss_dict |
| raise ValueError(" [!] Unexpected `optimizer_idx`.") |
|
|
| def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int): |
| return self.train_step(batch, criterion, optimizer_idx) |
|
|
| def _log(self, batch, outputs, name_prefix="train"): |
| figures, audios = {}, {} |
|
|
| |
| model_outputs = outputs[1]["acoustic_model_outputs"] |
| alignments = outputs[1]["alignments"] |
| mel_input = batch["mel_input"] |
|
|
| pred_spec = model_outputs[0].data.cpu().numpy() |
| gt_spec = mel_input[0].data.cpu().numpy() |
| align_img = alignments[0].data.cpu().numpy() |
|
|
| figures = { |
| "prediction": plot_spectrogram(pred_spec, None, output_fig=False), |
| "ground_truth": plot_spectrogram(gt_spec.T, None, output_fig=False), |
| "alignment": plot_alignment(align_img, output_fig=False), |
| } |
|
|
| |
| pitch_avg = abs(outputs[1]["pitch_target"][0, 0].data.cpu().numpy()) |
| pitch_avg_hat = abs(outputs[1]["pitch_pred"][0, 0].data.cpu().numpy()) |
| chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy()) |
| pitch_figures = { |
| "pitch_ground_truth": plot_avg_pitch(pitch_avg, chars, output_fig=False), |
| "pitch_avg_predicted": plot_avg_pitch(pitch_avg_hat, chars, output_fig=False), |
| } |
| figures.update(pitch_figures) |
|
|
| |
| energy_avg = abs(outputs[1]["energy_target"][0, 0].data.cpu().numpy()) |
| energy_avg_hat = abs(outputs[1]["energy_pred"][0, 0].data.cpu().numpy()) |
| chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy()) |
| energy_figures = { |
| "energy_ground_truth": plot_avg_pitch(energy_avg, chars, output_fig=False), |
| "energy_avg_predicted": plot_avg_pitch(energy_avg_hat, chars, output_fig=False), |
| } |
| figures.update(energy_figures) |
|
|
| |
| alignments_hat = outputs[1]["alignments_dp"][0].data.cpu().numpy() |
| figures["alignment_hat"] = plot_alignment(alignments_hat.T, output_fig=False) |
|
|
| |
| encoder_audio = mel_to_wav_numpy( |
| mel=db_to_amp_numpy(x=pred_spec.T, gain=1, base=None), mel_basis=self.mel_basis, **self.config.audio |
| ) |
| audios[f"{name_prefix}/encoder_audio"] = encoder_audio |
|
|
| |
| y_hat = outputs[1]["model_outputs"] |
| y = outputs[1]["waveform_seg"] |
|
|
| vocoder_figures = plot_results(y_hat=y_hat, y=y, ap=self.ap, name_prefix=name_prefix) |
| figures.update(vocoder_figures) |
|
|
| sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() |
| audios[f"{name_prefix}/vocoder_audio"] = sample_voice |
| return figures, audios |
|
|
| def train_log( |
| self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int |
| ): |
| """Create visualizations and waveform examples. |
| |
| For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to |
| be projected onto Tensorboard. |
| |
| Args: |
| batch (Dict): Model inputs used at the previous training step. |
| outputs (Dict): Model outputs generated at the previous training step. |
| |
| Returns: |
| Tuple[Dict, np.ndarray]: training plots and output waveform. |
| """ |
| figures, audios = self._log(batch=batch, outputs=outputs, name_prefix="vocoder/") |
| logger.train_figures(steps, figures) |
| logger.train_audios(steps, audios, self.ap.sample_rate) |
|
|
| def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: |
| figures, audios = self._log(batch=batch, outputs=outputs, name_prefix="vocoder/") |
| logger.eval_figures(steps, figures) |
| logger.eval_audios(steps, audios, self.ap.sample_rate) |
|
|
| def get_aux_input_from_test_sentences(self, sentence_info): |
| if hasattr(self.config, "model_args"): |
| config = self.config.model_args |
| else: |
| config = self.config |
|
|
| |
| text, speaker_name, style_wav = None, None, None |
|
|
| if isinstance(sentence_info, list): |
| if len(sentence_info) == 1: |
| text = sentence_info[0] |
| elif len(sentence_info) == 2: |
| text, speaker_name = sentence_info |
| elif len(sentence_info) == 3: |
| text, speaker_name, style_wav = sentence_info |
| else: |
| text = sentence_info |
|
|
| |
| speaker_id, d_vector = None, None |
| if hasattr(self, "speaker_manager"): |
| if config.use_d_vector_file: |
| if speaker_name is None: |
| d_vector = self.speaker_manager.get_random_embedding() |
| else: |
| d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False) |
| elif config.use_speaker_embedding: |
| if speaker_name is None: |
| speaker_id = self.speaker_manager.get_random_id() |
| else: |
| speaker_id = self.speaker_manager.name_to_id[speaker_name] |
|
|
| return {"text": text, "speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector} |
|
|
| def plot_outputs(self, text, wav, alignment, outputs): |
| figures = {} |
| pitch_avg_pred = outputs["pitch"].cpu() |
| energy_avg_pred = outputs["energy"].cpu() |
| spec = wav_to_mel( |
| y=torch.from_numpy(wav[None, :]), |
| n_fft=self.ap.fft_size, |
| sample_rate=self.ap.sample_rate, |
| num_mels=self.ap.num_mels, |
| hop_length=self.ap.hop_length, |
| win_length=self.ap.win_length, |
| fmin=self.ap.mel_fmin, |
| fmax=self.ap.mel_fmax, |
| center=False, |
| )[0].transpose(0, 1) |
| pitch = compute_f0( |
| x=wav[0], |
| sample_rate=self.ap.sample_rate, |
| hop_length=self.ap.hop_length, |
| pitch_fmax=self.ap.pitch_fmax, |
| ) |
| input_text = self.tokenizer.ids_to_text(self.tokenizer.text_to_ids(text, language="en")) |
| input_text = input_text.replace("<BLNK>", "_") |
| durations = outputs["durations"] |
| pitch_avg = average_over_durations(torch.from_numpy(pitch)[None, None, :], durations.cpu()) |
| pitch_avg_pred_denorm = (pitch_avg_pred * self.pitch_std) + self.pitch_mean |
| figures["alignment"] = plot_alignment(alignment.transpose(1, 2), output_fig=False) |
| figures["spectrogram"] = plot_spectrogram(spec) |
| figures["pitch_from_wav"] = plot_pitch(pitch, spec) |
| figures["pitch_avg_from_wav"] = plot_avg_pitch(pitch_avg.squeeze(), input_text) |
| figures["pitch_avg_pred"] = plot_avg_pitch(pitch_avg_pred_denorm.squeeze(), input_text) |
| figures["energy_avg_pred"] = plot_avg_pitch(energy_avg_pred.squeeze(), input_text) |
| return figures |
|
|
| def synthesize( |
| self, |
| text: str, |
| speaker_id: str = None, |
| d_vector: torch.tensor = None, |
| pitch_transform=None, |
| **kwargs, |
| ): |
| |
| is_cuda = next(self.parameters()).is_cuda |
|
|
| |
| text_inputs = np.asarray( |
| self.tokenizer.text_to_ids(text, language=None), |
| dtype=np.int32, |
| ) |
|
|
| |
| _speaker_id = None |
| if speaker_id is not None and self.args.use_speaker_embedding: |
| if isinstance(speaker_id, str) and self.args.use_speaker_embedding: |
| |
| _speaker_id = self.speaker_manager.name_to_id[speaker_id] |
| _speaker_id = id_to_torch(_speaker_id, cuda=is_cuda) |
|
|
| if speaker_id is not None and self.args.use_d_vector_file: |
| |
| d_vector = self.speaker_manager.get_mean_embedding(speaker_id, num_samples=None, randomize=False) |
| d_vector = embedding_to_torch(d_vector, cuda=is_cuda) |
|
|
| text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=is_cuda) |
| text_inputs = text_inputs.unsqueeze(0) |
|
|
| |
| outputs = self.inference( |
| text_inputs, |
| aux_input={"d_vectors": d_vector, "speaker_ids": _speaker_id}, |
| pitch_transform=pitch_transform, |
| |
| ) |
|
|
| |
| wav = outputs["model_outputs"][0].data.cpu().numpy() |
| alignments = outputs["alignments"] |
| return_dict = { |
| "wav": wav, |
| "alignments": alignments, |
| "text_inputs": text_inputs, |
| "outputs": outputs, |
| } |
| return return_dict |
|
|
| def synthesize_with_gl(self, text: str, speaker_id, d_vector): |
| is_cuda = next(self.parameters()).is_cuda |
|
|
| |
| text_inputs = np.asarray( |
| self.tokenizer.text_to_ids(text, language=None), |
| dtype=np.int32, |
| ) |
| |
| if speaker_id is not None: |
| speaker_id = id_to_torch(speaker_id, cuda=is_cuda) |
|
|
| if d_vector is not None: |
| d_vector = embedding_to_torch(d_vector, cuda=is_cuda) |
|
|
| text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=is_cuda) |
| text_inputs = text_inputs.unsqueeze(0) |
|
|
| |
| outputs = self.inference_spec_decoder( |
| x=text_inputs, |
| aux_input={"d_vectors": d_vector, "speaker_ids": speaker_id}, |
| ) |
|
|
| |
| S = outputs["model_outputs"].cpu().numpy()[0].T |
| S = db_to_amp_numpy(x=S, gain=1, base=None) |
| wav = mel_to_wav_numpy(mel=S, mel_basis=self.mel_basis, **self.config.audio) |
| alignments = outputs["alignments"] |
| return_dict = { |
| "wav": wav[None, :], |
| "alignments": alignments, |
| "text_inputs": text_inputs, |
| "outputs": outputs, |
| } |
| return return_dict |
|
|
| @torch.no_grad() |
| def test_run(self, assets) -> Tuple[Dict, Dict]: |
| """Generic test run for `tts` models used by `Trainer`. |
| |
| You can override this for a different behaviour. |
| |
| Returns: |
| Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. |
| """ |
| logger.info("Synthesizing test sentences.") |
| test_audios = {} |
| test_figures = {} |
| test_sentences = self.config.test_sentences |
| for idx, s_info in enumerate(test_sentences): |
| aux_inputs = self.get_aux_input_from_test_sentences(s_info) |
| outputs = self.synthesize( |
| aux_inputs["text"], |
| config=self.config, |
| speaker_id=aux_inputs["speaker_id"], |
| d_vector=aux_inputs["d_vector"], |
| ) |
| outputs_gl = self.synthesize_with_gl( |
| aux_inputs["text"], |
| speaker_id=aux_inputs["speaker_id"], |
| d_vector=aux_inputs["d_vector"], |
| ) |
| |
| test_audios["{}-audio".format(idx)] = outputs["wav"].T |
| test_audios["{}-audio_encoder".format(idx)] = outputs_gl["wav"].T |
| test_figures["{}-alignment".format(idx)] = plot_alignment(outputs["alignments"], output_fig=False) |
| return {"figures": test_figures, "audios": test_audios} |
|
|
| def test_log( |
| self, outputs: dict, logger: "Logger", assets: dict, steps: int |
| ) -> None: |
| logger.test_audios(steps, outputs["audios"], self.config.audio.sample_rate) |
| logger.test_figures(steps, outputs["figures"]) |
|
|
| def format_batch(self, batch: Dict) -> Dict: |
| """Compute speaker, langugage IDs and d_vector for the batch if necessary.""" |
| speaker_ids = None |
| d_vectors = None |
|
|
| |
| if self.speaker_manager is not None and self.speaker_manager.speaker_names and self.args.use_speaker_embedding: |
| speaker_ids = [self.speaker_manager.name_to_id[sn] for sn in batch["speaker_names"]] |
|
|
| if speaker_ids is not None: |
| speaker_ids = torch.LongTensor(speaker_ids) |
| batch["speaker_ids"] = speaker_ids |
|
|
| |
| if self.speaker_manager is not None and self.speaker_manager.embeddings and self.args.use_d_vector_file: |
| d_vector_mapping = self.speaker_manager.embeddings |
| d_vectors = [d_vector_mapping[w]["embedding"] for w in batch["audio_unique_names"]] |
| d_vectors = torch.FloatTensor(d_vectors) |
|
|
| batch["d_vectors"] = d_vectors |
| batch["speaker_ids"] = speaker_ids |
| return batch |
|
|
| def format_batch_on_device(self, batch): |
| """Compute spectrograms on the device.""" |
|
|
| ac = self.ap |
|
|
| |
| batch["mel_input"] = wav_to_mel( |
| batch["waveform"], |
| hop_length=ac.hop_length, |
| win_length=ac.win_length, |
| n_fft=ac.fft_size, |
| num_mels=ac.num_mels, |
| sample_rate=ac.sample_rate, |
| fmin=ac.mel_fmin, |
| fmax=ac.mel_fmax, |
| center=False, |
| ) |
|
|
| |
| |
| |
| |
| batch["pitch"] = batch["pitch"][:, :, : batch["mel_input"].shape[2]] if batch["pitch"] is not None else None |
| batch["mel_lengths"] = (batch["mel_input"].shape[2] * batch["waveform_rel_lens"]).int() |
|
|
| |
| batch["mel_input"] = batch["mel_input"] * sequence_mask(batch["mel_lengths"]).unsqueeze(1) |
|
|
| |
| |
|
|
| if self.config.use_attn_priors: |
| attn_priors_np = batch["attn_priors"] |
|
|
| batch["attn_priors"] = torch.zeros( |
| batch["mel_input"].shape[0], |
| batch["mel_lengths"].max(), |
| batch["text_lengths"].max(), |
| device=batch["mel_input"].device, |
| ) |
|
|
| for i in range(batch["mel_input"].shape[0]): |
| batch["attn_priors"][i, : attn_priors_np[i].shape[0], : attn_priors_np[i].shape[1]] = torch.from_numpy( |
| attn_priors_np[i] |
| ) |
|
|
| batch["energy"] = None |
| batch["energy"] = wav_to_energy( |
| batch["waveform"], |
| hop_length=ac.hop_length, |
| win_length=ac.win_length, |
| n_fft=ac.fft_size, |
| center=False, |
| ) |
| batch["energy"] = self.energy_scaler(batch["energy"]) |
| return batch |
|
|
| def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1): |
| weights = None |
| data_items = dataset.samples |
| if getattr(config, "use_weighted_sampler", False): |
| for attr_name, alpha in config.weighted_sampler_attrs.items(): |
| logger.info("Using weighted sampler for attribute '%s' with alpha %.2f", attr_name, alpha) |
| multi_dict = config.weighted_sampler_multipliers.get(attr_name, None) |
| logger.info(multi_dict) |
| weights, attr_names, attr_weights = get_attribute_balancer_weights( |
| attr_name=attr_name, items=data_items, multi_dict=multi_dict |
| ) |
| weights = weights * alpha |
| logger.info("Attribute weights for '%s' \n | > %s", attr_names, attr_weights) |
|
|
| if weights is not None: |
| sampler = WeightedRandomSampler(weights, len(weights)) |
| else: |
| sampler = None |
| |
| if sampler is None: |
| sampler = DistributedSampler(dataset) if num_gpus > 1 else None |
| else: |
| sampler = DistributedSamplerWrapper(sampler) if num_gpus > 1 else sampler |
| return sampler |
|
|
| def get_data_loader( |
| self, |
| config: Coqpit, |
| assets: Dict, |
| is_eval: bool, |
| samples: Union[List[Dict], List[List]], |
| verbose: bool, |
| num_gpus: int, |
| rank: int = None, |
| ) -> "DataLoader": |
| if is_eval and not config.run_eval: |
| loader = None |
| else: |
| |
| dataset = ForwardTTSE2eDataset( |
| samples=samples, |
| ap=self.ap, |
| batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, |
| min_text_len=config.min_text_len, |
| max_text_len=config.max_text_len, |
| min_audio_len=config.min_audio_len, |
| max_audio_len=config.max_audio_len, |
| phoneme_cache_path=config.phoneme_cache_path, |
| precompute_num_workers=config.precompute_num_workers, |
| compute_f0=config.compute_f0, |
| f0_cache_path=config.f0_cache_path, |
| attn_prior_cache_path=config.attn_prior_cache_path if config.use_attn_priors else None, |
| tokenizer=self.tokenizer, |
| start_by_longest=config.start_by_longest, |
| ) |
|
|
| |
| if num_gpus > 1: |
| dist.barrier() |
|
|
| |
| dataset.preprocess_samples() |
|
|
| |
| sampler = self.get_sampler(config, dataset, num_gpus) |
|
|
| loader = DataLoader( |
| dataset, |
| batch_size=config.eval_batch_size if is_eval else config.batch_size, |
| shuffle=False, |
| drop_last=False, |
| sampler=sampler, |
| collate_fn=dataset.collate_fn, |
| num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, |
| pin_memory=True, |
| ) |
|
|
| |
| self.pitch_mean = dataset.f0_dataset.mean |
| self.pitch_std = dataset.f0_dataset.std |
| return loader |
|
|
| def get_criterion(self): |
| return [VitsDiscriminatorLoss(self.config), DelightfulTTSLoss(self.config)] |
|
|
| def get_optimizer(self) -> List: |
| """Initiate and return the GAN optimizers based on the config parameters. |
| It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator. |
| Returns: |
| List: optimizers. |
| """ |
| optimizer_disc = get_optimizer( |
| self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc |
| ) |
| gen_parameters = chain(params for k, params in self.named_parameters() if not k.startswith("disc.")) |
| optimizer_gen = get_optimizer( |
| self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters |
| ) |
| return [optimizer_disc, optimizer_gen] |
|
|
| def get_lr(self) -> List: |
| """Set the initial learning rates for each optimizer. |
| |
| Returns: |
| List: learning rates for each optimizer. |
| """ |
| return [self.config.lr_disc, self.config.lr_gen] |
|
|
| def get_scheduler(self, optimizer) -> List: |
| """Set the schedulers for each optimizer. |
| |
| Args: |
| optimizer (List[`torch.optim.Optimizer`]): List of optimizers. |
| |
| Returns: |
| List: Schedulers, one for each optimizer. |
| """ |
| scheduler_D = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0]) |
| scheduler_G = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1]) |
| return [scheduler_D, scheduler_G] |
|
|
| def on_epoch_end(self, trainer): |
| |
| |
| self.energy_scaler.eval() |
|
|
| @staticmethod |
| def init_from_config( |
| config: "DelightfulTTSConfig", samples: Union[List[List], List[Dict]] = None |
| ): |
| """Initiate model from config |
| |
| Args: |
| config (ForwardTTSE2eConfig): Model config. |
| samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. |
| Defaults to None. |
| """ |
|
|
| tokenizer, new_config = TTSTokenizer.init_from_config(config) |
| speaker_manager = SpeakerManager.init_from_config(config.model_args, samples) |
| ap = AudioProcessor.init_from_config(config=config) |
| return DelightfulTTS(config=new_config, tokenizer=tokenizer, speaker_manager=speaker_manager, ap=ap) |
|
|
| def load_checkpoint(self, config, checkpoint_path, eval=False): |
| """Load model from a checkpoint created by the 👟""" |
| |
| state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) |
| self.load_state_dict(state["model"]) |
| if eval: |
| self.eval() |
| assert not self.training |
|
|
| def get_state_dict(self): |
| """Custom state dict of the model with all the necessary components for inference.""" |
| save_state = {"config": self.config.to_dict(), "args": self.args.to_dict(), "model": self.state_dict} |
|
|
| if hasattr(self, "emb_g"): |
| save_state["speaker_ids"] = self.speaker_manager.speaker_names |
|
|
| if self.args.use_d_vector_file: |
| |
| ... |
| return save_state |
|
|
| def save(self, config, checkpoint_path): |
| """Save model to a file.""" |
| save_state = self.get_state_dict(config, checkpoint_path) |
| save_state["pitch_mean"] = self.pitch_mean |
| save_state["pitch_std"] = self.pitch_std |
| torch.save(save_state, checkpoint_path) |
|
|
| def on_train_step_start(self, trainer) -> None: |
| """Enable the discriminator training based on `steps_to_start_discriminator` |
| |
| Args: |
| trainer (Trainer): Trainer object. |
| """ |
| self.binary_loss_weight = min(trainer.epochs_done / self.config.binary_loss_warmup_epochs, 1.0) * 1.0 |
| self.train_disc = ( |
| trainer.total_steps_done >= self.config.steps_to_start_discriminator |
| ) |
|
|
|
|
| class DelightfulTTSLoss(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
|
|
| self.mse_loss = nn.MSELoss() |
| self.mae_loss = nn.L1Loss() |
| self.forward_sum_loss = ForwardSumLoss() |
| self.multi_scale_stft_loss = MultiScaleSTFTLoss(**config.multi_scale_stft_loss_params) |
|
|
| self.mel_loss_alpha = config.mel_loss_alpha |
| self.aligner_loss_alpha = config.aligner_loss_alpha |
| self.pitch_loss_alpha = config.pitch_loss_alpha |
| self.energy_loss_alpha = config.energy_loss_alpha |
| self.u_prosody_loss_alpha = config.u_prosody_loss_alpha |
| self.p_prosody_loss_alpha = config.p_prosody_loss_alpha |
| self.dur_loss_alpha = config.dur_loss_alpha |
| self.char_dur_loss_alpha = config.char_dur_loss_alpha |
| self.binary_alignment_loss_alpha = config.binary_align_loss_alpha |
|
|
| self.vocoder_mel_loss_alpha = config.vocoder_mel_loss_alpha |
| self.feat_loss_alpha = config.feat_loss_alpha |
| self.gen_loss_alpha = config.gen_loss_alpha |
| self.multi_scale_stft_loss_alpha = config.multi_scale_stft_loss_alpha |
|
|
| @staticmethod |
| def _binary_alignment_loss(alignment_hard, alignment_soft): |
| """Binary loss that forces soft alignments to match the hard alignments as |
| explained in `https://arxiv.org/pdf/2108.10447.pdf`. |
| """ |
| log_sum = torch.log(torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)).sum() |
| return -log_sum / alignment_hard.sum() |
|
|
| @staticmethod |
| def feature_loss(feats_real, feats_generated): |
| loss = 0 |
| for dr, dg in zip(feats_real, feats_generated): |
| for rl, gl in zip(dr, dg): |
| rl = rl.float().detach() |
| gl = gl.float() |
| loss += torch.mean(torch.abs(rl - gl)) |
| return loss * 2 |
|
|
| @staticmethod |
| def generator_loss(scores_fake): |
| loss = 0 |
| gen_losses = [] |
| for dg in scores_fake: |
| dg = dg.float() |
| l = torch.mean((1 - dg) ** 2) |
| gen_losses.append(l) |
| loss += l |
|
|
| return loss, gen_losses |
|
|
| def forward( |
| self, |
| mel_output, |
| mel_target, |
| mel_lens, |
| dur_output, |
| dur_target, |
| pitch_output, |
| pitch_target, |
| energy_output, |
| energy_target, |
| src_lens, |
| waveform, |
| waveform_hat, |
| p_prosody_ref, |
| p_prosody_pred, |
| u_prosody_ref, |
| u_prosody_pred, |
| aligner_logprob, |
| aligner_hard, |
| aligner_soft, |
| binary_loss_weight=None, |
| feats_fake=None, |
| feats_real=None, |
| scores_fake=None, |
| spec_slice=None, |
| spec_slice_hat=None, |
| skip_disc=False, |
| ): |
| """ |
| Shapes: |
| - mel_output: :math:`(B, C_mel, T_mel)` |
| - mel_target: :math:`(B, C_mel, T_mel)` |
| - mel_lens: :math:`(B)` |
| - dur_output: :math:`(B, T_src)` |
| - dur_target: :math:`(B, T_src)` |
| - pitch_output: :math:`(B, 1, T_src)` |
| - pitch_target: :math:`(B, 1, T_src)` |
| - energy_output: :math:`(B, 1, T_src)` |
| - energy_target: :math:`(B, 1, T_src)` |
| - src_lens: :math:`(B)` |
| - waveform: :math:`(B, 1, T_wav)` |
| - waveform_hat: :math:`(B, 1, T_wav)` |
| - p_prosody_ref: :math:`(B, T_src, 4)` |
| - p_prosody_pred: :math:`(B, T_src, 4)` |
| - u_prosody_ref: :math:`(B, 1, 256) |
| - u_prosody_pred: :math:`(B, 1, 256) |
| - aligner_logprob: :math:`(B, 1, T_mel, T_src)` |
| - aligner_hard: :math:`(B, T_mel, T_src)` |
| - aligner_soft: :math:`(B, T_mel, T_src)` |
| - spec_slice: :math:`(B, C_mel, T_mel)` |
| - spec_slice_hat: :math:`(B, C_mel, T_mel)` |
| """ |
| loss_dict = {} |
| src_mask = sequence_mask(src_lens).to(mel_output.device) |
| mel_mask = sequence_mask(mel_lens).to(mel_output.device) |
|
|
| dur_target.requires_grad = False |
| mel_target.requires_grad = False |
| pitch_target.requires_grad = False |
|
|
| masked_mel_predictions = mel_output.masked_select(mel_mask[:, None]) |
| mel_targets = mel_target.masked_select(mel_mask[:, None]) |
| mel_loss = self.mae_loss(masked_mel_predictions, mel_targets) |
|
|
| p_prosody_ref = p_prosody_ref.detach() |
| p_prosody_loss = 0.5 * self.mae_loss( |
| p_prosody_ref.masked_select(src_mask.unsqueeze(-1)), |
| p_prosody_pred.masked_select(src_mask.unsqueeze(-1)), |
| ) |
|
|
| u_prosody_ref = u_prosody_ref.detach() |
| u_prosody_loss = 0.5 * self.mae_loss(u_prosody_ref, u_prosody_pred) |
|
|
| duration_loss = self.mse_loss(dur_output, dur_target) |
|
|
| pitch_output = pitch_output.masked_select(src_mask[:, None]) |
| pitch_target = pitch_target.masked_select(src_mask[:, None]) |
| pitch_loss = self.mse_loss(pitch_output, pitch_target) |
|
|
| energy_output = energy_output.masked_select(src_mask[:, None]) |
| energy_target = energy_target.masked_select(src_mask[:, None]) |
| energy_loss = self.mse_loss(energy_output, energy_target) |
|
|
| forward_sum_loss = self.forward_sum_loss(aligner_logprob, src_lens, mel_lens) |
|
|
| total_loss = ( |
| (mel_loss * self.mel_loss_alpha) |
| + (duration_loss * self.dur_loss_alpha) |
| + (u_prosody_loss * self.u_prosody_loss_alpha) |
| + (p_prosody_loss * self.p_prosody_loss_alpha) |
| + (pitch_loss * self.pitch_loss_alpha) |
| + (energy_loss * self.energy_loss_alpha) |
| + (forward_sum_loss * self.aligner_loss_alpha) |
| ) |
|
|
| if self.binary_alignment_loss_alpha > 0 and aligner_hard is not None: |
| binary_alignment_loss = self._binary_alignment_loss(aligner_hard, aligner_soft) |
| total_loss = total_loss + self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight |
| if binary_loss_weight: |
| loss_dict["loss_binary_alignment"] = ( |
| self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight |
| ) |
| else: |
| loss_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss |
|
|
| loss_dict["loss_aligner"] = self.aligner_loss_alpha * forward_sum_loss |
| loss_dict["loss_mel"] = self.mel_loss_alpha * mel_loss |
| loss_dict["loss_duration"] = self.dur_loss_alpha * duration_loss |
| loss_dict["loss_u_prosody"] = self.u_prosody_loss_alpha * u_prosody_loss |
| loss_dict["loss_p_prosody"] = self.p_prosody_loss_alpha * p_prosody_loss |
| loss_dict["loss_pitch"] = self.pitch_loss_alpha * pitch_loss |
| loss_dict["loss_energy"] = self.energy_loss_alpha * energy_loss |
| loss_dict["loss"] = total_loss |
|
|
| |
| if not skip_disc: |
| loss_feat = self.feature_loss(feats_real=feats_real, feats_generated=feats_fake) * self.feat_loss_alpha |
| loss_gen = self.generator_loss(scores_fake=scores_fake)[0] * self.gen_loss_alpha |
| loss_dict["vocoder_loss_feat"] = loss_feat |
| loss_dict["vocoder_loss_gen"] = loss_gen |
| loss_dict["loss"] = loss_dict["loss"] + loss_feat + loss_gen |
|
|
| loss_mel = torch.nn.functional.l1_loss(spec_slice, spec_slice_hat) * self.vocoder_mel_loss_alpha |
| loss_stft_mg, loss_stft_sc = self.multi_scale_stft_loss(y_hat=waveform_hat, y=waveform) |
| loss_stft_mg = loss_stft_mg * self.multi_scale_stft_loss_alpha |
| loss_stft_sc = loss_stft_sc * self.multi_scale_stft_loss_alpha |
|
|
| loss_dict["vocoder_loss_mel"] = loss_mel |
| loss_dict["vocoder_loss_stft_mg"] = loss_stft_mg |
| loss_dict["vocoder_loss_stft_sc"] = loss_stft_sc |
|
|
| loss_dict["loss"] = loss_dict["loss"] + loss_mel + loss_stft_sc + loss_stft_mg |
| return loss_dict |
|
|