Spaces:
Sleeping
Sleeping
| import itertools | |
| import os | |
| import warnings | |
| from typing import cast | |
| import librosa | |
| import matplotlib.pyplot as plt | |
| from matplotlib import font_manager as fm, rcParams | |
| import pyloudnorm | |
| import sounddevice | |
| import soundfile | |
| import torch | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore") | |
| from audioseal.builder import create_generator | |
| from omegaconf import DictConfig | |
| from omegaconf import OmegaConf | |
| from speechbrain.pretrained import EncoderClassifier | |
| from torchaudio.transforms import Resample | |
| from Architectures.ToucanTTS.InferenceToucanTTS import ToucanTTS | |
| from Architectures.Vocoder.HiFiGAN_Generator import HiFiGAN | |
| from Preprocessing.AudioPreprocessor import AudioPreprocessor | |
| from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend | |
| from Preprocessing.TextFrontend import get_language_id | |
| from Utility.storage_config import MODELS_DIR | |
| from Utility.utils import cumsum_durations | |
| from Utility.utils import float2pcm | |
| class ToucanTTSInterface(torch.nn.Module): | |
| def __init__( | |
| self, | |
| device="cpu", # device that everything computes on. If a cuda device is available, this can speed things up by an order of magnitude. | |
| tts_model_path=os.path.join( | |
| MODELS_DIR, f"ToucanTTS_Shan", "best.pt" | |
| ), # path to the ToucanTTS checkpoint or just a shorthand if run standalone | |
| vocoder_model_path=os.path.join( | |
| MODELS_DIR, f"Vocoder", "best.pt" | |
| ), # path to the Vocoder checkpoint | |
| language="eng", # initial language of the model, can be changed later with the setter methods | |
| enhance=None, # legacy argument | |
| ): | |
| super().__init__() | |
| self.device = device | |
| if not tts_model_path.endswith(".pt"): | |
| # default to shorthand system | |
| tts_model_path = os.path.join( | |
| MODELS_DIR, f"ToucanTTS_{tts_model_path}", "best.pt" | |
| ) | |
| if "USER" not in os.environ: | |
| os.environ["USER"] = ( | |
| "" # that's the case under Windows, but omegaconf needs this | |
| ) | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore") | |
| watermark_conf = cast( | |
| DictConfig, | |
| OmegaConf.load("InferenceInterfaces/audioseal_wm_16bits.yaml"), | |
| ) | |
| self.watermark = create_generator(watermark_conf) | |
| self.watermark.load_state_dict( | |
| torch.load("Models/audioseal/generator.pth", map_location="cpu")[ | |
| "model" | |
| ] | |
| ) # downloaded from https://dl.fbaipublicfiles.com/audioseal/6edcf62f/generator.pth originally | |
| ################################ | |
| # build text to phone # | |
| ################################ | |
| self.text2phone = ArticulatoryCombinedTextFrontend( | |
| language=language, add_silence_to_end=True | |
| ) | |
| ##################################### | |
| # load phone to features model # | |
| ##################################### | |
| checkpoint = torch.load(tts_model_path, map_location="cpu") | |
| self.phone2mel = ToucanTTS( | |
| weights=checkpoint["model"], config=checkpoint["config"] | |
| ) | |
| with torch.no_grad(): | |
| self.phone2mel.store_inverse_all() # this also removes weight norm | |
| self.phone2mel = self.phone2mel.to(torch.device(device)) | |
| ###################################### | |
| # load features to style models # | |
| ###################################### | |
| self.speaker_embedding_func_ecapa = EncoderClassifier.from_hparams( | |
| source="speechbrain/spkrec-ecapa-voxceleb", | |
| run_opts={"device": str(device)}, | |
| savedir=os.path.join( | |
| MODELS_DIR, "Embedding", "speechbrain_speaker_embedding_ecapa" | |
| ), | |
| ) | |
| ################################ | |
| # load mel to wave model # | |
| ################################ | |
| vocoder_checkpoint = torch.load(vocoder_model_path, map_location="cpu") | |
| self.vocoder = HiFiGAN() | |
| self.vocoder.load_state_dict(vocoder_checkpoint) | |
| self.vocoder = self.vocoder.to(device).eval() | |
| self.vocoder.remove_weight_norm() | |
| self.meter = pyloudnorm.Meter(24000) | |
| ################################ | |
| # set defaults # | |
| ################################ | |
| self.default_utterance_embedding = checkpoint["default_emb"].to(self.device) | |
| self.ap = AudioPreprocessor(input_sr=100, output_sr=16000, device=device) | |
| self.phone2mel.eval() | |
| self.vocoder.eval() | |
| self.lang_id = get_language_id(language) | |
| self.to(torch.device(device)) | |
| self.eval() | |
| def set_utterance_embedding(self, path_to_reference_audio="", embedding=None): | |
| if embedding is not None: | |
| self.default_utterance_embedding = embedding.squeeze().to(self.device) | |
| return | |
| if type(path_to_reference_audio) != list: | |
| path_to_reference_audio = [path_to_reference_audio] | |
| if len(path_to_reference_audio) > 0: | |
| for path in path_to_reference_audio: | |
| assert os.path.exists(path) | |
| speaker_embs = list() | |
| for path in path_to_reference_audio: | |
| wave, sr = soundfile.read(path) | |
| if len(wave.shape) > 1: # oh no, we found a stereo audio! | |
| if ( | |
| len(wave[0]) == 2 | |
| ): # let's figure out whether we need to switch the axes | |
| wave = wave.transpose() # if yes, we switch the axes. | |
| wave = librosa.to_mono(wave) | |
| wave = Resample(orig_freq=sr, new_freq=16000).to(self.device)( | |
| torch.tensor(wave, device=self.device, dtype=torch.float32) | |
| ) | |
| speaker_embedding = self.speaker_embedding_func_ecapa.encode_batch( | |
| wavs=wave.to(self.device).squeeze().unsqueeze(0) | |
| ).squeeze() | |
| speaker_embs.append(speaker_embedding) | |
| self.default_utterance_embedding = sum(speaker_embs) / len(speaker_embs) | |
| def set_language(self, lang_id): | |
| """ | |
| The id parameter actually refers to the shorthand. This has become ambiguous with the introduction of the actual language IDs | |
| """ | |
| self.set_phonemizer_language(lang_id=lang_id) | |
| self.set_accent_language(lang_id=lang_id) | |
| def set_phonemizer_language(self, lang_id): | |
| self.text2phone = ArticulatoryCombinedTextFrontend( | |
| language=lang_id, add_silence_to_end=True | |
| ) | |
| def set_accent_language(self, lang_id): | |
| if lang_id in [ | |
| "ajp", | |
| "ajt", | |
| "lak", | |
| "lno", | |
| "nul", | |
| "pii", | |
| "plj", | |
| "slq", | |
| "smd", | |
| "snb", | |
| "tpw", | |
| "wya", | |
| "zua", | |
| "en-us", | |
| "en-sc", | |
| "fr-be", | |
| "fr-sw", | |
| "pt-br", | |
| "spa-lat", | |
| "vi-ctr", | |
| "vi-so", | |
| ]: | |
| if lang_id == "vi-so" or lang_id == "vi-ctr": | |
| lang_id = "vie" | |
| elif lang_id == "spa-lat": | |
| lang_id = "spa" | |
| elif lang_id == "pt-br": | |
| lang_id = "por" | |
| elif lang_id == "fr-sw" or lang_id == "fr-be": | |
| lang_id = "fra" | |
| elif lang_id == "en-sc" or lang_id == "en-us": | |
| lang_id = "eng" | |
| else: | |
| # no clue where these others are even coming from, they are not in ISO 639-2 | |
| lang_id = "eng" | |
| self.lang_id = get_language_id(lang_id).to(self.device) | |
| def forward( | |
| self, | |
| text, | |
| view=False, | |
| duration_scaling_factor=1.0, | |
| pitch_variance_scale=1.0, | |
| energy_variance_scale=1.0, | |
| pause_duration_scaling_factor=1.0, | |
| durations=None, | |
| pitch=None, | |
| energy=None, | |
| input_is_phones=False, | |
| return_plot_as_filepath=False, | |
| loudness_in_db=-24.0, | |
| glow_sampling_temperature=0.2, | |
| ): | |
| """ | |
| duration_scaling_factor: reasonable values are 0.8 < scale < 1.2. | |
| 1.0 means no scaling happens, higher values increase durations for the whole | |
| utterance, lower values decrease durations for the whole utterance. | |
| pitch_variance_scale: reasonable values are 0.6 < scale < 1.4. | |
| 1.0 means no scaling happens, higher values increase variance of the pitch curve, | |
| lower values decrease variance of the pitch curve. | |
| energy_variance_scale: reasonable values are 0.6 < scale < 1.4. | |
| 1.0 means no scaling happens, higher values increase variance of the energy curve, | |
| lower values decrease variance of the energy curve. | |
| """ | |
| with torch.inference_mode(): | |
| phones = self.text2phone.string_to_tensor( | |
| text, input_phonemes=input_is_phones | |
| ).to(torch.device(self.device)) | |
| mel, durations, pitch, energy = self.phone2mel( | |
| phones, | |
| return_duration_pitch_energy=True, | |
| utterance_embedding=self.default_utterance_embedding, | |
| durations=durations, | |
| pitch=pitch, | |
| energy=energy, | |
| lang_id=self.lang_id, | |
| duration_scaling_factor=duration_scaling_factor, | |
| pitch_variance_scale=pitch_variance_scale, | |
| energy_variance_scale=energy_variance_scale, | |
| pause_duration_scaling_factor=pause_duration_scaling_factor, | |
| glow_sampling_temperature=glow_sampling_temperature, | |
| ) | |
| wave, _, _ = self.vocoder(mel.unsqueeze(0)) | |
| wave = wave.squeeze().cpu() | |
| wave = wave.numpy() | |
| sr = 24000 | |
| try: | |
| loudness = self.meter.integrated_loudness(wave) | |
| wave = pyloudnorm.normalize.loudness(wave, loudness, loudness_in_db) | |
| except ValueError: | |
| # if the audio is too short, a value error will arise | |
| pass | |
| with torch.inference_mode(): | |
| wave = ( | |
| ( | |
| torch.tensor(wave) | |
| + 0.1 | |
| * self.watermark.get_watermark( | |
| torch.tensor(wave).to(self.device).unsqueeze(0).unsqueeze(0) | |
| ) | |
| .squeeze() | |
| .detach() | |
| .cpu() | |
| ) | |
| .detach() | |
| .numpy() | |
| ) | |
| if view or return_plot_as_filepath: | |
| fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(9, 5)) | |
| # fpath = "./src/fonts/Shan.ttf" | |
| fpath = os.path.join(os.path.dirname(__file__), "src/fonts/Shan.ttf") | |
| prop = fm.FontProperties(fname=fpath) | |
| ax.imshow(mel.cpu().numpy(), origin="lower", cmap="GnBu") | |
| ax.yaxis.set_visible(False) | |
| duration_splits, label_positions = cumsum_durations(durations.cpu().numpy()) | |
| ax.xaxis.grid(True, which="minor") | |
| ax.set_xticks(label_positions, minor=False) | |
| if input_is_phones: | |
| phones = text.replace(" ", "|") | |
| else: | |
| phones = self.text2phone.get_phone_string(text, for_plot_labels=True) | |
| ax.set_xticklabels(phones) | |
| word_boundaries = list() | |
| for label_index, phone in enumerate(phones): | |
| if phone == "|": | |
| word_boundaries.append(label_positions[label_index]) | |
| try: | |
| prev_word_boundary = 0 | |
| word_label_positions = list() | |
| for word_boundary in word_boundaries: | |
| word_label_positions.append( | |
| (word_boundary + prev_word_boundary) / 2 | |
| ) | |
| prev_word_boundary = word_boundary | |
| word_label_positions.append( | |
| (duration_splits[-1] + prev_word_boundary) / 2 | |
| ) | |
| secondary_ax = ax.secondary_xaxis("bottom") | |
| secondary_ax.tick_params(axis="x", direction="out", pad=24) | |
| secondary_ax.set_xticks(word_label_positions, minor=False) | |
| secondary_ax.set_xticklabels(text.split(), fontproperties=prop) | |
| secondary_ax.tick_params(axis="x", colors="orange") | |
| secondary_ax.xaxis.label.set_color("orange") | |
| except ValueError: | |
| ax.set_title(text) | |
| except IndexError: | |
| ax.set_title(text) | |
| ax.vlines( | |
| x=duration_splits, | |
| colors="green", | |
| linestyles="solid", | |
| ymin=0, | |
| ymax=120, | |
| linewidth=0.5, | |
| ) | |
| ax.vlines( | |
| x=word_boundaries, | |
| colors="orange", | |
| linestyles="solid", | |
| ymin=0, | |
| ymax=120, | |
| linewidth=1.0, | |
| ) | |
| plt.subplots_adjust( | |
| left=0.02, bottom=0.2, right=0.98, top=0.9, wspace=0.0, hspace=0.0 | |
| ) | |
| ax.set_aspect("auto") | |
| if return_plot_as_filepath: | |
| plt.savefig("tmp.png") | |
| return wave, sr, "tmp.png" | |
| return wave, sr | |
| def read_to_file( | |
| self, | |
| text_list, | |
| file_location, | |
| duration_scaling_factor=1.0, | |
| pitch_variance_scale=1.0, | |
| energy_variance_scale=1.0, | |
| pause_duration_scaling_factor=1.0, | |
| silent=False, | |
| dur_list=None, | |
| pitch_list=None, | |
| energy_list=None, | |
| glow_sampling_temperature=0.2, | |
| ): | |
| """ | |
| Args: | |
| silent: Whether to be verbose about the process | |
| text_list: A list of strings to be read | |
| file_location: The path and name of the file it should be saved to | |
| energy_list: list of energy tensors to be used for the texts | |
| pitch_list: list of pitch tensors to be used for the texts | |
| dur_list: list of duration tensors to be used for the texts | |
| duration_scaling_factor: reasonable values are 0.8 < scale < 1.2. | |
| 1.0 means no scaling happens, higher values increase durations for the whole | |
| utterance, lower values decrease durations for the whole utterance. | |
| pitch_variance_scale: reasonable values are 0.6 < scale < 1.4. | |
| 1.0 means no scaling happens, higher values increase variance of the pitch curve, | |
| lower values decrease variance of the pitch curve. | |
| energy_variance_scale: reasonable values are 0.6 < scale < 1.4. | |
| 1.0 means no scaling happens, higher values increase variance of the energy curve, | |
| lower values decrease variance of the energy curve. | |
| """ | |
| if not dur_list: | |
| dur_list = [] | |
| if not pitch_list: | |
| pitch_list = [] | |
| if not energy_list: | |
| energy_list = [] | |
| silence = torch.zeros([14300]) | |
| wav = silence.clone() | |
| for text, durations, pitch, energy in itertools.zip_longest( | |
| text_list, dur_list, pitch_list, energy_list | |
| ): | |
| if text.strip() != "": | |
| if not silent: | |
| print("Now synthesizing: {}".format(text)) | |
| spoken_sentence, sr = self( | |
| text, | |
| durations=( | |
| durations.to(self.device) if durations is not None else None | |
| ), | |
| pitch=pitch.to(self.device) if pitch is not None else None, | |
| energy=energy.to(self.device) if energy is not None else None, | |
| duration_scaling_factor=duration_scaling_factor, | |
| pitch_variance_scale=pitch_variance_scale, | |
| energy_variance_scale=energy_variance_scale, | |
| pause_duration_scaling_factor=pause_duration_scaling_factor, | |
| glow_sampling_temperature=glow_sampling_temperature, | |
| ) | |
| spoken_sentence = torch.tensor(spoken_sentence).cpu() | |
| wav = torch.cat((wav, spoken_sentence, silence), 0) | |
| soundfile.write( | |
| file=file_location, data=float2pcm(wav), samplerate=sr, subtype="PCM_16" | |
| ) | |
| def read_aloud( | |
| self, | |
| text, | |
| view=False, | |
| duration_scaling_factor=1.0, | |
| pitch_variance_scale=1.0, | |
| energy_variance_scale=1.0, | |
| blocking=False, | |
| glow_sampling_temperature=0.2, | |
| ): | |
| if text.strip() == "": | |
| return | |
| wav, sr = self( | |
| text, | |
| view, | |
| duration_scaling_factor=duration_scaling_factor, | |
| pitch_variance_scale=pitch_variance_scale, | |
| energy_variance_scale=energy_variance_scale, | |
| glow_sampling_temperature=glow_sampling_temperature, | |
| ) | |
| silence = torch.zeros([sr // 2]) | |
| wav = torch.cat((silence, torch.tensor(wav), silence), 0).numpy() | |
| sounddevice.play(float2pcm(wav), samplerate=sr) | |
| if view: | |
| plt.show() | |
| if blocking: | |
| sounddevice.wait() | |