Yajur's First Pull

#2
by ypreetham - opened
Files changed (31) hide show
  1. .gitattributes +0 -1
  2. README.md +3 -354
  3. root_gnn_dgl/README.md +0 -2
  4. root_gnn_dgl/configs/stats_100K/finetuning_ttH_CP_even_vs_odd.yaml +1 -1
  5. root_gnn_dgl/configs/stats_100K/pretraining_multiclass.yaml +1 -1
  6. root_gnn_dgl/configs/stats_100K/ttH_CP_even_vs_odd_batch_size_2048.yaml +0 -57
  7. root_gnn_dgl/configs/stats_100K/ttH_CP_even_vs_odd_batch_size_4096.yaml +0 -57
  8. root_gnn_dgl/configs/stats_100K/ttH_CP_even_vs_odd_batch_size_8192.yaml +0 -57
  9. root_gnn_dgl/configs/stats_all/finetuning_ttH_CP_even_vs_odd.yaml +0 -5
  10. root_gnn_dgl/configs/stats_all/ttH_CP_even_vs_odd_batch_size_2048.yaml +0 -57
  11. root_gnn_dgl/configs/stats_all/ttH_CP_even_vs_odd_batch_size_4096.yaml +0 -57
  12. root_gnn_dgl/configs/stats_all/ttH_CP_even_vs_odd_batch_size_8192.yaml +0 -57
  13. root_gnn_dgl/jobs/prep_data/run_processing.py +1 -4
  14. root_gnn_dgl/jobs/training/singlegpu/run_job.sh +3 -3
  15. root_gnn_dgl/jobs/training/singlegpu/run_job_image.sh +11 -2
  16. root_gnn_dgl/jobs/training/singlegpu/submit.sh +1 -4
  17. root_gnn_dgl/models/GCN.py +3 -6
  18. root_gnn_dgl/profile.sh +0 -35
  19. root_gnn_dgl/root_gnn_base/batched_dataset.py +2 -3
  20. root_gnn_dgl/root_gnn_base/dataset.py +49 -41
  21. root_gnn_dgl/root_gnn_base/utils.py +8 -94
  22. root_gnn_dgl/root_gnn_base/visualize_input_distributions.py +0 -582
  23. root_gnn_dgl/run_demo.sh +1 -3
  24. root_gnn_dgl/scripts/check_dataset_files.py +0 -130
  25. root_gnn_dgl/scripts/inference.py +45 -30
  26. root_gnn_dgl/scripts/prep_data.py +5 -5
  27. root_gnn_dgl/scripts/training_script.py +11 -40
  28. root_gnn_dgl/setup/Dockerfile +0 -25
  29. root_gnn_dgl/setup/build_image.sh +0 -4
  30. root_gnn_dgl/setup/environment_torch24.yaml +0 -249
  31. training_time.png +0 -3
.gitattributes CHANGED
@@ -33,4 +33,3 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- training_time.png filter=lfs diff=lfs merge=lfs -text
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
README.md CHANGED
@@ -1,354 +1,3 @@
1
- ---
2
- license: mit
3
- language:
4
- - en
5
- pipeline_tag: graph-ml
6
- tags:
7
- - physics
8
- - GNN
9
- - FoundationModel
10
- ---
11
-
12
-
13
- ## Abstract
14
-
15
- We introduce a foundation model for event classification in high-energy physics, built on a **Graph Neural Network** architecture and trained on **120 million simulated proton-proton collision events** spanning 12 distinct physics processes. The model is *pretrained* to learn a general and robust representation of collision data using challenging multiclass and multilabel classification tasks.
16
-
17
- Its performance is evaluated across five event classification tasks, which include both physics processes used during pretraining and new processes not encountered during pretraining. Fine-tuning the pretrained model significantly improves classification performance, particularly in scenarios with limited training data, demonstrating gains in both accuracy and computational efficiency.
18
-
19
- To investigate the underlying mechanisms behind these performance improvements, we employ a representational similarity evaluation framework based on *Centered Kernel Alignment*. This analysis reveals notable differences in the learned representations of fine-tuned pretrained models compared to baseline models trained from scratch.
20
-
21
- ## Introduction
22
-
23
- Machine learning has become a ubiquitous tool in particle physics, employed in a variety of tasks including triggering, simulation, reconstruction, and offline analysis. While its utility spans classification, regression, and generative tasks, the current paradigm of developing machine learning models from scratch for each specific application presents several challenges. This approach not only demands specialized expertise and substantial computing resources but can also result in suboptimal performance due to limited training data. The from-scratch development of models necessitates individual validation studies to ensure that neural networks utilize well-modeled information from training samples, whether derived from Monte Carlo simulations or control samples from experimental data.
24
-
25
- Foundation models offer a promising direction to address these limitations. These models, pre-trained on large, diverse datasets across various tasks, provide robust and general representations of underlying data structures. Notable examples in other fields include GPT-4 [OpenAI et al., 2024](#ref-openai-2024-gpt4) and BERT [Devlin et al., 2018](#ref-devlin-2018-bert) in natural language processing, Stable Diffusion [Rombach et al., 2021](#ref-rombach-2021-latentdiffusion) in image processing, and AlphaFold [Jumper et al., 2021](#ref-jumper-2021-alphafold) in structural biology. The foundation model approach offers several advantages for particle physics applications: reduced computing resources for fine-tuning [Yosinski et al., 2014](#ref-yosinski-2014-transfer) compared to training from scratch, superior performance on specific tasks (particularly with limited training data), and potentially simplified validation procedures as downstream tasks inherit verified representations from the pre-trained model.
26
-
27
- Current literature on pretrained models for particle physics can be categorized based on the data representation they handle. Models operating on particle- or event-level numerical data use features like particle four momenta or jets, leveraging self-supervised or generative methods to learn versatile representations. Detector-focused models operate on high-dimensional responses such as calorimeter deposits or pixel hits, employing geometry-aware techniques for accurate simulation and analysis. Finally, models using textual or code representations apply large language model architectures to integrate domain knowledge, enabling tasks like question answering and code generation.
28
-
29
- Recent studies have begun exploring foundation models tailored to particle physics data, which has a variety of distinct structures and properties across many experiments and data processing stages, including:
30
-
31
- - particle-level & event-level numeric data [Wildridge et al., 2024](#ref-wildridge-2024-bumblebee), [Katel et al., 2024](#ref-katel-2024-jet), [Golling et al., 2024](#ref-golling-2024-maskedset), [Mikuni & Nachman, 2024](#ref-mikuni-2024-omnilearn), [Harris et al., 2024](#ref-harris-2024-resimulation), [Birk et al., 2024](#ref-birk-2024-omnijet), [Vigl et al., 2024](#ref-vigl-2024-finetune),
32
- - detector-level & geometry-aware data [Araz et al., 2024](#ref-araz-2024-pointcloud), [Liu et al., 2023](#ref-liu-2023-gaam), [Hashemi et al., 2024](#ref-hashemi-2024-gen), [Huang et al., 2024](#ref-huang-2024-lmtracking),
33
- - textual or code data [Zhang et al., 2024](#ref-zhang-2024-xiwu).
34
-
35
- This paper presents a foundation model designed specifically for collider event-level data. In modern collider experiments, final-stage analysis processes information from reconstructed objects that either directly correspond to particles in collision final states (such as leptons and photons) or serve as proxies (such as jets and missing transverse energy). While traditional approaches often relied on "high-level" variables calculated from object features, recent trends favor direct input of event objects and their features into neural networks for analysis tasks. A notable example is [ATLAS Collaboration, 2023](#ref-atlas-2023-4top), which established the observation of simultaneous production of four top quarks with the ATLAS experiment by employing a graph neural network (GNN) architecture to process event-level object information.
36
-
37
- We present foundation models that adopt an architecture similar to that used for [ATLAS Collaboration, 2023](#ref-atlas-2023-4top). Our models are pre-trained using either multiclass classification or multi-label learning tasks across 12 distinct physics processes. We evaluate these models through fine-tuning and testing on five classification tasks, including both familiar and novel processes not seen during pre-training. Our analysis benchmarks the models' performance improvements, their scaling behavior with training sample size, and computational efficiency, representing the first prototype of a foundation model operating on collider final-state object data.
38
-
39
- ## Data Samples
40
-
41
- To provide a diverse set of physics processes for the pretraining, we use Madgraph@NLO 2.7.3 [Alwall et al., 2014](#ref-alwall-2014hca) to generate proton-proton collision events at next-to-leading order (NLO) in Quantum Chromodynamics (QCD). We generate 12 distinct Standard Model (SM) physics processes, including six major Higgs boson production mechanisms: gluon fusion production (\\(ggF\\)), vector boson fusion (\\(VBF\\)), associated production of the Higgs boson with a W boson (\\(WH\\)) or a Z boson (\\(ZH\\)), associated production of the Higgs boson with a top-quark pair (\\(t\bar{t}H\\)), and associated production of the Higgs boson with a single top quark and a forward quark (\\(tHq\\)). Additionally, we simulate six top quark production processes: single top production, top-quark pair production (\\(t\bar{t}\\)), top quark pair production in association with a pair of photons (\\(t\bar{t}\gamma\gamma\\)), associated production of a top-quark pair with a W boson (\\(t\bar{t}W\\)), simultaneous production of three top quarks (\\(t\bar{t}t\\)), and simultaneous production of four top quarks (\\(t\bar{t}t\bar{t}\\)). In these samples, the Higgs boson and top quarks decay inclusively. These 12 Higgs and top quark production processes constitute the pretraining dataset.
42
-
43
- To test the pretrained model, we further generated four processes including three beyond Standard Model (SM) processes: a SM \\(t\bar{t}H\\) production where the Higgs boson decays exclusively to a pair of photons, a \\(t\bar{t}H\\) production with the Higgs boson decaying to a pair of photons, where the top-Yukawa coupling is CP-odd, implemented using the Higgs Characterization model [Artoisenet et al., 2013](#ref-artoisinet-2013puc), the production of a pair of superpartners of the top quark (s-top) using the Minimal Supersymmetric Standard Model (MSSM) [Rosiek, 1990](#ref-rosiek-1990), [Allanach et al., 2009](#ref-allanach-2009), and flavor changing neutral current (FCNC) processes [Degrande et al., 2015](#ref-degrande-2015), [Durieux et al., 2015](#ref-durieux-2015). For the s-top process, we simulate the production of heavier s-top pairs (\\(t_2\bar{t_2}\\)), where each heavier s-top (mass 582 GeV) decays into a lighter s-top (\\(t_1\\) or \\(\bar{t_1}\\), mass 400 GeV) and a Higgs boson. The FCNC process involves \\(t\bar{t}\\) production where one top quark decays to a Higgs boson and a light quark. We generate 10 million events for each process, except for \\(tHq\\) and \\(t\bar{t}t\bar{t}\\), where 5 million events were produced.
44
-
45
- In all simulation samples, the center of mass energy of the proton-proton collision is set to 13 TeV. The Higgs boson, top quarks, and vector bosons are set to decay inclusively (except the \\(t\bar{t}H \rightarrow \gamma\gamma\\) samples), with MadSpin [Artoisenet et al., 2012](#ref-artoisinet-2012st) handling the decays of top quarks and W bosons. The generated events are processed through Pythia 8.235 [Sjostrand et al., 2015](#ref-sjostrand-2015) for parton showering and heavy particle decays, followed by Delphes 3.4.2 [de Favereau et al., 2014](#ref-defavereau-2014) configured to emulate the ATLAS detector [ATLAS Collaboration, 2008](#ref-atlas-2008) for fast detector simulation.
46
-
47
- The detector-level object selection criteria are defined to align with typical experimental conditions. Photons are required to have transverse momentum \\(p_T \geq 20~\mathrm{GeV}\\) and pseudorapidity \\(|\eta| \leq 2.37\\), excluding the electromagnetic calorimeter crack region (\\(1.37 < |\eta| < 1.52\\)). Electrons must have \\(p_T \geq 10~\mathrm{GeV}\\) and \\(|\eta| \leq 2.47\\) (excluding the same crack region), while muons are selected with \\(p_T \geq 10~\mathrm{GeV}\\) and \\(|\eta| \leq 2.7\\). Jets are reconstructed using the anti-\\(k_t\\) algorithm [Cacciari et al., 2008](#ref-cacciari-2008gp) with radius parameter \\(\Delta R=0.4\\), where \\(\Delta R\\) is defined as \\(\sqrt{\Delta\eta ^2 + \Delta\phi^2}\\), with \\(\Delta\eta\\) being the difference in pseudorapidity and \\(\Delta\phi\\) the difference in azimuthal angle. Jets must satisfy \\(p_T \geq 25~\mathrm{GeV}\\) and \\(|\eta| \leq 2.5\\). To avoid double-counting, jets are removed if they are within \\(\Delta R < 0.4\\) of a photon or lepton. The identification of jets originating from b-quark decays (b-tagging) is performed by matching jets within \\(\Delta R = 0.4\\) of a b-quark, with efficiency corrections applied to match the performance of the ATLAS experiment's b-tagging algorithm [ATLAS Collaboration, 2019](#ref-atlas-2019bwq).
48
-
49
- ## Methods
50
-
51
- ### Overview
52
-
53
- We present a methodology for developing and evaluating a foundation model for particle collision event analysis. The approach centers on pretraining a Graph Neural Network (GNN) architecture using a comprehensive dataset that spans multiple physics tasks, enabling the model to learn robust and transferable features. For task-specific applications, we employ a fine-tuning strategy that combines output layer adaptation with carefully calibrated learning rates for updating the pretrained parameters.
54
-
55
- Given the prevalence of classification problems in particle physics data analysis, we evaluate the model's efficacy through a systematic assessment across five binary classification tasks:
56
-
57
- - \\(t\bar{t}H(\rightarrow \gamma\gamma)\\) with CP-even versus CP-odd t-H interaction
58
- - \\(t\bar{t}\\) with FCNC top quark decays versus \\(tHq\\) processes
59
- - \\(t\bar{t}W\\) versus \\(ttt\\) processes
60
- - Stop pair production with Higgs bosons in the decay chain versus \\(t\bar{t}H\\) processes
61
- - \\(WH\\) versus \\(ZH\\) production modes
62
-
63
- Our evaluation metrics encompass classification performance, computational efficiency, and model interpretability. The investigation extends to analyzing the model's scaling behavior with respect to training dataset size, benchmarked against models trained without pretraining. Although we explored transfer learning through parameter freezing of pretrained layers, this approach did not yield performance improvements, leading us to focus our detailed analysis on fine-tuning strategies.
64
-
65
- This methodological framework demonstrates the potential of foundation models to enhance the efficiency of particle physics analyses while improving task-specific performance, offering a promising direction for future high-energy physics research.
66
-
67
- ---
68
-
69
- ### GNN Architecture
70
-
71
- We implement a Graph Neural Network (GNN) architecture that naturally accommodates the point-cloud structure of particle physics data, employing the DGL framework with a PyTorch backend [Wang et al., 2019](#ref-dgl-2019), [Paszke et al., 2019](#ref-pytorch-2019). A fully connected graph is constructed for each event, with nodes corresponding to reconstructed jets, electrons, muons, photons, and \\(\vec{E}_T^{\text{miss}}\\). The features of each node include the four-momentum \\((p_T, \eta, \phi, E)\\) of the object with a massless assumption (\\(E = p_T \cosh \eta\\)), the b-tagging label (for jets), the charge (for leptons), and an integer labeling the type of object represented by the node. We use a placeholder value of 0 for features which are not defined for every node type such as the b-jet tag, lepton charge, or the pseudorapidity of \\(\vec{E}_T^{\text{miss}}\\). We assign the angular distances (\\(\Delta \eta, \Delta \phi, \Delta R\\)) as edge features and the number of nodes \\(N\\) in the graph as a global feature. We denote the node features \\(\{\vec x_i\}\\), edge features \\(\{\vec y_{ij}\}\\), and global features \\(\{\vec z\}\\).
72
-
73
- The GNN model is based on the graph network architecture described in [Battaglia et al., 2018](#ref-graphnets-2018) using simple multilayer perceptron (MLP) feature functions and summation aggregation. The model is comprised of three primary components: an encoder, the graph network, and a decoder. In the encoder, three MLPs embed the nodes, edges, and global features into a latent space of dimension 64. The graph network block, which is designed to facilitate message passing between different domains of the graph, performs an edge update \\(f_e\\), followed by a node update \\(f_n\\), and finally a global update \\(f_g\\), all defined below. The inputs to each update MLP are concatenated.
74
-
75
- $$
76
- \vec {y'}_{ij} = f_e\left(\{\vec x_k\},\vec y_{ij},\vec z\right) = \mathrm{MLP}\left(\vec x_i,\vec x_j,\vec y_{ij},\vec z\right)
77
- $$
78
-
79
- $$
80
- \vec{x'}_{i} = f_n\left(\vec x_i,\{\vec{y'}_{jk}\},\vec z\right) = \mathrm{MLP}\left(\vec x_i,\sum_j\vec{y'}_{ij},\vec z\right)
81
- $$
82
-
83
- $$
84
- \vec{z'} = f_g\left(\{\vec{x'}_i\},\{\vec{y'}_{ij}\},\vec z\right) = \mathrm{MLP}\left(\sum_i\vec{x'}_i,\sum_{i,j}\vec{y'}_{ij},\vec z\right)
85
- $$
86
-
87
- This graph block is iterated four times with the same update MLPs. Finally, the global features are passed through a decoder MLP and a final layer linear to produce the desired model outputs. Each MLP consists of 4 linear layers, each with an output width of 64, with the `ReLU` activation function. The output of the MLP is then passed through a `LayerNorm` layer [Ba et al., 2016](#ref-layernorm-2016). The total number of trainable parameters in this model is about 400,000.
88
-
89
- As a performance benchmark, a baseline GNN model is trained from scratch for each classification task. The initial learning rate is set to \\(10^{-4}\\) with an exponential decay following \\(LR(x) = LR_{\text{initial}}\cdot(0.99)^x\\), where \\(x\\) represents the epoch number.
90
-
91
- ---
92
-
93
- ### Pretraining Strategy
94
-
95
- We explore two complementary pretraining approaches to develop robust representations of collision events: (1) multi-class classification, which trains the model to distinguish between different physics processes, and (2) multi-label classification, which predicts the existence and kinematics of heavy particles with prompt decays. The pretraining dataset consists of approximately 120 million events, evenly distributed across 12 distinct physics processes, including all major Higgs boson production mechanisms and top quark processes as described in [Data Samples](#sec-data). This large-scale pretraining effort was conducted on the Perlmutter supercomputer at NERSC.
96
-
97
- #### Multi-class Classification
98
-
99
- For Monte Carlo simulated events, the underlying physics process that generated each event is known precisely, providing natural labels for supervised learning. However, the challenge lies in the complexity of collision events: different physics processes can produce similar kinematics and event topologies, particularly in certain regions of phase space. No single observable can unambiguously identify the underlying process. By training the model to distinguish between 12 different processes simultaneously, we challenge it to learn subtle differences in kinematics and topology that collectively characterize each process. The model is trained using categorical cross entropy as the loss function. The output layer of the multiclass classification model has 832 trainable parameters.
100
-
101
- #### Multi-label Classification
102
-
103
- This approach combines both classification and regression tasks to characterize collision events. For discrete properties like particle presence in specific kinematic regions, we employ classification labels with binary cross-entropy loss. For continuous quantities like particle multiplicities, we use regression labels with mean-squared error loss. This hybrid approach enables the model to learn both categorical and continuous aspects of the physics processes simultaneously.
104
-
105
- We develop a comprehensive set of 41 labels that capture both particle multiplicities and kinematic properties. This approach increases prediction granularity and enhances model interpretability. By training the model to predict event kinematics rather than event identification, we create a task-independent framework that can potentially generalize better to novel scenarios not seen during pretraining.
106
-
107
- The particle multiplicity labels count the number of Higgs bosons (\\(n_{\text{higgs}}\\)), top quarks (\\(n_{\text{tops}}\\)), vector bosons (\\(n_V\\)), \\(W\\) bosons (\\(n_W\\)), and \\(Z\\) bosons (\\(n_Z\\)). The kinematic labels characterize the transverse momentum (\\(p_T\\)), pseudorapidity (\\(\eta\\)), and azimuthal angle (\\(\phi\\)) of Higgs bosons and top quarks through binned classifications.
108
-
109
- For Higgs bosons, \\(p_T\\) is categorized into three ranges: (0, 30) GeV, (30, 200) GeV, and (200, \\(\infty\\)) GeV, with the upper range particularly sensitive to potential BSM effects. Similarly, both leading and subleading top quarks have \\(p_T\\) classifications spanning (0, 30) GeV, (30, 300) GeV, and (300, \\(\infty\\)) GeV. When no particle exists within a specific \\(p_T\\) range, the corresponding label is set to \\([0, 0, 0]\\). For all particles, \\(\eta\\) measurements are divided into 4 bins with boundaries at \\([-1.5, 0, 1.5]\\), while \\(\phi\\) measurements use 4 bins with boundaries at \\([-\frac{\pi}{2}, 0, \frac{\pi}{2}]\\). As with \\(p_T\\), both \\(\eta\\) and \\(\phi\\) labels default to \\([0, 0, 0, 0]\\) in the absence of a particle. This comprehensive labeling schema enables fine-grained learning of kinematic distributions and particle multiplicities, essential for characterizing complex collision events.
110
-
111
- The loss function combines individual losses from all 41 labels through weighted averaging. Binary cross-entropy is applied to classification labels, while mean-squared error is used for regression labels. The model generates predictions for all labels simultaneously, with individual losses calculated according to their respective types. The final loss is computed as an equally-weighted average across all labels, with weights set to 1 to ensure uniform contribution to the optimization process. The output layer of the multilabel model has 2,688 trainable parameters.
112
-
113
- #### Pretraining
114
-
115
- During pre-training, the initial learning rate is \\(10^{-4}\\), and the learning rate decays by 1% each epoch following the power law function \\(LR(x) = 10^{-4}\cdot(0.99)^x\\), where \\(x\\) is the number of epochs. Both pre-trained models reach a plateau in loss by epoch 50, at which point the training is stopped.
116
-
117
- ---
118
-
119
- ### Fine-tuning Methodology
120
-
121
- For downstream tasks, we adjust the model architecture for fine-tuning by replacing the original output layer (final linear layer) with a newly initialized linear layer while retaining the pre-trained weights for all other layers. This modification allows the model to specialize in the specific downstream task while leveraging the general features learned during pretraining.
122
-
123
- The fine-tuning process begins with distinct learning rate setups for different parts of the model. The newly initialized linear layer is trained with an initial learning rate of \\(10^{-4}\\), matching the rate used for models trained from scratch. Meanwhile, the pre-trained layers are fine-tuned more cautiously with a lower initial learning rate of \\(10^{-5}\\). This approach ensures that the pre-trained layers adapt gradually without losing their general features, while the new layer learns effectively from scratch. Both learning rates decay over time following the same power law function, \\(LR(x) = LR_{initial} \cdot (0.99)^x\\), to promote stable convergence as training progresses.
124
-
125
- We also evaluated a transfer learning setup in which either the decoder MLP or the final linear layer was replaced with a newly initialized component. During this process, all other model parameters remained frozen, leveraging the pre-trained features without further updating them. However, we did not observe performance improvements using the transfer learning setup. Consequently, we focus on reporting results obtained with the fine-tuning approach.
126
-
127
- ---
128
-
129
- ### Performance Evaluation
130
-
131
- We assess model performance using two figures of merit: the classification accuracy and the Area Under the Curve (AUC) of the Receiver Operating Characteristic (ROC) curve. The accuracy is defined as the fraction of correctly classified events when applying a threshold of 0.5 to the neural network output score. Both metrics demonstrate consistent trends in our analysis.
132
-
133
- To obtain reliable performance estimates and uncertainties, we employ an ensemble training approach where 5 independent models are trained for each configuration with random weight initialization and random subsets of the training dataset. This enables us to evaluate both the models' sensitivity to initial parameters and to quantify uncertainties in their performance.
134
-
135
- To investigate how model performance scales with training data, we conducted training runs using sample sizes ranging from \\(10^3\\) to \\(10^7\\) events per class (\\(10^3\\), \\(10^4\\), \\(10^5\\), \\(10^6\\), and \\(10^7\\)) for each model setup: the from-scratch baseline and models fine-tuned from multi-class or multi-label pretrained models. For the \\(10^7\\) case, only the initialization was randomized due to dataset size limitations. All models were evaluated on the same testing dataset, consisting of 2 million events per class, which remained separate from the training process.
136
-
137
- | **Name of Task** | **Pretraining Task** | \\(10^3\\) | \\(10^4\\) | \\(10^5\\) | \\(10^6\\) | \\(10^7\\) |
138
- |----------------------|----------------------|----------------|----------------|----------------|----------------|----------------|
139
- | **ttH CP Even vs Odd** | Baseline Accuracy | 56.5 \\(\pm\\) 1.1 | 62.2 \\(\pm\\) 0.1 | 64.3 \\(\pm\\) 0.0 | 65.7 \\(\pm\\) 0.0 | 66.2 \\(\pm\\) 0.0 |
140
- | | Multiclass (%) | +4.8 \\(\pm\\) 1.1 | +3.4 \\(\pm\\) 0.1 | +1.3 \\(\pm\\) 0.0 | +0.2 \\(\pm\\) 0.0 | -0.0 \\(\pm\\) 0.0 |
141
- | | Multilabel (%) | +2.1 \\(\pm\\) 1.2 | +1.9 \\(\pm\\) 0.1 | +0.8 \\(\pm\\) 0.1 | +0.0 \\(\pm\\) 0.0 | -0.1 \\(\pm\\) 0.0 |
142
- | **FCNC vs tHq** | Baseline Accuracy | 63.6 \\(\pm\\) 0.7 | 67.8 \\(\pm\\) 0.4 | 68.4 \\(\pm\\) 0.3 | 69.3 \\(\pm\\) 0.3 | 67.9 \\(\pm\\) 0.0 |
143
- | | Multiclass (%) | +5.8 \\(\pm\\) 0.8 | +1.2 \\(\pm\\) 0.4 | +1.4 \\(\pm\\) 0.3 | +0.5 \\(\pm\\) 0.3 | -0.0 \\(\pm\\) 0.0 |
144
- | | Multilabel (%) | -5.3 \\(\pm\\) 0.8 | -1.3 \\(\pm\\) 0.4 | +0.9 \\(\pm\\) 0.4 | +0.3 \\(\pm\\) 0.3 | +0.4 \\(\pm\\) 0.1 |
145
- | **ttW vs ttt** | Baseline Accuracy | 75.8 \\(\pm\\) 0.1 | 77.6 \\(\pm\\) 0.1 | 78.9 \\(\pm\\) 0.0 | 79.8 \\(\pm\\) 0.0 | 80.3 \\(\pm\\) 0.0 |
146
- | | Multiclass (%) | +3.7 \\(\pm\\) 0.1 | +2.7 \\(\pm\\) 0.1 | +1.3 \\(\pm\\) 0.0 | +0.4 \\(\pm\\) 0.0 | +0.0 \\(\pm\\) 0.0 |
147
- | | Multilabel (%) | +2.2 \\(\pm\\) 0.1 | +1.1 \\(\pm\\) 0.1 | +0.5 \\(\pm\\) 0.0 | +0.0 \\(\pm\\) 0.0 | -0.1 \\(\pm\\) 0.0 |
148
- | **stop vs ttH** | Baseline Accuracy | 83.0 \\(\pm\\) 0.2 | 86.3 \\(\pm\\) 0.1 | 87.6 \\(\pm\\) 0.0 | 88.5 \\(\pm\\) 0.0 | 88.8 \\(\pm\\) 0.0 |
149
- | | Multiclass (%) | +0.4 \\(\pm\\) 0.2 | +1.9 \\(\pm\\) 0.1 | +1.0 \\(\pm\\) 0.0 | +0.3 \\(\pm\\) 0.0 | +0.0 \\(\pm\\) 0.0 |
150
- | | Multilabel (%) | +2.8 \\(\pm\\) 0.2 | +1.0 \\(\pm\\) 0.1 | +0.5 \\(\pm\\) 0.0 | +0.0 \\(\pm\\) 0.0 | -0.0 \\(\pm\\) 0.0 |
151
- | **WH vs ZH** | Baseline Accuracy | 51.4 \\(\pm\\) 0.1 | 53.9 \\(\pm\\) 0.1 | 55.8 \\(\pm\\) 0.0 | 57.5 \\(\pm\\) 0.0 | 58.0 \\(\pm\\) 0.0 |
152
- | | Multiclass (%) | +5.2 \\(\pm\\) 0.1 | +5.3 \\(\pm\\) 0.1 | +3.1 \\(\pm\\) 0.0 | +0.6 \\(\pm\\) 0.0 | +0.1 \\(\pm\\) 0.0 |
153
- | | Multilabel (%) | -1.1 \\(\pm\\) 0.1 | -0.9 \\(\pm\\) 0.2 | +0.5 \\(\pm\\) 0.1 | +0.1 \\(\pm\\) 0.0 | -0.1 \\(\pm\\) 0.0 |
154
-
155
- > **Table 1**: Accuracy of the traditional model versus the accuracy increase due to fine-tuning from various pretraining tasks.
156
- > The accuracies are averaged over 5 independently trained models with randomly initialized weights and trained on a random subset of the data. One exception is the \\(10^7\\) training where all models use the same dataset due to limitations on our dataset size. The random subsets are allowed to overlap, but this overlap should be very minimal because all models take an independent random subset of \\(10^7\\) events. The testing accuracy is calculated from the same testing set of 2 million events per class across all models for a specific training task. The errors are the propagated errors (root sum of squares) of the standard deviation of accuracies for each model.
157
-
158
- ## Results
159
-
160
- ### Classification Performance
161
-
162
- Since the observations of AUC and accuracy show similar trends, we focus the presentation of the results using accuracy here for conciseness in Table 1.
163
-
164
- In general, the fine-tuned pretrained model achieves at least the same level of classification performance as the baseline model. Notably, there are significant improvements, particularly when the sample size is small, ranging from \\(10^3\\) to \\(10^4\\) events. In some cases, the accuracy improvements exceed five percentage points, demonstrating that pretrained models provide a strong initial representation that compensates for limited data. The numerical values of the improvements in accuracy may not fully capture the impact on the sensitivity of the measurements for which the neural network classifier is used, and the final sensitivity improvement is likely to be greater.
165
-
166
- As the training sample size grows to \\(10^5\\), \\(10^6\\), and eventually \\(10^7\\) events, the added benefit of pretraining diminishes. With abundant data, models trained from scratch approach or even match the accuracy of fine-tuned pretrained models. This suggests that large datasets enable effective learning from scratch, rendering the advantage of pretraining negligible in such scenarios.
167
-
168
- Although both pretraining approaches offer benefits, multiclass pretraining tends to provide more consistent improvements across tasks, especially in the low-data regime. In contrast, multilabel pretraining can sometimes lead to neutral or even slightly negative effects for certain tasks and data sizes. This highlights the importance of the pretraining task design, as the similarity between pretraining and fine-tuning tasks in the multiclass approach appears to yield better-aligned representations.
169
-
170
- Finally, the spread of accuracy across the five tasks for the baseline model is quite large, offering a robust test of fine-tuning across tasks of varying difficulty. The consistent observation of these trends across tasks confirms the reliability and robustness of the findings.
171
-
172
- ---
173
-
174
- ### Model Interpretability
175
-
176
- We aim to understand whether pretrained and baseline models learn the same underlying representations. If the two models exhibit high similarity, a plausible interpretation is that pretraining provides the pretrained model with an advantageous initialization, allowing it to converge to a similar state as the baseline model more efficiently. Conversely, significant differences between the models would indicate that pretraining facilitates the development of a more general and robust latent space, which serves as a foundation for fine-tuning to effectively adapt to the downstream task. To investigate this, we analyzed the representational similarity between a pretrained model fine-tuned for the downstream task and a baseline model trained directly on the downstream task without pretraining.
177
-
178
- We use Centered Kernel Alignment (CKA) [Kornblith et al., 2019](#ref-kornblith-2019-cka) to analyze model similarity and interpretability. CKA is a robust metric that quantifies the similarity between the internal representations of neural networks by comparing their feature matrices in a manner that is invariant to scaling, rotation, and alignment. This invariance makes CKA particularly effective for studying relationships between network layers, even across networks of different sizes or those trained from varying initializations.
179
-
180
- The similarity is evaluated using a 64-dimensional latent representation after the decoder stage of the GNN model. This choice allows us to compare the internal states of the models at a fine-grained level and understand how training strategies impact the representations directly used for the output task.
181
-
182
- To provide an intuitive understanding of CKA values, we construct a table of the CKA scores for various transformations performed on a set of dummy data.
183
-
184
- - **A:** randomly initialized matrix with shape (1000, 64), following a normal distribution ( \\(\sigma = 1, \mu=0\\) )
185
- - **B:** matrix with shape (1000, 64) constructed via various transformations performed on \\(A\\)
186
- - **Noise:** randomly initialized noise matrix with shape (1000, 64), following a normal distribution ( \\(\sigma = 1, \mu=0\\) )
187
-
188
- | Dataset | CKA Score |
189
- |---------|-----------|
190
- | \\(A, B = A\\) | 1.00 |
191
- | \\(A, B =\\) permutation on columns of \\(A\\) | 1.00 |
192
- | \\(A, B = A + \mathrm{Noise}(0.1)\\) | 0.99 |
193
- | \\(A, B = A + \mathrm{Noise}(0.5)\\) | 0.80 |
194
- | \\(A, B = A + \mathrm{Noise}(0.75)\\) | 0.77 |
195
- | \\(A, B = A\cdot \mathrm{Noise}(1)\\) (Linear Transformation) | 0.76 |
196
- | \\(A, B = A + \mathrm{Noise}(1)\\) | 0.69 |
197
- | \\(A, B = A + \mathrm{Noise}(2)\\) | 0.51 |
198
- | \\(A, B = A + \mathrm{Noise}(5)\\) | 0.39 |
199
-
200
- **Table 2:** CKA scores for a dummy dataset \\(A\\) and \\(B\\), where \\(B\\) is created via various transformations performed on \\(A\\).
201
-
202
- As seen in Table 2 and in the definition of the CKA, the CKA score is permutation-invariant. We will use the CKA score to evaluate the similarity between various models and gain insight into the learned representation of detector events in each model (i.e., the information that each model learns).
203
-
204
- We train ensembles of models for each training task to observe how the CKA score changes due to the random initialization of our models. The CKA score between two models is then defined to be:
205
-
206
- $$
207
- CKA(A, B) = \frac{1}{n^2}\sum_i^n \sum_j^n CKA(A_i, B_j)
208
- $$
209
-
210
- where \\(A_i\\) is the representation learned by the \\(i^{th}\\) model in an ensemble with \\(n\\) total models. The error in CKA is the standard deviation of \\(CKA(A_i, B_j)\\).
211
-
212
- Here we present results for the CKA similarity between the final model in each setup with the final model in the baseline, shown in Table 3.
213
-
214
- | Training Task | Baseline | Multiclass | Multilabel |
215
- |----------------------|------------------|-----------------|-----------------|
216
- | ttH CP Even vs Odd | 0.94 \\(\pm\\) 0.05 | 0.82 \\(\pm\\) 0.01 | 0.77 \\(\pm\\) 0.06 |
217
- | FCNC vs tHq | 0.96 \\(\pm\\) 0.03 | 0.76 \\(\pm\\) 0.01 | 0.81 \\(\pm\\) 0.01 |
218
- | ttW vs ttt | 0.91 \\(\pm\\) 0.08 | 0.75 \\(\pm\\) 0.10 | 0.72 \\(\pm\\) 0.05 |
219
- | stop vs ttH | 0.87 \\(\pm\\) 0.11 | 0.79 \\(\pm\\) 0.12 | 0.71 \\(\pm\\) 0.08 |
220
- | WH vs ZH | 0.90 \\(\pm\\) 0.07 | 0.53 \\(\pm\\) 0.03 | 0.44 \\(\pm\\) 0.06 |
221
-
222
- **Table 3:** CKA Similarity of the latent representation before the decoder with the baseline model, averaged over 3 models per training setup, and all models trained with the full dataset (\\(10^7\\)). The baseline column is not guaranteed to be 1.0 because of the random initialization of the model. Each baseline model converges to a slightly different representation as seen in the CKA values in that column.
223
-
224
- The baseline models with different initializations exhibit high similarity values, ranging from approximately 0.87 to 0.96, which indicates that independently trained baseline models tend to converge on similar internal representations despite random initialization. Across the considered tasks, models trained as multi-class or multi-label classifiers exhibit noticeably lower CKA similarity scores when compared to the baseline model. For example, in the WH vs ZH task, the baseline model and another baseline trained model have a high similarity of 0.90, whereas the multi-class and multi-label models show significantly reduced similarities (0.53 and 0.44, respectively). This pattern suggests that the representational spaces developed by multi-class or multi-label models differ substantially from those learned by the baseline model that was trained directly on the downstream classification task.
225
-
226
- ### Computational Efficiency
227
-
228
- To estimate the computational resources required for each approach, we measured the wall time needed for a model to reach its final performance. For baseline models, this is defined as the wall time from the start of training until the loss of the model plateaus. For the foundation model approach, the estimate includes both the pretraining time and the fine-tuning time, each measured from the start of training until the loss plateaus. This approach ensures a consistent and comprehensive evaluation of the computational demands.
229
-
230
- ![The ratio of the fine-tuning time required to achieve 99% of the baseline model's final classification accuracy to the total time spent training the baseline model.](training_time.png)
231
- *Fig. 1: The ratio of the fine-tuning time required to achieve 99% of the baseline model's final classification accuracy to the total time spent training the baseline model.*
232
-
233
- Figure 1 shows the fine-tuning time for the model pretrained with multiclass classification, relative to the time required for the baseline model, as a function of training sample size. In general, the fine-tuning time is significantly shorter than the training time required by the baseline model approach. For smaller training sets, on the order of \\(10^5\\) events, tasks such as FCNC vs. tHq and ttW vs. ttt benefit substantially from the pretrained model’s “head start,” achieving their final performance in only about 1% of the baseline time. For large training datasets, the fine-tuning time relative to the baseline training time becomes larger; however, given that the large training sample typically requires longer training time, fine-tuning still yields much faster training convergence. The ttH CP-even vs. ttH CP-odd task, with a training sample size of \\(10^7\\) events, is an exception where the fine-tuning time exceeds the training time required for the baseline model. This is likely because the processes involved in this task include photon objects in the final states, which are absent from the events used during pretraining.
234
-
235
- To accurately evaluate the total time consumption, it is necessary to include the pretraining time required for the foundation model approach. The pretraining times are as follows:
236
-
237
- - **Multi-class pretraining:** 45.5 GPU hours
238
- - **Multi-label pretraining:** 60.0 GPU hours
239
-
240
- The GPU hours recorded for the multi-label model represent the total time required when training the model in parallel on 16 GPUs. This includes a model synchronization step, which results in higher GPU hours compared to the multi-class pretraining model.
241
-
242
- The foundation model approach becomes increasingly efficient when a large number of tasks are fine-tuned using the same pretrained model, compared to training each task independently from scratch. To illustrate this, we evaluate the computational time required for a scenario where the training sample contains \\(10^7\\) events. For the five tasks tested in this study, the baseline training time (training from scratch) ranges from 1.68 GPU hours (WH vs. ZH) to 5.30 GPU hours (ttW vs. ttt), with an average baseline training time of 2.94 GPU hours. In contrast, the average fine-tuning time for the foundation model approach, relative to the baseline, is 38% of the baseline training time for \\(10^7\\) events. Based on these averages, we estimate that the foundation model approach becomes more computationally efficient than the baseline approach when fine-tuning is performed for more than 41 tasks.
243
-
244
- As a practical example, the ATLAS measurement of Higgs boson couplings using the \\(H \rightarrow \gamma\gamma\\) decay channel [ATLAS Collaboration, 2023](#ref-atlas-2023-higg) involved training 42 classifiers for event categorization. This coincides with our estimate, suggesting that the foundation model approach can reduce computational costs even for a single high-energy physics measurement.
245
-
246
- ## Conclusions
247
-
248
- We presented an in-depth study of a particle physics foundation model designed to operate on the four-momentum and identification properties of event final-state objects. This model is built on a Graph Neural Network (GNN) architecture and trained on a dataset comprising 120 million simulated proton-proton collision events across 12 distinct physics processes. The pretraining phase explored both multiclass and multilabel classification tasks, providing a robust foundation for downstream applications. Notably, the pretrained models demonstrated significant improvements in event classification performance when fine-tuned, particularly for tasks with limited training samples.
249
-
250
- The foundation model approach also offers substantial computational advantages. By leveraging fine-tuning, this methodology reduces the computational resources required for large-scale applications across multiple tasks. Our estimates indicate that significant resource savings can be achieved even for single particle physics measurements, making this approach both scalable and efficient.
251
-
252
- To better understand the learned representations of the pretrained model and guide future optimization efforts, we employed a representational similarity evaluation framework using Centered Kernel Alignment (CKA). This metric allowed us to investigate the source of the performance gains observed in the foundation model. Our analysis revealed notable differences in the learned representations between the fine-tuned pretrained model and a baseline model trained from scratch. In deep learning, it is well-established that multiple equally valid solutions can exist. Future studies are necessary to determine whether the low similarity in latent representations reflects complementary information uniquely captured by the foundation and baseline models, or if it can simply be attributed to connected local minima in the loss landscape.
253
-
254
- ## Acknowledgments
255
-
256
- This work is supported by the U.S. National Science Foundation under the Award No. 2046280, and by U.S. Department of Energy, Office of Science under contract DE-AC02-05CH11231.
257
-
258
- ## References
259
-
260
- - <span id="ref-openai-2024-gpt4"></span> **OpenAI et al.** GPT-4 Technical Report. arXiv:2303.08774 (2024). [https://arxiv.org/abs/2303.08774](https://arxiv.org/abs/2303.08774)
261
-
262
- - <span id="ref-yosinski-2014-transfer"></span> **Jason Yosinski, Jeff Clune, Yoshua Bengio, Hod Lipson.** How transferable are features in deep neural networks? CoRR abs/1411.1792 (2014). [http://arxiv.org/abs/1411.1792](http://arxiv.org/abs/1411.1792)
263
-
264
- - <span id="ref-rombach-2021-latentdiffusion"></span> **Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Björn Ommer.** High-Resolution Image Synthesis with Latent Diffusion Models. CoRR abs/2112.10752 (2021). [https://arxiv.org/abs/2112.10752](https://arxiv.org/abs/2112.10752)
265
-
266
- - <span id="ref-podell-2023-sdxl"></span> **Dustin Podell, Zion English, Kyle Lacey et al.** SDXL: Improving Latent Diffusion Models for High-Resolution Image Synthesis. arXiv:2307.01952 (2023). [https://arxiv.org/abs/2307.01952](https://arxiv.org/abs/2307.01952)
267
-
268
- - <span id="ref-jumper-2021-alphafold"></span> **John Jumper, Richard Evans, Alexander Pritzel et al.** Highly accurate protein structure prediction with AlphaFold. Nature 596, 583-589 (2021). [https://doi.org/10.1038/s41586-021-03819-2](https://doi.org/10.1038/s41586-021-03819-2)
269
-
270
- - <span id="ref-devlin-2018-bert"></span> **Jacob Devlin, Ming-Wei Chang, Kenton Lee, Kristina Toutanova.** BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding. CoRR abs/1810.04805 (2018). [http://arxiv.org/abs/1810.04805](http://arxiv.org/abs/1810.04805)
271
-
272
- - <span id="ref-atlas-2023-higg"></span> **ATLAS Collaboration.** Measurement of the properties of Higgs boson production at \\(\sqrt{s} = 13\,\text{TeV}\\) in the \\(H \to \gamma\gamma\\) channel using \\(139\,\text{fb}^{-1}\\) of \\(pp\\) collision data with the ATLAS experiment. JHEP 07 (2023) 088. [arXiv:2207.00348](https://arxiv.org/abs/2207.00348), [https://doi.org/10.1007/JHEP07(2023)088](https://doi.org/10.1007/JHEP07(2023)088)
273
-
274
- - <span id="ref-atlas-2023-4top"></span> **ATLAS Collaboration.** Observation of four-top-quark production in the multilepton final state with the ATLAS detector. Eur. Phys. J. C 83 (2023) 496. [arXiv:2303.15061](https://arxiv.org/abs/2303.15061), [https://doi.org/10.1140/epjc/s10052-023-11573-0](https://doi.org/10.1140/epjc/s10052-023-11573-0)
275
-
276
- - <span id="ref-kornblith-2019-cka"></span> **Simon Kornblith, Mohammad Norouzi, Honglak Lee, Geoffrey Hinton.** Similarity of Neural Network Representations Revisited. CoRR abs/1905.00414 (2019). [http://arxiv.org/abs/1905.00414](http://arxiv.org/abs/1905.00414)
277
-
278
- ---
279
-
280
- <!-- Historical/General Physics foundational texts -->
281
-
282
- - <span id="ref-birell-1982-qfields"></span> **N. D. Birell, P. C. W. Davies.** Quantum Fields in Curved Space. Cambridge Univ. Press (1982).
283
-
284
- - <span id="ref-feynman-1954"></span> **R. P. Feynman.** Phys. Rev. 94, 262 (1954).
285
-
286
- - <span id="ref-einstein-1935-epr"></span> **A. Einstein, Yu. Podolsky, N. Rosen.** Phys. Rev. 47, 777 (1935).
287
-
288
- - <span id="ref-berman-1983-stability"></span> **G. P. Berman, Jr., F. M. Izrailev, Jr.** Stability of nonlinear modes. Physica D 88, 445 (1983).
289
-
290
- - <span id="ref-davies-1988-trapped"></span> **E. B. Davies, L. Parns.** Trapped modes in acoustic waveguides. Q. J. Mech. Appl. Math. 51, 477–492 (1988).
291
-
292
- - <span id="ref-witten-2001"></span> **Edward Witten.** hep-th/0106109 (2001). [https://arxiv.org/abs/hep-th/0106109](https://arxiv.org/abs/hep-th/0106109)
293
-
294
- ---
295
-
296
- <!-- Particle physics/data science foundational models -->
297
-
298
- - <span id="ref-beutler-1994-hem"></span> **E. Beutler.** Williams Hematology, 5th Edition, Chapter 7, pp. 654–662. McGraw-Hill, New York (1994).
299
-
300
- - <span id="ref-knuth-1973-fa"></span> **Donald E. Knuth.** The Art of Computer Programming vol. 1: Fundamental Algorithms, 2nd Ed., Addison-Wesley (1973).
301
-
302
- - <span id="ref-smith-2005-philos"></span> **J. S. Smith, G. W. Johnson.** Philos. Trans. R. Soc. London, Ser. B 777, 1395 (2005).
303
-
304
- - <span id="ref-smith-2010-jap-unpub"></span> **W. J. Smith, T. J. Johnson, B. G. Miller.** Surface chemistry and preferential crystal orientation on a silicon surface. J. Appl. Phys. (unpublished, 2010).
305
-
306
- - <span id="ref-smith-2010-jap-sub"></span> **V. K. Smith, K. Johnson, M. O. Klein.** Surface chemistry and preferential crystal orientation on a silicon surface. J. Appl. Phys. (submitted, 2010).
307
-
308
- - <span id="ref-underwood-1988-lowerbounds"></span> **Ulrich Underwood, Ned Net, Paul Pot.** Lower Bounds for Wishful Research Results. Talk at Fanstord University (1988).
309
-
310
- - <span id="ref-johnson-2007-comm"></span> **M. P. Johnson, K. L. Miller, K. Smith.** Personal communication (Jan-May 2007).
311
-
312
- ---
313
-
314
- <!-- Prototypical collider software and tools -->
315
-
316
- - <span id="ref-pytorch-2019"></span> **Adam Paszke et al.** PyTorch: An Imperative Style, High-Performance Deep Learning Library. arXiv:1912.01703 (2019). [http://arxiv.org/abs/1912.01703](http://arxiv.org/abs/1912.01703)
317
-
318
- - <span id="ref-dgl-2019"></span> **Minjie Wang et al.** Deep Graph Library: Towards Efficient and Scalable Deep Learning on Graphs. arXiv:1909.01315 (2019). [http://arxiv.org/abs/1909.01315](http://arxiv.org/abs/1909.01315)
319
-
320
- - <span id="ref-graphnets-2018"></span> **Peter W. Battaglia et al.** Relational inductive biases, deep learning, and graph networks. arXiv:1806.01261 (2018). [http://arxiv.org/abs/1806.01261](http://arxiv.org/abs/1806.01261)
321
-
322
- - <span id="ref-layernorm-2016"></span> **Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton.** Layer Normalization. arXiv:1607.06450 (2016). [https://arxiv.org/abs/1607.06450](https://arxiv.org/abs/1607.06450)
323
-
324
- ---
325
-
326
- <!-- Recent & foundation models in HEP ML -->
327
-
328
- - <span id="ref-wildridge-2024-bumblebee"></span> **Andrew J. Wildridge et al.** Bumblebee: Foundation Model for Particle Physics Discovery. arXiv:2412.07867 (2024). [https://arxiv.org/abs/2412.07867](https://arxiv.org/abs/2412.07867)
329
-
330
- - <span id="ref-katel-2024-jet"></span> **Subash Katel et al.** Learning Symmetry-Independent Jet Representations via Jet-Based Joint Embedding Predictive Architecture. arXiv:2412.05333 (2024). [https://arxiv.org/abs/2412.05333](https://arxiv.org/abs/2412.05333)
331
-
332
- - <span id="ref-araz-2024-pointcloud"></span> **Jack Y. Araz et al.** Point cloud-based diffusion models for the Electron-Ion Collider. arXiv:2410.22421 (2024). [https://arxiv.org/abs/2410.22421](https://arxiv.org/abs/2410.22421)
333
-
334
- - <span id="ref-leigh-2024-maskedparticle"></span> **Matthew Leigh et al.** Is Tokenization Needed for Masked Particle Modelling? arXiv:2409.12589 (2024). [https://arxiv.org/abs/2409.12589](https://arxiv.org/abs/2409.12589)
335
-
336
- - <span id="ref-mikuni-2024-omnilearn"></span> **Vinicius Mikuni, Benjamin Nachman.** OmniLearn: A Method to Simultaneously Facilitate All Jet Physics Tasks. arXiv:2404.16091 (2024). [https://arxiv.org/abs/2404.16091](https://arxiv.org/abs/2404.16091)
337
-
338
- - <span id="ref-zhang-2024-xiwu"></span> **Zhengde Zhang et al.** Xiwu: A Basis Flexible and Learnable LLM for High Energy Physics. arXiv:2404.08001 (2024). [https://arxiv.org/abs/2404.08001](https://arxiv.org/abs/2404.08001)
339
-
340
- - <span id="ref-harris-2024-resimulation"></span> **Philip Harris et al.** Re-Simulation-based Self-Supervised Learning for Pre-Training Foundation Models. arXiv:2403.07066 (2024). [https://arxiv.org/abs/2403.07066](https://arxiv.org/abs/2403.07066)
341
-
342
- - <span id="ref-birk-2024-omnijet"></span> **Joschka Birk, Anna Hallin, Gregor Kasieczka.** OmniJet-$\alpha$: the first cross-task foundation model for particle physics. Machine Learning: Science and Technology. 5(3), 035031 (Aug 2024). [https://doi.org/10.1088/2632-2153/ad66ad](https://doi.org/10.1088/2632-2153/ad66ad)
343
-
344
- - <span id="ref-huang-2024-lmtracking"></span> **Andris Huang et al.** A Language Model for Particle Tracking. arXiv:2402.10239 (2024). [https://arxiv.org/abs/2402.10239](https://arxiv.org/abs/2402.10239)
345
-
346
- - <span id="ref-golling-2024-maskedset"></span> **Tobias Golling et al.** Masked Particle Modeling on Sets: Towards Self-Supervised High Energy Physics Foundation Models. arXiv:2401.13537 (2024). [https://arxiv.org/abs/2401.13537](https://arxiv.org/abs/2401.13537)
347
-
348
- - <span id="ref-liu-2023-gaam"></span> **Junze Liu et al.** Generalizing to new geometries with Geometry-Aware Autoregressive Models (GAAMs) for fast calorimeter simulation. Journal of Instrumentation 18(11), P11003 (Nov 2023). [https://doi.org/10.1088/1748-0221/18/11/p11003](https://doi.org/10.1088/1748-0221/18/11/p11003)
349
-
350
- - <span id="ref-hashemi-2024-gen"></span> **Baran Hashemi et al.** Ultra-high-granularity detector simulation with intra-event aware generative adversarial network and self-supervised relational reasoning. Nature Communications 15(1) (June 2024). [https://doi.org/10.1038/s41467-024-49104-4](https://doi.org/10.1038/s41467-024-49104-4)
351
-
352
- - <span id="ref-vigl-2024-finetune"></span> **Matthias Vigl et al.** Finetuning Foundation Models for Joint Analysis Optimization. arXiv:2401.13536 (2024). [https://arxiv.org/abs/2401.13536](https://arxiv.org/abs/2401.13536)
353
-
354
- - <span id="ref-li-2024-refine"></span> **Chen Li, Hao Cai, Xianyang Jiang.** Refine neutrino events reconstruction with BEiT-3. Journal of Instrumentation 19(6), T06003 (Jun 2024). [https://doi.org/10.1088/1748-0221/19/06/t06003](https://doi.org/10.1088/1748-0221/19/06/t06003)
 
1
+ ---
2
+ license: mit
3
+ ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
root_gnn_dgl/README.md CHANGED
@@ -89,8 +89,6 @@ dgl.save_graphs(str(graph_path).replace('.bin', f'_{self.process_chunks[i]}.bin'
89
  IndexError: list index out of range
90
  ```
91
 
92
- To make sure you have all the necessary graphs to train, you can use the `scripts/check_dataset_files.py` script to ensure all graphs are properly processed. Using the `--rerun` runtime arguement will tell the script to automically re-processes any missing files.
93
-
94
  ## Training
95
  Training is run by `scripts/training_script`. `--preshuffle` tells it to use the preshuffled and batched graphs rather than shuffling and batching on the fly, and `--restart` can be used to force the training to start from the beginning rather than from the last available checkpoint.
96
 
 
89
  IndexError: list index out of range
90
  ```
91
 
 
 
92
  ## Training
93
  Training is run by `scripts/training_script`. `--preshuffle` tells it to use the preshuffled and batched graphs rather than shuffling and batching on the fly, and `--restart` can be used to force the training to start from the beginning rather than from the last available checkpoint.
94
 
root_gnn_dgl/configs/stats_100K/finetuning_ttH_CP_even_vs_odd.yaml CHANGED
@@ -23,7 +23,7 @@ Model:
23
  Training:
24
  epochs: 500
25
  batch_size: 1024
26
- learning_rate: 0.00001
27
  gamma: 0.99
28
  Datasets:
29
  ttH_CP_even: &dataset_defn
 
23
  Training:
24
  epochs: 500
25
  batch_size: 1024
26
+ learning_rate: 0.0001
27
  gamma: 0.99
28
  Datasets:
29
  ttH_CP_even: &dataset_defn
root_gnn_dgl/configs/stats_100K/pretraining_multiclass.yaml CHANGED
@@ -94,7 +94,7 @@ Datasets:
94
  <<: *dataset_defn
95
  args:
96
  <<: *dataset_args
97
- name: ttyy
98
  label: 6
99
  file_names: 'ttyy.root'
100
  tttt:
 
94
  <<: *dataset_defn
95
  args:
96
  <<: *dataset_args
97
+ name: ttyy_ch
98
  label: 6
99
  file_names: 'ttyy.root'
100
  tttt:
root_gnn_dgl/configs/stats_100K/ttH_CP_even_vs_odd_batch_size_2048.yaml DELETED
@@ -1,57 +0,0 @@
1
- Training_Name: ttH_CP_even_vs_odd_batch_size_2048
2
- Training_Directory: trainings/stats_100K/ttH_CP_even_vs_odd_batch_size_2048
3
- Model:
4
- module: models.GCN
5
- class: Edge_Network
6
- args:
7
- hid_size: 64
8
- in_size: 7
9
- out_size: 1
10
- n_layers: 4
11
- n_proc_steps: 4
12
- dropout: 0
13
- Training:
14
- epochs: 500
15
- batch_size: 2048
16
- learning_rate: 0.0001
17
- gamma: 0.99
18
- Datasets:
19
- ttH_CP_even: &dataset_defn
20
- module: root_gnn_base.dataset
21
- class: LazyDataset
22
- shuffle_chunks: 3
23
- batch_size: 2048
24
- padding_mode: NONE #one of STEPS, FIXED, or NONE
25
- args: &dataset_args
26
- name: ttH_CP_even
27
- label: 0
28
- # weight_var: weight
29
- chunks: 3
30
- buffer_size: 2
31
- file_names: ttH_NLO.root
32
- tree_name: output
33
- fold_var: Number
34
- raw_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/stats_100K/
35
- save_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/processed_graphs/stats_100K/ttH_CP_even_vs_odd_batch_size_2048/
36
- node_branch_names:
37
- - [jet_pt, ele_pt, mu_pt, ph_pt, MET_met]
38
- - [jet_eta, ele_eta, mu_eta, ph_eta, 0]
39
- - [jet_phi, ele_phi, mu_phi, ph_phi, MET_phi]
40
- - CALC_E
41
- - [jet_btag, 0, 0, 0, 0]
42
- - [0, ele_charge, mu_charge, 0, 0]
43
- - NODE_TYPE
44
- node_branch_types: [vector, vector, vector, vector, single]
45
- node_feature_scales: [1e-1, 1, 1, 1e-1, 1, 1, 1]
46
- folding:
47
- n_folds: 4
48
- test: [0]
49
- # validation: 1
50
- train: [1, 2, 3]
51
- ttH_CP_odd:
52
- <<: *dataset_defn
53
- args:
54
- <<: *dataset_args
55
- name: ttH_CP_odd
56
- label: 1
57
- file_names: ttH_CPodd.root
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
root_gnn_dgl/configs/stats_100K/ttH_CP_even_vs_odd_batch_size_4096.yaml DELETED
@@ -1,57 +0,0 @@
1
- Training_Name: ttH_CP_even_vs_odd_batch_size_4096
2
- Training_Directory: trainings/stats_100K/ttH_CP_even_vs_odd_batch_size_4096
3
- Model:
4
- module: models.GCN
5
- class: Edge_Network
6
- args:
7
- hid_size: 64
8
- in_size: 7
9
- out_size: 1
10
- n_layers: 4
11
- n_proc_steps: 4
12
- dropout: 0
13
- Training:
14
- epochs: 500
15
- batch_size: 1024
16
- learning_rate: 0.0001
17
- gamma: 0.99
18
- Datasets:
19
- ttH_CP_even: &dataset_defn
20
- module: root_gnn_base.dataset
21
- class: LazyDataset
22
- shuffle_chunks: 3
23
- batch_size: 4096
24
- padding_mode: NONE #one of STEPS, FIXED, or NONE
25
- args: &dataset_args
26
- name: ttH_CP_even
27
- label: 0
28
- # weight_var: weight
29
- chunks: 3
30
- buffer_size: 2
31
- file_names: ttH_NLO.root
32
- tree_name: output
33
- fold_var: Number
34
- raw_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/stats_100K/
35
- save_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/processed_graphs/stats_100K/ttH_CP_even_vs_odd_batch_size_4096/
36
- node_branch_names:
37
- - [jet_pt, ele_pt, mu_pt, ph_pt, MET_met]
38
- - [jet_eta, ele_eta, mu_eta, ph_eta, 0]
39
- - [jet_phi, ele_phi, mu_phi, ph_phi, MET_phi]
40
- - CALC_E
41
- - [jet_btag, 0, 0, 0, 0]
42
- - [0, ele_charge, mu_charge, 0, 0]
43
- - NODE_TYPE
44
- node_branch_types: [vector, vector, vector, vector, single]
45
- node_feature_scales: [1e-1, 1, 1, 1e-1, 1, 1, 1]
46
- folding:
47
- n_folds: 4
48
- test: [0]
49
- # validation: 1
50
- train: [1, 2, 3]
51
- ttH_CP_odd:
52
- <<: *dataset_defn
53
- args:
54
- <<: *dataset_args
55
- name: ttH_CP_odd
56
- label: 1
57
- file_names: ttH_CPodd.root
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
root_gnn_dgl/configs/stats_100K/ttH_CP_even_vs_odd_batch_size_8192.yaml DELETED
@@ -1,57 +0,0 @@
1
- Training_Name: ttH_CP_even_vs_odd_batch_size_8192
2
- Training_Directory: trainings/stats_100K/ttH_CP_even_vs_odd_batch_size_8192
3
- Model:
4
- module: models.GCN
5
- class: Edge_Network
6
- args:
7
- hid_size: 64
8
- in_size: 7
9
- out_size: 1
10
- n_layers: 4
11
- n_proc_steps: 4
12
- dropout: 0
13
- Training:
14
- epochs: 500
15
- batch_size: 2048
16
- learning_rate: 0.0001
17
- gamma: 0.99
18
- Datasets:
19
- ttH_CP_even: &dataset_defn
20
- module: root_gnn_base.dataset
21
- class: LazyDataset
22
- shuffle_chunks: 3
23
- batch_size: 2048
24
- padding_mode: NONE #one of STEPS, FIXED, or NONE
25
- args: &dataset_args
26
- name: ttH_CP_even
27
- label: 0
28
- # weight_var: weight
29
- chunks: 3
30
- buffer_size: 2
31
- file_names: ttH_NLO.root
32
- tree_name: output
33
- fold_var: Number
34
- raw_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/stats_100K/
35
- save_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/processed_graphs/stats_100K/ttH_CP_even_vs_odd_batch_size_8192/
36
- node_branch_names:
37
- - [jet_pt, ele_pt, mu_pt, ph_pt, MET_met]
38
- - [jet_eta, ele_eta, mu_eta, ph_eta, 0]
39
- - [jet_phi, ele_phi, mu_phi, ph_phi, MET_phi]
40
- - CALC_E
41
- - [jet_btag, 0, 0, 0, 0]
42
- - [0, ele_charge, mu_charge, 0, 0]
43
- - NODE_TYPE
44
- node_branch_types: [vector, vector, vector, vector, single]
45
- node_feature_scales: [1e-1, 1, 1, 1e-1, 1, 1, 1]
46
- folding:
47
- n_folds: 4
48
- test: [0]
49
- # validation: 1
50
- train: [1, 2, 3]
51
- ttH_CP_odd:
52
- <<: *dataset_defn
53
- args:
54
- <<: *dataset_args
55
- name: ttH_CP_odd
56
- label: 1
57
- file_names: ttH_CPodd.root
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
root_gnn_dgl/configs/stats_all/finetuning_ttH_CP_even_vs_odd.yaml CHANGED
@@ -20,11 +20,6 @@ Model:
20
  n_layers: 4
21
  n_proc_steps: 4
22
  dropout: 0
23
- Training:
24
- epochs: 500
25
- batch_size: 1024
26
- learning_rate: 0.00001
27
- gamma: 0.99
28
  Datasets:
29
  ttH_CP_even: &dataset_defn
30
  module: root_gnn_base.dataset
 
20
  n_layers: 4
21
  n_proc_steps: 4
22
  dropout: 0
 
 
 
 
 
23
  Datasets:
24
  ttH_CP_even: &dataset_defn
25
  module: root_gnn_base.dataset
root_gnn_dgl/configs/stats_all/ttH_CP_even_vs_odd_batch_size_2048.yaml DELETED
@@ -1,57 +0,0 @@
1
- Training_Name: ttH_CP_even_vs_odd_batch_size_2048
2
- Training_Directory: trainings/stats_all/ttH_CP_even_vs_odd_batch_size_2048
3
- Model:
4
- module: models.GCN
5
- class: Edge_Network
6
- args:
7
- hid_size: 64
8
- in_size: 7
9
- out_size: 1
10
- n_layers: 4
11
- n_proc_steps: 4
12
- dropout: 0
13
- Training:
14
- epochs: 500
15
- batch_size: 2048
16
- learning_rate: 0.0001
17
- gamma: 0.99
18
- Datasets:
19
- ttH_CP_even: &dataset_defn
20
- module: root_gnn_base.dataset
21
- class: LazyDataset
22
- shuffle_chunks: 10
23
- batch_size: 2048
24
- padding_mode: NONE #one of STEPS, FIXED, or NONE
25
- args: &dataset_args
26
- name: ttH_CP_even
27
- label: 0
28
- # weight_var: weight
29
- chunks: 10
30
- buffer_size: 3
31
- file_names: ttH_NLO.root
32
- tree_name: output
33
- fold_var: Number
34
- raw_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/stats_all/
35
- save_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/processed_graphs/stats_all/ttH_CP_even_vs_odd_batch_size_2048/
36
- node_branch_names:
37
- - [jet_pt, ele_pt, mu_pt, ph_pt, MET_met]
38
- - [jet_eta, ele_eta, mu_eta, ph_eta, 0]
39
- - [jet_phi, ele_phi, mu_phi, ph_phi, MET_phi]
40
- - CALC_E
41
- - [jet_btag, 0, 0, 0, 0]
42
- - [0, ele_charge, mu_charge, 0, 0]
43
- - NODE_TYPE
44
- node_branch_types: [vector, vector, vector, vector, single]
45
- node_feature_scales: [1e-1, 1, 1, 1e-1, 1, 1, 1]
46
- folding:
47
- n_folds: 4
48
- test: [0]
49
- # validation: 1
50
- train: [1, 2, 3]
51
- ttH_CP_odd:
52
- <<: *dataset_defn
53
- args:
54
- <<: *dataset_args
55
- name: ttH_CP_odd
56
- label: 1
57
- file_names: ttH_CPodd.root
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
root_gnn_dgl/configs/stats_all/ttH_CP_even_vs_odd_batch_size_4096.yaml DELETED
@@ -1,57 +0,0 @@
1
- Training_Name: ttH_CP_even_vs_odd_batch_size_4096
2
- Training_Directory: trainings/stats_all/ttH_CP_even_vs_odd_batch_size_4096
3
- Model:
4
- module: models.GCN
5
- class: Edge_Network
6
- args:
7
- hid_size: 64
8
- in_size: 7
9
- out_size: 1
10
- n_layers: 4
11
- n_proc_steps: 4
12
- dropout: 0
13
- Training:
14
- epochs: 500
15
- batch_size: 4096
16
- learning_rate: 0.0001
17
- gamma: 0.99
18
- Datasets:
19
- ttH_CP_even: &dataset_defn
20
- module: root_gnn_base.dataset
21
- class: LazyDataset
22
- shuffle_chunks: 10
23
- batch_size: 4096
24
- padding_mode: NONE #one of STEPS, FIXED, or NONE
25
- args: &dataset_args
26
- name: ttH_CP_even
27
- label: 0
28
- # weight_var: weight
29
- chunks: 10
30
- buffer_size: 3
31
- file_names: ttH_NLO.root
32
- tree_name: output
33
- fold_var: Number
34
- raw_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/stats_all/
35
- save_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/processed_graphs/stats_all/ttH_CP_even_vs_odd_batch_size_4096/
36
- node_branch_names:
37
- - [jet_pt, ele_pt, mu_pt, ph_pt, MET_met]
38
- - [jet_eta, ele_eta, mu_eta, ph_eta, 0]
39
- - [jet_phi, ele_phi, mu_phi, ph_phi, MET_phi]
40
- - CALC_E
41
- - [jet_btag, 0, 0, 0, 0]
42
- - [0, ele_charge, mu_charge, 0, 0]
43
- - NODE_TYPE
44
- node_branch_types: [vector, vector, vector, vector, single]
45
- node_feature_scales: [1e-1, 1, 1, 1e-1, 1, 1, 1]
46
- folding:
47
- n_folds: 4
48
- test: [0]
49
- # validation: 1
50
- train: [1, 2, 3]
51
- ttH_CP_odd:
52
- <<: *dataset_defn
53
- args:
54
- <<: *dataset_args
55
- name: ttH_CP_odd
56
- label: 1
57
- file_names: ttH_CPodd.root
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
root_gnn_dgl/configs/stats_all/ttH_CP_even_vs_odd_batch_size_8192.yaml DELETED
@@ -1,57 +0,0 @@
1
- Training_Name: ttH_CP_even_vs_odd_batch_size_8192
2
- Training_Directory: trainings/stats_all/ttH_CP_even_vs_odd_batch_size_8192
3
- Model:
4
- module: models.GCN
5
- class: Edge_Network
6
- args:
7
- hid_size: 64
8
- in_size: 7
9
- out_size: 1
10
- n_layers: 4
11
- n_proc_steps: 4
12
- dropout: 0
13
- Training:
14
- epochs: 500
15
- batch_size: 8192
16
- learning_rate: 0.0001
17
- gamma: 0.99
18
- Datasets:
19
- ttH_CP_even: &dataset_defn
20
- module: root_gnn_base.dataset
21
- class: LazyDataset
22
- shuffle_chunks: 10
23
- batch_size: 8192
24
- padding_mode: NONE #one of STEPS, FIXED, or NONE
25
- args: &dataset_args
26
- name: ttH_CP_even
27
- label: 0
28
- # weight_var: weight
29
- chunks: 10
30
- buffer_size: 3
31
- file_names: ttH_NLO.root
32
- tree_name: output
33
- fold_var: Number
34
- raw_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/stats_all/
35
- save_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/processed_graphs/stats_all/ttH_CP_even_vs_odd_batch_size_8192/
36
- node_branch_names:
37
- - [jet_pt, ele_pt, mu_pt, ph_pt, MET_met]
38
- - [jet_eta, ele_eta, mu_eta, ph_eta, 0]
39
- - [jet_phi, ele_phi, mu_phi, ph_phi, MET_phi]
40
- - CALC_E
41
- - [jet_btag, 0, 0, 0, 0]
42
- - [0, ele_charge, mu_charge, 0, 0]
43
- - NODE_TYPE
44
- node_branch_types: [vector, vector, vector, vector, single]
45
- node_feature_scales: [1e-1, 1, 1, 1e-1, 1, 1, 1]
46
- folding:
47
- n_folds: 4
48
- test: [0]
49
- # validation: 1
50
- train: [1, 2, 3]
51
- ttH_CP_odd:
52
- <<: *dataset_defn
53
- args:
54
- <<: *dataset_args
55
- name: ttH_CP_odd
56
- label: 1
57
- file_names: ttH_CPodd.root
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
root_gnn_dgl/jobs/prep_data/run_processing.py CHANGED
@@ -79,10 +79,7 @@ def main():
79
  # "configs/stats_100K/ttH_CP_even_vs_odd.yaml",
80
  # "configs/stats_all/pretraining_multiclass.yaml",
81
  # "configs/stats_all/ttH_CP_even_vs_odd.yaml",
82
- # "configs/attention/ttH_CP_even_vs_odd.yaml",
83
- "configs/stats_all/ttH_CP_even_vs_odd_batch_size_2048.yaml",
84
- "configs/stats_all/ttH_CP_even_vs_odd_batch_size_4096.yaml",
85
- "configs/stats_all/ttH_CP_even_vs_odd_batch_size_8192.yaml",
86
  ]
87
 
88
  # Path to the bash script to be called
 
79
  # "configs/stats_100K/ttH_CP_even_vs_odd.yaml",
80
  # "configs/stats_all/pretraining_multiclass.yaml",
81
  # "configs/stats_all/ttH_CP_even_vs_odd.yaml",
82
+ "configs/attention/ttH_CP_even_vs_odd.yaml",
 
 
 
83
  ]
84
 
85
  # Path to the bash script to be called
root_gnn_dgl/jobs/training/singlegpu/run_job.sh CHANGED
@@ -5,11 +5,11 @@
5
  #SBATCH --mail-user=ho22joshua@berkeley.edu
6
  #SBATCH --mail-type=ALL
7
  #SBATCH -t 15:00:00
8
- #SBATCH -A trn007
9
- #SBATCH -o /global/cfs/projectdirs/atlas/joshua/GNN4Colliders/root_gnn_dgl/jobs/slurm/%j.out # STDOUT
10
 
11
  ARGUEMENTS="$*"
12
 
13
  echo "Arguements: $ARGUEMENTS"
14
  echo "launching image"
15
- source /global/homes/j/joshuaho/launch_image.sh "--entrypoint /global/cfs/projectdirs/atlas/joshua/GNN4Colliders/root_gnn_dgl/jobs/training/singlegpu/run_job_image.sh" $ARGUEMENTS
 
5
  #SBATCH --mail-user=ho22joshua@berkeley.edu
6
  #SBATCH --mail-type=ALL
7
  #SBATCH -t 15:00:00
8
+ #SBATCH -A atlas
9
+ #SBATCH -o /global/cfs/projectdirs/atlas/joshua/root_gnn/root_gnn_dgl/jobs/slurm/%j.out # STDOUT
10
 
11
  ARGUEMENTS="$*"
12
 
13
  echo "Arguements: $ARGUEMENTS"
14
  echo "launching image"
15
+ source launch_image.sh "--entrypoint /global/cfs/projectdirs/atlas/joshua/root_gnn/root_gnn_dgl/jobs/run_job_image.sh" $ARGUEMENTS
root_gnn_dgl/jobs/training/singlegpu/run_job_image.sh CHANGED
@@ -4,8 +4,17 @@ CONFIG=$1
4
  shift
5
  ARGUEMENTS="$*"
6
 
7
- DIRECTORY="/global/cfs/projectdirs/atlas/joshua/GNN4Colliders/root_gnn_dgl/"
8
- COMMAND="$DIRECTORY"scripts/training_script.py $ARGUEMENTS --preshuffle --nocompile --lazy --config $DIRECTORY$CONFIG
 
 
 
 
 
 
 
 
 
9
 
10
  echo "Running my script now"
11
  echo $COMMAND
 
4
  shift
5
  ARGUEMENTS="$*"
6
 
7
+ DIRECTORY="/global/cfs/projectdirs/atlas/joshua/root_gnn/root_gnn_dgl/configs/model_configs/"
8
+ BASE_COMMAND="/global/cfs/projectdirs/atlas/joshua/root_gnn/root_gnn_dgl/scripts/training_script.py $ARGUEMENTS --preshuffle --nocompile --lazy --config $DIRECTORY"
9
+
10
+ echo "launched image"
11
+ cd /global/cfs/projectdirs/atlas/joshua/root_gnn/root_gnn_dgl/
12
+
13
+ COMMAND="$BASE_COMMAND$CONFIG"
14
+
15
+ eval "$(conda shell.bash hook)"
16
+ conda init bash
17
+ conda activate /opt/conda/envs/dgl
18
 
19
  echo "Running my script now"
20
  echo $COMMAND
root_gnn_dgl/jobs/training/singlegpu/submit.sh CHANGED
@@ -3,10 +3,7 @@ date
3
  DIRECTORY="/global/cfs/projectdirs/atlas/joshua/root_gnn/root_gnn_dgl/configs/model_configs/"
4
 
5
  configs=(
6
- "configs/stats_all/ttH_CP_even_vs_odd.yaml"
7
- "configs/stats_all/ttH_CP_even_vs_odd_batch_size_2048.yaml"
8
- "configs/stats_all/ttH_CP_even_vs_odd_batch_size_4096.yaml"
9
- "configs/stats_all/ttH_CP_even_vs_odd_batch_size_8192.yaml"
10
  )
11
 
12
  counter=0
 
3
  DIRECTORY="/global/cfs/projectdirs/atlas/joshua/root_gnn/root_gnn_dgl/configs/model_configs/"
4
 
5
  configs=(
6
+ "run_3_ttH/v05/sb_yukawa_cp_abs_weights.yaml --abs"
 
 
 
7
  )
8
 
9
  counter=0
root_gnn_dgl/models/GCN.py CHANGED
@@ -1154,7 +1154,6 @@ class Attention(nn.Module):
1154
  self.n_proc_steps = n_proc_steps
1155
  self.layers = nn.ModuleList()
1156
  self.has_global = sample_global.shape[1] != 0
1157
- self.hid_size = hid_size
1158
  gl_size = sample_global.shape[1] if self.has_global else 1
1159
 
1160
  #encoder
@@ -1197,7 +1196,7 @@ class Attention(nn.Module):
1197
  batch_num_nodes.append(non_padded_count)
1198
  start_idx = end_idx
1199
  batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata['features'].device)
1200
- sum_weights = batch_num_nodes[:, None].repeat(1, self.hid_size)
1201
  global_feats = batch_num_nodes[:, None].to(torch.float)
1202
 
1203
  h_global = self.global_encoder(global_feats)
@@ -1365,7 +1364,6 @@ class Transferred_Learning_Attention(nn.Module):
1365
  self.n_proc_steps = n_proc_steps
1366
  self.layers = nn.ModuleList()
1367
  self.has_global = sample_global.shape[1] != 0
1368
- self.hid_size = hid_size
1369
  gl_size = sample_global.shape[1] if self.has_global else 1
1370
 
1371
  self.learning_rate = learning_rate
@@ -1442,7 +1440,7 @@ class Transferred_Learning_Attention(nn.Module):
1442
  batch_num_nodes.append(non_padded_count)
1443
  start_idx = end_idx
1444
  batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata['features'].device)
1445
- sum_weights = batch_num_nodes[:, None].repeat(1, self.hid_size)
1446
  global_feats = batch_num_nodes[:, None].to(torch.float)
1447
 
1448
  h_global = self.TL_global_encoder(global_feats)
@@ -1858,7 +1856,6 @@ class Clustering(nn.Module):
1858
  self.n_layers = n_layers
1859
  self.n_proc_steps = n_proc_steps
1860
  self.layers = nn.ModuleList()
1861
- self.hid_size = hid_size
1862
  if (len(sample_global) == 0):
1863
  self.has_global = False
1864
  else:
@@ -1902,7 +1899,7 @@ class Clustering(nn.Module):
1902
  batch_num_nodes.append(non_padded_count)
1903
  start_idx = end_idx
1904
  batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata[features].device)
1905
- sum_weights = batch_num_nodes[:, None].repeat(1, self.hid_size)
1906
  global_feats = batch_num_nodes[:, None].to(torch.float)
1907
 
1908
  h_global = self.global_encoder(global_feats)
 
1154
  self.n_proc_steps = n_proc_steps
1155
  self.layers = nn.ModuleList()
1156
  self.has_global = sample_global.shape[1] != 0
 
1157
  gl_size = sample_global.shape[1] if self.has_global else 1
1158
 
1159
  #encoder
 
1196
  batch_num_nodes.append(non_padded_count)
1197
  start_idx = end_idx
1198
  batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata['features'].device)
1199
+ sum_weights = batch_num_nodes[:, None].repeat(1, 64)
1200
  global_feats = batch_num_nodes[:, None].to(torch.float)
1201
 
1202
  h_global = self.global_encoder(global_feats)
 
1364
  self.n_proc_steps = n_proc_steps
1365
  self.layers = nn.ModuleList()
1366
  self.has_global = sample_global.shape[1] != 0
 
1367
  gl_size = sample_global.shape[1] if self.has_global else 1
1368
 
1369
  self.learning_rate = learning_rate
 
1440
  batch_num_nodes.append(non_padded_count)
1441
  start_idx = end_idx
1442
  batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata['features'].device)
1443
+ sum_weights = batch_num_nodes[:, None].repeat(1, 64)
1444
  global_feats = batch_num_nodes[:, None].to(torch.float)
1445
 
1446
  h_global = self.TL_global_encoder(global_feats)
 
1856
  self.n_layers = n_layers
1857
  self.n_proc_steps = n_proc_steps
1858
  self.layers = nn.ModuleList()
 
1859
  if (len(sample_global) == 0):
1860
  self.has_global = False
1861
  else:
 
1899
  batch_num_nodes.append(non_padded_count)
1900
  start_idx = end_idx
1901
  batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata[features].device)
1902
+ sum_weights = batch_num_nodes[:, None].repeat(1, 64)
1903
  global_feats = batch_num_nodes[:, None].to(torch.float)
1904
 
1905
  h_global = self.global_encoder(global_feats)
root_gnn_dgl/profile.sh DELETED
@@ -1,35 +0,0 @@
1
- nsys profile \
2
- -o /pscratch/sd/j/joshuaho/full_stats_profile_1_gpu_batch_size_1028 \
3
- --capture-range=cudaProfilerApi \
4
- --duration=100 \
5
- --force-overwrite true \
6
- --trace=nvtx \
7
- --cudabacktrace=all \
8
- python scripts/training_script.py --config configs/stats_all/ttH_CP_even_vs_odd.yaml --preshuffle --nocompile --lazy --restart --profile
9
-
10
- nsys profile \
11
- -o /pscratch/sd/j/joshuaho/full_stats_profile_1_gpu_batch_size_2048 \
12
- --capture-range=cudaProfilerApi \
13
- --duration=100 \
14
- --force-overwrite true \
15
- --trace=nvtx \
16
- --cudabacktrace=all \
17
- python scripts/training_script.py --config configs/stats_all/ttH_CP_even_vs_odd_batch_size_2048.yaml --preshuffle --nocompile --lazy --restart --profile
18
-
19
- nsys profile \
20
- -o /pscratch/sd/j/joshuaho/full_stats_profile_1_gpu_batch_size_4096 \
21
- --capture-range=cudaProfilerApi \
22
- --duration=100 \
23
- --force-overwrite=true \
24
- --trace=nvtx \
25
- --cudabacktrace=all \
26
- python scripts/training_script.py --config configs/stats_all/ttH_CP_even_vs_odd_batch_size_4096.yaml --preshuffle --nocompile --lazy --restart --profile
27
-
28
- nsys profile \
29
- -o /pscratch/sd/j/joshuaho/full_stats_profile_1_gpu_batch_size_8192 \
30
- --capture-range=cudaProfilerApi \
31
- --duration=100 \
32
- --force-overwrite true \
33
- --trace=nvtx \
34
- --cudabacktrace=all \
35
- python scripts/training_script.py --config configs/stats_all/ttH_CP_even_vs_odd_batch_size_8192.yaml --preshuffle --nocompile --lazy --restart --profile
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
root_gnn_dgl/root_gnn_base/batched_dataset.py CHANGED
@@ -16,7 +16,7 @@ def GetBatchedLoader(dataset, batch_size, mask_fn = None, drop_last=True, **kwar
16
 
17
  #Dataset which contains prebatched shuffled graphs. Cannot be saved to disk, else batching info is lost.
18
  class PreBatchedDataset(DGLDataset):
19
- def __init__(self, start_dataset, batch_size, mask_fn = None, drop_last=True, save_to_disk = True, suffix = '', chunks = 1, chunkno = -1, shuffle = True, padding_mode = 'NONE', hidden_size=64, **kwargs):
20
  print(f'Unused kwargs: {kwargs}')
21
  self.start_dataset = start_dataset
22
  self.start_dataset.load()
@@ -34,7 +34,6 @@ class PreBatchedDataset(DGLDataset):
34
  self.suffix = suffix
35
  self.current_chunk = None
36
  self.current_chunk_idx = -1
37
- self.hid_size = hidden_size
38
  super().__init__(name = start_dataset.name + '_prebatched_padded', save_dir=start_dataset.save_dir)
39
 
40
  def process(self):
@@ -87,7 +86,7 @@ class PreBatchedDataset(DGLDataset):
87
  for i in range(len(self.graphs)):
88
  unbatched_g = dgl.unbatch(self.graphs[i])
89
  max_num_nodes = max(g.number_of_nodes() for g in unbatched_g)
90
- self.graphs[i] = utils.pad_batch_num_nodes(self.graphs[i], max_num_nodes, hid_size=self.hid_size)
91
  self.batch_num_nodes.append(self.graphs[i].batch_num_nodes())
92
  self.batch_num_edges.append(self.graphs[i].batch_num_edges())
93
  else:
 
16
 
17
  #Dataset which contains prebatched shuffled graphs. Cannot be saved to disk, else batching info is lost.
18
  class PreBatchedDataset(DGLDataset):
19
+ def __init__(self, start_dataset, batch_size, mask_fn = None, drop_last=True, save_to_disk = True, suffix = '', chunks = 1, chunkno = -1, shuffle = True, padding_mode = 'NONE', **kwargs):
20
  print(f'Unused kwargs: {kwargs}')
21
  self.start_dataset = start_dataset
22
  self.start_dataset.load()
 
34
  self.suffix = suffix
35
  self.current_chunk = None
36
  self.current_chunk_idx = -1
 
37
  super().__init__(name = start_dataset.name + '_prebatched_padded', save_dir=start_dataset.save_dir)
38
 
39
  def process(self):
 
86
  for i in range(len(self.graphs)):
87
  unbatched_g = dgl.unbatch(self.graphs[i])
88
  max_num_nodes = max(g.number_of_nodes() for g in unbatched_g)
89
+ self.graphs[i] = utils.pad_batch_num_nodes(self.graphs[i], max_num_nodes)
90
  self.batch_num_nodes.append(self.graphs[i].batch_num_nodes())
91
  self.batch_num_edges.append(self.graphs[i].batch_num_edges())
92
  else:
root_gnn_dgl/root_gnn_base/dataset.py CHANGED
@@ -1,7 +1,6 @@
1
  from dgl.data import DGLDataset
2
  import dgl
3
- import uproot
4
- import awkward as ak
5
  import torch
6
  import os
7
  import glob
@@ -15,7 +14,7 @@ def node_features_from_tree(ch, node_branch_names, node_branch_types, node_featu
15
  if node_type == 'single':
16
  lengths.append(1)
17
  elif node_type == 'vector':
18
- lengths.append(len(ch[branch]))
19
  else:
20
  print('Unknown node branch type: {}'.format(node_type))
21
  features = []
@@ -39,14 +38,16 @@ def node_features_from_tree(ch, node_branch_names, node_branch_types, node_featu
39
  this_type_ends_at = sum(lengths[:itype+1])
40
  feat.extend(features[0][this_type_starts_at:this_type_ends_at]*torch.cosh(features[1][this_type_starts_at:this_type_ends_at]))
41
  elif node_type == 'single':
42
- feat.append(ch[branch])
43
  elif node_type == 'vector':
44
- feat.extend(ch[branch])
45
  itype += 1
46
  features.append(torch.tensor(feat))
47
  return torch.stack(features, dim=1) * node_feature_scales, lengths
48
 
49
  def full_connected_graph(n_nodes, self_loops=True):
 
 
50
  senders = np.arange(n_nodes*n_nodes) // n_nodes
51
  receivers = np.arange(n_nodes*n_nodes) % n_nodes
52
  if not self_loops and n_nodes > 1:
@@ -58,18 +59,19 @@ def full_connected_graph(n_nodes, self_loops=True):
58
  def check_selection(ch, selection):
59
  var, cut, op = selection
60
  if op == '>':
61
- return ch[var] > cut
62
  elif op == '<':
63
- return ch[var] < cut
64
  elif op == '==':
65
- return ch[var] == cut
66
-
67
  def check_selections(ch, selections):
68
  for selection in selections:
69
  if not check_selection(ch, selection):
70
  return False
71
  return True
72
 
 
73
  class RootDataset(DGLDataset):
74
  def __init__(self, name=None, raw_dir=None, save_dir=None, label=1, file_names = '*.root', node_branch_names=None, node_branch_types=None, node_feature_scales=None,
75
  selections=[], save=True, tree_name = 'nominal_Loose', fold_var = 'eventNumber', weight_var = None, chunks = 1, process_chunks = None, global_features = [], tracking_info = [], **kwargs):
@@ -86,7 +88,7 @@ class RootDataset(DGLDataset):
86
  self.fold_var = fold_var
87
  self.tracking_info = tracking_info
88
  self.tracking_info.insert(0, fold_var)
89
- if weight_var is None:
90
  weight_var = 1
91
  self.tracking_info.insert(1, weight_var)
92
  self.global_features = global_features
@@ -114,7 +116,7 @@ class RootDataset(DGLDataset):
114
  branches.append(feat)
115
  for selection in self.selections:
116
  branches.append(selection[0])
117
- return list(set(branches)) # Remove duplicates
118
 
119
  def make_graph(self, ch):
120
  t1 = time.time()
@@ -127,7 +129,7 @@ class RootDataset(DGLDataset):
127
  self.times[0] += t2 - t1
128
  self.times[1] += t3 - t2
129
  return g
130
-
131
  def process(self):
132
  times = [0, 0, 0]
133
  oldtime = time.time()
@@ -137,21 +139,21 @@ class RootDataset(DGLDataset):
137
  self.files = []
138
  for file_name in self.file_names:
139
  self.files.extend(glob.glob(os.path.join(self.raw_dir, file_name)))
140
- branches = self.get_list_of_branches()
141
 
142
- # Read all files and concatenate arrays
143
- arrays = []
144
- for file in self.files:
145
- with uproot.open(file) as f:
146
- arrays.append(f[self.tree_name].arrays(branches, library="ak"))
147
- if len(arrays) == 0:
148
  print('No files found in {}'.format(os.path.join(self.raw_dir, self.file_names)))
149
- return
150
- data = ak.concatenate(arrays, axis=0)
151
- n_entries = len(data[branches[0]])
 
 
 
 
 
152
  newtime = time.time()
153
  times[0] += newtime - oldtime
154
- chunks = np.array_split(np.arange(n_entries), self.chunks)
155
  chunks = [chunk for i, chunk in enumerate(chunks) if i in self.process_chunks]
156
 
157
  self.graph_chunks = []
@@ -160,7 +162,6 @@ class RootDataset(DGLDataset):
160
  self.global_chunks = []
161
  chunk_id = -1
162
  for chunk in chunks:
163
- print('Processing chunk {}/{}'.format(chunk_id + 1, len(chunks)), flush=True)
164
  chunk_id += 1
165
  graphs = []
166
  labels = []
@@ -168,28 +169,28 @@ class RootDataset(DGLDataset):
168
  globals = []
169
  for ientry in chunk:
170
  if (ientry % 10000 == 0):
171
- print('Processing event {}/{}'.format(ientry, n_entries), flush=True)
172
- ch = {b: data[b][ientry] for b in branches}
173
  passed = True
174
  for selection in self.selections:
175
- if not check_selection(ch, selection):
176
  passed = False
177
  continue
178
  oldtime = newtime
179
  newtime = time.time()
180
  times[1] += newtime - oldtime
181
  if passed:
182
- graphs.append(self.make_graph(ch))
183
- labels.append(self.label)
184
  tracking.append(torch.zeros(len(self.tracking_info), dtype=torch.double))
185
  globals.append(torch.zeros(len(self.global_features)))
186
  for i_ti, tr_branch in enumerate(self.tracking_info):
187
  if isinstance(tr_branch, str):
188
- tracking[-1][i_ti] = ch[tr_branch]
189
  else:
190
  tracking[-1][i_ti] = tr_branch
191
  for i_gl, gl_branch in enumerate(self.global_features):
192
- globals[-1][i_gl] = ch[gl_branch]
193
  oldtime = newtime
194
  newtime = time.time()
195
  times[2] += newtime - oldtime
@@ -197,12 +198,6 @@ class RootDataset(DGLDataset):
197
  labels = torch.tensor(labels)
198
  tracking = torch.stack(tracking)
199
  globals = torch.stack(globals)
200
-
201
- self.graph_chunks.append(graphs)
202
- self.label_chunks.append(labels)
203
- self.tracking_chunks.append(tracking)
204
- self.global_chunks.append(globals)
205
- self.counts.append(len(graphs))
206
 
207
  if (self.chunks > 1):
208
  self.save_chunk(chunk_id, graphs, labels, tracking, globals)
@@ -213,18 +208,31 @@ class RootDataset(DGLDataset):
213
  self.graphs = graphs
214
  self.save()
215
  return
216
-
 
 
 
 
 
 
 
 
217
  def save(self):
 
218
  if not self.save_to_disk:
219
  return
220
  graph_path = os.path.join(self.save_dir, self.name + '.bin')
221
  if self.chunks == 1:
 
 
 
 
222
  print(f'Saving dataset to {os.path.join(self.save_dir, self.name + ".bin")}')
223
  dgl.save_graphs(str(graph_path), self.graphs, {'labels': torch.tensor(self.labels), 'tracking': torch.tensor(self.tracking), 'global': torch.tensor(self.global_features)})
224
  else:
 
225
  for i in range(len(self.process_chunks)):
226
  print(f'Saving dataset to {os.path.join(self.save_dir, self.name + f"_{self.process_chunks[i]}.bin")}')
227
-
228
  dgl.save_graphs(str(graph_path).replace('.bin', f'_{self.process_chunks[i]}.bin'), self.graph_chunks[i], {'labels': self.label_chunks[i], 'tracking': self.tracking_chunks[i], 'global': self.global_chunks[i]})
229
 
230
  def save_chunk(self, chunk_id, graphs, labels, tracking, globals):
@@ -233,7 +241,7 @@ class RootDataset(DGLDataset):
233
  graph_path = os.path.join(self.save_dir, self.name + '.bin')
234
  print(f'Saving dataset to {os.path.join(self.save_dir, self.name + f"_{self.process_chunks[chunk_id]}.bin")}')
235
  dgl.save_graphs(str(graph_path).replace('.bin', f'_{self.process_chunks[chunk_id]}.bin'), graphs, {'labels': labels, 'tracking': tracking, 'global': globals})
236
-
237
  def has_cache(self):
238
  print(f'Checking for cache of {self.name}')
239
  if not self.save_to_disk:
@@ -282,7 +290,7 @@ class RootDataset(DGLDataset):
282
 
283
  def __len__(self):
284
  return len(self.graphs)
285
-
286
  #Dataset with edge features added (deta, dphi, dR)
287
  class EdgeDataset(RootDataset):
288
  def make_graph(self, ch):
 
1
  from dgl.data import DGLDataset
2
  import dgl
3
+ import ROOT
 
4
  import torch
5
  import os
6
  import glob
 
14
  if node_type == 'single':
15
  lengths.append(1)
16
  elif node_type == 'vector':
17
+ lengths.append(len(getattr(ch, branch)))
18
  else:
19
  print('Unknown node branch type: {}'.format(node_type))
20
  features = []
 
38
  this_type_ends_at = sum(lengths[:itype+1])
39
  feat.extend(features[0][this_type_starts_at:this_type_ends_at]*torch.cosh(features[1][this_type_starts_at:this_type_ends_at]))
40
  elif node_type == 'single':
41
+ feat.append(getattr(ch, branch))
42
  elif node_type == 'vector':
43
+ feat.extend(getattr(ch, branch))
44
  itype += 1
45
  features.append(torch.tensor(feat))
46
  return torch.stack(features, dim=1) * node_feature_scales, lengths
47
 
48
  def full_connected_graph(n_nodes, self_loops=True):
49
+ senders = []
50
+ receivers = []
51
  senders = np.arange(n_nodes*n_nodes) // n_nodes
52
  receivers = np.arange(n_nodes*n_nodes) % n_nodes
53
  if not self_loops and n_nodes > 1:
 
59
  def check_selection(ch, selection):
60
  var, cut, op = selection
61
  if op == '>':
62
+ return getattr(ch, var) > cut
63
  elif op == '<':
64
+ return getattr(ch, var) < cut
65
  elif op == '==':
66
+ return getattr(ch, var) == cut
67
+
68
  def check_selections(ch, selections):
69
  for selection in selections:
70
  if not check_selection(ch, selection):
71
  return False
72
  return True
73
 
74
+ #Base dataset class for making graphs from ROOT ntuples.
75
  class RootDataset(DGLDataset):
76
  def __init__(self, name=None, raw_dir=None, save_dir=None, label=1, file_names = '*.root', node_branch_names=None, node_branch_types=None, node_feature_scales=None,
77
  selections=[], save=True, tree_name = 'nominal_Loose', fold_var = 'eventNumber', weight_var = None, chunks = 1, process_chunks = None, global_features = [], tracking_info = [], **kwargs):
 
88
  self.fold_var = fold_var
89
  self.tracking_info = tracking_info
90
  self.tracking_info.insert(0, fold_var)
91
+ if weight_var == None:
92
  weight_var = 1
93
  self.tracking_info.insert(1, weight_var)
94
  self.global_features = global_features
 
116
  branches.append(feat)
117
  for selection in self.selections:
118
  branches.append(selection[0])
119
+ return branches
120
 
121
  def make_graph(self, ch):
122
  t1 = time.time()
 
129
  self.times[0] += t2 - t1
130
  self.times[1] += t3 - t2
131
  return g
132
+
133
  def process(self):
134
  times = [0, 0, 0]
135
  oldtime = time.time()
 
139
  self.files = []
140
  for file_name in self.file_names:
141
  self.files.extend(glob.glob(os.path.join(self.raw_dir, file_name)))
142
+ self.chain = ROOT.TChain(self.tree_name)
143
 
144
+ if len(self.files) == 0:
 
 
 
 
 
145
  print('No files found in {}'.format(os.path.join(self.raw_dir, self.file_names)))
146
+ for file in self.files:
147
+ utils.set_timeout(60*2)
148
+ self.chain.Add(file)
149
+ utils.unset_timeout()
150
+ branches = self.get_list_of_branches()
151
+ self.chain.SetBranchStatus('*', 0)
152
+ for branch in branches:
153
+ self.chain.SetBranchStatus(branch, 1)
154
  newtime = time.time()
155
  times[0] += newtime - oldtime
156
+ chunks = np.array_split(np.arange(self.chain.GetEntries()), self.chunks)
157
  chunks = [chunk for i, chunk in enumerate(chunks) if i in self.process_chunks]
158
 
159
  self.graph_chunks = []
 
162
  self.global_chunks = []
163
  chunk_id = -1
164
  for chunk in chunks:
 
165
  chunk_id += 1
166
  graphs = []
167
  labels = []
 
169
  globals = []
170
  for ientry in chunk:
171
  if (ientry % 10000 == 0):
172
+ print('Processing event {}/{}'.format(ientry, self.chain.GetEntries()), flush=True)
173
+ self.chain.GetEntry(ientry)
174
  passed = True
175
  for selection in self.selections:
176
+ if not check_selection(self.chain, selection):
177
  passed = False
178
  continue
179
  oldtime = newtime
180
  newtime = time.time()
181
  times[1] += newtime - oldtime
182
  if passed:
183
+ graphs.append(self.make_graph(self.chain))
184
+ labels.append( self.label )
185
  tracking.append(torch.zeros(len(self.tracking_info), dtype=torch.double))
186
  globals.append(torch.zeros(len(self.global_features)))
187
  for i_ti, tr_branch in enumerate(self.tracking_info):
188
  if isinstance(tr_branch, str):
189
+ tracking[-1][i_ti] = getattr(self.chain, tr_branch)
190
  else:
191
  tracking[-1][i_ti] = tr_branch
192
  for i_gl, gl_branch in enumerate(self.global_features):
193
+ globals[-1][i_gl] = getattr(self.chain, gl_branch)
194
  oldtime = newtime
195
  newtime = time.time()
196
  times[2] += newtime - oldtime
 
198
  labels = torch.tensor(labels)
199
  tracking = torch.stack(tracking)
200
  globals = torch.stack(globals)
 
 
 
 
 
 
201
 
202
  if (self.chunks > 1):
203
  self.save_chunk(chunk_id, graphs, labels, tracking, globals)
 
208
  self.graphs = graphs
209
  self.save()
210
  return
211
+ self.graphs = self.graph_chunks[0]
212
+ for chunk in self.graph_chunks[1:]:
213
+ self.graphs += chunk
214
+ self.labels = torch.cat(self.label_chunks)
215
+ self.tracking = torch.cat(self.tracking_chunks)
216
+ self.global_features = torch.cat(self.global_chunks)
217
+ print('Time spent: Creating TChain: {}s, Getting Entries and Selection: {}s, Graph Creation: {}s'.format(*times))
218
+ print('Time spent in node_features_from_tree: {}s, full_connected_graph: {}s'.format(*self.times))
219
+
220
  def save(self):
221
+ """save the graph list and the labels"""
222
  if not self.save_to_disk:
223
  return
224
  graph_path = os.path.join(self.save_dir, self.name + '.bin')
225
  if self.chunks == 1:
226
+ # print(len(self.graphs))
227
+ # print(len(self.labels))
228
+ # print(len(self.tracking))
229
+ # print(len(self.globals))
230
  print(f'Saving dataset to {os.path.join(self.save_dir, self.name + ".bin")}')
231
  dgl.save_graphs(str(graph_path), self.graphs, {'labels': torch.tensor(self.labels), 'tracking': torch.tensor(self.tracking), 'global': torch.tensor(self.global_features)})
232
  else:
233
+ print(len(self.graph_chunks))
234
  for i in range(len(self.process_chunks)):
235
  print(f'Saving dataset to {os.path.join(self.save_dir, self.name + f"_{self.process_chunks[i]}.bin")}')
 
236
  dgl.save_graphs(str(graph_path).replace('.bin', f'_{self.process_chunks[i]}.bin'), self.graph_chunks[i], {'labels': self.label_chunks[i], 'tracking': self.tracking_chunks[i], 'global': self.global_chunks[i]})
237
 
238
  def save_chunk(self, chunk_id, graphs, labels, tracking, globals):
 
241
  graph_path = os.path.join(self.save_dir, self.name + '.bin')
242
  print(f'Saving dataset to {os.path.join(self.save_dir, self.name + f"_{self.process_chunks[chunk_id]}.bin")}')
243
  dgl.save_graphs(str(graph_path).replace('.bin', f'_{self.process_chunks[chunk_id]}.bin'), graphs, {'labels': labels, 'tracking': tracking, 'global': globals})
244
+
245
  def has_cache(self):
246
  print(f'Checking for cache of {self.name}')
247
  if not self.save_to_disk:
 
290
 
291
  def __len__(self):
292
  return len(self.graphs)
293
+
294
  #Dataset with edge features added (deta, dphi, dR)
295
  class EdgeDataset(RootDataset):
296
  def make_graph(self, ch):
root_gnn_dgl/root_gnn_base/utils.py CHANGED
@@ -8,16 +8,10 @@ import dgl
8
  import signal
9
 
10
  def buildFromConfig(conf, run_time_args = {}):
11
- device = run_time_args.get('device', 'cpu')
12
  if 'module' in conf:
13
  module = importlib.import_module(conf['module'])
14
  cls = getattr(module, conf['class'])
15
- args = conf['args'].copy()
16
- if 'weight' in args and isinstance(args['weight'], list):
17
- args['weight'] = torch.tensor(args['weight'], dtype=torch.float, device=device)
18
- # Remove device from run_time_args to not pass it to the class
19
- run_time_args = {k: v for k, v in run_time_args.items() if k != 'device'}
20
- return cls(**args, **run_time_args)
21
  else:
22
  print('No module specified in config. Returning None.')
23
 
@@ -92,7 +86,7 @@ def pad_batch(batch, edges = 104000, nodes = 16000):
92
  return make_padding_graph(batch, pad_nodes, pad_edges)
93
 
94
  def pad_batch_num_nodes(batch, max_num_nodes, hid_size = 64):
95
- print(f"Padding each graph to have {max_num_nodes} nodes. Using hidden size {hid_size}.")
96
 
97
  unbatched = dgl.unbatch(batch)
98
  for g in unbatched:
@@ -183,101 +177,21 @@ def get_specific_epoch(config, target_epoch, device = None, from_ryan = False):
183
  checkpoint = torch.load(os.path.join(config['Training_Directory'], f'model_epoch_{last_epoch}.pt'), map_location=device)
184
  return last_epoch, checkpoint
185
 
186
- #Return the index and checkpoint of the nest epoch.
187
- def get_best_epoch(config, var='Test_AUC', mode='max', device=None, from_ryan=False):
188
- # Read the training log
189
- log = read_log(config)
190
-
191
- # Ensure the specified variable exists in the log
192
- if var not in log:
193
- raise ValueError(f"Variable '{var}' not found in the training log.")
194
-
195
- # Determine the target epoch based on the mode ('max' or 'min')
196
- if mode == 'max':
197
- target_epoch = int(np.argmax(log[var]))
198
- print(f"Best epoch based on '{var}' (max): {target_epoch} with value: {log[var][target_epoch]}")
199
- elif mode == 'min':
200
- target_epoch = int(np.argmin(log[var]))
201
- print(f"Best epoch based on '{var}' (min): {target_epoch} with value: {log[var][target_epoch]}")
202
- else:
203
- raise ValueError(f"Invalid mode '{mode}'. Expected 'max' or 'min'.")
204
-
205
- # Initialize checkpoint retrieval variables
206
- last_epoch = -1
207
- checkpoint = None
208
-
209
- # Iterate through epochs up to the target epoch to find the corresponding checkpoint
210
- for ep in range(target_epoch + 1):
211
- if from_ryan:
212
- checkpoint_path = os.path.join(
213
- '/global/cfs/cdirs/atlas/berobert/root_gnn_dgl/',
214
- config['Training_Directory'],
215
- f'model_epoch_{ep}.pt'
216
- )
217
- else:
218
- checkpoint_path = os.path.join(
219
- config['Training_Directory'],
220
- f'model_epoch_{ep}.pt'
221
- )
222
-
223
- if os.path.exists(checkpoint_path):
224
- last_epoch = ep
225
- else:
226
- print(f'Epoch {ep} not found. Stopping at epoch {last_epoch}')
227
- print('File not found: ', checkpoint_path)
228
- break
229
-
230
- # Load the checkpoint for the last valid epoch
231
- if last_epoch >= 0:
232
- if from_ryan:
233
- checkpoint_path = os.path.join(
234
- '/global/cfs/cdirs/atlas/berobert/root_gnn_dgl/',
235
- config['Training_Directory'],
236
- f'model_epoch_{last_epoch}.pt'
237
- )
238
- else:
239
- checkpoint_path = os.path.join(
240
- config['Training_Directory'],
241
- f'model_epoch_{last_epoch}.pt'
242
- )
243
-
244
- checkpoint = torch.load(checkpoint_path, map_location=device)
245
-
246
- return last_epoch, checkpoint
247
-
248
  def read_log(config):
249
  lines = []
250
  with open(config['Training_Directory'] + '/training.log', 'r') as f:
251
  lines = f.readlines()
252
- lines = [l for l in lines if 'Epoch' in l]
253
-
254
  labels = []
255
  for field in lines[0].split('|'):
256
  labels.append(field.split()[0])
257
-
258
- # Initialize log as a dictionary with empty lists
259
- log = {label: [] for label in labels}
260
-
261
- for line in lines:
262
- valid_row = True # Flag to check if the row is valid
263
- temp_row = {} # Temporary row to store values before adding to log
264
-
265
  for field in line.split('|'):
266
  spl = field.split()
267
- try:
268
- temp_row[spl[0]] = float(spl[1])
269
- except (ValueError, IndexError):
270
- valid_row = False # Mark row as invalid if conversion fails
271
- break
272
-
273
- if valid_row: # Only add the row if all fields are valid
274
- for label in labels:
275
- log[label].append(temp_row.get(label, np.nan)) # Handle missing labels gracefully
276
-
277
- # Convert lists to numpy arrays for consistency
278
- for label in labels:
279
- log[label] = np.array(log[label])
280
-
281
  return log
282
 
283
  #Plot training logs.
 
8
  import signal
9
 
10
  def buildFromConfig(conf, run_time_args = {}):
 
11
  if 'module' in conf:
12
  module = importlib.import_module(conf['module'])
13
  cls = getattr(module, conf['class'])
14
+ return cls(**conf['args'], **run_time_args)
 
 
 
 
 
15
  else:
16
  print('No module specified in config. Returning None.')
17
 
 
86
  return make_padding_graph(batch, pad_nodes, pad_edges)
87
 
88
  def pad_batch_num_nodes(batch, max_num_nodes, hid_size = 64):
89
+ print(f"Padding each graph to have {max_num_nodes} nodes")
90
 
91
  unbatched = dgl.unbatch(batch)
92
  for g in unbatched:
 
177
  checkpoint = torch.load(os.path.join(config['Training_Directory'], f'model_epoch_{last_epoch}.pt'), map_location=device)
178
  return last_epoch, checkpoint
179
 
180
+ #Convert training logs into dict for plotting.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  def read_log(config):
182
  lines = []
183
  with open(config['Training_Directory'] + '/training.log', 'r') as f:
184
  lines = f.readlines()
185
+ lines = [ l for l in lines if 'Epoch' in l ]
186
+ nlines = len(lines)
187
  labels = []
188
  for field in lines[0].split('|'):
189
  labels.append(field.split()[0])
190
+ log = {label : np.zeros(nlines) for label in labels}
191
+ for i, line in enumerate(lines):
 
 
 
 
 
 
192
  for field in line.split('|'):
193
  spl = field.split()
194
+ log[spl[0]][i] = float(spl[1])
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  return log
196
 
197
  #Plot training logs.
root_gnn_dgl/root_gnn_base/visualize_input_distributions.py DELETED
@@ -1,582 +0,0 @@
1
- import numpy as np
2
- import matplotlib.pyplot as plt
3
- import pandas as pd
4
- import uproot
5
- import yaml
6
- import argparse
7
- import sys
8
- from pathlib import Path
9
- from array import array
10
- import os
11
- import awkward as ak
12
- import math
13
-
14
- def tree_to_dataframe(tree_filepath, sort_by="", branches=[]):
15
- """
16
- Convert a ROOT tree to a Pandas DataFrame (Assuming data is columnar).
17
- Depends on uproot and pandas libraries (import them before-hand).
18
- """
19
- data_dict = {} # Use dictionary instead of list
20
-
21
- with uproot.open(tree_filepath) as file:
22
- if not branches: # If branches list is empty
23
- keys = file.keys()
24
- for key in keys:
25
- try:
26
- data_dict[key] = file[key].array(library="pd")
27
- except Exception as e:
28
- print(f"Warning: Could not load branch '{key}': {e}")
29
- else: # If specific branches are requested
30
- for branch in branches:
31
- try:
32
- data_dict[branch] = file[branch].array(library="pd")
33
- except KeyError:
34
- print(f"Warning: Branch '{branch}' not found in ROOT file")
35
- except Exception as e:
36
- print(f"Warning: Could not load branch '{branch}': {e}")
37
-
38
- # Create DataFrame from dictionary
39
- data = pd.DataFrame(data_dict)
40
-
41
- if sort_by == "":
42
- return data
43
- else:
44
- if sort_by in data.columns:
45
- data.sort_values(by=[sort_by], inplace=True)
46
- data.reset_index(inplace=True, drop=True)
47
- else:
48
- print(f"Warning: Sort column '{sort_by}' not found in DataFrame")
49
- return data
50
-
51
- def extract_dataset_info(yaml_file_path):
52
- with open(yaml_file_path, 'r') as file:
53
- config = yaml.safe_load(file)
54
-
55
- datasets_info = {}
56
- if "Datasets" in config:
57
- for dset_name, dset_config in config['Datasets'].items():
58
- if 'args' not in dset_config:
59
- continue
60
- args = dset_config["args"]
61
- dset_info = {}
62
- if "raw_dir" in args:
63
- dset_info["raw_dir"] = args["raw_dir"]
64
- if "file_names" in args:
65
- dset_info["file_names"] = args["file_names"]
66
- if "node_branch_names" in args:
67
- dset_info["node_branch_names"] = args["node_branch_names"]
68
- if "name" in args:
69
- dset_info["name"] = args["name"]
70
- if "node_feature_scales" in args:
71
- dset_info["node_feature_scales"] = args["node_feature_scales"]
72
- if "tree_name" in args:
73
- dset_info["tree_name"] = args["tree_name"]
74
- if "label" in args:
75
- dset_info["label"] = args["label"]
76
- # if "exclude_zeros" in args:
77
- # dset_info["exclude_zeros"] = args["exclude_zeros"]
78
- # if "exclude_zeros" not in args:
79
- # print("ERROR: Please add the following variable to your config, under args for each dataset:\nFor example, exclude_zeros: [pt, phi, eta]\exclude_zeros should be a list that contains the endings of the names of observables that you want to exclude the value 0 from while plotting histograms.")
80
- # sys.exit()
81
- if dset_info:
82
- datasets_info[dset_name] = dset_info
83
- return(datasets_info)
84
-
85
- def adaptive_bins(data, method='auto'):
86
- """Choose optimal number of bins based on data characteristics"""
87
- data = np.array([x for x in data if x is not None and not np.isnan(x)])
88
-
89
- if len(data) == 0:
90
- return 10
91
-
92
- if method == 'sturges':
93
- return int(np.ceil(np.log2(len(data)) + 1))
94
- elif method == 'scott':
95
- h = 3.5 * np.std(data) / (len(data) ** (1/3))
96
- return int(np.ceil((np.max(data) - np.min(data)) / h))
97
- elif method == 'freedman':
98
- iqr = np.percentile(data, 75) - np.percentile(data, 25)
99
- h = 2 * iqr / (len(data) ** (1/3))
100
- return int(np.ceil((np.max(data) - np.min(data)) / h)) if h > 0 else 50
101
- elif method == 'sqrt':
102
- return int(np.ceil(np.sqrt(len(data))))
103
- else: # 'auto'
104
- return 'auto' # Let matplotlib decide
105
-
106
- def safe_clean_data(data, observable_name=""):
107
- """Safely clean data, handling different data types and ignoring zeros for specific observables"""
108
- if data is None or len(data) == 0:
109
- return []
110
-
111
- # Convert to numpy array if it isn't already
112
- if not isinstance(data, np.ndarray):
113
- data = np.array(data)
114
-
115
- # Check if we should ignore zeros
116
- # ignore_zeros = observable_name.lower().endswith(exclude_zeros)
117
-
118
- # Handle different data types
119
- if data.dtype.kind in ['i', 'f']: # integer or float
120
- # Numeric data - can use isnan and isfinite
121
- if data.dtype.kind == 'f': # float
122
- mask = ~np.isnan(data) & np.isfinite(data)
123
- clean_data = data[mask]
124
- else: # integer
125
- clean_data = data # integers don't have NaN/inf issues
126
-
127
- clean_data = clean_data[(clean_data != -999) & (clean_data != -1)]
128
-
129
- # Remove zeros if needed
130
- # if ignore_zeros:
131
- # clean_data = clean_data[clean_data != 0]
132
-
133
- return clean_data
134
- else:
135
- # Non-numeric data - filter manually
136
- clean_list = []
137
- for item in data:
138
- if item is None:
139
- continue
140
-
141
- try:
142
- # Try to convert to float to check if it's numeric
143
- float_val = float(item)
144
- if not (np.isnan(float_val) or np.isinf(float_val)):
145
- # Check if we should ignore zeros
146
- if ignore_zeros and float_val == 0:
147
- continue
148
- clean_list.append(float_val)
149
- except (ValueError, TypeError):
150
- # Not convertible to float, skip
151
- continue
152
- return np.array(clean_list) if clean_list else np.array([])
153
-
154
- def make_distributions(dset_info, output_dir, exclude_zeros):
155
- os.makedirs(output_dir, exist_ok=True)
156
- awk_type = ak.Array
157
- list_type = type([])
158
-
159
- for dset_name in dset_info:
160
- curr_dset_info = dset_info[dset_name]
161
- curr_df = tree_to_dataframe(f"{curr_dset_info['raw_dir']}{curr_dset_info['file_names']}:{curr_dset_info['tree_name']}")
162
-
163
- # Collect all observables and their data for this dataset
164
- observables_data = {}
165
-
166
- for branch in curr_dset_info["node_branch_names"]:
167
- if type(branch) != list_type:
168
- continue
169
- for observable in branch:
170
- if type(observable) != type("str"):
171
- continue
172
- try:
173
- data = curr_df[observable]
174
- if type(data.iloc[0]) == awk_type or type(data.iloc[0]) == list_type:
175
- appended_data = []
176
- for i in range(len(data.iloc[0])):
177
- try:
178
- ith_obs_data = np.array([x[i] if x is not None and len(x) > i else None for x in data])
179
- # Filter out None values
180
- ith_obs_data = ith_obs_data[ith_obs_data != None]
181
- if len(ith_obs_data) > 0:
182
- appended_data.append(ith_obs_data)
183
- except (IndexError, TypeError):
184
- continue
185
- if appended_data:
186
- plot_data = np.concatenate(appended_data)
187
- observables_data[observable] = plot_data
188
- else:
189
- observables_data[observable] = data
190
- except KeyError:
191
- continue
192
-
193
- # Create subplot grid for all observables in this dataset
194
- if not observables_data:
195
- print(f"No data found for {dset_name}")
196
- continue
197
-
198
- n_observables = len(observables_data)
199
-
200
- # Calculate grid dimensions (try to make it roughly square)
201
- n_cols = math.ceil(math.sqrt(n_observables))
202
- n_rows = math.ceil(n_observables / n_cols)
203
-
204
- # Create the figure with subplots
205
- fig, axes = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 3*n_rows))
206
- fig.suptitle(f'All Distributions for {dset_name}', fontsize=16, y=0.98)
207
-
208
- # Handle case where there's only one subplot
209
- if n_observables == 1:
210
- axes = [axes]
211
- elif n_rows == 1:
212
- axes = axes.reshape(1, -1)
213
- elif n_cols == 1:
214
- axes = axes.reshape(-1, 1)
215
-
216
- # Flatten axes for easy iteration
217
- axes_flat = axes.flatten() if n_observables > 1 else axes
218
-
219
- # Plot each observable
220
- for idx, (observable, plot_data) in enumerate(observables_data.items()):
221
- ax = axes_flat[idx]
222
-
223
- # Clean data safely
224
- clean_data = safe_clean_data(plot_data, exclude_zeros, observable)
225
-
226
- if len(clean_data) > 0:
227
- try:
228
- bins = adaptive_bins(clean_data, method="freedman")
229
- # Plot histogram with label including event count
230
- ax.hist(clean_data, histtype="step", density=True, bins=bins,
231
- label=f'N = {len(clean_data):,}')
232
- if observable.lower().endswith(exclude_zeros):
233
- ax.set_title(f'{observable} (zeros excluded)', fontsize=10)
234
- else:
235
- ax.set_title(f'{observable}', fontsize=10)
236
- ax.set_xlabel(f'{observable}', fontsize=8)
237
- ax.set_ylabel('Density', fontsize=8)
238
- ax.tick_params(axis='both', which='major', labelsize=7)
239
- ax.grid(True, alpha=0.3)
240
-
241
- # Add legend with event count
242
- ax.legend(fontsize=8, loc='upper right')
243
-
244
- except Exception as e:
245
- print(f"Error plotting {observable}: {e}")
246
- ax.text(0.5, 0.5, f'Plot error:\n{str(e)[:50]}...', ha='center', va='center',
247
- transform=ax.transAxes, fontsize=8)
248
- ax.set_title(f'{observable} (Error)', fontsize=10)
249
- else:
250
- ax.text(0.5, 0.5, 'No valid data\nN = 0', ha='center', va='center',
251
- transform=ax.transAxes)
252
- ax.set_title(f'{observable} (No Data)', fontsize=10)
253
-
254
- # Hide unused subplots
255
- for idx in range(n_observables, len(axes_flat)):
256
- axes_flat[idx].set_visible(False)
257
-
258
- # Adjust layout and save
259
- plt.tight_layout()
260
- plt.subplots_adjust(top=0.93) # Make room for suptitle
261
- plt.savefig(f"{output_dir}/{dset_name}_all_distributions.png",
262
- dpi=300, bbox_inches='tight')
263
- plt.close()
264
-
265
- print(f"Created combined plot for {dset_name} with {n_observables} observables")
266
-
267
- def make_distributions_comparison_grid_by_label(dset_info, output_dir, output_filename, label_names=None, use_percentile_for_xlims = False, xlim_adjustment = False):
268
- """Create comparison plots grouped by label instead of dataset
269
-
270
- Args:
271
- dset_info: Dictionary containing dataset information
272
- output_dir: Directory to save output plots
273
- label_names: Optional list of strings to use as label names in legends.
274
- If provided, must have length equal to number of unique labels.
275
- Index corresponds to label number.
276
- """
277
- os.makedirs(output_dir, exist_ok=True)
278
- awk_type = ak.Array
279
- list_type = type([])
280
-
281
- label_to_datasets = {}
282
- for dset_name, curr_dset_info in dset_info.items():
283
- dataset_label = curr_dset_info.get('label', 'Unknown')
284
- if dataset_label not in label_to_datasets:
285
- label_to_datasets[dataset_label] = []
286
- label_to_datasets[dataset_label].append(dset_name)
287
-
288
- # First, collect all data organized by observable and then by label
289
- observables_by_variable = {}
290
-
291
- for dset_name in dset_info:
292
- print(f"Processing dataset: {dset_name}")
293
- curr_dset_info = dset_info[dset_name]
294
-
295
- # Get the label for this dataset
296
- dataset_label = curr_dset_info.get('label', 'Unknown')
297
- print(f" Label: {dataset_label}")
298
-
299
- if type(curr_dset_info['file_names']) == type("str"):
300
- curr_df = tree_to_dataframe(f"{curr_dset_info['raw_dir']}{curr_dset_info['file_names']}:{curr_dset_info['tree_name']}")
301
- else:
302
- curr_df_list = []
303
- for i in range(len(curr_dset_info['file_names'])):
304
- curr_name = curr_dset_info['file_names'][i]
305
- curr_curr_df = tree_to_dataframe(f"{curr_dset_info['raw_dir']}{curr_name}:{curr_dset_info['tree_name']}")
306
- curr_df_list.append(curr_curr_df)
307
- curr_df = pd.concat(curr_df_list, ignore_index = True)
308
-
309
- for branch in curr_dset_info["node_branch_names"]:
310
- if type(branch) != list_type:
311
- continue
312
- for observable in branch:
313
- if type(observable) != type("str"):
314
- continue
315
- try:
316
- data = curr_df[observable]
317
-
318
- # Initialize observable dict if not exists
319
- if observable not in observables_by_variable:
320
- observables_by_variable[observable] = {}
321
-
322
- # Initialize label dict if not exists
323
- if dataset_label not in observables_by_variable[observable]:
324
- observables_by_variable[observable][dataset_label] = []
325
-
326
- if type(data.iloc[0]) == awk_type or type(data.iloc[0]) == list_type:
327
- appended_data = []
328
- # for i in range(len(data.iloc[0])):
329
- # try:
330
- # ith_obs_data = np.array([x[i] if x is not None and len(x) > i else None for x in data])
331
- # ith_obs_data = ith_obs_data[ith_obs_data != None]
332
- # if len(ith_obs_data) > 0:
333
- # appended_data.append(ith_obs_data)
334
- # except (IndexError, TypeError):
335
- # continue
336
- for x in data:
337
- row_data = []
338
- for i in range(len(x)):
339
- if x[i] == 0 or x[i] == 0.0:
340
- continue
341
- row_data.append(x[i])
342
- row_data = np.array(row_data)
343
- row_data = row_data[row_data != None]
344
- if len(row_data > 0):
345
- appended_data.append(row_data)
346
-
347
- if appended_data:
348
- plot_data = np.concatenate(appended_data)
349
- observables_by_variable[observable][dataset_label].append(plot_data)
350
- else:
351
- observables_by_variable[observable][dataset_label].append(data)
352
-
353
- except KeyError:
354
- continue
355
-
356
- # Combine data for each label (since multiple datasets might have the same label)
357
- observables_by_label = {}
358
- for observable, labels_data in observables_by_variable.items():
359
- observables_by_label[observable] = {}
360
- for label, data_list in labels_data.items():
361
- if data_list:
362
- # Concatenate all data for this label
363
- combined_data = []
364
- for data in data_list:
365
- clean_data = safe_clean_data(data, observable)
366
- if len(clean_data) > 0:
367
- combined_data.extend(clean_data)
368
-
369
- if combined_data:
370
- observables_by_label[observable][label] = np.array(combined_data)
371
-
372
- # Filter out observables with no data
373
- observables_by_label = {k: v for k, v in observables_by_label.items() if v}
374
-
375
- if not observables_by_label:
376
- print("No observables found!")
377
- return
378
-
379
- # Get consistent colors for labels across all plots
380
- all_labels = set()
381
- for labels_data in observables_by_label.values():
382
- all_labels.update(labels_data.keys())
383
- all_labels = sorted(list(all_labels)) # Sort for consistency
384
-
385
- print(f"Found labels: {all_labels}")
386
-
387
- # Validate label_names parameter if provided
388
- if label_names is not None:
389
- if len(label_names) != len(all_labels):
390
- raise ValueError(f"label_names must have length {len(all_labels)} to match number of unique labels, but got {len(label_names)}")
391
- print(f"Using custom label names: {label_names}")
392
-
393
- # Calculate grid dimensions
394
- n_observables = len(observables_by_label)
395
- n_cols = math.ceil(math.sqrt(n_observables))
396
- n_rows = math.ceil(n_observables / n_cols)
397
-
398
- print(f"Creating comparison grid for {n_observables} observables ({n_rows}x{n_cols})")
399
-
400
- # Create the big figure
401
- fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 4*n_rows))
402
- fig.suptitle('Distribution Comparisons Across All Labels', fontsize=20, y=0.98)
403
-
404
- # Handle different subplot configurations
405
- if n_observables == 1:
406
- axes = [axes]
407
- elif n_rows == 1:
408
- axes = axes.reshape(1, -1)
409
- elif n_cols == 1:
410
- axes = axes.reshape(-1, 1)
411
-
412
- # Flatten axes for easy iteration
413
- axes_flat = axes.flatten() if n_observables > 1 else axes
414
-
415
- # Create color map for labels
416
- colors = plt.cm.tab10(np.linspace(0, 1, len(all_labels)))
417
- label_colors = dict(zip(all_labels, colors))
418
-
419
- # Plot each observable
420
- for idx, (observable, labels_data) in enumerate(observables_by_label.items()):
421
- ax = axes_flat[idx]
422
-
423
- # Calculate consistent bins based on ALL data for this observable
424
- all_combined_data = []
425
- for label_data in labels_data.values():
426
- all_combined_data.extend(label_data)
427
-
428
- if not all_combined_data:
429
- ax.text(0.5, 0.5, 'No valid data', ha='center', va='center', transform=ax.transAxes)
430
- ax.set_title(f'{observable} (No Data)', fontsize=12)
431
- continue
432
-
433
- combined_array = np.array(all_combined_data)
434
- if observable == "ph_phi" or observable == "ph_eta":
435
- n_bins = 10
436
- elif observable == "m_jet_btag77":
437
- n_bins = 4
438
- else:
439
- n_bins = adaptive_bins(combined_array, method="freedman")
440
- if n_bins > 35: ### CONTROL FINENESS OF BINNING HERE!!!!
441
- n_bins = 35
442
- bin_edges = np.histogram_bin_edges(combined_array, bins=n_bins)
443
-
444
- print(f"{observable}: Using {len(bin_edges)-1} consistent bins for {len(labels_data)} labels")
445
-
446
- # Plot each label's distribution for this observable
447
- for label, plot_data in labels_data.items():
448
- try:
449
- # Determine label for legend
450
- if label_names is not None:
451
- # Use custom label name based on label index
452
- label_idx = all_labels.index(label)
453
- legend_label = f'{label_names[label_idx]} (N={len(plot_data):,})'
454
- else:
455
- # Use original format
456
- legend_label = f'Label {label} (N={len(plot_data):,})'
457
-
458
- ax.hist(plot_data, bins=bin_edges, histtype="step", density=True,
459
- label=legend_label,
460
- color=label_colors[label], linewidth=1.5, alpha=0.8)
461
- except Exception as e:
462
- print(f"Error plotting {observable} for label {label}: {e}")
463
- continue
464
-
465
- # Add title and labels
466
- title = f'{observable}'
467
- # if observable.lower().endswith(exclude_zeros):
468
- # title += ' (zeros excluded)'
469
-
470
- if use_percentile_for_xlims and xlim_adjustment:
471
- print("ERROR: Only provide one of the flags at a time, either --use_percentile_for_xlims or --xlim_adjustment")
472
- return()
473
- if not use_percentile_for_xlims and not xlim_adjustment:
474
- ax.set_xlim(bin_edges[0], bin_edges[-1])
475
- elif use_percentile_for_xlims:
476
- combined_array = np.array(all_combined_data)
477
- ax.set_xlim(bin_edges[0], np.percentile(combined_array, 98))
478
- elif xlim_adjustment:
479
- combined_array = np.array(all_combined_data)
480
- min_edge = max(bin_edges[0], np.mean(combined_array) - 3*np.std(combined_array))
481
- max_edge = min(bin_edges[-1], np.mean(combined_array) + 3*np.std(combined_array))
482
- ax.set_xlim(min_edge, max_edge)
483
-
484
- ax.set_title(title, fontsize=12, pad=10)
485
- ax.set_xlabel(f'{observable}', fontsize=10)
486
- ax.set_ylabel('Density', fontsize=10)
487
- ax.tick_params(axis='both', which='major', labelsize=8)
488
- ax.grid(True, alpha=0.3)
489
-
490
- # Create legend
491
- if len(labels_data) <= 5:
492
- if label_names is not None:
493
- # Simple legend with just custom names and counts
494
- ax.legend(fontsize=8, loc='best')
495
- else:
496
- # Create custom legend labels with dataset information
497
- legend_labels = []
498
- for label in labels_data.keys():
499
- datasets = label_to_datasets.get(label, [])
500
-
501
- if len(datasets) == 1:
502
- # Single dataset
503
- dataset_info = datasets[0]
504
- elif len(datasets) <= 2:
505
- # Few datasets - show all names
506
- dataset_info = ', '.join(datasets)
507
- else:
508
- # Many datasets - show count
509
- dataset_info = f"{datasets[0]}, +{len(datasets)-1} more"
510
-
511
- legend_labels.append(f'Label {label} (N={len(labels_data[label]):,})\n{dataset_info}')
512
-
513
- # Get the legend handles and update their labels
514
- handles, _ = ax.get_legend_handles_labels()
515
- ax.legend(handles, legend_labels, fontsize=6, loc='best')
516
- else:
517
- total_events = sum(len(data) for data in labels_data.values())
518
- ax.set_title(f'{title}\n(Total N={total_events:,})', fontsize=11)
519
-
520
- # Hide unused subplots
521
- for idx in range(n_observables, len(axes_flat)):
522
- axes_flat[idx].set_visible(False)
523
-
524
- # Adjust layout and save
525
- plt.tight_layout()
526
- plt.subplots_adjust(top=0.94, right=0.85 if len(all_labels) > 5 else 0.95)
527
-
528
- output_path = f"{output_dir}/{output_filename}"
529
- plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white')
530
- plt.close()
531
-
532
- print(f"Created comparison grid by label: {output_path}")
533
- print(f"Grid contains {n_observables} observables across {len(all_labels)} labels")
534
-
535
- # Print summary of what was combined
536
- print("\nLabel summary:")
537
- for label in all_labels:
538
- datasets_with_label = [dset for dset, info in dset_info.items() if info.get('label') == label]
539
- if label_names is not None:
540
- label_idx = all_labels.index(label)
541
- display_name = label_names[label_idx]
542
- else:
543
- display_name = f"Label {label}"
544
- print(f" {display_name}: {len(datasets_with_label)} datasets ({', '.join(datasets_with_label)})")
545
-
546
- def main(): ###DONT SPECIFY EXCLUDE ZEROS HERE, BUT RATHER DERIVE IT FROM THE CONFIG!!!
547
- parser = argparse.ArgumentParser()
548
- add_arg = parser.add_argument
549
-
550
- add_arg("--config", type=str, required = True, help = "The path to the config.")
551
- add_arg("--output_dir", type=str, required = True, help = "The path of the directory where you want the plots to be outputted to.")
552
- add_arg('--label_names', nargs='+', default = ["None"], help = "A list of the names associated with each label to be displayed in the legends of the histograms.")
553
- add_arg("--output_filename", type=str, default = "input_var_distribution_comparisons.png", help = "The name of the file you want the plots to be outputted to.")
554
- add_arg("--use_percentile_for_xlims", action = "store_true", help = "If this flag is provided, the xlims will be set as [first bin edge, 98th percentile] rather than [first bin edge, last bin edge].")
555
- add_arg("--xlim_adjustment", action = "store_true", help = "If this flag is provided, the xlims will be set using the mean and std of the data.")
556
-
557
- args = parser.parse_args()
558
-
559
- config_filepath = args.config
560
- output_dir = args.output_dir
561
- label_names = args.label_names
562
- output_filename = args.output_filename
563
- use_percentile = args.use_percentile_for_xlims
564
- xlim_adjustment = args.xlim_adjustment
565
-
566
- dset = extract_dataset_info(config_filepath)
567
-
568
- # exclude_zeros_list = []
569
- # for key in dset:
570
- # exclude_zeros_list = dset[key]["exclude_zeros"]
571
- # break
572
-
573
- # exclude_zeros = tuple(exclude_zeros_list)
574
-
575
- # make_distributions(dset, output_dir, exclude_zeros)
576
- if label_names[0] == "None":
577
- make_distributions_comparison_grid_by_label(dset, output_dir, output_filename, use_percentile_for_xlims=use_percentile, xlim_adjustment=xlim_adjustment)
578
- else:
579
- make_distributions_comparison_grid_by_label(dset, output_dir, output_filename, label_names, use_percentile, xlim_adjustment)
580
-
581
- if __name__ == "__main__":
582
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
root_gnn_dgl/run_demo.sh CHANGED
@@ -13,7 +13,7 @@ done
13
 
14
  python scripts/training_script.py --config configs/stats_100K/pretraining_multiclass.yaml --preshuffle --nocompile --lazy
15
 
16
- # From Scratch Training
17
 
18
  datasets=("ttH_CP_even" "ttH_CP_odd")
19
  chunks=3
@@ -27,8 +27,6 @@ done
27
 
28
  python scripts/training_script.py --config configs/stats_100K/ttH_CP_even_vs_odd.yaml --preshuffle --nocompile --lazy
29
 
30
- # Finetuning Training
31
-
32
  python scripts/training_script.py --config configs/stats_100K/finetuning_ttH_CP_even_vs_odd.yaml --preshuffle --nocompile --lazy
33
 
34
  # Inference
 
13
 
14
  python scripts/training_script.py --config configs/stats_100K/pretraining_multiclass.yaml --preshuffle --nocompile --lazy
15
 
16
+ # Finetuning
17
 
18
  datasets=("ttH_CP_even" "ttH_CP_odd")
19
  chunks=3
 
27
 
28
  python scripts/training_script.py --config configs/stats_100K/ttH_CP_even_vs_odd.yaml --preshuffle --nocompile --lazy
29
 
 
 
30
  python scripts/training_script.py --config configs/stats_100K/finetuning_ttH_CP_even_vs_odd.yaml --preshuffle --nocompile --lazy
31
 
32
  # Inference
root_gnn_dgl/scripts/check_dataset_files.py DELETED
@@ -1,130 +0,0 @@
1
- import yaml
2
- import os
3
- import subprocess
4
- import argparse
5
-
6
- def check_dataset_files(yaml_file, rerun=False):
7
- """
8
- Check if all required .bin files exist for each dataset in the YAML file.
9
- """
10
- try:
11
- # Open and parse the YAML file
12
- with open(yaml_file, 'r') as file:
13
- config = yaml.safe_load(file)
14
-
15
- # Check if 'Datasets' exists in the YAML file
16
- if 'Datasets' not in config:
17
- print(f"No 'Datasets' section found in {yaml_file}.")
18
- return
19
-
20
- datasets = config['Datasets']
21
- all_files_exist = True
22
-
23
- for dataset_name, dataset_config in datasets.items():
24
- # Extract required information
25
- save_dir = dataset_config['args']['save_dir']
26
- chunks = dataset_config['args']['chunks']
27
- folding = dataset_config.get('folding', {})
28
- n_folds = folding.get('n_folds', 0)
29
- test_folds = folding.get('test', [])
30
- train_folds = folding.get('train', [])
31
-
32
- print(f"\n== Checking dataset: {dataset_name} ==")
33
- print(f" save_dir: {save_dir}")
34
- print(f" chunks: {chunks}")
35
- print(f" n_folds: {n_folds}")
36
- print(f" test_folds: {test_folds}")
37
- print(f" train_folds: {train_folds}")
38
-
39
- missing_files = []
40
-
41
- # 1. Check for chunk files
42
- for chunk in range(chunks):
43
- chunk_file = os.path.join(save_dir, f"{dataset_name}_{chunk}.bin")
44
- if not os.path.exists(chunk_file):
45
- missing_files.append(chunk_file)
46
-
47
- # 2. Check for prebatched fold files (test and train)
48
- # Naming: dataset_name_prebatched_padded_{fold}_n_{n_folds}_f_{foldlist}.bin
49
- fold_types = [('test', test_folds), ('train', train_folds)]
50
- for fold_type, folds in fold_types:
51
- if not folds:
52
- continue
53
- foldlist_str = '_'.join(map(str, folds))
54
- for i in range(chunks):
55
- prebatched_file = os.path.join(
56
- save_dir,
57
- f"{dataset_name}_prebatched_padded_{i}_n_{n_folds}_f_{foldlist_str}.bin"
58
- )
59
- if not os.path.exists(prebatched_file):
60
- missing_files.append(prebatched_file)
61
-
62
- # Print results for the current dataset
63
- if missing_files:
64
- all_files_exist = False
65
- print(f" Missing files for dataset '{dataset_name}':")
66
- for missing_file in missing_files:
67
- print(f" - {missing_file}")
68
-
69
- # Optionally rerun data prep
70
- if rerun:
71
- print(f" Reprocessing dataset '{dataset_name}' ...")
72
- prep_command = f"bash/prep_data.sh {yaml_file} {dataset_name} {chunks}"
73
- try:
74
- subprocess.run(prep_command, shell=True, check=True)
75
- except subprocess.CalledProcessError as e:
76
- print(f" Could NOT reprocess '{dataset_name}': {e}")
77
- else:
78
- print(f" All files exist for dataset '{dataset_name}'.")
79
-
80
- # Final summary
81
- if all_files_exist:
82
- print("\nAll required files exist for all datasets.")
83
- else:
84
- print("\nSome files are missing.")
85
-
86
- except Exception as e:
87
- print(f"Error processing {yaml_file}: {e}")
88
-
89
- def main(pargs):
90
- # Base directory containing the YAML files
91
- base_directory = os.getcwd() + "/configs/"
92
-
93
- if pargs.configs:
94
- configs = [p.strip() for p in pargs.configs.split(',')]
95
- else:
96
- configs = [
97
- "attention/ttH_CP_even_vs_odd.yaml",
98
-
99
- "stats_100K/finetuning_ttH_CP_even_vs_odd.yaml",
100
- "stats_100K/pretraining_multiclass.yaml",
101
- "stats_100K/ttH_CP_even_vs_odd.yaml",
102
-
103
- "stats_all/finetuning_ttH_CP_even_vs_odd.yaml",
104
- "stats_all/pretraining_multiclass.yaml",
105
- "stats_all/ttH_CP_even_vs_odd.yaml",
106
- ]
107
-
108
- for config in configs:
109
- yaml_file = os.path.join(base_directory, config)
110
- if os.path.exists(yaml_file):
111
- print(f"\nProcessing file: {config}")
112
- check_dataset_files(yaml_file, pargs.rerun)
113
- else:
114
- print(f"File not found: {yaml_file}")
115
-
116
- if __name__ == "__main__":
117
- parser = argparse.ArgumentParser(description="Check YAML config files")
118
- parser.add_argument(
119
- "--configs", "-c",
120
- type=str,
121
- required=False,
122
- help="Comma-separated list of YAML config paths relative to base directory"
123
- )
124
- parser.add_argument(
125
- "--rerun", "-r",
126
- action='store_true', # Correct way for a boolean flag
127
- help="Automatically re-run data processing to fix missing files"
128
- )
129
- args = parser.parse_args()
130
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
root_gnn_dgl/scripts/inference.py CHANGED
@@ -135,6 +135,7 @@ def main():
135
 
136
  import time
137
  start = time.time()
 
138
  import torch
139
  from array import array
140
  import numpy as np
@@ -246,41 +247,55 @@ def main():
246
  all_labels[branch] = labels
247
  all_tracking[branch] = tracking_info
248
 
249
-
250
  if args.write:
251
- import uproot
252
- import awkward as ak
253
-
254
- # Open the original ROOT file and get the tree
255
- infile = uproot.open(args.target)
256
- tree = infile[dset_config['args']['tree_name']]
257
 
258
- # Read the original tree as an awkward array
259
- original_data = tree.arrays(library="ak")
260
-
261
- # Prepare new branches as dicts of arrays
262
- new_branches = {}
263
- n_entries = len(original_data)
264
- for branch, scores in all_scores.items():
265
- # Ensure the scores array is the right length
266
- scores = np.asarray(scores)
267
- if scores.shape[0] != n_entries:
268
- raise ValueError(f"Branch '{branch}' has {scores.shape[0]} entries, but tree has {n_entries}")
269
- new_branches[branch] = scores
270
-
271
- # Merge all arrays (original + new branches)
272
- # Convert awkward to dict of numpy arrays for uproot
273
- out_dict = {k: np.asarray(v) for k, v in ak.to_numpy(original_data).items()}
274
- out_dict.update(new_branches)
275
-
276
- # Write to new ROOT file
277
  os.makedirs(os.path.split(args.destination)[0], exist_ok=True)
278
- with uproot.recreate(args.destination) as outfile:
279
- outfile.mktree(dset_config['args']['tree_name'], {k: v.dtype for k, v in out_dict.items()})
280
- outfile[dset_config['args']['tree_name']].extend(out_dict)
281
 
282
- print(f"Wrote new ROOT file {args.destination} with new branches {list(new_branches.keys())}")
 
283
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  else:
285
  os.makedirs(os.path.split(args.destination)[0], exist_ok=True)
286
  np.savez(args.destination, scores=all_scores, labels=all_labels, tracking_info=all_tracking)
 
135
 
136
  import time
137
  start = time.time()
138
+ import ROOT
139
  import torch
140
  from array import array
141
  import numpy as np
 
247
  all_labels[branch] = labels
248
  all_tracking[branch] = tracking_info
249
 
 
250
  if args.write:
251
+ from ROOT import std
252
+ # Open the original ROOT file
253
+ infile = ROOT.TFile.Open(args.target)
254
+ tree = infile.Get(dset_config['args']['tree_name'])
 
 
255
 
256
+ # Create the destination directory if it doesn't exist
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  os.makedirs(os.path.split(args.destination)[0], exist_ok=True)
 
 
 
258
 
259
+ # Create a new ROOT file to write the modified tree
260
+ outfile = ROOT.TFile.Open(args.destination, 'RECREATE')
261
 
262
+ # Clone the original tree structure
263
+ outtree = tree.CloneTree(0)
264
+
265
+ # Create branches for all scores
266
+ branch_vectors = {}
267
+ for branch, scores in all_scores.items():
268
+ if isinstance(scores[0], (list, tuple, np.ndarray)) and len(scores[0]) > 1:
269
+ # Create a new branch for vectors
270
+ branch_vectors[branch] = std.vector('float')()
271
+ outtree.Branch(branch, branch_vectors[branch])
272
+ else:
273
+ # Create a new branch for single floats
274
+ branch_vectors[branch] = array('f', [0])
275
+ outtree.Branch(branch, branch_vectors[branch], f'{branch}/F')
276
+
277
+ # Fill the tree
278
+ for i in range(tree.GetEntries()):
279
+ tree.GetEntry(i)
280
+
281
+ for branch, scores in all_scores.items():
282
+ branch_data = branch_vectors[branch]
283
+ if isinstance(branch_data, array): # Check if it's a single float array
284
+ branch_data[0] = float(scores[i])
285
+ else: # Assume it's a std::vector<float>
286
+ branch_data.clear()
287
+ for value in scores[i]:
288
+ branch_data.push_back(float(value))
289
+
290
+ outtree.Fill()
291
+
292
+ # Write the modified tree to the new file
293
+ print(f'Writing to file {args.destination}')
294
+ print(f'Input entries: {tree.GetEntries()}, Output entries: {outtree.GetEntries()}')
295
+ print(f'Wrote scores to {args.branch_name}')
296
+ outtree.Write()
297
+ outfile.Close()
298
+ infile.Close()
299
  else:
300
  os.makedirs(os.path.split(args.destination)[0], exist_ok=True)
301
  np.savez(args.destination, scores=all_scores, labels=all_labels, tracking_info=all_tracking)
root_gnn_dgl/scripts/prep_data.py CHANGED
@@ -33,12 +33,12 @@ def main():
33
  fold_conf = dset_config["folding"]
34
  print(f"shuffle_chunks = {shuffle_chunks}, args.chunk = {args.chunk}, padding_mode = {padding_mode}")
35
  if dset_config["class"] == "LazyMultiLabelDataset":
36
- LazyPreBatchedDataset(start_dataset = dset, batch_size = batch_size, mask_fn = utils.fold_selection(fold_conf, "train"), suffix = utils.fold_selection_name(fold_conf, "train"), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last, hidden_size=config['Model']['args']['hid_size'] )
37
- LazyPreBatchedDataset(start_dataset = dset, batch_size = batch_size, mask_fn = utils.fold_selection(fold_conf, "test"), suffix = utils.fold_selection_name(fold_conf, 'test'), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last, hidden_size=config['Model']['args']['hid_size'])
38
 
39
  else:
40
- PreBatchedDataset(dset, batch_size, utils.fold_selection(fold_conf, "train"), suffix = utils.fold_selection_name(fold_conf, "train"), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last,hidden_size=config['Model']['args']['hid_size'])
41
- PreBatchedDataset(dset, batch_size, utils.fold_selection(fold_conf, "test"), suffix = utils.fold_selection_name(fold_conf, 'test'), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last,hidden_size=config['Model']['args']['hid_size'] )
42
 
43
  if __name__ == "__main__":
44
- main()
 
33
  fold_conf = dset_config["folding"]
34
  print(f"shuffle_chunks = {shuffle_chunks}, args.chunk = {args.chunk}, padding_mode = {padding_mode}")
35
  if dset_config["class"] == "LazyMultiLabelDataset":
36
+ LazyPreBatchedDataset(start_dataset = dset, batch_size = batch_size, mask_fn = utils.fold_selection(fold_conf, "train"), suffix = utils.fold_selection_name(fold_conf, "train"), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last)
37
+ LazyPreBatchedDataset(start_dataset = dset, batch_size = batch_size, mask_fn = utils.fold_selection(fold_conf, "test"), suffix = utils.fold_selection_name(fold_conf, 'test'), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last)
38
 
39
  else:
40
+ PreBatchedDataset(dset, batch_size, utils.fold_selection(fold_conf, "train"), suffix = utils.fold_selection_name(fold_conf, "train"), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last)
41
+ PreBatchedDataset(dset, batch_size, utils.fold_selection(fold_conf, "test"), suffix = utils.fold_selection_name(fold_conf, 'test'), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last)
42
 
43
  if __name__ == "__main__":
44
+ main()
root_gnn_dgl/scripts/training_script.py CHANGED
@@ -45,10 +45,10 @@ def gpu_mem():
45
  # except:
46
  # pass
47
  print(f'Current GPU memory usage: {torch.cuda.memory_allocated() / 1024 / 1024 / 1024} GB')
48
- # print(f'Current GPU cache usage: {torch.cuda.memory_cached() / 1024 / 1024 / 1024} GB')
49
- # print(f'Current GPU max memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024} GB')
50
- # print(f'Current GPU max cache usage: {torch.cuda.max_memory_cached() / 1024 / 1024 / 1024} GB')
51
- # print(f'Numel in current tensors: {sum}')
52
  mem()
53
 
54
 
@@ -263,16 +263,11 @@ def train(train_loaders, test_loaders, model, device, config, args, rank):
263
  for epoch in range(starting_epoch, config['Training']['epochs']):
264
  start = time.time()
265
  run = start
266
- if (args.profile):
267
- if (epoch == 0):
268
- torch.cuda.cudart().cudaProfilerStart()
269
- torch.cuda.nvtx.range_push("Epoch Start")
270
-
271
  if (args.multigpu or args.multinode):
272
  dist.barrier()
273
-
274
- if (epoch == 5):
275
- exit
276
 
277
  # training
278
  model.train()
@@ -297,8 +292,6 @@ def train(train_loaders, test_loaders, model, device, config, args, rank):
297
  if is_padded: #Padding the globals to match padded graphs.
298
  global_feats = torch.concatenate((global_feats, torch.zeros(1, len(global_feats[0])).to(device)))
299
  load = time.time()
300
- if (args.profile):
301
- torch.cuda.nvtx.range_push("Model Forward")
302
  if (len(logits) == 0):
303
  logits = model(graph, global_feats)
304
  tlabels = label
@@ -309,9 +302,6 @@ def train(train_loaders, test_loaders, model, device, config, args, rank):
309
  weights = torch.concatenate((weights, track[:,1]), dim=0)
310
  batch_lengths.append(logits.shape[0] - 1)
311
 
312
- if (args.profile):
313
- torch.cuda.nvtx.range_pop() # popping model forward
314
-
315
  if is_padded:
316
  keepmask = torch.full_like(logits[:,0], True, dtype=torch.bool)
317
  keepmask[batch_lengths] = False
@@ -350,15 +340,11 @@ def train(train_loaders, test_loaders, model, device, config, args, rank):
350
  normalized_loss += label_loss
351
  loss = normalized_loss / len(unique_labels)
352
 
353
- if (args.profile):
354
- torch.cuda.nvtx.range_push("Model Backward")
355
  optimizer.zero_grad()
356
  loss.backward()
357
  optimizer.step()
358
  total_loss += loss.detach().cpu().item()
359
-
360
- if (args.profile):
361
- torch.cuda.nvtx.range_pop() # pop model backward
362
  ibatch += 1
363
  cumulative_times[0] += batch_start - run
364
  cumulative_times[1] += load - batch_start
@@ -380,10 +366,6 @@ def train(train_loaders, test_loaders, model, device, config, args, rank):
380
  labels = []
381
  weights = []
382
  model.eval()
383
-
384
- if (args.profile):
385
- torch.cuda.nvtx.range_push("Model Evaluation")
386
-
387
  with torch.no_grad():
388
  for loader in test_loaders:
389
  for batch, label, track, global_feats in loader:
@@ -404,9 +386,6 @@ def train(train_loaders, test_loaders, model, device, config, args, rank):
404
  eval_end = time.time()
405
  cumulative_times[3] += eval_end - run
406
 
407
- if (args.profile):
408
- torch.cuda.nvtx.range_pop() # pop evaluation
409
-
410
  if scores == []: #If validation set is empty.
411
  continue
412
  logits = torch.concatenate(scores).to(device)
@@ -496,8 +475,6 @@ def train(train_loaders, test_loaders, model, device, config, args, rank):
496
 
497
  try:
498
  #test_auc = roc_auc_score(labels[wgt_mask].to("cpu") == 1, scores[wgt_mask].to("cpu"), multi_class='ovr', sample_weight=weights[wgt_mask].to("cpu"))
499
- if (len(scores[0]) != config["Model"]["args"]["out_size"]):
500
- print("ERROR: The out_size and the number of class labels don't match! Please check config.")
501
  test_auc = roc_auc_score(labels_onehot[wgt_mask], scores[wgt_mask].to("cpu"), multi_class='ovr', sample_weight=weights[wgt_mask].to("cpu"))
502
  except ValueError:
503
  test_auc = np.nan
@@ -602,9 +579,6 @@ def train(train_loaders, test_loaders, model, device, config, args, rank):
602
  custom_scheduler.step(model, {'test_auc':test_auc})
603
  scheduler.step()
604
 
605
- if (args.profile):
606
- torch.cuda.nvtx.range_pop() # pop epoch
607
-
608
  print(f"Load: {cumulative_times[0]:.4f} s")
609
  print(f"Batch: {cumulative_times[1]:.4f} s")
610
  print(f"Train: {cumulative_times[2]:.4f} s")
@@ -704,7 +678,7 @@ def main(rank=0, args=None, world_size=1, port=24500, seed=12345):
704
  mask_fn = utils.fold_selection(fold_conf, "train")
705
  if args.preshuffle:
706
  # ldr = ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=mask_fn, suffix = utils.fold_selection_name(fold_conf, 'train'), chunks = shuffle_chunks, padding_mode = padding_mode, use_ddp = args.multigpu, rank=rank, world_size=world_size)
707
- ldr = ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=mask_fn, suffix = utils.fold_selection_name(fold_conf, 'train'), chunks = shuffle_chunks, padding_mode = padding_mode, hidden_size = config["Model"]["args"]["hid_size"])
708
  gsamp, _, _, global_samp = ldr[0]
709
  sampler = None
710
 
@@ -721,7 +695,7 @@ def main(rank=0, args=None, world_size=1, port=24500, seed=12345):
721
  sampler = DistributedSampler(ldr, num_replicas=world_size, rank=pargs.global_rank, shuffle=False, drop_last=True)
722
  train_loaders.append(torch.utils.data.DataLoader(ldr, batch_size = None, num_workers = 0, sampler = sampler))
723
  sampler = None
724
- ldr = ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=mask_fn, suffix = utils.fold_selection_name(fold_conf, 'test'), chunks = shuffle_chunks, padding_mode = padding_mode, hidden_size= config['Model']['args']['hid_size'])
725
  if (args.multigpu):
726
  sampler = DistributedSampler(ldr, num_replicas=world_size, rank=rank, shuffle=False, drop_last=True)
727
  # num_batches = len(ldr)
@@ -734,7 +708,7 @@ def main(rank=0, args=None, world_size=1, port=24500, seed=12345):
734
  test_loaders.append(torch.utils.data.DataLoader(ldr, batch_size = None, num_workers = 0, sampler=sampler))
735
 
736
  if "validation" in fold_conf:
737
- val_loaders.append(torch.utils.data.DataLoader((ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=utils.fold_selection(fold_conf, "validation"), suffix = utils.fold_selection_name(fold_conf, 'validation'), chunks = shuffle_chunks, hidden_size=config['Model']['args']['hid_size'], padding_mode = padding_mode, rank=rank, world_size=1)), batch_size = None, num_workers = 0, sampler = sampler))
738
  else:
739
  print("No validation set for dataset ", dset_conf)
740
  else:
@@ -750,8 +724,6 @@ def main(rank=0, args=None, world_size=1, port=24500, seed=12345):
750
  print("Load time: {:.4f} s".format(load_end - load_start))
751
 
752
  model = utils.buildFromConfig(config["Model"], {'sample_graph': gsamp, 'sample_global': global_samp, 'seed': seed}).to(device)
753
- pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
754
- print(f"Number of trainable parameters = {pytorch_total_params}")
755
  if not args.nocompile:
756
  model = torch.compile(model)
757
  if args.multigpu:
@@ -816,7 +788,6 @@ if __name__ == "__main__":
816
  add_arg("--directory", type=str, help="Append to Training Directory")
817
  add_arg("--seed", type=int, default=2, help="Sets random seed")
818
  add_arg("--abs", action="store_true", help="Use abs value of per-event weight")
819
- add_arg("--profile", action="store_true", help="use nsight systems profiler")
820
 
821
  pargs = parser.parse_args()
822
 
 
45
  # except:
46
  # pass
47
  print(f'Current GPU memory usage: {torch.cuda.memory_allocated() / 1024 / 1024 / 1024} GB')
48
+ print(f'Current GPU cache usage: {torch.cuda.memory_cached() / 1024 / 1024 / 1024} GB')
49
+ print(f'Current GPU max memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024} GB')
50
+ print(f'Current GPU max cache usage: {torch.cuda.max_memory_cached() / 1024 / 1024 / 1024} GB')
51
+ print(f'Numel in current tensors: {sum}')
52
  mem()
53
 
54
 
 
263
  for epoch in range(starting_epoch, config['Training']['epochs']):
264
  start = time.time()
265
  run = start
 
 
 
 
 
266
  if (args.multigpu or args.multinode):
267
  dist.barrier()
268
+ if (epoch == 2):
269
+ # torch.cuda.cudart().cudaProfilerStart()
270
+ pass
271
 
272
  # training
273
  model.train()
 
292
  if is_padded: #Padding the globals to match padded graphs.
293
  global_feats = torch.concatenate((global_feats, torch.zeros(1, len(global_feats[0])).to(device)))
294
  load = time.time()
 
 
295
  if (len(logits) == 0):
296
  logits = model(graph, global_feats)
297
  tlabels = label
 
302
  weights = torch.concatenate((weights, track[:,1]), dim=0)
303
  batch_lengths.append(logits.shape[0] - 1)
304
 
 
 
 
305
  if is_padded:
306
  keepmask = torch.full_like(logits[:,0], True, dtype=torch.bool)
307
  keepmask[batch_lengths] = False
 
340
  normalized_loss += label_loss
341
  loss = normalized_loss / len(unique_labels)
342
 
343
+
 
344
  optimizer.zero_grad()
345
  loss.backward()
346
  optimizer.step()
347
  total_loss += loss.detach().cpu().item()
 
 
 
348
  ibatch += 1
349
  cumulative_times[0] += batch_start - run
350
  cumulative_times[1] += load - batch_start
 
366
  labels = []
367
  weights = []
368
  model.eval()
 
 
 
 
369
  with torch.no_grad():
370
  for loader in test_loaders:
371
  for batch, label, track, global_feats in loader:
 
386
  eval_end = time.time()
387
  cumulative_times[3] += eval_end - run
388
 
 
 
 
389
  if scores == []: #If validation set is empty.
390
  continue
391
  logits = torch.concatenate(scores).to(device)
 
475
 
476
  try:
477
  #test_auc = roc_auc_score(labels[wgt_mask].to("cpu") == 1, scores[wgt_mask].to("cpu"), multi_class='ovr', sample_weight=weights[wgt_mask].to("cpu"))
 
 
478
  test_auc = roc_auc_score(labels_onehot[wgt_mask], scores[wgt_mask].to("cpu"), multi_class='ovr', sample_weight=weights[wgt_mask].to("cpu"))
479
  except ValueError:
480
  test_auc = np.nan
 
579
  custom_scheduler.step(model, {'test_auc':test_auc})
580
  scheduler.step()
581
 
 
 
 
582
  print(f"Load: {cumulative_times[0]:.4f} s")
583
  print(f"Batch: {cumulative_times[1]:.4f} s")
584
  print(f"Train: {cumulative_times[2]:.4f} s")
 
678
  mask_fn = utils.fold_selection(fold_conf, "train")
679
  if args.preshuffle:
680
  # ldr = ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=mask_fn, suffix = utils.fold_selection_name(fold_conf, 'train'), chunks = shuffle_chunks, padding_mode = padding_mode, use_ddp = args.multigpu, rank=rank, world_size=world_size)
681
+ ldr = ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=mask_fn, suffix = utils.fold_selection_name(fold_conf, 'train'), chunks = shuffle_chunks, padding_mode = padding_mode)
682
  gsamp, _, _, global_samp = ldr[0]
683
  sampler = None
684
 
 
695
  sampler = DistributedSampler(ldr, num_replicas=world_size, rank=pargs.global_rank, shuffle=False, drop_last=True)
696
  train_loaders.append(torch.utils.data.DataLoader(ldr, batch_size = None, num_workers = 0, sampler = sampler))
697
  sampler = None
698
+ ldr = ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=mask_fn, suffix = utils.fold_selection_name(fold_conf, 'test'), chunks = shuffle_chunks, padding_mode = padding_mode)
699
  if (args.multigpu):
700
  sampler = DistributedSampler(ldr, num_replicas=world_size, rank=rank, shuffle=False, drop_last=True)
701
  # num_batches = len(ldr)
 
708
  test_loaders.append(torch.utils.data.DataLoader(ldr, batch_size = None, num_workers = 0, sampler=sampler))
709
 
710
  if "validation" in fold_conf:
711
+ val_loaders.append(torch.utils.data.DataLoader((ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=utils.fold_selection(fold_conf, "validation"), suffix = utils.fold_selection_name(fold_conf, 'validation'), chunks = shuffle_chunks, padding_mode = padding_mode, rank=rank, world_size=1)), batch_size = None, num_workers = 0, sampler = sampler))
712
  else:
713
  print("No validation set for dataset ", dset_conf)
714
  else:
 
724
  print("Load time: {:.4f} s".format(load_end - load_start))
725
 
726
  model = utils.buildFromConfig(config["Model"], {'sample_graph': gsamp, 'sample_global': global_samp, 'seed': seed}).to(device)
 
 
727
  if not args.nocompile:
728
  model = torch.compile(model)
729
  if args.multigpu:
 
788
  add_arg("--directory", type=str, help="Append to Training Directory")
789
  add_arg("--seed", type=int, default=2, help="Sets random seed")
790
  add_arg("--abs", action="store_true", help="Use abs value of per-event weight")
 
791
 
792
  pargs = parser.parse_args()
793
 
root_gnn_dgl/setup/Dockerfile DELETED
@@ -1,25 +0,0 @@
1
- FROM nvcr.io/nvidia/dgl:25.05-py3
2
-
3
- WORKDIR /global/cfs/projectdirs/atlas/joshua/GNN4Colliders
4
-
5
- LABEL maintainer.name="Joshua Ho"
6
- LABEL maintainer.email="ho22joshua@berkeley.edu"
7
-
8
- ENV LANG=C.UTF-8
9
-
10
- # Install system dependencies: vim, OpenMPI, and build tools
11
- RUN apt-get update -qq \
12
- && apt-get install -y --no-install-recommends \
13
- wget lsb-release gnupg software-properties-common \
14
- vim \
15
- g++-11 gcc-11 libstdc++-11-dev \
16
- openmpi-bin openmpi-common libopenmpi-dev \
17
- && rm -rf /var/lib/apt/lists/*
18
-
19
- # Install Python packages: mpi4py and jupyter
20
- RUN pip install --no-cache-dir mpi4py jupyter uproot
21
-
22
- # (Optional) Expose Jupyter port
23
- EXPOSE 8888
24
-
25
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
root_gnn_dgl/setup/build_image.sh DELETED
@@ -1,4 +0,0 @@
1
- tag=$1
2
- echo $tag
3
- podman-hpc build -t joshuaho/pytorch:$tag --platform linux/amd64 .
4
- podman-hpc migrate joshuaho/pytorch:$tag
 
 
 
 
 
root_gnn_dgl/setup/environment_torch24.yaml DELETED
@@ -1,249 +0,0 @@
1
- name: pytorch_dgl_cf
2
- channels:
3
- - conda-forge
4
- dependencies:
5
- - _libgcc_mutex=0.1=conda_forge
6
- - _openmp_mutex=4.5=3_kmp_llvm
7
- - absl-py=2.3.1=pyhd8ed1ab_0
8
- - annotated-types=0.7.0=pyhd8ed1ab_1
9
- - astunparse=1.6.3=pyhd8ed1ab_3
10
- - brotli-python=1.1.0=py310hf71b8c6_3
11
- - bzip2=1.0.8=h4bc722e_7
12
- - c-ares=1.34.5=hb9d3cd8_0
13
- - ca-certificates=2025.7.14=hbd8a1cb_0
14
- - cached-property=1.5.2=hd8ed1ab_1
15
- - cached_property=1.5.2=pyha770c72_1
16
- - certifi=2025.7.14=pyhd8ed1ab_0
17
- - cffi=1.17.1=py310h8deb56e_0
18
- - charset-normalizer=3.4.2=pyhd8ed1ab_0
19
- - colorama=0.4.6=pyhd8ed1ab_1
20
- - cpython=3.10.18=py310hd8ed1ab_0
21
- - cuda-version=11.8=h70ddcb2_3
22
- - cudatoolkit=11.8.0=h4ba93d1_13
23
- - cudnn=8.9.7.29=hbc23b4c_3
24
- - dgl=2.3.0=cuda118py310h5c70fa1_0
25
- - filelock=3.18.0=pyhd8ed1ab_0
26
- - flatbuffers=24.3.25=h59595ed_0
27
- - fsspec=2025.7.0=pyhd8ed1ab_0
28
- - gast=0.6.0=pyhd8ed1ab_0
29
- - giflib=5.2.2=hd590300_0
30
- - gmp=6.3.0=hac33072_2
31
- - gmpy2=2.2.1=py310he8512ff_0
32
- - google-pasta=0.2.0=pyhd8ed1ab_2
33
- - grpcio=1.62.2=py310h1b8f574_0
34
- - h2=4.2.0=pyhd8ed1ab_0
35
- - h5py=3.14.0=nompi_py310hea1e86d_100
36
- - hdf5=1.14.6=nompi_h6e4c0c1_102
37
- - hpack=4.1.0=pyhd8ed1ab_0
38
- - hyperframe=6.1.0=pyhd8ed1ab_0
39
- - icu=75.1=he02047a_0
40
- - idna=3.10=pyhd8ed1ab_1
41
- - importlib-metadata=8.7.0=pyhe01879c_1
42
- - jinja2=3.1.6=pyhd8ed1ab_0
43
- - keras=3.10.0=pyh753f3f9_0
44
- - kernel-headers_linux-64=4.18.0=he073ed8_8
45
- - keyutils=1.6.1=h166bdaf_0
46
- - krb5=1.21.3=h659f571_0
47
- - ld_impl_linux-64=2.44=h1423503_1
48
- - libabseil=20240116.2=cxx17_he02047a_1
49
- - libaec=1.1.4=h3f801dc_0
50
- - libblas=3.9.0=32_h59b9bed_openblas
51
- - libcblas=3.9.0=32_he106b2a_openblas
52
- - libcurl=8.14.1=h332b0f4_0
53
- - libedit=3.1.20250104=pl5321h7949ede_0
54
- - libev=4.33=hd590300_2
55
- - libexpat=2.7.0=h5888daf_0
56
- - libffi=3.4.6=h2dba641_1
57
- - libgcc=15.1.0=h767d61c_3
58
- - libgcc-ng=15.1.0=h69a702a_3
59
- - libgfortran=15.1.0=h69a702a_3
60
- - libgfortran5=15.1.0=hcea5267_3
61
- - libgomp=15.1.0=h767d61c_3
62
- - libgrpc=1.62.2=h15f2491_0
63
- - libhwloc=2.11.2=default_h3d81e11_1002
64
- - libiconv=1.18=h4ce23a2_1
65
- - libjpeg-turbo=3.1.0=hb9d3cd8_0
66
- - liblapack=3.9.0=32_h7ac8fdf_openblas
67
- - liblapacke=3.9.0=32_he2f377e_openblas
68
- - liblzma=5.8.1=hb9d3cd8_2
69
- - libmagma=2.7.2=h09b5827_2
70
- - libmagma_sparse=2.7.2=h09b5827_3
71
- - libnghttp2=1.64.0=h161d5f1_0
72
- - libnsl=2.0.1=hb9d3cd8_1
73
- - libopenblas=0.3.30=pthreads_h94d23a6_0
74
- - libpng=1.6.50=h943b412_0
75
- - libprotobuf=4.25.3=hd5b35b9_1
76
- - libre2-11=2023.09.01=h5a48ba9_2
77
- - libsqlite=3.50.2=hee844dc_2
78
- - libssh2=1.11.1=hcf80075_0
79
- - libstdcxx=15.1.0=h8f9b012_3
80
- - libstdcxx-ng=15.1.0=h4852527_3
81
- - libtorch=2.3.1=cuda118_h7aef8b2_300
82
- - liburing=2.7=h434a139_0
83
- - libuuid=2.38.1=h0b41bf4_0
84
- - libuv=1.51.0=hb9d3cd8_0
85
- - libxcrypt=4.4.36=hd590300_1
86
- - libxml2=2.13.8=h4bc477f_0
87
- - libzlib=1.3.1=hb9d3cd8_2
88
- - llvm-openmp=20.1.8=h4922eb0_0
89
- - markdown=3.8.2=pyhd8ed1ab_0
90
- - markdown-it-py=3.0.0=pyhd8ed1ab_1
91
- - markupsafe=3.0.2=py310h89163eb_1
92
- - mdurl=0.1.2=pyhd8ed1ab_1
93
- - metis=5.1.1=h59595ed_2
94
- - mkl=2023.2.0=h84fe81f_50496
95
- - ml_dtypes=0.4.0=py310h5eaa309_2
96
- - mpc=1.3.1=h24ddda3_1
97
- - mpfr=4.2.1=h90cbb55_3
98
- - mpmath=1.3.0=pyhd8ed1ab_1
99
- - namex=0.1.0=pyhd8ed1ab_0
100
- - nccl=2.27.3.1=h03a54cd_0
101
- - ncurses=6.5=h2d0b736_3
102
- - networkx=3.4.2=pyh267e887_2
103
- - numpy=1.26.4=py310hb13e2d6_0
104
- - openssl=3.5.1=h7b32b05_0
105
- - opt_einsum=3.4.0=pyhd8ed1ab_1
106
- - optree=0.16.0=py310h3788b33_0
107
- - packaging=25.0=pyh29332c3_1
108
- - pandas=2.3.1=py310h0158d43_0
109
- - pip=25.1.1=pyh8b19718_0
110
- - protobuf=4.25.3=py310h0e2eeba_1
111
- - psutil=7.0.0=py310ha75aee5_0
112
- - pycparser=2.22=pyh29332c3_1
113
- - pydantic=2.11.7=pyh3cfb1c2_0
114
- - pydantic-core=2.33.2=py310hbcd0ec0_0
115
- - pygments=2.19.2=pyhd8ed1ab_0
116
- - pysocks=1.7.1=pyha55dd90_7
117
- - python=3.10.18=hd6af730_0_cpython
118
- - python-dateutil=2.9.0.post0=pyhe01879c_2
119
- - python-flatbuffers=25.2.10=pyhbc23db3_0
120
- - python-tzdata=2025.2=pyhd8ed1ab_0
121
- - python_abi=3.10=7_cp310
122
- - pytorch=2.3.1=cuda118_py310he8d5cbe_300
123
- - pytz=2025.2=pyhd8ed1ab_0
124
- - pyyaml=6.0.2=py310h89163eb_2
125
- - re2=2023.09.01=h7f4b329_2
126
- - readline=8.2=h8c095d6_2
127
- - requests=2.32.4=pyhd8ed1ab_0
128
- - rich=14.0.0=pyh29332c3_0
129
- - scipy=1.15.2=py310h1d65ade_0
130
- - setuptools=80.9.0=pyhff2d567_0
131
- - six=1.17.0=pyhd8ed1ab_0
132
- - sleef=3.8=h1b44611_0
133
- - snappy=1.2.1=h8bd8927_1
134
- - sympy=1.14.0=pyh2585a3b_105
135
- - sysroot_linux-64=2.28=h4ee821c_8
136
- - tbb=2021.13.0=hceb3a55_1
137
- - tensorboard=2.17.1=pyhd8ed1ab_0
138
- - tensorboard-data-server=0.7.0=py310h6c63255_2
139
- - tensorflow=2.17.0=cpu_py310h42475c5_2
140
- - tensorflow-base=2.17.0=cpu_py310h98e3cc3_2
141
- - tensorflow-estimator=2.17.0=cpu_py310heba74a3_2
142
- - termcolor=3.1.0=pyhd8ed1ab_0
143
- - tk=8.6.13=noxft_hd72426e_102
144
- - tqdm=4.67.1=pyhd8ed1ab_1
145
- - typing-extensions=4.14.1=h4440ef1_0
146
- - typing-inspection=0.4.1=pyhd8ed1ab_0
147
- - typing_extensions=4.14.1=pyhe01879c_0
148
- - tzdata=2025b=h78e105d_0
149
- - urllib3=2.5.0=pyhd8ed1ab_0
150
- - werkzeug=3.1.3=pyhd8ed1ab_1
151
- - wheel=0.45.1=pyhd8ed1ab_1
152
- - wrapt=1.17.2=py310ha75aee5_0
153
- - yaml=0.2.5=h7f98852_2
154
- - zipp=3.23.0=pyhd8ed1ab_0
155
- - zstandard=0.23.0=py310ha75aee5_2
156
- - zstd=1.5.7=hb8e6e7a_2
157
- - pip:
158
- - anyio==4.9.0
159
- - argon2-cffi==25.1.0
160
- - argon2-cffi-bindings==21.2.0
161
- - arrow==1.3.0
162
- - asttokens==3.0.0
163
- - async-lru==2.0.5
164
- - attrs==25.3.0
165
- - awkward==2.8.5
166
- - awkward-cpp==47
167
- - babel==2.17.0
168
- - beautifulsoup4==4.13.4
169
- - bleach==6.2.0
170
- - comm==0.2.2
171
- - contourpy==1.3.2
172
- - cramjam==2.10.0
173
- - cycler==0.12.1
174
- - debugpy==1.8.15
175
- - decorator==5.2.1
176
- - defusedxml==0.7.1
177
- - exceptiongroup==1.3.0
178
- - executing==2.2.0
179
- - fastjsonschema==2.21.1
180
- - fonttools==4.59.0
181
- - fqdn==1.5.1
182
- - h11==0.16.0
183
- - httpcore==1.0.9
184
- - httpx==0.28.1
185
- - ipykernel==6.29.5
186
- - ipython==8.37.0
187
- - isoduration==20.11.0
188
- - jedi==0.19.2
189
- - joblib==1.5.1
190
- - json5==0.12.0
191
- - jsonpointer==3.0.0
192
- - jsonschema==4.24.0
193
- - jsonschema-specifications==2025.4.1
194
- - jupyter-client==8.6.3
195
- - jupyter-core==5.8.1
196
- - jupyter-events==0.12.0
197
- - jupyter-lsp==2.2.5
198
- - jupyter-server==2.16.0
199
- - jupyter-server-terminals==0.5.3
200
- - jupyterlab==4.4.4
201
- - jupyterlab-pygments==0.3.0
202
- - jupyterlab-server==2.27.3
203
- - kiwisolver==1.4.8
204
- - matplotlib==3.10.3
205
- - matplotlib-inline==0.1.7
206
- - mistune==3.1.3
207
- - nbclient==0.10.2
208
- - nbconvert==7.16.6
209
- - nbformat==5.10.4
210
- - nest-asyncio==1.6.0
211
- - notebook-shim==0.2.4
212
- - overrides==7.7.0
213
- - pandocfilters==1.5.1
214
- - parso==0.8.4
215
- - pexpect==4.9.0
216
- - pillow==11.3.0
217
- - platformdirs==4.3.8
218
- - prometheus-client==0.22.1
219
- - prompt-toolkit==3.0.51
220
- - ptyprocess==0.7.0
221
- - pure-eval==0.2.3
222
- - pyparsing==3.2.3
223
- - python-json-logger==3.3.0
224
- - pyzmq==27.0.0
225
- - referencing==0.36.2
226
- - rfc3339-validator==0.1.4
227
- - rfc3986-validator==0.1.1
228
- - rpds-py==0.26.0
229
- - scikit-learn==1.7.0
230
- - send2trash==1.8.3
231
- - sniffio==1.3.1
232
- - soupsieve==2.7
233
- - stack-data==0.6.3
234
- - terminado==0.18.1
235
- - threadpoolctl==3.6.0
236
- - tinycss2==1.4.0
237
- - tomli==2.2.1
238
- - torchdata==0.9.0
239
- - tornado==6.5.1
240
- - traitlets==5.14.3
241
- - types-python-dateutil==2.9.0.20250708
242
- - uproot==5.6.3
243
- - uri-template==1.3.0
244
- - wcwidth==0.2.13
245
- - webcolors==24.11.1
246
- - webencodings==0.5.1
247
- - websocket-client==1.8.0
248
- - xxhash==3.5.0
249
- prefix: /global/homes/c/chult/.conda/envs/pytorch_dgl_cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training_time.png DELETED

Git LFS Details

  • SHA256: 64dcae074d1349722669ea8754c2e1fb59401a5535db119115a734d5798adb97
  • Pointer size: 131 Bytes
  • Size of remote file: 293 kB