ho22joshua commited on
Commit
5ceead6
·
1 Parent(s): 4a5b33b

working physicsnemo

Browse files
physicsnemo/{config.yaml → configs/config.yaml} RENAMED
@@ -16,7 +16,7 @@
16
  random_seed: 2
17
 
18
  scheduler:
19
- lr: 1.E-4
20
  lr_decay: 1.E-3
21
 
22
  training:
@@ -24,21 +24,18 @@ training:
24
 
25
  checkpoints:
26
  ckpt_path: "checkpoints"
27
- ckpt_name: "model.pt"
28
 
29
  performance:
30
  amp: False
31
  jit: False
32
 
33
- testing:
34
- graph: "s0090_0001.21.0.grph"
35
-
36
  architecture:
37
- processor_size: 5
38
- hidden_dim_node_encoder: 64
39
- hidden_dim_edge_encoder: 64
40
- hidden_dim_processor: 64
41
- hidden_dim_node_decoder: 64
42
  out_dim: 1
43
 
44
  paths:
@@ -62,6 +59,6 @@ root_dataset:
62
  globals: []
63
  weights: ""
64
  tracking: []
65
- chunks: 10
66
  batch_size: 8192
67
  train_val_test_split: [0.75, 0.24, 0.01]
 
16
  random_seed: 2
17
 
18
  scheduler:
19
+ lr: 1.E-3
20
  lr_decay: 1.E-3
21
 
22
  training:
 
24
 
25
  checkpoints:
26
  ckpt_path: "checkpoints"
27
+ ckpt_name: "config"
28
 
29
  performance:
30
  amp: False
31
  jit: False
32
 
 
 
 
33
  architecture:
34
+ processor_size: 8
35
+ hidden_dim_node_encoder: 128
36
+ hidden_dim_edge_encoder: 128
37
+ hidden_dim_processor: 128
38
+ hidden_dim_node_decoder: 128
39
  out_dim: 1
40
 
41
  paths:
 
59
  globals: []
60
  weights: ""
61
  tracking: []
62
+ step_size: 8192
63
  batch_size: 8192
64
  train_val_test_split: [0.75, 0.24, 0.01]
physicsnemo/{config_stats_all.yaml → configs/config_stats_all.yaml} RENAMED
@@ -24,15 +24,12 @@ training:
24
 
25
  checkpoints:
26
  ckpt_path: "checkpoints"
27
- ckpt_name: "model.pt"
28
 
29
  performance:
30
  amp: False
31
  jit: False
32
 
33
- testing:
34
- graph: "s0090_0001.21.0.grph"
35
-
36
  architecture:
37
  processor_size: 5
38
  hidden_dim_node_encoder: 64
@@ -62,6 +59,7 @@ root_dataset:
62
  globals: []
63
  weights: ""
64
  tracking: []
65
- step_size: 1024
66
  batch_size: 8192
67
- train_val_test_split: [0.75, 0.24, 0.01]
 
 
24
 
25
  checkpoints:
26
  ckpt_path: "checkpoints"
27
+ ckpt_name: "config_stats_all"
28
 
29
  performance:
30
  amp: False
31
  jit: False
32
 
 
 
 
33
  architecture:
34
  processor_size: 5
35
  hidden_dim_node_encoder: 64
 
59
  globals: []
60
  weights: ""
61
  tracking: []
62
+ step_size: 81920
63
  batch_size: 8192
64
+ train_val_test_split: [0.75, 0.24, 0.01]
65
+ prebatch: True
physicsnemo/configs/tHjb_CP_0_vs_45.yaml ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ignore_header_test
2
+ # Copyright 2023 Stanford University
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ random_seed: 2
17
+
18
+ scheduler:
19
+ lr: 1.E-3
20
+ lr_decay: 1.E-3
21
+
22
+ training:
23
+ epochs: 100
24
+
25
+ checkpoints:
26
+ ckpt_path: "checkpoints"
27
+ ckpt_name: "config"
28
+
29
+ performance:
30
+ amp: False
31
+ jit: False
32
+
33
+ architecture:
34
+ processor_size: 8
35
+ hidden_dim_node_encoder: 128
36
+ hidden_dim_edge_encoder: 128
37
+ hidden_dim_processor: 128
38
+ hidden_dim_node_decoder: 128
39
+ global_emb_dim: 128
40
+ out_dim: 1
41
+
42
+ paths:
43
+ data_dir: /global/cfs/projectdirs/atlas/joshua/ttHCP/ntuples/v02/preselection/merged_fixed/train/
44
+ save_dir: /pscratch/sd/j/joshuaho/physicsnemo/ttHCP/graphs/tHjb_CP_0_vs_45/
45
+ training_dir: ./training_tHjb_CP_0_vs_45/
46
+
47
+ datasets:
48
+ - name: tHjb_cp_0_had
49
+ load_path: ${paths.data_dir}/merged_aMCPy8_tHjb125_CP_0_AF3_had_scaled.root
50
+ label: 0
51
+ - name: tHjb_cp_0_lep
52
+ load_path: ${paths.data_dir}/merged_aMCPy8_tHjb125_CP_0_AF3_lep_scaled.root
53
+ label: 0
54
+ - name: tHjb_cp_45_had
55
+ load_path: ${paths.data_dir}/merged_aMCPy8_tHjb125_CP_45_AF3_had_scaled.root
56
+ label: 1
57
+ - name: tHjb_cp_45_lep
58
+ load_path: ${paths.data_dir}/merged_aMCPy8_tHjb125_CP_45_AF3_lep_scaled.root
59
+ label: 1
60
+
61
+ root_dataset:
62
+ ttree: output
63
+ dtype: torch.bfloat16
64
+ features:
65
+ # pt, eta, phi, energy, btag, charge, node_type
66
+ jet: [m_jet_pt, m_jet_eta, m_jet_phi, CALC_E, m_jet_PCbtag, 0, 0]
67
+ electron: [m_el_pt, m_el_eta, m_el_phi, CALC_E, 0, m_el_charge, 1]
68
+ muon: [m_mu_pt, m_mu_eta, m_mu_phi, CALC_E, 0, m_mu_charge, 2]
69
+ photon: [ph_pt_myy, ph_eta, ph_phi, CALC_E, 0, 0, 3]
70
+ met: [m_met, 0, m_met_phi, CALC_E, 0, 0, 4]
71
+ globals: [NUM_NODES]
72
+ weights: m_weightXlumi
73
+ tracking: []
74
+ step_size: 16384
75
+ batch_size: 16384
76
+ train_val_test_split: [0.5, 0.25, 0.25]
77
+ prebatch:
78
+ enabled: True
79
+ chunk_size: 512
physicsnemo/configs/tHjb_CP_0_vs_90.yaml ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ignore_header_test
2
+ # Copyright 2023 Stanford University
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ random_seed: 2
17
+
18
+ scheduler:
19
+ lr: 1.E-3
20
+ lr_decay: 1.E-3
21
+
22
+ training:
23
+ epochs: 100
24
+
25
+ checkpoints:
26
+ ckpt_path: "checkpoints"
27
+ ckpt_name: "tHjb_CP_0_vs_90"
28
+
29
+ performance:
30
+ amp: False
31
+ jit: False
32
+
33
+ architecture:
34
+ processor_size: 8
35
+ hidden_dim_node_encoder: 128
36
+ hidden_dim_edge_encoder: 128
37
+ hidden_dim_processor: 128
38
+ hidden_dim_node_decoder: 128
39
+ global_emb_dim: 128
40
+ out_dim: 1
41
+
42
+ paths:
43
+ data_dir: /global/cfs/projectdirs/atlas/joshua/ttHCP/ntuples/v02/preselection/merged_fixed/train/
44
+ save_dir: /pscratch/sd/j/joshuaho/physicsnemo/ttHCP/graphs/tHjb_CP_0_vs_90/
45
+ training_dir: ./tHjb_CP_0_vs_90/
46
+
47
+ datasets:
48
+ - name: tHjb_cp_0_had
49
+ load_path: ${paths.data_dir}/merged_aMCPy8_tHjb125_CP_0_AF3_had_scaled.root
50
+ label: 0
51
+ - name: tHjb_cp_0_lep
52
+ load_path: ${paths.data_dir}/merged_aMCPy8_tHjb125_CP_0_AF3_lep_scaled.root
53
+ label: 0
54
+ - name: tHjb_cp_90_had
55
+ load_path: ${paths.data_dir}/merged_aMCPy8_tHjb125_CP_90_AF3_had_scaled.root
56
+ label: 1
57
+ - name: tHjb_cp_90_lep
58
+ load_path: ${paths.data_dir}/merged_aMCPy8_tHjb125_CP_90_AF3_lep_scaled.root
59
+ label: 1
60
+
61
+ root_dataset:
62
+ ttree: output
63
+ dtype: torch.bfloat16
64
+ features:
65
+ # pt, eta, phi, energy, btag, charge, node_type
66
+ jet: [m_jet_pt, m_jet_eta, m_jet_phi, CALC_E, m_jet_PCbtag, 0, 0]
67
+ electron: [m_el_pt, m_el_eta, m_el_phi, CALC_E, 0, m_el_charge, 1]
68
+ muon: [m_mu_pt, m_mu_eta, m_mu_phi, CALC_E, 0, m_mu_charge, 2]
69
+ photon: [ph_pt_myy, ph_eta, ph_phi, CALC_E, 0, 0, 3]
70
+ met: [m_met, 0, m_met_phi, CALC_E, 0, 0, 4]
71
+ globals: [NUM_NODES]
72
+ weights: 1
73
+ tracking: []
74
+ step_size: 16384
75
+ batch_size: 16384
76
+ train_val_test_split: [0.5, 0.25, 0.25]
77
+ prebatch:
78
+ enabled: True
79
+ chunk_size: 512
physicsnemo/configs/tHjb_CP_0_vs_90_globals.yaml ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ignore_header_test
2
+ # Copyright 2023 Stanford University
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ random_seed: 2
17
+
18
+ scheduler:
19
+ lr: 1.E-3
20
+ lr_decay: 1.E-3
21
+
22
+ training:
23
+ epochs: 100
24
+
25
+ checkpoints:
26
+ ckpt_path: "checkpoints"
27
+ ckpt_name: "tHjb_CP_0_vs_90_globals"
28
+
29
+ performance:
30
+ amp: False
31
+ jit: False
32
+
33
+ architecture:
34
+ base_gnn:
35
+ input_dim_nodes: 7
36
+ input_dim_edges: 3
37
+ output_dim: 128
38
+ processor_size: 8
39
+ hidden_dim_node_encoder: 128
40
+ hidden_dim_edge_encoder: 128
41
+ hidden_dim_processor: 128
42
+ hidden_dim_node_decoder: 128
43
+ global_emb_dim: 128
44
+ global_feat_dim: 5
45
+ out_dim: 1
46
+
47
+ paths:
48
+ data_dir: /global/cfs/projectdirs/atlas/joshua/ttHCP/ntuples/v02/preselection/merged_fixed/train/
49
+ save_dir: /pscratch/sd/j/joshuaho/physicsnemo/ttHCP/graphs/tHjb_CP_0_vs_90_globals/
50
+ training_dir: ./tHjb_CP_0_vs_90_globals/
51
+
52
+ datasets:
53
+ - name: tHjb_cp_0_had
54
+ load_path: ${paths.data_dir}/merged_aMCPy8_tHjb125_CP_0_AF3_had_scaled.root
55
+ label: 0
56
+ - name: tHjb_cp_0_lep
57
+ load_path: ${paths.data_dir}/merged_aMCPy8_tHjb125_CP_0_AF3_lep_scaled.root
58
+ label: 0
59
+ - name: tHjb_cp_90_had
60
+ load_path: ${paths.data_dir}/merged_aMCPy8_tHjb125_CP_90_AF3_had_scaled.root
61
+ label: 1
62
+ - name: tHjb_cp_90_lep
63
+ load_path: ${paths.data_dir}/merged_aMCPy8_tHjb125_CP_90_AF3_lep_scaled.root
64
+ label: 1
65
+
66
+ root_dataset:
67
+ ttree: output
68
+ dtype: torch.bfloat16
69
+ features:
70
+ # pt, eta, phi, energy, btag, charge, node_type
71
+ jet: [m_jet_pt, m_jet_eta, m_jet_phi, CALC_E, m_jet_PCbtag, 0, 0]
72
+ electron: [m_el_pt, m_el_eta, m_el_phi, CALC_E, 0, m_el_charge, 1]
73
+ muon: [m_mu_pt, m_mu_eta, m_mu_phi, CALC_E, 0, m_mu_charge, 2]
74
+ photon: [ph_pt_myy, ph_eta, ph_phi, CALC_E, 0, 0, 3]
75
+ met: [m_met, 0, m_met_phi, CALC_E, 0, 0, 4]
76
+ globals: [NUM_NODES, eta_H, pt_H, eta_recotop1, pT_recotop1]
77
+ weights: 1
78
+ tracking: []
79
+ step_size: 16384
80
+ batch_size: 16384
81
+ train_val_test_split: [0.5, 0.25, 0.25]
82
+ prebatch:
83
+ enabled: True
84
+ chunk_size: 512
physicsnemo/{dataset.py → dataset/Dataset.py} RENAMED
@@ -3,25 +3,25 @@ import uproot
3
  import dgl
4
  import torch
5
  import numpy as np
6
- import awkward as ak
7
  from omegaconf import DictConfig
8
  from typing import List
9
  from concurrent.futures import ProcessPoolExecutor, as_completed
10
  from tqdm import tqdm
11
 
12
- import dataset_utils as utils
13
-
14
- from torch.utils.data import Dataset
15
 
16
  from dgl.dataloading import GraphDataLoader
17
 
18
- class RootDataset:
19
  def __init__(
20
  self,
21
  name: str,
22
  label: int,
23
  load_path: str,
24
  save_path: str,
 
25
  device: str,
26
  cfg: DictConfig
27
  ):
@@ -29,31 +29,38 @@ class RootDataset:
29
  self.label = label
30
  self.load_path = load_path
31
  self.save_path = save_path
 
32
  self.data = None
33
  self.device = device
34
 
35
  self.ttree = cfg.ttree
36
- self.particles = cfg.particles
37
  self.features = cfg.features
 
38
  self.globals = cfg.globals
 
39
  self.step_size = cfg.step_size
40
  self.batch_size = cfg.batch_size
41
 
 
 
42
  self.train_val_test_split = cfg.train_val_test_split
43
  assert np.sum(self.train_val_test_split) == 1, "train_val_test_split must sum to 1"
44
 
45
- dtype_str = getattr(cfg, "type", "torch.float32")
46
- if isinstance(dtype_str, str) and dtype_str.startswith("torch."):
47
- self.dtype = getattr(torch, dtype_str.split(".")[-1], torch.float32)
48
- else:
49
- self.dtype = torch.float32
50
  print(f"initializing dataset {name} with dtype {self.dtype}")
51
 
52
  def get_branches(self) -> List[str]:
53
- branches = [f"{p}_{f}" for p in self.particles for f in self.features]
54
- branches += self.globals
55
- branches = ["MET_met" if b == "MET_pt" else b for b in branches]
56
- return branches
 
 
 
 
 
 
 
 
57
 
58
  def process(self):
59
  branches = self.get_branches()
@@ -68,43 +75,46 @@ class RootDataset:
68
 
69
  with ProcessPoolExecutor(max_workers=num_cpus) as executor:
70
  futures = []
71
- for chunk_id, arrays in enumerate(
72
- tqdm(
73
- uproot.iterate(
74
  f"{self.load_path}:{self.ttree}",
75
  expressions=[b for b in branches if b in available_branches],
76
  step_size=self.step_size,
77
  library="ak"
78
  ),
79
- desc="loading root file",
80
- total=total_chunks,
81
- position=0,
82
- leave=False
83
- )
84
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
- cfg = utils.ChunkConfig(
87
- name=self.name,
88
- label=self.label,
89
- chunk_id=chunk_id,
90
- batch_size=self.batch_size,
91
- arrays=arrays,
92
- particles=self.particles,
93
- features=self.features,
94
- branches=branches,
95
- dtype=self.dtype,
96
- save_path=self.save_path,
97
- )
98
-
99
- futures.append(executor.submit(utils.process_chunk, cfg))
100
-
101
-
102
- for future in as_completed(futures):
103
  try:
104
  future.result()
105
  except Exception as e:
106
  import traceback
107
- print("Exception in worker process:")
108
  traceback.print_exception(type(e), e, e.__traceback__)
109
  return
110
 
@@ -119,12 +129,18 @@ class RootDataset:
119
  self.process()
120
 
121
  graph_tuple_list = []
122
-
123
  for chunk_id, f in enumerate(chunk_files):
124
  if chunk_id < total_chunks - 1:
125
- n_graphs = self.step_size
 
 
 
126
  else:
127
- n_graphs = num_entries - self.step_size * (total_chunks - 1)
 
 
 
128
  graph_tuple_list.extend((f, idx) for idx in range(n_graphs))
129
 
130
  split = self.train_val_test_split
@@ -136,54 +152,91 @@ class RootDataset:
136
  val_tuples = graph_tuple_list[n_train:n_train + n_val]
137
  test_tuples = graph_tuple_list[n_train + n_val:]
138
  return train_tuples, val_tuples, test_tuples
 
 
 
 
 
 
139
 
140
- class GraphDataset(Dataset):
141
- def __init__(self, graphs, labels):
142
- self.graphs = graphs
143
- self.labels = labels
144
  def __len__(self):
145
- return len(self.graphs)
146
- def __getitem__(self, idx):
147
- return self.graphs[idx], self.labels[idx]
148
- def shuffle(self):
149
- # TODO: implement graph shuffling
150
- return self.graphs
151
 
152
- class GraphTupleDataset:
153
- def __init__(self, tuple_list):
154
- self.tuples = tuple_list
155
- def __len__(self):
156
- return len(self.tuples)
157
  def __getitem__(self, idx):
158
- filepath, graph_idx = self.tuples[idx]
159
- graphs, labels = utils.load_graphs(filepath)
160
- return graphs[graph_idx], labels[graph_idx]
161
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  def get_dataset(cfg: DictConfig, device):
163
 
164
  all_train = []
165
  all_val = []
166
  all_test = []
167
 
 
 
 
 
 
 
168
  for ds in cfg.datasets:
169
  name = ds['name']
170
  load_path = ds.get('load_path', f"{cfg.paths.data_dir}/{name}.root")
171
  save_path = ds.get('save_path', f"{cfg.paths.save_dir}/")
172
- datastet = RootDataset(name, ds.get('label'), load_path, save_path, device, cfg.root_dataset)
173
  train, val, test = datastet.load()
174
  all_train.extend(train)
175
  all_val.extend(val)
176
  all_test.extend(test)
177
 
178
- train_dataset = GraphTupleDataset(all_train)
179
- val_dataset = GraphTupleDataset(all_val)
180
- test_dataset = GraphTupleDataset(all_test)
 
 
181
 
182
- batch_size = cfg.root_dataset.batch_size
 
 
 
 
 
183
 
184
- train_loader = GraphDataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=2)
185
- val_loader = GraphDataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=2)
186
- test_loader = GraphDataLoader(test_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=2)
187
 
188
  print("all data loaded successfully")
189
  print(f"train: {len(train_dataset)}, val: {len(val_dataset)}, test: {len(test_dataset)}")
 
3
  import dgl
4
  import torch
5
  import numpy as np
 
6
  from omegaconf import DictConfig
7
  from typing import List
8
  from concurrent.futures import ProcessPoolExecutor, as_completed
9
  from tqdm import tqdm
10
 
11
+ from dataset import GraphBuilder
12
+ from dataset import Graphs
13
+ from dataset import Normalization
14
 
15
  from dgl.dataloading import GraphDataLoader
16
 
17
+ class Dataset:
18
  def __init__(
19
  self,
20
  name: str,
21
  label: int,
22
  load_path: str,
23
  save_path: str,
24
+ dtype: torch.dtype,
25
  device: str,
26
  cfg: DictConfig
27
  ):
 
29
  self.label = label
30
  self.load_path = load_path
31
  self.save_path = save_path
32
+ self.dtype = dtype
33
  self.data = None
34
  self.device = device
35
 
36
  self.ttree = cfg.ttree
 
37
  self.features = cfg.features
38
+ self.weights = cfg.weights
39
  self.globals = cfg.globals
40
+ self.tracking = cfg.tracking
41
  self.step_size = cfg.step_size
42
  self.batch_size = cfg.batch_size
43
 
44
+ self.prebatch = cfg.get('prebatch', {'enabled': False})
45
+
46
  self.train_val_test_split = cfg.train_val_test_split
47
  assert np.sum(self.train_val_test_split) == 1, "train_val_test_split must sum to 1"
48
 
 
 
 
 
 
49
  print(f"initializing dataset {name} with dtype {self.dtype}")
50
 
51
  def get_branches(self) -> List[str]:
52
+ node_branches = [
53
+ branches
54
+ for particle in self.features.values()
55
+ for branches in particle
56
+ if isinstance(branches, str) and (branches != "CALC_E" or branches != "NUM_NODES")
57
+ ]
58
+ global_branches = [x for x in self.globals if isinstance(x, str)]
59
+ weight_branch = [self.weights] if isinstance(self.weights, str) else []
60
+ tracking_branches = [x for x in self.tracking if isinstance(x, str)]
61
+ label_branch = [self.label] if isinstance(self.label, str) else []
62
+
63
+ return node_branches + global_branches + weight_branch + tracking_branches + label_branch
64
 
65
  def process(self):
66
  branches = self.get_branches()
 
75
 
76
  with ProcessPoolExecutor(max_workers=num_cpus) as executor:
77
  futures = []
78
+
79
+ with tqdm(
80
+ uproot.iterate(
81
  f"{self.load_path}:{self.ttree}",
82
  expressions=[b for b in branches if b in available_branches],
83
  step_size=self.step_size,
84
  library="ak"
85
  ),
86
+ desc="loading root file",
87
+ total=total_chunks,
88
+ position=0,
89
+ leave=True
90
+ ) as pbar:
91
+
92
+ for chunk_id, arrays in enumerate(pbar):
93
+
94
+ cfg = GraphBuilder.ChunkConfig(
95
+ name=self.name,
96
+ label=self.label,
97
+ chunk_id=chunk_id,
98
+ batch_size=self.batch_size,
99
+ arrays=arrays,
100
+ features=self.features,
101
+ globals=self.globals,
102
+ tracking=self.tracking,
103
+ weights=self.weights,
104
+ branches=branches,
105
+ dtype=self.dtype,
106
+ save_path=self.save_path,
107
+ prebatch = self.prebatch,
108
+ )
109
 
110
+ futures.append(executor.submit(GraphBuilder.process_chunk, cfg))
111
+
112
+ for idx, future in enumerate(as_completed(futures)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  try:
114
  future.result()
115
  except Exception as e:
116
  import traceback
117
+ print(f"exception in chunk: {idx}")
118
  traceback.print_exception(type(e), e, e.__traceback__)
119
  return
120
 
 
129
  self.process()
130
 
131
  graph_tuple_list = []
132
+
133
  for chunk_id, f in enumerate(chunk_files):
134
  if chunk_id < total_chunks - 1:
135
+ if (self.prebatch.enabled):
136
+ n_graphs = self.step_size // self.prebatch.chunk_size
137
+ else:
138
+ n_graphs = self.step_size
139
  else:
140
+ if (self.prebatch.enabled):
141
+ n_graphs = (num_entries - self.step_size * (total_chunks - 1)) // self.prebatch.chunk_size + 1
142
+ else:
143
+ n_graphs = num_entries - self.step_size * (total_chunks - 1)
144
  graph_tuple_list.extend((f, idx) for idx in range(n_graphs))
145
 
146
  split = self.train_val_test_split
 
152
  val_tuples = graph_tuple_list[n_train:n_train + n_val]
153
  test_tuples = graph_tuple_list[n_train + n_val:]
154
  return train_tuples, val_tuples, test_tuples
155
+
156
+ class GraphTupleDataset:
157
+ def __init__(self, tuple_list, stats):
158
+ self.tuple_list = tuple_list
159
+ self.stats = stats
160
+ self.cache = {}
161
 
 
 
 
 
162
  def __len__(self):
163
+ return len(self.tuple_list)
 
 
 
 
 
164
 
 
 
 
 
 
165
  def __getitem__(self, idx):
166
+ f, graph_idx = self.tuple_list[idx]
167
+ if f in self.cache:
168
+ g = self.cache[f]
169
+ else:
170
+ g = Graphs.load_graphs(f)
171
+ g.normalize(self.stats)
172
+ self.cache[f] = g
173
+ return g[graph_idx]
174
+
175
+ @staticmethod
176
+ def collate_fn(samples):
177
+ all_graphs = []
178
+ all_metadata = {}
179
+
180
+ # Initialize keys in all_metadata from the first sample
181
+ for k in samples[0][1]:
182
+ all_metadata[k] = []
183
+
184
+ for graph, metadata in samples:
185
+ all_graphs.append(graph)
186
+ for k, v in metadata.items():
187
+ all_metadata[k].append(v)
188
+
189
+ # Stack or concatenate metadata for each key
190
+ for k in all_metadata:
191
+ # If v is a tensor, stack or cat as appropriate
192
+ # Use torch.cat if v is already [N, ...] (e.g. labels, features)
193
+ # Use torch.stack if v is scalar or needs new dimension
194
+ try:
195
+ all_metadata[k] = torch.cat(all_metadata[k], dim=0)
196
+ except Exception:
197
+ all_metadata[k] = torch.stack(all_metadata[k], dim=0)
198
+
199
+ batched_graph = dgl.batch(all_graphs)
200
+ return batched_graph, all_metadata
201
+
202
  def get_dataset(cfg: DictConfig, device):
203
 
204
  all_train = []
205
  all_val = []
206
  all_test = []
207
 
208
+ dtype_str = getattr(cfg.root_dataset, "dtype", "torch.float32")
209
+ if isinstance(dtype_str, str) and dtype_str.startswith("torch."):
210
+ dtype = getattr(torch, dtype_str.split(".")[-1], torch.float32)
211
+ else:
212
+ dtype = torch.float32
213
+
214
  for ds in cfg.datasets:
215
  name = ds['name']
216
  load_path = ds.get('load_path', f"{cfg.paths.data_dir}/{name}.root")
217
  save_path = ds.get('save_path', f"{cfg.paths.save_dir}/")
218
+ datastet = Dataset(name, ds.get('label'), load_path, save_path, dtype, device, cfg.root_dataset)
219
  train, val, test = datastet.load()
220
  all_train.extend(train)
221
  all_val.extend(val)
222
  all_test.extend(test)
223
 
224
+ stats = Normalization.global_stats(f"{cfg.paths.save_dir}/stats/", dtype=dtype)
225
+
226
+ train_dataset = GraphTupleDataset(all_train, stats)
227
+ val_dataset = GraphTupleDataset(all_val, stats)
228
+ test_dataset = GraphTupleDataset(all_test, stats)
229
 
230
+ if (cfg.root_dataset.get('prebatch', False)):
231
+ batch_size = cfg.root_dataset.batch_size // cfg.root_dataset.prebatch.chunk_size
232
+ collate_fn = GraphTupleDataset.collate_fn
233
+ else:
234
+ batch_size = cfg.root_dataset.batch_size
235
+ collate_fn = None
236
 
237
+ train_loader = GraphDataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=5, drop_last=False, collate_fn=collate_fn)
238
+ val_loader = GraphDataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=5, drop_last=False, collate_fn=collate_fn)
239
+ test_loader = GraphDataLoader(test_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=0, drop_last=False, collate_fn=collate_fn)
240
 
241
  print("all data loaded successfully")
242
  print(f"train: {len(train_dataset)}, val: {len(val_dataset)}, test: {len(test_dataset)}")
physicsnemo/dataset/GraphBuilder.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dgl
2
+ import torch
3
+ import numpy as np
4
+ import awkward as ak
5
+ from dataclasses import dataclass
6
+ from typing import List, Any, Union
7
+
8
+ from dataset.Graphs import Graphs, save_graphs
9
+ from dataset import Normalization
10
+
11
+ @dataclass
12
+ class ChunkConfig:
13
+ name: str
14
+ label: Union[str, int]
15
+ chunk_id: int
16
+ batch_size: int
17
+ arrays: List[Any]
18
+ features: List[Any]
19
+ globals: List[Any]
20
+ weights: Union[str, float]
21
+ tracking: List[Any]
22
+ branches: List[Any]
23
+ dtype: torch.dtype
24
+ save_path: str
25
+ prebatch: dict
26
+
27
+ def process_chunk(cfg: ChunkConfig):
28
+ # Collect everything as lists first
29
+ graph_list = []
30
+ meta_dict = {
31
+ 'globals': [],
32
+ 'label': [],
33
+ 'weight': [],
34
+ 'tracking': [],
35
+ 'batch_num_nodes': [],
36
+ 'batch_num_edges': [],
37
+ }
38
+
39
+ for i in range(len(cfg.arrays)):
40
+ g, meta = process_single_entry(cfg, i)
41
+ graph_list.append(g)
42
+ for k in meta_dict:
43
+ meta_dict[k].append(meta[k])
44
+
45
+ # Stack all metadata fields into tensors
46
+ for k in meta_dict:
47
+ meta_dict[k] = torch.stack(meta_dict[k])
48
+
49
+ graphs = Graphs(graphs=graph_list, metadata=meta_dict)
50
+ Normalization.save_stats(graphs, f"{cfg.save_path}/stats/{cfg.name}_{cfg.chunk_id:04d}.json")
51
+
52
+ if getattr(cfg.prebatch, "enabled", False):
53
+ graphs.shuffle()
54
+ graphs.batch(cfg.prebatch["chunk_size"])
55
+
56
+ save_graphs(graphs, f"{cfg.save_path}/{cfg.name}_{cfg.chunk_id:04d}.bin")
57
+
58
+ def process_single_entry(cfg, i):
59
+ # 1) node features
60
+ node_features: List[torch.Tensor] = []
61
+
62
+ for particle, branch_list in cfg.features.items():
63
+ feature_tensors: List[torch.Tensor] = []
64
+ for branch in branch_list:
65
+ if branch == "CALC_E":
66
+ pT = feature_tensors[0]
67
+ eta = feature_tensors[1]
68
+ val = pT * torch.cosh(eta)
69
+ elif isinstance(branch, str):
70
+ arr = cfg.arrays[branch][i]
71
+ val = torch.from_numpy(ak.to_numpy(arr)).to(cfg.dtype)
72
+ else:
73
+ length = feature_tensors[0].shape[0]
74
+ val = torch.full((length,), float(branch), dtype=cfg.dtype)
75
+ feature_tensors.append(val)
76
+
77
+ if feature_tensors and feature_tensors[0].numel() > 0:
78
+ block = torch.stack(feature_tensors, dim=1)
79
+ node_features.append(block)
80
+
81
+ node_features = torch.cat(node_features, dim=0) if node_features else torch.empty((0, len(cfg.features)), dtype=cfg.dtype)
82
+
83
+ # 2) global features
84
+ global_feat_list: List[torch.Tensor] = []
85
+ for b in cfg.globals:
86
+ if b == "NUM_NODES":
87
+ global_feat_list.append(torch.tensor([len(node_features)], dtype=cfg.dtype))
88
+ else:
89
+ arr = cfg.arrays[b][i]
90
+ global_feat_list.append(torch.from_numpy(ak.to_numpy(arr)).to(cfg.dtype))
91
+ global_feat = torch.cat(global_feat_list, dim=0) if global_feat_list else torch.zeros((1,), dtype=cfg.dtype)
92
+
93
+ # 3) tracking
94
+ tracking_list: List[torch.Tensor] = []
95
+ for b in cfg.tracking:
96
+ arr = cfg.arrays[b][i]
97
+ tracking_list.append(torch.from_numpy(ak.to_numpy(arr)).to(cfg.dtype))
98
+ tracking = torch.cat(tracking_list, dim=0) if tracking_list else torch.zeros((1,), dtype=cfg.dtype)
99
+
100
+ # 4) weight
101
+ weight = float(cfg.arrays[cfg.weights][i]) if isinstance(cfg.weights, str) else cfg.weights
102
+ weight = torch.tensor(weight, dtype=cfg.dtype)
103
+
104
+ # 5) label
105
+ label = float(cfg.arrays[cfg.label][i]) if isinstance(cfg.label, str) else cfg.label
106
+ label = torch.tensor(label, dtype=cfg.dtype)
107
+
108
+ # 6) make the DGLGraph
109
+ g = make_graph(node_features, dtype=cfg.dtype)
110
+
111
+ # 7) batch_num_nodes and batch_num_edges
112
+ batch_num_nodes = g.batch_num_nodes()
113
+ batch_num_edges = g.batch_num_edges()
114
+
115
+ meta = {
116
+ 'globals': global_feat,
117
+ 'label': label,
118
+ 'weight': weight,
119
+ 'tracking': tracking,
120
+ 'batch_num_nodes': batch_num_nodes,
121
+ 'batch_num_edges': batch_num_edges,
122
+ }
123
+ return g, meta
124
+
125
+ src_dst_cache = {}
126
+ def get_src_dst(num_nodes):
127
+ if num_nodes not in src_dst_cache:
128
+ src, dst = torch.meshgrid(torch.arange(num_nodes), torch.arange(num_nodes), indexing='ij')
129
+ src_dst_cache[num_nodes] = (src.flatten(), dst.flatten())
130
+ return src_dst_cache[num_nodes]
131
+
132
+ @torch.jit.script
133
+ def compute_edge_features(eta, phi, src, dst):
134
+ deta = eta[src] - eta[dst]
135
+ dphi = phi[src] - phi[dst]
136
+ dphi = torch.remainder(dphi + np.pi, 2 * np.pi) - np.pi
137
+ dR = torch.sqrt(deta ** 2 + dphi ** 2)
138
+ edge_features = torch.stack([dR, deta, dphi], dim=1)
139
+ return edge_features
140
+
141
+ def make_graph(node_features: torch.tensor, dtype=torch.float32):
142
+
143
+ num_nodes = node_features.shape[0]
144
+ if num_nodes == 0:
145
+ g = dgl.graph(([], []))
146
+ g.ndata['features'] = node_features
147
+ g.edata['features'] = torch.empty((0, 3), dtype=dtype)
148
+ g.globals = torch.tensor([0], dtype=dtype)
149
+ return g
150
+
151
+ src, dst = get_src_dst(num_nodes)
152
+ src = src.flatten()
153
+ dst = dst.flatten()
154
+ g = dgl.graph((src, dst))
155
+ g.ndata['features'] = node_features
156
+
157
+ eta = node_features[:, 1]
158
+ phi = node_features[:, 2]
159
+ edge_features = compute_edge_features(eta, phi, src, dst)
160
+ g.edata['features'] = edge_features
161
+
162
+ return g
physicsnemo/dataset/Graphs.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dgl
2
+ import torch
3
+ from dataclasses import dataclass, field
4
+ from typing import List, Dict
5
+
6
+ @dataclass
7
+ class Graphs:
8
+ graphs: List[dgl.DGLGraph]
9
+ metadata: Dict[str, torch.Tensor]
10
+
11
+ def __len__(self):
12
+ return len(self.graphs)
13
+
14
+ def __getitem__(self, idx):
15
+ meta = {k: v[idx] for k, v in self.metadata.items()}
16
+ return self.graphs[idx], meta
17
+
18
+ def shuffle(self):
19
+ idx = torch.randperm(len(self.graphs))
20
+ self.graphs = [self.graphs[i] for i in idx]
21
+ for k in self.metadata:
22
+ self.metadata[k] = self.metadata[k][idx]
23
+
24
+ def batch(self, batch_size, node_feature_dim=None, dtype=None):
25
+ """
26
+ In-place batching: after this, self.graphs is a list of batched DGLGraphs,
27
+ and self.metadata[k] is a tensor of shape [num_batches, batch_size, ...].
28
+ """
29
+ batched_graphs = []
30
+ batched_meta = {k: [] for k in self.metadata}
31
+ N = len(self.graphs)
32
+
33
+ # Infer node_feature_dim and dtype if not specified
34
+ if node_feature_dim is None and N > 0:
35
+ feats = self.graphs[0].ndata['features']
36
+ node_feature_dim = feats.shape[1] if feats.ndim > 1 else 1
37
+ if dtype is None and N > 0:
38
+ dtype = self.graphs[0].ndata['features'].dtype
39
+
40
+ for start in range(0, N, batch_size):
41
+ end = start + batch_size
42
+ batch_graphs = self.graphs[start:end]
43
+ batch_meta = {k: v[start:end] for k, v in self.metadata.items()}
44
+
45
+ # Padding if needed
46
+ pad_count = batch_size - len(batch_graphs)
47
+ if pad_count > 0:
48
+ dummy_graph = dgl.graph(([], []))
49
+ dummy_graph.ndata['features'] = torch.empty((0, node_feature_dim), dtype=dtype)
50
+ dummy_graph.edata['features'] = torch.empty((0, 3), dtype=dtype) # assuming 3 edge features
51
+ batch_graphs += [dummy_graph] * pad_count
52
+
53
+ # Pad metadata with zeros
54
+ for k, v in batch_meta.items():
55
+ shape = list(v[0].shape) if len(v) > 0 else []
56
+ pad_tensor = torch.zeros([pad_count] + shape, dtype=v.dtype, device=v.device)
57
+ batch_meta[k] = torch.cat([v, pad_tensor], dim=0)
58
+ else:
59
+ for k, v in batch_meta.items():
60
+ batch_meta[k] = torch.stack(v, dim=0) if isinstance(v, list) else v
61
+
62
+ batched_graphs.append(dgl.batch(batch_graphs))
63
+ for k in batched_meta:
64
+ batched_meta[k].append(batch_meta[k])
65
+
66
+ # Now stack along a new axis: [num_batches, batch_size, ...]
67
+ for k in batched_meta:
68
+ self.metadata[k] = torch.stack(batched_meta[k], dim=0)
69
+
70
+ self.graphs = batched_graphs
71
+
72
+ def normalize(self, stats):
73
+ node_mean, node_std, _ = stats['node']
74
+ edge_mean, edge_std, _ = stats['edge']
75
+ for g in self.graphs:
76
+ g.ndata['features'] = (g.ndata['features'] - node_mean) / node_std
77
+ g.edata['features'] = (g.edata['features'] - edge_mean) / edge_std
78
+
79
+ def save_graphs(graphs: Graphs, f: str):
80
+ meta_to_save = {k: v for k, v in graphs.metadata.items()}
81
+ dgl.save_graphs(f, graphs.graphs, meta_to_save)
82
+
83
+ def load_graphs(f: str) -> Graphs:
84
+ g, meta = dgl.load_graphs(f)
85
+ for k in meta:
86
+ if not isinstance(meta[k], torch.Tensor):
87
+ meta[k] = torch.stack(meta[k])
88
+ return Graphs(graphs=g, metadata=meta)
physicsnemo/dataset/Normalization.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import json
3
+ import os
4
+ from dataset.Graphs import Graphs
5
+ from typing import List, Dict, Tuple
6
+
7
+ def combine_feature_stats(chunks: List[Dict]) -> Tuple[torch.Tensor, torch.Tensor, int]:
8
+ """
9
+ Combine mean/std/count from multiple chunks using Welford's algorithm.
10
+ Returns combined mean, std, and total count.
11
+ """
12
+ n_total = 0
13
+ mean_total = None
14
+ M2_total = None
15
+
16
+ for chunk in chunks:
17
+ n_k = chunk['count']
18
+ if n_k == 0:
19
+ continue
20
+
21
+ mean_k = torch.tensor(chunk['mean'])
22
+ std_k = torch.tensor(chunk['std'])
23
+ M2_k = (std_k ** 2) * n_k
24
+
25
+ if n_total == 0:
26
+ mean_total = mean_k
27
+ M2_total = M2_k
28
+ n_total = n_k
29
+ else:
30
+ delta = mean_k - mean_total
31
+ N = n_total + n_k
32
+ mean_total += delta * (n_k / N)
33
+ M2_total += M2_k + (delta ** 2) * (n_total * n_k / N)
34
+ n_total = N
35
+
36
+ if n_total == 0:
37
+ return torch.tensor([]), torch.tensor([]), 0
38
+
39
+ std_total = torch.sqrt(M2_total / n_total)
40
+ return mean_total, std_total, n_total
41
+
42
+ def global_stats(dirpath: str, dtype: torch.dtype) -> Dict[str, Tuple[torch.Tensor, torch.Tensor, int]]:
43
+ """
44
+ Load all JSON stats files in a directory, combine node, edge, and global stats,
45
+ and optionally save the combined stats as JSON to `save_path`.
46
+ """
47
+
48
+ combined_stats_path = os.path.join(dirpath, "global_stats.json")
49
+
50
+ if not os.path.exists(combined_stats_path):
51
+ stats_list = []
52
+ for fname in os.listdir(dirpath):
53
+ if fname.endswith('.json'):
54
+ with open(os.path.join(dirpath, fname), 'r') as f:
55
+ stats_list.append(json.load(f))
56
+
57
+ node_stats = [s['node'] for s in stats_list]
58
+ edge_stats = [s['edge'] for s in stats_list]
59
+
60
+ combined = {
61
+ 'node': combine_feature_stats(node_stats),
62
+ 'edge': combine_feature_stats(edge_stats),
63
+ }
64
+
65
+ combined_json = {}
66
+ for key, (mean, std, count) in combined.items():
67
+ combined_json[key] = {
68
+ 'mean': mean.tolist() if mean.numel() > 0 else [],
69
+ 'std': std.tolist() if std.numel() > 0 else [],
70
+ 'count': count,
71
+ }
72
+
73
+ with open(combined_stats_path, 'w') as f:
74
+ json.dump(combined_json, f, indent=4)
75
+
76
+ with open(combined_stats_path, 'r') as f:
77
+ combined_json = json.load(f)
78
+
79
+ def to_tensor(d):
80
+ mean = torch.tensor(d['mean'], dtype=dtype) if d['mean'] else torch.tensor([], dtype=dtype)
81
+ std = torch.tensor(d['std'], dtype=dtype) if d['std'] else torch.tensor([], dtype=dtype)
82
+ count = d['count']
83
+ return mean, std, count
84
+
85
+ return {
86
+ 'node': to_tensor(combined_json['node']),
87
+ 'edge': to_tensor(combined_json['edge']),
88
+ }
89
+
90
+ def compute_stats(feats, eps=1e-6):
91
+ mean = feats.mean(dim=0)
92
+ if feats.size(0) > 1:
93
+ var = ((feats - mean) ** 2).mean(dim=0)
94
+ else:
95
+ var = torch.zeros_like(mean)
96
+ std = torch.sqrt(var)
97
+ std = torch.where(std < eps, torch.full_like(std, eps), std)
98
+
99
+ return mean, std
100
+
101
+ def save_stats(graphs: 'Graphs', filepath: str, categorical_unique_threshold=50):
102
+ """
103
+ Compute and save normalization stats (mean, std, counts) for node and edge features.
104
+ Categorical features (few unique values) have normalization disabled (mean=0, std=1).
105
+ """
106
+ if len(graphs) == 0:
107
+ raise ValueError("No graphs to compute stats from.")
108
+
109
+ # Node and edge features
110
+ all_node_feats = torch.cat([g.ndata['features'] for g, _ in graphs], dim=0)
111
+ all_edge_feats = torch.cat([g.edata['features'] for g, _ in graphs], dim=0)
112
+
113
+ counts = {
114
+ 'node': all_node_feats.size(0),
115
+ 'edge': all_edge_feats.size(0),
116
+ }
117
+
118
+ node_mean, node_std = compute_stats(all_node_feats)
119
+ edge_mean, edge_std = compute_stats(all_edge_feats)
120
+
121
+ categorical_mask = torch.tensor([
122
+ torch.unique(all_node_feats[:, i]).numel() < categorical_unique_threshold
123
+ for i in range(node_mean.size(0))
124
+ ], dtype=torch.bool)
125
+ node_mean[categorical_mask] = 0.0
126
+ node_std[categorical_mask] = 1.0
127
+
128
+ stats = {
129
+ 'node': {
130
+ 'mean': node_mean.tolist(),
131
+ 'std': node_std.tolist(),
132
+ 'count': counts['node'],
133
+ },
134
+ 'edge': {
135
+ 'mean': edge_mean.tolist(),
136
+ 'std': edge_std.tolist(),
137
+ 'count': counts['edge'],
138
+ },
139
+ }
140
+
141
+ os.makedirs(os.path.dirname(filepath), exist_ok=True)
142
+
143
+ with open(filepath, 'w') as f:
144
+ json.dump(stats, f, indent=4)
physicsnemo/dataset_utils.py DELETED
@@ -1,121 +0,0 @@
1
- import dgl
2
- import torch
3
- import numpy as np
4
- import awkward as ak
5
- from dataclasses import dataclass
6
- from typing import List, Any, Dict
7
- from tqdm import tqdm
8
-
9
- @dataclass
10
- class ChunkConfig:
11
- name: str
12
- label: str
13
- chunk_id: int
14
- batch_size: int
15
- arrays: List[Any]
16
- particles: List[Any]
17
- features: List[Any]
18
- branches: List[Any]
19
- dtype: torch.dtype
20
- save_path: str
21
-
22
- def process_chunk(cfg: ChunkConfig):
23
- n_entries = len(cfg.arrays)
24
- arrays_ordered = {}
25
-
26
- for b in cfg.branches:
27
- if b in cfg.arrays.fields:
28
- arrays_ordered[b] = cfg.arrays[b]
29
- elif b.endswith("_energy"):
30
- prefix = b[:-7]
31
- pt_name = f"{prefix}_pt"
32
- if prefix == "MET":
33
- pt_name = f"{prefix}_met"
34
- eta_name = f"{prefix}_eta"
35
- arrays_ordered[b] = cfg.arrays[pt_name] * np.cosh(cfg.arrays[eta_name])
36
- elif "node_type" in b:
37
- prefix = b[:-10]
38
- pt_name = f"{prefix}_pt"
39
- if prefix == "MET":
40
- pt_name = f"{prefix}_met"
41
- index = cfg.particles.index(prefix)
42
- arrays_ordered[b] = ak.ones_like(cfg.arrays[pt_name]) * index
43
- else:
44
- prefix = b.split("_")[0]
45
- pt_name = f"{prefix}_pt"
46
- if prefix == "MET":
47
- pt_name = f"{prefix}_met"
48
- arrays_ordered[b] = ak.zeros_like(cfg.arrays[pt_name])
49
-
50
- graphs = []
51
-
52
- for i in range(n_entries):
53
- node_features_list = []
54
- for p in cfg.particles:
55
- feats = []
56
- for f in cfg.features:
57
- branch = f"{p}_{f}"
58
- if p == "MET" and f == "pt":
59
- branch = "MET_met"
60
- value = ak.to_numpy(arrays_ordered[branch][i])
61
- feats.append(value)
62
- if len(feats[0]) == 0:
63
- continue
64
- node_array = np.stack(feats, axis=1)
65
- node_features_list.append(node_array)
66
- if node_features_list:
67
- node_features = np.concatenate(node_features_list, axis=0)
68
- else:
69
- node_features = np.empty((0, len(cfg.features)))
70
- graphs.append(make_graph(node_features, dtype=cfg.dtype))
71
-
72
- labels = torch.full((len(graphs),), cfg.label, dtype=cfg.dtype)
73
- save_graphs(f"{cfg.save_path}/{cfg.name}_{cfg.chunk_id:04d}.bin", graphs, {'labels': labels})
74
- return
75
-
76
- def save_graphs(f: str, g: List[dgl.DGLGraph], metadata: Dict) -> None:
77
- dgl.save_graphs(f, g, metadata)
78
-
79
- def load_graphs(f: str):
80
- g, metadata = dgl.load_graphs(f)
81
- return g, metadata['labels']
82
-
83
- src_dst_cache = {}
84
- def get_src_dst(num_nodes):
85
- if num_nodes not in src_dst_cache:
86
- src, dst = torch.meshgrid(torch.arange(num_nodes), torch.arange(num_nodes), indexing='ij')
87
- src_dst_cache[num_nodes] = (src.flatten(), dst.flatten())
88
- return src_dst_cache[num_nodes]
89
-
90
- @torch.jit.script
91
- def compute_edge_features(eta, phi, src, dst):
92
- deta = eta[src] - eta[dst]
93
- dphi = phi[src] - phi[dst]
94
- dphi = torch.remainder(dphi + np.pi, 2 * np.pi) - np.pi
95
- dR = torch.sqrt(deta ** 2 + dphi ** 2)
96
- edge_features = torch.stack([dR, deta, dphi], dim=1)
97
- return edge_features
98
-
99
- # TODO: normalize all features
100
- def make_graph(node_features: np.array, dtype=torch.float32):
101
- node_features = torch.tensor(node_features, dtype=dtype)
102
- num_nodes = node_features.shape[0]
103
- if num_nodes == 0:
104
- g = dgl.graph(([], []))
105
- g.ndata['features'] = node_features
106
- g.edata['features'] = torch.empty((0, 3), dtype=dtype)
107
- g.globals = torch.tensor([0], dtype=dtype)
108
- return g
109
- src, dst = get_src_dst(num_nodes)
110
- src = src.flatten()
111
- dst = dst.flatten()
112
- g = dgl.graph((src, dst))
113
- g.ndata['features'] = node_features
114
-
115
- eta = node_features[:, 1]
116
- phi = node_features[:, 2]
117
- edge_features = compute_edge_features(eta, phi, src, dst)
118
- g.edata['features'] = edge_features
119
-
120
- g.globals = torch.tensor([num_nodes], dtype=dtype)
121
- return g
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
physicsnemo/metrics.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+
5
+ def bce(input, target, weights=None):
6
+
7
+ if input.shape != target.shape:
8
+ if input.shape[-1] == 1 and input.shape[:-1] == target.shape:
9
+ input = input.squeeze(-1)
10
+ elif target.shape[-1] == 1 and target.shape[:-1] == input.shape:
11
+ target = target.squeeze(-1)
12
+
13
+ loss = F.binary_cross_entropy_with_logits(input, target, reduction='none')
14
+ return torch.mean(loss)
15
+
16
+ def weighted_bce(input, target, weights=None):
17
+ """
18
+ Compute a weighted and label-normalized binary cross entropy (BCE) loss.
19
+
20
+ For each unique label in the target tensor, the BCE loss is computed and weighted,
21
+ then normalized by the sum of weights for that label. The final loss is the mean
22
+ of these per-label normalized losses.
23
+
24
+ Args:
25
+ input (Tensor): Predicted logits of shape (N, ...).
26
+ target (Tensor): Ground truth labels of shape (N, ...), with discrete label values.
27
+ weights (Tensor or None): Optional tensor of per-sample weights, same shape as input/target.
28
+
29
+ Returns:
30
+ Tensor: Scalar tensor representing the normalized weighted BCE loss.
31
+ """
32
+
33
+ if input.shape != target.shape:
34
+ if input.shape[-1] == 1 and input.shape[:-1] == target.shape:
35
+ input = input.squeeze(-1)
36
+ elif target.shape[-1] == 1 and target.shape[:-1] == input.shape:
37
+ target = target.squeeze(-1)
38
+
39
+ # Compute per-element BCE loss (no reduction)
40
+ loss = F.binary_cross_entropy_with_logits(input, target, reduction='none')
41
+
42
+ # If weights not provided, use ones
43
+ if weights is None:
44
+ weights = torch.ones_like(loss)
45
+
46
+ unique_labels = torch.unique(target)
47
+ normalized_losses = []
48
+ for label in unique_labels:
49
+ label_mask = (target == label) # This will be a bool tensor
50
+ # Defensive: make sure mask is bool
51
+ if label_mask.dtype != torch.bool:
52
+ label_mask = label_mask.bool()
53
+ label_weights = weights[label_mask]
54
+ label_losses = loss[label_mask]
55
+ weight_sum = label_weights.sum()
56
+ if weight_sum > 0:
57
+ label_loss = (label_weights * label_losses).sum() / weight_sum
58
+ normalized_losses.append(label_loss)
59
+
60
+ if normalized_losses:
61
+ return torch.stack(normalized_losses).mean()
62
+ else:
63
+ return torch.tensor(0.0, device=input.device)
64
+
65
+
66
+ def roc_auc_score(classes : np.ndarray,
67
+ predictions : np.ndarray,
68
+ weights : np.ndarray = None) -> float:
69
+ """
70
+ Calculating ROC AUC score as the probability of correct ordering
71
+ """
72
+
73
+ if weights is None:
74
+ weights = np.ones_like(predictions)
75
+
76
+ assert len(classes) == len(predictions) == len(weights)
77
+ assert classes.ndim == predictions.ndim == weights.ndim == 1
78
+ class0, class1 = sorted(np.unique(classes))
79
+
80
+ data = np.empty(
81
+ shape=len(classes),
82
+ dtype=[('c', classes.dtype),
83
+ ('p', predictions.dtype),
84
+ ('w', weights.dtype)]
85
+ )
86
+ data['c'], data['p'], data['w'] = classes, predictions, weights
87
+
88
+ data = data[np.argsort(data['c'])]
89
+ data = data[np.argsort(data['p'], kind='mergesort')] # here we're relying on stability as we need class orders preserved
90
+
91
+ correction = 0.
92
+ # mask1 - bool mask to highlight collision areas
93
+ # mask2 - bool mask with collision areas' start points
94
+ mask1 = np.empty(len(data), dtype=bool)
95
+ mask2 = np.empty(len(data), dtype=bool)
96
+ mask1[0] = mask2[-1] = False
97
+ mask1[1:] = data['p'][1:] == data['p'][:-1]
98
+ if mask1.any():
99
+ mask2[:-1] = ~mask1[:-1] & mask1[1:]
100
+ mask1[:-1] |= mask1[1:]
101
+ ids, = mask2.nonzero()
102
+ correction = sum([((dsplit['c'] == class0) * dsplit['w'] * msplit).sum() *
103
+ ((dsplit['c'] == class1) * dsplit['w'] * msplit).sum()
104
+ for dsplit, msplit in zip(np.split(data, ids), np.split(mask1, ids))]) * 0.5
105
+
106
+ weights_0 = data['w'] * (data['c'] == class0)
107
+ weights_1 = data['w'] * (data['c'] == class1)
108
+ cumsum_0 = weights_0.cumsum()
109
+
110
+ return ((cumsum_0 * weights_1).sum() - correction) / (weights_1.sum() * cumsum_0[-1])
physicsnemo/models/MeshGraphNet.py CHANGED
@@ -6,25 +6,123 @@ import dgl
6
  from physicsnemo.models.meshgraphnet import MeshGraphNet as PhysicsNemoMeshGraphNet
7
 
8
  class MeshGraphNet(nn.Module):
9
- def __init__(self, *args, out_dim=1, **kwargs):
10
  super().__init__()
11
- # Initialize the PhysicsNemo MeshGraphNet
12
- self.base_gnn = PhysicsNemoMeshGraphNet(*args, **kwargs)
13
- # Assume node_output_dim is known or infer from args/kwargs
14
- node_output_dim = kwargs.get('hidden_dim_node_decoder', 64)
15
- self.mlp = nn.Linear(node_output_dim, out_dim)
16
 
17
- def forward(self, node_feats, edge_feats, batched_graph):
 
 
 
 
 
 
 
 
 
 
18
  """
19
- Args:
20
- node_feats: [total_num_nodes, node_feat_dim]
21
- edge_feats: [total_num_edges, edge_feat_dim]
22
- batched_graph: DGLGraph, batched graphs
 
23
  Returns:
24
  graph_pred: [num_graphs, out_dim]
25
  """
26
  node_pred = self.base_gnn(node_feats, edge_feats, batched_graph)
27
  batched_graph.ndata['h'] = node_pred
28
- graph_feat = dgl.readout_nodes(batched_graph, 'h', op='mean') # [num_graphs, node_output_dim]
29
- graph_pred = self.mlp(graph_feat) # [num_graphs, out_dim]
30
- return graph_pred
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from physicsnemo.models.meshgraphnet import MeshGraphNet as PhysicsNemoMeshGraphNet
7
 
8
  class MeshGraphNet(nn.Module):
9
+ def __init__(self, cfg):
10
  super().__init__()
11
+ base_gnn_cfg = cfg.base_gnn
12
+ self.base_gnn = PhysicsNemoMeshGraphNet(**base_gnn_cfg)
 
 
 
13
 
14
+ self.global_mlp = nn.Sequential(
15
+ nn.Linear(cfg.global_feat_dim, cfg.global_emb_dim),
16
+ nn.ReLU(),
17
+ )
18
+
19
+ self.mlp = nn.Linear(
20
+ base_gnn_cfg['output_dim'] + base_gnn_cfg['input_dim_edges'] + cfg.global_emb_dim,
21
+ cfg.out_dim
22
+ )
23
+
24
+ def forward(self, node_feats, edge_feats, global_feats, batched_graph, metadata={}):
25
  """
26
+ node_feats: [total_num_nodes, node_feat_dim]
27
+ edge_feats: [total_num_edges, edge_feat_dim]
28
+ global_feats: [num_graphs, global_feat_dim]
29
+ batched_graph: DGLGraph, representing the collection of graphs in a batch
30
+ metadata: dict, may contain 'batch_num_nodes', 'batch_num_edges', etc.
31
  Returns:
32
  graph_pred: [num_graphs, out_dim]
33
  """
34
  node_pred = self.base_gnn(node_feats, edge_feats, batched_graph)
35
  batched_graph.ndata['h'] = node_pred
36
+ batched_graph.edata['e'] = edge_feats
37
+
38
+ graph_node_feat = mean_nodes(batched_graph, 'h', node_split=metadata.get("batch_num_nodes", None))
39
+ graph_edge_feat = mean_edges(batched_graph, 'e', edge_split=metadata.get("batch_num_edges", None))
40
+
41
+ # Flatten global_feats if needed
42
+ if global_feats.ndim == 3:
43
+ global_feats = global_feats.view(-1, global_feats.shape[-1])
44
+ global_emb = self.global_mlp(global_feats) # [num_graphs, global_emb_dim]
45
+
46
+ combined_feat = torch.cat([graph_node_feat, graph_edge_feat, global_emb], dim=-1)
47
+ graph_pred = self.mlp(combined_feat)
48
+ return graph_pred
49
+
50
+ def mean_nodes(batched_graph, feat_key='h', op='mean', node_split=None):
51
+ """
52
+ Aggregates node features per disjoint graph in a batched DGLGraph.
53
+
54
+ Args:
55
+ batched_graph: DGLGraph
56
+ feat_key: str, node feature key
57
+ op: 'mean', 'sum', or 'max'
58
+ node_split: 1D tensor or list of ints (num nodes per graph)
59
+
60
+ Returns:
61
+ Tensor of shape [num_graphs, node_feat_dim]
62
+ """
63
+ h = batched_graph.ndata[feat_key]
64
+ if node_split is None or len(node_split) == 0:
65
+ if op == 'mean':
66
+ return dgl.mean_nodes(batched_graph, feat_key)
67
+ elif op == 'sum':
68
+ return dgl.sum_nodes(batched_graph, feat_key)
69
+ elif op == 'max':
70
+ return dgl.max_nodes(batched_graph, feat_key)
71
+ else:
72
+ raise ValueError(f"Unknown op: {op}")
73
+ else:
74
+ # Ensure node_split is a flat list of ints
75
+ if isinstance(node_split, torch.Tensor):
76
+ splits = node_split.view(-1).tolist()
77
+ else:
78
+ splits = [int(x) for x in node_split]
79
+ chunks = torch.split(h, splits, dim=0)
80
+ if op == 'mean':
81
+ out = torch.stack([chunk.mean(0) if chunk.shape[0] > 0 else torch.zeros_like(h[0]) for chunk in chunks])
82
+ elif op == 'sum':
83
+ out = torch.stack([chunk.sum(0) if chunk.shape[0] > 0 else torch.zeros_like(h[0]) for chunk in chunks])
84
+ elif op == 'max':
85
+ out = torch.stack([chunk.max(0).values if chunk.shape[0] > 0 else torch.zeros_like(h[0]) for chunk in chunks])
86
+ else:
87
+ raise ValueError(f"Unknown op: {op}")
88
+ return out
89
+
90
+ def mean_edges(batched_graph, feat_key='e', op='mean', edge_split=None):
91
+ """
92
+ Aggregates edge features per disjoint graph in a batched DGLGraph.
93
+
94
+ Args:
95
+ batched_graph: DGLGraph
96
+ feat_key: str, edge feature key
97
+ op: 'mean', 'sum', or 'max'
98
+ edge_split: 1D tensor or list of ints (num edges per graph)
99
+
100
+ Returns:
101
+ Tensor of shape [num_graphs, edge_feat_dim]
102
+ """
103
+ e = batched_graph.edata[feat_key]
104
+ if edge_split is None or len(edge_split) == 0:
105
+ if op == 'mean':
106
+ return dgl.mean_edges(batched_graph, feat_key)
107
+ elif op == 'sum':
108
+ return dgl.sum_edges(batched_graph, feat_key)
109
+ elif op == 'max':
110
+ return dgl.max_edges(batched_graph, feat_key)
111
+ else:
112
+ raise ValueError(f"Unknown op: {op}")
113
+ else:
114
+ # Ensure edge_split is a flat list of ints
115
+ if isinstance(edge_split, torch.Tensor):
116
+ splits = edge_split.view(-1).tolist()
117
+ else:
118
+ splits = [int(x) for x in edge_split]
119
+ chunks = torch.split(e, splits, dim=0)
120
+ if op == 'mean':
121
+ out = torch.stack([chunk.mean(0) if chunk.shape[0] > 0 else torch.zeros_like(e[0]) for chunk in chunks])
122
+ elif op == 'sum':
123
+ out = torch.stack([chunk.sum(0) if chunk.shape[0] > 0 else torch.zeros_like(e[0]) for chunk in chunks])
124
+ elif op == 'max':
125
+ out = torch.stack([chunk.max(0).values if chunk.shape[0] > 0 else torch.zeros_like(e[0]) for chunk in chunks])
126
+ else:
127
+ raise ValueError(f"Unknown op: {op}")
128
+ return out
physicsnemo/train.py CHANGED
@@ -14,73 +14,15 @@ from physicsnemo.launch.logging import (
14
  )
15
  from physicsnemo.launch.utils import load_checkpoint, save_checkpoint
16
  from physicsnemo.distributed.manager import DistributedManager
17
- from dataset import get_dataset
18
- import json
19
 
 
 
20
  import random
21
 
22
- from sklearn.metrics import roc_auc_score
23
-
24
  import models.MeshGraphNet as MeshGraphNet
 
 
25
 
26
- import torch.nn.functional as F
27
-
28
- def bce(input, target, device=None, weights=None):
29
- if input.shape != target.shape:
30
- if input.shape[-1] == 1 and input.shape[:-1] == target.shape:
31
- input = input.squeeze(-1)
32
- elif target.shape[-1] == 1 and target.shape[:-1] == input.shape:
33
- target = target.squeeze(-1)
34
-
35
- loss = F.binary_cross_entropy_with_logits(input, target, reduction='none')
36
- return torch.mean(loss)
37
-
38
- def weighted_bce(input, target, device=None, weights=None):
39
- """
40
- Compute a weighted and label-normalized binary cross entropy (BCE) loss.
41
-
42
- For each unique label in the target tensor, the BCE loss is computed and weighted,
43
- then normalized by the sum of weights for that label. The final loss is the mean
44
- of these per-label normalized losses.
45
-
46
- Args:
47
- input (Tensor): Predicted logits of shape (N, ...).
48
- target (Tensor): Ground truth labels of shape (N, ...), with discrete label values.
49
- device (torch.device or None): Device to move tensors to (optional).
50
- weights (Tensor or None): Optional tensor of per-sample weights, same shape as input/target.
51
-
52
- Returns:
53
- Tensor: Scalar tensor representing the normalized weighted BCE loss.
54
- """
55
-
56
- if input.shape != target.shape:
57
- if input.shape[-1] == 1 and input.shape[:-1] == target.shape:
58
- input = input.squeeze(-1)
59
- elif target.shape[-1] == 1 and target.shape[:-1] == input.shape:
60
- target = target.squeeze(-1)
61
-
62
- if weights is None:
63
- weights = torch.ones_like(target).to(device)
64
-
65
- # Compute per-element BCE loss (no reduction)
66
- loss = F.binary_cross_entropy_with_logits(input, target, reduction='none')
67
-
68
- # Vectorized label normalization
69
- unique_labels = torch.unique(target)
70
- normalized_losses = []
71
- for label in unique_labels:
72
- label_mask = (target == label)
73
- label_weights = weights[label_mask]
74
- label_losses = loss[label_mask]
75
- weight_sum = label_weights.sum()
76
- if weight_sum > 0:
77
- label_loss = (label_weights * label_losses).sum() / weight_sum
78
- normalized_losses.append(label_loss)
79
-
80
- if normalized_losses:
81
- return torch.stack(normalized_losses).mean()
82
- else:
83
- return torch.tensor(0.0, device=input.device)
84
 
85
  class MGNTrainer:
86
  def __init__(self, logger, cfg, dist):
@@ -91,37 +33,28 @@ class MGNTrainer:
91
  params = {}
92
 
93
  start = time.time()
94
- self.dataloader, self.valloader, self.testloader = get_dataset(cfg, self.device)
95
  print(f"total time loading dataset: {time.time() - start:.2f} seconds")
96
 
97
- dtype_str = getattr(cfg.root_dataset, "type", "torch.float32")
98
  if isinstance(dtype_str, str) and dtype_str.startswith("torch."):
99
  self.dtype = getattr(torch, dtype_str.split(".")[-1], torch.float32)
100
  else:
101
  self.dtype = torch.float32
102
 
103
- nodes_features = cfg.root_dataset.features
104
- edges_features = ["dR", "deta", "dphi"]
105
  global_features = ["num_nodes"]
106
 
107
- params["infeat_nodes"] = len(nodes_features)
108
- params["infeat_edges"] = len(edges_features)
109
  params["infeat_globals"] = len(global_features)
110
- params["out_dim"] = cfg.architecture.hidden_dim_node_encoder
111
- params["node_features"] = list(nodes_features)
112
- params["edges_features"] = edges_features
113
  params["global_features"] = global_features
114
 
115
- self.model = MeshGraphNet.MeshGraphNet(
116
- params["infeat_nodes"],
117
- params["infeat_edges"],
118
- params['out_dim'],
119
- processor_size=cfg.architecture.processor_size,
120
- hidden_dim_node_encoder=cfg.architecture.hidden_dim_node_encoder,
121
- hidden_dim_edge_encoder=cfg.architecture.hidden_dim_edge_encoder,
122
- hidden_dim_processor=cfg.architecture.hidden_dim_processor,
123
- hidden_dim_node_decoder=cfg.architecture.hidden_dim_node_decoder,
124
- )
125
  self.model = self.model.to(dtype=self.dtype, device=self.device)
126
 
127
  if cfg.performance.jit:
@@ -168,7 +101,7 @@ class MGNTrainer:
168
  loss.backward()
169
  self.optimizer.step()
170
 
171
- def train(self, graph, label):
172
  """
173
  Perform one training iteration over one graph. The training is performed
174
  over multiple timesteps, where the number of timesteps is specified in
@@ -181,11 +114,16 @@ class MGNTrainer:
181
  loss: loss value.
182
 
183
  """
 
 
 
 
 
184
  self.optimizer.zero_grad()
185
- pred = self.model(graph.ndata["features"], graph.edata["features"], graph)
186
- loss = bce(pred, label, device=self.device)
187
  self.backward(loss)
188
- return loss
189
 
190
  @torch.no_grad()
191
  def eval(self):
@@ -201,18 +139,25 @@ class MGNTrainer:
201
  """
202
  predictions = []
203
  labels = []
204
-
205
- for graph, label in self.valloader:
206
-
207
- graph = graph.to(self.device)
208
- pred = self.model(graph.ndata["features"], graph.edata["features"], graph)
 
 
 
 
 
209
  predictions.append(pred)
210
- labels.append(label)
 
211
 
212
  predictions = torch.cat(predictions, dim=0)
213
  labels = torch.cat(labels, dim=0)
 
214
 
215
- loss = weighted_bce(predictions, labels, device=self.device)
216
 
217
  # Convert logits to probabilities
218
  prob = torch.sigmoid(predictions)
@@ -223,13 +168,13 @@ class MGNTrainer:
223
 
224
  # Calculate AUC
225
  try:
226
- auc = roc_auc_score(labels_flat, prob_flat)
227
  except ValueError:
228
  auc = float('nan') # Not enough classes present for AUC
229
 
230
  return loss, auc
231
 
232
- @hydra.main(version_base=None, config_path=".", config_name="config")
233
  def do_training(cfg: DictConfig):
234
  """
235
  Perform training over all graphs in the dataset.
@@ -247,8 +192,9 @@ def do_training(cfg: DictConfig):
247
  dist = DistributedManager()
248
 
249
  # initialize loggers
 
250
  logger = PythonLogger("main")
251
- logger.file_logging()
252
 
253
  # initialize trainer
254
  trainer = MGNTrainer(logger, cfg, dist)
@@ -274,13 +220,14 @@ def do_training(cfg: DictConfig):
274
 
275
  # Training
276
  train_loss = []
277
- for graph, label in trainer.dataloader:
278
  trainer.model.train()
279
- train_loss.append(trainer.train(graph, label))
 
280
 
281
  val_loss, val_auc = trainer.eval()
282
 
283
- train_loss = torch.mean(torch.stack(train_loss)).item()
284
 
285
  logger.info(
286
  f"epoch: {epoch}, loss: {train_loss:10.3e}, val_loss: {val_loss:10.3e}, val_auc = {val_auc:10.3e}, time per epoch: {(time.time()-start):10.3e}"
 
14
  )
15
  from physicsnemo.launch.utils import load_checkpoint, save_checkpoint
16
  from physicsnemo.distributed.manager import DistributedManager
 
 
17
 
18
+ import json
19
+ from tqdm import tqdm
20
  import random
21
 
 
 
22
  import models.MeshGraphNet as MeshGraphNet
23
+ from dataset.Dataset import get_dataset
24
+ import metrics
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  class MGNTrainer:
28
  def __init__(self, logger, cfg, dist):
 
33
  params = {}
34
 
35
  start = time.time()
36
+ self.trainloader, self.valloader, self.testloader = get_dataset(cfg, self.device)
37
  print(f"total time loading dataset: {time.time() - start:.2f} seconds")
38
 
39
+ dtype_str = getattr(cfg.root_dataset, "dtype", "torch.float32")
40
  if isinstance(dtype_str, str) and dtype_str.startswith("torch."):
41
  self.dtype = getattr(torch, dtype_str.split(".")[-1], torch.float32)
42
  else:
43
  self.dtype = torch.float32
44
 
45
+ node_features = list(cfg.root_dataset.features.values())[0]
46
+ edge_features = ["dR", "deta", "dphi"]
47
  global_features = ["num_nodes"]
48
 
49
+ params["infeat_nodes"] = len(node_features)
50
+ params["infeat_edges"] = len(edge_features)
51
  params["infeat_globals"] = len(global_features)
52
+ params["out_dim"] = cfg.architecture.out_dim
53
+ params["node_features"] = list(node_features)
54
+ params["edge_features"] = edge_features
55
  params["global_features"] = global_features
56
 
57
+ self.model = MeshGraphNet.MeshGraphNet(cfg.architecture)
 
 
 
 
 
 
 
 
 
58
  self.model = self.model.to(dtype=self.dtype, device=self.device)
59
 
60
  if cfg.performance.jit:
 
101
  loss.backward()
102
  self.optimizer.step()
103
 
104
+ def train(self, graph, metadata):
105
  """
106
  Perform one training iteration over one graph. The training is performed
107
  over multiple timesteps, where the number of timesteps is specified in
 
114
  loss: loss value.
115
 
116
  """
117
+ graph = graph.to(self.device, non_blocking=True)
118
+ globals = metadata['globals'].to(self.device, non_blocking=True)
119
+ label = metadata['label'].to(self.device, non_blocking=True)
120
+ weight = metadata['weight'].to(self.device, non_blocking=True)
121
+
122
  self.optimizer.zero_grad()
123
+ pred = self.model(graph.ndata["features"], graph.edata["features"], globals, graph, metadata)
124
+ loss = metrics.weighted_bce(pred, label, weights=weight)
125
  self.backward(loss)
126
+ return loss.detach()
127
 
128
  @torch.no_grad()
129
  def eval(self):
 
139
  """
140
  predictions = []
141
  labels = []
142
+ weights = []
143
+
144
+ for graph, metadata in self.valloader:
145
+
146
+ graph = graph.to(self.device, non_blocking=True)
147
+ globals = metadata['globals'].to(self.device, non_blocking=True)
148
+ label = metadata['label'].to(self.device, non_blocking=True)
149
+ weight = metadata['weight'].to(self.device, non_blocking=True)
150
+
151
+ pred = self.model(graph.ndata["features"], graph.edata["features"], globals, graph, metadata)
152
  predictions.append(pred)
153
+ labels.append(label)
154
+ weights.append(weight)
155
 
156
  predictions = torch.cat(predictions, dim=0)
157
  labels = torch.cat(labels, dim=0)
158
+ weights = torch.cat(weights, dim=0)
159
 
160
+ loss = metrics.weighted_bce(predictions, labels, weights=weights)
161
 
162
  # Convert logits to probabilities
163
  prob = torch.sigmoid(predictions)
 
168
 
169
  # Calculate AUC
170
  try:
171
+ auc = metrics.roc_auc_score(labels_flat, prob_flat)
172
  except ValueError:
173
  auc = float('nan') # Not enough classes present for AUC
174
 
175
  return loss, auc
176
 
177
+ @hydra.main(version_base=None, config_path="./configs/", config_name="tHjb_CP_0_vs_45")
178
  def do_training(cfg: DictConfig):
179
  """
180
  Perform training over all graphs in the dataset.
 
192
  dist = DistributedManager()
193
 
194
  # initialize loggers
195
+ os.makedirs(cfg.checkpoints.ckpt_path, exist_ok=True)
196
  logger = PythonLogger("main")
197
+ logger.file_logging(os.path.join(cfg.checkpoints.ckpt_path, "train.log"))
198
 
199
  # initialize trainer
200
  trainer = MGNTrainer(logger, cfg, dist)
 
220
 
221
  # Training
222
  train_loss = []
223
+ for graph, metadata in tqdm(trainer.trainloader, desc=f"epoch {epoch} trianing"):
224
  trainer.model.train()
225
+ loss = trainer.train(graph, metadata)
226
+ train_loss.append(loss.item())
227
 
228
  val_loss, val_auc = trainer.eval()
229
 
230
+ train_loss = torch.tensor(train_loss).mean()
231
 
232
  logger.info(
233
  f"epoch: {epoch}, loss: {train_loss:10.3e}, val_loss: {val_loss:10.3e}, val_auc = {val_auc:10.3e}, time per epoch: {(time.time()-start):10.3e}"