Ryouko commited on
Commit
1613a66
·
verified ·
1 Parent(s): 1fff6c5

Delete infer/lib/train

Browse files
infer/lib/train/data_utils.py DELETED
@@ -1,517 +0,0 @@
1
- import os
2
- import traceback
3
- import logging
4
-
5
- logger = logging.getLogger(__name__)
6
-
7
- import numpy as np
8
- import torch
9
- import torch.utils.data
10
-
11
- from infer.lib.train.mel_processing import spectrogram_torch
12
- from infer.lib.train.utils import load_filepaths_and_text, load_wav_to_torch
13
-
14
-
15
- class TextAudioLoaderMultiNSFsid(torch.utils.data.Dataset):
16
- """
17
- 1) loads audio, text pairs
18
- 2) normalizes text and converts them to sequences of integers
19
- 3) computes spectrograms from audio files.
20
- """
21
-
22
- def __init__(self, audiopaths_and_text, hparams):
23
- self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text)
24
- self.max_wav_value = hparams.max_wav_value
25
- self.sampling_rate = hparams.sampling_rate
26
- self.filter_length = hparams.filter_length
27
- self.hop_length = hparams.hop_length
28
- self.win_length = hparams.win_length
29
- self.sampling_rate = hparams.sampling_rate
30
- self.min_text_len = getattr(hparams, "min_text_len", 1)
31
- self.max_text_len = getattr(hparams, "max_text_len", 5000)
32
- self._filter()
33
-
34
- def _filter(self):
35
- """
36
- Filter text & store spec lengths
37
- """
38
- # Store spectrogram lengths for Bucketing
39
- # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2)
40
- # spec_length = wav_length // hop_length
41
- audiopaths_and_text_new = []
42
- lengths = []
43
- for audiopath, text, pitch, pitchf, dv in self.audiopaths_and_text:
44
- if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
45
- audiopaths_and_text_new.append([audiopath, text, pitch, pitchf, dv])
46
- lengths.append(os.path.getsize(audiopath) // (3 * self.hop_length))
47
- self.audiopaths_and_text = audiopaths_and_text_new
48
- self.lengths = lengths
49
-
50
- def get_sid(self, sid):
51
- sid = torch.LongTensor([int(sid)])
52
- return sid
53
-
54
- def get_audio_text_pair(self, audiopath_and_text):
55
- # separate filename and text
56
- file = audiopath_and_text[0]
57
- phone = audiopath_and_text[1]
58
- pitch = audiopath_and_text[2]
59
- pitchf = audiopath_and_text[3]
60
- dv = audiopath_and_text[4]
61
-
62
- phone, pitch, pitchf = self.get_labels(phone, pitch, pitchf)
63
- spec, wav = self.get_audio(file)
64
- dv = self.get_sid(dv)
65
-
66
- len_phone = phone.size()[0]
67
- len_spec = spec.size()[-1]
68
- # print(123,phone.shape,pitch.shape,spec.shape)
69
- if len_phone != len_spec:
70
- len_min = min(len_phone, len_spec)
71
- # amor
72
- len_wav = len_min * self.hop_length
73
-
74
- spec = spec[:, :len_min]
75
- wav = wav[:, :len_wav]
76
-
77
- phone = phone[:len_min, :]
78
- pitch = pitch[:len_min]
79
- pitchf = pitchf[:len_min]
80
-
81
- return (spec, wav, phone, pitch, pitchf, dv)
82
-
83
- def get_labels(self, phone, pitch, pitchf):
84
- phone = np.load(phone)
85
- phone = np.repeat(phone, 2, axis=0)
86
- pitch = np.load(pitch)
87
- pitchf = np.load(pitchf)
88
- n_num = min(phone.shape[0], 900) # DistributedBucketSampler
89
- # print(234,phone.shape,pitch.shape)
90
- phone = phone[:n_num, :]
91
- pitch = pitch[:n_num]
92
- pitchf = pitchf[:n_num]
93
- phone = torch.FloatTensor(phone)
94
- pitch = torch.LongTensor(pitch)
95
- pitchf = torch.FloatTensor(pitchf)
96
- return phone, pitch, pitchf
97
-
98
- def get_audio(self, filename):
99
- audio, sampling_rate = load_wav_to_torch(filename)
100
- if sampling_rate != self.sampling_rate:
101
- raise ValueError(
102
- "{} SR doesn't match target {} SR".format(
103
- sampling_rate, self.sampling_rate
104
- )
105
- )
106
- audio_norm = audio
107
- # audio_norm = audio / self.max_wav_value
108
- # audio_norm = audio / np.abs(audio).max()
109
-
110
- audio_norm = audio_norm.unsqueeze(0)
111
- spec_filename = filename.replace(".wav", ".spec.pt")
112
- if os.path.exists(spec_filename):
113
- try:
114
- spec = torch.load(spec_filename)
115
- except:
116
- logger.warn("%s %s", spec_filename, traceback.format_exc())
117
- spec = spectrogram_torch(
118
- audio_norm,
119
- self.filter_length,
120
- self.sampling_rate,
121
- self.hop_length,
122
- self.win_length,
123
- center=False,
124
- )
125
- spec = torch.squeeze(spec, 0)
126
- torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
127
- else:
128
- spec = spectrogram_torch(
129
- audio_norm,
130
- self.filter_length,
131
- self.sampling_rate,
132
- self.hop_length,
133
- self.win_length,
134
- center=False,
135
- )
136
- spec = torch.squeeze(spec, 0)
137
- torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
138
- return spec, audio_norm
139
-
140
- def __getitem__(self, index):
141
- return self.get_audio_text_pair(self.audiopaths_and_text[index])
142
-
143
- def __len__(self):
144
- return len(self.audiopaths_and_text)
145
-
146
-
147
- class TextAudioCollateMultiNSFsid:
148
- """Zero-pads model inputs and targets"""
149
-
150
- def __init__(self, return_ids=False):
151
- self.return_ids = return_ids
152
-
153
- def __call__(self, batch):
154
- """Collate's training batch from normalized text and aduio
155
- PARAMS
156
- ------
157
- batch: [text_normalized, spec_normalized, wav_normalized]
158
- """
159
- # Right zero-pad all one-hot text sequences to max input length
160
- _, ids_sorted_decreasing = torch.sort(
161
- torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True
162
- )
163
-
164
- max_spec_len = max([x[0].size(1) for x in batch])
165
- max_wave_len = max([x[1].size(1) for x in batch])
166
- spec_lengths = torch.LongTensor(len(batch))
167
- wave_lengths = torch.LongTensor(len(batch))
168
- spec_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max_spec_len)
169
- wave_padded = torch.FloatTensor(len(batch), 1, max_wave_len)
170
- spec_padded.zero_()
171
- wave_padded.zero_()
172
-
173
- max_phone_len = max([x[2].size(0) for x in batch])
174
- phone_lengths = torch.LongTensor(len(batch))
175
- phone_padded = torch.FloatTensor(
176
- len(batch), max_phone_len, batch[0][2].shape[1]
177
- ) # (spec, wav, phone, pitch)
178
- pitch_padded = torch.LongTensor(len(batch), max_phone_len)
179
- pitchf_padded = torch.FloatTensor(len(batch), max_phone_len)
180
- phone_padded.zero_()
181
- pitch_padded.zero_()
182
- pitchf_padded.zero_()
183
- # dv = torch.FloatTensor(len(batch), 256)#gin=256
184
- sid = torch.LongTensor(len(batch))
185
-
186
- for i in range(len(ids_sorted_decreasing)):
187
- row = batch[ids_sorted_decreasing[i]]
188
-
189
- spec = row[0]
190
- spec_padded[i, :, : spec.size(1)] = spec
191
- spec_lengths[i] = spec.size(1)
192
-
193
- wave = row[1]
194
- wave_padded[i, :, : wave.size(1)] = wave
195
- wave_lengths[i] = wave.size(1)
196
-
197
- phone = row[2]
198
- phone_padded[i, : phone.size(0), :] = phone
199
- phone_lengths[i] = phone.size(0)
200
-
201
- pitch = row[3]
202
- pitch_padded[i, : pitch.size(0)] = pitch
203
- pitchf = row[4]
204
- pitchf_padded[i, : pitchf.size(0)] = pitchf
205
-
206
- # dv[i] = row[5]
207
- sid[i] = row[5]
208
-
209
- return (
210
- phone_padded,
211
- phone_lengths,
212
- pitch_padded,
213
- pitchf_padded,
214
- spec_padded,
215
- spec_lengths,
216
- wave_padded,
217
- wave_lengths,
218
- # dv
219
- sid,
220
- )
221
-
222
-
223
- class TextAudioLoader(torch.utils.data.Dataset):
224
- """
225
- 1) loads audio, text pairs
226
- 2) normalizes text and converts them to sequences of integers
227
- 3) computes spectrograms from audio files.
228
- """
229
-
230
- def __init__(self, audiopaths_and_text, hparams):
231
- self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text)
232
- self.max_wav_value = hparams.max_wav_value
233
- self.sampling_rate = hparams.sampling_rate
234
- self.filter_length = hparams.filter_length
235
- self.hop_length = hparams.hop_length
236
- self.win_length = hparams.win_length
237
- self.sampling_rate = hparams.sampling_rate
238
- self.min_text_len = getattr(hparams, "min_text_len", 1)
239
- self.max_text_len = getattr(hparams, "max_text_len", 5000)
240
- self._filter()
241
-
242
- def _filter(self):
243
- """
244
- Filter text & store spec lengths
245
- """
246
- # Store spectrogram lengths for Bucketing
247
- # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2)
248
- # spec_length = wav_length // hop_length
249
- audiopaths_and_text_new = []
250
- lengths = []
251
- for audiopath, text, dv in self.audiopaths_and_text:
252
- if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
253
- audiopaths_and_text_new.append([audiopath, text, dv])
254
- lengths.append(os.path.getsize(audiopath) // (3 * self.hop_length))
255
- self.audiopaths_and_text = audiopaths_and_text_new
256
- self.lengths = lengths
257
-
258
- def get_sid(self, sid):
259
- sid = torch.LongTensor([int(sid)])
260
- return sid
261
-
262
- def get_audio_text_pair(self, audiopath_and_text):
263
- # separate filename and text
264
- file = audiopath_and_text[0]
265
- phone = audiopath_and_text[1]
266
- dv = audiopath_and_text[2]
267
-
268
- phone = self.get_labels(phone)
269
- spec, wav = self.get_audio(file)
270
- dv = self.get_sid(dv)
271
-
272
- len_phone = phone.size()[0]
273
- len_spec = spec.size()[-1]
274
- if len_phone != len_spec:
275
- len_min = min(len_phone, len_spec)
276
- len_wav = len_min * self.hop_length
277
- spec = spec[:, :len_min]
278
- wav = wav[:, :len_wav]
279
- phone = phone[:len_min, :]
280
- return (spec, wav, phone, dv)
281
-
282
- def get_labels(self, phone):
283
- phone = np.load(phone)
284
- phone = np.repeat(phone, 2, axis=0)
285
- n_num = min(phone.shape[0], 900) # DistributedBucketSampler
286
- phone = phone[:n_num, :]
287
- phone = torch.FloatTensor(phone)
288
- return phone
289
-
290
- def get_audio(self, filename):
291
- audio, sampling_rate = load_wav_to_torch(filename)
292
- if sampling_rate != self.sampling_rate:
293
- raise ValueError(
294
- "{} SR doesn't match target {} SR".format(
295
- sampling_rate, self.sampling_rate
296
- )
297
- )
298
- audio_norm = audio
299
- # audio_norm = audio / self.max_wav_value
300
- # audio_norm = audio / np.abs(audio).max()
301
-
302
- audio_norm = audio_norm.unsqueeze(0)
303
- spec_filename = filename.replace(".wav", ".spec.pt")
304
- if os.path.exists(spec_filename):
305
- try:
306
- spec = torch.load(spec_filename)
307
- except:
308
- logger.warn("%s %s", spec_filename, traceback.format_exc())
309
- spec = spectrogram_torch(
310
- audio_norm,
311
- self.filter_length,
312
- self.sampling_rate,
313
- self.hop_length,
314
- self.win_length,
315
- center=False,
316
- )
317
- spec = torch.squeeze(spec, 0)
318
- torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
319
- else:
320
- spec = spectrogram_torch(
321
- audio_norm,
322
- self.filter_length,
323
- self.sampling_rate,
324
- self.hop_length,
325
- self.win_length,
326
- center=False,
327
- )
328
- spec = torch.squeeze(spec, 0)
329
- torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
330
- return spec, audio_norm
331
-
332
- def __getitem__(self, index):
333
- return self.get_audio_text_pair(self.audiopaths_and_text[index])
334
-
335
- def __len__(self):
336
- return len(self.audiopaths_and_text)
337
-
338
-
339
- class TextAudioCollate:
340
- """Zero-pads model inputs and targets"""
341
-
342
- def __init__(self, return_ids=False):
343
- self.return_ids = return_ids
344
-
345
- def __call__(self, batch):
346
- """Collate's training batch from normalized text and aduio
347
- PARAMS
348
- ------
349
- batch: [text_normalized, spec_normalized, wav_normalized]
350
- """
351
- # Right zero-pad all one-hot text sequences to max input length
352
- _, ids_sorted_decreasing = torch.sort(
353
- torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True
354
- )
355
-
356
- max_spec_len = max([x[0].size(1) for x in batch])
357
- max_wave_len = max([x[1].size(1) for x in batch])
358
- spec_lengths = torch.LongTensor(len(batch))
359
- wave_lengths = torch.LongTensor(len(batch))
360
- spec_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max_spec_len)
361
- wave_padded = torch.FloatTensor(len(batch), 1, max_wave_len)
362
- spec_padded.zero_()
363
- wave_padded.zero_()
364
-
365
- max_phone_len = max([x[2].size(0) for x in batch])
366
- phone_lengths = torch.LongTensor(len(batch))
367
- phone_padded = torch.FloatTensor(
368
- len(batch), max_phone_len, batch[0][2].shape[1]
369
- )
370
- phone_padded.zero_()
371
- sid = torch.LongTensor(len(batch))
372
-
373
- for i in range(len(ids_sorted_decreasing)):
374
- row = batch[ids_sorted_decreasing[i]]
375
-
376
- spec = row[0]
377
- spec_padded[i, :, : spec.size(1)] = spec
378
- spec_lengths[i] = spec.size(1)
379
-
380
- wave = row[1]
381
- wave_padded[i, :, : wave.size(1)] = wave
382
- wave_lengths[i] = wave.size(1)
383
-
384
- phone = row[2]
385
- phone_padded[i, : phone.size(0), :] = phone
386
- phone_lengths[i] = phone.size(0)
387
-
388
- sid[i] = row[3]
389
-
390
- return (
391
- phone_padded,
392
- phone_lengths,
393
- spec_padded,
394
- spec_lengths,
395
- wave_padded,
396
- wave_lengths,
397
- sid,
398
- )
399
-
400
-
401
- class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
402
- """
403
- Maintain similar input lengths in a batch.
404
- Length groups are specified by boundaries.
405
- Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
406
-
407
- It removes samples which are not included in the boundaries.
408
- Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
409
- """
410
-
411
- def __init__(
412
- self,
413
- dataset,
414
- batch_size,
415
- boundaries,
416
- num_replicas=None,
417
- rank=None,
418
- shuffle=True,
419
- ):
420
- super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
421
- self.lengths = dataset.lengths
422
- self.batch_size = batch_size
423
- self.boundaries = boundaries
424
-
425
- self.buckets, self.num_samples_per_bucket = self._create_buckets()
426
- self.total_size = sum(self.num_samples_per_bucket)
427
- self.num_samples = self.total_size // self.num_replicas
428
-
429
- def _create_buckets(self):
430
- buckets = [[] for _ in range(len(self.boundaries) - 1)]
431
- for i in range(len(self.lengths)):
432
- length = self.lengths[i]
433
- idx_bucket = self._bisect(length)
434
- if idx_bucket != -1:
435
- buckets[idx_bucket].append(i)
436
-
437
- for i in range(len(buckets) - 1, -1, -1): #
438
- if len(buckets[i]) == 0:
439
- buckets.pop(i)
440
- self.boundaries.pop(i + 1)
441
-
442
- num_samples_per_bucket = []
443
- for i in range(len(buckets)):
444
- len_bucket = len(buckets[i])
445
- total_batch_size = self.num_replicas * self.batch_size
446
- rem = (
447
- total_batch_size - (len_bucket % total_batch_size)
448
- ) % total_batch_size
449
- num_samples_per_bucket.append(len_bucket + rem)
450
- return buckets, num_samples_per_bucket
451
-
452
- def __iter__(self):
453
- # deterministically shuffle based on epoch
454
- g = torch.Generator()
455
- g.manual_seed(self.epoch)
456
-
457
- indices = []
458
- if self.shuffle:
459
- for bucket in self.buckets:
460
- indices.append(torch.randperm(len(bucket), generator=g).tolist())
461
- else:
462
- for bucket in self.buckets:
463
- indices.append(list(range(len(bucket))))
464
-
465
- batches = []
466
- for i in range(len(self.buckets)):
467
- bucket = self.buckets[i]
468
- len_bucket = len(bucket)
469
- ids_bucket = indices[i]
470
- num_samples_bucket = self.num_samples_per_bucket[i]
471
-
472
- # add extra samples to make it evenly divisible
473
- rem = num_samples_bucket - len_bucket
474
- ids_bucket = (
475
- ids_bucket
476
- + ids_bucket * (rem // len_bucket)
477
- + ids_bucket[: (rem % len_bucket)]
478
- )
479
-
480
- # subsample
481
- ids_bucket = ids_bucket[self.rank :: self.num_replicas]
482
-
483
- # batching
484
- for j in range(len(ids_bucket) // self.batch_size):
485
- batch = [
486
- bucket[idx]
487
- for idx in ids_bucket[
488
- j * self.batch_size : (j + 1) * self.batch_size
489
- ]
490
- ]
491
- batches.append(batch)
492
-
493
- if self.shuffle:
494
- batch_ids = torch.randperm(len(batches), generator=g).tolist()
495
- batches = [batches[i] for i in batch_ids]
496
- self.batches = batches
497
-
498
- assert len(self.batches) * self.batch_size == self.num_samples
499
- return iter(self.batches)
500
-
501
- def _bisect(self, x, lo=0, hi=None):
502
- if hi is None:
503
- hi = len(self.boundaries) - 1
504
-
505
- if hi > lo:
506
- mid = (hi + lo) // 2
507
- if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]:
508
- return mid
509
- elif x <= self.boundaries[mid]:
510
- return self._bisect(x, lo, mid)
511
- else:
512
- return self._bisect(x, mid + 1, hi)
513
- else:
514
- return -1
515
-
516
- def __len__(self):
517
- return self.num_samples // self.batch_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
infer/lib/train/losses.py DELETED
@@ -1,58 +0,0 @@
1
- import torch
2
-
3
-
4
- def feature_loss(fmap_r, fmap_g):
5
- loss = 0
6
- for dr, dg in zip(fmap_r, fmap_g):
7
- for rl, gl in zip(dr, dg):
8
- rl = rl.float().detach()
9
- gl = gl.float()
10
- loss += torch.mean(torch.abs(rl - gl))
11
-
12
- return loss * 2
13
-
14
-
15
- def discriminator_loss(disc_real_outputs, disc_generated_outputs):
16
- loss = 0
17
- r_losses = []
18
- g_losses = []
19
- for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
20
- dr = dr.float()
21
- dg = dg.float()
22
- r_loss = torch.mean((1 - dr) ** 2)
23
- g_loss = torch.mean(dg**2)
24
- loss += r_loss + g_loss
25
- r_losses.append(r_loss.item())
26
- g_losses.append(g_loss.item())
27
-
28
- return loss, r_losses, g_losses
29
-
30
-
31
- def generator_loss(disc_outputs):
32
- loss = 0
33
- gen_losses = []
34
- for dg in disc_outputs:
35
- dg = dg.float()
36
- l = torch.mean((1 - dg) ** 2)
37
- gen_losses.append(l)
38
- loss += l
39
-
40
- return loss, gen_losses
41
-
42
-
43
- def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
44
- """
45
- z_p, logs_q: [b, h, t_t]
46
- m_p, logs_p: [b, h, t_t]
47
- """
48
- z_p = z_p.float()
49
- logs_q = logs_q.float()
50
- m_p = m_p.float()
51
- logs_p = logs_p.float()
52
- z_mask = z_mask.float()
53
-
54
- kl = logs_p - logs_q - 0.5
55
- kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
56
- kl = torch.sum(kl * z_mask)
57
- l = kl / torch.sum(z_mask)
58
- return l
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
infer/lib/train/mel_processing.py DELETED
@@ -1,132 +0,0 @@
1
- import torch
2
- import torch.utils.data
3
- from librosa.filters import mel as librosa_mel_fn
4
- import logging
5
-
6
- logger = logging.getLogger(__name__)
7
-
8
- MAX_WAV_VALUE = 32768.0
9
-
10
-
11
- def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
12
- """
13
- PARAMS
14
- ------
15
- C: compression factor
16
- """
17
- return torch.log(torch.clamp(x, min=clip_val) * C)
18
-
19
-
20
- def dynamic_range_decompression_torch(x, C=1):
21
- """
22
- PARAMS
23
- ------
24
- C: compression factor used to compress
25
- """
26
- return torch.exp(x) / C
27
-
28
-
29
- def spectral_normalize_torch(magnitudes):
30
- return dynamic_range_compression_torch(magnitudes)
31
-
32
-
33
- def spectral_de_normalize_torch(magnitudes):
34
- return dynamic_range_decompression_torch(magnitudes)
35
-
36
-
37
- # Reusable banks
38
- mel_basis = {}
39
- hann_window = {}
40
-
41
-
42
- def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
43
- """Convert waveform into Linear-frequency Linear-amplitude spectrogram.
44
-
45
- Args:
46
- y :: (B, T) - Audio waveforms
47
- n_fft
48
- sampling_rate
49
- hop_size
50
- win_size
51
- center
52
- Returns:
53
- :: (B, Freq, Frame) - Linear-frequency Linear-amplitude spectrogram
54
- """
55
- # Validation
56
- if torch.min(y) < -1.07:
57
- logger.debug("min value is %s", str(torch.min(y)))
58
- if torch.max(y) > 1.07:
59
- logger.debug("max value is %s", str(torch.max(y)))
60
-
61
- # Window - Cache if needed
62
- global hann_window
63
- dtype_device = str(y.dtype) + "_" + str(y.device)
64
- wnsize_dtype_device = str(win_size) + "_" + dtype_device
65
- if wnsize_dtype_device not in hann_window:
66
- hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
67
- dtype=y.dtype, device=y.device
68
- )
69
-
70
- # Padding
71
- y = torch.nn.functional.pad(
72
- y.unsqueeze(1),
73
- (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
74
- mode="reflect",
75
- )
76
- y = y.squeeze(1)
77
-
78
- # Complex Spectrogram :: (B, T) -> (B, Freq, Frame, RealComplex=2)
79
- spec = torch.stft(
80
- y,
81
- n_fft,
82
- hop_length=hop_size,
83
- win_length=win_size,
84
- window=hann_window[wnsize_dtype_device],
85
- center=center,
86
- pad_mode="reflect",
87
- normalized=False,
88
- onesided=True,
89
- return_complex=False,
90
- )
91
-
92
- # Linear-frequency Linear-amplitude spectrogram :: (B, Freq, Frame, RealComplex=2) -> (B, Freq, Frame)
93
- spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
94
- return spec
95
-
96
-
97
- def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
98
- # MelBasis - Cache if needed
99
- global mel_basis
100
- dtype_device = str(spec.dtype) + "_" + str(spec.device)
101
- fmax_dtype_device = str(fmax) + "_" + dtype_device
102
- if fmax_dtype_device not in mel_basis:
103
- mel = librosa_mel_fn(
104
- sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
105
- )
106
- mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
107
- dtype=spec.dtype, device=spec.device
108
- )
109
-
110
- # Mel-frequency Log-amplitude spectrogram :: (B, Freq=num_mels, Frame)
111
- melspec = torch.matmul(mel_basis[fmax_dtype_device], spec)
112
- melspec = spectral_normalize_torch(melspec)
113
- return melspec
114
-
115
-
116
- def mel_spectrogram_torch(
117
- y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
118
- ):
119
- """Convert waveform into Mel-frequency Log-amplitude spectrogram.
120
-
121
- Args:
122
- y :: (B, T) - Waveforms
123
- Returns:
124
- melspec :: (B, Freq, Frame) - Mel-frequency Log-amplitude spectrogram
125
- """
126
- # Linear-frequency Linear-amplitude spectrogram :: (B, T) -> (B, Freq, Frame)
127
- spec = spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center)
128
-
129
- # Mel-frequency Log-amplitude spectrogram :: (B, Freq, Frame) -> (B, Freq=num_mels, Frame)
130
- melspec = spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax)
131
-
132
- return melspec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
infer/lib/train/process_ckpt.py DELETED
@@ -1,261 +0,0 @@
1
- import os
2
- import sys
3
- import traceback
4
- from collections import OrderedDict
5
-
6
- import torch
7
-
8
- from i18n.i18n import I18nAuto
9
-
10
- i18n = I18nAuto()
11
-
12
-
13
- def savee(ckpt, sr, if_f0, name, epoch, version, hps):
14
- try:
15
- opt = OrderedDict()
16
- opt["weight"] = {}
17
- for key in ckpt.keys():
18
- if "enc_q" in key:
19
- continue
20
- opt["weight"][key] = ckpt[key].half()
21
- opt["config"] = [
22
- hps.data.filter_length // 2 + 1,
23
- 32,
24
- hps.model.inter_channels,
25
- hps.model.hidden_channels,
26
- hps.model.filter_channels,
27
- hps.model.n_heads,
28
- hps.model.n_layers,
29
- hps.model.kernel_size,
30
- hps.model.p_dropout,
31
- hps.model.resblock,
32
- hps.model.resblock_kernel_sizes,
33
- hps.model.resblock_dilation_sizes,
34
- hps.model.upsample_rates,
35
- hps.model.upsample_initial_channel,
36
- hps.model.upsample_kernel_sizes,
37
- hps.model.spk_embed_dim,
38
- hps.model.gin_channels,
39
- hps.data.sampling_rate,
40
- ]
41
- opt["info"] = "%sepoch" % epoch
42
- opt["sr"] = sr
43
- opt["f0"] = if_f0
44
- opt["version"] = version
45
- torch.save(opt, "assets/weights/%s.pth" % name)
46
- return "Success."
47
- except:
48
- return traceback.format_exc()
49
-
50
-
51
- def show_info(path):
52
- try:
53
- a = torch.load(path, map_location="cpu")
54
- return "模型信息:%s\n采样率:%s\n模型是否输入音高引导:%s\n版本:%s" % (
55
- a.get("info", "None"),
56
- a.get("sr", "None"),
57
- a.get("f0", "None"),
58
- a.get("version", "None"),
59
- )
60
- except:
61
- return traceback.format_exc()
62
-
63
-
64
- def extract_small_model(path, name, sr, if_f0, info, version):
65
- try:
66
- ckpt = torch.load(path, map_location="cpu")
67
- if "model" in ckpt:
68
- ckpt = ckpt["model"]
69
- opt = OrderedDict()
70
- opt["weight"] = {}
71
- for key in ckpt.keys():
72
- if "enc_q" in key:
73
- continue
74
- opt["weight"][key] = ckpt[key].half()
75
- if sr == "40k":
76
- opt["config"] = [
77
- 1025,
78
- 32,
79
- 192,
80
- 192,
81
- 768,
82
- 2,
83
- 6,
84
- 3,
85
- 0,
86
- "1",
87
- [3, 7, 11],
88
- [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
89
- [10, 10, 2, 2],
90
- 512,
91
- [16, 16, 4, 4],
92
- 109,
93
- 256,
94
- 40000,
95
- ]
96
- elif sr == "48k":
97
- if version == "v1":
98
- opt["config"] = [
99
- 1025,
100
- 32,
101
- 192,
102
- 192,
103
- 768,
104
- 2,
105
- 6,
106
- 3,
107
- 0,
108
- "1",
109
- [3, 7, 11],
110
- [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
111
- [10, 6, 2, 2, 2],
112
- 512,
113
- [16, 16, 4, 4, 4],
114
- 109,
115
- 256,
116
- 48000,
117
- ]
118
- else:
119
- opt["config"] = [
120
- 1025,
121
- 32,
122
- 192,
123
- 192,
124
- 768,
125
- 2,
126
- 6,
127
- 3,
128
- 0,
129
- "1",
130
- [3, 7, 11],
131
- [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
132
- [12, 10, 2, 2],
133
- 512,
134
- [24, 20, 4, 4],
135
- 109,
136
- 256,
137
- 48000,
138
- ]
139
- elif sr == "32k":
140
- if version == "v1":
141
- opt["config"] = [
142
- 513,
143
- 32,
144
- 192,
145
- 192,
146
- 768,
147
- 2,
148
- 6,
149
- 3,
150
- 0,
151
- "1",
152
- [3, 7, 11],
153
- [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
154
- [10, 4, 2, 2, 2],
155
- 512,
156
- [16, 16, 4, 4, 4],
157
- 109,
158
- 256,
159
- 32000,
160
- ]
161
- else:
162
- opt["config"] = [
163
- 513,
164
- 32,
165
- 192,
166
- 192,
167
- 768,
168
- 2,
169
- 6,
170
- 3,
171
- 0,
172
- "1",
173
- [3, 7, 11],
174
- [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
175
- [10, 8, 2, 2],
176
- 512,
177
- [20, 16, 4, 4],
178
- 109,
179
- 256,
180
- 32000,
181
- ]
182
- if info == "":
183
- info = "Extracted model."
184
- opt["info"] = info
185
- opt["version"] = version
186
- opt["sr"] = sr
187
- opt["f0"] = int(if_f0)
188
- torch.save(opt, "assets/weights/%s.pth" % name)
189
- return "Success."
190
- except:
191
- return traceback.format_exc()
192
-
193
-
194
- def change_info(path, info, name):
195
- try:
196
- ckpt = torch.load(path, map_location="cpu")
197
- ckpt["info"] = info
198
- if name == "":
199
- name = os.path.basename(path)
200
- torch.save(ckpt, "assets/weights/%s" % name)
201
- return "Success."
202
- except:
203
- return traceback.format_exc()
204
-
205
-
206
- def merge(path1, path2, alpha1, sr, f0, info, name, version):
207
- try:
208
-
209
- def extract(ckpt):
210
- a = ckpt["model"]
211
- opt = OrderedDict()
212
- opt["weight"] = {}
213
- for key in a.keys():
214
- if "enc_q" in key:
215
- continue
216
- opt["weight"][key] = a[key]
217
- return opt
218
-
219
- ckpt1 = torch.load(path1, map_location="cpu")
220
- ckpt2 = torch.load(path2, map_location="cpu")
221
- cfg = ckpt1["config"]
222
- if "model" in ckpt1:
223
- ckpt1 = extract(ckpt1)
224
- else:
225
- ckpt1 = ckpt1["weight"]
226
- if "model" in ckpt2:
227
- ckpt2 = extract(ckpt2)
228
- else:
229
- ckpt2 = ckpt2["weight"]
230
- if sorted(list(ckpt1.keys())) != sorted(list(ckpt2.keys())):
231
- return "Fail to merge the models. The model architectures are not the same."
232
- opt = OrderedDict()
233
- opt["weight"] = {}
234
- for key in ckpt1.keys():
235
- # try:
236
- if key == "emb_g.weight" and ckpt1[key].shape != ckpt2[key].shape:
237
- min_shape0 = min(ckpt1[key].shape[0], ckpt2[key].shape[0])
238
- opt["weight"][key] = (
239
- alpha1 * (ckpt1[key][:min_shape0].float())
240
- + (1 - alpha1) * (ckpt2[key][:min_shape0].float())
241
- ).half()
242
- else:
243
- opt["weight"][key] = (
244
- alpha1 * (ckpt1[key].float()) + (1 - alpha1) * (ckpt2[key].float())
245
- ).half()
246
- # except:
247
- # pdb.set_trace()
248
- opt["config"] = cfg
249
- """
250
- if(sr=="40k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 10, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 40000]
251
- elif(sr=="48k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10,6,2,2,2], 512, [16, 16, 4, 4], 109, 256, 48000]
252
- elif(sr=="32k"):opt["config"] = [513, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 4, 2, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 32000]
253
- """
254
- opt["sr"] = sr
255
- opt["f0"] = 1 if f0 == i18n("是") else 0
256
- opt["version"] = version
257
- opt["info"] = info
258
- torch.save(opt, "assets/weights/%s.pth" % name)
259
- return "Success."
260
- except:
261
- return traceback.format_exc()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
infer/lib/train/utils.py DELETED
@@ -1,478 +0,0 @@
1
- import argparse
2
- import glob
3
- import json
4
- import logging
5
- import os
6
- import subprocess
7
- import sys
8
- import shutil
9
-
10
- import numpy as np
11
- import torch
12
- from scipy.io.wavfile import read
13
-
14
- MATPLOTLIB_FLAG = False
15
-
16
- logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
17
- logger = logging
18
-
19
-
20
- def load_checkpoint_d(checkpoint_path, combd, sbd, optimizer=None, load_opt=1):
21
- assert os.path.isfile(checkpoint_path)
22
- checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
23
-
24
- ##################
25
- def go(model, bkey):
26
- saved_state_dict = checkpoint_dict[bkey]
27
- if hasattr(model, "module"):
28
- state_dict = model.module.state_dict()
29
- else:
30
- state_dict = model.state_dict()
31
- new_state_dict = {}
32
- for k, v in state_dict.items(): # 模型需要的shape
33
- try:
34
- new_state_dict[k] = saved_state_dict[k]
35
- if saved_state_dict[k].shape != state_dict[k].shape:
36
- logger.warn(
37
- "shape-%s-mismatch. need: %s, get: %s",
38
- k,
39
- state_dict[k].shape,
40
- saved_state_dict[k].shape,
41
- ) #
42
- raise KeyError
43
- except:
44
- # logger.info(traceback.format_exc())
45
- logger.info("%s is not in the checkpoint", k) # pretrain缺失的
46
- new_state_dict[k] = v # 模型自带的随机值
47
- if hasattr(model, "module"):
48
- model.module.load_state_dict(new_state_dict, strict=False)
49
- else:
50
- model.load_state_dict(new_state_dict, strict=False)
51
- return model
52
-
53
- go(combd, "combd")
54
- model = go(sbd, "sbd")
55
- #############
56
- logger.info("Loaded model weights")
57
-
58
- iteration = checkpoint_dict["iteration"]
59
- learning_rate = checkpoint_dict["learning_rate"]
60
- if (
61
- optimizer is not None and load_opt == 1
62
- ): ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch
63
- # try:
64
- optimizer.load_state_dict(checkpoint_dict["optimizer"])
65
- # except:
66
- # traceback.print_exc()
67
- logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, iteration))
68
- return model, optimizer, learning_rate, iteration
69
-
70
-
71
- # def load_checkpoint(checkpoint_path, model, optimizer=None):
72
- # assert os.path.isfile(checkpoint_path)
73
- # checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
74
- # iteration = checkpoint_dict['iteration']
75
- # learning_rate = checkpoint_dict['learning_rate']
76
- # if optimizer is not None:
77
- # optimizer.load_state_dict(checkpoint_dict['optimizer'])
78
- # # print(1111)
79
- # saved_state_dict = checkpoint_dict['model']
80
- # # print(1111)
81
- #
82
- # if hasattr(model, 'module'):
83
- # state_dict = model.module.state_dict()
84
- # else:
85
- # state_dict = model.state_dict()
86
- # new_state_dict= {}
87
- # for k, v in state_dict.items():
88
- # try:
89
- # new_state_dict[k] = saved_state_dict[k]
90
- # except:
91
- # logger.info("%s is not in the checkpoint" % k)
92
- # new_state_dict[k] = v
93
- # if hasattr(model, 'module'):
94
- # model.module.load_state_dict(new_state_dict)
95
- # else:
96
- # model.load_state_dict(new_state_dict)
97
- # logger.info("Loaded checkpoint '{}' (epoch {})" .format(
98
- # checkpoint_path, iteration))
99
- # return model, optimizer, learning_rate, iteration
100
- def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
101
- assert os.path.isfile(checkpoint_path)
102
- checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
103
-
104
- saved_state_dict = checkpoint_dict["model"]
105
- if hasattr(model, "module"):
106
- state_dict = model.module.state_dict()
107
- else:
108
- state_dict = model.state_dict()
109
- new_state_dict = {}
110
- for k, v in state_dict.items(): # 模型需要的shape
111
- try:
112
- new_state_dict[k] = saved_state_dict[k]
113
- if saved_state_dict[k].shape != state_dict[k].shape:
114
- logger.warn(
115
- "shape-%s-mismatch|need-%s|get-%s",
116
- k,
117
- state_dict[k].shape,
118
- saved_state_dict[k].shape,
119
- ) #
120
- raise KeyError
121
- except:
122
- # logger.info(traceback.format_exc())
123
- logger.info("%s is not in the checkpoint", k) # pretrain缺失的
124
- new_state_dict[k] = v # 模型自带的随机值
125
- if hasattr(model, "module"):
126
- model.module.load_state_dict(new_state_dict, strict=False)
127
- else:
128
- model.load_state_dict(new_state_dict, strict=False)
129
- logger.info("Loaded model weights")
130
-
131
- iteration = checkpoint_dict["iteration"]
132
- learning_rate = checkpoint_dict["learning_rate"]
133
- if (
134
- optimizer is not None and load_opt == 1
135
- ): ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch
136
- # try:
137
- optimizer.load_state_dict(checkpoint_dict["optimizer"])
138
- # except:
139
- # traceback.print_exc()
140
- logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, iteration))
141
- return model, optimizer, learning_rate, iteration
142
-
143
-
144
- def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
145
- logger.info(
146
- "Saving model and optimizer state at epoch {} to {}".format(
147
- iteration, checkpoint_path
148
- )
149
- )
150
- if hasattr(model, "module"):
151
- state_dict = model.module.state_dict()
152
- else:
153
- state_dict = model.state_dict()
154
- torch.save(
155
- {
156
- "model": state_dict,
157
- "iteration": iteration,
158
- "optimizer": optimizer.state_dict(),
159
- "learning_rate": learning_rate,
160
- },
161
- checkpoint_path,
162
- )
163
-
164
-
165
- def save_checkpoint_d(combd, sbd, optimizer, learning_rate, iteration, checkpoint_path):
166
- logger.info(
167
- "Saving model and optimizer state at epoch {} to {}".format(
168
- iteration, checkpoint_path
169
- )
170
- )
171
- if hasattr(combd, "module"):
172
- state_dict_combd = combd.module.state_dict()
173
- else:
174
- state_dict_combd = combd.state_dict()
175
- if hasattr(sbd, "module"):
176
- state_dict_sbd = sbd.module.state_dict()
177
- else:
178
- state_dict_sbd = sbd.state_dict()
179
- torch.save(
180
- {
181
- "combd": state_dict_combd,
182
- "sbd": state_dict_sbd,
183
- "iteration": iteration,
184
- "optimizer": optimizer.state_dict(),
185
- "learning_rate": learning_rate,
186
- },
187
- checkpoint_path,
188
- )
189
-
190
-
191
- def summarize(
192
- writer,
193
- global_step,
194
- scalars={},
195
- histograms={},
196
- images={},
197
- audios={},
198
- audio_sampling_rate=22050,
199
- ):
200
- for k, v in scalars.items():
201
- writer.add_scalar(k, v, global_step)
202
- for k, v in histograms.items():
203
- writer.add_histogram(k, v, global_step)
204
- for k, v in images.items():
205
- writer.add_image(k, v, global_step, dataformats="HWC")
206
- for k, v in audios.items():
207
- writer.add_audio(k, v, global_step, audio_sampling_rate)
208
-
209
-
210
- def latest_checkpoint_path(dir_path, regex="G_*.pth"):
211
- f_list = glob.glob(os.path.join(dir_path, regex))
212
- f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
213
- x = f_list[-1]
214
- logger.debug(x)
215
- return x
216
-
217
-
218
- def plot_spectrogram_to_numpy(spectrogram):
219
- global MATPLOTLIB_FLAG
220
- if not MATPLOTLIB_FLAG:
221
- import matplotlib
222
-
223
- matplotlib.use("Agg")
224
- MATPLOTLIB_FLAG = True
225
- mpl_logger = logging.getLogger("matplotlib")
226
- mpl_logger.setLevel(logging.WARNING)
227
- import matplotlib.pylab as plt
228
- import numpy as np
229
-
230
- fig, ax = plt.subplots(figsize=(10, 2))
231
- im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
232
- plt.colorbar(im, ax=ax)
233
- plt.xlabel("Frames")
234
- plt.ylabel("Channels")
235
- plt.tight_layout()
236
-
237
- fig.canvas.draw()
238
- data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
239
- data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
240
- plt.close()
241
- return data
242
-
243
-
244
- def plot_alignment_to_numpy(alignment, info=None):
245
- global MATPLOTLIB_FLAG
246
- if not MATPLOTLIB_FLAG:
247
- import matplotlib
248
-
249
- matplotlib.use("Agg")
250
- MATPLOTLIB_FLAG = True
251
- mpl_logger = logging.getLogger("matplotlib")
252
- mpl_logger.setLevel(logging.WARNING)
253
- import matplotlib.pylab as plt
254
- import numpy as np
255
-
256
- fig, ax = plt.subplots(figsize=(6, 4))
257
- im = ax.imshow(
258
- alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
259
- )
260
- fig.colorbar(im, ax=ax)
261
- xlabel = "Decoder timestep"
262
- if info is not None:
263
- xlabel += "\n\n" + info
264
- plt.xlabel(xlabel)
265
- plt.ylabel("Encoder timestep")
266
- plt.tight_layout()
267
-
268
- fig.canvas.draw()
269
- data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
270
- data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
271
- plt.close()
272
- return data
273
-
274
-
275
- def load_wav_to_torch(full_path):
276
- sampling_rate, data = read(full_path)
277
- return torch.FloatTensor(data.astype(np.float32)), sampling_rate
278
-
279
-
280
- def load_filepaths_and_text(filename, split="|"):
281
- with open(filename, encoding="utf-8") as f:
282
- filepaths_and_text = [line.strip().split(split) for line in f]
283
- return filepaths_and_text
284
-
285
-
286
- def get_hparams(init=True):
287
- """
288
- todo:
289
- 结尾七人组:
290
- 保存频率、总epoch done
291
- bs done
292
- pretrainG、pretrainD done
293
- 卡号:os.en["CUDA_VISIBLE_DEVICES"] done
294
- if_latest done
295
- 模型:if_f0 done
296
- 采样率:自动选择config done
297
- 是否缓存数据集进GPU:if_cache_data_in_gpu done
298
-
299
- -m:
300
- 自动决定training_files路径,改掉train_nsf_load_pretrain.py里的hps.data.training_files done
301
- -c不要了
302
- """
303
- parser = argparse.ArgumentParser()
304
- parser.add_argument(
305
- "-se",
306
- "--save_every_epoch",
307
- type=int,
308
- required=True,
309
- help="checkpoint save frequency (epoch)",
310
- )
311
- parser.add_argument(
312
- "-te", "--total_epoch", type=int, required=True, help="total_epoch"
313
- )
314
- parser.add_argument(
315
- "-pg", "--pretrainG", type=str, default="", help="Pretrained Discriminator path"
316
- )
317
- parser.add_argument(
318
- "-pd", "--pretrainD", type=str, default="", help="Pretrained Generator path"
319
- )
320
- parser.add_argument("-g", "--gpus", type=str, default="0", help="split by -")
321
- parser.add_argument(
322
- "-bs", "--batch_size", type=int, required=True, help="batch size"
323
- )
324
- parser.add_argument(
325
- "-e", "--experiment_dir", type=str, required=True, help="experiment dir"
326
- ) # -m
327
- parser.add_argument(
328
- "-sr", "--sample_rate", type=str, required=True, help="sample rate, 32k/40k/48k"
329
- )
330
- parser.add_argument(
331
- "-sw",
332
- "--save_every_weights",
333
- type=str,
334
- default="0",
335
- help="save the extracted model in weights directory when saving checkpoints",
336
- )
337
- parser.add_argument(
338
- "-v", "--version", type=str, required=True, help="model version"
339
- )
340
- parser.add_argument(
341
- "-f0",
342
- "--if_f0",
343
- type=int,
344
- required=True,
345
- help="use f0 as one of the inputs of the model, 1 or 0",
346
- )
347
- parser.add_argument(
348
- "-l",
349
- "--if_latest",
350
- type=int,
351
- required=True,
352
- help="if only save the latest G/D pth file, 1 or 0",
353
- )
354
- parser.add_argument(
355
- "-c",
356
- "--if_cache_data_in_gpu",
357
- type=int,
358
- required=True,
359
- help="if caching the dataset in GPU memory, 1 or 0",
360
- )
361
-
362
- args = parser.parse_args()
363
- name = args.experiment_dir
364
- experiment_dir = os.path.join("./logs", args.experiment_dir)
365
-
366
- config_save_path = os.path.join(experiment_dir, "config.json")
367
- with open(config_save_path, "r") as f:
368
- config = json.load(f)
369
-
370
- hparams = HParams(**config)
371
- hparams.model_dir = hparams.experiment_dir = experiment_dir
372
- hparams.save_every_epoch = args.save_every_epoch
373
- hparams.name = name
374
- hparams.total_epoch = args.total_epoch
375
- hparams.pretrainG = args.pretrainG
376
- hparams.pretrainD = args.pretrainD
377
- hparams.version = args.version
378
- hparams.gpus = args.gpus
379
- hparams.train.batch_size = args.batch_size
380
- hparams.sample_rate = args.sample_rate
381
- hparams.if_f0 = args.if_f0
382
- hparams.if_latest = args.if_latest
383
- hparams.save_every_weights = args.save_every_weights
384
- hparams.if_cache_data_in_gpu = args.if_cache_data_in_gpu
385
- hparams.data.training_files = "%s/filelist.txt" % experiment_dir
386
- return hparams
387
-
388
-
389
- def get_hparams_from_dir(model_dir):
390
- config_save_path = os.path.join(model_dir, "config.json")
391
- with open(config_save_path, "r") as f:
392
- data = f.read()
393
- config = json.loads(data)
394
-
395
- hparams = HParams(**config)
396
- hparams.model_dir = model_dir
397
- return hparams
398
-
399
-
400
- def get_hparams_from_file(config_path):
401
- with open(config_path, "r") as f:
402
- data = f.read()
403
- config = json.loads(data)
404
-
405
- hparams = HParams(**config)
406
- return hparams
407
-
408
-
409
- def check_git_hash(model_dir):
410
- source_dir = os.path.dirname(os.path.realpath(__file__))
411
- if not os.path.exists(os.path.join(source_dir, ".git")):
412
- logger.warn(
413
- "{} is not a git repository, therefore hash value comparison will be ignored.".format(
414
- source_dir
415
- )
416
- )
417
- return
418
-
419
- cur_hash = subprocess.getoutput("git rev-parse HEAD")
420
-
421
- path = os.path.join(model_dir, "githash")
422
- if os.path.exists(path):
423
- saved_hash = open(path).read()
424
- if saved_hash != cur_hash:
425
- logger.warn(
426
- "git hash values are different. {}(saved) != {}(current)".format(
427
- saved_hash[:8], cur_hash[:8]
428
- )
429
- )
430
- else:
431
- open(path, "w").write(cur_hash)
432
-
433
-
434
- def get_logger(model_dir, filename="train.log"):
435
- global logger
436
- logger = logging.getLogger(os.path.basename(model_dir))
437
- logger.setLevel(logging.DEBUG)
438
-
439
- formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
440
- if not os.path.exists(model_dir):
441
- os.makedirs(model_dir)
442
- h = logging.FileHandler(os.path.join(model_dir, filename))
443
- h.setLevel(logging.DEBUG)
444
- h.setFormatter(formatter)
445
- logger.addHandler(h)
446
- return logger
447
-
448
-
449
- class HParams:
450
- def __init__(self, **kwargs):
451
- for k, v in kwargs.items():
452
- if type(v) == dict:
453
- v = HParams(**v)
454
- self[k] = v
455
-
456
- def keys(self):
457
- return self.__dict__.keys()
458
-
459
- def items(self):
460
- return self.__dict__.items()
461
-
462
- def values(self):
463
- return self.__dict__.values()
464
-
465
- def __len__(self):
466
- return len(self.__dict__)
467
-
468
- def __getitem__(self, key):
469
- return getattr(self, key)
470
-
471
- def __setitem__(self, key, value):
472
- return setattr(self, key, value)
473
-
474
- def __contains__(self, key):
475
- return key in self.__dict__
476
-
477
- def __repr__(self):
478
- return self.__dict__.__repr__()