Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| """ | |
| AMR (Abstract Meaning Representation) structure | |
| For detailed description of AMR, see http://www.isi.edu/natural-language/amr/a.pdf | |
| """ | |
| from __future__ import print_function | |
| from collections import defaultdict | |
| import sys | |
| # change this if needed | |
| ERROR_LOG = sys.stderr | |
| # change this if needed | |
| DEBUG_LOG = sys.stderr | |
| class AMR(object): | |
| """ | |
| AMR is a rooted, labeled graph to represent semantics. | |
| This class has the following members: | |
| nodes: list of node in the graph. Its ith element is the name of the ith node. For example, a node name | |
| could be "a1", "b", "g2", .etc | |
| node_values: list of node labels (values) of the graph. Its ith element is the value associated with node i in | |
| nodes list. In AMR, such value is usually a semantic concept (e.g. "boy", "want-01") | |
| root: root node name | |
| relations: list of edges connecting two nodes in the graph. Each entry is a link between two nodes, i.e. a triple | |
| <relation name, node1 name, node 2 name>. In AMR, such link denotes the relation between two semantic | |
| concepts. For example, "arg0" means that one of the concepts is the 0th argument of the other. | |
| attributes: list of edges connecting a node to an attribute name and its value. For example, if the polarity of | |
| some node is negative, there should be an edge connecting this node and "-". A triple < attribute name, | |
| node name, attribute value> is used to represent such attribute. It can also be viewed as a relation. | |
| """ | |
| def __init__(self, node_list=None, node_value_list=None, relation_list=None, attribute_list=None): | |
| """ | |
| node_list: names of nodes in AMR graph, e.g. "a11", "n" | |
| node_value_list: values of nodes in AMR graph, e.g. "group" for a node named "g" | |
| relation_list: list of relations between two nodes | |
| attribute_list: list of attributes (links between one node and one constant value) | |
| """ | |
| # initialize AMR graph nodes using list of nodes name | |
| # root, by default, is the first in var_list | |
| if node_list is None: | |
| self.nodes = [] | |
| self.root = None | |
| else: | |
| self.nodes = node_list[:] | |
| if len(node_list) != 0: | |
| self.root = node_list[0] | |
| else: | |
| self.root = None | |
| if node_value_list is None: | |
| self.node_values = [] | |
| else: | |
| self.node_values = node_value_list[:] | |
| if relation_list is None: | |
| self.relations = [] | |
| else: | |
| self.relations = relation_list[:] | |
| if attribute_list is None: | |
| self.attributes = [] | |
| else: | |
| self.attributes = attribute_list[:] | |
| def rename_node(self, prefix): | |
| """ | |
| Rename AMR graph nodes to prefix + node_index to avoid nodes with the same name in two different AMRs. | |
| """ | |
| node_map_dict = {} | |
| # map each node to its new name (e.g. "a1") | |
| for i in range(0, len(self.nodes)): | |
| node_map_dict[self.nodes[i]] = prefix + str(i) | |
| # update node name | |
| for i, v in enumerate(self.nodes): | |
| self.nodes[i] = node_map_dict[v] | |
| # update node name in relations | |
| for node_relations in self.relations: | |
| for i, l in enumerate(node_relations): | |
| node_relations[i][1] = node_map_dict[l[1]] | |
| def get_triples(self): | |
| """ | |
| Get the triples in three lists. | |
| instance_triple: a triple representing an instance. E.g. instance(w, want-01) | |
| attribute triple: relation of attributes, e.g. polarity(w, - ) | |
| and relation triple, e.g. arg0 (w, b) | |
| """ | |
| instance_triple = [] | |
| relation_triple = [] | |
| attribute_triple = [] | |
| for i in range(len(self.nodes)): | |
| instance_triple.append(("instance", self.nodes[i], self.node_values[i])) | |
| # l[0] is relation name | |
| # l[1] is the other node this node has relation with | |
| for l in self.relations[i]: | |
| relation_triple.append((l[0], self.nodes[i], l[1])) | |
| # l[0] is the attribute name | |
| # l[1] is the attribute value | |
| for l in self.attributes[i]: | |
| attribute_triple.append((l[0], self.nodes[i], l[1])) | |
| return instance_triple, attribute_triple, relation_triple | |
| def get_triples2(self): | |
| """ | |
| Get the triples in two lists: | |
| instance_triple: a triple representing an instance. E.g. instance(w, want-01) | |
| relation_triple: a triple representing all relations. E.g arg0 (w, b) or E.g. polarity(w, - ) | |
| Note that we do not differentiate between attribute triple and relation triple. Both are considered as relation | |
| triples. | |
| All triples are represented by (triple_type, argument 1 of the triple, argument 2 of the triple) | |
| """ | |
| instance_triple = [] | |
| relation_triple = [] | |
| for i in range(len(self.nodes)): | |
| # an instance triple is instance(node name, node value). | |
| # For example, instance(b, boy). | |
| instance_triple.append(("instance", self.nodes[i], self.node_values[i])) | |
| # l[0] is relation name | |
| # l[1] is the other node this node has relation with | |
| for l in self.relations[i]: | |
| relation_triple.append((l[0], self.nodes[i], l[1])) | |
| # l[0] is the attribute name | |
| # l[1] is the attribute value | |
| for l in self.attributes[i]: | |
| relation_triple.append((l[0], self.nodes[i], l[1])) | |
| return instance_triple, relation_triple | |
| def __str__(self): | |
| """ | |
| Generate AMR string for better readability | |
| """ | |
| lines = [] | |
| for i in range(len(self.nodes)): | |
| lines.append("Node "+ str(i) + " " + self.nodes[i]) | |
| lines.append("Value: " + self.node_values[i]) | |
| lines.append("Relations:") | |
| for relation in self.relations[i]: | |
| lines.append("Node " + relation[1] + " via " + relation[0]) | |
| for attribute in self.attributes[i]: | |
| lines.append("Attribute: " + attribute[0] + " value " + attribute[1]) | |
| return "\n".join(lines) | |
| def __repr__(self): | |
| return self.__str__() | |
| def output_amr(self): | |
| """ | |
| Output AMR string | |
| """ | |
| print(self.__str__(), file=DEBUG_LOG) | |
| def get_amr_line(input_f): | |
| """ | |
| Read the file containing AMRs. AMRs are separated by a blank line. | |
| Each call of get_amr_line() returns the next available AMR (in one-line form). | |
| Note: this function does not verify if the AMR is valid | |
| """ | |
| cur_amr = [] | |
| has_content = False | |
| for line in input_f: | |
| line = line.strip() | |
| if line == "": | |
| if not has_content: | |
| # empty lines before current AMR | |
| continue | |
| else: | |
| # end of current AMR | |
| break | |
| if line.strip().startswith("#"): | |
| # ignore the comment line (starting with "#") in the AMR file | |
| continue | |
| else: | |
| has_content = True | |
| cur_amr.append(line.strip()) | |
| return "".join(cur_amr) | |
| def parse_AMR_line(line): | |
| """ | |
| Parse a AMR from line representation to an AMR object. | |
| This parsing algorithm scans the line once and process each character, in a shift-reduce style. | |
| """ | |
| # Current state. It denotes the last significant symbol encountered. 1 for (, 2 for :, 3 for /, | |
| # and 0 for start state or ')' | |
| # Last significant symbol is ( --- start processing node name | |
| # Last significant symbol is : --- start processing relation name | |
| # Last significant symbol is / --- start processing node value (concept name) | |
| # Last significant symbol is ) --- current node processing is complete | |
| # Note that if these symbols are inside parenthesis, they are not significant symbols. | |
| state = 0 | |
| # node stack for parsing | |
| stack = [] | |
| # current not-yet-reduced character sequence | |
| cur_charseq = [] | |
| # key: node name value: node value | |
| node_dict = {} | |
| # node name list (order: occurrence of the node) | |
| node_name_list = [] | |
| # key: node name: value: list of (relation name, the other node name) | |
| node_relation_dict1 = defaultdict(list) | |
| # key: node name, value: list of (attribute name, const value) or (relation name, unseen node name) | |
| node_relation_dict2 = defaultdict(list) | |
| # current relation name | |
| cur_relation_name = "" | |
| # having unmatched quote string | |
| in_quote = False | |
| for i, c in enumerate(line.strip()): | |
| if c == " ": | |
| # allow space in relation name | |
| if state == 2: | |
| cur_charseq.append(c) | |
| continue | |
| if c == "\"": | |
| # flip in_quote value when a quote symbol is encountered | |
| # insert placeholder if in_quote from last symbol | |
| if in_quote: | |
| cur_charseq.append('¦') | |
| in_quote = not in_quote | |
| elif c == "(": | |
| # not significant symbol if inside quote | |
| if in_quote: | |
| cur_charseq.append(c) | |
| continue | |
| # get the attribute name | |
| # e.g :arg0 (x ... | |
| # at this point we get "arg0" | |
| if state == 2: | |
| # in this state, current relation name should be empty | |
| if cur_relation_name != "": | |
| print("Format error when processing ", line[0:i + 1], file=ERROR_LOG) | |
| return None | |
| # update current relation name for future use | |
| cur_relation_name = "".join(cur_charseq).strip() | |
| cur_charseq[:] = [] | |
| state = 1 | |
| elif c == ":": | |
| # not significant symbol if inside quote | |
| if in_quote: | |
| cur_charseq.append(c) | |
| continue | |
| # Last significant symbol is "/". Now we encounter ":" | |
| # Example: | |
| # :OR (o2 / *OR* | |
| # :mod (o3 / official) | |
| # gets node value "*OR*" at this point | |
| if state == 3: | |
| node_value = "".join(cur_charseq) | |
| # clear current char sequence | |
| cur_charseq[:] = [] | |
| # pop node name ("o2" in the above example) | |
| cur_node_name = stack[-1] | |
| # update node name/value map | |
| node_dict[cur_node_name] = node_value | |
| # Last significant symbol is ":". Now we encounter ":" | |
| # Example: | |
| # :op1 w :quant 30 | |
| # or :day 14 :month 3 | |
| # the problem is that we cannot decide if node value is attribute value (constant) | |
| # or node value (variable) at this moment | |
| elif state == 2: | |
| temp_attr_value = "".join(cur_charseq) | |
| cur_charseq[:] = [] | |
| parts = temp_attr_value.split() | |
| if len(parts) < 2: | |
| print("Error in processing; part len < 2", line[0:i + 1], file=ERROR_LOG) | |
| return None | |
| # For the above example, node name is "op1", and node value is "w" | |
| # Note that this node name might not be encountered before | |
| relation_name = parts[0].strip() | |
| relation_value = parts[1].strip() | |
| # We need to link upper level node to the current | |
| # top of stack is upper level node | |
| if len(stack) == 0: | |
| print("Error in processing", line[:i], relation_name, relation_value, file=ERROR_LOG) | |
| return None | |
| # if we have not seen this node name before | |
| if relation_value not in node_dict: | |
| node_relation_dict2[stack[-1]].append((relation_name, relation_value)) | |
| else: | |
| node_relation_dict1[stack[-1]].append((relation_name, relation_value)) | |
| state = 2 | |
| elif c == "/": | |
| if in_quote: | |
| cur_charseq.append(c) | |
| continue | |
| # Last significant symbol is "(". Now we encounter "/" | |
| # Example: | |
| # (d / default-01 | |
| # get "d" here | |
| if state == 1: | |
| node_name = "".join(cur_charseq) | |
| cur_charseq[:] = [] | |
| # if this node name is already in node_dict, it is duplicate | |
| if node_name in node_dict: | |
| print("Duplicate node name ", node_name, " in parsing AMR", file=ERROR_LOG) | |
| return None | |
| # push the node name to stack | |
| stack.append(node_name) | |
| # add it to node name list | |
| node_name_list.append(node_name) | |
| # if this node is part of the relation | |
| # Example: | |
| # :arg1 (n / nation) | |
| # cur_relation_name is arg1 | |
| # node name is n | |
| # we have a relation arg1(upper level node, n) | |
| if cur_relation_name != "": | |
| # if relation name ends with "-of", e.g."arg0-of", | |
| # it is reverse of some relation. For example, if a is "arg0-of" b, | |
| # we can also say b is "arg0" a. | |
| # If the relation name ends with "-of", we store the reverse relation. | |
| if True or not cur_relation_name.endswith("-of"): | |
| # stack[-2] is upper_level node we encountered, as we just add node_name to stack | |
| node_relation_dict1[stack[-2]].append((cur_relation_name, node_name)) | |
| else: | |
| # cur_relation_name[:-3] is to delete "-of" | |
| node_relation_dict1[node_name].append((cur_relation_name[:-3], stack[-2])) | |
| # clear current_relation_name | |
| cur_relation_name = "" | |
| else: | |
| # error if in other state | |
| print("Error in parsing AMR", line[0:i + 1], file=ERROR_LOG) | |
| return None | |
| state = 3 | |
| elif c == ")": | |
| if in_quote: | |
| cur_charseq.append(c) | |
| continue | |
| # stack should be non-empty to find upper level node | |
| if len(stack) == 0: | |
| print("Unmatched parenthesis at position", i, "in processing", line[0:i + 1], file=ERROR_LOG) | |
| return None | |
| # Last significant symbol is ":". Now we encounter ")" | |
| # Example: | |
| # :op2 "Brown") or :op2 w) | |
| # get \"Brown\" or w here | |
| if state == 2: | |
| temp_attr_value = "".join(cur_charseq) | |
| cur_charseq[:] = [] | |
| parts = temp_attr_value.split() | |
| if len(parts) < 2: | |
| print("Error processing", line[:i + 1], temp_attr_value, file=ERROR_LOG) | |
| return None | |
| relation_name = parts[0].strip() | |
| relation_value = parts[1].strip() | |
| # store reverse of the relation | |
| # we are sure relation_value is a node here, as "-of" relation is only between two nodes | |
| if False and relation_name.endswith("-of"): | |
| node_relation_dict1[relation_value].append((relation_name[:-3], stack[-1])) | |
| # attribute value not seen before | |
| # Note that it might be a constant attribute value, or an unseen node | |
| # process this after we have seen all the node names | |
| elif relation_value not in node_dict: | |
| node_relation_dict2[stack[-1]].append((relation_name, relation_value)) | |
| else: | |
| node_relation_dict1[stack[-1]].append((relation_name, relation_value)) | |
| # Last significant symbol is "/". Now we encounter ")" | |
| # Example: | |
| # :arg1 (n / nation) | |
| # we get "nation" here | |
| elif state == 3: | |
| node_value = "".join(cur_charseq) | |
| cur_charseq[:] = [] | |
| cur_node_name = stack[-1] | |
| # map node name to its value | |
| node_dict[cur_node_name] = node_value | |
| # pop from stack, as the current node has been processed | |
| stack.pop() | |
| cur_relation_name = "" | |
| state = 0 | |
| else: | |
| # not significant symbols, so we just shift. | |
| cur_charseq.append(c) | |
| #create data structures to initialize an AMR | |
| node_value_list = [] | |
| relation_list = [] | |
| attribute_list = [] | |
| for v in node_name_list: | |
| if v not in node_dict: | |
| print("Error: Node name not found", v, file=ERROR_LOG) | |
| return None | |
| else: | |
| node_value_list.append(node_dict[v]) | |
| # build relation list and attribute list for this node | |
| node_rel_list = [] | |
| node_attr_list = [] | |
| if v in node_relation_dict1: | |
| for v1 in node_relation_dict1[v]: | |
| node_rel_list.append([v1[0], v1[1]]) | |
| if v in node_relation_dict2: | |
| for v2 in node_relation_dict2[v]: | |
| # if value is in quote, it is a constant value | |
| # strip the quote and put it in attribute map | |
| if v2[1][0] == "\"" and v2[1][-1] == "\"": | |
| node_attr_list.append([[v2[0]], v2[1][1:-1]]) | |
| # if value is a node name | |
| elif v2[1] in node_dict: | |
| node_rel_list.append([v2[0], v2[1]]) | |
| else: | |
| node_attr_list.append([v2[0], v2[1]]) | |
| # each node has a relation list and attribute list | |
| relation_list.append(node_rel_list) | |
| attribute_list.append(node_attr_list) | |
| # add TOP as an attribute. The attribute value is the top node value | |
| attribute_list[0].append(["TOP", node_value_list[0]]) | |
| result_amr = AMR(node_name_list, node_value_list, relation_list, attribute_list) | |
| return result_amr | |
| # test AMR parsing | |
| # run by amr.py [file containing AMR] | |
| # a unittest can also be used. | |
| if __name__ == "__main__": | |
| if len(sys.argv) < 2: | |
| print("No file given", file=ERROR_LOG) | |
| exit(1) | |
| amr_count = 1 | |
| for line in open(sys.argv[1]): | |
| cur_line = line.strip() | |
| if cur_line == "" or cur_line.startswith("#"): | |
| continue | |
| print("AMR", amr_count, file=DEBUG_LOG) | |
| current = AMR.parse_AMR_line(cur_line) | |
| current.output_amr() | |
| amr_count += 1 | |