|
|
import os
|
|
|
import datetime
|
|
|
import glob
|
|
|
import itertools
|
|
|
import json
|
|
|
import math
|
|
|
import re
|
|
|
|
|
|
import subprocess
|
|
|
import sys
|
|
|
import warnings
|
|
|
|
|
|
pid_data = {"process_pids": []}
|
|
|
os.environ["USE_LIBUV"] = "0" if sys.platform == "win32" else "1"
|
|
|
|
|
|
from typing import Tuple
|
|
|
from collections import deque
|
|
|
from distutils.util import strtobool
|
|
|
from random import randint, shuffle
|
|
|
from time import time as ttime, sleep
|
|
|
|
|
|
|
|
|
from tqdm import TqdmExperimentalWarning
|
|
|
from tqdm.rich import trange, tqdm
|
|
|
from pesq import pesq
|
|
|
import numpy as np
|
|
|
import psutil
|
|
|
|
|
|
warnings.filterwarnings("ignore", category=TqdmExperimentalWarning)
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torchaudio
|
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
from torch.amp import autocast
|
|
|
from torch.utils.data import DataLoader
|
|
|
from torch.nn import functional as F
|
|
|
from torch.nn.utils import clip_grad_norm_
|
|
|
import torch.distributed as dist
|
|
|
import torch.multiprocessing as mp
|
|
|
import auraloss
|
|
|
|
|
|
now_dir = os.getcwd()
|
|
|
sys.path.append(os.path.join(now_dir))
|
|
|
|
|
|
import rvc.lib.zluda
|
|
|
|
|
|
from utils import (
|
|
|
HParams,
|
|
|
plot_spectrogram_to_numpy,
|
|
|
summarize,
|
|
|
load_checkpoint,
|
|
|
save_checkpoint,
|
|
|
latest_checkpoint_path,
|
|
|
load_wav_to_torch,
|
|
|
load_config_from_json,
|
|
|
mel_spec_similarity,
|
|
|
flush_writer,
|
|
|
block_tensorboard_flush_on_exit,
|
|
|
si_sdr,
|
|
|
wave_to_mel,
|
|
|
small_model_naming,
|
|
|
old_session_cleanup,
|
|
|
verify_remap_checkpoint,
|
|
|
print_init_setup,
|
|
|
train_loader_safety,
|
|
|
verify_spk_dim,
|
|
|
)
|
|
|
from losses import (
|
|
|
discriminator_loss,
|
|
|
generator_loss,
|
|
|
feature_loss,
|
|
|
kl_loss,
|
|
|
phase_loss,
|
|
|
)
|
|
|
from mel_processing import (
|
|
|
spec_to_mel_torch,
|
|
|
MultiScaleMelSpectrogramLoss,
|
|
|
)
|
|
|
from rvc.train.process.extract_model import extract_model
|
|
|
from rvc.lib.algorithm import commons
|
|
|
from rvc.train.utils import replace_keys_in_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_name = sys.argv[1]
|
|
|
epoch_save_frequency = int(sys.argv[2])
|
|
|
total_epoch_count = int(sys.argv[3])
|
|
|
pretrainG = sys.argv[4]
|
|
|
pretrainD = sys.argv[5]
|
|
|
gpus = sys.argv[6]
|
|
|
batch_size = int(sys.argv[7])
|
|
|
sample_rate = int(sys.argv[8])
|
|
|
save_only_latest_net_models = strtobool(sys.argv[9])
|
|
|
save_weight_models = strtobool(sys.argv[10])
|
|
|
cache_data_in_gpu = strtobool(sys.argv[11])
|
|
|
use_warmup = strtobool(sys.argv[12])
|
|
|
warmup_duration = int(sys.argv[13])
|
|
|
cleanup = strtobool(sys.argv[14])
|
|
|
vocoder = sys.argv[15]
|
|
|
architecture = sys.argv[16]
|
|
|
optimizer_choice = sys.argv[17]
|
|
|
use_checkpointing = strtobool(sys.argv[18])
|
|
|
use_tf32 = bool(strtobool(sys.argv[19]))
|
|
|
use_benchmark = bool(strtobool(sys.argv[20]))
|
|
|
use_deterministic = bool(strtobool(sys.argv[21]))
|
|
|
spectral_loss = sys.argv[22]
|
|
|
lr_scheduler = sys.argv[23]
|
|
|
exp_decay_gamma = float(sys.argv[24])
|
|
|
use_validation = strtobool(sys.argv[25])
|
|
|
double_d_update = strtobool(sys.argv[26])
|
|
|
use_custom_lr = strtobool(sys.argv[27])
|
|
|
custom_lr_g, custom_lr_d = (float(sys.argv[28]), float(sys.argv[29])) if use_custom_lr else (None, None)
|
|
|
assert not use_custom_lr or (custom_lr_g and custom_lr_d), "Invalid custom LR values."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
current_dir = os.getcwd()
|
|
|
experiment_dir = os.path.join(current_dir, "logs", model_name)
|
|
|
config_save_path = os.path.join(experiment_dir, "config.json")
|
|
|
dataset_path = os.path.join(experiment_dir, "sliced_audios")
|
|
|
model_info_path = os.path.join(experiment_dir, "model_info.json")
|
|
|
|
|
|
|
|
|
|
|
|
config = load_config_from_json(config_save_path)
|
|
|
config.data.training_files = os.path.join(experiment_dir, "filelist.txt")
|
|
|
|
|
|
|
|
|
|
|
|
if config.train.bf16_run:
|
|
|
train_dtype = torch.bfloat16
|
|
|
elif config.train.fp16_run:
|
|
|
train_dtype = torch.float16
|
|
|
else:
|
|
|
train_dtype = torch.float32
|
|
|
|
|
|
|
|
|
|
|
|
global_step = 0
|
|
|
d_updates_per_step = 2 if double_d_update else 1
|
|
|
warmup_completed = False
|
|
|
from_scratch = False
|
|
|
use_lr_scheduler = lr_scheduler != "none"
|
|
|
|
|
|
|
|
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = use_tf32
|
|
|
torch.backends.cudnn.allow_tf32 = use_tf32
|
|
|
torch.backends.cudnn.benchmark = use_benchmark
|
|
|
torch.backends.cudnn.deterministic = use_deterministic
|
|
|
|
|
|
|
|
|
|
|
|
randomized = False
|
|
|
benchmark_mode = True
|
|
|
enable_persistent_workers = True
|
|
|
debug_shapes = False
|
|
|
|
|
|
|
|
|
|
|
|
c_stft = 21.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging
|
|
|
logging.getLogger("torch").setLevel(logging.ERROR)
|
|
|
|
|
|
|
|
|
class EpochRecorder:
|
|
|
"""
|
|
|
Records the time elapsed per epoch.
|
|
|
"""
|
|
|
|
|
|
def __init__(self):
|
|
|
self.last_time = ttime()
|
|
|
|
|
|
def record(self):
|
|
|
"""
|
|
|
Records the elapsed time and returns a formatted string.
|
|
|
"""
|
|
|
now_time = ttime()
|
|
|
elapsed_time = now_time - self.last_time
|
|
|
self.last_time = now_time
|
|
|
elapsed_time = round(elapsed_time, 1)
|
|
|
elapsed_time_str = str(datetime.timedelta(seconds=int(elapsed_time)))
|
|
|
current_time = datetime.datetime.now().strftime("%H:%M:%S")
|
|
|
|
|
|
return f"Current time: {current_time} | Time per epoch: {elapsed_time_str}"
|
|
|
|
|
|
def setup_env_and_distr(rank, n_gpus, device, device_id, config):
|
|
|
if rank == 0:
|
|
|
writer_eval = SummaryWriter(
|
|
|
log_dir=os.path.join(experiment_dir, "eval"),
|
|
|
flush_secs=86400
|
|
|
)
|
|
|
block_tensorboard_flush_on_exit(writer_eval)
|
|
|
else:
|
|
|
writer_eval = None
|
|
|
|
|
|
dist.init_process_group(
|
|
|
backend="gloo" if sys.platform == "win32" or device.type != "cuda" else "nccl",
|
|
|
init_method="env://",
|
|
|
world_size=n_gpus if device.type == "cuda" else 1,
|
|
|
rank=rank if device.type == "cuda" else 0,
|
|
|
)
|
|
|
|
|
|
torch.manual_seed(config.train.seed)
|
|
|
if torch.cuda.is_available():
|
|
|
torch.cuda.set_device(device_id)
|
|
|
|
|
|
return writer_eval
|
|
|
|
|
|
def prepare_dataloaders(config, n_gpus, rank, batch_size, use_validation, benchmark_mode):
|
|
|
from data_utils import (
|
|
|
DistributedBucketSampler,
|
|
|
TextAudioCollateMultiNSFsid,
|
|
|
TextAudioLoaderMultiNSFsid
|
|
|
)
|
|
|
|
|
|
if not benchmark_mode and use_validation:
|
|
|
full_dataset = TextAudioLoaderMultiNSFsid(config.data)
|
|
|
train_len = int(0.90 * len(full_dataset))
|
|
|
val_len = len(full_dataset) - train_len
|
|
|
train_dataset, val_dataset = torch.utils.data.random_split(
|
|
|
full_dataset, [train_len, val_len], generator=torch.Generator().manual_seed(config.train.seed)
|
|
|
)
|
|
|
train_dataset.lengths = [full_dataset.lengths[i] for i in train_dataset.indices]
|
|
|
val_dataset.lengths = [full_dataset.lengths[i] for i in val_dataset.indices]
|
|
|
else:
|
|
|
train_dataset = TextAudioLoaderMultiNSFsid(config.data)
|
|
|
val_dataset = None
|
|
|
|
|
|
train_sampler = DistributedBucketSampler(
|
|
|
train_dataset,
|
|
|
batch_size * n_gpus,
|
|
|
[50, 100, 200, 300, 400, 500, 600, 700, 800, 900],
|
|
|
num_replicas=n_gpus,
|
|
|
rank=rank,
|
|
|
shuffle=True
|
|
|
)
|
|
|
|
|
|
collate_fn = TextAudioCollateMultiNSFsid()
|
|
|
train_loader = DataLoader(
|
|
|
train_dataset,
|
|
|
num_workers=4,
|
|
|
shuffle=False,
|
|
|
pin_memory=True,
|
|
|
collate_fn=collate_fn,
|
|
|
batch_sampler=train_sampler,
|
|
|
persistent_workers=enable_persistent_workers,
|
|
|
prefetch_factor=8
|
|
|
)
|
|
|
val_loader = None
|
|
|
if val_dataset:
|
|
|
val_sampler = DistributedBucketSampler(
|
|
|
val_dataset,
|
|
|
batch_size * n_gpus,
|
|
|
[50, 100, 200, 300, 400, 500, 600, 700, 800, 900],
|
|
|
num_replicas=n_gpus,
|
|
|
rank=rank,
|
|
|
shuffle=False
|
|
|
)
|
|
|
val_loader = DataLoader(
|
|
|
val_dataset, batch_sampler=val_sampler, shuffle=False, collate_fn=collate_fn,
|
|
|
num_workers=1, pin_memory=True
|
|
|
)
|
|
|
|
|
|
train_loader_safety(benchmark_mode, train_loader)
|
|
|
|
|
|
return train_loader, val_loader
|
|
|
|
|
|
def get_g_model(config, sample_rate, vocoder, use_checkpointing, randomized):
|
|
|
from rvc.lib.algorithm.synthesizers import Synthesizer
|
|
|
return Synthesizer(
|
|
|
config.data.filter_length // 2 + 1,
|
|
|
config.train.segment_size // config.data.hop_length,
|
|
|
**config.model,
|
|
|
use_f0 = True,
|
|
|
sr = sample_rate,
|
|
|
vocoder = vocoder,
|
|
|
checkpointing = use_checkpointing,
|
|
|
randomized = randomized,
|
|
|
)
|
|
|
|
|
|
def get_d_model(config, vocoder, use_checkpointing):
|
|
|
if vocoder == "RingFormer":
|
|
|
from rvc.lib.algorithm.discriminators.multi import MPD_MSD_MRD_Combined
|
|
|
|
|
|
return MPD_MSD_MRD_Combined(
|
|
|
config.model.use_spectral_norm,
|
|
|
use_checkpointing=use_checkpointing,
|
|
|
**dict(config.mrd)
|
|
|
)
|
|
|
else:
|
|
|
from rvc.lib.algorithm.discriminators.multi import MPD_MSD_Combined
|
|
|
|
|
|
return MPD_MSD_Combined(
|
|
|
config.model.use_spectral_norm,
|
|
|
use_checkpointing=use_checkpointing
|
|
|
)
|
|
|
|
|
|
def get_optimizers(
|
|
|
net_g,
|
|
|
net_d,
|
|
|
config,
|
|
|
optimizer_choice,
|
|
|
custom_lr_g,
|
|
|
custom_lr_d,
|
|
|
use_custom_lr,
|
|
|
total_epoch_count,
|
|
|
train_loader
|
|
|
):
|
|
|
|
|
|
common_args_g = dict(
|
|
|
lr=custom_lr_g if use_custom_lr else config.train.learning_rate,
|
|
|
betas=(0.8, 0.99),
|
|
|
eps=1e-9,
|
|
|
weight_decay=0,
|
|
|
)
|
|
|
common_args_d = dict(
|
|
|
lr=custom_lr_d if use_custom_lr else config.train.learning_rate,
|
|
|
betas=(0.8, 0.99),
|
|
|
eps=1e-9,
|
|
|
weight_decay=0,
|
|
|
)
|
|
|
common_args_g_bf16 = dict(
|
|
|
lr=custom_lr_g if use_custom_lr else config.train.learning_rate,
|
|
|
betas=(0.8, 0.99),
|
|
|
eps=1e-9,
|
|
|
weight_decay=0.0,
|
|
|
use_kahan_summation=True,
|
|
|
)
|
|
|
common_args_d_bf16 = dict(
|
|
|
lr=custom_lr_d if use_custom_lr else config.train.learning_rate,
|
|
|
betas=(0.8, 0.99),
|
|
|
eps=1e-9,
|
|
|
weight_decay=0.0,
|
|
|
use_kahan_summation=True,
|
|
|
)
|
|
|
if optimizer_choice == "Ranger21":
|
|
|
from rvc.train.custom_optimizers.ranger21 import Ranger21
|
|
|
ranger_args = dict(
|
|
|
num_epochs=total_epoch_count,
|
|
|
num_batches_per_epoch=len(train_loader),
|
|
|
use_madgrad=False,
|
|
|
use_warmup=False,
|
|
|
warmdown_active=False,
|
|
|
use_cheb=False,
|
|
|
lookahead_active=True,
|
|
|
normloss_active=False,
|
|
|
normloss_factor=1e-4,
|
|
|
softplus=False,
|
|
|
use_adaptive_gradient_clipping=True,
|
|
|
agc_clipping_value=0.01,
|
|
|
agc_eps=1e-3,
|
|
|
using_gc=True,
|
|
|
gc_conv_only=True,
|
|
|
using_normgc=False,
|
|
|
)
|
|
|
optim_g = Ranger21(filter(lambda p: p.requires_grad, net_g.parameters()), **common_args_g, **ranger_args)
|
|
|
optim_d = Ranger21(net_d.parameters(), **common_args_d, **ranger_args)
|
|
|
|
|
|
elif optimizer_choice == "RAdam":
|
|
|
import torch_optimizer
|
|
|
optim_g = torch_optimizer.RAdam(filter(lambda p: p.requires_grad, net_g.parameters()), **common_args_g)
|
|
|
optim_d = torch_optimizer.RAdam(net_d.parameters(), **common_args_d)
|
|
|
|
|
|
elif optimizer_choice == "AdamW":
|
|
|
optim_g = torch.optim.AdamW(filter(lambda p: p.requires_grad, net_g.parameters()), **common_args_g)
|
|
|
optim_d = torch.optim.AdamW(net_d.parameters(), **common_args_d)
|
|
|
|
|
|
elif optimizer_choice == "AdamW_BF16":
|
|
|
from rvc.train.custom_optimizers.adamw_bfloat import BFF_AdamW
|
|
|
optim_g = BFF_AdamW(filter(lambda p: p.requires_grad, net_g.parameters()), **common_args_g_bf16)
|
|
|
optim_d = BFF_AdamW(net_d.parameters(), **common_args_d_bf16)
|
|
|
|
|
|
elif optimizer_choice == "Prodigy":
|
|
|
from rvc.train.custom_optimizers.prodigy import Prodigy
|
|
|
prodigy_args = dict(
|
|
|
betas=(0.8, 0.99),
|
|
|
weight_decay=0.0,
|
|
|
decouple=True,
|
|
|
)
|
|
|
optim_g = Prodigy(filter(lambda p: p.requires_grad, net_g.parameters()), lr=custom_lr_g if use_custom_lr else 1.0, **prodigy_args)
|
|
|
optim_d = Prodigy(net_d.parameters(), lr=custom_lr_d if use_custom_lr else 1.0, **prodigy_args)
|
|
|
|
|
|
elif optimizer_choice == "DiffGrad":
|
|
|
from rvc.train.custom_optimizers.diffgrad import diffgrad
|
|
|
optim_g = diffgrad(filter(lambda p: p.requires_grad, net_g.parameters()), **common_args_g)
|
|
|
optim_d = diffgrad(net_d.parameters(), **common_args_d)
|
|
|
|
|
|
else:
|
|
|
raise ValueError(f"Unknown optimizer choice: {optimizer_choice}")
|
|
|
return optim_g, optim_d
|
|
|
|
|
|
def setup_models_for_training(net_g, net_d, device, device_id, n_gpus):
|
|
|
net_g = net_g.to(device_id) if device.type == "cuda" else net_g.to(device)
|
|
|
net_d = net_d.to(device_id) if device.type == "cuda" else net_d.to(device)
|
|
|
if n_gpus > 1 and device.type == "cuda":
|
|
|
net_g = DDP(net_g, device_ids=[device_id])
|
|
|
net_d = DDP(net_d, device_ids=[device_id])
|
|
|
|
|
|
return net_g, net_d
|
|
|
|
|
|
def load_models_and_optimizers(config, pretrainG, pretrainD, vocoder, use_checkpointing, randomized, sample_rate, optimizer_choice, custom_lr_g, custom_lr_d, use_custom_lr, total_epoch_count, train_loader, device, device_id, n_gpus, rank):
|
|
|
try:
|
|
|
print(" ██████ Starting the training ...")
|
|
|
|
|
|
|
|
|
g_checkpoint_path = latest_checkpoint_path(experiment_dir, "G_*.pth")
|
|
|
d_checkpoint_path = latest_checkpoint_path(experiment_dir, "D_*.pth")
|
|
|
|
|
|
|
|
|
if g_checkpoint_path and d_checkpoint_path:
|
|
|
|
|
|
|
|
|
net_g = get_g_model(config, sample_rate, vocoder, use_checkpointing, randomized)
|
|
|
net_d = get_d_model(config, vocoder, use_checkpointing)
|
|
|
|
|
|
|
|
|
|
|
|
optim_g, optim_d = get_optimizers(net_g, net_d, config, optimizer_choice, custom_lr_g, custom_lr_d, use_custom_lr, total_epoch_count, train_loader)
|
|
|
|
|
|
|
|
|
net_g, net_d = setup_models_for_training(net_g, net_d, device, device_id, n_gpus)
|
|
|
|
|
|
|
|
|
_, _, _, epoch_str = load_checkpoint(architecture, g_checkpoint_path, net_g, optim_g)
|
|
|
_, _, _, epoch_str = load_checkpoint(architecture, d_checkpoint_path, net_d, optim_d)
|
|
|
|
|
|
epoch_str += 1
|
|
|
global_step = (epoch_str - 1) * len(train_loader)
|
|
|
print(f"[RESUMING] (G) & (D) at global_step: {global_step} and epoch count: {epoch_str - 1}")
|
|
|
else:
|
|
|
raise FileNotFoundError("No checkpoints found.")
|
|
|
|
|
|
except FileNotFoundError:
|
|
|
|
|
|
epoch_str = 1
|
|
|
global_step = 0
|
|
|
|
|
|
|
|
|
net_g = get_g_model(config, sample_rate, vocoder, use_checkpointing, randomized)
|
|
|
net_d = get_d_model(config, vocoder, use_checkpointing)
|
|
|
|
|
|
|
|
|
if (pretrainG != "" and pretrainG != "None"):
|
|
|
if rank == 0:
|
|
|
print(f"Loading pretrained (G) '{pretrainG}'")
|
|
|
verify_remap_checkpoint(pretrainG, net_g, architecture)
|
|
|
|
|
|
|
|
|
|
|
|
if pretrainD != "" and pretrainD != "None":
|
|
|
if rank == 0:
|
|
|
print(f"Loading pretrained (D) '{pretrainD}'")
|
|
|
verify_remap_checkpoint(pretrainD, net_d, architecture)
|
|
|
|
|
|
|
|
|
net_g, net_d = setup_models_for_training(net_g, net_d, device, device_id, n_gpus)
|
|
|
|
|
|
optim_g, optim_d = get_optimizers(net_g, net_d, config, optimizer_choice, custom_lr_g, custom_lr_d, use_custom_lr, total_epoch_count, train_loader)
|
|
|
return net_g, net_d, optim_g, optim_d, epoch_str, global_step
|
|
|
|
|
|
def prepare_schedulers(optim_g, optim_d, use_warmup, warmup_duration, use_lr_scheduler, lr_scheduler, exp_decay_gamma, total_epoch_count, epoch_str):
|
|
|
warmup_scheduler_g, warmup_scheduler_d = None, None
|
|
|
scheduler_g, scheduler_d = None, None
|
|
|
|
|
|
if use_warmup:
|
|
|
warmup_scheduler_g = torch.optim.lr_scheduler.LambdaLR(
|
|
|
optim_g, lr_lambda=lambda epoch: min(1.0, (epoch + 1) / warmup_duration)
|
|
|
)
|
|
|
warmup_scheduler_d = torch.optim.lr_scheduler.LambdaLR(
|
|
|
optim_d, lr_lambda=lambda epoch: min(1.0, (epoch + 1) / warmup_duration)
|
|
|
)
|
|
|
|
|
|
if not use_warmup:
|
|
|
for param_group in optim_g.param_groups:
|
|
|
if 'initial_lr' not in param_group:
|
|
|
param_group['initial_lr'] = param_group['lr']
|
|
|
for param_group in optim_d.param_groups:
|
|
|
if 'initial_lr' not in param_group:
|
|
|
param_group['initial_lr'] = param_group['lr']
|
|
|
|
|
|
if use_lr_scheduler:
|
|
|
if lr_scheduler == "exp decay":
|
|
|
|
|
|
scheduler_g = torch.optim.lr_scheduler.ExponentialLR( optim_g, gamma=exp_decay_gamma, last_epoch=epoch_str - 1 )
|
|
|
scheduler_d = torch.optim.lr_scheduler.ExponentialLR( optim_d, gamma=exp_decay_gamma, last_epoch=epoch_str - 1 )
|
|
|
elif lr_scheduler == "cosine annealing":
|
|
|
scheduler_g = torch.optim.lr_scheduler.CosineAnnealingLR( optim_g, T_max=total_epoch_count, eta_min=3e-5, last_epoch=epoch_str - 1 )
|
|
|
scheduler_d = torch.optim.lr_scheduler.CosineAnnealingLR( optim_d, T_max=total_epoch_count, eta_min=3e-5, last_epoch=epoch_str - 1 )
|
|
|
|
|
|
return warmup_scheduler_g, warmup_scheduler_d, scheduler_g, scheduler_d
|
|
|
|
|
|
def get_reference_sample(train_loader, device, config):
|
|
|
reference_path = os.path.join("logs", "reference")
|
|
|
use_custom_ref = all([
|
|
|
os.path.isfile(os.path.join(reference_path, "ref_feats.npy")),
|
|
|
os.path.isfile(os.path.join(reference_path, "ref_f0c.npy")),
|
|
|
os.path.isfile(os.path.join(reference_path, "ref_f0f.npy")),
|
|
|
])
|
|
|
|
|
|
if use_custom_ref:
|
|
|
print("[REFERENCE] Using custom reference input from 'logs\\reference\\'")
|
|
|
|
|
|
phone = torch.FloatTensor(np.repeat(np.load(os.path.join(reference_path, "ref_feats.npy")), 2, axis=0)).unsqueeze(0).to(device)
|
|
|
pitch = torch.LongTensor(np.load(os.path.join(reference_path, "ref_f0c.npy"))).unsqueeze(0).to(device)
|
|
|
pitchf = torch.FloatTensor(np.load(os.path.join(reference_path, "ref_f0f.npy"))).unsqueeze(0).to(device)
|
|
|
|
|
|
min_len = min(phone.shape[1], pitch.shape[1], pitchf.shape[1])
|
|
|
|
|
|
phone, pitch, pitchf = phone[:, :min_len, :], pitch[:, :min_len], pitchf[:, :min_len]
|
|
|
phone_lengths = torch.LongTensor([phone.shape[1]]).to(device)
|
|
|
|
|
|
sid = torch.LongTensor([0]).to(device)
|
|
|
else:
|
|
|
print("[REFERENCE] No custom reference found. Fetching from the first batch of the train_loader.")
|
|
|
|
|
|
info = next(iter(train_loader))
|
|
|
phone, phone_lengths, pitch, pitchf, _, _, _, _, sid = info
|
|
|
phone, phone_lengths, pitch, pitchf, sid = phone.to(device), phone_lengths.to(device), pitch.to(device), pitchf.to(device), sid.to(device)
|
|
|
|
|
|
batch_indices = []
|
|
|
for batch in train_loader.batch_sampler:
|
|
|
batch_indices = batch
|
|
|
break
|
|
|
|
|
|
if isinstance(train_loader.dataset, torch.utils.data.Subset):
|
|
|
file_paths = train_loader.dataset.dataset.get_file_paths(batch_indices)
|
|
|
else:
|
|
|
file_paths = train_loader.dataset.get_file_paths(batch_indices)
|
|
|
|
|
|
file_name = os.path.basename(file_paths[0])
|
|
|
print(f"[REFERENCE] Origin of the ref: {file_name}")
|
|
|
|
|
|
return (phone, phone_lengths, pitch, pitchf, sid, config.train.seed)
|
|
|
|
|
|
def main():
|
|
|
"""
|
|
|
Main function to start the training process.
|
|
|
"""
|
|
|
global gpus
|
|
|
|
|
|
os.environ["MASTER_ADDR"] = "localhost"
|
|
|
os.environ["MASTER_PORT"] = str(randint(20000, 55555))
|
|
|
|
|
|
wavs = glob.glob(os.path.join(os.path.join(experiment_dir, "sliced_audios"), "*.wav"))
|
|
|
if wavs:
|
|
|
_, sr = load_wav_to_torch(wavs[0])
|
|
|
if sr != sample_rate:
|
|
|
print(f"Error: Pretrained model sample rate ({sample_rate} Hz) does not match dataset audio sample rate ({sr} Hz).")
|
|
|
os._exit(1)
|
|
|
else:
|
|
|
print("No wav file found.")
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
device = torch.device("cuda")
|
|
|
gpus = [int(item) for item in gpus.split("-")]
|
|
|
n_gpus = len(gpus)
|
|
|
else:
|
|
|
device = torch.device("cpu")
|
|
|
gpus = [0]
|
|
|
n_gpus = 1
|
|
|
print("No GPU detected, fallback to CPU. This will take a very long time ...")
|
|
|
|
|
|
def start():
|
|
|
"""
|
|
|
Starts the training process with multi-GPU support or CPU.
|
|
|
"""
|
|
|
children = []
|
|
|
|
|
|
for rank, device_id in enumerate(gpus):
|
|
|
subproc = mp.Process(
|
|
|
target=run,
|
|
|
args=(
|
|
|
rank,
|
|
|
n_gpus,
|
|
|
experiment_dir,
|
|
|
pretrainG,
|
|
|
pretrainD,
|
|
|
total_epoch_count,
|
|
|
epoch_save_frequency,
|
|
|
save_weight_models,
|
|
|
save_only_latest_net_models,
|
|
|
config,
|
|
|
device,
|
|
|
device_id,
|
|
|
),
|
|
|
)
|
|
|
children.append(subproc)
|
|
|
subproc.start()
|
|
|
pid_data["process_pids"].append(subproc.pid)
|
|
|
|
|
|
for i in range(n_gpus):
|
|
|
children[i].join()
|
|
|
|
|
|
if cleanup:
|
|
|
old_session_cleanup(now_dir, model_name)
|
|
|
start()
|
|
|
|
|
|
def run(
|
|
|
rank,
|
|
|
n_gpus,
|
|
|
experiment_dir,
|
|
|
pretrainG,
|
|
|
pretrainD,
|
|
|
total_epoch_count,
|
|
|
epoch_save_frequency,
|
|
|
save_weight_models,
|
|
|
save_only_latest_net_models,
|
|
|
config,
|
|
|
device,
|
|
|
device_id,
|
|
|
):
|
|
|
"""
|
|
|
Runs the training loop on a specific GPU or CPU.
|
|
|
|
|
|
Args:
|
|
|
rank (int): The rank of the current process within the distributed training setup.
|
|
|
n_gpus (int): The total number of GPUs available for training.
|
|
|
experiment_dir (str): The directory where experiment logs and checkpoints will be saved.
|
|
|
pretrainG (str): Path to the pre-trained generator model.
|
|
|
pretrainD (str): Path to the pre-trained discriminator model.
|
|
|
total_epoch_count (int): The total number of epochs for training.
|
|
|
epoch_save_frequency (int): Frequency of saving epochs.
|
|
|
save_weight_models (int): Whether to save small weight models. 0 for no, 1 for yes.
|
|
|
save_only_latest_net_models (int): Whether to save only latest G/D or for each epoch.
|
|
|
config (object): Configuration object containing training parameters.
|
|
|
device (torch.device): The device to use for training (CPU or GPU).
|
|
|
"""
|
|
|
global global_step, warmup_completed, optimizer_choice, from_scratch
|
|
|
|
|
|
if 'warmup_completed' not in globals():
|
|
|
warmup_completed = False
|
|
|
|
|
|
|
|
|
print_init_setup(
|
|
|
warmup_duration,
|
|
|
rank,
|
|
|
use_warmup,
|
|
|
config,
|
|
|
optimizer_choice,
|
|
|
d_updates_per_step,
|
|
|
use_validation,
|
|
|
lr_scheduler,
|
|
|
exp_decay_gamma
|
|
|
)
|
|
|
|
|
|
|
|
|
writer_eval = setup_env_and_distr(
|
|
|
rank,
|
|
|
n_gpus,
|
|
|
device,
|
|
|
device_id,
|
|
|
config
|
|
|
)
|
|
|
|
|
|
|
|
|
train_loader, val_loader = prepare_dataloaders(
|
|
|
config,
|
|
|
n_gpus,
|
|
|
rank,
|
|
|
batch_size,
|
|
|
use_validation,
|
|
|
benchmark_mode
|
|
|
)
|
|
|
|
|
|
|
|
|
spk_dim = verify_spk_dim(config, model_info_path, experiment_dir, latest_checkpoint_path, rank, pretrainG)
|
|
|
config.model.spk_embed_dim = spk_dim
|
|
|
|
|
|
|
|
|
if spectral_loss == "L1 Mel Loss":
|
|
|
fn_spectral_loss = torch.nn.L1Loss()
|
|
|
print(" ██████ Spectral loss: Single-Scale (L1) Mel loss function")
|
|
|
elif spectral_loss == "Multi-Scale Mel Loss":
|
|
|
fn_spectral_loss = MultiScaleMelSpectrogramLoss(sample_rate=sample_rate)
|
|
|
print(" ██████ Spectral loss: Multi-Scale Mel loss function")
|
|
|
elif spectral_loss == "Multi-Res STFT Loss":
|
|
|
fn_spectral_loss = auraloss.freq.MultiResolutionSTFTLoss(
|
|
|
fft_sizes = [1024, 2048, 512],
|
|
|
hop_sizes = [80, 160, 40],
|
|
|
win_lengths = [480, 960, 240],
|
|
|
window = "hann_window",
|
|
|
w_sc = 1.0,
|
|
|
w_log_mag = 1.0,
|
|
|
w_lin_mag = 0.0,
|
|
|
w_phs=0.0,
|
|
|
sample_rate = sample_rate,
|
|
|
scale = None,
|
|
|
n_bins = None,
|
|
|
perceptual_weighting = True,
|
|
|
scale_invariance = False,
|
|
|
output= "loss",
|
|
|
reduction = "mean",
|
|
|
mag_distance = "L1",
|
|
|
device=device,
|
|
|
)
|
|
|
print(" ██████ Spectral loss: Multi-Resolution STFT loss function")
|
|
|
else:
|
|
|
print("ERROR: Chosen spectral loss is undefined. Exiting.")
|
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
|
net_g, net_d, optim_g, optim_d, epoch_str, global_step = load_models_and_optimizers(
|
|
|
config,
|
|
|
pretrainG,
|
|
|
pretrainD,
|
|
|
vocoder,
|
|
|
use_checkpointing,
|
|
|
randomized,
|
|
|
sample_rate,
|
|
|
optimizer_choice,
|
|
|
custom_lr_g,
|
|
|
custom_lr_d,
|
|
|
use_custom_lr,
|
|
|
total_epoch_count,
|
|
|
train_loader,
|
|
|
device,
|
|
|
device_id,
|
|
|
n_gpus,
|
|
|
rank
|
|
|
)
|
|
|
|
|
|
|
|
|
if pretrainG in ["", "None"] and pretrainD in ["", "None"]:
|
|
|
from_scratch = True
|
|
|
if rank == 0:
|
|
|
print(" ██████ No pretrains used: Average loss disabled!")
|
|
|
|
|
|
|
|
|
warmup_scheduler_g, warmup_scheduler_d, scheduler_g, scheduler_d = prepare_schedulers(
|
|
|
optim_g,
|
|
|
optim_d,
|
|
|
use_warmup,
|
|
|
warmup_duration,
|
|
|
use_lr_scheduler,
|
|
|
lr_scheduler,
|
|
|
exp_decay_gamma,
|
|
|
total_epoch_count,
|
|
|
epoch_str
|
|
|
)
|
|
|
|
|
|
|
|
|
hann_window = torch.hann_window(config.model.gen_istft_n_fft).to(device) if vocoder == "RingFormer" else None
|
|
|
|
|
|
|
|
|
gradscaler = torch.amp.GradScaler(enabled=(device.type == "cuda" and train_dtype == torch.float16))
|
|
|
|
|
|
|
|
|
reference = get_reference_sample(train_loader, device, config)
|
|
|
|
|
|
|
|
|
cache = []
|
|
|
|
|
|
for epoch in range(epoch_str, total_epoch + 1):
|
|
|
training_loop(
|
|
|
rank,
|
|
|
epoch,
|
|
|
config,
|
|
|
[net_g, net_d],
|
|
|
[optim_g, optim_d],
|
|
|
train_loader,
|
|
|
val_loader if use_validation else None,
|
|
|
[writer_eval],
|
|
|
cache,
|
|
|
total_epoch_count,
|
|
|
epoch_save_frequency,
|
|
|
save_weight_models,
|
|
|
save_only_latest_net_models,
|
|
|
device,
|
|
|
device_id,
|
|
|
reference,
|
|
|
fn_spectral_loss,
|
|
|
n_gpus,
|
|
|
gradscaler,
|
|
|
hann_window,
|
|
|
)
|
|
|
if use_warmup and epoch <= warmup_duration:
|
|
|
if warmup_scheduler_g:
|
|
|
warmup_scheduler_g.step()
|
|
|
if warmup_scheduler_d:
|
|
|
warmup_scheduler_d.step()
|
|
|
|
|
|
|
|
|
if epoch == warmup_duration:
|
|
|
warmup_completed = True
|
|
|
print(f" ██████ Warmup completed at epochs: {warmup_duration}")
|
|
|
print(f" ██████ LR G: {optim_g.param_groups[0]['lr']}")
|
|
|
print(f" ██████ LR D: {optim_d.param_groups[0]['lr']}")
|
|
|
|
|
|
if lr_scheduler == "exp decay":
|
|
|
print(f" ██████ Starting the exponential lr decay with gamma of {exp_decay_gamma}")
|
|
|
elif lr_scheduler == "cosine annealing":
|
|
|
print(" ██████ Starting cosine annealing scheduler " )
|
|
|
|
|
|
if use_lr_scheduler and (not use_warmup or warmup_completed):
|
|
|
|
|
|
scheduler_g.step()
|
|
|
scheduler_d.step()
|
|
|
|
|
|
def training_loop(
|
|
|
rank,
|
|
|
epoch,
|
|
|
config,
|
|
|
nets,
|
|
|
optims,
|
|
|
train_loader,
|
|
|
val_loader,
|
|
|
writers,
|
|
|
cache,
|
|
|
total_epoch_count,
|
|
|
epoch_save_frequency,
|
|
|
save_weight_models,
|
|
|
save_only_latest_net_models,
|
|
|
device,
|
|
|
device_id,
|
|
|
reference,
|
|
|
fn_spectral_loss,
|
|
|
n_gpus,
|
|
|
gradscaler,
|
|
|
hann_window=None,
|
|
|
):
|
|
|
"""
|
|
|
Trains and evaluates the model for one epoch.
|
|
|
|
|
|
Args:
|
|
|
rank (int): Rank of the current process.
|
|
|
epoch (int): Current epoch number.
|
|
|
config (Namespace): Hyperparameters.
|
|
|
nets (list): List of models [net_g, net_d].
|
|
|
optims (list): List of optimizers [optim_g, net_d].
|
|
|
train_loader: training dataloader.
|
|
|
val_loader: validation dataloader.
|
|
|
writers (list): List of TensorBoard writers [writer_eval].
|
|
|
cache (list): List to cache data in GPU memory.
|
|
|
use_cpu (bool): Whether to use CPU for training.
|
|
|
"""
|
|
|
global global_step, warmup_completed, dynamic_c_kl
|
|
|
|
|
|
net_g, net_d = nets
|
|
|
optim_g, optim_d = optims
|
|
|
|
|
|
train_loader = train_loader if train_loader is not None else None
|
|
|
if not benchmark_mode and use_validation:
|
|
|
val_loader = val_loader if val_loader is not None else None
|
|
|
|
|
|
if writers is not None:
|
|
|
writer = writers[0]
|
|
|
|
|
|
train_loader.batch_sampler.set_epoch(epoch)
|
|
|
|
|
|
net_g.train()
|
|
|
net_d.train()
|
|
|
|
|
|
|
|
|
if device.type == "cuda" and cache_data_in_gpu:
|
|
|
data_iterator = cache
|
|
|
if cache == []:
|
|
|
for batch_idx, info in enumerate(train_loader):
|
|
|
|
|
|
info = [tensor.cuda(device_id, non_blocking=True) for tensor in info]
|
|
|
cache.append((batch_idx, info))
|
|
|
else:
|
|
|
shuffle(cache)
|
|
|
else:
|
|
|
data_iterator = enumerate(train_loader)
|
|
|
|
|
|
epoch_recorder = EpochRecorder()
|
|
|
|
|
|
if not from_scratch:
|
|
|
|
|
|
tensor_count = 7 if vocoder == "RingFormer" else 6
|
|
|
epoch_loss_tensor = torch.zeros(tensor_count, device=device)
|
|
|
num_batches_in_epoch = 0
|
|
|
|
|
|
avg_50_cache = {
|
|
|
"grad_norm_d_clipped_50": deque(maxlen=50),
|
|
|
"grad_norm_g_clipped_50": deque(maxlen=50),
|
|
|
"loss_disc_50": deque(maxlen=50),
|
|
|
"loss_adv_50": deque(maxlen=50),
|
|
|
"loss_gen_total_50": deque(maxlen=50),
|
|
|
"loss_fm_50": deque(maxlen=50),
|
|
|
"loss_mel_50": deque(maxlen=50),
|
|
|
"loss_kl_50": deque(maxlen=50),
|
|
|
|
|
|
}
|
|
|
if vocoder == "RingFormer":
|
|
|
avg_50_cache.update({
|
|
|
"loss_sd_50": deque(maxlen=50),
|
|
|
})
|
|
|
|
|
|
use_amp = (config.train.bf16_run or config.train.fp16_run) and device.type == "cuda"
|
|
|
|
|
|
with tqdm(total=len(train_loader), leave=False) as pbar:
|
|
|
for batch_idx, info in data_iterator:
|
|
|
|
|
|
global_step += 1
|
|
|
|
|
|
if not from_scratch:
|
|
|
num_batches_in_epoch += 1
|
|
|
|
|
|
if device.type == "cuda" and not cache_data_in_gpu:
|
|
|
info = [tensor.cuda(device_id, non_blocking=True) for tensor in info]
|
|
|
elif device.type != "cuda":
|
|
|
info = [tensor.to(device) for tensor in info]
|
|
|
(
|
|
|
phone,
|
|
|
phone_lengths,
|
|
|
pitch,
|
|
|
pitchf,
|
|
|
spec,
|
|
|
spec_lengths,
|
|
|
y,
|
|
|
y_lengths,
|
|
|
sid,
|
|
|
) = info
|
|
|
|
|
|
|
|
|
with autocast(device_type="cuda", enabled=use_amp, dtype=train_dtype):
|
|
|
model_output = net_g(phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid)
|
|
|
|
|
|
if vocoder == "RingFormer":
|
|
|
y_hat, ids_slice, x_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q), (mag, phase) = (model_output)
|
|
|
else:
|
|
|
y_hat, ids_slice, x_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q) = (model_output)
|
|
|
|
|
|
|
|
|
if randomized:
|
|
|
y = commons.slice_segments(
|
|
|
y,
|
|
|
ids_slice * config.data.hop_length,
|
|
|
config.train.segment_size,
|
|
|
dim=3,
|
|
|
)
|
|
|
|
|
|
if vocoder == "RingFormer":
|
|
|
reshaped_y = y.view(-1, y.size(-1))
|
|
|
reshaped_y_hat = y_hat.view(-1, y_hat.size(-1))
|
|
|
y_stft = torch.stft(reshaped_y, n_fft=config.model.gen_istft_n_fft, hop_length=config.model.gen_istft_hop_size, win_length=config.model.gen_istft_n_fft, window=hann_window, return_complex=True)
|
|
|
y_hat_stft = torch.stft(reshaped_y_hat, n_fft=config.model.gen_istft_n_fft, hop_length=config.model.gen_istft_hop_size, win_length=config.model.gen_istft_n_fft, window=hann_window, return_complex=True)
|
|
|
target_magnitude = torch.abs(y_stft)
|
|
|
|
|
|
|
|
|
for _ in range(d_updates_per_step):
|
|
|
with autocast(device_type="cuda", enabled=use_amp, dtype=train_dtype):
|
|
|
y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
|
|
|
|
|
|
with autocast(device_type="cuda", enabled=False):
|
|
|
|
|
|
loss_disc = discriminator_loss(y_d_hat_r, y_d_hat_g)
|
|
|
|
|
|
|
|
|
optim_d.zero_grad()
|
|
|
if train_dtype == torch.float16:
|
|
|
|
|
|
gradscaler.scale(loss_disc).backward()
|
|
|
gradscaler.unscale_(optim_d)
|
|
|
|
|
|
grad_norm_d = torch.nn.utils.clip_grad_norm_(net_d.parameters(), max_norm=999999)
|
|
|
|
|
|
grad_norm_d_clipped = commons.get_total_norm([p.grad for p in net_d.parameters() if p.grad is not None], norm_type=2.0, error_if_nonfinite=False)
|
|
|
|
|
|
gradscaler.step(optim_d)
|
|
|
else:
|
|
|
loss_disc.backward()
|
|
|
|
|
|
grad_norm_d = torch.nn.utils.clip_grad_norm_(net_d.parameters(), max_norm=999999)
|
|
|
|
|
|
grad_norm_d_clipped = commons.get_total_norm([p.grad for p in net_d.parameters() if p.grad is not None], norm_type=2.0, error_if_nonfinite=True)
|
|
|
|
|
|
optim_d.step()
|
|
|
|
|
|
|
|
|
with autocast(device_type="cuda", enabled=use_amp, dtype=train_dtype):
|
|
|
_, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
|
|
|
|
|
|
|
|
|
with autocast(device_type="cuda", enabled=False):
|
|
|
|
|
|
|
|
|
if spectral_loss == "L1 Mel Loss":
|
|
|
y_mel = wave_to_mel(config, y, half=train_dtype)
|
|
|
y_hat_mel = wave_to_mel(config, y_hat, half=train_dtype)
|
|
|
loss_mel = fn_spectral_loss(y_mel, y_hat_mel) * config.train.c_mel
|
|
|
elif spectral_loss == "Multi-Scale Mel Loss":
|
|
|
loss_mel = fn_spectral_loss(y, y_hat) * config.train.c_mel / 3.0
|
|
|
elif spectral_loss == "Multi-Res STFT Loss":
|
|
|
loss_mel = fn_spectral_loss(y_hat.float(), y.float()) * c_stft
|
|
|
|
|
|
|
|
|
loss_fm = feature_loss(fmap_r, fmap_g)
|
|
|
|
|
|
|
|
|
loss_adv = generator_loss(y_d_hat_g)
|
|
|
|
|
|
|
|
|
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * config.train.c_kl
|
|
|
|
|
|
if vocoder == "RingFormer":
|
|
|
|
|
|
loss_magnitude = torch.nn.functional.l1_loss(mag, target_magnitude)
|
|
|
loss_phase = phase_loss(y_stft, y_hat_stft)
|
|
|
loss_sd = (loss_magnitude + loss_phase) * 0.7
|
|
|
|
|
|
|
|
|
if vocoder == "RingFormer":
|
|
|
loss_gen_total = loss_adv + loss_fm + loss_mel + loss_kl + loss_sd
|
|
|
else:
|
|
|
loss_gen_total = loss_adv + loss_fm + loss_mel + loss_kl
|
|
|
|
|
|
|
|
|
|
|
|
optim_g.zero_grad()
|
|
|
if train_dtype == torch.float16:
|
|
|
|
|
|
gradscaler.scale(loss_gen_total).backward()
|
|
|
gradscaler.unscale_(optim_g)
|
|
|
|
|
|
grad_norm_g = torch.nn.utils.clip_grad_norm_(net_g.parameters(), max_norm=999999)
|
|
|
|
|
|
grad_norm_g_clipped = commons.get_total_norm([p.grad for p in net_g.parameters() if p.grad is not None], norm_type=2.0, error_if_nonfinite=False)
|
|
|
|
|
|
gradscaler.step(optim_g)
|
|
|
gradscaler.update()
|
|
|
else:
|
|
|
loss_gen_total.backward()
|
|
|
|
|
|
grad_norm_g = torch.nn.utils.clip_grad_norm_(net_g.parameters(), max_norm=999999)
|
|
|
|
|
|
grad_norm_g_clipped = commons.get_total_norm([p.grad for p in net_g.parameters() if p.grad is not None], norm_type=2.0, error_if_nonfinite=True)
|
|
|
|
|
|
optim_g.step()
|
|
|
|
|
|
|
|
|
if not from_scratch:
|
|
|
|
|
|
epoch_loss_tensor[0].add_(loss_disc.detach())
|
|
|
epoch_loss_tensor[1].add_(loss_adv.detach())
|
|
|
epoch_loss_tensor[2].add_(loss_gen_total.detach())
|
|
|
epoch_loss_tensor[3].add_(loss_fm.detach())
|
|
|
epoch_loss_tensor[4].add_(loss_mel.detach())
|
|
|
epoch_loss_tensor[5].add_(loss_kl.detach())
|
|
|
if vocoder == "RingFormer":
|
|
|
epoch_loss_tensor[6].add_(loss_sd.detach())
|
|
|
|
|
|
|
|
|
|
|
|
avg_50_cache["grad_norm_d_clipped_50"].append(grad_norm_d_clipped)
|
|
|
avg_50_cache["grad_norm_g_clipped_50"].append(grad_norm_g_clipped)
|
|
|
|
|
|
avg_50_cache["loss_disc_50"].append(loss_disc.detach())
|
|
|
avg_50_cache["loss_adv_50"].append(loss_adv.detach())
|
|
|
avg_50_cache["loss_gen_total_50"].append(loss_gen_total.detach())
|
|
|
avg_50_cache["loss_fm_50"].append(loss_fm.detach())
|
|
|
avg_50_cache["loss_mel_50"].append(loss_mel.detach())
|
|
|
avg_50_cache["loss_kl_50"].append(loss_kl.detach())
|
|
|
if vocoder == "RingFormer":
|
|
|
avg_50_cache["loss_sd_50"].append(loss_sd.detach())
|
|
|
|
|
|
if rank == 0 and global_step % 50 == 0:
|
|
|
scalar_dict_50 = {}
|
|
|
|
|
|
if from_scratch:
|
|
|
lr_d = optim_d.param_groups[0]["lr"]
|
|
|
lr_g = optim_g.param_groups[0]["lr"]
|
|
|
scalar_dict_50.update({
|
|
|
"learning_rate/lr_d": lr_d,
|
|
|
"learning_rate/lr_g": lr_g,
|
|
|
})
|
|
|
if optimizer_choice == "Prodigy":
|
|
|
prodigy_lr_g = optim_g.param_groups[0].get('d', 0)
|
|
|
prodigy_lr_d = optim_d.param_groups[0].get('d', 0)
|
|
|
scalar_dict_50.update({
|
|
|
"learning_rate/prodigy_lr_g": prodigy_lr_g,
|
|
|
"learning_rate/prodigy_lr_d": prodigy_lr_d,
|
|
|
})
|
|
|
|
|
|
scalar_dict_50.update({
|
|
|
|
|
|
"grad_avg_50/norm_d_clipped_50": sum(avg_50_cache["grad_norm_d_clipped_50"])
|
|
|
/ len(avg_50_cache["grad_norm_d_clipped_50"]),
|
|
|
"grad_avg_50/norm_g_clipped_50": sum(avg_50_cache["grad_norm_g_clipped_50"])
|
|
|
/ len(avg_50_cache["grad_norm_g_clipped_50"]),
|
|
|
|
|
|
"loss_avg_50/loss_disc_50": torch.mean(
|
|
|
torch.stack(list(avg_50_cache["loss_disc_50"]))),
|
|
|
"loss_avg_50/loss_adv_50": torch.mean(
|
|
|
torch.stack(list(avg_50_cache["loss_adv_50"]))),
|
|
|
"loss_avg_50/loss_gen_total_50": torch.mean(
|
|
|
torch.stack(list(avg_50_cache["loss_gen_total_50"]))),
|
|
|
"loss_avg_50/loss_fm_50": torch.mean(
|
|
|
torch.stack(list(avg_50_cache["loss_fm_50"]))),
|
|
|
"loss_avg_50/loss_mel_50": torch.mean(
|
|
|
torch.stack(list(avg_50_cache["loss_mel_50"]))),
|
|
|
"loss_avg_50/loss_kl_50": torch.mean(
|
|
|
torch.stack(list(avg_50_cache["loss_kl_50"]))),
|
|
|
})
|
|
|
if vocoder == "RingFormer":
|
|
|
scalar_dict_50.update({
|
|
|
|
|
|
"loss_avg_50/loss_sd_50": torch.mean(
|
|
|
torch.stack(list(avg_50_cache["loss_sd_50"]))),
|
|
|
})
|
|
|
|
|
|
summarize(writer=writer, global_step=global_step, scalars=scalar_dict_50)
|
|
|
flush_writer(writer, rank)
|
|
|
|
|
|
pbar.update(1)
|
|
|
|
|
|
|
|
|
|
|
|
if n_gpus > 1 and device.type == 'cuda':
|
|
|
dist.barrier()
|
|
|
|
|
|
with torch.no_grad():
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
if rank == 0:
|
|
|
|
|
|
mel = spec_to_mel_torch(
|
|
|
spec,
|
|
|
config.data.filter_length,
|
|
|
config.data.n_mel_channels,
|
|
|
config.data.sample_rate,
|
|
|
config.data.mel_fmin,
|
|
|
config.data.mel_fmax,
|
|
|
)
|
|
|
|
|
|
|
|
|
if train_dtype == torch.float16:
|
|
|
mel = mel.half()
|
|
|
|
|
|
|
|
|
if randomized:
|
|
|
y_mel = commons.slice_segments(
|
|
|
mel,
|
|
|
ids_slice,
|
|
|
config.train.segment_size // config.data.hop_length,
|
|
|
dim=3,
|
|
|
)
|
|
|
else:
|
|
|
y_mel = mel
|
|
|
|
|
|
|
|
|
y_hat_mel = wave_to_mel(config, y_hat, half=train_dtype)
|
|
|
|
|
|
|
|
|
mel_similarity = mel_spec_similarity(y_hat_mel, y_mel)
|
|
|
print(f'Mel Spectrogram Similarity: {mel_similarity:.2f}%')
|
|
|
writer.add_scalar('Metric/Mel_Spectrogram_Similarity', mel_similarity, global_step)
|
|
|
|
|
|
|
|
|
lr_d = optim_d.param_groups[0]["lr"]
|
|
|
lr_g = optim_g.param_groups[0]["lr"]
|
|
|
|
|
|
|
|
|
if global_step % len(train_loader) == 0 and not from_scratch:
|
|
|
avg_epoch_loss = epoch_loss_tensor / num_batches_in_epoch
|
|
|
|
|
|
scalar_dict_avg = {
|
|
|
"loss_avg/loss_disc": avg_epoch_loss[0],
|
|
|
"loss_avg/loss_adv": avg_epoch_loss[1],
|
|
|
"loss_avg/loss_gen_total": avg_epoch_loss[2],
|
|
|
"loss_avg/loss_fm": avg_epoch_loss[3],
|
|
|
"loss_avg/loss_mel": avg_epoch_loss[4],
|
|
|
"loss_avg/loss_kl": avg_epoch_loss[5],
|
|
|
"learning_rate/lr_d": lr_d,
|
|
|
"learning_rate/lr_g": lr_g,
|
|
|
}
|
|
|
if optimizer_choice == "Prodigy":
|
|
|
prodigy_lr_g = optim_g.param_groups[0].get('d', 0)
|
|
|
prodigy_lr_d = optim_d.param_groups[0].get('d', 0)
|
|
|
scalar_dict_avg.update({
|
|
|
"learning_rate/prodigy_lr_g": prodigy_lr_g,
|
|
|
"learning_rate/prodigy_lr_d": prodigy_lr_d,
|
|
|
})
|
|
|
if vocoder == "RingFormer":
|
|
|
scalar_dict_avg.update({
|
|
|
"loss_avg/loss_sd": avg_epoch_loss[6],
|
|
|
})
|
|
|
|
|
|
summarize(writer=writer, global_step=global_step, scalars=scalar_dict_avg)
|
|
|
flush_writer(writer, rank)
|
|
|
num_batches_in_epoch = 0
|
|
|
epoch_loss_tensor.zero_()
|
|
|
|
|
|
|
|
|
if train_dtype == torch.float16:
|
|
|
plot_dtype = torch.float16
|
|
|
else:
|
|
|
plot_dtype = torch.float32
|
|
|
|
|
|
image_dict = {
|
|
|
"slice/mel_org": plot_spectrogram_to_numpy(y_mel[0].detach().cpu().to(plot_dtype).numpy()),
|
|
|
"slice/mel_gen": plot_spectrogram_to_numpy(y_hat_mel[0].detach().cpu().to(plot_dtype).numpy()),
|
|
|
"all/mel": plot_spectrogram_to_numpy(mel[0].detach().cpu().to(plot_dtype).numpy()),
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if epoch % epoch_save_frequency == 0:
|
|
|
if not benchmark_mode and use_validation:
|
|
|
|
|
|
validation_loop(
|
|
|
net_g.module if hasattr(net_g, "module") else net_g,
|
|
|
val_loader,
|
|
|
device,
|
|
|
config,
|
|
|
writer,
|
|
|
global_step,
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
net_g.eval()
|
|
|
with torch.no_grad():
|
|
|
if hasattr(net_g, "module"):
|
|
|
o, *_ = net_g.module.infer(*reference)
|
|
|
else:
|
|
|
o, *_ = net_g.infer(*reference)
|
|
|
net_g.train()
|
|
|
audio_dict = {f"gen/audio_{epoch}e_{global_step}s": o[0, :, :]}
|
|
|
|
|
|
summarize(
|
|
|
writer=writer,
|
|
|
global_step=global_step,
|
|
|
images=image_dict,
|
|
|
audios=audio_dict,
|
|
|
audio_sample_rate=config.data.sample_rate,
|
|
|
)
|
|
|
flush_writer(writer, rank)
|
|
|
else:
|
|
|
summarize(
|
|
|
writer=writer,
|
|
|
global_step=global_step,
|
|
|
images=image_dict,
|
|
|
)
|
|
|
flush_writer(writer, rank)
|
|
|
|
|
|
|
|
|
model_add = []
|
|
|
done = False
|
|
|
|
|
|
if rank == 0:
|
|
|
|
|
|
record = f"{model_name} | epoch={epoch} | step={global_step} | {epoch_recorder.record()}"
|
|
|
print(record)
|
|
|
|
|
|
|
|
|
if epoch % epoch_save_frequency == 0:
|
|
|
checkpoint_suffix = f"{2333333 if save_only_latest_net_models else global_step}.pth"
|
|
|
|
|
|
save_checkpoint(
|
|
|
architecture,
|
|
|
net_g,
|
|
|
optim_g,
|
|
|
config.train.learning_rate,
|
|
|
epoch,
|
|
|
os.path.join(experiment_dir, "G_" + checkpoint_suffix),
|
|
|
)
|
|
|
|
|
|
save_checkpoint(
|
|
|
architecture,
|
|
|
net_d,
|
|
|
optim_d,
|
|
|
config.train.learning_rate,
|
|
|
epoch,
|
|
|
os.path.join(experiment_dir, "D_" + checkpoint_suffix),
|
|
|
)
|
|
|
|
|
|
if save_weight_models:
|
|
|
weight_model_name = small_model_naming(model_name, epoch, global_step)
|
|
|
model_add.append(os.path.join(experiment_dir, weight_model_name))
|
|
|
|
|
|
|
|
|
if epoch >= total_epoch_count:
|
|
|
print(
|
|
|
f"Training has been successfully completed with {epoch} epoch, {global_step} steps and {round(loss_gen_total.item(), 3)} loss gen."
|
|
|
)
|
|
|
|
|
|
weight_model_name = small_model_naming(model_name, epoch, global_step)
|
|
|
model_add.append(os.path.join(experiment_dir, weight_model_name))
|
|
|
|
|
|
done = True
|
|
|
|
|
|
if model_add:
|
|
|
ckpt = (
|
|
|
net_g.module.state_dict()
|
|
|
if hasattr(net_g, "module")
|
|
|
else net_g.state_dict()
|
|
|
)
|
|
|
for m in model_add:
|
|
|
if not os.path.exists(m):
|
|
|
extract_model(
|
|
|
ckpt=ckpt,
|
|
|
sr=sample_rate,
|
|
|
name=model_name,
|
|
|
model_path=m,
|
|
|
epoch=epoch,
|
|
|
step=global_step,
|
|
|
hps=config,
|
|
|
vocoder=vocoder,
|
|
|
architecture=architecture,
|
|
|
)
|
|
|
if done:
|
|
|
|
|
|
pid_data["process_pids"].clear()
|
|
|
|
|
|
if rank == 0:
|
|
|
writer.flush()
|
|
|
writer.close()
|
|
|
|
|
|
os._exit(2333333)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
def validation_loop(net_g, val_loader, device, config, writer, global_step):
|
|
|
net_g.eval()
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
total_mel_error = 0.0
|
|
|
total_mrstft_loss = 0.0
|
|
|
total_pesq = 0.0
|
|
|
valid_pesq_count = 0
|
|
|
total_si_sdr = 0.0
|
|
|
count = 0
|
|
|
|
|
|
mrstft = auraloss.freq.MultiResolutionSTFTLoss(device=device)
|
|
|
resample_to_16k = torchaudio.transforms.Resample(orig_freq=config.data.sample_rate, new_freq=16000).to(device)
|
|
|
|
|
|
hop_length = config.data.hop_length
|
|
|
sample_rate = config.data.sample_rate
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for batch in tqdm(val_loader, desc="Validating"):
|
|
|
phone, phone_lengths, pitch, pitchf, spec, spec_lengths, y, _, sid = [t.to(device) for t in batch]
|
|
|
|
|
|
|
|
|
y_hat, x_mask, _ = net_g.infer(phone, phone_lengths, pitch, pitchf, sid)
|
|
|
|
|
|
|
|
|
y_len = y.shape[-1]
|
|
|
|
|
|
|
|
|
y_hat_mel = wave_to_mel(config, y_hat, half=train_dtype)
|
|
|
mel = wave_to_mel(config, y, half=train_dtype)
|
|
|
|
|
|
|
|
|
y_hat_mel_len = y_hat_mel.shape[-1]
|
|
|
mel_len = mel.shape[-1]
|
|
|
|
|
|
min_t = min(y_hat_mel_len, mel_len)
|
|
|
|
|
|
mel_loss = F.l1_loss(y_hat_mel[..., :min_t], mel[..., :min_t])
|
|
|
total_mel_error += mel_loss.item()
|
|
|
|
|
|
|
|
|
y_hat_len = y_hat.shape[-1]
|
|
|
|
|
|
min_samples = min_t * hop_length
|
|
|
min_samples = min(min_samples, y_len, y_hat_len)
|
|
|
|
|
|
stft_loss = mrstft(y_hat[..., :min_samples], y[..., :min_samples])
|
|
|
total_mrstft_loss += stft_loss.item()
|
|
|
|
|
|
|
|
|
si_sdr_score = si_sdr(y_hat.squeeze(1), y.squeeze(1))
|
|
|
total_si_sdr += si_sdr_score.item()
|
|
|
|
|
|
|
|
|
try:
|
|
|
y_16k_batch = resample_to_16k(y).cpu().numpy()
|
|
|
y_hat_16k_batch = resample_to_16k(y_hat.squeeze(1)).cpu().numpy()
|
|
|
|
|
|
for i in range(y_16k_batch.shape[0]):
|
|
|
y_16k_f = np.squeeze(y_16k_batch[i]).astype(np.float32)
|
|
|
y_hat_16k_f = np.squeeze(y_hat_16k_batch[i]).astype(np.float32)
|
|
|
|
|
|
try:
|
|
|
pesq_score = pesq(16000, y_16k_f, y_hat_16k_f, mode="wb")
|
|
|
total_pesq += pesq_score
|
|
|
valid_pesq_count += 1
|
|
|
except Exception as e:
|
|
|
print(f"[PESQ skipped] {e}")
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"[PESQ skipped outer] {e}")
|
|
|
|
|
|
count += 1
|
|
|
|
|
|
avg_mel = total_mel_error / count
|
|
|
avg_mrstft = total_mrstft_loss / count
|
|
|
avg_pesq = total_pesq / max(valid_pesq_count, 1)
|
|
|
avg_si_sdr = total_si_sdr / count
|
|
|
|
|
|
if writer is not None:
|
|
|
writer.add_scalar("validation/loss/mel_l1", avg_mel, global_step)
|
|
|
writer.add_scalar("validation/loss/mrstft", avg_mrstft, global_step)
|
|
|
writer.add_scalar("validation/score/pesq", avg_pesq, global_step)
|
|
|
writer.add_scalar("validation/score/si_sdr", avg_si_sdr, global_step)
|
|
|
|
|
|
net_g.train()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
torch.multiprocessing.set_start_method("spawn")
|
|
|
main()
|
|
|
|