Spaces:
Sleeping
Sleeping
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import logging | |
| import random | |
| from pathlib import Path | |
| from typing import Dict, Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| logger = logging.getLogger(__name__) | |
| class TTSHubInterface(nn.Module): | |
| def __init__(self, cfg, task, model): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.task = task | |
| self.model = model | |
| self.model.eval() | |
| self.update_cfg_with_data_cfg(self.cfg, self.task.data_cfg) | |
| self.generator = self.task.build_generator([self.model], self.cfg) | |
| def phonemize( | |
| cls, | |
| text: str, | |
| lang: Optional[str], | |
| phonemizer: Optional[str] = None, | |
| preserve_punct: bool = False, | |
| to_simplified_zh: bool = False, | |
| ): | |
| if to_simplified_zh: | |
| import hanziconv | |
| text = hanziconv.HanziConv.toSimplified(text) | |
| if phonemizer == "g2p": | |
| import g2p_en | |
| g2p = g2p_en.G2p() | |
| if preserve_punct: | |
| return " ".join("|" if p == " " else p for p in g2p(text)) | |
| else: | |
| res = [{",": "sp", ";": "sp"}.get(p, p) for p in g2p(text)] | |
| return " ".join(p for p in res if p.isalnum()) | |
| if phonemizer == "g2pc": | |
| import g2pc | |
| g2p = g2pc.G2pC() | |
| return " ".join([w[3] for w in g2p(text)]) | |
| elif phonemizer == "ipa": | |
| assert lang is not None | |
| import phonemizer | |
| from phonemizer.separator import Separator | |
| lang_map = {"en": "en-us", "fr": "fr-fr"} | |
| return phonemizer.phonemize( | |
| text, | |
| backend="espeak", | |
| language=lang_map.get(lang, lang), | |
| separator=Separator(word="| ", phone=" "), | |
| ) | |
| else: | |
| return text | |
| def tokenize(cls, text: str, tkn_cfg: Dict[str, str]): | |
| sentencepiece_model = tkn_cfg.get("sentencepiece_model", None) | |
| if sentencepiece_model is not None: | |
| assert Path(sentencepiece_model).exists() | |
| import sentencepiece as sp | |
| spm = sp.SentencePieceProcessor() | |
| spm.Load(sentencepiece_model) | |
| return " ".join(spm.Encode(text, out_type=str)) | |
| else: | |
| return text | |
| def update_cfg_with_data_cfg(cls, cfg, data_cfg): | |
| cfg["task"].vocoder = data_cfg.vocoder.get("type", "griffin_lim") | |
| def get_model_input( | |
| cls, task, text: str, speaker: Optional[int] = None, verbose: bool = False | |
| ): | |
| phonemized = cls.phonemize( | |
| text, | |
| task.data_cfg.hub.get("lang", None), | |
| task.data_cfg.hub.get("phonemizer", None), | |
| task.data_cfg.hub.get("preserve_punct", False), | |
| task.data_cfg.hub.get("to_simplified_zh", False), | |
| ) | |
| tkn_cfg = task.data_cfg.bpe_tokenizer | |
| tokenized = cls.tokenize(phonemized, tkn_cfg) | |
| if verbose: | |
| logger.info(f"text: {text}") | |
| logger.info(f"phonemized: {phonemized}") | |
| logger.info(f"tokenized: {tokenized}") | |
| spk = task.data_cfg.hub.get("speaker", speaker) | |
| n_speakers = len(task.speaker_to_id or {}) | |
| if spk is None and n_speakers > 0: | |
| spk = random.randint(0, n_speakers - 1) | |
| if spk is not None: | |
| spk = max(0, min(spk, n_speakers - 1)) | |
| if verbose: | |
| logger.info(f"speaker: {spk}") | |
| spk = None if spk is None else torch.Tensor([[spk]]).long() | |
| src_tokens = task.src_dict.encode_line(tokenized, add_if_not_exist=False).view( | |
| 1, -1 | |
| ) | |
| src_lengths = torch.Tensor([len(tokenized.split())]).long() | |
| return { | |
| "net_input": { | |
| "src_tokens": src_tokens, | |
| "src_lengths": src_lengths, | |
| "prev_output_tokens": None, | |
| }, | |
| "target_lengths": None, | |
| "speaker": spk, | |
| } | |
| def get_prediction(cls, task, model, generator, sample) -> Tuple[torch.Tensor, int]: | |
| prediction = generator.generate(model, sample) | |
| return prediction[0]["waveform"], task.sr | |
| def predict( | |
| self, text: str, speaker: Optional[int] = None, verbose: bool = False | |
| ) -> Tuple[torch.Tensor, int]: | |
| sample = self.get_model_input(self.task, text, speaker, verbose=verbose) | |
| return self.get_prediction(self.task, self.model, self.generator, sample) | |
| class VocoderHubInterface(nn.Module): | |
| """Vocoder interface to run vocoder models through hub. Currently we only support unit vocoder""" | |
| def __init__(self, cfg, model): | |
| super().__init__() | |
| self.vocoder = model | |
| self.vocoder.eval() | |
| self.sr = 16000 | |
| self.multispkr = self.vocoder.model.multispkr | |
| if self.multispkr: | |
| logger.info("multi-speaker vocoder") | |
| self.num_speakers = cfg.get( | |
| "num_speakers", | |
| 200, | |
| ) # following the default in codehifigan to set to 200 | |
| def get_model_input( | |
| self, | |
| text: str, | |
| speaker: Optional[int] = -1, | |
| ): | |
| units = list(map(int, text.strip().split())) | |
| x = { | |
| "code": torch.LongTensor(units).view(1, -1), | |
| } | |
| if not speaker: | |
| speaker = -1 | |
| if self.multispkr: | |
| assert ( | |
| speaker < self.num_speakers | |
| ), f"invalid --speaker-id ({speaker}) with total #speakers = {self.num_speakers}" | |
| spk = random.randint(0, self.num_speakers - 1) if speaker == -1 else speaker | |
| x["spkr"] = torch.LongTensor([spk]).view(1, 1) | |
| return x | |
| def get_prediction(self, sample, dur_prediction: Optional[bool] = True): | |
| wav = self.vocoder(sample, dur_prediction) | |
| return wav, self.sr | |
| def predict( | |
| self, | |
| text: str, | |
| speaker: Optional[int] = None, | |
| dur_prediction: Optional[bool] = True, | |
| ): | |
| sample = self.get_model_input(text, speaker) | |
| return self.get_prediction(sample, dur_prediction) | |