| |
|
|
| """TTS mode decoding.""" |
|
|
| import argparse |
| import logging |
| from pathlib import Path |
| import shutil |
| import sys |
| import time |
| from typing import Optional |
| from typing import Sequence |
| from typing import Tuple |
| from typing import Union |
| from collections import defaultdict |
| import json |
|
|
| import matplotlib |
| import numpy as np |
| import soundfile as sf |
| import torch |
| from typeguard import check_argument_types |
|
|
| from espnet.utils.cli_utils import get_commandline_args |
| from espnet2.fileio.npy_scp import NpyScpWriter |
| from espnet2.tasks.tts import TTSTask |
| from espnet2.torch_utils.device_funcs import to_device |
| from espnet2.torch_utils.set_all_random_seed import set_all_random_seed |
| from espnet2.tts.duration_calculator import DurationCalculator |
| from espnet2.tts.fastspeech import FastSpeech |
| from espnet2.tts.fastspeech2 import FastSpeech2 |
| from espnet2.tts.fastespeech import FastESpeech |
| from espnet2.tts.tacotron2 import Tacotron2 |
| from espnet2.tts.transformer import Transformer |
| from espnet2.utils import config_argparse |
| from espnet2.utils.get_default_kwargs import get_default_kwargs |
| from espnet2.utils.griffin_lim import Spectrogram2Waveform |
| from espnet2.utils.nested_dict_action import NestedDictAction |
| from espnet2.utils.types import str2bool |
| from espnet2.utils.types import str2triple_str |
| from espnet2.utils.types import str_or_none |
|
|
|
|
| class Text2Speech: |
| """Speech2Text class |
| |
| Examples: |
| >>> import soundfile |
| >>> text2speech = Text2Speech("config.yml", "model.pth") |
| >>> wav = text2speech("Hello World")[0] |
| >>> soundfile.write("out.wav", wav.numpy(), text2speech.fs, "PCM_16") |
| |
| """ |
|
|
| def __init__( |
| self, |
| train_config: Optional[Union[Path, str]], |
| model_file: Optional[Union[Path, str]] = None, |
| threshold: float = 0.5, |
| minlenratio: float = 0.0, |
| maxlenratio: float = 10.0, |
| use_teacher_forcing: bool = False, |
| use_att_constraint: bool = False, |
| backward_window: int = 1, |
| forward_window: int = 3, |
| speed_control_alpha: float = 1.0, |
| vocoder_conf: dict = None, |
| dtype: str = "float32", |
| device: str = "cpu", |
| ): |
| assert check_argument_types() |
|
|
| model, train_args = TTSTask.build_model_from_file( |
| train_config, model_file, device |
| ) |
| model.to(dtype=getattr(torch, dtype)).eval() |
| self.device = device |
| self.dtype = dtype |
| self.train_args = train_args |
| self.model = model |
| self.tts = model.tts |
| self.normalize = model.normalize |
| self.feats_extract = model.feats_extract |
| self.duration_calculator = DurationCalculator() |
| self.preprocess_fn = TTSTask.build_preprocess_fn(train_args, False) |
| self.use_teacher_forcing = use_teacher_forcing |
|
|
| logging.info(f"Normalization:\n{self.normalize}") |
| logging.info(f"TTS:\n{self.tts}") |
|
|
| decode_config = {} |
| if isinstance(self.tts, (Tacotron2, Transformer)): |
| decode_config.update( |
| { |
| "threshold": threshold, |
| "maxlenratio": maxlenratio, |
| "minlenratio": minlenratio, |
| } |
| ) |
| if isinstance(self.tts, Tacotron2): |
| decode_config.update( |
| { |
| "use_att_constraint": use_att_constraint, |
| "forward_window": forward_window, |
| "backward_window": backward_window, |
| } |
| ) |
| if isinstance(self.tts, (FastSpeech, FastSpeech2, FastESpeech)): |
| decode_config.update({"alpha": speed_control_alpha}) |
| decode_config.update({"use_teacher_forcing": use_teacher_forcing}) |
|
|
| self.decode_config = decode_config |
|
|
| if vocoder_conf is None: |
| vocoder_conf = {} |
| if self.feats_extract is not None: |
| vocoder_conf.update(self.feats_extract.get_parameters()) |
| if ( |
| "n_fft" in vocoder_conf |
| and "n_shift" in vocoder_conf |
| and "fs" in vocoder_conf |
| ): |
| self.spc2wav = Spectrogram2Waveform(**vocoder_conf) |
| logging.info(f"Vocoder: {self.spc2wav}") |
| else: |
| self.spc2wav = None |
| logging.info("Vocoder is not used because vocoder_conf is not sufficient") |
|
|
| @torch.no_grad() |
| def __call__( |
| self, |
| text: Union[str, torch.Tensor, np.ndarray], |
| speech: Union[torch.Tensor, np.ndarray] = None, |
| durations: Union[torch.Tensor, np.ndarray] = None, |
| ref_embs: torch.Tensor = None, |
| spembs: Union[torch.Tensor, np.ndarray] = None, |
| fg_inds: torch.Tensor = None, |
| ): |
| assert check_argument_types() |
|
|
| if self.use_speech and speech is None: |
| raise RuntimeError("missing required argument: 'speech'") |
|
|
| if isinstance(text, str): |
| |
| text = self.preprocess_fn("<dummy>", {"text": text})["text"] |
| batch = {"text": text, "ref_embs": ref_embs, "ar_prior_inference": True, "fg_inds": fg_inds} |
| if speech is not None: |
| batch["speech"] = speech |
| if durations is not None: |
| batch["durations"] = durations |
| if spembs is not None: |
| batch["spembs"] = spembs |
|
|
| batch = to_device(batch, self.device) |
| outs, outs_denorm, probs, att_ws, ref_embs, ar_prior_loss = self.model.inference( |
| **batch, **self.decode_config |
| ) |
|
|
| if att_ws is not None: |
| duration, focus_rate = self.duration_calculator(att_ws) |
| else: |
| duration, focus_rate = None, None |
|
|
| if self.spc2wav is not None: |
| wav = torch.tensor(self.spc2wav(outs_denorm.cpu().numpy())) |
| else: |
| wav = None |
|
|
| return wav, outs, outs_denorm, probs, att_ws, duration, focus_rate, ref_embs |
|
|
| @property |
| def fs(self) -> Optional[int]: |
| if self.spc2wav is not None: |
| return self.spc2wav.fs |
| else: |
| return None |
|
|
| @property |
| def use_speech(self) -> bool: |
| """Check whether to require speech in inference. |
| |
| Returns: |
| bool: True if speech is required else False. |
| |
| """ |
| |
| return self.use_teacher_forcing or getattr(self.tts, "use_gst", False) |
|
|
|
|
| def inference( |
| output_dir: str, |
| batch_size: int, |
| dtype: str, |
| ngpu: int, |
| seed: int, |
| num_workers: int, |
| log_level: Union[int, str], |
| data_path_and_name_and_type: Sequence[Tuple[str, str, str]], |
| key_file: Optional[str], |
| train_config: Optional[str], |
| model_file: Optional[str], |
| ref_embs: Optional[str], |
| threshold: float, |
| minlenratio: float, |
| maxlenratio: float, |
| use_teacher_forcing: bool, |
| use_att_constraint: bool, |
| backward_window: int, |
| forward_window: int, |
| speed_control_alpha: float, |
| allow_variable_data_keys: bool, |
| vocoder_conf: dict, |
| ): |
| """Perform TTS model decoding.""" |
| assert check_argument_types() |
| if batch_size > 1: |
| raise NotImplementedError("batch decoding is not implemented") |
| if ngpu > 1: |
| raise NotImplementedError("only single GPU decoding is supported") |
| logging.basicConfig( |
| level=log_level, |
| format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
| ) |
|
|
| if len(ref_embs) > 0: |
| ref_emb_in = torch.load(ref_embs).squeeze(0) |
| else: |
| ref_emb_in = None |
|
|
| if ngpu >= 1: |
| device = "cuda" |
| else: |
| device = "cpu" |
|
|
| |
| set_all_random_seed(seed) |
|
|
| |
| text2speech = Text2Speech( |
| train_config=train_config, |
| model_file=model_file, |
| threshold=threshold, |
| maxlenratio=maxlenratio, |
| minlenratio=minlenratio, |
| use_teacher_forcing=use_teacher_forcing, |
| use_att_constraint=use_att_constraint, |
| backward_window=backward_window, |
| forward_window=forward_window, |
| speed_control_alpha=speed_control_alpha, |
| vocoder_conf=vocoder_conf, |
| dtype=dtype, |
| device=device, |
| ) |
|
|
| |
| if not text2speech.use_speech: |
| data_path_and_name_and_type = list( |
| filter(lambda x: x[1] != "speech", data_path_and_name_and_type) |
| ) |
| loader = TTSTask.build_streaming_iterator( |
| data_path_and_name_and_type, |
| dtype=dtype, |
| batch_size=batch_size, |
| key_file=key_file, |
| num_workers=num_workers, |
| preprocess_fn=TTSTask.build_preprocess_fn(text2speech.train_args, False), |
| collate_fn=TTSTask.build_collate_fn(text2speech.train_args, False), |
| allow_variable_data_keys=allow_variable_data_keys, |
| inference=True, |
| ) |
|
|
| |
| output_dir = Path(output_dir) |
| (output_dir / "norm").mkdir(parents=True, exist_ok=True) |
| (output_dir / "denorm").mkdir(parents=True, exist_ok=True) |
| (output_dir / "speech_shape").mkdir(parents=True, exist_ok=True) |
| (output_dir / "wav").mkdir(parents=True, exist_ok=True) |
| (output_dir / "att_ws").mkdir(parents=True, exist_ok=True) |
| (output_dir / "probs").mkdir(parents=True, exist_ok=True) |
| (output_dir / "durations").mkdir(parents=True, exist_ok=True) |
| (output_dir / "focus_rates").mkdir(parents=True, exist_ok=True) |
|
|
| |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| from matplotlib.ticker import MaxNLocator |
|
|
| with NpyScpWriter( |
| output_dir / "norm", |
| output_dir / "norm/feats.scp", |
| ) as norm_writer, NpyScpWriter( |
| output_dir / "denorm", output_dir / "denorm/feats.scp" |
| ) as denorm_writer, open( |
| output_dir / "speech_shape/speech_shape", "w" |
| ) as shape_writer, open( |
| output_dir / "durations/durations", "w" |
| ) as duration_writer, open( |
| output_dir / "focus_rates/focus_rates", "w" |
| ) as focus_rate_writer, open( |
| output_dir / "ref_embs", "w" |
| ) as ref_embs_writer: |
| ref_embs_list = [] |
| ref_embs_dict = defaultdict(list) |
| for idx, (keys, batch) in enumerate(loader, 1): |
| assert isinstance(batch, dict), type(batch) |
| assert all(isinstance(s, str) for s in keys), keys |
| _bs = len(next(iter(batch.values()))) |
| assert _bs == 1, _bs |
|
|
| |
| |
| batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")} |
|
|
| start_time = time.perf_counter() |
| wav, outs, outs_denorm, probs, att_ws, duration, focus_rate, \ |
| ref_embs = text2speech(ref_embs=ref_emb_in, **batch) |
|
|
| key = keys[0] |
| insize = next(iter(batch.values())).size(0) + 1 |
| logging.info( |
| "inference speed = {:.1f} frames / sec.".format( |
| int(outs.size(0)) / (time.perf_counter() - start_time) |
| ) |
| ) |
| logging.info(f"{key} (size:{insize}->{outs.size(0)})") |
| if outs.size(0) == insize * maxlenratio: |
| logging.warning(f"output length reaches maximum length ({key}).") |
|
|
| norm_writer[key] = outs.cpu().numpy() |
| shape_writer.write(f"{key} " + ",".join(map(str, outs.shape)) + "\n") |
|
|
| denorm_writer[key] = outs_denorm.cpu().numpy() |
|
|
| if duration is not None: |
| |
| duration_writer.write( |
| f"{key} " + " ".join(map(str, duration.cpu().numpy())) + "\n" |
| ) |
| focus_rate_writer.write(f"{key} {float(focus_rate):.5f}\n") |
|
|
| |
| att_ws = att_ws.cpu().numpy() |
|
|
| if att_ws.ndim == 2: |
| att_ws = att_ws[None][None] |
| elif att_ws.ndim != 4: |
| raise RuntimeError(f"Must be 2 or 4 dimension: {att_ws.ndim}") |
|
|
| w, h = plt.figaspect(att_ws.shape[0] / att_ws.shape[1]) |
| fig = plt.Figure( |
| figsize=( |
| w * 1.3 * min(att_ws.shape[0], 2.5), |
| h * 1.3 * min(att_ws.shape[1], 2.5), |
| ) |
| ) |
| fig.suptitle(f"{key}") |
| axes = fig.subplots(att_ws.shape[0], att_ws.shape[1]) |
| if len(att_ws) == 1: |
| axes = [[axes]] |
| for ax, att_w in zip(axes, att_ws): |
| for ax_, att_w_ in zip(ax, att_w): |
| ax_.imshow(att_w_.astype(np.float32), aspect="auto") |
| ax_.set_xlabel("Input") |
| ax_.set_ylabel("Output") |
| ax_.xaxis.set_major_locator(MaxNLocator(integer=True)) |
| ax_.yaxis.set_major_locator(MaxNLocator(integer=True)) |
|
|
| fig.set_tight_layout({"rect": [0, 0.03, 1, 0.95]}) |
| fig.savefig(output_dir / f"att_ws/{key}.png") |
| fig.clf() |
|
|
| if probs is not None: |
| |
| probs = probs.cpu().numpy() |
|
|
| fig = plt.Figure() |
| ax = fig.add_subplot(1, 1, 1) |
| ax.plot(probs) |
| ax.set_title(f"{key}") |
| ax.set_xlabel("Output") |
| ax.set_ylabel("Stop probability") |
| ax.set_ylim(0, 1) |
| ax.grid(which="both") |
|
|
| fig.set_tight_layout(True) |
| fig.savefig(output_dir / f"probs/{key}.png") |
| fig.clf() |
|
|
| |
| if wav is not None: |
| sf.write( |
| f"{output_dir}/wav/{key}.wav", wav.numpy(), text2speech.fs, "PCM_16" |
| ) |
|
|
| if ref_embs is not None: |
| ref_emb_key = -1 |
| for index, ref_emb in enumerate(ref_embs_list): |
| if torch.equal(ref_emb, ref_embs): |
| ref_emb_key = index |
| if ref_emb_key == -1: |
| ref_emb_key = len(ref_embs_list) |
| ref_embs_list.append(ref_embs) |
| ref_embs_dict[ref_emb_key].append(key) |
|
|
| ref_embs_writer.write(json.dumps(ref_embs_dict)) |
| for index, ref_emb in enumerate(ref_embs_list): |
| filename = "ref_embs_" + str(index) + ".pt" |
| torch.save(ref_emb, output_dir / filename) |
|
|
| |
| if att_ws is None: |
| shutil.rmtree(output_dir / "att_ws") |
| shutil.rmtree(output_dir / "durations") |
| shutil.rmtree(output_dir / "focus_rates") |
| if probs is None: |
| shutil.rmtree(output_dir / "probs") |
|
|
|
|
| def get_parser(): |
| """Get argument parser.""" |
| parser = config_argparse.ArgumentParser( |
| description="TTS Decode", |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| ) |
|
|
| |
| |
| parser.add_argument( |
| "--log_level", |
| type=lambda x: x.upper(), |
| default="INFO", |
| choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), |
| help="The verbose level of logging", |
| ) |
|
|
| parser.add_argument( |
| "--output_dir", |
| type=str, |
| required=True, |
| help="The path of output directory", |
| ) |
| parser.add_argument( |
| "--ngpu", |
| type=int, |
| default=0, |
| help="The number of gpus. 0 indicates CPU mode", |
| ) |
| parser.add_argument( |
| "--seed", |
| type=int, |
| default=0, |
| help="Random seed", |
| ) |
| parser.add_argument( |
| "--dtype", |
| default="float32", |
| choices=["float16", "float32", "float64"], |
| help="Data type", |
| ) |
| parser.add_argument( |
| "--num_workers", |
| type=int, |
| default=1, |
| help="The number of workers used for DataLoader", |
| ) |
| parser.add_argument( |
| "--batch_size", |
| type=int, |
| default=1, |
| help="The batch size for inference", |
| ) |
|
|
| group = parser.add_argument_group("Input data related") |
| group.add_argument( |
| "--data_path_and_name_and_type", |
| type=str2triple_str, |
| required=True, |
| action="append", |
| ) |
| group.add_argument( |
| "--key_file", |
| type=str_or_none, |
| ) |
| group.add_argument( |
| "--allow_variable_data_keys", |
| type=str2bool, |
| default=False, |
| ) |
| group.add_argument( |
| "--ref_embs", |
| type=str, |
| default=False, |
| ) |
|
|
| group = parser.add_argument_group("The model configuration related") |
| group.add_argument( |
| "--train_config", |
| type=str, |
| help="Training configuration file.", |
| ) |
| group.add_argument( |
| "--model_file", |
| type=str, |
| help="Model parameter file.", |
| ) |
|
|
| group = parser.add_argument_group("Decoding related") |
| group.add_argument( |
| "--maxlenratio", |
| type=float, |
| default=10.0, |
| help="Maximum length ratio in decoding", |
| ) |
| group.add_argument( |
| "--minlenratio", |
| type=float, |
| default=0.0, |
| help="Minimum length ratio in decoding", |
| ) |
| group.add_argument( |
| "--threshold", |
| type=float, |
| default=0.5, |
| help="Threshold value in decoding", |
| ) |
| group.add_argument( |
| "--use_att_constraint", |
| type=str2bool, |
| default=False, |
| help="Whether to use attention constraint", |
| ) |
| group.add_argument( |
| "--backward_window", |
| type=int, |
| default=1, |
| help="Backward window value in attention constraint", |
| ) |
| group.add_argument( |
| "--forward_window", |
| type=int, |
| default=3, |
| help="Forward window value in attention constraint", |
| ) |
| group.add_argument( |
| "--use_teacher_forcing", |
| type=str2bool, |
| default=False, |
| help="Whether to use teacher forcing", |
| ) |
| parser.add_argument( |
| "--speed_control_alpha", |
| type=float, |
| default=1.0, |
| help="Alpha in FastSpeech to change the speed of generated speech", |
| ) |
|
|
| group = parser.add_argument_group("Grriffin-Lim related") |
| group.add_argument( |
| "--vocoder_conf", |
| action=NestedDictAction, |
| default=get_default_kwargs(Spectrogram2Waveform), |
| help="The configuration for Grriffin-Lim", |
| ) |
| return parser |
|
|
|
|
| def main(cmd=None): |
| """Run TTS model decoding.""" |
| print(get_commandline_args(), file=sys.stderr) |
| parser = get_parser() |
| args = parser.parse_args(cmd) |
| kwargs = vars(args) |
| kwargs.pop("config", None) |
| inference(**kwargs) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|