ho22joshua commited on
Commit
4a5b33b
·
1 Parent(s): fbef6ad

updated dataset class

Browse files
physicsnemo/Dataset.py DELETED
@@ -1,293 +0,0 @@
1
- import os
2
- import uproot
3
- import dgl
4
- import torch
5
- import numpy as np
6
- import awkward as ak
7
- from omegaconf import DictConfig
8
- from typing import List
9
- from typing import Union
10
- import math
11
- import random
12
- from concurrent.futures import ThreadPoolExecutor, as_completed
13
-
14
- from torch.utils.data import Dataset
15
-
16
- from dgl.dataloading import GraphDataLoader
17
-
18
-
19
- os.environ["TMPDIR"] = "/pscratch/sd/j/joshuaho/tmp"
20
-
21
- def make_graph(node_features: np.array, dtype=torch.float32):
22
- node_features = torch.tensor(node_features, dtype=dtype)
23
- num_nodes = node_features.shape[0]
24
- if num_nodes == 0:
25
- # Return an empty graph
26
- g = dgl.graph(([], []))
27
- g.ndata['features'] = node_features
28
- g.edata['features'] = torch.empty((0, 3), dtype=dtype)
29
- g.globals = torch.tensor([0], dtype=dtype)
30
- return g
31
-
32
- src, dst = np.meshgrid(np.arange(num_nodes), np.arange(num_nodes))
33
- src = src.flatten()
34
- dst = dst.flatten()
35
- g = dgl.graph((src, dst))
36
- g.ndata['features'] = node_features # shape: (num_nodes, num_features)
37
-
38
- # Compute edge features
39
- eta = node_features[:, 1]
40
- phi = node_features[:, 2]
41
- deta = eta[src] - eta[dst]
42
- dphi = phi[src] - phi[dst]
43
- dphi = torch.remainder(dphi + np.pi, 2 * np.pi) - np.pi
44
- dR = torch.sqrt(deta ** 2 + dphi ** 2)
45
- edge_features = torch.stack([dR, deta, dphi], dim=1)
46
- g.edata['features'] = edge_features
47
-
48
- g.globals = torch.tensor([num_nodes], dtype=dtype)
49
-
50
- return g
51
-
52
- def process_chunk(args):
53
- name, label, chunk_id, arrays, particles, features, branches, dtype, save_path = args
54
- n_entries = len(arrays)
55
- arrays_ordered = {}
56
-
57
- arrays_ordered = {}
58
- for b in branches:
59
- if b in arrays.fields:
60
- arrays_ordered[b] = arrays[b]
61
- elif b.endswith("_energy"):
62
- prefix = b[:-7]
63
- pt_name = f"{prefix}_pt"
64
- if prefix == "MET":
65
- pt_name = f"{prefix}_met"
66
- eta_name = f"{prefix}_eta"
67
- arrays_ordered[b] = arrays[pt_name] * np.cosh(arrays[eta_name])
68
- elif "node_type" in b:
69
- prefix = b[:-10]
70
- pt_name = f"{prefix}_pt"
71
- if prefix == "MET":
72
- pt_name = f"{prefix}_met"
73
- index = particles.index(prefix)
74
- arrays_ordered[b] = ak.ones_like(arrays[pt_name]) * index
75
- else:
76
- prefix = b.split("_")[0]
77
- pt_name = f"{prefix}_pt"
78
- if prefix == "MET":
79
- pt_name = f"{prefix}_met"
80
- arrays_ordered[b] = ak.zeros_like(arrays[pt_name])
81
-
82
- graphs = []
83
- for i in range(n_entries):
84
- if (i % 250 == 0):
85
- print(f"{name} chunk {chunk_id} processed {i} events")
86
- node_features_list = []
87
- for p in particles:
88
- feats = []
89
- for f in features:
90
- branch = f"{p}_{f}"
91
- if p == "MET" and f == "pt":
92
- branch = "MET_met"
93
- value = ak.to_numpy(arrays_ordered[branch][i])
94
- feats.append(value)
95
- if len(feats[0]) == 0:
96
- continue
97
- node_array = np.stack(feats, axis=1)
98
- node_features_list.append(node_array)
99
- if node_features_list:
100
- node_features = np.concatenate(node_features_list, axis=0)
101
- else:
102
- node_features = np.empty((0, len(features)))
103
- graphs.append(make_graph(node_features, dtype=dtype))
104
-
105
- labels = torch.full((len(graphs),), label, dtype=dtype)
106
- dgl.save_graphs(f"{save_path}/{name}_{chunk_id:02d}.bin", graphs, {'label': labels})
107
- print(f"Saved {name} chunk {chunk_id:02d} to {save_path}/{name}_{chunk_id:03d}.bin")
108
- return
109
-
110
- class Root_Graph:
111
- def __init__(
112
- self,
113
- name: str,
114
- label: int,
115
- load_path: str,
116
- save_path: str,
117
- cfg: DictConfig
118
- ):
119
- self.name = name
120
- self.label = label
121
- self.load_path = load_path
122
- self.save_path = save_path
123
- self.data = None
124
-
125
- self.ttree = cfg.ttree
126
- self.particles = cfg.particles
127
- self.features = cfg.features
128
- self.globals = cfg.globals
129
- self.chunks = cfg.chunks
130
-
131
- self.train_val_test_split = cfg.train_val_test_split
132
- assert np.sum(self.train_val_test_split) == 1, "train_val_test_split must sum to 1"
133
-
134
- dtype_str = getattr(cfg, "type", "torch.float32")
135
- if isinstance(dtype_str, str) and dtype_str.startswith("torch."):
136
- self.dtype = getattr(torch, dtype_str.split(".")[-1], torch.float32)
137
- else:
138
- self.dtype = torch.float32
139
- print(f"Initializing dataset {name} with dtype {self.dtype}")
140
-
141
- def get_branches(self) -> List[str]:
142
- branches = [f"{p}_{f}" for p in self.particles for f in self.features]
143
- branches += self.globals
144
- branches = ["MET_met" if b == "MET_pt" else b for b in branches]
145
- return branches
146
-
147
- def process(self, max_workers: int = 128):
148
- branches = self.get_branches()
149
- with uproot.open(f"{self.load_path}:{self.ttree}") as tree:
150
- available_branches = set(tree.keys())
151
- num_entries = tree.num_entries
152
-
153
- print(f"Getting branches: {branches}")
154
-
155
- step_size = math.ceil(num_entries / self.chunks)
156
-
157
- # Prepare chunk arguments for each chunk
158
- chunk_args_list = []
159
- for chunk_id, arrays in enumerate(
160
- uproot.iterate(
161
- f"{self.load_path}:{self.ttree}",
162
- expressions=[b for b in branches if b in available_branches],
163
- step_size=step_size,
164
- library="ak"
165
- )
166
- ):
167
- # Pass everything needed for chunk processing
168
- chunk_args_list.append((self.name, self.label, chunk_id, arrays, self.particles, self.features, branches, self.dtype, self.save_path))
169
-
170
- # Parallel processing of chunks
171
-
172
- with ThreadPoolExecutor(max_workers) as executor:
173
- futures = [executor.submit(process_chunk, args) for args in chunk_args_list]
174
- for future in as_completed(futures):
175
- future.result()
176
-
177
- return
178
-
179
- def train_val_test(self):
180
- split = self.train_val_test_split
181
- # Collect all graphs and labels first
182
- all_graphs = []
183
- all_labels = []
184
- files = [f"{self.save_path}/{self.name}_{chunk_id:02d}.bin" for chunk_id in range(self.chunks)]
185
- for f in files:
186
- graphs, label_dict = dgl.load_graphs(f)
187
- all_graphs.extend(graphs)
188
- # Assuming label_dict['label'] is shape (num_graphs_in_file,)
189
- all_labels.extend(label_dict['label'].tolist())
190
-
191
- n = len(all_graphs)
192
- rand = np.random.rand(n)
193
- train_idx = rand < split[0]
194
- val_idx = (rand >= split[0]) & (rand < split[0] + split[1])
195
- test_idx = rand >= split[0] + split[1]
196
-
197
- train_graphs = [g for g, flag in zip(all_graphs, train_idx) if flag]
198
- val_graphs = [g for g, flag in zip(all_graphs, val_idx) if flag]
199
- test_graphs = [g for g, flag in zip(all_graphs, test_idx) if flag]
200
-
201
- train_labels = [l for l, flag in zip(all_labels, train_idx) if flag]
202
- val_labels = [l for l, flag in zip(all_labels, val_idx) if flag]
203
- test_labels = [l for l, flag in zip(all_labels, test_idx) if flag]
204
-
205
- train_labels = torch.tensor(train_labels)
206
- val_labels = torch.tensor(val_labels)
207
- test_labels = torch.tensor(test_labels)
208
-
209
- dgl.save_graphs(f"{self.save_path}/{self.name}_train.bin", train_graphs, {'label': train_labels})
210
- dgl.save_graphs(f"{self.save_path}/{self.name}_val.bin", val_graphs, {'label': val_labels})
211
- dgl.save_graphs(f"{self.save_path}/{self.name}_test.bin", test_graphs, {'label': test_labels})
212
-
213
- print(f"Train: {len(train_graphs)}, Val: {len(val_graphs)}, Test: {len(test_graphs)}")
214
-
215
- def load(self):
216
- # List of expected files
217
- files = [f"{self.save_path}/{self.name}_{chunk_id:02d}.bin" for chunk_id in range(self.chunks)]
218
-
219
- # Check if all files exist
220
- if not all(os.path.exists(f) for f in files):
221
- print("graphs not found, processing data...")
222
- self.process()
223
- else:
224
- print("graphs found, skipping processing...")
225
-
226
- # Check if train/val/test exist:
227
- files = [f"{self.save_path}/{self.name}_{split}.bin" for split in ["train", "val", "test"]]
228
- if not all(os.path.exists(f) for f in files):
229
- print("train/val/test split not found, splitting graphs...")
230
- self.train_val_test()
231
- else:
232
- print("train/val/test split found, skipping splitting...")
233
-
234
- print("loading graphs...")
235
- train_graphs, train_label_dict = dgl.load_graphs(f"{self.save_path}/{self.name}_train.bin")
236
- val_graphs, val_label_dict = dgl.load_graphs(f"{self.save_path}/{self.name}_val.bin")
237
- test_graphs, test_label_dict = dgl.load_graphs(f"{self.save_path}/{self.name}_test.bin")
238
-
239
- train_labels = train_label_dict['label']
240
- val_labels = val_label_dict['label']
241
- test_labels = test_label_dict['label']
242
-
243
- print(f"successfully loaded {self.name}")
244
-
245
- return train_graphs, train_labels, val_graphs, val_labels, test_graphs, test_labels
246
-
247
- class GraphDataset(Dataset):
248
- def __init__(self, graphs, labels):
249
- self.graphs = graphs
250
- self.labels = labels
251
- def __len__(self):
252
- return len(self.graphs)
253
- def __getitem__(self, idx):
254
- return self.graphs[idx], self.labels[idx]
255
-
256
- def get_dataset(cfg: DictConfig):
257
-
258
- random.seed(cfg.random_seed)
259
- np.random.seed(cfg.random_seed)
260
- torch.manual_seed(cfg.random_seed)
261
-
262
- all_train_graphs = []
263
- all_train_labels = []
264
- all_val_graphs = []
265
- all_val_labels = []
266
- all_test_graphs = []
267
- all_test_labels = []
268
-
269
- for ds in cfg.datasets:
270
- name = ds['name']
271
- load_path = ds.get('load_path', f"{cfg.paths.data_dir}/{name}.root")
272
- save_path = ds.get('save_path', f"{cfg.paths.save_dir}/")
273
- graph = Root_Graph(name, ds.get('label'), load_path, save_path, cfg.root_graph)
274
- train_graphs, train_labels, val_graphs, val_labels, test_graphs, test_labels = graph.load()
275
- all_train_graphs.extend(train_graphs)
276
- all_train_labels.extend(train_labels)
277
- all_val_graphs.extend(val_graphs)
278
- all_val_labels.extend(val_labels)
279
- all_test_graphs.extend(test_graphs)
280
- all_test_labels.extend(test_labels)
281
-
282
- train_dataset = GraphDataset(all_train_graphs, all_train_labels)
283
- val_dataset = GraphDataset(all_val_graphs, all_val_labels)
284
- test_dataset = GraphDataset(all_test_graphs, all_test_labels)
285
-
286
- batch_size = cfg.root_graph.batch_size
287
-
288
- train_loader = GraphDataLoader(train_dataset, batch_size=batch_size, shuffle=True)
289
- val_loader = GraphDataLoader(val_dataset, batch_size=batch_size, shuffle=False)
290
- test_loader = GraphDataLoader(test_dataset, batch_size=batch_size, shuffle=False)
291
-
292
- print("all data loaded successfully")
293
- return train_loader, val_loader, test_loader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
physicsnemo/config.yaml CHANGED
@@ -42,9 +42,9 @@ architecture:
42
  out_dim: 1
43
 
44
  paths:
45
- data_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/stats_100K
46
- save_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/physicsnemo_graphs/stats_100K
47
- training_dir: ./training/
48
 
49
  datasets:
50
  - name: ttH_cp_even
@@ -54,7 +54,7 @@ datasets:
54
  load_path: ${paths.data_dir}/ttH_CPodd.root
55
  label: 1
56
 
57
- root_graph:
58
  ttree: output
59
  type: torch.bfloat16
60
  particles: ["jet", "ele", "mu", "ph", "MET"]
@@ -62,6 +62,6 @@ root_graph:
62
  globals: []
63
  weights: ""
64
  tracking: []
65
- chunks: 32
66
  batch_size: 8192
67
  train_val_test_split: [0.75, 0.24, 0.01]
 
42
  out_dim: 1
43
 
44
  paths:
45
+ data_dir: /global/cfs/projectdirs/atlas/joshua/hackathon_data/stats_100K
46
+ save_dir: /pscratch/sd/j/joshuaho/physicsnemo/graphs/stats_100K
47
+ training_dir: ./training_stats_100K/
48
 
49
  datasets:
50
  - name: ttH_cp_even
 
54
  load_path: ${paths.data_dir}/ttH_CPodd.root
55
  label: 1
56
 
57
+ root_dataset:
58
  ttree: output
59
  type: torch.bfloat16
60
  particles: ["jet", "ele", "mu", "ph", "MET"]
 
62
  globals: []
63
  weights: ""
64
  tracking: []
65
+ chunks: 10
66
  batch_size: 8192
67
  train_val_test_split: [0.75, 0.24, 0.01]
physicsnemo/config_stats_all.yaml CHANGED
@@ -42,8 +42,8 @@ architecture:
42
  out_dim: 1
43
 
44
  paths:
45
- data_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/stats_all
46
- save_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/physicsnemo_graphs/stats_all
47
  training_dir: ./training_stats_all/
48
 
49
  datasets:
@@ -54,7 +54,7 @@ datasets:
54
  load_path: ${paths.data_dir}/ttH_CPodd.root
55
  label: 1
56
 
57
- root_graph:
58
  ttree: output
59
  type: torch.bfloat16
60
  particles: ["jet", "ele", "mu", "ph", "MET"]
@@ -62,6 +62,6 @@ root_graph:
62
  globals: []
63
  weights: ""
64
  tracking: []
65
- chunks: 128
66
  batch_size: 8192
67
  train_val_test_split: [0.75, 0.24, 0.01]
 
42
  out_dim: 1
43
 
44
  paths:
45
+ data_dir: /global/cfs/projectdirs/atlas/joshua/hackathon_data/stats_all
46
+ save_dir: /pscratch/sd/j/joshuaho/physicsnemo/graphs/stats_all
47
  training_dir: ./training_stats_all/
48
 
49
  datasets:
 
54
  load_path: ${paths.data_dir}/ttH_CPodd.root
55
  label: 1
56
 
57
+ root_dataset:
58
  ttree: output
59
  type: torch.bfloat16
60
  particles: ["jet", "ele", "mu", "ph", "MET"]
 
62
  globals: []
63
  weights: ""
64
  tracking: []
65
+ step_size: 1024
66
  batch_size: 8192
67
  train_val_test_split: [0.75, 0.24, 0.01]
physicsnemo/dataset.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uproot
3
+ import dgl
4
+ import torch
5
+ import numpy as np
6
+ import awkward as ak
7
+ from omegaconf import DictConfig
8
+ from typing import List
9
+ from concurrent.futures import ProcessPoolExecutor, as_completed
10
+ from tqdm import tqdm
11
+
12
+ import dataset_utils as utils
13
+
14
+ from torch.utils.data import Dataset
15
+
16
+ from dgl.dataloading import GraphDataLoader
17
+
18
+ class RootDataset:
19
+ def __init__(
20
+ self,
21
+ name: str,
22
+ label: int,
23
+ load_path: str,
24
+ save_path: str,
25
+ device: str,
26
+ cfg: DictConfig
27
+ ):
28
+ self.name = name
29
+ self.label = label
30
+ self.load_path = load_path
31
+ self.save_path = save_path
32
+ self.data = None
33
+ self.device = device
34
+
35
+ self.ttree = cfg.ttree
36
+ self.particles = cfg.particles
37
+ self.features = cfg.features
38
+ self.globals = cfg.globals
39
+ self.step_size = cfg.step_size
40
+ self.batch_size = cfg.batch_size
41
+
42
+ self.train_val_test_split = cfg.train_val_test_split
43
+ assert np.sum(self.train_val_test_split) == 1, "train_val_test_split must sum to 1"
44
+
45
+ dtype_str = getattr(cfg, "type", "torch.float32")
46
+ if isinstance(dtype_str, str) and dtype_str.startswith("torch."):
47
+ self.dtype = getattr(torch, dtype_str.split(".")[-1], torch.float32)
48
+ else:
49
+ self.dtype = torch.float32
50
+ print(f"initializing dataset {name} with dtype {self.dtype}")
51
+
52
+ def get_branches(self) -> List[str]:
53
+ branches = [f"{p}_{f}" for p in self.particles for f in self.features]
54
+ branches += self.globals
55
+ branches = ["MET_met" if b == "MET_pt" else b for b in branches]
56
+ return branches
57
+
58
+ def process(self):
59
+ branches = self.get_branches()
60
+ with uproot.open(f"{self.load_path}:{self.ttree}") as tree:
61
+ available_branches = set(tree.keys())
62
+ num_entries = tree.num_entries
63
+
64
+ print(f"getting branches: {branches}")
65
+
66
+ num_cpus = os.cpu_count()
67
+ total_chunks = np.ceil(num_entries / self.step_size)
68
+
69
+ with ProcessPoolExecutor(max_workers=num_cpus) as executor:
70
+ futures = []
71
+ for chunk_id, arrays in enumerate(
72
+ tqdm(
73
+ uproot.iterate(
74
+ f"{self.load_path}:{self.ttree}",
75
+ expressions=[b for b in branches if b in available_branches],
76
+ step_size=self.step_size,
77
+ library="ak"
78
+ ),
79
+ desc="loading root file",
80
+ total=total_chunks,
81
+ position=0,
82
+ leave=False
83
+ )
84
+ ):
85
+
86
+ cfg = utils.ChunkConfig(
87
+ name=self.name,
88
+ label=self.label,
89
+ chunk_id=chunk_id,
90
+ batch_size=self.batch_size,
91
+ arrays=arrays,
92
+ particles=self.particles,
93
+ features=self.features,
94
+ branches=branches,
95
+ dtype=self.dtype,
96
+ save_path=self.save_path,
97
+ )
98
+
99
+ futures.append(executor.submit(utils.process_chunk, cfg))
100
+
101
+
102
+ for future in as_completed(futures):
103
+ try:
104
+ future.result()
105
+ except Exception as e:
106
+ import traceback
107
+ print("Exception in worker process:")
108
+ traceback.print_exception(type(e), e, e.__traceback__)
109
+ return
110
+
111
+ def load(self):
112
+ with uproot.open(f"{self.load_path}:{self.ttree}") as tree:
113
+ num_entries = tree.num_entries
114
+ total_chunks = int(np.ceil(num_entries / self.step_size))
115
+
116
+ chunk_files = [f"{self.save_path}/{self.name}_{chunk_id:04d}.bin" for chunk_id in range(total_chunks)]
117
+ if not all(os.path.exists(f) for f in chunk_files):
118
+ print("graphs not found. processing root file...")
119
+ self.process()
120
+
121
+ graph_tuple_list = []
122
+
123
+ for chunk_id, f in enumerate(chunk_files):
124
+ if chunk_id < total_chunks - 1:
125
+ n_graphs = self.step_size
126
+ else:
127
+ n_graphs = num_entries - self.step_size * (total_chunks - 1)
128
+ graph_tuple_list.extend((f, idx) for idx in range(n_graphs))
129
+
130
+ split = self.train_val_test_split
131
+ n_total = len(graph_tuple_list)
132
+ n_train = int(split[0] * n_total)
133
+ n_val = int(split[1] * n_total)
134
+
135
+ train_tuples = graph_tuple_list[:n_train]
136
+ val_tuples = graph_tuple_list[n_train:n_train + n_val]
137
+ test_tuples = graph_tuple_list[n_train + n_val:]
138
+ return train_tuples, val_tuples, test_tuples
139
+
140
+ class GraphDataset(Dataset):
141
+ def __init__(self, graphs, labels):
142
+ self.graphs = graphs
143
+ self.labels = labels
144
+ def __len__(self):
145
+ return len(self.graphs)
146
+ def __getitem__(self, idx):
147
+ return self.graphs[idx], self.labels[idx]
148
+ def shuffle(self):
149
+ # TODO: implement graph shuffling
150
+ return self.graphs
151
+
152
+ class GraphTupleDataset:
153
+ def __init__(self, tuple_list):
154
+ self.tuples = tuple_list
155
+ def __len__(self):
156
+ return len(self.tuples)
157
+ def __getitem__(self, idx):
158
+ filepath, graph_idx = self.tuples[idx]
159
+ graphs, labels = utils.load_graphs(filepath)
160
+ return graphs[graph_idx], labels[graph_idx]
161
+
162
+ def get_dataset(cfg: DictConfig, device):
163
+
164
+ all_train = []
165
+ all_val = []
166
+ all_test = []
167
+
168
+ for ds in cfg.datasets:
169
+ name = ds['name']
170
+ load_path = ds.get('load_path', f"{cfg.paths.data_dir}/{name}.root")
171
+ save_path = ds.get('save_path', f"{cfg.paths.save_dir}/")
172
+ datastet = RootDataset(name, ds.get('label'), load_path, save_path, device, cfg.root_dataset)
173
+ train, val, test = datastet.load()
174
+ all_train.extend(train)
175
+ all_val.extend(val)
176
+ all_test.extend(test)
177
+
178
+ train_dataset = GraphTupleDataset(all_train)
179
+ val_dataset = GraphTupleDataset(all_val)
180
+ test_dataset = GraphTupleDataset(all_test)
181
+
182
+ batch_size = cfg.root_dataset.batch_size
183
+
184
+ train_loader = GraphDataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=2)
185
+ val_loader = GraphDataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=2)
186
+ test_loader = GraphDataLoader(test_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=2)
187
+
188
+ print("all data loaded successfully")
189
+ print(f"train: {len(train_dataset)}, val: {len(val_dataset)}, test: {len(test_dataset)}")
190
+ return train_loader, val_loader, test_loader
physicsnemo/dataset_utils.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dgl
2
+ import torch
3
+ import numpy as np
4
+ import awkward as ak
5
+ from dataclasses import dataclass
6
+ from typing import List, Any, Dict
7
+ from tqdm import tqdm
8
+
9
+ @dataclass
10
+ class ChunkConfig:
11
+ name: str
12
+ label: str
13
+ chunk_id: int
14
+ batch_size: int
15
+ arrays: List[Any]
16
+ particles: List[Any]
17
+ features: List[Any]
18
+ branches: List[Any]
19
+ dtype: torch.dtype
20
+ save_path: str
21
+
22
+ def process_chunk(cfg: ChunkConfig):
23
+ n_entries = len(cfg.arrays)
24
+ arrays_ordered = {}
25
+
26
+ for b in cfg.branches:
27
+ if b in cfg.arrays.fields:
28
+ arrays_ordered[b] = cfg.arrays[b]
29
+ elif b.endswith("_energy"):
30
+ prefix = b[:-7]
31
+ pt_name = f"{prefix}_pt"
32
+ if prefix == "MET":
33
+ pt_name = f"{prefix}_met"
34
+ eta_name = f"{prefix}_eta"
35
+ arrays_ordered[b] = cfg.arrays[pt_name] * np.cosh(cfg.arrays[eta_name])
36
+ elif "node_type" in b:
37
+ prefix = b[:-10]
38
+ pt_name = f"{prefix}_pt"
39
+ if prefix == "MET":
40
+ pt_name = f"{prefix}_met"
41
+ index = cfg.particles.index(prefix)
42
+ arrays_ordered[b] = ak.ones_like(cfg.arrays[pt_name]) * index
43
+ else:
44
+ prefix = b.split("_")[0]
45
+ pt_name = f"{prefix}_pt"
46
+ if prefix == "MET":
47
+ pt_name = f"{prefix}_met"
48
+ arrays_ordered[b] = ak.zeros_like(cfg.arrays[pt_name])
49
+
50
+ graphs = []
51
+
52
+ for i in range(n_entries):
53
+ node_features_list = []
54
+ for p in cfg.particles:
55
+ feats = []
56
+ for f in cfg.features:
57
+ branch = f"{p}_{f}"
58
+ if p == "MET" and f == "pt":
59
+ branch = "MET_met"
60
+ value = ak.to_numpy(arrays_ordered[branch][i])
61
+ feats.append(value)
62
+ if len(feats[0]) == 0:
63
+ continue
64
+ node_array = np.stack(feats, axis=1)
65
+ node_features_list.append(node_array)
66
+ if node_features_list:
67
+ node_features = np.concatenate(node_features_list, axis=0)
68
+ else:
69
+ node_features = np.empty((0, len(cfg.features)))
70
+ graphs.append(make_graph(node_features, dtype=cfg.dtype))
71
+
72
+ labels = torch.full((len(graphs),), cfg.label, dtype=cfg.dtype)
73
+ save_graphs(f"{cfg.save_path}/{cfg.name}_{cfg.chunk_id:04d}.bin", graphs, {'labels': labels})
74
+ return
75
+
76
+ def save_graphs(f: str, g: List[dgl.DGLGraph], metadata: Dict) -> None:
77
+ dgl.save_graphs(f, g, metadata)
78
+
79
+ def load_graphs(f: str):
80
+ g, metadata = dgl.load_graphs(f)
81
+ return g, metadata['labels']
82
+
83
+ src_dst_cache = {}
84
+ def get_src_dst(num_nodes):
85
+ if num_nodes not in src_dst_cache:
86
+ src, dst = torch.meshgrid(torch.arange(num_nodes), torch.arange(num_nodes), indexing='ij')
87
+ src_dst_cache[num_nodes] = (src.flatten(), dst.flatten())
88
+ return src_dst_cache[num_nodes]
89
+
90
+ @torch.jit.script
91
+ def compute_edge_features(eta, phi, src, dst):
92
+ deta = eta[src] - eta[dst]
93
+ dphi = phi[src] - phi[dst]
94
+ dphi = torch.remainder(dphi + np.pi, 2 * np.pi) - np.pi
95
+ dR = torch.sqrt(deta ** 2 + dphi ** 2)
96
+ edge_features = torch.stack([dR, deta, dphi], dim=1)
97
+ return edge_features
98
+
99
+ # TODO: normalize all features
100
+ def make_graph(node_features: np.array, dtype=torch.float32):
101
+ node_features = torch.tensor(node_features, dtype=dtype)
102
+ num_nodes = node_features.shape[0]
103
+ if num_nodes == 0:
104
+ g = dgl.graph(([], []))
105
+ g.ndata['features'] = node_features
106
+ g.edata['features'] = torch.empty((0, 3), dtype=dtype)
107
+ g.globals = torch.tensor([0], dtype=dtype)
108
+ return g
109
+ src, dst = get_src_dst(num_nodes)
110
+ src = src.flatten()
111
+ dst = dst.flatten()
112
+ g = dgl.graph((src, dst))
113
+ g.ndata['features'] = node_features
114
+
115
+ eta = node_features[:, 1]
116
+ phi = node_features[:, 2]
117
+ edge_features = compute_edge_features(eta, phi, src, dst)
118
+ g.edata['features'] = edge_features
119
+
120
+ g.globals = torch.tensor([num_nodes], dtype=dtype)
121
+ return g
physicsnemo/models/Edge_Network.py ADDED
File without changes
physicsnemo/{MeshGraphNet.py → models/MeshGraphNet.py} RENAMED
File without changes
physicsnemo/setup/Dockerfile CHANGED
@@ -21,5 +21,3 @@ RUN pip install --no-cache-dir mpi4py jupyter uproot
21
 
22
  # (Optional) Expose Jupyter port
23
  EXPOSE 8888
24
-
25
-
 
21
 
22
  # (Optional) Expose Jupyter port
23
  EXPOSE 8888
 
 
physicsnemo/train.py CHANGED
@@ -2,8 +2,9 @@ import time, os
2
 
3
  start = time.time()
4
  import torch
 
5
  from dgl.dataloading import GraphDataLoader
6
- from torch.cuda.amp import GradScaler
7
  import numpy as np
8
  import hydra
9
  from omegaconf import DictConfig
@@ -13,15 +14,27 @@ from physicsnemo.launch.logging import (
13
  )
14
  from physicsnemo.launch.utils import load_checkpoint, save_checkpoint
15
  from physicsnemo.distributed.manager import DistributedManager
16
- from Dataset import get_dataset
17
  import json
18
 
 
 
19
  from sklearn.metrics import roc_auc_score
20
 
21
- import MeshGraphNet
22
 
23
  import torch.nn.functional as F
24
 
 
 
 
 
 
 
 
 
 
 
25
  def weighted_bce(input, target, device=None, weights=None):
26
  """
27
  Compute a weighted and label-normalized binary cross entropy (BCE) loss.
@@ -47,12 +60,7 @@ def weighted_bce(input, target, device=None, weights=None):
47
  target = target.squeeze(-1)
48
 
49
  if weights is None:
50
- weights = torch.ones_like(target)
51
-
52
- if device is not None:
53
- input = input.to(device)
54
- target = target.to(device)
55
- weights = weights.to(device)
56
 
57
  # Compute per-element BCE loss (no reduction)
58
  loss = F.binary_cross_entropy_with_logits(input, target, reduction='none')
@@ -82,17 +90,17 @@ class MGNTrainer:
82
 
83
  params = {}
84
 
85
- norm_type = {"features": "normal", "labels": "normal"}
86
-
87
- self.dataloader, self.valloader, self.testloader = get_dataset(cfg)
88
 
89
- dtype_str = getattr(cfg.root_graph, "type", "torch.float32")
90
  if isinstance(dtype_str, str) and dtype_str.startswith("torch."):
91
  self.dtype = getattr(torch, dtype_str.split(".")[-1], torch.float32)
92
  else:
93
  self.dtype = torch.float32
94
 
95
- nodes_features = cfg.root_graph.features
96
  edges_features = ["dR", "deta", "dphi"]
97
  global_features = ["num_nodes"]
98
 
@@ -173,10 +181,9 @@ class MGNTrainer:
173
  loss: loss value.
174
 
175
  """
176
- graph = graph.to(self.device)
177
  self.optimizer.zero_grad()
178
  pred = self.model(graph.ndata["features"], graph.edata["features"], graph)
179
- loss = weighted_bce(pred, label, device=self.device)
180
  self.backward(loss)
181
  return loss
182
 
@@ -192,7 +199,6 @@ class MGNTrainer:
192
  Returns:
193
  loss (Tensor): The computed loss value (scalar).
194
  """
195
-
196
  predictions = []
197
  labels = []
198
 
@@ -232,6 +238,9 @@ def do_training(cfg: DictConfig):
232
  cfg: Dictionary of parameters.
233
 
234
  """
 
 
 
235
 
236
  # initialize distributed manager
237
  DistributedManager.initialize()
@@ -244,6 +253,20 @@ def do_training(cfg: DictConfig):
244
  # initialize trainer
245
  trainer = MGNTrainer(logger, cfg, dist)
246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  # training loop
248
  start = time.time()
249
  logger.info("Training started...")
 
2
 
3
  start = time.time()
4
  import torch
5
+ from torch.nn.parallel import DistributedDataParallel
6
  from dgl.dataloading import GraphDataLoader
7
+ from torch.amp import GradScaler
8
  import numpy as np
9
  import hydra
10
  from omegaconf import DictConfig
 
14
  )
15
  from physicsnemo.launch.utils import load_checkpoint, save_checkpoint
16
  from physicsnemo.distributed.manager import DistributedManager
17
+ from dataset import get_dataset
18
  import json
19
 
20
+ import random
21
+
22
  from sklearn.metrics import roc_auc_score
23
 
24
+ import models.MeshGraphNet as MeshGraphNet
25
 
26
  import torch.nn.functional as F
27
 
28
+ def bce(input, target, device=None, weights=None):
29
+ if input.shape != target.shape:
30
+ if input.shape[-1] == 1 and input.shape[:-1] == target.shape:
31
+ input = input.squeeze(-1)
32
+ elif target.shape[-1] == 1 and target.shape[:-1] == input.shape:
33
+ target = target.squeeze(-1)
34
+
35
+ loss = F.binary_cross_entropy_with_logits(input, target, reduction='none')
36
+ return torch.mean(loss)
37
+
38
  def weighted_bce(input, target, device=None, weights=None):
39
  """
40
  Compute a weighted and label-normalized binary cross entropy (BCE) loss.
 
60
  target = target.squeeze(-1)
61
 
62
  if weights is None:
63
+ weights = torch.ones_like(target).to(device)
 
 
 
 
 
64
 
65
  # Compute per-element BCE loss (no reduction)
66
  loss = F.binary_cross_entropy_with_logits(input, target, reduction='none')
 
90
 
91
  params = {}
92
 
93
+ start = time.time()
94
+ self.dataloader, self.valloader, self.testloader = get_dataset(cfg, self.device)
95
+ print(f"total time loading dataset: {time.time() - start:.2f} seconds")
96
 
97
+ dtype_str = getattr(cfg.root_dataset, "type", "torch.float32")
98
  if isinstance(dtype_str, str) and dtype_str.startswith("torch."):
99
  self.dtype = getattr(torch, dtype_str.split(".")[-1], torch.float32)
100
  else:
101
  self.dtype = torch.float32
102
 
103
+ nodes_features = cfg.root_dataset.features
104
  edges_features = ["dR", "deta", "dphi"]
105
  global_features = ["num_nodes"]
106
 
 
181
  loss: loss value.
182
 
183
  """
 
184
  self.optimizer.zero_grad()
185
  pred = self.model(graph.ndata["features"], graph.edata["features"], graph)
186
+ loss = bce(pred, label, device=self.device)
187
  self.backward(loss)
188
  return loss
189
 
 
199
  Returns:
200
  loss (Tensor): The computed loss value (scalar).
201
  """
 
202
  predictions = []
203
  labels = []
204
 
 
238
  cfg: Dictionary of parameters.
239
 
240
  """
241
+ random.seed(cfg.random_seed)
242
+ np.random.seed(cfg.random_seed)
243
+ torch.manual_seed(cfg.random_seed)
244
 
245
  # initialize distributed manager
246
  DistributedManager.initialize()
 
253
  # initialize trainer
254
  trainer = MGNTrainer(logger, cfg, dist)
255
 
256
+ if dist.distributed:
257
+ ddps = torch.cuda.Stream()
258
+ with torch.cuda.stream(ddps):
259
+ trainer.model = DistributedDataParallel(
260
+ trainer.model,
261
+ device_ids=[dist.local_rank], # Set the device_id to be
262
+ # the local rank of this process on
263
+ # this node
264
+ output_device=dist.device,
265
+ broadcast_buffers=dist.broadcast_buffers,
266
+ find_unused_parameters=dist.find_unused_parameters,
267
+ )
268
+ torch.cuda.current_stream().wait_stream(ddps)
269
+
270
  # training loop
271
  start = time.time()
272
  logger.info("Training started...")