Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| # Copyright 2019 Tomoki Hayashi | |
| # MIT License (https://opensource.org/licenses/MIT) | |
| """Dataset modules based on kaldi-style scp files.""" | |
| import logging | |
| from multiprocessing import Manager | |
| import kaldiio | |
| import numpy as np | |
| from torch.utils.data import Dataset | |
| from parallel_wavegan.utils import HDF5ScpLoader | |
| from parallel_wavegan.utils import NpyScpLoader | |
| def _get_feats_scp_loader(feats_scp): | |
| # read the first line of feats.scp file | |
| with open(feats_scp) as f: | |
| key, value = f.readlines()[0].replace("\n", "").split() | |
| # check scp type | |
| if ":" in value: | |
| value_1, value_2 = value.split(":") | |
| if value_1.endswith(".ark"): | |
| # kaldi-ark case: utt_id_1 /path/to/utt_id_1.ark:index | |
| return kaldiio.load_scp(feats_scp) | |
| elif value_1.endswith(".h5"): | |
| # hdf5 case with path in hdf5: utt_id_1 /path/to/utt_id_1.h5:feats | |
| return HDF5ScpLoader(feats_scp) | |
| else: | |
| raise ValueError("Not supported feats.scp type.") | |
| else: | |
| if value.endswith(".h5"): | |
| # hdf5 case without path in hdf5: utt_id_1 /path/to/utt_id_1.h5 | |
| return HDF5ScpLoader(feats_scp) | |
| elif value.endswith(".npy"): | |
| # npy case: utt_id_1 /path/to/utt_id_1.npy | |
| return NpyScpLoader(feats_scp) | |
| else: | |
| raise ValueError("Not supported feats.scp type.") | |
| class AudioMelSCPDataset(Dataset): | |
| """PyTorch compatible audio and mel dataset based on kaldi-stype scp files.""" | |
| def __init__( | |
| self, | |
| wav_scp, | |
| feats_scp, | |
| segments=None, | |
| audio_length_threshold=None, | |
| mel_length_threshold=None, | |
| return_utt_id=False, | |
| return_sampling_rate=False, | |
| allow_cache=False, | |
| ): | |
| """Initialize dataset. | |
| Args: | |
| wav_scp (str): Kaldi-style wav.scp file. | |
| feats_scp (str): Kaldi-style fests.scp file. | |
| segments (str): Kaldi-style segments file. | |
| audio_length_threshold (int): Threshold to remove short audio files. | |
| mel_length_threshold (int): Threshold to remove short feature files. | |
| return_utt_id (bool): Whether to return utterance id. | |
| return_sampling_rate (bool): Wheter to return sampling rate. | |
| allow_cache (bool): Whether to allow cache of the loaded files. | |
| """ | |
| # load scp as lazy dict | |
| audio_loader = kaldiio.load_scp(wav_scp, segments=segments) | |
| mel_loader = _get_feats_scp_loader(feats_scp) | |
| audio_keys = list(audio_loader.keys()) | |
| mel_keys = list(mel_loader.keys()) | |
| # filter by threshold | |
| if audio_length_threshold is not None: | |
| audio_lengths = [audio.shape[0] for _, audio in audio_loader.values()] | |
| idxs = [ | |
| idx | |
| for idx in range(len(audio_keys)) | |
| if audio_lengths[idx] > audio_length_threshold | |
| ] | |
| if len(audio_keys) != len(idxs): | |
| logging.warning( | |
| f"Some files are filtered by audio length threshold " | |
| f"({len(audio_keys)} -> {len(idxs)})." | |
| ) | |
| audio_keys = [audio_keys[idx] for idx in idxs] | |
| mel_keys = [mel_keys[idx] for idx in idxs] | |
| if mel_length_threshold is not None: | |
| mel_lengths = [mel.shape[0] for mel in mel_loader.values()] | |
| idxs = [ | |
| idx | |
| for idx in range(len(mel_keys)) | |
| if mel_lengths[idx] > mel_length_threshold | |
| ] | |
| if len(mel_keys) != len(idxs): | |
| logging.warning( | |
| f"Some files are filtered by mel length threshold " | |
| f"({len(mel_keys)} -> {len(idxs)})." | |
| ) | |
| audio_keys = [audio_keys[idx] for idx in idxs] | |
| mel_keys = [mel_keys[idx] for idx in idxs] | |
| # assert the number of files | |
| assert len(audio_keys) == len( | |
| mel_keys | |
| ), f"Number of audio and mel files are different ({len(audio_keys)} vs {len(mel_keys)})." | |
| self.audio_loader = audio_loader | |
| self.mel_loader = mel_loader | |
| self.utt_ids = audio_keys | |
| self.return_utt_id = return_utt_id | |
| self.return_sampling_rate = return_sampling_rate | |
| self.allow_cache = allow_cache | |
| if allow_cache: | |
| # NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0 | |
| self.manager = Manager() | |
| self.caches = self.manager.list() | |
| self.caches += [() for _ in range(len(self.utt_ids))] | |
| def __getitem__(self, idx): | |
| """Get specified idx items. | |
| Args: | |
| idx (int): Index of the item. | |
| Returns: | |
| str: Utterance id (only in return_utt_id = True). | |
| ndarray or tuple: Audio signal (T,) or (w/ sampling rate if return_sampling_rate = True). | |
| ndarray: Feature (T', C). | |
| """ | |
| if self.allow_cache and len(self.caches[idx]) != 0: | |
| return self.caches[idx] | |
| utt_id = self.utt_ids[idx] | |
| fs, audio = self.audio_loader[utt_id] | |
| mel = self.mel_loader[utt_id] | |
| # normalize audio signal to be [-1, 1] | |
| audio = audio.astype(np.float32) | |
| audio /= 1 << (16 - 1) # assume that wav is PCM 16 bit | |
| if self.return_sampling_rate: | |
| audio = (audio, fs) | |
| if self.return_utt_id: | |
| items = utt_id, audio, mel | |
| else: | |
| items = audio, mel | |
| if self.allow_cache: | |
| self.caches[idx] = items | |
| return items | |
| def __len__(self): | |
| """Return dataset length. | |
| Returns: | |
| int: The length of dataset. | |
| """ | |
| return len(self.utt_ids) | |
| class AudioSCPDataset(Dataset): | |
| """PyTorch compatible audio dataset based on kaldi-stype scp files.""" | |
| def __init__( | |
| self, | |
| wav_scp, | |
| segments=None, | |
| audio_length_threshold=None, | |
| return_utt_id=False, | |
| return_sampling_rate=False, | |
| allow_cache=False, | |
| ): | |
| """Initialize dataset. | |
| Args: | |
| wav_scp (str): Kaldi-style wav.scp file. | |
| segments (str): Kaldi-style segments file. | |
| audio_length_threshold (int): Threshold to remove short audio files. | |
| return_utt_id (bool): Whether to return utterance id. | |
| return_sampling_rate (bool): Wheter to return sampling rate. | |
| allow_cache (bool): Whether to allow cache of the loaded files. | |
| """ | |
| # load scp as lazy dict | |
| audio_loader = kaldiio.load_scp(wav_scp, segments=segments) | |
| audio_keys = list(audio_loader.keys()) | |
| # filter by threshold | |
| if audio_length_threshold is not None: | |
| audio_lengths = [audio.shape[0] for _, audio in audio_loader.values()] | |
| idxs = [ | |
| idx | |
| for idx in range(len(audio_keys)) | |
| if audio_lengths[idx] > audio_length_threshold | |
| ] | |
| if len(audio_keys) != len(idxs): | |
| logging.warning( | |
| f"Some files are filtered by audio length threshold " | |
| f"({len(audio_keys)} -> {len(idxs)})." | |
| ) | |
| audio_keys = [audio_keys[idx] for idx in idxs] | |
| self.audio_loader = audio_loader | |
| self.utt_ids = audio_keys | |
| self.return_utt_id = return_utt_id | |
| self.return_sampling_rate = return_sampling_rate | |
| self.allow_cache = allow_cache | |
| if allow_cache: | |
| # NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0 | |
| self.manager = Manager() | |
| self.caches = self.manager.list() | |
| self.caches += [() for _ in range(len(self.utt_ids))] | |
| def __getitem__(self, idx): | |
| """Get specified idx items. | |
| Args: | |
| idx (int): Index of the item. | |
| Returns: | |
| str: Utterance id (only in return_utt_id = True). | |
| ndarray or tuple: Audio signal (T,) or (w/ sampling rate if return_sampling_rate = True). | |
| """ | |
| if self.allow_cache and len(self.caches[idx]) != 0: | |
| return self.caches[idx] | |
| utt_id = self.utt_ids[idx] | |
| fs, audio = self.audio_loader[utt_id] | |
| # normalize audio signal to be [-1, 1] | |
| audio = audio.astype(np.float32) | |
| audio /= 1 << (16 - 1) # assume that wav is PCM 16 bit | |
| if self.return_sampling_rate: | |
| audio = (audio, fs) | |
| if self.return_utt_id: | |
| items = utt_id, audio | |
| else: | |
| items = audio | |
| if self.allow_cache: | |
| self.caches[idx] = items | |
| return items | |
| def __len__(self): | |
| """Return dataset length. | |
| Returns: | |
| int: The length of dataset. | |
| """ | |
| return len(self.utt_ids) | |
| class MelSCPDataset(Dataset): | |
| """PyTorch compatible mel dataset based on kaldi-stype scp files.""" | |
| def __init__( | |
| self, | |
| feats_scp, | |
| mel_length_threshold=None, | |
| return_utt_id=False, | |
| allow_cache=False, | |
| ): | |
| """Initialize dataset. | |
| Args: | |
| feats_scp (str): Kaldi-style fests.scp file. | |
| mel_length_threshold (int): Threshold to remove short feature files. | |
| return_utt_id (bool): Whether to return utterance id. | |
| allow_cache (bool): Whether to allow cache of the loaded files. | |
| """ | |
| # load scp as lazy dict | |
| mel_loader = _get_feats_scp_loader(feats_scp) | |
| mel_keys = list(mel_loader.keys()) | |
| # filter by threshold | |
| if mel_length_threshold is not None: | |
| mel_lengths = [mel.shape[0] for mel in mel_loader.values()] | |
| idxs = [ | |
| idx | |
| for idx in range(len(mel_keys)) | |
| if mel_lengths[idx] > mel_length_threshold | |
| ] | |
| if len(mel_keys) != len(idxs): | |
| logging.warning( | |
| f"Some files are filtered by mel length threshold " | |
| f"({len(mel_keys)} -> {len(idxs)})." | |
| ) | |
| mel_keys = [mel_keys[idx] for idx in idxs] | |
| self.mel_loader = mel_loader | |
| self.utt_ids = mel_keys | |
| self.return_utt_id = return_utt_id | |
| self.allow_cache = allow_cache | |
| if allow_cache: | |
| # NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0 | |
| self.manager = Manager() | |
| self.caches = self.manager.list() | |
| self.caches += [() for _ in range(len(self.utt_ids))] | |
| def __getitem__(self, idx): | |
| """Get specified idx items. | |
| Args: | |
| idx (int): Index of the item. | |
| Returns: | |
| str: Utterance id (only in return_utt_id = True). | |
| ndarray: Feature (T', C). | |
| """ | |
| if self.allow_cache and len(self.caches[idx]) != 0: | |
| return self.caches[idx] | |
| utt_id = self.utt_ids[idx] | |
| mel = self.mel_loader[utt_id] | |
| if self.return_utt_id: | |
| items = utt_id, mel | |
| else: | |
| items = mel | |
| if self.allow_cache: | |
| self.caches[idx] = items | |
| return items | |
| def __len__(self): | |
| """Return dataset length. | |
| Returns: | |
| int: The length of dataset. | |
| """ | |
| return len(self.utt_ids) | |