| import glob |
| import importlib |
| import os |
| from resemblyzer import VoiceEncoder |
| import numpy as np |
| import torch |
| import torch.distributed as dist |
| from torch.utils.data import DistributedSampler |
| import utils |
| from tasks.base_task import BaseDataset |
| from utils.hparams import hparams |
| from utils.indexed_datasets import IndexedDataset |
| from tqdm import tqdm |
|
|
| class EndlessDistributedSampler(DistributedSampler): |
| def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): |
| if num_replicas is None: |
| if not dist.is_available(): |
| raise RuntimeError("Requires distributed package to be available") |
| num_replicas = dist.get_world_size() |
| if rank is None: |
| if not dist.is_available(): |
| raise RuntimeError("Requires distributed package to be available") |
| rank = dist.get_rank() |
| self.dataset = dataset |
| self.num_replicas = num_replicas |
| self.rank = rank |
| self.epoch = 0 |
| self.shuffle = shuffle |
|
|
| g = torch.Generator() |
| g.manual_seed(self.epoch) |
| if self.shuffle: |
| indices = [i for _ in range(1000) for i in torch.randperm( |
| len(self.dataset), generator=g).tolist()] |
| else: |
| indices = [i for _ in range(1000) for i in list(range(len(self.dataset)))] |
| indices = indices[:len(indices) // self.num_replicas * self.num_replicas] |
| indices = indices[self.rank::self.num_replicas] |
| self.indices = indices |
|
|
| def __iter__(self): |
| return iter(self.indices) |
|
|
| def __len__(self): |
| return len(self.indices) |
|
|
|
|
| class VocoderDataset(BaseDataset): |
| def __init__(self, prefix, shuffle=False): |
| super().__init__(shuffle) |
| self.hparams = hparams |
| self.prefix = prefix |
| self.data_dir = hparams['binary_data_dir'] |
| self.is_infer = prefix == 'test' |
| self.batch_max_frames = 0 if self.is_infer else hparams['max_samples'] // hparams['hop_size'] |
| self.aux_context_window = hparams['aux_context_window'] |
| self.hop_size = hparams['hop_size'] |
| if self.is_infer and hparams['test_input_dir'] != '': |
| self.indexed_ds, self.sizes = self.load_test_inputs(hparams['test_input_dir']) |
| self.avail_idxs = [i for i, _ in enumerate(self.sizes)] |
| elif self.is_infer and hparams['test_mel_dir'] != '': |
| self.indexed_ds, self.sizes = self.load_mel_inputs(hparams['test_mel_dir']) |
| self.avail_idxs = [i for i, _ in enumerate(self.sizes)] |
| else: |
| self.indexed_ds = None |
| self.sizes = np.load(f'{self.data_dir}/{self.prefix}_lengths.npy') |
| self.avail_idxs = [idx for idx, s in enumerate(self.sizes) if |
| s - 2 * self.aux_context_window > self.batch_max_frames] |
| print(f"| {len(self.sizes) - len(self.avail_idxs)} short items are skipped in {prefix} set.") |
| self.sizes = [s for idx, s in enumerate(self.sizes) if |
| s - 2 * self.aux_context_window > self.batch_max_frames] |
|
|
| def _get_item(self, index): |
| if self.indexed_ds is None: |
| self.indexed_ds = IndexedDataset(f'{self.data_dir}/{self.prefix}') |
| item = self.indexed_ds[index] |
| return item |
|
|
| def __getitem__(self, index): |
| index = self.avail_idxs[index] |
| item = self._get_item(index) |
| sample = { |
| "id": index, |
| "item_name": item['item_name'], |
| "mel": torch.FloatTensor(item['mel']), |
| "wav": torch.FloatTensor(item['wav'].astype(np.float32)), |
| } |
| if 'pitch' in item: |
| sample['pitch'] = torch.LongTensor(item['pitch']) |
| sample['f0'] = torch.FloatTensor(item['f0']) |
|
|
| if hparams.get('use_spk_embed', False): |
| sample["spk_embed"] = torch.Tensor(item['spk_embed']) |
| if hparams.get('use_emo_embed', False): |
| sample["emo_embed"] = torch.Tensor(item['emo_embed']) |
|
|
| return sample |
|
|
| def collater(self, batch): |
| if len(batch) == 0: |
| return {} |
|
|
| y_batch, c_batch, p_batch, f0_batch = [], [], [], [] |
| item_name = [] |
| have_pitch = 'pitch' in batch[0] |
| for idx in range(len(batch)): |
| item_name.append(batch[idx]['item_name']) |
| x, c = batch[idx]['wav'] if self.hparams['use_wav'] else None, batch[idx]['mel'].squeeze(0) |
| if have_pitch: |
| p = batch[idx]['pitch'] |
| f0 = batch[idx]['f0'] |
| if self.hparams['use_wav']:self._assert_ready_for_upsampling(x, c, self.hop_size, 0) |
| if len(c) - 2 * self.aux_context_window > self.batch_max_frames: |
| |
| batch_max_frames = self.batch_max_frames if self.batch_max_frames != 0 else len( |
| c) - 2 * self.aux_context_window - 1 |
| batch_max_steps = batch_max_frames * self.hop_size |
| interval_start = self.aux_context_window |
| interval_end = len(c) - batch_max_frames - self.aux_context_window |
| start_frame = np.random.randint(interval_start, interval_end) |
| start_step = start_frame * self.hop_size |
| if self.hparams['use_wav']:y = x[start_step: start_step + batch_max_steps] |
| c = c[start_frame - self.aux_context_window: |
| start_frame + self.aux_context_window + batch_max_frames] |
| if have_pitch: |
| p = p[start_frame - self.aux_context_window: |
| start_frame + self.aux_context_window + batch_max_frames] |
| f0 = f0[start_frame - self.aux_context_window: |
| start_frame + self.aux_context_window + batch_max_frames] |
| if self.hparams['use_wav']:self._assert_ready_for_upsampling(y, c, self.hop_size, self.aux_context_window) |
| else: |
| print(f"Removed short sample from batch (length={len(x)}).") |
| continue |
| if self.hparams['use_wav']:y_batch += [y.reshape(-1, 1)] |
| c_batch += [c] |
| if have_pitch: |
| p_batch += [p] |
| f0_batch += [f0] |
|
|
| |
| if self.hparams['use_wav']:y_batch = utils.collate_2d(y_batch, 0).transpose(2, 1) |
| c_batch = utils.collate_2d(c_batch, 0).transpose(2, 1) |
| if have_pitch: |
| p_batch = utils.collate_1d(p_batch, 0) |
| f0_batch = utils.collate_1d(f0_batch, 0) |
| else: |
| p_batch, f0_batch = None, None |
|
|
| |
| if self.hparams['use_wav']: z_batch = torch.randn(y_batch.size()) |
| else: z_batch=[] |
| return { |
| 'z': z_batch, |
| 'mels': c_batch, |
| 'wavs': y_batch, |
| 'pitches': p_batch, |
| 'f0': f0_batch, |
| 'item_name': item_name |
| } |
|
|
| @staticmethod |
| def _assert_ready_for_upsampling(x, c, hop_size, context_window): |
| """Assert the audio and feature lengths are correctly adjusted for upsamping.""" |
| assert len(x) == (len(c) - 2 * context_window) * hop_size |
|
|
| def load_test_inputs(self, test_input_dir, spk_id=0): |
| inp_wav_paths = sorted(glob.glob(f'{test_input_dir}/*.wav') + glob.glob(f'{test_input_dir}/**/*.mp3')) |
| sizes = [] |
| items = [] |
|
|
| binarizer_cls = hparams.get("binarizer_cls", 'data_gen.tts.base_binarizer.BaseBinarizer') |
| pkg = ".".join(binarizer_cls.split(".")[:-1]) |
| cls_name = binarizer_cls.split(".")[-1] |
| binarizer_cls = getattr(importlib.import_module(pkg), cls_name) |
| binarization_args = hparams['binarization_args'] |
|
|
| for wav_fn in inp_wav_paths: |
| item_name = wav_fn[len(test_input_dir) + 1:].replace("/", "_") |
| item = binarizer_cls.process_item( |
| item_name, wav_fn, binarization_args) |
| items.append(item) |
| sizes.append(item['len']) |
| return items, sizes |
|
|
| def load_mel_inputs(self, test_input_dir, spk_id=0): |
| inp_mel_paths = sorted(glob.glob(f'{test_input_dir}/*.npy')) |
| sizes = [] |
| items = [] |
|
|
| binarizer_cls = hparams.get("binarizer_cls", 'data_gen.tts.base_binarizer.BaseBinarizer') |
| pkg = ".".join(binarizer_cls.split(".")[:-1]) |
| cls_name = binarizer_cls.split(".")[-1] |
| binarizer_cls = getattr(importlib.import_module(pkg), cls_name) |
| binarization_args = hparams['binarization_args'] |
|
|
| for mel in inp_mel_paths: |
| mel_input = np.load(mel) |
| mel_input = torch.FloatTensor(mel_input) |
| item_name = mel[len(test_input_dir) + 1:].replace("/", "_") |
| item = binarizer_cls.process_mel_item(item_name, mel_input, None, binarization_args) |
| items.append(item) |
| sizes.append(item['len']) |
| return items, sizes |
|
|