Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import torch | |
| import numpy as np | |
| import torchaudio | |
| import matplotlib.pyplot as plt | |
| CACHE = { | |
| "get_vits_phoneme_ids": { | |
| "PAD_LENGTH": 310, | |
| "_pad": "_", | |
| "_punctuation": ';:,.!?¡¿—…"«»“” ', | |
| "_letters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz", | |
| "_letters_ipa": "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ", | |
| "_special": "♪☎☒☝⚠", | |
| } | |
| } | |
| CACHE["get_vits_phoneme_ids"]["symbols"] = ( | |
| [CACHE["get_vits_phoneme_ids"]["_pad"]] | |
| + list(CACHE["get_vits_phoneme_ids"]["_punctuation"]) | |
| + list(CACHE["get_vits_phoneme_ids"]["_letters"]) | |
| + list(CACHE["get_vits_phoneme_ids"]["_letters_ipa"]) | |
| + list(CACHE["get_vits_phoneme_ids"]["_special"]) | |
| ) | |
| CACHE["get_vits_phoneme_ids"]["_symbol_to_id"] = { | |
| s: i for i, s in enumerate(CACHE["get_vits_phoneme_ids"]["symbols"]) | |
| } | |
| def get_vits_phoneme_ids(config, dl_output, metadata): | |
| pad_token_id = 0 | |
| pad_length = CACHE["get_vits_phoneme_ids"]["PAD_LENGTH"] | |
| _symbol_to_id = CACHE["get_vits_phoneme_ids"]["_symbol_to_id"] | |
| assert ( | |
| "phonemes" in metadata.keys() | |
| ), "You must provide vits phonemes on using addon get_vits_phoneme_ids" | |
| clean_text = metadata["phonemes"] | |
| sequence = [] | |
| for symbol in clean_text: | |
| symbol_id = _symbol_to_id[symbol] | |
| sequence += [symbol_id] | |
| inserted_zero_sequence = [0] * (len(sequence) * 2) | |
| inserted_zero_sequence[1::2] = sequence | |
| inserted_zero_sequence = inserted_zero_sequence + [0] | |
| def _pad_phonemes(phonemes_list): | |
| return phonemes_list + [pad_token_id] * (pad_length - len(phonemes_list)) | |
| return {"phoneme_idx": torch.LongTensor(_pad_phonemes(inserted_zero_sequence))} | |
| def get_vits_phoneme_ids_no_padding(config, dl_output, metadata): | |
| pad_token_id = 0 | |
| pad_length = CACHE["get_vits_phoneme_ids"]["PAD_LENGTH"] | |
| _symbol_to_id = CACHE["get_vits_phoneme_ids"]["_symbol_to_id"] | |
| assert ( | |
| "phonemes" in metadata.keys() | |
| ), "You must provide vits phonemes on using addon get_vits_phoneme_ids" | |
| clean_text = metadata["phonemes"] + "⚠" | |
| sequence = [] | |
| for symbol in clean_text: | |
| if symbol not in _symbol_to_id.keys(): | |
| print("%s is not in the vocabulary. %s" % (symbol, clean_text)) | |
| symbol = "_" | |
| symbol_id = _symbol_to_id[symbol] | |
| sequence += [symbol_id] | |
| def _pad_phonemes(phonemes_list): | |
| return phonemes_list + [pad_token_id] * (pad_length - len(phonemes_list)) | |
| sequence = sequence[:pad_length] | |
| return {"phoneme_idx": torch.LongTensor(_pad_phonemes(sequence))} | |
| def calculate_relative_bandwidth(config, dl_output, metadata): | |
| assert "stft" in dl_output.keys() | |
| # The last dimension of the stft feature is the frequency dimension | |
| freq_dimensions = dl_output["stft"].size(-1) | |
| freq_energy_dist = torch.sum(dl_output["stft"], dim=0) | |
| freq_energy_dist = torch.cumsum(freq_energy_dist, dim=0) | |
| total_energy = freq_energy_dist[-1] | |
| percentile_5th = total_energy * 0.05 | |
| percentile_95th = total_energy * 0.95 | |
| lower_idx = torch.argmin(torch.abs(percentile_5th - freq_energy_dist)) | |
| higher_idx = torch.argmin(torch.abs(percentile_95th - freq_energy_dist)) | |
| lower_idx = int((lower_idx / freq_dimensions) * 1000) | |
| higher_idx = int((higher_idx / freq_dimensions) * 1000) | |
| return {"freq_energy_percentile": torch.LongTensor([lower_idx, higher_idx])} | |
| def calculate_mel_spec_relative_bandwidth_as_extra_channel(config, dl_output, metadata): | |
| assert "stft" in dl_output.keys() | |
| linear_mel_spec = torch.exp(torch.clip(dl_output["log_mel_spec"], max=10)) | |
| # The last dimension of the stft feature is the frequency dimension | |
| freq_dimensions = linear_mel_spec.size(-1) | |
| freq_energy_dist = torch.sum(linear_mel_spec, dim=0) | |
| freq_energy_dist = torch.cumsum(freq_energy_dist, dim=0) | |
| total_energy = freq_energy_dist[-1] | |
| percentile_5th = total_energy * 0.05 | |
| percentile_95th = total_energy * 0.95 | |
| lower_idx = torch.argmin(torch.abs(percentile_5th - freq_energy_dist)) | |
| higher_idx = torch.argmin(torch.abs(percentile_95th - freq_energy_dist)) | |
| latent_t_size = config["model"]["params"]["latent_t_size"] | |
| latent_f_size = config["model"]["params"]["latent_f_size"] | |
| lower_idx = int(latent_f_size * float((lower_idx / freq_dimensions))) | |
| higher_idx = int(latent_f_size * float((higher_idx / freq_dimensions))) | |
| bandwidth_condition = torch.zeros((latent_t_size, latent_f_size)) | |
| bandwidth_condition[:, lower_idx:higher_idx] += 1.0 | |
| return { | |
| "mel_spec_bandwidth_cond_extra_channel": bandwidth_condition, | |
| "freq_energy_percentile": torch.LongTensor([lower_idx, higher_idx]), | |
| } | |
| def waveform_rs_48k(config, dl_output, metadata): | |
| waveform = dl_output["waveform"] # [1, samples] | |
| sampling_rate = dl_output["sampling_rate"] | |
| if sampling_rate != 48000: | |
| waveform_48k = torchaudio.functional.resample( | |
| waveform, orig_freq=sampling_rate, new_freq=48000 | |
| ) | |
| else: | |
| waveform_48k = waveform | |
| return {"waveform_48k": waveform_48k} | |
| def extract_vits_phoneme_and_flant5_text(config, dl_output, metadata): | |
| assert ( | |
| "phoneme" not in metadata.keys() | |
| ), "The metadata of speech you use seems belong to fastspeech. Please check dataset_root.json" | |
| if "phonemes" in metadata.keys(): | |
| new_item = get_vits_phoneme_ids_no_padding(config, dl_output, metadata) | |
| new_item["text"] = "" # We assume TTS data does not have text description | |
| else: | |
| fake_metadata = {"phonemes": ""} # Add empty phoneme sequence | |
| new_item = get_vits_phoneme_ids_no_padding(config, dl_output, fake_metadata) | |
| return new_item | |
| def extract_fs2_phoneme_and_flant5_text(config, dl_output, metadata): | |
| if "phoneme" in metadata.keys(): | |
| new_item = extract_fs2_phoneme_g2p_en_feature(config, dl_output, metadata) | |
| new_item["text"] = "" | |
| else: | |
| fake_metadata = {"phoneme": []} | |
| new_item = extract_fs2_phoneme_g2p_en_feature(config, dl_output, fake_metadata) | |
| return new_item | |
| def extract_fs2_phoneme_g2p_en_feature(config, dl_output, metadata): | |
| PAD_LENGTH = 135 | |
| phonemes_lookup_dict = { | |
| "K": 0, | |
| "IH2": 1, | |
| "NG": 2, | |
| "OW2": 3, | |
| "AH2": 4, | |
| "F": 5, | |
| "AE0": 6, | |
| "IY0": 7, | |
| "SH": 8, | |
| "G": 9, | |
| "W": 10, | |
| "UW1": 11, | |
| "AO2": 12, | |
| "AW2": 13, | |
| "UW0": 14, | |
| "EY2": 15, | |
| "UW2": 16, | |
| "AE2": 17, | |
| "IH0": 18, | |
| "P": 19, | |
| "D": 20, | |
| "ER1": 21, | |
| "AA1": 22, | |
| "EH0": 23, | |
| "UH1": 24, | |
| "N": 25, | |
| "V": 26, | |
| "AY1": 27, | |
| "EY1": 28, | |
| "UH2": 29, | |
| "EH1": 30, | |
| "L": 31, | |
| "AA2": 32, | |
| "R": 33, | |
| "OY1": 34, | |
| "Y": 35, | |
| "ER2": 36, | |
| "S": 37, | |
| "AE1": 38, | |
| "AH1": 39, | |
| "JH": 40, | |
| "ER0": 41, | |
| "EH2": 42, | |
| "IY2": 43, | |
| "OY2": 44, | |
| "AW1": 45, | |
| "IH1": 46, | |
| "IY1": 47, | |
| "OW0": 48, | |
| "AO0": 49, | |
| "AY0": 50, | |
| "EY0": 51, | |
| "AY2": 52, | |
| "UH0": 53, | |
| "M": 54, | |
| "TH": 55, | |
| "T": 56, | |
| "OY0": 57, | |
| "AW0": 58, | |
| "DH": 59, | |
| "Z": 60, | |
| "spn": 61, | |
| "AH0": 62, | |
| "sp": 63, | |
| "AO1": 64, | |
| "OW1": 65, | |
| "ZH": 66, | |
| "B": 67, | |
| "AA0": 68, | |
| "CH": 69, | |
| "HH": 70, | |
| } | |
| pad_token_id = len(phonemes_lookup_dict.keys()) | |
| assert ( | |
| "phoneme" in metadata.keys() | |
| ), "The dataloader add-on extract_phoneme_g2p_en_feature will output phoneme id, which is not specified in your dataset" | |
| phonemes = [ | |
| phonemes_lookup_dict[x] | |
| for x in metadata["phoneme"] | |
| if (x in phonemes_lookup_dict.keys()) | |
| ] | |
| if (len(phonemes) / PAD_LENGTH) > 5: | |
| print( | |
| "Warning: Phonemes length is too long and is truncated too much! %s" | |
| % metadata | |
| ) | |
| phonemes = phonemes[:PAD_LENGTH] | |
| def _pad_phonemes(phonemes_list): | |
| return phonemes_list + [pad_token_id] * (PAD_LENGTH - len(phonemes_list)) | |
| return {"phoneme_idx": torch.LongTensor(_pad_phonemes(phonemes))} | |
| def extract_phoneme_g2p_en_feature(config, dl_output, metadata): | |
| PAD_LENGTH = 250 | |
| phonemes_lookup_dict = { | |
| " ": 0, | |
| "AA": 1, | |
| "AE": 2, | |
| "AH": 3, | |
| "AO": 4, | |
| "AW": 5, | |
| "AY": 6, | |
| "B": 7, | |
| "CH": 8, | |
| "D": 9, | |
| "DH": 10, | |
| "EH": 11, | |
| "ER": 12, | |
| "EY": 13, | |
| "F": 14, | |
| "G": 15, | |
| "HH": 16, | |
| "IH": 17, | |
| "IY": 18, | |
| "JH": 19, | |
| "K": 20, | |
| "L": 21, | |
| "M": 22, | |
| "N": 23, | |
| "NG": 24, | |
| "OW": 25, | |
| "OY": 26, | |
| "P": 27, | |
| "R": 28, | |
| "S": 29, | |
| "SH": 30, | |
| "T": 31, | |
| "TH": 32, | |
| "UH": 33, | |
| "UW": 34, | |
| "V": 35, | |
| "W": 36, | |
| "Y": 37, | |
| "Z": 38, | |
| "ZH": 39, | |
| } | |
| pad_token_id = len(phonemes_lookup_dict.keys()) | |
| assert ( | |
| "phoneme" in metadata.keys() | |
| ), "The dataloader add-on extract_phoneme_g2p_en_feature will output phoneme id, which is not specified in your dataset" | |
| phonemes = [ | |
| phonemes_lookup_dict[x] | |
| for x in metadata["phoneme"] | |
| if (x in phonemes_lookup_dict.keys()) | |
| ] | |
| if (len(phonemes) / PAD_LENGTH) > 5: | |
| print( | |
| "Warning: Phonemes length is too long and is truncated too much! %s" | |
| % metadata | |
| ) | |
| phonemes = phonemes[:PAD_LENGTH] | |
| def _pad_phonemes(phonemes_list): | |
| return phonemes_list + [pad_token_id] * (PAD_LENGTH - len(phonemes_list)) | |
| return {"phoneme_idx": torch.LongTensor(_pad_phonemes(phonemes))} | |
| def extract_kaldi_fbank_feature(config, dl_output, metadata): | |
| norm_mean = -4.2677393 | |
| norm_std = 4.5689974 | |
| waveform = dl_output["waveform"] # [1, samples] | |
| sampling_rate = dl_output["sampling_rate"] | |
| log_mel_spec_hifigan = dl_output["log_mel_spec"] | |
| if sampling_rate != 16000: | |
| waveform_16k = torchaudio.functional.resample( | |
| waveform, orig_freq=sampling_rate, new_freq=16000 | |
| ) | |
| else: | |
| waveform_16k = waveform | |
| waveform_16k = waveform_16k - waveform_16k.mean() | |
| fbank = torchaudio.compliance.kaldi.fbank( | |
| waveform_16k, | |
| htk_compat=True, | |
| sample_frequency=16000, | |
| use_energy=False, | |
| window_type="hanning", | |
| num_mel_bins=128, | |
| dither=0.0, | |
| frame_shift=10, | |
| ) | |
| TARGET_LEN = log_mel_spec_hifigan.size(0) | |
| # cut and pad | |
| n_frames = fbank.shape[0] | |
| p = TARGET_LEN - n_frames | |
| if p > 0: | |
| m = torch.nn.ZeroPad2d((0, 0, 0, p)) | |
| fbank = m(fbank) | |
| elif p < 0: | |
| fbank = fbank[:TARGET_LEN, :] | |
| fbank = (fbank - norm_mean) / (norm_std * 2) | |
| return {"ta_kaldi_fbank": fbank} # [1024, 128] | |
| def extract_kaldi_fbank_feature_32k(config, dl_output, metadata): | |
| norm_mean = -4.2677393 | |
| norm_std = 4.5689974 | |
| waveform = dl_output["waveform"] # [1, samples] | |
| sampling_rate = dl_output["sampling_rate"] | |
| log_mel_spec_hifigan = dl_output["log_mel_spec"] | |
| if sampling_rate != 32000: | |
| waveform_32k = torchaudio.functional.resample( | |
| waveform, orig_freq=sampling_rate, new_freq=32000 | |
| ) | |
| else: | |
| waveform_32k = waveform | |
| waveform_32k = waveform_32k - waveform_32k.mean() | |
| fbank = torchaudio.compliance.kaldi.fbank( | |
| waveform_32k, | |
| htk_compat=True, | |
| sample_frequency=32000, | |
| use_energy=False, | |
| window_type="hanning", | |
| num_mel_bins=128, | |
| dither=0.0, | |
| frame_shift=10, | |
| ) | |
| TARGET_LEN = log_mel_spec_hifigan.size(0) | |
| # cut and pad | |
| n_frames = fbank.shape[0] | |
| p = TARGET_LEN - n_frames | |
| if p > 0: | |
| m = torch.nn.ZeroPad2d((0, 0, 0, p)) | |
| fbank = m(fbank) | |
| elif p < 0: | |
| fbank = fbank[:TARGET_LEN, :] | |
| fbank = (fbank - norm_mean) / (norm_std * 2) | |
| return {"ta_kaldi_fbank": fbank} # [1024, 128] | |
| # Use the beat and downbeat information as music conditions | |
| def extract_drum_beat(config, dl_output, metadata): | |
| def visualization(conditional_signal, mel_spectrogram, filename): | |
| import soundfile as sf | |
| sf.write( | |
| os.path.basename(dl_output["fname"]), | |
| np.array(dl_output["waveform"])[0], | |
| dl_output["sampling_rate"], | |
| ) | |
| plt.figure(figsize=(10, 10)) | |
| plt.subplot(211) | |
| plt.imshow(np.array(conditional_signal).T, aspect="auto") | |
| plt.title("Conditional Signal") | |
| plt.subplot(212) | |
| plt.imshow(np.array(mel_spectrogram).T, aspect="auto") | |
| plt.title("Mel Spectrogram") | |
| plt.savefig(filename) | |
| plt.close() | |
| assert "sample_rate" in metadata and "beat" in metadata and "downbeat" in metadata | |
| sampling_rate = metadata["sample_rate"] | |
| duration = dl_output["duration"] | |
| # The dataloader segment length before performing torch resampling | |
| original_segment_length_before_resample = int(sampling_rate * duration) | |
| random_start_sample = int(dl_output["random_start_sample_in_original_audio_file"]) | |
| # The sample idx for beat and downbeat, relatively to the segmented audio | |
| beat = [ | |
| x - random_start_sample | |
| for x in metadata["beat"] | |
| if ( | |
| x - random_start_sample >= 0 | |
| and x - random_start_sample <= original_segment_length_before_resample | |
| ) | |
| ] | |
| downbeat = [ | |
| x - random_start_sample | |
| for x in metadata["downbeat"] | |
| if ( | |
| x - random_start_sample >= 0 | |
| and x - random_start_sample <= original_segment_length_before_resample | |
| ) | |
| ] | |
| latent_shape = ( | |
| config["model"]["params"]["latent_t_size"], | |
| config["model"]["params"]["latent_f_size"], | |
| ) | |
| conditional_signal = torch.zeros(latent_shape) | |
| # beat: -0.5 | |
| # downbeat: +1.0 | |
| # 0: none; -0.5: beat; 1.0: downbeat; 0.5: downbeat+beat | |
| for each in beat: | |
| beat_index = int( | |
| (each / original_segment_length_before_resample) * latent_shape[0] | |
| ) | |
| beat_index = min(beat_index, conditional_signal.size(0) - 1) | |
| conditional_signal[beat_index, :] -= 0.5 | |
| for each in downbeat: | |
| beat_index = int( | |
| (each / original_segment_length_before_resample) * latent_shape[0] | |
| ) | |
| beat_index = min(beat_index, conditional_signal.size(0) - 1) | |
| conditional_signal[beat_index, :] += 1.0 | |
| # visualization(conditional_signal, dl_output["log_mel_spec"], filename = os.path.basename(dl_output["fname"])+".png") | |
| return {"cond_beat_downbeat": conditional_signal} | |