Spaces:
Sleeping
Sleeping
| """ | |
| lingua graph | |
| """ | |
| # import numpy as np | |
| # logger = logging.getLogger(__name__) | |
| import copy | |
| import json | |
| import re | |
| from enum import Enum | |
| from typing import Union | |
| from .basegraph import Node, Edge, DirectedGraph, LearnableGraph, GraphVisualizer | |
| from .utils import positions2spans, CompactJSONEncoder | |
| arg_placeholder_pattern = re.compile(r'{\d+}') | |
| def sent_spans2list(spans): | |
| spans_list = list() | |
| for span in spans: | |
| if isinstance(span, str): | |
| spans_list.append({'label': span}) | |
| else: | |
| assert isinstance(span, tuple) and len(span) == 2 | |
| spans_list.append({'start': span[0], 'end': span[1]}) | |
| return spans_list | |
| def is_arg_placeholder(string): | |
| """ | |
| @param string: | |
| @return: | |
| """ | |
| return re.match(arg_placeholder_pattern, string) is not None | |
| class GPGNode(Node): | |
| """ | |
| GPGNode | |
| """ | |
| def __init__(self, id=None, pos=None, confidence=None, *args, **kwargs): | |
| self.id = id | |
| self.pos = pos | |
| self.confidence = confidence | |
| def ID(self): | |
| """ | |
| :return: | |
| """ | |
| return self.id | |
| def ID(self, id): | |
| """ | |
| setter | |
| """ | |
| self.id = id | |
| def __hash__(self): | |
| """ | |
| :return: | |
| """ | |
| return hash(self.ID) | |
| def __eq__(self, another): | |
| """ | |
| :param another: | |
| :return: | |
| """ | |
| return self.ID == another.ID | |
| def copy(self): | |
| return copy.deepcopy(self) | |
| def get_start(x): | |
| if isinstance(x, (tuple, list)): | |
| return x[0] | |
| elif isinstance(x, int): | |
| return x | |
| else: | |
| raise ValueError('unexpected span') | |
| def standardize_spans(spans): | |
| """ | |
| @param spans: | |
| @type spans: | |
| @return: | |
| @rtype: | |
| """ | |
| standardized = [] | |
| # deduplicated = [] | |
| # span_set = set() | |
| # for span in spans: | |
| # if isinstance(span, int): | |
| # span = (span, span) | |
| # if tuple(span) not in span_set: | |
| # deduplicated.append(span) | |
| # span_set.add(tuple(span)) | |
| # else: | |
| # continue | |
| # spans = deduplicated | |
| idx = 0 | |
| while idx < len(spans): | |
| span = spans[idx] | |
| if isinstance(span, int): | |
| standardized.append((span, span)) | |
| elif isinstance(span, str): | |
| standardized.append(span) | |
| # elif isinstance(span, tuple) and isinstance(span[0], str): | |
| # standardized.append(span[0]) | |
| elif isinstance(span, (tuple, list)): | |
| assert len(span) == 2 | |
| standardized.append(tuple(span)) | |
| # we merge next span if it is continuous | |
| # to_merge = idx | |
| # break_continuous = False | |
| # for j in range(idx + 1, len(spans)): | |
| # if not isinstance(spans[j], (tuple, list)): | |
| # break | |
| # if spans[j][0] == spans[j-1][1] + 1 and not break_continuous: | |
| # to_merge = j | |
| # else: | |
| # break_continuous = True | |
| # | |
| # merged_span = (span[0], spans[to_merge][1]) | |
| # idx += to_merge - idx | |
| # standardized.append(tuple(merged_span)) | |
| else: | |
| raise ValueError('Invalid span: {}'.format(span)) | |
| idx += 1 | |
| return tuple(standardized) | |
| def readable_spans(spans): | |
| """ | |
| @param spans: | |
| @type spans: | |
| @return: | |
| @rtype: | |
| """ | |
| readable = [] | |
| for span in spans: | |
| if isinstance(span, int): | |
| readable.append(span) | |
| elif isinstance(span, str): | |
| readable.append(span) | |
| else: | |
| start, end = span | |
| if start == end: | |
| span = start | |
| readable.append(span) | |
| return tuple(readable) | |
| class GPGPhraseNode(GPGNode): | |
| """ | |
| GPGPhraseNode | |
| """ | |
| def __init__(self, spans=None, id=None, pos=None, confidence=None, *args, **kwargs): | |
| super().__init__(id=id, pos=pos, confidence=confidence, *args, **kwargs) | |
| if spans is not None: | |
| self._spans = standardize_spans(spans) | |
| else: | |
| self._spans = None | |
| self.contexts = list() | |
| # | |
| def merge_continuous_spans(spans): | |
| new_spans = list() | |
| for span in spans: | |
| if isinstance(span, int): | |
| new_spans.append((span, span)) | |
| else: | |
| new_spans.append(span) | |
| spans = new_spans | |
| span_list = list() | |
| start, end = None, None | |
| for idx, span in enumerate(spans): | |
| if isinstance(span, str): | |
| if start is not None: | |
| span_list.append((start, end)) | |
| span_list.append(span) | |
| start, end = None, None | |
| else: | |
| s, e = span | |
| if len(spans) == 1: | |
| span_list.append(span) | |
| break | |
| elif idx == len(spans) - 1: | |
| if start is None: | |
| span_list.append(span) | |
| else: | |
| if end + 1 == s: | |
| span_list.append((start, e)) | |
| else: | |
| span_list.append((start, end)) | |
| span_list.append(span) | |
| else: | |
| if start is None: | |
| start, end = s, e | |
| else: | |
| if end + 1 == s: | |
| end = e | |
| else: | |
| span_list.append((start, end)) | |
| start, end = s, e | |
| return tuple(span_list) | |
| # | |
| # @staticmethod | |
| # def sort_non_str_spans(spans): | |
| # if not any(isinstance(x, str) for x in spans): | |
| # sorted_spans = sorted(spans, key=lambda x:get_start(x)) | |
| # return sorted_spans | |
| # else: | |
| # return spans | |
| def has_symbols(self): | |
| """ | |
| @return: | |
| @rtype: | |
| """ | |
| for span in self._spans: | |
| if isinstance(span, str): | |
| return True | |
| return False | |
| # | |
| # def add_span_to_head(self, span): | |
| # """ | |
| # | |
| # @param span: | |
| # @type span: | |
| # @return: | |
| # @rtype: | |
| # """ | |
| # if isinstance(span, str): | |
| # self._spans = list(self._spans) | |
| # self._spans.insert(0, span) | |
| # self._spans = tuple(self._spans) | |
| # return | |
| # | |
| # if isinstance(span, int): | |
| # span = [span, span] | |
| # if isinstance(self._spans[0], (list, tuple)) and \ | |
| # span[1] == self._spans[0][0] - 1: | |
| # self._spans = tuple([(span[0], self._spans[0][1])] + list(self._spans)[1:]) | |
| # else: | |
| # x = list(self._spans) | |
| # x.insert(0, span) | |
| # self._spans = tuple(x) | |
| # | |
| # def remove_span_from_head(self, span): | |
| # """ | |
| # | |
| # @param span: | |
| # @type span: | |
| # @return: | |
| # @rtype: | |
| # """ | |
| # if isinstance(span, tuple): | |
| # span_list = list(self._spans) | |
| # for s in list(span): | |
| # span_list.remove(s) | |
| # self._spans = standardize_spans(span_list) | |
| # | |
| # def add_span_to_tail(self, span): | |
| # """ | |
| # | |
| # @param span: | |
| # @type span: | |
| # @return: | |
| # @rtype: | |
| # """ | |
| # if isinstance(span, str): | |
| # self._spans = list(self._spans) | |
| # self._spans.append(span) | |
| # self._spans = tuple(self._spans) | |
| # return | |
| # | |
| # if isinstance(span, int): | |
| # span = [span, span] | |
| # | |
| # if isinstance(self._spans[-1], (list, tuple)) and \ | |
| # span[0] == self._spans[-1][1] + 1: | |
| # self._spans = tuple(list(self._spans)[:-1] + [(self._spans[-1][0], span[1])]) | |
| # else: | |
| # x = list(self._spans) | |
| # x.append(span) | |
| # self._spans = tuple(x) | |
| # | |
| # def add_spans_to_head(self, spans): | |
| # """ | |
| # | |
| # @param span: | |
| # @type span: | |
| # @return: | |
| # @rtype: | |
| # """ | |
| # | |
| # for span in reversed(spans): | |
| # self.add_span_to_head(span) | |
| # | |
| # def add_spans_to_tail(self, spans): | |
| # """ | |
| # | |
| # @param span: | |
| # @type span: | |
| # @return: | |
| # @rtype: | |
| # """ | |
| # | |
| # for span in spans: | |
| # self.add_span_to_tail(span) | |
| def spans(self): | |
| """ | |
| @return: | |
| @rtype: | |
| """ | |
| return self._spans | |
| def spans(self, spans): | |
| """ | |
| @param spans: | |
| @type spans: | |
| @return: | |
| @rtype: | |
| """ | |
| raise Exception("spans should not be set directly. please use GPGraph.modify_node_spans") | |
| # self._spans = standardize_spans(spans) | |
| def readable_spans(self): | |
| """ | |
| @return: | |
| @rtype: | |
| """ | |
| return readable_spans(self._spans) | |
| def __str__(self): | |
| """ | |
| :return: | |
| """ | |
| return ",".join(map(str, self.spans)) | |
| def __contains__(self, word_id): | |
| """ | |
| :param word_id: | |
| :return: | |
| """ | |
| for span in self._spans: | |
| if isinstance(span, str): | |
| if span == word_id: | |
| return True | |
| else: | |
| continue | |
| else: | |
| start, end = span | |
| if start <= word_id <= end: | |
| return True | |
| return False | |
| def words(self, with_aux=True): | |
| """ | |
| :return: | |
| """ | |
| for span in self._spans: | |
| if isinstance(span, str): | |
| if with_aux: | |
| yield span | |
| else: | |
| start, end = span | |
| for i in range(start, end + 1): | |
| yield i | |
| def indexes(self, with_aux=True): | |
| """ | |
| :return: | |
| """ | |
| for span in self._spans: | |
| if isinstance(span, str): | |
| pass | |
| else: | |
| start, end = span | |
| for i in range(start, end + 1): | |
| yield i | |
| class GPGAuxNode(GPGNode): | |
| """ | |
| GPGNode | |
| """ | |
| def __init__(self, label=None, id=None, pos=None, confidence=None, *args, **kwargs): | |
| super().__init__(id=id, pos=pos, confidence=confidence, *args, **kwargs) | |
| self.label = label | |
| self.contexts = list() | |
| def __str__(self): | |
| """ | |
| :return: | |
| """ | |
| return self.label | |
| class GPGTextNode(GPGNode): | |
| def __init__(self, text, pos, confidence=None): | |
| super().__init__() | |
| self.text = text | |
| self.pos = pos | |
| self.confidence = confidence | |
| def __str__(self): | |
| """ | |
| :return: | |
| """ | |
| return self.text | |
| import networkx as nx | |
| class GPGEdge(Edge): | |
| """ | |
| a set of relations between a pair of nodes for multiple edges. | |
| behaves like a single relation (that is, a string) | |
| for code compatability | |
| """ | |
| def __init__(self, label=None, mod=False, confidence=None): | |
| super().__init__() | |
| self.label = label | |
| self.confidence = None | |
| self.mod = mod | |
| self.contexts = [] | |
| self.confidence = confidence | |
| def __str__(self): | |
| return self.label | |
| def __bool__(self): | |
| """ | |
| :return: | |
| """ | |
| return self.label is not None | |
| def value(self): | |
| """ | |
| :return: | |
| :rtype: | |
| """ | |
| return self.label | |
| def value(self, value): | |
| """ | |
| :return: | |
| :rtype: | |
| """ | |
| self.label = value | |
| def _empty_hook(): | |
| """ | |
| _empty_hook | |
| """ | |
| pass | |
| class GraphRootMixin: | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.root_ = None | |
| def root(self): | |
| """ | |
| Return the root | |
| * roots = [nodes with zero in-degree]. | |
| * if only one root found, return it, whether it is virtual or not | |
| * else return the node with label == 'Root', that is, the virtual root | |
| Contracts: | |
| * if no virtual root, the program is running on the training data, and there is only one | |
| node with zero in-degree | |
| Returns: | |
| """ | |
| if self.root_ is None: | |
| topo_roots = [node for node in self.nodes() | |
| if self.g.in_degree[node.ID] == 0] | |
| if len(topo_roots) == 1: | |
| self.root_ = topo_roots[0] | |
| else: | |
| virtual_roots = [node for node in self.nodes() | |
| if isinstance(node, GPGAuxNode) and node.label == 'Root'] | |
| if len(virtual_roots) == 1: | |
| self.root_ = virtual_roots[0] | |
| else: | |
| raise Exception("Bad Graph with multiple roots or no roots. ") | |
| return self.root_ | |
| def root(self, value): | |
| self.root_ = value | |
| def update_root(self): | |
| topo_roots = [node for node in self.nodes() | |
| if self.g.in_degree[node.ID] == 0] | |
| if len(topo_roots) == 1: | |
| self.root_ = topo_roots[0] | |
| else: | |
| virtual_roots = [node for node in self.nodes() | |
| if isinstance(node, GPGAuxNode) and node.label == 'Root'] | |
| if len(virtual_roots) == 1: | |
| self.root_ = virtual_roots[0] | |
| else: | |
| raise Exception("Bad Graph with multiple roots or no roots. ") | |
| class TextGPGraph(DirectedGraph, LearnableGraph, GraphRootMixin): | |
| def add_node(self, node: Union[GPGTextNode, GPGAuxNode], reuse_id=False): | |
| """ | |
| Add a node to the graph | |
| """ | |
| assert isinstance(node, (GPGTextNode, GPGAuxNode)), type(node) | |
| return super().add_node(node, reuse_id=reuse_id) | |
| def add_edge(self, ni, nj, e): | |
| """ | |
| Add an edge to the graph | |
| """ | |
| if isinstance(e, str): | |
| e = GPGEdge(label=e) | |
| return super().add_edge(ni, nj, e) | |
| def node_text(self, node, interval=" "): | |
| """ | |
| Return the text of a node | |
| """ | |
| if isinstance(node, GPGAuxNode): | |
| return node.label | |
| else: | |
| return node.text | |
| def add_relation(self, ni, nj, rel): | |
| """ | |
| Add an edge to the graph | |
| """ | |
| if isinstance(rel, str): | |
| edge = GPGEdge(rel) | |
| elif isinstance(rel, GPGEdge): | |
| confidence = rel.confidence | |
| edge = GPGEdge(rel.label, confidence=confidence) | |
| else: | |
| raise Exception("Unknown rel type") | |
| return super().add_edge(ni, nj, edge) | |
| def remove_relation(self, ni, nj): | |
| """ | |
| Remove an edge from the graph | |
| """ | |
| super().remove_edge_between(ni, nj) | |
| class GPGraph(DirectedGraph, LearnableGraph, GraphRootMixin): | |
| """ | |
| Dependency graph | |
| """ | |
| def __init__(self, g=None): | |
| super().__init__(g=g) | |
| self.meta = dict() | |
| self.words = [] | |
| self.context_hook = _empty_hook | |
| self.spans2node = dict() | |
| if isinstance(g, GPGraph): | |
| self.meta = g.meta | |
| self.words = g.words | |
| self._root = g.root_ | |
| self.spans2node = g.spans2node | |
| def __copy__(self): | |
| """ | |
| :return: | |
| """ | |
| from copy import copy | |
| copied = DirectedGraph.__copy__(self) | |
| copied.meta = copy(self.meta) | |
| copied.words = copy(self.words) | |
| copied.context_hook = self.context_hook | |
| return copied | |
| def __deepcopy__(self, memodict={}): | |
| """ | |
| :param memodict: | |
| :type memodict: | |
| :return: | |
| :rtype: | |
| """ | |
| from copy import deepcopy | |
| copied = DirectedGraph.__deepcopy__(self) | |
| copied.meta = deepcopy(self.meta) | |
| copied.words = deepcopy(self.words) | |
| copied.context_hook = deepcopy(self.context_hook) | |
| return copied | |
| def set_words(self, words): | |
| """ | |
| :param words: | |
| :return: | |
| """ | |
| if isinstance(words, list): | |
| self.words = words | |
| else: | |
| raise Exception("words input not list.") | |
| # | |
| # def modify_node_spans(self, node, spans): | |
| # original_spans = node.spans | |
| # node.spans = spans | |
| # del self.spans2node[original_spans] | |
| # self.spans2node[node.spans] = node | |
| def add_words(self, words, pos=None): | |
| """ | |
| :param head: | |
| :return: | |
| """ | |
| # | |
| # if len(words) == 1 and isinstance(words[0], float): | |
| # words = tuple(words) | |
| # if words not in self.spans2node: | |
| # self.spans2node[words] = self.add_aux("(be)") | |
| # | |
| # return self.spans2node[words] | |
| # | |
| # if any(isinstance(x, float) for x in words): | |
| # raise Exception("float found") | |
| spans = positions2spans(words) | |
| return self.add_spans(spans, pos) | |
| def set_context_hook(self, hook): | |
| """ | |
| :param hook: | |
| :return: | |
| """ | |
| self.context_hook = hook | |
| def clear_context_hook(self): | |
| """ | |
| :param hook: | |
| :return: | |
| """ | |
| self.context_hook = _empty_hook | |
| def add_node(self, node, reuse_id=False): | |
| """ | |
| :param node: | |
| :return: | |
| """ | |
| assert isinstance(node, GPGPhraseNode) or isinstance(node, GPGAuxNode), f"Invalid node type: {type(node)}" | |
| context = self.context_hook() | |
| if context: | |
| node.contexts.extend(context) | |
| if isinstance(node, GPGPhraseNode): | |
| node._spans = standardize_spans(node.spans) | |
| if node.spans in self.spans2node: | |
| raise Exception("Repeated node found: {}".format(node.spans)) | |
| self.spans2node[node.spans] = super().add_node(node, reuse_id=reuse_id) | |
| return self.spans2node[node.spans] | |
| else: | |
| node = super().add_node(node, reuse_id=reuse_id) | |
| return node | |
| def modify_node_spans(self, node, spans): | |
| """ | |
| Modify the spans of a node | |
| """ | |
| assert node.spans in self.spans2node, f"Node spans {node.spans} not found in spans2node: {self.spans2node}" | |
| if self.spans2node[node.spans] == node: | |
| del self.spans2node[node.spans] | |
| node._spans = standardize_spans(spans) | |
| self.spans2node[node.spans] = node | |
| def remove_node(self, node): | |
| if isinstance(node, GPGPhraseNode): | |
| # node.spans = standardize_spans(node.spans) | |
| if node.spans in self.spans2node: | |
| del self.spans2node[node.spans] | |
| node = super().remove_node(node) | |
| return node | |
| def add_spans(self, spans, pos=None): | |
| """ | |
| :param node: | |
| :return: | |
| """ | |
| spans = standardize_spans(spans) | |
| if spans not in self.spans2node: | |
| added_node = GPGPhraseNode(spans) | |
| added_node.pos = pos | |
| self.spans2node[spans] = super().add_node(added_node) | |
| return self.spans2node[spans] | |
| def has_node(self, node: GPGNode): | |
| """ | |
| @param node: | |
| @type node: | |
| @return: | |
| @rtype: | |
| """ | |
| if self.g.has_node(node.ID): | |
| return True | |
| else: | |
| return False | |
| def has_word(self, words): | |
| """ | |
| :param word: | |
| :return: | |
| """ | |
| spans = positions2spans(words) | |
| for node in self.nodes(): | |
| if isinstance(node, GPGPhraseNode) and node.spans == spans: | |
| added_node = node | |
| return True | |
| return False | |
| def get_node_by_words(self, positions): | |
| """ | |
| @param positions: | |
| @type positions: | |
| @return: | |
| @rtype: | |
| """ | |
| spans = positions2spans(positions) | |
| spans = standardize_spans(spans) | |
| if spans in self.spans2node: | |
| return self.spans2node[spans] | |
| return None | |
| def get_node_by_spans(self, spans): | |
| spans = standardize_spans(spans) | |
| if spans in self.spans2node: | |
| return self.spans2node[spans] | |
| return None | |
| def has_relation(self, node1: GPGNode, | |
| node2: GPGNode, | |
| direct_link=True): | |
| """ | |
| :return: | |
| @param node1: | |
| @type node1: | |
| @param node2: | |
| @type node2: | |
| """ | |
| if node1 is None or node2 is None: | |
| return False | |
| if direct_link and (self.g.has_edge(node1.ID, node2.ID) or self.g.has_edge(node2.ID, node1.ID)): | |
| return True | |
| elif not direct_link and ( | |
| nx.algorithms.shortest_paths.generic.has_path(self.g, node1.ID, node2.ID) or | |
| nx.algorithms.shortest_paths.generic.has_path(self.g, node2.ID, node1.ID)): | |
| return True | |
| return False | |
| def add_aux(self, label, pos=None): | |
| """ | |
| :param label: | |
| :return: | |
| """ | |
| node = GPGAuxNode(label) | |
| node.pos = pos | |
| self.add_node(node) | |
| return node | |
| def get_aux(self, label): | |
| """ | |
| :param label: | |
| :return: | |
| """ | |
| for node_id in self.g.nodes: | |
| node = self.get_node(node_id) | |
| if isinstance(node, GPGAuxNode) and node.label == label: | |
| yield node | |
| def get_edge(self, node1, node2): | |
| """ | |
| :param node1: | |
| :param node2: | |
| :return: | |
| """ | |
| try: | |
| return super().get_edge(node1, node2) | |
| except: | |
| return None | |
| def spans(self): | |
| """ | |
| @return: | |
| @rtype: | |
| """ | |
| spans = [] | |
| for x in self.nodes(): | |
| if isinstance(x, GPGPhraseNode): | |
| for span in x.spans: | |
| if isinstance(span, str): | |
| continue | |
| elif isinstance(span, int): | |
| span = (span, span) | |
| spans.append(span) | |
| else: | |
| spans.append(tuple(span)) | |
| spans.sort(key=lambda x: x[0]) | |
| return spans | |
| def parents_on_path(self, node, ancestor): | |
| """ | |
| :param node: | |
| :param ancestor: | |
| :return: | |
| """ | |
| for path in nx.all_simple_paths(self.g, ancestor.ID, node.ID): | |
| yield self.get_node(path[-2]) | |
| def paths(self, node1, node2): | |
| """ | |
| :param node1: | |
| :param node2: | |
| :return: | |
| """ | |
| for path in nx.all_simple_paths(self.g, node1.ID, node2.ID): | |
| yield [self.get_node(x) for x in path] | |
| def replace(self, old_node, new_node): | |
| """ | |
| :param old_node: | |
| :param new_node: | |
| :return: | |
| """ | |
| if new_node.ID == old_node.ID: | |
| raise Exception("Bad business logic: cannot replace a node with itself") | |
| # if self.g.has_node(new_node.ID): | |
| # raise Exception("Bad business logic: the new node is already in the graph") | |
| if not self.has_node(new_node): | |
| self.add_node(new_node) | |
| for child, rel in self.children(old_node): | |
| self.g.add_edge(new_node.ID, child.ID, Edge=rel) | |
| for parent, rel in self.parents(old_node): | |
| self.g.add_edge(parent.ID, new_node.ID, Edge=rel) | |
| self.remove_node(old_node) | |
| def add_edge(self, start_node, end_node, edge): | |
| """ | |
| @param start_node: | |
| @type start_node: | |
| @param end_node: | |
| @type end_node: | |
| @param edge: | |
| @type edge: | |
| @return: | |
| @rtype: | |
| """ | |
| context = self.context_hook() | |
| if context: | |
| edge.contexts.extend(context) | |
| if start_node.ID not in self.g.nodes(): | |
| raise ValueError('start node not in graph') | |
| if end_node.ID not in self.g.nodes(): | |
| raise ValueError('end node not in graph') | |
| if isinstance(edge, str): | |
| edge = GPGEdge(label=edge) | |
| # self.add_node(start_node, reuse_id=True) | |
| # self.add_node(end_node, reuse_id=True) | |
| edge.start = start_node.ID | |
| edge.end = end_node.ID | |
| self.g.add_edge(start_node.ID, end_node.ID, Edge=edge) | |
| # def add_argument(self, pred_node, arg_node, index, mod=False): | |
| # """ | |
| # :param node1: | |
| # :param node2: | |
| # :param rel: | |
| # :return: | |
| # """ | |
| # # if isinstance(pred_node.ID, int) or isinstance(pred_node.ID, str): | |
| # # raise Exception("Bad ID") | |
| # # if isinstance(arg_node.ID, int) or isinstance(arg_node.ID, str): | |
| # # raise Exception("Bad ID") | |
| # if mod: | |
| # if any(self.node_text(arg_node).lower().startswith(x) | |
| # for x in {"what ", "which ", "where ", "who ", | |
| # "whom ", "when ", "why ", "how "}): | |
| # edge_label = "func.arg" | |
| # else: | |
| # edge_label = "as:pred.arg.{0}".format(index) | |
| # edge = GPGEdge(label=edge_label, mod=mod) | |
| # self.add_edge(arg_node, pred_node, edge) | |
| # else: | |
| # edge_label = "pred.arg.{0}".format(index) | |
| # edge = GPGEdge(label=edge_label, mod=mod) | |
| # self.add_edge(pred_node, arg_node, edge) | |
| # def add_mod(self, modifier, center): | |
| # """ | |
| # :param target: | |
| # :param source: | |
| # :return: | |
| # """ | |
| # edge = GPGEdge(label="modification", mod=False) | |
| # self.add_edge(center, modifier, edge) | |
| # def add_function(self, functor, argument, index=None): | |
| # """ | |
| # :param functor: | |
| # :param argument: | |
| # :return: | |
| # """ | |
| # if index is None: | |
| # index = "1" | |
| # edge_label = "func.arg" #.format(index) | |
| # # if isinstance(functor.ID, int) or isinstance(functor.ID, str): | |
| # # raise Exception("Bad ID") | |
| # # if isinstance(argument.ID, int) or isinstance(argument.ID, str): | |
| # # raise Exception("Bad ID") | |
| # edge = GPGEdge(label=edge_label, mod=False) | |
| # functor.is_func = True | |
| # self.add_edge(functor, argument, edge) | |
| # def add_ref(self, source, ref): | |
| # """ | |
| # :param target: | |
| # :param source: | |
| # :return: | |
| # """ | |
| # edge = GPGEdge(label="ref", mod=False) | |
| # self.add_edge(source, ref, edge) | |
| def add_relation(self, node1, node2, rel, confidence=None): | |
| """ | |
| :param node1: | |
| :param node2: | |
| :param rel: | |
| :return: | |
| """ | |
| if isinstance(rel, str): | |
| edge = GPGEdge(rel, confidence=confidence) | |
| elif isinstance(rel, GPGEdge): | |
| if confidence is None: | |
| confidence = rel.confidence | |
| edge = GPGEdge(rel.label, confidence=confidence) | |
| else: | |
| raise Exception("Unknown rel type") | |
| self.add_edge(node1, node2, edge) | |
| def merge_continuous_spans(self): | |
| nodes = list(self.nodes()) | |
| for node in nodes: | |
| if isinstance(node, GPGAuxNode): | |
| continue | |
| spans = node.spans | |
| span_tuple = GPGPhraseNode.merge_continuous_spans(spans) | |
| # print('old', spans) | |
| # print('new', span_tuple) | |
| #del self.spans2node[node.spans] | |
| self.modify_node_spans(node, span_tuple) | |
| #node.spans = span_list | |
| #self.spans2node[node.spans] = node | |
| def remove_relation(self, node1, node2): | |
| """ | |
| :param node1: | |
| :param node2: | |
| :return: | |
| """ | |
| super().remove_edge_between(node1, node2) | |
| def relations(self): | |
| """ | |
| :return: | |
| """ | |
| return super().edges() | |
| def node_text(self, node, interval=""): | |
| """ | |
| @param node: | |
| @type node: | |
| @return: | |
| @rtype: | |
| """ | |
| if isinstance(node, GPGPhraseNode): | |
| node_texts = [] | |
| for span in node.spans: | |
| if isinstance(span, str): | |
| node_texts.append(span) | |
| elif isinstance(span, tuple) and isinstance(span[0], str): | |
| node_texts.append(span[0]) | |
| else: | |
| start, end = span | |
| for i in range(start, end + 1): | |
| node_texts.append(self.words[i]) | |
| node_text = interval.join(node_texts) | |
| else: | |
| node_text = node.label | |
| return node_text | |
| def topological_sort(self): | |
| """ | |
| :return: | |
| """ | |
| for id in nx.topological_sort(self.g): | |
| yield self.get_node(id) | |
| def parse(json_obj): | |
| """ | |
| Parse a JSON object into an GPGraph. | |
| :param json_obj: JSON object representing the graph | |
| :param check_valid: Whether to validate the graph after parsing | |
| :return: GPGraph instance | |
| """ | |
| if isinstance(json_obj, str): | |
| json_obj = json.loads(json_obj) | |
| assert isinstance(json_obj, dict) | |
| graph = GPGraph() | |
| graph.meta = json_obj["meta"] | |
| graph.words = [0] * len(json_obj["words"]) | |
| if isinstance(json_obj["words"][0], list): | |
| for id, word in json_obj["words"]: | |
| graph.words[id] = word | |
| elif isinstance(json_obj["words"][0], str): | |
| for id, word in enumerate(json_obj["words"]): | |
| graph.words[id] = word | |
| else: | |
| raise TypeError("Invalid word format") | |
| if len(json_obj["graph"]["nodes"]) == 0: | |
| return graph | |
| nodes = dict() | |
| for node_info in json_obj["graph"]["nodes"]: | |
| if isinstance(node_info, list): | |
| id, spans, pos = node_info[:3] | |
| confidence = node_info[3] if len(node_info) > 3 else None | |
| elif isinstance(node_info, dict): | |
| spans = node_info['spans'] | |
| pos = node_info['type'] | |
| id = len(nodes) | |
| confidence = node_info.get('confidence') | |
| else: | |
| raise TypeError("Invalid node format") | |
| if isinstance(spans, str) or (isinstance(spans, (list, tuple)) and len(spans) == 1 and isinstance(spans[0], str)): | |
| aux_node = GPGAuxNode(spans[0] if isinstance(spans, (list, tuple)) else spans) | |
| aux_node.ID = id | |
| node = graph.add_node(aux_node, reuse_id=True) | |
| else: | |
| node = GPGPhraseNode(spans) | |
| node.ID = id | |
| graph.add_node(node, reuse_id=True) | |
| node.pos = pos | |
| node.confidence = confidence | |
| nodes[id] = node | |
| graph.node_id_base = max([node.ID for node in graph.nodes()]) + 1 | |
| for edge_info in json_obj["graph"]["edges"]: | |
| if isinstance(edge_info, list): | |
| n1_id, edge_label, n2_id = edge_info[:3] | |
| confidence = edge_info[3] if len(edge_info) > 3 else None | |
| elif isinstance(edge_info, dict): | |
| n1_id = edge_info['start'] | |
| n2_id = edge_info['end'] | |
| edge_label = edge_info['label'] | |
| confidence = edge_info.get('confidence') | |
| else: | |
| raise TypeError("Invalid edge format") | |
| node1 = nodes[n1_id] | |
| node2 = nodes[n2_id] | |
| graph.add_relation(node1, node2, edge_label, confidence) | |
| return graph | |
| def data(self): | |
| """ | |
| @return: | |
| @rtype: | |
| """ | |
| graph = dict() | |
| graph["nodes"] = [] | |
| node2idx = dict() | |
| for idx, node in enumerate(self.nodes()): | |
| node2idx[node.ID] = idx | |
| if isinstance(node, GPGPhraseNode): | |
| if node.confidence is None: | |
| graph["nodes"].append({'spans': node.spans, 'type': node.pos}) | |
| else: | |
| graph["nodes"].append({'spans': node.spans, 'type': node.pos, 'confidence': node.confidence}) | |
| else: | |
| if node.confidence is None: | |
| graph["nodes"].append({'spans': node.label, 'type': node.pos}) | |
| else: | |
| graph["nodes"].append({'spans': node.label, 'type': node.pos, 'confidence': node.confidence}) | |
| graph["edges"] = [] | |
| for idx, (n1, edge, n2) in enumerate(self.relations()): | |
| n1_id = node2idx[n1.ID] | |
| n2_id = node2idx[n2.ID] | |
| if edge.confidence is None: | |
| graph["edges"].append({'start': n1_id, 'end': n2_id, 'label': edge.label}) | |
| else: | |
| graph["edges"].append({'start': n1_id, 'end': n2_id, 'label': edge.label, 'confidence': edge.confidence}) | |
| data = dict() | |
| data["meta"] = self.meta | |
| data['words'] = self.words | |
| data['graph'] = graph | |
| return data | |
| def save(self, output_file_path): | |
| """ | |
| :param output_file_path: | |
| :return: | |
| """ | |
| data = self.data() | |
| with open(output_file_path, "w", encoding="UTF8") as output_file: | |
| json.dump(data, output_file, cls=CompactJSONEncoder, ensure_ascii=False) | |
| class GPGraphVisualizer(GraphVisualizer): | |
| """ | |
| GPGraphVisualizer | |
| """ | |
| def __init__(self, debug=False): | |
| self.debug = debug | |
| def escape(self, node_text): | |
| """ | |
| @param node_text: | |
| @return: | |
| """ | |
| special_tokens = '{}<>"' | |
| for token in special_tokens: | |
| node_text = node_text.replace(token, "\\" + token) | |
| return node_text | |
| def node_label(self, graph, node, no_text=False, *args, **kwargs): | |
| """ | |
| :param node: | |
| :param dep_graph: | |
| :return: | |
| """ | |
| components = [] | |
| components.append(str(node.ID)) | |
| if no_text: | |
| node_text = "" | |
| else: | |
| node_text = self.escape(graph.node_text(node)) | |
| components.append(node_text) | |
| if isinstance(node, GPGPhraseNode): | |
| span_str = self.escape(str(tuple(node.readable_spans))) | |
| components.append(span_str) | |
| x = node.pos | |
| if x is None: | |
| x = 'None' | |
| components.append(x) | |
| if hasattr(node, 'concept'): | |
| components.append(node.concept) | |
| label = "{0}".format(" | ".join(components)) | |
| if self.debug and node.contexts: | |
| label = "{{{0}}}".format(" | ".join([label, "\n".join(node.contexts)])) | |
| return label | |
| def node_style(self, graph, node, *args, **kwargs): | |
| """ | |
| @param node: | |
| @return: | |
| """ | |
| style = {} | |
| style['shape'] = "record" | |
| style['fillcolor'] = "grey" | |
| style['style'] = "filled" | |
| return style | |
| def edge_label(self, graph, edge, *args, **kwargs): | |
| """ | |
| @param edge: | |
| @param debug: | |
| @return: | |
| """ | |
| edge_label = edge.label | |
| if self.debug and edge.contexts: | |
| edge_label = "{{{0}}}".format("|".join([edge_label, "\n".join(edge.contexts)])) | |
| return edge_label | |
| def edge_style(self, graph, edge, *args, **kwargs): | |
| """ | |
| @param node: | |
| @return: | |
| """ | |
| style = {} | |
| return style | |
| from typing import List | |
| import string | |
| def get_word_positions(sentence: str, words: List[str]) -> List[tuple]: | |
| """ | |
| Get the starting and ending character positions for each word in the sentence. | |
| Args: | |
| sentence (str): The original sentence. | |
| words (List[str]): The list of words in the sentence. | |
| Returns: | |
| List[tuple]: A list of tuples where each tuple contains the starting and ending positions of a word. | |
| """ | |
| positions = [] | |
| current_pos = 0 | |
| for word in words: | |
| start_pos = sentence.find(word, current_pos) | |
| if start_pos == -1: | |
| raise ValueError(f"Word '{word}' in [{words}]not found in the sentence [{sentence}] starting from position {current_pos}.") | |
| end_pos = start_pos + len(word) | |
| positions.append((start_pos, end_pos)) | |
| current_pos = end_pos | |
| return positions | |
| def get_node_text(node: GPGPhraseNode, sentence: str, word_positions: List[tuple]): | |
| word_indexes = list(node.words(with_aux=True)) | |
| node_label = "" | |
| for w in word_indexes: | |
| if isinstance(w, int): | |
| word_pos = word_positions[w] | |
| word_text = sentence[word_pos[0]: word_pos[1]] | |
| if word_pos[0] > 0 and sentence[word_pos[0] - 1] in string.whitespace: | |
| word_text = " " + word_text | |
| node_label += word_text | |
| else: | |
| node_label += w # should add space ? += (" " + " ")? | |
| return node_label | |
| class GraphValidator(object): | |
| def name(self): | |
| return None | |
| """ | |
| check whether the graph is valid, and return the severity of the error, and details of the error | |
| The severity of the error: | |
| * is an float between 0 and 1. The larger the value, the more severe the error. | |
| * 0 means the graph is perfect in the aspect of the check | |
| """ | |
| def validate(self, graph: GPGraph): | |
| pass | |