|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
import torch.utils.data |
|
|
|
|
|
from fairseq.data import data_utils |
|
|
|
|
|
|
|
|
class EpochListening: |
|
|
"""Mixin for receiving updates whenever the epoch increments.""" |
|
|
|
|
|
def set_epoch(self, epoch): |
|
|
"""Will receive the updated epoch number at the beginning of the epoch.""" |
|
|
pass |
|
|
|
|
|
|
|
|
class FairseqDataset(torch.utils.data.Dataset, EpochListening): |
|
|
"""A dataset that provides helpers for batching.""" |
|
|
|
|
|
def __getitem__(self, index): |
|
|
raise NotImplementedError |
|
|
|
|
|
def __len__(self): |
|
|
raise NotImplementedError |
|
|
|
|
|
def collater(self, samples): |
|
|
"""Merge a list of samples to form a mini-batch. |
|
|
|
|
|
Args: |
|
|
samples (List[dict]): samples to collate |
|
|
|
|
|
Returns: |
|
|
dict: a mini-batch suitable for forwarding with a Model |
|
|
""" |
|
|
raise NotImplementedError |
|
|
|
|
|
def num_tokens(self, index): |
|
|
"""Return the number of tokens in a sample. This value is used to |
|
|
enforce ``--max-tokens`` during batching.""" |
|
|
raise NotImplementedError |
|
|
|
|
|
def size(self, index): |
|
|
"""Return an example's size as a float or tuple. This value is used when |
|
|
filtering a dataset with ``--max-positions``.""" |
|
|
raise NotImplementedError |
|
|
|
|
|
def ordered_indices(self): |
|
|
"""Return an ordered list of indices. Batches will be constructed based |
|
|
on this order.""" |
|
|
return np.arange(len(self)) |
|
|
|
|
|
@property |
|
|
def supports_prefetch(self): |
|
|
"""Whether this dataset supports prefetching.""" |
|
|
return False |
|
|
|
|
|
def attr(self, attr: str, index: int): |
|
|
return getattr(self, attr, None) |
|
|
|
|
|
def prefetch(self, indices): |
|
|
"""Prefetch the data required for this epoch.""" |
|
|
raise NotImplementedError |
|
|
|
|
|
def get_batch_shapes(self): |
|
|
""" |
|
|
Return a list of valid batch shapes, for example:: |
|
|
|
|
|
[(8, 512), (16, 256), (32, 128)] |
|
|
|
|
|
The first dimension of each tuple is the batch size and can be ``None`` |
|
|
to automatically infer the max batch size based on ``--max-tokens``. |
|
|
The second dimension of each tuple is the max supported length as given |
|
|
by :func:`fairseq.data.FairseqDataset.num_tokens`. |
|
|
|
|
|
This will be used by :func:`fairseq.data.FairseqDataset.batch_by_size` |
|
|
to restrict batch shapes. This is useful on TPUs to avoid too many |
|
|
dynamic shapes (and recompilations). |
|
|
""" |
|
|
return None |
|
|
|
|
|
def batch_by_size( |
|
|
self, |
|
|
indices, |
|
|
max_tokens=None, |
|
|
max_sentences=None, |
|
|
required_batch_size_multiple=1, |
|
|
): |
|
|
""" |
|
|
Given an ordered set of indices, return batches according to |
|
|
*max_tokens*, *max_sentences* and *required_batch_size_multiple*. |
|
|
""" |
|
|
from fairseq.data import data_utils |
|
|
|
|
|
fixed_shapes = self.get_batch_shapes() |
|
|
if fixed_shapes is not None: |
|
|
|
|
|
def adjust_bsz(bsz, num_tokens): |
|
|
if bsz is None: |
|
|
assert max_tokens is not None, 'Must specify --max-tokens' |
|
|
bsz = max_tokens // num_tokens |
|
|
if max_sentences is not None: |
|
|
bsz = min(bsz, max_sentences) |
|
|
elif ( |
|
|
bsz >= required_batch_size_multiple |
|
|
and bsz % required_batch_size_multiple != 0 |
|
|
): |
|
|
bsz -= (bsz % required_batch_size_multiple) |
|
|
return bsz |
|
|
|
|
|
fixed_shapes = np.array([ |
|
|
[adjust_bsz(bsz, num_tokens), num_tokens] |
|
|
for (bsz, num_tokens) in fixed_shapes |
|
|
]) |
|
|
|
|
|
return data_utils.batch_by_size( |
|
|
indices, |
|
|
num_tokens_fn=self.num_tokens, |
|
|
max_tokens=max_tokens, |
|
|
max_sentences=max_sentences, |
|
|
required_batch_size_multiple=required_batch_size_multiple, |
|
|
fixed_shapes=fixed_shapes, |
|
|
) |
|
|
|
|
|
def filter_indices_by_size(self, indices, max_sizes): |
|
|
""" Filter a list of sample indices. Remove those that are longer |
|
|
than specified in max_sizes. |
|
|
|
|
|
WARNING: don't update, override method in child classes |
|
|
|
|
|
Args: |
|
|
indices (np.array): original array of sample indices |
|
|
max_sizes (int or list[int] or tuple[int]): max sample size, |
|
|
can be defined separately for src and tgt (then list or tuple) |
|
|
|
|
|
Returns: |
|
|
np.array: filtered sample array |
|
|
list: list of removed indices |
|
|
""" |
|
|
if isinstance(max_sizes, float) or isinstance(max_sizes, int): |
|
|
if hasattr(self, 'sizes') and isinstance(self.sizes, np.ndarray): |
|
|
ignored = indices[self.sizes[indices] > max_sizes].tolist() |
|
|
indices = indices[self.sizes[indices] <= max_sizes] |
|
|
elif hasattr(self, 'sizes') and isinstance(self.sizes, list) and len(self.sizes) == 1: |
|
|
ignored = indices[self.sizes[0][indices] > max_sizes].tolist() |
|
|
indices = indices[self.sizes[0][indices] <= max_sizes] |
|
|
else: |
|
|
indices, ignored = data_utils._filter_by_size_dynamic(indices, self.size, max_sizes) |
|
|
else: |
|
|
indices, ignored = data_utils._filter_by_size_dynamic(indices, self.size, max_sizes) |
|
|
return indices, ignored |
|
|
|
|
|
|
|
|
class FairseqIterableDataset(torch.utils.data.IterableDataset, EpochListening): |
|
|
"""For datasets that need to be read sequentially, usually because the data |
|
|
is being streamed or otherwise can't be manipulated on a single machine. |
|
|
""" |
|
|
|
|
|
def __iter__(self): |
|
|
raise NotImplementedError |
|
|
|