| | import torch |
| | import random |
| | import json |
| | import numpy as np |
| | import pdb |
| | import os.path as osp |
| | from model import BertTokenizer |
| | import torch.distributed as dist |
| |
|
| |
|
| | class SeqDataset(torch.utils.data.Dataset): |
| | def __init__(self, data, chi_ref=None, kpi_ref=None): |
| | self.data = data |
| | self.chi_ref = chi_ref |
| | self.kpi_ref = kpi_ref |
| |
|
| | def __len__(self): |
| | return len(self.data) |
| |
|
| | def __getitem__(self, index): |
| | sample = self.data[index] |
| | if self.chi_ref is not None: |
| | chi_ref = self.chi_ref[index] |
| | else: |
| | chi_ref = None |
| |
|
| | if self.kpi_ref is not None: |
| | kpi_ref = self.kpi_ref[index] |
| | else: |
| | kpi_ref = None |
| |
|
| | return sample, chi_ref, kpi_ref |
| |
|
| |
|
| | class OrderDataset(torch.utils.data.Dataset): |
| | def __init__(self, data, kpi_ref=None): |
| | self.data = data |
| | self.kpi_ref = kpi_ref |
| |
|
| | def __len__(self): |
| | return len(self.data) |
| |
|
| | def __getitem__(self, index): |
| | sample = self.data[index] |
| | if self.kpi_ref is not None: |
| | kpi_ref = self.kpi_ref[index] |
| | else: |
| | kpi_ref = None |
| |
|
| | return sample, kpi_ref |
| |
|
| |
|
| | class KGDataset(torch.utils.data.Dataset): |
| | def __init__(self, data): |
| | self.data = data |
| | self.len = len(self.data) |
| |
|
| | def __len__(self): |
| | return self.len |
| |
|
| | def __getitem__(self, index): |
| |
|
| | sample = self.data[index] |
| | return sample |
| |
|
| | |
| |
|
| |
|
| | class Collator_base(object): |
| | |
| | |
| | def __init__(self, args, tokenizer, special_token=None): |
| | self.tokenizer = tokenizer |
| | if special_token is None: |
| | self.special_token = ['[SEP]', '[MASK]', '[ALM]', '[KPI]', '[CLS]', '[LOC]', '[EOS]', '[ENT]', '[ATTR]', '[NUM]', '[REL]', '|', '[DOC]'] |
| | else: |
| | self.special_token = special_token |
| |
|
| | self.text_maxlength = args.maxlength |
| | self.mlm_probability = args.mlm_probability |
| | self.args = args |
| | if self.args.special_token_mask: |
| | self.special_token = ['|', '[NUM]'] |
| |
|
| | if not self.args.only_test and self.args.use_mlm_task: |
| | if args.mask_stratege == 'rand': |
| | self.mask_func = self.torch_mask_tokens |
| | else: |
| | if args.mask_stratege == 'wwm': |
| | |
| | if args.rank == 0: |
| | print("use word-level Mask ...") |
| | assert args.add_special_word == 1 |
| | self.mask_func = self.wwm_mask_tokens |
| | else: |
| | if args.rank == 0: |
| | print("use token-level Mask ...") |
| | self.mask_func = self.domain_mask_tokens |
| |
|
| | def __call__(self, batch): |
| | |
| | |
| | |
| | |
| | |
| | kpi_ref = None |
| | if self.args.use_NumEmb: |
| | kpi_ref = [item[2] for item in batch] |
| | |
| | chinese_ref = [item[1] for item in batch] |
| | batch = [item[0] for item in batch] |
| | |
| | batch = self.tokenizer.batch_encode_plus( |
| | batch, |
| | padding='max_length', |
| | max_length=self.text_maxlength, |
| | truncation=True, |
| | return_tensors="pt", |
| | return_token_type_ids=False, |
| | return_attention_mask=True, |
| | add_special_tokens=False |
| | ) |
| | special_tokens_mask = batch.pop("special_tokens_mask", None) |
| | |
| |
|
| | |
| | |
| | if chinese_ref is not None: |
| | batch["chinese_ref"] = chinese_ref |
| | if kpi_ref is not None: |
| | batch["kpi_ref"] = kpi_ref |
| |
|
| | |
| |
|
| | if not self.args.only_test and self.args.use_mlm_task: |
| | batch["input_ids"], batch["labels"] = self.mask_func( |
| | batch, special_tokens_mask=special_tokens_mask |
| | ) |
| | else: |
| | |
| | |
| | labels = batch["input_ids"].clone() |
| | if self.tokenizer.pad_token_id is not None: |
| | labels[labels == self.tokenizer.pad_token_id] = -100 |
| | batch["labels"] = labels |
| |
|
| | return batch |
| |
|
| | def torch_mask_tokens(self, inputs, special_tokens_mask=None): |
| | """ |
| | Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. |
| | """ |
| | if "input_ids" in inputs: |
| | inputs = inputs["input_ids"] |
| | labels = inputs.clone() |
| | |
| | probability_matrix = torch.full(labels.shape, self.mlm_probability) |
| | if special_tokens_mask is None: |
| | special_tokens_mask = [ |
| | self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() |
| | ] |
| | special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool) |
| | else: |
| | special_tokens_mask = special_tokens_mask.bool() |
| | |
| |
|
| | probability_matrix.masked_fill_(special_tokens_mask, value=0.0) |
| | masked_indices = torch.bernoulli(probability_matrix).bool() |
| | labels[~masked_indices] = -100 |
| |
|
| | |
| | indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices |
| | inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) |
| |
|
| | |
| | indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced |
| | random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long) |
| | inputs[indices_random] = random_words[indices_random] |
| |
|
| | |
| | return inputs, labels |
| |
|
| | def wwm_mask_tokens(self, inputs, special_tokens_mask=None): |
| | mask_labels = [] |
| | ref_tokens = inputs["chinese_ref"] |
| | input_ids = inputs["input_ids"] |
| | sz = len(input_ids) |
| |
|
| | |
| | for i in range(sz): |
| | |
| | mask_labels.append(self._whole_word_mask(ref_tokens[i])) |
| |
|
| | batch_mask = _torch_collate_batch(mask_labels, self.tokenizer, self.text_maxlength, pad_to_multiple_of=None) |
| | inputs, labels = self.torch_mask_tokens_4wwm(input_ids, batch_mask) |
| | return inputs, labels |
| |
|
| | |
| | def _whole_word_mask(self, input_tokens, max_predictions=512): |
| | """ |
| | Get 0/1 labels for masked tokens with whole word mask proxy |
| | """ |
| | assert isinstance(self.tokenizer, (BertTokenizer)) |
| | |
| | cand_indexes = [] |
| | cand_token = [] |
| |
|
| | for i, token in enumerate(input_tokens): |
| | if i >= self.text_maxlength - 1: |
| | |
| | break |
| | if token.lower() in self.special_token: |
| | |
| | continue |
| | if len(cand_indexes) >= 1 and token.startswith("##"): |
| | cand_indexes[-1].append(i) |
| | cand_token.append(i) |
| | else: |
| | cand_indexes.append([i]) |
| | cand_token.append(i) |
| |
|
| | random.shuffle(cand_indexes) |
| | |
| | |
| | |
| | num_to_predict = min(max_predictions, max(1, int(round((len(cand_token) + 2) * self.mlm_probability)))) |
| | masked_lms = [] |
| | covered_indexes = set() |
| | for index_set in cand_indexes: |
| | |
| | if len(masked_lms) >= num_to_predict: |
| | break |
| | |
| | |
| | |
| | if len(masked_lms) + len(index_set) > num_to_predict: |
| | continue |
| | is_any_index_covered = False |
| | for index in index_set: |
| | |
| | if index in covered_indexes: |
| | is_any_index_covered = True |
| | break |
| | if is_any_index_covered: |
| | continue |
| | for index in index_set: |
| | covered_indexes.add(index) |
| | masked_lms.append(index) |
| |
|
| | if len(covered_indexes) != len(masked_lms): |
| | |
| | raise ValueError("Length of covered_indexes is not equal to length of masked_lms.") |
| | |
| | mask_labels = [1 if i in covered_indexes else 0 for i in range(min(len(input_tokens), self.text_maxlength))] |
| |
|
| | return mask_labels |
| |
|
| | |
| |
|
| | |
| |
|
| | |
| | pass |
| |
|
| | def torch_mask_tokens_4wwm(self, inputs, mask_labels): |
| | """ |
| | Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set |
| | 'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref. |
| | """ |
| | |
| | |
| | if self.tokenizer.mask_token is None: |
| | raise ValueError( |
| | "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the" |
| | " --mlm flag if you want to use this tokenizer." |
| | ) |
| | labels = inputs.clone() |
| | |
| |
|
| | probability_matrix = mask_labels |
| |
|
| | special_tokens_mask = [self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()] |
| |
|
| | if len(special_tokens_mask[0]) != probability_matrix.shape[1]: |
| | print(f"len(special_tokens_mask[0]): {len(special_tokens_mask[0])}") |
| | print(f"probability_matrix.shape[1]): {probability_matrix.shape[1]}") |
| | print(f'max len {self.text_maxlength}') |
| | print(f"pad_token_id: {self.tokenizer.pad_token_id}") |
| | |
| | if self.args.dist: |
| | dist.barrier() |
| | pdb.set_trace() |
| | else: |
| | pdb.set_trace() |
| |
|
| | probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) |
| | if self.tokenizer._pad_token is not None: |
| | padding_mask = labels.eq(self.tokenizer.pad_token_id) |
| | probability_matrix.masked_fill_(padding_mask, value=0.0) |
| |
|
| | masked_indices = probability_matrix.bool() |
| | labels[~masked_indices] = -100 |
| |
|
| | |
| | |
| |
|
| | |
| | indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices |
| | inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) |
| |
|
| | |
| | indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced |
| | random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long) |
| | inputs[indices_random] = random_words[indices_random] |
| |
|
| | |
| | return inputs, labels |
| |
|
| | |
| |
|
| | def domain_mask_tokens(self, inputs, special_tokens_mask=None): |
| | pass |
| |
|
| |
|
| | class Collator_kg(object): |
| | |
| | |
| | def __init__(self, args, tokenizer, data): |
| | self.tokenizer = tokenizer |
| | self.text_maxlength = args.maxlength |
| | self.cross_sampling_flag = 0 |
| | |
| | self.neg_num = args.neg_num |
| | |
| | self.data = data |
| | self.args = args |
| |
|
| | def __call__(self, batch): |
| | |
| | outputs = self.sampling(batch) |
| |
|
| | return outputs |
| |
|
| | def sampling(self, data): |
| | """Filtering out positive samples and selecting some samples randomly as negative samples. |
| | |
| | Args: |
| | data: The triples used to be sampled. |
| | |
| | Returns: |
| | batch_data: The training data. |
| | """ |
| | batch_data = {} |
| | neg_ent_sample = [] |
| |
|
| | self.cross_sampling_flag = 1 - self.cross_sampling_flag |
| |
|
| | head_list = [] |
| | rel_list = [] |
| | tail_list = [] |
| | |
| | if self.cross_sampling_flag == 0: |
| | batch_data['mode'] = "head-batch" |
| | for index, (head, relation, tail) in enumerate(data): |
| | |
| | neg_head = self.find_neghead(data, index, relation, tail) |
| | neg_ent_sample.extend(random.sample(neg_head, self.neg_num)) |
| | head_list.append(head) |
| | rel_list.append(relation) |
| | tail_list.append(tail) |
| | else: |
| | batch_data['mode'] = "tail-batch" |
| | for index, (head, relation, tail) in enumerate(data): |
| | neg_tail = self.find_negtail(data, index, relation, head) |
| | neg_ent_sample.extend(random.sample(neg_tail, self.neg_num)) |
| |
|
| | head_list.append(head) |
| | rel_list.append(relation) |
| | tail_list.append(tail) |
| |
|
| | neg_ent_batch = self.batch_tokenizer(neg_ent_sample) |
| | head_batch = self.batch_tokenizer(head_list) |
| | rel_batch = self.batch_tokenizer(rel_list) |
| | tail_batch = self.batch_tokenizer(tail_list) |
| |
|
| | ent_list = head_list + rel_list + tail_list |
| | ent_dict = {k: v for v, k in enumerate(ent_list)} |
| | |
| | neg_index = torch.tensor([ent_dict[i] for i in neg_ent_sample]) |
| | |
| |
|
| | batch_data["positive_sample"] = (head_batch, rel_batch, tail_batch) |
| | batch_data['negative_sample'] = neg_ent_batch |
| | batch_data['neg_index'] = neg_index |
| | return batch_data |
| |
|
| | def batch_tokenizer(self, input_list): |
| | return self.tokenizer.batch_encode_plus( |
| | input_list, |
| | padding='max_length', |
| | max_length=self.text_maxlength, |
| | truncation=True, |
| | return_tensors="pt", |
| | return_token_type_ids=False, |
| | return_attention_mask=True, |
| | add_special_tokens=False |
| | ) |
| |
|
| | def find_neghead(self, data, index, rel, ta): |
| | head_list = [] |
| | for i, (head, relation, tail) in enumerate(data): |
| | |
| | if i != index and [head, rel, ta] not in self.data: |
| | head_list.append(head) |
| | |
| | |
| | while len(head_list) < self.neg_num: |
| | head_list.extend(random.sample(head_list, min(self.neg_num - len(head_list), len(head_list)))) |
| |
|
| | return head_list |
| |
|
| | def find_negtail(self, data, index, rel, he): |
| | tail_list = [] |
| | for i, (head, relation, tail) in enumerate(data): |
| | if i != index and [he, rel, tail] not in self.data: |
| | tail_list.append(tail) |
| | |
| | |
| | while len(tail_list) < self.neg_num: |
| | tail_list.extend(random.sample(tail_list, min(self.neg_num - len(tail_list), len(tail_list)))) |
| | return tail_list |
| |
|
| | |
| |
|
| |
|
| | def load_data(logger, args): |
| |
|
| | data_path = args.data_path |
| |
|
| | data_name = args.seq_data_name |
| | with open(osp.join(data_path, f'{data_name}_cws.json'), "r") as fp: |
| | data = json.load(fp) |
| | if args.rank == 0: |
| | logger.info(f"[Start] Loading Seq dataset: [{len(data)}]...") |
| | random.shuffle(data) |
| |
|
| | |
| | |
| | train_test_split = int(args.train_ratio * len(data)) |
| | |
| | |
| | train_data = data[0: train_test_split] |
| | test_data = data[train_test_split: len(data)] |
| |
|
| | |
| | if args.use_mlm_task: |
| | |
| | |
| | if args.rank == 0: |
| | print("using the domain words .....") |
| | domain_file_path = osp.join(args.data_path, f'{data_name}_chinese_ref.json') |
| | with open(domain_file_path, 'r') as f: |
| | chinese_ref = json.load(f) |
| | |
| | chi_ref_train = chinese_ref[:train_test_split] |
| | chi_ref_eval = chinese_ref[train_test_split:] |
| | else: |
| | chi_ref_train = None |
| | chi_ref_eval = None |
| |
|
| | if args.use_NumEmb: |
| | if args.rank == 0: |
| | print("using the kpi and num .....") |
| |
|
| | kpi_file_path = osp.join(args.data_path, f'{data_name}_kpi_ref.json') |
| | with open(kpi_file_path, 'r') as f: |
| | kpi_ref = json.load(f) |
| | kpi_ref_train = kpi_ref[:train_test_split] |
| | kpi_ref_eval = kpi_ref[train_test_split:] |
| | else: |
| | |
| | |
| | kpi_ref_train = None |
| | kpi_ref_eval = None |
| |
|
| | |
| | test_set = None |
| | train_set = SeqDataset(train_data, chi_ref=chi_ref_train, kpi_ref=kpi_ref_train) |
| | if len(test_data) > 0: |
| | test_set = SeqDataset(test_data, chi_ref=chi_ref_eval, kpi_ref=kpi_ref_eval) |
| | if args.rank == 0: |
| | logger.info("[End] Loading Seq dataset...") |
| | return train_set, test_set, train_test_split |
| |
|
| | |
| |
|
| |
|
| | def load_data_kg(logger, args): |
| | data_path = args.data_path |
| | if args.rank == 0: |
| | logger.info("[Start] Loading KG dataset...") |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | kg_data_name = args.kg_data_name |
| | with open(osp.join(data_path, f'{kg_data_name}.json'), "r") as fp: |
| | train_data = json.load(fp) |
| | |
| | |
| | |
| | |
| | |
| | train_set = KGDataset(train_data) |
| | if args.rank == 0: |
| | logger.info("[End] Loading KG dataset...") |
| | return train_set, train_data |
| |
|
| |
|
| | def _torch_collate_batch(examples, tokenizer, max_length=None, pad_to_multiple_of=None): |
| | """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary.""" |
| | import numpy as np |
| | import torch |
| |
|
| | |
| | if isinstance(examples[0], (list, tuple, np.ndarray)): |
| | examples = [torch.tensor(e, dtype=torch.long) for e in examples] |
| |
|
| | length_of_first = examples[0].size(0) |
| |
|
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | if tokenizer._pad_token is None: |
| | raise ValueError( |
| | "You are attempting to pad samples but the tokenizer you are using" |
| | f" ({tokenizer.__class__.__name__}) does not have a pad token." |
| | ) |
| |
|
| | |
| |
|
| | if max_length is None: |
| | pdb.set_trace() |
| | max_length = max(x.size(0) for x in examples) |
| |
|
| | if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): |
| | max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of |
| | result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id) |
| | for i, example in enumerate(examples): |
| | if tokenizer.padding_side == "right": |
| | result[i, : example.shape[0]] = example |
| | else: |
| | result[i, -example.shape[0]:] = example |
| |
|
| | return result |
| |
|
| |
|
| | def load_order_data(logger, args): |
| | if args.rank == 0: |
| | logger.info("[Start] Loading Order dataset...") |
| |
|
| | data_path = args.data_path |
| | if len(args.order_test_name) > 0: |
| | data_name = args.order_test_name |
| | else: |
| | data_name = args.order_data_name |
| | tmp = osp.join(data_path, f'{data_name}.json') |
| | if osp.exists(tmp): |
| | dp = tmp |
| | else: |
| | dp = osp.join(data_path, 'downstream_task', f'{data_name}.json') |
| | assert osp.exists(dp) |
| | with open(dp, "r") as fp: |
| | data = json.load(fp) |
| | |
| | |
| | train_test_split = int(args.train_ratio * len(data)) |
| |
|
| | mid_split = int(train_test_split / 2) |
| | mid = int(len(data) / 2) |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | test_data = data[0: mid_split] + data[mid: mid + mid_split] |
| | train_data = data[mid_split: mid] + data[mid + mid_split: len(data)] |
| |
|
| | |
| | test_set = None |
| | train_set = OrderDataset(train_data) |
| | if len(test_data) > 0: |
| | test_set = OrderDataset(test_data) |
| | if args.rank == 0: |
| | logger.info("[End] Loading Order dataset...") |
| | return train_set, test_set, train_test_split |
| |
|
| |
|
| | class Collator_order(object): |
| | |
| | def __init__(self, args, tokenizer): |
| | self.tokenizer = tokenizer |
| | self.text_maxlength = args.maxlength |
| | self.args = args |
| | |
| | self.order_num = args.order_num |
| | self.p_label, self.n_label = smooth_BCE(args.eps) |
| |
|
| | def __call__(self, batch): |
| | |
| | |
| | |
| | output = [] |
| | for item in range(self.order_num): |
| | output.extend([dat[0][0][item] for dat in batch]) |
| | |
| |
|
| | labels = [1 if dat[0][1][0] == 2 else self.p_label if dat[0][1][0] == 1 else self.n_label for dat in batch] |
| | batch = self.tokenizer.batch_encode_plus( |
| | output, |
| | padding='max_length', |
| | max_length=self.text_maxlength, |
| | truncation=True, |
| | return_tensors="pt", |
| | return_token_type_ids=False, |
| | return_attention_mask=True, |
| | add_special_tokens=False |
| | ) |
| | |
| | return batch, torch.FloatTensor(labels) |
| |
|
| |
|
| | def smooth_BCE(eps=0.1): |
| | |
| | |
| | |
| | return 1.0 - 0.5 * eps, 0.5 * eps |
| |
|