Respair commited on
Commit
4d3f1b8
·
verified ·
1 Parent(s): faa8827

Update meldataset.py

Browse files
Files changed (1) hide show
  1. meldataset.py +27 -30
meldataset.py CHANGED
@@ -18,8 +18,6 @@ import librosa
18
  import logging
19
  logger = logging.getLogger(__name__)
20
  logger.setLevel(logging.DEBUG)
21
- from utils import *
22
-
23
  # from text_utils import TextCleaner
24
  np.random.seed(1)
25
  random.seed(1)
@@ -65,11 +63,12 @@ class TextCleaner:
65
  print(text)
66
  return indexes
67
 
 
68
  class MelDataset(torch.utils.data.Dataset):
69
  def __init__(self,
70
  data_list,
71
- sr=44100,
72
- scaling_factor=1.0 # Add scaling_factor parameter
73
  ):
74
 
75
  spect_params = SPECT_PARAMS
@@ -81,14 +80,14 @@ class MelDataset(torch.utils.data.Dataset):
81
  self.sr = sr
82
 
83
  self.to_melspec = torchaudio.transforms.MelSpectrogram(sample_rate=44_100,
84
- n_mels=128,
85
- n_fft=2048,
86
- win_length=2048,
87
- hop_length=512)
88
  self.mean, self.std = -4, 4
89
 
90
- # Add the beta-binomial interpolator
91
- self.beta_binomial_interpolator = BetaBinomialInterpolator(scaling_factor=scaling_factor)
92
 
93
  def __len__(self):
94
  return len(self.data_list)
@@ -108,15 +107,12 @@ class MelDataset(torch.utils.data.Dataset):
108
 
109
  length_feature = acoustic_feature.size(1)
110
  acoustic_feature = acoustic_feature[:, :(length_feature - length_feature % 2)]
111
-
112
- # Generate attention prior matrix
113
- text_len = text_tensor.size(0)
114
- mel_len = acoustic_feature.size(1)
115
- attn_prior = torch.from_numpy(self.beta_binomial_interpolator(mel_len, text_len)).float()
116
 
117
- return wave_tensor, acoustic_feature, text_tensor, attn_prior, data[0]
 
118
 
119
  def _load_tensor(self, data):
 
120
  wave_path, text, speaker_id = data
121
  speaker_id = int(speaker_id)
122
  wave, sr = sf.read(wave_path)
@@ -124,7 +120,10 @@ class MelDataset(torch.utils.data.Dataset):
124
  wave = wave[:, 0].squeeze()
125
  if sr != 44100:
126
  wave = librosa.resample(wave, orig_sr=sr, target_sr=44100)
 
127
 
 
 
128
  text = self.text_cleaner(text)
129
 
130
  text.insert(0, 0)
@@ -134,8 +133,14 @@ class MelDataset(torch.utils.data.Dataset):
134
 
135
  return wave, text, speaker_id
136
 
137
- # Now modify the Collater class to handle the attention prior
 
138
  class Collater(object):
 
 
 
 
 
139
  def __init__(self, return_wave=False):
140
  self.text_pad_index = 0
141
  self.return_wave = return_wave
@@ -156,33 +161,25 @@ class Collater(object):
156
  texts = torch.zeros((batch_size, max_text_length)).long()
157
  input_lengths = torch.zeros(batch_size).long()
158
  output_lengths = torch.zeros(batch_size).long()
159
-
160
- # Add tensor for attention priors
161
- attn_priors = torch.zeros((batch_size, max_mel_length, max_text_length)).float()
162
-
163
  paths = ['' for _ in range(batch_size)]
164
-
165
- for bid, (_, mel, text, attn_prior, path) in enumerate(batch):
166
  mel_size = mel.size(1)
167
  text_size = text.size(0)
168
  mels[bid, :, :mel_size] = mel
169
  texts[bid, :text_size] = text
170
  input_lengths[bid] = text_size
171
  output_lengths[bid] = mel_size
172
-
173
- # Handle attention prior
174
- attn_priors[bid, :mel_size, :text_size] = attn_prior
175
-
176
  paths[bid] = path
177
  assert(text_size < (mel_size//2))
178
 
179
  if self.return_wave:
180
  waves = [b[0] for b in batch]
181
- return texts, input_lengths, mels, output_lengths, attn_priors, paths, waves
 
 
 
182
 
183
- return texts, input_lengths, mels, output_lengths, attn_priors
184
 
185
- # Update the build_dataloader function to use the new MelDataset and Collater
186
  def build_dataloader(path_list,
187
  validation=False,
188
  batch_size=4,
 
18
  import logging
19
  logger = logging.getLogger(__name__)
20
  logger.setLevel(logging.DEBUG)
 
 
21
  # from text_utils import TextCleaner
22
  np.random.seed(1)
23
  random.seed(1)
 
63
  print(text)
64
  return indexes
65
 
66
+
67
  class MelDataset(torch.utils.data.Dataset):
68
  def __init__(self,
69
  data_list,
70
+ # dict_path=DEFAULT_DICT_PATH,
71
+ sr=44100
72
  ):
73
 
74
  spect_params = SPECT_PARAMS
 
80
  self.sr = sr
81
 
82
  self.to_melspec = torchaudio.transforms.MelSpectrogram(sample_rate=44_100,
83
+ n_mels=128,
84
+ n_fft=2048,
85
+ win_length=2048,
86
+ hop_length=512)
87
  self.mean, self.std = -4, 4
88
 
89
+ # self.g2p = hibiki_phon()
90
+
91
 
92
  def __len__(self):
93
  return len(self.data_list)
 
107
 
108
  length_feature = acoustic_feature.size(1)
109
  acoustic_feature = acoustic_feature[:, :(length_feature - length_feature % 2)]
 
 
 
 
 
110
 
111
+ return wave_tensor, acoustic_feature, text_tensor, data[0]
112
+
113
 
114
  def _load_tensor(self, data):
115
+
116
  wave_path, text, speaker_id = data
117
  speaker_id = int(speaker_id)
118
  wave, sr = sf.read(wave_path)
 
120
  wave = wave[:, 0].squeeze()
121
  if sr != 44100:
122
  wave = librosa.resample(wave, orig_sr=sr, target_sr=44100)
123
+ # print(wave_path, sr)
124
 
125
+ # wave = np.concatenate([np.zeros([5000]), wave, np.zeros([5000])], axis=0)
126
+
127
  text = self.text_cleaner(text)
128
 
129
  text.insert(0, 0)
 
133
 
134
  return wave, text, speaker_id
135
 
136
+
137
+
138
  class Collater(object):
139
+ """
140
+ Args:
141
+ return_wave (bool): if true, will return the wave data along with spectrogram.
142
+ """
143
+
144
  def __init__(self, return_wave=False):
145
  self.text_pad_index = 0
146
  self.return_wave = return_wave
 
161
  texts = torch.zeros((batch_size, max_text_length)).long()
162
  input_lengths = torch.zeros(batch_size).long()
163
  output_lengths = torch.zeros(batch_size).long()
 
 
 
 
164
  paths = ['' for _ in range(batch_size)]
165
+ for bid, (_, mel, text, path) in enumerate(batch):
 
166
  mel_size = mel.size(1)
167
  text_size = text.size(0)
168
  mels[bid, :, :mel_size] = mel
169
  texts[bid, :text_size] = text
170
  input_lengths[bid] = text_size
171
  output_lengths[bid] = mel_size
 
 
 
 
172
  paths[bid] = path
173
  assert(text_size < (mel_size//2))
174
 
175
  if self.return_wave:
176
  waves = [b[0] for b in batch]
177
+ return texts, input_lengths, mels, output_lengths, paths, waves
178
+
179
+ return texts, input_lengths, mels, output_lengths
180
+
181
 
 
182
 
 
183
  def build_dataloader(path_list,
184
  validation=False,
185
  batch_size=4,