Buckets:
Title: When to Learn What: Model-Adaptive Data Augmentation Curriculum
URL Source: https://arxiv.org/html/2309.04747
Published Time: Wed, 04 Oct 2023 01:00:32 GMT
Markdown Content: When to Learn What: Model-Adaptive Data Augmentation Curriculum
3 Method 1. 3.1 When to Augment: Monotonic Curriculum 2. 3.2 What Augmentations to Apply: Model-Adaptive Data Augmentation 3. 3.3 Joint Training of Task model & MADAug Policy
4 Experiments 1. 4.1 Augmentation Operations 2. 4.2 Implementation Details 3. 4.3 Main Results 4. 4.4 Transferability of MADAug-Learned Policy 5. 4.5 Analysis of MADAug Augmentations 6. 4.6 Computing cost of MADAug 7. 4.7 Mean and variance of the experiment results
5 Ablation study 1. Magnitude perturbation. 2. Number of augmentation operations. 3. Structure of policy network. 4. Hyperparameter of τ 𝜏\tau italic_τ. 5. Analysis of optimization steps. 6. Effect of monotonic curriculum. 7. Strategy of MADAug.
A Appendix 1. Quality of augmentation policies on Reduced SVHN. 2. Advantages of MADAug over AdaAug. 3. Model hyperparameters.
When to Learn What: Model-Adaptive Data Augmentation Curriculum
Chengkai Hou 1 1{}^{1}start_FLOATSUPERSCRIPT 1 end_FLOATSUPERSCRIPT, Jieyu Zhang 2 2{}^{2}start_FLOATSUPERSCRIPT 2 end_FLOATSUPERSCRIPT, Tianyi Zhou 3 3{}^{3}start_FLOATSUPERSCRIPT 3 end_FLOATSUPERSCRIPT
1 1{}^{1}start_FLOATSUPERSCRIPT 1 end_FLOATSUPERSCRIPT Jilin university, 2 2{}^{2}start_FLOATSUPERSCRIPT 2 end_FLOATSUPERSCRIPT University of Washington, 3 3{}^{3}start_FLOATSUPERSCRIPT 3 end_FLOATSUPERSCRIPT University of Maryland
houck20@mails.jlu.edu.cn, jieyuz2@cs.washington.edu, tianyi@umd.edu
Corresponding author
Abstract
Data augmentation (DA) is widely used to improve the generalization of neural networks by enforcing the invariances and symmetries to pre-defined transformations applied to input data. However, a fixed augmentation policy may have different effects on each sample in different training stages but existing approaches cannot adjust the policy to be adaptive to each sample and the training model. In this paper, we propose “M odel-A daptive D ata Aug mentation (MADAug)” that jointly trains an augmentation policy network to teach the model “when to learn what”. Unlike previous work, MADAug selects augmentation operators for each input image by a model-adaptive policy varying between training stages, producing a data augmentation curriculum optimized for better generalization. In MADAug, we train the policy through a bi-level optimization scheme, which aims to minimize a validation-set loss of a model trained using the policy-produced data augmentations. We conduct an extensive evaluation of MADAug on multiple image classification tasks and network architectures with thorough comparisons to existing DA approaches. MADAug outperforms or is on par with other baselines and exhibits better fairness: it brings improvement to all classes and more to the difficult ones. Moreover, MADAug learned policy shows better performance when transferred to fine-grained datasets. In addition, the auto-optimized policy in MADAug gradually introduces increasing perturbations and naturally forms an easy-to-hard curriculum. Our code is available at https://github.com/JackHck/MADAug.
1 Introduction
Data augmentation is a widely used strategy to increase the diversity of training data, which improves the model generalization, especially in image recognition tasks[21, 35, 17]. Unlike previous works that apply manually-designed augmentation operations[6, 44, 47, 23, 4, 24], recent researchers have resorted to searching a data augmentation policy for a target dataset/samples. Despite the success of these learnable and dataset-dependent augmentation policies, they are fixed once learned and thus non-adaptive to either different samples or models at different training stages, resulting in biases across different data regions[2] or inefficient training.
In this paper, we study two fundamental problems towards developing a data-and-model-adaptive data augmentation policy that determines a curriculum of “when to learn what” to train a model: (1)when to apply data augmentation in training?(2)what data augmentations should be applied to each training sample at different training stages?
First, applying data augmentation does not always bring improvement over the whole course of training. For example, we observed that a model tends to learn faster during earlier training stages without using data augmentation. We hypothesize that models at the early stage of training even have no capability to recognize the original images so excessively augmented images are not conducive to the convergence of the models. Motivated by this observation, we first design a strategy called monotonic curriculum to progressively introduce more augmented data to the training. In particular, we gradually increase the probability of applying data augmentation to each sample by following the Tanh function (see Figure1), so the model can be quickly improved in earlier stages without distractions from augmentations while reaching a better performance in later stages through learning from augmented data.
Figure 1: MADAug applies a monotonic curriculum to gradually introduce more data augmentations to the task model training and uses a policy network to choose augmentations for each training sample. MADAug trains the policy to minimize the validation loss of the task model, so the augmentations are model-adaptive and optimized for different training stages.
Secondly, a fixed augmentation policy is not optimal for learning every sample or different training stages. Although the monotonic curriculum gradually increases the augmented data as the model improves, it does not determine which augmentations applied to each sample can bring the most improvement to the model training. Intuitively, the model can learn more from diverse data augmentations. Moreover, the difficulty of augmented data also has a great impact on the training and it depends on both the augmentations and the sample they are applied to. For example, “simple” augmentation is preferred in the early stages to accelerate model convergence but more challenging augmented data provide additional information for learning more robust features for better generalization in the later stage. One plausible strategy is leveraging expert knowledge and advice to adjust the augmentation operation and their strengths[29, 47, 34, 14]. In this paper, instead of relying on human experts, we regard the evaluation of the current model on a validation set as an expert to guide the optimization of augmentation policies applied to each sample in different training stages. As illustrated in Figure1, we utilize a policy network to produce the augmentations for each sample (i.e., data-adaptive) used to train the task model, while the training objective of the policy network is to minimize the validation loss of the task model (i.e., model-adaptive). This is a challenging bi-level optimization[5]. To address it, we train the task model on adaptive augmentations of training data and update the policy network to minimize the validation loss in an online manner. Thereby, the policy network is dynamically adapted to different training stages of the task model and generates customized augmentations for each sample. This results in a curriculum of data augmentations optimized for improving the generalization performance of the task model.
Our main contributions can be summarized as follows:
- (a)A monotonic curriculum gradually introducing more data augmentation to the training process.
- (b)MADAug that trains a data augmentation policy network on the fly with the task model training. The policy automatically selects augmentations for each training sample and for different training stages.
- (c)Experiments on CIFAR-10/100, SVHN, and ImageNet demonstrate that MADAug consistently brings greater improvement to task models than existing data augmentation methods in terms of test-set performance.
- (d)The augmentation policy network learned by MADAug on one dataset is transferable to unseen datasets and downstream tasks, producing better models than other baselines.
2 Related Work
Random crop and horizontal flip operations are commonly employed as standard data augmentation techniques for images in deep learning. Recently, there are significant advancements in advanced data augmentation techniques that have significantly increased the accuracy of image recognition tasks[46, 41, 43, 38, 9, 16, 17]. However, data augmentations may only be applicable to certain domains, and heuristically selected transformations, such as transplanting transformations that are effective in one domain into another, could have the opposite effect[2]. Thus, the exploration of optimal data augmentation policies necessitates specialized domain knowledge.
AutoAugment[6] adopts reinforcement learning to automatically find an available augmentation policy. However, AutoAugment requires thousands of GPU hours to find the policies on a reduced setting and limits randomness on the augmentation policies. To tackle these challenges, searching the optimal data augmentation strategies has become a prominent research topic and many methods have been proposed [46, 36, 13, 24, 14, 23, 25, 22, 44, 18, 45, 4].
These methods can be broadly classified into two distinct categories: fixed augmentation policies and online augmentation policies. The first category of methods[13, 23, 24, 47, 6, 45, 4] employs subsets of the training data and/or smaller models to efficiently discover fixed augmentation policies. However, the limited randomness in these policies makes it challenging to generate suitable samples for various stages of training. Thus, the fixed augmentation policies are suboptimal. The second category of methods[7, 36, 26, 22, 30, 44, 25] focuses on directly finding dynamic augmentation policies on the task model. This strategy is increasingly recognized as the primary choice for data augmentation search.
RandAugment[7] and TrivialAugment[30] are typically the second type of methods for finding online augmentations. They randomly select the augmentation parameters without relying on any external knowledge or prior information. Other methods, such as Adversarial AutoAugment[44], generate the adversarial augmentations by maximizing the training loss. However, the inherent instability of adversarial augmentations, without appropriate constraints, poses a risk of distorting the intrinsic meanings of images. To avoid this collapse, TeachAugment[36] utilizes the “teacher knowledge” to effectively restrict adversarial augmentations. However, Adversarial AutoAugment[44] and TeachAugment[36] both offer “hard” augmentations rather than “adoptive” augmentations, which are not effective to enhance the model generalization at the early training stage, because models at the early training even do not recognize the primitive images. “Hard” augmentations are reluctant to converge the model. Thus, in our paper, we gradually apply the data augmentations for samples and track the model performance on the validation set to adjust the policies through the original bi-level optimization during the model training.
3 Method
In this section, we first propose monotonic curriculum which progressively introduces more augmented samples as the training epoch increases. We then introduce the policy network that generates model-adaptive data augmentations and study how to train it through bi-level optimization with the task model.
3.1 When to Augment: Monotonic Curriculum
Figure 2: Test accuracy on Reduced CIFAR-10. No Augmentation does not apply any augmentations. Human-designed Augmentation always applies human pre-defined augmentations. Monotonic Curriculum gradually increases the probability of applying human-designed augmentations.
Previous studies[6, 23, 24, 13] have adopted the data augmentations for the whole model training process. However, at the early stage of model training, the model doesn’t even recognize the original images. In this case, is data augmentation effective? In Figure2, the test accuracy of a model trained on the Reduced CIFAR-10 dataset drops in the first ∼70 similar-to absent 70\sim 70∼ 70 epochs if applying human-designed data augmentations. To address this problem, at the beginning of model training, we only apply augmentations to a randomly sampled subset of training images while keeping the rest as original. In the later training stages, we apply a monotonic curriculum that gradually increases the proportion of images to be augmented or the probability of applying augmentation. Specifically, the proportion/probability p(t)𝑝 𝑡 p(t)italic_p ( italic_t ) increases with the number of epochs by following a schedule defined by tanh\tanh roman_tanh, i.e.,
p(t)=tanh(t/τ)𝑝 𝑡 𝑡 𝜏 p(t)=\tanh(t/\tau)italic_p ( italic_t ) = roman_tanh ( italic_t / italic_τ )(1)
where t 𝑡 t italic_t is the current training epoch number and τ 𝜏\tau italic_τ is a manually adjustable hyperparameter that controls the change of proportion. Therefore, the early-stage model is mainly trained on the original images without augmentations, which helps the premature model converge quickly. However, as training proceeds, the model fully learned the original images and its training can benefit more from the augmented images. To validate the efficiency of our strategy, compare with the images without augmentation policy or with the fixed human-design augmentation policies, our method can effectively boost model performance during various training stages (see Figure2).
3.2 What Augmentations to Apply: Model-Adaptive Data Augmentation
Instead of constantly applying the same data augmentation policies to all samples over the whole training process, adjusting the policy for each sample and model in different training stages can provide better guidance to the task model and thus accelerate its training towards better validation accuracy.
Following AdaAug[4], we assign an augmentation probability p 𝑝 p italic_p and magnitude λ 𝜆\lambda italic_λ to each sample. The augmentation probability vector p 𝑝 p italic_p contains the possibility p i subscript 𝑝 𝑖 p_{i}italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT of applying each augmentation-i 𝑖 i italic_i, i.e., ∑i=1 n p i=1 superscript subscript 𝑖 1 𝑛 subscript 𝑝 𝑖 1\sum_{i=1}^{n}p_{i}=1∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 1, where there are n 𝑛 n italic_n possible augmentation operations. The augmentation magnitude vector λ 𝜆\lambda italic_λ contains the associated augmentation strengths such that λ i∈[0,1]subscript 𝜆 𝑖 0 1\lambda_{i}\in[0,1]italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ [ 0 , 1 ]. In the training process, for every training image x 𝑥 x italic_x, we draw k 𝑘 k italic_k operations without replacement according to p 𝑝 p italic_p and build an augmentation policy based on them and their magnitude in λ 𝜆\lambda italic_λ. In particular, each sampled augmentation operator-j 𝑗 j italic_j is applied to the image x 𝑥 x italic_x with magnitude λ j subscript 𝜆 𝑗\lambda_{j}italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT, resulting in an augmented image Γ j(x)≜τ j(x;λ j)≜superscript Γ 𝑗 𝑥 subscript 𝜏 𝑗 𝑥 subscript 𝜆 𝑗\Gamma^{j}(x)\triangleq\tau_{j}(x;\lambda_{j})roman_Γ start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ( italic_x ) ≜ italic_τ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_x ; italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ). By applying the k 𝑘 k italic_k sampled augmentations, the final augmented image γ(x)𝛾 𝑥\gamma(x)italic_γ ( italic_x ) can be written as:
Γ t(x)=τ j(x;λ j),j∼p γ(x)=Γ k∘⋯∘Γ 1(x),\begin{split}&\Gamma^{t}(x)=\tau_{j}(x;\lambda_{j}),\quad j\sim p\ &\gamma(x)=\Gamma^{k}\circ\cdots\circ\Gamma^{1}(x),\end{split}start_ROW start_CELL end_CELL start_CELL roman_Γ start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ( italic_x ) = italic_τ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_x ; italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) , italic_j ∼ italic_p end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL italic_γ ( italic_x ) = roman_Γ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ∘ ⋯ ∘ roman_Γ start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ( italic_x ) , end_CELL end_ROW(2)
where ∘\circ∘ is the compositional operator.
An arbitrary augmentation policy is not guaranteed to improve the performance of a task model but a brute-force search is not practically feasible. Hence, we optimize a policy model producing the optimal augmentation probability vector p 𝑝 p italic_p and magnitude vector λ 𝜆\lambda italic_λ for each image at different training stages. For image x 𝑥 x italic_x, we define f(x;w)𝑓 𝑥 𝑤 f(x;w)italic_f ( italic_x ; italic_w ) as the task model with parameter w 𝑤 w italic_w and g w(x)subscript 𝑔 𝑤 𝑥 g_{w}(x)italic_g start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ( italic_x ) as the intermediate-layer representation of image x 𝑥 x italic_x extracted from the task model f(x;w)𝑓 𝑥 𝑤 f(x;w)italic_f ( italic_x ; italic_w ). The policy model p(⋅;θ)𝑝⋅𝜃 p(\cdot;\theta)italic_p ( ⋅ ; italic_θ ) with parameters θ 𝜃\theta italic_θ takes the extracted feature g w(x)subscript 𝑔 𝑤 𝑥 g_{w}(x)italic_g start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ( italic_x ) as input and outputs the probability vector p 𝑝 p italic_p and magnitude vector λ 𝜆\lambda italic_λ for the image x 𝑥 x italic_x. The parameter w 𝑤 w italic_w of the task model is optimized by minimizing the following training loss on the training set 𝒟 tr={x i,y i}i=1 N tr superscript 𝒟 𝑡 𝑟 superscript subscript subscript 𝑥 𝑖 subscript 𝑦 𝑖 𝑖 1 superscript 𝑁 𝑡 𝑟\mathcal{D}^{tr}=\left{x_{i},y_{i}\right}_{i=1}^{N^{tr}}caligraphic_D start_POSTSUPERSCRIPT italic_t italic_r end_POSTSUPERSCRIPT = { italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUPERSCRIPT italic_t italic_r end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT:
w=argmin w ℒ tr(w;θ)=1 N tr∑i=1 N tr ℒ CE(f(γ(x i);w),y i),𝑤 subscript 𝑤 superscript ℒ 𝑡 𝑟 𝑤 𝜃 1 superscript 𝑁 𝑡 𝑟 superscript subscript 𝑖 1 superscript 𝑁 𝑡 𝑟 subscript ℒ 𝐶 𝐸 𝑓 𝛾 subscript 𝑥 𝑖 𝑤 subscript 𝑦 𝑖\centering w=\mathop{\arg\min}\limits_{w}\mathcal{L}^{tr}(w;\theta)=\frac{1}{N% ^{tr}}\sum_{i=1}^{N^{tr}}\mathcal{L}{CE}(f(\gamma(x{i});w),y_{i}),@add@centering italic_w = start_BIGOP roman_arg roman_min end_BIGOP start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT caligraphic_L start_POSTSUPERSCRIPT italic_t italic_r end_POSTSUPERSCRIPT ( italic_w ; italic_θ ) = divide start_ARG 1 end_ARG start_ARG italic_N start_POSTSUPERSCRIPT italic_t italic_r end_POSTSUPERSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUPERSCRIPT italic_t italic_r end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_C italic_E end_POSTSUBSCRIPT ( italic_f ( italic_γ ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ; italic_w ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ,(3)
where the augmented training image γ(x i)𝛾 subscript 𝑥 𝑖\gamma(x_{i})italic_γ ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) is generated by the policy network p(g w(x i);θ)𝑝 subscript 𝑔 𝑤 subscript 𝑥 𝑖 𝜃 p(g_{w}(x_{i});\theta)italic_p ( italic_g start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ; italic_θ ) and ℒ CE(⋅,⋅)subscript ℒ 𝐶 𝐸⋅⋅\mathcal{L}{CE}(\cdot,\cdot)caligraphic_L start_POSTSUBSCRIPT italic_C italic_E end_POSTSUBSCRIPT ( ⋅ , ⋅ ) is the cross-entropy loss. The policy model is to produce augmentation policies applied to the training of the task model and its optimization objective is to minimize the trained task model’s loss on a validation set, i.e., 𝒟 val={x i val,y i val}i=1 N val superscript 𝒟 𝑣 𝑎 𝑙 superscript subscript superscript subscript 𝑥 𝑖 𝑣 𝑎 𝑙 superscript subscript 𝑦 𝑖 𝑣 𝑎 𝑙 𝑖 1 superscript 𝑁 𝑣 𝑎 𝑙\mathcal{D}^{val}=\left{x{i}^{val},y_{i}^{val}\right}_{i=1}^{N^{val}}caligraphic_D start_POSTSUPERSCRIPT italic_v italic_a italic_l end_POSTSUPERSCRIPT = { italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_v italic_a italic_l end_POSTSUPERSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_v italic_a italic_l end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUPERSCRIPT italic_v italic_a italic_l end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT. The above problem can be formulated as the bi-level optimization[5] below:
min θ ℒ val(w∗(θ))=1 N val∑i=1 N val ℒ i val(w∗(θ))s.t.w∗(θ)=arg min w ℒ tr(w;θ)\begin{split}&\min_{\theta}\quad\mathcal{L}^{val}(w^{\ast}(\theta))=\frac{1}{N% ^{val}}\sum_{i=1}^{N^{val}}\mathcal{L}^{val}{i}(w^{\ast}(\theta))\ &s.t.\quad w^{\ast}(\theta)=\arg\min{w}\mathcal{L}^{tr}(w;\theta)\end{split}start_ROW start_CELL end_CELL start_CELL roman_min start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT caligraphic_L start_POSTSUPERSCRIPT italic_v italic_a italic_l end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_θ ) ) = divide start_ARG 1 end_ARG start_ARG italic_N start_POSTSUPERSCRIPT italic_v italic_a italic_l end_POSTSUPERSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUPERSCRIPT italic_v italic_a italic_l end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT caligraphic_L start_POSTSUPERSCRIPT italic_v italic_a italic_l end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_θ ) ) end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL italic_s . italic_t . italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_θ ) = roman_arg roman_min start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT caligraphic_L start_POSTSUPERSCRIPT italic_t italic_r end_POSTSUPERSCRIPT ( italic_w ; italic_θ ) end_CELL end_ROW(4)
where ℒ i val(w∗(θ))=ℒ CE(f(x i val;w∗(θ)),y)subscript superscript ℒ 𝑣 𝑎 𝑙 𝑖 superscript 𝑤∗𝜃 subscript ℒ 𝐶 𝐸 𝑓 superscript subscript 𝑥 𝑖 𝑣 𝑎 𝑙 superscript 𝑤∗𝜃 𝑦\mathcal{L}^{val}{i}(w^{\ast}(\theta))=\mathcal{L}{CE}(f(x_{i}^{val};w^{\ast% }(\theta)),y)caligraphic_L start_POSTSUPERSCRIPT italic_v italic_a italic_l end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_θ ) ) = caligraphic_L start_POSTSUBSCRIPT italic_C italic_E end_POSTSUBSCRIPT ( italic_f ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_v italic_a italic_l end_POSTSUPERSCRIPT ; italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_θ ) ) , italic_y ). Bi-level optimization is challenging because the lower-level optimization (i.e., the optimization of w 𝑤 w italic_w) does not have a closed-form solution that can be substituted into the higher-level optimization (i.e., the optimization of θ 𝜃\theta italic_θ). Recent work[33, 34, 47, 14, 29] address this problem (4) by alternating minimization. In this paper, we employ the same strategy as[34, 47, 1].
3.3 Joint Training of Task model & MADAug Policy
To address the bi-level optimization, we alternately update θ 𝜃\theta italic_θ and w 𝑤 w italic_w by first optimizing the policy network θ 𝜃\theta italic_θ for a task model w^^𝑤\hat{w}over^ start_ARG italic_w end_ARG achieved by one-step training and then update w 𝑤 w italic_w using the augmentations produced by the new policy network θ 𝜃\theta italic_θ.
We split the original training set into two disjoint sets, i.e., a training set and a validation set. Each iteration trains the model on a mini-batch of n tr superscript 𝑛 𝑡 𝑟 n^{tr}italic_n start_POSTSUPERSCRIPT italic_t italic_r end_POSTSUPERSCRIPT images 𝒟 m i tr={x i,y i}i=1 n tr subscript superscript 𝒟 𝑡 𝑟 subscript 𝑚 𝑖 superscript subscript subscript 𝑥 𝑖 subscript 𝑦 𝑖 𝑖 1 superscript 𝑛 𝑡 𝑟\mathcal{D}^{tr}{m{i}}=\left{x_{i},y_{i}\right}{i=1}^{n^{tr}}caligraphic_D start_POSTSUPERSCRIPT italic_t italic_r end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT = { italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUPERSCRIPT italic_t italic_r end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT drawn from the training set. Let ℒ tr(w t;θ t)=ℒ CE(f(γ(x i);w t),y i)superscript ℒ 𝑡 𝑟 subscript 𝑤 𝑡 subscript 𝜃 𝑡 subscript ℒ 𝐶 𝐸 𝑓 𝛾 subscript 𝑥 𝑖 subscript 𝑤 𝑡 subscript 𝑦 𝑖\mathcal{L}^{tr}(w{t};\theta_{t})=\mathcal{L}{CE}(f(\gamma(x{i});w_{t}),y_{% i})caligraphic_L start_POSTSUPERSCRIPT italic_t italic_r end_POSTSUPERSCRIPT ( italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) = caligraphic_L start_POSTSUBSCRIPT italic_C italic_E end_POSTSUBSCRIPT ( italic_f ( italic_γ ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ; italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) denote the lower-level objective for optimizing w t subscript 𝑤 𝑡 w_{t}italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. We apply one-step gradient descent on w t subscript 𝑤 𝑡 w_{t}italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT to achieve a closed-form surrogate of the lower-level problem solution, i.e.,
w t^=w t−α1 n tr∑i=1 n tr∇w ℒ tr(w t;θ t)^subscript 𝑤 𝑡 subscript 𝑤 𝑡 𝛼 1 superscript 𝑛 𝑡 𝑟 superscript subscript 𝑖 1 superscript 𝑛 𝑡 𝑟 subscript∇𝑤 superscript ℒ 𝑡 𝑟 subscript 𝑤 𝑡 subscript 𝜃 𝑡\hat{w_{t}}=w_{t}-\alpha\frac{1}{n^{tr}}\sum_{i=1}^{n^{tr}}\nabla_{w}\mathcal{% L}^{tr}(w_{t};\theta_{t})over^ start_ARG italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG = italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_α divide start_ARG 1 end_ARG start_ARG italic_n start_POSTSUPERSCRIPT italic_t italic_r end_POSTSUPERSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUPERSCRIPT italic_t italic_r end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT caligraphic_L start_POSTSUPERSCRIPT italic_t italic_r end_POSTSUPERSCRIPT ( italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT )(5)
where α 𝛼\alpha italic_α is a learning rate. However, we cannot use back-propagation to optimize θ t subscript 𝜃 𝑡\theta_{t}italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT for the high-level optimization because the sampling process of the k 𝑘 k italic_k augmentation operations in γ(x i)𝛾 subscript 𝑥 𝑖\gamma(x_{i})italic_γ ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) is non-differentiable. Hence, back-propagation cannot compute the partial derivative w.r.t. the augmentation probability p 𝑝 p italic_p and magnitude λ 𝜆\lambda italic_λ. To address this problem, we relax the non-differentiable γ(x i)𝛾 subscript 𝑥 𝑖\gamma(x_{i})italic_γ ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) to be a differentiable operator. Since the augmentation policy in most previous work[6, 18] only consists of two operations, for k=2 𝑘 2 k=2 italic_k = 2, γ(x i)𝛾 subscript 𝑥 𝑖\gamma(x_{i})italic_γ ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) can be relaxed as
γ(x i)≈∑j 1=1 n∑j 2=1 n p ij 1⋅p ij 2 Γ ij 2 2(Γ ij 1 1(x i))))j 1≠j 2\gamma(x_{i})\approx\sum_{j_{1}=1}^{n}\sum_{j_{2}=1}^{n}p_{ij_{1}}\cdot p_{ij_% {2}}\Gamma_{ij_{2}}^{2}(\Gamma_{ij_{1}}^{1}(x_{i}))))\quad j_{1}\neq j_{2}italic_γ ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ≈ ∑ start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_i italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⋅ italic_p start_POSTSUBSCRIPT italic_i italic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_Γ start_POSTSUBSCRIPT italic_i italic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( roman_Γ start_POSTSUBSCRIPT italic_i italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) ) ) italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ≠ italic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT(6)
where Γ ij k t(x i)=τ j k(x i;λ j k)superscript subscript Γ 𝑖 subscript 𝑗 𝑘 𝑡 subscript 𝑥 𝑖 subscript 𝜏 subscript 𝑗 𝑘 subscript 𝑥 𝑖 subscript 𝜆 subscript 𝑗 𝑘\Gamma_{ij_{k}}^{t}(x_{i})=\tau_{j_{k}}(x_{i};\lambda_{j_{k}})roman_Γ start_POSTSUBSCRIPT italic_i italic_j start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) = italic_τ start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_λ start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) applies augmentation-j k subscript 𝑗 𝑘 j_{k}italic_j start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT (with magnitude λ j k subscript 𝜆 subscript 𝑗 𝑘\lambda_{j_{k}}italic_λ start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT) to x i subscript 𝑥 𝑖 x_{i}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT in the t 𝑡 t italic_t-th augmentation operation. The relaxed γ(x i)𝛾 subscript 𝑥 𝑖\gamma(x_{i})italic_γ ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) is differentiable by combining different augmentations according to weights as their probabilities, so we can estimate the partial derivatives w.r.t. p 𝑝 p italic_p via back-propagation through Eq.6. In our approach, the forward pass still uses the sampling-based γ(x i)𝛾 subscript 𝑥 𝑖\gamma(x_{i})italic_γ ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ), whereas the backward pass uses its differentiable relaxation in Eq.6.
For back-propagation through the augmentation magnitude vector λ 𝜆\lambda italic_λ, we apply the straight-through gradient estimator[3, 39] because the magnitudes of some operations such as “Posterize” and “Solarize” are discrete variables that only have finite choices. In previous approaches[4, 13], the loss’s gradient w.r.t. λ m subscript 𝜆 𝑚\lambda_{m}italic_λ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT is estimated by applying the chain-rule to each pixel value γ(x h,w)𝛾 subscript 𝑥 ℎ 𝑤\gamma(x_{h,w})italic_γ ( italic_x start_POSTSUBSCRIPT italic_h , italic_w end_POSTSUBSCRIPT ) in the augmented image γ(x)𝛾 𝑥\gamma(x)italic_γ ( italic_x ), i.e.∂γ(x h,w)∂λ j=1 𝛾 subscript 𝑥 ℎ 𝑤 subscript 𝜆 𝑗 1\frac{\partial\gamma(x_{h,w})}{\partial\lambda_{j}}=1 divide start_ARG ∂ italic_γ ( italic_x start_POSTSUBSCRIPT italic_h , italic_w end_POSTSUBSCRIPT ) end_ARG start_ARG ∂ italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG = 1. Hence, the gradient of loss ℒ ℒ\mathcal{L}caligraphic_L w.r.t. λ m subscript 𝜆 𝑚\lambda_{m}italic_λ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT can be computed as:
∂ℒ∂λ m=∑h,w∂ℒ∂γ(x h,w)∂γ(x h,w)∂λ m=∑h,w∂ℒ∂γ(x h,w)ℒ subscript 𝜆 𝑚 subscript ℎ 𝑤 ℒ 𝛾 subscript 𝑥 ℎ 𝑤 𝛾 subscript 𝑥 ℎ 𝑤 subscript 𝜆 𝑚 subscript ℎ 𝑤 ℒ 𝛾 subscript 𝑥 ℎ 𝑤\frac{\partial\mathcal{L}}{\partial\lambda_{m}}=\sum_{h,w}\frac{\partial% \mathcal{L}}{\partial\gamma(x_{h,w})}\frac{\partial\gamma(x_{h,w})}{\partial% \lambda_{m}}=\sum_{h,w}\frac{\partial\mathcal{L}}{\partial\gamma(x_{h,w})}divide start_ARG ∂ caligraphic_L end_ARG start_ARG ∂ italic_λ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG = ∑ start_POSTSUBSCRIPT italic_h , italic_w end_POSTSUBSCRIPT divide start_ARG ∂ caligraphic_L end_ARG start_ARG ∂ italic_γ ( italic_x start_POSTSUBSCRIPT italic_h , italic_w end_POSTSUBSCRIPT ) end_ARG divide start_ARG ∂ italic_γ ( italic_x start_POSTSUBSCRIPT italic_h , italic_w end_POSTSUBSCRIPT ) end_ARG start_ARG ∂ italic_λ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG = ∑ start_POSTSUBSCRIPT italic_h , italic_w end_POSTSUBSCRIPT divide start_ARG ∂ caligraphic_L end_ARG start_ARG ∂ italic_γ ( italic_x start_POSTSUBSCRIPT italic_h , italic_w end_POSTSUBSCRIPT ) end_ARG(7)
Then, the policy network parameters θ t subscript 𝜃 𝑡\theta_{t}italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT can be updated by minimizing the validation loss computed by the current meta task model w t^^subscript 𝑤 𝑡\hat{w_{t}}over^ start_ARG italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG on a mini-batch of validation set 𝒟 val={x i val,y i val}i=1 n val superscript 𝒟 𝑣 𝑎 𝑙 superscript subscript superscript subscript 𝑥 𝑖 𝑣 𝑎 𝑙 superscript subscript 𝑦 𝑖 𝑣 𝑎 𝑙 𝑖 1 superscript 𝑛 𝑣 𝑎 𝑙\mathcal{D}^{val}=\left{x_{i}^{val},y_{i}^{val}\right}{i=1}^{n^{val}}caligraphic_D start_POSTSUPERSCRIPT italic_v italic_a italic_l end_POSTSUPERSCRIPT = { italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_v italic_a italic_l end_POSTSUPERSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_v italic_a italic_l end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUPERSCRIPT italic_v italic_a italic_l end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT with batch size n val superscript 𝑛 𝑣 𝑎 𝑙 n^{val}italic_n start_POSTSUPERSCRIPT italic_v italic_a italic_l end_POSTSUPERSCRIPT. Therefore, the outer loop updates of θ t subscript 𝜃 𝑡\theta{t}italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is formulated by:
θ t+1=θ t−β1 n val∑i=1 n val∇θ ℒ i val(w t^(θ t))subscript 𝜃 𝑡 1 subscript 𝜃 𝑡 𝛽 1 superscript 𝑛 𝑣 𝑎 𝑙 superscript subscript 𝑖 1 superscript 𝑛 𝑣 𝑎 𝑙 subscript∇𝜃 subscript superscript ℒ 𝑣 𝑎 𝑙 𝑖^subscript 𝑤 𝑡 subscript 𝜃 𝑡\theta_{t+1}=\theta_{t}-\beta\frac{1}{n^{val}}\sum_{i=1}^{n^{val}}\nabla_{% \theta}\mathcal{L}^{val}{i}(\hat{w{t}}(\theta_{t}))italic_θ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT = italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_β divide start_ARG 1 end_ARG start_ARG italic_n start_POSTSUPERSCRIPT italic_v italic_a italic_l end_POSTSUPERSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUPERSCRIPT italic_v italic_a italic_l end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT caligraphic_L start_POSTSUPERSCRIPT italic_v italic_a italic_l end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) )(8)
where β 𝛽\beta italic_β is a learning rate. The third step is to update the parameter w t subscript 𝑤 𝑡 w_{t}italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT based on the parameter θ t+1 subscript 𝜃 𝑡 1\theta_{t+1}italic_θ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT of the policy model in the outer loop of iteration t+1 𝑡 1 t+1 italic_t + 1:
w t+1=w t−α1 n tr∑i=1 n tr∇w ℒ tr(w t;θ t+1)subscript 𝑤 𝑡 1 subscript 𝑤 𝑡 𝛼 1 superscript 𝑛 𝑡 𝑟 superscript subscript 𝑖 1 superscript 𝑛 𝑡 𝑟 subscript∇𝑤 superscript ℒ 𝑡 𝑟 subscript 𝑤 𝑡 subscript 𝜃 𝑡 1 w_{t+1}=w_{t}-\alpha\frac{1}{n^{tr}}\sum_{i=1}^{n^{tr}}\nabla_{w}\mathcal{L}^{% tr}(w_{t};\theta_{t+1})italic_w start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT = italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_α divide start_ARG 1 end_ARG start_ARG italic_n start_POSTSUPERSCRIPT italic_t italic_r end_POSTSUPERSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUPERSCRIPT italic_t italic_r end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT caligraphic_L start_POSTSUPERSCRIPT italic_t italic_r end_POSTSUPERSCRIPT ( italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT )(9)
With these updating rules, the policy and task networks can be alternatively trained. Our proposed algorithm is summarized in Algorithm1.
Algorithm 1 Model-Adaptive Data Augmentation
1:Training set 𝒟 train={x i,y i}i∈[N train]subscript 𝒟 𝑡 𝑟 𝑎 𝑖 𝑛 subscript subscript 𝑥 𝑖 subscript 𝑦 𝑖 𝑖 delimited-[]subscript 𝑁 𝑡 𝑟 𝑎 𝑖 𝑛\mathcal{D}{train}=\left{x{i},y_{i}\right}{i\in[N{train}]}caligraphic_D start_POSTSUBSCRIPT italic_t italic_r italic_a italic_i italic_n end_POSTSUBSCRIPT = { italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i ∈ [ italic_N start_POSTSUBSCRIPT italic_t italic_r italic_a italic_i italic_n end_POSTSUBSCRIPT ] end_POSTSUBSCRIPT; Validation set 𝒟 vaild={x i,y i}i∈[N valid]subscript 𝒟 𝑣 𝑎 𝑖 𝑙 𝑑 subscript subscript 𝑥 𝑖 subscript 𝑦 𝑖 𝑖 delimited-[]subscript 𝑁 𝑣 𝑎 𝑙 𝑖 𝑑\mathcal{D}{vaild}=\left{x{i},y_{i}\right}{i\in[N{valid}]}caligraphic_D start_POSTSUBSCRIPT italic_v italic_a italic_i italic_l italic_d end_POSTSUBSCRIPT = { italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i ∈ [ italic_N start_POSTSUBSCRIPT italic_v italic_a italic_l italic_i italic_d end_POSTSUBSCRIPT ] end_POSTSUBSCRIPT; Batch sizes n tr,n val superscript 𝑛 𝑡 𝑟 superscript 𝑛 𝑣 𝑎 𝑙 n^{tr},n^{val}italic_n start_POSTSUPERSCRIPT italic_t italic_r end_POSTSUPERSCRIPT , italic_n start_POSTSUPERSCRIPT italic_v italic_a italic_l end_POSTSUPERSCRIPT; Learning rate α,β 𝛼 𝛽\alpha,\beta italic_α , italic_β; Iteration number T 𝑇 T italic_T;
2:Task model w T subscript 𝑤 𝑇 w_{T}italic_w start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT; policy network θ T subscript 𝜃 𝑇\theta_{T}italic_θ start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT
3:Initialize w 0 subscript 𝑤 0 w_{0}italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, θ 0 subscript 𝜃 0\theta_{0}italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT
4:for t=0 𝑡 0 t=0 italic_t = 0 to T 𝑇 T italic_T do
5:Sample a training set mini-batch d train∈𝒟 train subscript 𝑑 𝑡 𝑟 𝑎 𝑖 𝑛 subscript 𝒟 𝑡 𝑟 𝑎 𝑖 𝑛 d_{train}\in\mathcal{D}_{train}italic_d start_POSTSUBSCRIPT italic_t italic_r italic_a italic_i italic_n end_POSTSUBSCRIPT ∈ caligraphic_D start_POSTSUBSCRIPT italic_t italic_r italic_a italic_i italic_n end_POSTSUBSCRIPT.
6:Draw Augment∼P(t)similar-to absent 𝑃 𝑡\sim P(t)∼ italic_P ( italic_t ) in Eq.1.
7:if Augment then
8:Apply policy network θ t subscript 𝜃 𝑡\theta_{t}italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT to achieve augmentations γ(x)𝛾 𝑥\gamma(x)italic_γ ( italic_x ) for each sample x∈d train 𝑥 subscript 𝑑 𝑡 𝑟 𝑎 𝑖 𝑛 x\in d_{train}italic_x ∈ italic_d start_POSTSUBSCRIPT italic_t italic_r italic_a italic_i italic_n end_POSTSUBSCRIPT.
9:end if
10:Update w t^^subscript 𝑤 𝑡\hat{w_{t}}over^ start_ARG italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG on the augmented d train subscript 𝑑 𝑡 𝑟 𝑎 𝑖 𝑛 d_{train}italic_d start_POSTSUBSCRIPT italic_t italic_r italic_a italic_i italic_n end_POSTSUBSCRIPT (Eq.5).
11:Sample a validation set mini-batch d valid∈𝒟 valid subscript 𝑑 𝑣 𝑎 𝑙 𝑖 𝑑 subscript 𝒟 𝑣 𝑎 𝑙 𝑖 𝑑 d_{valid}\in\mathcal{D}_{valid}italic_d start_POSTSUBSCRIPT italic_v italic_a italic_l italic_i italic_d end_POSTSUBSCRIPT ∈ caligraphic_D start_POSTSUBSCRIPT italic_v italic_a italic_l italic_i italic_d end_POSTSUBSCRIPT.
12:Update policy network θ t+1 subscript 𝜃 𝑡 1\theta_{t+1}italic_θ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT on d valid subscript 𝑑 𝑣 𝑎 𝑙 𝑖 𝑑 d_{valid}italic_d start_POSTSUBSCRIPT italic_v italic_a italic_l italic_i italic_d end_POSTSUBSCRIPT (Eq.8).
13:Apply policy network θ t+1 subscript 𝜃 𝑡 1\theta_{t+1}italic_θ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT to achieve new augmentations γ(x)𝛾 𝑥\gamma(x)italic_γ ( italic_x ) for each sample x∈d train 𝑥 subscript 𝑑 𝑡 𝑟 𝑎 𝑖 𝑛 x\in d_{train}italic_x ∈ italic_d start_POSTSUBSCRIPT italic_t italic_r italic_a italic_i italic_n end_POSTSUBSCRIPT.
14:Update task model w t+1 subscript 𝑤 𝑡 1 w_{t+1}italic_w start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT on the newly augmented d train subscript 𝑑 𝑡 𝑟 𝑎 𝑖 𝑛 d_{train}italic_d start_POSTSUBSCRIPT italic_t italic_r italic_a italic_i italic_n end_POSTSUBSCRIPT (Eq.9).
15:end for
4 Experiments
Table 1: Test error (%, average of 5 random trials) on CIFAR-10, CIFAR-100, SVHN and ImageNet. Lower value is better. “Simple” applies regular random crop, random horizontal flip, and Cutout. All other methods apply “Simple” on top of their proposed augmentations. We report the accuracy of our re-implemented AdaAug††\dagger†, while other baselines’ results are adapted from Zheng et al.[45], Cheung et al.[4], Tang et al.[37], and Suzuki et al.[36]. The best performance is highlighted in Bold.
Dataset Backbone Simple AA PBA Fast AA DADA Faster AA RA TA DeepAA Teach OnlineAug AdaAug AdaAug††\dagger†MADAug Reduced CIFAR-10 Wide-ResNet-28-10 18.9 14.1 12.8 14.6 15.6-15.1---14.3 13.6 15.0 12.5 Shake-Shake (26 2x96d)17.1 10.1 10.6--------10.9 11.8 10.0 CIFAR-10 Wide-ResNet-40-2 5.3 3.7 3.9 3.6 3.6 3.7 4.1----3.6-3.3 Wide-ResNet-28-10 3.9 2.6 2.6 2.7 2.7 2.6 2.7 2.5 2.4 2.5 2.4 2.6-2.1 Shake-Shake (26 2x96d)2.9 2.0 2.0 2.0 2.0 2.0 2.0 1.9 1.9 2.0---1.8 PyramidNet (ShakeDrop)2.7 1.5 1.5 1.8 1.7-1.5--1.5---1.4 CIFAR-100 Wide-ResNet-40-2 26.0 20.6 22.3 20.7 20.9 21.4-19.4---19.8-19.3 Wide-ResNet-28-10 18.8 17.1 16.7 17.3 17.5 17.3 16.7 16.5 16.1 16.8 16.6 17.1-16.1 Shake-Shake (26 2x96d)17.1 14.3 15.3 14.9 15.3 15.6--14.8 14.5---14.1 PyramidNet (ShakeDrop)14.0 10.7 10.9 11.9 11.2----11.8---10.5 Reduced SVHN Wide-ResNet-28-10 13.2 8.2 7.8 8.1 7.6-9.4---6.7 8.2 9.1 8.4 Shake-Shake (26 2x96d)13.3 5.9 6.5--------6.4 6.9 6.3 SVHN Wide-ResNet-28-10 1.5 1.1 1.2 1.1 1.2 1.2 1.0------1.0 Shake-Shake (26 2x96d)1.4 1.1 1.1 1.1 1.1--------1.0 ImageNet ResNet-50 23.7 22.4-22.4 22.5 23.5 22.4 22.1 21.7 22.2 22.5 22.8-21.5
In this section, following AutoAugment[6], we examine the performance of MADAug on two experiments: MADAug-direct and MADAug-transfer. In the first experiment, we directly explore the performance of the MADAug on the benchmark datasets: CIFAR-10[20], CIFAR-100[20], SVHN[31], and ImageNet[8]. For CIFAR-10, CIFAR-100, and SVHN, we equally select 1,000 images from the dataset as the validation set to train the policy model. Plus, for SVHN, we apply both the training images and additional “extra” training images as the training set. For ImageNet, the validation set consists of 1,200 examples from a randomly selected 120 classes. We compare the average test set error of our method with previous state-of-the-art methods, AutoAugment (AA)[6], Population Based Augmentation (PBA)[18], Fast AutoAugment (Fast AA)[24], DADA[23], Faster AutoAugment (Fasrer AA)[13], RandAugment (RA)[7], TrivialAugmen (TA)[30], Deep AutoAugment (Deep AA)[45], TeachAugment (Teach)[36], OnlineAug[37], and AdaAug[4].
Our experiment results demonstrate that MADAug-direct considerably improves the accuracy of baselines and achieves state-of-the-art performance on these benchmark datasets. In the second experiment, we investigate the transferability of MADAug-learned policy network to unseen fine-grained datasets. To verify its effectiveness, we apply the augmentation policies learned by MADAug on the CIFAR-100 dataset to fine-grained classification datasets such as Oxford 102 Flowers[32], Oxford-IIIT Pets[10], FGVC Aircraft[28], and Stanford Cars[19]. Our findings demonstrate the remarkable transferability of MADAug-learned policy, which significantly outperforms the robust baseline models on fine-grained classification datasets.
4.1 Augmentation Operations
We follow the augmentation actions taken by AutoAugment[6]. We adopt the 16 augmentation operations (ShearX, ShearY, TranslateX, TranslateY, Rotate, AutoContrast, Invert, Equalize, Solarize, Posterize, Contrast, Color, Brightness, Sharpness, and Cutout) that are previously suggested to build the augmentation policies. Meanwhile, we add the Identity operation, which does not apply augmentation on images. For the sample baseline, we employ random horizontal flip, color jittering, color normalization, and Cutout with a 16×16 16 16 16\times 16 16 × 16 patch size as basic augmentations. The found policies learned by MADAug and other baselines are applied on top of these basic augmentations.
Table 2: Transferability of MADAug learned policy network. Test set error (%) of fine-tuning a pertrained ResNet-50 using the augmentations produced by the policy network on downstream tasks. Baseline results are adapted from Cheung et al.[4].
Dataset# of classes Train number Simple AA Fast AA RA AdaAug MADAug Oxford 102 Flowers 102 2,040 5.0 6.1 4.8 3.9 2.8 2.5 Oxford-IIIT Pets 37 3,680 19.5 18.8 23.0 16.8 16.1 15.3 FGVC Aircraft 100 6,667 18.4 16.6 17.0 17.4 16.0 15.4 Stanford Cars 196 8,144 11.9 9.2 10.7 10.3 8.8 8.3
4.2 Implementation Details
In our experiments, the policy network of MADAug refers to a fully-connected layer that takes the representations produced by the penultimate layer of the task model as its inputs and outputs p 𝑝 p italic_p and λ 𝜆\lambda italic_λ. Following AdaAug[4], the update of policy projection network parameters uses the Adam optimizer with a learning rate of 0.001 0.001 0.001 0.001. For the CIFAR-10, CIFAR-100, and SVHN, we evaluate our method on four models: Wide-ResNet- 40-2[42], Wide-ResNet-28-10[42], Shake-Shake (26 2x96d)[11], and PyramidNet with ShakeDrop[40, 12]. We train all models using a batch size of 128 except for PyramidNet with Shake-Drop, which is trained with a batch size of 64. We train the Wide-ResNet for 200 epochs and Shake-Shake/PyramidNet for 1,800 epochs. For Wide-ResNet models trained on SVHN, we follow PBA[18] to use the step learning rate schedule[9] and all other models use a cosine learning rate scheduler with one annealing cycle[27]. To align our results with other baselines, we train the ResNet-50[15] from scratch on the full ImageNet using the hyperparameters in AutoAugment[6] on ImageNet. For all models, we use gradient clipping with magnitude 5. We provide specific details about the learning rate and weight decay values on the supplementary materials.
4.3 Main Results
Table1 shows that the learned policies through bi-level optimization achieve the best performance than the baselines for different models on the Reduced CIFAR-10, CIFAR-10, CIFAR 100, Reduced SVHN, SVHN, and ImageNet. The Reduced CIFAR-10(SVHN) dataset randomly selects 4,000(1,000) images for CIFAR-10(SVHN) as the training set and sets the remaining images as the validation set. MADAug achieves state-of-the-art performance on this dataset. On the Reduced SVHN dataset, compared to AdaAug, we achieve 0.7% and 0.6% improvement on Wide-ResNet-28-10 and Shake-Shake (26 2x96d), respectively. On ImageNet, compare with other baselines, our method performs the best on a large and complex dataset. Different from the prior work (AutoAugment, PBA, and Fast AutoAugment) which constructs the fixed augmentation policies for the enter dataset, our method can find the dynamic and model-adoptive policies for each image, which enhances the model’s generalization. We provide the average and variance of the experiment results in Section4.7.
4.4 Transferability of MADAug-Learned Policy
Following AdaAug[4], we apply the augmentation policies learned from the CIFAR-100 directly on the fine-grained datasets (MADAug-direct). To evaluate the transferability of the policies found on CIFAR-100, we compare the test error rate with AutoAugment (AA)[6], Fast AutoAugment (Fast AA)[24], RandAugment (RA)[7], and AdaAug[4] using their published policies on CIFAR-100. For all the fine-grained datasets, we compare the transfer results by training the ResNet-50 model[15] pretrained on ImageNet. Following the experiment setting of AdaAug, we use the cosine learning rate decay with one annealing cycle[27] and train the model for 100 epochs. According to the validation performance, we adjust the learning rate for different fine-grained datasets. The weight decay is set as 1e-4 and the gradient clipping parameter is 5.
Table2 shows that our method outperforms the other baselines when training the pertrained ResNet-50 model on these fine-grained datasets. Previous methods (AutoAugment[6], Fast Augmentation[24], and RandAugment[7]) apply the fixed augmentation policies. This strategy does not help the fine-grained sample to distinguish from each other, which makes the model easy to recognize their differences. In contrast, AdaAug and MADAug adapt the augmentation policies for the entire dataset to instance-level grade. Because our method gradually augments the images and provides more optimal augmentation policies to unseen fine-grained images according to their relationship to the classes that have been learned on the CIFAR-100 dataset, the model achieves better performance than AdaAug, especially for the Pet dataset. The Pet dataset only contains the “Cat” and “Dog” images. In Figure3, we also show that MADAug can improve the model’s ability to recognize “Cat” and “Dog” classes significantly on the Reduced CIFAR-10 dataset.
Figure 3: Improvement that MADAug and AdaAug bring to different classes. MADAug consistently improves the test accuracy over all classes and brings greater improvements to more difficult classes (fairness), e.g., “Bird”, “Cat”, “Deer”, and “Dog”. In contrast, AdaAug has a negative impact on “Airplane” and “Automobile”.
4.5 Analysis of MADAug Augmentations
Figure 4: Augmentations of AdaAug and MADAug for different classes of images (operations and associated strengths). AdaAug only produces specific augmentations for different images, while MADAug adjusts the augmentations for each image to be adaptive to different training epochs. MADAug introduces less distortion than AdaAug.
We compare the pre-class accuracy from MADAug with AdaAug on the Reduced CIFAR-10 dataset, which only has 4,000 training images. Figure3 shows the pre-class accuracy of a model trained on MADAug is higher than that trained on AdaAug and basic baseline, especially in “Bird”, “Deer”, “Cat” and “Dog” classes. Moreover, compared with the basic baseline, we can see that the augmentation policies trained by AdaAug play a negative impact on the “Airplane” and “Automobile” classes.
Figure 5: Similarity between the original images and MADAug-augmented images at different training epochs. MADAug starts from less perturbed images but generates more challenging augmentations in later training.
In Figure4, we display some augmented samples with AdaAug and MADAug, which are randomly selected from “Bird”, “Cat”, “Deer”, “Dog”, “Airplane”, and “Automobile” classes. For AdaAug, some augmented images in their classes have lost semantic information caused by the translation, because augmentation operations like TranslateY with unreasonable magnitude collapse the main information of the image. For example, in the Figure4, the selected “Cat” augmented image loses its face and only leaves its legs. And the “Dog” augmented image also discards a part of key information. We think that these unreasonable data augmentation strategies lead to an imbalance in the number of samples containing sufficient information about their true label across different categories, although the original dataset is balanced[2]. This phenomenon leads to a reduction in the classification accuracy of some categories. However, for our method, Figure5 shows augmentation policies generated by MADAug can produce more “hard” samples for the model with the training process. This strategy would improve the model generalization because at the early phase, “simple” samples can help models converge quickly and when the models have the capability to recognize the original samples, the “hard” samples can make them learn more robust features. Figure4 shows augmented images on the different training phases. The model gradually receives more adversarial augmented images. And, these augmented policies learned by MADAug increase the diversity of training data and highlight the key information of data.
The same analysis of the Reduced SVHN dataset is presented in AppendixA. Through analyses of Reduced SVHN and Reduced CIFAR-10 dataset, from an experimental perspective, we illustrate that MADAug consistently provides higher-quality data augmentation strategies for samples, leading to improve more accuracy of the task model across different categories than AdaAug. From a methodological perspective, we also provide a detailed account of the advantages of MADAug over AdaAug in AppendixA.
4.6 Computing cost of MADAug
To demonstrate the effectiveness of MADAug, we present a comparison of GPU hours needed to search the augmentation policy and train the task model across different baselines. The results are showed in Table3. The searching time of our method is regarded as the time to optimize Eq.4. Thus, we do not need extra time to find data augmentation policies. Our approach is more effective than AutoAugment[6], PBA[18], and AdaAug[4].
Table 3: Time consumption. Comparison of computing cost (GPU hours) in training Wide-ResNet-28-10 on Reduced CIFAR-10 datasets between AutoAugment, PBA, AdaAug, and MADAug.
| Method | Computing Cost |
|---|---|
| Searching | Training |
| AutoAugment | 5000 |
| PBA | 5 |
| AdaAug | 2.9 |
| MADAug | ∼similar-to\sim∼0 |
4.7 Mean and variance of the experiment results
Table4 represents the average values and variances of these experimental results obtained from multiply trials on different benchmark datasets.
Table 4: Mean and variance of experiment results. Test error and variance (%) of MADAug on different benchmark datasets with Wide-ResNet-28-10 and ResNet-50.
| Dataset | Reduced CIFAR-10 | CIFAR-10 | CIFAR-100 |
|---|---|---|---|
| 12.5±plus-or-minus\pm±0.05 | 2.1±plus-or-minus\pm±0.11 | 16.1±plus-or-minus\pm±0.10 | |
| Dataset | Reduced SVHN | SVHN | ImageNet |
| 8.4±plus-or-minus\pm±0.09 | 1.0±plus-or-minus\pm±0.10 | 21.5±plus-or-minus\pm±0.15 |
5 Ablation study
Magnitude perturbation.
Following the AdaAug[4], we also add the magnitude perturbation δ 𝛿\delta italic_δ for the augmentation policy. From Table5, when the magnitude perturbation of operation is set as 0.3, the performance of the model on the test dataset is best. And we can conclude that the magnitude perturbation plays a positive effect in improving the generalization of the model.
Number of augmentation operations.
The number of operations k 𝑘 k italic_k is arbitrary. Table5 shows the relationship between the number of operations and the final test error on the Reduced CIFAR-10 with Wide-ResNet-28-10 model. The number of operations ranges from 1 to 5. When the augmentation operation is chosen as 2, we have the lowest error rate on the dataset. Policies learned by other methods (AutoAugment[6], PBA[18], and AdaAug[4]) also formulates two augmentation operations. This phenomenon indicates two augmentation operations’ policies not only increase the diversity and amount of images but also do not make the task model unable to recognize the images due to excessive data augmentations.
Structure of policy network.
Does the use of a nonlinear projection deliver better performance? We would add the policy model with the multiple hidden layers and the ReLU activation. Table5 shows the influence of different number h ℎ h italic_h of hidden layers on model performance. A single linear layer is sufficient for the policy model, without adding extra hidden layers.
Hyperparameter of τ 𝜏\tau italic_τ.
The hyperparameter τ 𝜏\tau italic_τ controls the relationship between the epoch and the number of augmented samples. As is shown in Table5, the performance of the task model is quite robust to hyperparameter τ∈{10,20,30,40,50}𝜏 10 20 30 40 50\tau\in{10,20,30,40,50}italic_τ ∈ { 10 , 20 , 30 , 40 , 50 }. For the Reduced CIFAR-10, τ 𝜏\tau italic_τ is optimally set as 40.
Analysis of optimization steps.
Table5 illustrates the impact of optimizing data augmentation strategies through different steps s 𝑠 s italic_s on the experiment results using the Reduced CIFAR-10. The task model exhibits its highest accuracy when the step size is configured to 1.
Table 5: Ablation study. Sensitivity analysis of hyperparameter δ 𝛿\delta italic_δ, k 𝑘 k italic_k, h ℎ h italic_h τ 𝜏\tau italic_τ, and s 𝑠 s italic_s on Reduced CIFAR-10 (Wide-ResNet-28-10).
δ 𝛿\delta italic_δ 0 0.1 0.2 0.3 0.4 ACC(%)86.7 87.0 87.3 87.5 87.2 k 𝑘 k italic_k 1 2 3 4 5 ACC(%)86.6 87.5 87.3 86.8 86.2 h ℎ h italic_h 0 1 2 3 4 ACC(%)87.5 87.2 87.1 86.9 86.7 τ 𝜏\tau italic_τ 10 20 30 40 50 ACC(%)87.0 87.3 87.4 87.5 87.3 s 𝑠 s italic_s 1 2 5 10 30 ACC(%)87.5 87.0 86.6 86.3 85.9
Effect of monotonic curriculum.
We investigate the effect of monotonic curriculum which is introduced in Section3.1. We train the Wide-ResNet-28-10 on the Reduced CIFAR-10 and Reduced SVHN datasets without/with this trick across different baselines. The results are shown in Table6. For MADAug, monotonic curriculum contributes to the improvement of accuracy in these datasets. For other baselines, whether AutoAugment[6] method that applies the same data augmentation policy for the entire dataset, or AdaAug approach that offers different data augmentation policies to different samples, monotonic curriculum has been found effectively.
Table 6: Effect of monotonic curriculum. Test error (%) of MADAug and other baselines without/with monotonic curriculum.
| Method | Monotonic curriculum | Reduced CIFAR-10 | Reduced SVHN |
|---|---|---|---|
| AA | 14.1 | 8.2 | |
| ✓✓\checkmark✓ | 13.7 | 7.8 | |
| AdaAug | 15.0 | 9.1 | |
| ✓✓\checkmark✓ | 14.4 | 8.7 | |
| MADAug | 13.1 | 8.9 | |
| ✓✓\checkmark✓ | 12.5 | 8.4 |
Strategy of MADAug.
MADAug not only dynamically adjusts the augmentation strategies to minimize the loss of the task model on the validation set which is named a model-adaptive strategy but also provides different data augmentation policies for each sample called data-adaptive strategy. In order to verify the effectiveness of MADAug, we use one of these two strategies to find the augmentation policies and train the task model to classify the dataset. Table7 shows MADAug combines these two training strategies well and offers the higher quality of data augmentation policies for the dataset.
Table 7: Effect of model/data-adaptive augmentation strategy. Test error (%) of model-adaptive/data-adaptive only MADAug on two datasets.
| Model-adaptive | Data-adaptive | Reduced CIFAR-10 | Reduced SVHN | | --- | --- | | ✓✓\checkmark✓ | | 14.5 | 9.1 | | | ✓✓\checkmark✓ | 14.0 | 9.6 | | ✓✓\checkmark✓ | ✓✓\checkmark✓ | 12.5 | 8.3 |
6 Conclusion
In this paper, we propose a novel and general data augmentation method, MADAug, which is able to produce instance-adaptive augmentations adaptive to different training stages. Compared to previous methods, MADAug is featured by a monotonic curriculum that progressively increases augmented data and a policy network that generates augmentations optimized to minimize the validation loss of a task model. MADAug achieves SOTA performance on several benchmark datasets and its learned augmentation policy network is transferable to unseen tasks and brings more improvement than other augmentations. We show that MADAug-augmentations preserve the key information of images and change with the task model in different training stages accordingly. Due to its data-and-model-adaptive property, MADAug has a great potential to improve a rich class of machine learning tasks in different domains.
References
- [1] Antreas Antoniou, Harrison Edwards, and Amos Storkey. How to train your maml. arXiv preprint arXiv:1810.09502, 2018.
- [2] Randall Balestriero, Leon Bottou, and Yann LeCun. The effects of regularization and data augmentation are class dependent. arXiv preprint arXiv:2204.03632, 2022.
- [3] Yoshua Bengio, Nicholas Léonard, and Aaron Courville. Estimating or propagating gradients through stochastic neurons for conditional computation. arXiv preprint arXiv:1308.3432, 2013.
- [4] Tsz-Him Cheung and Dit-Yan Yeung. Adaaug: Learning class-and instance-adaptive data augmentation policies. In International Conference on Learning Representations, 2021.
- [5] Benoît Colson, Patrice Marcotte, and Gilles Savard. An overview of bilevel optimization. Annals of operations research, 153:235–256, 2007.
- [6] Ekin D Cubuk, Barret Zoph, Dandelion Mane, Vijay Vasudevan, and Quoc V Le. Autoaugment: Learning augmentation policies from data. arXiv preprint arXiv:1805.09501, 2018.
- [7] Ekin D Cubuk, Barret Zoph, Jonathon Shlens, and Quoc V Le. Randaugment: Practical automated data augmentation with a reduced search space. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition workshops, pages 702–703, 2020.
- [8] Jia Deng, Wei Dong, Richard Socher, Li-Jia Li, Kai Li, and Li Fei-Fei. Imagenet: A large-scale hierarchical image database. In 2009 IEEE conference on computer vision and pattern recognition, pages 248–255. Ieee, 2009.
- [9] Terrance DeVries and Graham W Taylor. Improved regularization of convolutional neural networks with cutout. arXiv preprint arXiv:1708.04552, 2017.
- [10] Yan Em, Feng Gag, Yihang Lou, Shiqi Wang, Tiejun Huang, and Ling-Yu Duan. Incorporating intra-class variance to fine-grained visual recognition. In 2017 IEEE International Conference on Multimedia and Expo (ICME), pages 1452–1457. IEEE, 2017.
- [11] Xavier Gastaldi. Shake-shake regularization. arXiv preprint arXiv:1705.07485, 2017.
- [12] Dongyoon Han, Jiwhan Kim, and Junmo Kim. Deep pyramidal residual networks. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 5927–5935, 2017.
- [13] Ryuichiro Hataya, Jan Zdenek, Kazuki Yoshizoe, and Hideki Nakayama. Faster autoaugment: Learning augmentation strategies using backpropagation. In European Conference on Computer Vision, pages 1–16. Springer, 2020.
- [14] Ryuichiro Hataya, Jan Zdenek, Kazuki Yoshizoe, and Hideki Nakayama. Meta approach to data augmentation optimization. In Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision, pages 2574–2583, 2022.
- [15] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 770–778, 2016.
- [16] Dan Hendrycks, Norman Mu, Ekin D Cubuk, Barret Zoph, Justin Gilmer, and Balaji Lakshminarayanan. Augmix: A simple data processing method to improve robustness and uncertainty. arXiv preprint arXiv:1912.02781, 2019.
- [17] Alex Hernández-García and Peter König. Data augmentation instead of explicit regularization. arXiv preprint arXiv:1806.03852, 2018.
- [18] Daniel Ho, Eric Liang, Xi Chen, Ion Stoica, and Pieter Abbeel. Population based augmentation: Efficient learning of augmentation policy schedules. In International Conference on Machine Learning, pages 2731–2741. PMLR, 2019.
- [19] Jonathan Krause, Jia Deng, Michael Stark, and Li Fei-Fei. Collecting a large-scale dataset of fine-grained cars. 2013.
- [20] Alex Krizhevsky, Geoffrey Hinton, et al. Learning multiple layers of features from tiny images. 2009.
- [21] Alex Krizhevsky, Ilya Sutskever, and Geoffrey E Hinton. Imagenet classification with deep convolutional neural networks. Communications of the ACM, 60(6):84–90, 2017.
- [22] Ruihui Li, Xianzhi Li, Pheng-Ann Heng, and Chi-Wing Fu. Pointaugment: an auto-augmentation framework for point cloud classification. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 6378–6387, 2020.
- [23] Yonggang Li, Guosheng Hu, Yongtao Wang, Timothy Hospedales, Neil M Robertson, and Yongxin Yang. Dada: Differentiable automatic data augmentation. arXiv preprint arXiv:2003.03780, 2020.
- [24] Sungbin Lim, Ildoo Kim, Taesup Kim, Chiheon Kim, and Sungwoong Kim. Fast autoaugment. Advances in Neural Information Processing Systems, 32, 2019.
- [25] Chen Lin, Minghao Guo, Chuming Li, Xin Yuan, Wei Wu, Junjie Yan, Dahua Lin, and Wanli Ouyang. Online hyper-parameter learning for auto-augmentation strategy. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pages 6579–6588, 2019.
- [26] Tom Ching LingChen, Ava Khonsari, Amirreza Lashkari, Mina Rafi Nazari, Jaspreet Singh Sambee, and Mario A Nascimento. Uniformaugment: A search-free probabilistic data augmentation approach. arXiv preprint arXiv:2003.14348, 2020.
- [27] Ilya Loshchilov and Frank Hutter. Sgdr: Stochastic gradient descent with warm restarts. arXiv preprint arXiv:1608.03983, 2016.
- [28] Subhransu Maji, Esa Rahtu, Juho Kannala, Matthew Blaschko, and Andrea Vedaldi. Fine-grained visual classification of aircraft. arXiv preprint arXiv:1306.5151, 2013.
- [29] Saypraseuth Mounsaveng, Issam Laradji, Ismail Ben Ayed, David Vazquez, and Marco Pedersoli. Learning data augmentation with online bilevel optimization for image classification. In Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision, pages 1691–1700, 2021.
- [30] Samuel G Müller and Frank Hutter. Trivialaugment: Tuning-free yet state-of-the-art data augmentation. In Proceedings of the IEEE/CVF international conference on computer vision, pages 774–782, 2021.
- [31] Yuval Netzer, Tao Wang, Adam Coates, Alessandro Bissacco, Bo Wu, and Andrew Y Ng. Reading digits in natural images with unsupervised feature learning. 2011.
- [32] Maria-Elena Nilsback and Andrew Zisserman. Automated flower classification over a large number of classes. In 2008 Sixth Indian Conference on Computer Vision, Graphics & Image Processing, pages 722–729. IEEE, 2008.
- [33] Mengye Ren, Wenyuan Zeng, Bin Yang, and Raquel Urtasun. Learning to reweight examples for robust deep learning. In International conference on machine learning, pages 4334–4343. PMLR, 2018.
- [34] Jun Shu, Qi Xie, Lixuan Yi, Qian Zhao, Sanping Zhou, Zongben Xu, and Deyu Meng. Meta-weight-net: Learning an explicit mapping for sample weighting. Advances in neural information processing systems, 32, 2019.
- [35] Rupesh K Srivastava, Klaus Greff, and Jürgen Schmidhuber. Training very deep networks. Advances in neural information processing systems, 28, 2015.
- [36] Teppei Suzuki. Teachaugment: Data augmentation optimization using teacher knowledge. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 10904–10914, 2022.
- [37] Zhiqiang Tang, Yunhe Gao, Leonid Karlinsky, Prasanna Sattigeri, Rogerio Feris, and Dimitris Metaxas. Onlineaugment: Online data augmentation with less domain knowledge. In Computer Vision–ECCV 2020: 16th European Conference, Glasgow, UK, August 23–28, 2020, Proceedings, Part VII 16, pages 313–329. Springer, 2020.
- [38] Yuji Tokozume, Yoshitaka Ushiku, and Tatsuya Harada. Between-class learning for image classification. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 5486–5494, 2018.
- [39] Aaron Van Den Oord, Oriol Vinyals, et al. Neural discrete representation learning. Advances in neural information processing systems, 30, 2017.
- [40] Yoshihiro Yamada, Masakazu Iwamura, Takuya Akiba, and Koichi Kise. Shakedrop regularization for deep residual learning. IEEE Access, 7:186126–186136, 2019.
- [41] Sangdoo Yun, Dongyoon Han, Seong Joon Oh, Sanghyuk Chun, Junsuk Choe, and Youngjoon Yoo. Cutmix: Regularization strategy to train strong classifiers with localizable features. In Proceedings of the IEEE/CVF international conference on computer vision, pages 6023–6032, 2019.
- [42] Sergey Zagoruyko and Nikos Komodakis. Wide residual networks. arXiv preprint arXiv:1605.07146, 2016.
- [43] Hongyi Zhang, Moustapha Cisse, Yann N Dauphin, and David Lopez-Paz. mixup: Beyond empirical risk minimization. arXiv preprint arXiv:1710.09412, 2017.
- [44] Xinyu Zhang, Qiang Wang, Jian Zhang, and Zhao Zhong. Adversarial autoaugment. arXiv preprint arXiv:1912.11188, 2019.
- [45] Yu Zheng, Zhi Zhang, Shen Yan, and Mi Zhang. Deep autoaugment. arXiv preprint arXiv:2203.06172, 2022.
- [46] Zhun Zhong, Liang Zheng, Guoliang Kang, Shaozi Li, and Yi Yang. Random erasing data augmentation. In Proceedings of the AAAI conference on artificial intelligence, volume 34, pages 13001–13008, 2020.
- [47] Fengwei Zhou, Jiawei Li, Chuanlong Xie, Fei Chen, Lanqing Hong, Rui Sun, and Zhenguo Li. Metaaugment: Sample-aware data augmentation policy learning. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 35, pages 11097–11105, 2021.
Appendix A Appendix
Quality of augmentation policies on Reduced SVHN.
For Reduced SVHN dataset, we show data augmentations learned by MADAug are superior to AdaAug on the improvement of per-class accuracy, especially in the “4”, “7”, “8”, and “9” classes in Figure6.
Figure 6: MADAug and AdaAug’s improvements to different classes on Reduced SVHN dataset. Compared with AdaAug, MADAug enhances the test accuracy across various classes, particularly demonstrating a more notable positive impact on same classes such as “4”, “7”, “8”, and “9”.
Figure 7: Similarity between the original images and MADAug-augmented images at various training phases. As the training process, MADAug gradually generates more “challenging” augmentation policies for images.
AdaAug[4] provides detailed augmentation policies for models throughout the entire training stage. These augmentation policies make samples distinguish from the original images, which can potentially hinder model convergence during the early stages of training. However, MADAug gradually applies the data augmentations for samples. Figure7 demonstrates that as the training epoch progresses, more “adversarial” images generated by MADAug can be provided for the task model. Figure8 visually illustrates that during the early training phase, the model receives the original images, while as the training progresses, MADAug learns and applies more “challenging” augmentation policies to augment the images. However, these policies are designed to avoid collapsing the intrinsic meanings of the images and instead, emphasize the crucial information within them.
Figure 8: Augmentations learned by AdaAug and MADAug applied to the “4”, “7”, “8”, and “9” class images at various training epochs on Reduced SVHN dataset. Each augmentation operation is formatted with its name and magnitude, respectively.
Advantages of MADAug over AdaAug.
AdaAug employs a two-stage training process, where the first stage involves learning data augmentation strategies for individual samples through alternating “exploration” and “exploitation” steps. In the second stage, the learned strategies are fixed while training the task model. However, AdaAug exhibits limitations, including the inability to dynamically update the learned data augmentation strategies based on the performance of the current task model and the suboptimal choice of a two-stage training process.
To address these drawbacks, we propose the MADAug method that overcomes these limitations. In MADAug, we dynamically optimize the data augmentation strategies for each sample by leveraging the current model’s performance on the validation set. This facilitates training the model with the most effective data augmentation for the given training stage. We adopt an end-to-end training approach for the task model, differing from AdaAug and AutoAugment, which utilize a two-stage training methodology.
Additionally, we discover that data augmentations do not improve model performance obviously in the early stages of training. Therefore, we use the monotonic curriculum strategy, gradually applying data augmentations to each sample as the training progresses, thereby enhancing the robustness of the task model.
Through the use of MADAug, we demonstrate significant advancements over AdaAug, achieved by its ability to dynamically optimize data augmentation strategies and employ the monotonic curriculum strategy.
Model hyperparameters.
We show the important hyperparameters on different benchmark datasets in Table8. For other details on the hyperparameters and implementation, we display them in the open source code.
Table 8: Hyperparameters on benchmark datasets. We do not specifically tune these hyperparameters, and all of these are consistent with PBA and AutoAugment.
| Dataset | Model | Learning Rate | Weight Decay | Batch Size | epoch |
|---|---|---|---|---|---|
| CIFAR-10 | Wide-ResNet-40-2 | 0.1 | 5e-4 | 128 | 200 |
| CIFAR-10 | Wide-ResNet-28-10 | 0.1 | 5e-4 | 128 | 200 |
| CIFAR-10 | Shake-Shake (26 2x96d) | 0.01 | 1e-3 | 128 | 1,800 |
| CIFAR-10 | PyramidNet+ShakeDrop | 0.05 | 5e-5 | 64 | 1,800 |
| CIFAR-100 | Wide-ResNet-40-2 | 0.1 | 5e-4 | 128 | 200 |
| CIFAR-100 | Wide-ResNet-28-10 | 0.1 | 5e-4 | 128 | 200 |
| CIFAR-100 | Shake-Shake (26 2x96d) | 0.01 | 2.5e-3 | 128 | 1,800 |
| CIFAR-100 | PyramidNet+ShakeDrop | 0.025 | 5e-4 | 64 | 1,800 |
| Reduced CIFAR-10 | Wide-ResNet-28-10 | 0.05 | 5e-3 | 128 | 200 |
| Reduced CIFAR-10 | Shake-Shake (26 2x96d) | 0.025 | 2.5e-3 | 128 | 1,800 |
| SVHN | Wide-ResNet-28-10 | 0.005 | 1e-3 | 128 | 200 |
| SVHN | Shake-Shake (26 2x96d) | 0.01 | 1.5e-4 | 128 | 1,800 |
| Reduced SVHN | Wide-ResNet-28-10 | 0.05 | 1e-2 | 128 | 200 |
| Reduced SVHN | Shake-Shake (26 2x96d) | 0.025 | 5e-3 | 128 | 1,800 |
| ImageNet | ResNet-50 | 1.6 | 1e-4 | 4096 | 270 |
Generated on Sat Sep 30 06:07:58 2023 by [L A T E xml](http://dlmf.nist.gov/LaTeXML/)
This paper uses the following packages that do not yet convert to HTML. These are known issues and are being worked on. Have free development cycles? We welcome contributors.
- failed: axessibility
Xet Storage Details
- Size:
- 96.5 kB
- Xet hash:
- 6b3ccd4658c20677e7989a7d4553c22174e866fd6190c2a277adc5c4bd728281
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.







