Spaces:
Sleeping
Sleeping
| import glob | |
| import json | |
| import operator | |
| import os | |
| import shutil | |
| import time | |
| from random import shuffle | |
| from typing import * | |
| import faiss | |
| import numpy as np | |
| import torch | |
| import torch.distributed as dist | |
| import torch.multiprocessing as mp | |
| import torchaudio | |
| import tqdm | |
| from sklearn.cluster import MiniBatchKMeans | |
| from torch.cuda.amp import GradScaler, autocast | |
| from torch.nn import functional as F | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| from torch.utils.data import DataLoader | |
| from torch.utils.tensorboard import SummaryWriter | |
| from . import commons, utils | |
| from .checkpoints import save | |
| from .config import DatasetMetadata, TrainConfig | |
| from .data_utils import (DistributedBucketSampler, TextAudioCollate, | |
| TextAudioCollateMultiNSFsid, TextAudioLoader, | |
| TextAudioLoaderMultiNSFsid) | |
| from .losses import discriminator_loss, feature_loss, generator_loss, kl_loss | |
| from .mel_processing import mel_spectrogram_torch, spec_to_mel_torch | |
| from .models import (MultiPeriodDiscriminator, SynthesizerTrnMs256NSFSid, | |
| SynthesizerTrnMs256NSFSidNono) | |
| from .preprocessing.extract_feature import (MODELS_DIR, get_embedder, | |
| load_embedder) | |
| def is_audio_file(file: str): | |
| if "." not in file: | |
| return False | |
| ext = os.path.splitext(file)[1] | |
| return ext.lower() in [ | |
| ".wav", | |
| ".flac", | |
| ".ogg", | |
| ".mp3", | |
| ".m4a", | |
| ".wma", | |
| ".aiff", | |
| ] | |
| def glob_dataset( | |
| glob_str: str, | |
| speaker_id: int, | |
| multiple_speakers: bool = False, | |
| recursive: bool = True, | |
| training_dir: str = ".", | |
| ): | |
| globs = glob_str.split(",") | |
| speaker_count = 0 | |
| datasets_speakers = [] | |
| speaker_to_id_mapping = {} | |
| for glob_str in globs: | |
| if os.path.isdir(glob_str): | |
| if multiple_speakers: | |
| # Multispeaker format: | |
| # dataset_path/ | |
| # - speakername/ | |
| # - {wav name here}.wav | |
| # - ... | |
| # - next_speakername/ | |
| # - {wav name here}.wav | |
| # - ... | |
| # - ... | |
| print("Multispeaker dataset enabled; Processing speakers.") | |
| for dir in tqdm.tqdm(os.listdir(glob_str)): | |
| print("Speaker ID " + str(speaker_count) + ": " + dir) | |
| speaker_to_id_mapping[dir] = speaker_count | |
| speaker_path = glob_str + "/" + dir | |
| for audio in tqdm.tqdm(os.listdir(speaker_path)): | |
| if is_audio_file(glob_str + "/" + dir + "/" + audio): | |
| datasets_speakers.append((glob_str + "/" + dir + "/" + audio, speaker_count)) | |
| speaker_count += 1 | |
| with open(os.path.join(training_dir, "speaker_info.json"), "w") as outfile: | |
| print("Dumped speaker info to {}".format(os.path.join(training_dir, "speaker_info.json"))) | |
| json.dump(speaker_to_id_mapping, outfile) | |
| continue # Skip the normal speaker extend | |
| glob_str = os.path.join(glob_str, "**", "*") | |
| print("Single speaker dataset enabled; Processing speaker as ID " + str(speaker_id) + ".") | |
| datasets_speakers.extend( | |
| [ | |
| (file, speaker_id) | |
| for file in glob.iglob(glob_str, recursive=recursive) | |
| if is_audio_file(file) | |
| ] | |
| ) | |
| return sorted(datasets_speakers) | |
| def create_dataset_meta(training_dir: str, f0: bool): | |
| gt_wavs_dir = os.path.join(training_dir, "0_gt_wavs") | |
| co256_dir = os.path.join(training_dir, "3_feature256") | |
| def list_data(dir: str): | |
| files = [] | |
| for subdir in os.listdir(dir): | |
| speaker_dir = os.path.join(dir, subdir) | |
| for name in os.listdir(speaker_dir): | |
| files.append(os.path.join(subdir, name.split(".")[0])) | |
| return files | |
| names = set(list_data(gt_wavs_dir)) & set(list_data(co256_dir)) | |
| if f0: | |
| f0_dir = os.path.join(training_dir, "2a_f0") | |
| f0nsf_dir = os.path.join(training_dir, "2b_f0nsf") | |
| names = names & set(list_data(f0_dir)) & set(list_data(f0nsf_dir)) | |
| meta = { | |
| "files": {}, | |
| } | |
| for name in names: | |
| speaker_id = os.path.dirname(name).split("_")[0] | |
| speaker_id = int(speaker_id) if speaker_id.isdecimal() else 0 | |
| if f0: | |
| gt_wav_path = os.path.join(gt_wavs_dir, f"{name}.wav") | |
| co256_path = os.path.join(co256_dir, f"{name}.npy") | |
| f0_path = os.path.join(f0_dir, f"{name}.wav.npy") | |
| f0nsf_path = os.path.join(f0nsf_dir, f"{name}.wav.npy") | |
| meta["files"][name] = { | |
| "gt_wav": gt_wav_path, | |
| "co256": co256_path, | |
| "f0": f0_path, | |
| "f0nsf": f0nsf_path, | |
| "speaker_id": speaker_id, | |
| } | |
| else: | |
| gt_wav_path = os.path.join(gt_wavs_dir, f"{name}.wav") | |
| co256_path = os.path.join(co256_dir, f"{name}.npy") | |
| meta["files"][name] = { | |
| "gt_wav": gt_wav_path, | |
| "co256": co256_path, | |
| "speaker_id": speaker_id, | |
| } | |
| with open(os.path.join(training_dir, "meta.json"), "w") as f: | |
| json.dump(meta, f, indent=2) | |
| def change_speaker(net_g, speaker_info, embedder, embedding_output_layer, phone, phone_lengths, pitch, pitchf, spec_lengths): | |
| """ | |
| random change formant | |
| inspired by https://github.com/auspicious3000/contentvec/blob/d746688a32940f4bee410ed7c87ec9cf8ff04f74/contentvec/data/audio/audio_utils_1.py#L179 | |
| """ | |
| N = phone.shape[0] | |
| device = phone.device | |
| dtype = phone.dtype | |
| f0_bin = 256 | |
| f0_max = 1100.0 | |
| f0_min = 50.0 | |
| f0_mel_min = 1127 * np.log(1 + f0_min / 700) | |
| f0_mel_max = 1127 * np.log(1 + f0_max / 700) | |
| pitch_median = torch.median(pitchf, 1).values | |
| lo = 75. + 25. * (pitch_median >= 200).to(dtype=dtype) | |
| hi = 250. + 150. * (pitch_median >= 200).to(dtype=dtype) | |
| pitch_median = torch.clip(pitch_median, lo, hi).unsqueeze(1) | |
| shift_pitch = torch.exp2((1. - 2. * torch.rand(N)) / 4).unsqueeze(1).to(device, dtype) # ピッチを半オクターブの範囲でずらす | |
| new_sid = np.random.choice(np.arange(len(speaker_info))[speaker_info > 0], size=N) | |
| rel_pitch = pitchf / pitch_median | |
| new_pitch_median = torch.from_numpy(speaker_info[new_sid]).to(device, dtype).unsqueeze(1) * shift_pitch | |
| new_pitchf = new_pitch_median * rel_pitch | |
| new_sid = torch.from_numpy(new_sid).to(device) | |
| new_pitch = 1127. * torch.log(1. + new_pitchf / 700.) | |
| new_pitch = (pitch - f0_mel_min) * (f0_bin - 2.) / (f0_mel_max - f0_mel_min) + 1. | |
| new_pitch = torch.clip(new_pitch, 1, f0_bin - 1).to(dtype=torch.int) | |
| aug_wave = net_g.infer(phone, phone_lengths, new_pitch, new_pitchf, new_sid)[0] | |
| aug_wave_16k = torchaudio.functional.resample(aug_wave, net_g.sr, 16000, rolloff=0.99).squeeze(1) | |
| padding_mask = torch.arange(aug_wave_16k.shape[1]).unsqueeze(0).to(device) > (spec_lengths.unsqueeze(1) * 160).to(device) | |
| inputs = { | |
| "source": aug_wave_16k.to(device, dtype), | |
| "padding_mask": padding_mask.to(device), | |
| "output_layer": embedding_output_layer | |
| } | |
| logits = embedder.extract_features(**inputs) | |
| if phone.shape[-1] == 768: | |
| feats = logits[0] | |
| else: | |
| feats = embedder.final_proj(logits[0]) | |
| feats = torch.repeat_interleave(feats, 2, 1) | |
| new_phone = torch.zeros(phone.shape).to(device, dtype) | |
| new_phone[:, :feats.shape[1]] = feats[:, :phone.shape[1]] | |
| return new_phone.to(device), aug_wave | |
| def change_speaker_nono(net_g, embedder, embedding_output_layer, phone, phone_lengths, spec_lengths): | |
| """ | |
| random change formant | |
| inspired by https://github.com/auspicious3000/contentvec/blob/d746688a32940f4bee410ed7c87ec9cf8ff04f74/contentvec/data/audio/audio_utils_1.py#L179 | |
| """ | |
| N = phone.shape[0] | |
| device = phone.device | |
| dtype = phone.dtype | |
| new_sid = np.random.randint(net_g.spk_embed_dim, size=N) | |
| new_sid = torch.from_numpy(new_sid).to(device) | |
| aug_wave = net_g.infer(phone, phone_lengths, new_sid)[0] | |
| aug_wave_16k = torchaudio.functional.resample(aug_wave, net_g.sr, 16000, rolloff=0.99).squeeze(1) | |
| padding_mask = torch.arange(aug_wave_16k.shape[1]).unsqueeze(0).to(device) > (spec_lengths.unsqueeze(1) * 160).to(device) | |
| inputs = { | |
| "source": aug_wave_16k.to(device, dtype), | |
| "padding_mask": padding_mask.to(device), | |
| "output_layer": embedding_output_layer | |
| } | |
| logits = embedder.extract_features(**inputs) | |
| if phone.shape[-1] == 768: | |
| feats = logits[0] | |
| else: | |
| feats = embedder.final_proj(logits[0]) | |
| feats = torch.repeat_interleave(feats, 2, 1) | |
| new_phone = torch.zeros(phone.shape).to(device, dtype) | |
| new_phone[:, :feats.shape[1]] = feats[:, :phone.shape[1]] | |
| return new_phone.to(device), aug_wave | |
| def train_index( | |
| training_dir: str, | |
| model_name: str, | |
| out_dir: str, | |
| emb_ch: int, | |
| num_cpu_process: int, | |
| maximum_index_size: Optional[int], | |
| ): | |
| checkpoint_path = os.path.join(out_dir, model_name) | |
| feature_256_dir = os.path.join(training_dir, "3_feature256") | |
| index_dir = os.path.join(os.path.dirname(checkpoint_path), f"{model_name}_index") | |
| os.makedirs(index_dir, exist_ok=True) | |
| for speaker_id in tqdm.tqdm( | |
| sorted([dir for dir in os.listdir(feature_256_dir) if dir.isdecimal()]) | |
| ): | |
| feature_256_spk_dir = os.path.join(feature_256_dir, speaker_id) | |
| speaker_id = int(speaker_id) | |
| npys = [] | |
| for name in [ | |
| os.path.join(feature_256_spk_dir, file) | |
| for file in os.listdir(feature_256_spk_dir) | |
| if file.endswith(".npy") | |
| ]: | |
| phone = np.load(os.path.join(feature_256_spk_dir, name)) | |
| npys.append(phone) | |
| # shuffle big_npy to prevent reproducing the sound source | |
| big_npy = np.concatenate(npys, 0) | |
| big_npy_idx = np.arange(big_npy.shape[0]) | |
| np.random.shuffle(big_npy_idx) | |
| big_npy = big_npy[big_npy_idx] | |
| if not maximum_index_size is None and big_npy.shape[0] > maximum_index_size: | |
| kmeans = MiniBatchKMeans( | |
| n_clusters=maximum_index_size, | |
| batch_size=256 * num_cpu_process, | |
| init="random", | |
| compute_labels=False, | |
| ) | |
| kmeans.fit(big_npy) | |
| big_npy = kmeans.cluster_centers_ | |
| # recommend parameter in https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index | |
| emb_ch = big_npy.shape[1] | |
| emb_ch_half = emb_ch // 2 | |
| n_ivf = int(8 * np.sqrt(big_npy.shape[0])) | |
| if big_npy.shape[0] >= 1_000_000: | |
| index = faiss.index_factory( | |
| emb_ch, f"IVF{n_ivf},PQ{emb_ch_half}x4fsr,RFlat" | |
| ) | |
| else: | |
| index = faiss.index_factory(emb_ch, f"IVF{n_ivf},Flat") | |
| index.train(big_npy) | |
| batch_size_add = 8192 | |
| for i in range(0, big_npy.shape[0], batch_size_add): | |
| index.add(big_npy[i : i + batch_size_add]) | |
| np.save( | |
| os.path.join(index_dir, f"{model_name}.{speaker_id}.big.npy"), | |
| big_npy, | |
| ) | |
| faiss.write_index( | |
| index, | |
| os.path.join(index_dir, f"{model_name}.{speaker_id}.index"), | |
| ) | |
| def train_model( | |
| gpus: List[int], | |
| config: TrainConfig, | |
| training_dir: str, | |
| model_name: str, | |
| out_dir: str, | |
| sample_rate: int, | |
| f0: bool, | |
| batch_size: int, | |
| augment: bool, | |
| augment_path: Optional[str], | |
| speaker_info_path: Optional[str], | |
| cache_batch: bool, | |
| total_epoch: int, | |
| save_every_epoch: int, | |
| save_wav_with_checkpoint: bool, | |
| pretrain_g: str, | |
| pretrain_d: str, | |
| embedder_name: str, | |
| embedding_output_layer: int, | |
| save_only_last: bool = False, | |
| device: Optional[Union[str, torch.device]] = None, | |
| ): | |
| os.environ["MASTER_ADDR"] = "localhost" | |
| os.environ["MASTER_PORT"] = str(utils.find_empty_port()) | |
| deterministic = torch.backends.cudnn.deterministic | |
| benchmark = torch.backends.cudnn.benchmark | |
| PREV_CUDA_VISIBLE_DEVICES = os.environ.get("CUDA_VISIBLE_DEVICES", None) | |
| torch.backends.cudnn.deterministic = False | |
| torch.backends.cudnn.benchmark = False | |
| os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(gpu) for gpu in gpus]) | |
| start = time.perf_counter() | |
| # Mac(MPS)でやると、mp.spawnでなんかトラブルが出るので普通にtraining_runnerを呼び出す。 | |
| if device is not None: | |
| training_runner( | |
| 0, # rank | |
| 1, # world size | |
| config, | |
| training_dir, | |
| model_name, | |
| out_dir, | |
| sample_rate, | |
| f0, | |
| batch_size, | |
| augment, | |
| augment_path, | |
| speaker_info_path, | |
| cache_batch, | |
| total_epoch, | |
| save_every_epoch, | |
| save_wav_with_checkpoint, | |
| pretrain_g, | |
| pretrain_d, | |
| embedder_name, | |
| embedding_output_layer, | |
| save_only_last, | |
| device, | |
| ) | |
| else: | |
| mp.spawn( | |
| training_runner, | |
| nprocs=len(gpus), | |
| args=( | |
| len(gpus), | |
| config, | |
| training_dir, | |
| model_name, | |
| out_dir, | |
| sample_rate, | |
| f0, | |
| batch_size, | |
| augment, | |
| augment_path, | |
| speaker_info_path, | |
| cache_batch, | |
| total_epoch, | |
| save_every_epoch, | |
| save_wav_with_checkpoint, | |
| pretrain_g, | |
| pretrain_d, | |
| embedder_name, | |
| embedding_output_layer, | |
| save_only_last, | |
| device, | |
| ), | |
| ) | |
| end = time.perf_counter() | |
| print(f"Time: {end - start}") | |
| if PREV_CUDA_VISIBLE_DEVICES is None: | |
| del os.environ["CUDA_VISIBLE_DEVICES"] | |
| else: | |
| os.environ["CUDA_VISIBLE_DEVICES"] = PREV_CUDA_VISIBLE_DEVICES | |
| torch.backends.cudnn.deterministic = deterministic | |
| torch.backends.cudnn.benchmark = benchmark | |
| def training_runner( | |
| rank: int, | |
| world_size: List[int], | |
| config: TrainConfig, | |
| training_dir: str, | |
| model_name: str, | |
| out_dir: str, | |
| sample_rate: int, | |
| f0: bool, | |
| batch_size: int, | |
| augment: bool, | |
| augment_path: Optional[str], | |
| speaker_info_path: Optional[str], | |
| cache_in_gpu: bool, | |
| total_epoch: int, | |
| save_every_epoch: int, | |
| save_wav_with_checkpoint: bool, | |
| pretrain_g: str, | |
| pretrain_d: str, | |
| embedder_name: str, | |
| embedding_output_layer: int, | |
| save_only_last: bool = False, | |
| device: Optional[Union[str, torch.device]] = None, | |
| ): | |
| config.train.batch_size = batch_size | |
| log_dir = os.path.join(training_dir, "logs") | |
| state_dir = os.path.join(training_dir, "state") | |
| training_files_path = os.path.join(training_dir, "meta.json") | |
| training_meta = DatasetMetadata.parse_file(training_files_path) | |
| embedder_out_channels = config.model.emb_channels | |
| is_multi_process = world_size > 1 | |
| if device is not None: | |
| if type(device) == str: | |
| device = torch.device(device) | |
| global_step = 0 | |
| is_main_process = rank == 0 | |
| if is_main_process: | |
| os.makedirs(log_dir, exist_ok=True) | |
| os.makedirs(state_dir, exist_ok=True) | |
| writer = SummaryWriter(log_dir=log_dir) | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| if not dist.is_initialized(): | |
| dist.init_process_group( | |
| backend="gloo", init_method="env://", rank=rank, world_size=world_size | |
| ) | |
| if is_multi_process: | |
| torch.cuda.set_device(rank) | |
| torch.manual_seed(config.train.seed) | |
| if f0: | |
| train_dataset = TextAudioLoaderMultiNSFsid(training_meta, config.data) | |
| else: | |
| train_dataset = TextAudioLoader(training_meta, config.data) | |
| train_sampler = DistributedBucketSampler( | |
| train_dataset, | |
| config.train.batch_size * world_size, | |
| [100, 200, 300, 400, 500, 600, 700, 800, 900], | |
| num_replicas=world_size, | |
| rank=rank, | |
| shuffle=True, | |
| ) | |
| if f0: | |
| collate_fn = TextAudioCollateMultiNSFsid() | |
| else: | |
| collate_fn = TextAudioCollate() | |
| train_loader = DataLoader( | |
| train_dataset, | |
| num_workers=4, | |
| shuffle=False, | |
| pin_memory=True, | |
| collate_fn=collate_fn, | |
| batch_sampler=train_sampler, | |
| persistent_workers=True, | |
| prefetch_factor=8, | |
| ) | |
| speaker_info = None | |
| if os.path.exists(os.path.join(training_dir, "speaker_info.json")): | |
| with open(os.path.join(training_dir, "speaker_info.json"), "r") as f: | |
| speaker_info = json.load(f) | |
| config.model.spk_embed_dim = len(speaker_info) | |
| if f0: | |
| net_g = SynthesizerTrnMs256NSFSid( | |
| config.data.filter_length // 2 + 1, | |
| config.train.segment_size // config.data.hop_length, | |
| **config.model.dict(), | |
| is_half=False, # config.train.fp16_run, | |
| sr=int(sample_rate[:-1] + "000"), | |
| ) | |
| else: | |
| net_g = SynthesizerTrnMs256NSFSidNono( | |
| config.data.filter_length // 2 + 1, | |
| config.train.segment_size // config.data.hop_length, | |
| **config.model.dict(), | |
| is_half=False, # config.train.fp16_run, | |
| sr=int(sample_rate[:-1] + "000"), | |
| ) | |
| if is_multi_process: | |
| net_g = net_g.cuda(rank) | |
| else: | |
| net_g = net_g.to(device=device) | |
| if config.version == "v1": | |
| periods = [2, 3, 5, 7, 11, 17] | |
| elif config.version == "v2": | |
| periods = [2, 3, 5, 7, 11, 17, 23, 37] | |
| net_d = MultiPeriodDiscriminator(config.model.use_spectral_norm, periods=periods) | |
| if is_multi_process: | |
| net_d = net_d.cuda(rank) | |
| else: | |
| net_d = net_d.to(device=device) | |
| optim_g = torch.optim.AdamW( | |
| net_g.parameters(), | |
| config.train.learning_rate, | |
| betas=config.train.betas, | |
| eps=config.train.eps, | |
| ) | |
| optim_d = torch.optim.AdamW( | |
| net_d.parameters(), | |
| config.train.learning_rate, | |
| betas=config.train.betas, | |
| eps=config.train.eps, | |
| ) | |
| last_d_state = utils.latest_checkpoint_path(state_dir, "D_*.pth") | |
| last_g_state = utils.latest_checkpoint_path(state_dir, "G_*.pth") | |
| if last_d_state is None or last_g_state is None: | |
| epoch = 1 | |
| global_step = 0 | |
| if os.path.exists(pretrain_g) and os.path.exists(pretrain_d): | |
| net_g_state = torch.load(pretrain_g, map_location="cpu")["model"] | |
| emb_spk_size = (config.model.spk_embed_dim, config.model.gin_channels) | |
| emb_phone_size = (config.model.hidden_channels, config.model.emb_channels) | |
| if emb_spk_size != net_g_state["emb_g.weight"].size(): | |
| original_weight = net_g_state["emb_g.weight"] | |
| net_g_state["emb_g.weight"] = original_weight.mean(dim=0, keepdims=True) * torch.ones(emb_spk_size, device=original_weight.device, dtype=original_weight.dtype) | |
| if emb_phone_size != net_g_state["enc_p.emb_phone.weight"].size(): | |
| # interpolate | |
| orig_shape = net_g_state["enc_p.emb_phone.weight"].size() | |
| if net_g_state["enc_p.emb_phone.weight"].dtype == torch.half: | |
| net_g_state["enc_p.emb_phone.weight"] = ( | |
| F.interpolate( | |
| net_g_state["enc_p.emb_phone.weight"] | |
| .float() | |
| .unsqueeze(0) | |
| .unsqueeze(0), | |
| size=emb_phone_size, | |
| mode="bilinear", | |
| ) | |
| .half() | |
| .squeeze(0) | |
| .squeeze(0) | |
| ) | |
| else: | |
| net_g_state["enc_p.emb_phone.weight"] = ( | |
| F.interpolate( | |
| net_g_state["enc_p.emb_phone.weight"] | |
| .unsqueeze(0) | |
| .unsqueeze(0), | |
| size=emb_phone_size, | |
| mode="bilinear", | |
| ) | |
| .squeeze(0) | |
| .squeeze(0) | |
| ) | |
| print( | |
| "interpolated pretrained state enc_p.emb_phone from", | |
| orig_shape, | |
| "to", | |
| emb_phone_size, | |
| ) | |
| if is_multi_process: | |
| net_g.module.load_state_dict(net_g_state) | |
| else: | |
| net_g.load_state_dict(net_g_state) | |
| del net_g_state | |
| if is_multi_process: | |
| net_d.module.load_state_dict( | |
| torch.load(pretrain_d, map_location="cpu")["model"] | |
| ) | |
| else: | |
| net_d.load_state_dict( | |
| torch.load(pretrain_d, map_location="cpu")["model"] | |
| ) | |
| if is_main_process: | |
| print(f"loaded pretrained {pretrain_g} {pretrain_d}") | |
| else: | |
| _, _, _, epoch = utils.load_checkpoint(last_d_state, net_d, optim_d) | |
| _, _, _, epoch = utils.load_checkpoint(last_g_state, net_g, optim_g) | |
| if is_main_process: | |
| print(f"loaded last state {last_d_state} {last_g_state}") | |
| epoch += 1 | |
| global_step = (epoch - 1) * len(train_loader) | |
| if augment: | |
| # load embedder | |
| embedder_filepath, _, embedder_load_from = get_embedder(embedder_name) | |
| if embedder_load_from == "local": | |
| embedder_filepath = os.path.join( | |
| MODELS_DIR, "embeddings", embedder_filepath | |
| ) | |
| embedder, _ = load_embedder(embedder_filepath, device) | |
| if not config.train.fp16_run: | |
| embedder = embedder.float() | |
| if (augment_path is not None): | |
| state_dict = torch.load(augment_path, map_location="cpu") | |
| if state_dict["f0"] == 1: | |
| augment_net_g = SynthesizerTrnMs256NSFSid( | |
| **state_dict["params"], is_half=config.train.fp16_run | |
| ) | |
| augment_speaker_info = np.load(speaker_info_path) | |
| else: | |
| augment_net_g = SynthesizerTrnMs256NSFSidNono( | |
| **state_dict["params"], is_half=config.train.fp16_run | |
| ) | |
| augment_net_g.load_state_dict(state_dict["weight"], strict=False) | |
| augment_net_g.eval().to(device) | |
| else: | |
| augment_net_g = net_g | |
| if f0: | |
| medians = [[] for _ in range(augment_net_g.spk_embed_dim)] | |
| for file in training_meta.files.values(): | |
| f0f = np.load(file.f0nsf) | |
| if np.any(f0f > 0): | |
| medians[file.speaker_id].append(np.median(f0f[f0f > 0])) | |
| augment_speaker_info = np.array([np.median(x) if len(x) else 0. for x in medians]) | |
| np.save(os.path.join(training_dir, "speaker_info.npy"), augment_speaker_info) | |
| if is_multi_process: | |
| net_g = DDP(net_g, device_ids=[rank]) | |
| net_d = DDP(net_d, device_ids=[rank]) | |
| scheduler_g = torch.optim.lr_scheduler.ExponentialLR( | |
| optim_g, gamma=config.train.lr_decay, last_epoch=epoch - 2 | |
| ) | |
| scheduler_d = torch.optim.lr_scheduler.ExponentialLR( | |
| optim_d, gamma=config.train.lr_decay, last_epoch=epoch - 2 | |
| ) | |
| scaler = GradScaler(enabled=config.train.fp16_run) | |
| cache = [] | |
| progress_bar = tqdm.tqdm(range((total_epoch - epoch + 1) * len(train_loader))) | |
| progress_bar.set_postfix(epoch=epoch) | |
| step = -1 + len(train_loader) * (epoch - 1) | |
| for epoch in range(epoch, total_epoch + 1): | |
| train_loader.batch_sampler.set_epoch(epoch) | |
| net_g.train() | |
| net_d.train() | |
| use_cache = len(cache) == len(train_loader) | |
| data = cache if use_cache else enumerate(train_loader) | |
| if is_main_process: | |
| lr = optim_g.param_groups[0]["lr"] | |
| if use_cache: | |
| shuffle(cache) | |
| for batch_idx, batch in data: | |
| step += 1 | |
| progress_bar.update(1) | |
| if f0: | |
| ( | |
| phone, | |
| phone_lengths, | |
| pitch, | |
| pitchf, | |
| spec, | |
| spec_lengths, | |
| wave, | |
| wave_lengths, | |
| sid, | |
| ) = batch | |
| else: | |
| ( | |
| phone, | |
| phone_lengths, | |
| spec, | |
| spec_lengths, | |
| wave, | |
| wave_lengths, | |
| sid, | |
| ) = batch | |
| if not use_cache: | |
| phone, phone_lengths = ( | |
| phone.to(device=device, non_blocking=True), | |
| phone_lengths.to(device=device, non_blocking=True), | |
| ) | |
| if f0: | |
| pitch, pitchf = ( | |
| pitch.to(device=device, non_blocking=True), | |
| pitchf.to(device=device, non_blocking=True), | |
| ) | |
| sid = sid.to(device=device, non_blocking=True) | |
| spec, spec_lengths = ( | |
| spec.to(device=device, non_blocking=True), | |
| spec_lengths.to(device=device, non_blocking=True), | |
| ) | |
| wave, wave_lengths = ( | |
| wave.to(device=device, non_blocking=True), | |
| wave_lengths.to(device=device, non_blocking=True), | |
| ) | |
| if cache_in_gpu: | |
| if f0: | |
| cache.append( | |
| ( | |
| batch_idx, | |
| ( | |
| phone, | |
| phone_lengths, | |
| pitch, | |
| pitchf, | |
| spec, | |
| spec_lengths, | |
| wave, | |
| wave_lengths, | |
| sid, | |
| ), | |
| ) | |
| ) | |
| else: | |
| cache.append( | |
| ( | |
| batch_idx, | |
| ( | |
| phone, | |
| phone_lengths, | |
| spec, | |
| spec_lengths, | |
| wave, | |
| wave_lengths, | |
| sid, | |
| ), | |
| ) | |
| ) | |
| with autocast(enabled=config.train.fp16_run): | |
| if augment: | |
| with torch.no_grad(): | |
| if type(augment_net_g) == SynthesizerTrnMs256NSFSid: | |
| new_phone, aug_wave = change_speaker(augment_net_g, augment_speaker_info, embedder, embedding_output_layer, phone, phone_lengths, pitch, pitchf, spec_lengths) | |
| else: | |
| new_phone, aug_wave = change_speaker_nono(augment_net_g, embedder, embedding_output_layer, phone, phone_lengths, spec_lengths) | |
| weight = np.power(.5, step / len(train_loader)) # 学習の初期はそのままのphone embeddingを使う | |
| phone = phone * weight + new_phone * (1. - weight) | |
| if f0: | |
| ( | |
| y_hat, | |
| ids_slice, | |
| x_mask, | |
| z_mask, | |
| (z, z_p, m_p, logs_p, m_q, logs_q), | |
| ) = net_g( | |
| phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid | |
| ) | |
| else: | |
| ( | |
| y_hat, | |
| ids_slice, | |
| x_mask, | |
| z_mask, | |
| (z, z_p, m_p, logs_p, m_q, logs_q), | |
| ) = net_g(phone, phone_lengths, spec, spec_lengths, sid) | |
| mel = spec_to_mel_torch( | |
| spec, | |
| config.data.filter_length, | |
| config.data.n_mel_channels, | |
| config.data.sampling_rate, | |
| config.data.mel_fmin, | |
| config.data.mel_fmax, | |
| ) | |
| y_mel = commons.slice_segments( | |
| mel, ids_slice, config.train.segment_size // config.data.hop_length | |
| ) | |
| with autocast(enabled=False): | |
| y_hat_mel = mel_spectrogram_torch( | |
| y_hat.float().squeeze(1), | |
| config.data.filter_length, | |
| config.data.n_mel_channels, | |
| config.data.sampling_rate, | |
| config.data.hop_length, | |
| config.data.win_length, | |
| config.data.mel_fmin, | |
| config.data.mel_fmax, | |
| ) | |
| if config.train.fp16_run == True and device != torch.device("mps"): | |
| y_hat_mel = y_hat_mel.half() | |
| wave_slice = commons.slice_segments( | |
| wave, ids_slice * config.data.hop_length, config.train.segment_size | |
| ) # slice | |
| # Discriminator | |
| y_d_hat_r, y_d_hat_g, _, _ = net_d(wave_slice, y_hat.detach()) | |
| with autocast(enabled=False): | |
| loss_disc, losses_disc_r, losses_disc_g = discriminator_loss( | |
| y_d_hat_r, y_d_hat_g | |
| ) | |
| optim_d.zero_grad() | |
| scaler.scale(loss_disc).backward() | |
| scaler.unscale_(optim_d) | |
| grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None) | |
| scaler.step(optim_d) | |
| with autocast(enabled=config.train.fp16_run): | |
| # Generator | |
| y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave_slice, y_hat) | |
| with autocast(enabled=False): | |
| loss_mel = F.l1_loss(y_mel, y_hat_mel) * config.train.c_mel | |
| loss_kl = ( | |
| kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * config.train.c_kl | |
| ) | |
| loss_fm = feature_loss(fmap_r, fmap_g) | |
| loss_gen, losses_gen = generator_loss(y_d_hat_g) | |
| loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl | |
| optim_g.zero_grad() | |
| scaler.scale(loss_gen_all).backward() | |
| scaler.unscale_(optim_g) | |
| grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None) | |
| scaler.step(optim_g) | |
| scaler.update() | |
| if is_main_process: | |
| progress_bar.set_postfix( | |
| epoch=epoch, | |
| loss_g=float(loss_gen_all) if loss_gen_all is not None else 0.0, | |
| loss_d=float(loss_disc) if loss_disc is not None else 0.0, | |
| lr=float(lr) if lr is not None else 0.0, | |
| use_cache=use_cache, | |
| ) | |
| if global_step % config.train.log_interval == 0: | |
| lr = optim_g.param_groups[0]["lr"] | |
| # Amor For Tensorboard display | |
| if loss_mel > 50: | |
| loss_mel = 50 | |
| if loss_kl > 5: | |
| loss_kl = 5 | |
| scalar_dict = { | |
| "loss/g/total": loss_gen_all, | |
| "loss/d/total": loss_disc, | |
| "learning_rate": lr, | |
| "grad_norm_d": grad_norm_d, | |
| "grad_norm_g": grad_norm_g, | |
| } | |
| scalar_dict.update( | |
| { | |
| "loss/g/fm": loss_fm, | |
| "loss/g/mel": loss_mel, | |
| "loss/g/kl": loss_kl, | |
| } | |
| ) | |
| scalar_dict.update( | |
| {"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)} | |
| ) | |
| scalar_dict.update( | |
| { | |
| "loss/d_r/{}".format(i): v | |
| for i, v in enumerate(losses_disc_r) | |
| } | |
| ) | |
| scalar_dict.update( | |
| { | |
| "loss/d_g/{}".format(i): v | |
| for i, v in enumerate(losses_disc_g) | |
| } | |
| ) | |
| image_dict = { | |
| "slice/mel_org": utils.plot_spectrogram_to_numpy( | |
| y_mel[0].data.cpu().numpy() | |
| ), | |
| "slice/mel_gen": utils.plot_spectrogram_to_numpy( | |
| y_hat_mel[0].data.cpu().numpy() | |
| ), | |
| "all/mel": utils.plot_spectrogram_to_numpy( | |
| mel[0].data.cpu().numpy() | |
| ), | |
| } | |
| utils.summarize( | |
| writer=writer, | |
| global_step=global_step, | |
| images=image_dict, | |
| scalars=scalar_dict, | |
| ) | |
| global_step += 1 | |
| if is_main_process and save_every_epoch != 0 and epoch % save_every_epoch == 0: | |
| if save_only_last: | |
| old_g_path = os.path.join( | |
| state_dir, f"G_{epoch - save_every_epoch}.pth" | |
| ) | |
| old_d_path = os.path.join( | |
| state_dir, f"D_{epoch - save_every_epoch}.pth" | |
| ) | |
| old_wav_path = os.path.join( | |
| state_dir, f"wav_sample_{epoch - save_every_epoch}" | |
| ) | |
| if os.path.exists(old_g_path): | |
| os.remove(old_g_path) | |
| if os.path.exists(old_d_path): | |
| os.remove(old_d_path) | |
| if os.path.exists(old_wav_path): | |
| shutil.rmtree(old_wav_path) | |
| if save_wav_with_checkpoint: | |
| with autocast(enabled=config.train.fp16_run): | |
| with torch.no_grad(): | |
| if f0: | |
| pred_wave = net_g.infer(phone, phone_lengths, pitch, pitchf, sid)[0] | |
| else: | |
| pred_wave = net_g.infer(phone, phone_lengths, sid)[0] | |
| os.makedirs(os.path.join(state_dir, f"wav_sample_{epoch}"), exist_ok=True) | |
| for i in range(pred_wave.shape[0]): | |
| torchaudio.save(filepath=os.path.join(state_dir, f"wav_sample_{epoch}", f"{i:02}_y_true.wav"), src=wave[i].detach().cpu().float(), sample_rate=int(sample_rate[:-1] + "000")) | |
| torchaudio.save(filepath=os.path.join(state_dir, f"wav_sample_{epoch}", f"{i:02}_y_pred.wav"), src=pred_wave[i].detach().cpu().float(), sample_rate=int(sample_rate[:-1] + "000")) | |
| if augment: | |
| torchaudio.save(filepath=os.path.join(state_dir, f"wav_sample_{epoch}", f"{i:02}_y_aug.wav"), src=aug_wave[i].detach().cpu().float(), sample_rate=int(sample_rate[:-1] + "000")) | |
| utils.save_state( | |
| net_g, | |
| optim_g, | |
| config.train.learning_rate, | |
| epoch, | |
| os.path.join(state_dir, f"G_{epoch}.pth"), | |
| ) | |
| utils.save_state( | |
| net_d, | |
| optim_d, | |
| config.train.learning_rate, | |
| epoch, | |
| os.path.join(state_dir, f"D_{epoch}.pth"), | |
| ) | |
| save( | |
| net_g, | |
| config.version, | |
| sample_rate, | |
| f0, | |
| embedder_name, | |
| embedder_out_channels, | |
| embedding_output_layer, | |
| os.path.join(training_dir, "checkpoints", f"{model_name}-{epoch}.pth"), | |
| epoch, | |
| speaker_info | |
| ) | |
| scheduler_g.step() | |
| scheduler_d.step() | |
| if is_main_process: | |
| print("Training is done. The program is closed.") | |
| save( | |
| net_g, | |
| config.version, | |
| sample_rate, | |
| f0, | |
| embedder_name, | |
| embedder_out_channels, | |
| embedding_output_layer, | |
| os.path.join(out_dir, f"{model_name}.pth"), | |
| epoch, | |
| speaker_info | |
| ) | |