stephenhoang commited on
Commit
e835266
·
1 Parent(s): 65ec5ec

Fix encoding header in meldataset.py

Browse files
Files changed (1) hide show
  1. meldataset.py +148 -66
meldataset.py CHANGED
@@ -7,7 +7,10 @@ import soundfile as sf
7
  import librosa
8
 
9
  import torch
10
- import torchaudio
 
 
 
11
  import torch.utils.data
12
  import torch.distributed as dist
13
  from multiprocessing import Pool
@@ -18,115 +21,194 @@ logger.setLevel(logging.DEBUG)
18
 
19
  import pandas as pd
20
 
21
- class TextCleaner:
22
- def __init__(self, symbol_dict, debug=True):
23
- self.word_index_dictionary = symbol_dict
24
- self.debug = debug
25
- def __call__(self, text):
26
- indexes = []
27
- for char in text:
28
- try:
29
- indexes.append(self.word_index_dictionary[char])
30
- except KeyError as e:
31
- if self.debug:
32
- print("\nWARNING UNKNOWN IPA CHARACTERS/LETTERS: ", char)
33
- print("To ignore set 'debug' to false in the config")
34
- continue
35
- return indexes
36
-
37
- np.random.seed(1)
38
- random.seed(1)
39
  SPECT_PARAMS = {
40
  "n_fft": 2048,
41
  "win_length": 1200,
42
- "hop_length": 300
43
  }
 
 
44
  MEL_PARAMS = {
45
  "n_mels": 80,
 
 
 
46
  }
47
 
48
- to_mel = torchaudio.transforms.MelSpectrogram(
49
- n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
50
  mean, std = -4, 4
51
 
52
- def preprocess(wave):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  wave_tensor = torch.from_numpy(wave).float()
54
- mel_tensor = to_mel(wave_tensor)
55
- mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
56
- return mel_tensor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  class FilePathDataset(torch.utils.data.Dataset):
59
- def __init__(self,
60
- data_list,
61
- root_path,
62
- symbol_dict,
63
- sr=24000,
64
- data_augmentation=False,
65
- validation=False,
66
- debug=True
67
- ):
68
-
69
- _data_list = [l.strip().split('|') for l in data_list]
70
- self.data_list = _data_list #[data if len(data) == 3 else (*data, 0) for data in _data_list] #append speakerid=0 for all
 
 
71
  self.text_cleaner = TextCleaner(symbol_dict, debug)
72
  self.sr = sr
73
 
74
  self.df = pd.DataFrame(self.data_list)
75
 
76
- self.to_melspec = torchaudio.transforms.MelSpectrogram(**MEL_PARAMS)
 
77
 
78
  self.mean, self.std = -4, 4
79
  self.data_augmentation = data_augmentation and (not validation)
80
  self.max_mel_length = 192
81
-
82
  self.root_path = root_path
83
 
84
  def __len__(self):
85
  return len(self.data_list)
86
 
87
- def __getitem__(self, idx):
88
  data = self.data_list[idx]
89
  path = data[0]
90
-
91
  wave, text_tensor = self._load_tensor(data)
92
-
93
- mel_tensor = preprocess(wave).squeeze()
94
-
95
- acoustic_feature = mel_tensor.squeeze()
96
  length_feature = acoustic_feature.size(1)
97
- acoustic_feature = acoustic_feature[:, :(length_feature - length_feature % 2)]
98
-
99
  return acoustic_feature, text_tensor, path, wave
100
 
101
  def _load_tensor(self, data):
102
- wave_path, text = data
 
 
 
103
  wave, sr = sf.read(osp.join(self.root_path, wave_path))
104
- if wave.shape[-1] == 2:
105
  wave = wave[:, 0].squeeze()
106
- if sr != 24000:
107
- wave = librosa.resample(wave, orig_sr=sr, target_sr=24000)
108
- print(wave_path, sr)
109
-
110
- # Adding half a second padding.
111
- wave = np.concatenate([np.zeros([12000]), wave, np.zeros([12000])], axis=0)
112
-
113
- text = self.text_cleaner(text)
114
-
115
- text.insert(0, 0)
116
- text.append(0)
117
-
118
- text = torch.LongTensor(text)
119
 
120
- return wave, text
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  def _load_data(self, data):
123
  wave, text_tensor = self._load_tensor(data)
124
- mel_tensor = preprocess(wave).squeeze()
125
 
126
  mel_length = mel_tensor.size(1)
127
  if mel_length > self.max_mel_length:
128
  random_start = np.random.randint(0, mel_length - self.max_mel_length)
129
- mel_tensor = mel_tensor[:, random_start:random_start + self.max_mel_length]
130
 
131
  return mel_tensor
132
 
 
7
  import librosa
8
 
9
  import torch
10
+ try:
11
+ import torchaudio
12
+ except ImportError:
13
+ torchaudio = None
14
  import torch.utils.data
15
  import torch.distributed as dist
16
  from multiprocessing import Pool
 
21
 
22
  import pandas as pd
23
 
24
+ # class TextCleaner:
25
+ # def __init__(self, symbol_dict, debug=True):
26
+ # self.word_index_dictionary = symbol_dict
27
+ # self.debug = debug
28
+ # def __call__(self, text):
29
+ # indexes = []
30
+ # for char in text:
31
+ # try:
32
+ # indexes.append(self.word_index_dictionary[char])
33
+ # except KeyError as e:
34
+ # if self.debug:
35
+ # print("\nWARNING UNKNOWN IPA CHARACTERS/LETTERS: ", char)
36
+ # print("To ignore set 'debug' to false in the config")
37
+ # continue
38
+ # return indexes
39
+
40
+
41
+
42
  SPECT_PARAMS = {
43
  "n_fft": 2048,
44
  "win_length": 1200,
45
+ "hop_length": 300,
46
  }
47
+
48
+ # Dùng đầy đủ params cho MelSpectrogram (tránh thiếu n_fft/win/hop)
49
  MEL_PARAMS = {
50
  "n_mels": 80,
51
+ "n_fft": 2048,
52
+ "win_length": 1200,
53
+ "hop_length": 300,
54
  }
55
 
 
 
56
  mean, std = -4, 4
57
 
58
+
59
+
60
+ # Cache MelSpectrogram theo sample_rate
61
+ _MEL_CACHE = {}
62
+
63
+
64
+ def _require_torchaudio(context: str) -> None:
65
+ if torchaudio is None:
66
+ raise RuntimeError(
67
+ f"torchaudio is required for {context} but is not installed in this environment. "
68
+ "For HF Spaces inference, you should not instantiate FilePathDataset / mel extraction."
69
+ )
70
+
71
+
72
+ def get_mel_transform(sample_rate: int = 24000):
73
+ _require_torchaudio("mel extraction")
74
+ if sample_rate not in _MEL_CACHE:
75
+ _MEL_CACHE[sample_rate] = torchaudio.transforms.MelSpectrogram(
76
+ sample_rate=sample_rate,
77
+ n_mels=MEL_PARAMS["n_mels"],
78
+ n_fft=MEL_PARAMS["n_fft"],
79
+ win_length=MEL_PARAMS["win_length"],
80
+ hop_length=MEL_PARAMS["hop_length"],
81
+ )
82
+ return _MEL_CACHE[sample_rate]
83
+
84
+
85
+ def preprocess(wave: np.ndarray, sample_rate: int = 24000):
86
+ """
87
+ wave: 1D numpy float array
88
+ return: mel tensor shape (1, n_mels, T)
89
+ """
90
+ _require_torchaudio("preprocess()")
91
+ if wave.ndim != 1:
92
+ wave = np.asarray(wave).squeeze()
93
  wave_tensor = torch.from_numpy(wave).float()
94
+
95
+ to_mel = get_mel_transform(sample_rate)
96
+ mel = to_mel(wave_tensor) # (n_mels, T)
97
+ mel = (torch.log(mel + 1e-5) - mean) / std
98
+ return mel.unsqueeze(0) # (1, n_mels, T)
99
+
100
+
101
+ class TextCleaner:
102
+ """
103
+ Minimal TextCleaner: map token -> id based on symbol_dict.
104
+ - Nếu input text có dấu cách: split theo space (phù hợp IPA tokenization)
105
+ - Nếu không có space: tách theo ký tự
106
+ """
107
+ def __init__(self, symbol_dict, debug=True):
108
+ self.symbol_dict = symbol_dict
109
+ self.debug = debug
110
+
111
+ def __call__(self, text: str):
112
+ text = (text or "").strip()
113
+
114
+ # IPA/token list thường được tách bằng space
115
+ if " " in text:
116
+ tokens = [t for t in text.split(" ") if t != ""]
117
+ else:
118
+ tokens = list(text)
119
+
120
+ ids = []
121
+ missing = []
122
+ for t in tokens:
123
+ if t in self.symbol_dict:
124
+ ids.append(self.symbol_dict[t])
125
+ else:
126
+ missing.append(t)
127
+
128
+ if self.debug and missing:
129
+ # In tối đa 30 token thiếu để tránh spam log
130
+ print(f"[TextCleaner] missing {len(missing)} symbols. sample={missing[:30]}")
131
+
132
+ return ids
133
+
134
 
135
  class FilePathDataset(torch.utils.data.Dataset):
136
+ def __init__(
137
+ self,
138
+ data_list,
139
+ root_path,
140
+ symbol_dict,
141
+ sr=24000,
142
+ data_augmentation=False,
143
+ validation=False,
144
+ debug=True,
145
+ ):
146
+ _require_torchaudio("FilePathDataset (training dataloader)")
147
+
148
+ _data_list = [l.strip().split("|") for l in data_list]
149
+ self.data_list = _data_list # [wav_path, text] (hoặc thêm speaker_id tuỳ bạn)
150
  self.text_cleaner = TextCleaner(symbol_dict, debug)
151
  self.sr = sr
152
 
153
  self.df = pd.DataFrame(self.data_list)
154
 
155
+ # training-only: mel transform
156
+ self.to_melspec = get_mel_transform(self.sr)
157
 
158
  self.mean, self.std = -4, 4
159
  self.data_augmentation = data_augmentation and (not validation)
160
  self.max_mel_length = 192
 
161
  self.root_path = root_path
162
 
163
  def __len__(self):
164
  return len(self.data_list)
165
 
166
+ def __getitem__(self, idx):
167
  data = self.data_list[idx]
168
  path = data[0]
169
+
170
  wave, text_tensor = self._load_tensor(data)
171
+
172
+ mel_tensor = preprocess(wave, sample_rate=self.sr).squeeze() # (n_mels, T)
173
+
174
+ acoustic_feature = mel_tensor
175
  length_feature = acoustic_feature.size(1)
176
+ acoustic_feature = acoustic_feature[:, : (length_feature - length_feature % 2)]
177
+
178
  return acoustic_feature, text_tensor, path, wave
179
 
180
  def _load_tensor(self, data):
181
+ # data có thể là [wave_path, text] hoặc [wave_path, text, speaker_id]
182
+ wave_path = data[0]
183
+ text = data[1]
184
+
185
  wave, sr = sf.read(osp.join(self.root_path, wave_path))
186
+ if isinstance(wave, np.ndarray) and wave.ndim == 2 and wave.shape[-1] == 2:
187
  wave = wave[:, 0].squeeze()
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
+ if sr != self.sr:
190
+ wave = librosa.resample(wave, orig_sr=sr, target_sr=self.sr)
191
+
192
+ # padding 0.5s mỗi bên (24000 * 0.5 = 12000)
193
+ wave = np.concatenate([np.zeros([12000]), wave, np.zeros([12000])], axis=0)
194
+
195
+ text_ids = self.text_cleaner(text)
196
+
197
+ # BOS/EOS = 0 như code gốc của bạn
198
+ text_ids.insert(0, 0)
199
+ text_ids.append(0)
200
+
201
+ text_tensor = torch.LongTensor(text_ids)
202
+ return wave, text_tensor
203
 
204
  def _load_data(self, data):
205
  wave, text_tensor = self._load_tensor(data)
206
+ mel_tensor = preprocess(wave, sample_rate=self.sr).squeeze()
207
 
208
  mel_length = mel_tensor.size(1)
209
  if mel_length > self.max_mel_length:
210
  random_start = np.random.randint(0, mel_length - self.max_mel_length)
211
+ mel_tensor = mel_tensor[:, random_start : random_start + self.max_mel_length]
212
 
213
  return mel_tensor
214