Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import traceback | |
| import types | |
| from functools import wraps | |
| from itertools import chain | |
| import numpy as np | |
| import torch.utils.data | |
| import torch.nn.functional as F | |
| from torch.utils.data import ConcatDataset | |
| from .hparams import hparams | |
| def collate_1d_or_2d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None, shift_id=1): | |
| if len(values[0].shape) == 1: | |
| return collate_1d(values, pad_idx, left_pad, shift_right, max_len, shift_id) | |
| else: | |
| return collate_2d(values, pad_idx, left_pad, shift_right, max_len) | |
| def collate_1d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None, shift_id=1): | |
| """Convert a list of 1d tensors into a padded 2d tensor.""" | |
| size = max(v.size(0) for v in values) if max_len is None else max_len | |
| res = values[0].new(len(values), size).fill_(pad_idx) | |
| def copy_tensor(src, dst): | |
| assert dst.numel() == src.numel() | |
| if shift_right: | |
| dst[1:] = src[:-1] | |
| dst[0] = shift_id | |
| else: | |
| dst.copy_(src) | |
| for i, v in enumerate(values): | |
| copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)]) | |
| return res | |
| def collate_2d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None): | |
| """Convert a list of 2d tensors into a padded 3d tensor.""" | |
| size = max(v.size(0) for v in values) if max_len is None else max_len | |
| res = values[0].new(len(values), size, values[0].shape[1]).fill_(pad_idx) | |
| def copy_tensor(src, dst): | |
| assert dst.numel() == src.numel() | |
| if shift_right: | |
| dst[1:] = src[:-1] | |
| else: | |
| dst.copy_(src) | |
| for i, v in enumerate(values): | |
| copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)]) | |
| return res | |
| def collate_xd(values, pad_value=0, max_len=None): | |
| size = ((max(v.size(0) for v in values) if max_len is None else max_len), *values[0].shape[1:]) | |
| res = torch.full((len(values), *size), fill_value=pad_value, dtype=values[0].dtype, device=values[0].device) | |
| for i, v in enumerate(values): | |
| res[i, :len(v), ...] = v | |
| return res | |
| def pad_or_cut_1d(values: torch.tensor, tgt_len, pad_value=0): | |
| src_len = values.shape[0] | |
| if src_len < tgt_len: | |
| res = F.pad(values, [0, tgt_len - src_len], value=pad_value) | |
| else: | |
| res = values[:tgt_len] | |
| return res | |
| def pad_or_cut_2d(values: torch.tensor, tgt_len, dim=-1, pad_value=0): | |
| if dim == 0 or dim == -2: | |
| src_len = values.shape[0] | |
| if src_len < tgt_len: | |
| res = F.pad(values, [0, 0, 0, tgt_len - src_len], value=pad_value) | |
| else: | |
| res = values[:tgt_len] | |
| elif dim == 1 or dim == -1: | |
| src_len = values.shape[1] | |
| if src_len < tgt_len: | |
| res = F.pad(values, [0, tgt_len - src_len], value=pad_value) | |
| else: | |
| res = values[:, :tgt_len] | |
| else: | |
| raise RuntimeError(f"Wrong dim number {dim} while the tensor only has {len(values.shape)} dimensions.") | |
| return res | |
| def pad_or_cut_3d(values: torch.tensor, tgt_len, dim=-1, pad_value=0): | |
| if dim == 0 or dim == -3: | |
| src_len = values.shape[0] | |
| if src_len < tgt_len: | |
| res = F.pad(values, [0, 0, 0, 0, 0, tgt_len - src_len], value=pad_value) | |
| else: | |
| res = values[:tgt_len] | |
| elif dim == 1 or dim == -2: | |
| src_len = values.shape[1] | |
| if src_len < tgt_len: | |
| res = F.pad(values, [0, 0, 0, tgt_len - src_len], value=pad_value) | |
| else: | |
| res = values[:, :tgt_len] | |
| elif dim == 2 or dim == -1: | |
| src_len = values.shape[2] | |
| if src_len < tgt_len: | |
| res = F.pad(values, [0, tgt_len - src_len], value=pad_value) | |
| else: | |
| res = values[:, :, :tgt_len] | |
| else: | |
| raise RuntimeError(f"Wrong dim number {dim} while the tensor only has {len(values.shape)} dimensions.") | |
| return res | |
| def pad_or_cut_xd(values, tgt_len, dim=-1, pad_value=0): | |
| if len(values.shape) == 1: | |
| return pad_or_cut_1d(values, tgt_len, pad_value) | |
| elif len(values.shape) == 2: | |
| return pad_or_cut_2d(values, tgt_len, dim, pad_value) | |
| elif len(values.shape) == 3: | |
| return pad_or_cut_3d(values, tgt_len, dim, pad_value) | |
| else: | |
| raise NotImplementedError | |
| def _is_batch_full(batch, num_tokens, max_tokens, max_sentences): | |
| if len(batch) == 0: | |
| return 0 | |
| if len(batch) == max_sentences: | |
| return 1 | |
| if num_tokens > max_tokens: | |
| return 1 | |
| return 0 | |
| def batch_by_size( | |
| indices, num_tokens_fn, max_tokens=None, max_sentences=None, | |
| required_batch_size_multiple=1, distributed=False | |
| ): | |
| """ | |
| Yield mini-batches of indices bucketed by size. Batches may contain | |
| sequences of different lengths. | |
| Args: | |
| indices (List[int]): ordered list of dataset indices | |
| num_tokens_fn (callable): function that returns the number of tokens at | |
| a given index | |
| max_tokens (int, optional): max number of tokens in each batch | |
| (default: None). | |
| max_sentences (int, optional): max number of sentences in each | |
| batch (default: None). | |
| required_batch_size_multiple (int, optional): require batch size to | |
| be a multiple of N (default: 1). | |
| """ | |
| max_tokens = max_tokens if max_tokens is not None else sys.maxsize | |
| max_sentences = max_sentences if max_sentences is not None else sys.maxsize | |
| bsz_mult = required_batch_size_multiple | |
| if isinstance(indices, types.GeneratorType): | |
| indices = np.fromiter(indices, dtype=np.int64, count=-1) | |
| sample_len = 0 | |
| sample_lens = [] | |
| batch = [] | |
| batches = [] | |
| for i in range(len(indices)): | |
| idx = indices[i] | |
| num_tokens = num_tokens_fn(idx) | |
| sample_lens.append(num_tokens) | |
| sample_len = max(sample_len, num_tokens) | |
| assert sample_len <= max_tokens, ( | |
| "sentence at index {} of size {} exceeds max_tokens " | |
| "limit of {}!".format(idx, sample_len, max_tokens) | |
| ) | |
| num_tokens = (len(batch) + 1) * sample_len | |
| if _is_batch_full(batch, num_tokens, max_tokens, max_sentences): | |
| mod_len = max( | |
| bsz_mult * (len(batch) // bsz_mult), | |
| len(batch) % bsz_mult, | |
| ) | |
| batches.append(batch[:mod_len]) | |
| batch = batch[mod_len:] | |
| sample_lens = sample_lens[mod_len:] | |
| sample_len = max(sample_lens) if len(sample_lens) > 0 else 0 | |
| batch.append(idx) | |
| if len(batch) > 0: | |
| batches.append(batch) | |
| return batches | |
| def build_dataloader(dataset, shuffle, max_tokens=None, max_sentences=None, | |
| required_batch_size_multiple=-1, endless=False, apply_batch_by_size=True, pin_memory=False, use_ddp=False): | |
| import torch.distributed as dist | |
| devices_cnt = torch.cuda.device_count() | |
| if devices_cnt == 0: | |
| devices_cnt = 1 | |
| if not use_ddp: | |
| devices_cnt = 1 | |
| if required_batch_size_multiple == -1: | |
| required_batch_size_multiple = devices_cnt | |
| def shuffle_batches(batches): | |
| np.random.shuffle(batches) | |
| return batches | |
| if max_tokens is not None: | |
| max_tokens *= devices_cnt | |
| if max_sentences is not None: | |
| max_sentences *= devices_cnt | |
| indices = dataset.ordered_indices() | |
| if apply_batch_by_size: | |
| batch_sampler = batch_by_size( | |
| indices, dataset.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences, | |
| required_batch_size_multiple=required_batch_size_multiple, | |
| ) | |
| else: | |
| batch_sampler = [] | |
| for i in range(0, len(indices), max_sentences): | |
| batch_sampler.append(indices[i:i + max_sentences]) | |
| if shuffle: | |
| batches = shuffle_batches(list(batch_sampler)) | |
| if endless: | |
| batches = [b for _ in range(1000) for b in shuffle_batches(list(batch_sampler))] | |
| else: | |
| batches = batch_sampler | |
| if endless: | |
| batches = [b for _ in range(1000) for b in batches] | |
| num_workers = dataset.num_workers | |
| if use_ddp: | |
| num_replicas = dist.get_world_size() | |
| rank = dist.get_rank() | |
| # batches = [x[rank::num_replicas] for x in batches if len(x) % num_replicas == 0] | |
| # ensure that every sample in the dataset is covered | |
| batches_ = [] | |
| for x in batches: | |
| if len(x) % num_replicas == 0: | |
| batches_.append(x[rank::num_replicas]) | |
| else: | |
| x_ = x + [x[-1]] * (len(x) - len(x) // num_replicas * num_replicas) | |
| batches_.append(x_[rank::num_replicas]) | |
| batches = batches_ | |
| return torch.utils.data.DataLoader(dataset, | |
| collate_fn=dataset.collater, | |
| batch_sampler=batches, | |
| num_workers=num_workers, | |
| pin_memory=pin_memory) | |
| def unpack_dict_to_list(samples): | |
| samples_ = [] | |
| bsz = samples.get('outputs').size(0) | |
| for i in range(bsz): | |
| res = {} | |
| for k, v in samples.items(): | |
| try: | |
| res[k] = v[i] | |
| except: | |
| pass | |
| samples_.append(res) | |
| return samples_ | |
| def remove_padding(x, padding_idx=0): | |
| if x is None: | |
| return None | |
| assert len(x.shape) in [1, 2] | |
| if len(x.shape) == 2: # [T, H] | |
| return x[np.abs(x).sum(-1) != padding_idx] | |
| elif len(x.shape) == 1: # [T] | |
| return x[x != padding_idx] | |
| def data_loader(fn): | |
| """ | |
| Decorator to make any fx with this use the lazy property | |
| :param fn: | |
| :return: | |
| """ | |
| wraps(fn) | |
| attr_name = '_lazy_' + fn.__name__ | |
| def _get_data_loader(self): | |
| try: | |
| value = getattr(self, attr_name) | |
| except AttributeError: | |
| try: | |
| value = fn(self) # Lazy evaluation, done only once. | |
| except AttributeError as e: | |
| # Guard against AttributeError suppression. (Issue #142) | |
| traceback.print_exc() | |
| error = f'{fn.__name__}: An AttributeError was encountered: ' + str(e) | |
| raise RuntimeError(error) from e | |
| setattr(self, attr_name, value) # Memoize evaluation. | |
| return value | |
| return _get_data_loader | |
| class BaseDataset(torch.utils.data.Dataset): | |
| def __init__(self, shuffle): | |
| super().__init__() | |
| self.hparams = hparams | |
| self.shuffle = shuffle | |
| self.sort_by_len = hparams['sort_by_len'] | |
| self.sizes = None | |
| def _sizes(self): | |
| return self.sizes | |
| def __getitem__(self, index): | |
| raise NotImplementedError | |
| def collater(self, samples): | |
| raise NotImplementedError | |
| def __len__(self): | |
| return len(self._sizes) | |
| def num_tokens(self, index): | |
| return self.size(index) | |
| def size(self, index): | |
| """Return an example's size as a float or tuple. This value is used when | |
| filtering a dataset with ``--max-positions``.""" | |
| return min(self._sizes[index], hparams['max_frames']) | |
| def ordered_indices(self): | |
| """Return an ordered list of indices. Batches will be constructed based | |
| on this order.""" | |
| if self.shuffle: | |
| indices = np.random.permutation(len(self)) | |
| if self.sort_by_len: | |
| indices = indices[np.argsort(np.array(self._sizes)[indices], kind='mergesort')] | |
| else: | |
| indices = np.arange(len(self)) | |
| return indices.tolist() | |
| def num_workers(self): | |
| return int(os.getenv('NUM_WORKERS', hparams['ds_workers'])) | |
| class BaseConcatDataset(ConcatDataset): | |
| def collater(self, samples): | |
| return self.datasets[0].collater(samples) | |
| def _sizes(self): | |
| if not hasattr(self, 'sizes'): | |
| self.sizes = list(chain.from_iterable([d._sizes for d in self.datasets])) | |
| return self.sizes | |
| def size(self, index): | |
| return min(self._sizes[index], hparams['max_frames']) | |
| def num_tokens(self, index): | |
| return self.size(index) | |
| def ordered_indices(self): | |
| """Return an ordered list of indices. Batches will be constructed based | |
| on this order.""" | |
| if self.datasets[0].shuffle: | |
| indices = np.random.permutation(len(self)) | |
| if self.datasets[0].sort_by_len: | |
| indices = indices[np.argsort(np.array(self._sizes)[indices], kind='mergesort')] | |
| else: | |
| indices = np.arange(len(self)) | |
| return indices | |
| def num_workers(self): | |
| return self.datasets[0].num_workers | |