Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| """ | |
| This script computes smatch score between two AMRs. | |
| For detailed description of smatch, see http://www.isi.edu/natural-language/amr/smatch-13.pdf | |
| """ | |
| from __future__ import division | |
| from __future__ import print_function | |
| try: | |
| import smatch.amr | |
| except: | |
| import amr | |
| import os | |
| import random | |
| import sys | |
| # total number of iteration in smatch computation | |
| iteration_num = 5 | |
| # verbose output switch. | |
| # Default false (no verbose output) | |
| verbose = False | |
| veryVerbose = False | |
| # single score output switch. | |
| # Default true (compute a single score for all AMRs in two files) | |
| single_score = True | |
| # precision and recall output switch. | |
| # Default false (do not output precision and recall, just output F score) | |
| pr_flag = False | |
| # Error log location | |
| ERROR_LOG = sys.stderr | |
| # Debug log location | |
| DEBUG_LOG = sys.stderr | |
| # dictionary to save pre-computed node mapping and its resulting triple match count | |
| # key: tuples of node mapping | |
| # value: the matching triple count | |
| match_triple_dict = {} | |
| def build_arg_parser(): | |
| """ | |
| Build an argument parser using argparse. Use it when python version is 2.7 or later. | |
| """ | |
| parser = argparse.ArgumentParser(description="Smatch calculator -- arguments") | |
| parser.add_argument('-f', nargs=2, required=True, type=argparse.FileType('r', encoding="utf-8"), | |
| help='Two files containing AMR pairs. AMRs in each file are separated by a single blank line') | |
| parser.add_argument('-r', type=int, default=4, help='Restart number (Default:4)') | |
| parser.add_argument('--significant', type=int, default=2, help='significant digits to output (default: 2)') | |
| parser.add_argument('-v', action='store_true', help='Verbose output (Default:false)') | |
| parser.add_argument('--vv', action='store_true', help='Very Verbose output (Default:false)') | |
| parser.add_argument('--ms', action='store_true', default=False, | |
| help='Output multiple scores (one AMR pair a score)' | |
| 'instead of a single document-level smatch score (Default: false)') | |
| parser.add_argument('--pr', action='store_true', default=False, | |
| help="Output precision and recall as well as the f-score. Default: false") | |
| parser.add_argument('--justinstance', action='store_true', default=False, | |
| help="just pay attention to matching instances") | |
| parser.add_argument('--justattribute', action='store_true', default=False, | |
| help="just pay attention to matching attributes") | |
| parser.add_argument('--justrelation', action='store_true', default=False, | |
| help="just pay attention to matching relations") | |
| return parser | |
| def build_arg_parser2(): | |
| """ | |
| Build an argument parser using optparse. Use it when python version is 2.5 or 2.6. | |
| """ | |
| usage_str = "Smatch calculator -- arguments" | |
| parser = optparse.OptionParser(usage=usage_str) | |
| parser.add_option("-f", "--files", nargs=2, dest="f", type="string", | |
| help='Two files containing AMR pairs. AMRs in each file are ' \ | |
| 'separated by a single blank line. This option is required.') | |
| parser.add_option("-r", "--restart", dest="r", type="int", help='Restart number (Default: 4)') | |
| parser.add_option('--significant', dest="significant", type="int", default=2, | |
| help='significant digits to output (default: 2)') | |
| parser.add_option("-v", "--verbose", action='store_true', dest="v", help='Verbose output (Default:False)') | |
| parser.add_option("--vv", "--veryverbose", action='store_true', dest="vv", | |
| help='Very Verbose output (Default:False)') | |
| parser.add_option("--ms", "--multiple_score", action='store_true', dest="ms", | |
| help='Output multiple scores (one AMR pair a score) instead of ' \ | |
| 'a single document-level smatch score (Default: False)') | |
| parser.add_option('--pr', "--precision_recall", action='store_true', dest="pr", | |
| help="Output precision and recall as well as the f-score. Default: false") | |
| parser.add_option('--justinstance', action='store_true', default=False, | |
| help="just pay attention to matching instances") | |
| parser.add_option('--justattribute', action='store_true', default=False, | |
| help="just pay attention to matching attributes") | |
| parser.add_option('--justrelation', action='store_true', default=False, | |
| help="just pay attention to matching relations") | |
| parser.set_defaults(r=4, v=False, ms=False, pr=False) | |
| return parser | |
| def get_best_match(instance1, attribute1, relation1, | |
| instance2, attribute2, relation2, | |
| prefix1, prefix2, doinstance=True, doattribute=True, dorelation=True): | |
| """ | |
| Get the highest triple match number between two sets of triples via hill-climbing. | |
| Arguments: | |
| instance1: instance triples of AMR 1 ("instance", node name, node value) | |
| attribute1: attribute triples of AMR 1 (attribute name, node name, attribute value) | |
| relation1: relation triples of AMR 1 (relation name, node 1 name, node 2 name) | |
| instance2: instance triples of AMR 2 ("instance", node name, node value) | |
| attribute2: attribute triples of AMR 2 (attribute name, node name, attribute value) | |
| relation2: relation triples of AMR 2 (relation name, node 1 name, node 2 name) | |
| prefix1: prefix label for AMR 1 | |
| prefix2: prefix label for AMR 2 | |
| Returns: | |
| best_match: the node mapping that results in the highest triple matching number | |
| best_match_num: the highest triple matching number | |
| """ | |
| # Compute candidate pool - all possible node match candidates. | |
| # In the hill-climbing, we only consider candidate in this pool to save computing time. | |
| # weight_dict is a dictionary that maps a pair of node | |
| (candidate_mappings, weight_dict) = compute_pool(instance1, attribute1, relation1, | |
| instance2, attribute2, relation2, | |
| prefix1, prefix2, doinstance=doinstance, doattribute=doattribute, | |
| dorelation=dorelation) | |
| if veryVerbose: | |
| print("Candidate mappings:", file=DEBUG_LOG) | |
| print(candidate_mappings, file=DEBUG_LOG) | |
| print("Weight dictionary", file=DEBUG_LOG) | |
| print(weight_dict, file=DEBUG_LOG) | |
| best_match_num = 0 | |
| # initialize best match mapping | |
| # the ith entry is the node index in AMR 2 which maps to the ith node in AMR 1 | |
| best_mapping = [-1] * len(instance1) | |
| for i in range(iteration_num): | |
| if veryVerbose: | |
| print("Iteration", i, file=DEBUG_LOG) | |
| if i == 0: | |
| # smart initialization used for the first round | |
| cur_mapping = smart_init_mapping(candidate_mappings, instance1, instance2) | |
| else: | |
| # random initialization for the other round | |
| cur_mapping = random_init_mapping(candidate_mappings) | |
| # compute current triple match number | |
| match_num = compute_match(cur_mapping, weight_dict) | |
| if veryVerbose: | |
| print("Node mapping at start", cur_mapping, file=DEBUG_LOG) | |
| print("Triple match number at start:", match_num, file=DEBUG_LOG) | |
| while True: | |
| # get best gain | |
| (gain, new_mapping) = get_best_gain(cur_mapping, candidate_mappings, weight_dict, | |
| len(instance2), match_num) | |
| if veryVerbose: | |
| print("Gain after the hill-climbing", gain, file=DEBUG_LOG) | |
| # hill-climbing until there will be no gain for new node mapping | |
| if gain <= 0: | |
| break | |
| # otherwise update match_num and mapping | |
| match_num += gain | |
| cur_mapping = new_mapping[:] | |
| if veryVerbose: | |
| print("Update triple match number to:", match_num, file=DEBUG_LOG) | |
| print("Current mapping:", cur_mapping, file=DEBUG_LOG) | |
| if match_num > best_match_num: | |
| best_mapping = cur_mapping[:] | |
| best_match_num = match_num | |
| return best_mapping, best_match_num | |
| def normalize(item): | |
| """ | |
| lowercase and remove quote signifiers from items that are about to be compared | |
| """ | |
| item = item.rstrip("¦") | |
| return item.lower().rstrip('_') | |
| def compute_pool(instance1, attribute1, relation1, | |
| instance2, attribute2, relation2, | |
| prefix1, prefix2, doinstance=True, doattribute=True, dorelation=True): | |
| """ | |
| compute all possible node mapping candidates and their weights (the triple matching number gain resulting from | |
| mapping one node in AMR 1 to another node in AMR2) | |
| Arguments: | |
| instance1: instance triples of AMR 1 | |
| attribute1: attribute triples of AMR 1 (attribute name, node name, attribute value) | |
| relation1: relation triples of AMR 1 (relation name, node 1 name, node 2 name) | |
| instance2: instance triples of AMR 2 | |
| attribute2: attribute triples of AMR 2 (attribute name, node name, attribute value) | |
| relation2: relation triples of AMR 2 (relation name, node 1 name, node 2 name | |
| prefix1: prefix label for AMR 1 | |
| prefix2: prefix label for AMR 2 | |
| Returns: | |
| candidate_mapping: a list of candidate nodes. | |
| The ith element contains the node indices (in AMR 2) the ith node (in AMR 1) can map to. | |
| (resulting in non-zero triple match) | |
| weight_dict: a dictionary which contains the matching triple number for every pair of node mapping. The key | |
| is a node pair. The value is another dictionary. key {-1} is triple match resulting from this node | |
| pair alone (instance triples and attribute triples), and other keys are node pairs that can result | |
| in relation triple match together with the first node pair. | |
| """ | |
| candidate_mapping = [] | |
| weight_dict = {} | |
| for instance1_item in instance1: | |
| # each candidate mapping is a set of node indices | |
| candidate_mapping.append(set()) | |
| if doinstance: | |
| for instance2_item in instance2: | |
| # if both triples are instance triples and have the same value | |
| if normalize(instance1_item[0]) == normalize(instance2_item[0]) and \ | |
| normalize(instance1_item[2]) == normalize(instance2_item[2]): | |
| # get node index by stripping the prefix | |
| node1_index = int(instance1_item[1][len(prefix1):]) | |
| node2_index = int(instance2_item[1][len(prefix2):]) | |
| candidate_mapping[node1_index].add(node2_index) | |
| node_pair = (node1_index, node2_index) | |
| # use -1 as key in weight_dict for instance triples and attribute triples | |
| if node_pair in weight_dict: | |
| weight_dict[node_pair][-1] += 1 | |
| else: | |
| weight_dict[node_pair] = {} | |
| weight_dict[node_pair][-1] = 1 | |
| if doattribute: | |
| for attribute1_item in attribute1: | |
| for attribute2_item in attribute2: | |
| # if both attribute relation triple have the same relation name and value | |
| if normalize(attribute1_item[0]) == normalize(attribute2_item[0]) \ | |
| and normalize(attribute1_item[2]) == normalize(attribute2_item[2]): | |
| node1_index = int(attribute1_item[1][len(prefix1):]) | |
| node2_index = int(attribute2_item[1][len(prefix2):]) | |
| candidate_mapping[node1_index].add(node2_index) | |
| node_pair = (node1_index, node2_index) | |
| # use -1 as key in weight_dict for instance triples and attribute triples | |
| if node_pair in weight_dict: | |
| weight_dict[node_pair][-1] += 1 | |
| else: | |
| weight_dict[node_pair] = {} | |
| weight_dict[node_pair][-1] = 1 | |
| if dorelation: | |
| for relation1_item in relation1: | |
| for relation2_item in relation2: | |
| # if both relation share the same name | |
| if normalize(relation1_item[0]) == normalize(relation2_item[0]): | |
| node1_index_amr1 = int(relation1_item[1][len(prefix1):]) | |
| node1_index_amr2 = int(relation2_item[1][len(prefix2):]) | |
| node2_index_amr1 = int(relation1_item[2][len(prefix1):]) | |
| node2_index_amr2 = int(relation2_item[2][len(prefix2):]) | |
| # add mapping between two nodes | |
| candidate_mapping[node1_index_amr1].add(node1_index_amr2) | |
| candidate_mapping[node2_index_amr1].add(node2_index_amr2) | |
| node_pair1 = (node1_index_amr1, node1_index_amr2) | |
| node_pair2 = (node2_index_amr1, node2_index_amr2) | |
| if node_pair2 != node_pair1: | |
| # update weight_dict weight. Note that we need to update both entries for future search | |
| # i.e weight_dict[node_pair1][node_pair2] | |
| # weight_dict[node_pair2][node_pair1] | |
| if node1_index_amr1 > node2_index_amr1: | |
| # swap node_pair1 and node_pair2 | |
| node_pair1 = (node2_index_amr1, node2_index_amr2) | |
| node_pair2 = (node1_index_amr1, node1_index_amr2) | |
| if node_pair1 in weight_dict: | |
| if node_pair2 in weight_dict[node_pair1]: | |
| weight_dict[node_pair1][node_pair2] += 1 | |
| else: | |
| weight_dict[node_pair1][node_pair2] = 1 | |
| else: | |
| weight_dict[node_pair1] = {-1: 0, node_pair2: 1} | |
| if node_pair2 in weight_dict: | |
| if node_pair1 in weight_dict[node_pair2]: | |
| weight_dict[node_pair2][node_pair1] += 1 | |
| else: | |
| weight_dict[node_pair2][node_pair1] = 1 | |
| else: | |
| weight_dict[node_pair2] = {-1: 0, node_pair1: 1} | |
| else: | |
| # two node pairs are the same. So we only update weight_dict once. | |
| # this generally should not happen. | |
| if node_pair1 in weight_dict: | |
| weight_dict[node_pair1][-1] += 1 | |
| else: | |
| weight_dict[node_pair1] = {-1: 1} | |
| return candidate_mapping, weight_dict | |
| def smart_init_mapping(candidate_mapping, instance1, instance2): | |
| """ | |
| Initialize mapping based on the concept mapping (smart initialization) | |
| Arguments: | |
| candidate_mapping: candidate node match list | |
| instance1: instance triples of AMR 1 | |
| instance2: instance triples of AMR 2 | |
| Returns: | |
| initialized node mapping between two AMRs | |
| """ | |
| random.seed() | |
| matched_dict = {} | |
| result = [] | |
| # list to store node indices that have no concept match | |
| no_word_match = [] | |
| for i, candidates in enumerate(candidate_mapping): | |
| if not candidates: | |
| # no possible mapping | |
| result.append(-1) | |
| continue | |
| # node value in instance triples of AMR 1 | |
| value1 = instance1[i][2] | |
| for node_index in candidates: | |
| value2 = instance2[node_index][2] | |
| # find the first instance triple match in the candidates | |
| # instance triple match is having the same concept value | |
| if value1 == value2: | |
| if node_index not in matched_dict: | |
| result.append(node_index) | |
| matched_dict[node_index] = 1 | |
| break | |
| if len(result) == i: | |
| no_word_match.append(i) | |
| result.append(-1) | |
| # if no concept match, generate a random mapping | |
| for i in no_word_match: | |
| candidates = list(candidate_mapping[i]) | |
| while candidates: | |
| # get a random node index from candidates | |
| rid = random.randint(0, len(candidates) - 1) | |
| candidate = candidates[rid] | |
| if candidate in matched_dict: | |
| candidates.pop(rid) | |
| else: | |
| matched_dict[candidate] = 1 | |
| result[i] = candidate | |
| break | |
| return result | |
| def random_init_mapping(candidate_mapping): | |
| """ | |
| Generate a random node mapping. | |
| Args: | |
| candidate_mapping: candidate_mapping: candidate node match list | |
| Returns: | |
| randomly-generated node mapping between two AMRs | |
| """ | |
| # if needed, a fixed seed could be passed here to generate same random (to help debugging) | |
| random.seed() | |
| matched_dict = {} | |
| result = [] | |
| for c in candidate_mapping: | |
| candidates = list(c) | |
| if not candidates: | |
| # -1 indicates no possible mapping | |
| result.append(-1) | |
| continue | |
| found = False | |
| while candidates: | |
| # randomly generate an index in [0, length of candidates) | |
| rid = random.randint(0, len(candidates) - 1) | |
| candidate = candidates[rid] | |
| # check if it has already been matched | |
| if candidate in matched_dict: | |
| candidates.pop(rid) | |
| else: | |
| matched_dict[candidate] = 1 | |
| result.append(candidate) | |
| found = True | |
| break | |
| if not found: | |
| result.append(-1) | |
| return result | |
| def compute_match(mapping, weight_dict): | |
| """ | |
| Given a node mapping, compute match number based on weight_dict. | |
| Args: | |
| mappings: a list of node index in AMR 2. The ith element (value j) means node i in AMR 1 maps to node j in AMR 2. | |
| Returns: | |
| matching triple number | |
| Complexity: O(m*n) , m is the node number of AMR 1, n is the node number of AMR 2 | |
| """ | |
| # If this mapping has been investigated before, retrieve the value instead of re-computing. | |
| if veryVerbose: | |
| print("Computing match for mapping", file=DEBUG_LOG) | |
| print(mapping, file=DEBUG_LOG) | |
| if tuple(mapping) in match_triple_dict: | |
| if veryVerbose: | |
| print("saved value", match_triple_dict[tuple(mapping)], file=DEBUG_LOG) | |
| return match_triple_dict[tuple(mapping)] | |
| match_num = 0 | |
| # i is node index in AMR 1, m is node index in AMR 2 | |
| for i, m in enumerate(mapping): | |
| if m == -1: | |
| # no node maps to this node | |
| continue | |
| # node i in AMR 1 maps to node m in AMR 2 | |
| current_node_pair = (i, m) | |
| if current_node_pair not in weight_dict: | |
| continue | |
| if veryVerbose: | |
| print("node_pair", current_node_pair, file=DEBUG_LOG) | |
| for key in weight_dict[current_node_pair]: | |
| if key == -1: | |
| # matching triple resulting from instance/attribute triples | |
| match_num += weight_dict[current_node_pair][key] | |
| if veryVerbose: | |
| print("instance/attribute match", weight_dict[current_node_pair][key], file=DEBUG_LOG) | |
| # only consider node index larger than i to avoid duplicates | |
| # as we store both weight_dict[node_pair1][node_pair2] and | |
| # weight_dict[node_pair2][node_pair1] for a relation | |
| elif key[0] < i: | |
| continue | |
| elif mapping[key[0]] == key[1]: | |
| match_num += weight_dict[current_node_pair][key] | |
| if veryVerbose: | |
| print("relation match with", key, weight_dict[current_node_pair][key], file=DEBUG_LOG) | |
| if veryVerbose: | |
| print("match computing complete, result:", match_num, file=DEBUG_LOG) | |
| # update match_triple_dict | |
| match_triple_dict[tuple(mapping)] = match_num | |
| return match_num | |
| def move_gain(mapping, node_id, old_id, new_id, weight_dict, match_num): | |
| """ | |
| Compute the triple match number gain from the move operation | |
| Arguments: | |
| mapping: current node mapping | |
| node_id: remapped node in AMR 1 | |
| old_id: original node id in AMR 2 to which node_id is mapped | |
| new_id: new node in to which node_id is mapped | |
| weight_dict: weight dictionary | |
| match_num: the original triple matching number | |
| Returns: | |
| the triple match gain number (might be negative) | |
| """ | |
| # new node mapping after moving | |
| new_mapping = (node_id, new_id) | |
| # node mapping before moving | |
| old_mapping = (node_id, old_id) | |
| # new nodes mapping list (all node pairs) | |
| new_mapping_list = mapping[:] | |
| new_mapping_list[node_id] = new_id | |
| # if this mapping is already been investigated, use saved one to avoid duplicate computing | |
| if tuple(new_mapping_list) in match_triple_dict: | |
| return match_triple_dict[tuple(new_mapping_list)] - match_num | |
| gain = 0 | |
| # add the triple match incurred by new_mapping to gain | |
| if new_mapping in weight_dict: | |
| for key in weight_dict[new_mapping]: | |
| if key == -1: | |
| # instance/attribute triple match | |
| gain += weight_dict[new_mapping][-1] | |
| elif new_mapping_list[key[0]] == key[1]: | |
| # relation gain incurred by new_mapping and another node pair in new_mapping_list | |
| gain += weight_dict[new_mapping][key] | |
| # deduct the triple match incurred by old_mapping from gain | |
| if old_mapping in weight_dict: | |
| for k in weight_dict[old_mapping]: | |
| if k == -1: | |
| gain -= weight_dict[old_mapping][-1] | |
| elif mapping[k[0]] == k[1]: | |
| gain -= weight_dict[old_mapping][k] | |
| # update match number dictionary | |
| match_triple_dict[tuple(new_mapping_list)] = match_num + gain | |
| return gain | |
| def swap_gain(mapping, node_id1, mapping_id1, node_id2, mapping_id2, weight_dict, match_num): | |
| """ | |
| Compute the triple match number gain from the swapping | |
| Arguments: | |
| mapping: current node mapping list | |
| node_id1: node 1 index in AMR 1 | |
| mapping_id1: the node index in AMR 2 node 1 maps to (in the current mapping) | |
| node_id2: node 2 index in AMR 1 | |
| mapping_id2: the node index in AMR 2 node 2 maps to (in the current mapping) | |
| weight_dict: weight dictionary | |
| match_num: the original matching triple number | |
| Returns: | |
| the gain number (might be negative) | |
| """ | |
| new_mapping_list = mapping[:] | |
| # Before swapping, node_id1 maps to mapping_id1, and node_id2 maps to mapping_id2 | |
| # After swapping, node_id1 maps to mapping_id2 and node_id2 maps to mapping_id1 | |
| new_mapping_list[node_id1] = mapping_id2 | |
| new_mapping_list[node_id2] = mapping_id1 | |
| if tuple(new_mapping_list) in match_triple_dict: | |
| return match_triple_dict[tuple(new_mapping_list)] - match_num | |
| gain = 0 | |
| new_mapping1 = (node_id1, mapping_id2) | |
| new_mapping2 = (node_id2, mapping_id1) | |
| old_mapping1 = (node_id1, mapping_id1) | |
| old_mapping2 = (node_id2, mapping_id2) | |
| if node_id1 > node_id2: | |
| new_mapping2 = (node_id1, mapping_id2) | |
| new_mapping1 = (node_id2, mapping_id1) | |
| old_mapping1 = (node_id2, mapping_id2) | |
| old_mapping2 = (node_id1, mapping_id1) | |
| if new_mapping1 in weight_dict: | |
| for key in weight_dict[new_mapping1]: | |
| if key == -1: | |
| gain += weight_dict[new_mapping1][-1] | |
| elif new_mapping_list[key[0]] == key[1]: | |
| gain += weight_dict[new_mapping1][key] | |
| if new_mapping2 in weight_dict: | |
| for key in weight_dict[new_mapping2]: | |
| if key == -1: | |
| gain += weight_dict[new_mapping2][-1] | |
| # to avoid duplicate | |
| elif key[0] == node_id1: | |
| continue | |
| elif new_mapping_list[key[0]] == key[1]: | |
| gain += weight_dict[new_mapping2][key] | |
| if old_mapping1 in weight_dict: | |
| for key in weight_dict[old_mapping1]: | |
| if key == -1: | |
| gain -= weight_dict[old_mapping1][-1] | |
| elif mapping[key[0]] == key[1]: | |
| gain -= weight_dict[old_mapping1][key] | |
| if old_mapping2 in weight_dict: | |
| for key in weight_dict[old_mapping2]: | |
| if key == -1: | |
| gain -= weight_dict[old_mapping2][-1] | |
| # to avoid duplicate | |
| elif key[0] == node_id1: | |
| continue | |
| elif mapping[key[0]] == key[1]: | |
| gain -= weight_dict[old_mapping2][key] | |
| match_triple_dict[tuple(new_mapping_list)] = match_num + gain | |
| return gain | |
| def get_best_gain(mapping, candidate_mappings, weight_dict, instance_len, cur_match_num): | |
| """ | |
| Hill-climbing method to return the best gain swap/move can get | |
| Arguments: | |
| mapping: current node mapping | |
| candidate_mappings: the candidates mapping list | |
| weight_dict: the weight dictionary | |
| instance_len: the number of the nodes in AMR 2 | |
| cur_match_num: current triple match number | |
| Returns: | |
| the best gain we can get via swap/move operation | |
| """ | |
| largest_gain = 0 | |
| # True: using swap; False: using move | |
| use_swap = True | |
| # the node to be moved/swapped | |
| node1 = None | |
| # store the other node affected. In swap, this other node is the node swapping with node1. In move, this other | |
| # node is the node node1 will move to. | |
| node2 = None | |
| # unmatched nodes in AMR 2 | |
| unmatched = set(range(instance_len)) | |
| # exclude nodes in current mapping | |
| # get unmatched nodes | |
| for nid in mapping: | |
| if nid in unmatched: | |
| unmatched.remove(nid) | |
| for i, nid in enumerate(mapping): | |
| # current node i in AMR 1 maps to node nid in AMR 2 | |
| for nm in unmatched: | |
| if nm in candidate_mappings[i]: | |
| # remap i to another unmatched node (move) | |
| # (i, m) -> (i, nm) | |
| if veryVerbose: | |
| print("Remap node", i, "from ", nid, "to", nm, file=DEBUG_LOG) | |
| mv_gain = move_gain(mapping, i, nid, nm, weight_dict, cur_match_num) | |
| if veryVerbose: | |
| print("Move gain:", mv_gain, file=DEBUG_LOG) | |
| new_mapping = mapping[:] | |
| new_mapping[i] = nm | |
| new_match_num = compute_match(new_mapping, weight_dict) | |
| if new_match_num != cur_match_num + mv_gain: | |
| print(mapping, new_mapping, file=ERROR_LOG) | |
| print("Inconsistency in computing: move gain", cur_match_num, mv_gain, new_match_num, | |
| file=ERROR_LOG) | |
| if mv_gain > largest_gain: | |
| largest_gain = mv_gain | |
| node1 = i | |
| node2 = nm | |
| use_swap = False | |
| # compute swap gain | |
| for i, m in enumerate(mapping): | |
| for j in range(i + 1, len(mapping)): | |
| m2 = mapping[j] | |
| # swap operation (i, m) (j, m2) -> (i, m2) (j, m) | |
| # j starts from i+1, to avoid duplicate swap | |
| if veryVerbose: | |
| print("Swap node", i, "and", j, file=DEBUG_LOG) | |
| print("Before swapping:", i, "-", m, ",", j, "-", m2, file=DEBUG_LOG) | |
| print(mapping, file=DEBUG_LOG) | |
| print("After swapping:", i, "-", m2, ",", j, "-", m, file=DEBUG_LOG) | |
| sw_gain = swap_gain(mapping, i, m, j, m2, weight_dict, cur_match_num) | |
| if veryVerbose: | |
| print("Swap gain:", sw_gain, file=DEBUG_LOG) | |
| new_mapping = mapping[:] | |
| new_mapping[i] = m2 | |
| new_mapping[j] = m | |
| print(new_mapping, file=DEBUG_LOG) | |
| new_match_num = compute_match(new_mapping, weight_dict) | |
| if new_match_num != cur_match_num + sw_gain: | |
| print(mapping, new_mapping, file=ERROR_LOG) | |
| print("Inconsistency in computing: swap gain", cur_match_num, sw_gain, new_match_num, | |
| file=ERROR_LOG) | |
| if sw_gain > largest_gain: | |
| largest_gain = sw_gain | |
| node1 = i | |
| node2 = j | |
| use_swap = True | |
| # generate a new mapping based on swap/move | |
| cur_mapping = mapping[:] | |
| if node1 is not None: | |
| if use_swap: | |
| if veryVerbose: | |
| print("Use swap gain", file=DEBUG_LOG) | |
| temp = cur_mapping[node1] | |
| cur_mapping[node1] = cur_mapping[node2] | |
| cur_mapping[node2] = temp | |
| else: | |
| if veryVerbose: | |
| print("Use move gain", file=DEBUG_LOG) | |
| cur_mapping[node1] = node2 | |
| else: | |
| if veryVerbose: | |
| print("no move/swap gain found", file=DEBUG_LOG) | |
| if veryVerbose: | |
| print("Original mapping", mapping, file=DEBUG_LOG) | |
| print("Current mapping", cur_mapping, file=DEBUG_LOG) | |
| return largest_gain, cur_mapping | |
| def print_alignment(mapping, instance1, instance2): | |
| """ | |
| print the alignment based on a node mapping | |
| Args: | |
| mapping: current node mapping list | |
| instance1: nodes of AMR 1 | |
| instance2: nodes of AMR 2 | |
| """ | |
| result = [] | |
| for instance1_item, m in zip(instance1, mapping): | |
| r = instance1_item[1] + "(" + instance1_item[2] + ")" | |
| if m == -1: | |
| r += "-Null" | |
| else: | |
| instance2_item = instance2[m] | |
| r += "-" + instance2_item[1] + "(" + instance2_item[2] + ")" | |
| result.append(r) | |
| return " ".join(result) | |
| def compute_f(match_num, test_num, gold_num): | |
| """ | |
| Compute the f-score based on the matching triple number, | |
| triple number of AMR set 1, | |
| triple number of AMR set 2 | |
| Args: | |
| match_num: matching triple number | |
| test_num: triple number of AMR 1 (test file) | |
| gold_num: triple number of AMR 2 (gold file) | |
| Returns: | |
| precision: match_num/test_num | |
| recall: match_num/gold_num | |
| f_score: 2*precision*recall/(precision+recall) | |
| """ | |
| if test_num == 0 or gold_num == 0: | |
| return 0.00, 0.00, 0.00 | |
| precision = float(match_num) / float(test_num) | |
| recall = float(match_num) / float(gold_num) | |
| if (precision + recall) != 0: | |
| f_score = 2 * precision * recall / (precision + recall) | |
| if veryVerbose: | |
| print("F-score:", f_score, file=DEBUG_LOG) | |
| return precision, recall, f_score | |
| else: | |
| if veryVerbose: | |
| print("F-score:", "0.0", file=DEBUG_LOG) | |
| return precision, recall, 0.00 | |
| def generate_amr_lines(f1, f2): | |
| """ | |
| Read one AMR line at a time from each file handle | |
| :param f1: file handle (or any iterable of strings) to read AMR 1 lines from | |
| :param f2: file handle (or any iterable of strings) to read AMR 2 lines from | |
| :return: generator of cur_amr1, cur_amr2 pairs: one-line AMR strings | |
| """ | |
| while True: | |
| cur_amr1 = amr.AMR.get_amr_line(f1) | |
| cur_amr2 = amr.AMR.get_amr_line(f2) | |
| if not cur_amr1 and not cur_amr2: | |
| pass | |
| elif not cur_amr1: | |
| print("Error: File 1 has less AMRs than file 2", file=ERROR_LOG) | |
| print("Ignoring remaining AMRs", file=ERROR_LOG) | |
| elif not cur_amr2: | |
| print("Error: File 2 has less AMRs than file 1", file=ERROR_LOG) | |
| print("Ignoring remaining AMRs", file=ERROR_LOG) | |
| else: | |
| yield cur_amr1, cur_amr2 | |
| continue | |
| break | |
| def get_amr_match(cur_amr1, cur_amr2, sent_num=1, justinstance=False, justattribute=False, justrelation=False, | |
| limit = None, | |
| instance1 = None, attributes1 = None, relation1 = None, prefix1 = None, | |
| instance2 = None, attributes2 = None, relation2 = None, prefix2 = None): | |
| global iteration_num | |
| if limit is not None: iteration_num = limit | |
| if cur_amr1 and cur_amr2: | |
| amr_pair = [] | |
| for i, cur_amr in (1, cur_amr1), (2, cur_amr2): | |
| try: | |
| amr_pair.append(amr.AMR.parse_AMR_line(cur_amr)) | |
| except Exception as e: | |
| print("Error in parsing amr %d: %s" % (i, cur_amr), file=ERROR_LOG) | |
| print("Please check if the AMR is ill-formatted. Ignoring remaining AMRs", file=ERROR_LOG) | |
| print("Error message: %s" % e, file=ERROR_LOG) | |
| amr1, amr2 = amr_pair | |
| prefix1 = "a" | |
| prefix2 = "b" | |
| # Rename node to "a1", "a2", .etc | |
| amr1.rename_node(prefix1) | |
| # Renaming node to "b1", "b2", .etc | |
| amr2.rename_node(prefix2) | |
| (instance1, attributes1, relation1) = amr1.get_triples() | |
| (instance2, attributes2, relation2) = amr2.get_triples() | |
| if verbose: | |
| print("AMR pair", sent_num, file=DEBUG_LOG) | |
| print("============================================", file=DEBUG_LOG) | |
| print("AMR 1 (one-line):", cur_amr1, file=DEBUG_LOG) | |
| print("AMR 2 (one-line):", cur_amr2, file=DEBUG_LOG) | |
| print("Instance triples of AMR 1:", len(instance1), file=DEBUG_LOG) | |
| print(instance1, file=DEBUG_LOG) | |
| print("Attribute triples of AMR 1:", len(attributes1), file=DEBUG_LOG) | |
| print(attributes1, file=DEBUG_LOG) | |
| print("Relation triples of AMR 1:", len(relation1), file=DEBUG_LOG) | |
| print(relation1, file=DEBUG_LOG) | |
| print("Instance triples of AMR 2:", len(instance2), file=DEBUG_LOG) | |
| print(instance2, file=DEBUG_LOG) | |
| print("Attribute triples of AMR 2:", len(attributes2), file=DEBUG_LOG) | |
| print(attributes2, file=DEBUG_LOG) | |
| print("Relation triples of AMR 2:", len(relation2), file=DEBUG_LOG) | |
| print(relation2, file=DEBUG_LOG) | |
| # optionally turn off some of the node comparison | |
| doinstance = doattribute = dorelation = True | |
| if justinstance: | |
| doattribute = dorelation = False | |
| if justattribute: | |
| doinstance = dorelation = False | |
| if justrelation: | |
| doinstance = doattribute = False | |
| (best_mapping, best_match_num) = get_best_match(instance1, attributes1, relation1, | |
| instance2, attributes2, relation2, | |
| prefix1, prefix2, doinstance=doinstance, | |
| doattribute=doattribute, dorelation=dorelation) | |
| if verbose: | |
| print("best match number", best_match_num, file=DEBUG_LOG) | |
| print("best node mapping", best_mapping, file=DEBUG_LOG) | |
| print("Best node mapping alignment:", print_alignment(best_mapping, instance1, instance2), file=DEBUG_LOG) | |
| if justinstance: | |
| test_triple_num = len(instance1) | |
| gold_triple_num = len(instance2) | |
| elif justattribute: | |
| test_triple_num = len(attributes1) | |
| gold_triple_num = len(attributes2) | |
| elif justrelation: | |
| test_triple_num = len(relation1) | |
| gold_triple_num = len(relation2) | |
| else: | |
| test_triple_num = len(instance1) + len(attributes1) + len(relation1) | |
| gold_triple_num = len(instance2) + len(attributes2) + len(relation2) | |
| match_triple_dict.clear() | |
| if cur_amr1 and cur_amr2: | |
| return best_match_num, test_triple_num, gold_triple_num | |
| else: | |
| return best_match_num, test_triple_num, gold_triple_num, best_mapping | |
| def score_amr_pairs(f1, f2, justinstance=False, justattribute=False, justrelation=False): | |
| """ | |
| Score one pair of AMR lines at a time from each file handle | |
| :param f1: file handle (or any iterable of strings) to read AMR 1 lines from | |
| :param f2: file handle (or any iterable of strings) to read AMR 2 lines from | |
| :param justinstance: just pay attention to matching instances | |
| :param justattribute: just pay attention to matching attributes | |
| :param justrelation: just pay attention to matching relations | |
| :return: generator of cur_amr1, cur_amr2 pairs: one-line AMR strings | |
| """ | |
| # matching triple number, triple number in test file, triple number in gold file | |
| total_match_num = total_test_num = total_gold_num = 0 | |
| # Read amr pairs from two files | |
| for sent_num, (cur_amr1, cur_amr2) in enumerate(generate_amr_lines(f1, f2), start=1): | |
| best_match_num, test_triple_num, gold_triple_num = get_amr_match(cur_amr1, cur_amr2, | |
| sent_num=sent_num, # sentence number | |
| justinstance=justinstance, | |
| justattribute=justattribute, | |
| justrelation=justrelation) | |
| total_match_num += best_match_num | |
| total_test_num += test_triple_num | |
| total_gold_num += gold_triple_num | |
| # clear the matching triple dictionary for the next AMR pair | |
| match_triple_dict.clear() | |
| if not single_score: # if each AMR pair should have a score, compute and output it here | |
| yield compute_f(best_match_num, test_triple_num, gold_triple_num) | |
| if verbose: | |
| print("Total match number, total triple number in AMR 1, and total triple number in AMR 2:", file=DEBUG_LOG) | |
| print(total_match_num, total_test_num, total_gold_num, file=DEBUG_LOG) | |
| print("---------------------------------------------------------------------------------", file=DEBUG_LOG) | |
| if single_score: # output document-level smatch score (a single f-score for all AMR pairs in two files) | |
| yield compute_f(total_match_num, total_test_num, total_gold_num) | |
| def main(arguments): | |
| """ | |
| Main function of smatch score calculation | |
| """ | |
| global verbose | |
| global veryVerbose | |
| global iteration_num | |
| global single_score | |
| global pr_flag | |
| global match_triple_dict | |
| # set the iteration number | |
| # total iteration number = restart number + 1 | |
| iteration_num = arguments.r + 1 | |
| if arguments.ms: | |
| single_score = False | |
| if arguments.v: | |
| verbose = True | |
| if arguments.vv: | |
| veryVerbose = True | |
| if arguments.pr: | |
| pr_flag = True | |
| # significant digits to print out | |
| floatdisplay = "%%.%df" % arguments.significant | |
| for (precision, recall, best_f_score) in score_amr_pairs(args.f[0], args.f[1], | |
| justinstance=arguments.justinstance, | |
| justattribute=arguments.justattribute, | |
| justrelation=arguments.justrelation): | |
| # print("Sentence", sent_num) | |
| if pr_flag: | |
| print("Precision: " + floatdisplay % precision) | |
| print("Recall: " + floatdisplay % recall) | |
| print("F-score: " + floatdisplay % best_f_score) | |
| args.f[0].close() | |
| args.f[1].close() | |
| if __name__ == "__main__": | |
| parser = None | |
| args = None | |
| # use optparse if python version is 2.5 or 2.6 | |
| if sys.version_info[0] == 2 and sys.version_info[1] < 7: | |
| import optparse | |
| if len(sys.argv) == 1: | |
| print("No argument given. Please run smatch.py -h to see the argument description.", file=ERROR_LOG) | |
| exit(1) | |
| parser = build_arg_parser2() | |
| (args, opts) = parser.parse_args() | |
| file_handle = [] | |
| if args.f is None: | |
| print("smatch.py requires -f option to indicate two files \ | |
| containing AMR as input. Please run smatch.py -h to \ | |
| see the argument description.", file=ERROR_LOG) | |
| exit(1) | |
| # assert there are 2 file names following -f. | |
| assert (len(args.f) == 2) | |
| for file_path in args.f: | |
| if not os.path.exists(file_path): | |
| print("Given file", args.f[0], "does not exist", file=ERROR_LOG) | |
| exit(1) | |
| file_handle.append(open(file_path)) | |
| # use opened files | |
| args.f = tuple(file_handle) | |
| # use argparse if python version is 2.7 or later | |
| else: | |
| import argparse | |
| parser = build_arg_parser() | |
| args = parser.parse_args() | |
| main(args) | |