| import copy |
| from distutils.version import LooseVersion |
| from io import StringIO |
| from pathlib import Path |
| from typing import Callable |
| from typing import Collection |
| from typing import Dict |
| from typing import Iterable |
| from typing import Tuple |
| from typing import Union |
|
|
| import kaldiio |
| import numpy as np |
| import soundfile |
| import torch |
| from typeguard import check_argument_types |
|
|
| from espnet2.train.dataset import ESPnetDataset |
|
|
| if LooseVersion(torch.__version__) >= LooseVersion("1.2"): |
| from torch.utils.data.dataset import IterableDataset |
| else: |
| from torch.utils.data.dataset import Dataset as IterableDataset |
|
|
|
|
| def load_kaldi(input): |
| retval = kaldiio.load_mat(input) |
| if isinstance(retval, tuple): |
| assert len(retval) == 2, len(retval) |
| if isinstance(retval[0], int) and isinstance(retval[1], np.ndarray): |
| |
| rate, array = retval |
| elif isinstance(retval[1], int) and isinstance(retval[0], np.ndarray): |
| |
| array, rate = retval |
| else: |
| raise RuntimeError(f"Unexpected type: {type(retval[0])}, {type(retval[1])}") |
|
|
| |
| |
|
|
| else: |
| |
| assert isinstance(retval, np.ndarray), type(retval) |
| array = retval |
| return array |
|
|
|
|
| DATA_TYPES = { |
| "sound": lambda x: soundfile.read(x)[0], |
| "kaldi_ark": load_kaldi, |
| "npy": np.load, |
| "text_int": lambda x: np.loadtxt( |
| StringIO(x), ndmin=1, dtype=np.long, delimiter=" " |
| ), |
| "csv_int": lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.long, delimiter=","), |
| "text_float": lambda x: np.loadtxt( |
| StringIO(x), ndmin=1, dtype=np.float32, delimiter=" " |
| ), |
| "csv_float": lambda x: np.loadtxt( |
| StringIO(x), ndmin=1, dtype=np.float32, delimiter="," |
| ), |
| "text": lambda x: x, |
| } |
|
|
|
|
| class IterableESPnetDataset(IterableDataset): |
| """Pytorch Dataset class for ESPNet. |
| |
| Examples: |
| >>> dataset = IterableESPnetDataset([('wav.scp', 'input', 'sound'), |
| ... ('token_int', 'output', 'text_int')], |
| ... ) |
| >>> for uid, data in dataset: |
| ... data |
| {'input': per_utt_array, 'output': per_utt_array} |
| """ |
|
|
| def __init__( |
| self, |
| path_name_type_list: Collection[Tuple[str, str, str]], |
| preprocess: Callable[ |
| [str, Dict[str, np.ndarray]], Dict[str, np.ndarray] |
| ] = None, |
| float_dtype: str = "float32", |
| int_dtype: str = "long", |
| key_file: str = None, |
| ): |
| assert check_argument_types() |
| if len(path_name_type_list) == 0: |
| raise ValueError( |
| '1 or more elements are required for "path_name_type_list"' |
| ) |
|
|
| path_name_type_list = copy.deepcopy(path_name_type_list) |
| self.preprocess = preprocess |
|
|
| self.float_dtype = float_dtype |
| self.int_dtype = int_dtype |
| self.key_file = key_file |
|
|
| self.debug_info = {} |
| non_iterable_list = [] |
| self.path_name_type_list = [] |
|
|
| for path, name, _type in path_name_type_list: |
| if name in self.debug_info: |
| raise RuntimeError(f'"{name}" is duplicated for data-key') |
| self.debug_info[name] = path, _type |
| if _type not in DATA_TYPES: |
| non_iterable_list.append((path, name, _type)) |
| else: |
| self.path_name_type_list.append((path, name, _type)) |
|
|
| if len(non_iterable_list) != 0: |
| |
| self.non_iterable_dataset = ESPnetDataset( |
| path_name_type_list=non_iterable_list, |
| preprocess=preprocess, |
| float_dtype=float_dtype, |
| int_dtype=int_dtype, |
| ) |
| else: |
| self.non_iterable_dataset = None |
|
|
| if Path(Path(path_name_type_list[0][0]).parent, "utt2category").exists(): |
| self.apply_utt2category = True |
| else: |
| self.apply_utt2category = False |
|
|
| def has_name(self, name) -> bool: |
| return name in self.debug_info |
|
|
| def names(self) -> Tuple[str, ...]: |
| return tuple(self.debug_info) |
|
|
| def __repr__(self): |
| _mes = self.__class__.__name__ |
| _mes += "(" |
| for name, (path, _type) in self.debug_info.items(): |
| _mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}' |
| _mes += f"\n preprocess: {self.preprocess})" |
| return _mes |
|
|
| def __iter__(self) -> Iterable[Tuple[Union[str, int], Dict[str, np.ndarray]]]: |
| if self.key_file is not None: |
| uid_iter = ( |
| line.rstrip().split(maxsplit=1)[0] |
| for line in open(self.key_file, encoding="utf-8") |
| ) |
| elif len(self.path_name_type_list) != 0: |
| uid_iter = ( |
| line.rstrip().split(maxsplit=1)[0] |
| for line in open(self.path_name_type_list[0][0], encoding="utf-8") |
| ) |
| else: |
| uid_iter = iter(self.non_iterable_dataset) |
|
|
| files = [open(lis[0], encoding="utf-8") for lis in self.path_name_type_list] |
|
|
| worker_info = torch.utils.data.get_worker_info() |
|
|
| linenum = 0 |
| count = 0 |
| for count, uid in enumerate(uid_iter, 1): |
| |
| if worker_info is not None: |
| if (count - 1) % worker_info.num_workers != worker_info.id: |
| continue |
|
|
| |
| while True: |
| keys = [] |
| values = [] |
| for f in files: |
| linenum += 1 |
| try: |
| line = next(f) |
| except StopIteration: |
| raise RuntimeError(f"{uid} is not found in the files") |
| sps = line.rstrip().split(maxsplit=1) |
| if len(sps) != 2: |
| raise RuntimeError( |
| f"This line doesn't include a space:" |
| f" {f}:L{linenum}: {line})" |
| ) |
| key, value = sps |
| keys.append(key) |
| values.append(value) |
|
|
| for k_idx, k in enumerate(keys): |
| if k != keys[0]: |
| raise RuntimeError( |
| f"Keys are mismatched. Text files (idx={k_idx}) is " |
| f"not sorted or not having same keys at L{linenum}" |
| ) |
|
|
| |
| if len(keys) == 0 or keys[0] == uid: |
| break |
|
|
| |
| data = {} |
| |
| for value, (path, name, _type) in zip(values, self.path_name_type_list): |
| func = DATA_TYPES[_type] |
| |
| array = func(value) |
| data[name] = array |
| if self.non_iterable_dataset is not None: |
| |
| _, from_non_iterable = self.non_iterable_dataset[uid] |
| data.update(from_non_iterable) |
|
|
| |
| |
| if self.preprocess is not None: |
| data = self.preprocess(uid, data) |
|
|
| |
| for name in data: |
| value = data[name] |
| if not isinstance(value, np.ndarray): |
| raise RuntimeError( |
| f"All values must be converted to np.ndarray object " |
| f'by preprocessing, but "{name}" is still {type(value)}.' |
| ) |
|
|
| |
| if value.dtype.kind == "f": |
| value = value.astype(self.float_dtype) |
| elif value.dtype.kind == "i": |
| value = value.astype(self.int_dtype) |
| else: |
| raise NotImplementedError(f"Not supported dtype: {value.dtype}") |
| data[name] = value |
|
|
| yield uid, data |
|
|
| if count == 0: |
| raise RuntimeError("No iteration") |
|
|