import glob import json import operator import os import shutil import time from random import shuffle from typing import * import faiss import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp import torchaudio import tqdm from sklearn.cluster import MiniBatchKMeans from torch.cuda.amp import GradScaler, autocast from torch.nn import functional as F from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from . import commons, utils from .checkpoints import save from .config import DatasetMetadata, TrainConfig from .data_utils import (DistributedBucketSampler, TextAudioCollate, TextAudioCollateMultiNSFsid, TextAudioLoader, TextAudioLoaderMultiNSFsid) from .losses import discriminator_loss, feature_loss, generator_loss, kl_loss from .mel_processing import mel_spectrogram_torch, spec_to_mel_torch from .models import (MultiPeriodDiscriminator, SynthesizerTrnMs256NSFSid, SynthesizerTrnMs256NSFSidNono) from .preprocessing.extract_feature import (MODELS_DIR, get_embedder, load_embedder) def is_audio_file(file: str): if "." not in file: return False ext = os.path.splitext(file)[1] return ext.lower() in [ ".wav", ".flac", ".ogg", ".mp3", ".m4a", ".wma", ".aiff", ] def glob_dataset( glob_str: str, speaker_id: int, multiple_speakers: bool = False, recursive: bool = True, training_dir: str = ".", ): globs = glob_str.split(",") speaker_count = 0 datasets_speakers = [] speaker_to_id_mapping = {} for glob_str in globs: if os.path.isdir(glob_str): if multiple_speakers: # Multispeaker format: # dataset_path/ # - speakername/ # - {wav name here}.wav # - ... # - next_speakername/ # - {wav name here}.wav # - ... # - ... print("Multispeaker dataset enabled; Processing speakers.") for dir in tqdm.tqdm(os.listdir(glob_str)): print("Speaker ID " + str(speaker_count) + ": " + dir) speaker_to_id_mapping[dir] = speaker_count speaker_path = glob_str + "/" + dir for audio in tqdm.tqdm(os.listdir(speaker_path)): if is_audio_file(glob_str + "/" + dir + "/" + audio): datasets_speakers.append((glob_str + "/" + dir + "/" + audio, speaker_count)) speaker_count += 1 with open(os.path.join(training_dir, "speaker_info.json"), "w") as outfile: print("Dumped speaker info to {}".format(os.path.join(training_dir, "speaker_info.json"))) json.dump(speaker_to_id_mapping, outfile) continue # Skip the normal speaker extend glob_str = os.path.join(glob_str, "**", "*") print("Single speaker dataset enabled; Processing speaker as ID " + str(speaker_id) + ".") datasets_speakers.extend( [ (file, speaker_id) for file in glob.iglob(glob_str, recursive=recursive) if is_audio_file(file) ] ) return sorted(datasets_speakers) def create_dataset_meta(training_dir: str, f0: bool): gt_wavs_dir = os.path.join(training_dir, "0_gt_wavs") co256_dir = os.path.join(training_dir, "3_feature256") def list_data(dir: str): files = [] for subdir in os.listdir(dir): speaker_dir = os.path.join(dir, subdir) for name in os.listdir(speaker_dir): files.append(os.path.join(subdir, name.split(".")[0])) return files names = set(list_data(gt_wavs_dir)) & set(list_data(co256_dir)) if f0: f0_dir = os.path.join(training_dir, "2a_f0") f0nsf_dir = os.path.join(training_dir, "2b_f0nsf") names = names & set(list_data(f0_dir)) & set(list_data(f0nsf_dir)) meta = { "files": {}, } for name in names: speaker_id = os.path.dirname(name).split("_")[0] speaker_id = int(speaker_id) if speaker_id.isdecimal() else 0 if f0: gt_wav_path = os.path.join(gt_wavs_dir, f"{name}.wav") co256_path = os.path.join(co256_dir, f"{name}.npy") f0_path = os.path.join(f0_dir, f"{name}.wav.npy") f0nsf_path = os.path.join(f0nsf_dir, f"{name}.wav.npy") meta["files"][name] = { "gt_wav": gt_wav_path, "co256": co256_path, "f0": f0_path, "f0nsf": f0nsf_path, "speaker_id": speaker_id, } else: gt_wav_path = os.path.join(gt_wavs_dir, f"{name}.wav") co256_path = os.path.join(co256_dir, f"{name}.npy") meta["files"][name] = { "gt_wav": gt_wav_path, "co256": co256_path, "speaker_id": speaker_id, } with open(os.path.join(training_dir, "meta.json"), "w") as f: json.dump(meta, f, indent=2) def change_speaker(net_g, speaker_info, embedder, embedding_output_layer, phone, phone_lengths, pitch, pitchf, spec_lengths): """ random change formant inspired by https://github.com/auspicious3000/contentvec/blob/d746688a32940f4bee410ed7c87ec9cf8ff04f74/contentvec/data/audio/audio_utils_1.py#L179 """ N = phone.shape[0] device = phone.device dtype = phone.dtype f0_bin = 256 f0_max = 1100.0 f0_min = 50.0 f0_mel_min = 1127 * np.log(1 + f0_min / 700) f0_mel_max = 1127 * np.log(1 + f0_max / 700) pitch_median = torch.median(pitchf, 1).values lo = 75. + 25. * (pitch_median >= 200).to(dtype=dtype) hi = 250. + 150. * (pitch_median >= 200).to(dtype=dtype) pitch_median = torch.clip(pitch_median, lo, hi).unsqueeze(1) shift_pitch = torch.exp2((1. - 2. * torch.rand(N)) / 4).unsqueeze(1).to(device, dtype) # ピッチを半オクターブの範囲でずらす new_sid = np.random.choice(np.arange(len(speaker_info))[speaker_info > 0], size=N) rel_pitch = pitchf / pitch_median new_pitch_median = torch.from_numpy(speaker_info[new_sid]).to(device, dtype).unsqueeze(1) * shift_pitch new_pitchf = new_pitch_median * rel_pitch new_sid = torch.from_numpy(new_sid).to(device) new_pitch = 1127. * torch.log(1. + new_pitchf / 700.) new_pitch = (pitch - f0_mel_min) * (f0_bin - 2.) / (f0_mel_max - f0_mel_min) + 1. new_pitch = torch.clip(new_pitch, 1, f0_bin - 1).to(dtype=torch.int) aug_wave = net_g.infer(phone, phone_lengths, new_pitch, new_pitchf, new_sid)[0] aug_wave_16k = torchaudio.functional.resample(aug_wave, net_g.sr, 16000, rolloff=0.99).squeeze(1) padding_mask = torch.arange(aug_wave_16k.shape[1]).unsqueeze(0).to(device) > (spec_lengths.unsqueeze(1) * 160).to(device) inputs = { "source": aug_wave_16k.to(device, dtype), "padding_mask": padding_mask.to(device), "output_layer": embedding_output_layer } logits = embedder.extract_features(**inputs) if phone.shape[-1] == 768: feats = logits[0] else: feats = embedder.final_proj(logits[0]) feats = torch.repeat_interleave(feats, 2, 1) new_phone = torch.zeros(phone.shape).to(device, dtype) new_phone[:, :feats.shape[1]] = feats[:, :phone.shape[1]] return new_phone.to(device), aug_wave def change_speaker_nono(net_g, embedder, embedding_output_layer, phone, phone_lengths, spec_lengths): """ random change formant inspired by https://github.com/auspicious3000/contentvec/blob/d746688a32940f4bee410ed7c87ec9cf8ff04f74/contentvec/data/audio/audio_utils_1.py#L179 """ N = phone.shape[0] device = phone.device dtype = phone.dtype new_sid = np.random.randint(net_g.spk_embed_dim, size=N) new_sid = torch.from_numpy(new_sid).to(device) aug_wave = net_g.infer(phone, phone_lengths, new_sid)[0] aug_wave_16k = torchaudio.functional.resample(aug_wave, net_g.sr, 16000, rolloff=0.99).squeeze(1) padding_mask = torch.arange(aug_wave_16k.shape[1]).unsqueeze(0).to(device) > (spec_lengths.unsqueeze(1) * 160).to(device) inputs = { "source": aug_wave_16k.to(device, dtype), "padding_mask": padding_mask.to(device), "output_layer": embedding_output_layer } logits = embedder.extract_features(**inputs) if phone.shape[-1] == 768: feats = logits[0] else: feats = embedder.final_proj(logits[0]) feats = torch.repeat_interleave(feats, 2, 1) new_phone = torch.zeros(phone.shape).to(device, dtype) new_phone[:, :feats.shape[1]] = feats[:, :phone.shape[1]] return new_phone.to(device), aug_wave def train_index( training_dir: str, model_name: str, out_dir: str, emb_ch: int, num_cpu_process: int, maximum_index_size: Optional[int], ): checkpoint_path = os.path.join(out_dir, model_name) feature_256_dir = os.path.join(training_dir, "3_feature256") index_dir = os.path.join(os.path.dirname(checkpoint_path), f"{model_name}_index") os.makedirs(index_dir, exist_ok=True) for speaker_id in tqdm.tqdm( sorted([dir for dir in os.listdir(feature_256_dir) if dir.isdecimal()]) ): feature_256_spk_dir = os.path.join(feature_256_dir, speaker_id) speaker_id = int(speaker_id) npys = [] for name in [ os.path.join(feature_256_spk_dir, file) for file in os.listdir(feature_256_spk_dir) if file.endswith(".npy") ]: phone = np.load(os.path.join(feature_256_spk_dir, name)) npys.append(phone) # shuffle big_npy to prevent reproducing the sound source big_npy = np.concatenate(npys, 0) big_npy_idx = np.arange(big_npy.shape[0]) np.random.shuffle(big_npy_idx) big_npy = big_npy[big_npy_idx] if not maximum_index_size is None and big_npy.shape[0] > maximum_index_size: kmeans = MiniBatchKMeans( n_clusters=maximum_index_size, batch_size=256 * num_cpu_process, init="random", compute_labels=False, ) kmeans.fit(big_npy) big_npy = kmeans.cluster_centers_ # recommend parameter in https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index emb_ch = big_npy.shape[1] emb_ch_half = emb_ch // 2 n_ivf = int(8 * np.sqrt(big_npy.shape[0])) if big_npy.shape[0] >= 1_000_000: index = faiss.index_factory( emb_ch, f"IVF{n_ivf},PQ{emb_ch_half}x4fsr,RFlat" ) else: index = faiss.index_factory(emb_ch, f"IVF{n_ivf},Flat") index.train(big_npy) batch_size_add = 8192 for i in range(0, big_npy.shape[0], batch_size_add): index.add(big_npy[i : i + batch_size_add]) np.save( os.path.join(index_dir, f"{model_name}.{speaker_id}.big.npy"), big_npy, ) faiss.write_index( index, os.path.join(index_dir, f"{model_name}.{speaker_id}.index"), ) def train_model( gpus: List[int], config: TrainConfig, training_dir: str, model_name: str, out_dir: str, sample_rate: int, f0: bool, batch_size: int, augment: bool, augment_path: Optional[str], speaker_info_path: Optional[str], cache_batch: bool, total_epoch: int, save_every_epoch: int, save_wav_with_checkpoint: bool, pretrain_g: str, pretrain_d: str, embedder_name: str, embedding_output_layer: int, save_only_last: bool = False, device: Optional[Union[str, torch.device]] = None, ): os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(utils.find_empty_port()) deterministic = torch.backends.cudnn.deterministic benchmark = torch.backends.cudnn.benchmark PREV_CUDA_VISIBLE_DEVICES = os.environ.get("CUDA_VISIBLE_DEVICES", None) torch.backends.cudnn.deterministic = False torch.backends.cudnn.benchmark = False os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(gpu) for gpu in gpus]) start = time.perf_counter() # Mac(MPS)でやると、mp.spawnでなんかトラブルが出るので普通にtraining_runnerを呼び出す。 if device is not None: training_runner( 0, # rank 1, # world size config, training_dir, model_name, out_dir, sample_rate, f0, batch_size, augment, augment_path, speaker_info_path, cache_batch, total_epoch, save_every_epoch, save_wav_with_checkpoint, pretrain_g, pretrain_d, embedder_name, embedding_output_layer, save_only_last, device, ) else: mp.spawn( training_runner, nprocs=len(gpus), args=( len(gpus), config, training_dir, model_name, out_dir, sample_rate, f0, batch_size, augment, augment_path, speaker_info_path, cache_batch, total_epoch, save_every_epoch, save_wav_with_checkpoint, pretrain_g, pretrain_d, embedder_name, embedding_output_layer, save_only_last, device, ), ) end = time.perf_counter() print(f"Time: {end - start}") if PREV_CUDA_VISIBLE_DEVICES is None: del os.environ["CUDA_VISIBLE_DEVICES"] else: os.environ["CUDA_VISIBLE_DEVICES"] = PREV_CUDA_VISIBLE_DEVICES torch.backends.cudnn.deterministic = deterministic torch.backends.cudnn.benchmark = benchmark def training_runner( rank: int, world_size: List[int], config: TrainConfig, training_dir: str, model_name: str, out_dir: str, sample_rate: int, f0: bool, batch_size: int, augment: bool, augment_path: Optional[str], speaker_info_path: Optional[str], cache_in_gpu: bool, total_epoch: int, save_every_epoch: int, save_wav_with_checkpoint: bool, pretrain_g: str, pretrain_d: str, embedder_name: str, embedding_output_layer: int, save_only_last: bool = False, device: Optional[Union[str, torch.device]] = None, ): config.train.batch_size = batch_size log_dir = os.path.join(training_dir, "logs") state_dir = os.path.join(training_dir, "state") training_files_path = os.path.join(training_dir, "meta.json") training_meta = DatasetMetadata.parse_file(training_files_path) embedder_out_channels = config.model.emb_channels is_multi_process = world_size > 1 if device is not None: if type(device) == str: device = torch.device(device) global_step = 0 is_main_process = rank == 0 if is_main_process: os.makedirs(log_dir, exist_ok=True) os.makedirs(state_dir, exist_ok=True) writer = SummaryWriter(log_dir=log_dir) if torch.cuda.is_available(): torch.cuda.empty_cache() if not dist.is_initialized(): dist.init_process_group( backend="gloo", init_method="env://", rank=rank, world_size=world_size ) if is_multi_process: torch.cuda.set_device(rank) torch.manual_seed(config.train.seed) if f0: train_dataset = TextAudioLoaderMultiNSFsid(training_meta, config.data) else: train_dataset = TextAudioLoader(training_meta, config.data) train_sampler = DistributedBucketSampler( train_dataset, config.train.batch_size * world_size, [100, 200, 300, 400, 500, 600, 700, 800, 900], num_replicas=world_size, rank=rank, shuffle=True, ) if f0: collate_fn = TextAudioCollateMultiNSFsid() else: collate_fn = TextAudioCollate() train_loader = DataLoader( train_dataset, num_workers=4, shuffle=False, pin_memory=True, collate_fn=collate_fn, batch_sampler=train_sampler, persistent_workers=True, prefetch_factor=8, ) speaker_info = None if os.path.exists(os.path.join(training_dir, "speaker_info.json")): with open(os.path.join(training_dir, "speaker_info.json"), "r") as f: speaker_info = json.load(f) config.model.spk_embed_dim = len(speaker_info) if f0: net_g = SynthesizerTrnMs256NSFSid( config.data.filter_length // 2 + 1, config.train.segment_size // config.data.hop_length, **config.model.dict(), is_half=False, # config.train.fp16_run, sr=int(sample_rate[:-1] + "000"), ) else: net_g = SynthesizerTrnMs256NSFSidNono( config.data.filter_length // 2 + 1, config.train.segment_size // config.data.hop_length, **config.model.dict(), is_half=False, # config.train.fp16_run, sr=int(sample_rate[:-1] + "000"), ) if is_multi_process: net_g = net_g.cuda(rank) else: net_g = net_g.to(device=device) if config.version == "v1": periods = [2, 3, 5, 7, 11, 17] elif config.version == "v2": periods = [2, 3, 5, 7, 11, 17, 23, 37] net_d = MultiPeriodDiscriminator(config.model.use_spectral_norm, periods=periods) if is_multi_process: net_d = net_d.cuda(rank) else: net_d = net_d.to(device=device) optim_g = torch.optim.AdamW( net_g.parameters(), config.train.learning_rate, betas=config.train.betas, eps=config.train.eps, ) optim_d = torch.optim.AdamW( net_d.parameters(), config.train.learning_rate, betas=config.train.betas, eps=config.train.eps, ) last_d_state = utils.latest_checkpoint_path(state_dir, "D_*.pth") last_g_state = utils.latest_checkpoint_path(state_dir, "G_*.pth") if last_d_state is None or last_g_state is None: epoch = 1 global_step = 0 if os.path.exists(pretrain_g) and os.path.exists(pretrain_d): net_g_state = torch.load(pretrain_g, map_location="cpu")["model"] emb_spk_size = (config.model.spk_embed_dim, config.model.gin_channels) emb_phone_size = (config.model.hidden_channels, config.model.emb_channels) if emb_spk_size != net_g_state["emb_g.weight"].size(): original_weight = net_g_state["emb_g.weight"] net_g_state["emb_g.weight"] = original_weight.mean(dim=0, keepdims=True) * torch.ones(emb_spk_size, device=original_weight.device, dtype=original_weight.dtype) if emb_phone_size != net_g_state["enc_p.emb_phone.weight"].size(): # interpolate orig_shape = net_g_state["enc_p.emb_phone.weight"].size() if net_g_state["enc_p.emb_phone.weight"].dtype == torch.half: net_g_state["enc_p.emb_phone.weight"] = ( F.interpolate( net_g_state["enc_p.emb_phone.weight"] .float() .unsqueeze(0) .unsqueeze(0), size=emb_phone_size, mode="bilinear", ) .half() .squeeze(0) .squeeze(0) ) else: net_g_state["enc_p.emb_phone.weight"] = ( F.interpolate( net_g_state["enc_p.emb_phone.weight"] .unsqueeze(0) .unsqueeze(0), size=emb_phone_size, mode="bilinear", ) .squeeze(0) .squeeze(0) ) print( "interpolated pretrained state enc_p.emb_phone from", orig_shape, "to", emb_phone_size, ) if is_multi_process: net_g.module.load_state_dict(net_g_state) else: net_g.load_state_dict(net_g_state) del net_g_state if is_multi_process: net_d.module.load_state_dict( torch.load(pretrain_d, map_location="cpu")["model"] ) else: net_d.load_state_dict( torch.load(pretrain_d, map_location="cpu")["model"] ) if is_main_process: print(f"loaded pretrained {pretrain_g} {pretrain_d}") else: _, _, _, epoch = utils.load_checkpoint(last_d_state, net_d, optim_d) _, _, _, epoch = utils.load_checkpoint(last_g_state, net_g, optim_g) if is_main_process: print(f"loaded last state {last_d_state} {last_g_state}") epoch += 1 global_step = (epoch - 1) * len(train_loader) if augment: # load embedder embedder_filepath, _, embedder_load_from = get_embedder(embedder_name) if embedder_load_from == "local": embedder_filepath = os.path.join( MODELS_DIR, "embeddings", embedder_filepath ) embedder, _ = load_embedder(embedder_filepath, device) if not config.train.fp16_run: embedder = embedder.float() if (augment_path is not None): state_dict = torch.load(augment_path, map_location="cpu") if state_dict["f0"] == 1: augment_net_g = SynthesizerTrnMs256NSFSid( **state_dict["params"], is_half=config.train.fp16_run ) augment_speaker_info = np.load(speaker_info_path) else: augment_net_g = SynthesizerTrnMs256NSFSidNono( **state_dict["params"], is_half=config.train.fp16_run ) augment_net_g.load_state_dict(state_dict["weight"], strict=False) augment_net_g.eval().to(device) else: augment_net_g = net_g if f0: medians = [[] for _ in range(augment_net_g.spk_embed_dim)] for file in training_meta.files.values(): f0f = np.load(file.f0nsf) if np.any(f0f > 0): medians[file.speaker_id].append(np.median(f0f[f0f > 0])) augment_speaker_info = np.array([np.median(x) if len(x) else 0. for x in medians]) np.save(os.path.join(training_dir, "speaker_info.npy"), augment_speaker_info) if is_multi_process: net_g = DDP(net_g, device_ids=[rank]) net_d = DDP(net_d, device_ids=[rank]) scheduler_g = torch.optim.lr_scheduler.ExponentialLR( optim_g, gamma=config.train.lr_decay, last_epoch=epoch - 2 ) scheduler_d = torch.optim.lr_scheduler.ExponentialLR( optim_d, gamma=config.train.lr_decay, last_epoch=epoch - 2 ) scaler = GradScaler(enabled=config.train.fp16_run) cache = [] progress_bar = tqdm.tqdm(range((total_epoch - epoch + 1) * len(train_loader))) progress_bar.set_postfix(epoch=epoch) step = -1 + len(train_loader) * (epoch - 1) for epoch in range(epoch, total_epoch + 1): train_loader.batch_sampler.set_epoch(epoch) net_g.train() net_d.train() use_cache = len(cache) == len(train_loader) data = cache if use_cache else enumerate(train_loader) if is_main_process: lr = optim_g.param_groups[0]["lr"] if use_cache: shuffle(cache) for batch_idx, batch in data: step += 1 progress_bar.update(1) if f0: ( phone, phone_lengths, pitch, pitchf, spec, spec_lengths, wave, wave_lengths, sid, ) = batch else: ( phone, phone_lengths, spec, spec_lengths, wave, wave_lengths, sid, ) = batch if not use_cache: phone, phone_lengths = ( phone.to(device=device, non_blocking=True), phone_lengths.to(device=device, non_blocking=True), ) if f0: pitch, pitchf = ( pitch.to(device=device, non_blocking=True), pitchf.to(device=device, non_blocking=True), ) sid = sid.to(device=device, non_blocking=True) spec, spec_lengths = ( spec.to(device=device, non_blocking=True), spec_lengths.to(device=device, non_blocking=True), ) wave, wave_lengths = ( wave.to(device=device, non_blocking=True), wave_lengths.to(device=device, non_blocking=True), ) if cache_in_gpu: if f0: cache.append( ( batch_idx, ( phone, phone_lengths, pitch, pitchf, spec, spec_lengths, wave, wave_lengths, sid, ), ) ) else: cache.append( ( batch_idx, ( phone, phone_lengths, spec, spec_lengths, wave, wave_lengths, sid, ), ) ) with autocast(enabled=config.train.fp16_run): if augment: with torch.no_grad(): if type(augment_net_g) == SynthesizerTrnMs256NSFSid: new_phone, aug_wave = change_speaker(augment_net_g, augment_speaker_info, embedder, embedding_output_layer, phone, phone_lengths, pitch, pitchf, spec_lengths) else: new_phone, aug_wave = change_speaker_nono(augment_net_g, embedder, embedding_output_layer, phone, phone_lengths, spec_lengths) weight = np.power(.5, step / len(train_loader)) # 学習の初期はそのままのphone embeddingを使う phone = phone * weight + new_phone * (1. - weight) if f0: ( y_hat, ids_slice, x_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q), ) = net_g( phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid ) else: ( y_hat, ids_slice, x_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q), ) = net_g(phone, phone_lengths, spec, spec_lengths, sid) mel = spec_to_mel_torch( spec, config.data.filter_length, config.data.n_mel_channels, config.data.sampling_rate, config.data.mel_fmin, config.data.mel_fmax, ) y_mel = commons.slice_segments( mel, ids_slice, config.train.segment_size // config.data.hop_length ) with autocast(enabled=False): y_hat_mel = mel_spectrogram_torch( y_hat.float().squeeze(1), config.data.filter_length, config.data.n_mel_channels, config.data.sampling_rate, config.data.hop_length, config.data.win_length, config.data.mel_fmin, config.data.mel_fmax, ) if config.train.fp16_run == True and device != torch.device("mps"): y_hat_mel = y_hat_mel.half() wave_slice = commons.slice_segments( wave, ids_slice * config.data.hop_length, config.train.segment_size ) # slice # Discriminator y_d_hat_r, y_d_hat_g, _, _ = net_d(wave_slice, y_hat.detach()) with autocast(enabled=False): loss_disc, losses_disc_r, losses_disc_g = discriminator_loss( y_d_hat_r, y_d_hat_g ) optim_d.zero_grad() scaler.scale(loss_disc).backward() scaler.unscale_(optim_d) grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None) scaler.step(optim_d) with autocast(enabled=config.train.fp16_run): # Generator y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave_slice, y_hat) with autocast(enabled=False): loss_mel = F.l1_loss(y_mel, y_hat_mel) * config.train.c_mel loss_kl = ( kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * config.train.c_kl ) loss_fm = feature_loss(fmap_r, fmap_g) loss_gen, losses_gen = generator_loss(y_d_hat_g) loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl optim_g.zero_grad() scaler.scale(loss_gen_all).backward() scaler.unscale_(optim_g) grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None) scaler.step(optim_g) scaler.update() if is_main_process: progress_bar.set_postfix( epoch=epoch, loss_g=float(loss_gen_all) if loss_gen_all is not None else 0.0, loss_d=float(loss_disc) if loss_disc is not None else 0.0, lr=float(lr) if lr is not None else 0.0, use_cache=use_cache, ) if global_step % config.train.log_interval == 0: lr = optim_g.param_groups[0]["lr"] # Amor For Tensorboard display if loss_mel > 50: loss_mel = 50 if loss_kl > 5: loss_kl = 5 scalar_dict = { "loss/g/total": loss_gen_all, "loss/d/total": loss_disc, "learning_rate": lr, "grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g, } scalar_dict.update( { "loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl": loss_kl, } ) scalar_dict.update( {"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)} ) scalar_dict.update( { "loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r) } ) scalar_dict.update( { "loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g) } ) image_dict = { "slice/mel_org": utils.plot_spectrogram_to_numpy( y_mel[0].data.cpu().numpy() ), "slice/mel_gen": utils.plot_spectrogram_to_numpy( y_hat_mel[0].data.cpu().numpy() ), "all/mel": utils.plot_spectrogram_to_numpy( mel[0].data.cpu().numpy() ), } utils.summarize( writer=writer, global_step=global_step, images=image_dict, scalars=scalar_dict, ) global_step += 1 if is_main_process and save_every_epoch != 0 and epoch % save_every_epoch == 0: if save_only_last: old_g_path = os.path.join( state_dir, f"G_{epoch - save_every_epoch}.pth" ) old_d_path = os.path.join( state_dir, f"D_{epoch - save_every_epoch}.pth" ) old_wav_path = os.path.join( state_dir, f"wav_sample_{epoch - save_every_epoch}" ) if os.path.exists(old_g_path): os.remove(old_g_path) if os.path.exists(old_d_path): os.remove(old_d_path) if os.path.exists(old_wav_path): shutil.rmtree(old_wav_path) if save_wav_with_checkpoint: with autocast(enabled=config.train.fp16_run): with torch.no_grad(): if f0: pred_wave = net_g.infer(phone, phone_lengths, pitch, pitchf, sid)[0] else: pred_wave = net_g.infer(phone, phone_lengths, sid)[0] os.makedirs(os.path.join(state_dir, f"wav_sample_{epoch}"), exist_ok=True) for i in range(pred_wave.shape[0]): torchaudio.save(filepath=os.path.join(state_dir, f"wav_sample_{epoch}", f"{i:02}_y_true.wav"), src=wave[i].detach().cpu().float(), sample_rate=int(sample_rate[:-1] + "000")) torchaudio.save(filepath=os.path.join(state_dir, f"wav_sample_{epoch}", f"{i:02}_y_pred.wav"), src=pred_wave[i].detach().cpu().float(), sample_rate=int(sample_rate[:-1] + "000")) if augment: torchaudio.save(filepath=os.path.join(state_dir, f"wav_sample_{epoch}", f"{i:02}_y_aug.wav"), src=aug_wave[i].detach().cpu().float(), sample_rate=int(sample_rate[:-1] + "000")) utils.save_state( net_g, optim_g, config.train.learning_rate, epoch, os.path.join(state_dir, f"G_{epoch}.pth"), ) utils.save_state( net_d, optim_d, config.train.learning_rate, epoch, os.path.join(state_dir, f"D_{epoch}.pth"), ) save( net_g, config.version, sample_rate, f0, embedder_name, embedder_out_channels, embedding_output_layer, os.path.join(training_dir, "checkpoints", f"{model_name}-{epoch}.pth"), epoch, speaker_info ) scheduler_g.step() scheduler_d.step() if is_main_process: print("Training is done. The program is closed.") save( net_g, config.version, sample_rate, f0, embedder_name, embedder_out_channels, embedding_output_layer, os.path.join(out_dir, f"{model_name}.pth"), epoch, speaker_info )