stephenhoang commited on
Commit
bb2d9d3
·
1 Parent(s): 4ea28e6

update models

Browse files
Files changed (2) hide show
  1. Models/config.yaml +80 -0
  2. inference.py +356 -0
Models/config.yaml ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: ./Models/Finetune
2
+ save_freq: 1
3
+ log_interval: 10
4
+ device: cuda
5
+ epochs: 50
6
+ batch_size: 2
7
+ max_len: 310 # maximum number of frames
8
+ pretrained_model: ./Models/Finetune/base_model.pth
9
+ load_only_params: false # set to true if do not want to load epoch numbers and optimizer parameters
10
+ debug: true
11
+
12
+ data_params:
13
+ train_data: "../../Data_Speech/viVoice/train.txt"
14
+ val_data: "../../Data_Speech/combine/combine_val.txt"
15
+ root_path: "../../Data_Speech/"
16
+
17
+ symbol: #Total 189 symbols
18
+ pad: "$"
19
+ punctuation: ';:,.!?¡¿—…"«»“” '
20
+ letters: "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
21
+ letters_ipa: "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
22
+ extend: "∫̆ăη͡123456" #ADD MORE SYMBOLS HERE
23
+
24
+ preprocess_params:
25
+ sr: 24000
26
+ spect_params:
27
+ n_fft: 2048
28
+ win_length: 1200
29
+ hop_length: 300
30
+
31
+ training_strats:
32
+ #All modules: 'decoder', 'predictor', 'text_encoder', 'style_encoder', 'text_aligner', 'pitch_extractor', 'mpd', 'msd'
33
+ freeze_modules: [''] # Not updated when training.
34
+ ignore_modules: [''] # Not loading => fresh start. IMPORTANT: 'text_aligner' and 'pitch_extractor' are util pretraineds DO NOT ignore them.
35
+
36
+ model_params:
37
+ dim_in: 64
38
+ hidden_dim: 512
39
+ max_conv_dim: 512
40
+ n_layer: 3
41
+ n_mels: 80
42
+ max_dur: 50 # maximum duration of a single phoneme
43
+ style_dim: 128 # style vector size
44
+
45
+ dropout: 0.2
46
+
47
+ ASR_params:
48
+ input_dim: 80
49
+ hidden_dim: 256
50
+ n_layers: 6
51
+ token_embedding_dim: 512
52
+
53
+ JDC_params:
54
+ num_class: 1
55
+ seq_len: 192
56
+
57
+ # config for decoder
58
+ decoder:
59
+ type: 'hifigan' # either hifigan or istftnet
60
+ resblock_kernel_sizes: [3,7,11]
61
+ upsample_rates : [10,5,3,2]
62
+ upsample_initial_channel: 512
63
+ resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
64
+ upsample_kernel_sizes: [20,10,6,4]
65
+
66
+ loss_params:
67
+ lambda_mel: 5. # mel reconstruction loss
68
+ lambda_gen: 1. # generator loss
69
+
70
+ lambda_mono: 1. # monotonic alignment loss (TMA)
71
+ lambda_s2s: 1. # sequence-to-sequence loss (TMA)
72
+
73
+ lambda_F0: 1. # F0 reconstruction loss
74
+ lambda_norm: 1. # norm reconstruction loss
75
+ lambda_dur: 1. # duration loss
76
+ lambda_ce: 20. # duration predictor probability output CE loss
77
+
78
+ optimizer_params:
79
+ lr: 0.0001 # general learning rate
80
+ ft_lr: 0.00001 # learning rate for acoustic modules
inference.py CHANGED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import yaml
3
+ from munch import Munch
4
+ import numpy as np
5
+ import librosa
6
+ import noisereduce as nr
7
+ from meldataset import TextCleaner
8
+ import torch
9
+ import torchaudio
10
+ from nltk.tokenize import word_tokenize
11
+ import nltk
12
+ nltk.download('punkt_tab')
13
+
14
+ from models import ProsodyPredictor, TextEncoder, StyleEncoder
15
+ from Modules.hifigan import Decoder
16
+
17
+ import sys
18
+ import phonemizer
19
+ if sys.platform.startswith("win"):
20
+ try:
21
+ from phonemizer.backend.espeak.wrapper import EspeakWrapper
22
+ import espeakng_loader
23
+ EspeakWrapper.set_library(espeakng_loader.get_library_path())
24
+ except Exception as e:
25
+ print(e)
26
+
27
+ def espeak_phn(text, lang):
28
+ try:
29
+ my_phonemizer = phonemizer.backend.EspeakBackend(language=lang, preserve_punctuation=True, with_stress=True, language_switch='remove-flags')
30
+ return my_phonemizer.phonemize([text])[0]
31
+ except Exception as e:
32
+ print(e)
33
+
34
+ class Preprocess:
35
+ def __text_normalize(self, text):
36
+ punctuation = [",", "、", "،", ";", "(", ".", "。", "…", "!", "–", ":", "?"]
37
+ map_to = "."
38
+ punctuation_pattern = re.compile(f"[{''.join(re.escape(p) for p in punctuation)}]")
39
+ #replace punctuation that acts like a comma or period
40
+ text = punctuation_pattern.sub(map_to, text)
41
+ #replace consecutive whitespace chars with a single space and strip leading/trailing spaces
42
+ text = re.sub(r'\s+', ' ', text).strip()
43
+ return text
44
+ def __merge_fragments(self, texts, n):
45
+ merged = []
46
+ i = 0
47
+ while i < len(texts):
48
+ fragment = texts[i]
49
+ j = i + 1
50
+ while len(fragment.split()) < n and j < len(texts):
51
+ fragment += ", " + texts[j]
52
+ j += 1
53
+ merged.append(fragment)
54
+ i = j
55
+ if len(merged[-1].split()) < n and len(merged) > 1: #handle last sentence
56
+ merged[-2] = merged[-2] + ", " + merged[-1]
57
+ del merged[-1]
58
+ else:
59
+ merged[-1] = merged[-1]
60
+ return merged
61
+ def wave_preprocess(self, wave):
62
+ to_mel = torchaudio.transforms.MelSpectrogram(n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
63
+ mean, std = -4, 4
64
+ wave_tensor = torch.from_numpy(wave).float()
65
+ mel_tensor = to_mel(wave_tensor)
66
+ mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
67
+ return mel_tensor
68
+ def text_preprocess(self, text, n_merge=12):
69
+ text_norm = self.__text_normalize(text).split(".")#split by sentences.
70
+ text_norm = [s.strip() for s in text_norm]
71
+ text_norm = list(filter(lambda x: x != '', text_norm)) #filter empty index
72
+ text_norm = self.__merge_fragments(text_norm, n=n_merge) #merge if a sentence has less that n
73
+ return text_norm
74
+ def length_to_mask(self, lengths):
75
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
76
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
77
+ return mask
78
+
79
+ #For inference only
80
+ class StyleTTS2(torch.nn.Module):
81
+ def __init__(self, config_path, models_path):
82
+ super().__init__()
83
+ self.register_buffer("get_device", torch.empty(0))
84
+ self.preprocess = Preprocess()
85
+ self.ref_s = None
86
+ config = yaml.safe_load(open(config_path, "r", encoding="utf-8"))
87
+
88
+ try:
89
+ symbols = (
90
+ list(config['symbol']['pad']) +
91
+ list(config['symbol']['punctuation']) +
92
+ list(config['symbol']['letters']) +
93
+ list(config['symbol']['letters_ipa']) +
94
+ list(config['symbol']['extend'])
95
+ )
96
+ symbol_dict = {}
97
+ for i in range(len((symbols))):
98
+ symbol_dict[symbols[i]] = i
99
+
100
+ n_token = len(symbol_dict) + 1
101
+ print("\nFound:", n_token, "symbols")
102
+ except Exception as e:
103
+ print(f"\nERROR: Cannot find {e} in config file!\nYour config file is likely outdated, please download updated version from the repository.")
104
+ raise SystemExit(1)
105
+
106
+ args = self.__recursive_munch(config['model_params'])
107
+ args['n_token'] = n_token
108
+
109
+ self.cleaner = TextCleaner(symbol_dict, debug=False)
110
+
111
+ assert args.decoder.type in ['hifigan'], 'Decoder type unknown'
112
+
113
+ self.decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
114
+ resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
115
+ upsample_rates = args.decoder.upsample_rates,
116
+ upsample_initial_channel=args.decoder.upsample_initial_channel,
117
+ resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
118
+ upsample_kernel_sizes=args.decoder.upsample_kernel_sizes)
119
+ self.predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout)
120
+ self.text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token)
121
+ self.style_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim)# acoustic style encoder
122
+
123
+ self.__load_models(models_path)
124
+
125
+ def __recursive_munch(self, d):
126
+ if isinstance(d, dict):
127
+ return Munch((k, self.__recursive_munch(v)) for k, v in d.items())
128
+ elif isinstance(d, list):
129
+ return [self.__recursive_munch(v) for v in d]
130
+ else:
131
+ return d
132
+
133
+ def __init_replacement_func(self, replacements):
134
+ replacement_iter = iter(replacements)
135
+ def replacement(match):
136
+ return next(replacement_iter)
137
+ return replacement
138
+
139
+ def __replace_outliers_zscore(self, tensor, threshold=3.0, factor=0.95):
140
+ mean = tensor.mean()
141
+ std = tensor.std()
142
+ z = (tensor - mean) / std
143
+
144
+ # Identify outliers
145
+ outlier_mask = torch.abs(z) > threshold
146
+ # Compute replacement value, respecting sign
147
+ sign = torch.sign(tensor - mean)
148
+ replacement = mean + sign * (threshold * std * factor)
149
+
150
+ result = tensor.clone()
151
+ result[outlier_mask] = replacement[outlier_mask]
152
+
153
+ return result
154
+
155
+ def __load_models(self, models_path):
156
+ module_params = []
157
+ model = {'decoder':self.decoder, 'predictor':self.predictor, 'text_encoder':self.text_encoder, 'style_encoder':self.style_encoder}
158
+
159
+ params_whole = torch.load(models_path, map_location='cpu')
160
+ params = params_whole['net']
161
+ params = {key: value for key, value in params.items() if key in model.keys()}
162
+
163
+ for key in model:
164
+ try:
165
+ model[key].load_state_dict(params[key])
166
+ except:
167
+ from collections import OrderedDict
168
+ state_dict = params[key]
169
+ new_state_dict = OrderedDict()
170
+ for k, v in state_dict.items():
171
+ name = k[7:] # remove `module.`
172
+ new_state_dict[name] = v
173
+ model[key].load_state_dict(new_state_dict, strict=False)
174
+
175
+ total_params = sum(p.numel() for p in model[key].parameters())
176
+ print(key,":",total_params)
177
+ module_params.append(total_params)
178
+
179
+ print('\nTotal',":",sum(module_params))
180
+
181
+ def __compute_style(self, path, denoise, split_dur):
182
+ device = self.get_device.device
183
+ denoise = min(denoise, 1)
184
+ if split_dur != 0: split_dur = max(int(split_dur), 1)
185
+ max_samples = 24000*20 #max 20 seconds ref audio
186
+ print("Computing the style for:", path)
187
+
188
+ wave, sr = librosa.load(path, sr=24000)
189
+ audio, index = librosa.effects.trim(wave, top_db=30)
190
+ if sr != 24000:
191
+ audio = librosa.resample(audio, sr, 24000)
192
+ if len(audio) > max_samples:
193
+ audio = audio[:max_samples]
194
+
195
+ if denoise > 0.0:
196
+ audio_denoise = nr.reduce_noise(y=audio, sr=sr, n_fft=2048, win_length=1200, hop_length=300)
197
+ audio = audio*(1-denoise) + audio_denoise*denoise
198
+
199
+ with torch.no_grad():
200
+ if split_dur>0 and len(audio)/sr>=4: #Only effective if audio length is >= 4s
201
+ #This option will split the ref audio to multiple parts, calculate styles and average them
202
+ count = 0
203
+ ref_s = None
204
+ jump = sr*split_dur
205
+ total_len = len(audio)
206
+
207
+ #Need to init before the loop
208
+ mel_tensor = self.preprocess.wave_preprocess(audio[0:jump]).to(device)
209
+ ref_s = self.style_encoder(mel_tensor.unsqueeze(1))
210
+ count += 1
211
+ for i in range(jump, total_len, jump):
212
+ if i+jump >= total_len:
213
+ left_dur = (total_len-i)/sr
214
+ if left_dur >= 1: #Still count if left over dur is >= 1s
215
+ mel_tensor = self.preprocess.wave_preprocess(audio[i:total_len]).to(device)
216
+ ref_s += self.style_encoder(mel_tensor.unsqueeze(1))
217
+ count += 1
218
+ continue
219
+ mel_tensor = self.preprocess.wave_preprocess(audio[i:i+jump]).to(device)
220
+ ref_s += self.style_encoder(mel_tensor.unsqueeze(1))
221
+ count += 1
222
+ ref_s /= count
223
+ else:
224
+ mel_tensor = self.preprocess.wave_preprocess(audio).to(device)
225
+ ref_s = self.style_encoder(mel_tensor.unsqueeze(1))
226
+
227
+ return ref_s
228
+
229
+ def __inference(self, phonem, ref_s, speed=1, prev_d_mean=0, t=0.1):
230
+ device = self.get_device.device
231
+ speed = min(max(speed, 0.0001), 2) #speed range [0, 2]
232
+
233
+ phonem = ' '.join(word_tokenize(phonem))
234
+ tokens = self.cleaner(phonem)
235
+ tokens.insert(0, 0)
236
+ tokens.append(0)
237
+ tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
238
+
239
+ with torch.no_grad():
240
+ input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
241
+ text_mask = self.preprocess.length_to_mask(input_lengths).to(device)
242
+
243
+ # encode
244
+ t_en = self.text_encoder(tokens, input_lengths, text_mask)
245
+ s = ref_s.to(device)
246
+
247
+ # cal alignment
248
+ d = self.predictor.text_encoder(t_en, s, input_lengths, text_mask)
249
+ x, _ = self.predictor.lstm(d)
250
+ duration = self.predictor.duration_proj(x)
251
+ duration = torch.sigmoid(duration).sum(axis=-1)
252
+
253
+ if prev_d_mean != 0:#Stabilize speaking speed between splits
254
+ dur_stats = torch.empty(duration.shape).normal_(mean=prev_d_mean, std=duration.std()).to(device)
255
+ else:
256
+ dur_stats = torch.empty(duration.shape).normal_(mean=duration.mean(), std=duration.std()).to(device)
257
+ duration = duration*(1-t) + dur_stats*t
258
+ duration[:,1:-2] = self.__replace_outliers_zscore(duration[:,1:-2]) #Normalize outlier
259
+
260
+ duration /= speed
261
+
262
+ pred_dur = torch.round(duration.squeeze()).clamp(min=1)
263
+ pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
264
+ c_frame = 0
265
+ for i in range(pred_aln_trg.size(0)):
266
+ pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
267
+ c_frame += int(pred_dur[i].data)
268
+ alignment = pred_aln_trg.unsqueeze(0).to(device)
269
+
270
+ # encode prosody
271
+ en = (d.transpose(-1, -2) @ alignment)
272
+ F0_pred, N_pred = self.predictor.F0Ntrain(en, s)
273
+ asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
274
+
275
+ out = self.decoder(asr, F0_pred, N_pred, s)
276
+
277
+ return out.squeeze().cpu().numpy(), duration.mean()
278
+
279
+ def get_styles(self, speakers, denoise=0.3, avg_style=True):
280
+ if avg_style: split_dur = 2
281
+ else: split_dur = 0
282
+ styles = {}
283
+ for id in speakers:
284
+ ref_s = self.__compute_style(speakers[id]['path'], denoise=denoise, split_dur=split_dur)
285
+ styles[id] = {
286
+ 'style': ref_s,
287
+ 'path': speakers[id]['path'],
288
+ 'lang': speakers[id]['lang'],
289
+ 'speed': speakers[id]['speed'],
290
+ }
291
+ return styles
292
+
293
+ def generate(self, text, styles, stabilize=True, n_merge=16, default_speaker= "[id_1]"):
294
+ if stabilize: smooth_value=0.2
295
+ else: smooth_value=0
296
+
297
+ list_wav = []
298
+ prev_d_mean = 0
299
+ lang_pattern = r'\[([^\]]+)\]\{([^}]+)\}'
300
+
301
+ text = re.sub(r'[\n\r\t\f\v]', '', text)
302
+ #fix lang tokens span to multiple sents
303
+ find_lang_tokens = re.findall(lang_pattern, text)
304
+ if find_lang_tokens:
305
+ cus_text = []
306
+ for lang, t in find_lang_tokens:
307
+ parts = self.preprocess.text_preprocess(t, n_merge=0)
308
+ parts = ".".join([f"[{lang}]" + f"{{{p}}}"for p in parts])
309
+ cus_text.append(parts)
310
+ replacement_func = self.__init_replacement_func(cus_text)
311
+ text = re.sub(lang_pattern, replacement_func, text)
312
+
313
+ texts = re.split(r'(\[id_\d+\])', text) #split the text by speaker ids while keeping the ids.
314
+ if len(texts) <= 1 or bool(re.match(r'(\[id_\d+\])', texts[0]) == False): #Add a default speaker
315
+ texts.insert(0, default_speaker)
316
+ curr_id = None
317
+ for i in range(len(texts)): #remove consecutive ids
318
+ if bool(re.match(r'(\[id_\d+\])', texts[i])):
319
+ if texts[i]!=curr_id:
320
+ curr_id = texts[i]
321
+ else:
322
+ texts[i] = ''
323
+ del curr_id
324
+ texts = list(filter(lambda x: x != '', texts))
325
+
326
+ print("Generating Audio...")
327
+ for i in texts:
328
+ if bool(re.match(r'(\[id_\d+\])', i)):
329
+ #Set up env for matched speaker
330
+ speaker_id = i.strip('[]')
331
+ current_ref_s = styles[speaker_id]['style']
332
+ speed = styles[speaker_id]['speed']
333
+ continue
334
+ text_norm = self.preprocess.text_preprocess(i, n_merge=n_merge)
335
+ for sentence in text_norm:
336
+ cus_phonem = []
337
+ find_lang_tokens = re.findall(lang_pattern, sentence)
338
+ if find_lang_tokens:
339
+ for lang, t in find_lang_tokens:
340
+ try:
341
+ phonem = espeak_phn(t, lang)
342
+ cus_phonem.append(phonem)
343
+ except Exception as e:
344
+ print(e)
345
+
346
+ replacement_func = self.__init_replacement_func(cus_phonem)
347
+ phonem = espeak_phn(sentence, styles[speaker_id]['lang'])
348
+ phonem = re.sub(lang_pattern, replacement_func, phonem)
349
+
350
+ wav, prev_d_mean = self.__inference(phonem, current_ref_s, speed=speed, prev_d_mean=prev_d_mean, t=smooth_value)
351
+ wav = wav[4000:-4000] #Remove weird pulse and silent tokens
352
+ list_wav.append(wav)
353
+
354
+ final_wav = np.concatenate(list_wav)
355
+ final_wav = np.concatenate([np.zeros([4000]), final_wav, np.zeros([4000])], axis=0) # add padding
356
+ return final_wav