| |
| |
|
|
|
|
| """ |
| This script computes smatch score between two AMRs. |
| For detailed description of smatch, see http://www.isi.edu/natural-language/amr/smatch-13.pdf |
| |
| """ |
|
|
| from __future__ import division |
| from __future__ import print_function |
|
|
| try: |
| import smatch.amr |
| except: |
| import amr |
| import os |
| import random |
| import sys |
|
|
| |
| iteration_num = 5 |
|
|
| |
| |
| verbose = False |
| veryVerbose = False |
|
|
| |
| |
| single_score = True |
|
|
| |
| |
| pr_flag = False |
|
|
| |
| ERROR_LOG = sys.stderr |
|
|
| |
| DEBUG_LOG = sys.stderr |
|
|
| |
| |
| |
| match_triple_dict = {} |
|
|
|
|
| def build_arg_parser(): |
| """ |
| Build an argument parser using argparse. Use it when python version is 2.7 or later. |
| |
| """ |
| parser = argparse.ArgumentParser(description="Smatch calculator -- arguments") |
| parser.add_argument('-f', nargs=2, required=True, type=argparse.FileType('r', encoding="utf-8"), |
| help='Two files containing AMR pairs. AMRs in each file are separated by a single blank line') |
| parser.add_argument('-r', type=int, default=4, help='Restart number (Default:4)') |
| parser.add_argument('--significant', type=int, default=2, help='significant digits to output (default: 2)') |
| parser.add_argument('-v', action='store_true', help='Verbose output (Default:false)') |
| parser.add_argument('--vv', action='store_true', help='Very Verbose output (Default:false)') |
| parser.add_argument('--ms', action='store_true', default=False, |
| help='Output multiple scores (one AMR pair a score)' |
| 'instead of a single document-level smatch score (Default: false)') |
| parser.add_argument('--pr', action='store_true', default=False, |
| help="Output precision and recall as well as the f-score. Default: false") |
| parser.add_argument('--justinstance', action='store_true', default=False, |
| help="just pay attention to matching instances") |
| parser.add_argument('--justattribute', action='store_true', default=False, |
| help="just pay attention to matching attributes") |
| parser.add_argument('--justrelation', action='store_true', default=False, |
| help="just pay attention to matching relations") |
|
|
| return parser |
|
|
|
|
| def build_arg_parser2(): |
| """ |
| Build an argument parser using optparse. Use it when python version is 2.5 or 2.6. |
| |
| """ |
| usage_str = "Smatch calculator -- arguments" |
| parser = optparse.OptionParser(usage=usage_str) |
| parser.add_option("-f", "--files", nargs=2, dest="f", type="string", |
| help='Two files containing AMR pairs. AMRs in each file are ' \ |
| 'separated by a single blank line. This option is required.') |
| parser.add_option("-r", "--restart", dest="r", type="int", help='Restart number (Default: 4)') |
| parser.add_option('--significant', dest="significant", type="int", default=2, |
| help='significant digits to output (default: 2)') |
| parser.add_option("-v", "--verbose", action='store_true', dest="v", help='Verbose output (Default:False)') |
| parser.add_option("--vv", "--veryverbose", action='store_true', dest="vv", |
| help='Very Verbose output (Default:False)') |
| parser.add_option("--ms", "--multiple_score", action='store_true', dest="ms", |
| help='Output multiple scores (one AMR pair a score) instead of ' \ |
| 'a single document-level smatch score (Default: False)') |
| parser.add_option('--pr', "--precision_recall", action='store_true', dest="pr", |
| help="Output precision and recall as well as the f-score. Default: false") |
| parser.add_option('--justinstance', action='store_true', default=False, |
| help="just pay attention to matching instances") |
| parser.add_option('--justattribute', action='store_true', default=False, |
| help="just pay attention to matching attributes") |
| parser.add_option('--justrelation', action='store_true', default=False, |
| help="just pay attention to matching relations") |
| parser.set_defaults(r=4, v=False, ms=False, pr=False) |
| return parser |
|
|
|
|
| def get_best_match(instance1, attribute1, relation1, |
| instance2, attribute2, relation2, |
| prefix1, prefix2, doinstance=True, doattribute=True, dorelation=True): |
| """ |
| Get the highest triple match number between two sets of triples via hill-climbing. |
| Arguments: |
| instance1: instance triples of AMR 1 ("instance", node name, node value) |
| attribute1: attribute triples of AMR 1 (attribute name, node name, attribute value) |
| relation1: relation triples of AMR 1 (relation name, node 1 name, node 2 name) |
| instance2: instance triples of AMR 2 ("instance", node name, node value) |
| attribute2: attribute triples of AMR 2 (attribute name, node name, attribute value) |
| relation2: relation triples of AMR 2 (relation name, node 1 name, node 2 name) |
| prefix1: prefix label for AMR 1 |
| prefix2: prefix label for AMR 2 |
| Returns: |
| best_match: the node mapping that results in the highest triple matching number |
| best_match_num: the highest triple matching number |
| |
| """ |
| |
| |
| |
| (candidate_mappings, weight_dict) = compute_pool(instance1, attribute1, relation1, |
| instance2, attribute2, relation2, |
| prefix1, prefix2, doinstance=doinstance, doattribute=doattribute, |
| dorelation=dorelation) |
| if veryVerbose: |
| print("Candidate mappings:", file=DEBUG_LOG) |
| print(candidate_mappings, file=DEBUG_LOG) |
| print("Weight dictionary", file=DEBUG_LOG) |
| print(weight_dict, file=DEBUG_LOG) |
|
|
| best_match_num = 0 |
| |
| |
| best_mapping = [-1] * len(instance1) |
| for i in range(iteration_num): |
| if veryVerbose: |
| print("Iteration", i, file=DEBUG_LOG) |
| if i == 0: |
| |
| cur_mapping = smart_init_mapping(candidate_mappings, instance1, instance2) |
| else: |
| |
| cur_mapping = random_init_mapping(candidate_mappings) |
| |
| match_num = compute_match(cur_mapping, weight_dict) |
| if veryVerbose: |
| print("Node mapping at start", cur_mapping, file=DEBUG_LOG) |
| print("Triple match number at start:", match_num, file=DEBUG_LOG) |
| while True: |
| |
| (gain, new_mapping) = get_best_gain(cur_mapping, candidate_mappings, weight_dict, |
| len(instance2), match_num) |
| if veryVerbose: |
| print("Gain after the hill-climbing", gain, file=DEBUG_LOG) |
| |
| if gain <= 0: |
| break |
| |
| match_num += gain |
| cur_mapping = new_mapping[:] |
| if veryVerbose: |
| print("Update triple match number to:", match_num, file=DEBUG_LOG) |
| print("Current mapping:", cur_mapping, file=DEBUG_LOG) |
| if match_num > best_match_num: |
| best_mapping = cur_mapping[:] |
| best_match_num = match_num |
| return best_mapping, best_match_num |
|
|
|
|
| def normalize(item): |
| """ |
| lowercase and remove quote signifiers from items that are about to be compared |
| """ |
| item = item.rstrip("¦") |
| return item.lower().rstrip('_') |
|
|
|
|
| def compute_pool(instance1, attribute1, relation1, |
| instance2, attribute2, relation2, |
| prefix1, prefix2, doinstance=True, doattribute=True, dorelation=True): |
| """ |
| compute all possible node mapping candidates and their weights (the triple matching number gain resulting from |
| mapping one node in AMR 1 to another node in AMR2) |
| |
| Arguments: |
| instance1: instance triples of AMR 1 |
| attribute1: attribute triples of AMR 1 (attribute name, node name, attribute value) |
| relation1: relation triples of AMR 1 (relation name, node 1 name, node 2 name) |
| instance2: instance triples of AMR 2 |
| attribute2: attribute triples of AMR 2 (attribute name, node name, attribute value) |
| relation2: relation triples of AMR 2 (relation name, node 1 name, node 2 name |
| prefix1: prefix label for AMR 1 |
| prefix2: prefix label for AMR 2 |
| Returns: |
| candidate_mapping: a list of candidate nodes. |
| The ith element contains the node indices (in AMR 2) the ith node (in AMR 1) can map to. |
| (resulting in non-zero triple match) |
| weight_dict: a dictionary which contains the matching triple number for every pair of node mapping. The key |
| is a node pair. The value is another dictionary. key {-1} is triple match resulting from this node |
| pair alone (instance triples and attribute triples), and other keys are node pairs that can result |
| in relation triple match together with the first node pair. |
| |
| |
| """ |
| candidate_mapping = [] |
| weight_dict = {} |
| for instance1_item in instance1: |
| |
| candidate_mapping.append(set()) |
| if doinstance: |
| for instance2_item in instance2: |
| |
| if normalize(instance1_item[0]) == normalize(instance2_item[0]) and \ |
| normalize(instance1_item[2]) == normalize(instance2_item[2]): |
| |
| node1_index = int(instance1_item[1][len(prefix1):]) |
| node2_index = int(instance2_item[1][len(prefix2):]) |
| candidate_mapping[node1_index].add(node2_index) |
| node_pair = (node1_index, node2_index) |
| |
| if node_pair in weight_dict: |
| weight_dict[node_pair][-1] += 1 |
| else: |
| weight_dict[node_pair] = {} |
| weight_dict[node_pair][-1] = 1 |
| if doattribute: |
| for attribute1_item in attribute1: |
| for attribute2_item in attribute2: |
| |
| if normalize(attribute1_item[0]) == normalize(attribute2_item[0]) \ |
| and normalize(attribute1_item[2]) == normalize(attribute2_item[2]): |
| node1_index = int(attribute1_item[1][len(prefix1):]) |
| node2_index = int(attribute2_item[1][len(prefix2):]) |
| candidate_mapping[node1_index].add(node2_index) |
| node_pair = (node1_index, node2_index) |
| |
| if node_pair in weight_dict: |
| weight_dict[node_pair][-1] += 1 |
| else: |
| weight_dict[node_pair] = {} |
| weight_dict[node_pair][-1] = 1 |
| if dorelation: |
| for relation1_item in relation1: |
| for relation2_item in relation2: |
| |
| if normalize(relation1_item[0]) == normalize(relation2_item[0]): |
| node1_index_amr1 = int(relation1_item[1][len(prefix1):]) |
| node1_index_amr2 = int(relation2_item[1][len(prefix2):]) |
| node2_index_amr1 = int(relation1_item[2][len(prefix1):]) |
| node2_index_amr2 = int(relation2_item[2][len(prefix2):]) |
| |
| candidate_mapping[node1_index_amr1].add(node1_index_amr2) |
| candidate_mapping[node2_index_amr1].add(node2_index_amr2) |
| node_pair1 = (node1_index_amr1, node1_index_amr2) |
| node_pair2 = (node2_index_amr1, node2_index_amr2) |
| if node_pair2 != node_pair1: |
| |
| |
| |
| if node1_index_amr1 > node2_index_amr1: |
| |
| node_pair1 = (node2_index_amr1, node2_index_amr2) |
| node_pair2 = (node1_index_amr1, node1_index_amr2) |
| if node_pair1 in weight_dict: |
| if node_pair2 in weight_dict[node_pair1]: |
| weight_dict[node_pair1][node_pair2] += 1 |
| else: |
| weight_dict[node_pair1][node_pair2] = 1 |
| else: |
| weight_dict[node_pair1] = {-1: 0, node_pair2: 1} |
| if node_pair2 in weight_dict: |
| if node_pair1 in weight_dict[node_pair2]: |
| weight_dict[node_pair2][node_pair1] += 1 |
| else: |
| weight_dict[node_pair2][node_pair1] = 1 |
| else: |
| weight_dict[node_pair2] = {-1: 0, node_pair1: 1} |
| else: |
| |
| |
| if node_pair1 in weight_dict: |
| weight_dict[node_pair1][-1] += 1 |
| else: |
| weight_dict[node_pair1] = {-1: 1} |
| return candidate_mapping, weight_dict |
|
|
|
|
| def smart_init_mapping(candidate_mapping, instance1, instance2): |
| """ |
| Initialize mapping based on the concept mapping (smart initialization) |
| Arguments: |
| candidate_mapping: candidate node match list |
| instance1: instance triples of AMR 1 |
| instance2: instance triples of AMR 2 |
| Returns: |
| initialized node mapping between two AMRs |
| |
| """ |
| random.seed() |
| matched_dict = {} |
| result = [] |
| |
| no_word_match = [] |
| for i, candidates in enumerate(candidate_mapping): |
| if not candidates: |
| |
| result.append(-1) |
| continue |
| |
| value1 = instance1[i][2] |
| for node_index in candidates: |
| value2 = instance2[node_index][2] |
| |
| |
| if value1 == value2: |
| if node_index not in matched_dict: |
| result.append(node_index) |
| matched_dict[node_index] = 1 |
| break |
| if len(result) == i: |
| no_word_match.append(i) |
| result.append(-1) |
| |
| for i in no_word_match: |
| candidates = list(candidate_mapping[i]) |
| while candidates: |
| |
| rid = random.randint(0, len(candidates) - 1) |
| candidate = candidates[rid] |
| if candidate in matched_dict: |
| candidates.pop(rid) |
| else: |
| matched_dict[candidate] = 1 |
| result[i] = candidate |
| break |
| return result |
|
|
|
|
| def random_init_mapping(candidate_mapping): |
| """ |
| Generate a random node mapping. |
| Args: |
| candidate_mapping: candidate_mapping: candidate node match list |
| Returns: |
| randomly-generated node mapping between two AMRs |
| |
| """ |
| |
| random.seed() |
| matched_dict = {} |
| result = [] |
| for c in candidate_mapping: |
| candidates = list(c) |
| if not candidates: |
| |
| result.append(-1) |
| continue |
| found = False |
| while candidates: |
| |
| rid = random.randint(0, len(candidates) - 1) |
| candidate = candidates[rid] |
| |
| if candidate in matched_dict: |
| candidates.pop(rid) |
| else: |
| matched_dict[candidate] = 1 |
| result.append(candidate) |
| found = True |
| break |
| if not found: |
| result.append(-1) |
| return result |
|
|
|
|
| def compute_match(mapping, weight_dict): |
| """ |
| Given a node mapping, compute match number based on weight_dict. |
| Args: |
| mappings: a list of node index in AMR 2. The ith element (value j) means node i in AMR 1 maps to node j in AMR 2. |
| Returns: |
| matching triple number |
| Complexity: O(m*n) , m is the node number of AMR 1, n is the node number of AMR 2 |
| |
| """ |
| |
| if veryVerbose: |
| print("Computing match for mapping", file=DEBUG_LOG) |
| print(mapping, file=DEBUG_LOG) |
| if tuple(mapping) in match_triple_dict: |
| if veryVerbose: |
| print("saved value", match_triple_dict[tuple(mapping)], file=DEBUG_LOG) |
| return match_triple_dict[tuple(mapping)] |
| match_num = 0 |
| |
| for i, m in enumerate(mapping): |
| if m == -1: |
| |
| continue |
| |
| current_node_pair = (i, m) |
| if current_node_pair not in weight_dict: |
| continue |
| if veryVerbose: |
| print("node_pair", current_node_pair, file=DEBUG_LOG) |
| for key in weight_dict[current_node_pair]: |
| if key == -1: |
| |
| match_num += weight_dict[current_node_pair][key] |
| if veryVerbose: |
| print("instance/attribute match", weight_dict[current_node_pair][key], file=DEBUG_LOG) |
| |
| |
| |
| elif key[0] < i: |
| continue |
| elif mapping[key[0]] == key[1]: |
| match_num += weight_dict[current_node_pair][key] |
| if veryVerbose: |
| print("relation match with", key, weight_dict[current_node_pair][key], file=DEBUG_LOG) |
| if veryVerbose: |
| print("match computing complete, result:", match_num, file=DEBUG_LOG) |
| |
| match_triple_dict[tuple(mapping)] = match_num |
| return match_num |
|
|
|
|
| def move_gain(mapping, node_id, old_id, new_id, weight_dict, match_num): |
| """ |
| Compute the triple match number gain from the move operation |
| Arguments: |
| mapping: current node mapping |
| node_id: remapped node in AMR 1 |
| old_id: original node id in AMR 2 to which node_id is mapped |
| new_id: new node in to which node_id is mapped |
| weight_dict: weight dictionary |
| match_num: the original triple matching number |
| Returns: |
| the triple match gain number (might be negative) |
| |
| """ |
| |
| new_mapping = (node_id, new_id) |
| |
| old_mapping = (node_id, old_id) |
| |
| new_mapping_list = mapping[:] |
| new_mapping_list[node_id] = new_id |
| |
| if tuple(new_mapping_list) in match_triple_dict: |
| return match_triple_dict[tuple(new_mapping_list)] - match_num |
| gain = 0 |
| |
| if new_mapping in weight_dict: |
| for key in weight_dict[new_mapping]: |
| if key == -1: |
| |
| gain += weight_dict[new_mapping][-1] |
| elif new_mapping_list[key[0]] == key[1]: |
| |
| gain += weight_dict[new_mapping][key] |
| |
| if old_mapping in weight_dict: |
| for k in weight_dict[old_mapping]: |
| if k == -1: |
| gain -= weight_dict[old_mapping][-1] |
| elif mapping[k[0]] == k[1]: |
| gain -= weight_dict[old_mapping][k] |
| |
| match_triple_dict[tuple(new_mapping_list)] = match_num + gain |
| return gain |
|
|
|
|
| def swap_gain(mapping, node_id1, mapping_id1, node_id2, mapping_id2, weight_dict, match_num): |
| """ |
| Compute the triple match number gain from the swapping |
| Arguments: |
| mapping: current node mapping list |
| node_id1: node 1 index in AMR 1 |
| mapping_id1: the node index in AMR 2 node 1 maps to (in the current mapping) |
| node_id2: node 2 index in AMR 1 |
| mapping_id2: the node index in AMR 2 node 2 maps to (in the current mapping) |
| weight_dict: weight dictionary |
| match_num: the original matching triple number |
| Returns: |
| the gain number (might be negative) |
| |
| """ |
| new_mapping_list = mapping[:] |
| |
| |
| new_mapping_list[node_id1] = mapping_id2 |
| new_mapping_list[node_id2] = mapping_id1 |
| if tuple(new_mapping_list) in match_triple_dict: |
| return match_triple_dict[tuple(new_mapping_list)] - match_num |
| gain = 0 |
| new_mapping1 = (node_id1, mapping_id2) |
| new_mapping2 = (node_id2, mapping_id1) |
| old_mapping1 = (node_id1, mapping_id1) |
| old_mapping2 = (node_id2, mapping_id2) |
| if node_id1 > node_id2: |
| new_mapping2 = (node_id1, mapping_id2) |
| new_mapping1 = (node_id2, mapping_id1) |
| old_mapping1 = (node_id2, mapping_id2) |
| old_mapping2 = (node_id1, mapping_id1) |
| if new_mapping1 in weight_dict: |
| for key in weight_dict[new_mapping1]: |
| if key == -1: |
| gain += weight_dict[new_mapping1][-1] |
| elif new_mapping_list[key[0]] == key[1]: |
| gain += weight_dict[new_mapping1][key] |
| if new_mapping2 in weight_dict: |
| for key in weight_dict[new_mapping2]: |
| if key == -1: |
| gain += weight_dict[new_mapping2][-1] |
| |
| elif key[0] == node_id1: |
| continue |
| elif new_mapping_list[key[0]] == key[1]: |
| gain += weight_dict[new_mapping2][key] |
| if old_mapping1 in weight_dict: |
| for key in weight_dict[old_mapping1]: |
| if key == -1: |
| gain -= weight_dict[old_mapping1][-1] |
| elif mapping[key[0]] == key[1]: |
| gain -= weight_dict[old_mapping1][key] |
| if old_mapping2 in weight_dict: |
| for key in weight_dict[old_mapping2]: |
| if key == -1: |
| gain -= weight_dict[old_mapping2][-1] |
| |
| elif key[0] == node_id1: |
| continue |
| elif mapping[key[0]] == key[1]: |
| gain -= weight_dict[old_mapping2][key] |
| match_triple_dict[tuple(new_mapping_list)] = match_num + gain |
| return gain |
|
|
|
|
| def get_best_gain(mapping, candidate_mappings, weight_dict, instance_len, cur_match_num): |
| """ |
| Hill-climbing method to return the best gain swap/move can get |
| Arguments: |
| mapping: current node mapping |
| candidate_mappings: the candidates mapping list |
| weight_dict: the weight dictionary |
| instance_len: the number of the nodes in AMR 2 |
| cur_match_num: current triple match number |
| Returns: |
| the best gain we can get via swap/move operation |
| |
| """ |
| largest_gain = 0 |
| |
| use_swap = True |
| |
| node1 = None |
| |
| |
| node2 = None |
| |
| unmatched = set(range(instance_len)) |
| |
| |
| for nid in mapping: |
| if nid in unmatched: |
| unmatched.remove(nid) |
| for i, nid in enumerate(mapping): |
| |
| for nm in unmatched: |
| if nm in candidate_mappings[i]: |
| |
| |
| if veryVerbose: |
| print("Remap node", i, "from ", nid, "to", nm, file=DEBUG_LOG) |
| mv_gain = move_gain(mapping, i, nid, nm, weight_dict, cur_match_num) |
| if veryVerbose: |
| print("Move gain:", mv_gain, file=DEBUG_LOG) |
| new_mapping = mapping[:] |
| new_mapping[i] = nm |
| new_match_num = compute_match(new_mapping, weight_dict) |
| if new_match_num != cur_match_num + mv_gain: |
| print(mapping, new_mapping, file=ERROR_LOG) |
| print("Inconsistency in computing: move gain", cur_match_num, mv_gain, new_match_num, |
| file=ERROR_LOG) |
| if mv_gain > largest_gain: |
| largest_gain = mv_gain |
| node1 = i |
| node2 = nm |
| use_swap = False |
| |
| for i, m in enumerate(mapping): |
| for j in range(i + 1, len(mapping)): |
| m2 = mapping[j] |
| |
| |
| if veryVerbose: |
| print("Swap node", i, "and", j, file=DEBUG_LOG) |
| print("Before swapping:", i, "-", m, ",", j, "-", m2, file=DEBUG_LOG) |
| print(mapping, file=DEBUG_LOG) |
| print("After swapping:", i, "-", m2, ",", j, "-", m, file=DEBUG_LOG) |
| sw_gain = swap_gain(mapping, i, m, j, m2, weight_dict, cur_match_num) |
| if veryVerbose: |
| print("Swap gain:", sw_gain, file=DEBUG_LOG) |
| new_mapping = mapping[:] |
| new_mapping[i] = m2 |
| new_mapping[j] = m |
| print(new_mapping, file=DEBUG_LOG) |
| new_match_num = compute_match(new_mapping, weight_dict) |
| if new_match_num != cur_match_num + sw_gain: |
| print(mapping, new_mapping, file=ERROR_LOG) |
| print("Inconsistency in computing: swap gain", cur_match_num, sw_gain, new_match_num, |
| file=ERROR_LOG) |
| if sw_gain > largest_gain: |
| largest_gain = sw_gain |
| node1 = i |
| node2 = j |
| use_swap = True |
| |
| cur_mapping = mapping[:] |
| if node1 is not None: |
| if use_swap: |
| if veryVerbose: |
| print("Use swap gain", file=DEBUG_LOG) |
| temp = cur_mapping[node1] |
| cur_mapping[node1] = cur_mapping[node2] |
| cur_mapping[node2] = temp |
| else: |
| if veryVerbose: |
| print("Use move gain", file=DEBUG_LOG) |
| cur_mapping[node1] = node2 |
| else: |
| if veryVerbose: |
| print("no move/swap gain found", file=DEBUG_LOG) |
| if veryVerbose: |
| print("Original mapping", mapping, file=DEBUG_LOG) |
| print("Current mapping", cur_mapping, file=DEBUG_LOG) |
| return largest_gain, cur_mapping |
|
|
|
|
| def print_alignment(mapping, instance1, instance2): |
| """ |
| print the alignment based on a node mapping |
| Args: |
| mapping: current node mapping list |
| instance1: nodes of AMR 1 |
| instance2: nodes of AMR 2 |
| |
| """ |
| result = [] |
| for instance1_item, m in zip(instance1, mapping): |
| r = instance1_item[1] + "(" + instance1_item[2] + ")" |
| if m == -1: |
| r += "-Null" |
| else: |
| instance2_item = instance2[m] |
| r += "-" + instance2_item[1] + "(" + instance2_item[2] + ")" |
| result.append(r) |
| return " ".join(result) |
|
|
|
|
| def compute_f(match_num, test_num, gold_num): |
| """ |
| Compute the f-score based on the matching triple number, |
| triple number of AMR set 1, |
| triple number of AMR set 2 |
| Args: |
| match_num: matching triple number |
| test_num: triple number of AMR 1 (test file) |
| gold_num: triple number of AMR 2 (gold file) |
| Returns: |
| precision: match_num/test_num |
| recall: match_num/gold_num |
| f_score: 2*precision*recall/(precision+recall) |
| """ |
| if test_num == 0 or gold_num == 0: |
| return 0.00, 0.00, 0.00 |
| precision = float(match_num) / float(test_num) |
| recall = float(match_num) / float(gold_num) |
| if (precision + recall) != 0: |
| f_score = 2 * precision * recall / (precision + recall) |
| if veryVerbose: |
| print("F-score:", f_score, file=DEBUG_LOG) |
| return precision, recall, f_score |
| else: |
| if veryVerbose: |
| print("F-score:", "0.0", file=DEBUG_LOG) |
| return precision, recall, 0.00 |
|
|
|
|
| def generate_amr_lines(f1, f2): |
| """ |
| Read one AMR line at a time from each file handle |
| :param f1: file handle (or any iterable of strings) to read AMR 1 lines from |
| :param f2: file handle (or any iterable of strings) to read AMR 2 lines from |
| :return: generator of cur_amr1, cur_amr2 pairs: one-line AMR strings |
| """ |
| while True: |
| cur_amr1 = amr.AMR.get_amr_line(f1) |
| cur_amr2 = amr.AMR.get_amr_line(f2) |
| if not cur_amr1 and not cur_amr2: |
| pass |
| elif not cur_amr1: |
| print("Error: File 1 has less AMRs than file 2", file=ERROR_LOG) |
| print("Ignoring remaining AMRs", file=ERROR_LOG) |
| elif not cur_amr2: |
| print("Error: File 2 has less AMRs than file 1", file=ERROR_LOG) |
| print("Ignoring remaining AMRs", file=ERROR_LOG) |
| else: |
| yield cur_amr1, cur_amr2 |
| continue |
| break |
|
|
|
|
| def get_amr_match(cur_amr1, cur_amr2, sent_num=1, justinstance=False, justattribute=False, justrelation=False, |
| limit = None, |
| instance1 = None, attributes1 = None, relation1 = None, prefix1 = None, |
| instance2 = None, attributes2 = None, relation2 = None, prefix2 = None): |
| global iteration_num |
| if limit is not None: iteration_num = limit |
| if cur_amr1 and cur_amr2: |
| amr_pair = [] |
| for i, cur_amr in (1, cur_amr1), (2, cur_amr2): |
| try: |
| amr_pair.append(amr.AMR.parse_AMR_line(cur_amr)) |
| except Exception as e: |
| print("Error in parsing amr %d: %s" % (i, cur_amr), file=ERROR_LOG) |
| print("Please check if the AMR is ill-formatted. Ignoring remaining AMRs", file=ERROR_LOG) |
| print("Error message: %s" % e, file=ERROR_LOG) |
| amr1, amr2 = amr_pair |
| prefix1 = "a" |
| prefix2 = "b" |
| |
| amr1.rename_node(prefix1) |
| |
| amr2.rename_node(prefix2) |
| (instance1, attributes1, relation1) = amr1.get_triples() |
| (instance2, attributes2, relation2) = amr2.get_triples() |
| if verbose: |
| print("AMR pair", sent_num, file=DEBUG_LOG) |
| print("============================================", file=DEBUG_LOG) |
| print("AMR 1 (one-line):", cur_amr1, file=DEBUG_LOG) |
| print("AMR 2 (one-line):", cur_amr2, file=DEBUG_LOG) |
| print("Instance triples of AMR 1:", len(instance1), file=DEBUG_LOG) |
| print(instance1, file=DEBUG_LOG) |
| print("Attribute triples of AMR 1:", len(attributes1), file=DEBUG_LOG) |
| print(attributes1, file=DEBUG_LOG) |
| print("Relation triples of AMR 1:", len(relation1), file=DEBUG_LOG) |
| print(relation1, file=DEBUG_LOG) |
| print("Instance triples of AMR 2:", len(instance2), file=DEBUG_LOG) |
| print(instance2, file=DEBUG_LOG) |
| print("Attribute triples of AMR 2:", len(attributes2), file=DEBUG_LOG) |
| print(attributes2, file=DEBUG_LOG) |
| print("Relation triples of AMR 2:", len(relation2), file=DEBUG_LOG) |
| print(relation2, file=DEBUG_LOG) |
| |
| doinstance = doattribute = dorelation = True |
| if justinstance: |
| doattribute = dorelation = False |
| if justattribute: |
| doinstance = dorelation = False |
| if justrelation: |
| doinstance = doattribute = False |
| (best_mapping, best_match_num) = get_best_match(instance1, attributes1, relation1, |
| instance2, attributes2, relation2, |
| prefix1, prefix2, doinstance=doinstance, |
| doattribute=doattribute, dorelation=dorelation) |
| if verbose: |
| print("best match number", best_match_num, file=DEBUG_LOG) |
| print("best node mapping", best_mapping, file=DEBUG_LOG) |
| print("Best node mapping alignment:", print_alignment(best_mapping, instance1, instance2), file=DEBUG_LOG) |
| if justinstance: |
| test_triple_num = len(instance1) |
| gold_triple_num = len(instance2) |
| elif justattribute: |
| test_triple_num = len(attributes1) |
| gold_triple_num = len(attributes2) |
| elif justrelation: |
| test_triple_num = len(relation1) |
| gold_triple_num = len(relation2) |
| else: |
| test_triple_num = len(instance1) + len(attributes1) + len(relation1) |
| gold_triple_num = len(instance2) + len(attributes2) + len(relation2) |
| match_triple_dict.clear() |
| return best_match_num, test_triple_num, gold_triple_num |
|
|
|
|
| def score_amr_pairs(f1, f2, justinstance=False, justattribute=False, justrelation=False): |
| """ |
| Score one pair of AMR lines at a time from each file handle |
| :param f1: file handle (or any iterable of strings) to read AMR 1 lines from |
| :param f2: file handle (or any iterable of strings) to read AMR 2 lines from |
| :param justinstance: just pay attention to matching instances |
| :param justattribute: just pay attention to matching attributes |
| :param justrelation: just pay attention to matching relations |
| :return: generator of cur_amr1, cur_amr2 pairs: one-line AMR strings |
| """ |
| |
| total_match_num = total_test_num = total_gold_num = 0 |
| |
| for sent_num, (cur_amr1, cur_amr2) in enumerate(generate_amr_lines(f1, f2), start=1): |
| best_match_num, test_triple_num, gold_triple_num = get_amr_match(cur_amr1, cur_amr2, |
| sent_num=sent_num, |
| justinstance=justinstance, |
| justattribute=justattribute, |
| justrelation=justrelation) |
| total_match_num += best_match_num |
| total_test_num += test_triple_num |
| total_gold_num += gold_triple_num |
| |
| match_triple_dict.clear() |
| if not single_score: |
| yield compute_f(best_match_num, test_triple_num, gold_triple_num) |
| if verbose: |
| print("Total match number, total triple number in AMR 1, and total triple number in AMR 2:", file=DEBUG_LOG) |
| print(total_match_num, total_test_num, total_gold_num, file=DEBUG_LOG) |
| print("---------------------------------------------------------------------------------", file=DEBUG_LOG) |
| if single_score: |
| yield compute_f(total_match_num, total_test_num, total_gold_num) |
|
|
|
|
| def main(arguments): |
| """ |
| Main function of smatch score calculation |
| """ |
| global verbose |
| global veryVerbose |
| global iteration_num |
| global single_score |
| global pr_flag |
| global match_triple_dict |
| |
| |
| iteration_num = arguments.r + 1 |
| if arguments.ms: |
| single_score = False |
| if arguments.v: |
| verbose = True |
| if arguments.vv: |
| veryVerbose = True |
| if arguments.pr: |
| pr_flag = True |
| |
| floatdisplay = "%%.%df" % arguments.significant |
| for (precision, recall, best_f_score) in score_amr_pairs(args.f[0], args.f[1], |
| justinstance=arguments.justinstance, |
| justattribute=arguments.justattribute, |
| justrelation=arguments.justrelation): |
| |
| if pr_flag: |
| print("Precision: " + floatdisplay % precision) |
| print("Recall: " + floatdisplay % recall) |
| print("F-score: " + floatdisplay % best_f_score) |
| args.f[0].close() |
| args.f[1].close() |
|
|
|
|
| if __name__ == "__main__": |
| parser = None |
| args = None |
| |
| if sys.version_info[0] == 2 and sys.version_info[1] < 7: |
| import optparse |
|
|
| if len(sys.argv) == 1: |
| print("No argument given. Please run smatch.py -h to see the argument description.", file=ERROR_LOG) |
| exit(1) |
| parser = build_arg_parser2() |
| (args, opts) = parser.parse_args() |
| file_handle = [] |
| if args.f is None: |
| print("smatch.py requires -f option to indicate two files \ |
| containing AMR as input. Please run smatch.py -h to \ |
| see the argument description.", file=ERROR_LOG) |
| exit(1) |
| |
| assert (len(args.f) == 2) |
| for file_path in args.f: |
| if not os.path.exists(file_path): |
| print("Given file", args.f[0], "does not exist", file=ERROR_LOG) |
| exit(1) |
| file_handle.append(open(file_path)) |
| |
| args.f = tuple(file_handle) |
| |
| else: |
| import argparse |
|
|
| parser = build_arg_parser() |
| args = parser.parse_args() |
| main(args) |
|
|