Yoshitaka16 commited on
Commit
4f471db
·
verified ·
1 Parent(s): c159a35

Upload train.py

Browse files
Files changed (1) hide show
  1. train.py +995 -0
train.py ADDED
@@ -0,0 +1,995 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import glob
3
+ import json
4
+ import logging
5
+ import os
6
+ import shutil
7
+ import signal
8
+ import sys
9
+ from collections import deque
10
+ from random import randint, shuffle
11
+ from time import time as ttime
12
+
13
+ import numpy as np
14
+ from tqdm import tqdm
15
+
16
+ import torch
17
+ import torch.distributed as dist
18
+ import torch.multiprocessing as mp
19
+ from torch.nn.parallel import DistributedDataParallel as DDP
20
+ from torch.utils.data import DataLoader
21
+ from torch.utils.tensorboard import SummaryWriter
22
+
23
+ # Zluda hijack
24
+ import ultimate_rvc.rvc.lib.zluda
25
+ from ultimate_rvc.common import TRAINING_MODELS_DIR
26
+ from ultimate_rvc.rvc.common import RVC_TRAINING_MODELS_DIR
27
+ from ultimate_rvc.rvc.lib.algorithm import commons
28
+ from ultimate_rvc.rvc.train.losses import (
29
+ discriminator_loss,
30
+ feature_loss,
31
+ generator_loss,
32
+ kl_loss,
33
+ )
34
+ from ultimate_rvc.rvc.train.mel_processing import (
35
+ MultiScaleMelSpectrogramLoss,
36
+ mel_spectrogram_torch,
37
+ spec_to_mel_torch,
38
+ )
39
+ from ultimate_rvc.rvc.train.process.extract_model import extract_model
40
+ from ultimate_rvc.rvc.train.utils import (
41
+ HParams,
42
+ latest_checkpoint_path,
43
+ load_checkpoint,
44
+ load_wav_to_torch,
45
+ plot_spectrogram_to_numpy,
46
+ remove_sox_libmso6_from_ld_preload,
47
+ save_checkpoint,
48
+ summarize,
49
+ )
50
+
51
+ logging.getLogger("torch").setLevel(logging.ERROR)
52
+ logger = logging.getLogger(__name__)
53
+
54
+ torch.backends.cudnn.deterministic = False
55
+ torch.backends.cudnn.benchmark = True
56
+ torch.multiprocessing.set_start_method("spawn", force=True)
57
+ os.environ["USE_LIBUV"] = "0" if sys.platform == "win32" else "1"
58
+
59
+ randomized = True
60
+ optimizer = "AdamW" # "RAdam"
61
+ d_lr_coeff = 1.0
62
+ g_lr_coeff = 1.0
63
+ global_step = 0
64
+ lowest_g_value = {"value": float("inf"), "epoch": 0}
65
+ lowest_d_value = {"value": float("inf"), "epoch": 0}
66
+ consecutive_increases_gen = 0
67
+ consecutive_increases_disc = 0
68
+
69
+ avg_losses = {
70
+ "grad_d_50": deque(maxlen=50),
71
+ "grad_g_50": deque(maxlen=50),
72
+ "disc_loss_50": deque(maxlen=50),
73
+ "fm_loss_50": deque(maxlen=50),
74
+ "kl_loss_50": deque(maxlen=50),
75
+ "mel_loss_50": deque(maxlen=50),
76
+ "gen_loss_50": deque(maxlen=50),
77
+ }
78
+
79
+
80
+ class EpochRecorder:
81
+ """
82
+ Records the time elapsed per epoch.
83
+ """
84
+
85
+ def __init__(self):
86
+ self.last_time = ttime()
87
+
88
+ def record(self):
89
+ """
90
+ Records the elapsed time and returns a formatted string.
91
+ """
92
+ now_time = ttime()
93
+ elapsed_time = now_time - self.last_time
94
+ self.last_time = now_time
95
+ elapsed_time = round(elapsed_time, 1)
96
+ elapsed_time_str = str(datetime.timedelta(seconds=int(elapsed_time)))
97
+ current_time = datetime.datetime.now().strftime("%H:%M:%S")
98
+ return f"time={current_time} | speed={elapsed_time_str}"
99
+
100
+
101
+ def verify_checkpoint_shapes(checkpoint_path: str, model: torch.nn.Module) -> None:
102
+ """
103
+ Verify that the checkpoint and model have the same
104
+ architecture.
105
+ """
106
+ checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
107
+ checkpoint_state_dict = checkpoint["model"]
108
+ try:
109
+ if hasattr(model, "module"):
110
+ model_state_dict = model.module.load_state_dict(checkpoint_state_dict)
111
+ else:
112
+ model_state_dict = model.load_state_dict(checkpoint_state_dict)
113
+ except RuntimeError:
114
+ logger.error( # noqa: TRY400
115
+ "The parameters of the pretrain model such as the sample rate or"
116
+ " architecture do not match the selected model.",
117
+ )
118
+ sys.exit(1)
119
+ else:
120
+ del checkpoint
121
+ del checkpoint_state_dict
122
+ del model_state_dict
123
+
124
+
125
+ def main(
126
+ model_name: str,
127
+ sample_rate: int,
128
+ vocoder: str,
129
+ total_epoch: int,
130
+ batch_size: int,
131
+ save_every_epoch: int,
132
+ save_only_latest: bool,
133
+ save_every_weights: bool,
134
+ pretrain_g: str,
135
+ pretrain_d: str,
136
+ overtraining_detector: bool,
137
+ overtraining_threshold: int,
138
+ cleanup: bool,
139
+ cache_data_in_gpu: bool,
140
+ checkpointing: bool,
141
+ device_type: str,
142
+ gpus: set[int] | None,
143
+ ) -> None:
144
+ """
145
+ Start the training process.
146
+
147
+ Raises:
148
+ RuntimeError: If the sample rate of the pretrained model does not match the dataset audio sample rate.
149
+
150
+ """
151
+ remove_sox_libmso6_from_ld_preload()
152
+ experiment_dir = os.path.join(TRAINING_MODELS_DIR, model_name)
153
+ config_save_path = os.path.join(experiment_dir, "config.json")
154
+
155
+ # Use a Manager to create a shared list
156
+ manager = mp.Manager()
157
+ global_gen_loss = manager.list([0] * total_epoch)
158
+ global_disc_loss = manager.list([0] * total_epoch)
159
+
160
+ with open(config_save_path) as f:
161
+ config = json.load(f)
162
+ config = HParams(**config)
163
+ config.data.training_files = os.path.join(experiment_dir, "filelist.txt")
164
+
165
+ # Set up distributed training environment for master node.
166
+ # master node is localhost because we are running on a single local
167
+ # machine. master port is randomly selected
168
+ os.environ["MASTER_ADDR"] = "localhost"
169
+ os.environ["MASTER_PORT"] = str(randint(20000, 55555))
170
+ logger.info("MASTER_PORT: %s", os.environ["MASTER_PORT"])
171
+ # Check sample rate
172
+ wavs = glob.glob(
173
+ os.path.join(experiment_dir, "sliced_audios", "*.wav"),
174
+ )
175
+ if wavs:
176
+ _, sr = load_wav_to_torch(wavs[0])
177
+ if sr != sample_rate:
178
+ error_message = (
179
+ f"Error: Pretrained model sample rate ({sample_rate} Hz) does not match"
180
+ f" dataset audio sample rate ({sr} Hz)."
181
+ )
182
+ raise RuntimeError(error_message)
183
+ else:
184
+ logger.warning("No wav file found.")
185
+
186
+ device = torch.device(device_type)
187
+ gpus = gpus or {0}
188
+ n_gpus = len(gpus)
189
+
190
+ if device.type == "cpu":
191
+ logger.warning("Training with CPU, this will take a long time.")
192
+
193
+ def start() -> None:
194
+ """Start the training process with multi-GPU support or CPU."""
195
+ children = []
196
+ pid_data = {"process_pids": []}
197
+ with open(config_save_path) as pid_file:
198
+ try:
199
+ existing_data = json.load(pid_file)
200
+ pid_data.update(existing_data)
201
+ except json.JSONDecodeError:
202
+ pass
203
+ with open(config_save_path, "w") as pid_file:
204
+ for rank, device_id in enumerate(gpus):
205
+ subproc = mp.Process(
206
+ target=run,
207
+ args=(
208
+ rank,
209
+ n_gpus,
210
+ experiment_dir,
211
+ pretrain_g,
212
+ pretrain_d,
213
+ total_epoch,
214
+ save_every_weights,
215
+ config,
216
+ device,
217
+ device_id,
218
+ model_name,
219
+ sample_rate,
220
+ vocoder,
221
+ batch_size,
222
+ save_every_epoch,
223
+ save_only_latest,
224
+ overtraining_detector,
225
+ overtraining_threshold,
226
+ checkpointing,
227
+ cache_data_in_gpu,
228
+ global_gen_loss,
229
+ global_disc_loss,
230
+ ),
231
+ )
232
+ children.append(subproc)
233
+ subproc.start()
234
+ pid_data["process_pids"].append(subproc.pid)
235
+ json.dump(pid_data, pid_file, indent=4)
236
+ cancel_signal = signal.SIGTERM if os.name == "nt" else -signal.SIGTERM
237
+ error_codes = []
238
+ for i in range(n_gpus):
239
+ children[i].join()
240
+ exit_code = children[i].exitcode
241
+ if exit_code != 0:
242
+ logger.warning(
243
+ "Process running on device %s exited with code %s.",
244
+ device_id,
245
+ exit_code,
246
+ )
247
+ if exit_code != cancel_signal:
248
+ error_codes.append(exit_code)
249
+ if error_codes:
250
+ err_msg = (
251
+ "One or more training processes failed. See the logs or console for"
252
+ " details."
253
+ )
254
+ raise RuntimeError(err_msg)
255
+
256
+ if cleanup:
257
+ logger.info("Removing files from the prior training attempt...")
258
+
259
+ # Clean up unnecessary files
260
+ for entry in os.scandir(os.path.join(TRAINING_MODELS_DIR, model_name)):
261
+ if entry.is_file():
262
+ _, file_extension = os.path.splitext(entry.name)
263
+ if file_extension in {".0", ".pth", ".index"}:
264
+ os.remove(entry.path)
265
+ elif entry.is_dir() and entry.name == "eval":
266
+ shutil.rmtree(entry.path)
267
+
268
+ logger.info("Cleanup done!")
269
+ start()
270
+
271
+
272
+ def run(
273
+ rank,
274
+ n_gpus,
275
+ experiment_dir,
276
+ pretrain_g,
277
+ pretrain_d,
278
+ custom_total_epoch,
279
+ custom_save_every_weights,
280
+ config,
281
+ device,
282
+ device_id,
283
+ model_name,
284
+ sample_rate,
285
+ vocoder,
286
+ batch_size,
287
+ save_every_epoch,
288
+ save_only_latest,
289
+ overtraining_detector,
290
+ overtraining_threshold,
291
+ checkpointing,
292
+ cache_data_in_gpu,
293
+ global_gen_loss,
294
+ global_disc_loss,
295
+ ):
296
+ """
297
+ Runs the training loop on a specific GPU or CPU.
298
+
299
+ Args:
300
+ rank (int): The rank of the current process within the distributed training setup.
301
+ n_gpus (int): The total number of GPUs available for training.
302
+ experiment_dir (str): The directory where experiment logs and checkpoints will be saved.
303
+ pretrain_g (str): Path to the pre-trained generator model.
304
+ pretrain_d (str): Path to the pre-trained discriminator model.
305
+ custom_total_epoch (int): The total number of epochs for training.
306
+ custom_save_every_weights (int): The interval (in epochs) at which to save model weights.
307
+ config (object): Configuration object containing training parameters.
308
+ device (torch.device): The device to use for training (CPU or GPU).
309
+
310
+ """
311
+ global global_step, optimizer, lowest_d_value, lowest_g_value, consecutive_increases_gen, consecutive_increases_disc
312
+
313
+ if rank == 0:
314
+ writer_eval = SummaryWriter(log_dir=os.path.join(experiment_dir, "eval"))
315
+ else:
316
+ writer_eval = None
317
+
318
+ # Initialize distributed training environment for child node.
319
+ dist.init_process_group(
320
+ backend="gloo" if sys.platform == "win32" or device.type != "cuda" else "nccl",
321
+ init_method="env://",
322
+ world_size=n_gpus if device.type == "cuda" else 1,
323
+ rank=rank if device.type == "cuda" else 0,
324
+ )
325
+
326
+ torch.manual_seed(config.train.seed)
327
+
328
+ if device.type == "cuda":
329
+ torch.cuda.set_device(device_id)
330
+
331
+ # Create datasets and dataloaders
332
+ from ultimate_rvc.rvc.train.data_utils import (
333
+ DistributedBucketSampler,
334
+ TextAudioCollateMultiNSFsid,
335
+ TextAudioLoaderMultiNSFsid,
336
+ )
337
+
338
+ train_dataset = TextAudioLoaderMultiNSFsid(config.data)
339
+ collate_fn = TextAudioCollateMultiNSFsid()
340
+ train_sampler = DistributedBucketSampler(
341
+ train_dataset,
342
+ batch_size * n_gpus,
343
+ [50, 100, 200, 300, 400, 500, 600, 700, 800, 900],
344
+ num_replicas=n_gpus,
345
+ rank=rank,
346
+ shuffle=True,
347
+ )
348
+
349
+ train_loader = DataLoader(
350
+ train_dataset,
351
+ num_workers=4,
352
+ shuffle=False,
353
+ pin_memory=True,
354
+ collate_fn=collate_fn,
355
+ batch_sampler=train_sampler,
356
+ persistent_workers=True,
357
+ prefetch_factor=8,
358
+ )
359
+ if len(train_loader) < 3:
360
+ logger.error(
361
+ "Not enough data in the training set. Perhaps you forgot to slice the"
362
+ " audio files in preprocess?",
363
+ )
364
+ sys.exit(1)
365
+ else:
366
+ g_file = latest_checkpoint_path(experiment_dir, "G_*.pth")
367
+ if g_file != None:
368
+ logger.info("Checking saved weights...")
369
+ g = torch.load(g_file, map_location="cpu", weights_only=False)
370
+ if (
371
+ optimizer == "RAdam"
372
+ and "amsgrad" in g["optimizer"]["param_groups"][0].keys()
373
+ ):
374
+ optimizer = "AdamW"
375
+ logger.info(
376
+ "Optimizer choice has been reverted to %s to match the saved D/G"
377
+ " weights.",
378
+ optimizer,
379
+ )
380
+ elif (
381
+ optimizer == "AdamW"
382
+ and "decoupled_weight_decay" in g["optimizer"]["param_groups"][0].keys()
383
+ ):
384
+ optimizer = "RAdam"
385
+ logger.info(
386
+ "Optimizer choice has been reverted to %s to match the saved D/G"
387
+ " weights.",
388
+ optimizer,
389
+ )
390
+ del g
391
+
392
+ # Initialize models and optimizers
393
+ from ultimate_rvc.rvc.lib.algorithm.discriminators import MultiPeriodDiscriminator
394
+ from ultimate_rvc.rvc.lib.algorithm.synthesizers import Synthesizer
395
+
396
+ # NOTE checkingpointing here means whether or not activations are
397
+ # saved during forward pass for backpropagation during backward pass
398
+
399
+ net_g = Synthesizer(
400
+ config.data.filter_length // 2 + 1,
401
+ config.train.segment_size // config.data.hop_length,
402
+ **config.model,
403
+ use_f0=True,
404
+ sr=sample_rate,
405
+ vocoder=vocoder,
406
+ checkpointing=checkpointing,
407
+ randomized=randomized,
408
+ )
409
+
410
+ net_d = MultiPeriodDiscriminator(
411
+ config.model.use_spectral_norm,
412
+ checkpointing=checkpointing,
413
+ )
414
+
415
+ if device.type == "cuda":
416
+ net_g = net_g.cuda(device_id)
417
+ net_d = net_d.cuda(device_id)
418
+ else:
419
+ net_g = net_g.to(device)
420
+ net_d = net_d.to(device)
421
+
422
+ if optimizer == "AdamW":
423
+ optimizer = torch.optim.AdamW
424
+ elif optimizer == "RAdam":
425
+ optimizer = torch.optim.RAdam
426
+
427
+ optim_g = optimizer(
428
+ net_g.parameters(),
429
+ config.train.learning_rate * g_lr_coeff,
430
+ betas=config.train.betas,
431
+ eps=config.train.eps,
432
+ )
433
+ optim_d = optimizer(
434
+ net_d.parameters(),
435
+ config.train.learning_rate * d_lr_coeff,
436
+ betas=config.train.betas,
437
+ eps=config.train.eps,
438
+ )
439
+
440
+ fn_mel_loss = MultiScaleMelSpectrogramLoss(sample_rate=sample_rate)
441
+
442
+ # Wrap models with DDP for multi-gpu processing
443
+ if n_gpus > 1 and device.type == "cuda":
444
+ net_g = DDP(net_g, device_ids=[device_id])
445
+ net_d = DDP(net_d, device_ids=[device_id])
446
+
447
+ # Load checkpoint if available
448
+ try:
449
+ logger.info("Starting training...")
450
+ _, _, _, epoch_str, lowest_d_value, consecutive_increases_disc = (
451
+ load_checkpoint(
452
+ latest_checkpoint_path(experiment_dir, "D_*.pth"),
453
+ net_d,
454
+ optim_d,
455
+ )
456
+ )
457
+ _, _, _, epoch_str, lowest_g_value, consecutive_increases_gen = load_checkpoint(
458
+ latest_checkpoint_path(experiment_dir, "G_*.pth"),
459
+ net_g,
460
+ optim_g,
461
+ )
462
+ epoch_str += 1
463
+ global_step = (epoch_str - 1) * len(train_loader)
464
+
465
+ except Exception:
466
+ epoch_str = 1
467
+ global_step = 0
468
+ if pretrain_g not in {"", "None"}:
469
+ if rank == 0:
470
+ verify_checkpoint_shapes(pretrain_g, net_g)
471
+ logger.info("Loaded pretrained (G) '%s'", pretrain_g)
472
+ if hasattr(net_g, "module"):
473
+ net_g.module.load_state_dict(
474
+ torch.load(pretrain_g, map_location="cpu", weights_only=False)[
475
+ "model"
476
+ ],
477
+ )
478
+ else:
479
+ net_g.load_state_dict(
480
+ torch.load(pretrain_g, map_location="cpu", weights_only=False)[
481
+ "model"
482
+ ],
483
+ )
484
+
485
+ if pretrain_d not in {"", "None"}:
486
+ if rank == 0:
487
+ verify_checkpoint_shapes(pretrain_d, net_d)
488
+ logger.info("Loaded pretrained (D) '%s'", pretrain_d)
489
+ if hasattr(net_d, "module"):
490
+ net_d.module.load_state_dict(
491
+ torch.load(pretrain_d, map_location="cpu", weights_only=False)[
492
+ "model"
493
+ ],
494
+ )
495
+ else:
496
+ net_d.load_state_dict(
497
+ torch.load(pretrain_d, map_location="cpu", weights_only=False)[
498
+ "model"
499
+ ],
500
+ )
501
+
502
+ # Initialize schedulers
503
+ scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
504
+ optim_g,
505
+ gamma=config.train.lr_decay,
506
+ last_epoch=epoch_str - 2,
507
+ )
508
+ scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
509
+ optim_d,
510
+ gamma=config.train.lr_decay,
511
+ last_epoch=epoch_str - 2,
512
+ )
513
+
514
+ cache = []
515
+ # get the first sample as reference for tensorboard evaluation
516
+ # custom reference temporarily disabled
517
+ if True == False and os.path.isfile(
518
+ os.path.join(RVC_TRAINING_MODELS_DIR, "reference", f"ref{sample_rate}.wav"),
519
+ ):
520
+ phone = np.load(
521
+ os.path.join(
522
+ RVC_TRAINING_MODELS_DIR,
523
+ "reference",
524
+ f"ref{sample_rate}_feats.npy",
525
+ ),
526
+ )
527
+ # expanding x2 to match pitch size
528
+ phone = np.repeat(phone, 2, axis=0)
529
+ phone = torch.FloatTensor(phone).unsqueeze(0).to(device)
530
+ phone_lengths = torch.LongTensor(phone.size(0)).to(device)
531
+ pitch = np.load(
532
+ os.path.join(
533
+ RVC_TRAINING_MODELS_DIR,
534
+ "reference",
535
+ f"ref{sample_rate}_f0c.npy",
536
+ ),
537
+ )
538
+ # removed last frame to match features
539
+ pitch = torch.LongTensor(pitch[:-1]).unsqueeze(0).to(device)
540
+ pitchf = np.load(
541
+ os.path.join(
542
+ RVC_TRAINING_MODELS_DIR,
543
+ "reference",
544
+ f"ref{sample_rate}_f0f.npy",
545
+ ),
546
+ )
547
+ # removed last frame to match features
548
+ pitchf = torch.FloatTensor(pitchf[:-1]).unsqueeze(0).to(device)
549
+ sid = torch.LongTensor([0]).to(device)
550
+ reference = (
551
+ phone,
552
+ phone_lengths,
553
+ pitch,
554
+ pitchf,
555
+ sid,
556
+ )
557
+ else:
558
+ for info in train_loader:
559
+ phone, phone_lengths, pitch, pitchf, _, _, _, _, sid = info
560
+ if device.type == "cuda":
561
+ reference = (
562
+ phone.cuda(device_id, non_blocking=True),
563
+ phone_lengths.cuda(device_id, non_blocking=True),
564
+ pitch.cuda(device_id, non_blocking=True),
565
+ pitchf.cuda(device_id, non_blocking=True),
566
+ sid.cuda(device_id, non_blocking=True),
567
+ )
568
+ else:
569
+ reference = (
570
+ phone.to(device),
571
+ phone_lengths.to(device),
572
+ pitch.to(device),
573
+ pitchf.to(device),
574
+ sid.to(device),
575
+ )
576
+ break
577
+
578
+ for epoch in range(epoch_str, custom_total_epoch + 1):
579
+ train_and_evaluate(
580
+ rank,
581
+ epoch,
582
+ config,
583
+ [net_g, net_d],
584
+ [optim_g, optim_d],
585
+ [scheduler_g, scheduler_d],
586
+ [train_loader, None],
587
+ [writer_eval],
588
+ cache,
589
+ custom_save_every_weights,
590
+ custom_total_epoch,
591
+ device,
592
+ device_id,
593
+ reference,
594
+ fn_mel_loss,
595
+ model_name,
596
+ experiment_dir,
597
+ sample_rate,
598
+ vocoder,
599
+ save_every_epoch,
600
+ save_only_latest,
601
+ overtraining_detector,
602
+ overtraining_threshold,
603
+ cache_data_in_gpu,
604
+ global_gen_loss,
605
+ global_disc_loss,
606
+ )
607
+
608
+
609
+ def train_and_evaluate(
610
+ rank,
611
+ epoch,
612
+ config,
613
+ nets,
614
+ optims,
615
+ schedulers,
616
+ loaders,
617
+ writers,
618
+ cache,
619
+ custom_save_every_weights,
620
+ custom_total_epoch,
621
+ device,
622
+ device_id,
623
+ reference,
624
+ fn_mel_loss,
625
+ model_name,
626
+ experiment_dir,
627
+ sample_rate,
628
+ vocoder,
629
+ save_every_epoch,
630
+ save_only_latest,
631
+ overtraining_detector,
632
+ overtraining_threshold,
633
+ cache_data_in_gpu,
634
+ global_gen_loss,
635
+ global_disc_loss,
636
+ ) -> None:
637
+ """Train and evaluates the model for one epoch."""
638
+ global global_step, lowest_g_value, lowest_d_value, consecutive_increases_gen, consecutive_increases_disc
639
+
640
+ model_add = []
641
+ checkpoint_idxs = []
642
+ done = False
643
+
644
+ net_g, net_d = nets
645
+ optim_g, optim_d = optims
646
+ scheduler_g, scheduler_d = schedulers
647
+ skip_g_scheduler, skip_d_scheduler = False, False
648
+ train_loader = loaders[0] if loaders is not None else None
649
+ if writers is not None:
650
+ writer = writers[0]
651
+
652
+ train_loader.batch_sampler.set_epoch(epoch)
653
+
654
+ net_g.train()
655
+ net_d.train()
656
+
657
+ # Data caching
658
+ if device.type == "cuda" and cache_data_in_gpu:
659
+ if cache == []:
660
+ for batch_idx, info in enumerate(train_loader):
661
+ # phone, phone_lengths, pitch, pitchf, spec, spec_lengths, wave, wave_lengths, sid
662
+ info = [tensor.cuda(device_id, non_blocking=True) for tensor in info]
663
+ cache.append((batch_idx, info))
664
+ shuffle(cache)
665
+ data_iterator = cache
666
+ else:
667
+ data_iterator = enumerate(train_loader)
668
+
669
+ epoch_recorder = EpochRecorder()
670
+ with tqdm(total=len(train_loader), leave=False) as pbar:
671
+ for batch_idx, info in data_iterator:
672
+ if device.type == "cuda" and not cache_data_in_gpu:
673
+ info = [tensor.cuda(device_id, non_blocking=True) for tensor in info]
674
+ elif device.type != "cuda":
675
+ info = [tensor.to(device) for tensor in info]
676
+ # else iterator is going thru a cached list with a device already assigned
677
+
678
+ (
679
+ phone,
680
+ phone_lengths,
681
+ pitch,
682
+ pitchf,
683
+ spec,
684
+ spec_lengths,
685
+ wave,
686
+ wave_lengths,
687
+ sid,
688
+ ) = info
689
+
690
+ # Forward pass
691
+ model_output = net_g(
692
+ phone,
693
+ phone_lengths,
694
+ pitch,
695
+ pitchf,
696
+ spec,
697
+ spec_lengths,
698
+ sid,
699
+ )
700
+ y_hat, ids_slice, x_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q) = (
701
+ model_output
702
+ )
703
+ # slice of the original waveform to match a generate slice
704
+ if randomized:
705
+ wave = commons.slice_segments(
706
+ wave,
707
+ ids_slice * config.data.hop_length,
708
+ config.train.segment_size,
709
+ dim=3,
710
+ )
711
+ y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach())
712
+ loss_disc, _, _ = discriminator_loss(y_d_hat_r, y_d_hat_g)
713
+ # Discriminator backward and update
714
+ global_disc_loss[epoch - 1] += loss_disc.item()
715
+ optim_d.zero_grad()
716
+ loss_disc.backward()
717
+ grad_norm_d = commons.grad_norm(net_d.parameters())
718
+ optim_d.step()
719
+
720
+ # Generator backward and update
721
+ _, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat)
722
+ loss_mel = fn_mel_loss(wave, y_hat) * config.train.c_mel / 3.0
723
+ loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * config.train.c_kl
724
+ loss_fm = feature_loss(fmap_r, fmap_g)
725
+ loss_gen, _ = generator_loss(y_d_hat_g)
726
+ loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl
727
+ global_gen_loss[epoch - 1] += loss_gen_all.item()
728
+ optim_g.zero_grad()
729
+ loss_gen_all.backward()
730
+ grad_norm_g = commons.grad_norm(net_g.parameters())
731
+ optim_g.step()
732
+
733
+ global_step += 1
734
+
735
+ # queue for rolling losses over 50 steps
736
+ avg_losses["grad_d_50"].append(grad_norm_d)
737
+ avg_losses["grad_g_50"].append(grad_norm_g)
738
+ avg_losses["disc_loss_50"].append(loss_disc.detach())
739
+ avg_losses["fm_loss_50"].append(loss_fm.detach())
740
+ avg_losses["kl_loss_50"].append(loss_kl.detach())
741
+ avg_losses["mel_loss_50"].append(loss_mel.detach())
742
+ avg_losses["gen_loss_50"].append(loss_gen_all.detach())
743
+
744
+ if rank == 0 and global_step % 50 == 0:
745
+ # logging rolling averages
746
+ scalar_dict = {
747
+ "grad_avg_50/norm_d": (
748
+ sum(avg_losses["grad_d_50"]) / len(avg_losses["grad_d_50"])
749
+ ),
750
+ "grad_avg_50/norm_g": (
751
+ sum(avg_losses["grad_g_50"]) / len(avg_losses["grad_g_50"])
752
+ ),
753
+ "loss_avg_50/d/total": torch.mean(
754
+ torch.stack(list(avg_losses["disc_loss_50"])),
755
+ ),
756
+ "loss_avg_50/g/fm": torch.mean(
757
+ torch.stack(list(avg_losses["fm_loss_50"])),
758
+ ),
759
+ "loss_avg_50/g/kl": torch.mean(
760
+ torch.stack(list(avg_losses["kl_loss_50"])),
761
+ ),
762
+ "loss_avg_50/g/mel": torch.mean(
763
+ torch.stack(list(avg_losses["mel_loss_50"])),
764
+ ),
765
+ "loss_avg_50/g/total": torch.mean(
766
+ torch.stack(list(avg_losses["gen_loss_50"])),
767
+ ),
768
+ }
769
+ summarize(
770
+ writer=writer,
771
+ global_step=global_step,
772
+ scalars=scalar_dict,
773
+ )
774
+
775
+ pbar.update(1)
776
+ # end of batch train
777
+ # end of tqdm
778
+ scheduler_d.step()
779
+ scheduler_g.step()
780
+
781
+ with torch.no_grad():
782
+ torch.cuda.empty_cache()
783
+ # Logging and checkpointing
784
+ if rank == 0:
785
+ avg_global_disc_loss = global_disc_loss[epoch - 1] / len(train_loader.dataset)
786
+ avg_global_gen_loss = global_gen_loss[epoch - 1] / len(train_loader.dataset)
787
+
788
+ min_delta = 0.004
789
+
790
+ if avg_global_disc_loss < lowest_d_value["value"] - min_delta:
791
+ lowest_d_value = {"value": avg_global_disc_loss, "epoch": epoch}
792
+ consecutive_increases_disc = 0
793
+ else:
794
+ consecutive_increases_disc += 1
795
+
796
+ if avg_global_gen_loss < lowest_g_value["value"] - min_delta:
797
+ logger.info(
798
+ "New best epoch %d with average generator loss %.3f and discriminator"
799
+ " loss %.3f",
800
+ epoch,
801
+ avg_global_gen_loss,
802
+ avg_global_disc_loss,
803
+ )
804
+ lowest_g_value = {"value": avg_global_gen_loss, "epoch": epoch}
805
+ consecutive_increases_gen = 0
806
+ model_add.append(
807
+ os.path.join(experiment_dir, f"{model_name}_best.pth"),
808
+ )
809
+ else:
810
+ consecutive_increases_gen += 1
811
+
812
+ # used for tensorboard chart - all/mel
813
+ mel = spec_to_mel_torch(
814
+ spec,
815
+ config.data.filter_length,
816
+ config.data.n_mel_channels,
817
+ config.data.sample_rate,
818
+ config.data.mel_fmin,
819
+ config.data.mel_fmax,
820
+ )
821
+ # used for tensorboard chart - slice/mel_org
822
+ if randomized:
823
+ y_mel = commons.slice_segments(
824
+ mel,
825
+ ids_slice,
826
+ config.train.segment_size // config.data.hop_length,
827
+ dim=3,
828
+ )
829
+ else:
830
+ y_mel = mel
831
+ # used for tensorboard chart - slice/mel_gen
832
+ y_hat_mel = mel_spectrogram_torch(
833
+ y_hat.float().squeeze(1),
834
+ config.data.filter_length,
835
+ config.data.n_mel_channels,
836
+ config.data.sample_rate,
837
+ config.data.hop_length,
838
+ config.data.win_length,
839
+ config.data.mel_fmin,
840
+ config.data.mel_fmax,
841
+ )
842
+
843
+ lr = optim_g.param_groups[0]["lr"]
844
+
845
+ scalar_dict = {
846
+ "loss/g/total": loss_gen_all,
847
+ "loss/d/total": loss_disc,
848
+ "learning_rate": lr,
849
+ "grad/norm_d": grad_norm_d,
850
+ "grad/norm_g": grad_norm_g,
851
+ "loss/g/fm": loss_fm,
852
+ "loss/g/mel": loss_mel,
853
+ "loss/g/kl": loss_kl,
854
+ }
855
+
856
+ image_dict = {
857
+ "slice/mel_org": plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()),
858
+ "slice/mel_gen": plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()),
859
+ "all/mel": plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()),
860
+ }
861
+ overtrain_info = ""
862
+ # Print training progress
863
+ lowest_g_value_rounded = float(lowest_g_value["value"])
864
+ lowest_g_value_rounded = round(lowest_g_value_rounded, 3)
865
+
866
+ record = f"{model_name} | epoch={epoch} | {epoch_recorder.record()}"
867
+ record += (
868
+ f" | best avg-gen-loss={lowest_g_value_rounded:.3f} (epoch"
869
+ f" {lowest_g_value['epoch']})"
870
+ )
871
+ # Check overtraining
872
+ if overtraining_detector:
873
+ overtrain_info = (
874
+ f"Average epoch generator loss {avg_global_gen_loss:.3f} and"
875
+ f" discriminator loss {avg_global_disc_loss:.3f}"
876
+ )
877
+
878
+ remaining_epochs_gen = max(
879
+ overtraining_threshold - consecutive_increases_gen,
880
+ 0,
881
+ )
882
+ remaining_epochs_disc = max(
883
+ overtraining_threshold * 2 - consecutive_increases_disc,
884
+ 0,
885
+ )
886
+ record += (
887
+ " | overtrain countdown: g="
888
+ f"{remaining_epochs_gen},d={remaining_epochs_disc} |"
889
+ f" avg-gen-loss={avg_global_gen_loss:.3f} | avg-"
890
+ f"disc-loss={avg_global_disc_loss:.3f}"
891
+ )
892
+
893
+ if remaining_epochs_disc == 0 or remaining_epochs_gen == 0:
894
+ record += (
895
+ f"\nOvertraining detected at epoch {epoch} with average"
896
+ f" generator loss {avg_global_gen_loss:.3f} and discriminator loss"
897
+ f" {avg_global_disc_loss:.3f}"
898
+ )
899
+ done = True
900
+ print(record)
901
+
902
+ # Save weights, checkpoints and reference inference results
903
+ # every N epochs
904
+ if epoch % save_every_epoch == 0:
905
+ with torch.no_grad():
906
+ if hasattr(net_g, "module"):
907
+ o, *_ = net_g.module.infer(*reference)
908
+ else:
909
+ o, *_ = net_g.infer(*reference)
910
+ audio_dict = {f"gen/audio_{global_step:07d}": o[0, :, :]}
911
+ summarize(
912
+ writer=writer,
913
+ global_step=global_step,
914
+ images=image_dict,
915
+ scalars=scalar_dict,
916
+ audios=audio_dict,
917
+ audio_sample_rate=config.data.sample_rate,
918
+ )
919
+ checkpoint_idxs.append(2333333)
920
+ if not save_only_latest:
921
+ checkpoint_idxs.append(epoch)
922
+
923
+ if custom_save_every_weights:
924
+ model_add.append(
925
+ os.path.join(experiment_dir, f"{model_name}_{epoch}.pth"),
926
+ )
927
+ else:
928
+ summarize(
929
+ writer=writer,
930
+ global_step=global_step,
931
+ images=image_dict,
932
+ scalars=scalar_dict,
933
+ )
934
+ for idx in checkpoint_idxs:
935
+ save_checkpoint(
936
+ net_g,
937
+ optim_g,
938
+ config.train.learning_rate,
939
+ epoch,
940
+ lowest_g_value,
941
+ consecutive_increases_gen,
942
+ os.path.join(experiment_dir, f"G_{idx}.pth"),
943
+ )
944
+ save_checkpoint(
945
+ net_d,
946
+ optim_d,
947
+ config.train.learning_rate,
948
+ epoch,
949
+ lowest_d_value,
950
+ consecutive_increases_disc,
951
+ os.path.join(experiment_dir, f"D_{idx}.pth"),
952
+ )
953
+ if model_add:
954
+ ckpt = (
955
+ net_g.module.state_dict()
956
+ if hasattr(net_g, "module")
957
+ else net_g.state_dict()
958
+ )
959
+ for m in model_add:
960
+ extract_model(
961
+ ckpt=ckpt,
962
+ sr=sample_rate,
963
+ name=model_name,
964
+ model_dir=m,
965
+ epoch=epoch,
966
+ step=global_step,
967
+ hps=config,
968
+ overtrain_info=overtrain_info,
969
+ vocoder=vocoder,
970
+ )
971
+ # Check completion
972
+ if epoch >= custom_total_epoch:
973
+ lowest_g_value_rounded = float(lowest_g_value["value"])
974
+ lowest_g_value_rounded = round(lowest_g_value_rounded, 3)
975
+ print(
976
+ f"Training has been successfully completed with {epoch} epoch(s),"
977
+ f" {global_step} step(s) and {round(avg_global_gen_loss, 3)} average"
978
+ " generator loss.",
979
+ )
980
+ print(
981
+ f"Lowest average generator loss: {lowest_g_value_rounded} at epoch"
982
+ f" {lowest_g_value['epoch']}",
983
+ )
984
+
985
+ done = True
986
+ with torch.no_grad():
987
+ torch.cuda.empty_cache()
988
+ if done:
989
+ pid_file_path = os.path.join(experiment_dir, "config.json")
990
+ with open(pid_file_path) as pid_file:
991
+ pid_data = json.load(pid_file)
992
+ with open(pid_file_path, "w") as pid_file:
993
+ pid_data.pop("process_pids", None)
994
+ json.dump(pid_data, pid_file, indent=4)
995
+ os._exit(0)