| | from collections import defaultdict |
| |
|
| |
|
| | def read_conjunctions(cfg): |
| | conj2sent = dict() |
| | file_path = cfg.conjunctions_file |
| |
|
| | with open(file_path, 'r') as fin: |
| | sent = 1 |
| | currentSentText = '' |
| | for line in fin: |
| | if line == '\n': |
| | sent = 1 |
| | continue |
| | elif sent == 1: |
| | currentSentText = line.replace('\n', '') |
| | sent = 0 |
| | else: |
| | conj_sent = line.replace('\n', '') |
| | conj2sent[conj_sent] = currentSentText |
| | conj_sentences = list(conj2sent.keys()) |
| | return conj_sentences, conj2sent |
| |
|
| |
|
| | def print_predictions(outputs, file_path, vocab, sequence_label_domain=None): |
| | """print_predictions prints prediction results |
| | |
| | Args: |
| | outputs (list): prediction outputs |
| | file_path (str): output file path |
| | vocab (Vocabulary): vocabulary |
| | sequence_label_domain (str, optional): sequence label domain. Defaults to None. |
| | """ |
| |
|
| | with open(file_path, 'w') as fout: |
| | for sent_output in outputs: |
| | seq_len = sent_output['seq_len'] |
| | assert 'tokens' in sent_output |
| | tokens = [vocab.get_token_from_index(token, 'tokens') for token in sent_output['tokens'][:seq_len]] |
| | print("Token\t{}".format(' '.join(tokens)), file=fout) |
| |
|
| | if 'text' in sent_output: |
| | print(f"Text\t{sent_output['text']}", file=fout) |
| |
|
| | if 'sequence_labels' in sent_output and 'sequence_label_preds' in sent_output: |
| | sequence_labels = [ |
| | vocab.get_token_from_index(true_sequence_label, sequence_label_domain) |
| | for true_sequence_label in sent_output['sequence_labels'][:seq_len] |
| | ] |
| | sequence_label_preds = [ |
| | vocab.get_token_from_index(pred_sequence_label, sequence_label_domain) |
| | for pred_sequence_label in sent_output['sequence_label_preds'][:seq_len] |
| | ] |
| |
|
| | print("Sequence-Label-True\t{}".format(' '.join(sequence_labels)), file=fout) |
| | print("Sequence-Label-Pred\t{}".format(' '.join(sequence_label_preds)), file=fout) |
| |
|
| | if 'joint_label_matrix' in sent_output: |
| | for row in sent_output['joint_label_matrix'][:seq_len]: |
| | print("Joint-Label-True\t{}".format(' '.join( |
| | [vocab.get_token_from_index(item, 'ent_rel_id') for item in row[:seq_len]])), |
| | file=fout) |
| |
|
| | if 'joint_label_preds' in sent_output: |
| | for row in sent_output['joint_label_preds'][:seq_len]: |
| | print("Joint-Label-Pred\t{}".format(' '.join( |
| | [vocab.get_token_from_index(item, 'ent_rel_id') for item in row[:seq_len]])), |
| | file=fout) |
| |
|
| | if 'separate_positions' in sent_output: |
| | print("Separate-Position-True\t{}".format(' '.join(map(str, sent_output['separate_positions']))), |
| | file=fout) |
| |
|
| | if 'all_separate_position_preds' in sent_output: |
| | print("Separate-Position-Pred\t{}".format(' '.join(map(str, |
| | sent_output['all_separate_position_preds']))), |
| | file=fout) |
| |
|
| | if 'span2ent' in sent_output: |
| | for span, ent in sent_output['span2ent'].items(): |
| | ent = vocab.get_token_from_index(ent, 'span2ent') |
| | assert ent != 'None', "true relation can not be `None`." |
| |
|
| | print("Ent-True\t{}\t{}\t{}".format(ent, span, ' '.join(tokens[span[0]:span[1]])), file=fout) |
| |
|
| | if 'all_ent_preds' in sent_output: |
| | for span, ent in sent_output['all_ent_preds'].items(): |
| | |
| |
|
| | print("Ent-Span-Pred\t{}".format(span), file=fout) |
| | print("Ent-Pred\t{}\t{}\t{}".format(ent, span, ' '.join(tokens[span[0]:span[1]])), file=fout) |
| |
|
| | if 'span2rel' in sent_output: |
| | for (span1, span2), rel in sent_output['span2rel'].items(): |
| | rel = vocab.get_token_from_index(rel, 'span2rel') |
| | assert rel != 'None', "true relation can not be `None`." |
| |
|
| | if rel[-1] == '<': |
| | span1, span2 = span2, span1 |
| | print("Rel-True\t{}\t{}\t{}\t{}\t{}".format(rel[:-2], span1, span2, |
| | ' '.join(tokens[span1[0]:span1[1]]), |
| | ' '.join(tokens[span2[0]:span2[1]])), |
| | file=fout) |
| |
|
| | if 'all_rel_preds' in sent_output: |
| | for (span1, span2), rel in sent_output['all_rel_preds'].items(): |
| | |
| |
|
| | if rel[-1] == '<': |
| | span1, span2 = span2, span1 |
| | print("Rel-Pred\t{}\t{}\t{}\t{}\t{}".format(rel[:-2], span1, span2, |
| | ' '.join(tokens[span1[0]:span1[1]]), |
| | ' '.join(tokens[span2[0]:span2[1]])), |
| | file=fout) |
| |
|
| | print(file=fout) |
| |
|
| |
|
| | def print_extractions_allennlp_format(cfg, outputs, file_path, vocab): |
| | conj_sentences, conj2sent = read_conjunctions(cfg) |
| | ext_texts = [] |
| | with open(file_path, 'w') as fout: |
| | for sent_output in outputs: |
| | extractions = {} |
| | seq_len = sent_output['seq_len'] |
| | assert 'tokens' in sent_output |
| | tokens = [vocab.get_token_from_index(token, 'tokens') for token in sent_output['tokens'][:seq_len-6]] |
| | sentence = ' '.join(tokens) |
| | if sentence in conj_sentences: |
| | sentence = conj2sent[sentence] |
| |
|
| | if 'all_rel_preds' in sent_output: |
| | for (span1, span2), rel in sent_output['all_rel_preds'].items(): |
| | if rel == '' or rel == ' ': |
| | continue |
| | if sent_output['all_ent_preds'][span1] == 'Relation': |
| | try: |
| | if span2 in extractions[span1][rel]: |
| | continue |
| | except: |
| | pass |
| | try: |
| | extractions[span1][rel].append(span2) |
| | except: |
| | extractions[span1] = defaultdict(list) |
| | extractions[span1][rel].append(span2) |
| | else: |
| | try: |
| | if span1 in extractions[span2][rel]: |
| | continue |
| | except: |
| | pass |
| | try: |
| | extractions[span2][rel].append(span1) |
| | except: |
| | extractions[span2] = defaultdict(list) |
| | extractions[span2][rel].append(span1) |
| | to_remove_rel_spans = set() |
| | expand_rel = {} |
| | to_add = {} |
| | for rel_span1, d1 in extractions.items(): |
| | for rel_span2, d2 in extractions.items(): |
| | if rel_span1 != rel_span2 and not (rel_span1 in to_remove_rel_spans or rel_span2 in to_remove_rel_spans): |
| | if d1["Subject"] == d2["Subject"] and d1["Object"] == d2["Object"]: |
| | if rel_span1 in to_remove_rel_spans: |
| | to_add[expand_rel[rel_span1] + rel_span2] = d1 |
| | to_remove_rel_spans.add(rel_span2) |
| | to_remove_rel_spans.add(expand_rel[rel_span1]) |
| | expand_rel[rel_span2] = expand_rel[rel_span1] + rel_span2 |
| | expand_rel[rel_span1] = expand_rel[rel_span1] + rel_span2 |
| | elif rel_span2 in to_remove_rel_spans: |
| | to_add[expand_rel[rel_span2] + rel_span1] = d1 |
| | to_remove_rel_spans.add(rel_span1) |
| | to_remove_rel_spans.add(expand_rel[rel_span2]) |
| | expand_rel[rel_span1] = expand_rel[rel_span2] + rel_span1 |
| | expand_rel[rel_span2] = expand_rel[rel_span2] + rel_span1 |
| | else: |
| | to_add[rel_span1 + rel_span2] = d1 |
| | expand_rel[rel_span1] = rel_span1 + rel_span2 |
| | expand_rel[rel_span2] = rel_span1 + rel_span2 |
| | to_remove_rel_spans.add(rel_span1) |
| | to_remove_rel_spans.add(rel_span2) |
| | for tm in to_remove_rel_spans: |
| | del extractions[tm] |
| | for k, v in to_add.items(): |
| | extractions[k] = v |
| | for rel_sp, d in extractions.items(): |
| | if len(d["Subject"]) > 1: |
| | sorted_d_subject = sorted(d["Subject"], key=lambda x: x[0][0]) |
| | sorted_d_subject = [x[0] for x in sorted_d_subject] |
| | subject_text = " ".join([" ".join(tokens[sub_span[0]:sub_span[1]]) for sub_span in sorted_d_subject]) |
| | elif len(d["Subject"]) == 1: |
| | subject_text = " ".join([" ".join(tokens[sub_span[0]:sub_span[1]]) for sub_span in d["Subject"][0]]) |
| | else: |
| | subject_text = "" |
| | if len(d["Object"]) > 1: |
| | sorted_d_object = sorted(d["Object"], key=lambda x: x[0][0]) |
| | sorted_d_object = [x[0] for x in sorted_d_object] |
| | object_text = " ".join([" ".join(tokens[sub_span[0]:sub_span[1]]) for sub_span in sorted_d_object]) |
| | elif len(d["Object"]) == 1: |
| | object_text = " ".join([" ".join(tokens[sub_span[0]:sub_span[1]]) for sub_span in d["Object"][0]]) |
| | else: |
| | object_text = "" |
| | rel_text = " ".join([" ".join(tokens[sub_span[0]:sub_span[1]]) for sub_span in rel_sp]).replace('[unused1]', 'is') |
| | ext = f'<arg1> {subject_text} </arg1> <rel> {rel_text} </rel> <arg2> {object_text} </arg2>' |
| | if ext not in ext_texts and (rel_text != '' and subject_text != ''): |
| | print("{}\t{}".format(sentence, ext), file=fout) |
| | ext_texts.append(ext) |
| |
|
| |
|
| | def print_predictions_for_joint_decoding(outputs, file_path, vocab): |
| | """print_predictions prints prediction results |
| | |
| | Args: |
| | outputs (list): prediction outputs |
| | file_path (str): output file path |
| | vocab (Vocabulary): vocabulary |
| | sequence_label_domain (str, optional): sequence label domain. Defaults to None. |
| | """ |
| |
|
| | with open(file_path, 'w') as fout: |
| | for sent_output in outputs: |
| | seq_len = sent_output['seq_len'] |
| | assert 'tokens' in sent_output |
| | tokens = [vocab.get_token_from_index(token, 'tokens') for token in sent_output['tokens'][:seq_len]] |
| | print("Token\t{}".format(' '.join(tokens)), file=fout) |
| |
|
| | if 'joint_label_matrix' in sent_output: |
| | for row in sent_output['joint_label_matrix'][:seq_len]: |
| | print("Joint-Label-True\t{}".format(' '.join( |
| | [vocab.get_token_from_index(item, 'ent_rel_id') for item in row[:seq_len]])), |
| | file=fout) |
| |
|
| | if 'joint_label_preds' in sent_output: |
| | for row in sent_output['joint_label_preds'][:seq_len]: |
| | print("Joint-Label-Pred\t{}".format(' '.join( |
| | [vocab.get_token_from_index(item, 'ent_rel_id') for item in row[:seq_len]])), |
| | file=fout) |
| |
|
| | if 'separate_positions' in sent_output: |
| | print("Separate-Position-True\t{}".format(' '.join(map(str, sent_output['separate_positions']))), |
| | file=fout) |
| |
|
| | if 'all_separate_position_preds' in sent_output: |
| | print("Separate-Position-Pred\t{}".format(' '.join(map(str, |
| | sent_output['all_separate_position_preds']))), |
| | file=fout) |
| |
|
| | if 'all_ent_span_preds' in sent_output: |
| | for span in sent_output['all_ent_span_preds']: |
| | print("Ent-Span-Pred\t{}".format(span), file=fout) |
| |
|
| | if 'span2ent' in sent_output: |
| | for span, ent in sent_output['span2ent'].items(): |
| | ent = vocab.get_token_from_index(ent, 'ent_rel_id') |
| | assert ent != 'None', "true relation can not be `None`." |
| |
|
| | print("Ent-True\t{}\t{}\t{}".format(ent, span, ' '.join([' '.join(tokens[sub_span[0]:sub_span[1]]) for sub_span in span])), file=fout) |
| |
|
| | if 'all_ent_preds' in sent_output: |
| | for span, ent in sent_output['all_ent_preds'].items(): |
| | |
| | print("Ent-Pred\t{}\t{}\t{}".format(ent, span, ' '.join( |
| | [' '.join(tokens[sub_span[0]:sub_span[1]]) for sub_span in span])), file=fout) |
| |
|
| | if 'span2rel' in sent_output: |
| | for (span1, span2), rel in sent_output['span2rel'].items(): |
| | rel = vocab.get_token_from_index(rel, 'ent_rel_id') |
| | assert rel != 'None', "true relation can not be `None`." |
| | span1_text_list = [' '.join(tokens[sub_span[0]:sub_span[1]]) for sub_span in span1] |
| | span2_text_list = [' '.join(tokens[sub_span[0]:sub_span[1]]) for sub_span in span2] |
| | print("Rel-True\t{}\t{}\t{}\t{}\t{}".format(rel, span1, span2, ' '.join(span1_text_list), |
| | ' '.join(span2_text_list)), |
| | file=fout) |
| |
|
| | if 'all_rel_preds' in sent_output: |
| | for (span1, span2), rel in sent_output['all_rel_preds'].items(): |
| | |
| |
|
| | span1_text_list = [' '.join(tokens[sub_span[0]:sub_span[1]]) for sub_span in span1] |
| | span2_text_list = [' '.join(tokens[sub_span[0]:sub_span[1]]) for sub_span in span2] |
| | print("Rel-Pred\t{}\t{}\t{}\t{}\t{}".format(rel, span1, span2, ' '.join(span1_text_list), |
| | ' '.join(span2_text_list)), |
| | file=fout) |
| |
|
| | |
| | |
| | |
| |
|
| | print(file=fout) |
| |
|
| |
|
| | def print_predictions_for_entity_rel_decoding(outputs, file_path, vocab): |
| | """print_predictions prints prediction results |
| | |
| | Args: |
| | outputs (list): prediction outputs |
| | file_path (str): output file path |
| | vocab (Vocabulary): vocabulary |
| | sequence_label_domain (str, optional): sequence label domain. Defaults to None. |
| | """ |
| |
|
| | with open(file_path, 'w') as fout: |
| | |
| | for sent_output in outputs: |
| | seq_len = sent_output['seq_len'] |
| | assert 'tokens' in sent_output |
| | tokens = [vocab.get_token_from_index(token, 'tokens') for token in sent_output['tokens'][:seq_len]] |
| | print("Token\t{}".format(' '.join(tokens)), file=fout) |
| |
|
| | if 'entity_label_preds' in sent_output: |
| | for row in sent_output['entity_label_preds'][:seq_len]: |
| | print("Ent-Label-Pred\t{}".format(' '.join( |
| | [vocab.get_token_from_index(item, 'ent_rel_id') for item in row[:seq_len]])), |
| | file=fout) |
| |
|
| | if 'relation_label_matrix' in sent_output: |
| | for row in sent_output['relation_label_matrix'][:seq_len]: |
| | print("Rel-Label-True\t{}".format(' '.join( |
| | [vocab.get_token_from_index(item + 2, 'ent_rel_id') if item != 0 else "None" for item in row[:seq_len]])), |
| | file=fout) |
| |
|
| | if 'relation_label_preds' in sent_output: |
| | for row in sent_output['relation_label_preds'][:seq_len]: |
| | print("Rel-Label-Pred\t{}".format(' '.join( |
| | [vocab.get_token_from_index(item + 2, 'ent_rel_id') if item != 0 else "None" for item in row[:seq_len]])), |
| | file=fout) |
| |
|
| | if 'separate_positions' in sent_output: |
| | print("Separate-Position-True\t{}".format(' '.join(map(str, sent_output['separate_positions']))), |
| | file=fout) |
| |
|
| | if 'all_separate_position_preds' in sent_output: |
| | print("Separate-Position-Pred\t{}".format(' '.join(map(str, |
| | sent_output['all_separate_position_preds']))), |
| | file=fout) |
| |
|
| | if 'all_ent_span_preds' in sent_output: |
| | for span in sent_output['all_ent_span_preds']: |
| | print("Ent-Span-Pred\t{}".format(span), file=fout) |
| |
|
| | if 'span2ent' in sent_output: |
| | for span, ent in sent_output['span2ent'].items(): |
| | ent = vocab.get_token_from_index(ent, 'ent_rel_id') |
| | assert ent != 'None', "true relation can not be `None`." |
| |
|
| | print("Ent-True\t{}\t{}\t{}".format(ent, span, ' '.join( |
| | [' '.join(tokens[sub_span[0]:sub_span[1]]) for sub_span in span])), file=fout) |
| |
|
| | if 'all_ent_preds' in sent_output: |
| | for span, ent in sent_output['all_ent_preds'].items(): |
| | print("Ent-Pred\t{}\t{}\t{}".format(ent, span, ' '.join( |
| | [' '.join(tokens[sub_span[0]:sub_span[1]]) for sub_span in span])), file=fout) |
| |
|
| | if 'span2rel' in sent_output: |
| | for (span1, span2), rel in sent_output['span2rel'].items(): |
| | rel = vocab.get_token_from_index(rel, 'ent_rel_id') |
| | assert rel != 'None', "true relation can not be `None`." |
| |
|
| | span1_text_list = [' '.join(tokens[sub_span[0]:sub_span[1]]) for sub_span in span1] |
| | span2_text_list = [' '.join(tokens[sub_span[0]:sub_span[1]]) for sub_span in span2] |
| | print("Rel-True\t{}\t{}\t{}\t{}\t{}".format(rel, span1, span2, ' '.join(span1_text_list), |
| | ' '.join(span2_text_list)), |
| | file=fout) |
| | if 'all_rel_preds' in sent_output: |
| | for (span1, span2), rel in sent_output['all_rel_preds'].items(): |
| |
|
| | span1_text_list = [' '.join(tokens[sub_span[0]:sub_span[1]]) for sub_span in span1] |
| | span2_text_list = [' '.join(tokens[sub_span[0]:sub_span[1]]) for sub_span in span2] |
| | print("Rel-Pred\t{}\t{}\t{}\t{}\t{}".format(rel, span1, span2, ' '.join(span1_text_list), |
| | ' '.join(span2_text_list)), file=fout) |
| |
|
| | print(file=fout) |
| |
|
| | def print_predictions_for_relation_decoding(outputs, file_path, vocab): |
| | with open(file_path, 'w') as fout: |
| | for sent_output in outputs: |
| | seq_len = sent_output['seq_len'] |
| | assert 'tokens' in sent_output |
| | tokens = [vocab.get_token_from_index(token, 'tokens') for token in sent_output['tokens'][:seq_len]] |
| | print("Token\t{}".format(' '.join(tokens)), file=fout) |
| | if 'relation_label_matrix' in sent_output: |
| | for row in sent_output['relation_label_matrix'][:seq_len]: |
| | print("Relation-Label-True\t{}".format(' '.join( |
| | [vocab.get_token_from_index(item, 'ent_rel_id') for item in row[:seq_len]])), |
| | file=fout) |
| |
|
| | if 'relation_label_preds' in sent_output: |
| | for row in sent_output['relation_label_preds'][:seq_len]: |
| | print("Relation-Label-Pred\t{}".format(' '.join( |
| | [vocab.get_token_from_index(item, 'ent_rel_id') for item in row[:seq_len]])), |
| | file=fout) |
| |
|