Buckets:

|
download
raw
54.1 kB

Title: Universal In-Context Approximation By Prompting Fully Recurrent Models

URL Source: https://arxiv.org/html/2406.01424

Published Time: Fri, 11 Oct 2024 01:20:46 GMT

Markdown Content: Aleksandar Petrov, Tom A. Lamb, Alasdair Paren, Philip H.S. Torr, Adel Bibi

Department of Engineering Science

University of Oxford

aleks@robots.ox.ac.uk

Abstract

Zero-shot and in-context learning enable solving tasks without model fine-tuning, making them essential for developing generative model solutions. Therefore, it is crucial to understand whether a pretrained model can be prompted to approximate any function, i.e., whether it is a universal in-context approximator. While it was recently shown that transformer models do possess this property, these results rely on their attention mechanism. Hence, these findings do not apply to fully recurrent architectures like RNNs, LSTMs, and the increasingly popular SSMs. We demonstrate that RNNs, LSTMs, GRUs, Linear RNNs, and linear gated architectures such as Mamba and Hawk/Griffin can also serve as universal in-context approximators. To streamline our argument, we introduce a programming language called LSRL that compiles to these fully recurrent architectures. LSRL may be of independent interest for further studies of fully recurrent models, such as constructing interpretability benchmarks. We also study the role of multiplicative gating and observe that architectures incorporating such gating (e.g., LSTMs, GRUs, Hawk/Griffin) can implement certain operations more stably, making them more viable candidates for practical in-context universal approximation.

1 Introduction

Until recently, solving a task with machine learning required training or fine-tuning a model on a dataset matching the task at hand. However, large foundation models exhibit the ability to solve new tasks without being specifically fine-tuned or trained for them: often it is sufficient to simply prompt them in the right way. \untagged neuripsThis has made prompting a key method for steering a model towards a specific behaviour or task (liu2023pretrain). Prompting has been especially successful because of in-context learning: the ability to modify the model’s behavior with information provided within the input sequence, without changing the underlying model parameters (brown2020language). \untagged neuripsAs a result, the art and skill of constructing a successful prompt (prompt engineering) has become extremely important (liu2022design; sahoo2024systematic). Yet, we know little about the theoretical properties of prompting. It is not even clear if there are limits to what can be achieved with prompting or, conversely, whether it is possible to prompt your way into any behaviour or task.

This can be framed as a universal approximation question. Classically, universal approximation results show how a class of tractable functions, such as neural networks, approximates another class of concept functions, e.g., all continuous functions on a bounded domain, with arbitrary accuracy. This is often done by showing that one can choose model parameters that approximate the target function. However, in-context learning poses a different challenge as the model parameters are fixed. Instead, a part of the input (the prompt) is modified to cause the model to approximate the target function. Hence, we define universal in-context approximation to be the property that there exist fixed weights such that the resulting model can be prompted to approximate any function from a concept class. Understanding whether a model can be a universal in-context approximator is especially important as most commercial models are accessible exclusively via a prompting interface (lamalfa2023language).

In-context learning has been almost exclusively studied in conjunction with the transformer architecture (vaswani2017attention). This is likely because in-context abilities appear once the models are large enough (wei2021finetuned) and most large models have been transformer-based. On the subject of universal in-context approximation, wang2024universality were first to show that a transformer possesses this property by discretising and memorising all possible functions in the model weights. Memorisation is not needed, though, and even small transformers can be universal approximators when prompted petrov2024universal. Both results, however, critically depend on the attention mechanism of the transformer architecture (bahdanau2014neural).

Still, generative models are not restricted to attention-based architectures: there are the “classic” recurrent neural networks (RNNs, amari1992rnns), long short-term memory models (LSTMs, hochreiter1997long) and gated recurrent units (GRUs, cho2014learning). Recently, Linear RNN models (also known as state-space models or SSMs) were proposed as a scalable alternative to the transformer architecture (orvieto2023resurrecting; fu2023hungry) and have started to outperform similarly-sized transformers when multiplicative gating is added (gu2023mamba; de2024griffin; botev2024recurrentgemma). Furthermore, despite in-context learning being associated with the transformer, recent empirical results show in-context learning in SSMs, RNNs, LSTMs and even convolutional models (xie2021explanation; akyurek2024context; lee2023exploring).

Yet, despite their ability to be in-context learners, there is little known about the theoretical properties of these fully recurrent architectures. As these architectures become more and more widely used, understanding their in-context approximation abilities is increasingly more important for their safety, security and alignment. We show that, in fact, many of these architectures, similarly to transformers, can be universal in-context approximators. Concretely, our contributions are as follows:

  1. i.We develop Linear State Recurrent Language (LSRL): a programming language that compiles to different fully recurrent models. Programming in LSRL is akin to “thinking like a recurrent model”. LSRL programs can then be implemented exactly as model weights.
  2. ii.Using LSRL, we construct Linear RNN models that can be prompted to act as any token-to-token function over finite token sequences, or to approximate any continuous function. These results also hold for RNNs, LSTMs, GRUs and Hawk/Griffin models (de2024griffin).
  3. iii.We present constructions with and without multiplicative gating. However, we observe that the constructions without these gates depend on numerically unstable conditional logic.
  4. iv.Nevertheless, we show that multiplicative gates lead to more compact and numerically stable models, making it more likely that universal in-context approximation properties arise in models utilising them, such as LSTMs, GRUs and the latest generation of Linear RNNs.

2 Preliminaries

Fully recurrent architectures.

In this work, we focus exclusively on fully recurrent neural network architectures. Recurrent models operate over sequences. Concretely, consider an input sequence (𝒙 1,…,𝒙 N)subscript 𝒙 1…subscript 𝒙 𝑁(\bm{x}{1},\ldots,\bm{x}{N})( bold_italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_x start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ) with 𝒙 t∈𝒳 subscript 𝒙 𝑡 𝒳\bm{x}_{t}\in\mathcal{X}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ caligraphic_X, 𝒳 𝒳\mathcal{X}caligraphic_X being some input space. We will refer to the elements of the input sequence as tokens even if they are real-valued vectors. A recurrent model g:𝒳⋆→𝒴:𝑔→superscript 𝒳⋆𝒴 g:\mathcal{X}^{\star}\to\mathcal{Y}italic_g : caligraphic_X start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT → caligraphic_Y maps a sequence of inputs to an output in some output space 𝒴 𝒴\mathcal{Y}caligraphic_Y. These models are always causal, namely:

𝒚 t=g⁢(𝒙 1,…,𝒙 t).subscript 𝒚 𝑡 𝑔 subscript 𝒙 1…subscript 𝒙 𝑡\bm{y}{t}=g(\bm{x}{1},\ldots,\bm{x}_{t}).bold_italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_g ( bold_italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) .(1)

We will abuse the notation and refer to (𝒚 1,…,𝒚 t)=(g⁢(𝒙 1),…,g⁢(𝒙 1,…,𝒙 t))subscript 𝒚 1…subscript 𝒚 𝑡 𝑔 subscript 𝒙 1…𝑔 subscript 𝒙 1…subscript 𝒙 𝑡(\bm{y}{1},.\kern-0.20004pt.\kern-0.20004pt.,\bm{y}{t}){=}(g(\bm{x}{1}),.% \kern-0.20004pt.\kern-0.20004pt.,g(\bm{x}{1},.\kern-0.20004pt.\kern-0.20004pt% .,\bm{x}{t}))( bold_italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) = ( italic_g ( bold_italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , … , italic_g ( bold_italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ) as simply g⁢(𝒙 1,…,𝒙 t)𝑔 subscript 𝒙 1…subscript 𝒙 𝑡 g(\bm{x}{1},.\kern-0.20004pt.\kern-0.20004pt.,\bm{x}{t})italic_g ( bold_italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ). We will also separate the input sequence into a query (𝒒 1,…,𝒒 n)subscript 𝒒 1…subscript 𝒒 𝑛(\bm{q}{1},.\kern-0.20004pt.\kern-0.20004pt.,\bm{q}{n})( bold_italic_q start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_q start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) and a prompt (𝒑 1,…,𝒑 N)subscript 𝒑 1…subscript 𝒑 𝑁(\bm{p}{1},.\kern-0.20004pt.\kern-0.20004pt.,\bm{p}_{N})( bold_italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_p start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ). The prompt specifies the target function f 𝑓 f italic_f that we approximate while the query designates the input at which we evaluate it. Contrary to the typical setting, we will place the query before the prompt.1 1 1 That is necessitated by the limited capacity of the state variables. As the model is fixed, in order to increase the precision of the approximation, we can only increase the prompt length. If the prompt is before the query, it would have to be compressed into a fixed-size state, limiting the approximation precision even with increased prompt lengths. But if the query has a fixed size, it can be stored in a fixed-size state variable exactly.

There are various neural network architectures that fall under the general framework of Eq.1. The quintessential one is the RNN. It processes inputs one by one with only a non-linear state being passed from one time step to the other. A model g 𝑔 g italic_g can thus be stacked RNN layers, each one being:

𝒔 t=σ⁢(𝑨⁢𝒔 t−1+𝑩⁢𝒙 t+𝒃),𝒚 t=ϕ⁢(𝒔 t),subscript 𝒔 𝑡 absent 𝜎 𝑨 subscript 𝒔 𝑡 1 𝑩 subscript 𝒙 𝑡 𝒃 subscript 𝒚 𝑡 absent italic-ϕ subscript 𝒔 𝑡\displaystyle\begin{aligned} \bm{s}{t}&=\sigma(\bm{A}\bm{s}{t-1}+\bm{B}\bm{x% }{t}+\bm{b}),\ \bm{y}{t}&=\phi(\bm{s}_{t}),\end{aligned}start_ROW start_CELL bold_italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL start_CELL = italic_σ ( bold_italic_A bold_italic_s start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT + bold_italic_B bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + bold_italic_b ) , end_CELL end_ROW start_ROW start_CELL bold_italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL start_CELL = italic_ϕ ( bold_italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) , end_CELL end_ROW(Classic RNN)(2)

with 𝑨,𝑩,𝒃 𝑨 𝑩 𝒃\bm{A},\bm{B},\bm{b}bold_italic_A , bold_italic_B , bold_italic_b and the initial state value 𝒔 0 subscript 𝒔 0\bm{s}_{0}bold_italic_s start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT being model parameters, σ 𝜎\sigma italic_σ a non-linear activation function and ϕ italic-ϕ\phi italic_ϕ a multi-layer perceptron (MLP) with ReLU activations. We assume that σ 𝜎\sigma italic_σ is always a ReLU to keep the analysis simpler. The non-linearity in the state update can make the model difficult to train (vanishing and exploding gradients, bengio1994learning). Therefore, Linear RNNs have been proposed as regularizing the eigenvalues of 𝑨 𝑨\bm{A}bold_italic_A can stabilise the training dynamics (orvieto2023resurrecting). Linear RNNs also admit a convolutional representation, making them trainable in parallel (gu2021efficiently; fu2023hungry). Linear RNNs drop the non-linearity from the state update in Eq.2:

𝒔 t=𝑨⁢𝒔 t−1+𝑩⁢𝒙 t+𝒃,𝒚 t=ϕ⁢(𝒔 t).subscript 𝒔 𝑡 absent 𝑨 subscript 𝒔 𝑡 1 𝑩 subscript 𝒙 𝑡 𝒃 subscript 𝒚 𝑡 absent italic-ϕ subscript 𝒔 𝑡\displaystyle\begin{aligned} \bm{s}{t}&=\bm{A}\bm{s}{t-1}+\bm{B}\bm{x}{t}+% \bm{b},\ \bm{y}{t}&=\phi(\bm{s}_{t}).\end{aligned}start_ROW start_CELL bold_italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL start_CELL = bold_italic_A bold_italic_s start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT + bold_italic_B bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + bold_italic_b , end_CELL end_ROW start_ROW start_CELL bold_italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL start_CELL = italic_ϕ ( bold_italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) . end_CELL end_ROW(Linear RNN)(3)

The fully linear state updates do not affect the expressivity of the models, as non-linear activations are nevertheless present in the MLP layers ϕ italic-ϕ\phi italic_ϕ between the linear state update layers (wang2024state; boyd1985fading). The state-of-the-art Linear RNN models also utilise some form of multiplicative gating (gu2023mamba; de2024griffin; botev2024recurrentgemma). While specific implementations can differ, we can abstract it as the following Gated Linear RNN architecture:

𝒔 t=𝑨⁢𝒔 t−1+𝑩⁢𝒙 t+𝒃,𝒚 t=γ⁢(𝒙 t)⊙ϕ⁢(𝒔 t),subscript 𝒔 𝑡 absent 𝑨 subscript 𝒔 𝑡 1 𝑩 subscript 𝒙 𝑡 𝒃 subscript 𝒚 𝑡 absent direct-product 𝛾 subscript 𝒙 𝑡 italic-ϕ subscript 𝒔 𝑡\displaystyle\begin{aligned} \bm{s}{t}&=\bm{A}\bm{s}{t-1}+\bm{B}\bm{x}{t}+% \bm{b},\ \bm{y}{t}&=\gamma(\bm{x}{t})\odot\phi(\bm{s}{t}),\end{aligned}start_ROW start_CELL bold_italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL start_CELL = bold_italic_A bold_italic_s start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT + bold_italic_B bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + bold_italic_b , end_CELL end_ROW start_ROW start_CELL bold_italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL start_CELL = italic_γ ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ⊙ italic_ϕ ( bold_italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) , end_CELL end_ROW(Gated Linear RNN)(4)

with γ 𝛾\gamma italic_γ being another MLP and ⊙direct-product\odot⊙ being the element-wise multiplication operation (Hadamard product). Eq.4 encompasses a range of recently proposed models. For example, one can show that any model consisting of L 𝐿 L italic_L stacked Gated Linear RNN layers, with γ 𝛾\gamma italic_γ and ϕ italic-ϕ\phi italic_ϕ with k 𝑘 k italic_k layers, can be represented as a L⁢(k+2)𝐿 𝑘 2 L(k{+}2)italic_L ( italic_k + 2 )-layer Hawk or Griffin model (de2024griffin). The conversions are described in detail in LABEL:sec:hawk_griffin. We can similarly add multiplicative gating to the classic RNN architecture:

𝒔 t=σ⁢(𝑨⁢𝒔 t−1+𝑩⁢𝒙 t+𝒃),𝒚 t=γ⁢(𝒙 t)⊙ϕ⁢(𝒔 t),subscript 𝒔 𝑡 absent 𝜎 𝑨 subscript 𝒔 𝑡 1 𝑩 subscript 𝒙 𝑡 𝒃 subscript 𝒚 𝑡 absent direct-product 𝛾 subscript 𝒙 𝑡 italic-ϕ subscript 𝒔 𝑡\displaystyle\begin{aligned} \bm{s}{t}&=\sigma(\bm{A}\bm{s}{t-1}+\bm{B}\bm{x% }{t}+\bm{b}),\ \bm{y}{t}&=\gamma(\bm{x}{t})\odot\phi(\bm{s}{t}),\end{aligned}start_ROW start_CELL bold_italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL start_CELL = italic_σ ( bold_italic_A bold_italic_s start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT + bold_italic_B bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + bold_italic_b ) , end_CELL end_ROW start_ROW start_CELL bold_italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL start_CELL = italic_γ ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ⊙ italic_ϕ ( bold_italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) , end_CELL end_ROW(Gated RNN)(5)

Eq.5 may appear unusual but it is related to the well-known GRU (cho2014learning) and LSTM (hochreiter1997long) architectures. Same as the case with Griffin/Hawk, any Gated RNN can be represented as a L⁢(k+2)𝐿 𝑘 2 L(k{+}2)italic_L ( italic_k + 2 )-layer GRU or LSTM model (details in LABEL:sec:GRU and LABEL:sec:lstms). As a result, if there exists a Gated RNN model that is a universal in-context approximator (which we later show to be the case), then there also exist GRU and LSTM models with the same property.

\untagged

neuripsAll the models above can be boiled down to compositions of a few building blocks. Namely, linear layers, ReLU activations, (non-)linear state updates and multiplicative operations (in the case of gated models). These four building blocks will be the primitives of LSRL, the programming language we introduce in Sec.3 as a tool to write programs that directly compile to these architectures. In practice, a number of additional elements might be present such as residual connections (he2016deep), positional embeddings (su2024roformer) and normalisation layers (ba2016layer; zhang2019root). However, as these are not necessary for showing the in-context universal approximation abilities of the four architectures above, we will not consider them in this work.

Theoretical understanding of in-context learning.

Beyond the question of universal in-context approximation, there have been attempts to theoretically understand in-context learning from various perspectives. The ability to learn linear functions and perform optimization in-context has been extensively explored in the context of linear regression (garg2022can; akyurek2022learning; oswald23transformers; fu2023transformers; zhang2023trained; ahn2023transformers), kernel regression (han2023context) and dynamical systems (li2023transformers). Furthermore, studies have explored how in-context learning identifies and applies the appropriate pretraining skill (xie2021explanation; coda2023meta; bai2024transformers). It has also been shown that transformers can construct internal learning objectives and optimize them during the forward pass (von2023uncovering; dai2023gpt). However, these studies almost exclusively focus on the transformer architecture, and the applicability of their findings to fully recurrent models remains unclear.

Approximation theory.

Let 𝒳 𝒳\mathcal{X}caligraphic_X and 𝒴 𝒴\mathcal{Y}caligraphic_Y be normed vector spaces. Take a set of functions 𝒞⊆𝒴 𝒳 𝒞 superscript 𝒴 𝒳\mathcal{C}\subseteq\mathcal{Y}^{\mathcal{X}}caligraphic_C ⊆ caligraphic_Y start_POSTSUPERSCRIPT caligraphic_X end_POSTSUPERSCRIPT from 𝒳 𝒳\mathcal{X}caligraphic_X to 𝒴 𝒴\mathcal{Y}caligraphic_Y called a concept space. Take also a set of nicely behaved functions ℋ⊂𝒴 𝒳 ℋ superscript 𝒴 𝒳\mathcal{H}\subset\mathcal{Y}^{\mathcal{X}}caligraphic_H ⊂ caligraphic_Y start_POSTSUPERSCRIPT caligraphic_X end_POSTSUPERSCRIPT, called hypothesis space. ℋ ℋ\mathcal{H}caligraphic_H could be any set that we have tools to construct and analyse, e.g., all polynomials or all neural networks of a particular architectural type. Approximation theory is concerned with how well functions in ℋ ℋ\mathcal{H}caligraphic_H approximate functions in 𝒞 𝒞\mathcal{C}caligraphic_C. We say that ℋ ℋ\mathcal{H}caligraphic_H universally approximates 𝒞 𝒞\mathcal{C}caligraphic_C over a compact domain 𝒟 𝒟\mathcal{D}caligraphic_D (or that ℋ ℋ\mathcal{H}caligraphic_H is dense in 𝒞 𝒞\mathcal{C}caligraphic_C) if for every f∈𝒞 𝑓 𝒞 f{\in}\mathcal{C}italic_f ∈ caligraphic_C and ϵ>0 italic-ϵ 0\epsilon{>}0 italic_ϵ > 0 there exist a h∈ℋ ℎ ℋ h{\in}\mathcal{H}italic_h ∈ caligraphic_H such that sup 𝒙∈𝒟|f⁢(𝒙)⁢-⁢h⁢(𝒙)|≤ϵ subscript supremum 𝒙 𝒟 𝑓 𝒙-ℎ 𝒙 italic-ϵ\sup_{\bm{x}\in\mathcal{D}}|f(\bm{x})\texttt{-}h(\bm{x})|{\leq}\epsilon roman_sup start_POSTSUBSCRIPT bold_italic_x ∈ caligraphic_D end_POSTSUBSCRIPT | italic_f ( bold_italic_x ) - italic_h ( bold_italic_x ) | ≤ italic_ϵ. There is a long history of studying the concept class of continuous functions and hypothesis classes of single hidden layer neural networks (cybenko1989approximation; barron1993universal) or deeper models (hornik1989multilayer; telgarsky2015representation). The concept class of sequence-to-sequence functions has been shown to be universally approximated with the hypothesis classes of transformers (yun2019transformers), RNNs (schafer2006recurrent) and Linear RNNs (wang2024state).

The hypothesis spaces in this work are different. The model is fixed and only the prompt part of the input is changed, i.e., all learnable parameters are in the prompt. Take a recurrent model g 𝑔 g italic_g as in Eq.1 with fixed model parameters and a query length n 𝑛 n italic_n. The hypothesis class is all functions that result by calling g 𝑔 g italic_g on the user query followed by the prompt and taking the last n′superscript 𝑛′n^{\prime}italic_n start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT outputs:

ℋ g 𝒟 n={(𝒒 1,…,𝒒 n)↦g⁢(𝒒 1,…,𝒒 n,𝒑 1,…,𝒑 N)⁢[-⁢n′⁢:]∣∀𝒑 i∈𝒟,N>0}.superscript subscript ℋ 𝑔 superscript 𝒟 𝑛 conditional-set maps-to subscript 𝒒 1…subscript 𝒒 𝑛 𝑔 subscript 𝒒 1…subscript 𝒒 𝑛 subscript 𝒑 1…subscript 𝒑 𝑁 delimited-[]-superscript 𝑛′:formulae-sequence for-all subscript 𝒑 𝑖 𝒟 𝑁 0\mathcal{H}{g}^{\mathcal{D}^{n}}={(\bm{q}{1},\ldots,\bm{q}{n})\mapsto g(% \bm{q}{1},\ldots,\bm{q}{n},\bm{p}{1},\ldots,\bm{p}{N})[\texttt{-}n^{\prime% }\texttt{:}]\mid\forall\bm{p}{i}\in\mathcal{D},N>0}.caligraphic_H start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT caligraphic_D start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT = { ( bold_italic_q start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_q start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) ↦ italic_g ( bold_italic_q start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_q start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , bold_italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_p start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ) [ - italic_n start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT : ] ∣ ∀ bold_italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ caligraphic_D , italic_N > 0 } .(6)

The domain 𝒟 𝒟\mathcal{D}caligraphic_D of 𝒑 i subscript 𝒑 𝑖\bm{p}{i}bold_italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and 𝒒 i subscript 𝒒 𝑖\bm{q}{i}bold_italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT can be continuous embeddings in ℝ d superscript ℝ 𝑑\mathbb{R}^{d}blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT or discrete tokens 𝒱={1,…,V}𝒱 1…𝑉\mathcal{V}={1,.\kern-0.20004pt.\kern-0.20004pt.,V}caligraphic_V = { 1 , … , italic_V }.

Note that each h∈ℋ g ℎ subscript ℋ 𝑔 h{\in}\mathcal{H}{g}italic_h ∈ caligraphic_H start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT is identified by a prompt (𝒑 1,…,𝒑 N)subscript 𝒑 1…subscript 𝒑 𝑁(\bm{p}{1},.\kern-0.20004pt.\kern-0.20004pt.,\bm{p}{N})( bold_italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_p start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ) but is a function with domain all possible queries (𝒒 1,…,𝒒 n)subscript 𝒒 1…subscript 𝒒 𝑛(\bm{q}{1},.\kern-0.20004pt.\kern-0.20004pt.,\bm{q}{n})( bold_italic_q start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_q start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ). Therefore, finding a hypothesis h∈ℋ g ℎ subscript ℋ 𝑔 h{\in}\mathcal{H}{g}italic_h ∈ caligraphic_H start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT that approximates a target function f 𝑓 f italic_f is equivalent to finding the prompt of that hypothesis. The approximation properties of ℋ g subscript ℋ 𝑔\mathcal{H}_{g}caligraphic_H start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT in Eq.6 depend on the architecture of g 𝑔 g italic_g, as well as its specific parameters. \untagged neuripsThis makes it challenging to do approximation in the context window. The possibilities for interaction between the inputs are limited and the effects of the fixed model weights can be difficult to study (petrov2023prompting). To the best of our knowledge, this has only been studied in the case where g 𝑔 g italic_g is a transformer model. wang2024universality showed that in-context universal approximation is possible with a transformer by discretizing and memorising all possible functions in the model weights, while, (petrov2024universal) argues that no memorisation is needed and that a transformer with n+2 𝑛 2 n{+}2 italic_n + 2 layers can be a universal approximator for sequence-to-sequence functions with input length n 𝑛 n italic_n with a prompt of length 𝒪⁢(ϵ−10−14⁢d−4⁢d 2)𝒪 superscript italic-ϵ 10 14 𝑑 4 superscript 𝑑 2\mathcal{O}(\epsilon^{-10-14d-4d^{2}})caligraphic_O ( italic_ϵ start_POSTSUPERSCRIPT - 10 - 14 italic_d - 4 italic_d start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT ).

We study the recurrent architectures in Eqs.2, 3, 4 and5 and their ability to approximate continuous functions over real-valued vectors and to represent discrete maps over tokens (which corresponds to how language models are used in practice). We consider the following classes of functions. 𝒞 vec=(ℝ d out)[0,1]d in superscript 𝒞 vec superscript superscript ℝ subscript 𝑑 out superscript 0 1 subscript 𝑑 in\mathcal{C}^{\text{vec}}{=}(\mathbb{R}^{d_{\text{out}}})^{[0,1]^{d_{\text{in}}}}caligraphic_C start_POSTSUPERSCRIPT vec end_POSTSUPERSCRIPT = ( blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT out end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT [ 0 , 1 ] start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT in end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT contains all continuous functions from the unit hypercube to ℝ d out superscript ℝ subscript 𝑑 out\mathbb{R}^{d_{\text{out}}}blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT out end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, while 𝒞 tok={h∈(𝒱 l)𝒱 l∣h⁢causal}superscript 𝒞 tok conditional-set ℎ superscript superscript 𝒱 𝑙 superscript 𝒱 𝑙 ℎ causal\mathcal{C}^{\text{tok}}{=}{h{\in}(\mathcal{V}^{l})^{\mathcal{V}^{l}}\mid h% \text{ causal}}caligraphic_C start_POSTSUPERSCRIPT tok end_POSTSUPERSCRIPT = { italic_h ∈ ( caligraphic_V start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT caligraphic_V start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT ∣ italic_h causal } all causal functions from l 𝑙 l italic_l tokens to l 𝑙 l italic_l tokens. The hypothesis classes are ℋ vec⁢(g)superscript ℋ vec 𝑔\mathcal{H}^{\text{vec}}(g)caligraphic_H start_POSTSUPERSCRIPT vec end_POSTSUPERSCRIPT ( italic_g ) corresponding to Eq.6 with D=[0,1]d in,n=n′=1 formulae-sequence 𝐷 superscript 0 1 subscript 𝑑 in 𝑛 superscript 𝑛′1 D{=}[0,1]^{d_{\text{in}}},n{=}n^{\prime}{=}1 italic_D = [ 0 , 1 ] start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT in end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , italic_n = italic_n start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = 1 and g 𝑔 g italic_g some fixed model of one of the four architectures in Eqs.2, 3, 4 and5, and ℋ tok⁢(g)superscript ℋ tok 𝑔\mathcal{H}^{\text{tok}}(g)caligraphic_H start_POSTSUPERSCRIPT tok end_POSTSUPERSCRIPT ( italic_g ) with D=𝒱 𝐷 𝒱 D{=}\mathcal{V}italic_D = caligraphic_V and n=n′=l 𝑛 superscript 𝑛′𝑙 n{=}n^{\prime}{=}l italic_n = italic_n start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = italic_l.

3 Linear State Recurrent Language (LSRL)

Image 1: Refer to caption

Figure 1: Compilation of an LSRL program to a Linear RNN. An example of a simple LSRL program that takes a sequence of 0s and 1s as an input and outputs 1 if there have been more 1s than 0s and 0 otherwise. The LSRL compiler follows the rules in LABEL:sec:debranching_rules to simplify the computation DAG into a path graph. The resulting path graph can be represented as a Linear RNN with one layer.

{textblock*}

7cm(0.4cm, -6.8cm) 1 ForEach:2 input=Input(dim=1)4 ctr1s=LinState(input,5 A=[[1]],B=[[1]],6 init_state=[[0]])8 9 is_zero=f_not(input)10 ctr0s=LinState(is_zero,11 A=[[1]],B=[[1]],12 init_state=[[0]])14 15 output=f_larger(16 ctr1s,ctr0s,mu=10)17 return output

We can construct the weights for universal in-context models with the architectures in Eqs.2, 3, 4 and5 by hand but this is labour-intensive, error-prone, difficult to interpret, and the specific weights would be architecture-dependent. Working at such a low level of abstraction can also obfuscate common mechanisms and design patterns, making it more difficult to appreciate both the capabilities and the constraints of fully recurrent architectures. Instead, we propose a new programming language: Linear State Recurrent Language (LSRL).2 2 2 Our implementation of LSRL is available at https://github.com/AleksandarPetrov/LSRL LSRL programs compile to the four architectures in Eqs.2, 3, 4 and5. Conversely, any Linear RNN can be represented as an LSRL program, making LSRL a versatile tool for studying the capabilities of recurrent models. Later, in Secs.4, LABEL:sec:ua_gated_linear_RNNs and LABEL:sec:ua_nonlinear_RNNs we make use of LSRL to develop programs that are universal approximators for 𝒞 vec superscript 𝒞 vec\mathcal{C}^{\text{vec}}caligraphic_C start_POSTSUPERSCRIPT vec end_POSTSUPERSCRIPT and 𝒞 tok superscript 𝒞 tok\mathcal{C}^{\text{tok}}caligraphic_C start_POSTSUPERSCRIPT tok end_POSTSUPERSCRIPT, thus showing that all four architectures can be universal in-context approximators.

LSRL syntax.

An LSRL program specifies how a single element is processed and how the recurrent states are updated for the next element. LSRL programs always start with an Input⁢(𝒙)=𝒙 Input 𝒙 𝒙\texttt{Input}(\bm{x})=\bm{x}Input ( bold_italic_x ) = bold_italic_x with an 𝒙 𝒙\bm{x}bold_italic_x of a fixed dimension. Only one Input can be declared in a program. Linear layers and ReLU s are also supported: Lin⁢[𝑨,𝒃]⁢(𝒙):=𝑨⁢𝒙+𝒃 assign Lin 𝑨 𝒃 𝒙 𝑨 𝒙 𝒃\texttt{Lin}\bm{A},\bm{b}:=\bm{A}\bm{x}+\bm{b}Lin [ bold_italic_A , bold_italic_b ] ( bold_italic_x ) := bold_italic_A bold_italic_x + bold_italic_b, ReLU⁢(𝒙):=max⁡(𝟎,𝒙)assign ReLU 𝒙 0 𝒙\texttt{ReLU}(\bm{x}):=\max(\bm{0},\bm{x})ReLU ( bold_italic_x ) := roman_max ( bold_0 , bold_italic_x ). The unique component of LSRL, however, is its LinState operation implementing the linear state update in Linear RNNs (Eq.3): LinState⁢[𝑨,𝑩,𝒃,𝒔 0]⁢(𝒙 t):=𝑨⁢𝒔 t−1+𝑩⁢𝒙 t+𝒃 assign LinState 𝑨 𝑩 𝒃 subscript 𝒔 0 subscript 𝒙 𝑡 𝑨 subscript 𝒔 𝑡 1 𝑩 subscript 𝒙 𝑡 𝒃\texttt{LinState}\bm{A},\bm{B},\bm{b},\bm{s}_{0}:=\bm{A}\bm{s}{% t-1}+\bm{B}\bm{x}{t}+\bm{b}LinState [ bold_italic_A , bold_italic_B , bold_italic_b , bold_italic_s start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ] ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) := bold_italic_A bold_italic_s start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT + bold_italic_B bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + bold_italic_b, where the state 𝒔 t−1 subscript 𝒔 𝑡 1\bm{s}{t-1}bold_italic_s start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT is the output of the call this node at step t−1 𝑡 1 t-1 italic_t - 1. LinState is the only way information can be passed from previous tokens to the current one. We also provide a Concat operation that combines variables: Concat⁢(𝒙,𝒚):=(𝒙 1,…,𝒙|𝒙|,𝒚 1,…,𝒚|𝒚|)assign Concat 𝒙 𝒚 subscript 𝒙 1…subscript 𝒙 𝒙 subscript 𝒚 1…subscript 𝒚 𝒚\texttt{Concat}(\bm{x},\bm{y}):=(\bm{x}{1},.\kern-0.20004pt.\kern-0.20004pt.,% \bm{x}{|\bm{x}|},\bm{y}{1},.\kern-0.20004pt.\kern-0.20004pt.,\bm{y}_{|\bm{y}% |})Concat ( bold_italic_x , bold_italic_y ) := ( bold_italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_x start_POSTSUBSCRIPT | bold_italic_x | end_POSTSUBSCRIPT , bold_italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_y start_POSTSUBSCRIPT | bold_italic_y | end_POSTSUBSCRIPT ). Finally, to support gating architectures we also implement a rudimentary Multi operation that splits its input into two sub-arrays and returns their element-wise multiplication: Multi⁢(𝒙):=𝒙⁢[:|𝒙|/2⁢]⊙𝒙⁢[⁢|𝒙|/2:]:assign Multi 𝒙 𝒙[direct-product 𝒙 2]𝒙[𝒙 2:]\texttt{Multi}(\bm{x}):=\bm{x}\texttt{[}:\nicefrac{{|\bm{x}|}}{{2}}\texttt{]}% \odot\bm{x}\texttt{[}\nicefrac{{|\bm{x}|}}{{2}}:\texttt{]}Multi ( bold_italic_x ) := bold_italic_x [ : / start_ARG | bold_italic_x | end_ARG start_ARG 2 end_ARG ] ⊙ bold_italic_x [ / start_ARG | bold_italic_x | end_ARG start_ARG 2 end_ARG : ]. Naturally, Multi requires that 𝒙 𝒙\bm{x}bold_italic_x has even length. These six operations can be composed into a direct acyclic graph (DAG) with a single source node (the Input variable) and a single sink node (marked with a return statement).

Such a program operates over a single token 𝒙 t subscript 𝒙 𝑡\bm{x}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT passed to Input, while a recurrent model needs to operate over sequences. Thus, we wrap the program into a ForEach loop that passes each element individually for the DAG to output a variable denoted by a return clause. Each element is processed by the exact same program, with the only difference being that the state of the LinState variables is changing between iterations. You can see an example of a small LSRL program in Fig.1.

Expressiveness limitations.

ForEach does not behave like the typical for loop: only the states are accessible between iterations, i.e., you cannot use the output of a linear layer at step t 𝑡 t italic_t in any computation at step t+1 𝑡 1 t+1 italic_t + 1. Furthermore, as the program is a DAG and only states of LinState nodes are passed between iterations, variables computed in latter operations of a previous time step are not accessible as inputs in earlier layers (with respect to the topological sorting of the computation graph). This leads to a key programming paradigm in LSRL: a LinState update cannot depend non-linearly on its own state. That includes it depending on a variable that depends on the LinState itself and conditional updates to the state. Such a dependency would break the DAG property of the program.3 3 3 For example, we cannot implement an operation that adds one to the state and squares it at each time step: s t+1=(s t+1)2 subscript 𝑠 𝑡 1 superscript subscript 𝑠 𝑡 1 2 s_{t+1}=(s_{t}+1)^{2}italic_s start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT = ( italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + 1 ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT or an operation that performs conditional assignment s t+1=0⁢if⁢(s t>5)⁢else⁢s t subscript 𝑠 𝑡 1 0 if subscript 𝑠 𝑡 5 else subscript 𝑠 𝑡 s_{t+1}=0\texttt{ if }(s_{t}>5)\texttt{ else }s_{t}italic_s start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT = 0 if ( italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT > 5 ) else italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. This poses serious limitations on what algorithms can be expressed in a Linear RNN and makes programming them challenging. Still, in Sec.4 we show how carefully constructing state updates and auxiliary variables can nevertheless allow to program some limited conditional behaviours.

Compilation.

Any LSRL program without Multi nodes can be compiled to a Linear RNN (Eq.3) or to a Gated Linear RNN (Eq.4). If the program has Multi nodes, then it cannot be compiled to a Linear RNN as the multiplicative gating cannot be implemented exactly. However, it can be compiled to a Gated Linear RNN. To compile an LSRL program to a Linear (Gated) RNN, we first parse the program to build a computation graph. This is a DAG with a single source (the Input node) and a single sink (the return statement of the ForEach loop). At the same time, a Linear (Gated) RNN can be represented as a path graph (no branching) with the six basic operations as nodes. Therefore, the compilation step needs to transform this DAG into a path graph. We achieve that by iterativly collapsing the first branching point into a single node. The exact rules that achieve that are described in LABEL:sec:debranching_rules. Later, in LABEL:sec:ua_nonlinear_RNNs, we will show how any Linear (Gated) RNN can be converted into a non-linear (Gated) RNN, hence, how we can compile LSRL programs to these architectures as well.

1 ForEach:

2 input=Input(dim=1+d_in+d_out)

4

5 const_1=f_constant(input,1)

6 counter_vector=LinState(input=const_1,A=ones(d_in,d_in),B=ones(d_in,1),init_state=zeros(d_in,1))

8

9 q_update=f_ifelse(cond=f_smaller(counter_vector,1.5),t=input[:d_in],f=input[:d_in]*0)

10 q=LinState(input=q_update,A=eye(d_in),B=eye(d_in),init_state=zeros(d_in,1))

12

13

14 step_size=Linear(input=input[0],A=ones(d_in,1),b=zeros(d_in,1))

16

17 lb=input[1:1+d_in]

18 ub=lb+step_size

20

21 q_in_bump_componentwise=f_bump(q,lb,ub)

22 bump_sum=Linear(input=q_in_bump_componentwise,A=ones(1,d_in),b=zeros(1,1))

23 in_cell=f_larger(bump_sum,d_in-0.5)

24 in_and_processing=f_and(in_cell,f_larger(counter,0.5))

26

27 update=f_ifelse(cond=f_larger(in_and_processing,0.5),t=input[-d_out:],f=input[-d_out:]*0)

28 y=LinState(input=update,A=eye(d_out),B=eye(d_out),init_state=zeros(d_out,1))

29 return y

Listing 1: LSRL program for universal approximation in-context for continuous functions. The inputs are 𝒒=[𝒒′⁣⊤,𝟎 d out+1⊤]⊤𝒒 superscript superscript 𝒒′top superscript subscript 0 subscript 𝑑 out 1 top top\bm{q}=[\bm{q}^{\prime\top},\bm{0}{d{\text{out}}+1}^{\top}]^{\top}bold_italic_q = [ bold_italic_q start_POSTSUPERSCRIPT ′ ⊤ end_POSTSUPERSCRIPT , bold_0 start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT out end_POSTSUBSCRIPT + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT with 𝒒′∈[0,1]d in superscript 𝒒′superscript 0 1 subscript 𝑑 in\bm{q}^{\prime}\in[0,1]^{d_{\text{in}}}bold_italic_q start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ [ 0 , 1 ] start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT in end_POSTSUBSCRIPT end_POSTSUPERSCRIPT being the query value at which we want to evaluate the function, then followed by prompts describing the target function as in LABEL:eq:continous_prompt_def.

Syntactic sugar.

To make programming easier, we define several convenience functions. For instance, we can Slice variables 𝒙[l:u]\bm{x}[l{:}u]bold_italic_x [ italic_l : italic_u ] via sparse Lin layers. We can also sum variables and element-wise multiplication with scalars (implemented as Lin layers). For logical operations we also need step functions which can be approximated with ReLU s: f_step⁢[μ]⁢(𝒙):=ReLU⁢(μ⁢𝒙)−μ⁢ReLU⁢(𝒙−1/μ)assign f_step delimited-[]𝜇 𝒙 ReLU 𝜇 𝒙 𝜇 ReLU 𝒙 1 𝜇\texttt{f_step}\mu:=\texttt{ReLU}(\mu\bm{x})-\mu\texttt{ReLU}(\bm{x% }-\nicefrac{{1}}{{\mu}})f_step [ italic_μ ] ( bold_italic_x ) := ReLU ( italic_μ bold_italic_x ) - italic_μ ReLU ( bold_italic_x - / start_ARG 1 end_ARG start_ARG italic_μ end_ARG ), where μ 𝜇\mu italic_μ is a positive constant controlling the quality of the approximation. We can also approximate bump functions (1 between l 𝑙 l italic_l and u 𝑢 u italic_u and 0 otherwise): f_bump⁢[𝒍,𝒖,μ]⁢(𝒙):=f_step⁢[μ]⁢(𝒙−𝒍)−f_step⁢[μ]⁢(𝒙−𝒖)assign f_bump 𝒍 𝒖 𝜇 𝒙 f_step delimited-[]𝜇 𝒙 𝒍 f_step delimited-[]𝜇 𝒙 𝒖\texttt{f_bump}\bm{l},\bm{u},\mu:=\texttt{f_step}[\mu](\bm{x}-\bm{% l})-\texttt{f_step}\muf_bump [ bold_italic_l , bold_italic_u , italic_μ ] ( bold_italic_x ) := f_step [ italic_μ ] ( bold_italic_x - bold_italic_l ) - f_step [ italic_μ ] ( bold_italic_x - bold_italic_u ). Similarly, we can approximate conjunction (f_and), disjunction (f_or), negation (f_not), and comparison operators (f_larger and f_smaller). See LABEL:sec:sugar_in_lsrl for the definitions.

Critically, we need also a conditional operator that assigns a value t⁢(𝒙)t 𝒙\texttt{t}(\bm{x})t ( bold_italic_x ) if a certain condition is met and another value f⁢(𝒙)f 𝒙\texttt{f}(\bm{x})f ( bold_italic_x ) otherwise. One way to implement this is:

f_ifelse⁢[cond,t,f,λ]⁢(𝒙)f_ifelse cond t f 𝜆 𝒙\displaystyle\texttt{f_ifelse}[\texttt{cond},\texttt{t},\texttt{f},\lambda](% \bm{x})f_ifelse [ cond , t , f , italic_λ ] ( bold_italic_x ):=ReLU⁢(-⁢λ⁢cond⁢(𝒙)⁢+f⁢(𝒙))+ReLU⁢(-⁢λ⁢f_not⁢(cond⁢(𝒙))⁢+t⁢(𝒙))assign absent ReLU-𝜆 cond 𝒙+f 𝒙 ReLU-𝜆 f_not cond 𝒙+t 𝒙\displaystyle:=\texttt{ReLU}(\texttt{-}\lambda,\texttt{cond}(\bm{x})\texttt{+% }\texttt{f}(\bm{x}))+\texttt{ReLU}(\texttt{-}\lambda,\texttt{f_not}(\texttt{% cond}(\bm{x}))\texttt{+}\texttt{t}(\bm{x})):= ReLU ( - italic_λ cond ( bold_italic_x ) typewriter_+ typewriter_f ( bold_italic_x ) ) + ReLU ( - italic_λ f_not ( cond ( bold_italic_x ) ) typewriter_+ typewriter_t ( bold_italic_x ) )(7) −ReLU⁢(-⁢λ⁢cond⁢(𝒙)⁢-f⁢(𝒙))−ReLU⁢(-⁢λ⁢f_not⁢(cond⁢(𝒙))⁢-t⁢(𝒙)),ReLU-𝜆 cond 𝒙-f 𝒙 ReLU-𝜆 f_not cond 𝒙-t 𝒙\displaystyle\hskip 3.50006pt-\texttt{ReLU}(\texttt{-}\lambda,\texttt{cond}(% \bm{x})\texttt{-}\texttt{f}(\bm{x}))-\texttt{ReLU}(\texttt{-}\lambda,\texttt{% f_not}(\texttt{cond}(\bm{x}))\texttt{-}\texttt{t}(\bm{x})),- ReLU ( - italic_λ cond ( bold_italic_x ) typewriter_- typewriter_f ( bold_italic_x ) ) - ReLU ( - italic_λ f_not ( cond ( bold_italic_x ) ) typewriter_- typewriter_t ( bold_italic_x ) ) ,

where λ 𝜆\lambda italic_λ is a constant that is larger than any absolute value that t⁢(𝒙)t 𝒙\texttt{t}(\bm{x})t ( bold_italic_x ) and f⁢(𝒙)f 𝒙\texttt{f}(\bm{x})f ( bold_italic_x ) can attain. This construction, however, is not numerically stable \untagged neurips(consider if cond⁢(𝒙)cond 𝒙\texttt{cond}(\bm{x})cond ( bold_italic_x ) is not exactly 0 but a small positive number) and we will study alternatives in LABEL:sec:ua_gated_linear_RNNs. We provide both numerical (SciPy.sparse, SciPy) and symbolic (SymPy, SymPy) backends with the second being crucial for programs that are not numerically stable.

\untagged

neurips

Constant and dynamic variables.

It is important also to distinguish between variables which can be dynamically assigned and such that must by “baked in” the model weights and be constant. Some operations can only be performed when one of the operands is a constant. For example, with a Linear RNN we cannot exactly compute the product of two variables —such as Lin⁢[𝑨 1,𝒃 1]⁢(𝒙)⊙Lin⁢[𝑨 2,𝒃 2]⁢(𝒙)direct-product Lin subscript 𝑨 1 subscript 𝒃 1 𝒙 Lin subscript 𝑨 2 subscript 𝒃 2 𝒙\texttt{Lin}\bm{A}{1},\bm{b}{1}\odot\texttt{Lin}\bm{A}_{2},\bm{b}% _{2}Lin [ bold_italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ] ( bold_italic_x ) ⊙ Lin [ bold_italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , bold_italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ] ( bold_italic_x )— but we can compute a product with a fixed vector 𝒗⊙Lin⁢[𝑨 2,𝒃 2]⁢(𝒙)direct-product 𝒗 Lin subscript 𝑨 2 subscript 𝒃 2 𝒙\bm{v}\odot\texttt{Lin}\bm{A}{2},\bm{b}{2}bold_italic_v ⊙ Lin [ bold_italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , bold_italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ] ( bold_italic_x ). This is also why λ 𝜆\lambda italic_λ in Eq.7 cannot be dynamically computed depending on the input 𝒙 𝒙\bm{x}bold_italic_x. This is not the case for the gated architectures, where variable product is possible, something we will leverage to construct more numerically stable conditional operators in LABEL:sec:ua_gated_linear_RNNs.

Prior work on encoding algorithms in model weights.

A similar approach to developing a programming language that compiles to model weights was already done for the transformer architecture with the RASP language (weiss2021thinking) and the Tracr compiler (lindner2024tracr). They were predominantly created as a tool for interpretability research. In a sense, RASP is to a transformer as LSRL is to a (Linear) (Gated) RNN. Hence, can be used to develop benchmarks for interpretability methods for fully-recurrent architectures. However, while RASP can only express a subset of transformer models, LSRL is isomorphic to the set of all (Gated) Linear RNNs (though not to the non-linear ones). That means that any (Gated) Linear RNN can be represented and analysed as an LSRL program and vice versa. Hence, the limitations of what you can express in LSRL are also limitations of what a Linear (Gated) RNN can do. Namely: (i) we cannot have exact multiplicative interactions between inputs without multiplicative gates, and (ii) we cannot have state variable updates depending non-linearly on their previous iterations or in any way on a variable that depends on them.

4 Universal In-Context Approximation with Linear RNNs

Image 2: Refer to caption

Figure 2: Intuition behind the LSRL program for universal in-context approximation for continuous functions in Lst.1. Our target function f 𝑓 f italic_f has input dimension d in=2 subscript 𝑑 in 2 d_{\text{in}}=2 italic_d start_POSTSUBSCRIPT in end_POSTSUBSCRIPT = 2 and output dimension d out=1 subscript 𝑑 out 1 d_{\text{out}}=1 italic_d start_POSTSUBSCRIPT out end_POSTSUBSCRIPT = 1. Each input dimension is split into two parts, hence δ=1/2 𝛿 1 2\delta=\nicefrac{{1}}{{2}}italic_δ = / start_ARG 1 end_ARG start_ARG 2 end_ARG. We illustrated an example input sequence of length 5: one for the query and four for the prompt tokens corresponding to each of the discretisation cells. The query (q 1,q 2)subscript 𝑞 1 subscript 𝑞 2(q_{1},q_{2})( italic_q start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_q start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) falls in the cell corresponding to the third prompt token. We show how the two LinState variables in the program are updated after each step. Most notably, how the state holding the output y is updated after 𝒑 3 subscript 𝒑 3\bm{p}_{3}bold_italic_p start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT is processed.

Xet Storage Details

Size:
54.1 kB
·
Xet hash:
eeae5587e12fa17576b9c9487fef5b3cb9b93d989f5144306158687d6a5fb868

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.