Spaces:
Runtime error
Runtime error
| from __future__ import absolute_import, division, print_function | |
| import logging | |
| import os | |
| import json | |
| import random | |
| import glob | |
| import re | |
| import torch | |
| import tqdm | |
| import torch.utils.data | |
| logger = logging.getLogger(__name__) | |
| class Seq2seqDatasetForBert(torch.utils.data.Dataset): | |
| def __init__( | |
| self, features, max_source_len, max_target_len, | |
| vocab_size, cls_id, sep_id, pad_id, mask_id, | |
| random_prob, keep_prob, offset, num_training_instances, | |
| span_len=1, span_prob=1.0): | |
| self.features = features | |
| self.max_source_len = max_source_len | |
| self.max_target_len = max_target_len | |
| self.offset = offset | |
| if offset > 0: | |
| logger.info(" **** Set offset %d in Seq2seqDatasetForBert **** ", offset) | |
| self.cls_id = cls_id | |
| self.sep_id = sep_id | |
| self.pad_id = pad_id | |
| self.random_prob = random_prob | |
| self.keep_prob = keep_prob | |
| self.mask_id = mask_id | |
| self.vocab_size = vocab_size | |
| self.num_training_instances = num_training_instances | |
| self.span_len = span_len | |
| self.span_prob = span_prob | |
| def __len__(self): | |
| return int(self.num_training_instances) | |
| def __trunk(self, ids, max_len): | |
| if len(ids) > max_len - 1: | |
| ids = ids[:max_len - 1] | |
| ids = ids + [self.sep_id] | |
| return ids | |
| def __pad(self, ids, max_len): | |
| if len(ids) < max_len: | |
| return ids + [self.pad_id] * (max_len - len(ids)) | |
| else: | |
| assert len(ids) == max_len | |
| return ids | |
| def __getitem__(self, idx): | |
| idx = (self.offset + idx) % len(self.features) | |
| feature = self.features[idx] | |
| source_ids = self.__trunk([self.cls_id] + feature["source_ids"], self.max_source_len) | |
| target_ids = self.__trunk(feature["target_ids"], self.max_target_len) | |
| pseudo_ids = [] | |
| for tk_id in target_ids: | |
| p = random.random() | |
| if p < self.keep_prob: | |
| pseudo_ids.append(tk_id) | |
| elif p < self.keep_prob + self.random_prob: | |
| pseudo_ids.append(random.randint(0, self.vocab_size - 1)) | |
| else: | |
| pseudo_ids.append(self.mask_id) | |
| num_source_tokens = len(source_ids) | |
| num_target_tokens = len(target_ids) | |
| source_ids = self.__pad(source_ids, self.max_source_len) | |
| target_ids = self.__pad(target_ids, self.max_target_len) | |
| pseudo_ids = self.__pad(pseudo_ids, self.max_target_len) | |
| if self.span_len > 1: | |
| span_ids = [] | |
| span_id = 1 | |
| while len(span_ids) < num_target_tokens: | |
| p = random.random() | |
| if p < self.span_prob: | |
| span_len = random.randint(2, self.span_len) | |
| span_len = min(span_len, num_target_tokens - len(span_ids)) | |
| else: | |
| span_len = 1 | |
| span_ids.extend([span_id] * span_len) | |
| span_id += 1 | |
| span_ids = self.__pad(span_ids, self.max_target_len) | |
| return source_ids, target_ids, pseudo_ids, num_source_tokens, num_target_tokens, span_ids | |
| else: | |
| return source_ids, target_ids, pseudo_ids, num_source_tokens, num_target_tokens | |
| # DONE: finish this!!! the 2D input id settings. | |
| class Seq2seqDatasetForLayoutlm(torch.utils.data.Dataset): | |
| def __init__( | |
| self, features, max_source_len, max_target_len, | |
| vocab_size, cls_id, sep_id, pad_id, mask_id, | |
| random_prob, keep_prob, offset, num_training_instances, layout_flag=True, | |
| span_len=1, span_prob=1.0): | |
| self.layout_flag = layout_flag | |
| self.features = features | |
| self.max_source_len = max_source_len | |
| self.max_target_len = max_target_len | |
| self.offset = offset | |
| if offset > 0: | |
| logger.info(" **** Set offset %d in Seq2seqDatasetForBert **** ", offset) | |
| self.cls_id = cls_id | |
| self.sep_id = sep_id | |
| self.pad_id = pad_id | |
| self.random_prob = random_prob | |
| self.keep_prob = keep_prob | |
| self.mask_id = mask_id | |
| self.vocab_size = vocab_size | |
| self.num_training_instances = num_training_instances | |
| self.span_len = span_len | |
| self.span_prob = span_prob | |
| self.index_sp_id = 0 | |
| def __len__(self): | |
| return int(self.num_training_instances) | |
| def __clip_index(self, ids): | |
| replace_value = 0 | |
| for i in range(len(ids)): | |
| if ids[i] > self.max_source_len - 1: | |
| ids[i] = replace_value | |
| return ids | |
| def __trunk(self, ids, max_len, simple=False, value=None): | |
| trunk_value = value if value is not None else self.sep_id | |
| if len(ids) > max_len - 1: | |
| ids = ids[:max_len - 1] | |
| if simple: | |
| ids = ids + [trunk_value] | |
| else: | |
| ids = ids + [[trunk_value, 1000, 1000, 1000, 1000]] | |
| return ids | |
| def __pad(self, ids, max_len, simple=False, value=None): | |
| pad_value = value if value is not None else self.pad_id | |
| if len(ids) < max_len: | |
| if simple: | |
| return ids + [pad_value] * (max_len - len(ids)) | |
| else: | |
| return ids + [[pad_value, 0, 0, 0, 0]] * (max_len - len(ids)) | |
| else: | |
| assert len(ids) == max_len | |
| return ids | |
| def __getitem__(self, idx): | |
| if self.layout_flag: | |
| return self.__getitem_layout__(idx) | |
| else: | |
| return self.__getitem_bert__(idx) | |
| def __getitem_bert__(self, idx): | |
| idx = (self.offset + idx) % len(self.features) | |
| feature = self.features[idx] | |
| source_ids = self.__trunk([self.cls_id] + feature["source_ids"], self.max_source_len, simple=True) | |
| target_ids = self.__trunk(feature["target_ids"], self.max_target_len, simple=True) | |
| target_index = self.__trunk(feature['target_index'], self.max_target_len, simple=True, value=self.index_sp_id) | |
| pseudo_ids = [] | |
| for tk_id in target_ids: | |
| p = random.random() | |
| if p < self.keep_prob: | |
| pseudo_ids.append(tk_id) | |
| elif p < self.keep_prob + self.random_prob: | |
| pseudo_ids.append(random.randint(0, self.vocab_size - 1)) | |
| else: | |
| pseudo_ids.append(self.mask_id) | |
| num_source_tokens = len(source_ids) | |
| num_target_tokens = len(target_ids) | |
| source_ids = self.__pad(source_ids, self.max_source_len, simple=True) | |
| target_ids = self.__pad(target_ids, self.max_target_len, simple=True) | |
| pseudo_ids = self.__pad(pseudo_ids, self.max_target_len, simple=True) | |
| target_index = self.__pad(target_index, self.max_target_len, simple=True, value=self.index_sp_id) | |
| target_index = self.__clip_index(target_index) | |
| if self.span_len > 1: | |
| span_ids = [] | |
| span_id = 1 | |
| while len(span_ids) < num_target_tokens: | |
| p = random.random() | |
| if p < self.span_prob: | |
| span_len = random.randint(2, self.span_len) | |
| span_len = min(span_len, num_target_tokens - len(span_ids)) | |
| else: | |
| span_len = 1 | |
| span_ids.extend([span_id] * span_len) | |
| span_id += 1 | |
| span_ids = self.__pad(span_ids, self.max_target_len) | |
| return source_ids, target_ids, pseudo_ids, num_source_tokens, num_target_tokens, span_ids, target_index | |
| else: | |
| return source_ids, target_ids, pseudo_ids, num_source_tokens, num_target_tokens, target_index | |
| def __getitem_layout__(self, idx): | |
| # TODO: how to initialize the random and masked tokens' pos emb | |
| # Simple Solution: only mask the text | |
| idx = (self.offset + idx) % len(self.features) | |
| feature = self.features[idx] | |
| source_ids = self.__trunk([[self.cls_id, 0, 0, 0, 0]] + feature["source_ids"], self.max_source_len) | |
| target_ids = self.__trunk(feature["target_ids"], self.max_target_len) | |
| target_index = self.__trunk(feature['target_index'], self.max_target_len, simple=True, value=self.index_sp_id) | |
| pseudo_ids = [] | |
| for tk_id in target_ids: | |
| p = random.random() | |
| if p < self.keep_prob: | |
| pseudo_ids.append(tk_id) | |
| elif p < self.keep_prob + self.random_prob: | |
| pseudo_ids.append([random.randint(0, self.vocab_size - 1)] + [0, 0, 0, 0]) # tk_id[1:]) | |
| else: | |
| pseudo_ids.append([self.mask_id] + [0, 0, 0, 0]) # tk_id[1:]) | |
| num_source_tokens = len(source_ids) | |
| num_target_tokens = len(target_ids) | |
| source_ids = self.__pad(source_ids, self.max_source_len) | |
| target_ids = self.__pad(target_ids, self.max_target_len) | |
| pseudo_ids = self.__pad(pseudo_ids, self.max_target_len) | |
| target_index = self.__pad(target_index, self.max_target_len, simple=True, value=self.index_sp_id) | |
| target_index = self.__clip_index(target_index) | |
| if self.span_len > 1: | |
| span_ids = [] | |
| span_id = 1 | |
| while len(span_ids) < num_target_tokens: | |
| p = random.random() | |
| if p < self.span_prob: | |
| span_len = random.randint(2, self.span_len) | |
| span_len = min(span_len, num_target_tokens - len(span_ids)) | |
| else: | |
| span_len = 1 | |
| span_ids.extend([span_id] * span_len) | |
| span_id += 1 | |
| span_ids = self.__pad(span_ids, self.max_target_len) | |
| return source_ids, target_ids, pseudo_ids, num_source_tokens, num_target_tokens, span_ids, target_index | |
| else: | |
| return source_ids, target_ids, pseudo_ids, num_source_tokens, num_target_tokens, target_index | |
| def batch_list_to_batch_tensors(batch): | |
| batch_tensors = [] | |
| for x in zip(*batch): | |
| if isinstance(x[0], torch.Tensor): | |
| batch_tensors.append(torch.stack(x)) | |
| else: | |
| batch_tensors.append(torch.tensor(x, dtype=torch.long)) | |
| return batch_tensors | |
| def get_max_epoch_model(output_dir): | |
| fn_model_list = glob.glob(os.path.join(output_dir, "model.*.bin")) | |
| fn_optim_list = glob.glob(os.path.join(output_dir, "optim.*.bin")) | |
| if (not fn_model_list) or (not fn_optim_list): | |
| return None | |
| os.path.basename(output_dir) | |
| both_set = set([int(os.path.basename(fn).split('.')[1]) for fn in fn_model_list] | |
| ) & set([int(os.path.basename(fn).split('.')[1]) for fn in fn_optim_list]) | |
| if both_set: | |
| return max(both_set) | |
| else: | |
| return None | |
| def load_and_cache_examples( | |
| example_file, tokenizer, local_rank, cached_features_file, shuffle=True): | |
| # Make sure only the first process in distributed training process the dataset, and the others will use the cache | |
| if local_rank not in [-1, 0]: | |
| torch.distributed.barrier() | |
| if cached_features_file is not None and os.path.exists(cached_features_file): | |
| logger.info("Loading features from cached file %s", cached_features_file) | |
| features = torch.load(cached_features_file) | |
| else: | |
| logger.info("Creating features from dataset file at %s", example_file) | |
| examples = [] | |
| with open(example_file, mode="r", encoding="utf-8") as reader: | |
| for i, line in enumerate(reader): | |
| if i == 100: | |
| break | |
| examples.append(json.loads(line)) | |
| features = [] | |
| for example in tqdm.tqdm(examples): | |
| if isinstance(example["src"], list): | |
| source_tokens = example["src"] | |
| target_tokens = example["tgt"] | |
| else: | |
| source_tokens = tokenizer.tokenize(example["src"]) | |
| target_tokens = tokenizer.tokenize(example["tgt"]) | |
| features.append({ | |
| "source_ids": tokenizer.convert_tokens_to_ids(source_tokens), | |
| "target_ids": tokenizer.convert_tokens_to_ids(target_tokens), | |
| }) | |
| if shuffle: | |
| random.shuffle(features) | |
| if local_rank in [-1, 0] and cached_features_file is not None: | |
| logger.info("Saving features into cached file %s", cached_features_file) | |
| torch.save(features, cached_features_file) | |
| # Make sure only the first process in distributed training process the dataset, and the others will use the cache | |
| if local_rank == 0: | |
| torch.distributed.barrier() | |
| return features | |
| def load_and_cache_line_order_examples( | |
| example_path, tokenizer, local_rank, cached_features_file, max_src_length=1024, | |
| layout_flag=True, shuffle=True, | |
| src_shuffle_rate=0, | |
| file_info_flag=False, | |
| ): | |
| # Make sure only the first process in distributed training process the dataset, and the others will use the cache | |
| if local_rank not in [-1, 0]: | |
| torch.distributed.barrier() | |
| if cached_features_file is not None and os.path.exists(cached_features_file) and False: | |
| logger.info("Loading features from cached file %s", cached_features_file) | |
| features = torch.load(cached_features_file) | |
| else: | |
| logger.info("Creating features from dataset at %s", example_path) | |
| examples = [] | |
| with open(example_path, 'r') as layout_reader: | |
| logger.info(f'Start loading {example_path}') | |
| for i, line in enumerate(layout_reader): | |
| examples.append(json.loads(line)) | |
| features = [] | |
| for layout in tqdm.tqdm(examples): | |
| bleu = layout['bleu'] | |
| if random.random() < src_shuffle_rate: | |
| # print('Random!!!') | |
| # DONE: the random src! here has bug! index also need shuffle | |
| src_layout = layout['src'] | |
| tgt_index = layout['tgt_index'] | |
| source_length = len(src_layout) | |
| shuffle_index = list(range(source_length)) | |
| random.shuffle(shuffle_index) | |
| shuffle_layout = ['' for _ in range(source_length)] | |
| for i, j in enumerate(shuffle_index): | |
| # NOTE: map i-th token to j-th token | |
| shuffle_layout[j] = src_layout[i] | |
| shuffle_target_index = [shuffle_index[i] for i in tgt_index] | |
| layout['tgt_index'] = shuffle_target_index | |
| layout['src'] = shuffle_layout | |
| mask = tokenizer.mask_token_id | |
| src_ids = [tokenizer.convert_tokens_to_ids([str(tmp_i)])[:1] + src_layout for tmp_i, src_layout in enumerate(layout['src'])] | |
| tgt_ids = [tokenizer.convert_tokens_to_ids([str(tmp_i)])[:1] + tgt_layout for tmp_i, tgt_layout in enumerate(layout['tgt'])] | |
| tgt_index = layout['tgt_index'] | |
| feature = { | |
| "source_ids": src_ids, | |
| "target_ids": tgt_ids, | |
| "target_index": tgt_index, | |
| 'bleu': bleu | |
| } | |
| if file_info_flag: | |
| file_info = {'original_filename': layout['filename'], 'filename': layout['filename'], | |
| 'page_idx': 0} | |
| feature['file_info'] = file_info | |
| features.append(feature) | |
| if shuffle: | |
| random.shuffle(features) | |
| if local_rank in [-1, 0] and cached_features_file is not None: | |
| logger.info("Saving features into cached file %s", cached_features_file) | |
| torch.save(features, cached_features_file) | |
| # Make sure only the first process in distributed training process the dataset, and the others will use the cache | |
| if local_rank == 0: | |
| torch.distributed.barrier() | |
| return features | |
| def load_and_cache_layoutlm_examples( | |
| example_path, tokenizer, local_rank, cached_features_file, max_src_length=1024, | |
| layout_flag=True, shuffle=True, | |
| src_shuffle_rate=0, | |
| file_info_flag=False | |
| ): | |
| # Make sure only the first process in distributed training process the dataset, and the others will use the cache | |
| if local_rank not in [-1, 0]: | |
| torch.distributed.barrier() | |
| if cached_features_file is not None and os.path.exists(cached_features_file): | |
| logger.info("Loading features from cached file %s", cached_features_file) | |
| features = torch.load(cached_features_file) | |
| else: | |
| logger.info("Creating features from dataset at %s", example_path) | |
| examples = [] | |
| if os.path.isdir(example_path): | |
| text_files = glob.glob(f'{example_path}/*text*.json') | |
| layout_files = [re.sub('text|txt', 'layout', x, 1) for x in text_files] | |
| else: | |
| text_files = [example_path] | |
| layout_files = [re.sub('text|txt', 'layout', example_path, 1)] | |
| for text_file, layout_file in zip(text_files, layout_files): | |
| with open(text_file, mode='r', encoding='utf-8') as text_reader, \ | |
| open(layout_file, mode='r', encoding='utf-8') as layout_reader: | |
| logger.info(f'Start loading {text_file}') | |
| for i, (text_line, layout_line) in enumerate(zip(text_reader, layout_reader)): | |
| if (i + 1) % 10000 == 0: | |
| logger.info(f'{i + 1} lines ...') | |
| examples.append((json.loads(text_line), json.loads(layout_line))) | |
| features = [] | |
| def tokenize_text_and_layout_src(_text, _layout, _layout_flag): | |
| ret = [] | |
| index_split = {} | |
| words = _text.split() | |
| # note: (OLD) the index should start from 1: 0-the cls token in src | |
| # note: (NEW) we need to remove the src embedding's CLS SEP token so we can still start from 0 | |
| # note: (NEWER) we need to at least one blank pos for ignore index in loss function (we use sep's index) | |
| # NOTE: (NEWER-ER) 1 for all padding tgt index | |
| new_token_index = 1 # first ordinary index | |
| for i, (word, box) in enumerate(zip(words, _layout)): | |
| if (not box[2] >= box[0]) or (not box[3] >= box[1]): | |
| continue | |
| tokens = tokenizer.tokenize(word) | |
| tokens = tokenizer.convert_tokens_to_ids(tokens) | |
| new_token_ids = [] | |
| for token in tokens: | |
| if _layout_flag: | |
| ret.append([token] + box) | |
| else: | |
| ret.append(token) | |
| new_token_ids.append(new_token_index) | |
| new_token_index += 1 | |
| index_split[i] = new_token_ids | |
| return ret, index_split | |
| def tokenize_text_and_layout_tgt(_text, _layout, _index, _index_split, _layout_flag): | |
| ret = [] | |
| ret_index = [] | |
| words = _text.split() | |
| for word, box, i in zip(words, _layout, _index): | |
| if (not box[2] >= box[0]) or (not box[3] >= box[1]): | |
| continue | |
| tokens = tokenizer.tokenize(word) | |
| tokens = tokenizer.convert_tokens_to_ids(tokens) | |
| for token, ii in zip(tokens, _index_split[i]): | |
| if _layout_flag: | |
| ret.append([token] + box) | |
| else: | |
| ret.append(token) | |
| ii = min(ii, max_src_length - 1) | |
| ret_index.append(ii) | |
| return ret, ret_index | |
| for text, layout in tqdm.tqdm(examples): | |
| if 'bleu' in text: | |
| bleu = text['bleu'] | |
| else: | |
| bleu = 0 | |
| if random.random() < src_shuffle_rate: | |
| # print('Random!!!') | |
| # DONE: the random src! here has bug! index also need shuffle | |
| src_text = text['src'] | |
| src_layout = layout['src'] | |
| tgt_index = text['tgt_index'] | |
| src_text = src_text.split() | |
| source_length = len(src_text) | |
| shuffle_index = list(range(source_length)) | |
| random.shuffle(shuffle_index) | |
| shuffle_text = ['' for _ in range(source_length)] | |
| shuffle_layout = ['' for _ in range(source_length)] | |
| for i, j in enumerate(shuffle_index): | |
| # NOTE: map i-th token to j-th token | |
| shuffle_text[j] = src_text[i] | |
| shuffle_layout[j] = src_layout[i] | |
| shuffle_target_index = [shuffle_index[i] for i in tgt_index] | |
| text['src'] = ' '.join(shuffle_text) | |
| text['tgt_index'] = shuffle_target_index | |
| layout['src'] = shuffle_layout | |
| src_ids, src_index_split = tokenize_text_and_layout_src(text['src'], layout['src'], | |
| _layout_flag=layout_flag) | |
| tgt_ids, tgt_index = tokenize_text_and_layout_tgt(text['tgt'], layout['tgt'], text['tgt_index'], | |
| src_index_split, _layout_flag=layout_flag) | |
| feature = { | |
| "source_ids": src_ids, | |
| "target_ids": tgt_ids, | |
| "target_index": tgt_index, | |
| 'bleu': bleu | |
| } | |
| if file_info_flag: | |
| file_info = {'original_filename': text['original_filename'], 'filename': text['filename'], 'page_idx': text['page_idx']} | |
| feature['file_info'] = file_info | |
| features.append(feature) | |
| if shuffle: | |
| random.shuffle(features) | |
| if local_rank in [-1, 0] and cached_features_file is not None: | |
| if not os.path.exists(os.path.dirname(cached_features_file)): | |
| os.makedirs(os.path.dirname(cached_features_file)) | |
| logger.info("Saving features into cached file %s", cached_features_file) | |
| torch.save(features, cached_features_file) | |
| # Make sure only the first process in distributed training process the dataset, and the others will use the cache | |
| if local_rank == 0: | |
| torch.distributed.barrier() | |
| return features | |
| def convert_src_layout_inputs_to_tokens(inputs, converter, max_src_length, layout_flag=True): | |
| ret = [] | |
| if not layout_flag: | |
| for line in inputs: | |
| ret.append(converter(line["source_ids"])[: max_src_length]) | |
| else: | |
| for line in inputs: | |
| raw_text_ids = [x[0] for x in line['source_ids']] | |
| raw_text = converter(raw_text_ids) | |
| new_line = [[t] + x[1:] for t, x in zip(raw_text, line['source_ids'])][: max_src_length] | |
| ret.append(new_line) | |
| return ret | |
| def convert_tgt_layout_inputs_to_tokens(inputs, converter, max_tgt_length, layout_flag=True): | |
| ret = [] | |
| if not layout_flag: | |
| for line in inputs: | |
| ret.append(converter(line["target_ids"])[: max_tgt_length]) | |
| else: | |
| for line in inputs: | |
| raw_text_ids = [x[0] for x in line['target_ids']] | |
| ret.append(converter(raw_text_ids)[: max_tgt_length]) | |
| return ret | |
| def get_tokens_from_src_and_index(src, index, modifier=None): | |
| result = [] | |
| for i in index: | |
| i = modifier(i) | |
| i = min(i, len(src) - 1) | |
| if isinstance(src[i], list): | |
| result.append(src[i][0]) | |
| else: | |
| result.append(src[i]) | |
| return result | |
| def get_layout_from_src_and_index(src, index, modifier=None): | |
| result = [] | |
| s = set() | |
| for i in index: | |
| i = modifier(i) | |
| i = min(i, len(src) - 1) | |
| layout = src[i][1:] | |
| if repr(layout) not in s: | |
| result.append(layout) | |
| s.add(repr(layout)) | |
| return result | |
| def get_everything_from_src_and_index(src, index, modifier=None): | |
| result = [] | |
| for i in index: | |
| i = modifier(i) | |
| i = min(i, len(src) - 1) | |
| result.append(src[i]) | |
| return result | |