Buckets:
Title: Short-Cutting Transformers with Linear Transformations
URL Source: https://arxiv.org/html/2303.09435
Published Time: Fri, 21 Jun 2024 00:07:57 GMT
Markdown Content: Jump to Conclusions: Short-Cutting Transformers
with Linear Transformations
Abstract
Transformer-based language models create hidden representations of their inputs at every layer, but only use final-layer representations for prediction. This obscures the internal decision-making process of the model and the utility of its intermediate representations. One way to elucidate this is to cast the hidden representations as final representations, bypassing the transformer computation in-between. In this work, we suggest a simple method for such casting, using linear transformations. This approximation far exceeds the prevailing practice of inspecting hidden representations from all layers, in the space of the final layer. Moreover, in the context of language modeling, our method produces more accurate predictions from hidden layers, across various model scales, architectures, and data distributions. This allows “peeking” into intermediate representations, showing that GPT-2 and BERT often predict the final output already in early layers. We then demonstrate the practicality of our method to recent early exit strategies, showing that when aiming, for example, at retention of 95% accuracy, our approach saves additional 7.9% layers for GPT-2 and 5.4% layers for BERT. Last, we extend our method to linearly approximate sub-modules, finding that attention is most tolerant to this change. Our code and learned mappings are publicly available at https://github.com/sashayd/mat.
Keywords: interpretability, language models, efficiency, logitlens, linear lense, linear, early exit, shortcut, layer jump
\NAT@set@cites
Jump to Conclusions: Short-Cutting Transformers
with Linear Transformations
Alexander Yom Din 1, Taelin Karidi 1, Leshem Choshen 1, Mor Geva 2 1 Hebrew University of Jerusalem 2 Tel Aviv University {alexander.yomdin, taelin.karidi, leshem.choshen}@mail.huji.ac.il, morgeva@tauex.tau.ac.il
Abstract content
1.Introduction
Transformer-based language models (LMs) process an input sequence of tokens by first representing it as a sequence of vectors and then repeatedly transforming it through a fixed number of attention and feed-forward network (FFN) layers Vaswani et al. (2017). While each transformation creates new representations, only the final representations are used to obtain model predictions. Correspondingly, LM loss minimization directly optimizes the final representations, while hidden representations are only optimized implicitly, thus making their interpretation and usefulness more obscure.
Figure 1: An illustration of our approach to enhance interpretability and utilization of hidden representations. We use linear mappings A=A ℓ′,ℓ 𝐴 subscript 𝐴 superscript ℓ′ℓ A=A_{\ell^{\prime},\ell}italic_A = italic_A start_POSTSUBSCRIPT roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , roman_ℓ end_POSTSUBSCRIPT to short-cut transformer inference in-between layers (mat ℓ→ℓ′subscript mat→ℓ superscript ℓ′\texttt{mat}{\ell\rightarrow\ell^{\prime}}{}mat start_POSTSUBSCRIPT roman_ℓ → roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT), instead of the prevalent baseline of propagating the hidden representation as-is to the further layer (id ℓ→ℓ′subscript id→ℓ superscript ℓ′\texttt{id}{\ell\rightarrow\ell^{\prime}}{}id start_POSTSUBSCRIPT roman_ℓ → roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT).
However, utilizing hidden representations is highly desirable; a successful interpretation of them can shed light on the “decision-making process” in the course of transformer inference Tenney et al. (2019); Voita et al. (2019); Slobodkin et al. (2021); Geva et al. (2022b), and obtaining predictions from them can substantially reduce computational cost Schwartz et al. (2020); Xu et al. (2021).
Previous attempts to harness hidden representations viewed the hidden representations of an input token as a sequence of approximations of its final representation Elhage et al. (2021); Geva et al. (2022b). This view is motivated by the additive updates induced via the residual connections He et al. (2016) around each layer in the network. Indeed, previous works Geva et al. (2021, 2022a); Ram et al. (2022); Alammar (2021) followed a simplifying assumption that representations at any layer can be transformed into a distribution over the output vocabulary via projection to the output embeddings. While this approach has proven to be surprisingly effective for interpretability Geva et al. (2022a); Dar et al. (2022) and computation efficiency Schuster et al. (2022); Xin et al. (2020); Schwartz et al. (2020), it oversimplifies the model’s computation and assumes that all the layers in the network operate in the same space.
A natural question that arises is whether there is a more accurate way to cast hidden representations into final representation substitutes than interpreting them as they are. In this work, we tackle this question by learning linear transformations across layers in the network (illustrated in Fig.1). For any two layers ℓ<ℓ′ℓ superscript ℓ′\ell<\ell^{\prime}roman_ℓ < roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT, we fit a linear regression to transform hidden representations from layer ℓ ℓ\ell roman_ℓ to layer ℓ′superscript ℓ′\ell^{\prime}roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT. We show that this method, denoted as mat, produces substantially more accurate approximations than the above-discussed identity mapping, dubbed id, applied in previous works (§3). As mat is a non-contextual mapping that operates on single hidden representations, this suggests that there is more linearity to transformer inference than could be estimated by the id mapping.
Next, we test if these gains in approximating future representations also translate to better prediction estimations (§4). To this end, we measure how often language modeling predictions from final representation substitutes produced by mat, and by alternations between mat and regular inference, agree with those of actual final representations. Through experiments with two data sources and various scales of GPT-2 Radford et al. (2019) and BERT Devlin et al. (2019), we observe large accuracy gains (15%percent 15 15%15 %-40%percent 40 40%40 % at most layers) in prediction estimation by mat over naive projections (id). Moreover, we show that our mappings generalize well across different data distributions (§5).
We leverage these findings for enhancing model efficiency and demonstrate our method’s utility in the setting of early exiting – a strategy for dynamically deciding at which layer to stop the inference pass and use that layer’s representation for prediction. While previous works have utilized these hidden representations intact (i.e. using id), we transform them using mat, showing that our method performs better than the baseline in this setting as well (§6), allowing for the saving of additional 7.9%percent 7.9 7.9%7.9 % (resp. 5.4%percent 5.4 5.4%5.4 %) of the layers for GPT-2 (resp. BERT) when aiming at 95%percent 95 95%95 % accuracy.
Last, we analyze how well the different sub-modules of transformer computation – attention, FFN, and layer normalization – can be estimated linearly (§7), by applying the same methodology of linear mappings. We find that linearly approximating attention, the only sub-module that has contextual processing, results in the least reduction of precision. This hints at an interesting possibility of compute time reduction, as non-contextual inference is parallelizable.
To conclude, we propose a method for casting hidden representations across transformer layers, that is light to train, cheap to infer, and provides more accurate and robust representation approximations than the commonly-used baseline of identical propagation. Beyond interpretability, our method holds potential for enhancing efficiency.
2.Background and Notation
The input to a transformer-based LM Vaswani et al. (2017) is a sequence of tokens t 1,…,t n subscript 𝑡 1…subscript 𝑡 𝑛 t_{1},...,t_{n}italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_t start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT from a vocabulary 𝒱 𝒱\mathcal{V}caligraphic_V of size |𝒱|=d v 𝒱 subscript 𝑑 𝑣|\mathcal{V}|=d_{v}| caligraphic_V | = italic_d start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT. The tokens are first represented as vectors using an embedding matrix E∈ℝ d h×d v 𝐸 superscript ℝ subscript 𝑑 ℎ subscript 𝑑 𝑣 E\in\mathbb{R}^{d_{h}\times d_{v}}italic_E ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, where d h subscript 𝑑 ℎ d_{h}italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT is the hidden dimension of the model, to create the initial hidden representations
H 0=(h 1 0,…,h n 0)∈ℝ d h×n.superscript 𝐻 0 superscript subscript ℎ 1 0…superscript subscript ℎ 𝑛 0 superscript ℝ subscript 𝑑 ℎ 𝑛 H^{0}=(h_{1}^{0},\ldots,h_{n}^{0})\in\mathbb{R}^{d_{h}\times n}.italic_H start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT = ( italic_h start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT , … , italic_h start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT × italic_n end_POSTSUPERSCRIPT .
These representations are then repeatedly transformed through L 𝐿 L italic_L transformer blocks, where each block outputs hidden representations that are the inputs to the next block:
∀ℓ∈[1,L]:b ℓ(H ℓ−1)=H ℓ:for-all ℓ 1 𝐿 superscript b ℓ superscript 𝐻 ℓ 1 superscript 𝐻 ℓ\forall\ell\in[1,L]:;;\texttt{b}^{\ell}(H^{\ell-1})=H^{\ell}∀ roman_ℓ ∈ [ 1 , italic_L ] : b start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( italic_H start_POSTSUPERSCRIPT roman_ℓ - 1 end_POSTSUPERSCRIPT ) = italic_H start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT
where
H ℓ=(h 1 ℓ,…,h n ℓ)∈ℝ d h×n.superscript 𝐻 ℓ subscript superscript ℎ ℓ 1…subscript superscript ℎ ℓ 𝑛 superscript ℝ subscript 𝑑 ℎ 𝑛 H^{\ell}=(h^{\ell}{1},\ldots,h^{\ell}{n})\in\mathbb{R}^{d_{h}\times n}.italic_H start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT = ( italic_h start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_h start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT × italic_n end_POSTSUPERSCRIPT .
The ℓ ℓ\ell roman_ℓ-th transformer block is constructed as a composition of two layers:
b ℓ=b ℓ ffn∘b ℓ attn,subscript b ℓ subscript superscript b ffn ℓ subscript superscript b attn ℓ\texttt{b}{\ell}=\texttt{b}^{\texttt{ffn}}{\ell}\circ\texttt{b}^{\texttt{% attn}}_{\ell},b start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT = b start_POSTSUPERSCRIPT ffn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT ∘ b start_POSTSUPERSCRIPT attn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT ,
where b ℓ attn subscript superscript b attn ℓ\texttt{b}^{\texttt{attn}}{\ell}b start_POSTSUPERSCRIPT attn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT (resp. b ℓ ffn subscript superscript b ffn ℓ\texttt{b}^{\texttt{ffn}}{\ell}b start_POSTSUPERSCRIPT ffn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT) is a multi-head self-attention (MHSA) layer (resp. FFN layer) enclosed by a residual connection, and potentially interjected with layer normalization Ba et al. (2016). The final representations,
H L=(h 1 L,…,h n L),superscript 𝐻 𝐿 superscript subscript ℎ 1 𝐿…superscript subscript ℎ 𝑛 𝐿 H^{L}=(h_{1}^{L},\ldots,h_{n}^{L}),italic_H start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT = ( italic_h start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT , … , italic_h start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ) ,
are the transformer stack’s output, used to form various predictions. In this work, we investigate whether and how hidden representations from earlier layers can be utilized for this purpose instead.
3.Linear Shortcut Across Blocks
To use a hidden representation from layer ℓ<L ℓ 𝐿\ell<L roman_ℓ < italic_L as a final representation, we propose to cast it using linear regression, while skipping the computation in-between these layers. More generally, this approach can be applied to cast any ℓ ℓ\ell roman_ℓ-th hidden representation to any subsequent layer ℓ′>ℓ superscript ℓ′ℓ\ell^{\prime}>\ell roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT > roman_ℓ.
3.1.Method
Given a source layer ℓ ℓ\ell roman_ℓ and a target layer ℓ′superscript ℓ′\ell^{\prime}roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT such that 0≤ℓ<ℓ′≤L 0 ℓ superscript ℓ′𝐿 0\leq\ell<\ell^{\prime}\leq L 0 ≤ roman_ℓ < roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ≤ italic_L, our goal is to learn a mapping from hidden representations at layer ℓ ℓ\ell roman_ℓ to those at layer ℓ′superscript ℓ′\ell^{\prime}roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT. To this end, we first collect a set of corresponding hidden representation pairs (h ℓ,h ℓ′)superscript ℎ ℓ superscript ℎ superscript ℓ′(h^{\ell},h^{\ell^{\prime}})( italic_h start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT , italic_h start_POSTSUPERSCRIPT roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT ). Concretely, we run a set 𝒯 𝒯\mathcal{T}caligraphic_T of input sequences through the model, and for each input s 𝑠 s italic_s, we extract the hidden representations h i s ℓ,h i s ℓ′superscript subscript ℎ subscript 𝑖 𝑠 ℓ superscript subscript ℎ subscript 𝑖 𝑠 superscript ℓ′h_{i_{s}}^{\ell},h_{i_{s}}^{\ell^{\prime}}italic_h start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT , italic_h start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT, where i s subscript 𝑖 𝑠 i_{s}italic_i start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT is a random position in s 𝑠 s italic_s. Next, we learn a matrix A ℓ′,ℓ∈ℝ d h×d h subscript 𝐴 superscript ℓ′ℓ superscript ℝ subscript 𝑑 ℎ subscript 𝑑 ℎ A_{\ell^{\prime},\ell}\in\mathbb{R}^{d_{h}\times d_{h}}italic_A start_POSTSUBSCRIPT roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , roman_ℓ end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUPERSCRIPT by fitting linear regression over 𝒯 𝒯\mathcal{T}caligraphic_T, i.e., A ℓ′,ℓ subscript 𝐴 superscript ℓ′ℓ A_{\ell^{\prime},\ell}italic_A start_POSTSUBSCRIPT roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , roman_ℓ end_POSTSUBSCRIPT is a numerical minimizer for:
A↦∑s∈𝒯‖A⋅h i s ℓ−h i s ℓ′‖2,maps-to 𝐴 subscript 𝑠 𝒯 superscript norm⋅𝐴 superscript subscript ℎ subscript 𝑖 𝑠 ℓ superscript subscript ℎ subscript 𝑖 𝑠 superscript ℓ′2 A\mapsto\sum_{s\in\mathcal{T}}||A\cdot h_{i_{s}}^{\ell}-h_{i_{s}}^{\ell^{% \prime}}||^{2},italic_A ↦ ∑ start_POSTSUBSCRIPT italic_s ∈ caligraphic_T end_POSTSUBSCRIPT | | italic_A ⋅ italic_h start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT - italic_h start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT | | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ,
and define the mapping of a representation h ℎ h italic_h from layer ℓ ℓ\ell roman_ℓ to layer ℓ′superscript ℓ′\ell^{\prime}roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT as:
mat ℓ→ℓ′(h)≔A ℓ′,ℓ⋅h.≔subscript mat→ℓ superscript ℓ′ℎ⋅subscript 𝐴 superscript ℓ′ℓ ℎ\texttt{mat}{\ell\rightarrow\ell^{\prime}}{}(h)\coloneqq A{\ell^{\prime},% \ell}\cdot h.mat start_POSTSUBSCRIPT roman_ℓ → roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_h ) ≔ italic_A start_POSTSUBSCRIPT roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , roman_ℓ end_POSTSUBSCRIPT ⋅ italic_h .(1)
3.2.Baseline
We evaluate the prevalent approach of “reading” hidden representations directly, without any transformation. Namely, the propagation of a hidden representation from layer ℓ ℓ\ell roman_ℓ to layer ℓ′superscript ℓ′\ell^{\prime}roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT is given by the identity function, dubbed id:
id ℓ→ℓ′(h)≔h.≔subscript id→ℓ superscript ℓ′ℎ ℎ\texttt{id}_{\ell\rightarrow\ell^{\prime}}{}(h)\coloneqq h.id start_POSTSUBSCRIPT roman_ℓ → roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_h ) ≔ italic_h .
This baseline assumes that representations at different layers operate in the same linear space.
3.3.Quality of Fit
We first evaluate our method by measuring how well the learned linear mappings approximate the representations at the target layer. To this end, we calculate the (coordinate-averaged) r 2 superscript 𝑟 2 r^{2}italic_r start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT-score of our mapping’s outputs with respect to the representations obtained from a full inference pass, and compare to the same for the id baseline.
Models.
We use GPT-2 Radford et al. (2019), a decoder-only auto-regressive LM, with L=48 𝐿 48 L=48 italic_L = 48, d h=1600 subscript 𝑑 ℎ 1600 d_{h}=1600 italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT = 1600, and BERT Devlin et al. (2019), an encoder-only model trained with masked language modeling, with L=24 𝐿 24 L=24 italic_L = 24, d h=1024 subscript 𝑑 ℎ 1024 d_{h}=1024 italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT = 1024.
Data.
We sample random sentences from Wikipedia, collecting 9,000 (resp. 3,000) sentences for the training set 𝒯 𝒯\mathcal{T}caligraphic_T (resp. validation set 𝒱 𝒱\mathcal{V}caligraphic_V).1 1 1 We use sentences rather than full documents to simplify the analysis. For each example s 𝑠 s italic_s, we select a random position i s subscript 𝑖 𝑠 i_{s}italic_i start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT and extract the hidden representations h i s ℓ superscript subscript ℎ subscript 𝑖 𝑠 ℓ h_{i_{s}}^{\ell}italic_h start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT at that position from all the layers. For BERT, we first replace the input token at position i s subscript 𝑖 𝑠 i_{s}italic_i start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT with a [MASK] token, as our motivation is interpreting predictions, which are obtained via masked tokens in BERT (see §4.2). Thus, in this case, the hidden representations we consider are of [MASK] tokens only.
Figure 2: The coordinate-averaged r 2 superscript 𝑟 2 r^{2}italic_r start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT-score of mat ℓ→ℓ′subscript mat→ℓ superscript ℓ′\texttt{mat}{\ell\rightarrow\ell^{\prime}}{}mat start_POSTSUBSCRIPT roman_ℓ → roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT (left) and id ℓ→ℓ′subscript id→ℓ superscript ℓ′\texttt{id}{\ell\rightarrow\ell^{\prime}}{}id start_POSTSUBSCRIPT roman_ℓ → roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT (right) (GPT-2).
Figure 3: The coordinate-averaged r 2 superscript 𝑟 2 r^{2}italic_r start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT-score of mat ℓ→ℓ′subscript mat→ℓ superscript ℓ′\texttt{mat}{\ell\rightarrow\ell^{\prime}}{}mat start_POSTSUBSCRIPT roman_ℓ → roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT (left) and id ℓ→ℓ′subscript id→ℓ superscript ℓ′\texttt{id}{\ell\rightarrow\ell^{\prime}}{}id start_POSTSUBSCRIPT roman_ℓ → roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT (right) (BERT).
Evaluation.
For every pair of layers ℓ,ℓ′ℓ superscript ℓ′\ell,\ell^{\prime}roman_ℓ , roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT, such that 0≤ℓ<ℓ′≤L 0 ℓ superscript ℓ′𝐿 0\leq\ell<\ell^{\prime}\leq L 0 ≤ roman_ℓ < roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ≤ italic_L, we use the training set 𝒯 𝒯\mathcal{T}caligraphic_T to fit linear regression as described in §3.1, and obtain a mapping mat ℓ→ℓ′subscript mat→ℓ superscript ℓ′\texttt{mat}{\ell\rightarrow\ell^{\prime}}{}mat start_POSTSUBSCRIPT roman_ℓ → roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT. Next, we evaluate the quality of mat ℓ→ℓ′subscript mat→ℓ superscript ℓ′\texttt{mat}{\ell\rightarrow\ell^{\prime}}{}mat start_POSTSUBSCRIPT roman_ℓ → roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT as well as of id ℓ→ℓ′subscript id→ℓ superscript ℓ′\texttt{id}{\ell\rightarrow\ell^{\prime}}{}id start_POSTSUBSCRIPT roman_ℓ → roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT using the r 2 superscript 𝑟 2 r^{2}italic_r start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT-coefficient, uniformly averaged over all coordinates. Concretely, we compute the r 2 superscript 𝑟 2 r^{2}italic_r start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT-coefficient of each of the predicted representations mat ℓ→ℓ′(h i s ℓ)subscript mat→ℓ superscript ℓ′superscript subscript ℎ subscript 𝑖 𝑠 ℓ\texttt{mat}{\ell\rightarrow\ell^{\prime}}{}(h_{i_{s}}^{\ell})mat start_POSTSUBSCRIPT roman_ℓ → roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) and id ℓ→ℓ′(h i s ℓ)subscript id→ℓ superscript ℓ′superscript subscript ℎ subscript 𝑖 𝑠 ℓ\texttt{id}{\ell\rightarrow\ell^{\prime}}{}(h{i_{s}}^{\ell})id start_POSTSUBSCRIPT roman_ℓ → roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) versus the true representations h i s ℓ′superscript subscript ℎ subscript 𝑖 𝑠 superscript ℓ′h_{i_{s}}^{\ell^{\prime}}italic_h start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT over all s∈𝒱 𝑠 𝒱 s\in\mathcal{V}italic_s ∈ caligraphic_V.
Results.
Results for GPT-2 and BERT are presented in Figs.2 and3, respectively. In both models, mat consistently yields better approximations than id, as it obtains higher r 2 superscript 𝑟 2 r^{2}italic_r start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT-scores (in blue) across the network. This gap between mat and id is especially evident in BERT, where id completely fails to map the representations between most layers, suggesting that hidden representations are modified substantially by every transformer block. Overall, this highlights the shortcoming of existing practices to inspect representations in the same linear space, and the gains from using our method to approximate future layers.
4.Linear Shortcut for Language Modeling
We saw that our method approximates future hidden representations substantially better than a naive propagation. In this section, we show that this improvement also translates to better predictive abilities from earlier layers. Concretely, we use our method to estimate the final prediction from intermediate representations, in the context of two fundamental LM tasks; next token prediction and masked token prediction.
Evaluation Metrics.
Let h,h′∈ℝ d h ℎ superscript ℎ′superscript ℝ subscript 𝑑 ℎ h,h^{\prime}\in\mathbb{R}^{d_{h}}italic_h , italic_h start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUPERSCRIPT be a final representation and its substitute obtained by some mapping, and denote by δ(h),δ(h′)∈ℝ d v 𝛿 ℎ 𝛿 superscript ℎ′superscript ℝ subscript 𝑑 𝑣\delta(h),\delta(h^{\prime})\in\mathbb{R}^{d_{v}}italic_δ ( italic_h ) , italic_δ ( italic_h start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT end_POSTSUPERSCRIPT their corresponding output probability distributions (see details below). We measure the prediction quality of h′superscript ℎ′h^{\prime}italic_h start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT with respect to h ℎ h italic_h using two metrics:
- •Precision@k 𝑘 k italic_k (↑↑\uparrow↑ is better): This checks whether the token with the highest probability according to δ(h′)𝛿 superscript ℎ′\delta(h^{\prime})italic_δ ( italic_h start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) appears in the top-k 𝑘 k italic_k tokens according to δ(h)𝛿 ℎ\delta(h)italic_δ ( italic_h ). Namely, we sort δ(h)𝛿 ℎ\delta(h)italic_δ ( italic_h ) and assign a score of 1 1 1 1 if argmax(δ(h′))𝛿 superscript ℎ′\arg\max(\delta(h^{\prime}))roman_arg roman_max ( italic_δ ( italic_h start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) appears in the top-k 𝑘 k italic_k tokens by δ(h)𝛿 ℎ\delta(h)italic_δ ( italic_h ), and 0 0 otherwise.
- •Surprisal (↓↓\downarrow↓ is better): We measure the negative log likelihood according to δ(h)𝛿 ℎ\delta(h)italic_δ ( italic_h ), of the highest-probability token according to δ(h′)𝛿 superscript ℎ′\delta(h^{\prime})italic_δ ( italic_h start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ). Intuitively, low values mean that the model sees the substitute result as probable and hence not surprising.
We report the average Precision@k 𝑘 k italic_k and Surprisal over the validation set 𝒱 𝒱\mathcal{V}caligraphic_V.
4.1.Next Token Prediction
Auto-regressive LMs output for every position a probability distribution over the vocabulary for the next token. Specifically, the output distribution for every position i 𝑖 i italic_i is given by δ(h i L)𝛿 superscript subscript ℎ 𝑖 𝐿\delta(h_{i}^{L})italic_δ ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ), where
δ(h)=softmax(E⊤⋅h)∈ℝ d v.𝛿 ℎ softmax⋅superscript 𝐸 top ℎ superscript ℝ subscript 𝑑 𝑣\delta(h)=\texttt{softmax}(E^{\top}\cdot h)\in\mathbb{R}^{d_{v}}.italic_δ ( italic_h ) = softmax ( italic_E start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ⋅ italic_h ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT end_POSTSUPERSCRIPT .(2)
For some LMs, including GPT-2, a layer normalization ln_f is applied to the final layer representation before this conversion (i.e., computing δ(ln_f(h))𝛿 ln_f ℎ\delta(\texttt{ln_f}(h))italic_δ ( ln_f ( italic_h ) ) rather than δ(h)𝛿 ℎ\delta(h)italic_δ ( italic_h )).
Recall that our goal is to measure how well this distribution can be estimated from intermediate representations, i.e. estimating δ(h i L)𝛿 superscript subscript ℎ 𝑖 𝐿\delta(h_{i}^{L})italic_δ ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ) from h i ℓ superscript subscript ℎ 𝑖 ℓ h_{i}^{\ell}italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT where ℓ<L ℓ 𝐿\ell<L roman_ℓ < italic_L. Thus, we first run the validation set examples through the model, while extracting for each example s 𝑠 s italic_s and every layer the hidden representation at a random position i s subscript 𝑖 𝑠 i_{s}italic_i start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT. Next, we apply our mappings mat ℓ→L subscript mat→ℓ 𝐿\texttt{mat}{\ell\rightarrow L}{}mat start_POSTSUBSCRIPT roman_ℓ → italic_L end_POSTSUBSCRIPT and id ℓ→L subscript id→ℓ 𝐿\texttt{id}{\ell\rightarrow L}{}id start_POSTSUBSCRIPT roman_ℓ → italic_L end_POSTSUBSCRIPT to cast the hidden representations of every layer ℓ ℓ\ell roman_ℓ to final layer substitutes (see §3). Last, we convert every final-layer substitute to an output distribution (Eq.2) and compute for each layer the average Precision@k 𝑘 k italic_k (for k=1,5,10 𝑘 1 5 10 k=1,5,10 italic_k = 1 , 5 , 10) and Surprisal scores with respect to the final output distribution, over the validation set.
Figure 4: Precision@k 𝑘 k italic_k (k=1,5,10 𝑘 1 5 10 k=1,5,10 italic_k = 1 , 5 , 10) and Surprisal for mat ℓ→L subscript mat→ℓ 𝐿\texttt{mat}{\ell\rightarrow L}{}mat start_POSTSUBSCRIPT roman_ℓ → italic_L end_POSTSUBSCRIPT and id ℓ→L subscript id→ℓ 𝐿\texttt{id}{\ell\rightarrow L}{}id start_POSTSUBSCRIPT roman_ℓ → italic_L end_POSTSUBSCRIPT (GPT-2 next token prediction task). 95% confidence intervals are shown for Surprisal.
Results.
Fig.4 shows the average Precision@k 𝑘 k italic_k and Surprisal scores per layer in GPT-2. Across all layers, mat outperforms id in terms of both scores, often by a large margin (e.g. till layer 44 44 44 44 the Precision@1 1 1 1 achieved by mat is bigger than that of id by more than 20%percent 20 20%20 %). This shows that linear mappings enable not just better estimation of final layer representations, but also of the predictions they induce. Moreover, the relatively high Precision@k 𝑘 k italic_k scores of mat in early layers (62%percent 62 62%62 %-82%percent 82 82%82 % for k=10 𝑘 10 k=10 italic_k = 10, 52%percent 52 52%52 %-74%percent 74 74%74 % for k=5 𝑘 5 k=5 italic_k = 5, and 28%percent 28 28%28 %-45%percent 45 45%45 % for k=1 𝑘 1 k=1 italic_k = 1) suggest that early representations often accurately approximate the final prediction. Also, the substantially lower Surprisal scores of mat compared to id imply that our method allows for a more representative reading into the layer-wise prediction-formation of the model than allowed via direct projection to the vocabulary.
4.2.Masked Token Prediction
We conduct the same experiment in §4.1 for masked language modeling, where the model predicts a probability distribution for a masked token in the input. Unlike next token prediction, where the output distribution is computed from representations of varying input tokens, in masked token prediction the output is always obtained from representations of the same input token (i.e. [MASK]).
For this experiment, we use BERT, on top of which we use a pretrained masked language model head δ 𝛿\delta italic_δ; given a token sequence s 𝑠 s italic_s, a [MASK] token inside it and its final representation h ℎ h italic_h, δ(h)∈ℝ d v 𝛿 ℎ superscript ℝ subscript 𝑑 𝑣\delta(h)\in\mathbb{R}^{d_{v}}italic_δ ( italic_h ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT end_POSTSUPERSCRIPT is a probability distribution over tokens giving the model’s assessment of the likelihood of tokens to be fitting in place of the [MASK] token in s 𝑠 s italic_s.
Figure 5: Precision@k 𝑘 k italic_k (k=1,5,10 𝑘 1 5 10 k=1,5,10 italic_k = 1 , 5 , 10) and Surprisal for mat ℓ→L subscript mat→ℓ 𝐿\texttt{mat}{\ell\rightarrow L}{}mat start_POSTSUBSCRIPT roman_ℓ → italic_L end_POSTSUBSCRIPT and id ℓ→L subscript id→ℓ 𝐿\texttt{id}{\ell\rightarrow L}{}id start_POSTSUBSCRIPT roman_ℓ → italic_L end_POSTSUBSCRIPT (BERT masked token prediction task). 95% confidence intervals are shown for Surprisal.
Results.
Fig.5 shows the average Precision@k 𝑘 k italic_k and Surprisal scores per layer in BERT, overall showing trends similar to those observed for next token prediction in GPT-2 (§4.1). This is despite the differences between the two tasks and the considerable architectural differences between the models. Notably, the superiority of mat over id in this setting is even more prominent; while, in the first ten layers, mat’s precision is between 8%percent 8 8%8 %-52%percent 52 52%52 % (Fig.5), id’s precision for all values of k 𝑘 k italic_k is close to zero, again strongly indicating that our method allows for better reading into early layer hidden representations. More generally, mat improves the Precision@1 1 1 1 of id by more than 17%percent 17 17%17 % at most layers, and unveils that a substantial amount of predictions (>25%absent percent 25>25%> 25 % starting from layer 3 3 3 3) appear already in the very first layers. Interestingly, the (rough) divide between the first last halves of layers for id in Fig.5 seems to align with the two-hump shape of the blue region for mat in Fig.3.
id 4 subscript id 4\texttt{id}{4}id start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT mat 4 subscript mat 4\texttt{mat}{4}mat start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT id 12 subscript id 12\texttt{id}{12}id start_POSTSUBSCRIPT 12 end_POSTSUBSCRIPT mat 12 subscript mat 12\texttt{mat}{12}mat start_POSTSUBSCRIPT 12 end_POSTSUBSCRIPT id 24 subscript id 24\texttt{id}_{24}id start_POSTSUBSCRIPT 24 end_POSTSUBSCRIPT Input: aldridge had shoulder surgery in [MASK]. fellowship cyclist employment emergencies agreement her seniors ##ostal them cycling ##com work Input: on your next view you will be asked to [MASK] continue reading. ##com be be be accreditation get undergo ©©\copyright©go spartans help fellowship seniors summer have*say
Table 1: Examples of top-5 5 5 5 BERT masked token predictions at layers 4 4 4 4, 12 12 12 12 and 24 24 24 24, under the mappings mat ℓ→L subscript mat→ℓ 𝐿\texttt{mat}{\ell\rightarrow L}{}mat start_POSTSUBSCRIPT roman_ℓ → italic_L end_POSTSUBSCRIPT (abbreviated mat ℓ\texttt{mat}{}{\ell}mat start_FLOATSUBSCRIPT roman_ℓ end_FLOATSUBSCRIPT) and id ℓ→L subscript id→ℓ 𝐿\texttt{id}{\ell\rightarrow L}{}id start_POSTSUBSCRIPT roman_ℓ → italic_L end_POSTSUBSCRIPT (abbreviated id ℓ\texttt{id}{}{\ell}id start_FLOATSUBSCRIPT roman_ℓ end_FLOATSUBSCRIPT). Plausible predictions (according to a human annotator) are marked in . Note that for ℓ=L=24 ℓ 𝐿 24\ell=L=24 roman_ℓ = italic_L = 24, predictions of mat ℓ\texttt{mat}{}{\ell}mat start_FLOATSUBSCRIPT roman_ℓ end_FLOATSUBSCRIPT and id ℓ\texttt{id}{}{\ell}id start_FLOATSUBSCRIPT roman_ℓ end_FLOATSUBSCRIPT are the same.
Analysis.
We manually compare the predictions of our mapping mat ℓ→L subscript mat→ℓ 𝐿\texttt{mat}{\ell\rightarrow L}{}mat start_POSTSUBSCRIPT roman_ℓ → italic_L end_POSTSUBSCRIPT with id ℓ→L subscript id→ℓ 𝐿\texttt{id}{\ell\rightarrow L}{}id start_POSTSUBSCRIPT roman_ℓ → italic_L end_POSTSUBSCRIPT, for the BERT model. Concretely, we select 50 50 50 50 random sentences from the Leipzig dataset (see §5.2). Next, for each layer ℓ ℓ\ell roman_ℓ, we manually analyze how many of the top-5 5 5 5 tokens according to mat ℓ→L subscript mat→ℓ 𝐿\texttt{mat}{\ell\rightarrow L}{}mat start_POSTSUBSCRIPT roman_ℓ → italic_L end_POSTSUBSCRIPT and id ℓ→L subscript id→ℓ 𝐿\texttt{id}{\ell\rightarrow L}{}id start_POSTSUBSCRIPT roman_ℓ → italic_L end_POSTSUBSCRIPT fit into context. We consider a token to fit into context if it is grammatically plausible within the sentence (see Tab.1 for examples). In the resulting 1,250 1 250 1,250 1 , 250 instances (i.e. 50 50 50 50 sentences ×\times×25 25 25 25 representations), we observe a substantially higher plausibility rate of 85.4%percent 85.4 85.4%85.4 % for mat compared to 52.8%percent 52.8 52.8%52.8 % for id. In fact, only in less than 4.3%percent 4.3 4.3%4.3 % of the instances, there are more plausible tokens among the top-5 5 5 5 tokens according to id than according to mat, further supporting the Surprisal results above.
4.3.Alternation Schemes
Thus far, we considered direct mappings to the last layer. We now check if mappings between intermediate layers can improve prediction estimations further. To this end, we obtain final-representation substitutes by alternating between transformer-inference and linear mappings. For ℓ<ℓ′ℓ superscript ℓ′\ell<\ell^{\prime}roman_ℓ < roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT, let us abbreviate
b ℓ→ℓ′:=b ℓ′∘⋯∘b ℓ+2∘b ℓ+1,assign subscript b→ℓ superscript ℓ′superscript b superscript ℓ′⋯superscript b ℓ 2 superscript b ℓ 1\texttt{b}_{{\ell}\rightarrow{\ell^{\prime}}}:=\texttt{b}^{\ell^{\prime}}\circ% \cdots\circ\texttt{b}^{\ell+2}\circ\texttt{b}^{\ell+1},b start_POSTSUBSCRIPT roman_ℓ → roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT := b start_POSTSUPERSCRIPT roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT ∘ ⋯ ∘ b start_POSTSUPERSCRIPT roman_ℓ + 2 end_POSTSUPERSCRIPT ∘ b start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT ,
i.e. b ℓ→ℓ′subscript b→ℓ superscript ℓ′\texttt{b}{{\ell}\rightarrow{\ell^{\prime}}}b start_POSTSUBSCRIPT roman_ℓ → roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT is the application of transformer inference from layer ℓ ℓ\ell roman_ℓ to layer ℓ′superscript ℓ′\ell^{\prime}roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT. For a sequence 0=ℓ 0<ℓ 1<…<ℓ n=L 0 subscript ℓ 0 subscript ℓ 1…subscript ℓ 𝑛 𝐿 0=\ell{0}<\ell_{1}<\ldots<\ell_{n}=L 0 = roman_ℓ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT < roman_ℓ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT < … < roman_ℓ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT = italic_L, we consider either of
⋯∘b ℓ 2→ℓ 3∘mat ℓ 1→ℓ 2∘b ℓ 0→ℓ 1,⋯subscript b→subscript ℓ 2 subscript ℓ 3 subscript mat→subscript ℓ 1 subscript ℓ 2 subscript b→subscript ℓ 0 subscript ℓ 1\cdots\circ\texttt{b}{{\ell{2}}\rightarrow{\ell_{3}}}\circ\texttt{mat}{{% \ell{1}}\rightarrow{\ell_{2}}}\circ\texttt{b}{{\ell{0}}\rightarrow{\ell_{1}% }},⋯ ∘ b start_POSTSUBSCRIPT roman_ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT → roman_ℓ start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∘ mat start_POSTSUBSCRIPT roman_ℓ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT → roman_ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∘ b start_POSTSUBSCRIPT roman_ℓ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT → roman_ℓ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ,(3)
⋯∘mat ℓ 2→ℓ 3∘b ℓ 1→ℓ 2∘mat ℓ 0→ℓ 1.⋯subscript mat→subscript ℓ 2 subscript ℓ 3 subscript b→subscript ℓ 1 subscript ℓ 2 subscript mat→subscript ℓ 0 subscript ℓ 1\cdots\circ\texttt{mat}{{\ell{2}}\rightarrow{\ell_{3}}}\circ\texttt{b}{{% \ell{1}}\rightarrow{\ell_{2}}}\circ\texttt{mat}{{\ell{0}}\rightarrow{\ell_{% 1}}}.⋯ ∘ mat start_POSTSUBSCRIPT roman_ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT → roman_ℓ start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∘ b start_POSTSUBSCRIPT roman_ℓ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT → roman_ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∘ mat start_POSTSUBSCRIPT roman_ℓ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT → roman_ℓ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT .(4)
In other words, those are inference modes alternating between transformer inference and application of our linear mappings in a prescribed manner. We then collect two sets of alternation schemes:
- •r 2 superscript 𝑟 2 r^{2}italic_r start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT-informed (R2): We define the r 2 superscript 𝑟 2 r^{2}italic_r start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT-score of Eq.3 (resp. Eq.4) to be the product of the r 2 superscript 𝑟 2 r^{2}italic_r start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT-scores of mat ℓ i→ℓ i+1 subscript mat→subscript ℓ 𝑖 subscript ℓ 𝑖 1\texttt{mat}{{\ell{i}}\rightarrow{\ell_{i+1}}}mat start_POSTSUBSCRIPT roman_ℓ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT → roman_ℓ start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT (computed in §3.3), for i=1,3,…𝑖 1 3…i=1,3,\ldots italic_i = 1 , 3 , … (resp. i=0,2,…𝑖 0 2…i=0,2,\ldots italic_i = 0 , 2 , …). For each ℓ ℓ\ell roman_ℓ, we consider the scheme that employs ℓ ℓ\ell roman_ℓ transformer blocks with the maximal r 2 superscript 𝑟 2 r^{2}italic_r start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT-score.
- •Weighted round-robin (WRR): For a,b≥1 𝑎 𝑏 1 a,b\geq 1 italic_a , italic_b ≥ 1 such that a+b 𝑎 𝑏 a+b italic_a + italic_b divides L 𝐿 L italic_L, we consider (ℓ 0,ℓ 1,…)subscript ℓ 0 subscript ℓ 1…(\ell_{0},\ell_{1},\ldots)( roman_ℓ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , roman_ℓ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … ) given by (0,a,a+b,2a+b,2a+2b,…)0 𝑎 𝑎 𝑏 2 𝑎 𝑏 2 𝑎 2 𝑏…(0,a,a+b,2a+b,2a+2b,\ldots)( 0 , italic_a , italic_a + italic_b , 2 italic_a + italic_b , 2 italic_a + 2 italic_b , … ) and the two corresponding schemes Eq.3,4. In other words, here we alternate between performing a 𝑎 a italic_a transformer blocks and application of our linear mapping across b 𝑏 b italic_b layers, for some fixed values of a 𝑎 a italic_a and b 𝑏 b italic_b.
For the experiment, we use GPT-2 with 24 24 24 24 layers (see §5.1) and measure Precision@1 1 1 1.
Figure 6: Precision@1 1 1 1 for various alternation schemes and previous mappings for comparison (24 24 24 24-layer GPT-2 next token prediction task).
Results.
Fig.6 presents Precision@1 1 1 1 for various alternation schemes. We see, first of all, that some alternation schemes provide better precision than the previously considered mat ℓ→L subscript mat→ℓ 𝐿\texttt{mat}_{\ell\rightarrow L}{}mat start_POSTSUBSCRIPT roman_ℓ → italic_L end_POSTSUBSCRIPT. Second, the best-r 2 superscript 𝑟 2 r^{2}italic_r start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT-score tactic for choosing an alternation scheme seems to work well for the first half of layers, but under-achieves (relative to other possible alternation schemes) for the second half. It is, therefore, interesting to try to devise more clever tactics in the future; perhaps, for example, by weighting r 2 superscript 𝑟 2 r^{2}italic_r start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT-scores according to layer index.
5.Method Robustness
5.1.Robustness Across Model Scales
We repeat our experiments in §4, with three additional scales of GPT-2 and one additional scale of BERT. Overall, the models are gpt2 (L=12 𝐿 12 L=12 italic_L = 12, d h=768 subscript 𝑑 ℎ 768 d_{h}=768 italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT = 768), gpt2-medium (L=24 𝐿 24 L=24 italic_L = 24, d h=1024 subscript 𝑑 ℎ 1024 d_{h}=1024 italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT = 1024), gpt2-large (L=36 𝐿 36 L=36 italic_L = 36, d h=1280 subscript 𝑑 ℎ 1280 d_{h}=1280 italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT = 1280) and gpt2-xl (L=48 𝐿 48 L=48 italic_L = 48, d h=1600 subscript 𝑑 ℎ 1600 d_{h}=1600 italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT = 1600), and bert-base-uncased (L=12 𝐿 12 L=12 italic_L = 12, d h=768 subscript 𝑑 ℎ 768 d_{h}=768 italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT = 768) and bert-large-uncased (L=24 𝐿 24 L=24 italic_L = 24, d h=1024 subscript 𝑑 ℎ 1024 d_{h}=1024 italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT = 1024).
Fig.7 (resp. Fig.8) depicts the Precision@1 1 1 1 and Surprisal scores as functions of the relative depth of the model (i.e. ℓ/L ℓ 𝐿\ell/L roman_ℓ / italic_L), for GPT-2 models (resp. BERT models). The plots show the same trends observed in §4 across various model scales, with mat exhibiting substantially higher predictive abilities from intermediate layers than id. Interestingly, there is a great overlap between GPT-2 scores of different scales, but not between the scores of BERT models.
Figure 7: Precision@1 1 1 1 and Surprisal for mat ℓ→L subscript mat→ℓ 𝐿\texttt{mat}{\ell\rightarrow L}mat start_POSTSUBSCRIPT roman_ℓ → italic_L end_POSTSUBSCRIPT and id ℓ→L subscript id→ℓ 𝐿\texttt{id}{\ell\rightarrow L}{}id start_POSTSUBSCRIPT roman_ℓ → italic_L end_POSTSUBSCRIPT, for next token prediction with GPT-2. 95% confidence intervals are shown for Surprisal.
Figure 8: Precision@1 1 1 1 and Surprisal for mat ℓ→L subscript mat→ℓ 𝐿\texttt{mat}{\ell\rightarrow L}mat start_POSTSUBSCRIPT roman_ℓ → italic_L end_POSTSUBSCRIPT and id ℓ→L subscript id→ℓ 𝐿\texttt{id}{\ell\rightarrow L}{}id start_POSTSUBSCRIPT roman_ℓ → italic_L end_POSTSUBSCRIPT, for masked token prediction with BERT. 95% confidence intervals are shown for Surprisal.
5.2.Robustness Across Data Distributions
We test whether the linear mappings learned from one data distribution are useful for predictions on other data distributions. To this end, we use a second dataset of news article sentences, the 10K English 2020 news sentences corpus from the Leipzig Corpora Collection Goldhahn et al. (2012), which we randomly divide into a training set 𝒯 𝒯\mathcal{T}caligraphic_T consisting of 9,000 examples and a validation set 𝒱 𝒱\mathcal{V}caligraphic_V consisting of 1,000 examples. For our experiments, we use the 24 24 24 24-layer GPT-2 and BERT models. First, we replicate previous experiments on the Leipzig dataset, obtaining results that are extremely similar; for example, the average (across layers) difference between the Precision@1 1 1 1 score of Wikipedia and Leipzig is 0.3%percent 0.3 0.3%0.3 % for GPT-2 and −1.4%percent 1.4-1.4%- 1.4 % for BERT. Next, we use Leipzig (resp. Wikipedia) samples to fit linear mappings mat ℓ→L subscript mat→ℓ 𝐿\texttt{mat}_{\ell\rightarrow L}{}mat start_POSTSUBSCRIPT roman_ℓ → italic_L end_POSTSUBSCRIPT (as described in §3), and then evaluate these mappings in the context of next-token prediction, on samples from Wikipedia (resp. Leipzig) (as in §4). When swapping the original mappings with those trained on the other dataset, we observe a decrease of 0.1%percent 0.1 0.1%0.1 % (resp. increase of 1.1%percent 1.1 1.1%1.1 %) relative to the original Precision@1 1 1 1 scores for BERT and a decrease of 5.5%percent 5.5 5.5%5.5 % (resp. decrease of 8%percent 8 8%8 %) relative to the original Precision@1 1 1 1 scores for GPT-2, on average across layers. Overall, this shows that our method generalizes well to out-of-distribution samples. Moreover, our linear mappings capture general, rather than domain specific, features of the model’s inference pass.
6.Implication to Early Exiting
The possibility of approximating the final prediction already in the early layers has important implications for efficiency; applying our linear mapping instead of executing transformer blocks of quadratic time complexity, could save a substantial portion of the computation. In this section, we demonstrate this in the context of early exiting.
When using an early exit strategy Schwartz et al. (2020); Xin et al. (2020); Schuster et al. (2022), one aims at deciding dynamically at which layer to stop the computation and “read” the prediction from the hidden representation of that layer. More precisely, under a confidence measure paradigm, one decides to stop the computation for a position i 𝑖 i italic_i at layer ℓ ℓ\ell roman_ℓ based on a confidence criterion, that is derived from casting the hidden representation h i ℓ superscript subscript ℎ 𝑖 ℓ h_{i}^{\ell}italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT as a final-layer representation and converting it to an output probability distribution. Specifically, following Schuster et al. (2022), a decision to exit is made if the difference between the highest and the second highest probabilities is bigger than
0.9⋅λ+0.1⋅exp(−4i/N),⋅0.9 𝜆⋅0.1 exp 4 𝑖 𝑁 0.9\cdot\lambda+0.1\cdot{\rm exp}(-4i/N),0.9 ⋅ italic_λ + 0.1 ⋅ roman_exp ( - 4 italic_i / italic_N ) ,
where N 𝑁 N italic_N is the average length of the input until position i s subscript 𝑖 𝑠 i_{s}italic_i start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT for s∈𝒱 𝑠 𝒱 s\in\mathcal{V}italic_s ∈ caligraphic_V, and λ 𝜆\lambda italic_λ is a hyper-parameter.
Figure 9: Precision@1 1 1 1 with early exit and “fixed exit”, applied to the 24 24 24 24-layer GPT-2 for next token prediction (left) and the 24 24 24 24-layer BERT for masked token prediction (right). Varying the confidence parameter λ 𝜆\lambda italic_λ, the x 𝑥 x italic_x-coordinate is the average number of layers processed before an early exit decision is reached.
Experiment.
We assess the utility of our mapping mat ℓ→L subscript mat→ℓ 𝐿\texttt{mat}{\ell\rightarrow L}{}mat start_POSTSUBSCRIPT roman_ℓ → italic_L end_POSTSUBSCRIPT for early exit as a plug-and-play replacement for id ℓ→L subscript id→ℓ 𝐿\texttt{id}{\ell\rightarrow L}{}id start_POSTSUBSCRIPT roman_ℓ → italic_L end_POSTSUBSCRIPT, through which intermediate representations are cast into final-layer representations. We use GPT-2 for the next token prediction and BERT for masked token prediction (both with 24 layers). We run each of the models over the validation set examples, while varying the confidence parameter λ 𝜆\lambda italic_λ and using either id ℓ→L subscript id→ℓ 𝐿\texttt{id}{\ell\rightarrow L}{}id start_POSTSUBSCRIPT roman_ℓ → italic_L end_POSTSUBSCRIPT or mat ℓ→L subscript mat→ℓ 𝐿\texttt{mat}{\ell\rightarrow L}{}mat start_POSTSUBSCRIPT roman_ℓ → italic_L end_POSTSUBSCRIPT for casting intermediate representations. Furthermore, we compare these early exit variants to the “fixed exit” strategy from §4, where the computation is stopped after a pre-defined number of layers rather than relying on a dynamic decision. We evaluate each variant in terms of both prediction’s accuracy, using the Precision@1 1 1 1 metric (see §4), and efficiency, measured as the average number of transformer layers processed during inference.
Results.
Fig.9 plots the average Precision@1 1 1 1 score against the average number of layers processed, for 24 24 24 24-layer GPT-2 and 24 24 24 24-layer BERT. For both models, under an early exit strategy our mapping mat again provides a substantial improvement over id. For example, aiming at 95%percent 95 95%95 % average precision, mat saves ∼3.3 similar-to absent 3.3\sim 3.3∼ 3.3 (13.8 13.8 13.8 13.8%) layers in GPT-2 compared to only ∼1.4 similar-to absent 1.4\sim 1.4∼ 1.4 (5.9 5.9 5.9 5.9%) layers by id, and ∼4.8 similar-to absent 4.8\sim 4.8∼ 4.8 (20 20 20 20%) layers in BERT versus ∼3.5 similar-to absent 3.5\sim 3.5∼ 3.5 (14.6 14.6 14.6 14.6%) layers by id. These results highlight the potential gains prominent early exit methods can obtain by using our method. Notably, in both models and for each of the mapping methods, early exit obtains better results than fixed layer exit, as expected.
7.Linear Shortcut Across Sub-Modules
In this section, we investigate whether discrepancies across layers result from specific sub-modules or are a general behaviour of all sub-modules in the network. This is done by extending our approach to test how well particular components in transformer blocks can be linearly approximated.
Method.
Consider GPT-2 for definiteness, then:
b ℓ=b ℓ ffn∘b ℓ attn subscript b ℓ superscript subscript b ℓ ffn superscript subscript b ℓ attn\texttt{b}{\ell}=\texttt{b}{\ell}^{\texttt{ffn}}\circ\texttt{b}_{\ell}^{% \texttt{attn}}b start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT = b start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ffn end_POSTSUPERSCRIPT ∘ b start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT attn end_POSTSUPERSCRIPT
b ℓ attn(H)=attn ℓ(ln1 ℓ(H))+H,subscript superscript b attn ℓ 𝐻 subscript attn ℓ subscript ln1 ℓ 𝐻 𝐻\texttt{b}^{\texttt{attn}}{\ell}(H)=\texttt{attn}{\ell}(\texttt{ln1}_{\ell}(% H))+H,b start_POSTSUPERSCRIPT attn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT ( italic_H ) = attn start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT ( ln1 start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT ( italic_H ) ) + italic_H ,(5)
where attn ℓ subscript attn ℓ\texttt{attn}_{\ell}attn start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT is a MHSA layer and ln1 is a layer normalization (LN), and
b ℓ ffn(H)=ffn ℓ(ln2 ℓ(H))+H,subscript superscript b ffn ℓ 𝐻 subscript ffn ℓ subscript ln2 ℓ 𝐻 𝐻\texttt{b}^{\texttt{ffn}}{\ell}(H)=\texttt{ffn}{\ell}(\texttt{ln2}_{\ell}(H)% )+H,b start_POSTSUPERSCRIPT ffn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT ( italic_H ) = ffn start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT ( ln2 start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT ( italic_H ) ) + italic_H ,
where ffn ℓ subscript ffn ℓ\texttt{ffn}{\ell}ffn start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT is an FFN layer and ln2 is a LN. Given a block b ℓ subscript b ℓ\texttt{b}{\ell}b start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT and one of its sub-modules ln1 ℓ,attn ℓ,ln2 ℓ subscript ln1 ℓ subscript attn ℓ subscript ln2 ℓ\texttt{ln1}{\ell},\ \texttt{attn}{\ell},\ \texttt{ln2}{\ell}ln1 start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT , attn start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT , ln2 start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT, or ffn ℓ subscript ffn ℓ\texttt{ffn}{\ell}ffn start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT, we fit linear regression approximating the output of the sub-module given its input, and then use it to define mappings mat_attn ℓ→ℓ′subscript mat_attn→ℓ superscript ℓ′\texttt{mat_attn}{\ell\rightarrow\ell^{\prime}}{}mat_attn start_POSTSUBSCRIPT roman_ℓ → roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT, mat_ln1_ln2 ℓ→ℓ′subscript mat_ln1_ln2→ℓ superscript ℓ′\texttt{mat_ln1_ln2}{\ell\rightarrow\ell^{\prime}}{}mat_ln1_ln2 start_POSTSUBSCRIPT roman_ℓ → roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT and mat_ffn ℓ→ℓ′subscript mat_ffn→ℓ superscript ℓ′\texttt{mat_ffn}_{\ell\rightarrow\ell^{\prime}}{}mat_ffn start_POSTSUBSCRIPT roman_ℓ → roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT. We provide the formal definitions of these mappings in App. A.
Evaluation.
We analyze the 24 24 24 24-layered GPT-2, and proceed completely analogously to §4.1, evaluating the Precision@1 1 1 1 and Surprisal metrics for the mappings mat_attn ℓ→L subscript mat_attn→ℓ 𝐿\texttt{mat_attn}{\ell\rightarrow L}{}mat_attn start_POSTSUBSCRIPT roman_ℓ → italic_L end_POSTSUBSCRIPT, mat_ffn ℓ→L subscript mat_ffn→ℓ 𝐿\texttt{mat_ffn}{\ell\rightarrow L}{}mat_ffn start_POSTSUBSCRIPT roman_ℓ → italic_L end_POSTSUBSCRIPT and mat_ln1_ln2 ℓ→L subscript mat_ln1_ln2→ℓ 𝐿\texttt{mat_ln1_ln2}_{\ell\rightarrow L}{}mat_ln1_ln2 start_POSTSUBSCRIPT roman_ℓ → italic_L end_POSTSUBSCRIPT.
Figure 10: Precision@1 1 1 1 and Surprisal for the various sub-module linear mappings, and mat ℓ→L subscript mat→ℓ 𝐿\texttt{mat}_{\ell\rightarrow L}{}mat start_POSTSUBSCRIPT roman_ℓ → italic_L end_POSTSUBSCRIPT for comparison (24 24 24 24-layer GPT-2 next token prediction task). A 95% confidence interval surrounds the Surprisal lines.
Results.
Fig.10 shows the average Precision@1 1 1 1 and Surprisal scores per layer. From a certain layer (~7 7 7 7), all sub-module mappings achieve better results than the full-block mapping mat ℓ→L subscript mat→ℓ 𝐿\texttt{mat}{\ell\rightarrow L}{}mat start_POSTSUBSCRIPT roman_ℓ → italic_L end_POSTSUBSCRIPT. Thus, it is not just the cumulative effect of all the sub-modules in the transformer block that is amenable to linear approximation, but also individual sub-modules can be linearly approximated. Furthermore, the linear approximation of attention sub-modules is less harmful than that of the FFN or LN sub-modules. A possible reason is that the linear replacement of FFN or LN “erodes” the self-attention computation after a few layers. Moreover, the good performance of mat_attn ℓ→L subscript mat_attn→ℓ 𝐿\texttt{mat_attn}{\ell\rightarrow L}{}mat_attn start_POSTSUBSCRIPT roman_ℓ → italic_L end_POSTSUBSCRIPT suggests that contextualization often exhausts itself in early layers; speculatively, it is only in more delicate cases that the self-attention of late layers adds important information. Last, remark the sharp ascent of the scores for layer normalization in layers 5 5 5 5-8 8 8 8, for which we do not currently see a particular reason. To conclude, we see that the possibility of linear approximation permeates transformer components.
8.Related Work
There is a growing interest in utilizing intermediate representations of LMs for interpretability and efficiency. For interpretability, one seeks to understand the prediction construction process of the model Tenney et al. (2019); Voita et al. (2019); Geva et al. (2022b), or the features stored in its hidden representations Adi et al. (2017); Conneau et al. (2018); Liu et al. (2019). Our work is different as it converts intermediate representations into a final-layer form, which is interpretable by design.
Previous works on early exiting cut the computation at a dynamically-decided earlier stage Schwartz et al. (2020); Xin et al. (2020); Schuster et al. (2022); Gera et al. (2023), or a fixed network is utilized to parallelize inference (Leviathan et al., 2022; Chen et al., 2023). However, these methods propagate intermediate representations directly, which we show is substantially worse than our approach. Also, our method requires training of considerably fewer parameters than methods such as Schuster et al. (2021), which learn a different output softmax for each layer.
Last, skipping transformer layers and analyzing the linearity properties of transformer components have been discussed in prior works Zhao et al. (2021); Mickus et al. (2022); Wang et al. (2022); Lamparth and Reuel (2023). Specifically, a concurrent work by Belrose et al. (2023) proposed to train affine transformations from hidden to final representations to increase model transparency. Our work is different in that we train linear transformations across all layers. Moreover, while Belrose et al. (2023) use SGD for training while minimizing KL divergence, we use linear regression, which requires much less compute. It will be valuable to compare the accuracy of both methods.
9.Conclusion and Future Work
We present a simple and effective method for enhancing utilization of hidden representations in transformer-based LMs, that uses pre-fitted context-free and token-uniform linear mappings. Through a series of experiments on different data sources, model architectures and scales, we show that our method consistently outperforms the prevalent practice of interpreting representations in the final-layer space of the model, yielding better approximations of succeeding representations and the predictions they induce, thus allowing a more faithful interpretation of the model’s prediction-formation. We demonstrate the practicality of our method for improving computation efficiency, saving a substantial amount of compute on top of prominent early exiting approaches. Also, by extending our method to sub-modules, we observe that replacing a part of the transformer inference by a non-contextual linear computation often results in a small deterioration of the prediction. This opens new research directions for improving model efficiency, including breaking the computation into parallel tasks.
10.Limitations
First, while it is possible to define many different mappings in-between layers, for example, affine or non-linear transformations, we focus on the “simple” case of linear transformations. Our choice is motivated by the wide success of the simplest mapping (i.e. the identity baseline, of inspecting hidden representations in the same linear space), while we are asking if there is more linearity in transformer inference that can be exploited for interpretability.
Second, we find that there is more linear structure to parts of the transformer computation (both full layers and sub-modules) than could be explained solely by the residual connection. However, we do not elucidate a reason for that, leaving exploration of this interesting research question for future work.
Third, our experiments focus on post-hoc interpretability, that is, analyzing a trained model without changing its weights. Future work should also consider analyzing the utility of such linear mappings when those are integrated into the model training.
Last, in our experiments we use only data in English. Nonetheless, given the comprehensiveness of our experiments and the fact that our method does not rely on any language-specific features, we would expect our findings to hold in other languages as well.
11.Bibliographical References
\c@NAT@ctr
- Adi et al. (2017) Yossi Adi, Einat Kermany, Yonatan Belinkov, Ofer Lavi, and Yoav Goldberg. 2017. Fine-grained analysis of sentence embeddings using auxiliary prediction tasks. In International Conference on Learning Representations.
- Alammar (2021) J Alammar. 2021. Ecco: An open source library for the explainability of transformer language models. In Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing: System Demonstrations, pages 249–257, Online. Association for Computational Linguistics.
- Ba et al. (2016) Jimmy Lei Ba, Jamie Ryan Kiros, and Geoffrey E. Hinton. 2016. Layer normalization. arXiv:1607.06450v1.
- Belrose et al. (2023) Nora Belrose, Zach Furman, Logan Smith, Danny Halawi, Lev McKinney, Igor Ostrovsky, Stella Biderman, and Jacob Steinhardt. 2023. Eliciting latent predictions from transformers with the tuned lens. to appear.
- Chen et al. (2023) Charlie Chen, Sebastian Borgeaud, Geoffrey Irving, Jean-Baptiste Lespiau, Laurent Sifre, and John Jumper. 2023. Accelerating large language model decoding with speculative sampling. arXiv:2302.01318v1.
- Conneau et al. (2018) Alexis Conneau, German Kruszewski, Guillaume Lample, Loïc Barrault, and Marco Baroni. 2018. What you can cram into a single $&!#* vector: Probing sentence embeddings for linguistic properties. In Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pages 2126–2136, Melbourne, Australia. Association for Computational Linguistics.
- Dar et al. (2022) Guy Dar, Mor Geva, Ankit Gupta, and Jonathan Berant. 2022. Analyzing transformers in embedding space. arXiv:2209.02535v2.
- Devlin et al. (2019) Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. 2019. BERT: Pre-training of deep bidirectional transformers for language understanding. In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers), pages 4171–4186, Minneapolis, Minnesota. Association for Computational Linguistics.
- Elhage et al. (2021) Nelson Elhage, Neel Nanda, Catherine Olsson, Tom Henighan, Nicholas Joseph, Ben Mann, Amanda Askell, Yuntao Bai, Anna Chen, Tom Conerly, Nova DasSarma, Dawn Drain, Deep Ganguli, Zac Hatfield-Dodds, Danny Hernandez, Andy Jones, Jackson Kernion, Liane Lovitt, Kamal Ndousse, Dario Amodei, Tom Brown, Jack Clark, Jared Kaplan, Sam McCandlish, and Chris Olah. 2021. A mathematical framework for transformer circuits. In Transformer Circuits Thread.
- Gera et al. (2023) Ariel Gera, Roni Friedman, Ofir Arviv, Chulaka Gunasekara, Benjamin Sznajder, Noam Slonim, and Eyal Shnarch. 2023. The benefits of bad advice: Autocontrastive decoding across model layers. In Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pages 10406–10420, Toronto, Canada. Association for Computational Linguistics.
- Geva et al. (2022a) Mor Geva, Avi Caciularu, Guy Dar, Paul Roit, Shoval Sadde, Micah Shlain, Bar Tamir, and Yoav Goldberg. 2022a. LM-debugger: An interactive tool for inspection and intervention in transformer-based language models. In Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing: System Demonstrations, pages 12–21, Abu Dhabi, UAE. Association for Computational Linguistics.
- Geva et al. (2022b) Mor Geva, Avi Caciularu, Kevin Wang, and Yoav Goldberg. 2022b. Transformer feed-forward layers build predictions by promoting concepts in the vocabulary space. In Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing, pages 30–45, Abu Dhabi, United Arab Emirates. Association for Computational Linguistics.
- Geva et al. (2021) Mor Geva, Roei Schuster, Jonathan Berant, and Omer Levy. 2021. Transformer feed-forward layers are key-value memories. In Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing, pages 5484–5495, Online and Punta Cana, Dominican Republic. Association for Computational Linguistics.
- Goldhahn et al. (2012) Dirk Goldhahn, Thomas Eckart, and Uwe Quasthoff. 2012. Building large monolingual dictionaries at the Leipzig corpora collection: From 100 to 200 languages. In Proceedings of the Eighth International Conference on Language Resources and Evaluation (LREC’12), pages 759–765, Istanbul, Turkey. European Language Resources Association (ELRA).
- He et al. (2016) K.He, X.Zhang, S.Ren, and J.Sun. 2016. Deep residual learning for image recognition. In 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pages 770–778, Los Alamitos, CA, USA. IEEE Computer Society.
- Lamparth and Reuel (2023) Max Lamparth and Anka Reuel. 2023. Analyzing and editing inner mechanisms of backdoored language models. arXiv:2302.12461v1.
- Leviathan et al. (2022) Yaniv Leviathan, Matan Kalman, and Yossi Matias. 2022. Fast inference from transformers via speculative decoding. arXiv:2211.17192v1.
- Liu et al. (2019) Nelson F. Liu, Matt Gardner, Yonatan Belinkov, Matthew E. Peters, and Noah A. Smith. 2019. Linguistic knowledge and transferability of contextual representations. In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers), pages 1073–1094, Minneapolis, Minnesota. Association for Computational Linguistics.
- Mickus et al. (2022) Timothee Mickus, Denis Paperno, and Mathieu Constant. 2022. How to dissect a Muppet: The structure of transformer embedding spaces. Transactions of the Association for Computational Linguistics, 10:981–996.
- Radford et al. (2019) Alec Radford, Jeff Wu, Rewon Child, David Luan, Dario Amodei, and Ilya Sutskever. 2019. Language models are unsupervised multitask learners.
- Ram et al. (2022) Ori Ram, Liat Bezalel, Adi Zicher, Yonatan Belinkov, Jonathan Berant, and Amir Globerson. 2022. What are you token about? dense retrieval as distributions over the vocabulary. arXiv:2212.10380v1.
- Schuster et al. (2022) Tal Schuster, Adam Fisch, Jai Gupta, Mostafa Dehghani, Dara Bahri, Vinh Q. Tran, Yi Tay, and Donald Metzler. 2022. Confident adaptive language modeling. In Advances in Neural Information Processing Systems.
- Schuster et al. (2021) Tal Schuster, Adam Fisch, Tommi Jaakkola, and Regina Barzilay. 2021. Consistent accelerated inference via confident adaptive transformers. In Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing, pages 4962–4979, Online and Punta Cana, Dominican Republic. Association for Computational Linguistics.
- Schwartz et al. (2020) Roy Schwartz, Gabriel Stanovsky, Swabha Swayamdipta, Jesse Dodge, and Noah A. Smith. 2020. The right tool for the job: Matching model and instance complexities. In Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics, pages 6640–6651, Online. Association for Computational Linguistics.
- Slobodkin et al. (2021) Aviv Slobodkin, Leshem Choshen, and Omri Abend. 2021. Mediators in determining what processing BERT performs first. In Proceedings of the 2021 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, pages 86–93, Online. Association for Computational Linguistics.
- Tenney et al. (2019) Ian Tenney, Dipanjan Das, and Ellie Pavlick. 2019. BERT rediscovers the classical NLP pipeline. In Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics, pages 4593–4601, Florence, Italy. Association for Computational Linguistics.
- Vaswani et al. (2017) Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Ł ukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information Processing Systems, volume 30. Curran Associates, Inc.
- Voita et al. (2019) Elena Voita, Rico Sennrich, and Ivan Titov. 2019. The bottom-up evolution of representations in the transformer: A study with machine translation and language modeling objectives. In Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP), pages 4396–4406, Hong Kong, China. Association for Computational Linguistics.
- Wang et al. (2022) Jue Wang, Ke Chen, Gang Chen, Lidan Shou, and Julian McAuley. 2022. SkipBERT: Efficient inference with shallow layer skipping. In Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pages 7287–7301, Dublin, Ireland. Association for Computational Linguistics.
- Xin et al. (2020) Ji Xin, Raphael Tang, Jaejun Lee, Yaoliang Yu, and Jimmy Lin. 2020. DeeBERT: Dynamic early exiting for accelerating BERT inference. In Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics, pages 2246–2251, Online. Association for Computational Linguistics.
- Xu et al. (2021) Jingjing Xu, Wangchunshu Zhou, Zhiyi Fu, Hao Zhou, and Lei Li. 2021. A survey on green deep learning. arXiv:2111.05193v2.
- Zhao et al. (2021) Sumu Zhao, Damian Pascual, Gino Brunner, and Roger Wattenhofer. 2021. Of Non-Linearity and Commutativity in BERT. In International Joint Conference on Neural Networks (IJCNN), Virtual-only.
Appendix A Descriptions of mat_attn, mat_ffn and mat_ln1_ln2
Here we detail the definitions of the mappings mat_attn ℓ→ℓ′subscript mat_attn→ℓ superscript ℓ′\texttt{mat_attn}{\ell\rightarrow\ell^{\prime}}{}mat_attn start_POSTSUBSCRIPT roman_ℓ → roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT, mat_ffn ℓ→ℓ′subscript mat_ffn→ℓ superscript ℓ′\texttt{mat_ffn}{\ell\rightarrow\ell^{\prime}}{}mat_ffn start_POSTSUBSCRIPT roman_ℓ → roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT and mat_ln1_ln2 ℓ→ℓ′subscript mat_ln1_ln2→ℓ superscript ℓ′\texttt{mat_ln1_ln2}_{\ell\rightarrow\ell^{\prime}}{}mat_ln1_ln2 start_POSTSUBSCRIPT roman_ℓ → roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT utilized in §7.
Description of mat_attn ℓ→ℓ′subscript mat_attn→ℓ superscript ℓ′\texttt{mat_attn}_{\ell\rightarrow\ell^{\prime}}{}mat_attn start_POSTSUBSCRIPT roman_ℓ → roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT.
For an input s 𝑠 s italic_s, let v i s ℓ subscript superscript 𝑣 ℓ subscript 𝑖 𝑠 v^{\ell}{i{s}}italic_v start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT be the vector at position i s subscript 𝑖 𝑠 i_{s}italic_i start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT in the output of attn ℓ(ln1 ℓ(H ℓ−1))subscript attn ℓ subscript ln1 ℓ superscript 𝐻 ℓ 1\texttt{attn}{\ell}(\texttt{ln1}{\ell}(H^{\ell-1}))attn start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT ( ln1 start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT ( italic_H start_POSTSUPERSCRIPT roman_ℓ - 1 end_POSTSUPERSCRIPT ) ). We denote by A ℓ attn∈ℝ d h×d h superscript subscript 𝐴 ℓ attn superscript ℝ subscript 𝑑 ℎ subscript 𝑑 ℎ A_{\ell}^{\texttt{attn}}\in\mathbb{R}^{d_{h}\times d_{h}}italic_A start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT attn end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUPERSCRIPT the matrix numerically minimizing
A↦∑s∈𝒯‖A⋅ln1 ℓ(h i s ℓ−1)−v i s ℓ‖2,maps-to 𝐴 subscript 𝑠 𝒯 superscript norm⋅𝐴 subscript ln1 ℓ subscript superscript ℎ ℓ 1 subscript 𝑖 𝑠 subscript superscript 𝑣 ℓ subscript 𝑖 𝑠 2 A\mapsto\sum_{s\in\mathcal{T}}||A\cdot\texttt{ln1}{\ell}(h^{\ell-1}{i_{s}})-% v^{\ell}{i{s}}||^{2},italic_A ↦ ∑ start_POSTSUBSCRIPT italic_s ∈ caligraphic_T end_POSTSUBSCRIPT | | italic_A ⋅ ln1 start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT ( italic_h start_POSTSUPERSCRIPT roman_ℓ - 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) - italic_v start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT | | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ,
and define an attention sub-module replacement (Eq.5) by
b ℓ attn¯(h)≔A ℓ attn⋅ln1 ℓ(h)+h.≔subscript superscript b¯attn ℓ ℎ⋅superscript subscript 𝐴 ℓ attn subscript ln1 ℓ ℎ ℎ\texttt{b}^{\overline{\texttt{attn}}}{\ell}(h)\coloneqq A{\ell}^{\texttt{% attn}}\cdot\texttt{ln1}_{\ell}(h)+h.b start_POSTSUPERSCRIPT over¯ start_ARG attn end_ARG end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT ( italic_h ) ≔ italic_A start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT attn end_POSTSUPERSCRIPT ⋅ ln1 start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT ( italic_h ) + italic_h .
We then define a mapping between two layers ℓ→ℓ′→ℓ superscript ℓ′{\ell\rightarrow\ell^{\prime}}roman_ℓ → roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT by:
mat_attn ℓ→ℓ′(h)≔≔subscript mat_attn→ℓ superscript ℓ′ℎ absent\texttt{mat_attn}_{\ell\rightarrow\ell^{\prime}}{}(h)\coloneqq mat_attn start_POSTSUBSCRIPT roman_ℓ → roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_h ) ≔
b ℓ′ffn(b ℓ′attn¯(…(b ℓ+1 ffn(b ℓ+1 attn¯(h)))…)).subscript superscript b ffn superscript ℓ′subscript superscript b¯attn superscript ℓ′…subscript superscript b ffn ℓ 1 subscript superscript b¯attn ℓ 1 ℎ…\texttt{b}^{\texttt{ffn}}{\ell^{\prime}}(\texttt{b}^{\overline{\texttt{attn}}% }{\ell^{\prime}}(\ldots(\texttt{b}^{\texttt{ffn}}{\ell+1}(\texttt{b}^{% \overline{\texttt{attn}}}{\ell+1}(h)))\ldots)).b start_POSTSUPERSCRIPT ffn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( b start_POSTSUPERSCRIPT over¯ start_ARG attn end_ARG end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( … ( b start_POSTSUPERSCRIPT ffn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_ℓ + 1 end_POSTSUBSCRIPT ( b start_POSTSUPERSCRIPT over¯ start_ARG attn end_ARG end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_ℓ + 1 end_POSTSUBSCRIPT ( italic_h ) ) ) … ) ) .
Namely, when applying each ℓ′′superscript ℓ′′\ell^{\prime\prime}roman_ℓ start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT-th block, ℓ<ℓ′′≤ℓ′ℓ superscript ℓ′′superscript ℓ′\ell<\ell^{\prime\prime}\leq\ell^{\prime}roman_ℓ < roman_ℓ start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT ≤ roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT, we replace its attention sub-module attn ℓ′′subscript attn superscript ℓ′′\texttt{attn}{\ell^{\prime\prime}}attn start_POSTSUBSCRIPT roman_ℓ start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT by its linear approximation. Importantly, unlike the original attention module, the approximation b ℓ attn¯subscript superscript b¯attn ℓ\texttt{b}^{\overline{\texttt{attn}}}{\ell}b start_POSTSUPERSCRIPT over¯ start_ARG attn end_ARG end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT operates on each position independently, and therefore applying mat_attn ℓ→ℓ′subscript mat_attn→ℓ superscript ℓ′\texttt{mat_attn}{\ell\rightarrow\ell^{\prime}}{}mat_attn start_POSTSUBSCRIPT roman_ℓ → roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT disables any contextualization between the layers ℓ ℓ\ell roman_ℓ and ℓ′superscript ℓ′\ell^{\prime}roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT. Note that this is not the case for mat_ffn ℓ→ℓ′subscript mat_ffn→ℓ superscript ℓ′\texttt{mat_ffn}{\ell\rightarrow\ell^{\prime}}{}mat_ffn start_POSTSUBSCRIPT roman_ℓ → roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT and mat_ln1_ln2 ℓ→ℓ′subscript mat_ln1_ln2→ℓ superscript ℓ′\texttt{mat_ln1_ln2}_{\ell\rightarrow\ell^{\prime}}{}mat_ln1_ln2 start_POSTSUBSCRIPT roman_ℓ → roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT, which retain the self-attention sub-modules and operate contextually.
Description of mat_ffn ℓ→ℓ′subscript mat_ffn→ℓ superscript ℓ′\texttt{mat_ffn}_{\ell\rightarrow\ell^{\prime}}{}mat_ffn start_POSTSUBSCRIPT roman_ℓ → roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT.
Let v i s ℓ subscript superscript 𝑣 ℓ subscript 𝑖 𝑠 v^{\ell}{i{s}}italic_v start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT be the vector at position i s subscript 𝑖 𝑠 i_{s}italic_i start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT in the output of ln2 ℓ(b ℓ attn(H ℓ−1))subscript ln2 ℓ superscript subscript b ℓ attn superscript 𝐻 ℓ 1\texttt{ln2}{\ell}(\texttt{b}{\ell}^{\texttt{attn}}(H^{\ell-1}))ln2 start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT ( b start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT attn end_POSTSUPERSCRIPT ( italic_H start_POSTSUPERSCRIPT roman_ℓ - 1 end_POSTSUPERSCRIPT ) ), for a given input s 𝑠 s italic_s. We denote by A ℓ ffn∈ℝ d h×d h superscript subscript 𝐴 ℓ ffn superscript ℝ subscript 𝑑 ℎ subscript 𝑑 ℎ A_{\ell}^{\texttt{ffn}}\in\mathbb{R}^{d_{h}\times d_{h}}italic_A start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ffn end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUPERSCRIPT the matrix numerically minimizing
A↦∑s∈𝒯‖A⋅v i s ℓ−ffn ℓ(v i s ℓ)‖2,maps-to 𝐴 subscript 𝑠 𝒯 superscript norm⋅𝐴 subscript superscript 𝑣 ℓ subscript 𝑖 𝑠 subscript ffn ℓ subscript superscript 𝑣 ℓ subscript 𝑖 𝑠 2 A\mapsto\sum_{s\in\mathcal{T}}||A\cdot v^{\ell}{i{s}}-\texttt{ffn}{\ell}(v^% {\ell}{i_{s}})||^{2},italic_A ↦ ∑ start_POSTSUBSCRIPT italic_s ∈ caligraphic_T end_POSTSUBSCRIPT | | italic_A ⋅ italic_v start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT - ffn start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT ( italic_v start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) | | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ,
and define a replacement of the feed-forward sub-module b ℓ ffn superscript subscript b ℓ ffn\texttt{b}_{\ell}^{\texttt{ffn}}b start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ffn end_POSTSUPERSCRIPT by
b ℓ ffn¯(H)≔A ℓ ffn⋅ln2 ℓ(H)+H.≔subscript superscript b¯ffn ℓ 𝐻⋅superscript subscript 𝐴 ℓ ffn subscript ln2 ℓ 𝐻 𝐻\texttt{b}^{\overline{\texttt{ffn}}}{\ell}(H)\coloneqq A{\ell}^{\texttt{ffn}% }\cdot\texttt{ln2}_{\ell}(H)+H.b start_POSTSUPERSCRIPT over¯ start_ARG ffn end_ARG end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT ( italic_H ) ≔ italic_A start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ffn end_POSTSUPERSCRIPT ⋅ ln2 start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT ( italic_H ) + italic_H .
We then define a mapping between two layers ℓ→ℓ′→ℓ superscript ℓ′{\ell\rightarrow\ell^{\prime}}roman_ℓ → roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT by:
mat_ffn ℓ→ℓ′(H)≔≔subscript mat_ffn→ℓ superscript ℓ′𝐻 absent\texttt{mat_ffn}_{\ell\rightarrow\ell^{\prime}}{}(H)\coloneqq mat_ffn start_POSTSUBSCRIPT roman_ℓ → roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_H ) ≔
b ℓ′ffn¯(b ℓ′attn(…(b ℓ+1 ffn¯(b ℓ+1 attn(H))…)).\texttt{b}^{\overline{\texttt{ffn}}}{\ell^{\prime}}(\texttt{b}^{\texttt{attn}% }{\ell^{\prime}}(\ldots(\texttt{b}^{\overline{\texttt{ffn}}}{\ell+1}(\texttt% {b}^{\texttt{attn}}{\ell+1}(H))\ldots)).b start_POSTSUPERSCRIPT over¯ start_ARG ffn end_ARG end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( b start_POSTSUPERSCRIPT attn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( … ( b start_POSTSUPERSCRIPT over¯ start_ARG ffn end_ARG end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_ℓ + 1 end_POSTSUBSCRIPT ( b start_POSTSUPERSCRIPT attn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_ℓ + 1 end_POSTSUBSCRIPT ( italic_H ) ) … ) ) .
Description of mat_ln1_ln2 ℓ→ℓ′subscript mat_ln1_ln2→ℓ superscript ℓ′\texttt{mat_ln1_ln2}_{\ell\rightarrow\ell^{\prime}}{}mat_ln1_ln2 start_POSTSUBSCRIPT roman_ℓ → roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT.
Let v i s ℓ subscript superscript 𝑣 ℓ subscript 𝑖 𝑠 v^{\ell}{i{s}}italic_v start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT be the vector at position i s subscript 𝑖 𝑠 i_{s}italic_i start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT in the output of b ℓ attn(H ℓ−1)subscript superscript b attn ℓ superscript 𝐻 ℓ 1\texttt{b}^{\texttt{attn}}{\ell}(H^{\ell-1})b start_POSTSUPERSCRIPT attn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT ( italic_H start_POSTSUPERSCRIPT roman_ℓ - 1 end_POSTSUPERSCRIPT ), for a given input s 𝑠 s italic_s. We denote by A ℓ ln1∈ℝ d h×d h superscript subscript 𝐴 ℓ ln1 superscript ℝ subscript 𝑑 ℎ subscript 𝑑 ℎ A{\ell}^{\texttt{ln1}}\in\mathbb{R}^{d_{h}\times d_{h}}italic_A start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ln1 end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUPERSCRIPT the matrix numerically minimizing
A↦∑s∈𝒯‖A⋅h i s ℓ−ln1 ℓ(h i s ℓ)‖2 maps-to 𝐴 subscript 𝑠 𝒯 superscript norm⋅𝐴 subscript superscript ℎ ℓ subscript 𝑖 𝑠 subscript ln1 ℓ subscript superscript ℎ ℓ subscript 𝑖 𝑠 2 A\mapsto\sum_{s\in\mathcal{T}}||A\cdot h^{\ell}{i{s}}-\texttt{ln1}{\ell}(h^% {\ell}{i_{s}})||^{2}italic_A ↦ ∑ start_POSTSUBSCRIPT italic_s ∈ caligraphic_T end_POSTSUBSCRIPT | | italic_A ⋅ italic_h start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT - ln1 start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT ( italic_h start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) | | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
and we denote by A ℓ ln2∈ℝ d h×d h superscript subscript 𝐴 ℓ ln2 superscript ℝ subscript 𝑑 ℎ subscript 𝑑 ℎ A_{\ell}^{\texttt{ln2}}\in\mathbb{R}^{d_{h}\times d_{h}}italic_A start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ln2 end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUPERSCRIPT the matrix numerically minimizing
A↦∑s∈𝒯‖A⋅v i s ℓ−ln2 ℓ(v i s ℓ)‖2.maps-to 𝐴 subscript 𝑠 𝒯 superscript norm⋅𝐴 subscript superscript 𝑣 ℓ subscript 𝑖 𝑠 subscript ln2 ℓ subscript superscript 𝑣 ℓ subscript 𝑖 𝑠 2 A\mapsto\sum_{s\in\mathcal{T}}||A\cdot v^{\ell}{i{s}}-\texttt{ln2}{\ell}(v^% {\ell}{i_{s}})||^{2}.italic_A ↦ ∑ start_POSTSUBSCRIPT italic_s ∈ caligraphic_T end_POSTSUBSCRIPT | | italic_A ⋅ italic_v start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT - ln2 start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT ( italic_v start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) | | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT .
We define a replacement of the block b ℓ attn subscript superscript b attn ℓ\texttt{b}^{\texttt{attn}}_{\ell}b start_POSTSUPERSCRIPT attn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT by
b ℓ ln1¯(H)≔attn ℓ(A ℓ ln1⋅H)+H≔subscript superscript b¯ln1 ℓ 𝐻 subscript attn ℓ⋅superscript subscript 𝐴 ℓ ln1 𝐻 𝐻\texttt{b}^{\overline{\texttt{ln1}}}{\ell}(H)\coloneqq\texttt{attn}{\ell}(A_% {\ell}^{\texttt{ln1}}\cdot H)+H b start_POSTSUPERSCRIPT over¯ start_ARG ln1 end_ARG end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT ( italic_H ) ≔ attn start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT ( italic_A start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ln1 end_POSTSUPERSCRIPT ⋅ italic_H ) + italic_H(6)
and we define a replacement of the block b ℓ ffn subscript superscript b ffn ℓ\texttt{b}^{\texttt{ffn}}_{\ell}b start_POSTSUPERSCRIPT ffn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT by
b ℓ ln2¯(H)≔ffn ℓ(A ℓ ln2⋅H)+H.≔subscript superscript b¯ln2 ℓ 𝐻 subscript ffn ℓ⋅superscript subscript 𝐴 ℓ ln2 𝐻 𝐻\texttt{b}^{\overline{\texttt{ln2}}}{\ell}(H)\coloneqq\texttt{ffn}{\ell}(A_{% \ell}^{\texttt{ln2}}\cdot H)+H.b start_POSTSUPERSCRIPT over¯ start_ARG ln2 end_ARG end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT ( italic_H ) ≔ ffn start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT ( italic_A start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ln2 end_POSTSUPERSCRIPT ⋅ italic_H ) + italic_H .(7)
We then define a mapping between two layers ℓ→ℓ′→ℓ superscript ℓ′{\ell\rightarrow\ell^{\prime}}roman_ℓ → roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT by:
mat_ln1_ln2 ℓ→ℓ′(H)≔≔subscript mat_ln1_ln2→ℓ superscript ℓ′𝐻 absent\texttt{mat_ln1_ln2}_{\ell\rightarrow\ell^{\prime}}{}(H)\coloneqq mat_ln1_ln2 start_POSTSUBSCRIPT roman_ℓ → roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_H ) ≔
b ℓ′ln2¯(b ℓ′ln1¯(…(b ℓ+1 ln2¯(b ℓ+1 ln1¯(H))…)).\texttt{b}^{\overline{\texttt{ln2}}}{\ell^{\prime}}(\texttt{b}^{\overline{% \texttt{ln1}}}{\ell^{\prime}}(\ldots(\texttt{b}^{\overline{\texttt{ln2}}}{% \ell+1}(\texttt{b}^{\overline{\texttt{ln1}}}{\ell+1}(H))\ldots)).b start_POSTSUPERSCRIPT over¯ start_ARG ln2 end_ARG end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( b start_POSTSUPERSCRIPT over¯ start_ARG ln1 end_ARG end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( … ( b start_POSTSUPERSCRIPT over¯ start_ARG ln2 end_ARG end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_ℓ + 1 end_POSTSUBSCRIPT ( b start_POSTSUPERSCRIPT over¯ start_ARG ln1 end_ARG end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_ℓ + 1 end_POSTSUBSCRIPT ( italic_H ) ) … ) ) .
Xet Storage Details
- Size:
- 104 kB
- Xet hash:
- f0beb3da1db698325dbfb29953da31b2ccf2ebb62d9df88983758eb28d9829da
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.









