ho22joshua commited on
Commit
5ca4f82
·
1 Parent(s): 4d16332

added root_gnn_dgl directory

Browse files
models/GCN.py DELETED
@@ -1,1944 +0,0 @@
1
- import dgl
2
- import dgl.nn as dglnn
3
-
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
-
8
- import sys
9
- import os
10
- file_path = os.getcwd()
11
- sys.path.append(file_path)
12
-
13
- import root_gnn_base.dataset as datasets
14
- from root_gnn_base import utils
15
-
16
- import gc
17
-
18
- def Make_SLP(in_size, out_size, activation = nn.ReLU, dropout = 0):
19
- layers = []
20
- layers.append(nn.Linear(in_size, out_size))
21
- layers.append(activation())
22
- layers.append(nn.Dropout(dropout))
23
- return layers
24
-
25
- def Make_MLP(in_size, hid_size, out_size, n_layers, activation = nn.ReLU, dropout = 0):
26
- layers = []
27
- if n_layers > 1:
28
- layers += Make_SLP(in_size, hid_size, activation, dropout)
29
- for i in range(n_layers-2):
30
- layers += Make_SLP(hid_size, hid_size, activation, dropout)
31
- layers += Make_SLP(hid_size, out_size, activation, dropout)
32
- else:
33
- layers += Make_SLP(in_size, out_size, activation, dropout)
34
- layers.append(torch.nn.LayerNorm(out_size))
35
- return nn.Sequential(*layers)
36
-
37
- class MLP(nn.Module):
38
- def __init__(self, in_size, hid_size, out_size, n_layers, activation = nn.ReLU, dropout = 0, **kwargs):
39
- super().__init__()
40
- print(f'Unused args while creating MLP: {kwargs}')
41
- self.layers = Make_MLP(in_size, hid_size, hid_size, n_layers-1, activation, dropout)
42
- self.linear = nn.Linear(hid_size, out_size)
43
-
44
- def forward(self, x):
45
- return self.linear(self.layers(x))
46
-
47
- def broadcast_global_to_nodes(g, globals):
48
- boundaries = g.batch_num_nodes()
49
- return torch.repeat_interleave(globals, boundaries, dim=0)
50
-
51
- def broadcast_global_to_edges(g, globals):
52
- boundaries = g.batch_num_edges()
53
- return torch.repeat_interleave(globals, boundaries, dim=0)
54
-
55
- def copy_v(edges):
56
- return {'m_v': edges.dst['h']}
57
-
58
- def partial_reset(model : nn.Module):
59
- in_size = len(model.classify.weight[0])
60
- out_size = len(model.classify.weight)
61
- device = next(model.classify.parameters()).device
62
- torch.manual_seed(2)
63
- model.classify = nn.Linear(in_size, out_size)
64
- model.classify.to(device)
65
- print(model.classify.weight)
66
-
67
- def print_model(model: nn.Module):
68
- print(model)
69
-
70
- def print_mlp(layer):
71
- for l in layer.children():
72
- if isinstance(l, nn.Linear):
73
- print(l.state_dict())
74
- else:
75
- print(l)
76
-
77
-
78
- # This function returns a model with the whole GNN completely reset
79
- def full_reset(model : nn.Module):
80
- mlp_list = [model.node_encoder, model.edge_encoder, model.global_encoder,
81
- model.node_update, model.edge_update, model.global_update,
82
- model.global_decoder]
83
-
84
- for mlp in mlp_list:
85
- for layer in mlp.children():
86
- if hasattr(layer, 'reset_parameters'):
87
- layer.reset_parameters()
88
- partial_reset(model)
89
-
90
-
91
- class GCN(nn.Module):
92
- def __init__(self, in_size, hid_size, out_size, n_layers, **kwargs):
93
- super().__init__()
94
- print(f'Unused args while creating GCN: {kwargs}')
95
- self.n_layers = n_layers
96
- self.layers = nn.ModuleList()
97
-
98
- # two-layer GCN
99
- self.layers.extend(
100
- [nn.Linear(in_size, hid_size),] +
101
- [nn.Linear(hid_size, hid_size) for i in range(n_layers)] +
102
- [dglnn.GraphConv(hid_size, hid_size) for i in range(n_layers)] +
103
- [nn.Linear(hid_size, hid_size) for i in range(n_layers)]
104
- )
105
- self.classify = nn.Linear(hid_size, out_size)
106
- #self.dropout = nn.Dropout(0.05)
107
-
108
- def forward(self, g):
109
- h = g.ndata['features']
110
- for i, layer in enumerate(self.layers):
111
- if i >= self.n_layers + 1 and i < self.n_layers * 2 + 1:
112
- h = layer(g, h)
113
- else:
114
- h = layer(h)
115
- h = F.relu(h)
116
- with g.local_scope():
117
- g.ndata['h'] = h
118
- # Calculate graph representation by average readout.
119
- hg = dgl.mean_nodes(g, 'h')
120
- return self.classify(hg)
121
-
122
- class GCN_global(nn.Module):
123
- def __init__(self, in_size, hid_size=4, out_size=1, n_layers=1, dropout=0, **kwargs):
124
- super().__init__()
125
- print(f'Unused args while creating GCN: {kwargs}')
126
- self.n_layers = n_layers
127
-
128
- #encoder
129
- self.node_encoder = Make_MLP(in_size, hid_size, hid_size, n_layers, dropout=dropout)
130
- self.global_encoder = Make_MLP(1, hid_size, hid_size, n_layers, dropout=dropout)
131
-
132
- #GCN
133
- self.node_update = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
134
- self.global_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
135
- self.conv = dglnn.GraphConv(hid_size, hid_size)
136
-
137
- #decoder
138
- self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
139
- self.classify = nn.Linear(hid_size, out_size)
140
-
141
- def forward(self, g):
142
- h = self.node_encoder(g.ndata['features'])
143
- h_global = self.global_encoder(g.batch_num_nodes()[:, None].to(torch.float))
144
- for i in range(self.n_layers):
145
- h = self.node_update(h)
146
- h = self.conv(g, h)
147
- g.ndata['h'] = h
148
- h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h')), dim = 1))
149
- h_global = self.global_decoder(h_global)
150
- return self.classify(h_global)
151
-
152
- class GCN_global_2way(nn.Module):
153
- def __init__(self, in_size, hid_size=4, out_size=1, n_layers=1, dropout=0, **kwargs):
154
- super().__init__()
155
- print(f'Unused args while creating GCN: {kwargs}')
156
- self.n_layers = n_layers
157
-
158
- #encoder
159
- self.node_encoder = Make_MLP(in_size, hid_size, hid_size, n_layers, dropout=dropout)
160
- self.global_encoder = Make_MLP(1, hid_size, hid_size, n_layers, dropout=dropout)
161
-
162
- #GCN
163
- self.node_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
164
- self.global_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
165
- self.conv = dglnn.GraphConv(hid_size, hid_size)
166
-
167
- #decoder
168
- self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
169
- self.classify = nn.Linear(hid_size, out_size)
170
-
171
- def forward(self, g):
172
- h = self.node_encoder(g.ndata['features'])
173
- h_global = self.global_encoder(g.batch_num_nodes()[:, None].to(torch.float))
174
- for i in range(self.n_layers):
175
- h = self.node_update(torch.cat((h, broadcast_global_to_nodes(g, h_global)), dim = 1))
176
- h = self.conv(g, h)
177
- g.ndata['h'] = h
178
- h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h')), dim = 1))
179
- h_global = self.global_decoder(h_global)
180
- return self.classify(h_global)
181
-
182
- class Edge_Network(nn.Module):
183
- def __init__(self, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs):
184
- super().__init__()
185
- print(f'Unused args while creating GCN: {kwargs}')
186
- self.n_layers = n_layers
187
- self.n_proc_steps = n_proc_steps
188
- self.layers = nn.ModuleList()
189
- if (len(sample_global) == 0):
190
- self.has_global = False
191
- else:
192
- self.has_global = sample_global.shape[1] != 0
193
- gl_size = sample_global.shape[1] if self.has_global else 1
194
-
195
- #encoder
196
- self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
197
- self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
198
- self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
199
-
200
- #GNN
201
- self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
202
- self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
203
- self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
204
-
205
- #decoder
206
- self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
207
- self.classify = nn.Linear(hid_size, out_size)
208
-
209
- def forward(self, g, global_feats):
210
- h = self.node_encoder(g.ndata['features'])
211
- e = self.edge_encoder(g.edata['features'])
212
-
213
- g.ndata['h'] = h
214
- g.edata['e'] = e
215
- if not self.has_global:
216
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
217
-
218
- batch_num_nodes = None
219
- sum_weights = None
220
- if "w" in g.ndata:
221
- batch_indices = g.batch_num_nodes()
222
- # Find non-zero rows (non-padded nodes)
223
- non_padded_nodes_mask = torch.any(g.ndata['features'] != 0, dim=1)
224
- # Split the mask according to the batch indices
225
- batch_num_nodes = []
226
- start_idx = 0
227
- for num_nodes in batch_indices:
228
- end_idx = start_idx + num_nodes
229
- non_padded_count = non_padded_nodes_mask[start_idx:end_idx].sum().item()
230
- batch_num_nodes.append(non_padded_count)
231
- start_idx = end_idx
232
- batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata['features'].device)
233
- sum_weights = batch_num_nodes[:, None].repeat(1, 64)
234
- global_feats = batch_num_nodes[:, None].to(torch.float)
235
-
236
- h_global = self.global_encoder(global_feats)
237
- for i in range(self.n_proc_steps):
238
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
239
- g.apply_edges(copy_v)
240
- g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
241
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
242
- g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
243
- if "w" in g.ndata:
244
- mean_nodes = dgl.sum_nodes(g, 'h', 'w') / sum_weights
245
- h_global = self.global_update(torch.cat((h_global, mean_nodes, dgl.mean_edges(g, 'e')), dim = 1))
246
- else:
247
- h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
248
- h_global = self.global_decoder(h_global)
249
- return self.classify(h_global)
250
-
251
- def representation(self, g, global_feats):
252
- h = self.node_encoder(g.ndata['features'])
253
- e = self.edge_encoder(g.edata['features'])
254
-
255
- g.ndata['h'] = h
256
- g.edata['e'] = e
257
- if not self.has_global:
258
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
259
-
260
- batch_num_nodes = None
261
- sum_weights = None
262
- if "w" in g.ndata:
263
- batch_indices = g.batch_num_nodes()
264
- # Find non-zero rows (non-padded nodes)
265
- non_padded_nodes_mask = torch.any(g.ndata['features'] != 0, dim=1)
266
- # Split the mask according to the batch indices
267
- batch_num_nodes = []
268
- start_idx = 0
269
- for num_nodes in batch_indices:
270
- end_idx = start_idx + num_nodes
271
- non_padded_count = non_padded_nodes_mask[start_idx:end_idx].sum().item()
272
- batch_num_nodes.append(non_padded_count)
273
- start_idx = end_idx
274
- batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata['features'].device)
275
- sum_weights = batch_num_nodes[:, None].repeat(1, 64)
276
- global_feats = batch_num_nodes[:, None].to(torch.float)
277
-
278
- h_global = self.global_encoder(global_feats)
279
- for i in range(self.n_proc_steps):
280
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
281
- g.apply_edges(copy_v)
282
- g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
283
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
284
- g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
285
- if "w" in g.ndata:
286
- mean_nodes = dgl.sum_nodes(g, 'h', 'w') / sum_weights
287
- h_global = self.global_update(torch.cat((h_global, mean_nodes, dgl.mean_edges(g, 'e')), dim = 1))
288
- else:
289
- h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
290
- before_global_decoder = h_global
291
- after_global_decoder = self.global_decoder(before_global_decoder)
292
- after_classify = self.classify(after_global_decoder)
293
- return before_global_decoder, after_global_decoder, after_classify
294
-
295
- def __str__(self):
296
- layer_names = ["node_encoder", "edge_encoder", "global_encoder",
297
- "node_update", "edge_update", "global_update", "global_decoder"]
298
-
299
- layers = [self.node_encoder, self.edge_encoder, self.global_encoder,
300
- self.node_update, self.edge_update, self.global_update, self.global_decoder]
301
-
302
- for i in range(len(layers)):
303
- print(layer_names[i])
304
- for layer in layers[i].children():
305
- if isinstance(layer, nn.Linear):
306
- print(layer.state_dict())
307
-
308
- print("classify")
309
- print(self.classify.weight)
310
- return ""
311
-
312
- class Transferred_Learning(nn.Module):
313
- def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs):
314
- super().__init__()
315
- print(f'Unused args while creating GCN: {kwargs}')
316
- self.n_layers = n_layers
317
- self.n_proc_steps = n_proc_steps
318
- self.layers = nn.ModuleList()
319
-
320
- if (len(sample_global) == 0):
321
- self.has_global = False
322
- else:
323
- self.has_global = sample_global.shape[1] != 0
324
- gl_size = sample_global.shape[1] if self.has_global else 1
325
-
326
- self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
327
-
328
- checkpoint = torch.load(pretraining_path)
329
- self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
330
- pretrained_layers = list(self.pretrained_model.children())
331
- pretrained_layers = pretrained_layers[:-1]
332
- self.pretrained_model = nn.Sequential(*pretrained_layers)
333
-
334
- # Freeze Weights
335
- for param in self.pretrained_model.parameters():
336
- param.requires_grad = False # Freeze all layers
337
-
338
- self.global_decoder = Make_MLP(pretraining_model['args']['hid_size'], hid_size, hid_size, n_layers, dropout=dropout)
339
- self.classify = nn.Linear(hid_size, out_size)
340
-
341
- def TL_node_encoder(self, x):
342
- for layer in self.pretrained_model[1]:
343
- x = layer(x)
344
- return x
345
-
346
- def TL_edge_encoder(self, x):
347
- for layer in self.pretrained_model[2]:
348
- x = layer(x)
349
- return x
350
-
351
- def TL_global_encoder(self, x):
352
- for layer in self.pretrained_model[3]:
353
- x = layer(x)
354
- return x
355
-
356
- def TL_node_update(self, x):
357
- for layer in self.pretrained_model[4]:
358
- x = layer(x)
359
- return x
360
-
361
- def TL_edge_update(self, x):
362
- for layer in self.pretrained_model[5]:
363
- x = layer(x)
364
- return x
365
-
366
- def TL_global_update(self, x):
367
- for layer in self.pretrained_model[6]:
368
- x = layer(x)
369
- return x
370
-
371
- def TL_global_decoder(self, x):
372
- for layer in self.pretrained_model[7]:
373
- x = layer(x)
374
- return x
375
-
376
- def forward(self, g, global_feats):
377
- h = self.TL_node_encoder(g.ndata['features'])
378
- e = self.TL_edge_encoder(g.edata['features'])
379
- g.ndata['h'] = h
380
- g.edata['e'] = e
381
- if not self.has_global:
382
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
383
- h_global = self.TL_global_encoder(global_feats)
384
- for i in range(self.n_proc_steps):
385
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
386
- g.apply_edges(copy_v)
387
- g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
388
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
389
- g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
390
- h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
391
- h_global = self.TL_global_decoder(h_global)
392
- return self.classify(self.global_decoder(h_global))
393
-
394
- class Transferred_Learning_Graph(nn.Module):
395
- def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, additional_proc_steps=1, dropout=0, **kwargs):
396
- super().__init__()
397
- print(f'Unused args while creating GCN: {kwargs}')
398
- self.n_layers = n_layers
399
- self.n_proc_steps = n_proc_steps
400
- self.layers = nn.ModuleList()
401
-
402
- if (len(sample_global) == 0):
403
- self.has_global = False
404
- else:
405
- self.has_global = sample_global.shape[1] != 0
406
- gl_size = sample_global.shape[1] if self.has_global else 1
407
-
408
- self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
409
-
410
- checkpoint = torch.load(pretraining_path)
411
- self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
412
- pretrained_layers = list(self.pretrained_model.children())
413
- pretrained_layers = pretrained_layers[:-1]
414
- self.pretrained_model = nn.Sequential(*pretrained_layers)
415
-
416
- self.additional_proc_steps = additional_proc_steps
417
-
418
- # Freeze Weights
419
- for param in self.pretrained_model.parameters():
420
- param.requires_grad = False # Freeze all layers
421
-
422
- #GNN
423
- self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
424
- self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
425
- self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
426
-
427
- #decoder
428
- self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
429
- self.classify = nn.Linear(hid_size, out_size)
430
-
431
- def TL_node_encoder(self, x):
432
- for layer in self.pretrained_model[1]:
433
- x = layer(x)
434
- return x
435
-
436
- def TL_edge_encoder(self, x):
437
- for layer in self.pretrained_model[2]:
438
- x = layer(x)
439
- return x
440
-
441
- def TL_global_encoder(self, x):
442
- for layer in self.pretrained_model[3]:
443
- x = layer(x)
444
- return x
445
-
446
- def TL_node_update(self, x):
447
- for layer in self.pretrained_model[4]:
448
- x = layer(x)
449
- return x
450
-
451
- def TL_edge_update(self, x):
452
- for layer in self.pretrained_model[5]:
453
- x = layer(x)
454
- return x
455
-
456
- def TL_global_update(self, x):
457
- for layer in self.pretrained_model[6]:
458
- x = layer(x)
459
- return x
460
-
461
- def forward(self, g, global_feats):
462
- h = self.TL_node_encoder(g.ndata['features'])
463
- e = self.TL_edge_encoder(g.edata['features'])
464
- g.ndata['h'] = h
465
- g.edata['e'] = e
466
- if not self.has_global:
467
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
468
- h_global = self.TL_global_encoder(global_feats)
469
- for i in range(self.n_proc_steps):
470
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
471
- g.apply_edges(copy_v)
472
- g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
473
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
474
- g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
475
- h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
476
- for j in range(self.additional_proc_steps):
477
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
478
- g.apply_edges(copy_v)
479
- g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
480
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
481
- g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
482
- h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
483
-
484
- h_global = self.global_decoder(h_global)
485
- return self.classify(h_global)
486
-
487
- class Transferred_Learning_Parallel(nn.Module):
488
- def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs):
489
- super().__init__()
490
- print(f'Unused args while creating GCN: {kwargs}')
491
- self.n_layers = n_layers
492
- self.n_proc_steps = n_proc_steps
493
- self.layers = nn.ModuleList()
494
- self.has_global = sample_global.shape[1] != 0
495
- gl_size = sample_global.shape[1] if self.has_global else 1
496
-
497
- self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
498
- checkpoint = torch.load(pretraining_path)
499
- self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
500
- pretrained_layers = list(self.pretrained_model.children())
501
- pretrained_layers = pretrained_layers[:-1]
502
- self.pretrained_model = nn.Sequential(*pretrained_layers)
503
-
504
- # Freeze Weights
505
- for param in self.pretrained_model.parameters():
506
- param.requires_grad = False # Freeze all layers
507
-
508
- #encoder
509
- self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
510
- self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
511
- self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
512
-
513
- #GNN
514
- self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
515
- self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
516
- self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
517
-
518
- #decoder
519
- self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
520
- self.classify = nn.Linear(hid_size + pretraining_model['args']['hid_size'], out_size)
521
-
522
- def TL_node_encoder(self, x):
523
- for layer in self.pretrained_model[1]:
524
- x = layer(x)
525
- return x
526
-
527
- def TL_edge_encoder(self, x):
528
- for layer in self.pretrained_model[2]:
529
- x = layer(x)
530
- return x
531
-
532
- def TL_global_encoder(self, x):
533
- for layer in self.pretrained_model[3]:
534
- x = layer(x)
535
- return x
536
-
537
- def TL_node_update(self, x):
538
- for layer in self.pretrained_model[4]:
539
- x = layer(x)
540
- return x
541
-
542
- def TL_edge_update(self, x):
543
- for layer in self.pretrained_model[5]:
544
- x = layer(x)
545
- return x
546
-
547
- def TL_global_update(self, x):
548
- for layer in self.pretrained_model[6]:
549
- x = layer(x)
550
- return x
551
-
552
- def TL_global_decoder(self, x):
553
- for layer in self.pretrained_model[7]:
554
- x = layer(x)
555
- return x
556
-
557
- def Pretrained_Output(self, g):
558
- h = self.TL_node_encoder(g.ndata['features'])
559
- e = self.TL_edge_encoder(g.edata['features'])
560
- g.ndata['h'] = h
561
- g.edata['e'] = e
562
- if not self.has_global:
563
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
564
- h_global = self.TL_global_encoder(global_feats)
565
- for i in range(self.n_proc_steps):
566
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
567
- g.apply_edges(copy_v)
568
- g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
569
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
570
- g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
571
- h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
572
- h_global = self.TL_global_decoder(h_global)
573
- return h_global
574
-
575
- def forward(self, g, global_feats):
576
- pretrained_global = self.Pretrained_Output(g.clone())
577
- h = self.node_encoder(g.ndata['features'])
578
- e = self.edge_encoder(g.edata['features'])
579
- g.ndata['h'] = h
580
- g.edata['e'] = e
581
- if not self.has_global:
582
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
583
- h_global = self.global_encoder(global_feats)
584
- for i in range(self.n_proc_steps):
585
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
586
- g.apply_edges(copy_v)
587
- g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
588
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
589
- g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
590
- h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
591
- h_global = self.global_decoder(h_global)
592
-
593
- return self.classify(torch.cat((pretrained_global, h_global), dim = 1))
594
-
595
- class Transferred_Learning_Sequential(nn.Module):
596
- def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs):
597
- super().__init__()
598
- print(f'Unused args while creating GCN: {kwargs}')
599
- self.n_layers = n_layers
600
- self.n_proc_steps = n_proc_steps
601
- self.layers = nn.ModuleList()
602
- self.has_global = sample_global.shape[1] != 0
603
- gl_size = sample_global.shape[1] if self.has_global else 1
604
-
605
- self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
606
- checkpoint = torch.load(pretraining_path)
607
- self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
608
- pretrained_layers = list(self.pretrained_model.children())
609
- pretrained_layers = pretrained_layers[:-1]
610
- self.pretrained_model = nn.Sequential(*pretrained_layers)
611
-
612
- # Freeze Weights
613
- for param in self.pretrained_model.parameters():
614
- param.requires_grad = False # Freeze all layers
615
-
616
- #encoder
617
- self.mlp = Make_MLP(pretraining_model['args']['hid_size'], hid_size, hid_size, n_layers, dropout=dropout)
618
-
619
- self.classify = nn.Linear(hid_size, out_size)
620
-
621
- def TL_node_encoder(self, x):
622
- for layer in self.pretrained_model[1]:
623
- x = layer(x)
624
- return x
625
-
626
- def TL_edge_encoder(self, x):
627
- for layer in self.pretrained_model[2]:
628
- x = layer(x)
629
- return x
630
-
631
- def TL_global_encoder(self, x):
632
- for layer in self.pretrained_model[3]:
633
- x = layer(x)
634
- return x
635
-
636
- def TL_node_update(self, x):
637
- for layer in self.pretrained_model[4]:
638
- x = layer(x)
639
- return x
640
-
641
- def TL_edge_update(self, x):
642
- for layer in self.pretrained_model[5]:
643
- x = layer(x)
644
- return x
645
-
646
- def TL_global_update(self, x):
647
- for layer in self.pretrained_model[6]:
648
- x = layer(x)
649
- return x
650
-
651
- def TL_global_decoder(self, x):
652
- for layer in self.pretrained_model[7]:
653
- x = layer(x)
654
- return x
655
-
656
- def Pretrained_Output(self, g):
657
- h = self.TL_node_encoder(g.ndata['features'])
658
- e = self.TL_edge_encoder(g.edata['features'])
659
- g.ndata['h'] = h
660
- g.edata['e'] = e
661
- if not self.has_global:
662
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
663
- h_global = self.TL_global_encoder(global_feats)
664
- for i in range(self.n_proc_steps):
665
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
666
- g.apply_edges(copy_v)
667
- g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
668
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
669
- g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
670
- h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
671
- h_global = self.TL_global_decoder(h_global)
672
- return h_global
673
-
674
- def forward(self, g, global_feats):
675
- pretrained_global = self.Pretrained_Output(g.clone())
676
- global_features = self.mlp(pretrained_global)
677
- return self.classify(global_features)
678
-
679
-
680
- class Transferred_Learning_Message_Passing(nn.Module):
681
- def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs):
682
- super().__init__()
683
- print(f'Unused args while creating GCN: {kwargs}')
684
- self.n_layers = n_layers
685
- self.n_proc_steps = n_proc_steps
686
- self.layers = nn.ModuleList()
687
- self.has_global = sample_global.shape[1] != 0
688
- gl_size = sample_global.shape[1] if self.has_global else 1
689
-
690
- self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
691
- checkpoint = torch.load(pretraining_path)
692
- self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
693
- pretrained_layers = list(self.pretrained_model.children())
694
- pretrained_layers = pretrained_layers[:-1]
695
- self.pretrained_model = nn.Sequential(*pretrained_layers)
696
-
697
- # Freeze Weights
698
- for param in self.pretrained_model.parameters():
699
- param.requires_grad = False # Freeze all layers
700
-
701
- #encoder
702
- self.mlp = Make_MLP(pretraining_model['args']['hid_size']*pretraining_model['args']['n_proc_steps'], hid_size, hid_size, n_layers, dropout=dropout)
703
-
704
- self.classify = nn.Linear(hid_size, out_size)
705
-
706
- def TL_node_encoder(self, x):
707
- for layer in self.pretrained_model[1]:
708
- x = layer(x)
709
- return x
710
-
711
- def TL_edge_encoder(self, x):
712
- for layer in self.pretrained_model[2]:
713
- x = layer(x)
714
- return x
715
-
716
- def TL_global_encoder(self, x):
717
- for layer in self.pretrained_model[3]:
718
- x = layer(x)
719
- return x
720
-
721
- def TL_node_update(self, x):
722
- for layer in self.pretrained_model[4]:
723
- x = layer(x)
724
- return x
725
-
726
- def TL_edge_update(self, x):
727
- for layer in self.pretrained_model[5]:
728
- x = layer(x)
729
- return x
730
-
731
- def TL_global_update(self, x):
732
- for layer in self.pretrained_model[6]:
733
- x = layer(x)
734
- return x
735
-
736
- def TL_global_decoder(self, x):
737
- for layer in self.pretrained_model[7]:
738
- x = layer(x)
739
- return x
740
-
741
- def Pretrained_Output(self, g):
742
- message_passing = None
743
- h = self.TL_node_encoder(g.ndata['features'])
744
- e = self.TL_edge_encoder(g.edata['features'])
745
- g.ndata['h'] = h
746
- g.edata['e'] = e
747
- if not self.has_global:
748
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
749
- h_global = self.TL_global_encoder(global_feats)
750
- for i in range(self.n_proc_steps):
751
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
752
- g.apply_edges(copy_v)
753
- g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
754
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
755
- g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
756
- h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
757
- if (message_passing is None):
758
- message_passing = h_global.clone()
759
- else:
760
- message_passing = torch.cat((message_passing, h_global.clone()), dim=1)
761
- h_global = self.TL_global_decoder(h_global)
762
- return message_passing
763
-
764
- def forward(self, g, global_feats):
765
- pretrained_global = self.Pretrained_Output(g.clone())
766
- #print(f"message_passing layers have size = {pretrained_global.shape}")
767
- #print(pretrained_global)
768
- global_features = self.mlp(pretrained_global)
769
- return self.classify(global_features)
770
-
771
- class Transferred_Learning_Message_Passing_Parallel(nn.Module):
772
- def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs):
773
- super().__init__()
774
- print(f'Unused args while creating GCN: {kwargs}')
775
- self.n_layers = n_layers
776
- self.n_proc_steps = n_proc_steps
777
- self.layers = nn.ModuleList()
778
- self.has_global = sample_global.shape[1] != 0
779
- gl_size = sample_global.shape[1] if self.has_global else 1
780
-
781
- self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
782
- checkpoint = torch.load(pretraining_path)
783
- self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
784
- pretrained_layers = list(self.pretrained_model.children())
785
- pretrained_layers = pretrained_layers[:-1]
786
- self.pretrained_model = nn.Sequential(*pretrained_layers)
787
-
788
- #encoder
789
- self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
790
- self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
791
- self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
792
-
793
- #GNN
794
- self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
795
- self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
796
- self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
797
-
798
- #decoder
799
- self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
800
-
801
- # Freeze Weights
802
- for param in self.pretrained_model.parameters():
803
- param.requires_grad = False # Freeze all layers
804
-
805
- self.classify = nn.Linear(pretraining_model['args']['hid_size']*pretraining_model['args']['n_proc_steps'] + hid_size, out_size)
806
-
807
- def TL_node_encoder(self, x):
808
- for layer in self.pretrained_model[1]:
809
- x = layer(x)
810
- return x
811
-
812
- def TL_edge_encoder(self, x):
813
- for layer in self.pretrained_model[2]:
814
- x = layer(x)
815
- return x
816
-
817
- def TL_global_encoder(self, x):
818
- for layer in self.pretrained_model[3]:
819
- x = layer(x)
820
- return x
821
-
822
- def TL_node_update(self, x):
823
- for layer in self.pretrained_model[4]:
824
- x = layer(x)
825
- return x
826
-
827
- def TL_edge_update(self, x):
828
- for layer in self.pretrained_model[5]:
829
- x = layer(x)
830
- return x
831
-
832
- def TL_global_update(self, x):
833
- for layer in self.pretrained_model[6]:
834
- x = layer(x)
835
- return x
836
-
837
- def TL_global_decoder(self, x):
838
- for layer in self.pretrained_model[7]:
839
- x = layer(x)
840
- return x
841
-
842
- def Pretrained_Output(self, g):
843
- message_passing = None
844
- h = self.TL_node_encoder(g.ndata['features'])
845
- e = self.TL_edge_encoder(g.edata['features'])
846
- g.ndata['h'] = h
847
- g.edata['e'] = e
848
- if not self.has_global:
849
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
850
- h_global = self.TL_global_encoder(global_feats)
851
- for i in range(self.n_proc_steps):
852
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
853
- g.apply_edges(copy_v)
854
- g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
855
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
856
- g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
857
- h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
858
- if (message_passing is None):
859
- message_passing = h_global.clone()
860
- else:
861
- message_passing = torch.cat((message_passing, h_global.clone()), dim=1)
862
- h_global = self.TL_global_decoder(h_global)
863
- return message_passing
864
-
865
- def forward(self, g, global_feats):
866
- pretrained_message = self.Pretrained_Output(g.clone())
867
- h = self.node_encoder(g.ndata['features'])
868
- e = self.edge_encoder(g.edata['features'])
869
- g.ndata['h'] = h
870
- g.edata['e'] = e
871
- if not self.has_global:
872
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
873
- h_global = self.global_encoder(global_feats)
874
- for i in range(self.n_proc_steps):
875
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
876
- g.apply_edges(copy_v)
877
- g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
878
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
879
- g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
880
- h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
881
- h_global = self.global_decoder(h_global)
882
- return self.classify(torch.cat((pretrained_message, h_global), dim = 1))
883
-
884
- class Transferred_Learning_Finetuning(nn.Module):
885
- def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, frozen_pretraining=False, **kwargs):
886
- super().__init__()
887
- print(f'Unused args while creating GCN: {kwargs}')
888
- self.n_layers = n_layers
889
- self.n_proc_steps = n_proc_steps
890
- self.layers = nn.ModuleList()
891
-
892
- if (len(sample_global) == 0):
893
- self.has_global = False
894
- else:
895
- self.has_global = sample_global.shape[1] != 0
896
- gl_size = sample_global.shape[1] if self.has_global else 1
897
-
898
- self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
899
-
900
- checkpoint = torch.load(pretraining_path)
901
- self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
902
- pretrained_layers = list(self.pretrained_model.children())
903
- pretrained_layers = pretrained_layers[:-1]
904
- self.pretrained_model = nn.Sequential(*pretrained_layers)
905
-
906
- print(f"Freeze Pretraining = {frozen_pretraining}")
907
- if (frozen_pretraining):
908
- for param in self.pretrained_model.parameters():
909
- param.requires_grad = False # Freeze all layers
910
- for param in self.pretrained_model[7]:
911
- param.requires_grad = True
912
-
913
- torch.manual_seed(2)
914
- self.classify = nn.Linear(pretraining_model['args']['hid_size'], out_size)
915
-
916
- def TL_node_encoder(self, x):
917
- for layer in self.pretrained_model[1]:
918
- x = layer(x)
919
- return x
920
-
921
- def TL_edge_encoder(self, x):
922
- for layer in self.pretrained_model[2]:
923
- x = layer(x)
924
- return x
925
-
926
- def TL_global_encoder(self, x):
927
- for layer in self.pretrained_model[3]:
928
- x = layer(x)
929
- return x
930
-
931
- def TL_node_update(self, x):
932
- for layer in self.pretrained_model[4]:
933
- x = layer(x)
934
- return x
935
-
936
- def TL_edge_update(self, x):
937
- for layer in self.pretrained_model[5]:
938
- x = layer(x)
939
- return x
940
-
941
- def TL_global_update(self, x):
942
- for layer in self.pretrained_model[6]:
943
- x = layer(x)
944
- return x
945
-
946
- def TL_global_decoder(self, x):
947
- for layer in self.pretrained_model[7]:
948
- x = layer(x)
949
- return x
950
-
951
- def Pretrained_Output(self, g):
952
- h = self.TL_node_encoder(g.ndata['features'])
953
- e = self.TL_edge_encoder(g.edata['features'])
954
- g.ndata['h'] = h
955
- g.edata['e'] = e
956
- if not self.has_global:
957
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
958
- h_global = self.TL_global_encoder(global_feats)
959
- for i in range(self.n_proc_steps):
960
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
961
- g.apply_edges(copy_v)
962
- g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
963
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
964
- g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
965
- h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
966
- h_global = self.TL_global_decoder(h_global)
967
- return h_global
968
-
969
- def forward(self, g, global_feats):
970
- h_global = self.Pretrained_Output(g.clone())
971
- return self.classify(h_global)
972
-
973
- def representation(self, g, global_feats):
974
- h = self.TL_node_encoder(g.ndata['features'])
975
- e = self.TL_edge_encoder(g.edata['features'])
976
- g.ndata['h'] = h
977
- g.edata['e'] = e
978
- if not self.has_global:
979
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
980
- h_global = self.TL_global_encoder(global_feats)
981
- for i in range(self.n_proc_steps):
982
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
983
- g.apply_edges(copy_v)
984
- g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
985
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
986
- g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
987
- h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
988
-
989
- before_global_decoder = h_global
990
- after_global_decoder = self.TL_global_decoder(before_global_decoder)
991
- after_classify = self.classify(after_global_decoder)
992
- return before_global_decoder, after_global_decoder, after_classify
993
-
994
- def __str__(self):
995
- layer_names = ["node_encoder", "edge_encoder", "global_encoder",
996
- "node_update", "edge_update", "global_update", "global_decoder"]
997
-
998
- layers = [self.pretrained_model[1], self.pretrained_model[2], self.pretrained_model[3],
999
- self.pretrained_model[4], self.pretrained_model[5], self.pretrained_model[6],
1000
- self.pretrained_model[7]]
1001
-
1002
- for i in range(len(layers)):
1003
- print(layer_names[i])
1004
- for layer in layers[i].children():
1005
- if isinstance(layer, nn.Linear):
1006
- print(layer.state_dict())
1007
-
1008
- print("classify")
1009
- print(self.classify.weight)
1010
- return ""
1011
-
1012
-
1013
- class Transferred_Learning_Parallel_Finetuning(nn.Module):
1014
- def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, learning_rate=0.0001, **kwargs):
1015
- super().__init__()
1016
- print(f'Unused args while creating GCN: {kwargs}')
1017
-
1018
- self.learning_rate = learning_rate
1019
-
1020
- self.parallel_params = []
1021
- self.finetuning_params = []
1022
-
1023
-
1024
- self.n_layers = n_layers
1025
- self.n_proc_steps = n_proc_steps
1026
- self.layers = nn.ModuleList()
1027
- self.has_global = sample_global.shape[1] != 0
1028
- gl_size = sample_global.shape[1] if self.has_global else 1
1029
-
1030
- self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
1031
- checkpoint = torch.load(pretraining_path)
1032
- self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
1033
- pretrained_layers = list(self.pretrained_model.children())
1034
- pretrained_layers = pretrained_layers[:-1]
1035
- self.pretrained_model = nn.Sequential(*pretrained_layers)
1036
-
1037
- self.finetuning_params.append(self.pretrained_model)
1038
-
1039
- #encoder
1040
- self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
1041
- self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
1042
- self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
1043
-
1044
- #GNN
1045
- self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1046
- self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1047
- self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1048
-
1049
- #decoder
1050
- self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1051
- self.classify = nn.Linear(hid_size + pretraining_model['args']['hid_size'], out_size)
1052
-
1053
- self.parallel_params.append(self.node_encoder)
1054
- self.parallel_params.append(self.edge_encoder)
1055
- self.parallel_params.append(self.global_encoder)
1056
- self.parallel_params.append(self.node_update)
1057
- self.parallel_params.append(self.edge_update)
1058
- self.parallel_params.append(self.global_update)
1059
- self.parallel_params.append(self.global_decoder)
1060
- self.parallel_params.append(self.classify)
1061
-
1062
- def TL_node_encoder(self, x):
1063
- for layer in self.pretrained_model[1]:
1064
- x = layer(x)
1065
- return x
1066
-
1067
- def TL_edge_encoder(self, x):
1068
- for layer in self.pretrained_model[2]:
1069
- x = layer(x)
1070
- return x
1071
-
1072
- def TL_global_encoder(self, x):
1073
- for layer in self.pretrained_model[3]:
1074
- x = layer(x)
1075
- return x
1076
-
1077
- def TL_node_update(self, x):
1078
- for layer in self.pretrained_model[4]:
1079
- x = layer(x)
1080
- return x
1081
-
1082
- def TL_edge_update(self, x):
1083
- for layer in self.pretrained_model[5]:
1084
- x = layer(x)
1085
- return x
1086
-
1087
- def TL_global_update(self, x):
1088
- for layer in self.pretrained_model[6]:
1089
- x = layer(x)
1090
- return x
1091
-
1092
- def TL_global_decoder(self, x):
1093
- for layer in self.pretrained_model[7]:
1094
- x = layer(x)
1095
- return x
1096
-
1097
- def Pretrained_Output(self, g):
1098
- h = self.TL_node_encoder(g.ndata['features'])
1099
- e = self.TL_edge_encoder(g.edata['features'])
1100
- g.ndata['h'] = h
1101
- g.edata['e'] = e
1102
- if not self.has_global:
1103
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
1104
- h_global = self.TL_global_encoder(global_feats)
1105
- for i in range(self.n_proc_steps):
1106
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
1107
- g.apply_edges(copy_v)
1108
- g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
1109
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
1110
- g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
1111
- h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
1112
- h_global = self.TL_global_decoder(h_global)
1113
- return h_global
1114
-
1115
- def forward(self, g, global_feats):
1116
- pretrained_global = self.Pretrained_Output(g.clone())
1117
- h = self.node_encoder(g.ndata['features'])
1118
- e = self.edge_encoder(g.edata['features'])
1119
- g.ndata['h'] = h
1120
- g.edata['e'] = e
1121
- if not self.has_global:
1122
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
1123
- h_global = self.global_encoder(global_feats)
1124
- for i in range(self.n_proc_steps):
1125
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
1126
- g.apply_edges(copy_v)
1127
- g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
1128
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
1129
- g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
1130
- h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
1131
- h_global = self.global_decoder(h_global)
1132
-
1133
- return self.classify(torch.cat((pretrained_global, h_global), dim = 1))
1134
-
1135
- def parameters(self, recurse: bool = True):
1136
- params = []
1137
- for model_section in self.parallel_params:
1138
- if (type(self.learning_rate) == dict and self.learning_rate["trainable_lr"]):
1139
- params.append({'params': model_section.parameters(), 'lr': self.learning_rate["trainable_lr"]})
1140
- else:
1141
- params.append({'params': model_section.parameters(), 'lr': 0.0001})
1142
- for model_section in self.finetuning_params:
1143
- if (type(self.learning_rate) == dict and self.learning_rate["finetuning_lr"]):
1144
- params.append({'params': model_section.parameters(), 'lr': self.learning_rate["finetuning_lr"]})
1145
- else:
1146
- params.append({'params': model_section.parameters(), 'lr': 0.0001})
1147
- return params
1148
-
1149
- class Attention(nn.Module):
1150
- def __init__(self, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, num_heads = 1, **kwargs):
1151
- super().__init__()
1152
- print(f'Unused args while creating GCN: {kwargs}')
1153
- self.n_layers = n_layers
1154
- self.n_proc_steps = n_proc_steps
1155
- self.layers = nn.ModuleList()
1156
- self.has_global = sample_global.shape[1] != 0
1157
- gl_size = sample_global.shape[1] if self.has_global else 1
1158
-
1159
- #encoder
1160
- self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
1161
- self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
1162
-
1163
- #GNN
1164
- self.node_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1165
- self.global_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1166
-
1167
- #decoder
1168
- self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1169
- self.classify = nn.Linear(hid_size, out_size)
1170
-
1171
- #attention
1172
- self.multihead_attn = nn.MultiheadAttention(hid_size, num_heads, dropout=dropout, batch_first=True)
1173
- self.queries = nn.Linear(hid_size, hid_size)
1174
- self.keys = nn.Linear(hid_size, hid_size)
1175
- self.values = nn.Linear(hid_size, hid_size)
1176
-
1177
- def forward(self, g, global_feats):
1178
- h = self.node_encoder(g.ndata['features'])
1179
- g.ndata['h'] = h
1180
-
1181
- if not self.has_global:
1182
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
1183
-
1184
- batch_num_nodes = None
1185
- sum_weights = None
1186
- if "w" in g.ndata:
1187
- batch_indices = g.batch_num_nodes()
1188
- # Find non-zero rows (non-padded nodes)
1189
- non_padded_nodes_mask = torch.any(g.ndata['features'] != 0, dim=1)
1190
- # Split the mask according to the batch indices
1191
- batch_num_nodes = []
1192
- start_idx = 0
1193
- for num_nodes in batch_indices:
1194
- end_idx = start_idx + num_nodes
1195
- non_padded_count = non_padded_nodes_mask[start_idx:end_idx].sum().item()
1196
- batch_num_nodes.append(non_padded_count)
1197
- start_idx = end_idx
1198
- batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata['features'].device)
1199
- sum_weights = batch_num_nodes[:, None].repeat(1, 64)
1200
- global_feats = batch_num_nodes[:, None].to(torch.float)
1201
-
1202
- h_global = self.global_encoder(global_feats)
1203
-
1204
- h_original_shape = h.shape
1205
- num_graphs = len(dgl.unbatch(g))
1206
- num_nodes = g.batch_num_nodes()[0].item()
1207
- padding_mask = g.ndata['padding_mask'] > 0
1208
- padding_mask = torch.reshape(padding_mask, (num_graphs, num_nodes))
1209
-
1210
- h = g.ndata['h']
1211
- query = self.queries(h)
1212
- key = self.keys(h)
1213
- value = self.values(h)
1214
- query = torch.reshape(query, (num_graphs, num_nodes, h_original_shape[1]))
1215
- key = torch.reshape(key, (num_graphs, num_nodes, h_original_shape[1]))
1216
- value = torch.reshape(value, (num_graphs, num_nodes, h_original_shape[1]))
1217
- h, _ = self.multihead_attn(query, key, value, key_padding_mask=padding_mask)
1218
- h = torch.reshape(h, h_original_shape)
1219
-
1220
- h = self.node_update(torch.cat((h, broadcast_global_to_nodes(g, h_global)), dim = 1))
1221
- g.ndata['h'] = h
1222
- mean_nodes = dgl.sum_nodes(g, 'h', 'w') / sum_weights
1223
- h_global = self.global_update(torch.cat((h_global, mean_nodes), dim = 1))
1224
- h_global = self.global_decoder(h_global)
1225
- return self.classify(h_global)
1226
-
1227
- class Attention_Edge_Network(nn.Module):
1228
- def __init__(self, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, num_heads = 1, **kwargs):
1229
- super().__init__()
1230
- print(f'Unused args while creating GCN: {kwargs}')
1231
- self.n_layers = n_layers
1232
- self.n_proc_steps = n_proc_steps
1233
- self.layers = nn.ModuleList()
1234
- self.has_global = sample_global.shape[1] != 0
1235
- gl_size = sample_global.shape[1] if self.has_global else 1
1236
-
1237
- #encoder
1238
- self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
1239
- self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
1240
- self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
1241
-
1242
- #GNN
1243
- self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1244
- self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1245
- self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1246
-
1247
- #decoder
1248
- self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1249
- self.classify = nn.Linear(hid_size, out_size)
1250
-
1251
-
1252
- #attention
1253
- self.multihead_attn = nn.MultiheadAttention(hid_size, num_heads, dropout=dropout, batch_first=True)
1254
- self.queries = nn.Linear(hid_size, hid_size)
1255
- self.keys = nn.Linear(hid_size, hid_size)
1256
- self.values = nn.Linear(hid_size, hid_size)
1257
-
1258
- def forward(self, g, global_feats):
1259
- h = self.node_encoder(g.ndata['features'])
1260
- e = self.edge_encoder(g.edata['features'])
1261
- g.ndata['h'] = h
1262
- g.edata['e'] = e
1263
-
1264
- if not self.has_global:
1265
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
1266
- h_global = self.global_encoder(global_feats)
1267
-
1268
- h = g.ndata['h']
1269
- h_original_shape = h.shape
1270
- num_graphs = len(dgl.unbatch(g))
1271
- num_nodes = g.batch_num_nodes()[0].item()
1272
- padding_mask = g.ndata['padding_mask'] > 0
1273
- padding_mask = torch.reshape(padding_mask, (num_graphs, num_nodes))
1274
-
1275
- for i in range(self.n_proc_steps):
1276
-
1277
- h = g.ndata['h']
1278
- query = self.queries(h)
1279
- key = self.keys(h)
1280
- value = self.values(h)
1281
- query = torch.reshape(query, (num_graphs, num_nodes, h_original_shape[1]))
1282
- key = torch.reshape(key, (num_graphs, num_nodes, h_original_shape[1]))
1283
- value = torch.reshape(value, (num_graphs, num_nodes, h_original_shape[1]))
1284
- h, _ = self.multihead_attn(query, key, value, key_padding_mask=padding_mask)
1285
- h = torch.reshape(h, h_original_shape)
1286
-
1287
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
1288
- g.apply_edges(copy_v)
1289
- g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
1290
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
1291
- g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
1292
- h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h', 'w'), dgl.mean_edges(g, 'e')), dim = 1))
1293
- h_global = self.global_decoder(h_global)
1294
- return self.classify(h_global)
1295
-
1296
- class Attention_Unbatched(nn.Module):
1297
- def __init__(self, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, num_heads = 1, **kwargs):
1298
- super().__init__()
1299
- print(f'Unused args while creating GCN: {kwargs}')
1300
- self.n_layers = n_layers
1301
- self.n_proc_steps = n_proc_steps
1302
- self.layers = nn.ModuleList()
1303
- self.has_global = sample_global.shape[1] != 0
1304
- gl_size = sample_global.shape[1] if self.has_global else 1
1305
-
1306
- #encoder
1307
- self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
1308
- self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
1309
- self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
1310
-
1311
- #GNN
1312
- self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1313
- self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1314
- self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1315
-
1316
- #decoder
1317
- self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1318
- self.classify = nn.Linear(hid_size, out_size)
1319
-
1320
-
1321
- #attention
1322
- self.multihead_attn = nn.MultiheadAttention(hid_size, 1, dropout=dropout)
1323
- self.queries = nn.Linear(hid_size, hid_size)
1324
- self.keys = nn.Linear(hid_size, hid_size)
1325
- self.values = nn.Linear(hid_size, hid_size)
1326
-
1327
-
1328
-
1329
- def forward(self, g, global_feats):
1330
-
1331
- h = self.node_encoder(g.ndata['features'])
1332
- e = self.edge_encoder(g.edata['features'])
1333
- g.ndata['h'] = h
1334
- g.edata['e'] = e
1335
-
1336
- if not self.has_global:
1337
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
1338
- h_global = self.global_encoder(global_feats)
1339
-
1340
- for i in range(self.n_proc_steps):
1341
-
1342
- unbatched_g = dgl.unbatch(g)
1343
- for graph in unbatched_g:
1344
- h = graph.ndata['h']
1345
- h, _ = self.multihead_attn(self.queries(h), self.keys(h), self.values(h))
1346
- graph.ndata['h'] = h
1347
- g = dgl.batch(unbatched_g)
1348
-
1349
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
1350
- g.apply_edges(copy_v)
1351
- g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
1352
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
1353
- g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
1354
- h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
1355
- h_global = self.global_decoder(h_global)
1356
- return self.classify(h_global)
1357
-
1358
- class Transferred_Learning_Attention(nn.Module):
1359
- def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, num_heads, dropout=0, learning_rate=0.0001, **kwargs):
1360
- super().__init__()
1361
- print(f'Unused args while creating GCN: {kwargs}')
1362
- self.n_layers = n_layers
1363
- self.n_proc_steps = n_proc_steps
1364
- self.layers = nn.ModuleList()
1365
- self.has_global = sample_global.shape[1] != 0
1366
- gl_size = sample_global.shape[1] if self.has_global else 1
1367
-
1368
- self.learning_rate = learning_rate
1369
-
1370
- self.pretraining_params = []
1371
- self.attention_params = []
1372
-
1373
- self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
1374
-
1375
- checkpoint = torch.load(pretraining_path)
1376
- self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
1377
- pretrained_layers = list(self.pretrained_model.children())
1378
- pretrained_layers = pretrained_layers[:-1]
1379
- self.pretrained_model = nn.Sequential(*pretrained_layers)
1380
-
1381
- self.pretraining_params.append(self.pretrained_model[1])
1382
- self.pretraining_params.append(self.pretrained_model[3])
1383
- self.pretraining_params.append(self.pretrained_model[7])
1384
-
1385
- #attention
1386
- self.multihead_attn = nn.MultiheadAttention(hid_size, num_heads, dropout=dropout, batch_first=True)
1387
- self.queries = nn.Linear(hid_size, hid_size)
1388
- self.keys = nn.Linear(hid_size, hid_size)
1389
- self.values = nn.Linear(hid_size, hid_size)
1390
-
1391
- self.node_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1392
- self.global_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1393
-
1394
- self.classify = nn.Linear(pretraining_model['args']['hid_size'], out_size)
1395
-
1396
- self.attention_params.append(self.multihead_attn)
1397
-
1398
- self.attention_params.append(self.queries)
1399
- self.attention_params.append(self.keys)
1400
- self.attention_params.append(self.values)
1401
- self.attention_params.append(self.classify)
1402
- self.attention_params.append(self.node_update)
1403
- self.attention_params.append(self.global_update)
1404
-
1405
- def TL_node_encoder(self, x):
1406
- for layer in self.pretrained_model[1]:
1407
- x = layer(x)
1408
- return x
1409
-
1410
- def TL_global_encoder(self, x):
1411
- for layer in self.pretrained_model[3]:
1412
- x = layer(x)
1413
- return x
1414
-
1415
- def TL_global_decoder(self, x):
1416
- for layer in self.pretrained_model[7]:
1417
- x = layer(x)
1418
- return x
1419
-
1420
- def forward(self, g, global_feats):
1421
- h = self.TL_node_encoder(g.ndata['features'])
1422
- g.ndata['h'] = h
1423
-
1424
- if not self.has_global:
1425
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
1426
-
1427
- batch_num_nodes = None
1428
- sum_weights = None
1429
- if "w" in g.ndata:
1430
- batch_indices = g.batch_num_nodes()
1431
- # Find non-zero rows (non-padded nodes)
1432
- non_padded_nodes_mask = torch.any(g.ndata['features'] != 0, dim=1)
1433
- # Split the mask according to the batch indices
1434
- batch_num_nodes = []
1435
- start_idx = 0
1436
- for num_nodes in batch_indices:
1437
- end_idx = start_idx + num_nodes
1438
- non_padded_count = non_padded_nodes_mask[start_idx:end_idx].sum().item()
1439
- batch_num_nodes.append(non_padded_count)
1440
- start_idx = end_idx
1441
- batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata['features'].device)
1442
- sum_weights = batch_num_nodes[:, None].repeat(1, 64)
1443
- global_feats = batch_num_nodes[:, None].to(torch.float)
1444
-
1445
- h_global = self.TL_global_encoder(global_feats)
1446
-
1447
- h_original_shape = h.shape
1448
- num_graphs = len(dgl.unbatch(g))
1449
- num_nodes = g.batch_num_nodes()[0].item()
1450
- padding_mask = g.ndata['padding_mask'] > 0
1451
- padding_mask = torch.reshape(padding_mask, (num_graphs, num_nodes))
1452
-
1453
- h = g.ndata['h']
1454
- query = self.queries(h)
1455
- key = self.keys(h)
1456
- value = self.values(h)
1457
- query = torch.reshape(query, (num_graphs, num_nodes, h_original_shape[1]))
1458
- key = torch.reshape(key, (num_graphs, num_nodes, h_original_shape[1]))
1459
- value = torch.reshape(value, (num_graphs, num_nodes, h_original_shape[1]))
1460
- h, _ = self.multihead_attn(query, key, value, key_padding_mask=padding_mask)
1461
- h = torch.reshape(h, h_original_shape)
1462
-
1463
- h = self.node_update(torch.cat((h, broadcast_global_to_nodes(g, h_global)), dim = 1))
1464
- g.ndata['h'] = h
1465
- mean_nodes = dgl.sum_nodes(g, 'h', 'w') / sum_weights
1466
- h_global = self.global_update(torch.cat((h_global, mean_nodes), dim = 1))
1467
- h_global = self.TL_global_decoder(h_global)
1468
- return self.classify(h_global)
1469
-
1470
- def parameters(self, recurse: bool = True):
1471
- params = []
1472
- for model_section in self.pretraining_params:
1473
- if (type(self.learning_rate) == dict and self.learning_rate["pretraining_lr"]):
1474
- params.append({'params': model_section.parameters(), 'lr': self.learning_rate["pretraining_lr"]})
1475
- else:
1476
- params.append({'params': model_section.parameters(), 'lr': 0.0001})
1477
- for model_section in self.attention_params:
1478
- if (type(self.learning_rate) == dict and self.learning_rate["attention_lr"]):
1479
- params.append({'params': model_section.parameters(), 'lr': self.learning_rate["attention_lr"]})
1480
- else:
1481
- params.append({'params': model_section.parameters(), 'lr': 0.0001})
1482
- return params
1483
-
1484
- class Multimodel_Transferred_Learning(nn.Module):
1485
- def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, frozen_pretraining=True, learning_rate=None, **kwargs):
1486
- super().__init__()
1487
- print(f'Unused args while creating GCN: {kwargs}')
1488
- self.n_layers = n_layers
1489
- self.n_proc_steps = n_proc_steps
1490
- self.layers = nn.ModuleList()
1491
- self.has_global = sample_global.shape[1] != 0
1492
- gl_size = sample_global.shape[1] if self.has_global else 1
1493
-
1494
- self.learning_rate = learning_rate
1495
- input_size = 0
1496
-
1497
- self.pretraining_params = []
1498
- self.model_params = []
1499
-
1500
- self.pretrained_models = []
1501
- for model, path in zip(pretraining_model, pretraining_path):
1502
- input_size += model['args']['hid_size']
1503
- model = utils.buildFromConfig(model, {'sample_graph': sample_graph, 'sample_global': sample_global})
1504
-
1505
- checkpoint = torch.load(path)['model_state_dict']
1506
- new_state_dict = {}
1507
- for k, v in checkpoint.items():
1508
- new_key = k.replace('module.', '')
1509
- new_state_dict[new_key] = v
1510
- model.load_state_dict(new_state_dict)
1511
- pretrained_layers = list(model.children())
1512
- pretrained_layers = pretrained_layers[:-1]
1513
-
1514
- model = nn.Sequential(*pretrained_layers)
1515
-
1516
- # Freeze Weights
1517
- print(f"Freeze Pretraining = {frozen_pretraining}")
1518
- if (frozen_pretraining):
1519
- for param in model.parameters():
1520
- param.requires_grad = False # Freeze all layers
1521
- self.pretraining_params.append(model)
1522
- self.pretrained_models.append(model)
1523
-
1524
- print(f"len(pretrained_models) = {len(self.pretrained_models)}")
1525
- print(f"input size = {input_size}")
1526
-
1527
- self.final_mlp = Make_MLP(input_size, hid_size, hid_size, n_layers, dropout=dropout)
1528
- self.classify = nn.Linear(hid_size, out_size)
1529
-
1530
- self.model_params.append(self.final_mlp)
1531
- self.model_params.append(self.classify)
1532
-
1533
- def TL_node_encoder(self, x, model_idx):
1534
- try:
1535
- for layer in self.pretrained_models[model_idx][1]:
1536
- x = layer(x)
1537
- return x
1538
- except (NotImplementedError, IndexError):
1539
- for layer in self.pretrained_models[model_idx][1][1]:
1540
- x = layer(x)
1541
- return x
1542
-
1543
- def TL_edge_encoder(self, x, model_idx):
1544
- try:
1545
- for layer in self.pretrained_models[model_idx][2]:
1546
- x = layer(x)
1547
- return x
1548
- except (NotImplementedError, IndexError):
1549
- for layer in self.pretrained_models[model_idx][1][2]:
1550
- x = layer(x)
1551
- return x
1552
-
1553
- def TL_global_encoder(self, x, model_idx):
1554
- try:
1555
- for layer in self.pretrained_models[model_idx][3]:
1556
- x = layer(x)
1557
- return x
1558
- except (NotImplementedError, IndexError):
1559
- for layer in self.pretrained_models[model_idx][1][3]:
1560
- x = layer(x)
1561
- return x
1562
-
1563
- def TL_node_update(self, x, model_idx):
1564
- try:
1565
- for layer in self.pretrained_models[model_idx][4]:
1566
- x = layer(x)
1567
- return x
1568
- except (NotImplementedError, IndexError):
1569
- for layer in self.pretrained_models[model_idx][1][4]:
1570
- x = layer(x)
1571
- return x
1572
-
1573
- def TL_edge_update(self, x, model_idx):
1574
- try:
1575
- for layer in self.pretrained_models[model_idx][5]:
1576
- x = layer(x)
1577
- return x
1578
- except (NotImplementedError, IndexError):
1579
- for layer in self.pretrained_models[model_idx][1][5]:
1580
- x = layer(x)
1581
- return x
1582
-
1583
- def TL_global_update(self, x, model_idx):
1584
- try:
1585
- for layer in self.pretrained_models[model_idx][6]:
1586
- x = layer(x)
1587
- return x
1588
- except (NotImplementedError, IndexError):
1589
- for layer in self.pretrained_models[model_idx][1][6]:
1590
- x = layer(x)
1591
- return x
1592
-
1593
- def TL_global_decoder(self, x, model_idx):
1594
- try:
1595
- for layer in self.pretrained_models[model_idx][7]:
1596
- x = layer(x)
1597
- return x
1598
- except (NotImplementedError, IndexError):
1599
- for layer in self.pretrained_models[model_idx][1][7]:
1600
- x = layer(x)
1601
- return x
1602
-
1603
- def Pretrained_Output(self, g, model_idx):
1604
- h = self.TL_node_encoder(g.ndata['features'], model_idx)
1605
- e = self.TL_edge_encoder(g.edata['features'], model_idx)
1606
- g.ndata['h'] = h
1607
- g.edata['e'] = e
1608
- if not self.has_global:
1609
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
1610
- h_global = self.TL_global_encoder(global_feats, model_idx)
1611
- for i in range(self.n_proc_steps):
1612
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
1613
- g.apply_edges(copy_v)
1614
- g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1), model_idx)
1615
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
1616
- g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1), model_idx)
1617
- h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1), model_idx)
1618
- # h_global = self.TL_global_decoder(h_global, model_idx)
1619
- return h_global
1620
-
1621
- def forward(self, g, global_feats):
1622
- h_global = []
1623
- for i in range(len(self.pretrained_models)):
1624
- h_global.append(self.Pretrained_Output(g.clone(), i))
1625
- h_global = torch.concatenate(h_global, dim=1)
1626
- return self.classify(self.final_mlp(h_global))
1627
-
1628
- def to(self, device):
1629
- for i in range(len(self.pretrained_models)):
1630
- self.pretrained_models[i].to(device)
1631
- self.classify.to(device)
1632
- self.final_mlp.to(device)
1633
- return self
1634
-
1635
- def parameters(self, recurse: bool = True):
1636
- params = []
1637
- for model_section in self.pretraining_params:
1638
- if (type(self.learning_rate) == dict and self.learning_rate["pretraining_lr"]):
1639
- params.append({'params': model_section.parameters(), 'lr': self.learning_rate["pretraining_lr"]})
1640
- else:
1641
- params.append({'params': model_section.parameters(), 'lr': 0.00001})
1642
- for model_section in self.model_params:
1643
- if (type(self.learning_rate) == dict and self.learning_rate["model_lr"]):
1644
- params.append({'params': model_section.parameters(), 'lr': self.learning_rate["model_lr"]})
1645
- else:
1646
- params.append({'params': model_section.parameters(), 'lr': 0.0001})
1647
- return params
1648
-
1649
-
1650
- class MultiModel(nn.Module):
1651
- def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, frozen_pretraining=True, learning_rate=None, **kwargs):
1652
- super().__init__()
1653
- print(f'Unused args while creating GCN: {kwargs}')
1654
- self.n_layers = n_layers
1655
- self.n_proc_steps = n_proc_steps
1656
- self.layers = nn.ModuleList()
1657
- self.has_global = sample_global.shape[1] != 0
1658
- gl_size = sample_global.shape[1] if self.has_global else 1
1659
-
1660
- self.learning_rate = learning_rate
1661
- input_size = 0
1662
-
1663
- self.model_params = []
1664
- self.pretraining_params = []
1665
-
1666
- self.pretrained_models = []
1667
- for model, path in zip(pretraining_model, pretraining_path):
1668
- input_size += model['args']['hid_size']
1669
- model = utils.buildFromConfig(model, {'sample_graph': sample_graph, 'sample_global': sample_global})
1670
-
1671
- checkpoint = torch.load(path)['model_state_dict']
1672
- new_state_dict = {}
1673
- for k, v in checkpoint.items():
1674
- new_key = k.replace('module.', '')
1675
- new_state_dict[new_key] = v
1676
- model.load_state_dict(new_state_dict)
1677
- pretrained_layers = list(model.children())
1678
- pretrained_layers = pretrained_layers[:-1]
1679
-
1680
- model = nn.Sequential(*pretrained_layers)
1681
-
1682
- # Freeze Weights
1683
- print(f"Freeze Pretraining = {frozen_pretraining}")
1684
- if (frozen_pretraining):
1685
- for param in model.parameters():
1686
- param.requires_grad = False # Freeze all layers
1687
- self.pretraining_params.append(model)
1688
- self.pretrained_models.append(model)
1689
-
1690
- print(f"len(pretrained_models) = {len(self.pretrained_models)}")
1691
- print(f"input size = {input_size}")
1692
-
1693
- #encoder
1694
- self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
1695
- self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
1696
- self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
1697
-
1698
- #GNN
1699
- self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1700
- self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1701
- self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1702
-
1703
- self.final_mlp = Make_MLP(input_size + hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1704
- self.classify = nn.Linear(hid_size, out_size)
1705
-
1706
- self.model_params.append(self.final_mlp)
1707
- self.model_params.append(self.classify)
1708
-
1709
- def TL_node_encoder(self, x, model_idx):
1710
- try:
1711
- for layer in self.pretrained_models[model_idx][1]:
1712
- x = layer(x)
1713
- return x
1714
- except (NotImplementedError, IndexError):
1715
- for layer in self.pretrained_models[model_idx][1][1]:
1716
- x = layer(x)
1717
- return x
1718
-
1719
- def TL_edge_encoder(self, x, model_idx):
1720
- try:
1721
- for layer in self.pretrained_models[model_idx][2]:
1722
- x = layer(x)
1723
- return x
1724
- except (NotImplementedError, IndexError):
1725
- for layer in self.pretrained_models[model_idx][1][2]:
1726
- x = layer(x)
1727
- return x
1728
-
1729
- def TL_global_encoder(self, x, model_idx):
1730
- try:
1731
- for layer in self.pretrained_models[model_idx][3]:
1732
- x = layer(x)
1733
- return x
1734
- except (NotImplementedError, IndexError):
1735
- for layer in self.pretrained_models[model_idx][1][3]:
1736
- x = layer(x)
1737
- return x
1738
-
1739
- def TL_node_update(self, x, model_idx):
1740
- try:
1741
- for layer in self.pretrained_models[model_idx][4]:
1742
- x = layer(x)
1743
- return x
1744
- except (NotImplementedError, IndexError):
1745
- for layer in self.pretrained_models[model_idx][1][4]:
1746
- x = layer(x)
1747
- return x
1748
-
1749
- def TL_edge_update(self, x, model_idx):
1750
- try:
1751
- for layer in self.pretrained_models[model_idx][5]:
1752
- x = layer(x)
1753
- return x
1754
- except (NotImplementedError, IndexError):
1755
- for layer in self.pretrained_models[model_idx][1][5]:
1756
- x = layer(x)
1757
- return x
1758
-
1759
- def TL_global_update(self, x, model_idx):
1760
- try:
1761
- for layer in self.pretrained_models[model_idx][6]:
1762
- x = layer(x)
1763
- return x
1764
- except (NotImplementedError, IndexError):
1765
- for layer in self.pretrained_models[model_idx][1][6]:
1766
- x = layer(x)
1767
- return x
1768
-
1769
- def TL_global_decoder(self, x, model_idx):
1770
- try:
1771
- for layer in self.pretrained_models[model_idx][7]:
1772
- x = layer(x)
1773
- return x
1774
- except (NotImplementedError, IndexError):
1775
- for layer in self.pretrained_models[model_idx][1][7]:
1776
- x = layer(x)
1777
- return x
1778
-
1779
- def Pretrained_Output(self, g, model_idx):
1780
- h = self.TL_node_encoder(g.ndata['features'], model_idx)
1781
- e = self.TL_edge_encoder(g.edata['features'], model_idx)
1782
- g.ndata['h'] = h
1783
- g.edata['e'] = e
1784
- if not self.has_global:
1785
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
1786
- h_global = self.TL_global_encoder(global_feats, model_idx)
1787
- for i in range(self.n_proc_steps):
1788
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
1789
- g.apply_edges(copy_v)
1790
- g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1), model_idx)
1791
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
1792
- g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1), model_idx)
1793
- h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1), model_idx)
1794
- # h_global = self.TL_global_decoder(h_global, model_idx)
1795
- return h_global
1796
-
1797
- def forward(self, g, global_feats):
1798
- h = self.node_encoder(g.ndata['features'])
1799
- e = self.edge_encoder(g.edata['features'])
1800
- g.ndata['h'] = h
1801
- g.edata['e'] = e
1802
- if not self.has_global:
1803
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
1804
- h_global = self.global_encoder(global_feats)
1805
- for i in range(self.n_proc_steps):
1806
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
1807
- g.apply_edges(copy_v)
1808
- g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
1809
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
1810
- g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
1811
- h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
1812
- h_global = [h_global]
1813
- for i in range(len(self.pretrained_models)):
1814
- h_global.append(self.Pretrained_Output(g.clone(), i))
1815
- h_global = torch.concatenate(h_global, dim=1)
1816
- return self.classify(self.final_mlp(h_global))
1817
-
1818
- def to(self, device):
1819
- for i in range(len(self.pretrained_models)):
1820
- self.pretrained_models[i].to(device)
1821
- self.classify.to(device)
1822
- self.final_mlp.to(device)
1823
- self.node_encoder.to(device)
1824
- self.edge_encoder.to(device)
1825
- self.global_encoder.to(device)
1826
-
1827
- self.node_update.to(device)
1828
- self.edge_update.to(device)
1829
- self.global_update.to(device)
1830
- return self
1831
-
1832
- def parameters(self, recurse: bool = True):
1833
- params = []
1834
- for i, model_section in enumerate(self.pretraining_params):
1835
- if (type(self.learning_rate) == dict and self.learning_rate["pretraining_lr"]):
1836
- print(f"Pretraining LR = {self.learning_rate['pretraining_lr'][i]}")
1837
- params.append({'params': model_section.parameters(), 'lr': self.learning_rate["pretraining_lr"][i]})
1838
- else:
1839
- print(f"Pretraining LR = 0.00001")
1840
- params.append({'params': model_section.parameters(), 'lr': 0.00001})
1841
- for model_section in self.model_params:
1842
- if (type(self.learning_rate) == dict and self.learning_rate["model_lr"]):
1843
- print(f"Model LR = {self.learning_rate['model_lr']}")
1844
- params.append({'params': model_section.parameters(), 'lr': self.learning_rate["model_lr"]})
1845
- else:
1846
- print(f"Model LR = 0.0001")
1847
- params.append({'params': model_section.parameters(), 'lr': 0.0001})
1848
- return params
1849
-
1850
-
1851
- class Clustering(nn.Module):
1852
- def __init__(self, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs):
1853
- super().__init__()
1854
- print(f'Unused args while creating GCN: {kwargs}')
1855
- self.n_layers = n_layers
1856
- self.n_proc_steps = n_proc_steps
1857
- self.layers = nn.ModuleList()
1858
- if (len(sample_global) == 0):
1859
- self.has_global = False
1860
- else:
1861
- self.has_global = sample_global.shape[1] != 0
1862
- gl_size = sample_global.shape[1] if self.has_global else 1
1863
-
1864
- #encoder
1865
- self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
1866
- self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
1867
- self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
1868
-
1869
- #GNN
1870
- self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1871
- self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1872
- self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1873
-
1874
- #decoder
1875
- self.global_decoder = Make_MLP(hid_size, hid_size, out_size, n_layers, dropout=dropout)
1876
-
1877
- def model_forward(self, g, global_feats, features = 'features'):
1878
- h = self.node_encoder(g.ndata[features])
1879
- e = self.edge_encoder(g.edata[features])
1880
-
1881
- g.ndata['h'] = h
1882
- g.edata['e'] = e
1883
- if not self.has_global:
1884
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
1885
-
1886
- batch_num_nodes = None
1887
- sum_weights = None
1888
- if "w" in g.ndata:
1889
- batch_indices = g.batch_num_nodes()
1890
- # Find non-zero rows (non-padded nodes)
1891
- non_padded_nodes_mask = torch.any(g.ndata[features] != 0, dim=1)
1892
- # Split the mask according to the batch indices
1893
- batch_num_nodes = []
1894
- start_idx = 0
1895
- for num_nodes in batch_indices:
1896
- end_idx = start_idx + num_nodes
1897
- non_padded_count = non_padded_nodes_mask[start_idx:end_idx].sum().item()
1898
- batch_num_nodes.append(non_padded_count)
1899
- start_idx = end_idx
1900
- batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata[features].device)
1901
- sum_weights = batch_num_nodes[:, None].repeat(1, 64)
1902
- global_feats = batch_num_nodes[:, None].to(torch.float)
1903
-
1904
- h_global = self.global_encoder(global_feats)
1905
- for i in range(self.n_proc_steps):
1906
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
1907
- g.apply_edges(copy_v)
1908
- g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
1909
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
1910
- g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
1911
- if "w" in g.ndata:
1912
- mean_nodes = dgl.sum_nodes(g, 'h', 'w') / sum_weights
1913
- h_global = self.global_update(torch.cat((h_global, mean_nodes, dgl.mean_edges(g, 'e')), dim = 1))
1914
- else:
1915
- h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
1916
- h_global = self.global_decoder(h_global)
1917
- return h_global
1918
-
1919
- def forward(self, g, global_feats):
1920
- h_global = self.model_forward(g, global_feats, 'features')
1921
- h_global_augmented = self.model_forward(g, global_feats, 'augmented_features')
1922
- return torch.cat((h_global, h_global_augmented), dim=1)
1923
-
1924
- def representation(self, g, global_feats):
1925
- h_global = self.model_forward(g, global_feats, 'features')
1926
- h_global_augmented = self.model_forward(g, global_feats, 'augmented_features')
1927
- return h_global, h_global_augmented, torch.cat((h_global, h_global_augmented), dim=1)
1928
-
1929
- def __str__(self):
1930
- layer_names = ["node_encoder", "edge_encoder", "global_encoder",
1931
- "node_update", "edge_update", "global_update", "global_decoder"]
1932
-
1933
- layers = [self.node_encoder, self.edge_encoder, self.global_encoder,
1934
- self.node_update, self.edge_update, self.global_update, self.global_decoder]
1935
-
1936
- for i in range(len(layers)):
1937
- print(layer_names[i])
1938
- for layer in layers[i].children():
1939
- if isinstance(layer, nn.Linear):
1940
- print(layer.state_dict())
1941
-
1942
- print("classify")
1943
- print(self.classify.weight)
1944
- return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/__pycache__/GCN.cpython-38.pyc DELETED
Binary file (57 kB)
 
models/__pycache__/loss.cpython-38.pyc DELETED
Binary file (11.4 kB)
 
models/loss.py DELETED
@@ -1,311 +0,0 @@
1
- from torch import nn
2
- import torch
3
- from root_gnn_base import utils
4
- import numpy as np
5
-
6
- class MaskedLoss():
7
- def __init__(self, mask = []):
8
- self.mask = mask
9
-
10
- def make_mask(self, targets):
11
- mask = torch.ones_like(targets[:,0])
12
- for m in self.mask:
13
- if m['op'] == 'eq':
14
- mask[targets[:,m['idx']] == m['val']] = 0
15
- elif m['op'] == 'gt':
16
- mask[targets[:,m['idx']] > m['val']] = 0
17
- elif m['op'] == 'lt':
18
- mask[targets[:,m['idx']] < m['val']] = 0
19
- elif m['op'] == 'ge':
20
- mask[targets[:,m['idx']] >= m['val']] = 0
21
- elif m['op'] == 'le':
22
- mask[targets[:,m['idx']] <= m['val']] = 0
23
- elif m['op'] == 'ne':
24
- mask[targets[:,m['idx']] != m['val']] = 0
25
- else:
26
- raise ValueError(f'Unknown mask op {m["op"]}')
27
- return mask == 1
28
-
29
- class MaskedL1Loss(MaskedLoss):
30
- def __init__(self, mask = [], index = 0):
31
- super().__init__(mask)
32
- self.index = index
33
- self.loss = nn.L1Loss()
34
-
35
- def __call__(self, logits, targets):
36
- mask = self.make_mask(targets)
37
- return self.loss(logits[mask], targets[mask][:,self.index])
38
-
39
- class BCEWithLogitsLoss():
40
- def __init__(self, weight=None, reduction='mean'):
41
- self.loss = nn.BCEWithLogitsLoss(weight=weight, reduction=reduction)
42
-
43
- def __call__(self, logits, targets):
44
- return self.loss(logits[:,0], targets.float())
45
-
46
- class MultiScore():
47
- def __init__(self, scores):
48
- self. score_fcns = []
49
- self.start_idx = []
50
- self.end_idx = []
51
- for score in scores:
52
- self.score_fcns.append(utils.buildFromConfig(score))
53
- self.start_idx.append(score['start_idx'])
54
- self.end_idx.append(score['end_idx'])
55
-
56
- def __call__(self, last_layer):
57
- scores = []
58
- for i in range(len(self.score_fcns)):
59
- scores.append(self.score_fcns[i](last_layer[:, self.start_idx[i]:self.end_idx[i]]))
60
- return torch.cat(scores, dim=1)
61
-
62
- class MultiLoss():
63
- def __init__(self, losses):
64
- self.loss_fcns = []
65
- self.label_start_idx = []
66
- self.label_end_idx = []
67
- self.output_start_idx = []
68
- self.output_end_idx = []
69
- self.weights = []
70
- self.label_types = []
71
- for loss in losses:
72
- self.loss_fcns.append(utils.buildFromConfig(loss))
73
- self.label_start_idx.append(loss['label_start_idx'])
74
- self.label_end_idx.append(loss['label_end_idx'])
75
- self.output_start_idx.append(loss['output_start_idx'])
76
- self.output_end_idx.append(loss['output_end_idx'])
77
- self.weights.append(loss.get('weight', 1.0))
78
- self.label_types.append(loss.get('label_type', 'float'))
79
-
80
- def __call__(self, logits, targets):
81
- loss = 0
82
- # print(logits.shape, targets.shape)
83
- for i in range(len(self.loss_fcns)):
84
- if self.label_types[i] == 'int':
85
- # print('loss', i, self.label_start_idx[i], self.label_end_idx[i], self.output_start_idx[i], self.output_end_idx[i])
86
- # print(logits[:, self.output_start_idx[i]:self.output_end_idx[i]].shape, targets[:, self.label_start_idx[i]].shape)
87
- loss += self.weights[i] * self.loss_fcns[i](logits[:, self.output_start_idx[i]:self.output_end_idx[i]], targets[:, self.label_start_idx[i]].to(int))
88
- elif self.label_end_idx[i] - self.label_start_idx[i] == 1:
89
- loss += self.weights[i] * self.loss_fcns[i](logits[:, self.output_start_idx[i]:self.output_end_idx[i]], targets[:, self.label_start_idx[i]])
90
- else:
91
- # print('loos', i, self.label_start_idx[i], self.label_end_idx[i], self.output_start_idx[i], self.output_end_idx[i])
92
- # print(logits[:, self.output_start_idx[i]:self.output_end_idx[i]].shape, targets[:, self.label_start_idx[i]:self.label_end_idx[i]].shape)
93
- loss += self.weights[i] * self.loss_fcns[i](logits[:, self.output_start_idx[i]:self.output_end_idx[i]], targets[:, self.label_start_idx[i]:self.label_end_idx[i]])
94
- return loss
95
-
96
- class AdvLoss():
97
- def __init__(self, loss, adv_loss, adv_weight=1.0):
98
- self.loss_fcn = utils.buildFromConfig(loss)
99
- self.adv_loss_fcn = utils.buildFromConfig(adv_loss)
100
- self.adv_weight = adv_weight
101
-
102
- def __call__(self, logits, targets):
103
- mask = targets[:,0] == 0
104
- loss = self.loss_fcn(logits[:,0], targets[:,0])
105
- adv_loss = self.adv_loss_fcn(logits[mask][:,1], targets[mask])
106
- return loss - self.adv_weight * adv_loss
107
-
108
- class MassWindowAdvLoss(AdvLoss):
109
- def __call__(self, logits, targets):
110
- mask = (targets[:,0] == 0) & (targets[:,1] > 5) & (targets[:,1] < 25)
111
- print(mask, mask.shape, mask.sum())
112
- loss = self.loss_fcn(logits[:,0], targets[:,0])
113
- print(loss)
114
- adv_loss = self.adv_loss_fcn(logits[mask][:,1], targets[mask][:,1])
115
- print(adv_loss)
116
- return loss - self.adv_weight * adv_loss
117
-
118
- class KDELoss(MaskedLoss):
119
- def __init__(self, mask = [], index = 0):
120
- self.index = index
121
- super().__init__(mask)
122
-
123
- def __call__(self, logits, targets):
124
- mask = self.make_mask(targets)
125
- logits = logits[mask]
126
- targets = targets[mask][:,self.index]
127
- N = logits.shape[0]
128
- masses = targets / torch.sqrt(torch.mean(targets**2))
129
- scores = logits[:,0] / torch.sqrt(torch.mean(logits**2))
130
-
131
- factor_2d = (1.0*N) ** (-2/6)
132
- covs = (factor_2d * torch.var(masses), factor_2d * torch.var(scores))
133
-
134
- m_diffs = torch.unsqueeze(masses, 1) - torch.unsqueeze(masses, 0)
135
- s_diffs = torch.unsqueeze(scores, 1) - torch.unsqueeze(scores, 0)
136
-
137
- ymm = torch.exp(- (m_diffs**2) / (4 * covs[0]))
138
- yss = torch.exp(- (s_diffs**2) / (4 * covs[1]))
139
-
140
- integral_rho_2d_rho_2d = torch.einsum('ij,ij->', ymm, yss)
141
- integral_rho_1d_rho_1d = torch.einsum('ij,kl->', ymm, yss)
142
- integral_rho_2d_rho_1d = torch.einsum('ij,ik->', ymm, yss)
143
- raw_integral = integral_rho_2d_rho_2d - 2 * integral_rho_2d_rho_1d / N + integral_rho_1d_rho_1d / N**2
144
- return raw_integral / (4 * torch.pi * N**2)
145
-
146
- class MultiLabelLoss():
147
- def __init__(self, label_names, label_types, label_weights = None):
148
- self.loss_fcn = []
149
- if (label_weights):
150
- self.weights = torch.tensor(label_weights)
151
- else:
152
- self.weights = torch.ones(len(label_types))
153
- for type in label_types:
154
- if (type == "r"):
155
- self.loss_fcn.append(torch.nn.MSELoss(reduce=False))
156
- elif (type == "c"):
157
- self.loss_fcn.append(torch.nn.BCEWithLogitsLoss())
158
- print(f"self.weights = {self.weights}")
159
-
160
- def __call__(self, logits, targets):
161
- targets = targets.float()
162
- loss = torch.zeros(len(logits[:, 0]), device = logits.get_device())
163
- for i in range(len(self.loss_fcn)):
164
- loss += self.weights[i] * self.loss_fcn[i](logits[:, i], targets[:, i])
165
- return torch.mean(loss)
166
-
167
-
168
- class MultiLabelFinish():
169
- def __init__(self, label_names, label_types):
170
- self.finish_fcn = []
171
- for type in label_types:
172
- if (type == "r"):
173
- self.finish_fcn.append(None)
174
- elif (type == "c"):
175
- self.finish_fcn.append(torch.special.expit)
176
-
177
- def __call__(self, logits):
178
- for i in range(len(self.finish_fcn)):
179
- if (self.finish_fcn[i]):
180
- logits[:, i] = self.finish_fcn[i](logits[:, i].to(torch.long))
181
- return logits
182
-
183
- class ContrastiveClusterLoss():
184
- def __init__(self, k=10, temperature=1, alpha=1):
185
- self.k = k
186
- self.temperature = temperature
187
- self.alpha = alpha
188
-
189
- def __call__(self, logits, targets):
190
- targets = targets.float()
191
- logits_combined = logits.float()
192
-
193
- hid_size = int(len(logits[0]) / 2)
194
-
195
- logits = normalize_embeddings(logits_combined[:, :hid_size])
196
- logits_augmented = normalize_embeddings(logits_combined[:, hid_size:])
197
-
198
- contrastive = contrastive_loss(logits, logits_augmented, self.temperature)
199
- clustering, _ = clustering_loss(logits, self.k)
200
-
201
- variance_loss = variance_regularization(logits) + variance_regularization(logits_augmented)
202
-
203
- return torch.mean(contrastive + clustering + self.alpha * variance_loss)
204
-
205
- class ContrastiveClusterFinish():
206
- def __init__(self, k = 10, temperature = 1, max_cluster_iterations = 10):
207
- self.k = k
208
- self.temperature = temperature
209
- self.max_cluster_iterations = max_cluster_iterations
210
-
211
- print(f"ContrastiveClusterFinish: k = {k}, temperature = {temperature}")
212
-
213
- def __call__(self, logits):
214
- logits_combined = logits.float()
215
-
216
- hid_size = int(len(logits[0]) / 2)
217
-
218
- logits = logits_combined[:, :hid_size]
219
- logits_augmented = logits_combined[:, hid_size:]
220
-
221
- contrastive = contrastive_loss(logits, logits_augmented, self.temperature)
222
- clustering, _ = clustering_loss(logits, self.k, self.max_cluster_iterations)
223
- variance = variance_regularization(logits) + variance_regularization(logits_augmented)
224
-
225
- return contrastive, clustering, variance
226
-
227
- def s(z_i, z_j):
228
- z_i = torch.tensor(z_i) if not isinstance(z_i, torch.Tensor) else z_i
229
- z_j = torch.tensor(z_j) if not isinstance(z_j, torch.Tensor) else z_j
230
-
231
- return torch.cdist(z_i, z_j, p=2)
232
- # dot_product = torch.dot(z_i, z_j)
233
- # norm_i = torch.linalg.norm(z_i)
234
- # norm_j = torch.linalg.norm(z_j)
235
-
236
- # return dot_product / (norm_i * norm_j)
237
-
238
- def contrastive_loss(logits, logits_augmented, temperature=1, margin=1.0):
239
- logits = torch.tensor(logits) if not isinstance(logits, torch.Tensor) else logits
240
- logits_augmented = torch.tensor(logits_augmented) if not isinstance(logits_augmented, torch.Tensor) else logits_augmented
241
-
242
- z = torch.cat((logits, logits_augmented), dim=0)
243
- similarity_matrix = torch.mm(z, z.t()) / temperature
244
- norms = torch.linalg.norm(z, dim=1)
245
- norm_matrix = torch.ger(norms, norms)
246
- similarity_matrix = similarity_matrix / norm_matrix
247
- mask = torch.eye(similarity_matrix.size(0), dtype=torch.bool)
248
-
249
- loss = 0
250
- for k in range(len(logits)):
251
- numerator = torch.exp(similarity_matrix[k, k + len(logits)])
252
- denominator = torch.sum(torch.exp(similarity_matrix[k, ~mask[k]]))
253
-
254
- loss += -torch.log(numerator / denominator)
255
-
256
- return loss
257
-
258
-
259
- def clustering_loss(logits, k=10, max_iterations=10):
260
- # Step 1: Initialize cluster means
261
- indices = torch.randperm(logits.size(0))[:k]
262
- cluster_means = logits[indices]
263
-
264
- prev_assignments = None
265
- assignment_history = []
266
- iteration = 0
267
-
268
- while iteration < max_iterations:
269
- iteration += 1
270
-
271
- # Step 2: Assign each data point to the nearest cluster mean
272
- distances = torch.cdist(logits, cluster_means, p=2) # Compute distances between logits and cluster means
273
- cluster_assignments = torch.argmin(distances, dim=1) # Assign each point to the nearest cluster mean
274
-
275
- # Check for convergence: if assignments do not change, break the loop
276
- if prev_assignments is not None and torch.equal(cluster_assignments, prev_assignments):
277
- break
278
-
279
- # Check for cycles: if assignments have been seen before, break the loop
280
- if any(torch.equal(cluster_assignments, prev) for prev in assignment_history):
281
- break
282
-
283
- assignment_history.append(cluster_assignments.clone())
284
- prev_assignments = cluster_assignments.clone()
285
-
286
- # Step 3: Update cluster means based on assignments
287
- new_cluster_means = torch.zeros_like(cluster_means)
288
- for i in range(k):
289
- assigned_points = logits[cluster_assignments == i]
290
- if assigned_points.size(0) > 0:
291
- new_cluster_means[i] = assigned_points.mean(dim=0)
292
- else:
293
- # If no points are assigned to the cluster, reinitialize the mean randomly
294
- new_cluster_means[i] = logits[torch.randint(0, logits.size(0), (1,)).item()]
295
- cluster_means = new_cluster_means
296
-
297
- # Step 4: Compute the clustering loss
298
- distances = torch.cdist(logits, cluster_means, p=2)
299
- min_distances = torch.min(distances, dim=1)[0]
300
- loss = torch.sum(min_distances ** 2)
301
-
302
- return loss, cluster_means
303
-
304
- def normalize_embeddings(embeddings):
305
- return embeddings / embeddings.norm(dim=1, keepdim=True)
306
-
307
- def variance_regularization(embeddings):
308
- mean_embedding = embeddings.mean(dim=0)
309
- variance = ((embeddings - mean_embedding) ** 2).mean()
310
- return variance
311
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
root_gnn_base/batched_dataset.py DELETED
@@ -1,190 +0,0 @@
1
- from dgl.dataloading import GraphDataLoader
2
- from torch.utils.data.sampler import SubsetRandomSampler
3
- from torch.utils.data.sampler import SequentialSampler
4
- from dgl.data import DGLDataset
5
- import torch
6
- import time
7
- import os
8
- import dgl
9
- from root_gnn_base import utils
10
-
11
- def GetBatchedLoader(dataset, batch_size, mask_fn = None, drop_last=True, **kwargs):
12
- if mask_fn == None:
13
- mask_fn = lambda x: torch.ones(len(x), dtype=torch.bool)
14
- dloader = GraphDataLoader(dataset, sampler=SubsetRandomSampler(torch.arange(len(dataset))[mask_fn(dataset)]), batch_size=batch_size, drop_last=drop_last, num_workers = 0)
15
- return dloader
16
-
17
- #Dataset which contains prebatched shuffled graphs. Cannot be saved to disk, else batching info is lost.
18
- class PreBatchedDataset(DGLDataset):
19
- def __init__(self, start_dataset, batch_size, mask_fn = None, drop_last=True, save_to_disk = True, suffix = '', chunks = 1, chunkno = -1, shuffle = True, padding_mode = 'NONE', **kwargs):
20
- print(f'Unused kwargs: {kwargs}')
21
- self.start_dataset = start_dataset
22
- self.start_dataset.load()
23
-
24
- self.batch_size = batch_size
25
- self.chunks = chunks
26
- self.chunkno = chunkno
27
- self.mask_fn = mask_fn
28
- self.drop_last = drop_last
29
- self.graphs = []
30
- self.label = []
31
- self.padding_mode = padding_mode
32
- self.save_to_disk = save_to_disk
33
- self.shuffle = shuffle
34
- self.suffix = suffix
35
- self.current_chunk = None
36
- self.current_chunk_idx = -1
37
- super().__init__(name = start_dataset.name + '_prebatched_padded', save_dir=start_dataset.save_dir)
38
-
39
- def process(self):
40
- first = 0
41
- last = len(self.start_dataset)
42
- if self.chunks > 1 and self.chunkno >= 0:
43
- first = int(self.chunkno / self.chunks * len(self.start_dataset))
44
- last = int((self.chunkno + 1) / self.chunks * len(self.start_dataset))
45
- print(f'Processing chunk {self.chunkno} of {self.chunks} from {first} to {last} of {len(self.start_dataset)}')
46
- mask = torch.logical_and(torch.logical_and(self.mask_fn(self.start_dataset), torch.arange(len(self.start_dataset)) >= first), torch.arange(len(self.start_dataset)) < last)
47
- if self.shuffle:
48
- dloader = GraphDataLoader(self.start_dataset, sampler=SubsetRandomSampler(torch.arange(len(self.start_dataset))[mask]), batch_size=self.batch_size, drop_last=self.drop_last)
49
- else: #Only don't shuffle if we're doing inference. Then we want all of the events anyways?
50
- dloader = GraphDataLoader(self.start_dataset, sampler=SequentialSampler(self.start_dataset), batch_size=self.batch_size, drop_last=self.drop_last)
51
- self.graphs = []
52
- self.labels = []
53
- self.tracking = []
54
- self.globals = []
55
- self.batch_num_nodes = []
56
- self.batch_num_edges = []
57
- max_edges = 0
58
- max_nodes = 0
59
- load_batch_start = time.time()
60
- for batch, label, tracking, global_feat in dloader:
61
- if batch.num_edges() > max_edges:
62
- max_edges = batch.num_edges()
63
- if batch.num_nodes() > max_nodes:
64
- max_nodes = batch.num_nodes()
65
- self.graphs.append(batch)
66
- self.labels.append(label)
67
- self.tracking.append(tracking)
68
- self.globals.append(global_feat)
69
- load_batch_end = time.time()
70
- print(f'Loaded {len(self.graphs)} batches in {load_batch_end - load_batch_start} seconds')
71
- if self.padding_mode == 'STEPS':
72
- pad_node, pad_edge = utils.pad_size(self.batch_size, max_edges, max_nodes)
73
- elif self.padding_mode == 'FIXED':
74
- print('Padding to fixed size. This is currently hardcoded.')
75
- pad_node = 16000
76
- pad_edge = 104000
77
- elif self.padding_mode == 'NONE':
78
- pad_node = 0
79
- pad_edge = 0
80
- else:
81
- pad_node = 0
82
- pad_edge = 0
83
- print(f'Max edges: {max_edges}, Max nodes: {max_nodes}, Padding to {pad_edge} edges and {pad_node} nodes')
84
- pad_start = time.time()
85
- if self.padding_mode == 'NODE':
86
- for i in range(len(self.graphs)):
87
- unbatched_g = dgl.unbatch(self.graphs[i])
88
- max_num_nodes = max(g.number_of_nodes() for g in unbatched_g)
89
- self.graphs[i] = utils.pad_batch_num_nodes(self.graphs[i], max_num_nodes)
90
- self.batch_num_nodes.append(self.graphs[i].batch_num_nodes())
91
- self.batch_num_edges.append(self.graphs[i].batch_num_edges())
92
- else:
93
- for i in range(len(self.graphs)):
94
- self.graphs[i] = utils.pad_batch(self.graphs[i], pad_edge, pad_node)
95
- self.batch_num_nodes.append(self.graphs[i].batch_num_nodes())
96
- self.batch_num_edges.append(self.graphs[i].batch_num_edges())
97
- pad_end = time.time()
98
- print(f'Padded {len(self.graphs)} batches in {pad_end - pad_start} seconds')
99
-
100
- def save(self):
101
- if not self.save_to_disk:
102
- return
103
- graph_path = os.path.join(self.save_dir, f'{self.name}_{self.chunkno}_{self.suffix}.bin')
104
- print(f'Saving dataset to {graph_path}')
105
- if len(self.graphs) == 0:
106
- return
107
- dgl.save_graphs(str(graph_path), self.graphs, {'labels': torch.stack(self.labels), 'batch_num_nodes': torch.stack(self.batch_num_nodes), 'batch_num_edges': torch.stack(self.batch_num_edges), 'tracking': torch.stack(self.tracking), 'globals': torch.stack(self.globals)})
108
-
109
- def has_cache(self):
110
- if not self.save_to_disk:
111
- return False
112
- for ch in range(self.chunks):
113
- graph_path = os.path.join(self.save_dir, f'{self.name}_{ch}_{self.suffix}.bin')
114
- if not os.path.exists(graph_path):
115
- print(f'Cache file {graph_path} does not exist, not loading from cache.')
116
- return False
117
- return True
118
-
119
- def load(self):
120
- if not self.save_to_disk:
121
- return
122
- self.graphs = []
123
- label_chunks = []
124
- tracking_chunks = []
125
- global_chunks = []
126
- for ch in range(self.chunks):
127
- graph_path = os.path.join(self.save_dir, f'{self.name}_{ch}_{self.suffix}.bin')
128
- print(f'Loading dataset from {graph_path}')
129
- graphs, label_dict = dgl.load_graphs(graph_path)
130
- label_chunks.append(label_dict['labels'])
131
- tracking_chunks.append(label_dict['tracking'])
132
- global_chunks.append(label_dict['globals'])
133
- for g, bnn, bne in zip(graphs, label_dict['batch_num_nodes'], label_dict['batch_num_edges']):
134
- g.set_batch_num_nodes(bnn)
135
- g.set_batch_num_edges(bne)
136
- self.graphs.extend(graphs)
137
- self.labels = torch.cat(label_chunks)
138
- self.tracking = torch.cat(tracking_chunks)
139
- self.globals = torch.cat(global_chunks)
140
-
141
- def __getitem__(self, idx):
142
- return self.graphs[idx], self.labels[idx], self.tracking[idx], self.globals[idx]
143
-
144
- def __len__(self):
145
- return len(self.graphs)
146
-
147
- #Dataset which contains prebatched shuffled graphs. Cannot be saved to disk, else batching info is lost.
148
- class LazyPreBatchedDataset(PreBatchedDataset):
149
- def __init__(self, **kwargs):
150
- # print(f'Unused kwargs: {kwargs}')
151
- self.current_chunk = None
152
- self.current_chunk_idx = -10
153
- self.label_chunks = []
154
- super().__init__(**kwargs)
155
-
156
- def load(self):
157
- if not self.save_to_disk:
158
- return
159
- self.label_chunks = []
160
- for ch in range(self.chunks):
161
- graph_path = os.path.join(self.save_dir, f'{self.name}_{ch}_{self.suffix}.bin')
162
- print(f'Loading dataset from {graph_path}')
163
- label_dict = dgl.data.graph_serialize.load_labels_v2(graph_path)
164
- self.label_chunks.append(label_dict)
165
-
166
- def __getitem__(self, idx):
167
- chunk_idx = -1
168
- sum = 0
169
- ev_idx = -999
170
- for i in range(len(self.label_chunks)):
171
- count = len(self.label_chunks[i]['labels'])
172
- if idx < sum + count:
173
- chunk_idx = i
174
- ev_idx = idx - sum
175
- break
176
- sum += count
177
- if chunk_idx != self.current_chunk_idx:
178
- # print(f"rank {self.rank} getting data from {self.name}_{chunk_idx}_{self.suffix}.bin")
179
- self.current_chunk, _ = dgl.load_graphs(os.path.join(self.save_dir, f'{self.name}_{chunk_idx}_{self.suffix}.bin'))
180
- self.current_chunk_idx = chunk_idx
181
- g = self.current_chunk[ev_idx]
182
- g.set_batch_num_nodes(self.label_chunks[chunk_idx]['batch_num_nodes'][ev_idx])
183
- g.set_batch_num_edges(self.label_chunks[chunk_idx]['batch_num_edges'][ev_idx])
184
- return g, self.label_chunks[chunk_idx]['labels'][ev_idx], self.label_chunks[chunk_idx]['tracking'][ev_idx], self.label_chunks[chunk_idx]['globals'][ev_idx]
185
-
186
- def __len__(self):
187
- l = 0
188
- for chunk in self.label_chunks:
189
- l += len(chunk['labels'])
190
- return l
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
root_gnn_base/custom_scheduler.py DELETED
@@ -1,565 +0,0 @@
1
- import types
2
- import math
3
- import torch
4
- from torch import inf
5
- from functools import wraps, partial
6
- import warnings
7
- import weakref
8
- from collections import Counter
9
- from bisect import bisect_right
10
-
11
- from models import GCN
12
-
13
-
14
-
15
-
16
- ### Code from: https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#ReduceLROnPlateau
17
-
18
- Optimizer = torch.optim.Optimizer
19
-
20
- __all__ = ['LambdaLR', 'MultiplicativeLR', 'StepLR', 'MultiStepLR', 'ConstantLR', 'LinearLR',
21
- 'ExponentialLR', 'SequentialLR', 'CosineAnnealingLR', 'ChainedScheduler', 'ReduceLROnPlateau',
22
- 'CyclicLR', 'CosineAnnealingWarmRestarts', 'OneCycleLR', 'PolynomialLR', 'LRScheduler']
23
-
24
- EPOCH_DEPRECATION_WARNING = (
25
- "The epoch parameter in `scheduler.step()` was not necessary and is being "
26
- "deprecated where possible. Please use `scheduler.step()` to step the "
27
- "scheduler. During the deprecation, if epoch is different from None, the "
28
- "closed form is used instead of the new chainable form, where available. "
29
- "Please open an issue if you are unable to replicate your use case: "
30
- "https://github.com/pytorch/pytorch/issues/new/choose."
31
- )
32
-
33
-
34
- def update_LR(opt, lr):
35
- for param_group in opt.param_groups:
36
- param_group['lr'] = lr
37
-
38
- def print_LR(opt):
39
- for param_group in opt.param_groups:
40
- print(f"LR = {param_group['lr']}")
41
-
42
- def _check_verbose_deprecated_warning(verbose):
43
- """Raises a warning when verbose is not the default value."""
44
- if verbose != "deprecated":
45
- warnings.warn("The verbose parameter is deprecated. Please use get_last_lr() "
46
- "to access the learning rate.", UserWarning)
47
- return verbose
48
- return False
49
-
50
- class LRScheduler:
51
-
52
- def __init__(self, optimizer, last_epoch=-1, verbose="deprecated"):
53
-
54
- # Attach optimizer
55
- if not isinstance(optimizer, Optimizer):
56
- raise TypeError(f'{type(optimizer).__name__} is not an Optimizer')
57
- self.optimizer = optimizer
58
-
59
- # Initialize epoch and base learning rates
60
- if last_epoch == -1:
61
- for group in optimizer.param_groups:
62
- group.setdefault('initial_lr', group['lr'])
63
- else:
64
- for i, group in enumerate(optimizer.param_groups):
65
- if 'initial_lr' not in group:
66
- raise KeyError("param 'initial_lr' is not specified "
67
- f"in param_groups[{i}] when resuming an optimizer")
68
- self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups]
69
- self.last_epoch = last_epoch
70
-
71
- # Following https://github.com/pytorch/pytorch/issues/20124
72
- # We would like to ensure that `lr_scheduler.step()` is called after
73
- # `optimizer.step()`
74
- def with_counter(method):
75
- if getattr(method, '_with_counter', False):
76
- # `optimizer.step()` has already been replaced, return.
77
- return method
78
-
79
- # Keep a weak reference to the optimizer instance to prevent
80
- # cyclic references.
81
- instance_ref = weakref.ref(method.__self__)
82
- # Get the unbound method for the same purpose.
83
- func = method.__func__
84
- cls = instance_ref().__class__
85
- del method
86
-
87
- @wraps(func)
88
- def wrapper(*args, **kwargs):
89
- instance = instance_ref()
90
- instance._step_count += 1
91
- wrapped = func.__get__(instance, cls)
92
- return wrapped(*args, **kwargs)
93
-
94
- # Note that the returned function here is no longer a bound method,
95
- # so attributes like `__func__` and `__self__` no longer exist.
96
- wrapper._with_counter = True
97
- return wrapper
98
-
99
- self.optimizer.step = with_counter(self.optimizer.step)
100
- self.verbose = _check_verbose_deprecated_warning(verbose)
101
-
102
- self._initial_step()
103
-
104
- def _initial_step(self):
105
- """Initialize step counts and performs a step"""
106
- self.optimizer._step_count = 0
107
- self._step_count = 0
108
- self.step()
109
-
110
- def state_dict(self):
111
- """Returns the state of the scheduler as a :class:`dict`.
112
-
113
- It contains an entry for every variable in self.__dict__ which
114
- is not the optimizer.
115
- """
116
- return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
117
-
118
- def load_state_dict(self, state_dict):
119
- """Loads the schedulers state.
120
-
121
- Args:
122
- state_dict (dict): scheduler state. Should be an object returned
123
- from a call to :meth:`state_dict`.
124
- """
125
- self.__dict__.update(state_dict)
126
-
127
- def get_last_lr(self):
128
- """ Return last computed learning rate by current scheduler.
129
- """
130
- return self._last_lr
131
-
132
- def get_lr(self):
133
- # Compute learning rate using chainable form of the scheduler
134
- raise NotImplementedError
135
-
136
- def print_lr(self, is_verbose, group, lr, epoch=None):
137
- """Display the current learning rate.
138
- """
139
- if is_verbose:
140
- if epoch is None:
141
- print(f'Adjusting learning rate of group {group} to {lr:.4e}.')
142
- else:
143
- epoch_str = ("%.2f" if isinstance(epoch, float) else
144
- "%.5d") % epoch
145
- print(f'Epoch {epoch_str}: adjusting learning rate of group {group} to {lr:.4e}.')
146
-
147
-
148
- def step(self, epoch=None):
149
- # Raise a warning if old pattern is detected
150
- # https://github.com/pytorch/pytorch/issues/20124
151
- if self._step_count == 1:
152
- if not hasattr(self.optimizer.step, "_with_counter"):
153
- warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler "
154
- "initialization. Please, make sure to call `optimizer.step()` before "
155
- "`lr_scheduler.step()`. See more details at "
156
- "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
157
-
158
- # Just check if there were two first lr_scheduler.step() calls before optimizer.step()
159
- elif self.optimizer._step_count < 1:
160
- warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
161
- "In PyTorch 1.1.0 and later, you should call them in the opposite order: "
162
- "`optimizer.step()` before `lr_scheduler.step()`. Failure to do this "
163
- "will result in PyTorch skipping the first value of the learning rate schedule. "
164
- "See more details at "
165
- "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
166
- self._step_count += 1
167
-
168
- with _enable_get_lr_call(self):
169
- if epoch is None:
170
- self.last_epoch += 1
171
- values = self.get_lr()
172
- else:
173
- warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
174
- self.last_epoch = epoch
175
- if hasattr(self, "_get_closed_form_lr"):
176
- values = self._get_closed_form_lr()
177
- else:
178
- values = self.get_lr()
179
-
180
- for i, data in enumerate(zip(self.optimizer.param_groups, values)):
181
- param_group, lr = data
182
- param_group['lr'] = lr
183
-
184
- self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
185
-
186
-
187
- # Including _LRScheduler for backwards compatibility
188
- # Subclass instead of assign because we want __name__ of _LRScheduler to be _LRScheduler (assigning would make it LRScheduler).
189
- class _LRScheduler(LRScheduler):
190
- pass
191
-
192
-
193
- class _enable_get_lr_call:
194
-
195
- def __init__(self, o):
196
- self.o = o
197
-
198
- def __enter__(self):
199
- self.o._get_lr_called_within_step = True
200
- return self
201
-
202
- def __exit__(self, type, value, traceback):
203
- self.o._get_lr_called_within_step = False
204
-
205
-
206
- class Dynamic_LR(LRScheduler):
207
- """Reduce learning rate when a metric has stopped improving.
208
- Models often benefit from reducing the learning rate by a factor
209
- of 2-10 once learning stagnates. This scheduler reads a metrics
210
- quantity and if no improvement is seen for a 'patience' number
211
- of epochs, the learning rate is reduced.
212
-
213
- Args:
214
- optimizer (Optimizer): Wrapped optimizer.
215
- mode (str): One of `min`, `max`. In `min` mode, lr will
216
- be reduced when the quantity monitored has stopped
217
- decreasing; in `max` mode it will be reduced when the
218
- quantity monitored has stopped increasing. Default: 'min'.
219
- factor (float): Factor by which the learning rate will be
220
- reduced. new_lr = lr * factor. Default: 0.1.
221
- patience (int): Number of epochs with no improvement after
222
- which learning rate will be reduced. For example, if
223
- `patience = 2`, then we will ignore the first 2 epochs
224
- with no improvement, and will only decrease the LR after the
225
- 3rd epoch if the loss still hasn't improved then.
226
- Default: 10.
227
- threshold (float): Threshold for measuring the new optimum,
228
- to only focus on significant changes. Default: 1e-4.
229
- threshold_mode (str): One of `rel`, `abs`. In `rel` mode,
230
- dynamic_threshold = best * ( 1 + threshold ) in 'max'
231
- mode or best * ( 1 - threshold ) in `min` mode.
232
- In `abs` mode, dynamic_threshold = best + threshold in
233
- `max` mode or best - threshold in `min` mode. Default: 'rel'.
234
- cooldown (int): Number of epochs to wait before resuming
235
- normal operation after lr has been reduced. Default: 0.
236
- min_lr (float or list): A scalar or a list of scalars. A
237
- lower bound on the learning rate of all param groups
238
- or each group respectively. Default: 0.
239
- eps (float): Minimal decay applied to lr. If the difference
240
- between new and old lr is smaller than eps, the update is
241
- ignored. Default: 1e-8.
242
- verbose (bool): If ``True``, prints a message to stdout for
243
- each update. Default: ``False``.
244
-
245
- .. deprecated:: 2.2
246
- ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
247
- learning rate.
248
-
249
- Example:
250
- >>> # xdoctest: +SKIP
251
- >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
252
- >>> scheduler = ReduceLROnPlateau(optimizer, 'min')
253
- >>> for epoch in range(10):
254
- >>> train(...)
255
- >>> val_loss = validate(...)
256
- >>> # Note that step should be called after validate()
257
- >>> scheduler.step(val_loss)
258
- """
259
-
260
- def __init__(self, optimizer, mode = 'max', factor=0.1, patience=10,
261
- plateau_var = "test_auc",
262
- threshold=1e-4, threshold_mode='rel', cooldown=0,
263
- min_lr=0, max_lr=1e-4, eps=1e-8, verbose=False):
264
-
265
- """
266
- if factor >= 1.0:
267
- raise ValueError('Factor should be < 1.0.')
268
- """
269
- self.factor = factor
270
-
271
- # Attach optimizer
272
- if not isinstance(optimizer, Optimizer):
273
- raise TypeError(f'{type(optimizer).__name__} is not an Optimizer')
274
- self.optimizer = optimizer
275
-
276
- if isinstance(min_lr, (list, tuple)):
277
- if len(min_lr) != len(optimizer.param_groups):
278
- raise ValueError(f"expected {len(optimizer.param_groups)} min_lrs, got {len(min_lr)}")
279
- self.min_lrs = list(min_lr)
280
- self.max_lrs = list(max_lr)
281
- else:
282
- self.min_lrs = [min_lr] * len(optimizer.param_groups)
283
- self.max_lrs = [max_lr] * len(optimizer.param_groups)
284
-
285
- self.patience = patience
286
- self.plateau_var = plateau_var
287
-
288
- self.verbose = verbose
289
- self.cooldown = cooldown
290
- self.cooldown_counter = 0
291
- self.mode = mode
292
- self.threshold = threshold
293
- self.threshold_mode = threshold_mode
294
- self.best = None
295
- self.num_bad_epochs = None
296
- self.mode_worse = None # the worse value for the chosen mode
297
- self.eps = eps
298
- self.last_epoch = 0
299
- self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
300
- self._init_is_better(mode=mode, threshold=threshold,
301
- threshold_mode=threshold_mode)
302
- self._reset()
303
-
304
- def _reset(self):
305
- """Resets num_bad_epochs counter and cooldown counter."""
306
- self.best = self.mode_worse
307
- self.cooldown_counter = 0
308
- self.num_bad_epochs = 0
309
-
310
- def step(self, model, metrics, epoch=None):
311
- # convert `metrics` to float, in case it's a zero-dim Tensor
312
- current = float(metrics[self.plateau_var])
313
- if epoch is None:
314
- epoch = self.last_epoch + 1
315
- else:
316
- warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
317
- self.last_epoch = epoch
318
-
319
- if self.is_better(current, self.best):
320
- if(self.verbose):
321
- print("Model is improving!")
322
- self.best = current
323
- self.num_bad_epochs = 0
324
- else:
325
- if(self.verbose):
326
- print(f"Model is not improving :( best = {self.best}, current = {current}")
327
- self.num_bad_epochs += 1
328
-
329
- if self.in_cooldown:
330
- self.cooldown_counter -= 1
331
- self.num_bad_epochs = 0 # ignore any bad epochs in cooldown
332
-
333
- if self.num_bad_epochs > self.patience:
334
- self._reduce_lr(epoch)
335
- self.cooldown_counter = self.cooldown
336
- self.num_bad_epochs = 0
337
-
338
- self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
339
-
340
- def _reduce_lr(self, epoch):
341
- print("Adjusting Learning Rate")
342
- self._reset()
343
- for i, param_group in enumerate(self.optimizer.param_groups):
344
- old_lr = float(param_group['lr'])
345
- new_lr = max(old_lr * self.factor, self.min_lrs[i])
346
- new_lr = min(new_lr, self.max_lrs[i])
347
- if abs(old_lr - new_lr) > self.eps:
348
- param_group['lr'] = new_lr
349
-
350
- def get_last_lr(self):
351
- return self._last_lr
352
- @property
353
- def in_cooldown(self):
354
- return self.cooldown_counter > 0
355
-
356
- def is_better(self, a, best):
357
- if self.mode == 'min' and self.threshold_mode == 'rel':
358
- rel_epsilon = 1. - self.threshold
359
- return a < best * rel_epsilon
360
-
361
- elif self.mode == 'min' and self.threshold_mode == 'abs':
362
- return a < best - self.threshold
363
-
364
- elif self.mode == 'max' and self.threshold_mode == 'rel':
365
- rel_epsilon = self.threshold + 1.
366
- return a > best * rel_epsilon
367
-
368
- else: # mode == 'max' and epsilon_mode == 'abs':
369
- return a > best + self.threshold
370
-
371
- def _init_is_better(self, mode, threshold, threshold_mode):
372
- if mode not in {'min', 'max'}:
373
- raise ValueError('mode ' + mode + ' is unknown!')
374
- if threshold_mode not in {'rel', 'abs'}:
375
- raise ValueError('threshold mode ' + threshold_mode + ' is unknown!')
376
-
377
- if mode == 'min':
378
- self.mode_worse = inf
379
- else: # mode == 'max':
380
- self.mode_worse = -inf
381
-
382
- self.mode = mode
383
- self.threshold = threshold
384
- self.threshold_mode = threshold_mode
385
-
386
- def state_dict(self):
387
- return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
388
-
389
- def load_state_dict(self, state_dict):
390
- self.__dict__.update(state_dict)
391
- self._init_is_better(mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode)
392
-
393
- class Action_On_Plateau():
394
-
395
- def __init__(self, mode = 'max', patience=10,
396
- plateau_var = "test_auc",
397
- threshold=1e-4, threshold_mode='rel', cooldown=0,
398
- eps=1e-8, verbose=False):
399
-
400
- self.patience = patience
401
- self.plateau_var = plateau_var
402
-
403
- self.verbose = verbose
404
- self.cooldown = cooldown
405
- self.cooldown_counter = 0
406
- self.mode = mode
407
- self.threshold = threshold
408
- self.threshold_mode = threshold_mode
409
- self.best = None
410
- self.num_bad_epochs = None
411
- self.mode_worse = None # the worse value for the chosen mode
412
- self.eps = eps
413
- self.last_epoch = 0
414
- self._init_is_better(mode=mode, threshold=threshold,
415
- threshold_mode=threshold_mode)
416
- self._reset()
417
-
418
- def _reset(self):
419
- """Resets num_bad_epochs counter and cooldown counter."""
420
- self.best = self.mode_worse
421
- self.cooldown_counter = 0
422
- self.num_bad_epochs = 0
423
-
424
- def step(self, model, metrics, epoch=None):
425
- # convert `metrics` to float, in case it's a zero-dim Tensor
426
- current = float(metrics[self.plateau_var])
427
- if epoch is None:
428
- epoch = self.last_epoch + 1
429
- else:
430
- warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
431
- self.last_epoch = epoch
432
-
433
- if self.is_better(current, self.best):
434
- if(self.verbose):
435
- print("Model is improving!")
436
- self.best = current
437
- self.num_bad_epochs = 0
438
- else:
439
- if(self.verbose):
440
- print(f"Model is not improving :( best = {self.best}, current = {current}")
441
- self.num_bad_epochs += 1
442
-
443
- if self.in_cooldown:
444
- self.cooldown_counter -= 1
445
- self.num_bad_epochs = 0 # ignore any bad epochs in cooldown
446
-
447
- if self.num_bad_epochs > self.patience:
448
- self.action(model, metrics, epoch)
449
-
450
- def action(self, model, metrics, epoch=None):
451
- if(self.verbose):
452
- print("Doing my action")
453
-
454
- @property
455
- def in_cooldown(self):
456
- return self.cooldown_counter > 0
457
-
458
- def is_better(self, a, best):
459
- if self.mode == 'min' and self.threshold_mode == 'rel':
460
- rel_epsilon = 1. - self.threshold
461
- return a < best * rel_epsilon
462
-
463
- elif self.mode == 'min' and self.threshold_mode == 'abs':
464
- return a < best - self.threshold
465
-
466
- elif self.mode == 'max' and self.threshold_mode == 'rel':
467
- rel_epsilon = self.threshold + 1.
468
- return a > best * rel_epsilon
469
-
470
- else: # mode == 'max' and epsilon_mode == 'abs':
471
- return a > best + self.threshold
472
-
473
- def _init_is_better(self, mode, threshold, threshold_mode):
474
- if mode not in {'min', 'max'}:
475
- raise ValueError('mode ' + mode + ' is unknown!')
476
- if threshold_mode not in {'rel', 'abs'}:
477
- raise ValueError('threshold mode ' + threshold_mode + ' is unknown!')
478
-
479
- if mode == 'min':
480
- self.mode_worse = inf
481
- else: # mode == 'max':
482
- self.mode_worse = -inf
483
-
484
- self.mode = mode
485
- self.threshold = threshold
486
- self.threshold_mode = threshold_mode
487
-
488
- class Partial_Reset(Action_On_Plateau):
489
-
490
- def __init__(self, mode='max', patience=10, plateau_var="test_auc",
491
- threshold=0.0001, threshold_mode='rel', cooldown=0,
492
- eps=1e-8, verbose=False):
493
-
494
- super().__init__(mode, patience, plateau_var, threshold,
495
- threshold_mode, cooldown, eps, verbose)
496
-
497
- def action(self, model, metrics, epoch=None):
498
- print("Partial Reset!!")
499
- GCN.partial_reset(model)
500
- self._reset()
501
- self.cooldown_counter = self.cooldown
502
- self.num_bad_epochs = 0
503
-
504
-
505
- class Full_Reset(Action_On_Plateau):
506
-
507
- def __init__(self, mode='max', patience=10, plateau_var="test_auc",
508
- threshold=0.0001, threshold_mode='rel', cooldown=0,
509
- eps=1e-8, verbose=False):
510
-
511
- super().__init__(mode, patience, plateau_var, threshold,
512
- threshold_mode, cooldown, eps, verbose)
513
-
514
- def action(self, model, metrics, epoch=None):
515
- print("Full Reset!!")
516
- GCN.full_reset(model)
517
- self._reset()
518
- self.cooldown_counter = self.cooldown
519
- self.num_bad_epochs = 0
520
-
521
- class Dynamic_LR_AND_Partial_Reset():
522
- def __init__(self, optimizer, mode = 'max', factor=0.1, patience=10,
523
- plateau_var = "test_auc", reset_patience=None, reset_plateau_var=None,
524
- threshold=1e-4, threshold_mode='rel', cooldown=0,
525
- min_lr=0, max_lr=1e-4, eps=1e-8, verbose=False):
526
-
527
- if (reset_patience == None):
528
- reset_patience = patience
529
- if(reset_plateau_var == None):
530
- reset_plateau_var = plateau_var
531
-
532
- self.dynamic_lr = Dynamic_LR(optimizer, mode=mode, factor=factor, patience = patience,
533
- plateau_var=plateau_var, threshold=threshold, threshold_mode =threshold_mode,
534
- cooldown=cooldown, min_lr=min_lr, max_lr=max_lr, eps=eps, verbose=verbose)
535
-
536
- self.partial_reset = Partial_Reset(mode=mode, patience=reset_patience, plateau_var=reset_plateau_var,
537
- threshold=threshold, threshold_mode=threshold_mode, cooldown=cooldown,
538
- eps=eps)
539
-
540
- def step(self, model, metrics, epoch=None):
541
- self.dynamic_lr.step(model=model, metrics=metrics, epoch=epoch)
542
- self.partial_reset.step(model=model, metrics=metrics, epoch=epoch)
543
-
544
- class Dynamic_LR_AND_Full_Reset():
545
- def __init__(self, optimizer, mode = 'max', factor=0.1, patience=10,
546
- plateau_var = "test_auc", reset_patience=None, reset_plateau_var=None,
547
- threshold=1e-4, threshold_mode='rel', cooldown=0,
548
- min_lr=0, max_lr=1e-4, eps=1e-8, verbose=False):
549
-
550
- if (reset_patience == None):
551
- reset_patience = patience
552
- if(reset_plateau_var == None):
553
- reset_plateau_var = plateau_var
554
-
555
- self.dynamic_lr = Dynamic_LR(optimizer, mode=mode, factor=factor, patience = patience,
556
- plateau_var=plateau_var, threshold=threshold, threshold_mode =threshold_mode,
557
- cooldown=cooldown, min_lr=min_lr, max_lr=max_lr, eps=eps, verbose=verbose)
558
-
559
- self.full_reset = Full_Reset(mode=mode, patience=reset_patience, plateau_var=reset_plateau_var,
560
- threshold=threshold, threshold_mode=threshold_mode, cooldown=cooldown,
561
- eps=eps)
562
-
563
- def step(self, model, metrics, epoch=None):
564
- self.dynamic_lr.step(model=model, metrics=metrics, epoch=epoch)
565
- self.full_reset.step(model=model, metrics=metrics, epoch=epoch)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
root_gnn_base/dataset.py DELETED
@@ -1,685 +0,0 @@
1
- from dgl.data import DGLDataset
2
- import dgl
3
- import ROOT
4
- import torch
5
- import os
6
- import glob
7
- import time
8
- import numpy as np
9
- from root_gnn_base import utils
10
-
11
- def node_features_from_tree(ch, node_branch_names, node_branch_types, node_feature_scales):
12
- lengths = []
13
- for branch, node_type in zip(node_branch_names[0], node_branch_types):
14
- if node_type == 'single':
15
- lengths.append(1)
16
- elif node_type == 'vector':
17
- lengths.append(len(getattr(ch, branch)))
18
- else:
19
- print('Unknown node branch type: {}'.format(node_type))
20
- features = []
21
- for node_feat in node_branch_names:
22
- if node_feat == 'CALC_E':
23
- features.append(features[0]*torch.cosh(features[1]))
24
- continue
25
- elif node_feat == 'NODE_TYPE':
26
- feat = []
27
- for i, length in enumerate(lengths):
28
- feat.extend([i,]*length)
29
- features.append(torch.tensor(feat))
30
- continue
31
- feat = []
32
- itype = 0
33
- for length, branch, node_type in zip(lengths, node_feat, node_branch_types):
34
- if isinstance(branch, (int, float, complex)):
35
- feat.extend([branch,]*length)
36
- elif branch == 'CALC_E':
37
- this_type_starts_at = sum(lengths[:itype])
38
- this_type_ends_at = sum(lengths[:itype+1])
39
- feat.extend(features[0][this_type_starts_at:this_type_ends_at]*torch.cosh(features[1][this_type_starts_at:this_type_ends_at]))
40
- elif node_type == 'single':
41
- feat.append(getattr(ch, branch))
42
- elif node_type == 'vector':
43
- feat.extend(getattr(ch, branch))
44
- itype += 1
45
- features.append(torch.tensor(feat))
46
- return torch.stack(features, dim=1) * node_feature_scales, lengths
47
-
48
- def full_connected_graph(n_nodes, self_loops=True):
49
- senders = []
50
- receivers = []
51
- senders = np.arange(n_nodes*n_nodes) // n_nodes
52
- receivers = np.arange(n_nodes*n_nodes) % n_nodes
53
- if not self_loops and n_nodes > 1:
54
- mask = senders != receivers
55
- senders = senders[mask]
56
- receivers = receivers[mask]
57
- return dgl.graph((senders, receivers))
58
-
59
- def check_selection(ch, selection):
60
- var, cut, op = selection
61
- if op == '>':
62
- return getattr(ch, var) > cut
63
- elif op == '<':
64
- return getattr(ch, var) < cut
65
- elif op == '==':
66
- return getattr(ch, var) == cut
67
-
68
- def check_selections(ch, selections):
69
- for selection in selections:
70
- if not check_selection(ch, selection):
71
- return False
72
- return True
73
-
74
- #Base dataset class for making graphs from ROOT ntuples.
75
- class RootDataset(DGLDataset):
76
- def __init__(self, name=None, raw_dir=None, save_dir=None, label=1, file_names = '*.root', node_branch_names=None, node_branch_types=None, node_feature_scales=None,
77
- selections=[], save=True, tree_name = 'nominal_Loose', fold_var = 'eventNumber', weight_var = None, chunks = 1, process_chunks = None, global_features = [], tracking_info = [], **kwargs):
78
- print(f'Unused args while creating RootDataset: {kwargs}')
79
- self.label = label
80
- self.counts = []
81
- self.selections = selections
82
- self.save_to_disk = save
83
- self.file_names = file_names
84
- self.node_branch_names = node_branch_names
85
- self.node_branch_types = node_branch_types
86
- self.node_feature_scales = torch.tensor([float(sf) for sf in node_feature_scales])
87
- self.tree_name = tree_name
88
- self.fold_var = fold_var
89
- self.tracking_info = tracking_info
90
- self.tracking_info.insert(0, fold_var)
91
- if weight_var == None:
92
- weight_var = 1
93
- self.tracking_info.insert(1, weight_var)
94
- self.global_features = global_features
95
- self.chunks = chunks
96
- self.process_chunks = process_chunks
97
- if self.process_chunks is None:
98
- self.process_chunks = [i for i in range(self.chunks)]
99
- self.times = [0, 0]
100
- super().__init__(name=name, raw_dir=raw_dir, save_dir=save_dir)
101
-
102
- def get_list_of_branches(self):
103
- branches = []
104
- for feat in self.node_branch_names:
105
- if isinstance(feat, list):
106
- for branch in feat:
107
- if branch == 'CALC_E':
108
- continue
109
- if isinstance(branch, str):
110
- branches.append(branch)
111
- for feat in self.global_features:
112
- if isinstance(feat, str):
113
- branches.append(feat)
114
- for feat in self.tracking_info:
115
- if isinstance(feat, str):
116
- branches.append(feat)
117
- for selection in self.selections:
118
- branches.append(selection[0])
119
- return branches
120
-
121
- def make_graph(self, ch):
122
- t1 = time.time()
123
- features, _ = node_features_from_tree(ch, self.node_branch_names, self.node_branch_types, self.node_feature_scales)
124
- features = features[features[:,0] != 0]
125
- t2 = time.time()
126
- g = full_connected_graph(features.shape[0], self_loops=False)
127
- g.ndata['features'] = features
128
- t3 = time.time()
129
- self.times[0] += t2 - t1
130
- self.times[1] += t3 - t2
131
- return g
132
-
133
- def process(self):
134
- times = [0, 0, 0]
135
- oldtime = time.time()
136
- if isinstance(self.file_names, str):
137
- self.files = glob.glob(os.path.join(self.raw_dir, self.file_names))
138
- else:
139
- self.files = []
140
- for file_name in self.file_names:
141
- self.files.extend(glob.glob(os.path.join(self.raw_dir, file_name)))
142
- self.chain = ROOT.TChain(self.tree_name)
143
-
144
- if len(self.files) == 0:
145
- print('No files found in {}'.format(os.path.join(self.raw_dir, self.file_names)))
146
- for file in self.files:
147
- utils.set_timeout(60*2)
148
- self.chain.Add(file)
149
- utils.unset_timeout()
150
- branches = self.get_list_of_branches()
151
- self.chain.SetBranchStatus('*', 0)
152
- for branch in branches:
153
- self.chain.SetBranchStatus(branch, 1)
154
- newtime = time.time()
155
- times[0] += newtime - oldtime
156
- chunks = np.array_split(np.arange(self.chain.GetEntries()), self.chunks)
157
- chunks = [chunk for i, chunk in enumerate(chunks) if i in self.process_chunks]
158
-
159
- self.graph_chunks = []
160
- self.label_chunks = []
161
- self.tracking_chunks = []
162
- self.global_chunks = []
163
- chunk_id = -1
164
- for chunk in chunks:
165
- chunk_id += 1
166
- graphs = []
167
- labels = []
168
- tracking = []
169
- globals = []
170
- for ientry in chunk:
171
- if (ientry % 10000 == 0):
172
- print('Processing event {}/{}'.format(ientry, self.chain.GetEntries()), flush=True)
173
- self.chain.GetEntry(ientry)
174
- passed = True
175
- for selection in self.selections:
176
- if not check_selection(self.chain, selection):
177
- passed = False
178
- continue
179
- oldtime = newtime
180
- newtime = time.time()
181
- times[1] += newtime - oldtime
182
- if passed:
183
- graphs.append(self.make_graph(self.chain))
184
- labels.append( self.label )
185
- tracking.append(torch.zeros(len(self.tracking_info), dtype=torch.double))
186
- globals.append(torch.zeros(len(self.global_features)))
187
- for i_ti, tr_branch in enumerate(self.tracking_info):
188
- if isinstance(tr_branch, str):
189
- tracking[-1][i_ti] = getattr(self.chain, tr_branch)
190
- else:
191
- tracking[-1][i_ti] = tr_branch
192
- for i_gl, gl_branch in enumerate(self.global_features):
193
- globals[-1][i_gl] = getattr(self.chain, gl_branch)
194
- oldtime = newtime
195
- newtime = time.time()
196
- times[2] += newtime - oldtime
197
-
198
- labels = torch.tensor(labels)
199
- tracking = torch.stack(tracking)
200
- globals = torch.stack(globals)
201
-
202
- # self.labels = labels
203
- # self.tracking = tracking
204
- # self.global_features = globals
205
- # self.graphs = graphs
206
-
207
- self.save_chunk(chunk_id, graphs, labels, tracking, globals)
208
-
209
- return
210
- self.graphs = self.graph_chunks[0]
211
- for chunk in self.graph_chunks[1:]:
212
- self.graphs += chunk
213
- self.labels = torch.cat(self.label_chunks)
214
- self.tracking = torch.cat(self.tracking_chunks)
215
- self.global_features = torch.cat(self.global_chunks)
216
- print('Time spent: Creating TChain: {}s, Getting Entries and Selection: {}s, Graph Creation: {}s'.format(*times))
217
- print('Time spent in node_features_from_tree: {}s, full_connected_graph: {}s'.format(*self.times))
218
-
219
- def save(self):
220
- """save the graph list and the labels"""
221
- if not self.save_to_disk:
222
- return
223
- graph_path = os.path.join(self.save_dir, self.name + '.bin')
224
- if self.chunks == 1:
225
- # print(len(self.graphs))
226
- # print(len(self.labels))
227
- # print(len(self.tracking))
228
- # print(len(self.globals))
229
- print(f'Saving dataset to {os.path.join(self.save_dir, self.name + ".bin")}')
230
- dgl.save_graphs(str(graph_path), self.graphs, {'labels': torch.tensor(self.labels), 'tracking': torch.tensor(self.tracking), 'global': torch.tensor(self.global_features)})
231
- else:
232
- print(len(self.graph_chunks))
233
- for i in range(len(self.process_chunks)):
234
- print(f'Saving dataset to {os.path.join(self.save_dir, self.name + f"_{self.process_chunks[i]}.bin")}')
235
- dgl.save_graphs(str(graph_path).replace('.bin', f'_{self.process_chunks[i]}.bin'), self.graph_chunks[i], {'labels': self.label_chunks[i], 'tracking': self.tracking_chunks[i], 'global': self.global_chunks[i]})
236
-
237
- def save_chunk(self, chunk_id, graphs, labels, tracking, globals):
238
- if not self.save_to_disk:
239
- return
240
- graph_path = os.path.join(self.save_dir, self.name + '.bin')
241
- print(f'Saving dataset to {os.path.join(self.save_dir, self.name + f"_{self.process_chunks[chunk_id]}.bin")}')
242
- dgl.save_graphs(str(graph_path).replace('.bin', f'_{self.process_chunks[chunk_id]}.bin'), graphs, {'labels': labels, 'tracking': tracking, 'global': globals})
243
-
244
- def has_cache(self):
245
- print(f'Checking for cache of {self.name}')
246
- if not self.save_to_disk:
247
- print('Skipping load.')
248
- return False
249
- if self.chunks == 1:
250
- graph_path = os.path.join(self.save_dir, self.name + '.bin')
251
- return os.path.exists(graph_path)
252
- else:
253
- for i in range(len(self.process_chunks)):
254
- graph_path = os.path.join(self.save_dir, self.name + f'_{self.process_chunks[i]}.bin')
255
- if not os.path.exists(graph_path):
256
- print(f'File {graph_path} does not exist, processing.')
257
- return False
258
- return True
259
-
260
- def load(self):
261
- if self.chunks == 1:
262
- print(f'Loading dataset from {os.path.join(self.save_dir, self.name + ".bin")}')
263
- graphs, label_dict = dgl.load_graphs(os.path.join(self.save_dir, self.name + '.bin'))
264
- self.graphs = graphs
265
- self.labels = label_dict['labels']
266
- self.tracking = label_dict['tracking']
267
- self.global_features = label_dict['global']
268
- else:
269
- self.graphs = []
270
- self.labels = []
271
- self.tracking = []
272
- self.global_features = []
273
- for i in range(self.chunks):
274
- try:
275
- print(f'Loading dataset from {os.path.join(self.save_dir, self.name + f"_{self.process_chunks[i]}.bin")}')
276
- graphs, label = dgl.load_graphs(os.path.join(self.save_dir, self.name + f'_{self.process_chunks[i]}.bin'))
277
- self.graphs.extend(graphs)
278
- self.labels.append(label['labels'])
279
- self.tracking.append(label['tracking'])
280
- self.global_features.append(label['global'])
281
- except Exception as e:
282
- print(e)
283
- self.labels = torch.cat(self.labels)
284
- self.tracking = torch.cat(self.tracking)
285
- self.global_features = torch.cat(self.global_features)
286
-
287
- def __getitem__(self, idx):
288
- return self.graphs[idx], self.labels[idx], self.tracking[idx], self.global_features[idx]
289
-
290
- def __len__(self):
291
- return len(self.graphs)
292
-
293
- #Dataset with edge features added (deta, dphi, dR)
294
- class EdgeDataset(RootDataset):
295
- def make_graph(self, ch):
296
- g = super().make_graph(ch)
297
- u, v = g.edges()
298
- deta = g.ndata['features'][u, 1] - g.ndata['features'][v, 1]
299
- dphi = g.ndata['features'][u, 2] - g.ndata['features'][v, 2]
300
- dphi = torch.where(dphi > np.pi, dphi - 2*np.pi, dphi)
301
- dphi = torch.where(dphi < -np.pi, dphi + 2*np.pi, dphi)
302
- dR = torch.sqrt(deta**2 + dphi**2)
303
- g.edata['features'] = torch.stack([deta, dphi, dR], dim=1)
304
- return g
305
-
306
- class tHbbEdgeDataset(RootDataset):
307
- def __init__(self, exclude_branches=None, **kwargs):
308
- self.exclude_branches = exclude_branches
309
- super().__init__(**kwargs)
310
-
311
- def get_list_of_branches(self):
312
- br = super().get_list_of_branches()
313
- for sector in self.exclude_branches:
314
- if sector == None:
315
- continue
316
- for excl in sector:
317
- if type(excl) == str:
318
- br.append(excl)
319
- return br
320
-
321
- def make_graph(self, ch):
322
- features, lengths = node_features_from_tree(ch, self.node_branch_names, self.node_branch_types, self.node_feature_scales)
323
-
324
- include_mask = torch.ones(features.shape[0], dtype=torch.bool)
325
- node_idx = 0
326
- for sector, length in zip(self.exclude_branches, lengths):
327
- if sector == None:
328
- node_idx += length
329
- continue
330
- for excl in sector:
331
- if type(excl) == int:
332
- include_mask[excl + node_idx] = False
333
- elif type(excl) == str:
334
- include_mask[getattr(self.chain, excl) + node_idx] = False
335
- g = full_connected_graph(features[include_mask].shape[0], self_loops=False)
336
- g.ndata['features'] = features[include_mask]
337
-
338
- u, v = g.edges()
339
- deta = g.ndata['features'][u, 1] - g.ndata['features'][v, 1]
340
- dphi = g.ndata['features'][u, 2] - g.ndata['features'][v, 2]
341
- dphi = torch.where(dphi > np.pi, dphi - 2*np.pi, dphi)
342
- dphi = torch.where(dphi < -np.pi, dphi + 2*np.pi, dphi)
343
- dR = torch.sqrt(deta**2 + dphi**2)
344
- g.edata['features'] = torch.stack([deta, dphi, dR], dim=1)
345
- return g
346
-
347
- class LazyDataset(EdgeDataset):
348
- def __init__(self, buffer_size = 2, **kwargs):
349
- self.buffer = [None,] * buffer_size
350
- self.buffer_ptr = 0
351
- self.get_item_calls = 0
352
- self.buffer_indices = [-1,] * buffer_size
353
- super().__init__(**kwargs)
354
-
355
- def __getitem__(self, idx):
356
- self.get_item_calls += 1
357
- chunk_idx = -1
358
- sum = 0
359
- ev_idx = -999
360
- for i, count in enumerate(self.counts):
361
- sum += count
362
- if idx < sum:
363
- chunk_idx = i
364
- ev_idx = idx - sum + count
365
- break
366
- buf_idx = self.buffer_get(chunk_idx)
367
- if ev_idx >= len(self.buffer[buf_idx][0]):
368
- print(f'Getting event {ev_idx} from chunk {chunk_idx} from buffer {buf_idx}. Calls: {self.get_item_calls}')
369
- print(len(self.buffer))
370
- print(self.counts)
371
- print(len(self.buffer[buf_idx][0]))
372
- return self.buffer[buf_idx][0][ev_idx], self.buffer[buf_idx][1]['labels'][ev_idx], self.buffer[buf_idx][1]['tracking'][ev_idx], self.buffer[buf_idx][1]['global'][ev_idx]
373
-
374
- def buffer_get(self, buffer_idx):
375
- if buffer_idx in self.buffer_indices:
376
- for i in range(len(self.buffer)):
377
- if self.buffer_indices[i] == buffer_idx:
378
- return i
379
- else:
380
- print(f'Loading dataset from {os.path.join(self.save_dir, self.name + f"_{buffer_idx}.bin")}', flush=True)
381
- self.buffer_ptr = (self.buffer_ptr + 1) % len(self.buffer)
382
- self.buffer[self.buffer_ptr] = dgl.load_graphs(os.path.join(self.save_dir, self.name + f'_{buffer_idx}.bin'))
383
- self.buffer_indices[self.buffer_ptr] = buffer_idx
384
- return self.buffer_ptr
385
-
386
- def load(self):
387
- self.counts = []
388
- self.tracking = []
389
- try:
390
- for i in range(self.chunks):
391
- print(f'Loading dataset from {os.path.join(self.save_dir, self.name + f"_{self.process_chunks[i]}.bin")}')
392
- l = dgl.data.graph_serialize.load_labels_v2(os.path.join(self.save_dir, self.name + f'_{self.process_chunks[i]}.bin'))
393
- self.counts.append(len(l['tracking']))
394
- self.tracking.append(l['tracking'])
395
- self.tracking = torch.cat(self.tracking)
396
- except Exception as e:
397
- print(e)
398
-
399
- def __len__(self):
400
- return sum(self.counts)
401
-
402
- class MultiLabelDataset(EdgeDataset):
403
- def __init__(self, **kwargs):
404
- super().__init__(**kwargs)
405
-
406
- def get_list_of_branches(self):
407
- br = super().get_list_of_branches()
408
- for l in self.label:
409
- if isinstance(l, str):
410
- br.append(l)
411
- if isinstance(l, dict):
412
- br.append(l['branch'])
413
- return br
414
-
415
- def get_label(self, ch):
416
- label = []
417
- for l in self.label:
418
- if isinstance(l, str):
419
- label.append((getattr(ch, l)))
420
- if isinstance(l, dict):
421
- label.append(getattr(ch, l['branch'])*float(l['scale']))
422
- if isinstance(l, float) or isinstance(l, int):
423
- label.append(l)
424
-
425
- return torch.tensor(label)
426
-
427
- def process(self):
428
- times = [0, 0, 0]
429
- oldtime = time.time()
430
- if isinstance(self.file_names, str):
431
- self.files = glob.glob(os.path.join(self.raw_dir, self.file_names))
432
- else:
433
- self.files = []
434
- for file_name in self.file_names:
435
- self.files.extend(glob.glob(os.path.join(self.raw_dir, file_name)))
436
- self.chain = ROOT.TChain(self.tree_name)
437
- if len(self.files) == 0:
438
- print('No files found in {}'.format(os.path.join(self.raw_dir, self.file_names)))
439
- for file in self.files:
440
- utils.set_timeout(60*2)
441
- self.chain.Add(file)
442
- utils.unset_timeout()
443
- branches = self.get_list_of_branches()
444
- self.chain.SetBranchStatus('*', 0)
445
- for branch in branches:
446
- self.chain.SetBranchStatus(branch, 1)
447
- newtime = time.time()
448
- times[0] += newtime - oldtime
449
- chunks = np.array_split(np.arange(self.chain.GetEntries()), self.chunks)
450
- chunks = [chunk for i, chunk in enumerate(chunks) if i in self.process_chunks]
451
- self.graph_chunks = []
452
- self.label_chunks = []
453
- self.tracking_chunks = []
454
- self.global_chunks = []
455
- chunk_id = -1
456
- for chunk in chunks:
457
- chunk_id += 1
458
- graphs = []
459
- labels = []
460
- tracking = []
461
- globals = []
462
- for ientry in chunk:
463
- if (ientry % 10000 == 0):
464
- print('Processing event {}/{}'.format(ientry, self.chain.GetEntries()), flush=True)
465
- self.chain.GetEntry(ientry)
466
- passed = True
467
- for selection in self.selections:
468
- if not check_selection(self.chain, selection):
469
- passed = False
470
- continue
471
- oldtime = newtime
472
- newtime = time.time()
473
- times[1] += newtime - oldtime
474
- if passed:
475
- graphs.append(self.make_graph(self.chain))
476
- labels.append(self.get_label(self.chain))
477
- tracking.append(torch.zeros(len(self.tracking_info), dtype=torch.double))
478
- globals.append(torch.zeros(len(self.global_features)))
479
- for i_ti, tr_branch in enumerate(self.tracking_info):
480
- if isinstance(tr_branch, str):
481
- tracking[-1][i_ti] = getattr(self.chain, tr_branch)
482
- else:
483
- tracking[-1][i_ti] = tr_branch
484
- for i_gl, gl_branch in enumerate(self.global_features):
485
- globals[-1][i_gl] = getattr(self.chain, gl_branch)
486
- oldtime = newtime
487
- newtime = time.time()
488
- times[2] += newtime - oldtime
489
-
490
- labels = torch.stack(labels)
491
- self.save_chunk(chunk_id, graphs, labels, torch.stack(tracking), torch.stack(globals))
492
- # self.graph_chunks.append(graphs)
493
- # self.label_chunks.append(labels)
494
- # self.tracking_chunks.append(torch.stack(tracking))
495
- # self.global_chunks.append(torch.stack(globals))
496
- # self.counts.append(len(graphs))
497
- return
498
- self.graphs = self.graph_chunks[0]
499
- for chunk in self.graph_chunks[1:]:
500
- self.graphs += chunk
501
-
502
- self.labels = torch.cat(self.label_chunks)
503
- self.tracking = torch.cat(self.tracking_chunks)
504
- self.global_features = torch.cat(self.global_chunks)
505
- print('Time spent: Creating TChain: {}s, Getting Entries and Selection: {}s, Graph Creation: {}s'.format(*times))
506
- print('Time spent in node_features_from_tree: {}s, full_connected_graph: {}s'.format(*self.times))
507
-
508
- class LazyMultiLabelDataset(MultiLabelDataset, LazyDataset):
509
- def __init__(self, buffer_size = 2, **kwargs):
510
- LazyDataset.__init__(self, buffer_size=buffer_size, **kwargs)
511
-
512
- class MultiLabeltHbbDataset(MultiLabelDataset, tHbbEdgeDataset):
513
- def __init__(self, **kwargs):
514
- super().__init__(**kwargs)
515
-
516
- def get_list_of_branches(self):
517
- br = super().get_list_of_branches()
518
- for sector in self.exclude_branches:
519
- if sector == None:
520
- continue
521
- for excl in sector:
522
- if type(excl) == str:
523
- br.append(excl)
524
- return br
525
-
526
-
527
- class AugmentedDataset(RootDataset):
528
-
529
- def __init__(self, seed = 2, feature_index = None, node_mapping = None, **kwargs):
530
- self.seed = seed
531
- np.random.seed(seed)
532
- if(feature_index == None):
533
- self.feature_index = {"pt": 0, "eta": 1, "phi": 2, "energy": 3, "btag": 4, "charge": 5, "node_type": 6}
534
- if (node_mapping == None):
535
- self.node_mapping = {"jet": 0, "ele": 1, "mu": 2, "ph": 3, "MET": 4}
536
- super().__init__(**kwargs)
537
-
538
- def detector_noise(self, node_features):
539
- noise = np.zeros_like(node_features)
540
-
541
- node_types = node_features[:, self.feature_index["node_type"]]
542
- pts = node_features[:, self.feature_index["pt"]]
543
- etas = node_features[:, self.feature_index["eta"]]
544
- energies = node_features[:, self.feature_index["energy"]]
545
-
546
- # Noise calculation for jets
547
- jet_mask = (node_types == self.node_mapping["jet"])
548
- jet_pts = pts[jet_mask]
549
- jet_etas = etas[jet_mask]
550
-
551
- if (jet_mask.sum() > 0):
552
- jet_resolutions = np.where(
553
- jet_pts <= 0.1, 0.0,
554
- np.where(
555
- np.abs(jet_etas) <= 0.5, np.sqrt(0.06**2 + jet_pts**2 * 1.3e-3**2),
556
- np.where(
557
- np.abs(jet_etas) <= 1.5, np.sqrt(0.10**2 + jet_pts**2 * 1.7e-3**2),
558
- np.where(
559
- np.abs(jet_etas) <= 2.5, np.sqrt(0.25**2 + jet_pts**2 * 3.1e-3**2),
560
- 0.0
561
- )
562
- )
563
- )
564
- )
565
- noise[jet_mask, self.feature_index["pt"]] = np.random.normal(loc=0.0, scale=jet_resolutions)
566
-
567
- # Noise calculation for electrons
568
- ele_mask = (node_types == self.node_mapping["ele"])
569
- ele_pts = pts[ele_mask]
570
- ele_etas = etas[ele_mask]
571
-
572
- if (ele_mask.sum() > 0):
573
- ele_resolutions = np.where(
574
- np.abs(ele_etas) <= 0.5, np.sqrt(0.03**2 + ele_pts**2 * 1.3e-3**2),
575
- np.where(
576
- np.abs(ele_etas) <= 1.5, np.sqrt(0.05**2 + ele_pts**2 * 1.7e-3**2),
577
- np.where(
578
- np.abs(ele_etas) <= 2.5, np.sqrt(0.15**2 + ele_pts**2 * 3.1e-3**2),
579
- 0.0
580
- )
581
- )
582
- )
583
- noise[ele_mask, self.feature_index["pt"]] = np.random.normal(loc=0.0, scale=ele_resolutions)
584
-
585
- # Noise calculation for muons
586
- mu_mask = (node_types == self.node_mapping["mu"])
587
- mu_pts = pts[mu_mask]
588
- mu_etas = etas[mu_mask]
589
-
590
- if (mu_mask.sum() > 0):
591
- mu_resolutions = np.where(
592
- np.abs(mu_etas) <= 0.5, np.sqrt(0.01**2 + mu_pts**2 * 1.0e-4**2),
593
- np.where(
594
- np.abs(mu_etas) <= 1.5, np.sqrt(0.015**2 + mu_pts**2 * 1.5e-4**2),
595
- np.where(
596
- np.abs(mu_etas) <= 2.5, np.sqrt(0.025**2 + mu_pts**2 * 3.5e-4**2),
597
- 0.0
598
- )
599
- )
600
- )
601
- noise[mu_mask, self.feature_index["pt"]] = np.random.normal(loc=0.0, scale=mu_resolutions)
602
-
603
- # Noise calculation for photons
604
- ph_mask = (node_types == self.node_mapping["ph"])
605
- ph_etas = etas[ph_mask]
606
- ph_energies = energies[ph_mask]
607
-
608
- if (ph_mask.sum() > 0):
609
- ph_resolutions = np.where(
610
- np.abs(ph_etas) <= 3.2, np.sqrt(ph_energies**2 * 0.0017**2 + ph_energies * 0.101**2),
611
- np.where(
612
- np.abs(ph_etas) <= 4.9, np.sqrt(ph_energies**2 * 0.0350**2 + ph_energies * 0.285**2),
613
- 0.0
614
- )
615
- )
616
- noise[ph_mask, self.feature_index["energy"]] = np.random.normal(loc=0.0, scale=ph_resolutions)
617
- return noise
618
-
619
- def make_graph(self, ch):
620
- g = super().make_graph(ch)
621
-
622
- g.ndata['augmented_features'] = g.ndata['features']
623
-
624
- num_nodes = len(g.ndata['features'][:, 0])
625
-
626
- # Rotations: phi -> phi + delta_phi
627
- phi_index = self.feature_index["phi"]
628
- # Generate a single delta_phi for all nodes
629
- delta_phi = np.random.uniform(low=-np.pi, high=np.pi)
630
-
631
- # Apply the same delta_phi to all nodes
632
- g.ndata['augmented_features'][:, phi_index] = (g.ndata['augmented_features'][:, phi_index] + delta_phi + np.pi) % (2 * np.pi) - np.pi
633
-
634
- # Reflections: eta -> -1 * eta, phi -> -1 * phi
635
- eta_index = self.feature_index["eta"]
636
-
637
- eta_reflection = np.random.choice([-1, 1])
638
- phi_reflection = np.random.choice([-1, 1])
639
-
640
- g.ndata['augmented_features'][:, eta_index] = g.ndata['augmented_features'][:, eta_index] * eta_reflection
641
- g.ndata['augmented_features'][:, phi_index] = g.ndata['augmented_features'][:, phi_index] * phi_reflection
642
-
643
-
644
- # Detector Noise: pt -> pt + normal(pt, noise(pt))
645
- noise = self.detector_noise(g.ndata['augmented_features'])
646
- g.ndata['augmented_features'] = g.ndata['augmented_features'] + noise
647
-
648
- pt_index = self.feature_index["pt"]
649
- if (g.ndata['augmented_features'][-1][self.feature_index["node_type"]] == self.node_mapping["MET"]):
650
- # Initialize sums of px and py
651
- sum_px = 0
652
- sum_py = 0
653
-
654
- # Loop over all nodes except the last one (MET node)
655
- for i in range(len(g.ndata['augmented_features']) - 1):
656
- pt = g.ndata['augmented_features'][i][pt_index]
657
- phi = g.ndata['augmented_features'][i][phi_index]
658
-
659
- # Compute px and py
660
- px = pt * np.cos(phi)
661
- py = pt * np.sin(phi)
662
-
663
- # Sum px and py
664
- sum_px += px
665
- sum_py += py
666
-
667
- # Calculate MET
668
- g.ndata['augmented_features'][-1][pt_index] = np.sqrt(sum_px**2 + sum_py**2)
669
-
670
- u, v = g.edges()
671
- deta = g.ndata['features'][u, 1] - g.ndata['features'][v, 1]
672
- dphi = g.ndata['features'][u, 2] - g.ndata['features'][v, 2]
673
- dphi = torch.where(dphi > np.pi, dphi - 2*np.pi, dphi)
674
- dphi = torch.where(dphi < -np.pi, dphi + 2*np.pi, dphi)
675
- dR = torch.sqrt(deta**2 + dphi**2)
676
- g.edata['features'] = torch.stack([deta, dphi, dR], dim=1)
677
-
678
- deta = g.ndata['augmented_features'][u, 1] - g.ndata['augmented_features'][v, 1]
679
- dphi = g.ndata['augmented_features'][u, 2] - g.ndata['augmented_features'][v, 2]
680
- dphi = torch.where(dphi > np.pi, dphi - 2*np.pi, dphi)
681
- dphi = torch.where(dphi < -np.pi, dphi + 2*np.pi, dphi)
682
- dR = torch.sqrt(deta**2 + dphi**2)
683
- g.edata['augmented_features'] = torch.stack([deta, dphi, dR], dim=1)
684
-
685
- return g
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
root_gnn_base/photon_ID_dataset.py DELETED
@@ -1,44 +0,0 @@
1
- from root_gnn_base import dataset
2
- import dgl
3
- import torch
4
- import numpy as np
5
-
6
- def radius_graph(features, radii, self_loops=False):
7
- senders = []
8
- receivers = []
9
- n_nodes = features.shape[0]
10
- senders = np.arange(n_nodes*n_nodes) // n_nodes
11
- receivers = np.arange(n_nodes*n_nodes) % n_nodes
12
- if not self_loops and n_nodes > 1:
13
- mask = senders != receivers
14
- senders = senders[mask]
15
- receivers = receivers[mask]
16
- for k, r in radii.items():
17
- d = features[senders, k] - features[receivers, k]
18
- mask = np.abs(d) < r
19
- senders = senders[mask]
20
- receivers = receivers[mask]
21
- return dgl.graph((senders, receivers))
22
-
23
- class PhotonIDDataset(dataset.LazyMultiLabelDataset):
24
- def __init__(self, eta_radius, phi_radius, **kwargs):
25
- self.eta_radius = eta_radius
26
- self.phi_radius = phi_radius
27
- super().__init__(**kwargs)
28
- def make_graph(self, ch):
29
- features, _ = dataset.node_features_from_tree(ch, self.node_branch_names, self.node_branch_types, self.node_feature_scales)
30
- features = features[features[:,0] != 0]
31
- #Delta Eta, Delta Phi, Adjacent Layer
32
- g = radius_graph(features, {1: self.eta_radius, 2: self.phi_radius, 6: 1.1}, self_loops=True) #Self loops ensure last cell is included even if disconnected
33
- g.ndata['features'] = features
34
- u, v = g.edges()
35
- deta = features[u, 1] - features[v, 1]
36
- dphi = g.ndata['features'][u, 2] - g.ndata['features'][v, 2]
37
- dphi = torch.where(dphi > np.pi, dphi - 2*np.pi, dphi)
38
- dphi = torch.where(dphi < -np.pi, dphi + 2*np.pi, dphi)
39
- dR = torch.sqrt(deta**2 + dphi**2)
40
- dx = features[u, 3] - features[v, 3]
41
- dy = features[u, 4] - features[v, 4]
42
- dz = features[u, 5] - features[v, 5]
43
- g.edata['features'] = torch.stack([deta, dphi, dR, dx, dy, dz], dim=1)
44
- return g
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
root_gnn_base/similarity.py DELETED
@@ -1,158 +0,0 @@
1
- import numpy as np
2
- import scipy
3
- from sklearn.decomposition import PCA
4
- from sklearn.metrics.pairwise import cosine_similarity
5
- from sklearn.metrics.pairwise import euclidean_distances
6
- from sklearn.preprocessing import StandardScaler
7
-
8
- from scipy.stats import wasserstein_distance
9
-
10
- def cka(rep_a, rep_b, size=None):
11
- """
12
- Computes the Centered Kernel Alignment (CKA) between two large representation matrices rep_a and rep_b.
13
- If size is provided, it performs CKA on a randomly selected subset of the data.
14
-
15
- Parameters:
16
- rep_a : np.ndarray
17
- First representation matrix of size (n_samples, n_features_a).
18
- rep_b : np.ndarray
19
- Second representation matrix of size (n_samples, n_features_b).
20
- size : int, optional
21
- Number of samples to use for the CKA calculation. If None, use the full dataset.
22
-
23
- Returns:
24
- float
25
- CKA similarity between rep_a and rep_b.
26
- """
27
-
28
- def gram_linear(x):
29
- """Compute the Gram (kernel) matrix using a linear kernel."""
30
- return x @ x.T
31
-
32
- def center_gram(gram):
33
- """Center the Gram matrix."""
34
- n = gram.shape[0]
35
- identity = np.eye(n)
36
- ones = np.ones((n, n)) / n
37
- return gram - ones @ gram - gram @ ones + ones @ gram @ ones
38
-
39
- # If sample_size is specified, randomly sample a subset of the data
40
- if size is not None and size < rep_a.shape[0]:
41
- indices = np.random.choice(rep_a.shape[0], size, replace=False)
42
- rep_a = rep_a[indices]
43
- rep_b = rep_b[indices]
44
-
45
- # Compute the Gram matrices
46
- gram_a = gram_linear(rep_a)
47
- gram_b = gram_linear(rep_b)
48
-
49
- # Center the Gram matrices
50
- centered_gram_a = center_gram(gram_a)
51
- centered_gram_b = center_gram(gram_b)
52
-
53
- # Compute the CKA similarity
54
- numerator = np.sum(centered_gram_a * centered_gram_b)
55
- denominator = np.sqrt(np.sum(centered_gram_a**2) * np.sum(centered_gram_b**2))
56
-
57
- return numerator / denominator if denominator != 0 else 0
58
-
59
- def cca(X, Y, size = None, num_components=10):
60
- """
61
- Perform Canonical Correlation Analysis (CCA) between two datasets.
62
-
63
- Parameters:
64
- X : np.ndarray
65
- First dataset, shape (n_samples, n_features_X).
66
- Y : np.ndarray
67
- Second dataset, shape (n_samples, n_features_Y).
68
- num_components : int
69
- Number of CCA components to return.
70
-
71
- Returns:
72
- w_X : np.ndarray
73
- Canonical weights for the first dataset, shape (n_features_X, num_components).
74
- w_Y : np.ndarray
75
- Canonical weights for the second dataset, shape (n_features_Y, num_components).
76
- corrs : np.ndarray
77
- Array of canonical correlations for each component.
78
- """
79
-
80
- # If sample size is specified, randomly sample a subset of the data
81
- if size is not None and size < X.shape[0]:
82
- indices = np.random.choice(X.shape[0], size, replace=False)
83
- X = X[indices]
84
- Y = Y[indices]
85
-
86
- # Standardize both datasets (mean = 0, variance = 1)
87
- scaler_X = StandardScaler()
88
- scaler_Y = StandardScaler()
89
-
90
- X = scaler_X.fit_transform(X)
91
- Y = scaler_Y.fit_transform(Y)
92
-
93
- # Covariance matrices
94
- C_XX = np.cov(X, rowvar=False) # Covariance of X
95
- C_YY = np.cov(Y, rowvar=False) # Covariance of Y
96
- C_XY = np.cov(X, Y, rowvar=False)[:X.shape[1], X.shape[1]:] # Cross-covariance of X and Y
97
-
98
- # Regularization term to avoid singular matrices
99
- reg = 1e-6
100
- inv_C_XX = np.linalg.inv(C_XX + reg * np.eye(C_XX.shape[0]))
101
- inv_C_YY = np.linalg.inv(C_YY + reg * np.eye(C_YY.shape[0]))
102
-
103
- # Solve the generalized eigenvalue problem for CCA
104
- # (inv_C_XX @ C_XY @ inv_C_YY @ C_XY.T) and vice versa for Y
105
- A = inv_C_XX @ C_XY @ inv_C_YY @ C_XY.T
106
- B = inv_C_YY @ C_XY.T @ inv_C_XX @ C_XY
107
-
108
- # Perform eigenvalue decomposition
109
- eigvals_X, eigvecs_X = np.linalg.eigh(A)
110
- eigvals_Y, eigvecs_Y = np.linalg.eigh(B)
111
-
112
- # Sort the eigenvalues and eigenvectors in descending order
113
- idx_X = np.argsort(eigvals_X)[::-1]
114
- idx_Y = np.argsort(eigvals_Y)[::-1]
115
-
116
- eigvecs_X = eigvecs_X[:, idx_X]
117
- eigvecs_Y = eigvecs_Y[:, idx_Y]
118
-
119
- # Canonical weights (the first `num_components` components)
120
- w_X = eigvecs_X[:, :num_components]
121
- w_Y = eigvecs_Y[:, :num_components]
122
-
123
- # Canonical correlations (square root of the eigenvalues, constrained to [0,1])
124
- corrs = np.sqrt(np.clip(eigvals_X[:num_components], 0, 1))
125
-
126
- return np.mean(corrs)
127
- return w_X, w_Y, corrs
128
-
129
- def pca(X, Y, size=1000, n_components=3, bins=30):
130
-
131
- pca_X = PCA(n_components=n_components)
132
- X_pca = pca_X.fit_transform(X)
133
-
134
- pca_Y = PCA(n_components=n_components)
135
- Y_pca = pca_Y.fit_transform(Y)
136
-
137
- # Step 2: Determine common bin edges based on the range of PCA components
138
- min_value = min(X_pca.min(), Y_pca.min())
139
- max_value = max(X_pca.max(), Y_pca.max())
140
- bin_edges = np.linspace(min_value, max_value, bins + 1)
141
-
142
- # Step 3: Calculate histograms for each PCA component using the same bins
143
- histograms_X = [np.histogram(X_pca[:, i], bins=bin_edges, density=True)[0] for i in range(n_components)]
144
- histograms_Y = [np.histogram(Y_pca[:, i], bins=bin_edges, density=True)[0] for i in range(n_components)]
145
-
146
- # Step 4: Calculate Wasserstein distance between corresponding histograms
147
- total_distance = 0
148
- for i in range(n_components):
149
- total_distance += wasserstein_distance(histograms_X[i], histograms_Y[i])
150
-
151
- # Step 5: Normalize the total distance for a similarity score
152
- # Calculate the maximum possible distance (theoretical max could be based on histogram size)
153
- # This could be replaced with a more complex calculation if necessary.
154
- max_distance = 1.0 # Replace this with a suitable maximum based on your dataset properties.
155
-
156
- similarity_score = 1 - (total_distance / max_distance)
157
-
158
- return max(0, min(1, similarity_score)) # Ensure the score stays in [0, 1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
root_gnn_base/uproot_dataset.py DELETED
@@ -1,54 +0,0 @@
1
- from root_gnn_base import dataset
2
- import torch
3
- import uproot
4
- import glob
5
- import os
6
- import awkward as ak
7
- import numpy as np
8
- import time
9
-
10
- def node_features_from_ak(ch, node_branch_names, node_branch_types, node_feature_scales):
11
- node_types = []
12
- n_types = len(node_branch_names[0])
13
- for i in range(n_types):
14
- features = []
15
- branch_type = node_branch_types[i]
16
- for j in range(len(node_branch_names)):
17
- if node_branch_names[j] == 'CALC_E':
18
- features.append(features[0] * np.cosh(features[1]))
19
- elif node_branch_names[j] == 'NODE_TYPE':
20
- features.append(ak.full_like(features[0], i))
21
- elif isinstance(node_branch_names[j][i], str):
22
- features.append(ch[node_branch_names[j][i]])
23
- elif isinstance(node_branch_names[j][i], (int, float)):
24
- features.append(ak.full_like(features[0], node_branch_names[j][i]))
25
- if branch_type == 'single':
26
- features = [f[:,np.newaxis] for f in features]
27
- node_types.append(ak.Array(features))
28
- node_features = ak.concatenate(node_types, axis=2) * node_feature_scales #axis order at this point is (feature, event, node)
29
- return node_features
30
-
31
- class UprootDataset(dataset.RootDataset):
32
- def process(self):
33
- starttime = time.time()
34
- self.files = glob.glob(os.path.join(self.raw_dir, self.file_names))
35
- branches = self.get_list_of_branches()
36
- self.chain = uproot.concatenate([f + ':' + self.tree_name for f in self.files], branches, num_workers=4)
37
- node_features = node_features_from_ak(self.chain, self.node_branch_names, self.node_branch_types, self.node_feature_scales)
38
- loadtime = time.time()
39
- n_nodes = ak.num(node_features[0], axis=1) #number of nodes for each event
40
- ftime = time.time()
41
- self.graphs = [dataset.full_connected_graph(n, False) for n in n_nodes]
42
- itime = time.time()
43
- for i in range(len(self.graphs)):
44
- if i % 10000 == 0:
45
- print(f'Processing event {i}/{len(self.graphs)}')
46
- self.graphs[i].ndata['features'] = torch.transpose(torch.tensor(node_features[:,i,:]),0,1).to(torch.float)
47
- self.label = torch.stack([torch.full((len(self.graphs),),torch.tensor(self.label)), torch.tensor(ak.values_astype(self.chain[self.fold_var], np.int64))], dim=1)
48
- gtime = time.time()
49
- print()
50
- print(f'load time: {loadtime - starttime} s')
51
- print(f'feature time: {ftime - loadtime} s')
52
- print(f'graph time: {itime - ftime} s')
53
- print(f'graph data time: {gtime - itime} s')
54
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
root_gnn_base/utils.py DELETED
@@ -1,307 +0,0 @@
1
- import importlib
2
- import yaml
3
- import os
4
- import torch
5
- import numpy as np
6
- import matplotlib.pyplot as plt
7
- import dgl
8
- import signal
9
-
10
- def buildFromConfig(conf, run_time_args = {}):
11
- if 'module' in conf:
12
- module = importlib.import_module(conf['module'])
13
- cls = getattr(module, conf['class'])
14
- return cls(**conf['args'], **run_time_args)
15
- else:
16
- print('No module specified in config. Returning None.')
17
-
18
- def cycler(iterable):
19
- while True:
20
- #print('Cycler is cycling...')
21
- for i in iterable:
22
- yield i
23
-
24
- def include_config(conf):
25
- if 'include' in conf:
26
- for i in conf['include']:
27
- with open(i) as f:
28
- conf.update(yaml.load(f, Loader=yaml.FullLoader))
29
- del conf['include']
30
-
31
- def load_config(config_file):
32
- with open(config_file) as f:
33
- conf = yaml.load(f, Loader=yaml.FullLoader)
34
- include_config(conf)
35
- return conf
36
-
37
- #Timeout function from https://stackoverflow.com/questions/492519/timeout-on-a-function-call
38
- class TimeoutException(Exception):
39
- pass
40
-
41
- def timeout_handler(signum, frame):
42
- raise TimeoutException()
43
-
44
- def set_timeout(timeout):
45
- signal.signal(signal.SIGALRM, timeout_handler)
46
- signal.alarm(timeout)
47
-
48
- def unset_timeout():
49
- signal.alarm(0)
50
- signal.signal(signal.SIGALRM, signal.SIG_DFL)
51
-
52
- def make_padding_graph(batch, pad_nodes, pad_edges):
53
- senders = []
54
- receivers = []
55
- senders = torch.arange(0,pad_edges) // pad_nodes
56
- receivers = torch.arange(1,pad_edges+1) % pad_nodes
57
- if pad_nodes < 0 or pad_edges < 0 or pad_edges > pad_nodes * pad_nodes / 2:
58
- print('Batch is larger than padding size or e > n^2/2. Repeating edges as necessary.')
59
- print(f'Batch nodes: {batch.num_nodes()}, Batch edges: {batch.num_edges()}, Padding nodes: {pad_nodes}, Padding edges: {pad_edges}')
60
- senders = senders % pad_nodes
61
- padg = dgl.graph((senders[:pad_edges], receivers[:pad_edges]), num_nodes = pad_nodes)
62
- for k in batch.ndata.keys():
63
- padg.ndata[k] = torch.zeros( (pad_nodes, batch.ndata[k].shape[1]) )
64
- for k in batch.edata.keys():
65
- padg.edata[k] = torch.zeros( (pad_edges, batch.edata[k].shape[1]) )
66
- return dgl.batch([batch, padg.to(batch.device)])
67
-
68
- def pad_size(graphs, edges, nodes, edge_per_graph=3, node_per_graph=14):
69
- pad_nodes = ((nodes // (node_per_graph * graphs))+1) * graphs * node_per_graph
70
- pad_edges = ((edges // (edge_per_graph * graphs))+1) * graphs * edge_per_graph
71
- return pad_nodes, pad_edges
72
-
73
- def pad_batch_to_step_per_graph(batch, edge_per_graph=3, node_per_graph=14):
74
- n_graphs = batch.batch_num_nodes().shape[0]
75
- pad_nodes = (batch.num_nodes() + node_per_graph * n_graphs) % int(n_graphs * node_per_graph)
76
- pad_edges = (batch.num_edges() + edge_per_graph * n_graphs) % int(n_graphs * edge_per_graph)
77
- return make_padding_graph(batch, pad_nodes, pad_edges)
78
-
79
- def pad_batch(batch, edges = 104000, nodes = 16000):
80
- if edges == 0 and nodes == 0:
81
- return batch
82
- pad_nodes = 0
83
- pad_edges = 0
84
- pad_nodes = nodes - batch.num_nodes()
85
- pad_edges = edges - batch.num_edges()
86
- return make_padding_graph(batch, pad_nodes, pad_edges)
87
-
88
- def pad_batch_num_nodes(batch, max_num_nodes, hid_size = 64):
89
- print(f"Padding each graph to have {max_num_nodes} nodes")
90
-
91
- unbatched = dgl.unbatch(batch)
92
- for g in unbatched:
93
- num_nodes_to_add = max_num_nodes - g.number_of_nodes()
94
- if num_nodes_to_add > 0:
95
- g.add_nodes(num_nodes_to_add) # Add isolated nodes
96
-
97
- batch = dgl.batch(unbatched)
98
-
99
- padding_mask = torch.zeros((batch.ndata['features'].shape[0]), dtype=torch.bool)
100
- global_update_weights = torch.ones((batch.ndata['features'].shape[0], hid_size))
101
-
102
- for i in range(len(batch.ndata['features'])):
103
- if (torch.count_nonzero(batch.ndata['features'][i]) == 0):
104
- padding_mask[i] = True
105
- global_update_weights[i] = 0
106
-
107
- batch.ndata['w'] = global_update_weights
108
- batch.ndata['padding_mask'] = padding_mask
109
-
110
- return batch
111
-
112
-
113
- def fold_selection(fold_config, sample):
114
- n_folds = fold_config['n_folds']
115
- folds_opt = fold_config[sample]
116
- folds = []
117
- if type(folds_opt) == int:
118
- return lambda x : x.tracking[:,0] % n_folds == folds_opt
119
- elif type(folds_opt) == list:
120
- print("fold type is list")
121
- print(f"fold_config = {fold_config}")
122
- print(f"folds_opt = {folds_opt}")
123
- return lambda x : sum([x.tracking[:,0] % n_folds == f for f in folds_opt]) == 1
124
- else:
125
- raise ValueError("Invalid fold selection option with type {}".format(type(folds_opt)))
126
-
127
- def fold_selection_name(fold_config, sample):
128
- n_folds = fold_config['n_folds']
129
- folds_opt = fold_config[sample]
130
- if type(folds_opt) == int:
131
- return f'n_{n_folds}_f_{folds_opt}'
132
- elif type(folds_opt) == list:
133
- return f'n_{n_folds}_f_{"_".join([str(f) for f in folds_opt])}'
134
- else:
135
- raise ValueError("Invalid fold selection option with type {}".format(type(folds_opt)))
136
-
137
- #Return the index and checkpoint of the last epoch.
138
- def get_last_epoch(config, max_ep = -1, device = None):
139
- last_epoch = -1
140
- checkpoint = None
141
- if max_ep < 0:
142
- max_ep = config['Training']['epochs']
143
- for ep in range(max_ep):
144
- if os.path.exists(os.path.join(config['Training_Directory'], f'model_epoch_{ep}.pt')):
145
- last_epoch = ep
146
- else:
147
- print(f'Epoch {ep} not found. Stopping at epoch {last_epoch}')
148
- print('File not found: ', os.path.join(config['Training_Directory'], f'model_epoch_{ep}.pt'))
149
- break
150
- if last_epoch >= 0:
151
- checkpoint = torch.load(os.path.join(config['Training_Directory'], f'model_epoch_{last_epoch}.pt'), map_location=device)
152
- return last_epoch, checkpoint
153
-
154
- #Return the index and checkpoint of the last epoch.
155
- def get_specific_epoch(config, target_epoch, device = None, from_ryan = False):
156
- last_epoch = -1
157
- checkpoint = None
158
- for ep in range(target_epoch + 1):
159
- if (from_ryan):
160
- if os.path.exists(os.path.join('/global/cfs/cdirs/atlas/berobert/root_gnn_dgl/' + config['Training_Directory'], f'model_epoch_{ep}.pt')):
161
- last_epoch = ep
162
- else:
163
- print(f'Epoch {ep} not found. Stopping at epoch {last_epoch}')
164
- print('File not found: ', os.path.join('/global/cfs/cdirs/atlas/berobert/root_gnn_dgl/' + config['Training_Directory'], f'model_epoch_{ep}.pt'))
165
- break
166
- else:
167
- if os.path.exists(os.path.join(config['Training_Directory'], f'model_epoch_{ep}.pt')):
168
- last_epoch = ep
169
- else:
170
- print(f'Epoch {ep} not found. Stopping at epoch {last_epoch}')
171
- print('File not found: ', os.path.join(config['Training_Directory'], f'model_epoch_{ep}.pt'))
172
- break
173
- if last_epoch >= 0:
174
- if (from_ryan):
175
- checkpoint = torch.load('/global/cfs/cdirs/atlas/berobert/root_gnn_dgl/' + os.path.join(config['Training_Directory'], f'model_epoch_{last_epoch}.pt'), map_location=device)
176
- else:
177
- checkpoint = torch.load(os.path.join(config['Training_Directory'], f'model_epoch_{last_epoch}.pt'), map_location=device)
178
- return last_epoch, checkpoint
179
-
180
- #Convert training logs into dict for plotting.
181
- def read_log(config):
182
- lines = []
183
- with open(config['Training_Directory'] + '/training.log', 'r') as f:
184
- lines = f.readlines()
185
- lines = [ l for l in lines if 'Epoch' in l ]
186
- nlines = len(lines)
187
- labels = []
188
- for field in lines[0].split('|'):
189
- labels.append(field.split()[0])
190
- log = {label : np.zeros(nlines) for label in labels}
191
- for i, line in enumerate(lines):
192
- for field in line.split('|'):
193
- spl = field.split()
194
- log[spl[0]][i] = float(spl[1])
195
- return log
196
-
197
- #Plot training logs.
198
- def plot_log(log, output_file):
199
- fig, ax = plt.subplots(2, 2, figsize=(10,10))
200
- #Time
201
-
202
- ax[0][0].plot(log['Epoch'], np.cumsum(log['Time']), label='Time')
203
- ax[0][0].set_xlabel('Epoch')
204
- ax[0][0].set_ylabel('Time (s)')
205
- ax[0][0].legend()
206
-
207
- """
208
- ax[0][0].plot(log['Epoch'], log['LR'], label='Learning Rate')
209
- ax[0][0].set_xlabel('Epoch')
210
- ax[0][0].set_ylabel('Learning Rate')
211
- ax[0][0].set_yscale('log')
212
- ax[0][0].legend()
213
- """
214
-
215
- #Loss
216
- ax[0][1].plot(log['Epoch'], log['Loss'], label='Train Loss')
217
- ax[0][1].plot(log['Epoch'], log['Test_Loss'], label='Test Loss')
218
- ax[0][1].set_xlabel('Epoch')
219
- ax[0][1].set_ylabel('Loss')
220
- ax[0][1].legend()
221
-
222
- #Accuracy
223
- ax[1][0].plot(log['Epoch'], log['Accuracy'], label='Test Accuracy')
224
- ax[1][0].set_xlabel('Epoch')
225
- ax[1][0].set_ylabel('Accuracy')
226
- ax[1][0].set_ylim((0.44, 0.56))
227
- ax[1][0].legend()
228
-
229
- #AUC
230
- ax[1][1].plot(log['Epoch'], log['Test_AUC'], label='Test AUC')
231
- ax[1][1].set_xlabel('Epoch')
232
- ax[1][1].set_ylabel('AUC')
233
- ax[1][1].legend()
234
-
235
- fig.savefig(output_file)
236
-
237
- class EarlyStop():
238
- def __init__(self, patience=15, threshold=1e-8, mode='min'):
239
- self.patience = patience
240
- self.threshold = threshold
241
- self.mode = mode
242
- self.count = 0
243
- self.current_best = np.inf if mode == 'min' else -np.inf
244
- self.should_stop = False
245
-
246
- def update(self, value):
247
- if self.mode == 'min': # Minimizing loss
248
- if value < self.current_best - self.threshold:
249
- self.current_best = value
250
- self.count = 0
251
- else:
252
- self.count += 1
253
- elif self.mode == 'max': # Maximizing metric
254
- if value > self.current_best + self.threshold:
255
- self.current_best = value
256
- self.count = 0
257
- else:
258
- self.count += 1
259
-
260
- # Check if patience is exceeded
261
- if self.count >= self.patience:
262
- self.should_stop = True
263
-
264
- def reset(self):
265
- self.count = 0
266
- self.current_best = np.inf if self.mode == 'min' else -np.inf
267
- self.should_stop = False
268
-
269
- def to_str(self):
270
- status = (
271
- f"EarlyStop Status:\n"
272
- f" Mode: {'Minimize' if self.mode == 'min' else 'Maximize'}\n"
273
- f" Patience: {self.patience}\n"
274
- f" Threshold: {self.threshold:.3e}\n"
275
- f" Current Best: {self.current_best:.6f}\n"
276
- f" Consecutive Epochs Without Improvement: {self.count}\n"
277
- f" Stopping Triggered: {'Yes' if self.should_stop else 'No'}"
278
- )
279
- return status
280
-
281
- def to_dict(self):
282
-
283
- return {
284
- 'patience': self.patience,
285
- 'threshold': self.threshold,
286
- 'mode': self.mode,
287
- 'count': self.count,
288
- 'current_best': self.current_best,
289
- 'should_stop': self.should_stop,
290
- }
291
-
292
- @classmethod
293
- def load_from_dict(cls, state_dict):
294
- instance = cls(
295
- patience=state_dict['patience'],
296
- threshold=state_dict['threshold'],
297
- mode=state_dict['mode']
298
- )
299
- instance.count = state_dict['count']
300
- instance.current_best = state_dict['current_best']
301
- instance.should_stop = state_dict['should_stop']
302
- return instance
303
-
304
-
305
- def graph_augmentation(graph):
306
- print("Augmenting Graph")
307
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/find_free_port.py DELETED
@@ -1,12 +0,0 @@
1
- # find_free_port.py
2
- def find_free_port():
3
- import socket
4
- from contextlib import closing
5
-
6
- with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
7
- s.bind(('', 0))
8
- s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
9
- return str(s.getsockname()[1])
10
-
11
- if __name__ == "__main__":
12
- print(find_free_port())
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/inference.py DELETED
@@ -1,289 +0,0 @@
1
- import sys
2
- import os
3
- file_path = os.getcwd()
4
- sys.path.append(file_path)
5
-
6
- import argparse
7
- import yaml
8
-
9
- import torch
10
- import dgl
11
- from dgl.data import DGLDataset
12
- from dgl.dataloading import GraphDataLoader
13
- from torch.utils.data import SubsetRandomSampler, SequentialSampler
14
-
15
-
16
- def my_error_handler(level, abort, location, msg):
17
- # Log the error message to a file instead of printing
18
- with open("error_log.txt", "a") as log_file:
19
- log_file.write(f"Error in {location}: {msg}\n")
20
-
21
- # Optionally, print the error message to the console
22
- # print(f"Error in {location}: {msg}")
23
-
24
- # Decide whether to abort based on the error level
25
- if abort:
26
- raise RuntimeError(f"Fatal error in {location}: {msg}")
27
-
28
- class CustomPreBatchedDataset(DGLDataset):
29
- def __init__(self, start_dataset, batch_size, mask_fn=None, drop_last=False, shuffle=False, **kwargs):
30
- self.start_dataset = start_dataset
31
- self.batch_size = batch_size
32
- self.mask_fn = mask_fn or (lambda x: torch.ones(len(x), dtype=torch.bool))
33
- self.drop_last = drop_last
34
- self.shuffle = shuffle
35
- super().__init__(name=start_dataset.name + '_custom_prebatched', save_dir=start_dataset.save_dir)
36
-
37
- def process(self):
38
- mask = self.mask_fn(self.start_dataset)
39
- indices = torch.arange(len(self.start_dataset))[mask]
40
- print(f"Number of elements after masking: {len(indices)}") # Debugging print
41
-
42
- if self.shuffle:
43
- sampler = SubsetRandomSampler(indices)
44
- else:
45
- sampler = SequentialSampler(indices)
46
-
47
- self.dataloader = GraphDataLoader(
48
- self.start_dataset,
49
- sampler=sampler,
50
- batch_size=self.batch_size,
51
- drop_last=self.drop_last
52
- )
53
- print(f"Batch size set in DataLoader: {self.batch_size}") # Debugging print
54
-
55
- def __getitem__(self, idx):
56
- if isinstance(idx, int):
57
- idx = [idx]
58
- sampler = SequentialSampler(idx)
59
- dloader = GraphDataLoader(self.start_dataset, sampler=sampler, batch_size=self.batch_size, drop_last=False)
60
- return next(iter(dloader))
61
-
62
- def __len__(self):
63
- return len(self.start_dataset)
64
-
65
- def include_config(conf):
66
- if 'include' in conf:
67
- for i in conf['include']:
68
- with open(i) as f:
69
- conf.update(yaml.load(f, Loader=yaml.FullLoader))
70
- del conf['include']
71
-
72
- def load_config(config_file):
73
- with open(config_file) as f:
74
- conf = yaml.load(f, Loader=yaml.FullLoader)
75
- include_config(conf)
76
- return conf
77
-
78
- def main():
79
- parser = argparse.ArgumentParser()
80
- add_arg = parser.add_argument
81
- add_arg('--config', type=str, required=True)
82
- add_arg('--target', type=str, required=True)
83
- add_arg('--destination', type=str, default='')
84
- add_arg('--chunkno', type=int, default=0)
85
- add_arg('--chunks', type=int, default=1)
86
- add_arg('--write', action='store_true')
87
- add_arg('--ckpt', type=int, default=-1)
88
- add_arg('--clobber', action='store_true')
89
- add_arg('--tree', type=str, default='')
90
- add_arg('--branch_name', type=str, default='score')
91
- args = parser.parse_args()
92
-
93
- config = load_config(args.config)
94
- if args.destination == '':
95
- args.destination = os.path.join(config['Training_Directory'], 'inference/', os.path.split(args.target)[1])
96
- else:
97
- args.destination = args.destination
98
- if not args.write:
99
- args.destination = args.destination.replace('.root', '') + f'_chunk{args.chunkno}.npz'
100
-
101
- if os.path.exists(args.destination):
102
- print(f'File {args.destination} already exists.')
103
- if args.clobber:
104
- print('Clobbering.')
105
- else:
106
- print('Exiting.')
107
- return
108
- else:
109
- print(f'Writing to {args.destination}')
110
-
111
- import time
112
- start = time.time()
113
- import ROOT
114
- import torch
115
- from array import array
116
- import numpy as np
117
- from root_gnn_base import batched_dataset as dataset
118
- from root_gnn_base import utils
119
- end = time.time()
120
- print('Imports finished in {:.2f} seconds'.format(end - start))
121
-
122
- start = time.time()
123
- dset_config = config['Datasets'][list(config['Datasets'].keys())[0]]
124
- if dset_config['class'] == 'LazyDataset':
125
- dset_config['class'] = 'EdgeDataset'
126
- elif dset_config['class'] == 'LazyMultiLabelDataset':
127
- dset_config['class'] = 'MultiLabelDataset'
128
- elif dset_config['class'] == 'PhotonIDDataset':
129
- dset_config['class'] = 'UnlazyPhotonIDDataset'
130
- elif dset_config['class'] == 'kNNDataset':
131
- dset_config['class'] = 'UnlazyKNNDataset'
132
- dset_config['args']['raw_dir'] = os.path.split(args.target)[0]
133
- dset_config['args']['file_names'] = os.path.split(args.target)[1]
134
- dset_config['args']['save'] = False
135
- dset_config['args']['chunks'] = args.chunks
136
- dset_config['args']['process_chunks'] = [args.chunkno,]
137
- dset_config['args']['selections'] = []
138
-
139
- dset_config['args']['save_dir'] = os.path.dirname(args.destination)
140
-
141
- if args.tree != '':
142
- dset_config['args']['tree_name'] = args.tree
143
-
144
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
145
-
146
- dstart = time.time()
147
- dset = utils.buildFromConfig(dset_config)
148
- dend = time.time()
149
- print('Dataset finished in {:.2f} seconds'.format(dend - dstart))
150
-
151
- print(dset)
152
-
153
- batch_size = config['Training']['batch_size']
154
- lstart = time.time()
155
- loader = CustomPreBatchedDataset(dset, batch_size)
156
- loader.process()
157
- # loader = dataset.PreBatchedDataset(dset, batch_size, shuffle=False, drop_last=False, save_to_disk=False, chunks = 1, num_workers=0)
158
- lend = time.time()
159
- print('Loader finished in {:.2f} seconds'.format(lend - lstart))
160
- sample_graph, _, _, global_sample = loader[0]
161
-
162
- print('dset length =', len(dset))
163
- print('loader length =', len(loader))
164
-
165
- model = utils.buildFromConfig(config['Model'], {'sample_graph' : sample_graph, 'sample_global': global_sample}).to(device)
166
- if args.ckpt < 0:
167
- ep, checkpoint = utils.get_last_epoch(config, args.ckpt, device=device)
168
- else:
169
- ep, checkpoint = utils.get_specific_epoch(config, args.ckpt, device=device)
170
- #Bad filler for models which were compiled. Have to remove this prefix.
171
- mds_copy = {}
172
- for key in checkpoint['model_state_dict'].keys():
173
- newkey = key.replace('module.', '')
174
- newkey = newkey.replace('_orig_mod.', '')
175
- mds_copy[newkey] = checkpoint['model_state_dict'][key]
176
- model.load_state_dict(mds_copy)
177
- model.eval()
178
-
179
- end = time.time()
180
- print('Model and dataset finished in {:.2f} seconds'.format(end - start))
181
- print('Starting inference')
182
- start = time.time()
183
-
184
- finish_fn = torch.nn.Sigmoid()
185
- if 'Loss' in config:
186
- finish_fn = utils.buildFromConfig(config['Loss']['finish'])
187
-
188
- scores = []
189
- labels = []
190
- tracking_info = []
191
- ibatch = 0
192
-
193
- for batch, label, track, globals in loader.dataloader:
194
- batch = batch.to(device)
195
- pred = model(batch, globals.to(device))
196
- ibatch += 1
197
- # scores.append(finish_fn(pred).detach().cpu().numpy())
198
- if (finish_fn.__class__.__name__ == "ContrastiveClusterFinish"):
199
- scores.append(pred.detach().cpu().numpy())
200
- else:
201
- scores.append(finish_fn(pred).detach().cpu().numpy())
202
- labels.append(label.detach().cpu().numpy())
203
- tracking_info.append(track.detach().cpu().numpy())
204
-
205
- # for batch, label, track, globals in loader:
206
- # batch = batch.to(device)
207
- # pred = model(batch, globals.to(device))
208
- # print(f'Batch size: {batch.batch_size if hasattr(batch, "batch_size") else "Unavailable"}')
209
- # print(f'Prediction shape: {pred.shape}')
210
- # ibatch += 1
211
- # scores.append(finish_fn(pred).detach().cpu().numpy())
212
- # labels.append(label.detach().cpu().numpy())
213
- # tracking_info.append(track.detach().cpu().numpy())
214
- # exit()
215
-
216
- score_size = scores[0].shape[1]
217
- scores = np.concatenate(scores)
218
- labels = np.concatenate(labels)
219
- tracking_info = np.concatenate(tracking_info)
220
- end = time.time()
221
-
222
- print('Inference finished in {:.2f} seconds'.format(end - start))
223
-
224
- if args.write:
225
- # ROOT.SetErrorHandler(my_error_handler)
226
- ROOT.gErrorIgnoreLevel = ROOT.kFatal
227
- # ROOT.gSystem.RedirectOutput("/dev/null", "w")
228
-
229
- # Open the original ROOT file
230
- infile = ROOT.TFile.Open(args.target)
231
- tree = infile.Get(dset_config['args']['tree_name'])
232
-
233
- # Create the destination directory if it doesn't exist
234
- os.makedirs(os.path.split(args.destination)[0], exist_ok=True)
235
-
236
- # Create a new ROOT file to write the modified tree
237
- outfile = ROOT.TFile.Open(args.destination, 'RECREATE')
238
-
239
- # Clone the original tree, including data
240
- outtree = tree.CloneTree(0) # Clone all entries
241
-
242
- # Determine if scores is a list of single values or vectors
243
- from ROOT import std
244
- if isinstance(scores[0], (list, tuple, np.ndarray)): # Check if scores contains vectors
245
- # Create a new branch for scores as a vector of floats
246
- scores_branch_vec = std.vector('float')()
247
- outtree.Branch(args.branch_name, scores_branch_vec)
248
- is_vector = True
249
- else: # Scores contains single values
250
- # Create a new branch for scores as a single float
251
- score_branch_arr = array('f', [0])
252
- outtree.Branch(args.branch_name, score_branch_arr, f'{args.branch_name}/F')
253
- is_vector = False
254
-
255
- # Write scores to the new branch
256
- print(f'Writing {len(scores)} scores to tree')
257
-
258
- for i in range(tree.GetEntries()):
259
- tree.GetEntry(i)
260
-
261
- if is_vector:
262
- # Clear the vector
263
- scores_branch_vec.clear()
264
-
265
- # Add all elements from scores[i] to the vector
266
- for value in scores[i]:
267
- scores_branch_vec.push_back(float(value)) # Use push_back to add elements one by one
268
- else:
269
- # Fill the score branch with the current single score
270
- score_branch_arr[0] = float(scores[i]) # Ensure the value is a float
271
-
272
- # Fill the output tree with all branches, including the new scores branch
273
- outtree.Fill()
274
-
275
- # Write the modified tree to the new file
276
- print(f'Writing to file {args.destination}')
277
- print(f'Input entries: {tree.GetEntries()}, Output entries: {outtree.GetEntries()}')
278
- outtree.Write()
279
- outfile.Close()
280
- infile.Close()
281
- else:
282
- os.makedirs(os.path.split(args.destination)[0], exist_ok=True)
283
- np.savez(args.destination, scores=scores, labels=labels, tracking_info=tracking_info)
284
-
285
- if __name__ == '__main__':
286
- main()
287
-
288
-
289
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/prep_data.py DELETED
@@ -1,43 +0,0 @@
1
- import sys
2
- import os
3
- file_path = os.getcwd()
4
- sys.path.append(file_path)
5
-
6
- import root_gnn_base.utils as utils
7
- import argparse
8
- from root_gnn_base.batched_dataset import PreBatchedDataset
9
- from root_gnn_base.batched_dataset import LazyPreBatchedDataset
10
-
11
- def main():
12
- parser = argparse.ArgumentParser()
13
- add_arg = parser.add_argument
14
- add_arg('--config', type=str, required=True)
15
- add_arg('--dataset', type=str, required=True)
16
- add_arg('--chunk', type=int, default=0)
17
- add_arg('--shuffle_mode', action='store_true', help='Shuffle the dataset before training.')
18
- args = parser.parse_args()
19
-
20
- config = utils.load_config(args.config)
21
- dset_config = config['Datasets'][args.dataset]
22
- batch_size = config['Training']['batch_size']
23
- if not args.shuffle_mode:
24
- dset = utils.buildFromConfig(dset_config, {'process_chunks': [args.chunk,]})
25
- else:
26
- dset = utils.buildFromConfig(dset_config)
27
- if 'batch_size' in dset_config:
28
- batch_size = dset_config['batch_size']
29
-
30
- shuffle_chunks = dset_config.get('shuffle_chunks', 10)
31
- padding_mode = dset_config.get('padding_mode', 'STEPS')
32
- fold_conf = dset_config["folding"]
33
- print(f"shuffle_chunks = {shuffle_chunks}, args.chunk = {args.chunk}, padding_mode = {padding_mode}")
34
- if dset_config["class"] == "LazyMultiLabelDataset":
35
- LazyPreBatchedDataset(start_dataset = dset, batch_size = batch_size, mask_fn = utils.fold_selection(fold_conf, "train"), suffix = utils.fold_selection_name(fold_conf, "train"), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode)
36
- LazyPreBatchedDataset(start_dataset = dset, batch_size = batch_size, mask_fn = utils.fold_selection(fold_conf, "test"), suffix = utils.fold_selection_name(fold_conf, 'test'), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode)
37
-
38
- else:
39
- PreBatchedDataset(dset, batch_size, utils.fold_selection(fold_conf, "train"), suffix = utils.fold_selection_name(fold_conf, "train"), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode)
40
- PreBatchedDataset(dset, batch_size, utils.fold_selection(fold_conf, "test"), suffix = utils.fold_selection_name(fold_conf, 'test'), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode)
41
-
42
- if __name__ == "__main__":
43
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/training_script.py DELETED
@@ -1,755 +0,0 @@
1
- import argparse
2
- import time
3
- import datetime
4
- import yaml
5
- import os
6
-
7
- start_time = time.time()
8
-
9
- import dgl
10
- import torch
11
- import torch.nn as nn
12
-
13
- import sys
14
- file_path = os.getcwd()
15
- sys.path.append(file_path)
16
-
17
- import root_gnn_base.batched_dataset as datasets
18
- from root_gnn_base import utils
19
- import root_gnn_base.custom_scheduler as lr_utils
20
- from models import GCN
21
-
22
- import numpy as np
23
- from sklearn.metrics import roc_auc_score
24
- import resource
25
- import gc
26
-
27
- import torch.distributed as dist
28
- import torch.multiprocessing as mp
29
- from torch.utils.data.distributed import DistributedSampler
30
- from torch.nn.parallel import DistributedDataParallel as DDP
31
-
32
- print("import time: {:.4f} s".format(time.time() - start_time))
33
-
34
- def mem():
35
- print(f'Current memory usage: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024 / 1024} GB')
36
-
37
- def gpu_mem():
38
- print()
39
- print('GPU Memory Usage:')
40
- sum = 0
41
- # for obj in gc.get_objects():
42
- # try:
43
- # if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
44
- # print(obj.numel() if len(obj.size()) > 0 else 0, type(obj), obj.size())
45
- # sum += obj.numel() if len(obj.size()) > 0 else 0
46
- # except:
47
- # pass
48
- print(f'Current GPU memory usage: {torch.cuda.memory_allocated() / 1024 / 1024 / 1024} GB')
49
- print(f'Current GPU cache usage: {torch.cuda.memory_cached() / 1024 / 1024 / 1024} GB')
50
- print(f'Current GPU max memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024} GB')
51
- print(f'Current GPU max cache usage: {torch.cuda.max_memory_cached() / 1024 / 1024 / 1024} GB')
52
- print(f'Numel in current tensors: {sum}')
53
- mem()
54
-
55
-
56
- ## epoch stores the epoch number I want to evaluate the model at
57
- def evaluate(val_loaders, model, config, device, epoch = -1):
58
- print("Evaluating")
59
-
60
- if (epoch != -1) :
61
- print(f"Evalulating at epoch {epoch}")
62
- last_ep, checkpoint = utils.get_specific_epoch(config, epoch, from_ryan=False)
63
- print(f"Evaluating at epoch = {last_ep}")
64
- else:
65
- starting_epoch = 0
66
- last_ep, checkpoint = utils.get_last_epoch(config)
67
-
68
- if checkpoint != None:
69
- ep = last_ep
70
- state_dict = checkpoint['model_state_dict']
71
- new_state_dict = {}
72
- for k, v in state_dict.items():
73
- new_key = k.replace('module.', '')
74
- new_state_dict[new_key] = v
75
- model.load_state_dict(new_state_dict)
76
- starting_epoch = checkpoint['epoch'] + 1
77
- print(f"Loaded epoch {checkpoint['epoch']} from checkpoint")
78
-
79
- if 'Loss' not in config:
80
- loss_fcn = nn.BCEWithLogitsLoss()
81
- else:
82
- loss_fcn = utils.buildFromConfig(config['Loss'])
83
- if len(val_loaders) == 0:
84
- return "No validation data"
85
- start = time.time()
86
- scores = []
87
- labels = []
88
- weights = []
89
- before_decoder = []
90
- after_decoder = []
91
- tracking = []
92
-
93
- batch_size = config["Training"]["batch_size"]
94
-
95
- batch_limit = int(np.ceil(1e5 / batch_size))
96
-
97
- model.eval()
98
- with torch.no_grad():
99
- for loader in val_loaders:
100
- batch_count = 0
101
- for batch, label, track, global_feats in loader:
102
- #Don't use compiled model for testing since we can't control the batch size.
103
- #We could before, but it assumes each dataset has the same number of batches...
104
- before_global_decoder, after_global_decoder, after_classify = model.representation(batch.to(device), global_feats.to(device))
105
-
106
- scores.append(after_classify.to("cpu"))
107
- before_decoder.append(before_global_decoder.to("cpu"))
108
- after_decoder.append(after_global_decoder.to("cpu"))
109
- labels.append(label.to("cpu"))
110
- weights.append(track[:,1].to("cpu"))
111
- tracking.append(track.to("cpu"))
112
-
113
- batch_count += 1
114
- if batch_count >= batch_limit:
115
- break
116
-
117
- if scores == []: #If validation set is empty.
118
- return
119
- logits = torch.concatenate(scores)
120
- scores = torch.sigmoid(logits)
121
- labels = torch.concatenate(labels)
122
- weights = torch.concatenate(weights)
123
- before_decoder = torch.concatenate(before_decoder)
124
- after_decoder = torch.concatenate(after_decoder)
125
- tracking = torch.concatenate(tracking)
126
-
127
- logits = logits.to("cpu").numpy()
128
- scores = scores.to("cpu").numpy()
129
- labels = labels.to("cpu").numpy()
130
- before_decoder = before_decoder.to("cpu").numpy()
131
- after_decoder = after_decoder.to("cpu").numpy()
132
- tracking = tracking.to("cpu").numpy()
133
-
134
- # Save the NumPy arrays to a .npz file
135
- outfile = f"{config['Training_Directory']}/evaluation_{epoch}.npz"
136
-
137
- np.savez(outfile, logits=logits, scores=scores, labels=labels, before_decoder=before_decoder, after_decoder=after_decoder, tracking=tracking)
138
-
139
- print(f"saved scores to {outfile}")
140
- return
141
-
142
-
143
- def train(train_loaders, test_loaders, model, device, config, args, rank):
144
- nocompile = args.nocompile
145
- restart = args.restart
146
- # define train/val samples, loss function and optimizer
147
- if 'Loss' not in config:
148
- loss_fcn = nn.BCEWithLogitsLoss()
149
- finish_fn = torch.nn.Sigmoid()
150
- else:
151
- loss_fcn = utils.buildFromConfig(config['Loss'])
152
- finish_fn = utils.buildFromConfig(config['Loss']['finish'])
153
-
154
- optimizer = torch.optim.Adam(model.parameters(), lr=config['Training']['learning_rate'])
155
- if 'gamma' in config['Training']:
156
- gamma = config['Training']['gamma']
157
- else:
158
- gamma = 1
159
-
160
- if 'dynamic_lr' in config['Training']:
161
- factor = config['Training']['dynamic_lr']['factor']
162
- patience = config['Training']['dynamic_lr']['patience']
163
- else:
164
- factor = 1
165
- patience = 1
166
-
167
- early_termination = utils.EarlyStop()
168
- if 'early_termination' in config['Training']:
169
- early_termination.patience = config['Training']['early_termination']['patience']
170
- early_termination.threshold = config['Training']['early_termination']['threshold']
171
- early_termination.mode = config['Training']['early_termination']['mode']
172
-
173
- scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma = gamma)
174
- #scheduler_reset = custom_scheduler.Dynamic_LR(optimizer, 'max', factor = factor, patience = patience)
175
- custom_scheduler = None
176
- if ('custom_scheduler' in config['Training']):
177
- run_time_args = {}
178
- scheduler_class = config['Training']['custom_scheduler']['class']
179
- if (scheduler_class == 'Dynamic_LR' or
180
- scheduler_class == 'Dynamic_LR_AND_Partial_Reset' or
181
- scheduler_class == 'Dynamic_LR_AND_Full_Reset'):
182
-
183
- run_time_args={'optimizer': optimizer}
184
-
185
- custom_scheduler = utils.buildFromConfig(config['Training']['custom_scheduler'], run_time_args=run_time_args)
186
-
187
- starting_epoch = 0
188
- if not restart:
189
- last_ep, checkpoint = utils.get_last_epoch(config)
190
- if checkpoint != None:
191
- ep = starting_epoch - 1
192
- if nocompile:
193
- new_state_dict = {}
194
- for k, v in checkpoint['model_state_dict'].items():
195
- new_key = k.replace('module.', '')
196
- new_state_dict[new_key] = v
197
- checkpoint['model_state_dict'] = new_state_dict
198
- if (args.multinode or args.multigpu):
199
- new_state_dict = {}
200
- for k, v in checkpoint['model_state_dict'].items():
201
- new_key = 'module.' + k
202
- new_state_dict[new_key] = v
203
- checkpoint['model_state_dict'] = new_state_dict
204
- model.load_state_dict(checkpoint['model_state_dict'])
205
- else:
206
- model._orig_mod.load_state_dict(checkpoint['model_state_dict'])
207
- optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
208
- starting_epoch = checkpoint['epoch'] + 1
209
- if 'early_stop' in checkpoint:
210
- early_termination = utils.EarlyStop.load_from_dict(checkpoint['early_stop'])
211
- print(early_termination.to_str())
212
- print("EarlyStop state restored successfully.")
213
- if early_termination.should_stop:
214
- print(f"Early Termination at Epoch {epoch}")
215
- return
216
- else:
217
- print("'early_stop' not found in checkpoint. Initializing a new EarlyStop instance.")
218
- early_termination = utils.EarlyStop()
219
- print(f"Loaded epoch {checkpoint['epoch']} from checkpoint")
220
- log = open(config['Training_Directory'] + '/training.log', 'a', buffering=1)
221
- else:
222
- log = open(config['Training_Directory'] + '/training.log', 'w', buffering=1)
223
-
224
- train_cyclers = []
225
- for loader in train_loaders:
226
- train_cyclers.append(utils.cycler((loader)))
227
-
228
- if args.savecache:
229
- max_batch = [None,] * len(train_loaders)
230
- for dset_i, loader in enumerate(train_loaders):
231
- mbs = 0
232
- for batch_i, batch in enumerate(loader):
233
- if batch[0].num_nodes() > mbs:
234
- mbs = batch[0].num_nodes()
235
- max_batch[dset_i] = batch[0]
236
- print(f'Max batch size for dataset {dset_i}: {mbs}')
237
- big_batch = dgl.batch(max_batch).to(device)
238
- with torch.no_grad():
239
- model(big_batch)
240
-
241
- cumulative_times = [0,0,0,0,0]
242
- log.write(f'Training {config["Training_Name"]} {datetime.datetime.now()} \n')
243
- print(f"Starting training for {config['Training']['epochs']} epochs")
244
-
245
- if hasattr(train_loaders[0].dataset, 'padding_mode'):
246
- is_padded = train_loaders[0].dataset.padding_mode != 'NONE'
247
- if (train_loaders[0].dataset.padding_mode == 'NODE'):
248
- is_padded = False
249
- else:
250
- is_padded = False
251
-
252
- lr_utils.print_LR(optimizer)
253
-
254
- # torch.save({
255
- # 'epoch': 0,
256
- # 'model_state_dict': model.state_dict(),
257
- # 'optimizer_state_dict': optimizer.state_dict(),
258
- # }, os.path.join(config['Training_Directory'], f"model_epoch_{0}.pt"))
259
- # exit()
260
-
261
-
262
- # training loop
263
- # gpu_mem()
264
- for epoch in range(starting_epoch, config['Training']['epochs']):
265
- start = time.time()
266
- run = start
267
- if (args.multigpu or args.multinode):
268
- dist.barrier()
269
- if (epoch == 2):
270
- # torch.cuda.cudart().cudaProfilerStart()
271
- pass
272
-
273
- # training
274
- model.train()
275
- ibatch = 0
276
- total_loss = 0
277
- for batched_graph, labels, _, global_feats in train_loaders[0]:
278
- # # need to fix padded case
279
- # if is_padded:
280
- # tglobals.append(torch.zeros(1, len(global_feats[0])))
281
-
282
- batch_start = time.time()
283
- logits = torch.tensor([])
284
- tlabels = torch.tensor([])
285
- batch_lengths = []
286
- for cycler in train_cyclers:
287
- graph, label, _, global_feats = next(cycler)
288
- graph = graph.to(device)
289
- label = label.to(device)
290
- global_feats = global_feats.to(device)
291
- if is_padded: #Padding the globals to match padded graphs.
292
- global_feats = torch.concatenate((global_feats, torch.zeros(1, len(global_feats[0])).to(device)))
293
- load = time.time()
294
- if (len(logits) == 0):
295
- logits = model(graph, global_feats)
296
- tlabels = label
297
- else:
298
- logits = torch.concatenate((logits, model(graph, global_feats)), dim=0)
299
- tlabels = torch.concatenate((tlabels, label), dim=0)
300
- batch_lengths.append(logits.shape[0] - 1)
301
-
302
- if is_padded:
303
- keepmask = torch.full_like(logits[:,0], True, dtype=torch.bool)
304
- keepmask[batch_lengths] = False
305
- logits = logits[keepmask]
306
- tlabels = tlabels.to(torch.float)
307
- if logits.shape[1] == 1 and loss_fcn.__class__.__name__ == 'BCEWithLogitsLoss':
308
- logits = logits[:,0]
309
- tlabels = tlabels.to(torch.float)
310
- if loss_fcn.__class__.__name__ == 'CrossEntropyLoss':
311
- tlabels = tlabels.to(torch.long)
312
- loss = loss_fcn(logits, tlabels.to(device)) # changed logits from logits[:,0] and left labels as int for multiclass. Does this break binary? Yes.
313
- optimizer.zero_grad()
314
- loss.backward()
315
- optimizer.step()
316
- total_loss += loss.detach().cpu().item()
317
- ibatch += 1
318
- cumulative_times[0] += batch_start - run
319
- cumulative_times[1] += load - batch_start
320
- run = time.time()
321
- cumulative_times[2] += run - load
322
- if ibatch % 1000 == 0:
323
- print(f'Batch {ibatch} out of {len(train_loaders[0])}', end='\r')
324
- # gpu_mem()
325
-
326
- if (args.multigpu):
327
- print(f'Rank {rank} Epoch Done.')
328
- elif (args.multinode):
329
- print(f'Rank {args.global_rank} Epoch Done.')
330
- else:
331
- print("Epoch Done.")
332
- # validation
333
-
334
- scores = []
335
- labels = []
336
- weights = []
337
- model.eval()
338
- with torch.no_grad():
339
- for loader in test_loaders:
340
- for batch, label, track, global_feats in loader:
341
- #Don't use compiled model for testing since we can't control the batch size.
342
- #We could before, but it assumes each dataset has the same number of batches...
343
- if is_padded:
344
- global_feats = torch.cat([global_feats, torch.zeros(1, len(global_feats[0]))])
345
- if nocompile:
346
- batch_scores = model(batch.to(device), global_feats.to(device))
347
- else:
348
- batch_scores = model._orig_mod(batch.to(device), global_feats.to(device))
349
- if is_padded:
350
- scores.append(batch_scores[:-1,:])
351
- else:
352
- scores.append(batch_scores)
353
- labels.append(label)
354
- weights.append(track[:,1])
355
- eval_end = time.time()
356
- cumulative_times[3] += eval_end - run
357
-
358
- if scores == []: #If validation set is empty.
359
- continue
360
- logits = torch.concatenate(scores).to(device)
361
- labels = torch.concatenate(labels).to(device)
362
- weights = torch.concatenate(weights).to(device)
363
-
364
- if (args.multigpu or args.multinode):
365
- gathered_logits = [torch.zeros_like(logits) for _ in range(dist.get_world_size())]
366
- gathered_labels = [torch.zeros_like(labels) for _ in range(dist.get_world_size())]
367
- gathered_weights = [torch.zeros_like(weights) for _ in range(dist.get_world_size())]
368
-
369
- if (args.multigpu or args.multinode):
370
- dist.barrier()
371
- if (args.multigpu and rank != 0) or (args.multinode and args.global_rank != 0):
372
- dist.gather(logits, dst=0)
373
- dist.gather(labels, dst=0)
374
- dist.gather(weights, dst=0)
375
- continue
376
- else:
377
- dist.gather(logits, gather_list=gathered_logits)
378
- dist.gather(labels, gather_list=gathered_labels)
379
- dist.gather(weights, gather_list=gathered_weights)
380
-
381
- logits = torch.concatenate(gathered_logits)
382
- labels = torch.concatenate(gathered_labels)
383
- weights = torch.concatenate(gathered_weights)
384
-
385
- wgt_mask = weights > 0
386
-
387
- print(f"Num batches trained = {ibatch}")
388
-
389
- #Note: This section is a bit ugly. Very conditional. Should maybe config defined behavior?
390
- if (loss_fcn.__class__.__name__ == "ContrastiveClusterLoss"):
391
- scores = logits
392
- preds = scores
393
- accuracy = 0
394
- test_auc = 0
395
- acc = 0
396
- contrastive_cluster_loss = finish_fn(logits)
397
-
398
- elif (loss_fcn.__class__.__name__ == "MultiLabelLoss"):
399
- scores = finish_fn(logits)
400
- preds = torch.round(scores)
401
- multilabel_accuracy = []
402
- threshold = 0.1 # 10% threshold
403
-
404
- for i in range(len(labels[0])):
405
- # accurate_count = torch.sum(torch.abs(preds[:, i].to("cpu") - labels[:, i].to("cpu")) / labels[:, i].to("cpu") <= threshold)
406
- # multilabel_accruacy.append(accurate_count / len(labels))
407
- multilabel_accuracy.append(torch.sum(preds[:, i].to("cpu") == labels[:, i].to("cpu")) / len(labels))
408
- test_auc = 0
409
- acc = np.mean(multilabel_accuracy)
410
-
411
- elif logits.shape[1] == 1 and loss_fcn.__class__.__name__ == 'BCEWithLogitsLoss': #Proxy for binary classification.
412
- test_auc = 0
413
- acc = 0
414
- logits = logits[:,0]
415
- scores = finish_fn(logits)
416
- labels =labels.to(torch.float)
417
- preds = scores > 0.5
418
- test_auc = roc_auc_score(labels[wgt_mask].to("cpu") == 1, scores[wgt_mask].to("cpu"), sample_weight=weights[wgt_mask].to("cpu"))
419
- acc = torch.sum(preds.to("cpu") == labels.to("cpu")) / len(labels)
420
-
421
- elif logits.shape[1] == 1 and loss_fcn.__class__.__name__ == 'MSELoss':
422
- logits = logits[:,0]
423
- scores = finish_fn(logits)
424
- labels = labels.to(torch.float)
425
- acc = 0
426
- test_auc = 0
427
-
428
- else:
429
- preds = torch.argmax(logits, dim=1)
430
- scores = finish_fn(logits)
431
- if labels.dim() == 1: #Multi-class
432
- acc = torch.sum(preds.to("cpu") == labels.to("cpu")) / len(labels) #TODO: Make each class weighted equally?
433
-
434
- labels = labels.to("cpu")
435
- weights = weights.to("cpu")
436
- logits = logits.to("cpu")
437
- wgt_mask = wgt_mask.to("cpu")
438
-
439
- labels_onehot = np.zeros((len(labels), len(scores[0])))
440
- labels_onehot[np.arange(len(labels)), labels] = 1
441
-
442
- try:
443
- #test_auc = roc_auc_score(labels[wgt_mask].to("cpu") == 1, scores[wgt_mask].to("cpu"), multi_class='ovr', sample_weight=weights[wgt_mask].to("cpu"))
444
- test_auc = roc_auc_score(labels_onehot[wgt_mask], scores[wgt_mask].to("cpu"), multi_class='ovr', sample_weight=weights[wgt_mask].to("cpu"))
445
- except ValueError:
446
- test_auc = np.nan
447
- else: #Multi-loss
448
- acc = torch.sum(preds.to("cpu") == labels[:,0].to("cpu")) / len(labels)
449
- try:
450
- test_auc = roc_auc_score(labels[:,0][wgt_mask].to("cpu") == 1, scores[wgt_mask].to("cpu"), multi_class='ovr', sample_weight=weights[wgt_mask].to("cpu"))
451
- except ValueError:
452
- test_auc = np.nan
453
-
454
-
455
- # print(f"logits = {logits[:10]}")
456
- # print(f"preds = {preds[:2]}")
457
- # print(f"labels = {labels[:10]}")
458
-
459
- # print(f"len(Unique logits) = {len(torch.unique(logits))}")
460
- # print(f"Average of labels = {torch.mean(labels)}")
461
- # print(f"unique logits = {torch.unique(logits)[0]:.4f}, {torch.unique(logits)[-1]:.4f}")
462
-
463
-
464
- if (loss_fcn.__class__.__name__ == "MultiLabelLoss"):
465
- multilabel_log_str = "MultiLabel_Accuracy "
466
- for accuracy in multilabel_accuracy:
467
- multilabel_log_str += f" | {accuracy:.4f}"
468
- log.write(multilabel_log_str + '\n')
469
- print(multilabel_log_str, flush=True)
470
- elif (loss_fcn.__class__.__name__ == "ContrastiveClusterLoss"):
471
- contrastive_cluster_log_str = "ContrastiveClusterLoss "
472
- contrastive_cluster_log_str += f"Contrastive Loss: {contrastive_cluster_loss[0]:.4f}, Clustering Loss: {contrastive_cluster_loss[1]:.4f}, Variance Loss: {contrastive_cluster_loss[2]:.4f}"
473
- log.write(contrastive_cluster_log_str + '\n')
474
- print(contrastive_cluster_log_str, flush=True)
475
-
476
- # test_loss = loss_fcn(logits, labels.to(device))
477
- test_loss = loss_fcn(logits, labels)
478
- end = time.time()
479
- log_str = "Epoch {:05d} | LR {:.4e} | Loss {:.4f} | Accuracy {:.4f} | Test_Loss {:.4f} | Test_AUC {:.4f} | Time {:.4f} s".format(
480
- epoch, optimizer.param_groups[0]['lr'], total_loss/ibatch, acc, test_loss, test_auc, end - start
481
- )
482
- log.write(log_str + '\n')
483
- print(log_str, flush=True)
484
-
485
- state_dict = model.state_dict()
486
- if not nocompile:
487
- state_dict = model._orig_mod.state_dict()
488
-
489
- new_state_dict = {}
490
- for k, v in state_dict.items():
491
- new_key = k.replace('module.', '')
492
- new_state_dict[new_key] = v
493
- state_dict = new_state_dict
494
-
495
- # print('Testing done')
496
- # gpu_mem()
497
-
498
- if epoch == 2:
499
- # torch.cuda.cudart().cudaProfilerStop()
500
- pass
501
-
502
- torch.save({
503
- 'epoch': epoch,
504
- 'model_state_dict': state_dict,
505
- 'optimizer_state_dict': optimizer.state_dict(),
506
- 'early_stop': early_termination.to_dict()
507
- }, os.path.join(config['Training_Directory'], f"model_epoch_{epoch}.pt"))
508
- np.savez(os.path.join(config['Training_Directory'], f'model_epoch_{epoch}.npz'), scores=scores.to("cpu"), labels=labels.to("cpu"))
509
- save_end = time.time()
510
- cumulative_times[4] += save_end - eval_end
511
-
512
- early_termination.update(test_loss)
513
- if early_termination.should_stop:
514
- log_str = f"Early Termination at Epoch {epoch}"
515
- log.write(log_str + "\n")
516
- print(log_str)
517
- log_str = early_termination.to_str()
518
- log.write(log_str + "\n")
519
- print(log_str)
520
- break
521
-
522
- if (custom_scheduler):
523
- custom_scheduler.step(model, {'test_auc':test_auc})
524
- scheduler.step()
525
-
526
- print(f"Load: {cumulative_times[0]:.4f} s")
527
- print(f"Batch: {cumulative_times[1]:.4f} s")
528
- print(f"Train: {cumulative_times[2]:.4f} s")
529
- print(f"Eval: {cumulative_times[3]:.4f} s")
530
- print(f"Save: {cumulative_times[4]:.4f} s")
531
- log.close()
532
-
533
- def find_free_port():
534
- import socket
535
- from contextlib import closing
536
-
537
- with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
538
- s.bind(('', 0))
539
- s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
540
- return str(s.getsockname()[1])
541
-
542
- def init_process_group(world_size, rank, port):
543
- os.environ['MASTER_ADDR'] = 'localhost'
544
- # os.environ['MASTER_PORT'] = find_free_port()
545
- os.environ['MASTER_PORT'] = port
546
-
547
- dist.init_process_group(
548
- backend="nccl", # change to 'nccl' for multiple GPUs (other was gloo)
549
- init_method='env://',
550
- world_size=world_size,
551
- rank=rank,
552
- timeout=datetime.timedelta(seconds=300),
553
- )
554
-
555
- def main(rank=0, args=None, world_size=1, port=24500, seed=12345):
556
-
557
- #Prevent simultaneous file access
558
- #sleep_time = 120 * rank
559
- #time.sleep(sleep_time)
560
-
561
- #Load config file
562
- config = utils.load_config(args.config)
563
-
564
- if (args.directory):
565
- print(f"New training directory: { config['Training_Directory'] + args.directory}")
566
- config['Training_Directory'] = config['Training_Directory'] + args.directory
567
-
568
- if not os.path.exists(config['Training_Directory']):
569
- os.makedirs(config['Training_Directory'], exist_ok=True)
570
- with open(config['Training_Directory'] + '/config.yaml', 'w') as f:
571
- yaml.dump(config, f)
572
- batch_size = config["Training"]["batch_size"]
573
-
574
- if(args.plot):
575
- rl = utils.read_log(config)
576
- utils.plot_log(rl, config['Training_Directory'] + '/training.png')
577
- print('Log at ' + config['Training_Directory'] + '/training.log')
578
- print('Plotted at ' + config['Training_Directory'] + '/training.png')
579
- exit()
580
-
581
- if (args.multigpu):
582
- print(f"Setting up multigpu")
583
- start_time = time.time()
584
- init_process_group(world_size, rank, port)
585
- print("multigpu setup time: {:.4f} s".format(time.time() - start_time))
586
- device = torch.device(f'cuda:{rank}')
587
- torch.cuda.device(device)
588
- elif (args.multinode):
589
- device = torch.device(f'cuda:{rank}')
590
- torch.cuda.device(device)
591
- print(f"global rank = {args.global_rank}, local rank = {rank}, device = {device}")
592
- else:
593
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
594
-
595
- if (args.cpu):
596
- print(f"Using CPU")
597
- device = "cpu"
598
-
599
- train_loaders = []
600
- test_loaders = []
601
- val_loaders = []
602
- load_start = time.time()
603
-
604
- torch.backends.cuda.matmul.allow_tf32 = True
605
-
606
- ldr_type = datasets.LazyPreBatchedDataset if args.lazy else datasets.PreBatchedDataset
607
-
608
- #Load datasets
609
- if (pargs.statistics):
610
- pargs.statistics = int(pargs.statistics)
611
- print(f"Training Dataset Size: {pargs.statistics}")
612
- num_batches = int(np.ceil(pargs.statistics / batch_size))
613
- np.random.seed(pargs.seed)
614
-
615
- for dset_conf in config["Datasets"]:
616
- dset = utils.buildFromConfig(config["Datasets"][dset_conf])
617
- if 'batch_size' in config["Datasets"][dset_conf]:
618
- batch_size = config["Datasets"][dset_conf]['batch_size']
619
- fold_conf = config["Datasets"][dset_conf]["folding"]
620
- shuffle_chunks = config["Datasets"][dset_conf].get("shuffle_chunks", 10)
621
- padding_mode = config["Datasets"][dset_conf].get("padding_mode", "STEPS")
622
- mask_fn = utils.fold_selection(fold_conf, "train")
623
- if args.preshuffle:
624
- # ldr = ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=mask_fn, suffix = utils.fold_selection_name(fold_conf, 'train'), chunks = shuffle_chunks, padding_mode = padding_mode, use_ddp = args.multigpu, rank=rank, world_size=world_size)
625
- ldr = ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=mask_fn, suffix = utils.fold_selection_name(fold_conf, 'train'), chunks = shuffle_chunks, padding_mode = padding_mode)
626
- gsamp, _, _, global_samp = ldr[0]
627
- sampler = None
628
-
629
- if (pargs.statistics):
630
- sampler = np.random.choice(range(len(ldr)), size=num_batches)
631
-
632
- if (args.multigpu):
633
- sampler = DistributedSampler(ldr, num_replicas=world_size, rank=rank, shuffle=False, drop_last=True)
634
- # num_batches = len(ldr)
635
- # sampler = list(sampler)
636
- # if (sampler[0] >= num_batches % world_size):
637
- # sampler.pop()
638
- if (args.multinode):
639
- sampler = DistributedSampler(ldr, num_replicas=world_size, rank=pargs.global_rank, shuffle=False, drop_last=True)
640
- train_loaders.append(torch.utils.data.DataLoader(ldr, batch_size = None, num_workers = 0, sampler = sampler))
641
- sampler = None
642
- ldr = ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=mask_fn, suffix = utils.fold_selection_name(fold_conf, 'test'), chunks = shuffle_chunks, padding_mode = padding_mode)
643
- if (args.multigpu):
644
- sampler = DistributedSampler(ldr, num_replicas=world_size, rank=rank, shuffle=False, drop_last=True)
645
- # num_batches = len(ldr)
646
- # sampler = list(sampler)
647
- # if (rank >= num_batches % world_size):
648
- # sampler.pop()
649
- if (args.multinode):
650
- sampler = DistributedSampler(ldr, num_replicas=world_size, rank=pargs.global_rank, shuffle=False, drop_last=True)
651
-
652
- test_loaders.append(torch.utils.data.DataLoader(ldr, batch_size = None, num_workers = 0, sampler=sampler))
653
-
654
- if "validation" in fold_conf:
655
- val_loaders.append(torch.utils.data.DataLoader((ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=utils.fold_selection(fold_conf, "validation"), suffix = utils.fold_selection_name(fold_conf, 'validation'), chunks = shuffle_chunks, padding_mode = padding_mode, rank=rank, world_size=1)), batch_size = None, num_workers = 0, sampler = sampler))
656
- else:
657
- print("No validation set for dataset ", dset_conf)
658
- else:
659
- train_loaders.append(datasets.GetBatchedLoader(dset, batch_size, utils.fold_selection(fold_conf, "train")))
660
- gsamp, _, _, global_samp = dset[0]
661
- test_loaders.append(datasets.GetBatchedLoader(dset, batch_size, utils.fold_selection(fold_conf, "test")))
662
- if "validation" in fold_conf:
663
- val_loaders.append(datasets.GetBatchedLoader(dset, batch_size, utils.fold_selection(fold_conf, "validation")))
664
- else:
665
- print("No validation set for dataset ", dset_conf)
666
-
667
- load_end = time.time()
668
- print("Load time: {:.4f} s".format(load_end - load_start))
669
- model = utils.buildFromConfig(config["Model"], {'sample_graph': gsamp, 'sample_global': global_samp, 'seed': seed}).to(device)
670
- if not args.nocompile:
671
- model = torch.compile(model)
672
- if args.multigpu:
673
- print(f"Trying to create DDP model")
674
- start_time = time.time()
675
- model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device])
676
- print("model creation time: {:.4f} s".format(time.time() - start_time))
677
- if (args.multinode):
678
- print(f"Trying to create DDP model")
679
- start_time = time.time()
680
- model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device])
681
- print("model creation time: {:.4f} s".format(time.time() - start_time))
682
-
683
- # total_params = 0
684
- # for param_dict in model.parameters():
685
- # for param in param_dict['params']:
686
- # if param.requires_grad:
687
- # total_params += param.numel()
688
- # print(f"Number of trainable parameters = {total_params}")
689
-
690
- if(type(model) == GCN.Clustering):
691
- print("clustering")
692
-
693
- if args.evaluate != None:
694
- evaluate(test_loaders, model, config, device, args.evaluate)
695
- exit()
696
-
697
- # model training
698
- print("Training...")
699
- gpu_mem()
700
- train(train_loaders, test_loaders, model, device, config, args, rank)
701
-
702
- # test the model
703
- # print("Testing...")
704
- # evaluate(val_loaders, model, config, device)
705
-
706
- # if args.multigpu or args.multinode:
707
- # dist.destroy_process_group()
708
-
709
- # if rank == 0:
710
- # rl = utils.read_log(config)
711
- # utils.plot_log(rl, config['Training_Directory'] + '/training.png')
712
- # print('Log at ' + config['Training_Directory'] + '/training.log')
713
- # print('Plotted at ' + config['Training_Directory'] + '/training.png')
714
-
715
- if __name__ == "__main__":
716
- #Handle CLI arguments
717
- parser = argparse.ArgumentParser()
718
- add_arg = parser.add_argument
719
- add_arg("--config", type=str, help="Config file.", required=True)
720
- add_arg("--restart", action="store_true", help="Restart training from scratch.")
721
- add_arg("--preshuffle", action="store_true", help="Shuffle data before training.")
722
- add_arg("--lazy", action="store_true", help="Lazy loading of data.")
723
- add_arg("--nocompile", action="store_true", help="Disable JIT compilation.")
724
- add_arg("--evaluate", type = int, help="Skip training and go to evaluation.")
725
- add_arg("--plot", action="store_true", help="Plot training logs.")
726
- add_arg("--multigpu", action="store_true", help="Use multiple GPUs.")
727
- add_arg("--multinode", action="store_true", help="Use multiple nodes.")
728
- add_arg("--savecache", action="store_true", help="")
729
- add_arg("--cpu", action="store_true", help="Uses the cpu only")
730
- add_arg("--statistics", type=float, help="Size of training data")
731
- add_arg("--directory", type=str, help="Append to Training Directory")
732
- add_arg("--seed", type=int, default=2, help="Sets random seed")
733
-
734
- pargs = parser.parse_args()
735
-
736
- if pargs.multigpu:
737
- port = find_free_port()
738
- torch.backends.cudnn.enabled = False
739
- mp.spawn(main, args=(pargs, 4, port), nprocs=4, join=True)
740
- if pargs.multinode:
741
- global_rank = int(os.environ["RANK"])
742
- local_rank = int(os.environ["LOCAL_RANK"])
743
- world_size = int(os.environ["WORLD_SIZE"])
744
- print(f"global_rank = {global_rank}, local_rank = {local_rank}, world_size = {world_size}")
745
-
746
- dist.init_process_group(backend="nccl")
747
- torch.backends.cudnn.enabled = False
748
-
749
- pargs.global_rank = global_rank
750
-
751
- main(rank = local_rank, args=pargs, world_size=world_size)
752
- else:
753
- main(0, pargs)
754
-
755
-