Spaces:
Build error
Build error
| # -*- coding: utf-8 -*- | |
| """ | |
| @Author : Bao | |
| @Date : 2020/4/14 | |
| @Desc : | |
| @Last modified by : Bao | |
| @Last modified date : 2020/8/12 | |
| """ | |
| import os | |
| import copy | |
| import logging | |
| import ujson as json | |
| import torch | |
| from tqdm import tqdm | |
| from torch.utils.data import TensorDataset | |
| import tensorflow as tf | |
| import cjjpy as cjj | |
| import sys | |
| try: | |
| from ...mrc_client.answer_generator import assemble_answers_to_one | |
| except: | |
| sys.path.append(cjj.AbsParentDir(__file__, '...')) | |
| from mrc_client.answer_generator import assemble_answers_to_one | |
| logger = logging.getLogger(__name__) | |
| class InputExample(object): | |
| def __init__(self, guid, claim, evidences, questions, answers, | |
| evidential, label=None, nli_labels=None): | |
| self.guid = guid | |
| self.claim = claim | |
| self.evidences = evidences | |
| self.questions = questions | |
| self.answers = answers | |
| self.evidential = evidential | |
| self.label = label | |
| self.nli_labels = nli_labels | |
| def __repr__(self): | |
| return str(self.to_json_string()) | |
| def to_dict(self): | |
| """Serializes this instance to a Python dictionary.""" | |
| output = copy.deepcopy(self.__dict__) | |
| return output | |
| def to_json_string(self): | |
| """Serializes this instance to a JSON string.""" | |
| return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" | |
| class InputFeatures(object): | |
| def __init__( | |
| self, | |
| guid, | |
| c_input_ids, | |
| c_attention_mask, | |
| c_token_type_ids, | |
| q_input_ids_list, | |
| q_attention_mask_list, | |
| q_token_type_ids_list, | |
| nli_labels=None, | |
| label=None, | |
| ): | |
| self.guid = guid | |
| self.c_input_ids = c_input_ids | |
| self.c_attention_mask = c_attention_mask | |
| self.c_token_type_ids = c_token_type_ids | |
| self.q_input_ids_list = q_input_ids_list | |
| self.q_attention_mask_list = q_attention_mask_list | |
| self.q_token_type_ids_list = q_token_type_ids_list | |
| self.nli_labels = nli_labels | |
| self.label = label | |
| def __repr__(self): | |
| return str(self.to_json_string()) | |
| def to_dict(self): | |
| """Serializes this instance to a Python dictionary.""" | |
| output = copy.deepcopy(self.__dict__) | |
| return output | |
| def to_json_string(self): | |
| """Serializes this instance to a JSON string.""" | |
| return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" | |
| def _create_input_ids_from_token_ids(token_ids_a, token_ids_b, tokenizer, max_seq_length): | |
| pair = len(token_ids_b) != 0 | |
| # Truncate sequences. | |
| num_special_tokens_to_add = tokenizer.num_special_tokens_to_add(pair=pair) | |
| while len(token_ids_a) + len(token_ids_b) > max_seq_length - num_special_tokens_to_add: | |
| if len(token_ids_b) > 0: | |
| token_ids_b = token_ids_b[:-1] | |
| else: | |
| token_ids_a = token_ids_a[:-1] | |
| # Add special tokens to input_ids. | |
| input_ids = tokenizer.build_inputs_with_special_tokens(token_ids_a, token_ids_b if pair else None) | |
| # The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to. | |
| attention_mask = [1] * len(input_ids) | |
| # Create token_type_ids. | |
| token_type_ids = tokenizer.create_token_type_ids_from_sequences(token_ids_a, token_ids_b if pair else None) | |
| # Pad up to the sequence length. | |
| padding_length = max_seq_length - len(input_ids) | |
| if tokenizer.padding_side == "right": | |
| input_ids = input_ids + ([tokenizer.pad_token_id] * padding_length) | |
| attention_mask = attention_mask + ([0] * padding_length) | |
| token_type_ids = token_type_ids + ([tokenizer.pad_token_type_id] * padding_length) | |
| else: | |
| input_ids = ([tokenizer.pad_token_id] * padding_length) + input_ids | |
| attention_mask = ([0] * padding_length) + attention_mask | |
| token_type_ids = ([tokenizer.pad_token_type_id] * padding_length) + token_type_ids | |
| assert len(input_ids) == max_seq_length | |
| assert len(attention_mask) == max_seq_length | |
| assert len(token_type_ids) == max_seq_length | |
| return input_ids, attention_mask, token_type_ids | |
| def convert_examples_to_features( | |
| examples, | |
| tokenizer, | |
| max_seq1_length=256, | |
| max_seq2_length=128, | |
| verbose=True | |
| ): | |
| features = [] | |
| iter = tqdm(examples, desc="Converting Examples") if verbose else examples | |
| for (ex_index, example) in enumerate(iter): | |
| encoded_outputs = {"guid": example.guid, 'label': example.label, | |
| 'nli_labels': example.nli_labels} | |
| # ****** for sequence 1 ******* # | |
| token_ids_a, token_ids_b = [], [] | |
| # text a in sequence 1 | |
| token_ids = tokenizer.encode(example.claim, add_special_tokens=False) # encode claim | |
| token_ids_a.extend(token_ids) | |
| # text b in sequence 1 | |
| for i, evidence in enumerate(example.evidences): | |
| token_ids = tokenizer.encode(evidence, add_special_tokens=False) # encode evidence | |
| token_ids_b.extend(token_ids + [tokenizer.sep_token_id]) | |
| # Remove last sep token in token_ids_b. | |
| token_ids_b = token_ids_b[:-1] | |
| token_ids_b = token_ids_b[:max_seq1_length - len(token_ids_a) - 4] # magic number for special tokens | |
| # premise </s> </s> hypothesis | |
| input_ids, attention_mask, token_type_ids = _create_input_ids_from_token_ids( | |
| token_ids_b, | |
| token_ids_a, | |
| tokenizer, | |
| max_seq1_length, | |
| ) | |
| encoded_outputs["c_input_ids"] = input_ids | |
| encoded_outputs["c_attention_mask"] = attention_mask | |
| encoded_outputs["c_token_type_ids"] = token_type_ids | |
| # ****** for sequence 2 ******* # | |
| encoded_outputs["q_input_ids_list"] = [] # m x L | |
| encoded_outputs["q_attention_mask_list"] = [] | |
| encoded_outputs["q_token_type_ids_list"] = [] | |
| for candidate in example.evidential: | |
| # text a in sequence 2 | |
| token_ids_a = tokenizer.encode(example.claim, add_special_tokens=False) # encode claim | |
| # text b in sequence 2 | |
| token_ids_b = tokenizer.encode(candidate, add_special_tokens=False) # encode candidate answer | |
| # premise </s> </s> hypothesis | |
| input_ids, attention_mask, token_type_ids = _create_input_ids_from_token_ids( | |
| token_ids_b, | |
| token_ids_a, | |
| tokenizer, | |
| max_seq2_length, | |
| ) | |
| encoded_outputs["q_input_ids_list"].append(input_ids) | |
| encoded_outputs["q_attention_mask_list"].append(attention_mask) | |
| encoded_outputs["q_token_type_ids_list"].append(token_type_ids) | |
| features.append(InputFeatures(**encoded_outputs)) | |
| if ex_index < 5 and verbose: | |
| logger.info("*** Example ***") | |
| logger.info("guid: {}".format(example.guid)) | |
| logger.info("c_input_ids: {}".format(encoded_outputs["c_input_ids"])) | |
| for input_ids in encoded_outputs['q_input_ids_list']: | |
| logger.info('q_input_ids: {}'.format(input_ids)) | |
| logger.info("label: {}".format(example.label)) | |
| logger.info("nli_labels: {}".format(example.nli_labels)) | |
| return features | |
| class DataProcessor: | |
| def __init__( | |
| self, | |
| model_name_or_path, | |
| max_seq1_length, | |
| max_seq2_length, | |
| max_num_questions, | |
| cand_k, | |
| data_dir='', | |
| cache_dir_name='cache_check', | |
| overwrite_cache=False, | |
| mask_rate=0. | |
| ): | |
| self.model_name_or_path = model_name_or_path | |
| self.max_seq1_length = max_seq1_length | |
| self.max_seq2_length = max_seq2_length | |
| self.max_num_questions = max_num_questions | |
| self.k = cand_k | |
| self.mask_rate = mask_rate | |
| self.data_dir = data_dir | |
| self.cached_data_dir = os.path.join(data_dir, cache_dir_name) | |
| self.overwrite_cache = overwrite_cache | |
| self.label2id = {"SUPPORTS": 2, "REFUTES": 0, 'NOT ENOUGH INFO': 1} | |
| def _format_file(self, role): | |
| return os.path.join(self.data_dir, "{}.json".format(role)) | |
| def load_and_cache_data(self, role, tokenizer, data_tag): | |
| tf.io.gfile.makedirs(self.cached_data_dir) | |
| cached_file = os.path.join( | |
| self.cached_data_dir, | |
| "cached_features_{}_{}_{}_{}_{}_{}".format( | |
| role, | |
| list(filter(None, self.model_name_or_path.split("/"))).pop(), | |
| str(self.max_seq1_length), | |
| str(self.max_seq2_length), | |
| str(self.k), | |
| data_tag | |
| ), | |
| ) | |
| if os.path.exists(cached_file) and not self.overwrite_cache: | |
| logger.info("Loading features from cached file {}".format(cached_file)) | |
| features = torch.load(cached_file) | |
| else: | |
| examples = [] | |
| with tf.io.gfile.GFile(self._format_file(role)) as f: | |
| data = f.readlines() | |
| for line in tqdm(data): | |
| sample = self._load_line(line) | |
| examples.append(InputExample(**sample)) | |
| features = convert_examples_to_features(examples, tokenizer, | |
| self.max_seq1_length, self.max_seq2_length) | |
| if 'train' in role or 'eval' in role: | |
| logger.info("Saving features into cached file {}".format(cached_file)) | |
| torch.save(features, cached_file) | |
| return self._create_tensor_dataset(features, tokenizer) | |
| def convert_inputs_to_dataset(self, inputs, tokenizer, verbose=True): | |
| examples = [] | |
| for line in inputs: | |
| sample = self._load_line(line) | |
| examples.append(InputExample(**sample)) | |
| features = convert_examples_to_features(examples, tokenizer, | |
| self.max_seq1_length, self.max_seq2_length, verbose) | |
| return self._create_tensor_dataset(features, tokenizer, do_predict=True) | |
| def _create_tensor_dataset(self, features, tokenizer, do_predict=False): | |
| all_c_input_ids = torch.tensor([f.c_input_ids for f in features], dtype=torch.long) | |
| all_c_attention_mask = torch.tensor([f.c_attention_mask for f in features], dtype=torch.long) | |
| all_c_token_type_ids = torch.tensor([f.c_token_type_ids for f in features], dtype=torch.long) | |
| all_q_input_ids_list = [] | |
| all_q_attention_mask_list = [] | |
| all_q_token_type_ids_list = [] | |
| def _trunc_agg(self, feature, pad_token): | |
| # feature: m x L | |
| _input_list = [v for v in feature[:self.max_num_questions]] | |
| while len(_input_list) < self.max_num_questions: | |
| _input_list.append([pad_token] * self.max_seq2_length) | |
| return _input_list | |
| for f in features: # N x m x L | |
| all_q_input_ids_list.append(_trunc_agg(self, f.q_input_ids_list, tokenizer.pad_token_id)) | |
| all_q_attention_mask_list.append(_trunc_agg(self, f.q_attention_mask_list, 0)) | |
| all_q_token_type_ids_list.append(_trunc_agg(self, f.q_token_type_ids_list, tokenizer.pad_token_type_id)) | |
| all_q_input_ids_list = torch.tensor(all_q_input_ids_list, dtype=torch.long) | |
| all_q_attention_mask_list = torch.tensor(all_q_attention_mask_list, dtype=torch.long) | |
| all_q_token_type_ids_list = torch.tensor(all_q_token_type_ids_list, dtype=torch.long) | |
| all_nli_labels_list = [] | |
| for f in features: | |
| all_nli_labels_list.append(f.nli_labels[:self.max_num_questions] | |
| + max(0, (self.max_num_questions - len(f.nli_labels))) * [[0., 0., 0.]]) | |
| all_nli_labels = torch.tensor(all_nli_labels_list, dtype=torch.float) | |
| if not do_predict: | |
| all_labels = torch.tensor([f.label for f in features], dtype=torch.long) | |
| dataset = TensorDataset( | |
| all_c_input_ids, all_c_attention_mask, all_c_token_type_ids, | |
| all_q_input_ids_list, all_q_attention_mask_list, all_q_token_type_ids_list, | |
| all_nli_labels, all_labels, | |
| ) | |
| else: | |
| dataset = TensorDataset( | |
| all_c_input_ids, all_c_attention_mask, all_c_token_type_ids, | |
| all_q_input_ids_list, all_q_attention_mask_list, all_q_token_type_ids_list, | |
| all_nli_labels, | |
| ) | |
| return dataset | |
| def _load_line(self, line): | |
| if isinstance(line, str): | |
| line = json.loads(line) | |
| guid = line["id"] | |
| claim = line["claim"] | |
| # TODO: hack no evidence situation | |
| evidences = line["evidence"] if len(line['evidence']) > 0 else ['no idea'] * 5 | |
| questions = line["questions"] | |
| answers = line["answers"] | |
| evidential = assemble_answers_to_one(line, self.k, mask_rate=self.mask_rate)['evidential_assembled'] | |
| label = line.get("label", None) | |
| nli_labels = line.get('nli_labels', [[0., 0., 0.]] * len(questions)) | |
| for i, e in enumerate(evidential): | |
| if '<mask>' in e: | |
| nli_labels[i] = [0., 0., 0.] | |
| answers = [v[0] for v in answers] # k = 1 | |
| label = self.label2id.get(label) | |
| sample = { | |
| "guid": guid, | |
| "claim": claim, | |
| "evidences": evidences, | |
| "questions": questions, | |
| "answers": answers, | |
| "evidential": evidential, # already assembled. | |
| "label": label, | |
| 'nli_labels': nli_labels | |
| } | |
| return sample | |