""" Graph data problem designed for graph2graph learning """ from typing import List import networkx as nx class Node: """ Node """ # def __init__(self, ID=None): # """ # # :param id: # :param object: # :param rep: # :param state: # :param prob: # """ # # self.ID = ID @property def ID(self): """ :return: :rtype: """ pass @ID.setter def ID(self, id): """ :param id: :type id: :return: :rtype: """ pass def __hash__(self): """ :return: """ return hash(self.ID) def __eq__(self, another): """ :param another: :return: """ return self.ID == another.ID class Edge(object): """ Edge """ def __init__(self, start=None, end=None): """ :param object: :param rep: :param state: :param prob: """ self.start = start self.end = end class Graph(object): """ Graph """ def __init__(self, g=None): """ init graph """ super().__init__() if g is None: self.g = nx.Graph() self.node_id_base = 0 else: import copy assert (isinstance(g, nx.Graph) or isinstance(g, Graph)) if isinstance(g, Graph): self.g = copy.deepcopy(g.g) self.node_id_base = g.node_id_base else: self.g = copy.deepcopy(g) self.node_id_base = max(g.nodes, default=0) + 1 def nodes(self): """ :return: """ for node in self.g.nodes: yield self.get_node(node) def edges(self): """ :return: """ for s, e in self.g.edges(): edge = self.g[s][e]["Edge"] s_node = self.get_node(s) e_node = self.get_node(e) yield (s_node, edge, e_node) def has_node(self, node): """ :param node: :return: """ return node.ID in self.g def has_edge(self, node1, node2): """ :param node1: :param node2: :return: """ return node2.ID in self.g[node1.ID] def number_of_nodes(self): """ :return: """ return nx.number_of_nodes(self.g) def number_of_edges(self): """ :return: """ return nx.number_of_edges(self.g) def get_node(self, node_id): """ :param node_id: :return: """ return self.g.nodes[node_id]["Node"] def remove_node(self, node): """ :param node: :return: """ self.g.remove_node(node.ID) def remove_edge(self, edge): """ :param node: :return: """ self.g.remove_edge(edge.start, edge.end) def remove_edge_between(self, node1, node2): """ :param node1: :type node1: :param node2: :type node2: :return: :rtype: """ if self.g.has_edge(node1.ID, node2.ID): self.g.remove_edge(node1.ID, node2.ID) def get_edge(self, node1, node2): """ :param node1_id: :param node2_id: :return: """ if isinstance(node1, Node): node1 = node1.ID if isinstance(node2, Node): node2 = node2.ID """ if type(node1) is not int: node1 = node1.ID if type(node2) is not int: node2 = node2.ID """ try: edge = self.g[node1][node2]["Edge"] except KeyError as e: raise Exception("There is no edge between node {0} and {1}".format(node1, node2)) return edge def add_node(self, n, reuse_id=False): """ :param n: :param id: :return: """ if reuse_id: node_id = n.ID else: node_id = self.node_id_base self.node_id_base += 1 n.ID = node_id self.g.add_node(node_id, Node=n) return n def add_edge(self, ni, nj, e): """ :param ni: :param eij: :param nj: :return: """ if not isinstance(ni, Node): ni = self.get_node(ni) if not isinstance(nj, Node): nj = self.get_node(nj) e.start = ni.ID e.end = nj.ID self.g.add_edge(ni.ID, nj.ID, Edge=e) def neighbors(self, node): """ :param ni: :return: """ for nj in self.g[node.ID]: eij = self.g[node.ID][nj]["Edge"] yield eij, self.get_node(nj) def connected_components(self): """ :return: """ components = nx.algorithms.components.connected_components(self.g) for component in components: yield [self.get_node(x) for x in component] def breadth_first_dag(self, start_node): """ :return: """ dag = DirectedGraph() for node in self.nodes(): dag.add_node(node.copy(), reuse_id=True) edges = nx.bfs_edges(self.g, start_node.ID) orderd_nodes = [start_node.ID] + [v for u, v in edges] for i, u in enumerate(orderd_nodes): for j, v in enumerate(orderd_nodes): if j <= i: continue node_u = self.get_node(u) node_v = self.get_node(v) if self.has_edge(node_u, node_v): edge = self.get_edge(node_u, node_v).copy() dag.add_edge(node_u, node_v, edge) assert self.number_of_nodes() == dag.number_of_nodes() assert self.number_of_edges() == dag.number_of_edges() return dag def breadth_first_tree(self, start_node): """ :return: """ dag = DirectedGraph() dag.add_node(start_node.copy(), reuse_id=True) edges = nx.bfs_edges(self.g, start_node.ID) def __get_or_copy_node(u): try: node_u = dag.get_node(u) except: node_u = self.get_node(u).copy() dag.add_node(node_u, reuse_id=True) return node_u for u, v in edges: node_u = __get_or_copy_node(u) node_v = __get_or_copy_node(v) edge = self.get_edge(u, v) dag.add_edge(node_u, node_v, edge) return dag def __copy__(self): """ :return: """ copied = type(self)() copied.g = self.g.copy() copied.node_id_base = self.node_id_base return copied def __deepcopy__(self, memodict={}): """ :param memodict: :type memodict: :return: :rtype: """ from copy import deepcopy # copied_g = type(self.g)() # copied = type(self)() memodict[id(self)] = copied for node in self.nodes(): new_node = deepcopy(node) new_node = copied.add_node(new_node, reuse_id=True) assert new_node.ID == node.ID, "Node ID is not copied correctly {0} {1}".format(new_node.ID, node.ID) for (s_node, edge, e_node) in self.edges(): copied.add_edge(s_node, e_node, deepcopy(edge)) copied.node_id_base = self.node_id_base return copied def offsprings(self, node, filter=None): """ :param node: :return: """ for node_id in nx.dfs_postorder_nodes(self.g, node.ID): node = self.get_node(node_id) if not filter or filter(node): yield self.get_node(node_id) def subgraph(self, nodes): """ :param nodes: :return: """ node_ids = [n.ID if isinstance(n, Node) else n for n in nodes] subgraph = self.g.subgraph(node_ids).copy() result = self.__class__() result.g = subgraph return result def has_path(self, node1, node2): """ :param node1: :param node2: :return: """ return nx.algorithms.shortest_paths.has_path(self.g, node1.ID, node2.ID) def dual(self): """ return the dual graph the dual graph is the graph with edges corresponding nodes and nodes corresponding edges """ dual = Graph() edge_node_map = dict() for edge in self.edges(): node = Node(value=edge.value) dual.add_node(node) edge_node_map[(edge.start, edge.end)] = node edge_node_map[(edge.end, edge.start)] = node for node in self.nodes(): edges = list(self.g.edges(node.ID)) # since the end node is added assert len(edges) >= 2, "Edge number should larger than 2 " \ "since the end node is added" for idx1, (edge1_start, edge1_end) in enumerate(edges): for (edge2_start, edge2_end) in edges[idx1 + 1:]: node1 = edge_node_map[(edge1_start, edge1_end)] node2 = edge_node_map[(edge2_start, edge2_end)] dual.add_edge(node1, node2, Edge(value=node.value)) return dual # # def visualize(self, file_name=None): # """ # # :return: # """ # # visual_g = type(self.g)() # # for node in self.nodes(): # visual_g.add_node(node.ID, label=self.node_label(node)) # # for node_s, edge, node_e in self.edges(): # # visual_g.add_edge(node_s.ID, node_e.ID, label=self.edge_label(edge)) # # from networkx.drawing.nx_agraph import graphviz_layout, to_agraph # # A = to_agraph(visual_g) # if file_name: # A.draw(file_name, prog="dot") # # return A.to_string() class DirectedGraph(Graph): """ Directed Graph """ def __init__(self, g=None): """ :param edge_identifier: """ if g is None: g = nx.DiGraph() super().__init__(g=g) def is_connected(self): """ :return: """ return nx.algorithms.components.is_weakly_connected(self.g) def connected_components(self): """ :return: """ components = nx.algorithms.components.weakly_connected_components(self.g) for component in components: yield [self.get_node(x) for x in component] def is_leaf(self, node): if len(list(self.children(node))) == 0: return True return False def children(self, node): """ :param node: :return: """ for child_id in self.g.successors(node.ID): child = self.get_node(child_id) rel = self.get_edge(node, child) yield child, rel def offsprings(self, node, filter=None): """ :param node: :return: """ yield node for node_id in nx.descendants(self.g, node.ID): node = self.get_node(node_id) if not filter or filter(node): yield self.get_node(node_id) def ancestors(self, node, filter=None): """ :param node: :return: """ yield node for node_id in nx.ancestors(self.g, node.ID): node = self.get_node(node_id) if not filter or filter(node): yield self.get_node(node_id) def parents(self, node): """ :param node: :return: """ for parent_id in self.g.predecessors(node.ID): parent = self.get_node(parent_id) rel = self.get_edge(parent, node) yield parent, rel def topological_sort(self): """ :return: """ for id in nx.topological_sort(self.g): yield self.get_node(id) class LearnableGraph(object): """ LearnableGraph """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def node_types(self, orders: List[Node], node_voc_functor): """ :return: :rtype: """ import numpy as np return np.narray([node_voc_functor(x) for x in orders]) def adjmatrix(self, node_orders: List[Node], edge_voc_functor, empty_id=0): """ :return: :rtype: """ n_nodes = len(node_orders) node2index = dict((node, idx) for idx, node in enumerate(node_orders)) import numpy as np in_a = np.ones([n_nodes, n_nodes], dtype=np.int32) * empty_id out_a = np.ones([n_nodes, n_nodes], dtype=np.int32) * empty_id for u, edge, v in self.edges(): u_idx = node2index[u] v_idx = node2index[v] e_idx = edge_voc_functor(edge) # zero is empty type out_a[u_idx][v_idx] = e_idx in_a[v_idx][u_idx] = e_idx if not nx.is_directed(self.g): in_a[u_idx][v_idx] = e_idx out_a[v_idx][u_idx] = e_idx return (in_a, out_a) def to_tensor(self, node_orders: List[Node], node_voc, edge_voc, end_node=None): """ :return: """ node_types = self.node_types(node_orders, node_voc) a_in, a_out = self.adjmatrix(node_orders, edge_voc) if end_node: node_num = len(node_types) node_types.resize((node_num + 1,)) node_types[-1] = end_node a_in.resize((node_num + 1, node_num + 1)) a_out.resize((node_num + 1, node_num + 1)) return node_types, a_in, a_out def valid_alignment(choices): """ :param choices: :return: """ def _inner(i): if i == n: yield tuple(result) return for elt in sets[i] - seen: seen.add(elt) result[i] = elt for t in _inner(i + 1): yield t seen.remove(elt) sets = [set(seq) for seq in choices] n = len(sets) seen = set() result = [None] * n for t in _inner(0): yield t def is_valid_topology_sort(dag, node_objs): """ decide whether the order of nodes in dag2 is a valid topological sort order of dag1 :param dag: :param pred_node_objs donot contain the start node: :return: """ target_nodes = list(dag.nodes()) target_node_objects = [n.object for n in target_nodes] choices = [] for i, node_obj in enumerate(node_objs): cur_choice = [] for j, target_node in enumerate(target_node_objects): if target_node == node_obj: cur_choice.append(j) if len(cur_choice) == 0: return False choices.append(cur_choice) for align in valid_alignment(choices): if len(set(align)) != len(align): continue node_ids = [target_nodes[i].ID for i in align] bad_align = False for id, node_id in enumerate(node_ids): if nx.descendants(dag.g, node_id).intersection(set(node_ids[:id])): bad_align = True break if nx.ancestors(dag.g, node_id).intersection(set(node_ids[id + 1:])): bad_align = True break if not bad_align: return True return False def not_isomorphic(graph_a, graph_b): """ :param graph_a: :param graph_b: :return: """ return nx.faster_could_be_isomorphic(graph_a.g, graph_b.g) def dot2image(dot_string, file_name=None, program="dot", format=None, return_img=False): """ @param g: @param file_name: @return: """ from PIL import Image import os import tempfile dot_file = tempfile.NamedTemporaryFile(mode='w', suffix=".dot", delete=False) dot_file.write(dot_string) dot_file.close() if not format: format = "svg" if not file_name and return_img: import tempfile fout = tempfile.NamedTemporaryFile(suffix="." + format) file_name = fout.name return_val = os.system(f'{program} -T {format} "{dot_file.name}" -o "{file_name}"') assert return_val == 0 if return_img: return Image.open(file_name) class GraphVisualizer(object): """ BasicGraphVisualizer """ def node_label(self, graph, node, *args, **kwargs): """ :param node: :type node: :return: :rtype: """ return str(node) def node_style(self, graph, node, *args, **kwargs): """ @param graph: @param node: @param args: @param kwargs: @return: """ return {} def edge_label(self, graph, edge, *args, **kwargs): """ :param edge: :type edge: :return: :rtype: """ return str(edge) def edge_style(self, graph, edge, *args, **kwargs): """ @param graph: @param edge: @param args: @param kwargs: @return: """ return {} def visualize(self, graph, file_name=None, return_img=False, format="svg", no_text=False, *args, **kwargs): """ @return: @rtype: """ import io dot_string = io.StringIO() dot_string.write("strict digraph {\n") node2index = dict() for index, node_id in enumerate(graph.g.nodes()): node = graph.get_node(node_id) node_label = self.node_label(graph, node, no_text=no_text, *args, **kwargs) node_style = self.node_style(graph, node, *args, **kwargs) node2index[node.ID] = index node_attr = ['label="{0}"'.format(node_label)] for k, v in node_style.items(): node_attr.append('{0}="{1}"'.format(k, v)) vis_node_label = '{0}\t[{1}]; \n'.format( index, ", ".join(node_attr) ) dot_string.write(vis_node_label) # if simple: # g.add_node(id2index[node_id], label=node_text, shape=shape) # else: # g.add_node(node_id, label=node_text, shape=shape) for s, e in graph.g.edges(): edge = graph.g[s][e]["Edge"] edge_label = self.edge_label(graph, edge, *args, **kwargs) edge_style = self.edge_style(graph, edge, *args, **kwargs) edge_attr = ['label="{0}"'.format(edge_label)] for k, v in edge_style.items(): edge_attr.append('{0}="{1}"'.format(k, v)) s = node2index[s] e = node2index[e] dot_string.write('{0}\t->\t{1}\t[{2}];\n'.format( s, e, ", ".join(edge_attr) )) dot_string.write("}\n") dot_string = dot_string.getvalue() result = dot_string if file_name or return_img: image = dot2image(dot_string, file_name=file_name, return_img=return_img, format=format) if return_img: result = image return result