| |
| |
| |
| |
|
|
| import math |
|
|
| import numpy as np |
| import torch |
|
|
| from . import FairseqDataset, data_utils |
|
|
|
|
| def collate( |
| samples, |
| pad_idx, |
| eos_idx, |
| vocab, |
| left_pad_source=False, |
| left_pad_target=False, |
| input_feeding=True, |
| pad_to_length=None, |
| ): |
| assert input_feeding |
| if len(samples) == 0: |
| return {} |
|
|
| def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): |
| return data_utils.collate_tokens( |
| [s[key] for s in samples], |
| pad_idx, |
| eos_idx=None, |
| left_pad=left_pad, |
| move_eos_to_beginning=move_eos_to_beginning, |
| pad_to_length=pad_to_length, |
| ) |
|
|
| id = torch.LongTensor([s["id"] for s in samples]) |
| src_tokens = merge( |
| "source", |
| left_pad=left_pad_source, |
| pad_to_length=pad_to_length["source"] if pad_to_length is not None else None, |
| ) |
| |
| src_lengths = torch.LongTensor([s["source"].numel() for s in samples]) |
| src_lengths, sort_order = src_lengths.sort(descending=True) |
| id = id.index_select(0, sort_order) |
| src_tokens = src_tokens.index_select(0, sort_order) |
|
|
| prev_output_tokens = None |
| target = None |
| if samples[0].get("target", None) is not None: |
| target = merge( |
| "target", |
| left_pad=left_pad_target, |
| pad_to_length=pad_to_length["target"] |
| if pad_to_length is not None |
| else None, |
| ) |
| target = target.index_select(0, sort_order) |
| ntokens = sum(len(s["target"]) for s in samples) |
|
|
| if input_feeding: |
| |
| |
| prev_output_tokens = merge( |
| "target", |
| left_pad=left_pad_target, |
| move_eos_to_beginning=True, |
| pad_to_length=pad_to_length["target"] |
| if pad_to_length is not None |
| else None, |
| ) |
| prev_output_tokens = prev_output_tokens.index_select(0, sort_order) |
| else: |
| ntokens = sum(len(s["source"]) for s in samples) |
|
|
| batch = { |
| "id": id, |
| "ntokens": ntokens, |
| "net_input": { |
| "src_tokens": src_tokens, |
| "src_lengths": src_lengths, |
| }, |
| "target": target, |
| "nsentences": samples[0]["source"].size(0), |
| "sort_order": sort_order, |
| } |
| if prev_output_tokens is not None: |
| batch["net_input"]["prev_output_tokens"] = prev_output_tokens |
|
|
| return batch |
|
|
|
|
| class DenoisingDataset(FairseqDataset): |
| """ |
| A wrapper around TokenBlockDataset for BART dataset. |
| |
| Args: |
| dataset (TokenBlockDataset): dataset to wrap |
| sizes (List[int]): sentence lengths |
| vocab (~fairseq.data.Dictionary): vocabulary |
| mask_idx (int): dictionary index used for masked token |
| mask_whole_words: only mask whole words. This should be a byte mask |
| over vocab indices, indicating whether it is the beginning of a |
| word. We will extend any mask to encompass the whole word. |
| shuffle (bool, optional): shuffle the elements before batching. |
| Default: ``True`` |
| seed: Seed for random number generator for reproducibility. |
| """ |
|
|
| def __init__( |
| self, |
| dataset, |
| sizes, |
| vocab, |
| mask_idx, |
| mask_whole_words, |
| shuffle, |
| seed, |
| mask, |
| mask_random, |
| insert, |
| rotate, |
| permute_sentences, |
| bpe, |
| replace_length, |
| mask_length, |
| poisson_lambda, |
| eos=None, |
| item_transform_func=None, |
| ): |
| self.dataset = dataset |
|
|
| self.sizes = sizes |
|
|
| self.vocab = vocab |
| self.shuffle = shuffle |
| self.seed = seed |
| self.mask_idx = mask_idx |
| self.mask_whole_word = mask_whole_words |
| self.mask_ratio = mask |
| self.random_ratio = mask_random |
| self.insert_ratio = insert |
| self.rotate_ratio = rotate |
| self.permute_sentence_ratio = permute_sentences |
| self.eos = eos if eos is not None else vocab.eos() |
| self.item_transform_func = item_transform_func |
|
|
| if bpe != "gpt2": |
| self.full_stop_index = self.vocab.eos() |
| else: |
| assert bpe == "gpt2" |
| self.full_stop_index = self.vocab.index("13") |
|
|
| self.replace_length = replace_length |
| if self.replace_length not in [-1, 0, 1]: |
| raise ValueError(f"invalid arg: replace_length={self.replace_length}") |
| if mask_length not in ["subword", "word", "span-poisson"]: |
| raise ValueError(f"invalid arg: mask-length={mask_length}") |
| if mask_length == "subword" and replace_length not in [0, 1]: |
| raise ValueError(f"if using subwords, use replace-length=1 or 0") |
|
|
| self.mask_span_distribution = None |
| if mask_length == "span-poisson": |
| _lambda = poisson_lambda |
|
|
| lambda_to_the_k = 1 |
| e_to_the_minus_lambda = math.exp(-_lambda) |
| k_factorial = 1 |
| ps = [] |
| for k in range(0, 128): |
| ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial) |
| lambda_to_the_k *= _lambda |
| k_factorial *= k + 1 |
| if ps[-1] < 0.0000001: |
| break |
| ps = torch.FloatTensor(ps) |
| self.mask_span_distribution = torch.distributions.Categorical(ps) |
|
|
| self.epoch = 0 |
|
|
| @property |
| def can_reuse_epoch_itr_across_epochs(self): |
| return True |
|
|
| def set_epoch(self, epoch, **unused): |
| self.epoch = epoch |
|
|
| def __getitem__(self, index): |
| with data_utils.numpy_seed(self.seed, self.epoch, index): |
| tokens = self.dataset[index] |
| assert tokens[-1] == self.eos |
| source, target = tokens, tokens.clone() |
|
|
| if self.permute_sentence_ratio > 0.0: |
| source = self.permute_sentences(source, self.permute_sentence_ratio) |
|
|
| if self.mask_ratio > 0: |
| source = self.add_whole_word_mask(source, self.mask_ratio) |
|
|
| if self.insert_ratio > 0: |
| source = self.add_insertion_noise(source, self.insert_ratio) |
|
|
| if self.rotate_ratio > 0.0 and np.random.random() < self.rotate_ratio: |
| source = self.add_rolling_noise(source) |
| |
| if self.item_transform_func is not None: |
| source, target = self.item_transform_func(source, target) |
|
|
| assert (source >= 0).all() |
| assert (source[1:-1] >= 1).all() |
| assert (source <= len(self.vocab)).all() |
| assert source[0] == self.vocab.bos() |
| assert source[-1] == self.eos |
| return { |
| "id": index, |
| "source": source, |
| "target": target, |
| } |
|
|
| def __len__(self): |
| return len(self.dataset) |
|
|
| def permute_sentences(self, source, p=1.0): |
| full_stops = source == self.full_stop_index |
| |
| full_stops[-2] = 1 |
|
|
| |
| sentence_ends = (full_stops[1:] * ~full_stops[:-1]).nonzero(as_tuple=False) + 2 |
| result = source.clone() |
|
|
| num_sentences = sentence_ends.size(0) |
| num_to_permute = math.ceil((num_sentences * 2 * p) / 2.0) |
| substitutions = torch.randperm(num_sentences)[:num_to_permute] |
| ordering = torch.arange(0, num_sentences) |
| ordering[substitutions] = substitutions[torch.randperm(num_to_permute)] |
|
|
| |
| index = 1 |
| for i in ordering: |
| sentence = source[(sentence_ends[i - 1] if i > 0 else 1) : sentence_ends[i]] |
| result[index : index + sentence.size(0)] = sentence |
| index += sentence.size(0) |
| return result |
|
|
| def word_starts(self, source): |
| if self.mask_whole_word is not None: |
| is_word_start = self.mask_whole_word.gather(0, source) |
| else: |
| is_word_start = torch.ones(source.size()) |
| is_word_start[0] = 0 |
| is_word_start[-1] = 0 |
| return is_word_start |
|
|
| def add_whole_word_mask(self, source, p): |
| is_word_start = self.word_starts(source) |
| num_to_mask = int(math.ceil(is_word_start.float().sum() * p)) |
| num_inserts = 0 |
| if num_to_mask == 0: |
| return source |
|
|
| if self.mask_span_distribution is not None: |
| lengths = self.mask_span_distribution.sample(sample_shape=(num_to_mask,)) |
|
|
| |
| cum_length = torch.cumsum(lengths, 0) |
| while cum_length[-1] < num_to_mask: |
| lengths = torch.cat( |
| [ |
| lengths, |
| self.mask_span_distribution.sample(sample_shape=(num_to_mask,)), |
| ], |
| dim=0, |
| ) |
| cum_length = torch.cumsum(lengths, 0) |
|
|
| |
| i = 0 |
| while cum_length[i] < num_to_mask: |
| i += 1 |
| lengths[i] = num_to_mask - (0 if i == 0 else cum_length[i - 1]) |
| num_to_mask = i + 1 |
| lengths = lengths[:num_to_mask] |
|
|
| |
| lengths = lengths[lengths > 0] |
| num_inserts = num_to_mask - lengths.size(0) |
| num_to_mask -= num_inserts |
| if num_to_mask == 0: |
| return self.add_insertion_noise(source, num_inserts / source.size(0)) |
|
|
| assert (lengths > 0).all() |
| else: |
| lengths = torch.ones((num_to_mask,)).long() |
| assert is_word_start[-1] == 0 |
| word_starts = is_word_start.nonzero(as_tuple=False) |
| indices = word_starts[ |
| torch.randperm(word_starts.size(0))[:num_to_mask] |
| ].squeeze(1) |
| mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio |
|
|
| source_length = source.size(0) |
| assert source_length - 1 not in indices |
| to_keep = torch.ones(source_length, dtype=torch.bool) |
| is_word_start[ |
| -1 |
| ] = 255 |
| if self.replace_length == 0: |
| to_keep[indices] = 0 |
| else: |
| |
| source[indices] = self.mask_idx |
| source[indices[mask_random]] = torch.randint( |
| 1, len(self.vocab), size=(mask_random.sum(),) |
| ) |
|
|
| if self.mask_span_distribution is not None: |
| assert len(lengths.size()) == 1 |
| assert lengths.size() == indices.size() |
| lengths -= 1 |
| while indices.size(0) > 0: |
| assert lengths.size() == indices.size() |
| lengths -= is_word_start[indices + 1].long() |
| uncompleted = lengths >= 0 |
| indices = indices[uncompleted] + 1 |
| mask_random = mask_random[uncompleted] |
| lengths = lengths[uncompleted] |
| if self.replace_length != -1: |
| |
| to_keep[indices] = 0 |
| else: |
| |
| source[indices] = self.mask_idx |
| source[indices[mask_random]] = torch.randint( |
| 1, len(self.vocab), size=(mask_random.sum(),) |
| ) |
| else: |
| |
| while indices.size(0) > 0: |
| uncompleted = is_word_start[indices + 1] == 0 |
| indices = indices[uncompleted] + 1 |
| mask_random = mask_random[uncompleted] |
| if self.replace_length != -1: |
| |
| to_keep[indices] = 0 |
| else: |
| |
| source[indices] = self.mask_idx |
| source[indices[mask_random]] = torch.randint( |
| 1, len(self.vocab), size=(mask_random.sum(),) |
| ) |
|
|
| assert source_length - 1 not in indices |
|
|
| source = source[to_keep] |
|
|
| if num_inserts > 0: |
| source = self.add_insertion_noise(source, num_inserts / source.size(0)) |
|
|
| return source |
|
|
| def add_permuted_noise(self, tokens, p): |
| num_words = len(tokens) |
| num_to_permute = math.ceil(((num_words * 2) * p) / 2.0) |
| substitutions = torch.randperm(num_words - 2)[:num_to_permute] + 1 |
| tokens[substitutions] = tokens[substitutions[torch.randperm(num_to_permute)]] |
| return tokens |
|
|
| def add_rolling_noise(self, tokens): |
| offset = np.random.randint(1, max(1, tokens.size(-1) - 1) + 1) |
| tokens = torch.cat( |
| (tokens[0:1], tokens[offset:-1], tokens[1:offset], tokens[-1:]), |
| dim=0, |
| ) |
| return tokens |
|
|
| def add_insertion_noise(self, tokens, p): |
| if p == 0.0: |
| return tokens |
|
|
| num_tokens = len(tokens) |
| n = int(math.ceil(num_tokens * p)) |
|
|
| noise_indices = torch.randperm(num_tokens + n - 2)[:n] + 1 |
| noise_mask = torch.zeros(size=(num_tokens + n,), dtype=torch.bool) |
| noise_mask[noise_indices] = 1 |
| result = torch.LongTensor(n + len(tokens)).fill_(-1) |
|
|
| num_random = int(math.ceil(n * self.random_ratio)) |
| result[noise_indices[num_random:]] = self.mask_idx |
| result[noise_indices[:num_random]] = torch.randint( |
| low=1, high=len(self.vocab), size=(num_random,) |
| ) |
|
|
| result[~noise_mask] = tokens |
|
|
| assert (result >= 0).all() |
| return result |
|
|
| def collater(self, samples, pad_to_length=None): |
| """Merge a list of samples to form a mini-batch. |
| Args: |
| samples (List[dict]): samples to collate |
| Returns: |
| dict: a mini-batch of data |
| """ |
| return collate( |
| samples, self.vocab.pad(), self.eos, self.vocab, pad_to_length=pad_to_length |
| ) |
|
|
| def num_tokens(self, index): |
| """Return the number of tokens in a sample. This value is used to |
| enforce ``--max-tokens`` during batching.""" |
| return self.sizes[index] |
|
|
| def size(self, index): |
| """Return an example's size as a float or tuple. This value is used when |
| filtering a dataset with ``--max-positions``.""" |
| return self.sizes[index] |
|
|
| def ordered_indices(self): |
| """Return an ordered list of indices. Batches will be constructed based |
| on this order.""" |
| if self.shuffle: |
| indices = np.random.permutation(len(self)) |
| else: |
| indices = np.arange(len(self)) |
| return indices[np.argsort(self.sizes[indices], kind="mergesort")] |
|
|
| def prefetch(self, indices): |
| self.src.prefetch(indices) |
| self.tgt.prefetch(indices) |
|
|
| @property |
| def supports_prefetch(self): |
| return ( |
| hasattr(self.src, "supports_prefetch") |
| and self.src.supports_prefetch |
| and hasattr(self.tgt, "supports_prefetch") |
| and self.tgt.supports_prefetch |
| ) |
|
|