Spaces:
Runtime error
Runtime error
| # Copyright 2023 (authors: Feiteng Li) | |
| # | |
| # See ../../../../LICENSE for clarification regarding multiple authors | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| modified from lhoste.dataset.speech_synthesis.py | |
| """ | |
| import torch | |
| import math | |
| import h5py | |
| from tokenizers import Tokenizer | |
| from typing import Union, List | |
| import numpy as np | |
| from tqdm import tqdm | |
| _pad = '_' | |
| _punctuation = ',.!?-~…' | |
| _letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ ' | |
| symbols = [_pad] + list(_punctuation) + list(_letters) | |
| language_dict = { | |
| 'en': 0, | |
| 'zh': 1, | |
| 'ja': 2, | |
| } | |
| def seq2phone(tokens: Union[List, np.ndarray]): | |
| """ | |
| Convert tokenized phoneme ID sequence back to phoneme string | |
| :param tokens: phoneme tokens | |
| :return: recovered phoneme sequence | |
| """ | |
| phones = "".join([symbols[i] for i in tokens]) | |
| return phones | |
| class DynamicBatchSampler(torch.utils.data.Sampler): | |
| def __init__(self, sampler, num_tokens_fn, num_buckets=100, min_size=0, max_size=1000, | |
| max_tokens=None, max_sentences=None, drop_last=False): | |
| """ | |
| :param sampler: | |
| :param num_tokens_fn: 根据idx返回样本的长度的函数 | |
| :param num_buckets: 利用桶原理将相似长度的样本放在一个batchsize中,桶的数量 | |
| :param min_size: 最小长度的样本, 小于这个值的样本会被过滤掉。 依据这个值来创建样桶 | |
| :param max_size: 最大长度的样本 | |
| :param max_sentences: batch_size, 但是这里可以通过max_sentences 和 max_tokens 共同控制最终的大小 | |
| """ | |
| super(DynamicBatchSampler, self).__init__(sampler) | |
| self.sampler = sampler | |
| self.num_tokens_fn = num_tokens_fn | |
| self.num_buckets = num_buckets | |
| self.min_size = min_size | |
| self.max_size = max_size | |
| assert max_size <= max_tokens, "max_size should be smaller than max tokens" | |
| assert max_tokens is not None or max_sentences is not None, \ | |
| "max_tokens and max_sentences should not be null at the same time, please specify one parameter at least" | |
| self.max_tokens = max_tokens if max_tokens is not None else float('Inf') | |
| self.max_sentences = max_sentences if max_sentences is not None else float('Inf') | |
| self.drop_last = drop_last | |
| def set_epoch(self, epoch): | |
| self.sampler.set_epoch(epoch) | |
| def is_batch_full(self, num_tokens, batch): | |
| if len(batch) == 0: | |
| return False | |
| if len(batch) == self.max_sentences: | |
| return True | |
| if num_tokens > self.max_tokens: | |
| return True | |
| return False | |
| def __iter__(self): | |
| buckets = [[] for _ in range(self.num_buckets)] | |
| sample_len = [0] * self.num_buckets | |
| for idx in self.sampler: | |
| idx_length = self.num_tokens_fn(idx) | |
| if not (self.min_size <= idx_length <= self.max_size): | |
| print("sentence at index {} of size {} exceeds max_tokens, the sentence is ignored".format(idx, idx_length)) | |
| continue | |
| index_buckets = math.floor((idx_length - self.min_size) / (self.max_size - self.min_size + 1) | |
| * self.num_buckets) | |
| sample_len[index_buckets] = max(sample_len[index_buckets], idx_length) | |
| num_tokens = (len(buckets[index_buckets]) + 1) * sample_len[index_buckets] | |
| if self.is_batch_full(num_tokens, buckets[index_buckets]): | |
| # yield this batch | |
| yield buckets[index_buckets] | |
| buckets[index_buckets] = [] | |
| sample_len[index_buckets] = 0 | |
| buckets[index_buckets].append(idx) | |
| # process left-over | |
| leftover_batch = [] | |
| leftover_sample_len = 0 | |
| leftover = [idx for bucket in buckets for idx in bucket] | |
| for idx in leftover: | |
| idx_length = self.num_tokens_fn(idx) | |
| leftover_sample_len = max(leftover_sample_len, idx_length) | |
| num_tokens = (len(leftover_batch) + 1) * leftover_sample_len | |
| if self.is_batch_full(num_tokens, leftover_batch): | |
| yield leftover_batch | |
| leftover_batch = [] | |
| leftover_sample_len = 0 | |
| leftover_batch.append(idx) | |
| if len(leftover_batch) > 0 and not self.drop_last: | |
| yield leftover_batch | |
| def __len__(self): | |
| # we do not know the exactly batch size, so do not call len(dataloader) | |
| pass | |
| class AudioDataset(torch.utils.data.Dataset): | |
| def __init__(self, h5_path, ann_path, tokenizer_path): | |
| self.h5_path = h5_path | |
| with open(ann_path, 'r', encoding='utf-8') as f: | |
| lines = f.readlines() | |
| ls = [l.split("|") for l in lines] | |
| ls_T = list(zip(*ls)) | |
| del ls_T[-1] | |
| self.h5_paths, self.durations, self.langs, self.texts = \ | |
| list(ls_T[0]), list(ls_T[1]), list(ls_T[2]), list(ls_T[3]) | |
| self.durations = [float(dur) for dur in self.durations] | |
| self.tokenizer = Tokenizer.from_file(tokenizer_path) | |
| self._archive = None | |
| def __len__(self): | |
| return len(self.h5_paths) | |
| def get_dur(self, idx): | |
| return self.durations[idx] | |
| def archive(self): | |
| if self._archive is None: # lazy loading here! | |
| self._archive = h5py.File(self.h5_path, "r") | |
| return self._archive | |
| def __getitem__(self, idx): | |
| archive = self.archive | |
| h5_path = self.h5_paths[idx] | |
| sub = archive[h5_path] | |
| audio_tokens = sub['audio'][()] | |
| phone_tokens = sub['text'][()] | |
| dur = self.durations[idx] | |
| lang = self.langs[idx] | |
| text = self.texts[idx] | |
| # tokenization should be done within dataloader | |
| phones = seq2phone(phone_tokens) | |
| phones = phones.replace(" ", "_") | |
| if not len(phones): | |
| cptpho_tokens = self.tokenizer.encode(text).ids | |
| else: | |
| cptpho_tokens = self.tokenizer.encode(phones).ids | |
| assert len(cptpho_tokens) | |
| return { | |
| 'utt_id': h5_path, | |
| 'text': text, | |
| 'audio': None, | |
| 'audio_lens': None, | |
| 'audio_features': audio_tokens, | |
| 'audio_features_lens': len(audio_tokens.T), | |
| 'text_tokens': np.array(cptpho_tokens), | |
| 'text_tokens_lens': len(cptpho_tokens), | |
| 'language': language_dict[lang], | |
| } | |
| def collate(batch): | |
| utt_id_s = [b['utt_id'] for b in batch] | |
| text_s = [b['text'] for b in batch] | |
| audio_s = [b['audio'] for b in batch] | |
| audio_lens_s = [b['audio_lens'] for b in batch] | |
| audio_features_lens_s = [b['audio_features_lens'] for b in batch] | |
| # create an empty tensor with maximum audio feature length | |
| audio_features_s = torch.zeros([len(batch), max(audio_features_lens_s), 8], dtype=torch.int64) - 1 # audio pad with -1 | |
| text_tokens_lens_s = [b['text_tokens_lens'] for b in batch] | |
| # create an empty tensor with maximum text tokens length | |
| text_tokens_s = torch.zeros([len(batch), max(text_tokens_lens_s)], dtype=torch.int64) + 3 # [PAD] token id 3 | |
| language_s = [b['language'] for b in batch] | |
| for i, b in enumerate(batch): | |
| audio_features = b['audio_features'] | |
| audio_features_lens = b['audio_features_lens'] | |
| audio_features_s[i, :audio_features_lens, :] = torch.LongTensor(audio_features.T) | |
| text_tokens = b['text_tokens'] | |
| text_tokens_lens = b['text_tokens_lens'] | |
| text_tokens_s[i, :text_tokens_lens] = torch.LongTensor(text_tokens) | |
| batch = { | |
| 'utt_id': utt_id_s, | |
| 'text': text_s, | |
| 'audio': audio_s, | |
| 'audio_lens': audio_lens_s, | |
| 'audio_features': audio_features_s, | |
| 'audio_features_lens': torch.LongTensor(np.array(audio_features_lens_s)), | |
| 'text_tokens': text_tokens_s, | |
| 'text_tokens_lens': torch.LongTensor(np.array(text_tokens_lens_s)), | |
| 'languages': torch.LongTensor(np.array(language_s)), | |
| } | |
| return batch | |
| def create_dataloader(data_dir="/root/valle/egs/mix", n_gpus=1, rank=0, num_workers=0, num_buckets=10, max_duration=120): | |
| train_dataset = AudioDataset(h5_path=f"{data_dir}/audio_sum.hdf5", | |
| ann_path=f"{data_dir}/audio_ann_sum.txt", | |
| tokenizer_path=f"{data_dir}/bpe_69.json") | |
| ran_sampler = torch.utils.data.distributed.DistributedSampler( | |
| train_dataset, | |
| num_replicas=n_gpus, | |
| rank=rank, | |
| shuffle=True, | |
| ) | |
| dynamic_sampler = DynamicBatchSampler(ran_sampler, train_dataset.get_dur, num_buckets=num_buckets, max_size=20, | |
| max_tokens=max_duration) | |
| train_loader = torch.utils.data.DataLoader(train_dataset, num_workers=num_workers, collate_fn=collate, | |
| batch_sampler=dynamic_sampler) | |
| return train_loader | |