Yoshitaka16 commited on
Commit
b6e3132
·
verified ·
1 Parent(s): 3b097af

Upload train.py

Browse files
Files changed (1) hide show
  1. train.py +1419 -0
train.py ADDED
@@ -0,0 +1,1419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import datetime
3
+ import glob
4
+ import itertools
5
+ import json
6
+ import math
7
+ import re
8
+ #import signal
9
+ import subprocess
10
+ import sys
11
+ import warnings
12
+
13
+ pid_data = {"process_pids": []}
14
+ os.environ["USE_LIBUV"] = "0" if sys.platform == "win32" else "1"
15
+
16
+ from typing import Tuple
17
+ from collections import deque
18
+ from distutils.util import strtobool
19
+ from random import randint, shuffle
20
+ from time import time as ttime, sleep
21
+
22
+
23
+ from tqdm import TqdmExperimentalWarning
24
+ from tqdm.rich import trange, tqdm
25
+ from pesq import pesq
26
+ import numpy as np
27
+ import psutil
28
+
29
+ warnings.filterwarnings("ignore", category=TqdmExperimentalWarning)
30
+
31
+ import torch
32
+ import torch.nn as nn
33
+ import torchaudio
34
+ from torch.nn.parallel import DistributedDataParallel as DDP
35
+ from torch.utils.tensorboard import SummaryWriter
36
+ from torch.amp import autocast
37
+ from torch.utils.data import DataLoader
38
+ from torch.nn import functional as F
39
+ from torch.nn.utils import clip_grad_norm_
40
+ import torch.distributed as dist
41
+ import torch.multiprocessing as mp
42
+ import auraloss
43
+
44
+ now_dir = os.getcwd()
45
+ sys.path.append(os.path.join(now_dir))
46
+
47
+ import rvc.lib.zluda # Zluda hijack
48
+
49
+ from utils import (
50
+ HParams,
51
+ plot_spectrogram_to_numpy,
52
+ summarize,
53
+ load_checkpoint,
54
+ save_checkpoint,
55
+ latest_checkpoint_path,
56
+ load_wav_to_torch,
57
+ load_config_from_json,
58
+ mel_spec_similarity,
59
+ flush_writer,
60
+ block_tensorboard_flush_on_exit,
61
+ si_sdr,
62
+ wave_to_mel,
63
+ small_model_naming,
64
+ old_session_cleanup,
65
+ verify_remap_checkpoint,
66
+ print_init_setup,
67
+ train_loader_safety,
68
+ verify_spk_dim,
69
+ )
70
+ from losses import (
71
+ discriminator_loss,
72
+ generator_loss,
73
+ feature_loss,
74
+ kl_loss,
75
+ phase_loss,
76
+ )
77
+ from mel_processing import (
78
+ spec_to_mel_torch,
79
+ MultiScaleMelSpectrogramLoss,
80
+ )
81
+ from rvc.train.process.extract_model import extract_model
82
+ from rvc.lib.algorithm import commons
83
+ from rvc.train.utils import replace_keys_in_dict
84
+
85
+
86
+ # Parse command line arguments start region ===========================
87
+
88
+ model_name = sys.argv[1]
89
+ epoch_save_frequency = int(sys.argv[2])
90
+ total_epoch_count = int(sys.argv[3])
91
+ pretrainG = sys.argv[4]
92
+ pretrainD = sys.argv[5]
93
+ gpus = sys.argv[6]
94
+ batch_size = int(sys.argv[7])
95
+ sample_rate = int(sys.argv[8])
96
+ save_only_latest_net_models = strtobool(sys.argv[9])
97
+ save_weight_models = strtobool(sys.argv[10])
98
+ cache_data_in_gpu = strtobool(sys.argv[11])
99
+ use_warmup = strtobool(sys.argv[12])
100
+ warmup_duration = int(sys.argv[13])
101
+ cleanup = strtobool(sys.argv[14])
102
+ vocoder = sys.argv[15]
103
+ architecture = sys.argv[16]
104
+ optimizer_choice = sys.argv[17]
105
+ use_checkpointing = strtobool(sys.argv[18])
106
+ use_tf32 = bool(strtobool(sys.argv[19]))
107
+ use_benchmark = bool(strtobool(sys.argv[20]))
108
+ use_deterministic = bool(strtobool(sys.argv[21]))
109
+ spectral_loss = sys.argv[22]
110
+ lr_scheduler = sys.argv[23]
111
+ exp_decay_gamma = float(sys.argv[24])
112
+ use_validation = strtobool(sys.argv[25])
113
+ double_d_update = strtobool(sys.argv[26])
114
+ use_custom_lr = strtobool(sys.argv[27])
115
+ custom_lr_g, custom_lr_d = (float(sys.argv[28]), float(sys.argv[29])) if use_custom_lr else (None, None)
116
+ assert not use_custom_lr or (custom_lr_g and custom_lr_d), "Invalid custom LR values."
117
+
118
+ # Parse command line arguments end region ===========================
119
+
120
+
121
+ current_dir = os.getcwd()
122
+ experiment_dir = os.path.join(current_dir, "logs", model_name)
123
+ config_save_path = os.path.join(experiment_dir, "config.json")
124
+ dataset_path = os.path.join(experiment_dir, "sliced_audios")
125
+ model_info_path = os.path.join(experiment_dir, "model_info.json")
126
+
127
+
128
+ # Load the config from json
129
+ config = load_config_from_json(config_save_path)
130
+ config.data.training_files = os.path.join(experiment_dir, "filelist.txt")
131
+
132
+
133
+ # AMP precision / dtype init
134
+ if config.train.bf16_run:
135
+ train_dtype = torch.bfloat16
136
+ elif config.train.fp16_run:
137
+ train_dtype = torch.float16
138
+ else:
139
+ train_dtype = torch.float32
140
+
141
+
142
+ # Globals ( do not touch these. )
143
+ global_step = 0
144
+ d_updates_per_step = 2 if double_d_update else 1
145
+ warmup_completed = False
146
+ from_scratch = False
147
+ use_lr_scheduler = lr_scheduler != "none"
148
+
149
+
150
+ # Torch backends config
151
+ torch.backends.cuda.matmul.allow_tf32 = use_tf32
152
+ torch.backends.cudnn.allow_tf32 = use_tf32
153
+ torch.backends.cudnn.benchmark = use_benchmark
154
+ torch.backends.cudnn.deterministic = use_deterministic
155
+
156
+
157
+ # Globals ( tweakable )
158
+ randomized = False
159
+ benchmark_mode = True
160
+ enable_persistent_workers = True
161
+ debug_shapes = False
162
+
163
+
164
+ # EXPERIMENTAL
165
+ c_stft = 21.0 # 18.0
166
+
167
+
168
+ ##################################################################
169
+
170
+ import logging
171
+ logging.getLogger("torch").setLevel(logging.ERROR)
172
+
173
+
174
+ class EpochRecorder:
175
+ """
176
+ Records the time elapsed per epoch.
177
+ """
178
+
179
+ def __init__(self):
180
+ self.last_time = ttime()
181
+
182
+ def record(self):
183
+ """
184
+ Records the elapsed time and returns a formatted string.
185
+ """
186
+ now_time = ttime()
187
+ elapsed_time = now_time - self.last_time
188
+ self.last_time = now_time
189
+ elapsed_time = round(elapsed_time, 1)
190
+ elapsed_time_str = str(datetime.timedelta(seconds=int(elapsed_time)))
191
+ current_time = datetime.datetime.now().strftime("%H:%M:%S")
192
+
193
+ return f"Current time: {current_time} | Time per epoch: {elapsed_time_str}"
194
+
195
+ def setup_env_and_distr(rank, n_gpus, device, device_id, config):
196
+ if rank == 0:
197
+ writer_eval = SummaryWriter(
198
+ log_dir=os.path.join(experiment_dir, "eval"),
199
+ flush_secs=86400 # Periodic background flush's timer workarouand.
200
+ )
201
+ block_tensorboard_flush_on_exit(writer_eval)
202
+ else:
203
+ writer_eval = None
204
+
205
+ dist.init_process_group(
206
+ backend="gloo" if sys.platform == "win32" or device.type != "cuda" else "nccl",
207
+ init_method="env://",
208
+ world_size=n_gpus if device.type == "cuda" else 1,
209
+ rank=rank if device.type == "cuda" else 0,
210
+ )
211
+
212
+ torch.manual_seed(config.train.seed)
213
+ if torch.cuda.is_available():
214
+ torch.cuda.set_device(device_id)
215
+
216
+ return writer_eval
217
+
218
+ def prepare_dataloaders(config, n_gpus, rank, batch_size, use_validation, benchmark_mode):
219
+ from data_utils import (
220
+ DistributedBucketSampler,
221
+ TextAudioCollateMultiNSFsid,
222
+ TextAudioLoaderMultiNSFsid
223
+ )
224
+
225
+ if not benchmark_mode and use_validation:
226
+ full_dataset = TextAudioLoaderMultiNSFsid(config.data)
227
+ train_len = int(0.90 * len(full_dataset))
228
+ val_len = len(full_dataset) - train_len
229
+ train_dataset, val_dataset = torch.utils.data.random_split(
230
+ full_dataset, [train_len, val_len], generator=torch.Generator().manual_seed(config.train.seed)
231
+ )
232
+ train_dataset.lengths = [full_dataset.lengths[i] for i in train_dataset.indices]
233
+ val_dataset.lengths = [full_dataset.lengths[i] for i in val_dataset.indices]
234
+ else:
235
+ train_dataset = TextAudioLoaderMultiNSFsid(config.data)
236
+ val_dataset = None
237
+
238
+ train_sampler = DistributedBucketSampler(
239
+ train_dataset,
240
+ batch_size * n_gpus,
241
+ [50, 100, 200, 300, 400, 500, 600, 700, 800, 900],
242
+ num_replicas=n_gpus,
243
+ rank=rank,
244
+ shuffle=True
245
+ )
246
+
247
+ collate_fn = TextAudioCollateMultiNSFsid()
248
+ train_loader = DataLoader(
249
+ train_dataset,
250
+ num_workers=4,
251
+ shuffle=False,
252
+ pin_memory=True,
253
+ collate_fn=collate_fn,
254
+ batch_sampler=train_sampler,
255
+ persistent_workers=enable_persistent_workers,
256
+ prefetch_factor=8
257
+ )
258
+ val_loader = None
259
+ if val_dataset:
260
+ val_sampler = DistributedBucketSampler(
261
+ val_dataset,
262
+ batch_size * n_gpus,
263
+ [50, 100, 200, 300, 400, 500, 600, 700, 800, 900],
264
+ num_replicas=n_gpus,
265
+ rank=rank,
266
+ shuffle=False
267
+ )
268
+ val_loader = DataLoader(
269
+ val_dataset, batch_sampler=val_sampler, shuffle=False, collate_fn=collate_fn,
270
+ num_workers=1, pin_memory=True
271
+ )
272
+
273
+ train_loader_safety(benchmark_mode, train_loader)
274
+
275
+ return train_loader, val_loader
276
+
277
+ def get_g_model(config, sample_rate, vocoder, use_checkpointing, randomized):
278
+ from rvc.lib.algorithm.synthesizers import Synthesizer
279
+ return Synthesizer(
280
+ config.data.filter_length // 2 + 1,
281
+ config.train.segment_size // config.data.hop_length,
282
+ **config.model,
283
+ use_f0 = True,
284
+ sr = sample_rate,
285
+ vocoder = vocoder,
286
+ checkpointing = use_checkpointing,
287
+ randomized = randomized,
288
+ )
289
+
290
+ def get_d_model(config, vocoder, use_checkpointing):
291
+ if vocoder == "RingFormer":
292
+ from rvc.lib.algorithm.discriminators.multi import MPD_MSD_MRD_Combined
293
+ # MPD + MSD + MRD ( unified ) - RingFormer architecture v1
294
+ return MPD_MSD_MRD_Combined(
295
+ config.model.use_spectral_norm,
296
+ use_checkpointing=use_checkpointing,
297
+ **dict(config.mrd)
298
+ )
299
+ else: # For HiFi-GAN, RefineGan or MRF-HiFi-GAN
300
+ from rvc.lib.algorithm.discriminators.multi import MPD_MSD_Combined
301
+ # MPD + MSD ( unified ) - Original RVC Setup
302
+ return MPD_MSD_Combined(
303
+ config.model.use_spectral_norm,
304
+ use_checkpointing=use_checkpointing
305
+ )
306
+
307
+ def get_optimizers(
308
+ net_g,
309
+ net_d,
310
+ config,
311
+ optimizer_choice,
312
+ custom_lr_g,
313
+ custom_lr_d,
314
+ use_custom_lr,
315
+ total_epoch_count,
316
+ train_loader
317
+ ):
318
+ # Base / Common kwargs for gen and disc
319
+ common_args_g = dict(
320
+ lr=custom_lr_g if use_custom_lr else config.train.learning_rate,
321
+ betas=(0.8, 0.99),
322
+ eps=1e-9,
323
+ weight_decay=0,
324
+ )
325
+ common_args_d = dict(
326
+ lr=custom_lr_d if use_custom_lr else config.train.learning_rate,
327
+ betas=(0.8, 0.99),
328
+ eps=1e-9,
329
+ weight_decay=0,
330
+ )
331
+ common_args_g_bf16 = dict(
332
+ lr=custom_lr_g if use_custom_lr else config.train.learning_rate,
333
+ betas=(0.8, 0.99),
334
+ eps=1e-9,
335
+ weight_decay=0.0,
336
+ use_kahan_summation=True,
337
+ )
338
+ common_args_d_bf16 = dict(
339
+ lr=custom_lr_d if use_custom_lr else config.train.learning_rate,
340
+ betas=(0.8, 0.99),
341
+ eps=1e-9,
342
+ weight_decay=0.0,
343
+ use_kahan_summation=True,
344
+ )
345
+ if optimizer_choice == "Ranger21":
346
+ from rvc.train.custom_optimizers.ranger21 import Ranger21
347
+ ranger_args = dict(
348
+ num_epochs=total_epoch_count,
349
+ num_batches_per_epoch=len(train_loader),
350
+ use_madgrad=False,
351
+ use_warmup=False,
352
+ warmdown_active=False,
353
+ use_cheb=False,
354
+ lookahead_active=True,
355
+ normloss_active=False,
356
+ normloss_factor=1e-4,
357
+ softplus=False,
358
+ use_adaptive_gradient_clipping=True,
359
+ agc_clipping_value=0.01,
360
+ agc_eps=1e-3,
361
+ using_gc=True,
362
+ gc_conv_only=True,
363
+ using_normgc=False,
364
+ )
365
+ optim_g = Ranger21(filter(lambda p: p.requires_grad, net_g.parameters()), **common_args_g, **ranger_args)
366
+ optim_d = Ranger21(net_d.parameters(), **common_args_d, **ranger_args)
367
+
368
+ elif optimizer_choice == "RAdam":
369
+ import torch_optimizer
370
+ optim_g = torch_optimizer.RAdam(filter(lambda p: p.requires_grad, net_g.parameters()), **common_args_g)
371
+ optim_d = torch_optimizer.RAdam(net_d.parameters(), **common_args_d)
372
+
373
+ elif optimizer_choice == "AdamW":
374
+ optim_g = torch.optim.AdamW(filter(lambda p: p.requires_grad, net_g.parameters()), **common_args_g)
375
+ optim_d = torch.optim.AdamW(net_d.parameters(), **common_args_d)
376
+
377
+ elif optimizer_choice == "AdamW_BF16":
378
+ from rvc.train.custom_optimizers.adamw_bfloat import BFF_AdamW
379
+ optim_g = BFF_AdamW(filter(lambda p: p.requires_grad, net_g.parameters()), **common_args_g_bf16)
380
+ optim_d = BFF_AdamW(net_d.parameters(), **common_args_d_bf16)
381
+
382
+ elif optimizer_choice == "Prodigy":
383
+ from rvc.train.custom_optimizers.prodigy import Prodigy
384
+ prodigy_args = dict(
385
+ betas=(0.8, 0.99),
386
+ weight_decay=0.0,
387
+ decouple=True,
388
+ )
389
+ optim_g = Prodigy(filter(lambda p: p.requires_grad, net_g.parameters()), lr=custom_lr_g if use_custom_lr else 1.0, **prodigy_args)
390
+ optim_d = Prodigy(net_d.parameters(), lr=custom_lr_d if use_custom_lr else 1.0, **prodigy_args)
391
+
392
+ elif optimizer_choice == "DiffGrad":
393
+ from rvc.train.custom_optimizers.diffgrad import diffgrad
394
+ optim_g = diffgrad(filter(lambda p: p.requires_grad, net_g.parameters()), **common_args_g)
395
+ optim_d = diffgrad(net_d.parameters(), **common_args_d)
396
+
397
+ else:
398
+ raise ValueError(f"Unknown optimizer choice: {optimizer_choice}")
399
+ return optim_g, optim_d
400
+
401
+ def setup_models_for_training(net_g, net_d, device, device_id, n_gpus):
402
+ net_g = net_g.to(device_id) if device.type == "cuda" else net_g.to(device)
403
+ net_d = net_d.to(device_id) if device.type == "cuda" else net_d.to(device)
404
+ if n_gpus > 1 and device.type == "cuda":
405
+ net_g = DDP(net_g, device_ids=[device_id]) # find_unused_parameters=True)
406
+ net_d = DDP(net_d, device_ids=[device_id]) # find_unused_parameters=True)
407
+
408
+ return net_g, net_d
409
+
410
+ def load_models_and_optimizers(config, pretrainG, pretrainD, vocoder, use_checkpointing, randomized, sample_rate, optimizer_choice, custom_lr_g, custom_lr_d, use_custom_lr, total_epoch_count, train_loader, device, device_id, n_gpus, rank):
411
+ try:
412
+ print(" ██████ Starting the training ...")
413
+
414
+ # Confirm presence of checkpoints
415
+ g_checkpoint_path = latest_checkpoint_path(experiment_dir, "G_*.pth")
416
+ d_checkpoint_path = latest_checkpoint_path(experiment_dir, "D_*.pth")
417
+
418
+ # If they exist, we attempt to resume the training
419
+ if g_checkpoint_path and d_checkpoint_path:
420
+
421
+ # Init the models
422
+ net_g = get_g_model(config, sample_rate, vocoder, use_checkpointing, randomized)
423
+ net_d = get_d_model(config, vocoder, use_checkpointing)
424
+
425
+
426
+ # Init the optimizers
427
+ optim_g, optim_d = get_optimizers(net_g, net_d, config, optimizer_choice, custom_lr_g, custom_lr_d, use_custom_lr, total_epoch_count, train_loader)
428
+
429
+ # Move the models to an appropriate device ( And optionally wrap with DDP for multi-gpu )
430
+ net_g, net_d = setup_models_for_training(net_g, net_d, device, device_id, n_gpus)
431
+
432
+ # Load the model and optim states
433
+ _, _, _, epoch_str = load_checkpoint(architecture, g_checkpoint_path, net_g, optim_g)
434
+ _, _, _, epoch_str = load_checkpoint(architecture, d_checkpoint_path, net_d, optim_d)
435
+
436
+ epoch_str += 1
437
+ global_step = (epoch_str - 1) * len(train_loader)
438
+ print(f"[RESUMING] (G) & (D) at global_step: {global_step} and epoch count: {epoch_str - 1}")
439
+ else:
440
+ raise FileNotFoundError("No checkpoints found.")
441
+
442
+ except FileNotFoundError:
443
+ # If no checkpoints are available, using the Pretrains directly
444
+ epoch_str = 1
445
+ global_step = 0
446
+
447
+ # Init the models
448
+ net_g = get_g_model(config, sample_rate, vocoder, use_checkpointing, randomized)
449
+ net_d = get_d_model(config, vocoder, use_checkpointing)
450
+
451
+ # Loading the pretrained Generator model
452
+ if (pretrainG != "" and pretrainG != "None"):
453
+ if rank == 0:
454
+ print(f"Loading pretrained (G) '{pretrainG}'")
455
+ verify_remap_checkpoint(pretrainG, net_g, architecture)
456
+
457
+
458
+ # Loading the pretrained Discriminator model
459
+ if pretrainD != "" and pretrainD != "None":
460
+ if rank == 0:
461
+ print(f"Loading pretrained (D) '{pretrainD}'")
462
+ verify_remap_checkpoint(pretrainD, net_d, architecture)
463
+
464
+ # Load the models and optionally wrap with DDP
465
+ net_g, net_d = setup_models_for_training(net_g, net_d, device, device_id, n_gpus)
466
+ # Init the optimizers
467
+ optim_g, optim_d = get_optimizers(net_g, net_d, config, optimizer_choice, custom_lr_g, custom_lr_d, use_custom_lr, total_epoch_count, train_loader)
468
+ return net_g, net_d, optim_g, optim_d, epoch_str, global_step
469
+
470
+ def prepare_schedulers(optim_g, optim_d, use_warmup, warmup_duration, use_lr_scheduler, lr_scheduler, exp_decay_gamma, total_epoch_count, epoch_str):
471
+ warmup_scheduler_g, warmup_scheduler_d = None, None
472
+ scheduler_g, scheduler_d = None, None
473
+
474
+ if use_warmup:
475
+ warmup_scheduler_g = torch.optim.lr_scheduler.LambdaLR(
476
+ optim_g, lr_lambda=lambda epoch: min(1.0, (epoch + 1) / warmup_duration)
477
+ )
478
+ warmup_scheduler_d = torch.optim.lr_scheduler.LambdaLR(
479
+ optim_d, lr_lambda=lambda epoch: min(1.0, (epoch + 1) / warmup_duration)
480
+ )
481
+
482
+ if not use_warmup:
483
+ for param_group in optim_g.param_groups: # For Generator
484
+ if 'initial_lr' not in param_group:
485
+ param_group['initial_lr'] = param_group['lr']
486
+ for param_group in optim_d.param_groups: # For Discriminator
487
+ if 'initial_lr' not in param_group:
488
+ param_group['initial_lr'] = param_group['lr']
489
+
490
+ if use_lr_scheduler:
491
+ if lr_scheduler == "exp decay":
492
+ # Exponential decay lr scheduler
493
+ scheduler_g = torch.optim.lr_scheduler.ExponentialLR( optim_g, gamma=exp_decay_gamma, last_epoch=epoch_str - 1 )
494
+ scheduler_d = torch.optim.lr_scheduler.ExponentialLR( optim_d, gamma=exp_decay_gamma, last_epoch=epoch_str - 1 )
495
+ elif lr_scheduler == "cosine annealing":
496
+ scheduler_g = torch.optim.lr_scheduler.CosineAnnealingLR( optim_g, T_max=total_epoch_count, eta_min=3e-5, last_epoch=epoch_str - 1 )
497
+ scheduler_d = torch.optim.lr_scheduler.CosineAnnealingLR( optim_d, T_max=total_epoch_count, eta_min=3e-5, last_epoch=epoch_str - 1 )
498
+
499
+ return warmup_scheduler_g, warmup_scheduler_d, scheduler_g, scheduler_d
500
+
501
+ def get_reference_sample(train_loader, device, config):
502
+ reference_path = os.path.join("logs", "reference")
503
+ use_custom_ref = all([
504
+ os.path.isfile(os.path.join(reference_path, "ref_feats.npy")),
505
+ os.path.isfile(os.path.join(reference_path, "ref_f0c.npy")),
506
+ os.path.isfile(os.path.join(reference_path, "ref_f0f.npy")),
507
+ ])
508
+
509
+ if use_custom_ref:
510
+ print("[REFERENCE] Using custom reference input from 'logs\\reference\\'")
511
+
512
+ phone = torch.FloatTensor(np.repeat(np.load(os.path.join(reference_path, "ref_feats.npy")), 2, axis=0)).unsqueeze(0).to(device)
513
+ pitch = torch.LongTensor(np.load(os.path.join(reference_path, "ref_f0c.npy"))).unsqueeze(0).to(device)
514
+ pitchf = torch.FloatTensor(np.load(os.path.join(reference_path, "ref_f0f.npy"))).unsqueeze(0).to(device)
515
+
516
+ min_len = min(phone.shape[1], pitch.shape[1], pitchf.shape[1])
517
+
518
+ phone, pitch, pitchf = phone[:, :min_len, :], pitch[:, :min_len], pitchf[:, :min_len]
519
+ phone_lengths = torch.LongTensor([phone.shape[1]]).to(device)
520
+
521
+ sid = torch.LongTensor([0]).to(device)
522
+ else:
523
+ print("[REFERENCE] No custom reference found. Fetching from the first batch of the train_loader.")
524
+
525
+ info = next(iter(train_loader))
526
+ phone, phone_lengths, pitch, pitchf, _, _, _, _, sid = info
527
+ phone, phone_lengths, pitch, pitchf, sid = phone.to(device), phone_lengths.to(device), pitch.to(device), pitchf.to(device), sid.to(device)
528
+
529
+ batch_indices = []
530
+ for batch in train_loader.batch_sampler:
531
+ batch_indices = batch
532
+ break
533
+
534
+ if isinstance(train_loader.dataset, torch.utils.data.Subset):
535
+ file_paths = train_loader.dataset.dataset.get_file_paths(batch_indices)
536
+ else:
537
+ file_paths = train_loader.dataset.get_file_paths(batch_indices)
538
+
539
+ file_name = os.path.basename(file_paths[0])
540
+ print(f"[REFERENCE] Origin of the ref: {file_name}")
541
+
542
+ return (phone, phone_lengths, pitch, pitchf, sid, config.train.seed)
543
+
544
+ def main():
545
+ """
546
+ Main function to start the training process.
547
+ """
548
+ global gpus
549
+
550
+ os.environ["MASTER_ADDR"] = "localhost"
551
+ os.environ["MASTER_PORT"] = str(randint(20000, 55555))
552
+
553
+ wavs = glob.glob(os.path.join(os.path.join(experiment_dir, "sliced_audios"), "*.wav"))
554
+ if wavs:
555
+ _, sr = load_wav_to_torch(wavs[0])
556
+ if sr != sample_rate:
557
+ print(f"Error: Pretrained model sample rate ({sample_rate} Hz) does not match dataset audio sample rate ({sr} Hz).")
558
+ os._exit(1)
559
+ else:
560
+ print("No wav file found.")
561
+
562
+ if torch.cuda.is_available():
563
+ device = torch.device("cuda")
564
+ gpus = [int(item) for item in gpus.split("-")]
565
+ n_gpus = len(gpus)
566
+ else:
567
+ device = torch.device("cpu")
568
+ gpus = [0]
569
+ n_gpus = 1
570
+ print("No GPU detected, fallback to CPU. This will take a very long time ...")
571
+
572
+ def start():
573
+ """
574
+ Starts the training process with multi-GPU support or CPU.
575
+ """
576
+ children = []
577
+
578
+ for rank, device_id in enumerate(gpus):
579
+ subproc = mp.Process(
580
+ target=run,
581
+ args=(
582
+ rank,
583
+ n_gpus,
584
+ experiment_dir,
585
+ pretrainG,
586
+ pretrainD,
587
+ total_epoch_count,
588
+ epoch_save_frequency,
589
+ save_weight_models,
590
+ save_only_latest_net_models,
591
+ config,
592
+ device,
593
+ device_id,
594
+ ),
595
+ )
596
+ children.append(subproc)
597
+ subproc.start()
598
+ pid_data["process_pids"].append(subproc.pid)
599
+
600
+ for i in range(n_gpus):
601
+ children[i].join()
602
+
603
+ if cleanup:
604
+ old_session_cleanup(now_dir, model_name)
605
+ start()
606
+
607
+ def run(
608
+ rank,
609
+ n_gpus,
610
+ experiment_dir,
611
+ pretrainG,
612
+ pretrainD,
613
+ total_epoch_count,
614
+ epoch_save_frequency,
615
+ save_weight_models,
616
+ save_only_latest_net_models,
617
+ config,
618
+ device,
619
+ device_id,
620
+ ):
621
+ """
622
+ Runs the training loop on a specific GPU or CPU.
623
+
624
+ Args:
625
+ rank (int): The rank of the current process within the distributed training setup.
626
+ n_gpus (int): The total number of GPUs available for training.
627
+ experiment_dir (str): The directory where experiment logs and checkpoints will be saved.
628
+ pretrainG (str): Path to the pre-trained generator model.
629
+ pretrainD (str): Path to the pre-trained discriminator model.
630
+ total_epoch_count (int): The total number of epochs for training.
631
+ epoch_save_frequency (int): Frequency of saving epochs.
632
+ save_weight_models (int): Whether to save small weight models. 0 for no, 1 for yes.
633
+ save_only_latest_net_models (int): Whether to save only latest G/D or for each epoch.
634
+ config (object): Configuration object containing training parameters.
635
+ device (torch.device): The device to use for training (CPU or GPU).
636
+ """
637
+ global global_step, warmup_completed, optimizer_choice, from_scratch
638
+
639
+ if 'warmup_completed' not in globals():
640
+ warmup_completed = False
641
+
642
+ # Initial print / session info for console
643
+ print_init_setup(
644
+ warmup_duration,
645
+ rank,
646
+ use_warmup,
647
+ config,
648
+ optimizer_choice,
649
+ d_updates_per_step,
650
+ use_validation,
651
+ lr_scheduler,
652
+ exp_decay_gamma
653
+ )
654
+
655
+ # Initial setup
656
+ writer_eval = setup_env_and_distr(
657
+ rank,
658
+ n_gpus,
659
+ device,
660
+ device_id,
661
+ config
662
+ )
663
+
664
+ # Dataloading and loaders preparation
665
+ train_loader, val_loader = prepare_dataloaders(
666
+ config,
667
+ n_gpus,
668
+ rank,
669
+ batch_size,
670
+ use_validation,
671
+ benchmark_mode
672
+ )
673
+
674
+ # Spk dim verif
675
+ spk_dim = verify_spk_dim(config, model_info_path, experiment_dir, latest_checkpoint_path, rank, pretrainG)
676
+ config.model.spk_embed_dim = spk_dim
677
+
678
+ # Spectral loss init
679
+ if spectral_loss == "L1 Mel Loss":
680
+ fn_spectral_loss = torch.nn.L1Loss()
681
+ print(" ██████ Spectral loss: Single-Scale (L1) Mel loss function")
682
+ elif spectral_loss == "Multi-Scale Mel Loss":
683
+ fn_spectral_loss = MultiScaleMelSpectrogramLoss(sample_rate=sample_rate)
684
+ print(" ██████ Spectral loss: Multi-Scale Mel loss function")
685
+ elif spectral_loss == "Multi-Res STFT Loss":
686
+ fn_spectral_loss = auraloss.freq.MultiResolutionSTFTLoss(
687
+ fft_sizes = [1024, 2048, 512],
688
+ hop_sizes = [80, 160, 40], # stock: 120, 240, 50
689
+ win_lengths = [480, 960, 240], # stock: 600, 1200, 240
690
+ window = "hann_window",
691
+ w_sc = 1.0,
692
+ w_log_mag = 1.0,
693
+ w_lin_mag = 0.0,
694
+ w_phs=0.0,
695
+ sample_rate = sample_rate,
696
+ scale = None,
697
+ n_bins = None,
698
+ perceptual_weighting = True,
699
+ scale_invariance = False,
700
+ output= "loss", # "loss", "full"
701
+ reduction = "mean", # "none", "mean", "sum"
702
+ mag_distance = "L1", # "L1", "L2"
703
+ device=device,
704
+ )
705
+ print(" ██████ Spectral loss: Multi-Resolution STFT loss function")
706
+ else:
707
+ print("ERROR: Chosen spectral loss is undefined. Exiting.")
708
+ sys.exit(1)
709
+
710
+ # Loading of models and optims
711
+ net_g, net_d, optim_g, optim_d, epoch_str, global_step = load_models_and_optimizers(
712
+ config,
713
+ pretrainG,
714
+ pretrainD,
715
+ vocoder,
716
+ use_checkpointing,
717
+ randomized,
718
+ sample_rate,
719
+ optimizer_choice,
720
+ custom_lr_g,
721
+ custom_lr_d,
722
+ use_custom_lr,
723
+ total_epoch_count,
724
+ train_loader,
725
+ device,
726
+ device_id,
727
+ n_gpus,
728
+ rank
729
+ )
730
+
731
+ # from-scratch checker ( disables average loss )
732
+ if pretrainG in ["", "None"] and pretrainD in ["", "None"]:
733
+ from_scratch = True
734
+ if rank == 0:
735
+ print(" ██████ No pretrains used: Average loss disabled!")
736
+
737
+ # Prepare the schedulers
738
+ warmup_scheduler_g, warmup_scheduler_d, scheduler_g, scheduler_d = prepare_schedulers(
739
+ optim_g,
740
+ optim_d,
741
+ use_warmup,
742
+ warmup_duration,
743
+ use_lr_scheduler,
744
+ lr_scheduler,
745
+ exp_decay_gamma,
746
+ total_epoch_count,
747
+ epoch_str
748
+ )
749
+
750
+ # Hann window for stft ( for RingFormer only. )
751
+ hann_window = torch.hann_window(config.model.gen_istft_n_fft).to(device) if vocoder == "RingFormer" else None
752
+
753
+ # GradScaler for FP16 training
754
+ gradscaler = torch.amp.GradScaler(enabled=(device.type == "cuda" and train_dtype == torch.float16))
755
+
756
+ # Reference sample for live-infer
757
+ reference = get_reference_sample(train_loader, device, config)
758
+
759
+ # Cache for training with " cache " enabled
760
+ cache = []
761
+
762
+ for epoch in range(epoch_str, total_epoch + 1):
763
+ training_loop(
764
+ rank,
765
+ epoch,
766
+ config,
767
+ [net_g, net_d],
768
+ [optim_g, optim_d],
769
+ train_loader,
770
+ val_loader if use_validation else None,
771
+ [writer_eval],
772
+ cache,
773
+ total_epoch_count,
774
+ epoch_save_frequency,
775
+ save_weight_models,
776
+ save_only_latest_net_models,
777
+ device,
778
+ device_id,
779
+ reference,
780
+ fn_spectral_loss,
781
+ n_gpus,
782
+ gradscaler,
783
+ hann_window,
784
+ )
785
+ if use_warmup and epoch <= warmup_duration:
786
+ if warmup_scheduler_g:
787
+ warmup_scheduler_g.step()
788
+ if warmup_scheduler_d:
789
+ warmup_scheduler_d.step()
790
+
791
+ # Logging of finished warmup
792
+ if epoch == warmup_duration:
793
+ warmup_completed = True
794
+ print(f" ██████ Warmup completed at epochs: {warmup_duration}")
795
+ print(f" ██████ LR G: {optim_g.param_groups[0]['lr']}")
796
+ print(f" ██████ LR D: {optim_d.param_groups[0]['lr']}")
797
+ # scheduler:
798
+ if lr_scheduler == "exp decay":
799
+ print(f" ██████ Starting the exponential lr decay with gamma of {exp_decay_gamma}")
800
+ elif lr_scheduler == "cosine annealing":
801
+ print(" ██████ Starting cosine annealing scheduler " )
802
+
803
+ if use_lr_scheduler and (not use_warmup or warmup_completed):
804
+ # Once the warmup phase is completed, uses exponential lr decay
805
+ scheduler_g.step()
806
+ scheduler_d.step()
807
+
808
+ def training_loop(
809
+ rank,
810
+ epoch,
811
+ config,
812
+ nets,
813
+ optims,
814
+ train_loader,
815
+ val_loader,
816
+ writers,
817
+ cache,
818
+ total_epoch_count,
819
+ epoch_save_frequency,
820
+ save_weight_models,
821
+ save_only_latest_net_models,
822
+ device,
823
+ device_id,
824
+ reference,
825
+ fn_spectral_loss,
826
+ n_gpus,
827
+ gradscaler,
828
+ hann_window=None,
829
+ ):
830
+ """
831
+ Trains and evaluates the model for one epoch.
832
+
833
+ Args:
834
+ rank (int): Rank of the current process.
835
+ epoch (int): Current epoch number.
836
+ config (Namespace): Hyperparameters.
837
+ nets (list): List of models [net_g, net_d].
838
+ optims (list): List of optimizers [optim_g, net_d].
839
+ train_loader: training dataloader.
840
+ val_loader: validation dataloader.
841
+ writers (list): List of TensorBoard writers [writer_eval].
842
+ cache (list): List to cache data in GPU memory.
843
+ use_cpu (bool): Whether to use CPU for training.
844
+ """
845
+ global global_step, warmup_completed, dynamic_c_kl
846
+
847
+ net_g, net_d = nets
848
+ optim_g, optim_d = optims
849
+
850
+ train_loader = train_loader if train_loader is not None else None
851
+ if not benchmark_mode and use_validation:
852
+ val_loader = val_loader if val_loader is not None else None
853
+
854
+ if writers is not None:
855
+ writer = writers[0]
856
+
857
+ train_loader.batch_sampler.set_epoch(epoch)
858
+
859
+ net_g.train()
860
+ net_d.train()
861
+
862
+ # Data caching
863
+ if device.type == "cuda" and cache_data_in_gpu:
864
+ data_iterator = cache
865
+ if cache == []:
866
+ for batch_idx, info in enumerate(train_loader):
867
+ # phone, phone_lengths, pitch, pitchf, spec, spec_lengths, y, y_lengths, sid
868
+ info = [tensor.cuda(device_id, non_blocking=True) for tensor in info]
869
+ cache.append((batch_idx, info))
870
+ else:
871
+ shuffle(cache)
872
+ else:
873
+ data_iterator = enumerate(train_loader)
874
+
875
+ epoch_recorder = EpochRecorder()
876
+
877
+ if not from_scratch:
878
+ # Tensors init for averaged losses:
879
+ tensor_count = 7 if vocoder == "RingFormer" else 6
880
+ epoch_loss_tensor = torch.zeros(tensor_count, device=device)
881
+ num_batches_in_epoch = 0
882
+
883
+ avg_50_cache = {
884
+ "grad_norm_d_clipped_50": deque(maxlen=50),
885
+ "grad_norm_g_clipped_50": deque(maxlen=50),
886
+ "loss_disc_50": deque(maxlen=50),
887
+ "loss_adv_50": deque(maxlen=50),
888
+ "loss_gen_total_50": deque(maxlen=50),
889
+ "loss_fm_50": deque(maxlen=50),
890
+ "loss_mel_50": deque(maxlen=50),
891
+ "loss_kl_50": deque(maxlen=50),
892
+
893
+ }
894
+ if vocoder == "RingFormer":
895
+ avg_50_cache.update({
896
+ "loss_sd_50": deque(maxlen=50),
897
+ })
898
+
899
+ use_amp = (config.train.bf16_run or config.train.fp16_run) and device.type == "cuda"
900
+
901
+ with tqdm(total=len(train_loader), leave=False) as pbar:
902
+ for batch_idx, info in data_iterator:
903
+
904
+ global_step += 1
905
+
906
+ if not from_scratch:
907
+ num_batches_in_epoch += 1
908
+
909
+ if device.type == "cuda" and not cache_data_in_gpu:
910
+ info = [tensor.cuda(device_id, non_blocking=True) for tensor in info]
911
+ elif device.type != "cuda":
912
+ info = [tensor.to(device) for tensor in info]
913
+ (
914
+ phone,
915
+ phone_lengths,
916
+ pitch,
917
+ pitchf,
918
+ spec,
919
+ spec_lengths,
920
+ y,
921
+ y_lengths,
922
+ sid,
923
+ ) = info
924
+
925
+ # Generator forward pass:
926
+ with autocast(device_type="cuda", enabled=use_amp, dtype=train_dtype):
927
+ model_output = net_g(phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid)
928
+ # Unpacking:
929
+ if vocoder == "RingFormer":
930
+ y_hat, ids_slice, x_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q), (mag, phase) = (model_output)
931
+ else:
932
+ y_hat, ids_slice, x_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q) = (model_output)
933
+
934
+ # Slice the original waveform ( y ) to match the generated slice:
935
+ if randomized:
936
+ y = commons.slice_segments(
937
+ y,
938
+ ids_slice * config.data.hop_length,
939
+ config.train.segment_size,
940
+ dim=3,
941
+ )
942
+
943
+ if vocoder == "RingFormer":
944
+ reshaped_y = y.view(-1, y.size(-1))
945
+ reshaped_y_hat = y_hat.view(-1, y_hat.size(-1))
946
+ y_stft = torch.stft(reshaped_y, n_fft=config.model.gen_istft_n_fft, hop_length=config.model.gen_istft_hop_size, win_length=config.model.gen_istft_n_fft, window=hann_window, return_complex=True)
947
+ y_hat_stft = torch.stft(reshaped_y_hat, n_fft=config.model.gen_istft_n_fft, hop_length=config.model.gen_istft_hop_size, win_length=config.model.gen_istft_n_fft, window=hann_window, return_complex=True)
948
+ target_magnitude = torch.abs(y_stft) # shape: [B, F, T]
949
+
950
+ # Discriminator forward pass:
951
+ for _ in range(d_updates_per_step): # default is 1 update per step
952
+ with autocast(device_type="cuda", enabled=use_amp, dtype=train_dtype):
953
+ y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
954
+
955
+ with autocast(device_type="cuda", enabled=False):
956
+ # Compute discriminator loss:
957
+ loss_disc = discriminator_loss(y_d_hat_r, y_d_hat_g)
958
+
959
+ # Discriminator backward and update:
960
+ optim_d.zero_grad()
961
+ if train_dtype == torch.float16:
962
+ # 0. GradScaler handling
963
+ gradscaler.scale(loss_disc).backward()
964
+ gradscaler.unscale_(optim_d)
965
+ # 1. Grads norm clip
966
+ grad_norm_d = torch.nn.utils.clip_grad_norm_(net_d.parameters(), max_norm=999999)
967
+ # 2. Retrieve the clipped grads
968
+ grad_norm_d_clipped = commons.get_total_norm([p.grad for p in net_d.parameters() if p.grad is not None], norm_type=2.0, error_if_nonfinite=False)
969
+ # 3. Optimization step
970
+ gradscaler.step(optim_d)
971
+ else:
972
+ loss_disc.backward()
973
+ # 1. Grads norm clip
974
+ grad_norm_d = torch.nn.utils.clip_grad_norm_(net_d.parameters(), max_norm=999999) # 1000 / 999999
975
+ # 2. Retrieve the clipped grads
976
+ grad_norm_d_clipped = commons.get_total_norm([p.grad for p in net_d.parameters() if p.grad is not None], norm_type=2.0, error_if_nonfinite=True)
977
+ # 3. Optimization step
978
+ optim_d.step()
979
+
980
+ # Run discriminator on generated output
981
+ with autocast(device_type="cuda", enabled=use_amp, dtype=train_dtype):
982
+ _, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
983
+
984
+ # Compute generator losses:
985
+ with autocast(device_type="cuda", enabled=False):
986
+
987
+ # Spectral loss ( In code kept referenced as "loss_mel" to avoid confusion in old logs / graphs):
988
+ if spectral_loss == "L1 Mel Loss":
989
+ y_mel = wave_to_mel(config, y, half=train_dtype)
990
+ y_hat_mel = wave_to_mel(config, y_hat, half=train_dtype)
991
+ loss_mel = fn_spectral_loss(y_mel, y_hat_mel) * config.train.c_mel
992
+ elif spectral_loss == "Multi-Scale Mel Loss":
993
+ loss_mel = fn_spectral_loss(y, y_hat) * config.train.c_mel / 3.0
994
+ elif spectral_loss == "Multi-Res STFT Loss":
995
+ loss_mel = fn_spectral_loss(y_hat.float(), y.float()) * c_stft
996
+
997
+ # Feature Matching loss
998
+ loss_fm = feature_loss(fmap_r, fmap_g)
999
+
1000
+ # Generator loss
1001
+ loss_adv = generator_loss(y_d_hat_g)
1002
+
1003
+ # KL ( Kullback–Leibler divergence ) loss
1004
+ loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * config.train.c_kl
1005
+
1006
+ if vocoder == "RingFormer":
1007
+ # RingFormer related; Phase, Magnitude and SD:
1008
+ loss_magnitude = torch.nn.functional.l1_loss(mag, target_magnitude)
1009
+ loss_phase = phase_loss(y_stft, y_hat_stft)
1010
+ loss_sd = (loss_magnitude + loss_phase) * 0.7
1011
+
1012
+ # Total generator loss
1013
+ if vocoder == "RingFormer":
1014
+ loss_gen_total = loss_adv + loss_fm + loss_mel + loss_kl + loss_sd
1015
+ else:
1016
+ loss_gen_total = loss_adv + loss_fm + loss_mel + loss_kl
1017
+
1018
+
1019
+ # Generator backward and update:
1020
+ optim_g.zero_grad()
1021
+ if train_dtype == torch.float16:
1022
+ # 0. GradScaler handling
1023
+ gradscaler.scale(loss_gen_total).backward()
1024
+ gradscaler.unscale_(optim_g)
1025
+ # 1. Grads norm clip
1026
+ grad_norm_g = torch.nn.utils.clip_grad_norm_(net_g.parameters(), max_norm=999999)
1027
+ # 2. Retrieve the clipped grads
1028
+ grad_norm_g_clipped = commons.get_total_norm([p.grad for p in net_g.parameters() if p.grad is not None], norm_type=2.0, error_if_nonfinite=False)
1029
+ # 3. Optimization step
1030
+ gradscaler.step(optim_g)
1031
+ gradscaler.update()
1032
+ else:
1033
+ loss_gen_total.backward()
1034
+ # 1. Grads norm clip
1035
+ grad_norm_g = torch.nn.utils.clip_grad_norm_(net_g.parameters(), max_norm=999999) # 1000 / 999999
1036
+ # 2. Retrieve the clipped grads
1037
+ grad_norm_g_clipped = commons.get_total_norm([p.grad for p in net_g.parameters() if p.grad is not None], norm_type=2.0, error_if_nonfinite=True)
1038
+ # 3. Optimization step
1039
+ optim_g.step()
1040
+
1041
+
1042
+ if not from_scratch:
1043
+ # Loss accumulation In the epoch_loss_tensor
1044
+ epoch_loss_tensor[0].add_(loss_disc.detach())
1045
+ epoch_loss_tensor[1].add_(loss_adv.detach())
1046
+ epoch_loss_tensor[2].add_(loss_gen_total.detach())
1047
+ epoch_loss_tensor[3].add_(loss_fm.detach())
1048
+ epoch_loss_tensor[4].add_(loss_mel.detach())
1049
+ epoch_loss_tensor[5].add_(loss_kl.detach())
1050
+ if vocoder == "RingFormer":
1051
+ epoch_loss_tensor[6].add_(loss_sd.detach())
1052
+
1053
+ # queue for rolling losses / grads over 50 steps
1054
+ # Grads:
1055
+ avg_50_cache["grad_norm_d_clipped_50"].append(grad_norm_d_clipped)
1056
+ avg_50_cache["grad_norm_g_clipped_50"].append(grad_norm_g_clipped)
1057
+ # Losses:
1058
+ avg_50_cache["loss_disc_50"].append(loss_disc.detach())
1059
+ avg_50_cache["loss_adv_50"].append(loss_adv.detach())
1060
+ avg_50_cache["loss_gen_total_50"].append(loss_gen_total.detach())
1061
+ avg_50_cache["loss_fm_50"].append(loss_fm.detach())
1062
+ avg_50_cache["loss_mel_50"].append(loss_mel.detach())
1063
+ avg_50_cache["loss_kl_50"].append(loss_kl.detach())
1064
+ if vocoder == "RingFormer":
1065
+ avg_50_cache["loss_sd_50"].append(loss_sd.detach())
1066
+
1067
+ if rank == 0 and global_step % 50 == 0:
1068
+ scalar_dict_50 = {}
1069
+ # Learning rate retrieval for avg-50 variation:
1070
+ if from_scratch:
1071
+ lr_d = optim_d.param_groups[0]["lr"]
1072
+ lr_g = optim_g.param_groups[0]["lr"]
1073
+ scalar_dict_50.update({
1074
+ "learning_rate/lr_d": lr_d,
1075
+ "learning_rate/lr_g": lr_g,
1076
+ })
1077
+ if optimizer_choice == "Prodigy":
1078
+ prodigy_lr_g = optim_g.param_groups[0].get('d', 0)
1079
+ prodigy_lr_d = optim_d.param_groups[0].get('d', 0)
1080
+ scalar_dict_50.update({
1081
+ "learning_rate/prodigy_lr_g": prodigy_lr_g,
1082
+ "learning_rate/prodigy_lr_d": prodigy_lr_d,
1083
+ })
1084
+ # logging rolling averages
1085
+ scalar_dict_50.update({
1086
+ # Grads:
1087
+ "grad_avg_50/norm_d_clipped_50": sum(avg_50_cache["grad_norm_d_clipped_50"])
1088
+ / len(avg_50_cache["grad_norm_d_clipped_50"]),
1089
+ "grad_avg_50/norm_g_clipped_50": sum(avg_50_cache["grad_norm_g_clipped_50"])
1090
+ / len(avg_50_cache["grad_norm_g_clipped_50"]),
1091
+ # Losses:
1092
+ "loss_avg_50/loss_disc_50": torch.mean(
1093
+ torch.stack(list(avg_50_cache["loss_disc_50"]))),
1094
+ "loss_avg_50/loss_adv_50": torch.mean(
1095
+ torch.stack(list(avg_50_cache["loss_adv_50"]))),
1096
+ "loss_avg_50/loss_gen_total_50": torch.mean(
1097
+ torch.stack(list(avg_50_cache["loss_gen_total_50"]))),
1098
+ "loss_avg_50/loss_fm_50": torch.mean(
1099
+ torch.stack(list(avg_50_cache["loss_fm_50"]))),
1100
+ "loss_avg_50/loss_mel_50": torch.mean(
1101
+ torch.stack(list(avg_50_cache["loss_mel_50"]))),
1102
+ "loss_avg_50/loss_kl_50": torch.mean(
1103
+ torch.stack(list(avg_50_cache["loss_kl_50"]))),
1104
+ })
1105
+ if vocoder == "RingFormer":
1106
+ scalar_dict_50.update({
1107
+ # Losses:
1108
+ "loss_avg_50/loss_sd_50": torch.mean(
1109
+ torch.stack(list(avg_50_cache["loss_sd_50"]))),
1110
+ })
1111
+
1112
+ summarize(writer=writer, global_step=global_step, scalars=scalar_dict_50)
1113
+ flush_writer(writer, rank)
1114
+
1115
+ pbar.update(1)
1116
+ # end of batch train
1117
+ # end of tqdm
1118
+
1119
+ if n_gpus > 1 and device.type == 'cuda':
1120
+ dist.barrier()
1121
+
1122
+ with torch.no_grad():
1123
+ torch.cuda.empty_cache()
1124
+
1125
+ # Logging and checkpointing
1126
+ if rank == 0:
1127
+ # Used for tensorboard chart - all/mel
1128
+ mel = spec_to_mel_torch(
1129
+ spec,
1130
+ config.data.filter_length,
1131
+ config.data.n_mel_channels,
1132
+ config.data.sample_rate,
1133
+ config.data.mel_fmin,
1134
+ config.data.mel_fmax,
1135
+ )
1136
+
1137
+ # For fp16 we need to .half() the mel spec
1138
+ if train_dtype == torch.float16:
1139
+ mel = mel.half()
1140
+
1141
+ # Used for tensorboard chart - slice/mel_org
1142
+ if randomized:
1143
+ y_mel = commons.slice_segments(
1144
+ mel,
1145
+ ids_slice,
1146
+ config.train.segment_size // config.data.hop_length,
1147
+ dim=3,
1148
+ )
1149
+ else:
1150
+ y_mel = mel
1151
+
1152
+ # used for tensorboard chart - slice/mel_gen
1153
+ y_hat_mel = wave_to_mel(config, y_hat, half=train_dtype)
1154
+
1155
+ # Mel similarity metric:
1156
+ mel_similarity = mel_spec_similarity(y_hat_mel, y_mel)
1157
+ print(f'Mel Spectrogram Similarity: {mel_similarity:.2f}%')
1158
+ writer.add_scalar('Metric/Mel_Spectrogram_Similarity', mel_similarity, global_step)
1159
+
1160
+ # Learning rate retrieval for avg-epoch variation:
1161
+ lr_d = optim_d.param_groups[0]["lr"]
1162
+ lr_g = optim_g.param_groups[0]["lr"]
1163
+
1164
+ # Calculate the avg epoch loss:
1165
+ if global_step % len(train_loader) == 0 and not from_scratch: # At each epoch completion
1166
+ avg_epoch_loss = epoch_loss_tensor / num_batches_in_epoch
1167
+
1168
+ scalar_dict_avg = {
1169
+ "loss_avg/loss_disc": avg_epoch_loss[0],
1170
+ "loss_avg/loss_adv": avg_epoch_loss[1],
1171
+ "loss_avg/loss_gen_total": avg_epoch_loss[2],
1172
+ "loss_avg/loss_fm": avg_epoch_loss[3],
1173
+ "loss_avg/loss_mel": avg_epoch_loss[4],
1174
+ "loss_avg/loss_kl": avg_epoch_loss[5],
1175
+ "learning_rate/lr_d": lr_d,
1176
+ "learning_rate/lr_g": lr_g,
1177
+ }
1178
+ if optimizer_choice == "Prodigy":
1179
+ prodigy_lr_g = optim_g.param_groups[0].get('d', 0)
1180
+ prodigy_lr_d = optim_d.param_groups[0].get('d', 0)
1181
+ scalar_dict_avg.update({
1182
+ "learning_rate/prodigy_lr_g": prodigy_lr_g,
1183
+ "learning_rate/prodigy_lr_d": prodigy_lr_d,
1184
+ })
1185
+ if vocoder == "RingFormer":
1186
+ scalar_dict_avg.update({
1187
+ "loss_avg/loss_sd": avg_epoch_loss[6],
1188
+ })
1189
+
1190
+ summarize(writer=writer, global_step=global_step, scalars=scalar_dict_avg)
1191
+ flush_writer(writer, rank)
1192
+ num_batches_in_epoch = 0
1193
+ epoch_loss_tensor.zero_()
1194
+
1195
+ # Determine the plot data type
1196
+ if train_dtype == torch.float16:
1197
+ plot_dtype = torch.float16
1198
+ else:
1199
+ plot_dtype = torch.float32
1200
+
1201
+ image_dict = {
1202
+ "slice/mel_org": plot_spectrogram_to_numpy(y_mel[0].detach().cpu().to(plot_dtype).numpy()),
1203
+ "slice/mel_gen": plot_spectrogram_to_numpy(y_hat_mel[0].detach().cpu().to(plot_dtype).numpy()),
1204
+ "all/mel": plot_spectrogram_to_numpy(mel[0].detach().cpu().to(plot_dtype).numpy()),
1205
+ }
1206
+
1207
+
1208
+ # At each epoch save point:
1209
+ if epoch % epoch_save_frequency == 0:
1210
+ if not benchmark_mode and use_validation:
1211
+ # Running validation
1212
+ validation_loop(
1213
+ net_g.module if hasattr(net_g, "module") else net_g,
1214
+ val_loader,
1215
+ device,
1216
+ config,
1217
+ writer,
1218
+ global_step,
1219
+ )
1220
+ # Inferencing on reference sample
1221
+
1222
+ # with torch.amp.autocast(
1223
+ # device_type="cuda", enabled=use_amp, dtype=train_dtype
1224
+ # ):
1225
+
1226
+ net_g.eval()
1227
+ with torch.no_grad():
1228
+ if hasattr(net_g, "module"):
1229
+ o, *_ = net_g.module.infer(*reference)
1230
+ else:
1231
+ o, *_ = net_g.infer(*reference)
1232
+ net_g.train()
1233
+ audio_dict = {f"gen/audio_{epoch}e_{global_step}s": o[0, :, :]} # Eval-infer samples
1234
+ # Logging
1235
+ summarize(
1236
+ writer=writer,
1237
+ global_step=global_step,
1238
+ images=image_dict,
1239
+ audios=audio_dict,
1240
+ audio_sample_rate=config.data.sample_rate,
1241
+ )
1242
+ flush_writer(writer, rank)
1243
+ else:
1244
+ summarize(
1245
+ writer=writer,
1246
+ global_step=global_step,
1247
+ images=image_dict,
1248
+ )
1249
+ flush_writer(writer, rank)
1250
+
1251
+ # Save checkpoint
1252
+ model_add = []
1253
+ done = False
1254
+
1255
+ if rank == 0:
1256
+ # Print training progress
1257
+ record = f"{model_name} | epoch={epoch} | step={global_step} | {epoch_recorder.record()}"
1258
+ print(record)
1259
+
1260
+ # Save weights every N epochs
1261
+ if epoch % epoch_save_frequency == 0:
1262
+ checkpoint_suffix = f"{2333333 if save_only_latest_net_models else global_step}.pth"
1263
+ # Save Generator checkpoint
1264
+ save_checkpoint(
1265
+ architecture,
1266
+ net_g,
1267
+ optim_g,
1268
+ config.train.learning_rate,
1269
+ epoch,
1270
+ os.path.join(experiment_dir, "G_" + checkpoint_suffix),
1271
+ )
1272
+ # Save Discriminator checkpoint
1273
+ save_checkpoint(
1274
+ architecture,
1275
+ net_d,
1276
+ optim_d,
1277
+ config.train.learning_rate,
1278
+ epoch,
1279
+ os.path.join(experiment_dir, "D_" + checkpoint_suffix),
1280
+ )
1281
+ # Save small weight model
1282
+ if save_weight_models:
1283
+ weight_model_name = small_model_naming(model_name, epoch, global_step)
1284
+ model_add.append(os.path.join(experiment_dir, weight_model_name))
1285
+
1286
+ # Check completion
1287
+ if epoch >= total_epoch_count:
1288
+ print(
1289
+ f"Training has been successfully completed with {epoch} epoch, {global_step} steps and {round(loss_gen_total.item(), 3)} loss gen."
1290
+ )
1291
+ # Final model
1292
+ weight_model_name = small_model_naming(model_name, epoch, global_step)
1293
+ model_add.append(os.path.join(experiment_dir, weight_model_name))
1294
+
1295
+ done = True
1296
+
1297
+ if model_add:
1298
+ ckpt = (
1299
+ net_g.module.state_dict()
1300
+ if hasattr(net_g, "module")
1301
+ else net_g.state_dict()
1302
+ )
1303
+ for m in model_add:
1304
+ if not os.path.exists(m):
1305
+ extract_model(
1306
+ ckpt=ckpt,
1307
+ sr=sample_rate,
1308
+ name=model_name,
1309
+ model_path=m,
1310
+ epoch=epoch,
1311
+ step=global_step,
1312
+ hps=config,
1313
+ vocoder=vocoder,
1314
+ architecture=architecture,
1315
+ )
1316
+ if done:
1317
+ # Clean-up process IDs from memory
1318
+ pid_data["process_pids"].clear() # Clear the PID list when done
1319
+
1320
+ if rank == 0:
1321
+ writer.flush()
1322
+ writer.close()
1323
+
1324
+ os._exit(2333333)
1325
+
1326
+ with torch.no_grad():
1327
+ torch.cuda.empty_cache()
1328
+
1329
+
1330
+ def validation_loop(net_g, val_loader, device, config, writer, global_step):
1331
+ net_g.eval()
1332
+ torch.cuda.empty_cache()
1333
+
1334
+ total_mel_error = 0.0
1335
+ total_mrstft_loss = 0.0
1336
+ total_pesq = 0.0
1337
+ valid_pesq_count = 0
1338
+ total_si_sdr = 0.0
1339
+ count = 0
1340
+
1341
+ mrstft = auraloss.freq.MultiResolutionSTFTLoss(device=device)
1342
+ resample_to_16k = torchaudio.transforms.Resample(orig_freq=config.data.sample_rate, new_freq=16000).to(device)
1343
+
1344
+ hop_length = config.data.hop_length
1345
+ sample_rate = config.data.sample_rate
1346
+
1347
+ with torch.no_grad():
1348
+ for batch in tqdm(val_loader, desc="Validating"):
1349
+ phone, phone_lengths, pitch, pitchf, spec, spec_lengths, y, _, sid = [t.to(device) for t in batch]
1350
+
1351
+ # Infer
1352
+ y_hat, x_mask, _ = net_g.infer(phone, phone_lengths, pitch, pitchf, sid)
1353
+
1354
+ # Get reference min-length ( according to gt wave's length )
1355
+ y_len = y.shape[-1]
1356
+
1357
+ # Obtaining mel specs
1358
+ y_hat_mel = wave_to_mel(config, y_hat, half=train_dtype) # generator-source mel
1359
+ mel = wave_to_mel(config, y, half=train_dtype) # gt-source mel
1360
+
1361
+ # Mel loss:
1362
+ y_hat_mel_len = y_hat_mel.shape[-1]
1363
+ mel_len = mel.shape[-1]
1364
+
1365
+ min_t = min(y_hat_mel_len, mel_len)
1366
+
1367
+ mel_loss = F.l1_loss(y_hat_mel[..., :min_t], mel[..., :min_t])
1368
+ total_mel_error += mel_loss.item()
1369
+
1370
+ # STFT loss:
1371
+ y_hat_len = y_hat.shape[-1]
1372
+
1373
+ min_samples = min_t * hop_length
1374
+ min_samples = min(min_samples, y_len, y_hat_len)
1375
+
1376
+ stft_loss = mrstft(y_hat[..., :min_samples], y[..., :min_samples])
1377
+ total_mrstft_loss += stft_loss.item()
1378
+
1379
+ # si_sdr:
1380
+ si_sdr_score = si_sdr(y_hat.squeeze(1), y.squeeze(1))
1381
+ total_si_sdr += si_sdr_score.item()
1382
+
1383
+ # PESQ:
1384
+ try:
1385
+ y_16k_batch = resample_to_16k(y).cpu().numpy() # (B, T)
1386
+ y_hat_16k_batch = resample_to_16k(y_hat.squeeze(1)).cpu().numpy() # (B, T)
1387
+
1388
+ for i in range(y_16k_batch.shape[0]):
1389
+ y_16k_f = np.squeeze(y_16k_batch[i]).astype(np.float32)
1390
+ y_hat_16k_f = np.squeeze(y_hat_16k_batch[i]).astype(np.float32)
1391
+
1392
+ try:
1393
+ pesq_score = pesq(16000, y_16k_f, y_hat_16k_f, mode="wb")
1394
+ total_pesq += pesq_score
1395
+ valid_pesq_count += 1
1396
+ except Exception as e:
1397
+ print(f"[PESQ skipped] {e}")
1398
+
1399
+ except Exception as e:
1400
+ print(f"[PESQ skipped outer] {e}")
1401
+
1402
+ count += 1
1403
+
1404
+ avg_mel = total_mel_error / count
1405
+ avg_mrstft = total_mrstft_loss / count
1406
+ avg_pesq = total_pesq / max(valid_pesq_count, 1)
1407
+ avg_si_sdr = total_si_sdr / count
1408
+
1409
+ if writer is not None:
1410
+ writer.add_scalar("validation/loss/mel_l1", avg_mel, global_step)
1411
+ writer.add_scalar("validation/loss/mrstft", avg_mrstft, global_step)
1412
+ writer.add_scalar("validation/score/pesq", avg_pesq, global_step)
1413
+ writer.add_scalar("validation/score/si_sdr", avg_si_sdr, global_step)
1414
+
1415
+ net_g.train()
1416
+
1417
+ if __name__ == "__main__":
1418
+ torch.multiprocessing.set_start_method("spawn")
1419
+ main()