Spaces:
Build error
Build error
| # -------------------------------------------------------- | |
| # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) | |
| # Github source: https://github.com/mbzuai-nlp/ArTST | |
| # Based on speecht5, fairseq and espnet code bases | |
| # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet | |
| # -------------------------------------------------------- | |
| import math | |
| import numpy as np | |
| import torch | |
| from fairseq.data import FairseqDataset, data_utils | |
| def collate( | |
| samples, | |
| pad_idx, | |
| eos_idx, | |
| vocab, | |
| left_pad_source=False, | |
| left_pad_target=False, | |
| input_feeding=True, | |
| pad_to_length=None, | |
| ): | |
| assert input_feeding | |
| if len(samples) == 0: | |
| return {} | |
| def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): | |
| return data_utils.collate_tokens( | |
| [s[key] for s in samples], | |
| pad_idx, | |
| eos_idx=None, # use eos_idx of each sample instead of vocab.eos() | |
| left_pad=left_pad, | |
| move_eos_to_beginning=move_eos_to_beginning, | |
| pad_to_length=pad_to_length, | |
| ) | |
| id = torch.LongTensor([s["id"] for s in samples]) | |
| src_tokens = merge( | |
| "source", | |
| left_pad=left_pad_source, | |
| pad_to_length=pad_to_length["source"] if pad_to_length is not None else None, | |
| ) | |
| # sort by descending source length | |
| src_lengths = torch.LongTensor([s["source"].numel() for s in samples]) | |
| src_lengths, sort_order = src_lengths.sort(descending=True) | |
| id = id.index_select(0, sort_order) | |
| src_tokens = src_tokens.index_select(0, sort_order) | |
| prev_output_tokens = None | |
| target = None | |
| if samples[0].get("target", None) is not None: | |
| target = merge( | |
| "target", | |
| left_pad=left_pad_target, | |
| pad_to_length=pad_to_length["target"] | |
| if pad_to_length is not None | |
| else None, | |
| ) | |
| target = target.index_select(0, sort_order) | |
| ntokens = sum(len(s["target"]) for s in samples) | |
| if input_feeding: | |
| # we create a shifted version of targets for feeding the | |
| # previous output token(s) into the next decoder step | |
| prev_output_tokens = merge( | |
| "target", | |
| left_pad=left_pad_target, | |
| move_eos_to_beginning=True, | |
| pad_to_length=pad_to_length["target"] | |
| if pad_to_length is not None | |
| else None, | |
| ) | |
| prev_output_tokens = prev_output_tokens.index_select(0, sort_order) | |
| else: | |
| ntokens = sum(len(s["source"]) for s in samples) | |
| batch = { | |
| "id": id, | |
| "ntokens": ntokens, | |
| "net_input": { | |
| "src_tokens": src_tokens, | |
| "src_lengths": src_lengths, | |
| }, | |
| "target": target, | |
| "nsentences": samples[0]["source"].size(0), | |
| "sort_order": sort_order, | |
| "task_name": 'text_pretrain', | |
| } | |
| if prev_output_tokens is not None: | |
| batch["net_input"]["prev_output_tokens"] = prev_output_tokens | |
| return batch | |
| class TextPretrainDataset(FairseqDataset): | |
| """ | |
| A wrapper around TokenBlockDataset for BART dataset. | |
| Args: | |
| dataset (TokenBlockDataset): dataset to wrap | |
| sizes (List[int]): sentence lengths | |
| vocab (~fairseq.data.Dictionary): vocabulary | |
| mask_idx (int): dictionary index used for masked token | |
| mask_whole_words: only mask whole words. This should be a byte mask | |
| over vocab indices, indicating whether it is the beginning of a | |
| word. We will extend any mask to encompass the whole word. | |
| shuffle (bool, optional): shuffle the elements before batching. | |
| Default: ``True`` | |
| seed: Seed for random number generator for reproducibility. | |
| args: argparse arguments. | |
| """ | |
| def __init__( | |
| self, | |
| dataset, | |
| sizes, | |
| vocab, | |
| mask_idx, | |
| mask_whole_words, | |
| shuffle, | |
| seed, | |
| args, | |
| eos=None, | |
| item_transform_func=None, | |
| iid_noise_target=False, | |
| uni_mask_idxs=None, | |
| ): | |
| self.dataset = dataset | |
| self.sizes = sizes | |
| self.vocab = vocab | |
| self.shuffle = shuffle | |
| self.seed = seed | |
| if iid_noise_target: | |
| assert isinstance(uni_mask_idxs, torch.Tensor), "if use iid_noise_target, the uni_mask_idxs must be a tensor which contain the mask indexs" | |
| self.iid_noise_target = iid_noise_target | |
| self.uni_mask_idxs = uni_mask_idxs | |
| self.mask_idx = mask_idx | |
| self.mask_whole_word = mask_whole_words | |
| self.mask_ratio = args.mask | |
| self.random_ratio = args.mask_random | |
| self.insert_ratio = args.insert | |
| self.rotate_ratio = args.rotate | |
| self.permute_sentence_ratio = args.permute_sentences | |
| self.eos = eos if eos is not None else vocab.eos() | |
| self.item_transform_func = item_transform_func | |
| if args.bpe != "gpt2": | |
| self.full_stop_index = self.vocab.eos() | |
| else: | |
| assert args.bpe == "gpt2" | |
| self.full_stop_index = self.vocab.index("13") | |
| self.replace_length = args.replace_length | |
| if self.replace_length not in [-1, 0, 1]: | |
| raise ValueError(f"invalid arg: replace_length={self.replace_length}") | |
| if args.mask_length not in ["subword", "word", "span-poisson"]: | |
| raise ValueError(f"invalid arg: mask-length={args.mask_length}") | |
| if args.mask_length == "subword" and args.replace_length not in [0, 1]: | |
| raise ValueError(f"if using subwords, use replace-length=1 or 0") | |
| self.mask_span_distribution = None | |
| if args.mask_length == "span-poisson": | |
| _lambda = args.poisson_lambda | |
| lambda_to_the_k = 1 | |
| e_to_the_minus_lambda = math.exp(-_lambda) | |
| k_factorial = 1 | |
| ps = [] | |
| for k in range(0, 128): | |
| ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial) | |
| lambda_to_the_k *= _lambda | |
| k_factorial *= k + 1 | |
| if ps[-1] < 0.0000001: | |
| break | |
| ps = torch.FloatTensor(ps) | |
| self.mask_span_distribution = torch.distributions.Categorical(ps) | |
| self.epoch = 0 | |
| def can_reuse_epoch_itr_across_epochs(self): | |
| return True # only the noise changes, not item sizes | |
| def set_epoch(self, epoch, **unused): | |
| self.epoch = epoch | |
| def __getitem__(self, index): | |
| with data_utils.numpy_seed(self.seed, self.epoch, index): | |
| tokens = self.dataset[index] | |
| assert tokens[-1] == self.eos | |
| source, target = tokens, tokens.clone() | |
| if self.permute_sentence_ratio > 0.0: | |
| source = self.permute_sentences(source, self.permute_sentence_ratio) | |
| if self.mask_ratio > 0: | |
| source, new_target = self.add_whole_word_mask(source, self.mask_ratio) | |
| if new_target is not None: | |
| target = new_target | |
| if self.insert_ratio > 0: | |
| source = self.add_insertion_noise(source, self.insert_ratio) | |
| if self.rotate_ratio > 0.0 and np.random.random() < self.rotate_ratio: | |
| source = self.add_rolling_noise(source) | |
| # there can additional changes to make: | |
| if self.item_transform_func is not None: | |
| source, target = self.item_transform_func(source, target) | |
| assert (source >= 0).all() | |
| assert (source[1:-1] >= 1).all() | |
| assert (source <= len(self.vocab)).all() | |
| assert source[0] == self.vocab.bos() | |
| assert source[-1] == self.eos | |
| return { | |
| "id": index, | |
| "source": source, | |
| "target": target, | |
| } | |
| def __len__(self): | |
| return len(self.dataset) | |
| def permute_sentences(self, source, p=1.0): | |
| full_stops = source == self.full_stop_index | |
| # Pretend it ends with a full stop so last span is a sentence | |
| full_stops[-2] = 1 | |
| # Tokens that are full stops, where the previous token is not | |
| sentence_ends = (full_stops[1:] * ~full_stops[:-1]).nonzero(as_tuple=False) + 2 | |
| result = source.clone() | |
| num_sentences = sentence_ends.size(0) | |
| num_to_permute = math.ceil((num_sentences * 2 * p) / 2.0) | |
| substitutions = torch.randperm(num_sentences)[:num_to_permute] | |
| ordering = torch.arange(0, num_sentences) | |
| ordering[substitutions] = substitutions[torch.randperm(num_to_permute)] | |
| # Ignore <bos> at start | |
| index = 1 | |
| for i in ordering: | |
| sentence = source[(sentence_ends[i - 1] if i > 0 else 1) : sentence_ends[i]] | |
| result[index : index + sentence.size(0)] = sentence | |
| index += sentence.size(0) | |
| return result | |
| def word_starts(self, source): | |
| if self.mask_whole_word is not None: | |
| is_word_start = self.mask_whole_word.gather(0, source) | |
| else: | |
| is_word_start = torch.ones(source.size()) | |
| is_word_start[0] = 0 | |
| is_word_start[-1] = 0 | |
| return is_word_start | |
| def add_whole_word_mask(self, source, p): | |
| source_ori = source.clone() | |
| is_word_start = self.word_starts(source) | |
| num_to_mask = int(math.ceil(is_word_start.float().sum() * p)) | |
| num_inserts = 0 | |
| if num_to_mask == 0: | |
| return source | |
| if self.mask_span_distribution is not None: | |
| lengths = self.mask_span_distribution.sample(sample_shape=(num_to_mask,)) | |
| # Make sure we have enough to mask | |
| cum_length = torch.cumsum(lengths, 0) | |
| while cum_length[-1] < num_to_mask: | |
| lengths = torch.cat( | |
| [ | |
| lengths, | |
| self.mask_span_distribution.sample(sample_shape=(num_to_mask,)), | |
| ], | |
| dim=0, | |
| ) | |
| cum_length = torch.cumsum(lengths, 0) | |
| # Trim to masking budget | |
| i = 0 | |
| while cum_length[i] < num_to_mask: | |
| i += 1 | |
| lengths[i] = num_to_mask - (0 if i == 0 else cum_length[i - 1]) | |
| num_to_mask = i + 1 | |
| lengths = lengths[:num_to_mask] | |
| # Handle 0-length mask (inserts) separately | |
| lengths = lengths[lengths > 0] | |
| num_inserts = num_to_mask - lengths.size(0) | |
| num_to_mask -= num_inserts | |
| if num_to_mask == 0: | |
| return self.add_insertion_noise(source, num_inserts / source.size(0)) | |
| assert (lengths > 0).all() | |
| else: | |
| lengths = torch.ones((num_to_mask,)).long() | |
| assert is_word_start[-1] == 0 | |
| word_starts = is_word_start.nonzero(as_tuple=False) | |
| indices = word_starts[ | |
| torch.randperm(word_starts.size(0))[:num_to_mask] | |
| ].squeeze(1) | |
| mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio | |
| source_length = source.size(0) | |
| assert source_length - 1 not in indices | |
| to_keep = torch.ones(source_length, dtype=torch.bool) | |
| is_word_start[ | |
| -1 | |
| ] = 255 # acts as a long length, so spans don't go over the end of doc | |
| if self.replace_length == 0: | |
| to_keep[indices] = 0 | |
| else: | |
| # keep index, but replace it with [MASK] | |
| source[indices] = self.mask_idx | |
| source[indices[mask_random]] = torch.randint( | |
| 1, len(self.vocab), size=(mask_random.sum(),) | |
| ) | |
| if self.mask_span_distribution is not None: | |
| assert len(lengths.size()) == 1 | |
| assert lengths.size() == indices.size() | |
| lengths -= 1 | |
| while indices.size(0) > 0: | |
| assert lengths.size() == indices.size() | |
| lengths -= is_word_start[indices + 1].long() | |
| uncompleted = lengths >= 0 | |
| indices = indices[uncompleted] + 1 | |
| mask_random = mask_random[uncompleted] | |
| lengths = lengths[uncompleted] | |
| if self.replace_length != -1: | |
| # delete token | |
| to_keep[indices] = 0 | |
| else: | |
| # keep index, but replace it with [MASK] | |
| source[indices] = self.mask_idx | |
| source[indices[mask_random]] = torch.randint( | |
| 1, len(self.vocab), size=(mask_random.sum(),) | |
| ) | |
| else: | |
| # A bit faster when all lengths are 1 | |
| while indices.size(0) > 0: | |
| uncompleted = is_word_start[indices + 1] == 0 | |
| indices = indices[uncompleted] + 1 | |
| mask_random = mask_random[uncompleted] | |
| if self.replace_length != -1: | |
| # delete token | |
| to_keep[indices] = 0 | |
| else: | |
| # keep index, but replace it with [MASK] | |
| source[indices] = self.mask_idx | |
| source[indices[mask_random]] = torch.randint( | |
| 1, len(self.vocab), size=(mask_random.sum(),) | |
| ) | |
| assert source_length - 1 not in indices | |
| if not self.iid_noise_target: | |
| source = source[to_keep] | |
| target = None | |
| else: | |
| ## Prepare source | |
| source_mask_idx = (source == self.mask_idx).nonzero().view(-1) | |
| source[source_mask_idx] = self.uni_mask_idxs[:source_mask_idx.size(0)] | |
| source = source[to_keep] | |
| ## Prepare target | |
| to_keep[source_mask_idx] = 0 | |
| # source_mask_idx: from [a, b, c, ...] to [a, b + 1, c + 2, ...] | |
| source_mask_idx = source_mask_idx + torch.arange(source_mask_idx.size(0)) | |
| # target: source_length + mask_length | |
| target = source_ori.new_zeros(source_mask_idx.size(0) + source_ori.size(0)) | |
| # target: [0, 0, 0, X, 0, 0, Y, ....] | |
| target[source_mask_idx] = self.uni_mask_idxs[:source_mask_idx.size(0)] | |
| target_to_keep = to_keep.new_zeros(source_mask_idx.size(0) + source_ori.size(0)) | |
| # Copy original value to target and target_to_keep | |
| target_to_keep[target == 0] = to_keep | |
| target_to_keep[-1] = 0 | |
| target[target == 0] = source_ori | |
| target = target[~target_to_keep] | |
| if num_inserts > 0: | |
| source = self.add_insertion_noise(source, num_inserts / source.size(0)) | |
| return source, target | |
| def add_permuted_noise(self, tokens, p): | |
| num_words = len(tokens) | |
| num_to_permute = math.ceil(((num_words * 2) * p) / 2.0) | |
| substitutions = torch.randperm(num_words - 2)[:num_to_permute] + 1 | |
| tokens[substitutions] = tokens[substitutions[torch.randperm(num_to_permute)]] | |
| return tokens | |
| def add_rolling_noise(self, tokens): | |
| offset = np.random.randint(1, max(1, tokens.size(-1) - 1) + 1) | |
| tokens = torch.cat( | |
| (tokens[0:1], tokens[offset:-1], tokens[1:offset], tokens[-1:]), | |
| dim=0, | |
| ) | |
| return tokens | |
| def add_insertion_noise(self, tokens, p): | |
| if p == 0.0: | |
| return tokens | |
| num_tokens = len(tokens) | |
| n = int(math.ceil(num_tokens * p)) | |
| noise_indices = torch.randperm(num_tokens + n - 2)[:n] + 1 | |
| noise_mask = torch.zeros(size=(num_tokens + n,), dtype=torch.bool) | |
| noise_mask[noise_indices] = 1 | |
| result = torch.LongTensor(n + len(tokens)).fill_(-1) | |
| num_random = int(math.ceil(n * self.random_ratio)) | |
| result[noise_indices[num_random:]] = self.mask_idx | |
| result[noise_indices[:num_random]] = torch.randint( | |
| low=1, high=len(self.vocab), size=(num_random,) | |
| ) | |
| result[~noise_mask] = tokens | |
| assert (result >= 0).all() | |
| return result | |
| def collater(self, samples, pad_to_length=None): | |
| """Merge a list of samples to form a mini-batch. | |
| Args: | |
| samples (List[dict]): samples to collate | |
| Returns: | |
| dict: a mini-batch of data | |
| """ | |
| return collate( | |
| samples, self.vocab.pad(), self.eos, self.vocab, pad_to_length=pad_to_length | |
| ) | |
| def num_tokens(self, index): | |
| """Return the number of tokens in a sample. This value is used to | |
| enforce ``--max-tokens`` during batching.""" | |
| return self.sizes[index] | |
| def size(self, index): | |
| """Return an example's size as a float or tuple. This value is used when | |
| filtering a dataset with ``--max-positions``.""" | |
| return self.sizes[index] | |
| def ordered_indices(self): | |
| """Return an ordered list of indices. Batches will be constructed based | |
| on this order.""" | |
| if self.shuffle: | |
| indices = np.random.permutation(len(self)) | |
| else: | |
| indices = np.arange(len(self)) | |
| return indices[np.argsort(self.sizes[indices], kind="mergesort")] | |
| def prefetch(self, indices): | |
| self.src.prefetch(indices) | |
| self.tgt.prefetch(indices) | |
| def supports_prefetch(self): | |
| return ( | |
| hasattr(self.src, "supports_prefetch") | |
| and self.src.supports_prefetch | |
| and hasattr(self.tgt, "supports_prefetch") | |
| and self.tgt.supports_prefetch | |
| ) | |