File size: 9,500 Bytes
beb1305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
"""
Knowledge Graph data structure for the Graph Reasoning Network.

Each node represents a fact, each edge represents a relationship between facts.
The graph supports dynamic node/edge creation, navigation, and DAG subgraph extraction.
"""
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.utils import subgraph, k_hop_subgraph, to_undirected
from typing import Optional, List, Tuple, Dict
import json


class KnowledgeGraph:
    """
    A dynamic knowledge graph that can grow over time.
    
    Nodes: facts (with text descriptions and learned embeddings)
    Edges: relationships between facts (typed, directed)
    
    Supports:
    - Adding new nodes and edges
    - Navigating to find relevant subgraphs
    - Extracting DAG subgraphs as traceable reasoning chains
    - Serialization for persistence
    """
    
    def __init__(self, node_dim: int = 256, edge_dim: int = 64):
        self.node_dim = node_dim
        self.edge_dim = edge_dim
        self.node_features = torch.zeros(0, node_dim)
        self.node_texts: List[str] = []
        self.node_types: List[str] = []
        self.edge_index = torch.zeros(2, 0, dtype=torch.long)
        self.edge_features = torch.zeros(0, edge_dim)
        self.edge_texts: List[str] = []
        self.relation_types: List[str] = []
        self.entity_type_to_id: Dict[str, int] = {}
        self.relation_type_to_id: Dict[str, int] = {}
        self.start_node_idx: Optional[int] = None
        self.end_node_idx: Optional[int] = None
        
    @property
    def num_nodes(self) -> int:
        return self.node_features.shape[0]
    
    @property
    def num_edges(self) -> int:
        return self.edge_index.shape[1]
    
    @property
    def num_entity_types(self) -> int:
        return len(self.entity_type_to_id)
    
    @property
    def num_relation_types(self) -> int:
        return len(self.relation_type_to_id)
    
    def add_node(self, features: torch.Tensor, text: str = "", 
                 entity_type: str = "default") -> int:
        """Add a new fact node to the graph. Returns the node index."""
        assert features.shape[-1] == self.node_dim
        if features.dim() == 1:
            features = features.unsqueeze(0)
        self.node_features = torch.cat([self.node_features, features], dim=0)
        self.node_texts.append(text)
        self.node_types.append(entity_type)
        if entity_type not in self.entity_type_to_id:
            self.entity_type_to_id[entity_type] = len(self.entity_type_to_id)
        return self.num_nodes - 1
    
    def add_edge(self, src: int, dst: int, features: Optional[torch.Tensor] = None,
                 text: str = "", relation_type: str = "default") -> int:
        """Add a directed edge (relationship) between two fact nodes."""
        assert 0 <= src < self.num_nodes and 0 <= dst < self.num_nodes
        new_edge = torch.tensor([[src], [dst]], dtype=torch.long)
        self.edge_index = torch.cat([self.edge_index, new_edge], dim=1)
        if features is None:
            features = torch.zeros(1, self.edge_dim)
        elif features.dim() == 1:
            features = features.unsqueeze(0)
        self.edge_features = torch.cat([self.edge_features, features], dim=0)
        self.edge_texts.append(text)
        self.relation_types.append(relation_type)
        if relation_type not in self.relation_type_to_id:
            self.relation_type_to_id[relation_type] = len(self.relation_type_to_id)
        return self.num_edges - 1
    
    def to_pyg_data(self) -> Data:
        """Convert to PyG Data object for GNN processing."""
        rel_type_ids = torch.tensor(
            [self.relation_type_to_id.get(r, 0) for r in self.relation_types], dtype=torch.long
        ) if self.relation_types else torch.zeros(0, dtype=torch.long)
        entity_type_ids = torch.tensor(
            [self.entity_type_to_id.get(t, 0) for t in self.node_types], dtype=torch.long
        ) if self.node_types else torch.zeros(0, dtype=torch.long)
        return Data(x=self.node_features, edge_index=self.edge_index,
                    edge_attr=self.edge_features, edge_type=rel_type_ids,
                    node_type=entity_type_ids, num_nodes=self.num_nodes)
    
    def extract_subgraph(self, node_indices: torch.Tensor) -> 'KnowledgeGraph':
        """Extract a subgraph containing only the specified nodes."""
        mask = torch.zeros(self.num_nodes, dtype=torch.bool)
        mask[node_indices] = True
        new_edge_index, new_edge_attr, edge_mask = subgraph(
            mask, self.edge_index, self.edge_features, relabel_nodes=True, return_edge_mask=True)
        sub_kg = KnowledgeGraph(self.node_dim, self.edge_dim)
        sub_kg.node_features = self.node_features[mask]
        sub_kg.node_texts = [self.node_texts[i] for i in node_indices.tolist()]
        sub_kg.node_types = [self.node_types[i] for i in node_indices.tolist()]
        sub_kg.edge_index = new_edge_index
        sub_kg.edge_features = new_edge_attr
        sub_kg.edge_texts = [self.edge_texts[i] for i, m in enumerate(edge_mask.tolist()) if m]
        sub_kg.relation_types = [self.relation_types[i] for i, m in enumerate(edge_mask.tolist()) if m]
        sub_kg.entity_type_to_id = dict(self.entity_type_to_id)
        sub_kg.relation_type_to_id = dict(self.relation_type_to_id)
        return sub_kg
    
    def extract_dag_subgraph(self, node_scores: torch.Tensor, 
                              threshold: float = 0.5) -> 'KnowledgeGraph':
        """Extract a DAG subgraph. Removes cycles by keeping only forward edges."""
        mask = node_scores > threshold
        selected_indices = mask.nonzero(as_tuple=True)[0]
        if selected_indices.numel() == 0:
            return KnowledgeGraph(self.node_dim, self.edge_dim)
        new_edge_index, new_edge_attr, edge_mask = subgraph(
            mask, self.edge_index, self.edge_features, relabel_nodes=True, return_edge_mask=True)
        if new_edge_index.shape[1] > 0:
            orig_src = self.edge_index[0][edge_mask]
            orig_dst = self.edge_index[1][edge_mask]
            dag_mask = node_scores[orig_src] <= node_scores[orig_dst]
            new_edge_index = new_edge_index[:, dag_mask]
            new_edge_attr = new_edge_attr[dag_mask] if new_edge_attr.numel() > 0 else new_edge_attr
        sub_kg = KnowledgeGraph(self.node_dim, self.edge_dim)
        sub_kg.node_features = self.node_features[mask]
        sub_kg.node_texts = [self.node_texts[i] for i in selected_indices.tolist()]
        sub_kg.node_types = [self.node_types[i] for i in selected_indices.tolist()]
        sub_kg.edge_index = new_edge_index
        sub_kg.edge_features = new_edge_attr
        edge_text_indices = [i for i, m in enumerate(edge_mask.tolist()) if m]
        if new_edge_index.shape[1] > 0:
            dag_indices = dag_mask.nonzero(as_tuple=True)[0].tolist()
            sub_kg.edge_texts = [self.edge_texts[edge_text_indices[j]] for j in dag_indices]
            sub_kg.relation_types = [self.relation_types[edge_text_indices[j]] for j in dag_indices]
        else:
            sub_kg.edge_texts = []
            sub_kg.relation_types = []
        sub_kg.entity_type_to_id = dict(self.entity_type_to_id)
        sub_kg.relation_type_to_id = dict(self.relation_type_to_id)
        return sub_kg
    
    def get_neighbors(self, node_idx: int, direction: str = "out") -> torch.Tensor:
        if direction == "out":
            mask = self.edge_index[0] == node_idx
            return self.edge_index[1][mask]
        elif direction == "in":
            mask = self.edge_index[1] == node_idx
            return self.edge_index[0][mask]
        else:
            out_mask = self.edge_index[0] == node_idx
            in_mask = self.edge_index[1] == node_idx
            return torch.cat([self.edge_index[1][out_mask], self.edge_index[0][in_mask]]).unique()
    
    def save(self, path: str):
        torch.save({
            'node_features': self.node_features, 'node_texts': self.node_texts,
            'node_types': self.node_types, 'edge_index': self.edge_index,
            'edge_features': self.edge_features, 'edge_texts': self.edge_texts,
            'relation_types': self.relation_types,
            'entity_type_to_id': self.entity_type_to_id,
            'relation_type_to_id': self.relation_type_to_id,
            'node_dim': self.node_dim, 'edge_dim': self.edge_dim,
            'start_node_idx': self.start_node_idx, 'end_node_idx': self.end_node_idx,
        }, path)
    
    @classmethod
    def load(cls, path: str) -> 'KnowledgeGraph':
        state = torch.load(path, weights_only=False)
        kg = cls(state['node_dim'], state['edge_dim'])
        kg.node_features = state['node_features']
        kg.node_texts = state['node_texts']
        kg.node_types = state['node_types']
        kg.edge_index = state['edge_index']
        kg.edge_features = state['edge_features']
        kg.edge_texts = state['edge_texts']
        kg.relation_types = state['relation_types']
        kg.entity_type_to_id = state['entity_type_to_id']
        kg.relation_type_to_id = state['relation_type_to_id']
        kg.start_node_idx = state.get('start_node_idx')
        kg.end_node_idx = state.get('end_node_idx')
        return kg
    
    def __repr__(self):
        return (f"KnowledgeGraph(nodes={self.num_nodes}, edges={self.num_edges}, "
                f"entity_types={self.num_entity_types}, relation_types={self.num_relation_types})")