| | |
| | import os |
| | import os.path as osp |
| | import time |
| | import random |
| | import numpy as np |
| | import random |
| | import soundfile as sf |
| | import librosa |
| | import re, unicodedata |
| |
|
| | import torch |
| | from torch import nn |
| | import torch.nn.functional as F |
| | import torchaudio |
| | from torch.utils.data import DataLoader |
| |
|
| | import logging |
| | logger = logging.getLogger(__name__) |
| | logger.setLevel(logging.DEBUG) |
| |
|
| | import pandas as pd |
| |
|
| | _pad = "$" |
| | _punctuation = ';:,.!?¡¿—…"«»“” ' |
| | _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' |
| | _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" |
| |
|
| | |
| | symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) |
| |
|
| | dicts = {} |
| | for i in range(len((symbols))): |
| | dicts[symbols[i]] = i |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | class TextCleaner: |
| | """ |
| | • Normalises text to NFC so pre-composed IPA glyphs match `symbols`. |
| | • Splits on event tokens first (e.g. <evt_gasp>), then per-character. |
| | • Unknown chars map to the <unk> symbol instead of printing. |
| | """ |
| | _EVENT_RE = re.compile(r"<[^>]+>|.") |
| |
|
| | def __init__(self): |
| | |
| | self.lookup = dicts |
| | self.unk_id = 0 |
| |
|
| | def __call__(self, text: str): |
| | text = unicodedata.normalize("NFC", text) |
| | ids = [] |
| | for tok in self._EVENT_RE.findall(text): |
| | ids.append(self.lookup.get(tok, self.unk_id)) |
| | return ids |
| | |
| |
|
| | np.random.seed(1) |
| | random.seed(1) |
| | SPECT_PARAMS = { |
| | "n_fft": 2048, |
| | "win_length": 1200, |
| | "hop_length": 300 |
| | } |
| | MEL_PARAMS = { |
| | "n_mels": 80, |
| | } |
| |
|
| | to_mel = torchaudio.transforms.MelSpectrogram( |
| | n_mels=80, n_fft=2048, win_length=1200, hop_length=300) |
| | mean, std = -4, 4 |
| |
|
| | def preprocess(wave): |
| | wave_tensor = torch.from_numpy(wave).float() |
| | mel_tensor = to_mel(wave_tensor) |
| | mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std |
| | return mel_tensor |
| |
|
| | class FilePathDataset(torch.utils.data.Dataset): |
| | def __init__(self, |
| | data_list, |
| | root_path, |
| | sr=24000, |
| | data_augmentation=False, |
| | validation=False, |
| | OOD_data="Data/OOD_texts.txt", |
| | min_length=50, |
| | ): |
| |
|
| | spect_params = SPECT_PARAMS |
| | mel_params = MEL_PARAMS |
| |
|
| | _data_list = [l.strip().split('|') for l in data_list] |
| | self.data_list = [data if len(data) == 3 else (*data, 0) for data in _data_list] |
| | self.text_cleaner = TextCleaner() |
| | self.sr = sr |
| |
|
| | self.df = pd.DataFrame(self.data_list) |
| |
|
| | self.to_melspec = torchaudio.transforms.MelSpectrogram(**MEL_PARAMS) |
| |
|
| | self.mean, self.std = -4, 4 |
| | self.data_augmentation = data_augmentation and (not validation) |
| | self.max_mel_length = 192 |
| | |
| | self.min_length = min_length |
| | with open(OOD_data, 'r', encoding='utf-8') as f: |
| | tl = f.readlines() |
| | idx = 1 if '.wav' in tl[0].split('|')[0] else 0 |
| | self.ptexts = [t.split('|')[idx] for t in tl] |
| | |
| | self.root_path = root_path |
| |
|
| | def __len__(self): |
| | return len(self.data_list) |
| |
|
| | def __getitem__(self, idx): |
| | data = self.data_list[idx] |
| | path = data[0] |
| | |
| | wave, text_tensor, speaker_id = self._load_tensor(data) |
| | |
| | mel_tensor = preprocess(wave).squeeze() |
| | |
| | acoustic_feature = mel_tensor.squeeze() |
| | length_feature = acoustic_feature.size(1) |
| | acoustic_feature = acoustic_feature[:, :(length_feature - length_feature % 2)] |
| | |
| | |
| | ref_data = (self.df[self.df[2] == str(speaker_id)]).sample(n=1).iloc[0].tolist() |
| | ref_mel_tensor, ref_label = self._load_data(ref_data[:3]) |
| | |
| | |
| | |
| | ps = "" |
| | |
| | while len(ps) < self.min_length: |
| | rand_idx = np.random.randint(0, len(self.ptexts) - 1) |
| | ps = self.ptexts[rand_idx] |
| | |
| | text = self.text_cleaner(ps) |
| | text.insert(0, 0) |
| | text.append(0) |
| |
|
| | ref_text = torch.LongTensor(text) |
| | |
| | return speaker_id, acoustic_feature, text_tensor, ref_text, ref_mel_tensor, ref_label, path, wave |
| |
|
| | def _load_tensor(self, data): |
| | wave_path, text, speaker_id = data |
| | speaker_id = int(speaker_id) |
| | full_path = osp.join(self.root_path, wave_path) |
| | try: |
| | wave, sr = sf.read(full_path, dtype="float32") |
| | except Exception as e: |
| | print(f"[BAD] {full_path} -> {e}", flush=True) |
| | raise |
| | if wave.shape[-1] == 2: |
| | wave = wave[:, 0].squeeze() |
| | if sr != 24000: |
| | wave = librosa.resample(wave, orig_sr=sr, target_sr=24000) |
| | print(wave_path, sr) |
| | |
| | wave = np.concatenate([np.zeros([5000]), wave, np.zeros([5000])], axis=0) |
| | |
| | text = self.text_cleaner(text) |
| | |
| | text.insert(0, 0) |
| | text.append(0) |
| | |
| | text = torch.LongTensor(text) |
| |
|
| | return wave, text, speaker_id |
| |
|
| | def _load_data(self, data): |
| | wave, text_tensor, speaker_id = self._load_tensor(data) |
| | mel_tensor = preprocess(wave).squeeze() |
| |
|
| | mel_length = mel_tensor.size(1) |
| | if mel_length > self.max_mel_length: |
| | random_start = np.random.randint(0, mel_length - self.max_mel_length) |
| | mel_tensor = mel_tensor[:, random_start:random_start + self.max_mel_length] |
| |
|
| | return mel_tensor, speaker_id |
| |
|
| |
|
| | class Collater(object): |
| | """ |
| | Args: |
| | adaptive_batch_size (bool): if true, decrease batch size when long data comes. |
| | """ |
| |
|
| | def __init__(self, return_wave=False): |
| | self.text_pad_index = 0 |
| | self.min_mel_length = 192 |
| | self.max_mel_length = 192 |
| | self.return_wave = return_wave |
| | |
| |
|
| | def __call__(self, batch): |
| | |
| | batch_size = len(batch) |
| |
|
| | |
| | lengths = [b[1].shape[1] for b in batch] |
| | batch_indexes = np.argsort(lengths)[::-1] |
| | batch = [batch[bid] for bid in batch_indexes] |
| |
|
| | nmels = batch[0][1].size(0) |
| | max_mel_length = max([b[1].shape[1] for b in batch]) |
| | max_text_length = max([b[2].shape[0] for b in batch]) |
| | max_rtext_length = max([b[3].shape[0] for b in batch]) |
| |
|
| | labels = torch.zeros((batch_size)).long() |
| | mels = torch.zeros((batch_size, nmels, max_mel_length)).float() |
| | texts = torch.zeros((batch_size, max_text_length)).long() |
| | ref_texts = torch.zeros((batch_size, max_rtext_length)).long() |
| |
|
| | input_lengths = torch.zeros(batch_size).long() |
| | ref_lengths = torch.zeros(batch_size).long() |
| | output_lengths = torch.zeros(batch_size).long() |
| | ref_mels = torch.zeros((batch_size, nmels, self.max_mel_length)).float() |
| | ref_labels = torch.zeros((batch_size)).long() |
| | paths = ['' for _ in range(batch_size)] |
| | waves = [None for _ in range(batch_size)] |
| | |
| | for bid, (label, mel, text, ref_text, ref_mel, ref_label, path, wave) in enumerate(batch): |
| | mel_size = mel.size(1) |
| | text_size = text.size(0) |
| | rtext_size = ref_text.size(0) |
| | labels[bid] = label |
| | mels[bid, :, :mel_size] = mel |
| | texts[bid, :text_size] = text |
| | ref_texts[bid, :rtext_size] = ref_text |
| | input_lengths[bid] = text_size |
| | ref_lengths[bid] = rtext_size |
| | output_lengths[bid] = mel_size |
| | paths[bid] = path |
| | ref_mel_size = ref_mel.size(1) |
| | ref_mels[bid, :, :ref_mel_size] = ref_mel |
| | |
| | ref_labels[bid] = ref_label |
| | waves[bid] = wave |
| |
|
| | return waves, texts, input_lengths, ref_texts, ref_lengths, mels, output_lengths, ref_mels |
| |
|
| |
|
| |
|
| | def build_dataloader(path_list, |
| | root_path, |
| | validation=False, |
| | OOD_data="Data/OOD_texts.txt", |
| | min_length=50, |
| | batch_size=4, |
| | num_workers=1, |
| | device='cpu', |
| | collate_config={}, |
| | dataset_config={}): |
| | |
| | dataset = FilePathDataset(path_list, root_path, OOD_data=OOD_data, min_length=min_length, validation=validation, **dataset_config) |
| | collate_fn = Collater(**collate_config) |
| | data_loader = DataLoader(dataset, |
| | batch_size=batch_size, |
| | shuffle=(not validation), |
| | num_workers=num_workers, |
| | drop_last=(not validation), |
| | collate_fn=collate_fn, |
| | pin_memory=(device != 'cpu')) |
| |
|
| | return data_loader |
| |
|
| |
|