Buckets:
Title: Do Language Models Plan Ahead for Future Tokens?
URL Source: https://arxiv.org/html/2404.00859
Markdown Content: Back to arXiv
This is experimental HTML to improve accessibility. We invite you to report rendering errors. Use Alt+Y to toggle on accessible reporting links and Alt+Shift+Y to toggle off. Learn more about this project and help improve conversions.
Why HTML? Report Issue Back to Abstract Download PDF Abstract 1Introduction 2Related work 3Theory: Pre-caching or breadcrumbs? 4Synthetic data experiments 5Natural language experiments 6Discussion and future work References License: CC BY-SA 4.0 arXiv:2404.00859v2 [cs.LG] 01 Aug 2024 Do Language Models Plan Ahead for Future Tokens? Wilson Wu Department of Mathematics University of Colorado Boulder wiwu2390@colorado.edu &John X. Morris Department of Computer Science Cornell University jxm3@cornell.edu \ANDLionel Levine Department of Mathematics Cornell University levine@math.cornell.edu Abstract
Do transformers βthink aheadβ during inference at a given position? It is known transformers prepare information in the hidden states of the forward pass at time step π‘ that is then used in future forward passes π‘ + π . We posit two explanations for this phenomenon: pre-caching, in which off-diagonal gradient terms present during training result in the model computing features at π‘ irrelevant to the present inference task but useful for the future, and breadcrumbs, in which features most relevant to time step π‘ are already the same as those that would most benefit inference at time π‘ + π . We test these hypotheses by training language models without propagating gradients to past timesteps, a scheme we formalize as myopic training. In a constructed synthetic data setting, we find clear evidence for pre-caching. In the autoregressive language modeling setting, our experiments are more suggestive of the breadcrumbs hypothesis, though pre-caching increases with model scale.
1Introduction
Humans are known to think ahead while speaking; decades of linguistics research (Huettig, 2015; Miller, 1951) have shown evidence that human language users internally predict upcoming language input, words and sometimes sentences ahead (Barthel et al., 2016).
Unlike humans, contemporary language models allocate a fixed amount of information processing for each token when βspeakingβ (Vaswani et al., 2017). Do language models, like humans, think ahead? Recent work (Pal et al., 2023; Hernandez et al., 2024) has shown that tokens beyond the immediate next token can be predicted by probing the hidden state of the language model. Model outputs at future tokens can be predicted to some extent using linear probes on model hidden states, and interventions on hidden states can predictably alter future outputs.1
These findings indicate that model activations at a given timestep are at least somewhat predictive of future outputs. However, it remains unclear why this might be: is this just a happenstance property of the data, or because the model is deliberately preparing information for future timesteps, at the expense of degrading performance on the current position?
We observe that gradients during training optimize weights for both the loss at the current token position as well as for tokens later in the sequence. We question to what extent current transformer weights dedicate resources to the current token vs. allocating it for future tokens.
We consider two possibilities: the pre-caching hypothesis, in which the transformer learns to compute features at time step π‘ that are irrelevant to the inference task at that current time step but may be useful for future time steps π‘ + π , and the breadcrumbs hypothesis, in which the features most relevant to time step π‘ are already identical to those that would most benefit inference at time π‘ + π . To evaluate which hypothesis might be correct, we propose a myopic training scheme that does not propagate gradients from the loss at the current position to hidden states from previous positions. We then evaluate the myopia gap in performance between myopically trained and vanilla transformers as a measure of pre-caching.
To consider whether language models might directly implement pre-caching, we design a synthetic scenario where the task can only be completed via explicit pre-caching. We configure a task where the model must precompute information for the next token, because otherwise the correct answer could not be accurately computed in a single forward pass. In this synthetic scenario, we find clear evidence that the transformer learns to pre-cache. When transformer-based sequence models must precompute information to minimize loss, they do so.
We then consider whether breadcrumbs or pre-caching is demonstrated in natural language models. Our experiments with myopic training suggest that, with small language models like GPT-2 (Radford et al., 2019), much less pre-caching occurs in this setting, pointing towards the breadcrumbs hypothesis. That is, we claim language models on this scale do not intentionally prepare information for the future to a significant extent. Instead, they compute features that are useful to predicting the immediate next token, which turn out to then be helpful at future steps; there is not a significant tradeoff between greedily optimizing for next token loss and ensuring future predictive performance.
However, we also find evidence that the importance of pre-caching increases with scale, becoming non-negligible with larger models, e.g. Pythia 2.8B (Biderman et al., 2023). This suggests that these larger models are βplanning for the futureβ in a way that small models cannot.
Figure 1: At which position is the computation required to correctly answer this math problem taking place? Cognitive science tells us that humans think ahead while speaking; we investigate the extent to which language models do the same. 2Related work Future token meta-prediction.
Several recent works (nostalgebraist, 2020; Belrose et al., 2023; Pal et al., 2023; Cai et al., 2024) observe that transformer hidden states can be used to predict current and future tokens in a sequence, typically via linear probing. Notably, Hernandez et al. (2024) show that more complicated relationships are encoded linearly in hidden states, such as subject-object relations, implying that future tokens can also be predicted in specific cases. This future token predictivity has also been applied to speeding up inference by decoding future tokens in parallel (Stern et al., 2018; Cai et al., 2024). Unlike these works, we focus on the question of how the model learns to prepare hidden states that are useful for future prediction, possibly at the expense of current-token predictivity.
Probing.
Our synthetic data experiments make use of probing, a technique where a simple auxiliary model is used to predict properties from target modelsβ representations (Belinkov & Glass, 2019; Shi et al., 2016; Hewitt & Liang, 2019; Pimentel et al., 2020; Belinkov, 2021). Probing-based approaches are known to overestimate latent information if the classifier learns to do a task on its own (Belinkov, 2021), and probing analyses may only be informative when compared to probing a reasonable baseline (Hewitt & Liang, 2019). In our probing experiments, we avoid these pitfalls by ensuring that the function to be learned cannot possibly be computed by the probe itself.
Mechanistic interpretability.
Our analysis of transformer models in a synthetic setting relates to the subfield of mechanistic interpretability, which seeks to understand models by isolating and explaining the behavior of their components (Olah et al., 2020; Bau et al., 2020; Meng et al., 2023; Nanda et al., 2023). Some of these works (Nanda et al., 2023; Li et al., 2023; Zhong et al., 2023) practice mechanistic interpretability by studying models trained on synthetic data. We apply some mechanistic interpretability techniques in a synthetic setting to study the problem of whether language models βthink aheadβ for future tokens. However, our approach also differs from that of mechanistic interpretability by analyzing the effect of the training procedure on the learned model.
3Theory: Pre-caching or breadcrumbs?
Consider a generic causal sequence-to-sequence prediction task
( π± 1 , β¦ , π± π , π² 1 , β¦ , π² π ) βΌ π ,
where π is a data distribution supported on π π Γ π π for some domains π , π . The task is to estimate the conditional expectations πΌ π β’ ( π² π β£ π± 1 , β¦ , π± π ) for 1 β€ π β€ π .2 Note that we recover the autoregressive setting by setting π
π and π² π
π± π + 1 .
Transformer models trained on such tasks have been observed (Pal et al., 2023) to store information in hidden states during inference at position π that is then used in future inference at π
π . However, since the loss associated with each step π depends only on how well the model does at the immediate task of predicting π² π , it may not be immediately obvious how this preparation for the future arises. We give names to two competing explanations:
β’
Pre-caching: The model βdeliberatelyβ computes and stores features that are expected to be useful for the future, even if they are irrelevant to the present.
β’
Breadcrumbs: The features that most benefit the present inference task are the same as those that are most useful to the future. When the model performs the present forward pass, it βunintentionallyβ leaves a trace (βbreadcrumbsβ) that is then picked up by future passes.
To disentangle these two explanations, we introduce a notion of myopic transformer models, which we show to be incapable of deliberate pre-cachingβfor these models, the extent to which past features are beneficial to the future is decided purely by the breadcrumbs explanation. Thus, the gap between vanilla and myopic transformer models is a quantitative measure of how much pre-caching is taking place.
3.1Causal sequence modeling
Suppose, for the sake of exposition, that the transformer model πΊ uses independent parameters for each position. 3 Let π be the parameter count of each forward pass of πΊ . Then, letting π½ π β Ξ
β π be all parameters used by πΊ at position π , a transformer πΊ is a parameterized function
πΊ : π π Γ Ξ π β π π , ( π± 1 , β¦ , π± π ; π½ 1 , β¦ , π½ π ) β¦ ( π² ^ 1 , β¦ , π² ^ π ) .
For 1 β€ π β€ π , let πΊ π β’ ( π± 1 , β¦ , π± π ) β π be the output of πΊ βs π th forward pass. Because of the causal masking within πΊ , this depends only on π± 1 , β¦ , π± π and π½ 1 , β¦ , π½ π . That is, with slight abuse of notation, we may write
π² ^ π
πΊ π β’ ( π± 1 , β¦ , π± π ; π½ 1 , β¦ , π½ π )
πΊ π β’ ( π± 1 , β¦ , π± π ; π½ 1 , β¦ , π½ π ) .
3.2Off-diagonal gradient terms
Now, letting β : π Γ π β β + be some choice of loss function, the expected loss β of a transformer model with parameters π½ 1 , β¦ , π½ π is
β ( π½ 1 , β¦ , π½ π ) := πΌ ( π± β , π² β ) βΌ π β π
1 π β ( πΊ π ( π± 1 , β¦ , π± π ; π½ 1 , β¦ , π½ π ) , π² π )
: β π
1 π β π ( π½ 1 , β¦ , π½ π ) ,
the sum over 1 β€ π β€ π of the expected loss β π at position π . (We suppress the dependence on πΊ and π for concision.) In practice, we always tie the weights across position. That is, all π½ π are set equal to the same π½ β Ξ . Then, by the chain rule,
β π½ β ( π½ , β¦ , π½ )
β π
1 π β π½ π β ( π½ 1 , β¦ , π½ π ) | π½ 1
β¦
π½ π
π½
β π
1 π β π
π π β π½ π β π ( π½ 1 , β¦ , π½ π ) | π½ 1
β¦
π½ π
π½ ,
a sum over an upper-triangular expected Jacobian βmatrixβ. The off-diagonal terms π < π , corresponding to the expected gradient of the modelβs future loss at position π with respect to its weights at position π , are the training signals that encourage pre-caching.
3.3Measuring pre-caching: The myopia gap
We say a model is myopic when each forward pass πΊ π optimizes only β π without regard for future β π at π
π . In the untied weights case, the right definition is then apparent.
Definition 1.
The parameters ( π ~ 1 , β¦ , π ~ π ) β Ξ π are untied-myopic if they satisfy
π½ ~ π β arg β’ min π½ π β‘ β π β’ ( π½ ~ 1 , β¦ , π½ ~ π β 1 , π½ π ) β π β { 1 , β¦ , π } .
(1) Definition 2.
Let π be the feasible set of the constraints in Equation 1. The untied myopia gap is the smallest possible gap between the expected loss attained by a myopic model and the optimal model:
π β := min ( π½ ~ 1 , β¦ , π½ ~ π ) β π β‘ β β’ ( π½ ~ 1 , β¦ , π½ ~ π ) β min ( π½ 1 , β¦ , π½ π ) β Ξ π β‘ β β’ ( π½ 1 , β¦ , π½ π ) β₯ 0 .
(2)
In the tied weights case, it is perhaps not immediately clear what the right definition of myopia should be. It does not suffice to simply constrain the minimizations in Equation 1 to π½ ~ 1
β¦
π½ ~ π , since min π½ β‘ β π β’ ( π½ , β¦ , π½ ) is optimizing for pre-caching (the dependence on arguments π < π ) as well as the present inference (the dependence on argument π ). Instead, the right notion is a choice of tied parameters such that the model is, aggregated over positions, optimal for the present task when conditioned on a fixed past. That is, forward passes do not compute features for the future if they can compute other features more beneficial to the present.
Definition 3.
The parameters π ~ β Ξ are (tied-)myopic if they satisfy
π½ ~ β arg β’ min π½ β Ξ β’ β π
1 π β π β’ ( π½ ~ , β¦ , π½ ~ , π½ ) .
(3)
The (tied) myopia gap is then defined analogously to Definition 2.
The breadcrumbs hypothesis states that the myopia gap is smallβnear-optimal performance can be attained even when each forward pass is computing features relevant to only its own immediate inference task, with no regard to pre-caching for the future.
If the breadcrumbs hypothesis does not hold, we say that the model is pre-caching. It is important to remember that the β π depend on a choice of transformer model πΊ and dataset π . That is, breadcrumbs and pre-caching are properties of the model architecture and the data considered as a whole.
Although a small myopia gap reveals that one can do just as well without pre-caching, it does not say much about any specific model. To measure pre-caching within a given model, we examine the extent to which its parameters violate the myopia constraints.
Definition 4.
The (tied) local myopia bonus at π β β Ξ is
π ^ β’ ( π½ β ) := max π½ β Ξ β’ β π
1 π ( β π β’ ( π½ β , β¦ , π½ β ) β β π β’ ( π½ β , β¦ , π½ β , π½ ) ) β₯ 0 .
For further interpretation of the myopia gap and myopia bonus, see Appendix A.
3.4Myopic gradient descent
Our heuristic remark in Section 3.2, that the off-diagonal gradient terms are responsible for pre-caching, is justified by Theorem 13 below. It states that, given certain regularity conditions on the loss terms β π , performing gradient descent with the off-diagonal terms removed results in a myopic model in the sense of Definition 3. We call this myopic descent.
For myopic descent to be stable in the tied-weights case, we need, roughly speaking, for the model to depend more on the parameters associated with the present forward pass than those from the past. This is a plausible conditionβdependence on the past is mediated purely by the attention mechanism, while the present forward pass depends on both attention and feedforward parameters. The precise condition we use is forward bias (Definition 11); see Appendix C for details and the proof of the theorem.
Theorem 13.
Let π β’ ( π ~ , π ) := β π
1 π β π β’ ( π ~ , β¦ , π ~ , π ) . If π is forward-biased, π -strongly convex, and πΏ -smooth, then, for some step size π
0 , the iterates of myopic descent with tied weights
π½ ( π‘ + 1 )
π½ ( π‘ ) β π β π½ π ( π½ ~ , π½ ) | π½ ~
π½
π½ ( π‘ )
converge to π ~ β Ξ satisfying the myopia constraints of Equation 3.
4Synthetic data experiments 4.1The π π task
To demonstrate a simple example where the transformer model learns to pre-cache (and thus the myopia gap is large), we construct the following synthetic dataset.
Definition 5.
The data distribution π π , π , π π is defined as the joint distribution of real-valued random variables ( x π ) π
1 π , ( y π ) π
1 π , ( z π ) π
1 π where, for each π ,
β’
x π βΌ π© β’ ( 0 , 1 ) (standard Gaussian)
β’
z π βΌ Ber β’ ( π ) (Bernoulli with probability π )
β’
y π
z π β’ β π
1 π sin β‘ ( π β’ x π β π ) + ( 1 β z π ) β’ x π
and { x π } π β β βͺ { z π } π β β are mutually independent. In our experiments, we always set the parameters π
π
10 and π
64 , so for convenience notate π π := π π , 10 , 10 64 .
The intuition is that a transformer regression model πΊ trained on π π would benefit from pre-caching sin β‘ ( π β’ x π ) during its forward pass at position π , even though this computation is irrelevant to its task of predicting y π . One simple strategy that makes use of this pre-caching is Algorithm 1. 4
The motivation for the Bernoulli variables z π is that, as π decreases, the expected first time when sin β‘ ( π β’ π± π ) becomes useful advances further into the future. In addition, when π is sufficiently small, the probability ( 1 β π ) π that the value sin β‘ ( π β’ x π ) is never useful at all becomes non-negligible. We will show that, even in this case, the transformer model learns to pre-cache.
Investigating myopia.
Suppose that we train a myopic model (Section 3.4) on the same task. Since this model lacks off-diagonal gradient terms, we do not expect it to learn to pre-cache sin β‘ ( π β’ x π ) at position π . One possible strategy that does not use pre-caching is Algorithm 2. We expect this brute force algorithm to perform significantly worse given the same parameter countβit computes an π -dimensional nonlinear function within a single layer, while each layer of Algorithm 1 computes only scalar nonlinear functions.5
Algorithm 1 Pre-caching algorithm
At position π ,
inputββ x π , z π layer 1 computeββ πΉ π := sin β‘ ( π β’ x π ) layer 2 readββ πΉ π β π for π
1 , β¦ , π . layer 2 compute y ^ π := z π β’ β π
1 π πΉ π β π + ( 1 β z π ) β’ x π . return ββ y ^ π .
At position π ,
Algorithm 2 Brute force algorithm inputββ x π , z π layer 1 computeββ β layer 2 readββ x π β π for π
1 , β¦ , π . layer 2 compute y ^ π := z π β’ β π
1 π sin β‘ ( π β’ x π β π ) + ( 1 β z π ) β’ x π . return ββ y ^ π . 4.1.1Evaluation: linear probing
To determine if the transformer model is computing sin β‘ ( π β’ x π ) at position π , we fit linear probes on the hidden states. We additionally compute the correlations between sin β‘ ( π β’ x π ) and each individual dimension (i.e., each neuron) of each hidden state.6 See Section D.1.1 for details.
4.1.2Results for π π
For varying π , we train two-layer transformer models with embedding dimensions of 128 on π π using using both ordinary and myopic gradient descent. Full architecture and training details are provided in Section D.1.
Examining the performance of each linear probe against sin β‘ ( π β’ x π β π ) for varying π , we find strong evidence that the transformer model with vanilla training is indeed pre-caching sin β‘ ( π β’ x π ) , possibly in order to implement Algorithm 1. Indeed, in Figure 2,
β’
The zeroth hidden state (i.e., the sum of the input and position embeddings) at position π is correlated with only x π .
β’
The first hidden state is correlated with sin β‘ ( π β’ x π ) but not correlated with any sin β‘ ( π β’ x π β π ) for π
0 .
β’
The second hidden state (immediately before the output unembedding) is correlated with sin β‘ ( π β’ x π β π ) for each 0 β€ π β€ π .
π Vanilla Myopic 0.01 0.096
1.10
0.1 0.016
0.97
0.3 0.0030
1.03
1.0 0.0074
1.26 Table 1:Normalized Huber loss β / π for vanilla and myopic models trained and evaluated on π π for each π in our synthetic setting. For reference, the trivial model that always outputs zero attains a Huber loss of 1.26 .
Further, looking at the per-neuron correlations in Figure 5, we see that sin β‘ ( π β’ x π β π ) for 1 β€ π β€ π are all correlated with a single 1-d subspace of the second hidden state (they share the same striping pattern); this is the subspace storing β π
1 π sin β‘ ( π β’ x π β π ) . Meanwhile, sin β‘ ( π β’ x π ) , as well as many of the x π β π , are located in various other 1-d subspace of the second hidden state; these terms are all left over in the residual stream from previous layers, and are cleaned up only by the output unembedding.
On the other hand, in Table 1, the myopic models perform significantly worse. The per-neuron correlations in Figure 5 suggest that the myopic model may be implementing a crude approximation of Algorithm 2. This suggests that the synthetic setting has an inherently high myopia gapβit is impossible for the transformer model to do well without pre-caching.
Figure 2: Empirical π 2 between linear probes fit on each layer of vanilla transformer models trained on π π for π β { 0.01 , 0.1 , 0.3 , 1 } to targets sin β‘ ( π β’ x π β π ) . Computed over 50β000 samples from π 1 . Figure 3:*
Vanilla model
Figure 4:*
Myopic model
Figure 5: Empirical correlations between each hidden state neuron and x π β π or sin β‘ ( π β’ x π β π ) . Models are vanilla (left two columns) and myopic (right two columns) transformers trained on π 0.3 . 4.2Multiplication
In addition to the above π π synthetic data setting, we also measure the myopia gap on the task of natural number multiplication. In particular, we find evidence suggesting that pre-caching is responsible for model computation on filler tokens, in the sense of Pfau et al. (2024). See Appendix D.3 for details.
5Natural language experiments 5.1GPT-2βs myopia gap
In order to measure the extent to which transformer models learn to pre-cache on natural language data, we estimate both the myopia gap (Definition 3) in this setting as well as the local myopia bonus (Definition 4) of a transformer model with vanilla pre-training. Experiments in this subsection use the 124M-parameter GPT-2 architecture; see Table 4 in Appendix D for configuration details.
We train all models (vanilla and myopic) from random initialization for one epoch on 4.6M sequences from the MS MARCO dataset (Nguyen et al., 2016), truncated to length 64. To estimate the local myopia bonus of the vanilla model, we train another model from random initialization with the same architecture, but with past hidden states sourced from the frozen vanilla model during both training and evaluation.7 See Appendix B for implementation details.
As baseline, we also train a βtransformer bigramβ model, a model with an identical architecture but all off-diagonal key/value states zeroed out.
5.1.1GPT-2 results
From Table 2, the estimated myopia gap in this setting is 3.40 β 3.28
0.12 cross entropy, while the local myopia bonus of the vanilla model is 3.28 β 3.26
0.02 .
The nonzero myopia gap suggests that pre-caching may provide a small positive benefit. Indeed, in Figure 6, we see that the myopic model outperforms the vanilla model at the beginning of the sequence, since it can allocate all compute to next-token prediction, but quickly falls behind as the length of the past increases, since it suffers from a lack of pre-cached information from earlier forward passes.8
Model Cross-entropy Vanilla 3.28 Myopic 3.40 Local myopic 3.26 Transformer bigram 5.33 Table 2:Validation cross-entropy loss obtained by GPT-2 with vanilla and myopic training
However, note that this gap is much smaller than that between the vanilla model and the transformer bigram model (Table 2). That is, the myopic model is still able to leverage past information (breadcrumbs) to a significant extent, even if they optimized only for the present inference task. That the local myopia gap is near zero further supports this directionβthe model learned through vanilla training does not trade off significantly between features useful for the present and pre-caching for the future.
Figure 6: Cross-entropy loss of vanilla and myopic GPT-2 models by token position, and their difference. Evaluated on a sliding window over a 100K-token sample text from the PG-19 dataset (Rae et al., 2019). Aggregate cross-entropy losses on this sample are 4.67 (vanilla) and 4.77 (myopic). 5.2Myopia gap scaling
One might suppose that the relatively small myopia gap of GPT-2 is due to the relative simplicity of the small architecture we consider, and that larger language models might exhibit a more pronounced myopia gap.
To test this, we train both vanilla and myopic transformers from the Pythia LLM suite (Biderman et al., 2023), ranging in size from 14M to 2.8B parameters, on one epoch of 10M sequences of 64 tokens each subsampled from the Pile dataset (Gao et al., 2020). (We use the same subsampled dataset for every training run.) We report validation cross-entropy loss (Figure 9 in Appendix D) as well as performance on a variety of natural language experiments (Figure 7). Note that, unlike in the GPT-2 experiments (Section 5.1), which start from random initialization, we start all training for Pythia models from the pre-trained checkpoints provided by Biderman et al. (2023)βfor the larger architectures, the 10M sequence dataset we use is not sufficiently large to use for pre-training from random initialization.
LAMBADA (Paperno et al., 2016) PIQA (Bisk et al., 2020)
SciQ (Welbl et al., 2017) ARC-Easy (Zhang et al., 2018) Figure 7:Benchmarks of Pythia models fine-tuned on the Pile dateset using vanilla and myopic descent. 6Discussion and future work
Using a synthetic dataset, we demonstrate that pre-caching can indeed be learned by a transformer model. On the other hand, our experiments with natural language suggest that the breadcrumbs hypothesis is more explanatory for that setting, especially with smaller models, but that the importance of pre-caching increases with scale.
If the myopia gap is indeed not large in practice, there may be several applications of myopic training. We hypothesize that myopic transformers may have advantages in terms of safety and/or interpretabilityβit may be easier to understand what a model is doing if we know that everything that it is computing on each forward pass is directly towards the goal of predicting the immediate next token. For example, as seen in Appendix D.3, myopic models may not be able to make use of computation on the forward passes of filler tokens in the sense of Pfau et al. (2024).
Another possibility is that of automatically swapping in a locally myopic model (Section 5.1) on forward passes where we detect it is beneficial to sacrifice future performance in favor of immediate next-token accuracy (for example, on especially important tokens, or near the end of a text). We leave these possible applications to future work.
Acknowledgments
LL thanks Lukas Berglund and David Schneider-Joseph for inspiring conversations. This research was partly supported by Open Philanthropy and the Berkeley Existential Risk Initiative.
References Barthel et al. (2016) β Maik Barthel, Sebastian Sauppe, Stephen C Levinson, and Antje S Meyer.The timing of utterance planning in task-oriented dialogue: Evidence from a novel list-completion paradigm.Frontiers in psychology, 7:1858, December 2016.doi: 10.3389/fpsyg.2016.01858. Bau et al. (2020) β David Bau, Jun-Yan Zhu, Hendrik Strobelt, Agata Lapedriza, Bolei Zhou, and Antonio Torralba.Understanding the role of individual units in a deep neural network.Proceedings of the National Academy of Sciences, 117(48):30071β30078, September 2020.ISSN 1091-6490.doi: 10.1073/pnas.1907375117.URL http://dx.doi.org/10.1073/pnas.1907375117. Beck (2017) β Amir Beck.First-Order Methods in Optimization.Society for Industrial and Applied Mathematics, Philadelphia, PA, 2017.doi: 10.1137/1.9781611974997.URL https://epubs.siam.org/doi/abs/10.1137/1.9781611974997. Belinkov (2021) β Yonatan Belinkov.Probing classifiers: Promises, shortcomings, and advances, 2021. Belinkov & Glass (2019) β Yonatan Belinkov and James Glass.Analysis methods in neural language processing: A survey, 2019. Belrose et al. (2023) β Nora Belrose, Zach Furman, Logan Smith, Danny Halawi, Igor Ostrovsky, Lev McKinney, Stella Biderman, and Jacob Steinhardt.Eliciting latent predictions from transformers with the tuned lens, 2023. Biderman et al. (2023) β Stella Biderman, Hailey Schoelkopf, Quentin Anthony, Herbie Bradley, Kyle OβBrien, Eric Hallahan, Mohammad Aflah Khan, Shivanshu Purohit, USVSN Sai Prashanth, Edward Raff, Aviya Skowron, Lintang Sutawika, and Oskar Van Der Wal.Pythia: a suite for analyzing large language models across training and scaling.In Proceedings of the 40th International Conference on Machine Learning, ICMLβ23. JMLR.org, 2023. Bisk et al. (2020) β Yonatan Bisk, Rowan Zellers, Ronan Le bras, Jianfeng Gao, and Yejin Choi.Piqa: Reasoning about physical commonsense in natural language.Proceedings of the AAAI Conference on Artificial Intelligence, 34(05):7432β7439, Apr. 2020.doi: 10.1609/aaai.v34i05.6239.URL https://ojs.aaai.org/index.php/AAAI/article/view/6239. Cai et al. (2024) β Tianle Cai, Yuhong Li, Zhengyang Geng, Hongwu Peng, Jason D. Lee, Deming Chen, and Tri Dao.Medusa: Simple llm inference acceleration framework with multiple decoding heads, 2024. Gao et al. (2020) β Leo Gao, Stella Biderman, Sid Black, Laurence Golding, Travis Hoppe, Charles Foster, Jason Phang, Horace He, Anish Thite, Noa Nabeshima, Shawn Presser, and Connor Leahy.The Pile: An 800gb dataset of diverse text for language modeling.arXiv preprint arXiv:2101.00027, 2020. Hernandez et al. (2024) β Evan Hernandez, Arnab Sen Sharma, Tal Haklay, Kevin Meng, Martin Wattenberg, Jacob Andreas, Yonatan Belinkov, and David Bau.Linearity of relation decoding in transformer language models, 2024. Hewitt & Liang (2019) β John Hewitt and Percy Liang.Designing and interpreting probes with control tasks, 2019. Huettig (2015) β Falk Huettig.Four central questions about prediction in language processing.Brain Research, 1626:118β135, 2015.ISSN 0006-8993.doi: https://doi.org/10.1016/j.brainres.2015.02.014.URL https://www.sciencedirect.com/science/article/pii/S0006899315001146.Predictive and Attentive Processing in Perception and Action. Li et al. (2023) β Kenneth Li, Aspen K. Hopkins, David Bau, Fernanda ViΓ©gas, Hanspeter Pfister, and Martin Wattenberg.Emergent world representations: Exploring a sequence model trained on a synthetic task, 2023. Meng et al. (2023) β Kevin Meng, David Bau, Alex Andonian, and Yonatan Belinkov.Locating and editing factual associations in gpt, 2023. Miller (1951) β George A. Miller.Language and communication.McGraw-Hill, New York, NY, US, 1951.doi: 10.1037/11135-000. Nanda et al. (2023) β Neel Nanda, Lawrence Chan, Tom Lieberum, Jess Smith, and Jacob Steinhardt.Progress measures for grokking via mechanistic interpretability, 2023. Nesterov (2018) β Yurii Nesterov.Lectures on Convex Optimization.Springer Publishing Company, Incorporated, 2nd edition, 2018.ISBN 3319915770. Nguyen et al. (2016) β Tri Nguyen, Mir Rosenberg, Xia Song, Jianfeng Gao, Saurabh Tiwary, Rangan Majumder, and Li Deng.MS MARCO: A human generated machine reading comprehension dataset.In Tarek Richard Besold, Antoine Bordes, Artur S. dβAvila Garcez, and Greg Wayne (eds.), Proceedings of the Workshop on Cognitive Computation: Integrating neural and symbolic approaches 2016 co-located with the 30th Annual Conference on Neural Information Processing Systems (NIPS 2016), Barcelona, Spain, December 9, 2016, volume 1773 of CEUR Workshop Proceedings. CEUR-WS.org, 2016.URL https://ceur-ws.org/Vol-1773/CoCoNIPS_2016_paper9.pdf. nostalgebraist (2020) β nostalgebraist.Interpreting gpt: The logit lens, 2020. Olah et al. (2020) β Chris Olah, Nick Cammarata, Ludwig Schubert, Gabriel Goh, Michael Petrov, and Shan Carter.Zoom in: An introduction to circuits.Distill, 2020.doi: 10.23915/distill.00024.001.https://distill.pub/2020/circuits/zoom-in. Pal et al. (2023) β Koyena Pal, Jiuding Sun, Andrew Yuan, Byron Wallace, and David Bau.Future lens: Anticipating subsequent tokens from a single hidden state.In Proceedings of the 27th Conference on Computational Natural Language Learning (CoNLL). Association for Computational Linguistics, 2023.doi: 10.18653/v1/2023.conll-1.37.URL http://dx.doi.org/10.18653/v1/2023.conll-1.37. Paperno et al. (2016) β Denis Paperno, GermΓ‘n Kruszewski, Angeliki Lazaridou, Ngoc Quan Pham, Raffaella Bernardi, Sandro Pezzelle, Marco Baroni, Gemma Boleda, and Raquel FernΓ‘ndez.The LAMBADA dataset: Word prediction requiring a broad discourse context.In Katrin Erk and Noah A. Smith (eds.), Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pp. 1525β1534, Berlin, Germany, August 2016. Association for Computational Linguistics.doi: 10.18653/v1/P16-1144.URL https://aclanthology.org/P16-1144. Pfau et al. (2024) β Jacob Pfau, William Merrill, and Samuel R. Bowman.Letβs think dot by dot: Hidden computation in transformer language models, 2024. Pimentel et al. (2020) β Tiago Pimentel, Josef Valvoda, Rowan Hall Maudslay, Ran Zmigrod, Adina Williams, and Ryan Cotterell.Information-theoretic probing for linguistic structure, 2020. Radford et al. (2019) β Alec Radford, Jeff Wu, Rewon Child, David Luan, Dario Amodei, and Ilya Sutskever.Language models are unsupervised multitask learners.2019.URL https://api.semanticscholar.org/CorpusID:160025533. Rae et al. (2019) β Jack W Rae, Anna Potapenko, Siddhant M Jayakumar, Chloe Hillier, and Timothy P Lillicrap.Compressive transformers for long-range sequence modelling.arXiv preprint, 2019.URL https://arxiv.org/abs/1911.05507. Shen et al. (2023) β Ruoqi Shen, SΓ©bastien Bubeck, Ronen Eldan, Yin Tat Lee, Yuanzhi Li, and Yi Zhang.Positional description matters for transformers arithmetic, 2023. Shen et al. (2022) β Zuowei Shen, Haizhao Yang, and Shijun Zhang.Optimal approximation rate of relu networks in terms of width and depth.Journal de MathΓ©matiques Pures et AppliquΓ©es, 157:101β135, 2022.ISSN 0021-7824.doi: https://doi.org/10.1016/j.matpur.2021.07.009.URL https://www.sciencedirect.com/science/article/pii/S0021782421001124. Shi et al. (2016) β Xing Shi, Inkit Padhi, and Kevin Knight.Does string-based neural mt learn source syntax?pp. 1526β1534, 01 2016.doi: 10.18653/v1/D16-1159. Stern et al. (2018) β Mitchell Stern, Noam M. Shazeer, and Jakob Uszkoreit.Blockwise parallel decoding for deep autoregressive models.In Neural Information Processing Systems, 2018.URL https://api.semanticscholar.org/CorpusID:53208380. Vaswani et al. (2017) β Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Ε ukasz Kaiser, and Illia Polosukhin.Attention is all you need.In Advances in Neural Information Processing Systems, volume 30. Curran Associates, Inc., 2017.URL https://proceedings.neurips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf. Welbl et al. (2017) β Johannes Welbl, Nelson F. Liu, and Matt Gardner.Crowdsourcing multiple choice science questions.ArXiv, abs/1707.06209, 2017.URL https://api.semanticscholar.org/CorpusID:1553193. Xu et al. (2020) β Yilun Xu, Shengjia Zhao, Jiaming Song, Russell Stewart, and Stefano Ermon.A theory of usable information under computational constraints, 2020. Zhang et al. (2018) β Yuyu Zhang, Hanjun Dai, Kamil Toraman, and Le Song. π β’ π 2 : Learning to reason science exam questions with contextual knowledge graph embeddings.ArXiv, abs/1805.12393, 2018.URL https://api.semanticscholar.org/CorpusID:44098100. Zhong et al. (2023) β Ziqian Zhong, Ziming Liu, Max Tegmark, and Jacob Andreas.The clock and the pizza: Two stories in mechanistic explanation of neural networks, 2023. Appendix AMyopia bonus and malus
Notice that the myopia gap consists of two pieces: a myopia bonus, the improvement that can be obtained at the current forward pass by ignoring the future forward passes; and a myopia malus, the cost to the future forward passes that is incurred by not pre-caching for them. To be precise, in the untied case,9 given a choice of myopic π½ ~ 1 , β¦ , π½ ~ π satisfying constraints (1) of Definition 1, write
β β’ ( π½ ~ 1 , β¦ , π½ ~ π ) β min π½ 1 , β¦ , π½ π β‘ β β’ ( π½ 1 , β¦ , π½ π )
β π
1 π ( min π½ π + 1 , β¦ , π½ π β‘ β β’ ( π½ ~ 1 , β¦ , π½ ~ π β 1 , π½ ~ π , π½ π + 1 , β¦ , π½ π ) β min π½ π , β¦ , π½ π β‘ β β’ ( π½ ~ 1 , β¦ , π½ ~ π β 1 , π½ π , π½ π + 1 , β¦ , π½ π ) )
β π
1
π
(
β
π
β’
(
π½
~
1
,
β¦
,
π½
~
π
β
1
,
π½
~
π
)
β
β
π
β’
(
π½
~
1
,
β¦
,
π½
~
π
β
1
,
π½
π
β
π
)
)
+
β
π
1 π β π
π + 1 π ( β π β’ ( π½ ~ 1 , β¦ , π½ ~ π β 1 , π½ ~ π , π½ π + 1 β π + 1 β’ β¦ , π½ π β π + 1 ) β β π β’ ( π½ ~ 1 , β¦ , π½ ~ π β 1 , π½ π β π , π½ π + 1 β π , β¦ , π½ π β π ) ) .
where we define
π½ π β π , β¦ , π½ π β π β arg β’ min π½ π , β¦ , π½ π β‘ β β’ ( π½ ~ 1 , β¦ , π½ ~ π β 1 , π½ π , π½ π + 1 β’ β¦ , π½ π ) .
That is, the myopia gap is the sum π + π
β π π π + β π π π of the myopia bonuses
π π β’ ( π½ ~ 1 , β¦ , π½ ~ π ) := β π β’ ( π½ ~ 1 , β¦ , π½ ~ π β 1 , π½ ~ π ) β β π β’ ( π½ ~ 1 , β¦ , π½ ~ π β 1 , π½ π β π ) β€ 0 ,
(the inequality following from the myopia constraints (1)), and the myopia maluses
π
π
(
π½
~
1
,
β¦
,
π½
~
π
)
:=
β
π
π + 1 π ( β π β’ ( π½ ~ 1 , β¦ , π½ ~ π β 1 , π½ ~ π , π½ π + 1 β π + 1 β’ β¦ , π½ π β π + 1 ) β β π β’ ( π½ ~ 1 , β¦ , π½ ~ π β 1 , π½ π β π , π½ π + 1 β π , β¦ , π½ π β π ) )
β₯ 0 ,
with π π + π π β₯ 0 for each 1 β€ π β€ π by the definition of the π½ π β π . A priori, a small myopia gap does not necessarily imply a small (in magnitude) myopia bonus π and malus π . Indeed, in the case when the myopia gap is small, a large value for π (and thus a corresponding large value for π ) means precisely that the transformer model is committing significant resources to pre-caching that could otherwise have been used to improve inference on the current position. On the other hand, it is possible that both π and π are small; that is, there is not much cost associated with pre-caching for the future, as the present forward pass already results in information (breadcrumbs) useful for that purpose.
In practice, the myopia bonus may be difficult to estimate, as it depends on the result of π β’ ( π ) separate optimization problems (each of which, in practice, is a full transformer model training run). Thus, we instead compute the local myopia bonus of Definition 4.
A.1Explicit gradient paths
The dependence of forward pass πΊ π on previous forward passes πΊ π for π < π is mediated through hidden states π 1 , β¦ , π π β 1 :
πΊ π β’ ( π± 1 , β¦ , π± π ; π½ 1 , β¦ , π½ π )
πΊ ~ π β’ ( π 1 , β¦ , π π β 1 , π± π ; π½ π )
where the π π are themselves recursively defined parameterized functions π π
π π β’ ( π 1 , β¦ , π π β 1 , π± π ; π½ π ) . (Note that we are making a choice here to consider transformers as functions of hidden states, and not their key/value states. This has implications for how myopic descent is defined: when hidden state π π is attended to by forward pass π
π , we consider the key and value weights πΎ πΎ and πΎ π , respectively, to belong to forward pass π . Thus, they are updated by the gradient wrt. β π .)
With the hidden states explicitly written out, the gradient wrt the loss is a sum over all possible paths to the present: for π < π ,
β πΊ π β π½ π β’ ( π± 1 , β¦ , π± π ; π½ 1 , β¦ , π½ π )
β π
π 1 < β¦ < π π < π β π π β π½ π β’ β πΊ ^ π π π π β’ β π
1 π β 1 β π π π β π π π β 1 .
where the sum is over all partitions π
π 1 < β¦ < π π < π .
Appendix BThe myopic attention mechanism
An important primitive that we use when implementing the myopic gradient descent of Section 3.4, the local myopic bonus of Definition 4, and the transformer bigram in Section 5.1 is an attention mechanism that uses key/value states for past forward passes differing from those it uses for the current pass, while still computing all forward passes in parallel. We call our construction the myopic attention mechanism. We use it to implement several distinct transformer training methodologies:
β’
When training with myopic descent, the past key/value states are the result of key/value weights πΎ πΎ and πΎ π , respectively, multiplying a cloned and detached copy of the previous hidden states.
β’
During both training and inference of a local myopic model, the past key/value states come from a separate frozen pre-trained transformer model.
β’
During both training and inference of a transformer bigram model, the past key/value states are simply zeroed out.
Let πΏ
( π 1 , β¦ , π π ) β€ β β π Γ π be the sequence of residual stream hidden states per position, with each row representing one positionβs hidden state in β π . Let πΎ π , πΎ πΎ , πΎ π β β π Γ β be the query, key, value weight matrices for one attention head, of dimensionality β , in the transformer πΊ . Denote
πΈ := πΏ β’ πΎ π , π² := πΏ β’ πΎ πΎ , π½ := πΏ β’ πΎ π
and let πΈ ~ , π² ~ , π½ ~ be the alternate states we wish to use for off-diagonal attention terms. We adopt the convention that lowercase letters with subscripts represent rows of matrices; e.g. π π is the π th row of πΈ . For simplicity of presentation, we omit causal masking; the modifications that should be made in the presence of a mask are straightforward.
Recall that the vanilla attention mechanism for πΊ is
π
π β’ ( πΈ β’ π² β€ ) β’ π½ ,
where π is row-wise softmax. Writing this out token-wise,
π π
π π β 1 β’ β π
1 π exp β‘ ( π π β€ β’ π π ) β’ π― π ,
where π π is the partition function
π π := β π
1 π exp β‘ ( π π β€ β’ π π ) .
The myopic attention mechanism, on the other hand, is written tokenwise as
π ~ π
π ~ π β 1 β’ ( exp β‘ ( π π β€ β’ π π ) β’ π― π + β π β π exp β‘ ( π π β€ β’ π ~ π ) β’ π― ~ π )
π ~ π β 1 β’ β π
1 π exp β‘ ( π π β€ β’ π ~ π + πΏ π β’ π β’ π π β€ β’ ( π π β π ~ π ) ) β’ ( π― ~ π + πΏ π β’ π β’ ( π― π β π― ~ π ) )
β π
1 π π π β’ π β’ π― ~ π β β π
1 π πΏ π β’ π β’ π π β’ π β’ ( π― π β π― ~ π )
where
π
~
π
:=
exp
β‘
(
π
π
β€
β’
π
π
)
+
β
π
β
π
exp
β‘
(
π
π
β€
β’
π
~
π
)
β π
1 π exp β‘ ( π π β€ β’ π ~ π + πΏ π β’ π β’ π π β€ β’ ( π π β π ~ π ) ) ,
π π β’ π
:= π ~ π β 1 β’ exp β‘ ( π π β€ β’ π ~ π + πΏ π β’ π β’ π π β€ β’ ( π π β π ~ π ) ) ,
and πΏ π β’ π is the Kronecker delta. Now, notate
( diag β‘ π¨ ) π β’ π := πΏ π β’ π β’ π΄ π β’ π .
That is, diag β‘ π¨ is the diagonal matrix that has the same entries as π¨ along the diagonal and is zero elsewhere. We are now able to write the myopic attention mechanism in matrix form:
π ~
π¨ β’ π½ ~ + ( diag β‘ π¨ ) β’ ( π½ β π½ ~ ) ,
where
π¨ := π β’ ( πΈ β’ π² ~ β€ + diag β‘ ( πΈ β’ ( π² β€ β π² ~ β€ ) ) ) .
Appendix CProofs
We use two assumptions on the loss function to be minimized. These are standard in the first-order methods literature.
Definition 6.
A function π : β π β β is called πΏ -smooth if it is continuously differentiable with πΏ -Lipschitz gradient. That is, for all π± , π² in the domain,
β₯ β π β’ ( π ) β β π β’ ( π ) β₯ 2 β€ πΏ β’ β₯ π β π β₯ 2 .
Definition 7.
A function π : β π β β is called π -strongly convex if, for all π± , π² in the domain,
π β’ ( π ) β₯ π β’ ( π ) + β π β’ ( π ) β€ β’ ( π β π ) + π 2 β’ β₯ π β π β₯ 2 2 .
In particular, recall that strong convexity implies the existence of a unique minimum. Hence, it makes sense to write, for example, π β
arg β’ max π β‘ π β’ ( π ) without ambiguity.
C.1Gradient descent with untied weights Theorem 8.
Assume β : Ξ π β β is π -strongly convex and πΏ -smooth for some π , πΏ
0 . Consider ordinary gradient descent with untied weights
π½ π ( π‘ )
π½ π ( π‘ β 1 ) β π β’ β π½ π β β’ ( π½ 1 ( π‘ β 1 ) , β¦ , π½ π ( π‘ β 1 ) ) β π β { 1 , β¦ , π } .
Then, for π 1 β , β¦ , π π β
arg β’ min π 1 , β¦ , π π β‘ β β’ ( π 1 , β¦ , π π ) , for small enough π
0 ,
β₯ π½ π ( π‘ ) β π½ π β β₯ 2 2 β€ ( 1 β 2 β’ π β’ π β’ πΏ π + πΏ ) π‘ β’ β₯ π½ π ( 0 ) β π½ π β β₯ 2 2 β π β { 1 , β¦ , π } .
Proof.
This is a standard convergence result for gradient descent on strongly convex functions. For example, see Nesterov (2018). β
C.2Gradient descent with tied weights Theorem 9.
Assume β is π -strongly convex and πΏ -smooth, and consider ordinary gradient descent with tied weights
π½ 1 ( 0 )
β¦
π½
π
(
0
)
π½
π
(
π‘
+
1
)
π½ π ( π‘ ) + π β’ β π
1 π β π½ π β β’ ( π½ 1 , β¦ , π½ π ) β π β { 1 , β¦ , π } .
There exists π
0 such that
β₯ π½ π ( π‘ ) β π½ β β₯ 2 2 β€ ( 1 β 2 β’ π β’ π β’ π β’ πΏ π + πΏ ) π‘ β’ β₯ π½ π ( 0 ) β π½ β β₯ 2 2 β π β { 1 , β¦ , π } .
where π β
arg β’ min π β‘ β β’ ( π , β¦ , π ) .
Proof.
This is again the standard convergence result, now applied to descent with step size π β’ π on the π -strongly convex πΏ -smooth function π½ β¦ β β’ ( π½ / π , β¦ , π½ / π ) . Alternatively, one may think of this as projected gradient descent constrained to the subspace π½ 1
β¦
π½ π with a step size of π β’ π . Projected gradient descent inherits the same convergence properties as unconstrained gradient descent (Beck, 2017). β
C.3Myopic descent with untied weights Theorem 10.
Assume each of the β 1 , β¦ , β π are π -strongly convex and πΏ -smooth. Consider myopic gradient descent with untied weights
π½ π ( π‘ )
π½ π ( π‘ β 1 ) β π β’ β π½ π β π β’ ( π½ 1 ( π‘ β 1 ) , β¦ , π½ π ( π‘ β 1 ) ) β π β { 1 , β¦ , π } .
There exists π
0 such that π π β π‘ β β π ~ π for all π , where
π½ ~ 1
arg β’ min π½ 1 β‘ β 1 β’ ( π½ 1 )
π½ ~ 2
arg β’ min π½ 2 β‘ β 2 β’ ( π½ ~ 1 , π½ 2 )
β¦
π½ ~ π
arg β’ min π½ π β‘ β π β’ ( π½ ~ 1 , π½ ~ 2 , β¦ , π½ ~ π β 1 , π½ π ) .
Proof.
Let π½ ~ 1 , β¦ , π½ ~ π be as in the theorem statement. We proceed by induction. For the base case, note that the myopic descent iterates for π½ 1 ( π‘ ) are independent of π½ π ( π‘ ) for π
π . Thus, the standard convergence theorem gives that π½ 1 ( π‘ ) β π½ ~ 1 as π‘ β β .
Now, assume π½ π ( π‘ ) β π‘ β β π½ ~ π for all π < π . Thus, for any π
0 , for sufficiently large π‘ ,
β₯ ( π½ π ( π‘ ) ) π
1 π β 1 β ( π½ ~ π ) π
1 π β 1 β₯ 2 < π .
Hence, since β π is πΏ -smooth, for any π½ π ,
β₯ β π½ π β π β’ ( π½ 1 ( π‘ ) , β¦ , π½ π β 1 ( π‘ ) , π½ π ) β β π½ π β π β’ ( π½ ~ 1 , β¦ , π½ ~ π β 1 , π½ π ) β₯ 2 < πΏ β’ π .
Expanding and rearranging,
β¨ β π½ π β π β’ ( π½ 1 ( π‘ ) , β¦ , π½ π β 1 ( π‘ ) , π½ π ) , β π½ π β π β’ ( π½ ~ 1 , β¦ , π½ ~ π β 1 , π½ π ) β© β₯ 1 2 β’ β₯ β π½ π β π β’ ( π½ ~ 1 , β¦ , π½ ~ π β 1 , π½ π ) β₯ 2 2 β 1 2 β’ πΏ 2 β’ π 2
is bounded away from zero as long as, say,
β₯ β π½ π β π β’ ( π½ ~ 1 , β¦ , π½ ~ π β 1 , π½ π ) β₯ 2 β₯ π β’ β₯ π½ π β π½ ~ π β₯ 2
( πΏ + 1 ) β’ π ,
using the π -strong convexity of β π . That is, as long as β₯ π½ π β π½ ~ π β₯ 2
( πΏ + 1 ) β’ π π , it is guaranteed that β β π½ π β π β’ ( π½ 1 ( π‘ ) , β¦ , π½ π β 1 ( π‘ ) , π½ π ) is a descent direction for β π β’ ( π½ ~ 1 , β¦ , π½ ~ π β 1 , π½ π ) . This is a sufficient condition for π½ π to converge to a ( πΏ + 1 ) β’ π π -neighborhood of π½ ~ π given a small enough step size (Beck, 2017). Since π
0 is arbitrary, this completes the inductive step. β
C.4Myopic descent with tied weights Definition 11.
A function π β’ ( π± , π² ) : β π Γ β π β β with continuous second derivatives is π -forward-biased if, for all π² β β π and π±
π² ,
π― π , π β’ π β’ ( π , π ) + π― π , π β’ π β’ ( π , π ) β» 0
and
π β’ ( π― π , π β’ π β’ ( π , π ) + π― π , π β’ π β’ ( π , π ) ) < π ,
where π is the condition number, and we write the Hessian of π as a block matrix:
π― β’ π β’ ( π , π )
[ π― π , π β’ π β’ ( π , π )
π― π , π β’ π β’ ( π , π )
π―
π
,
π
β’
π
β’
(
π
,
π
)
π―
π
,
π
β’
π
β’
(
π
,
π
)
]
[ ( β π β’ ( π , π ) β π₯ π β’ β π₯ π ) 1 β€ π β€ π
1 β€ π β€ π
( β π β’ ( π , π ) β π₯ π β’ β π¦ π ) 1 β€ π β€ π
1 β€ π β€ π
( β π β’ ( π , π ) β π¦ π β’ β π₯ π ) 1 β€ π β€ π
1 β€ π β€ π
( β π β’ ( π , π ) β π¦ π β’ β π¦ π ) 1 β€ π β€ π
1 β€ π β€ π
] .
Lemma 12.
If π β β π Γ π is positive definite with π β’ ( π΄ ) < π , then there exists some π
0 such that π β π β’ π is a ( 1 β π β 1 ) -contraction. Explicitly, for any π² β β π ,
β₯ ( π° β π β’ π¨ ) β’ π β₯ 2 β€ ( 1 β π β 1 ) β’ β₯ π β₯ 2 .
Proof.
Set π
1 / π max β’ ( π¨ ) . Then π° β π β’ π¨ βͺ° 0 with
π max β’ ( π° β π β’ π¨ )
1 β π β’ π min β’ ( π¨ )
1 β π min β’ ( π¨ ) π max β’ ( π¨ ) < 1 β π β 1 .
It immediately follows that π° β π β’ π¨ is a ( 1 β π β 1 ) -contraction. β
Theorem 13.
Let π β’ ( π± , π² ) : β π Γ β π β β be π -forward biased, π -strongly convex, and πΏ -smooth with continuous second derivatives for some π , π , πΏ
0 . Then, there exists π² ~ β β π such that
π ~
arg β’ min π β β π β‘ π β’ ( π ~ , π ) .
(4)
Further, for some step size π
0 , the iterates of myopic gradient descent with tied weights
π ( π‘ + 1 )
π ( π‘ ) β π β’ β π π β’ ( π , π ) | π
π
π ( π‘ )
converge to π² ~ β β π satisfying (4). We call such π² ~ a myopic solution.
We then recover the original sequential modeling setting by defining π β’ ( π½ ~ , π½ ) := β π
1 π β π β’ ( π½ ~ , β¦ , π½ ~ , π½ ) .
Proof.
First, note that if myopic descent does converge, it must be to a point π ~ such that π ~
arg β’ min π β‘ π β’ ( π ~ , π ) . Indeed, at convergence we must have β π π β’ ( π ~ , π ) | π
π ~
0 , so strong convexity tells us that π ~ is optimal. Thus, if we establish that myopic descent with small enough step size π
0 converges to some π ~ , we automatically get the existence of a myopic solution. For this, it suffices to show that the gradient descent step
π π β’ ( π )
π β π β’ β π π β’ ( π , π ) | π
π
is strictly contractive, so that iterates of π π converge to a fixed point.
Consider arbitrary π and π β² . Then, by chain rule and the mean value theorem applied to the map π β² β¦ β π π β’ ( π , π ) | π
π
π β² ,
β π π β’ ( π β² , π β² ) β β π π β’ ( π , π )
( π― π , π β’ π β’ ( π β²β² , π β²β² ) + π― π , π β’ π β’ ( π β²β² , π β²β² ) ) β’ ( π β² β π )
(5)
for some π β²β² β [ π , π β² ] . Using the definition of π π and substituting in (5),
β₯ π π β’ ( π β² ) β π π β’ ( π ) β₯ 2 2
β₯ ( π β² β π ) β π β’ ( β π π β’ ( π β² , π β² ) β β π¦ π β’ ( π , π ) ) β₯ 2 2
β₯ ( πΌ β ( π― π , π β’ π β’ ( π β²β² , π β²β² ) + π― π , π β’ π β’ ( π β²β² , π β²β² ) ) ) β’ ( π β² β π ) β₯ 2 2
β€ ( 1 β π β 1 ) β’ ( π β² β π ) ,
where the last line is by π -forward bias and Lemma 12. That is, π π is ( 1 β π β 1 ) -contractive, completing the proof. β
C.5Properties of sine Lemma 14.
Let x βΌ π© β’ ( 0 , 1 ) . Then
Var β’ ( sin β‘ ( π β’ x ) )
1
2
β
π
2
β’
π
2
2
.
Proof.
Var
β’
(
sin
β‘
(
π
β’
x
)
)
( 2 β’ π ) β 1 / 2 β’ β« β β β sin 2 β‘ ( π β’ π₯ ) β’ π β π₯ 2 / 2 β’ π π₯
1 2 β π 2 β’ π 2 2 .
β
Lemma 15.
Let x βΌ π© β’ ( 0 , 1 ) . Then
π β’ ( x , sin β‘ ( π β’ x ) )
2 β’ π β’ π 3 β’ π 2 / 2 π 2 β’ π 2 β 1 β π β’ ( π β π 2 / 2 ) ,
where π is the Pearson correlation coefficient.
Proof.
By symmetry, πΌ β’ [ sin β‘ ( π β’ x ) ]
0 . We calculate
Cov β’ ( x , sin β‘ ( π β’ x ) )
( 2 β’ π ) β 1 / 2 β’ β« β β β π₯ β’ sin β‘ ( π β’ π₯ ) β’ π β π₯ 2 / 2 β’ π π₯
π β’ π β π 2 / 2 .
We already computed the variance in Lemma 14, so
π β’ ( x , sin β‘ ( π β’ x ) )
Cov β’ ( x , sin β‘ ( π β’ x ) ) Var β’ ( x ) β’ Var β’ ( sin β‘ ( π β’ x ) )
2 β’ π β’ π 3 β’ π 2 / 2 π 2 β’ π 2 β 1 .
In our experiments we set π
10 , so π β’ ( x , sin β‘ ( π β’ x ) ) < 10 β 20 .
Appendix DDetails and additional experiments D.1Synthetic setting: π π task
We use a smaller version of the GPT-2 architecture, adapted to regression tasks. That is, the token embedding and unembedding layers are replaced with a trained linear map from the input space to the embedding space and from the embedding space to the output space, respectively. For each π β { 0.01 , 0.1 , 0.3 , 1 } models are trained using ordinary and myopic descent on one epoch of 30M sequences of length 64 sampled from π π , 10 , 10 64 . See Table 3 for architecture details.
Configuration Key Value num_layers 2 num_heads 2 embd_dim 128 n_inner 512 input_dim 2 output_dim 1 activation relu attn_pdrop 0 embd_pdrop 0 resid_pdrop 0 lr 1e-3 optimizer AdamW weight_decay 0.01 betas (0.9, 0.999) scheduler cosine warmup 0.01 batch_size 512 seq_length 64 loss_fn HuberLoss Table 3:Transformer configuration used when training on synthetic data distribution π π D.1.1Probe details
Given a transformer model trained on π π , we sample the hidden states at each layer when the model is given as input 50β000 evaluation sequences from the same distribution π π . For each layer and targets sin β‘ ( π β’ x π β π ) for varying π
0 , we fit a linear regression model on the hidden state of that layer to the target. The in-sample π 2 of each linear model is then reported in Figure 2. Figure 8 is a visualization of the linear probeβs performance on the vanilla transformer.
Figure 8: Estimate of sin β‘ ( π β’ x π ) by linear probe fit on layer 1 of transformer with vanilla training on π 0.3 . Computed over 50β000 samples from π 1 . D.2Natural language setting D.2.1GPT-2
For both vanilla and myopic training, we train the GPT2-small architecture from random initialization for one epoch on 4.6M sequences from the MS MARCO dataset (Nguyen et al., 2016), truncated to length 64. To estimate the local myopia bonus of the vanilla model, we train another model from random initialization with the same architecture, but with past hidden states provided by the vanilla model. Note that this βlocal myopicβ model attains slightly better performance than the vanilla model; each forward pass can focus purely on next-token prediction, since past hidden states are supplied by a separate model. As a baseline, we also train a βtransformer bigramβ model, which is a transformer model whose key/value states are zeroed out during training and evaluation. See Table 4 for configuration details.
Configuration Key Value num_layers 12 num_heads 12 embd_dim 768 n_inner 3072 vocab_size 50257 activation gelu_new attn_pdrop 0.1 embd_pdrop 0.1 resid_pdrop 0.1 lr 6.0 Γ 10 β 4
optimizer AdamW weight_decay 0.01 betas (0.9, 0.999) scheduler cosine warmup 0.01 batch_size 512 seq_length 64 loss_fn CrossEntropy Table 4:GPT-2 small configuration used when training on natural language data. D.2.2Pythia
For experiments on the Pythia suite, we finetune with either vanilla or myopic descent on 10M sequences of 64 tokens each subsampled from The Pile (Gao et al., 2020). Learning rates and batch sizes for each model are presented in Table 5; they are the same between vanilla and myopic descent. All other training and architectural hyperparameters are the same as those used by Biderman et al. (2023).
Model Learning rate Batch size Pythia 14M 4.0 Γ 10 β 4 512 Pythia 31M 4.0 Γ 10 β 4 512 Pythia 70M 1.2 Γ 10 β 4 512 Pythia 160M 1.2 Γ 10 β 4 512 Pythia 410M 1.2 Γ 10 β 4 256 Pythia 1B 1.2 Γ 10 β 4 128 Pythia 1.4B 1.2 Γ 10 β 4 128 Pythia 2.8B 8.0 Γ 10 β 5 64 Table 5:Pythia suite hyperparameters for finetuning. Batch size is measured in sequences of 64 tokens each. Figure 9: Cross-entropy loss of Pythia models fine-tuned on The Pile dataset using vanilla and myopic gradient descent. Starting from the 70M model, we see that the gap increases with parameter count. D.3Multiplication
In addition to the natural language experiments, we also measure the myopia gap on the task of natural number multiplication. We use the same GPT2-small architecture as in the natural language experiments starting from random initialization; see Table 4. See Figure 10 for an example input sequence.
3 7 0 0 * 5 4 0 0 = 5 8 2 3 0 0 0 0
Figure 10: Example multiplication input sequence.
We use several formatting tricks found by Shen et al. (2023) to improve performance on the multiplication task:
β’
Characters are delimited by spaces, such that each digit is tokenized into a separate token.
β’
All numbers are written in the reverse of the standard order, i.e. such that the least significant digits come first.
β’
All inputs are zero-padded to the same length.
Note for each multiplicand we first sample the number of digits π βΌ Unif β’ ( π ) uniformly in some range π , then uniformly sample a natural number π₯ βΌ Unif β’ ( 10 π β 1 ) with no more than π digits. This distribution allocates increased weight to smaller numbers, and was found to result in superior performance.
We train both vanilla and myopic transformers on one epoch of 10M examples with no more than 8 digits, then measure 0/1 accuracy (that is, the model is provided with the an input sample up to the β=β token, and scored 1 if it completes the rest of the sequence exactly correctly and 0 otherwise) on 1024 independent random validation examples for each of the 8 Γ 8 possible pairs of digit counts for the two multiplicands. See Figure 11.
Vanilla multiplication accuracy
Myopic multiplication accuracy
Vanilla-myopic accuracy gap Figure 11:Accuracy of vanilla and myopic transformers trained on multiplication of up to 8-digit inputs. Row and columns correspond to the number of digits in the first and second multiplicands, respectively. D.3.1Filler tokens
We further hypothesize that, as in Pfau et al. (2024), the vanilla transformer may learn to perform computation even on forward passes corresponding to filler tokens, thus attaining better performance when trained on examples zero-padded to longer lengths. We expect that myopic transformers, on the other hand, are not incentivized to do this, since this extra computation holds no relevance towards the immediate task of predicting the filler zero token. To test this hypothesis, we train vanilla and myopic transformers on each of two different multiplication datasets:
1.
Both multiplicands have at most 5 digits, and are zero-padded to exactly 5 digits.
2.
Both multiplicands have at most 5 digits, and are zero-padded to exactly 10 digits.
Again, all training runs consist of one epoch of 10M examples.
See Figure 12 for results. Note that the vanilla transformer indeed performs better when trained and evaluated on input sequences zero-padded to a longer length. However, the myopic transformer performs substantially worse with increased zero-padding. Our explanation is that, not only does the myopic transformer not learn to perform intermediate tokens during zero-token forward passes, the increased input length makes it more difficult for the attention mechanism to correctly attend to the relevant tokens.
D.4Gradient angles
Using the publicly available training checkpoints for Pythia-410M (Biderman et al., 2023), we measure the sizes of both the myopic component and the future component of the gradient of the loss w/r/t the parameters over the course of training. (Note that the future component is the difference between the total vanilla gradient and the myopic gradient.) We also measure the cosine similarity between the myopic and future components. See Figure 13. One observation is that the norm of the myopic gradient is consistently larger than that of the future gradientβthus, training is dominated by the effect of each forward passβs parameters on the immediate next-token prediction.
Vanilla accuracy, padded to length 5 Vanilla accuracy, padded to length 10
Myopic accuracy, padded to length 5 Myopic accuracy, padded to length 10 Figure 12:Multiplication accuracy of GPT-2 with either vanilla or myopic training and with input multiplicands zero-padded to either length 5 or 10. Row and columns correspond to the number of digits in the first and second multiplicands, respectively. Observe that padding improves performance of the vanilla model, but decreases performance of the myopic model.
Norms of myopic, future, and total gradients Cosine similarity between myopic and future gradient Figure 13:Myopic and future gradients of Pythia-410M during training. Report Issue Report Issue for Selection Generated by L A T E xml Instructions for reporting errors
We are continuing to improve HTML versions of papers, and your feedback helps enhance accessibility and mobile support. To report errors in the HTML that will help us improve conversion and rendering, choose any of the methods listed below:
Click the "Report Issue" button. Open a report feedback form via keyboard, use "Ctrl + ?". Make a text selection and click the "Report Issue for Selection" button near your cursor. You can use Alt+Y to toggle on and Alt+Shift+Y to toggle off accessible reporting links at each section.
Our team has already identified the following issues. We appreciate your time reviewing and reporting rendering errors we may not have found yet. Your efforts will help us improve the HTML versions for all readers, because disability should not be a barrier to accessing research. Thank you for your continued support in championing open access for all.
Have a free development cycle? Help support accessibility at arXiv! Our collaborators at LaTeXML maintain a list of packages that need conversion, and welcome developer contributions.
Xet Storage Details
- Size:
- 68 kB
- Xet hash:
- ea03e4dbdd766e7c1b824094be4208766d7dbbb71baff63d35c8f3f2a163ace5
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.