Spaces:
Sleeping
Sleeping
| import os.path as osp | |
| from pyvis.network import Network | |
| import torch | |
| import numpy as np | |
| from src.tools.graph import k_hop_subgraph | |
| from src.tools.node import Node, register_node | |
| from torch_geometric.utils import to_undirected, is_undirected | |
| color_types = ['#97c2fc', 'lightgreen', 'lightpink', 'lightpurple'] | |
| class SemiStructureKB: | |
| def __init__(self, node_info, edge_index, | |
| node_type_dict=None, | |
| edge_type_dict=None, | |
| node_types=None, edge_types=None, | |
| indirected=True, **kwargs): | |
| """ | |
| A abstract dataset for semistructure data | |
| Args: | |
| node_info (Dict[dict]): A meta dictionary, where each key is node ID and each value is a dictionary | |
| containing information about the corresponding node. | |
| The dictionary can be in arbitrary structure (e.g., hierarchical). | |
| node_types (torch.LongTensor): node types | |
| node_type_dict (torch.LongTensor): A meta dictionary, where each key is node ID (if node_types is None) or node type | |
| (if node_types is not None) and each value dictionary contains information about | |
| the node of the node type. | |
| edge_index (torch.LongTensor): edge index in the pyg format. | |
| edge_types (torch.LongTensor): edge types | |
| edge_type_dict (List[dict]): A meta dictionary, where each key is edge ID (if edge_types is None) or edge type | |
| (if edge_types is not None) and each value dictionary contains information about | |
| the edge of the edge type. | |
| """ | |
| self.node_info = node_info | |
| self.edge_index = edge_index | |
| self.edge_type_dict = edge_type_dict | |
| self.node_type_dict = node_type_dict | |
| self.node_types = node_types | |
| self.edge_types = edge_types | |
| if indirected and not is_undirected(self.edge_index): | |
| self.edge_index, self.edge_types = to_undirected(self.edge_index, self.edge_types, | |
| num_nodes=self.num_nodes(), reduce='mean') | |
| self.edge_types = self.edge_types.long() | |
| if hasattr(self, 'candidate_types'): | |
| self.candidate_ids = self.get_candidate_ids() | |
| else: | |
| self.candidate_ids = [i for i in range(len(self.node_info))] | |
| self.num_candidates = len(self.candidate_ids) | |
| self._build_sparse_adj() | |
| def __len__(self) -> int: | |
| return len(self.node_info) | |
| def __getitem__(self, idx): | |
| idx = int(idx) | |
| node = Node() | |
| register_node(node, self.node_info[idx]) | |
| return node | |
| def get_doc_info(self, idx, | |
| add_rel=False, compact=False) -> str: | |
| ''' | |
| Return a text document containing information about the node. | |
| Args: | |
| idx (int): node index | |
| add_rel (bool): whether to add relational information explicitly | |
| compact (bool): whether to compact the text | |
| ''' | |
| raise NotImplementedError | |
| def _build_sparse_adj(self): | |
| ''' | |
| Build the sparse adjacency matrix. | |
| ''' | |
| self.sparse_adj = torch.sparse.FloatTensor(self.edge_index, | |
| torch.ones(self.edge_index.shape[1]), | |
| torch.Size([self.num_nodes(), self.num_nodes()])) | |
| self.sparse_adj_by_type = {} | |
| for edge_type in self.rel_type_lst(): | |
| edge_idx = torch.arange(self.num_edges())[self.edge_types == self.edge_type2id(edge_type)] | |
| self.sparse_adj_by_type[edge_type] = torch.sparse.FloatTensor(self.edge_index[:, edge_idx], | |
| torch.ones(edge_idx.shape[0]), | |
| torch.Size([self.num_nodes(), self.num_nodes()])) | |
| def get_rel_info(self, idx, rel_type=None) -> str: | |
| ''' | |
| Return a text document containing information about the node. | |
| Args: | |
| idx (int): node index | |
| add_rel (bool): whether to add relational information explicitly | |
| compact (bool): whether to compact the text | |
| ''' | |
| raise NotImplementedError | |
| def get_candidate_ids(self) -> list: | |
| ''' | |
| Get the candidate IDs. | |
| ''' | |
| assert hasattr(self, 'candidate_types') | |
| candidate_ids = np.concatenate([self.get_node_ids_by_type(candidate_type) for candidate_type in self.candidate_types]).tolist() | |
| candidate_ids.sort() | |
| return candidate_ids | |
| def num_nodes(self, node_type_id=None): | |
| if node_type_id is None: | |
| return len(self.node_types) | |
| else: | |
| return sum(self.node_types == node_type_id) | |
| def num_edges(self, node_type_id=None): | |
| if node_type_id is None: | |
| return len(self.edge_types) | |
| else: | |
| return sum(self.edge_types == node_type_id) | |
| def rel_type_lst(self): | |
| return list(self.edge_type_dict.values()) | |
| def node_type_lst(self): | |
| return list(self.node_type_dict.values()) | |
| def node_attr_dict(self): | |
| raise NotImplementedError | |
| def is_rel_type(self, edge_type: str): | |
| return edge_type in self.rel_type_lst() | |
| def edge_type2id(self, edge_type: str) -> int: | |
| ''' | |
| Get the edge type ID given the edge type. | |
| ''' | |
| try: | |
| idx = list(self.edge_type_dict.values()).index(edge_type) | |
| except: | |
| raise ValueError(f"Edge type {edge_type} not found") | |
| return list(self.edge_type_dict.keys())[idx] | |
| def node_type2id(self, node_type: str) -> int: | |
| ''' | |
| Get the node type ID given the node type. | |
| ''' | |
| try: | |
| idx = list(self.node_type_dict.values()).index(node_type) | |
| except: | |
| raise ValueError(f"Node type {node_type} not found") | |
| return list(self.node_type_dict.keys())[idx] | |
| def get_node_type_by_id(self, node_id: int) -> str: | |
| ''' | |
| Get the node type given the node ID. | |
| ''' | |
| return self.node_type_dict[self.node_types[node_id].item()] | |
| def get_edge_type_by_id(self, edge_id: int) -> str: | |
| ''' | |
| Get the edge type given the edge ID. | |
| ''' | |
| return self.edge_type_dict[self.edge_types[edge_id].item()] | |
| def get_node_ids_by_type(self, node_type: str) -> list: | |
| ''' | |
| Get the node IDs given the node type. | |
| ''' | |
| return torch.arange(self.num_nodes())[self.node_types == self.node_type2id(node_type)].tolist() | |
| def get_node_ids_by_value(self, node_type, key, value) -> list: | |
| ''' | |
| Get the node IDs given the node type and the value of a specific attribute. | |
| ''' | |
| ids = self.get_node_ids_by_type(node_type) | |
| indices = [] | |
| for idx in ids: | |
| if hasattr(self[idx], key) and getattr(self[idx], key) == value: | |
| indices.append(idx) | |
| return indices | |
| def get_edge_ids_by_type(self, edge_type: str) -> list: | |
| ''' | |
| Get the edge IDs given the edge type. | |
| ''' | |
| return torch.arange(self.num_edges())[self.edge_types == self.edge_type2id(edge_type)].tolist() | |
| def sample_paths(self, node_types: list, edge_types: list, start_node_id=None, size=1) -> list: | |
| ''' | |
| Sample paths give the node types and edge types. | |
| Use "*" to indicate any edge type. | |
| ''' | |
| assert len(node_types) == len(edge_types) + 1 | |
| for i in range(len(edge_types)): | |
| if edge_types[i] == "*": | |
| continue | |
| _tuple = (node_types[i], edge_types[i], node_types[i+1]) | |
| assert _tuple in self.get_tuples(), f"{_tuple} invalid" | |
| paths = [] | |
| while len(paths) < size: | |
| p = [] | |
| for i in range(len(node_types)): | |
| if i == 0: | |
| node_idx = start_node_id if not start_node_id is None else \ | |
| np.random.choice(self.get_node_ids_by_type(node_types[i])) | |
| else: | |
| # neighbor_nodes = self.get_neighbor_nodes(node_idx, edge_types[i-1], direction='in-and-out') | |
| neighbor_nodes = self.get_neighbor_nodes(node_idx, edge_types[i-1]) | |
| neighbor_nodes = torch.LongTensor(neighbor_nodes) | |
| node_type_id = list(self.node_type_dict.keys())[list(self.node_type_dict.values()).index(node_types[1])] | |
| neighbor_nodes = neighbor_nodes[self.node_types[neighbor_nodes] == node_type_id] | |
| neighbor_nodes = neighbor_nodes.tolist() | |
| if len(neighbor_nodes) == 0: | |
| if i == 1 and not start_node_id is None: | |
| return [] | |
| else: | |
| break | |
| node_idx = np.random.choice(neighbor_nodes) | |
| p.append(node_idx) | |
| if len(p) == len(node_types): | |
| paths.append(p) | |
| return paths | |
| def get_all_paths(self, start_node_id: int, | |
| node_types: list, edge_types: list, | |
| max_num=None, direction='in-and-out') -> list: | |
| ''' | |
| Sample paths give the node types and edge types. | |
| Use "*" to indicate any edge type. | |
| ''' | |
| assert len(node_types) == len(edge_types) + 1 | |
| paths = [] | |
| # neighbor_nodes = self.get_neighbor_nodes(start_node_id, edge_types[0], direction=direction) | |
| neighbor_nodes = self.get_neighbor_nodes(start_node_id, edge_types[0]) | |
| neighbor_nodes = torch.LongTensor(neighbor_nodes) | |
| node_type_id = list(self.node_type_dict.keys())[list(self.node_type_dict.values()).index(node_types[1])] | |
| neighbor_nodes = neighbor_nodes[self.node_types[neighbor_nodes] == node_type_id] | |
| neighbor_nodes = neighbor_nodes.tolist() | |
| if len(neighbor_nodes) == 0: | |
| # print(f'{start_node_id} => No neighbor nodes | len(node_types)={len(node_types)}') | |
| return [] | |
| elif len(node_types) == 2: | |
| return [[start_node_id, node_idx] for node_idx in neighbor_nodes] | |
| else: | |
| # print(f'Iterating over # {len(neighbor_nodes)} neighbors') | |
| for iter_start_node_id in neighbor_nodes: | |
| subpaths = self.get_all_paths(iter_start_node_id, node_types[1:], edge_types[1:]) | |
| if len(subpaths) == 0: | |
| continue | |
| for subpath in subpaths: | |
| paths.append([start_node_id] + subpath) | |
| # print((iter_start_node_id, node_types[1:], edge_types[1:]), '==> subpaths #', len(subpaths), ' | Total #', len(paths)) | |
| if not max_num is None and len(paths) > max_num: | |
| print('max_num reached') | |
| return [] | |
| # print('--------------Finished iterating--------------') | |
| return paths | |
| def get_tuples(self) -> list: | |
| ''' | |
| Get all possible tuples of node types and edge types. | |
| ''' | |
| col, row = self.edge_index.tolist() | |
| edge_types = self.edge_types.tolist() | |
| col_types, row_types = self.node_types[col].tolist(), self.node_types[row].tolist() | |
| tuples_by_id = set([(n_i, e, n_j) for n_i, e, n_j in zip(col_types, edge_types, row_types)]) | |
| tuples = [] | |
| for n_i, e, n_j in tuples_by_id: | |
| tuples.append((self.node_type_dict[n_i], self.edge_type_dict[e], self.node_type_dict[n_j])) | |
| tuples = list(set(tuples)) | |
| tuples.sort() | |
| return tuples | |
| def get_neighbor_nodes(self, idx, edge_type: str = "*") -> list: | |
| ''' | |
| Get the neighbor nodes given the node ID and the edge type. | |
| Args: | |
| idx (int): node index | |
| edge_type (str): edge type, use "*" to indicate any edge type. | |
| ''' | |
| if edge_type == "*": | |
| neighbor_nodes = self.sparse_adj[idx].coalesce().indices().view(-1).tolist() | |
| else: | |
| neighbor_nodes = self.sparse_adj_by_type[edge_type][idx].coalesce().indices().view(-1).tolist() | |
| return neighbor_nodes | |
| def k_hop_neighbor(self, node_idx, num_hops, **kwargs): | |
| subset, edge_index, _, edge_mask = k_hop_subgraph(node_idx, | |
| num_hops, | |
| self.edge_index, | |
| num_nodes=self.num_nodes(), | |
| flow='bidirectional', | |
| **kwargs) | |
| node_types = self.node_types[subset] | |
| edge_types = self.edge_types[edge_mask] | |
| return subset, edge_index, node_types, edge_types | |
| def visualize(self, path='.'): | |
| net = Network() | |
| for idx in range(self.num_nodes()): | |
| try: | |
| net.add_node(idx, label=getattr(self[idx], | |
| self.node_type_dict[self.node_types[idx].item()])[:1], | |
| color=color_types[self.node_types[idx].item()] | |
| ) | |
| except: | |
| net.add_node(idx, | |
| label=getattr(self[idx], 'title')[:1], | |
| color=color_types[self.node_types[idx].item()] | |
| ) | |
| for idx in range(self.num_edges()): | |
| net.add_edge(self.edge_index[0][idx].item(), | |
| self.edge_index[1][idx].item(), | |
| color=color_types[self.edge_types[idx].item()]) | |
| net.toggle_physics(True) | |
| net.show(osp.join(path, 'nodes.html'), notebook=False) |