stephenhoang commited on
Commit
fc5f72b
·
verified ·
1 Parent(s): 59629ab

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +95 -691
inference.py CHANGED
@@ -1,476 +1,3 @@
1
- # import re
2
- # import yaml
3
- # from munch import Munch
4
- # import numpy as np
5
- # import librosa
6
- # import noisereduce as nr
7
- # from meldataset import TextCleaner
8
- # import torch
9
- # import torchaudio
10
- # from nltk.tokenize import word_tokenize
11
- # import nltk
12
- # nltk.download('punkt_tab')
13
-
14
- # from models import ProsodyPredictor, TextEncoder, StyleEncoder
15
- # from Modules.hifigan import Decoder
16
-
17
- # import sys
18
- # import phonemizer
19
- # if sys.platform.startswith("win"):
20
- # try:
21
- # from phonemizer.backend.espeak.wrapper import EspeakWrapper
22
- # import espeakng_loader
23
- # EspeakWrapper.set_library(espeakng_loader.get_library_path())
24
- # except Exception as e:
25
- # print(e)
26
-
27
- # def espeak_phn(text, lang):
28
- # try:
29
- # my_phonemizer = phonemizer.backend.EspeakBackend(language=lang, preserve_punctuation=True, with_stress=True, language_switch='remove-flags')
30
- # return my_phonemizer.phonemize([text])[0]
31
- # except Exception as e:
32
- # print(e)
33
-
34
- # class Preprocess:
35
- # def __text_normalize(self, text):
36
- # punctuation = [",", "、", "،", ";", "(", ".", "。", "…", "!", "–", ":", "?"]
37
- # map_to = "."
38
- # punctuation_pattern = re.compile(f"[{''.join(re.escape(p) for p in punctuation)}]")
39
- # #replace punctuation that acts like a comma or period
40
- # text = punctuation_pattern.sub(map_to, text)
41
- # #replace consecutive whitespace chars with a single space and strip leading/trailing spaces
42
- # text = re.sub(r'\s+', ' ', text).strip()
43
- # return text
44
- # def __merge_fragments(self, texts, n):
45
- # merged = []
46
- # i = 0
47
- # while i < len(texts):
48
- # fragment = texts[i]
49
- # j = i + 1
50
- # while len(fragment.split()) < n and j < len(texts):
51
- # fragment += ", " + texts[j]
52
- # j += 1
53
- # merged.append(fragment)
54
- # i = j
55
- # if len(merged[-1].split()) < n and len(merged) > 1: #handle last sentence
56
- # merged[-2] = merged[-2] + ", " + merged[-1]
57
- # del merged[-1]
58
- # else:
59
- # merged[-1] = merged[-1]
60
- # return merged
61
- # def wave_preprocess(self, wave):
62
- # to_mel = torchaudio.transforms.MelSpectrogram(n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
63
- # mean, std = -4, 4
64
- # wave_tensor = torch.from_numpy(wave).float()
65
- # mel_tensor = to_mel(wave_tensor)
66
- # mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
67
- # return mel_tensor
68
- # def text_preprocess(self, text, n_merge=12):
69
- # text_norm = self.__text_normalize(text).split(".")#split by sentences.
70
- # text_norm = [s.strip() for s in text_norm]
71
- # text_norm = list(filter(lambda x: x != '', text_norm)) #filter empty index
72
- # text_norm = self.__merge_fragments(text_norm, n=n_merge) #merge if a sentence has less that n
73
- # return text_norm
74
- # def length_to_mask(self, lengths):
75
- # mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
76
- # mask = torch.gt(mask+1, lengths.unsqueeze(1))
77
- # return mask
78
-
79
-
80
- # import re
81
- # import sys
82
- # import yaml
83
- # import nltk
84
- # import numpy as np
85
- # import librosa
86
- # import torch
87
- # import phonemizer
88
- # import noisereduce as nr
89
- # from munch import Munch
90
- # from nltk.tokenize import word_tokenize
91
-
92
- # from meldataset import TextCleaner
93
- # from models import ProsodyPredictor, TextEncoder, StyleEncoder
94
- # from Modules.hifigan import Decoder
95
-
96
- # # Không download ở runtime trên Space (dễ treo / fail do network)
97
- # # nltk.download('punkt_tab')
98
- # # Nếu bạn cần, chuyển sang packages/requirements hoặc chạy local build step.
99
- # # Trên Space, khuyến nghị bỏ phụ thuộc NLTK hoặc thay bằng tokenizer đơn giản.
100
-
101
- # if sys.platform.startswith("win"):
102
- # try:
103
- # from phonemizer.backend.espeak.wrapper import EspeakWrapper
104
- # import espeakng_loader
105
- # EspeakWrapper.set_library(espeakng_loader.get_library_path())
106
- # except Exception as e:
107
- # print(e)
108
-
109
-
110
- # def espeak_phn(text, lang):
111
- # try:
112
- # my_phonemizer = phonemizer.backend.EspeakBackend(
113
- # language=lang,
114
- # preserve_punctuation=True,
115
- # with_stress=True,
116
- # language_switch="remove-flags",
117
- # )
118
- # return my_phonemizer.phonemize([text])[0]
119
- # except Exception as e:
120
- # print(e)
121
- # return text
122
-
123
-
124
- # class Preprocess:
125
- # def __text_normalize(self, text):
126
- # punctuation = [",", "、", "،", ";", "(", ".", "。", "…", "!", "–", ":", "?"]
127
- # map_to = "."
128
- # punctuation_pattern = re.compile(f"[{''.join(re.escape(p) for p in punctuation)}]")
129
- # text = punctuation_pattern.sub(map_to, text)
130
- # text = re.sub(r"\s+", " ", text).strip()
131
- # return text
132
-
133
- # def __merge_fragments(self, texts, n):
134
- # merged = []
135
- # i = 0
136
- # while i < len(texts):
137
- # fragment = texts[i]
138
- # j = i + 1
139
- # while len(fragment.split()) < n and j < len(texts):
140
- # fragment += ", " + texts[j]
141
- # j += 1
142
- # merged.append(fragment)
143
- # i = j
144
-
145
- # if len(merged) > 1 and len(merged[-1].split()) < n:
146
- # merged[-2] = merged[-2] + ", " + merged[-1]
147
- # del merged[-1]
148
- # return merged
149
-
150
- # def wave_preprocess(self, wave, sr=24000):
151
- # """
152
- # Thay torchaudio bằng librosa để tránh dependency torchaudio trên HF Space.
153
- # Output giống shape cũ: (1, 80, T)
154
- # """
155
- # if wave is None:
156
- # raise ValueError("wave is None")
157
- # wave = np.asarray(wave)
158
- # if wave.ndim != 1:
159
- # wave = wave.squeeze()
160
- # wave = wave.astype(np.float32)
161
-
162
- # # Mel spectrogram (power). Nếu muốn khớp torchaudio default power=2.0, để power=2.0.
163
- # mel = librosa.feature.melspectrogram(
164
- # y=wave,
165
- # sr=sr,
166
- # n_fft=2048,
167
- # win_length=1200,
168
- # hop_length=300,
169
- # n_mels=80,
170
- # power=2.0,
171
- # ) # (80, T)
172
-
173
- # mean, std = -4, 4
174
- # mel = np.log(1e-5 + mel)
175
- # mel = (mel - mean) / std
176
-
177
- # mel_tensor = torch.from_numpy(mel).float().unsqueeze(0) # (1, 80, T)
178
- # return mel_tensor
179
-
180
- # def text_preprocess(self, text, n_merge=12):
181
- # text_norm = self.__text_normalize(text).split(".")
182
- # text_norm = [s.strip() for s in text_norm]
183
- # text_norm = list(filter(lambda x: x != "", text_norm))
184
- # text_norm = self.__merge_fragments(text_norm, n=n_merge)
185
- # return text_norm
186
-
187
- # def length_to_mask(self, lengths):
188
- # mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
189
- # mask = torch.gt(mask + 1, lengths.unsqueeze(1))
190
- # return mask
191
-
192
-
193
- # #For inference only
194
- # class StyleTTS2(torch.nn.Module):
195
- # def __init__(self, config_path, models_path):
196
- # super().__init__()
197
- # self.register_buffer("get_device", torch.empty(0))
198
- # self.preprocess = Preprocess()
199
- # self.ref_s = None
200
- # config = yaml.safe_load(open(config_path, "r", encoding="utf-8"))
201
-
202
- # try:
203
- # symbols = (
204
- # list(config['symbol']['pad']) +
205
- # list(config['symbol']['punctuation']) +
206
- # list(config['symbol']['letters']) +
207
- # list(config['symbol']['letters_ipa']) +
208
- # list(config['symbol']['extend'])
209
- # )
210
- # symbol_dict = {}
211
- # for i in range(len((symbols))):
212
- # symbol_dict[symbols[i]] = i
213
-
214
- # n_token = len(symbol_dict) + 1
215
- # print("\nFound:", n_token, "symbols")
216
- # except Exception as e:
217
- # print(f"\nERROR: Cannot find {e} in config file!\nYour config file is likely outdated, please download updated version from the repository.")
218
- # raise SystemExit(1)
219
-
220
- # args = self.__recursive_munch(config['model_params'])
221
- # args['n_token'] = n_token
222
-
223
- # self.cleaner = TextCleaner(symbol_dict, debug=False)
224
-
225
- # assert args.decoder.type in ['hifigan'], 'Decoder type unknown'
226
-
227
- # self.decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
228
- # resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
229
- # upsample_rates = args.decoder.upsample_rates,
230
- # upsample_initial_channel=args.decoder.upsample_initial_channel,
231
- # resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
232
- # upsample_kernel_sizes=args.decoder.upsample_kernel_sizes)
233
- # self.predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout)
234
- # self.text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token)
235
- # self.style_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim)# acoustic style encoder
236
-
237
- # self.__load_models(models_path)
238
-
239
- # def __recursive_munch(self, d):
240
- # if isinstance(d, dict):
241
- # return Munch((k, self.__recursive_munch(v)) for k, v in d.items())
242
- # elif isinstance(d, list):
243
- # return [self.__recursive_munch(v) for v in d]
244
- # else:
245
- # return d
246
-
247
- # def __init_replacement_func(self, replacements):
248
- # replacement_iter = iter(replacements)
249
- # def replacement(match):
250
- # return next(replacement_iter)
251
- # return replacement
252
-
253
- # def __replace_outliers_zscore(self, tensor, threshold=3.0, factor=0.95):
254
- # mean = tensor.mean()
255
- # std = tensor.std()
256
- # z = (tensor - mean) / std
257
-
258
- # # Identify outliers
259
- # outlier_mask = torch.abs(z) > threshold
260
- # # Compute replacement value, respecting sign
261
- # sign = torch.sign(tensor - mean)
262
- # replacement = mean + sign * (threshold * std * factor)
263
-
264
- # result = tensor.clone()
265
- # result[outlier_mask] = replacement[outlier_mask]
266
-
267
- # return result
268
-
269
- # def __load_models(self, models_path):
270
- # module_params = []
271
- # model = {'decoder':self.decoder, 'predictor':self.predictor, 'text_encoder':self.text_encoder, 'style_encoder':self.style_encoder}
272
-
273
- # params_whole = torch.load(models_path, map_location='cpu')
274
- # params = params_whole['net']
275
- # params = {key: value for key, value in params.items() if key in model.keys()}
276
-
277
- # for key in model:
278
- # try:
279
- # model[key].load_state_dict(params[key])
280
- # except:
281
- # from collections import OrderedDict
282
- # state_dict = params[key]
283
- # new_state_dict = OrderedDict()
284
- # for k, v in state_dict.items():
285
- # name = k[7:] # remove `module.`
286
- # new_state_dict[name] = v
287
- # model[key].load_state_dict(new_state_dict, strict=False)
288
-
289
- # total_params = sum(p.numel() for p in model[key].parameters())
290
- # print(key,":",total_params)
291
- # module_params.append(total_params)
292
-
293
- # print('\nTotal',":",sum(module_params))
294
-
295
- # def __compute_style(self, path, denoise, split_dur):
296
- # device = self.get_device.device
297
- # denoise = min(denoise, 1)
298
- # if split_dur != 0: split_dur = max(int(split_dur), 1)
299
- # max_samples = 24000*20 #max 20 seconds ref audio
300
- # print("Computing the style for:", path)
301
-
302
- # wave, sr = librosa.load(path, sr=24000)
303
- # audio, index = librosa.effects.trim(wave, top_db=30)
304
- # if sr != 24000:
305
- # audio = librosa.resample(audio, sr, 24000)
306
- # if len(audio) > max_samples:
307
- # audio = audio[:max_samples]
308
-
309
- # if denoise > 0.0:
310
- # audio_denoise = nr.reduce_noise(y=audio, sr=sr, n_fft=2048, win_length=1200, hop_length=300)
311
- # audio = audio*(1-denoise) + audio_denoise*denoise
312
-
313
- # with torch.no_grad():
314
- # if split_dur>0 and len(audio)/sr>=4: #Only effective if audio length is >= 4s
315
- # #This option will split the ref audio to multiple parts, calculate styles and average them
316
- # count = 0
317
- # ref_s = None
318
- # jump = sr*split_dur
319
- # total_len = len(audio)
320
-
321
- # #Need to init before the loop
322
- # mel_tensor = self.preprocess.wave_preprocess(audio[0:jump]).to(device)
323
- # ref_s = self.style_encoder(mel_tensor.unsqueeze(1))
324
- # count += 1
325
- # for i in range(jump, total_len, jump):
326
- # if i+jump >= total_len:
327
- # left_dur = (total_len-i)/sr
328
- # if left_dur >= 1: #Still count if left over dur is >= 1s
329
- # mel_tensor = self.preprocess.wave_preprocess(audio[i:total_len]).to(device)
330
- # ref_s += self.style_encoder(mel_tensor.unsqueeze(1))
331
- # count += 1
332
- # continue
333
- # mel_tensor = self.preprocess.wave_preprocess(audio[i:i+jump]).to(device)
334
- # ref_s += self.style_encoder(mel_tensor.unsqueeze(1))
335
- # count += 1
336
- # ref_s /= count
337
- # else:
338
- # mel_tensor = self.preprocess.wave_preprocess(audio).to(device)
339
- # ref_s = self.style_encoder(mel_tensor.unsqueeze(1))
340
-
341
- # return ref_s
342
-
343
- # def __inference(self, phonem, ref_s, speed=1, prev_d_mean=0, t=0.1):
344
- # device = self.get_device.device
345
- # speed = min(max(speed, 0.0001), 2) #speed range [0, 2]
346
-
347
- # phonem = ' '.join(word_tokenize(phonem))
348
- # tokens = self.cleaner(phonem)
349
- # tokens.insert(0, 0)
350
- # tokens.append(0)
351
- # tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
352
-
353
- # with torch.no_grad():
354
- # input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
355
- # text_mask = self.preprocess.length_to_mask(input_lengths).to(device)
356
-
357
- # # encode
358
- # t_en = self.text_encoder(tokens, input_lengths, text_mask)
359
- # s = ref_s.to(device)
360
-
361
- # # cal alignment
362
- # d = self.predictor.text_encoder(t_en, s, input_lengths, text_mask)
363
- # x, _ = self.predictor.lstm(d)
364
- # duration = self.predictor.duration_proj(x)
365
- # duration = torch.sigmoid(duration).sum(axis=-1)
366
-
367
- # if prev_d_mean != 0:#Stabilize speaking speed between splits
368
- # dur_stats = torch.empty(duration.shape).normal_(mean=prev_d_mean, std=duration.std()).to(device)
369
- # else:
370
- # dur_stats = torch.empty(duration.shape).normal_(mean=duration.mean(), std=duration.std()).to(device)
371
- # duration = duration*(1-t) + dur_stats*t
372
- # duration[:,1:-2] = self.__replace_outliers_zscore(duration[:,1:-2]) #Normalize outlier
373
-
374
- # duration /= speed
375
-
376
- # pred_dur = torch.round(duration.squeeze()).clamp(min=1)
377
- # pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
378
- # c_frame = 0
379
- # for i in range(pred_aln_trg.size(0)):
380
- # pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
381
- # c_frame += int(pred_dur[i].data)
382
- # alignment = pred_aln_trg.unsqueeze(0).to(device)
383
-
384
- # # encode prosody
385
- # en = (d.transpose(-1, -2) @ alignment)
386
- # F0_pred, N_pred = self.predictor.F0Ntrain(en, s)
387
- # asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
388
-
389
- # out = self.decoder(asr, F0_pred, N_pred, s)
390
-
391
- # return out.squeeze().cpu().numpy(), duration.mean()
392
-
393
- # def get_styles(self, speakers, denoise=0.3, avg_style=True):
394
- # if avg_style: split_dur = 2
395
- # else: split_dur = 0
396
- # styles = {}
397
- # for id in speakers:
398
- # ref_s = self.__compute_style(speakers[id]['path'], denoise=denoise, split_dur=split_dur)
399
- # styles[id] = {
400
- # 'style': ref_s,
401
- # 'path': speakers[id]['path'],
402
- # 'lang': speakers[id]['lang'],
403
- # 'speed': speakers[id]['speed'],
404
- # }
405
- # return styles
406
-
407
- # def generate(self, text, styles, stabilize=True, n_merge=16, default_speaker= "[id_1]"):
408
- # if stabilize: smooth_value=0.2
409
- # else: smooth_value=0
410
-
411
- # list_wav = []
412
- # prev_d_mean = 0
413
- # lang_pattern = r'\[([^\]]+)\]\{([^}]+)\}'
414
-
415
- # text = re.sub(r'[\n\r\t\f\v]', '', text)
416
- # #fix lang tokens span to multiple sents
417
- # find_lang_tokens = re.findall(lang_pattern, text)
418
- # if find_lang_tokens:
419
- # cus_text = []
420
- # for lang, t in find_lang_tokens:
421
- # parts = self.preprocess.text_preprocess(t, n_merge=0)
422
- # parts = ".".join([f"[{lang}]" + f"{{{p}}}"for p in parts])
423
- # cus_text.append(parts)
424
- # replacement_func = self.__init_replacement_func(cus_text)
425
- # text = re.sub(lang_pattern, replacement_func, text)
426
-
427
- # texts = re.split(r'(\[id_\d+\])', text) #split the text by speaker ids while keeping the ids.
428
- # if len(texts) <= 1 or bool(re.match(r'(\[id_\d+\])', texts[0]) == False): #Add a default speaker
429
- # texts.insert(0, default_speaker)
430
- # curr_id = None
431
- # for i in range(len(texts)): #remove consecutive ids
432
- # if bool(re.match(r'(\[id_\d+\])', texts[i])):
433
- # if texts[i]!=curr_id:
434
- # curr_id = texts[i]
435
- # else:
436
- # texts[i] = ''
437
- # del curr_id
438
- # texts = list(filter(lambda x: x != '', texts))
439
-
440
- # print("Generating Audio...")
441
- # for i in texts:
442
- # if bool(re.match(r'(\[id_\d+\])', i)):
443
- # #Set up env for matched speaker
444
- # speaker_id = i.strip('[]')
445
- # current_ref_s = styles[speaker_id]['style']
446
- # speed = styles[speaker_id]['speed']
447
- # continue
448
- # text_norm = self.preprocess.text_preprocess(i, n_merge=n_merge)
449
- # for sentence in text_norm:
450
- # cus_phonem = []
451
- # find_lang_tokens = re.findall(lang_pattern, sentence)
452
- # if find_lang_tokens:
453
- # for lang, t in find_lang_tokens:
454
- # try:
455
- # phonem = espeak_phn(t, lang)
456
- # cus_phonem.append(phonem)
457
- # except Exception as e:
458
- # print(e)
459
-
460
- # replacement_func = self.__init_replacement_func(cus_phonem)
461
- # phonem = espeak_phn(sentence, styles[speaker_id]['lang'])
462
- # phonem = re.sub(lang_pattern, replacement_func, phonem)
463
-
464
- # wav, prev_d_mean = self.__inference(phonem, current_ref_s, speed=speed, prev_d_mean=prev_d_mean, t=smooth_value)
465
- # wav = wav[4000:-4000] #Remove weird pulse and silent tokens
466
- # list_wav.append(wav)
467
-
468
- # final_wav = np.concatenate(list_wav)
469
- # final_wav = np.concatenate([np.zeros([4000]), final_wav, np.zeros([4000])], axis=0) # add padding
470
- # return final_wav
471
-
472
-
473
-
474
  import re
475
  import sys
476
  import yaml
@@ -485,9 +12,8 @@ from meldataset import TextCleaner
485
  from models import ProsodyPredictor, TextEncoder, StyleEncoder
486
  from Modules.hifigan import Decoder
487
 
488
-
489
  # -------------------------
490
- # Windows-only espeak-ng loader (không ảnh hưởng Linux/Space)
491
  # -------------------------
492
  if sys.platform.startswith("win"):
493
  try:
@@ -497,34 +23,29 @@ if sys.platform.startswith("win"):
497
  except Exception as e:
498
  print(e)
499
 
 
 
 
 
500
 
501
- def espeak_phn(text, lang):
502
  """
503
- Trả về phoneme string từ espeak backend.
504
- Nếu backend fail, trả về text gốc (để không crash).
505
  """
506
  try:
507
- my_phonemizer = phonemizer.backend.EspeakBackend(
508
  language=lang,
509
  preserve_punctuation=True,
510
  with_stress=True,
511
  language_switch="remove-flags",
512
  )
513
- return my_phonemizer.phonemize([text])[0]
 
 
 
 
514
  except Exception as e:
515
- print("[espeak_phn error]", e)
516
- return text
517
-
518
-
519
- # -------------------------
520
- # Tokenization thay cho nltk.word_tokenize
521
- # Với phoneme/IPA, normalize whitespace là đủ.
522
- # -------------------------
523
- _TOKEN_RE = re.compile(r"\S+")
524
-
525
- def normalize_phonem_tokens(phonem: str) -> str:
526
- return " ".join(_TOKEN_RE.findall((phonem or "").strip()))
527
-
528
 
529
  class Preprocess:
530
  def __text_normalize(self, text):
@@ -553,17 +74,7 @@ class Preprocess:
553
  return merged
554
 
555
  def wave_preprocess(self, wave, sr=24000):
556
- """
557
- Không dùng torchaudio.
558
- Tạo log-mel bằng librosa, output shape (1, 80, T) giống code gốc.
559
- """
560
- if wave is None:
561
- raise ValueError("wave is None")
562
- wave = np.asarray(wave)
563
- if wave.ndim != 1:
564
- wave = wave.squeeze()
565
- wave = wave.astype(np.float32)
566
-
567
  mel = librosa.feature.melspectrogram(
568
  y=wave,
569
  sr=sr,
@@ -577,58 +88,41 @@ class Preprocess:
577
  mean, std = -4, 4
578
  mel = np.log(1e-5 + mel)
579
  mel = (mel - mean) / std
580
-
581
- mel_tensor = torch.from_numpy(mel).float().unsqueeze(0) # (1, 80, T)
582
- return mel_tensor
583
 
584
  def text_preprocess(self, text, n_merge=12):
585
  text_norm = self.__text_normalize(text).split(".")
586
- text_norm = [s.strip() for s in text_norm]
587
- text_norm = list(filter(lambda x: x != "", text_norm))
588
- text_norm = self.__merge_fragments(text_norm, n=n_merge)
589
- return text_norm
590
 
591
  def length_to_mask(self, lengths):
592
  mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
593
- mask = torch.gt(mask + 1, lengths.unsqueeze(1))
594
- return mask
595
-
596
 
597
- # For inference only
598
  class StyleTTS2(torch.nn.Module):
599
  def __init__(self, config_path, models_path):
600
  super().__init__()
601
  self.register_buffer("get_device", torch.empty(0))
602
  self.preprocess = Preprocess()
603
- self.ref_s = None
604
 
605
  config = yaml.safe_load(open(config_path, "r", encoding="utf-8"))
606
 
607
- try:
608
- symbols = (
609
- list(config["symbol"]["pad"])
610
- + list(config["symbol"]["punctuation"])
611
- + list(config["symbol"]["letters"])
612
- + list(config["symbol"]["letters_ipa"])
613
- + list(config["symbol"]["extend"])
614
- )
615
- symbol_dict = {symbols[i]: i for i in range(len(symbols))}
616
- n_token = len(symbol_dict) + 1
617
- print("\nFound:", n_token, "symbols")
618
- except Exception as e:
619
- print(
620
- f"\nERROR: Cannot find {e} in config file!\n"
621
- "Your config file is likely outdated, please download updated version from the repository."
622
- )
623
- raise SystemExit(1)
624
 
625
  args = self.__recursive_munch(config["model_params"])
626
  args["n_token"] = n_token
627
 
628
  self.cleaner = TextCleaner(symbol_dict, debug=False)
629
 
630
- assert args.decoder.type in ["hifigan"], "Decoder type unknown"
631
-
632
  self.decoder = Decoder(
633
  dim_in=args.hidden_dim,
634
  style_dim=args.style_dim,
@@ -663,21 +157,14 @@ class StyleTTS2(torch.nn.Module):
663
  def __recursive_munch(self, d):
664
  if isinstance(d, dict):
665
  return Munch((k, self.__recursive_munch(v)) for k, v in d.items())
666
- elif isinstance(d, list):
667
  return [self.__recursive_munch(v) for v in d]
668
- else:
669
- return d
670
-
671
- def __init_replacement_func(self, replacements):
672
- replacement_iter = iter(replacements)
673
- def replacement(match):
674
- return next(replacement_iter)
675
- return replacement
676
 
677
  def __replace_outliers_zscore(self, tensor, threshold=3.0, factor=0.95):
678
  mean = tensor.mean()
679
  std = tensor.std()
680
- z = (tensor - mean) / std
681
  outlier_mask = torch.abs(z) > threshold
682
  sign = torch.sign(tensor - mean)
683
  replacement = mean + sign * (threshold * std * factor)
@@ -686,7 +173,6 @@ class StyleTTS2(torch.nn.Module):
686
  return result
687
 
688
  def __load_models(self, models_path):
689
- module_params = []
690
  model = {
691
  "decoder": self.decoder,
692
  "predictor": self.predictor,
@@ -696,45 +182,28 @@ class StyleTTS2(torch.nn.Module):
696
 
697
  params_whole = torch.load(models_path, map_location="cpu")
698
  params = params_whole["net"]
699
- params = {key: value for key, value in params.items() if key in model.keys()}
700
 
701
- for key in model:
702
  try:
703
- model[key].load_state_dict(params[key])
704
  except Exception:
705
  from collections import OrderedDict
706
- state_dict = params[key]
707
  new_state_dict = OrderedDict()
708
- for k, v in state_dict.items():
709
- name = k[7:] # remove `module.`
710
- new_state_dict[name] = v
711
- model[key].load_state_dict(new_state_dict, strict=False)
712
-
713
- total_params = sum(p.numel() for p in model[key].parameters())
714
- print(key, ":", total_params)
715
- module_params.append(total_params)
716
 
717
- print("\nTotal", ":", sum(module_params))
718
 
719
  def __compute_style(self, path, denoise, split_dur):
720
  device = self.get_device.device
721
  denoise = min(float(denoise), 1.0)
722
- if split_dur != 0:
723
- split_dur = max(int(split_dur), 1)
724
-
725
- max_samples = 24000 * 20
726
- print("Computing the style for:", path)
727
 
728
  wave, sr = librosa.load(path, sr=24000)
729
  audio, _ = librosa.effects.trim(wave, top_db=30)
730
 
731
- if sr != 24000:
732
- audio = librosa.resample(audio, orig_sr=sr, target_sr=24000)
733
- sr = 24000
734
-
735
- if len(audio) > max_samples:
736
- audio = audio[:max_samples]
737
-
738
  if denoise > 0.0:
739
  audio_denoise = nr.reduce_noise(
740
  y=audio, sr=sr, n_fft=2048, win_length=1200, hop_length=300
@@ -743,49 +212,39 @@ class StyleTTS2(torch.nn.Module):
743
 
744
  with torch.no_grad():
745
  if split_dur > 0 and len(audio) / sr >= 4:
746
- count = 0
747
  jump = sr * split_dur
748
  total_len = len(audio)
 
 
749
 
750
- mel_tensor = self.preprocess.wave_preprocess(audio[0:jump]).to(device)
751
- ref_s = self.style_encoder(mel_tensor.unsqueeze(1))
752
- count += 1
753
-
754
- for i in range(jump, total_len, jump):
755
- if i + jump >= total_len:
756
- left_dur = (total_len - i) / sr
757
- if left_dur >= 1:
758
- mel_tensor = self.preprocess.wave_preprocess(audio[i:total_len]).to(device)
759
- ref_s += self.style_encoder(mel_tensor.unsqueeze(1))
760
- count += 1
761
  continue
762
-
763
- mel_tensor = self.preprocess.wave_preprocess(audio[i : i + jump]).to(device)
764
- ref_s += self.style_encoder(mel_tensor.unsqueeze(1))
765
  count += 1
766
 
767
- ref_s /= count
 
 
 
 
768
  else:
769
- mel_tensor = self.preprocess.wave_preprocess(audio).to(device)
770
- ref_s = self.style_encoder(mel_tensor.unsqueeze(1))
771
 
772
  return ref_s
773
 
774
- def __inference(self, phonem, ref_s, speed=1, prev_d_mean=0, t=0.1):
775
  device = self.get_device.device
776
- speed = min(max(float(speed), 0.0001), 2.0)
777
 
778
  phonem = normalize_phonem_tokens(phonem)
779
-
780
  tokens = self.cleaner(phonem)
781
- tokens.insert(0, 0)
782
- tokens.append(0)
783
-
784
- # Guard: nếu cleaner trả rỗng thì fail sớm thay vì tạo audio 0s
785
- if len(tokens) <= 2:
786
- return np.zeros((0,), dtype=np.float32), 0.0
787
-
788
- tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
789
 
790
  with torch.no_grad():
791
  input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
@@ -797,41 +256,34 @@ class StyleTTS2(torch.nn.Module):
797
  d = self.predictor.text_encoder(t_en, s, input_lengths, text_mask)
798
  x, _ = self.predictor.lstm(d)
799
  duration = self.predictor.duration_proj(x)
800
- duration = torch.sigmoid(duration).sum(axis=-1)
801
 
802
  if prev_d_mean != 0:
803
- dur_stats = torch.empty(duration.shape).normal_(
804
- mean=prev_d_mean, std=duration.std()
805
- ).to(device)
806
  else:
807
- dur_stats = torch.empty(duration.shape).normal_(
808
- mean=duration.mean(), std=duration.std()
809
- ).to(device)
810
 
811
  duration = duration * (1 - t) + dur_stats * t
812
- if duration.shape[1] > 3:
813
- duration[:, 1:-2] = self.__replace_outliers_zscore(duration[:, 1:-2])
814
- duration /= speed
815
 
816
- pred_dur = torch.round(duration.squeeze()).clamp(min=1)
817
 
818
  L = int(input_lengths.item())
819
  T = int(pred_dur.sum().item())
820
- if T <= 0:
821
- return np.zeros((0,), dtype=np.float32), float(duration.mean().item())
822
 
823
- pred_aln_trg = torch.zeros((L, T))
824
- c_frame = 0
825
  for i in range(L):
826
  di = int(pred_dur[i].item())
827
- pred_aln_trg[i, c_frame : c_frame + di] = 1
828
- c_frame += di
829
 
830
- alignment = pred_aln_trg.unsqueeze(0).to(device)
831
 
832
- en = (d.transpose(-1, -2) @ alignment)
833
  F0_pred, N_pred = self.predictor.F0Ntrain(en, s)
834
- asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
835
 
836
  out = self.decoder(asr, F0_pred, N_pred, s)
837
 
@@ -840,17 +292,13 @@ class StyleTTS2(torch.nn.Module):
840
  def get_styles(self, speakers, denoise=0.3, avg_style=True):
841
  split_dur = 2 if avg_style else 0
842
  styles = {}
843
- for sid in speakers:
844
- ref_s = self.__compute_style(
845
- speakers[sid]["path"],
846
- denoise=denoise,
847
- split_dur=split_dur,
848
- )
849
  styles[sid] = {
850
  "style": ref_s,
851
- "path": speakers[sid]["path"],
852
- "lang": speakers[sid]["lang"],
853
- "speed": speakers[sid]["speed"],
854
  }
855
  return styles
856
 
@@ -860,89 +308,45 @@ class StyleTTS2(torch.nn.Module):
860
  list_wav = []
861
  prev_d_mean = 0.0
862
  lang_pattern = r"\[([^\]]+)\]\{([^}]+)\}"
863
-
864
  text = re.sub(r"[\n\r\t\f\v]", "", text)
865
 
866
- # fix lang tokens span to multiple sents
867
- find_lang_tokens = re.findall(lang_pattern, text)
868
- if find_lang_tokens:
869
- cus_text = []
870
- for lang, t in find_lang_tokens:
871
- parts = self.preprocess.text_preprocess(t, n_merge=0)
872
- parts = ".".join([f"[{lang}]{{{p}}}" for p in parts])
873
- cus_text.append(parts)
874
- replacement_func = self.__init_replacement_func(cus_text)
875
- text = re.sub(lang_pattern, replacement_func, text)
876
-
877
- texts = re.split(r"(\[id_\d+\])", text)
878
- if len(texts) <= 1 or (re.match(r"(\[id_\d+\])", texts[0]) is None):
879
- texts.insert(0, default_speaker)
880
-
881
- # remove consecutive ids
882
- curr_id = None
883
- for i in range(len(texts)):
884
- if re.match(r"(\[id_\d+\])", texts[i]):
885
- if texts[i] != curr_id:
886
- curr_id = texts[i]
887
- else:
888
- texts[i] = ""
889
- texts = list(filter(lambda x: x != "", texts))
890
-
891
- print("Generating Audio...")
892
 
893
  speaker_id = None
894
  current_ref_s = None
895
  speed = 1.0
896
 
897
- for seg in texts:
898
- if re.match(r"(\[id_\d+\])", seg):
899
- speaker_id = seg.strip("[]") # "id_1"
900
- if speaker_id not in styles:
901
- raise KeyError(f"speaker_id '{speaker_id}' not found in styles keys={list(styles.keys())[:5]}...")
902
  current_ref_s = styles[speaker_id]["style"]
903
  speed = styles[speaker_id]["speed"]
904
  continue
905
 
906
- if speaker_id is None or current_ref_s is None:
907
- # input text không có speaker tag hợp lệ
908
- speaker_id = default_speaker.strip("[]")
909
- current_ref_s = styles[speaker_id]["style"]
910
- speed = styles[speaker_id]["speed"]
911
-
912
- text_norm = self.preprocess.text_preprocess(seg, n_merge=n_merge)
913
- for sentence in text_norm:
914
- cus_phonem = []
915
- find_lang_tokens = re.findall(lang_pattern, sentence)
916
- if find_lang_tokens:
917
- for lang, t in find_lang_tokens:
918
- cus_phonem.append(espeak_phn(t, lang))
919
 
920
- replacement_func = self.__init_replacement_func(cus_phonem)
 
921
  phonem = espeak_phn(sentence, styles[speaker_id]["lang"])
922
- phonem = re.sub(lang_pattern, replacement_func, phonem)
923
-
924
  wav, prev_d_mean = self.__inference(
925
- phonem,
926
- current_ref_s,
927
- speed=speed,
928
- prev_d_mean=prev_d_mean,
929
- t=smooth_value,
930
  )
931
 
932
- if wav is None or wav.shape[0] == 0:
933
- continue
934
-
935
  trim = 4000
936
  if wav.shape[0] > 2 * trim:
937
  wav = wav[trim:-trim]
938
 
939
- # chỉ append 1 lần
940
- list_wav.append(wav)
941
 
942
  if len(list_wav) == 0:
943
- # trả một đoạn silence ngắn để tránh 0s file
944
- return np.zeros((2400,), dtype=np.float32)
945
 
946
- final_wav = np.concatenate(list_wav, axis=0)
947
- final_wav = np.concatenate([np.zeros([4000]), final_wav, np.zeros([4000])], axis=0)
948
  return final_wav
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import re
2
  import sys
3
  import yaml
 
12
  from models import ProsodyPredictor, TextEncoder, StyleEncoder
13
  from Modules.hifigan import Decoder
14
 
 
15
  # -------------------------
16
+ # Windows-only espeak-ng loader
17
  # -------------------------
18
  if sys.platform.startswith("win"):
19
  try:
 
23
  except Exception as e:
24
  print(e)
25
 
26
+ _TOKEN_RE = re.compile(r"\S+")
27
+
28
+ def normalize_phonem_tokens(phonem: str) -> str:
29
+ return " ".join(_TOKEN_RE.findall((phonem or "").strip()))
30
 
31
+ def espeak_phn(text: str, lang: str) -> str:
32
  """
33
+ Nếu phonemizer/espeak lỗi -> raise để bạn biết ngay thiếu espeak-ng / libespeak-ng1 / voice 'vi'
 
34
  """
35
  try:
36
+ backend = phonemizer.backend.EspeakBackend(
37
  language=lang,
38
  preserve_punctuation=True,
39
  with_stress=True,
40
  language_switch="remove-flags",
41
  )
42
+ out = backend.phonemize([text])[0]
43
+ out = (out or "").strip()
44
+ if len(out) == 0:
45
+ raise RuntimeError(f"phonemizer returned empty output for lang='{lang}', text='{text[:50]}'")
46
+ return out
47
  except Exception as e:
48
+ raise RuntimeError(f"espeak/phonemizer failed (lang={lang}). Error: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  class Preprocess:
51
  def __text_normalize(self, text):
 
74
  return merged
75
 
76
  def wave_preprocess(self, wave, sr=24000):
77
+ wave = np.asarray(wave, dtype=np.float32).squeeze()
 
 
 
 
 
 
 
 
 
 
78
  mel = librosa.feature.melspectrogram(
79
  y=wave,
80
  sr=sr,
 
88
  mean, std = -4, 4
89
  mel = np.log(1e-5 + mel)
90
  mel = (mel - mean) / std
91
+ return torch.from_numpy(mel).float().unsqueeze(0) # (1, 80, T)
 
 
92
 
93
  def text_preprocess(self, text, n_merge=12):
94
  text_norm = self.__text_normalize(text).split(".")
95
+ text_norm = [s.strip() for s in text_norm if s.strip()]
96
+ return self.__merge_fragments(text_norm, n=n_merge)
 
 
97
 
98
  def length_to_mask(self, lengths):
99
  mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
100
+ return torch.gt(mask + 1, lengths.unsqueeze(1))
 
 
101
 
 
102
  class StyleTTS2(torch.nn.Module):
103
  def __init__(self, config_path, models_path):
104
  super().__init__()
105
  self.register_buffer("get_device", torch.empty(0))
106
  self.preprocess = Preprocess()
 
107
 
108
  config = yaml.safe_load(open(config_path, "r", encoding="utf-8"))
109
 
110
+ symbols = (
111
+ list(config["symbol"]["pad"])
112
+ + list(config["symbol"]["punctuation"])
113
+ + list(config["symbol"]["letters"])
114
+ + list(config["symbol"]["letters_ipa"])
115
+ + list(config["symbol"]["extend"])
116
+ )
117
+ symbol_dict = {s: i for i, s in enumerate(symbols)}
118
+ n_token = len(symbol_dict) + 1
119
+ print("\nFound:", n_token, "symbols")
 
 
 
 
 
 
 
120
 
121
  args = self.__recursive_munch(config["model_params"])
122
  args["n_token"] = n_token
123
 
124
  self.cleaner = TextCleaner(symbol_dict, debug=False)
125
 
 
 
126
  self.decoder = Decoder(
127
  dim_in=args.hidden_dim,
128
  style_dim=args.style_dim,
 
157
  def __recursive_munch(self, d):
158
  if isinstance(d, dict):
159
  return Munch((k, self.__recursive_munch(v)) for k, v in d.items())
160
+ if isinstance(d, list):
161
  return [self.__recursive_munch(v) for v in d]
162
+ return d
 
 
 
 
 
 
 
163
 
164
  def __replace_outliers_zscore(self, tensor, threshold=3.0, factor=0.95):
165
  mean = tensor.mean()
166
  std = tensor.std()
167
+ z = (tensor - mean) / (std + 1e-8)
168
  outlier_mask = torch.abs(z) > threshold
169
  sign = torch.sign(tensor - mean)
170
  replacement = mean + sign * (threshold * std * factor)
 
173
  return result
174
 
175
  def __load_models(self, models_path):
 
176
  model = {
177
  "decoder": self.decoder,
178
  "predictor": self.predictor,
 
182
 
183
  params_whole = torch.load(models_path, map_location="cpu")
184
  params = params_whole["net"]
185
+ params = {k: v for k, v in params.items() if k in model}
186
 
187
+ for k in model:
188
  try:
189
+ model[k].load_state_dict(params[k])
190
  except Exception:
191
  from collections import OrderedDict
 
192
  new_state_dict = OrderedDict()
193
+ for kk, vv in params[k].items():
194
+ new_state_dict[kk[7:]] = vv # strip "module."
195
+ model[k].load_state_dict(new_state_dict, strict=False)
 
 
 
 
 
196
 
197
+ print(k, ":", sum(p.numel() for p in model[k].parameters()))
198
 
199
  def __compute_style(self, path, denoise, split_dur):
200
  device = self.get_device.device
201
  denoise = min(float(denoise), 1.0)
202
+ split_dur = int(split_dur) if split_dur else 0
 
 
 
 
203
 
204
  wave, sr = librosa.load(path, sr=24000)
205
  audio, _ = librosa.effects.trim(wave, top_db=30)
206
 
 
 
 
 
 
 
 
207
  if denoise > 0.0:
208
  audio_denoise = nr.reduce_noise(
209
  y=audio, sr=sr, n_fft=2048, win_length=1200, hop_length=300
 
212
 
213
  with torch.no_grad():
214
  if split_dur > 0 and len(audio) / sr >= 4:
 
215
  jump = sr * split_dur
216
  total_len = len(audio)
217
+ ref_s = None
218
+ count = 0
219
 
220
+ for i in range(0, total_len, jump):
221
+ seg = audio[i : min(i + jump, total_len)]
222
+ if len(seg) < sr: # <1s thì bỏ
 
 
 
 
 
 
 
 
223
  continue
224
+ mel = self.preprocess.wave_preprocess(seg).to(device)
225
+ s = self.style_encoder(mel.unsqueeze(1))
226
+ ref_s = s if ref_s is None else (ref_s + s)
227
  count += 1
228
 
229
+ if ref_s is None:
230
+ mel = self.preprocess.wave_preprocess(audio).to(device)
231
+ ref_s = self.style_encoder(mel.unsqueeze(1))
232
+ else:
233
+ ref_s = ref_s / count
234
  else:
235
+ mel = self.preprocess.wave_preprocess(audio).to(device)
236
+ ref_s = self.style_encoder(mel.unsqueeze(1))
237
 
238
  return ref_s
239
 
240
+ def __inference(self, phonem, ref_s, speed=1.0, prev_d_mean=0.0, t=0.1):
241
  device = self.get_device.device
242
+ speed = float(np.clip(speed, 1e-4, 2.0))
243
 
244
  phonem = normalize_phonem_tokens(phonem)
 
245
  tokens = self.cleaner(phonem)
246
+ tokens = [0] + tokens + [0]
247
+ tokens = torch.LongTensor(tokens).unsqueeze(0).to(device)
 
 
 
 
 
 
248
 
249
  with torch.no_grad():
250
  input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
 
256
  d = self.predictor.text_encoder(t_en, s, input_lengths, text_mask)
257
  x, _ = self.predictor.lstm(d)
258
  duration = self.predictor.duration_proj(x)
259
+ duration = torch.sigmoid(duration).sum(dim=-1)
260
 
261
  if prev_d_mean != 0:
262
+ dur_stats = torch.empty_like(duration).normal_(mean=prev_d_mean, std=duration.std() + 1e-8).to(device)
 
 
263
  else:
264
+ dur_stats = torch.empty_like(duration).normal_(mean=duration.mean(), std=duration.std() + 1e-8).to(device)
 
 
265
 
266
  duration = duration * (1 - t) + dur_stats * t
267
+ duration[:, 1:-2] = self.__replace_outliers_zscore(duration[:, 1:-2])
268
+ duration = duration / speed
 
269
 
270
+ pred_dur = torch.round(duration.squeeze(0)).clamp(min=1)
271
 
272
  L = int(input_lengths.item())
273
  T = int(pred_dur.sum().item())
274
+ pred_aln_trg = torch.zeros((L, T), device=device)
 
275
 
276
+ c = 0
 
277
  for i in range(L):
278
  di = int(pred_dur[i].item())
279
+ pred_aln_trg[i, c : c + di] = 1
280
+ c += di
281
 
282
+ alignment = pred_aln_trg.unsqueeze(0)
283
 
284
+ en = d.transpose(-1, -2) @ alignment
285
  F0_pred, N_pred = self.predictor.F0Ntrain(en, s)
286
+ asr = t_en @ pred_aln_trg.unsqueeze(0)
287
 
288
  out = self.decoder(asr, F0_pred, N_pred, s)
289
 
 
292
  def get_styles(self, speakers, denoise=0.3, avg_style=True):
293
  split_dur = 2 if avg_style else 0
294
  styles = {}
295
+ for sid, meta in speakers.items():
296
+ ref_s = self.__compute_style(meta["path"], denoise=denoise, split_dur=split_dur)
 
 
 
 
297
  styles[sid] = {
298
  "style": ref_s,
299
+ "path": meta["path"],
300
+ "lang": meta["lang"],
301
+ "speed": meta["speed"],
302
  }
303
  return styles
304
 
 
308
  list_wav = []
309
  prev_d_mean = 0.0
310
  lang_pattern = r"\[([^\]]+)\]\{([^}]+)\}"
 
311
  text = re.sub(r"[\n\r\t\f\v]", "", text)
312
 
313
+ # split by speaker tags
314
+ parts = re.split(r"(\[id_\d+\])", text)
315
+ if len(parts) <= 1 or re.match(r"(\[id_\d+\])", parts[0]) is None:
316
+ parts.insert(0, default_speaker)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
 
318
  speaker_id = None
319
  current_ref_s = None
320
  speed = 1.0
321
 
322
+ for p in parts:
323
+ if re.match(r"(\[id_\d+\])", p):
324
+ speaker_id = p.strip("[]")
 
 
325
  current_ref_s = styles[speaker_id]["style"]
326
  speed = styles[speaker_id]["speed"]
327
  continue
328
 
329
+ if not p.strip():
330
+ continue
 
 
 
 
 
 
 
 
 
 
 
331
 
332
+ for sentence in self.preprocess.text_preprocess(p, n_merge=n_merge):
333
+ # phonemize
334
  phonem = espeak_phn(sentence, styles[speaker_id]["lang"])
 
 
335
  wav, prev_d_mean = self.__inference(
336
+ phonem, current_ref_s, speed=speed, prev_d_mean=prev_d_mean, t=smooth_value
 
 
 
 
337
  )
338
 
339
+ # trim an toàn
 
 
340
  trim = 4000
341
  if wav.shape[0] > 2 * trim:
342
  wav = wav[trim:-trim]
343
 
344
+ if wav.size > 0:
345
+ list_wav.append(wav)
346
 
347
  if len(list_wav) == 0:
348
+ return np.zeros((2400,), dtype=np.float32) # 0.1s silence để không crash
 
349
 
350
+ final_wav = np.concatenate(list_wav)
351
+ final_wav = np.concatenate([np.zeros((4000,), dtype=np.float32), final_wav, np.zeros((4000,), dtype=np.float32)])
352
  return final_wav