ho22joshua commited on
Commit
adc6050
·
1 Parent(s): 5c5e769

added model and training scripts

Browse files
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 JO5HO4
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
models/GCN.py ADDED
@@ -0,0 +1,1944 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dgl
2
+ import dgl.nn as dglnn
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ import sys
9
+ import os
10
+ file_path = os.getcwd()
11
+ sys.path.append(file_path)
12
+
13
+ import root_gnn_base.dataset as datasets
14
+ from root_gnn_base import utils
15
+
16
+ import gc
17
+
18
+ def Make_SLP(in_size, out_size, activation = nn.ReLU, dropout = 0):
19
+ layers = []
20
+ layers.append(nn.Linear(in_size, out_size))
21
+ layers.append(activation())
22
+ layers.append(nn.Dropout(dropout))
23
+ return layers
24
+
25
+ def Make_MLP(in_size, hid_size, out_size, n_layers, activation = nn.ReLU, dropout = 0):
26
+ layers = []
27
+ if n_layers > 1:
28
+ layers += Make_SLP(in_size, hid_size, activation, dropout)
29
+ for i in range(n_layers-2):
30
+ layers += Make_SLP(hid_size, hid_size, activation, dropout)
31
+ layers += Make_SLP(hid_size, out_size, activation, dropout)
32
+ else:
33
+ layers += Make_SLP(in_size, out_size, activation, dropout)
34
+ layers.append(torch.nn.LayerNorm(out_size))
35
+ return nn.Sequential(*layers)
36
+
37
+ class MLP(nn.Module):
38
+ def __init__(self, in_size, hid_size, out_size, n_layers, activation = nn.ReLU, dropout = 0, **kwargs):
39
+ super().__init__()
40
+ print(f'Unused args while creating MLP: {kwargs}')
41
+ self.layers = Make_MLP(in_size, hid_size, hid_size, n_layers-1, activation, dropout)
42
+ self.linear = nn.Linear(hid_size, out_size)
43
+
44
+ def forward(self, x):
45
+ return self.linear(self.layers(x))
46
+
47
+ def broadcast_global_to_nodes(g, globals):
48
+ boundaries = g.batch_num_nodes()
49
+ return torch.repeat_interleave(globals, boundaries, dim=0)
50
+
51
+ def broadcast_global_to_edges(g, globals):
52
+ boundaries = g.batch_num_edges()
53
+ return torch.repeat_interleave(globals, boundaries, dim=0)
54
+
55
+ def copy_v(edges):
56
+ return {'m_v': edges.dst['h']}
57
+
58
+ def partial_reset(model : nn.Module):
59
+ in_size = len(model.classify.weight[0])
60
+ out_size = len(model.classify.weight)
61
+ device = next(model.classify.parameters()).device
62
+ torch.manual_seed(2)
63
+ model.classify = nn.Linear(in_size, out_size)
64
+ model.classify.to(device)
65
+ print(model.classify.weight)
66
+
67
+ def print_model(model: nn.Module):
68
+ print(model)
69
+
70
+ def print_mlp(layer):
71
+ for l in layer.children():
72
+ if isinstance(l, nn.Linear):
73
+ print(l.state_dict())
74
+ else:
75
+ print(l)
76
+
77
+
78
+ # This function returns a model with the whole GNN completely reset
79
+ def full_reset(model : nn.Module):
80
+ mlp_list = [model.node_encoder, model.edge_encoder, model.global_encoder,
81
+ model.node_update, model.edge_update, model.global_update,
82
+ model.global_decoder]
83
+
84
+ for mlp in mlp_list:
85
+ for layer in mlp.children():
86
+ if hasattr(layer, 'reset_parameters'):
87
+ layer.reset_parameters()
88
+ partial_reset(model)
89
+
90
+
91
+ class GCN(nn.Module):
92
+ def __init__(self, in_size, hid_size, out_size, n_layers, **kwargs):
93
+ super().__init__()
94
+ print(f'Unused args while creating GCN: {kwargs}')
95
+ self.n_layers = n_layers
96
+ self.layers = nn.ModuleList()
97
+
98
+ # two-layer GCN
99
+ self.layers.extend(
100
+ [nn.Linear(in_size, hid_size),] +
101
+ [nn.Linear(hid_size, hid_size) for i in range(n_layers)] +
102
+ [dglnn.GraphConv(hid_size, hid_size) for i in range(n_layers)] +
103
+ [nn.Linear(hid_size, hid_size) for i in range(n_layers)]
104
+ )
105
+ self.classify = nn.Linear(hid_size, out_size)
106
+ #self.dropout = nn.Dropout(0.05)
107
+
108
+ def forward(self, g):
109
+ h = g.ndata['features']
110
+ for i, layer in enumerate(self.layers):
111
+ if i >= self.n_layers + 1 and i < self.n_layers * 2 + 1:
112
+ h = layer(g, h)
113
+ else:
114
+ h = layer(h)
115
+ h = F.relu(h)
116
+ with g.local_scope():
117
+ g.ndata['h'] = h
118
+ # Calculate graph representation by average readout.
119
+ hg = dgl.mean_nodes(g, 'h')
120
+ return self.classify(hg)
121
+
122
+ class GCN_global(nn.Module):
123
+ def __init__(self, in_size, hid_size=4, out_size=1, n_layers=1, dropout=0, **kwargs):
124
+ super().__init__()
125
+ print(f'Unused args while creating GCN: {kwargs}')
126
+ self.n_layers = n_layers
127
+
128
+ #encoder
129
+ self.node_encoder = Make_MLP(in_size, hid_size, hid_size, n_layers, dropout=dropout)
130
+ self.global_encoder = Make_MLP(1, hid_size, hid_size, n_layers, dropout=dropout)
131
+
132
+ #GCN
133
+ self.node_update = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
134
+ self.global_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
135
+ self.conv = dglnn.GraphConv(hid_size, hid_size)
136
+
137
+ #decoder
138
+ self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
139
+ self.classify = nn.Linear(hid_size, out_size)
140
+
141
+ def forward(self, g):
142
+ h = self.node_encoder(g.ndata['features'])
143
+ h_global = self.global_encoder(g.batch_num_nodes()[:, None].to(torch.float))
144
+ for i in range(self.n_layers):
145
+ h = self.node_update(h)
146
+ h = self.conv(g, h)
147
+ g.ndata['h'] = h
148
+ h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h')), dim = 1))
149
+ h_global = self.global_decoder(h_global)
150
+ return self.classify(h_global)
151
+
152
+ class GCN_global_2way(nn.Module):
153
+ def __init__(self, in_size, hid_size=4, out_size=1, n_layers=1, dropout=0, **kwargs):
154
+ super().__init__()
155
+ print(f'Unused args while creating GCN: {kwargs}')
156
+ self.n_layers = n_layers
157
+
158
+ #encoder
159
+ self.node_encoder = Make_MLP(in_size, hid_size, hid_size, n_layers, dropout=dropout)
160
+ self.global_encoder = Make_MLP(1, hid_size, hid_size, n_layers, dropout=dropout)
161
+
162
+ #GCN
163
+ self.node_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
164
+ self.global_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
165
+ self.conv = dglnn.GraphConv(hid_size, hid_size)
166
+
167
+ #decoder
168
+ self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
169
+ self.classify = nn.Linear(hid_size, out_size)
170
+
171
+ def forward(self, g):
172
+ h = self.node_encoder(g.ndata['features'])
173
+ h_global = self.global_encoder(g.batch_num_nodes()[:, None].to(torch.float))
174
+ for i in range(self.n_layers):
175
+ h = self.node_update(torch.cat((h, broadcast_global_to_nodes(g, h_global)), dim = 1))
176
+ h = self.conv(g, h)
177
+ g.ndata['h'] = h
178
+ h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h')), dim = 1))
179
+ h_global = self.global_decoder(h_global)
180
+ return self.classify(h_global)
181
+
182
+ class Edge_Network(nn.Module):
183
+ def __init__(self, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs):
184
+ super().__init__()
185
+ print(f'Unused args while creating GCN: {kwargs}')
186
+ self.n_layers = n_layers
187
+ self.n_proc_steps = n_proc_steps
188
+ self.layers = nn.ModuleList()
189
+ if (len(sample_global) == 0):
190
+ self.has_global = False
191
+ else:
192
+ self.has_global = sample_global.shape[1] != 0
193
+ gl_size = sample_global.shape[1] if self.has_global else 1
194
+
195
+ #encoder
196
+ self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
197
+ self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
198
+ self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
199
+
200
+ #GNN
201
+ self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
202
+ self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
203
+ self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
204
+
205
+ #decoder
206
+ self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
207
+ self.classify = nn.Linear(hid_size, out_size)
208
+
209
+ def forward(self, g, global_feats):
210
+ h = self.node_encoder(g.ndata['features'])
211
+ e = self.edge_encoder(g.edata['features'])
212
+
213
+ g.ndata['h'] = h
214
+ g.edata['e'] = e
215
+ if not self.has_global:
216
+ global_feats = g.batch_num_nodes()[:, None].to(torch.float)
217
+
218
+ batch_num_nodes = None
219
+ sum_weights = None
220
+ if "w" in g.ndata:
221
+ batch_indices = g.batch_num_nodes()
222
+ # Find non-zero rows (non-padded nodes)
223
+ non_padded_nodes_mask = torch.any(g.ndata['features'] != 0, dim=1)
224
+ # Split the mask according to the batch indices
225
+ batch_num_nodes = []
226
+ start_idx = 0
227
+ for num_nodes in batch_indices:
228
+ end_idx = start_idx + num_nodes
229
+ non_padded_count = non_padded_nodes_mask[start_idx:end_idx].sum().item()
230
+ batch_num_nodes.append(non_padded_count)
231
+ start_idx = end_idx
232
+ batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata['features'].device)
233
+ sum_weights = batch_num_nodes[:, None].repeat(1, 64)
234
+ global_feats = batch_num_nodes[:, None].to(torch.float)
235
+
236
+ h_global = self.global_encoder(global_feats)
237
+ for i in range(self.n_proc_steps):
238
+ g.apply_edges(dgl.function.copy_u('h', 'm_u'))
239
+ g.apply_edges(copy_v)
240
+ g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
241
+ g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
242
+ g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
243
+ if "w" in g.ndata:
244
+ mean_nodes = dgl.sum_nodes(g, 'h', 'w') / sum_weights
245
+ h_global = self.global_update(torch.cat((h_global, mean_nodes, dgl.mean_edges(g, 'e')), dim = 1))
246
+ else:
247
+ h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
248
+ h_global = self.global_decoder(h_global)
249
+ return self.classify(h_global)
250
+
251
+ def representation(self, g, global_feats):
252
+ h = self.node_encoder(g.ndata['features'])
253
+ e = self.edge_encoder(g.edata['features'])
254
+
255
+ g.ndata['h'] = h
256
+ g.edata['e'] = e
257
+ if not self.has_global:
258
+ global_feats = g.batch_num_nodes()[:, None].to(torch.float)
259
+
260
+ batch_num_nodes = None
261
+ sum_weights = None
262
+ if "w" in g.ndata:
263
+ batch_indices = g.batch_num_nodes()
264
+ # Find non-zero rows (non-padded nodes)
265
+ non_padded_nodes_mask = torch.any(g.ndata['features'] != 0, dim=1)
266
+ # Split the mask according to the batch indices
267
+ batch_num_nodes = []
268
+ start_idx = 0
269
+ for num_nodes in batch_indices:
270
+ end_idx = start_idx + num_nodes
271
+ non_padded_count = non_padded_nodes_mask[start_idx:end_idx].sum().item()
272
+ batch_num_nodes.append(non_padded_count)
273
+ start_idx = end_idx
274
+ batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata['features'].device)
275
+ sum_weights = batch_num_nodes[:, None].repeat(1, 64)
276
+ global_feats = batch_num_nodes[:, None].to(torch.float)
277
+
278
+ h_global = self.global_encoder(global_feats)
279
+ for i in range(self.n_proc_steps):
280
+ g.apply_edges(dgl.function.copy_u('h', 'm_u'))
281
+ g.apply_edges(copy_v)
282
+ g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
283
+ g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
284
+ g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
285
+ if "w" in g.ndata:
286
+ mean_nodes = dgl.sum_nodes(g, 'h', 'w') / sum_weights
287
+ h_global = self.global_update(torch.cat((h_global, mean_nodes, dgl.mean_edges(g, 'e')), dim = 1))
288
+ else:
289
+ h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
290
+ before_global_decoder = h_global
291
+ after_global_decoder = self.global_decoder(before_global_decoder)
292
+ after_classify = self.classify(after_global_decoder)
293
+ return before_global_decoder, after_global_decoder, after_classify
294
+
295
+ def __str__(self):
296
+ layer_names = ["node_encoder", "edge_encoder", "global_encoder",
297
+ "node_update", "edge_update", "global_update", "global_decoder"]
298
+
299
+ layers = [self.node_encoder, self.edge_encoder, self.global_encoder,
300
+ self.node_update, self.edge_update, self.global_update, self.global_decoder]
301
+
302
+ for i in range(len(layers)):
303
+ print(layer_names[i])
304
+ for layer in layers[i].children():
305
+ if isinstance(layer, nn.Linear):
306
+ print(layer.state_dict())
307
+
308
+ print("classify")
309
+ print(self.classify.weight)
310
+ return ""
311
+
312
+ class Transferred_Learning(nn.Module):
313
+ def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs):
314
+ super().__init__()
315
+ print(f'Unused args while creating GCN: {kwargs}')
316
+ self.n_layers = n_layers
317
+ self.n_proc_steps = n_proc_steps
318
+ self.layers = nn.ModuleList()
319
+
320
+ if (len(sample_global) == 0):
321
+ self.has_global = False
322
+ else:
323
+ self.has_global = sample_global.shape[1] != 0
324
+ gl_size = sample_global.shape[1] if self.has_global else 1
325
+
326
+ self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
327
+
328
+ checkpoint = torch.load(pretraining_path)
329
+ self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
330
+ pretrained_layers = list(self.pretrained_model.children())
331
+ pretrained_layers = pretrained_layers[:-1]
332
+ self.pretrained_model = nn.Sequential(*pretrained_layers)
333
+
334
+ # Freeze Weights
335
+ for param in self.pretrained_model.parameters():
336
+ param.requires_grad = False # Freeze all layers
337
+
338
+ self.global_decoder = Make_MLP(pretraining_model['args']['hid_size'], hid_size, hid_size, n_layers, dropout=dropout)
339
+ self.classify = nn.Linear(hid_size, out_size)
340
+
341
+ def TL_node_encoder(self, x):
342
+ for layer in self.pretrained_model[1]:
343
+ x = layer(x)
344
+ return x
345
+
346
+ def TL_edge_encoder(self, x):
347
+ for layer in self.pretrained_model[2]:
348
+ x = layer(x)
349
+ return x
350
+
351
+ def TL_global_encoder(self, x):
352
+ for layer in self.pretrained_model[3]:
353
+ x = layer(x)
354
+ return x
355
+
356
+ def TL_node_update(self, x):
357
+ for layer in self.pretrained_model[4]:
358
+ x = layer(x)
359
+ return x
360
+
361
+ def TL_edge_update(self, x):
362
+ for layer in self.pretrained_model[5]:
363
+ x = layer(x)
364
+ return x
365
+
366
+ def TL_global_update(self, x):
367
+ for layer in self.pretrained_model[6]:
368
+ x = layer(x)
369
+ return x
370
+
371
+ def TL_global_decoder(self, x):
372
+ for layer in self.pretrained_model[7]:
373
+ x = layer(x)
374
+ return x
375
+
376
+ def forward(self, g, global_feats):
377
+ h = self.TL_node_encoder(g.ndata['features'])
378
+ e = self.TL_edge_encoder(g.edata['features'])
379
+ g.ndata['h'] = h
380
+ g.edata['e'] = e
381
+ if not self.has_global:
382
+ global_feats = g.batch_num_nodes()[:, None].to(torch.float)
383
+ h_global = self.TL_global_encoder(global_feats)
384
+ for i in range(self.n_proc_steps):
385
+ g.apply_edges(dgl.function.copy_u('h', 'm_u'))
386
+ g.apply_edges(copy_v)
387
+ g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
388
+ g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
389
+ g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
390
+ h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
391
+ h_global = self.TL_global_decoder(h_global)
392
+ return self.classify(self.global_decoder(h_global))
393
+
394
+ class Transferred_Learning_Graph(nn.Module):
395
+ def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, additional_proc_steps=1, dropout=0, **kwargs):
396
+ super().__init__()
397
+ print(f'Unused args while creating GCN: {kwargs}')
398
+ self.n_layers = n_layers
399
+ self.n_proc_steps = n_proc_steps
400
+ self.layers = nn.ModuleList()
401
+
402
+ if (len(sample_global) == 0):
403
+ self.has_global = False
404
+ else:
405
+ self.has_global = sample_global.shape[1] != 0
406
+ gl_size = sample_global.shape[1] if self.has_global else 1
407
+
408
+ self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
409
+
410
+ checkpoint = torch.load(pretraining_path)
411
+ self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
412
+ pretrained_layers = list(self.pretrained_model.children())
413
+ pretrained_layers = pretrained_layers[:-1]
414
+ self.pretrained_model = nn.Sequential(*pretrained_layers)
415
+
416
+ self.additional_proc_steps = additional_proc_steps
417
+
418
+ # Freeze Weights
419
+ for param in self.pretrained_model.parameters():
420
+ param.requires_grad = False # Freeze all layers
421
+
422
+ #GNN
423
+ self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
424
+ self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
425
+ self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
426
+
427
+ #decoder
428
+ self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
429
+ self.classify = nn.Linear(hid_size, out_size)
430
+
431
+ def TL_node_encoder(self, x):
432
+ for layer in self.pretrained_model[1]:
433
+ x = layer(x)
434
+ return x
435
+
436
+ def TL_edge_encoder(self, x):
437
+ for layer in self.pretrained_model[2]:
438
+ x = layer(x)
439
+ return x
440
+
441
+ def TL_global_encoder(self, x):
442
+ for layer in self.pretrained_model[3]:
443
+ x = layer(x)
444
+ return x
445
+
446
+ def TL_node_update(self, x):
447
+ for layer in self.pretrained_model[4]:
448
+ x = layer(x)
449
+ return x
450
+
451
+ def TL_edge_update(self, x):
452
+ for layer in self.pretrained_model[5]:
453
+ x = layer(x)
454
+ return x
455
+
456
+ def TL_global_update(self, x):
457
+ for layer in self.pretrained_model[6]:
458
+ x = layer(x)
459
+ return x
460
+
461
+ def forward(self, g, global_feats):
462
+ h = self.TL_node_encoder(g.ndata['features'])
463
+ e = self.TL_edge_encoder(g.edata['features'])
464
+ g.ndata['h'] = h
465
+ g.edata['e'] = e
466
+ if not self.has_global:
467
+ global_feats = g.batch_num_nodes()[:, None].to(torch.float)
468
+ h_global = self.TL_global_encoder(global_feats)
469
+ for i in range(self.n_proc_steps):
470
+ g.apply_edges(dgl.function.copy_u('h', 'm_u'))
471
+ g.apply_edges(copy_v)
472
+ g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
473
+ g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
474
+ g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
475
+ h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
476
+ for j in range(self.additional_proc_steps):
477
+ g.apply_edges(dgl.function.copy_u('h', 'm_u'))
478
+ g.apply_edges(copy_v)
479
+ g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
480
+ g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
481
+ g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
482
+ h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
483
+
484
+ h_global = self.global_decoder(h_global)
485
+ return self.classify(h_global)
486
+
487
+ class Transferred_Learning_Parallel(nn.Module):
488
+ def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs):
489
+ super().__init__()
490
+ print(f'Unused args while creating GCN: {kwargs}')
491
+ self.n_layers = n_layers
492
+ self.n_proc_steps = n_proc_steps
493
+ self.layers = nn.ModuleList()
494
+ self.has_global = sample_global.shape[1] != 0
495
+ gl_size = sample_global.shape[1] if self.has_global else 1
496
+
497
+ self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
498
+ checkpoint = torch.load(pretraining_path)
499
+ self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
500
+ pretrained_layers = list(self.pretrained_model.children())
501
+ pretrained_layers = pretrained_layers[:-1]
502
+ self.pretrained_model = nn.Sequential(*pretrained_layers)
503
+
504
+ # Freeze Weights
505
+ for param in self.pretrained_model.parameters():
506
+ param.requires_grad = False # Freeze all layers
507
+
508
+ #encoder
509
+ self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
510
+ self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
511
+ self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
512
+
513
+ #GNN
514
+ self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
515
+ self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
516
+ self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
517
+
518
+ #decoder
519
+ self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
520
+ self.classify = nn.Linear(hid_size + pretraining_model['args']['hid_size'], out_size)
521
+
522
+ def TL_node_encoder(self, x):
523
+ for layer in self.pretrained_model[1]:
524
+ x = layer(x)
525
+ return x
526
+
527
+ def TL_edge_encoder(self, x):
528
+ for layer in self.pretrained_model[2]:
529
+ x = layer(x)
530
+ return x
531
+
532
+ def TL_global_encoder(self, x):
533
+ for layer in self.pretrained_model[3]:
534
+ x = layer(x)
535
+ return x
536
+
537
+ def TL_node_update(self, x):
538
+ for layer in self.pretrained_model[4]:
539
+ x = layer(x)
540
+ return x
541
+
542
+ def TL_edge_update(self, x):
543
+ for layer in self.pretrained_model[5]:
544
+ x = layer(x)
545
+ return x
546
+
547
+ def TL_global_update(self, x):
548
+ for layer in self.pretrained_model[6]:
549
+ x = layer(x)
550
+ return x
551
+
552
+ def TL_global_decoder(self, x):
553
+ for layer in self.pretrained_model[7]:
554
+ x = layer(x)
555
+ return x
556
+
557
+ def Pretrained_Output(self, g):
558
+ h = self.TL_node_encoder(g.ndata['features'])
559
+ e = self.TL_edge_encoder(g.edata['features'])
560
+ g.ndata['h'] = h
561
+ g.edata['e'] = e
562
+ if not self.has_global:
563
+ global_feats = g.batch_num_nodes()[:, None].to(torch.float)
564
+ h_global = self.TL_global_encoder(global_feats)
565
+ for i in range(self.n_proc_steps):
566
+ g.apply_edges(dgl.function.copy_u('h', 'm_u'))
567
+ g.apply_edges(copy_v)
568
+ g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
569
+ g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
570
+ g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
571
+ h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
572
+ h_global = self.TL_global_decoder(h_global)
573
+ return h_global
574
+
575
+ def forward(self, g, global_feats):
576
+ pretrained_global = self.Pretrained_Output(g.clone())
577
+ h = self.node_encoder(g.ndata['features'])
578
+ e = self.edge_encoder(g.edata['features'])
579
+ g.ndata['h'] = h
580
+ g.edata['e'] = e
581
+ if not self.has_global:
582
+ global_feats = g.batch_num_nodes()[:, None].to(torch.float)
583
+ h_global = self.global_encoder(global_feats)
584
+ for i in range(self.n_proc_steps):
585
+ g.apply_edges(dgl.function.copy_u('h', 'm_u'))
586
+ g.apply_edges(copy_v)
587
+ g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
588
+ g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
589
+ g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
590
+ h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
591
+ h_global = self.global_decoder(h_global)
592
+
593
+ return self.classify(torch.cat((pretrained_global, h_global), dim = 1))
594
+
595
+ class Transferred_Learning_Sequential(nn.Module):
596
+ def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs):
597
+ super().__init__()
598
+ print(f'Unused args while creating GCN: {kwargs}')
599
+ self.n_layers = n_layers
600
+ self.n_proc_steps = n_proc_steps
601
+ self.layers = nn.ModuleList()
602
+ self.has_global = sample_global.shape[1] != 0
603
+ gl_size = sample_global.shape[1] if self.has_global else 1
604
+
605
+ self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
606
+ checkpoint = torch.load(pretraining_path)
607
+ self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
608
+ pretrained_layers = list(self.pretrained_model.children())
609
+ pretrained_layers = pretrained_layers[:-1]
610
+ self.pretrained_model = nn.Sequential(*pretrained_layers)
611
+
612
+ # Freeze Weights
613
+ for param in self.pretrained_model.parameters():
614
+ param.requires_grad = False # Freeze all layers
615
+
616
+ #encoder
617
+ self.mlp = Make_MLP(pretraining_model['args']['hid_size'], hid_size, hid_size, n_layers, dropout=dropout)
618
+
619
+ self.classify = nn.Linear(hid_size, out_size)
620
+
621
+ def TL_node_encoder(self, x):
622
+ for layer in self.pretrained_model[1]:
623
+ x = layer(x)
624
+ return x
625
+
626
+ def TL_edge_encoder(self, x):
627
+ for layer in self.pretrained_model[2]:
628
+ x = layer(x)
629
+ return x
630
+
631
+ def TL_global_encoder(self, x):
632
+ for layer in self.pretrained_model[3]:
633
+ x = layer(x)
634
+ return x
635
+
636
+ def TL_node_update(self, x):
637
+ for layer in self.pretrained_model[4]:
638
+ x = layer(x)
639
+ return x
640
+
641
+ def TL_edge_update(self, x):
642
+ for layer in self.pretrained_model[5]:
643
+ x = layer(x)
644
+ return x
645
+
646
+ def TL_global_update(self, x):
647
+ for layer in self.pretrained_model[6]:
648
+ x = layer(x)
649
+ return x
650
+
651
+ def TL_global_decoder(self, x):
652
+ for layer in self.pretrained_model[7]:
653
+ x = layer(x)
654
+ return x
655
+
656
+ def Pretrained_Output(self, g):
657
+ h = self.TL_node_encoder(g.ndata['features'])
658
+ e = self.TL_edge_encoder(g.edata['features'])
659
+ g.ndata['h'] = h
660
+ g.edata['e'] = e
661
+ if not self.has_global:
662
+ global_feats = g.batch_num_nodes()[:, None].to(torch.float)
663
+ h_global = self.TL_global_encoder(global_feats)
664
+ for i in range(self.n_proc_steps):
665
+ g.apply_edges(dgl.function.copy_u('h', 'm_u'))
666
+ g.apply_edges(copy_v)
667
+ g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
668
+ g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
669
+ g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
670
+ h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
671
+ h_global = self.TL_global_decoder(h_global)
672
+ return h_global
673
+
674
+ def forward(self, g, global_feats):
675
+ pretrained_global = self.Pretrained_Output(g.clone())
676
+ global_features = self.mlp(pretrained_global)
677
+ return self.classify(global_features)
678
+
679
+
680
+ class Transferred_Learning_Message_Passing(nn.Module):
681
+ def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs):
682
+ super().__init__()
683
+ print(f'Unused args while creating GCN: {kwargs}')
684
+ self.n_layers = n_layers
685
+ self.n_proc_steps = n_proc_steps
686
+ self.layers = nn.ModuleList()
687
+ self.has_global = sample_global.shape[1] != 0
688
+ gl_size = sample_global.shape[1] if self.has_global else 1
689
+
690
+ self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
691
+ checkpoint = torch.load(pretraining_path)
692
+ self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
693
+ pretrained_layers = list(self.pretrained_model.children())
694
+ pretrained_layers = pretrained_layers[:-1]
695
+ self.pretrained_model = nn.Sequential(*pretrained_layers)
696
+
697
+ # Freeze Weights
698
+ for param in self.pretrained_model.parameters():
699
+ param.requires_grad = False # Freeze all layers
700
+
701
+ #encoder
702
+ self.mlp = Make_MLP(pretraining_model['args']['hid_size']*pretraining_model['args']['n_proc_steps'], hid_size, hid_size, n_layers, dropout=dropout)
703
+
704
+ self.classify = nn.Linear(hid_size, out_size)
705
+
706
+ def TL_node_encoder(self, x):
707
+ for layer in self.pretrained_model[1]:
708
+ x = layer(x)
709
+ return x
710
+
711
+ def TL_edge_encoder(self, x):
712
+ for layer in self.pretrained_model[2]:
713
+ x = layer(x)
714
+ return x
715
+
716
+ def TL_global_encoder(self, x):
717
+ for layer in self.pretrained_model[3]:
718
+ x = layer(x)
719
+ return x
720
+
721
+ def TL_node_update(self, x):
722
+ for layer in self.pretrained_model[4]:
723
+ x = layer(x)
724
+ return x
725
+
726
+ def TL_edge_update(self, x):
727
+ for layer in self.pretrained_model[5]:
728
+ x = layer(x)
729
+ return x
730
+
731
+ def TL_global_update(self, x):
732
+ for layer in self.pretrained_model[6]:
733
+ x = layer(x)
734
+ return x
735
+
736
+ def TL_global_decoder(self, x):
737
+ for layer in self.pretrained_model[7]:
738
+ x = layer(x)
739
+ return x
740
+
741
+ def Pretrained_Output(self, g):
742
+ message_passing = None
743
+ h = self.TL_node_encoder(g.ndata['features'])
744
+ e = self.TL_edge_encoder(g.edata['features'])
745
+ g.ndata['h'] = h
746
+ g.edata['e'] = e
747
+ if not self.has_global:
748
+ global_feats = g.batch_num_nodes()[:, None].to(torch.float)
749
+ h_global = self.TL_global_encoder(global_feats)
750
+ for i in range(self.n_proc_steps):
751
+ g.apply_edges(dgl.function.copy_u('h', 'm_u'))
752
+ g.apply_edges(copy_v)
753
+ g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
754
+ g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
755
+ g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
756
+ h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
757
+ if (message_passing is None):
758
+ message_passing = h_global.clone()
759
+ else:
760
+ message_passing = torch.cat((message_passing, h_global.clone()), dim=1)
761
+ h_global = self.TL_global_decoder(h_global)
762
+ return message_passing
763
+
764
+ def forward(self, g, global_feats):
765
+ pretrained_global = self.Pretrained_Output(g.clone())
766
+ #print(f"message_passing layers have size = {pretrained_global.shape}")
767
+ #print(pretrained_global)
768
+ global_features = self.mlp(pretrained_global)
769
+ return self.classify(global_features)
770
+
771
+ class Transferred_Learning_Message_Passing_Parallel(nn.Module):
772
+ def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs):
773
+ super().__init__()
774
+ print(f'Unused args while creating GCN: {kwargs}')
775
+ self.n_layers = n_layers
776
+ self.n_proc_steps = n_proc_steps
777
+ self.layers = nn.ModuleList()
778
+ self.has_global = sample_global.shape[1] != 0
779
+ gl_size = sample_global.shape[1] if self.has_global else 1
780
+
781
+ self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
782
+ checkpoint = torch.load(pretraining_path)
783
+ self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
784
+ pretrained_layers = list(self.pretrained_model.children())
785
+ pretrained_layers = pretrained_layers[:-1]
786
+ self.pretrained_model = nn.Sequential(*pretrained_layers)
787
+
788
+ #encoder
789
+ self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
790
+ self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
791
+ self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
792
+
793
+ #GNN
794
+ self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
795
+ self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
796
+ self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
797
+
798
+ #decoder
799
+ self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
800
+
801
+ # Freeze Weights
802
+ for param in self.pretrained_model.parameters():
803
+ param.requires_grad = False # Freeze all layers
804
+
805
+ self.classify = nn.Linear(pretraining_model['args']['hid_size']*pretraining_model['args']['n_proc_steps'] + hid_size, out_size)
806
+
807
+ def TL_node_encoder(self, x):
808
+ for layer in self.pretrained_model[1]:
809
+ x = layer(x)
810
+ return x
811
+
812
+ def TL_edge_encoder(self, x):
813
+ for layer in self.pretrained_model[2]:
814
+ x = layer(x)
815
+ return x
816
+
817
+ def TL_global_encoder(self, x):
818
+ for layer in self.pretrained_model[3]:
819
+ x = layer(x)
820
+ return x
821
+
822
+ def TL_node_update(self, x):
823
+ for layer in self.pretrained_model[4]:
824
+ x = layer(x)
825
+ return x
826
+
827
+ def TL_edge_update(self, x):
828
+ for layer in self.pretrained_model[5]:
829
+ x = layer(x)
830
+ return x
831
+
832
+ def TL_global_update(self, x):
833
+ for layer in self.pretrained_model[6]:
834
+ x = layer(x)
835
+ return x
836
+
837
+ def TL_global_decoder(self, x):
838
+ for layer in self.pretrained_model[7]:
839
+ x = layer(x)
840
+ return x
841
+
842
+ def Pretrained_Output(self, g):
843
+ message_passing = None
844
+ h = self.TL_node_encoder(g.ndata['features'])
845
+ e = self.TL_edge_encoder(g.edata['features'])
846
+ g.ndata['h'] = h
847
+ g.edata['e'] = e
848
+ if not self.has_global:
849
+ global_feats = g.batch_num_nodes()[:, None].to(torch.float)
850
+ h_global = self.TL_global_encoder(global_feats)
851
+ for i in range(self.n_proc_steps):
852
+ g.apply_edges(dgl.function.copy_u('h', 'm_u'))
853
+ g.apply_edges(copy_v)
854
+ g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
855
+ g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
856
+ g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
857
+ h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
858
+ if (message_passing is None):
859
+ message_passing = h_global.clone()
860
+ else:
861
+ message_passing = torch.cat((message_passing, h_global.clone()), dim=1)
862
+ h_global = self.TL_global_decoder(h_global)
863
+ return message_passing
864
+
865
+ def forward(self, g, global_feats):
866
+ pretrained_message = self.Pretrained_Output(g.clone())
867
+ h = self.node_encoder(g.ndata['features'])
868
+ e = self.edge_encoder(g.edata['features'])
869
+ g.ndata['h'] = h
870
+ g.edata['e'] = e
871
+ if not self.has_global:
872
+ global_feats = g.batch_num_nodes()[:, None].to(torch.float)
873
+ h_global = self.global_encoder(global_feats)
874
+ for i in range(self.n_proc_steps):
875
+ g.apply_edges(dgl.function.copy_u('h', 'm_u'))
876
+ g.apply_edges(copy_v)
877
+ g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
878
+ g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
879
+ g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
880
+ h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
881
+ h_global = self.global_decoder(h_global)
882
+ return self.classify(torch.cat((pretrained_message, h_global), dim = 1))
883
+
884
+ class Transferred_Learning_Finetuning(nn.Module):
885
+ def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, frozen_pretraining=False, **kwargs):
886
+ super().__init__()
887
+ print(f'Unused args while creating GCN: {kwargs}')
888
+ self.n_layers = n_layers
889
+ self.n_proc_steps = n_proc_steps
890
+ self.layers = nn.ModuleList()
891
+
892
+ if (len(sample_global) == 0):
893
+ self.has_global = False
894
+ else:
895
+ self.has_global = sample_global.shape[1] != 0
896
+ gl_size = sample_global.shape[1] if self.has_global else 1
897
+
898
+ self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
899
+
900
+ checkpoint = torch.load(pretraining_path)
901
+ self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
902
+ pretrained_layers = list(self.pretrained_model.children())
903
+ pretrained_layers = pretrained_layers[:-1]
904
+ self.pretrained_model = nn.Sequential(*pretrained_layers)
905
+
906
+ print(f"Freeze Pretraining = {frozen_pretraining}")
907
+ if (frozen_pretraining):
908
+ for param in self.pretrained_model.parameters():
909
+ param.requires_grad = False # Freeze all layers
910
+ for param in self.pretrained_model[7]:
911
+ param.requires_grad = True
912
+
913
+ torch.manual_seed(2)
914
+ self.classify = nn.Linear(pretraining_model['args']['hid_size'], out_size)
915
+
916
+ def TL_node_encoder(self, x):
917
+ for layer in self.pretrained_model[1]:
918
+ x = layer(x)
919
+ return x
920
+
921
+ def TL_edge_encoder(self, x):
922
+ for layer in self.pretrained_model[2]:
923
+ x = layer(x)
924
+ return x
925
+
926
+ def TL_global_encoder(self, x):
927
+ for layer in self.pretrained_model[3]:
928
+ x = layer(x)
929
+ return x
930
+
931
+ def TL_node_update(self, x):
932
+ for layer in self.pretrained_model[4]:
933
+ x = layer(x)
934
+ return x
935
+
936
+ def TL_edge_update(self, x):
937
+ for layer in self.pretrained_model[5]:
938
+ x = layer(x)
939
+ return x
940
+
941
+ def TL_global_update(self, x):
942
+ for layer in self.pretrained_model[6]:
943
+ x = layer(x)
944
+ return x
945
+
946
+ def TL_global_decoder(self, x):
947
+ for layer in self.pretrained_model[7]:
948
+ x = layer(x)
949
+ return x
950
+
951
+ def Pretrained_Output(self, g):
952
+ h = self.TL_node_encoder(g.ndata['features'])
953
+ e = self.TL_edge_encoder(g.edata['features'])
954
+ g.ndata['h'] = h
955
+ g.edata['e'] = e
956
+ if not self.has_global:
957
+ global_feats = g.batch_num_nodes()[:, None].to(torch.float)
958
+ h_global = self.TL_global_encoder(global_feats)
959
+ for i in range(self.n_proc_steps):
960
+ g.apply_edges(dgl.function.copy_u('h', 'm_u'))
961
+ g.apply_edges(copy_v)
962
+ g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
963
+ g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
964
+ g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
965
+ h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
966
+ h_global = self.TL_global_decoder(h_global)
967
+ return h_global
968
+
969
+ def forward(self, g, global_feats):
970
+ h_global = self.Pretrained_Output(g.clone())
971
+ return self.classify(h_global)
972
+
973
+ def representation(self, g, global_feats):
974
+ h = self.TL_node_encoder(g.ndata['features'])
975
+ e = self.TL_edge_encoder(g.edata['features'])
976
+ g.ndata['h'] = h
977
+ g.edata['e'] = e
978
+ if not self.has_global:
979
+ global_feats = g.batch_num_nodes()[:, None].to(torch.float)
980
+ h_global = self.TL_global_encoder(global_feats)
981
+ for i in range(self.n_proc_steps):
982
+ g.apply_edges(dgl.function.copy_u('h', 'm_u'))
983
+ g.apply_edges(copy_v)
984
+ g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
985
+ g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
986
+ g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
987
+ h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
988
+
989
+ before_global_decoder = h_global
990
+ after_global_decoder = self.TL_global_decoder(before_global_decoder)
991
+ after_classify = self.classify(after_global_decoder)
992
+ return before_global_decoder, after_global_decoder, after_classify
993
+
994
+ def __str__(self):
995
+ layer_names = ["node_encoder", "edge_encoder", "global_encoder",
996
+ "node_update", "edge_update", "global_update", "global_decoder"]
997
+
998
+ layers = [self.pretrained_model[1], self.pretrained_model[2], self.pretrained_model[3],
999
+ self.pretrained_model[4], self.pretrained_model[5], self.pretrained_model[6],
1000
+ self.pretrained_model[7]]
1001
+
1002
+ for i in range(len(layers)):
1003
+ print(layer_names[i])
1004
+ for layer in layers[i].children():
1005
+ if isinstance(layer, nn.Linear):
1006
+ print(layer.state_dict())
1007
+
1008
+ print("classify")
1009
+ print(self.classify.weight)
1010
+ return ""
1011
+
1012
+
1013
+ class Transferred_Learning_Parallel_Finetuning(nn.Module):
1014
+ def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, learning_rate=0.0001, **kwargs):
1015
+ super().__init__()
1016
+ print(f'Unused args while creating GCN: {kwargs}')
1017
+
1018
+ self.learning_rate = learning_rate
1019
+
1020
+ self.parallel_params = []
1021
+ self.finetuning_params = []
1022
+
1023
+
1024
+ self.n_layers = n_layers
1025
+ self.n_proc_steps = n_proc_steps
1026
+ self.layers = nn.ModuleList()
1027
+ self.has_global = sample_global.shape[1] != 0
1028
+ gl_size = sample_global.shape[1] if self.has_global else 1
1029
+
1030
+ self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
1031
+ checkpoint = torch.load(pretraining_path)
1032
+ self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
1033
+ pretrained_layers = list(self.pretrained_model.children())
1034
+ pretrained_layers = pretrained_layers[:-1]
1035
+ self.pretrained_model = nn.Sequential(*pretrained_layers)
1036
+
1037
+ self.finetuning_params.append(self.pretrained_model)
1038
+
1039
+ #encoder
1040
+ self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
1041
+ self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
1042
+ self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
1043
+
1044
+ #GNN
1045
+ self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1046
+ self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1047
+ self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1048
+
1049
+ #decoder
1050
+ self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1051
+ self.classify = nn.Linear(hid_size + pretraining_model['args']['hid_size'], out_size)
1052
+
1053
+ self.parallel_params.append(self.node_encoder)
1054
+ self.parallel_params.append(self.edge_encoder)
1055
+ self.parallel_params.append(self.global_encoder)
1056
+ self.parallel_params.append(self.node_update)
1057
+ self.parallel_params.append(self.edge_update)
1058
+ self.parallel_params.append(self.global_update)
1059
+ self.parallel_params.append(self.global_decoder)
1060
+ self.parallel_params.append(self.classify)
1061
+
1062
+ def TL_node_encoder(self, x):
1063
+ for layer in self.pretrained_model[1]:
1064
+ x = layer(x)
1065
+ return x
1066
+
1067
+ def TL_edge_encoder(self, x):
1068
+ for layer in self.pretrained_model[2]:
1069
+ x = layer(x)
1070
+ return x
1071
+
1072
+ def TL_global_encoder(self, x):
1073
+ for layer in self.pretrained_model[3]:
1074
+ x = layer(x)
1075
+ return x
1076
+
1077
+ def TL_node_update(self, x):
1078
+ for layer in self.pretrained_model[4]:
1079
+ x = layer(x)
1080
+ return x
1081
+
1082
+ def TL_edge_update(self, x):
1083
+ for layer in self.pretrained_model[5]:
1084
+ x = layer(x)
1085
+ return x
1086
+
1087
+ def TL_global_update(self, x):
1088
+ for layer in self.pretrained_model[6]:
1089
+ x = layer(x)
1090
+ return x
1091
+
1092
+ def TL_global_decoder(self, x):
1093
+ for layer in self.pretrained_model[7]:
1094
+ x = layer(x)
1095
+ return x
1096
+
1097
+ def Pretrained_Output(self, g):
1098
+ h = self.TL_node_encoder(g.ndata['features'])
1099
+ e = self.TL_edge_encoder(g.edata['features'])
1100
+ g.ndata['h'] = h
1101
+ g.edata['e'] = e
1102
+ if not self.has_global:
1103
+ global_feats = g.batch_num_nodes()[:, None].to(torch.float)
1104
+ h_global = self.TL_global_encoder(global_feats)
1105
+ for i in range(self.n_proc_steps):
1106
+ g.apply_edges(dgl.function.copy_u('h', 'm_u'))
1107
+ g.apply_edges(copy_v)
1108
+ g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
1109
+ g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
1110
+ g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
1111
+ h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
1112
+ h_global = self.TL_global_decoder(h_global)
1113
+ return h_global
1114
+
1115
+ def forward(self, g, global_feats):
1116
+ pretrained_global = self.Pretrained_Output(g.clone())
1117
+ h = self.node_encoder(g.ndata['features'])
1118
+ e = self.edge_encoder(g.edata['features'])
1119
+ g.ndata['h'] = h
1120
+ g.edata['e'] = e
1121
+ if not self.has_global:
1122
+ global_feats = g.batch_num_nodes()[:, None].to(torch.float)
1123
+ h_global = self.global_encoder(global_feats)
1124
+ for i in range(self.n_proc_steps):
1125
+ g.apply_edges(dgl.function.copy_u('h', 'm_u'))
1126
+ g.apply_edges(copy_v)
1127
+ g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
1128
+ g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
1129
+ g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
1130
+ h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
1131
+ h_global = self.global_decoder(h_global)
1132
+
1133
+ return self.classify(torch.cat((pretrained_global, h_global), dim = 1))
1134
+
1135
+ def parameters(self, recurse: bool = True):
1136
+ params = []
1137
+ for model_section in self.parallel_params:
1138
+ if (type(self.learning_rate) == dict and self.learning_rate["trainable_lr"]):
1139
+ params.append({'params': model_section.parameters(), 'lr': self.learning_rate["trainable_lr"]})
1140
+ else:
1141
+ params.append({'params': model_section.parameters(), 'lr': 0.0001})
1142
+ for model_section in self.finetuning_params:
1143
+ if (type(self.learning_rate) == dict and self.learning_rate["finetuning_lr"]):
1144
+ params.append({'params': model_section.parameters(), 'lr': self.learning_rate["finetuning_lr"]})
1145
+ else:
1146
+ params.append({'params': model_section.parameters(), 'lr': 0.0001})
1147
+ return params
1148
+
1149
+ class Attention(nn.Module):
1150
+ def __init__(self, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, num_heads = 1, **kwargs):
1151
+ super().__init__()
1152
+ print(f'Unused args while creating GCN: {kwargs}')
1153
+ self.n_layers = n_layers
1154
+ self.n_proc_steps = n_proc_steps
1155
+ self.layers = nn.ModuleList()
1156
+ self.has_global = sample_global.shape[1] != 0
1157
+ gl_size = sample_global.shape[1] if self.has_global else 1
1158
+
1159
+ #encoder
1160
+ self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
1161
+ self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
1162
+
1163
+ #GNN
1164
+ self.node_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1165
+ self.global_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1166
+
1167
+ #decoder
1168
+ self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1169
+ self.classify = nn.Linear(hid_size, out_size)
1170
+
1171
+ #attention
1172
+ self.multihead_attn = nn.MultiheadAttention(hid_size, num_heads, dropout=dropout, batch_first=True)
1173
+ self.queries = nn.Linear(hid_size, hid_size)
1174
+ self.keys = nn.Linear(hid_size, hid_size)
1175
+ self.values = nn.Linear(hid_size, hid_size)
1176
+
1177
+ def forward(self, g, global_feats):
1178
+ h = self.node_encoder(g.ndata['features'])
1179
+ g.ndata['h'] = h
1180
+
1181
+ if not self.has_global:
1182
+ global_feats = g.batch_num_nodes()[:, None].to(torch.float)
1183
+
1184
+ batch_num_nodes = None
1185
+ sum_weights = None
1186
+ if "w" in g.ndata:
1187
+ batch_indices = g.batch_num_nodes()
1188
+ # Find non-zero rows (non-padded nodes)
1189
+ non_padded_nodes_mask = torch.any(g.ndata['features'] != 0, dim=1)
1190
+ # Split the mask according to the batch indices
1191
+ batch_num_nodes = []
1192
+ start_idx = 0
1193
+ for num_nodes in batch_indices:
1194
+ end_idx = start_idx + num_nodes
1195
+ non_padded_count = non_padded_nodes_mask[start_idx:end_idx].sum().item()
1196
+ batch_num_nodes.append(non_padded_count)
1197
+ start_idx = end_idx
1198
+ batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata['features'].device)
1199
+ sum_weights = batch_num_nodes[:, None].repeat(1, 64)
1200
+ global_feats = batch_num_nodes[:, None].to(torch.float)
1201
+
1202
+ h_global = self.global_encoder(global_feats)
1203
+
1204
+ h_original_shape = h.shape
1205
+ num_graphs = len(dgl.unbatch(g))
1206
+ num_nodes = g.batch_num_nodes()[0].item()
1207
+ padding_mask = g.ndata['padding_mask'] > 0
1208
+ padding_mask = torch.reshape(padding_mask, (num_graphs, num_nodes))
1209
+
1210
+ h = g.ndata['h']
1211
+ query = self.queries(h)
1212
+ key = self.keys(h)
1213
+ value = self.values(h)
1214
+ query = torch.reshape(query, (num_graphs, num_nodes, h_original_shape[1]))
1215
+ key = torch.reshape(key, (num_graphs, num_nodes, h_original_shape[1]))
1216
+ value = torch.reshape(value, (num_graphs, num_nodes, h_original_shape[1]))
1217
+ h, _ = self.multihead_attn(query, key, value, key_padding_mask=padding_mask)
1218
+ h = torch.reshape(h, h_original_shape)
1219
+
1220
+ h = self.node_update(torch.cat((h, broadcast_global_to_nodes(g, h_global)), dim = 1))
1221
+ g.ndata['h'] = h
1222
+ mean_nodes = dgl.sum_nodes(g, 'h', 'w') / sum_weights
1223
+ h_global = self.global_update(torch.cat((h_global, mean_nodes), dim = 1))
1224
+ h_global = self.global_decoder(h_global)
1225
+ return self.classify(h_global)
1226
+
1227
+ class Attention_Edge_Network(nn.Module):
1228
+ def __init__(self, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, num_heads = 1, **kwargs):
1229
+ super().__init__()
1230
+ print(f'Unused args while creating GCN: {kwargs}')
1231
+ self.n_layers = n_layers
1232
+ self.n_proc_steps = n_proc_steps
1233
+ self.layers = nn.ModuleList()
1234
+ self.has_global = sample_global.shape[1] != 0
1235
+ gl_size = sample_global.shape[1] if self.has_global else 1
1236
+
1237
+ #encoder
1238
+ self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
1239
+ self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
1240
+ self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
1241
+
1242
+ #GNN
1243
+ self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1244
+ self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1245
+ self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1246
+
1247
+ #decoder
1248
+ self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1249
+ self.classify = nn.Linear(hid_size, out_size)
1250
+
1251
+
1252
+ #attention
1253
+ self.multihead_attn = nn.MultiheadAttention(hid_size, num_heads, dropout=dropout, batch_first=True)
1254
+ self.queries = nn.Linear(hid_size, hid_size)
1255
+ self.keys = nn.Linear(hid_size, hid_size)
1256
+ self.values = nn.Linear(hid_size, hid_size)
1257
+
1258
+ def forward(self, g, global_feats):
1259
+ h = self.node_encoder(g.ndata['features'])
1260
+ e = self.edge_encoder(g.edata['features'])
1261
+ g.ndata['h'] = h
1262
+ g.edata['e'] = e
1263
+
1264
+ if not self.has_global:
1265
+ global_feats = g.batch_num_nodes()[:, None].to(torch.float)
1266
+ h_global = self.global_encoder(global_feats)
1267
+
1268
+ h = g.ndata['h']
1269
+ h_original_shape = h.shape
1270
+ num_graphs = len(dgl.unbatch(g))
1271
+ num_nodes = g.batch_num_nodes()[0].item()
1272
+ padding_mask = g.ndata['padding_mask'] > 0
1273
+ padding_mask = torch.reshape(padding_mask, (num_graphs, num_nodes))
1274
+
1275
+ for i in range(self.n_proc_steps):
1276
+
1277
+ h = g.ndata['h']
1278
+ query = self.queries(h)
1279
+ key = self.keys(h)
1280
+ value = self.values(h)
1281
+ query = torch.reshape(query, (num_graphs, num_nodes, h_original_shape[1]))
1282
+ key = torch.reshape(key, (num_graphs, num_nodes, h_original_shape[1]))
1283
+ value = torch.reshape(value, (num_graphs, num_nodes, h_original_shape[1]))
1284
+ h, _ = self.multihead_attn(query, key, value, key_padding_mask=padding_mask)
1285
+ h = torch.reshape(h, h_original_shape)
1286
+
1287
+ g.apply_edges(dgl.function.copy_u('h', 'm_u'))
1288
+ g.apply_edges(copy_v)
1289
+ g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
1290
+ g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
1291
+ g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
1292
+ h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h', 'w'), dgl.mean_edges(g, 'e')), dim = 1))
1293
+ h_global = self.global_decoder(h_global)
1294
+ return self.classify(h_global)
1295
+
1296
+ class Attention_Unbatched(nn.Module):
1297
+ def __init__(self, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, num_heads = 1, **kwargs):
1298
+ super().__init__()
1299
+ print(f'Unused args while creating GCN: {kwargs}')
1300
+ self.n_layers = n_layers
1301
+ self.n_proc_steps = n_proc_steps
1302
+ self.layers = nn.ModuleList()
1303
+ self.has_global = sample_global.shape[1] != 0
1304
+ gl_size = sample_global.shape[1] if self.has_global else 1
1305
+
1306
+ #encoder
1307
+ self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
1308
+ self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
1309
+ self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
1310
+
1311
+ #GNN
1312
+ self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1313
+ self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1314
+ self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1315
+
1316
+ #decoder
1317
+ self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1318
+ self.classify = nn.Linear(hid_size, out_size)
1319
+
1320
+
1321
+ #attention
1322
+ self.multihead_attn = nn.MultiheadAttention(hid_size, 1, dropout=dropout)
1323
+ self.queries = nn.Linear(hid_size, hid_size)
1324
+ self.keys = nn.Linear(hid_size, hid_size)
1325
+ self.values = nn.Linear(hid_size, hid_size)
1326
+
1327
+
1328
+
1329
+ def forward(self, g, global_feats):
1330
+
1331
+ h = self.node_encoder(g.ndata['features'])
1332
+ e = self.edge_encoder(g.edata['features'])
1333
+ g.ndata['h'] = h
1334
+ g.edata['e'] = e
1335
+
1336
+ if not self.has_global:
1337
+ global_feats = g.batch_num_nodes()[:, None].to(torch.float)
1338
+ h_global = self.global_encoder(global_feats)
1339
+
1340
+ for i in range(self.n_proc_steps):
1341
+
1342
+ unbatched_g = dgl.unbatch(g)
1343
+ for graph in unbatched_g:
1344
+ h = graph.ndata['h']
1345
+ h, _ = self.multihead_attn(self.queries(h), self.keys(h), self.values(h))
1346
+ graph.ndata['h'] = h
1347
+ g = dgl.batch(unbatched_g)
1348
+
1349
+ g.apply_edges(dgl.function.copy_u('h', 'm_u'))
1350
+ g.apply_edges(copy_v)
1351
+ g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
1352
+ g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
1353
+ g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
1354
+ h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
1355
+ h_global = self.global_decoder(h_global)
1356
+ return self.classify(h_global)
1357
+
1358
+ class Transferred_Learning_Attention(nn.Module):
1359
+ def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, num_heads, dropout=0, learning_rate=0.0001, **kwargs):
1360
+ super().__init__()
1361
+ print(f'Unused args while creating GCN: {kwargs}')
1362
+ self.n_layers = n_layers
1363
+ self.n_proc_steps = n_proc_steps
1364
+ self.layers = nn.ModuleList()
1365
+ self.has_global = sample_global.shape[1] != 0
1366
+ gl_size = sample_global.shape[1] if self.has_global else 1
1367
+
1368
+ self.learning_rate = learning_rate
1369
+
1370
+ self.pretraining_params = []
1371
+ self.attention_params = []
1372
+
1373
+ self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
1374
+
1375
+ checkpoint = torch.load(pretraining_path)
1376
+ self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
1377
+ pretrained_layers = list(self.pretrained_model.children())
1378
+ pretrained_layers = pretrained_layers[:-1]
1379
+ self.pretrained_model = nn.Sequential(*pretrained_layers)
1380
+
1381
+ self.pretraining_params.append(self.pretrained_model[1])
1382
+ self.pretraining_params.append(self.pretrained_model[3])
1383
+ self.pretraining_params.append(self.pretrained_model[7])
1384
+
1385
+ #attention
1386
+ self.multihead_attn = nn.MultiheadAttention(hid_size, num_heads, dropout=dropout, batch_first=True)
1387
+ self.queries = nn.Linear(hid_size, hid_size)
1388
+ self.keys = nn.Linear(hid_size, hid_size)
1389
+ self.values = nn.Linear(hid_size, hid_size)
1390
+
1391
+ self.node_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1392
+ self.global_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1393
+
1394
+ self.classify = nn.Linear(pretraining_model['args']['hid_size'], out_size)
1395
+
1396
+ self.attention_params.append(self.multihead_attn)
1397
+
1398
+ self.attention_params.append(self.queries)
1399
+ self.attention_params.append(self.keys)
1400
+ self.attention_params.append(self.values)
1401
+ self.attention_params.append(self.classify)
1402
+ self.attention_params.append(self.node_update)
1403
+ self.attention_params.append(self.global_update)
1404
+
1405
+ def TL_node_encoder(self, x):
1406
+ for layer in self.pretrained_model[1]:
1407
+ x = layer(x)
1408
+ return x
1409
+
1410
+ def TL_global_encoder(self, x):
1411
+ for layer in self.pretrained_model[3]:
1412
+ x = layer(x)
1413
+ return x
1414
+
1415
+ def TL_global_decoder(self, x):
1416
+ for layer in self.pretrained_model[7]:
1417
+ x = layer(x)
1418
+ return x
1419
+
1420
+ def forward(self, g, global_feats):
1421
+ h = self.TL_node_encoder(g.ndata['features'])
1422
+ g.ndata['h'] = h
1423
+
1424
+ if not self.has_global:
1425
+ global_feats = g.batch_num_nodes()[:, None].to(torch.float)
1426
+
1427
+ batch_num_nodes = None
1428
+ sum_weights = None
1429
+ if "w" in g.ndata:
1430
+ batch_indices = g.batch_num_nodes()
1431
+ # Find non-zero rows (non-padded nodes)
1432
+ non_padded_nodes_mask = torch.any(g.ndata['features'] != 0, dim=1)
1433
+ # Split the mask according to the batch indices
1434
+ batch_num_nodes = []
1435
+ start_idx = 0
1436
+ for num_nodes in batch_indices:
1437
+ end_idx = start_idx + num_nodes
1438
+ non_padded_count = non_padded_nodes_mask[start_idx:end_idx].sum().item()
1439
+ batch_num_nodes.append(non_padded_count)
1440
+ start_idx = end_idx
1441
+ batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata['features'].device)
1442
+ sum_weights = batch_num_nodes[:, None].repeat(1, 64)
1443
+ global_feats = batch_num_nodes[:, None].to(torch.float)
1444
+
1445
+ h_global = self.TL_global_encoder(global_feats)
1446
+
1447
+ h_original_shape = h.shape
1448
+ num_graphs = len(dgl.unbatch(g))
1449
+ num_nodes = g.batch_num_nodes()[0].item()
1450
+ padding_mask = g.ndata['padding_mask'] > 0
1451
+ padding_mask = torch.reshape(padding_mask, (num_graphs, num_nodes))
1452
+
1453
+ h = g.ndata['h']
1454
+ query = self.queries(h)
1455
+ key = self.keys(h)
1456
+ value = self.values(h)
1457
+ query = torch.reshape(query, (num_graphs, num_nodes, h_original_shape[1]))
1458
+ key = torch.reshape(key, (num_graphs, num_nodes, h_original_shape[1]))
1459
+ value = torch.reshape(value, (num_graphs, num_nodes, h_original_shape[1]))
1460
+ h, _ = self.multihead_attn(query, key, value, key_padding_mask=padding_mask)
1461
+ h = torch.reshape(h, h_original_shape)
1462
+
1463
+ h = self.node_update(torch.cat((h, broadcast_global_to_nodes(g, h_global)), dim = 1))
1464
+ g.ndata['h'] = h
1465
+ mean_nodes = dgl.sum_nodes(g, 'h', 'w') / sum_weights
1466
+ h_global = self.global_update(torch.cat((h_global, mean_nodes), dim = 1))
1467
+ h_global = self.TL_global_decoder(h_global)
1468
+ return self.classify(h_global)
1469
+
1470
+ def parameters(self, recurse: bool = True):
1471
+ params = []
1472
+ for model_section in self.pretraining_params:
1473
+ if (type(self.learning_rate) == dict and self.learning_rate["pretraining_lr"]):
1474
+ params.append({'params': model_section.parameters(), 'lr': self.learning_rate["pretraining_lr"]})
1475
+ else:
1476
+ params.append({'params': model_section.parameters(), 'lr': 0.0001})
1477
+ for model_section in self.attention_params:
1478
+ if (type(self.learning_rate) == dict and self.learning_rate["attention_lr"]):
1479
+ params.append({'params': model_section.parameters(), 'lr': self.learning_rate["attention_lr"]})
1480
+ else:
1481
+ params.append({'params': model_section.parameters(), 'lr': 0.0001})
1482
+ return params
1483
+
1484
+ class Multimodel_Transferred_Learning(nn.Module):
1485
+ def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, frozen_pretraining=True, learning_rate=None, **kwargs):
1486
+ super().__init__()
1487
+ print(f'Unused args while creating GCN: {kwargs}')
1488
+ self.n_layers = n_layers
1489
+ self.n_proc_steps = n_proc_steps
1490
+ self.layers = nn.ModuleList()
1491
+ self.has_global = sample_global.shape[1] != 0
1492
+ gl_size = sample_global.shape[1] if self.has_global else 1
1493
+
1494
+ self.learning_rate = learning_rate
1495
+ input_size = 0
1496
+
1497
+ self.pretraining_params = []
1498
+ self.model_params = []
1499
+
1500
+ self.pretrained_models = []
1501
+ for model, path in zip(pretraining_model, pretraining_path):
1502
+ input_size += model['args']['hid_size']
1503
+ model = utils.buildFromConfig(model, {'sample_graph': sample_graph, 'sample_global': sample_global})
1504
+
1505
+ checkpoint = torch.load(path)['model_state_dict']
1506
+ new_state_dict = {}
1507
+ for k, v in checkpoint.items():
1508
+ new_key = k.replace('module.', '')
1509
+ new_state_dict[new_key] = v
1510
+ model.load_state_dict(new_state_dict)
1511
+ pretrained_layers = list(model.children())
1512
+ pretrained_layers = pretrained_layers[:-1]
1513
+
1514
+ model = nn.Sequential(*pretrained_layers)
1515
+
1516
+ # Freeze Weights
1517
+ print(f"Freeze Pretraining = {frozen_pretraining}")
1518
+ if (frozen_pretraining):
1519
+ for param in model.parameters():
1520
+ param.requires_grad = False # Freeze all layers
1521
+ self.pretraining_params.append(model)
1522
+ self.pretrained_models.append(model)
1523
+
1524
+ print(f"len(pretrained_models) = {len(self.pretrained_models)}")
1525
+ print(f"input size = {input_size}")
1526
+
1527
+ self.final_mlp = Make_MLP(input_size, hid_size, hid_size, n_layers, dropout=dropout)
1528
+ self.classify = nn.Linear(hid_size, out_size)
1529
+
1530
+ self.model_params.append(self.final_mlp)
1531
+ self.model_params.append(self.classify)
1532
+
1533
+ def TL_node_encoder(self, x, model_idx):
1534
+ try:
1535
+ for layer in self.pretrained_models[model_idx][1]:
1536
+ x = layer(x)
1537
+ return x
1538
+ except (NotImplementedError, IndexError):
1539
+ for layer in self.pretrained_models[model_idx][1][1]:
1540
+ x = layer(x)
1541
+ return x
1542
+
1543
+ def TL_edge_encoder(self, x, model_idx):
1544
+ try:
1545
+ for layer in self.pretrained_models[model_idx][2]:
1546
+ x = layer(x)
1547
+ return x
1548
+ except (NotImplementedError, IndexError):
1549
+ for layer in self.pretrained_models[model_idx][1][2]:
1550
+ x = layer(x)
1551
+ return x
1552
+
1553
+ def TL_global_encoder(self, x, model_idx):
1554
+ try:
1555
+ for layer in self.pretrained_models[model_idx][3]:
1556
+ x = layer(x)
1557
+ return x
1558
+ except (NotImplementedError, IndexError):
1559
+ for layer in self.pretrained_models[model_idx][1][3]:
1560
+ x = layer(x)
1561
+ return x
1562
+
1563
+ def TL_node_update(self, x, model_idx):
1564
+ try:
1565
+ for layer in self.pretrained_models[model_idx][4]:
1566
+ x = layer(x)
1567
+ return x
1568
+ except (NotImplementedError, IndexError):
1569
+ for layer in self.pretrained_models[model_idx][1][4]:
1570
+ x = layer(x)
1571
+ return x
1572
+
1573
+ def TL_edge_update(self, x, model_idx):
1574
+ try:
1575
+ for layer in self.pretrained_models[model_idx][5]:
1576
+ x = layer(x)
1577
+ return x
1578
+ except (NotImplementedError, IndexError):
1579
+ for layer in self.pretrained_models[model_idx][1][5]:
1580
+ x = layer(x)
1581
+ return x
1582
+
1583
+ def TL_global_update(self, x, model_idx):
1584
+ try:
1585
+ for layer in self.pretrained_models[model_idx][6]:
1586
+ x = layer(x)
1587
+ return x
1588
+ except (NotImplementedError, IndexError):
1589
+ for layer in self.pretrained_models[model_idx][1][6]:
1590
+ x = layer(x)
1591
+ return x
1592
+
1593
+ def TL_global_decoder(self, x, model_idx):
1594
+ try:
1595
+ for layer in self.pretrained_models[model_idx][7]:
1596
+ x = layer(x)
1597
+ return x
1598
+ except (NotImplementedError, IndexError):
1599
+ for layer in self.pretrained_models[model_idx][1][7]:
1600
+ x = layer(x)
1601
+ return x
1602
+
1603
+ def Pretrained_Output(self, g, model_idx):
1604
+ h = self.TL_node_encoder(g.ndata['features'], model_idx)
1605
+ e = self.TL_edge_encoder(g.edata['features'], model_idx)
1606
+ g.ndata['h'] = h
1607
+ g.edata['e'] = e
1608
+ if not self.has_global:
1609
+ global_feats = g.batch_num_nodes()[:, None].to(torch.float)
1610
+ h_global = self.TL_global_encoder(global_feats, model_idx)
1611
+ for i in range(self.n_proc_steps):
1612
+ g.apply_edges(dgl.function.copy_u('h', 'm_u'))
1613
+ g.apply_edges(copy_v)
1614
+ g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1), model_idx)
1615
+ g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
1616
+ g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1), model_idx)
1617
+ h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1), model_idx)
1618
+ # h_global = self.TL_global_decoder(h_global, model_idx)
1619
+ return h_global
1620
+
1621
+ def forward(self, g, global_feats):
1622
+ h_global = []
1623
+ for i in range(len(self.pretrained_models)):
1624
+ h_global.append(self.Pretrained_Output(g.clone(), i))
1625
+ h_global = torch.concatenate(h_global, dim=1)
1626
+ return self.classify(self.final_mlp(h_global))
1627
+
1628
+ def to(self, device):
1629
+ for i in range(len(self.pretrained_models)):
1630
+ self.pretrained_models[i].to(device)
1631
+ self.classify.to(device)
1632
+ self.final_mlp.to(device)
1633
+ return self
1634
+
1635
+ def parameters(self, recurse: bool = True):
1636
+ params = []
1637
+ for model_section in self.pretraining_params:
1638
+ if (type(self.learning_rate) == dict and self.learning_rate["pretraining_lr"]):
1639
+ params.append({'params': model_section.parameters(), 'lr': self.learning_rate["pretraining_lr"]})
1640
+ else:
1641
+ params.append({'params': model_section.parameters(), 'lr': 0.00001})
1642
+ for model_section in self.model_params:
1643
+ if (type(self.learning_rate) == dict and self.learning_rate["model_lr"]):
1644
+ params.append({'params': model_section.parameters(), 'lr': self.learning_rate["model_lr"]})
1645
+ else:
1646
+ params.append({'params': model_section.parameters(), 'lr': 0.0001})
1647
+ return params
1648
+
1649
+
1650
+ class MultiModel(nn.Module):
1651
+ def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, frozen_pretraining=True, learning_rate=None, **kwargs):
1652
+ super().__init__()
1653
+ print(f'Unused args while creating GCN: {kwargs}')
1654
+ self.n_layers = n_layers
1655
+ self.n_proc_steps = n_proc_steps
1656
+ self.layers = nn.ModuleList()
1657
+ self.has_global = sample_global.shape[1] != 0
1658
+ gl_size = sample_global.shape[1] if self.has_global else 1
1659
+
1660
+ self.learning_rate = learning_rate
1661
+ input_size = 0
1662
+
1663
+ self.model_params = []
1664
+ self.pretraining_params = []
1665
+
1666
+ self.pretrained_models = []
1667
+ for model, path in zip(pretraining_model, pretraining_path):
1668
+ input_size += model['args']['hid_size']
1669
+ model = utils.buildFromConfig(model, {'sample_graph': sample_graph, 'sample_global': sample_global})
1670
+
1671
+ checkpoint = torch.load(path)['model_state_dict']
1672
+ new_state_dict = {}
1673
+ for k, v in checkpoint.items():
1674
+ new_key = k.replace('module.', '')
1675
+ new_state_dict[new_key] = v
1676
+ model.load_state_dict(new_state_dict)
1677
+ pretrained_layers = list(model.children())
1678
+ pretrained_layers = pretrained_layers[:-1]
1679
+
1680
+ model = nn.Sequential(*pretrained_layers)
1681
+
1682
+ # Freeze Weights
1683
+ print(f"Freeze Pretraining = {frozen_pretraining}")
1684
+ if (frozen_pretraining):
1685
+ for param in model.parameters():
1686
+ param.requires_grad = False # Freeze all layers
1687
+ self.pretraining_params.append(model)
1688
+ self.pretrained_models.append(model)
1689
+
1690
+ print(f"len(pretrained_models) = {len(self.pretrained_models)}")
1691
+ print(f"input size = {input_size}")
1692
+
1693
+ #encoder
1694
+ self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
1695
+ self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
1696
+ self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
1697
+
1698
+ #GNN
1699
+ self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1700
+ self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1701
+ self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1702
+
1703
+ self.final_mlp = Make_MLP(input_size + hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1704
+ self.classify = nn.Linear(hid_size, out_size)
1705
+
1706
+ self.model_params.append(self.final_mlp)
1707
+ self.model_params.append(self.classify)
1708
+
1709
+ def TL_node_encoder(self, x, model_idx):
1710
+ try:
1711
+ for layer in self.pretrained_models[model_idx][1]:
1712
+ x = layer(x)
1713
+ return x
1714
+ except (NotImplementedError, IndexError):
1715
+ for layer in self.pretrained_models[model_idx][1][1]:
1716
+ x = layer(x)
1717
+ return x
1718
+
1719
+ def TL_edge_encoder(self, x, model_idx):
1720
+ try:
1721
+ for layer in self.pretrained_models[model_idx][2]:
1722
+ x = layer(x)
1723
+ return x
1724
+ except (NotImplementedError, IndexError):
1725
+ for layer in self.pretrained_models[model_idx][1][2]:
1726
+ x = layer(x)
1727
+ return x
1728
+
1729
+ def TL_global_encoder(self, x, model_idx):
1730
+ try:
1731
+ for layer in self.pretrained_models[model_idx][3]:
1732
+ x = layer(x)
1733
+ return x
1734
+ except (NotImplementedError, IndexError):
1735
+ for layer in self.pretrained_models[model_idx][1][3]:
1736
+ x = layer(x)
1737
+ return x
1738
+
1739
+ def TL_node_update(self, x, model_idx):
1740
+ try:
1741
+ for layer in self.pretrained_models[model_idx][4]:
1742
+ x = layer(x)
1743
+ return x
1744
+ except (NotImplementedError, IndexError):
1745
+ for layer in self.pretrained_models[model_idx][1][4]:
1746
+ x = layer(x)
1747
+ return x
1748
+
1749
+ def TL_edge_update(self, x, model_idx):
1750
+ try:
1751
+ for layer in self.pretrained_models[model_idx][5]:
1752
+ x = layer(x)
1753
+ return x
1754
+ except (NotImplementedError, IndexError):
1755
+ for layer in self.pretrained_models[model_idx][1][5]:
1756
+ x = layer(x)
1757
+ return x
1758
+
1759
+ def TL_global_update(self, x, model_idx):
1760
+ try:
1761
+ for layer in self.pretrained_models[model_idx][6]:
1762
+ x = layer(x)
1763
+ return x
1764
+ except (NotImplementedError, IndexError):
1765
+ for layer in self.pretrained_models[model_idx][1][6]:
1766
+ x = layer(x)
1767
+ return x
1768
+
1769
+ def TL_global_decoder(self, x, model_idx):
1770
+ try:
1771
+ for layer in self.pretrained_models[model_idx][7]:
1772
+ x = layer(x)
1773
+ return x
1774
+ except (NotImplementedError, IndexError):
1775
+ for layer in self.pretrained_models[model_idx][1][7]:
1776
+ x = layer(x)
1777
+ return x
1778
+
1779
+ def Pretrained_Output(self, g, model_idx):
1780
+ h = self.TL_node_encoder(g.ndata['features'], model_idx)
1781
+ e = self.TL_edge_encoder(g.edata['features'], model_idx)
1782
+ g.ndata['h'] = h
1783
+ g.edata['e'] = e
1784
+ if not self.has_global:
1785
+ global_feats = g.batch_num_nodes()[:, None].to(torch.float)
1786
+ h_global = self.TL_global_encoder(global_feats, model_idx)
1787
+ for i in range(self.n_proc_steps):
1788
+ g.apply_edges(dgl.function.copy_u('h', 'm_u'))
1789
+ g.apply_edges(copy_v)
1790
+ g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1), model_idx)
1791
+ g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
1792
+ g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1), model_idx)
1793
+ h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1), model_idx)
1794
+ # h_global = self.TL_global_decoder(h_global, model_idx)
1795
+ return h_global
1796
+
1797
+ def forward(self, g, global_feats):
1798
+ h = self.node_encoder(g.ndata['features'])
1799
+ e = self.edge_encoder(g.edata['features'])
1800
+ g.ndata['h'] = h
1801
+ g.edata['e'] = e
1802
+ if not self.has_global:
1803
+ global_feats = g.batch_num_nodes()[:, None].to(torch.float)
1804
+ h_global = self.global_encoder(global_feats)
1805
+ for i in range(self.n_proc_steps):
1806
+ g.apply_edges(dgl.function.copy_u('h', 'm_u'))
1807
+ g.apply_edges(copy_v)
1808
+ g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
1809
+ g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
1810
+ g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
1811
+ h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
1812
+ h_global = [h_global]
1813
+ for i in range(len(self.pretrained_models)):
1814
+ h_global.append(self.Pretrained_Output(g.clone(), i))
1815
+ h_global = torch.concatenate(h_global, dim=1)
1816
+ return self.classify(self.final_mlp(h_global))
1817
+
1818
+ def to(self, device):
1819
+ for i in range(len(self.pretrained_models)):
1820
+ self.pretrained_models[i].to(device)
1821
+ self.classify.to(device)
1822
+ self.final_mlp.to(device)
1823
+ self.node_encoder.to(device)
1824
+ self.edge_encoder.to(device)
1825
+ self.global_encoder.to(device)
1826
+
1827
+ self.node_update.to(device)
1828
+ self.edge_update.to(device)
1829
+ self.global_update.to(device)
1830
+ return self
1831
+
1832
+ def parameters(self, recurse: bool = True):
1833
+ params = []
1834
+ for i, model_section in enumerate(self.pretraining_params):
1835
+ if (type(self.learning_rate) == dict and self.learning_rate["pretraining_lr"]):
1836
+ print(f"Pretraining LR = {self.learning_rate['pretraining_lr'][i]}")
1837
+ params.append({'params': model_section.parameters(), 'lr': self.learning_rate["pretraining_lr"][i]})
1838
+ else:
1839
+ print(f"Pretraining LR = 0.00001")
1840
+ params.append({'params': model_section.parameters(), 'lr': 0.00001})
1841
+ for model_section in self.model_params:
1842
+ if (type(self.learning_rate) == dict and self.learning_rate["model_lr"]):
1843
+ print(f"Model LR = {self.learning_rate['model_lr']}")
1844
+ params.append({'params': model_section.parameters(), 'lr': self.learning_rate["model_lr"]})
1845
+ else:
1846
+ print(f"Model LR = 0.0001")
1847
+ params.append({'params': model_section.parameters(), 'lr': 0.0001})
1848
+ return params
1849
+
1850
+
1851
+ class Clustering(nn.Module):
1852
+ def __init__(self, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs):
1853
+ super().__init__()
1854
+ print(f'Unused args while creating GCN: {kwargs}')
1855
+ self.n_layers = n_layers
1856
+ self.n_proc_steps = n_proc_steps
1857
+ self.layers = nn.ModuleList()
1858
+ if (len(sample_global) == 0):
1859
+ self.has_global = False
1860
+ else:
1861
+ self.has_global = sample_global.shape[1] != 0
1862
+ gl_size = sample_global.shape[1] if self.has_global else 1
1863
+
1864
+ #encoder
1865
+ self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
1866
+ self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
1867
+ self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
1868
+
1869
+ #GNN
1870
+ self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1871
+ self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1872
+ self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1873
+
1874
+ #decoder
1875
+ self.global_decoder = Make_MLP(hid_size, hid_size, out_size, n_layers, dropout=dropout)
1876
+
1877
+ def model_forward(self, g, global_feats, features = 'features'):
1878
+ h = self.node_encoder(g.ndata[features])
1879
+ e = self.edge_encoder(g.edata[features])
1880
+
1881
+ g.ndata['h'] = h
1882
+ g.edata['e'] = e
1883
+ if not self.has_global:
1884
+ global_feats = g.batch_num_nodes()[:, None].to(torch.float)
1885
+
1886
+ batch_num_nodes = None
1887
+ sum_weights = None
1888
+ if "w" in g.ndata:
1889
+ batch_indices = g.batch_num_nodes()
1890
+ # Find non-zero rows (non-padded nodes)
1891
+ non_padded_nodes_mask = torch.any(g.ndata[features] != 0, dim=1)
1892
+ # Split the mask according to the batch indices
1893
+ batch_num_nodes = []
1894
+ start_idx = 0
1895
+ for num_nodes in batch_indices:
1896
+ end_idx = start_idx + num_nodes
1897
+ non_padded_count = non_padded_nodes_mask[start_idx:end_idx].sum().item()
1898
+ batch_num_nodes.append(non_padded_count)
1899
+ start_idx = end_idx
1900
+ batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata[features].device)
1901
+ sum_weights = batch_num_nodes[:, None].repeat(1, 64)
1902
+ global_feats = batch_num_nodes[:, None].to(torch.float)
1903
+
1904
+ h_global = self.global_encoder(global_feats)
1905
+ for i in range(self.n_proc_steps):
1906
+ g.apply_edges(dgl.function.copy_u('h', 'm_u'))
1907
+ g.apply_edges(copy_v)
1908
+ g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
1909
+ g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
1910
+ g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
1911
+ if "w" in g.ndata:
1912
+ mean_nodes = dgl.sum_nodes(g, 'h', 'w') / sum_weights
1913
+ h_global = self.global_update(torch.cat((h_global, mean_nodes, dgl.mean_edges(g, 'e')), dim = 1))
1914
+ else:
1915
+ h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
1916
+ h_global = self.global_decoder(h_global)
1917
+ return h_global
1918
+
1919
+ def forward(self, g, global_feats):
1920
+ h_global = self.model_forward(g, global_feats, 'features')
1921
+ h_global_augmented = self.model_forward(g, global_feats, 'augmented_features')
1922
+ return torch.cat((h_global, h_global_augmented), dim=1)
1923
+
1924
+ def representation(self, g, global_feats):
1925
+ h_global = self.model_forward(g, global_feats, 'features')
1926
+ h_global_augmented = self.model_forward(g, global_feats, 'augmented_features')
1927
+ return h_global, h_global_augmented, torch.cat((h_global, h_global_augmented), dim=1)
1928
+
1929
+ def __str__(self):
1930
+ layer_names = ["node_encoder", "edge_encoder", "global_encoder",
1931
+ "node_update", "edge_update", "global_update", "global_decoder"]
1932
+
1933
+ layers = [self.node_encoder, self.edge_encoder, self.global_encoder,
1934
+ self.node_update, self.edge_update, self.global_update, self.global_decoder]
1935
+
1936
+ for i in range(len(layers)):
1937
+ print(layer_names[i])
1938
+ for layer in layers[i].children():
1939
+ if isinstance(layer, nn.Linear):
1940
+ print(layer.state_dict())
1941
+
1942
+ print("classify")
1943
+ print(self.classify.weight)
1944
+ return ""
models/__pycache__/GCN.cpython-38.pyc ADDED
Binary file (57 kB). View file
 
models/__pycache__/loss.cpython-38.pyc ADDED
Binary file (11.4 kB). View file
 
models/loss.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ from root_gnn_base import utils
4
+ import numpy as np
5
+
6
+ class MaskedLoss():
7
+ def __init__(self, mask = []):
8
+ self.mask = mask
9
+
10
+ def make_mask(self, targets):
11
+ mask = torch.ones_like(targets[:,0])
12
+ for m in self.mask:
13
+ if m['op'] == 'eq':
14
+ mask[targets[:,m['idx']] == m['val']] = 0
15
+ elif m['op'] == 'gt':
16
+ mask[targets[:,m['idx']] > m['val']] = 0
17
+ elif m['op'] == 'lt':
18
+ mask[targets[:,m['idx']] < m['val']] = 0
19
+ elif m['op'] == 'ge':
20
+ mask[targets[:,m['idx']] >= m['val']] = 0
21
+ elif m['op'] == 'le':
22
+ mask[targets[:,m['idx']] <= m['val']] = 0
23
+ elif m['op'] == 'ne':
24
+ mask[targets[:,m['idx']] != m['val']] = 0
25
+ else:
26
+ raise ValueError(f'Unknown mask op {m["op"]}')
27
+ return mask == 1
28
+
29
+ class MaskedL1Loss(MaskedLoss):
30
+ def __init__(self, mask = [], index = 0):
31
+ super().__init__(mask)
32
+ self.index = index
33
+ self.loss = nn.L1Loss()
34
+
35
+ def __call__(self, logits, targets):
36
+ mask = self.make_mask(targets)
37
+ return self.loss(logits[mask], targets[mask][:,self.index])
38
+
39
+ class BCEWithLogitsLoss():
40
+ def __init__(self, weight=None, reduction='mean'):
41
+ self.loss = nn.BCEWithLogitsLoss(weight=weight, reduction=reduction)
42
+
43
+ def __call__(self, logits, targets):
44
+ return self.loss(logits[:,0], targets.float())
45
+
46
+ class MultiScore():
47
+ def __init__(self, scores):
48
+ self. score_fcns = []
49
+ self.start_idx = []
50
+ self.end_idx = []
51
+ for score in scores:
52
+ self.score_fcns.append(utils.buildFromConfig(score))
53
+ self.start_idx.append(score['start_idx'])
54
+ self.end_idx.append(score['end_idx'])
55
+
56
+ def __call__(self, last_layer):
57
+ scores = []
58
+ for i in range(len(self.score_fcns)):
59
+ scores.append(self.score_fcns[i](last_layer[:, self.start_idx[i]:self.end_idx[i]]))
60
+ return torch.cat(scores, dim=1)
61
+
62
+ class MultiLoss():
63
+ def __init__(self, losses):
64
+ self.loss_fcns = []
65
+ self.label_start_idx = []
66
+ self.label_end_idx = []
67
+ self.output_start_idx = []
68
+ self.output_end_idx = []
69
+ self.weights = []
70
+ self.label_types = []
71
+ for loss in losses:
72
+ self.loss_fcns.append(utils.buildFromConfig(loss))
73
+ self.label_start_idx.append(loss['label_start_idx'])
74
+ self.label_end_idx.append(loss['label_end_idx'])
75
+ self.output_start_idx.append(loss['output_start_idx'])
76
+ self.output_end_idx.append(loss['output_end_idx'])
77
+ self.weights.append(loss.get('weight', 1.0))
78
+ self.label_types.append(loss.get('label_type', 'float'))
79
+
80
+ def __call__(self, logits, targets):
81
+ loss = 0
82
+ # print(logits.shape, targets.shape)
83
+ for i in range(len(self.loss_fcns)):
84
+ if self.label_types[i] == 'int':
85
+ # print('loss', i, self.label_start_idx[i], self.label_end_idx[i], self.output_start_idx[i], self.output_end_idx[i])
86
+ # print(logits[:, self.output_start_idx[i]:self.output_end_idx[i]].shape, targets[:, self.label_start_idx[i]].shape)
87
+ loss += self.weights[i] * self.loss_fcns[i](logits[:, self.output_start_idx[i]:self.output_end_idx[i]], targets[:, self.label_start_idx[i]].to(int))
88
+ elif self.label_end_idx[i] - self.label_start_idx[i] == 1:
89
+ loss += self.weights[i] * self.loss_fcns[i](logits[:, self.output_start_idx[i]:self.output_end_idx[i]], targets[:, self.label_start_idx[i]])
90
+ else:
91
+ # print('loos', i, self.label_start_idx[i], self.label_end_idx[i], self.output_start_idx[i], self.output_end_idx[i])
92
+ # print(logits[:, self.output_start_idx[i]:self.output_end_idx[i]].shape, targets[:, self.label_start_idx[i]:self.label_end_idx[i]].shape)
93
+ loss += self.weights[i] * self.loss_fcns[i](logits[:, self.output_start_idx[i]:self.output_end_idx[i]], targets[:, self.label_start_idx[i]:self.label_end_idx[i]])
94
+ return loss
95
+
96
+ class AdvLoss():
97
+ def __init__(self, loss, adv_loss, adv_weight=1.0):
98
+ self.loss_fcn = utils.buildFromConfig(loss)
99
+ self.adv_loss_fcn = utils.buildFromConfig(adv_loss)
100
+ self.adv_weight = adv_weight
101
+
102
+ def __call__(self, logits, targets):
103
+ mask = targets[:,0] == 0
104
+ loss = self.loss_fcn(logits[:,0], targets[:,0])
105
+ adv_loss = self.adv_loss_fcn(logits[mask][:,1], targets[mask])
106
+ return loss - self.adv_weight * adv_loss
107
+
108
+ class MassWindowAdvLoss(AdvLoss):
109
+ def __call__(self, logits, targets):
110
+ mask = (targets[:,0] == 0) & (targets[:,1] > 5) & (targets[:,1] < 25)
111
+ print(mask, mask.shape, mask.sum())
112
+ loss = self.loss_fcn(logits[:,0], targets[:,0])
113
+ print(loss)
114
+ adv_loss = self.adv_loss_fcn(logits[mask][:,1], targets[mask][:,1])
115
+ print(adv_loss)
116
+ return loss - self.adv_weight * adv_loss
117
+
118
+ class KDELoss(MaskedLoss):
119
+ def __init__(self, mask = [], index = 0):
120
+ self.index = index
121
+ super().__init__(mask)
122
+
123
+ def __call__(self, logits, targets):
124
+ mask = self.make_mask(targets)
125
+ logits = logits[mask]
126
+ targets = targets[mask][:,self.index]
127
+ N = logits.shape[0]
128
+ masses = targets / torch.sqrt(torch.mean(targets**2))
129
+ scores = logits[:,0] / torch.sqrt(torch.mean(logits**2))
130
+
131
+ factor_2d = (1.0*N) ** (-2/6)
132
+ covs = (factor_2d * torch.var(masses), factor_2d * torch.var(scores))
133
+
134
+ m_diffs = torch.unsqueeze(masses, 1) - torch.unsqueeze(masses, 0)
135
+ s_diffs = torch.unsqueeze(scores, 1) - torch.unsqueeze(scores, 0)
136
+
137
+ ymm = torch.exp(- (m_diffs**2) / (4 * covs[0]))
138
+ yss = torch.exp(- (s_diffs**2) / (4 * covs[1]))
139
+
140
+ integral_rho_2d_rho_2d = torch.einsum('ij,ij->', ymm, yss)
141
+ integral_rho_1d_rho_1d = torch.einsum('ij,kl->', ymm, yss)
142
+ integral_rho_2d_rho_1d = torch.einsum('ij,ik->', ymm, yss)
143
+ raw_integral = integral_rho_2d_rho_2d - 2 * integral_rho_2d_rho_1d / N + integral_rho_1d_rho_1d / N**2
144
+ return raw_integral / (4 * torch.pi * N**2)
145
+
146
+ class MultiLabelLoss():
147
+ def __init__(self, label_names, label_types, label_weights = None):
148
+ self.loss_fcn = []
149
+ if (label_weights):
150
+ self.weights = torch.tensor(label_weights)
151
+ else:
152
+ self.weights = torch.ones(len(label_types))
153
+ for type in label_types:
154
+ if (type == "r"):
155
+ self.loss_fcn.append(torch.nn.MSELoss(reduce=False))
156
+ elif (type == "c"):
157
+ self.loss_fcn.append(torch.nn.BCEWithLogitsLoss())
158
+ print(f"self.weights = {self.weights}")
159
+
160
+ def __call__(self, logits, targets):
161
+ targets = targets.float()
162
+ loss = torch.zeros(len(logits[:, 0]), device = logits.get_device())
163
+ for i in range(len(self.loss_fcn)):
164
+ loss += self.weights[i] * self.loss_fcn[i](logits[:, i], targets[:, i])
165
+ return torch.mean(loss)
166
+
167
+
168
+ class MultiLabelFinish():
169
+ def __init__(self, label_names, label_types):
170
+ self.finish_fcn = []
171
+ for type in label_types:
172
+ if (type == "r"):
173
+ self.finish_fcn.append(None)
174
+ elif (type == "c"):
175
+ self.finish_fcn.append(torch.special.expit)
176
+
177
+ def __call__(self, logits):
178
+ for i in range(len(self.finish_fcn)):
179
+ if (self.finish_fcn[i]):
180
+ logits[:, i] = self.finish_fcn[i](logits[:, i].to(torch.long))
181
+ return logits
182
+
183
+ class ContrastiveClusterLoss():
184
+ def __init__(self, k=10, temperature=1, alpha=1):
185
+ self.k = k
186
+ self.temperature = temperature
187
+ self.alpha = alpha
188
+
189
+ def __call__(self, logits, targets):
190
+ targets = targets.float()
191
+ logits_combined = logits.float()
192
+
193
+ hid_size = int(len(logits[0]) / 2)
194
+
195
+ logits = normalize_embeddings(logits_combined[:, :hid_size])
196
+ logits_augmented = normalize_embeddings(logits_combined[:, hid_size:])
197
+
198
+ contrastive = contrastive_loss(logits, logits_augmented, self.temperature)
199
+ clustering, _ = clustering_loss(logits, self.k)
200
+
201
+ variance_loss = variance_regularization(logits) + variance_regularization(logits_augmented)
202
+
203
+ return torch.mean(contrastive + clustering + self.alpha * variance_loss)
204
+
205
+ class ContrastiveClusterFinish():
206
+ def __init__(self, k = 10, temperature = 1, max_cluster_iterations = 10):
207
+ self.k = k
208
+ self.temperature = temperature
209
+ self.max_cluster_iterations = max_cluster_iterations
210
+
211
+ print(f"ContrastiveClusterFinish: k = {k}, temperature = {temperature}")
212
+
213
+ def __call__(self, logits):
214
+ logits_combined = logits.float()
215
+
216
+ hid_size = int(len(logits[0]) / 2)
217
+
218
+ logits = logits_combined[:, :hid_size]
219
+ logits_augmented = logits_combined[:, hid_size:]
220
+
221
+ contrastive = contrastive_loss(logits, logits_augmented, self.temperature)
222
+ clustering, _ = clustering_loss(logits, self.k, self.max_cluster_iterations)
223
+ variance = variance_regularization(logits) + variance_regularization(logits_augmented)
224
+
225
+ return contrastive, clustering, variance
226
+
227
+ def s(z_i, z_j):
228
+ z_i = torch.tensor(z_i) if not isinstance(z_i, torch.Tensor) else z_i
229
+ z_j = torch.tensor(z_j) if not isinstance(z_j, torch.Tensor) else z_j
230
+
231
+ return torch.cdist(z_i, z_j, p=2)
232
+ # dot_product = torch.dot(z_i, z_j)
233
+ # norm_i = torch.linalg.norm(z_i)
234
+ # norm_j = torch.linalg.norm(z_j)
235
+
236
+ # return dot_product / (norm_i * norm_j)
237
+
238
+ def contrastive_loss(logits, logits_augmented, temperature=1, margin=1.0):
239
+ logits = torch.tensor(logits) if not isinstance(logits, torch.Tensor) else logits
240
+ logits_augmented = torch.tensor(logits_augmented) if not isinstance(logits_augmented, torch.Tensor) else logits_augmented
241
+
242
+ z = torch.cat((logits, logits_augmented), dim=0)
243
+ similarity_matrix = torch.mm(z, z.t()) / temperature
244
+ norms = torch.linalg.norm(z, dim=1)
245
+ norm_matrix = torch.ger(norms, norms)
246
+ similarity_matrix = similarity_matrix / norm_matrix
247
+ mask = torch.eye(similarity_matrix.size(0), dtype=torch.bool)
248
+
249
+ loss = 0
250
+ for k in range(len(logits)):
251
+ numerator = torch.exp(similarity_matrix[k, k + len(logits)])
252
+ denominator = torch.sum(torch.exp(similarity_matrix[k, ~mask[k]]))
253
+
254
+ loss += -torch.log(numerator / denominator)
255
+
256
+ return loss
257
+
258
+
259
+ def clustering_loss(logits, k=10, max_iterations=10):
260
+ # Step 1: Initialize cluster means
261
+ indices = torch.randperm(logits.size(0))[:k]
262
+ cluster_means = logits[indices]
263
+
264
+ prev_assignments = None
265
+ assignment_history = []
266
+ iteration = 0
267
+
268
+ while iteration < max_iterations:
269
+ iteration += 1
270
+
271
+ # Step 2: Assign each data point to the nearest cluster mean
272
+ distances = torch.cdist(logits, cluster_means, p=2) # Compute distances between logits and cluster means
273
+ cluster_assignments = torch.argmin(distances, dim=1) # Assign each point to the nearest cluster mean
274
+
275
+ # Check for convergence: if assignments do not change, break the loop
276
+ if prev_assignments is not None and torch.equal(cluster_assignments, prev_assignments):
277
+ break
278
+
279
+ # Check for cycles: if assignments have been seen before, break the loop
280
+ if any(torch.equal(cluster_assignments, prev) for prev in assignment_history):
281
+ break
282
+
283
+ assignment_history.append(cluster_assignments.clone())
284
+ prev_assignments = cluster_assignments.clone()
285
+
286
+ # Step 3: Update cluster means based on assignments
287
+ new_cluster_means = torch.zeros_like(cluster_means)
288
+ for i in range(k):
289
+ assigned_points = logits[cluster_assignments == i]
290
+ if assigned_points.size(0) > 0:
291
+ new_cluster_means[i] = assigned_points.mean(dim=0)
292
+ else:
293
+ # If no points are assigned to the cluster, reinitialize the mean randomly
294
+ new_cluster_means[i] = logits[torch.randint(0, logits.size(0), (1,)).item()]
295
+ cluster_means = new_cluster_means
296
+
297
+ # Step 4: Compute the clustering loss
298
+ distances = torch.cdist(logits, cluster_means, p=2)
299
+ min_distances = torch.min(distances, dim=1)[0]
300
+ loss = torch.sum(min_distances ** 2)
301
+
302
+ return loss, cluster_means
303
+
304
+ def normalize_embeddings(embeddings):
305
+ return embeddings / embeddings.norm(dim=1, keepdim=True)
306
+
307
+ def variance_regularization(embeddings):
308
+ mean_embedding = embeddings.mean(dim=0)
309
+ variance = ((embeddings - mean_embedding) ** 2).mean()
310
+ return variance
311
+
root_gnn_base/batched_dataset.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dgl.dataloading import GraphDataLoader
2
+ from torch.utils.data.sampler import SubsetRandomSampler
3
+ from torch.utils.data.sampler import SequentialSampler
4
+ from dgl.data import DGLDataset
5
+ import torch
6
+ import time
7
+ import os
8
+ import dgl
9
+ from root_gnn_base import utils
10
+
11
+ def GetBatchedLoader(dataset, batch_size, mask_fn = None, drop_last=True, **kwargs):
12
+ if mask_fn == None:
13
+ mask_fn = lambda x: torch.ones(len(x), dtype=torch.bool)
14
+ dloader = GraphDataLoader(dataset, sampler=SubsetRandomSampler(torch.arange(len(dataset))[mask_fn(dataset)]), batch_size=batch_size, drop_last=drop_last, num_workers = 0)
15
+ return dloader
16
+
17
+ #Dataset which contains prebatched shuffled graphs. Cannot be saved to disk, else batching info is lost.
18
+ class PreBatchedDataset(DGLDataset):
19
+ def __init__(self, start_dataset, batch_size, mask_fn = None, drop_last=True, save_to_disk = True, suffix = '', chunks = 1, chunkno = -1, shuffle = True, padding_mode = 'NONE', **kwargs):
20
+ print(f'Unused kwargs: {kwargs}')
21
+ self.start_dataset = start_dataset
22
+ self.start_dataset.load()
23
+
24
+ self.batch_size = batch_size
25
+ self.chunks = chunks
26
+ self.chunkno = chunkno
27
+ self.mask_fn = mask_fn
28
+ self.drop_last = drop_last
29
+ self.graphs = []
30
+ self.label = []
31
+ self.padding_mode = padding_mode
32
+ self.save_to_disk = save_to_disk
33
+ self.shuffle = shuffle
34
+ self.suffix = suffix
35
+ self.current_chunk = None
36
+ self.current_chunk_idx = -1
37
+ super().__init__(name = start_dataset.name + '_prebatched_padded', save_dir=start_dataset.save_dir)
38
+
39
+ def process(self):
40
+ first = 0
41
+ last = len(self.start_dataset)
42
+ if self.chunks > 1 and self.chunkno >= 0:
43
+ first = int(self.chunkno / self.chunks * len(self.start_dataset))
44
+ last = int((self.chunkno + 1) / self.chunks * len(self.start_dataset))
45
+ print(f'Processing chunk {self.chunkno} of {self.chunks} from {first} to {last} of {len(self.start_dataset)}')
46
+ mask = torch.logical_and(torch.logical_and(self.mask_fn(self.start_dataset), torch.arange(len(self.start_dataset)) >= first), torch.arange(len(self.start_dataset)) < last)
47
+ if self.shuffle:
48
+ dloader = GraphDataLoader(self.start_dataset, sampler=SubsetRandomSampler(torch.arange(len(self.start_dataset))[mask]), batch_size=self.batch_size, drop_last=self.drop_last)
49
+ else: #Only don't shuffle if we're doing inference. Then we want all of the events anyways?
50
+ dloader = GraphDataLoader(self.start_dataset, sampler=SequentialSampler(self.start_dataset), batch_size=self.batch_size, drop_last=self.drop_last)
51
+ self.graphs = []
52
+ self.labels = []
53
+ self.tracking = []
54
+ self.globals = []
55
+ self.batch_num_nodes = []
56
+ self.batch_num_edges = []
57
+ max_edges = 0
58
+ max_nodes = 0
59
+ load_batch_start = time.time()
60
+ for batch, label, tracking, global_feat in dloader:
61
+ if batch.num_edges() > max_edges:
62
+ max_edges = batch.num_edges()
63
+ if batch.num_nodes() > max_nodes:
64
+ max_nodes = batch.num_nodes()
65
+ self.graphs.append(batch)
66
+ self.labels.append(label)
67
+ self.tracking.append(tracking)
68
+ self.globals.append(global_feat)
69
+ load_batch_end = time.time()
70
+ print(f'Loaded {len(self.graphs)} batches in {load_batch_end - load_batch_start} seconds')
71
+ if self.padding_mode == 'STEPS':
72
+ pad_node, pad_edge = utils.pad_size(self.batch_size, max_edges, max_nodes)
73
+ elif self.padding_mode == 'FIXED':
74
+ print('Padding to fixed size. This is currently hardcoded.')
75
+ pad_node = 16000
76
+ pad_edge = 104000
77
+ elif self.padding_mode == 'NONE':
78
+ pad_node = 0
79
+ pad_edge = 0
80
+ else:
81
+ pad_node = 0
82
+ pad_edge = 0
83
+ print(f'Max edges: {max_edges}, Max nodes: {max_nodes}, Padding to {pad_edge} edges and {pad_node} nodes')
84
+ pad_start = time.time()
85
+ if self.padding_mode == 'NODE':
86
+ for i in range(len(self.graphs)):
87
+ unbatched_g = dgl.unbatch(self.graphs[i])
88
+ max_num_nodes = max(g.number_of_nodes() for g in unbatched_g)
89
+ self.graphs[i] = utils.pad_batch_num_nodes(self.graphs[i], max_num_nodes)
90
+ self.batch_num_nodes.append(self.graphs[i].batch_num_nodes())
91
+ self.batch_num_edges.append(self.graphs[i].batch_num_edges())
92
+ else:
93
+ for i in range(len(self.graphs)):
94
+ self.graphs[i] = utils.pad_batch(self.graphs[i], pad_edge, pad_node)
95
+ self.batch_num_nodes.append(self.graphs[i].batch_num_nodes())
96
+ self.batch_num_edges.append(self.graphs[i].batch_num_edges())
97
+ pad_end = time.time()
98
+ print(f'Padded {len(self.graphs)} batches in {pad_end - pad_start} seconds')
99
+
100
+ def save(self):
101
+ if not self.save_to_disk:
102
+ return
103
+ graph_path = os.path.join(self.save_dir, f'{self.name}_{self.chunkno}_{self.suffix}.bin')
104
+ print(f'Saving dataset to {graph_path}')
105
+ if len(self.graphs) == 0:
106
+ return
107
+ dgl.save_graphs(str(graph_path), self.graphs, {'labels': torch.stack(self.labels), 'batch_num_nodes': torch.stack(self.batch_num_nodes), 'batch_num_edges': torch.stack(self.batch_num_edges), 'tracking': torch.stack(self.tracking), 'globals': torch.stack(self.globals)})
108
+
109
+ def has_cache(self):
110
+ if not self.save_to_disk:
111
+ return False
112
+ for ch in range(self.chunks):
113
+ graph_path = os.path.join(self.save_dir, f'{self.name}_{ch}_{self.suffix}.bin')
114
+ if not os.path.exists(graph_path):
115
+ print(f'Cache file {graph_path} does not exist, not loading from cache.')
116
+ return False
117
+ return True
118
+
119
+ def load(self):
120
+ if not self.save_to_disk:
121
+ return
122
+ self.graphs = []
123
+ label_chunks = []
124
+ tracking_chunks = []
125
+ global_chunks = []
126
+ for ch in range(self.chunks):
127
+ graph_path = os.path.join(self.save_dir, f'{self.name}_{ch}_{self.suffix}.bin')
128
+ print(f'Loading dataset from {graph_path}')
129
+ graphs, label_dict = dgl.load_graphs(graph_path)
130
+ label_chunks.append(label_dict['labels'])
131
+ tracking_chunks.append(label_dict['tracking'])
132
+ global_chunks.append(label_dict['globals'])
133
+ for g, bnn, bne in zip(graphs, label_dict['batch_num_nodes'], label_dict['batch_num_edges']):
134
+ g.set_batch_num_nodes(bnn)
135
+ g.set_batch_num_edges(bne)
136
+ self.graphs.extend(graphs)
137
+ self.labels = torch.cat(label_chunks)
138
+ self.tracking = torch.cat(tracking_chunks)
139
+ self.globals = torch.cat(global_chunks)
140
+
141
+ def __getitem__(self, idx):
142
+ return self.graphs[idx], self.labels[idx], self.tracking[idx], self.globals[idx]
143
+
144
+ def __len__(self):
145
+ return len(self.graphs)
146
+
147
+ #Dataset which contains prebatched shuffled graphs. Cannot be saved to disk, else batching info is lost.
148
+ class LazyPreBatchedDataset(PreBatchedDataset):
149
+ def __init__(self, **kwargs):
150
+ # print(f'Unused kwargs: {kwargs}')
151
+ self.current_chunk = None
152
+ self.current_chunk_idx = -10
153
+ self.label_chunks = []
154
+ super().__init__(**kwargs)
155
+
156
+ def load(self):
157
+ if not self.save_to_disk:
158
+ return
159
+ self.label_chunks = []
160
+ for ch in range(self.chunks):
161
+ graph_path = os.path.join(self.save_dir, f'{self.name}_{ch}_{self.suffix}.bin')
162
+ print(f'Loading dataset from {graph_path}')
163
+ label_dict = dgl.data.graph_serialize.load_labels_v2(graph_path)
164
+ self.label_chunks.append(label_dict)
165
+
166
+ def __getitem__(self, idx):
167
+ chunk_idx = -1
168
+ sum = 0
169
+ ev_idx = -999
170
+ for i in range(len(self.label_chunks)):
171
+ count = len(self.label_chunks[i]['labels'])
172
+ if idx < sum + count:
173
+ chunk_idx = i
174
+ ev_idx = idx - sum
175
+ break
176
+ sum += count
177
+ if chunk_idx != self.current_chunk_idx:
178
+ # print(f"rank {self.rank} getting data from {self.name}_{chunk_idx}_{self.suffix}.bin")
179
+ self.current_chunk, _ = dgl.load_graphs(os.path.join(self.save_dir, f'{self.name}_{chunk_idx}_{self.suffix}.bin'))
180
+ self.current_chunk_idx = chunk_idx
181
+ g = self.current_chunk[ev_idx]
182
+ g.set_batch_num_nodes(self.label_chunks[chunk_idx]['batch_num_nodes'][ev_idx])
183
+ g.set_batch_num_edges(self.label_chunks[chunk_idx]['batch_num_edges'][ev_idx])
184
+ return g, self.label_chunks[chunk_idx]['labels'][ev_idx], self.label_chunks[chunk_idx]['tracking'][ev_idx], self.label_chunks[chunk_idx]['globals'][ev_idx]
185
+
186
+ def __len__(self):
187
+ l = 0
188
+ for chunk in self.label_chunks:
189
+ l += len(chunk['labels'])
190
+ return l
root_gnn_base/custom_scheduler.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import types
2
+ import math
3
+ import torch
4
+ from torch import inf
5
+ from functools import wraps, partial
6
+ import warnings
7
+ import weakref
8
+ from collections import Counter
9
+ from bisect import bisect_right
10
+
11
+ from models import GCN
12
+
13
+
14
+
15
+
16
+ ### Code from: https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#ReduceLROnPlateau
17
+
18
+ Optimizer = torch.optim.Optimizer
19
+
20
+ __all__ = ['LambdaLR', 'MultiplicativeLR', 'StepLR', 'MultiStepLR', 'ConstantLR', 'LinearLR',
21
+ 'ExponentialLR', 'SequentialLR', 'CosineAnnealingLR', 'ChainedScheduler', 'ReduceLROnPlateau',
22
+ 'CyclicLR', 'CosineAnnealingWarmRestarts', 'OneCycleLR', 'PolynomialLR', 'LRScheduler']
23
+
24
+ EPOCH_DEPRECATION_WARNING = (
25
+ "The epoch parameter in `scheduler.step()` was not necessary and is being "
26
+ "deprecated where possible. Please use `scheduler.step()` to step the "
27
+ "scheduler. During the deprecation, if epoch is different from None, the "
28
+ "closed form is used instead of the new chainable form, where available. "
29
+ "Please open an issue if you are unable to replicate your use case: "
30
+ "https://github.com/pytorch/pytorch/issues/new/choose."
31
+ )
32
+
33
+
34
+ def update_LR(opt, lr):
35
+ for param_group in opt.param_groups:
36
+ param_group['lr'] = lr
37
+
38
+ def print_LR(opt):
39
+ for param_group in opt.param_groups:
40
+ print(f"LR = {param_group['lr']}")
41
+
42
+ def _check_verbose_deprecated_warning(verbose):
43
+ """Raises a warning when verbose is not the default value."""
44
+ if verbose != "deprecated":
45
+ warnings.warn("The verbose parameter is deprecated. Please use get_last_lr() "
46
+ "to access the learning rate.", UserWarning)
47
+ return verbose
48
+ return False
49
+
50
+ class LRScheduler:
51
+
52
+ def __init__(self, optimizer, last_epoch=-1, verbose="deprecated"):
53
+
54
+ # Attach optimizer
55
+ if not isinstance(optimizer, Optimizer):
56
+ raise TypeError(f'{type(optimizer).__name__} is not an Optimizer')
57
+ self.optimizer = optimizer
58
+
59
+ # Initialize epoch and base learning rates
60
+ if last_epoch == -1:
61
+ for group in optimizer.param_groups:
62
+ group.setdefault('initial_lr', group['lr'])
63
+ else:
64
+ for i, group in enumerate(optimizer.param_groups):
65
+ if 'initial_lr' not in group:
66
+ raise KeyError("param 'initial_lr' is not specified "
67
+ f"in param_groups[{i}] when resuming an optimizer")
68
+ self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups]
69
+ self.last_epoch = last_epoch
70
+
71
+ # Following https://github.com/pytorch/pytorch/issues/20124
72
+ # We would like to ensure that `lr_scheduler.step()` is called after
73
+ # `optimizer.step()`
74
+ def with_counter(method):
75
+ if getattr(method, '_with_counter', False):
76
+ # `optimizer.step()` has already been replaced, return.
77
+ return method
78
+
79
+ # Keep a weak reference to the optimizer instance to prevent
80
+ # cyclic references.
81
+ instance_ref = weakref.ref(method.__self__)
82
+ # Get the unbound method for the same purpose.
83
+ func = method.__func__
84
+ cls = instance_ref().__class__
85
+ del method
86
+
87
+ @wraps(func)
88
+ def wrapper(*args, **kwargs):
89
+ instance = instance_ref()
90
+ instance._step_count += 1
91
+ wrapped = func.__get__(instance, cls)
92
+ return wrapped(*args, **kwargs)
93
+
94
+ # Note that the returned function here is no longer a bound method,
95
+ # so attributes like `__func__` and `__self__` no longer exist.
96
+ wrapper._with_counter = True
97
+ return wrapper
98
+
99
+ self.optimizer.step = with_counter(self.optimizer.step)
100
+ self.verbose = _check_verbose_deprecated_warning(verbose)
101
+
102
+ self._initial_step()
103
+
104
+ def _initial_step(self):
105
+ """Initialize step counts and performs a step"""
106
+ self.optimizer._step_count = 0
107
+ self._step_count = 0
108
+ self.step()
109
+
110
+ def state_dict(self):
111
+ """Returns the state of the scheduler as a :class:`dict`.
112
+
113
+ It contains an entry for every variable in self.__dict__ which
114
+ is not the optimizer.
115
+ """
116
+ return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
117
+
118
+ def load_state_dict(self, state_dict):
119
+ """Loads the schedulers state.
120
+
121
+ Args:
122
+ state_dict (dict): scheduler state. Should be an object returned
123
+ from a call to :meth:`state_dict`.
124
+ """
125
+ self.__dict__.update(state_dict)
126
+
127
+ def get_last_lr(self):
128
+ """ Return last computed learning rate by current scheduler.
129
+ """
130
+ return self._last_lr
131
+
132
+ def get_lr(self):
133
+ # Compute learning rate using chainable form of the scheduler
134
+ raise NotImplementedError
135
+
136
+ def print_lr(self, is_verbose, group, lr, epoch=None):
137
+ """Display the current learning rate.
138
+ """
139
+ if is_verbose:
140
+ if epoch is None:
141
+ print(f'Adjusting learning rate of group {group} to {lr:.4e}.')
142
+ else:
143
+ epoch_str = ("%.2f" if isinstance(epoch, float) else
144
+ "%.5d") % epoch
145
+ print(f'Epoch {epoch_str}: adjusting learning rate of group {group} to {lr:.4e}.')
146
+
147
+
148
+ def step(self, epoch=None):
149
+ # Raise a warning if old pattern is detected
150
+ # https://github.com/pytorch/pytorch/issues/20124
151
+ if self._step_count == 1:
152
+ if not hasattr(self.optimizer.step, "_with_counter"):
153
+ warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler "
154
+ "initialization. Please, make sure to call `optimizer.step()` before "
155
+ "`lr_scheduler.step()`. See more details at "
156
+ "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
157
+
158
+ # Just check if there were two first lr_scheduler.step() calls before optimizer.step()
159
+ elif self.optimizer._step_count < 1:
160
+ warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
161
+ "In PyTorch 1.1.0 and later, you should call them in the opposite order: "
162
+ "`optimizer.step()` before `lr_scheduler.step()`. Failure to do this "
163
+ "will result in PyTorch skipping the first value of the learning rate schedule. "
164
+ "See more details at "
165
+ "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
166
+ self._step_count += 1
167
+
168
+ with _enable_get_lr_call(self):
169
+ if epoch is None:
170
+ self.last_epoch += 1
171
+ values = self.get_lr()
172
+ else:
173
+ warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
174
+ self.last_epoch = epoch
175
+ if hasattr(self, "_get_closed_form_lr"):
176
+ values = self._get_closed_form_lr()
177
+ else:
178
+ values = self.get_lr()
179
+
180
+ for i, data in enumerate(zip(self.optimizer.param_groups, values)):
181
+ param_group, lr = data
182
+ param_group['lr'] = lr
183
+
184
+ self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
185
+
186
+
187
+ # Including _LRScheduler for backwards compatibility
188
+ # Subclass instead of assign because we want __name__ of _LRScheduler to be _LRScheduler (assigning would make it LRScheduler).
189
+ class _LRScheduler(LRScheduler):
190
+ pass
191
+
192
+
193
+ class _enable_get_lr_call:
194
+
195
+ def __init__(self, o):
196
+ self.o = o
197
+
198
+ def __enter__(self):
199
+ self.o._get_lr_called_within_step = True
200
+ return self
201
+
202
+ def __exit__(self, type, value, traceback):
203
+ self.o._get_lr_called_within_step = False
204
+
205
+
206
+ class Dynamic_LR(LRScheduler):
207
+ """Reduce learning rate when a metric has stopped improving.
208
+ Models often benefit from reducing the learning rate by a factor
209
+ of 2-10 once learning stagnates. This scheduler reads a metrics
210
+ quantity and if no improvement is seen for a 'patience' number
211
+ of epochs, the learning rate is reduced.
212
+
213
+ Args:
214
+ optimizer (Optimizer): Wrapped optimizer.
215
+ mode (str): One of `min`, `max`. In `min` mode, lr will
216
+ be reduced when the quantity monitored has stopped
217
+ decreasing; in `max` mode it will be reduced when the
218
+ quantity monitored has stopped increasing. Default: 'min'.
219
+ factor (float): Factor by which the learning rate will be
220
+ reduced. new_lr = lr * factor. Default: 0.1.
221
+ patience (int): Number of epochs with no improvement after
222
+ which learning rate will be reduced. For example, if
223
+ `patience = 2`, then we will ignore the first 2 epochs
224
+ with no improvement, and will only decrease the LR after the
225
+ 3rd epoch if the loss still hasn't improved then.
226
+ Default: 10.
227
+ threshold (float): Threshold for measuring the new optimum,
228
+ to only focus on significant changes. Default: 1e-4.
229
+ threshold_mode (str): One of `rel`, `abs`. In `rel` mode,
230
+ dynamic_threshold = best * ( 1 + threshold ) in 'max'
231
+ mode or best * ( 1 - threshold ) in `min` mode.
232
+ In `abs` mode, dynamic_threshold = best + threshold in
233
+ `max` mode or best - threshold in `min` mode. Default: 'rel'.
234
+ cooldown (int): Number of epochs to wait before resuming
235
+ normal operation after lr has been reduced. Default: 0.
236
+ min_lr (float or list): A scalar or a list of scalars. A
237
+ lower bound on the learning rate of all param groups
238
+ or each group respectively. Default: 0.
239
+ eps (float): Minimal decay applied to lr. If the difference
240
+ between new and old lr is smaller than eps, the update is
241
+ ignored. Default: 1e-8.
242
+ verbose (bool): If ``True``, prints a message to stdout for
243
+ each update. Default: ``False``.
244
+
245
+ .. deprecated:: 2.2
246
+ ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
247
+ learning rate.
248
+
249
+ Example:
250
+ >>> # xdoctest: +SKIP
251
+ >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
252
+ >>> scheduler = ReduceLROnPlateau(optimizer, 'min')
253
+ >>> for epoch in range(10):
254
+ >>> train(...)
255
+ >>> val_loss = validate(...)
256
+ >>> # Note that step should be called after validate()
257
+ >>> scheduler.step(val_loss)
258
+ """
259
+
260
+ def __init__(self, optimizer, mode = 'max', factor=0.1, patience=10,
261
+ plateau_var = "test_auc",
262
+ threshold=1e-4, threshold_mode='rel', cooldown=0,
263
+ min_lr=0, max_lr=1e-4, eps=1e-8, verbose=False):
264
+
265
+ """
266
+ if factor >= 1.0:
267
+ raise ValueError('Factor should be < 1.0.')
268
+ """
269
+ self.factor = factor
270
+
271
+ # Attach optimizer
272
+ if not isinstance(optimizer, Optimizer):
273
+ raise TypeError(f'{type(optimizer).__name__} is not an Optimizer')
274
+ self.optimizer = optimizer
275
+
276
+ if isinstance(min_lr, (list, tuple)):
277
+ if len(min_lr) != len(optimizer.param_groups):
278
+ raise ValueError(f"expected {len(optimizer.param_groups)} min_lrs, got {len(min_lr)}")
279
+ self.min_lrs = list(min_lr)
280
+ self.max_lrs = list(max_lr)
281
+ else:
282
+ self.min_lrs = [min_lr] * len(optimizer.param_groups)
283
+ self.max_lrs = [max_lr] * len(optimizer.param_groups)
284
+
285
+ self.patience = patience
286
+ self.plateau_var = plateau_var
287
+
288
+ self.verbose = verbose
289
+ self.cooldown = cooldown
290
+ self.cooldown_counter = 0
291
+ self.mode = mode
292
+ self.threshold = threshold
293
+ self.threshold_mode = threshold_mode
294
+ self.best = None
295
+ self.num_bad_epochs = None
296
+ self.mode_worse = None # the worse value for the chosen mode
297
+ self.eps = eps
298
+ self.last_epoch = 0
299
+ self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
300
+ self._init_is_better(mode=mode, threshold=threshold,
301
+ threshold_mode=threshold_mode)
302
+ self._reset()
303
+
304
+ def _reset(self):
305
+ """Resets num_bad_epochs counter and cooldown counter."""
306
+ self.best = self.mode_worse
307
+ self.cooldown_counter = 0
308
+ self.num_bad_epochs = 0
309
+
310
+ def step(self, model, metrics, epoch=None):
311
+ # convert `metrics` to float, in case it's a zero-dim Tensor
312
+ current = float(metrics[self.plateau_var])
313
+ if epoch is None:
314
+ epoch = self.last_epoch + 1
315
+ else:
316
+ warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
317
+ self.last_epoch = epoch
318
+
319
+ if self.is_better(current, self.best):
320
+ if(self.verbose):
321
+ print("Model is improving!")
322
+ self.best = current
323
+ self.num_bad_epochs = 0
324
+ else:
325
+ if(self.verbose):
326
+ print(f"Model is not improving :( best = {self.best}, current = {current}")
327
+ self.num_bad_epochs += 1
328
+
329
+ if self.in_cooldown:
330
+ self.cooldown_counter -= 1
331
+ self.num_bad_epochs = 0 # ignore any bad epochs in cooldown
332
+
333
+ if self.num_bad_epochs > self.patience:
334
+ self._reduce_lr(epoch)
335
+ self.cooldown_counter = self.cooldown
336
+ self.num_bad_epochs = 0
337
+
338
+ self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
339
+
340
+ def _reduce_lr(self, epoch):
341
+ print("Adjusting Learning Rate")
342
+ self._reset()
343
+ for i, param_group in enumerate(self.optimizer.param_groups):
344
+ old_lr = float(param_group['lr'])
345
+ new_lr = max(old_lr * self.factor, self.min_lrs[i])
346
+ new_lr = min(new_lr, self.max_lrs[i])
347
+ if abs(old_lr - new_lr) > self.eps:
348
+ param_group['lr'] = new_lr
349
+
350
+ def get_last_lr(self):
351
+ return self._last_lr
352
+ @property
353
+ def in_cooldown(self):
354
+ return self.cooldown_counter > 0
355
+
356
+ def is_better(self, a, best):
357
+ if self.mode == 'min' and self.threshold_mode == 'rel':
358
+ rel_epsilon = 1. - self.threshold
359
+ return a < best * rel_epsilon
360
+
361
+ elif self.mode == 'min' and self.threshold_mode == 'abs':
362
+ return a < best - self.threshold
363
+
364
+ elif self.mode == 'max' and self.threshold_mode == 'rel':
365
+ rel_epsilon = self.threshold + 1.
366
+ return a > best * rel_epsilon
367
+
368
+ else: # mode == 'max' and epsilon_mode == 'abs':
369
+ return a > best + self.threshold
370
+
371
+ def _init_is_better(self, mode, threshold, threshold_mode):
372
+ if mode not in {'min', 'max'}:
373
+ raise ValueError('mode ' + mode + ' is unknown!')
374
+ if threshold_mode not in {'rel', 'abs'}:
375
+ raise ValueError('threshold mode ' + threshold_mode + ' is unknown!')
376
+
377
+ if mode == 'min':
378
+ self.mode_worse = inf
379
+ else: # mode == 'max':
380
+ self.mode_worse = -inf
381
+
382
+ self.mode = mode
383
+ self.threshold = threshold
384
+ self.threshold_mode = threshold_mode
385
+
386
+ def state_dict(self):
387
+ return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
388
+
389
+ def load_state_dict(self, state_dict):
390
+ self.__dict__.update(state_dict)
391
+ self._init_is_better(mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode)
392
+
393
+ class Action_On_Plateau():
394
+
395
+ def __init__(self, mode = 'max', patience=10,
396
+ plateau_var = "test_auc",
397
+ threshold=1e-4, threshold_mode='rel', cooldown=0,
398
+ eps=1e-8, verbose=False):
399
+
400
+ self.patience = patience
401
+ self.plateau_var = plateau_var
402
+
403
+ self.verbose = verbose
404
+ self.cooldown = cooldown
405
+ self.cooldown_counter = 0
406
+ self.mode = mode
407
+ self.threshold = threshold
408
+ self.threshold_mode = threshold_mode
409
+ self.best = None
410
+ self.num_bad_epochs = None
411
+ self.mode_worse = None # the worse value for the chosen mode
412
+ self.eps = eps
413
+ self.last_epoch = 0
414
+ self._init_is_better(mode=mode, threshold=threshold,
415
+ threshold_mode=threshold_mode)
416
+ self._reset()
417
+
418
+ def _reset(self):
419
+ """Resets num_bad_epochs counter and cooldown counter."""
420
+ self.best = self.mode_worse
421
+ self.cooldown_counter = 0
422
+ self.num_bad_epochs = 0
423
+
424
+ def step(self, model, metrics, epoch=None):
425
+ # convert `metrics` to float, in case it's a zero-dim Tensor
426
+ current = float(metrics[self.plateau_var])
427
+ if epoch is None:
428
+ epoch = self.last_epoch + 1
429
+ else:
430
+ warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
431
+ self.last_epoch = epoch
432
+
433
+ if self.is_better(current, self.best):
434
+ if(self.verbose):
435
+ print("Model is improving!")
436
+ self.best = current
437
+ self.num_bad_epochs = 0
438
+ else:
439
+ if(self.verbose):
440
+ print(f"Model is not improving :( best = {self.best}, current = {current}")
441
+ self.num_bad_epochs += 1
442
+
443
+ if self.in_cooldown:
444
+ self.cooldown_counter -= 1
445
+ self.num_bad_epochs = 0 # ignore any bad epochs in cooldown
446
+
447
+ if self.num_bad_epochs > self.patience:
448
+ self.action(model, metrics, epoch)
449
+
450
+ def action(self, model, metrics, epoch=None):
451
+ if(self.verbose):
452
+ print("Doing my action")
453
+
454
+ @property
455
+ def in_cooldown(self):
456
+ return self.cooldown_counter > 0
457
+
458
+ def is_better(self, a, best):
459
+ if self.mode == 'min' and self.threshold_mode == 'rel':
460
+ rel_epsilon = 1. - self.threshold
461
+ return a < best * rel_epsilon
462
+
463
+ elif self.mode == 'min' and self.threshold_mode == 'abs':
464
+ return a < best - self.threshold
465
+
466
+ elif self.mode == 'max' and self.threshold_mode == 'rel':
467
+ rel_epsilon = self.threshold + 1.
468
+ return a > best * rel_epsilon
469
+
470
+ else: # mode == 'max' and epsilon_mode == 'abs':
471
+ return a > best + self.threshold
472
+
473
+ def _init_is_better(self, mode, threshold, threshold_mode):
474
+ if mode not in {'min', 'max'}:
475
+ raise ValueError('mode ' + mode + ' is unknown!')
476
+ if threshold_mode not in {'rel', 'abs'}:
477
+ raise ValueError('threshold mode ' + threshold_mode + ' is unknown!')
478
+
479
+ if mode == 'min':
480
+ self.mode_worse = inf
481
+ else: # mode == 'max':
482
+ self.mode_worse = -inf
483
+
484
+ self.mode = mode
485
+ self.threshold = threshold
486
+ self.threshold_mode = threshold_mode
487
+
488
+ class Partial_Reset(Action_On_Plateau):
489
+
490
+ def __init__(self, mode='max', patience=10, plateau_var="test_auc",
491
+ threshold=0.0001, threshold_mode='rel', cooldown=0,
492
+ eps=1e-8, verbose=False):
493
+
494
+ super().__init__(mode, patience, plateau_var, threshold,
495
+ threshold_mode, cooldown, eps, verbose)
496
+
497
+ def action(self, model, metrics, epoch=None):
498
+ print("Partial Reset!!")
499
+ GCN.partial_reset(model)
500
+ self._reset()
501
+ self.cooldown_counter = self.cooldown
502
+ self.num_bad_epochs = 0
503
+
504
+
505
+ class Full_Reset(Action_On_Plateau):
506
+
507
+ def __init__(self, mode='max', patience=10, plateau_var="test_auc",
508
+ threshold=0.0001, threshold_mode='rel', cooldown=0,
509
+ eps=1e-8, verbose=False):
510
+
511
+ super().__init__(mode, patience, plateau_var, threshold,
512
+ threshold_mode, cooldown, eps, verbose)
513
+
514
+ def action(self, model, metrics, epoch=None):
515
+ print("Full Reset!!")
516
+ GCN.full_reset(model)
517
+ self._reset()
518
+ self.cooldown_counter = self.cooldown
519
+ self.num_bad_epochs = 0
520
+
521
+ class Dynamic_LR_AND_Partial_Reset():
522
+ def __init__(self, optimizer, mode = 'max', factor=0.1, patience=10,
523
+ plateau_var = "test_auc", reset_patience=None, reset_plateau_var=None,
524
+ threshold=1e-4, threshold_mode='rel', cooldown=0,
525
+ min_lr=0, max_lr=1e-4, eps=1e-8, verbose=False):
526
+
527
+ if (reset_patience == None):
528
+ reset_patience = patience
529
+ if(reset_plateau_var == None):
530
+ reset_plateau_var = plateau_var
531
+
532
+ self.dynamic_lr = Dynamic_LR(optimizer, mode=mode, factor=factor, patience = patience,
533
+ plateau_var=plateau_var, threshold=threshold, threshold_mode =threshold_mode,
534
+ cooldown=cooldown, min_lr=min_lr, max_lr=max_lr, eps=eps, verbose=verbose)
535
+
536
+ self.partial_reset = Partial_Reset(mode=mode, patience=reset_patience, plateau_var=reset_plateau_var,
537
+ threshold=threshold, threshold_mode=threshold_mode, cooldown=cooldown,
538
+ eps=eps)
539
+
540
+ def step(self, model, metrics, epoch=None):
541
+ self.dynamic_lr.step(model=model, metrics=metrics, epoch=epoch)
542
+ self.partial_reset.step(model=model, metrics=metrics, epoch=epoch)
543
+
544
+ class Dynamic_LR_AND_Full_Reset():
545
+ def __init__(self, optimizer, mode = 'max', factor=0.1, patience=10,
546
+ plateau_var = "test_auc", reset_patience=None, reset_plateau_var=None,
547
+ threshold=1e-4, threshold_mode='rel', cooldown=0,
548
+ min_lr=0, max_lr=1e-4, eps=1e-8, verbose=False):
549
+
550
+ if (reset_patience == None):
551
+ reset_patience = patience
552
+ if(reset_plateau_var == None):
553
+ reset_plateau_var = plateau_var
554
+
555
+ self.dynamic_lr = Dynamic_LR(optimizer, mode=mode, factor=factor, patience = patience,
556
+ plateau_var=plateau_var, threshold=threshold, threshold_mode =threshold_mode,
557
+ cooldown=cooldown, min_lr=min_lr, max_lr=max_lr, eps=eps, verbose=verbose)
558
+
559
+ self.full_reset = Full_Reset(mode=mode, patience=reset_patience, plateau_var=reset_plateau_var,
560
+ threshold=threshold, threshold_mode=threshold_mode, cooldown=cooldown,
561
+ eps=eps)
562
+
563
+ def step(self, model, metrics, epoch=None):
564
+ self.dynamic_lr.step(model=model, metrics=metrics, epoch=epoch)
565
+ self.full_reset.step(model=model, metrics=metrics, epoch=epoch)
root_gnn_base/dataset.py ADDED
@@ -0,0 +1,685 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dgl.data import DGLDataset
2
+ import dgl
3
+ import ROOT
4
+ import torch
5
+ import os
6
+ import glob
7
+ import time
8
+ import numpy as np
9
+ from root_gnn_base import utils
10
+
11
+ def node_features_from_tree(ch, node_branch_names, node_branch_types, node_feature_scales):
12
+ lengths = []
13
+ for branch, node_type in zip(node_branch_names[0], node_branch_types):
14
+ if node_type == 'single':
15
+ lengths.append(1)
16
+ elif node_type == 'vector':
17
+ lengths.append(len(getattr(ch, branch)))
18
+ else:
19
+ print('Unknown node branch type: {}'.format(node_type))
20
+ features = []
21
+ for node_feat in node_branch_names:
22
+ if node_feat == 'CALC_E':
23
+ features.append(features[0]*torch.cosh(features[1]))
24
+ continue
25
+ elif node_feat == 'NODE_TYPE':
26
+ feat = []
27
+ for i, length in enumerate(lengths):
28
+ feat.extend([i,]*length)
29
+ features.append(torch.tensor(feat))
30
+ continue
31
+ feat = []
32
+ itype = 0
33
+ for length, branch, node_type in zip(lengths, node_feat, node_branch_types):
34
+ if isinstance(branch, (int, float, complex)):
35
+ feat.extend([branch,]*length)
36
+ elif branch == 'CALC_E':
37
+ this_type_starts_at = sum(lengths[:itype])
38
+ this_type_ends_at = sum(lengths[:itype+1])
39
+ feat.extend(features[0][this_type_starts_at:this_type_ends_at]*torch.cosh(features[1][this_type_starts_at:this_type_ends_at]))
40
+ elif node_type == 'single':
41
+ feat.append(getattr(ch, branch))
42
+ elif node_type == 'vector':
43
+ feat.extend(getattr(ch, branch))
44
+ itype += 1
45
+ features.append(torch.tensor(feat))
46
+ return torch.stack(features, dim=1) * node_feature_scales, lengths
47
+
48
+ def full_connected_graph(n_nodes, self_loops=True):
49
+ senders = []
50
+ receivers = []
51
+ senders = np.arange(n_nodes*n_nodes) // n_nodes
52
+ receivers = np.arange(n_nodes*n_nodes) % n_nodes
53
+ if not self_loops and n_nodes > 1:
54
+ mask = senders != receivers
55
+ senders = senders[mask]
56
+ receivers = receivers[mask]
57
+ return dgl.graph((senders, receivers))
58
+
59
+ def check_selection(ch, selection):
60
+ var, cut, op = selection
61
+ if op == '>':
62
+ return getattr(ch, var) > cut
63
+ elif op == '<':
64
+ return getattr(ch, var) < cut
65
+ elif op == '==':
66
+ return getattr(ch, var) == cut
67
+
68
+ def check_selections(ch, selections):
69
+ for selection in selections:
70
+ if not check_selection(ch, selection):
71
+ return False
72
+ return True
73
+
74
+ #Base dataset class for making graphs from ROOT ntuples.
75
+ class RootDataset(DGLDataset):
76
+ def __init__(self, name=None, raw_dir=None, save_dir=None, label=1, file_names = '*.root', node_branch_names=None, node_branch_types=None, node_feature_scales=None,
77
+ selections=[], save=True, tree_name = 'nominal_Loose', fold_var = 'eventNumber', weight_var = None, chunks = 1, process_chunks = None, global_features = [], tracking_info = [], **kwargs):
78
+ print(f'Unused args while creating RootDataset: {kwargs}')
79
+ self.label = label
80
+ self.counts = []
81
+ self.selections = selections
82
+ self.save_to_disk = save
83
+ self.file_names = file_names
84
+ self.node_branch_names = node_branch_names
85
+ self.node_branch_types = node_branch_types
86
+ self.node_feature_scales = torch.tensor([float(sf) for sf in node_feature_scales])
87
+ self.tree_name = tree_name
88
+ self.fold_var = fold_var
89
+ self.tracking_info = tracking_info
90
+ self.tracking_info.insert(0, fold_var)
91
+ if weight_var == None:
92
+ weight_var = 1
93
+ self.tracking_info.insert(1, weight_var)
94
+ self.global_features = global_features
95
+ self.chunks = chunks
96
+ self.process_chunks = process_chunks
97
+ if self.process_chunks is None:
98
+ self.process_chunks = [i for i in range(self.chunks)]
99
+ self.times = [0, 0]
100
+ super().__init__(name=name, raw_dir=raw_dir, save_dir=save_dir)
101
+
102
+ def get_list_of_branches(self):
103
+ branches = []
104
+ for feat in self.node_branch_names:
105
+ if isinstance(feat, list):
106
+ for branch in feat:
107
+ if branch == 'CALC_E':
108
+ continue
109
+ if isinstance(branch, str):
110
+ branches.append(branch)
111
+ for feat in self.global_features:
112
+ if isinstance(feat, str):
113
+ branches.append(feat)
114
+ for feat in self.tracking_info:
115
+ if isinstance(feat, str):
116
+ branches.append(feat)
117
+ for selection in self.selections:
118
+ branches.append(selection[0])
119
+ return branches
120
+
121
+ def make_graph(self, ch):
122
+ t1 = time.time()
123
+ features, _ = node_features_from_tree(ch, self.node_branch_names, self.node_branch_types, self.node_feature_scales)
124
+ features = features[features[:,0] != 0]
125
+ t2 = time.time()
126
+ g = full_connected_graph(features.shape[0], self_loops=False)
127
+ g.ndata['features'] = features
128
+ t3 = time.time()
129
+ self.times[0] += t2 - t1
130
+ self.times[1] += t3 - t2
131
+ return g
132
+
133
+ def process(self):
134
+ times = [0, 0, 0]
135
+ oldtime = time.time()
136
+ if isinstance(self.file_names, str):
137
+ self.files = glob.glob(os.path.join(self.raw_dir, self.file_names))
138
+ else:
139
+ self.files = []
140
+ for file_name in self.file_names:
141
+ self.files.extend(glob.glob(os.path.join(self.raw_dir, file_name)))
142
+ self.chain = ROOT.TChain(self.tree_name)
143
+
144
+ if len(self.files) == 0:
145
+ print('No files found in {}'.format(os.path.join(self.raw_dir, self.file_names)))
146
+ for file in self.files:
147
+ utils.set_timeout(60*2)
148
+ self.chain.Add(file)
149
+ utils.unset_timeout()
150
+ branches = self.get_list_of_branches()
151
+ self.chain.SetBranchStatus('*', 0)
152
+ for branch in branches:
153
+ self.chain.SetBranchStatus(branch, 1)
154
+ newtime = time.time()
155
+ times[0] += newtime - oldtime
156
+ chunks = np.array_split(np.arange(self.chain.GetEntries()), self.chunks)
157
+ chunks = [chunk for i, chunk in enumerate(chunks) if i in self.process_chunks]
158
+
159
+ self.graph_chunks = []
160
+ self.label_chunks = []
161
+ self.tracking_chunks = []
162
+ self.global_chunks = []
163
+ chunk_id = -1
164
+ for chunk in chunks:
165
+ chunk_id += 1
166
+ graphs = []
167
+ labels = []
168
+ tracking = []
169
+ globals = []
170
+ for ientry in chunk:
171
+ if (ientry % 10000 == 0):
172
+ print('Processing event {}/{}'.format(ientry, self.chain.GetEntries()), flush=True)
173
+ self.chain.GetEntry(ientry)
174
+ passed = True
175
+ for selection in self.selections:
176
+ if not check_selection(self.chain, selection):
177
+ passed = False
178
+ continue
179
+ oldtime = newtime
180
+ newtime = time.time()
181
+ times[1] += newtime - oldtime
182
+ if passed:
183
+ graphs.append(self.make_graph(self.chain))
184
+ labels.append( self.label )
185
+ tracking.append(torch.zeros(len(self.tracking_info), dtype=torch.double))
186
+ globals.append(torch.zeros(len(self.global_features)))
187
+ for i_ti, tr_branch in enumerate(self.tracking_info):
188
+ if isinstance(tr_branch, str):
189
+ tracking[-1][i_ti] = getattr(self.chain, tr_branch)
190
+ else:
191
+ tracking[-1][i_ti] = tr_branch
192
+ for i_gl, gl_branch in enumerate(self.global_features):
193
+ globals[-1][i_gl] = getattr(self.chain, gl_branch)
194
+ oldtime = newtime
195
+ newtime = time.time()
196
+ times[2] += newtime - oldtime
197
+
198
+ labels = torch.tensor(labels)
199
+ tracking = torch.stack(tracking)
200
+ globals = torch.stack(globals)
201
+
202
+ # self.labels = labels
203
+ # self.tracking = tracking
204
+ # self.global_features = globals
205
+ # self.graphs = graphs
206
+
207
+ self.save_chunk(chunk_id, graphs, labels, tracking, globals)
208
+
209
+ return
210
+ self.graphs = self.graph_chunks[0]
211
+ for chunk in self.graph_chunks[1:]:
212
+ self.graphs += chunk
213
+ self.labels = torch.cat(self.label_chunks)
214
+ self.tracking = torch.cat(self.tracking_chunks)
215
+ self.global_features = torch.cat(self.global_chunks)
216
+ print('Time spent: Creating TChain: {}s, Getting Entries and Selection: {}s, Graph Creation: {}s'.format(*times))
217
+ print('Time spent in node_features_from_tree: {}s, full_connected_graph: {}s'.format(*self.times))
218
+
219
+ def save(self):
220
+ """save the graph list and the labels"""
221
+ if not self.save_to_disk:
222
+ return
223
+ graph_path = os.path.join(self.save_dir, self.name + '.bin')
224
+ if self.chunks == 1:
225
+ # print(len(self.graphs))
226
+ # print(len(self.labels))
227
+ # print(len(self.tracking))
228
+ # print(len(self.globals))
229
+ print(f'Saving dataset to {os.path.join(self.save_dir, self.name + ".bin")}')
230
+ dgl.save_graphs(str(graph_path), self.graphs, {'labels': torch.tensor(self.labels), 'tracking': torch.tensor(self.tracking), 'global': torch.tensor(self.global_features)})
231
+ else:
232
+ print(len(self.graph_chunks))
233
+ for i in range(len(self.process_chunks)):
234
+ print(f'Saving dataset to {os.path.join(self.save_dir, self.name + f"_{self.process_chunks[i]}.bin")}')
235
+ dgl.save_graphs(str(graph_path).replace('.bin', f'_{self.process_chunks[i]}.bin'), self.graph_chunks[i], {'labels': self.label_chunks[i], 'tracking': self.tracking_chunks[i], 'global': self.global_chunks[i]})
236
+
237
+ def save_chunk(self, chunk_id, graphs, labels, tracking, globals):
238
+ if not self.save_to_disk:
239
+ return
240
+ graph_path = os.path.join(self.save_dir, self.name + '.bin')
241
+ print(f'Saving dataset to {os.path.join(self.save_dir, self.name + f"_{self.process_chunks[chunk_id]}.bin")}')
242
+ dgl.save_graphs(str(graph_path).replace('.bin', f'_{self.process_chunks[chunk_id]}.bin'), graphs, {'labels': labels, 'tracking': tracking, 'global': globals})
243
+
244
+ def has_cache(self):
245
+ print(f'Checking for cache of {self.name}')
246
+ if not self.save_to_disk:
247
+ print('Skipping load.')
248
+ return False
249
+ if self.chunks == 1:
250
+ graph_path = os.path.join(self.save_dir, self.name + '.bin')
251
+ return os.path.exists(graph_path)
252
+ else:
253
+ for i in range(len(self.process_chunks)):
254
+ graph_path = os.path.join(self.save_dir, self.name + f'_{self.process_chunks[i]}.bin')
255
+ if not os.path.exists(graph_path):
256
+ print(f'File {graph_path} does not exist, processing.')
257
+ return False
258
+ return True
259
+
260
+ def load(self):
261
+ if self.chunks == 1:
262
+ print(f'Loading dataset from {os.path.join(self.save_dir, self.name + ".bin")}')
263
+ graphs, label_dict = dgl.load_graphs(os.path.join(self.save_dir, self.name + '.bin'))
264
+ self.graphs = graphs
265
+ self.labels = label_dict['labels']
266
+ self.tracking = label_dict['tracking']
267
+ self.global_features = label_dict['global']
268
+ else:
269
+ self.graphs = []
270
+ self.labels = []
271
+ self.tracking = []
272
+ self.global_features = []
273
+ for i in range(self.chunks):
274
+ try:
275
+ print(f'Loading dataset from {os.path.join(self.save_dir, self.name + f"_{self.process_chunks[i]}.bin")}')
276
+ graphs, label = dgl.load_graphs(os.path.join(self.save_dir, self.name + f'_{self.process_chunks[i]}.bin'))
277
+ self.graphs.extend(graphs)
278
+ self.labels.append(label['labels'])
279
+ self.tracking.append(label['tracking'])
280
+ self.global_features.append(label['global'])
281
+ except Exception as e:
282
+ print(e)
283
+ self.labels = torch.cat(self.labels)
284
+ self.tracking = torch.cat(self.tracking)
285
+ self.global_features = torch.cat(self.global_features)
286
+
287
+ def __getitem__(self, idx):
288
+ return self.graphs[idx], self.labels[idx], self.tracking[idx], self.global_features[idx]
289
+
290
+ def __len__(self):
291
+ return len(self.graphs)
292
+
293
+ #Dataset with edge features added (deta, dphi, dR)
294
+ class EdgeDataset(RootDataset):
295
+ def make_graph(self, ch):
296
+ g = super().make_graph(ch)
297
+ u, v = g.edges()
298
+ deta = g.ndata['features'][u, 1] - g.ndata['features'][v, 1]
299
+ dphi = g.ndata['features'][u, 2] - g.ndata['features'][v, 2]
300
+ dphi = torch.where(dphi > np.pi, dphi - 2*np.pi, dphi)
301
+ dphi = torch.where(dphi < -np.pi, dphi + 2*np.pi, dphi)
302
+ dR = torch.sqrt(deta**2 + dphi**2)
303
+ g.edata['features'] = torch.stack([deta, dphi, dR], dim=1)
304
+ return g
305
+
306
+ class tHbbEdgeDataset(RootDataset):
307
+ def __init__(self, exclude_branches=None, **kwargs):
308
+ self.exclude_branches = exclude_branches
309
+ super().__init__(**kwargs)
310
+
311
+ def get_list_of_branches(self):
312
+ br = super().get_list_of_branches()
313
+ for sector in self.exclude_branches:
314
+ if sector == None:
315
+ continue
316
+ for excl in sector:
317
+ if type(excl) == str:
318
+ br.append(excl)
319
+ return br
320
+
321
+ def make_graph(self, ch):
322
+ features, lengths = node_features_from_tree(ch, self.node_branch_names, self.node_branch_types, self.node_feature_scales)
323
+
324
+ include_mask = torch.ones(features.shape[0], dtype=torch.bool)
325
+ node_idx = 0
326
+ for sector, length in zip(self.exclude_branches, lengths):
327
+ if sector == None:
328
+ node_idx += length
329
+ continue
330
+ for excl in sector:
331
+ if type(excl) == int:
332
+ include_mask[excl + node_idx] = False
333
+ elif type(excl) == str:
334
+ include_mask[getattr(self.chain, excl) + node_idx] = False
335
+ g = full_connected_graph(features[include_mask].shape[0], self_loops=False)
336
+ g.ndata['features'] = features[include_mask]
337
+
338
+ u, v = g.edges()
339
+ deta = g.ndata['features'][u, 1] - g.ndata['features'][v, 1]
340
+ dphi = g.ndata['features'][u, 2] - g.ndata['features'][v, 2]
341
+ dphi = torch.where(dphi > np.pi, dphi - 2*np.pi, dphi)
342
+ dphi = torch.where(dphi < -np.pi, dphi + 2*np.pi, dphi)
343
+ dR = torch.sqrt(deta**2 + dphi**2)
344
+ g.edata['features'] = torch.stack([deta, dphi, dR], dim=1)
345
+ return g
346
+
347
+ class LazyDataset(EdgeDataset):
348
+ def __init__(self, buffer_size = 2, **kwargs):
349
+ self.buffer = [None,] * buffer_size
350
+ self.buffer_ptr = 0
351
+ self.get_item_calls = 0
352
+ self.buffer_indices = [-1,] * buffer_size
353
+ super().__init__(**kwargs)
354
+
355
+ def __getitem__(self, idx):
356
+ self.get_item_calls += 1
357
+ chunk_idx = -1
358
+ sum = 0
359
+ ev_idx = -999
360
+ for i, count in enumerate(self.counts):
361
+ sum += count
362
+ if idx < sum:
363
+ chunk_idx = i
364
+ ev_idx = idx - sum + count
365
+ break
366
+ buf_idx = self.buffer_get(chunk_idx)
367
+ if ev_idx >= len(self.buffer[buf_idx][0]):
368
+ print(f'Getting event {ev_idx} from chunk {chunk_idx} from buffer {buf_idx}. Calls: {self.get_item_calls}')
369
+ print(len(self.buffer))
370
+ print(self.counts)
371
+ print(len(self.buffer[buf_idx][0]))
372
+ return self.buffer[buf_idx][0][ev_idx], self.buffer[buf_idx][1]['labels'][ev_idx], self.buffer[buf_idx][1]['tracking'][ev_idx], self.buffer[buf_idx][1]['global'][ev_idx]
373
+
374
+ def buffer_get(self, buffer_idx):
375
+ if buffer_idx in self.buffer_indices:
376
+ for i in range(len(self.buffer)):
377
+ if self.buffer_indices[i] == buffer_idx:
378
+ return i
379
+ else:
380
+ print(f'Loading dataset from {os.path.join(self.save_dir, self.name + f"_{buffer_idx}.bin")}', flush=True)
381
+ self.buffer_ptr = (self.buffer_ptr + 1) % len(self.buffer)
382
+ self.buffer[self.buffer_ptr] = dgl.load_graphs(os.path.join(self.save_dir, self.name + f'_{buffer_idx}.bin'))
383
+ self.buffer_indices[self.buffer_ptr] = buffer_idx
384
+ return self.buffer_ptr
385
+
386
+ def load(self):
387
+ self.counts = []
388
+ self.tracking = []
389
+ try:
390
+ for i in range(self.chunks):
391
+ print(f'Loading dataset from {os.path.join(self.save_dir, self.name + f"_{self.process_chunks[i]}.bin")}')
392
+ l = dgl.data.graph_serialize.load_labels_v2(os.path.join(self.save_dir, self.name + f'_{self.process_chunks[i]}.bin'))
393
+ self.counts.append(len(l['tracking']))
394
+ self.tracking.append(l['tracking'])
395
+ self.tracking = torch.cat(self.tracking)
396
+ except Exception as e:
397
+ print(e)
398
+
399
+ def __len__(self):
400
+ return sum(self.counts)
401
+
402
+ class MultiLabelDataset(EdgeDataset):
403
+ def __init__(self, **kwargs):
404
+ super().__init__(**kwargs)
405
+
406
+ def get_list_of_branches(self):
407
+ br = super().get_list_of_branches()
408
+ for l in self.label:
409
+ if isinstance(l, str):
410
+ br.append(l)
411
+ if isinstance(l, dict):
412
+ br.append(l['branch'])
413
+ return br
414
+
415
+ def get_label(self, ch):
416
+ label = []
417
+ for l in self.label:
418
+ if isinstance(l, str):
419
+ label.append((getattr(ch, l)))
420
+ if isinstance(l, dict):
421
+ label.append(getattr(ch, l['branch'])*float(l['scale']))
422
+ if isinstance(l, float) or isinstance(l, int):
423
+ label.append(l)
424
+
425
+ return torch.tensor(label)
426
+
427
+ def process(self):
428
+ times = [0, 0, 0]
429
+ oldtime = time.time()
430
+ if isinstance(self.file_names, str):
431
+ self.files = glob.glob(os.path.join(self.raw_dir, self.file_names))
432
+ else:
433
+ self.files = []
434
+ for file_name in self.file_names:
435
+ self.files.extend(glob.glob(os.path.join(self.raw_dir, file_name)))
436
+ self.chain = ROOT.TChain(self.tree_name)
437
+ if len(self.files) == 0:
438
+ print('No files found in {}'.format(os.path.join(self.raw_dir, self.file_names)))
439
+ for file in self.files:
440
+ utils.set_timeout(60*2)
441
+ self.chain.Add(file)
442
+ utils.unset_timeout()
443
+ branches = self.get_list_of_branches()
444
+ self.chain.SetBranchStatus('*', 0)
445
+ for branch in branches:
446
+ self.chain.SetBranchStatus(branch, 1)
447
+ newtime = time.time()
448
+ times[0] += newtime - oldtime
449
+ chunks = np.array_split(np.arange(self.chain.GetEntries()), self.chunks)
450
+ chunks = [chunk for i, chunk in enumerate(chunks) if i in self.process_chunks]
451
+ self.graph_chunks = []
452
+ self.label_chunks = []
453
+ self.tracking_chunks = []
454
+ self.global_chunks = []
455
+ chunk_id = -1
456
+ for chunk in chunks:
457
+ chunk_id += 1
458
+ graphs = []
459
+ labels = []
460
+ tracking = []
461
+ globals = []
462
+ for ientry in chunk:
463
+ if (ientry % 10000 == 0):
464
+ print('Processing event {}/{}'.format(ientry, self.chain.GetEntries()), flush=True)
465
+ self.chain.GetEntry(ientry)
466
+ passed = True
467
+ for selection in self.selections:
468
+ if not check_selection(self.chain, selection):
469
+ passed = False
470
+ continue
471
+ oldtime = newtime
472
+ newtime = time.time()
473
+ times[1] += newtime - oldtime
474
+ if passed:
475
+ graphs.append(self.make_graph(self.chain))
476
+ labels.append(self.get_label(self.chain))
477
+ tracking.append(torch.zeros(len(self.tracking_info), dtype=torch.double))
478
+ globals.append(torch.zeros(len(self.global_features)))
479
+ for i_ti, tr_branch in enumerate(self.tracking_info):
480
+ if isinstance(tr_branch, str):
481
+ tracking[-1][i_ti] = getattr(self.chain, tr_branch)
482
+ else:
483
+ tracking[-1][i_ti] = tr_branch
484
+ for i_gl, gl_branch in enumerate(self.global_features):
485
+ globals[-1][i_gl] = getattr(self.chain, gl_branch)
486
+ oldtime = newtime
487
+ newtime = time.time()
488
+ times[2] += newtime - oldtime
489
+
490
+ labels = torch.stack(labels)
491
+ self.save_chunk(chunk_id, graphs, labels, torch.stack(tracking), torch.stack(globals))
492
+ # self.graph_chunks.append(graphs)
493
+ # self.label_chunks.append(labels)
494
+ # self.tracking_chunks.append(torch.stack(tracking))
495
+ # self.global_chunks.append(torch.stack(globals))
496
+ # self.counts.append(len(graphs))
497
+ return
498
+ self.graphs = self.graph_chunks[0]
499
+ for chunk in self.graph_chunks[1:]:
500
+ self.graphs += chunk
501
+
502
+ self.labels = torch.cat(self.label_chunks)
503
+ self.tracking = torch.cat(self.tracking_chunks)
504
+ self.global_features = torch.cat(self.global_chunks)
505
+ print('Time spent: Creating TChain: {}s, Getting Entries and Selection: {}s, Graph Creation: {}s'.format(*times))
506
+ print('Time spent in node_features_from_tree: {}s, full_connected_graph: {}s'.format(*self.times))
507
+
508
+ class LazyMultiLabelDataset(MultiLabelDataset, LazyDataset):
509
+ def __init__(self, buffer_size = 2, **kwargs):
510
+ LazyDataset.__init__(self, buffer_size=buffer_size, **kwargs)
511
+
512
+ class MultiLabeltHbbDataset(MultiLabelDataset, tHbbEdgeDataset):
513
+ def __init__(self, **kwargs):
514
+ super().__init__(**kwargs)
515
+
516
+ def get_list_of_branches(self):
517
+ br = super().get_list_of_branches()
518
+ for sector in self.exclude_branches:
519
+ if sector == None:
520
+ continue
521
+ for excl in sector:
522
+ if type(excl) == str:
523
+ br.append(excl)
524
+ return br
525
+
526
+
527
+ class AugmentedDataset(RootDataset):
528
+
529
+ def __init__(self, seed = 2, feature_index = None, node_mapping = None, **kwargs):
530
+ self.seed = seed
531
+ np.random.seed(seed)
532
+ if(feature_index == None):
533
+ self.feature_index = {"pt": 0, "eta": 1, "phi": 2, "energy": 3, "btag": 4, "charge": 5, "node_type": 6}
534
+ if (node_mapping == None):
535
+ self.node_mapping = {"jet": 0, "ele": 1, "mu": 2, "ph": 3, "MET": 4}
536
+ super().__init__(**kwargs)
537
+
538
+ def detector_noise(self, node_features):
539
+ noise = np.zeros_like(node_features)
540
+
541
+ node_types = node_features[:, self.feature_index["node_type"]]
542
+ pts = node_features[:, self.feature_index["pt"]]
543
+ etas = node_features[:, self.feature_index["eta"]]
544
+ energies = node_features[:, self.feature_index["energy"]]
545
+
546
+ # Noise calculation for jets
547
+ jet_mask = (node_types == self.node_mapping["jet"])
548
+ jet_pts = pts[jet_mask]
549
+ jet_etas = etas[jet_mask]
550
+
551
+ if (jet_mask.sum() > 0):
552
+ jet_resolutions = np.where(
553
+ jet_pts <= 0.1, 0.0,
554
+ np.where(
555
+ np.abs(jet_etas) <= 0.5, np.sqrt(0.06**2 + jet_pts**2 * 1.3e-3**2),
556
+ np.where(
557
+ np.abs(jet_etas) <= 1.5, np.sqrt(0.10**2 + jet_pts**2 * 1.7e-3**2),
558
+ np.where(
559
+ np.abs(jet_etas) <= 2.5, np.sqrt(0.25**2 + jet_pts**2 * 3.1e-3**2),
560
+ 0.0
561
+ )
562
+ )
563
+ )
564
+ )
565
+ noise[jet_mask, self.feature_index["pt"]] = np.random.normal(loc=0.0, scale=jet_resolutions)
566
+
567
+ # Noise calculation for electrons
568
+ ele_mask = (node_types == self.node_mapping["ele"])
569
+ ele_pts = pts[ele_mask]
570
+ ele_etas = etas[ele_mask]
571
+
572
+ if (ele_mask.sum() > 0):
573
+ ele_resolutions = np.where(
574
+ np.abs(ele_etas) <= 0.5, np.sqrt(0.03**2 + ele_pts**2 * 1.3e-3**2),
575
+ np.where(
576
+ np.abs(ele_etas) <= 1.5, np.sqrt(0.05**2 + ele_pts**2 * 1.7e-3**2),
577
+ np.where(
578
+ np.abs(ele_etas) <= 2.5, np.sqrt(0.15**2 + ele_pts**2 * 3.1e-3**2),
579
+ 0.0
580
+ )
581
+ )
582
+ )
583
+ noise[ele_mask, self.feature_index["pt"]] = np.random.normal(loc=0.0, scale=ele_resolutions)
584
+
585
+ # Noise calculation for muons
586
+ mu_mask = (node_types == self.node_mapping["mu"])
587
+ mu_pts = pts[mu_mask]
588
+ mu_etas = etas[mu_mask]
589
+
590
+ if (mu_mask.sum() > 0):
591
+ mu_resolutions = np.where(
592
+ np.abs(mu_etas) <= 0.5, np.sqrt(0.01**2 + mu_pts**2 * 1.0e-4**2),
593
+ np.where(
594
+ np.abs(mu_etas) <= 1.5, np.sqrt(0.015**2 + mu_pts**2 * 1.5e-4**2),
595
+ np.where(
596
+ np.abs(mu_etas) <= 2.5, np.sqrt(0.025**2 + mu_pts**2 * 3.5e-4**2),
597
+ 0.0
598
+ )
599
+ )
600
+ )
601
+ noise[mu_mask, self.feature_index["pt"]] = np.random.normal(loc=0.0, scale=mu_resolutions)
602
+
603
+ # Noise calculation for photons
604
+ ph_mask = (node_types == self.node_mapping["ph"])
605
+ ph_etas = etas[ph_mask]
606
+ ph_energies = energies[ph_mask]
607
+
608
+ if (ph_mask.sum() > 0):
609
+ ph_resolutions = np.where(
610
+ np.abs(ph_etas) <= 3.2, np.sqrt(ph_energies**2 * 0.0017**2 + ph_energies * 0.101**2),
611
+ np.where(
612
+ np.abs(ph_etas) <= 4.9, np.sqrt(ph_energies**2 * 0.0350**2 + ph_energies * 0.285**2),
613
+ 0.0
614
+ )
615
+ )
616
+ noise[ph_mask, self.feature_index["energy"]] = np.random.normal(loc=0.0, scale=ph_resolutions)
617
+ return noise
618
+
619
+ def make_graph(self, ch):
620
+ g = super().make_graph(ch)
621
+
622
+ g.ndata['augmented_features'] = g.ndata['features']
623
+
624
+ num_nodes = len(g.ndata['features'][:, 0])
625
+
626
+ # Rotations: phi -> phi + delta_phi
627
+ phi_index = self.feature_index["phi"]
628
+ # Generate a single delta_phi for all nodes
629
+ delta_phi = np.random.uniform(low=-np.pi, high=np.pi)
630
+
631
+ # Apply the same delta_phi to all nodes
632
+ g.ndata['augmented_features'][:, phi_index] = (g.ndata['augmented_features'][:, phi_index] + delta_phi + np.pi) % (2 * np.pi) - np.pi
633
+
634
+ # Reflections: eta -> -1 * eta, phi -> -1 * phi
635
+ eta_index = self.feature_index["eta"]
636
+
637
+ eta_reflection = np.random.choice([-1, 1])
638
+ phi_reflection = np.random.choice([-1, 1])
639
+
640
+ g.ndata['augmented_features'][:, eta_index] = g.ndata['augmented_features'][:, eta_index] * eta_reflection
641
+ g.ndata['augmented_features'][:, phi_index] = g.ndata['augmented_features'][:, phi_index] * phi_reflection
642
+
643
+
644
+ # Detector Noise: pt -> pt + normal(pt, noise(pt))
645
+ noise = self.detector_noise(g.ndata['augmented_features'])
646
+ g.ndata['augmented_features'] = g.ndata['augmented_features'] + noise
647
+
648
+ pt_index = self.feature_index["pt"]
649
+ if (g.ndata['augmented_features'][-1][self.feature_index["node_type"]] == self.node_mapping["MET"]):
650
+ # Initialize sums of px and py
651
+ sum_px = 0
652
+ sum_py = 0
653
+
654
+ # Loop over all nodes except the last one (MET node)
655
+ for i in range(len(g.ndata['augmented_features']) - 1):
656
+ pt = g.ndata['augmented_features'][i][pt_index]
657
+ phi = g.ndata['augmented_features'][i][phi_index]
658
+
659
+ # Compute px and py
660
+ px = pt * np.cos(phi)
661
+ py = pt * np.sin(phi)
662
+
663
+ # Sum px and py
664
+ sum_px += px
665
+ sum_py += py
666
+
667
+ # Calculate MET
668
+ g.ndata['augmented_features'][-1][pt_index] = np.sqrt(sum_px**2 + sum_py**2)
669
+
670
+ u, v = g.edges()
671
+ deta = g.ndata['features'][u, 1] - g.ndata['features'][v, 1]
672
+ dphi = g.ndata['features'][u, 2] - g.ndata['features'][v, 2]
673
+ dphi = torch.where(dphi > np.pi, dphi - 2*np.pi, dphi)
674
+ dphi = torch.where(dphi < -np.pi, dphi + 2*np.pi, dphi)
675
+ dR = torch.sqrt(deta**2 + dphi**2)
676
+ g.edata['features'] = torch.stack([deta, dphi, dR], dim=1)
677
+
678
+ deta = g.ndata['augmented_features'][u, 1] - g.ndata['augmented_features'][v, 1]
679
+ dphi = g.ndata['augmented_features'][u, 2] - g.ndata['augmented_features'][v, 2]
680
+ dphi = torch.where(dphi > np.pi, dphi - 2*np.pi, dphi)
681
+ dphi = torch.where(dphi < -np.pi, dphi + 2*np.pi, dphi)
682
+ dR = torch.sqrt(deta**2 + dphi**2)
683
+ g.edata['augmented_features'] = torch.stack([deta, dphi, dR], dim=1)
684
+
685
+ return g
root_gnn_base/photon_ID_dataset.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from root_gnn_base import dataset
2
+ import dgl
3
+ import torch
4
+ import numpy as np
5
+
6
+ def radius_graph(features, radii, self_loops=False):
7
+ senders = []
8
+ receivers = []
9
+ n_nodes = features.shape[0]
10
+ senders = np.arange(n_nodes*n_nodes) // n_nodes
11
+ receivers = np.arange(n_nodes*n_nodes) % n_nodes
12
+ if not self_loops and n_nodes > 1:
13
+ mask = senders != receivers
14
+ senders = senders[mask]
15
+ receivers = receivers[mask]
16
+ for k, r in radii.items():
17
+ d = features[senders, k] - features[receivers, k]
18
+ mask = np.abs(d) < r
19
+ senders = senders[mask]
20
+ receivers = receivers[mask]
21
+ return dgl.graph((senders, receivers))
22
+
23
+ class PhotonIDDataset(dataset.LazyMultiLabelDataset):
24
+ def __init__(self, eta_radius, phi_radius, **kwargs):
25
+ self.eta_radius = eta_radius
26
+ self.phi_radius = phi_radius
27
+ super().__init__(**kwargs)
28
+ def make_graph(self, ch):
29
+ features, _ = dataset.node_features_from_tree(ch, self.node_branch_names, self.node_branch_types, self.node_feature_scales)
30
+ features = features[features[:,0] != 0]
31
+ #Delta Eta, Delta Phi, Adjacent Layer
32
+ g = radius_graph(features, {1: self.eta_radius, 2: self.phi_radius, 6: 1.1}, self_loops=True) #Self loops ensure last cell is included even if disconnected
33
+ g.ndata['features'] = features
34
+ u, v = g.edges()
35
+ deta = features[u, 1] - features[v, 1]
36
+ dphi = g.ndata['features'][u, 2] - g.ndata['features'][v, 2]
37
+ dphi = torch.where(dphi > np.pi, dphi - 2*np.pi, dphi)
38
+ dphi = torch.where(dphi < -np.pi, dphi + 2*np.pi, dphi)
39
+ dR = torch.sqrt(deta**2 + dphi**2)
40
+ dx = features[u, 3] - features[v, 3]
41
+ dy = features[u, 4] - features[v, 4]
42
+ dz = features[u, 5] - features[v, 5]
43
+ g.edata['features'] = torch.stack([deta, dphi, dR, dx, dy, dz], dim=1)
44
+ return g
root_gnn_base/similarity.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import scipy
3
+ from sklearn.decomposition import PCA
4
+ from sklearn.metrics.pairwise import cosine_similarity
5
+ from sklearn.metrics.pairwise import euclidean_distances
6
+ from sklearn.preprocessing import StandardScaler
7
+
8
+ from scipy.stats import wasserstein_distance
9
+
10
+ def cka(rep_a, rep_b, size=None):
11
+ """
12
+ Computes the Centered Kernel Alignment (CKA) between two large representation matrices rep_a and rep_b.
13
+ If size is provided, it performs CKA on a randomly selected subset of the data.
14
+
15
+ Parameters:
16
+ rep_a : np.ndarray
17
+ First representation matrix of size (n_samples, n_features_a).
18
+ rep_b : np.ndarray
19
+ Second representation matrix of size (n_samples, n_features_b).
20
+ size : int, optional
21
+ Number of samples to use for the CKA calculation. If None, use the full dataset.
22
+
23
+ Returns:
24
+ float
25
+ CKA similarity between rep_a and rep_b.
26
+ """
27
+
28
+ def gram_linear(x):
29
+ """Compute the Gram (kernel) matrix using a linear kernel."""
30
+ return x @ x.T
31
+
32
+ def center_gram(gram):
33
+ """Center the Gram matrix."""
34
+ n = gram.shape[0]
35
+ identity = np.eye(n)
36
+ ones = np.ones((n, n)) / n
37
+ return gram - ones @ gram - gram @ ones + ones @ gram @ ones
38
+
39
+ # If sample_size is specified, randomly sample a subset of the data
40
+ if size is not None and size < rep_a.shape[0]:
41
+ indices = np.random.choice(rep_a.shape[0], size, replace=False)
42
+ rep_a = rep_a[indices]
43
+ rep_b = rep_b[indices]
44
+
45
+ # Compute the Gram matrices
46
+ gram_a = gram_linear(rep_a)
47
+ gram_b = gram_linear(rep_b)
48
+
49
+ # Center the Gram matrices
50
+ centered_gram_a = center_gram(gram_a)
51
+ centered_gram_b = center_gram(gram_b)
52
+
53
+ # Compute the CKA similarity
54
+ numerator = np.sum(centered_gram_a * centered_gram_b)
55
+ denominator = np.sqrt(np.sum(centered_gram_a**2) * np.sum(centered_gram_b**2))
56
+
57
+ return numerator / denominator if denominator != 0 else 0
58
+
59
+ def cca(X, Y, size = None, num_components=10):
60
+ """
61
+ Perform Canonical Correlation Analysis (CCA) between two datasets.
62
+
63
+ Parameters:
64
+ X : np.ndarray
65
+ First dataset, shape (n_samples, n_features_X).
66
+ Y : np.ndarray
67
+ Second dataset, shape (n_samples, n_features_Y).
68
+ num_components : int
69
+ Number of CCA components to return.
70
+
71
+ Returns:
72
+ w_X : np.ndarray
73
+ Canonical weights for the first dataset, shape (n_features_X, num_components).
74
+ w_Y : np.ndarray
75
+ Canonical weights for the second dataset, shape (n_features_Y, num_components).
76
+ corrs : np.ndarray
77
+ Array of canonical correlations for each component.
78
+ """
79
+
80
+ # If sample size is specified, randomly sample a subset of the data
81
+ if size is not None and size < X.shape[0]:
82
+ indices = np.random.choice(X.shape[0], size, replace=False)
83
+ X = X[indices]
84
+ Y = Y[indices]
85
+
86
+ # Standardize both datasets (mean = 0, variance = 1)
87
+ scaler_X = StandardScaler()
88
+ scaler_Y = StandardScaler()
89
+
90
+ X = scaler_X.fit_transform(X)
91
+ Y = scaler_Y.fit_transform(Y)
92
+
93
+ # Covariance matrices
94
+ C_XX = np.cov(X, rowvar=False) # Covariance of X
95
+ C_YY = np.cov(Y, rowvar=False) # Covariance of Y
96
+ C_XY = np.cov(X, Y, rowvar=False)[:X.shape[1], X.shape[1]:] # Cross-covariance of X and Y
97
+
98
+ # Regularization term to avoid singular matrices
99
+ reg = 1e-6
100
+ inv_C_XX = np.linalg.inv(C_XX + reg * np.eye(C_XX.shape[0]))
101
+ inv_C_YY = np.linalg.inv(C_YY + reg * np.eye(C_YY.shape[0]))
102
+
103
+ # Solve the generalized eigenvalue problem for CCA
104
+ # (inv_C_XX @ C_XY @ inv_C_YY @ C_XY.T) and vice versa for Y
105
+ A = inv_C_XX @ C_XY @ inv_C_YY @ C_XY.T
106
+ B = inv_C_YY @ C_XY.T @ inv_C_XX @ C_XY
107
+
108
+ # Perform eigenvalue decomposition
109
+ eigvals_X, eigvecs_X = np.linalg.eigh(A)
110
+ eigvals_Y, eigvecs_Y = np.linalg.eigh(B)
111
+
112
+ # Sort the eigenvalues and eigenvectors in descending order
113
+ idx_X = np.argsort(eigvals_X)[::-1]
114
+ idx_Y = np.argsort(eigvals_Y)[::-1]
115
+
116
+ eigvecs_X = eigvecs_X[:, idx_X]
117
+ eigvecs_Y = eigvecs_Y[:, idx_Y]
118
+
119
+ # Canonical weights (the first `num_components` components)
120
+ w_X = eigvecs_X[:, :num_components]
121
+ w_Y = eigvecs_Y[:, :num_components]
122
+
123
+ # Canonical correlations (square root of the eigenvalues, constrained to [0,1])
124
+ corrs = np.sqrt(np.clip(eigvals_X[:num_components], 0, 1))
125
+
126
+ return np.mean(corrs)
127
+ return w_X, w_Y, corrs
128
+
129
+ def pca(X, Y, size=1000, n_components=3, bins=30):
130
+
131
+ pca_X = PCA(n_components=n_components)
132
+ X_pca = pca_X.fit_transform(X)
133
+
134
+ pca_Y = PCA(n_components=n_components)
135
+ Y_pca = pca_Y.fit_transform(Y)
136
+
137
+ # Step 2: Determine common bin edges based on the range of PCA components
138
+ min_value = min(X_pca.min(), Y_pca.min())
139
+ max_value = max(X_pca.max(), Y_pca.max())
140
+ bin_edges = np.linspace(min_value, max_value, bins + 1)
141
+
142
+ # Step 3: Calculate histograms for each PCA component using the same bins
143
+ histograms_X = [np.histogram(X_pca[:, i], bins=bin_edges, density=True)[0] for i in range(n_components)]
144
+ histograms_Y = [np.histogram(Y_pca[:, i], bins=bin_edges, density=True)[0] for i in range(n_components)]
145
+
146
+ # Step 4: Calculate Wasserstein distance between corresponding histograms
147
+ total_distance = 0
148
+ for i in range(n_components):
149
+ total_distance += wasserstein_distance(histograms_X[i], histograms_Y[i])
150
+
151
+ # Step 5: Normalize the total distance for a similarity score
152
+ # Calculate the maximum possible distance (theoretical max could be based on histogram size)
153
+ # This could be replaced with a more complex calculation if necessary.
154
+ max_distance = 1.0 # Replace this with a suitable maximum based on your dataset properties.
155
+
156
+ similarity_score = 1 - (total_distance / max_distance)
157
+
158
+ return max(0, min(1, similarity_score)) # Ensure the score stays in [0, 1]
root_gnn_base/uproot_dataset.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from root_gnn_base import dataset
2
+ import torch
3
+ import uproot
4
+ import glob
5
+ import os
6
+ import awkward as ak
7
+ import numpy as np
8
+ import time
9
+
10
+ def node_features_from_ak(ch, node_branch_names, node_branch_types, node_feature_scales):
11
+ node_types = []
12
+ n_types = len(node_branch_names[0])
13
+ for i in range(n_types):
14
+ features = []
15
+ branch_type = node_branch_types[i]
16
+ for j in range(len(node_branch_names)):
17
+ if node_branch_names[j] == 'CALC_E':
18
+ features.append(features[0] * np.cosh(features[1]))
19
+ elif node_branch_names[j] == 'NODE_TYPE':
20
+ features.append(ak.full_like(features[0], i))
21
+ elif isinstance(node_branch_names[j][i], str):
22
+ features.append(ch[node_branch_names[j][i]])
23
+ elif isinstance(node_branch_names[j][i], (int, float)):
24
+ features.append(ak.full_like(features[0], node_branch_names[j][i]))
25
+ if branch_type == 'single':
26
+ features = [f[:,np.newaxis] for f in features]
27
+ node_types.append(ak.Array(features))
28
+ node_features = ak.concatenate(node_types, axis=2) * node_feature_scales #axis order at this point is (feature, event, node)
29
+ return node_features
30
+
31
+ class UprootDataset(dataset.RootDataset):
32
+ def process(self):
33
+ starttime = time.time()
34
+ self.files = glob.glob(os.path.join(self.raw_dir, self.file_names))
35
+ branches = self.get_list_of_branches()
36
+ self.chain = uproot.concatenate([f + ':' + self.tree_name for f in self.files], branches, num_workers=4)
37
+ node_features = node_features_from_ak(self.chain, self.node_branch_names, self.node_branch_types, self.node_feature_scales)
38
+ loadtime = time.time()
39
+ n_nodes = ak.num(node_features[0], axis=1) #number of nodes for each event
40
+ ftime = time.time()
41
+ self.graphs = [dataset.full_connected_graph(n, False) for n in n_nodes]
42
+ itime = time.time()
43
+ for i in range(len(self.graphs)):
44
+ if i % 10000 == 0:
45
+ print(f'Processing event {i}/{len(self.graphs)}')
46
+ self.graphs[i].ndata['features'] = torch.transpose(torch.tensor(node_features[:,i,:]),0,1).to(torch.float)
47
+ self.label = torch.stack([torch.full((len(self.graphs),),torch.tensor(self.label)), torch.tensor(ak.values_astype(self.chain[self.fold_var], np.int64))], dim=1)
48
+ gtime = time.time()
49
+ print()
50
+ print(f'load time: {loadtime - starttime} s')
51
+ print(f'feature time: {ftime - loadtime} s')
52
+ print(f'graph time: {itime - ftime} s')
53
+ print(f'graph data time: {gtime - itime} s')
54
+
root_gnn_base/utils.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import yaml
3
+ import os
4
+ import torch
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ import dgl
8
+ import signal
9
+
10
+ def buildFromConfig(conf, run_time_args = {}):
11
+ if 'module' in conf:
12
+ module = importlib.import_module(conf['module'])
13
+ cls = getattr(module, conf['class'])
14
+ return cls(**conf['args'], **run_time_args)
15
+ else:
16
+ print('No module specified in config. Returning None.')
17
+
18
+ def cycler(iterable):
19
+ while True:
20
+ #print('Cycler is cycling...')
21
+ for i in iterable:
22
+ yield i
23
+
24
+ def include_config(conf):
25
+ if 'include' in conf:
26
+ for i in conf['include']:
27
+ with open(i) as f:
28
+ conf.update(yaml.load(f, Loader=yaml.FullLoader))
29
+ del conf['include']
30
+
31
+ def load_config(config_file):
32
+ with open(config_file) as f:
33
+ conf = yaml.load(f, Loader=yaml.FullLoader)
34
+ include_config(conf)
35
+ return conf
36
+
37
+ #Timeout function from https://stackoverflow.com/questions/492519/timeout-on-a-function-call
38
+ class TimeoutException(Exception):
39
+ pass
40
+
41
+ def timeout_handler(signum, frame):
42
+ raise TimeoutException()
43
+
44
+ def set_timeout(timeout):
45
+ signal.signal(signal.SIGALRM, timeout_handler)
46
+ signal.alarm(timeout)
47
+
48
+ def unset_timeout():
49
+ signal.alarm(0)
50
+ signal.signal(signal.SIGALRM, signal.SIG_DFL)
51
+
52
+ def make_padding_graph(batch, pad_nodes, pad_edges):
53
+ senders = []
54
+ receivers = []
55
+ senders = torch.arange(0,pad_edges) // pad_nodes
56
+ receivers = torch.arange(1,pad_edges+1) % pad_nodes
57
+ if pad_nodes < 0 or pad_edges < 0 or pad_edges > pad_nodes * pad_nodes / 2:
58
+ print('Batch is larger than padding size or e > n^2/2. Repeating edges as necessary.')
59
+ print(f'Batch nodes: {batch.num_nodes()}, Batch edges: {batch.num_edges()}, Padding nodes: {pad_nodes}, Padding edges: {pad_edges}')
60
+ senders = senders % pad_nodes
61
+ padg = dgl.graph((senders[:pad_edges], receivers[:pad_edges]), num_nodes = pad_nodes)
62
+ for k in batch.ndata.keys():
63
+ padg.ndata[k] = torch.zeros( (pad_nodes, batch.ndata[k].shape[1]) )
64
+ for k in batch.edata.keys():
65
+ padg.edata[k] = torch.zeros( (pad_edges, batch.edata[k].shape[1]) )
66
+ return dgl.batch([batch, padg.to(batch.device)])
67
+
68
+ def pad_size(graphs, edges, nodes, edge_per_graph=3, node_per_graph=14):
69
+ pad_nodes = ((nodes // (node_per_graph * graphs))+1) * graphs * node_per_graph
70
+ pad_edges = ((edges // (edge_per_graph * graphs))+1) * graphs * edge_per_graph
71
+ return pad_nodes, pad_edges
72
+
73
+ def pad_batch_to_step_per_graph(batch, edge_per_graph=3, node_per_graph=14):
74
+ n_graphs = batch.batch_num_nodes().shape[0]
75
+ pad_nodes = (batch.num_nodes() + node_per_graph * n_graphs) % int(n_graphs * node_per_graph)
76
+ pad_edges = (batch.num_edges() + edge_per_graph * n_graphs) % int(n_graphs * edge_per_graph)
77
+ return make_padding_graph(batch, pad_nodes, pad_edges)
78
+
79
+ def pad_batch(batch, edges = 104000, nodes = 16000):
80
+ if edges == 0 and nodes == 0:
81
+ return batch
82
+ pad_nodes = 0
83
+ pad_edges = 0
84
+ pad_nodes = nodes - batch.num_nodes()
85
+ pad_edges = edges - batch.num_edges()
86
+ return make_padding_graph(batch, pad_nodes, pad_edges)
87
+
88
+ def pad_batch_num_nodes(batch, max_num_nodes, hid_size = 64):
89
+ print(f"Padding each graph to have {max_num_nodes} nodes")
90
+
91
+ unbatched = dgl.unbatch(batch)
92
+ for g in unbatched:
93
+ num_nodes_to_add = max_num_nodes - g.number_of_nodes()
94
+ if num_nodes_to_add > 0:
95
+ g.add_nodes(num_nodes_to_add) # Add isolated nodes
96
+
97
+ batch = dgl.batch(unbatched)
98
+
99
+ padding_mask = torch.zeros((batch.ndata['features'].shape[0]), dtype=torch.bool)
100
+ global_update_weights = torch.ones((batch.ndata['features'].shape[0], hid_size))
101
+
102
+ for i in range(len(batch.ndata['features'])):
103
+ if (torch.count_nonzero(batch.ndata['features'][i]) == 0):
104
+ padding_mask[i] = True
105
+ global_update_weights[i] = 0
106
+
107
+ batch.ndata['w'] = global_update_weights
108
+ batch.ndata['padding_mask'] = padding_mask
109
+
110
+ return batch
111
+
112
+
113
+ def fold_selection(fold_config, sample):
114
+ n_folds = fold_config['n_folds']
115
+ folds_opt = fold_config[sample]
116
+ folds = []
117
+ if type(folds_opt) == int:
118
+ return lambda x : x.tracking[:,0] % n_folds == folds_opt
119
+ elif type(folds_opt) == list:
120
+ print("fold type is list")
121
+ print(f"fold_config = {fold_config}")
122
+ print(f"folds_opt = {folds_opt}")
123
+ return lambda x : sum([x.tracking[:,0] % n_folds == f for f in folds_opt]) == 1
124
+ else:
125
+ raise ValueError("Invalid fold selection option with type {}".format(type(folds_opt)))
126
+
127
+ def fold_selection_name(fold_config, sample):
128
+ n_folds = fold_config['n_folds']
129
+ folds_opt = fold_config[sample]
130
+ if type(folds_opt) == int:
131
+ return f'n_{n_folds}_f_{folds_opt}'
132
+ elif type(folds_opt) == list:
133
+ return f'n_{n_folds}_f_{"_".join([str(f) for f in folds_opt])}'
134
+ else:
135
+ raise ValueError("Invalid fold selection option with type {}".format(type(folds_opt)))
136
+
137
+ #Return the index and checkpoint of the last epoch.
138
+ def get_last_epoch(config, max_ep = -1, device = None):
139
+ last_epoch = -1
140
+ checkpoint = None
141
+ if max_ep < 0:
142
+ max_ep = config['Training']['epochs']
143
+ for ep in range(max_ep):
144
+ if os.path.exists(os.path.join(config['Training_Directory'], f'model_epoch_{ep}.pt')):
145
+ last_epoch = ep
146
+ else:
147
+ print(f'Epoch {ep} not found. Stopping at epoch {last_epoch}')
148
+ print('File not found: ', os.path.join(config['Training_Directory'], f'model_epoch_{ep}.pt'))
149
+ break
150
+ if last_epoch >= 0:
151
+ checkpoint = torch.load(os.path.join(config['Training_Directory'], f'model_epoch_{last_epoch}.pt'), map_location=device)
152
+ return last_epoch, checkpoint
153
+
154
+ #Return the index and checkpoint of the last epoch.
155
+ def get_specific_epoch(config, target_epoch, device = None, from_ryan = False):
156
+ last_epoch = -1
157
+ checkpoint = None
158
+ for ep in range(target_epoch + 1):
159
+ if (from_ryan):
160
+ if os.path.exists(os.path.join('/global/cfs/cdirs/atlas/berobert/root_gnn_dgl/' + config['Training_Directory'], f'model_epoch_{ep}.pt')):
161
+ last_epoch = ep
162
+ else:
163
+ print(f'Epoch {ep} not found. Stopping at epoch {last_epoch}')
164
+ print('File not found: ', os.path.join('/global/cfs/cdirs/atlas/berobert/root_gnn_dgl/' + config['Training_Directory'], f'model_epoch_{ep}.pt'))
165
+ break
166
+ else:
167
+ if os.path.exists(os.path.join(config['Training_Directory'], f'model_epoch_{ep}.pt')):
168
+ last_epoch = ep
169
+ else:
170
+ print(f'Epoch {ep} not found. Stopping at epoch {last_epoch}')
171
+ print('File not found: ', os.path.join(config['Training_Directory'], f'model_epoch_{ep}.pt'))
172
+ break
173
+ if last_epoch >= 0:
174
+ if (from_ryan):
175
+ checkpoint = torch.load('/global/cfs/cdirs/atlas/berobert/root_gnn_dgl/' + os.path.join(config['Training_Directory'], f'model_epoch_{last_epoch}.pt'), map_location=device)
176
+ else:
177
+ checkpoint = torch.load(os.path.join(config['Training_Directory'], f'model_epoch_{last_epoch}.pt'), map_location=device)
178
+ return last_epoch, checkpoint
179
+
180
+ #Convert training logs into dict for plotting.
181
+ def read_log(config):
182
+ lines = []
183
+ with open(config['Training_Directory'] + '/training.log', 'r') as f:
184
+ lines = f.readlines()
185
+ lines = [ l for l in lines if 'Epoch' in l ]
186
+ nlines = len(lines)
187
+ labels = []
188
+ for field in lines[0].split('|'):
189
+ labels.append(field.split()[0])
190
+ log = {label : np.zeros(nlines) for label in labels}
191
+ for i, line in enumerate(lines):
192
+ for field in line.split('|'):
193
+ spl = field.split()
194
+ log[spl[0]][i] = float(spl[1])
195
+ return log
196
+
197
+ #Plot training logs.
198
+ def plot_log(log, output_file):
199
+ fig, ax = plt.subplots(2, 2, figsize=(10,10))
200
+ #Time
201
+
202
+ ax[0][0].plot(log['Epoch'], np.cumsum(log['Time']), label='Time')
203
+ ax[0][0].set_xlabel('Epoch')
204
+ ax[0][0].set_ylabel('Time (s)')
205
+ ax[0][0].legend()
206
+
207
+ """
208
+ ax[0][0].plot(log['Epoch'], log['LR'], label='Learning Rate')
209
+ ax[0][0].set_xlabel('Epoch')
210
+ ax[0][0].set_ylabel('Learning Rate')
211
+ ax[0][0].set_yscale('log')
212
+ ax[0][0].legend()
213
+ """
214
+
215
+ #Loss
216
+ ax[0][1].plot(log['Epoch'], log['Loss'], label='Train Loss')
217
+ ax[0][1].plot(log['Epoch'], log['Test_Loss'], label='Test Loss')
218
+ ax[0][1].set_xlabel('Epoch')
219
+ ax[0][1].set_ylabel('Loss')
220
+ ax[0][1].legend()
221
+
222
+ #Accuracy
223
+ ax[1][0].plot(log['Epoch'], log['Accuracy'], label='Test Accuracy')
224
+ ax[1][0].set_xlabel('Epoch')
225
+ ax[1][0].set_ylabel('Accuracy')
226
+ ax[1][0].set_ylim((0.44, 0.56))
227
+ ax[1][0].legend()
228
+
229
+ #AUC
230
+ ax[1][1].plot(log['Epoch'], log['Test_AUC'], label='Test AUC')
231
+ ax[1][1].set_xlabel('Epoch')
232
+ ax[1][1].set_ylabel('AUC')
233
+ ax[1][1].legend()
234
+
235
+ fig.savefig(output_file)
236
+
237
+ class EarlyStop():
238
+ def __init__(self, patience=15, threshold=1e-8, mode='min'):
239
+ self.patience = patience
240
+ self.threshold = threshold
241
+ self.mode = mode
242
+ self.count = 0
243
+ self.current_best = np.inf if mode == 'min' else -np.inf
244
+ self.should_stop = False
245
+
246
+ def update(self, value):
247
+ if self.mode == 'min': # Minimizing loss
248
+ if value < self.current_best - self.threshold:
249
+ self.current_best = value
250
+ self.count = 0
251
+ else:
252
+ self.count += 1
253
+ elif self.mode == 'max': # Maximizing metric
254
+ if value > self.current_best + self.threshold:
255
+ self.current_best = value
256
+ self.count = 0
257
+ else:
258
+ self.count += 1
259
+
260
+ # Check if patience is exceeded
261
+ if self.count >= self.patience:
262
+ self.should_stop = True
263
+
264
+ def reset(self):
265
+ self.count = 0
266
+ self.current_best = np.inf if self.mode == 'min' else -np.inf
267
+ self.should_stop = False
268
+
269
+ def to_str(self):
270
+ status = (
271
+ f"EarlyStop Status:\n"
272
+ f" Mode: {'Minimize' if self.mode == 'min' else 'Maximize'}\n"
273
+ f" Patience: {self.patience}\n"
274
+ f" Threshold: {self.threshold:.3e}\n"
275
+ f" Current Best: {self.current_best:.6f}\n"
276
+ f" Consecutive Epochs Without Improvement: {self.count}\n"
277
+ f" Stopping Triggered: {'Yes' if self.should_stop else 'No'}"
278
+ )
279
+ return status
280
+
281
+ def to_dict(self):
282
+
283
+ return {
284
+ 'patience': self.patience,
285
+ 'threshold': self.threshold,
286
+ 'mode': self.mode,
287
+ 'count': self.count,
288
+ 'current_best': self.current_best,
289
+ 'should_stop': self.should_stop,
290
+ }
291
+
292
+ @classmethod
293
+ def load_from_dict(cls, state_dict):
294
+ instance = cls(
295
+ patience=state_dict['patience'],
296
+ threshold=state_dict['threshold'],
297
+ mode=state_dict['mode']
298
+ )
299
+ instance.count = state_dict['count']
300
+ instance.current_best = state_dict['current_best']
301
+ instance.should_stop = state_dict['should_stop']
302
+ return instance
303
+
304
+
305
+ def graph_augmentation(graph):
306
+ print("Augmenting Graph")
307
+ return
scripts/find_free_port.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # find_free_port.py
2
+ def find_free_port():
3
+ import socket
4
+ from contextlib import closing
5
+
6
+ with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
7
+ s.bind(('', 0))
8
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
9
+ return str(s.getsockname()[1])
10
+
11
+ if __name__ == "__main__":
12
+ print(find_free_port())
scripts/inference.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ file_path = os.getcwd()
4
+ sys.path.append(file_path)
5
+
6
+ import argparse
7
+ import yaml
8
+
9
+ import torch
10
+ import dgl
11
+ from dgl.data import DGLDataset
12
+ from dgl.dataloading import GraphDataLoader
13
+ from torch.utils.data import SubsetRandomSampler, SequentialSampler
14
+
15
+
16
+ def my_error_handler(level, abort, location, msg):
17
+ # Log the error message to a file instead of printing
18
+ with open("error_log.txt", "a") as log_file:
19
+ log_file.write(f"Error in {location}: {msg}\n")
20
+
21
+ # Optionally, print the error message to the console
22
+ # print(f"Error in {location}: {msg}")
23
+
24
+ # Decide whether to abort based on the error level
25
+ if abort:
26
+ raise RuntimeError(f"Fatal error in {location}: {msg}")
27
+
28
+ class CustomPreBatchedDataset(DGLDataset):
29
+ def __init__(self, start_dataset, batch_size, mask_fn=None, drop_last=False, shuffle=False, **kwargs):
30
+ self.start_dataset = start_dataset
31
+ self.batch_size = batch_size
32
+ self.mask_fn = mask_fn or (lambda x: torch.ones(len(x), dtype=torch.bool))
33
+ self.drop_last = drop_last
34
+ self.shuffle = shuffle
35
+ super().__init__(name=start_dataset.name + '_custom_prebatched', save_dir=start_dataset.save_dir)
36
+
37
+ def process(self):
38
+ mask = self.mask_fn(self.start_dataset)
39
+ indices = torch.arange(len(self.start_dataset))[mask]
40
+ print(f"Number of elements after masking: {len(indices)}") # Debugging print
41
+
42
+ if self.shuffle:
43
+ sampler = SubsetRandomSampler(indices)
44
+ else:
45
+ sampler = SequentialSampler(indices)
46
+
47
+ self.dataloader = GraphDataLoader(
48
+ self.start_dataset,
49
+ sampler=sampler,
50
+ batch_size=self.batch_size,
51
+ drop_last=self.drop_last
52
+ )
53
+ print(f"Batch size set in DataLoader: {self.batch_size}") # Debugging print
54
+
55
+ def __getitem__(self, idx):
56
+ if isinstance(idx, int):
57
+ idx = [idx]
58
+ sampler = SequentialSampler(idx)
59
+ dloader = GraphDataLoader(self.start_dataset, sampler=sampler, batch_size=self.batch_size, drop_last=False)
60
+ return next(iter(dloader))
61
+
62
+ def __len__(self):
63
+ return len(self.start_dataset)
64
+
65
+ def include_config(conf):
66
+ if 'include' in conf:
67
+ for i in conf['include']:
68
+ with open(i) as f:
69
+ conf.update(yaml.load(f, Loader=yaml.FullLoader))
70
+ del conf['include']
71
+
72
+ def load_config(config_file):
73
+ with open(config_file) as f:
74
+ conf = yaml.load(f, Loader=yaml.FullLoader)
75
+ include_config(conf)
76
+ return conf
77
+
78
+ def main():
79
+ parser = argparse.ArgumentParser()
80
+ add_arg = parser.add_argument
81
+ add_arg('--config', type=str, required=True)
82
+ add_arg('--target', type=str, required=True)
83
+ add_arg('--destination', type=str, default='')
84
+ add_arg('--chunkno', type=int, default=0)
85
+ add_arg('--chunks', type=int, default=1)
86
+ add_arg('--write', action='store_true')
87
+ add_arg('--ckpt', type=int, default=-1)
88
+ add_arg('--clobber', action='store_true')
89
+ add_arg('--tree', type=str, default='')
90
+ add_arg('--branch_name', type=str, default='score')
91
+ args = parser.parse_args()
92
+
93
+ config = load_config(args.config)
94
+ if args.destination == '':
95
+ args.destination = os.path.join(config['Training_Directory'], 'inference/', os.path.split(args.target)[1])
96
+ else:
97
+ args.destination = args.destination
98
+ if not args.write:
99
+ args.destination = args.destination.replace('.root', '') + f'_chunk{args.chunkno}.npz'
100
+
101
+ if os.path.exists(args.destination):
102
+ print(f'File {args.destination} already exists.')
103
+ if args.clobber:
104
+ print('Clobbering.')
105
+ else:
106
+ print('Exiting.')
107
+ return
108
+ else:
109
+ print(f'Writing to {args.destination}')
110
+
111
+ import time
112
+ start = time.time()
113
+ import ROOT
114
+ import torch
115
+ from array import array
116
+ import numpy as np
117
+ from root_gnn_base import batched_dataset as dataset
118
+ from root_gnn_base import utils
119
+ end = time.time()
120
+ print('Imports finished in {:.2f} seconds'.format(end - start))
121
+
122
+ start = time.time()
123
+ dset_config = config['Datasets'][list(config['Datasets'].keys())[0]]
124
+ if dset_config['class'] == 'LazyDataset':
125
+ dset_config['class'] = 'EdgeDataset'
126
+ elif dset_config['class'] == 'LazyMultiLabelDataset':
127
+ dset_config['class'] = 'MultiLabelDataset'
128
+ elif dset_config['class'] == 'PhotonIDDataset':
129
+ dset_config['class'] = 'UnlazyPhotonIDDataset'
130
+ elif dset_config['class'] == 'kNNDataset':
131
+ dset_config['class'] = 'UnlazyKNNDataset'
132
+ dset_config['args']['raw_dir'] = os.path.split(args.target)[0]
133
+ dset_config['args']['file_names'] = os.path.split(args.target)[1]
134
+ dset_config['args']['save'] = False
135
+ dset_config['args']['chunks'] = args.chunks
136
+ dset_config['args']['process_chunks'] = [args.chunkno,]
137
+ dset_config['args']['selections'] = []
138
+
139
+ dset_config['args']['save_dir'] = os.path.dirname(args.destination)
140
+
141
+ if args.tree != '':
142
+ dset_config['args']['tree_name'] = args.tree
143
+
144
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
145
+
146
+ dstart = time.time()
147
+ dset = utils.buildFromConfig(dset_config)
148
+ dend = time.time()
149
+ print('Dataset finished in {:.2f} seconds'.format(dend - dstart))
150
+
151
+ print(dset)
152
+
153
+ batch_size = config['Training']['batch_size']
154
+ lstart = time.time()
155
+ loader = CustomPreBatchedDataset(dset, batch_size)
156
+ loader.process()
157
+ # loader = dataset.PreBatchedDataset(dset, batch_size, shuffle=False, drop_last=False, save_to_disk=False, chunks = 1, num_workers=0)
158
+ lend = time.time()
159
+ print('Loader finished in {:.2f} seconds'.format(lend - lstart))
160
+ sample_graph, _, _, global_sample = loader[0]
161
+
162
+ print('dset length =', len(dset))
163
+ print('loader length =', len(loader))
164
+
165
+ model = utils.buildFromConfig(config['Model'], {'sample_graph' : sample_graph, 'sample_global': global_sample}).to(device)
166
+ if args.ckpt < 0:
167
+ ep, checkpoint = utils.get_last_epoch(config, args.ckpt, device=device)
168
+ else:
169
+ ep, checkpoint = utils.get_specific_epoch(config, args.ckpt, device=device)
170
+ #Bad filler for models which were compiled. Have to remove this prefix.
171
+ mds_copy = {}
172
+ for key in checkpoint['model_state_dict'].keys():
173
+ newkey = key.replace('module.', '')
174
+ newkey = newkey.replace('_orig_mod.', '')
175
+ mds_copy[newkey] = checkpoint['model_state_dict'][key]
176
+ model.load_state_dict(mds_copy)
177
+ model.eval()
178
+
179
+ end = time.time()
180
+ print('Model and dataset finished in {:.2f} seconds'.format(end - start))
181
+ print('Starting inference')
182
+ start = time.time()
183
+
184
+ finish_fn = torch.nn.Sigmoid()
185
+ if 'Loss' in config:
186
+ finish_fn = utils.buildFromConfig(config['Loss']['finish'])
187
+
188
+ scores = []
189
+ labels = []
190
+ tracking_info = []
191
+ ibatch = 0
192
+
193
+ for batch, label, track, globals in loader.dataloader:
194
+ batch = batch.to(device)
195
+ pred = model(batch, globals.to(device))
196
+ ibatch += 1
197
+ # scores.append(finish_fn(pred).detach().cpu().numpy())
198
+ if (finish_fn.__class__.__name__ == "ContrastiveClusterFinish"):
199
+ scores.append(pred.detach().cpu().numpy())
200
+ else:
201
+ scores.append(finish_fn(pred).detach().cpu().numpy())
202
+ labels.append(label.detach().cpu().numpy())
203
+ tracking_info.append(track.detach().cpu().numpy())
204
+
205
+ # for batch, label, track, globals in loader:
206
+ # batch = batch.to(device)
207
+ # pred = model(batch, globals.to(device))
208
+ # print(f'Batch size: {batch.batch_size if hasattr(batch, "batch_size") else "Unavailable"}')
209
+ # print(f'Prediction shape: {pred.shape}')
210
+ # ibatch += 1
211
+ # scores.append(finish_fn(pred).detach().cpu().numpy())
212
+ # labels.append(label.detach().cpu().numpy())
213
+ # tracking_info.append(track.detach().cpu().numpy())
214
+ # exit()
215
+
216
+ score_size = scores[0].shape[1]
217
+ scores = np.concatenate(scores)
218
+ labels = np.concatenate(labels)
219
+ tracking_info = np.concatenate(tracking_info)
220
+ end = time.time()
221
+
222
+ print('Inference finished in {:.2f} seconds'.format(end - start))
223
+
224
+ if args.write:
225
+ # ROOT.SetErrorHandler(my_error_handler)
226
+ ROOT.gErrorIgnoreLevel = ROOT.kFatal
227
+ # ROOT.gSystem.RedirectOutput("/dev/null", "w")
228
+
229
+ # Open the original ROOT file
230
+ infile = ROOT.TFile.Open(args.target)
231
+ tree = infile.Get(dset_config['args']['tree_name'])
232
+
233
+ # Create the destination directory if it doesn't exist
234
+ os.makedirs(os.path.split(args.destination)[0], exist_ok=True)
235
+
236
+ # Create a new ROOT file to write the modified tree
237
+ outfile = ROOT.TFile.Open(args.destination, 'RECREATE')
238
+
239
+ # Clone the original tree, including data
240
+ outtree = tree.CloneTree(0) # Clone all entries
241
+
242
+ # Determine if scores is a list of single values or vectors
243
+ from ROOT import std
244
+ if isinstance(scores[0], (list, tuple, np.ndarray)): # Check if scores contains vectors
245
+ # Create a new branch for scores as a vector of floats
246
+ scores_branch_vec = std.vector('float')()
247
+ outtree.Branch(args.branch_name, scores_branch_vec)
248
+ is_vector = True
249
+ else: # Scores contains single values
250
+ # Create a new branch for scores as a single float
251
+ score_branch_arr = array('f', [0])
252
+ outtree.Branch(args.branch_name, score_branch_arr, f'{args.branch_name}/F')
253
+ is_vector = False
254
+
255
+ # Write scores to the new branch
256
+ print(f'Writing {len(scores)} scores to tree')
257
+
258
+ for i in range(tree.GetEntries()):
259
+ tree.GetEntry(i)
260
+
261
+ if is_vector:
262
+ # Clear the vector
263
+ scores_branch_vec.clear()
264
+
265
+ # Add all elements from scores[i] to the vector
266
+ for value in scores[i]:
267
+ scores_branch_vec.push_back(float(value)) # Use push_back to add elements one by one
268
+ else:
269
+ # Fill the score branch with the current single score
270
+ score_branch_arr[0] = float(scores[i]) # Ensure the value is a float
271
+
272
+ # Fill the output tree with all branches, including the new scores branch
273
+ outtree.Fill()
274
+
275
+ # Write the modified tree to the new file
276
+ print(f'Writing to file {args.destination}')
277
+ print(f'Input entries: {tree.GetEntries()}, Output entries: {outtree.GetEntries()}')
278
+ outtree.Write()
279
+ outfile.Close()
280
+ infile.Close()
281
+ else:
282
+ os.makedirs(os.path.split(args.destination)[0], exist_ok=True)
283
+ np.savez(args.destination, scores=scores, labels=labels, tracking_info=tracking_info)
284
+
285
+ if __name__ == '__main__':
286
+ main()
287
+
288
+
289
+
scripts/prep_data.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ file_path = os.getcwd()
4
+ sys.path.append(file_path)
5
+
6
+ import root_gnn_base.utils as utils
7
+ import argparse
8
+ from root_gnn_base.batched_dataset import PreBatchedDataset
9
+ from root_gnn_base.batched_dataset import LazyPreBatchedDataset
10
+
11
+ def main():
12
+ parser = argparse.ArgumentParser()
13
+ add_arg = parser.add_argument
14
+ add_arg('--config', type=str, required=True)
15
+ add_arg('--dataset', type=str, required=True)
16
+ add_arg('--chunk', type=int, default=0)
17
+ add_arg('--shuffle_mode', action='store_true', help='Shuffle the dataset before training.')
18
+ args = parser.parse_args()
19
+
20
+ config = utils.load_config(args.config)
21
+ dset_config = config['Datasets'][args.dataset]
22
+ batch_size = config['Training']['batch_size']
23
+ if not args.shuffle_mode:
24
+ dset = utils.buildFromConfig(dset_config, {'process_chunks': [args.chunk,]})
25
+ else:
26
+ dset = utils.buildFromConfig(dset_config)
27
+ if 'batch_size' in dset_config:
28
+ batch_size = dset_config['batch_size']
29
+
30
+ shuffle_chunks = dset_config.get('shuffle_chunks', 10)
31
+ padding_mode = dset_config.get('padding_mode', 'STEPS')
32
+ fold_conf = dset_config["folding"]
33
+ print(f"shuffle_chunks = {shuffle_chunks}, args.chunk = {args.chunk}, padding_mode = {padding_mode}")
34
+ if dset_config["class"] == "LazyMultiLabelDataset":
35
+ LazyPreBatchedDataset(start_dataset = dset, batch_size = batch_size, mask_fn = utils.fold_selection(fold_conf, "train"), suffix = utils.fold_selection_name(fold_conf, "train"), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode)
36
+ LazyPreBatchedDataset(start_dataset = dset, batch_size = batch_size, mask_fn = utils.fold_selection(fold_conf, "test"), suffix = utils.fold_selection_name(fold_conf, 'test'), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode)
37
+
38
+ else:
39
+ PreBatchedDataset(dset, batch_size, utils.fold_selection(fold_conf, "train"), suffix = utils.fold_selection_name(fold_conf, "train"), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode)
40
+ PreBatchedDataset(dset, batch_size, utils.fold_selection(fold_conf, "test"), suffix = utils.fold_selection_name(fold_conf, 'test'), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode)
41
+
42
+ if __name__ == "__main__":
43
+ main()
scripts/training_script.py ADDED
@@ -0,0 +1,755 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import time
3
+ import datetime
4
+ import yaml
5
+ import os
6
+
7
+ start_time = time.time()
8
+
9
+ import dgl
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ import sys
14
+ file_path = os.getcwd()
15
+ sys.path.append(file_path)
16
+
17
+ import root_gnn_base.batched_dataset as datasets
18
+ from root_gnn_base import utils
19
+ import root_gnn_base.custom_scheduler as lr_utils
20
+ from models import GCN
21
+
22
+ import numpy as np
23
+ from sklearn.metrics import roc_auc_score
24
+ import resource
25
+ import gc
26
+
27
+ import torch.distributed as dist
28
+ import torch.multiprocessing as mp
29
+ from torch.utils.data.distributed import DistributedSampler
30
+ from torch.nn.parallel import DistributedDataParallel as DDP
31
+
32
+ print("import time: {:.4f} s".format(time.time() - start_time))
33
+
34
+ def mem():
35
+ print(f'Current memory usage: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024 / 1024} GB')
36
+
37
+ def gpu_mem():
38
+ print()
39
+ print('GPU Memory Usage:')
40
+ sum = 0
41
+ # for obj in gc.get_objects():
42
+ # try:
43
+ # if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
44
+ # print(obj.numel() if len(obj.size()) > 0 else 0, type(obj), obj.size())
45
+ # sum += obj.numel() if len(obj.size()) > 0 else 0
46
+ # except:
47
+ # pass
48
+ print(f'Current GPU memory usage: {torch.cuda.memory_allocated() / 1024 / 1024 / 1024} GB')
49
+ print(f'Current GPU cache usage: {torch.cuda.memory_cached() / 1024 / 1024 / 1024} GB')
50
+ print(f'Current GPU max memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024} GB')
51
+ print(f'Current GPU max cache usage: {torch.cuda.max_memory_cached() / 1024 / 1024 / 1024} GB')
52
+ print(f'Numel in current tensors: {sum}')
53
+ mem()
54
+
55
+
56
+ ## epoch stores the epoch number I want to evaluate the model at
57
+ def evaluate(val_loaders, model, config, device, epoch = -1):
58
+ print("Evaluating")
59
+
60
+ if (epoch != -1) :
61
+ print(f"Evalulating at epoch {epoch}")
62
+ last_ep, checkpoint = utils.get_specific_epoch(config, epoch, from_ryan=False)
63
+ print(f"Evaluating at epoch = {last_ep}")
64
+ else:
65
+ starting_epoch = 0
66
+ last_ep, checkpoint = utils.get_last_epoch(config)
67
+
68
+ if checkpoint != None:
69
+ ep = last_ep
70
+ state_dict = checkpoint['model_state_dict']
71
+ new_state_dict = {}
72
+ for k, v in state_dict.items():
73
+ new_key = k.replace('module.', '')
74
+ new_state_dict[new_key] = v
75
+ model.load_state_dict(new_state_dict)
76
+ starting_epoch = checkpoint['epoch'] + 1
77
+ print(f"Loaded epoch {checkpoint['epoch']} from checkpoint")
78
+
79
+ if 'Loss' not in config:
80
+ loss_fcn = nn.BCEWithLogitsLoss()
81
+ else:
82
+ loss_fcn = utils.buildFromConfig(config['Loss'])
83
+ if len(val_loaders) == 0:
84
+ return "No validation data"
85
+ start = time.time()
86
+ scores = []
87
+ labels = []
88
+ weights = []
89
+ before_decoder = []
90
+ after_decoder = []
91
+ tracking = []
92
+
93
+ batch_size = config["Training"]["batch_size"]
94
+
95
+ batch_limit = int(np.ceil(1e5 / batch_size))
96
+
97
+ model.eval()
98
+ with torch.no_grad():
99
+ for loader in val_loaders:
100
+ batch_count = 0
101
+ for batch, label, track, global_feats in loader:
102
+ #Don't use compiled model for testing since we can't control the batch size.
103
+ #We could before, but it assumes each dataset has the same number of batches...
104
+ before_global_decoder, after_global_decoder, after_classify = model.representation(batch.to(device), global_feats.to(device))
105
+
106
+ scores.append(after_classify.to("cpu"))
107
+ before_decoder.append(before_global_decoder.to("cpu"))
108
+ after_decoder.append(after_global_decoder.to("cpu"))
109
+ labels.append(label.to("cpu"))
110
+ weights.append(track[:,1].to("cpu"))
111
+ tracking.append(track.to("cpu"))
112
+
113
+ batch_count += 1
114
+ if batch_count >= batch_limit:
115
+ break
116
+
117
+ if scores == []: #If validation set is empty.
118
+ return
119
+ logits = torch.concatenate(scores)
120
+ scores = torch.sigmoid(logits)
121
+ labels = torch.concatenate(labels)
122
+ weights = torch.concatenate(weights)
123
+ before_decoder = torch.concatenate(before_decoder)
124
+ after_decoder = torch.concatenate(after_decoder)
125
+ tracking = torch.concatenate(tracking)
126
+
127
+ logits = logits.to("cpu").numpy()
128
+ scores = scores.to("cpu").numpy()
129
+ labels = labels.to("cpu").numpy()
130
+ before_decoder = before_decoder.to("cpu").numpy()
131
+ after_decoder = after_decoder.to("cpu").numpy()
132
+ tracking = tracking.to("cpu").numpy()
133
+
134
+ # Save the NumPy arrays to a .npz file
135
+ outfile = f"{config['Training_Directory']}/evaluation_{epoch}.npz"
136
+
137
+ np.savez(outfile, logits=logits, scores=scores, labels=labels, before_decoder=before_decoder, after_decoder=after_decoder, tracking=tracking)
138
+
139
+ print(f"saved scores to {outfile}")
140
+ return
141
+
142
+
143
+ def train(train_loaders, test_loaders, model, device, config, args, rank):
144
+ nocompile = args.nocompile
145
+ restart = args.restart
146
+ # define train/val samples, loss function and optimizer
147
+ if 'Loss' not in config:
148
+ loss_fcn = nn.BCEWithLogitsLoss()
149
+ finish_fn = torch.nn.Sigmoid()
150
+ else:
151
+ loss_fcn = utils.buildFromConfig(config['Loss'])
152
+ finish_fn = utils.buildFromConfig(config['Loss']['finish'])
153
+
154
+ optimizer = torch.optim.Adam(model.parameters(), lr=config['Training']['learning_rate'])
155
+ if 'gamma' in config['Training']:
156
+ gamma = config['Training']['gamma']
157
+ else:
158
+ gamma = 1
159
+
160
+ if 'dynamic_lr' in config['Training']:
161
+ factor = config['Training']['dynamic_lr']['factor']
162
+ patience = config['Training']['dynamic_lr']['patience']
163
+ else:
164
+ factor = 1
165
+ patience = 1
166
+
167
+ early_termination = utils.EarlyStop()
168
+ if 'early_termination' in config['Training']:
169
+ early_termination.patience = config['Training']['early_termination']['patience']
170
+ early_termination.threshold = config['Training']['early_termination']['threshold']
171
+ early_termination.mode = config['Training']['early_termination']['mode']
172
+
173
+ scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma = gamma)
174
+ #scheduler_reset = custom_scheduler.Dynamic_LR(optimizer, 'max', factor = factor, patience = patience)
175
+ custom_scheduler = None
176
+ if ('custom_scheduler' in config['Training']):
177
+ run_time_args = {}
178
+ scheduler_class = config['Training']['custom_scheduler']['class']
179
+ if (scheduler_class == 'Dynamic_LR' or
180
+ scheduler_class == 'Dynamic_LR_AND_Partial_Reset' or
181
+ scheduler_class == 'Dynamic_LR_AND_Full_Reset'):
182
+
183
+ run_time_args={'optimizer': optimizer}
184
+
185
+ custom_scheduler = utils.buildFromConfig(config['Training']['custom_scheduler'], run_time_args=run_time_args)
186
+
187
+ starting_epoch = 0
188
+ if not restart:
189
+ last_ep, checkpoint = utils.get_last_epoch(config)
190
+ if checkpoint != None:
191
+ ep = starting_epoch - 1
192
+ if nocompile:
193
+ new_state_dict = {}
194
+ for k, v in checkpoint['model_state_dict'].items():
195
+ new_key = k.replace('module.', '')
196
+ new_state_dict[new_key] = v
197
+ checkpoint['model_state_dict'] = new_state_dict
198
+ if (args.multinode or args.multigpu):
199
+ new_state_dict = {}
200
+ for k, v in checkpoint['model_state_dict'].items():
201
+ new_key = 'module.' + k
202
+ new_state_dict[new_key] = v
203
+ checkpoint['model_state_dict'] = new_state_dict
204
+ model.load_state_dict(checkpoint['model_state_dict'])
205
+ else:
206
+ model._orig_mod.load_state_dict(checkpoint['model_state_dict'])
207
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
208
+ starting_epoch = checkpoint['epoch'] + 1
209
+ if 'early_stop' in checkpoint:
210
+ early_termination = utils.EarlyStop.load_from_dict(checkpoint['early_stop'])
211
+ print(early_termination.to_str())
212
+ print("EarlyStop state restored successfully.")
213
+ if early_termination.should_stop:
214
+ print(f"Early Termination at Epoch {epoch}")
215
+ return
216
+ else:
217
+ print("'early_stop' not found in checkpoint. Initializing a new EarlyStop instance.")
218
+ early_termination = utils.EarlyStop()
219
+ print(f"Loaded epoch {checkpoint['epoch']} from checkpoint")
220
+ log = open(config['Training_Directory'] + '/training.log', 'a', buffering=1)
221
+ else:
222
+ log = open(config['Training_Directory'] + '/training.log', 'w', buffering=1)
223
+
224
+ train_cyclers = []
225
+ for loader in train_loaders:
226
+ train_cyclers.append(utils.cycler((loader)))
227
+
228
+ if args.savecache:
229
+ max_batch = [None,] * len(train_loaders)
230
+ for dset_i, loader in enumerate(train_loaders):
231
+ mbs = 0
232
+ for batch_i, batch in enumerate(loader):
233
+ if batch[0].num_nodes() > mbs:
234
+ mbs = batch[0].num_nodes()
235
+ max_batch[dset_i] = batch[0]
236
+ print(f'Max batch size for dataset {dset_i}: {mbs}')
237
+ big_batch = dgl.batch(max_batch).to(device)
238
+ with torch.no_grad():
239
+ model(big_batch)
240
+
241
+ cumulative_times = [0,0,0,0,0]
242
+ log.write(f'Training {config["Training_Name"]} {datetime.datetime.now()} \n')
243
+ print(f"Starting training for {config['Training']['epochs']} epochs")
244
+
245
+ if hasattr(train_loaders[0].dataset, 'padding_mode'):
246
+ is_padded = train_loaders[0].dataset.padding_mode != 'NONE'
247
+ if (train_loaders[0].dataset.padding_mode == 'NODE'):
248
+ is_padded = False
249
+ else:
250
+ is_padded = False
251
+
252
+ lr_utils.print_LR(optimizer)
253
+
254
+ # torch.save({
255
+ # 'epoch': 0,
256
+ # 'model_state_dict': model.state_dict(),
257
+ # 'optimizer_state_dict': optimizer.state_dict(),
258
+ # }, os.path.join(config['Training_Directory'], f"model_epoch_{0}.pt"))
259
+ # exit()
260
+
261
+
262
+ # training loop
263
+ # gpu_mem()
264
+ for epoch in range(starting_epoch, config['Training']['epochs']):
265
+ start = time.time()
266
+ run = start
267
+ if (args.multigpu or args.multinode):
268
+ dist.barrier()
269
+ if (epoch == 2):
270
+ # torch.cuda.cudart().cudaProfilerStart()
271
+ pass
272
+
273
+ # training
274
+ model.train()
275
+ ibatch = 0
276
+ total_loss = 0
277
+ for batched_graph, labels, _, global_feats in train_loaders[0]:
278
+ # # need to fix padded case
279
+ # if is_padded:
280
+ # tglobals.append(torch.zeros(1, len(global_feats[0])))
281
+
282
+ batch_start = time.time()
283
+ logits = torch.tensor([])
284
+ tlabels = torch.tensor([])
285
+ batch_lengths = []
286
+ for cycler in train_cyclers:
287
+ graph, label, _, global_feats = next(cycler)
288
+ graph = graph.to(device)
289
+ label = label.to(device)
290
+ global_feats = global_feats.to(device)
291
+ if is_padded: #Padding the globals to match padded graphs.
292
+ global_feats = torch.concatenate((global_feats, torch.zeros(1, len(global_feats[0])).to(device)))
293
+ load = time.time()
294
+ if (len(logits) == 0):
295
+ logits = model(graph, global_feats)
296
+ tlabels = label
297
+ else:
298
+ logits = torch.concatenate((logits, model(graph, global_feats)), dim=0)
299
+ tlabels = torch.concatenate((tlabels, label), dim=0)
300
+ batch_lengths.append(logits.shape[0] - 1)
301
+
302
+ if is_padded:
303
+ keepmask = torch.full_like(logits[:,0], True, dtype=torch.bool)
304
+ keepmask[batch_lengths] = False
305
+ logits = logits[keepmask]
306
+ tlabels = tlabels.to(torch.float)
307
+ if logits.shape[1] == 1 and loss_fcn.__class__.__name__ == 'BCEWithLogitsLoss':
308
+ logits = logits[:,0]
309
+ tlabels = tlabels.to(torch.float)
310
+ if loss_fcn.__class__.__name__ == 'CrossEntropyLoss':
311
+ tlabels = tlabels.to(torch.long)
312
+ loss = loss_fcn(logits, tlabels.to(device)) # changed logits from logits[:,0] and left labels as int for multiclass. Does this break binary? Yes.
313
+ optimizer.zero_grad()
314
+ loss.backward()
315
+ optimizer.step()
316
+ total_loss += loss.detach().cpu().item()
317
+ ibatch += 1
318
+ cumulative_times[0] += batch_start - run
319
+ cumulative_times[1] += load - batch_start
320
+ run = time.time()
321
+ cumulative_times[2] += run - load
322
+ if ibatch % 1000 == 0:
323
+ print(f'Batch {ibatch} out of {len(train_loaders[0])}', end='\r')
324
+ # gpu_mem()
325
+
326
+ if (args.multigpu):
327
+ print(f'Rank {rank} Epoch Done.')
328
+ elif (args.multinode):
329
+ print(f'Rank {args.global_rank} Epoch Done.')
330
+ else:
331
+ print("Epoch Done.")
332
+ # validation
333
+
334
+ scores = []
335
+ labels = []
336
+ weights = []
337
+ model.eval()
338
+ with torch.no_grad():
339
+ for loader in test_loaders:
340
+ for batch, label, track, global_feats in loader:
341
+ #Don't use compiled model for testing since we can't control the batch size.
342
+ #We could before, but it assumes each dataset has the same number of batches...
343
+ if is_padded:
344
+ global_feats = torch.cat([global_feats, torch.zeros(1, len(global_feats[0]))])
345
+ if nocompile:
346
+ batch_scores = model(batch.to(device), global_feats.to(device))
347
+ else:
348
+ batch_scores = model._orig_mod(batch.to(device), global_feats.to(device))
349
+ if is_padded:
350
+ scores.append(batch_scores[:-1,:])
351
+ else:
352
+ scores.append(batch_scores)
353
+ labels.append(label)
354
+ weights.append(track[:,1])
355
+ eval_end = time.time()
356
+ cumulative_times[3] += eval_end - run
357
+
358
+ if scores == []: #If validation set is empty.
359
+ continue
360
+ logits = torch.concatenate(scores).to(device)
361
+ labels = torch.concatenate(labels).to(device)
362
+ weights = torch.concatenate(weights).to(device)
363
+
364
+ if (args.multigpu or args.multinode):
365
+ gathered_logits = [torch.zeros_like(logits) for _ in range(dist.get_world_size())]
366
+ gathered_labels = [torch.zeros_like(labels) for _ in range(dist.get_world_size())]
367
+ gathered_weights = [torch.zeros_like(weights) for _ in range(dist.get_world_size())]
368
+
369
+ if (args.multigpu or args.multinode):
370
+ dist.barrier()
371
+ if (args.multigpu and rank != 0) or (args.multinode and args.global_rank != 0):
372
+ dist.gather(logits, dst=0)
373
+ dist.gather(labels, dst=0)
374
+ dist.gather(weights, dst=0)
375
+ continue
376
+ else:
377
+ dist.gather(logits, gather_list=gathered_logits)
378
+ dist.gather(labels, gather_list=gathered_labels)
379
+ dist.gather(weights, gather_list=gathered_weights)
380
+
381
+ logits = torch.concatenate(gathered_logits)
382
+ labels = torch.concatenate(gathered_labels)
383
+ weights = torch.concatenate(gathered_weights)
384
+
385
+ wgt_mask = weights > 0
386
+
387
+ print(f"Num batches trained = {ibatch}")
388
+
389
+ #Note: This section is a bit ugly. Very conditional. Should maybe config defined behavior?
390
+ if (loss_fcn.__class__.__name__ == "ContrastiveClusterLoss"):
391
+ scores = logits
392
+ preds = scores
393
+ accuracy = 0
394
+ test_auc = 0
395
+ acc = 0
396
+ contrastive_cluster_loss = finish_fn(logits)
397
+
398
+ elif (loss_fcn.__class__.__name__ == "MultiLabelLoss"):
399
+ scores = finish_fn(logits)
400
+ preds = torch.round(scores)
401
+ multilabel_accuracy = []
402
+ threshold = 0.1 # 10% threshold
403
+
404
+ for i in range(len(labels[0])):
405
+ # accurate_count = torch.sum(torch.abs(preds[:, i].to("cpu") - labels[:, i].to("cpu")) / labels[:, i].to("cpu") <= threshold)
406
+ # multilabel_accruacy.append(accurate_count / len(labels))
407
+ multilabel_accuracy.append(torch.sum(preds[:, i].to("cpu") == labels[:, i].to("cpu")) / len(labels))
408
+ test_auc = 0
409
+ acc = np.mean(multilabel_accuracy)
410
+
411
+ elif logits.shape[1] == 1 and loss_fcn.__class__.__name__ == 'BCEWithLogitsLoss': #Proxy for binary classification.
412
+ test_auc = 0
413
+ acc = 0
414
+ logits = logits[:,0]
415
+ scores = finish_fn(logits)
416
+ labels =labels.to(torch.float)
417
+ preds = scores > 0.5
418
+ test_auc = roc_auc_score(labels[wgt_mask].to("cpu") == 1, scores[wgt_mask].to("cpu"), sample_weight=weights[wgt_mask].to("cpu"))
419
+ acc = torch.sum(preds.to("cpu") == labels.to("cpu")) / len(labels)
420
+
421
+ elif logits.shape[1] == 1 and loss_fcn.__class__.__name__ == 'MSELoss':
422
+ logits = logits[:,0]
423
+ scores = finish_fn(logits)
424
+ labels = labels.to(torch.float)
425
+ acc = 0
426
+ test_auc = 0
427
+
428
+ else:
429
+ preds = torch.argmax(logits, dim=1)
430
+ scores = finish_fn(logits)
431
+ if labels.dim() == 1: #Multi-class
432
+ acc = torch.sum(preds.to("cpu") == labels.to("cpu")) / len(labels) #TODO: Make each class weighted equally?
433
+
434
+ labels = labels.to("cpu")
435
+ weights = weights.to("cpu")
436
+ logits = logits.to("cpu")
437
+ wgt_mask = wgt_mask.to("cpu")
438
+
439
+ labels_onehot = np.zeros((len(labels), len(scores[0])))
440
+ labels_onehot[np.arange(len(labels)), labels] = 1
441
+
442
+ try:
443
+ #test_auc = roc_auc_score(labels[wgt_mask].to("cpu") == 1, scores[wgt_mask].to("cpu"), multi_class='ovr', sample_weight=weights[wgt_mask].to("cpu"))
444
+ test_auc = roc_auc_score(labels_onehot[wgt_mask], scores[wgt_mask].to("cpu"), multi_class='ovr', sample_weight=weights[wgt_mask].to("cpu"))
445
+ except ValueError:
446
+ test_auc = np.nan
447
+ else: #Multi-loss
448
+ acc = torch.sum(preds.to("cpu") == labels[:,0].to("cpu")) / len(labels)
449
+ try:
450
+ test_auc = roc_auc_score(labels[:,0][wgt_mask].to("cpu") == 1, scores[wgt_mask].to("cpu"), multi_class='ovr', sample_weight=weights[wgt_mask].to("cpu"))
451
+ except ValueError:
452
+ test_auc = np.nan
453
+
454
+
455
+ # print(f"logits = {logits[:10]}")
456
+ # print(f"preds = {preds[:2]}")
457
+ # print(f"labels = {labels[:10]}")
458
+
459
+ # print(f"len(Unique logits) = {len(torch.unique(logits))}")
460
+ # print(f"Average of labels = {torch.mean(labels)}")
461
+ # print(f"unique logits = {torch.unique(logits)[0]:.4f}, {torch.unique(logits)[-1]:.4f}")
462
+
463
+
464
+ if (loss_fcn.__class__.__name__ == "MultiLabelLoss"):
465
+ multilabel_log_str = "MultiLabel_Accuracy "
466
+ for accuracy in multilabel_accuracy:
467
+ multilabel_log_str += f" | {accuracy:.4f}"
468
+ log.write(multilabel_log_str + '\n')
469
+ print(multilabel_log_str, flush=True)
470
+ elif (loss_fcn.__class__.__name__ == "ContrastiveClusterLoss"):
471
+ contrastive_cluster_log_str = "ContrastiveClusterLoss "
472
+ contrastive_cluster_log_str += f"Contrastive Loss: {contrastive_cluster_loss[0]:.4f}, Clustering Loss: {contrastive_cluster_loss[1]:.4f}, Variance Loss: {contrastive_cluster_loss[2]:.4f}"
473
+ log.write(contrastive_cluster_log_str + '\n')
474
+ print(contrastive_cluster_log_str, flush=True)
475
+
476
+ # test_loss = loss_fcn(logits, labels.to(device))
477
+ test_loss = loss_fcn(logits, labels)
478
+ end = time.time()
479
+ log_str = "Epoch {:05d} | LR {:.4e} | Loss {:.4f} | Accuracy {:.4f} | Test_Loss {:.4f} | Test_AUC {:.4f} | Time {:.4f} s".format(
480
+ epoch, optimizer.param_groups[0]['lr'], total_loss/ibatch, acc, test_loss, test_auc, end - start
481
+ )
482
+ log.write(log_str + '\n')
483
+ print(log_str, flush=True)
484
+
485
+ state_dict = model.state_dict()
486
+ if not nocompile:
487
+ state_dict = model._orig_mod.state_dict()
488
+
489
+ new_state_dict = {}
490
+ for k, v in state_dict.items():
491
+ new_key = k.replace('module.', '')
492
+ new_state_dict[new_key] = v
493
+ state_dict = new_state_dict
494
+
495
+ # print('Testing done')
496
+ # gpu_mem()
497
+
498
+ if epoch == 2:
499
+ # torch.cuda.cudart().cudaProfilerStop()
500
+ pass
501
+
502
+ torch.save({
503
+ 'epoch': epoch,
504
+ 'model_state_dict': state_dict,
505
+ 'optimizer_state_dict': optimizer.state_dict(),
506
+ 'early_stop': early_termination.to_dict()
507
+ }, os.path.join(config['Training_Directory'], f"model_epoch_{epoch}.pt"))
508
+ np.savez(os.path.join(config['Training_Directory'], f'model_epoch_{epoch}.npz'), scores=scores.to("cpu"), labels=labels.to("cpu"))
509
+ save_end = time.time()
510
+ cumulative_times[4] += save_end - eval_end
511
+
512
+ early_termination.update(test_loss)
513
+ if early_termination.should_stop:
514
+ log_str = f"Early Termination at Epoch {epoch}"
515
+ log.write(log_str + "\n")
516
+ print(log_str)
517
+ log_str = early_termination.to_str()
518
+ log.write(log_str + "\n")
519
+ print(log_str)
520
+ break
521
+
522
+ if (custom_scheduler):
523
+ custom_scheduler.step(model, {'test_auc':test_auc})
524
+ scheduler.step()
525
+
526
+ print(f"Load: {cumulative_times[0]:.4f} s")
527
+ print(f"Batch: {cumulative_times[1]:.4f} s")
528
+ print(f"Train: {cumulative_times[2]:.4f} s")
529
+ print(f"Eval: {cumulative_times[3]:.4f} s")
530
+ print(f"Save: {cumulative_times[4]:.4f} s")
531
+ log.close()
532
+
533
+ def find_free_port():
534
+ import socket
535
+ from contextlib import closing
536
+
537
+ with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
538
+ s.bind(('', 0))
539
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
540
+ return str(s.getsockname()[1])
541
+
542
+ def init_process_group(world_size, rank, port):
543
+ os.environ['MASTER_ADDR'] = 'localhost'
544
+ # os.environ['MASTER_PORT'] = find_free_port()
545
+ os.environ['MASTER_PORT'] = port
546
+
547
+ dist.init_process_group(
548
+ backend="nccl", # change to 'nccl' for multiple GPUs (other was gloo)
549
+ init_method='env://',
550
+ world_size=world_size,
551
+ rank=rank,
552
+ timeout=datetime.timedelta(seconds=300),
553
+ )
554
+
555
+ def main(rank=0, args=None, world_size=1, port=24500, seed=12345):
556
+
557
+ #Prevent simultaneous file access
558
+ #sleep_time = 120 * rank
559
+ #time.sleep(sleep_time)
560
+
561
+ #Load config file
562
+ config = utils.load_config(args.config)
563
+
564
+ if (args.directory):
565
+ print(f"New training directory: { config['Training_Directory'] + args.directory}")
566
+ config['Training_Directory'] = config['Training_Directory'] + args.directory
567
+
568
+ if not os.path.exists(config['Training_Directory']):
569
+ os.makedirs(config['Training_Directory'], exist_ok=True)
570
+ with open(config['Training_Directory'] + '/config.yaml', 'w') as f:
571
+ yaml.dump(config, f)
572
+ batch_size = config["Training"]["batch_size"]
573
+
574
+ if(args.plot):
575
+ rl = utils.read_log(config)
576
+ utils.plot_log(rl, config['Training_Directory'] + '/training.png')
577
+ print('Log at ' + config['Training_Directory'] + '/training.log')
578
+ print('Plotted at ' + config['Training_Directory'] + '/training.png')
579
+ exit()
580
+
581
+ if (args.multigpu):
582
+ print(f"Setting up multigpu")
583
+ start_time = time.time()
584
+ init_process_group(world_size, rank, port)
585
+ print("multigpu setup time: {:.4f} s".format(time.time() - start_time))
586
+ device = torch.device(f'cuda:{rank}')
587
+ torch.cuda.device(device)
588
+ elif (args.multinode):
589
+ device = torch.device(f'cuda:{rank}')
590
+ torch.cuda.device(device)
591
+ print(f"global rank = {args.global_rank}, local rank = {rank}, device = {device}")
592
+ else:
593
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
594
+
595
+ if (args.cpu):
596
+ print(f"Using CPU")
597
+ device = "cpu"
598
+
599
+ train_loaders = []
600
+ test_loaders = []
601
+ val_loaders = []
602
+ load_start = time.time()
603
+
604
+ torch.backends.cuda.matmul.allow_tf32 = True
605
+
606
+ ldr_type = datasets.LazyPreBatchedDataset if args.lazy else datasets.PreBatchedDataset
607
+
608
+ #Load datasets
609
+ if (pargs.statistics):
610
+ pargs.statistics = int(pargs.statistics)
611
+ print(f"Training Dataset Size: {pargs.statistics}")
612
+ num_batches = int(np.ceil(pargs.statistics / batch_size))
613
+ np.random.seed(pargs.seed)
614
+
615
+ for dset_conf in config["Datasets"]:
616
+ dset = utils.buildFromConfig(config["Datasets"][dset_conf])
617
+ if 'batch_size' in config["Datasets"][dset_conf]:
618
+ batch_size = config["Datasets"][dset_conf]['batch_size']
619
+ fold_conf = config["Datasets"][dset_conf]["folding"]
620
+ shuffle_chunks = config["Datasets"][dset_conf].get("shuffle_chunks", 10)
621
+ padding_mode = config["Datasets"][dset_conf].get("padding_mode", "STEPS")
622
+ mask_fn = utils.fold_selection(fold_conf, "train")
623
+ if args.preshuffle:
624
+ # ldr = ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=mask_fn, suffix = utils.fold_selection_name(fold_conf, 'train'), chunks = shuffle_chunks, padding_mode = padding_mode, use_ddp = args.multigpu, rank=rank, world_size=world_size)
625
+ ldr = ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=mask_fn, suffix = utils.fold_selection_name(fold_conf, 'train'), chunks = shuffle_chunks, padding_mode = padding_mode)
626
+ gsamp, _, _, global_samp = ldr[0]
627
+ sampler = None
628
+
629
+ if (pargs.statistics):
630
+ sampler = np.random.choice(range(len(ldr)), size=num_batches)
631
+
632
+ if (args.multigpu):
633
+ sampler = DistributedSampler(ldr, num_replicas=world_size, rank=rank, shuffle=False, drop_last=True)
634
+ # num_batches = len(ldr)
635
+ # sampler = list(sampler)
636
+ # if (sampler[0] >= num_batches % world_size):
637
+ # sampler.pop()
638
+ if (args.multinode):
639
+ sampler = DistributedSampler(ldr, num_replicas=world_size, rank=pargs.global_rank, shuffle=False, drop_last=True)
640
+ train_loaders.append(torch.utils.data.DataLoader(ldr, batch_size = None, num_workers = 0, sampler = sampler))
641
+ sampler = None
642
+ ldr = ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=mask_fn, suffix = utils.fold_selection_name(fold_conf, 'test'), chunks = shuffle_chunks, padding_mode = padding_mode)
643
+ if (args.multigpu):
644
+ sampler = DistributedSampler(ldr, num_replicas=world_size, rank=rank, shuffle=False, drop_last=True)
645
+ # num_batches = len(ldr)
646
+ # sampler = list(sampler)
647
+ # if (rank >= num_batches % world_size):
648
+ # sampler.pop()
649
+ if (args.multinode):
650
+ sampler = DistributedSampler(ldr, num_replicas=world_size, rank=pargs.global_rank, shuffle=False, drop_last=True)
651
+
652
+ test_loaders.append(torch.utils.data.DataLoader(ldr, batch_size = None, num_workers = 0, sampler=sampler))
653
+
654
+ if "validation" in fold_conf:
655
+ val_loaders.append(torch.utils.data.DataLoader((ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=utils.fold_selection(fold_conf, "validation"), suffix = utils.fold_selection_name(fold_conf, 'validation'), chunks = shuffle_chunks, padding_mode = padding_mode, rank=rank, world_size=1)), batch_size = None, num_workers = 0, sampler = sampler))
656
+ else:
657
+ print("No validation set for dataset ", dset_conf)
658
+ else:
659
+ train_loaders.append(datasets.GetBatchedLoader(dset, batch_size, utils.fold_selection(fold_conf, "train")))
660
+ gsamp, _, _, global_samp = dset[0]
661
+ test_loaders.append(datasets.GetBatchedLoader(dset, batch_size, utils.fold_selection(fold_conf, "test")))
662
+ if "validation" in fold_conf:
663
+ val_loaders.append(datasets.GetBatchedLoader(dset, batch_size, utils.fold_selection(fold_conf, "validation")))
664
+ else:
665
+ print("No validation set for dataset ", dset_conf)
666
+
667
+ load_end = time.time()
668
+ print("Load time: {:.4f} s".format(load_end - load_start))
669
+ model = utils.buildFromConfig(config["Model"], {'sample_graph': gsamp, 'sample_global': global_samp, 'seed': seed}).to(device)
670
+ if not args.nocompile:
671
+ model = torch.compile(model)
672
+ if args.multigpu:
673
+ print(f"Trying to create DDP model")
674
+ start_time = time.time()
675
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device])
676
+ print("model creation time: {:.4f} s".format(time.time() - start_time))
677
+ if (args.multinode):
678
+ print(f"Trying to create DDP model")
679
+ start_time = time.time()
680
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device])
681
+ print("model creation time: {:.4f} s".format(time.time() - start_time))
682
+
683
+ # total_params = 0
684
+ # for param_dict in model.parameters():
685
+ # for param in param_dict['params']:
686
+ # if param.requires_grad:
687
+ # total_params += param.numel()
688
+ # print(f"Number of trainable parameters = {total_params}")
689
+
690
+ if(type(model) == GCN.Clustering):
691
+ print("clustering")
692
+
693
+ if args.evaluate != None:
694
+ evaluate(test_loaders, model, config, device, args.evaluate)
695
+ exit()
696
+
697
+ # model training
698
+ print("Training...")
699
+ gpu_mem()
700
+ train(train_loaders, test_loaders, model, device, config, args, rank)
701
+
702
+ # test the model
703
+ # print("Testing...")
704
+ # evaluate(val_loaders, model, config, device)
705
+
706
+ # if args.multigpu or args.multinode:
707
+ # dist.destroy_process_group()
708
+
709
+ # if rank == 0:
710
+ # rl = utils.read_log(config)
711
+ # utils.plot_log(rl, config['Training_Directory'] + '/training.png')
712
+ # print('Log at ' + config['Training_Directory'] + '/training.log')
713
+ # print('Plotted at ' + config['Training_Directory'] + '/training.png')
714
+
715
+ if __name__ == "__main__":
716
+ #Handle CLI arguments
717
+ parser = argparse.ArgumentParser()
718
+ add_arg = parser.add_argument
719
+ add_arg("--config", type=str, help="Config file.", required=True)
720
+ add_arg("--restart", action="store_true", help="Restart training from scratch.")
721
+ add_arg("--preshuffle", action="store_true", help="Shuffle data before training.")
722
+ add_arg("--lazy", action="store_true", help="Lazy loading of data.")
723
+ add_arg("--nocompile", action="store_true", help="Disable JIT compilation.")
724
+ add_arg("--evaluate", type = int, help="Skip training and go to evaluation.")
725
+ add_arg("--plot", action="store_true", help="Plot training logs.")
726
+ add_arg("--multigpu", action="store_true", help="Use multiple GPUs.")
727
+ add_arg("--multinode", action="store_true", help="Use multiple nodes.")
728
+ add_arg("--savecache", action="store_true", help="")
729
+ add_arg("--cpu", action="store_true", help="Uses the cpu only")
730
+ add_arg("--statistics", type=float, help="Size of training data")
731
+ add_arg("--directory", type=str, help="Append to Training Directory")
732
+ add_arg("--seed", type=int, default=2, help="Sets random seed")
733
+
734
+ pargs = parser.parse_args()
735
+
736
+ if pargs.multigpu:
737
+ port = find_free_port()
738
+ torch.backends.cudnn.enabled = False
739
+ mp.spawn(main, args=(pargs, 4, port), nprocs=4, join=True)
740
+ if pargs.multinode:
741
+ global_rank = int(os.environ["RANK"])
742
+ local_rank = int(os.environ["LOCAL_RANK"])
743
+ world_size = int(os.environ["WORLD_SIZE"])
744
+ print(f"global_rank = {global_rank}, local_rank = {local_rank}, world_size = {world_size}")
745
+
746
+ dist.init_process_group(backend="nccl")
747
+ torch.backends.cudnn.enabled = False
748
+
749
+ pargs.global_rank = global_rank
750
+
751
+ main(rank = local_rank, args=pargs, world_size=world_size)
752
+ else:
753
+ main(0, pargs)
754
+
755
+