| | |
| | |
| | |
| | |
| |
|
| | import numpy as np |
| | import torch |
| |
|
| | from . import FairseqDataset, data_utils |
| |
|
| |
|
| | def collate(samples, pad_idx, eos_idx, fixed_pad_length=None, pad_to_bsz=None): |
| | if len(samples) == 0: |
| | return {} |
| |
|
| | def merge(key, is_list=False): |
| | if is_list: |
| | res = [] |
| | for i in range(len(samples[0][key])): |
| | res.append( |
| | data_utils.collate_tokens( |
| | [s[key][i] for s in samples], |
| | pad_idx, |
| | eos_idx, |
| | left_pad=False, |
| | pad_to_length=fixed_pad_length, |
| | pad_to_bsz=pad_to_bsz, |
| | ) |
| | ) |
| | return res |
| | else: |
| | return data_utils.collate_tokens( |
| | [s[key] for s in samples], |
| | pad_idx, |
| | eos_idx, |
| | left_pad=False, |
| | pad_to_length=fixed_pad_length, |
| | pad_to_bsz=pad_to_bsz, |
| | ) |
| |
|
| | src_tokens = merge("source") |
| | if samples[0]["target"] is not None: |
| | is_target_list = isinstance(samples[0]["target"], list) |
| | target = merge("target", is_target_list) |
| | else: |
| | target = src_tokens |
| |
|
| | return { |
| | "id": torch.LongTensor([s["id"] for s in samples]), |
| | "nsentences": len(samples), |
| | "ntokens": sum(len(s["source"]) for s in samples), |
| | "net_input": { |
| | "src_tokens": src_tokens, |
| | "src_lengths": torch.LongTensor([s["source"].numel() for s in samples]), |
| | }, |
| | "target": target, |
| | } |
| |
|
| |
|
| | class MonolingualDataset(FairseqDataset): |
| | """ |
| | A wrapper around torch.utils.data.Dataset for monolingual data. |
| | |
| | Args: |
| | dataset (torch.utils.data.Dataset): dataset to wrap |
| | sizes (List[int]): sentence lengths |
| | vocab (~fairseq.data.Dictionary): vocabulary |
| | shuffle (bool, optional): shuffle the elements before batching |
| | (default: True). |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | dataset, |
| | sizes, |
| | src_vocab, |
| | tgt_vocab=None, |
| | add_eos_for_other_targets=False, |
| | shuffle=False, |
| | targets=None, |
| | add_bos_token=False, |
| | fixed_pad_length=None, |
| | pad_to_bsz=None, |
| | src_lang_idx=None, |
| | tgt_lang_idx=None, |
| | ): |
| | self.dataset = dataset |
| | self.sizes = np.array(sizes) |
| | self.vocab = src_vocab |
| | self.tgt_vocab = tgt_vocab or src_vocab |
| | self.add_eos_for_other_targets = add_eos_for_other_targets |
| | self.shuffle = shuffle |
| | self.add_bos_token = add_bos_token |
| | self.fixed_pad_length = fixed_pad_length |
| | self.pad_to_bsz = pad_to_bsz |
| | self.src_lang_idx = src_lang_idx |
| | self.tgt_lang_idx = tgt_lang_idx |
| |
|
| | assert targets is None or all( |
| | t in {"self", "future", "past"} for t in targets |
| | ), "targets must be none or one of 'self', 'future', 'past'" |
| | if targets is not None and len(targets) == 0: |
| | targets = None |
| | self.targets = targets |
| |
|
| | def __getitem__(self, index): |
| | if self.targets is not None: |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | source, future_target, past_target = self.dataset[index] |
| | source, target = self._make_source_target( |
| | source, future_target, past_target |
| | ) |
| | else: |
| | source = self.dataset[index] |
| | target = None |
| | source, target = self._maybe_add_bos(source, target) |
| | return {"id": index, "source": source, "target": target} |
| |
|
| | def __len__(self): |
| | return len(self.dataset) |
| |
|
| | def _make_source_target(self, source, future_target, past_target): |
| | if self.targets is not None: |
| | target = [] |
| |
|
| | if ( |
| | self.add_eos_for_other_targets |
| | and (("self" in self.targets) or ("past" in self.targets)) |
| | and source[-1] != self.vocab.eos() |
| | ): |
| | |
| | source = torch.cat([source, source.new([self.vocab.eos()])]) |
| |
|
| | if "future" in self.targets: |
| | future_target = torch.cat( |
| | [future_target, future_target.new([self.vocab.pad()])] |
| | ) |
| | if "past" in self.targets: |
| | |
| | |
| | past_target = torch.cat( |
| | [ |
| | past_target.new([self.vocab.pad()]), |
| | past_target[1:], |
| | source[-2, None], |
| | ] |
| | ) |
| |
|
| | for t in self.targets: |
| | if t == "self": |
| | target.append(source) |
| | elif t == "future": |
| | target.append(future_target) |
| | elif t == "past": |
| | target.append(past_target) |
| | else: |
| | raise Exception("invalid target " + t) |
| |
|
| | if len(target) == 1: |
| | target = target[0] |
| | else: |
| | target = future_target |
| |
|
| | return source, self._filter_vocab(target) |
| |
|
| | def _maybe_add_bos(self, source, target): |
| | if self.add_bos_token: |
| | source = torch.cat([source.new([self.vocab.bos()]), source]) |
| | if target is not None: |
| | target = torch.cat([target.new([self.tgt_vocab.bos()]), target]) |
| | return source, target |
| |
|
| | def num_tokens_vec(self, indices): |
| | """Return the number of tokens for a set of positions defined by indices. |
| | This value is used to enforce ``--max-tokens`` during batching.""" |
| | return self.sizes[indices] |
| |
|
| | def _filter_vocab(self, target): |
| | if len(self.tgt_vocab) != len(self.vocab): |
| |
|
| | def _filter(target): |
| | mask = target.ge(len(self.tgt_vocab)) |
| | if mask.any(): |
| | target[mask] = self.tgt_vocab.unk() |
| | return target |
| |
|
| | if isinstance(target, list): |
| | return [_filter(t) for t in target] |
| | return _filter(target) |
| | return target |
| |
|
| | def collater(self, samples): |
| | """Merge a list of samples to form a mini-batch. |
| | |
| | Args: |
| | samples (List[dict]): samples to collate |
| | |
| | Returns: |
| | dict: a mini-batch with the following keys: |
| | |
| | - `id` (LongTensor): example IDs in the original input order |
| | - `ntokens` (int): total number of tokens in the batch |
| | - `net_input` (dict): the input to the Model, containing keys: |
| | |
| | - `src_tokens` (LongTensor): a padded 2D Tensor of tokens in |
| | the source sentence of shape `(bsz, src_len)`. Padding will |
| | appear on the right. |
| | |
| | - `target` (LongTensor): a padded 2D Tensor of tokens in the |
| | target sentence of shape `(bsz, tgt_len)`. Padding will appear |
| | on the right. |
| | """ |
| | return collate( |
| | samples, |
| | self.vocab.pad(), |
| | self.vocab.eos(), |
| | self.fixed_pad_length, |
| | self.pad_to_bsz, |
| | ) |
| |
|
| | def num_tokens(self, index): |
| | """Return the number of tokens in a sample. This value is used to |
| | enforce ``--max-tokens`` during batching.""" |
| | return self.sizes[index] |
| |
|
| | def size(self, index): |
| | """Return an example's size as a float or tuple. This value is used when |
| | filtering a dataset with ``--max-positions``.""" |
| | return self.sizes[index] |
| |
|
| | def ordered_indices(self): |
| | """Return an ordered list of indices. Batches will be constructed based |
| | on this order.""" |
| | if self.shuffle: |
| | order = [np.random.permutation(len(self))] |
| | else: |
| | order = [np.arange(len(self))] |
| | order.append(self.sizes) |
| | return np.lexsort(order) |
| |
|
| | @property |
| | def supports_prefetch(self): |
| | return getattr(self.dataset, "supports_prefetch", False) |
| |
|
| | def prefetch(self, indices): |
| | self.dataset.prefetch(indices) |
| |
|