Update meldataset.py
Browse files- meldataset.py +27 -30
meldataset.py
CHANGED
|
@@ -18,8 +18,6 @@ import librosa
|
|
| 18 |
import logging
|
| 19 |
logger = logging.getLogger(__name__)
|
| 20 |
logger.setLevel(logging.DEBUG)
|
| 21 |
-
from utils import *
|
| 22 |
-
|
| 23 |
# from text_utils import TextCleaner
|
| 24 |
np.random.seed(1)
|
| 25 |
random.seed(1)
|
|
@@ -65,11 +63,12 @@ class TextCleaner:
|
|
| 65 |
print(text)
|
| 66 |
return indexes
|
| 67 |
|
|
|
|
| 68 |
class MelDataset(torch.utils.data.Dataset):
|
| 69 |
def __init__(self,
|
| 70 |
data_list,
|
| 71 |
-
|
| 72 |
-
|
| 73 |
):
|
| 74 |
|
| 75 |
spect_params = SPECT_PARAMS
|
|
@@ -81,14 +80,14 @@ class MelDataset(torch.utils.data.Dataset):
|
|
| 81 |
self.sr = sr
|
| 82 |
|
| 83 |
self.to_melspec = torchaudio.transforms.MelSpectrogram(sample_rate=44_100,
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
self.mean, self.std = -4, 4
|
| 89 |
|
| 90 |
-
#
|
| 91 |
-
|
| 92 |
|
| 93 |
def __len__(self):
|
| 94 |
return len(self.data_list)
|
|
@@ -108,15 +107,12 @@ class MelDataset(torch.utils.data.Dataset):
|
|
| 108 |
|
| 109 |
length_feature = acoustic_feature.size(1)
|
| 110 |
acoustic_feature = acoustic_feature[:, :(length_feature - length_feature % 2)]
|
| 111 |
-
|
| 112 |
-
# Generate attention prior matrix
|
| 113 |
-
text_len = text_tensor.size(0)
|
| 114 |
-
mel_len = acoustic_feature.size(1)
|
| 115 |
-
attn_prior = torch.from_numpy(self.beta_binomial_interpolator(mel_len, text_len)).float()
|
| 116 |
|
| 117 |
-
return wave_tensor, acoustic_feature, text_tensor,
|
|
|
|
| 118 |
|
| 119 |
def _load_tensor(self, data):
|
|
|
|
| 120 |
wave_path, text, speaker_id = data
|
| 121 |
speaker_id = int(speaker_id)
|
| 122 |
wave, sr = sf.read(wave_path)
|
|
@@ -124,7 +120,10 @@ class MelDataset(torch.utils.data.Dataset):
|
|
| 124 |
wave = wave[:, 0].squeeze()
|
| 125 |
if sr != 44100:
|
| 126 |
wave = librosa.resample(wave, orig_sr=sr, target_sr=44100)
|
|
|
|
| 127 |
|
|
|
|
|
|
|
| 128 |
text = self.text_cleaner(text)
|
| 129 |
|
| 130 |
text.insert(0, 0)
|
|
@@ -134,8 +133,14 @@ class MelDataset(torch.utils.data.Dataset):
|
|
| 134 |
|
| 135 |
return wave, text, speaker_id
|
| 136 |
|
| 137 |
-
|
|
|
|
| 138 |
class Collater(object):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
def __init__(self, return_wave=False):
|
| 140 |
self.text_pad_index = 0
|
| 141 |
self.return_wave = return_wave
|
|
@@ -156,33 +161,25 @@ class Collater(object):
|
|
| 156 |
texts = torch.zeros((batch_size, max_text_length)).long()
|
| 157 |
input_lengths = torch.zeros(batch_size).long()
|
| 158 |
output_lengths = torch.zeros(batch_size).long()
|
| 159 |
-
|
| 160 |
-
# Add tensor for attention priors
|
| 161 |
-
attn_priors = torch.zeros((batch_size, max_mel_length, max_text_length)).float()
|
| 162 |
-
|
| 163 |
paths = ['' for _ in range(batch_size)]
|
| 164 |
-
|
| 165 |
-
for bid, (_, mel, text, attn_prior, path) in enumerate(batch):
|
| 166 |
mel_size = mel.size(1)
|
| 167 |
text_size = text.size(0)
|
| 168 |
mels[bid, :, :mel_size] = mel
|
| 169 |
texts[bid, :text_size] = text
|
| 170 |
input_lengths[bid] = text_size
|
| 171 |
output_lengths[bid] = mel_size
|
| 172 |
-
|
| 173 |
-
# Handle attention prior
|
| 174 |
-
attn_priors[bid, :mel_size, :text_size] = attn_prior
|
| 175 |
-
|
| 176 |
paths[bid] = path
|
| 177 |
assert(text_size < (mel_size//2))
|
| 178 |
|
| 179 |
if self.return_wave:
|
| 180 |
waves = [b[0] for b in batch]
|
| 181 |
-
return texts, input_lengths, mels, output_lengths,
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
-
return texts, input_lengths, mels, output_lengths, attn_priors
|
| 184 |
|
| 185 |
-
# Update the build_dataloader function to use the new MelDataset and Collater
|
| 186 |
def build_dataloader(path_list,
|
| 187 |
validation=False,
|
| 188 |
batch_size=4,
|
|
|
|
| 18 |
import logging
|
| 19 |
logger = logging.getLogger(__name__)
|
| 20 |
logger.setLevel(logging.DEBUG)
|
|
|
|
|
|
|
| 21 |
# from text_utils import TextCleaner
|
| 22 |
np.random.seed(1)
|
| 23 |
random.seed(1)
|
|
|
|
| 63 |
print(text)
|
| 64 |
return indexes
|
| 65 |
|
| 66 |
+
|
| 67 |
class MelDataset(torch.utils.data.Dataset):
|
| 68 |
def __init__(self,
|
| 69 |
data_list,
|
| 70 |
+
# dict_path=DEFAULT_DICT_PATH,
|
| 71 |
+
sr=44100
|
| 72 |
):
|
| 73 |
|
| 74 |
spect_params = SPECT_PARAMS
|
|
|
|
| 80 |
self.sr = sr
|
| 81 |
|
| 82 |
self.to_melspec = torchaudio.transforms.MelSpectrogram(sample_rate=44_100,
|
| 83 |
+
n_mels=128,
|
| 84 |
+
n_fft=2048,
|
| 85 |
+
win_length=2048,
|
| 86 |
+
hop_length=512)
|
| 87 |
self.mean, self.std = -4, 4
|
| 88 |
|
| 89 |
+
# self.g2p = hibiki_phon()
|
| 90 |
+
|
| 91 |
|
| 92 |
def __len__(self):
|
| 93 |
return len(self.data_list)
|
|
|
|
| 107 |
|
| 108 |
length_feature = acoustic_feature.size(1)
|
| 109 |
acoustic_feature = acoustic_feature[:, :(length_feature - length_feature % 2)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
+
return wave_tensor, acoustic_feature, text_tensor, data[0]
|
| 112 |
+
|
| 113 |
|
| 114 |
def _load_tensor(self, data):
|
| 115 |
+
|
| 116 |
wave_path, text, speaker_id = data
|
| 117 |
speaker_id = int(speaker_id)
|
| 118 |
wave, sr = sf.read(wave_path)
|
|
|
|
| 120 |
wave = wave[:, 0].squeeze()
|
| 121 |
if sr != 44100:
|
| 122 |
wave = librosa.resample(wave, orig_sr=sr, target_sr=44100)
|
| 123 |
+
# print(wave_path, sr)
|
| 124 |
|
| 125 |
+
# wave = np.concatenate([np.zeros([5000]), wave, np.zeros([5000])], axis=0)
|
| 126 |
+
|
| 127 |
text = self.text_cleaner(text)
|
| 128 |
|
| 129 |
text.insert(0, 0)
|
|
|
|
| 133 |
|
| 134 |
return wave, text, speaker_id
|
| 135 |
|
| 136 |
+
|
| 137 |
+
|
| 138 |
class Collater(object):
|
| 139 |
+
"""
|
| 140 |
+
Args:
|
| 141 |
+
return_wave (bool): if true, will return the wave data along with spectrogram.
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
def __init__(self, return_wave=False):
|
| 145 |
self.text_pad_index = 0
|
| 146 |
self.return_wave = return_wave
|
|
|
|
| 161 |
texts = torch.zeros((batch_size, max_text_length)).long()
|
| 162 |
input_lengths = torch.zeros(batch_size).long()
|
| 163 |
output_lengths = torch.zeros(batch_size).long()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
paths = ['' for _ in range(batch_size)]
|
| 165 |
+
for bid, (_, mel, text, path) in enumerate(batch):
|
|
|
|
| 166 |
mel_size = mel.size(1)
|
| 167 |
text_size = text.size(0)
|
| 168 |
mels[bid, :, :mel_size] = mel
|
| 169 |
texts[bid, :text_size] = text
|
| 170 |
input_lengths[bid] = text_size
|
| 171 |
output_lengths[bid] = mel_size
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
paths[bid] = path
|
| 173 |
assert(text_size < (mel_size//2))
|
| 174 |
|
| 175 |
if self.return_wave:
|
| 176 |
waves = [b[0] for b in batch]
|
| 177 |
+
return texts, input_lengths, mels, output_lengths, paths, waves
|
| 178 |
+
|
| 179 |
+
return texts, input_lengths, mels, output_lengths
|
| 180 |
+
|
| 181 |
|
|
|
|
| 182 |
|
|
|
|
| 183 |
def build_dataloader(path_list,
|
| 184 |
validation=False,
|
| 185 |
batch_size=4,
|