Spaces:
Runtime error
Runtime error
| """ | |
| Synthesize a given text using the trained DiT models. | |
| """ | |
| import json | |
| import os | |
| os.environ["NLTK_DATA"] = "nltk_data" | |
| import torch | |
| import yaml | |
| from g2p_en import G2p | |
| from vocos import Vocos | |
| from sample import sample | |
| def synthesize( | |
| text, | |
| duration_model_config, | |
| duration_model_checkpoint, | |
| acoustic_model_config, | |
| acoustic_model_checkpoint, | |
| speaker_id, | |
| cfg_scale=4.0, | |
| num_sampling_steps=1000, | |
| ): | |
| """ | |
| Synthesize speech from text using trained DiT models. | |
| Args: | |
| text (str): Input text to synthesize | |
| duration_model_config (str): Path to duration model config file | |
| duration_model_checkpoint (str): Path to duration model checkpoint | |
| acoustic_model_config (str): Path to acoustic model config file | |
| acoustic_model_checkpoint (str): Path to acoustic model checkpoint | |
| speaker_id (str): Speaker ID to use for synthesis | |
| cfg_scale (float): Classifier-free guidance scale (default: 4.0) | |
| num_sampling_steps (int): Number of sampling steps for diffusion (default: 1000) | |
| Returns: | |
| numpy.ndarray: Audio waveform array | |
| int: Sample rate (24000) | |
| """ | |
| print("Text:", text) | |
| # Read duration model config | |
| with open(duration_model_config, "r") as f: | |
| duration_config = yaml.safe_load(f) | |
| # Get data directory from data_path | |
| data_dir = os.path.dirname(duration_config["data"]["data_path"]) | |
| # Read maps.json from same directory | |
| with open(os.path.join(data_dir, "maps.json"), "r") as f: | |
| maps = json.load(f) | |
| phone_to_idx = maps["phone_to_idx"] | |
| phone_kind_to_idx = maps["phone_kind_to_idx"] | |
| speaker_id_to_idx = maps["speaker_id_to_idx"] | |
| # Step 1: Text to phonemes | |
| def text_to_phonemes(text, insert_empty=True): | |
| g2p = G2p() | |
| phonemes = g2p(text) | |
| words = [] | |
| word = [] | |
| for p in phonemes: | |
| if p == " ": | |
| if len(word) > 0: | |
| words.append(word) | |
| word = [] | |
| else: | |
| word.append(p) | |
| if len(word) > 0: | |
| words.append(word) | |
| phones = [] | |
| phone_kinds = [] | |
| for word in words: | |
| for i, p in enumerate(word): | |
| if p in [",", ".", "!", "?", ";", ":"]: | |
| p = "EMPTY" | |
| elif p in phone_to_idx: | |
| pass | |
| else: | |
| continue | |
| if p == "EMPTY": | |
| phone_kind = "EMPTY" | |
| elif len(word) == 1: | |
| phone_kind = "WORD" | |
| elif i == 0: | |
| phone_kind = "START" | |
| elif i == len(word) - 1: | |
| phone_kind = "END" | |
| else: | |
| phone_kind = "MIDDLE" | |
| phones.append(p) | |
| phone_kinds.append(phone_kind) | |
| if insert_empty: | |
| if phones[0] != "EMPTY": | |
| phones.insert(0, "EMPTY") | |
| phone_kinds.insert(0, "EMPTY") | |
| if phones[-1] != "EMPTY": | |
| phones.append("EMPTY") | |
| phone_kinds.append("EMPTY") | |
| return phones, phone_kinds | |
| phonemes, phone_kinds = text_to_phonemes(text) | |
| # Convert phonemes to indices | |
| phoneme_indices = [phone_to_idx[p] for p in phonemes] | |
| phone_kind_indices = [phone_kind_to_idx[p] for p in phone_kinds] | |
| print("Phonemes:", phonemes) | |
| # Step 2: Duration prediction | |
| device = torch.device("cuda") # if torch.cuda.is_available() else "cpu") | |
| torch_phoneme_indices = torch.tensor(phoneme_indices)[None, :].long().to(device) | |
| torch_speaker_id = torch.full_like(torch_phoneme_indices, int(speaker_id)) | |
| torch_phone_kind_indices = ( | |
| torch.tensor(phone_kind_indices)[None, :].long().to(device) | |
| ) | |
| samples = sample( | |
| duration_model_config, | |
| duration_model_checkpoint, | |
| cfg_scale=cfg_scale, | |
| num_sampling_steps=num_sampling_steps, | |
| seed=0, | |
| speaker_id=torch_speaker_id, | |
| phone=torch_phoneme_indices, | |
| phone_kind=torch_phone_kind_indices, | |
| ) | |
| phoneme_durations = samples[-1][0, 0] | |
| # Step 3: Acoustic prediction | |
| # First, we need to convert phoneme durations to number of frames per phoneme (min 1 frame) | |
| SAMPLE_RATE = 24000 | |
| HOP_LENGTH = 256 | |
| N_FFT = 1024 | |
| N_MELS = 100 | |
| time_per_frame = HOP_LENGTH / SAMPLE_RATE | |
| # convert predicted durations to raw durations using data mean and std in the config | |
| if duration_config["data"]["normalize"]: | |
| mean = duration_config["data"]["data_mean"] | |
| std = duration_config["data"]["data_std"] | |
| raw_durations = phoneme_durations * std + mean | |
| else: | |
| raw_durations = phoneme_durations | |
| raw_durations = raw_durations.clamp(min=time_per_frame, max=1.0) | |
| end_time = torch.cumsum(raw_durations, dim=0) | |
| end_frame = end_time / time_per_frame | |
| int_end_frame = end_frame.floor().int() | |
| repeated_phoneme_indices = [] | |
| repeated_phone_kind_indices = [] | |
| for i in range(len(phonemes)): | |
| repeated_phoneme_indices.extend( | |
| [phoneme_indices[i]] * (int_end_frame[i] - len(repeated_phoneme_indices)) | |
| ) | |
| repeated_phone_kind_indices.extend( | |
| [phone_kind_indices[i]] | |
| * (int_end_frame[i] - len(repeated_phone_kind_indices)) | |
| ) | |
| torch_phoneme_indices = ( | |
| torch.tensor(repeated_phoneme_indices)[None, :].long().to(device) | |
| ) | |
| torch_speaker_id = torch.full_like(torch_phoneme_indices, int(speaker_id)) | |
| torch_phone_kind_indices = ( | |
| torch.tensor(repeated_phone_kind_indices)[None, :].long().to(device) | |
| ) | |
| samples = sample( | |
| acoustic_model_config, | |
| acoustic_model_checkpoint, | |
| cfg_scale=cfg_scale, | |
| num_sampling_steps=num_sampling_steps, | |
| seed=0, | |
| speaker_id=torch_speaker_id, | |
| phone=torch_phoneme_indices, | |
| phone_kind=torch_phone_kind_indices, | |
| ) | |
| mel = samples[-1][0] | |
| # compute raw mel if acoustic model normalize is true | |
| acoustic_config = yaml.safe_load(open(acoustic_model_config, "r")) | |
| if acoustic_config["data"]["normalize"]: | |
| mean = acoustic_config["data"]["data_mean"] | |
| std = acoustic_config["data"]["data_std"] | |
| raw_mel = mel * std + mean | |
| else: | |
| raw_mel = mel | |
| # Step 4: Vocoder | |
| vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz") | |
| audio = vocos.decode(raw_mel.cpu()[None, :, :]).squeeze().cpu().numpy() | |
| return audio, SAMPLE_RATE | |