Spaces:
Runtime error
Runtime error
| import numpy as np | |
| from random import randint, shuffle, choice | |
| from random import random as rand | |
| import math | |
| import logging | |
| import torch | |
| import torch.utils.data | |
| logger = logging.getLogger(__name__) | |
| def get_random_word(vocab_words): | |
| i = randint(0, len(vocab_words)-1) | |
| return vocab_words[i] | |
| def batch_list_to_batch_tensors(batch): | |
| batch_tensors = [] | |
| for x in zip(*batch): | |
| if x[0] is None: | |
| batch_tensors.append(None) | |
| elif isinstance(x[0], torch.Tensor): | |
| batch_tensors.append(torch.stack(x)) | |
| else: | |
| batch_tensors.append(torch.tensor(x, dtype=torch.long)) | |
| return batch_tensors | |
| def _get_word_split_index(tokens, st, end): | |
| split_idx = [] | |
| i = st | |
| while i < end: | |
| if (not tokens[i].startswith('##')) or (i == st): | |
| split_idx.append(i) | |
| i += 1 | |
| split_idx.append(end) | |
| return split_idx | |
| def _expand_whole_word(tokens, st, end): | |
| new_st, new_end = st, end | |
| while (new_st >= 0) and tokens[new_st].startswith('##'): | |
| new_st -= 1 | |
| while (new_end < len(tokens)) and tokens[new_end].startswith('##'): | |
| new_end += 1 | |
| return new_st, new_end | |
| class Pipeline(): | |
| """ Pre-process Pipeline Class : callable """ | |
| def __init__(self): | |
| super().__init__() | |
| self.skipgram_prb = None | |
| self.skipgram_size = None | |
| self.pre_whole_word = None | |
| self.mask_whole_word = None | |
| self.word_subsample_prb = None | |
| self.sp_prob = None | |
| self.pieces_dir = None | |
| self.vocab_words = None | |
| self.pieces_threshold = 10 | |
| self.call_count = 0 | |
| self.offline_mode = False | |
| self.skipgram_size_geo_list = None | |
| self.span_same_mask = False | |
| def __call__(self, instance): | |
| raise NotImplementedError | |
| class Preprocess4Seq2seqDecoder(Pipeline): | |
| """ Pre-processing steps for pretraining transformer """ | |
| def __init__(self, vocab_words, indexer, max_len=512, max_tgt_length=128, | |
| mode="s2s", pos_shift=False, source_type_id=0, target_type_id=1, | |
| cls_token='[CLS]', sep_token='[SEP]', pad_token='[PAD]', layout_flag=False): | |
| super().__init__() | |
| self.max_len = max_len | |
| self.vocab_words = vocab_words # vocabulary (sub)words | |
| self.indexer = indexer # function from token to token index | |
| self.max_len = max_len | |
| self._tril_matrix = torch.tril(torch.ones((max_len, max_len), dtype=torch.long)) | |
| self.task_idx = 3 # relax projection layer for different tasks | |
| assert mode in ("s2s", "l2r") | |
| self.mode = mode | |
| self.max_tgt_length = max_tgt_length | |
| self.pos_shift = pos_shift | |
| self.layout_flag = layout_flag | |
| if layout_flag: | |
| self.cls_token = [cls_token, 0, 0, 0, 0] | |
| self.sep_token = [sep_token, 1000, 1000, 1000, 1000] | |
| self.pad_token = [pad_token, 0, 0, 0, 0] | |
| else: | |
| self.cls_token = cls_token | |
| self.sep_token = sep_token | |
| self.pad_token = pad_token | |
| self.source_type_id = source_type_id | |
| self.target_type_id = target_type_id | |
| self.cc = 0 | |
| def __call__(self, instance): | |
| tokens_a, max_a_len = instance | |
| # NOTE: must pad to the max src length | |
| max_a_len = 511 | |
| padded_tokens_a = [self.cls_token] + tokens_a + [self.sep_token] | |
| assert len(padded_tokens_a) <= max_a_len + 2 | |
| if max_a_len + 2 > len(padded_tokens_a): | |
| padded_tokens_a += [self.pad_token] * \ | |
| (max_a_len + 2 - len(padded_tokens_a)) | |
| assert len(padded_tokens_a) == max_a_len + 2 | |
| max_len_in_batch = min(self.max_tgt_length + max_a_len + 2, self.max_len) | |
| tokens = padded_tokens_a | |
| segment_ids = [self.source_type_id] * (len(padded_tokens_a)) \ | |
| + [self.target_type_id] * (max_len_in_batch - len(padded_tokens_a)) | |
| mask_qkv = None | |
| position_ids = [] | |
| for i in range(len(tokens_a) + 2): | |
| position_ids.append(i) | |
| for i in range(len(tokens_a) + 2, max_a_len + 2): | |
| position_ids.append(0) | |
| for i in range(max_a_len + 2, max_len_in_batch): | |
| position_ids.append(i - (max_a_len + 2) + len(tokens_a) + 2) | |
| # Token Indexing | |
| if not self.layout_flag: | |
| input_ids = self.indexer(tokens) | |
| else: | |
| raw_text = [x[0] for x in tokens] | |
| raw_text_ids = self.indexer(raw_text) | |
| input_ids = [[i] + x[1:] for i, x in zip(raw_text_ids, tokens)] | |
| self.cc += 1 | |
| if self.cc < 5: | |
| if not self.layout_flag: | |
| logger.info("Input src = %s" % " ".join(self.vocab_words[tk_id] for tk_id in input_ids)) | |
| else: | |
| logger.info("Input src = %s" % " ".join(self.vocab_words[tk_id[0]] for tk_id in input_ids)) | |
| # Zero Padding | |
| input_mask = torch.zeros( | |
| max_len_in_batch, max_len_in_batch, dtype=torch.long) | |
| if self.mode == "s2s": | |
| input_mask[:, :len(tokens_a)+2].fill_(1) | |
| else: | |
| st, end = 0, len(tokens_a) + 2 | |
| input_mask[st:end, st:end].copy_( | |
| self._tril_matrix[:end, :end]) | |
| input_mask[end:, :len(tokens_a)+2].fill_(1) | |
| second_st, second_end = len(padded_tokens_a), max_len_in_batch | |
| input_mask[second_st:second_end, second_st:second_end].copy_( | |
| self._tril_matrix[:second_end-second_st, :second_end-second_st]) | |
| return input_ids, segment_ids, position_ids, input_mask, mask_qkv, self.task_idx | |