| """ |
| Knowledge Graph data structure for the Graph Reasoning Network. |
| |
| Each node represents a fact, each edge represents a relationship between facts. |
| The graph supports dynamic node/edge creation, navigation, and DAG subgraph extraction. |
| """ |
| import torch |
| import torch.nn.functional as F |
| from torch_geometric.data import Data |
| from torch_geometric.utils import subgraph, k_hop_subgraph, to_undirected |
| from typing import Optional, List, Tuple, Dict |
| import json |
|
|
|
|
| class KnowledgeGraph: |
| """ |
| A dynamic knowledge graph that can grow over time. |
| |
| Nodes: facts (with text descriptions and learned embeddings) |
| Edges: relationships between facts (typed, directed) |
| |
| Supports: |
| - Adding new nodes and edges |
| - Navigating to find relevant subgraphs |
| - Extracting DAG subgraphs as traceable reasoning chains |
| - Serialization for persistence |
| """ |
| |
| def __init__(self, node_dim: int = 256, edge_dim: int = 64): |
| self.node_dim = node_dim |
| self.edge_dim = edge_dim |
| self.node_features = torch.zeros(0, node_dim) |
| self.node_texts: List[str] = [] |
| self.node_types: List[str] = [] |
| self.edge_index = torch.zeros(2, 0, dtype=torch.long) |
| self.edge_features = torch.zeros(0, edge_dim) |
| self.edge_texts: List[str] = [] |
| self.relation_types: List[str] = [] |
| self.entity_type_to_id: Dict[str, int] = {} |
| self.relation_type_to_id: Dict[str, int] = {} |
| self.start_node_idx: Optional[int] = None |
| self.end_node_idx: Optional[int] = None |
| |
| @property |
| def num_nodes(self) -> int: |
| return self.node_features.shape[0] |
| |
| @property |
| def num_edges(self) -> int: |
| return self.edge_index.shape[1] |
| |
| @property |
| def num_entity_types(self) -> int: |
| return len(self.entity_type_to_id) |
| |
| @property |
| def num_relation_types(self) -> int: |
| return len(self.relation_type_to_id) |
| |
| def add_node(self, features: torch.Tensor, text: str = "", |
| entity_type: str = "default") -> int: |
| """Add a new fact node to the graph. Returns the node index.""" |
| assert features.shape[-1] == self.node_dim |
| if features.dim() == 1: |
| features = features.unsqueeze(0) |
| self.node_features = torch.cat([self.node_features, features], dim=0) |
| self.node_texts.append(text) |
| self.node_types.append(entity_type) |
| if entity_type not in self.entity_type_to_id: |
| self.entity_type_to_id[entity_type] = len(self.entity_type_to_id) |
| return self.num_nodes - 1 |
| |
| def add_edge(self, src: int, dst: int, features: Optional[torch.Tensor] = None, |
| text: str = "", relation_type: str = "default") -> int: |
| """Add a directed edge (relationship) between two fact nodes.""" |
| assert 0 <= src < self.num_nodes and 0 <= dst < self.num_nodes |
| new_edge = torch.tensor([[src], [dst]], dtype=torch.long) |
| self.edge_index = torch.cat([self.edge_index, new_edge], dim=1) |
| if features is None: |
| features = torch.zeros(1, self.edge_dim) |
| elif features.dim() == 1: |
| features = features.unsqueeze(0) |
| self.edge_features = torch.cat([self.edge_features, features], dim=0) |
| self.edge_texts.append(text) |
| self.relation_types.append(relation_type) |
| if relation_type not in self.relation_type_to_id: |
| self.relation_type_to_id[relation_type] = len(self.relation_type_to_id) |
| return self.num_edges - 1 |
| |
| def to_pyg_data(self) -> Data: |
| """Convert to PyG Data object for GNN processing.""" |
| rel_type_ids = torch.tensor( |
| [self.relation_type_to_id.get(r, 0) for r in self.relation_types], dtype=torch.long |
| ) if self.relation_types else torch.zeros(0, dtype=torch.long) |
| entity_type_ids = torch.tensor( |
| [self.entity_type_to_id.get(t, 0) for t in self.node_types], dtype=torch.long |
| ) if self.node_types else torch.zeros(0, dtype=torch.long) |
| return Data(x=self.node_features, edge_index=self.edge_index, |
| edge_attr=self.edge_features, edge_type=rel_type_ids, |
| node_type=entity_type_ids, num_nodes=self.num_nodes) |
| |
| def extract_subgraph(self, node_indices: torch.Tensor) -> 'KnowledgeGraph': |
| """Extract a subgraph containing only the specified nodes.""" |
| mask = torch.zeros(self.num_nodes, dtype=torch.bool) |
| mask[node_indices] = True |
| new_edge_index, new_edge_attr, edge_mask = subgraph( |
| mask, self.edge_index, self.edge_features, relabel_nodes=True, return_edge_mask=True) |
| sub_kg = KnowledgeGraph(self.node_dim, self.edge_dim) |
| sub_kg.node_features = self.node_features[mask] |
| sub_kg.node_texts = [self.node_texts[i] for i in node_indices.tolist()] |
| sub_kg.node_types = [self.node_types[i] for i in node_indices.tolist()] |
| sub_kg.edge_index = new_edge_index |
| sub_kg.edge_features = new_edge_attr |
| sub_kg.edge_texts = [self.edge_texts[i] for i, m in enumerate(edge_mask.tolist()) if m] |
| sub_kg.relation_types = [self.relation_types[i] for i, m in enumerate(edge_mask.tolist()) if m] |
| sub_kg.entity_type_to_id = dict(self.entity_type_to_id) |
| sub_kg.relation_type_to_id = dict(self.relation_type_to_id) |
| return sub_kg |
| |
| def extract_dag_subgraph(self, node_scores: torch.Tensor, |
| threshold: float = 0.5) -> 'KnowledgeGraph': |
| """Extract a DAG subgraph. Removes cycles by keeping only forward edges.""" |
| mask = node_scores > threshold |
| selected_indices = mask.nonzero(as_tuple=True)[0] |
| if selected_indices.numel() == 0: |
| return KnowledgeGraph(self.node_dim, self.edge_dim) |
| new_edge_index, new_edge_attr, edge_mask = subgraph( |
| mask, self.edge_index, self.edge_features, relabel_nodes=True, return_edge_mask=True) |
| if new_edge_index.shape[1] > 0: |
| orig_src = self.edge_index[0][edge_mask] |
| orig_dst = self.edge_index[1][edge_mask] |
| dag_mask = node_scores[orig_src] <= node_scores[orig_dst] |
| new_edge_index = new_edge_index[:, dag_mask] |
| new_edge_attr = new_edge_attr[dag_mask] if new_edge_attr.numel() > 0 else new_edge_attr |
| sub_kg = KnowledgeGraph(self.node_dim, self.edge_dim) |
| sub_kg.node_features = self.node_features[mask] |
| sub_kg.node_texts = [self.node_texts[i] for i in selected_indices.tolist()] |
| sub_kg.node_types = [self.node_types[i] for i in selected_indices.tolist()] |
| sub_kg.edge_index = new_edge_index |
| sub_kg.edge_features = new_edge_attr |
| edge_text_indices = [i for i, m in enumerate(edge_mask.tolist()) if m] |
| if new_edge_index.shape[1] > 0: |
| dag_indices = dag_mask.nonzero(as_tuple=True)[0].tolist() |
| sub_kg.edge_texts = [self.edge_texts[edge_text_indices[j]] for j in dag_indices] |
| sub_kg.relation_types = [self.relation_types[edge_text_indices[j]] for j in dag_indices] |
| else: |
| sub_kg.edge_texts = [] |
| sub_kg.relation_types = [] |
| sub_kg.entity_type_to_id = dict(self.entity_type_to_id) |
| sub_kg.relation_type_to_id = dict(self.relation_type_to_id) |
| return sub_kg |
| |
| def get_neighbors(self, node_idx: int, direction: str = "out") -> torch.Tensor: |
| if direction == "out": |
| mask = self.edge_index[0] == node_idx |
| return self.edge_index[1][mask] |
| elif direction == "in": |
| mask = self.edge_index[1] == node_idx |
| return self.edge_index[0][mask] |
| else: |
| out_mask = self.edge_index[0] == node_idx |
| in_mask = self.edge_index[1] == node_idx |
| return torch.cat([self.edge_index[1][out_mask], self.edge_index[0][in_mask]]).unique() |
| |
| def save(self, path: str): |
| torch.save({ |
| 'node_features': self.node_features, 'node_texts': self.node_texts, |
| 'node_types': self.node_types, 'edge_index': self.edge_index, |
| 'edge_features': self.edge_features, 'edge_texts': self.edge_texts, |
| 'relation_types': self.relation_types, |
| 'entity_type_to_id': self.entity_type_to_id, |
| 'relation_type_to_id': self.relation_type_to_id, |
| 'node_dim': self.node_dim, 'edge_dim': self.edge_dim, |
| 'start_node_idx': self.start_node_idx, 'end_node_idx': self.end_node_idx, |
| }, path) |
| |
| @classmethod |
| def load(cls, path: str) -> 'KnowledgeGraph': |
| state = torch.load(path, weights_only=False) |
| kg = cls(state['node_dim'], state['edge_dim']) |
| kg.node_features = state['node_features'] |
| kg.node_texts = state['node_texts'] |
| kg.node_types = state['node_types'] |
| kg.edge_index = state['edge_index'] |
| kg.edge_features = state['edge_features'] |
| kg.edge_texts = state['edge_texts'] |
| kg.relation_types = state['relation_types'] |
| kg.entity_type_to_id = state['entity_type_to_id'] |
| kg.relation_type_to_id = state['relation_type_to_id'] |
| kg.start_node_idx = state.get('start_node_idx') |
| kg.end_node_idx = state.get('end_node_idx') |
| return kg |
| |
| def __repr__(self): |
| return (f"KnowledgeGraph(nodes={self.num_nodes}, edges={self.num_edges}, " |
| f"entity_types={self.num_entity_types}, relation_types={self.num_relation_types})") |
|
|