Buckets:

|
download
raw
68 kB

Title: Do Language Models Plan Ahead for Future Tokens?

URL Source: https://arxiv.org/html/2404.00859

Markdown Content: Back to arXiv

This is experimental HTML to improve accessibility. We invite you to report rendering errors. Use Alt+Y to toggle on accessible reporting links and Alt+Shift+Y to toggle off. Learn more about this project and help improve conversions.

Why HTML? Report Issue Back to Abstract Download PDF Abstract 1Introduction 2Related work 3Theory: Pre-caching or breadcrumbs? 4Synthetic data experiments 5Natural language experiments 6Discussion and future work References License: CC BY-SA 4.0 arXiv:2404.00859v2 [cs.LG] 01 Aug 2024 Do Language Models Plan Ahead for Future Tokens? Wilson Wu Department of Mathematics University of Colorado Boulder wiwu2390@colorado.edu &John X. Morris Department of Computer Science Cornell University jxm3@cornell.edu \ANDLionel Levine Department of Mathematics Cornell University levine@math.cornell.edu Abstract

Do transformers β€œthink ahead” during inference at a given position? It is known transformers prepare information in the hidden states of the forward pass at time step 𝑑 that is then used in future forward passes 𝑑 + 𝜏 . We posit two explanations for this phenomenon: pre-caching, in which off-diagonal gradient terms present during training result in the model computing features at 𝑑 irrelevant to the present inference task but useful for the future, and breadcrumbs, in which features most relevant to time step 𝑑 are already the same as those that would most benefit inference at time 𝑑 + 𝜏 . We test these hypotheses by training language models without propagating gradients to past timesteps, a scheme we formalize as myopic training. In a constructed synthetic data setting, we find clear evidence for pre-caching. In the autoregressive language modeling setting, our experiments are more suggestive of the breadcrumbs hypothesis, though pre-caching increases with model scale.

1Introduction

Humans are known to think ahead while speaking; decades of linguistics research (Huettig, 2015; Miller, 1951) have shown evidence that human language users internally predict upcoming language input, words and sometimes sentences ahead (Barthel et al., 2016).

Unlike humans, contemporary language models allocate a fixed amount of information processing for each token when β€œspeaking” (Vaswani et al., 2017). Do language models, like humans, think ahead? Recent work (Pal et al., 2023; Hernandez et al., 2024) has shown that tokens beyond the immediate next token can be predicted by probing the hidden state of the language model. Model outputs at future tokens can be predicted to some extent using linear probes on model hidden states, and interventions on hidden states can predictably alter future outputs.1

These findings indicate that model activations at a given timestep are at least somewhat predictive of future outputs. However, it remains unclear why this might be: is this just a happenstance property of the data, or because the model is deliberately preparing information for future timesteps, at the expense of degrading performance on the current position?

We observe that gradients during training optimize weights for both the loss at the current token position as well as for tokens later in the sequence. We question to what extent current transformer weights dedicate resources to the current token vs. allocating it for future tokens.

We consider two possibilities: the pre-caching hypothesis, in which the transformer learns to compute features at time step 𝑑 that are irrelevant to the inference task at that current time step but may be useful for future time steps 𝑑 + 𝜏 , and the breadcrumbs hypothesis, in which the features most relevant to time step 𝑑 are already identical to those that would most benefit inference at time 𝑑 + 𝜏 . To evaluate which hypothesis might be correct, we propose a myopic training scheme that does not propagate gradients from the loss at the current position to hidden states from previous positions. We then evaluate the myopia gap in performance between myopically trained and vanilla transformers as a measure of pre-caching.

To consider whether language models might directly implement pre-caching, we design a synthetic scenario where the task can only be completed via explicit pre-caching. We configure a task where the model must precompute information for the next token, because otherwise the correct answer could not be accurately computed in a single forward pass. In this synthetic scenario, we find clear evidence that the transformer learns to pre-cache. When transformer-based sequence models must precompute information to minimize loss, they do so.

We then consider whether breadcrumbs or pre-caching is demonstrated in natural language models. Our experiments with myopic training suggest that, with small language models like GPT-2 (Radford et al., 2019), much less pre-caching occurs in this setting, pointing towards the breadcrumbs hypothesis. That is, we claim language models on this scale do not intentionally prepare information for the future to a significant extent. Instead, they compute features that are useful to predicting the immediate next token, which turn out to then be helpful at future steps; there is not a significant tradeoff between greedily optimizing for next token loss and ensuring future predictive performance.

However, we also find evidence that the importance of pre-caching increases with scale, becoming non-negligible with larger models, e.g. Pythia 2.8B (Biderman et al., 2023). This suggests that these larger models are β€œplanning for the future” in a way that small models cannot.

Figure 1: At which position is the computation required to correctly answer this math problem taking place? Cognitive science tells us that humans think ahead while speaking; we investigate the extent to which language models do the same. 2Related work Future token meta-prediction.

Several recent works (nostalgebraist, 2020; Belrose et al., 2023; Pal et al., 2023; Cai et al., 2024) observe that transformer hidden states can be used to predict current and future tokens in a sequence, typically via linear probing. Notably, Hernandez et al. (2024) show that more complicated relationships are encoded linearly in hidden states, such as subject-object relations, implying that future tokens can also be predicted in specific cases. This future token predictivity has also been applied to speeding up inference by decoding future tokens in parallel (Stern et al., 2018; Cai et al., 2024). Unlike these works, we focus on the question of how the model learns to prepare hidden states that are useful for future prediction, possibly at the expense of current-token predictivity.

Probing.

Our synthetic data experiments make use of probing, a technique where a simple auxiliary model is used to predict properties from target models’ representations (Belinkov & Glass, 2019; Shi et al., 2016; Hewitt & Liang, 2019; Pimentel et al., 2020; Belinkov, 2021). Probing-based approaches are known to overestimate latent information if the classifier learns to do a task on its own (Belinkov, 2021), and probing analyses may only be informative when compared to probing a reasonable baseline (Hewitt & Liang, 2019). In our probing experiments, we avoid these pitfalls by ensuring that the function to be learned cannot possibly be computed by the probe itself.

Mechanistic interpretability.

Our analysis of transformer models in a synthetic setting relates to the subfield of mechanistic interpretability, which seeks to understand models by isolating and explaining the behavior of their components (Olah et al., 2020; Bau et al., 2020; Meng et al., 2023; Nanda et al., 2023). Some of these works (Nanda et al., 2023; Li et al., 2023; Zhong et al., 2023) practice mechanistic interpretability by studying models trained on synthetic data. We apply some mechanistic interpretability techniques in a synthetic setting to study the problem of whether language models β€œthink ahead” for future tokens. However, our approach also differs from that of mechanistic interpretability by analyzing the effect of the training procedure on the learned model.

3Theory: Pre-caching or breadcrumbs?

Consider a generic causal sequence-to-sequence prediction task

( 𝐱 1 , … , 𝐱 𝑛 , 𝐲 1 , … , 𝐲 𝑛 ) ∼ π’Ÿ ,

where π’Ÿ is a data distribution supported on 𝕏 𝑛 Γ— 𝕐 𝑛 for some domains 𝕏 , 𝕐 . The task is to estimate the conditional expectations 𝔼 π’Ÿ ⁒ ( 𝐲 𝑖 ∣ 𝐱 1 , … , 𝐱 𝑖 ) for 1 ≀ 𝑖 ≀ 𝑛 .2 Note that we recover the autoregressive setting by setting 𝕐

𝕏 and 𝐲 𝑖

𝐱 𝑖 + 1 .

Transformer models trained on such tasks have been observed (Pal et al., 2023) to store information in hidden states during inference at position 𝑖 that is then used in future inference at 𝑗

𝑖 . However, since the loss associated with each step 𝑖 depends only on how well the model does at the immediate task of predicting 𝐲 𝑖 , it may not be immediately obvious how this preparation for the future arises. We give names to two competing explanations:

β€’

Pre-caching: The model β€œdeliberately” computes and stores features that are expected to be useful for the future, even if they are irrelevant to the present.

β€’

Breadcrumbs: The features that most benefit the present inference task are the same as those that are most useful to the future. When the model performs the present forward pass, it β€œunintentionally” leaves a trace (β€œbreadcrumbs”) that is then picked up by future passes.

To disentangle these two explanations, we introduce a notion of myopic transformer models, which we show to be incapable of deliberate pre-cachingβ€”for these models, the extent to which past features are beneficial to the future is decided purely by the breadcrumbs explanation. Thus, the gap between vanilla and myopic transformer models is a quantitative measure of how much pre-caching is taking place.

3.1Causal sequence modeling

Suppose, for the sake of exposition, that the transformer model 𝐺 uses independent parameters for each position. 3 Let 𝑝 be the parameter count of each forward pass of 𝐺 . Then, letting 𝜽 𝑖 ∈ Θ

ℝ 𝑝 be all parameters used by 𝐺 at position 𝑖 , a transformer 𝐺 is a parameterized function

𝐺 : 𝕏 𝑛 Γ— Θ 𝑛 β†’ 𝕐 𝑛 , ( 𝐱 1 , … , 𝐱 𝑛 ; 𝜽 1 , … , 𝜽 𝑛 ) ↦ ( 𝐲 ^ 1 , … , 𝐲 ^ 𝑛 ) .

For 1 ≀ 𝑖 ≀ 𝑛 , let 𝐺 𝑖 ⁒ ( 𝐱 1 , … , 𝐱 𝑛 ) ∈ 𝕐 be the output of 𝐺 ’s 𝑖 th forward pass. Because of the causal masking within 𝐺 , this depends only on 𝐱 1 , … , 𝐱 𝑖 and 𝜽 1 , … , 𝜽 𝑖 . That is, with slight abuse of notation, we may write

𝐲 ^ 𝑖

𝐺 𝑖 ⁒ ( 𝐱 1 , … , 𝐱 𝑛 ; 𝜽 1 , … , 𝜽 𝑛 )

𝐺 𝑖 ⁒ ( 𝐱 1 , … , 𝐱 𝑖 ; 𝜽 1 , … , 𝜽 𝑖 ) .

3.2Off-diagonal gradient terms

Now, letting β„’ : 𝕐 Γ— 𝕐 β†’ ℝ + be some choice of loss function, the expected loss β„“ of a transformer model with parameters 𝜽 1 , … , 𝜽 𝑛 is

β„“ ( 𝜽 1 , … , 𝜽 𝑛 ) := 𝔼 ( 𝐱 β†’ , 𝐲 β†’ ) ∼ π’Ÿ βˆ‘ 𝑖

1 𝑛 β„’ ( 𝐺 𝑖 ( 𝐱 1 , … , 𝐱 𝑖 ; 𝜽 1 , … , 𝜽 𝑖 ) , 𝐲 𝑖 )

: βˆ‘ 𝑖

1 𝑛 β„“ 𝑖 ( 𝜽 1 , … , 𝜽 𝑖 ) ,

the sum over 1 ≀ 𝑖 ≀ 𝑛 of the expected loss β„“ 𝑖 at position 𝑖 . (We suppress the dependence on 𝐺 and π’Ÿ for concision.) In practice, we always tie the weights across position. That is, all 𝜽 𝑖 are set equal to the same 𝜽 ∈ Θ . Then, by the chain rule,

βˆ‡ 𝜽 β„“ ( 𝜽 , … , 𝜽 )

βˆ‘ 𝑖

1 𝑛 βˆ‡ 𝜽 𝑖 β„“ ( 𝜽 1 , … , 𝜽 𝑛 ) | 𝜽 1

…

𝜽 𝑛

𝜽

βˆ‘ 𝑖

1 𝑛 βˆ‘ 𝑗

𝑖 𝑛 βˆ‡ 𝜽 𝑖 β„“ 𝑗 ( 𝜽 1 , … , 𝜽 𝑗 ) | 𝜽 1

…

𝜽 𝑛

𝜽 ,

a sum over an upper-triangular expected Jacobian β€œmatrix”. The off-diagonal terms 𝑖 < 𝑗 , corresponding to the expected gradient of the model’s future loss at position 𝑗 with respect to its weights at position 𝑖 , are the training signals that encourage pre-caching.

3.3Measuring pre-caching: The myopia gap

We say a model is myopic when each forward pass 𝐺 𝑖 optimizes only β„“ 𝑖 without regard for future β„“ 𝑗 at 𝑗

𝑖 . In the untied weights case, the right definition is then apparent.

Definition 1.

The parameters ( 𝛉 ~ 1 , … , 𝛉 ~ 𝑛 ) ∈ Θ 𝑛 are untied-myopic if they satisfy

𝜽 ~ 𝑖 ∈ arg ⁒ min 𝜽 𝑖 ⁑ β„“ 𝑖 ⁒ ( 𝜽 ~ 1 , … , 𝜽 ~ 𝑖 βˆ’ 1 , 𝜽 𝑖 ) βˆ€ 𝑖 ∈ { 1 , … , 𝑛 } .

(1) Definition 2.

Let 𝕄 be the feasible set of the constraints in Equation 1. The untied myopia gap is the smallest possible gap between the expected loss attained by a myopic model and the optimal model:

𝑝 βˆ— := min ( 𝜽 ~ 1 , … , 𝜽 ~ 𝑛 ) ∈ 𝕄 ⁑ β„“ ⁒ ( 𝜽 ~ 1 , … , 𝜽 ~ 𝑛 ) βˆ’ min ( 𝜽 1 , … , 𝜽 𝑛 ) ∈ Θ 𝑛 ⁑ β„“ ⁒ ( 𝜽 1 , … , 𝜽 𝑛 ) β‰₯ 0 .

(2)

In the tied weights case, it is perhaps not immediately clear what the right definition of myopia should be. It does not suffice to simply constrain the minimizations in Equation 1 to 𝜽 ~ 1

…

𝜽 ~ 𝑛 , since min 𝜽 ⁑ β„“ 𝑖 ⁒ ( 𝜽 , … , 𝜽 ) is optimizing for pre-caching (the dependence on arguments 𝑗 < 𝑖 ) as well as the present inference (the dependence on argument 𝑖 ). Instead, the right notion is a choice of tied parameters such that the model is, aggregated over positions, optimal for the present task when conditioned on a fixed past. That is, forward passes do not compute features for the future if they can compute other features more beneficial to the present.

Definition 3.

The parameters 𝛉 ~ ∈ Θ are (tied-)myopic if they satisfy

𝜽 ~ ∈ arg ⁒ min 𝜽 ∈ Θ ⁒ βˆ‘ 𝑖

1 𝑛 β„“ 𝑖 ⁒ ( 𝜽 ~ , … , 𝜽 ~ , 𝜽 ) .

(3)

The (tied) myopia gap is then defined analogously to Definition 2.

The breadcrumbs hypothesis states that the myopia gap is smallβ€”near-optimal performance can be attained even when each forward pass is computing features relevant to only its own immediate inference task, with no regard to pre-caching for the future.

If the breadcrumbs hypothesis does not hold, we say that the model is pre-caching. It is important to remember that the β„“ 𝑖 depend on a choice of transformer model 𝐺 and dataset π’Ÿ . That is, breadcrumbs and pre-caching are properties of the model architecture and the data considered as a whole.

Although a small myopia gap reveals that one can do just as well without pre-caching, it does not say much about any specific model. To measure pre-caching within a given model, we examine the extent to which its parameters violate the myopia constraints.

Definition 4.

The (tied) local myopia bonus at 𝛉 βˆ— ∈ Θ is

𝑐 ^ ⁒ ( 𝜽 βˆ— ) := max 𝜽 ∈ Θ ⁒ βˆ‘ 𝑖

1 𝑛 ( β„“ 𝑖 ⁒ ( 𝜽 βˆ— , … , 𝜽 βˆ— ) βˆ’ β„“ 𝑖 ⁒ ( 𝜽 βˆ— , … , 𝜽 βˆ— , 𝜽 ) ) β‰₯ 0 .

For further interpretation of the myopia gap and myopia bonus, see Appendix A.

3.4Myopic gradient descent

Our heuristic remark in Section 3.2, that the off-diagonal gradient terms are responsible for pre-caching, is justified by Theorem 13 below. It states that, given certain regularity conditions on the loss terms β„“ 𝑖 , performing gradient descent with the off-diagonal terms removed results in a myopic model in the sense of Definition 3. We call this myopic descent.

For myopic descent to be stable in the tied-weights case, we need, roughly speaking, for the model to depend more on the parameters associated with the present forward pass than those from the past. This is a plausible conditionβ€”dependence on the past is mediated purely by the attention mechanism, while the present forward pass depends on both attention and feedforward parameters. The precise condition we use is forward bias (Definition 11); see Appendix C for details and the proof of the theorem.

Theorem 13.

Let 𝑓 ⁒ ( 𝛉 ~ , 𝛉 ) := βˆ‘ 𝑖

1 𝑛 β„“ 𝑖 ⁒ ( 𝛉 ~ , … , 𝛉 ~ , 𝛉 ) . If 𝑓 is forward-biased, 𝜎 -strongly convex, and 𝐿 -smooth, then, for some step size πœ‚

0 , the iterates of myopic descent with tied weights

𝜽 ( 𝑑 + 1 )

𝜽 ( 𝑑 ) βˆ’ πœ‚ βˆ‡ 𝜽 𝑓 ( 𝜽 ~ , 𝜽 ) | 𝜽 ~

𝜽

𝜽 ( 𝑑 )

converge to 𝛉 ~ ∈ Θ satisfying the myopia constraints of Equation 3.

4Synthetic data experiments 4.1The π’Ÿ 𝑝 task

To demonstrate a simple example where the transformer model learns to pre-cache (and thus the myopia gap is large), we construct the following synthetic dataset.

Definition 5.

The data distribution π’Ÿ 𝑝 , π‘Ž , 𝑏 𝑁 is defined as the joint distribution of real-valued random variables ( x 𝑛 ) 𝑛

1 𝑁 , ( y 𝑛 ) 𝑛

1 𝑁 , ( z 𝑛 ) 𝑛

1 𝑁 where, for each 𝑛 ,

β€’

x 𝑛 ∼ 𝒩 ⁒ ( 0 , 1 ) (standard Gaussian)

β€’

z 𝑛 ∼ Ber ⁒ ( 𝑝 ) (Bernoulli with probability 𝑝 )

β€’

y 𝑛

z 𝑛 ⁒ βˆ‘ 𝑖

1 π‘Ž sin ⁑ ( 𝑏 ⁒ x 𝑛 βˆ’ 𝑖 ) + ( 1 βˆ’ z 𝑛 ) ⁒ x 𝑛

and { x 𝑛 } 𝑛 ∈ β„• βˆͺ { z 𝑛 } 𝑛 ∈ β„• are mutually independent. In our experiments, we always set the parameters π‘Ž

𝑏

10 and 𝑁

64 , so for convenience notate π’Ÿ 𝑝 := π’Ÿ 𝑝 , 10 , 10 64 .

The intuition is that a transformer regression model 𝐺 trained on π’Ÿ 𝑝 would benefit from pre-caching sin ⁑ ( 𝑏 ⁒ x 𝑛 ) during its forward pass at position 𝑛 , even though this computation is irrelevant to its task of predicting y 𝑛 . One simple strategy that makes use of this pre-caching is Algorithm 1. 4

The motivation for the Bernoulli variables z 𝑖 is that, as 𝑝 decreases, the expected first time when sin ⁑ ( 𝑏 ⁒ 𝐱 𝑛 ) becomes useful advances further into the future. In addition, when 𝑝 is sufficiently small, the probability ( 1 βˆ’ 𝑝 ) π‘Ž that the value sin ⁑ ( 𝑏 ⁒ x 𝑛 ) is never useful at all becomes non-negligible. We will show that, even in this case, the transformer model learns to pre-cache.

Investigating myopia.

Suppose that we train a myopic model (Section 3.4) on the same task. Since this model lacks off-diagonal gradient terms, we do not expect it to learn to pre-cache sin ⁑ ( 𝑏 ⁒ x 𝑛 ) at position 𝑛 . One possible strategy that does not use pre-caching is Algorithm 2. We expect this brute force algorithm to perform significantly worse given the same parameter countβ€”it computes an π‘Ž -dimensional nonlinear function within a single layer, while each layer of Algorithm 1 computes only scalar nonlinear functions.5

Algorithm 1 Pre-caching algorithm

At position 𝑛 ,

input   x 𝑛 , z 𝑛 layer 1 compute   𝐹 𝑛 := sin ⁑ ( 𝑏 ⁒ x 𝑛 ) layer 2 read   𝐹 𝑛 βˆ’ 𝑖 for 𝑖

1 , … , π‘Ž . layer 2 compute y ^ 𝑛 := z 𝑛 ⁒ βˆ‘ 𝑖

1 π‘Ž 𝐹 𝑛 βˆ’ 𝑖 + ( 1 βˆ’ z 𝑛 ) ⁒ x 𝑛 . return    y ^ 𝑛 .

At position 𝑛 ,

Algorithm 2 Brute force algorithm input   x 𝑛 , z 𝑛 layer 1 compute   βˆ… layer 2 read   x 𝑛 βˆ’ 𝑖 for 𝑖

1 , … , π‘Ž . layer 2 compute y ^ 𝑛 := z 𝑛 ⁒ βˆ‘ 𝑖

1 π‘Ž sin ⁑ ( 𝑏 ⁒ x 𝑛 βˆ’ 𝑖 ) + ( 1 βˆ’ z 𝑛 ) ⁒ x 𝑛 . return    y ^ 𝑛 . 4.1.1Evaluation: linear probing

To determine if the transformer model is computing sin ⁑ ( 𝑏 ⁒ x 𝑛 ) at position 𝑛 , we fit linear probes on the hidden states. We additionally compute the correlations between sin ⁑ ( 𝑏 ⁒ x 𝑛 ) and each individual dimension (i.e., each neuron) of each hidden state.6 See Section D.1.1 for details.

4.1.2Results for π’Ÿ 𝑝

For varying 𝑝 , we train two-layer transformer models with embedding dimensions of 128 on π’Ÿ 𝑝 using using both ordinary and myopic gradient descent. Full architecture and training details are provided in Section D.1.

Examining the performance of each linear probe against sin ⁑ ( 𝑏 ⁒ x 𝑛 βˆ’ 𝑖 ) for varying 𝑖 , we find strong evidence that the transformer model with vanilla training is indeed pre-caching sin ⁑ ( 𝑏 ⁒ x 𝑛 ) , possibly in order to implement Algorithm 1. Indeed, in Figure 2,

β€’

The zeroth hidden state (i.e., the sum of the input and position embeddings) at position 𝑛 is correlated with only x 𝑛 .

β€’

The first hidden state is correlated with sin ⁑ ( 𝑏 ⁒ x 𝑛 ) but not correlated with any sin ⁑ ( 𝑏 ⁒ x 𝑛 βˆ’ 𝑖 ) for 𝑖

0 .

β€’

The second hidden state (immediately before the output unembedding) is correlated with sin ⁑ ( 𝑏 ⁒ x 𝑛 βˆ’ 𝑖 ) for each 0 ≀ 𝑖 ≀ π‘Ž .

𝑝 Vanilla Myopic 0.01 0.096

1.10

0.1 0.016

0.97

0.3 0.0030

1.03

1.0 0.0074

1.26 Table 1:Normalized Huber loss β„’ / 𝑝 for vanilla and myopic models trained and evaluated on π’Ÿ 𝑝 for each 𝑝 in our synthetic setting. For reference, the trivial model that always outputs zero attains a Huber loss of 1.26 .

Further, looking at the per-neuron correlations in Figure 5, we see that sin ⁑ ( 𝑏 ⁒ x 𝑛 βˆ’ 𝑖 ) for 1 ≀ 𝑖 ≀ π‘Ž are all correlated with a single 1-d subspace of the second hidden state (they share the same striping pattern); this is the subspace storing βˆ‘ 𝑖

1 π‘Ž sin ⁑ ( 𝑏 ⁒ x 𝑛 βˆ’ 𝑖 ) . Meanwhile, sin ⁑ ( 𝑏 ⁒ x 𝑛 ) , as well as many of the x 𝑛 βˆ’ 𝑖 , are located in various other 1-d subspace of the second hidden state; these terms are all left over in the residual stream from previous layers, and are cleaned up only by the output unembedding.

On the other hand, in Table 1, the myopic models perform significantly worse. The per-neuron correlations in Figure 5 suggest that the myopic model may be implementing a crude approximation of Algorithm 2. This suggests that the synthetic setting has an inherently high myopia gapβ€”it is impossible for the transformer model to do well without pre-caching.

Figure 2: Empirical 𝑅 2 between linear probes fit on each layer of vanilla transformer models trained on π’Ÿ 𝑝 for 𝑝 ∈ { 0.01 , 0.1 , 0.3 , 1 } to targets sin ⁑ ( 𝑏 ⁒ x 𝑛 βˆ’ 𝑖 ) . Computed over 50 000 samples from π’Ÿ 1 . Figure 3:*

Vanilla model

Figure 4:*

Myopic model

Figure 5: Empirical correlations between each hidden state neuron and x 𝑛 βˆ’ 𝑖 or sin ⁑ ( 𝑏 ⁒ x 𝑛 βˆ’ 𝑖 ) . Models are vanilla (left two columns) and myopic (right two columns) transformers trained on π’Ÿ 0.3 . 4.2Multiplication

In addition to the above π’Ÿ 𝑝 synthetic data setting, we also measure the myopia gap on the task of natural number multiplication. In particular, we find evidence suggesting that pre-caching is responsible for model computation on filler tokens, in the sense of Pfau et al. (2024). See Appendix D.3 for details.

5Natural language experiments 5.1GPT-2’s myopia gap

In order to measure the extent to which transformer models learn to pre-cache on natural language data, we estimate both the myopia gap (Definition 3) in this setting as well as the local myopia bonus (Definition 4) of a transformer model with vanilla pre-training. Experiments in this subsection use the 124M-parameter GPT-2 architecture; see Table 4 in Appendix D for configuration details.

We train all models (vanilla and myopic) from random initialization for one epoch on 4.6M sequences from the MS MARCO dataset (Nguyen et al., 2016), truncated to length 64. To estimate the local myopia bonus of the vanilla model, we train another model from random initialization with the same architecture, but with past hidden states sourced from the frozen vanilla model during both training and evaluation.7 See Appendix B for implementation details.

As baseline, we also train a β€œtransformer bigram” model, a model with an identical architecture but all off-diagonal key/value states zeroed out.

5.1.1GPT-2 results

From Table 2, the estimated myopia gap in this setting is 3.40 βˆ’ 3.28

0.12 cross entropy, while the local myopia bonus of the vanilla model is 3.28 βˆ’ 3.26

0.02 .

The nonzero myopia gap suggests that pre-caching may provide a small positive benefit. Indeed, in Figure 6, we see that the myopic model outperforms the vanilla model at the beginning of the sequence, since it can allocate all compute to next-token prediction, but quickly falls behind as the length of the past increases, since it suffers from a lack of pre-cached information from earlier forward passes.8

Model Cross-entropy Vanilla 3.28 Myopic 3.40 Local myopic 3.26 Transformer bigram 5.33 Table 2:Validation cross-entropy loss obtained by GPT-2 with vanilla and myopic training

However, note that this gap is much smaller than that between the vanilla model and the transformer bigram model (Table 2). That is, the myopic model is still able to leverage past information (breadcrumbs) to a significant extent, even if they optimized only for the present inference task. That the local myopia gap is near zero further supports this directionβ€”the model learned through vanilla training does not trade off significantly between features useful for the present and pre-caching for the future.

Figure 6: Cross-entropy loss of vanilla and myopic GPT-2 models by token position, and their difference. Evaluated on a sliding window over a 100K-token sample text from the PG-19 dataset (Rae et al., 2019). Aggregate cross-entropy losses on this sample are 4.67 (vanilla) and 4.77 (myopic). 5.2Myopia gap scaling

One might suppose that the relatively small myopia gap of GPT-2 is due to the relative simplicity of the small architecture we consider, and that larger language models might exhibit a more pronounced myopia gap.

To test this, we train both vanilla and myopic transformers from the Pythia LLM suite (Biderman et al., 2023), ranging in size from 14M to 2.8B parameters, on one epoch of 10M sequences of 64 tokens each subsampled from the Pile dataset (Gao et al., 2020). (We use the same subsampled dataset for every training run.) We report validation cross-entropy loss (Figure 9 in Appendix D) as well as performance on a variety of natural language experiments (Figure 7). Note that, unlike in the GPT-2 experiments (Section 5.1), which start from random initialization, we start all training for Pythia models from the pre-trained checkpoints provided by Biderman et al. (2023)β€”for the larger architectures, the 10M sequence dataset we use is not sufficiently large to use for pre-training from random initialization.

LAMBADA (Paperno et al., 2016) PIQA (Bisk et al., 2020)

SciQ (Welbl et al., 2017) ARC-Easy (Zhang et al., 2018) Figure 7:Benchmarks of Pythia models fine-tuned on the Pile dateset using vanilla and myopic descent. 6Discussion and future work

Using a synthetic dataset, we demonstrate that pre-caching can indeed be learned by a transformer model. On the other hand, our experiments with natural language suggest that the breadcrumbs hypothesis is more explanatory for that setting, especially with smaller models, but that the importance of pre-caching increases with scale.

If the myopia gap is indeed not large in practice, there may be several applications of myopic training. We hypothesize that myopic transformers may have advantages in terms of safety and/or interpretabilityβ€”it may be easier to understand what a model is doing if we know that everything that it is computing on each forward pass is directly towards the goal of predicting the immediate next token. For example, as seen in Appendix D.3, myopic models may not be able to make use of computation on the forward passes of filler tokens in the sense of Pfau et al. (2024).

Another possibility is that of automatically swapping in a locally myopic model (Section 5.1) on forward passes where we detect it is beneficial to sacrifice future performance in favor of immediate next-token accuracy (for example, on especially important tokens, or near the end of a text). We leave these possible applications to future work.

Acknowledgments

LL thanks Lukas Berglund and David Schneider-Joseph for inspiring conversations. This research was partly supported by Open Philanthropy and the Berkeley Existential Risk Initiative.

References Barthel et al. (2016) ↑ Maik Barthel, Sebastian Sauppe, Stephen C Levinson, and Antje S Meyer.The timing of utterance planning in task-oriented dialogue: Evidence from a novel list-completion paradigm.Frontiers in psychology, 7:1858, December 2016.doi: 10.3389/fpsyg.2016.01858. Bau et al. (2020) ↑ David Bau, Jun-Yan Zhu, Hendrik Strobelt, Agata Lapedriza, Bolei Zhou, and Antonio Torralba.Understanding the role of individual units in a deep neural network.Proceedings of the National Academy of Sciences, 117(48):30071–30078, September 2020.ISSN 1091-6490.doi: 10.1073/pnas.1907375117.URL http://dx.doi.org/10.1073/pnas.1907375117. Beck (2017) ↑ Amir Beck.First-Order Methods in Optimization.Society for Industrial and Applied Mathematics, Philadelphia, PA, 2017.doi: 10.1137/1.9781611974997.URL https://epubs.siam.org/doi/abs/10.1137/1.9781611974997. Belinkov (2021) ↑ Yonatan Belinkov.Probing classifiers: Promises, shortcomings, and advances, 2021. Belinkov & Glass (2019) ↑ Yonatan Belinkov and James Glass.Analysis methods in neural language processing: A survey, 2019. Belrose et al. (2023) ↑ Nora Belrose, Zach Furman, Logan Smith, Danny Halawi, Igor Ostrovsky, Lev McKinney, Stella Biderman, and Jacob Steinhardt.Eliciting latent predictions from transformers with the tuned lens, 2023. Biderman et al. (2023) ↑ Stella Biderman, Hailey Schoelkopf, Quentin Anthony, Herbie Bradley, Kyle O’Brien, Eric Hallahan, Mohammad Aflah Khan, Shivanshu Purohit, USVSN Sai Prashanth, Edward Raff, Aviya Skowron, Lintang Sutawika, and Oskar Van Der Wal.Pythia: a suite for analyzing large language models across training and scaling.In Proceedings of the 40th International Conference on Machine Learning, ICML’23. JMLR.org, 2023. Bisk et al. (2020) ↑ Yonatan Bisk, Rowan Zellers, Ronan Le bras, Jianfeng Gao, and Yejin Choi.Piqa: Reasoning about physical commonsense in natural language.Proceedings of the AAAI Conference on Artificial Intelligence, 34(05):7432–7439, Apr. 2020.doi: 10.1609/aaai.v34i05.6239.URL https://ojs.aaai.org/index.php/AAAI/article/view/6239. Cai et al. (2024) ↑ Tianle Cai, Yuhong Li, Zhengyang Geng, Hongwu Peng, Jason D. Lee, Deming Chen, and Tri Dao.Medusa: Simple llm inference acceleration framework with multiple decoding heads, 2024. Gao et al. (2020) ↑ Leo Gao, Stella Biderman, Sid Black, Laurence Golding, Travis Hoppe, Charles Foster, Jason Phang, Horace He, Anish Thite, Noa Nabeshima, Shawn Presser, and Connor Leahy.The Pile: An 800gb dataset of diverse text for language modeling.arXiv preprint arXiv:2101.00027, 2020. Hernandez et al. (2024) ↑ Evan Hernandez, Arnab Sen Sharma, Tal Haklay, Kevin Meng, Martin Wattenberg, Jacob Andreas, Yonatan Belinkov, and David Bau.Linearity of relation decoding in transformer language models, 2024. Hewitt & Liang (2019) ↑ John Hewitt and Percy Liang.Designing and interpreting probes with control tasks, 2019. Huettig (2015) ↑ Falk Huettig.Four central questions about prediction in language processing.Brain Research, 1626:118–135, 2015.ISSN 0006-8993.doi: https://doi.org/10.1016/j.brainres.2015.02.014.URL https://www.sciencedirect.com/science/article/pii/S0006899315001146.Predictive and Attentive Processing in Perception and Action. Li et al. (2023) ↑ Kenneth Li, Aspen K. Hopkins, David Bau, Fernanda ViΓ©gas, Hanspeter Pfister, and Martin Wattenberg.Emergent world representations: Exploring a sequence model trained on a synthetic task, 2023. Meng et al. (2023) ↑ Kevin Meng, David Bau, Alex Andonian, and Yonatan Belinkov.Locating and editing factual associations in gpt, 2023. Miller (1951) ↑ George A. Miller.Language and communication.McGraw-Hill, New York, NY, US, 1951.doi: 10.1037/11135-000. Nanda et al. (2023) ↑ Neel Nanda, Lawrence Chan, Tom Lieberum, Jess Smith, and Jacob Steinhardt.Progress measures for grokking via mechanistic interpretability, 2023. Nesterov (2018) ↑ Yurii Nesterov.Lectures on Convex Optimization.Springer Publishing Company, Incorporated, 2nd edition, 2018.ISBN 3319915770. Nguyen et al. (2016) ↑ Tri Nguyen, Mir Rosenberg, Xia Song, Jianfeng Gao, Saurabh Tiwary, Rangan Majumder, and Li Deng.MS MARCO: A human generated machine reading comprehension dataset.In Tarek Richard Besold, Antoine Bordes, Artur S. d’Avila Garcez, and Greg Wayne (eds.), Proceedings of the Workshop on Cognitive Computation: Integrating neural and symbolic approaches 2016 co-located with the 30th Annual Conference on Neural Information Processing Systems (NIPS 2016), Barcelona, Spain, December 9, 2016, volume 1773 of CEUR Workshop Proceedings. CEUR-WS.org, 2016.URL https://ceur-ws.org/Vol-1773/CoCoNIPS_2016_paper9.pdf. nostalgebraist (2020) ↑ nostalgebraist.Interpreting gpt: The logit lens, 2020. Olah et al. (2020) ↑ Chris Olah, Nick Cammarata, Ludwig Schubert, Gabriel Goh, Michael Petrov, and Shan Carter.Zoom in: An introduction to circuits.Distill, 2020.doi: 10.23915/distill.00024.001.https://distill.pub/2020/circuits/zoom-in. Pal et al. (2023) ↑ Koyena Pal, Jiuding Sun, Andrew Yuan, Byron Wallace, and David Bau.Future lens: Anticipating subsequent tokens from a single hidden state.In Proceedings of the 27th Conference on Computational Natural Language Learning (CoNLL). Association for Computational Linguistics, 2023.doi: 10.18653/v1/2023.conll-1.37.URL http://dx.doi.org/10.18653/v1/2023.conll-1.37. Paperno et al. (2016) ↑ Denis Paperno, GermΓ‘n Kruszewski, Angeliki Lazaridou, Ngoc Quan Pham, Raffaella Bernardi, Sandro Pezzelle, Marco Baroni, Gemma Boleda, and Raquel FernΓ‘ndez.The LAMBADA dataset: Word prediction requiring a broad discourse context.In Katrin Erk and Noah A. Smith (eds.), Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pp.  1525–1534, Berlin, Germany, August 2016. Association for Computational Linguistics.doi: 10.18653/v1/P16-1144.URL https://aclanthology.org/P16-1144. Pfau et al. (2024) ↑ Jacob Pfau, William Merrill, and Samuel R. Bowman.Let’s think dot by dot: Hidden computation in transformer language models, 2024. Pimentel et al. (2020) ↑ Tiago Pimentel, Josef Valvoda, Rowan Hall Maudslay, Ran Zmigrod, Adina Williams, and Ryan Cotterell.Information-theoretic probing for linguistic structure, 2020. Radford et al. (2019) ↑ Alec Radford, Jeff Wu, Rewon Child, David Luan, Dario Amodei, and Ilya Sutskever.Language models are unsupervised multitask learners.2019.URL https://api.semanticscholar.org/CorpusID:160025533. Rae et al. (2019) ↑ Jack W Rae, Anna Potapenko, Siddhant M Jayakumar, Chloe Hillier, and Timothy P Lillicrap.Compressive transformers for long-range sequence modelling.arXiv preprint, 2019.URL https://arxiv.org/abs/1911.05507. Shen et al. (2023) ↑ Ruoqi Shen, SΓ©bastien Bubeck, Ronen Eldan, Yin Tat Lee, Yuanzhi Li, and Yi Zhang.Positional description matters for transformers arithmetic, 2023. Shen et al. (2022) ↑ Zuowei Shen, Haizhao Yang, and Shijun Zhang.Optimal approximation rate of relu networks in terms of width and depth.Journal de MathΓ©matiques Pures et AppliquΓ©es, 157:101–135, 2022.ISSN 0021-7824.doi: https://doi.org/10.1016/j.matpur.2021.07.009.URL https://www.sciencedirect.com/science/article/pii/S0021782421001124. Shi et al. (2016) ↑ Xing Shi, Inkit Padhi, and Kevin Knight.Does string-based neural mt learn source syntax?pp.  1526–1534, 01 2016.doi: 10.18653/v1/D16-1159. Stern et al. (2018) ↑ Mitchell Stern, Noam M. Shazeer, and Jakob Uszkoreit.Blockwise parallel decoding for deep autoregressive models.In Neural Information Processing Systems, 2018.URL https://api.semanticscholar.org/CorpusID:53208380. Vaswani et al. (2017) ↑ Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Ł ukasz Kaiser, and Illia Polosukhin.Attention is all you need.In Advances in Neural Information Processing Systems, volume 30. Curran Associates, Inc., 2017.URL https://proceedings.neurips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf. Welbl et al. (2017) ↑ Johannes Welbl, Nelson F. Liu, and Matt Gardner.Crowdsourcing multiple choice science questions.ArXiv, abs/1707.06209, 2017.URL https://api.semanticscholar.org/CorpusID:1553193. Xu et al. (2020) ↑ Yilun Xu, Shengjia Zhao, Jiaming Song, Russell Stewart, and Stefano Ermon.A theory of usable information under computational constraints, 2020. Zhang et al. (2018) ↑ Yuyu Zhang, Hanjun Dai, Kamil Toraman, and Le Song. π‘˜ ⁒ 𝑔 2 : Learning to reason science exam questions with contextual knowledge graph embeddings.ArXiv, abs/1805.12393, 2018.URL https://api.semanticscholar.org/CorpusID:44098100. Zhong et al. (2023) ↑ Ziqian Zhong, Ziming Liu, Max Tegmark, and Jacob Andreas.The clock and the pizza: Two stories in mechanistic explanation of neural networks, 2023. Appendix AMyopia bonus and malus

Notice that the myopia gap consists of two pieces: a myopia bonus, the improvement that can be obtained at the current forward pass by ignoring the future forward passes; and a myopia malus, the cost to the future forward passes that is incurred by not pre-caching for them. To be precise, in the untied case,9 given a choice of myopic 𝜽 ~ 1 , … , 𝜽 ~ 𝑛 satisfying constraints (1) of Definition 1, write

β„“ ⁒ ( 𝜽 ~ 1 , … , 𝜽 ~ 𝑛 ) βˆ’ min 𝜽 1 , … , 𝜽 𝑛 ⁑ β„“ ⁒ ( 𝜽 1 , … , 𝜽 𝑛 )

βˆ‘ 𝑖

1 𝑛 ( min 𝜽 𝑖 + 1 , … , 𝜽 𝑛 ⁑ β„“ ⁒ ( 𝜽 ~ 1 , … , 𝜽 ~ 𝑖 βˆ’ 1 , 𝜽 ~ 𝑖 , 𝜽 𝑖 + 1 , … , 𝜽 𝑛 ) βˆ’ min 𝜽 𝑖 , … , 𝜽 𝑛 ⁑ β„“ ⁒ ( 𝜽 ~ 1 , … , 𝜽 ~ 𝑖 βˆ’ 1 , 𝜽 𝑖 , 𝜽 𝑖 + 1 , … , 𝜽 𝑛 ) )

βˆ‘ 𝑖

1 𝑛 ( β„“ 𝑖 ⁒ ( 𝜽 ~ 1 , … , 𝜽 ~ 𝑛 βˆ’ 1 , 𝜽 ~ 𝑖 ) βˆ’ β„“ 𝑖 ⁒ ( 𝜽 ~ 1 , … , 𝜽 ~ 𝑛 βˆ’ 1 , 𝜽 𝑖 βˆ— 𝑖 ) )

+ βˆ‘ 𝑖

1 𝑛 βˆ‘ 𝑗

𝑖 + 1 𝑛 ( β„“ 𝑗 ⁒ ( 𝜽 ~ 1 , … , 𝜽 ~ 𝑖 βˆ’ 1 , 𝜽 ~ 𝑖 , 𝜽 𝑖 + 1 βˆ— 𝑖 + 1 ⁒ … , 𝜽 𝑛 βˆ— 𝑖 + 1 ) βˆ’ β„“ 𝑗 ⁒ ( 𝜽 ~ 1 , … , 𝜽 ~ 𝑖 βˆ’ 1 , 𝜽 𝑖 βˆ— 𝑖 , 𝜽 𝑖 + 1 βˆ— 𝑖 , … , 𝜽 𝑛 βˆ— 𝑖 ) ) .

where we define

𝜽 𝑖 βˆ— 𝑖 , … , 𝜽 𝑛 βˆ— 𝑖 ∈ arg ⁒ min 𝜽 𝑖 , … , 𝜽 𝑛 ⁑ β„“ ⁒ ( 𝜽 ~ 1 , … , 𝜽 ~ 𝑖 βˆ’ 1 , 𝜽 𝑖 , 𝜽 𝑖 + 1 ⁒ … , 𝜽 𝑛 ) .

That is, the myopia gap is the sum 𝑐 + 𝑑

βˆ‘ 𝑖 𝑐 𝑖 + βˆ‘ 𝑖 𝑑 𝑖 of the myopia bonuses

𝑐 𝑖 ⁒ ( 𝜽 ~ 1 , … , 𝜽 ~ 𝑛 ) := β„“ 𝑖 ⁒ ( 𝜽 ~ 1 , … , 𝜽 ~ 𝑛 βˆ’ 1 , 𝜽 ~ 𝑖 ) βˆ’ β„“ 𝑖 ⁒ ( 𝜽 ~ 1 , … , 𝜽 ~ 𝑛 βˆ’ 1 , 𝜽 𝑖 βˆ— 𝑖 ) ≀ 0 ,

(the inequality following from the myopia constraints (1)), and the myopia maluses

𝑑 𝑖
( 𝜽 ~ 1 , … , 𝜽 ~ 𝑛 )

:= βˆ‘ 𝑗

𝑖 + 1 𝑛 ( β„“ 𝑗 ⁒ ( 𝜽 ~ 1 , … , 𝜽 ~ 𝑖 βˆ’ 1 , 𝜽 ~ 𝑖 , 𝜽 𝑖 + 1 βˆ— 𝑖 + 1 ⁒ … , 𝜽 𝑛 βˆ— 𝑖 + 1 ) βˆ’ β„“ 𝑗 ⁒ ( 𝜽 ~ 1 , … , 𝜽 ~ 𝑖 βˆ’ 1 , 𝜽 𝑖 βˆ— 𝑖 , 𝜽 𝑖 + 1 βˆ— 𝑖 , … , 𝜽 𝑛 βˆ— 𝑖 ) )

β‰₯ 0 ,

with 𝑑 𝑖 + 𝑐 𝑖 β‰₯ 0 for each 1 ≀ 𝑖 ≀ 𝑛 by the definition of the 𝜽 𝑗 βˆ— 𝑖 . A priori, a small myopia gap does not necessarily imply a small (in magnitude) myopia bonus 𝑐 and malus 𝑑 . Indeed, in the case when the myopia gap is small, a large value for 𝑐 (and thus a corresponding large value for 𝑑 ) means precisely that the transformer model is committing significant resources to pre-caching that could otherwise have been used to improve inference on the current position. On the other hand, it is possible that both 𝑐 and 𝑑 are small; that is, there is not much cost associated with pre-caching for the future, as the present forward pass already results in information (breadcrumbs) useful for that purpose.

In practice, the myopia bonus may be difficult to estimate, as it depends on the result of 𝑂 ⁒ ( 𝑛 ) separate optimization problems (each of which, in practice, is a full transformer model training run). Thus, we instead compute the local myopia bonus of Definition 4.

A.1Explicit gradient paths

The dependence of forward pass 𝐺 𝑖 on previous forward passes 𝐺 𝑗 for 𝑗 < 𝑖 is mediated through hidden states 𝒉 1 , … , 𝒉 𝑖 βˆ’ 1 :

𝐺 𝑖 ⁒ ( 𝐱 1 , … , 𝐱 𝑖 ; 𝜽 1 , … , 𝜽 𝑖 )

𝐺 ~ 𝑖 ⁒ ( 𝒉 1 , … , 𝒉 𝑖 βˆ’ 1 , 𝐱 𝑖 ; 𝜽 𝑖 )

where the 𝒉 𝑖 are themselves recursively defined parameterized functions 𝒉 𝑖

𝒉 𝑖 ⁒ ( 𝒉 1 , … , 𝒉 𝑖 βˆ’ 1 , 𝐱 𝑖 ; 𝜽 𝑖 ) . (Note that we are making a choice here to consider transformers as functions of hidden states, and not their key/value states. This has implications for how myopic descent is defined: when hidden state 𝒉 𝑗 is attended to by forward pass 𝑖

𝑗 , we consider the key and value weights 𝑾 𝐾 and 𝑾 𝑉 , respectively, to belong to forward pass 𝑖 . Thus, they are updated by the gradient wrt.  β„“ 𝑖 .)

With the hidden states explicitly written out, the gradient wrt the loss is a sum over all possible paths to the present: for 𝑗 < 𝑖 ,

βˆ‚ 𝐺 𝑖 βˆ‚ 𝜽 𝑗 ⁒ ( 𝐱 1 , … , 𝐱 𝑖 ; 𝜽 1 , … , 𝜽 𝑖 )

βˆ‘ 𝑗

𝑖 1 < … < 𝑖 π‘š < 𝑖 βˆ‚ 𝒉 𝑗 βˆ‚ 𝜽 𝑗 ⁒ βˆ‚ 𝐺 ^ 𝑖 𝒉 𝑖 π‘š ⁒ ∏ π‘˜

1 π‘š βˆ’ 1 βˆ‚ 𝒉 𝑖 π‘˜ βˆ‚ 𝒉 𝑖 π‘˜ βˆ’ 1 .

where the sum is over all partitions 𝑗

𝑖 1 < … < 𝑖 π‘š < 𝑖 .

Appendix BThe myopic attention mechanism

An important primitive that we use when implementing the myopic gradient descent of Section 3.4, the local myopic bonus of Definition 4, and the transformer bigram in Section 5.1 is an attention mechanism that uses key/value states for past forward passes differing from those it uses for the current pass, while still computing all forward passes in parallel. We call our construction the myopic attention mechanism. We use it to implement several distinct transformer training methodologies:

β€’

When training with myopic descent, the past key/value states are the result of key/value weights 𝑾 𝐾 and 𝑾 𝑉 , respectively, multiplying a cloned and detached copy of the previous hidden states.

β€’

During both training and inference of a local myopic model, the past key/value states come from a separate frozen pre-trained transformer model.

β€’

During both training and inference of a transformer bigram model, the past key/value states are simply zeroed out.

Let 𝑿

( 𝒙 1 , … , 𝒙 𝑛 ) ⊀ ∈ ℝ 𝑛 Γ— 𝑑 be the sequence of residual stream hidden states per position, with each row representing one position’s hidden state in ℝ 𝑑 . Let 𝑾 𝑄 , 𝑾 𝐾 , 𝑾 𝑉 ∈ ℝ 𝑑 Γ— β„Ž be the query, key, value weight matrices for one attention head, of dimensionality β„Ž , in the transformer 𝐺 . Denote

𝑸 := 𝑿 ⁒ 𝑾 𝑄 , 𝑲 := 𝑿 ⁒ 𝑾 𝐾 , 𝑽 := 𝑿 ⁒ 𝑾 𝑉

and let 𝑸 ~ , 𝑲 ~ , 𝑽 ~ be the alternate states we wish to use for off-diagonal attention terms. We adopt the convention that lowercase letters with subscripts represent rows of matrices; e.g.  𝒒 𝑖 is the 𝑖 th row of 𝑸 . For simplicity of presentation, we omit causal masking; the modifications that should be made in the presence of a mask are straightforward.

Recall that the vanilla attention mechanism for 𝐺 is

𝒀

𝜎 ⁒ ( 𝑸 ⁒ 𝑲 ⊀ ) ⁒ 𝑽 ,

where 𝜎 is row-wise softmax. Writing this out token-wise,

π’š 𝑖

𝑍 𝑖 βˆ’ 1 ⁒ βˆ‘ 𝑗

1 𝑛 exp ⁑ ( 𝒒 𝑖 ⊀ ⁒ π’Œ 𝑗 ) ⁒ 𝐯 𝑗 ,

where 𝑍 𝑖 is the partition function

𝑍 𝑖 := βˆ‘ 𝑗

1 𝑛 exp ⁑ ( 𝒒 𝑖 ⊀ ⁒ π’Œ 𝑗 ) .

The myopic attention mechanism, on the other hand, is written tokenwise as

π’š ~ 𝑖

𝑍 ~ 𝑖 βˆ’ 1 ⁒ ( exp ⁑ ( 𝒒 𝑖 ⊀ ⁒ π’Œ 𝑖 ) ⁒ 𝐯 𝑖 + βˆ‘ 𝑗 β‰  𝑖 exp ⁑ ( 𝒒 𝑖 ⊀ ⁒ π’Œ ~ 𝑗 ) ⁒ 𝐯 ~ 𝑗 )

𝑍 ~ 𝑖 βˆ’ 1 ⁒ βˆ‘ 𝑗

1 𝑛 exp ⁑ ( 𝒒 𝑖 ⊀ ⁒ π’Œ ~ 𝑗 + 𝛿 𝑖 ⁒ 𝑗 ⁒ 𝒒 𝑖 ⊀ ⁒ ( π’Œ 𝑗 βˆ’ π’Œ ~ 𝑗 ) ) ⁒ ( 𝐯 ~ 𝑗 + 𝛿 𝑖 ⁒ 𝑗 ⁒ ( 𝐯 𝑗 βˆ’ 𝐯 ~ 𝑗 ) )

βˆ‘ 𝑗

1 𝑛 π‘Ž 𝑖 ⁒ 𝑗 ⁒ 𝐯 ~ 𝑗 βˆ’ βˆ‘ 𝑗

1 𝑛 𝛿 𝑖 ⁒ 𝑗 ⁒ π‘Ž 𝑖 ⁒ 𝑗 ⁒ ( 𝐯 𝑗 βˆ’ 𝐯 ~ 𝑗 )

where

𝑍 ~ 𝑖
:= exp ⁑ ( 𝒒 𝑖 ⊀ ⁒ π’Œ 𝑖 ) + βˆ‘ 𝑗 β‰  𝑖 exp ⁑ ( 𝒒 𝑖 ⊀ ⁒ π’Œ ~ 𝑗 )

βˆ‘ 𝑗

1 𝑛 exp ⁑ ( 𝒒 𝑖 ⊀ ⁒ π’Œ ~ 𝑗 + 𝛿 𝑖 ⁒ 𝑗 ⁒ 𝒒 𝑖 ⊀ ⁒ ( π’Œ 𝑗 βˆ’ π’Œ ~ 𝑗 ) ) ,

π‘Ž 𝑖 ⁒ 𝑗

:= 𝑍 ~ 𝑖 βˆ’ 1 ⁒ exp ⁑ ( 𝒒 𝑖 ⊀ ⁒ π’Œ ~ 𝑗 + 𝛿 𝑖 ⁒ 𝑗 ⁒ 𝒒 𝑖 ⊀ ⁒ ( π’Œ 𝑗 βˆ’ π’Œ ~ 𝑗 ) ) ,

and 𝛿 𝑖 ⁒ 𝑗 is the Kronecker delta. Now, notate

( diag ⁑ 𝑨 ) 𝑖 ⁒ 𝑗 := 𝛿 𝑖 ⁒ 𝑗 ⁒ 𝐴 𝑖 ⁒ 𝑗 .

That is, diag ⁑ 𝑨 is the diagonal matrix that has the same entries as 𝑨 along the diagonal and is zero elsewhere. We are now able to write the myopic attention mechanism in matrix form:

𝒀 ~

𝑨 ⁒ 𝑽 ~ + ( diag ⁑ 𝑨 ) ⁒ ( 𝑽 βˆ’ 𝑽 ~ ) ,

where

𝑨 := 𝜎 ⁒ ( 𝑸 ⁒ 𝑲 ~ ⊀ + diag ⁑ ( 𝑸 ⁒ ( 𝑲 ⊀ βˆ’ 𝑲 ~ ⊀ ) ) ) .

Appendix CProofs

We use two assumptions on the loss function to be minimized. These are standard in the first-order methods literature.

Definition 6.

A function 𝑓 : ℝ 𝑛 β†’ ℝ is called 𝐿 -smooth if it is continuously differentiable with 𝐿 -Lipschitz gradient. That is, for all 𝐱 , 𝐲 in the domain,

βˆ₯ βˆ‡ 𝑓 ⁒ ( 𝒙 ) βˆ’ βˆ‡ 𝑓 ⁒ ( π’š ) βˆ₯ 2 ≀ 𝐿 ⁒ βˆ₯ 𝒙 βˆ’ π’š βˆ₯ 2 .

Definition 7.

A function 𝑓 : ℝ 𝑛 β†’ ℝ is called 𝜎 -strongly convex if, for all 𝐱 , 𝐲 in the domain,

𝑓 ⁒ ( π’š ) β‰₯ 𝑓 ⁒ ( 𝒙 ) + βˆ‡ 𝑓 ⁒ ( 𝒙 ) ⊀ ⁒ ( π’š βˆ’ 𝒙 ) + 𝜎 2 ⁒ βˆ₯ π’š βˆ’ 𝒙 βˆ₯ 2 2 .

In particular, recall that strong convexity implies the existence of a unique minimum. Hence, it makes sense to write, for example, 𝒙 βˆ—

arg ⁒ max 𝒙 ⁑ 𝑓 ⁒ ( 𝒙 ) without ambiguity.

C.1Gradient descent with untied weights Theorem 8.

Assume β„“ : Θ 𝑛 β†’ ℝ is 𝜎 -strongly convex and 𝐿 -smooth for some 𝜎 , 𝐿

0 . Consider ordinary gradient descent with untied weights

𝜽 𝑖 ( 𝑑 )

𝜽 𝑖 ( 𝑑 βˆ’ 1 ) βˆ’ πœ‚ ⁒ βˆ‡ 𝜽 𝑖 β„“ ⁒ ( 𝜽 1 ( 𝑑 βˆ’ 1 ) , … , 𝜽 𝑖 ( 𝑑 βˆ’ 1 ) ) βˆ€ 𝑖 ∈ { 1 , … , 𝑛 } .

Then, for 𝛉 1 βˆ— , … , 𝛉 𝑛 βˆ—

arg ⁒ min 𝛉 1 , … , 𝛉 𝑛 ⁑ β„“ ⁒ ( 𝛉 1 , … , 𝛉 𝑛 ) , for small enough πœ‚

0 ,

βˆ₯ 𝜽 𝑖 ( 𝑑 ) βˆ’ 𝜽 𝑖 βˆ— βˆ₯ 2 2 ≀ ( 1 βˆ’ 2 ⁒ πœ‚ ⁒ 𝜎 ⁒ 𝐿 𝜎 + 𝐿 ) 𝑑 ⁒ βˆ₯ 𝜽 𝑖 ( 0 ) βˆ’ 𝜽 𝑖 βˆ— βˆ₯ 2 2 βˆ€ 𝑖 ∈ { 1 , … , 𝑛 } .

Proof.

This is a standard convergence result for gradient descent on strongly convex functions. For example, see Nesterov (2018). ∎

C.2Gradient descent with tied weights Theorem 9.

Assume β„“ is 𝜎 -strongly convex and 𝐿 -smooth, and consider ordinary gradient descent with tied weights

𝜽 1 ( 0 )

…

𝜽 𝑛 ( 0 )

𝜽 𝑖 ( 𝑑 + 1 )

𝜽 𝑖 ( 𝑑 ) + πœ‚ ⁒ βˆ‘ 𝑖

1 𝑛 βˆ‡ 𝜽 𝑖 β„“ ⁒ ( 𝜽 1 , … , 𝜽 𝑛 ) βˆ€ 𝑖 ∈ { 1 , … , 𝑛 } .

There exists πœ‚

0 such that

βˆ₯ 𝜽 𝑖 ( 𝑑 ) βˆ’ 𝜽 βˆ— βˆ₯ 2 2 ≀ ( 1 βˆ’ 2 ⁒ 𝑛 ⁒ πœ‚ ⁒ 𝜎 ⁒ 𝐿 𝜎 + 𝐿 ) 𝑑 ⁒ βˆ₯ 𝜽 𝑖 ( 0 ) βˆ’ 𝜽 βˆ— βˆ₯ 2 2 βˆ€ 𝑖 ∈ { 1 , … , 𝑛 } .

where 𝛉 βˆ—

arg ⁒ min 𝛉 ⁑ β„“ ⁒ ( 𝛉 , … , 𝛉 ) .

Proof.

This is again the standard convergence result, now applied to descent with step size 𝑛 ⁒ πœ‚ on the 𝜎 -strongly convex 𝐿 -smooth function 𝜽 ↦ β„“ ⁒ ( 𝜽 / 𝑛 , … , 𝜽 / 𝑛 ) . Alternatively, one may think of this as projected gradient descent constrained to the subspace 𝜽 1

…

𝜽 𝑛 with a step size of 𝑛 ⁒ πœ‚ . Projected gradient descent inherits the same convergence properties as unconstrained gradient descent (Beck, 2017). ∎

C.3Myopic descent with untied weights Theorem 10.

Assume each of the β„“ 1 , … , β„“ 𝑛 are 𝜎 -strongly convex and 𝐿 -smooth. Consider myopic gradient descent with untied weights

𝜽 𝑖 ( 𝑑 )

𝜽 𝑖 ( 𝑑 βˆ’ 1 ) βˆ’ πœ‚ ⁒ βˆ‡ 𝜽 𝑖 β„“ 𝑖 ⁒ ( 𝜽 1 ( 𝑑 βˆ’ 1 ) , … , 𝜽 𝑖 ( 𝑑 βˆ’ 1 ) ) βˆ€ 𝑖 ∈ { 1 , … , 𝑛 } .

There exists πœ‚

0 such that 𝛉 𝑖 β†’ 𝑑 β†’ ∞ 𝛉 ~ 𝑖 for all 𝑖 , where

𝜽 ~ 1

arg ⁒ min 𝜽 1 ⁑ β„“ 1 ⁒ ( 𝜽 1 )

𝜽 ~ 2

arg ⁒ min 𝜽 2 ⁑ β„“ 2 ⁒ ( 𝜽 ~ 1 , 𝜽 2 )

…

𝜽 ~ 𝑛

arg ⁒ min 𝜽 𝑛 ⁑ β„“ 𝑛 ⁒ ( 𝜽 ~ 1 , 𝜽 ~ 2 , … , 𝜽 ~ 𝑛 βˆ’ 1 , 𝜽 𝑛 ) .

Proof.

Let 𝜽 ~ 1 , … , 𝜽 ~ 𝑛 be as in the theorem statement. We proceed by induction. For the base case, note that the myopic descent iterates for 𝜽 1 ( 𝑑 ) are independent of 𝜽 𝑗 ( 𝑑 ) for 𝑗

𝑖 . Thus, the standard convergence theorem gives that 𝜽 1 ( 𝑑 ) β†’ 𝜽 ~ 1 as 𝑑 β†’ ∞ .

Now, assume 𝜽 𝑖 ( 𝑑 ) β†’ 𝑑 β†’ ∞ 𝜽 ~ 𝑖 for all 𝑖 < π‘˜ . Thus, for any πœ€

0 , for sufficiently large 𝑑 ,

βˆ₯ ( 𝜽 𝑖 ( 𝑑 ) ) 𝑖

1 π‘˜ βˆ’ 1 βˆ’ ( 𝜽 ~ 𝑖 ) 𝑖

1 π‘˜ βˆ’ 1 βˆ₯ 2 < πœ€ .

Hence, since β„“ π‘˜ is 𝐿 -smooth, for any 𝜽 π‘˜ ,

βˆ₯ βˆ‡ 𝜽 π‘˜ β„“ π‘˜ ⁒ ( 𝜽 1 ( 𝑑 ) , … , 𝜽 π‘˜ βˆ’ 1 ( 𝑑 ) , 𝜽 π‘˜ ) βˆ’ βˆ‡ 𝜽 π‘˜ β„“ π‘˜ ⁒ ( 𝜽 ~ 1 , … , 𝜽 ~ π‘˜ βˆ’ 1 , 𝜽 π‘˜ ) βˆ₯ 2 < 𝐿 ⁒ πœ€ .

Expanding and rearranging,

⟨ βˆ‡ 𝜽 π‘˜ β„“ π‘˜ ⁒ ( 𝜽 1 ( 𝑑 ) , … , 𝜽 π‘˜ βˆ’ 1 ( 𝑑 ) , 𝜽 π‘˜ ) , βˆ‡ 𝜽 π‘˜ β„“ π‘˜ ⁒ ( 𝜽 ~ 1 , … , 𝜽 ~ π‘˜ βˆ’ 1 , 𝜽 π‘˜ ) ⟩ β‰₯ 1 2 ⁒ βˆ₯ βˆ‡ 𝜽 π‘˜ β„“ π‘˜ ⁒ ( 𝜽 ~ 1 , … , 𝜽 ~ π‘˜ βˆ’ 1 , 𝜽 π‘˜ ) βˆ₯ 2 2 βˆ’ 1 2 ⁒ 𝐿 2 ⁒ πœ€ 2

is bounded away from zero as long as, say,

βˆ₯ βˆ‡ 𝜽 π‘˜ β„“ π‘˜ ⁒ ( 𝜽 ~ 1 , … , 𝜽 ~ π‘˜ βˆ’ 1 , 𝜽 π‘˜ ) βˆ₯ 2 β‰₯ 𝜎 ⁒ βˆ₯ 𝜽 π‘˜ βˆ’ 𝜽 ~ π‘˜ βˆ₯ 2

( 𝐿 + 1 ) ⁒ πœ€ ,

using the 𝜎 -strong convexity of β„“ π‘˜ . That is, as long as βˆ₯ 𝜽 π‘˜ βˆ’ 𝜽 ~ π‘˜ βˆ₯ 2

( 𝐿 + 1 ) ⁒ πœ€ 𝜎 , it is guaranteed that βˆ’ βˆ‡ 𝜽 π‘˜ β„“ π‘˜ ⁒ ( 𝜽 1 ( 𝑑 ) , … , 𝜽 π‘˜ βˆ’ 1 ( 𝑑 ) , 𝜽 π‘˜ ) is a descent direction for β„“ π‘˜ ⁒ ( 𝜽 ~ 1 , … , 𝜽 ~ π‘˜ βˆ’ 1 , 𝜽 π‘˜ ) . This is a sufficient condition for 𝜽 π‘˜ to converge to a ( 𝐿 + 1 ) ⁒ πœ€ 𝜎 -neighborhood of 𝜽 ~ π‘˜ given a small enough step size (Beck, 2017). Since πœ€

0 is arbitrary, this completes the inductive step. ∎

C.4Myopic descent with tied weights Definition 11.

A function 𝑓 ⁒ ( 𝐱 , 𝐲 ) : ℝ π‘Ž Γ— ℝ π‘Ž β†’ ℝ with continuous second derivatives is 𝜌 -forward-biased if, for all 𝐲 ∈ ℝ π‘Ž and 𝐱

𝐲 ,

𝑯 π’š , π’š ⁒ 𝑓 ⁒ ( 𝒙 , π’š ) + 𝑯 𝒙 , π’š ⁒ 𝑓 ⁒ ( 𝒙 , π’š ) ≻ 0

and

πœ… ⁒ ( 𝑯 π’š , π’š ⁒ 𝑓 ⁒ ( 𝒙 , π’š ) + 𝑯 𝒙 , π’š ⁒ 𝑓 ⁒ ( 𝒙 , π’š ) ) < 𝜌 ,

where πœ… is the condition number, and we write the Hessian of 𝑓 as a block matrix:

𝑯 ⁒ 𝑓 ⁒ ( 𝒙 , π’š )

[ 𝑯 𝒙 , 𝒙 ⁒ 𝑓 ⁒ ( 𝒙 , π’š )

𝑯 𝒙 , π’š ⁒ 𝑓 ⁒ ( 𝒙 , π’š )

𝑯 π’š , 𝒙 ⁒ 𝑓 ⁒ ( 𝒙 , π’š )
𝑯 π’š , π’š ⁒ 𝑓 ⁒ ( 𝒙 , π’š ) ]

[ ( βˆ‚ 𝑓 ⁒ ( 𝒙 , π’š ) βˆ‚ π‘₯ 𝑖 ⁒ βˆ‚ π‘₯ 𝑗 ) 1 ≀ 𝑖 ≀ π‘Ž

1 ≀ 𝑗 ≀ π‘Ž

( βˆ‚ 𝑓 ⁒ ( 𝒙 , π’š ) βˆ‚ π‘₯ 𝑖 ⁒ βˆ‚ 𝑦 𝑗 ) 1 ≀ 𝑖 ≀ π‘Ž

1 ≀ 𝑗 ≀ 𝑏

( βˆ‚ 𝑓 ⁒ ( 𝒙 , π’š ) βˆ‚ 𝑦 𝑖 ⁒ βˆ‚ π‘₯ 𝑗 ) 1 ≀ 𝑖 ≀ 𝑏

1 ≀ 𝑗 ≀ π‘Ž

( βˆ‚ 𝑓 ⁒ ( 𝒙 , π’š ) βˆ‚ 𝑦 𝑖 ⁒ βˆ‚ 𝑦 𝑗 ) 1 ≀ 𝑖 ≀ 𝑏

1 ≀ 𝑗 ≀ 𝑏

] .

Lemma 12.

If 𝐀 ∈ ℝ π‘Ž Γ— π‘Ž is positive definite with πœ… ⁒ ( 𝐴 ) < 𝜌 , then there exists some πœ‚

0 such that 𝐈 βˆ’ πœ‚ ⁒ 𝐀 is a ( 1 βˆ’ 𝜌 βˆ’ 1 ) -contraction. Explicitly, for any 𝐲 ∈ ℝ π‘Ž ,

βˆ₯ ( 𝑰 βˆ’ πœ‚ ⁒ 𝑨 ) ⁒ π’š βˆ₯ 2 ≀ ( 1 βˆ’ 𝜌 βˆ’ 1 ) ⁒ βˆ₯ π’š βˆ₯ 2 .

Proof.

Set πœ‚

1 / πœ† max ⁒ ( 𝑨 ) . Then 𝑰 βˆ’ πœ‚ ⁒ 𝑨 βͺ° 0 with

πœ† max ⁒ ( 𝑰 βˆ’ πœ‚ ⁒ 𝑨 )

1 βˆ’ πœ‚ ⁒ πœ† min ⁒ ( 𝑨 )

1 βˆ’ πœ† min ⁒ ( 𝑨 ) πœ† max ⁒ ( 𝑨 ) < 1 βˆ’ 𝜌 βˆ’ 1 .

It immediately follows that 𝑰 βˆ’ πœ‚ ⁒ 𝑨 is a ( 1 βˆ’ 𝜌 βˆ’ 1 ) -contraction. ∎

Theorem 13.

Let 𝑓 ⁒ ( 𝐱 , 𝐲 ) : ℝ π‘Ž Γ— ℝ π‘Ž β†’ ℝ be 𝜌 -forward biased, 𝜎 -strongly convex, and 𝐿 -smooth with continuous second derivatives for some 𝜌 , 𝜎 , 𝐿

0 . Then, there exists 𝐲 ~ ∈ ℝ π‘Ž such that

π’š ~

arg ⁒ min π’š ∈ ℝ π‘Ž ⁑ 𝑓 ⁒ ( π’š ~ , π’š ) .

(4)

Further, for some step size πœ‚

0 , the iterates of myopic gradient descent with tied weights

π’š ( 𝑑 + 1 )

π’š ( 𝑑 ) βˆ’ πœ‚ ⁒ βˆ‡ π’š 𝑓 ⁒ ( 𝒙 , π’š ) | 𝒙

π’š

π’š ( 𝑑 )

converge to 𝐲 ~ ∈ ℝ π‘Ž satisfying (4). We call such 𝐲 ~ a myopic solution.

We then recover the original sequential modeling setting by defining 𝑓 ⁒ ( 𝜽 ~ , 𝜽 ) := βˆ‘ 𝑖

1 𝑛 β„“ 𝑖 ⁒ ( 𝜽 ~ , … , 𝜽 ~ , 𝜽 ) .

Proof.

First, note that if myopic descent does converge, it must be to a point π’š ~ such that π’š ~

arg ⁒ min π’š ⁑ 𝑓 ⁒ ( π’š ~ , π’š ) . Indeed, at convergence we must have βˆ‡ π’š 𝑓 ⁒ ( π’š ~ , π’š ) | π’š

π’š ~

0 , so strong convexity tells us that π’š ~ is optimal. Thus, if we establish that myopic descent with small enough step size πœ‚

0 converges to some π’š ~ , we automatically get the existence of a myopic solution. For this, it suffices to show that the gradient descent step

𝑔 πœ‚ ⁒ ( π’š )

π’š βˆ’ πœ‚ ⁒ βˆ‡ π’š 𝑓 ⁒ ( 𝒙 , π’š ) | 𝒙

π’š

is strictly contractive, so that iterates of 𝑔 πœ‚ converge to a fixed point.

Consider arbitrary π’š and π’š β€² . Then, by chain rule and the mean value theorem applied to the map π’š β€² ↦ βˆ‡ π’š 𝑓 ⁒ ( 𝒙 , π’š ) | 𝒙

π’š

π’š β€² ,

βˆ‡ π’š 𝑓 ⁒ ( π’š β€² , π’š β€² ) βˆ’ βˆ‡ π’š 𝑓 ⁒ ( π’š , π’š )

( 𝑯 𝒙 , π’š ⁒ 𝑓 ⁒ ( π’š β€²β€² , π’š β€²β€² ) + 𝑯 π’š , π’š ⁒ 𝑓 ⁒ ( π’š β€²β€² , π’š β€²β€² ) ) ⁒ ( π’š β€² βˆ’ π’š )

(5)

for some π’š β€²β€² ∈ [ π’š , π’š β€² ] . Using the definition of 𝑔 πœ‚ and substituting in (5),

βˆ₯ 𝑔 πœ‚ ⁒ ( π’š β€² ) βˆ’ 𝑔 πœ‚ ⁒ ( π’š ) βˆ₯ 2 2

βˆ₯ ( π’š β€² βˆ’ π’š ) βˆ’ πœ‚ ⁒ ( βˆ‡ π’š 𝑓 ⁒ ( π’š β€² , π’š β€² ) βˆ’ βˆ‡ 𝑦 𝑓 ⁒ ( π’š , π’š ) ) βˆ₯ 2 2

βˆ₯ ( 𝐼 βˆ’ ( 𝑯 𝒙 , π’š ⁒ 𝑓 ⁒ ( π’š β€²β€² , π’š β€²β€² ) + 𝑯 π’š , π’š ⁒ 𝑓 ⁒ ( π’š β€²β€² , π’š β€²β€² ) ) ) ⁒ ( π’š β€² βˆ’ π’š ) βˆ₯ 2 2

≀ ( 1 βˆ’ 𝜌 βˆ’ 1 ) ⁒ ( π’š β€² βˆ’ π’š ) ,

where the last line is by 𝜌 -forward bias and Lemma 12. That is, 𝑔 πœ‚ is ( 1 βˆ’ 𝜌 βˆ’ 1 ) -contractive, completing the proof. ∎

C.5Properties of sine Lemma 14.

Let x ∼ 𝒩 ⁒ ( 0 , 1 ) . Then

Var ⁒ ( sin ⁑ ( 𝑏 ⁒ x ) )

1 2 βˆ’ 𝑒 2 ⁒ 𝑏 2 2 .
Proof.
Var ⁒ ( sin ⁑ ( 𝑏 ⁒ x ) )

( 2 ⁒ πœ‹ ) βˆ’ 1 / 2 ⁒ ∫ βˆ’ ∞ ∞ sin 2 ⁑ ( 𝑏 ⁒ π‘₯ ) ⁒ 𝑒 βˆ’ π‘₯ 2 / 2 ⁒ 𝑑 π‘₯

1 2 βˆ’ 𝑒 2 ⁒ 𝑏 2 2 .

∎

Lemma 15.

Let x ∼ 𝒩 ⁒ ( 0 , 1 ) . Then

𝜌 ⁒ ( x , sin ⁑ ( 𝑏 ⁒ x ) )

2 ⁒ 𝑏 ⁒ 𝑒 3 ⁒ 𝑏 2 / 2 𝑒 2 ⁒ 𝑏 2 βˆ’ 1 ∈ 𝑂 ⁒ ( 𝑒 βˆ’ 𝑏 2 / 2 ) ,

where 𝜌 is the Pearson correlation coefficient.

Proof.

By symmetry, 𝔼 ⁒ [ sin ⁑ ( 𝑏 ⁒ x ) ]

0 . We calculate

Cov ⁒ ( x , sin ⁑ ( 𝑏 ⁒ x ) )

( 2 ⁒ πœ‹ ) βˆ’ 1 / 2 ⁒ ∫ βˆ’ ∞ ∞ π‘₯ ⁒ sin ⁑ ( 𝑏 ⁒ π‘₯ ) ⁒ 𝑒 βˆ’ π‘₯ 2 / 2 ⁒ 𝑑 π‘₯

𝑏 ⁒ 𝑒 βˆ’ 𝑏 2 / 2 .

We already computed the variance in Lemma 14, so

𝜌 ⁒ ( x , sin ⁑ ( 𝑏 ⁒ x ) )

Cov ⁒ ( x , sin ⁑ ( 𝑏 ⁒ x ) ) Var ⁒ ( x ) ⁒ Var ⁒ ( sin ⁑ ( 𝑏 ⁒ x ) )

2 ⁒ 𝑏 ⁒ 𝑒 3 ⁒ 𝑏 2 / 2 𝑒 2 ⁒ 𝑏 2 βˆ’ 1 .

In our experiments we set 𝑏

10 , so 𝜌 ⁒ ( x , sin ⁑ ( 𝑏 ⁒ x ) ) < 10 βˆ’ 20 .

Appendix DDetails and additional experiments D.1Synthetic setting: π’Ÿ 𝑝 task

We use a smaller version of the GPT-2 architecture, adapted to regression tasks. That is, the token embedding and unembedding layers are replaced with a trained linear map from the input space to the embedding space and from the embedding space to the output space, respectively. For each 𝑝 ∈ { 0.01 , 0.1 , 0.3 , 1 } models are trained using ordinary and myopic descent on one epoch of 30M sequences of length 64 sampled from π’Ÿ 𝑝 , 10 , 10 64 . See Table 3 for architecture details.

Configuration Key Value num_layers 2 num_heads 2 embd_dim 128 n_inner 512 input_dim 2 output_dim 1 activation relu attn_pdrop 0 embd_pdrop 0 resid_pdrop 0 lr 1e-3 optimizer AdamW weight_decay 0.01 betas (0.9, 0.999) scheduler cosine warmup 0.01 batch_size 512 seq_length 64 loss_fn HuberLoss Table 3:Transformer configuration used when training on synthetic data distribution π’Ÿ 𝑝 D.1.1Probe details

Given a transformer model trained on π’Ÿ 𝑝 , we sample the hidden states at each layer when the model is given as input 50 000 evaluation sequences from the same distribution π’Ÿ 𝑝 . For each layer and targets sin ⁑ ( 𝑏 ⁒ x 𝑛 βˆ’ 𝑖 ) for varying 𝑖

0 , we fit a linear regression model on the hidden state of that layer to the target. The in-sample 𝑅 2 of each linear model is then reported in Figure 2. Figure 8 is a visualization of the linear probe’s performance on the vanilla transformer.

Figure 8: Estimate of sin ⁑ ( 𝑏 ⁒ x 𝑛 ) by linear probe fit on layer 1 of transformer with vanilla training on π’Ÿ 0.3 . Computed over 50 000 samples from π’Ÿ 1 . D.2Natural language setting D.2.1GPT-2

For both vanilla and myopic training, we train the GPT2-small architecture from random initialization for one epoch on 4.6M sequences from the MS MARCO dataset (Nguyen et al., 2016), truncated to length 64. To estimate the local myopia bonus of the vanilla model, we train another model from random initialization with the same architecture, but with past hidden states provided by the vanilla model. Note that this β€œlocal myopic” model attains slightly better performance than the vanilla model; each forward pass can focus purely on next-token prediction, since past hidden states are supplied by a separate model. As a baseline, we also train a β€œtransformer bigram” model, which is a transformer model whose key/value states are zeroed out during training and evaluation. See Table 4 for configuration details.

Configuration Key Value num_layers 12 num_heads 12 embd_dim 768 n_inner 3072 vocab_size 50257 activation gelu_new attn_pdrop 0.1 embd_pdrop 0.1 resid_pdrop 0.1 lr 6.0 Γ— 10 βˆ’ 4

optimizer AdamW weight_decay 0.01 betas (0.9, 0.999) scheduler cosine warmup 0.01 batch_size 512 seq_length 64 loss_fn CrossEntropy Table 4:GPT-2 small configuration used when training on natural language data. D.2.2Pythia

For experiments on the Pythia suite, we finetune with either vanilla or myopic descent on 10M sequences of 64 tokens each subsampled from The Pile (Gao et al., 2020). Learning rates and batch sizes for each model are presented in Table 5; they are the same between vanilla and myopic descent. All other training and architectural hyperparameters are the same as those used by Biderman et al. (2023).

Model Learning rate Batch size Pythia 14M 4.0 Γ— 10 βˆ’ 4 512 Pythia 31M 4.0 Γ— 10 βˆ’ 4 512 Pythia 70M 1.2 Γ— 10 βˆ’ 4 512 Pythia 160M 1.2 Γ— 10 βˆ’ 4 512 Pythia 410M 1.2 Γ— 10 βˆ’ 4 256 Pythia 1B 1.2 Γ— 10 βˆ’ 4 128 Pythia 1.4B 1.2 Γ— 10 βˆ’ 4 128 Pythia 2.8B 8.0 Γ— 10 βˆ’ 5 64 Table 5:Pythia suite hyperparameters for finetuning. Batch size is measured in sequences of 64 tokens each. Figure 9: Cross-entropy loss of Pythia models fine-tuned on The Pile dataset using vanilla and myopic gradient descent. Starting from the 70M model, we see that the gap increases with parameter count. D.3Multiplication

In addition to the natural language experiments, we also measure the myopia gap on the task of natural number multiplication. We use the same GPT2-small architecture as in the natural language experiments starting from random initialization; see Table 4. See Figure 10 for an example input sequence.

                3 7 0 0 * 5 4 0 0 = 5 8 2 3 0 0 0 0

Figure 10: Example multiplication input sequence.

We use several formatting tricks found by Shen et al. (2023) to improve performance on the multiplication task:

β€’

Characters are delimited by spaces, such that each digit is tokenized into a separate token.

β€’

All numbers are written in the reverse of the standard order, i.e. such that the least significant digits come first.

β€’

All inputs are zero-padded to the same length.

Note for each multiplicand we first sample the number of digits 𝑑 ∼ Unif ⁒ ( 𝑛 ) uniformly in some range 𝑛 , then uniformly sample a natural number π‘₯ ∼ Unif ⁒ ( 10 𝑑 βˆ’ 1 ) with no more than 𝑑 digits. This distribution allocates increased weight to smaller numbers, and was found to result in superior performance.

We train both vanilla and myopic transformers on one epoch of 10M examples with no more than 8 digits, then measure 0/1 accuracy (that is, the model is provided with the an input sample up to the β€˜=’ token, and scored 1 if it completes the rest of the sequence exactly correctly and 0 otherwise) on 1024 independent random validation examples for each of the 8 Γ— 8 possible pairs of digit counts for the two multiplicands. See Figure 11.

Vanilla multiplication accuracy

Myopic multiplication accuracy

Vanilla-myopic accuracy gap Figure 11:Accuracy of vanilla and myopic transformers trained on multiplication of up to 8-digit inputs. Row and columns correspond to the number of digits in the first and second multiplicands, respectively. D.3.1Filler tokens

We further hypothesize that, as in Pfau et al. (2024), the vanilla transformer may learn to perform computation even on forward passes corresponding to filler tokens, thus attaining better performance when trained on examples zero-padded to longer lengths. We expect that myopic transformers, on the other hand, are not incentivized to do this, since this extra computation holds no relevance towards the immediate task of predicting the filler zero token. To test this hypothesis, we train vanilla and myopic transformers on each of two different multiplication datasets:

1.

Both multiplicands have at most 5 digits, and are zero-padded to exactly 5 digits.

2.

Both multiplicands have at most 5 digits, and are zero-padded to exactly 10 digits.

Again, all training runs consist of one epoch of 10M examples.

See Figure 12 for results. Note that the vanilla transformer indeed performs better when trained and evaluated on input sequences zero-padded to a longer length. However, the myopic transformer performs substantially worse with increased zero-padding. Our explanation is that, not only does the myopic transformer not learn to perform intermediate tokens during zero-token forward passes, the increased input length makes it more difficult for the attention mechanism to correctly attend to the relevant tokens.

D.4Gradient angles

Using the publicly available training checkpoints for Pythia-410M (Biderman et al., 2023), we measure the sizes of both the myopic component and the future component of the gradient of the loss w/r/t the parameters over the course of training. (Note that the future component is the difference between the total vanilla gradient and the myopic gradient.) We also measure the cosine similarity between the myopic and future components. See Figure 13. One observation is that the norm of the myopic gradient is consistently larger than that of the future gradientβ€”thus, training is dominated by the effect of each forward pass’s parameters on the immediate next-token prediction.

Vanilla accuracy, padded to length 5 Vanilla accuracy, padded to length 10

Myopic accuracy, padded to length 5 Myopic accuracy, padded to length 10 Figure 12:Multiplication accuracy of GPT-2 with either vanilla or myopic training and with input multiplicands zero-padded to either length 5 or 10. Row and columns correspond to the number of digits in the first and second multiplicands, respectively. Observe that padding improves performance of the vanilla model, but decreases performance of the myopic model.

Norms of myopic, future, and total gradients Cosine similarity between myopic and future gradient Figure 13:Myopic and future gradients of Pythia-410M during training. Report Issue Report Issue for Selection Generated by L A T E xml Instructions for reporting errors

We are continuing to improve HTML versions of papers, and your feedback helps enhance accessibility and mobile support. To report errors in the HTML that will help us improve conversion and rendering, choose any of the methods listed below:

Click the "Report Issue" button. Open a report feedback form via keyboard, use "Ctrl + ?". Make a text selection and click the "Report Issue for Selection" button near your cursor. You can use Alt+Y to toggle on and Alt+Shift+Y to toggle off accessible reporting links at each section.

Our team has already identified the following issues. We appreciate your time reviewing and reporting rendering errors we may not have found yet. Your efforts will help us improve the HTML versions for all readers, because disability should not be a barrier to accessing research. Thank you for your continued support in championing open access for all.

Have a free development cycle? Help support accessibility at arXiv! Our collaborators at LaTeXML maintain a list of packages that need conversion, and welcome developer contributions.

Xet Storage Details

Size:
68 kB
Β·
Xet hash:
ea03e4dbdd766e7c1b824094be4208766d7dbbb71baff63d35c8f3f2a163ace5

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.