Yoshitaka16 commited on
Commit
3b097af
·
verified ·
1 Parent(s): 8e26279

Delete train.py

Browse files
Files changed (1) hide show
  1. train.py +0 -995
train.py DELETED
@@ -1,995 +0,0 @@
1
- import datetime
2
- import glob
3
- import json
4
- import logging
5
- import os
6
- import shutil
7
- import signal
8
- import sys
9
- from collections import deque
10
- from random import randint, shuffle
11
- from time import time as ttime
12
-
13
- import numpy as np
14
- from tqdm import tqdm
15
-
16
- import torch
17
- import torch.distributed as dist
18
- import torch.multiprocessing as mp
19
- from torch.nn.parallel import DistributedDataParallel as DDP
20
- from torch.utils.data import DataLoader
21
- from torch.utils.tensorboard import SummaryWriter
22
-
23
- # Zluda hijack
24
- import ultimate_rvc.rvc.lib.zluda
25
- from ultimate_rvc.common import TRAINING_MODELS_DIR
26
- from ultimate_rvc.rvc.common import RVC_TRAINING_MODELS_DIR
27
- from ultimate_rvc.rvc.lib.algorithm import commons
28
- from ultimate_rvc.rvc.train.losses import (
29
- discriminator_loss,
30
- feature_loss,
31
- generator_loss,
32
- kl_loss,
33
- )
34
- from ultimate_rvc.rvc.train.mel_processing import (
35
- MultiScaleMelSpectrogramLoss,
36
- mel_spectrogram_torch,
37
- spec_to_mel_torch,
38
- )
39
- from ultimate_rvc.rvc.train.process.extract_model import extract_model
40
- from ultimate_rvc.rvc.train.utils import (
41
- HParams,
42
- latest_checkpoint_path,
43
- load_checkpoint,
44
- load_wav_to_torch,
45
- plot_spectrogram_to_numpy,
46
- remove_sox_libmso6_from_ld_preload,
47
- save_checkpoint,
48
- summarize,
49
- )
50
-
51
- logging.getLogger("torch").setLevel(logging.ERROR)
52
- logger = logging.getLogger(__name__)
53
-
54
- torch.backends.cudnn.deterministic = False
55
- torch.backends.cudnn.benchmark = True
56
- torch.multiprocessing.set_start_method("spawn", force=True)
57
- os.environ["USE_LIBUV"] = "0" if sys.platform == "win32" else "1"
58
-
59
- randomized = True
60
- optimizer = "AdamW" # "RAdam"
61
- d_lr_coeff = 1.0
62
- g_lr_coeff = 1.0
63
- global_step = 0
64
- lowest_g_value = {"value": float("inf"), "epoch": 0}
65
- lowest_d_value = {"value": float("inf"), "epoch": 0}
66
- consecutive_increases_gen = 0
67
- consecutive_increases_disc = 0
68
-
69
- avg_losses = {
70
- "grad_d_50": deque(maxlen=50),
71
- "grad_g_50": deque(maxlen=50),
72
- "disc_loss_50": deque(maxlen=50),
73
- "fm_loss_50": deque(maxlen=50),
74
- "kl_loss_50": deque(maxlen=50),
75
- "mel_loss_50": deque(maxlen=50),
76
- "gen_loss_50": deque(maxlen=50),
77
- }
78
-
79
-
80
- class EpochRecorder:
81
- """
82
- Records the time elapsed per epoch.
83
- """
84
-
85
- def __init__(self):
86
- self.last_time = ttime()
87
-
88
- def record(self):
89
- """
90
- Records the elapsed time and returns a formatted string.
91
- """
92
- now_time = ttime()
93
- elapsed_time = now_time - self.last_time
94
- self.last_time = now_time
95
- elapsed_time = round(elapsed_time, 1)
96
- elapsed_time_str = str(datetime.timedelta(seconds=int(elapsed_time)))
97
- current_time = datetime.datetime.now().strftime("%H:%M:%S")
98
- return f"time={current_time} | speed={elapsed_time_str}"
99
-
100
-
101
- def verify_checkpoint_shapes(checkpoint_path: str, model: torch.nn.Module) -> None:
102
- """
103
- Verify that the checkpoint and model have the same
104
- architecture.
105
- """
106
- checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
107
- checkpoint_state_dict = checkpoint["model"]
108
- try:
109
- if hasattr(model, "module"):
110
- model_state_dict = model.module.load_state_dict(checkpoint_state_dict)
111
- else:
112
- model_state_dict = model.load_state_dict(checkpoint_state_dict)
113
- except RuntimeError:
114
- logger.error( # noqa: TRY400
115
- "The parameters of the pretrain model such as the sample rate or"
116
- " architecture do not match the selected model.",
117
- )
118
- sys.exit(1)
119
- else:
120
- del checkpoint
121
- del checkpoint_state_dict
122
- del model_state_dict
123
-
124
-
125
- def main(
126
- model_name: str,
127
- sample_rate: int,
128
- vocoder: str,
129
- total_epoch: int,
130
- batch_size: int,
131
- save_every_epoch: int,
132
- save_only_latest: bool,
133
- save_every_weights: bool,
134
- pretrain_g: str,
135
- pretrain_d: str,
136
- overtraining_detector: bool,
137
- overtraining_threshold: int,
138
- cleanup: bool,
139
- cache_data_in_gpu: bool,
140
- checkpointing: bool,
141
- device_type: str,
142
- gpus: set[int] | None,
143
- ) -> None:
144
- """
145
- Start the training process.
146
-
147
- Raises:
148
- RuntimeError: If the sample rate of the pretrained model does not match the dataset audio sample rate.
149
-
150
- """
151
- remove_sox_libmso6_from_ld_preload()
152
- experiment_dir = os.path.join(TRAINING_MODELS_DIR, model_name)
153
- config_save_path = os.path.join(experiment_dir, "config.json")
154
-
155
- # Use a Manager to create a shared list
156
- manager = mp.Manager()
157
- global_gen_loss = manager.list([0] * total_epoch)
158
- global_disc_loss = manager.list([0] * total_epoch)
159
-
160
- with open(config_save_path) as f:
161
- config = json.load(f)
162
- config = HParams(**config)
163
- config.data.training_files = os.path.join(experiment_dir, "filelist.txt")
164
-
165
- # Set up distributed training environment for master node.
166
- # master node is localhost because we are running on a single local
167
- # machine. master port is randomly selected
168
- os.environ["MASTER_ADDR"] = "localhost"
169
- os.environ["MASTER_PORT"] = str(randint(20000, 55555))
170
- logger.info("MASTER_PORT: %s", os.environ["MASTER_PORT"])
171
- # Check sample rate
172
- wavs = glob.glob(
173
- os.path.join(experiment_dir, "sliced_audios", "*.wav"),
174
- )
175
- if wavs:
176
- _, sr = load_wav_to_torch(wavs[0])
177
- if sr != sample_rate:
178
- error_message = (
179
- f"Error: Pretrained model sample rate ({sample_rate} Hz) does not match"
180
- f" dataset audio sample rate ({sr} Hz)."
181
- )
182
- raise RuntimeError(error_message)
183
- else:
184
- logger.warning("No wav file found.")
185
-
186
- device = torch.device(device_type)
187
- gpus = gpus or {0}
188
- n_gpus = len(gpus)
189
-
190
- if device.type == "cpu":
191
- logger.warning("Training with CPU, this will take a long time.")
192
-
193
- def start() -> None:
194
- """Start the training process with multi-GPU support or CPU."""
195
- children = []
196
- pid_data = {"process_pids": []}
197
- with open(config_save_path) as pid_file:
198
- try:
199
- existing_data = json.load(pid_file)
200
- pid_data.update(existing_data)
201
- except json.JSONDecodeError:
202
- pass
203
- with open(config_save_path, "w") as pid_file:
204
- for rank, device_id in enumerate(gpus):
205
- subproc = mp.Process(
206
- target=run,
207
- args=(
208
- rank,
209
- n_gpus,
210
- experiment_dir,
211
- pretrain_g,
212
- pretrain_d,
213
- total_epoch,
214
- save_every_weights,
215
- config,
216
- device,
217
- device_id,
218
- model_name,
219
- sample_rate,
220
- vocoder,
221
- batch_size,
222
- save_every_epoch,
223
- save_only_latest,
224
- overtraining_detector,
225
- overtraining_threshold,
226
- checkpointing,
227
- cache_data_in_gpu,
228
- global_gen_loss,
229
- global_disc_loss,
230
- ),
231
- )
232
- children.append(subproc)
233
- subproc.start()
234
- pid_data["process_pids"].append(subproc.pid)
235
- json.dump(pid_data, pid_file, indent=4)
236
- cancel_signal = signal.SIGTERM if os.name == "nt" else -signal.SIGTERM
237
- error_codes = []
238
- for i in range(n_gpus):
239
- children[i].join()
240
- exit_code = children[i].exitcode
241
- if exit_code != 0:
242
- logger.warning(
243
- "Process running on device %s exited with code %s.",
244
- device_id,
245
- exit_code,
246
- )
247
- if exit_code != cancel_signal:
248
- error_codes.append(exit_code)
249
- if error_codes:
250
- err_msg = (
251
- "One or more training processes failed. See the logs or console for"
252
- " details."
253
- )
254
- raise RuntimeError(err_msg)
255
-
256
- if cleanup:
257
- logger.info("Removing files from the prior training attempt...")
258
-
259
- # Clean up unnecessary files
260
- for entry in os.scandir(os.path.join(TRAINING_MODELS_DIR, model_name)):
261
- if entry.is_file():
262
- _, file_extension = os.path.splitext(entry.name)
263
- if file_extension in {".0", ".pth", ".index"}:
264
- os.remove(entry.path)
265
- elif entry.is_dir() and entry.name == "eval":
266
- shutil.rmtree(entry.path)
267
-
268
- logger.info("Cleanup done!")
269
- start()
270
-
271
-
272
- def run(
273
- rank,
274
- n_gpus,
275
- experiment_dir,
276
- pretrain_g,
277
- pretrain_d,
278
- custom_total_epoch,
279
- custom_save_every_weights,
280
- config,
281
- device,
282
- device_id,
283
- model_name,
284
- sample_rate,
285
- vocoder,
286
- batch_size,
287
- save_every_epoch,
288
- save_only_latest,
289
- overtraining_detector,
290
- overtraining_threshold,
291
- checkpointing,
292
- cache_data_in_gpu,
293
- global_gen_loss,
294
- global_disc_loss,
295
- ):
296
- """
297
- Runs the training loop on a specific GPU or CPU.
298
-
299
- Args:
300
- rank (int): The rank of the current process within the distributed training setup.
301
- n_gpus (int): The total number of GPUs available for training.
302
- experiment_dir (str): The directory where experiment logs and checkpoints will be saved.
303
- pretrain_g (str): Path to the pre-trained generator model.
304
- pretrain_d (str): Path to the pre-trained discriminator model.
305
- custom_total_epoch (int): The total number of epochs for training.
306
- custom_save_every_weights (int): The interval (in epochs) at which to save model weights.
307
- config (object): Configuration object containing training parameters.
308
- device (torch.device): The device to use for training (CPU or GPU).
309
-
310
- """
311
- global global_step, optimizer, lowest_d_value, lowest_g_value, consecutive_increases_gen, consecutive_increases_disc
312
-
313
- if rank == 0:
314
- writer_eval = SummaryWriter(log_dir=os.path.join(experiment_dir, "eval"))
315
- else:
316
- writer_eval = None
317
-
318
- # Initialize distributed training environment for child node.
319
- dist.init_process_group(
320
- backend="gloo" if sys.platform == "win32" or device.type != "cuda" else "nccl",
321
- init_method="env://",
322
- world_size=n_gpus if device.type == "cuda" else 1,
323
- rank=rank if device.type == "cuda" else 0,
324
- )
325
-
326
- torch.manual_seed(config.train.seed)
327
-
328
- if device.type == "cuda":
329
- torch.cuda.set_device(device_id)
330
-
331
- # Create datasets and dataloaders
332
- from ultimate_rvc.rvc.train.data_utils import (
333
- DistributedBucketSampler,
334
- TextAudioCollateMultiNSFsid,
335
- TextAudioLoaderMultiNSFsid,
336
- )
337
-
338
- train_dataset = TextAudioLoaderMultiNSFsid(config.data)
339
- collate_fn = TextAudioCollateMultiNSFsid()
340
- train_sampler = DistributedBucketSampler(
341
- train_dataset,
342
- batch_size * n_gpus,
343
- [50, 100, 200, 300, 400, 500, 600, 700, 800, 900],
344
- num_replicas=n_gpus,
345
- rank=rank,
346
- shuffle=True,
347
- )
348
-
349
- train_loader = DataLoader(
350
- train_dataset,
351
- num_workers=4,
352
- shuffle=False,
353
- pin_memory=True,
354
- collate_fn=collate_fn,
355
- batch_sampler=train_sampler,
356
- persistent_workers=True,
357
- prefetch_factor=8,
358
- )
359
- if len(train_loader) < 3:
360
- logger.error(
361
- "Not enough data in the training set. Perhaps you forgot to slice the"
362
- " audio files in preprocess?",
363
- )
364
- sys.exit(1)
365
- else:
366
- g_file = latest_checkpoint_path(experiment_dir, "G_*.pth")
367
- if g_file != None:
368
- logger.info("Checking saved weights...")
369
- g = torch.load(g_file, map_location="cpu", weights_only=False)
370
- if (
371
- optimizer == "RAdam"
372
- and "amsgrad" in g["optimizer"]["param_groups"][0].keys()
373
- ):
374
- optimizer = "AdamW"
375
- logger.info(
376
- "Optimizer choice has been reverted to %s to match the saved D/G"
377
- " weights.",
378
- optimizer,
379
- )
380
- elif (
381
- optimizer == "AdamW"
382
- and "decoupled_weight_decay" in g["optimizer"]["param_groups"][0].keys()
383
- ):
384
- optimizer = "RAdam"
385
- logger.info(
386
- "Optimizer choice has been reverted to %s to match the saved D/G"
387
- " weights.",
388
- optimizer,
389
- )
390
- del g
391
-
392
- # Initialize models and optimizers
393
- from ultimate_rvc.rvc.lib.algorithm.discriminators import MultiPeriodDiscriminator
394
- from ultimate_rvc.rvc.lib.algorithm.synthesizers import Synthesizer
395
-
396
- # NOTE checkingpointing here means whether or not activations are
397
- # saved during forward pass for backpropagation during backward pass
398
-
399
- net_g = Synthesizer(
400
- config.data.filter_length // 2 + 1,
401
- config.train.segment_size // config.data.hop_length,
402
- **config.model,
403
- use_f0=True,
404
- sr=sample_rate,
405
- vocoder=vocoder,
406
- checkpointing=checkpointing,
407
- randomized=randomized,
408
- )
409
-
410
- net_d = MultiPeriodDiscriminator(
411
- config.model.use_spectral_norm,
412
- checkpointing=checkpointing,
413
- )
414
-
415
- if device.type == "cuda":
416
- net_g = net_g.cuda(device_id)
417
- net_d = net_d.cuda(device_id)
418
- else:
419
- net_g = net_g.to(device)
420
- net_d = net_d.to(device)
421
-
422
- if optimizer == "AdamW":
423
- optimizer = torch.optim.AdamW
424
- elif optimizer == "RAdam":
425
- optimizer = torch.optim.RAdam
426
-
427
- optim_g = optimizer(
428
- net_g.parameters(),
429
- config.train.learning_rate * g_lr_coeff,
430
- betas=config.train.betas,
431
- eps=config.train.eps,
432
- )
433
- optim_d = optimizer(
434
- net_d.parameters(),
435
- config.train.learning_rate * d_lr_coeff,
436
- betas=config.train.betas,
437
- eps=config.train.eps,
438
- )
439
-
440
- fn_mel_loss = MultiScaleMelSpectrogramLoss(sample_rate=sample_rate)
441
-
442
- # Wrap models with DDP for multi-gpu processing
443
- if n_gpus > 1 and device.type == "cuda":
444
- net_g = DDP(net_g, device_ids=[device_id])
445
- net_d = DDP(net_d, device_ids=[device_id])
446
-
447
- # Load checkpoint if available
448
- try:
449
- logger.info("Starting training...")
450
- _, _, _, epoch_str, lowest_d_value, consecutive_increases_disc = (
451
- load_checkpoint(
452
- latest_checkpoint_path(experiment_dir, "D_*.pth"),
453
- net_d,
454
- optim_d,
455
- )
456
- )
457
- _, _, _, epoch_str, lowest_g_value, consecutive_increases_gen = load_checkpoint(
458
- latest_checkpoint_path(experiment_dir, "G_*.pth"),
459
- net_g,
460
- optim_g,
461
- )
462
- epoch_str += 1
463
- global_step = (epoch_str - 1) * len(train_loader)
464
-
465
- except Exception:
466
- epoch_str = 1
467
- global_step = 0
468
- if pretrain_g not in {"", "None"}:
469
- if rank == 0:
470
- verify_checkpoint_shapes(pretrain_g, net_g)
471
- logger.info("Loaded pretrained (G) '%s'", pretrain_g)
472
- if hasattr(net_g, "module"):
473
- net_g.module.load_state_dict(
474
- torch.load(pretrain_g, map_location="cpu", weights_only=False)[
475
- "model"
476
- ],
477
- )
478
- else:
479
- net_g.load_state_dict(
480
- torch.load(pretrain_g, map_location="cpu", weights_only=False)[
481
- "model"
482
- ],
483
- )
484
-
485
- if pretrain_d not in {"", "None"}:
486
- if rank == 0:
487
- verify_checkpoint_shapes(pretrain_d, net_d)
488
- logger.info("Loaded pretrained (D) '%s'", pretrain_d)
489
- if hasattr(net_d, "module"):
490
- net_d.module.load_state_dict(
491
- torch.load(pretrain_d, map_location="cpu", weights_only=False)[
492
- "model"
493
- ],
494
- )
495
- else:
496
- net_d.load_state_dict(
497
- torch.load(pretrain_d, map_location="cpu", weights_only=False)[
498
- "model"
499
- ],
500
- )
501
-
502
- # Initialize schedulers
503
- scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
504
- optim_g,
505
- gamma=config.train.lr_decay,
506
- last_epoch=epoch_str - 2,
507
- )
508
- scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
509
- optim_d,
510
- gamma=config.train.lr_decay,
511
- last_epoch=epoch_str - 2,
512
- )
513
-
514
- cache = []
515
- # get the first sample as reference for tensorboard evaluation
516
- # custom reference temporarily disabled
517
- if True == False and os.path.isfile(
518
- os.path.join(RVC_TRAINING_MODELS_DIR, "reference", f"ref{sample_rate}.wav"),
519
- ):
520
- phone = np.load(
521
- os.path.join(
522
- RVC_TRAINING_MODELS_DIR,
523
- "reference",
524
- f"ref{sample_rate}_feats.npy",
525
- ),
526
- )
527
- # expanding x2 to match pitch size
528
- phone = np.repeat(phone, 2, axis=0)
529
- phone = torch.FloatTensor(phone).unsqueeze(0).to(device)
530
- phone_lengths = torch.LongTensor(phone.size(0)).to(device)
531
- pitch = np.load(
532
- os.path.join(
533
- RVC_TRAINING_MODELS_DIR,
534
- "reference",
535
- f"ref{sample_rate}_f0c.npy",
536
- ),
537
- )
538
- # removed last frame to match features
539
- pitch = torch.LongTensor(pitch[:-1]).unsqueeze(0).to(device)
540
- pitchf = np.load(
541
- os.path.join(
542
- RVC_TRAINING_MODELS_DIR,
543
- "reference",
544
- f"ref{sample_rate}_f0f.npy",
545
- ),
546
- )
547
- # removed last frame to match features
548
- pitchf = torch.FloatTensor(pitchf[:-1]).unsqueeze(0).to(device)
549
- sid = torch.LongTensor([0]).to(device)
550
- reference = (
551
- phone,
552
- phone_lengths,
553
- pitch,
554
- pitchf,
555
- sid,
556
- )
557
- else:
558
- for info in train_loader:
559
- phone, phone_lengths, pitch, pitchf, _, _, _, _, sid = info
560
- if device.type == "cuda":
561
- reference = (
562
- phone.cuda(device_id, non_blocking=True),
563
- phone_lengths.cuda(device_id, non_blocking=True),
564
- pitch.cuda(device_id, non_blocking=True),
565
- pitchf.cuda(device_id, non_blocking=True),
566
- sid.cuda(device_id, non_blocking=True),
567
- )
568
- else:
569
- reference = (
570
- phone.to(device),
571
- phone_lengths.to(device),
572
- pitch.to(device),
573
- pitchf.to(device),
574
- sid.to(device),
575
- )
576
- break
577
-
578
- for epoch in range(epoch_str, custom_total_epoch + 1):
579
- train_and_evaluate(
580
- rank,
581
- epoch,
582
- config,
583
- [net_g, net_d],
584
- [optim_g, optim_d],
585
- [scheduler_g, scheduler_d],
586
- [train_loader, None],
587
- [writer_eval],
588
- cache,
589
- custom_save_every_weights,
590
- custom_total_epoch,
591
- device,
592
- device_id,
593
- reference,
594
- fn_mel_loss,
595
- model_name,
596
- experiment_dir,
597
- sample_rate,
598
- vocoder,
599
- save_every_epoch,
600
- save_only_latest,
601
- overtraining_detector,
602
- overtraining_threshold,
603
- cache_data_in_gpu,
604
- global_gen_loss,
605
- global_disc_loss,
606
- )
607
-
608
-
609
- def train_and_evaluate(
610
- rank,
611
- epoch,
612
- config,
613
- nets,
614
- optims,
615
- schedulers,
616
- loaders,
617
- writers,
618
- cache,
619
- custom_save_every_weights,
620
- custom_total_epoch,
621
- device,
622
- device_id,
623
- reference,
624
- fn_mel_loss,
625
- model_name,
626
- experiment_dir,
627
- sample_rate,
628
- vocoder,
629
- save_every_epoch,
630
- save_only_latest,
631
- overtraining_detector,
632
- overtraining_threshold,
633
- cache_data_in_gpu,
634
- global_gen_loss,
635
- global_disc_loss,
636
- ) -> None:
637
- """Train and evaluates the model for one epoch."""
638
- global global_step, lowest_g_value, lowest_d_value, consecutive_increases_gen, consecutive_increases_disc
639
-
640
- model_add = []
641
- checkpoint_idxs = []
642
- done = False
643
-
644
- net_g, net_d = nets
645
- optim_g, optim_d = optims
646
- scheduler_g, scheduler_d = schedulers
647
- skip_g_scheduler, skip_d_scheduler = False, False
648
- train_loader = loaders[0] if loaders is not None else None
649
- if writers is not None:
650
- writer = writers[0]
651
-
652
- train_loader.batch_sampler.set_epoch(epoch)
653
-
654
- net_g.train()
655
- net_d.train()
656
-
657
- # Data caching
658
- if device.type == "cuda" and cache_data_in_gpu:
659
- if cache == []:
660
- for batch_idx, info in enumerate(train_loader):
661
- # phone, phone_lengths, pitch, pitchf, spec, spec_lengths, wave, wave_lengths, sid
662
- info = [tensor.cuda(device_id, non_blocking=True) for tensor in info]
663
- cache.append((batch_idx, info))
664
- shuffle(cache)
665
- data_iterator = cache
666
- else:
667
- data_iterator = enumerate(train_loader)
668
-
669
- epoch_recorder = EpochRecorder()
670
- with tqdm(total=len(train_loader), leave=False) as pbar:
671
- for batch_idx, info in data_iterator:
672
- if device.type == "cuda" and not cache_data_in_gpu:
673
- info = [tensor.cuda(device_id, non_blocking=True) for tensor in info]
674
- elif device.type != "cuda":
675
- info = [tensor.to(device) for tensor in info]
676
- # else iterator is going thru a cached list with a device already assigned
677
-
678
- (
679
- phone,
680
- phone_lengths,
681
- pitch,
682
- pitchf,
683
- spec,
684
- spec_lengths,
685
- wave,
686
- wave_lengths,
687
- sid,
688
- ) = info
689
-
690
- # Forward pass
691
- model_output = net_g(
692
- phone,
693
- phone_lengths,
694
- pitch,
695
- pitchf,
696
- spec,
697
- spec_lengths,
698
- sid,
699
- )
700
- y_hat, ids_slice, x_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q) = (
701
- model_output
702
- )
703
- # slice of the original waveform to match a generate slice
704
- if randomized:
705
- wave = commons.slice_segments(
706
- wave,
707
- ids_slice * config.data.hop_length,
708
- config.train.segment_size,
709
- dim=3,
710
- )
711
- y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach())
712
- loss_disc, _, _ = discriminator_loss(y_d_hat_r, y_d_hat_g)
713
- # Discriminator backward and update
714
- global_disc_loss[epoch - 1] += loss_disc.item()
715
- optim_d.zero_grad()
716
- loss_disc.backward()
717
- grad_norm_d = commons.grad_norm(net_d.parameters())
718
- optim_d.step()
719
-
720
- # Generator backward and update
721
- _, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat)
722
- loss_mel = fn_mel_loss(wave, y_hat) * config.train.c_mel / 3.0
723
- loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * config.train.c_kl
724
- loss_fm = feature_loss(fmap_r, fmap_g)
725
- loss_gen, _ = generator_loss(y_d_hat_g)
726
- loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl
727
- global_gen_loss[epoch - 1] += loss_gen_all.item()
728
- optim_g.zero_grad()
729
- loss_gen_all.backward()
730
- grad_norm_g = commons.grad_norm(net_g.parameters())
731
- optim_g.step()
732
-
733
- global_step += 1
734
-
735
- # queue for rolling losses over 50 steps
736
- avg_losses["grad_d_50"].append(grad_norm_d)
737
- avg_losses["grad_g_50"].append(grad_norm_g)
738
- avg_losses["disc_loss_50"].append(loss_disc.detach())
739
- avg_losses["fm_loss_50"].append(loss_fm.detach())
740
- avg_losses["kl_loss_50"].append(loss_kl.detach())
741
- avg_losses["mel_loss_50"].append(loss_mel.detach())
742
- avg_losses["gen_loss_50"].append(loss_gen_all.detach())
743
-
744
- if rank == 0 and global_step % 50 == 0:
745
- # logging rolling averages
746
- scalar_dict = {
747
- "grad_avg_50/norm_d": (
748
- sum(avg_losses["grad_d_50"]) / len(avg_losses["grad_d_50"])
749
- ),
750
- "grad_avg_50/norm_g": (
751
- sum(avg_losses["grad_g_50"]) / len(avg_losses["grad_g_50"])
752
- ),
753
- "loss_avg_50/d/total": torch.mean(
754
- torch.stack(list(avg_losses["disc_loss_50"])),
755
- ),
756
- "loss_avg_50/g/fm": torch.mean(
757
- torch.stack(list(avg_losses["fm_loss_50"])),
758
- ),
759
- "loss_avg_50/g/kl": torch.mean(
760
- torch.stack(list(avg_losses["kl_loss_50"])),
761
- ),
762
- "loss_avg_50/g/mel": torch.mean(
763
- torch.stack(list(avg_losses["mel_loss_50"])),
764
- ),
765
- "loss_avg_50/g/total": torch.mean(
766
- torch.stack(list(avg_losses["gen_loss_50"])),
767
- ),
768
- }
769
- summarize(
770
- writer=writer,
771
- global_step=global_step,
772
- scalars=scalar_dict,
773
- )
774
-
775
- pbar.update(1)
776
- # end of batch train
777
- # end of tqdm
778
- scheduler_d.step()
779
- scheduler_g.step()
780
-
781
- with torch.no_grad():
782
- torch.cuda.empty_cache()
783
- # Logging and checkpointing
784
- if rank == 0:
785
- avg_global_disc_loss = global_disc_loss[epoch - 1] / len(train_loader.dataset)
786
- avg_global_gen_loss = global_gen_loss[epoch - 1] / len(train_loader.dataset)
787
-
788
- min_delta = 0.004
789
-
790
- if avg_global_disc_loss < lowest_d_value["value"] - min_delta:
791
- lowest_d_value = {"value": avg_global_disc_loss, "epoch": epoch}
792
- consecutive_increases_disc = 0
793
- else:
794
- consecutive_increases_disc += 1
795
-
796
- if avg_global_gen_loss < lowest_g_value["value"] - min_delta:
797
- logger.info(
798
- "New best epoch %d with average generator loss %.3f and discriminator"
799
- " loss %.3f",
800
- epoch,
801
- avg_global_gen_loss,
802
- avg_global_disc_loss,
803
- )
804
- lowest_g_value = {"value": avg_global_gen_loss, "epoch": epoch}
805
- consecutive_increases_gen = 0
806
- model_add.append(
807
- os.path.join(experiment_dir, f"{model_name}_best.pth"),
808
- )
809
- else:
810
- consecutive_increases_gen += 1
811
-
812
- # used for tensorboard chart - all/mel
813
- mel = spec_to_mel_torch(
814
- spec,
815
- config.data.filter_length,
816
- config.data.n_mel_channels,
817
- config.data.sample_rate,
818
- config.data.mel_fmin,
819
- config.data.mel_fmax,
820
- )
821
- # used for tensorboard chart - slice/mel_org
822
- if randomized:
823
- y_mel = commons.slice_segments(
824
- mel,
825
- ids_slice,
826
- config.train.segment_size // config.data.hop_length,
827
- dim=3,
828
- )
829
- else:
830
- y_mel = mel
831
- # used for tensorboard chart - slice/mel_gen
832
- y_hat_mel = mel_spectrogram_torch(
833
- y_hat.float().squeeze(1),
834
- config.data.filter_length,
835
- config.data.n_mel_channels,
836
- config.data.sample_rate,
837
- config.data.hop_length,
838
- config.data.win_length,
839
- config.data.mel_fmin,
840
- config.data.mel_fmax,
841
- )
842
-
843
- lr = optim_g.param_groups[0]["lr"]
844
-
845
- scalar_dict = {
846
- "loss/g/total": loss_gen_all,
847
- "loss/d/total": loss_disc,
848
- "learning_rate": lr,
849
- "grad/norm_d": grad_norm_d,
850
- "grad/norm_g": grad_norm_g,
851
- "loss/g/fm": loss_fm,
852
- "loss/g/mel": loss_mel,
853
- "loss/g/kl": loss_kl,
854
- }
855
-
856
- image_dict = {
857
- "slice/mel_org": plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()),
858
- "slice/mel_gen": plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()),
859
- "all/mel": plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()),
860
- }
861
- overtrain_info = ""
862
- # Print training progress
863
- lowest_g_value_rounded = float(lowest_g_value["value"])
864
- lowest_g_value_rounded = round(lowest_g_value_rounded, 3)
865
-
866
- record = f"{model_name} | epoch={epoch} | {epoch_recorder.record()}"
867
- record += (
868
- f" | best avg-gen-loss={lowest_g_value_rounded:.3f} (epoch"
869
- f" {lowest_g_value['epoch']})"
870
- )
871
- # Check overtraining
872
- if overtraining_detector:
873
- overtrain_info = (
874
- f"Average epoch generator loss {avg_global_gen_loss:.3f} and"
875
- f" discriminator loss {avg_global_disc_loss:.3f}"
876
- )
877
-
878
- remaining_epochs_gen = max(
879
- overtraining_threshold - consecutive_increases_gen,
880
- 0,
881
- )
882
- remaining_epochs_disc = max(
883
- overtraining_threshold * 2 - consecutive_increases_disc,
884
- 0,
885
- )
886
- record += (
887
- " | overtrain countdown: g="
888
- f"{remaining_epochs_gen},d={remaining_epochs_disc} |"
889
- f" avg-gen-loss={avg_global_gen_loss:.3f} | avg-"
890
- f"disc-loss={avg_global_disc_loss:.3f}"
891
- )
892
-
893
- if remaining_epochs_disc == 0 or remaining_epochs_gen == 0:
894
- record += (
895
- f"\nOvertraining detected at epoch {epoch} with average"
896
- f" generator loss {avg_global_gen_loss:.3f} and discriminator loss"
897
- f" {avg_global_disc_loss:.3f}"
898
- )
899
- done = True
900
- print(record)
901
-
902
- # Save weights, checkpoints and reference inference results
903
- # every N epochs
904
- if epoch % save_every_epoch == 0:
905
- with torch.no_grad():
906
- if hasattr(net_g, "module"):
907
- o, *_ = net_g.module.infer(*reference)
908
- else:
909
- o, *_ = net_g.infer(*reference)
910
- audio_dict = {f"gen/audio_{global_step:07d}": o[0, :, :]}
911
- summarize(
912
- writer=writer,
913
- global_step=global_step,
914
- images=image_dict,
915
- scalars=scalar_dict,
916
- audios=audio_dict,
917
- audio_sample_rate=config.data.sample_rate,
918
- )
919
- checkpoint_idxs.append(2333333)
920
- if not save_only_latest:
921
- checkpoint_idxs.append(epoch)
922
-
923
- if custom_save_every_weights:
924
- model_add.append(
925
- os.path.join(experiment_dir, f"{model_name}_{epoch}.pth"),
926
- )
927
- else:
928
- summarize(
929
- writer=writer,
930
- global_step=global_step,
931
- images=image_dict,
932
- scalars=scalar_dict,
933
- )
934
- for idx in checkpoint_idxs:
935
- save_checkpoint(
936
- net_g,
937
- optim_g,
938
- config.train.learning_rate,
939
- epoch,
940
- lowest_g_value,
941
- consecutive_increases_gen,
942
- os.path.join(experiment_dir, f"G_{idx}.pth"),
943
- )
944
- save_checkpoint(
945
- net_d,
946
- optim_d,
947
- config.train.learning_rate,
948
- epoch,
949
- lowest_d_value,
950
- consecutive_increases_disc,
951
- os.path.join(experiment_dir, f"D_{idx}.pth"),
952
- )
953
- if model_add:
954
- ckpt = (
955
- net_g.module.state_dict()
956
- if hasattr(net_g, "module")
957
- else net_g.state_dict()
958
- )
959
- for m in model_add:
960
- extract_model(
961
- ckpt=ckpt,
962
- sr=sample_rate,
963
- name=model_name,
964
- model_dir=m,
965
- epoch=epoch,
966
- step=global_step,
967
- hps=config,
968
- overtrain_info=overtrain_info,
969
- vocoder=vocoder,
970
- )
971
- # Check completion
972
- if epoch >= custom_total_epoch:
973
- lowest_g_value_rounded = float(lowest_g_value["value"])
974
- lowest_g_value_rounded = round(lowest_g_value_rounded, 3)
975
- print(
976
- f"Training has been successfully completed with {epoch} epoch(s),"
977
- f" {global_step} step(s) and {round(avg_global_gen_loss, 3)} average"
978
- " generator loss.",
979
- )
980
- print(
981
- f"Lowest average generator loss: {lowest_g_value_rounded} at epoch"
982
- f" {lowest_g_value['epoch']}",
983
- )
984
-
985
- done = True
986
- with torch.no_grad():
987
- torch.cuda.empty_cache()
988
- if done:
989
- pid_file_path = os.path.join(experiment_dir, "config.json")
990
- with open(pid_file_path) as pid_file:
991
- pid_data = json.load(pid_file)
992
- with open(pid_file_path, "w") as pid_file:
993
- pid_data.pop("process_pids", None)
994
- json.dump(pid_data, pid_file, indent=4)
995
- os._exit(0)