Create train_single_gpu.py
Browse files- RingFormer/train_single_gpu.py +295 -0
RingFormer/train_single_gpu.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
warnings.simplefilter(action='ignore', category=FutureWarning)
|
| 3 |
+
import itertools
|
| 4 |
+
import os
|
| 5 |
+
import time
|
| 6 |
+
import argparse
|
| 7 |
+
import json
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 11 |
+
from torch.utils.data import DistributedSampler, DataLoader
|
| 12 |
+
import torch.multiprocessing as mp
|
| 13 |
+
from torch.distributed import init_process_group
|
| 14 |
+
from torch.nn.parallel import DistributedDataParallel
|
| 15 |
+
from env import AttrDict, build_env
|
| 16 |
+
from meldataset import MelDataset, mel_spectrogram, get_dataset_filelist
|
| 17 |
+
from models import Generator, MultiPeriodDiscriminator, feature_loss, generator_loss,\
|
| 18 |
+
discriminator_loss, discriminator_TPRLS_loss, generator_TPRLS_loss, MultiScaleSubbandCQTDiscriminator
|
| 19 |
+
from utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint
|
| 20 |
+
from stft import TorchSTFT
|
| 21 |
+
from Utils.JDC.model import JDCNet
|
| 22 |
+
|
| 23 |
+
torch.backends.cudnn.benchmark = True
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def train(rank, a, h):
|
| 27 |
+
if h.num_gpus > 1:
|
| 28 |
+
init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'],
|
| 29 |
+
world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank)
|
| 30 |
+
|
| 31 |
+
torch.cuda.manual_seed(h.seed)
|
| 32 |
+
device = torch.device('cuda:{:d}'.format(rank))
|
| 33 |
+
|
| 34 |
+
F0_model = JDCNet(num_class=1, seq_len=192)
|
| 35 |
+
params = torch.load(h.F0_path)['net']
|
| 36 |
+
F0_model.load_state_dict(params)
|
| 37 |
+
|
| 38 |
+
generator = Generator(h, F0_model).to(device)
|
| 39 |
+
mpd = MultiPeriodDiscriminator().to(device)
|
| 40 |
+
msd = MultiScaleSubbandCQTDiscriminator().to(device)
|
| 41 |
+
stft = TorchSTFT(filter_length=h.gen_istft_n_fft, hop_length=h.gen_istft_hop_size, win_length=h.gen_istft_n_fft).to(device)
|
| 42 |
+
|
| 43 |
+
if rank == 0:
|
| 44 |
+
print(generator)
|
| 45 |
+
os.makedirs(a.checkpoint_path, exist_ok=True)
|
| 46 |
+
print("checkpoints directory : ", a.checkpoint_path)
|
| 47 |
+
|
| 48 |
+
if os.path.isdir(a.checkpoint_path):
|
| 49 |
+
cp_g = scan_checkpoint(a.checkpoint_path, 'g_')
|
| 50 |
+
cp_do = scan_checkpoint(a.checkpoint_path, 'do_')
|
| 51 |
+
|
| 52 |
+
steps = 0
|
| 53 |
+
if cp_g is None or cp_do is None:
|
| 54 |
+
state_dict_do = None
|
| 55 |
+
last_epoch = -1
|
| 56 |
+
else:
|
| 57 |
+
state_dict_g = load_checkpoint(cp_g, device)
|
| 58 |
+
state_dict_do = load_checkpoint(cp_do, device)
|
| 59 |
+
generator.load_state_dict(state_dict_g['generator'])
|
| 60 |
+
mpd.load_state_dict(state_dict_do['mpd'])
|
| 61 |
+
msd.load_state_dict(state_dict_do['msd'])
|
| 62 |
+
steps = state_dict_do['steps'] + 1
|
| 63 |
+
last_epoch = state_dict_do['epoch']
|
| 64 |
+
|
| 65 |
+
if h.num_gpus > 1:
|
| 66 |
+
generator = DistributedDataParallel(generator, device_ids=[rank], find_unused_parameters=True).to(device)
|
| 67 |
+
mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
|
| 68 |
+
msd = DistributedDataParallel(msd, device_ids=[rank]).to(device)
|
| 69 |
+
|
| 70 |
+
optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
|
| 71 |
+
optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()),
|
| 72 |
+
h.learning_rate, betas=[h.adam_b1, h.adam_b2])
|
| 73 |
+
|
| 74 |
+
if state_dict_do is not None:
|
| 75 |
+
optim_g.load_state_dict(state_dict_do['optim_g'])
|
| 76 |
+
optim_d.load_state_dict(state_dict_do['optim_d'])
|
| 77 |
+
|
| 78 |
+
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
|
| 79 |
+
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)
|
| 80 |
+
|
| 81 |
+
training_filelist, validation_filelist = get_dataset_filelist(a)
|
| 82 |
+
|
| 83 |
+
trainset = MelDataset(training_filelist, h.segment_size, h.n_fft, h.num_mels,
|
| 84 |
+
h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, n_cache_reuse=0,
|
| 85 |
+
shuffle=False if h.num_gpus > 1 else True, fmax_loss=h.fmax_for_loss, device=device,
|
| 86 |
+
fine_tuning=a.fine_tuning, base_mels_path=a.input_mels_dir)
|
| 87 |
+
|
| 88 |
+
train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None
|
| 89 |
+
|
| 90 |
+
train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False,
|
| 91 |
+
sampler=train_sampler,
|
| 92 |
+
batch_size=h.batch_size,
|
| 93 |
+
pin_memory=True,
|
| 94 |
+
drop_last=True)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
if rank == 0:
|
| 100 |
+
validset = MelDataset(validation_filelist, h.segment_size, h.n_fft, h.num_mels,
|
| 101 |
+
h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, False, False, n_cache_reuse=0,
|
| 102 |
+
fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning,
|
| 103 |
+
base_mels_path=a.input_mels_dir)
|
| 104 |
+
validation_loader = DataLoader(validset, num_workers=1, shuffle=False,
|
| 105 |
+
sampler=None,
|
| 106 |
+
batch_size=1,
|
| 107 |
+
pin_memory=True,
|
| 108 |
+
drop_last=True)
|
| 109 |
+
|
| 110 |
+
sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs'))
|
| 111 |
+
|
| 112 |
+
generator.train()
|
| 113 |
+
mpd.train()
|
| 114 |
+
msd.train()
|
| 115 |
+
for epoch in range(max(0, last_epoch), a.training_epochs):
|
| 116 |
+
if rank == 0:
|
| 117 |
+
start = time.time()
|
| 118 |
+
print("Epoch: {}".format(epoch+1))
|
| 119 |
+
|
| 120 |
+
if h.num_gpus > 1:
|
| 121 |
+
train_sampler.set_epoch(epoch)
|
| 122 |
+
|
| 123 |
+
for i, batch in enumerate(train_loader):
|
| 124 |
+
if rank == 0:
|
| 125 |
+
start_b = time.time()
|
| 126 |
+
x, y, _, y_mel = batch
|
| 127 |
+
x = torch.autograd.Variable(x.to(device, non_blocking=True))
|
| 128 |
+
y = torch.autograd.Variable(y.to(device, non_blocking=True))
|
| 129 |
+
y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
|
| 130 |
+
y = y.unsqueeze(1)
|
| 131 |
+
# y_g_hat = generator(x)
|
| 132 |
+
spec, phase = generator(x)
|
| 133 |
+
|
| 134 |
+
y_g_hat = stft.inverse(spec, phase)
|
| 135 |
+
|
| 136 |
+
y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size,
|
| 137 |
+
h.fmin, h.fmax_for_loss)
|
| 138 |
+
|
| 139 |
+
optim_d.zero_grad()
|
| 140 |
+
|
| 141 |
+
# MPD
|
| 142 |
+
y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
|
| 143 |
+
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
|
| 144 |
+
loss_disc_f += discriminator_TPRLS_loss(y_df_hat_r, y_df_hat_g)
|
| 145 |
+
|
| 146 |
+
# MSD
|
| 147 |
+
y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
|
| 148 |
+
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
|
| 149 |
+
loss_disc_s += discriminator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)
|
| 150 |
+
|
| 151 |
+
loss_disc_all = loss_disc_s + loss_disc_f
|
| 152 |
+
|
| 153 |
+
loss_disc_all.backward()
|
| 154 |
+
optim_d.step()
|
| 155 |
+
|
| 156 |
+
# Generator
|
| 157 |
+
optim_g.zero_grad()
|
| 158 |
+
|
| 159 |
+
# L1 Mel-Spectrogram Loss
|
| 160 |
+
loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45
|
| 161 |
+
|
| 162 |
+
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
|
| 163 |
+
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
|
| 164 |
+
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
|
| 165 |
+
loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
|
| 166 |
+
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
|
| 167 |
+
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
|
| 168 |
+
|
| 169 |
+
loss_gen_f += generator_TPRLS_loss(y_df_hat_r, y_df_hat_g)
|
| 170 |
+
loss_gen_s += generator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)
|
| 171 |
+
|
| 172 |
+
loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
|
| 173 |
+
|
| 174 |
+
loss_gen_all.backward()
|
| 175 |
+
optim_g.step()
|
| 176 |
+
|
| 177 |
+
if rank == 0:
|
| 178 |
+
# STDOUT logging
|
| 179 |
+
if steps % a.stdout_interval == 0:
|
| 180 |
+
with torch.no_grad():
|
| 181 |
+
mel_error = F.l1_loss(y_mel, y_g_hat_mel).item()
|
| 182 |
+
|
| 183 |
+
print('Steps : {:d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'.
|
| 184 |
+
format(steps, loss_gen_all, mel_error, time.time() - start_b))
|
| 185 |
+
|
| 186 |
+
# checkpointing
|
| 187 |
+
if steps % a.checkpoint_interval == 0 and steps != 0:
|
| 188 |
+
checkpoint_path = "{}/g_{:08d}".format(a.checkpoint_path, steps)
|
| 189 |
+
save_checkpoint(checkpoint_path,
|
| 190 |
+
{'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
|
| 191 |
+
checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps)
|
| 192 |
+
save_checkpoint(checkpoint_path,
|
| 193 |
+
{'mpd': (mpd.module if h.num_gpus > 1
|
| 194 |
+
else mpd).state_dict(),
|
| 195 |
+
'msd': (msd.module if h.num_gpus > 1
|
| 196 |
+
else msd).state_dict(),
|
| 197 |
+
'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,
|
| 198 |
+
'epoch': epoch})
|
| 199 |
+
|
| 200 |
+
# Tensorboard summary logging
|
| 201 |
+
if steps % a.summary_interval == 0:
|
| 202 |
+
sw.add_scalar("training/gen_loss_total", loss_gen_all, steps)
|
| 203 |
+
sw.add_scalar("training/mel_spec_error", mel_error, steps)
|
| 204 |
+
|
| 205 |
+
# Validation
|
| 206 |
+
if steps % a.validation_interval == 0: # and steps != 0:
|
| 207 |
+
generator.eval()
|
| 208 |
+
torch.cuda.empty_cache()
|
| 209 |
+
val_err_tot = 0
|
| 210 |
+
with torch.no_grad():
|
| 211 |
+
for j, batch in enumerate(validation_loader):
|
| 212 |
+
x, y, _, y_mel = batch
|
| 213 |
+
# y_g_hat = generator(x.to(device))
|
| 214 |
+
spec, phase = generator(x.to(device))
|
| 215 |
+
|
| 216 |
+
y_g_hat = stft.inverse(spec, phase)
|
| 217 |
+
|
| 218 |
+
y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
|
| 219 |
+
y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate,
|
| 220 |
+
h.hop_size, h.win_size,
|
| 221 |
+
h.fmin, h.fmax_for_loss)
|
| 222 |
+
val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item()
|
| 223 |
+
|
| 224 |
+
if j <= 4:
|
| 225 |
+
if steps == 0:
|
| 226 |
+
sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate)
|
| 227 |
+
sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(x[0]), steps)
|
| 228 |
+
|
| 229 |
+
sw.add_audio('generated/y_hat_{}'.format(j), y_g_hat[0], steps, h.sampling_rate)
|
| 230 |
+
y_hat_spec = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels,
|
| 231 |
+
h.sampling_rate, h.hop_size, h.win_size,
|
| 232 |
+
h.fmin, h.fmax)
|
| 233 |
+
sw.add_figure('generated/y_hat_spec_{}'.format(j),
|
| 234 |
+
plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), steps)
|
| 235 |
+
|
| 236 |
+
val_err = val_err_tot / (j+1)
|
| 237 |
+
sw.add_scalar("validation/mel_spec_error", val_err, steps)
|
| 238 |
+
|
| 239 |
+
generator.train()
|
| 240 |
+
|
| 241 |
+
steps += 1
|
| 242 |
+
|
| 243 |
+
scheduler_g.step()
|
| 244 |
+
scheduler_d.step()
|
| 245 |
+
|
| 246 |
+
if rank == 0:
|
| 247 |
+
print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start)))
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def main():
|
| 251 |
+
print('Initializing Training Process..')
|
| 252 |
+
|
| 253 |
+
parser = argparse.ArgumentParser()
|
| 254 |
+
|
| 255 |
+
parser.add_argument('--group_name', default=None)
|
| 256 |
+
parser.add_argument('--input_wavs_dir', default='')
|
| 257 |
+
parser.add_argument('--input_mels_dir', default='ft_dataset')
|
| 258 |
+
# parser.add_argument('--input_training_file', default='/home/ubuntu/RINGFORMER/LJSpeech-1.1/training.txt')
|
| 259 |
+
parser.add_argument('--input_training_file', default='/home/ubuntu/RINGFORMER/LJSpeech-1.1/eng_norm.txt')
|
| 260 |
+
parser.add_argument('--input_validation_file', default='/home/ubuntu/RINGFORMER/LJSpeech-1.1/valid_eng.txt')
|
| 261 |
+
parser.add_argument('--checkpoint_path', default='cp_ringformer_LIBRI')
|
| 262 |
+
parser.add_argument('--config', default='config_v1.json')
|
| 263 |
+
parser.add_argument('--training_epochs', default=3100, type=int)
|
| 264 |
+
parser.add_argument('--stdout_interval', default=10, type=int)
|
| 265 |
+
parser.add_argument('--checkpoint_interval', default=2500, type=int)
|
| 266 |
+
parser.add_argument('--summary_interval', default=100, type=int)
|
| 267 |
+
parser.add_argument('--validation_interval', default=1000, type=int)
|
| 268 |
+
parser.add_argument('--fine_tuning', default=False, type=bool)
|
| 269 |
+
|
| 270 |
+
a = parser.parse_args()
|
| 271 |
+
|
| 272 |
+
with open(a.config) as f:
|
| 273 |
+
data = f.read()
|
| 274 |
+
|
| 275 |
+
json_config = json.loads(data)
|
| 276 |
+
h = AttrDict(json_config)
|
| 277 |
+
build_env(a.config, 'config.json', a.checkpoint_path)
|
| 278 |
+
|
| 279 |
+
torch.manual_seed(h.seed)
|
| 280 |
+
if torch.cuda.is_available():
|
| 281 |
+
torch.cuda.manual_seed(h.seed)
|
| 282 |
+
h.num_gpus = torch.cuda.device_count()
|
| 283 |
+
h.batch_size = int(h.batch_size / h.num_gpus)
|
| 284 |
+
print('Batch size per GPU :', h.batch_size)
|
| 285 |
+
else:
|
| 286 |
+
pass
|
| 287 |
+
|
| 288 |
+
if h.num_gpus > 1:
|
| 289 |
+
mp.spawn(train, nprocs=h.num_gpus, args=(a, h,))
|
| 290 |
+
else:
|
| 291 |
+
train(0, a, h)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
if __name__ == '__main__':
|
| 295 |
+
main()
|