stephenhoang commited on
Commit
d3f42ba
·
1 Parent(s): e835266

Remove torchaudio; compute mel with librosa

Browse files
Files changed (1) hide show
  1. inference.py +138 -24
inference.py CHANGED
@@ -1,21 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import re
 
2
  import yaml
3
- from munch import Munch
4
  import numpy as np
5
  import librosa
6
- import noisereduce as nr
7
- from meldataset import TextCleaner
8
  import torch
9
- import torchaudio
 
 
10
  from nltk.tokenize import word_tokenize
11
- import nltk
12
- nltk.download('punkt_tab')
13
 
 
14
  from models import ProsodyPredictor, TextEncoder, StyleEncoder
15
  from Modules.hifigan import Decoder
16
 
17
- import sys
18
- import phonemizer
 
 
 
19
  if sys.platform.startswith("win"):
20
  try:
21
  from phonemizer.backend.espeak.wrapper import EspeakWrapper
@@ -24,23 +106,30 @@ if sys.platform.startswith("win"):
24
  except Exception as e:
25
  print(e)
26
 
 
27
  def espeak_phn(text, lang):
28
  try:
29
- my_phonemizer = phonemizer.backend.EspeakBackend(language=lang, preserve_punctuation=True, with_stress=True, language_switch='remove-flags')
 
 
 
 
 
30
  return my_phonemizer.phonemize([text])[0]
31
  except Exception as e:
32
  print(e)
 
 
33
 
34
  class Preprocess:
35
  def __text_normalize(self, text):
36
  punctuation = [",", "、", "،", ";", "(", ".", "。", "…", "!", "–", ":", "?"]
37
  map_to = "."
38
  punctuation_pattern = re.compile(f"[{''.join(re.escape(p) for p in punctuation)}]")
39
- #replace punctuation that acts like a comma or period
40
  text = punctuation_pattern.sub(map_to, text)
41
- #replace consecutive whitespace chars with a single space and strip leading/trailing spaces
42
- text = re.sub(r'\s+', ' ', text).strip()
43
  return text
 
44
  def __merge_fragments(self, texts, n):
45
  merged = []
46
  i = 0
@@ -52,30 +141,55 @@ class Preprocess:
52
  j += 1
53
  merged.append(fragment)
54
  i = j
55
- if len(merged[-1].split()) < n and len(merged) > 1: #handle last sentence
 
56
  merged[-2] = merged[-2] + ", " + merged[-1]
57
  del merged[-1]
58
- else:
59
- merged[-1] = merged[-1]
60
  return merged
61
- def wave_preprocess(self, wave):
62
- to_mel = torchaudio.transforms.MelSpectrogram(n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  mean, std = -4, 4
64
- wave_tensor = torch.from_numpy(wave).float()
65
- mel_tensor = to_mel(wave_tensor)
66
- mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
 
67
  return mel_tensor
 
68
  def text_preprocess(self, text, n_merge=12):
69
- text_norm = self.__text_normalize(text).split(".")#split by sentences.
70
  text_norm = [s.strip() for s in text_norm]
71
- text_norm = list(filter(lambda x: x != '', text_norm)) #filter empty index
72
- text_norm = self.__merge_fragments(text_norm, n=n_merge) #merge if a sentence has less that n
73
  return text_norm
 
74
  def length_to_mask(self, lengths):
75
  mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
76
- mask = torch.gt(mask+1, lengths.unsqueeze(1))
77
  return mask
78
 
 
79
  #For inference only
80
  class StyleTTS2(torch.nn.Module):
81
  def __init__(self, config_path, models_path):
 
1
+ # import re
2
+ # import yaml
3
+ # from munch import Munch
4
+ # import numpy as np
5
+ # import librosa
6
+ # import noisereduce as nr
7
+ # from meldataset import TextCleaner
8
+ # import torch
9
+ # import torchaudio
10
+ # from nltk.tokenize import word_tokenize
11
+ # import nltk
12
+ # nltk.download('punkt_tab')
13
+
14
+ # from models import ProsodyPredictor, TextEncoder, StyleEncoder
15
+ # from Modules.hifigan import Decoder
16
+
17
+ # import sys
18
+ # import phonemizer
19
+ # if sys.platform.startswith("win"):
20
+ # try:
21
+ # from phonemizer.backend.espeak.wrapper import EspeakWrapper
22
+ # import espeakng_loader
23
+ # EspeakWrapper.set_library(espeakng_loader.get_library_path())
24
+ # except Exception as e:
25
+ # print(e)
26
+
27
+ # def espeak_phn(text, lang):
28
+ # try:
29
+ # my_phonemizer = phonemizer.backend.EspeakBackend(language=lang, preserve_punctuation=True, with_stress=True, language_switch='remove-flags')
30
+ # return my_phonemizer.phonemize([text])[0]
31
+ # except Exception as e:
32
+ # print(e)
33
+
34
+ # class Preprocess:
35
+ # def __text_normalize(self, text):
36
+ # punctuation = [",", "、", "،", ";", "(", ".", "。", "…", "!", "–", ":", "?"]
37
+ # map_to = "."
38
+ # punctuation_pattern = re.compile(f"[{''.join(re.escape(p) for p in punctuation)}]")
39
+ # #replace punctuation that acts like a comma or period
40
+ # text = punctuation_pattern.sub(map_to, text)
41
+ # #replace consecutive whitespace chars with a single space and strip leading/trailing spaces
42
+ # text = re.sub(r'\s+', ' ', text).strip()
43
+ # return text
44
+ # def __merge_fragments(self, texts, n):
45
+ # merged = []
46
+ # i = 0
47
+ # while i < len(texts):
48
+ # fragment = texts[i]
49
+ # j = i + 1
50
+ # while len(fragment.split()) < n and j < len(texts):
51
+ # fragment += ", " + texts[j]
52
+ # j += 1
53
+ # merged.append(fragment)
54
+ # i = j
55
+ # if len(merged[-1].split()) < n and len(merged) > 1: #handle last sentence
56
+ # merged[-2] = merged[-2] + ", " + merged[-1]
57
+ # del merged[-1]
58
+ # else:
59
+ # merged[-1] = merged[-1]
60
+ # return merged
61
+ # def wave_preprocess(self, wave):
62
+ # to_mel = torchaudio.transforms.MelSpectrogram(n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
63
+ # mean, std = -4, 4
64
+ # wave_tensor = torch.from_numpy(wave).float()
65
+ # mel_tensor = to_mel(wave_tensor)
66
+ # mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
67
+ # return mel_tensor
68
+ # def text_preprocess(self, text, n_merge=12):
69
+ # text_norm = self.__text_normalize(text).split(".")#split by sentences.
70
+ # text_norm = [s.strip() for s in text_norm]
71
+ # text_norm = list(filter(lambda x: x != '', text_norm)) #filter empty index
72
+ # text_norm = self.__merge_fragments(text_norm, n=n_merge) #merge if a sentence has less that n
73
+ # return text_norm
74
+ # def length_to_mask(self, lengths):
75
+ # mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
76
+ # mask = torch.gt(mask+1, lengths.unsqueeze(1))
77
+ # return mask
78
+
79
+
80
  import re
81
+ import sys
82
  import yaml
83
+ import nltk
84
  import numpy as np
85
  import librosa
 
 
86
  import torch
87
+ import phonemizer
88
+ import noisereduce as nr
89
+ from munch import Munch
90
  from nltk.tokenize import word_tokenize
 
 
91
 
92
+ from meldataset import TextCleaner
93
  from models import ProsodyPredictor, TextEncoder, StyleEncoder
94
  from Modules.hifigan import Decoder
95
 
96
+ # Không download ở runtime trên Space (dễ treo / fail do network)
97
+ # nltk.download('punkt_tab')
98
+ # Nếu bạn cần, chuyển sang packages/requirements hoặc chạy local build step.
99
+ # Trên Space, khuyến nghị bỏ phụ thuộc NLTK hoặc thay bằng tokenizer đơn giản.
100
+
101
  if sys.platform.startswith("win"):
102
  try:
103
  from phonemizer.backend.espeak.wrapper import EspeakWrapper
 
106
  except Exception as e:
107
  print(e)
108
 
109
+
110
  def espeak_phn(text, lang):
111
  try:
112
+ my_phonemizer = phonemizer.backend.EspeakBackend(
113
+ language=lang,
114
+ preserve_punctuation=True,
115
+ with_stress=True,
116
+ language_switch="remove-flags",
117
+ )
118
  return my_phonemizer.phonemize([text])[0]
119
  except Exception as e:
120
  print(e)
121
+ return text
122
+
123
 
124
  class Preprocess:
125
  def __text_normalize(self, text):
126
  punctuation = [",", "、", "،", ";", "(", ".", "。", "…", "!", "–", ":", "?"]
127
  map_to = "."
128
  punctuation_pattern = re.compile(f"[{''.join(re.escape(p) for p in punctuation)}]")
 
129
  text = punctuation_pattern.sub(map_to, text)
130
+ text = re.sub(r"\s+", " ", text).strip()
 
131
  return text
132
+
133
  def __merge_fragments(self, texts, n):
134
  merged = []
135
  i = 0
 
141
  j += 1
142
  merged.append(fragment)
143
  i = j
144
+
145
+ if len(merged) > 1 and len(merged[-1].split()) < n:
146
  merged[-2] = merged[-2] + ", " + merged[-1]
147
  del merged[-1]
 
 
148
  return merged
149
+
150
+ def wave_preprocess(self, wave, sr=24000):
151
+ """
152
+ Thay torchaudio bằng librosa để tránh dependency torchaudio trên HF Space.
153
+ Output giống shape cũ: (1, 80, T)
154
+ """
155
+ if wave is None:
156
+ raise ValueError("wave is None")
157
+ wave = np.asarray(wave)
158
+ if wave.ndim != 1:
159
+ wave = wave.squeeze()
160
+ wave = wave.astype(np.float32)
161
+
162
+ # Mel spectrogram (power). Nếu muốn khớp torchaudio default power=2.0, để power=2.0.
163
+ mel = librosa.feature.melspectrogram(
164
+ y=wave,
165
+ sr=sr,
166
+ n_fft=2048,
167
+ win_length=1200,
168
+ hop_length=300,
169
+ n_mels=80,
170
+ power=2.0,
171
+ ) # (80, T)
172
+
173
  mean, std = -4, 4
174
+ mel = np.log(1e-5 + mel)
175
+ mel = (mel - mean) / std
176
+
177
+ mel_tensor = torch.from_numpy(mel).float().unsqueeze(0) # (1, 80, T)
178
  return mel_tensor
179
+
180
  def text_preprocess(self, text, n_merge=12):
181
+ text_norm = self.__text_normalize(text).split(".")
182
  text_norm = [s.strip() for s in text_norm]
183
+ text_norm = list(filter(lambda x: x != "", text_norm))
184
+ text_norm = self.__merge_fragments(text_norm, n=n_merge)
185
  return text_norm
186
+
187
  def length_to_mask(self, lengths):
188
  mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
189
+ mask = torch.gt(mask + 1, lengths.unsqueeze(1))
190
  return mask
191
 
192
+
193
  #For inference only
194
  class StyleTTS2(torch.nn.Module):
195
  def __init__(self, config_path, models_path):