root
commited on
Commit
·
ce1f4fd
1
Parent(s):
eb039b5
changing ROOT to uproot
Browse files
root_gnn_dgl/jobs/prep_data/run_processing.py
CHANGED
|
@@ -77,9 +77,12 @@ def main():
|
|
| 77 |
configs = [
|
| 78 |
# "configs/stats_100K/pretraining_multiclass.yaml",
|
| 79 |
# "configs/stats_100K/ttH_CP_even_vs_odd.yaml",
|
| 80 |
-
"configs/stats_all/pretraining_multiclass.yaml",
|
| 81 |
-
"configs/stats_all/ttH_CP_even_vs_odd.yaml",
|
| 82 |
# "configs/attention/ttH_CP_even_vs_odd.yaml",
|
|
|
|
|
|
|
|
|
|
| 83 |
]
|
| 84 |
|
| 85 |
# Path to the bash script to be called
|
|
|
|
| 77 |
configs = [
|
| 78 |
# "configs/stats_100K/pretraining_multiclass.yaml",
|
| 79 |
# "configs/stats_100K/ttH_CP_even_vs_odd.yaml",
|
| 80 |
+
# "configs/stats_all/pretraining_multiclass.yaml",
|
| 81 |
+
# "configs/stats_all/ttH_CP_even_vs_odd.yaml",
|
| 82 |
# "configs/attention/ttH_CP_even_vs_odd.yaml",
|
| 83 |
+
"configs/stats_all/ttH_CP_even_vs_odd_batch_size_2048.yaml",
|
| 84 |
+
"configs/stats_all/ttH_CP_even_vs_odd_batch_size_4096.yaml",
|
| 85 |
+
"configs/stats_all/ttH_CP_even_vs_odd_batch_size_8192.yaml",
|
| 86 |
]
|
| 87 |
|
| 88 |
# Path to the bash script to be called
|
root_gnn_dgl/profile.sh
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
nsys profile \
|
| 2 |
-
-o /pscratch/sd/j/joshuaho/
|
| 3 |
--capture-range=cudaProfilerApi \
|
| 4 |
--capture-range-end=stop-shutdown \
|
| 5 |
--force-overwrite true \
|
|
@@ -8,7 +8,7 @@ nsys profile \
|
|
| 8 |
python scripts/training_script.py --config configs/stats_100K/ttH_CP_even_vs_odd.yaml --preshuffle --nocompile --lazy --restart --profile
|
| 9 |
|
| 10 |
nsys profile \
|
| 11 |
-
-o /pscratch/sd/j/joshuaho/
|
| 12 |
--capture-range=cudaProfilerApi \
|
| 13 |
--capture-range-end=stop-shutdown \
|
| 14 |
--force-overwrite true \
|
|
@@ -17,7 +17,7 @@ nsys profile \
|
|
| 17 |
python scripts/training_script.py --config configs/stats_100K/ttH_CP_even_vs_odd_batch_size_2048.yaml --preshuffle --nocompile --lazy --restart --profile
|
| 18 |
|
| 19 |
nsys profile \
|
| 20 |
-
-o /pscratch/sd/j/joshuaho/
|
| 21 |
--capture-range=cudaProfilerApi \
|
| 22 |
--capture-range-end=stop-shutdown \
|
| 23 |
--force-overwrite true \
|
|
@@ -26,7 +26,7 @@ nsys profile \
|
|
| 26 |
python scripts/training_script.py --config configs/stats_100K/ttH_CP_even_vs_odd_batch_size_4096.yaml --preshuffle --nocompile --lazy --restart --profile
|
| 27 |
|
| 28 |
nsys profile \
|
| 29 |
-
-o /pscratch/sd/j/joshuaho/
|
| 30 |
--capture-range=cudaProfilerApi \
|
| 31 |
--capture-range-end=stop-shutdown \
|
| 32 |
--force-overwrite true \
|
|
|
|
| 1 |
nsys profile \
|
| 2 |
+
-o /pscratch/sd/j/joshuaho/full_stats_profile_1_gpu_batch_size_1028 \
|
| 3 |
--capture-range=cudaProfilerApi \
|
| 4 |
--capture-range-end=stop-shutdown \
|
| 5 |
--force-overwrite true \
|
|
|
|
| 8 |
python scripts/training_script.py --config configs/stats_100K/ttH_CP_even_vs_odd.yaml --preshuffle --nocompile --lazy --restart --profile
|
| 9 |
|
| 10 |
nsys profile \
|
| 11 |
+
-o /pscratch/sd/j/joshuaho/full_stats_profile_1_gpu_batch_size_2048 \
|
| 12 |
--capture-range=cudaProfilerApi \
|
| 13 |
--capture-range-end=stop-shutdown \
|
| 14 |
--force-overwrite true \
|
|
|
|
| 17 |
python scripts/training_script.py --config configs/stats_100K/ttH_CP_even_vs_odd_batch_size_2048.yaml --preshuffle --nocompile --lazy --restart --profile
|
| 18 |
|
| 19 |
nsys profile \
|
| 20 |
+
-o /pscratch/sd/j/joshuaho/full_stats_profile_1_gpu_batch_size_4096 \
|
| 21 |
--capture-range=cudaProfilerApi \
|
| 22 |
--capture-range-end=stop-shutdown \
|
| 23 |
--force-overwrite true \
|
|
|
|
| 26 |
python scripts/training_script.py --config configs/stats_100K/ttH_CP_even_vs_odd_batch_size_4096.yaml --preshuffle --nocompile --lazy --restart --profile
|
| 27 |
|
| 28 |
nsys profile \
|
| 29 |
+
-o /pscratch/sd/j/joshuaho/full_stats_profile_1_gpu_batch_size_8192 \
|
| 30 |
--capture-range=cudaProfilerApi \
|
| 31 |
--capture-range-end=stop-shutdown \
|
| 32 |
--force-overwrite true \
|
root_gnn_dgl/root_gnn_base/dataset.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
from dgl.data import DGLDataset
|
| 2 |
import dgl
|
| 3 |
-
import
|
|
|
|
| 4 |
import torch
|
| 5 |
import os
|
| 6 |
import glob
|
|
@@ -14,7 +15,7 @@ def node_features_from_tree(ch, node_branch_names, node_branch_types, node_featu
|
|
| 14 |
if node_type == 'single':
|
| 15 |
lengths.append(1)
|
| 16 |
elif node_type == 'vector':
|
| 17 |
-
lengths.append(len(
|
| 18 |
else:
|
| 19 |
print('Unknown node branch type: {}'.format(node_type))
|
| 20 |
features = []
|
|
@@ -38,16 +39,14 @@ def node_features_from_tree(ch, node_branch_names, node_branch_types, node_featu
|
|
| 38 |
this_type_ends_at = sum(lengths[:itype+1])
|
| 39 |
feat.extend(features[0][this_type_starts_at:this_type_ends_at]*torch.cosh(features[1][this_type_starts_at:this_type_ends_at]))
|
| 40 |
elif node_type == 'single':
|
| 41 |
-
feat.append(
|
| 42 |
elif node_type == 'vector':
|
| 43 |
-
feat.extend(
|
| 44 |
itype += 1
|
| 45 |
features.append(torch.tensor(feat))
|
| 46 |
return torch.stack(features, dim=1) * node_feature_scales, lengths
|
| 47 |
|
| 48 |
def full_connected_graph(n_nodes, self_loops=True):
|
| 49 |
-
senders = []
|
| 50 |
-
receivers = []
|
| 51 |
senders = np.arange(n_nodes*n_nodes) // n_nodes
|
| 52 |
receivers = np.arange(n_nodes*n_nodes) % n_nodes
|
| 53 |
if not self_loops and n_nodes > 1:
|
|
@@ -59,19 +58,18 @@ def full_connected_graph(n_nodes, self_loops=True):
|
|
| 59 |
def check_selection(ch, selection):
|
| 60 |
var, cut, op = selection
|
| 61 |
if op == '>':
|
| 62 |
-
return
|
| 63 |
elif op == '<':
|
| 64 |
-
return
|
| 65 |
elif op == '==':
|
| 66 |
-
return
|
| 67 |
-
|
| 68 |
def check_selections(ch, selections):
|
| 69 |
for selection in selections:
|
| 70 |
if not check_selection(ch, selection):
|
| 71 |
return False
|
| 72 |
return True
|
| 73 |
|
| 74 |
-
#Base dataset class for making graphs from ROOT ntuples.
|
| 75 |
class RootDataset(DGLDataset):
|
| 76 |
def __init__(self, name=None, raw_dir=None, save_dir=None, label=1, file_names = '*.root', node_branch_names=None, node_branch_types=None, node_feature_scales=None,
|
| 77 |
selections=[], save=True, tree_name = 'nominal_Loose', fold_var = 'eventNumber', weight_var = None, chunks = 1, process_chunks = None, global_features = [], tracking_info = [], **kwargs):
|
|
@@ -88,7 +86,7 @@ class RootDataset(DGLDataset):
|
|
| 88 |
self.fold_var = fold_var
|
| 89 |
self.tracking_info = tracking_info
|
| 90 |
self.tracking_info.insert(0, fold_var)
|
| 91 |
-
if weight_var
|
| 92 |
weight_var = 1
|
| 93 |
self.tracking_info.insert(1, weight_var)
|
| 94 |
self.global_features = global_features
|
|
@@ -116,7 +114,7 @@ class RootDataset(DGLDataset):
|
|
| 116 |
branches.append(feat)
|
| 117 |
for selection in self.selections:
|
| 118 |
branches.append(selection[0])
|
| 119 |
-
return branches
|
| 120 |
|
| 121 |
def make_graph(self, ch):
|
| 122 |
t1 = time.time()
|
|
@@ -129,7 +127,7 @@ class RootDataset(DGLDataset):
|
|
| 129 |
self.times[0] += t2 - t1
|
| 130 |
self.times[1] += t3 - t2
|
| 131 |
return g
|
| 132 |
-
|
| 133 |
def process(self):
|
| 134 |
times = [0, 0, 0]
|
| 135 |
oldtime = time.time()
|
|
@@ -139,21 +137,21 @@ class RootDataset(DGLDataset):
|
|
| 139 |
self.files = []
|
| 140 |
for file_name in self.file_names:
|
| 141 |
self.files.extend(glob.glob(os.path.join(self.raw_dir, file_name)))
|
| 142 |
-
|
| 143 |
|
| 144 |
-
|
| 145 |
-
|
| 146 |
for file in self.files:
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
newtime = time.time()
|
| 155 |
times[0] += newtime - oldtime
|
| 156 |
-
chunks = np.array_split(np.arange(
|
| 157 |
chunks = [chunk for i, chunk in enumerate(chunks) if i in self.process_chunks]
|
| 158 |
|
| 159 |
self.graph_chunks = []
|
|
@@ -169,28 +167,28 @@ class RootDataset(DGLDataset):
|
|
| 169 |
globals = []
|
| 170 |
for ientry in chunk:
|
| 171 |
if (ientry % 10000 == 0):
|
| 172 |
-
print('Processing event {}/{}'.format(ientry,
|
| 173 |
-
|
| 174 |
passed = True
|
| 175 |
for selection in self.selections:
|
| 176 |
-
if not check_selection(
|
| 177 |
passed = False
|
| 178 |
continue
|
| 179 |
oldtime = newtime
|
| 180 |
newtime = time.time()
|
| 181 |
times[1] += newtime - oldtime
|
| 182 |
if passed:
|
| 183 |
-
graphs.append(self.make_graph(
|
| 184 |
-
labels.append(
|
| 185 |
tracking.append(torch.zeros(len(self.tracking_info), dtype=torch.double))
|
| 186 |
globals.append(torch.zeros(len(self.global_features)))
|
| 187 |
for i_ti, tr_branch in enumerate(self.tracking_info):
|
| 188 |
if isinstance(tr_branch, str):
|
| 189 |
-
tracking[-1][i_ti] =
|
| 190 |
else:
|
| 191 |
tracking[-1][i_ti] = tr_branch
|
| 192 |
for i_gl, gl_branch in enumerate(self.global_features):
|
| 193 |
-
globals[-1][i_gl] =
|
| 194 |
oldtime = newtime
|
| 195 |
newtime = time.time()
|
| 196 |
times[2] += newtime - oldtime
|
|
@@ -198,7 +196,6 @@ class RootDataset(DGLDataset):
|
|
| 198 |
labels = torch.tensor(labels)
|
| 199 |
tracking = torch.stack(tracking)
|
| 200 |
globals = torch.stack(globals)
|
| 201 |
-
|
| 202 |
if (self.chunks > 1):
|
| 203 |
self.save_chunk(chunk_id, graphs, labels, tracking, globals)
|
| 204 |
else:
|
|
@@ -208,29 +205,15 @@ class RootDataset(DGLDataset):
|
|
| 208 |
self.graphs = graphs
|
| 209 |
self.save()
|
| 210 |
return
|
| 211 |
-
|
| 212 |
-
for chunk in self.graph_chunks[1:]:
|
| 213 |
-
self.graphs += chunk
|
| 214 |
-
self.labels = torch.cat(self.label_chunks)
|
| 215 |
-
self.tracking = torch.cat(self.tracking_chunks)
|
| 216 |
-
self.global_features = torch.cat(self.global_chunks)
|
| 217 |
-
print('Time spent: Creating TChain: {}s, Getting Entries and Selection: {}s, Graph Creation: {}s'.format(*times))
|
| 218 |
-
print('Time spent in node_features_from_tree: {}s, full_connected_graph: {}s'.format(*self.times))
|
| 219 |
-
|
| 220 |
def save(self):
|
| 221 |
-
"""save the graph list and the labels"""
|
| 222 |
if not self.save_to_disk:
|
| 223 |
return
|
| 224 |
graph_path = os.path.join(self.save_dir, self.name + '.bin')
|
| 225 |
if self.chunks == 1:
|
| 226 |
-
# print(len(self.graphs))
|
| 227 |
-
# print(len(self.labels))
|
| 228 |
-
# print(len(self.tracking))
|
| 229 |
-
# print(len(self.globals))
|
| 230 |
print(f'Saving dataset to {os.path.join(self.save_dir, self.name + ".bin")}')
|
| 231 |
dgl.save_graphs(str(graph_path), self.graphs, {'labels': torch.tensor(self.labels), 'tracking': torch.tensor(self.tracking), 'global': torch.tensor(self.global_features)})
|
| 232 |
else:
|
| 233 |
-
print(len(self.graph_chunks))
|
| 234 |
for i in range(len(self.process_chunks)):
|
| 235 |
print(f'Saving dataset to {os.path.join(self.save_dir, self.name + f"_{self.process_chunks[i]}.bin")}')
|
| 236 |
dgl.save_graphs(str(graph_path).replace('.bin', f'_{self.process_chunks[i]}.bin'), self.graph_chunks[i], {'labels': self.label_chunks[i], 'tracking': self.tracking_chunks[i], 'global': self.global_chunks[i]})
|
|
@@ -241,7 +224,7 @@ class RootDataset(DGLDataset):
|
|
| 241 |
graph_path = os.path.join(self.save_dir, self.name + '.bin')
|
| 242 |
print(f'Saving dataset to {os.path.join(self.save_dir, self.name + f"_{self.process_chunks[chunk_id]}.bin")}')
|
| 243 |
dgl.save_graphs(str(graph_path).replace('.bin', f'_{self.process_chunks[chunk_id]}.bin'), graphs, {'labels': labels, 'tracking': tracking, 'global': globals})
|
| 244 |
-
|
| 245 |
def has_cache(self):
|
| 246 |
print(f'Checking for cache of {self.name}')
|
| 247 |
if not self.save_to_disk:
|
|
@@ -290,7 +273,7 @@ class RootDataset(DGLDataset):
|
|
| 290 |
|
| 291 |
def __len__(self):
|
| 292 |
return len(self.graphs)
|
| 293 |
-
|
| 294 |
#Dataset with edge features added (deta, dphi, dR)
|
| 295 |
class EdgeDataset(RootDataset):
|
| 296 |
def make_graph(self, ch):
|
|
|
|
| 1 |
from dgl.data import DGLDataset
|
| 2 |
import dgl
|
| 3 |
+
import uproot
|
| 4 |
+
import awkward as ak
|
| 5 |
import torch
|
| 6 |
import os
|
| 7 |
import glob
|
|
|
|
| 15 |
if node_type == 'single':
|
| 16 |
lengths.append(1)
|
| 17 |
elif node_type == 'vector':
|
| 18 |
+
lengths.append(len(ch[branch]))
|
| 19 |
else:
|
| 20 |
print('Unknown node branch type: {}'.format(node_type))
|
| 21 |
features = []
|
|
|
|
| 39 |
this_type_ends_at = sum(lengths[:itype+1])
|
| 40 |
feat.extend(features[0][this_type_starts_at:this_type_ends_at]*torch.cosh(features[1][this_type_starts_at:this_type_ends_at]))
|
| 41 |
elif node_type == 'single':
|
| 42 |
+
feat.append(ch[branch])
|
| 43 |
elif node_type == 'vector':
|
| 44 |
+
feat.extend(ch[branch])
|
| 45 |
itype += 1
|
| 46 |
features.append(torch.tensor(feat))
|
| 47 |
return torch.stack(features, dim=1) * node_feature_scales, lengths
|
| 48 |
|
| 49 |
def full_connected_graph(n_nodes, self_loops=True):
|
|
|
|
|
|
|
| 50 |
senders = np.arange(n_nodes*n_nodes) // n_nodes
|
| 51 |
receivers = np.arange(n_nodes*n_nodes) % n_nodes
|
| 52 |
if not self_loops and n_nodes > 1:
|
|
|
|
| 58 |
def check_selection(ch, selection):
|
| 59 |
var, cut, op = selection
|
| 60 |
if op == '>':
|
| 61 |
+
return ch[var] > cut
|
| 62 |
elif op == '<':
|
| 63 |
+
return ch[var] < cut
|
| 64 |
elif op == '==':
|
| 65 |
+
return ch[var] == cut
|
| 66 |
+
|
| 67 |
def check_selections(ch, selections):
|
| 68 |
for selection in selections:
|
| 69 |
if not check_selection(ch, selection):
|
| 70 |
return False
|
| 71 |
return True
|
| 72 |
|
|
|
|
| 73 |
class RootDataset(DGLDataset):
|
| 74 |
def __init__(self, name=None, raw_dir=None, save_dir=None, label=1, file_names = '*.root', node_branch_names=None, node_branch_types=None, node_feature_scales=None,
|
| 75 |
selections=[], save=True, tree_name = 'nominal_Loose', fold_var = 'eventNumber', weight_var = None, chunks = 1, process_chunks = None, global_features = [], tracking_info = [], **kwargs):
|
|
|
|
| 86 |
self.fold_var = fold_var
|
| 87 |
self.tracking_info = tracking_info
|
| 88 |
self.tracking_info.insert(0, fold_var)
|
| 89 |
+
if weight_var is None:
|
| 90 |
weight_var = 1
|
| 91 |
self.tracking_info.insert(1, weight_var)
|
| 92 |
self.global_features = global_features
|
|
|
|
| 114 |
branches.append(feat)
|
| 115 |
for selection in self.selections:
|
| 116 |
branches.append(selection[0])
|
| 117 |
+
return list(set(branches)) # Remove duplicates
|
| 118 |
|
| 119 |
def make_graph(self, ch):
|
| 120 |
t1 = time.time()
|
|
|
|
| 127 |
self.times[0] += t2 - t1
|
| 128 |
self.times[1] += t3 - t2
|
| 129 |
return g
|
| 130 |
+
|
| 131 |
def process(self):
|
| 132 |
times = [0, 0, 0]
|
| 133 |
oldtime = time.time()
|
|
|
|
| 137 |
self.files = []
|
| 138 |
for file_name in self.file_names:
|
| 139 |
self.files.extend(glob.glob(os.path.join(self.raw_dir, file_name)))
|
| 140 |
+
branches = self.get_list_of_branches()
|
| 141 |
|
| 142 |
+
# Read all files and concatenate arrays
|
| 143 |
+
arrays = []
|
| 144 |
for file in self.files:
|
| 145 |
+
with uproot.open(file) as f:
|
| 146 |
+
arrays.append(f[self.tree_name].arrays(branches, library="ak"))
|
| 147 |
+
if len(arrays) == 0:
|
| 148 |
+
print('No files found in {}'.format(os.path.join(self.raw_dir, self.file_names)))
|
| 149 |
+
return
|
| 150 |
+
data = ak.concatenate(arrays, axis=0)
|
| 151 |
+
n_entries = len(data[branches[0]])
|
| 152 |
newtime = time.time()
|
| 153 |
times[0] += newtime - oldtime
|
| 154 |
+
chunks = np.array_split(np.arange(n_entries), self.chunks)
|
| 155 |
chunks = [chunk for i, chunk in enumerate(chunks) if i in self.process_chunks]
|
| 156 |
|
| 157 |
self.graph_chunks = []
|
|
|
|
| 167 |
globals = []
|
| 168 |
for ientry in chunk:
|
| 169 |
if (ientry % 10000 == 0):
|
| 170 |
+
print('Processing event {}/{}'.format(ientry, n_entries), flush=True)
|
| 171 |
+
ch = {b: data[b][ientry] for b in branches}
|
| 172 |
passed = True
|
| 173 |
for selection in self.selections:
|
| 174 |
+
if not check_selection(ch, selection):
|
| 175 |
passed = False
|
| 176 |
continue
|
| 177 |
oldtime = newtime
|
| 178 |
newtime = time.time()
|
| 179 |
times[1] += newtime - oldtime
|
| 180 |
if passed:
|
| 181 |
+
graphs.append(self.make_graph(ch))
|
| 182 |
+
labels.append(self.label)
|
| 183 |
tracking.append(torch.zeros(len(self.tracking_info), dtype=torch.double))
|
| 184 |
globals.append(torch.zeros(len(self.global_features)))
|
| 185 |
for i_ti, tr_branch in enumerate(self.tracking_info):
|
| 186 |
if isinstance(tr_branch, str):
|
| 187 |
+
tracking[-1][i_ti] = ch[tr_branch]
|
| 188 |
else:
|
| 189 |
tracking[-1][i_ti] = tr_branch
|
| 190 |
for i_gl, gl_branch in enumerate(self.global_features):
|
| 191 |
+
globals[-1][i_gl] = ch[gl_branch]
|
| 192 |
oldtime = newtime
|
| 193 |
newtime = time.time()
|
| 194 |
times[2] += newtime - oldtime
|
|
|
|
| 196 |
labels = torch.tensor(labels)
|
| 197 |
tracking = torch.stack(tracking)
|
| 198 |
globals = torch.stack(globals)
|
|
|
|
| 199 |
if (self.chunks > 1):
|
| 200 |
self.save_chunk(chunk_id, graphs, labels, tracking, globals)
|
| 201 |
else:
|
|
|
|
| 205 |
self.graphs = graphs
|
| 206 |
self.save()
|
| 207 |
return
|
| 208 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
def save(self):
|
|
|
|
| 210 |
if not self.save_to_disk:
|
| 211 |
return
|
| 212 |
graph_path = os.path.join(self.save_dir, self.name + '.bin')
|
| 213 |
if self.chunks == 1:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
print(f'Saving dataset to {os.path.join(self.save_dir, self.name + ".bin")}')
|
| 215 |
dgl.save_graphs(str(graph_path), self.graphs, {'labels': torch.tensor(self.labels), 'tracking': torch.tensor(self.tracking), 'global': torch.tensor(self.global_features)})
|
| 216 |
else:
|
|
|
|
| 217 |
for i in range(len(self.process_chunks)):
|
| 218 |
print(f'Saving dataset to {os.path.join(self.save_dir, self.name + f"_{self.process_chunks[i]}.bin")}')
|
| 219 |
dgl.save_graphs(str(graph_path).replace('.bin', f'_{self.process_chunks[i]}.bin'), self.graph_chunks[i], {'labels': self.label_chunks[i], 'tracking': self.tracking_chunks[i], 'global': self.global_chunks[i]})
|
|
|
|
| 224 |
graph_path = os.path.join(self.save_dir, self.name + '.bin')
|
| 225 |
print(f'Saving dataset to {os.path.join(self.save_dir, self.name + f"_{self.process_chunks[chunk_id]}.bin")}')
|
| 226 |
dgl.save_graphs(str(graph_path).replace('.bin', f'_{self.process_chunks[chunk_id]}.bin'), graphs, {'labels': labels, 'tracking': tracking, 'global': globals})
|
| 227 |
+
|
| 228 |
def has_cache(self):
|
| 229 |
print(f'Checking for cache of {self.name}')
|
| 230 |
if not self.save_to_disk:
|
|
|
|
| 273 |
|
| 274 |
def __len__(self):
|
| 275 |
return len(self.graphs)
|
| 276 |
+
|
| 277 |
#Dataset with edge features added (deta, dphi, dR)
|
| 278 |
class EdgeDataset(RootDataset):
|
| 279 |
def make_graph(self, ch):
|
root_gnn_dgl/scripts/inference.py
CHANGED
|
@@ -135,7 +135,6 @@ def main():
|
|
| 135 |
|
| 136 |
import time
|
| 137 |
start = time.time()
|
| 138 |
-
import ROOT
|
| 139 |
import torch
|
| 140 |
from array import array
|
| 141 |
import numpy as np
|
|
@@ -247,55 +246,41 @@ def main():
|
|
| 247 |
all_labels[branch] = labels
|
| 248 |
all_tracking[branch] = tracking_info
|
| 249 |
|
| 250 |
-
if args.write:
|
| 251 |
-
from ROOT import std
|
| 252 |
-
# Open the original ROOT file
|
| 253 |
-
infile = ROOT.TFile.Open(args.target)
|
| 254 |
-
tree = infile.Get(dset_config['args']['tree_name'])
|
| 255 |
|
| 256 |
-
|
| 257 |
-
|
|
|
|
| 258 |
|
| 259 |
-
#
|
| 260 |
-
|
|
|
|
| 261 |
|
| 262 |
-
#
|
| 263 |
-
|
| 264 |
|
| 265 |
-
#
|
| 266 |
-
|
|
|
|
| 267 |
for branch, scores in all_scores.items():
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
for value in scores[i]:
|
| 288 |
-
branch_data.push_back(float(value))
|
| 289 |
-
|
| 290 |
-
outtree.Fill()
|
| 291 |
-
|
| 292 |
-
# Write the modified tree to the new file
|
| 293 |
-
print(f'Writing to file {args.destination}')
|
| 294 |
-
print(f'Input entries: {tree.GetEntries()}, Output entries: {outtree.GetEntries()}')
|
| 295 |
-
print(f'Wrote scores to {args.branch_name}')
|
| 296 |
-
outtree.Write()
|
| 297 |
-
outfile.Close()
|
| 298 |
-
infile.Close()
|
| 299 |
else:
|
| 300 |
os.makedirs(os.path.split(args.destination)[0], exist_ok=True)
|
| 301 |
np.savez(args.destination, scores=all_scores, labels=all_labels, tracking_info=all_tracking)
|
|
|
|
| 135 |
|
| 136 |
import time
|
| 137 |
start = time.time()
|
|
|
|
| 138 |
import torch
|
| 139 |
from array import array
|
| 140 |
import numpy as np
|
|
|
|
| 246 |
all_labels[branch] = labels
|
| 247 |
all_tracking[branch] = tracking_info
|
| 248 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
|
| 250 |
+
if args.write:
|
| 251 |
+
import uproot
|
| 252 |
+
import awkward as ak
|
| 253 |
|
| 254 |
+
# Open the original ROOT file and get the tree
|
| 255 |
+
infile = uproot.open(args.target)
|
| 256 |
+
tree = infile[dset_config['args']['tree_name']]
|
| 257 |
|
| 258 |
+
# Read the original tree as an awkward array
|
| 259 |
+
original_data = tree.arrays(library="ak")
|
| 260 |
|
| 261 |
+
# Prepare new branches as dicts of arrays
|
| 262 |
+
new_branches = {}
|
| 263 |
+
n_entries = len(original_data)
|
| 264 |
for branch, scores in all_scores.items():
|
| 265 |
+
# Ensure the scores array is the right length
|
| 266 |
+
scores = np.asarray(scores)
|
| 267 |
+
if scores.shape[0] != n_entries:
|
| 268 |
+
raise ValueError(f"Branch '{branch}' has {scores.shape[0]} entries, but tree has {n_entries}")
|
| 269 |
+
new_branches[branch] = scores
|
| 270 |
+
|
| 271 |
+
# Merge all arrays (original + new branches)
|
| 272 |
+
# Convert awkward to dict of numpy arrays for uproot
|
| 273 |
+
out_dict = {k: np.asarray(v) for k, v in ak.to_numpy(original_data).items()}
|
| 274 |
+
out_dict.update(new_branches)
|
| 275 |
+
|
| 276 |
+
# Write to new ROOT file
|
| 277 |
+
os.makedirs(os.path.split(args.destination)[0], exist_ok=True)
|
| 278 |
+
with uproot.recreate(args.destination) as outfile:
|
| 279 |
+
outfile.mktree(dset_config['args']['tree_name'], {k: v.dtype for k, v in out_dict.items()})
|
| 280 |
+
outfile[dset_config['args']['tree_name']].extend(out_dict)
|
| 281 |
+
|
| 282 |
+
print(f"Wrote new ROOT file {args.destination} with new branches {list(new_branches.keys())}")
|
| 283 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
else:
|
| 285 |
os.makedirs(os.path.split(args.destination)[0], exist_ok=True)
|
| 286 |
np.savez(args.destination, scores=all_scores, labels=all_labels, tracking_info=all_tracking)
|