root commited on
Commit
ce1f4fd
·
1 Parent(s): eb039b5

changing ROOT to uproot

Browse files
root_gnn_dgl/jobs/prep_data/run_processing.py CHANGED
@@ -77,9 +77,12 @@ def main():
77
  configs = [
78
  # "configs/stats_100K/pretraining_multiclass.yaml",
79
  # "configs/stats_100K/ttH_CP_even_vs_odd.yaml",
80
- "configs/stats_all/pretraining_multiclass.yaml",
81
- "configs/stats_all/ttH_CP_even_vs_odd.yaml",
82
  # "configs/attention/ttH_CP_even_vs_odd.yaml",
 
 
 
83
  ]
84
 
85
  # Path to the bash script to be called
 
77
  configs = [
78
  # "configs/stats_100K/pretraining_multiclass.yaml",
79
  # "configs/stats_100K/ttH_CP_even_vs_odd.yaml",
80
+ # "configs/stats_all/pretraining_multiclass.yaml",
81
+ # "configs/stats_all/ttH_CP_even_vs_odd.yaml",
82
  # "configs/attention/ttH_CP_even_vs_odd.yaml",
83
+ "configs/stats_all/ttH_CP_even_vs_odd_batch_size_2048.yaml",
84
+ "configs/stats_all/ttH_CP_even_vs_odd_batch_size_4096.yaml",
85
+ "configs/stats_all/ttH_CP_even_vs_odd_batch_size_8192.yaml",
86
  ]
87
 
88
  # Path to the bash script to be called
root_gnn_dgl/profile.sh CHANGED
@@ -1,5 +1,5 @@
1
  nsys profile \
2
- -o /pscratch/sd/j/joshuaho/my_profile_report_1_gpu_batch_size_1028 \
3
  --capture-range=cudaProfilerApi \
4
  --capture-range-end=stop-shutdown \
5
  --force-overwrite true \
@@ -8,7 +8,7 @@ nsys profile \
8
  python scripts/training_script.py --config configs/stats_100K/ttH_CP_even_vs_odd.yaml --preshuffle --nocompile --lazy --restart --profile
9
 
10
  nsys profile \
11
- -o /pscratch/sd/j/joshuaho/my_profile_report_1_gpu_batch_size_2048 \
12
  --capture-range=cudaProfilerApi \
13
  --capture-range-end=stop-shutdown \
14
  --force-overwrite true \
@@ -17,7 +17,7 @@ nsys profile \
17
  python scripts/training_script.py --config configs/stats_100K/ttH_CP_even_vs_odd_batch_size_2048.yaml --preshuffle --nocompile --lazy --restart --profile
18
 
19
  nsys profile \
20
- -o /pscratch/sd/j/joshuaho/my_profile_report_1_gpu_batch_size_4096 \
21
  --capture-range=cudaProfilerApi \
22
  --capture-range-end=stop-shutdown \
23
  --force-overwrite true \
@@ -26,7 +26,7 @@ nsys profile \
26
  python scripts/training_script.py --config configs/stats_100K/ttH_CP_even_vs_odd_batch_size_4096.yaml --preshuffle --nocompile --lazy --restart --profile
27
 
28
  nsys profile \
29
- -o /pscratch/sd/j/joshuaho/my_profile_report_1_gpu_batch_size_8192 \
30
  --capture-range=cudaProfilerApi \
31
  --capture-range-end=stop-shutdown \
32
  --force-overwrite true \
 
1
  nsys profile \
2
+ -o /pscratch/sd/j/joshuaho/full_stats_profile_1_gpu_batch_size_1028 \
3
  --capture-range=cudaProfilerApi \
4
  --capture-range-end=stop-shutdown \
5
  --force-overwrite true \
 
8
  python scripts/training_script.py --config configs/stats_100K/ttH_CP_even_vs_odd.yaml --preshuffle --nocompile --lazy --restart --profile
9
 
10
  nsys profile \
11
+ -o /pscratch/sd/j/joshuaho/full_stats_profile_1_gpu_batch_size_2048 \
12
  --capture-range=cudaProfilerApi \
13
  --capture-range-end=stop-shutdown \
14
  --force-overwrite true \
 
17
  python scripts/training_script.py --config configs/stats_100K/ttH_CP_even_vs_odd_batch_size_2048.yaml --preshuffle --nocompile --lazy --restart --profile
18
 
19
  nsys profile \
20
+ -o /pscratch/sd/j/joshuaho/full_stats_profile_1_gpu_batch_size_4096 \
21
  --capture-range=cudaProfilerApi \
22
  --capture-range-end=stop-shutdown \
23
  --force-overwrite true \
 
26
  python scripts/training_script.py --config configs/stats_100K/ttH_CP_even_vs_odd_batch_size_4096.yaml --preshuffle --nocompile --lazy --restart --profile
27
 
28
  nsys profile \
29
+ -o /pscratch/sd/j/joshuaho/full_stats_profile_1_gpu_batch_size_8192 \
30
  --capture-range=cudaProfilerApi \
31
  --capture-range-end=stop-shutdown \
32
  --force-overwrite true \
root_gnn_dgl/root_gnn_base/dataset.py CHANGED
@@ -1,6 +1,7 @@
1
  from dgl.data import DGLDataset
2
  import dgl
3
- import ROOT
 
4
  import torch
5
  import os
6
  import glob
@@ -14,7 +15,7 @@ def node_features_from_tree(ch, node_branch_names, node_branch_types, node_featu
14
  if node_type == 'single':
15
  lengths.append(1)
16
  elif node_type == 'vector':
17
- lengths.append(len(getattr(ch, branch)))
18
  else:
19
  print('Unknown node branch type: {}'.format(node_type))
20
  features = []
@@ -38,16 +39,14 @@ def node_features_from_tree(ch, node_branch_names, node_branch_types, node_featu
38
  this_type_ends_at = sum(lengths[:itype+1])
39
  feat.extend(features[0][this_type_starts_at:this_type_ends_at]*torch.cosh(features[1][this_type_starts_at:this_type_ends_at]))
40
  elif node_type == 'single':
41
- feat.append(getattr(ch, branch))
42
  elif node_type == 'vector':
43
- feat.extend(getattr(ch, branch))
44
  itype += 1
45
  features.append(torch.tensor(feat))
46
  return torch.stack(features, dim=1) * node_feature_scales, lengths
47
 
48
  def full_connected_graph(n_nodes, self_loops=True):
49
- senders = []
50
- receivers = []
51
  senders = np.arange(n_nodes*n_nodes) // n_nodes
52
  receivers = np.arange(n_nodes*n_nodes) % n_nodes
53
  if not self_loops and n_nodes > 1:
@@ -59,19 +58,18 @@ def full_connected_graph(n_nodes, self_loops=True):
59
  def check_selection(ch, selection):
60
  var, cut, op = selection
61
  if op == '>':
62
- return getattr(ch, var) > cut
63
  elif op == '<':
64
- return getattr(ch, var) < cut
65
  elif op == '==':
66
- return getattr(ch, var) == cut
67
-
68
  def check_selections(ch, selections):
69
  for selection in selections:
70
  if not check_selection(ch, selection):
71
  return False
72
  return True
73
 
74
- #Base dataset class for making graphs from ROOT ntuples.
75
  class RootDataset(DGLDataset):
76
  def __init__(self, name=None, raw_dir=None, save_dir=None, label=1, file_names = '*.root', node_branch_names=None, node_branch_types=None, node_feature_scales=None,
77
  selections=[], save=True, tree_name = 'nominal_Loose', fold_var = 'eventNumber', weight_var = None, chunks = 1, process_chunks = None, global_features = [], tracking_info = [], **kwargs):
@@ -88,7 +86,7 @@ class RootDataset(DGLDataset):
88
  self.fold_var = fold_var
89
  self.tracking_info = tracking_info
90
  self.tracking_info.insert(0, fold_var)
91
- if weight_var == None:
92
  weight_var = 1
93
  self.tracking_info.insert(1, weight_var)
94
  self.global_features = global_features
@@ -116,7 +114,7 @@ class RootDataset(DGLDataset):
116
  branches.append(feat)
117
  for selection in self.selections:
118
  branches.append(selection[0])
119
- return branches
120
 
121
  def make_graph(self, ch):
122
  t1 = time.time()
@@ -129,7 +127,7 @@ class RootDataset(DGLDataset):
129
  self.times[0] += t2 - t1
130
  self.times[1] += t3 - t2
131
  return g
132
-
133
  def process(self):
134
  times = [0, 0, 0]
135
  oldtime = time.time()
@@ -139,21 +137,21 @@ class RootDataset(DGLDataset):
139
  self.files = []
140
  for file_name in self.file_names:
141
  self.files.extend(glob.glob(os.path.join(self.raw_dir, file_name)))
142
- self.chain = ROOT.TChain(self.tree_name)
143
 
144
- if len(self.files) == 0:
145
- print('No files found in {}'.format(os.path.join(self.raw_dir, self.file_names)))
146
  for file in self.files:
147
- utils.set_timeout(60*2)
148
- self.chain.Add(file)
149
- utils.unset_timeout()
150
- branches = self.get_list_of_branches()
151
- self.chain.SetBranchStatus('*', 0)
152
- for branch in branches:
153
- self.chain.SetBranchStatus(branch, 1)
154
  newtime = time.time()
155
  times[0] += newtime - oldtime
156
- chunks = np.array_split(np.arange(self.chain.GetEntries()), self.chunks)
157
  chunks = [chunk for i, chunk in enumerate(chunks) if i in self.process_chunks]
158
 
159
  self.graph_chunks = []
@@ -169,28 +167,28 @@ class RootDataset(DGLDataset):
169
  globals = []
170
  for ientry in chunk:
171
  if (ientry % 10000 == 0):
172
- print('Processing event {}/{}'.format(ientry, self.chain.GetEntries()), flush=True)
173
- self.chain.GetEntry(ientry)
174
  passed = True
175
  for selection in self.selections:
176
- if not check_selection(self.chain, selection):
177
  passed = False
178
  continue
179
  oldtime = newtime
180
  newtime = time.time()
181
  times[1] += newtime - oldtime
182
  if passed:
183
- graphs.append(self.make_graph(self.chain))
184
- labels.append( self.label )
185
  tracking.append(torch.zeros(len(self.tracking_info), dtype=torch.double))
186
  globals.append(torch.zeros(len(self.global_features)))
187
  for i_ti, tr_branch in enumerate(self.tracking_info):
188
  if isinstance(tr_branch, str):
189
- tracking[-1][i_ti] = getattr(self.chain, tr_branch)
190
  else:
191
  tracking[-1][i_ti] = tr_branch
192
  for i_gl, gl_branch in enumerate(self.global_features):
193
- globals[-1][i_gl] = getattr(self.chain, gl_branch)
194
  oldtime = newtime
195
  newtime = time.time()
196
  times[2] += newtime - oldtime
@@ -198,7 +196,6 @@ class RootDataset(DGLDataset):
198
  labels = torch.tensor(labels)
199
  tracking = torch.stack(tracking)
200
  globals = torch.stack(globals)
201
-
202
  if (self.chunks > 1):
203
  self.save_chunk(chunk_id, graphs, labels, tracking, globals)
204
  else:
@@ -208,29 +205,15 @@ class RootDataset(DGLDataset):
208
  self.graphs = graphs
209
  self.save()
210
  return
211
- self.graphs = self.graph_chunks[0]
212
- for chunk in self.graph_chunks[1:]:
213
- self.graphs += chunk
214
- self.labels = torch.cat(self.label_chunks)
215
- self.tracking = torch.cat(self.tracking_chunks)
216
- self.global_features = torch.cat(self.global_chunks)
217
- print('Time spent: Creating TChain: {}s, Getting Entries and Selection: {}s, Graph Creation: {}s'.format(*times))
218
- print('Time spent in node_features_from_tree: {}s, full_connected_graph: {}s'.format(*self.times))
219
-
220
  def save(self):
221
- """save the graph list and the labels"""
222
  if not self.save_to_disk:
223
  return
224
  graph_path = os.path.join(self.save_dir, self.name + '.bin')
225
  if self.chunks == 1:
226
- # print(len(self.graphs))
227
- # print(len(self.labels))
228
- # print(len(self.tracking))
229
- # print(len(self.globals))
230
  print(f'Saving dataset to {os.path.join(self.save_dir, self.name + ".bin")}')
231
  dgl.save_graphs(str(graph_path), self.graphs, {'labels': torch.tensor(self.labels), 'tracking': torch.tensor(self.tracking), 'global': torch.tensor(self.global_features)})
232
  else:
233
- print(len(self.graph_chunks))
234
  for i in range(len(self.process_chunks)):
235
  print(f'Saving dataset to {os.path.join(self.save_dir, self.name + f"_{self.process_chunks[i]}.bin")}')
236
  dgl.save_graphs(str(graph_path).replace('.bin', f'_{self.process_chunks[i]}.bin'), self.graph_chunks[i], {'labels': self.label_chunks[i], 'tracking': self.tracking_chunks[i], 'global': self.global_chunks[i]})
@@ -241,7 +224,7 @@ class RootDataset(DGLDataset):
241
  graph_path = os.path.join(self.save_dir, self.name + '.bin')
242
  print(f'Saving dataset to {os.path.join(self.save_dir, self.name + f"_{self.process_chunks[chunk_id]}.bin")}')
243
  dgl.save_graphs(str(graph_path).replace('.bin', f'_{self.process_chunks[chunk_id]}.bin'), graphs, {'labels': labels, 'tracking': tracking, 'global': globals})
244
-
245
  def has_cache(self):
246
  print(f'Checking for cache of {self.name}')
247
  if not self.save_to_disk:
@@ -290,7 +273,7 @@ class RootDataset(DGLDataset):
290
 
291
  def __len__(self):
292
  return len(self.graphs)
293
-
294
  #Dataset with edge features added (deta, dphi, dR)
295
  class EdgeDataset(RootDataset):
296
  def make_graph(self, ch):
 
1
  from dgl.data import DGLDataset
2
  import dgl
3
+ import uproot
4
+ import awkward as ak
5
  import torch
6
  import os
7
  import glob
 
15
  if node_type == 'single':
16
  lengths.append(1)
17
  elif node_type == 'vector':
18
+ lengths.append(len(ch[branch]))
19
  else:
20
  print('Unknown node branch type: {}'.format(node_type))
21
  features = []
 
39
  this_type_ends_at = sum(lengths[:itype+1])
40
  feat.extend(features[0][this_type_starts_at:this_type_ends_at]*torch.cosh(features[1][this_type_starts_at:this_type_ends_at]))
41
  elif node_type == 'single':
42
+ feat.append(ch[branch])
43
  elif node_type == 'vector':
44
+ feat.extend(ch[branch])
45
  itype += 1
46
  features.append(torch.tensor(feat))
47
  return torch.stack(features, dim=1) * node_feature_scales, lengths
48
 
49
  def full_connected_graph(n_nodes, self_loops=True):
 
 
50
  senders = np.arange(n_nodes*n_nodes) // n_nodes
51
  receivers = np.arange(n_nodes*n_nodes) % n_nodes
52
  if not self_loops and n_nodes > 1:
 
58
  def check_selection(ch, selection):
59
  var, cut, op = selection
60
  if op == '>':
61
+ return ch[var] > cut
62
  elif op == '<':
63
+ return ch[var] < cut
64
  elif op == '==':
65
+ return ch[var] == cut
66
+
67
  def check_selections(ch, selections):
68
  for selection in selections:
69
  if not check_selection(ch, selection):
70
  return False
71
  return True
72
 
 
73
  class RootDataset(DGLDataset):
74
  def __init__(self, name=None, raw_dir=None, save_dir=None, label=1, file_names = '*.root', node_branch_names=None, node_branch_types=None, node_feature_scales=None,
75
  selections=[], save=True, tree_name = 'nominal_Loose', fold_var = 'eventNumber', weight_var = None, chunks = 1, process_chunks = None, global_features = [], tracking_info = [], **kwargs):
 
86
  self.fold_var = fold_var
87
  self.tracking_info = tracking_info
88
  self.tracking_info.insert(0, fold_var)
89
+ if weight_var is None:
90
  weight_var = 1
91
  self.tracking_info.insert(1, weight_var)
92
  self.global_features = global_features
 
114
  branches.append(feat)
115
  for selection in self.selections:
116
  branches.append(selection[0])
117
+ return list(set(branches)) # Remove duplicates
118
 
119
  def make_graph(self, ch):
120
  t1 = time.time()
 
127
  self.times[0] += t2 - t1
128
  self.times[1] += t3 - t2
129
  return g
130
+
131
  def process(self):
132
  times = [0, 0, 0]
133
  oldtime = time.time()
 
137
  self.files = []
138
  for file_name in self.file_names:
139
  self.files.extend(glob.glob(os.path.join(self.raw_dir, file_name)))
140
+ branches = self.get_list_of_branches()
141
 
142
+ # Read all files and concatenate arrays
143
+ arrays = []
144
  for file in self.files:
145
+ with uproot.open(file) as f:
146
+ arrays.append(f[self.tree_name].arrays(branches, library="ak"))
147
+ if len(arrays) == 0:
148
+ print('No files found in {}'.format(os.path.join(self.raw_dir, self.file_names)))
149
+ return
150
+ data = ak.concatenate(arrays, axis=0)
151
+ n_entries = len(data[branches[0]])
152
  newtime = time.time()
153
  times[0] += newtime - oldtime
154
+ chunks = np.array_split(np.arange(n_entries), self.chunks)
155
  chunks = [chunk for i, chunk in enumerate(chunks) if i in self.process_chunks]
156
 
157
  self.graph_chunks = []
 
167
  globals = []
168
  for ientry in chunk:
169
  if (ientry % 10000 == 0):
170
+ print('Processing event {}/{}'.format(ientry, n_entries), flush=True)
171
+ ch = {b: data[b][ientry] for b in branches}
172
  passed = True
173
  for selection in self.selections:
174
+ if not check_selection(ch, selection):
175
  passed = False
176
  continue
177
  oldtime = newtime
178
  newtime = time.time()
179
  times[1] += newtime - oldtime
180
  if passed:
181
+ graphs.append(self.make_graph(ch))
182
+ labels.append(self.label)
183
  tracking.append(torch.zeros(len(self.tracking_info), dtype=torch.double))
184
  globals.append(torch.zeros(len(self.global_features)))
185
  for i_ti, tr_branch in enumerate(self.tracking_info):
186
  if isinstance(tr_branch, str):
187
+ tracking[-1][i_ti] = ch[tr_branch]
188
  else:
189
  tracking[-1][i_ti] = tr_branch
190
  for i_gl, gl_branch in enumerate(self.global_features):
191
+ globals[-1][i_gl] = ch[gl_branch]
192
  oldtime = newtime
193
  newtime = time.time()
194
  times[2] += newtime - oldtime
 
196
  labels = torch.tensor(labels)
197
  tracking = torch.stack(tracking)
198
  globals = torch.stack(globals)
 
199
  if (self.chunks > 1):
200
  self.save_chunk(chunk_id, graphs, labels, tracking, globals)
201
  else:
 
205
  self.graphs = graphs
206
  self.save()
207
  return
208
+
 
 
 
 
 
 
 
 
209
  def save(self):
 
210
  if not self.save_to_disk:
211
  return
212
  graph_path = os.path.join(self.save_dir, self.name + '.bin')
213
  if self.chunks == 1:
 
 
 
 
214
  print(f'Saving dataset to {os.path.join(self.save_dir, self.name + ".bin")}')
215
  dgl.save_graphs(str(graph_path), self.graphs, {'labels': torch.tensor(self.labels), 'tracking': torch.tensor(self.tracking), 'global': torch.tensor(self.global_features)})
216
  else:
 
217
  for i in range(len(self.process_chunks)):
218
  print(f'Saving dataset to {os.path.join(self.save_dir, self.name + f"_{self.process_chunks[i]}.bin")}')
219
  dgl.save_graphs(str(graph_path).replace('.bin', f'_{self.process_chunks[i]}.bin'), self.graph_chunks[i], {'labels': self.label_chunks[i], 'tracking': self.tracking_chunks[i], 'global': self.global_chunks[i]})
 
224
  graph_path = os.path.join(self.save_dir, self.name + '.bin')
225
  print(f'Saving dataset to {os.path.join(self.save_dir, self.name + f"_{self.process_chunks[chunk_id]}.bin")}')
226
  dgl.save_graphs(str(graph_path).replace('.bin', f'_{self.process_chunks[chunk_id]}.bin'), graphs, {'labels': labels, 'tracking': tracking, 'global': globals})
227
+
228
  def has_cache(self):
229
  print(f'Checking for cache of {self.name}')
230
  if not self.save_to_disk:
 
273
 
274
  def __len__(self):
275
  return len(self.graphs)
276
+
277
  #Dataset with edge features added (deta, dphi, dR)
278
  class EdgeDataset(RootDataset):
279
  def make_graph(self, ch):
root_gnn_dgl/scripts/inference.py CHANGED
@@ -135,7 +135,6 @@ def main():
135
 
136
  import time
137
  start = time.time()
138
- import ROOT
139
  import torch
140
  from array import array
141
  import numpy as np
@@ -247,55 +246,41 @@ def main():
247
  all_labels[branch] = labels
248
  all_tracking[branch] = tracking_info
249
 
250
- if args.write:
251
- from ROOT import std
252
- # Open the original ROOT file
253
- infile = ROOT.TFile.Open(args.target)
254
- tree = infile.Get(dset_config['args']['tree_name'])
255
 
256
- # Create the destination directory if it doesn't exist
257
- os.makedirs(os.path.split(args.destination)[0], exist_ok=True)
 
258
 
259
- # Create a new ROOT file to write the modified tree
260
- outfile = ROOT.TFile.Open(args.destination, 'RECREATE')
 
261
 
262
- # Clone the original tree structure
263
- outtree = tree.CloneTree(0)
264
 
265
- # Create branches for all scores
266
- branch_vectors = {}
 
267
  for branch, scores in all_scores.items():
268
- if isinstance(scores[0], (list, tuple, np.ndarray)) and len(scores[0]) > 1:
269
- # Create a new branch for vectors
270
- branch_vectors[branch] = std.vector('float')()
271
- outtree.Branch(branch, branch_vectors[branch])
272
- else:
273
- # Create a new branch for single floats
274
- branch_vectors[branch] = array('f', [0])
275
- outtree.Branch(branch, branch_vectors[branch], f'{branch}/F')
276
-
277
- # Fill the tree
278
- for i in range(tree.GetEntries()):
279
- tree.GetEntry(i)
280
-
281
- for branch, scores in all_scores.items():
282
- branch_data = branch_vectors[branch]
283
- if isinstance(branch_data, array): # Check if it's a single float array
284
- branch_data[0] = float(scores[i])
285
- else: # Assume it's a std::vector<float>
286
- branch_data.clear()
287
- for value in scores[i]:
288
- branch_data.push_back(float(value))
289
-
290
- outtree.Fill()
291
-
292
- # Write the modified tree to the new file
293
- print(f'Writing to file {args.destination}')
294
- print(f'Input entries: {tree.GetEntries()}, Output entries: {outtree.GetEntries()}')
295
- print(f'Wrote scores to {args.branch_name}')
296
- outtree.Write()
297
- outfile.Close()
298
- infile.Close()
299
  else:
300
  os.makedirs(os.path.split(args.destination)[0], exist_ok=True)
301
  np.savez(args.destination, scores=all_scores, labels=all_labels, tracking_info=all_tracking)
 
135
 
136
  import time
137
  start = time.time()
 
138
  import torch
139
  from array import array
140
  import numpy as np
 
246
  all_labels[branch] = labels
247
  all_tracking[branch] = tracking_info
248
 
 
 
 
 
 
249
 
250
+ if args.write:
251
+ import uproot
252
+ import awkward as ak
253
 
254
+ # Open the original ROOT file and get the tree
255
+ infile = uproot.open(args.target)
256
+ tree = infile[dset_config['args']['tree_name']]
257
 
258
+ # Read the original tree as an awkward array
259
+ original_data = tree.arrays(library="ak")
260
 
261
+ # Prepare new branches as dicts of arrays
262
+ new_branches = {}
263
+ n_entries = len(original_data)
264
  for branch, scores in all_scores.items():
265
+ # Ensure the scores array is the right length
266
+ scores = np.asarray(scores)
267
+ if scores.shape[0] != n_entries:
268
+ raise ValueError(f"Branch '{branch}' has {scores.shape[0]} entries, but tree has {n_entries}")
269
+ new_branches[branch] = scores
270
+
271
+ # Merge all arrays (original + new branches)
272
+ # Convert awkward to dict of numpy arrays for uproot
273
+ out_dict = {k: np.asarray(v) for k, v in ak.to_numpy(original_data).items()}
274
+ out_dict.update(new_branches)
275
+
276
+ # Write to new ROOT file
277
+ os.makedirs(os.path.split(args.destination)[0], exist_ok=True)
278
+ with uproot.recreate(args.destination) as outfile:
279
+ outfile.mktree(dset_config['args']['tree_name'], {k: v.dtype for k, v in out_dict.items()})
280
+ outfile[dset_config['args']['tree_name']].extend(out_dict)
281
+
282
+ print(f"Wrote new ROOT file {args.destination} with new branches {list(new_branches.keys())}")
283
+
 
 
 
 
 
 
 
 
 
 
 
 
284
  else:
285
  os.makedirs(os.path.split(args.destination)[0], exist_ok=True)
286
  np.savez(args.destination, scores=all_scores, labels=all_labels, tracking_info=all_tracking)