ho22joshua commited on
Commit
843d449
·
1 Parent(s): d646e7f

testing demo

Browse files
Files changed (47) hide show
  1. README.md +1 -1
  2. nemo/configs/physicsnemo/physicsnemo.yaml +0 -54
  3. nemo/configs/stats_100K/ttH_CP_even_vs_odd.yaml +0 -57
  4. nemo/models/Edge_Network.py +0 -163
  5. nemo/models/GCN.py +0 -1818
  6. nemo/models/loss.py +0 -311
  7. nemo/models/meshgraphnet.py +0 -33
  8. nemo/root_gnn_base/batched_dataset.py +0 -191
  9. nemo/root_gnn_base/custom_scheduler.py +0 -565
  10. nemo/root_gnn_base/dataset.py +0 -678
  11. nemo/root_gnn_base/photon_ID_dataset.py +0 -44
  12. nemo/root_gnn_base/similarity.py +0 -158
  13. nemo/root_gnn_base/uproot_dataset.py +0 -54
  14. nemo/root_gnn_base/utils.py +0 -393
  15. nemo/scripts/check_dataset_files.py +0 -130
  16. nemo/scripts/find_free_port.py +0 -12
  17. nemo/scripts/inference.py +0 -289
  18. nemo/scripts/prep_data.py +0 -44
  19. nemo/scripts/training_script.py +0 -463
  20. nemo/setup/Dockerfile +0 -25
  21. nemo/setup/build_image.sh +0 -4
  22. nemo/setup/environment.yml +0 -391
  23. nemo/setup/setup/Dockerfile +0 -29
  24. nemo/setup/setup/build_image.sh +0 -4
  25. nemo/setup/setup/environment.yml +0 -391
  26. nemo/setup/setup/test_setup.py +0 -48
  27. nemo/setup/test_setup.py +0 -48
  28. root_gnn_dgl/README.md +39 -30
  29. root_gnn_dgl/configs/attention/ttH_CP_even_vs_odd.yaml +0 -58
  30. root_gnn_dgl/configs/stats_100K/finetuning_ttH_CP_even_vs_odd.yaml +2 -2
  31. root_gnn_dgl/configs/stats_100K/pretraining_multiclass.yaml +2 -2
  32. root_gnn_dgl/configs/stats_100K/ttH_CP_even_vs_odd.yaml +2 -2
  33. root_gnn_dgl/configs/stats_100K/ttH_CP_even_vs_odd_batch_size_2048.yaml +0 -57
  34. root_gnn_dgl/configs/stats_100K/ttH_CP_even_vs_odd_batch_size_4096.yaml +0 -57
  35. root_gnn_dgl/configs/stats_100K/ttH_CP_even_vs_odd_batch_size_8192.yaml +0 -57
  36. root_gnn_dgl/configs/stats_all/finetuning_ttH_CP_even_vs_odd.yaml +2 -2
  37. root_gnn_dgl/configs/stats_all/pretraining_multiclass.yaml +2 -2
  38. root_gnn_dgl/configs/stats_all/ttH_CP_even_vs_odd.yaml +2 -2
  39. root_gnn_dgl/configs/stats_all/ttH_CP_even_vs_odd_batch_size_2048.yaml +0 -57
  40. root_gnn_dgl/configs/stats_all/ttH_CP_even_vs_odd_batch_size_4096.yaml +0 -57
  41. root_gnn_dgl/configs/stats_all/ttH_CP_even_vs_odd_batch_size_8192.yaml +0 -57
  42. root_gnn_dgl/jobs/interactive.sh +1 -1
  43. root_gnn_dgl/run_demo.sh +3 -3
  44. root_gnn_dgl/setup/Dockerfile +1 -1
  45. root_gnn_dgl/setup/build_image.sh +2 -4
  46. root_gnn_dgl/setup/environment.yml +2 -3
  47. root_gnn_dgl/setup/launch_image.sh +9 -0
README.md CHANGED
@@ -1,3 +1,3 @@
1
  ---
2
  license: mit
3
- ---
 
1
  ---
2
  license: mit
3
+ ---
nemo/configs/physicsnemo/physicsnemo.yaml DELETED
@@ -1,54 +0,0 @@
1
- Training_Name: ttH_CP_even_vs_odd
2
- Training_Directory: trainings/physicsnemo/test
3
- Model:
4
- module: models.meshgraphnet
5
- class: MeshGraphNet
6
- args:
7
- input_dim_nodes: 7
8
- input_dim_edges: 3
9
- output_dim: 64
10
- Training:
11
- epochs: 500
12
- batch_size: 1024
13
- learning_rate: 0.0001
14
- gamma: 0.99
15
- Datasets:
16
- ttH_CP_even: &dataset_defn
17
- module: root_gnn_base.dataset
18
- class: LazyDataset
19
- shuffle_chunks: 3
20
- batch_size: 1024
21
- padding_mode: NONE #one of STEPS, FIXED, or NONE
22
- args: &dataset_args
23
- name: ttH_CP_even
24
- label: 0
25
- # weight_var: weight
26
- chunks: 3
27
- buffer_size: 2
28
- file_names: ttH_NLO.root
29
- tree_name: output
30
- fold_var: Number
31
- raw_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/stats_100K/
32
- save_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/processed_graphs/stats_100K/ttH_CP_even_vs_odd/
33
- node_branch_names:
34
- - [jet_pt, ele_pt, mu_pt, ph_pt, MET_met]
35
- - [jet_eta, ele_eta, mu_eta, ph_eta, 0]
36
- - [jet_phi, ele_phi, mu_phi, ph_phi, MET_phi]
37
- - CALC_E
38
- - [jet_btag, 0, 0, 0, 0]
39
- - [0, ele_charge, mu_charge, 0, 0]
40
- - NODE_TYPE
41
- node_branch_types: [vector, vector, vector, vector, single]
42
- node_feature_scales: [1e-1, 1, 1, 1e-1, 1, 1, 1]
43
- folding:
44
- n_folds: 4
45
- test: [0]
46
- # validation: 1
47
- train: [1, 2, 3]
48
- ttH_CP_odd:
49
- <<: *dataset_defn
50
- args:
51
- <<: *dataset_args
52
- name: ttH_CP_odd
53
- label: 1
54
- file_names: ttH_CPodd.root
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nemo/configs/stats_100K/ttH_CP_even_vs_odd.yaml DELETED
@@ -1,57 +0,0 @@
1
- Training_Name: ttH_CP_even_vs_odd
2
- Training_Directory: trainings/stats_100K/ttH_CP_even_vs_odd
3
- Model:
4
- module: models.GCN
5
- class: Edge_Network
6
- args:
7
- hid_size: 64
8
- in_size: 7
9
- out_size: 1
10
- n_layers: 4
11
- n_proc_steps: 4
12
- dropout: 0
13
- Training:
14
- epochs: 500
15
- batch_size: 1024
16
- learning_rate: 0.0001
17
- gamma: 0.99
18
- Datasets:
19
- ttH_CP_even: &dataset_defn
20
- module: root_gnn_base.dataset
21
- class: LazyDataset
22
- shuffle_chunks: 3
23
- batch_size: 1024
24
- padding_mode: NONE #one of STEPS, FIXED, or NONE
25
- args: &dataset_args
26
- name: ttH_CP_even
27
- label: 0
28
- # weight_var: weight
29
- chunks: 3
30
- buffer_size: 2
31
- file_names: ttH_NLO.root
32
- tree_name: output
33
- fold_var: Number
34
- raw_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/stats_100K/
35
- save_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/processed_graphs/stats_100K/ttH_CP_even_vs_odd/
36
- node_branch_names:
37
- - [jet_pt, ele_pt, mu_pt, ph_pt, MET_met]
38
- - [jet_eta, ele_eta, mu_eta, ph_eta, 0]
39
- - [jet_phi, ele_phi, mu_phi, ph_phi, MET_phi]
40
- - CALC_E
41
- - [jet_btag, 0, 0, 0, 0]
42
- - [0, ele_charge, mu_charge, 0, 0]
43
- - NODE_TYPE
44
- node_branch_types: [vector, vector, vector, vector, single]
45
- node_feature_scales: [1e-1, 1, 1, 1e-1, 1, 1, 1]
46
- folding:
47
- n_folds: 4
48
- test: [0]
49
- # validation: 1
50
- train: [1, 2, 3]
51
- ttH_CP_odd:
52
- <<: *dataset_defn
53
- args:
54
- <<: *dataset_args
55
- name: ttH_CP_odd
56
- label: 1
57
- file_names: ttH_CPodd.root
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nemo/models/Edge_Network.py DELETED
@@ -1,163 +0,0 @@
1
- import dgl
2
- import dgl.nn as dglnn
3
-
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
-
8
- import sys
9
- import os
10
- file_path = os.getcwd()
11
- sys.path.append(file_path)
12
-
13
- def Make_SLP(in_size, out_size, activation = nn.ReLU, dropout = 0):
14
- layers = []
15
- layers.append(nn.Linear(in_size, out_size))
16
- layers.append(activation())
17
- layers.append(nn.Dropout(dropout))
18
- return layers
19
-
20
- def Make_MLP(in_size, hid_size, out_size, n_layers, activation = nn.ReLU, dropout = 0):
21
- layers = []
22
- if n_layers > 1:
23
- layers += Make_SLP(in_size, hid_size, activation, dropout)
24
- for i in range(n_layers-2):
25
- layers += Make_SLP(hid_size, hid_size, activation, dropout)
26
- layers += Make_SLP(hid_size, out_size, activation, dropout)
27
- else:
28
- layers += Make_SLP(in_size, out_size, activation, dropout)
29
- layers.append(torch.nn.LayerNorm(out_size))
30
- return nn.Sequential(*layers)
31
-
32
- class Edge_Network(nn.Module):
33
- def __init__(self, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs):
34
- super().__init__()
35
- print(f'Unused args while creating GCN: {kwargs}')
36
- self.n_layers = n_layers
37
- self.n_proc_steps = n_proc_steps
38
- self.layers = nn.ModuleList()
39
- if (len(sample_global) == 0):
40
- self.has_global = False
41
- else:
42
- self.has_global = sample_global.shape[1] != 0
43
- gl_size = sample_global.shape[1] if self.has_global else 1
44
-
45
- #encoder
46
- self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
47
- self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
48
- self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
49
-
50
- #GNN
51
- self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
52
- self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
53
- self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
54
-
55
- #decoder
56
- self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
57
- self.classify = nn.Linear(hid_size, out_size)
58
-
59
- def forward(self, g, global_feats):
60
- h = self.node_encoder(g.ndata['features'])
61
- e = self.edge_encoder(g.edata['features'])
62
-
63
- g.ndata['h'] = h
64
- g.edata['e'] = e
65
- if not self.has_global:
66
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
67
-
68
- batch_num_nodes = None
69
- sum_weights = None
70
- if "w" in g.ndata:
71
- batch_indices = g.batch_num_nodes()
72
- # Find non-zero rows (non-padded nodes)
73
- non_padded_nodes_mask = torch.any(g.ndata['features'] != 0, dim=1)
74
- # Split the mask according to the batch indices
75
- batch_num_nodes = []
76
- start_idx = 0
77
- for num_nodes in batch_indices:
78
- end_idx = start_idx + num_nodes
79
- non_padded_count = non_padded_nodes_mask[start_idx:end_idx].sum().item()
80
- batch_num_nodes.append(non_padded_count)
81
- start_idx = end_idx
82
- batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata['features'].device)
83
- sum_weights = batch_num_nodes[:, None].repeat(1, 64)
84
- global_feats = batch_num_nodes[:, None].to(torch.float)
85
-
86
- h_global = self.global_encoder(global_feats)
87
- for i in range(self.n_proc_steps):
88
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
89
- g.apply_edges(copy_v)
90
- g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
91
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
92
- g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
93
- if "w" in g.ndata:
94
- mean_nodes = dgl.sum_nodes(g, 'h', 'w') / sum_weights
95
- h_global = self.global_update(torch.cat((h_global, mean_nodes, dgl.mean_edges(g, 'e')), dim = 1))
96
- else:
97
- h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
98
- h_global = self.global_decoder(h_global)
99
- return self.classify(h_global)
100
-
101
- def representation(self, g, global_feats):
102
- h = self.node_encoder(g.ndata['features'])
103
- e = self.edge_encoder(g.edata['features'])
104
-
105
- g.ndata['h'] = h
106
- g.edata['e'] = e
107
- if not self.has_global:
108
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
109
-
110
- batch_num_nodes = None
111
- sum_weights = None
112
- if "w" in g.ndata:
113
- batch_indices = g.batch_num_nodes()
114
- # Find non-zero rows (non-padded nodes)
115
- non_padded_nodes_mask = torch.any(g.ndata['features'] != 0, dim=1)
116
- # Split the mask according to the batch indices
117
- batch_num_nodes = []
118
- start_idx = 0
119
- for num_nodes in batch_indices:
120
- end_idx = start_idx + num_nodes
121
- non_padded_count = non_padded_nodes_mask[start_idx:end_idx].sum().item()
122
- batch_num_nodes.append(non_padded_count)
123
- start_idx = end_idx
124
- batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata['features'].device)
125
- sum_weights = batch_num_nodes[:, None].repeat(1, 64)
126
- global_feats = batch_num_nodes[:, None].to(torch.float)
127
-
128
- h_global = self.global_encoder(global_feats)
129
- for i in range(self.n_proc_steps):
130
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
131
- g.apply_edges(copy_v)
132
- g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
133
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
134
- g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
135
- if "w" in g.ndata:
136
- mean_nodes = dgl.sum_nodes(g, 'h', 'w') / sum_weights
137
- h_global = self.global_update(torch.cat((h_global, mean_nodes, dgl.mean_edges(g, 'e')), dim = 1))
138
- else:
139
- h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
140
- before_global_decoder = h_global
141
- after_global_decoder = self.global_decoder(before_global_decoder)
142
- after_classify = self.classify(after_global_decoder)
143
- return before_global_decoder, after_global_decoder, after_classify
144
-
145
- def __str__(self):
146
- layer_names = ["node_encoder", "edge_encoder", "global_encoder",
147
- "node_update", "edge_update", "global_update", "global_decoder"]
148
-
149
- layers = [self.node_encoder, self.edge_encoder, self.global_encoder,
150
- self.node_update, self.edge_update, self.global_update, self.global_decoder]
151
-
152
- for i in range(len(layers)):
153
- print(layer_names[i])
154
- for layer in layers[i].children():
155
- if isinstance(layer, nn.Linear):
156
- print(layer.state_dict())
157
-
158
- print("classify")
159
- print(self.classify.weight)
160
- return ""
161
-
162
- def __name__():
163
- return "Edge_Network"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nemo/models/GCN.py DELETED
@@ -1,1818 +0,0 @@
1
- import dgl
2
- import dgl.nn as dglnn
3
-
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
-
8
- import sys
9
- import os
10
- file_path = os.getcwd()
11
- sys.path.append(file_path)
12
-
13
- import root_gnn_base.dataset as datasets
14
- from root_gnn_base import utils
15
-
16
- import gc
17
-
18
- def Make_SLP(in_size, out_size, activation = nn.ReLU, dropout = 0):
19
- layers = []
20
- layers.append(nn.Linear(in_size, out_size))
21
- layers.append(activation())
22
- layers.append(nn.Dropout(dropout))
23
- return layers
24
-
25
- def Make_MLP(in_size, hid_size, out_size, n_layers, activation = nn.ReLU, dropout = 0):
26
- layers = []
27
- if n_layers > 1:
28
- layers += Make_SLP(in_size, hid_size, activation, dropout)
29
- for i in range(n_layers-2):
30
- layers += Make_SLP(hid_size, hid_size, activation, dropout)
31
- layers += Make_SLP(hid_size, out_size, activation, dropout)
32
- else:
33
- layers += Make_SLP(in_size, out_size, activation, dropout)
34
- layers.append(torch.nn.LayerNorm(out_size))
35
- return nn.Sequential(*layers)
36
-
37
- class MLP(nn.Module):
38
- def __init__(self, in_size, hid_size, out_size, n_layers, activation = nn.ReLU, dropout = 0, **kwargs):
39
- super().__init__()
40
- print(f'Unused args while creating MLP: {kwargs}')
41
- self.layers = Make_MLP(in_size, hid_size, hid_size, n_layers-1, activation, dropout)
42
- self.linear = nn.Linear(hid_size, out_size)
43
-
44
- def forward(self, x):
45
- return self.linear(self.layers(x))
46
-
47
- def broadcast_global_to_nodes(g, globals):
48
- boundaries = g.batch_num_nodes()
49
- return torch.repeat_interleave(globals, boundaries, dim=0)
50
-
51
- def broadcast_global_to_edges(g, globals):
52
- boundaries = g.batch_num_edges()
53
- return torch.repeat_interleave(globals, boundaries, dim=0)
54
-
55
- def copy_v(edges):
56
- return {'m_v': edges.dst['h']}
57
-
58
- def partial_reset(model : nn.Module):
59
- in_size = len(model.classify.weight[0])
60
- out_size = len(model.classify.weight)
61
- device = next(model.classify.parameters()).device
62
- torch.manual_seed(2)
63
- model.classify = nn.Linear(in_size, out_size)
64
- model.classify.to(device)
65
- print(model.classify.weight)
66
-
67
- def print_model(model: nn.Module):
68
- print(model)
69
-
70
- def print_mlp(layer):
71
- for l in layer.children():
72
- if isinstance(l, nn.Linear):
73
- print(l.state_dict())
74
- else:
75
- print(l)
76
-
77
-
78
- # This function returns a model with the whole GNN completely reset
79
- def full_reset(model : nn.Module):
80
- mlp_list = [model.node_encoder, model.edge_encoder, model.global_encoder,
81
- model.node_update, model.edge_update, model.global_update,
82
- model.global_decoder]
83
-
84
- for mlp in mlp_list:
85
- for layer in mlp.children():
86
- if hasattr(layer, 'reset_parameters'):
87
- layer.reset_parameters()
88
- partial_reset(model)
89
-
90
-
91
- class GCN(nn.Module):
92
- def __init__(self, in_size, hid_size, out_size, n_layers, **kwargs):
93
- super().__init__()
94
- print(f'Unused args while creating GCN: {kwargs}')
95
- self.n_layers = n_layers
96
- self.layers = nn.ModuleList()
97
-
98
- # two-layer GCN
99
- self.layers.extend(
100
- [nn.Linear(in_size, hid_size),] +
101
- [nn.Linear(hid_size, hid_size) for i in range(n_layers)] +
102
- [dglnn.GraphConv(hid_size, hid_size) for i in range(n_layers)] +
103
- [nn.Linear(hid_size, hid_size) for i in range(n_layers)]
104
- )
105
- self.classify = nn.Linear(hid_size, out_size)
106
- #self.dropout = nn.Dropout(0.05)
107
-
108
- def forward(self, g):
109
- h = g.ndata['features']
110
- for i, layer in enumerate(self.layers):
111
- if i >= self.n_layers + 1 and i < self.n_layers * 2 + 1:
112
- h = layer(g, h)
113
- else:
114
- h = layer(h)
115
- h = F.relu(h)
116
- with g.local_scope():
117
- g.ndata['h'] = h
118
- # Calculate graph representation by average readout.
119
- hg = dgl.mean_nodes(g, 'h')
120
- return self.classify(hg)
121
-
122
- class GCN_global(nn.Module):
123
- def __init__(self, in_size, hid_size=4, out_size=1, n_layers=1, dropout=0, **kwargs):
124
- super().__init__()
125
- print(f'Unused args while creating GCN: {kwargs}')
126
- self.n_layers = n_layers
127
-
128
- #encoder
129
- self.node_encoder = Make_MLP(in_size, hid_size, hid_size, n_layers, dropout=dropout)
130
- self.global_encoder = Make_MLP(1, hid_size, hid_size, n_layers, dropout=dropout)
131
-
132
- #GCN
133
- self.node_update = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
134
- self.global_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
135
- self.conv = dglnn.GraphConv(hid_size, hid_size)
136
-
137
- #decoder
138
- self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
139
- self.classify = nn.Linear(hid_size, out_size)
140
-
141
- def forward(self, g):
142
- h = self.node_encoder(g.ndata['features'])
143
- h_global = self.global_encoder(g.batch_num_nodes()[:, None].to(torch.float))
144
- for i in range(self.n_layers):
145
- h = self.node_update(h)
146
- h = self.conv(g, h)
147
- g.ndata['h'] = h
148
- h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h')), dim = 1))
149
- h_global = self.global_decoder(h_global)
150
- return self.classify(h_global)
151
-
152
- class GCN_global_2way(nn.Module):
153
- def __init__(self, in_size, hid_size=4, out_size=1, n_layers=1, dropout=0, **kwargs):
154
- super().__init__()
155
- print(f'Unused args while creating GCN: {kwargs}')
156
- self.n_layers = n_layers
157
-
158
- #encoder
159
- self.node_encoder = Make_MLP(in_size, hid_size, hid_size, n_layers, dropout=dropout)
160
- self.global_encoder = Make_MLP(1, hid_size, hid_size, n_layers, dropout=dropout)
161
-
162
- #GCN
163
- self.node_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
164
- self.global_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
165
- self.conv = dglnn.GraphConv(hid_size, hid_size)
166
-
167
- #decoder
168
- self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
169
- self.classify = nn.Linear(hid_size, out_size)
170
-
171
- def forward(self, g):
172
- h = self.node_encoder(g.ndata['features'])
173
- h_global = self.global_encoder(g.batch_num_nodes()[:, None].to(torch.float))
174
- for i in range(self.n_layers):
175
- h = self.node_update(torch.cat((h, broadcast_global_to_nodes(g, h_global)), dim = 1))
176
- h = self.conv(g, h)
177
- g.ndata['h'] = h
178
- h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h')), dim = 1))
179
- h_global = self.global_decoder(h_global)
180
- return self.classify(h_global)
181
-
182
- class Transferred_Learning(nn.Module):
183
- def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs):
184
- super().__init__()
185
- print(f'Unused args while creating GCN: {kwargs}')
186
- self.n_layers = n_layers
187
- self.n_proc_steps = n_proc_steps
188
- self.layers = nn.ModuleList()
189
-
190
- if (len(sample_global) == 0):
191
- self.has_global = False
192
- else:
193
- self.has_global = sample_global.shape[1] != 0
194
- gl_size = sample_global.shape[1] if self.has_global else 1
195
-
196
- self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
197
-
198
- checkpoint = torch.load(pretraining_path)
199
- self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
200
- pretrained_layers = list(self.pretrained_model.children())
201
- pretrained_layers = pretrained_layers[:-1]
202
- self.pretrained_model = nn.Sequential(*pretrained_layers)
203
-
204
- # Freeze Weights
205
- for param in self.pretrained_model.parameters():
206
- param.requires_grad = False # Freeze all layers
207
-
208
- self.global_decoder = Make_MLP(pretraining_model['args']['hid_size'], hid_size, hid_size, n_layers, dropout=dropout)
209
- self.classify = nn.Linear(hid_size, out_size)
210
-
211
- def TL_node_encoder(self, x):
212
- for layer in self.pretrained_model[1]:
213
- x = layer(x)
214
- return x
215
-
216
- def TL_edge_encoder(self, x):
217
- for layer in self.pretrained_model[2]:
218
- x = layer(x)
219
- return x
220
-
221
- def TL_global_encoder(self, x):
222
- for layer in self.pretrained_model[3]:
223
- x = layer(x)
224
- return x
225
-
226
- def TL_node_update(self, x):
227
- for layer in self.pretrained_model[4]:
228
- x = layer(x)
229
- return x
230
-
231
- def TL_edge_update(self, x):
232
- for layer in self.pretrained_model[5]:
233
- x = layer(x)
234
- return x
235
-
236
- def TL_global_update(self, x):
237
- for layer in self.pretrained_model[6]:
238
- x = layer(x)
239
- return x
240
-
241
- def TL_global_decoder(self, x):
242
- for layer in self.pretrained_model[7]:
243
- x = layer(x)
244
- return x
245
-
246
- def forward(self, g, global_feats):
247
- h = self.TL_node_encoder(g.ndata['features'])
248
- e = self.TL_edge_encoder(g.edata['features'])
249
- g.ndata['h'] = h
250
- g.edata['e'] = e
251
- if not self.has_global:
252
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
253
- h_global = self.TL_global_encoder(global_feats)
254
- for i in range(self.n_proc_steps):
255
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
256
- g.apply_edges(copy_v)
257
- g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
258
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
259
- g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
260
- h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
261
- h_global = self.TL_global_decoder(h_global)
262
- return self.classify(self.global_decoder(h_global))
263
-
264
- class Transferred_Learning_Graph(nn.Module):
265
- def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, additional_proc_steps=1, dropout=0, **kwargs):
266
- super().__init__()
267
- print(f'Unused args while creating GCN: {kwargs}')
268
- self.n_layers = n_layers
269
- self.n_proc_steps = n_proc_steps
270
- self.layers = nn.ModuleList()
271
-
272
- if (len(sample_global) == 0):
273
- self.has_global = False
274
- else:
275
- self.has_global = sample_global.shape[1] != 0
276
- gl_size = sample_global.shape[1] if self.has_global else 1
277
-
278
- self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
279
-
280
- checkpoint = torch.load(pretraining_path)
281
- self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
282
- pretrained_layers = list(self.pretrained_model.children())
283
- pretrained_layers = pretrained_layers[:-1]
284
- self.pretrained_model = nn.Sequential(*pretrained_layers)
285
-
286
- self.additional_proc_steps = additional_proc_steps
287
-
288
- # Freeze Weights
289
- for param in self.pretrained_model.parameters():
290
- param.requires_grad = False # Freeze all layers
291
-
292
- #GNN
293
- self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
294
- self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
295
- self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
296
-
297
- #decoder
298
- self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
299
- self.classify = nn.Linear(hid_size, out_size)
300
-
301
- def TL_node_encoder(self, x):
302
- for layer in self.pretrained_model[1]:
303
- x = layer(x)
304
- return x
305
-
306
- def TL_edge_encoder(self, x):
307
- for layer in self.pretrained_model[2]:
308
- x = layer(x)
309
- return x
310
-
311
- def TL_global_encoder(self, x):
312
- for layer in self.pretrained_model[3]:
313
- x = layer(x)
314
- return x
315
-
316
- def TL_node_update(self, x):
317
- for layer in self.pretrained_model[4]:
318
- x = layer(x)
319
- return x
320
-
321
- def TL_edge_update(self, x):
322
- for layer in self.pretrained_model[5]:
323
- x = layer(x)
324
- return x
325
-
326
- def TL_global_update(self, x):
327
- for layer in self.pretrained_model[6]:
328
- x = layer(x)
329
- return x
330
-
331
- def forward(self, g, global_feats):
332
- h = self.TL_node_encoder(g.ndata['features'])
333
- e = self.TL_edge_encoder(g.edata['features'])
334
- g.ndata['h'] = h
335
- g.edata['e'] = e
336
- if not self.has_global:
337
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
338
- h_global = self.TL_global_encoder(global_feats)
339
- for i in range(self.n_proc_steps):
340
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
341
- g.apply_edges(copy_v)
342
- g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
343
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
344
- g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
345
- h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
346
- for j in range(self.additional_proc_steps):
347
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
348
- g.apply_edges(copy_v)
349
- g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
350
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
351
- g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
352
- h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
353
-
354
- h_global = self.global_decoder(h_global)
355
- return self.classify(h_global)
356
-
357
- class Transferred_Learning_Parallel(nn.Module):
358
- def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs):
359
- super().__init__()
360
- print(f'Unused args while creating GCN: {kwargs}')
361
- self.n_layers = n_layers
362
- self.n_proc_steps = n_proc_steps
363
- self.layers = nn.ModuleList()
364
- self.has_global = sample_global.shape[1] != 0
365
- gl_size = sample_global.shape[1] if self.has_global else 1
366
-
367
- self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
368
- checkpoint = torch.load(pretraining_path)
369
- self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
370
- pretrained_layers = list(self.pretrained_model.children())
371
- pretrained_layers = pretrained_layers[:-1]
372
- self.pretrained_model = nn.Sequential(*pretrained_layers)
373
-
374
- # Freeze Weights
375
- for param in self.pretrained_model.parameters():
376
- param.requires_grad = False # Freeze all layers
377
-
378
- #encoder
379
- self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
380
- self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
381
- self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
382
-
383
- #GNN
384
- self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
385
- self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
386
- self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
387
-
388
- #decoder
389
- self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
390
- self.classify = nn.Linear(hid_size + pretraining_model['args']['hid_size'], out_size)
391
-
392
- def TL_node_encoder(self, x):
393
- for layer in self.pretrained_model[1]:
394
- x = layer(x)
395
- return x
396
-
397
- def TL_edge_encoder(self, x):
398
- for layer in self.pretrained_model[2]:
399
- x = layer(x)
400
- return x
401
-
402
- def TL_global_encoder(self, x):
403
- for layer in self.pretrained_model[3]:
404
- x = layer(x)
405
- return x
406
-
407
- def TL_node_update(self, x):
408
- for layer in self.pretrained_model[4]:
409
- x = layer(x)
410
- return x
411
-
412
- def TL_edge_update(self, x):
413
- for layer in self.pretrained_model[5]:
414
- x = layer(x)
415
- return x
416
-
417
- def TL_global_update(self, x):
418
- for layer in self.pretrained_model[6]:
419
- x = layer(x)
420
- return x
421
-
422
- def TL_global_decoder(self, x):
423
- for layer in self.pretrained_model[7]:
424
- x = layer(x)
425
- return x
426
-
427
- def Pretrained_Output(self, g):
428
- h = self.TL_node_encoder(g.ndata['features'])
429
- e = self.TL_edge_encoder(g.edata['features'])
430
- g.ndata['h'] = h
431
- g.edata['e'] = e
432
- if not self.has_global:
433
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
434
- h_global = self.TL_global_encoder(global_feats)
435
- for i in range(self.n_proc_steps):
436
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
437
- g.apply_edges(copy_v)
438
- g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
439
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
440
- g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
441
- h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
442
- h_global = self.TL_global_decoder(h_global)
443
- return h_global
444
-
445
- def forward(self, g, global_feats):
446
- pretrained_global = self.Pretrained_Output(g.clone())
447
- h = self.node_encoder(g.ndata['features'])
448
- e = self.edge_encoder(g.edata['features'])
449
- g.ndata['h'] = h
450
- g.edata['e'] = e
451
- if not self.has_global:
452
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
453
- h_global = self.global_encoder(global_feats)
454
- for i in range(self.n_proc_steps):
455
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
456
- g.apply_edges(copy_v)
457
- g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
458
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
459
- g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
460
- h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
461
- h_global = self.global_decoder(h_global)
462
-
463
- return self.classify(torch.cat((pretrained_global, h_global), dim = 1))
464
-
465
- class Transferred_Learning_Sequential(nn.Module):
466
- def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs):
467
- super().__init__()
468
- print(f'Unused args while creating GCN: {kwargs}')
469
- self.n_layers = n_layers
470
- self.n_proc_steps = n_proc_steps
471
- self.layers = nn.ModuleList()
472
- self.has_global = sample_global.shape[1] != 0
473
- gl_size = sample_global.shape[1] if self.has_global else 1
474
-
475
- self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
476
- checkpoint = torch.load(pretraining_path)
477
- self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
478
- pretrained_layers = list(self.pretrained_model.children())
479
- pretrained_layers = pretrained_layers[:-1]
480
- self.pretrained_model = nn.Sequential(*pretrained_layers)
481
-
482
- # Freeze Weights
483
- for param in self.pretrained_model.parameters():
484
- param.requires_grad = False # Freeze all layers
485
-
486
- #encoder
487
- self.mlp = Make_MLP(pretraining_model['args']['hid_size'], hid_size, hid_size, n_layers, dropout=dropout)
488
-
489
- self.classify = nn.Linear(hid_size, out_size)
490
-
491
- def TL_node_encoder(self, x):
492
- for layer in self.pretrained_model[1]:
493
- x = layer(x)
494
- return x
495
-
496
- def TL_edge_encoder(self, x):
497
- for layer in self.pretrained_model[2]:
498
- x = layer(x)
499
- return x
500
-
501
- def TL_global_encoder(self, x):
502
- for layer in self.pretrained_model[3]:
503
- x = layer(x)
504
- return x
505
-
506
- def TL_node_update(self, x):
507
- for layer in self.pretrained_model[4]:
508
- x = layer(x)
509
- return x
510
-
511
- def TL_edge_update(self, x):
512
- for layer in self.pretrained_model[5]:
513
- x = layer(x)
514
- return x
515
-
516
- def TL_global_update(self, x):
517
- for layer in self.pretrained_model[6]:
518
- x = layer(x)
519
- return x
520
-
521
- def TL_global_decoder(self, x):
522
- for layer in self.pretrained_model[7]:
523
- x = layer(x)
524
- return x
525
-
526
- def Pretrained_Output(self, g):
527
- h = self.TL_node_encoder(g.ndata['features'])
528
- e = self.TL_edge_encoder(g.edata['features'])
529
- g.ndata['h'] = h
530
- g.edata['e'] = e
531
- if not self.has_global:
532
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
533
- h_global = self.TL_global_encoder(global_feats)
534
- for i in range(self.n_proc_steps):
535
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
536
- g.apply_edges(copy_v)
537
- g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
538
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
539
- g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
540
- h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
541
- h_global = self.TL_global_decoder(h_global)
542
- return h_global
543
-
544
- def forward(self, g, global_feats):
545
- pretrained_global = self.Pretrained_Output(g.clone())
546
- global_features = self.mlp(pretrained_global)
547
- return self.classify(global_features)
548
-
549
-
550
- class Transferred_Learning_Message_Passing(nn.Module):
551
- def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs):
552
- super().__init__()
553
- print(f'Unused args while creating GCN: {kwargs}')
554
- self.n_layers = n_layers
555
- self.n_proc_steps = n_proc_steps
556
- self.layers = nn.ModuleList()
557
- self.has_global = sample_global.shape[1] != 0
558
- gl_size = sample_global.shape[1] if self.has_global else 1
559
-
560
- self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
561
- checkpoint = torch.load(pretraining_path)
562
- self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
563
- pretrained_layers = list(self.pretrained_model.children())
564
- pretrained_layers = pretrained_layers[:-1]
565
- self.pretrained_model = nn.Sequential(*pretrained_layers)
566
-
567
- # Freeze Weights
568
- for param in self.pretrained_model.parameters():
569
- param.requires_grad = False # Freeze all layers
570
-
571
- #encoder
572
- self.mlp = Make_MLP(pretraining_model['args']['hid_size']*pretraining_model['args']['n_proc_steps'], hid_size, hid_size, n_layers, dropout=dropout)
573
-
574
- self.classify = nn.Linear(hid_size, out_size)
575
-
576
- def TL_node_encoder(self, x):
577
- for layer in self.pretrained_model[1]:
578
- x = layer(x)
579
- return x
580
-
581
- def TL_edge_encoder(self, x):
582
- for layer in self.pretrained_model[2]:
583
- x = layer(x)
584
- return x
585
-
586
- def TL_global_encoder(self, x):
587
- for layer in self.pretrained_model[3]:
588
- x = layer(x)
589
- return x
590
-
591
- def TL_node_update(self, x):
592
- for layer in self.pretrained_model[4]:
593
- x = layer(x)
594
- return x
595
-
596
- def TL_edge_update(self, x):
597
- for layer in self.pretrained_model[5]:
598
- x = layer(x)
599
- return x
600
-
601
- def TL_global_update(self, x):
602
- for layer in self.pretrained_model[6]:
603
- x = layer(x)
604
- return x
605
-
606
- def TL_global_decoder(self, x):
607
- for layer in self.pretrained_model[7]:
608
- x = layer(x)
609
- return x
610
-
611
- def Pretrained_Output(self, g):
612
- message_passing = None
613
- h = self.TL_node_encoder(g.ndata['features'])
614
- e = self.TL_edge_encoder(g.edata['features'])
615
- g.ndata['h'] = h
616
- g.edata['e'] = e
617
- if not self.has_global:
618
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
619
- h_global = self.TL_global_encoder(global_feats)
620
- for i in range(self.n_proc_steps):
621
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
622
- g.apply_edges(copy_v)
623
- g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
624
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
625
- g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
626
- h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
627
- if (message_passing is None):
628
- message_passing = h_global.clone()
629
- else:
630
- message_passing = torch.cat((message_passing, h_global.clone()), dim=1)
631
- h_global = self.TL_global_decoder(h_global)
632
- return message_passing
633
-
634
- def forward(self, g, global_feats):
635
- pretrained_global = self.Pretrained_Output(g.clone())
636
- #print(f"message_passing layers have size = {pretrained_global.shape}")
637
- #print(pretrained_global)
638
- global_features = self.mlp(pretrained_global)
639
- return self.classify(global_features)
640
-
641
- class Transferred_Learning_Message_Passing_Parallel(nn.Module):
642
- def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs):
643
- super().__init__()
644
- print(f'Unused args while creating GCN: {kwargs}')
645
- self.n_layers = n_layers
646
- self.n_proc_steps = n_proc_steps
647
- self.layers = nn.ModuleList()
648
- self.has_global = sample_global.shape[1] != 0
649
- gl_size = sample_global.shape[1] if self.has_global else 1
650
-
651
- self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
652
- checkpoint = torch.load(pretraining_path)
653
- self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
654
- pretrained_layers = list(self.pretrained_model.children())
655
- pretrained_layers = pretrained_layers[:-1]
656
- self.pretrained_model = nn.Sequential(*pretrained_layers)
657
-
658
- #encoder
659
- self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
660
- self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
661
- self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
662
-
663
- #GNN
664
- self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
665
- self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
666
- self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
667
-
668
- #decoder
669
- self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
670
-
671
- # Freeze Weights
672
- for param in self.pretrained_model.parameters():
673
- param.requires_grad = False # Freeze all layers
674
-
675
- self.classify = nn.Linear(pretraining_model['args']['hid_size']*pretraining_model['args']['n_proc_steps'] + hid_size, out_size)
676
-
677
- def TL_node_encoder(self, x):
678
- for layer in self.pretrained_model[1]:
679
- x = layer(x)
680
- return x
681
-
682
- def TL_edge_encoder(self, x):
683
- for layer in self.pretrained_model[2]:
684
- x = layer(x)
685
- return x
686
-
687
- def TL_global_encoder(self, x):
688
- for layer in self.pretrained_model[3]:
689
- x = layer(x)
690
- return x
691
-
692
- def TL_node_update(self, x):
693
- for layer in self.pretrained_model[4]:
694
- x = layer(x)
695
- return x
696
-
697
- def TL_edge_update(self, x):
698
- for layer in self.pretrained_model[5]:
699
- x = layer(x)
700
- return x
701
-
702
- def TL_global_update(self, x):
703
- for layer in self.pretrained_model[6]:
704
- x = layer(x)
705
- return x
706
-
707
- def TL_global_decoder(self, x):
708
- for layer in self.pretrained_model[7]:
709
- x = layer(x)
710
- return x
711
-
712
- def Pretrained_Output(self, g):
713
- message_passing = None
714
- h = self.TL_node_encoder(g.ndata['features'])
715
- e = self.TL_edge_encoder(g.edata['features'])
716
- g.ndata['h'] = h
717
- g.edata['e'] = e
718
- if not self.has_global:
719
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
720
- h_global = self.TL_global_encoder(global_feats)
721
- for i in range(self.n_proc_steps):
722
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
723
- g.apply_edges(copy_v)
724
- g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
725
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
726
- g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
727
- h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
728
- if (message_passing is None):
729
- message_passing = h_global.clone()
730
- else:
731
- message_passing = torch.cat((message_passing, h_global.clone()), dim=1)
732
- h_global = self.TL_global_decoder(h_global)
733
- return message_passing
734
-
735
- def forward(self, g, global_feats):
736
- pretrained_message = self.Pretrained_Output(g.clone())
737
- h = self.node_encoder(g.ndata['features'])
738
- e = self.edge_encoder(g.edata['features'])
739
- g.ndata['h'] = h
740
- g.edata['e'] = e
741
- if not self.has_global:
742
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
743
- h_global = self.global_encoder(global_feats)
744
- for i in range(self.n_proc_steps):
745
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
746
- g.apply_edges(copy_v)
747
- g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
748
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
749
- g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
750
- h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
751
- h_global = self.global_decoder(h_global)
752
- return self.classify(torch.cat((pretrained_message, h_global), dim = 1))
753
-
754
- class Transferred_Learning_Finetuning(nn.Module):
755
- def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, frozen_pretraining=False, **kwargs):
756
- super().__init__()
757
- print(f'Unused args while creating GCN: {kwargs}')
758
- self.n_layers = n_layers
759
- self.n_proc_steps = n_proc_steps
760
- self.layers = nn.ModuleList()
761
-
762
- if (len(sample_global) == 0):
763
- self.has_global = False
764
- else:
765
- self.has_global = sample_global.shape[1] != 0
766
- gl_size = sample_global.shape[1] if self.has_global else 1
767
-
768
- self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
769
-
770
- checkpoint = torch.load(pretraining_path)
771
- self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
772
- pretrained_layers = list(self.pretrained_model.children())
773
- pretrained_layers = pretrained_layers[:-1]
774
- self.pretrained_model = nn.Sequential(*pretrained_layers)
775
-
776
- print(f"Freeze Pretraining = {frozen_pretraining}")
777
- if (frozen_pretraining):
778
- for param in self.pretrained_model.parameters():
779
- param.requires_grad = False # Freeze all layers
780
- for param in self.pretrained_model[7]:
781
- param.requires_grad = True
782
-
783
- torch.manual_seed(2)
784
- self.classify = nn.Linear(pretraining_model['args']['hid_size'], out_size)
785
-
786
- def TL_node_encoder(self, x):
787
- for layer in self.pretrained_model[1]:
788
- x = layer(x)
789
- return x
790
-
791
- def TL_edge_encoder(self, x):
792
- for layer in self.pretrained_model[2]:
793
- x = layer(x)
794
- return x
795
-
796
- def TL_global_encoder(self, x):
797
- for layer in self.pretrained_model[3]:
798
- x = layer(x)
799
- return x
800
-
801
- def TL_node_update(self, x):
802
- for layer in self.pretrained_model[4]:
803
- x = layer(x)
804
- return x
805
-
806
- def TL_edge_update(self, x):
807
- for layer in self.pretrained_model[5]:
808
- x = layer(x)
809
- return x
810
-
811
- def TL_global_update(self, x):
812
- for layer in self.pretrained_model[6]:
813
- x = layer(x)
814
- return x
815
-
816
- def TL_global_decoder(self, x):
817
- for layer in self.pretrained_model[7]:
818
- x = layer(x)
819
- return x
820
-
821
- def Pretrained_Output(self, g):
822
- h = self.TL_node_encoder(g.ndata['features'])
823
- e = self.TL_edge_encoder(g.edata['features'])
824
- g.ndata['h'] = h
825
- g.edata['e'] = e
826
- if not self.has_global:
827
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
828
- h_global = self.TL_global_encoder(global_feats)
829
- for i in range(self.n_proc_steps):
830
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
831
- g.apply_edges(copy_v)
832
- g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
833
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
834
- g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
835
- h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
836
- h_global = self.TL_global_decoder(h_global)
837
- return h_global
838
-
839
- def forward(self, g, global_feats):
840
- h_global = self.Pretrained_Output(g.clone())
841
- return self.classify(h_global)
842
-
843
- def representation(self, g, global_feats):
844
- h = self.TL_node_encoder(g.ndata['features'])
845
- e = self.TL_edge_encoder(g.edata['features'])
846
- g.ndata['h'] = h
847
- g.edata['e'] = e
848
- if not self.has_global:
849
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
850
- h_global = self.TL_global_encoder(global_feats)
851
- for i in range(self.n_proc_steps):
852
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
853
- g.apply_edges(copy_v)
854
- g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
855
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
856
- g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
857
- h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
858
-
859
- before_global_decoder = h_global
860
- after_global_decoder = self.TL_global_decoder(before_global_decoder)
861
- after_classify = self.classify(after_global_decoder)
862
- return before_global_decoder, after_global_decoder, after_classify
863
-
864
- def __str__(self):
865
- layer_names = ["node_encoder", "edge_encoder", "global_encoder",
866
- "node_update", "edge_update", "global_update", "global_decoder"]
867
-
868
- layers = [self.pretrained_model[1], self.pretrained_model[2], self.pretrained_model[3],
869
- self.pretrained_model[4], self.pretrained_model[5], self.pretrained_model[6],
870
- self.pretrained_model[7]]
871
-
872
- for i in range(len(layers)):
873
- print(layer_names[i])
874
- for layer in layers[i].children():
875
- if isinstance(layer, nn.Linear):
876
- print(layer.state_dict())
877
-
878
- print("classify")
879
- print(self.classify.weight)
880
- return ""
881
-
882
-
883
- class Transferred_Learning_Parallel_Finetuning(nn.Module):
884
- def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, learning_rate=0.0001, **kwargs):
885
- super().__init__()
886
- print(f'Unused args while creating GCN: {kwargs}')
887
-
888
- self.learning_rate = learning_rate
889
-
890
- self.parallel_params = []
891
- self.finetuning_params = []
892
-
893
-
894
- self.n_layers = n_layers
895
- self.n_proc_steps = n_proc_steps
896
- self.layers = nn.ModuleList()
897
- self.has_global = sample_global.shape[1] != 0
898
- gl_size = sample_global.shape[1] if self.has_global else 1
899
-
900
- self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
901
- checkpoint = torch.load(pretraining_path)
902
- self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
903
- pretrained_layers = list(self.pretrained_model.children())
904
- pretrained_layers = pretrained_layers[:-1]
905
- self.pretrained_model = nn.Sequential(*pretrained_layers)
906
-
907
- self.finetuning_params.append(self.pretrained_model)
908
-
909
- #encoder
910
- self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
911
- self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
912
- self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
913
-
914
- #GNN
915
- self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
916
- self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
917
- self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
918
-
919
- #decoder
920
- self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
921
- self.classify = nn.Linear(hid_size + pretraining_model['args']['hid_size'], out_size)
922
-
923
- self.parallel_params.append(self.node_encoder)
924
- self.parallel_params.append(self.edge_encoder)
925
- self.parallel_params.append(self.global_encoder)
926
- self.parallel_params.append(self.node_update)
927
- self.parallel_params.append(self.edge_update)
928
- self.parallel_params.append(self.global_update)
929
- self.parallel_params.append(self.global_decoder)
930
- self.parallel_params.append(self.classify)
931
-
932
- def TL_node_encoder(self, x):
933
- for layer in self.pretrained_model[1]:
934
- x = layer(x)
935
- return x
936
-
937
- def TL_edge_encoder(self, x):
938
- for layer in self.pretrained_model[2]:
939
- x = layer(x)
940
- return x
941
-
942
- def TL_global_encoder(self, x):
943
- for layer in self.pretrained_model[3]:
944
- x = layer(x)
945
- return x
946
-
947
- def TL_node_update(self, x):
948
- for layer in self.pretrained_model[4]:
949
- x = layer(x)
950
- return x
951
-
952
- def TL_edge_update(self, x):
953
- for layer in self.pretrained_model[5]:
954
- x = layer(x)
955
- return x
956
-
957
- def TL_global_update(self, x):
958
- for layer in self.pretrained_model[6]:
959
- x = layer(x)
960
- return x
961
-
962
- def TL_global_decoder(self, x):
963
- for layer in self.pretrained_model[7]:
964
- x = layer(x)
965
- return x
966
-
967
- def Pretrained_Output(self, g):
968
- h = self.TL_node_encoder(g.ndata['features'])
969
- e = self.TL_edge_encoder(g.edata['features'])
970
- g.ndata['h'] = h
971
- g.edata['e'] = e
972
- if not self.has_global:
973
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
974
- h_global = self.TL_global_encoder(global_feats)
975
- for i in range(self.n_proc_steps):
976
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
977
- g.apply_edges(copy_v)
978
- g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
979
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
980
- g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
981
- h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
982
- h_global = self.TL_global_decoder(h_global)
983
- return h_global
984
-
985
- def forward(self, g, global_feats):
986
- pretrained_global = self.Pretrained_Output(g.clone())
987
- h = self.node_encoder(g.ndata['features'])
988
- e = self.edge_encoder(g.edata['features'])
989
- g.ndata['h'] = h
990
- g.edata['e'] = e
991
- if not self.has_global:
992
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
993
- h_global = self.global_encoder(global_feats)
994
- for i in range(self.n_proc_steps):
995
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
996
- g.apply_edges(copy_v)
997
- g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
998
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
999
- g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
1000
- h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
1001
- h_global = self.global_decoder(h_global)
1002
-
1003
- return self.classify(torch.cat((pretrained_global, h_global), dim = 1))
1004
-
1005
- def parameters(self, recurse: bool = True):
1006
- params = []
1007
- for model_section in self.parallel_params:
1008
- if (type(self.learning_rate) == dict and self.learning_rate["trainable_lr"]):
1009
- params.append({'params': model_section.parameters(), 'lr': self.learning_rate["trainable_lr"]})
1010
- else:
1011
- params.append({'params': model_section.parameters(), 'lr': 0.0001})
1012
- for model_section in self.finetuning_params:
1013
- if (type(self.learning_rate) == dict and self.learning_rate["finetuning_lr"]):
1014
- params.append({'params': model_section.parameters(), 'lr': self.learning_rate["finetuning_lr"]})
1015
- else:
1016
- params.append({'params': model_section.parameters(), 'lr': 0.0001})
1017
- return params
1018
-
1019
- class Attention(nn.Module):
1020
- def __init__(self, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, num_heads = 1, **kwargs):
1021
- super().__init__()
1022
- print(f'Unused args while creating GCN: {kwargs}')
1023
- self.n_layers = n_layers
1024
- self.n_proc_steps = n_proc_steps
1025
- self.layers = nn.ModuleList()
1026
- self.has_global = sample_global.shape[1] != 0
1027
- self.hid_size = hid_size
1028
- gl_size = sample_global.shape[1] if self.has_global else 1
1029
-
1030
- #encoder
1031
- self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
1032
- self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
1033
-
1034
- #GNN
1035
- self.node_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1036
- self.global_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1037
-
1038
- #decoder
1039
- self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1040
- self.classify = nn.Linear(hid_size, out_size)
1041
-
1042
- #attention
1043
- self.multihead_attn = nn.MultiheadAttention(hid_size, num_heads, dropout=dropout, batch_first=True)
1044
- self.queries = nn.Linear(hid_size, hid_size)
1045
- self.keys = nn.Linear(hid_size, hid_size)
1046
- self.values = nn.Linear(hid_size, hid_size)
1047
-
1048
- def forward(self, g, global_feats):
1049
- h = self.node_encoder(g.ndata['features'])
1050
- g.ndata['h'] = h
1051
-
1052
- if not self.has_global:
1053
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
1054
-
1055
- batch_num_nodes = None
1056
- sum_weights = None
1057
- if "w" in g.ndata:
1058
- batch_indices = g.batch_num_nodes()
1059
- # Find non-zero rows (non-padded nodes)
1060
- non_padded_nodes_mask = torch.any(g.ndata['features'] != 0, dim=1)
1061
- # Split the mask according to the batch indices
1062
- batch_num_nodes = []
1063
- start_idx = 0
1064
- for num_nodes in batch_indices:
1065
- end_idx = start_idx + num_nodes
1066
- non_padded_count = non_padded_nodes_mask[start_idx:end_idx].sum().item()
1067
- batch_num_nodes.append(non_padded_count)
1068
- start_idx = end_idx
1069
- batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata['features'].device)
1070
- sum_weights = batch_num_nodes[:, None].repeat(1, self.hid_size)
1071
- global_feats = batch_num_nodes[:, None].to(torch.float)
1072
-
1073
- h_global = self.global_encoder(global_feats)
1074
-
1075
- h_original_shape = h.shape
1076
- num_graphs = len(dgl.unbatch(g))
1077
- num_nodes = g.batch_num_nodes()[0].item()
1078
- padding_mask = g.ndata['padding_mask'] > 0
1079
- padding_mask = torch.reshape(padding_mask, (num_graphs, num_nodes))
1080
-
1081
- h = g.ndata['h']
1082
- query = self.queries(h)
1083
- key = self.keys(h)
1084
- value = self.values(h)
1085
- query = torch.reshape(query, (num_graphs, num_nodes, h_original_shape[1]))
1086
- key = torch.reshape(key, (num_graphs, num_nodes, h_original_shape[1]))
1087
- value = torch.reshape(value, (num_graphs, num_nodes, h_original_shape[1]))
1088
- h, _ = self.multihead_attn(query, key, value, key_padding_mask=padding_mask)
1089
- h = torch.reshape(h, h_original_shape)
1090
-
1091
- h = self.node_update(torch.cat((h, broadcast_global_to_nodes(g, h_global)), dim = 1))
1092
- g.ndata['h'] = h
1093
- mean_nodes = dgl.sum_nodes(g, 'h', 'w') / sum_weights
1094
- h_global = self.global_update(torch.cat((h_global, mean_nodes), dim = 1))
1095
- h_global = self.global_decoder(h_global)
1096
- return self.classify(h_global)
1097
-
1098
- class Attention_Edge_Network(nn.Module):
1099
- def __init__(self, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, num_heads = 1, **kwargs):
1100
- super().__init__()
1101
- print(f'Unused args while creating GCN: {kwargs}')
1102
- self.n_layers = n_layers
1103
- self.n_proc_steps = n_proc_steps
1104
- self.layers = nn.ModuleList()
1105
- self.has_global = sample_global.shape[1] != 0
1106
- gl_size = sample_global.shape[1] if self.has_global else 1
1107
-
1108
- #encoder
1109
- self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
1110
- self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
1111
- self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
1112
-
1113
- #GNN
1114
- self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1115
- self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1116
- self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1117
-
1118
- #decoder
1119
- self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1120
- self.classify = nn.Linear(hid_size, out_size)
1121
-
1122
-
1123
- #attention
1124
- self.multihead_attn = nn.MultiheadAttention(hid_size, num_heads, dropout=dropout, batch_first=True)
1125
- self.queries = nn.Linear(hid_size, hid_size)
1126
- self.keys = nn.Linear(hid_size, hid_size)
1127
- self.values = nn.Linear(hid_size, hid_size)
1128
-
1129
- def forward(self, g, global_feats):
1130
- h = self.node_encoder(g.ndata['features'])
1131
- e = self.edge_encoder(g.edata['features'])
1132
- g.ndata['h'] = h
1133
- g.edata['e'] = e
1134
-
1135
- if not self.has_global:
1136
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
1137
- h_global = self.global_encoder(global_feats)
1138
-
1139
- h = g.ndata['h']
1140
- h_original_shape = h.shape
1141
- num_graphs = len(dgl.unbatch(g))
1142
- num_nodes = g.batch_num_nodes()[0].item()
1143
- padding_mask = g.ndata['padding_mask'].bool()
1144
-
1145
- padding_mask = torch.reshape(padding_mask, (num_graphs, num_nodes))
1146
-
1147
- for i in range(self.n_proc_steps):
1148
-
1149
- h = g.ndata['h']
1150
- query = self.queries(h)
1151
- key = self.keys(h)
1152
- value = self.values(h)
1153
- query = torch.reshape(query, (num_graphs, num_nodes, h_original_shape[1]))
1154
- key = torch.reshape(key, (num_graphs, num_nodes, h_original_shape[1]))
1155
- value = torch.reshape(value, (num_graphs, num_nodes, h_original_shape[1]))
1156
- h, _ = self.multihead_attn(query, key, value, key_padding_mask=padding_mask)
1157
- h = torch.reshape(h, h_original_shape)
1158
-
1159
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
1160
- g.apply_edges(copy_v)
1161
- g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
1162
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
1163
- g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
1164
- h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h', 'w'), dgl.mean_edges(g, 'e')), dim = 1))
1165
- h_global = self.global_decoder(h_global)
1166
- return self.classify(h_global)
1167
-
1168
- class Attention_Unbatched(nn.Module):
1169
- def __init__(self, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, num_heads = 1, **kwargs):
1170
- super().__init__()
1171
- print(f'Unused args while creating GCN: {kwargs}')
1172
- self.n_layers = n_layers
1173
- self.n_proc_steps = n_proc_steps
1174
- self.layers = nn.ModuleList()
1175
- self.has_global = sample_global.shape[1] != 0
1176
- gl_size = sample_global.shape[1] if self.has_global else 1
1177
-
1178
- #encoder
1179
- self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
1180
- self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
1181
- self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
1182
-
1183
- #GNN
1184
- self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1185
- self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1186
- self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1187
-
1188
- #decoder
1189
- self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1190
- self.classify = nn.Linear(hid_size, out_size)
1191
-
1192
-
1193
- #attention
1194
- self.multihead_attn = nn.MultiheadAttention(hid_size, 1, dropout=dropout)
1195
- self.queries = nn.Linear(hid_size, hid_size)
1196
- self.keys = nn.Linear(hid_size, hid_size)
1197
- self.values = nn.Linear(hid_size, hid_size)
1198
-
1199
-
1200
-
1201
- def forward(self, g, global_feats):
1202
-
1203
- h = self.node_encoder(g.ndata['features'])
1204
- e = self.edge_encoder(g.edata['features'])
1205
- g.ndata['h'] = h
1206
- g.edata['e'] = e
1207
-
1208
- if not self.has_global:
1209
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
1210
- h_global = self.global_encoder(global_feats)
1211
-
1212
- for i in range(self.n_proc_steps):
1213
-
1214
- unbatched_g = dgl.unbatch(g)
1215
- for graph in unbatched_g:
1216
- h = graph.ndata['h']
1217
- h, _ = self.multihead_attn(self.queries(h), self.keys(h), self.values(h))
1218
- graph.ndata['h'] = h
1219
- g = dgl.batch(unbatched_g)
1220
-
1221
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
1222
- g.apply_edges(copy_v)
1223
- g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
1224
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
1225
- g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
1226
- h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
1227
- h_global = self.global_decoder(h_global)
1228
- return self.classify(h_global)
1229
-
1230
- class Transferred_Learning_Attention(nn.Module):
1231
- def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, num_heads, dropout=0, learning_rate=0.0001, **kwargs):
1232
- super().__init__()
1233
- print(f'Unused args while creating GCN: {kwargs}')
1234
- self.n_layers = n_layers
1235
- self.n_proc_steps = n_proc_steps
1236
- self.layers = nn.ModuleList()
1237
- self.has_global = sample_global.shape[1] != 0
1238
- self.hid_size = hid_size
1239
- gl_size = sample_global.shape[1] if self.has_global else 1
1240
-
1241
- self.learning_rate = learning_rate
1242
-
1243
- self.pretraining_params = []
1244
- self.attention_params = []
1245
-
1246
- self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
1247
-
1248
- checkpoint = torch.load(pretraining_path)
1249
- self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
1250
- pretrained_layers = list(self.pretrained_model.children())
1251
- pretrained_layers = pretrained_layers[:-1]
1252
- self.pretrained_model = nn.Sequential(*pretrained_layers)
1253
-
1254
- self.pretraining_params.append(self.pretrained_model[1])
1255
- self.pretraining_params.append(self.pretrained_model[3])
1256
- self.pretraining_params.append(self.pretrained_model[7])
1257
-
1258
- #attention
1259
- self.multihead_attn = nn.MultiheadAttention(hid_size, num_heads, dropout=dropout, batch_first=True)
1260
- self.queries = nn.Linear(hid_size, hid_size)
1261
- self.keys = nn.Linear(hid_size, hid_size)
1262
- self.values = nn.Linear(hid_size, hid_size)
1263
-
1264
- self.node_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1265
- self.global_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1266
-
1267
- self.classify = nn.Linear(pretraining_model['args']['hid_size'], out_size)
1268
-
1269
- self.attention_params.append(self.multihead_attn)
1270
-
1271
- self.attention_params.append(self.queries)
1272
- self.attention_params.append(self.keys)
1273
- self.attention_params.append(self.values)
1274
- self.attention_params.append(self.classify)
1275
- self.attention_params.append(self.node_update)
1276
- self.attention_params.append(self.global_update)
1277
-
1278
- def TL_node_encoder(self, x):
1279
- for layer in self.pretrained_model[1]:
1280
- x = layer(x)
1281
- return x
1282
-
1283
- def TL_global_encoder(self, x):
1284
- for layer in self.pretrained_model[3]:
1285
- x = layer(x)
1286
- return x
1287
-
1288
- def TL_global_decoder(self, x):
1289
- for layer in self.pretrained_model[7]:
1290
- x = layer(x)
1291
- return x
1292
-
1293
- def forward(self, g, global_feats):
1294
- h = self.TL_node_encoder(g.ndata['features'])
1295
- g.ndata['h'] = h
1296
-
1297
- if not self.has_global:
1298
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
1299
-
1300
- batch_num_nodes = None
1301
- sum_weights = None
1302
- if "w" in g.ndata:
1303
- batch_indices = g.batch_num_nodes()
1304
- # Find non-zero rows (non-padded nodes)
1305
- non_padded_nodes_mask = torch.any(g.ndata['features'] != 0, dim=1)
1306
- # Split the mask according to the batch indices
1307
- batch_num_nodes = []
1308
- start_idx = 0
1309
- for num_nodes in batch_indices:
1310
- end_idx = start_idx + num_nodes
1311
- non_padded_count = non_padded_nodes_mask[start_idx:end_idx].sum().item()
1312
- batch_num_nodes.append(non_padded_count)
1313
- start_idx = end_idx
1314
- batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata['features'].device)
1315
- sum_weights = batch_num_nodes[:, None].repeat(1, self.hid_size)
1316
- global_feats = batch_num_nodes[:, None].to(torch.float)
1317
-
1318
- h_global = self.TL_global_encoder(global_feats)
1319
-
1320
- h_original_shape = h.shape
1321
- num_graphs = len(dgl.unbatch(g))
1322
- num_nodes = g.batch_num_nodes()[0].item()
1323
- padding_mask = g.ndata['padding_mask'] > 0
1324
- padding_mask = torch.reshape(padding_mask, (num_graphs, num_nodes))
1325
-
1326
- h = g.ndata['h']
1327
- query = self.queries(h)
1328
- key = self.keys(h)
1329
- value = self.values(h)
1330
- query = torch.reshape(query, (num_graphs, num_nodes, h_original_shape[1]))
1331
- key = torch.reshape(key, (num_graphs, num_nodes, h_original_shape[1]))
1332
- value = torch.reshape(value, (num_graphs, num_nodes, h_original_shape[1]))
1333
- h, _ = self.multihead_attn(query, key, value, key_padding_mask=padding_mask)
1334
- h = torch.reshape(h, h_original_shape)
1335
-
1336
- h = self.node_update(torch.cat((h, broadcast_global_to_nodes(g, h_global)), dim = 1))
1337
- g.ndata['h'] = h
1338
- mean_nodes = dgl.sum_nodes(g, 'h', 'w') / sum_weights
1339
- h_global = self.global_update(torch.cat((h_global, mean_nodes), dim = 1))
1340
- h_global = self.TL_global_decoder(h_global)
1341
- return self.classify(h_global)
1342
-
1343
- def parameters(self, recurse: bool = True):
1344
- params = []
1345
- for model_section in self.pretraining_params:
1346
- if (type(self.learning_rate) == dict and self.learning_rate["pretraining_lr"]):
1347
- params.append({'params': model_section.parameters(), 'lr': self.learning_rate["pretraining_lr"]})
1348
- else:
1349
- params.append({'params': model_section.parameters(), 'lr': 0.0001})
1350
- for model_section in self.attention_params:
1351
- if (type(self.learning_rate) == dict and self.learning_rate["attention_lr"]):
1352
- params.append({'params': model_section.parameters(), 'lr': self.learning_rate["attention_lr"]})
1353
- else:
1354
- params.append({'params': model_section.parameters(), 'lr': 0.0001})
1355
- return params
1356
-
1357
- class Multimodel_Transferred_Learning(nn.Module):
1358
- def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, frozen_pretraining=True, learning_rate=None, **kwargs):
1359
- super().__init__()
1360
- print(f'Unused args while creating GCN: {kwargs}')
1361
- self.n_layers = n_layers
1362
- self.n_proc_steps = n_proc_steps
1363
- self.layers = nn.ModuleList()
1364
- self.has_global = sample_global.shape[1] != 0
1365
- gl_size = sample_global.shape[1] if self.has_global else 1
1366
-
1367
- self.learning_rate = learning_rate
1368
- input_size = 0
1369
-
1370
- self.pretraining_params = []
1371
- self.model_params = []
1372
-
1373
- self.pretrained_models = []
1374
- for model, path in zip(pretraining_model, pretraining_path):
1375
- input_size += model['args']['hid_size']
1376
- model = utils.buildFromConfig(model, {'sample_graph': sample_graph, 'sample_global': sample_global})
1377
-
1378
- checkpoint = torch.load(path)['model_state_dict']
1379
- new_state_dict = {}
1380
- for k, v in checkpoint.items():
1381
- new_key = k.replace('module.', '')
1382
- new_state_dict[new_key] = v
1383
- model.load_state_dict(new_state_dict)
1384
- pretrained_layers = list(model.children())
1385
- pretrained_layers = pretrained_layers[:-1]
1386
-
1387
- model = nn.Sequential(*pretrained_layers)
1388
-
1389
- # Freeze Weights
1390
- print(f"Freeze Pretraining = {frozen_pretraining}")
1391
- if (frozen_pretraining):
1392
- for param in model.parameters():
1393
- param.requires_grad = False # Freeze all layers
1394
- self.pretraining_params.append(model)
1395
- self.pretrained_models.append(model)
1396
-
1397
- print(f"len(pretrained_models) = {len(self.pretrained_models)}")
1398
- print(f"input size = {input_size}")
1399
-
1400
- self.final_mlp = Make_MLP(input_size, hid_size, hid_size, n_layers, dropout=dropout)
1401
- self.classify = nn.Linear(hid_size, out_size)
1402
-
1403
- self.model_params.append(self.final_mlp)
1404
- self.model_params.append(self.classify)
1405
-
1406
- def TL_node_encoder(self, x, model_idx):
1407
- try:
1408
- for layer in self.pretrained_models[model_idx][1]:
1409
- x = layer(x)
1410
- return x
1411
- except (NotImplementedError, IndexError):
1412
- for layer in self.pretrained_models[model_idx][1][1]:
1413
- x = layer(x)
1414
- return x
1415
-
1416
- def TL_edge_encoder(self, x, model_idx):
1417
- try:
1418
- for layer in self.pretrained_models[model_idx][2]:
1419
- x = layer(x)
1420
- return x
1421
- except (NotImplementedError, IndexError):
1422
- for layer in self.pretrained_models[model_idx][1][2]:
1423
- x = layer(x)
1424
- return x
1425
-
1426
- def TL_global_encoder(self, x, model_idx):
1427
- try:
1428
- for layer in self.pretrained_models[model_idx][3]:
1429
- x = layer(x)
1430
- return x
1431
- except (NotImplementedError, IndexError):
1432
- for layer in self.pretrained_models[model_idx][1][3]:
1433
- x = layer(x)
1434
- return x
1435
-
1436
- def TL_node_update(self, x, model_idx):
1437
- try:
1438
- for layer in self.pretrained_models[model_idx][4]:
1439
- x = layer(x)
1440
- return x
1441
- except (NotImplementedError, IndexError):
1442
- for layer in self.pretrained_models[model_idx][1][4]:
1443
- x = layer(x)
1444
- return x
1445
-
1446
- def TL_edge_update(self, x, model_idx):
1447
- try:
1448
- for layer in self.pretrained_models[model_idx][5]:
1449
- x = layer(x)
1450
- return x
1451
- except (NotImplementedError, IndexError):
1452
- for layer in self.pretrained_models[model_idx][1][5]:
1453
- x = layer(x)
1454
- return x
1455
-
1456
- def TL_global_update(self, x, model_idx):
1457
- try:
1458
- for layer in self.pretrained_models[model_idx][6]:
1459
- x = layer(x)
1460
- return x
1461
- except (NotImplementedError, IndexError):
1462
- for layer in self.pretrained_models[model_idx][1][6]:
1463
- x = layer(x)
1464
- return x
1465
-
1466
- def TL_global_decoder(self, x, model_idx):
1467
- try:
1468
- for layer in self.pretrained_models[model_idx][7]:
1469
- x = layer(x)
1470
- return x
1471
- except (NotImplementedError, IndexError):
1472
- for layer in self.pretrained_models[model_idx][1][7]:
1473
- x = layer(x)
1474
- return x
1475
-
1476
- def Pretrained_Output(self, g, model_idx):
1477
- h = self.TL_node_encoder(g.ndata['features'], model_idx)
1478
- e = self.TL_edge_encoder(g.edata['features'], model_idx)
1479
- g.ndata['h'] = h
1480
- g.edata['e'] = e
1481
- if not self.has_global:
1482
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
1483
- h_global = self.TL_global_encoder(global_feats, model_idx)
1484
- for i in range(self.n_proc_steps):
1485
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
1486
- g.apply_edges(copy_v)
1487
- g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1), model_idx)
1488
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
1489
- g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1), model_idx)
1490
- h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1), model_idx)
1491
- # h_global = self.TL_global_decoder(h_global, model_idx)
1492
- return h_global
1493
-
1494
- def forward(self, g, global_feats):
1495
- h_global = []
1496
- for i in range(len(self.pretrained_models)):
1497
- h_global.append(self.Pretrained_Output(g.clone(), i))
1498
- h_global = torch.concatenate(h_global, dim=1)
1499
- return self.classify(self.final_mlp(h_global))
1500
-
1501
- def to(self, device):
1502
- for i in range(len(self.pretrained_models)):
1503
- self.pretrained_models[i].to(device)
1504
- self.classify.to(device)
1505
- self.final_mlp.to(device)
1506
- return self
1507
-
1508
- def parameters(self, recurse: bool = True):
1509
- params = []
1510
- for model_section in self.pretraining_params:
1511
- if (type(self.learning_rate) == dict and self.learning_rate["pretraining_lr"]):
1512
- params.append({'params': model_section.parameters(), 'lr': self.learning_rate["pretraining_lr"]})
1513
- else:
1514
- params.append({'params': model_section.parameters(), 'lr': 0.00001})
1515
- for model_section in self.model_params:
1516
- if (type(self.learning_rate) == dict and self.learning_rate["model_lr"]):
1517
- params.append({'params': model_section.parameters(), 'lr': self.learning_rate["model_lr"]})
1518
- else:
1519
- params.append({'params': model_section.parameters(), 'lr': 0.0001})
1520
- return params
1521
-
1522
-
1523
- class MultiModel(nn.Module):
1524
- def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, frozen_pretraining=True, learning_rate=None, **kwargs):
1525
- super().__init__()
1526
- print(f'Unused args while creating GCN: {kwargs}')
1527
- self.n_layers = n_layers
1528
- self.n_proc_steps = n_proc_steps
1529
- self.layers = nn.ModuleList()
1530
- self.has_global = sample_global.shape[1] != 0
1531
- gl_size = sample_global.shape[1] if self.has_global else 1
1532
-
1533
- self.learning_rate = learning_rate
1534
- input_size = 0
1535
-
1536
- self.model_params = []
1537
- self.pretraining_params = []
1538
-
1539
- self.pretrained_models = []
1540
- for model, path in zip(pretraining_model, pretraining_path):
1541
- input_size += model['args']['hid_size']
1542
- model = utils.buildFromConfig(model, {'sample_graph': sample_graph, 'sample_global': sample_global})
1543
-
1544
- checkpoint = torch.load(path)['model_state_dict']
1545
- new_state_dict = {}
1546
- for k, v in checkpoint.items():
1547
- new_key = k.replace('module.', '')
1548
- new_state_dict[new_key] = v
1549
- model.load_state_dict(new_state_dict)
1550
- pretrained_layers = list(model.children())
1551
- pretrained_layers = pretrained_layers[:-1]
1552
-
1553
- model = nn.Sequential(*pretrained_layers)
1554
-
1555
- # Freeze Weights
1556
- print(f"Freeze Pretraining = {frozen_pretraining}")
1557
- if (frozen_pretraining):
1558
- for param in model.parameters():
1559
- param.requires_grad = False # Freeze all layers
1560
- self.pretraining_params.append(model)
1561
- self.pretrained_models.append(model)
1562
-
1563
- print(f"len(pretrained_models) = {len(self.pretrained_models)}")
1564
- print(f"input size = {input_size}")
1565
-
1566
- #encoder
1567
- self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
1568
- self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
1569
- self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
1570
-
1571
- #GNN
1572
- self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1573
- self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1574
- self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1575
-
1576
- self.final_mlp = Make_MLP(input_size + hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1577
- self.classify = nn.Linear(hid_size, out_size)
1578
-
1579
- self.model_params.append(self.final_mlp)
1580
- self.model_params.append(self.classify)
1581
-
1582
- def TL_node_encoder(self, x, model_idx):
1583
- try:
1584
- for layer in self.pretrained_models[model_idx][1]:
1585
- x = layer(x)
1586
- return x
1587
- except (NotImplementedError, IndexError):
1588
- for layer in self.pretrained_models[model_idx][1][1]:
1589
- x = layer(x)
1590
- return x
1591
-
1592
- def TL_edge_encoder(self, x, model_idx):
1593
- try:
1594
- for layer in self.pretrained_models[model_idx][2]:
1595
- x = layer(x)
1596
- return x
1597
- except (NotImplementedError, IndexError):
1598
- for layer in self.pretrained_models[model_idx][1][2]:
1599
- x = layer(x)
1600
- return x
1601
-
1602
- def TL_global_encoder(self, x, model_idx):
1603
- try:
1604
- for layer in self.pretrained_models[model_idx][3]:
1605
- x = layer(x)
1606
- return x
1607
- except (NotImplementedError, IndexError):
1608
- for layer in self.pretrained_models[model_idx][1][3]:
1609
- x = layer(x)
1610
- return x
1611
-
1612
- def TL_node_update(self, x, model_idx):
1613
- try:
1614
- for layer in self.pretrained_models[model_idx][4]:
1615
- x = layer(x)
1616
- return x
1617
- except (NotImplementedError, IndexError):
1618
- for layer in self.pretrained_models[model_idx][1][4]:
1619
- x = layer(x)
1620
- return x
1621
-
1622
- def TL_edge_update(self, x, model_idx):
1623
- try:
1624
- for layer in self.pretrained_models[model_idx][5]:
1625
- x = layer(x)
1626
- return x
1627
- except (NotImplementedError, IndexError):
1628
- for layer in self.pretrained_models[model_idx][1][5]:
1629
- x = layer(x)
1630
- return x
1631
-
1632
- def TL_global_update(self, x, model_idx):
1633
- try:
1634
- for layer in self.pretrained_models[model_idx][6]:
1635
- x = layer(x)
1636
- return x
1637
- except (NotImplementedError, IndexError):
1638
- for layer in self.pretrained_models[model_idx][1][6]:
1639
- x = layer(x)
1640
- return x
1641
-
1642
- def TL_global_decoder(self, x, model_idx):
1643
- try:
1644
- for layer in self.pretrained_models[model_idx][7]:
1645
- x = layer(x)
1646
- return x
1647
- except (NotImplementedError, IndexError):
1648
- for layer in self.pretrained_models[model_idx][1][7]:
1649
- x = layer(x)
1650
- return x
1651
-
1652
- def Pretrained_Output(self, g, model_idx):
1653
- h = self.TL_node_encoder(g.ndata['features'], model_idx)
1654
- e = self.TL_edge_encoder(g.edata['features'], model_idx)
1655
- g.ndata['h'] = h
1656
- g.edata['e'] = e
1657
- if not self.has_global:
1658
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
1659
- h_global = self.TL_global_encoder(global_feats, model_idx)
1660
- for i in range(self.n_proc_steps):
1661
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
1662
- g.apply_edges(copy_v)
1663
- g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1), model_idx)
1664
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
1665
- g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1), model_idx)
1666
- h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1), model_idx)
1667
- # h_global = self.TL_global_decoder(h_global, model_idx)
1668
- return h_global
1669
-
1670
- def forward(self, g, global_feats):
1671
- h = self.node_encoder(g.ndata['features'])
1672
- e = self.edge_encoder(g.edata['features'])
1673
- g.ndata['h'] = h
1674
- g.edata['e'] = e
1675
- if not self.has_global:
1676
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
1677
- h_global = self.global_encoder(global_feats)
1678
- for i in range(self.n_proc_steps):
1679
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
1680
- g.apply_edges(copy_v)
1681
- g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
1682
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
1683
- g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
1684
- h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
1685
- h_global = [h_global]
1686
- for i in range(len(self.pretrained_models)):
1687
- h_global.append(self.Pretrained_Output(g.clone(), i))
1688
- h_global = torch.concatenate(h_global, dim=1)
1689
- return self.classify(self.final_mlp(h_global))
1690
-
1691
- def to(self, device):
1692
- for i in range(len(self.pretrained_models)):
1693
- self.pretrained_models[i].to(device)
1694
- self.classify.to(device)
1695
- self.final_mlp.to(device)
1696
- self.node_encoder.to(device)
1697
- self.edge_encoder.to(device)
1698
- self.global_encoder.to(device)
1699
-
1700
- self.node_update.to(device)
1701
- self.edge_update.to(device)
1702
- self.global_update.to(device)
1703
- return self
1704
-
1705
- def parameters(self, recurse: bool = True):
1706
- params = []
1707
- for i, model_section in enumerate(self.pretraining_params):
1708
- if (type(self.learning_rate) == dict and self.learning_rate["pretraining_lr"]):
1709
- print(f"Pretraining LR = {self.learning_rate['pretraining_lr'][i]}")
1710
- params.append({'params': model_section.parameters(), 'lr': self.learning_rate["pretraining_lr"][i]})
1711
- else:
1712
- print(f"Pretraining LR = 0.00001")
1713
- params.append({'params': model_section.parameters(), 'lr': 0.00001})
1714
- for model_section in self.model_params:
1715
- if (type(self.learning_rate) == dict and self.learning_rate["model_lr"]):
1716
- print(f"Model LR = {self.learning_rate['model_lr']}")
1717
- params.append({'params': model_section.parameters(), 'lr': self.learning_rate["model_lr"]})
1718
- else:
1719
- print(f"Model LR = 0.0001")
1720
- params.append({'params': model_section.parameters(), 'lr': 0.0001})
1721
- return params
1722
-
1723
-
1724
- class Clustering(nn.Module):
1725
- def __init__(self, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs):
1726
- super().__init__()
1727
- print(f'Unused args while creating GCN: {kwargs}')
1728
- self.n_layers = n_layers
1729
- self.n_proc_steps = n_proc_steps
1730
- self.layers = nn.ModuleList()
1731
- self.hid_size = hid_size
1732
- if (len(sample_global) == 0):
1733
- self.has_global = False
1734
- else:
1735
- self.has_global = sample_global.shape[1] != 0
1736
- gl_size = sample_global.shape[1] if self.has_global else 1
1737
-
1738
- #encoder
1739
- self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
1740
- self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
1741
- self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
1742
-
1743
- #GNN
1744
- self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1745
- self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1746
- self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
1747
-
1748
- #decoder
1749
- self.global_decoder = Make_MLP(hid_size, hid_size, out_size, n_layers, dropout=dropout)
1750
-
1751
- def model_forward(self, g, global_feats, features = 'features'):
1752
- h = self.node_encoder(g.ndata[features])
1753
- e = self.edge_encoder(g.edata[features])
1754
-
1755
- g.ndata['h'] = h
1756
- g.edata['e'] = e
1757
- if not self.has_global:
1758
- global_feats = g.batch_num_nodes()[:, None].to(torch.float)
1759
-
1760
- batch_num_nodes = None
1761
- sum_weights = None
1762
- if "w" in g.ndata:
1763
- batch_indices = g.batch_num_nodes()
1764
- # Find non-zero rows (non-padded nodes)
1765
- non_padded_nodes_mask = torch.any(g.ndata[features] != 0, dim=1)
1766
- # Split the mask according to the batch indices
1767
- batch_num_nodes = []
1768
- start_idx = 0
1769
- for num_nodes in batch_indices:
1770
- end_idx = start_idx + num_nodes
1771
- non_padded_count = non_padded_nodes_mask[start_idx:end_idx].sum().item()
1772
- batch_num_nodes.append(non_padded_count)
1773
- start_idx = end_idx
1774
- batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata[features].device)
1775
- sum_weights = batch_num_nodes[:, None].repeat(1, self.hid_size)
1776
- global_feats = batch_num_nodes[:, None].to(torch.float)
1777
-
1778
- h_global = self.global_encoder(global_feats)
1779
- for i in range(self.n_proc_steps):
1780
- g.apply_edges(dgl.function.copy_u('h', 'm_u'))
1781
- g.apply_edges(copy_v)
1782
- g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
1783
- g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
1784
- g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
1785
- if "w" in g.ndata:
1786
- mean_nodes = dgl.sum_nodes(g, 'h', 'w') / sum_weights
1787
- h_global = self.global_update(torch.cat((h_global, mean_nodes, dgl.mean_edges(g, 'e')), dim = 1))
1788
- else:
1789
- h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
1790
- h_global = self.global_decoder(h_global)
1791
- return h_global
1792
-
1793
- def forward(self, g, global_feats):
1794
- h_global = self.model_forward(g, global_feats, 'features')
1795
- h_global_augmented = self.model_forward(g, global_feats, 'augmented_features')
1796
- return torch.cat((h_global, h_global_augmented), dim=1)
1797
-
1798
- def representation(self, g, global_feats):
1799
- h_global = self.model_forward(g, global_feats, 'features')
1800
- h_global_augmented = self.model_forward(g, global_feats, 'augmented_features')
1801
- return h_global, h_global_augmented, torch.cat((h_global, h_global_augmented), dim=1)
1802
-
1803
- def __str__(self):
1804
- layer_names = ["node_encoder", "edge_encoder", "global_encoder",
1805
- "node_update", "edge_update", "global_update", "global_decoder"]
1806
-
1807
- layers = [self.node_encoder, self.edge_encoder, self.global_encoder,
1808
- self.node_update, self.edge_update, self.global_update, self.global_decoder]
1809
-
1810
- for i in range(len(layers)):
1811
- print(layer_names[i])
1812
- for layer in layers[i].children():
1813
- if isinstance(layer, nn.Linear):
1814
- print(layer.state_dict())
1815
-
1816
- print("classify")
1817
- print(self.classify.weight)
1818
- return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nemo/models/loss.py DELETED
@@ -1,311 +0,0 @@
1
- from torch import nn
2
- import torch
3
- from root_gnn_base import utils
4
- import numpy as np
5
-
6
- class MaskedLoss():
7
- def __init__(self, mask = []):
8
- self.mask = mask
9
-
10
- def make_mask(self, targets):
11
- mask = torch.ones_like(targets[:,0])
12
- for m in self.mask:
13
- if m['op'] == 'eq':
14
- mask[targets[:,m['idx']] == m['val']] = 0
15
- elif m['op'] == 'gt':
16
- mask[targets[:,m['idx']] > m['val']] = 0
17
- elif m['op'] == 'lt':
18
- mask[targets[:,m['idx']] < m['val']] = 0
19
- elif m['op'] == 'ge':
20
- mask[targets[:,m['idx']] >= m['val']] = 0
21
- elif m['op'] == 'le':
22
- mask[targets[:,m['idx']] <= m['val']] = 0
23
- elif m['op'] == 'ne':
24
- mask[targets[:,m['idx']] != m['val']] = 0
25
- else:
26
- raise ValueError(f'Unknown mask op {m["op"]}')
27
- return mask == 1
28
-
29
- class MaskedL1Loss(MaskedLoss):
30
- def __init__(self, mask = [], index = 0):
31
- super().__init__(mask)
32
- self.index = index
33
- self.loss = nn.L1Loss()
34
-
35
- def __call__(self, logits, targets):
36
- mask = self.make_mask(targets)
37
- return self.loss(logits[mask], targets[mask][:,self.index])
38
-
39
- class BCEWithLogitsLoss():
40
- def __init__(self, weight=None, reduction='mean'):
41
- self.loss = nn.BCEWithLogitsLoss(weight=weight, reduction=reduction)
42
-
43
- def __call__(self, logits, targets):
44
- return self.loss(logits[:,0], targets.float())
45
-
46
- class MultiScore():
47
- def __init__(self, scores):
48
- self. score_fcns = []
49
- self.start_idx = []
50
- self.end_idx = []
51
- for score in scores:
52
- self.score_fcns.append(utils.buildFromConfig(score))
53
- self.start_idx.append(score['start_idx'])
54
- self.end_idx.append(score['end_idx'])
55
-
56
- def __call__(self, last_layer):
57
- scores = []
58
- for i in range(len(self.score_fcns)):
59
- scores.append(self.score_fcns[i](last_layer[:, self.start_idx[i]:self.end_idx[i]]))
60
- return torch.cat(scores, dim=1)
61
-
62
- class MultiLoss():
63
- def __init__(self, losses):
64
- self.loss_fcns = []
65
- self.label_start_idx = []
66
- self.label_end_idx = []
67
- self.output_start_idx = []
68
- self.output_end_idx = []
69
- self.weights = []
70
- self.label_types = []
71
- for loss in losses:
72
- self.loss_fcns.append(utils.buildFromConfig(loss))
73
- self.label_start_idx.append(loss['label_start_idx'])
74
- self.label_end_idx.append(loss['label_end_idx'])
75
- self.output_start_idx.append(loss['output_start_idx'])
76
- self.output_end_idx.append(loss['output_end_idx'])
77
- self.weights.append(loss.get('weight', 1.0))
78
- self.label_types.append(loss.get('label_type', 'float'))
79
-
80
- def __call__(self, logits, targets):
81
- loss = 0
82
- # print(logits.shape, targets.shape)
83
- for i in range(len(self.loss_fcns)):
84
- if self.label_types[i] == 'int':
85
- # print('loss', i, self.label_start_idx[i], self.label_end_idx[i], self.output_start_idx[i], self.output_end_idx[i])
86
- # print(logits[:, self.output_start_idx[i]:self.output_end_idx[i]].shape, targets[:, self.label_start_idx[i]].shape)
87
- loss += self.weights[i] * self.loss_fcns[i](logits[:, self.output_start_idx[i]:self.output_end_idx[i]], targets[:, self.label_start_idx[i]].to(int))
88
- elif self.label_end_idx[i] - self.label_start_idx[i] == 1:
89
- loss += self.weights[i] * self.loss_fcns[i](logits[:, self.output_start_idx[i]:self.output_end_idx[i]], targets[:, self.label_start_idx[i]])
90
- else:
91
- # print('loos', i, self.label_start_idx[i], self.label_end_idx[i], self.output_start_idx[i], self.output_end_idx[i])
92
- # print(logits[:, self.output_start_idx[i]:self.output_end_idx[i]].shape, targets[:, self.label_start_idx[i]:self.label_end_idx[i]].shape)
93
- loss += self.weights[i] * self.loss_fcns[i](logits[:, self.output_start_idx[i]:self.output_end_idx[i]], targets[:, self.label_start_idx[i]:self.label_end_idx[i]])
94
- return loss
95
-
96
- class AdvLoss():
97
- def __init__(self, loss, adv_loss, adv_weight=1.0):
98
- self.loss_fcn = utils.buildFromConfig(loss)
99
- self.adv_loss_fcn = utils.buildFromConfig(adv_loss)
100
- self.adv_weight = adv_weight
101
-
102
- def __call__(self, logits, targets):
103
- mask = targets[:,0] == 0
104
- loss = self.loss_fcn(logits[:,0], targets[:,0])
105
- adv_loss = self.adv_loss_fcn(logits[mask][:,1], targets[mask])
106
- return loss - self.adv_weight * adv_loss
107
-
108
- class MassWindowAdvLoss(AdvLoss):
109
- def __call__(self, logits, targets):
110
- mask = (targets[:,0] == 0) & (targets[:,1] > 5) & (targets[:,1] < 25)
111
- print(mask, mask.shape, mask.sum())
112
- loss = self.loss_fcn(logits[:,0], targets[:,0])
113
- print(loss)
114
- adv_loss = self.adv_loss_fcn(logits[mask][:,1], targets[mask][:,1])
115
- print(adv_loss)
116
- return loss - self.adv_weight * adv_loss
117
-
118
- class KDELoss(MaskedLoss):
119
- def __init__(self, mask = [], index = 0):
120
- self.index = index
121
- super().__init__(mask)
122
-
123
- def __call__(self, logits, targets):
124
- mask = self.make_mask(targets)
125
- logits = logits[mask]
126
- targets = targets[mask][:,self.index]
127
- N = logits.shape[0]
128
- masses = targets / torch.sqrt(torch.mean(targets**2))
129
- scores = logits[:,0] / torch.sqrt(torch.mean(logits**2))
130
-
131
- factor_2d = (1.0*N) ** (-2/6)
132
- covs = (factor_2d * torch.var(masses), factor_2d * torch.var(scores))
133
-
134
- m_diffs = torch.unsqueeze(masses, 1) - torch.unsqueeze(masses, 0)
135
- s_diffs = torch.unsqueeze(scores, 1) - torch.unsqueeze(scores, 0)
136
-
137
- ymm = torch.exp(- (m_diffs**2) / (4 * covs[0]))
138
- yss = torch.exp(- (s_diffs**2) / (4 * covs[1]))
139
-
140
- integral_rho_2d_rho_2d = torch.einsum('ij,ij->', ymm, yss)
141
- integral_rho_1d_rho_1d = torch.einsum('ij,kl->', ymm, yss)
142
- integral_rho_2d_rho_1d = torch.einsum('ij,ik->', ymm, yss)
143
- raw_integral = integral_rho_2d_rho_2d - 2 * integral_rho_2d_rho_1d / N + integral_rho_1d_rho_1d / N**2
144
- return raw_integral / (4 * torch.pi * N**2)
145
-
146
- class MultiLabelLoss():
147
- def __init__(self, label_names, label_types, label_weights = None):
148
- self.loss_fcn = []
149
- if (label_weights):
150
- self.weights = torch.tensor(label_weights)
151
- else:
152
- self.weights = torch.ones(len(label_types))
153
- for type in label_types:
154
- if (type == "r"):
155
- self.loss_fcn.append(torch.nn.MSELoss(reduce=False))
156
- elif (type == "c"):
157
- self.loss_fcn.append(torch.nn.BCEWithLogitsLoss())
158
- print(f"self.weights = {self.weights}")
159
-
160
- def __call__(self, logits, targets):
161
- targets = targets.float()
162
- loss = torch.zeros(len(logits[:, 0]), device = logits.get_device())
163
- for i in range(len(self.loss_fcn)):
164
- loss += self.weights[i] * self.loss_fcn[i](logits[:, i], targets[:, i])
165
- return torch.mean(loss)
166
-
167
-
168
- class MultiLabelFinish():
169
- def __init__(self, label_names, label_types):
170
- self.finish_fcn = []
171
- for type in label_types:
172
- if (type == "r"):
173
- self.finish_fcn.append(None)
174
- elif (type == "c"):
175
- self.finish_fcn.append(torch.special.expit)
176
-
177
- def __call__(self, logits):
178
- for i in range(len(self.finish_fcn)):
179
- if (self.finish_fcn[i]):
180
- logits[:, i] = self.finish_fcn[i](logits[:, i].to(torch.long))
181
- return logits
182
-
183
- class ContrastiveClusterLoss():
184
- def __init__(self, k=10, temperature=1, alpha=1):
185
- self.k = k
186
- self.temperature = temperature
187
- self.alpha = alpha
188
-
189
- def __call__(self, logits, targets):
190
- targets = targets.float()
191
- logits_combined = logits.float()
192
-
193
- hid_size = int(len(logits[0]) / 2)
194
-
195
- logits = normalize_embeddings(logits_combined[:, :hid_size])
196
- logits_augmented = normalize_embeddings(logits_combined[:, hid_size:])
197
-
198
- contrastive = contrastive_loss(logits, logits_augmented, self.temperature)
199
- clustering, _ = clustering_loss(logits, self.k)
200
-
201
- variance_loss = variance_regularization(logits) + variance_regularization(logits_augmented)
202
-
203
- return torch.mean(contrastive + clustering + self.alpha * variance_loss)
204
-
205
- class ContrastiveClusterFinish():
206
- def __init__(self, k = 10, temperature = 1, max_cluster_iterations = 10):
207
- self.k = k
208
- self.temperature = temperature
209
- self.max_cluster_iterations = max_cluster_iterations
210
-
211
- print(f"ContrastiveClusterFinish: k = {k}, temperature = {temperature}")
212
-
213
- def __call__(self, logits):
214
- logits_combined = logits.float()
215
-
216
- hid_size = int(len(logits[0]) / 2)
217
-
218
- logits = logits_combined[:, :hid_size]
219
- logits_augmented = logits_combined[:, hid_size:]
220
-
221
- contrastive = contrastive_loss(logits, logits_augmented, self.temperature)
222
- clustering, _ = clustering_loss(logits, self.k, self.max_cluster_iterations)
223
- variance = variance_regularization(logits) + variance_regularization(logits_augmented)
224
-
225
- return contrastive, clustering, variance
226
-
227
- def s(z_i, z_j):
228
- z_i = torch.tensor(z_i) if not isinstance(z_i, torch.Tensor) else z_i
229
- z_j = torch.tensor(z_j) if not isinstance(z_j, torch.Tensor) else z_j
230
-
231
- return torch.cdist(z_i, z_j, p=2)
232
- # dot_product = torch.dot(z_i, z_j)
233
- # norm_i = torch.linalg.norm(z_i)
234
- # norm_j = torch.linalg.norm(z_j)
235
-
236
- # return dot_product / (norm_i * norm_j)
237
-
238
- def contrastive_loss(logits, logits_augmented, temperature=1, margin=1.0):
239
- logits = torch.tensor(logits) if not isinstance(logits, torch.Tensor) else logits
240
- logits_augmented = torch.tensor(logits_augmented) if not isinstance(logits_augmented, torch.Tensor) else logits_augmented
241
-
242
- z = torch.cat((logits, logits_augmented), dim=0)
243
- similarity_matrix = torch.mm(z, z.t()) / temperature
244
- norms = torch.linalg.norm(z, dim=1)
245
- norm_matrix = torch.ger(norms, norms)
246
- similarity_matrix = similarity_matrix / norm_matrix
247
- mask = torch.eye(similarity_matrix.size(0), dtype=torch.bool)
248
-
249
- loss = 0
250
- for k in range(len(logits)):
251
- numerator = torch.exp(similarity_matrix[k, k + len(logits)])
252
- denominator = torch.sum(torch.exp(similarity_matrix[k, ~mask[k]]))
253
-
254
- loss += -torch.log(numerator / denominator)
255
-
256
- return loss
257
-
258
-
259
- def clustering_loss(logits, k=10, max_iterations=10):
260
- # Step 1: Initialize cluster means
261
- indices = torch.randperm(logits.size(0))[:k]
262
- cluster_means = logits[indices]
263
-
264
- prev_assignments = None
265
- assignment_history = []
266
- iteration = 0
267
-
268
- while iteration < max_iterations:
269
- iteration += 1
270
-
271
- # Step 2: Assign each data point to the nearest cluster mean
272
- distances = torch.cdist(logits, cluster_means, p=2) # Compute distances between logits and cluster means
273
- cluster_assignments = torch.argmin(distances, dim=1) # Assign each point to the nearest cluster mean
274
-
275
- # Check for convergence: if assignments do not change, break the loop
276
- if prev_assignments is not None and torch.equal(cluster_assignments, prev_assignments):
277
- break
278
-
279
- # Check for cycles: if assignments have been seen before, break the loop
280
- if any(torch.equal(cluster_assignments, prev) for prev in assignment_history):
281
- break
282
-
283
- assignment_history.append(cluster_assignments.clone())
284
- prev_assignments = cluster_assignments.clone()
285
-
286
- # Step 3: Update cluster means based on assignments
287
- new_cluster_means = torch.zeros_like(cluster_means)
288
- for i in range(k):
289
- assigned_points = logits[cluster_assignments == i]
290
- if assigned_points.size(0) > 0:
291
- new_cluster_means[i] = assigned_points.mean(dim=0)
292
- else:
293
- # If no points are assigned to the cluster, reinitialize the mean randomly
294
- new_cluster_means[i] = logits[torch.randint(0, logits.size(0), (1,)).item()]
295
- cluster_means = new_cluster_means
296
-
297
- # Step 4: Compute the clustering loss
298
- distances = torch.cdist(logits, cluster_means, p=2)
299
- min_distances = torch.min(distances, dim=1)[0]
300
- loss = torch.sum(min_distances ** 2)
301
-
302
- return loss, cluster_means
303
-
304
- def normalize_embeddings(embeddings):
305
- return embeddings / embeddings.norm(dim=1, keepdim=True)
306
-
307
- def variance_regularization(embeddings):
308
- mean_embedding = embeddings.mean(dim=0)
309
- variance = ((embeddings - mean_embedding) ** 2).mean()
310
- return variance
311
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nemo/models/meshgraphnet.py DELETED
@@ -1,33 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import dgl
4
-
5
- # Import the PhysicsNemo MeshGraphNet model
6
- from physicsnemo.models.meshgraphnet import MeshGraphNet as PhysicsNemoMeshGraphNet
7
-
8
- class MeshGraphNet(nn.Module):
9
- def __init__(self, *args, out_dim=1, **kwargs):
10
- super().__init__()
11
- # Initialize the PhysicsNemo MeshGraphNet
12
- self.base_gnn = PhysicsNemoMeshGraphNet(*args, **kwargs)
13
- # Assume node_output_dim is known or infer from args/kwargs
14
- node_output_dim = 64
15
- self.mlp = nn.Linear(node_output_dim, out_dim)
16
-
17
- def forward(self, node_feats, edge_feats, batched_graph):
18
- """
19
- Args:
20
- batched_graph: DGLGraph, batched graphs
21
- node_feats: Tensor [total_num_nodes, node_feat_dim]
22
- edge_feats: Tensor [total_num_edges, edge_feat_dim]
23
- Returns:
24
- graph_pred: Tensor [num_graphs, out_dim]
25
- """
26
- # 1. Node-level prediction from PhysicsNemo GNN
27
- node_pred = self.base_gnn(node_feats, edge_feats, batched_graph)
28
- batched_graph.ndata['h'] = node_pred
29
- graph_feat = dgl.readout_nodes(batched_graph, 'h', op='mean') # [num_graphs, node_output_dim]
30
-
31
- # 3. Final MLP for graph-level prediction
32
- graph_pred = self.mlp(graph_feat) # [num_graphs, out_dim]
33
- return graph_pred
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nemo/root_gnn_base/batched_dataset.py DELETED
@@ -1,191 +0,0 @@
1
- from dgl.dataloading import GraphDataLoader
2
- from torch.utils.data.sampler import SubsetRandomSampler
3
- from torch.utils.data.sampler import SequentialSampler
4
- from dgl.data import DGLDataset
5
- import torch
6
- import time
7
- import os
8
- import dgl
9
- from root_gnn_base import utils
10
-
11
- def GetBatchedLoader(dataset, batch_size, mask_fn = None, drop_last=True, **kwargs):
12
- if mask_fn == None:
13
- mask_fn = lambda x: torch.ones(len(x), dtype=torch.bool)
14
- dloader = GraphDataLoader(dataset, sampler=SubsetRandomSampler(torch.arange(len(dataset))[mask_fn(dataset)]), batch_size=batch_size, drop_last=drop_last, num_workers = 0)
15
- return dloader
16
-
17
- #Dataset which contains prebatched shuffled graphs. Cannot be saved to disk, else batching info is lost.
18
- class PreBatchedDataset(DGLDataset):
19
- def __init__(self, start_dataset, batch_size, mask_fn = None, drop_last=True, save_to_disk = True, suffix = '', chunks = 1, chunkno = -1, shuffle = True, padding_mode = 'NONE', hidden_size=64, **kwargs):
20
- print(f'Unused kwargs: {kwargs}')
21
- self.start_dataset = start_dataset
22
- self.start_dataset.load()
23
-
24
- self.batch_size = batch_size
25
- self.chunks = chunks
26
- self.chunkno = chunkno
27
- self.mask_fn = mask_fn
28
- self.drop_last = drop_last
29
- self.graphs = []
30
- self.label = []
31
- self.padding_mode = padding_mode
32
- self.save_to_disk = save_to_disk
33
- self.shuffle = shuffle
34
- self.suffix = suffix
35
- self.current_chunk = None
36
- self.current_chunk_idx = -1
37
- self.hid_size = hidden_size
38
- super().__init__(name = start_dataset.name + '_prebatched_padded', save_dir=start_dataset.save_dir)
39
-
40
- def process(self):
41
- first = 0
42
- last = len(self.start_dataset)
43
- if self.chunks > 1 and self.chunkno >= 0:
44
- first = int(self.chunkno / self.chunks * len(self.start_dataset))
45
- last = int((self.chunkno + 1) / self.chunks * len(self.start_dataset))
46
- print(f'Processing chunk {self.chunkno} of {self.chunks} from {first} to {last} of {len(self.start_dataset)}')
47
- mask = torch.logical_and(torch.logical_and(self.mask_fn(self.start_dataset), torch.arange(len(self.start_dataset)) >= first), torch.arange(len(self.start_dataset)) < last)
48
- if self.shuffle:
49
- dloader = GraphDataLoader(self.start_dataset, sampler=SubsetRandomSampler(torch.arange(len(self.start_dataset))[mask]), batch_size=self.batch_size, drop_last=self.drop_last)
50
- else: #Only don't shuffle if we're doing inference. Then we want all of the events anyways?
51
- dloader = GraphDataLoader(self.start_dataset, sampler=SequentialSampler(self.start_dataset), batch_size=self.batch_size, drop_last=self.drop_last)
52
- self.graphs = []
53
- self.labels = []
54
- self.tracking = []
55
- self.globals = []
56
- self.batch_num_nodes = []
57
- self.batch_num_edges = []
58
- max_edges = 0
59
- max_nodes = 0
60
- load_batch_start = time.time()
61
- for batch, label, tracking, global_feat in dloader:
62
- if batch.num_edges() > max_edges:
63
- max_edges = batch.num_edges()
64
- if batch.num_nodes() > max_nodes:
65
- max_nodes = batch.num_nodes()
66
- self.graphs.append(batch)
67
- self.labels.append(label)
68
- self.tracking.append(tracking)
69
- self.globals.append(global_feat)
70
- load_batch_end = time.time()
71
- print(f'Loaded {len(self.graphs)} batches in {load_batch_end - load_batch_start} seconds')
72
- if self.padding_mode == 'STEPS':
73
- pad_node, pad_edge = utils.pad_size(self.batch_size, max_edges, max_nodes)
74
- elif self.padding_mode == 'FIXED':
75
- print('Padding to fixed size. This is currently hardcoded.')
76
- pad_node = 16000
77
- pad_edge = 104000
78
- elif self.padding_mode == 'NONE':
79
- pad_node = 0
80
- pad_edge = 0
81
- else:
82
- pad_node = 0
83
- pad_edge = 0
84
- print(f'Max edges: {max_edges}, Max nodes: {max_nodes}, Padding to {pad_edge} edges and {pad_node} nodes')
85
- pad_start = time.time()
86
- if self.padding_mode == 'NODE':
87
- for i in range(len(self.graphs)):
88
- unbatched_g = dgl.unbatch(self.graphs[i])
89
- max_num_nodes = max(g.number_of_nodes() for g in unbatched_g)
90
- self.graphs[i] = utils.pad_batch_num_nodes(self.graphs[i], max_num_nodes, hid_size=self.hid_size)
91
- self.batch_num_nodes.append(self.graphs[i].batch_num_nodes())
92
- self.batch_num_edges.append(self.graphs[i].batch_num_edges())
93
- else:
94
- for i in range(len(self.graphs)):
95
- self.graphs[i] = utils.pad_batch(self.graphs[i], pad_edge, pad_node)
96
- self.batch_num_nodes.append(self.graphs[i].batch_num_nodes())
97
- self.batch_num_edges.append(self.graphs[i].batch_num_edges())
98
- pad_end = time.time()
99
- print(f'Padded {len(self.graphs)} batches in {pad_end - pad_start} seconds')
100
-
101
- def save(self):
102
- if not self.save_to_disk:
103
- return
104
- graph_path = os.path.join(self.save_dir, f'{self.name}_{self.chunkno}_{self.suffix}.bin')
105
- print(f'Saving dataset to {graph_path}')
106
- if len(self.graphs) == 0:
107
- return
108
- dgl.save_graphs(str(graph_path), self.graphs, {'labels': torch.stack(self.labels), 'batch_num_nodes': torch.stack(self.batch_num_nodes), 'batch_num_edges': torch.stack(self.batch_num_edges), 'tracking': torch.stack(self.tracking), 'globals': torch.stack(self.globals)})
109
-
110
- def has_cache(self):
111
- if not self.save_to_disk:
112
- return False
113
- for ch in range(self.chunks):
114
- graph_path = os.path.join(self.save_dir, f'{self.name}_{ch}_{self.suffix}.bin')
115
- if not os.path.exists(graph_path):
116
- print(f'Cache file {graph_path} does not exist, not loading from cache.')
117
- return False
118
- return True
119
-
120
- def load(self):
121
- if not self.save_to_disk:
122
- return
123
- self.graphs = []
124
- label_chunks = []
125
- tracking_chunks = []
126
- global_chunks = []
127
- for ch in range(self.chunks):
128
- graph_path = os.path.join(self.save_dir, f'{self.name}_{ch}_{self.suffix}.bin')
129
- print(f'Loading dataset from {graph_path}')
130
- graphs, label_dict = dgl.load_graphs(graph_path)
131
- label_chunks.append(label_dict['labels'])
132
- tracking_chunks.append(label_dict['tracking'])
133
- global_chunks.append(label_dict['globals'])
134
- for g, bnn, bne in zip(graphs, label_dict['batch_num_nodes'], label_dict['batch_num_edges']):
135
- g.set_batch_num_nodes(bnn)
136
- g.set_batch_num_edges(bne)
137
- self.graphs.extend(graphs)
138
- self.labels = torch.cat(label_chunks)
139
- self.tracking = torch.cat(tracking_chunks)
140
- self.globals = torch.cat(global_chunks)
141
-
142
- def __getitem__(self, idx):
143
- return self.graphs[idx], self.labels[idx], self.tracking[idx], self.globals[idx]
144
-
145
- def __len__(self):
146
- return len(self.graphs)
147
-
148
- #Dataset which contains prebatched shuffled graphs. Cannot be saved to disk, else batching info is lost.
149
- class LazyPreBatchedDataset(PreBatchedDataset):
150
- def __init__(self, **kwargs):
151
- # print(f'Unused kwargs: {kwargs}')
152
- self.current_chunk = None
153
- self.current_chunk_idx = -10
154
- self.label_chunks = []
155
- super().__init__(**kwargs)
156
-
157
- def load(self):
158
- if not self.save_to_disk:
159
- return
160
- self.label_chunks = []
161
- for ch in range(self.chunks):
162
- graph_path = os.path.join(self.save_dir, f'{self.name}_{ch}_{self.suffix}.bin')
163
- print(f'Loading dataset from {graph_path}')
164
- label_dict = dgl.data.graph_serialize.load_labels_v2(graph_path)
165
- self.label_chunks.append(label_dict)
166
-
167
- def __getitem__(self, idx):
168
- chunk_idx = -1
169
- sum = 0
170
- ev_idx = -999
171
- for i in range(len(self.label_chunks)):
172
- count = len(self.label_chunks[i]['labels'])
173
- if idx < sum + count:
174
- chunk_idx = i
175
- ev_idx = idx - sum
176
- break
177
- sum += count
178
- if chunk_idx != self.current_chunk_idx:
179
- # print(f"rank {self.rank} getting data from {self.name}_{chunk_idx}_{self.suffix}.bin")
180
- self.current_chunk, _ = dgl.load_graphs(os.path.join(self.save_dir, f'{self.name}_{chunk_idx}_{self.suffix}.bin'))
181
- self.current_chunk_idx = chunk_idx
182
- g = self.current_chunk[ev_idx]
183
- g.set_batch_num_nodes(self.label_chunks[chunk_idx]['batch_num_nodes'][ev_idx])
184
- g.set_batch_num_edges(self.label_chunks[chunk_idx]['batch_num_edges'][ev_idx])
185
- return g, self.label_chunks[chunk_idx]['labels'][ev_idx], self.label_chunks[chunk_idx]['tracking'][ev_idx], self.label_chunks[chunk_idx]['globals'][ev_idx]
186
-
187
- def __len__(self):
188
- l = 0
189
- for chunk in self.label_chunks:
190
- l += len(chunk['labels'])
191
- return l
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nemo/root_gnn_base/custom_scheduler.py DELETED
@@ -1,565 +0,0 @@
1
- import types
2
- import math
3
- import torch
4
- from torch import inf
5
- from functools import wraps, partial
6
- import warnings
7
- import weakref
8
- from collections import Counter
9
- from bisect import bisect_right
10
-
11
- from models import GCN
12
-
13
-
14
-
15
-
16
- ### Code from: https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#ReduceLROnPlateau
17
-
18
- Optimizer = torch.optim.Optimizer
19
-
20
- __all__ = ['LambdaLR', 'MultiplicativeLR', 'StepLR', 'MultiStepLR', 'ConstantLR', 'LinearLR',
21
- 'ExponentialLR', 'SequentialLR', 'CosineAnnealingLR', 'ChainedScheduler', 'ReduceLROnPlateau',
22
- 'CyclicLR', 'CosineAnnealingWarmRestarts', 'OneCycleLR', 'PolynomialLR', 'LRScheduler']
23
-
24
- EPOCH_DEPRECATION_WARNING = (
25
- "The epoch parameter in `scheduler.step()` was not necessary and is being "
26
- "deprecated where possible. Please use `scheduler.step()` to step the "
27
- "scheduler. During the deprecation, if epoch is different from None, the "
28
- "closed form is used instead of the new chainable form, where available. "
29
- "Please open an issue if you are unable to replicate your use case: "
30
- "https://github.com/pytorch/pytorch/issues/new/choose."
31
- )
32
-
33
-
34
- def update_LR(opt, lr):
35
- for param_group in opt.param_groups:
36
- param_group['lr'] = lr
37
-
38
- def print_LR(opt):
39
- for param_group in opt.param_groups:
40
- print(f"LR = {param_group['lr']}")
41
-
42
- def _check_verbose_deprecated_warning(verbose):
43
- """Raises a warning when verbose is not the default value."""
44
- if verbose != "deprecated":
45
- warnings.warn("The verbose parameter is deprecated. Please use get_last_lr() "
46
- "to access the learning rate.", UserWarning)
47
- return verbose
48
- return False
49
-
50
- class LRScheduler:
51
-
52
- def __init__(self, optimizer, last_epoch=-1, verbose="deprecated"):
53
-
54
- # Attach optimizer
55
- if not isinstance(optimizer, Optimizer):
56
- raise TypeError(f'{type(optimizer).__name__} is not an Optimizer')
57
- self.optimizer = optimizer
58
-
59
- # Initialize epoch and base learning rates
60
- if last_epoch == -1:
61
- for group in optimizer.param_groups:
62
- group.setdefault('initial_lr', group['lr'])
63
- else:
64
- for i, group in enumerate(optimizer.param_groups):
65
- if 'initial_lr' not in group:
66
- raise KeyError("param 'initial_lr' is not specified "
67
- f"in param_groups[{i}] when resuming an optimizer")
68
- self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups]
69
- self.last_epoch = last_epoch
70
-
71
- # Following https://github.com/pytorch/pytorch/issues/20124
72
- # We would like to ensure that `lr_scheduler.step()` is called after
73
- # `optimizer.step()`
74
- def with_counter(method):
75
- if getattr(method, '_with_counter', False):
76
- # `optimizer.step()` has already been replaced, return.
77
- return method
78
-
79
- # Keep a weak reference to the optimizer instance to prevent
80
- # cyclic references.
81
- instance_ref = weakref.ref(method.__self__)
82
- # Get the unbound method for the same purpose.
83
- func = method.__func__
84
- cls = instance_ref().__class__
85
- del method
86
-
87
- @wraps(func)
88
- def wrapper(*args, **kwargs):
89
- instance = instance_ref()
90
- instance._step_count += 1
91
- wrapped = func.__get__(instance, cls)
92
- return wrapped(*args, **kwargs)
93
-
94
- # Note that the returned function here is no longer a bound method,
95
- # so attributes like `__func__` and `__self__` no longer exist.
96
- wrapper._with_counter = True
97
- return wrapper
98
-
99
- self.optimizer.step = with_counter(self.optimizer.step)
100
- self.verbose = _check_verbose_deprecated_warning(verbose)
101
-
102
- self._initial_step()
103
-
104
- def _initial_step(self):
105
- """Initialize step counts and performs a step"""
106
- self.optimizer._step_count = 0
107
- self._step_count = 0
108
- self.step()
109
-
110
- def state_dict(self):
111
- """Returns the state of the scheduler as a :class:`dict`.
112
-
113
- It contains an entry for every variable in self.__dict__ which
114
- is not the optimizer.
115
- """
116
- return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
117
-
118
- def load_state_dict(self, state_dict):
119
- """Loads the schedulers state.
120
-
121
- Args:
122
- state_dict (dict): scheduler state. Should be an object returned
123
- from a call to :meth:`state_dict`.
124
- """
125
- self.__dict__.update(state_dict)
126
-
127
- def get_last_lr(self):
128
- """ Return last computed learning rate by current scheduler.
129
- """
130
- return self._last_lr
131
-
132
- def get_lr(self):
133
- # Compute learning rate using chainable form of the scheduler
134
- raise NotImplementedError
135
-
136
- def print_lr(self, is_verbose, group, lr, epoch=None):
137
- """Display the current learning rate.
138
- """
139
- if is_verbose:
140
- if epoch is None:
141
- print(f'Adjusting learning rate of group {group} to {lr:.4e}.')
142
- else:
143
- epoch_str = ("%.2f" if isinstance(epoch, float) else
144
- "%.5d") % epoch
145
- print(f'Epoch {epoch_str}: adjusting learning rate of group {group} to {lr:.4e}.')
146
-
147
-
148
- def step(self, epoch=None):
149
- # Raise a warning if old pattern is detected
150
- # https://github.com/pytorch/pytorch/issues/20124
151
- if self._step_count == 1:
152
- if not hasattr(self.optimizer.step, "_with_counter"):
153
- warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler "
154
- "initialization. Please, make sure to call `optimizer.step()` before "
155
- "`lr_scheduler.step()`. See more details at "
156
- "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
157
-
158
- # Just check if there were two first lr_scheduler.step() calls before optimizer.step()
159
- elif self.optimizer._step_count < 1:
160
- warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
161
- "In PyTorch 1.1.0 and later, you should call them in the opposite order: "
162
- "`optimizer.step()` before `lr_scheduler.step()`. Failure to do this "
163
- "will result in PyTorch skipping the first value of the learning rate schedule. "
164
- "See more details at "
165
- "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
166
- self._step_count += 1
167
-
168
- with _enable_get_lr_call(self):
169
- if epoch is None:
170
- self.last_epoch += 1
171
- values = self.get_lr()
172
- else:
173
- warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
174
- self.last_epoch = epoch
175
- if hasattr(self, "_get_closed_form_lr"):
176
- values = self._get_closed_form_lr()
177
- else:
178
- values = self.get_lr()
179
-
180
- for i, data in enumerate(zip(self.optimizer.param_groups, values)):
181
- param_group, lr = data
182
- param_group['lr'] = lr
183
-
184
- self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
185
-
186
-
187
- # Including _LRScheduler for backwards compatibility
188
- # Subclass instead of assign because we want __name__ of _LRScheduler to be _LRScheduler (assigning would make it LRScheduler).
189
- class _LRScheduler(LRScheduler):
190
- pass
191
-
192
-
193
- class _enable_get_lr_call:
194
-
195
- def __init__(self, o):
196
- self.o = o
197
-
198
- def __enter__(self):
199
- self.o._get_lr_called_within_step = True
200
- return self
201
-
202
- def __exit__(self, type, value, traceback):
203
- self.o._get_lr_called_within_step = False
204
-
205
-
206
- class Dynamic_LR(LRScheduler):
207
- """Reduce learning rate when a metric has stopped improving.
208
- Models often benefit from reducing the learning rate by a factor
209
- of 2-10 once learning stagnates. This scheduler reads a metrics
210
- quantity and if no improvement is seen for a 'patience' number
211
- of epochs, the learning rate is reduced.
212
-
213
- Args:
214
- optimizer (Optimizer): Wrapped optimizer.
215
- mode (str): One of `min`, `max`. In `min` mode, lr will
216
- be reduced when the quantity monitored has stopped
217
- decreasing; in `max` mode it will be reduced when the
218
- quantity monitored has stopped increasing. Default: 'min'.
219
- factor (float): Factor by which the learning rate will be
220
- reduced. new_lr = lr * factor. Default: 0.1.
221
- patience (int): Number of epochs with no improvement after
222
- which learning rate will be reduced. For example, if
223
- `patience = 2`, then we will ignore the first 2 epochs
224
- with no improvement, and will only decrease the LR after the
225
- 3rd epoch if the loss still hasn't improved then.
226
- Default: 10.
227
- threshold (float): Threshold for measuring the new optimum,
228
- to only focus on significant changes. Default: 1e-4.
229
- threshold_mode (str): One of `rel`, `abs`. In `rel` mode,
230
- dynamic_threshold = best * ( 1 + threshold ) in 'max'
231
- mode or best * ( 1 - threshold ) in `min` mode.
232
- In `abs` mode, dynamic_threshold = best + threshold in
233
- `max` mode or best - threshold in `min` mode. Default: 'rel'.
234
- cooldown (int): Number of epochs to wait before resuming
235
- normal operation after lr has been reduced. Default: 0.
236
- min_lr (float or list): A scalar or a list of scalars. A
237
- lower bound on the learning rate of all param groups
238
- or each group respectively. Default: 0.
239
- eps (float): Minimal decay applied to lr. If the difference
240
- between new and old lr is smaller than eps, the update is
241
- ignored. Default: 1e-8.
242
- verbose (bool): If ``True``, prints a message to stdout for
243
- each update. Default: ``False``.
244
-
245
- .. deprecated:: 2.2
246
- ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
247
- learning rate.
248
-
249
- Example:
250
- >>> # xdoctest: +SKIP
251
- >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
252
- >>> scheduler = ReduceLROnPlateau(optimizer, 'min')
253
- >>> for epoch in range(10):
254
- >>> train(...)
255
- >>> val_loss = validate(...)
256
- >>> # Note that step should be called after validate()
257
- >>> scheduler.step(val_loss)
258
- """
259
-
260
- def __init__(self, optimizer, mode = 'max', factor=0.1, patience=10,
261
- plateau_var = "test_auc",
262
- threshold=1e-4, threshold_mode='rel', cooldown=0,
263
- min_lr=0, max_lr=1e-4, eps=1e-8, verbose=False):
264
-
265
- """
266
- if factor >= 1.0:
267
- raise ValueError('Factor should be < 1.0.')
268
- """
269
- self.factor = factor
270
-
271
- # Attach optimizer
272
- if not isinstance(optimizer, Optimizer):
273
- raise TypeError(f'{type(optimizer).__name__} is not an Optimizer')
274
- self.optimizer = optimizer
275
-
276
- if isinstance(min_lr, (list, tuple)):
277
- if len(min_lr) != len(optimizer.param_groups):
278
- raise ValueError(f"expected {len(optimizer.param_groups)} min_lrs, got {len(min_lr)}")
279
- self.min_lrs = list(min_lr)
280
- self.max_lrs = list(max_lr)
281
- else:
282
- self.min_lrs = [min_lr] * len(optimizer.param_groups)
283
- self.max_lrs = [max_lr] * len(optimizer.param_groups)
284
-
285
- self.patience = patience
286
- self.plateau_var = plateau_var
287
-
288
- self.verbose = verbose
289
- self.cooldown = cooldown
290
- self.cooldown_counter = 0
291
- self.mode = mode
292
- self.threshold = threshold
293
- self.threshold_mode = threshold_mode
294
- self.best = None
295
- self.num_bad_epochs = None
296
- self.mode_worse = None # the worse value for the chosen mode
297
- self.eps = eps
298
- self.last_epoch = 0
299
- self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
300
- self._init_is_better(mode=mode, threshold=threshold,
301
- threshold_mode=threshold_mode)
302
- self._reset()
303
-
304
- def _reset(self):
305
- """Resets num_bad_epochs counter and cooldown counter."""
306
- self.best = self.mode_worse
307
- self.cooldown_counter = 0
308
- self.num_bad_epochs = 0
309
-
310
- def step(self, model, metrics, epoch=None):
311
- # convert `metrics` to float, in case it's a zero-dim Tensor
312
- current = float(metrics[self.plateau_var])
313
- if epoch is None:
314
- epoch = self.last_epoch + 1
315
- else:
316
- warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
317
- self.last_epoch = epoch
318
-
319
- if self.is_better(current, self.best):
320
- if(self.verbose):
321
- print("Model is improving!")
322
- self.best = current
323
- self.num_bad_epochs = 0
324
- else:
325
- if(self.verbose):
326
- print(f"Model is not improving :( best = {self.best}, current = {current}")
327
- self.num_bad_epochs += 1
328
-
329
- if self.in_cooldown:
330
- self.cooldown_counter -= 1
331
- self.num_bad_epochs = 0 # ignore any bad epochs in cooldown
332
-
333
- if self.num_bad_epochs > self.patience:
334
- self._reduce_lr(epoch)
335
- self.cooldown_counter = self.cooldown
336
- self.num_bad_epochs = 0
337
-
338
- self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
339
-
340
- def _reduce_lr(self, epoch):
341
- print("Adjusting Learning Rate")
342
- self._reset()
343
- for i, param_group in enumerate(self.optimizer.param_groups):
344
- old_lr = float(param_group['lr'])
345
- new_lr = max(old_lr * self.factor, self.min_lrs[i])
346
- new_lr = min(new_lr, self.max_lrs[i])
347
- if abs(old_lr - new_lr) > self.eps:
348
- param_group['lr'] = new_lr
349
-
350
- def get_last_lr(self):
351
- return self._last_lr
352
- @property
353
- def in_cooldown(self):
354
- return self.cooldown_counter > 0
355
-
356
- def is_better(self, a, best):
357
- if self.mode == 'min' and self.threshold_mode == 'rel':
358
- rel_epsilon = 1. - self.threshold
359
- return a < best * rel_epsilon
360
-
361
- elif self.mode == 'min' and self.threshold_mode == 'abs':
362
- return a < best - self.threshold
363
-
364
- elif self.mode == 'max' and self.threshold_mode == 'rel':
365
- rel_epsilon = self.threshold + 1.
366
- return a > best * rel_epsilon
367
-
368
- else: # mode == 'max' and epsilon_mode == 'abs':
369
- return a > best + self.threshold
370
-
371
- def _init_is_better(self, mode, threshold, threshold_mode):
372
- if mode not in {'min', 'max'}:
373
- raise ValueError('mode ' + mode + ' is unknown!')
374
- if threshold_mode not in {'rel', 'abs'}:
375
- raise ValueError('threshold mode ' + threshold_mode + ' is unknown!')
376
-
377
- if mode == 'min':
378
- self.mode_worse = inf
379
- else: # mode == 'max':
380
- self.mode_worse = -inf
381
-
382
- self.mode = mode
383
- self.threshold = threshold
384
- self.threshold_mode = threshold_mode
385
-
386
- def state_dict(self):
387
- return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
388
-
389
- def load_state_dict(self, state_dict):
390
- self.__dict__.update(state_dict)
391
- self._init_is_better(mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode)
392
-
393
- class Action_On_Plateau():
394
-
395
- def __init__(self, mode = 'max', patience=10,
396
- plateau_var = "test_auc",
397
- threshold=1e-4, threshold_mode='rel', cooldown=0,
398
- eps=1e-8, verbose=False):
399
-
400
- self.patience = patience
401
- self.plateau_var = plateau_var
402
-
403
- self.verbose = verbose
404
- self.cooldown = cooldown
405
- self.cooldown_counter = 0
406
- self.mode = mode
407
- self.threshold = threshold
408
- self.threshold_mode = threshold_mode
409
- self.best = None
410
- self.num_bad_epochs = None
411
- self.mode_worse = None # the worse value for the chosen mode
412
- self.eps = eps
413
- self.last_epoch = 0
414
- self._init_is_better(mode=mode, threshold=threshold,
415
- threshold_mode=threshold_mode)
416
- self._reset()
417
-
418
- def _reset(self):
419
- """Resets num_bad_epochs counter and cooldown counter."""
420
- self.best = self.mode_worse
421
- self.cooldown_counter = 0
422
- self.num_bad_epochs = 0
423
-
424
- def step(self, model, metrics, epoch=None):
425
- # convert `metrics` to float, in case it's a zero-dim Tensor
426
- current = float(metrics[self.plateau_var])
427
- if epoch is None:
428
- epoch = self.last_epoch + 1
429
- else:
430
- warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
431
- self.last_epoch = epoch
432
-
433
- if self.is_better(current, self.best):
434
- if(self.verbose):
435
- print("Model is improving!")
436
- self.best = current
437
- self.num_bad_epochs = 0
438
- else:
439
- if(self.verbose):
440
- print(f"Model is not improving :( best = {self.best}, current = {current}")
441
- self.num_bad_epochs += 1
442
-
443
- if self.in_cooldown:
444
- self.cooldown_counter -= 1
445
- self.num_bad_epochs = 0 # ignore any bad epochs in cooldown
446
-
447
- if self.num_bad_epochs > self.patience:
448
- self.action(model, metrics, epoch)
449
-
450
- def action(self, model, metrics, epoch=None):
451
- if(self.verbose):
452
- print("Doing my action")
453
-
454
- @property
455
- def in_cooldown(self):
456
- return self.cooldown_counter > 0
457
-
458
- def is_better(self, a, best):
459
- if self.mode == 'min' and self.threshold_mode == 'rel':
460
- rel_epsilon = 1. - self.threshold
461
- return a < best * rel_epsilon
462
-
463
- elif self.mode == 'min' and self.threshold_mode == 'abs':
464
- return a < best - self.threshold
465
-
466
- elif self.mode == 'max' and self.threshold_mode == 'rel':
467
- rel_epsilon = self.threshold + 1.
468
- return a > best * rel_epsilon
469
-
470
- else: # mode == 'max' and epsilon_mode == 'abs':
471
- return a > best + self.threshold
472
-
473
- def _init_is_better(self, mode, threshold, threshold_mode):
474
- if mode not in {'min', 'max'}:
475
- raise ValueError('mode ' + mode + ' is unknown!')
476
- if threshold_mode not in {'rel', 'abs'}:
477
- raise ValueError('threshold mode ' + threshold_mode + ' is unknown!')
478
-
479
- if mode == 'min':
480
- self.mode_worse = inf
481
- else: # mode == 'max':
482
- self.mode_worse = -inf
483
-
484
- self.mode = mode
485
- self.threshold = threshold
486
- self.threshold_mode = threshold_mode
487
-
488
- class Partial_Reset(Action_On_Plateau):
489
-
490
- def __init__(self, mode='max', patience=10, plateau_var="test_auc",
491
- threshold=0.0001, threshold_mode='rel', cooldown=0,
492
- eps=1e-8, verbose=False):
493
-
494
- super().__init__(mode, patience, plateau_var, threshold,
495
- threshold_mode, cooldown, eps, verbose)
496
-
497
- def action(self, model, metrics, epoch=None):
498
- print("Partial Reset!!")
499
- GCN.partial_reset(model)
500
- self._reset()
501
- self.cooldown_counter = self.cooldown
502
- self.num_bad_epochs = 0
503
-
504
-
505
- class Full_Reset(Action_On_Plateau):
506
-
507
- def __init__(self, mode='max', patience=10, plateau_var="test_auc",
508
- threshold=0.0001, threshold_mode='rel', cooldown=0,
509
- eps=1e-8, verbose=False):
510
-
511
- super().__init__(mode, patience, plateau_var, threshold,
512
- threshold_mode, cooldown, eps, verbose)
513
-
514
- def action(self, model, metrics, epoch=None):
515
- print("Full Reset!!")
516
- GCN.full_reset(model)
517
- self._reset()
518
- self.cooldown_counter = self.cooldown
519
- self.num_bad_epochs = 0
520
-
521
- class Dynamic_LR_AND_Partial_Reset():
522
- def __init__(self, optimizer, mode = 'max', factor=0.1, patience=10,
523
- plateau_var = "test_auc", reset_patience=None, reset_plateau_var=None,
524
- threshold=1e-4, threshold_mode='rel', cooldown=0,
525
- min_lr=0, max_lr=1e-4, eps=1e-8, verbose=False):
526
-
527
- if (reset_patience == None):
528
- reset_patience = patience
529
- if(reset_plateau_var == None):
530
- reset_plateau_var = plateau_var
531
-
532
- self.dynamic_lr = Dynamic_LR(optimizer, mode=mode, factor=factor, patience = patience,
533
- plateau_var=plateau_var, threshold=threshold, threshold_mode =threshold_mode,
534
- cooldown=cooldown, min_lr=min_lr, max_lr=max_lr, eps=eps, verbose=verbose)
535
-
536
- self.partial_reset = Partial_Reset(mode=mode, patience=reset_patience, plateau_var=reset_plateau_var,
537
- threshold=threshold, threshold_mode=threshold_mode, cooldown=cooldown,
538
- eps=eps)
539
-
540
- def step(self, model, metrics, epoch=None):
541
- self.dynamic_lr.step(model=model, metrics=metrics, epoch=epoch)
542
- self.partial_reset.step(model=model, metrics=metrics, epoch=epoch)
543
-
544
- class Dynamic_LR_AND_Full_Reset():
545
- def __init__(self, optimizer, mode = 'max', factor=0.1, patience=10,
546
- plateau_var = "test_auc", reset_patience=None, reset_plateau_var=None,
547
- threshold=1e-4, threshold_mode='rel', cooldown=0,
548
- min_lr=0, max_lr=1e-4, eps=1e-8, verbose=False):
549
-
550
- if (reset_patience == None):
551
- reset_patience = patience
552
- if(reset_plateau_var == None):
553
- reset_plateau_var = plateau_var
554
-
555
- self.dynamic_lr = Dynamic_LR(optimizer, mode=mode, factor=factor, patience = patience,
556
- plateau_var=plateau_var, threshold=threshold, threshold_mode =threshold_mode,
557
- cooldown=cooldown, min_lr=min_lr, max_lr=max_lr, eps=eps, verbose=verbose)
558
-
559
- self.full_reset = Full_Reset(mode=mode, patience=reset_patience, plateau_var=reset_plateau_var,
560
- threshold=threshold, threshold_mode=threshold_mode, cooldown=cooldown,
561
- eps=eps)
562
-
563
- def step(self, model, metrics, epoch=None):
564
- self.dynamic_lr.step(model=model, metrics=metrics, epoch=epoch)
565
- self.full_reset.step(model=model, metrics=metrics, epoch=epoch)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nemo/root_gnn_base/dataset.py DELETED
@@ -1,678 +0,0 @@
1
- from dgl.data import DGLDataset
2
- import dgl
3
- import uproot
4
- import awkward as ak
5
- import torch
6
- import os
7
- import glob
8
- import time
9
- import numpy as np
10
- from root_gnn_base import utils
11
-
12
- def node_features_from_tree(ch, node_branch_names, node_branch_types, node_feature_scales):
13
- lengths = []
14
- for branch, node_type in zip(node_branch_names[0], node_branch_types):
15
- if node_type == 'single':
16
- lengths.append(1)
17
- elif node_type == 'vector':
18
- lengths.append(len(ch[branch]))
19
- else:
20
- print('Unknown node branch type: {}'.format(node_type))
21
- features = []
22
- for node_feat in node_branch_names:
23
- if node_feat == 'CALC_E':
24
- features.append(features[0]*torch.cosh(features[1]))
25
- continue
26
- elif node_feat == 'NODE_TYPE':
27
- feat = []
28
- for i, length in enumerate(lengths):
29
- feat.extend([i,]*length)
30
- features.append(torch.tensor(feat))
31
- continue
32
- feat = []
33
- itype = 0
34
- for length, branch, node_type in zip(lengths, node_feat, node_branch_types):
35
- if isinstance(branch, (int, float, complex)):
36
- feat.extend([branch,]*length)
37
- elif branch == 'CALC_E':
38
- this_type_starts_at = sum(lengths[:itype])
39
- this_type_ends_at = sum(lengths[:itype+1])
40
- feat.extend(features[0][this_type_starts_at:this_type_ends_at]*torch.cosh(features[1][this_type_starts_at:this_type_ends_at]))
41
- elif node_type == 'single':
42
- feat.append(ch[branch])
43
- elif node_type == 'vector':
44
- feat.extend(ch[branch])
45
- itype += 1
46
- features.append(torch.tensor(feat))
47
- return torch.stack(features, dim=1) * node_feature_scales, lengths
48
-
49
- def full_connected_graph(n_nodes, self_loops=True):
50
- senders = np.arange(n_nodes*n_nodes) // n_nodes
51
- receivers = np.arange(n_nodes*n_nodes) % n_nodes
52
- if not self_loops and n_nodes > 1:
53
- mask = senders != receivers
54
- senders = senders[mask]
55
- receivers = receivers[mask]
56
- return dgl.graph((senders, receivers))
57
-
58
- def check_selection(ch, selection):
59
- var, cut, op = selection
60
- if op == '>':
61
- return ch[var] > cut
62
- elif op == '<':
63
- return ch[var] < cut
64
- elif op == '==':
65
- return ch[var] == cut
66
-
67
- def check_selections(ch, selections):
68
- for selection in selections:
69
- if not check_selection(ch, selection):
70
- return False
71
- return True
72
-
73
- class RootDataset(DGLDataset):
74
- def __init__(self, name=None, raw_dir=None, save_dir=None, label=1, file_names = '*.root', node_branch_names=None, node_branch_types=None, node_feature_scales=None,
75
- selections=[], save=True, tree_name = 'nominal_Loose', fold_var = 'eventNumber', weight_var = None, chunks = 1, process_chunks = None, global_features = [], tracking_info = [], **kwargs):
76
- print(f'Unused args while creating RootDataset: {kwargs}')
77
- self.label = label
78
- self.counts = []
79
- self.selections = selections
80
- self.save_to_disk = save
81
- self.file_names = file_names
82
- self.node_branch_names = node_branch_names
83
- self.node_branch_types = node_branch_types
84
- self.node_feature_scales = torch.tensor([float(sf) for sf in node_feature_scales])
85
- self.tree_name = tree_name
86
- self.fold_var = fold_var
87
- self.tracking_info = tracking_info
88
- self.tracking_info.insert(0, fold_var)
89
- if weight_var is None:
90
- weight_var = 1
91
- self.tracking_info.insert(1, weight_var)
92
- self.global_features = global_features
93
- self.chunks = chunks
94
- self.process_chunks = process_chunks
95
- if self.process_chunks is None:
96
- self.process_chunks = [i for i in range(self.chunks)]
97
- self.times = [0, 0]
98
- super().__init__(name=name, raw_dir=raw_dir, save_dir=save_dir)
99
-
100
- def get_list_of_branches(self):
101
- branches = []
102
- for feat in self.node_branch_names:
103
- if isinstance(feat, list):
104
- for branch in feat:
105
- if branch == 'CALC_E':
106
- continue
107
- if isinstance(branch, str):
108
- branches.append(branch)
109
- for feat in self.global_features:
110
- if isinstance(feat, str):
111
- branches.append(feat)
112
- for feat in self.tracking_info:
113
- if isinstance(feat, str):
114
- branches.append(feat)
115
- for selection in self.selections:
116
- branches.append(selection[0])
117
- return list(set(branches)) # Remove duplicates
118
-
119
- def make_graph(self, ch):
120
- t1 = time.time()
121
- features, _ = node_features_from_tree(ch, self.node_branch_names, self.node_branch_types, self.node_feature_scales)
122
- features = features[features[:,0] != 0]
123
- t2 = time.time()
124
- g = full_connected_graph(features.shape[0], self_loops=False)
125
- g.ndata['features'] = features
126
- t3 = time.time()
127
- self.times[0] += t2 - t1
128
- self.times[1] += t3 - t2
129
- return g
130
-
131
- def process(self):
132
- times = [0, 0, 0]
133
- oldtime = time.time()
134
- if isinstance(self.file_names, str):
135
- self.files = glob.glob(os.path.join(self.raw_dir, self.file_names))
136
- else:
137
- self.files = []
138
- for file_name in self.file_names:
139
- self.files.extend(glob.glob(os.path.join(self.raw_dir, file_name)))
140
- branches = self.get_list_of_branches()
141
-
142
- # Read all files and concatenate arrays
143
- arrays = []
144
- for file in self.files:
145
- with uproot.open(file) as f:
146
- arrays.append(f[self.tree_name].arrays(branches, library="ak"))
147
- if len(arrays) == 0:
148
- print('No files found in {}'.format(os.path.join(self.raw_dir, self.file_names)))
149
- return
150
- data = ak.concatenate(arrays, axis=0)
151
- n_entries = len(data[branches[0]])
152
- newtime = time.time()
153
- times[0] += newtime - oldtime
154
- chunks = np.array_split(np.arange(n_entries), self.chunks)
155
- chunks = [chunk for i, chunk in enumerate(chunks) if i in self.process_chunks]
156
-
157
- self.graph_chunks = []
158
- self.label_chunks = []
159
- self.tracking_chunks = []
160
- self.global_chunks = []
161
- chunk_id = -1
162
- for chunk in chunks:
163
- print('Processing chunk {}/{}'.format(chunk_id + 1, len(chunks)), flush=True)
164
- chunk_id += 1
165
- graphs = []
166
- labels = []
167
- tracking = []
168
- globals = []
169
- for ientry in chunk:
170
- if (ientry % 10000 == 0):
171
- print('Processing event {}/{}'.format(ientry, n_entries), flush=True)
172
- ch = {b: data[b][ientry] for b in branches}
173
- passed = True
174
- for selection in self.selections:
175
- if not check_selection(ch, selection):
176
- passed = False
177
- continue
178
- oldtime = newtime
179
- newtime = time.time()
180
- times[1] += newtime - oldtime
181
- if passed:
182
- graphs.append(self.make_graph(ch))
183
- labels.append(self.label)
184
- tracking.append(torch.zeros(len(self.tracking_info), dtype=torch.double))
185
- globals.append(torch.zeros(len(self.global_features)))
186
- for i_ti, tr_branch in enumerate(self.tracking_info):
187
- if isinstance(tr_branch, str):
188
- tracking[-1][i_ti] = ch[tr_branch]
189
- else:
190
- tracking[-1][i_ti] = tr_branch
191
- for i_gl, gl_branch in enumerate(self.global_features):
192
- globals[-1][i_gl] = ch[gl_branch]
193
- oldtime = newtime
194
- newtime = time.time()
195
- times[2] += newtime - oldtime
196
-
197
- labels = torch.tensor(labels)
198
- tracking = torch.stack(tracking)
199
- globals = torch.stack(globals)
200
-
201
- self.graph_chunks.append(graphs)
202
- self.label_chunks.append(labels)
203
- self.tracking_chunks.append(tracking)
204
- self.global_chunks.append(globals)
205
- self.counts.append(len(graphs))
206
-
207
- if (self.chunks > 1):
208
- self.save_chunk(chunk_id, graphs, labels, tracking, globals)
209
- else:
210
- self.labels = labels
211
- self.tracking = tracking
212
- self.global_features = globals
213
- self.graphs = graphs
214
- self.save()
215
- return
216
-
217
- def save(self):
218
- if not self.save_to_disk:
219
- return
220
- graph_path = os.path.join(self.save_dir, self.name + '.bin')
221
- if self.chunks == 1:
222
- print(f'Saving dataset to {os.path.join(self.save_dir, self.name + ".bin")}')
223
- dgl.save_graphs(str(graph_path), self.graphs, {'labels': torch.tensor(self.labels), 'tracking': torch.tensor(self.tracking), 'global': torch.tensor(self.global_features)})
224
- else:
225
- for i in range(len(self.process_chunks)):
226
- print(f'Saving dataset to {os.path.join(self.save_dir, self.name + f"_{self.process_chunks[i]}.bin")}')
227
-
228
- dgl.save_graphs(str(graph_path).replace('.bin', f'_{self.process_chunks[i]}.bin'), self.graph_chunks[i], {'labels': self.label_chunks[i], 'tracking': self.tracking_chunks[i], 'global': self.global_chunks[i]})
229
-
230
- def save_chunk(self, chunk_id, graphs, labels, tracking, globals):
231
- if not self.save_to_disk:
232
- return
233
- graph_path = os.path.join(self.save_dir, self.name + '.bin')
234
- print(f'Saving dataset to {os.path.join(self.save_dir, self.name + f"_{self.process_chunks[chunk_id]}.bin")}')
235
- dgl.save_graphs(str(graph_path).replace('.bin', f'_{self.process_chunks[chunk_id]}.bin'), graphs, {'labels': labels, 'tracking': tracking, 'global': globals})
236
-
237
- def has_cache(self):
238
- print(f'Checking for cache of {self.name}')
239
- if not self.save_to_disk:
240
- print('Skipping load.')
241
- return False
242
- if self.chunks == 1:
243
- graph_path = os.path.join(self.save_dir, self.name + '.bin')
244
- return os.path.exists(graph_path)
245
- else:
246
- for i in range(len(self.process_chunks)):
247
- graph_path = os.path.join(self.save_dir, self.name + f'_{self.process_chunks[i]}.bin')
248
- if not os.path.exists(graph_path):
249
- print(f'File {graph_path} does not exist, processing.')
250
- return False
251
- return True
252
-
253
- def load(self):
254
- if self.chunks == 1:
255
- print(f'Loading dataset from {os.path.join(self.save_dir, self.name + ".bin")}')
256
- graphs, label_dict = dgl.load_graphs(os.path.join(self.save_dir, self.name + '.bin'))
257
- self.graphs = graphs
258
- self.labels = label_dict['labels']
259
- self.tracking = label_dict['tracking']
260
- self.global_features = label_dict['global']
261
- else:
262
- self.graphs = []
263
- self.labels = []
264
- self.tracking = []
265
- self.global_features = []
266
- for i in range(self.chunks):
267
- try:
268
- print(f'Loading dataset from {os.path.join(self.save_dir, self.name + f"_{self.process_chunks[i]}.bin")}')
269
- graphs, label = dgl.load_graphs(os.path.join(self.save_dir, self.name + f'_{self.process_chunks[i]}.bin'))
270
- self.graphs.extend(graphs)
271
- self.labels.append(label['labels'])
272
- self.tracking.append(label['tracking'])
273
- self.global_features.append(label['global'])
274
- except Exception as e:
275
- print(e)
276
- self.labels = torch.cat(self.labels)
277
- self.tracking = torch.cat(self.tracking)
278
- self.global_features = torch.cat(self.global_features)
279
-
280
- def __getitem__(self, idx):
281
- return self.graphs[idx], self.labels[idx], self.tracking[idx], self.global_features[idx]
282
-
283
- def __len__(self):
284
- return len(self.graphs)
285
-
286
- #Dataset with edge features added (deta, dphi, dR)
287
- class EdgeDataset(RootDataset):
288
- def make_graph(self, ch):
289
- g = super().make_graph(ch)
290
- u, v = g.edges()
291
- deta = g.ndata['features'][u, 1] - g.ndata['features'][v, 1]
292
- dphi = g.ndata['features'][u, 2] - g.ndata['features'][v, 2]
293
- dphi = torch.where(dphi > np.pi, dphi - 2*np.pi, dphi)
294
- dphi = torch.where(dphi < -np.pi, dphi + 2*np.pi, dphi)
295
- dR = torch.sqrt(deta**2 + dphi**2)
296
- g.edata['features'] = torch.stack([deta, dphi, dR], dim=1)
297
- return g
298
-
299
- class tHbbEdgeDataset(RootDataset):
300
- def __init__(self, exclude_branches=None, **kwargs):
301
- self.exclude_branches = exclude_branches
302
- super().__init__(**kwargs)
303
-
304
- def get_list_of_branches(self):
305
- br = super().get_list_of_branches()
306
- for sector in self.exclude_branches:
307
- if sector == None:
308
- continue
309
- for excl in sector:
310
- if type(excl) == str:
311
- br.append(excl)
312
- return br
313
-
314
- def make_graph(self, ch):
315
- features, lengths = node_features_from_tree(ch, self.node_branch_names, self.node_branch_types, self.node_feature_scales)
316
-
317
- include_mask = torch.ones(features.shape[0], dtype=torch.bool)
318
- node_idx = 0
319
- for sector, length in zip(self.exclude_branches, lengths):
320
- if sector == None:
321
- node_idx += length
322
- continue
323
- for excl in sector:
324
- if type(excl) == int:
325
- include_mask[excl + node_idx] = False
326
- elif type(excl) == str:
327
- include_mask[getattr(self.chain, excl) + node_idx] = False
328
- g = full_connected_graph(features[include_mask].shape[0], self_loops=False)
329
- g.ndata['features'] = features[include_mask]
330
-
331
- u, v = g.edges()
332
- deta = g.ndata['features'][u, 1] - g.ndata['features'][v, 1]
333
- dphi = g.ndata['features'][u, 2] - g.ndata['features'][v, 2]
334
- dphi = torch.where(dphi > np.pi, dphi - 2*np.pi, dphi)
335
- dphi = torch.where(dphi < -np.pi, dphi + 2*np.pi, dphi)
336
- dR = torch.sqrt(deta**2 + dphi**2)
337
- g.edata['features'] = torch.stack([deta, dphi, dR], dim=1)
338
- return g
339
-
340
- class LazyDataset(EdgeDataset):
341
- def __init__(self, buffer_size = 2, **kwargs):
342
- self.buffer = [None,] * buffer_size
343
- self.buffer_ptr = 0
344
- self.get_item_calls = 0
345
- self.buffer_indices = [-1,] * buffer_size
346
- super().__init__(**kwargs)
347
-
348
- def __getitem__(self, idx):
349
- self.get_item_calls += 1
350
- chunk_idx = -1
351
- sum = 0
352
- ev_idx = -999
353
- for i, count in enumerate(self.counts):
354
- sum += count
355
- if idx < sum:
356
- chunk_idx = i
357
- ev_idx = idx - sum + count
358
- break
359
- buf_idx = self.buffer_get(chunk_idx)
360
- if ev_idx >= len(self.buffer[buf_idx][0]):
361
- print(f'Getting event {ev_idx} from chunk {chunk_idx} from buffer {buf_idx}. Calls: {self.get_item_calls}')
362
- print(len(self.buffer))
363
- print(self.counts)
364
- print(len(self.buffer[buf_idx][0]))
365
- return self.buffer[buf_idx][0][ev_idx], self.buffer[buf_idx][1]['labels'][ev_idx], self.buffer[buf_idx][1]['tracking'][ev_idx], self.buffer[buf_idx][1]['global'][ev_idx]
366
-
367
- def buffer_get(self, buffer_idx):
368
- if buffer_idx in self.buffer_indices:
369
- for i in range(len(self.buffer)):
370
- if self.buffer_indices[i] == buffer_idx:
371
- return i
372
- else:
373
- print(f'Loading dataset from {os.path.join(self.save_dir, self.name + f"_{buffer_idx}.bin")}', flush=True)
374
- self.buffer_ptr = (self.buffer_ptr + 1) % len(self.buffer)
375
- self.buffer[self.buffer_ptr] = dgl.load_graphs(os.path.join(self.save_dir, self.name + f'_{buffer_idx}.bin'))
376
- self.buffer_indices[self.buffer_ptr] = buffer_idx
377
- return self.buffer_ptr
378
-
379
- def load(self):
380
- self.counts = []
381
- self.tracking = []
382
- try:
383
- for i in range(self.chunks):
384
- print(f'Loading dataset from {os.path.join(self.save_dir, self.name + f"_{self.process_chunks[i]}.bin")}')
385
- l = dgl.data.graph_serialize.load_labels_v2(os.path.join(self.save_dir, self.name + f'_{self.process_chunks[i]}.bin'))
386
- self.counts.append(len(l['tracking']))
387
- self.tracking.append(l['tracking'])
388
- self.tracking = torch.cat(self.tracking)
389
- except Exception as e:
390
- print(e)
391
-
392
- def __len__(self):
393
- return sum(self.counts)
394
-
395
- class MultiLabelDataset(EdgeDataset):
396
- def __init__(self, **kwargs):
397
- super().__init__(**kwargs)
398
-
399
- def get_list_of_branches(self):
400
- br = super().get_list_of_branches()
401
- for l in self.label:
402
- if isinstance(l, str):
403
- br.append(l)
404
- if isinstance(l, dict):
405
- br.append(l['branch'])
406
- return br
407
-
408
- def get_label(self, ch):
409
- label = []
410
- for l in self.label:
411
- if isinstance(l, str):
412
- label.append((getattr(ch, l)))
413
- if isinstance(l, dict):
414
- label.append(getattr(ch, l['branch'])*float(l['scale']))
415
- if isinstance(l, float) or isinstance(l, int):
416
- label.append(l)
417
-
418
- return torch.tensor(label)
419
-
420
- def process(self):
421
- times = [0, 0, 0]
422
- oldtime = time.time()
423
- if isinstance(self.file_names, str):
424
- self.files = glob.glob(os.path.join(self.raw_dir, self.file_names))
425
- else:
426
- self.files = []
427
- for file_name in self.file_names:
428
- self.files.extend(glob.glob(os.path.join(self.raw_dir, file_name)))
429
- self.chain = ROOT.TChain(self.tree_name)
430
- if len(self.files) == 0:
431
- print('No files found in {}'.format(os.path.join(self.raw_dir, self.file_names)))
432
- for file in self.files:
433
- utils.set_timeout(60*2)
434
- self.chain.Add(file)
435
- utils.unset_timeout()
436
- branches = self.get_list_of_branches()
437
- self.chain.SetBranchStatus('*', 0)
438
- for branch in branches:
439
- self.chain.SetBranchStatus(branch, 1)
440
- newtime = time.time()
441
- times[0] += newtime - oldtime
442
- chunks = np.array_split(np.arange(self.chain.GetEntries()), self.chunks)
443
- chunks = [chunk for i, chunk in enumerate(chunks) if i in self.process_chunks]
444
- self.graph_chunks = []
445
- self.label_chunks = []
446
- self.tracking_chunks = []
447
- self.global_chunks = []
448
- chunk_id = -1
449
- for chunk in chunks:
450
- chunk_id += 1
451
- graphs = []
452
- labels = []
453
- tracking = []
454
- globals = []
455
- for ientry in chunk:
456
- if (ientry % 10000 == 0):
457
- print('Processing event {}/{}'.format(ientry, self.chain.GetEntries()), flush=True)
458
- self.chain.GetEntry(ientry)
459
- passed = True
460
- for selection in self.selections:
461
- if not check_selection(self.chain, selection):
462
- passed = False
463
- continue
464
- oldtime = newtime
465
- newtime = time.time()
466
- times[1] += newtime - oldtime
467
- if passed:
468
- graphs.append(self.make_graph(self.chain))
469
- labels.append(self.get_label(self.chain))
470
- tracking.append(torch.zeros(len(self.tracking_info), dtype=torch.double))
471
- globals.append(torch.zeros(len(self.global_features)))
472
- for i_ti, tr_branch in enumerate(self.tracking_info):
473
- if isinstance(tr_branch, str):
474
- tracking[-1][i_ti] = getattr(self.chain, tr_branch)
475
- else:
476
- tracking[-1][i_ti] = tr_branch
477
- for i_gl, gl_branch in enumerate(self.global_features):
478
- globals[-1][i_gl] = getattr(self.chain, gl_branch)
479
- oldtime = newtime
480
- newtime = time.time()
481
- times[2] += newtime - oldtime
482
-
483
- labels = torch.stack(labels)
484
- self.save_chunk(chunk_id, graphs, labels, torch.stack(tracking), torch.stack(globals))
485
- # self.graph_chunks.append(graphs)
486
- # self.label_chunks.append(labels)
487
- # self.tracking_chunks.append(torch.stack(tracking))
488
- # self.global_chunks.append(torch.stack(globals))
489
- # self.counts.append(len(graphs))
490
- return
491
- self.graphs = self.graph_chunks[0]
492
- for chunk in self.graph_chunks[1:]:
493
- self.graphs += chunk
494
-
495
- self.labels = torch.cat(self.label_chunks)
496
- self.tracking = torch.cat(self.tracking_chunks)
497
- self.global_features = torch.cat(self.global_chunks)
498
- print('Time spent: Creating TChain: {}s, Getting Entries and Selection: {}s, Graph Creation: {}s'.format(*times))
499
- print('Time spent in node_features_from_tree: {}s, full_connected_graph: {}s'.format(*self.times))
500
-
501
- class LazyMultiLabelDataset(MultiLabelDataset, LazyDataset):
502
- def __init__(self, buffer_size = 2, **kwargs):
503
- LazyDataset.__init__(self, buffer_size=buffer_size, **kwargs)
504
-
505
- class MultiLabeltHbbDataset(MultiLabelDataset, tHbbEdgeDataset):
506
- def __init__(self, **kwargs):
507
- super().__init__(**kwargs)
508
-
509
- def get_list_of_branches(self):
510
- br = super().get_list_of_branches()
511
- for sector in self.exclude_branches:
512
- if sector == None:
513
- continue
514
- for excl in sector:
515
- if type(excl) == str:
516
- br.append(excl)
517
- return br
518
-
519
-
520
- class AugmentedDataset(RootDataset):
521
-
522
- def __init__(self, seed = 2, feature_index = None, node_mapping = None, **kwargs):
523
- self.seed = seed
524
- np.random.seed(seed)
525
- if(feature_index == None):
526
- self.feature_index = {"pt": 0, "eta": 1, "phi": 2, "energy": 3, "btag": 4, "charge": 5, "node_type": 6}
527
- if (node_mapping == None):
528
- self.node_mapping = {"jet": 0, "ele": 1, "mu": 2, "ph": 3, "MET": 4}
529
- super().__init__(**kwargs)
530
-
531
- def detector_noise(self, node_features):
532
- noise = np.zeros_like(node_features)
533
-
534
- node_types = node_features[:, self.feature_index["node_type"]]
535
- pts = node_features[:, self.feature_index["pt"]]
536
- etas = node_features[:, self.feature_index["eta"]]
537
- energies = node_features[:, self.feature_index["energy"]]
538
-
539
- # Noise calculation for jets
540
- jet_mask = (node_types == self.node_mapping["jet"])
541
- jet_pts = pts[jet_mask]
542
- jet_etas = etas[jet_mask]
543
-
544
- if (jet_mask.sum() > 0):
545
- jet_resolutions = np.where(
546
- jet_pts <= 0.1, 0.0,
547
- np.where(
548
- np.abs(jet_etas) <= 0.5, np.sqrt(0.06**2 + jet_pts**2 * 1.3e-3**2),
549
- np.where(
550
- np.abs(jet_etas) <= 1.5, np.sqrt(0.10**2 + jet_pts**2 * 1.7e-3**2),
551
- np.where(
552
- np.abs(jet_etas) <= 2.5, np.sqrt(0.25**2 + jet_pts**2 * 3.1e-3**2),
553
- 0.0
554
- )
555
- )
556
- )
557
- )
558
- noise[jet_mask, self.feature_index["pt"]] = np.random.normal(loc=0.0, scale=jet_resolutions)
559
-
560
- # Noise calculation for electrons
561
- ele_mask = (node_types == self.node_mapping["ele"])
562
- ele_pts = pts[ele_mask]
563
- ele_etas = etas[ele_mask]
564
-
565
- if (ele_mask.sum() > 0):
566
- ele_resolutions = np.where(
567
- np.abs(ele_etas) <= 0.5, np.sqrt(0.03**2 + ele_pts**2 * 1.3e-3**2),
568
- np.where(
569
- np.abs(ele_etas) <= 1.5, np.sqrt(0.05**2 + ele_pts**2 * 1.7e-3**2),
570
- np.where(
571
- np.abs(ele_etas) <= 2.5, np.sqrt(0.15**2 + ele_pts**2 * 3.1e-3**2),
572
- 0.0
573
- )
574
- )
575
- )
576
- noise[ele_mask, self.feature_index["pt"]] = np.random.normal(loc=0.0, scale=ele_resolutions)
577
-
578
- # Noise calculation for muons
579
- mu_mask = (node_types == self.node_mapping["mu"])
580
- mu_pts = pts[mu_mask]
581
- mu_etas = etas[mu_mask]
582
-
583
- if (mu_mask.sum() > 0):
584
- mu_resolutions = np.where(
585
- np.abs(mu_etas) <= 0.5, np.sqrt(0.01**2 + mu_pts**2 * 1.0e-4**2),
586
- np.where(
587
- np.abs(mu_etas) <= 1.5, np.sqrt(0.015**2 + mu_pts**2 * 1.5e-4**2),
588
- np.where(
589
- np.abs(mu_etas) <= 2.5, np.sqrt(0.025**2 + mu_pts**2 * 3.5e-4**2),
590
- 0.0
591
- )
592
- )
593
- )
594
- noise[mu_mask, self.feature_index["pt"]] = np.random.normal(loc=0.0, scale=mu_resolutions)
595
-
596
- # Noise calculation for photons
597
- ph_mask = (node_types == self.node_mapping["ph"])
598
- ph_etas = etas[ph_mask]
599
- ph_energies = energies[ph_mask]
600
-
601
- if (ph_mask.sum() > 0):
602
- ph_resolutions = np.where(
603
- np.abs(ph_etas) <= 3.2, np.sqrt(ph_energies**2 * 0.0017**2 + ph_energies * 0.101**2),
604
- np.where(
605
- np.abs(ph_etas) <= 4.9, np.sqrt(ph_energies**2 * 0.0350**2 + ph_energies * 0.285**2),
606
- 0.0
607
- )
608
- )
609
- noise[ph_mask, self.feature_index["energy"]] = np.random.normal(loc=0.0, scale=ph_resolutions)
610
- return noise
611
-
612
- def make_graph(self, ch):
613
- g = super().make_graph(ch)
614
-
615
- g.ndata['augmented_features'] = g.ndata['features']
616
-
617
- num_nodes = len(g.ndata['features'][:, 0])
618
-
619
- # Rotations: phi -> phi + delta_phi
620
- phi_index = self.feature_index["phi"]
621
- # Generate a single delta_phi for all nodes
622
- delta_phi = np.random.uniform(low=-np.pi, high=np.pi)
623
-
624
- # Apply the same delta_phi to all nodes
625
- g.ndata['augmented_features'][:, phi_index] = (g.ndata['augmented_features'][:, phi_index] + delta_phi + np.pi) % (2 * np.pi) - np.pi
626
-
627
- # Reflections: eta -> -1 * eta, phi -> -1 * phi
628
- eta_index = self.feature_index["eta"]
629
-
630
- eta_reflection = np.random.choice([-1, 1])
631
- phi_reflection = np.random.choice([-1, 1])
632
-
633
- g.ndata['augmented_features'][:, eta_index] = g.ndata['augmented_features'][:, eta_index] * eta_reflection
634
- g.ndata['augmented_features'][:, phi_index] = g.ndata['augmented_features'][:, phi_index] * phi_reflection
635
-
636
-
637
- # Detector Noise: pt -> pt + normal(pt, noise(pt))
638
- noise = self.detector_noise(g.ndata['augmented_features'])
639
- g.ndata['augmented_features'] = g.ndata['augmented_features'] + noise
640
-
641
- pt_index = self.feature_index["pt"]
642
- if (g.ndata['augmented_features'][-1][self.feature_index["node_type"]] == self.node_mapping["MET"]):
643
- # Initialize sums of px and py
644
- sum_px = 0
645
- sum_py = 0
646
-
647
- # Loop over all nodes except the last one (MET node)
648
- for i in range(len(g.ndata['augmented_features']) - 1):
649
- pt = g.ndata['augmented_features'][i][pt_index]
650
- phi = g.ndata['augmented_features'][i][phi_index]
651
-
652
- # Compute px and py
653
- px = pt * np.cos(phi)
654
- py = pt * np.sin(phi)
655
-
656
- # Sum px and py
657
- sum_px += px
658
- sum_py += py
659
-
660
- # Calculate MET
661
- g.ndata['augmented_features'][-1][pt_index] = np.sqrt(sum_px**2 + sum_py**2)
662
-
663
- u, v = g.edges()
664
- deta = g.ndata['features'][u, 1] - g.ndata['features'][v, 1]
665
- dphi = g.ndata['features'][u, 2] - g.ndata['features'][v, 2]
666
- dphi = torch.where(dphi > np.pi, dphi - 2*np.pi, dphi)
667
- dphi = torch.where(dphi < -np.pi, dphi + 2*np.pi, dphi)
668
- dR = torch.sqrt(deta**2 + dphi**2)
669
- g.edata['features'] = torch.stack([deta, dphi, dR], dim=1)
670
-
671
- deta = g.ndata['augmented_features'][u, 1] - g.ndata['augmented_features'][v, 1]
672
- dphi = g.ndata['augmented_features'][u, 2] - g.ndata['augmented_features'][v, 2]
673
- dphi = torch.where(dphi > np.pi, dphi - 2*np.pi, dphi)
674
- dphi = torch.where(dphi < -np.pi, dphi + 2*np.pi, dphi)
675
- dR = torch.sqrt(deta**2 + dphi**2)
676
- g.edata['augmented_features'] = torch.stack([deta, dphi, dR], dim=1)
677
-
678
- return g
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nemo/root_gnn_base/photon_ID_dataset.py DELETED
@@ -1,44 +0,0 @@
1
- from root_gnn_base import dataset
2
- import dgl
3
- import torch
4
- import numpy as np
5
-
6
- def radius_graph(features, radii, self_loops=False):
7
- senders = []
8
- receivers = []
9
- n_nodes = features.shape[0]
10
- senders = np.arange(n_nodes*n_nodes) // n_nodes
11
- receivers = np.arange(n_nodes*n_nodes) % n_nodes
12
- if not self_loops and n_nodes > 1:
13
- mask = senders != receivers
14
- senders = senders[mask]
15
- receivers = receivers[mask]
16
- for k, r in radii.items():
17
- d = features[senders, k] - features[receivers, k]
18
- mask = np.abs(d) < r
19
- senders = senders[mask]
20
- receivers = receivers[mask]
21
- return dgl.graph((senders, receivers))
22
-
23
- class PhotonIDDataset(dataset.LazyMultiLabelDataset):
24
- def __init__(self, eta_radius, phi_radius, **kwargs):
25
- self.eta_radius = eta_radius
26
- self.phi_radius = phi_radius
27
- super().__init__(**kwargs)
28
- def make_graph(self, ch):
29
- features, _ = dataset.node_features_from_tree(ch, self.node_branch_names, self.node_branch_types, self.node_feature_scales)
30
- features = features[features[:,0] != 0]
31
- #Delta Eta, Delta Phi, Adjacent Layer
32
- g = radius_graph(features, {1: self.eta_radius, 2: self.phi_radius, 6: 1.1}, self_loops=True) #Self loops ensure last cell is included even if disconnected
33
- g.ndata['features'] = features
34
- u, v = g.edges()
35
- deta = features[u, 1] - features[v, 1]
36
- dphi = g.ndata['features'][u, 2] - g.ndata['features'][v, 2]
37
- dphi = torch.where(dphi > np.pi, dphi - 2*np.pi, dphi)
38
- dphi = torch.where(dphi < -np.pi, dphi + 2*np.pi, dphi)
39
- dR = torch.sqrt(deta**2 + dphi**2)
40
- dx = features[u, 3] - features[v, 3]
41
- dy = features[u, 4] - features[v, 4]
42
- dz = features[u, 5] - features[v, 5]
43
- g.edata['features'] = torch.stack([deta, dphi, dR, dx, dy, dz], dim=1)
44
- return g
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nemo/root_gnn_base/similarity.py DELETED
@@ -1,158 +0,0 @@
1
- import numpy as np
2
- import scipy
3
- from sklearn.decomposition import PCA
4
- from sklearn.metrics.pairwise import cosine_similarity
5
- from sklearn.metrics.pairwise import euclidean_distances
6
- from sklearn.preprocessing import StandardScaler
7
-
8
- from scipy.stats import wasserstein_distance
9
-
10
- def cka(rep_a, rep_b, size=None):
11
- """
12
- Computes the Centered Kernel Alignment (CKA) between two large representation matrices rep_a and rep_b.
13
- If size is provided, it performs CKA on a randomly selected subset of the data.
14
-
15
- Parameters:
16
- rep_a : np.ndarray
17
- First representation matrix of size (n_samples, n_features_a).
18
- rep_b : np.ndarray
19
- Second representation matrix of size (n_samples, n_features_b).
20
- size : int, optional
21
- Number of samples to use for the CKA calculation. If None, use the full dataset.
22
-
23
- Returns:
24
- float
25
- CKA similarity between rep_a and rep_b.
26
- """
27
-
28
- def gram_linear(x):
29
- """Compute the Gram (kernel) matrix using a linear kernel."""
30
- return x @ x.T
31
-
32
- def center_gram(gram):
33
- """Center the Gram matrix."""
34
- n = gram.shape[0]
35
- identity = np.eye(n)
36
- ones = np.ones((n, n)) / n
37
- return gram - ones @ gram - gram @ ones + ones @ gram @ ones
38
-
39
- # If sample_size is specified, randomly sample a subset of the data
40
- if size is not None and size < rep_a.shape[0]:
41
- indices = np.random.choice(rep_a.shape[0], size, replace=False)
42
- rep_a = rep_a[indices]
43
- rep_b = rep_b[indices]
44
-
45
- # Compute the Gram matrices
46
- gram_a = gram_linear(rep_a)
47
- gram_b = gram_linear(rep_b)
48
-
49
- # Center the Gram matrices
50
- centered_gram_a = center_gram(gram_a)
51
- centered_gram_b = center_gram(gram_b)
52
-
53
- # Compute the CKA similarity
54
- numerator = np.sum(centered_gram_a * centered_gram_b)
55
- denominator = np.sqrt(np.sum(centered_gram_a**2) * np.sum(centered_gram_b**2))
56
-
57
- return numerator / denominator if denominator != 0 else 0
58
-
59
- def cca(X, Y, size = None, num_components=10):
60
- """
61
- Perform Canonical Correlation Analysis (CCA) between two datasets.
62
-
63
- Parameters:
64
- X : np.ndarray
65
- First dataset, shape (n_samples, n_features_X).
66
- Y : np.ndarray
67
- Second dataset, shape (n_samples, n_features_Y).
68
- num_components : int
69
- Number of CCA components to return.
70
-
71
- Returns:
72
- w_X : np.ndarray
73
- Canonical weights for the first dataset, shape (n_features_X, num_components).
74
- w_Y : np.ndarray
75
- Canonical weights for the second dataset, shape (n_features_Y, num_components).
76
- corrs : np.ndarray
77
- Array of canonical correlations for each component.
78
- """
79
-
80
- # If sample size is specified, randomly sample a subset of the data
81
- if size is not None and size < X.shape[0]:
82
- indices = np.random.choice(X.shape[0], size, replace=False)
83
- X = X[indices]
84
- Y = Y[indices]
85
-
86
- # Standardize both datasets (mean = 0, variance = 1)
87
- scaler_X = StandardScaler()
88
- scaler_Y = StandardScaler()
89
-
90
- X = scaler_X.fit_transform(X)
91
- Y = scaler_Y.fit_transform(Y)
92
-
93
- # Covariance matrices
94
- C_XX = np.cov(X, rowvar=False) # Covariance of X
95
- C_YY = np.cov(Y, rowvar=False) # Covariance of Y
96
- C_XY = np.cov(X, Y, rowvar=False)[:X.shape[1], X.shape[1]:] # Cross-covariance of X and Y
97
-
98
- # Regularization term to avoid singular matrices
99
- reg = 1e-6
100
- inv_C_XX = np.linalg.inv(C_XX + reg * np.eye(C_XX.shape[0]))
101
- inv_C_YY = np.linalg.inv(C_YY + reg * np.eye(C_YY.shape[0]))
102
-
103
- # Solve the generalized eigenvalue problem for CCA
104
- # (inv_C_XX @ C_XY @ inv_C_YY @ C_XY.T) and vice versa for Y
105
- A = inv_C_XX @ C_XY @ inv_C_YY @ C_XY.T
106
- B = inv_C_YY @ C_XY.T @ inv_C_XX @ C_XY
107
-
108
- # Perform eigenvalue decomposition
109
- eigvals_X, eigvecs_X = np.linalg.eigh(A)
110
- eigvals_Y, eigvecs_Y = np.linalg.eigh(B)
111
-
112
- # Sort the eigenvalues and eigenvectors in descending order
113
- idx_X = np.argsort(eigvals_X)[::-1]
114
- idx_Y = np.argsort(eigvals_Y)[::-1]
115
-
116
- eigvecs_X = eigvecs_X[:, idx_X]
117
- eigvecs_Y = eigvecs_Y[:, idx_Y]
118
-
119
- # Canonical weights (the first `num_components` components)
120
- w_X = eigvecs_X[:, :num_components]
121
- w_Y = eigvecs_Y[:, :num_components]
122
-
123
- # Canonical correlations (square root of the eigenvalues, constrained to [0,1])
124
- corrs = np.sqrt(np.clip(eigvals_X[:num_components], 0, 1))
125
-
126
- return np.mean(corrs)
127
- return w_X, w_Y, corrs
128
-
129
- def pca(X, Y, size=1000, n_components=3, bins=30):
130
-
131
- pca_X = PCA(n_components=n_components)
132
- X_pca = pca_X.fit_transform(X)
133
-
134
- pca_Y = PCA(n_components=n_components)
135
- Y_pca = pca_Y.fit_transform(Y)
136
-
137
- # Step 2: Determine common bin edges based on the range of PCA components
138
- min_value = min(X_pca.min(), Y_pca.min())
139
- max_value = max(X_pca.max(), Y_pca.max())
140
- bin_edges = np.linspace(min_value, max_value, bins + 1)
141
-
142
- # Step 3: Calculate histograms for each PCA component using the same bins
143
- histograms_X = [np.histogram(X_pca[:, i], bins=bin_edges, density=True)[0] for i in range(n_components)]
144
- histograms_Y = [np.histogram(Y_pca[:, i], bins=bin_edges, density=True)[0] for i in range(n_components)]
145
-
146
- # Step 4: Calculate Wasserstein distance between corresponding histograms
147
- total_distance = 0
148
- for i in range(n_components):
149
- total_distance += wasserstein_distance(histograms_X[i], histograms_Y[i])
150
-
151
- # Step 5: Normalize the total distance for a similarity score
152
- # Calculate the maximum possible distance (theoretical max could be based on histogram size)
153
- # This could be replaced with a more complex calculation if necessary.
154
- max_distance = 1.0 # Replace this with a suitable maximum based on your dataset properties.
155
-
156
- similarity_score = 1 - (total_distance / max_distance)
157
-
158
- return max(0, min(1, similarity_score)) # Ensure the score stays in [0, 1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nemo/root_gnn_base/uproot_dataset.py DELETED
@@ -1,54 +0,0 @@
1
- from root_gnn_base import dataset
2
- import torch
3
- import uproot
4
- import glob
5
- import os
6
- import awkward as ak
7
- import numpy as np
8
- import time
9
-
10
- def node_features_from_ak(ch, node_branch_names, node_branch_types, node_feature_scales):
11
- node_types = []
12
- n_types = len(node_branch_names[0])
13
- for i in range(n_types):
14
- features = []
15
- branch_type = node_branch_types[i]
16
- for j in range(len(node_branch_names)):
17
- if node_branch_names[j] == 'CALC_E':
18
- features.append(features[0] * np.cosh(features[1]))
19
- elif node_branch_names[j] == 'NODE_TYPE':
20
- features.append(ak.full_like(features[0], i))
21
- elif isinstance(node_branch_names[j][i], str):
22
- features.append(ch[node_branch_names[j][i]])
23
- elif isinstance(node_branch_names[j][i], (int, float)):
24
- features.append(ak.full_like(features[0], node_branch_names[j][i]))
25
- if branch_type == 'single':
26
- features = [f[:,np.newaxis] for f in features]
27
- node_types.append(ak.Array(features))
28
- node_features = ak.concatenate(node_types, axis=2) * node_feature_scales #axis order at this point is (feature, event, node)
29
- return node_features
30
-
31
- class UprootDataset(dataset.RootDataset):
32
- def process(self):
33
- starttime = time.time()
34
- self.files = glob.glob(os.path.join(self.raw_dir, self.file_names))
35
- branches = self.get_list_of_branches()
36
- self.chain = uproot.concatenate([f + ':' + self.tree_name for f in self.files], branches, num_workers=4)
37
- node_features = node_features_from_ak(self.chain, self.node_branch_names, self.node_branch_types, self.node_feature_scales)
38
- loadtime = time.time()
39
- n_nodes = ak.num(node_features[0], axis=1) #number of nodes for each event
40
- ftime = time.time()
41
- self.graphs = [dataset.full_connected_graph(n, False) for n in n_nodes]
42
- itime = time.time()
43
- for i in range(len(self.graphs)):
44
- if i % 10000 == 0:
45
- print(f'Processing event {i}/{len(self.graphs)}')
46
- self.graphs[i].ndata['features'] = torch.transpose(torch.tensor(node_features[:,i,:]),0,1).to(torch.float)
47
- self.label = torch.stack([torch.full((len(self.graphs),),torch.tensor(self.label)), torch.tensor(ak.values_astype(self.chain[self.fold_var], np.int64))], dim=1)
48
- gtime = time.time()
49
- print()
50
- print(f'load time: {loadtime - starttime} s')
51
- print(f'feature time: {ftime - loadtime} s')
52
- print(f'graph time: {itime - ftime} s')
53
- print(f'graph data time: {gtime - itime} s')
54
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nemo/root_gnn_base/utils.py DELETED
@@ -1,393 +0,0 @@
1
- import importlib
2
- import yaml
3
- import os
4
- import torch
5
- import numpy as np
6
- import matplotlib.pyplot as plt
7
- import dgl
8
- import signal
9
-
10
- def buildFromConfig(conf, run_time_args = {}):
11
- device = run_time_args.get('device', 'cpu')
12
- if 'module' in conf:
13
- module = importlib.import_module(conf['module'])
14
- cls = getattr(module, conf['class'])
15
- args = conf['args'].copy()
16
- if 'weight' in args and isinstance(args['weight'], list):
17
- args['weight'] = torch.tensor(args['weight'], dtype=torch.float, device=device)
18
- # Remove device from run_time_args to not pass it to the class
19
- run_time_args = {k: v for k, v in run_time_args.items() if k != 'device'}
20
- return cls(**args, **run_time_args)
21
- else:
22
- print('No module specified in config. Returning None.')
23
-
24
- def cycler(iterable):
25
- while True:
26
- #print('Cycler is cycling...')
27
- for i in iterable:
28
- yield i
29
-
30
- def include_config(conf):
31
- if 'include' in conf:
32
- for i in conf['include']:
33
- with open(i) as f:
34
- conf.update(yaml.load(f, Loader=yaml.FullLoader))
35
- del conf['include']
36
-
37
- def load_config(config_file):
38
- with open(config_file) as f:
39
- conf = yaml.load(f, Loader=yaml.FullLoader)
40
- include_config(conf)
41
- return conf
42
-
43
- #Timeout function from https://stackoverflow.com/questions/492519/timeout-on-a-function-call
44
- class TimeoutException(Exception):
45
- pass
46
-
47
- def timeout_handler(signum, frame):
48
- raise TimeoutException()
49
-
50
- def set_timeout(timeout):
51
- signal.signal(signal.SIGALRM, timeout_handler)
52
- signal.alarm(timeout)
53
-
54
- def unset_timeout():
55
- signal.alarm(0)
56
- signal.signal(signal.SIGALRM, signal.SIG_DFL)
57
-
58
- def make_padding_graph(batch, pad_nodes, pad_edges):
59
- senders = []
60
- receivers = []
61
- senders = torch.arange(0,pad_edges) // pad_nodes
62
- receivers = torch.arange(1,pad_edges+1) % pad_nodes
63
- if pad_nodes < 0 or pad_edges < 0 or pad_edges > pad_nodes * pad_nodes / 2:
64
- print('Batch is larger than padding size or e > n^2/2. Repeating edges as necessary.')
65
- print(f'Batch nodes: {batch.num_nodes()}, Batch edges: {batch.num_edges()}, Padding nodes: {pad_nodes}, Padding edges: {pad_edges}')
66
- senders = senders % pad_nodes
67
- padg = dgl.graph((senders[:pad_edges], receivers[:pad_edges]), num_nodes = pad_nodes)
68
- for k in batch.ndata.keys():
69
- padg.ndata[k] = torch.zeros( (pad_nodes, batch.ndata[k].shape[1]) )
70
- for k in batch.edata.keys():
71
- padg.edata[k] = torch.zeros( (pad_edges, batch.edata[k].shape[1]) )
72
- return dgl.batch([batch, padg.to(batch.device)])
73
-
74
- def pad_size(graphs, edges, nodes, edge_per_graph=3, node_per_graph=14):
75
- pad_nodes = ((nodes // (node_per_graph * graphs))+1) * graphs * node_per_graph
76
- pad_edges = ((edges // (edge_per_graph * graphs))+1) * graphs * edge_per_graph
77
- return pad_nodes, pad_edges
78
-
79
- def pad_batch_to_step_per_graph(batch, edge_per_graph=3, node_per_graph=14):
80
- n_graphs = batch.batch_num_nodes().shape[0]
81
- pad_nodes = (batch.num_nodes() + node_per_graph * n_graphs) % int(n_graphs * node_per_graph)
82
- pad_edges = (batch.num_edges() + edge_per_graph * n_graphs) % int(n_graphs * edge_per_graph)
83
- return make_padding_graph(batch, pad_nodes, pad_edges)
84
-
85
- def pad_batch(batch, edges = 104000, nodes = 16000):
86
- if edges == 0 and nodes == 0:
87
- return batch
88
- pad_nodes = 0
89
- pad_edges = 0
90
- pad_nodes = nodes - batch.num_nodes()
91
- pad_edges = edges - batch.num_edges()
92
- return make_padding_graph(batch, pad_nodes, pad_edges)
93
-
94
- def pad_batch_num_nodes(batch, max_num_nodes, hid_size = 64):
95
- print(f"Padding each graph to have {max_num_nodes} nodes. Using hidden size {hid_size}.")
96
-
97
- unbatched = dgl.unbatch(batch)
98
- for g in unbatched:
99
- num_nodes_to_add = max_num_nodes - g.number_of_nodes()
100
- if num_nodes_to_add > 0:
101
- g.add_nodes(num_nodes_to_add) # Add isolated nodes
102
-
103
- batch = dgl.batch(unbatched)
104
-
105
- padding_mask = torch.zeros((batch.ndata['features'].shape[0]), dtype=torch.bool)
106
- global_update_weights = torch.ones((batch.ndata['features'].shape[0], hid_size))
107
-
108
- for i in range(len(batch.ndata['features'])):
109
- if (torch.count_nonzero(batch.ndata['features'][i]) == 0):
110
- padding_mask[i] = True
111
- global_update_weights[i] = 0
112
-
113
- batch.ndata['w'] = global_update_weights
114
- batch.ndata['padding_mask'] = padding_mask
115
-
116
- return batch
117
-
118
-
119
- def fold_selection(fold_config, sample):
120
- n_folds = fold_config['n_folds']
121
- folds_opt = fold_config[sample]
122
- folds = []
123
- if type(folds_opt) == int:
124
- return lambda x : x.tracking[:,0] % n_folds == folds_opt
125
- elif type(folds_opt) == list:
126
- print("fold type is list")
127
- print(f"fold_config = {fold_config}")
128
- print(f"folds_opt = {folds_opt}")
129
- return lambda x : sum([x.tracking[:,0] % n_folds == f for f in folds_opt]) == 1
130
- else:
131
- raise ValueError("Invalid fold selection option with type {}".format(type(folds_opt)))
132
-
133
- def fold_selection_name(fold_config, sample):
134
- n_folds = fold_config['n_folds']
135
- folds_opt = fold_config[sample]
136
- if type(folds_opt) == int:
137
- return f'n_{n_folds}_f_{folds_opt}'
138
- elif type(folds_opt) == list:
139
- return f'n_{n_folds}_f_{"_".join([str(f) for f in folds_opt])}'
140
- else:
141
- raise ValueError("Invalid fold selection option with type {}".format(type(folds_opt)))
142
-
143
- #Return the index and checkpoint of the last epoch.
144
- def get_last_epoch(config, max_ep = -1, device = None):
145
- last_epoch = -1
146
- checkpoint = None
147
- if max_ep < 0:
148
- max_ep = config['Training']['epochs']
149
- for ep in range(max_ep):
150
- if os.path.exists(os.path.join(config['Training_Directory'], f'model_epoch_{ep}.pt')):
151
- last_epoch = ep
152
- else:
153
- print(f'Epoch {ep} not found. Stopping at epoch {last_epoch}')
154
- print('File not found: ', os.path.join(config['Training_Directory'], f'model_epoch_{ep}.pt'))
155
- break
156
- if last_epoch >= 0:
157
- checkpoint = torch.load(os.path.join(config['Training_Directory'], f'model_epoch_{last_epoch}.pt'), map_location=device)
158
- return last_epoch, checkpoint
159
-
160
- #Return the index and checkpoint of the last epoch.
161
- def get_specific_epoch(config, target_epoch, device = None, from_ryan = False):
162
- last_epoch = -1
163
- checkpoint = None
164
- for ep in range(target_epoch + 1):
165
- if (from_ryan):
166
- if os.path.exists(os.path.join('/global/cfs/cdirs/atlas/berobert/root_gnn_dgl/' + config['Training_Directory'], f'model_epoch_{ep}.pt')):
167
- last_epoch = ep
168
- else:
169
- print(f'Epoch {ep} not found. Stopping at epoch {last_epoch}')
170
- print('File not found: ', os.path.join('/global/cfs/cdirs/atlas/berobert/root_gnn_dgl/' + config['Training_Directory'], f'model_epoch_{ep}.pt'))
171
- break
172
- else:
173
- if os.path.exists(os.path.join(config['Training_Directory'], f'model_epoch_{ep}.pt')):
174
- last_epoch = ep
175
- else:
176
- print(f'Epoch {ep} not found. Stopping at epoch {last_epoch}')
177
- print('File not found: ', os.path.join(config['Training_Directory'], f'model_epoch_{ep}.pt'))
178
- break
179
- if last_epoch >= 0:
180
- if (from_ryan):
181
- checkpoint = torch.load('/global/cfs/cdirs/atlas/berobert/root_gnn_dgl/' + os.path.join(config['Training_Directory'], f'model_epoch_{last_epoch}.pt'), map_location=device)
182
- else:
183
- checkpoint = torch.load(os.path.join(config['Training_Directory'], f'model_epoch_{last_epoch}.pt'), map_location=device)
184
- return last_epoch, checkpoint
185
-
186
- #Return the index and checkpoint of the nest epoch.
187
- def get_best_epoch(config, var='Test_AUC', mode='max', device=None, from_ryan=False):
188
- # Read the training log
189
- log = read_log(config)
190
-
191
- # Ensure the specified variable exists in the log
192
- if var not in log:
193
- raise ValueError(f"Variable '{var}' not found in the training log.")
194
-
195
- # Determine the target epoch based on the mode ('max' or 'min')
196
- if mode == 'max':
197
- target_epoch = int(np.argmax(log[var]))
198
- print(f"Best epoch based on '{var}' (max): {target_epoch} with value: {log[var][target_epoch]}")
199
- elif mode == 'min':
200
- target_epoch = int(np.argmin(log[var]))
201
- print(f"Best epoch based on '{var}' (min): {target_epoch} with value: {log[var][target_epoch]}")
202
- else:
203
- raise ValueError(f"Invalid mode '{mode}'. Expected 'max' or 'min'.")
204
-
205
- # Initialize checkpoint retrieval variables
206
- last_epoch = -1
207
- checkpoint = None
208
-
209
- # Iterate through epochs up to the target epoch to find the corresponding checkpoint
210
- for ep in range(target_epoch + 1):
211
- if from_ryan:
212
- checkpoint_path = os.path.join(
213
- '/global/cfs/cdirs/atlas/berobert/root_gnn_dgl/',
214
- config['Training_Directory'],
215
- f'model_epoch_{ep}.pt'
216
- )
217
- else:
218
- checkpoint_path = os.path.join(
219
- config['Training_Directory'],
220
- f'model_epoch_{ep}.pt'
221
- )
222
-
223
- if os.path.exists(checkpoint_path):
224
- last_epoch = ep
225
- else:
226
- print(f'Epoch {ep} not found. Stopping at epoch {last_epoch}')
227
- print('File not found: ', checkpoint_path)
228
- break
229
-
230
- # Load the checkpoint for the last valid epoch
231
- if last_epoch >= 0:
232
- if from_ryan:
233
- checkpoint_path = os.path.join(
234
- '/global/cfs/cdirs/atlas/berobert/root_gnn_dgl/',
235
- config['Training_Directory'],
236
- f'model_epoch_{last_epoch}.pt'
237
- )
238
- else:
239
- checkpoint_path = os.path.join(
240
- config['Training_Directory'],
241
- f'model_epoch_{last_epoch}.pt'
242
- )
243
-
244
- checkpoint = torch.load(checkpoint_path, map_location=device)
245
-
246
- return last_epoch, checkpoint
247
-
248
- def read_log(config):
249
- lines = []
250
- with open(config['Training_Directory'] + '/training.log', 'r') as f:
251
- lines = f.readlines()
252
- lines = [l for l in lines if 'Epoch' in l]
253
-
254
- labels = []
255
- for field in lines[0].split('|'):
256
- labels.append(field.split()[0])
257
-
258
- # Initialize log as a dictionary with empty lists
259
- log = {label: [] for label in labels}
260
-
261
- for line in lines:
262
- valid_row = True # Flag to check if the row is valid
263
- temp_row = {} # Temporary row to store values before adding to log
264
-
265
- for field in line.split('|'):
266
- spl = field.split()
267
- try:
268
- temp_row[spl[0]] = float(spl[1])
269
- except (ValueError, IndexError):
270
- valid_row = False # Mark row as invalid if conversion fails
271
- break
272
-
273
- if valid_row: # Only add the row if all fields are valid
274
- for label in labels:
275
- log[label].append(temp_row.get(label, np.nan)) # Handle missing labels gracefully
276
-
277
- # Convert lists to numpy arrays for consistency
278
- for label in labels:
279
- log[label] = np.array(log[label])
280
-
281
- return log
282
-
283
- #Plot training logs.
284
- def plot_log(log, output_file):
285
- fig, ax = plt.subplots(2, 2, figsize=(10,10))
286
- #Time
287
-
288
- ax[0][0].plot(log['Epoch'], np.cumsum(log['Time']), label='Time')
289
- ax[0][0].set_xlabel('Epoch')
290
- ax[0][0].set_ylabel('Time (s)')
291
- ax[0][0].legend()
292
-
293
- """
294
- ax[0][0].plot(log['Epoch'], log['LR'], label='Learning Rate')
295
- ax[0][0].set_xlabel('Epoch')
296
- ax[0][0].set_ylabel('Learning Rate')
297
- ax[0][0].set_yscale('log')
298
- ax[0][0].legend()
299
- """
300
-
301
- #Loss
302
- ax[0][1].plot(log['Epoch'], log['Loss'], label='Train Loss')
303
- ax[0][1].plot(log['Epoch'], log['Test_Loss'], label='Test Loss')
304
- ax[0][1].set_xlabel('Epoch')
305
- ax[0][1].set_ylabel('Loss')
306
- ax[0][1].legend()
307
-
308
- #Accuracy
309
- ax[1][0].plot(log['Epoch'], log['Accuracy'], label='Test Accuracy')
310
- ax[1][0].set_xlabel('Epoch')
311
- ax[1][0].set_ylabel('Accuracy')
312
- ax[1][0].set_ylim((0.44, 0.56))
313
- ax[1][0].legend()
314
-
315
- #AUC
316
- ax[1][1].plot(log['Epoch'], log['Test_AUC'], label='Test AUC')
317
- ax[1][1].set_xlabel('Epoch')
318
- ax[1][1].set_ylabel('AUC')
319
- ax[1][1].legend()
320
-
321
- fig.savefig(output_file)
322
-
323
- class EarlyStop():
324
- def __init__(self, patience=15, threshold=1e-8, mode='min'):
325
- self.patience = patience
326
- self.threshold = threshold
327
- self.mode = mode
328
- self.count = 0
329
- self.current_best = np.inf if mode == 'min' else -np.inf
330
- self.should_stop = False
331
-
332
- def update(self, value):
333
- if self.mode == 'min': # Minimizing loss
334
- if value < self.current_best - self.threshold:
335
- self.current_best = value
336
- self.count = 0
337
- else:
338
- self.count += 1
339
- elif self.mode == 'max': # Maximizing metric
340
- if value > self.current_best + self.threshold:
341
- self.current_best = value
342
- self.count = 0
343
- else:
344
- self.count += 1
345
-
346
- # Check if patience is exceeded
347
- if self.count >= self.patience:
348
- self.should_stop = True
349
-
350
- def reset(self):
351
- self.count = 0
352
- self.current_best = np.inf if self.mode == 'min' else -np.inf
353
- self.should_stop = False
354
-
355
- def to_str(self):
356
- status = (
357
- f"EarlyStop Status:\n"
358
- f" Mode: {'Minimize' if self.mode == 'min' else 'Maximize'}\n"
359
- f" Patience: {self.patience}\n"
360
- f" Threshold: {self.threshold:.3e}\n"
361
- f" Current Best: {self.current_best:.6f}\n"
362
- f" Consecutive Epochs Without Improvement: {self.count}\n"
363
- f" Stopping Triggered: {'Yes' if self.should_stop else 'No'}"
364
- )
365
- return status
366
-
367
- def to_dict(self):
368
-
369
- return {
370
- 'patience': self.patience,
371
- 'threshold': self.threshold,
372
- 'mode': self.mode,
373
- 'count': self.count,
374
- 'current_best': self.current_best,
375
- 'should_stop': self.should_stop,
376
- }
377
-
378
- @classmethod
379
- def load_from_dict(cls, state_dict):
380
- instance = cls(
381
- patience=state_dict['patience'],
382
- threshold=state_dict['threshold'],
383
- mode=state_dict['mode']
384
- )
385
- instance.count = state_dict['count']
386
- instance.current_best = state_dict['current_best']
387
- instance.should_stop = state_dict['should_stop']
388
- return instance
389
-
390
-
391
- def graph_augmentation(graph):
392
- print("Augmenting Graph")
393
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nemo/scripts/check_dataset_files.py DELETED
@@ -1,130 +0,0 @@
1
- import yaml
2
- import os
3
- import subprocess
4
- import argparse
5
-
6
- def check_dataset_files(yaml_file, rerun=False):
7
- """
8
- Check if all required .bin files exist for each dataset in the YAML file.
9
- """
10
- try:
11
- # Open and parse the YAML file
12
- with open(yaml_file, 'r') as file:
13
- config = yaml.safe_load(file)
14
-
15
- # Check if 'Datasets' exists in the YAML file
16
- if 'Datasets' not in config:
17
- print(f"No 'Datasets' section found in {yaml_file}.")
18
- return
19
-
20
- datasets = config['Datasets']
21
- all_files_exist = True
22
-
23
- for dataset_name, dataset_config in datasets.items():
24
- # Extract required information
25
- save_dir = dataset_config['args']['save_dir']
26
- chunks = dataset_config['args']['chunks']
27
- folding = dataset_config.get('folding', {})
28
- n_folds = folding.get('n_folds', 0)
29
- test_folds = folding.get('test', [])
30
- train_folds = folding.get('train', [])
31
-
32
- print(f"\n== Checking dataset: {dataset_name} ==")
33
- print(f" save_dir: {save_dir}")
34
- print(f" chunks: {chunks}")
35
- print(f" n_folds: {n_folds}")
36
- print(f" test_folds: {test_folds}")
37
- print(f" train_folds: {train_folds}")
38
-
39
- missing_files = []
40
-
41
- # 1. Check for chunk files
42
- for chunk in range(chunks):
43
- chunk_file = os.path.join(save_dir, f"{dataset_name}_{chunk}.bin")
44
- if not os.path.exists(chunk_file):
45
- missing_files.append(chunk_file)
46
-
47
- # 2. Check for prebatched fold files (test and train)
48
- # Naming: dataset_name_prebatched_padded_{fold}_n_{n_folds}_f_{foldlist}.bin
49
- fold_types = [('test', test_folds), ('train', train_folds)]
50
- for fold_type, folds in fold_types:
51
- if not folds:
52
- continue
53
- foldlist_str = '_'.join(map(str, folds))
54
- for i in range(chunks):
55
- prebatched_file = os.path.join(
56
- save_dir,
57
- f"{dataset_name}_prebatched_padded_{i}_n_{n_folds}_f_{foldlist_str}.bin"
58
- )
59
- if not os.path.exists(prebatched_file):
60
- missing_files.append(prebatched_file)
61
-
62
- # Print results for the current dataset
63
- if missing_files:
64
- all_files_exist = False
65
- print(f" Missing files for dataset '{dataset_name}':")
66
- for missing_file in missing_files:
67
- print(f" - {missing_file}")
68
-
69
- # Optionally rerun data prep
70
- if rerun:
71
- print(f" Reprocessing dataset '{dataset_name}' ...")
72
- prep_command = f"bash/prep_data.sh {yaml_file} {dataset_name} {chunks}"
73
- try:
74
- subprocess.run(prep_command, shell=True, check=True)
75
- except subprocess.CalledProcessError as e:
76
- print(f" Could NOT reprocess '{dataset_name}': {e}")
77
- else:
78
- print(f" All files exist for dataset '{dataset_name}'.")
79
-
80
- # Final summary
81
- if all_files_exist:
82
- print("\nAll required files exist for all datasets.")
83
- else:
84
- print("\nSome files are missing.")
85
-
86
- except Exception as e:
87
- print(f"Error processing {yaml_file}: {e}")
88
-
89
- def main(pargs):
90
- # Base directory containing the YAML files
91
- base_directory = os.getcwd() + "/configs/"
92
-
93
- if pargs.configs:
94
- configs = [p.strip() for p in pargs.configs.split(',')]
95
- else:
96
- configs = [
97
- "attention/ttH_CP_even_vs_odd.yaml",
98
-
99
- "stats_100K/finetuning_ttH_CP_even_vs_odd.yaml",
100
- "stats_100K/pretraining_multiclass.yaml",
101
- "stats_100K/ttH_CP_even_vs_odd.yaml",
102
-
103
- "stats_all/finetuning_ttH_CP_even_vs_odd.yaml",
104
- "stats_all/pretraining_multiclass.yaml",
105
- "stats_all/ttH_CP_even_vs_odd.yaml",
106
- ]
107
-
108
- for config in configs:
109
- yaml_file = os.path.join(base_directory, config)
110
- if os.path.exists(yaml_file):
111
- print(f"\nProcessing file: {config}")
112
- check_dataset_files(yaml_file, pargs.rerun)
113
- else:
114
- print(f"File not found: {yaml_file}")
115
-
116
- if __name__ == "__main__":
117
- parser = argparse.ArgumentParser(description="Check YAML config files")
118
- parser.add_argument(
119
- "--configs", "-c",
120
- type=str,
121
- required=False,
122
- help="Comma-separated list of YAML config paths relative to base directory"
123
- )
124
- parser.add_argument(
125
- "--rerun", "-r",
126
- action='store_true', # Correct way for a boolean flag
127
- help="Automatically re-run data processing to fix missing files"
128
- )
129
- args = parser.parse_args()
130
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nemo/scripts/find_free_port.py DELETED
@@ -1,12 +0,0 @@
1
- # find_free_port.py
2
- def find_free_port():
3
- import socket
4
- from contextlib import closing
5
-
6
- with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
7
- s.bind(('', 0))
8
- s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
9
- return str(s.getsockname()[1])
10
-
11
- if __name__ == "__main__":
12
- print(find_free_port())
 
 
 
 
 
 
 
 
 
 
 
 
 
nemo/scripts/inference.py DELETED
@@ -1,289 +0,0 @@
1
- import sys
2
- import os
3
- file_path = os.getcwd()
4
- sys.path.append(file_path)
5
- import os
6
- import argparse
7
- import yaml
8
- import gc
9
-
10
- import torch
11
- import dgl
12
- from dgl.data import DGLDataset
13
- from dgl.dataloading import GraphDataLoader
14
- from torch.utils.data import SubsetRandomSampler, SequentialSampler
15
-
16
- class CustomPreBatchedDataset(DGLDataset):
17
- def __init__(self, start_dataset, batch_size, chunkno=0, chunks=1, mask_fn=None, drop_last=False, shuffle=False, **kwargs):
18
- self.start_dataset = start_dataset
19
- self.batch_size = batch_size
20
- self.mask_fn = mask_fn or (lambda x: torch.ones(len(x), dtype=torch.bool))
21
- self.drop_last = drop_last
22
- self.shuffle = shuffle
23
- self.chunkno = chunkno
24
- self.chunks = chunks
25
- super().__init__(name=start_dataset.name + '_custom_prebatched', save_dir=start_dataset.save_dir)
26
-
27
- def process(self):
28
- mask = self.mask_fn(self.start_dataset)
29
- indices = torch.arange(len(self.start_dataset))[mask]
30
- print(f"Number of elements after masking: {len(indices)}") # Debugging print
31
-
32
- # --- CHUNK SPLITTING ---
33
- total = len(indices)
34
- if self.chunks == 1:
35
- chunk_indices = indices
36
- print(f"Chunks=1, using all {total} indices.")
37
- else:
38
- chunk_size = (total + self.chunks - 1) // self.chunks
39
- start = self.chunkno * chunk_size
40
- end = min((self.chunkno + 1) * chunk_size, total)
41
- chunk_indices = indices[start:end]
42
- print(f"Working on chunk {self.chunkno}/{self.chunks}: indices {start}:{end} (total {len(chunk_indices)})")
43
-
44
- if self.shuffle:
45
- sampler = SubsetRandomSampler(chunk_indices)
46
- else:
47
- sampler = SequentialSampler(chunk_indices)
48
-
49
- self.dataloader = GraphDataLoader(
50
- self.start_dataset,
51
- sampler=sampler,
52
- batch_size=self.batch_size,
53
- drop_last=self.drop_last
54
- )
55
-
56
- def __getitem__(self, idx):
57
- if isinstance(idx, int):
58
- idx = [idx]
59
- sampler = SequentialSampler(idx)
60
- dloader = GraphDataLoader(self.start_dataset, sampler=sampler, batch_size=self.batch_size, drop_last=False)
61
- return next(iter(dloader))
62
-
63
- def __len__(self):
64
- mask = self.mask_fn(self.start_dataset)
65
- indices = torch.arange(len(self.start_dataset))[mask]
66
- total = len(indices)
67
- if self.chunks == 1:
68
- return total
69
- chunk_size = (total + self.chunks - 1) // self.chunks
70
- start = self.chunkno * chunk_size
71
- end = min((self.chunkno + 1) * chunk_size, total)
72
- return end - start
73
-
74
- def include_config(conf):
75
- if 'include' in conf:
76
- for i in conf['include']:
77
- with open(i) as f:
78
- conf.update(yaml.load(f, Loader=yaml.FullLoader))
79
- del conf['include']
80
-
81
- def load_config(config_file):
82
- with open(config_file) as f:
83
- conf = yaml.load(f, Loader=yaml.FullLoader)
84
- include_config(conf)
85
- return conf
86
-
87
- def main():
88
-
89
- parser = argparse.ArgumentParser()
90
- add_arg = parser.add_argument
91
- add_arg('--config', type=str, nargs='+', required=True, help="List of config files")
92
- add_arg('--target', type=str, required=True)
93
- add_arg('--destination', type=str, default='')
94
- add_arg('--chunkno', type=int, default=0)
95
- add_arg('--chunks', type=int, default=1)
96
- add_arg('--write', action='store_true')
97
- add_arg('--ckpt', type=int, default=-1)
98
- add_arg('--var', type=str, default='Test_AUC')
99
- add_arg('--mode', type=str, default='max')
100
- add_arg('--clobber', action='store_true')
101
- add_arg('--tree', type=str, default='')
102
- add_arg('--branch_name', type=str, nargs='+', required=True, help="List of branch names corresponding to configs")
103
- args = parser.parse_args()
104
-
105
- if(len(args.config) != len(args.branch_name)):
106
- print(f"configs and branch names do not match")
107
- return
108
-
109
- config = load_config(args.config[0])
110
-
111
- # --- OUTPUT DESTINATION LOGIC ---
112
- if args.destination == '':
113
- base_dest = os.path.join(config['Training_Directory'], 'inference/', os.path.split(args.target)[1])
114
- else:
115
- base_dest = args.destination
116
-
117
- base_dest = base_dest.replace('.root', '').replace('.npz', '')
118
- if args.chunks > 1:
119
- chunked_dest = f"{base_dest}_chunk{args.chunkno}"
120
- else:
121
- chunked_dest = base_dest
122
- chunked_dest += '.root' if args.write else '.npz'
123
- args.destination = chunked_dest
124
-
125
- # --- FILE EXISTENCE CHECK ---
126
- if os.path.exists(args.destination):
127
- print(f'File {args.destination} already exists.')
128
- if args.clobber:
129
- print('Clobbering.')
130
- else:
131
- print('Exiting.')
132
- return
133
- else:
134
- print(f'Writing to {args.destination}')
135
-
136
- import time
137
- start = time.time()
138
- import torch
139
- from array import array
140
- import numpy as np
141
- from root_gnn_base import batched_dataset as dataset
142
- from root_gnn_base import utils
143
- end = time.time()
144
- print('Imports finished in {:.2f} seconds'.format(end - start))
145
-
146
- start = time.time()
147
- dset_config = config['Datasets'][list(config['Datasets'].keys())[0]]
148
- if dset_config['class'] == 'LazyDataset':
149
- dset_config['class'] = 'EdgeDataset'
150
- elif dset_config['class'] == 'LazyMultiLabelDataset':
151
- dset_config['class'] = 'MultiLabelDataset'
152
- elif dset_config['class'] == 'PhotonIDDataset':
153
- dset_config['class'] = 'UnlazyPhotonIDDataset'
154
- elif dset_config['class'] == 'kNNDataset':
155
- dset_config['class'] = 'UnlazyKNNDataset'
156
- dset_config['args']['raw_dir'] = os.path.split(args.target)[0]
157
- dset_config['args']['file_names'] = os.path.split(args.target)[1]
158
- dset_config['args']['save'] = False
159
- dset_config['args']['chunks'] = args.chunks
160
- dset_config['args']['process_chunks'] = [args.chunkno,]
161
- dset_config['args']['selections'] = []
162
-
163
- dset_config['args']['save_dir'] = os.path.dirname(args.destination)
164
-
165
- if args.tree != '':
166
- dset_config['args']['tree_name'] = args.tree
167
-
168
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
169
-
170
- dstart = time.time()
171
- dset = utils.buildFromConfig(dset_config)
172
- dend = time.time()
173
- print('Dataset finished in {:.2f} seconds'.format(dend - dstart))
174
-
175
- print(dset)
176
-
177
- batch_size = config['Training']['batch_size']
178
- lstart = time.time()
179
- loader = CustomPreBatchedDataset(
180
- dset,
181
- batch_size,
182
- chunkno=args.chunkno,
183
- chunks=args.chunks
184
- )
185
- loader.process()
186
- lend = time.time()
187
- print('Loader finished in {:.2f} seconds'.format(lend - lstart))
188
- sample_graph, _, _, global_sample = loader[0]
189
-
190
- print('dset length =', len(dset))
191
- print('loader length =', len(loader))
192
-
193
- all_scores = {}
194
- all_labels = {}
195
- all_tracking = {}
196
- with torch.no_grad():
197
- for config_file, branch in zip(args.config, args.branch_name):
198
- config = load_config(config_file)
199
- model = utils.buildFromConfig(config['Model'], {'sample_graph' : sample_graph, 'sample_global': global_sample}).to(device)
200
- if args.ckpt < 0:
201
- ep, checkpoint = utils.get_best_epoch(config, var=args.var, mode='max', device=device)
202
- else:
203
- ep, checkpoint = utils.get_specific_epoch(config, args.ckpt, device=device)
204
- # Remove distributed/compiled prefixes if present
205
- mds_copy = {}
206
- for key in checkpoint['model_state_dict'].keys():
207
- newkey = key.replace('module.', '')
208
- newkey = newkey.replace('_orig_mod.', '')
209
- mds_copy[newkey] = checkpoint['model_state_dict'][key]
210
- model.load_state_dict(mds_copy)
211
- model.eval()
212
-
213
- end = time.time()
214
- print('Model and dataset finished in {:.2f} seconds'.format(end - start))
215
- print('Starting inference')
216
- start = time.time()
217
-
218
- finish_fn = torch.nn.Sigmoid()
219
- if 'Loss' in config:
220
- finish_fn = utils.buildFromConfig(config['Loss']['finish'])
221
-
222
- scores = []
223
- labels = []
224
- tracking_info = []
225
- ibatch = 0
226
-
227
- for batch, label, track, globals in loader.dataloader:
228
- batch = batch.to(device)
229
- pred = model(batch, globals.to(device))
230
- ibatch += 1
231
- if (finish_fn.__class__.__name__ == "ContrastiveClusterFinish"):
232
- scores.append(pred.detach().cpu().numpy())
233
- else:
234
- scores.append(finish_fn(pred).detach().cpu().numpy())
235
- labels.append(label.detach().cpu().numpy())
236
- tracking_info.append(track.detach().cpu().numpy())
237
-
238
- score_size = scores[0].shape[1] if len(scores[0].shape) > 1 else 1
239
- scores = np.concatenate(scores)
240
- labels = np.concatenate(labels)
241
- tracking_info = np.concatenate(tracking_info)
242
- end = time.time()
243
-
244
- print('Inference finished in {:.2f} seconds'.format(end - start))
245
- all_scores[branch] = scores
246
- all_labels[branch] = labels
247
- all_tracking[branch] = tracking_info
248
-
249
-
250
- if args.write:
251
- import uproot
252
- import awkward as ak
253
-
254
- # Open the original ROOT file and get the tree
255
- infile = uproot.open(args.target)
256
- tree = infile[dset_config['args']['tree_name']]
257
-
258
- # Read the original tree as an awkward array
259
- original_data = tree.arrays(library="ak")
260
-
261
- # Prepare new branches as dicts of arrays
262
- new_branches = {}
263
- n_entries = len(original_data)
264
- for branch, scores in all_scores.items():
265
- # Ensure the scores array is the right length
266
- scores = np.asarray(scores)
267
- if scores.shape[0] != n_entries:
268
- raise ValueError(f"Branch '{branch}' has {scores.shape[0]} entries, but tree has {n_entries}")
269
- new_branches[branch] = scores
270
-
271
- # Merge all arrays (original + new branches)
272
- # Convert awkward to dict of numpy arrays for uproot
273
- out_dict = {k: np.asarray(v) for k, v in ak.to_numpy(original_data).items()}
274
- out_dict.update(new_branches)
275
-
276
- # Write to new ROOT file
277
- os.makedirs(os.path.split(args.destination)[0], exist_ok=True)
278
- with uproot.recreate(args.destination) as outfile:
279
- outfile.mktree(dset_config['args']['tree_name'], {k: v.dtype for k, v in out_dict.items()})
280
- outfile[dset_config['args']['tree_name']].extend(out_dict)
281
-
282
- print(f"Wrote new ROOT file {args.destination} with new branches {list(new_branches.keys())}")
283
-
284
- else:
285
- os.makedirs(os.path.split(args.destination)[0], exist_ok=True)
286
- np.savez(args.destination, scores=all_scores, labels=all_labels, tracking_info=all_tracking)
287
-
288
- if __name__ == '__main__':
289
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nemo/scripts/prep_data.py DELETED
@@ -1,44 +0,0 @@
1
- import sys
2
- import os
3
- file_path = os.getcwd()
4
- sys.path.append(file_path)
5
-
6
- import root_gnn_base.utils as utils
7
- import argparse
8
- from root_gnn_base.batched_dataset import PreBatchedDataset
9
- from root_gnn_base.batched_dataset import LazyPreBatchedDataset
10
-
11
- def main():
12
- parser = argparse.ArgumentParser()
13
- add_arg = parser.add_argument
14
- add_arg('--config', type=str, required=True)
15
- add_arg('--dataset', type=str, required=True)
16
- add_arg('--chunk', type=int, default=0)
17
- add_arg('--shuffle_mode', action='store_true', help='Shuffle the dataset before training.')
18
- add_arg('--drop_last', action='store_false', help='Set drop_last to False if the flag is provided. Defaults to True.')
19
- args = parser.parse_args()
20
-
21
- config = utils.load_config(args.config)
22
- dset_config = config['Datasets'][args.dataset]
23
- batch_size = config['Training']['batch_size']
24
- if not args.shuffle_mode:
25
- dset = utils.buildFromConfig(dset_config, {'process_chunks': [args.chunk,]})
26
- else:
27
- dset = utils.buildFromConfig(dset_config)
28
- if 'batch_size' in dset_config:
29
- batch_size = dset_config['batch_size']
30
-
31
- shuffle_chunks = dset_config.get('shuffle_chunks', 10)
32
- padding_mode = dset_config.get('padding_mode', 'STEPS')
33
- fold_conf = dset_config["folding"]
34
- print(f"shuffle_chunks = {shuffle_chunks}, args.chunk = {args.chunk}, padding_mode = {padding_mode}")
35
- if dset_config["class"] == "LazyMultiLabelDataset":
36
- LazyPreBatchedDataset(start_dataset = dset, batch_size = batch_size, mask_fn = utils.fold_selection(fold_conf, "train"), suffix = utils.fold_selection_name(fold_conf, "train"), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last, hidden_size=config['Model']['args']['hid_size'] )
37
- LazyPreBatchedDataset(start_dataset = dset, batch_size = batch_size, mask_fn = utils.fold_selection(fold_conf, "test"), suffix = utils.fold_selection_name(fold_conf, 'test'), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last, hidden_size=config['Model']['args']['hid_size'])
38
-
39
- else:
40
- PreBatchedDataset(dset, batch_size, utils.fold_selection(fold_conf, "train"), suffix = utils.fold_selection_name(fold_conf, "train"), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last,hidden_size=config['Model']['args']['hid_size'])
41
- PreBatchedDataset(dset, batch_size, utils.fold_selection(fold_conf, "test"), suffix = utils.fold_selection_name(fold_conf, 'test'), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last,hidden_size=config['Model']['args']['hid_size'] )
42
-
43
- if __name__ == "__main__":
44
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nemo/scripts/training_script.py DELETED
@@ -1,463 +0,0 @@
1
- import argparse
2
- import time
3
- import datetime
4
- import yaml
5
- import os
6
- import glob
7
-
8
- start_time = time.time()
9
-
10
- import dgl
11
- import torch
12
- import torch.nn as nn
13
-
14
- import sys
15
- file_path = os.getcwd()
16
- sys.path.append(file_path)
17
- import root_gnn_base.batched_dataset as datasets
18
- from root_gnn_base import utils
19
- import root_gnn_base.custom_scheduler as lr_utils
20
- from models import GCN
21
-
22
- import numpy as np
23
- from sklearn.metrics import roc_auc_score
24
- import resource
25
- import gc
26
-
27
- import torch.distributed as dist
28
- import torch.multiprocessing as mp
29
- from torch.utils.data.distributed import DistributedSampler
30
- from torch.nn.parallel import DistributedDataParallel as DDP
31
-
32
- from physicsnemo.models.module import Module
33
- from physicsnemo.models.meta import ModelMetaData
34
- from dataclasses import dataclass
35
-
36
- print("import time: {:.4f} s".format(time.time() - start_time))
37
-
38
- def mem():
39
- print(f'Current memory usage: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024 / 1024:.2f} GB')
40
-
41
- def gpu_mem():
42
- print()
43
- print('GPU Memory Usage:')
44
- print(f'Current GPU memory usage: {torch.cuda.memory_allocated() / 1024 / 1024 / 1024:.2f} GB')
45
- print(f'Current GPU cache usage: {torch.cuda.memory_cached() / 1024 / 1024 / 1024:.2f} GB')
46
- print(f'Current GPU max memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024:.2f} GB')
47
- print(f'Current GPU max cache usage: {torch.cuda.memory_reserved() / 1024 / 1024 / 1024:.2f} GB')
48
- mem()
49
-
50
- def train(train_loaders, test_loaders, model, device, config, args):
51
- restart = args.restart
52
-
53
- if ('Loss' in config):
54
- loss_fcn = utils.buildFromConfig(config['Loss'], {'reduction':'none'})
55
- finish_fn = utils.buildFromConfig(config['Loss']['finish'])
56
- else:
57
- loss_fcn = torch.nn.BCEWithLogitsLoss(reduction='none')
58
- finish_fn = torch.nn.Sigmoid()
59
- optimizer = torch.optim.Adam(model.parameters(), lr=config['Training']['learning_rate'])
60
- scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma = config['Training']['gamma'])
61
-
62
- early_termination = utils.EarlyStop()
63
- if 'early_termination' in config['Training']:
64
- early_termination.patience = config['Training']['early_termination']['patience']
65
- early_termination.threshold = config['Training']['early_termination']['threshold']
66
- early_termination.mode = config['Training']['early_termination']['mode']
67
-
68
- starting_epoch = 0
69
- if not restart:
70
- last_ep, checkpoint = utils.get_last_epoch(config)
71
- if (last_ep >= 0):
72
- model.load_state_dict(checkpoint['model_state_dict'])
73
- optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
74
- starting_epoch = checkpoint['epoch'] + 1
75
- early_termination = utils.EarlyStop.load_from_dict(checkpoint['early_stop'])
76
- if early_termination.should_stop:
77
- print(f"Early Termination at Epoch {last_ep}")
78
- return
79
- print(f"Loaded epoch {checkpoint['epoch']} from checkpoint")
80
- else:
81
- print("Starting from scratch")
82
- log = open(config['Training_Directory'] + '/training.log', 'a', buffering=1)
83
- else:
84
- # Remove all *.pt and *.npz files in the Training_Directory
85
- for pattern in ('*.pt', '*.npz'):
86
- for file in glob.glob(os.path.join(config['Training_Directory'], pattern)):
87
- os.remove(file)
88
- log = open(config['Training_Directory'] + '/training.log', 'w', buffering=1)
89
-
90
- train_cyclers = []
91
- for loader in train_loaders:
92
- train_cyclers.append(utils.cycler((loader)))
93
-
94
- log.write(f'Training {config["Training_Name"]} {datetime.datetime.now()} \n')
95
- print(f"Starting training for {config['Training']['epochs']} epochs")
96
-
97
- if hasattr(train_loaders[0].dataset, 'padding_mode'):
98
- is_padded = train_loaders[0].dataset.padding_mode != 'NONE'
99
- if (train_loaders[0].dataset.padding_mode == 'NODE'):
100
- is_padded = False
101
- else:
102
- is_padded = False
103
-
104
- lr_utils.print_LR(optimizer)
105
-
106
- # training loop
107
- for epoch in range(starting_epoch, config['Training']['epochs']):
108
- start = time.time()
109
- run = start
110
- if (args.profile):
111
- if (epoch == 0):
112
- torch.cuda.cudart().cudaProfilerStart()
113
- torch.cuda.nvtx.range_push("Epoch Start")
114
-
115
- # training
116
-
117
- model.train()
118
-
119
- ibatch = 0
120
- total_loss = 0
121
- for batched_graph, labels, _, global_feats in train_loaders[0]:
122
- batch_start = time.time()
123
- logits = torch.tensor([])
124
- tlabels = torch.tensor([])
125
- weights = torch.tensor([])
126
- batch_lengths = []
127
- for cycler in train_cyclers:
128
- graph, label, track, global_feats = next(cycler)
129
- graph = graph.to(device)
130
- label = label.to(device)
131
- track = track.to(device)
132
- global_feats = global_feats.to(device)
133
- if is_padded: #Padding the globals to match padded graphs.
134
- global_feats = torch.concatenate((global_feats, torch.zeros(1, len(global_feats[0])).to(device)))
135
- load = time.time()
136
- if (args.profile):
137
- torch.cuda.nvtx.range_push("Model Forward")
138
- if (len(logits) == 0):
139
- logits = model(graph.ndata['features'], graph.edata['features'], graph)
140
- tlabels = label
141
- weights = track[:,1]
142
- else:
143
- logits = torch.concatenate((logits, model(graph.ndata['features'], graph.edata['features'], graph)), dim=0)
144
- tlabels = torch.concatenate((tlabels, label), dim=0)
145
- weights = torch.concatenate((weights, track[:,1]), dim=0)
146
- batch_lengths.append(logits.shape[0] - 1)
147
-
148
- if (args.profile):
149
- torch.cuda.nvtx.range_pop() # popping model forward
150
-
151
- if is_padded:
152
- keepmask = torch.full_like(logits[:,0], True, dtype=torch.bool)
153
- keepmask[batch_lengths] = False
154
- logits = logits[keepmask]
155
- tlabels = tlabels.to(torch.float)
156
- if logits.shape[1] == 1 and loss_fcn.__class__.__name__ == 'BCEWithLogitsLoss':
157
- logits = logits[:,0]
158
- tlabels = tlabels.to(torch.float)
159
- if loss_fcn.__class__.__name__ == 'CrossEntropyLoss':
160
- tlabels = tlabels.to(torch.long)
161
-
162
- if args.abs:
163
- weights = torch.abs(weights)
164
-
165
- loss = loss_fcn(logits, tlabels.to(device))
166
- # Normalize loss within each label
167
- unique_labels = torch.unique(tlabels) # Get unique labels
168
- normalized_loss = 0.0
169
- for label in unique_labels:
170
- # Mask for samples belonging to the current label
171
- label_mask = (tlabels == label)
172
- # Extract weights and losses for the current label
173
- label_weights = weights[label_mask]
174
- label_losses = loss[label_mask]
175
- # Compute normalized loss for the current label
176
- label_loss = torch.sum(label_weights * label_losses) / torch.sum(label_weights)
177
- # Add to the total normalized loss
178
- normalized_loss += label_loss
179
- loss = normalized_loss / len(unique_labels)
180
-
181
- if (args.profile):
182
- torch.cuda.nvtx.range_push("Model Backward")
183
-
184
- optimizer.zero_grad()
185
- loss.backward()
186
- optimizer.step()
187
- total_loss += loss.detach().cpu().item()
188
-
189
- if (args.profile):
190
- torch.cuda.nvtx.range_pop() # pop model backward
191
-
192
- ibatch += 1
193
-
194
- if ibatch % 1000 == 0:
195
- print(f'Batch {ibatch} out of {len(train_loaders[0])}', end='\r')
196
- # gpu_mem()
197
- else:
198
- print("Epoch Done.")
199
-
200
- # validation
201
-
202
- scores = []
203
- labels = []
204
- weights = []
205
- model.eval()
206
-
207
- if (args.profile):
208
- torch.cuda.nvtx.range_push("Model Evaluation")
209
-
210
- with torch.no_grad():
211
- for loader in test_loaders:
212
- for batch, label, track, global_feats in loader:
213
- #Don't use compiled model for testing since we can't control the batch size.
214
- #We could before, but it assumes each dataset has the same number of batches...
215
- if is_padded:
216
- global_feats = torch.cat([global_feats, torch.zeros(1, len(global_feats[0]))])
217
-
218
- # batch_scores = model(batch.to(device), global_feats.to(device))
219
- batch_scores = model(graph.ndata['features'].to(device), graph.edata['features'].to(device), graph)
220
-
221
- if is_padded:
222
- scores.append(batch_scores[:-1,:])
223
- else:
224
- scores.append(batch_scores)
225
- labels.append(label)
226
- weights.append(track[:,1])
227
-
228
- if (args.profile):
229
- torch.cuda.nvtx.range_pop() # pop evaluation
230
-
231
- if scores == []: #If validation set is empty.
232
- continue
233
-
234
- logits = torch.concatenate(scores).to(device)
235
- labels = torch.concatenate(labels).to(device)
236
- weights = torch.concatenate(weights).to(device)
237
-
238
- wgt_mask = weights > 0
239
-
240
- if args.abs:
241
- weights = torch.abs(weights)
242
-
243
- print(f"Num batches trained = {ibatch}")
244
-
245
- if (loss_fcn.__class__.__name__ == "ContrastiveClusterLoss"):
246
- scores = logits
247
- preds = scores
248
- accuracy = 0
249
- test_auc = 0
250
- acc = 0
251
- contrastive_cluster_loss = finish_fn(logits)
252
-
253
- elif (loss_fcn.__class__.__name__ == "MultiLabelLoss"):
254
- scores = finish_fn(logits)
255
- preds = torch.round(scores)
256
- multilabel_accuracy = []
257
- threshold = 0.1
258
- for i in range(len(labels[0])):
259
- multilabel_accuracy.append(torch.sum(preds[:, i].to("cpu") == labels[:, i].to("cpu")) / len(labels))
260
- test_auc = 0
261
- acc = np.mean(multilabel_accuracy)
262
-
263
- elif logits.shape[1] == 1 and loss_fcn.__class__.__name__ == 'BCEWithLogitsLoss':
264
- test_auc = 0
265
- acc = 0
266
- logits = logits[:,0]
267
- scores = finish_fn(logits)
268
- labels =labels.to(torch.float)
269
- preds = scores > 0.5
270
- test_auc = roc_auc_score(labels[wgt_mask].to("cpu") == 1, scores[wgt_mask].to("cpu"), sample_weight=weights[wgt_mask].to("cpu"))
271
- acc = torch.sum(preds.to("cpu") == labels.to("cpu")) / len(labels)
272
-
273
- elif logits.shape[1] == 1 and loss_fcn.__class__.__name__ == 'MSELoss':
274
- logits = logits[:,0]
275
- scores = finish_fn(logits)
276
- labels = labels.to(torch.float)
277
- acc = 0
278
- test_auc = 0
279
-
280
- else:
281
- preds = torch.argmax(logits, dim=1)
282
- scores = finish_fn(logits)
283
- if labels.dim() == 1: #Multi-class
284
- acc = torch.sum(preds.to("cpu") == labels.to("cpu")) / len(labels) #TODO: Make each class weighted equally?
285
-
286
- labels = labels.to("cpu")
287
- weights = weights.to("cpu")
288
- logits = logits.to("cpu")
289
- wgt_mask = wgt_mask.to("cpu")
290
-
291
- labels_onehot = np.zeros((len(labels), len(scores[0])))
292
- labels_onehot[np.arange(len(labels)), labels] = 1
293
-
294
- try:
295
- #test_auc = roc_auc_score(labels[wgt_mask].to("cpu") == 1, scores[wgt_mask].to("cpu"), multi_class='ovr', sample_weight=weights[wgt_mask].to("cpu"))
296
- if (len(scores[0]) != config["Model"]["args"]["out_size"]):
297
- print("ERROR: The out_size and the number of class labels don't match! Please check config.")
298
- test_auc = roc_auc_score(labels_onehot[wgt_mask], scores[wgt_mask].to("cpu"), multi_class='ovr', sample_weight=weights[wgt_mask].to("cpu"))
299
- except ValueError:
300
- test_auc = np.nan
301
- else: #Multi-loss
302
- acc = torch.sum(preds.to("cpu") == labels[:,0].to("cpu")) / len(labels)
303
- try:
304
- test_auc = roc_auc_score(labels[:,0][wgt_mask].to("cpu") == 1, scores[wgt_mask].to("cpu"), multi_class='ovr', sample_weight=weights[wgt_mask].to("cpu"))
305
- except ValueError:
306
- test_auc = np.nan
307
-
308
- if (loss_fcn.__class__.__name__ == "MultiLabelLoss"):
309
- multilabel_log_str = "MultiLabel_Accuracy "
310
- for accuracy in multilabel_accuracy:
311
- multilabel_log_str += f" | {accuracy:.4f}"
312
- log.write(multilabel_log_str + '\n')
313
- print(multilabel_log_str, flush=True)
314
- elif (loss_fcn.__class__.__name__ == "ContrastiveClusterLoss"):
315
- contrastive_cluster_log_str = "ContrastiveClusterLoss "
316
- contrastive_cluster_log_str += f"Contrastive Loss: {contrastive_cluster_loss[0]:.4f}, Clustering Loss: {contrastive_cluster_loss[1]:.4f}, Variance Loss: {contrastive_cluster_loss[2]:.4f}"
317
- log.write(contrastive_cluster_log_str + '\n')
318
- print(contrastive_cluster_log_str, flush=True)
319
-
320
- test_loss = loss_fcn(logits, labels)
321
- # Normalize loss within each label
322
- unique_labels = torch.unique(labels) # Get unique labels
323
- normalized_loss = 0.0
324
- for label in unique_labels:
325
- # Mask for samples belonging to the current label
326
- label_mask = (labels == label)
327
- # Extract weights and losses for the current label
328
- label_weights = weights[label_mask]
329
- label_losses = test_loss[label_mask]
330
- # Compute normalized loss for the current label
331
- label_loss = torch.sum(label_weights * label_losses) / torch.sum(label_weights)
332
- # Add to the total normalized loss
333
- normalized_loss += label_loss
334
- test_loss = normalized_loss / len(unique_labels)
335
-
336
-
337
- end = time.time()
338
- log_str = "Epoch {:05d} | LR {:.4e} | Loss {:.4f} | Accuracy {:.4f} | Test_Loss {:.4f} | Test_AUC {:.4f} | Time {:.4f} s".format(
339
- epoch, optimizer.param_groups[0]['lr'], total_loss/ibatch, acc, test_loss, test_auc, end - start
340
- )
341
- log.write(log_str + '\n')
342
- print(log_str, flush=True)
343
-
344
- state_dict = model.state_dict()
345
-
346
- torch.save({
347
- 'epoch': epoch,
348
- 'model_state_dict': state_dict,
349
- 'optimizer_state_dict': optimizer.state_dict(),
350
- 'early_stop': early_termination.to_dict()
351
- }, os.path.join(config['Training_Directory'], f"model_epoch_{epoch}.pt"))
352
- np.savez(os.path.join(config['Training_Directory'], f'model_epoch_{epoch}.npz'), scores=scores.to("cpu"), labels=labels.to("cpu"))
353
-
354
- early_termination.update(test_loss)
355
- if early_termination.should_stop:
356
- log_str = f"Early Termination at Epoch {epoch}"
357
- log.write(log_str + "\n")
358
- print(log_str)
359
- log_str = early_termination.to_str()
360
- log.write(log_str + "\n")
361
- print(log_str)
362
- break
363
-
364
- scheduler.step()
365
-
366
- if (args.profile):
367
- torch.cuda.nvtx.range_pop() # pop epoch
368
-
369
- log.close()
370
-
371
- def main(args=None):
372
- config = utils.load_config(args.config)
373
-
374
- if not os.path.exists(config['Training_Directory']):
375
- os.makedirs(config['Training_Directory'], exist_ok=True)
376
- with open(config['Training_Directory'] + '/config.yaml', 'w') as f:
377
- yaml.dump(config, f)
378
- batch_size = config["Training"]["batch_size"]
379
-
380
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
381
-
382
- train_loaders = []
383
- test_loaders = []
384
- val_loaders = []
385
- load_start = time.time()
386
-
387
- torch.backends.cuda.matmul.allow_tf32 = True
388
-
389
- # ldr_type = datasets.LazyPreBatchedDataset if args.lazy else datasets.PreBatchedDataset
390
- ldr_type = datasets.LazyPreBatchedDataset
391
-
392
- for dset_conf in config["Datasets"]:
393
- dset = utils.buildFromConfig(config["Datasets"][dset_conf])
394
- if 'batch_size' in config["Datasets"][dset_conf]:
395
- batch_size = config["Datasets"][dset_conf]['batch_size']
396
- fold_conf = config["Datasets"][dset_conf]["folding"]
397
- shuffle_chunks = config["Datasets"][dset_conf].get("shuffle_chunks", 10)
398
- padding_mode = config["Datasets"][dset_conf].get("padding_mode", "STEPS")
399
- mask_fn = utils.fold_selection(fold_conf, "train")
400
- if args.preshuffle:
401
- # ldr = ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=mask_fn, suffix = utils.fold_selection_name(fold_conf, 'train'), chunks = shuffle_chunks, padding_mode = padding_mode, use_ddp = args.multigpu, rank=rank, world_size=world_size)
402
- ldr = ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=mask_fn, suffix = utils.fold_selection_name(fold_conf, 'train'), chunks = shuffle_chunks, padding_mode = padding_mode, hidden_size = 128)
403
- gsamp, _, _, global_samp = ldr[0]
404
- sampler = None
405
- train_loaders.append(torch.utils.data.DataLoader(ldr, batch_size = None, num_workers = 0, sampler = sampler))
406
-
407
- sampler = None
408
- ldr = ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=mask_fn, suffix = utils.fold_selection_name(fold_conf, 'test'), chunks = shuffle_chunks, padding_mode = padding_mode, hidden_size=128)
409
- test_loaders.append(torch.utils.data.DataLoader(ldr, batch_size = None, num_workers = 0, sampler=sampler))
410
-
411
- if "validation" in fold_conf:
412
- val_loaders.append(torch.utils.data.DataLoader((ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=utils.fold_selection(fold_conf, "validation"), suffix = utils.fold_selection_name(fold_conf, 'validation'), chunks = shuffle_chunks, hid_size=128, padding_mode = padding_mode)), batch_size = None, num_workers = 0, sampler = sampler))
413
- else:
414
- print("No validation set for dataset ", dset_conf)
415
- else:
416
- train_loaders.append(datasets.GetBatchedLoader(dset, batch_size, utils.fold_selection(fold_conf, "train")))
417
- gsamp, _, _, global_samp = dset[0]
418
- test_loaders.append(datasets.GetBatchedLoader(dset, batch_size, utils.fold_selection(fold_conf, "test")))
419
- if "validation" in fold_conf:
420
- val_loaders.append(datasets.GetBatchedLoader(dset, batch_size, utils.fold_selection(fold_conf, "validation")))
421
- else:
422
- print("No validation set for dataset ", dset_conf)
423
-
424
- load_end = time.time()
425
- print("Load time: {:.4f} s".format(load_end - load_start))
426
-
427
- # model = utils.buildFromConfig(config["Model"], {'sample_graph': gsamp, 'sample_global': global_samp, 'seed': args.seed}).to(device)
428
-
429
- model = utils.buildFromConfig(config["Model"]).to(device)
430
-
431
- # @dataclass
432
- # class MetaData(ModelMetaData):
433
- # name: str = "Edge_Network"
434
-
435
- # physicsnemo_model = Module.from_torch(model)
436
- # print(f"physicsnemo_model = {physicsnemo_model}")
437
- # model = physicsnemo_model()
438
-
439
-
440
- pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
441
- print(f"Number of trainable parameters = {pytorch_total_params}")
442
-
443
- # model training
444
- print("Training...")
445
- gpu_mem()
446
- train(train_loaders, test_loaders, model, device, config, args)
447
-
448
- if __name__ == "__main__":
449
- #Handle CLI arguments
450
- parser = argparse.ArgumentParser()
451
- add_arg = parser.add_argument
452
- add_arg("--config", type=str, help="Config file.", required=True)
453
- add_arg("--restart", action="store_true", help="Restart training from scratch.")
454
- add_arg("--preshuffle", action="store_true", help="Shuffle data before training.")
455
- add_arg("--seed", type=int, default=2, help="Sets random seed")
456
- add_arg("--abs", action="store_true", help="Use abs value of per-event weight")
457
- add_arg("--profile", action="store_true", help="use nsight systems profiler")
458
-
459
- pargs = parser.parse_args()
460
-
461
- main(pargs)
462
-
463
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nemo/setup/Dockerfile DELETED
@@ -1,25 +0,0 @@
1
- FROM nvcr.io/nvidia/physicsnemo/physicsnemo:25.06
2
-
3
- WORKDIR /global/cfs/projectdirs/atlas/joshua/GNN4Colliders
4
-
5
- LABEL maintainer.name="Joshua Ho"
6
- LABEL maintainer.email="ho22joshua@berkeley.edu"
7
-
8
- ENV LANG=C.UTF-8
9
-
10
- # Install system dependencies: vim, OpenMPI, and build tools
11
- RUN apt-get update -qq \
12
- && apt-get install -y --no-install-recommends \
13
- wget lsb-release gnupg software-properties-common \
14
- vim \
15
- g++-11 gcc-11 libstdc++-11-dev \
16
- openmpi-bin openmpi-common libopenmpi-dev \
17
- && rm -rf /var/lib/apt/lists/*
18
-
19
- # Install Python packages: mpi4py and jupyter
20
- RUN pip install --no-cache-dir mpi4py jupyter uproot
21
-
22
- # (Optional) Expose Jupyter port
23
- EXPOSE 8888
24
-
25
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nemo/setup/build_image.sh DELETED
@@ -1,4 +0,0 @@
1
- tag=$1
2
- echo $tag
3
- podman-hpc build -t joshuaho/nemo:$tag --platform linux/amd64 .
4
- podman-hpc migrate joshuaho/nemo:$tag
 
 
 
 
 
nemo/setup/environment.yml DELETED
@@ -1,391 +0,0 @@
1
- name: dgl
2
- channels:
3
- - pytorch
4
- - dglteam/label/cu118
5
- - nvidia
6
- - conda-forge
7
- - defaults
8
- dependencies:
9
- - _libgcc_mutex=0.1
10
- - _openmp_mutex=4.5
11
- - _sysroot_linux-64_curr_repodata_hack=3
12
- - afterimage=1.21
13
- - anyio=3.7.1
14
- - appdirs=1.4.4
15
- - argon2-cffi=21.3.0
16
- - argon2-cffi-bindings=21.2.0
17
- - arrow=1.2.3
18
- - asttokens=2.2.1
19
- - async-lru=2.0.4
20
- - atk-1.0=2.38.0
21
- - attrs=23.1.0
22
- - awkward-pandas=2023.8.0
23
- - aws-c-auth=0.7.0
24
- - aws-c-cal=0.6.0
25
- - aws-c-common=0.8.23
26
- - aws-c-compression=0.2.17
27
- - aws-c-event-stream=0.3.1
28
- - aws-c-http=0.7.11
29
- - aws-c-io=0.13.28
30
- - aws-c-mqtt=0.8.14
31
- - aws-c-s3=0.3.13
32
- - aws-c-sdkutils=0.1.11
33
- - aws-checksums=0.1.16
34
- - aws-crt-cpp=0.20.3
35
- - aws-sdk-cpp=1.10.57
36
- - babel=2.12.1
37
- - backcall=0.2.0
38
- - backports=1.0
39
- - backports.functools_lru_cache=1.6.5
40
- - beautifulsoup4=4.12.2
41
- - binutils=2.38
42
- - binutils_impl_linux-64=2.38
43
- - binutils_linux-64=2.38.0
44
- - blas=1.0
45
- - bleach=6.0.0
46
- - brotlipy=0.7.0
47
- - bzip2=1.0.8
48
- - c-ares=1.19.1
49
- - c-compiler=1.5.2
50
- - ca-certificates=2025.4.26
51
- - cached-property=1.5.2
52
- - cached_property=1.5.2
53
- - cairo=1.16.0
54
- - certifi=2024.8.30
55
- - cffi=1.15.1
56
- - cfitsio=4.2.0
57
- - charset-normalizer=2.0.4
58
- - comm=0.1.4
59
- - compilers=1.5.2
60
- - cryptography=41.0.2
61
- - cuda-cudart=11.8.89
62
- - cuda-cupti=11.8.87
63
- - cuda-libraries=11.8.0
64
- - cuda-nvrtc=11.8.89
65
- - cuda-nvtx=11.8.86
66
- - cuda-runtime=11.8.0
67
- - cxx-compiler=1.5.2
68
- - davix=0.8.4
69
- - debugpy=1.6.8
70
- - decorator=5.1.1
71
- - defusedxml=0.7.1
72
- - dgl=1.1.1.cu118
73
- - entrypoints=0.4
74
- - exceptiongroup=1.1.3
75
- - executing=1.2.0
76
- - expat=2.5.0
77
- - ffmpeg=4.3
78
- - fftw=3.3.10
79
- - filelock=3.9.0
80
- - flit-core=3.9.0
81
- - font-ttf-dejavu-sans-mono=2.37
82
- - font-ttf-inconsolata=3.000
83
- - font-ttf-source-code-pro=2.038
84
- - font-ttf-ubuntu=0.83
85
- - fontconfig=2.14.2
86
- - fonts-conda-ecosystem=1
87
- - fonts-conda-forge=1
88
- - fortran-compiler=1.5.2
89
- - fqdn=1.5.1
90
- - freetype=2.12.1
91
- - fribidi=1.0.10
92
- - ftgl=2.4.0
93
- - gcc=11.2.0
94
- - gcc_impl_linux-64=11.2.0
95
- - gcc_linux-64=11.2.0
96
- - gdk-pixbuf=2.42.8
97
- - gettext=0.21.1
98
- - gflags=2.2.2
99
- - gfortran=11.2.0
100
- - gfortran_impl_linux-64=11.2.0
101
- - gfortran_linux-64=11.2.0
102
- - giflib=5.2.1
103
- - gl2ps=1.4.2
104
- - glew=2.1.0
105
- - glog=0.6.0
106
- - gmp=6.2.1
107
- - gmpy2=2.1.2
108
- - gnutls=3.6.15
109
- - graphite2=1.3.13
110
- - graphviz=6.0.2
111
- - gsl=2.7
112
- - gsoap=2.8.123
113
- - gtk2=2.24.33
114
- - gts=0.7.6
115
- - gxx=11.2.0
116
- - gxx_impl_linux-64=11.2.0
117
- - gxx_linux-64=11.2.0
118
- - harfbuzz=7.3.0
119
- - icu=72.1
120
- - idna=3.4
121
- - importlib-metadata=6.8.0
122
- - importlib-resources=6.0.1
123
- - importlib_metadata=6.8.0
124
- - importlib_resources=6.0.1
125
- - intel-openmp=2023.1.0
126
- - ipykernel=6.25.1
127
- - ipyparallel=8.6.1
128
- - ipython=8.12.2
129
- - isoduration=20.11.0
130
- - jedi=0.19.0
131
- - jinja2=3.1.2
132
- - jpeg=9e
133
- - json5=0.9.14
134
- - jsonpointer=2.0
135
- - jsonschema=4.19.0
136
- - jsonschema-specifications=2023.7.1
137
- - jsonschema-with-format-nongpl=4.19.0
138
- - jupyter-lsp=2.2.0
139
- - jupyter_client=8.3.0
140
- - jupyter_core=5.3.0
141
- - jupyter_events=0.7.0
142
- - jupyter_server=2.7.0
143
- - jupyter_server_terminals=0.4.4
144
- - jupyterlab=4.0.5
145
- - jupyterlab_pygments=0.2.2
146
- - jupyterlab_server=2.24.0
147
- - kernel-headers_linux-64=3.10.0
148
- - keyutils=1.6.1
149
- - krb5=1.20.1
150
- - lame=3.100
151
- - lcms2=2.12
152
- - ld_impl_linux-64=2.38
153
- - lerc=3.0
154
- - libabseil=20230125.3
155
- - libarrow=12.0.1
156
- - libblas=3.9.0
157
- - libbrotlicommon=1.0.9
158
- - libbrotlidec=1.0.9
159
- - libbrotlienc=1.0.9
160
- - libcblas=3.9.0
161
- - libcrc32c=1.1.2
162
- - libcublas=11.11.3.6
163
- - libcufft=10.9.0.58
164
- - libcufile=1.7.1.12
165
- - libcurand=10.3.3.129
166
- - libcurl=8.1.2
167
- - libcusolver=11.4.1.48
168
- - libcusparse=11.7.5.86
169
- - libcxx=15.0.7
170
- - libcxxabi=15.0.7
171
- - libdeflate=1.12
172
- - libedit=3.1.20191231
173
- - libev=4.33
174
- - libevent=2.1.12
175
- - libexpat=2.5.0
176
- - libffi=3.4.4
177
- - libgcc-devel_linux-64=11.2.0
178
- - libgcc-ng=13.1.0
179
- - libgd=2.3.3
180
- - libgfortran-ng=11.2.0
181
- - libgfortran5=11.2.0
182
- - libglib=2.76.4
183
- - libglu=9.0.0
184
- - libgomp=13.1.0
185
- - libgoogle-cloud=2.12.0
186
- - libgrpc=1.56.2
187
- - libiconv=1.17
188
- - libidn2=2.3.4
189
- - libllvm13=13.0.1
190
- - libllvm14=14.0.6
191
- - libnghttp2=1.52.0
192
- - libnpp=11.8.0.86
193
- - libnsl=2.0.0
194
- - libnuma=2.0.18
195
- - libnvjpeg=11.9.0.86
196
- - libpng=1.6.39
197
- - libprotobuf=4.23.3
198
- - librsvg=2.54.4
199
- - libsodium=1.0.18
200
- - libsqlite=3.42.0
201
- - libssh2=1.11.0
202
- - libstdcxx-devel_linux-64=11.2.0
203
- - libstdcxx-ng=13.1.0
204
- - libtasn1=4.19.0
205
- - libthrift=0.18.1
206
- - libtiff=4.4.0
207
- - libtool=2.4.7
208
- - libunistring=0.9.10
209
- - libutf8proc=2.8.0
210
- - libuuid=2.38.1
211
- - libwebp=1.2.4
212
- - libwebp-base=1.2.4
213
- - libxcb=1.15
214
- - libxml2=2.10.4
215
- - libzlib=1.2.13
216
- - llvmlite=0.40.1
217
- - lz4-c=1.9.4
218
- - markupsafe=2.1.1
219
- - matplotlib-inline=0.1.6
220
- - metakernel=0.29.5
221
- - mistune=3.0.0
222
- - mkl=2023.1.0
223
- - mkl-service=2.4.0
224
- - mkl_fft=1.3.6
225
- - mkl_random=1.2.2
226
- - mpc=1.1.0
227
- - mpfr=4.0.2
228
- - mpmath=1.3.0
229
- - nbclient=0.8.0
230
- - nbconvert-core=7.7.3
231
- - nbformat=5.9.2
232
- - ncurses=6.4
233
- - nest-asyncio=1.5.6
234
- - nettle=3.7.3
235
- - networkx=3.1
236
- - nlohmann_json=3.11.2
237
- - notebook=7.0.2
238
- - notebook-shim=0.2.3
239
- - numba=0.57.1
240
- - numpy=1.24.3
241
- - numpy-base=1.24.3
242
- - openh264=2.1.1
243
- - openssl=3.3.1
244
- - orc=1.9.0
245
- - overrides=7.4.0
246
- - packaging=23.0
247
- - pandas=2.0.3
248
- - pandocfilters=1.5.0
249
- - pango=1.50.14
250
- - parso=0.8.3
251
- - pcre=8.45
252
- - pcre2=10.40
253
- - pexpect=4.8.0
254
- - pickleshare=0.7.5
255
- - pillow=9.4.0
256
- - pip=23.2.1
257
- - pixman=0.40.0
258
- - pkgutil-resolve-name=1.3.10
259
- - platformdirs=2.6.0
260
- - pooch=1.4.0
261
- - portalocker=2.7.0
262
- - prometheus_client=0.17.1
263
- - prompt-toolkit=3.0.39
264
- - prompt_toolkit=3.0.39
265
- - psutil=5.9.0
266
- - pthread-stubs=0.4
267
- - ptyprocess=0.7.0
268
- - pure_eval=0.2.2
269
- - pyarrow=12.0.1
270
- - pycparser=2.21
271
- - pygments=2.16.1
272
- - pyopenssl=23.2.0
273
- - pysocks=1.7.1
274
- - pythia8=8.309
275
- - python=3.8.17
276
- - python-dateutil=2.8.2
277
- - python-fastjsonschema=2.18.0
278
- - python-json-logger=2.0.7
279
- - python-tzdata=2024.2
280
- - python_abi=3.8
281
- - pytorch=2.0.1
282
- - pytorch-cuda=11.8
283
- - pytorch-mutex=1.0
284
- - pytz=2023.3
285
- - pyyaml=6.0
286
- - pyzmq=25.1.1
287
- - rdma-core=28.9
288
- - re2=2023.03.02
289
- - readline=8.2
290
- - referencing=0.30.2
291
- - requests=2.31.0
292
- - rfc3339-validator=0.1.4
293
- - rfc3986-validator=0.1.1
294
- - root=6.28.0
295
- - root_base=6.28.0
296
- - rpds-py=0.9.2
297
- - s2n=1.3.46
298
- - scipy=1.10.1
299
- - scitokens-cpp=0.7.3
300
- - send2trash=1.8.2
301
- - setuptools=68.0.0
302
- - six=1.16.0
303
- - snappy=1.1.10
304
- - sniffio=1.3.0
305
- - soupsieve=2.3.2.post1
306
- - sqlite=3.41.2
307
- - stack_data=0.6.2
308
- - sympy=1.11.1
309
- - sysroot_linux-64=2.17
310
- - tbb=2021.8.0
311
- - terminado=0.17.1
312
- - tinycss2=1.2.1
313
- - tk=8.6.12
314
- - tomli=2.0.1
315
- - torchaudio=2.0.2
316
- - torchtriton=2.0.0
317
- - torchvision=0.15.2
318
- - tornado=6.3.2
319
- - tqdm=4.65.0
320
- - traitlets=5.9.0
321
- - typing_extensions=4.12.2
322
- - typing_utils=0.1.0
323
- - ucx=1.14.1
324
- - uri-template=1.3.0
325
- - urllib3=1.26.16
326
- - vdt=0.4.3
327
- - vector-classes=1.4.3
328
- - wcwidth=0.2.6
329
- - webcolors=1.13
330
- - webencodings=0.5.1
331
- - websocket-client=1.6.1
332
- - wheel=0.38.4
333
- - xorg-fixesproto=5.0
334
- - xorg-kbproto=1.0.7
335
- - xorg-libice=1.1.1
336
- - xorg-libsm=1.2.4
337
- - xorg-libx11=1.8.6
338
- - xorg-libxau=1.0.11
339
- - xorg-libxcursor=1.2.0
340
- - xorg-libxdmcp=1.1.3
341
- - xorg-libxext=1.3.4
342
- - xorg-libxfixes=5.0.3
343
- - xorg-libxft=2.3.8
344
- - xorg-libxpm=3.5.16
345
- - xorg-libxrender=0.9.11
346
- - xorg-libxt=1.3.0
347
- - xorg-renderproto=0.11.1
348
- - xorg-xextproto=7.3.0
349
- - xorg-xproto=7.0.31
350
- - xrootd=5.5.4
351
- - xxhash=0.8.1
352
- - xz=5.2.6
353
- - yaml=0.2.5
354
- - zeromq=4.3.4
355
- - zipp=3.16.2
356
- - zlib=1.2.13
357
- - zstd=1.5.2
358
- - pip:
359
- - awkward==2.6.4
360
- - awkward-cpp==33
361
- - contourpy==1.1.0
362
- - cramjam==2.8.3
363
- - cycler==0.11.0
364
- - fonttools==4.42.0
365
- - fsspec==2024.3.1
366
- - h5py==3.9.0
367
- - pip-install==1.3.5
368
- - joblib==1.3.2
369
- - kiwisolver==1.4.4
370
- - matplotlib==3.7.2
371
- - nvidia-cublas-cu12==12.1.3.1
372
- - nvidia-cuda-cupti-cu12==12.1.105
373
- - nvidia-cuda-nvrtc-cu12==12.1.105
374
- - nvidia-cuda-runtime-cu12==12.1.105
375
- - nvidia-cudnn-cu12==8.9.2.26
376
- - nvidia-cufft-cu12==11.0.2.54
377
- - nvidia-curand-cu12==10.3.2.106
378
- - nvidia-cusolver-cu12==11.4.5.107
379
- - nvidia-cusparse-cu12==12.1.0.106
380
- - nvidia-nccl-cu12==2.20.5
381
- - nvidia-nvjitlink-cu12==12.4.127
382
- - nvidia-nvtx-cu12==12.1.105
383
- - pyparsing==3.0.9
384
- - scikit-learn==1.3.0
385
- - threadpoolctl==3.2.0
386
- - torch==2.3.0
387
- - triton==2.3.0
388
- - typing-extensions==4.11.0
389
- - tzdata==2024.1
390
- - uproot==5.3.7
391
- prefix: /global/homes/j/joshuaho/.conda/envs/dgl
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nemo/setup/setup/Dockerfile DELETED
@@ -1,29 +0,0 @@
1
- FROM nvcr.io/nvidia/dgl:25.05-py3
2
-
3
- WORKDIR /global/cfs/projectdirs/atlas/joshua/GNN4Colliders
4
-
5
- LABEL maintainer.name="Joshua Ho"
6
- LABEL maintainer.email="ho22joshua@berkeley.edu"
7
-
8
- ENV LANG=C.UTF-8
9
-
10
- # Install system dependencies: vim, OpenMPI, and build tools
11
- RUN apt-get update -qq \
12
- && apt-get install -y --no-install-recommends \
13
- wget lsb-release gnupg software-properties-common \
14
- vim \
15
- g++-11 gcc-11 libstdc++-11-dev \
16
- openmpi-bin openmpi-common libopenmpi-dev \
17
- && rm -rf /var/lib/apt/lists/*
18
-
19
- # Install Python packages: mpi4py and jupyter
20
- RUN pip install --no-cache-dir mpi4py jupyter uproot
21
-
22
- RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
23
-
24
- RUN pip install dgl -f https://data.dgl.ai/wheels/torch-2.1/cu118/repo.html
25
- i
26
- # (Optional) Expose Jupyter port
27
- EXPOSE 8888
28
-
29
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nemo/setup/setup/build_image.sh DELETED
@@ -1,4 +0,0 @@
1
- tag=$1
2
- echo $tag
3
- podman-hpc build -t joshuaho/pytorch:$tag --platform linux/amd64 .
4
- podman-hpc migrate joshuaho/pytorch:$tag
 
 
 
 
 
nemo/setup/setup/environment.yml DELETED
@@ -1,391 +0,0 @@
1
- name: dgl
2
- channels:
3
- - pytorch
4
- - dglteam/label/cu118
5
- - nvidia
6
- - conda-forge
7
- - defaults
8
- dependencies:
9
- - _libgcc_mutex=0.1
10
- - _openmp_mutex=4.5
11
- - _sysroot_linux-64_curr_repodata_hack=3
12
- - afterimage=1.21
13
- - anyio=3.7.1
14
- - appdirs=1.4.4
15
- - argon2-cffi=21.3.0
16
- - argon2-cffi-bindings=21.2.0
17
- - arrow=1.2.3
18
- - asttokens=2.2.1
19
- - async-lru=2.0.4
20
- - atk-1.0=2.38.0
21
- - attrs=23.1.0
22
- - awkward-pandas=2023.8.0
23
- - aws-c-auth=0.7.0
24
- - aws-c-cal=0.6.0
25
- - aws-c-common=0.8.23
26
- - aws-c-compression=0.2.17
27
- - aws-c-event-stream=0.3.1
28
- - aws-c-http=0.7.11
29
- - aws-c-io=0.13.28
30
- - aws-c-mqtt=0.8.14
31
- - aws-c-s3=0.3.13
32
- - aws-c-sdkutils=0.1.11
33
- - aws-checksums=0.1.16
34
- - aws-crt-cpp=0.20.3
35
- - aws-sdk-cpp=1.10.57
36
- - babel=2.12.1
37
- - backcall=0.2.0
38
- - backports=1.0
39
- - backports.functools_lru_cache=1.6.5
40
- - beautifulsoup4=4.12.2
41
- - binutils=2.38
42
- - binutils_impl_linux-64=2.38
43
- - binutils_linux-64=2.38.0
44
- - blas=1.0
45
- - bleach=6.0.0
46
- - brotlipy=0.7.0
47
- - bzip2=1.0.8
48
- - c-ares=1.19.1
49
- - c-compiler=1.5.2
50
- - ca-certificates=2025.4.26
51
- - cached-property=1.5.2
52
- - cached_property=1.5.2
53
- - cairo=1.16.0
54
- - certifi=2024.8.30
55
- - cffi=1.15.1
56
- - cfitsio=4.2.0
57
- - charset-normalizer=2.0.4
58
- - comm=0.1.4
59
- - compilers=1.5.2
60
- - cryptography=41.0.2
61
- - cuda-cudart=11.8.89
62
- - cuda-cupti=11.8.87
63
- - cuda-libraries=11.8.0
64
- - cuda-nvrtc=11.8.89
65
- - cuda-nvtx=11.8.86
66
- - cuda-runtime=11.8.0
67
- - cxx-compiler=1.5.2
68
- - davix=0.8.4
69
- - debugpy=1.6.8
70
- - decorator=5.1.1
71
- - defusedxml=0.7.1
72
- - dgl=1.1.1.cu118
73
- - entrypoints=0.4
74
- - exceptiongroup=1.1.3
75
- - executing=1.2.0
76
- - expat=2.5.0
77
- - ffmpeg=4.3
78
- - fftw=3.3.10
79
- - filelock=3.9.0
80
- - flit-core=3.9.0
81
- - font-ttf-dejavu-sans-mono=2.37
82
- - font-ttf-inconsolata=3.000
83
- - font-ttf-source-code-pro=2.038
84
- - font-ttf-ubuntu=0.83
85
- - fontconfig=2.14.2
86
- - fonts-conda-ecosystem=1
87
- - fonts-conda-forge=1
88
- - fortran-compiler=1.5.2
89
- - fqdn=1.5.1
90
- - freetype=2.12.1
91
- - fribidi=1.0.10
92
- - ftgl=2.4.0
93
- - gcc=11.2.0
94
- - gcc_impl_linux-64=11.2.0
95
- - gcc_linux-64=11.2.0
96
- - gdk-pixbuf=2.42.8
97
- - gettext=0.21.1
98
- - gflags=2.2.2
99
- - gfortran=11.2.0
100
- - gfortran_impl_linux-64=11.2.0
101
- - gfortran_linux-64=11.2.0
102
- - giflib=5.2.1
103
- - gl2ps=1.4.2
104
- - glew=2.1.0
105
- - glog=0.6.0
106
- - gmp=6.2.1
107
- - gmpy2=2.1.2
108
- - gnutls=3.6.15
109
- - graphite2=1.3.13
110
- - graphviz=6.0.2
111
- - gsl=2.7
112
- - gsoap=2.8.123
113
- - gtk2=2.24.33
114
- - gts=0.7.6
115
- - gxx=11.2.0
116
- - gxx_impl_linux-64=11.2.0
117
- - gxx_linux-64=11.2.0
118
- - harfbuzz=7.3.0
119
- - icu=72.1
120
- - idna=3.4
121
- - importlib-metadata=6.8.0
122
- - importlib-resources=6.0.1
123
- - importlib_metadata=6.8.0
124
- - importlib_resources=6.0.1
125
- - intel-openmp=2023.1.0
126
- - ipykernel=6.25.1
127
- - ipyparallel=8.6.1
128
- - ipython=8.12.2
129
- - isoduration=20.11.0
130
- - jedi=0.19.0
131
- - jinja2=3.1.2
132
- - jpeg=9e
133
- - json5=0.9.14
134
- - jsonpointer=2.0
135
- - jsonschema=4.19.0
136
- - jsonschema-specifications=2023.7.1
137
- - jsonschema-with-format-nongpl=4.19.0
138
- - jupyter-lsp=2.2.0
139
- - jupyter_client=8.3.0
140
- - jupyter_core=5.3.0
141
- - jupyter_events=0.7.0
142
- - jupyter_server=2.7.0
143
- - jupyter_server_terminals=0.4.4
144
- - jupyterlab=4.0.5
145
- - jupyterlab_pygments=0.2.2
146
- - jupyterlab_server=2.24.0
147
- - kernel-headers_linux-64=3.10.0
148
- - keyutils=1.6.1
149
- - krb5=1.20.1
150
- - lame=3.100
151
- - lcms2=2.12
152
- - ld_impl_linux-64=2.38
153
- - lerc=3.0
154
- - libabseil=20230125.3
155
- - libarrow=12.0.1
156
- - libblas=3.9.0
157
- - libbrotlicommon=1.0.9
158
- - libbrotlidec=1.0.9
159
- - libbrotlienc=1.0.9
160
- - libcblas=3.9.0
161
- - libcrc32c=1.1.2
162
- - libcublas=11.11.3.6
163
- - libcufft=10.9.0.58
164
- - libcufile=1.7.1.12
165
- - libcurand=10.3.3.129
166
- - libcurl=8.1.2
167
- - libcusolver=11.4.1.48
168
- - libcusparse=11.7.5.86
169
- - libcxx=15.0.7
170
- - libcxxabi=15.0.7
171
- - libdeflate=1.12
172
- - libedit=3.1.20191231
173
- - libev=4.33
174
- - libevent=2.1.12
175
- - libexpat=2.5.0
176
- - libffi=3.4.4
177
- - libgcc-devel_linux-64=11.2.0
178
- - libgcc-ng=13.1.0
179
- - libgd=2.3.3
180
- - libgfortran-ng=11.2.0
181
- - libgfortran5=11.2.0
182
- - libglib=2.76.4
183
- - libglu=9.0.0
184
- - libgomp=13.1.0
185
- - libgoogle-cloud=2.12.0
186
- - libgrpc=1.56.2
187
- - libiconv=1.17
188
- - libidn2=2.3.4
189
- - libllvm13=13.0.1
190
- - libllvm14=14.0.6
191
- - libnghttp2=1.52.0
192
- - libnpp=11.8.0.86
193
- - libnsl=2.0.0
194
- - libnuma=2.0.18
195
- - libnvjpeg=11.9.0.86
196
- - libpng=1.6.39
197
- - libprotobuf=4.23.3
198
- - librsvg=2.54.4
199
- - libsodium=1.0.18
200
- - libsqlite=3.42.0
201
- - libssh2=1.11.0
202
- - libstdcxx-devel_linux-64=11.2.0
203
- - libstdcxx-ng=13.1.0
204
- - libtasn1=4.19.0
205
- - libthrift=0.18.1
206
- - libtiff=4.4.0
207
- - libtool=2.4.7
208
- - libunistring=0.9.10
209
- - libutf8proc=2.8.0
210
- - libuuid=2.38.1
211
- - libwebp=1.2.4
212
- - libwebp-base=1.2.4
213
- - libxcb=1.15
214
- - libxml2=2.10.4
215
- - libzlib=1.2.13
216
- - llvmlite=0.40.1
217
- - lz4-c=1.9.4
218
- - markupsafe=2.1.1
219
- - matplotlib-inline=0.1.6
220
- - metakernel=0.29.5
221
- - mistune=3.0.0
222
- - mkl=2023.1.0
223
- - mkl-service=2.4.0
224
- - mkl_fft=1.3.6
225
- - mkl_random=1.2.2
226
- - mpc=1.1.0
227
- - mpfr=4.0.2
228
- - mpmath=1.3.0
229
- - nbclient=0.8.0
230
- - nbconvert-core=7.7.3
231
- - nbformat=5.9.2
232
- - ncurses=6.4
233
- - nest-asyncio=1.5.6
234
- - nettle=3.7.3
235
- - networkx=3.1
236
- - nlohmann_json=3.11.2
237
- - notebook=7.0.2
238
- - notebook-shim=0.2.3
239
- - numba=0.57.1
240
- - numpy=1.24.3
241
- - numpy-base=1.24.3
242
- - openh264=2.1.1
243
- - openssl=3.3.1
244
- - orc=1.9.0
245
- - overrides=7.4.0
246
- - packaging=23.0
247
- - pandas=2.0.3
248
- - pandocfilters=1.5.0
249
- - pango=1.50.14
250
- - parso=0.8.3
251
- - pcre=8.45
252
- - pcre2=10.40
253
- - pexpect=4.8.0
254
- - pickleshare=0.7.5
255
- - pillow=9.4.0
256
- - pip=23.2.1
257
- - pixman=0.40.0
258
- - pkgutil-resolve-name=1.3.10
259
- - platformdirs=2.6.0
260
- - pooch=1.4.0
261
- - portalocker=2.7.0
262
- - prometheus_client=0.17.1
263
- - prompt-toolkit=3.0.39
264
- - prompt_toolkit=3.0.39
265
- - psutil=5.9.0
266
- - pthread-stubs=0.4
267
- - ptyprocess=0.7.0
268
- - pure_eval=0.2.2
269
- - pyarrow=12.0.1
270
- - pycparser=2.21
271
- - pygments=2.16.1
272
- - pyopenssl=23.2.0
273
- - pysocks=1.7.1
274
- - pythia8=8.309
275
- - python=3.8.17
276
- - python-dateutil=2.8.2
277
- - python-fastjsonschema=2.18.0
278
- - python-json-logger=2.0.7
279
- - python-tzdata=2024.2
280
- - python_abi=3.8
281
- - pytorch=2.0.1
282
- - pytorch-cuda=11.8
283
- - pytorch-mutex=1.0
284
- - pytz=2023.3
285
- - pyyaml=6.0
286
- - pyzmq=25.1.1
287
- - rdma-core=28.9
288
- - re2=2023.03.02
289
- - readline=8.2
290
- - referencing=0.30.2
291
- - requests=2.31.0
292
- - rfc3339-validator=0.1.4
293
- - rfc3986-validator=0.1.1
294
- - root=6.28.0
295
- - root_base=6.28.0
296
- - rpds-py=0.9.2
297
- - s2n=1.3.46
298
- - scipy=1.10.1
299
- - scitokens-cpp=0.7.3
300
- - send2trash=1.8.2
301
- - setuptools=68.0.0
302
- - six=1.16.0
303
- - snappy=1.1.10
304
- - sniffio=1.3.0
305
- - soupsieve=2.3.2.post1
306
- - sqlite=3.41.2
307
- - stack_data=0.6.2
308
- - sympy=1.11.1
309
- - sysroot_linux-64=2.17
310
- - tbb=2021.8.0
311
- - terminado=0.17.1
312
- - tinycss2=1.2.1
313
- - tk=8.6.12
314
- - tomli=2.0.1
315
- - torchaudio=2.0.2
316
- - torchtriton=2.0.0
317
- - torchvision=0.15.2
318
- - tornado=6.3.2
319
- - tqdm=4.65.0
320
- - traitlets=5.9.0
321
- - typing_extensions=4.12.2
322
- - typing_utils=0.1.0
323
- - ucx=1.14.1
324
- - uri-template=1.3.0
325
- - urllib3=1.26.16
326
- - vdt=0.4.3
327
- - vector-classes=1.4.3
328
- - wcwidth=0.2.6
329
- - webcolors=1.13
330
- - webencodings=0.5.1
331
- - websocket-client=1.6.1
332
- - wheel=0.38.4
333
- - xorg-fixesproto=5.0
334
- - xorg-kbproto=1.0.7
335
- - xorg-libice=1.1.1
336
- - xorg-libsm=1.2.4
337
- - xorg-libx11=1.8.6
338
- - xorg-libxau=1.0.11
339
- - xorg-libxcursor=1.2.0
340
- - xorg-libxdmcp=1.1.3
341
- - xorg-libxext=1.3.4
342
- - xorg-libxfixes=5.0.3
343
- - xorg-libxft=2.3.8
344
- - xorg-libxpm=3.5.16
345
- - xorg-libxrender=0.9.11
346
- - xorg-libxt=1.3.0
347
- - xorg-renderproto=0.11.1
348
- - xorg-xextproto=7.3.0
349
- - xorg-xproto=7.0.31
350
- - xrootd=5.5.4
351
- - xxhash=0.8.1
352
- - xz=5.2.6
353
- - yaml=0.2.5
354
- - zeromq=4.3.4
355
- - zipp=3.16.2
356
- - zlib=1.2.13
357
- - zstd=1.5.2
358
- - pip:
359
- - awkward==2.6.4
360
- - awkward-cpp==33
361
- - contourpy==1.1.0
362
- - cramjam==2.8.3
363
- - cycler==0.11.0
364
- - fonttools==4.42.0
365
- - fsspec==2024.3.1
366
- - h5py==3.9.0
367
- - pip-install==1.3.5
368
- - joblib==1.3.2
369
- - kiwisolver==1.4.4
370
- - matplotlib==3.7.2
371
- - nvidia-cublas-cu12==12.1.3.1
372
- - nvidia-cuda-cupti-cu12==12.1.105
373
- - nvidia-cuda-nvrtc-cu12==12.1.105
374
- - nvidia-cuda-runtime-cu12==12.1.105
375
- - nvidia-cudnn-cu12==8.9.2.26
376
- - nvidia-cufft-cu12==11.0.2.54
377
- - nvidia-curand-cu12==10.3.2.106
378
- - nvidia-cusolver-cu12==11.4.5.107
379
- - nvidia-cusparse-cu12==12.1.0.106
380
- - nvidia-nccl-cu12==2.20.5
381
- - nvidia-nvjitlink-cu12==12.4.127
382
- - nvidia-nvtx-cu12==12.1.105
383
- - pyparsing==3.0.9
384
- - scikit-learn==1.3.0
385
- - threadpoolctl==3.2.0
386
- - torch==2.3.0
387
- - triton==2.3.0
388
- - typing-extensions==4.11.0
389
- - tzdata==2024.1
390
- - uproot==5.3.7
391
- prefix: /global/homes/j/joshuaho/.conda/envs/dgl
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nemo/setup/setup/test_setup.py DELETED
@@ -1,48 +0,0 @@
1
- import os
2
- import importlib.util
3
- import sys
4
-
5
- def test_imports(directories):
6
- """
7
- Test importing all Python files in the specified directories.
8
-
9
- Parameters:
10
- - directories: List of directory paths to test.
11
- """
12
- print("Testing Conda environment...")
13
-
14
- for directory in directories:
15
- print(f"\nChecking directory: {directory}")
16
-
17
- # Check if the directory exists
18
- if not os.path.isdir(directory):
19
- print(f"Directory not found: {directory}")
20
- continue
21
-
22
- # Iterate through all files in the directory
23
- for filename in os.listdir(directory):
24
- # Only consider Python files
25
- if filename.endswith(".py"):
26
- filepath = os.path.join(directory, filename)
27
- module_name = os.path.splitext(filename)[0] # Remove .py extension
28
-
29
- try:
30
- # Dynamically import the module
31
- spec = importlib.util.spec_from_file_location(module_name, filepath)
32
- module = importlib.util.module_from_spec(spec)
33
- spec.loader.exec_module(module)
34
- print(f"Successfully imported: {filepath}")
35
- except Exception as e:
36
- # Print the file and the error message if import fails
37
- print(f"Failed to import: {filepath}")
38
- print(f"Error: {e}")
39
-
40
- if __name__ == "__main__":
41
- # Automatically append the current directory to sys.path
42
- current_directory = os.getcwd()
43
- sys.path.append(current_directory)
44
- print(f"Current directory added to sys.path: {current_directory}")
45
-
46
- # List of directories to check
47
- directories = ["scripts", "root_gnn_base", "models"]
48
- test_imports(directories)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nemo/setup/test_setup.py DELETED
@@ -1,48 +0,0 @@
1
- import os
2
- import importlib.util
3
- import sys
4
-
5
- def test_imports(directories):
6
- """
7
- Test importing all Python files in the specified directories.
8
-
9
- Parameters:
10
- - directories: List of directory paths to test.
11
- """
12
- print("Testing Conda environment...")
13
-
14
- for directory in directories:
15
- print(f"\nChecking directory: {directory}")
16
-
17
- # Check if the directory exists
18
- if not os.path.isdir(directory):
19
- print(f"Directory not found: {directory}")
20
- continue
21
-
22
- # Iterate through all files in the directory
23
- for filename in os.listdir(directory):
24
- # Only consider Python files
25
- if filename.endswith(".py"):
26
- filepath = os.path.join(directory, filename)
27
- module_name = os.path.splitext(filename)[0] # Remove .py extension
28
-
29
- try:
30
- # Dynamically import the module
31
- spec = importlib.util.spec_from_file_location(module_name, filepath)
32
- module = importlib.util.module_from_spec(spec)
33
- spec.loader.exec_module(module)
34
- print(f"Successfully imported: {filepath}")
35
- except Exception as e:
36
- # Print the file and the error message if import fails
37
- print(f"Failed to import: {filepath}")
38
- print(f"Error: {e}")
39
-
40
- if __name__ == "__main__":
41
- # Automatically append the current directory to sys.path
42
- current_directory = os.getcwd()
43
- sys.path.append(current_directory)
44
- print(f"Current directory added to sys.path: {current_directory}")
45
-
46
- # List of directories to check
47
- directories = ["scripts", "root_gnn_base", "models"]
48
- test_imports(directories)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
root_gnn_dgl/README.md CHANGED
@@ -1,53 +1,62 @@
1
- # root_gnn_dgl
2
-
3
- ## Data Directory (for Hackathon)
4
- `/global/cfs/projectdirs/trn007/lbl_atlas/data/`
5
 
6
- * `stats_all`: full statistics sample, ~10M events per process
7
- * `stats_100K`: reduced statistics sample, 100K events per process
8
- * `processed_graphs`: graphs that have already been processed
9
- * `scores`: a copy of the samples along with the GNN scores for each event
10
 
11
- ## Environment Setup
12
 
13
- The environment dependencies for this project are listed in `setup/environment.yml`. Follow the steps below to set up the environment:
 
14
 
15
- ### Step 1: Install Conda
16
- If you don’t already have Conda installed, install either Miniconda (lightweight) or Anaconda (full version):
17
 
18
- - **Miniconda**: Download and install from [https://docs.conda.io/en/latest/miniconda.html](https://docs.conda.io/en/latest/miniconda.html).
19
- - **Anaconda**: Download and install from [https://www.anaconda.com/products/distribution](https://www.anaconda.com/products/distribution).
20
 
21
- ### Step 2: Clone the Repository
22
- Clone this repository to your local machine:
23
  ```bash
24
- git init
25
- git lfs install
26
- git clone https://huggingface.co/HWresearch/GNN4Colliders
 
 
27
  ```
28
- If you want to clone without large files - just their pointers
 
 
 
 
 
 
 
29
  ```bash
30
- GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/HWresearch/GNN4Colliders
31
  ```
32
 
33
- ### Step 3: Create the Conda Environment
34
- Use the `environment.yml` file to create the Conda environment:
35
  ```bash
36
- conda env create -f setup/environment.yml -n <environment_name>
 
37
  ```
38
 
39
- ### Step 4: Activate the Environment
40
- Activate the newly created environment:
41
  ```bash
42
- conda activate <environment_name>
43
- ```
44
- Replace <environment_name> with the name of the environment specified in Step 4.
45
 
46
- ### Step 5: Test the Environment
 
 
 
 
 
 
 
 
 
 
 
47
  Run the `setup/test_setup.py` script to confirm that all packages needed for training are properly set up.
48
  ```bash
49
  python setup/test_setup.py
50
  ```
 
 
51
  ## Running the Demo
52
  The demo training is an example of our ML workflow, consisting of training a pretrained model, then finetuning it for an analysis task, while also training a model for the analysis task from scratch. The config files for the demo are located in the directory `configs/stats_100K/`. The demo can be run on a login node on Perlmutter (if enough GPU memory is availble).
53
 
 
 
 
 
 
1
 
2
+ # root_gnn_dgl
 
 
 
3
 
4
+ Pretrained DGL-based ROOT graph neural network.
5
 
6
+ ## Overview
7
+ - Stable release with pretrained model weights.
8
 
9
+ Pretrained model location: ``
 
10
 
11
+ ## Conda setup
 
12
 
 
 
13
  ```bash
14
+ cd setup
15
+ conda env create -f environment.yml
16
+ conda activate pytorch
17
+ cd ..
18
+ python setup/test_setup.py
19
  ```
20
+
21
+ ## Container Setup (Podman-HPC)
22
+
23
+ - NERSC Perlmutter environment with `podman-hpc` available.
24
+ - Access to `joshuaho/pytorch:1.0` on Docker Hub [https://hub.docker.com/r/joshuaho/pytorch](https://hub.docker.com/r/joshuaho/pytorch)
25
+
26
+ ### Pull the Prebuilt Image
27
+
28
  ```bash
29
+ podman-hpc pull docker.io/joshuaho/pytorch:1.0
30
  ```
31
 
32
+ Or, you can build your own container here:
33
+
34
  ```bash
35
+ cd setup
36
+ source build_image.sh
37
  ```
38
 
39
+ Run the image and mount the paths you need, replaceing `<source>` with source directory path and `<target>` with the path for when you are inside the container.
 
40
  ```bash
 
 
 
41
 
42
+ podman-hpc run \
43
+ -it \
44
+ --mount type=bind,source=<source>,target=<target> \
45
+ --rm \
46
+ --network host \
47
+ --gpu \
48
+ --userns keep-id \
49
+ --shm-size=32g \
50
+ joshuaho/pytorch:1.0
51
+ ```
52
+
53
+ ### Test the Environment
54
  Run the `setup/test_setup.py` script to confirm that all packages needed for training are properly set up.
55
  ```bash
56
  python setup/test_setup.py
57
  ```
58
+
59
+
60
  ## Running the Demo
61
  The demo training is an example of our ML workflow, consisting of training a pretrained model, then finetuning it for an analysis task, while also training a model for the analysis task from scratch. The config files for the demo are located in the directory `configs/stats_100K/`. The demo can be run on a login node on Perlmutter (if enough GPU memory is availble).
62
 
root_gnn_dgl/configs/attention/ttH_CP_even_vs_odd.yaml DELETED
@@ -1,58 +0,0 @@
1
- Training_Name: ttH_CP_even_vs_odd
2
- Training_Directory: trainings/attention/ttH_CP_even_vs_odd
3
- Model:
4
- module: models.GCN
5
- class: Attention_Edge_Network
6
- args:
7
- hid_size: 64
8
- in_size: 7
9
- out_size: 1
10
- n_layers: 4
11
- n_proc_steps: 4
12
- dropout: 0
13
- num_heads: 2
14
- Training:
15
- epochs: 500
16
- batch_size: 1024
17
- learning_rate: 0.0001
18
- gamma: 0.99
19
- Datasets:
20
- ttH_CP_even: &dataset_defn
21
- module: root_gnn_base.dataset
22
- class: LazyDataset
23
- shuffle_chunks: 3
24
- batch_size: 1024
25
- padding_mode: NODE
26
- args: &dataset_args
27
- name: ttH_CP_even
28
- label: 0
29
- # weight_var: weight
30
- chunks: 3
31
- buffer_size: 2
32
- file_names: ttH_NLO.root
33
- tree_name: output
34
- fold_var: Number
35
- raw_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/stats_100K/
36
- save_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/processed_graphs/attention/ttH_CP_even_vs_odd/
37
- node_branch_names:
38
- - [jet_pt, ele_pt, mu_pt, ph_pt, MET_met]
39
- - [jet_eta, ele_eta, mu_eta, ph_eta, 0]
40
- - [jet_phi, ele_phi, mu_phi, ph_phi, MET_phi]
41
- - CALC_E
42
- - [jet_btag, 0, 0, 0, 0]
43
- - [0, ele_charge, mu_charge, 0, 0]
44
- - NODE_TYPE
45
- node_branch_types: [vector, vector, vector, vector, single]
46
- node_feature_scales: [1e-1, 1, 1, 1e-1, 1, 1, 1]
47
- folding:
48
- n_folds: 4
49
- test: [0]
50
- # validation: 1
51
- train: [1, 2, 3]
52
- ttH_CP_odd:
53
- <<: *dataset_defn
54
- args:
55
- <<: *dataset_args
56
- name: ttH_CP_odd
57
- label: 1
58
- file_names: ttH_CPodd.root
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
root_gnn_dgl/configs/stats_100K/finetuning_ttH_CP_even_vs_odd.yaml CHANGED
@@ -41,8 +41,8 @@ Datasets:
41
  file_names: ttH_NLO.root
42
  tree_name: output
43
  fold_var: Number
44
- raw_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/stats_100K/
45
- save_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/processed_graphs/stats_100K/ttH_CP_even_vs_odd/
46
  node_branch_names:
47
  - [jet_pt, ele_pt, mu_pt, ph_pt, MET_met]
48
  - [jet_eta, ele_eta, mu_eta, ph_eta, 0]
 
41
  file_names: ttH_NLO.root
42
  tree_name: output
43
  fold_var: Number
44
+ raw_dir: /global/cfs/projectdirs/atlas/joshua/gnn_data/stats_100K/
45
+ save_dir: /global/cfs/projectdirs/atlas/joshua/gnn_data/processed_graphs/stats_100K/ttH_CP_even_vs_odd/
46
  node_branch_names:
47
  - [jet_pt, ele_pt, mu_pt, ph_pt, MET_met]
48
  - [jet_eta, ele_eta, mu_eta, ph_eta, 0]
root_gnn_dgl/configs/stats_100K/pretraining_multiclass.yaml CHANGED
@@ -38,8 +38,8 @@ Datasets:
38
  file_names: ttH_NLO_inc.root
39
  tree_name: output
40
  fold_var: Number
41
- raw_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/stats_100K/
42
- save_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/processed_graphs/stats_100K/pretraining_multiclass/
43
  node_branch_names:
44
  - [jet_pt, ele_pt, mu_pt, ph_pt, MET_met]
45
  - [jet_eta, ele_eta, mu_eta, ph_eta, 0]
 
38
  file_names: ttH_NLO_inc.root
39
  tree_name: output
40
  fold_var: Number
41
+ raw_dir: /global/cfs/projectdirs/atlas/joshua/gnn_data/stats_100K/
42
+ save_dir: /global/cfs/projectdirs/atlas/joshua/gnn_data/processed_graphs/stats_100K/pretraining_multiclass/
43
  node_branch_names:
44
  - [jet_pt, ele_pt, mu_pt, ph_pt, MET_met]
45
  - [jet_eta, ele_eta, mu_eta, ph_eta, 0]
root_gnn_dgl/configs/stats_100K/ttH_CP_even_vs_odd.yaml CHANGED
@@ -31,8 +31,8 @@ Datasets:
31
  file_names: ttH_NLO.root
32
  tree_name: output
33
  fold_var: Number
34
- raw_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/stats_100K/
35
- save_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/processed_graphs/stats_100K/ttH_CP_even_vs_odd/
36
  node_branch_names:
37
  - [jet_pt, ele_pt, mu_pt, ph_pt, MET_met]
38
  - [jet_eta, ele_eta, mu_eta, ph_eta, 0]
 
31
  file_names: ttH_NLO.root
32
  tree_name: output
33
  fold_var: Number
34
+ raw_dir: /global/cfs/projectdirs/atlas/joshua/gnn_data/stats_100K/
35
+ save_dir: /global/cfs/projectdirs/atlas/joshua/gnn_data/processed_graphs/stats_100K/ttH_CP_even_vs_odd/
36
  node_branch_names:
37
  - [jet_pt, ele_pt, mu_pt, ph_pt, MET_met]
38
  - [jet_eta, ele_eta, mu_eta, ph_eta, 0]
root_gnn_dgl/configs/stats_100K/ttH_CP_even_vs_odd_batch_size_2048.yaml DELETED
@@ -1,57 +0,0 @@
1
- Training_Name: ttH_CP_even_vs_odd_batch_size_2048
2
- Training_Directory: trainings/stats_100K/ttH_CP_even_vs_odd_batch_size_2048
3
- Model:
4
- module: models.GCN
5
- class: Edge_Network
6
- args:
7
- hid_size: 64
8
- in_size: 7
9
- out_size: 1
10
- n_layers: 4
11
- n_proc_steps: 4
12
- dropout: 0
13
- Training:
14
- epochs: 500
15
- batch_size: 2048
16
- learning_rate: 0.0001
17
- gamma: 0.99
18
- Datasets:
19
- ttH_CP_even: &dataset_defn
20
- module: root_gnn_base.dataset
21
- class: LazyDataset
22
- shuffle_chunks: 3
23
- batch_size: 2048
24
- padding_mode: NONE #one of STEPS, FIXED, or NONE
25
- args: &dataset_args
26
- name: ttH_CP_even
27
- label: 0
28
- # weight_var: weight
29
- chunks: 3
30
- buffer_size: 2
31
- file_names: ttH_NLO.root
32
- tree_name: output
33
- fold_var: Number
34
- raw_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/stats_100K/
35
- save_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/processed_graphs/stats_100K/ttH_CP_even_vs_odd_batch_size_2048/
36
- node_branch_names:
37
- - [jet_pt, ele_pt, mu_pt, ph_pt, MET_met]
38
- - [jet_eta, ele_eta, mu_eta, ph_eta, 0]
39
- - [jet_phi, ele_phi, mu_phi, ph_phi, MET_phi]
40
- - CALC_E
41
- - [jet_btag, 0, 0, 0, 0]
42
- - [0, ele_charge, mu_charge, 0, 0]
43
- - NODE_TYPE
44
- node_branch_types: [vector, vector, vector, vector, single]
45
- node_feature_scales: [1e-1, 1, 1, 1e-1, 1, 1, 1]
46
- folding:
47
- n_folds: 4
48
- test: [0]
49
- # validation: 1
50
- train: [1, 2, 3]
51
- ttH_CP_odd:
52
- <<: *dataset_defn
53
- args:
54
- <<: *dataset_args
55
- name: ttH_CP_odd
56
- label: 1
57
- file_names: ttH_CPodd.root
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
root_gnn_dgl/configs/stats_100K/ttH_CP_even_vs_odd_batch_size_4096.yaml DELETED
@@ -1,57 +0,0 @@
1
- Training_Name: ttH_CP_even_vs_odd_batch_size_4096
2
- Training_Directory: trainings/stats_100K/ttH_CP_even_vs_odd_batch_size_4096
3
- Model:
4
- module: models.GCN
5
- class: Edge_Network
6
- args:
7
- hid_size: 64
8
- in_size: 7
9
- out_size: 1
10
- n_layers: 4
11
- n_proc_steps: 4
12
- dropout: 0
13
- Training:
14
- epochs: 500
15
- batch_size: 1024
16
- learning_rate: 0.0001
17
- gamma: 0.99
18
- Datasets:
19
- ttH_CP_even: &dataset_defn
20
- module: root_gnn_base.dataset
21
- class: LazyDataset
22
- shuffle_chunks: 3
23
- batch_size: 4096
24
- padding_mode: NONE #one of STEPS, FIXED, or NONE
25
- args: &dataset_args
26
- name: ttH_CP_even
27
- label: 0
28
- # weight_var: weight
29
- chunks: 3
30
- buffer_size: 2
31
- file_names: ttH_NLO.root
32
- tree_name: output
33
- fold_var: Number
34
- raw_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/stats_100K/
35
- save_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/processed_graphs/stats_100K/ttH_CP_even_vs_odd_batch_size_4096/
36
- node_branch_names:
37
- - [jet_pt, ele_pt, mu_pt, ph_pt, MET_met]
38
- - [jet_eta, ele_eta, mu_eta, ph_eta, 0]
39
- - [jet_phi, ele_phi, mu_phi, ph_phi, MET_phi]
40
- - CALC_E
41
- - [jet_btag, 0, 0, 0, 0]
42
- - [0, ele_charge, mu_charge, 0, 0]
43
- - NODE_TYPE
44
- node_branch_types: [vector, vector, vector, vector, single]
45
- node_feature_scales: [1e-1, 1, 1, 1e-1, 1, 1, 1]
46
- folding:
47
- n_folds: 4
48
- test: [0]
49
- # validation: 1
50
- train: [1, 2, 3]
51
- ttH_CP_odd:
52
- <<: *dataset_defn
53
- args:
54
- <<: *dataset_args
55
- name: ttH_CP_odd
56
- label: 1
57
- file_names: ttH_CPodd.root
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
root_gnn_dgl/configs/stats_100K/ttH_CP_even_vs_odd_batch_size_8192.yaml DELETED
@@ -1,57 +0,0 @@
1
- Training_Name: ttH_CP_even_vs_odd_batch_size_8192
2
- Training_Directory: trainings/stats_100K/ttH_CP_even_vs_odd_batch_size_8192
3
- Model:
4
- module: models.GCN
5
- class: Edge_Network
6
- args:
7
- hid_size: 64
8
- in_size: 7
9
- out_size: 1
10
- n_layers: 4
11
- n_proc_steps: 4
12
- dropout: 0
13
- Training:
14
- epochs: 500
15
- batch_size: 2048
16
- learning_rate: 0.0001
17
- gamma: 0.99
18
- Datasets:
19
- ttH_CP_even: &dataset_defn
20
- module: root_gnn_base.dataset
21
- class: LazyDataset
22
- shuffle_chunks: 3
23
- batch_size: 2048
24
- padding_mode: NONE #one of STEPS, FIXED, or NONE
25
- args: &dataset_args
26
- name: ttH_CP_even
27
- label: 0
28
- # weight_var: weight
29
- chunks: 3
30
- buffer_size: 2
31
- file_names: ttH_NLO.root
32
- tree_name: output
33
- fold_var: Number
34
- raw_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/stats_100K/
35
- save_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/processed_graphs/stats_100K/ttH_CP_even_vs_odd_batch_size_8192/
36
- node_branch_names:
37
- - [jet_pt, ele_pt, mu_pt, ph_pt, MET_met]
38
- - [jet_eta, ele_eta, mu_eta, ph_eta, 0]
39
- - [jet_phi, ele_phi, mu_phi, ph_phi, MET_phi]
40
- - CALC_E
41
- - [jet_btag, 0, 0, 0, 0]
42
- - [0, ele_charge, mu_charge, 0, 0]
43
- - NODE_TYPE
44
- node_branch_types: [vector, vector, vector, vector, single]
45
- node_feature_scales: [1e-1, 1, 1, 1e-1, 1, 1, 1]
46
- folding:
47
- n_folds: 4
48
- test: [0]
49
- # validation: 1
50
- train: [1, 2, 3]
51
- ttH_CP_odd:
52
- <<: *dataset_defn
53
- args:
54
- <<: *dataset_args
55
- name: ttH_CP_odd
56
- label: 1
57
- file_names: ttH_CPodd.root
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
root_gnn_dgl/configs/stats_all/finetuning_ttH_CP_even_vs_odd.yaml CHANGED
@@ -41,8 +41,8 @@ Datasets:
41
  file_names: ttH_NLO.root
42
  tree_name: output
43
  fold_var: Number
44
- raw_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/stats_all/
45
- save_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/processed_graphs/stats_all/ttH_CP_even_vs_odd/
46
  node_branch_names:
47
  - [jet_pt, ele_pt, mu_pt, ph_pt, MET_met]
48
  - [jet_eta, ele_eta, mu_eta, ph_eta, 0]
 
41
  file_names: ttH_NLO.root
42
  tree_name: output
43
  fold_var: Number
44
+ raw_dir: /global/cfs/projectdirs/atlas/joshua/gnn_data/stats_all/
45
+ save_dir: /global/cfs/projectdirs/atlas/joshua/gnn_data/processed_graphs/stats_all/ttH_CP_even_vs_odd/
46
  node_branch_names:
47
  - [jet_pt, ele_pt, mu_pt, ph_pt, MET_met]
48
  - [jet_eta, ele_eta, mu_eta, ph_eta, 0]
root_gnn_dgl/configs/stats_all/pretraining_multiclass.yaml CHANGED
@@ -38,8 +38,8 @@ Datasets:
38
  file_names: ttH_NLO_inc.root
39
  tree_name: output
40
  fold_var: Number
41
- raw_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/stats_all/
42
- save_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/processed_graphs/stats_all/pretraining_multiclass/
43
  node_branch_names:
44
  - [jet_pt, ele_pt, mu_pt, ph_pt, MET_met]
45
  - [jet_eta, ele_eta, mu_eta, ph_eta, 0]
 
38
  file_names: ttH_NLO_inc.root
39
  tree_name: output
40
  fold_var: Number
41
+ raw_dir: /global/cfs/projectdirs/atlas/joshua/gnn_data/stats_all/
42
+ save_dir: /global/cfs/projectdirs/atlas/joshua/gnn_data/processed_graphs/stats_all/pretraining_multiclass/
43
  node_branch_names:
44
  - [jet_pt, ele_pt, mu_pt, ph_pt, MET_met]
45
  - [jet_eta, ele_eta, mu_eta, ph_eta, 0]
root_gnn_dgl/configs/stats_all/ttH_CP_even_vs_odd.yaml CHANGED
@@ -31,8 +31,8 @@ Datasets:
31
  file_names: ttH_NLO.root
32
  tree_name: output
33
  fold_var: Number
34
- raw_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/stats_all/
35
- save_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/processed_graphs/stats_all/ttH_CP_even_vs_odd/
36
  node_branch_names:
37
  - [jet_pt, ele_pt, mu_pt, ph_pt, MET_met]
38
  - [jet_eta, ele_eta, mu_eta, ph_eta, 0]
 
31
  file_names: ttH_NLO.root
32
  tree_name: output
33
  fold_var: Number
34
+ raw_dir: /global/cfs/projectdirs/atlas/joshua/gnn_data/stats_all/
35
+ save_dir: /global/cfs/projectdirs/atlas/joshua/gnn_data/processed_graphs/stats_all/ttH_CP_even_vs_odd/
36
  node_branch_names:
37
  - [jet_pt, ele_pt, mu_pt, ph_pt, MET_met]
38
  - [jet_eta, ele_eta, mu_eta, ph_eta, 0]
root_gnn_dgl/configs/stats_all/ttH_CP_even_vs_odd_batch_size_2048.yaml DELETED
@@ -1,57 +0,0 @@
1
- Training_Name: ttH_CP_even_vs_odd_batch_size_2048
2
- Training_Directory: trainings/stats_all/ttH_CP_even_vs_odd_batch_size_2048
3
- Model:
4
- module: models.GCN
5
- class: Edge_Network
6
- args:
7
- hid_size: 64
8
- in_size: 7
9
- out_size: 1
10
- n_layers: 4
11
- n_proc_steps: 4
12
- dropout: 0
13
- Training:
14
- epochs: 500
15
- batch_size: 2048
16
- learning_rate: 0.0001
17
- gamma: 0.99
18
- Datasets:
19
- ttH_CP_even: &dataset_defn
20
- module: root_gnn_base.dataset
21
- class: LazyDataset
22
- shuffle_chunks: 10
23
- batch_size: 2048
24
- padding_mode: NONE #one of STEPS, FIXED, or NONE
25
- args: &dataset_args
26
- name: ttH_CP_even
27
- label: 0
28
- # weight_var: weight
29
- chunks: 10
30
- buffer_size: 3
31
- file_names: ttH_NLO.root
32
- tree_name: output
33
- fold_var: Number
34
- raw_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/stats_all/
35
- save_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/processed_graphs/stats_all/ttH_CP_even_vs_odd_batch_size_2048/
36
- node_branch_names:
37
- - [jet_pt, ele_pt, mu_pt, ph_pt, MET_met]
38
- - [jet_eta, ele_eta, mu_eta, ph_eta, 0]
39
- - [jet_phi, ele_phi, mu_phi, ph_phi, MET_phi]
40
- - CALC_E
41
- - [jet_btag, 0, 0, 0, 0]
42
- - [0, ele_charge, mu_charge, 0, 0]
43
- - NODE_TYPE
44
- node_branch_types: [vector, vector, vector, vector, single]
45
- node_feature_scales: [1e-1, 1, 1, 1e-1, 1, 1, 1]
46
- folding:
47
- n_folds: 4
48
- test: [0]
49
- # validation: 1
50
- train: [1, 2, 3]
51
- ttH_CP_odd:
52
- <<: *dataset_defn
53
- args:
54
- <<: *dataset_args
55
- name: ttH_CP_odd
56
- label: 1
57
- file_names: ttH_CPodd.root
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
root_gnn_dgl/configs/stats_all/ttH_CP_even_vs_odd_batch_size_4096.yaml DELETED
@@ -1,57 +0,0 @@
1
- Training_Name: ttH_CP_even_vs_odd_batch_size_4096
2
- Training_Directory: trainings/stats_all/ttH_CP_even_vs_odd_batch_size_4096
3
- Model:
4
- module: models.GCN
5
- class: Edge_Network
6
- args:
7
- hid_size: 64
8
- in_size: 7
9
- out_size: 1
10
- n_layers: 4
11
- n_proc_steps: 4
12
- dropout: 0
13
- Training:
14
- epochs: 500
15
- batch_size: 4096
16
- learning_rate: 0.0001
17
- gamma: 0.99
18
- Datasets:
19
- ttH_CP_even: &dataset_defn
20
- module: root_gnn_base.dataset
21
- class: LazyDataset
22
- shuffle_chunks: 10
23
- batch_size: 4096
24
- padding_mode: NONE #one of STEPS, FIXED, or NONE
25
- args: &dataset_args
26
- name: ttH_CP_even
27
- label: 0
28
- # weight_var: weight
29
- chunks: 10
30
- buffer_size: 3
31
- file_names: ttH_NLO.root
32
- tree_name: output
33
- fold_var: Number
34
- raw_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/stats_all/
35
- save_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/processed_graphs/stats_all/ttH_CP_even_vs_odd_batch_size_4096/
36
- node_branch_names:
37
- - [jet_pt, ele_pt, mu_pt, ph_pt, MET_met]
38
- - [jet_eta, ele_eta, mu_eta, ph_eta, 0]
39
- - [jet_phi, ele_phi, mu_phi, ph_phi, MET_phi]
40
- - CALC_E
41
- - [jet_btag, 0, 0, 0, 0]
42
- - [0, ele_charge, mu_charge, 0, 0]
43
- - NODE_TYPE
44
- node_branch_types: [vector, vector, vector, vector, single]
45
- node_feature_scales: [1e-1, 1, 1, 1e-1, 1, 1, 1]
46
- folding:
47
- n_folds: 4
48
- test: [0]
49
- # validation: 1
50
- train: [1, 2, 3]
51
- ttH_CP_odd:
52
- <<: *dataset_defn
53
- args:
54
- <<: *dataset_args
55
- name: ttH_CP_odd
56
- label: 1
57
- file_names: ttH_CPodd.root
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
root_gnn_dgl/configs/stats_all/ttH_CP_even_vs_odd_batch_size_8192.yaml DELETED
@@ -1,57 +0,0 @@
1
- Training_Name: ttH_CP_even_vs_odd_batch_size_8192
2
- Training_Directory: trainings/stats_all/ttH_CP_even_vs_odd_batch_size_8192
3
- Model:
4
- module: models.GCN
5
- class: Edge_Network
6
- args:
7
- hid_size: 64
8
- in_size: 7
9
- out_size: 1
10
- n_layers: 4
11
- n_proc_steps: 4
12
- dropout: 0
13
- Training:
14
- epochs: 500
15
- batch_size: 8192
16
- learning_rate: 0.0001
17
- gamma: 0.99
18
- Datasets:
19
- ttH_CP_even: &dataset_defn
20
- module: root_gnn_base.dataset
21
- class: LazyDataset
22
- shuffle_chunks: 10
23
- batch_size: 8192
24
- padding_mode: NONE #one of STEPS, FIXED, or NONE
25
- args: &dataset_args
26
- name: ttH_CP_even
27
- label: 0
28
- # weight_var: weight
29
- chunks: 10
30
- buffer_size: 3
31
- file_names: ttH_NLO.root
32
- tree_name: output
33
- fold_var: Number
34
- raw_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/stats_all/
35
- save_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/processed_graphs/stats_all/ttH_CP_even_vs_odd_batch_size_8192/
36
- node_branch_names:
37
- - [jet_pt, ele_pt, mu_pt, ph_pt, MET_met]
38
- - [jet_eta, ele_eta, mu_eta, ph_eta, 0]
39
- - [jet_phi, ele_phi, mu_phi, ph_phi, MET_phi]
40
- - CALC_E
41
- - [jet_btag, 0, 0, 0, 0]
42
- - [0, ele_charge, mu_charge, 0, 0]
43
- - NODE_TYPE
44
- node_branch_types: [vector, vector, vector, vector, single]
45
- node_feature_scales: [1e-1, 1, 1, 1e-1, 1, 1, 1]
46
- folding:
47
- n_folds: 4
48
- test: [0]
49
- # validation: 1
50
- train: [1, 2, 3]
51
- ttH_CP_odd:
52
- <<: *dataset_defn
53
- args:
54
- <<: *dataset_args
55
- name: ttH_CP_odd
56
- label: 1
57
- file_names: ttH_CPodd.root
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
root_gnn_dgl/jobs/interactive.sh CHANGED
@@ -1 +1 @@
1
- salloc --nodes 1 --qos shared_interactive --time 04:00:00 --constraint gpu --account=trn007 --gres=gpu:1
 
1
+ salloc --nodes 1 --qos shared_interactive --time 04:00:00 --constraint gpu --account=atlas --gres=gpu:1
root_gnn_dgl/run_demo.sh CHANGED
@@ -31,7 +31,7 @@ python scripts/training_script.py --config configs/stats_100K/ttH_CP_even_vs_odd
31
 
32
  python scripts/training_script.py --config configs/stats_100K/finetuning_ttH_CP_even_vs_odd.yaml --preshuffle --nocompile --lazy
33
 
34
- # Inference
35
  files=(
36
  "ttH_NLO.root"
37
  "ttH_CPodd.root"
@@ -50,8 +50,8 @@ branch_name=(
50
  for ((j=0; j<${#files[@]}; j++))
51
  do
52
  python scripts/inference.py \
53
- --target "/global/cfs/projectdirs/trn007/lbl_atlas/data/stats_100K/${files[j]}" \
54
- --destination "/global/cfs/projectdirs/trn007/lbl_atlas/data/scores/stats_100K/${files[j]}" \
55
  --config "${config[@]}" \
56
  --branch_name "${branch_name[@]}" \
57
  --chunks 1 \
 
31
 
32
  python scripts/training_script.py --config configs/stats_100K/finetuning_ttH_CP_even_vs_odd.yaml --preshuffle --nocompile --lazy
33
 
34
+ # Inference: Writing GNN Scores for from-scratch training and finetuned training to root files
35
  files=(
36
  "ttH_NLO.root"
37
  "ttH_CPodd.root"
 
50
  for ((j=0; j<${#files[@]}; j++))
51
  do
52
  python scripts/inference.py \
53
+ --target "/global/cfs/projectdirs/atlas/joshua/gnn_data/stats_100K/${files[j]}" \
54
+ --destination "/global/cfs/projectdirs/atlas/joshua/gnn_data/scores/stats_100K/${files[j]}" \
55
  --config "${config[@]}" \
56
  --branch_name "${branch_name[@]}" \
57
  --chunks 1 \
root_gnn_dgl/setup/Dockerfile CHANGED
@@ -1,6 +1,6 @@
1
  FROM nvcr.io/nvidia/dgl:25.05-py3
2
 
3
- WORKDIR /global/cfs/projectdirs/atlas/joshua/GNN4Colliders
4
 
5
  LABEL maintainer.name="Joshua Ho"
6
  LABEL maintainer.email="ho22joshua@berkeley.edu"
 
1
  FROM nvcr.io/nvidia/dgl:25.05-py3
2
 
3
+ WORKDIR /workspace
4
 
5
  LABEL maintainer.name="Joshua Ho"
6
  LABEL maintainer.email="ho22joshua@berkeley.edu"
root_gnn_dgl/setup/build_image.sh CHANGED
@@ -1,4 +1,2 @@
1
- tag=$1
2
- echo $tag
3
- podman-hpc build -t joshuaho/pytorch:$tag --platform linux/amd64 .
4
- podman-hpc migrate joshuaho/pytorch:$tag
 
1
+ podman-hpc build -t joshuaho/pytorch:1.0 --platform linux/amd64 .
2
+ podman-hpc migrate joshuaho/pytorch:1.0
 
 
root_gnn_dgl/setup/environment.yml CHANGED
@@ -1,4 +1,4 @@
1
- name: dgl
2
  channels:
3
  - pytorch
4
  - dglteam/label/cu118
@@ -387,5 +387,4 @@ dependencies:
387
  - triton==2.3.0
388
  - typing-extensions==4.11.0
389
  - tzdata==2024.1
390
- - uproot==5.3.7
391
- prefix: /global/homes/j/joshuaho/.conda/envs/dgl
 
1
+ name: pytorch
2
  channels:
3
  - pytorch
4
  - dglteam/label/cu118
 
387
  - triton==2.3.0
388
  - typing-extensions==4.11.0
389
  - tzdata==2024.1
390
+ - uproot==5.3.7
 
root_gnn_dgl/setup/launch_image.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ podman-hpc run \
2
+ -it \
3
+ --mount type=bind,source=/pscratch/sd/j/joshuaho/,target=/pscratch/sd/j/joshuaho/ \
4
+ --mount type=bind,source=/global/cfs/projectdirs/atlas/joshua/,target=/global/cfs/projectdirs/atlas/joshua/ \
5
+ --rm \
6
+ --network host \
7
+ --gpu \
8
+ --shm-size=32g \
9
+ joshuaho/pytorch:1.0