| | |
| | |
| | |
| |
|
| | import os |
| | import json |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | import numpy as np |
| | import matplotlib |
| | from scipy.io import wavfile |
| | from matplotlib import pyplot as plt |
| |
|
| | matplotlib.use("Agg") |
| |
|
| | import hashlib |
| | import os |
| |
|
| | import requests |
| | from tqdm import tqdm |
| |
|
| | URL_MAP = { |
| | "vggishish_lpaps": "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/vggishish16.pt", |
| | "vggishish_mean_std_melspec_10s_22050hz": "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/train_means_stds_melspec_10s_22050hz.txt", |
| | "melception": "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/melception-21-05-10T09-28-40.pt", |
| | } |
| |
|
| | CKPT_MAP = { |
| | "vggishish_lpaps": "vggishish16.pt", |
| | "vggishish_mean_std_melspec_10s_22050hz": "train_means_stds_melspec_10s_22050hz.txt", |
| | "melception": "melception-21-05-10T09-28-40.pt", |
| | } |
| |
|
| | MD5_MAP = { |
| | "vggishish_lpaps": "197040c524a07ccacf7715d7080a80bd", |
| | "vggishish_mean_std_melspec_10s_22050hz": "f449c6fd0e248936c16f6d22492bb625", |
| | "melception": "a71a41041e945b457c7d3d814bbcf72d", |
| | } |
| |
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| |
|
| | def read_list(fname): |
| | result = [] |
| | with open(fname, "r") as f: |
| | for each in f.readlines(): |
| | each = each.strip("\n") |
| | result.append(each) |
| | return result |
| |
|
| |
|
| | def build_dataset_json_from_list(list_path): |
| | data = [] |
| | for each in read_list(list_path): |
| | if "|" in each: |
| | wav, caption = each.split("|") |
| | else: |
| | caption = each |
| | wav = "" |
| | data.append( |
| | { |
| | "wav": wav, |
| | "caption": caption, |
| | } |
| | ) |
| | return {"data": data} |
| |
|
| |
|
| | def load_json(fname): |
| | with open(fname, "r") as f: |
| | data = json.load(f) |
| | return data |
| |
|
| |
|
| | def read_json(dataset_json_file): |
| | with open(dataset_json_file, "r") as fp: |
| | data_json = json.load(fp) |
| | return data_json["data"] |
| |
|
| |
|
| | def copy_test_subset_data(metadata, testset_copy_target_path): |
| | |
| | os.makedirs(testset_copy_target_path, exist_ok=True) |
| | if len(os.listdir(testset_copy_target_path)) == len(metadata): |
| | return |
| | else: |
| | |
| | for file in os.listdir(testset_copy_target_path): |
| | try: |
| | os.remove(os.path.join(testset_copy_target_path, file)) |
| | except Exception as e: |
| | print(e) |
| |
|
| | print("Copying test subset data to {}".format(testset_copy_target_path)) |
| | for each in tqdm(metadata): |
| | cmd = "cp {} {}".format(each["wav"], os.path.join(testset_copy_target_path)) |
| | os.system(cmd) |
| |
|
| |
|
| | def listdir_nohidden(path): |
| | for f in os.listdir(path): |
| | if not f.startswith("."): |
| | yield f |
| |
|
| |
|
| | def get_restore_step(path): |
| | checkpoints = os.listdir(path) |
| | if os.path.exists(os.path.join(path, "final.ckpt")): |
| | return "final.ckpt", 0 |
| | elif not os.path.exists(os.path.join(path, "last.ckpt")): |
| | steps = [int(x.split(".ckpt")[0].split("step=")[1]) for x in checkpoints] |
| | return checkpoints[np.argmax(steps)], np.max(steps) |
| | else: |
| | steps = [] |
| | for x in checkpoints: |
| | if "last" in x: |
| | if "-v" not in x: |
| | fname = "last.ckpt" |
| | else: |
| | this_version = int(x.split(".ckpt")[0].split("-v")[1]) |
| | steps.append(this_version) |
| | if len(steps) == 0 or this_version > np.max(steps): |
| | fname = "last-v%s.ckpt" % this_version |
| | return fname, 0 |
| |
|
| |
|
| | def download(url, local_path, chunk_size=1024): |
| | os.makedirs(os.path.split(local_path)[0], exist_ok=True) |
| | with requests.get(url, stream=True) as r: |
| | total_size = int(r.headers.get("content-length", 0)) |
| | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: |
| | with open(local_path, "wb") as f: |
| | for data in r.iter_content(chunk_size=chunk_size): |
| | if data: |
| | f.write(data) |
| | pbar.update(chunk_size) |
| |
|
| |
|
| | def md5_hash(path): |
| | with open(path, "rb") as f: |
| | content = f.read() |
| | return hashlib.md5(content).hexdigest() |
| |
|
| |
|
| | def get_ckpt_path(name, root, check=False): |
| | assert name in URL_MAP |
| | path = os.path.join(root, CKPT_MAP[name]) |
| | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): |
| | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) |
| | download(URL_MAP[name], path) |
| | md5 = md5_hash(path) |
| | assert md5 == MD5_MAP[name], md5 |
| | return path |
| |
|
| |
|
| | class KeyNotFoundError(Exception): |
| | def __init__(self, cause, keys=None, visited=None): |
| | self.cause = cause |
| | self.keys = keys |
| | self.visited = visited |
| | messages = list() |
| | if keys is not None: |
| | messages.append("Key not found: {}".format(keys)) |
| | if visited is not None: |
| | messages.append("Visited: {}".format(visited)) |
| | messages.append("Cause:\n{}".format(cause)) |
| | message = "\n".join(messages) |
| | super().__init__(message) |
| |
|
| |
|
| | def retrieve( |
| | list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False |
| | ): |
| | """Given a nested list or dict return the desired value at key expanding |
| | callable nodes if necessary and :attr:`expand` is ``True``. The expansion |
| | is done in-place. |
| | |
| | Parameters |
| | ---------- |
| | list_or_dict : list or dict |
| | Possibly nested list or dictionary. |
| | key : str |
| | key/to/value, path like string describing all keys necessary to |
| | consider to get to the desired value. List indices can also be |
| | passed here. |
| | splitval : str |
| | String that defines the delimiter between keys of the |
| | different depth levels in `key`. |
| | default : obj |
| | Value returned if :attr:`key` is not found. |
| | expand : bool |
| | Whether to expand callable nodes on the path or not. |
| | |
| | Returns |
| | ------- |
| | The desired value or if :attr:`default` is not ``None`` and the |
| | :attr:`key` is not found returns ``default``. |
| | |
| | Raises |
| | ------ |
| | Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is |
| | ``None``. |
| | """ |
| |
|
| | keys = key.split(splitval) |
| |
|
| | success = True |
| | try: |
| | visited = [] |
| | parent = None |
| | last_key = None |
| | for key in keys: |
| | if callable(list_or_dict): |
| | if not expand: |
| | raise KeyNotFoundError( |
| | ValueError( |
| | "Trying to get past callable node with expand=False." |
| | ), |
| | keys=keys, |
| | visited=visited, |
| | ) |
| | list_or_dict = list_or_dict() |
| | parent[last_key] = list_or_dict |
| |
|
| | last_key = key |
| | parent = list_or_dict |
| |
|
| | try: |
| | if isinstance(list_or_dict, dict): |
| | list_or_dict = list_or_dict[key] |
| | else: |
| | list_or_dict = list_or_dict[int(key)] |
| | except (KeyError, IndexError, ValueError) as e: |
| | raise KeyNotFoundError(e, keys=keys, visited=visited) |
| |
|
| | visited += [key] |
| | |
| | if expand and callable(list_or_dict): |
| | list_or_dict = list_or_dict() |
| | parent[last_key] = list_or_dict |
| | except KeyNotFoundError as e: |
| | if default is None: |
| | raise e |
| | else: |
| | list_or_dict = default |
| | success = False |
| |
|
| | if not pass_success: |
| | return list_or_dict |
| | else: |
| | return list_or_dict, success |
| |
|
| |
|
| | def to_device(data, device): |
| | if len(data) == 12: |
| | ( |
| | ids, |
| | raw_texts, |
| | speakers, |
| | texts, |
| | src_lens, |
| | max_src_len, |
| | mels, |
| | mel_lens, |
| | max_mel_len, |
| | pitches, |
| | energies, |
| | durations, |
| | ) = data |
| |
|
| | speakers = torch.from_numpy(speakers).long().to(device) |
| | texts = torch.from_numpy(texts).long().to(device) |
| | src_lens = torch.from_numpy(src_lens).to(device) |
| | mels = torch.from_numpy(mels).float().to(device) |
| | mel_lens = torch.from_numpy(mel_lens).to(device) |
| | pitches = torch.from_numpy(pitches).float().to(device) |
| | energies = torch.from_numpy(energies).to(device) |
| | durations = torch.from_numpy(durations).long().to(device) |
| |
|
| | return ( |
| | ids, |
| | raw_texts, |
| | speakers, |
| | texts, |
| | src_lens, |
| | max_src_len, |
| | mels, |
| | mel_lens, |
| | max_mel_len, |
| | pitches, |
| | energies, |
| | durations, |
| | ) |
| |
|
| | if len(data) == 6: |
| | (ids, raw_texts, speakers, texts, src_lens, max_src_len) = data |
| |
|
| | speakers = torch.from_numpy(speakers).long().to(device) |
| | texts = torch.from_numpy(texts).long().to(device) |
| | src_lens = torch.from_numpy(src_lens).to(device) |
| |
|
| | return (ids, raw_texts, speakers, texts, src_lens, max_src_len) |
| |
|
| |
|
| | def log(logger, step=None, fig=None, audio=None, sampling_rate=22050, tag=""): |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | if fig is not None: |
| | logger.add_figure(tag, fig) |
| |
|
| | if audio is not None: |
| | audio = audio / (max(abs(audio)) * 1.1) |
| | logger.add_audio( |
| | tag, |
| | audio, |
| | sample_rate=sampling_rate, |
| | ) |
| |
|
| |
|
| | def get_mask_from_lengths(lengths, max_len=None): |
| | batch_size = lengths.shape[0] |
| | if max_len is None: |
| | max_len = torch.max(lengths).item() |
| |
|
| | ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device) |
| | mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) |
| |
|
| | return mask |
| |
|
| |
|
| | def expand(values, durations): |
| | out = list() |
| | for value, d in zip(values, durations): |
| | out += [value] * max(0, int(d)) |
| | return np.array(out) |
| |
|
| |
|
| | def synth_one_sample_val( |
| | targets, predictions, vocoder, model_config, preprocess_config |
| | ): |
| | index = np.random.choice(list(np.arange(targets[6].size(0)))) |
| |
|
| | basename = targets[0][index] |
| | src_len = predictions[8][index].item() |
| | mel_len = predictions[9][index].item() |
| | mel_target = targets[6][index, :mel_len].detach().transpose(0, 1) |
| |
|
| | mel_prediction = predictions[0][index, :mel_len].detach().transpose(0, 1) |
| | postnet_mel_prediction = predictions[1][index, :mel_len].detach().transpose(0, 1) |
| | duration = targets[11][index, :src_len].detach().cpu().numpy() |
| |
|
| | if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level": |
| | pitch = predictions[2][index, :src_len].detach().cpu().numpy() |
| | pitch = expand(pitch, duration) |
| | else: |
| | pitch = predictions[2][index, :mel_len].detach().cpu().numpy() |
| |
|
| | if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level": |
| | energy = predictions[3][index, :src_len].detach().cpu().numpy() |
| | energy = expand(energy, duration) |
| | else: |
| | energy = predictions[3][index, :mel_len].detach().cpu().numpy() |
| |
|
| | with open( |
| | os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json") |
| | ) as f: |
| | stats = json.load(f) |
| | stats = stats["pitch"] + stats["energy"][:2] |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | fig = plot_mel( |
| | [ |
| | (mel_prediction.cpu().numpy(), pitch, energy), |
| | (postnet_mel_prediction.cpu().numpy(), pitch, energy), |
| | (mel_target.cpu().numpy(), pitch, energy), |
| | ], |
| | stats, |
| | [ |
| | "Raw mel spectrogram prediction", |
| | "Postnet mel prediction", |
| | "Ground-Truth Spectrogram", |
| | ], |
| | ) |
| |
|
| | if vocoder is not None: |
| | from .model_util import vocoder_infer |
| |
|
| | wav_reconstruction = vocoder_infer( |
| | mel_target.unsqueeze(0), |
| | vocoder, |
| | model_config, |
| | preprocess_config, |
| | )[0] |
| | wav_prediction = vocoder_infer( |
| | postnet_mel_prediction.unsqueeze(0), |
| | vocoder, |
| | model_config, |
| | preprocess_config, |
| | )[0] |
| | else: |
| | wav_reconstruction = wav_prediction = None |
| |
|
| | return fig, wav_reconstruction, wav_prediction, basename |
| |
|
| |
|
| | def synth_one_sample(mel_input, mel_prediction, labels, vocoder): |
| | if vocoder is not None: |
| | from .model_util import vocoder_infer |
| |
|
| | wav_reconstruction = vocoder_infer( |
| | mel_input.permute(0, 2, 1), |
| | vocoder, |
| | ) |
| | wav_prediction = vocoder_infer( |
| | mel_prediction.permute(0, 2, 1), |
| | vocoder, |
| | ) |
| | else: |
| | wav_reconstruction = wav_prediction = None |
| |
|
| | return wav_reconstruction, wav_prediction |
| |
|
| |
|
| | def synth_samples(targets, predictions, vocoder, model_config, preprocess_config, path): |
| | |
| |
|
| | basenames = targets[0] |
| |
|
| | for i in range(len(predictions[1])): |
| | basename = basenames[i] |
| | src_len = predictions[8][i].item() |
| | mel_len = predictions[9][i].item() |
| | mel_prediction = predictions[1][i, :mel_len].detach().transpose(0, 1) |
| | |
| | |
| | if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level": |
| | pitch = predictions[2][i, :src_len].detach().cpu().numpy() |
| | |
| | else: |
| | pitch = predictions[2][i, :mel_len].detach().cpu().numpy() |
| | if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level": |
| | energy = predictions[3][i, :src_len].detach().cpu().numpy() |
| | |
| | else: |
| | energy = predictions[3][i, :mel_len].detach().cpu().numpy() |
| | |
| | with open( |
| | os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json") |
| | ) as f: |
| | stats = json.load(f) |
| | stats = stats["pitch"] + stats["energy"][:2] |
| |
|
| | fig = plot_mel( |
| | [ |
| | (mel_prediction.cpu().numpy(), pitch, energy), |
| | ], |
| | stats, |
| | ["Synthetized Spectrogram by PostNet"], |
| | ) |
| | |
| | plt.savefig(os.path.join(path, "{}_postnet_2.png".format(basename))) |
| | plt.close() |
| |
|
| | from .model_util import vocoder_infer |
| |
|
| | mel_predictions = predictions[1].transpose(1, 2) |
| | lengths = predictions[9] * preprocess_config["preprocessing"]["stft"]["hop_length"] |
| | wav_predictions = vocoder_infer( |
| | mel_predictions, vocoder, model_config, preprocess_config, lengths=lengths |
| | ) |
| |
|
| | sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"] |
| | for wav, basename in zip(wav_predictions, basenames): |
| | wavfile.write(os.path.join(path, "{}.wav".format(basename)), sampling_rate, wav) |
| |
|
| |
|
| | def plot_mel(data, titles=None): |
| | fig, axes = plt.subplots(len(data), 1, squeeze=False) |
| | if titles is None: |
| | titles = [None for i in range(len(data))] |
| |
|
| | for i in range(len(data)): |
| | mel = data[i] |
| | axes[i][0].imshow(mel, origin="lower", aspect="auto") |
| | axes[i][0].set_aspect(2.5, adjustable="box") |
| | axes[i][0].set_ylim(0, mel.shape[0]) |
| | axes[i][0].set_title(titles[i], fontsize="medium") |
| | axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False) |
| | axes[i][0].set_anchor("W") |
| |
|
| | return fig |
| |
|
| |
|
| | def pad_1D(inputs, PAD=0): |
| | def pad_data(x, length, PAD): |
| | x_padded = np.pad( |
| | x, (0, length - x.shape[0]), mode="constant", constant_values=PAD |
| | ) |
| | return x_padded |
| |
|
| | max_len = max((len(x) for x in inputs)) |
| | padded = np.stack([pad_data(x, max_len, PAD) for x in inputs]) |
| |
|
| | return padded |
| |
|
| |
|
| | def pad_2D(inputs, maxlen=None): |
| | def pad(x, max_len): |
| | PAD = 0 |
| | if np.shape(x)[0] > max_len: |
| | raise ValueError("not max_len") |
| |
|
| | s = np.shape(x)[1] |
| | x_padded = np.pad( |
| | x, (0, max_len - np.shape(x)[0]), mode="constant", constant_values=PAD |
| | ) |
| | return x_padded[:, :s] |
| |
|
| | if maxlen: |
| | output = np.stack([pad(x, maxlen) for x in inputs]) |
| | else: |
| | max_len = max(np.shape(x)[0] for x in inputs) |
| | output = np.stack([pad(x, max_len) for x in inputs]) |
| |
|
| | return output |
| |
|
| |
|
| | def pad(input_ele, mel_max_length=None): |
| | if mel_max_length: |
| | max_len = mel_max_length |
| | else: |
| | max_len = max([input_ele[i].size(0) for i in range(len(input_ele))]) |
| |
|
| | out_list = list() |
| | for i, batch in enumerate(input_ele): |
| | if len(batch.shape) == 1: |
| | one_batch_padded = F.pad( |
| | batch, (0, max_len - batch.size(0)), "constant", 0.0 |
| | ) |
| | elif len(batch.shape) == 2: |
| | one_batch_padded = F.pad( |
| | batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0 |
| | ) |
| | out_list.append(one_batch_padded) |
| | out_padded = torch.stack(out_list) |
| | return out_padded |