ho22joshua commited on
Commit
cbcb0be
·
1 Parent(s): e1f04e9

updated training and inference script

Browse files
root_gnn_dgl/scripts/inference.py CHANGED
@@ -1,10 +1,10 @@
1
  import sys
2
- import os
3
- file_path = os.getcwd()
4
  sys.path.append(file_path)
5
-
6
  import argparse
7
  import yaml
 
8
 
9
  import torch
10
  import dgl
@@ -12,26 +12,15 @@ from dgl.data import DGLDataset
12
  from dgl.dataloading import GraphDataLoader
13
  from torch.utils.data import SubsetRandomSampler, SequentialSampler
14
 
15
-
16
- def my_error_handler(level, abort, location, msg):
17
- # Log the error message to a file instead of printing
18
- with open("error_log.txt", "a") as log_file:
19
- log_file.write(f"Error in {location}: {msg}\n")
20
-
21
- # Optionally, print the error message to the console
22
- # print(f"Error in {location}: {msg}")
23
-
24
- # Decide whether to abort based on the error level
25
- if abort:
26
- raise RuntimeError(f"Fatal error in {location}: {msg}")
27
-
28
  class CustomPreBatchedDataset(DGLDataset):
29
- def __init__(self, start_dataset, batch_size, mask_fn=None, drop_last=False, shuffle=False, **kwargs):
30
  self.start_dataset = start_dataset
31
  self.batch_size = batch_size
32
  self.mask_fn = mask_fn or (lambda x: torch.ones(len(x), dtype=torch.bool))
33
  self.drop_last = drop_last
34
  self.shuffle = shuffle
 
 
35
  super().__init__(name=start_dataset.name + '_custom_prebatched', save_dir=start_dataset.save_dir)
36
 
37
  def process(self):
@@ -39,18 +28,29 @@ class CustomPreBatchedDataset(DGLDataset):
39
  indices = torch.arange(len(self.start_dataset))[mask]
40
  print(f"Number of elements after masking: {len(indices)}") # Debugging print
41
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  if self.shuffle:
43
- sampler = SubsetRandomSampler(indices)
44
  else:
45
- sampler = SequentialSampler(indices)
46
 
47
  self.dataloader = GraphDataLoader(
48
- self.start_dataset,
49
- sampler=sampler,
50
- batch_size=self.batch_size,
51
  drop_last=self.drop_last
52
  )
53
- print(f"Batch size set in DataLoader: {self.batch_size}") # Debugging print
54
 
55
  def __getitem__(self, idx):
56
  if isinstance(idx, int):
@@ -60,7 +60,15 @@ class CustomPreBatchedDataset(DGLDataset):
60
  return next(iter(dloader))
61
 
62
  def __len__(self):
63
- return len(self.start_dataset)
 
 
 
 
 
 
 
 
64
 
65
  def include_config(conf):
66
  if 'include' in conf:
@@ -76,28 +84,44 @@ def load_config(config_file):
76
  return conf
77
 
78
  def main():
 
79
  parser = argparse.ArgumentParser()
80
  add_arg = parser.add_argument
81
- add_arg('--config', type=str, required=True)
82
  add_arg('--target', type=str, required=True)
83
  add_arg('--destination', type=str, default='')
84
  add_arg('--chunkno', type=int, default=0)
85
  add_arg('--chunks', type=int, default=1)
86
  add_arg('--write', action='store_true')
87
  add_arg('--ckpt', type=int, default=-1)
 
 
88
  add_arg('--clobber', action='store_true')
89
  add_arg('--tree', type=str, default='')
90
- add_arg('--branch_name', type=str, default='score')
91
  args = parser.parse_args()
92
 
93
- config = load_config(args.config)
 
 
 
 
 
 
94
  if args.destination == '':
95
- args.destination = os.path.join(config['Training_Directory'], 'inference/', os.path.split(args.target)[1])
 
 
 
 
 
 
96
  else:
97
- args.destination = args.destination
98
- if not args.write:
99
- args.destination = args.destination.replace('.root', '') + f'_chunk{args.chunkno}.npz'
100
 
 
101
  if os.path.exists(args.destination):
102
  print(f'File {args.destination} already exists.')
103
  if args.clobber:
@@ -137,7 +161,7 @@ def main():
137
  dset_config['args']['selections'] = []
138
 
139
  dset_config['args']['save_dir'] = os.path.dirname(args.destination)
140
-
141
  if args.tree != '':
142
  dset_config['args']['tree_name'] = args.tree
143
 
@@ -152,9 +176,13 @@ def main():
152
 
153
  batch_size = config['Training']['batch_size']
154
  lstart = time.time()
155
- loader = CustomPreBatchedDataset(dset, batch_size)
 
 
 
 
 
156
  loader.process()
157
- # loader = dataset.PreBatchedDataset(dset, batch_size, shuffle=False, drop_last=False, save_to_disk=False, chunks = 1, num_workers=0)
158
  lend = time.time()
159
  print('Loader finished in {:.2f} seconds'.format(lend - lstart))
160
  sample_graph, _, _, global_sample = loader[0]
@@ -162,70 +190,64 @@ def main():
162
  print('dset length =', len(dset))
163
  print('loader length =', len(loader))
164
 
165
- model = utils.buildFromConfig(config['Model'], {'sample_graph' : sample_graph, 'sample_global': global_sample}).to(device)
166
- if args.ckpt < 0:
167
- ep, checkpoint = utils.get_last_epoch(config, args.ckpt, device=device)
168
- else:
169
- ep, checkpoint = utils.get_specific_epoch(config, args.ckpt, device=device)
170
- #Bad filler for models which were compiled. Have to remove this prefix.
171
- mds_copy = {}
172
- for key in checkpoint['model_state_dict'].keys():
173
- newkey = key.replace('module.', '')
174
- newkey = newkey.replace('_orig_mod.', '')
175
- mds_copy[newkey] = checkpoint['model_state_dict'][key]
176
- model.load_state_dict(mds_copy)
177
- model.eval()
178
-
179
- end = time.time()
180
- print('Model and dataset finished in {:.2f} seconds'.format(end - start))
181
- print('Starting inference')
182
- start = time.time()
183
-
184
- finish_fn = torch.nn.Sigmoid()
185
- if 'Loss' in config:
186
- finish_fn = utils.buildFromConfig(config['Loss']['finish'])
187
-
188
- scores = []
189
- labels = []
190
- tracking_info = []
191
- ibatch = 0
192
-
193
- for batch, label, track, globals in loader.dataloader:
194
- batch = batch.to(device)
195
- pred = model(batch, globals.to(device))
196
- ibatch += 1
197
- # scores.append(finish_fn(pred).detach().cpu().numpy())
198
- if (finish_fn.__class__.__name__ == "ContrastiveClusterFinish"):
199
- scores.append(pred.detach().cpu().numpy())
200
- else:
201
- scores.append(finish_fn(pred).detach().cpu().numpy())
202
- labels.append(label.detach().cpu().numpy())
203
- tracking_info.append(track.detach().cpu().numpy())
204
-
205
- # for batch, label, track, globals in loader:
206
- # batch = batch.to(device)
207
- # pred = model(batch, globals.to(device))
208
- # print(f'Batch size: {batch.batch_size if hasattr(batch, "batch_size") else "Unavailable"}')
209
- # print(f'Prediction shape: {pred.shape}')
210
- # ibatch += 1
211
- # scores.append(finish_fn(pred).detach().cpu().numpy())
212
- # labels.append(label.detach().cpu().numpy())
213
- # tracking_info.append(track.detach().cpu().numpy())
214
- # exit()
215
-
216
- score_size = scores[0].shape[1]
217
- scores = np.concatenate(scores)
218
- labels = np.concatenate(labels)
219
- tracking_info = np.concatenate(tracking_info)
220
- end = time.time()
221
-
222
- print('Inference finished in {:.2f} seconds'.format(end - start))
223
 
224
  if args.write:
225
- # ROOT.SetErrorHandler(my_error_handler)
226
- ROOT.gErrorIgnoreLevel = ROOT.kFatal
227
- # ROOT.gSystem.RedirectOutput("/dev/null", "w")
228
-
229
  # Open the original ROOT file
230
  infile = ROOT.TFile.Open(args.target)
231
  tree = infile.Get(dset_config['args']['tree_name'])
@@ -236,54 +258,46 @@ def main():
236
  # Create a new ROOT file to write the modified tree
237
  outfile = ROOT.TFile.Open(args.destination, 'RECREATE')
238
 
239
- # Clone the original tree, including data
240
- outtree = tree.CloneTree(0) # Clone all entries
241
 
242
- # Determine if scores is a list of single values or vectors
243
- from ROOT import std
244
- if isinstance(scores[0], (list, tuple, np.ndarray)): # Check if scores contains vectors
245
- # Create a new branch for scores as a vector of floats
246
- scores_branch_vec = std.vector('float')()
247
- outtree.Branch(args.branch_name, scores_branch_vec)
248
- is_vector = True
249
- else: # Scores contains single values
250
- # Create a new branch for scores as a single float
251
- score_branch_arr = array('f', [0])
252
- outtree.Branch(args.branch_name, score_branch_arr, f'{args.branch_name}/F')
253
- is_vector = False
254
-
255
- # Write scores to the new branch
256
- print(f'Writing {len(scores)} scores to tree')
257
 
 
258
  for i in range(tree.GetEntries()):
259
  tree.GetEntry(i)
260
-
261
- if is_vector:
262
- # Clear the vector
263
- scores_branch_vec.clear()
264
-
265
- # Add all elements from scores[i] to the vector
266
- for value in scores[i]:
267
- scores_branch_vec.push_back(float(value)) # Use push_back to add elements one by one
268
- else:
269
- # Fill the score branch with the current single score
270
- score_branch_arr[0] = float(scores[i]) # Ensure the value is a float
271
 
272
- # Fill the output tree with all branches, including the new scores branch
 
 
 
 
 
 
 
 
273
  outtree.Fill()
274
 
275
  # Write the modified tree to the new file
276
  print(f'Writing to file {args.destination}')
277
  print(f'Input entries: {tree.GetEntries()}, Output entries: {outtree.GetEntries()}')
 
278
  outtree.Write()
279
  outfile.Close()
280
  infile.Close()
281
  else:
282
  os.makedirs(os.path.split(args.destination)[0], exist_ok=True)
283
- np.savez(args.destination, scores=scores, labels=labels, tracking_info=tracking_info)
284
 
285
  if __name__ == '__main__':
286
- main()
287
-
288
-
289
-
 
1
  import sys
2
+ file_path = "/global/cfs/projectdirs/atlas/joshua/root_gnn/root_gnn_dgl"
 
3
  sys.path.append(file_path)
4
+ import os
5
  import argparse
6
  import yaml
7
+ import gc
8
 
9
  import torch
10
  import dgl
 
12
  from dgl.dataloading import GraphDataLoader
13
  from torch.utils.data import SubsetRandomSampler, SequentialSampler
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  class CustomPreBatchedDataset(DGLDataset):
16
+ def __init__(self, start_dataset, batch_size, chunkno=0, chunks=1, mask_fn=None, drop_last=False, shuffle=False, **kwargs):
17
  self.start_dataset = start_dataset
18
  self.batch_size = batch_size
19
  self.mask_fn = mask_fn or (lambda x: torch.ones(len(x), dtype=torch.bool))
20
  self.drop_last = drop_last
21
  self.shuffle = shuffle
22
+ self.chunkno = chunkno
23
+ self.chunks = chunks
24
  super().__init__(name=start_dataset.name + '_custom_prebatched', save_dir=start_dataset.save_dir)
25
 
26
  def process(self):
 
28
  indices = torch.arange(len(self.start_dataset))[mask]
29
  print(f"Number of elements after masking: {len(indices)}") # Debugging print
30
 
31
+ # --- CHUNK SPLITTING ---
32
+ total = len(indices)
33
+ if self.chunks == 1:
34
+ chunk_indices = indices
35
+ print(f"Chunks=1, using all {total} indices.")
36
+ else:
37
+ chunk_size = (total + self.chunks - 1) // self.chunks
38
+ start = self.chunkno * chunk_size
39
+ end = min((self.chunkno + 1) * chunk_size, total)
40
+ chunk_indices = indices[start:end]
41
+ print(f"Working on chunk {self.chunkno}/{self.chunks}: indices {start}:{end} (total {len(chunk_indices)})")
42
+
43
  if self.shuffle:
44
+ sampler = SubsetRandomSampler(chunk_indices)
45
  else:
46
+ sampler = SequentialSampler(chunk_indices)
47
 
48
  self.dataloader = GraphDataLoader(
49
+ self.start_dataset,
50
+ sampler=sampler,
51
+ batch_size=self.batch_size,
52
  drop_last=self.drop_last
53
  )
 
54
 
55
  def __getitem__(self, idx):
56
  if isinstance(idx, int):
 
60
  return next(iter(dloader))
61
 
62
  def __len__(self):
63
+ mask = self.mask_fn(self.start_dataset)
64
+ indices = torch.arange(len(self.start_dataset))[mask]
65
+ total = len(indices)
66
+ if self.chunks == 1:
67
+ return total
68
+ chunk_size = (total + self.chunks - 1) // self.chunks
69
+ start = self.chunkno * chunk_size
70
+ end = min((self.chunkno + 1) * chunk_size, total)
71
+ return end - start
72
 
73
  def include_config(conf):
74
  if 'include' in conf:
 
84
  return conf
85
 
86
  def main():
87
+
88
  parser = argparse.ArgumentParser()
89
  add_arg = parser.add_argument
90
+ add_arg('--config', type=str, nargs='+', required=True, help="List of config files")
91
  add_arg('--target', type=str, required=True)
92
  add_arg('--destination', type=str, default='')
93
  add_arg('--chunkno', type=int, default=0)
94
  add_arg('--chunks', type=int, default=1)
95
  add_arg('--write', action='store_true')
96
  add_arg('--ckpt', type=int, default=-1)
97
+ add_arg('--var', type=str, default='Test_AUC')
98
+ add_arg('--mode', type=str, default='max')
99
  add_arg('--clobber', action='store_true')
100
  add_arg('--tree', type=str, default='')
101
+ add_arg('--branch_name', type=str, nargs='+', required=True, help="List of branch names corresponding to configs")
102
  args = parser.parse_args()
103
 
104
+ if(len(args.config) != len(args.branch_name)):
105
+ print(f"configs and branch names do not match")
106
+ return
107
+
108
+ config = load_config(args.config[0])
109
+
110
+ # --- OUTPUT DESTINATION LOGIC ---
111
  if args.destination == '':
112
+ base_dest = os.path.join(config['Training_Directory'], 'inference/', os.path.split(args.target)[1])
113
+ else:
114
+ base_dest = args.destination
115
+
116
+ base_dest = base_dest.replace('.root', '').replace('.npz', '')
117
+ if args.chunks > 1:
118
+ chunked_dest = f"{base_dest}_chunk{args.chunkno}"
119
  else:
120
+ chunked_dest = base_dest
121
+ chunked_dest += '.root' if args.write else '.npz'
122
+ args.destination = chunked_dest
123
 
124
+ # --- FILE EXISTENCE CHECK ---
125
  if os.path.exists(args.destination):
126
  print(f'File {args.destination} already exists.')
127
  if args.clobber:
 
161
  dset_config['args']['selections'] = []
162
 
163
  dset_config['args']['save_dir'] = os.path.dirname(args.destination)
164
+
165
  if args.tree != '':
166
  dset_config['args']['tree_name'] = args.tree
167
 
 
176
 
177
  batch_size = config['Training']['batch_size']
178
  lstart = time.time()
179
+ loader = CustomPreBatchedDataset(
180
+ dset,
181
+ batch_size,
182
+ chunkno=args.chunkno,
183
+ chunks=args.chunks
184
+ )
185
  loader.process()
 
186
  lend = time.time()
187
  print('Loader finished in {:.2f} seconds'.format(lend - lstart))
188
  sample_graph, _, _, global_sample = loader[0]
 
190
  print('dset length =', len(dset))
191
  print('loader length =', len(loader))
192
 
193
+ all_scores = {}
194
+ all_labels = {}
195
+ all_tracking = {}
196
+ with torch.no_grad():
197
+ for config_file, branch in zip(args.config, args.branch_name):
198
+ config = load_config(config_file)
199
+ model = utils.buildFromConfig(config['Model'], {'sample_graph' : sample_graph, 'sample_global': global_sample}).to(device)
200
+ if args.ckpt < 0:
201
+ ep, checkpoint = utils.get_best_epoch(config, var=args.var, mode='max', device=device)
202
+ else:
203
+ ep, checkpoint = utils.get_specific_epoch(config, args.ckpt, device=device)
204
+ # Remove distributed/compiled prefixes if present
205
+ mds_copy = {}
206
+ for key in checkpoint['model_state_dict'].keys():
207
+ newkey = key.replace('module.', '')
208
+ newkey = newkey.replace('_orig_mod.', '')
209
+ mds_copy[newkey] = checkpoint['model_state_dict'][key]
210
+ model.load_state_dict(mds_copy)
211
+ model.eval()
212
+
213
+ end = time.time()
214
+ print('Model and dataset finished in {:.2f} seconds'.format(end - start))
215
+ print('Starting inference')
216
+ start = time.time()
217
+
218
+ finish_fn = torch.nn.Sigmoid()
219
+ if 'Loss' in config:
220
+ finish_fn = utils.buildFromConfig(config['Loss']['finish'])
221
+
222
+ scores = []
223
+ labels = []
224
+ tracking_info = []
225
+ ibatch = 0
226
+
227
+ for batch, label, track, globals in loader.dataloader:
228
+ batch = batch.to(device)
229
+ pred = model(batch, globals.to(device))
230
+ ibatch += 1
231
+ if (finish_fn.__class__.__name__ == "ContrastiveClusterFinish"):
232
+ scores.append(pred.detach().cpu().numpy())
233
+ else:
234
+ scores.append(finish_fn(pred).detach().cpu().numpy())
235
+ labels.append(label.detach().cpu().numpy())
236
+ tracking_info.append(track.detach().cpu().numpy())
237
+
238
+ score_size = scores[0].shape[1] if len(scores[0].shape) > 1 else 1
239
+ scores = np.concatenate(scores)
240
+ labels = np.concatenate(labels)
241
+ tracking_info = np.concatenate(tracking_info)
242
+ end = time.time()
243
+
244
+ print('Inference finished in {:.2f} seconds'.format(end - start))
245
+ all_scores[branch] = scores
246
+ all_labels[branch] = labels
247
+ all_tracking[branch] = tracking_info
 
 
 
248
 
249
  if args.write:
250
+ from ROOT import std
 
 
 
251
  # Open the original ROOT file
252
  infile = ROOT.TFile.Open(args.target)
253
  tree = infile.Get(dset_config['args']['tree_name'])
 
258
  # Create a new ROOT file to write the modified tree
259
  outfile = ROOT.TFile.Open(args.destination, 'RECREATE')
260
 
261
+ # Clone the original tree structure
262
+ outtree = tree.CloneTree(0)
263
 
264
+ # Create branches for all scores
265
+ branch_vectors = {}
266
+ for branch, scores in all_scores.items():
267
+ if isinstance(scores[0], (list, tuple, np.ndarray)) and len(scores[0]) > 1:
268
+ # Create a new branch for vectors
269
+ branch_vectors[branch] = std.vector('float')()
270
+ outtree.Branch(branch, branch_vectors[branch])
271
+ else:
272
+ # Create a new branch for single floats
273
+ branch_vectors[branch] = array('f', [0])
274
+ outtree.Branch(branch, branch_vectors[branch], f'{branch}/F')
 
 
 
 
275
 
276
+ # Fill the tree
277
  for i in range(tree.GetEntries()):
278
  tree.GetEntry(i)
 
 
 
 
 
 
 
 
 
 
 
279
 
280
+ for branch, scores in all_scores.items():
281
+ branch_data = branch_vectors[branch]
282
+ if isinstance(branch_data, array): # Check if it's a single float array
283
+ branch_data[0] = float(scores[i])
284
+ else: # Assume it's a std::vector<float>
285
+ branch_data.clear()
286
+ for value in scores[i]:
287
+ branch_data.push_back(float(value))
288
+
289
  outtree.Fill()
290
 
291
  # Write the modified tree to the new file
292
  print(f'Writing to file {args.destination}')
293
  print(f'Input entries: {tree.GetEntries()}, Output entries: {outtree.GetEntries()}')
294
+ print(f'Wrote scores to {args.branch_name}')
295
  outtree.Write()
296
  outfile.Close()
297
  infile.Close()
298
  else:
299
  os.makedirs(os.path.split(args.destination)[0], exist_ok=True)
300
+ np.savez(args.destination, scores=all_scores, labels=all_labels, tracking_info=all_tracking)
301
 
302
  if __name__ == '__main__':
303
+ main()
 
 
 
root_gnn_dgl/scripts/prep_data.py CHANGED
@@ -1,6 +1,5 @@
1
  import sys
2
- import os
3
- file_path = os.getcwd()
4
  sys.path.append(file_path)
5
 
6
  import root_gnn_base.utils as utils
@@ -15,6 +14,7 @@ def main():
15
  add_arg('--dataset', type=str, required=True)
16
  add_arg('--chunk', type=int, default=0)
17
  add_arg('--shuffle_mode', action='store_true', help='Shuffle the dataset before training.')
 
18
  args = parser.parse_args()
19
 
20
  config = utils.load_config(args.config)
@@ -32,12 +32,12 @@ def main():
32
  fold_conf = dset_config["folding"]
33
  print(f"shuffle_chunks = {shuffle_chunks}, args.chunk = {args.chunk}, padding_mode = {padding_mode}")
34
  if dset_config["class"] == "LazyMultiLabelDataset":
35
- LazyPreBatchedDataset(start_dataset = dset, batch_size = batch_size, mask_fn = utils.fold_selection(fold_conf, "train"), suffix = utils.fold_selection_name(fold_conf, "train"), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode)
36
- LazyPreBatchedDataset(start_dataset = dset, batch_size = batch_size, mask_fn = utils.fold_selection(fold_conf, "test"), suffix = utils.fold_selection_name(fold_conf, 'test'), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode)
37
 
38
  else:
39
- PreBatchedDataset(dset, batch_size, utils.fold_selection(fold_conf, "train"), suffix = utils.fold_selection_name(fold_conf, "train"), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode)
40
- PreBatchedDataset(dset, batch_size, utils.fold_selection(fold_conf, "test"), suffix = utils.fold_selection_name(fold_conf, 'test'), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode)
41
 
42
  if __name__ == "__main__":
43
  main()
 
1
  import sys
2
+ file_path = "/global/cfs/projectdirs/atlas/joshua/root_gnn/root_gnn_dgl"
 
3
  sys.path.append(file_path)
4
 
5
  import root_gnn_base.utils as utils
 
14
  add_arg('--dataset', type=str, required=True)
15
  add_arg('--chunk', type=int, default=0)
16
  add_arg('--shuffle_mode', action='store_true', help='Shuffle the dataset before training.')
17
+ add_arg('--drop_last', action='store_false', help='Set drop_last to False if the flag is provided. Defaults to True.')
18
  args = parser.parse_args()
19
 
20
  config = utils.load_config(args.config)
 
32
  fold_conf = dset_config["folding"]
33
  print(f"shuffle_chunks = {shuffle_chunks}, args.chunk = {args.chunk}, padding_mode = {padding_mode}")
34
  if dset_config["class"] == "LazyMultiLabelDataset":
35
+ LazyPreBatchedDataset(start_dataset = dset, batch_size = batch_size, mask_fn = utils.fold_selection(fold_conf, "train"), suffix = utils.fold_selection_name(fold_conf, "train"), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last)
36
+ LazyPreBatchedDataset(start_dataset = dset, batch_size = batch_size, mask_fn = utils.fold_selection(fold_conf, "test"), suffix = utils.fold_selection_name(fold_conf, 'test'), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last)
37
 
38
  else:
39
+ PreBatchedDataset(dset, batch_size, utils.fold_selection(fold_conf, "train"), suffix = utils.fold_selection_name(fold_conf, "train"), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last)
40
+ PreBatchedDataset(dset, batch_size, utils.fold_selection(fold_conf, "test"), suffix = utils.fold_selection_name(fold_conf, 'test'), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last)
41
 
42
  if __name__ == "__main__":
43
  main()
root_gnn_dgl/scripts/training_script.py CHANGED
@@ -11,9 +11,8 @@ import torch
11
  import torch.nn as nn
12
 
13
  import sys
14
- file_path = os.getcwd()
15
  sys.path.append(file_path)
16
-
17
  import root_gnn_base.batched_dataset as datasets
18
  from root_gnn_base import utils
19
  import root_gnn_base.custom_scheduler as lr_utils
@@ -29,6 +28,8 @@ import torch.multiprocessing as mp
29
  from torch.utils.data.distributed import DistributedSampler
30
  from torch.nn.parallel import DistributedDataParallel as DDP
31
 
 
 
32
  def mem():
33
  print(f'Current memory usage: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024 / 1024} GB')
34
 
@@ -75,9 +76,9 @@ def evaluate(val_loaders, model, config, device, epoch = -1):
75
  print(f"Loaded epoch {checkpoint['epoch']} from checkpoint")
76
 
77
  if 'Loss' not in config:
78
- loss_fcn = nn.BCEWithLogitsLoss()
79
  else:
80
- loss_fcn = utils.buildFromConfig(config['Loss'])
81
  if len(val_loaders) == 0:
82
  return "No validation data"
83
  start = time.time()
@@ -143,10 +144,10 @@ def train(train_loaders, test_loaders, model, device, config, args, rank):
143
  restart = args.restart
144
  # define train/val samples, loss function and optimizer
145
  if 'Loss' not in config:
146
- loss_fcn = nn.BCEWithLogitsLoss()
147
  finish_fn = torch.nn.Sigmoid()
148
  else:
149
- loss_fcn = utils.buildFromConfig(config['Loss'])
150
  finish_fn = utils.buildFromConfig(config['Loss']['finish'])
151
 
152
  optimizer = torch.optim.Adam(model.parameters(), lr=config['Training']['learning_rate'])
@@ -280,11 +281,13 @@ def train(train_loaders, test_loaders, model, device, config, args, rank):
280
  batch_start = time.time()
281
  logits = torch.tensor([])
282
  tlabels = torch.tensor([])
 
283
  batch_lengths = []
284
  for cycler in train_cyclers:
285
- graph, label, _, global_feats = next(cycler)
286
  graph = graph.to(device)
287
  label = label.to(device)
 
288
  global_feats = global_feats.to(device)
289
  if is_padded: #Padding the globals to match padded graphs.
290
  global_feats = torch.concatenate((global_feats, torch.zeros(1, len(global_feats[0])).to(device)))
@@ -292,9 +295,11 @@ def train(train_loaders, test_loaders, model, device, config, args, rank):
292
  if (len(logits) == 0):
293
  logits = model(graph, global_feats)
294
  tlabels = label
 
295
  else:
296
  logits = torch.concatenate((logits, model(graph, global_feats)), dim=0)
297
  tlabels = torch.concatenate((tlabels, label), dim=0)
 
298
  batch_lengths.append(logits.shape[0] - 1)
299
 
300
  if is_padded:
@@ -307,7 +312,35 @@ def train(train_loaders, test_loaders, model, device, config, args, rank):
307
  tlabels = tlabels.to(torch.float)
308
  if loss_fcn.__class__.__name__ == 'CrossEntropyLoss':
309
  tlabels = tlabels.to(torch.long)
310
- loss = loss_fcn(logits, tlabels.to(device)) # changed logits from logits[:,0] and left labels as int for multiclass. Does this break binary? Yes.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  optimizer.zero_grad()
312
  loss.backward()
313
  optimizer.step()
@@ -382,6 +415,9 @@ def train(train_loaders, test_loaders, model, device, config, args, rank):
382
 
383
  wgt_mask = weights > 0
384
 
 
 
 
385
  print(f"Num batches trained = {ibatch}")
386
 
387
  #Note: This section is a bit ugly. Very conditional. Should maybe config defined behavior?
@@ -472,7 +508,29 @@ def train(train_loaders, test_loaders, model, device, config, args, rank):
472
  print(contrastive_cluster_log_str, flush=True)
473
 
474
  # test_loss = loss_fcn(logits, labels.to(device))
 
 
 
475
  test_loss = loss_fcn(logits, labels)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
476
  end = time.time()
477
  log_str = "Epoch {:05d} | LR {:.4e} | Loss {:.4f} | Accuracy {:.4f} | Test_Loss {:.4f} | Test_AUC {:.4f} | Time {:.4f} s".format(
478
  epoch, optimizer.param_groups[0]['lr'], total_loss/ibatch, acc, test_loss, test_auc, end - start
@@ -664,6 +722,7 @@ def main(rank=0, args=None, world_size=1, port=24500, seed=12345):
664
 
665
  load_end = time.time()
666
  print("Load time: {:.4f} s".format(load_end - load_start))
 
667
  model = utils.buildFromConfig(config["Model"], {'sample_graph': gsamp, 'sample_global': global_samp, 'seed': seed}).to(device)
668
  if not args.nocompile:
669
  model = torch.compile(model)
@@ -728,6 +787,7 @@ if __name__ == "__main__":
728
  add_arg("--statistics", type=float, help="Size of training data")
729
  add_arg("--directory", type=str, help="Append to Training Directory")
730
  add_arg("--seed", type=int, default=2, help="Sets random seed")
 
731
 
732
  pargs = parser.parse_args()
733
 
 
11
  import torch.nn as nn
12
 
13
  import sys
14
+ file_path = "/global/cfs/projectdirs/atlas/joshua/root_gnn/root_gnn_dgl/"
15
  sys.path.append(file_path)
 
16
  import root_gnn_base.batched_dataset as datasets
17
  from root_gnn_base import utils
18
  import root_gnn_base.custom_scheduler as lr_utils
 
28
  from torch.utils.data.distributed import DistributedSampler
29
  from torch.nn.parallel import DistributedDataParallel as DDP
30
 
31
+ print("import time: {:.4f} s".format(time.time() - start_time))
32
+
33
  def mem():
34
  print(f'Current memory usage: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024 / 1024} GB')
35
 
 
76
  print(f"Loaded epoch {checkpoint['epoch']} from checkpoint")
77
 
78
  if 'Loss' not in config:
79
+ loss_fcn = nn.BCEWithLogitsLoss(reduction='none')
80
  else:
81
+ loss_fcn = utils.buildFromConfig(config['Loss'], {'reduction': 'none'})
82
  if len(val_loaders) == 0:
83
  return "No validation data"
84
  start = time.time()
 
144
  restart = args.restart
145
  # define train/val samples, loss function and optimizer
146
  if 'Loss' not in config:
147
+ loss_fcn = nn.BCEWithLogitsLoss(reduction='none')
148
  finish_fn = torch.nn.Sigmoid()
149
  else:
150
+ loss_fcn = utils.buildFromConfig(config['Loss'], {'reduction':'none'})
151
  finish_fn = utils.buildFromConfig(config['Loss']['finish'])
152
 
153
  optimizer = torch.optim.Adam(model.parameters(), lr=config['Training']['learning_rate'])
 
281
  batch_start = time.time()
282
  logits = torch.tensor([])
283
  tlabels = torch.tensor([])
284
+ weights = torch.tensor([])
285
  batch_lengths = []
286
  for cycler in train_cyclers:
287
+ graph, label, track, global_feats = next(cycler)
288
  graph = graph.to(device)
289
  label = label.to(device)
290
+ track = track.to(device)
291
  global_feats = global_feats.to(device)
292
  if is_padded: #Padding the globals to match padded graphs.
293
  global_feats = torch.concatenate((global_feats, torch.zeros(1, len(global_feats[0])).to(device)))
 
295
  if (len(logits) == 0):
296
  logits = model(graph, global_feats)
297
  tlabels = label
298
+ weights = track[:,1]
299
  else:
300
  logits = torch.concatenate((logits, model(graph, global_feats)), dim=0)
301
  tlabels = torch.concatenate((tlabels, label), dim=0)
302
+ weights = torch.concatenate((weights, track[:,1]), dim=0)
303
  batch_lengths.append(logits.shape[0] - 1)
304
 
305
  if is_padded:
 
312
  tlabels = tlabels.to(torch.float)
313
  if loss_fcn.__class__.__name__ == 'CrossEntropyLoss':
314
  tlabels = tlabels.to(torch.long)
315
+ # loss = loss_fcn(logits, tlabels.to(device)) # changed logits from logits[:,0] and left labels as int for multiclass. Does this break binary? Yes.
316
+ # loss = torch.sum(weights * loss) / torch.sum(weights)
317
+
318
+
319
+ if args.abs:
320
+ weights = torch.abs(weights)
321
+
322
+ loss = loss_fcn(logits, tlabels.to(device))
323
+ # Normalize loss within each label
324
+ unique_labels = torch.unique(tlabels) # Get unique labels
325
+ normalized_loss = 0.0
326
+
327
+ for label in unique_labels:
328
+ # Mask for samples belonging to the current label
329
+ label_mask = (tlabels == label)
330
+
331
+ # Extract weights and losses for the current label
332
+ label_weights = weights[label_mask]
333
+ label_losses = loss[label_mask]
334
+
335
+
336
+ # Compute normalized loss for the current label
337
+ label_loss = torch.sum(label_weights * label_losses) / torch.sum(label_weights)
338
+
339
+ # Add to the total normalized loss
340
+ normalized_loss += label_loss
341
+ loss = normalized_loss / len(unique_labels)
342
+
343
+
344
  optimizer.zero_grad()
345
  loss.backward()
346
  optimizer.step()
 
415
 
416
  wgt_mask = weights > 0
417
 
418
+ if args.abs:
419
+ weights = torch.abs(weights)
420
+
421
  print(f"Num batches trained = {ibatch}")
422
 
423
  #Note: This section is a bit ugly. Very conditional. Should maybe config defined behavior?
 
508
  print(contrastive_cluster_log_str, flush=True)
509
 
510
  # test_loss = loss_fcn(logits, labels.to(device))
511
+ # test_loss = loss_fcn(logits, labels)
512
+ # test_loss = torch.sum(weights * test_loss) / torch.sum(weights)
513
+
514
  test_loss = loss_fcn(logits, labels)
515
+ # Normalize loss within each label
516
+ unique_labels = torch.unique(labels) # Get unique labels
517
+ normalized_loss = 0.0
518
+
519
+ for label in unique_labels:
520
+ # Mask for samples belonging to the current label
521
+ label_mask = (labels == label)
522
+
523
+ # Extract weights and losses for the current label
524
+ label_weights = weights[label_mask]
525
+ label_losses = test_loss[label_mask]
526
+ # Compute normalized loss for the current label
527
+ label_loss = torch.sum(label_weights * label_losses) / torch.sum(label_weights)
528
+
529
+ # Add to the total normalized loss
530
+ normalized_loss += label_loss
531
+ test_loss = normalized_loss / len(unique_labels)
532
+
533
+
534
  end = time.time()
535
  log_str = "Epoch {:05d} | LR {:.4e} | Loss {:.4f} | Accuracy {:.4f} | Test_Loss {:.4f} | Test_AUC {:.4f} | Time {:.4f} s".format(
536
  epoch, optimizer.param_groups[0]['lr'], total_loss/ibatch, acc, test_loss, test_auc, end - start
 
722
 
723
  load_end = time.time()
724
  print("Load time: {:.4f} s".format(load_end - load_start))
725
+
726
  model = utils.buildFromConfig(config["Model"], {'sample_graph': gsamp, 'sample_global': global_samp, 'seed': seed}).to(device)
727
  if not args.nocompile:
728
  model = torch.compile(model)
 
787
  add_arg("--statistics", type=float, help="Size of training data")
788
  add_arg("--directory", type=str, help="Append to Training Directory")
789
  add_arg("--seed", type=int, default=2, help="Sets random seed")
790
+ add_arg("--abs", action="store_true", help="Use abs value of per-event weight")
791
 
792
  pargs = parser.parse_args()
793