| |
| |
| |
| |
|
|
| import math |
| from typing import Dict, List, Tuple |
|
|
| import numpy as np |
| import torch |
| from fairseq.data import Dictionary, FairseqDataset, data_utils |
| from fairseq.data.concat_dataset import ConcatDataset |
| from fairseq.data.legacy.block_pair_dataset import BlockPairDataset |
| from fairseq.data.token_block_dataset import TokenBlockDataset |
|
|
|
|
| class MaskedLMDataset(FairseqDataset): |
| """ |
| A wrapper Dataset for masked language modelling. The dataset |
| wraps around TokenBlockDataset or BlockedPairDataset and creates a batch |
| where the input blocks are masked according to the specified masking |
| probability. Additionally the batch can also contain sentence level targets |
| if this is specified. |
| |
| Args: |
| dataset: Dataset which generates blocks of data. Only BlockPairDataset |
| and TokenBlockDataset are supported. |
| sizes: Sentence lengths |
| vocab: Dictionary with the vocabulary and special tokens. |
| pad_idx: Id of padding token in dictionary |
| mask_idx: Id of mask token in dictionary |
| classif_token_idx: Id of classification token in dictionary. This is the |
| token associated with the sentence embedding (Eg: CLS for BERT) |
| sep_token_idx: Id of separator token in dictionary |
| (Eg: SEP in BERT) |
| seed: Seed for random number generator for reproducibility. |
| shuffle: Shuffle the elements before batching. |
| has_pairs: Specifies whether the underlying dataset |
| generates a pair of blocks along with a sentence_target or not. |
| Setting it to True assumes that the underlying dataset generates a |
| label for the pair of sentences which is surfaced as |
| sentence_target. The default value assumes a single block with no |
| sentence target. |
| segment_id: An optional segment id for filling in the segment labels |
| when we are in the single block setting (Eg: XLM). Default is 0. |
| masking_ratio: specifies what percentage of the blocks should be masked. |
| masking_prob: specifies the probability of a given token being |
| replaced with the "MASK" token. |
| random_token_prob: specifies the probability of a given token being |
| replaced by a random token from the vocabulary. |
| """ |
|
|
| def __init__( |
| self, |
| dataset: FairseqDataset, |
| sizes: np.ndarray, |
| vocab: Dictionary, |
| pad_idx: int, |
| mask_idx: int, |
| classif_token_idx: int, |
| sep_token_idx: int, |
| seed: int = 1, |
| shuffle: bool = True, |
| has_pairs: bool = True, |
| segment_id: int = 0, |
| masking_ratio: float = 0.15, |
| masking_prob: float = 0.8, |
| random_token_prob: float = 0.1, |
| ): |
| |
| assert ( |
| isinstance(dataset, TokenBlockDataset) |
| or isinstance(dataset, BlockPairDataset) |
| or isinstance(dataset, ConcatDataset) |
| ), ( |
| "MaskedLMDataset only wraps TokenBlockDataset or BlockPairDataset or " |
| "ConcatDataset" |
| ) |
|
|
| self.dataset = dataset |
| self.sizes = np.array(sizes) |
| self.vocab = vocab |
| self.pad_idx = pad_idx |
| self.mask_idx = mask_idx |
| self.classif_token_idx = classif_token_idx |
| self.sep_token_idx = sep_token_idx |
| self.shuffle = shuffle |
| self.seed = seed |
| self.has_pairs = has_pairs |
| self.segment_id = segment_id |
| self.masking_ratio = masking_ratio |
| self.masking_prob = masking_prob |
| self.random_token_prob = random_token_prob |
|
|
| |
| |
| if not has_pairs: |
| self.sizes = self.sizes + 1 |
|
|
| def __getitem__(self, index: int): |
| |
| if self.has_pairs: |
| (block_one, block_two, sentence_target) = self.dataset[index] |
| else: |
| block_one = self.dataset[index] |
|
|
| return { |
| "id": index, |
| "block_one": block_one, |
| "block_two": block_two if self.has_pairs else None, |
| "sentence_target": sentence_target if self.has_pairs else None, |
| } |
|
|
| def __len__(self): |
| return len(self.dataset) |
|
|
| def _mask_block( |
| self, |
| sentence: np.ndarray, |
| mask_idx: int, |
| pad_idx: int, |
| dictionary_token_range: Tuple, |
| ): |
| """ |
| Mask tokens for Masked Language Model training |
| Samples mask_ratio tokens that will be predicted by LM. |
| |
| Note:This function may not be efficient enough since we had multiple |
| conversions between np and torch, we can replace them with torch |
| operators later. |
| |
| Args: |
| sentence: 1d tensor to be masked |
| mask_idx: index to use for masking the sentence |
| pad_idx: index to use for masking the target for tokens we aren't |
| predicting |
| dictionary_token_range: range of indices in dictionary which can |
| be used for random word replacement |
| (e.g. without special characters) |
| Return: |
| masked_sent: masked sentence |
| target: target with words which we are not predicting replaced |
| by pad_idx |
| """ |
| masked_sent = np.copy(sentence) |
| sent_length = len(sentence) |
| mask_num = math.ceil(sent_length * self.masking_ratio) |
| mask = np.random.choice(sent_length, mask_num, replace=False) |
| target = np.copy(sentence) |
|
|
| for i in range(sent_length): |
| if i in mask: |
| rand = np.random.random() |
|
|
| |
| |
| if rand < self.masking_prob: |
| masked_sent[i] = mask_idx |
|
|
| |
| |
| elif rand < (self.masking_prob + self.random_token_prob): |
| |
| masked_sent[i] = np.random.randint( |
| dictionary_token_range[0], dictionary_token_range[1] |
| ) |
| else: |
| target[i] = pad_idx |
|
|
| return masked_sent, target |
|
|
| def _collate(self, samples: List[Dict], pad_idx: int, eos_idx: int): |
| """ |
| Does the heavy lifting for creating a batch from the input list of |
| examples. The logic is as follows: |
| 1. Mask the input blocks. In case has_pair is True then we have 2 |
| blocks to mask. |
| 2. Prepend the first masked block tensor with the special token |
| used as sentence embedding. Eg: CLS in BERT. This happens |
| irrespective of the value of has_pair. |
| 3. If has_pair is True, then append the first masked block with the |
| special separator token (eg: SEP for BERT) and compute segment |
| label accordingly. In this case, also append the second masked |
| block with this special separator token and compute its segment |
| label. |
| 4. For the targets tensor, prepend and append with padding index |
| accordingly. |
| 5. Concatenate all tensors. |
| """ |
| if len(samples) == 0: |
| return {} |
| |
| |
| |
| |
| |
| with data_utils.numpy_seed(self.seed + samples[0]["id"]): |
| for s in samples: |
|
|
| |
| |
| token_range = (self.vocab.nspecial, len(self.vocab)) |
|
|
| |
| masked_blk_one, masked_tgt_one = self._mask_block( |
| s["block_one"], |
| self.mask_idx, |
| self.pad_idx, |
| token_range, |
| ) |
|
|
| tokens = np.concatenate([[self.classif_token_idx], masked_blk_one]) |
| targets = np.concatenate([[self.pad_idx], masked_tgt_one]) |
| segments = np.ones(len(tokens)) * self.segment_id |
|
|
| |
| |
| |
| if self.has_pairs: |
| tokens_one = np.concatenate([tokens, [self.sep_token_idx]]) |
| targets_one = np.concatenate([targets, [self.pad_idx]]) |
|
|
| masked_blk_two, masked_tgt_two = self._mask_block( |
| s["block_two"], self.mask_idx, self.pad_idx, token_range |
| ) |
| tokens_two = np.concatenate([masked_blk_two, [self.sep_token_idx]]) |
| targets_two = np.concatenate([masked_tgt_two, [self.pad_idx]]) |
|
|
| |
| segments_one = np.zeros(len(tokens_one)) |
| |
| segments_two = np.ones(len(tokens_two)) |
|
|
| tokens = np.concatenate([tokens_one, tokens_two]) |
| targets = np.concatenate([targets_one, targets_two]) |
| segments = np.concatenate([segments_one, segments_two]) |
|
|
| s["source"] = torch.LongTensor(tokens) |
| s["segment_labels"] = torch.LongTensor(segments) |
| s["lm_target"] = torch.LongTensor(targets) |
|
|
| def merge(key): |
| return data_utils.collate_tokens( |
| [s[key] for s in samples], pad_idx, eos_idx, left_pad=False |
| ) |
|
|
| return { |
| "id": torch.LongTensor([s["id"] for s in samples]), |
| "ntokens": sum(len(s["source"]) for s in samples), |
| "net_input": { |
| "src_tokens": merge("source"), |
| "segment_labels": merge("segment_labels"), |
| }, |
| "lm_target": merge("lm_target"), |
| "sentence_target": torch.LongTensor([s["sentence_target"] for s in samples]) |
| if self.has_pairs |
| else None, |
| "nsentences": len(samples), |
| } |
|
|
| def collater(self, samples: List[Dict]): |
| """Merge a list of samples to form a mini-batch. |
| |
| Args: |
| samples (List[dict]): samples to collate |
| |
| Returns: |
| dict: a mini-batch of data |
| """ |
| return self._collate(samples, self.vocab.pad(), self.vocab.eos()) |
|
|
| def num_tokens(self, index: int): |
| """ |
| Return the number of tokens in a sample. This value is used to |
| enforce max-tokens during batching. |
| """ |
| return self.sizes[index] |
|
|
| def size(self, index: int): |
| """ |
| Return an example's size as a float or tuple. This value is used when |
| filtering a dataset with max-positions. |
| """ |
| return self.sizes[index] |
|
|
| def ordered_indices(self): |
| """ |
| Return an ordered list of indices. Batches will be constructed based |
| on this order. |
| """ |
| if self.shuffle: |
| return np.random.permutation(len(self)) |
| else: |
| order = [np.arange(len(self))] |
| order.append(self.sizes) |
| return np.lexsort(order) |
|
|
| @property |
| def supports_prefetch(self): |
| return getattr(self.dataset, "supports_prefetch", False) |
|
|
| def prefetch(self, indices): |
| self.dataset.prefetch(indices) |
|
|