| import logging |
| import os |
| import random |
| from typing import Any |
|
|
| import torch |
| import torch.distributed as dist |
| from coqpit import Coqpit |
| from torch import nn |
| from torch.utils.data import DataLoader |
| from torch.utils.data.sampler import WeightedRandomSampler |
| from trainer.torch import DistributedSampler, DistributedSamplerWrapper |
| from trainer.trainer import Trainer |
|
|
| from TTS.model import BaseTrainerModel |
| from TTS.tts.datasets.dataset import TTSDataset |
| from TTS.tts.utils.data import get_length_balancer_weights |
| from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights |
| from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights |
| from TTS.utils.audio.processor import AudioProcessor |
|
|
| |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class BaseVC(BaseTrainerModel): |
| """Base VC class. Every new voice conversion model must inherit this. |
| |
| It defines common VC-specific functions on top of the :py:class:`~TTS.model.BaseTrainerModel`. |
| """ |
|
|
| MODEL_TYPE = "vc" |
|
|
| def __init__( |
| self, |
| config: Coqpit, |
| ap: AudioProcessor | None = None, |
| speaker_manager: SpeakerManager | None = None, |
| language_manager: LanguageManager | None = None, |
| ) -> None: |
| super().__init__() |
| self.config = config |
| self.ap = ap |
| self.speaker_manager = speaker_manager |
| self.language_manager = language_manager |
| self._set_model_args(config) |
|
|
| def _set_model_args(self, config: Coqpit) -> None: |
| """Set up model args based on the config type (``ModelConfig`` or ``ModelArgs``). |
| |
| ``ModelArgs`` has all the fields required to initialize the model architecture. |
| |
| ``ModelConfig`` has all the fields required for training, inference and containes ``ModelArgs``. |
| |
| If the config is for training with a name like ``*Config``, then the model args are embeded in the |
| ``config.model_args`` |
| |
| If the config is for the model with a name like ``*Args``, then we assign them directly. |
| """ |
| |
| if "Config" in config.__class__.__name__: |
| self.config = config |
| self.args = config.model_args |
| elif "Args" in config.__class__.__name__: |
| self.args = config |
| else: |
| raise ValueError("config must be either a *Config or *Args") |
|
|
| def init_multispeaker(self, config: Coqpit, data: list[Any] | None = None) -> None: |
| """Set up for multi-speaker use. |
| |
| Initialize a speaker embedding layer if needed and define the expected |
| embedding channel size for defining ``in_channels`` size of the connected layers. |
| |
| This implementation yields 3 possible outcomes: |
| |
| 1. If ``config.use_speaker_embedding`` and ``config.use_d_vector_file`` are False, do nothing. |
| 2. If ``config.use_d_vector_file`` is True, set expected embedding channel size to ``config.d_vector_dim`` or 512. |
| 3. If ``config.use_speaker_embedding``, initialize a speaker embedding |
| layer with channel size of ``config.d_vector_dim`` or 512. |
| |
| You can override this function for new models. |
| |
| Args: |
| config (Coqpit): Model configuration. |
| """ |
| |
| if self.speaker_manager is not None: |
| self.num_speakers = self.speaker_manager.num_speakers |
| elif hasattr(config, "num_speakers"): |
| self.num_speakers = config.num_speakers |
|
|
| |
| if config.use_speaker_embedding or config.use_d_vector_file: |
| self.embedded_speaker_dim = ( |
| config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512 |
| ) |
| |
| if config.use_speaker_embedding and not config.use_d_vector_file: |
| logger.info("Init speaker_embedding layer.") |
| self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) |
| self.speaker_embedding.weight.data.normal_(0, 0.3) |
|
|
| def get_aux_input_from_test_sentences(self, sentence_info: str | list[str]) -> dict[str, Any]: |
| if hasattr(self.config, "model_args"): |
| config = self.config.model_args |
| else: |
| config = self.config |
|
|
| |
| text, speaker_name, style_wav, language_name = None, None, None, None |
|
|
| if isinstance(sentence_info, list): |
| if len(sentence_info) == 1: |
| text = sentence_info[0] |
| elif len(sentence_info) == 2: |
| text, speaker_name = sentence_info |
| elif len(sentence_info) == 3: |
| text, speaker_name, style_wav = sentence_info |
| elif len(sentence_info) == 4: |
| text, speaker_name, style_wav, language_name = sentence_info |
| else: |
| text = sentence_info |
|
|
| |
| speaker_id, d_vector, language_id = None, None, None |
| if self.speaker_manager is not None: |
| if config.use_d_vector_file: |
| if speaker_name is None: |
| d_vector = self.speaker_manager.get_random_embedding() |
| else: |
| d_vector = self.speaker_manager.get_mean_embedding(speaker_name) |
| elif config.use_speaker_embedding: |
| if speaker_name is None: |
| speaker_id = self.speaker_manager.get_random_id() |
| else: |
| speaker_id = self.speaker_manager.name_to_id[speaker_name] |
|
|
| |
| if self.language_manager is not None and config.use_language_embedding and language_name is not None: |
| language_id = self.language_manager.name_to_id[language_name] |
|
|
| return { |
| "text": text, |
| "speaker_id": speaker_id, |
| "style_wav": style_wav, |
| "d_vector": d_vector, |
| "language_id": language_id, |
| } |
|
|
| def format_batch(self, batch: dict[str, Any]) -> dict[str, Any]: |
| """Generic batch formatting for ``VCDataset``. |
| |
| You must override this if you use a custom dataset. |
| |
| Args: |
| batch: [description] |
| |
| Returns: |
| dict: [description] |
| """ |
| |
| text_input = batch["token_id"] |
| text_lengths = batch["token_id_lengths"] |
| speaker_names = batch["speaker_names"] |
| linear_input = batch["linear"] |
| mel_input = batch["mel"] |
| mel_lengths = batch["mel_lengths"] |
| stop_targets = batch["stop_targets"] |
| item_idx = batch["item_idxs"] |
| d_vectors = batch["d_vectors"] |
| speaker_ids = batch["speaker_ids"] |
| attn_mask = batch["attns"] |
| waveform = batch["waveform"] |
| pitch = batch["pitch"] |
| energy = batch["energy"] |
| language_ids = batch["language_ids"] |
| max_text_length = torch.max(text_lengths.float()) |
| max_spec_length = torch.max(mel_lengths.float()) |
|
|
| |
| durations = None |
| if attn_mask is not None: |
| durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2]) |
| for idx, am in enumerate(attn_mask): |
| |
| c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1] |
| |
| c_idxs, counts = torch.unique(c_idxs, return_counts=True) |
| dur = torch.ones([text_lengths[idx]]).to(counts.dtype) |
| dur[c_idxs] = counts |
| |
| |
| extra_frames = dur.sum() - mel_lengths[idx] |
| largest_idxs = torch.argsort(-dur)[:extra_frames] |
| dur[largest_idxs] -= 1 |
| assert dur.sum() == mel_lengths[idx], ( |
| f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}" |
| ) |
| durations[idx, : text_lengths[idx]] = dur |
|
|
| |
| stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // self.config.r, -1) |
| stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2) |
| stop_target_lengths = torch.divide(mel_lengths, self.config.r).ceil_() |
|
|
| return { |
| "text_input": text_input, |
| "text_lengths": text_lengths, |
| "speaker_names": speaker_names, |
| "mel_input": mel_input, |
| "mel_lengths": mel_lengths, |
| "linear_input": linear_input, |
| "stop_targets": stop_targets, |
| "stop_target_lengths": stop_target_lengths, |
| "attn_mask": attn_mask, |
| "durations": durations, |
| "speaker_ids": speaker_ids, |
| "d_vectors": d_vectors, |
| "max_text_length": float(max_text_length), |
| "max_spec_length": float(max_spec_length), |
| "item_idx": item_idx, |
| "waveform": waveform, |
| "pitch": pitch, |
| "energy": energy, |
| "language_ids": language_ids, |
| "audio_unique_names": batch["audio_unique_names"], |
| } |
|
|
| def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus: int = 1): |
| weights = None |
| data_items = dataset.samples |
|
|
| if getattr(config, "use_language_weighted_sampler", False): |
| alpha = getattr(config, "language_weighted_sampler_alpha", 1.0) |
| logger.info("Using Language weighted sampler with alpha: %.2f", alpha) |
| weights = get_language_balancer_weights(data_items) * alpha |
|
|
| if getattr(config, "use_speaker_weighted_sampler", False): |
| alpha = getattr(config, "speaker_weighted_sampler_alpha", 1.0) |
| logger.info("Using Speaker weighted sampler with alpha: %.2f", alpha) |
| if weights is not None: |
| weights += get_speaker_balancer_weights(data_items) * alpha |
| else: |
| weights = get_speaker_balancer_weights(data_items) * alpha |
|
|
| if getattr(config, "use_length_weighted_sampler", False): |
| alpha = getattr(config, "length_weighted_sampler_alpha", 1.0) |
| logger.info("Using Length weighted sampler with alpha: %.2f", alpha) |
| if weights is not None: |
| weights += get_length_balancer_weights(data_items) * alpha |
| else: |
| weights = get_length_balancer_weights(data_items) * alpha |
|
|
| if weights is not None: |
| sampler = WeightedRandomSampler(weights, len(weights)) |
| else: |
| sampler = None |
|
|
| |
| if sampler is None: |
| sampler = DistributedSampler(dataset) if num_gpus > 1 else None |
| else: |
| sampler = DistributedSamplerWrapper(sampler) if num_gpus > 1 else sampler |
|
|
| return sampler |
|
|
| def get_data_loader( |
| self, |
| config: Coqpit, |
| assets: dict, |
| is_eval: bool, |
| samples: list[dict] | list[list], |
| verbose: bool, |
| num_gpus: int, |
| rank: int | None = None, |
| ) -> "DataLoader": |
| if is_eval and not config.run_eval: |
| loader = None |
| else: |
| |
| if self.speaker_manager is not None: |
| if hasattr(config, "model_args"): |
| speaker_id_mapping = ( |
| self.speaker_manager.name_to_id if config.model_args.use_speaker_embedding else None |
| ) |
| d_vector_mapping = self.speaker_manager.embeddings if config.model_args.use_d_vector_file else None |
| config.use_d_vector_file = config.model_args.use_d_vector_file |
| else: |
| speaker_id_mapping = self.speaker_manager.name_to_id if config.use_speaker_embedding else None |
| d_vector_mapping = self.speaker_manager.embeddings if config.use_d_vector_file else None |
| else: |
| speaker_id_mapping = None |
| d_vector_mapping = None |
|
|
| |
| if self.language_manager is not None: |
| language_id_mapping = self.language_manager.name_to_id if self.args.use_language_embedding else None |
| else: |
| language_id_mapping = None |
|
|
| |
| dataset = TTSDataset( |
| outputs_per_step=config.r if "r" in config else 1, |
| compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec, |
| compute_f0=config.get("compute_f0", False), |
| f0_cache_path=config.get("f0_cache_path", None), |
| compute_energy=config.get("compute_energy", False), |
| energy_cache_path=config.get("energy_cache_path", None), |
| samples=samples, |
| ap=self.ap, |
| return_wav=config.return_wav if "return_wav" in config else False, |
| batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, |
| min_text_len=config.min_text_len, |
| max_text_len=config.max_text_len, |
| min_audio_len=config.min_audio_len, |
| max_audio_len=config.max_audio_len, |
| phoneme_cache_path=config.phoneme_cache_path, |
| precompute_num_workers=config.precompute_num_workers, |
| use_noise_augment=False if is_eval else config.use_noise_augment, |
| speaker_id_mapping=speaker_id_mapping, |
| d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, |
| tokenizer=None, |
| start_by_longest=config.start_by_longest, |
| language_id_mapping=language_id_mapping, |
| ) |
|
|
| |
| if num_gpus > 1: |
| dist.barrier() |
|
|
| |
| dataset.preprocess_samples() |
|
|
| |
| sampler = self.get_sampler(config, dataset, num_gpus) |
|
|
| loader = DataLoader( |
| dataset, |
| batch_size=config.eval_batch_size if is_eval else config.batch_size, |
| shuffle=config.shuffle if sampler is None else False, |
| collate_fn=dataset.collate_fn, |
| drop_last=config.drop_last, |
| sampler=sampler, |
| num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, |
| pin_memory=False, |
| ) |
| return loader |
|
|
| def _get_test_aux_input( |
| self, |
| ) -> dict[str, Any]: |
| d_vector = None |
| if self.speaker_manager is not None and self.config.use_d_vector_file: |
| d_vector = [self.speaker_manager.embeddings[name]["embedding"] for name in self.speaker_manager.embeddings] |
| d_vector = (random.sample(sorted(d_vector), 1),) |
|
|
| aux_inputs = { |
| "speaker_id": ( |
| None |
| if not self.config.use_speaker_embedding |
| else random.sample(sorted(self.speaker_manager.name_to_id.values()), 1) |
| ), |
| "d_vector": d_vector, |
| "style_wav": None, |
| } |
| return aux_inputs |
|
|
| def test_run(self, assets: dict) -> tuple[dict, dict]: |
| """Generic test run for ``vc`` models used by ``Trainer``. |
| |
| You can override this for a different behaviour. |
| |
| Args: |
| assets (dict): A dict of training assets. For ``vc`` models, it must include ``{'audio_processor': ap}``. |
| |
| Returns: |
| tuple[dict, dict]: Test figures and audios to be projected to Tensorboard. |
| """ |
| raise NotImplementedError |
|
|
| def on_init_start(self, trainer: Trainer) -> None: |
| """Save the speaker.pth and language_ids.json at the beginning of the training. Also update both paths.""" |
| if self.speaker_manager is not None: |
| output_path = os.path.join(trainer.output_path, "speakers.pth") |
| self.speaker_manager.save_ids_to_file(output_path) |
| trainer.config.speakers_file = output_path |
| |
| if hasattr(trainer.config, "model_args"): |
| trainer.config.model_args.speakers_file = output_path |
| trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) |
| logger.info("`speakers.pth` is saved to %s", output_path) |
| logger.info("`speakers_file` is updated in the config.json.") |
|
|
| if self.language_manager is not None: |
| output_path = os.path.join(trainer.output_path, "language_ids.json") |
| self.language_manager.save_ids_to_file(output_path) |
| trainer.config.language_ids_file = output_path |
| if hasattr(trainer.config, "model_args"): |
| trainer.config.model_args.language_ids_file = output_path |
| trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) |
| logger.info("`language_ids.json` is saved to %s", output_path) |
| logger.info("`language_ids_file` is updated in the config.json.") |
|
|