Buckets:

|
download
raw
87.5 kB

Title: Gradient Norm Aware Minimization Seeks First-Order Flatness and Improves Generalization

URL Source: https://arxiv.org/html/2303.03108

Markdown Content: \AtAppendix\AtAppendix\AtAppendix Gradient Norm Aware Minimization Seeks First-Order Flatness and Improves Generalization Xingxuan Zhang † , Renzhe Xu † , Han Yu, Hao Zou, Peng Cui* Department of Computer Science, Tsinghua University xingxuanzhang@hotmail.com, xrz199721@gmail.com yuh21@mails.tsinghua.edu.cn, zouh18@mails.tsinghua.edu.cn, cuip@tsinghua.edu.cn Abstract

Recently, flat minima are proven to be effective for improving generalization and sharpness-aware minimization (SAM) achieves state-of-the-art performance. Yet the current definition of flatness discussed in SAM and its follow-ups are limited to the zeroth-order flatness (i.e., the worst-case loss within a perturbation radius). We show that the zeroth-order flatness can be insufficient to discriminate minima with low generalization error from those with high generalization error both when there is a single minimum or multiple minima within the given perturbation radius. Thus we present first-order flatness, a stronger measure of flatness focusing on the maximal gradient norm within a perturbation radius which bounds both the maximal eigenvalue of Hessian at local minima and the regularization function of SAM. We also present a novel training procedure named Gradient norm Aware Minimization (GAM) to seek minima with uniformly small curvature across all directions. Experimental results show that GAM improves the generalization of models trained with current optimizers such as SGD and AdamW on various datasets and networks. Furthermore, we show that GAM can help SAM find flatter minima and achieve better generalization. The code is available at https://github.com/xxgege/GAM.

†† † Equal contribution, *Corresponding author 1 Introduction

Current neural networks have achieved promising results in a wide range of fields [57, 39, 59, 81, 73, 80, 84, 79], yet they are typically heavily over-parameterized [2, 4]. Such heavy overparameterization leads to severe overfitting and poor generalization to unseen data when the model is learned simply with common loss functions (e.g., cross-entropy) [29]. Thus effective training algorithms are required to limit the negative effects of overfitting training data and find generalizable solutions.

(a) (b) Figure 1: The comparison of the zeroth-order flatness (ZOF) and first-order flatness (FOF). Given a perturbation radius 𝜌 , ZOF can fail to indicate generalization error both when there are multiple minima (0(a)) and a single minimum (0(b)) in the radius while FOF remains discriminative. The height of blue rectangles in curly brackets is the value of ZOF and the height of gray triangles (which indicates the slope) is the value of FOF. In Figure 0(a), when 𝜌 is large and enough to cover multiple minima, ZOF can not measure the fluctuation frequency while FOF prefers the flatter valley which has a smaller gradient norm. When 𝜌 is small and covers only a single minimum, the maximum loss in 𝜌 can be misleading as it can be misaligned with the uptrend of loss. As shown in Figure 0(b), ZOF prefers the valley on the right, which has a larger generalization error (the orange dotted line), while FOF prefers the left one.

Many studies try to improve model generalization by modifying the training procedure, such as batch normalization [28], dropout [25], and data augmentation [78, 74, 14]. Especially, some works discuss the connection between the geometry of the loss landscape and generalization [29, 23, 20]. A branch of effective approaches, sharpness-Aware Minimization (SAM) [20] and its variants [17, 47, 18, 52, 83, 37], minimizes the worst-case loss within a perturbation radius, which we call zeroth-order flatness. It is proven that optimizing the zeroth-order flatness leads to lower generalization error and achieves state-of-the-art performance on various image classification tasks [20, 86, 42].

Optimizing the worst case, however, relies on a reasonable choice of perturbation radius 𝜌 . As a prefixed hyperparameter in SAM or a hyperparameter under parameter re-scaling in its variants, such as ASAM [42], 𝜌 can not always be a perfect choice in the whole training process. We show that the zeroth-order flatness may fail to indicate the generalization error with a given 𝜌 . As in Figure 0(a), when 𝜌 covers multiple minima, the zeroth-order flatness (SAM) can not measure the fluctuation frequency. When there is a single minimum within 𝜌 , as in Figure 0(b) the observation radius is limited and the maximum loss in 𝜌 can be misaligned with the uptrend of loss. So zeroth-order flatness can be misleading and the knowledge of loss gradient is required for generalization error minimization.

To address this problem, we introduce first-order flatness, which controls the maximum gradient norm in the neighborhood of minima. We show that the first-order flatness is stronger than the zeroth-order flatness as the loss intensity of the loss fluctuation can be bounded by the maximum gradient. When the perturbation radius covers multiple minima, which we show is quite common in practice, the first-order flatness discriminates more drastic jitters from real flat valleys, as in Figure 0(a). When the perturbation radius is small and covers only one minimum, the first-order flatness demonstrates the trend of loss gradient and can help indicate generalization error. We further show that the first-order flatness directly controls the maximal eigenvalue of Hessian of the training loss, which is a proper sharpness/flatness measure indicating the loss uptrend under an adversarial perturbation to the weights [36, 34, 35].

To optimize the first-order flatness in deep model training, we propose Gradient norm Aware Minimization (GAM), which approximates the maximum gradient norm with stochastic gradient ascent and Hessian-vector products to avoid the materialization of the Hessian matrix.

We summarize our contributions as follows.

We present first-order flatness, which measures the largest gradient norm in the neighborhood of minima. We show that the first-order flatness is stronger than current zeroth-order flatness and it controls the maximum eigenvalue of Hessian.

We propose a novel training procedure, GAM, to simultaneously optimize prediction loss and first-order flatness. We analyze the generalization error and the convergence of GAM.

We empirically show that GAM considerably improves model generalization when combined with current optimizers such as SGD and AdamW across a wide range of datasets and networks. We show that GAM further improves the generalization of models trained with SAM.

We empirically validate that GAM indeed finds flatter optima with lower Hessian spectra.

2 Related Works Optimizer

Some studies [68, 20] have demonstrated that current optimization approaches, such as SGD [53], Adam [38], AdamW [49] and others [19, 46] affect generalization. Some previous literature finds that Adam is more vulnerable to sharp minima than SGD [64], which results in worse generalization ability [67, 22, 26]. Some following works [50, 10, 68, 76] propose generalizable optimizers to address this problem. However, it can be a trade-off between generalization ability and convergence speed [36, 68, 46, 76, 19]. Different tasks and network architectures may agree with different optimizers (e.g., SGD is often chosen for ResNet [24] while AdamW [49] for ViTs [16]). Thus selecting a proper optimizer is critical while the understanding of its relationship to model generalization remains nascent [20].

Flat Minima and Generalization

Many recent works show that flatter minima lead to better generalization [36, 36, 86, 32, 55]. Recently, [35] thoroughly reviews the literature related to generalization and sharpness of minima. It highlights the role of maximum Hessian eigenvalue in deciding the sharpness of minima [36, 63]. And there also have been several simple strategies to achieve a smaller maximum Hessian eigenvalue, such as choosing a large learning rate [44, 12, 31] and smaller batch size [61, 44, 30]. Sharpness-Aware Minimization (SAM) [20] and its variants [86, 42, 17, 47, 18, 52, 83, 37] are representative training algorithm to seek flat minima for better generalization. However, their definition of flatness is limited to zeroth-order flatness. In this paper, we present first-order flatness, a stronger flatness measure to learn better generalization. It is shown that discrete steps of gradient descent regularize deep models implicitly by penalizing the gradient descent trajectories with large loss gradients and this implicit regularization helps to find flat minima [7]. [82] proposes to directly control the gradient norm. They focus on the gradient norm at each training step, while we propose to penalize the maximum gradient norm in the neighborhood of minima and show the connection between our regularizer and the largest eigenvalue of Hessian and generalization error.

3 Preliminaries Notations

Let 𝒳 and 𝒴 be the sample space and label space, respectively. Let 𝒟 denote the training distribution on 𝒳 × 𝒴 and 𝑆

{ ( 𝑥 𝑖 , 𝑦 𝑖 ) } 𝑖

1 𝑛 denote the training dataset with 𝑛 data-points drawn independently from 𝒟 . Let 𝜽 ∈ Θ ⊆ ℝ 𝑑 denote the parameters of the model. In addition, we use 𝐵 ⁢ ( 𝜽 , 𝜌 ) to denote the open ball of radius 𝜌 > 0 centered at the point 𝜽 in the Euclidean space, i.e., 𝐵 ⁢ ( 𝜽 , 𝜌 )

{ 𝜽 ′ : ‖ 𝜽 − 𝜽 ′ ‖ < 𝜌 } 111We use ∥ ⋅ ∥ to denote the L2 norm throughout the paper..

Let ℓ : Θ × 𝒳 × 𝒴 → ℝ be the per-data-point loss function. Let 𝐿 ^ ⁢ ( 𝜽 )

∑ 𝑖

1 𝑛 ℓ ⁢ ( 𝜽 , 𝑥 𝑖 , 𝑦 𝑖 ) and 𝐿 ⁢ ( 𝜽 )

𝔼 ( 𝑥 , 𝑦 ) ∼ 𝒟 ⁢ [ ℓ ⁢ ( 𝜽 , 𝑥 , 𝑦 ) ] denote the empirical loss function and population-level loss function, respectively. We assume 𝐿 ^ ⁢ ( 𝜽 ) and 𝐿 ⁢ ( 𝜽 ) are twice differentiable throughout the paper. ∇ 𝐿 ⁢ ( 𝜽 ) and ∇ 2 𝐿 ⁢ ( 𝜽 ) ( ∇ 𝐿 ^ ⁢ ( 𝜽 ) and ∇ 2 𝐿 ^ ⁢ ( 𝜽 ) ) are the derivative and Hessian matrix of the function 𝐿 ⁢ ( ⋅ ) ( 𝐿 ^ ⁢ ( ⋅ ) ) at point 𝜽 , respectively. Besides, for any 𝜽 ∈ Θ , we use ∇ ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 ) ‖ to represent the gradient of function ‖ ∇ 𝐿 ^ ⁢ ( ⋅ ) ‖ at point 𝜽 . In addition, we use 𝐿 oracle ⁢ ( 𝜽 ) to denote an oracle loss function and it can be chosen as empirical loss function 𝐿 ^ ⁢ ( 𝜽 ) , 𝐿 ^ ⁢ ( 𝜽 ) with the weight decay regularization, and other common loss functions.

3.1 Zeroth-Order Flatness

The most popular mathematical definitions of flatness considers the maximal loss value within a raduis [36, 20], which we call the zeroth-order flatness. We follow the loss function proposed in SAM:

𝐿 sam ⁢ ( 𝜽 )

𝐿 ^ ⁢ ( 𝜽 ) + max 𝜽 ′ ∈ 𝐵 ⁢ ( 𝜽 , 𝜌 ) ⁡ ( 𝐿 ^ ⁢ ( 𝜽 ′ ) − 𝐿 ^ ⁢ ( 𝜽 ) ) . (1)

The second term in the right-hand side of Equation (1) can be considered as a measure of the zeroth-order flatness.

Definition 3.1 ( 𝜌 -zeroth-order flatness).

For any 𝜌

0 , the 𝜌 -zeroth-order flatness 𝑅 𝜌 ( 0 ) ⁢ ( 𝜽 ) of function 𝐿 ^ ⁢ ( 𝜽 ) at a point 𝜽 is defined as

𝑅 𝜌 ( 0 ) ⁢ ( 𝜽 ) ≜ max 𝜽 ′ ∈ 𝐵 ⁢ ( 𝜽 , 𝜌 ) ⁡ ( 𝐿 ^ ⁢ ( 𝜽 ′ ) − 𝐿 ^ ⁢ ( 𝜽 ) ) , ∀ 𝜽 ∈ Θ . (2)

Here 𝜌 is the perturbation radius that controls the magnitude of the neighborhood.

Intuitively, we name the term zeroth-order flatness because it measures the gap between the maximum loss value and the current point. As a measure of accumulation of gradients, zeroth-order flatness can be insufficient to indicate the generalization loss as shown in Section 4.2. In this paper, we propose a novel first-order flatness measure and compare these two flatness notions in Section 4.2.

4 First-order Flatness and Optimization

In this section, we introduce the first-order flatness and the corresponding minimizer for optimization. In Section 4.1, we formulate the first-order flatness and show its connection with the maximal eigenvalue of the Hessian. Afterward, we discuss the relationship between the zeroth-order and first-order flatness in Section 4.2. In Section 4.3, we present the optimization framework based on the first-order flatness as shown in Algorithm 1. We further provide a generalization bound with respect to the empirical loss, the first-order flatness, and high order terms, indicating that optimizing the first-order flatness improves generalization abilities. We then prove the convergence of the algorithm.

4.1 First-order Flatness

We first introduce the formulation of the first-order flatness, which measures the maximal gradient norm in the neighbourhood of a point 𝜽 ∈ Θ .

Definition 4.1 ( 𝜌 -first-order flatness).

For any 𝜌

0 , the 𝜌 -first-order flatness 𝑅 𝜌 ( 1 ) ⁢ ( 𝜽 ) of function 𝐿 ^ ⁢ ( 𝜽 ) at a point 𝜽 is defined as

𝑅 𝜌 ( 1 ) ⁢ ( 𝜽 ) ≜ 𝜌 ⋅ max 𝜽 ′ ∈ 𝐵 ⁢ ( 𝜽 , 𝜌 ) ⁡ ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 ′ ) ‖ , ∀ 𝜽 ∈ Θ . (3)

Here 𝜌 is the perturbation radius that controls the magnitude of the neighbourhood.

Intuitively, the first-order flatness entails that the loss function 𝐿 ^ ⁢ ( 𝜽 ) should not change drastically in the neighbourhood of 𝜽 so that the largest gradient norm of loss is constrained.

We then discuss the relationship between the first-order flatness and the maximal eigenvalue of the Hessian matrix ∇ 2 𝐿 ^ ⁢ ( 𝜽 * ) (denoted as 𝜆 max ⁢ ( ∇ 2 𝐿 ^ ⁢ ( 𝜽 * ) ) ). 𝜆 max is proven to be a proper measure of the curvature of minima [36, 35] and is closely related to generalization abilities [30, 63, 11]. As another definition of flatness in related works [44, 9], 𝜆 max is widely accepted yet hard to calculate. We show in the following lemma that given a radius 𝜌 , the first-order flatness controls 𝜆 max , which reinforces the validity of the first-order flatness.

Lemma 4.1.

Let 𝛉 * be a local minimum of 𝐿 ^ . Suppose 𝐿 ^ can be second-order Taylor approximated in the neighbourhood 𝐵 ⁢ ( 𝛉 * , 𝜌 ) 222The second order Taylor approximation assumption is commonly adopted in optimization-related literature [51, 77, 66, 68] to analyze the properties near critical points., i.e., ∀ 𝛉 ∈ 𝐵 ⁢ ( 𝛉 * , 𝜌 ) , 𝐿 ^ ⁢ ( 𝛉 )

𝐿 ^ ⁢ ( 𝛉 * ) + ( 𝛉 − 𝛉 * ) ⊤ ⁢ ∇ 2 𝐿 ^ ⁢ ( 𝛉 * ) ⁢ ( 𝛉 − 𝛉 * ) / 2 . Then

𝜆 max ⁢ ( ∇ 2 𝐿 ^ ⁢ ( 𝜽 * ) )

𝑅 𝜌 ( 1 ) ⁢ ( 𝜽 * ) 𝜌 2 . (4)

Since the maximal eigenvalue of Hessian matrices is usually difficult to approximate and optimize directly [72, 71], the first-order flatness becomes a proper surrogate of 𝜆 max .

4.2 Comparison with Zeroth-order Flatness

We compare the first-order flatness with the zeroth-order flatness. We first show that 𝑅 𝜌 ( 0 ) ⁢ ( 𝜽 ) in Equation (2) is bounded by 𝑅 𝜌 ( 1 ) ⁢ ( 𝜽 ) in Equation (3).

Proposition 4.2.

For any 𝛉 ∈ Θ , 𝑅 𝜌 ( 0 ) ⁢ ( 𝛉 ) is bounded by 𝑅 𝜌 ( 1 ) ⁢ ( 𝛉 ) , i.e., 𝑅 𝜌 ( 1 ) ⁢ ( 𝛉 ) ≥ 𝑅 𝜌 ( 0 ) ⁢ ( 𝛉 ) .

Thus a smaller 𝑅 𝜌 ( 1 ) also leads to a smaller 𝑅 𝜌 ( 0 ) , indicating that 𝑅 𝜌 ( 1 ) is a stronger flatness measure than 𝑅 𝜌 ( 0 ) . Proposition 4.2 gives an explanation that the first-order flatness covers wider scenarios compared with the zeroth-order flatness.

We present scenarios where the zeroth-order flatness fails to indicate generalization error while the first-order flatness remains discriminative in Figure 1. The gap between a local minimum and the largest loss in 𝜌 can be considered as an accumulation of gradients across the trajectory while the largest gradient norm measures the maximum ascent rate, which may indicate the trends of loss outside of 𝜌 .

When 𝜌 is large, there probably exist several other local minima in the neighborhood 𝐵 ⁢ ( 𝜽 * , 𝜌 ) as shown in Figure 0(a). This case is common in practice as shown in Section 5.1. In addition, when the number of local minimum in 𝐵 ⁢ ( 𝜽 * , 𝜌 ) becomes larger, 𝜽 * is expected to become sharper since the valley of 𝜽 * becomes narrower. However, the zeroth-order flatness 𝑅 𝜌 ( 0 ) only measures the maximal gap of the loss function in 𝐵 ⁢ ( 𝜽 * , 𝜌 ) and fails to distinguish the cases when the number of local minimums varies. By contrast, the maximal gradient norm in 𝐵 ⁢ ( 𝜽 * , 𝜌 ) increases when the number of local minima is larger, indicating that the first-order flatness can successfully characterize the sharpness in this case.

When 𝜌 only covers a single minimum, as shown in Figure 0(b), the zeroth-order flatness in 𝜌 can be misleading since the observation radius is insufficient to measure the loss trend with the maximum loss. The first-order flatness can help to learn more about the loss trend.

From the perspective of flatness, the zeroth-order flatness focuses on the average gradient within a radius while the first-order flatness measures the maximum gradient. Intuitively, the combination of the zeroth-order and first-order captures a more comprehensive picture of the loss landscape. Furthermore, as discussed in the following Section 4.3, minimizers for both flatness measures adopt the first-order approximation to calculate the maxima within a radius. This may be the reason that the combination of the two flatness measures achieves the best performance as shown in Section 5.

4.3 Gradient Norm Aware Minimization

In this subsection, we propose a novel Gradient norm Aware Minimization (GAM) framework to incorporate the first-order flatness 𝑅 𝜌 ( 1 ) ⁢ ( 𝜽 ) into optimization procedures.

Specifically, suppose we could obtain an oracle loss function 𝐿 oracle ⁢ ( 𝜽 ) and calculate its gradient ∇ 𝐿 oracle ⁢ ( 𝜽 ) . 𝐿 oracle ⁢ ( 𝜽 ) can be chosen as the empirical loss function 𝐿 ^ ⁢ ( 𝜽 ) and the empirical loss function with other regularizations (such as the weight decay and the zeroth-order flatness as shown in Definition 3.1).

Generalization analysis

We first derive a generalization bound w.r.t. the first-order flatness in Proposition 4.3.

Proposition 4.3.

Suppose the per-data-point loss function ℓ is differentiable and bounded by 𝑀 . Fix 𝜌

0 and 𝛉 ∈ Θ . Then with probability at least 1 − 𝛿 over training set 𝑆 generated from the distribution 𝒟 ,

𝔼 𝜖 𝑖 ∼ 𝑁 ⁢ ( 0 , 𝜌 2 / ( 𝑑 + log ⁡ 𝑛 ) 2 ) ⁢ [ 𝐿 ⁢ ( 𝜽 + 𝜖 ) ] (5)

𝐿 ^ ⁢ ( 𝜽 ) + 𝑅 𝜌 ( 1 ) ⁢ ( 𝜽 ) + 𝑀 𝑛

1 4 ⁢ 𝑑 ⁢ log ⁡ ( 1 + ‖ 𝜽 ‖ 2 ⁢ ( 𝑑 + log ⁡ 𝑛 ) 2 𝑑 ⁢ 𝜌 2 ) + 1 4 + log ⁡ 𝑛 𝛿 + 2 ⁢ log ⁡ ( 6 ⁢ 𝑛 + 3 ⁢ 𝑑 ) 𝑛 − 1 .

Remark.

The left-hand side of Equation (5) is close to the population-level loss function 𝐿 ⁢ ( 𝜽 ) since the numbers of samples 𝑛 and parameters 𝑑 are often large. As a result, ignoring high-order terms, the population-level loss 𝐿 ⁢ ( 𝜽 ) is bounded by the empirical loss 𝐿 ^ ⁢ ( 𝜽 ) and the first-order flatness 𝑅 𝜌 ( 1 ) ⁢ ( 𝜽 ) , which motivates us to use 𝑅 𝜌 ( 1 ) ⁢ ( 𝜽 ) as a regularizer to help improve the generalization abilities of models.

Inspired by Lemma 4.1 and Proposition 4.3, the overall loss function is given by

𝐿 overall ⁢ ( 𝜽 )

𝐿 oracle ⁢ ( 𝜽 ) + 𝛼 ⁢ 𝑅 𝜌 ( 1 ) ⁢ ( 𝜽 ) , (6)

where 𝛼 is a hyperparameter that determines the strength of regularization. The gradient of the loss function 𝐿 overall ⁢ ( 𝜽 ) is given by ∇ 𝐿 overall ⁢ ( 𝜽 )

∇ 𝐿 oracle ⁢ ( 𝜽 ) + 𝛼 ⁢ ∇ 𝑅 𝜌 ( 1 ) ⁢ ( 𝜽 ) . Using similar techniques in [20], GAM approximates ∇ 𝑅 𝜌 ( 1 ) ⁢ ( 𝜽 ) by

∇ 𝑅 𝜌 ( 1 ) ⁢ ( 𝜽 ) ≈ 𝜌 ⋅ ∇ ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 adv ) ‖ , 𝜽 adv

𝜽 + 𝜌 ⋅ 𝒇 ‖ 𝒇 ‖ , (7)
𝒇

∇ ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 ) ‖ .

Details of the derivation of ∇ 𝑅 𝜌 ( 1 ) ⁢ ( 𝜽 ) can be found in Appendix A. Notice that

∀ 𝜽 ∈ Θ , ∇ ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 ) ‖

∇ 2 𝐿 ^ ⁢ ( 𝜽 ) ⋅ ∇ 𝐿 ^ ⁢ ( 𝜽 ) ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 ) ‖ . (8)

As a result, Equation (7) can be calculated efficiently by the Hessian vector product. The pseudocode of the whole optimization procedure is shown in Algorithm 1.

Convergence analysis

We further analyze the convergence properties of GAM. Firstly, we introduce the Lipschitz smoothness, which is common adopted in optimization-related literature [1, 70, 86].

Definition 4.2.

A function 𝐽 : Θ → ℝ is 𝛾 -Lipschitz smooth if

∀ 𝜽 1 , 𝜽 2 ∈ Θ , ‖ ∇ 𝐽 ⁢ ( 𝜽 1 ) − ∇ 𝐽 ⁢ ( 𝜽 2 ) ‖ ≤ 𝛾 ⁢ ‖ 𝜽 1 − 𝜽 2 ‖ . (9)

With Definition 4.2, we could prove the convergence property of GAM as shown in Theorem 4.4.

Theorem 4.4.

Suppose 𝐿 𝑜𝑟𝑎𝑐𝑙𝑒 ⁢ ( 𝛉 ) is 𝛾 1 -Lipschitz smooth and 𝐿 ^ ⁢ ( 𝛉 ) is 𝛾 2 -Lipschitz smooth. Suppose | 𝐿 𝑜𝑟𝑎𝑐𝑙𝑒 ⁢ ( 𝛉 ) | is bounded by 𝑀 . For any timestamp 𝑡 ∈ { 0 , 1 , … , 𝑇 } and any 𝛉 ∈ Θ , suppose we can obtain noisy and bounded observations 𝑔 𝑡 𝑙𝑜𝑠𝑠 ⁢ ( 𝛉 ) , 𝑔 𝑡 𝑛𝑜𝑟𝑚 ⁢ ( 𝛉 ) , and 𝑔 ~ 𝑡 𝑙𝑜𝑠𝑠 ⁢ ( 𝛉 ) of ∇ 𝐿 ^ ⁢ ( 𝛉 ) , ∇ ‖ ∇ 𝐿 ^ ⁢ ( 𝛉 ) ‖ , and ∇ 𝐿 𝑜𝑟𝑎𝑐𝑙𝑒 ⁢ ( 𝛉 ) such that

𝔼 ⁢ [ 𝑔 𝑡 𝑙𝑜𝑠𝑠 ⁢ ( 𝜽 ) ]

∇ 𝐿 ^ ⁢ ( 𝜽 ) , ‖ 𝑔 𝑡 𝑙𝑜𝑠𝑠 ⁢ ( 𝜽 ) ‖ ≤ 𝐺 𝑙𝑜𝑠𝑠 , ‖ 𝑔 𝑡 𝑛𝑜𝑟𝑚 ⁢ ( 𝜽 ) ‖ ≤ 𝐺 𝑛𝑜𝑟𝑚 , (10)
𝔼 ⁢ [ 𝑔 ~ 𝑡 𝑙𝑜𝑠𝑠 ⁢ ( 𝜽 ) ]

∇ 𝐿 𝑜𝑟𝑎𝑐𝑙𝑒 ⁢ ( 𝜽 ) , ‖ 𝑔 ~ 𝑡 𝑙𝑜𝑠𝑠 ⁢ ( 𝜽 ) ‖ ≤ 𝐺 ~ 𝑙𝑜𝑠𝑠 .

Then with learning rate 𝜂 𝑡

𝜂 0 / 𝑡 and perturbation radius 𝜌 𝑡

𝜌 0 / 𝑡 , GAM could obtain

1 𝑇 ⁢ ∑ 𝑡

1 𝑇 𝔼 ⁢ [ ‖ ∇ 𝐿 𝑜𝑣𝑒𝑟𝑎𝑙𝑙 ⁢ ( 𝜽 𝑡 ) ‖ 2 ] ≤ 𝐶 1 + 𝐶 2 ⁢ log ⁡ 𝑇 𝑇 , (11)

for some constants 𝐶 1 and 𝐶 2 that only depend on 𝛾 , 𝐺 𝑙𝑜𝑠𝑠 , 𝐺 𝑛𝑜𝑟𝑚 , 𝐺 ~ 𝑙𝑜𝑠𝑠 , 𝑀 , 𝜂 0 , 𝜌 0 , and 𝛼 . Here ∇ 𝐿 𝑜𝑣𝑒𝑟𝑎𝑙𝑙 ⁢ ( 𝛉 𝑡 )

∇ 𝐿 𝑜𝑟𝑎𝑐𝑙𝑒 ⁢ ( 𝛉 𝑡 ) + 𝛼 ⁢ ∇ 𝑅 𝜌 ( 1 ) ⁢ ( 𝛉 𝑡 ) and ∇ 𝑅 𝜌 ( 1 ) ⁢ ( 𝛉 𝑡 ) is approximated in Equation (7).

Remark.

The assumptions in Theorem 4.4 are common and standard when analyzing convergence of non-convex functions via SGD-based methods [38, 56, 86]. In addition, the requirements on 𝐿 oracle ⁢ ( 𝜽 ) (i.e., 𝐿 oracle ⁢ ( 𝜽 ) is Lipschitz smooth and we can obtain unbiased and bounded observations of ∇ 𝐿 oracle ⁢ ( 𝜽 ) ) are mild and common. For example, when the empirical loss function 𝐿 ^ ⁢ ( 𝜽 ) satisfies the constraints, it is easy to check that 𝐿 ^ ⁢ ( 𝜽 ) with the weight decay regularization also meets the requirements.

Algorithm 1 Gradient norm Aware Minimization (GAM) 1:Input: Batch size 𝑏 , Learning rate 𝜂 𝑡 , Perturbation radius 𝜌 𝑡 , Trade-off coefficient 𝛼 , Small constant 𝜉 2: 𝑡 ← 0 , 𝜽 0 ← initial parameters 3:while  𝜽 𝑡 not converged do 4:     Sample 𝑊 𝑡 from the training data with 𝑏 instances 5:      𝒉 𝑡 loss ← ∇ 𝐿 oracle ⁢ ( 𝜽 𝑡 ) ▷ Calculate the oracle loss gradient ∇ 𝐿 oracle ⁢ ( 𝜽 𝑡 ) 6:      𝒇 𝑡 ← ∇ 2 𝐿 ^ 𝑊 𝑡 ⁢ ( 𝜽 𝑡 ) ⋅ ∇ 𝐿 ^ 𝑊 𝑡 ⁢ ( 𝜽 𝑡 ) ‖ ∇ 𝐿 ^ 𝑊 𝑡 ⁢ ( 𝜽 𝑡 ) ‖ + 𝜉 7:      𝜽 𝑡 adv ← 𝜽 𝑡 + 𝜌 𝑡 ⋅ 𝒇 𝑡 ‖ 𝒇 𝑡 ‖ + 𝜉 8:      𝒉 𝑡 norm ← 𝜌 𝑡 ⋅ ∇ 2 𝐿 ^ 𝑊 𝑡 ⁢ ( 𝜽 𝑡 adv ) ⋅ ∇ 𝐿 ^ 𝑊 𝑡 ⁢ ( 𝜽 𝑡 adv ) ‖ ∇ 𝐿 ^ 𝑊 𝑡 ⁢ ( 𝜽 𝑡 adv ) ‖ + 𝜉 ▷ Calculate the norm gradient ∇ 𝑅 𝜌 𝑡 ( 1 ) ⁢ ( 𝜽 𝑡 ) 9:      𝜽 𝑡 + 1 ← 𝜽 𝑡 − 𝜂 𝑡 ⁢ ( 𝒉 𝑡 loss + 𝛼 ⁢ 𝒉 𝑡 norm ) 10:      𝑡 ← 𝑡 + 1 11:end while 12:return 𝜽 𝑡 5 Experiments

We empirically show that the case discussed in Section 4.2 is common in practice. Then we evaluate GAM with random initialization on various state-of-the-art models and the transfer learning setting on various datasets. We show the Hessian spectra of GAM at convergence and discuss the computation overhead of GAM with the considerable improvement of model generalization.

Figure 2: The distribution of numbers of local minima and maxima within the perturbation radius 𝜌 after convergence. 5.1 The Density of Local Minima

To investigate the number of local minima within the perturbation radius, we train 3 ResNet-18 models with SAM on CIFAR-100 with proper hyperparameters for 200 epochs. The perturbation radius is set to 0.1 as suggested by [20]. We load the checkpoints at convergence for evaluation. We randomly generate 100 perturbation directions with the same size as the model weights for each model. For each direction, we repeatedly add a perturbation with the norm of 0.01 along the selected direction 10 times. We calculate the training loss after each addition and report the distribution of the number of local maxima and minima along each perturbation direction within the perturbation radius 𝜌 of 0.1. As shown in Figure 2, we find more than 1 local minima within 𝜌 for most of the directions, indicating that the case is common in practice. As discussed in Section 4.2, zeroth-order flatness fails to tell the sharpness caused by multiple minima while the first-order flatness measure increases as the sharpness grow.

5.2 Training from Scratch 5.2.1 CIFAR-10 and CIFAR-100 Table 1: Results of GAM with state-of-the-art models on CIFAR-10 and CIFAR-100. The best results are highlighted in bold font. CIFAR-10 CIFAR-100 Model Aug SGD SGD + GAM SAM SAM + GAM SGD SGD + GAM SAM SAM + GAM ResNet18 Basic 95.32 ± 0.13 96.17 ± 0.21 96.10 ± 0.20 96.75 ± 0.18 78.32 ± 0.32 79.53 ± 0.30 79.27 ± 0.16 80.45 ± 0.25

ResNet18 Cutout 95.99 ± 0.13 96.46 ± 0.20 96.64 ± 0.13 96.99 ± 0.23 78.73 ± 0.13 79.89 ± 0.31 79.43 ± 0.15 80.80 ± 0.14

ResNet18 RA 96.07 ± 0.07 96.52 ± 0.09 96.64 ± 0.17 97.06 ± 0.13 78.62 ± 0.32 79.82 ± 0.24 79.71 ± 0.15 80.97 ± 0.29

ResNet18 AA 96.13 ± 0.05 96.71 ± 0.07 96.75 ± 0.08 97.17 ± 0.08 78.88 ± 0.15 80.56 ± 0.21 80.58 ± 0.25 81.59 ± 0.24

ResNet101 Basic 96.35 ± 0.08 96.98 ± 0.11 96.82 ± 0.16 97.20 ± 0.15 80.47 ± 0.13 82.21 ± 0.40 82.03 ± 0.12 83.13 ± 0.07

ResNet101 Cutout 96.56 ± 0.18 97.22 ± 0.05 97.07 ± 0.08 97.36 ± 0.24 80.53 ± 0.30 82.36 ± 0.24 81.60 ± 0.35 83.40 ± 0.13

ResNet101 RA 96.68 ± 0.25 97.33 ± 0.30 97.12 ± 0.18 97.40 ± 0.23 80.60 ± 0.28 82.40 ± 0.31 82.19 ± 0.34 83.28 ± 0.20

ResNet101 AA 96.78 ± 0.14 97.39 ± 0.18 97.18 ± 0.11 97.42 ± 0.1 81.83 ± 0.37 83.19 ± 0.15 82.44 ± 0.47 83.94 ± 0.23

WRN28_2 Basic 94.82 ± 0.07 95.69 ± 0.13 95.47 ± 0.08 95.85 ± 0.08 75.45 ± 0.25 77.21 ± 0.31 77.04 ± 0.18 77.69 ± 0.20

WRN28_2 Cutout 95.70 ± 0.20 96.41 ± 0.18 96.22 ± 0.13 96.39 ± 0.22 76.80 ± 0.45 78.58 ± 0.24 78.04 ± 0.43 79.33 ± 0.12

WRN28_2 RA 95.75 ± 0.16 96.35 ± 0.13 96.22 ± 0.08 96.49 ± 0.20 76.73 ± 0.27 78.66 ± 0.03 77.88 ± 0.29 78.96 ± 0.13

WRN28_2 AA 95.44 ± 0.06 95.98 ± 0.09 96.07 ± 0.08 96.44 ± 0.09 77.35 ± 0.02 79.05 ± 0.10 78.64 ± 0.23 79.50 ± 0.21

WRN28_10 Basic 95.73 ± 0.10 96.61 ± 0.15 96.78 ± 0.80 97.29 ± 0.11 81.40 ± 0.13 83.45 ± 0.09 83.41 ± 0.04 84.31 ± 0.06

WRN28_10 Cutout 96.74 ± 0.03 96.97 ± 0.05 97.35 ± 0.16 97.56 ± 0.12 81.53 ± 0.40 83.69 ± 0.08 82.38 ± 0.15 84.43 ± 0.13

WRN28_10 RA 97.14 ± 0.04 96.83 ± 0.03 97.58 ± 0.07 97.49 ± 0.03 81.65 ± 0.18 83.84 ± 0.09 82.79 ± 0.06 84.68 ± 0.13

WRN28_10 AA 96.93 ± 0.12 97.05 ± 0.04 97.48 ± 0.06 97.67 ± 0.08 81.99 ± 0.11 84.02 ± 0.18 83.84 ± 0.30 84.81 ± 0.21

PyramidNet110 Basic 96.19 ± 0.11 97.11 ± 0.14 97.26 ± 0.05 97.51 ± 0.09 82.74 ± 0.12 84.91 ± 0.09 85.01 ± 0.09 85.25 ± 0.06

PyramidNet110 Cutout 96.82 ± 0.09 97.32 ± 0.21 97.49 ± 0.06 97.91 ± 0.14 83.31 ± 0.21 85.20 ± 0.19 84.90 ± 0.03 85.46 ± 0.10

PyramidNet110 RA 97.15 ± 0.21 97.80 ± 0.22 97.60 ± 0.09 98.01 ± 0.10 84.04 ± 0.19 86.47 ± 0.14 85.33 ± 0.27 85.64 ± 0.20

PyramidNet110 AA 97.11 ± 0.01 97.85 ± 0.02 97.61 ± 0.14 97.95 ± 0.10 84.48 ± 0.03 85.92 ± 0.03 85.69 ± 0.17 86.35 ± 0.18

We conduct experiments on CIFAR-10 and CIFAR-100 [41] with ResNets [24], WideResNet [75], ResNeXt [65], PyramidNet [21] and Vision Transformers (ViTs) [16]. All the models are trained for 200 epochs from scratch. We evaluate GAM both with basic data augmentations (i.e., horizontal flip, padding by four pixels, and random crop) and advanced data augmentation including cutout regularization [15], RandAugment [13] and AutoAugment [14].

GAM has two hyperparameters, 𝜌 and 𝛼 . We conduct a grid search over { 0.05 , 0.1 , 0.2 , 0.5 , 1.0 , 2.0 } to tune 𝜌 and { 0.1 , 0.2 , 0.5 , 1.0 , 2.0 , 3.0 , … , 10.0 } for 𝛼 using 10% of the training data as a validation set. The selection of hyperparameters is in Appendix C.5.

As a gradient regularizer, GAM can be integrated with current optimizers such as SGD and Adam [36]. We also show that GAM can be combined with sharpness-aware training procedures such as SAM. As shown in Section 4.2, the GAM term bounds the regularization term in SAM. Yet the practical implementations of GAM and SAM rely on first-order Taylor expansion of different objective functions (GAM approximates the maximum gradient norm while SAM approximates the maximum loss). We empirically show that the combination of GAM and SAM outperforms both of them, indicating that they may strengthen each other with omitted items.

As shown in Table 1, GAM improves generalization for all models on CIFAR-10 and CIFAR-100. When combined with SGD, GAM achieves considerably higher test accuracy compared with SGD. Moreover, GAM further improves generalization when combined with SAM. For example, GAM improves SAM performance by 1.18% and 1.10% on CIFAR-100 with ResNet-18 and ResNet-101, respectively, which are noticeable margins. Other experimental results are in Appendix C.1.

5.2.2 ImageNet Table 2: Results of GAM with ResNet50 on ImageNet. Model Dataset Base Opt Base + GAM SAM SAM + GAM ResNet50 Top-1 76.01 ± 0.19 76.59 ± 0.15 76.47 ± 0.11 76.86 ± 0.15

ResNet50 Top-5 92.75 ± 0.08 93.10 ± 0.08 93.07 ± 0.05 93.22 ± 0.06

ResNet101 Top-1 77.69 ± 0.08 78.45 ± 0.10 78.35 ± 0.12 78.70 ± 0.12

ResNet101 Top-5 93.76 ± 0.09 94.09 ± 0.12 94.02 ± 0.06 94.15 ± 0.12

ViT-S/32 Top-1 68.26 ± 0.22 69.95 ± 0.16 69.73 ± 0.05 70.15 ± 0.18

ViT-S/32 Top-5 87.39 ± 0.19 88.11 ± 0.26 87.91 ± 0.30 88.23 ± 0.18

ViT-B/32 Top-1 71.15 ± 0.14 73.58 ± 0.06 73.10 ± 0.18 73.70 ± 0.10

ViT-B/32 Top-5 90.12 ± 0.07 91.15 ± 0.19 91.03 ± 0.06 91.50 ± 0.16

We use ResNet50, ResNet101 [24], ViT-S/32 and ViT-B/32 [16] for evaluations on ImageNet [58] to evaluate GAM on large scale data. For ResNet, we use SGD with momentum= 0.9 as the base optimizer for both GAM and SAM. For ViT, we use the AdamW optimizer with 𝛽 1

0.9 , 𝛽 2

0.999 . We train ResNets for 90 epochs and ViTs for 300 epochs following [16]. We set the batch size to 256, learning rate to 0.1, and weight decay to 0.0001. The learning rate is decayed using a cosine schedule.

As shown in Table 2, GAM consistently improves SGD performance on ImageNet for both ResNets and ViTs. GAM also further improves the model generalization compared with SAM. The combination of GAM and SAM outperforms both SGD and SAM by a noticeable margin.

5.3 Transfer Learning Table 3: Results of GAM for finetuning EfficientNet-b0 and Swin Transformers on various datasets. EfficientNet-b0 Swin-t Dataset SGD SGD + GAM SAM SAM + GAM AdamW AdamW + GAM SAM SAM + GAM Stanford Cars 82.14 83.50 83.21 83.98 83.50 84.90 83.55 85.29 CIFAR-10 86.26 87.37 86.95 87.97 91.32 92.06 91.77 92.55 CIFAR-100 63.75 64.85 64.29 65.03 72.88 73.78 73.99 74.30 Oxford_IIIT_Pets 91.03 91.80 91.65 91.96 93.49 93.87 93.59 94.03 Food101 82.54 82.69 82.57 83.01 86.38 86.89 86.64 87.03

Transfer learning shows the generalization of models when trained on sufficient labeled data and finetuned on a small dataset [85]. We show that GAM improves generalization on all datasets in this setting.

We consider Stanford Cars [40], CIFAR-10, CIFAR-100 [41], Oxford_IIIT_Pets [54] and Food101 [8] for this setting. We apply SGD, SAM, and GAM to finetuning EfficientNet-b0 [62] and Swin-Transformer-t [48] on these datasets. Both EfficientNet-b0 and Swin-Transformer-t are pretrained on ImageNet.

We use ImageNet pretrained weights of EfficientNet-b0 and Swin-t except for the last linear layer for classification. Following previous works, we train for 40k steps since our batch size is 128. The initial learning rate is set to 2e-3 with cosine learning rate decay. Weight decay is set to 1e-5. We do not use any data augmentations for Stanford Cars, Oxford_IIIT_Pets and Food101. For CIFAR datasets, we employ the same data augmentations as previous experiments.

As seen in Table 3, GAM once again brings generalization improvement for SGD, AdamW, and SAM on both EfficientNet-b0 and Swin-t. For example, GAM improves AdamW by 1.2% on Stanford Cars with Swin-t and 1.11% on CIFAR-10 with EfficientNet-b0.

Moreover, we leave the experiments of robustness to label noise in Appendix C.2.

5.4 Top Eigenvalues of Hessian and Hessian Trace Figure 3: The distribution of top eigenvalues and the trace of Hessian at epoch 100 and 200 on CIFAR-100 with SGD, SGD + GAM, SAM, or SAM + GAM. Figure 4: Accuracy and training speed of training with different ratios ([0, 0.05, 0.1, 0.5, 1] from upper left to lower right, see details in Appendix C.3) of iterations using GAM. Numbers in parentheses indicate the ratio of the training speed compared with the vanilla base optimizer SGD/SAM.

Lemma 4.1 shows that the GAM term can be an equivalent measure of the maximum eigenvalue of the Hessian, which is a well-known measure of flatness/sharpness. Thus optimizing the GAM term decreases the maximum eigenvalue of the Hessian and leads to flatter minima. To empirically validate that GAM finds optima with low curvature, we present the Hessian spectra of SGD, SAM, and GAM. We consider the maximum eigenvalue of Hessian and the Hessian trace, which measures the expected loss increase under random perturbations to the weights [35] as the measures of flatness. We empirically show that GAM significantly decreases both the maximum eigenvalue and the trace of Hessian during training compared with SGD and SAM, and thus finds flatter minima.

We compute the Hessian spectra of ResNet-18 trained on CIFAR-100 for 200 epochs with SGD, SAM, SGD + GAM, and SAM + GAM. We use power iteration [72] to compute the top eigenvalues of Hessian and Hutchinson’s method [5, 6, 71] to compute the Hessian trace. We report the histogram of the distribution of the top-50 Hessian eigenvalues for each method.

As shown in Figure 3, the model trained with SGD has a higher maximum Hessian eigenvalue and Hessian trace at convergence compared to the middle of training, indicating that optimizing directly with cross-entropy loss does not contribute to the lower Hessian spectra. In contrast, GAM leads to lower Hessian spectra and thus flatter minima. Moreover, GAM helps to reduce both top eigenvalues and the Hessian trace when combined with SAM, where Hessian spectra at convergence are lower than other methods. We show visualizations of landscapes of SGD, SAM, and GAM in Section 5.6.

5.5 Computation Overhead

As discussed in Section 4.3, the GAM term can be easily calculated via the Hessian vector product, which is an efficient approach to calculating the dot product between the Hessian and a vector without the need to calculate the entire Hessian. However, it can still introduce extra computation when calculated in each iteration. To accelerate the training with GAM, we investigate applying GAM to only a few iterations in each epoch. Surprisingly, we show that only several iterations of learning with GAM (with higher 𝛼 compared with applying GAM to all iterations) improve model generalization considerably. As shown in Figure 4, with approximately 1/20 of iterations, GAM considerablly improves test accuracy for both SGD and SAM on CIFAR-10 and CIFAR-100. When applying GAM to 1/10 iterations of training, it shows similar effectiveness to applying GAM to all the iterations, while the extra computational cost for GAM is less than 25% of the original cost. GAM outperformes SAM with lower computation overhead and achieves significant improvement when combined with SGD (the red line in the figure). When combined with SAM, GAM also improves generalization with low computation cost. Thus the computation overhead of GAM can be easily controlled. The optimization of first-order flatness can be further accelerated by approximation of second-order gradient with first-order gradient and the details are in Appendix D.

5.6 Visualization of Landscapes (a) SGD (b) SGD + GAM (c) SAM (d) SAM + GAM Figure 5: Visualization of loss landscape for SGD, SGD + GAM, SAM, SAM + GAM.

We visualize the loss landscapes of models trained with SGD, SGD + GAM, SAM, SAM + GAM of the ResNet-18 model on CIFAR-100 following [45]. All the models are trained with the same hyperparameters for 200 epochs as described in Section 5.2.1. As shown in Figure 5, GAM consistently helps SGD and SAM find flatter minima.

6 Discussions

We show that the most popular definitions of flatness, which we call the zeroth-order flatness, can be insufficient to indicate generalization error. Thus we propose first-order flatness, a stronger flatness measure that bounds both the maximum eigenvalue of Hessian and the zeroth-order flatness. We also propose a novel Gradient norm Aware Minimization (GAM) to optimize the first-order flatness. We empirically show that GAM considerably improves generalization for SGD, AdamW, and SAM.

Despite the empirical effectiveness of GAM, adopting the first-order flatness for generalization has the following limitations which could lead to potential future work. First, a theoretical explanation of whether a stronger flatness measure is better for generalization is vital for selecting flatness measures in practice. Second, the contribution to generalization of combining the zeroth-order and first-order flatness requires a thorough theoretical analysis.

Acknowledgement

This work was supported in part by National Key R&D Program of China (No. 2018AAA0102004, No. 2020AAA0106300), National Natural Science Foundation of China (No. U1936219, 62141607), Beijing Academy of Artificial Intelligence (BAAI).

References [1] Zeyuan Allen-Zhu and Yuanzhi Li. Neon2: Finding local minima via first-order oracles. Advances in Neural Information Processing Systems, 31, 2018. [2] Zeyuan Allen-Zhu, Yuanzhi Li, and Zhao Song. A convergence theory for deep learning via over-parameterization. In International Conference on Machine Learning, pages 242–252. PMLR, 2019. [3] Eric Arazo, Diego Ortego, Paul Albert, Noel O’Connor, and Kevin McGuinness. Unsupervised label noise modeling and loss correction. In International conference on machine learning, pages 312–321. PMLR, 2019. [4] Sanjeev Arora, Nadav Cohen, and Elad Hazan. On the optimization of deep networks: Implicit acceleration by overparameterization. In International Conference on Machine Learning, pages 244–253. PMLR, 2018. [5] Haim Avron and Sivan Toledo. Randomized algorithms for estimating the trace of an implicit symmetric positive semi-definite matrix. Journal of the ACM (JACM), 58(2):1–34, 2011. [6] Zhaojun Bai, Gark Fahey, and Gene Golub. Some large-scale matrix computation problems. Journal of Computational and Applied Mathematics, 74(1-2):71–89, 1996. [7] David GT Barrett and Benoit Dherin. Implicit gradient regularization. arXiv preprint arXiv:2009.11162, 2020. [8] Lukas Bossard, Matthieu Guillaumin, and Luc Van Gool. Food-101–mining discriminative components with random forests. In European conference on computer vision, pages 446–461. Springer, 2014. [9] Pratik Chaudhari, Anna Choromanska, Stefano Soatto, Yann LeCun, Carlo Baldassi, Christian Borgs, Jennifer Chayes, Levent Sagun, and Riccardo Zecchina. Entropy-sgd: Biasing gradient descent into wide valleys. Journal of Statistical Mechanics: Theory and Experiment, 2019(12):124018, 2019. [10] Jinghui Chen, Dongruo Zhou, Yiqi Tang, Ziyan Yang, Yuan Cao, and Quanquan Gu. Closing the generalization gap of adaptive gradient methods in training deep neural networks. arXiv preprint arXiv:1806.06763, 2018. [11] Xiangning Chen, Cho-Jui Hsieh, and Boqing Gong. When vision transformers outperform resnets without pre-training or strong data augmentations. arXiv preprint arXiv:2106.01548, 2021. [12] Jeremy M. Cohen, Simran Kaur, Yuanzhi Li, J. Zico Kolter, and Ameet Talwalkar. Gradient descent on neural networks typically occurs at the edge of stability. In 9th International Conference on Learning Representations, ICLR 2021, Virtual Event, Austria, May 3-7, 2021. OpenReview.net, 2021. [13] Ekin D Cubuk, Barret Zoph, Dandelion Mane, Vijay Vasudevan, and Quoc V Le. Autoaugment: Learning augmentation policies from data. arXiv preprint arXiv:1805.09501, 2018. [14] Ekin D Cubuk, Barret Zoph, Jonathon Shlens, and Quoc V Le. Randaugment: Practical automated data augmentation with a reduced search space. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition workshops, pages 702–703, 2020. [15] Terrance DeVries and Graham W Taylor. Improved regularization of convolutional neural networks with cutout. arXiv preprint arXiv:1708.04552, 2017. [16] Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, et al. An image is worth 16x16 words: Transformers for image recognition at scale. In International Conference on Learning Representations, 2020. [17] Jiawei Du, Hanshu Yan, Jiashi Feng, Joey Tianyi Zhou, Liangli Zhen, Rick Siow Mong Goh, and Vincent YF Tan. Efficient sharpness-aware minimization for improved training of neural networks. arXiv preprint arXiv:2110.03141, 2021. [18] Jiawei Du, Daquan Zhou, Jiashi Feng, Vincent YF Tan, and Joey Tianyi Zhou. Sharpness-aware training for free. arXiv preprint arXiv:2205.14083, 2022. [19] John Duchi, Elad Hazan, and Yoram Singer. Adaptive subgradient methods for online learning and stochastic optimization. Journal of machine learning research, 12(7), 2011. [20] Pierre Foret, Ariel Kleiner, Hossein Mobahi, and Behnam Neyshabur. Sharpness-aware minimization for efficiently improving generalization. In International Conference on Learning Representations, 2021. [21] Dongyoon Han, Jiwhan Kim, and Junmo Kim. Deep pyramidal residual networks. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 5927–5935, 2017. [22] Moritz Hardt, Ben Recht, and Yoram Singer. Train faster, generalize better: Stability of stochastic gradient descent. In International conference on machine learning, pages 1225–1234. PMLR, 2016. [23] Haowei He, Gao Huang, and Yang Yuan. Asymmetric valleys: Beyond sharp and flat local minima. Advances in neural information processing systems, 32, 2019. [24] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 770–778, 2016. [25] Geoffrey E Hinton, Nitish Srivastava, Alex Krizhevsky, Ilya Sutskever, and Ruslan R Salakhutdinov. Improving neural networks by preventing co-adaptation of feature detectors. arXiv preprint arXiv:1207.0580, 2012. [26] Sepp Hochreiter and Jürgen Schmidhuber. Simplifying neural nets by discovering flat minima. Advances in neural information processing systems, 7, 1994. [27] Gao Huang, Zhuang Liu, Geoff Pleiss, Laurens Van Der Maaten, and Kilian Weinberger. Convolutional networks with dense connectivity. IEEE transactions on pattern analysis and machine intelligence, 2019. [28] Sergey Ioffe and Christian Szegedy. Batch normalization: Accelerating deep network training by reducing internal covariate shift. In International conference on machine learning, pages 448–456. PMLR, 2015. [29] Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov, and Andrew Gordon Wilson. Averaging weights leads to wider optima and better generalization. arXiv preprint arXiv:1803.05407, 2018. [30] Stanisław Jastrzębski, Zachary Kenton, Devansh Arpit, Nicolas Ballas, Asja Fischer, Yoshua Bengio, and Amos Storkey. Three factors influencing minima in sgd. arXiv preprint arXiv:1711.04623, 2017. [31] Stanislaw Jastrzebski, Maciej Szymczak, Stanislav Fort, Devansh Arpit, Jacek Tabor, Kyunghyun Cho, and Krzysztof J. Geras. The break-even point on optimization trajectories of deep neural networks. In 8th International Conference on Learning Representations, ICLR 2020, Addis Ababa, Ethiopia, April 26-30, 2020. OpenReview.net, 2020. [32] Zhiwei Jia and Hao Su. Information-theoretic local minima characterization and regularization. In International Conference on Machine Learning, pages 4773–4783. PMLR, 2020. [33] Lu Jiang, Di Huang, Mason Liu, and Weilong Yang. Beyond synthetic noise: Deep learning on controlled noisy labels. In International Conference on Machine Learning, pages 4804–4815. PMLR, 2020. [34] Yiding Jiang, Behnam Neyshabur, Hossein Mobahi, Dilip Krishnan, and Samy Bengio. Fantastic generalization measures and where to find them. arXiv preprint arXiv:1912.02178, 2019. [35] Simran Kaur, Jeremy Cohen, and Zachary C Lipton. On the maximum hessian eigenvalue and generalization. arXiv preprint arXiv:2206.10654, 2022. [36] Nitish Shirish Keskar, Dheevatsa Mudigere, Jorge Nocedal, Mikhail Smelyanskiy, and Ping Tak Peter Tang. On large-batch training for deep learning: Generalization gap and sharp minima. In International Conference on Learning Representations, 2017. [37] Taero Kim, Sungjun Lim, and Kyungwoo Song. Sharpness-aware minimization for worst case optimization. arXiv preprint arXiv:2210.13533, 2022. [38] Diederik P. Kingma and Jimmy Ba. Adam: A method for stochastic optimization. In Yoshua Bengio and Yann LeCun, editors, International Conference on Learning Representations, 2015. [39] Thomas N Kipf and Max Welling. Semi-supervised classification with graph convolutional networks. arXiv preprint arXiv:1609.02907, 2016. [40] Jonathan Krause, Michael Stark, Jia Deng, and Li Fei-Fei. 3d object representations for fine-grained categorization. In Proceedings of the IEEE international conference on computer vision workshops, pages 554–561, 2013. [41] Alex Krizhevsky, Geoffrey Hinton, et al. Learning multiple layers of features from tiny images. Citeseer, 2009. [42] Jungmin Kwon, Jeongseop Kim, Hyunseo Park, and In Kwon Choi. Asam: Adaptive sharpness-aware minimization for scale-invariant learning of deep neural networks. In International Conference on Machine Learning, pages 5905–5914. PMLR, 2021. [43] Beatrice Laurent and Pascal Massart. Adaptive estimation of a quadratic functional by model selection. Annals of Statistics, pages 1302–1338, 2000. [44] Aitor Lewkowycz, Yasaman Bahri, Ethan Dyer, Jascha Sohl-Dickstein, and Guy Gur-Ari. The large learning rate phase of deep learning: the catapult mechanism. arXiv preprint arXiv:2003.02218, 2020. [45] Hao Li, Zheng Xu, Gavin Taylor, Christoph Studer, and Tom Goldstein. Visualizing the loss landscape of neural nets. Advances in neural information processing systems, 31, 2018. [46] Liyuan Liu, Haoming Jiang, Pengcheng He, Weizhu Chen, Xiaodong Liu, Jianfeng Gao, and Jiawei Han. On the variance of the adaptive learning rate and beyond. In 8th International Conference on Learning Representations, ICLR 2020, Addis Ababa, Ethiopia, April 26-30, 2020. OpenReview.net, 2020. [47] Yong Liu, Siqi Mai, Xiangning Chen, Cho-Jui Hsieh, and Yang You. Towards efficient and scalable sharpness-aware minimization. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 12360–12370, 2022. [48] Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, and Baining Guo. Swin transformer: Hierarchical vision transformer using shifted windows. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pages 10012–10022, 2021. [49] Ilya Loshchilov and Frank Hutter. Decoupled weight decay regularization. arXiv preprint arXiv:1711.05101, 2017. [50] Liangchen Luo, Yuanhao Xiong, Yan Liu, and Xu Sun. Adaptive gradient methods with dynamic bound of learning rate. arXiv preprint arXiv:1902.09843, 2019. [51] Stephan Mandt, Matthew D Hoffman, and David M Blei. Stochastic gradient descent as approximate bayesian inference. Journal of Machine Learning Research, 18:1–35, 2017. [52] Peng Mi, Li Shen, Tianhe Ren, Yiyi Zhou, Xiaoshuai Sun, Rongrong Ji, and Dacheng Tao. Make sharpness-aware minimization stronger: A sparsified perturbation approach. arXiv preprint arXiv:2210.05177, 2022. [53] Yu E Nesterov. A method for solving the convex programming problem with convergence rate. In Dokl. Akad. Nauk SSSR,, volume 269, pages 543–547, 1983. [54] Omkar M Parkhi, Andrea Vedaldi, Andrew Zisserman, and CV Jawahar. Cats and dogs. In 2012 IEEE conference on computer vision and pattern recognition, pages 3498–3505. IEEE, 2012. [55] Henning Petzka, Michael Kamp, Linara Adilova, Cristian Sminchisescu, and Mario Boley. Relative flatness and generalization. Advances in Neural Information Processing Systems, 34:18420–18432, 2021. [56] Sashank J Reddi, Satyen Kale, and Sanjiv Kumar. On the convergence of adam and beyond. In International Conference on Learning Representations, 2018. [57] Shaoqing Ren, Kaiming He, Ross Girshick, and Jian Sun. Faster r-cnn: Towards real-time object detection with region proposal networks. Advances in neural information processing systems, 28, 2015. [58] Olga Russakovsky, Jia Deng, Hao Su, Jonathan Krause, Sanjeev Satheesh, Sean Ma, Zhiheng Huang, Andrej Karpathy, Aditya Khosla, Michael Bernstein, et al. Imagenet large scale visual recognition challenge. International journal of computer vision, 115(3):211–252, 2015. [59] Karen Simonyan and Andrew Zisserman. Very deep convolutional networks for large-scale image recognition. arXiv preprint arXiv:1409.1556, 2014. [60] Sidak Pal Singh and Dan Alistarh. Woodfisher: Efficient second-order approximation for neural network compression. Advances in Neural Information Processing Systems, 33:18098–18109, 2020. [61] Samuel L. Smith and Quoc V. Le. A bayesian perspective on generalization and stochastic gradient descent. In 6th International Conference on Learning Representations, ICLR 2018, Vancouver, BC, Canada, April 30 - May 3, 2018, Conference Track Proceedings. OpenReview.net, 2018. [62] Mingxing Tan and Quoc Le. Efficientnet: Rethinking model scaling for convolutional neural networks. In International conference on machine learning, pages 6105–6114. PMLR, 2019. [63] Yeming Wen, Kevin Luk, Maxime Gazeau, Guodong Zhang, Harris Chan, and Jimmy Ba. An empirical study of large-batch stochastic gradient descent with structured covariance noise. arXiv preprint arXiv:1902.08234, 2019. [64] Ashia C Wilson, Rebecca Roelofs, Mitchell Stern, Nati Srebro, and Benjamin Recht. The marginal value of adaptive gradient methods in machine learning. Advances in neural information processing systems, 30, 2017. [65] Saining Xie, Ross Girshick, Piotr Dollár, Zhuowen Tu, and Kaiming He. Aggregated residual transformations for deep neural networks. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 1492–1500, 2017. [66] Zeke Xie, Issei Sato, and Masashi Sugiyama. A diffusion theory for deep learning dynamics: Stochastic gradient descent exponentially favors flat minima. In International Conference on Learning Representations, 2021. [67] Zeke Xie, Qian-Yuan Tang, Yunfeng Cai, Mingming Sun, and Ping Li. On the power-law spectrum in deep learning: A bridge to protein science. arXiv preprint arXiv:2201.13011, 2022. [68] Zeke Xie, Xinrui Wang, Huishuai Zhang, Issei Sato, and Masashi Sugiyama. Adaptive inertia: Disentangling the effects of adaptive learning rate and momentum. In International Conference on Machine Learning, pages 24430–24459. PMLR, 2022. [69] Haoyi Xiong, Ruosi Wan, Jian Zhao, Zeyu Chen, Xingjian Li, Zhanxing Zhu, and Jun Huan. Grod: Deep learning with gradients orthogonal decomposition for knowledge transfer, distillation, and adversarial training. ACM Transactions on Knowledge Discovery from Data (TKDD), 16(6):1–25, 2022. [70] Yi Xu, Rong Jin, and Tianbao Yang. First-order stochastic algorithms for escaping from saddle points in almost linear time. Advances in neural information processing systems, 31, 2018. [71] Zhewei Yao, Amir Gholami, Kurt Keutzer, and Michael W Mahoney. Pyhessian: Neural networks through the lens of the hessian. In 2020 IEEE international conference on big data (Big data), pages 581–590. IEEE, 2020. [72] Zhewei Yao, Amir Gholami, Qi Lei, Kurt Keutzer, and Michael W Mahoney. Hessian-based analysis of large batch training and robustness to adversaries. Advances in Neural Information Processing Systems, 31, 2018. [73] Tom Young, Devamanyu Hazarika, Soujanya Poria, and Erik Cambria. Recent trends in deep learning based natural language processing. ieee Computational intelligenCe magazine, 13(3):55–75, 2018. [74] Sangdoo Yun, Dongyoon Han, Seong Joon Oh, Sanghyuk Chun, Junsuk Choe, and Youngjoon Yoo. Cutmix: Regularization strategy to train strong classifiers with localizable features. In Proceedings of the IEEE/CVF international conference on computer vision, pages 6023–6032, 2019. [75] Sergey Zagoruyko and Nikos Komodakis. Wide residual networks. arXiv preprint arXiv:1605.07146, 2016. [76] Manzil Zaheer, Sashank Reddi, Devendra Sachan, Satyen Kale, and Sanjiv Kumar. Adaptive methods for nonconvex optimization. Advances in neural information processing systems, 31, 2018. [77] Guodong Zhang, Lala Li, Zachary Nado, James Martens, Sushant Sachdeva, George Dahl, Chris Shallue, and Roger B Grosse. Which algorithmic choices matter at which batch sizes? insights from a noisy quadratic model. Advances in neural information processing systems, 32, 2019. [78] Hongyi Zhang, Moustapha Cisse, Yann N Dauphin, and David Lopez-Paz. mixup: Beyond empirical risk minimization. In International Conference on Learning Representations, 2018. [79] Xingxuan Zhang, Feng Cheng, and Shilin Wang. Spatio-temporal fusion based convolutional sequence learning for lip reading. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pages 713–722, 2019. [80] Xingxuan Zhang, Peng Cui, Renzhe Xu, Linjun Zhou, Yue He, and Zheyan Shen. Deep stable learning for out-of-distribution generalization. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 5372–5382, 2021. [81] Xingxuan Zhang, Linjun Zhou, Renzhe Xu, Peng Cui, Zheyan Shen, and Haoxin Liu. Towards unsupervised domain generalization. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 4910–4920, 2022. [82] Yang Zhao, Hao Zhang, and Xiuyuan Hu. Penalizing gradient norm for efficiently improving generalization in deep learning. In International Conference on Machine Learning, pages 26982–26992. PMLR, 2022. [83] Qihuang Zhong, Liang Ding, Li Shen, Peng Mi, Juhua Liu, Bo Du, and Dacheng Tao. Improving sharpness-aware minimization with fisher mask for better generalization on language models. arXiv preprint arXiv:2210.05497, 2022. [84] Daquan Zhou, Bingyi Kang, Xiaojie Jin, Linjie Yang, Xiaochen Lian, Zihang Jiang, Qibin Hou, and Jiashi Feng. Deepvit: Towards deeper vision transformer. arXiv preprint arXiv:2103.11886, 2021. [85] Fuzhen Zhuang, Zhiyuan Qi, Keyu Duan, Dongbo Xi, Yongchun Zhu, Hengshu Zhu, Hui Xiong, and Qing He. A comprehensive survey on transfer learning. Proceedings of the IEEE, 109(1):43–76, 2020. [86] Juntang Zhuang, Boqing Gong, Liangzhe Yuan, Yin Cui, Hartwig Adam, Nicha C Dvornek, James s Duncan, Ting Liu, et al. Surrogate gap minimization improves sharpness-aware training. In International Conference on Learning Representations, 2022. Appendix A Omitted details in Section 4 A.1 Derivation of Equation (7)

We follow the steps in [20] to approximate

∇ 𝑅 ( 1 ) ⁢ ( 𝜽 )

𝜌 ⋅ ∇ 𝜽 max 𝜖 ∈ 𝐵 ⁢ ( 0 , 𝜌 ) ⁡ ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 + 𝜖 ) ‖ . (12)

We first conduct the first-order Taylor expansion of ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 + 𝜖 ) ‖ and get that

𝜖 * ⁢ ( 𝜽 )

arg ⁢ max 𝜖 ∈ 𝐵 ⁢ ( 0 , 𝜌 ) ⁡ ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 + 𝜖 ) ‖ ≈ arg ⁢ max 𝜖 ∈ 𝐵 ⁢ ( 0 , 𝜌 ) ⁡ ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 ) ‖ + ( ∇ ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 ) ‖ ) ⊤ ⁢ 𝜖 (13)

arg ⁢ max 𝜖 ∈ 𝐵 ⁢ ( 0 , 𝜌 ) ( ∇ ∥ ∇ 𝐿 ^ ( 𝜽 ) ∥ ) ⊤ 𝜖

𝜌 ⋅ 𝒇 ‖ 𝒇 ‖ ,

where 𝒇

∇ ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 ) ‖ . As a result, by letting 𝜽 adv

𝜽 + 𝜖 * ⁢ ( 𝜽 ) ,

∇ 𝑅 ( 1 ) ⁢ ( 𝜽 ) ≈ 𝜌 ⋅ ∇ 𝜽 ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 + 𝜖 * ⁢ ( 𝜽 ) ) ‖

𝜌 ⋅ ∇ ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 adv ) ‖ + 𝜌 ⋅ d ⁢ 𝜖 * ⁢ ( 𝜽 ) d ⁢ 𝜽 ⋅ ∇ ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 adv ) ‖ . (14)

In addition, similar to [20], we further drop the second-order term to accelerate the computation. Finally, the derivative ∇ 𝑅 ( 1 ) ⁢ ( 𝜽 ) is given by

∇ 𝑅 ( 1 ) ⁢ ( 𝜽 ) ≈ 𝜌 ⋅ ∇ ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 adv ) ‖ , 𝜽 adv

𝜽 + 𝜌 ⋅ 𝒇 ‖ 𝒇 ‖ , 𝒇

∇ ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 ) ‖ . (15) Appendix B Proofs B.1 Proof of Lemma 4.1 Proof.

By assumption, we have that for all 𝜽 ∈ 𝐵 ⁢ ( 𝜽 * , 𝜌 ) ,

𝐿 ^ ⁢ ( 𝜽 )

𝐿 ^ ⁢ ( 𝜽 * ) + 1 2 ⁢ ( 𝜽 − 𝜽 * ) ⊤ ⁢ ( ∇ 2 𝐿 ^ ⁢ ( 𝜽 * ) ) ⁢ ( 𝜽 − 𝜽 * ) . (16)

In addition,

∇ 𝐿 ^ ⁢ ( 𝜽 )

( ∇ 2 𝐿 ^ ⁢ ( 𝜽 * ) ) ⁢ ( 𝜽 − 𝜽 * ) . (17)

As a result,

max 𝜽 ∈ 𝐵 ⁢ ( 𝜽 * , 𝜌 ) ⁡ ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 ) ‖

max 𝜽 ∈ 𝐵 ⁢ ( 𝜽 * , 𝜌 ) ⁡ ‖ ( ∇ 2 𝐿 ^ ⁢ ( 𝜽 * ) ) ⁢ ( 𝜽 − 𝜽 * ) ‖

𝜌 ⁢ ‖ ∇ 2 𝐿 ^ ⁢ ( 𝜽 * ) ‖

𝜌 ⁢ 𝜆 max ⁢ ( ∇ 2 𝐿 ^ ⁢ ( 𝜽 * ) ) . (18)

Now the claim follows. ∎

B.2 Proof of Proposition 4.2 Proof.

Suppose 𝜖 *

arg ⁢ max 𝜖 ∈ 𝐵 ⁢ ( 0 , 𝜌 ) ⁡ 𝐿 ^ ⁢ ( 𝜽 + 𝜖 ) . Then 𝑅 ( 0 ) ⁢ ( 𝜽 )

𝐿 ^ ⁢ ( 𝜽 + 𝜖 * ) − 𝐿 ^ ⁢ ( 𝜽 ) . According to the mean value theorem, there exists a constant 0 ≤ 𝑐 ≤ 1 such that

𝐿 ^ ⁢ ( 𝜽 + 𝜖 * ) − 𝐿 ^ ⁢ ( 𝜽 )

( ∇ 𝐿 ^ ⁢ ( 𝜽 + 𝑐 ⋅ 𝜖 * ) ) ⊤ ⁢ 𝜖 * . (19)

As a result, by the Cauchy–Schwarz inequality,

𝑅 ( 0 ) ⁢ ( 𝜽 )

𝐿 ^ ⁢ ( 𝜽 + 𝜖 * ) − 𝐿 ^ ⁢ ( 𝜽 )

( ∇ 𝐿 ^ ⁢ ( 𝜽 + 𝑐 ⋅ 𝜖 * ) ) ⊤ ⁢ 𝜖 * ≤ ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 + 𝑐 ⋅ 𝜖 * ) ‖ ⁢ ‖ 𝜖 * ‖ (20)
≤ max 𝜖 ∈ 𝐵 ⁢ ( 0 , 𝜌 ) ⁡ ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 + 𝜖 ) ‖ ⋅ 𝜌

𝑅 ( 1 ) ⁢ ( 𝜽 ) .

B.3 Proof of Proposition 4.3 Proof.

Define ℎ ⁢ ( 𝜽 )

max 𝜽 ′ ∈ 𝐵 ⁢ ( 𝜽 , 𝜌 ) ⁡ ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 ) ‖ . Fix 𝜎

𝜌 / ( 𝑑 + log ⁡ 𝑛 ) , following the proof of Theorem 1 in [20], we can obtain that with probability at least 1 − 𝛿 ,

𝔼 𝜖 𝑖 ∼ 𝑁 ⁢ ( 0 , 𝜎 2 ) ⁢ [ 𝐿 ⁢ ( 𝜽 + 𝜖 ) ] ≤ 𝔼 𝜖 𝑖 ∼ 𝑁 ⁢ ( 0 , 𝜎 2 ) ⁢ [ 𝐿 ^ ⁢ ( 𝜽 + 𝜖 ) ] + 1 4 ⁢ 𝑑 ⁢ log ⁡ ( 1 + ‖ 𝜽 ‖ 2 2 𝑑 ⁢ 𝜎 2 ) + 1 4 + log ⁡ 𝑛 𝛿 + 2 ⁢ log ⁡ ( 6 ⁢ 𝑛 + 3 ⁢ 𝑑 ) 𝑛 − 1 . (21)

Since 𝜖 𝑖 ∼ 𝑁 ⁢ ( 0 , 𝜎 2 ) , ‖ 𝜖 ‖ 2 / 𝜎 2 has a chi-square distribution. As a result, according to [43, Lemma 1], we have that for any 𝑡

0 ,

ℙ ⁢ ( ‖ 𝜖 ‖ 2 / 𝜎 2 − 𝑑 ≥ 2 ⁢ 𝑑 ⁢ 𝑡 + 2 ⁢ 𝑡 ) ≤ exp ⁡ ( − 𝑡 ) . (22)

By letting 𝑡

1 2 ⁢ log ⁡ 𝑛 , we can get that with probability at least 1 − 1 / 𝑛 ,

‖ 𝜖 ‖ 2 ≤ 𝜎 2 ⁢ ( 𝑑 + 2 ⁢ 𝑑 ⁢ log ⁡ 𝑛 + log ⁡ 𝑛 ) ≤ 𝜎 2 ⁢ ( 𝑑 + log ⁡ 𝑛 ) 2

𝜌 2 . (23)

As a result,

𝔼 𝜖 𝑖 ∼ 𝑁 ⁢ ( 0 , 𝜎 2 ) ⁢ [ 𝐿 ^ ⁢ ( 𝜽 + 𝜖 ) ] (24)

𝔼 𝜖 𝑖 ∼ 𝑁 ⁢ ( 0 , 𝜎 2 ) [ 𝐿 ^ ( 𝜽 + 𝜖 ) ∣ ∥ 𝜖 ∥ ≤ 𝜌 ] ℙ ( ∥ 𝜖 ∥ ≤ 𝜌 ) + 𝔼 𝜖 𝑖 ∼ 𝑁 ⁢ ( 0 , 𝜎 2 ) [ 𝐿 ^ ( 𝜽 + 𝜖 ) ∣ ∥ 𝜖 ∥

𝜌 ] ℙ ( ∥ 𝜖 ∥

𝜌 )

𝔼 𝜖 𝑖 ∼ 𝑁 ⁢ ( 0 , 𝜎 2 ) [ 𝐿 ^ ( 𝜽 + 𝜖 ) ∣ ∥ 𝜖 ∥ ≤ 𝜌 ] + 𝑀 𝑛 .

According to the mean value theorem and Cauchy–Schwarz inequality, for any 𝜖 such that ‖ 𝜖 ‖ < 𝜌 , there exists a constant 0 ≤ 𝑐 ≤ 1 , such that

𝐿 ^ ⁢ ( 𝜽 + 𝜖 )

𝐿 ^ ⁢ ( 𝜽 ) + ( ∇ 𝐿 ^ ⁢ ( 𝜽 + 𝑐 ⁢ 𝜖 ) ) ⊤ ⁢ 𝜖 ≤ 𝐿 ^ ⁢ ( 𝜽 ) + ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 + 𝑐 ⁢ 𝜖 ) ‖ ⋅ ‖ 𝜖 ‖ ≤ 𝐿 ^ ⁢ ( 𝜽 ) + ℎ ⁢ ( 𝜽 ) ⁢ 𝜌

𝐿 ^ ⁢ ( 𝜽 ) + 𝑅 ( 1 ) ⁢ ( 𝜽 ) . (25)

Now the claim follows from Equations (21), (24), and (25). ∎

B.4 Proof of Theorem 4.4 Proof.

Observe that

‖ ∇ 𝐿 overall ⁢ ( 𝜽 𝑡 ) ‖ 2

‖ ∇ 𝐿 oracle ⁢ ( 𝜽 𝑡 ) + 𝛼 ⁢ 𝜌 𝑡 ⋅ ∇ ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 𝑡 adv ) ‖ ‖ 2 (26)

 2 ⁢ ( ‖ ∇ 𝐿 oracle ⁢ ( 𝜽 𝑡 ) ‖ 2 + ‖ 𝛼 ⁢ 𝜌 𝑡 ⋅ ∇ ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 𝑡 adv ) ‖ ‖ 2 ) .

The claim follows from Propositions B.1 and B.2. ∎

Proposition B.1.

Assume the conditions in Theorem 4.4 hold (with parameters 𝛾 1 , 𝛾 2 , 𝐺 𝑙𝑜𝑠𝑠 , 𝐺 𝑛𝑜𝑟𝑚 , 𝐺 ~ 𝑙𝑜𝑠𝑠 , 𝑀 , 𝜂 0 , 𝜌 0 , 𝛼 ). Then with learning rate 𝜂 𝑡

𝜂 0 / 𝑡 and perturbation radius 𝜌 𝑡

𝜌 0 / 𝑡 , Algorithm 1 could obtain

1 𝑇 ⁢ ∑ 𝑡

1 𝑇 𝔼 ⁢ [ ‖ ∇ 𝐿 𝑜𝑟𝑎𝑐𝑙𝑒 ⁢ ( 𝜽 𝑡 ) ‖ 2 ] ≤ 𝐶 1 ′ + 𝐶 2 ′ ⁢ log ⁡ 𝑇 𝑇 (27)

for some constants 𝐶 1 ′ and 𝐶 2 ′ that only depend on 𝛾 1 , 𝛾 2 , 𝐺 𝑙𝑜𝑠𝑠 , 𝐺 𝑛𝑜𝑟𝑚 , 𝐺 ~ 𝑙𝑜𝑠𝑠 , 𝑀 , 𝜂 0 , 𝜌 0 , 𝛼 .

Proof.

By definition, we have 𝒉 𝑡 loss

𝑔 ~ 𝑡 loss ⁢ ( 𝜽 𝑡 ) and 𝒉 𝑡 norm

𝑔 𝑡 norm ⁢ ( 𝜽 𝑡 adv ) . By assumption,

𝐿 oracle ⁢ ( 𝜽 𝑡 + 1 )
≤ 𝐿 oracle ⁢ ( 𝜽 𝑡 ) + ( ∇ 𝐿 oracle ⁢ ( 𝜽 𝑡 ) ) ⊤ ⁢ ( 𝜽 𝑡 + 1 − 𝜽 𝑡 ) + 𝛾 1 2 ⁢ ‖ 𝜽 𝑡 + 1 − 𝜽 𝑡 ‖ 2 (28)

𝐿 oracle ⁢ ( 𝜽 𝑡 ) − 𝜂 𝑡 ⁢ ( ∇ 𝐿 oracle ⁢ ( 𝜽 𝑡 ) ) ⊤ ⁢ ( 𝒉 𝑡 loss + 𝛼 ⁢ 𝜌 𝑡 ⁢ 𝒉 𝑡 norm ) + 𝛾 1 ⁢ 𝜂 𝑡 2 2 ⁢ ‖ 𝒉 𝑡 loss + 𝛼 ⁢ 𝜌 𝑡 ⁢ 𝒉 𝑡 norm ‖ 2 .

Take the expectation conditioned on the observations till timestamp 𝑡 . By the assumption 𝔼 ⁢ [ 𝒉 𝑡 loss ]

𝔼 ⁢ [ 𝑔 ~ 𝑡 loss ⁢ ( 𝜽 𝑡 ) ]

∇ 𝐿 oracle ⁢ ( 𝜽 𝑡 ) and 𝔼 ⁢ [ 𝒉 𝑡 norm ]

𝔼 ⁢ [ 𝑔 𝑡 norm ⁢ ( 𝜽 𝑡 adv ) ] , we can obtain that

𝔼 ⁢ [ 𝐿 oracle ⁢ ( 𝜽 𝑡 + 1 ) ] − 𝐿 oracle ⁢ ( 𝜽 𝑡 ) (29)

− 𝜂 𝑡 ⁢ ‖ ∇ 𝐿 oracle ⁢ ( 𝜽 𝑡 ) ‖ 2 − 𝜂 𝑡 ⁢ 𝜌 𝑡 ⁢ 𝛼 ⁢ ( ∇ 𝐿 oracle ⁢ ( 𝜽 𝑡 ) ) ⊤ ⁢ 𝔼 ⁢ [ 𝑔 𝑡 norm ⁢ ( 𝜽 𝑡 adv ) ] + 𝛾 1 ⁢ 𝜂 𝑡 2 2 ⁢ ‖ 𝒉 𝑡 loss + 𝛼 ⁢ 𝜌 𝑡 ⁢ 𝒉 𝑡 norm ‖ 2

We have

− 𝜂 𝑡 ⁢ 𝜌 𝑡 ⁢ 𝛼 ⁢ ( ∇ 𝐿 oracle ⁢ ( 𝜽 𝑡 ) ) ⊤ ⁢ 𝔼 ⁢ [ 𝑔 𝑡 norm ⁢ ( 𝜽 𝑡 adv ) ] ≤ 𝜂 𝑡 ⁢ 𝜌 𝑡 ⁢ 𝛼 ⁢ ‖ ∇ 𝐿 oracle ⁢ ( 𝜽 𝑡 ) ‖ ⁢ ‖ 𝔼 ⁢ [ 𝑔 𝑡 norm ⁢ ( 𝜽 𝑡 adv ) ] ‖ ≤ 𝜂 𝑡 ⁢ 𝜌 𝑡 ⁢ 𝛼 ⁢ 𝐺 ~ loss ⁢ 𝐺 norm . (30)

In addition,

𝔼 ⁢ [ ‖ 𝒉 𝑡 loss + 𝛼 ⁢ 𝒉 𝑡 norm ‖ 2 ] ≤ 2 ⁢ 𝔼 ⁢ [ ‖ 𝒉 𝑡 loss ‖ 2 ] + 2 ⁢ 𝛼 2 ⁢ 𝔼 ⁢ [ ‖ 𝒉 𝑡 norm ‖ 2 ] ≤ 2 ⁢ ( 𝐺 ~ loss ) 2 + 2 ⁢ 𝛼 2 ⁢ ( 𝐺 norm ) 2 . (31)

Combining Equations (29), (30), and (31), we can get that

𝜂 𝑡 ⁢ ‖ ∇ 𝐿 oracle ⁢ ( 𝜽 𝑡 ) ‖ 2 ≤ − 𝔼 ⁢ [ 𝐿 oracle ⁢ ( 𝜽 𝑡 + 1 ) ] + 𝐿 oracle ⁢ ( 𝜽 𝑡 ) + 𝜂 𝑡 ⁢ 𝜌 𝑡 ⁢ 𝑍 1 + 𝜂 𝑡 2 ⁢ 𝑍 2 (32)

for some constants 𝑍 1 and 𝑍 2 that only depend on 𝛾 1 , 𝛾 2 , 𝐺 loss , 𝐺 norm , 𝐺 ~ loss , 𝛼 . Now perform telescope sum and take the expectations at each step, we can obtain that

∑ 𝑡

1 𝑇 𝜂 𝑡 ⁢ ‖ ∇ 𝐿 oracle ⁢ ( 𝜽 𝑡 ) ‖ 2 ≤ − 𝔼 ⁢ [ 𝐿 oracle ⁢ ( 𝜽 𝑇 + 1 ) ] + 𝐿 oracle ⁢ ( 𝜽 1 ) + 𝑍 1 ⁢ ∑ 𝑡

1 𝑇 𝜂 𝑡 ⁢ 𝜌 𝑡 + 𝑍 2 ⁢ ∑ 𝑡

1 𝑇 𝜂 𝑡 2 . (33)

By letting 𝜂 𝑡

𝜂 0 / 𝑡 and 𝜌 𝑡

𝜌 0 / 𝑡 , we can get that

𝜂 0 𝑇 ⁢ ∑ 𝑡

1 𝑇 ‖ ∇ 𝐿 oracle ⁢ ( 𝜽 𝑡 ) ‖ 2
≤ ∑ 𝑡

1 𝑇 𝜂 𝑡 ⁢ ‖ ∇ 𝐿 oracle ⁢ ( 𝜽 𝑡 ) ‖ 2 (34)
≤ − 𝔼 ⁢ [ 𝐿 oracle ⁢ ( 𝜽 𝑇 + 1 ) ] + 𝐿 oracle ⁢ ( 𝜽 1 ) + 𝑍 1 ⁢ ∑ 𝑡

1 𝑇 𝜂 𝑡 ⁢ 𝜌 𝑡 + 𝑍 2 ⁢ ∑ 𝑡

1 𝑇 𝜂 𝑡 2

≤ 2 ⁢ 𝑀 + 𝑍 1 ⁢ 𝜂 0 ⁢ 𝜌 0 ⁢ ∑ 𝑡

1 𝑇 1 𝑡 + 𝑍 2 ⁢ 𝜂 0 2 ⁢ ∑ 𝑡

1 𝑇 1 𝑡

≤ 𝑍 4 + 𝑍 5 ⁢ log ⁡ 𝑇

for some constants 𝑍 4 and 𝑍 5 that only depend on 𝛾 1 , 𝛾 2 , 𝐺 loss , 𝐺 norm , 𝐺 ~ loss , 𝑀 , 𝜂 0 , 𝜌 0 , 𝛼 . Divide the two sides of the equation by 𝜂 0 ⁢ 𝑇 and the claim follows. ∎

Proposition B.2.

Assume the conditions in Theorem 4.4 hold (with parameters 𝛾 1 , 𝛾 2 , 𝐺 𝑙𝑜𝑠𝑠 , 𝐺 𝑛𝑜𝑟𝑚 , 𝐺 ~ 𝑙𝑜𝑠𝑠 , 𝑀 , 𝜂 0 , 𝜌 0 , 𝛼 ). Then with perturbation radius 𝜌 𝑡

𝜌 0 / 𝑡 , Algorithm 1 could obtain

1 𝑇 ⁢ ∑ 𝑡

1 𝑇 𝔼 ⁢ [ ‖ 𝛼 ⁢ 𝜌 𝑡 ⋅ ∇ ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 𝑎𝑑𝑣 ) ‖ ‖ 2 ] ≤ 𝐶 1 ′′ + 𝐶 2 ′′ ⁢ log ⁡ 𝑇 𝑇 (35)

for some constants 𝐶 1 ′′ and 𝐶 2 ′′ that only depend on 𝛾 2 , 𝜌 0 , 𝛼 .

Proof.

For any 𝑡 ∈ { 1 , 2 , … , 𝑇 } ,

𝔼 ⁢ [ ‖ 𝛼 ⁢ 𝜌 𝑡 ⋅ ∇ ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 adv ) ‖ ‖ 2 ]

𝛼 2 ⁢ 𝜌 𝑡 2 ⁢ 𝔼 ⁢ [ ‖ ∇ ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 adv ) ‖ ‖ 2 ] (36)

𝛼 2 ⁢ 𝜌 𝑡 2 ⁢ 𝔼 ⁢ [ ‖ ∇ 2 𝐿 ^ ⁢ ( 𝜽 𝑡 adv ) ⋅ ∇ 𝐿 ^ ⁢ ( 𝜽 𝑡 adv ) ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 𝑡 adv ) ‖ ‖ ] ≤ 𝛼 2 ⁢ 𝜌 𝑡 2 ⁢ 𝔼 ⁢ [ ‖ ∇ 2 𝐿 ^ ⁢ ( 𝜽 𝑡 adv ) ‖ ⁢ ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 𝑡 adv ) ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 𝑡 adv ) ‖ ‖ ]


𝛼 2 ⁢ 𝜌 𝑡 2 ⁢ 𝔼 ⁢ [ 𝛾 2 ]

𝛼 2 ⁢ 𝜌 𝑡 2 ⁢ 𝛾 2 .

By letting 𝜌 𝑡

𝜌 0 / 𝑡 ,

1 𝑇 ⁢ ∑ 𝑡

1 𝑇 𝔼 ⁢ [ ‖ 𝛼 ⁢ 𝜌 𝑡 ⋅ ∇ ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 adv ) ‖ ‖ 2 ] ≤ 1 𝑇 ⁢ 𝛼 2 ⁢ 𝛾 2 ⁢ 𝜌 0 2 ⁢ ∑ 𝑡

1 𝑇 1 𝑡 ≤ 𝐶 1 ′′ + 𝐶 2 ′′ ⁢ log ⁡ 𝑇 𝑇 (37)

for some constants 𝐶 1 ′′ and 𝐶 2 ′′ that only depend on 𝛾 2 , 𝜌 0 , 𝛼 . ∎

Appendix C More Experimental Results and Details C.1 More Results on Training from Scratch

Due to space limitation, we omit the experimental results on CIFAR-10 and CIFAR-100 [41] with ResNeXt [65], DenseNet [27] and ViTs [16] in Section 5.2.1 in the main paper and report them in Section C.1.1. Then we report the results of robustness to label noise in Section C.2.

C.1.1 CIFAR-10 and CIFAR-100

We report the omitted results on CIFAR-10 and CIFAR-100 with ResNeXt, DenseNet and ViTs. As described in the main paper, all the models are trained for 200 epochs from scratch. We evaluate GAM both with basic data augmentations (i.e., horizontal flip, padding by four pixels, and random crop) and advanced data augmentation including cutout regularization [15], RandAugment [13] and AutoAugment [14]. The hyperparameters, 𝜌 and 𝛼 are searched with the same approach described in the main paper.

Table 4: Results of GAM with state-of-the-art models on CIFAR-10 and CIFAR-100. The best results are highlighted in bold font. CIFAR-10 CIFAR-100 Model Aug SGD SGD + GAM SAM SAM + GAM SGD SGD + GAM SAM SAM + GAM DenseNet121 Basic 91.16 ± 0.13 92.35 ± 0.14 92.19 ± 0.20 92.72 ± 0.30 69.25 ± 0.40 70.48 ± 0.27 70.44 ± 0.19 71.16 ± 0.25

DenseNet121 Cutout 91.85 ± 0.17 92.93 ± 0.26 92.35 ± 0.16 93.30 ± 0.17 70.17 ± 0.31 71.47 ± 0.23 70.89 ± 0.15 71.80 ± 0.07

DenseNet121 RA 91.59 ± 0.16 92.37 ± 0.20 92.32 ± 0.29 92.97 ± 0.27 69.65 ± 0.36 70.10 ± 0.26 70.49 ± 0.16 71.43 ± 0.17

DenseNet121 AA 92.65 ± 0.10 94.17 ± 0.25 92.96 ± 0.19 94.05 ± 0.22 70.53 ± 0.17 72.25 ± 0.19 71.34 ± 0.20 72.90 ± 0.17

ResNeXt29-32x4d Basic 95.75 ± 0.31 96.46 ± 0.25 96.32 ± 0.36 96.90 ± 0.24 79.45 ± 0.29 81.67 ± 0.26 81.35 ± 0.12 82.93 ± 0.25

ResNeXt29-32x4d Cutout 96.20 ± 0.37 97.82 ± 0.24 96.44 ± 0.21 97.85 ± 0.27 80.56 ± 0.20 82.62 ± 0.33 82.49 ± 0.25 83.58 ± 0.09

ResNeXt29-32x4d RA 95.86 ± 0.28 97.17 ± 0.26 96.75 ± 0.35 97.79 ± 0.19 79.88 ± 0.12 81.75 ± 0.23 82.26 ± 0.15 83.02 ± 0.22

ResNeXt29-32x4d AA 96.58 ± 0.18 97.46 ± 0.15 97.38 ± 0.25 97.58 ± 0.16 80.47 ± 0.13 82.02 ± 0.19 81.52 ± 0.26 83.35 ± 0.09

ViT-S/16 Basic 95.27 ± 0.23 97.21 ± 0.14 96.85 ± 0.25 97.58 ± 0.20 79.52 ± 0.36 83.35 ± 0.28 82.77 ± 0.29 84.30 ± 0.25

ViT-S/16 Cutout 95.36 ± 0.22 97.53 ± 0.17 97.10 ± 0.25 97.85 ± 0.10 79.36 ± 0.21 83.59 ± 0.28 82.86 ± 0.14 84.53 ± 0.16

ViT-S/16 RA 95.59 ± 0.19 97.44 ± 0.31 97.18 ± 0.12 97.59 ± 0.11 79.96 ± 0.22 83.80 ± 0.20 83.36 ± 0.13 84.66 ± 0.23

ViT-S/16 AA 96.40 ± 0.30 97.82 ± 0.21 97.52 ± 0.25 97.97 ± 0.13 80.35 ± 0.06 84.02 ± 0.18 83.54 ± 0.19 85.20 ± 0.26

Results are shown in Table 4. GAM consistently improves generalization for all models. We observe the same results as in the main paper. When combined with SGD, GAM achieves considerably higher test accuracy compared with SGD. And GAM also achieves improvement when combined with SAM.

Comparison with GNP

GNP can be considered as a special case of GAM where 𝜌 is set to 0. We compare GAM with GNP [82] in on CIFAR-100 in Table 5. We follow hyperparameters searching and choice of GNP in its original paper. GAM consistently outperforms GNP by noticeable margins.

Table 5: Comparison with GNP on CIFAR-100. -BA indicates the basic data augmentation, -CU indicates cutout regularization, -RA indicates RandAugment, and -AA indicates AutoAugment. Res18-BA Res18-CU Res18-RA Res18-AA Res101-BA Res101-CU Res101-RA Res101-AA SGD 78.32 ± 0.32

78.73 ± 0.13 78.62 ± 0.32

78.88 ± 0.15

80.47 ± 0.13

80.53 ± 0.30

80.60 ± 0.28

81.83 ± 0.37

GNP + SGD 78.80 ± 0.40

79.29 ± 0.15

79.21 ± 0.27

80.29 ± 0.05

81.17 ± 0.29

81.10 ± 0.14

81.31 ± 0.88

82.53 ± 0.25

GAM + SGD 79.53 ± 0.30 79.89 ± 0.31 79.82 ± 0.24 80.56 ± 0.21 82.21 ± 0.40 82.36 ± 0.24 82.40 ± 0.31 83.19 ± 0.15

WRN10-BA	WRN10-CU	WRN10-RA	WRN10-AA	Pyr110-BA	Pyr110-CU	Pyr110-RA	Pyr110-AA

SGD 81.40 ± 0.13

81.53 ± 0.40

81.65 ± 0.18

81.99 ± 0.11

82.74 ± 0.12

83.31 ± 0.21

84.04 ± 0.19

84.48 ± 0.03

GNP + SGD 82.30 ± 0.05

82.54 ± 0.19

82.99 ± 0.39

83.58 ± 0.32

83.99 ± 0.27

84.46 ± 0.16

84.47 ± 0.08

84.83 ± 0.21

GAM + SGD 83.45 ± 0.09 83.69 ± 0.08 83.84 ± 0.09 84.02 ± 0.18 84.91 ± 0.09 85.20 ± 0.19 86.47 ± 0.14 85.92 ± 0.03 C.2 Robustness to Label Noise

It is observed that sharpness-aware minimization methods are robust to perturbations to label noise [20, 42]. Here we assess the degree of robustness that GAM provides to label noise.

Following [20, 42], we measure the effectiveness of GAM in the classical noisy-label setting for CIFAR-10. A fraction of the training data labels is randomly flipped [33] while the test data remains unmodified. We train a ResNet32 for 200 epochs following [33]. Hyperparameter settings for all the models are the same as that of previous CIFAR experiments. Following [3, 42], we report the best results during the training instead of the results at the end of the training.

We report test accuracies for SGD, SAM, SGD+GAM, and SAM+GAM obtained from 3 independent runs for each label noise level in Table 6. As seen in Table 6, GAM shows a high degree of robustness to label noise. GAM consistently improves the robustness to label noise for both SGD and SAM.

Table 6: Test accuracy of ResNet32 on CIFAR-10 with label noise. Noise Rate (%) Base Opt Base + GAM SAM SAM + GAM 0% 95.71 96.55 96.25 96.88 20% 92.05 94.20 93.85 94.74 40% 88.89 92.17 90.85 92.57 60% 83.17 88.49 87.37 89.65 80% 63.16 73.82 70.65 76.33 C.3 Detailed Results and Discussions about Computation Overhead Table 7: Accuracy and training speed of training with different ratios of iterations using GAM. Superscripts indicate the ratio of iterations in each epoch is trained with GAM (e.g., GAM 0.05 indicates that 5% of iterations are trained with GAM, while the remaining iterations are trained with the basic optimizer). Numbers in parentheses indicate the ratio of the training speed compared with the vanilla base optimizer SGD/SAM. We mark runs whose training speed is lower than 50% of the basic optimizer in red and others in green. Please note that the speed of SAM is about 50% w.r.t SGD’s speed. Thus when combined with SGD, green markers indicate that the speed of GAM under the corresponding ratio is faster than SAM. CIFAR-10 SGD SGD + GAM 0.05 SGD + GAM 0.1 SGD + GAM 0.5 SGD + GAM 1

Accuracy 95.32 96.08 96.15 96.17 96.17 Images/s 2,593 (100%) 2,258 (87%) 1,996 (77%) 1,023 (39%) 658 (25%) SAM SAM + GAM 0.05 SAM + GAM 0.1 SAM + GAM 0.5 SAM + GAM 1

Accuracy 96.10 96.54 96.62 96.65 96.58 Images/s 1,314 (100%) 1,247 (95%) 1,184 (90%) 858 (65%) 629 (48%) CIFAR-100 SGD SGD + GAM 0.05 SGD + GAM 0.1 SGD + GAM 0.5 SGD + GAM 1

Accuracy 78.32 79.25 79.42 79.50 79.53 Images/s 2,609 (100%) 2,243 (86%) 1,955 (75%) 1,011 (39%) 655 (25%) SAM SAM + GAM 0.05 SAM + GAM 0.1 SAM + GAM 0.5 SAM + GAM 1

Accuracy 79.27 80.08 80.44 80.40 80.45 Images/s 1,318 (100%) 1,251 (95%) 1,172 (89%) 848 (64%) 628 (48%)

Here we report the detailed results of the trade-off between computation overhead and test accuracy of GAM. As discussed in Section 4.3 and 5.5 in the main paper, the GAM term can be easily calculated via the Hessian vector product, which is an efficient approach to calculating the dot product between the Hessian and a vector without the need to calculate the entire Hessian. And we also notice that only several iterations of learning with GAM (with higher 𝛼 compared with applying GAM to all iterations) improve model generalization considerably. As seen in Table 7, applying GAM to 1/10 iterations of training shows similar generalization performance to applying GAM to all the iterations, while the extra computational cost for GAM is less than 25% of the original cost. Thus the computation overhead of GAM can be easily controlled.

C.4 Ablation Study

GAM has two hyperparameters, 𝜌 , and 𝛼 . We analyze the influence of the choice of them in the following subsection C.4.1 and C.4.2. The results are shown in Figure 6.

(a) (b) Figure 6: The influence of hyperparameters 𝜌 and 𝛼 on the performance of ResNet101 on CIFAR-100. C.4.1 Influence of 𝜌

𝜌 controls the step length of gradient ascent in GAM. When 𝜌 is set to 0, GAM degenerates into a naive regularizer constraining the gradient norm at each training step. We plot the performance of ResNet101 on CIFAR-100 with varying 𝜌 . For experiments where GAM is combined with SAM, we apply the same 𝜌 for both GAM and SAM. As shown in Figure 5(a), GAM with 𝜌 larger than 0 outperforms GAM without gradient ascent, showing that the gradient ascent is necessary for GAM. Moreover, GAM consistently outperforms SGD with different 𝜌 . GAM also consistently improves SAM’s performance with various 𝜌 , indicating that the first-order flatness can improve the generalization ability of zero-order flatness.

C.4.2 Influence of 𝛼

𝛼 controls the strength of GAM penalty. When 𝛼 is set to 0 GAM degenerates into the basic optimizer (SGD or SAM). We show the performance of GAM with varying 𝛼 in Figure 5(b). Compared with SGD, GAM shows considerable improvement with varying 𝛼 . The improvement of GAM under various 𝛼 is also observed when combined with SAM.

C.5 Training Details and Selection of Hyperparameters C.5.1 Training Details of Training from Scratch Experiments Table 8: Hyperparameters for Algorithm 1 on CIFAR-10 and CIFAR-100 datasets. Model Learning Rate Weight Decay Base Optimizer Epochs LR Schedule ResNet18 0.1 0.0005 SGD 200 Cosine ResNet101 0.1 0.0005 SGD 200 Cosine WRN28_2 0.1 0.0005 SGD 200 Cosine WRN28_10 0.1 0.0005 SGD 200 Cosine PyramidNet110 0.05 0.0005 SGD 200 Cosine DenseNet121 0.1 0.001 SGD 200 Cosine ResNeXt29-32x4d 0.1 0.0005 SGD 200 Cosine Table 9: Hyperparameters for Algorithm 1 on ImageNet. Model 𝜌

𝛼 Learning Rate Weight Decay Base Optimizer Epochs LR Schedule ResNet50 0.2 0.1 0.1 0.0001 SGD 90 Cosine ResNet101 0.2 0.1 0.1 0.0001 SGD 90 Cosine ViT-S/32 0.3 0.5 0.0003 0.3 AdamW 300 Cosine ViT-B/32 0.3 0.5 0.0003 0.3 AdamW 300 Cosine

We search hyperparameters, including learning rate and weight decay for all the models unless otherwise noted. For ResNets, we conduct a grid search of learning rate in {0.01, 0.1, 1.0} and weight decay in {0.0001, 0.0005, 0.001, 0.01, 0.1}. The batch size is set to 128 for all models. For Vits, we search the learning rate in {1e-3, 3e-3, 1e-2, 3e-3}, and weight decay in {0.001, 0.01, 0.1}. We adopt SGD with momentum = 0.9 for ResNets and AdamW with 𝛽 1 = 0.9, 𝛽 2 = 0.999 for ViTs. We train ResNets for 90 epochs, and train ViTs for 300 epochs following [11, 86]. We first search for the optimal learning rate and weight decay for training with basic optimizers and keep them fixed for SAM and GAM. We search 𝜌 in {0.05, 0.1, 0.2, 0.5, 1.0, 2.0} for both SAM and GAM and search 𝛼 in {0.1, 0.2, 0.5, 1.0, 2.0, 3.0, …, 10.0} for GAM. We set 𝜌 to 0.04 for CIFAR-10 and 0.1 for CIFAR-100. We set 𝛼 to 0.3 for ResNet-18 and 0.1 for other models. We report the best selection of hyperparameters for each individual model in Table 8 and Table 9.

C.5.2 Training Details of Transfer Learning Experiments

We finetune the models on downstream datasets including Stanford Cars [40], CIFAR-10, CIFAR-100 [41], Oxford_IIIT_Pets [54] and Food101 [8] from the weights pretrained on ImageNet. For EfficientNet-b0, we adopt SGD with momentum = 0.9. For Swin-t, we adopt AdamW with 𝛽 1 = 0.9, 𝛽 2 = 0.999. We train all the models for 40k steps and the batch size is 128. The initial learning rate is 2e-3 and the cosine learning rate decay is used. Weight decay is set to 1e-5.

Appendix D Further Acceleration of GAM Acceleration

Optimizing the gradient of 𝑅 𝜌 ( 1 ) ⁢ ( 𝜽 ) according to Equations (7) and (8) requires the Hessian vector product operation, which can still introduce considerable extra computation when the model is large. Inspired by [60, 82], one can approximate ∇ ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 ) ‖ with first-order gradient as follows.

∀ 𝜽 ∈ Θ , ∇ ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 ) ‖ ≈ ∇ 𝐿 ^ ⁢ ( 𝜽 + 𝜌 ′ ⋅ ∇ 𝐿 ^ ⁢ ( 𝜽 ) ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 ) ‖ ) − ∇ 𝐿 ^ ⁢ ( 𝜽 ) 𝜌 ′ , (38)

where 𝜌 ′ is a small constant. Thus we can further accelerate GAM by applying Equation (38) to the 𝜽 adv term in Equation (7) as follows,

𝜽 adv ≈ 𝜽 + 𝜌 ⋅ ∇ 𝐿 ^ ⁢ ( 𝜽 + 𝜌 ′ ⋅ ∇ 𝐿 ^ ⁢ ( 𝜽 ) ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 ) ‖ ) − ∇ 𝐿 ^ ⁢ ( 𝜽 ) ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 + 𝜌 ′ ⋅ ∇ 𝐿 ^ ⁢ ( 𝜽 ) ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 ) ‖ ) − ∇ 𝐿 ^ ⁢ ( 𝜽 ) ‖

𝜽 + 𝜌 ⋅ 𝒈 1 − 𝒈 0 ‖ 𝒈 1 − 𝒈 0 ‖ , (39)

where

𝒈 0

∇ 𝐿 ^ ⁢ ( 𝜽 ) , 𝒈 1

∇ 𝐿 ^ ⁢ ( 𝜽 ~ 1 ) , and 𝜽 ~ 1

𝜽 + 𝜌 ′ ⋅ ∇ 𝐿 ^ ⁢ ( 𝜽 ) ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 ) ‖

𝜽 + 𝜌 ′ ⋅ 𝒈 0 ‖ 𝒈 0 ‖ . (40)

We let 𝜽 ~ 2 ≜ 𝜽 adv . Now applying Equation (38) to the calculation of ∇ ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 ~ 2 ) ‖ , we can get that

∇ 𝑅 𝜌 ( 1 ) ⁢ ( 𝜽 ) ≈ 𝜌 ⋅ ∇ ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 ~ 2 ) ‖
≈ 𝜌 ⋅ ∇ 𝐿 ^ ⁢ ( 𝜽 ~ 2 + 𝜌 ′ ⋅ ∇ 𝐿 ^ ⁢ ( 𝜽 ~ 2 ) ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 ~ 2 ) ‖ ) − ∇ 𝐿 ^ ⁢ ( 𝜽 ~ 2 ) 𝜌 ′ (41)

𝜌 𝜌 ′ ⁢ ( 𝒈 3 − 𝒈 2 ) ,

where

𝒈 2

∇ 𝐿 ^ ⁢ ( 𝜽 ~ 2 ) , 𝒈 3

∇ 𝐿 ^ ⁢ ( 𝜽 ~ 3 ) , and 𝜽 ~ 3

𝜽 ~ 2 + 𝜌 ′ ⋅ ∇ 𝐿 ^ ⁢ ( 𝜽 ~ 2 ) ‖ ∇ 𝐿 ^ ⁢ ( 𝜽 ~ 2 ) ‖

𝜽 ~ 2 + 𝜌 ′ ⋅ 𝒈 2 ‖ 𝒈 2 ‖ . (42) Accelerated GAM

Based on the above approximations, we could obtain the accelerated version of GAM as shown in Algorithm 2. Besides them, the following modifications can further accelerate and suppress the effects of the above approximation.

Algorithm 2 Accelerated GAM 1:Input: Batch size 𝑏 , Learning rate 𝜂 𝑡 , Perturbation radius 𝜌 𝑡 , 𝜌 𝑡 ′ , Trade-off coefficient 𝛼 , 𝛽 , 𝛾 , Small constant 𝜉 2: 𝑡 ← 0 , 𝜽 0 ← initial parameters 3:while  𝜽 𝑡 not converged do 4:     Sample 𝑊 𝑡 from the training data with 𝑏 instances 5:      𝒈 𝑡 , 0 ← ∇ 𝐿 ^ 𝑊 𝑡 ⁢ ( 𝜽 𝑡 ) 6:      𝜽 ~ 𝑡 , 1 ← 𝜽 𝑡 + 𝜌 𝑡 ′ ⋅ 𝒈 𝑡 , 0 / ( ‖ 𝒈 𝑡 , 0 ‖ + 𝜉 ) 7:      𝒈 𝑡 , 1 ← ∇ 𝐿 ^ 𝑊 𝑡 ⁢ ( 𝜽 ~ 𝑡 , 1 ) 8:      𝒉 𝑡 , 0 ← 𝒈 𝑡 , 1 − 𝒈 𝑡 , 0 9:      𝜽 ~ 𝑡 , 2 ← 𝜽 𝑡 + 𝜌 𝑡 ⋅ 𝒉 𝑡 , 0 / ( ‖ 𝒉 𝑡 , 0 ‖ + 𝜉 ) 10:      𝒈 𝑡 , 2 ← ∇ 𝐿 ^ 𝑊 𝑡 ⁢ ( 𝜽 ~ 𝑡 , 2 ) 11:      𝜽 ~ 𝑡 , 3 ← 𝜽 ~ 𝑡 , 2 + 𝜌 𝑡 ′ ⋅ 𝒈 𝑡 , 2 / ( ‖ 𝒈 𝑡 , 2 ‖ + 𝜉 ) 12:      𝒈 𝑡 , 3 ← ∇ 𝐿 ^ 𝑊 𝑡 ⁢ ( 𝜽 ~ 𝑡 , 3 ) 13:      𝒉 𝑡 , + ← 𝛼 ⁢ 𝒈 𝑡 , 1 + ( 1 − 𝛼 ) ⁢ 𝒈 𝑡 , 3 14:      𝒉 𝑡 , − ← 𝛽 ⁢ 𝒈 𝑡 , 0 + ( 1 − 𝛽 ) ⁢ 𝒈 𝑡 , 2 15:      𝒉 𝑡 , − ∥ , 𝒉 𝑡 , − ⟂ ← decompose ⁢ ( 𝒉 𝑡 , − ; 𝒉 𝑡 , + ) ▷ Decompose 𝒉 𝑡 , − into components that are parallel or orthogonal to 𝒉 𝑡 , + 16:      𝜽 𝑡 + 1 ← 𝜽 𝑡 − 𝜂 𝑡 ⁢ ( 𝒉 𝑡 , + − 𝛾 ⁢ 𝒉 𝑡 , − ⟂ ) 17:      𝑡 ← 𝑡 + 1 18:end while 19:return 𝜽 𝑡 1.

We find that the gradient of the SAM regularization term, 𝒈 1 − 𝒈 0 , is already calculated during the approximation steps. As a result, we directly optimize the target 𝐿 ^ ⁢ ( 𝜽 ) + 𝛼 ′ ⁢ 𝑅 𝜌 ( 1 ) ⁢ ( 𝜽 ) + 𝛽 ′ ⁢ 𝑅 𝜌 ′ ( 0 ) ⁢ ( 𝜽 ) with hyper-parameter 𝛼 ′ , 𝛽 ′ . The gradient of the target can be approximated as

∇ ( 𝐿 ^ ⁢ ( 𝜽 ) + 𝛼 ′ ⁢ 𝑅 𝜌 ( 1 ) ⁢ ( 𝜽 ) + 𝛽 ′ ⁢ 𝑅 𝜌 ′ ( 0 ) ⁢ ( 𝜽 ) ) ≈ 𝒈 0 + 𝛼 ′ ⁢ 𝜌 𝜌 ′ ⁢ ( 𝒈 3 − 𝒈 2 ) + 𝛽 ′ ⁢ ( 𝒈 1 − 𝒈 0 )

𝛽 ′ ⁢ 𝒈 1 + 𝛼 ′ ⁢ 𝜌 𝜌 ′ ⁢ 𝒈 3 − ( 𝛽 ′ − 1 ) ⁢ 𝒈 0 − 𝛼 ′ ⁢ 𝜌 𝜌 ′ ⁢ 𝒈 2 , (43)

which means that the gradient of our target is a linear combination of 𝒈 0 , 𝒈 1 , 𝒈 2 , 𝒈 3 . We further set three hyper-parameters to control the importance of these parts and preserve the signs of different parts, i.e.,

∇ ( 𝐿 ^ ⁢ ( 𝜽 ) + 𝛼 ′ ⁢ 𝑅 𝜌 ( 1 ) ⁢ ( 𝜽 ) + 𝛽 ′ ⁢ 𝑅 𝜌 ′ ( 0 ) ⁢ ( 𝜽 ) ) ≈ 𝛼 ⁢ 𝒈 1 + ( 1 − 𝛼 ) ⁢ 𝒈 3 − 𝛾 ⁢ ( 𝛽 ⁢ 𝒈 0 + ( 1 − 𝛽 ) ⁢ 𝒈 2 ) (44)

with 1 ≥ 𝛼 , 𝛽 ≥ 0 , 𝛾 ≥ 0 .

We find that the negative parts of Equation (44) may have side effects on the model’s convergence and performance, and thus require fine-tuning of hyperparameters. Inspired by [69, 86], we decompose 𝛽 ⁢ 𝒈 0 + ( 1 − 𝛽 ) ⁢ 𝒈 2 into components that are parallel and orthogonal to 𝛼 ⁢ 𝒈 1 + ( 1 − 𝛼 ) ⁢ 𝒈 3 . Specifically, we first let

𝒉 +

𝛼 ⁢ 𝒈 1 + ( 1 − 𝛼 ) ⁢ 𝒈 3 , 𝒉 −

𝛽 ⁢ 𝒈 0 + ( 1 − 𝛽 ) ⁢ 𝒈 2 (45)

to denote the positive and negative parts in Equation (44), respectively. We then decompose the negative part 𝒉 − into two components that are parallel or orthogonal to 𝒉 + and we get 𝒉 − ∥ and 𝒉 − ⟂ . As a result, the final gradient is given by 𝒉 + − 𝛾 ⁢ 𝒉 − ⟂ .

Generated on Thu Jul 13 17:13:16 2023 by LATExml

Xet Storage Details

Size:
87.5 kB
·
Xet hash:
cc86664820c5f3709f3bcb866f638bf192c9e41a52d231831d2264b5658b9e3e

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.