ho22joshua commited on
Commit
d129ca0
·
1 Parent(s): 1845be0

fixing inference

Browse files
README.md CHANGED
@@ -31,11 +31,11 @@ We present foundation models that adopt an architecture similar to that used for
31
 
32
  To provide a diverse set of physics processes for the pretraining, we use Madgraph@NLO 2.7.3 [Alwall et al., 2014](#ref-alwall-2014hca) to generate proton-proton collision events at next-to-leading order (NLO) in Quantum Chromodynamics (QCD). We generate 12 distinct Standard Model (SM) physics processes, including six major Higgs boson production mechanisms: gluon fusion production \\(ggF\\), vector boson fusion \\(VBF\\), associated production of the Higgs boson with a W boson \\(WH\\) or a Z boson \\(ZH\\), associated production of the Higgs boson with a top-quark pair \\(t\bar{t}H\\), and associated production of the Higgs boson with a single top quark and a forward quark \\(tHq\\). Additionally, we simulate six top quark production processes: single top production, top-quark pair production \\(t\bar{t}\\), top quark pair production in association with a pair of photons \\(t\bar{t}\gamma\gamma\\), associated production of a top-quark pair with a W boson \\(t\bar{t}W\\), simultaneous production of three top quarks \\(t\bar{t}t\\), and simultaneous production of four top quarks \\(t\bar{t}t\bar{t}\\). In these samples, the Higgs boson and top quarks decay inclusively. These 12 Higgs and top quark production processes constitute the pretraining dataset.
33
 
34
- To test the pretrained model, we further generated four processes including three beyond Standard Model (SM) processes: a SM \\(t\\bar{t}H\\) production where the Higgs boson decays exclusively to a pair of photons, a \\(t\\bar{t}H\\) production with the Higgs boson decaying to a pair of photons, where the top-Yukawa coupling is CP-odd, implemented using the Higgs Characterization model [Artoisenet et al., 2013](#ref-artoisinet-2013puc), the production of a pair of superpartners of the top quark (s-top) using the Minimal Supersymmetric Standard Model (MSSM) [Rosiek, 1990](#ref-rosiek-1990), [Allanach et al., 2009](#ref-allanach-2009), and flavor changing neutral current (FCNC) processes [Degrande et al., 2015](#ref-degrande-2015), [Durieux et al., 2015](#ref-durieux-2015). For the s-top process, we simulate the production of heavier s-top pairs \\(t_2\\bar{t_2}\\), where each heavier s-top (mass 582 GeV) decays into a lighter s-top \\(t_1\\) or \\(\bar{t_1}\\), mass 400 GeV) and a Higgs boson. The FCNC process involves \\(t\\bar{t}\\) production where one top quark decays to a Higgs boson and a light quark. We generate 10 million events for each process, except for \\(tHq\\) and \\(t\\bar{t}t\\bar{t}\\), where 5 million events were produced.
35
 
36
- In all simulation samples, the center of mass energy of the proton-proton collision is set to 13 TeV. The Higgs boson, top quarks, and vector bosons are set to decay inclusively (except the \\(t\\bar{t}H \\rightarrow \\gamma\\gamma\\) samples), with MadSpin [Artoisenet et al., 2012](#ref-artoisinet-2012st) handling the decays of top quarks and W bosons. The generated events are processed through Pythia 8.235 [Sjostrand et al., 2015](#ref-sjostrand-2015) for parton showering and heavy particle decays, followed by Delphes 3.4.2 [de Favereau et al., 2014](#ref-defavereau-2014) configured to emulate the ATLAS detector [ATLAS Collaboration, 2008](#ref-atlas-2008) for fast detector simulation.
37
 
38
- The detector-level object selection criteria are defined to align with typical experimental conditions. Photons are required to have transverse momentum \\(p_T \\geq 20~\\mathrm{GeV}\\) and pseudorapidity \\(|\\eta| \\leq 2.37\\), excluding the electromagnetic calorimeter crack region \\(1.37 < |\\eta| < 1.52\\). Electrons must have \\(p_T \\geq 10~\\mathrm{GeV}\\) and \\(|\\eta| \\leq 2.47\\) (excluding the same crack region), while muons are selected with \\(p_T \\geq 10~\\mathrm{GeV}\\) and \\(|\\eta| \\leq 2.7\\). Jets are reconstructed using the anti-\\(k_t\\) algorithm [Cacciari et al., 2008](#ref-cacciari-2008gp) with radius parameter \\(\\Delta R=0.4\\), where \\(\\Delta R\\) is defined as \\(\\sqrt{\\Delta\\eta ^2 + \\Delta\\phi^2}\\), with \\(\\Delta\\eta\\) being the difference in pseudorapidity and \\(\\Delta\\phi\\) the difference in azimuthal angle. Jets must satisfy \\(p_T \\geq 25~\\mathrm{GeV}\\) and \\(|\\eta| \\leq 2.5\\). To avoid double-counting, jets are removed if they are within \\(\\Delta R < 0.4\\) of a photon or lepton. The identification of jets originating from b-quark decays (b-tagging) is performed by matching jets within \\(\\Delta R = 0.4\\) of a b-quark, with efficiency corrections applied to match the performance of the ATLAS experiment's b-tagging algorithm [ATLAS Collaboration, 2019](#ref-atlas-2019bwq).
39
 
40
  ## Methods
41
 
@@ -45,11 +45,11 @@ We present a methodology for developing and evaluating a foundation model for pa
45
 
46
  Given the prevalence of classification problems in particle physics data analysis, we evaluate the model's efficacy through a systematic assessment across five binary classification tasks:
47
 
48
- - $t\bar{t}H(\rightarrow \gamma\gamma)$ with CP-even versus CP-odd t-H interaction
49
- - $t\bar{t}$ with FCNC top quark decays versus $tHq$ processes
50
- - $t\bar{t}W$ versus $ttt$ processes
51
- - Stop pair production with Higgs bosons in the decay chain versus $t\bar{t}H$ processes
52
- - $WH$ versus $ZH$ production modes
53
 
54
  Our evaluation metrics encompass classification performance, computational efficiency, and model interpretability. The investigation extends to analyzing the model's scaling behavior with respect to training dataset size, benchmarked against models trained without pretraining. Although we explored transfer learning through parameter freezing of pretrained layers, this approach did not yield performance improvements, leading us to focus our detailed analysis on fine-tuning strategies.
55
 
@@ -59,7 +59,7 @@ This methodological framework demonstrates the potential of foundation models to
59
 
60
  ### GNN Architecture
61
 
62
- We implement a Graph Neural Network (GNN) architecture that naturally accommodates the point-cloud structure of particle physics data, employing the DGL framework with a PyTorch backend [Wang et al., 2019][ref-dgl-2019], [Paszke et al., 2019][ref-pytorch-2019]. A fully connected graph is constructed for each event, with nodes corresponding to reconstructed jets, electrons, muons, photons, and $\vec{E}_T^{\text{miss}}$. The features of each node include the four-momentum $(p_T, \eta, \phi, E)$ of the object with a massless assumption ($E = p_T \cosh \eta$), the b-tagging label (for jets), the charge (for leptons), and an integer labeling the type of object represented by the node. We use a placeholder value of 0 for features which are not defined for every node type such as the b-jet tag, lepton charge, or the pseudorapidity of $\vec{E}_T^{\text{miss}}$. We assign the angular distances ($\Delta \eta, \Delta \phi, \Delta R$) as edge features and the number of nodes $N$ in the graph as a global feature. We denote the node features $\{\vec x_i\}$, edge features $\{\vec y_{ij}\}$, and global features $\{\vec z\}$.
63
 
64
  The GNN model is based on the graph network architecture described in [Battaglia et al., 2018][ref-graphnets-2018] using simple multilayer perceptron (MLP) feature functions and summation aggregation. The model is comprised of three primary components: an encoder, the graph network, and a decoder. In the encoder, three MLPs embed the nodes, edges, and global features into a latent space of dimension 64. The graph network block, which is designed to facilitate message passing between different domains of the graph, performs an edge update $f_e$, followed by a node update $f_n$, and finally a global update $f_g$, all defined below. The inputs to each update MLP are concatenated.
65
 
@@ -77,7 +77,7 @@ $$
77
 
78
  This graph block is iterated four times with the same update MLPs. Finally, the global features are passed through a decoder MLP and a final layer linear to produce the desired model outputs. Each MLP consists of 4 linear layers, each with an output width of 64, with the `ReLU` activation function. The output of the MLP is then passed through a `LayerNorm` layer [Ba et al., 2016][ref-layernorm-2016]. The total number of trainable parameters in this model is about 400,000.
79
 
80
- As a performance benchmark, a baseline GNN model is trained from scratch for each classification task. The initial learning rate is set to $10^{-4}$ with an exponential decay following $LR(x) = LR_{\text{initial}}\cdot(0.99)^x$, where $x$ represents the epoch number.
81
 
82
  ---
83
 
@@ -95,23 +95,22 @@ This approach combines both classification and regression tasks to characterize
95
 
96
  We develop a comprehensive set of 41 labels that capture both particle multiplicities and kinematic properties. This approach increases prediction granularity and enhances model interpretability. By training the model to predict event kinematics rather than event identification, we create a task-independent framework that can potentially generalize better to novel scenarios not seen during pretraining.
97
 
98
- The particle multiplicity labels count the number of Higgs bosons ($n_{\text{higgs}}$), top quarks ($n_{\text{tops}}$), vector bosons ($n_V$), $W$ bosons ($n_W$), and $Z$ bosons ($n_Z$). The kinematic labels characterize the transverse momentum ($p_T$), pseudorapidity ($\eta$), and azimuthal angle ($\phi$) of Higgs bosons and top quarks through binned classifications.
99
 
100
- For Higgs bosons, $p_T$ is categorized into three ranges: (0, 30) GeV, (30, 200) GeV, and (200, $\infty$) GeV, with the upper range particularly sensitive to potential BSM effects. Similarly, both leading and subleading top quarks have $p_T$ classifications spanning (0, 30) GeV, (30, 300) GeV, and (300, $\infty$) GeV. When no particle exists within a specific $p_T$ range, the corresponding label is set to $[0, 0, 0]$. For all particles, $\eta$ measurements are divided into 4 bins with boundaries at $[-1.5, 0, 1.5]$, while $\phi$ measurements use 4 bins with boundaries at $[-\frac{\pi}{2}, 0, \frac{\pi}{2}]$. As with $p_T$, both $\eta$ and $\phi$ labels default to $[0, 0, 0, 0]$ in the absence of a particle. This comprehensive labeling schema enables fine-grained learning of kinematic distributions and particle multiplicities, essential for characterizing complex collision events.
101
 
102
  The loss function combines individual losses from all 41 labels through weighted averaging. Binary cross-entropy is applied to classification labels, while mean-squared error is used for regression labels. The model generates predictions for all labels simultaneously, with individual losses calculated according to their respective types. The final loss is computed as an equally-weighted average across all labels, with weights set to 1 to ensure uniform contribution to the optimization process. The output layer of the multilabel model has 2,688 trainable parameters.
103
 
104
  #### Pretraining
105
 
106
- During pre-training, the initial learning rate is $10^{-4}$, and the learning rate decays by 1% each epoch following the power law function $LR(x) = 10^{-4}\cdot(0.99)^x$, where $x$ is the number of epochs. Both pre-trained models reach a plateau in loss by epoch 50, at which point the training is stopped.
107
 
108
  ---
109
-
110
  ### Fine-tuning Methodology
111
 
112
  For downstream tasks, we adjust the model architecture for fine-tuning by replacing the original output layer (final linear layer) with a newly initialized linear layer while retaining the pre-trained weights for all other layers. This modification allows the model to specialize in the specific downstream task while leveraging the general features learned during pretraining.
113
 
114
- The fine-tuning process begins with distinct learning rate setups for different parts of the model. The newly initialized linear layer is trained with an initial learning rate of $10^{-4}$, matching the rate used for models trained from scratch. Meanwhile, the pre-trained layers are fine-tuned more cautiously with a lower initial learning rate of $10^{-5}$. This approach ensures that the pre-trained layers adapt gradually without losing their general features, while the new layer learns effectively from scratch. Both learning rates decay over time following the same power law function, $LR(x) = LR_{initial} \cdot (0.99)^x$, to promote stable convergence as training progresses.
115
 
116
  We also evaluated a transfer learning setup in which either the decoder MLP or the final linear layer was replaced with a newly initialized component. During this process, all other model parameters remained frozen, leveraging the pre-trained features without further updating them. However, we did not observe performance improvements using the transfer learning setup. Consequently, we focus on reporting results obtained with the fine-tuning approach.
117
 
@@ -123,28 +122,28 @@ We assess model performance using two figures of merit: the classification accur
123
 
124
  To obtain reliable performance estimates and uncertainties, we employ an ensemble training approach where 5 independent models are trained for each configuration with random weight initialization and random subsets of the training dataset. This enables us to evaluate both the models' sensitivity to initial parameters and to quantify uncertainties in their performance.
125
 
126
- To investigate how model performance scales with training data, we conducted training runs using sample sizes ranging from $10^3$ to $10^7$ events per class ($10^3$, $10^4$, $10^5$, $10^6$, and $10^7$) for each model setup: the from-scratch baseline and models fine-tuned from multi-class or multi-label pretrained models. For the $10^7$ case, only the initialization was randomized due to dataset size limitations. All models were evaluated on the same testing dataset, consisting of 2 million events per class, which remained separate from the training process.
127
-
128
- | **Name of Task** | **Pretraining Task** | $10^3$ | $10^4$ | $10^5$ | $10^6$ | $10^7$ |
129
- |----------------------|----------------------|----------------|----------------|----------------|----------------|----------------|
130
- | **ttH CP Even vs Odd** | Baseline Accuracy | 56.5 $\pm$ 1.1 | 62.2 $\pm$ 0.1 | 64.3 $\pm$ 0.0 | 65.7 $\pm$ 0.0 | 66.2 $\pm$ 0.0 |
131
- | | Multiclass (%) | +4.8 $\pm$ 1.1 | +3.4 $\pm$ 0.1 | +1.3 $\pm$ 0.0 | +0.2 $\pm$ 0.0 | -0.0 $\pm$ 0.0 |
132
- | | Multilabel (%) | +2.1 $\pm$ 1.2 | +1.9 $\pm$ 0.1 | +0.8 $\pm$ 0.1 | +0.0 $\pm$ 0.0 | -0.1 $\pm$ 0.0 |
133
- | **FCNC vs tHq** | Baseline Accuracy | 63.6 $\pm$ 0.7 | 67.8 $\pm$ 0.4 | 68.4 $\pm$ 0.3 | 69.3 $\pm$ 0.3 | 67.9 $\pm$ 0.0 |
134
- | | Multiclass (%) | +5.8 $\pm$ 0.8 | +1.2 $\pm$ 0.4 | +1.4 $\pm$ 0.3 | +0.5 $\pm$ 0.3 | -0.0 $\pm$ 0.0 |
135
- | | Multilabel (%) | -5.3 $\pm$ 0.8 | -1.3 $\pm$ 0.4 | +0.9 $\pm$ 0.4 | +0.3 $\pm$ 0.3 | +0.4 $\pm$ 0.1 |
136
- | **ttW vs ttt** | Baseline Accuracy | 75.8 $\pm$ 0.1 | 77.6 $\pm$ 0.1 | 78.9 $\pm$ 0.0 | 79.8 $\pm$ 0.0 | 80.3 $\pm$ 0.0 |
137
- | | Multiclass (%) | +3.7 $\pm$ 0.1 | +2.7 $\pm$ 0.1 | +1.3 $\pm$ 0.0 | +0.4 $\pm$ 0.0 | +0.0 $\pm$ 0.0 |
138
- | | Multilabel (%) | +2.2 $\pm$ 0.1 | +1.1 $\pm$ 0.1 | +0.5 $\pm$ 0.0 | +0.0 $\pm$ 0.0 | -0.1 $\pm$ 0.0 |
139
- | **stop vs ttH** | Baseline Accuracy | 83.0 $\pm$ 0.2 | 86.3 $\pm$ 0.1 | 87.6 $\pm$ 0.0 | 88.5 $\pm$ 0.0 | 88.8 $\pm$ 0.0 |
140
- | | Multiclass (%) | +0.4 $\pm$ 0.2 | +1.9 $\pm$ 0.1 | +1.0 $\pm$ 0.0 | +0.3 $\pm$ 0.0 | +0.0 $\pm$ 0.0 |
141
- | | Multilabel (%) | +2.8 $\pm$ 0.2 | +1.0 $\pm$ 0.1 | +0.5 $\pm$ 0.0 | +0.0 $\pm$ 0.0 | -0.0 $\pm$ 0.0 |
142
- | **WH vs ZH** | Baseline Accuracy | 51.4 $\pm$ 0.1 | 53.9 $\pm$ 0.1 | 55.8 $\pm$ 0.0 | 57.5 $\pm$ 0.0 | 58.0 $\pm$ 0.0 |
143
- | | Multiclass (%) | +5.2 $\pm$ 0.1 | +5.3 $\pm$ 0.1 | +3.1 $\pm$ 0.0 | +0.6 $\pm$ 0.0 | +0.1 $\pm$ 0.0 |
144
- | | Multilabel (%) | -1.1 $\pm$ 0.1 | -0.9 $\pm$ 0.2 | +0.5 $\pm$ 0.1 | +0.1 $\pm$ 0.0 | -0.1 $\pm$ 0.0 |
145
 
146
  > **Table 1**: Accuracy of the traditional model versus the accuracy increase due to fine-tuning from various pretraining tasks.
147
- > The accuracies are averaged over 5 independently trained models with randomly initialized weights and trained on a random subset of the data. One exception is the $10^7$ training where all models use the same dataset due to limitations on our dataset size. The random subsets are allowed to overlap, but this overlap should be very minimal because all models take an independent random subset of $10^7$ events. The testing accuracy is calculated from the same testing set of 2 million events per class across all models for a specific training task. The errors are the propagated errors (root sum of squares) of the standard deviation of accuracies for each model.
148
 
149
  ## Results
150
 
@@ -152,9 +151,9 @@ To investigate how model performance scales with training data, we conducted tra
152
 
153
  Since the observations of AUC and accuracy show similar trends, we focus the presentation of the results using accuracy here for conciseness in Table 1.
154
 
155
- In general, the fine-tuned pretrained model achieves at least the same level of classification performance as the baseline model. Notably, there are significant improvements, particularly when the sample size is small, ranging from \$10^3$ to \$10^4$ events. In some cases, the accuracy improvements exceed five percentage points, demonstrating that pretrained models provide a strong initial representation that compensates for limited data. The numerical values of the improvements in accuracy may not fully capture the impact on the sensitivity of the measurements for which the neural network classifier is used, and the final sensitivity improvement is likely to be greater.
156
 
157
- As the training sample size grows to \$10^5$, \$10^6$, and eventually \$10^7$ events, the added benefit of pretraining diminishes. With abundant data, models trained from scratch approach or even match the accuracy of fine-tuned pretrained models. This suggests that large datasets enable effective learning from scratch, rendering the advantage of pretraining negligible in such scenarios.
158
 
159
  Although both pretraining approaches offer benefits, multiclass pretraining tends to provide more consistent improvements across tasks, especially in the low-data regime. In contrast, multilabel pretraining can sometimes lead to neutral or even slightly negative effects for certain tasks and data sizes. This highlights the importance of the pretraining task design, as the similarity between pretraining and fine-tuning tasks in the multiclass approach appears to yield better-aligned representations.
160
 
@@ -172,45 +171,45 @@ The similarity is evaluated using a 64-dimensional latent representation after t
172
 
173
  To provide an intuitive understanding of CKA values, we construct a table of the CKA scores for various transformations performed on a set of dummy data.
174
 
175
- - **A:** randomly initialized matrix with shape (1000, 64), following a normal distribution ($\sigma = 1, \mu=0$)
176
- - **B:** matrix with shape (1000, 64) constructed via various transformations performed on $A$
177
- - **Noise:** randomly initialized noise matrix with shape (1000, 64), following a normal distribution ($\sigma = 1, \mu=0$)
178
 
179
  | Dataset | CKA Score |
180
  |---------|-----------|
181
- | $A, B = A$ | 1.00 |
182
- | $A, B =$ permutation on columns of $A$ | 1.00 |
183
- | $A, B = A + \mathrm{Noise}(0.1)$ | 0.99 |
184
- | $A, B = A + \mathrm{Noise}(0.5)$ | 0.80 |
185
- | $A, B = A + \mathrm{Noise}(0.75)$ | 0.77 |
186
- | $A, B = A\cdot \mathrm{Noise}(1)$ (Linear Transformation) | 0.76 |
187
- | $A, B = A + \mathrm{Noise}(1)$ | 0.69 |
188
- | $A, B = A + \mathrm{Noise}(2)$ | 0.51 |
189
- | $A, B = A + \mathrm{Noise}(5)$ | 0.39 |
190
-
191
- **Table 2:** CKA scores for a dummy dataset $A$ and $B$, where $B$ is created via various transformations performed on $A$.
192
 
193
  As seen in Table 2 and in the definition of the CKA, the CKA score is permutation-invariant. We will use the CKA score to evaluate the similarity between various models and gain insight into the learned representation of detector events in each model (i.e., the information that each model learns).
194
 
195
  We train ensembles of models for each training task to observe how the CKA score changes due to the random initialization of our models. The CKA score between two models is then defined to be:
196
 
197
- $$
198
- CKA(A, B) = \frac{1}{n^2}\sum_i^n \sum_j^n CKA(A_i, B_j)
199
- $$
200
 
201
- where $A_i$ is the representation learned by the $i^{th}$ model in an ensemble with $n$ total models. The error in CKA is the standard deviation of $CKA(A_i, B_j)$.
202
 
203
  Here we present results for the CKA similarity between the final model in each setup with the final model in the baseline, shown in Table 3.
204
 
205
  | Training Task | Baseline | Multiclass | Multilabel |
206
- |----------------------|------------------|-----------------|-----------------|
207
- | ttH CP Even vs Odd | 0.94 ± 0.05 | 0.82 ± 0.01 | 0.77 ± 0.06 |
208
- | FCNC vs tHq | 0.96 ± 0.03 | 0.76 ± 0.01 | 0.81 ± 0.01 |
209
- | ttW vs ttt | 0.91 ± 0.08 | 0.75 ± 0.10 | 0.72 ± 0.05 |
210
- | stop vs ttH | 0.87 ± 0.11 | 0.79 ± 0.12 | 0.71 ± 0.08 |
211
- | WH vs ZH | 0.90 ± 0.07 | 0.53 ± 0.03 | 0.44 ± 0.06 |
212
 
213
- **Table 3:** CKA Similarity of the latent representation before the decoder with the baseline model, averaged over 3 models per training setup, and all models trained with the full dataset (\$10^7$). The baseline column is not guaranteed to be 1.0 because of the random initialization of the model. Each baseline model converges to a slightly different representation as seen in the CKA values in that column.
214
 
215
  The baseline models with different initializations exhibit high similarity values, ranging from approximately 0.87 to 0.96, which indicates that independently trained baseline models tend to converge on similar internal representations despite random initialization. Across the considered tasks, models trained as multi-class or multi-label classifiers exhibit noticeably lower CKA similarity scores when compared to the baseline model. For example, in the WH vs ZH task, the baseline model and another baseline trained model have a high similarity of 0.90, whereas the multi-class and multi-label models show significantly reduced similarities (0.53 and 0.44, respectively). This pattern suggests that the representational spaces developed by multi-class or multi-label models differ substantially from those learned by the baseline model that was trained directly on the downstream classification task.
216
 
@@ -221,7 +220,7 @@ To estimate the computational resources required for each approach, we measured
221
  ![The ratio of the fine-tuning time required to achieve 99% of the baseline model's final classification accuracy to the total time spent training the baseline model.](training_time.png)
222
  *Fig. 1: The ratio of the fine-tuning time required to achieve 99% of the baseline model's final classification accuracy to the total time spent training the baseline model.*
223
 
224
- Figure 1 shows the fine-tuning time for the model pretrained with multiclass classification, relative to the time required for the baseline model, as a function of training sample size. In general, the fine-tuning time is significantly shorter than the training time required by the baseline model approach. For smaller training sets, on the order of $10^5$ events, tasks such as FCNC vs. tHq and ttW vs. ttt benefit substantially from the pretrained model’s “head start,” achieving their final performance in only about 1% of the baseline time. For large training datasets, the fine-tuning time relative to the baseline training time becomes larger; however, given that the large training sample typically requires longer training time, fine-tuning still yields much faster training convergence. The ttH CP-even vs. ttH CP-odd task, with a training sample size of $10^7$ events, is an exception where the fine-tuning time exceeds the training time required for the baseline model. This is likely because the processes involved in this task include photon objects in the final states, which are absent from the events used during pretraining.
225
 
226
  To accurately evaluate the total time consumption, it is necessary to include the pretraining time required for the foundation model approach. The pretraining times are as follows:
227
 
@@ -230,9 +229,9 @@ To accurately evaluate the total time consumption, it is necessary to include th
230
 
231
  The GPU hours recorded for the multi-label model represent the total time required when training the model in parallel on 16 GPUs. This includes a model synchronization step, which results in higher GPU hours compared to the multi-class pretraining model.
232
 
233
- The foundation model approach becomes increasingly efficient when a large number of tasks are fine-tuned using the same pretrained model, compared to training each task independently from scratch. To illustrate this, we evaluate the computational time required for a scenario where the training sample contains $10^7$ events. For the five tasks tested in this study, the baseline training time (training from scratch) ranges from 1.68 GPU hours (WH vs. ZH) to 5.30 GPU hours (ttW vs. ttt), with an average baseline training time of 2.94 GPU hours. In contrast, the average fine-tuning time for the foundation model approach, relative to the baseline, is 38% of the baseline training time for $10^7$ events. Based on these averages, we estimate that the foundation model approach becomes more computationally efficient than the baseline approach when fine-tuning is performed for more than 41 tasks.
234
 
235
- As a practical example, the ATLAS measurement of Higgs boson couplings using the $H \rightarrow \gamma\gamma$ decay channel [ATLAS Collaboration, 2023][ref-atlas-2023-higg] involved training 42 classifiers for event categorization. This coincides with our estimate, suggesting that the foundation model approach can reduce computational costs even for a single high-energy physics measurement.
236
 
237
  ## Conclusions
238
 
 
31
 
32
  To provide a diverse set of physics processes for the pretraining, we use Madgraph@NLO 2.7.3 [Alwall et al., 2014](#ref-alwall-2014hca) to generate proton-proton collision events at next-to-leading order (NLO) in Quantum Chromodynamics (QCD). We generate 12 distinct Standard Model (SM) physics processes, including six major Higgs boson production mechanisms: gluon fusion production \\(ggF\\), vector boson fusion \\(VBF\\), associated production of the Higgs boson with a W boson \\(WH\\) or a Z boson \\(ZH\\), associated production of the Higgs boson with a top-quark pair \\(t\bar{t}H\\), and associated production of the Higgs boson with a single top quark and a forward quark \\(tHq\\). Additionally, we simulate six top quark production processes: single top production, top-quark pair production \\(t\bar{t}\\), top quark pair production in association with a pair of photons \\(t\bar{t}\gamma\gamma\\), associated production of a top-quark pair with a W boson \\(t\bar{t}W\\), simultaneous production of three top quarks \\(t\bar{t}t\\), and simultaneous production of four top quarks \\(t\bar{t}t\bar{t}\\). In these samples, the Higgs boson and top quarks decay inclusively. These 12 Higgs and top quark production processes constitute the pretraining dataset.
33
 
34
+ To test the pretrained model, we further generated four processes including three beyond Standard Model (SM) processes: a SM \\(t\bar{t}H\\) production where the Higgs boson decays exclusively to a pair of photons, a \\(t\bar{t}H\\) production with the Higgs boson decaying to a pair of photons, where the top-Yukawa coupling is CP-odd, implemented using the Higgs Characterization model [Artoisenet et al., 2013](#ref-artoisinet-2013puc), the production of a pair of superpartners of the top quark (s-top) using the Minimal Supersymmetric Standard Model (MSSM) [Rosiek, 1990](#ref-rosiek-1990), [Allanach et al., 2009](#ref-allanach-2009), and flavor changing neutral current (FCNC) processes [Degrande et al., 2015](#ref-degrande-2015), [Durieux et al., 2015](#ref-durieux-2015). For the s-top process, we simulate the production of heavier s-top pairs \\(t_2\bar{t_2}\\), where each heavier s-top (mass 582 GeV) decays into a lighter s-top \\(t_1\\) or \\(\bar{t_1}\\), mass 400 GeV) and a Higgs boson. The FCNC process involves \\(t\bar{t}\\) production where one top quark decays to a Higgs boson and a light quark. We generate 10 million events for each process, except for \\(tHq\\) and \\(t\bar{t}t\bar{t}\\), where 5 million events were produced.
35
 
36
+ In all simulation samples, the center of mass energy of the proton-proton collision is set to 13 TeV. The Higgs boson, top quarks, and vector bosons are set to decay inclusively (except the \\(t\bar{t}H \rightarrow \gamma\gamma\\) samples), with MadSpin [Artoisenet et al., 2012](#ref-artoisinet-2012st) handling the decays of top quarks and W bosons. The generated events are processed through Pythia 8.235 [Sjostrand et al., 2015](#ref-sjostrand-2015) for parton showering and heavy particle decays, followed by Delphes 3.4.2 [de Favereau et al., 2014](#ref-defavereau-2014) configured to emulate the ATLAS detector [ATLAS Collaboration, 2008](#ref-atlas-2008) for fast detector simulation.
37
 
38
+ The detector-level object selection criteria are defined to align with typical experimental conditions. Photons are required to have transverse momentum \\(p_T \geq 20~\mathrm{GeV}\\) and pseudorapidity \\(|\eta| \leq 2.37\\), excluding the electromagnetic calorimeter crack region \\(1.37 < |\eta| < 1.52\\). Electrons must have \\(p_T \geq 10~\mathrm{GeV}\\) and \\(|\eta| \leq 2.47\\) (excluding the same crack region), while muons are selected with \\(p_T \geq 10~\mathrm{GeV}\\) and \\(|\eta| \leq 2.7\\). Jets are reconstructed using the anti-\\(k_t\\) algorithm [Cacciari et al., 2008](#ref-cacciari-2008gp) with radius parameter \\(\Delta R=0.4\\), where \\(\Delta R\\) is defined as \\(\sqrt{\Delta\eta ^2 + \Delta\phi^2}\\), with \\(\Delta\eta\\) being the difference in pseudorapidity and \\(\Delta\phi\\) the difference in azimuthal angle. Jets must satisfy \\(p_T \geq 25~\mathrm{GeV}\\) and \\(|\eta| \leq 2.5\\). To avoid double-counting, jets are removed if they are within \\(\Delta R < 0.4\\) of a photon or lepton. The identification of jets originating from b-quark decays (b-tagging) is performed by matching jets within \\(\Delta R = 0.4\\) of a b-quark, with efficiency corrections applied to match the performance of the ATLAS experiment's b-tagging algorithm [ATLAS Collaboration, 2019](#ref-atlas-2019bwq).
39
 
40
  ## Methods
41
 
 
45
 
46
  Given the prevalence of classification problems in particle physics data analysis, we evaluate the model's efficacy through a systematic assessment across five binary classification tasks:
47
 
48
+ - \\(t\bar{t}H(\rightarrow \gamma\gamma)\\) with CP-even versus CP-odd t-H interaction
49
+ - \\(t\bar{t}\\) with FCNC top quark decays versus $tHq$ processes
50
+ - \\(t\bar{t}W\\) versus $ttt$ processes
51
+ - Stop pair production with Higgs bosons in the decay chain versus \\(t\bar{t}H\\) processes
52
+ - \\(WH\\) versus \\(ZH\\) production modes
53
 
54
  Our evaluation metrics encompass classification performance, computational efficiency, and model interpretability. The investigation extends to analyzing the model's scaling behavior with respect to training dataset size, benchmarked against models trained without pretraining. Although we explored transfer learning through parameter freezing of pretrained layers, this approach did not yield performance improvements, leading us to focus our detailed analysis on fine-tuning strategies.
55
 
 
59
 
60
  ### GNN Architecture
61
 
62
+ We implement a Graph Neural Network (GNN) architecture that naturally accommodates the point-cloud structure of particle physics data, employing the DGL framework with a PyTorch backend [Wang et al., 2019][ref-dgl-2019], [Paszke et al., 2019][ref-pytorch-2019]. A fully connected graph is constructed for each event, with nodes corresponding to reconstructed jets, electrons, muons, photons, and \\(\vec{E}_T^{\text{miss}}\\). The features of each node include the four-momentum \\((p_T, \eta, \phi, E)\\) of the object with a massless assumption (\\(E = p_T \cosh \eta\\)), the b-tagging label (for jets), the charge (for leptons), and an integer labeling the type of object represented by the node. We use a placeholder value of 0 for features which are not defined for every node type such as the b-jet tag, lepton charge, or the pseudorapidity of \\(\vec{E}_T^{\text{miss}}\\). We assign the angular distances (\\(\Delta \eta, \Delta \phi, \Delta R\\)) as edge features and the number of nodes $N$ in the graph as a global feature. We denote the node features \\(\{\vec x_i\}\\), edge features \\(\{\vec y_{ij}\}\\), and global features \\(\{\vec z\}\\).
63
 
64
  The GNN model is based on the graph network architecture described in [Battaglia et al., 2018][ref-graphnets-2018] using simple multilayer perceptron (MLP) feature functions and summation aggregation. The model is comprised of three primary components: an encoder, the graph network, and a decoder. In the encoder, three MLPs embed the nodes, edges, and global features into a latent space of dimension 64. The graph network block, which is designed to facilitate message passing between different domains of the graph, performs an edge update $f_e$, followed by a node update $f_n$, and finally a global update $f_g$, all defined below. The inputs to each update MLP are concatenated.
65
 
 
77
 
78
  This graph block is iterated four times with the same update MLPs. Finally, the global features are passed through a decoder MLP and a final layer linear to produce the desired model outputs. Each MLP consists of 4 linear layers, each with an output width of 64, with the `ReLU` activation function. The output of the MLP is then passed through a `LayerNorm` layer [Ba et al., 2016][ref-layernorm-2016]. The total number of trainable parameters in this model is about 400,000.
79
 
80
+ As a performance benchmark, a baseline GNN model is trained from scratch for each classification task. The initial learning rate is set to \\(10^{-4}\\) with an exponential decay following \\(LR(x) = LR_{\text{initial}}\cdot(0.99)^x\\), where \\(x\\) represents the epoch number.
81
 
82
  ---
83
 
 
95
 
96
  We develop a comprehensive set of 41 labels that capture both particle multiplicities and kinematic properties. This approach increases prediction granularity and enhances model interpretability. By training the model to predict event kinematics rather than event identification, we create a task-independent framework that can potentially generalize better to novel scenarios not seen during pretraining.
97
 
98
+ The particle multiplicity labels count the number of Higgs bosons (\\(n_{\text{higgs}}\\)), top quarks (\\(n_{\text{tops}}\\)), vector bosons (\\(n_V\\)), \\(W\\) bosons (\\(n_W\\)), and \\(Z\\) bosons (\\(n_Z\\)). The kinematic labels characterize the transverse momentum (\\(p_T\\)), pseudorapidity (\\(\eta\\)), and azimuthal angle (\\(\phi\\)) of Higgs bosons and top quarks through binned classifications.
99
 
100
+ For Higgs bosons, $p_T$ is categorized into three ranges: (0, 30) GeV, (30, 200) GeV, and (200, \\(\infty\\)) GeV, with the upper range particularly sensitive to potential BSM effects. Similarly, both leading and subleading top quarks have $p_T$ classifications spanning (0, 30) GeV, (30, 300) GeV, and (300, \\(\infty\\)) GeV. When no particle exists within a specific \\(p_T\\) range, the corresponding label is set to \\([0, 0, 0]\\). For all particles, \\(\eta\\) measurements are divided into 4 bins with boundaries at \\([-1.5, 0, 1.5]\\), while \\(\phi\\) measurements use 4 bins with boundaries at \\([-\frac{\pi}{2}, 0, \frac{\pi}{2}]\\). As with \\(p_T\\), both \\(\eta\\) and \\(\phi\\) labels default to \\([0, 0, 0, 0]\\) in the absence of a particle. This comprehensive labeling schema enables fine-grained learning of kinematic distributions and particle multiplicities, essential for characterizing complex collision events.
101
 
102
  The loss function combines individual losses from all 41 labels through weighted averaging. Binary cross-entropy is applied to classification labels, while mean-squared error is used for regression labels. The model generates predictions for all labels simultaneously, with individual losses calculated according to their respective types. The final loss is computed as an equally-weighted average across all labels, with weights set to 1 to ensure uniform contribution to the optimization process. The output layer of the multilabel model has 2,688 trainable parameters.
103
 
104
  #### Pretraining
105
 
106
+ During pre-training, the initial learning rate is \\(10^{-4}\\), and the learning rate decays by 1% each epoch following the power law function \\(LR(x) = 10^{-4}\cdot(0.99)^x\\), where \\(x\\) is the number of epochs. Both pre-trained models reach a plateau in loss by epoch 50, at which point the training is stopped.
107
 
108
  ---
 
109
  ### Fine-tuning Methodology
110
 
111
  For downstream tasks, we adjust the model architecture for fine-tuning by replacing the original output layer (final linear layer) with a newly initialized linear layer while retaining the pre-trained weights for all other layers. This modification allows the model to specialize in the specific downstream task while leveraging the general features learned during pretraining.
112
 
113
+ The fine-tuning process begins with distinct learning rate setups for different parts of the model. The newly initialized linear layer is trained with an initial learning rate of \\(10^{-4}\\), matching the rate used for models trained from scratch. Meanwhile, the pre-trained layers are fine-tuned more cautiously with a lower initial learning rate of \\(10^{-5}\\). This approach ensures that the pre-trained layers adapt gradually without losing their general features, while the new layer learns effectively from scratch. Both learning rates decay over time following the same power law function, \\(LR(x) = LR_{initial} \cdot (0.99)^x\\), to promote stable convergence as training progresses.
114
 
115
  We also evaluated a transfer learning setup in which either the decoder MLP or the final linear layer was replaced with a newly initialized component. During this process, all other model parameters remained frozen, leveraging the pre-trained features without further updating them. However, we did not observe performance improvements using the transfer learning setup. Consequently, we focus on reporting results obtained with the fine-tuning approach.
116
 
 
122
 
123
  To obtain reliable performance estimates and uncertainties, we employ an ensemble training approach where 5 independent models are trained for each configuration with random weight initialization and random subsets of the training dataset. This enables us to evaluate both the models' sensitivity to initial parameters and to quantify uncertainties in their performance.
124
 
125
+ To investigate how model performance scales with training data, we conducted training runs using sample sizes ranging from \\(10^3\\) to \\(10^7\\) events per class (\\(10^3\\), \\(10^4\\), \\(10^5\\), \\(10^6\\), and \\(10^7\\)) for each model setup: the from-scratch baseline and models fine-tuned from multi-class or multi-label pretrained models. For the \\(10^7\\) case, only the initialization was randomized due to dataset size limitations. All models were evaluated on the same testing dataset, consisting of 2 million events per class, which remained separate from the training process.
126
+
127
+ | **Name of Task** | **Pretraining Task** | \\(10^3\\) | \\(10^4\\) | \\(10^5\\) | \\(10^6\\) | \\(10^7\\) |
128
+ |----------------------|----------------------|--------------------|--------------------|--------------------|--------------------|--------------------|
129
+ | **ttH CP Even vs Odd** | Baseline Accuracy | 56.5 ± 1.1 | 62.2 ± 0.1 | 64.3 ± 0.0 | 65.7 ± 0.0 | 66.2 ± 0.0 |
130
+ | | Multiclass (%) | +4.8 ± 1.1 | +3.4 ± 0.1 | +1.3 ± 0.0 | +0.2 ± 0.0 | 0.0 ± 0.0 |
131
+ | | Multilabel (%) | +2.1 ± 1.2 | +1.9 ± 0.1 | +0.8 ± 0.1 | +0.0 ± 0.0 | 0.1 ± 0.0 |
132
+ | **FCNC vs tHq** | Baseline Accuracy | 63.6 ± 0.7 | 67.8 ± 0.4 | 68.4 ± 0.3 | 69.3 ± 0.3 | 67.9 ± 0.0 |
133
+ | | Multiclass (%) | +5.8 ± 0.8 | +1.2 ± 0.4 | +1.4 ± 0.3 | +0.5 ± 0.3 | 0.0 ± 0.0 |
134
+ | | Multilabel (%) | 5.3 ± 0.8 | 1.3 ± 0.4 | +0.9 ± 0.4 | +0.3 ± 0.3 | +0.4 ± 0.1 |
135
+ | **ttW vs ttt** | Baseline Accuracy | 75.8 ± 0.1 | 77.6 ± 0.1 | 78.9 ± 0.0 | 79.8 ± 0.0 | 80.3 ± 0.0 |
136
+ | | Multiclass (%) | +3.7 ± 0.1 | +2.7 ± 0.1 | +1.3 ± 0.0 | +0.4 ± 0.0 | +0.0 ± 0.0 |
137
+ | | Multilabel (%) | +2.2 ± 0.1 | +1.1 ± 0.1 | +0.5 ± 0.0 | +0.0 ± 0.0 | 0.1 ± 0.0 |
138
+ | **stop vs ttH** | Baseline Accuracy | 83.0 ± 0.2 | 86.3 ± 0.1 | 87.6 ± 0.0 | 88.5 ± 0.0 | 88.8 ± 0.0 |
139
+ | | Multiclass (%) | +0.4 ± 0.2 | +1.9 ± 0.1 | +1.0 ± 0.0 | +0.3 ± 0.0 | +0.0 ± 0.0 |
140
+ | | Multilabel (%) | +2.8 ± 0.2 | +1.0 ± 0.1 | +0.5 ± 0.0 | +0.0 ± 0.0 | 0.0 ± 0.0 |
141
+ | **WH vs ZH** | Baseline Accuracy | 51.4 ± 0.1 | 53.9 ± 0.1 | 55.8 ± 0.0 | 57.5 ± 0.0 | 58.0 ± 0.0 |
142
+ | | Multiclass (%) | +5.2 ± 0.1 | +5.3 ± 0.1 | +3.1 ± 0.0 | +0.6 ± 0.0 | +0.1 ± 0.0 |
143
+ | | Multilabel (%) | 1.1 ± 0.1 | 0.9 ± 0.2 | +0.5 ± 0.1 | +0.1 ± 0.0 | 0.1 ± 0.0 |
144
 
145
  > **Table 1**: Accuracy of the traditional model versus the accuracy increase due to fine-tuning from various pretraining tasks.
146
+ > The accuracies are averaged over 5 independently trained models with randomly initialized weights and trained on a random subset of the data. One exception is the \\(10^7\\) training where all models use the same dataset due to limitations on our dataset size. The random subsets are allowed to overlap, but this overlap should be very minimal because all models take an independent random subset of \\(10^7\\) events. The testing accuracy is calculated from the same testing set of 2 million events per class across all models for a specific training task. The errors are the propagated errors (root sum of squares) of the standard deviation of accuracies for each model.
147
 
148
  ## Results
149
 
 
151
 
152
  Since the observations of AUC and accuracy show similar trends, we focus the presentation of the results using accuracy here for conciseness in Table 1.
153
 
154
+ In general, the fine-tuned pretrained model achieves at least the same level of classification performance as the baseline model. Notably, there are significant improvements, particularly when the sample size is small, ranging from \\(10^3\\) to \\(10^4\\) events. In some cases, the accuracy improvements exceed five percentage points, demonstrating that pretrained models provide a strong initial representation that compensates for limited data. The numerical values of the improvements in accuracy may not fully capture the impact on the sensitivity of the measurements for which the neural network classifier is used, and the final sensitivity improvement is likely to be greater.
155
 
156
+ As the training sample size grows to \\(10^5\\), \\(10^6\\), and eventually \\(10^7\\) events, the added benefit of pretraining diminishes. With abundant data, models trained from scratch approach or even match the accuracy of fine-tuned pretrained models. This suggests that large datasets enable effective learning from scratch, rendering the advantage of pretraining negligible in such scenarios.
157
 
158
  Although both pretraining approaches offer benefits, multiclass pretraining tends to provide more consistent improvements across tasks, especially in the low-data regime. In contrast, multilabel pretraining can sometimes lead to neutral or even slightly negative effects for certain tasks and data sizes. This highlights the importance of the pretraining task design, as the similarity between pretraining and fine-tuning tasks in the multiclass approach appears to yield better-aligned representations.
159
 
 
171
 
172
  To provide an intuitive understanding of CKA values, we construct a table of the CKA scores for various transformations performed on a set of dummy data.
173
 
174
+ - **A:** randomly initialized matrix with shape (1000, 64), following a normal distribution (\\(\sigma = 1, \mu = 0\\))
175
+ - **B:** matrix with shape (1000, 64) constructed via various transformations performed on \\(A\\)
176
+ - **Noise:** randomly initialized noise matrix with shape (1000, 64), following a normal distribution (\\(\sigma = 1, \mu = 0\\))
177
 
178
  | Dataset | CKA Score |
179
  |---------|-----------|
180
+ | \\(A, B = A\\) | 1.00 |
181
+ | \\(A, B =\\) permutation on columns of \\(A\\) | 1.00 |
182
+ | \\(A, B = A + \mathrm{Noise}(0.1)\\) | 0.99 |
183
+ | \\(A, B = A + \mathrm{Noise}(0.5)\\) | 0.80 |
184
+ | \\(A, B = A + \mathrm{Noise}(0.75)\\) | 0.77 |
185
+ | \\(A, B = A \cdot \mathrm{Noise}(1)\\) (Linear Transformation) | 0.76 |
186
+ | \\(A, B = A + \mathrm{Noise}(1)\\) | 0.69 |
187
+ | \\(A, B = A + \mathrm{Noise}(2)\\) | 0.51 |
188
+ | \\(A, B = A + \mathrm{Noise}(5)\\) | 0.39 |
189
+
190
+ **Table 2:** CKA scores for a dummy dataset \\(A\\) and \\(B\\), where \\(B\\) is created via various transformations performed on \\(A\\).
191
 
192
  As seen in Table 2 and in the definition of the CKA, the CKA score is permutation-invariant. We will use the CKA score to evaluate the similarity between various models and gain insight into the learned representation of detector events in each model (i.e., the information that each model learns).
193
 
194
  We train ensembles of models for each training task to observe how the CKA score changes due to the random initialization of our models. The CKA score between two models is then defined to be:
195
 
196
+ \\[
197
+ CKA(A, B) = \frac{1}{n^2} \sum_i^n \sum_j^n CKA(A_i, B_j)
198
+ \\]
199
 
200
+ where \\(A_i\\) is the representation learned by the \\(i^{\text{th}}\\) model in an ensemble with \\(n\\) total models. The error in CKA is the standard deviation of \\(CKA(A_i, B_j)\\).
201
 
202
  Here we present results for the CKA similarity between the final model in each setup with the final model in the baseline, shown in Table 3.
203
 
204
  | Training Task | Baseline | Multiclass | Multilabel |
205
+ |-----------------------|------------------|-----------------|-----------------|
206
+ | ttH CP Even vs Odd | 0.94 ± 0.05 | 0.82 ± 0.01 | 0.77 ± 0.06 |
207
+ | FCNC vs tHq | 0.96 ± 0.03 | 0.76 ± 0.01 | 0.81 ± 0.01 |
208
+ | ttW vs ttt | 0.91 ± 0.08 | 0.75 ± 0.10 | 0.72 ± 0.05 |
209
+ | stop vs ttH | 0.87 ± 0.11 | 0.79 ± 0.12 | 0.71 ± 0.08 |
210
+ | WH vs ZH | 0.90 ± 0.07 | 0.53 ± 0.03 | 0.44 ± 0.06 |
211
 
212
+ **Table 3:** CKA Similarity of the latent representation before the decoder with the baseline model, averaged over 3 models per training setup, and all models trained with the full dataset (\\(10^7\\)). The baseline column is not guaranteed to be 1.0 because of the random initialization of the model. Each baseline model converges to a slightly different representation as seen in the CKA values in that column.
213
 
214
  The baseline models with different initializations exhibit high similarity values, ranging from approximately 0.87 to 0.96, which indicates that independently trained baseline models tend to converge on similar internal representations despite random initialization. Across the considered tasks, models trained as multi-class or multi-label classifiers exhibit noticeably lower CKA similarity scores when compared to the baseline model. For example, in the WH vs ZH task, the baseline model and another baseline trained model have a high similarity of 0.90, whereas the multi-class and multi-label models show significantly reduced similarities (0.53 and 0.44, respectively). This pattern suggests that the representational spaces developed by multi-class or multi-label models differ substantially from those learned by the baseline model that was trained directly on the downstream classification task.
215
 
 
220
  ![The ratio of the fine-tuning time required to achieve 99% of the baseline model's final classification accuracy to the total time spent training the baseline model.](training_time.png)
221
  *Fig. 1: The ratio of the fine-tuning time required to achieve 99% of the baseline model's final classification accuracy to the total time spent training the baseline model.*
222
 
223
+ Figure 1 shows the fine-tuning time for the model pretrained with multiclass classification, relative to the time required for the baseline model, as a function of training sample size. In general, the fine-tuning time is significantly shorter than the training time required by the baseline model approach. For smaller training sets, on the order of \\(10^5\\) events, tasks such as FCNC vs. tHq and ttW vs. ttt benefit substantially from the pretrained model’s “head start,” achieving their final performance in only about 1% of the baseline time. For large training datasets, the fine-tuning time relative to the baseline training time becomes larger; however, given that the large training sample typically requires longer training time, fine-tuning still yields much faster training convergence. The ttH CP-even vs. ttH CP-odd task, with a training sample size of \\(10^7\\) events, is an exception where the fine-tuning time exceeds the training time required for the baseline model. This is likely because the processes involved in this task include photon objects in the final states, which are absent from the events used during pretraining.
224
 
225
  To accurately evaluate the total time consumption, it is necessary to include the pretraining time required for the foundation model approach. The pretraining times are as follows:
226
 
 
229
 
230
  The GPU hours recorded for the multi-label model represent the total time required when training the model in parallel on 16 GPUs. This includes a model synchronization step, which results in higher GPU hours compared to the multi-class pretraining model.
231
 
232
+ The foundation model approach becomes increasingly efficient when a large number of tasks are fine-tuned using the same pretrained model, compared to training each task independently from scratch. To illustrate this, we evaluate the computational time required for a scenario where the training sample contains \\(10^7\\) events. For the five tasks tested in this study, the baseline training time (training from scratch) ranges from 1.68 GPU hours (WH vs. ZH) to 5.30 GPU hours (ttW vs. ttt), with an average baseline training time of 2.94 GPU hours. In contrast, the average fine-tuning time for the foundation model approach, relative to the baseline, is 38% of the baseline training time for \\(10^7\\) events. Based on these averages, we estimate that the foundation model approach becomes more computationally efficient than the baseline approach when fine-tuning is performed for more than 41 tasks.
233
 
234
+ As a practical example, the ATLAS measurement of Higgs boson couplings using the \\(H \rightarrow \gamma\gamma\\) decay channel [ATLAS Collaboration, 2023][ref-atlas-2023-higg] involved training 42 classifiers for event categorization. This coincides with our estimate, suggesting that the foundation model approach can reduce computational costs even for a single high-energy physics measurement.
235
 
236
  ## Conclusions
237
 
root_gnn_dgl/README.md CHANGED
@@ -1,15 +1,19 @@
1
 
2
  # root_gnn_dgl
3
 
4
- Pretrained DGL-based ROOT graph neural network.
 
 
 
 
5
 
6
  ## Overview
7
  - Stable release with pretrained model weights.
8
 
9
- Pretrained model location: ``
10
-
11
  ## Conda setup
12
 
 
 
13
  ```bash
14
  cd setup
15
  conda env create -f environment.yml
@@ -23,6 +27,8 @@ python setup/test_setup.py
23
  - NERSC Perlmutter environment with `podman-hpc` available.
24
  - Access to `joshuaho/pytorch:1.0` on Docker Hub [https://hub.docker.com/r/joshuaho/pytorch](https://hub.docker.com/r/joshuaho/pytorch)
25
 
 
 
26
  ### Pull the Prebuilt Image
27
 
28
  ```bash
 
1
 
2
  # root_gnn_dgl
3
 
4
+ Pretrained DGL-based ROOT graph neural network.
5
+
6
+ Pretrained model location: `/global/cfs/projectdirs/atlas/joshua/Pretrained_GNN/multiclass_pretrained_model_12/`
7
+ To use the pretrained model, take a look at a finetuning config in `configs`.
8
+ Replace `pretraining_path:` with `/global/cfs/projectdirs/atlas/joshua/Pretrained_GNN/multiclass_pretrained_model_12/model_epoch_71.pt`.
9
 
10
  ## Overview
11
  - Stable release with pretrained model weights.
12
 
 
 
13
  ## Conda setup
14
 
15
+ The conda environment is required for the inference step: applying the GNN onto root files and saving GNN scores as an additional branch. This is because the infereces script uses pyROOT.
16
+
17
  ```bash
18
  cd setup
19
  conda env create -f environment.yml
 
27
  - NERSC Perlmutter environment with `podman-hpc` available.
28
  - Access to `joshuaho/pytorch:1.0` on Docker Hub [https://hub.docker.com/r/joshuaho/pytorch](https://hub.docker.com/r/joshuaho/pytorch)
29
 
30
+ The inference step requires the conda environment, since the container does not contain ROOT.
31
+
32
  ### Pull the Prebuilt Image
33
 
34
  ```bash
root_gnn_dgl/scripts/inference.py CHANGED
@@ -1,6 +1,5 @@
1
  import sys
2
- import os
3
- file_path = os.getcwd()
4
  sys.path.append(file_path)
5
  import os
6
  import argparse
@@ -135,6 +134,7 @@ def main():
135
 
136
  import time
137
  start = time.time()
 
138
  import torch
139
  from array import array
140
  import numpy as np
@@ -186,6 +186,7 @@ def main():
186
  lend = time.time()
187
  print('Loader finished in {:.2f} seconds'.format(lend - lstart))
188
  sample_graph, _, _, global_sample = loader[0]
 
189
 
190
  print('dset length =', len(dset))
191
  print('loader length =', len(loader))
@@ -197,6 +198,7 @@ def main():
197
  for config_file, branch in zip(args.config, args.branch_name):
198
  config = load_config(config_file)
199
  model = utils.buildFromConfig(config['Model'], {'sample_graph' : sample_graph, 'sample_global': global_sample}).to(device)
 
200
  if args.ckpt < 0:
201
  ep, checkpoint = utils.get_best_epoch(config, var=args.var, mode='max', device=device)
202
  else:
@@ -246,41 +248,55 @@ def main():
246
  all_labels[branch] = labels
247
  all_tracking[branch] = tracking_info
248
 
249
-
250
  if args.write:
251
- import uproot
252
- import awkward as ak
253
-
254
- # Open the original ROOT file and get the tree
255
- infile = uproot.open(args.target)
256
- tree = infile[dset_config['args']['tree_name']]
257
-
258
- # Read the original tree as an awkward array
259
- original_data = tree.arrays(library="ak")
260
 
261
- # Prepare new branches as dicts of arrays
262
- new_branches = {}
263
- n_entries = len(original_data)
264
- for branch, scores in all_scores.items():
265
- # Ensure the scores array is the right length
266
- scores = np.asarray(scores)
267
- if scores.shape[0] != n_entries:
268
- raise ValueError(f"Branch '{branch}' has {scores.shape[0]} entries, but tree has {n_entries}")
269
- new_branches[branch] = scores
270
-
271
- # Merge all arrays (original + new branches)
272
- # Convert awkward to dict of numpy arrays for uproot
273
- out_dict = {k: np.asarray(v) for k, v in ak.to_numpy(original_data).items()}
274
- out_dict.update(new_branches)
275
-
276
- # Write to new ROOT file
277
  os.makedirs(os.path.split(args.destination)[0], exist_ok=True)
278
- with uproot.recreate(args.destination) as outfile:
279
- outfile.mktree(dset_config['args']['tree_name'], {k: v.dtype for k, v in out_dict.items()})
280
- outfile[dset_config['args']['tree_name']].extend(out_dict)
281
 
282
- print(f"Wrote new ROOT file {args.destination} with new branches {list(new_branches.keys())}")
 
 
 
 
283
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  else:
285
  os.makedirs(os.path.split(args.destination)[0], exist_ok=True)
286
  np.savez(args.destination, scores=all_scores, labels=all_labels, tracking_info=all_tracking)
 
1
  import sys
2
+ file_path = "/global/cfs/projectdirs/atlas/joshua/root_gnn/root_gnn_dgl"
 
3
  sys.path.append(file_path)
4
  import os
5
  import argparse
 
134
 
135
  import time
136
  start = time.time()
137
+ import ROOT
138
  import torch
139
  from array import array
140
  import numpy as np
 
186
  lend = time.time()
187
  print('Loader finished in {:.2f} seconds'.format(lend - lstart))
188
  sample_graph, _, _, global_sample = loader[0]
189
+ global_sample = []
190
 
191
  print('dset length =', len(dset))
192
  print('loader length =', len(loader))
 
198
  for config_file, branch in zip(args.config, args.branch_name):
199
  config = load_config(config_file)
200
  model = utils.buildFromConfig(config['Model'], {'sample_graph' : sample_graph, 'sample_global': global_sample}).to(device)
201
+
202
  if args.ckpt < 0:
203
  ep, checkpoint = utils.get_best_epoch(config, var=args.var, mode='max', device=device)
204
  else:
 
248
  all_labels[branch] = labels
249
  all_tracking[branch] = tracking_info
250
 
 
251
  if args.write:
252
+ from ROOT import std
253
+ # Open the original ROOT file
254
+ infile = ROOT.TFile.Open(args.target)
255
+ tree = infile.Get(dset_config['args']['tree_name'])
 
 
 
 
 
256
 
257
+ # Create the destination directory if it doesn't exist
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  os.makedirs(os.path.split(args.destination)[0], exist_ok=True)
 
 
 
259
 
260
+ # Create a new ROOT file to write the modified tree
261
+ outfile = ROOT.TFile.Open(args.destination, 'RECREATE')
262
+
263
+ # Clone the original tree structure
264
+ outtree = tree.CloneTree(0)
265
 
266
+ # Create branches for all scores
267
+ branch_vectors = {}
268
+ for branch, scores in all_scores.items():
269
+ if isinstance(scores[0], (list, tuple, np.ndarray)) and len(scores[0]) > 1:
270
+ # Create a new branch for vectors
271
+ branch_vectors[branch] = std.vector('float')()
272
+ outtree.Branch(branch, branch_vectors[branch])
273
+ else:
274
+ # Create a new branch for single floats
275
+ branch_vectors[branch] = array('f', [0])
276
+ outtree.Branch(branch, branch_vectors[branch], f'{branch}/F')
277
+
278
+ # Fill the tree
279
+ for i in range(tree.GetEntries()):
280
+ tree.GetEntry(i)
281
+
282
+ for branch, scores in all_scores.items():
283
+ branch_data = branch_vectors[branch]
284
+ if isinstance(branch_data, array): # Check if it's a single float array
285
+ branch_data[0] = float(scores[i])
286
+ else: # Assume it's a std::vector<float>
287
+ branch_data.clear()
288
+ for value in scores[i]:
289
+ branch_data.push_back(float(value))
290
+
291
+ outtree.Fill()
292
+
293
+ # Write the modified tree to the new file
294
+ print(f'Writing to file {args.destination}')
295
+ print(f'Input entries: {tree.GetEntries()}, Output entries: {outtree.GetEntries()}')
296
+ print(f'Wrote scores to {args.branch_name}')
297
+ outtree.Write()
298
+ outfile.Close()
299
+ infile.Close()
300
  else:
301
  os.makedirs(os.path.split(args.destination)[0], exist_ok=True)
302
  np.savez(args.destination, scores=all_scores, labels=all_labels, tracking_info=all_tracking)
root_gnn_dgl/setup/Dockerfile CHANGED
@@ -7,19 +7,16 @@ LABEL maintainer.email="ho22joshua@berkeley.edu"
7
 
8
  ENV LANG=C.UTF-8
9
 
10
- # Install system dependencies: vim, OpenMPI, and build tools
11
  RUN apt-get update -qq \
12
  && apt-get install -y --no-install-recommends \
13
- wget lsb-release gnupg software-properties-common \
14
  vim \
15
  g++-11 gcc-11 libstdc++-11-dev \
16
  openmpi-bin openmpi-common libopenmpi-dev \
17
  && rm -rf /var/lib/apt/lists/*
18
 
19
- # Install Python packages: mpi4py and jupyter
20
  RUN pip install --no-cache-dir mpi4py jupyter uproot
21
 
22
- # (Optional) Expose Jupyter port
23
- EXPOSE 8888
24
-
25
-
 
7
 
8
  ENV LANG=C.UTF-8
9
 
10
+ # System deps (with CA certs for HTTPS downloads)
11
  RUN apt-get update -qq \
12
  && apt-get install -y --no-install-recommends \
13
+ wget curl ca-certificates lsb-release gnupg software-properties-common \
14
  vim \
15
  g++-11 gcc-11 libstdc++-11-dev \
16
  openmpi-bin openmpi-common libopenmpi-dev \
17
  && rm -rf /var/lib/apt/lists/*
18
 
19
+ # Python packages
20
  RUN pip install --no-cache-dir mpi4py jupyter uproot
21
 
22
+ EXPOSE 8888