Buckets:

|
download
raw
210 kB

Title: Filtering with Self-Attention and Storing with MLP: One-Layer Transformers Can Provably Acquire and Extract Knowledge

URL Source: https://arxiv.org/html/2508.00901

Markdown Content: Back to arXiv

This is experimental HTML to improve accessibility. We invite you to report rendering errors. Use Alt+Y to toggle on accessible reporting links and Alt+Shift+Y to toggle off. Learn more about this project and help improve conversions.

Why HTML? Report Issue Back to Abstract Download PDF Abstract 1Introduction 2Simplified One-Layer Transformer Architecture 3Data 4Problem Setting 5Conditions for Out-of-Distribution Generalization Post Fine-Tuning 6Mechanics of Knowledge Acquisition and Extraction 7Experiments 8Conclusions

HTML conversions sometimes display errors due to content that did not convert correctly from the source. This paper uses the following packages that are not yet supported by the HTML conversion tool. Feedback on these issues are not necessary; they are known and are being worked on.

failed: mdframed.sty

Authors: achieve the best HTML results from your LaTeX submissions by following these best practices.

License: arXiv.org perpetual non-exclusive license arXiv:2508.00901v3 [cs.LG] 25 Nov 2025 Filtering with Self-Attention and Storing with MLP: One-Layer Transformers Can Provably Acquire and Extract Knowledge Ruichen Xu  Kexin Chen Correspondence to: rcxu642@gmail.com; Random order Abstract

Modern large language models (LLMs) demonstrate exceptional performance on knowledge-intensive tasks, yet the theoretical mechanisms underlying knowledge acquisition (storage and memorization) during pre-training and extraction (retrieval and recall) during inference after fine-tuning remain poorly understood. Although prior theoretical studies have explored these processes through analyses of training dynamics, they overlook critical components essential for a comprehensive theory: (1) the multi-layer perceptron (MLP), empirically identified as the primary module for knowledge storage; (2) out-of-distribution (OOD) adaptivity, which enables LLMs to generalize to unseen scenarios post-pre-training; and (3) next-token prediction, the standard autoregressive objective that encodes knowledge as conditional probabilities. In this work, we introduce, to the best of our knowledge, the first theoretical framework that addresses these limitations by examining the training dynamics of one-layer transformers. Under regularity assumptions, we establish that: (i) transformers attain near-optimal training loss during pre-training, demonstrating effective knowledge acquisition; (ii) given a sufficiently large fine-tuning dataset and appropriate data multiplicity conditions, transformers achieve low generalization error on factual knowledge acquired during pre-training but not revisited in fine-tuning, indicating robust knowledge extraction; and (iii) violation of these conditions leads to elevated generalization error, manifesting as hallucinations. Our analysis encompasses both full fine-tuning and low-rank fine-tuning, yielding insights into the efficacy of practical low-rank adaptation methods. Additionally, it elucidates several empirical observations, including the impact of learning rate schedules on training dynamics. We validate our theoretical findings through experiments on synthetic datasets and the real-world PopQA benchmark, employing GPT-2 and Llama-3.2-1B models.

1Introduction

Modern Large Language Models (LLMs) function as powerful knowledge repositories, demonstrating a remarkable ability to perform a wide range of tasks, from answering complex questions to writing creative text [brown2020language]. These capabilities are not explicitly programmed but emerge from the models’ training process. During their training phase, these models implicitly encode vast amounts of information from massive datasets directly into their billions of internal parameters. When a user provides a prompt, it acts as a key to navigate this complex parameter space, causing the model to synthesize the relevant information by predicting a statistically probable sequence of text based on the prompt’s context and its encoded knowledge.

A critical challenge, however, is that knowledge encoded during training is not always accessible at inference time. Even when an LLM has successfully learned a piece of information, it can fail to retrieve or correctly apply that knowledge when prompted [allen2023physics1, kandpal2023large]. This retrieval failure often leads the model to generate plausible but incorrect information, a phenomenon widely known as “hallucination.” Therefore, demystifying the precise mechanics of how these models acquire and extract knowledge is fundamental to building genuinely reliable and trustworthy AI systems.

Research into the knowledge mechanisms of large language models has progressed from empirical observation to theoretical modeling. Empirically, studies have identified the MLP layers as the primary location for knowledge storage, where they act as key-value memories [geva2020transformer, dong2025attention, yao2024knowledge]. Others have explored and established approaches for probing and steering the factual knowledge stored within an LLM’s parameters [allen2023physics1, allen2023physics2, meng2022locating]. However, while these empirical findings reveal the mechanisms at play in knowledge storage, they don’t fully explain why these knowledge mechanisms emerge from the training process. To bridge this gap, a growing body of theoretical research now seeks to formally characterize the knowledge acquisition process by analyzing knowledge acquisition mechanisms with simplified models at distinct phases. For example, Nichani et al. [nichani2024understanding] investigated the storage capacities of one-layer transformers in conjunction with associative memories. Ghosal et al. [ghosal2024understanding] studied the gradients of one-layer attention-only transformers during fine-tuning with subject-answer pairs.

Despite these advances, a comprehensive theoretical framework linking knowledge acquisition during pre-training to out-of-distribution (OOD) generalization following fine-tuning remains elusive. Such a framework must incorporate several key components that are absent from current analyses:

Interplay of self-attention and MLP: Since analyzing a full transformer is analytically challenging, the majority of existing theoretical analyses on training dynamics concentrate on attention-only transformer architectures.1 However, this simplification overlooks the critical role of the MLP, which extensive empirical evidence identifies as the primary component for knowledge storage in transformer-based language models (e.g., [geva2020transformer, meng2022locating, dong2025attention, yao2024knowledge]). Furthermore, our empirical results, detailed in Table 3 of Appendix A, demonstrate that attention-only transformers often fail to acquire knowledge effectively, as evidenced by their inability to achieve near-optimal training loss on simple next-token prediction datasets. To address this limitation, we examine a simplified one-layer transformer comprising both self-attention and multi-layer perceptron (MLP) modules. Although this model is simplified from a full transformer, elucidating its knowledge acquisition mechanisms remains highly nontrivial due to the intricate interplay between self-attention and MLP: the output of the self-attention module directly feeds into the MLP, resulting in tightly coupled parameters and entangled gradients.

Fine-tuning for out-of-distribution (OOD) adaptivity: The capacity for transformers to generalize their knowledge to OOD tasks [bubeck2023sparks, dubey2024llama] is typically unlocked via fine-tuning, a process encompassing methods from task-specific adaptation [devlin2019bert, radford2018improving] to large-scale instruction tuning [ouyang2022training]. However, a central challenge lies in characterizing the dynamics of this adaptation process. In contrast to training from scratch, which is extensively studied, fine-tuning dynamics are critically dependent on the rich features inherited from the pre-trained model. This interdependency necessitates a joint theoretical consideration of both the pre-training and fine-tuning phases.

Next-token prediction: next token prediction (NTP) is the standard loss function for large language models pre-training [achiam2023gpt, dubey2024llama]. Exploring the implicit bias introduced by NTP is necessary for understanding modern LLMs. However, analyzing NTP is challenging because of its non-i.i.d. data structure. NTP diverges from simpler, well-studied learning paradigms in two fundamental ways: first, the strong correlations among data samples drawn from the same sequence; and second, the multi-label nature of the samples, where a single context may admit multiple valid subsequent tokens.

In this paper, we seek to fill the above gap to answer the following key open research questions:

1.

How do transformers acquire knowledge during pre-training (PT) via NTP and extract it after fine-tuning (FT)?

2.

Under which conditions can pre-trained transformers perform OOD generalization after full and low-rank FT?

To resolve these questions, we examine the feature learning dynamics of one-layer “self-attention + MLP” transformers during pre-training and fine-tuning. We derive convergence guarantees for the pre-training phase and establish post-fine-tuning OOD generalization bounds that delineate conditions for successful generalization versus hallucination.

1.1Main Results

Inspired by [allen2023physics1], we examine a typical training diagram where the model is pre-trained on general knowledge and then fine-tuned for the downstream question-answering task. The formal data settings and their justifications are detailed in Section 3. Here, we provide a concise overview of these settings.

For pre-training, the transformer is trained on sentences structured as “noisesubject relation answer” (e.g., The following list details each person’s birthplace. Alice was born in California) using the NTP objective. This formulation naturally extends the conventional subject-relation-answer modeling by incorporating potentially irrelevant contextual information, thereby enhancing realism and encouraging the model to develop appropriate attention mechanisms to filter noise.

Drawing further inspiration from the “Multi- 𝐾 ” knowledge augmentation in [allen2023physics1], we replicate certain subject-answer pairs 𝐾 times during pre-training, each instance featuring a semantically equivalent but rephrased relation (e.g., “was born in” is equivalent to “comes from” or “hails from”) sampled from a universal set ℛ .

In the fine-tuning phase, the model is adapted using a random subset of 𝑁 ~ 𝑓 subject-answer pairs, reformatted into a question-and-answer (Q&A) style: “subject Q&A format answer” (e.g., Q: Alice’s hometown is? A: California). Finally, we assess the model’s post-FT OOD generalization on unseen subject-answer pairs, employing prompts in the same Q&A format.

Under these data settings, we provide the first theoretical characterizations of transformers’ ability to achieve post-FT OOD generalization following NTP PT and full and low-rank FT.

Theorem 1 (Informal, see Theorem 2 for the formal statement).

For any constant 𝛿 , 𝜅

0 , under certain regularity conditions and proper choices of learning rates, with probability at least 1 − 𝛿 :

1.

During PT over 𝑇 𝑝 iterations, there exists 0 ≤ 𝑡 𝑝 ≤ 𝑇 𝑝 such that the transformer achieves near-optimal pre-training loss:

Pre-training loss at iteration  ​ 𝑡 𝑝 < 0.001 + ( 1 + 𝜅 ) ​ 𝐻 ,

(1)

where 𝐻 is the entropy of the NTP training dataset, representing the optimal achievable cross-entropy loss.

2.

Under conditions 𝑁 ~ 𝑓

Ω ​ ( | ℛ | ) (or 𝑁 ~ 𝑓

Ω ​ ( | ℛ | 2 ) ), after 𝑇 𝑓 full (or low-rank) FT iterations, the fine-tuned model extracts the embedded knowledge effectively, yielding:

Post-FT OOD generalization error ≤ exp ⁡ ( − 𝑁 ~ 𝑓 ​ 𝐾 2 ​ log ⁡ ( | ℛ | ) 2 ​ | ℛ | 2 ) .

(2) 3.

When 𝑁 ~ 𝑓 ​ 𝐾

𝒪 ​ ( | ℛ | ) , the full and low-rank fine-tuned model may hallucinate, with:

Post-FT OOD generalization error ≥ 0.1 .

(3) How do the self-attention and MLP modules evolve during pre-training and fine-tuning?

We proved that in the first stage, self-attention quickly learns to “unattend” to the noise tokens and focus on subject and relation tokens. Afterward, the attention scores on subject and relation tokens remain comparable during the whole PT, enabling the MLP module to learn balanced features for subjects and relations. As a result, even with NTP, the MLP learn features to associate the answer with the subject at strength after pre-training. Then, during fine-tuning, MLP integrates the Q&A format into the learned relation feature, enabling the model to trigger accurate answer generation upon detecting the Q&A format alongside the subject.

Why is low-rank fine-tuning effective?

The effectiveness stems from the fact that FT gradients closely approximate their best rank-1 form, with the approximation error bounded by 𝒪 ​ ( 𝜆 / 𝑚 ​ 𝑁 ~ 𝑓 ) , where 𝑚 is the MLP width and 𝜆 is a normalization factor. Moreover, the right singular vector corresponding to the largest singular value of the MLP gradient aligns closely with the Q&A format embedding, preserving essential format information. Thus, for sufficiently large 𝑁 ~ 𝑓 , low-rank FT adapts the pre-trained model with the Q&A format, facilitating OOD generalization.

Our experimental contributions.

To validate our theoretical insights, we conduct targeted experiments. First, we replicate our theoretical settings using synthetic data, thereby verifying our theory and demonstrating their congruence with empirical trends in large-scale language models like GPT-2. We further extend this validation to real-world scenarios by applying GPT-2 and Llama-3.2-1B on the PopQA dataset [mallen2022not].

1.2Related Work

Our work is most closely related to theoretical analyses of transformer training dynamics. We review the recent literature on this topic below. Due to the space constraints, a broader discussion of related topics, including transformer memorization, feature learning, and empirical studies of knowledge acquisition and extraction, is deferred to Appendix B.

Table 1:Comparison with existing theoretical works on transformer training dynamics. Given a sequence ’ 𝑎 , 𝑏 , 𝑐 ’, NTP constructs a training pair for every adjacent position, 𝑎 → 𝑏 and 𝑎 ​ 𝑏 → 𝑐 , whereas “seq”-prefix prediction treats the entire prefix as context, producing a single pair 𝑎 ​ 𝑏 → 𝑐 . Task Characterization Obj. PT FT OOD Architectures Our Paper Knowledge ​​Full dynamics​​ NTP ​1-layer self-attention + MLP [nichani2024understanding, zhu2024towards] Knowledge ​​Full dynamics​​ seq - - ​1-layer linear or attention-only [cabannes2024learning, tian2023scan] [ghosal2024understanding] Knowledge ​​Gradients​​ seq - - ​1-layer attention-only​ [wang2025learning] In-context ​​Layerwise dynamics​​ seq - - ​ 𝐿 -layer ( 𝐿 ≥ 2 ) attention-only​ [nichani2024transformers] In-context ​​Layerwise dynamics​​ seq - ​2-layer attention-only​ [ahn2023transformers, mahankali2023one] In-context ​​Critical points​​ seq - - Linear​ [bietti2023birth] In-context ​​Gradients​​ seq - - ​2-layer attention-only​ [zhang2024trained] In-context ​​Full dynamics​​ seq - ​1-layer linear attention-only​ [huang2023context, zhang2024trained] In-context ​​Gradients​​ seq - - ​1-layer attention-only​ [jelassi2022vision] Img. cls. ​​Full dynamics​​ - - ​1-layer attention-only​ [sakamoto2024benign] Img. cls. ​​Full dynamics​​ - - - ​1-layer attention-only​ [wang2024transformers] Others ​​Full dynamics​​ - - - ​1-layer attention-only​ [tian2023joma] Others ​​Full dynamics​​ - - - ​ 1-layer uniform attention + MLP​

Extensive theoretical research has examined the (pre-)training and fine-tuning of transformers, often focusing on training dynamics within simplified architectural settings. For example, the dynamics of one-layer attention-only transformers have been explored for tasks such as next-token prediction in long sequences [tian2023scan] and image classification [jelassi2022vision]. Additionally, the full training dynamics of linear one-layer attention-only transformers have been studied for factual recall [nichani2024understanding] and in-context learning [zhang2024trained]. Other architectures have also been explored. For example, Cabannes et al. [cabannes2024learning] analyzed the training dynamics of learning associative memories with a single linear layer.

A complementary line of research examines transformer training through gradients and critical points. For example, Bietti et al. [bietti2023birth] investigated gradients in two-layer attention-only transformers to elucidate the emergence of induction heads, while Ghosal et al. [ghosal2024understanding] analyzed single-step gradients in one-layer attention-only transformers during fine-tuning. However, for non-convex transformers, analyses limited to critical points and gradients are insufficient to capture the full training dynamics or the properties of the converged solution.

Our work distinguishes itself by characterizing the full joint training dynamics of self-attention and MLP components during NTP-based PT, as well as the full and low-rank joint dynamics during subsequent FT. Moreover, we evaluate the out-of-distribution generalization performance of the fine-tuned transformers. A detailed comparison of our approach and characterizations with prior work is presented in Table 1.

1.3Notation

We use lowercase letters for scalars, lowercase boldface letters for vectors, and uppercase boldface letters for matrices. The set { 1 , ⋯ , 𝑚 } is denoted by [ 𝑚 ] . Given two sequences { 𝑥 𝑛 } and { 𝑦 𝑛 } , we denote 𝑥 𝑛

𝒪 ​ ( 𝑦 𝑛 ) if | 𝑥 𝑛 | ≤ 𝐶 1 ​ | 𝑦 𝑛 | for some positive constant 𝐶 1 and 𝑥 𝑛

Ω ​ ( 𝑦 𝑛 ) if | 𝑥 𝑛 | ≥ 𝐶 2 ​ | 𝑦 𝑛 | for some positive constant 𝐶 2 . We use 𝑥 𝑛

Θ ​ ( 𝑦 𝑛 ) if 𝑥 𝑛

𝒪 ​ ( 𝑦 𝑛 ) and 𝑥 𝑛

Ω ​ ( 𝑦 𝑛 ) both hold. The notations 𝒪 ~ ​ ( ⋅ ) , Θ ~ ​ ( ⋅ ) , and Ω ~ ​ ( ⋅ ) are used to suppress logarithmic factors. For a matrix 𝐀 , both 𝐴 𝑖 , 𝑗 and ( 𝐀 ) 𝑖 , 𝑗 denote the element of 𝐀 located at the 𝑖 𝑡 ​ ℎ row and the 𝑗 𝑡 ​ ℎ column.

2Simplified One-Layer Transformer Architecture Figure 1:Illustration of the simplified one-layer transformer.

Standard one-layer transformers typically comprise seven weight matrices and two normalization operations. Analyzing their learning dynamics is challenging due to the complex interactions among these matrices and the non-smooth, non-convex loss landscape. To address these issues and enhance analytical tractability, we introduce several simplifications to the standard architecture. The resulting model, as empirically validated (Table 3 in Appendix A), retains the core mechanisms of knowledge acquisition and extraction. An overview of the architecture is depicted in Figure 1. We describe its components below.

Self-Attention.

The attention mechanism computes scores for an input sequence 𝐗 ∈ ℝ 𝑑 × 𝐿 (where 𝑑 is the embedding dimension and 𝐿 is the sequence length), using the query derived from the last token. The attention scores are computed by:

𝜶 ​ ( 𝐙 , 𝐗 )

softmax ​ ( 𝐗 ⊤ ​ 𝐖 𝐾 ⊤ ​ 𝐖 𝑄 ⏞ 𝐙 ​ 𝐗 ​ [ − 1 ] 𝑑 ) ,

(4)

where 𝐗 ​ [ − 1 ] ∈ ℝ 𝑑 × 1 denotes the embedding of the last token.2 Following established simplifications [tian2023scan], we reparameterize 𝐖 𝐾 ⊤ ​ 𝐖 𝑄 as a single matrix 𝐙 ∈ ℝ 𝑑 × 𝑑 . The self-attention output 𝐱 𝑎 ∈ ℝ 𝑑 is then:

𝐱 𝑎 ​ ( 𝐙 , 𝐗 )

𝐗 ​ 𝜶 ​ ( 𝐙 , 𝐗 ) + 𝐗 ​ [ − 1 ] .

(5) Multi-Layer Perceptron.

The MLP processes the self-attention output 𝐱 𝑎 . With trainable parameters 𝐖 ∈ ℝ 𝑑 ​ 𝑚 × 𝑑 (first layer) and fixed parameters 𝐖 2 ∈ ℝ 𝑑 × 𝑑 ​ 𝑚 (second layer), the MLP output is:

𝐱 𝑓 ​ ( 𝐖 , 𝐙 , 𝐗 )

𝐖 2 ​ 𝜎 ​ ( 𝐖𝐱 𝑎 ​ ( 𝐙 , 𝐗 ) ) ,

(6)

where 𝜎 ​ ( ⋅ ) denotes the ReLU activation function. We fix the matrix 𝐖 2 as:

𝐖 2

[ 1 𝑚 ​ 𝟏 ⊤

𝟎 ⊤

𝟎 ⊤

𝟎 ⊤

𝟎 ⊤

1 𝑚 ​ 𝟏 ⊤ ] ,

(7)

where 𝟎 , 𝟏 ∈ ℝ 𝑚 are all-zeros and all-ones vectors, respectively. This block-diagonal structure, commonly employed in MLP analyses [allen2020towards, cao2022benign], decouples neurons across output dimensions, thereby simplifying the theoretical treatment.

Linear Layer.

Then, the MLP output is fed into a final linear layer:

𝐱 output ​ ( 𝐖 , 𝐙 , 𝐗 )

𝜆 ​ 𝐈𝐱 𝑓 ​ ( 𝐖 , 𝐙 , 𝐗 ) ,

(8)

where 𝜆 is a constant scaling factor and 𝐈 is the identity matrix. This scaling mimics learnable gains or fixed normalization in layers like RMSNorm, which typically scales inputs to have an L2 norm of 𝑑 .

Remark 1.

Our simplified transformer omits positional encodings, compelling the model to rely exclusively on semantic content for next-token prediction.

3Data

In this section, we define a synthetic data generation process designed to capture the essential characteristics of sentences, including the representation of factual knowledge. Our synthetic data aims to mimic structures commonly observed in real-world text. A pre-training example is:

  XXX (Irrelevant information) Here is a record ⏟ Context ​ . ⏟ Ending ​ Alice ⏟ Subject ​   ​ was born in ⏟ Relation phrase ​   ​ California ⏟ Answer ,

and a corresponding FT example is:

Question : ⏟ Q&A format 1 ​ Alice ⏟ Subject ​ ’s hometown is      ?   Answer :  ⏟ Q&A format 2 ​ California ⏟ Answer

To model these structures, we introduce the following data distributions.

Pre-Training Sentence Generation (Sentence Set 𝒯 ).

We begin with a subject-answer set 𝒜

{ ( 𝐬 𝑗 , 𝐚 𝑗 ) } 𝑗

1 𝑁 comprising 𝑁 unique pairs, where 𝐬 𝑗 ∈ ℝ 𝑑 is the subject embedding and 𝐚 𝑗 ∈ ℝ 𝑑 is the answer embedding. This set is partitioned into 𝑁 𝑓 “frequent” pairs (freq) and 𝑁 𝑟 “rare” pairs (rare). Frequent pairs are associated with 𝐾

Θ ​ ( 1 ) semantically equivalent relation phrases,3 while rare pairs are linked to a single unique relation phrase. This setup emulates data augmentation techniques such as “Multi- 𝐾 ” in [allen2023physics1], where facts are presented via multiple templates (i.e., semantically equivalent relation phrases).

The sentence generation process for 𝒯 proceeds as follows:

Given context token embedding 𝐨 ∈ ℝ 𝑑 and ending token embedding 𝐝 ∈ ℝ 𝑑 , let ℛ denote the universal set of semantically equivalent relation phrase embeddings. For each subject-answer pair ( 𝐬 𝑗 , 𝐚 𝑗 ) ∈ 𝒜 : 1. Determine the number of relation phrases 𝐾 ~ ​ ( 𝑗 )

𝐾 ⋅ 𝕀 ​ ( 𝑗 ∈ freq ) + 𝕀 ​ ( 𝑗 ∈ rare ) . 2. Sample a subset ℛ ~ ⊆ ℛ of 𝐾 ~ ​ ( 𝑗 ) relation phrase embeddings uniformly at random without replacement. 3. For each 𝑖 ∈ ℛ , construct a corresponding tuple of the form ( 𝐨 , 𝐝 , 𝐬 𝑗 , 𝐫 𝑖 , 𝐚 𝑗 ) .

As in prior works [allen2020towards, tian2023scan], we assume all token embeddings— 𝐨 , 𝐝 , 𝐬 𝑗 , 𝐫 𝑖 , 𝐚 𝑗 for all 𝑗 ∈ [ 𝑁 ] and 𝑖 ∈ ℛ —are orthogonal.4 We further assume ‖ 𝐬 𝑗 ‖ 2

‖ 𝐫 𝑖 ‖ 2

‖ 𝐚 𝑗 ‖ 2

‖ 𝐝 ‖ 2

1 , while ‖ 𝐨 ‖ 2

Θ ​ ( 𝑑 ) to embed subject-relation-answer tokens within a “rich” contextual environment. In this paper, we consider the simplest case with a single context; our analysis extends to Θ ​ ( 1 ) contexts assigned uniformly at random.

Next-Token Prediction Dataset Generation (Dataset 𝒫 ).

The transformer is pre-trained using an NTP objective. We construct the NTP dataset 𝒫 by creating input-target pairs from the sentence set 𝒯 :

From each sentence ( 𝐨 , 𝐝 , 𝐬 , 𝐫 , 𝐚 ) ∈ 𝒯 , generate ( 𝐨 , ℐ ​ ( 𝐝 ) ) , ( [ 𝐨 ​ 𝐬 ] , ℐ ​ ( 𝐫 ) ) , ( [ 𝐨 ​ 𝐬 ​ 𝐫 ] , ℐ ​ ( 𝐚 ) ) , where ℐ ​ ( ⋅ ) maps the embeddings to their indices.

Two simplifications are applied here. First, we omit the prediction of the subject 𝐬 from 𝐨𝐝 , i.e., ( [ 𝐨 ​ 𝐝 ] , ℐ ​ ( 𝐬 ) ) , as it resembles a random guess. Second, since 𝐝 invariably follows 𝐨 and its magnitude is dominated by that of 𝐨 , we ignore the ending token 𝐝 in ( [ 𝐨 ​ 𝐝 ​ 𝐬 ] , ℐ ​ ( 𝐫 ) ) and ( [ 𝐨 ​ 𝐝 ​ 𝐬 ​ 𝐫 ] , ℐ ​ ( 𝐚 ) ) . Therefore, the total number of NTP data samples is 𝑛

3 × 𝑁 𝑓 × 𝐾 + 3 × 𝑁 𝑟 .

Remark 2.

Despite these simplifications, 𝒫 preserves the core attributes of autoregressive NTP. The training requires the transformer to process the subject ( [ 𝐨 , 𝐬 ] ) before predicting the subsequent relation ( 𝐫 ). This sequential dependency instructs the model to interpret the relation as prerequisite to determining the answer, thereby discouraging direct subject-answer memorization.

Consequently, the model relies heavily on the relation token. Absent this token during inference, the model falters in producing accurate answers due to the missing relational cue. Fine-tuning is thus essential to adapt the model for answer prediction using only the subject, enabling robust performance without explicit relations.

Q&A Dataset Generation (Dataset 𝒬 ).

For FT and evaluation in a question-answering paradigm, we generate dataset 𝒬 , where each sample entails predicting answer 𝐚 given subject embedding 𝐬 and format embedding 𝐩 .5 The generation process is:

Given a format token embedding 𝐩 ∈ ℝ 𝑑 , for each subject-answer pair ( 𝐬 , 𝐚 ) ∈ 𝒜 , output ( [ 𝐬    𝐩 ] , ℐ ( 𝐚 ) ) .

Without loss of generality, we assume that the embedding of format token 𝐩 is orthogonal to all other token embeddings 𝐨 , 𝐝 , 𝐬 𝑗 , 𝐫 𝑖 , 𝐚 𝑗 for all 𝑗 ∈ [ 𝑁 ] , 𝑖 ∈ ℛ and has unit magnitude, i.e., ‖ 𝐩 ‖ 2

1 .

Fine-Tuning Q&A Dataset Generation (Dataset 𝒬 𝑠 ).

To examine FT’s impact across varying PT multiplicities, we form 𝒬 𝑠 as a random subset comprising proportion 𝛽 of Q&A pairs from 𝒬 corresponding to freq pairs: Let 𝒬 freq

{ ( [ 𝐬 𝑗 ​ 𝐩 ] , ℐ ​ ( 𝐚 𝑗 ) ) ∈ 𝒬 | 𝑗 ∈ freq } . Then, 𝒬 𝑠 is a randomly selected subset of proportion 𝛽 from 𝒬 freq . Our PT and FT datasets naturally extend conventional subject ( 𝑠 )-relation ( 𝑟 )- answer ( 𝑎 ) knowledge modeling [ghosal2024understanding, meng2022locating, haviv2022understanding, meng2023mass] in two key ways:

Practical Relevance: They better represent knowledge embedded in longer contexts. Given NTP’s autoregressive nature, knowledge triples ( 𝑠 , 𝑟 , 𝑎 ) are typically prefixed with irrelevant information.

Technical Challenge: The inclusion of noise ( 𝐨 , 𝐝 ) precludes learning facts via MLPs with uniform attention alone. Transformers must develop targeted attention patterns to filter noise and attain near-optimal loss.

4Problem Setting

In this section, we formalize the optimization problem, encompassing the loss function, training algorithms, and parameter initialization.

Loss Function.

We employ the cross-entropy loss to train the model parameters 𝐖 and 𝐙 on the next-token prediction dataset with input-target pairs 𝒫

{ ( 𝐗 𝑖 , 𝑦 𝑖 ) } 𝑖

1 𝑛 :

ℒ 𝒫 ​ ( 𝐖 , 𝐙 )

1 𝑛 ​ ∑ ( 𝐗 , 𝑦 ) ∈ 𝒫 ℒ ​ ( 𝐖 , 𝐙 , 𝐗 , 𝑦 ) ,

(9)

where the per-sample loss is ℒ ​ ( 𝐖 , 𝐙 , 𝐗 , 𝑦 )

− log ⁡ ( logit 𝑦 ​ ( 𝐖 , 𝐙 , 𝐗 ) ) , and 𝐥𝐨𝐠𝐢𝐭 ​ ( 𝐖 , 𝐙 , 𝐗 )

softmax ⁡ ( 𝐱 output ​ ( 𝐖 , 𝐙 , 𝐗 ) ) .

Pre-Training Algorithm.

We train the model with gradient descent (GD):

𝐙 ( 𝑡 + 1 )

𝐙 ( 𝑡 ) − 𝜂 ( 𝑡 ) 𝑛 ​ ∑ ( 𝐗 , 𝑦 ) ∈ 𝒫 ∇ 𝐙 ( 𝑡 ) ℒ ​ ( 𝐖 ( 𝑡 ) , 𝐙 ( 𝑡 ) , 𝐗 , 𝑦 ) ,

𝐖 ( 𝑡 + 1 )

𝐖 ( 𝑡 ) − 𝜂 ( 𝑡 ) 𝑛 ​ ∑ ( 𝐗 , 𝑦 ) ∈ 𝒫 ∇ 𝐖 ( 𝑡 ) ℒ ​ ( 𝐖 ( 𝑡 ) , 𝐙 ( 𝑡 ) , 𝐗 , 𝑦 ) ,

where 𝜂 ( 𝑡 ) is the learning rate at iteration 𝑡 . For the pre-training phase, we adopt a three-stage learning-rate schedule with two pre-defined thresholds ℎ 1 and ℎ 2 : 𝜂 ( 𝑡 )

𝜂 1 for 𝑡 ≤ ℎ 1 , 𝜂 2 for ℎ 1 < 𝑡 ≤ ℎ 2 , and 𝜂 3 for 𝑡

ℎ 2 . As shown in Section 6, this schedule facilitates convergence to near-optimal training loss.

Fine-Tuning Algorithm.

For full FT, we adopt the same GD procedure as in PT on dataset 𝒬 𝑠 , with a fixed learning rate 𝜂 𝑓 . For low-rank fine-tuning, we instead update using the best rank-1 approximation of the gradient: for a matrix 𝐀 , this is obtained by minimizing ‖ 𝐀 − 𝐚𝐛 ⊤ ‖ 𝐹 over 𝐚 , 𝐛 . Further details are provided in Appendix H. This method aligns with practical low-rank adaptation techniques, such as [zhang2023adalora], which prioritize components associated with dominant singular values.

Initialization.

The parameters are initialized independently from a Gaussian distribution: 𝐙 ( 0 ) , 𝐖 ( 0 ) ∼ 𝒩 ​ ( 𝟎 , 𝜎 0 2 ​ 𝐈 ) .

Building on this formulation, we present our main theoretical results in the subsequent section, providing formal guarantees on convergence and generalization.

5Conditions for Out-of-Distribution Generalization Post Fine-Tuning

In this section, we address the key question: Under what conditions do transformers achieve out-of-distribution (OOD) generalization following fine-tuning (FT)? We provide an answer by deriving convergence bounds for the PT phase over 𝑇 𝑝 iterations and generalization bounds after FT over 𝑇 𝑓 iterations.

Our theoretical analyses rely on the following conditions:

Condition 1.

Assume there exist sufficiently large positive constants 𝐶 1 ′ , 𝐶 2 ′ , 𝐶 3 ′ , 𝐶 4 ′ , 𝐶 5 ′ , 𝐶 6 ′ , 𝐶 7 ′ . For a probability parameter 𝛿 ∈ ( 0 , 1 ) , the following hold:

1.

The learning rates satisfy:

{ 𝐶 1 ′ ​ 𝑚 ​ 𝑛 / 𝜆 ​ 𝑑 1 / 4 ≤ 𝜂 1 ≤ 𝐶 2 ′ ​ 𝑚 ​ log ⁡ ( 𝑑 ) / 𝜆 2 ,

𝜂 2 ≤ 𝑛 ​ 𝑚 ​ log ⁡ ( 𝑑 ) / ( 𝐶 3 ′ ​ 𝜆 2 ) ,

𝜂 3 ≤ 𝑛 ​ 𝑚 ​ log ⁡ ( 𝑑 ) / ( 𝐶 4 ′ ​ 𝜆 2 ) .

(10) 2.

The embedding dimension satisfies:

𝑑 ≥ 𝐶 5 ′ ⋅ max ⁡ { 𝑚 ​ 𝑁 ​ ( 𝑁 + | ℛ | ) 2 𝜆 2 , 𝜆 2 ( 𝑛 2 ​ 𝛿 2 ) } .

(11) 3.

The initialization magnitude satisfies: 𝜎 0

Θ ​ ( 1 / 𝑑 ) .

4.

The width (number of neurons) of the MLP satisfies: 𝑚 ≥ 𝐶 6 ′ ⋅ log ⁡ ( ( 𝑁 ​ 𝐾 ) / 𝛿 ) .

5.

The size of training data satisfies: 𝑁 ≥ 𝐶 7 ′ ​ log ⁡ ( 𝑚 / 𝛿 ) .

Condition 1 ensures learning rates are appropriately bounded to minimize training loss. Condition 2 mandates a sufficiently large embedding dimension for stable per-sample learning during PT, akin to assumptions in prior work [kou2023benign, frei2022benign, chatterji2021finite]. Condition 3 specifies an initialization scale common in practical LLMs, such as GPT-2 [radford2019language] and Llama-2 [touvron2023llama]. Condition 4 imposes a mild logarithmic lower bound on MLP width, ensuring a non-negligible number of neurons ( Ω ​ ( 𝑚 ) neurons) activate per sample. Condition 5 is a mild logarithmic lower bound on training data size, ensuring all the 𝑚 neurons are covered.

We define the knowledge extraction loss (or generalization loss) as the misclassification rate on unseen Q&A pairs during FT ( ( 1 − 𝛽 ) ​ 𝑁 𝑓 + 𝑁 𝑟 pairs):

ℒ 𝑒 ​ ( 𝐖 , 𝐙 )

∑ ( 𝐗 , 𝑦 ) ∈ 𝒬
𝒬 𝑠 | 𝒬
𝒬 𝑠 | ​ ℙ ​ [ ( 𝐱 output ​ ( 𝐖 , 𝐙 , 𝐗 ) ) 𝑦 ≤ max 𝑖 ≠ 𝑦 ⁡ { ( 𝐱 output ​ ( 𝐖 , 𝐙 , 𝐗 ) ) 𝑖 } ] .

(12)

We now analyze the PT training loss and post-FT knowledge extraction loss for full and low-rank FT.

Theorem 2.

For any constant 𝛿 > 0 , under Condition 1, select the third-stage pre-training learning rate 𝜂 3 ≤ 𝜅 / 64 (with 𝜅 ≤ 1 / 2 as a constant) and fine-tuning learning rate 𝜂 𝑓

Θ ~ ​ ( 𝑚 ​ | ℛ | / 𝜆 2 ) . Then, with probability at least 1 − 𝛿 :

1.

During PT over 𝑇 𝑝

𝜂 2 − 1 ​ poly ⁡ ( 𝑛 , 𝑚 , 𝜆 − 1 , 𝐾 − 1 ) + 𝜂 3 − 1 ​ poly ⁡ ( 𝑚 , 𝑁 , | ℛ | , 𝜆 − 1 ) iterations, there exists 0 ≤ 𝑡 𝑝 ≤ 𝑇 𝑝 where the model acquires knowledge:

ℒ 𝒫 ​ ( 𝐖 ( 𝑡 𝑝 ) , 𝐙 ( 𝑡 𝑝 ) ) < 0.001 + ( 1 + 𝜅 ) ​ 𝐻 .

(13) 2.

If 𝛽 ​ 𝑁 𝑓 ≥ 𝐶 1 ​ | ℛ | (full FT) or 𝛽 ​ 𝑁 𝑓 ≥ 𝐶 2 ​ 𝑚 ​ | ℛ | 2 (low-rank FT), then after 𝑇 𝑓

𝒪 ​ ( 𝛽 ​ 𝑁 𝑓 ​ 𝑑 − 0.01 / | ℛ | ) iterations, the model extracts knowledge:

ℒ 𝑒 ​ ( 𝐖 ( 𝑇 𝑝 + 𝑇 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑇 𝑓 ) ) ≤ exp ⁡ ( − 𝛽 ​ 𝑁 𝑓 ​ 𝐾 2 ​ log ⁡ | ℛ | 2 ​ | ℛ | 2 ) .

(14) 3.

If 𝛽 ​ 𝑁 𝑓 ​ 𝐾 / | ℛ | ≤ 𝐶 3 , the full or low-rank fine-tuned model may hallucinate:

ℒ 𝑒 ​ ( 𝐖 ( 𝑇 𝑝 + 𝑇 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑇 𝑓 ) ) ≥ 𝑐 1 .

Here, 𝐶 1 , 𝐶 2 , 𝐶 3 , 𝑐 1

0 are positive constants, and 𝐻 is the entropy of the NTP training dataset (optimal cross-entropy loss).

Theorem 2 demonstrates that FT on frequent facts enhances OOD generalization, even to rare facts, aligning with empirical findings [ghosal2024understanding, allen2023physics1]. The bound in (14) reveals that higher multiplicity 𝐾 (diverse fact presentations in PT) reduces generalization error, particularly for large | ℛ | , explaining the success of knowledge augmentation [allen2023physics1].6

Remark 3.

A novel theory-derived insight is that post-FT OOD accuracy declines with increasing | ℛ | , a phenomenon obscured in empirical studies by natural language complexity.

6Mechanics of Knowledge Acquisition and Extraction

This section provides a detailed characterization of the training dynamics during PT with NTP and subsequent knowledge extraction via FT, serving as a proof sketch for Theorem 2.

Pre-Training

As outlined in Section 4, our analysis of pre-training dynamics employs a three-stage learning rate schedule:

In Stage I ( 0 ≤ 𝑡 < 𝑇 1

2 ), the learning rate is 𝜂 1 ;

In Stage II ( 𝑇 1 ≤ 𝑡 ≤ 𝑇 2

Θ ​ ( 𝑛 ​ 𝑚 ​ log ⁡ ( 𝑑 ) / ( 𝜆 2 ​ 𝜂 2 ​ 𝐾 ) ) ), the learning rate is 𝜂 2 ;

In Stage III ( 𝑇 2 < 𝑡 ≤ 𝑇 𝑝 ), the learning rate is 𝜂 3 .

We characterize the training dynamics of self-attention and MLP separately.

Proposition 1 (Training dynamics of self-attention).

Under Condition 1, with probability at least 1 − 𝛿 , after Stage I ( 𝑇 1 ≤ 𝑡 ≤ 𝑇 𝑝 ):

1.

The transformer learns to filter out the irrelevant context token 𝐨 : For all ( 𝐨 , 𝐝 , 𝐬 , 𝐫 , 𝐚 ) ∈ 𝒯 , the attention scores7 satisfy:

𝛼 1 ​ ( 𝐙 ( 𝑡 ) , [ 𝐨 ​ 𝐬 ] ) ≤ 1 𝑑 and 𝛼 1 ​ ( 𝐙 ( 𝑡 ) , [ 𝐨 ​ 𝐬 ​ 𝐫 ] ) ≤ 1 𝑑 .

(15) 2.

The transformer’s attention scores on 𝐬 𝑗 and 𝐫 𝑖 are comparable: For all ( 𝐨 , 𝐝 , 𝐬 , 𝐫 , 𝐚 ) ∈ 𝒯 :

0.5 ≤ 𝛼 2 ​ ( 𝐙 ( 𝑡 ) , [ 𝐨 ​ 𝐬 ​ 𝐫 ] ) 𝛼 3 ​ ( 𝐙 ( 𝑡 ) , [ 𝐨 ​ 𝐬 ​ 𝐫 ] ) ≤ 2 .

(16) Proposition 2 (Training dynamics of MLP).

During PT, under Condition 1, with probability 1 − 𝛿 , the followings hold:

1.

After Stage I ( 𝑇 1 ≤ 𝑡 ≤ 𝑇 𝑝 ), the transformer learns to predict 𝐨 𝑙 → 𝐝 :

𝐱 ℐ ​ ( 𝐝 ) output ​ ( 𝐖 ( 𝑡 ) , 𝐙 ( 𝑡 ) , 𝐨 )

Ω ~ ​ ( 1 ) .

(17) 2.

During Stage I and II ( 0 ≤ 𝑡 ≤ 𝑇 2 ), the transformer rapidly learns to predict 𝐨𝐬 𝑗 → 𝐫 𝑖 and 𝐨𝐬 𝑗 ​ 𝐫 𝑖 → 𝐚 𝑗 : For all ( 𝐨 , 𝐝 , 𝐬 𝑗 , 𝐫 𝑖 , 𝐚 𝑗 ) ∈ 𝒯 ,

1 − logit ℐ ​ ( 𝐚 𝑗 ) ​ ( 𝐖 ( 𝑡 ) , 𝐙 ( 𝑡 ) , [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] )

Θ ​ ( 1 ) ​  and  ​ 1 − 𝐾 ~ ​ ( 𝑗 ) ​ logit ℐ ​ ( 𝐫 𝑖 ) ​ ( 𝐖 ( 𝑡 ) , 𝐙 ( 𝑡 ) , [ 𝐨 ​ 𝐬 𝑗 ] )

Θ ​ ( 1 ) .

(18) 3.

At the end of Stage III, the pre-training loss converges to nearly optimal:

1 𝑇 𝑝 − 𝑇 2 ​ ∑ 𝑡

𝑇 2 + 1 𝑇 𝑝 ℒ 𝒫 ​ ( 𝐖 ( 𝑡 ) , 𝐙 ( 𝑡 ) ) ≤ 𝐴 ( 2 − 𝜅 ) ​ 𝜂 3 ​ ( 𝑇 𝑝 − 𝑇 2 ) + 0.0005 + ( 1 + 𝜅 ) ​ 𝐻 ,

(19)

where 𝐴

Θ ~ ​ ( ( 𝑚 + 𝑁 2 / 𝜆 2 + | ℛ | 2 / 𝜆 2 ) / 𝜂 3 ) .

Propositions 1 and 2 elucidate the transformer’s functional division: self-attention filters irrelevant context (visualized in Appendix C), while the MLP memorizes the filtered features. The stages unfold as follows: Stage I rapidly establishes attention patterns to filter out noise with a relatively large learning rate 𝜂 1 ; Stage II develops subject and relation features with a moderate learning rate 𝜂 2 ; Stage III refines the loss to near-optimality with a small learning rate 𝜂 3 (and thus small 𝜅 ). This progression underscores the importance of learning rate scheduling in optimization.

Fine-Tuning.

During the FT phase, the transformer integrates the Q&A format embedding 𝐩 into its pre-learned features. We formalize this process below.

Proposition 3.

For all 𝑖 ∈ ℛ and 𝑗 ∈ [ 𝑁 ] , when 𝛽 ​ 𝑁 𝑓

𝐶 1 ​ | ℛ | , the gradient for 𝐖 in the first iteration of FT satisfies:

1 ​ ⟨ ∇ ¯ ℐ ​ ( 𝐫 𝑖 ) , 𝐩 ⟩

Θ ​ ( 𝜆 𝑚 ​ | ℛ | ) ,
2 ​ ⟨ ∇ ¯ ℐ ​ ( 𝐚 𝑗 ) , 𝐩 ⟩

− Θ ​ ( 𝜆 𝑚 ​ 𝛽 ​ 𝑁 𝑓 ) ,

3 ​ ⟨ ∇ ¯ ℐ ​ ( 𝐫 𝑖 ) , 𝐬 𝑗 ⟩

Θ ​ ( 𝜆 𝑚 ​ 𝛽 ​ 𝑁 𝑓 ) ,
4 ​ ⟨ ∇ ¯ ℐ ​ ( 𝐚 𝑗 ) , 𝐬 𝑗 ⟩

− Θ ​ ( 𝜆 𝑚 ​ 𝛽 ​ 𝑁 𝑓 ) ,

where ∇ ¯ ℐ ​ ( ⋅ )

1 / 𝑚 ​ ∑ 𝑙

1 𝑚 ∇ 𝐰 ℐ ​ ( ⋅ ) , 𝑙 ( 𝑇 𝑝 ) ℒ 𝒬 𝑠 ​ ( 𝐖 ( 𝑇 𝑝 ) , 𝐙 ( 𝑇 𝑝 ) ) .

Proposition 3 reveals that when the FT dataset size 𝛽 ​ 𝑁 𝑓 is much larger than | ℛ | , the gradient exhibits low-rank structure, with 𝐩 components ( 1 , 2 ) dominating others ( 3 , 4 ), explaining low-rank FT efficacy. As gradients prioritize 𝐩 , relation neurons swiftly incorporate the format ( 1 prevails), enabling OOD generalization on unseen Q&A pairs.

Remark 4.

FT may induce overfitting to the format token 𝐩 : When the number of FT iterations satisfies 𝑇 𝑓

Ω ~ ​ ( 𝜆 ​ 𝛽 ​ 𝑁 𝑓 / | ℛ | ) , feature increments ⟨ ∇ ¯ ℐ ​ ( 𝐚 𝑗 ) , 𝐩 ⟩ accumulate, potentially yielding ∑ 𝑙

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐚 𝑖 ) , 𝑙 ( 𝑇 𝑝 + 𝑇 𝑓 ) , 𝐩 ⟩ / 𝑚

Ω ​ ( ∑ 𝑙

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 𝑇 𝑓 ) , 𝐬 𝑗 ⟩ / 𝑚 ) for some unseen 𝑖 ∈ [ 𝑁 ] . Appendix D visualizes OOD accuracy versus FT steps.

7Experiments

In this section, we empirically validate our theoretical results through experiments on both synthetic and real-world datasets.

(a)Impact of the number of subject-answer pairs 𝑁 . (b)Impact of the size of relation phrases | ℛ | . Figure 2:OOD generalization accuracy of our simplified one-layer transformers with 5-token data. (a)Impact of the number of subject-answer pairs 𝑁 . (b)Impact of the size of relation phrases | ℛ | . Figure 3:OOD generalization accuracy of GPT-2 with 5-token data.

We generate a synthetic dataset adhering strictly to the process outlined in Section 3. Experiments are conducted using two architectures: (1) our simplified one-layer transformer (Section 2) and (2) a standard 12-layer, 12-head GPT-2 model [radford2019language]. Additional experimental details are provided in Appendix J. Notably, both models display similar performance trends, indicating that our simplified framework captures key dynamics of modern large language models (LLMs).

As illustrated in Figures 2(a) and 3(a), OOD generalization accuracy improves with increasing pre-training multiplicity 𝐾 and fine-tuning dataset size 𝛽 ​ 𝑁 𝑓 . Furthermore, Figures 2(b) and 3(b) show that accuracy declines as the size of the universal relation set | ℛ | grows. These trends align precisely with our theoretical predictions (Theorem 2).

7.1Real-World Dataset

To further assess the applicability of our theory to practical scenarios, we conducted experiments on modern LLMs (GPT-2 and Llama-3.2-1B [grattafiori2024llama]) with PopQA dataset [mallen2022not].

Pre-Train and Fine-Tune GPT-2.

We pre-train GPT-2 from scratch and subsequently fine-tune it on a curated subset of PopQA focused on movies and their directors. Pre-training data incorporates multiple phrasings per fact (e.g., “[movie] was directed by [director]” or “[movie] hails from [director]”), simulating multiplicity 𝐾 . Fine-tuning uses a question-answering format (e.g., “Question: Who is the director of [movie]? Answer:”). Data examples and training details are in Appendices K and J.3, respectively. Results in Figure 4 show that exact match accuracy on unseen facts rises with both fine-tuning dataset size and multiplicity 𝐾 , corroborating Theorem 2.

Figure 4:Exact match accuracy of fine-tuned Llama-3.2-1B on PopQA. Table 2:Exact match accuracy and F1 scores of Llama-3.2-1B fine-tuned with frequent and rare facts. FT Data Exact Match F1 score Frequent 26.36 28.68 Rare 20.93 22.74 Fine-Tune Llama-3.2-1B.

We further test our insights by fine-tuning the pre-trained Llama-3.2-1B model on a PopQA subset concerning capitals, formatted as questions (e.g., “Question: Where is the capital of [location]? Answer:”). Details are in Appendices L and J.4. Table 7.1 reveals superior exact match accuracy and F1 scores when fine-tuning on high-frequency data (higher Wikipedia page views) compared to low-frequency data, affirming our theoretical conditions for effective knowledge extraction.

8Conclusions

In this paper, we develop a theoretical framework to elucidate the mechanisms by which transformers acquire and extract knowledge. Our analysis demonstrates that transformers can attain near-optimal training loss, thereby enabling effective knowledge acquisition during pre-training. Furthermore, we delineate precise conditions under which transformers exhibit strong out-of-distribution generalization following full or low-rank fine-tuning, thus facilitating robust knowledge extraction. Empirical validations on synthetic datasets, as well as the real-world PopQA dataset using large language models such as GPT-2 and Llama-3.2-1B, corroborate our theoretical insights.

Although this work establishes foundational principles for knowledge acquisition and extraction in transformers, our derivations rely on idealized assumptions, including orthogonal embeddings and a block-diagonal structure for the second MLP layer, to mitigate feature interference. Contemporary multi-layer architectures likely encompass more intricate dynamics beyond those examined here. As such, investigating the roles of advanced components, such as multiple layers, attention heads, and optimizers like AdamW, in these processes remains a compelling avenue for future research.

Contents Appendix AAn NTP experiment

In this appendix section, we present additional experimental results for one-layer attention-only transformers. To replicate the realistic next-token prediction and isolate the effect of the contextual noise, we use a 3-token ’subject-relation-answer’ dataset as follows.

We conduct experiments with the following data distribution. We begin with a subject-answer set 𝒜 containing 𝑁 unique subject-answer pairs. Each pair associates with 𝐾 relation phrases. This setup mirrors data augmentation techniques like “Multi- 𝐾 ” in [allen2023physics1], where each fact is presented through multiple templates (relation phrases). The generation process for sentences in 𝒯 is as follows:

Let ℛ be a predefined set of relation phrase embeddings. For each subject-answer pair ( 𝐬 𝑗 , 𝐚 𝑗 ) ∈ 𝒜 :

  1. Sample a subset ℛ ~ ⊆ ℛ of 𝐾 relation phrase embeddings uniformly at random without replacement.
  2. Generate a total of 𝐾 tuples by iterating through each index 𝑖 in the set ℛ ~ and constructing a corresponding tuple of the form ( 𝐬 𝑗 , 𝐫 𝑖 , 𝐚 𝑗 ) .
  3. Convert each tuple ( 𝐬 𝑗 , 𝐫 𝑖 , 𝐚 𝑗 ) to NTP pre-training samples, ( 𝐬 𝑗 , ℐ ​ ( 𝐫 𝑖 ) ) and ( [ 𝐬 𝑗 ​ 𝐫 𝑖 ] , ℐ ​ ( 𝐚 𝑗 ) ) .
  4. Randomly select half tuples from 𝒯 and convert each of them to fine-tuning samples, ( [ 𝐬 𝑗 ​ 𝐩 ] , ℐ ​ ( 𝐚 𝑗 ) ) .

Based on the sentence set 𝒯 , we construct the pre-training dataset as follows: For each ( 𝐬 𝑗 , 𝐫 𝑖 , 𝐚 𝑗 ) ∈ 𝒯 :

  1. Convert each tuple ( 𝐬 𝑗 , 𝐫 𝑖 , 𝐚 𝑗 ) to NTP pre-training samples, ( 𝐬 𝑗 , ℐ ​ ( 𝐫 𝑖 ) ) and ( [ 𝐬 𝑗 ​ 𝐫 𝑖 ] , ℐ ​ ( 𝐚 𝑗 ) ) . We construct the fine-tuning dataset and the test dataset as: For each ( 𝐬 𝑗 , 𝐫 𝑖 , 𝐚 𝑗 ) ∈ 𝒯 :
  2. Randomly select half tuples from 𝒯 and convert each of them to fine-tuning samples, ( [ 𝐬 𝑗 ​ 𝐩 ] , ℐ ​ ( 𝐚 𝑗 ) ) .
  3. Convert each of the remaining tuples to test samples, ( [ 𝐬 𝑗 ​ 𝐩 ] , ℐ ​ ( 𝐚 𝑗 ) ) .

We assume all token embeddings— 𝐬 𝑗 , 𝐫 𝑖 , 𝐚 𝑗 , for all 𝑗 ∈ [ 𝑁 ] and 𝑖 ∈ ℛ are generated from random Gaussian distributions 𝒩 ​ ( 0 , 𝜎 2 ​ 𝐈 ) . We set the parameters as 𝑁

1000 , 𝐾

5 , | ℛ |

5 , 𝜎

1 , 𝑑

50 . We pre-train the transformers (‘standard’ and ’Uniform attention + MLP’) with AdamW optimizer with learning rate 𝜂 ∈ { 1 ​ 𝑒 − 5 , 1 ​ 𝑒 − 4 , 1 ​ 𝑒 − 3 } , weight decay parameter 0.1, and full batch. We fine-tune the pre-trained transformers (‘standard’ and ’Uniform attention + MLP’) with Adam optimizer, using a learning rate 𝜂 ∈ { 1 ​ 𝑒 − 5 , 1 ​ 𝑒 − 4 } and a full batch. We pre-train the simplified transformer with the GD optimizer with a learning rate of 0.5. We fine-tune the pre-trained simplified transformer using the GD optimizer with a learning rate of 𝜂 𝑓

8 .

Performance of the trained models (until convergence) is shown in Table 3. Table 3 shows that attention-only transformers fail to acquire knowledge (achieve near-optimal training loss), implying that MLP is necessary for transformers to acquire knowledge during pre-training. Notably, even ’uniform attention + MLP’ can acquire and extract knowledge under this three-token dataset. To explore the function of the self-attention module, we extend the above three-token dataset to a five-token ’context-subject-relation-answer’ dataset, described in Section 3, under which ’uniform attention + MLP’ fails to acquire and extract knowledge while our simplified transformer can.

Table 3:Training loss and OOD generalization arg max accuracy on the above three-token ’subject-relation-answer’ dataset. We put the cross-entropy loss lower bound, the data entropy, in the bracket. The simplified model details are provided in Sections 2. We do not report the OOD accuracy when the training loss is large, as the extraction accuracy is somehow meaningless when the model cannot acquire the knowledge. Architecture Training loss OOD accuracy Self-attention + MLP (Standard) 0.8064 (0.8047) 100% Uniform attention + MLP 0.8068 (0.8047) 100% Attention-only 5.5875 (0.8047) - Simplified self-attention + MLP (Ours) 0.8200 (0.8047) 99% Appendix BAdditional related works Memorization of transformers.

Many recent works have studied the memorization capability of transformers. A line of work studied the memorization capacity of transformers. Yun et al. [yun2019transformers] demonstrated that transformers are universal approximators. Kim et al. [kim2023provable] proved that transformers can memorize sequence mappings of length- 𝑛

𝑑 -dimensional inputs with 𝑂 ~ ​ ( 𝑑 + 𝑛 + 𝑛 ​ 𝑁 ) parameters. Later, Kajitsuka and Sato [kajitsuka2024optimal] proved lower and upper bounds of the memorization capacity of transformers in next token prediction and sequence-to-sequence settings. Additionally, Mahdavi et al. [mahdavi2023memorization] studied the memorization capacity of multi-head attention-only transformers. Madden et al. [madden2024next] proved upper and lower bounds of memorization capacity of one-layer transformers in next token prediction. Another line of work studied associative memories in transformers. Some works showed that associative memories scale with model sizes for linear models [cabannes2023scaling] and one-layer transformers [nichani2024understanding]. Despite the powerful capacity shown in the above papers, it is unclear how the transformer learns to store and extract knowledge. In this paper, we answer this question by analyzing the training dynamics of transformers.

Feature learning.

A line of work theoretically studied the training dynamics of neural networks by analyzing the feature learning process. Feature learning demonstrates benign overfitting over several data structures [kou2023benign, xu2025rethinking] and reveals the benefits of knowledge distillation [allen2020towards], data augmentations [zou2023benefits, shen2022data], and mixture of experts [chen2022towards]. However, existing feature learning studies are currently limited to two-layer neural networks and attention-only transformers [jiang2024unveil, jelassi2022vision] in image classification tasks. In this paper, we develop a new theoretical framework that incorporates self-attention and MLP modules.

Knowledge acquisition and extraction of transformers.

Many studies explored the mechanistic interpretability of knowledge acquisition and extraction of transformers. A line of work tries to identify knowledge circuits in pre-trained LLMs. For example, Geva et al. [geva2020transformer] studied the function of MLP and highlighted that MLPs are key-value memories capturing input patterns. Based on this, Meng et al. [meng2022locating] localized factual associations in MLPs and proposed a knowledge editing approach. In contrast to knowledge circuits, Ghosal et al. [ghosal2024understanding] investigated the impact of subject entity frequency on knowledge extraction through fine-tuning, demonstrating that increasing frequency enhances the factuality of downstream tasks. Allen-Zhu and Li [allen2023physics1] designed controlled experiments to empirically identify the conditions under which pre-trained transformer models can perform OOD generalizations, highlighting the importance of knowledge augmentations. In our paper, we theoretically prove some of their findings and characterize conditions for transformers to perform OOD generalization.

Appendix CAttention score visualization

We visualize the attention heatmap in Figure 5. It verifies the characterization of the training dynamics in Section 6, which shows that attention scores on 𝐨 are small and attention scores on subject and relation tokens are comparable.

Figure 5:Attention heatmap after pre-training. Appendix DOOD generalization accuracy and FT steps

The visualization results for OOD generalization accuracy and FT steps are shown in Figure 6. The results show that the OOD accuracy after FT may decrease after a large number of training steps due to overfitting, indicating the effectiveness of early stopping in FT.

Figure 6:OOD generalization accuracy versus FT steps. Appendix EPreliminaries

This section provides the preliminaries essential for the proofs of our main results. We begin by deriving the explicit form of the model gradients with respect to its parameters and then present several supporting lemmas.

E.1Gradient computation

For simplicity, we first rewrite the model output 𝐱 output in (8) in the following form:

𝑥 𝑖 output ​ ( 𝐖 , 𝐙 , 𝐗 )

𝜆 𝑚 ​ ∑ 𝑘

1 𝑚 𝜎 ​ ( ⟨ 𝐰 𝑖 , 𝑘 , 𝐱 𝑎 ​ ( 𝐙 , 𝐗 ) ⟩ ) ,

(20)

where 𝐰 𝑖 , 𝑘 is the 𝑘 𝑡 ​ ℎ neuron associated with the 𝑖 𝑡 ​ ℎ element of the final output in the MLP.

We denote

logit ( 𝑡 ) ​ ( 𝐗 )

softmax ​ ( 𝐱 output ​ ( 𝐖 ( 𝑡 ) , 𝐙 ( 𝑡 ) , 𝐗 ) ) .

(21)

Next, we present the model gradients in the following lemma.

Lemma 1.

Under the simplified model architecture, the gradients of model parameters 𝐖 and 𝐙 with respect to an example ( 𝐗 , 𝑦 ) with ‖ 𝐗 ​ [ − 1 ] ‖ 2

1 are

∇ 𝐰 𝑖 , 𝑘 ℒ ​ ( 𝐖 , 𝐙 , 𝐗 , 𝑦 )

− 𝜆 𝑚 ​ ( 𝕀 ​ ( 𝑖

𝑦 ) − logit 𝑖 ​ ( 𝐗 ) ) ​ 𝜎 ′ ​ ( 𝐰 𝑖 , 𝑘 ⊤ ​ 𝐱 𝑎 ) ​ 𝐱 𝑎 ,

(22)

and

( ∇ 𝐙 ℒ ​ ( 𝐖 , 𝐙 , 𝐗 , 𝑦 ) ​ 𝐗 ​ [ − 1 ] ) ⊤

(23)

=

− 𝜆 𝑑 ​ ( 𝐞 𝑦 − logit ​ ( 𝐗 ) ) ⊤ ​ 𝐖 2 ​ [ 𝜎 ′ ​ ( 𝐰 1 , 1 ⊤ ​ 𝐱 𝑎 ​ ( 𝐙 , 𝐗 ) ) ​ 𝐰 1 , 1

𝜎 ′ ​ ( 𝐰 1 , 𝑚 ⊤ ​ 𝐱 𝑎 ​ ( 𝐙 , 𝐗 ) ) ​ 𝐰 1 , 𝑚

𝜎 ′ ​ ( 𝐰 𝑑 , 𝑚 ⊤ ​ 𝐱 𝑎 ​ ( 𝐙 , 𝐗 ) ) ​ 𝐰 𝑑 , 𝑚 ] ​ 𝐗 ​ ( diag ​ ( 𝜶 ​ ( 𝐙 , 𝐗 ) ) − 𝜶 ​ ( 𝐙 , 𝐗 ) ​ 𝜶 ​ ( 𝐙 , 𝐗 ) ⊤ ) ​ 𝐗 ⊤ .

Proof.

For notational simplicity, we treat all derivatives in this proof as Jacobian matrices. Consequently, the gradient of a scalar function is represented as a row vector. To match the conclusions presented in Lemma 1 (which assumes column vectors), the final gradient vectors derived here must be transposed.

Our goal is to compute the gradients of the loss function ℒ ​ ( 𝐖 , 𝐙 , 𝐗 , 𝑦 ) with respect to 𝐰 𝑖 , 𝑘 and 𝐙 . Using the following transformation:

∇ 𝐙 ℒ ​ ( 𝐖 , 𝐙 , 𝐗 , 𝑦 )

𝐗 ​ [ − 1 ] ​ ∇ 𝐙𝐗 ​ [ − 1 ] ℒ ​ ( 𝐖 , 𝐙 , 𝐗 , 𝑦 ) .

(24)

and chain rule, we can break down the gradients into several components:

∇ 𝐰 𝑖 , 𝑘 ℒ ​ ( 𝐖 , 𝐙 , 𝐗 , 𝑦 )

∇ 𝐱 output ℒ ​ ( 𝐖 , 𝐙 , 𝐗 , 𝑦 ) ​ ∇ 𝐰 𝑖 , 𝑘 𝐱 output ,

(25)

and

∇ 𝐙 ℒ ​ ( 𝐖 , 𝐙 , 𝐗 , 𝑦 )

𝐗 ​ [ − 1 ] ​ ∇ 𝐱 output ℒ ​ ( 𝐖 , 𝐙 , 𝐗 , 𝑦 ) ​ ∇ 𝐑 𝐱 output ​ ∇ 𝐱 𝑎 𝐑 ​ ∇ 𝐙𝐗 ​ [ − 1 ] 𝐱 𝑎 ,

(26)

where 𝐑 ≜ 𝜎 ​ ( 𝐖 1 ​ 𝐱 𝑎 ​ ( 𝐙 , 𝐗 ) ) .

We then compute the components in (25) and (26) as follows:

∇ 𝐱 output ℒ ​ ( 𝐖 , 𝐙 , 𝐗 , 𝑦 )

− ( 𝐞 𝑦 − logit ​ ( 𝐗 ) ) ⊤ ,
(27)
∇ 𝐰 𝑖 , 𝑘 𝑥 𝑖 output

𝜆 𝑚 ​ 𝜎 ′ ​ ( ⟨ 𝐰 𝑖 , 𝑘 , 𝐱 𝑎 ​ ( 𝐙 , 𝐗 ) ⟩ ) ​ 𝐱 𝑎 ⊤ ​ ( 𝐙 , 𝐗 ) ,
(28)
∇ 𝐑 𝐱 output

𝜆 ​ 𝐖 2 ,
(29)
∇ 𝐱 𝑎 𝐑

[ 𝜎 ′ ​ ( 𝐰 1 , 1 ⊤ ​ 𝐱 𝑎 ​ ( 𝐙 , 𝐗 ) ) ​ 𝐰 1 , 1

𝜎 ′ ​ ( 𝐰 1 , 𝑚 ⊤ ​ 𝐱 𝑎 ​ ( 𝐙 , 𝐗 ) ) ​ 𝐰 1 , 𝑚

𝜎 ′ ​ ( 𝐰 𝑑 , 𝑚 ⊤ ​ 𝐱 𝑎 ​ ( 𝐙 , 𝐗 ) ) ​ 𝐰 𝑑 , 𝑚 ] ,
(30)
∇ 𝐙𝐗 ​ [ − 1 ] 𝐱 𝑎 ​ ( 𝐙 , 𝐗 )

∇ 𝜶 ​ ( 𝐙 , 𝐗 ) 𝐱 𝑎 ​ ( 𝐙 , 𝐗 ) ​ ∇ 𝐙𝐗 ​ [ − 1 ] 𝜶 ​ ( 𝐙 , 𝐗 )

(31)

=

1 𝑑 ​ 𝐗 ​ ( diag ​ ( 𝜶 ​ ( 𝐙 , 𝐗 ) ) − 𝜶 ​ ( 𝐙 , 𝐗 ) ⋅ 𝜶 ​ ( 𝐙 , 𝐗 ) ⊤ ) ​ 𝐗 ⊤ .

(32)

Combining them and transposing the derived final gradient vectors completes the proof. ∎

E.2Notations

For notational convenience in the subsequent proofs, we define the following shorthand at iteration 𝑡 . First, let:

𝜻 ( 𝑡 ) ​ ( 𝐗 ) := 𝐖 2 ​ [ 𝜎 ′ ​ ( ( 𝐰 1 , 1 ( 𝑡 ) ) ⊤ ​ 𝐱 𝑎 ​ ( 𝐙 ( 𝑡 ) , 𝐗 ) ) ​ 𝐰 1 , 1 ( 𝑡 )

𝜎 ′ ​ ( ( 𝐰 1 , 𝑚 ( 𝑡 ) ) ⊤ ​ 𝐱 𝑎 ​ ( 𝐙 ( 𝑡 ) , 𝐗 ) ) ​ 𝐰 1 , 𝑚 ( 𝑡 )

𝜎 ′ ​ ( ( 𝐰 𝑑 , 𝑚 ( 𝑡 ) ) ⊤ ​ 𝐱 𝑎 ​ ( 𝐙 ( 𝑡 ) , 𝐗 ) ) ​ 𝐰 𝑑 , 𝑚 ( 𝑡 ) ] .

(33)

Then, for all 𝑗 ∈ [ 𝑁 ] , 𝑖 ∈ ℛ , we define:

Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) :=

𝐱 𝑎 ​ ( 𝐙 ( 𝑡 ) , [ 𝐨 ​ 𝐬 𝑗 ] ) ,

(34)

Ξ ~ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) :=

𝐱 𝑎 ​ ( 𝐙 ( 𝑡 ) , [ 𝐨 ​ 𝐬 𝑗 ] ) − 𝐬 𝑗 ,

(35)

Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) :=

𝐱 𝑎 ​ ( 𝐙 ( 𝑡 ) , [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ,

(36)

Ξ ~ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) :=

𝐱 𝑎 ​ ( 𝐙 ( 𝑡 ) , [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) − 𝐫 𝑖 ,

(37)

𝜶 ( 𝑡 ) ​ ( 𝐗 ) :=

𝜶 ​ ( 𝐙 ( 𝑡 ) , 𝐗 ) .

(38)

For any 𝐛 1 , 𝐛 2 ∈ { 𝐨 } ∪ { 𝐝 } ∪ { 𝐬 𝑗 } 𝑗

1 𝑁 ∪ { 𝐫 𝑖 } 𝑖 ∈ ℛ ∪ { 𝐚 𝑗 } 𝑗

1 𝑁 , we define:

𝑍 ~ ℐ ​ ( 𝐛 𝟏 ) , ℐ ​ ( 𝐛 𝟐 ) ( 𝑡 )

𝐛 1 ⊤ ​ 𝐙 ( 𝑡 ) ​ 𝐛 2 .

(39) E.3Lemmas

To facilitate the analysis, we present the following useful lemmas.

Lemma 2.

A variable 𝑥 ∼ 𝒩 ​ ( 0 , 𝜎 0 ) satisfies

ℙ ​ [ | 𝑥 |

𝑞 ] ≤ 2 ​ exp ⁡ ( − 𝑞 2 2 ​ 𝜎 0 2 ) .

(40) Lemma 3.

For a variable 𝑥 ∼ 𝒩 ​ ( 0 , 𝜎 0 2 ) and 𝑡

0 , we have

ℙ ​ [ | 𝑥 |

𝑡 ] ≥ 1 − 2 ​ 𝑡 2 ​ 𝜋 ​ 𝜎 0 .

(41) Proof.

As the Gaussian probability density function 𝑓 ​ ( ⋅ ) for 𝒩 ​ ( 0 , 𝜎 0 2 ) satisfies 𝑓 ​ ( 𝑥 ) ≤ 1 / 2 ​ 𝜋 ​ 𝜎 0 , we have

ℙ ​ [ | 𝑥 | ≤ 𝑡 ] ≤ 2 ​ 𝑡 2 ​ 𝜋 ​ 𝜎 0 .

(42)

As a result, we have

ℙ ​ [ | 𝑥 |

𝑡 ] ≥ 1 − 2 ​ 𝑡 2 ​ 𝜋 ​ 𝜎 0 .

(43)

This finishes the proof. ∎

Lemma 4.

Suppose 𝛿

0 and each element in a matrix 𝐙 ∈ ℝ 𝑑 × 𝑑 is generated following 𝒩 ​ ( 0 , 𝜎 0 2 ) . With probability 1 − 𝛿 , for each 𝑖 , 𝑗 ∈ [ 𝑑 ] , we have

| 𝑍 𝑖 , 𝑗 | ≤ 2 ​ log ⁡ ( 2 ​ 𝑑 2 / 𝛿 ) ​ 𝜎 0 .

(44) Proof.

By Lemma 2, for each 𝑖 , 𝑗 ∈ [ 𝑑 ] , with probability 1 − 𝛿 / 𝑑 2 , we have that

| 𝑍 𝑖 , 𝑗 | < 2 ​ log ⁡ ( 2 ​ 𝑑 2 / 𝛿 ) ​ 𝜎 0 .

(45)

Using the Union bound over all 𝑖 , 𝑗 ∈ [ 𝑑 ] , we finish the proof. ∎

Lemma 5.

Suppose 𝑑

0 and there are two variables 𝑥 and 𝑦 satisfying 𝑦 ≥ 2 ​ 𝑥 . The following holds:

exp ⁡ ( 𝑥 ) exp ⁡ ( 𝑥 ) + exp ⁡ ( 𝑦 ) + 𝑑 ≤ 1 𝑑 .

(46) Proof.

If exp ⁡ ( 𝑦 ) ≤ 𝑑 , we have

exp ⁡ ( 𝑥 ) ≤ 𝑑 ,

(47)

resulting in

exp ⁡ ( 𝑥 ) exp ⁡ ( 𝑥 ) + exp ⁡ ( 𝑦 ) + 𝑑 ≤ 1 𝑑 .

(48)

If exp ⁡ ( 𝑦 )

𝑑 , we have

exp ⁡ ( 𝑥 ) exp ⁡ ( 𝑥 ) + exp ⁡ ( 𝑦 ) + 𝑑 ≤ exp ⁡ ( 𝑥 ) exp ⁡ ( 𝑥 ) + exp ⁡ ( 𝑦 ) ≤ 1 𝑑 .

(49)

This completes the proof. ∎

Lemma 6.

Suppose at each time 𝑖 ∈ [ 𝑛 ] , the set 𝒜 𝑖 contains 𝐾 entries that are i.i.d. and uniformly selected without replacement from the set [ 𝑞 ] . For all 𝑗 ∈ [ 𝑞 ] , the following holds:

ℙ ​ [ ∑ 𝑖

1 𝑛 𝕀 ​ ( 𝑗 ∈ 𝒜 𝑖 ) ≥ 𝐾 2 ​ 𝑞 ​ 𝑛 ] ≥ 1 − exp ⁡ ( − 𝑛 ​ 𝐾 2 2 ​ 𝑞 2 ) .

(50) Proof.

In each time 𝑖 ∈ [ 𝑛 ] , a number 𝑗 ∈ [ 𝑞 ] has probability 𝐾 / 𝑞 to selected. By Hoeffding’s inequality, for any constant 𝑡

0 , we have:

ℙ ​ [ ∑ 𝑖

1 𝑛 𝕀 ​ ( 𝑗 ∈ 𝒜 𝑖 ) − 𝐾 𝑞 ​ 𝑛 ≤ − 𝑡 ] ≤ exp ⁡ ( − 2 ​ 𝑡 2 𝑛 ) .

(51)

Setting − 𝑡 to − 𝐾 2 ​ 𝑞 ​ 𝑛 , we have:

ℙ ​ [ ∑ 𝑖

1 𝑛 𝕀 ​ ( 𝑗 ∈ 𝒜 𝑖 ) − 𝐾 𝑞 ​ 𝑛 ≤ − 𝐾 2 ​ 𝑞 ​ 𝑛 ] ≤ exp ⁡ ( − 𝑛 ​ 𝐾 2 2 ​ 𝑞 2 ) .

(52)

Reorganizing the results, we prove:

ℙ ​ [ ∑ 𝑖

1 𝑛 𝕀 ​ ( 𝑗 ∈ 𝒜 𝑖 ) ≥ 𝐾 2 ​ 𝑞 ​ 𝑛 ] ≥ 1 − exp ⁡ ( − 𝑛 ​ 𝐾 2 2 ​ 𝑞 2 ) .

(53)

Lemma 7.

With probability 1 − 𝛿 , for all 𝑖 ∈ [ 𝑚 ​ 𝑑 ] and 𝑗 ∈ [ 𝑑 ] , we have

| 𝑤 𝑖 , 𝑗 ( 0 ) | ≤ 2 ​ log ⁡ ( 2 ​ 𝑚 ​ 𝑑 2 / 𝛿 ) ​ 𝜎 0 .

(54) Proof.

For all 𝑖 ∈ [ 𝑚 ​ 𝑑 ] and 𝑗 ∈ [ 𝑑 ] , we have

ℙ ​ [ | 𝑤 𝑖 , 𝑗 ( 0 ) |

𝑞 ] < 2 ​ exp ⁡ ( − 𝑞 2 2 ​ 𝜎 0 2 ) .

(55)

By union bound, with probability 1 − 𝛿 , we have

| 𝑤 𝑖 , 𝑗 ( 0 ) | < 2 ​ log ⁡ ( 2 ​ 𝑚 ​ 𝑑 2 / 𝛿 ) ​ 𝜎 0 .

(56)

Lemma 8.

For a variable 𝑥 ∼ 𝒩 ​ ( 0 , 𝜎 0 2 ) , we have

ℙ ​ [ 0 < 𝑥 < 𝑞 ] ≤ 𝑞 2 ​ 𝜋 ​ 𝜎 0 .

(57) Proof.

As the probability density function of 𝑥 is

𝑓 ​ ( 𝑥 )

1 2 ​ 𝜋 ​ 𝜎 0 ​ exp ⁡ ( − 𝑥 2 2 ​ 𝜎 0 2 ) ≤ 1 2 ​ 𝜋 ​ 𝜎 0 .

(58)

Therefore, we have

ℙ ​ [ 0 < 𝑥 < 𝑞 ]

∫ 0 𝑞 𝑓 ​ ( 𝑥 ) ​ 𝑑 𝑥 ≤ 𝑞 2 ​ 𝜋 ​ 𝜎 0 .

(59)

Lemma 9.

In each of a total of 𝑛 independent trials, we randomly select a subset 𝒳 𝑖 of 0.4 ​ 𝑚 distinct elements from [ 𝑚 ] uniformly at random (without replacement within a trial). When 𝑛 ≥ 50 ​ log ⁡ ( 𝑚 / 𝛿 ) , with probability 1 − 𝛿 , for all 𝑙 ∈ [ 𝑚 ] we have

∑ 𝑖

1 𝑛 𝕀 ​ ( 𝑙 ∈ 𝒳 𝑖 ) ≥ 0.3 ​ 𝑛 .

(60) Proof.

For all 𝑙 ∈ [ 𝑚 ] , we have

𝔼 ​ [ 1 𝑛 ​ ∑ 𝑖

1 𝑛 𝕀 ​ ( 𝑙 ∈ 𝒳 𝑖 ) ]

0.4 .

(61)

By Hoeffding’s inequality, we have

ℙ ​ [ 1 𝑛 ​ ∑ 𝑖

1 𝑛 𝕀 ​ ( 𝑙 ∈ 𝒳 𝑖 ) ≤ 0.3 ] ≤ exp ⁡ ( − 𝑛 50 ) .

(62)

Using union bound for all 𝑙 ∈ [ 𝑚 ] , we finish the proof. ∎

Appendix FProofs for the pre-training phase

For the sake of indexing the appeared subjects and answers with relations, we first define the following sets:

ℬ 𝑗 := { 𝑖 ∈ ℛ : ( 𝐨 , 𝐝 , 𝐬 𝑗 , 𝐫 𝑖 , 𝐚 𝑗 ) ∈ 𝒯 } ,

(63)

and

𝒟 𝑙 := { 𝑗 ∈ [ 𝑁 ] : ( 𝐨 , 𝐝 , 𝐬 𝑗 , 𝐫 𝑙 , 𝐚 𝑗 ) ∈ 𝒯 } .

(64) F.1Activation patterns at initialization

In this section, we characterize the activation patterns, i.e., which neurons in the MLP are activated at initialization. The latter results are built upon conclusions in this subsection, which hold with high probability.

For convenience, we define the following sets for indexing the activated neurons:

𝒮 s , 𝑘 , 𝑗 ( 𝑡 ) := { 𝑖 ∈ [ 𝑚 ] : ⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑖 ( 𝑡 ) , 𝐬 𝑗 ⟩

0 } ,

(65)

𝒮 r , 𝑗 , 𝑘 ( 𝑡 ) := { 𝑖 ∈ [ 𝑚 ] : ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑖 ( 𝑡 ) , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑘 ] ) ⟩

0 } ,

(66)

and

𝒮 o ( 𝑡 ) := { 𝑖 ∈ [ 𝑚 ] : ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑖 ( 𝑡 ) , 2 ​ 𝐨 ⟩

0 } .

(67)

Now, we characterize the activation pattern at initialization.

Lemma 10.

Suppose that 𝛿

0 and 𝑚 ≥ 50 ​ log ⁡ ( 6 ​ 𝑁 ​ 𝐾 / 𝛿 ) . For all 𝑗 ∈ [ 𝑁 ] and 𝑘 ∈ ℬ 𝑗 , with probability at least 1 − 𝛿 , the followings hold:

| 𝒮 s , 𝑘 , 𝑗 ( 0 ) | ≥ 0.4 ​ 𝑚 , | 𝒮 r , 𝑗 , 𝑘 ( 0 ) | ≥ 0.4 ​ 𝑚 , and | 𝒮 o ( 0 ) | ≥ 0.4 ​ 𝑚 .

(68) Proof.

At initialization, we have

ℙ ​ [ ⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑖 ( 0 ) , 𝐬 𝑗 ⟩ > 0 ]

1 / 2 , ℙ ​ [ ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑖 ( 0 ) , Ξ ( 0 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑘 ] ) ⟩ ]

1 / 2 , and ℙ ​ [ ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑖 ( 0 ) , 2 ​ 𝐨 ⟩ ]

1 / 2 .

(69)

By Hoeffding’s inequality, with probability 1- 𝛿 / ( 3 ​ 𝑁 ​ 𝐾 ) , we have

| | 𝒮 s , 𝑘 , 𝑗 ( 0 ) | 𝑚 − 1 2 | ≤ log ⁡ ( 6 ​ 𝑁 ​ 𝐾 / 𝛿 ) 2 ​ 𝑚 .

(70)

Similarly, with probability 1- 𝛿 / ( 6 ​ 𝑁 ​ 𝐾 ) , we have

| | 𝒮 r , 𝑗 , 𝑘 ( 0 ) | 𝑚 − 1 2 | ≤ log ⁡ ( 6 ​ 𝑁 ​ 𝐾 / 𝛿 ) 2 ​ 𝑚 .

(71)

Additionally, with probability 1- 𝛿 / 3 , we have

| | 𝒮 o ( 0 ) | 𝑚 − 1 2 | ≤ log ⁡ ( 6 / 𝛿 ) 2 ​ 𝑚 .

(72)

As long as 𝑚 ≥ 50 ​ log ⁡ ( 6 ​ 𝑁 ​ 𝐾 / 𝛿 ) , by union bound over all the combinations of 𝑗 ∈ [ 𝑁 ] and 𝑘 ∈ ℬ 𝑗 , we have

| 𝒮 s , 𝑘 , 𝑗 ( 0 ) | ≥ 0.4 ​ 𝑚 , | 𝒮 r , 𝑗 , 𝑘 ( 0 ) | ≥ 0.4 ​ 𝑚 , and | 𝒮 o ( 0 ) | ≥ 0.4 ​ 𝑚 ,

(73)

for all 𝑗 ∈ [ 𝑁 ] and 𝑘 ∈ ℬ 𝑗 , with probability at least 1 − 𝛿 . This finishes the proof. ∎

The analysis of activation patterns throughout the entire pre-training dynamics is not a separate proof but is integrated into the proofs for Stages 1, 2, and 3, which are detailed in the following subsections.

F.2Proof of pre-training Stage 1

In this subsection, we prove several key properties of pre-training Stage 1, including the attention scores, activation patterns, and MLP feature learning.

Lemma 11.

Under Condition 1, with probability 1 − 𝛿 , the followings hold in pre-training Stage 1 ( 𝑡 ≤ 𝑇 1 ) with 𝑇 1

2 .

1.

We have 1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑘 ( 𝑇 1 ) , 2 ​ 𝐨 ⟩

Θ ​ ( 𝑑 1 / 4 ​ log ⁡ ( 𝑑 ) / 𝜆 ) .

2.

For all 𝑗 ∈ [ 𝑁 ] and 𝑖 ∈ ℬ 𝑗 , we have 1 − logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑇 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] )

Θ ​ ( 1 ) , and 1 − 𝐾 ~ ​ ( 𝑗 ) ​ logit ℐ ​ ( 𝐫 𝑖 ) ( 𝑇 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] )

Θ ​ ( 1 ) .

3.

For all 𝑗 ∈ [ 𝑁 ] and 𝑖 ∈ ℬ 𝑗 , we have 0.8 ≤ 𝛼 2 ( 𝑇 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) / 𝛼 3 ( 𝑇 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ≤ 1.2 .

4.

At the end of Stage 1, for all 𝑗 ∈ [ 𝑁 ] , 𝑖 ∈ ℬ 𝑗 , we have

𝛼 1 ( 𝑇 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ≤ 1 4 ​ 𝑑 5 / 4 ​  and  ​ 𝛼 1 ( 𝑇 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ≤ 1 4 ​ 𝑑 5 / 4 .

5.

For each 𝑗 ∈ [ 𝑁 ] and all 𝑘 ∈ ℬ 𝑗 , 𝑙 ∈ 𝒮 s , 𝑘 , 𝑗 ( 0 ) , we have ⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑡 ) , 𝐬 𝑗 ) ⟩

0 .

6.

For each 𝑗 ∈ [ 𝑁 ] and all 𝑘 ∈ ℬ 𝑗 , 𝑙 ∈ 𝒮 r , 𝑗 , 𝑘 ( 0 ) , we have ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑘 ] ) ⟩

0 .

Proof.

In this proof, we establish each of the six statements in Lemma 11 sequentially.

Proof of the first statement: During the first iteration, we have

1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑘 ( 1 ) , 2 ​ 𝐨 ⟩ − 1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑘 ( 0 ) , 2 ​ 𝐨 ⟩

(74)

= ( 𝑎 )
Θ ( 1 ) ⋅ ( 𝜂 1 𝑛 𝑛 3 ​ 𝑚 2 ∑ 𝑘

1 𝑚 ∑ 𝑖

1 𝑛 ( 1 − logit ℐ ​ ( 𝐝 ) ( 0 ) ( 2 𝐨 ) ) 𝜎 ′ ( ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑘 ( 0 ) , 2 𝐨 ⟩ ) 𝜆 ∥ 2 𝐨 ∥ 2 2

− 2 ​ 𝜂 1 ​ 𝐾 ~ ​ ( 𝑗 ) 𝑚 2 ​ 𝑛 ​ ∑ 𝑘

1 𝑚 ∑ 𝑗

1 𝑁 logit ℐ ​ ( 𝐝 ) ( 0 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ​ 𝛼 1 ( 0 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ​ 𝜎 ′ ​ ( ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑘 ( 0 ) , Ξ ( 0 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ⟩ ) ​ 𝜆 ​ ‖ 𝐨 ‖ 2 2

− 𝜂 1 𝑚 2 ​ 𝑛 ∑ 𝑘

1 𝑚 ∑ 𝑗

1 𝑁 ∑ 𝑖 ∈ ℬ 𝑗 logit ℐ ​ ( 𝐝 ) ( 0 ) ( [ 𝐨 𝐬 𝑗 𝐫 𝑖 ] ) 𝛼 1 ( 0 ) ( [ 𝐨 𝐬 𝑗 𝐫 𝑖 ] ) 𝜎 ′ ( ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑘 ( 0 ) , Ξ ( 0 ) ( [ 𝐨 𝐬 𝑗 𝐫 𝑖 ] ) ⟩ ) 𝜆 ∥ 𝐨 ∥ 2 2 )

( 𝑏 )

Θ ​ ( 𝜂 1 ​ 𝜆 ​ 𝑑 𝑚 ) ,

where ( 𝑎 ) follows Lemmas 1, and ( 𝑏 ) holds because the initial logit terms satisfy 1 − logit ℐ ​ ( 𝐝 ) ( 0 ) ​ ( 𝐨 )

Θ ​ ( 1 ) , logit ℐ ​ ( 𝐝 ) ( 0 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] )

𝒪 ​ ( 1 / 𝑑 ) , and ​ logit ℐ ​ ( 𝐝 ) ( 0 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] )

𝒪 ​ ( 1 / 𝑑 ) and Lemma 10 which ensures the number of activated neurons are Θ ​ ( 𝑚 ) .

In the second iteration, we consider two cases.

Case 1: If 1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑘 ( 2 ) , 2 ​ 𝐨 ⟩ − 1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑘 ( 1 ) , 2 ​ 𝐨 ⟩ ≥ 0 , we have

1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑘 ( 2 ) , 2 ​ 𝐨 ⟩ − 1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑘 ( 1 ) , 2 ​ 𝐨 ⟩

𝒪 ​ ( 𝜂 1 ​ 𝜆 ​ 𝑑 𝑚 ) ,

(75)

with similar derivations as (74).

Case 2: While if 1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑘 ( 2 ) , 2 ​ 𝐨 ⟩ − 1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑘 ( 1 ) , 2 ​ 𝐨 ⟩ < 0 , by Lemma 5, we have

1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑘 ( 2 ) , 2 ​ 𝐨 ⟩ − 1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑘 ( 1 ) , 2 ​ 𝐨 ⟩

(76)


− 2 ​ 𝜂 1 ​ 𝐾 ~ ​ ( 𝑗 ) 𝑚 2 ​ 𝑛 ​ Θ ​ ( 1 ) ​ ∑ 𝑘

1 𝑚 ∑ 𝑗

1 𝑁 logit ℐ ​ ( 𝐝 ) ( 0 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ​ 𝛼 1 ( 0 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ​ 𝜎 ′ ​ ( ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑘 ( 0 ) , Ξ ( 0 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ⟩ ) ​ 𝜆 ​ ‖ 𝐨 ‖ 2 2

− 𝜂 1 𝑚 2 ​ 𝑛 ​ Θ ​ ( 1 ) ​ ∑ 𝑘

1 𝑚 ∑ 𝑗

1 𝑁 ∑ 𝑖 ∈ ℬ 𝑗 logit ℐ ​ ( 𝐝 ) ( 0 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ​ 𝛼 1 ( 0 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ​ 𝜎 ′ ​ ( ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑘 ( 0 ) , Ξ ( 0 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⟩ ) ​ 𝜆 ​ ‖ 𝐨 ‖ 2 2

− 𝒪 ​ ( 𝜂 1 ​ 𝜆 ​ 𝑑 𝑚 ) .

Combining (74), (75) and (76), we have

1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑘 ( 𝑇 1 ) , 2 ​ 𝐨 ⟩ − 1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑘 ( 0 ) , 2 ​ 𝐨 ⟩

Θ ​ ( 𝜂 1 ​ 𝜆 ​ 𝑑 𝑚 )

Θ ​ ( 𝑑 1 / 4 ​ log ⁡ ( 𝑑 ) 𝜆 ) .

(77)

Therefore, we can conclude that

1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑘 ( 𝑇 1 ) , 2 ​ 𝐨 ⟩

− 𝒪 ~ ​ ( 𝑑 ​ 𝜎 0 / 𝑚 ) + Θ ​ ( 𝑑 1 / 4 ​ log ⁡ ( 𝑑 ) 𝜆 )

Θ ​ ( 𝑑 1 / 4 ​ log ⁡ ( 𝑑 ) 𝜆 ) .

(78)

Proof of the second statement: For 𝑡 ≤ 𝑇 1 , for all 𝑘 ∈ [ 𝑚 ] , 𝑗 ∈ [ 𝑁 ] , and 𝑖 ∈ ℬ 𝑗 , we have

⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑘 ( 𝑡 + 1 ) , Ξ ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ⟩ − ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑘 ( 𝑡 ) , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ⟩

(79)

= ( 𝑎 )
𝒪 ( 𝜂 1 𝑚 ​ 𝑛 𝜆 ∥ 𝐨 + 2 𝐬 𝑗 ∥ 2 2 ) + Θ ( 𝑑 ) ⋅ 𝜆 ​ 𝜂 1 𝑛 ​ 𝑚 ( 1 − 𝐾 ~ ( 𝑗 ) logit ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) ( [ 𝐨 𝐬 𝑗 ] )

( 𝑏 )

𝒪 ​ ( 𝜆 ​ 𝜂 1 ​ 𝑑 𝑚 ​ 𝑛 ) ,

where ( 𝑎 ) is obtained by Lemmas 1 and 15, Cauchy–Schwarz inequality and the fact that | ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑘 ( 𝑡 ) , Ξ ( 𝑡 + 1 ) ( [ 𝐨 𝐬 𝑗 ] ) − Ξ ( 𝑡 ) ( [ 𝐨 𝐬 𝑗 ] ) ⟩ |

𝒪 ( 𝜆 𝜂 1 ∥ 𝐨 + 2 𝐬 𝑗 ∥ 2 2 / ( 𝑛 𝑚 ) ) . ( 𝑏 ) is due to the fact that ‖ 𝐨 + 2 ​ 𝐬 𝑗 ‖ 2 2

𝒪 ​ ( 𝑑 ) and 1 − 𝐾 ~ ​ ( 𝑗 ) ​ logit ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ≤ 1 .

As there are at most 𝑇 1 iterations during Stage 1, for 𝑡 < 𝑇 1 , using (79), we have

1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑘 ( 𝑡 ) , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ⟩

𝒪 ​ ( 𝜆 ​ 𝜂 1 ​ 𝑑 𝑚 ​ 𝑛 ​ 𝑇 1 )

𝒪 ​ ( 𝜆 ​ 𝜂 1 ​ 𝑑 𝑚 ​ 𝑛 ) .

(80)

By (74) and (80), the results 1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑘 ( 𝑇 1 ) , 2 ​ 𝐨 ⟩

Θ ​ ( 𝜆 ​ 𝜂 1 ​ 𝑑 / 𝑚 ) and 1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑘 ( 𝑇 1 ) , Ξ ( 𝑇 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ⟩

𝒪 ​ ( 1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑘 ( 𝑇 1 ) , Ξ ( 𝑇 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ⟩ / 𝑛 ) hold. Thus, we have

1 − 𝐾 ~ ​ ( 𝑗 ) ​ logit ℐ ​ ( 𝐫 𝑖 ) ( 𝑇 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] )

1 − 𝒪 ​ ( 𝐾 ~ ​ ( 𝑗 ) 𝑑 1 − 1 / 𝑛 )

Θ ​ ( 1 ) .

(81)

In addition, the increment of ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑘 ( 𝑡 ) , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⟩ at iterations 𝑡 satisfies

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑘 ( 𝑡 + 1 ) , Ξ ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⟩ − ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑘 ( 𝑡 ) , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⟩

(82)

= ( 𝑎 )
𝒪 ​ ( 𝜂 1 𝑛 ​ 𝑚 ​ ∑ 𝑖 ∈ ℬ 𝑗 𝜆 ​ ‖ 𝐨 + 𝐬 𝑗 + 2 ​ 𝐫 𝑖 ‖ 2 2 ) + 𝒪 ​ ( 𝜂 1 ​ 𝜆 ​ 𝑑 𝑛 ​ 𝑚 ​ ∑ 𝑖 ∈ ℬ 𝑗 ( 1 − logit ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ) )

( 𝑏 )

𝒪 ​ ( 𝜆 ​ 𝜂 1 ​ 𝑑 ​ 𝐾 ~ ​ ( 𝑗 ) 𝑛 ​ 𝑚 ) ,

where ( 𝑎 ) is by Lemma 1 and the facts that 1 − logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ≤ 1 and | ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑘 ( 𝑡 ) , Ξ ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) − Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⟩ |

𝒪 ​ ( 𝜂 1 / ( 𝑛 ​ 𝑚 ) ​ ∑ 𝑖 ∈ ℬ 𝑗 ‖ 𝐨 + 𝐬 𝑗 + 2 ​ 𝐫 𝑖 ‖ 2 2 ) for 𝑡 ≤ 𝑇 1 . ( 𝑏 ) is because of 1 − logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ≤ 1 and ‖ 𝐨 + 𝐬 𝑗 + 2 ​ 𝐫 𝑖 ‖ 2 2

𝒪 ​ ( 𝑑 ) .

After 𝑇 1 iterations, we have

1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑘 ( 𝑇 1 ) , Ξ ( 𝑇 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⟩

𝒪 ​ ( 𝜆 ​ 𝜂 1 ​ 𝑑 ​ 𝐾 ~ ​ ( 𝑗 ) 𝑛 ​ 𝑚 ) .

(83)

Then, as 1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑘 ( 𝑇 1 ) , 2 ​ 𝐨 ⟩

Θ ​ ( 𝜆 ​ 𝜂 1 ​ 𝑑 / 𝑚 ) and 1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑘 ( 𝑇 1 ) , Ξ ( 𝑇 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⟩

= 𝒪 ​ ( 1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑘 ( 𝑇 1 ) , Ξ ( 𝑇 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⟩ / 𝑛 ) , we have

1 − logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑇 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] )

1 − 𝒪 ​ ( 𝑑 𝐾 ~ ​ ( 𝑗 ) / 𝑛 𝑑 )

Θ ​ ( 1 ) .

(84)

Proof of the third statement: For all 𝑡 ∈ [ 𝑇 1 ] , we have

𝛼 2 ( 𝑡 ) ​ ( [ 𝐨 , 𝐬 𝑗 , 𝐫 𝑖 ] ) 𝛼 3 ( 𝑡 ) ​ ( [ 𝐨 , 𝐬 𝑗 , 𝐫 𝑖 ] )

exp ⁡ ( 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) / 𝑑 ) exp ⁡ ( 𝑍 ~ ℐ ​ ( 𝐫 𝑖 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) / 𝑑 )

exp ⁡ ( ( 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) − 𝑍 ~ ℐ ​ ( 𝐫 𝑖 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) ) / 𝑑 ) .

(85)

For all 𝑡 ∈ [ 𝑇 1 ] , the increments of the terms 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) and 𝑍 ~ ℐ ​ ( 𝐫 𝑖 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) satisfy

| 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 + 1 ) − 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) |

(86)

=
𝜂 1 ​ 𝜆 𝑛 ​ 𝑑 ​ 𝛼 2 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⏟ ≤ 1 ​ | ( 𝐞 ℐ ​ ( 𝐚 𝑗 ) − logit ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ) ⊤ ​ 𝜻 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ​ ( 𝐬 𝑗 − Ξ ~ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ) | ⏟

𝒪 ​ ( 𝜂 1 ​ 𝜆 ​ 𝑑 / 𝑚 ) ​  by ( 77 )

𝒪 ​ ( 𝜆 2 ​ 𝜂 1 2 ​ 𝑑 𝑛 ​ 𝑚 ) ,

and

| 𝑍 ~ ℐ ​ ( 𝐫 𝑖 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 + 1 ) − 𝑍 ~ ℐ ​ ( 𝐫 𝑖 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) |

(87)

=
𝜂 1 ​ 𝜆 𝑛 ​ 𝑑

∑ 𝑗 ∈ 𝒟 𝑖 𝛼 3 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⏟ ≤ 1 ​ | ( 𝐞 ℐ ​ ( 𝐚 𝑗 ) − logit ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ) ⊤ ​ 𝜻 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ​ ( 𝐫 𝑖 − Ξ ~ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ) | ⏟

𝒪 ​ ( 𝜂 1 ​ 𝜆 ​ 𝑑 / 𝑚 ) ​  by ( 77 )

𝒪 ​ ( 𝜆 2 ​ 𝜂 1 2 ​ 𝑑 𝑚 ) .

Furthermore, combining (86), (87), and the fact that | 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐫 𝑖 ) ( 0 ) − 𝑍 ~ ℐ ​ ( 𝐫 𝑖 ) , ℐ ​ ( 𝐫 𝑖 ) ( 0 ) |

𝒪 ~ ​ ( 𝜎 0 ) by Lemma 7, we have

| 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐫 𝑖 ) ( 2 ) − 𝑍 ~ ℐ ​ ( 𝐫 𝑖 ) , ℐ ​ ( 𝐫 𝑖 ) ( 2 ) | 𝑑

(88)


∑ 𝑡

1 2 | 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) − 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 − 1 ) | + | 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐫 𝑖 ) ( 0 ) − 𝑍 ~ ℐ ​ ( 𝐫 𝑖 ) , ℐ ​ ( 𝐫 𝑖 ) ( 0 ) | + ∑ 𝑡

1 2 | 𝑍 ~ ℐ ​ ( 𝐫 𝑖 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) − 𝑍 ~ ℐ ​ ( 𝐫 𝑖 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 − 1 ) | 𝑑

𝒪 ​ ( 𝜂 1 2 ​ 𝜆 2 / 𝑚 ) .

Under Condition 1, we have

exp ⁡ ( ( 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑇 1 ) − 𝑍 ~ ℐ ​ ( 𝐫 𝑖 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑇 1 ) ) / 𝑑 )

exp ⁡ ( 𝒪 ​ ( 𝜂 1 2 ​ 𝜆 2 ​ 𝑇 1 / 𝑚 ) ) ≤ 1.2 .

(89)

Using similar proof, we also have

exp ⁡ ( ( 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑇 1 ) − 𝑍 ~ ℐ ​ ( 𝐫 𝑖 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑇 1 ) ) / 𝑑 ) ≥ 0.8 .

(90)

Proof of the fourth statement: First, we prove the first part that 𝛼 1 ( 𝑇 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ≤ 1 / 4 ​ 𝑑 5 / 4 .

The increments of 𝑍 ~ ℐ ​ ( 𝐨 ) , ℐ ​ ( 𝐬 𝑗 ) ( 𝑡 ) in the first and the second iterations satisfy

| 𝑍 ~ ℐ ​ ( 𝐨 ) , ℐ ​ ( 𝐬 𝑗 ) ( 1 ) − 𝑍 ~ ℐ ​ ( 𝐨 ) , ℐ ​ ( 𝐬 𝑗 ) ( 0 ) |

(91)


𝜂 1 ​ 𝜆 𝑛 ​ 𝑑 ​ 𝛼 1 ( 0 ) ​ ( [ 𝐨 , 𝐬 𝑗 ] ) ⏟ ≤ 1 ​ ∑ 𝑖 ∈ ℬ 𝑗 | ( 𝐞 ℐ ​ ( 𝐫 𝑖 ) − logit ( 0 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ) ⊤ ​ 𝜻 ( 0 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ​ ( 𝐨 − Ξ ~ ( 0 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ) | ⏟

𝒪 ~ ​ ( 𝜎 0 ​ 𝑑 / 𝑚 ) ​  by Lemma  7 ⋅ Θ ​ ( 𝑑 )

𝒪 ​ ( 𝜆 ​ 𝜂 1 ​ 𝜎 0 ​ 𝑑 ​ 𝐾 𝑛 ​ 𝑚 ) ,

and

𝑍 ~ ℐ ​ ( 𝐨 ) , ℐ ​ ( 𝐬 𝑗 ) ( 2 ) − 𝑍 ~ ℐ ​ ( 𝐨 ) , ℐ ​ ( 𝐬 𝑗 ) ( 1 )

(92)

=
𝜂 1 ​ 𝜆 𝑛 ​ 𝑑 ​ 𝛼 1 ( 1 ) ​ ( [ 𝐨 , 𝐬 𝑗 ] ) ⏟

Θ ​ ( 1 ) ​ ∑ 𝑖 ∈ ℬ 𝑗 ( 𝐞 ℐ ​ ( 𝐫 𝑖 ) − logit ( 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ) ⊤ ​ 𝜻 ( 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ​ ( 𝐨 − Ξ ~ ( 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ) ⋅ Θ ​ ( 𝑑 )

( 𝑎 )

− Θ ​ ( 𝜆 2 ​ 𝜂 1 2 ​ 𝑑 ​ 𝐾 𝑛 ​ 𝑚 ) ,

where ( 𝑎 ) holds because according to (77), 𝜁 ℐ ​ ( 𝐝 ) ( 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] )

Ω ​ ( log ⁡ ( 𝑑 ) ) holds, resulting in logit ℐ ​ ( 𝐝 ) ( 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] )

Θ ​ ( 1 ) , which inturn yields ( 𝐞 ℐ ​ ( 𝐫 𝑖 ) − logit ( 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ) ⊤ ​ 𝜻 ( 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ​ ( 𝐨 − Ξ ~ ( 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) )

− Θ ​ ( 𝜂 1 ​ 𝜆 ​ 𝑑 / 𝑚 ) .

Additionally, the increment of 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐬 𝑗 ) ( 𝑡 ) for 𝑡 ∈ [ 𝑇 1 ] satisfies

| 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐬 𝑗 ) ( 𝑡 + 1 ) − 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐬 𝑗 ) ( 𝑡 ) |

(93)


𝜂 1 ​ 𝜆 𝑛 ​ 𝑑 ​ 𝛼 2 ( 𝑡 ) ​ ( [ 𝐨 , 𝐬 𝑗 ] ) ⏟

Θ ​ ( 1 ) ​ ∑ 𝑖 ∈ ℬ 𝑗 | ( 𝐞 ℐ ​ ( 𝐫 𝑖 ) − logit ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ) ⊤ ​ 𝜻 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ​ ( 𝐬 𝑗 − Ξ ~ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ) | ⏟

𝒪 ​ ( 𝜂 1 ​ 𝜆 ​ 𝑑 / 𝑚 ) ​  by ( 77 )

𝒪 ​ ( 𝜆 2 ​ 𝜂 1 2 ​ 𝑑 ​ 𝐾 𝑛 ​ 𝑚 ) .

Combining (91), (92) (93) and the fact that | 𝑍 ~ ℐ ​ ( 𝐨 ) , ℐ ​ ( 𝐬 𝑗 ) ( 0 ) − 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐬 𝑗 ) ( 0 ) |

𝒪 ~ ​ ( 𝜎 0 ) by Lemma 7, we have

𝑍 ~ ℐ ​ ( 𝐨 ) , ℐ ​ ( 𝐬 𝑗 ) ( 2 ) − 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐬 𝑗 ) ( 2 ) 𝑑

(94)


∑ 𝑡

1 2 𝑍 ~ ℐ ​ ( 𝐨 ) , ℐ ​ ( 𝐬 𝑗 ) ( 𝑡 ) − 𝑍 ~ ℐ ​ ( 𝐨 ) , ℐ ​ ( 𝐬 𝑗 ) ( 𝑡 − 1 ) + | 𝑍 ~ ℐ ​ ( 𝐨 ) , ℐ ​ ( 𝐬 𝑗 ) ( 0 ) − 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐬 𝑗 ) ( 0 ) | + ∑ 𝑡

1 2 | 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐬 𝑗 ) ( 𝑡 ) − 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐬 𝑗 ) ( 𝑡 − 1 ) | 𝑑

− Θ ​ ( 𝜆 2 ​ 𝜂 1 2 ​ 𝑑 ​ 𝐾 𝑛 ​ 𝑚 ) .

Under Condition 1, we have

exp ⁡ ( ( 𝑍 ~ ℐ ​ ( 𝐨 ) , ℐ ​ ( 𝐬 𝑗 ) ( 2 ) − 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐬 𝑗 ) ( 2 ) ) / 𝑑 ) ≤ exp ⁡ ( − Θ ​ ( 𝜂 1 2 ​ 𝜆 2 ​ 𝑑 ​ 𝐾 𝑛 ​ 𝑚 ) ) ≤ 1 4 ​ 𝑑 5 / 4 .

(95)

Next, we prove the second conclusion that 𝛼 1 ( 𝑇 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ≤ 1 / 4 ​ 𝑑 5 / 4 . The increments of 𝑍 ~ ℐ ​ ( 𝐨 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) in the first and the second iterations satisfy

| 𝑍 ~ ℐ ​ ( 𝐨 ) , ℐ ​ ( 𝐫 𝑖 ) ( 1 ) − 𝑍 ~ ℐ ​ ( 𝐨 ) , ℐ ​ ( 𝐫 𝑖 ) ( 0 ) |

(96)

=
𝜂 1 ​ 𝜆 𝑛 ​ 𝑑 ​ ∑ 𝑗 ∈ 𝒟 𝑖 𝛼 1 ( 0 ) ​ ( [ 𝐨 , 𝐬 𝑗 , 𝐫 𝑖 ] ) ⏟

Θ ​ ( 1 )

⋅ | ( 𝐞 ℐ ​ ( 𝐚 𝑗 ) − logit ( 0 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ) ⊤ ​ 𝜻 ( 0 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⋅ ( 𝐨 − Ξ ~ ( 0 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ) | ⏟

𝒪 ~ ​ ( 𝜎 0 ​ 𝑑 / 𝑚 ) ​  by Lemma  7 ⋅ Θ ​ ( 𝑑 )

𝒪 ​ ( 𝜆 ​ 𝜂 1 ​ 𝜎 0 ​ 𝑑 𝑛 ​ 𝑚 ) ,

and

𝑍 ~ ℐ ​ ( 𝐨 ) , ℐ ​ ( 𝐫 𝑖 ) ( 2 ) − 𝑍 ~ ℐ ​ ( 𝐨 ) , ℐ ​ ( 𝐫 𝑖 ) ( 1 )

(97)

=
𝜂 1 ​ 𝜆 𝑛 ​ 𝑑 ​ ∑ 𝑗 ∈ 𝒟 𝑖 𝛼 1 ( 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⏟

Θ ​ ( 1 )

⋅ ( 𝐞 ℐ ​ ( 𝐚 𝑗 ) − logit ( 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ) ⊤ ​ 𝜻 ( 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⋅ ( 𝐨 − Ξ ~ ( 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ) ⋅ Θ ​ ( 𝑑 )

( 𝑎 )

− Ω ​ ( 𝜆 2 ​ 𝜂 1 2 ​ 𝑑 𝑛 ​ 𝑚 ) ,

where ( 𝑎 ) holds because according to (77), 𝜁 ℐ ​ ( 𝐝 ) ( 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] )

Ω ​ ( log ⁡ ( 𝑑 ) ) holds, resulting in logit ℐ ​ ( 𝐝 ) ( 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] )

Θ ​ ( 1 ) , which in turn yields ( 𝐞 ℐ ​ ( 𝐚 𝑗 ) − logit ( 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ) ⊤ ​ 𝜻 ( 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⋅ ( 𝐨 − Ξ ~ ( 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) )

− Θ ​ ( 𝜂 1 ​ 𝜆 ​ 𝑑 / 𝑚 ) .

Consequently, combining (96), (97), (86), and the fact that | 𝑍 ~ ℐ ​ ( 𝐨 ) , ℐ ​ ( 𝐫 𝑗 ) ( 0 ) − 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐫 𝑖 ) ( 0 ) |

𝒪 ~ ​ ( 𝜎 0 ) by Lemma 7, we have

𝑍 ~ ℐ ​ ( 𝐨 ) , ℐ ​ ( 𝐫 𝑖 ) ( 2 ) − 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐫 𝑖 ) ( 2 ) 𝑑

(98)


∑ 𝑡

1 2 𝑍 ~ ℐ ​ ( 𝐨 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) − 𝑍 ~ ℐ ​ ( 𝐨 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 − 1 ) + | 𝑍 ~ ℐ ​ ( 𝐨 ) , ℐ ​ ( 𝐫 𝑖 ) ( 0 ) − 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐫 𝑖 ) ( 0 ) | + ∑ 𝑡

1 2 | 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) − 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 − 1 ) | 𝑑

− Θ ​ ( 𝜆 2 ​ 𝜂 1 2 ​ 𝑑 ​ 𝐾 𝑛 ​ 𝑚 ) .

Under condition 1, we have

exp ⁡ ( 𝑍 ~ ℐ ​ ( 𝐨 ) , ℐ ​ ( 𝐫 𝑖 ) ( 2 ) − 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐫 𝑖 ) ( 2 ) / 𝑑 )

exp ⁡ ( − Θ ​ ( 𝜂 1 2 ​ 𝜆 2 ​ 𝑑 𝑛 ​ 𝑚 ) ) ≤ 1 4 ​ 𝑑 5 / 4 .

(99)

Combining (99) with (89), (90) yields the conclusion.

Proof of the fifth statement: By Lemma 3, for 𝑘 ∈ ℬ 𝑗 and 𝑖 ∈ 𝒮 s , 𝑘 , 𝑗 ( 0 ) , we have

ℙ ​ [ ⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑖 ( 0 ) , 𝐬 𝑗 ⟩ ≥ Θ ​ ( 𝜆 𝑛 ​ 𝑚 ​ 𝑑 ) ] ≥ 1 − Θ ​ ( 𝜆 𝑛 ​ 𝑚 ​ 𝑑 ​ 𝜎 0 ) .

(100)

In Stage 1, by the gradient update, if a neuron 𝑖 ∈ [ 𝑚 ] satisfies

⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑖 ( 0 ) , Ξ ( 0 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ⟩ ≤ 0 ​  or  ​ ⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑖 ( 0 ) , Ξ ( 0 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⟩ ≥ 0 ,

(101)

we have

⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑖 ( 𝑡 + 1 ) , 𝐬 𝑗 ⟩ − ⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑖 ( 𝑡 ) , 𝐬 𝑗 ⟩

− 𝒪 ​ ( 𝜆 ​ 𝜂 1 𝑛 ​ 𝑚 ​ 𝑑 ) .

(102)

Using the union bound over 𝑖 ∈ 𝒮 s , 𝑘 , 𝑗 ( 0 ) , with probability 1 − Θ ​ ( 𝜆 / ( 𝑛 ​ 𝑑 ) ) , we have

⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑖 ( 0 ) , 𝐬 𝑗 ⟩ ≥ Θ ​ ( 𝜆 𝑛 ​ 𝑚 ​ 𝑑 ) .

(103)

Combining (102) and (103), for all 𝑡 ∈ [ 𝑇 1 ] , we have

⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑖 ( 𝑡 ) , 𝐬 𝑗 ⟩ ≥ 0 .

(104)

Proof of the sixth statement: Based on Lemma 10, for all 𝑙 ∈ 𝒮 r , 𝑗 , 𝑘 ( 0 ) , the updates of ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , 𝐫 𝑘 ⟩ and ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , 𝐬 𝑗 ⟩ satisfy:

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 + 1 ) , 𝐫 𝑘 ⟩ ≥

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , 𝐫 𝑘 ⟩

(105)

  • 𝜆 ​ 𝜂 ( 𝑡 ) 𝑛 ​ 𝑚 ​ ( 1 − logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑘 ] ) ) − 2 ​ 𝜆 ​ ∑ 𝑖 ≠ 𝑗 , 𝑖 ∈ 𝒟 𝑘 𝜂 ( 𝑡 ) 𝑛 ​ 𝑚 ​ logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑖 ​ 𝐫 𝑘 ] ) ⏟ Δ 1 , 𝑡 ,

and

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 + 1 ) , 𝐬 𝑗 ⟩ ≥

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , 𝐬 𝑗 ⟩

(106)

− 𝜆 ​ 𝜂 ( 𝑡 ) ​ 𝐾 ~ ​ ( 𝑗 ) 𝑛 ​ 𝑚 ​ logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) + 𝜆 ​ Θ ​ ( 1 ) ⋅ 𝜂 ( 𝑡 ) 𝑛 ​ 𝑚 ​ ∑ 𝑘 ∈ ℬ 𝑗 ( 1 − logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑘 ] ) ) ⏟ Δ 2 , 𝑡 ,

by the fact that ‖ 𝐫 𝑘 ‖ 2 2 ≤ ⟨ Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑘 ] ) , 𝐫 𝑘 ⟩ ≤ 2 ​ ‖ 𝐫 𝑘 ‖ 2 2 .

Similarly, for ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 + 1 ) , 𝐨 ⟩ , we have

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 + 1 ) , 𝐨 ⟩ ≥ ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , 𝐨 ⟩

(107)

− 𝜆 ​ 𝜂 ( 𝑡 ) ​ 𝐾 ~ ​ ( 𝑗 ) ​ 𝑑 𝑛 ​ 𝑚 ​ logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) + 𝜆 ​ Θ ​ ( 1 ) ⋅ 𝜂 ( 𝑡 ) ​ 𝑑 𝑛 ​ 𝑚 ​ ∑ 𝑘 ∈ ℬ 𝑗 ( 1 − logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑘 ] ) ) − 𝒪 ​ ( 𝜆 ​ 𝜂 ( 𝑡 ) 𝑚 ​ 𝑑 1 / 2 ) ⏟ Δ 𝑡 𝑜 ,

When 1 ≤ 𝑡 ≤ 𝑇 1 , ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , 𝐫 𝑘 ⟩ , since 1 − logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑘 ] )

Θ ​ ( 1 ) , we have Δ 1 , 𝑡

0 , Δ 2 , 𝑡

0 , and Δ 𝑡 𝑜

0 . Therefore, the inner products ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , 𝐬 𝑗 ⟩ and ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , 𝐨 ⟩ remain positive. As a result, for all 𝑡 ∈ [ 𝑇 1 ] and 𝑙 ∈ 𝒮 r , 𝑗 , 𝑘 ( 0 ) , we have

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑘 ] ) ⟩

0 .

(108)

This completes the proof. ∎

F.3Proof of pre-training Stage 2

In this subsection, we will prove several key properties of pre-training Stage 1, including the attention scores, activation patterns, and MLP feature learning. First, we prove the following lemma.

Lemma 12.

For all 𝑡 ∈ [ 0 , 𝑇 𝑝 ] , the following holds:

For all 𝑗 ∈ [ 𝑁 ] , 𝑙 ∈ [ 𝑚 ] and 𝑘 ≠ 𝑗 , we have

1 𝑚 ​ ∑ 𝑙

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , 𝐬 𝑘 ⟩

𝒪 ~ ​ ( 𝜎 0 ) .

(109) •

For all 𝑗 ∈ [ 𝑁 ] , 𝑙 ∈ [ 𝑚 ] and 𝑖 ∉ ℬ 𝑗 , we have

1 𝑚 ​ ∑ 𝑙

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , 𝐫 𝑖 ⟩

𝒪 ~ ​ ( 𝜎 0 ) .

(110) •

For all 𝑖 , 𝑘 ∈ ℛ , 𝑙 ∈ [ 𝑚 ] , we have

1 𝑚 ​ ∑ 𝑙

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 ( 𝑡 ) , 𝐫 𝑘 ⟩

𝒪 ~ ​ ( 𝜎 0 ) .

(111) Proof.

We prove these three statements sequentially.

Proof of the first statement: During pre-training 𝑡 ∈ [ 0 , 𝑇 𝑝 − 1 ] , for all 𝑙 ∈ [ 𝑚 ] , 𝑗 ∈ [ 𝑁 ] , and 𝑘 ≠ 𝑗 , we have

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 + 1 ) , 𝐬 𝑘 ⟩ − ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , 𝐬 𝑘 ⟩

(112)

=

− ∑ 𝑖 ∈ ℬ 𝑘 𝛼 2 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑘 ​ 𝐫 𝑖 ] ) ​ 𝜂 ( 𝑡 ) ​ 𝜆 𝑚 ​ 𝑛 ​ 𝕀 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑘 ​ 𝐫 𝑖 ] ) ⟩

0 ) ​ logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑘 ​ 𝐫 𝑖 ] )

− ∑ 𝑖 ∈ ℬ 𝑘 𝛼 2 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑘 ] ) ​ 𝜂 ( 𝑡 ) ​ 𝜆 𝑚 ​ 𝑛 ​ 𝕀 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑘 ] ) ⟩

0 ) ​ logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑘 ] )

0 .

Therefore, for all 𝑙 ∈ [ 𝑚 ] , 𝑡 ∈ [ 𝑇 𝑝 ] and 𝑗 ∈ [ 𝑁 ] , 𝑘 ≠ 𝑗 , by Lemma 7, we have

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , 𝐬 𝑘 ⟩ ≤ ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 0 ) , 𝐬 𝑘 ⟩

𝒪 ~ ​ ( 𝜎 0 ) .

(113)

Proof of the second statement: During pre-training 𝑡 ∈ [ 0 , 𝑇 𝑝 − 1 ] , for all 𝑙 ∈ [ 𝑚 ] , 𝑗 ∈ [ 𝑁 ] , and 𝑖 ∉ ℬ 𝑗 , we have

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 + 1 ) , 𝐫 𝑖 ⟩ − ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , 𝐫 𝑖 ⟩

(114)

=

− ∑ 𝑗 ∈ 𝒟 𝑖 𝛼 3 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ​ 𝜂 ( 𝑡 ) ​ 𝜆 𝑚 ​ 𝑛 ​ 𝕀 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⟩

0 ) ​ logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] )

0 .

Hence, for all 𝑙 ∈ [ 𝑚 ] , 𝑗 ∈ [ 𝑁 ] and 𝑖 ∉ ℬ 𝑗 , by Lemma 7, we have

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , 𝐫 𝑖 ⟩ ≤ ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 0 ) , 𝐫 𝑖 ⟩

𝒪 ~ ​ ( 𝜎 0 ) .

(115)

Proof of the third statement: During Pre-training 𝑡 ∈ [ 0 , 𝑇 𝑝 − 1 ] , for all 𝑖 , 𝑘 ∈ ℛ and 𝑙 ∈ [ 𝑚 ] , we have

⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 ( 𝑡 + 1 ) , 𝐫 𝑘 ⟩ − ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 ( 𝑡 ) , 𝐫 𝑘 ⟩

(116)

=

− ∑ 𝑗 ∈ 𝒟 𝑘 𝛼 3 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑘 ] ) ​ 𝜂 ( 𝑡 ) ​ 𝜆 𝑚 ​ 𝑛 ​ 𝕀 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 ( 𝑡 ) , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑘 ] ) ⟩

0 ) ​ logit ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑘 ] )

0 .

Hence, for all 𝑙 ∈ [ 𝑚 ] and 𝑖 , 𝑘 ∈ ℛ , by Lemma 7, we have

⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 ( 𝑡 ) , 𝐫 𝑘 ⟩ ≤ ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 ( 0 ) , 𝐫 𝑘 ⟩

𝒪 ~ ​ ( 𝜎 0 ) .

(117)

This completes the proof. ∎

Now, we proceed to derive the key properties in Stage 2.

Lemma 13.

Under Condition 1, with probability 1 − 𝛿 , in pre-training Stage 2 ( 𝑇 1 ≤ 𝑡 ≤ 𝑇 2

Θ ​ ( 𝑛 ​ 𝑚 ​ log ⁡ ( 𝑑 ) / ( 𝜆 2 ​ 𝜂 2 ​ 𝐾 ) ) ), the followings hold:

1.

For all 𝑗 ∈ [ 𝑁 ] and 𝑖 ∈ ℬ 𝑗 , we have 1 − logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] )

Θ ​ ( 1 ) and 1 − 𝐾 ~ ​ ( 𝑗 ) ​ logit ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] )

Θ ​ ( 1 ) .

2.

For all 𝑗 ∈ [ 𝑁 ] and 𝑖 ∈ ℬ 𝑗 , we have 0.7 ≤ 𝛼 2 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) / 𝛼 3 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ≤ 1.6 .

3.

For all 𝑗 ∈ [ 𝑁 ] and 𝑖 ∈ ℬ 𝑗 , we have

𝛼 1 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ≤ 2 3 ​ 𝑑 5 / 4 ​  and  ​ 𝛼 1 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ≤ 2 3 ​ 𝑑 5 / 4 .

(118) 4.

We have 1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑘 ( 𝑡 ) , 2 ​ 𝐨 ⟩

Ω ​ ( log ⁡ ( 𝑑 ) / 𝜆 ) and 1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑘 ( 𝑡 ) , 2 ​ 𝐨 ⟩

𝒪 ​ ( 𝑑 1 / 4 ​ log ⁡ ( 𝑑 ) / 𝜆 ) .

5.

For each 𝑗 ∈ [ 𝑁 ] and all 𝑘 ∈ ℬ 𝑗 , 𝑙 ∈ 𝒮 s , 𝑘 , 𝑗 ( 0 ) , we have ⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑡 ) , Ξ ( 𝑡 ) ( [ 𝐨 𝐬 𝑗 ] ) ) ⟩

0 .

6.

For each 𝑗 ∈ [ 𝑁 ] and all 𝑘 ∈ ℬ 𝑗 , 𝑙 ∈ 𝒮 r , 𝑗 , 𝑘 ( 0 ) . we have ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑘 ] ) ⟩

0 .

7.

For all 𝑗 ∈ [ 𝑁 ] , 𝑙 ∈ [ 𝑚 ] , we have

⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 ( 𝑡 ) , 𝐨 ⟩ ≤ 𝒪 ​ ( log ⁡ ( 𝑑 ) 𝜆 ) , ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , 𝐨 ⟩ ≤ 𝒪 ​ ( 𝑛 ​ log ⁡ ( 𝑑 ) 𝜆 ) ,

(119) Proof.

In this proof, we establish each of the seven statements in Lemma 13 by induction. It is straightforward to demonstrate that the conclusions hold at 𝑇 1 using Lemma 11. Suppose that the conclusions hold at iteration 𝑇 1 ≤ 𝑡 ≤ 𝑇 2 − 1 . We now prove that the conclusions hold at iteration 𝑡 + 1 sequentially.

Proof of the first statement: For any 𝑗 ∈ [ 𝑁 ] , 𝑖 ∈ ℬ 𝑗 , the model updates satisfy

1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑘 ( 𝑡 + 1 ) , Ξ ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ⟩ − 1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑘 ( 𝑡 ) , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ⟩

(120)

≤ ( 𝑎 )
Θ ( 1 ) ⋅ 𝜆 ​ 𝜂 2 𝑛 ​ 𝑚 ( 1 − 𝐾 ~ ( 𝑗 ) logit ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) ( [ 𝐨 𝐬 𝑗 ] )

( 𝑏 )

𝒪 ​ ( 𝜆 ​ 𝜂 2 𝑛 ​ 𝑚 ) ,

and

1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑘 ( 𝑡 + 1 ) , Ξ ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⟩ − 1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑘 ( 𝑡 ) , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⟩

(121)

= ( 𝑐 )
Θ ( 1 ) ⋅ ( 𝜆 ​ 𝜂 2 𝑛 ​ 𝑚 ( 1 − logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑡 ) ( [ 𝐨 𝐬 𝑗 𝐫 𝑖 ] ) ) ⟨ Ξ ( 𝑡 ) ( [ 𝐨 𝐬 𝑗 𝐫 𝑖 ] ) , Ξ ( 𝑡 + 1 ) ( [ 𝐨 𝐬 𝑗 𝐫 𝑖 ] ) ⟩

+ ∑ 𝑘 ≠ 𝑖 𝜆 ​ 𝜂 2 𝑛 ​ 𝑚 ​ ( 1 − logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑘 ] ) ) ​ ⟨ Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑘 ] ) , Ξ ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⟩

− 𝜆 ​ 𝜂 2 𝑛 ​ 𝑚 logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑡 ) ( [ 𝐨 𝐬 𝑗 ] ) ⟨ Ξ ( 𝑡 ) ( [ 𝐨 𝐬 𝑗 ] ) , Ξ ( 𝑡 + 1 ) ( [ 𝐨 𝐬 𝑗 𝐫 𝑖 ] ) ⟩ )

𝒪 ​ ( 𝜆 ​ 𝜂 2 ​ 𝐾 𝑛 ​ 𝑚 ) ,

where ( 𝑎 ) is because of the gradient form in Lemma 1, the fifth statement and Lemma 10, ( 𝑏 ) is due to 1 − 𝐾 ~ ( 𝑗 ) logit ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) ( [ 𝐨 𝐬 𝑗 ]

Θ ( 1 ) , and ( 𝑐 ) is because of Lemmas 1, the fifth statement, Lemma 10 and the fact that under Condition 1, Ξ ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) is close to Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) , i.e.,

‖ Ξ ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) − Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ‖ 2 ≤ 4 ​ ( exp ⁡ ( 𝒪 ​ ( 𝜂 2 ​ log ⁡ ( 𝑑 ) 𝑑 ) ) − 1 )

𝒪 ​ ( 𝜂 2 ​ log ⁡ ( 𝑑 ) 𝑑 ) .

(122)

Inequality (122) is derived using the change of the attention score at iteration 𝑡 , which is formally given in a later proof (130).

At iteration 𝑡 + 1 , we have

1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑘 ( 𝑡 + 1 ) , Ξ ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ⟩ − 1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑘 ( 𝑇 1 ) , Ξ ( 𝑇 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ⟩

𝒪 ​ ( 𝜆 ​ 𝜂 2 𝑛 ​ 𝑚 ​ 𝑇 2 )

𝒪 ​ ( log ⁡ ( 𝑑 ) 𝜆 ​ 𝐾 ~ ​ ( 𝑗 ) ) ,

(123)

and

1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑘 ( 𝑡 + 1 ) , Ξ ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⟩ − 1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑘 ( 𝑇 1 ) , Ξ ( 𝑇 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⟩

𝒪 ​ ( 𝜆 ​ 𝜂 2 ​ 𝐾 ~ ​ ( 𝑗 ) 𝑛 ​ 𝑚 ​ 𝑇 2 )

(124)

=

𝒪 ​ ( log ⁡ ( 𝑑 ) 𝜆 ) .

Therefore, we have

1 − logit ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] )

Θ ​ ( 1 ) .

(125)

Similarly, we have

1 − 𝐾 ~ ​ ( 𝑗 ) ​ logit ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] )

Θ ​ ( 1 ) .

(126)

The first statement holds at iteration 𝑡 + 1 .

Proof of the second statement: We have

𝛼 2 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) 𝛼 3 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] )

exp ⁡ ( 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) / 𝑑 ) exp ⁡ ( 𝑍 ~ ℐ ​ ( 𝐫 𝑖 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) / 𝑑 )

exp ⁡ ( 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) − 𝑍 ~ ℐ ​ ( 𝐫 𝑖 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) 𝑑 ) .

(127)

The terms 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) and 𝑍 ~ ℐ ​ ( 𝐫 𝑖 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) satisfy

| 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 + 1 ) − 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) |

(128)

=
𝜂 2 ​ 𝜆 𝑛 ​ 𝑑 ​ 𝛼 2 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⏟ ≤ 1 ​ | ( 𝐞 ℐ ​ ( 𝐚 𝑗 ) − logit ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ) ⊤ ​ 𝜻 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ​ ( 𝐬 𝑗 − Ξ ~ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ) | ⏟

𝒪 ​ ( log ⁡ ( 𝑑 ) / 𝜆 )

( 𝑎 )

𝒪 ​ ( 𝜂 2 ​ log ⁡ ( 𝑑 ) 𝑛 ​ 𝑑 1 / 2 ) ,

and

| 𝑍 ~ ℐ ​ ( 𝐫 𝑖 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 + 1 ) − 𝑍 ~ ℐ ​ ( 𝐫 𝑖 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) |

(129)

=
𝜂 2 ​ 𝜆 𝑛 ​ 𝑑 ​ ∑ 𝑗 ∈ 𝒟 𝑖 𝛼 3 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⏟ ≤ 1 ​ | ( 𝐞 ℐ ​ ( 𝐚 𝑗 ) − logit ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ) ⊤ ​ 𝜻 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ​ ( 𝐫 𝑖 − Ξ ~ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ) | ⏟

𝒪 ​ ( log ⁡ ( 𝑑 ) / 𝜆 )

( 𝑏 )

𝒪 ​ ( 𝜂 2 ​ log ⁡ ( 𝑑 ) 𝑑 1 / 2 ) .

where ( 𝑎 ) follows the conclusion from the third statement

𝛼 1 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ≤ 2 / ( 3 ​ 𝑑 5 / 4 ) and the conclusion from the first statement which implies that | ( 𝐞 ℐ ​ ( 𝐚 𝑗 ) − logit ( 𝑡 ) ( [ 𝐨 𝐬 𝑗 𝐫 𝑖 ] ) ) ⊤ 𝜻 ( 𝑡 ) ( [ 𝐨 𝐬 𝑗 𝐫 𝑖 ] ) ( 𝐬 𝑗 − Ξ ~ ( 𝑡 ) ( [ 𝐨 𝐬 𝑗 𝐫 𝑖 ] ) ) |

𝒪 ( log ⁡ ( 𝑑 ) / 𝜆 ) . ( 𝑏 ) follows the conclusion from the third statement 𝛼 1 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ≤ 2 / ( 3 ​ 𝑑 5 / 4 ) and the conclusion from the first statement which implies that | ( 𝐞 ℐ ​ ( 𝐚 𝑗 ) − logit ( 𝑡 ) ( [ 𝐨 𝐬 𝑗 𝐫 𝑖 ] ) ) ⊤ 𝜻 ( 𝑡 ) ( [ 𝐨 𝐬 𝑗 𝐫 𝑖 ] ) ( 𝐫 𝑖 − Ξ ~ ( 𝑡 ) ( [ 𝐨 𝐬 𝑗 𝐫 𝑖 ] ) ) |

𝒪 ( log ⁡ ( 𝑑 ) / 𝜆 ) .

Therefore, we have

𝛼 2 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) 𝛼 3 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ​ 𝛼 3 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) 𝛼 2 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] )

exp ⁡ ( 𝒪 ​ ( 𝜂 2 ​ log ⁡ ( 𝑑 ) 𝑑 ) ) .

(130)

Recursively using (130) from 𝑇 1 to 𝑡 , we have

𝛼 2 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) 𝛼 3 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ​ 𝛼 3 ( 𝑇 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) 𝛼 2 ( 𝑇 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] )

exp ⁡ ( 𝒪 ​ ( ( 𝑡 + 1 − 𝑇 1 ) ​ 𝜂 2 ​ log ⁡ ( 𝑑 ) 𝑑 ) )

(131)

=
exp ⁡ ( 𝒪 ​ ( ( 𝑇 2 − 𝑇 1 ) ​ 𝜂 2 ​ log ⁡ ( 𝑑 ) 𝑑 ) )

exp ⁡ ( 𝒪 ​ ( 𝜂 2 ​ log ⁡ ( 𝑑 ) 𝑑 ​ 𝑛 ​ 𝑚 ​ log ⁡ ( 𝑑 ) 𝜆 2 ​ 𝜂 2 ​ 𝐾 ) )

exp ⁡ ( 𝒪 ​ ( 𝑛 ​ 𝑚 ​ log 2 ⁡ ( 𝑑 ) 𝜆 2 ​ 𝑑 ​ 𝐾 ) ) .

Under Condition 1, by 0.8 ≤ 𝛼 2 ( 𝑇 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) / 𝛼 3 ( 𝑇 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ≤ 1.2 , we have

𝛼 2 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) 𝛼 3 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ≤ 1.6 .

(132)

Similarly, we have

𝛼 2 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) 𝛼 3 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ≥ 0.7 .

(133)

Hence, the second statement holds at iteration 𝑡 + 1 .

Proof of the third statement: First, we prove the first conclusion in the third statement.

For 𝑇 1 ≤ 𝑡 ≤ 𝑇 2 , the increment of the attention matrix satisfies

| 𝑍 ~ ℐ ​ ( 𝐨 ) , ℐ ​ ( 𝐬 𝑗 ) ( 𝑡 + 1 ) − 𝑍 ~ ℐ ​ ( 𝐨 ) , ℐ ​ ( 𝐬 𝑗 ) ( 𝑡 ) |

(134)


𝜂 2 ​ 𝜆 𝑛 ​ 𝑑 ​ 𝛼 1 ( 𝑡 ) ​ ( [ 𝐨 , 𝐬 𝑗 ] ) ⏟ ≤ 1 / 𝑑 5 / 4 ​ ∑ 𝑖 ∈ ℬ 𝑗 | ( 𝐞 ℐ ​ ( 𝐫 𝑖 ) − logit ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ) ⊤ ​ 𝜻 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ​ ( 𝐨 − Ξ ~ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ) | ⏟

𝒪 ​ ( 𝑑 1 / 4 ​ log ⁡ ( 𝑑 ) / 𝜆 ) ⋅ Θ ​ ( 𝑑 )

( 𝑎 )

𝒪 ​ ( 𝜂 2 ​ 𝐾 ​ log ⁡ ( 𝑑 ) 𝑛 ​ 𝑑 ) ,

and

| 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐬 𝑗 ) ( 𝑡 + 1 ) − 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐬 𝑗 ) ( 𝑡 ) |

(135)

=
𝜂 2 ​ 𝜆 𝑛 ​ 𝑑 ​ 𝛼 2 ( 𝑡 ) ​ ( [ 𝐨 , 𝐬 𝑗 ] ) ⏟

Θ ​ ( 1 ) ​ ∑ 𝑖 ∈ ℬ 𝑗 | ( 𝐞 ℐ ​ ( 𝐫 𝑖 ) − logit ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ) ⊤ ​ 𝜻 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ​ ( 𝐬 𝑗 − Ξ ~ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ) | ⏟

𝒪 ​ ( log ⁡ ( 𝑑 ) / 𝜆 )

( 𝑏 )

𝒪 ​ ( 𝜂 2 ​ 𝐾 ​ log ⁡ ( 𝑑 ) 𝑛 ​ 𝑑 1 / 2 ) ,

where ( 𝑎 ) follows the conclusion from the third statement 𝛼 1 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ≤ 2 / ( 3 ​ 𝑑 5 / 4 ) and the conclusions from the first and the fourth statements which imply that | ( 𝐞 ℐ ​ ( 𝐫 𝑖 ) − logit ( 𝑡 ) ( [ 𝐨 𝐬 𝑗 ] ) ) ⊤ 𝜻 ( 𝑡 ) ( [ 𝐨 𝐬 𝑗 ] ) ( 𝐨 − Ξ ~ ( 𝑡 ) ( [ 𝐨 𝐬 𝑗 ] ) ) |

𝒪 ( 𝑑 1 / 4 log ⁡ ( 𝑑 ) / 𝜆 ) . ( 𝑏 ) follows the conclusion from the third statement 𝛼 1 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ≤ 2 / ( 3 ​ 𝑑 5 / 4 ) and the conclusion from the first statement which implies that

| ( 𝐞 ℐ ​ ( 𝐫 𝑖 ) − logit ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ) ⊤ ​ 𝜻 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ​ ( 𝐬 𝑗 − Ξ ~ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ) |

𝒪 ​ ( log ⁡ ( 𝑑 ) / 𝜆 ) .

Therefore, we have

𝛼 1 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) 𝛼 2 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ​ 𝛼 2 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) 𝛼 1 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] )

exp ⁡ ( 𝒪 ​ ( 𝜂 2 ​ 𝐾 ​ log ⁡ ( 𝑑 ) 𝑛 ​ 𝑑 ) ) .

(136)

As a result, we have

𝛼 1 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) 𝛼 2 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ​ 𝛼 2 ( 𝑇 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) 𝛼 1 ( 𝑇 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] )

exp ⁡ ( 𝒪 ​ ( 𝜂 2 ​ 𝐾 ​ log ⁡ ( 𝑑 ) 𝑛 ​ 𝑑 ​ ( 𝑇 2 − 𝑇 1 ) ) )

exp ⁡ ( 𝒪 ​ ( 𝑚 ​ log 2 ⁡ ( 𝑑 ) 𝑑 ​ 𝜆 2 ) ) .

(137)

Under Condition 1, using 𝛼 1 ( 𝑇 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) / 𝛼 2 ( 𝑇 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ≤ 1 / ( 1.8 ​ 𝑑 5 / 4 ) , we have

𝛼 1 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) 𝛼 2 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ≤ 2 3 ​ 𝑑 5 / 4 .

(138)

Next, we prove the second conclusion in the third statement. For 𝑇 1 ≤ 𝑡 ≤ 𝑇 2 , we have

| 𝑍 ~ ℐ ​ ( 𝐨 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 + 1 ) − 𝑍 ~ ℐ ​ ( 𝐨 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) |

(139)

=
𝜂 2 ​ 𝜆 𝑛 ​ 𝑑 ​ ∑ 𝑗 ∈ 𝒟 𝑖 𝛼 1 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⏟ ≤ 1 / 𝑑 5 / 4

⋅ | ( 𝐞 ℐ ​ ( 𝐚 𝑗 ) − logit ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ) ⊤ ​ 𝜻 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⋅ ( 𝐨 − Ξ ~ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ) | ⏟ Θ ​ ( 𝑑 1 / 4 ​ log ⁡ ( 𝑑 ) / 𝜆 ) ⋅ Θ ​ ( 𝑑 )

( 𝑎 )

𝒪 ​ ( 𝜂 2 ​ log ⁡ ( 𝑑 ) 𝑑 ) ,

where ( 𝑎 ) follows the conclusion from the third statement 𝛼 1 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ≤ 2 / ( 3 ​ 𝑑 5 / 4 ) and the conclusions from the first and the fourth statements which imply that | ( 𝐞 ℐ ​ ( 𝐚 𝑗 ) − logit ( 𝑡 ) ( [ 𝐨 𝐬 𝑗 𝐫 𝑖 ] ) ) ⊤ 𝜻 ( 𝑡 ) ( [ 𝐨 𝐬 𝑗 𝐫 𝑖 ] ) ⋅ ( 𝐨 − Ξ ~ ( 𝑡 ) ( [ 𝐨 𝐬 𝑗 𝐫 𝑖 ] ) ) |

𝒪 ( 𝑑 1 / 4 log ⁡ ( 𝑑 ) / 𝜆 ) .

Therefore, we have

𝛼 1 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) 𝛼 2 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ​ 𝛼 2 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) 𝛼 1 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] )

exp ⁡ ( 𝒪 ​ ( 𝜂 2 ​ log ⁡ ( 𝑑 ) 𝑑 ) ) .

(140)

As a result, we have

𝛼 1 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) 𝛼 2 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ​ 𝛼 2 ( 𝑇 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) 𝛼 1 ( 𝑇 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] )

exp ⁡ ( 𝒪 ​ ( 𝜂 2 ​ log ⁡ ( 𝑑 ) 𝑑 ​ ( 𝑇 2 − 𝑇 1 ) ) )

(141)

=

exp ⁡ ( 𝒪 ​ ( 𝑚 ​ 𝑛 ​ log 2 ⁡ ( 𝑑 ) 𝑑 ​ 𝜆 2 ​ 𝐾 ) ) .

Under Condition 1, using 𝛼 1 ( 𝑇 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) / 𝛼 2 ( 𝑇 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ≤ 0.625 / 𝑑 5 / 4 , we have

𝛼 1 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) 𝛼 2 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ≤ 0.8 𝑑 5 / 4 .

(142)

Therefore, we have

𝛼 1 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ≤ 2 3 ​ 𝑑 5 / 4 , 𝛼 1 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ≤ 2 3 ​ 𝑑 5 / 4 .

(143)

The third statement holds for iteration 𝑡 + 1 .

Proof of the fourth statement: For 𝑡 ′ ∈ [ 𝑇 1 , 𝑇 2 ] , by the gradient form in Lemma 1, we have

1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑘 ( 𝑡 ′ + 1 ) , 2 ​ 𝐨 ⟩ − 1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑘 ( 𝑡 ′ ) , 2 ​ 𝐨 ⟩

(144)


− 2 ​ 𝜂 2 ​ 𝐾 ~ ​ ( 𝑗 ) 𝑚 2 ​ 𝑛 ​ Θ ​ ( 1 ) ​ ∑ 𝑘

1 𝑚 ∑ 𝑗

1 𝑁 logit ℐ ​ ( 𝐝 ) ( 𝑡 ′ ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ​ 𝛼 1 ( 𝑡 ′ ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ​ 𝜎 ′ ​ ( ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑘 ( 𝑡 ′ ) , Ξ ( 𝑡 ′ ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ⟩ ) ​ 𝜆 ​ ‖ 𝐨 ‖ 2 2

− 𝜂 2 𝑚 2 ​ 𝑛 ​ Θ ​ ( 1 ) ​ ∑ 𝑘

1 𝑚 ∑ 𝑗

1 𝑁 ∑ 𝑖 ∈ ℬ 𝑗 logit ℐ ​ ( 𝐝 ) ( 𝑡 ′ ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ​ 𝛼 1 ( 𝑡 ′ ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ​ 𝜎 ′ ​ ( ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑘 ( 𝑡 ′ ) , Ξ ( 𝑡 ′ ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⟩ ) ​ 𝜆 ​ ‖ 𝐨 ‖ 2 2

( 𝑎 )

− 𝒪 ​ ( 𝜂 2 ​ 𝜆 𝑚 ​ 𝑑 3 / 4 ) ,

where ( 𝑎 ) is obtained by the third statement, logit ℐ ​ ( 𝐝 ) ( 𝑡 ′ ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ≤ 1 , logit ℐ ​ ( 𝐝 ) ( 𝑡 ′ ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ≤ 1 , and 𝜎 ′ ​ ( ⋅ ) ≤ 1 .

As a result, we have

1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑘 ( 𝑡 + 1 ) , 2 ​ 𝐨 ⟩ − 1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑘 ( 𝑇 1 ) , 2 ​ 𝐨 ⟩ ≥ 𝐴 ,

(145)

where 𝐴 ≤ 0 and

| 𝐴 |

𝒪 ​ ( 𝜂 2 ​ 𝜆 𝑚 ​ 𝑑 3 / 4 ​ ( 𝑇 2 − 𝑇 1 ) )

𝒪 ​ ( 𝜂 2 ​ 𝜆 𝑚 ​ 𝑑 3 / 4 ​ 𝑛 ​ 𝑚 ​ log ⁡ ( 𝑑 ) 𝜆 2 ​ 𝜂 2 ​ 𝐾 )

𝒪 ​ ( 𝑛 ​ log ⁡ ( 𝑑 ) 𝜆 ​ 𝐾 ​ 𝑑 3 / 4 ) .

(146)

As 1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑘 ( 𝑇 1 ) , 2 ​ 𝐨 ⟩

Θ ​ ( 𝜂 1 ​ 𝜆 ​ 𝑑 / 𝑚 ) , we have

1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑘 ( 𝑡 + 1 ) , 2 ​ 𝐨 ⟩

Ω ​ ( log ⁡ ( 𝑑 ) 𝜆 ) .

(147)

Furthermore, for 𝑡 ′ ∈ [ 𝑇 1 , 𝑇 2 ] , by the gradient form in Lemma 1, we also have

1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑘 ( 𝑡 ′ + 1 ) , 2 ​ 𝐨 ⟩ − 1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑘 ( 𝑡 ′ ) , 2 ​ 𝐨 ⟩

(148)


Θ ​ ( 1 ) ⋅ 𝜂 2 𝑛 ​ 𝑛 3 ​ 𝑚 2 ​ ∑ 𝑘

1 𝑚 ∑ 𝑖

1 𝑛 ( 1 − logit ℐ ​ ( 𝐝 ) ( 𝑡 ′ ) ​ ( 2 ​ 𝐨 ) ) ​ 𝜎 ′ ​ ( ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑘 ( 𝑡 ′ ) , 2 ​ 𝐨 ⟩ ) ​ 𝜆 ​ ‖ 2 ​ 𝐨 ‖ 2 2

( 𝑎 )

𝒪 ​ ( 𝜂 2 ​ 𝜆 𝑚 ​ 𝑑 ) ,

where ( 𝑎 ) holds because 1 − logit ℐ ​ ( 𝐝 ) ( 𝑡 ′ ) ​ ( 2 ​ 𝐨 ) ≤ 𝑑 − 2 as 1 / 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑘 ( 𝑡 ′ ) , 2 ​ 𝐨 ⟩

Ω ​ ( log ⁡ ( 𝑑 ) / 𝜆 ) .

Hence, we have

1 𝑚 ​ ∑ 𝑘

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐝 ) , 𝑘 ( 𝑡 + 1 ) , 2 ​ 𝐨 ⟩

𝒪 ​ ( 𝑑 1 / 4 ​ log ⁡ ( 𝑑 ) 𝜆 ) .

(149)

Thus, the fourth statement holds at iteration 𝑡 + 1 .

Proof of the fifth statement: After Stage 1 with 𝑡 ≥ 𝑇 1 , for 𝑙 ∈ [ 𝑚 ] satisfying ⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑡 ) , 𝐬 𝑗 ⟩

0 , the update of ⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑡 ) , 𝐬 𝑗 ⟩ satisfies

⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑡 + 1 ) , 𝐬 𝑗 ⟩

(150)

⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑡 ) , 𝐬 𝑗 ⟩ + 𝜆 ​ 𝜂 2 𝑛 ​ 𝑚 ​ ( 1 − 𝐾 ~ ​ ( 𝑗 ) ​ logit ℐ ​ ( 𝐫 𝑘 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ) − 𝜆 ​ 𝜂 2 𝑛 ​ 𝑚 ​ ∑ 𝑖 ∈ ℬ 𝑗 logit ℐ ​ ( 𝐫 𝑘 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ,

by the fact that ⟨ Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) , 𝐬 𝑗 ⟩ ≥ ‖ 𝐬 𝑗 ‖ 2 2

1 and ⟨ Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) , 𝐬 𝑗 ⟩ ≤ ‖ 𝐬 𝑗 ‖ 2 2

1 . Similarly, for all 𝑙 ∈ [ 𝑚 ] , the update of ⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑡 ) , 𝐨 ⟩ satisfies

⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑡 + 1 ) , 𝐨 ⟩ ≥

⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑡 ) , 𝐨 ⟩ − 𝒪 ​ ( 𝜆 ​ 𝜂 2 𝑛 ​ 𝑚 ​ 𝑑 1 / 4 ) ​ ∑ 𝑖 ∈ ℬ 𝑗 logit ℐ ​ ( 𝐫 𝑘 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] )

(151)

− 𝒪 ​ ( 𝜆 ​ 𝜂 2 𝑛 ​ 𝑚 ​ 𝑑 1 / 4 ) ​ ∑ 𝑗 ∈ 𝒟 𝑘 logit ℐ ​ ( 𝐫 𝑘 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) − 𝒪 ​ ( 𝜆 ​ 𝜂 2 𝑚 ) ​ logit ℐ ​ ( 𝐫 𝑘 ) ( 𝑡 ) ​ ( 2 ​ 𝐨 )

⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑡 ) , 𝐨 ⟩ − 𝒪 ​ ( 𝜆 ​ 𝜂 2 ​ 𝐾 𝑛 ​ 𝑚 ​ 𝑑 1 / 4 ) − 𝒪 ​ ( 𝜆 ​ 𝜂 2 𝑚 ​ 𝑑 1 / 4 ) − 𝒪 ​ ( 𝜆 ​ 𝜂 2 ​ 𝐾 𝑚 ​ 𝑑 3 / 2 ) ,

As a result, by the third statement, we have

⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑡 + 1 ) , Ξ ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ⟩ − ⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑡 ) , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ⟩

(152)

0.9 ​ 𝜆 ​ 𝜂 2 𝑛 ​ 𝑚 ​ ( 1 − 𝐾 ~ ​ ( 𝑗 ) ​ logit ℐ ​ ( 𝐫 𝑘 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ) − 1.1 ​ 𝜆 ​ 𝜂 2 𝑛 ​ 𝑚 ​ ∑ 𝑖 ∈ ℬ 𝑗 logit ℐ ​ ( 𝐫 𝑘 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] )

− 𝒪 ​ ( 𝜆 ​ 𝜂 2 𝑚 ​ 𝑑 3 / 2 ) − 𝒪 ​ ( 𝜂 2 ​ log 2 ⁡ ( 𝑑 ) 𝜆 ​ 𝑑 ) .

Case 1: If 1 𝑚 ​ ∑ 𝑙

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑡 ) , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ⟩ ) ≤ 2 ​ log ⁡ ( 𝑑 ) / ( 3 ​ 𝜆 ) , for all 𝑖 ∈ ℬ 𝑗 , we have

1 − 𝐾 ~ ​ ( 𝑗 ) ​ logit ℐ ​ ( 𝐫 𝑘 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] )

Θ ​ ( 1 ) , logit ℐ ​ ( 𝐫 𝑘 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ≤ 1 𝑑 ,

(153)

by Lemma 12. As a result, 1 𝑚 ​ ∑ 𝑙

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑡 + 1 ) , Ξ ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ⟩ ) − 1 𝑚 ​ ∑ 𝑙

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑡 ) , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ⟩ )

Θ ​ ( 1 ) holds.

Case 2: If 1 𝑚 ​ ∑ 𝑙

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑡 ) , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ⟩ )

2 ​ log ⁡ ( 𝑑 ) / ( 3 ​ 𝜆 ) , for all 𝑙 ∈ [ 𝑚 ] , we have

⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑡 + 1 ) , Ξ ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ⟩ − ⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑡 ) , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ⟩ ≥ − 𝒪 ​ ( 𝜂 2 ​ 𝜆 𝑛 ​ 𝑚 ) .

(154)

Then, we have

1 𝑚 ​ ∑ 𝑙

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑡 + 1 ) , Ξ ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ⟩ ) ≥ log ⁡ ( 𝑑 ) 2 ​ 𝜆 .

(155)

Therefore, the fifth statement holds for iteration 𝑡 + 1 .

Proof of the sixth statement: We first prove that during 𝑡 ∈ [ 𝑇 1 , 𝑇 2 − 1 ] , we have

1 𝑚 ∑ 𝑙

1 𝑚 𝜎 ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , 𝐬 𝑗 ⟩ ) } ≥ 0 .

(156)

We prove (156) by considering the following cases.

Case 1: 1 𝑚 ∑ 𝑙

1 𝑚 𝜎 ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , 𝐬 𝑗 ⟩ ) } ≤ log ⁡ ( 𝑑 ) / ( 3 𝜆 ) . For all 𝑙 ∈ ∪ 𝑖 ∈ ℬ 𝑗 𝒮 𝐫 , 𝑗 , 𝑖 ( 𝑡 ) , we have

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 + 1 ) , 𝐬 𝑗 ⟩ ≥

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , 𝐬 𝑗 ⟩ + 𝜆 ​ 𝜂 2 𝑛 ​ 𝑚 ​ ∑ 𝑖 ∈ ℬ 𝑗 𝕀 ​ ( 𝑙 ∈ 𝒮 𝐫 , 𝑗 , 𝑖 ( 𝑡 ) ) ​ ( 1 − logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ) − 𝜆 ​ 𝜂 2 𝑚 ​ 𝑛 ​ 𝐾 ~ ​ ( 𝑗 ) ​ logit ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] )

(157)

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , 𝐬 𝑗 ⟩ + Θ ​ ( 𝜂 2 ​ 𝜆 𝑛 ​ 𝑚 ) − 𝒪 ​ ( 𝜂 2 ​ 𝜆 𝑛 ​ 𝑚 ​ 𝑑 )

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , 𝐬 𝑗 ⟩ .

Case 2: 1 𝑚 ∑ 𝑙

1 𝑚 𝜎 ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , 𝐬 𝑗 ⟩ ) }

log ⁡ ( 𝑑 ) / ( 3 𝜆 ) . We have

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 + 1 ) , 𝐬 𝑗 ⟩ ≥

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , 𝐬 𝑗 ⟩ − 𝒪 ​ ( 𝜂 2 ​ 𝜆 ​ 𝐾 𝑚 ​ 𝑛 ) ≥ log ⁡ ( 𝑑 ) 4 ​ 𝜆 .

(158)

As a result, we have

1 𝑚 ​ ∑ 𝑙

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , 𝐬 𝑗 ⟩ ) ≥ 0 .

(159)

Based on (159), we now prove the conclusion in the sixth statement. We have

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 + 1 ) , Ξ ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⟩ ≥ ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⟩

(160)

− 1.1 ​ 𝜆 ​ 𝜂 ( 𝑡 ) ​ 𝐾 ~ ​ ( 𝑗 ) 𝑛 ​ 𝑚 ​ logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) + 𝜆 ​ Θ ​ ( 1 ) ⋅ 𝜂 ( 𝑡 ) 𝑛 ​ 𝑚 ​ ∑ 𝑘 ∈ ℬ 𝑗 ( 1 − logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑘 ] ) )

− 2 ​ 𝜆 ​ ∑ 𝑘 ≠ 𝑗 , 𝑘 ∈ 𝒟 𝑖 𝜂 ( 𝑡 ) 𝑛 ​ 𝑚 ​ logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑘 ​ 𝐫 𝑖 ] ) − 𝒪 ​ ( 𝜆 ​ 𝜂 ( 𝑡 ) 𝑚 ​ 𝑑 3 / 2 ) − 𝒪 ​ ( 𝜂 ( 𝑡 ) ​ log 2 ⁡ ( 𝑑 ) 𝜆 ​ 𝑑 ) ,

We consider two cases.

Case 1: 1 𝑚 ​ ∑ 𝑙

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⟩ ) ≤ log ⁡ ( 𝑑 ) / 6 ​ 𝜆 . By (159), we also have

1 𝑚 ​ ∑ 𝑙

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , 𝐬 𝑗 ⟩ ) ≤ log ⁡ ( 𝑑 ) 2 ​ 𝜆 , 1 𝑚 ​ ∑ 𝑙

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , 𝐫 𝑖 ⟩ ) ≤ log ⁡ ( 𝑑 ) 6 ​ 𝜆 .

(161)

A direct result is

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 + 1 ) , Ξ ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⟩ ≥

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⟩ + Θ ​ ( 𝜆 ​ 𝜂 2 𝑚 ​ 𝑛 ) − 𝒪 ​ ( 𝜂 2 ​ 𝜆 𝑛 ​ 𝑚 ​ 𝑑 )

(162)

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⟩ .

Case 2: 1 𝑚 ​ ∑ 𝑙

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⟩ ) ≥ log ⁡ ( 𝑑 ) / 6 ​ 𝜆 . In this case, we have

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 + 1 ) , Ξ ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⟩ ≥

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⟩ − 𝒪 ​ ( 𝜂 2 ​ 𝜆 𝑛 ​ 𝑚 )

(163)

log ⁡ ( 𝑑 ) 12 ​ 𝜆 .

Thus, the sixth statement holds.

Proof of the seventh statement: First, prove the first part. For iteration 𝑡 ∈ [ 0 , 𝑇 𝑝 − 1 ] , we have

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 + 1 ) , 𝐨 ⟩ − ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , 𝐨 ⟩ ≤ 𝜂 ( 𝑡 ) ​ 𝜆 𝑛 ​ 𝑚 ​ ∑ 𝑖 ∈ ℬ 𝑗 𝛼 1 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ​ ( 1 − logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑡 ) ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ​ ‖ 𝐨 ‖ 2 2

(164)

After pre-training Stage 1, for all 𝑗 ∈ [ 𝑁 ] and 𝑙 ∈ [ 𝑚 ] , we have

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 1 ) , 𝐨 ⟩ − ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 0 ) , 𝐨 ⟩ ≤ 𝒪 ​ ( 𝜂 1 ​ 𝜆 ​ 𝐾 ​ 𝑑 𝑛 ​ 𝑚 ) .

(165)

Then, at pre-training Stage 2 with iteration 𝑡 , we have

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 + 1 ) , 𝐨 ⟩ − ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , 𝐨 ⟩ ≤ 𝒪 ​ ( 𝜂 2 ​ 𝜆 ​ 𝐾 𝑛 ​ 𝑚 ​ 𝑑 1 / 4 ) ,

(166)

due to the attention score bound in the third statement. Then, we have

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 + 1 ) , 𝐨 ⟩ − ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 0 ) , 𝐨 ⟩ ≤

𝒪 ​ ( 𝜂 1 ​ 𝜆 ​ 𝐾 ​ 𝑑 𝑛 ​ 𝑚 ) + 𝒪 ​ ( 𝜂 2 ​ 𝜆 ​ 𝐾 𝑛 ​ 𝑚 ​ 𝑑 1 / 4 ​ 𝑇 2 )

(167)

=
𝒪 ​ ( 𝜂 1 ​ 𝜆 ​ 𝐾 ​ 𝑑 𝑛 ​ 𝑚 + log ⁡ ( 𝑑 ) ​ 𝜆 𝑑 1 / 4 )

𝒪 ​ ( log ⁡ ( 𝑑 ) 𝜆 ) .

Next, we prove the second part. For iteration 𝑡 ∈ [ 0 , 𝑇 𝑝 − 1 ] , we have

⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 ( 𝑡 + 1 ) , 𝐨 ⟩ − ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 ( 𝑡 ) , 𝐨 ⟩ ≤ 𝜂 ( 𝑡 ) ​ 𝜆 𝑛 ​ 𝑚 ​ ∑ 𝑗 ∈ 𝒟 𝑖 𝛼 1 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ​ ( 1 − logit ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ) ​ ‖ 𝐨 ‖ 2 2

(168)

During pre-training Stage 1, we have

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 1 ) , 𝐨 ⟩ − ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 0 ) , 𝐨 ⟩ ≤ 𝒪 ​ ( 𝜂 1 ​ 𝜆 ​ 𝑑 𝑚 ) ,

(169)

Then, at pre-training Stage 2 with iteration 𝑡 , we have

⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 ( 𝑡 + 1 ) , 𝐨 ⟩ − ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 ( 𝑡 ) , 𝐨 ⟩ ≤ 𝒪 ​ ( 𝜂 2 ​ 𝜆 𝑚 ​ 𝑑 1 / 4 ) ,

(170)

Then, we have

⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 ( 𝑡 + 1 ) , 𝐨 ⟩ − ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 ( 0 ) , 𝐨 ⟩ ≤

𝒪 ​ ( 𝜂 1 ​ 𝜆 ​ 𝐾 ​ 𝑑 𝑛 ​ 𝑚 ) + 𝒪 ​ ( 𝜂 2 ​ 𝜆 𝑚 ​ 𝑇 2 )

(171)

=
𝒪 ​ ( 𝜂 1 ​ 𝜆 ​ 𝑑 𝑚 + 𝑛 ​ log ⁡ ( 𝑑 ) / ( 𝜆 ​ 𝑑 1 / 4 ) )

𝒪 ​ ( 𝑛 ​ log ⁡ ( 𝑑 ) / ( 𝜆 ) ) .

Thus the seventh statement holds at iteration 𝑡 + 1 .

This finishes the proof of the induction. ∎

F.4Proof of pre-training Stage 3

In this subsection, we prove the convergence of the model pre-training. We first derive the following key properties during Stage 3.

Lemma 14.

Under Condition 1, during pre-training Stage 3 ( 𝑇 2 < 𝑡 ≤ 𝑇 𝑝 ), the following holds:

1.

For all 𝑗 ∈ [ 𝑁 ] and 𝑖 ∈ ℬ 𝑗 , we have

0.5 ≤ 𝛼 2 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) 𝛼 3 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ≤ 2 .

(172) 2.

For all 𝑗 ∈ [ 𝑁 ] and 𝑖 ∈ ℬ 𝑗 , we have

𝛼 1 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ≤ 1 𝑑 5 / 4 , 𝛼 1 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ≤ 1 𝑑 5 / 4 .

(173) 3.

For each 𝑗 ∈ [ 𝑁 ] and all 𝑘 ∈ ℬ 𝑗 , 𝑙 ∈ 𝒮 s , 𝑘 , 𝑗 ( 0 ) , we have ⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑡 ) , Ξ ( 𝑡 ) ( [ 𝐨 𝐬 𝑗 ] ) ) ⟩

0 .

4.

For each 𝑗 ∈ [ 𝑁 ] and all 𝑘 ∈ ℬ 𝑗 , 𝑙 ∈ 𝒮 r , 𝑗 , 𝑘 ( 0 ) . we have ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑘 ] ) ⟩

0 .

5.

For each 𝑗 ∈ [ 𝑁 ] and all 𝑘 ∈ ℬ 𝑗 , 𝑙 ∈ 𝒮 s , 𝑘 , 𝑗 ( 0 ) , we have ⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑡 ) , 𝐬 𝑗 ⟩

Ω ​ ( log ⁡ ( 𝑑 ) / 𝜆 ) .

6.

For all 𝑗 ∈ [ 𝑁 ] , 𝑖 ∈ ℛ and 𝑙 ∈ [ 𝑚 ] , we have

⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 ( 𝑡 ) , 𝐨 ⟩ ≤ 𝒪 ~ ​ ( 𝑁 ​ ( 𝑁 + | ℛ | ) 𝜆 ​ 𝑑 1 / 4 ) , ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , 𝐨 ⟩ ≤ 𝒪 ~ ​ ( 𝑁 + | ℛ | 𝜆 ​ 𝑑 1 / 4 ) .

(174) 7.

For any 𝑗 ∈ [ 𝑁 ] and 𝑖 ∈ ℬ 𝑗 , we have

⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 ( 𝑡 ) , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ⟩

𝒪 ~ ​ ( 𝑁 + | ℛ | 𝜆 ) , ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⟩

𝒪 ~ ​ ( 𝑁 + | ℛ | 𝜆 ) .

(175) Proof.

In this proof, we establish each of the seven statements in Lemma 14 by induction. It is easy to show that the conclusions hold at 𝑇 2 by Lemma 13. Suppose that the conclusions hold at iteration 𝑇 2 ≤ 𝑡 ≤ 𝑇 𝑝 − 1 . We now prove that the conclusions hold at iteration 𝑡 + 1 sequentially.

Proof of the first statement: We have

𝛼 2 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) 𝛼 3 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] )

exp ⁡ ( 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) / 𝑑 ) exp ⁡ ( 𝑍 ~ ℐ ​ ( 𝐫 𝑖 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) / 𝑑 )

exp ⁡ ( 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) − 𝑍 ~ ℐ ​ ( 𝐫 𝑖 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) 𝑑 ) .

(176)

For the term 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) − 𝑍 ~ ℐ ​ ( 𝐫 𝑖 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) , we bound it by bounding the increments of 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) and 𝑍 ~ ℐ ​ ( 𝐫 𝑖 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) . We have

| 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 + 1 ) − 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) |

(177)

=
𝜂 3 ​ 𝜆 𝑛 ​ 𝑑 ​ 𝛼 2 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⏟ ≤ 1 ​ ( 𝐞 ℐ ​ ( 𝐚 𝑗 ) − logit ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ) ⊤ ​ 𝜻 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ​ ( 𝐬 𝑗 − Ξ ~ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ) ⏟

𝒪 ~ ​ ( 𝑁 + | ℛ | / 𝜆 )

𝒪 ​ ( 𝜂 3 ​ ( 𝑁 + | ℛ | ) 𝑛 ​ 𝑑 1 / 2 ) ,

and

| 𝑍 ~ ℐ ​ ( 𝐫 𝑖 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 + 1 ) − 𝑍 ~ ℐ ​ ( 𝐫 𝑖 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) |

(178)

=
𝜂 3 ​ 𝜆 𝑛 ​ 𝑑 ​ ∑ 𝑗 ∈ 𝒟 𝑖 𝛼 3 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⏟ ≤ 1 ​ ( 𝐞 ℐ ​ ( 𝐚 𝑗 ) − logit ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ) ⊤ ​ 𝜻 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ​ ( 𝐫 𝑖 − Ξ ~ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ) ⏟

𝒪 ~ ​ ( 𝑁 + | ℛ | / 𝜆 )

𝒪 ​ ( 𝜂 3 ​ ( 𝑁 + | ℛ | ) 𝑑 1 / 2 ) .

Therefore, we have

𝛼 2 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) 𝛼 3 ( 𝑡 + 1 ) ( [ 𝐨 𝐬 𝑗 𝐫 𝑖 ) ​ 𝛼 3 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) 𝛼 2 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] )

exp ⁡ ( 𝒪 ​ ( 𝜂 3 ​ ( 𝑁 + | ℛ | ) 𝑑 ) ) .

(179)

Using (179), for 𝑡 + 1 , we have

𝛼 2 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) 𝛼 3 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ​ 𝛼 3 ( 𝑇 2 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) 𝛼 2 ( 𝑇 2 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] )

(180)

=
exp ⁡ ( 𝒪 ​ ( 𝜂 3 ​ ( 𝑁 + | ℛ | ) 𝑑 ​ ( 𝑇 𝑝 − 𝑇 2 ) ) )

exp ⁡ ( 𝒪 ~ ​ ( ( 𝑁 + | ℛ | ) 𝑑 ​ ( 𝑚 ​ 𝑁 ​ ( 𝑁 + | ℛ | ) ) 𝜆 2 ) ) .

Under Condition 1, we have

𝛼 2 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) 𝛼 3 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ≤ 2 .

(181)

Similarly, we have

𝛼 2 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) 𝛼 3 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ≥ 0.5 .

(182)

Thus, the first statement holds at iteration 𝑡 + 1 .

Proof of the second statement: For 𝑇 2 ≤ 𝑡 ≤ 𝑇 𝑝 , the increment of the attention matrix satisfies

| 𝑍 ~ ℐ ​ ( 𝐨 ) , ℐ ​ ( 𝐬 𝑗 ) ( 𝑡 + 1 ) − 𝑍 ~ ℐ ​ ( 𝐨 ) , ℐ ​ ( 𝐬 𝑗 ) ( 𝑡 ) |

(183)


𝜂 3 ​ 𝜆 𝑛 ​ 𝑑 ​ 𝛼 1 ( 𝑡 ) ​ ( [ 𝐨 , 𝐬 𝑗 ] ) ⏟ ≤ 1 / 𝑑 5 / 4 ​ ∑ 𝑖 ∈ ℬ 𝑗 | ( 𝐞 ℐ ​ ( 𝐫 𝑖 ) − logit ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ) ⊤ ​ 𝜻 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ​ ( 𝐨 − Ξ ~ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ) | ⏟

𝒪 ~ ​ ( 𝑁 ​ ( 𝑁 + | ℛ | ) / 𝜆 ​ 𝑑 1 / 4 ) ⋅ Θ ​ ( 𝑑 )

𝒪 ​ ( 𝜂 3 ​ ( 𝑁 + | ℛ | ) 𝑑 3 / 2 ) ,

and

| 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐬 𝑗 ) ( 𝑡 + 1 ) − 𝑍 ~ ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐬 𝑗 ) ( 𝑡 ) |

(184)

=
𝜂 3 ​ 𝜆 𝑛 ​ 𝑑 ​ 𝛼 2 ( 𝑡 ) ​ ( [ 𝐨 , 𝐬 𝑗 ] ) ⏟

Θ ​ ( 1 ) ​ ∑ 𝑖 ∈ ℬ 𝑗 | ( 𝐞 ℐ ​ ( 𝐫 𝑖 ) − logit ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ) ⊤ ​ 𝜻 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ​ ( 𝐬 𝑗 − Ξ ~ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ) | ⏟

𝒪 ~ ​ ( ( 𝑁 + | ℛ | ) / 𝜆 )

𝒪 ​ ( 𝜂 3 ​ 𝐾 ​ ( 𝑁 + | ℛ | ) 𝑛 ​ 𝑑 1 / 2 ) .

So, for all 𝑇 2 < 𝑡 ≤ 𝑇 𝑝 , we have

𝛼 1 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) 𝛼 2 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ​ 𝛼 2 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) 𝛼 1 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] )

exp ⁡ ( 𝒪 ​ ( 𝜂 3 ​ 𝐾 ​ ( 𝑁 + | ℛ | ) 𝑛 ​ 𝑑 ) ) .

(185)

As a result, we have

𝛼 1 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) 𝛼 2 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ​ 𝛼 2 ( 𝑇 2 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) 𝛼 1 ( 𝑇 2 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] )

exp ⁡ ( 𝒪 ​ ( 𝜂 3 ​ 𝐾 ​ ( 𝑁 + | ℛ | ) 𝑛 ​ 𝑑 ​ ( 𝑇 𝑝 − 𝑇 2 ) ) )

(186)

=

exp ⁡ ( 𝒪 ​ ( 𝜂 3 ​ 𝐾 ​ ( 𝑁 + | ℛ | ) 𝑛 ​ 𝑑 ​ ( 𝑚 𝑁 ( 𝑁 + | ℛ | ) 𝜆 2 ) ) ,

resulting in

𝛼 1 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) 𝛼 2 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ≤ 1 𝑑 5 / 4 .

(187)

Moreover, we have

| 𝑍 ~ ℐ ​ ( 𝐨 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 + 1 ) − 𝑍 ~ ℐ ​ ( 𝐨 ) , ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) |

(188)

=
𝜂 3 ​ 𝜆 𝑛 ​ 𝑑 ​ ∑ 𝑗 ∈ 𝒟 𝑖 𝛼 1 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⏟ ≤ 1 / 𝑑 5 / 4

⋅ | ( 𝐞 ℐ ​ ( 𝐚 𝑗 ) − logit ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ) ⊤ ​ 𝜻 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⋅ ( 𝐨 − Ξ ~ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ) | ⏟

𝒪 ~ ​ ( ( 𝑁 + | ℛ | ) / 𝜆 ​ 𝑑 1 / 4 ) ⋅ Θ ​ ( 𝑑 )

𝒪 ​ ( 𝜂 3 ​ ( 𝑁 + | ℛ | ) 𝑛 ​ 𝑑 3 / 2 ) .

So, for all 𝑇 2 ≤ 𝑡 ≤ 𝑇 𝑝 , we have

𝛼 1 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) 𝛼 2 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ​ 𝛼 2 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) 𝛼 1 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] )

exp ⁡ ( 𝒪 ​ ( 𝜂 3 ( 𝑁 + | ℛ | ) ) 𝑛 ​ 𝑑 ) ) .

(189)

As a result, we have

𝛼 1 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) 𝛼 2 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ​ 𝛼 2 ( 𝑇 2 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) 𝛼 1 ( 𝑇 2 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] )

(190)

=
exp ⁡ ( 𝒪 ​ ( 𝜂 3 ( 𝑁 + | ℛ | ) ) 𝑛 ​ 𝑑 ​ ( 𝑇 𝑝 − 𝑇 2 ) ) )

exp ⁡ ( 𝒪 ~ ​ ( 𝑁 + | ℛ | 𝑛 ​ 𝑑 ​ 𝑚 ​ 𝑁 ​ ( 𝑁 + | ℛ | ) 𝜆 2 ) ) ,

resulting in

𝛼 1 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) 𝛼 2 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ≤ 1 𝑑 5 / 4 .

(191)

Therefore, we have

𝛼 1 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ≤ 1 𝑑 5 / 4 , 𝛼 1 ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ≤ 1 𝑑 5 / 4 .

(192)

The second statement holds at iteration 𝑡 + 1 .

Proof of the third statement:

After Stage 1 with 𝑡 ≥ 𝑇 2 , for 𝑙 ∈ [ 𝑚 ] satisfying ⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑡 ) , 𝐬 𝑗 ⟩

0 , the update of ⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑡 ) , 𝐬 𝑗 ⟩ satisfies

⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑡 + 1 ) , 𝐬 𝑗 ⟩

(193)

⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑡 ) , 𝐬 𝑗 ⟩ + 𝜆 ​ 𝜂 3 𝑛 ​ 𝑚 ​ ( 1 − 𝐾 ~ ​ ( 𝑗 ) ​ logit ℐ ​ ( 𝐫 𝑘 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ) − 𝜆 ​ 𝜂 3 𝑛 ​ 𝑚 ​ ∑ 𝑖 ∈ ℬ 𝑗 logit ℐ ​ ( 𝐫 𝑘 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ,

by the fact that ⟨ Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) , 𝐬 𝑗 ⟩ ≥ ‖ 𝐬 𝑗 ‖ 2 2

1 and ⟨ Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) , 𝐬 𝑗 ⟩ ≤ ‖ 𝐬 𝑗 ‖ 2 2

1 . Similarly, for all 𝑙 ∈ [ 𝑚 ] , the update of ⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑡 ) , 𝐨 ⟩ satisfies

⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑡 + 1 ) , 𝐨 ⟩ ≥

⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑡 ) , 𝐨 ⟩ − 𝒪 ​ ( 𝜆 ​ 𝜂 3 𝑛 ​ 𝑚 ​ 𝑑 1 / 4 ) ​ ∑ 𝑖 ∈ ℬ 𝑗 logit ℐ ​ ( 𝐫 𝑘 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] )

(194)

− 𝒪 ​ ( 𝜆 ​ 𝜂 3 𝑛 ​ 𝑚 ​ 𝑑 1 / 4 ) ​ ∑ 𝑗 ∈ 𝒟 𝑘 logit ℐ ​ ( 𝐫 𝑘 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) − 𝒪 ​ ( 𝜆 ​ 𝜂 3 𝑚 ) ​ logit ℐ ​ ( 𝐫 𝑘 ) ( 𝑡 ) ​ ( 2 ​ 𝐨 )

⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑡 ) , 𝐨 ⟩ − 𝒪 ​ ( 𝜆 ​ 𝜂 3 ​ 𝐾 𝑛 ​ 𝑚 ​ 𝑑 1 / 4 ) − 𝒪 ​ ( 𝜆 ​ 𝜂 3 ​ 𝐾 𝑚 ​ 𝑑 1 / 4 ) − 𝒪 ​ ( 𝜆 ​ 𝜂 3 ​ 𝐾 𝑚 ​ 𝑑 3 / 2 ) ,

As a result, by the sixth statement, we have

⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑡 + 1 ) , Ξ ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ⟩ − ⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑡 ) , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ⟩

(195)

0.9 ​ 𝜆 ​ 𝜂 3 𝑛 ​ 𝑚 ​ ( 1 − 𝐾 ~ ​ ( 𝑗 ) ​ logit ℐ ​ ( 𝐫 𝑘 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ) − 1.1 ​ 𝜆 ​ 𝜂 3 𝑛 ​ 𝑚 ​ ∑ 𝑖 ∈ ℬ 𝑗 logit ℐ ​ ( 𝐫 𝑘 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] )

− 𝒪 ​ ( 𝜆 ​ 𝜂 3 𝑚 ​ 𝑑 3 / 2 ) − 𝒪 ​ ( 𝜂 3 ​ log 2 ⁡ ( 𝑑 ) 𝜆 ​ 𝑑 ) .

Case 1: If 1 𝑚 ​ ∑ 𝑙

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑡 ) , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ⟩ ) ≤ 2 ​ log ⁡ ( 𝑑 ) / ( 3 ​ 𝜆 ) , for all 𝑖 ∈ ℬ 𝑗 , we have

1 − 𝐾 ~ ​ ( 𝑗 ) ​ logit ℐ ​ ( 𝐫 𝑘 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] )

Θ ​ ( 1 ) , logit ℐ ​ ( 𝐫 𝑘 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ≤ 1 𝑑 ,

(196)

by Lemma 12. As a result, 1 𝑚 ​ ∑ 𝑙

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑡 + 1 ) , Ξ ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ⟩ ) − 1 𝑚 ​ ∑ 𝑙

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑡 ) , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ⟩ )

Θ ​ ( 1 ) holds.

Case 2: If 1 𝑚 ​ ∑ 𝑙

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑡 ) , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ⟩ )

2 ​ log ⁡ ( 𝑑 ) / ( 3 ​ 𝜆 ) , for all 𝑙 ∈ [ 𝑚 ] , we have

⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑡 + 1 ) , Ξ ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ⟩ − ⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑡 ) , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ⟩ ≥ − 𝒪 ​ ( 𝜂 3 ​ 𝜆 𝑛 ​ 𝑚 ) .

(197)

Then, we have

1 𝑚 ​ ∑ 𝑙

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑡 + 1 ) , Ξ ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ⟩ ) ≥ log ⁡ ( 𝑑 ) 2 ​ 𝜆 .

(198)

Therefore, the third statement holds for iteration 𝑡 + 1 .

Proof of the fourth statement: Based on the fifth statement, we have

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 + 1 ) , Ξ ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⟩ ≥ ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⟩

(199)

− 1.1 ​ 𝜆 ​ 𝜂 ( 𝑡 ) ​ 𝐾 ~ ​ ( 𝑗 ) 𝑛 ​ 𝑚 ​ logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) + 𝜆 ​ Θ ​ ( 1 ) ⋅ 𝜂 ( 𝑡 ) 𝑛 ​ 𝑚 ​ ∑ 𝑘 ∈ ℬ 𝑗 ( 1 − logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑘 ] ) )

− 2 ​ 𝜆 ​ ∑ 𝑘 ≠ 𝑗 , 𝑘 ∈ 𝒟 𝑖 𝜂 ( 𝑡 ) 𝑛 ​ 𝑚 ​ logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑘 ​ 𝐫 𝑖 ] ) − 𝒪 ​ ( 𝜆 ​ 𝜂 ( 𝑡 ) 𝑚 ​ 𝑑 3 / 2 ) − 𝒪 ​ ( 𝜂 ( 𝑡 ) ​ log 2 ⁡ ( 𝑑 ) 𝜆 ​ 𝑑 ) ,

We consider two cases.

Case 1: 1 𝑚 ​ ∑ 𝑙

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⟩ ) ≤ log ⁡ ( 𝑑 ) / 6 ​ 𝜆 . By the fifth statement, we also have

1 𝑚 ​ ∑ 𝑙

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , 𝐬 𝑗 ⟩ ) ≤ log ⁡ ( 𝑑 ) 2 ​ 𝜆 , 1 𝑚 ​ ∑ 𝑙

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , 𝐫 𝑖 ⟩ ) ≤ log ⁡ ( 𝑑 ) 6 ​ 𝜆 .

(200)

A direct result is

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 + 1 ) , Ξ ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⟩ ≥

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⟩ + Θ ​ ( 𝜆 ​ 𝜂 3 𝑚 ​ 𝑛 ) − 𝒪 ​ ( 𝜂 3 ​ 𝜆 𝑛 ​ 𝑚 ​ 𝑑 )

(201)

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⟩ .

Case 2: 1 𝑚 ​ ∑ 𝑙

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⟩ ) ≥ log ⁡ ( 𝑑 ) / 6 ​ 𝜆 . In this case, we have

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 + 1 ) , Ξ ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⟩ ≥

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⟩ − 𝒪 ​ ( 𝜂 3 ​ 𝜆 𝑛 ​ 𝑚 )

(202)

log ⁡ ( 𝑑 ) 12 ​ 𝜆 .

Thus, the fourth statement holds for iteration 𝑡 + 1 .

Proof of the fifth statement: With the derivation of the fifth and the sixth statements in Lemma 13, after 𝑇 2 iterations, we have 1 / 𝑚 ​ ∑ 𝑙

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑇 2 ) , Ξ ( 𝑇 2 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ⟩ ) ≥ log ⁡ ( 𝑑 ) / 2 ​ 𝜆 and 1 / 𝑚 ​ ∑ 𝑙

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 2 ) , 𝐬 𝑗 ⟩ ≥ log ⁡ ( 𝑑 ) / 4 ​ 𝜆 . If there exists an iteration 𝑡 ∈ [ 𝑇 2 , 𝑇 𝑝 ] that 1 / 𝑚 ​ ∑ 𝑙

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , 𝐬 𝑗 ⟩ < log ⁡ ( 𝑑 ) / 3 ​ 𝜆 , for all 𝑘 ∈ ℬ 𝑗 , 𝑙 ∈ 𝒮 s , 𝑘 , 𝑗 ( 0 ) and for 𝑡 ′

𝑡 , we have

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ′ ) , 𝐬 𝑗 ⟩ ≥

log ⁡ ( 𝑑 ) / 3 ​ 𝜆 − 𝒪 ​ ( 𝜆 𝑛 ​ 𝑚 ​ 1 𝑑 3 ​ ( 𝑇 𝑝 − 𝑇 2 ) )

(203)

=

log ⁡ ( 𝑑 ) / 3 ​ 𝜆 − 𝒪 ​ ( 𝜆 ​ 𝜂 3 𝑛 ​ 𝑚 ​ 1 𝑑 ​ ( 𝑁 ​ 𝑚 ​ ( 𝑁 + | ℛ | ) ) )

log ⁡ ( 𝑑 ) / 6 ​ 𝜆 .

Thus, the conclusion holds for iteration 𝑡 + 1 .

Proof of the sixth statement: First, we prove the first part. For iteration 𝑡 ∈ [ 0 , 𝑇 𝑝 − 1 ] , we have

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 + 1 ) , 𝐨 ⟩ − ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , 𝐨 ⟩ ≤ 𝜂 ( 𝑡 ) ​ 𝜆 𝑛 ​ 𝑚 ​ ∑ 𝑖 ∈ ℬ 𝑗 𝛼 1 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ​ ( 1 − logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ) ​ ‖ 𝐨 ‖ 2 2

(204)

At pre-training Stage 3 with iteration 𝑡 ∈ [ 𝑇 2 , 𝑇 𝑝 ] , we have

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 + 1 ) , 𝐨 ⟩ − ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , 𝐨 ⟩ ≤ 𝒪 ​ ( 𝜂 3 ​ 𝜆 ​ 𝐾 𝑛 ​ 𝑚 ​ 𝑑 1 / 4 ) ,

(205)

Then, we have

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 + 1 ) , 𝐨 ⟩ ≤

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 2 ) , 𝐨 ⟩ + 𝒪 ​ ( 𝜂 3 ​ 𝜆 ​ 𝐾 𝑛 ​ 𝑚 ​ 𝑑 1 / 4 ​ ( 𝑇 𝑝 − 𝑇 2 ) )

(206)

=
𝒪 ​ ( log ⁡ ( 𝑑 ) 𝜆 ​ 𝑑 1 / 4 ) + 𝑂 ~ ​ ( 𝑁 + | ℛ | 𝜆 ​ 𝑑 1 / 4 )

𝑂 ~ ​ ( 𝑁 + | ℛ | 𝜆 ​ 𝑑 1 / 4 ) .

Next, we prove the second part. For iteration 𝑡 ∈ [ 0 , 𝑇 𝑝 − 1 ] , we have

⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 ( 𝑡 + 1 ) , 𝐨 ⟩ − ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 ( 𝑡 ) , 𝐨 ⟩ ≤ 𝜂 ( 𝑡 ) ​ 𝜆 𝑛 ​ 𝑚 ​ ∑ 𝑗 ∈ 𝒟 𝑖 𝛼 1 ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ​ ( 1 − logit ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ) ​ ‖ 𝐨 ‖ 2 2

(207)

During pre-training Stage 3 with iteration 𝑡 , we have

⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 ( 𝑡 + 1 ) , 𝐨 ⟩ − ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 ( 𝑡 ) , 𝐨 ⟩ ≤ 𝒪 ​ ( 𝜂 3 ​ 𝜆 𝑚 ​ 𝑑 1 / 4 ) ,

(208)

Then, we have

⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 ( 𝑡 + 1 ) , 𝐨 ⟩ ≤

⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 ( 𝑇 2 ) , 𝐨 ⟩ + 𝒪 ​ ( 𝜂 3 ​ 𝜆 𝑚 ​ 𝑑 1 / 4 ​ ( 𝑇 𝑝 − 𝑇 2 ) )

(209)

=

𝒪 ~ ​ ( 𝑁 ​ ( 𝑁 + | ℛ | ) 𝜆 ​ 𝑑 1 / 4 ) .

Thus the conclusion holds for iteration 𝑡 + 1 .

Proof of the seventh statement: Based on the gradient form in Lemma 1, in Stage 3 𝑡 ∈ [ 𝑇 2 , 𝑇 𝑝 − 1 ] , for all 𝑙 ∈ [ 𝑚 ] and 𝑖 ∈ ℬ 𝑗 , we have

⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 ( 𝑡 + 1 ) , 𝐬 𝑗 ⟩ − ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 ( 𝑡 ) , 𝐬 𝑗 ⟩ ≤ 2 ​ 𝜆 ​ 𝜂 3 𝑛 ​ 𝑚 ​ ( 1 − 𝐾 ~ ​ ( 𝑗 ) ​ logit ℐ ​ ( 𝐫 𝑖 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ) ,

(210)

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 + 1 ) , 𝐬 𝑗 ⟩ − ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ) , 𝐬 𝑗 ⟩ ≤ 2 ​ ∑ 𝑖 ∈ ℬ 𝑗 𝜆 ​ 𝜂 3 𝑛 ​ 𝑚 ​ ( 1 − logit ℐ ​ ( 𝐚 𝑗 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ) .

(211)

At the end of stage 2, we have

⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 ( 𝑇 2 ) , 𝐬 𝑗 ⟩

𝒪 ​ ( log ⁡ ( 𝑑 ) 𝜆 ) , ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 2 ) , 𝐬 𝑗 ⟩

𝒪 ​ ( log ⁡ ( 𝑑 ) 𝜆 ) .

(212)

As 𝑇 𝑝 − 𝑇 2

Θ ~ ​ ( ( 𝑚 ​ 𝑁 ​ ( 𝑁 + | ℛ | ) ) / ( 𝜆 2 ​ 𝜂 3 ) ) , for all 𝑡 + 1 , we have

⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 ( 𝑡 + 1 ) , 𝐬 𝑗 ⟩

𝒪 ~ ​ ( 𝑁 + | ℛ | 𝜆 ) ,

(213)

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 + 1 ) , 𝐬 𝑗 ⟩

𝒪 ~ ​ ( 𝑁 + | ℛ | 𝜆 ) ,

(214)

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 + 1 ) , 𝐫 𝑖 ⟩

𝒪 ~ ​ ( 𝑁 + | ℛ | 𝜆 ) ,

(215)

As a result, we have

⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 ( 𝑡 + 1 ) , Ξ ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ⟩

𝒪 ~ ​ ( 𝑁 + | ℛ | 𝜆 ) ,

(216)

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 + 1 ) , Ξ ( 𝑡 + 1 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] ) ⟩

𝒪 ~ ​ ( 𝑁 + | ℛ | 𝜆 ) .

(217)

Thus, the conclusion holds at iteration 𝑡 + 1 .

This completes the proof of the induction. ∎

Combined with Lemma 10, we can characterize the activation patterns during PT.

Lemma 15.

For each iteration 𝑡 ∈ [ 𝑇 𝑝 ] , for all 𝑗 ∈ [ 𝑁 ] and 𝑘 ∈ ℬ 𝑗 , with probability 1 − 𝛿 , the followings hold:

| 𝒮 s , 𝑘 , 𝑗 ( 𝑡 ) | ≥ 0.4 ​ 𝑚 and | 𝒮 r , 𝑗 , 𝑘 ( 𝑡 ) | ≥ 0.4 ​ 𝑚 .

(218)

Finally, we prove the convergence of pre-training in Stage 3.

Lemma 16.

Under condition 1, for any constant 𝜅 , there exist an iteration 0 < 𝑡 𝑝 < 𝑇 𝑝

Θ ~ ( 𝑚 𝑁 ( 𝑁 + | ℛ | ) / ( 𝜆 2 𝜂 3 ) ) + 𝑛 𝑚 log ⁡ ( 𝑑 ) / ( 𝜆 2 𝜂 2 𝐾 ) ) such that

ℒ 𝒫 ​ ( 𝐖 ( 𝑡 𝑝 ) , 𝐙 ( 𝑡 𝑝 ) ) ≤ 0.001 + ( 1 + 𝜅 ) ​ 𝐻 .

(219) Proof.

Let 𝐖 ∗ be

𝐰 𝑗 , 𝑘 ∗

𝐰 𝑗 , 𝑘 ( 0 ) + 6 𝜆 ​ log ⁡ ( 𝑑 ​ 𝜖 ( 1 − 𝜖 ) ) ​ ∑ 𝑖

1 𝑁 ( ∑ 𝑙 ∈ ℬ 𝑖 𝐫 𝑙 − 6 ​ ∑ 𝑙 ≠ 𝑖 𝐬 𝑙 ) ​ 𝕀 ​ ( 𝑗

ℐ ​ ( 𝐚 𝑖 ) )

(220)

  • 5 𝜆 ​ log ⁡ ( 𝑑 ​ 𝜖 ( 1 − 𝜖 ) ) ​ ∑ 𝑙

    1 𝐾 ( ∑ 𝑖 ∈ 𝒟 𝑙 𝐬 𝑖 − ∑ 𝑙

    1 𝐾 𝐫 𝑙 ) ​ 𝕀 ​ ( 𝑗

    ℐ ​ ( 𝐫 𝑙 ) )

  • 15 2 ​ 𝑑 ​ 𝜆 ​ log ⁡ ( 𝑑 ​ 𝜖 ( 1 − 𝜖 ) ) ​ ( 𝐨 − ∑ 𝑖

    1 𝑁 𝐬 𝑖 − ∑ 𝑖

    1 𝐾 𝐫 𝑖 ) ​ 𝕀 ​ ( 𝑗

    ℐ ​ ( 𝐝 ) ) ,

where 𝜖

0.0001 . With 𝐖 ∗ , we have

‖ 𝐖 ( 𝑡 + 1 ) − 𝐖 ∗ ‖ 𝐹 2

‖ 𝐖 ( 𝑡 ) − 𝐖 ∗ ‖ 𝐹 2 + 2 ​ 𝜂 3 ​ ⟨ ∇ 𝐖 ( 𝑡 ) ℒ 𝒫 ​ ( 𝐖 ( 𝑡 ) , 𝐙 ( 𝑡 ) ) , 𝐖 ( 𝑡 ) − 𝐖 ∗ ⟩ ⏟ 𝐴

(221)

  • 𝜂 3 2 ​ ‖ ∇ 𝐖 ( 𝑡 ) ℒ 𝒫 ​ ( 𝐖 ( 𝑡 ) , 𝐙 ( 𝑡 ) ) ‖ 𝐹 2 .

For the term ⟨ ∇ 𝐖 ( 𝑡 ) ℒ 𝒫 ​ ( 𝐖 ( 𝑡 ) , 𝐙 ( 𝑡 ) ) , 𝐖 ( 𝑡 ) − 𝐖 ∗ ⟩ , by the homogeneity of the ReLU function and the convexity of the cross-entropy loss, we have

𝐴

2 ​ 𝜂 3 ​ ⟨ ∇ 𝐖 ( 𝑡 ) ℒ 𝒫 ​ ( 𝐖 ( 𝑡 ) , 𝐙 ( 𝑡 ) ) , 𝐖 ( 𝑡 ) − 𝐖 ∗ ⟩

(222)

=
2 ​ 𝜂 3 𝑛 ​ ∑ ( 𝐗 , 𝑦 ) ∈ 𝒫 ⟨ ∇ 𝐖 ( 𝑡 ) ℒ ​ ( 𝐖 ( 𝑡 ) , 𝐙 ( 𝑡 ) , 𝐗 , 𝑦 ) , 𝐖 ( 𝑡 ) − 𝐖 ∗ ⟩

2 ​ 𝜂 3 𝑛 ​ ∑ ( 𝐗 , 𝑦 ) ∈ 𝒫 ∑ 𝑘 ∈ [ 𝑑 ] ⟨ ∂ ℒ ​ ( 𝐖 ( 𝑡 ) , 𝐙 ( 𝑡 ) , 𝐗 , 𝑦 ) ∂ 𝑥 𝑘 output ​ ( 𝐖 ( 𝑡 ) , 𝐙 ( 𝑡 ) , 𝐗 ) ​ ∇ 𝐖 ( 𝑡 ) 𝑥 𝑘 output ​ ( 𝐖 ( 𝑡 ) , 𝐙 ( 𝑡 ) , 𝐗 ) , 𝐖 ( 𝑡 ) − 𝐖 ∗ ⟩

2 ​ 𝜂 3 𝑛 ​ ∑ ( 𝐗 , 𝑦 ) ∈ 𝒫 ∑ 𝑘 ∈ [ 𝑑 ] ∂ ℒ ​ ( 𝐖 ( 𝑡 ) , 𝐙 ( 𝑡 ) , 𝐗 , 𝑦 ) ∂ 𝑥 𝑘 output ​ ( 𝐖 ( 𝑡 ) , 𝐙 ( 𝑡 ) , 𝐗 ) ​ ⟨ ∇ 𝐖 ( 𝑡 ) 𝑥 𝑘 output ​ ( 𝐖 ( 𝑡 ) , 𝐙 ( 𝑡 ) , 𝐗 ) , 𝐖 ( 𝑡 ) ⟩

− 2 ​ 𝜂 3 𝑛 ​ ∑ ( 𝐗 , 𝑦 ) ∈ 𝒫 ∑ 𝑘 ∈ [ 𝑑 ] ∂ ℒ ​ ( 𝐖 ( 𝑡 ) , 𝐙 ( 𝑡 ) , 𝐗 , 𝑦 ) ∂ 𝑥 𝑘 output ​ ( 𝐖 ( 𝑡 ) , 𝐙 ( 𝑡 ) , 𝐗 ) ​ ⟨ ∇ 𝐖 ( 𝑡 ) 𝑥 𝑘 output ​ ( 𝐖 ( 𝑡 ) , 𝐙 ( 𝑡 ) , 𝐗 ) , 𝐖 ∗ ⟩

2 ​ 𝜂 3 𝑛 ​ ∑ ( 𝐗 , 𝑦 ) ∈ 𝒫 ∑ 𝑘 ∈ [ 𝑑 ] ∂ ℒ ​ ( 𝐖 ( 𝑡 ) , 𝐙 ( 𝑡 ) , 𝐗 , 𝑦 ) ∂ 𝑥 𝑘 output ​ ( 𝐖 ( 𝑡 ) , 𝐙 ( 𝑡 ) , 𝐗 )

( 𝑥 𝑘 output ​ ( 𝐖 ( 𝑡 ) , 𝐙 ( 𝑡 ) , 𝐗 ) − ⟨ ∇ 𝐖 ( 𝑡 ) 𝑥 𝑘 output ​ ( 𝐖 ( 𝑡 ) , 𝐙 ( 𝑡 ) , 𝐗 ) , 𝐖 ∗ ⟩ ⏟ 𝐵 ) ,

Before deriving the bound of the term 𝐵 , we first recall that by Lemma 15, in each iteration, there is Θ ​ ( 𝑚 ) amount of neurons are activated and by Lemma 14 statements 1 and 2, the attention scores of tokens 𝐬 𝑗 and 𝐫 are at least 0.3 for all 𝑗 ∈ [ 𝑁 ] and 𝑖 ∈ ℬ 𝑗 . Additionally, by the form of the gradient, the gradient of data ( 𝐗 , 𝑦 ) for the MLP given by

∇ 𝐰 𝑖 , 𝑙 ( 𝑡 ) 𝑥 𝑘 output ​ ( 𝐖 ( 𝑡 ) , 𝐙 ( 𝑡 ) , 𝐗 )

{ 𝜆 𝑚 ​ 𝜎 ′ ​ ( ⟨ 𝐰 𝑖 , 𝑙 ( 𝑡 ) , 𝐱 𝑎 ​ ( 𝐙 ( 𝑡 ) , 𝐗 ) ⟩ ) ​ 𝐱 𝑎 ​ ( 𝐙 ( 𝑡 ) , 𝐗 ) ,
 if  ​ 𝑖

𝑦 ,

0

 if  ​ 𝑖 ≠ 𝑦 .

(223)

For the term 𝐵 , we consider three cases.

Case 1: ( 𝐗 , 𝑦 ) is in the form of ( [ 𝐨 ​ 𝐬 𝑗 ​ 𝐫 𝑖 ] , ℐ ​ ( 𝐚 ) ) . By the form of the gradient for the MLP in (223), the construction of 𝐖 ∗ in (220) and the number of activated neurons is at least 0.4 ​ 𝑚 , we have

𝐵 ​ { ≥ 1.2 ​ log ⁡ ( 𝜖 ​ ( 𝑑 − 1 ) / ( 1 − 𝜖 ) ) ,
if  ​ 𝑘

ℐ ​ ( 𝐚 𝑗 ) ,

= 𝒪 ~ ​ ( 𝜎 0 ) ,

if  ​ 𝑘 ≠ ℐ ​ ( 𝐚 𝑗 ) .

(224)

Case 2: ( 𝐗 , 𝑦 ) is in the form of ( [ 𝐨 ​ 𝐬 𝑗 ] , ℐ ​ ( 𝐫 𝑖 ) ) , we have

𝐵 ​ {

𝒪 ~ ​ ( 𝜎 0 ) ,
if  ​ 𝑘

ℐ ​ ( 𝐚 𝑗 ) ,

≥ 1.2 ​ log ⁡ ( 𝜖 ​ ( 𝑑 − 1 ) / ( 1 − 𝜖 ) ) ,
if  ​ 𝑘

ℐ ​ ( 𝐫 𝑖 ) , ∀ 𝑖 ∈ 𝒟 𝑗 ,

= 𝒪 ~ ​ ( 𝜎 0 ) ,

if  ​ 𝑘 ≠ ℐ ​ ( 𝐚 𝑗 ) , ℐ ​ ( 𝐫 𝑖 ) , ∀ 𝑖 ∈ 𝒟 𝑗 .

(225)

Case 3: If ( 𝐗 , 𝑦 ) is in the form of ( 𝐨 , ℐ ​ ( 𝐝 ) ) , we have

𝐵 ​ { ≥ 1.2 𝑑 log ( ( 1 − 𝜖 ) ) ,
if  ​ 𝑘

ℐ ​ ( 𝐝 ) ,

= 𝒪 ~ ​ ( 𝑑 ​ 𝜎 0 ) ,

if  ​ 𝑘 ≠ ℐ ​ ( 𝐝 ) .

(226)

Combining the three cases (224), (225), (226) and (222), we have

𝐴 ≤ 2 ​ 𝜂 3 ​ ℒ 𝒫 ​ ( 𝐖 ( 𝑡 ) , 𝐙 ( 𝑡 ) ) − 2 ​ 𝜂 3 ​ 𝜖 − 2 ​ 𝜂 3 3 ​ 𝐾 𝐾 + 1 ​ log ⁡ ( 𝐾 ) .

(227)

Additionally, for the term 𝜂 3 2 ​ ‖ ∇ 𝐖 ( 𝑡 ) ℒ 𝒫 ​ ( 𝐖 ( 𝑡 ) , 𝐙 ( 𝑡 ) ) ‖ 𝐹 2 in (221), we have

𝜂 3 2 ​ ‖ ∇ 𝐖 ( 𝑡 ) ℒ 𝒫 ​ ( 𝐖 ( 𝑡 ) , 𝐙 ( 𝑡 ) ) ‖ 𝐹 2 ≤
𝜂 3 2 ​ [ 1 𝑛 ​ ∑ 𝑖

1 𝑛 ∑ 𝑘 ∈ [ 𝑑 ] | ∂ ℒ ​ ( 𝐖 ( 𝑡 ) , 𝐙 ( 𝑡 ) , 𝐱 , 𝑦 ) ∂ 𝑥 𝑘 output ​ ( 𝐖 ( 𝑡 ) , 𝐙 ( 𝑡 ) , 𝐗 ) | ​ ‖ ∇ 𝑥 𝑘 output ​ ( 𝐖 ( 𝑡 ) , 𝐙 ( 𝑡 ) , 𝐗 ) ‖ 𝐹 ] 2

(228)


𝜂 3 2 ​ 4 2 ⋅ 4 ​ [ 1 𝑛 ​ ∑ 𝑖

1 𝑛 ( 1 − logit 𝑦 ( 𝑡 ) ​ ( 𝐗 ) ) ] 2

64 ​ 𝜂 3 2 ​ ℒ 𝒫 ​ ( 𝐖 ( 𝑡 ) , 𝐙 ( 𝑡 ) ) .

Substituting (227) and (228) into (221), we have

‖ 𝐖 ( 𝑡 + 1 ) − 𝐖 ∗ ‖ 𝐹 2 ≤

‖ 𝐖 ( 𝑡 ) − 𝐖 ∗ ‖ 𝐹 2 − 2 ​ 𝜂 3 ​ ( ℒ 𝒫 ​ ( 𝐖 ( 𝑡 ) , 𝐙 ( 𝑡 ) ) − 𝜖 − 𝐾 𝐾 + 1 ​ log ⁡ ( 𝐾 ) / 3 )

(229)

  • 64 ​ 𝜂 3 2 ​ ℒ 𝒫 ​ ( 𝐖 ( 𝑡 ) , 𝐙 ( 𝑡 ) ) .

Telescoping over 𝑡 and take 𝜂 3 ≤ 𝜅 / 64 ​ ( 𝜅 ≤ 1 / 2 ) , we have

‖ 𝐖 ( 𝑇 𝑝 ) − 𝐖 ∗ ‖ 𝐹 2 ≤
‖ 𝐖 ( 𝑇 2 + 1 ) − 𝐖 ∗ ‖ 𝐹 2 − ( 2 − 𝜅 ) ​ 𝜂 3 ​ ∑ 𝑡

𝑇 2 + 1 𝑇 𝑝 ℒ 𝒫 ​ ( 𝐖 ( 𝑡 ) , 𝐙 ( 𝑡 ) ) + 2 ​ 𝜂 3 ​ ( 𝑇 𝑝 − 𝑇 2 ) ​ 𝜖

(230)

  • 2 ​ 𝜂 3 ​ ( 𝑇 𝑝 − 𝑇 2 ) ​ 𝐾 𝐾
  • 1 ​ log ⁡ ( 𝐾 ) 3 ⏟ 𝐻 .

Rearranging (230) , we have

1 𝑇 𝑝 − 𝑇 2 ​ ∑ 𝑡

𝑇 2 + 1 𝑇 𝑝 ℒ 𝒫 ​ ( 𝐖 ( 𝑡 ) , 𝐙 ( 𝑡 ) ) ≤ ‖ 𝐖 ( 𝑇 2 ) − 𝐖 ∗ ‖ 𝐹 2 ( 2 − 𝜅 ) ​ 𝜂 3 ​ ( 𝑇 𝑝 − 𝑇 2 ) + 2 ​ ( 1 + 𝜅 ) ​ 𝜖 + ( 1 + 𝜅 ) ​ 𝐻 .

(231)

For the term ‖ 𝐖 ( 𝑇 2 ) − 𝐖 ∗ ‖ 𝐹 2 , we have

‖ 𝐖 ( 𝑇 2 ) − 𝐖 ∗ ‖ 𝐹 2 ≤

‖ 𝐖 ( 𝑇 2 ) − 𝐖 ( 0 ) ‖ 𝐹 2 + ‖ 𝐖 ( 0 ) − 𝐖 ∗ ‖ 𝐹 2

(232)

=
𝒪 ~ ​ ( 𝑚 / 𝜆 2 ) + 𝒪 ~ ​ ( 𝑚 ​ 𝑁 ​ ( 𝑁 + | ℛ | ) / 𝜆 2 )

𝒪 ~ ​ ( 𝑚 ​ 𝑁 ​ ( 𝑁 + | ℛ | ) / 𝜆 2 ) .

Setting 𝑇 𝑝 − 𝑇 2

Θ ~ ​ ( 𝑚 ​ 𝑁 ​ ( 𝑁 + | ℛ | ) / ( 𝜆 2 ​ 𝜂 3 ) ) completes the proof. ∎

Appendix GProofs of full fine-tuning

In this part, we provide the proofs of the results for full fine-tuning. We show some key properties of the activation patterns, attention scores during full FT in H.2 and G.1, and prove the results for full FT in statements of Theorem 2 in G.4 and G.5. For each FT iteration 𝑡 𝑓 ∈ [ 0 , 𝑇 𝑓 ] , we represent the iteration index with 𝑇 𝑝 + 𝑡 𝑓 .

G.1Attention scores during full FT Lemma 17.

Under Condition 1, during full FT, the following holds:

1 2 ≤ 𝛼 1 ( 𝑇 𝑝 + 0 ) ​ ( [ 𝐬 𝑗 ​ 𝐩 ] ) 𝛼 2 ( 𝑇 𝑝 + 0 ) ​ ( [ 𝐬 𝑗 ​ 𝐩 ] ) ≤ 2 .

(233) Proof.

As the attention matrix 𝐙 has never been updated during PT with respect to the query of 𝐩 ,

𝛼 1 ( 𝑇 𝑝 + 0 ) ​ ( [ 𝐬 𝑗 ​ 𝐩 ] ) 𝛼 2 ( 𝑇 𝑝 + 0 ) ​ ( [ 𝐬 𝑗 ​ 𝐩 ] )

𝛼 1 ( 0 ) ​ ( [ 𝐬 𝑗 ​ 𝐩 ] ) 𝛼 2 ( 0 ) ( [ 𝐬 𝑗 𝐩 ] ] )

exp ⁡ ( 𝒪 ~ ​ ( 𝜎 0 ) ) exp ⁡ ( 𝒪 ~ ​ ( 𝜎 0 ) ) ,

(234)

by Lemma 4. As a result, we have

0.95 ≤ 𝛼 1 ( 𝑇 𝑝 + 0 ) ​ ( [ 𝐬 𝑗 ​ 𝐩 ] ) 𝛼 2 ( 𝑇 𝑝 + 0 ) ​ ( [ 𝐬 𝑗 ​ 𝐩 ] ) ≤ 1.05 .

(235)

This finishes the proof. ∎

G.2Key property of pre-trained transformers on the fine-tuning data

We highlight the following key property of the transformer after pre-training.

Lemma 18.

With probability 1 − 𝛿 , the following holds:

For all 𝑗 ∈ [ 𝑁 ] , we have

1 𝑚 ​ ∑ 𝑙

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 0 ) , Ξ ( 𝑇 𝑝 + 0 ) ​ [ 𝐬 𝑗 ​ 𝐩 ] ⟩ )

Ω ​ ( log ⁡ ( 𝑑 ) 𝜆 ) .

(236) •

For all 𝑗 ∈ [ 𝑁 ] and 𝑖 ∈ ℬ 𝑗 , we have

max 𝑖 ∈ ℬ 𝑗 ⁡ 1 𝑚 ​ ∑ 𝑙

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 ( 𝑇 𝑝 + 0 ) , Ξ ( 𝑇 𝑝 + 0 ) ​ ( [ 𝐬 𝑗 ​ 𝐩 ] ) ⟩ ) − 1 𝑚 ​ ∑ 𝑙

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 0 ) , Ξ ( 𝑇 𝑝 + 0 ) ​ ( [ 𝐬 𝑗 ​ 𝐩 ] ) ⟩ )

𝒪 ​ ( log ⁡ ( 𝑑 ) 𝜆 ) .

(237) Proof.

By the fifth statement of Lemma 14 and the fact that for all 𝑙 ∈ [ 𝑚 ] ,

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 0 ) , 𝐩 ⟩

𝒪 ~ ​ ( 𝜎 0 ) .

(238)

the first statement holds.

Next, prove the second statement. We first derive the increment properties for iteration 𝑡 ′ satisfying

1 𝑚 ​ ∑ 𝑙

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 ( 𝑡 ′ ) , 𝐬 𝑗 ⟩ ) − 1 𝑚 ​ ∑ 𝑙

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ′ ) , 𝐬 𝑗 ⟩ ) ≥ log ⁡ ( 𝑑 ) 𝜆 .

(239)

For all 𝑖 ∈ ℬ 𝑗 and arbitrary 𝑙 𝑖 ∈ [ 𝑚 ] , we have

max 𝑖 ∈ ℬ 𝑗 ⁡ ( ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 𝑖 ( 𝑡 ′ + 1 ) , 𝐬 𝑗 ⟩ − ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 𝑖 ( 𝑡 ′ ) , 𝐬 𝑗 ⟩ ) ≤ 2 ​ ∑ 𝑖 ∈ ℬ 𝑗 ( 1 − 𝐾 ~ ​ ( 𝑗 ) ​ logit ℐ ​ ( 𝐫 𝑖 ) ( 𝑡 ′ ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ) ≤ 2 ​ 𝐾 ~ ​ ( 𝑗 ) ​ 𝜂 ( 𝑡 ′ ) ​ 𝜆 𝑚 ​ 𝑛 ​ 𝑑 .

(240)

We also have

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ′ + 1 ) , 𝐬 𝑗 ⟩ − ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ′ ) , 𝐬 𝑗 ⟩ ≥ − 2 ​ 𝐾 ~ ​ ( 𝑗 ) ​ 𝜂 ( 𝑡 ′ ) ​ 𝜆 𝑚 ​ 𝑛 ​ logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑡 ′ ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ≥ − 2 ​ 𝐾 ~ ​ ( 𝑗 ) ​ 𝜂 ( 𝑡 ′ ) ​ 𝜆 𝑚 ​ 𝑛 ​ 𝑑 ,

(241)

for all 𝑙 ∈ [ 𝑚 ] . Therefore, we have

1 𝑚 ​ ∑ 𝑖 ∈ ℬ 𝑗 ∑ 𝑙

1 𝑚 ( ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 𝑖 ( 𝑇 𝑝 ) , 𝐬 𝑗 ⟩ − ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 𝑖 ( 𝑡 ′ ) , 𝐬 𝑗 ⟩ ) − 1 𝑚 ​ ∑ 𝑙

1 𝑚 ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 ) , 𝐬 𝑗 ⟩ − ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑡 ′ ) , 𝐬 𝑗 ⟩ )

(242)


2 ​ 𝐾 ~ ​ ( 𝑗 ) ​ 𝜂 2 ​ 𝜆 𝑚 ​ 𝑛 ​ 𝑑 ​ 𝑇 2 − 2 ​ 𝐾 ~ ​ ( 𝑗 ) ​ 𝜂 3 ​ 𝜆 𝑚 ​ 𝑛 ​ 𝑑 ​ ( 𝑇 𝑝 − 𝑇 2 )

𝒪 ​ ( log ⁡ ( 𝑑 ) 𝜆 ) .

Combing with (238) and (239), we can conclude that

max 𝑖 ∈ ℬ 𝑗 ⁡ 1 𝑚 ​ ∑ 𝑙

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 ( 𝑇 𝑝 + 0 ) , Ξ ( 𝑇 𝑝 + 0 ) ​ ( [ 𝐬 𝑗 ​ 𝐩 ] ) ⟩ ) − 1 𝑚 ​ ∑ 𝑙

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 0 ) , Ξ ( 𝑇 𝑝 + 0 ) ​ ( [ 𝐬 𝑗 ​ 𝐩 ] ) ⟩ )

𝒪 ​ ( log ⁡ ( 𝑑 ) 𝜆 ) .

(243)

This finishes the proof. ∎

G.3Useful lemmas in full FT

In this subsection, we characterize the feature learning during full FT. We first define the following set for indexing the subjects in 𝒬 𝑠 :

𝒥 := { 𝑗 ∈ [ 𝑁 ] : ( [ 𝐬 𝑗 ​ 𝐩 ] , ℐ ​ ( 𝐚 𝑗 ) ) ∈ 𝒬 𝑠 } .

(244) Lemma 19.

During full FT, 𝑡 𝑓 ∈ [ 𝑇 𝑓 ] , the following holds:

For all 𝑘 ∉ { ℐ ​ ( 𝐚 𝑗 ) } 𝑗 ∈ 𝒥 and 𝑙 ∈ [ 𝑚 ] ,

⟨ 𝐰 𝑘 , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐩 ⟩

𝒪 ~ ​ ( 𝜎 0 ) .

(245) •

For all 𝑗 ∈ [ 𝑁 ] , 𝑘 ≠ 𝑗 and 𝑙 ∈ [ 𝑚 ] , we have

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐬 𝑘 ⟩

𝒪 ~ ​ ( 𝜎 0 ) .

(246) Proof.

We prove the statements sequentially.

Proof of the first statement: For all 𝑘 ∉ { ℐ ​ ( 𝐚 𝑗 ) } 𝑗 ∈ 𝒥 and 𝑙 ∈ [ 𝑚 ] , we have

⟨ 𝐰 𝑘 , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 + 1 ) , 𝐩 ⟩ − ⟨ 𝐰 𝑘 , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐩 ⟩

(247)

=

− ∑ 𝑞 ∈ 𝒥 Θ ​ ( 1 ) ​ 𝜂 𝑓 ​ 𝜆 𝑚 ​ 𝑛 ​ 𝕀 ​ ( ⟨ 𝐰 𝑘 , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , Ξ ( 𝑇 𝑝 + 𝑡 𝑓 ) ​ ( [ 𝐬 𝑞 ​ 𝐩 ] ) ⟩

0 ) ​ logit 𝑘 ( 𝑇 𝑝 + 𝑡 𝑓 ) ​ ( [ 𝐬 𝑞 ​ 𝐩 ] )

0 .

Therefore, we have

⟨ 𝐰 𝑘 , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐩 ⟩ ≤ ⟨ 𝐰 𝑘 , 𝑙 ( 𝑇 𝑝 + 0 ) , 𝐩 ⟩

𝒪 ~ ​ ( 𝜎 0 ) .

(248)

This completes the proof.

Proof of the second statement: During fine-tuning 𝑡 𝑓 ∈ [ 0 , 𝑇 𝑓 − 1 ] , for all 𝑙 ∈ [ 𝑚 ] , 𝑗 ∈ [ 𝑁 ] , and 𝑘 ≠ 𝑗 , we have

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 + 1 ) , 𝐬 𝑘 ⟩ − ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐬 𝑘 ⟩

(249)

=

− Θ ​ ( 1 ) ​ 𝜂 𝑓 ​ 𝜆 𝑚 ​ 𝑛 ​ 𝕀 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , Ξ ( 𝑇 𝑝 + 𝑡 𝑓 ) ​ ( [ 𝐬 𝑘 ​ 𝐩 ] ) ⟩

0 ) ​ logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑇 𝑝 + 𝑡 𝑓 ) ​ ( [ 𝐬 𝑘 ​ 𝐩 ] )

0 .

Therefore, by Lemma 12, for all 𝑡 𝑓 ∈ [ 𝑇 𝑓 ] , 𝑙 ∈ [ 𝑚 ] , and 𝑗 ∈ [ 𝑁 ] , 𝑘 ≠ 𝑗 , we have

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐬 𝑘 ⟩ ≤ ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 0 ) , 𝐬 𝑘 ⟩

𝒪 ~ ​ ( 𝜎 0 ) .

(250)

This completes the proof. ∎

Lemma 20.

For all 𝑙 ∈ 𝒮 s , 𝑖 , 𝑗 ( 0 ) , we have

⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑇 𝑝 + 0 ) , Ξ ( 𝑇 𝑝 + 0 ) ​ ( [ 𝐬 𝑗 ​ 𝐩 ] ) ⟩

Ω ​ ( log ⁡ ( 𝑑 ) 𝜆 ) ,

(251)

and for all 𝑙 ∈ ∪ 𝑖 ∈ ℬ 𝑗 𝒮 r , 𝑗 , 𝑖 ( 0 ) ,

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 0 ) , Ξ ( 𝑇 𝑝 + 0 ) ​ ( [ 𝐬 𝑗 ​ 𝐩 ] ) ⟩

Ω ​ ( log ⁡ ( 𝑑 ) 𝜆 ) ,

(252) Proof.

By Lemma 18, at the end of pre-training, for all 𝑙 ∈ 𝒮 s , 𝑖 , 𝑗 ( 0 ) , we have

⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑇 𝑝 + 0 ) , 𝐬 𝑗 ⟩

Ω ​ ( log ⁡ ( 𝑑 ) 𝜆 ) ,

(253)

and

⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑇 𝑝 + 0 ) , 𝐩 ⟩

𝑂 ~ ​ ( 𝜎 0 ) .

(254)

Then, by Lemma 17 we have

⟨ 𝐰 ℐ ​ ( 𝐫 𝑘 ) , 𝑙 ( 𝑇 𝑝 + 0 ) , Ξ ( 𝑇 𝑝 + 0 ) ​ [ 𝐬 𝑗 ​ 𝐩 ] ⟩

Ω ​ ( log ⁡ ( 𝑑 ) 𝜆 ) ,

(255)

Additionally, under Condition 1, for all 𝑙 ∈ ∪ 𝑖 ∈ ℬ 𝑗 𝒮 r , 𝑗 , 𝑖 ( 0 ) ,

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 0 ) , 𝐬 𝑗 ⟩

Ω ​ ( log ⁡ ( 𝑑 ) 𝜆 ) ,

(256)

and

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 0 ) , 𝐩 ⟩

𝑂 ~ ​ ( 𝜎 0 ) ,

(257)

As a result, we have

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 0 ) , Ξ ( 𝑇 𝑝 + 0 ) ​ ( [ 𝐬 𝑗 ​ 𝐩 ] ) ⟩

Ω ​ ( log ⁡ ( 𝑑 ) 𝜆 ) .

(258)

This completes the proof. ∎

Lemma 21.

For all 𝑗 ∈ [ 𝑁 ] and 𝑙 ∈ ∪ 𝑖 ∈ ℬ 𝑗 𝒮 r , 𝑖 , 𝑗 ( 0 ) , during full FT, 𝑡 𝑓 ∈ [ 𝑇 𝑓 ] , the following holds:

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐬 𝑗 ⟩

Ω ​ ( log ⁡ ( 𝑑 ) 𝜆 ) .

(259) Proof.

For all 𝑡 𝑓 ∈ [ 0 , 𝑇 𝑓 − 1 ] , 𝑙 ∈ [ 𝑚 ] and 𝑗 ∈ [ 𝑁 ] , the update of 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) satisfies

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 + 1 ) , 𝐬 𝑗 ⟩ − ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐬 𝑗 ⟩

(260)

=

Θ ​ ( 1 ) ​ 𝜆 ​ 𝜂 𝑓 𝛽 ​ 𝑁 𝑓 ​ 𝕀 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , Ξ ( 𝑇 𝑓 + 𝑡 𝑓 ) ​ ( [ 𝐬 𝑗 ​ 𝐩 ] ) ⟩

0 ) ​ 𝕀 ​ ( 𝑗 ∈ 𝒥 ) ​ ( 1 − logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑇 𝑝 + 𝑡 𝑓 ) ​ ( [ 𝐬 𝑗 ​ 𝐩 ] ) )

0 .

Then, for all 𝑡 𝑓 ∈ [ 0 , 𝑇 𝑓 − 1 ] , 𝑙 ∈ [ 𝑚 ] and 𝑗 ∈ [ 𝑁 ] , we have

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐬 𝑗 ⟩ ≥ ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 0 ) , 𝐬 𝑗 ⟩ .

(261)

Based on Lemma 14, we have

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐬 𝑗 ⟩

Ω ​ ( log ⁡ ( 𝑑 ) 𝜆 ) ,

(262)

for all 𝑗 ∈ [ 𝑁 ] and 𝑙 ∈ ∪ 𝑖 ∈ ℬ 𝑗 𝒮 r , 𝑖 , 𝑗 ( 0 ) .

This completes the proof. ∎

Lemma 22.

When 𝛽 ​ 𝑁 𝑓 ≥ 𝐶 1 ​ | ℛ | , under Condition 1, with probability 1 − 𝛿 , for all 𝑖 ∈ ℛ , 𝑡 ∈ [ 𝑇 1 , 𝑇 𝑝 ] , and 𝑙 ∈ [ 𝑚 ] , we have

ℙ ​ [ ∑ 𝑗

1 𝑁 𝜎 ′ ​ ( ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 , Ξ ( 𝑡 ) ​ ( [ 𝐨 ​ 𝐬 𝑗 ] ) ⟩ ) ≥ 0.15 ​ 𝑁 | ℛ | ] ≥ 1 − exp ⁡ ( − 𝑁 𝑓 ​ 𝛽 ​ 𝐾 2 ​ log ⁡ ( | ℛ | ) 2 ​ | ℛ | 2 ) .

(263) Proof.

By Lemma 6, with probability 1 − exp ⁡ ( − 𝑁 𝑓 ​ 𝛽 ​ 𝐾 2 ​ log ⁡ ( | ℛ | ) / ( 2 ​ | ℛ | 2 ) ) , each 𝑖 ∈ ℛ satisfies

∑ 𝑗

1 𝑁 𝕀 ​ ( 𝑖 ∈ ℬ 𝑗 ) ≥ 0.5 ​ 𝑁 | ℛ | .

(264)

Then, by Lemma 9, under Condition 1, with probability 1 − 𝛿 , each 𝑙 ∈ [ 𝑚 ] satisfies

∑ 𝑖 ∈ ℛ , 𝑖 ∈ ℬ 𝑗 𝕀 ​ ( 𝑙 ∈ 𝒮 s , 𝑖 , 𝑗 ( 0 ) ) ≥ 0.15 ​ 𝑁 | ℛ | .

(265)

Combining with the activation patterns finishes the proof. ∎

G.4Proof of the result for full FT in Statement 2 of Theorem 2

We denote

𝒥 𝑘 := { 𝑗 ∈ [ 𝑁 ] : ( [ 𝐬 𝑗 ​ 𝐩 ] , ℐ ​ ( 𝐚 𝑗 ) ) ∈ 𝒬 𝑠 , 𝑘 ∈ ℬ 𝑗 } .

(266)

After pre-training, the first iteration of full FT satisfies

∑ 𝑗 ∈ 𝒥 𝑖 logit ℐ ​ ( 𝐫 𝑖 ) ( 𝑇 𝑝 + 0 ) ​ ( [ 𝐬 𝑗 ​ 𝐩 ] )

Θ ​ ( 𝑁 𝑓 ​ 𝛽 𝐾 ⋅ 𝐾 | ℛ | )

Θ ​ ( 𝑁 𝑓 ​ 𝛽 | ℛ | ) ,

(267)

and

1 − logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑇 𝑝 + 0 ) ​ ( [ 𝐬 𝑗 ​ 𝐩 ] )

Θ ​ ( 1 ) .

(268)

By Lemmas 6 and 22, with probability 1 − exp ⁡ ( − 𝑁 𝑓 ​ 𝛽 ​ 𝐾 2 ​ log ⁡ ( | ℛ | ) / ( 2 ​ | ℛ | 2 ) ) , for all 𝑖 ∈ ℛ and 𝑘 ∈ [ 𝑚 ] , the model updates satisfy

𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑘 ( 𝑇 𝑝 + 1 ) , 𝐩 ⟩ ) − 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑘 ( 𝑇 𝑝 + 0 ) , 𝐩 ⟩ )

(269)

=
− Θ ​ ( 𝜆 ​ 𝜂 𝑓 𝑁 𝑓 ​ 𝛽 ​ 𝑚 ) ​ ∑ 𝑗 ∈ 𝒥 logit ℐ ​ ( 𝐫 𝑖 ) ( 𝑇 𝑝 + 0 ) ​ ( [ 𝐬 𝑗 ​ 𝐩 ] ) ​ 𝕀 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑘 ( 𝑇 𝑝 + 0 ) , Ξ ( 𝑇 𝑝 + 0 ) ​ ( [ 𝐬 𝑗 ​ 𝐩 ] ) ⟩ > 0 )

− Θ ​ ( 𝜂 𝑓 ​ 𝜆 𝑚 ​ 𝑁 𝑓 ​ 𝛽 ​ 𝑁 𝑓 ​ 𝛽 ​ 𝐾 | ℛ | ​ 1 𝐾 )

− Θ ​ ( 𝜆 ​ 𝜂 𝑓 | ℛ | ​ 𝑚 ) ,

and

1 𝑚 ​ ∑ 𝑘

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑖 ) , 𝑘 ( 𝑇 𝑝 + 1 ) , 𝐩 ⟩ ) − 1 𝑚 ​ ∑ 𝑘

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑖 ) , 𝑘 ( 𝑇 𝑝 + 0 ) , 𝐩 ⟩ )

(270)

=
Θ ​ ( 𝜆 ​ 𝜂 𝑓 𝑁 𝑓 ​ 𝛽 ​ 𝑚 ) ​ ( 1 − logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑇 𝑝 + 0 ) ​ ( [ 𝐬 𝑗 ​ 𝐩 ] ) )

Θ ​ ( 𝜆 ​ 𝜂 𝑓 𝑁 𝑓 ​ 𝛽 ​ 𝑚 ) .

Similarly, we have

1 𝑚 ​ ∑ 𝑘

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑖 ) , 𝑘 ( 𝑇 𝑝 + 1 ) , 𝐬 𝑖 ⟩ ) − 1 𝑚 ​ ∑ 𝑘

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑖 ) , 𝑘 ( 𝑇 𝑝 + 0 ) , 𝐬 𝑖 ⟩ )

Θ ​ ( 𝜆 ​ 𝜂 𝑓 𝑁 𝑓 ​ 𝛽 ​ 𝑚 ) .

(271)

When 𝜂 𝑓

Θ ​ ( 𝑚 ​ | ℛ | ​ log ⁡ ( 𝑑 ) / 𝜆 2 ) , we have that

1 𝑚 ​ ∑ 𝑘

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑘 ( 𝑇 𝑝 + 1 ) , 𝐩 ⟩ ) − 1 𝑚 ​ ∑ 𝑘

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑘 ( 𝑇 𝑝 + 0 ) , 𝐩 ⟩ )

− Θ ​ ( 𝜂 𝑓 ​ 𝜆 | ℛ | ​ 𝑚 )

− Θ ~ ​ ( log ⁡ ( 𝑑 ) 𝜆 ) ,

(272)

and

1 𝑚 ​ ∑ 𝑘

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑖 ) , 𝑘 ( 𝑇 𝑝 + 1 ) , 𝐩 ⟩ ) − 1 𝑚 ​ ∑ 𝑘

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑖 ) , 𝑘 ( 𝑇 𝑝 + 0 ) , 𝐩 ⟩ )

Θ ​ ( 𝜆 ​ 𝜂 𝑓 𝑁 𝑓 ​ 𝛽 ​ 𝑚 )

Θ ​ ( | ℛ | ​ log ⁡ ( 𝑑 ) 𝜆 ​ 𝑁 𝑓 ​ 𝛽 ) .

(273)

Consequently, when 𝛽 ​ 𝑁 𝑓

𝐶 1 ​ | ℛ | , after one iteration, we already have

max 𝑘 ∈ ℛ ⁡ 𝑥 ℐ ​ ( 𝐫 𝑘 ) output ​ ( 𝐖 ( 𝑇 𝑝 + 1 ) , 𝐙 ( 𝑇 𝑝 + 1 ) , [ 𝐬 𝑗 ​ 𝐩 ] ) < 𝑥 ℐ ​ ( 𝐚 𝑗 ) output ​ ( 𝐖 ( 𝑇 𝑝 + 1 ) , 𝐙 ( 𝑇 𝑝 + 1 ) , [ 𝐬 𝑗 ​ 𝐩 ] ) ,

(274)

max 𝑘 ∈ [ 𝑁 ]
{ 𝑗 } ⁡ 𝑥 ℐ ​ ( 𝐚 𝑘 ) output ​ ( 𝐖 ( 𝑇 𝑝 + 1 ) , 𝐙 ( 𝑇 𝑝 + 1 ) , [ 𝐬 𝑗 ​ 𝐩 ] ) < 𝑥 ℐ ​ ( 𝐚 𝑗 ) output ​ ( 𝐖 ( 𝑇 𝑝 + 1 ) , 𝐙 ( 𝑇 𝑝 + 1 ) , [ 𝐬 𝑗 ​ 𝐩 ] ) ,

(275)

and

max 𝑙 ∉ { ℐ ​ ( 𝐚 𝑗 ) } 𝑗 ∈ [ 𝑁 ] ∪ { ℐ ​ ( 𝐫 𝑖 ) } 𝑖 ∈ ℛ ⁡ 𝑥 𝑙 output ​ ( 𝐖 ( 𝑇 𝑝 + 1 ) , 𝐙 ( 𝑇 𝑝 + 1 ) , [ 𝐬 𝑗 ​ 𝐩 ] ) < 𝑥 ℐ ​ ( 𝐚 𝑗 ) output ​ ( 𝐖 ( 𝑇 𝑝 + 1 ) , 𝐙 ( 𝑇 𝑝 + 1 ) , [ 𝐬 𝑗 ​ 𝐩 ] ) ,

(276)

Within 𝑇 𝑓

𝒪 ​ ( 𝑑 − 0.01 ​ 𝑁 𝑓 ​ 𝛽 / | ℛ | ) , we have

1 𝑚 ​ ∑ 𝑘

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑖 ) , 𝑘 ( 𝑇 𝑝 + 𝑇 𝑓 ) , 𝐩 ⟩ ) − 1 𝑚 ​ ∑ 𝑘

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑖 ) , 𝑘 ( 𝑇 𝑝 + 𝑇 𝑓 ) , 𝐩 ⟩ )

𝒪 ​ ( | ℛ | ​ log ⁡ ( 𝑑 ) 𝜆 ​ 𝑁 𝑓 ​ 𝛽 ​ 𝑇 𝑓 )

𝒪 ​ ( 𝑑 − 0.01 ​ log ⁡ ( 𝑑 ) 𝜆 ) .

(277)

Hence, by Lemmas 19 and 21, within 𝑇 𝑓

𝒪 ​ ( 𝑑 − 0.01 ​ 𝑁 𝑓 ​ 𝛽 / | ℛ | ) the outputs of the transformer also satisfy

max 𝑘 ∈ ℛ ⁡ 𝑥 ℐ ​ ( 𝐫 𝑘 ) output ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , [ 𝐬 𝑗 ​ 𝐩 ] ) < 𝑥 ℐ ​ ( 𝐚 𝑗 ) output ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , [ 𝐬 𝑗 ​ 𝐩 ] ) ,

(278)

max 𝑘 ∈ [ 𝑁 ]
{ 𝑗 } ⁡ 𝑥 ℐ ​ ( 𝐚 𝑘 ) output ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , [ 𝐬 𝑗 ​ 𝐩 ] ) < 𝑥 ℐ ​ ( 𝐚 𝑗 ) output ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , [ 𝐬 𝑗 ​ 𝐩 ] ) ,

(279)

and

max 𝑙 ∉ { ℐ ​ ( 𝐚 𝑗 ) } 𝑗 ∈ [ 𝑁 ] ∪ { ℐ ​ ( 𝐫 𝑖 ) } 𝑖 ∈ ℛ ⁡ 𝑥 𝑙 output ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , [ 𝐬 𝑗 ​ 𝐩 ] ) < 𝑥 ℐ ​ ( 𝐚 𝑗 ) output ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , [ 𝐬 𝑗 ​ 𝐩 ] ) ,

(280)

for all 𝑡 𝑓 ∈ [ 1 , 𝑇 𝑓 ] .

This finishes the proof and concludes that

ℒ 𝑒 ​ ( 𝐖 ( 𝑇 𝑝 + 𝑇 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑇 𝑓 ) ) ≤ exp ⁡ ( − 𝑁 𝑓 ​ 𝛽 ​ 𝐾 2 ​ log ⁡ ( | ℛ | ) 2 ​ | ℛ | 2 ) .

(281) G.5Proof of the result for full FT in Statement 3 of Theorem 2

When 𝛽 ​ 𝑁 𝑓 ​ 𝐾 / | ℛ | ≤ 𝐶 3 with 0 < 𝐶 3 < 1 , there are more than ( 1 − 𝐶 3 ) ​ | ℛ | relation elements, denoted as set 𝒞 , not been covered in the FT training set. For the elements not covered, 𝑗 ∈ 𝒞 , we have

1 𝑚 ​ ∑ 𝑘

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐫 𝑗 ) , 𝑘 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐩 ⟩ ) ≤ 1 𝑚 ​ ∑ 𝑘

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐫 𝑗 ) , 𝑘 ( 𝑇 𝑝 + 0 ) , 𝐩 ⟩ )

𝒪 ~ ​ ( 𝜎 0 𝑚 ) .

(282)

For all 𝑖 ∈ 𝒞 and 𝒟 𝑖 and any 𝑡 𝑓 , we have

1 𝑚 ​ ∑ 𝑙

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , [ 𝐬 𝑗 ​ 𝐩 ] ⟩ )

Θ ​ ( 1 ) ,

(283)

resulting in

1 𝑚 ​ ∑ 𝑙

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , [ 𝐬 𝑗 ​ 𝐩 ] ⟩ ) > 1 𝑚 ​ ∑ 𝑙

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , [ 𝐬 𝑗 ​ 𝐩 ] ⟩ ) .

(284)

For the remaining ( 1 − 𝛽 ) ​ 𝑁 𝑓 + 𝑁 𝑟 data ( [ 𝐬 𝑗 ​ 𝐩 ] , ℐ ​ ( 𝐚 𝑗 ) ) ∈ 𝒬
𝒬 𝑠 , we have

ℙ ​ [ 𝑗 ∈ 𝒟 𝑖 , ∀ 𝑖 ∈ 𝒞 ] ≥ 1 − 𝐶 3 𝐾 ~ ​ ( 𝑗 ) .

(285)

Therefore, we have

ℒ 𝑒 ​ ( 𝐖 ( 𝑇 𝑝 + 𝑇 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑇 𝑓 ) ) ≥ 1 − 𝐶 3 .

(286)

This completes the proof.

Appendix HProofs of low-rank fine-tuning

We consider a low-rank fine-tuning approach, where the model update is based on the best rank-1 approximation of the full gradient. Specifically, in FT iteration 𝑡 𝑓 , the model update is

𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 + 1 )

𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) − 𝜂 𝑓 𝑛 ​ RA 1 ​ ( ∑ ( 𝐗 , 𝑦 ) ∈ 𝒬 𝑠 ∇ 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ℒ ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐗 , 𝑦 ) ) ,

𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 + 1 )

𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) − 𝜂 𝑓 𝑛 ​ RA 1 ​ ( ∑ ( 𝐗 , 𝑦 ) ∈ 𝒬 𝑠 ∇ 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) ℒ ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐗 , 𝑦 ) ) ,

where RA 1 ​ ( 𝐁 ) denotes the best rank-1 approximation of 𝐁 .

For a matrix 𝐀 , its best rank-1 approximation 𝐀 ~ is found by solving

𝐀 ~

𝐚𝐛 ⊤ ,

(287)

where

𝐚 , 𝐛

arg ⁡ min ⁡ ‖ 𝐀 − 𝐚𝐛 ⊤ ‖ 𝐹 .

(288)

Notably, by the Eckart–Young–Mirsky theorem,

𝐀 ~

𝜎 1 ​ 𝐮 1 ​ 𝐯 1 ⊤ ,

(289)

where 𝜎 1 is the largest eigenvalue of 𝐀 and 𝐮 1 , 𝐯 1 are the corresponding left and right singular vectors.

H.1Useful lemmas

First, we present two lemmas that characterize some properties of the best rank-1 approximation. Denote the full gradient at FT iteration 𝑡 𝑓 as

𝐆 ( 𝑡 𝑓 )

[ ∇ 𝐰 1 , 1 ( 𝑇 𝑝 + 𝑡 𝑓 ) ℒ ℱ ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ) ​ ⋯ ​ ∇ 𝐰 𝑑 , 𝑚 ( 𝑇 𝑝 + 𝑡 𝑓 ) ℒ ℱ ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ) ] ⊤ .

(290) Lemma 23.

The best rank-1 approximation of the full gradient 𝐆 ( 𝑡 𝑓 ) can be decomposed as

𝐆 ( 𝑡 𝑓 )

[ 𝐯 𝑝 ​ 𝐯 1 ​ ⋯ ​ 𝐯 𝑁 ] ​ [ 𝐩 ​ 𝐬 1 ​ ⋯ ​ 𝐬 𝑁 ] ⊤ .

(291)

The matrix 𝐆 ( 𝑡 𝑓 ) satisfies

‖ RA 1 ​ ( 𝐆 ( 𝑡 𝑓 ) ) − 𝐯 𝑝 ​ 𝐩 ⊤ ‖ 𝐹

𝒪 ​ ( 𝜆 𝑚 ​ 𝑁 𝑓 ​ 𝛽 ) .

(292) Proof.

By the Eckart–Young–Mirsky theorem, we derive the best rank-1 approximation of the full gradient 𝐆 ( 𝑡 𝑓 ) based on SVD. The gradient can be represented as follows. For clarity, we add | to seprate the elements in a vector in (293) and (295). For all 𝑖 ∈ ℛ and 𝑙 ∈ [ 𝑚 ] , we rearrange the gradient as follows.

∇ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ℒ ℱ ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) )

(293)

=

𝜆 𝑚 ​ 𝑁 𝑓 ​ 𝛽 ​ [ 𝐩 ​ | 𝐬 1 | ​ ⋯ | 𝐬 𝑁 ]

[ − ∑ ( [ 𝐬 𝑗 ​ 𝐩 ] , ℐ ​ ( 𝐚 𝑗 ) ) ∈ 𝒬 𝑠 , 𝑗 ∈ 𝒥 𝑖 𝕀 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , Ξ ( 𝑇 𝑝 + 𝑡 𝑓 ) ​ ( [ 𝐬 𝑗 ​ 𝐩 ] ) ⟩ ) ​ logit ℐ ​ ( 𝐫 𝑖 ) ( 𝑇 𝑝 + 𝑡 𝑓 ) ​ ( [ 𝐬 𝑗 ​ 𝐩 ] ) ⏟ 𝑢 1 |

− 𝕀 ​ ( 1 ∈ 𝒥 ) ​ 𝕀 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , Ξ ( 𝑇 𝑝 + 𝑡 𝑓 ) ​ ( [ 𝐬 1 ​ 𝐩 ] ) ⟩ ) ​ logit ℐ ​ ( 𝐫 𝑖 ) ( 𝑇 𝑝 + 𝑡 𝑓 ) ​ ( [ 𝐬 1 ​ 𝐩 ] ) ​ | ⋯ |

− 𝕀 ( 𝑁 ∈ 𝒥 ) 𝕀 ( ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , Ξ ( 𝑇 𝑝 + 𝑡 𝑓 ) ( [ 𝐬 𝑁 𝐩 ] ) ⟩ ) logit ℐ ​ ( 𝐫 𝑖 ) ( 𝑇 𝑝 + 𝑡 𝑓 ) ( [ 𝐬 𝑁 𝐩 ] ) ] ⊤ .

Here, it is worth noting that the norm of the components corresponding to 𝐬 𝑗 for all 𝑗 ∈ 𝒬 𝑠 is bounded by

‖ ∇ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ℒ ℱ ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ) + 𝑢 1 ​ 𝐩 ‖ 2

𝒪 ​ ( 𝜆 𝑚 ​ 𝑁 𝑓 ​ 𝛽 ) .

(294)

Then, for each 𝑗 ∈ 𝒥 and 𝑙 ∈ [ 𝑚 ] , we rearrange the gradient as follows.

∇ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ℒ ℱ ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) )

(295)

=
𝜆 𝑚 ​ 𝑁 𝑓 ​ 𝛽 [ 𝐩 ​ 𝐬 1 ​ ⋯ ​ 𝐬 𝑁 ] ( [ 𝕀 ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , Ξ ( 𝑇 𝑝 + 𝑡 𝑓 ) ( [ 𝐬 𝑗 𝐩 ] ) ⟩ ) ( 1 − logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑇 𝑝 + 𝑡 𝑓 ) ( [ 𝐬 𝑗 𝐩 ] ) )

− ∑ 𝑖 ≠ 𝑗 , 𝑖 ∈ 𝒥 𝕀 ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , Ξ ( 𝑇 𝑝 + 𝑡 𝑓 ) ( [ 𝐬 𝑖 𝐩 ] ) ⟩ ) logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑇 𝑝 + 𝑡 𝑓 ) ( [ 𝐬 𝑖 𝐩 ] ) ) |

− 𝕀 ​ ( 1 ∈ 𝒥 ) ​ 𝕀 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , Ξ ( 𝑇 𝑝 + 𝑡 𝑓 ) ​ ( [ 𝐬 1 ​ 𝐩 ] ) ⟩ ) ​ logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑇 𝑝 + 𝑡 𝑓 ) ​ ( [ 𝐬 1 ​ 𝐩 ] ) ​ | ⋯ |

𝕀 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , Ξ ( 𝑇 𝑝 + 𝑡 𝑓 ) ​ ( [ 𝐬 𝑗 ​ 𝐩 ] ) ⟩ ) ​ ( 1 − logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑇 𝑝 + 𝑡 𝑓 ) ​ ( [ 𝐬 𝑗 ​ 𝐩 ] ) ) ​ | ⋯ |

− 𝕀 ( 𝑁 ∈ 𝒥 ) 𝕀 ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , Ξ ( 𝑇 𝑝 + 𝑡 𝑓 ) ( [ 𝐬 𝑁 𝐩 ] ) ⟩ ) logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑇 𝑝 + 𝑡 𝑓 ) ( [ 𝐬 𝑁 𝐩 ] ) ]

𝜆 𝑚 ​ 𝑁 𝑓 ​ 𝛽 [ 𝐩 ​ 𝐬 1 ​ ⋯ ​ 𝐬 𝑁 ] ( [ 𝕀 ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , Ξ ( 𝑇 𝑝 + 𝑡 𝑓 ) ( [ 𝐬 𝑗 𝐩 ] ) ⟩ ) ( 1 − logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑇 𝑝 + 𝑡 𝑓 ) ( [ 𝐬 𝑗 𝐩 ] ) )

− ∑ 𝑖 ≠ 𝑗 , 𝑖 ∈ 𝒥 𝕀 ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , Ξ ( 𝑇 𝑝 + 𝑡 𝑓 ) ( [ 𝐬 𝑖 𝐩 ] ) ⟩ ) logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑇 𝑝 + 𝑡 𝑓 ) ( [ 𝐬 𝑖 𝐩 ] ) ) |  0 | ⋯ |

𝕀 ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , Ξ ( 𝑇 𝑝 + 𝑡 𝑓 ) ( [ 𝐬 𝑗 𝐩 ] ) ⟩ ) ( 1 − logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑇 𝑝 + 𝑡 𝑓 ) ( [ 𝐬 𝑗 𝐩 ] ) ) | ⋯ |  0 ) ] + 𝜖 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ) ,

where 𝜖 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 is the vector containing the components except those for 𝐩 and 𝐬 𝑗 . It is worth noting that the norm of 𝜖 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 satisfies

‖ 𝜖 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ‖ 2

𝒪 ​ ( 𝜆 𝑚 ​ 𝑁 𝑓 ​ 𝛽 ​ 𝑑 ) ,

(296)

by Lemma 19.

For all 𝑘 ∉ { ℐ ​ ( 𝐚 𝑗 ) } 𝑗 ∈ 𝒥 ∪ { ℐ ​ ( 𝐫 𝑖 ) } 𝑖 ∈ ℛ and 𝑘 ∈ [ 𝑑 ] , 𝑙 ∈ [ 𝑚 ] , we have

∇ 𝐰 𝑘 , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ℒ ℱ ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) )

(297)

=
𝜆 𝑚 ​ 𝑁 𝑓 ​ 𝛽 ​ [ 𝐩 ​ 𝐬 1 ​ ⋯ ​ 𝐬 𝑁 ]

− ∑ 𝑖 ∈ 𝒥 𝕀 ( ⟨ 𝐰 𝑘 , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , Ξ ( 𝑇 𝑝 + 𝑡 𝑓 ) ( [ 𝐬 𝑖 𝐩 ] ) ⟩ ) logit 𝑘 ( 𝑇 𝑝 + 𝑡 𝑓 ) ( [ 𝐬 𝑖 𝐩 ] ) ) |

− 𝕀 ​ ( 1 ∈ 𝒥 ) ​ 𝕀 ​ ( ⟨ 𝐰 𝑘 , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , Ξ ( 𝑇 𝑝 + 𝑡 𝑓 ) ​ ( [ 𝐬 1 ​ 𝐩 ] ) ⟩ ) ​ logit 𝑘 ( 𝑇 𝑝 + 𝑡 𝑓 ) ​ ( [ 𝐬 1 ​ 𝐩 ] ) ​ | ⋯ |

− 𝕀 ( 𝑁 ∈ 𝒥 ) 𝕀 ( ⟨ 𝐰 𝑘 , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , Ξ ( 𝑇 𝑝 + 𝑡 𝑓 ) ( [ 𝐬 𝑁 𝐩 ] ) ⟩ ) logit 𝑘 ( 𝑇 𝑝 + 𝑡 𝑓 ) ( [ 𝐬 𝑁 𝐩 ] ) ]

𝜖 𝑘 , 𝑙 .

The norm of ∇ 𝐰 𝑘 , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ℒ ℱ ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ) satisfies

‖ ∇ 𝐰 𝑘 , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ℒ ℱ ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ) ‖ 2

𝒪 ​ ( 𝜆 𝑚 ​ 𝑁 𝑓 ​ 𝛽 ​ 𝑑 ) .

(298)

Compactly, we decompose the full gradient as

𝐆 ( 𝑡 𝑓 )

𝐯 𝑝 ​ 𝐩 ⊤ + 𝐯 1 ​ 𝐬 1 ⊤ + ⋯ + 𝐯 𝑁 ​ 𝐬 𝑁 ⊤ + 𝜖 ,

(299)

where 𝜖 ∈ ℝ 𝑚 ​ 𝑑 × 𝑑 is a concatenation of 𝜖 𝑘 , 𝑙 for all 𝑘 ∈ [ 𝑑 ] and 𝑙 ∈ [ 𝑚 ] with 𝜖 ℐ ​ ( 𝐫 𝑖 ) , 𝑙

𝟎 . As RA 1 ​ ( 𝐆 ( 𝑡 𝑓 ) ) is the best rank-1 approximation of 𝐆 ( 𝑡 𝑓 ) , we have that

‖ 𝐆 ( 𝑡 𝑓 ) − RA 1 ​ ( 𝐆 ( 𝑡 𝑓 ) ) ‖ 𝐹 ≤ ‖ 𝐆 ( 𝑡 𝑓 ) − 𝐯 𝑝 ​ 𝐩 ⊤ ‖ 𝐹 ,

(300)

by definition. By the triangle inequality, we have

‖ RA 1 ​ ( 𝐆 ( 𝑡 𝑓 ) ) − 𝐯 𝑝 ​ 𝐩 ⊤ ‖ 𝐹 ≤

‖ 𝐆 ( 𝑡 𝑓 ) − RA 1 ​ ( 𝐆 ( 𝑡 𝑓 ) ) ‖ 𝐹 + ‖ 𝐆 ( 𝑡 𝑓 ) − 𝐯 𝑝 ​ 𝐩 ⊤ ‖ 𝐹

(301)

2 ​ ‖ 𝐆 ( 𝑡 𝑓 ) − 𝐯 𝑝 ​ 𝐩 ⊤ ‖ 𝐹 .

The term ‖ 𝐆 ( 𝑡 𝑓 ) − 𝐯 𝑝 ​ 𝐩 ⊤ ‖ 𝐹 satisfies

‖ 𝐆 ( 𝑡 𝑓 ) − 𝐯 𝑝 ​ 𝐩 ⊤ ‖ 𝐹

‖ ∑ 𝑖

1 𝑁 𝐯 𝑖 ​ 𝐬 𝑖 ‖ 𝐹 + ‖ 𝜖 ‖ 𝐹

𝒪 ​ ( 𝜆 𝑁 𝑓 ​ 𝑚 ​ 𝛽 ) + 𝒪 ​ ( 𝜆 𝑚 ​ 𝑁 𝑓 ​ 𝛽 ​ 𝑑 )

(302)

=

𝒪 ​ ( 𝜆 𝑁 𝑓 ​ 𝑚 ​ 𝛽 ) .

This completes the proof. ∎

Lemma 24.

For all 𝑖 ∈ [ 𝑚 ​ 𝑑 ] , the full gradient 𝐆 ( 𝑡 𝑓 ) satisfies

‖ 𝐞 𝑖 ⊤ ​ RA 1 ​ ( 𝐆 ( 𝑡 𝑓 ) ) ‖ 2 ≤ ‖ 𝐞 𝑖 ⊤ ​ 𝐆 ( 𝑡 𝑓 ) ‖ 2 .

(303) Proof.

Suppose that the SVD decomposition of 𝐆 ( 𝑡 𝑓 ) as 𝐔 ​ 𝚺 ​ 𝐕 ⊤ . By the Eckart–Young–Mirsky theorem, we have that

RA 1 ​ ( 𝐆 ( 𝑡 𝑓 ) )

𝜎 1 ​ 𝐮 1 ​ 𝐯 1 ⊤ ,

(304)

where 𝜎 1 ≥ 𝜎 2 ≥ ⋯ ≥ 𝜎 𝑑 are the singular values of 𝐆 ( 𝑡 𝑓 ) . Then, we have

𝐞 𝑖 ⊤ ​ 𝐆 ( 𝑡 𝑓 )

𝐞 𝑖 ⊤ ​ ∑ 𝑗

1 𝑑 𝜎 𝑗 ​ 𝐮 𝑗 ​ 𝐯 𝑗 ⊤ , 𝐞 𝑖 ⊤ ​ RA 1 ​ ( 𝐆 ( 𝑡 𝑓 ) )

𝐞 𝑖 ⊤ ​ 𝜎 1 ​ 𝐮 1 ​ 𝐯 1 ⊤ .

(305)

Taking the L2 norms, we have

‖ 𝐞 𝑖 ⊤ ​ 𝐆 ( 𝑡 𝑓 ) ‖ 2 2

‖ 𝐞 𝑖 ⊤ ​ ∑ 𝑗

1 𝑑 𝜎 𝑗 ​ 𝐮 𝑗 ​ 𝐯 𝑗 ⊤ ‖ 2 2

‖ ∑ 𝑗

1 𝑑 𝜎 𝑗 ​ 𝑢 𝑖 , 𝑗 ​ 𝐯 𝑗 ⊤ ‖ 2 2

∑ 𝑗

1 𝑑 ( 𝜎 𝑗 ​ 𝑢 𝑖 , 𝑗 ) 2 ,

(306)

and

‖ 𝐞 𝑖 ⊤ ​ RA 1 ​ ( 𝐆 ( 𝑡 𝑓 ) ) ‖ 2 2

‖ 𝐞 𝑖 ⊤ ​ 𝜎 1 ​ 𝐮 1 ​ 𝐯 1 ⊤ ‖ 2 2

‖ 𝜎 1 ​ 𝑢 𝑖 , 1 ​ 𝐯 1 ⊤ ‖ 2 2

( 𝜎 1 2 ​ 𝑢 𝑖 , 1 ) 2 ≤ ‖ 𝐞 𝑖 ⊤ ​ 𝐆 ( 𝑡 𝑓 ) ‖ 2 2 .

(307)

This completes the proof. ∎

Next, we present a lemma that characterizes the property of ∇ 𝐙 ( 𝑡 ) ℒ 𝒬 𝑠 ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) )

Lemma 25.

The gradient ∇ 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ℒ 𝒬 𝑠 ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ) satisfies

| 𝐛 1 ⊤ ​ RA 1 ​ ( ∇ 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ℒ 𝒬 𝑠 ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ) ) ​ 𝐛 2 | ≤ | 𝐛 1 ⊤ ​ ∇ 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ℒ 𝒬 𝑠 ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ) ​ 𝐛 2 | ,

(308)

for any 𝐛 1 , 𝐛 2 ∈ { 𝐬 𝑗 : [ 𝐬 𝑗 ​ 𝐩 ] ∈ 𝒬 𝑠 } ∪ { 𝐩 } .

Proof.

Suppose that 𝐌 is a unitary matrix, and each 𝑖 𝑡 ​ ℎ column corresponds to the normalized embedding whose index is 𝑖 . By the Eckart–Young–Mirsky theorem, we have that

RA 1 ​ ( ∇ 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ℒ 𝒬 𝑠 ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ) )

𝜎 1 ​ 𝐮 1 ​ 𝐯 1 ⊤ ,

(309)

where 𝜎 1 ≥ 𝜎 2 ≥ ⋯ ≥ 𝜎 𝑑 are the singular values of ∇ 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ℒ 𝒬 𝑠 ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ) . We have

𝐌 ⊤ ​ ∇ 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ℒ 𝒬 𝑠 ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ) ​ 𝐌

∑ 𝑖

1 𝑑 𝜎 𝑖 ​ 𝐌 ⊤ ​ 𝐮 𝑖 ​ 𝐯 𝑖 ⊤ ​ 𝐌 ,

(310)

and

𝐌 ⊤ ​ RA 1 ​ ( ∇ 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ℒ 𝒬 𝑠 ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ) ) ​ 𝐌

𝜎 1 ​ 𝐌 ⊤ ​ 𝐮 1 ​ 𝐯 1 ⊤ ​ 𝐌 .

(311)

Then, we have

‖ 𝐞 𝑖 ⊤ ​ 𝐌 ⊤ ​ ∇ 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ℒ 𝒬 𝑠 ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ) ​ 𝐌 ‖ 2 2

(312)

=
‖ 𝐞 𝑖 ⊤ ​ ∑ 𝑗

1 𝑑 𝜎 𝑗 ​ 𝐌 ⊤ ​ 𝐮 𝑗 ​ 𝐯 𝑗 ⊤ ​ 𝐌 ‖ 2 2

‖ ∑ 𝑗

1 𝑑 𝜎 𝑗 ​ 𝐞 𝑖 ⊤ ​ 𝐌 ⊤ ​ 𝐮 𝑗 ​ 𝐯 𝑗 ⊤ ​ 𝐌 ‖ 2 2

∑ 𝑗

1 𝑑 ( 𝜎 𝑗 ​ 𝐞 𝑖 ⊤ ​ 𝐌 ⊤ ​ 𝐮 𝑗 ) 2 ,

and

‖ 𝐞 𝑖 ⊤ ​ 𝐌 ⊤ ​ RA 1 ​ ( ∇ 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ℒ 𝒬 𝑠 ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ) ) ​ 𝐌 ‖ 2 2

(313)

=
‖ 𝐞 𝑖 ⊤ ​ 𝜎 1 ​ 𝐌 ⊤ ​ 𝐮 1 ​ 𝐯 1 ⊤ ​ 𝐌 ‖ 2 2

‖ 𝐞 𝑖 ⊤ ​ 𝐌 ⊤ ​ ∇ 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ℒ 𝒬 𝑠 ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ) ​ 𝐌 ‖ 2 2 .

During fine-tuning, only ( 𝐌 ⊤ ​ 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ​ 𝐌 ) ℐ ​ ( 𝐬 𝑗 ) , ℐ ​ ( 𝐩 ) , for all [ 𝐬 𝑗 ​ 𝐩 ] ∈ 𝒬 𝑠 and ( 𝐌 ⊤ ​ 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ​ 𝐌 ) ℐ ​ ( 𝐩 ) , ℐ ​ ( 𝐩 ) are updated. So, we have

‖ 𝐞 ℐ ​ ( 𝐬 𝑗 ) ⊤ ​ 𝐌 ⊤ ​ ∇ 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ℒ 𝒬 𝑠 ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ) ​ 𝐌 ‖ 2

| 𝐬 𝑗 ⊤ ​ ∇ 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ℒ 𝒬 𝑠 ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ) ​ 𝐩 | ,

(314)

‖ 𝐞 ℐ ​ ( 𝐩 ) ⊤ ​ 𝐌 ⊤ ​ ∇ 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ℒ 𝒬 𝑠 ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ) ​ 𝐌 ‖ 2

| 𝐩 ⊤ ​ ∇ 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ℒ 𝒬 𝑠 ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ) ​ 𝐩 | ,

‖ 𝐞 ℐ ​ ( 𝐬 𝑗 ) ⊤ ​ 𝐌 ⊤ ​ RA 1 ​ ( ∇ 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ℒ 𝒬 𝑠 ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ) ) ​ 𝐌 ‖ 2

| 𝐬 𝑗 ⊤ ​ RA 1 ​ ( ∇ 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ℒ 𝒬 𝑠 ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ) ) ​ 𝐩 | ,

‖ 𝐞 ℐ ​ ( 𝐩 ) ⊤ ​ 𝐌 ⊤ ​ RA 1 ​ ( ∇ 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ℒ 𝒬 𝑠 ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ) ) ​ 𝐌 ‖ 2

| 𝐩 ⊤ ​ RA 1 ​ ( ∇ 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ℒ 𝒬 𝑠 ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) ) ) ​ 𝐩 | .

This completes the proof. ∎

H.2Features in low-tank FT

In this subsection, we characterize the features during low-rank FT.

Lemma 26.

For all 𝑗 ∈ [ 𝑁 ] , 𝑘 ≠ 𝑗 , during low-rank FT, 𝑡 𝑓 ∈ [ 𝑇 𝑓 ] , the following holds:

1 𝑚 ​ ∑ 𝑙

1 𝑚 ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐬 𝑘 ⟩

𝒪 ​ ( log ⁡ ( 𝑑 ) 𝜆 ​ 𝑑 0.01 ) .

(315) Proof.

During low-rank fine-tuning with 𝑡 ∈ [ 0 , 𝑇 𝑓 − 1 ] , for all 𝑙 ∈ [ 𝑚 ] , 𝑗 ∈ [ 𝑁 ] , and 𝑘 ≠ 𝑗 , we have

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 + 1 ) , 𝐬 𝑘 ⟩ − ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 ) , 𝐬 𝑘 ⟩ ≤

𝒪 ​ ( 𝜆 ​ 𝜂 𝑓 𝑚 ​ 𝑁 𝑓 ​ 𝛽 ) ,

(316)

by Lemmas 24 and 25. Therefore, by Lemma 12, for all 𝑡 𝑓 ∈ [ 𝑇 𝑓 ] with 𝑇 𝑓

𝒪 ​ ( 𝑑 − 0.01 ​ 𝑁 𝑓 ​ 𝛽 / | ℛ | ) and for all 𝑙 ∈ [ 𝑚 ] , and 𝑗 ∈ [ 𝑁 ] , 𝑘 ≠ 𝑗 , we have

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐬 𝑘 ⟩ ≤ ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 0 ) , 𝐬 𝑘 ⟩

𝒪 ​ ( 𝑇 𝑓 ​ 𝜆 ​ 𝜂 𝑓 𝑚 ​ 𝑁 𝑓 ​ 𝛽 )

𝒪 ​ ( log ⁡ ( 𝑑 ) 𝜆 ​ 𝑑 0.01 ) .

(317)

This completes the proof. ∎

Lemma 27.

Under Condition 1, for all 𝑗 ∈ [ 𝑁 ] , 𝑙 ∈ ∪ 𝑖 ∈ ℬ 𝑗 𝒮 r , 𝑗 , 𝑖 ( 0 ) , during low-rank FT, 𝑡 𝑓 ∈ [ 𝑇 𝑓 ] , the following holds:

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐬 𝑗 ⟩

Ω ​ ( log ⁡ ( 𝑑 ) 𝜆 ) .

(318) Proof.

According to Lemma 14, for all 𝑗 ∈ [ 𝑁 ] and 𝑙 ∈ ∪ 𝑖 ∈ ℬ 𝑗 𝒮 r , 𝑗 , 𝑖 ( 0 ) , we have

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 0 ) , 𝐬 𝑗 ⟩

Ω ​ ( log ⁡ ( 𝑑 ) / 𝜆 ) .

(319)

By Lemma 24, for all 𝑙 ∈ [ 𝑚 ] , the update of 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) satisfies

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 + 1 ) , 𝐬 𝑗 ⟩ − ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐬 𝑗 ⟩ ≥ − 𝒪 ​ ( 𝜆 ​ 𝜂 𝑓 𝑚 ​ 𝑁 𝑓 ​ 𝛽 ) .

(320)

For 𝑡 𝑓 ≤ 𝑇 𝑓

𝒪 ​ ( 𝑁 𝑓 ​ 𝛽 ​ 𝑑 − 0.01 / | ℛ | ) and 𝜂 𝑓

Θ ​ ( 𝑚 ​ | ℛ | ​ log ⁡ ( 𝑑 ) / 𝜆 2 ) , we have

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐬 𝑗 ⟩ − ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 0 ) , 𝐬 𝑗 ⟩ ≥ − 𝒪 ​ ( 𝑑 − 0.01 ​ log ⁡ ( 𝑑 ) / 𝜆 ) .

(321)

Hence, for all 𝑙 ∈ ∪ 𝑖 ∈ ℬ 𝑗 𝒮 r , 𝑗 , 𝑖 ( 0 ) , we have

⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐬 𝑗 ⟩

Ω ​ ( log ⁡ ( 𝑑 ) / 𝜆 ) .

(322)

Averaging over all 𝑙 ∈ [ 𝑚 ] completes the proof. ∎

H.3Attention scores during low-rank FT

With the same proof as Lemma 17, we have the following lemma.

Lemma 28.

Under Condition 1, during low-rank FT, the following holds:

1 2 ≤ 𝛼 1 ( 𝑇 𝑝 + 0 ) ​ ( [ 𝐬 𝑗 ​ 𝐩 ] ) 𝛼 2 ( 𝑇 𝑝 + 0 ) ​ ( [ 𝐬 𝑗 ​ 𝐩 ] ) ≤ 2 .

(323) H.4Proof of the result for low-rank FT in Statement 2 of Theorem 2

Recall that at the first FT iteration, the pre-trained model satisfies

∑ 𝑗 ∈ 𝒥 𝑖 logit ℐ ​ ( 𝐫 𝑖 ) ( 𝑇 𝑝 + 0 ) ​ ( [ 𝐬 𝑗 ​ 𝐩 ] )

Θ ​ ( 𝑁 𝑓 ​ 𝛽 𝐾 ⋅ 𝐾 | ℛ | )

Θ ​ ( 𝑁 𝑓 ​ 𝛽 | ℛ | ) ,

(324)

and

1 − logit ℐ ​ ( 𝐚 𝑗 ) ( 𝑇 𝑝 + 0 ) ​ ( [ 𝐬 𝑗 ​ 𝐩 ] )

Θ ​ ( 1 ) .

(325)

By Lemma 23 and Lemma 6, with probability 1 − exp ⁡ ( − 𝑁 𝑓 ​ 𝛽 ​ 𝐾 2 ​ log ⁡ ( | ℛ | ) / ( 2 ​ | ℛ | 2 ) ) , for all 𝑖 ∈ ℛ , the updates with low-rank FT satisfy

1 𝑚 ​ ∑ 𝑘

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑘 ( 𝑇 𝑝 + 1 ) , 𝐩 ⟩ ) − 1 𝑚 ​ ∑ 𝑘

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑘 ( 𝑇 𝑝 + 0 ) , 𝐩 ⟩ ) ≤

− Θ ​ ( 𝜂 𝑓 ​ 𝜆 𝑁 𝑓 ​ 𝛽 ​ 𝑁 𝑓 ​ 𝛽 ​ 𝐾 | ℛ | ​ 1 𝐾 ​ 𝑚 ) + 𝒪 ​ ( 𝜆 ​ 𝜂 𝑓 𝑚 ​ 𝑁 𝑓 ​ 𝛽 )

(326)

=

− Θ ​ ( 𝜆 ​ 𝜂 𝑓 | ℛ | ​ 𝑚 ) .

By Lemma 24, for all 𝑗 ∈ 𝒥 , we have

1 𝑚 ​ ∑ 𝑘

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑘 ( 𝑇 𝑝 + 1 ) , 𝐩 ⟩ ) − 1 𝑚 ​ ∑ 𝑘

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑘 ( 𝑇 𝑝 + 0 ) , 𝐩 ⟩ )

𝒪 ​ ( 𝜆 ​ 𝜂 𝑓 𝑁 𝑓 ​ 𝛽 ​ 𝑚 ) ,

(327)

and

1 𝑚 ​ ∑ 𝑘

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑘 ( 𝑇 𝑝 + 1 ) , 𝐬 𝑖 ⟩ ) − 1 𝑚 ​ ∑ 𝑘

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑘 ( 𝑇 𝑝 + 0 ) , 𝐬 𝑖 ⟩ )

𝒪 ​ ( 𝜆 ​ 𝜂 𝑓 𝑁 𝑓 ​ 𝛽 ​ 𝑚 ) .

(328)

When 𝜂 𝑓

Θ ​ ( 𝑚 ​ | ℛ | ​ log ⁡ ( 𝑑 ) / 𝜆 2 ) , within a single low-rank FT iteration, we have

1 𝑚 ​ ∑ 𝑘

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑘 ( 𝑇 𝑝 + 1 ) , 𝐩 ⟩ ) − 1 𝑚 ​ ∑ 𝑘

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑘 ( 𝑇 𝑝 + 0 ) , 𝐩 ⟩ )

− Θ ​ ( log ⁡ ( 𝑑 ) 𝜆 ) ,

(329)

and

1 𝑚 ​ ∑ 𝑘

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑖 ) , 𝑘 ( 𝑇 𝑝 + 1 ) , 𝐩 ⟩ ) − 1 𝑚 ​ ∑ 𝑘

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑖 ) , 𝑘 ( 𝑇 𝑝 + 0 ) , 𝐩 ⟩ )

𝒪 ​ ( 𝜆 ​ 𝜂 𝑓 𝑁 𝑓 ​ 𝛽 ​ 𝑚 )

𝒪 ​ ( | ℛ | ​ log ⁡ ( 𝑑 ) 𝜆 ​ 𝑁 𝑓 ​ 𝛽 ) .

(330)

Consequently, when 𝛽 ​ 𝑁 𝑓

𝐶 1 ​ | ℛ | , after one iteration, we already have

max 𝑘 ∈ ℛ ⁡ 𝑥 ℐ ​ ( 𝐫 𝑘 ) output ​ ( 𝐖 ( 𝑇 𝑝 + 1 ) , 𝐙 ( 𝑇 𝑝 + 1 ) , [ 𝐬 𝑗 ​ 𝐩 ] ) < 𝑥 ℐ ​ ( 𝐚 𝑗 ) output ​ ( 𝐖 ( 𝑇 𝑝 + 1 ) , 𝐙 ( 𝑇 𝑝 + 1 ) , [ 𝐬 𝑗 ​ 𝐩 ] ) ,

(331)

max 𝑘 ∈ [ 𝑁 ]
{ 𝑗 } ⁡ 𝑥 ℐ ​ ( 𝐚 𝑘 ) output ​ ( 𝐖 ( 𝑇 𝑝 + 1 ) , 𝐙 ( 𝑇 𝑝 + 1 ) , [ 𝐬 𝑗 ​ 𝐩 ] ) < 𝑥 ℐ ​ ( 𝐚 𝑗 ) output ​ ( 𝐖 ( 𝑇 𝑝 + 1 ) , 𝐙 ( 𝑇 𝑝 + 1 ) , [ 𝐬 𝑗 ​ 𝐩 ] ) ,

(332)

and

max 𝑙 ∉ { ℐ ​ ( 𝐚 𝑗 ) } 𝑗 ∈ [ 𝑁 ] ∪ { ℐ ​ ( 𝐫 𝑖 ) } 𝑖 ∈ ℛ ⁡ 𝑥 𝑙 output ​ ( 𝐖 ( 𝑇 𝑝 + 1 ) , 𝐙 ( 𝑇 𝑝 + 1 ) , [ 𝐬 𝑗 ​ 𝐩 ] ) < 𝑥 ℐ ​ ( 𝐚 𝑗 ) output ​ ( 𝐖 ( 𝑇 𝑝 + 1 ) , 𝐙 ( 𝑇 𝑝 + 1 ) , [ 𝐬 𝑗 ​ 𝐩 ] ) ,

(333)

Within 𝑇 𝑓

𝒪 ​ ( 𝑑 − 0.01 ​ 𝑁 𝑓 ​ 𝛽 / | ℛ | ) , we have

1 𝑚 ​ ∑ 𝑘

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑖 ) , 𝑘 ( 𝑇 𝑝 + 𝑇 𝑓 ) , 𝐩 ⟩ ) − 1 𝑚 ​ ∑ 𝑘

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑖 ) , 𝑘 ( 𝑇 𝑝 + 𝑇 𝑓 ) , 𝐩 ⟩ )

𝒪 ​ ( | ℛ | ​ log ⁡ ( 𝑑 ) 𝜆 ​ 𝑁 𝑓 ​ 𝛽 ​ 𝑇 𝑓 )

𝒪 ​ ( 𝑑 − 0.01 ​ log ⁡ ( 𝑑 ) 𝜆 ) .

(334)

Hence, by Lemmas 26 and 27, within 𝑇 𝑓

𝒪 ​ ( 𝑑 − 0.01 ​ 𝑁 𝑓 ​ 𝛽 / | ℛ | ) the outputs of the transformer also satisfy

max 𝑘 ∈ ℛ ⁡ 𝑥 ℐ ​ ( 𝐫 𝑘 ) output ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , [ 𝐬 𝑗 ​ 𝐩 ] ) < 𝑥 ℐ ​ ( 𝐚 𝑗 ) output ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , [ 𝐬 𝑗 ​ 𝐩 ] ) ,

(335)

max 𝑘 ∈ [ 𝑁 ]
{ 𝑗 } ⁡ 𝑥 ℐ ​ ( 𝐚 𝑘 ) output ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , [ 𝐬 𝑗 ​ 𝐩 ] ) < 𝑥 ℐ ​ ( 𝐚 𝑗 ) output ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , [ 𝐬 𝑗 ​ 𝐩 ] ) ,

(336)

and

max 𝑙 ∉ { ℐ ​ ( 𝐚 𝑗 ) } 𝑗 ∈ [ 𝑁 ] ∪ { ℐ ​ ( 𝐫 𝑖 ) } 𝑖 ∈ ℛ ⁡ 𝑥 𝑙 output ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , [ 𝐬 𝑗 ​ 𝐩 ] ) < 𝑥 ℐ ​ ( 𝐚 𝑗 ) output ​ ( 𝐖 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , [ 𝐬 𝑗 ​ 𝐩 ] ) ,

(337)

for all 𝑡 𝑓 ∈ [ 1 , 𝑇 𝑓 ] .

This finishes the proof and concludes that

ℒ 𝑒 ​ ( 𝐖 ( 𝑇 𝑝 + 𝑇 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑇 𝑓 ) ) ≤ exp ⁡ ( − 𝑁 𝑓 ​ 𝛽 ​ 𝐾 2 ​ log ⁡ ( | ℛ | ) 2 ​ | ℛ | 2 ) .

(338)

This completes the proof.

H.5Proof of the result for low-rank FT in Statement 3 of Theorem 2

When 𝛽 ​ 𝑁 𝑓 ​ 𝐾 / | ℛ | ≤ 𝐶 3 with 0 < 𝐶 3 < 1 , there are more than ( 1 − 𝐶 3 ) ​ | ℛ | relation elements, denoted as set 𝒞 , not been covered in the FT training set. For the elements not covered, 𝑗 ∈ 𝒞 , by Lemma 24, with 𝜂 𝑓

Θ ​ ( 𝑚 ​ | ℛ | ​ log ⁡ ( 𝑑 ) / 𝜆 2 ) and 𝑇 𝑓

𝒪 ​ ( 𝑑 − 0.01 ​ 𝑁 𝑓 ​ 𝛽 / | ℛ | ) we have

1 𝑚 ​ ∑ 𝑘

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐫 𝑗 ) , 𝑘 ( 𝑇 𝑝 + 𝑡 𝑓 ) , 𝐩 ⟩ ) ≤
1 𝑚 ​ ∑ 𝑘

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐫 𝑗 ) , 𝑘 ( 𝑇 𝑝 + 0 ) , 𝐩 ⟩ ) + 𝒪 ​ ( 𝑡 𝑓 ​ 𝜂 𝑓 ​ 𝜆 𝑚 ​ 𝑁 𝑓 ​ 𝛽 ​ 𝑑 )

(339)

=

𝒪 ~ ​ ( 𝜎 0 𝑚 ) + 𝒪 ​ ( log ⁡ ( 𝑑 ) 𝑑 ​ 𝜆 ) .

(340)

For all 𝑖 ∈ 𝒞 and 𝒟 𝑖 and any 𝑡 𝑓 , we have

1 𝑚 ​ ∑ 𝑙

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , [ 𝐬 𝑗 ​ 𝐩 ] ⟩ )

Θ ​ ( 1 ) ,

(341)

resulting in

1 𝑚 ​ ∑ 𝑙

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐫 𝑖 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , [ 𝐬 𝑗 ​ 𝐩 ] ⟩ ) > 1 𝑚 ​ ∑ 𝑙

1 𝑚 𝜎 ​ ( ⟨ 𝐰 ℐ ​ ( 𝐚 𝑗 ) , 𝑙 ( 𝑇 𝑝 + 𝑡 𝑓 ) , [ 𝐬 𝑗 ​ 𝐩 ] ⟩ ) .

(342)

For the remaining ( 1 − 𝛽 ) ​ 𝑁 𝑓 + 𝑁 𝑟 data ( [ 𝐬 𝑗 ​ 𝐩 ] , ℐ ​ ( 𝐚 𝑗 ) ) ∈ 𝒬
𝒬 𝑠 , we have

ℙ ​ [ 𝑗 ∈ 𝒟 𝑖 , ∀ 𝑖 ∈ 𝒞 ] ≥ 1 − 𝐶 3 𝐾 ~ ​ ( 𝑗 ) .

(343)

Therefore, we have

ℒ 𝑒 ​ ( 𝐖 ( 𝑇 𝑝 + 𝑇 𝑓 ) , 𝐙 ( 𝑇 𝑝 + 𝑇 𝑓 ) ) ≥ 1 − 𝐶 3 .

(344)

This completes the proof.

Appendix IClaim of 𝛿

In total, using union bound, with probability 1 − Θ ​ ( 𝛿 ′ ) , the high probability lemmas in the above proofs including Lemmas 4, 7, 9, 10, 11, 13, 18 hold. In the formal Statement of Theorem 2, we use 𝛿

Θ ​ ( 𝛿 ′ ) .

Appendix JExperimental settings J.1Settings for Figures 2a and 2b

We train a one-layer transformer based on the architecture presented in Section 2. The architecture and training parameters are as follows.

Number of neurons: 𝑚

100

Scaling factor: 𝜆

200

Optimizer: SGD with learning rates 𝜂 1

0.1 , 𝜂 2

0.05 and 𝜂 3

0.01

To ensure the model fully converges, during pre-training, we implement ℎ

3 for training stage 1, ℎ 2

5003 for training stage 2, and a total of 10003 iterations. During fine-tuning, we implement a total of 2000 iterations to determine the optimal number of iterations for OOD generalization.

To avoid the dominance of the gradient from frequent facts over those of rare facts as discussed in Section 5, we reweight the gradients from freq and rare by multiplying the gradients from rare with multiplicity 𝐾 .

All experiments are run over 5 random seeds.

J.2Settings for Figures 2c and 2d

We train a standard GPT-2 with 12 layers, 12 heads, and 786 hidden dimensions without implementing any positional encoding. For pre-training, we employ full fine-tuning and the AdamW optimizer with a learning rate of 1 × 10 − 5 and a weight decay of 0.1 .

For fine-tuning, we employ the Adam optimizer with learning rate 5 × 10 − 6 . We implement a total of 2000 iterations and find the best number of iterations for OOD generalization.

To avoid the dominance of the gradient from frequent facts over those of rare facts, as discussed in Section 5, we reweight the gradients from frequent facts and rare facts by multiplying the gradients from rare facts by multiplicity 𝐾 .

All experiments are run over 5 random seeds.

J.3Settings for Figure 2e

We train a standard GPT-2 with 12 layers, 12 heads, and 786 hidden dimensions.

We employ the AdamW optimizer for pre-training, using a learning rate of 2 × 10 − 4 and a weight decay of 0.1 . We also employ a learning rate warmup with 50 steps and a cosine learning rate schedule. We implement a batch size of 128 and a total of 150 epochs.

For fine-tuning, we employ full fine-tuning with the AdamW optimizer, using a learning rate of 1 × 10 − 4 and a weight decay of 0.01. We employ a learning rate warmup with 20 steps and a constant learning rate schedule. We implement a batch size of 128 and a total of 50 epochs.

J.4Settings for Table 3

We employ full fine-tuning and the AdamW optimizer with a learning rate of 1 × 10 − 4 and weight decay of 0.01. We employ a constant learning rate schedule. We implement a batch size of 128 and a total of 50 epochs.

Appendix KPre-training and fine-tuning question answering examples constructed from PopQA ’director’ relation

We show some pre-training and fine-tuning examples of the constructed dataset from PopQA ’director’ relation.

Table 4:Examples of constructed pre-training and fine-tuning data with PopQA under the setting of Figure 2e. Pre-training examples

Question prompts

Answers

Atlantic was helmed by Ewald André Dupont [EOS]

Question: Who was the director of Atlantic? Answer:

Ewald André Dupont

Bhoopathi Ranga was directed by Geethapriya [EOS]

Question: Who was the director of Bhoopathi Ranga?

Geethapriya

The Cell was crafted by Tarsem Singh [EOS]

Question: Who was the director of Atlantic? Answer:

Tarsem Singh

The Good Earth was shaped by Sidney Franklin [EOS]

Question: Who was the director of The Good Earth? Answer:

Sidney Franklin Appendix LFine-tuning question answering examples constructed from PopQA ’capital’ relation

We show fine-tuning examples of the constructed dataset from PopQA ’capital’ relation.

Table 5:Examples of constructed fine-tuning data with PopQA under the setting of Table 3. Question prompts

Answers

Question: Where is the capital of Ilocos Region? Answer:

San Fernando

Question: Where is the capital of Cayman Islands? Answer:

George Town

Question: Where is the capital of Massachusetts? Answer:

Boston

Question: Where is the capital of Cherokee County? Answer:

Canton Report Issue Report Issue for Selection Generated by L A T E xml Instructions for reporting errors

We are continuing to improve HTML versions of papers, and your feedback helps enhance accessibility and mobile support. To report errors in the HTML that will help us improve conversion and rendering, choose any of the methods listed below:

Click the "Report Issue" button. Open a report feedback form via keyboard, use "Ctrl + ?". Make a text selection and click the "Report Issue for Selection" button near your cursor. You can use Alt+Y to toggle on and Alt+Shift+Y to toggle off accessible reporting links at each section.

Our team has already identified the following issues. We appreciate your time reviewing and reporting rendering errors we may not have found yet. Your efforts will help us improve the HTML versions for all readers, because disability should not be a barrier to accessing research. Thank you for your continued support in championing open access for all.

Have a free development cycle? Help support accessibility at arXiv! Our collaborators at LaTeXML maintain a list of packages that need conversion, and welcome developer contributions.

Xet Storage Details

Size:
210 kB
·
Xet hash:
cd7367d934acf42eae7acc07874dc697721a5c233d93580fa8192c52583f9ce8

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.