Update infer/modules/train/train.py
Browse files- infer/modules/train/train.py +155 -113
infer/modules/train/train.py
CHANGED
|
@@ -8,6 +8,7 @@ now_dir = os.getcwd()
|
|
| 8 |
sys.path.append(os.path.join(now_dir))
|
| 9 |
|
| 10 |
import datetime
|
|
|
|
| 11 |
|
| 12 |
from infer.lib.train import utils
|
| 13 |
|
|
@@ -105,6 +106,7 @@ def main():
|
|
| 105 |
os.environ["MASTER_PORT"] = str(randint(20000, 55555))
|
| 106 |
children = []
|
| 107 |
logger = utils.get_logger(hps.model_dir)
|
|
|
|
| 108 |
for i in range(n_gpus):
|
| 109 |
subproc = mp.Process(
|
| 110 |
target=run,
|
|
@@ -120,9 +122,8 @@ def main():
|
|
| 120 |
def run(rank, n_gpus, hps, logger: logging.Logger):
|
| 121 |
global global_step
|
| 122 |
if rank == 0:
|
| 123 |
-
|
| 124 |
logger.info(hps)
|
| 125 |
-
# utils.check_git_hash(hps.model_dir)
|
| 126 |
writer = SummaryWriter(log_dir=hps.model_dir)
|
| 127 |
writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))
|
| 128 |
|
|
@@ -140,18 +141,17 @@ def run(rank, n_gpus, hps, logger: logging.Logger):
|
|
| 140 |
train_sampler = DistributedBucketSampler(
|
| 141 |
train_dataset,
|
| 142 |
hps.train.batch_size * n_gpus,
|
| 143 |
-
|
| 144 |
-
[100, 200, 300, 400, 500, 600, 700, 800, 900], # 16s
|
| 145 |
num_replicas=n_gpus,
|
| 146 |
rank=rank,
|
| 147 |
shuffle=True,
|
| 148 |
)
|
| 149 |
-
|
| 150 |
-
# num_workers=8 -> num_workers=4
|
| 151 |
if hps.if_f0 == 1:
|
| 152 |
collate_fn = TextAudioCollateMultiNSFsid()
|
| 153 |
else:
|
| 154 |
collate_fn = TextAudioCollate()
|
|
|
|
| 155 |
train_loader = DataLoader(
|
| 156 |
train_dataset,
|
| 157 |
num_workers=4,
|
|
@@ -162,6 +162,11 @@ def run(rank, n_gpus, hps, logger: logging.Logger):
|
|
| 162 |
persistent_workers=True,
|
| 163 |
prefetch_factor=8,
|
| 164 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
if hps.if_f0 == 1:
|
| 166 |
net_g = RVC_Model_f0(
|
| 167 |
hps.data.filter_length // 2 + 1,
|
|
@@ -177,11 +182,14 @@ def run(rank, n_gpus, hps, logger: logging.Logger):
|
|
| 177 |
**hps.model,
|
| 178 |
is_half=hps.train.fp16_run,
|
| 179 |
)
|
|
|
|
| 180 |
if torch.cuda.is_available():
|
| 181 |
net_g = net_g.cuda(rank)
|
|
|
|
| 182 |
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm)
|
| 183 |
if torch.cuda.is_available():
|
| 184 |
net_d = net_d.cuda(rank)
|
|
|
|
| 185 |
optim_g = torch.optim.AdamW(
|
| 186 |
net_g.parameters(),
|
| 187 |
hps.train.learning_rate,
|
|
@@ -194,8 +202,7 @@ def run(rank, n_gpus, hps, logger: logging.Logger):
|
|
| 194 |
betas=hps.train.betas,
|
| 195 |
eps=hps.train.eps,
|
| 196 |
)
|
| 197 |
-
|
| 198 |
-
# net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
|
| 199 |
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
| 200 |
pass
|
| 201 |
elif torch.cuda.is_available():
|
|
@@ -205,52 +212,43 @@ def run(rank, n_gpus, hps, logger: logging.Logger):
|
|
| 205 |
net_g = DDP(net_g)
|
| 206 |
net_d = DDP(net_d)
|
| 207 |
|
| 208 |
-
try:
|
| 209 |
_, _, _, epoch_str = utils.load_checkpoint(
|
| 210 |
utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d
|
| 211 |
-
)
|
| 212 |
if rank == 0:
|
| 213 |
-
logger.info("
|
| 214 |
-
|
| 215 |
_, _, _, epoch_str = utils.load_checkpoint(
|
| 216 |
utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g
|
| 217 |
)
|
| 218 |
global_step = (epoch_str - 1) * len(train_loader)
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
except:
|
| 222 |
-
# traceback.print_exc()
|
| 223 |
epoch_str = 1
|
| 224 |
global_step = 0
|
| 225 |
if hps.pretrainG != "":
|
| 226 |
if rank == 0:
|
| 227 |
-
logger.info("
|
| 228 |
if hasattr(net_g, "module"):
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
)
|
| 233 |
-
) ##测试不加载优化器
|
| 234 |
else:
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
)
|
| 239 |
-
) ##测试不加载优化器
|
| 240 |
if hps.pretrainD != "":
|
| 241 |
if rank == 0:
|
| 242 |
-
logger.info("
|
| 243 |
if hasattr(net_d, "module"):
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
torch.load(hps.pretrainD, map_location="cpu")["model"]
|
| 247 |
-
)
|
| 248 |
)
|
| 249 |
else:
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
torch.load(hps.pretrainD, map_location="cpu")["model"]
|
| 253 |
-
)
|
| 254 |
)
|
| 255 |
|
| 256 |
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
|
|
@@ -263,6 +261,11 @@ def run(rank, n_gpus, hps, logger: logging.Logger):
|
|
| 263 |
scaler = GradScaler(enabled=hps.train.fp16_run)
|
| 264 |
|
| 265 |
cache = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
for epoch in range(epoch_str, hps.train.epochs + 1):
|
| 267 |
if rank == 0:
|
| 268 |
train_and_evaluate(
|
|
@@ -313,12 +316,16 @@ def train_and_evaluate(
|
|
| 313 |
|
| 314 |
# Prepare data iterator
|
| 315 |
if hps.if_cache_data_in_gpu == True:
|
| 316 |
-
# Use Cache
|
| 317 |
-
data_iterator = cache
|
| 318 |
if cache == []:
|
| 319 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
for batch_idx, info in enumerate(train_loader):
|
| 321 |
-
# Unpack
|
| 322 |
if hps.if_f0 == 1:
|
| 323 |
(
|
| 324 |
phone,
|
|
@@ -341,7 +348,7 @@ def train_and_evaluate(
|
|
| 341 |
wave_lengths,
|
| 342 |
sid,
|
| 343 |
) = info
|
| 344 |
-
|
| 345 |
if torch.cuda.is_available():
|
| 346 |
phone = phone.cuda(rank, non_blocking=True)
|
| 347 |
phone_lengths = phone_lengths.cuda(rank, non_blocking=True)
|
|
@@ -352,8 +359,7 @@ def train_and_evaluate(
|
|
| 352 |
spec = spec.cuda(rank, non_blocking=True)
|
| 353 |
spec_lengths = spec_lengths.cuda(rank, non_blocking=True)
|
| 354 |
wave = wave.cuda(rank, non_blocking=True)
|
| 355 |
-
|
| 356 |
-
# Cache on list
|
| 357 |
if hps.if_f0 == 1:
|
| 358 |
cache.append(
|
| 359 |
(
|
|
@@ -386,18 +392,31 @@ def train_and_evaluate(
|
|
| 386 |
),
|
| 387 |
)
|
| 388 |
)
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
else:
|
| 393 |
-
# Loader
|
| 394 |
data_iterator = enumerate(train_loader)
|
| 395 |
|
| 396 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
epoch_recorder = EpochRecorder()
|
|
|
|
| 398 |
for batch_idx, info in data_iterator:
|
| 399 |
-
#
|
| 400 |
-
## Unpack
|
| 401 |
if hps.if_f0 == 1:
|
| 402 |
(
|
| 403 |
phone,
|
|
@@ -412,7 +431,7 @@ def train_and_evaluate(
|
|
| 412 |
) = info
|
| 413 |
else:
|
| 414 |
phone, phone_lengths, spec, spec_lengths, wave, wave_lengths, sid = info
|
| 415 |
-
|
| 416 |
if (hps.if_cache_data_in_gpu == False) and torch.cuda.is_available():
|
| 417 |
phone = phone.cuda(rank, non_blocking=True)
|
| 418 |
phone_lengths = phone_lengths.cuda(rank, non_blocking=True)
|
|
@@ -423,9 +442,8 @@ def train_and_evaluate(
|
|
| 423 |
spec = spec.cuda(rank, non_blocking=True)
|
| 424 |
spec_lengths = spec_lengths.cuda(rank, non_blocking=True)
|
| 425 |
wave = wave.cuda(rank, non_blocking=True)
|
| 426 |
-
# wave_lengths = wave_lengths.cuda(rank, non_blocking=True)
|
| 427 |
|
| 428 |
-
#
|
| 429 |
with autocast(enabled=hps.train.fp16_run):
|
| 430 |
if hps.if_f0 == 1:
|
| 431 |
(
|
|
@@ -443,6 +461,7 @@ def train_and_evaluate(
|
|
| 443 |
z_mask,
|
| 444 |
(z, z_p, m_p, logs_p, m_q, logs_q),
|
| 445 |
) = net_g(phone, phone_lengths, spec, spec_lengths, sid)
|
|
|
|
| 446 |
mel = spec_to_mel_torch(
|
| 447 |
spec,
|
| 448 |
hps.data.filter_length,
|
|
@@ -454,6 +473,7 @@ def train_and_evaluate(
|
|
| 454 |
y_mel = commons.slice_segments(
|
| 455 |
mel, ids_slice, hps.train.segment_size // hps.data.hop_length
|
| 456 |
)
|
|
|
|
| 457 |
with autocast(enabled=False):
|
| 458 |
y_hat_mel = mel_spectrogram_torch(
|
| 459 |
y_hat.float().squeeze(1),
|
|
@@ -465,26 +485,30 @@ def train_and_evaluate(
|
|
| 465 |
hps.data.mel_fmin,
|
| 466 |
hps.data.mel_fmax,
|
| 467 |
)
|
|
|
|
| 468 |
if hps.train.fp16_run == True:
|
| 469 |
y_hat_mel = y_hat_mel.half()
|
|
|
|
| 470 |
wave = commons.slice_segments(
|
| 471 |
wave, ids_slice * hps.data.hop_length, hps.train.segment_size
|
| 472 |
-
)
|
| 473 |
|
| 474 |
-
# Discriminator
|
| 475 |
y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach())
|
| 476 |
with autocast(enabled=False):
|
| 477 |
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
|
| 478 |
y_d_hat_r, y_d_hat_g
|
| 479 |
)
|
|
|
|
|
|
|
| 480 |
optim_d.zero_grad()
|
| 481 |
scaler.scale(loss_disc).backward()
|
| 482 |
scaler.unscale_(optim_d)
|
| 483 |
grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
|
| 484 |
scaler.step(optim_d)
|
| 485 |
|
|
|
|
| 486 |
with autocast(enabled=hps.train.fp16_run):
|
| 487 |
-
# Generator
|
| 488 |
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat)
|
| 489 |
with autocast(enabled=False):
|
| 490 |
loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
|
|
@@ -492,6 +516,8 @@ def train_and_evaluate(
|
|
| 492 |
loss_fm = feature_loss(fmap_r, fmap_g)
|
| 493 |
loss_gen, losses_gen = generator_loss(y_d_hat_g)
|
| 494 |
loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl
|
|
|
|
|
|
|
| 495 |
optim_g.zero_grad()
|
| 496 |
scaler.scale(loss_gen_all).backward()
|
| 497 |
scaler.unscale_(optim_g)
|
|
@@ -499,39 +525,43 @@ def train_and_evaluate(
|
|
| 499 |
scaler.step(optim_g)
|
| 500 |
scaler.update()
|
| 501 |
|
|
|
|
| 502 |
if rank == 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 503 |
if global_step % hps.train.log_interval == 0:
|
| 504 |
lr = optim_g.param_groups[0]["lr"]
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
)
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
loss_kl = 9
|
| 515 |
-
|
| 516 |
-
logger.info([global_step, lr])
|
| 517 |
-
logger.info(
|
| 518 |
-
f"loss_disc={loss_disc:.3f}, loss_gen={loss_gen:.3f}, loss_fm={loss_fm:.3f},loss_mel={loss_mel:.3f}, loss_kl={loss_kl:.3f}"
|
| 519 |
-
)
|
| 520 |
scalar_dict = {
|
| 521 |
"loss/g/total": loss_gen_all,
|
| 522 |
"loss/d/total": loss_disc,
|
| 523 |
"learning_rate": lr,
|
| 524 |
"grad_norm_d": grad_norm_d,
|
| 525 |
"grad_norm_g": grad_norm_g,
|
|
|
|
|
|
|
|
|
|
| 526 |
}
|
| 527 |
-
scalar_dict.update(
|
| 528 |
-
{
|
| 529 |
-
"loss/g/fm": loss_fm,
|
| 530 |
-
"loss/g/mel": loss_mel,
|
| 531 |
-
"loss/g/kl": loss_kl,
|
| 532 |
-
}
|
| 533 |
-
)
|
| 534 |
-
|
| 535 |
scalar_dict.update(
|
| 536 |
{"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}
|
| 537 |
)
|
|
@@ -541,6 +571,7 @@ def train_and_evaluate(
|
|
| 541 |
scalar_dict.update(
|
| 542 |
{"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}
|
| 543 |
)
|
|
|
|
| 544 |
image_dict = {
|
| 545 |
"slice/mel_org": utils.plot_spectrogram_to_numpy(
|
| 546 |
y_mel[0].data.cpu().numpy()
|
|
@@ -552,89 +583,100 @@ def train_and_evaluate(
|
|
| 552 |
mel[0].data.cpu().numpy()
|
| 553 |
),
|
| 554 |
}
|
|
|
|
| 555 |
utils.summarize(
|
| 556 |
writer=writer,
|
| 557 |
global_step=global_step,
|
| 558 |
images=image_dict,
|
| 559 |
scalars=scalar_dict,
|
| 560 |
)
|
|
|
|
| 561 |
global_step += 1
|
| 562 |
-
|
| 563 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 564 |
if epoch % hps.save_every_epoch == 0 and rank == 0:
|
| 565 |
if hps.if_latest == 0:
|
|
|
|
|
|
|
| 566 |
utils.save_checkpoint(
|
| 567 |
net_g,
|
| 568 |
optim_g,
|
| 569 |
hps.train.learning_rate,
|
| 570 |
epoch,
|
| 571 |
-
|
| 572 |
)
|
| 573 |
utils.save_checkpoint(
|
| 574 |
net_d,
|
| 575 |
optim_d,
|
| 576 |
hps.train.learning_rate,
|
| 577 |
epoch,
|
| 578 |
-
|
| 579 |
)
|
|
|
|
| 580 |
else:
|
|
|
|
|
|
|
| 581 |
utils.save_checkpoint(
|
| 582 |
net_g,
|
| 583 |
optim_g,
|
| 584 |
hps.train.learning_rate,
|
| 585 |
epoch,
|
| 586 |
-
|
| 587 |
)
|
| 588 |
utils.save_checkpoint(
|
| 589 |
net_d,
|
| 590 |
optim_d,
|
| 591 |
hps.train.learning_rate,
|
| 592 |
epoch,
|
| 593 |
-
|
| 594 |
)
|
|
|
|
|
|
|
| 595 |
if rank == 0 and hps.save_every_weights == "1":
|
| 596 |
if hasattr(net_g, "module"):
|
| 597 |
ckpt = net_g.module.state_dict()
|
| 598 |
else:
|
| 599 |
ckpt = net_g.state_dict()
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
epoch,
|
| 611 |
-
hps.version,
|
| 612 |
-
hps,
|
| 613 |
-
),
|
| 614 |
-
)
|
| 615 |
)
|
|
|
|
| 616 |
|
|
|
|
| 617 |
if rank == 0:
|
| 618 |
-
logger.info("
|
|
|
|
|
|
|
|
|
|
| 619 |
if epoch >= hps.total_epoch and rank == 0:
|
| 620 |
-
logger.info("Training
|
| 621 |
-
|
| 622 |
if hasattr(net_g, "module"):
|
| 623 |
ckpt = net_g.module.state_dict()
|
| 624 |
else:
|
| 625 |
ckpt = net_g.state_dict()
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
savee(
|
| 630 |
-
ckpt, hps.sample_rate, hps.if_f0, hps.name, epoch, hps.version, hps
|
| 631 |
-
)
|
| 632 |
-
)
|
| 633 |
)
|
| 634 |
-
|
| 635 |
-
|
|
|
|
|
|
|
| 636 |
|
| 637 |
|
| 638 |
if __name__ == "__main__":
|
| 639 |
torch.multiprocessing.set_start_method("spawn")
|
| 640 |
-
main()
|
|
|
|
| 8 |
sys.path.append(os.path.join(now_dir))
|
| 9 |
|
| 10 |
import datetime
|
| 11 |
+
from tqdm import tqdm # Added import
|
| 12 |
|
| 13 |
from infer.lib.train import utils
|
| 14 |
|
|
|
|
| 106 |
os.environ["MASTER_PORT"] = str(randint(20000, 55555))
|
| 107 |
children = []
|
| 108 |
logger = utils.get_logger(hps.model_dir)
|
| 109 |
+
logger.info(f"Starting training with {n_gpus} GPU(s)")
|
| 110 |
for i in range(n_gpus):
|
| 111 |
subproc = mp.Process(
|
| 112 |
target=run,
|
|
|
|
| 122 |
def run(rank, n_gpus, hps, logger: logging.Logger):
|
| 123 |
global global_step
|
| 124 |
if rank == 0:
|
| 125 |
+
logger.info(f"Process {rank}/{n_gpus-1} started")
|
| 126 |
logger.info(hps)
|
|
|
|
| 127 |
writer = SummaryWriter(log_dir=hps.model_dir)
|
| 128 |
writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))
|
| 129 |
|
|
|
|
| 141 |
train_sampler = DistributedBucketSampler(
|
| 142 |
train_dataset,
|
| 143 |
hps.train.batch_size * n_gpus,
|
| 144 |
+
[100, 200, 300, 400, 500, 600, 700, 800, 900],
|
|
|
|
| 145 |
num_replicas=n_gpus,
|
| 146 |
rank=rank,
|
| 147 |
shuffle=True,
|
| 148 |
)
|
| 149 |
+
|
|
|
|
| 150 |
if hps.if_f0 == 1:
|
| 151 |
collate_fn = TextAudioCollateMultiNSFsid()
|
| 152 |
else:
|
| 153 |
collate_fn = TextAudioCollate()
|
| 154 |
+
|
| 155 |
train_loader = DataLoader(
|
| 156 |
train_dataset,
|
| 157 |
num_workers=4,
|
|
|
|
| 162 |
persistent_workers=True,
|
| 163 |
prefetch_factor=8,
|
| 164 |
)
|
| 165 |
+
|
| 166 |
+
if rank == 0:
|
| 167 |
+
logger.info(f"Training dataset size: {len(train_dataset)}")
|
| 168 |
+
logger.info(f"Number of batches per epoch: {len(train_loader)}")
|
| 169 |
+
|
| 170 |
if hps.if_f0 == 1:
|
| 171 |
net_g = RVC_Model_f0(
|
| 172 |
hps.data.filter_length // 2 + 1,
|
|
|
|
| 182 |
**hps.model,
|
| 183 |
is_half=hps.train.fp16_run,
|
| 184 |
)
|
| 185 |
+
|
| 186 |
if torch.cuda.is_available():
|
| 187 |
net_g = net_g.cuda(rank)
|
| 188 |
+
|
| 189 |
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm)
|
| 190 |
if torch.cuda.is_available():
|
| 191 |
net_d = net_d.cuda(rank)
|
| 192 |
+
|
| 193 |
optim_g = torch.optim.AdamW(
|
| 194 |
net_g.parameters(),
|
| 195 |
hps.train.learning_rate,
|
|
|
|
| 202 |
betas=hps.train.betas,
|
| 203 |
eps=hps.train.eps,
|
| 204 |
)
|
| 205 |
+
|
|
|
|
| 206 |
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
| 207 |
pass
|
| 208 |
elif torch.cuda.is_available():
|
|
|
|
| 212 |
net_g = DDP(net_g)
|
| 213 |
net_d = DDP(net_d)
|
| 214 |
|
| 215 |
+
try:
|
| 216 |
_, _, _, epoch_str = utils.load_checkpoint(
|
| 217 |
utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d
|
| 218 |
+
)
|
| 219 |
if rank == 0:
|
| 220 |
+
logger.info("Loaded discriminator checkpoint")
|
| 221 |
+
|
| 222 |
_, _, _, epoch_str = utils.load_checkpoint(
|
| 223 |
utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g
|
| 224 |
)
|
| 225 |
global_step = (epoch_str - 1) * len(train_loader)
|
| 226 |
+
if rank == 0:
|
| 227 |
+
logger.info(f"Resuming from epoch {epoch_str}, global step {global_step}")
|
| 228 |
+
except:
|
|
|
|
| 229 |
epoch_str = 1
|
| 230 |
global_step = 0
|
| 231 |
if hps.pretrainG != "":
|
| 232 |
if rank == 0:
|
| 233 |
+
logger.info(f"Loading pretrained generator from {hps.pretrainG}")
|
| 234 |
if hasattr(net_g, "module"):
|
| 235 |
+
net_g.module.load_state_dict(
|
| 236 |
+
torch.load(hps.pretrainG, map_location="cpu")["model"]
|
| 237 |
+
)
|
|
|
|
|
|
|
| 238 |
else:
|
| 239 |
+
net_g.load_state_dict(
|
| 240 |
+
torch.load(hps.pretrainG, map_location="cpu")["model"]
|
| 241 |
+
)
|
|
|
|
|
|
|
| 242 |
if hps.pretrainD != "":
|
| 243 |
if rank == 0:
|
| 244 |
+
logger.info(f"Loading pretrained discriminator from {hps.pretrainD}")
|
| 245 |
if hasattr(net_d, "module"):
|
| 246 |
+
net_d.module.load_state_dict(
|
| 247 |
+
torch.load(hps.pretrainD, map_location="cpu")["model"]
|
|
|
|
|
|
|
| 248 |
)
|
| 249 |
else:
|
| 250 |
+
net_d.load_state_dict(
|
| 251 |
+
torch.load(hps.pretrainD, map_location="cpu")["model"]
|
|
|
|
|
|
|
| 252 |
)
|
| 253 |
|
| 254 |
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
|
|
|
|
| 261 |
scaler = GradScaler(enabled=hps.train.fp16_run)
|
| 262 |
|
| 263 |
cache = []
|
| 264 |
+
|
| 265 |
+
if rank == 0:
|
| 266 |
+
logger.info(f"Starting training from epoch {epoch_str} to {hps.train.epochs}")
|
| 267 |
+
logger.info(f"Total epochs to train: {hps.train.epochs - epoch_str + 1}")
|
| 268 |
+
|
| 269 |
for epoch in range(epoch_str, hps.train.epochs + 1):
|
| 270 |
if rank == 0:
|
| 271 |
train_and_evaluate(
|
|
|
|
| 316 |
|
| 317 |
# Prepare data iterator
|
| 318 |
if hps.if_cache_data_in_gpu == True:
|
|
|
|
|
|
|
| 319 |
if cache == []:
|
| 320 |
+
if rank == 0:
|
| 321 |
+
logger.info("Caching data in GPU...")
|
| 322 |
+
cache_progress = tqdm(total=len(train_loader),
|
| 323 |
+
desc="Caching",
|
| 324 |
+
position=0,
|
| 325 |
+
leave=True,
|
| 326 |
+
disable=(rank != 0))
|
| 327 |
+
|
| 328 |
for batch_idx, info in enumerate(train_loader):
|
|
|
|
| 329 |
if hps.if_f0 == 1:
|
| 330 |
(
|
| 331 |
phone,
|
|
|
|
| 348 |
wave_lengths,
|
| 349 |
sid,
|
| 350 |
) = info
|
| 351 |
+
|
| 352 |
if torch.cuda.is_available():
|
| 353 |
phone = phone.cuda(rank, non_blocking=True)
|
| 354 |
phone_lengths = phone_lengths.cuda(rank, non_blocking=True)
|
|
|
|
| 359 |
spec = spec.cuda(rank, non_blocking=True)
|
| 360 |
spec_lengths = spec_lengths.cuda(rank, non_blocking=True)
|
| 361 |
wave = wave.cuda(rank, non_blocking=True)
|
| 362 |
+
|
|
|
|
| 363 |
if hps.if_f0 == 1:
|
| 364 |
cache.append(
|
| 365 |
(
|
|
|
|
| 392 |
),
|
| 393 |
)
|
| 394 |
)
|
| 395 |
+
|
| 396 |
+
if rank == 0:
|
| 397 |
+
cache_progress.update(1)
|
| 398 |
+
|
| 399 |
+
if rank == 0:
|
| 400 |
+
cache_progress.close()
|
| 401 |
+
logger.info(f"Cached {len(cache)} batches in GPU")
|
| 402 |
+
|
| 403 |
+
shuffle(cache)
|
| 404 |
+
data_iterator = cache
|
| 405 |
else:
|
|
|
|
| 406 |
data_iterator = enumerate(train_loader)
|
| 407 |
|
| 408 |
+
# Initialize tqdm progress bar for training
|
| 409 |
+
if rank == 0:
|
| 410 |
+
epoch_progress = tqdm(total=len(train_loader),
|
| 411 |
+
desc=f"Epoch {epoch}/{hps.train.epochs}",
|
| 412 |
+
position=0,
|
| 413 |
+
leave=True,
|
| 414 |
+
bar_format='{l_bar}{bar:20}{r_bar}{bar:-20b}')
|
| 415 |
+
|
| 416 |
epoch_recorder = EpochRecorder()
|
| 417 |
+
|
| 418 |
for batch_idx, info in data_iterator:
|
| 419 |
+
# Unpack data
|
|
|
|
| 420 |
if hps.if_f0 == 1:
|
| 421 |
(
|
| 422 |
phone,
|
|
|
|
| 431 |
) = info
|
| 432 |
else:
|
| 433 |
phone, phone_lengths, spec, spec_lengths, wave, wave_lengths, sid = info
|
| 434 |
+
|
| 435 |
if (hps.if_cache_data_in_gpu == False) and torch.cuda.is_available():
|
| 436 |
phone = phone.cuda(rank, non_blocking=True)
|
| 437 |
phone_lengths = phone_lengths.cuda(rank, non_blocking=True)
|
|
|
|
| 442 |
spec = spec.cuda(rank, non_blocking=True)
|
| 443 |
spec_lengths = spec_lengths.cuda(rank, non_blocking=True)
|
| 444 |
wave = wave.cuda(rank, non_blocking=True)
|
|
|
|
| 445 |
|
| 446 |
+
# Forward pass
|
| 447 |
with autocast(enabled=hps.train.fp16_run):
|
| 448 |
if hps.if_f0 == 1:
|
| 449 |
(
|
|
|
|
| 461 |
z_mask,
|
| 462 |
(z, z_p, m_p, logs_p, m_q, logs_q),
|
| 463 |
) = net_g(phone, phone_lengths, spec, spec_lengths, sid)
|
| 464 |
+
|
| 465 |
mel = spec_to_mel_torch(
|
| 466 |
spec,
|
| 467 |
hps.data.filter_length,
|
|
|
|
| 473 |
y_mel = commons.slice_segments(
|
| 474 |
mel, ids_slice, hps.train.segment_size // hps.data.hop_length
|
| 475 |
)
|
| 476 |
+
|
| 477 |
with autocast(enabled=False):
|
| 478 |
y_hat_mel = mel_spectrogram_torch(
|
| 479 |
y_hat.float().squeeze(1),
|
|
|
|
| 485 |
hps.data.mel_fmin,
|
| 486 |
hps.data.mel_fmax,
|
| 487 |
)
|
| 488 |
+
|
| 489 |
if hps.train.fp16_run == True:
|
| 490 |
y_hat_mel = y_hat_mel.half()
|
| 491 |
+
|
| 492 |
wave = commons.slice_segments(
|
| 493 |
wave, ids_slice * hps.data.hop_length, hps.train.segment_size
|
| 494 |
+
)
|
| 495 |
|
| 496 |
+
# Discriminator forward
|
| 497 |
y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach())
|
| 498 |
with autocast(enabled=False):
|
| 499 |
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
|
| 500 |
y_d_hat_r, y_d_hat_g
|
| 501 |
)
|
| 502 |
+
|
| 503 |
+
# Discriminator backward
|
| 504 |
optim_d.zero_grad()
|
| 505 |
scaler.scale(loss_disc).backward()
|
| 506 |
scaler.unscale_(optim_d)
|
| 507 |
grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
|
| 508 |
scaler.step(optim_d)
|
| 509 |
|
| 510 |
+
# Generator forward
|
| 511 |
with autocast(enabled=hps.train.fp16_run):
|
|
|
|
| 512 |
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat)
|
| 513 |
with autocast(enabled=False):
|
| 514 |
loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
|
|
|
|
| 516 |
loss_fm = feature_loss(fmap_r, fmap_g)
|
| 517 |
loss_gen, losses_gen = generator_loss(y_d_hat_g)
|
| 518 |
loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl
|
| 519 |
+
|
| 520 |
+
# Generator backward
|
| 521 |
optim_g.zero_grad()
|
| 522 |
scaler.scale(loss_gen_all).backward()
|
| 523 |
scaler.unscale_(optim_g)
|
|
|
|
| 525 |
scaler.step(optim_g)
|
| 526 |
scaler.update()
|
| 527 |
|
| 528 |
+
# Update progress bar and logging
|
| 529 |
if rank == 0:
|
| 530 |
+
if epoch_progress is not None:
|
| 531 |
+
epoch_progress.update(1)
|
| 532 |
+
|
| 533 |
+
# Update progress bar description with current losses
|
| 534 |
+
if batch_idx % hps.train.log_interval == 0:
|
| 535 |
+
postfix_dict = {
|
| 536 |
+
'G': f'{loss_gen_all:.3f}',
|
| 537 |
+
'D': f'{loss_disc:.3f}',
|
| 538 |
+
'Mel': f'{loss_mel:.3f}',
|
| 539 |
+
'KL': f'{loss_kl:.3f}',
|
| 540 |
+
'Step': global_step
|
| 541 |
+
}
|
| 542 |
+
epoch_progress.set_postfix(postfix_dict)
|
| 543 |
+
|
| 544 |
if global_step % hps.train.log_interval == 0:
|
| 545 |
lr = optim_g.param_groups[0]["lr"]
|
| 546 |
+
|
| 547 |
+
logger.info(f"\nEpoch: {epoch} [{batch_idx}/{len(train_loader)}]")
|
| 548 |
+
logger.info(f"Global Step: {global_step}")
|
| 549 |
+
logger.info(f"Learning Rate: {lr:.6f}")
|
| 550 |
+
logger.info(f"Generator Loss: {loss_gen_all:.3f} (FM: {loss_fm:.3f}, Mel: {loss_mel:.3f}, KL: {loss_kl:.3f})")
|
| 551 |
+
logger.info(f"Discriminator Loss: {loss_disc:.3f}")
|
| 552 |
+
logger.info(f"Grad Norm - G: {grad_norm_g:.3f}, D: {grad_norm_d:.3f}")
|
| 553 |
+
|
| 554 |
+
# Tensorboard logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 555 |
scalar_dict = {
|
| 556 |
"loss/g/total": loss_gen_all,
|
| 557 |
"loss/d/total": loss_disc,
|
| 558 |
"learning_rate": lr,
|
| 559 |
"grad_norm_d": grad_norm_d,
|
| 560 |
"grad_norm_g": grad_norm_g,
|
| 561 |
+
"loss/g/fm": loss_fm,
|
| 562 |
+
"loss/g/mel": loss_mel,
|
| 563 |
+
"loss/g/kl": loss_kl,
|
| 564 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 565 |
scalar_dict.update(
|
| 566 |
{"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}
|
| 567 |
)
|
|
|
|
| 571 |
scalar_dict.update(
|
| 572 |
{"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}
|
| 573 |
)
|
| 574 |
+
|
| 575 |
image_dict = {
|
| 576 |
"slice/mel_org": utils.plot_spectrogram_to_numpy(
|
| 577 |
y_mel[0].data.cpu().numpy()
|
|
|
|
| 583 |
mel[0].data.cpu().numpy()
|
| 584 |
),
|
| 585 |
}
|
| 586 |
+
|
| 587 |
utils.summarize(
|
| 588 |
writer=writer,
|
| 589 |
global_step=global_step,
|
| 590 |
images=image_dict,
|
| 591 |
scalars=scalar_dict,
|
| 592 |
)
|
| 593 |
+
|
| 594 |
global_step += 1
|
| 595 |
+
|
| 596 |
+
# Close progress bar
|
| 597 |
+
if rank == 0 and epoch_progress is not None:
|
| 598 |
+
epoch_progress.close()
|
| 599 |
+
|
| 600 |
+
# Save checkpoints
|
| 601 |
if epoch % hps.save_every_epoch == 0 and rank == 0:
|
| 602 |
if hps.if_latest == 0:
|
| 603 |
+
save_path_g = os.path.join(hps.model_dir, f"G_{global_step}.pth")
|
| 604 |
+
save_path_d = os.path.join(hps.model_dir, f"D_{global_step}.pth")
|
| 605 |
utils.save_checkpoint(
|
| 606 |
net_g,
|
| 607 |
optim_g,
|
| 608 |
hps.train.learning_rate,
|
| 609 |
epoch,
|
| 610 |
+
save_path_g,
|
| 611 |
)
|
| 612 |
utils.save_checkpoint(
|
| 613 |
net_d,
|
| 614 |
optim_d,
|
| 615 |
hps.train.learning_rate,
|
| 616 |
epoch,
|
| 617 |
+
save_path_d,
|
| 618 |
)
|
| 619 |
+
logger.info(f"Saved checkpoints: {save_path_g}, {save_path_d}")
|
| 620 |
else:
|
| 621 |
+
save_path_g = os.path.join(hps.model_dir, "G_2333333.pth")
|
| 622 |
+
save_path_d = os.path.join(hps.model_dir, "D_2333333.pth")
|
| 623 |
utils.save_checkpoint(
|
| 624 |
net_g,
|
| 625 |
optim_g,
|
| 626 |
hps.train.learning_rate,
|
| 627 |
epoch,
|
| 628 |
+
save_path_g,
|
| 629 |
)
|
| 630 |
utils.save_checkpoint(
|
| 631 |
net_d,
|
| 632 |
optim_d,
|
| 633 |
hps.train.learning_rate,
|
| 634 |
epoch,
|
| 635 |
+
save_path_d,
|
| 636 |
)
|
| 637 |
+
logger.info(f"Saved latest checkpoints: {save_path_g}, {save_path_d}")
|
| 638 |
+
|
| 639 |
if rank == 0 and hps.save_every_weights == "1":
|
| 640 |
if hasattr(net_g, "module"):
|
| 641 |
ckpt = net_g.module.state_dict()
|
| 642 |
else:
|
| 643 |
ckpt = net_g.state_dict()
|
| 644 |
+
|
| 645 |
+
model_name = hps.name + f"_e{epoch}_s{global_step}"
|
| 646 |
+
save_result = savee(
|
| 647 |
+
ckpt,
|
| 648 |
+
hps.sample_rate,
|
| 649 |
+
hps.if_f0,
|
| 650 |
+
model_name,
|
| 651 |
+
epoch,
|
| 652 |
+
hps.version,
|
| 653 |
+
hps,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 654 |
)
|
| 655 |
+
logger.info(f"Saved weights checkpoint: {model_name}: {save_result}")
|
| 656 |
|
| 657 |
+
# Log epoch completion
|
| 658 |
if rank == 0:
|
| 659 |
+
logger.info(f"Completed Epoch {epoch} {epoch_recorder.record()}")
|
| 660 |
+
logger.info(f"Global Step: {global_step}")
|
| 661 |
+
|
| 662 |
+
# End training if completed
|
| 663 |
if epoch >= hps.total_epoch and rank == 0:
|
| 664 |
+
logger.info("Training completed!")
|
| 665 |
+
|
| 666 |
if hasattr(net_g, "module"):
|
| 667 |
ckpt = net_g.module.state_dict()
|
| 668 |
else:
|
| 669 |
ckpt = net_g.state_dict()
|
| 670 |
+
|
| 671 |
+
final_save = savee(
|
| 672 |
+
ckpt, hps.sample_rate, hps.if_f0, hps.name, epoch, hps.version, hps
|
|
|
|
|
|
|
|
|
|
|
|
|
| 673 |
)
|
| 674 |
+
logger.info(f"Saved final model: {final_save}")
|
| 675 |
+
|
| 676 |
+
sleep(2) # Give time for final logging
|
| 677 |
+
os._exit(0)
|
| 678 |
|
| 679 |
|
| 680 |
if __name__ == "__main__":
|
| 681 |
torch.multiprocessing.set_start_method("spawn")
|
| 682 |
+
main()
|