Spaces:
Runtime error
Runtime error
| import torch | |
| import random | |
| import warnings | |
| from transformers import BertTokenizer, BertTokenizerFast, BatchEncoding | |
| from typing import List, Union, Tuple, Any, Dict | |
| def whole_word_mask(tokenizer: Union[BertTokenizer, BertTokenizerFast], | |
| input_tokens: List[str], | |
| mlm_prob: float, | |
| max_predictions=512) -> List[int]: | |
| """ | |
| Get 0/1 labels for masked tokens with whole word mask proxy | |
| """ | |
| if not isinstance(tokenizer, (BertTokenizer, BertTokenizerFast)): | |
| warnings.warn( | |
| "DataCollatorForWholeWordMask is only suitable for BertTokenizer-like tokenizers. " | |
| "Please refer to the documentation for more information." | |
| ) | |
| cand_indexes = [] | |
| for (i, token) in enumerate(input_tokens): | |
| if token == "[CLS]" or token == "[SEP]": | |
| continue | |
| if len(cand_indexes) >= 1 and token.startswith("##"): | |
| cand_indexes[-1].append(i) | |
| else: | |
| cand_indexes.append([i]) | |
| random.shuffle(cand_indexes) | |
| num_to_predict = min(max_predictions, max(1, int(round(len(input_tokens) * mlm_prob)))) | |
| masked_lms = [] | |
| covered_indexes = set() | |
| for index_set in cand_indexes: | |
| if len(masked_lms) >= num_to_predict: | |
| break | |
| # If adding a whole-word mask would exceed the maximum number of | |
| # predictions, then just skip this candidate. | |
| if len(masked_lms) + len(index_set) > num_to_predict: | |
| continue | |
| is_any_index_covered = False | |
| for index in index_set: | |
| if index in covered_indexes: | |
| is_any_index_covered = True | |
| break | |
| if is_any_index_covered: | |
| continue | |
| for index in index_set: | |
| covered_indexes.add(index) | |
| masked_lms.append(index) | |
| if len(covered_indexes) != len(masked_lms): | |
| raise ValueError("Length of covered_indexes is not equal to length of masked_lms.") | |
| mask_labels = [1 if i in covered_indexes else 0 for i in range(len(input_tokens))] | |
| return mask_labels | |
| def torch_mask_tokens(tokenizer: Union[BertTokenizer, BertTokenizerFast], | |
| inputs: torch.Tensor, | |
| mask_labels: torch.Tensor, | |
| all_use_mask_token: bool = False) -> Tuple[Any, Any]: | |
| """ | |
| Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set | |
| 'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref. | |
| """ | |
| if tokenizer.mask_token is None: | |
| raise ValueError( | |
| "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer." | |
| ) | |
| labels = inputs.clone() | |
| masked_inputs = inputs.clone() | |
| # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa) | |
| probability_matrix = mask_labels.clone() | |
| special_tokens_mask = [ | |
| tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() | |
| ] | |
| probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) | |
| if tokenizer._pad_token is not None: | |
| padding_mask = labels.eq(tokenizer.pad_token_id) | |
| probability_matrix.masked_fill_(padding_mask, value=0.0) | |
| masked_indices = probability_matrix.bool() | |
| labels[~masked_indices] = -100 # We only compute loss on masked tokens | |
| if all_use_mask_token: | |
| masked_inputs[masked_indices] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token) | |
| return masked_inputs, labels | |
| # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) | |
| indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices | |
| masked_inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token) | |
| # 10% of the time, we replace masked input tokens with random word | |
| indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced | |
| random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long) | |
| masked_inputs[indices_random] = random_words[indices_random] | |
| # The rest of the time (10% of the time) we keep the masked input tokens unchanged | |
| return masked_inputs, labels | |
| def merge_batch_dict(src_batch_dict: Union[Dict, BatchEncoding], | |
| tgt_batch_dict: Union[Dict, BatchEncoding], | |
| prefix: str = None): | |
| for key in src_batch_dict: | |
| tgt_batch_dict[(prefix or '') + key] = src_batch_dict[key].clone() | |