Buckets:
Title: Finding Alignments Between Interpretable Causal Variables and Distributed Neural Representations
URL Source: https://arxiv.org/html/2303.02536
Published Time: Fri, 23 Feb 2024 01:12:49 GMT
Markdown Content: \clearauthor
Atticus Geiger∗∗{}^{\ast}start_FLOATSUPERSCRIPT ∗ end_FLOATSUPERSCRIPT♢♢{}^{\diamondsuit}start_FLOATSUPERSCRIPT ♢ end_FLOATSUPERSCRIPT, Zhengxuan Wu††thanks: Equal contribution., Christopher Potts, Thomas Icard, and Noah D.Goodman
Pr(Ai)2 2{}^{2}start_FLOATSUPERSCRIPT 2 end_FLOATSUPERSCRIPT R Group♢♢{}^{\diamondsuit}start_FLOATSUPERSCRIPT ♢ end_FLOATSUPERSCRIPT Stanford University
{atticusg, wuzhengx, cgpotts, icard, ngoodman}@stanford.edu
Abstract
Causal abstraction is a promising theoretical framework for explainable artificial intelligence that defines when an interpretable high-level causal model is a faithful simplification of a low-level deep learning system. However, existing causal abstraction methods have two major limitations: they require a brute-force search over alignments between the high-level model and the low-level one, and they presuppose that variables in the high-level model will align with disjoint sets of neurons in the low-level one. In this paper, we present distributed alignment search(DAS), which overcomes these limitations. In DAS, we find the alignment between high-level and low-level models using gradient descent rather than conducting a brute-force search, and we allow individual neurons to play multiple distinct roles by analyzing representations in non-standard bases—distributed representations. Our experiments show that DAS can discover internal structure that prior approaches miss. Overall, DAS removes previous obstacles to uncovering conceptual structure in trained neural nets.
1 Introduction
Can an interpretable symbolic algorithm be used to faithfully explain a complex neural network model? This is a key question for interpretability; a positive answer can provide guarantees about how the model will behave, and a negative answer could lead to fundamental concerns about whether the model will be safe and trustworthy.
Causal abstraction provides a mathematical framework for precisely characterizing what it means for any complex causal system (e.g., a deep learning model) to implement a simpler causal system (e.g., a symbolic algorithm) (Rubenstein et al., 2017; Beckers et al., 2019; Massidda et al., 2023). For modern AI models, the fundamental operation for assessing whether this relationship holds in practice has been the interchange intervention (also known as activation patching), in which a neural network is provided a ‘base’ input, and sets of neurons are forced to take on the values they would have if different ‘source’ inputs were processed (Geiger et al., 2020; Vig et al., 2020; Finlayson et al., 2021; Meng et al., 2022). The counterfactuals that these interventions create are the basis for causal inferences about model behavior.
Geiger et al. (2021) show that the relevant causal abstraction relation obtains when interchange interventions on aligned high-level variables and low-level variables have equivalent effects. This ideal relationship rarely obtains in practice, but the proportion of interchange interventions with the same effect (interchange intervention accuracy; IIA) provides a graded notion, and Geiger et al. (2023) formally ground this metric in the theory of approximate causal abstraction. also use causal abstraction theory as a unified framework for a wide range of recent intervention-based analysis methods (Vig et al., 2020; Csordás et al., 2021; Feder et al., 2021; Ravfogel et al., 2020; Elazar et al., 2020; De Cao et al., 2021; Abraham et al., 2022; Olah et al., 2020; Olsson et al., 2022; Chan et al., 2022).
Causal abstraction techniques have been applied to diverse problems (Geiger et al., 2019, 2020; Li et al., 2021; Huang et al., 2022). However, previous applications have faced two central challenges. First, causal abstraction requires a computationally intensive brute-force search process to find optimal alignments between the variables in the high-level model and the states of the low-level one. Where exhaustive search is intractable, we risk missing the best alignment entirely. Second, these prior methods are localist: they artificially limit the space of possible alignments by presupposing that high-level causal variables will be aligned with disjoint groups of neurons. There is no reason to assume this a priori, and indeed much recent work in model explanation (see especially Ravfogel et al. 2020, 2022; Elazar et al. 2020; Olah et al. 2020; Olsson et al. 2022) is converging on the insight of Smolensky (1986), Rumelhart et al. (1986), and McClelland et al. (1986) that individual neurons can play multiple conceptual roles. Smolensky (1986) identified distributed neural representations as “patterns” consisting of linear combinations of unit vectors.
In the current paper, we propose distributed alignment search(DAS), which overcomes the above limitations of prior causal abstraction work. In DAS, we find the best alignment via gradient descent rather than conducting a brute-force search. In addition, we use distributed interchange interventions, which are “soft” interventions in which the causal mechanisms of a group of neurons are edited such that (1) their values are rotated with a change-of-basis matrix, (2) the targeted dimensions of the rotated neural representation are fixed to be the corresponding values in the rotated neural representation created for the source inputs, and (3) the representation is rotated back to the standard neuron-aligned basis. The key insight is that viewing a neural representation through an alternative basis that is not aligned with individual neurons can reveal interpretable dimensions (Smolensky, 1986).
In our experiments, we evaluate the capabilities of DAS to provide faithful and interpretable explanations with two tasks that have obvious interpretable high-level algorithmic solutions with two intermediate variables. In both tasks, the distributed alignment learned by DAS is as good or better than both the closest localist alignment and the best localist alignment in a brute-force search.
In our first set of experiments, we focus on a hierarchical equality task that has been used extensively in developmental and cognitive psychology as a test of relational reasoning (Premack, 1983; Thompson et al., 1997; Geiger et al., 2022a): the inputs are sequences [w,x,y,z]𝑤 𝑥 𝑦 𝑧[w,x,y,z][ italic_w , italic_x , italic_y , italic_z ], and the label is given by (w=x)=(y=z)𝑤 𝑥 𝑦 𝑧(w=x)=(y=z)( italic_w = italic_x ) = ( italic_y = italic_z ). We train a simple feed-forward neural network on this task and show that it perfectly solves the task. Our key question: does this model implement a program that computes w=x 𝑤 𝑥 w=x italic_w = italic_x and y=z 𝑦 𝑧 y=z italic_y = italic_z as intermediate values, as we might hypothesize humans do? Using DAS, we find a distributed alignment with 100% IIA. In other words, the network is perfectly abstracted by the high-level model; the distinction between the learned neural model and the symbolic algorithm is thus one of implementation.
Our second task models a natural language inference dataset (Geiger et al., 2020) where the inputs are premise and hypothesis sentences (p,h)𝑝 ℎ(p,h)( italic_p , italic_h ) that are identical but for the words w p subscript 𝑤 𝑝 w_{p}italic_w start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT and w h subscript 𝑤 ℎ w_{h}italic_w start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT; the label is either entails (p 𝑝 p italic_p makes h ℎ h italic_h true) or contradicts/neutral (p 𝑝 p italic_p makes h ℎ h italic_h false). We fine-tune a pretrained language model to perfectly solve the task. With DAS, we find a perfect alignment (100% IIA) to a causal model with a binary variable for the entailment relation between the words w p subscript 𝑤 𝑝 w_{p}italic_w start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT and w h subscript 𝑤 ℎ w_{h}italic_w start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT (e.g., dog entails mammal).
In both our sets of experiments, the DAS analyses reveal perfect abstraction relations. However, we also identify an important difference between them. In the NLI case, the entailment relation can be decomposed into representations of w p subscript 𝑤 𝑝 w_{p}italic_w start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT and w h subscript 𝑤 ℎ w_{h}italic_w start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT. What appears to be a representation of lexical entailment is, in this case, a “data structure” containing two representations of word identity, rather than an encoding of their entailment relation. By contrast, the hierarchical equality models learn representations of w=x 𝑤 𝑥 w=x italic_w = italic_x and y=z 𝑦 𝑧 y=z italic_y = italic_z that cannot be decomposed into representations of w 𝑤 w italic_w, x 𝑥 x italic_x, y 𝑦 y italic_y and z 𝑧 z italic_z. In other words, these relations are entirely abstracted from the entities participating in the relation; DAS reveals that the neural network truly implements a symbolic, tree-structured algorithm.
2 Related Work
A theory of causal abstraction specifies exactly when a ‘high-level causal model’ can be seen as an abstract characterization of some ‘low-level causal model’ (Iwasaki and Simon, 1994; Chalupka et al., 2017; Rubenstein et al., 2017; Beckers et al., 2019). The basic idea is that high-level variables are associated with (potentially overlapping) sets of low-level variables that summarize their causal mechanisms with respect to a set of hard or soft interventions (Massidda et al., 2023). In practice, a graded notion of approximate causal abstraction is often more useful (Beckers et al., 2019; Rischel and Weichwald, 2021; Geiger et al., 2023).
Geiger et al. (2023) argue that causal abstraction is a generic theoretical framework for providing faithful(Jacovi and Goldberg, 2020; Lyu et al., 2022) and interpretable(Lipton, 2018) explanations of AI models and show that LIME (Ribeiro et al., 2016), causal effect estimation (Abraham et al., 2022; Feder et al., 2021), causal mediation analysis (Vig et al., 2020; Csordás et al., 2021; De Cao et al., 2021), iterated nullspace projection (Ravfogel et al., 2020; Elazar et al., 2020), and circuit-based explanations (Olah et al., 2020; Olsson et al., 2022; Wang et al., 2022; Chan et al., 2022) can all be understood as causal abstraction analysis.
Interchange intervention training (IIT) objectives are minimized when a high-level causal model is an abstraction of a neural network under a given alignment (Geiger et al., 2022b; Wu et al., 2022; Huang et al., 2022). In this paper, we use IIT objectives to learn an alignment between a high-level causal model and a deep learning model.
3 Methods
We focus on acyclic causal models (Pearl, 2001; Spirtes et al., 2000) and seek to provide an intuitive overview of our method. An acyclic causal model consists of input, intermediate, and output variables, where each variable has an associated set of values it can take on and a causal mechanism that determine the value of the variable based on the value of its causal parents. For a simple running example, we modify the boolean conjunction models of Geiger et al. (2022b) to reveal key properties of DAS. A causal model ℬ ℬ\mathcal{B}caligraphic_B for this problem can be defined as below, where the inputs and outputs are booleans t and f. Alongside ℬ ℬ\mathcal{B}caligraphic_B, we also define a causal model 𝒩 𝒩\mathcal{N}caligraphic_N of a linear feed-forward neural network that solves the task. Here we show ℬ ℬ\mathcal{B}caligraphic_B, 𝒩 𝒩\mathcal{N}caligraphic_N, and the parameters of 𝒩 𝒩\mathcal{N}caligraphic_N:
W 1=[cos(20∘)−sin(20∘)]subscript 𝑊 1 delimited-[]superscript 20 superscript 20 W_{1}=\left[\begin{array}[]{rr}\cos(20^{\circ})&-\sin(20^{\circ})\end{array}\right]italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = [ start_ARRAY start_ROW start_CELL roman_cos ( 20 start_POSTSUPERSCRIPT ∘ end_POSTSUPERSCRIPT ) end_CELL start_CELL - roman_sin ( 20 start_POSTSUPERSCRIPT ∘ end_POSTSUPERSCRIPT ) end_CELL end_ROW end_ARRAY ]𝐰=[1 1]𝐰 delimited-[]1 1\mathbf{w}=\left[\begin{array}[]{ll}1&1\end{array}\right]bold_w = [ start_ARRAY start_ROW start_CELL 1 end_CELL start_CELL 1 end_CELL end_ROW end_ARRAY ] W 2=[sin(20∘)cos(20∘)]subscript 𝑊 2 delimited-[]superscript 20 superscript 20 W_{2}=\left[\begin{array}[]{rr}\sin(20^{\circ})&\phantom{-}\cos(20^{\circ})% \end{array}\right]italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = [ start_ARRAY start_ROW start_CELL roman_sin ( 20 start_POSTSUPERSCRIPT ∘ end_POSTSUPERSCRIPT ) end_CELL start_CELL roman_cos ( 20 start_POSTSUPERSCRIPT ∘ end_POSTSUPERSCRIPT ) end_CELL end_ROW end_ARRAY ]b=−1.8 𝑏 1.8 b=-1.8 italic_b = - 1.8
The model 𝒩 𝒩\mathcal{N}caligraphic_N predicts t if O>0 𝑂 0 O>0 italic_O > 0 and f otherwise. This network solves the boolean conjunction problem perfectly in that all pairs of input boolean values are mapped to the intended output.
An input 𝐱 𝐱\mathbf{x}bold_x of a model ℳ ℳ\mathcal{M}caligraphic_M determines a unique total setting ℳ(𝐱)ℳ 𝐱\mathcal{M}(\mathbf{x})caligraphic_M ( bold_x ) of all the variables in the model. The inputs are fixed to be 𝐱 𝐱\mathbf{x}bold_x and the causal mechanisms of the model determine the values of the remaining variables. We denote the values that ℳ(𝐱)ℳ 𝐱\mathcal{M}(\mathbf{x})caligraphic_M ( bold_x ) assigns to the variable or variables 𝐙 𝐙\mathbf{Z}bold_Z as GetValues 𝐙(ℳ(𝐱))subscript GetValues 𝐙 ℳ 𝐱\textsc{GetValues}{\mathbf{Z}}(\mathcal{M}(\mathbf{x}))GetValues start_POSTSUBSCRIPT bold_Z end_POSTSUBSCRIPT ( caligraphic_M ( bold_x ) ). For example, GetValues V 3(ℬ([t,f]))=f subscript GetValues subscript 𝑉 3 ℬ t f f\textsc{GetValues}{V_{3}}(\mathcal{B}([\textsc{t},\textsc{f}]))=\textsc{f}GetValues start_POSTSUBSCRIPT italic_V start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( caligraphic_B ( [ t , f ] ) ) = f.
3.1 Interventions
Interventions are a fundamental building block of causal models, and of causal abstraction analysis in particular. An intervention 𝐈←𝐢←𝐈 𝐢\mathbf{I}\leftarrow\mathbf{i}bold_I ← bold_i is a setting 𝐢 𝐢\mathbf{i}bold_i of variables 𝐈 𝐈\mathbf{I}bold_I. Together, an intervention and an input setting 𝐱 𝐱\mathbf{x}bold_x of a model ℳ ℳ\mathcal{M}caligraphic_M determine a unique total setting that we denote as ℳ 𝐈←𝐢(𝐱)subscript ℳ←𝐈 𝐢 𝐱\mathcal{M}_{\mathbf{I}\leftarrow\mathbf{i}}(\mathbf{x})caligraphic_M start_POSTSUBSCRIPT bold_I ← bold_i end_POSTSUBSCRIPT ( bold_x ). The inputs are fixed to be 𝐱 𝐱\mathbf{x}bold_x, and the causal mechanisms of the model determine the values of the non-intervened variables, with the intervened variables 𝐈 𝐈\mathbf{I}bold_I being fixed to 𝐢 𝐢\mathbf{i}bold_i.
We can define interventions on both our causal model ℬ ℬ\mathcal{B}caligraphic_B and our neural model 𝒩 𝒩\mathcal{N}caligraphic_N. For example, ℬ V 1←t([f,t])subscript ℬ←subscript 𝑉 1 t f t\mathcal{B}{V{1}\leftarrow\textsc{t}}([\textsc{f},\textsc{t}])caligraphic_B start_POSTSUBSCRIPT italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ← t end_POSTSUBSCRIPT ( [ f , t ] ) is our boolean model when it processes input [f,t]f t[\textsc{f},\textsc{t}][ f , t ] but with variable V 1 subscript 𝑉 1 V_{1}italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT set to t. This has the effect of changing the output value to t. Similarly, whereas 𝒩([0,1])𝒩 0 1\mathcal{N}([0,1])caligraphic_N ( [ 0 , 1 ] ) leads to an intermediate values h 1=−0.34 subscript ℎ 1 0.34 h_{1}=-0.34 italic_h start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = - 0.34 and h 2=0.94 subscript ℎ 2 0.94 h_{2}=0.94 italic_h start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = 0.94 and output value −1.2 1.2-1.2- 1.2, if we compute 𝒩 h 1←1.34([0,1])subscript 𝒩←subscript ℎ 1 1.34 0 1\mathcal{N}{h{1}\leftarrow 1.34}([0,1])caligraphic_N start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ← 1.34 end_POSTSUBSCRIPT ( [ 0 , 1 ] ), then the output value is 0.48 0.48 0.48 0.48. This has the effect of changing the predicted value to t, because 0.48>0 0.48 0 0.48>0 0.48 > 0.
3.2 Alignment
In causal abstraction analysis, we ask whether a specific low-level model like 𝒩 𝒩\mathcal{N}caligraphic_N implements a high-level algorithm like ℬ ℬ\mathcal{B}caligraphic_B. This is always relative to a specific alignment of variables between the two models. An alignment Π=({Π X}X,{τ X}X)Π subscript subscript Π 𝑋 𝑋 subscript subscript 𝜏 𝑋 𝑋\Pi=({\Pi_{X}}{X},{\tau{X}}{X})roman_Π = ( { roman_Π start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT , { italic_τ start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT ) assigns to each high-level variable X 𝑋 X italic_X a set of low-level variables Π X subscript Π 𝑋\Pi{X}roman_Π start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT and a function τ X subscript 𝜏 𝑋\tau_{X}italic_τ start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT that maps from values of the low-level variables in Π X subscript Π 𝑋\Pi_{X}roman_Π start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT to values of the aligned high-level variable X 𝑋 X italic_X. One possible alignment between ℬ ℬ\mathcal{B}caligraphic_B and 𝒩 𝒩\mathcal{N}caligraphic_N is shown in the diagram above: Π Π\Pi roman_Π is depicted by the dashed lines connecting ℬ ℬ\mathcal{B}caligraphic_B and 𝒩 𝒩\mathcal{N}caligraphic_N.
We immediately know what the functions for high-level input and output variables are. For the inputs, t is encoded as 1 1 1 1 and f is encoded as 0 0, meaning τ P(1)=τ Q(1)=t subscript 𝜏 𝑃 1 subscript 𝜏 𝑄 1 t\tau_{P}(1)=\tau_{Q}(1)=\textsc{t}italic_τ start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT ( 1 ) = italic_τ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ( 1 ) = t and τ P(0)=τ Q(0)=f subscript 𝜏 𝑃 0 subscript 𝜏 𝑄 0 f\tau_{P}(0)=\tau_{Q}(0)=\textsc{f}italic_τ start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT ( 0 ) = italic_τ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ( 0 ) = f. For the output, the network only predicts t if y>0 𝑦 0 y>0 italic_y > 0, meaning τ V 3(x)=t subscript 𝜏 subscript 𝑉 3 𝑥 t\tau_{V_{3}}(x)=\textsc{t}italic_τ start_POSTSUBSCRIPT italic_V start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x ) = t if x>0 𝑥 0 x>0 italic_x > 0, else f. This is simply a consequence of how a neural network is used and trained. The functions for high-level intermediate variables τ V 1(x)subscript 𝜏 subscript 𝑉 1 𝑥\tau_{V_{1}}(x)italic_τ start_POSTSUBSCRIPT italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x ) and τ V 2(x)subscript 𝜏 subscript 𝑉 2 𝑥\tau_{V_{2}}(x)italic_τ start_POSTSUBSCRIPT italic_V start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x ) must be discovered and verified experimentally.
3.3 Constructive Causal Abstraction
Relative to an alignment like this, we can define abstraction:
Definition 3.1.
(Constructive Causal Abstraction) A high-level causal model ℋ ℋ\mathcal{H}caligraphic_H is a constructive abstraction of a low-level causal model ℒ ℒ\mathcal{L}caligraphic_L under alignment Π normal-Π\Pi roman_Π exactly when the following holds for every low-level input setting 𝐱 𝐱\mathbf{x}bold_x and low-level intervention 𝐈←𝐢 normal-←𝐈 𝐢\mathbf{I}\leftarrow\mathbf{i}bold_I ← bold_i:
τ(ℒ 𝐈←𝐢(𝐱))=ℋ τ(𝐈←𝐢)(τ(𝐱))𝜏 subscript ℒ←𝐈 𝐢 𝐱 subscript ℋ 𝜏←𝐈 𝐢 𝜏 𝐱\tau(\mathcal{L}{\mathbf{I}\leftarrow\mathbf{i}}(\mathbf{x})\big{)}=\mathcal{% H}{\tau(\mathbf{I}\leftarrow\mathbf{i})}(\tau(\mathbf{x}))italic_τ ( caligraphic_L start_POSTSUBSCRIPT bold_I ← bold_i end_POSTSUBSCRIPT ( bold_x ) ) = caligraphic_H start_POSTSUBSCRIPT italic_τ ( bold_I ← bold_i ) end_POSTSUBSCRIPT ( italic_τ ( bold_x ) )
ℋ ℋ\mathcal{H}caligraphic_H being a causal abstraction of ℒ ℒ\mathcal{L}caligraphic_L under Π Π\Pi roman_Π guarantees that the causal mechanism for each high-level variable X 𝑋 X italic_X is a faithful rendering of the causal mechanisms for the low-level variables in Π X subscript Π 𝑋\Pi_{X}roman_Π start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT.
To assess the degree to which a high-level model is a constructive causal abstraction of a low-level model, we perform interchange interventions:
Definition 3.2.
(Interchange Interventions) Given source input settings {𝐬 j}1 k superscript subscript subscript 𝐬 𝑗 1 𝑘{\mathbf{s}{j}}{1}^{k}{ bold_s start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT, and non-overlapping sets of intermediate variables {𝐗 j}1 k superscript subscript subscript 𝐗 𝑗 1 𝑘{\mathbf{X}{j}}{1}^{k}{ bold_X start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT for model ℳ ℳ\mathcal{M}caligraphic_M, define the interchange intervention as the model
II(ℳ,{𝐬 j}1 k,{𝐗 j}1 k)=ℳ⋀j=1 k⟨𝐗 j←𝖦𝖾𝗍𝖵𝖺𝗅𝗌 𝐗 j(ℳ(s j))⟩II ℳ superscript subscript subscript 𝐬 𝑗 1 𝑘 superscript subscript subscript 𝐗 𝑗 1 𝑘 subscript ℳ subscript superscript 𝑘 𝑗 1 delimited-⟨⟩←subscript 𝐗 𝑗 subscript 𝖦𝖾𝗍𝖵𝖺𝗅𝗌 subscript 𝐗 𝑗 ℳ subscript 𝑠 𝑗\textsc{II}(\mathcal{M},{\mathbf{s}{j}}{1}^{k},{\mathbf{X}{j}}{1}^{k})% =\mathcal{M}{\bigwedge^{k}{j=1}\langle\mathbf{X}{j}\leftarrow\mathsf{% GetVals}{\mathbf{X}{j}}(\mathcal{M}(s{j}))\rangle}II ( caligraphic_M , { bold_s start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT , { bold_X start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) = caligraphic_M start_POSTSUBSCRIPT ⋀ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT ⟨ bold_X start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ← sansserif_GetVals start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( caligraphic_M ( italic_s start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ) ⟩ end_POSTSUBSCRIPT
where ⋀j=1 k⟨⋅⟩subscript superscript 𝑘 𝑗 1 delimited-⟨⟩normal-⋅\bigwedge^{k}_{j=1}\langle\cdot\rangle⋀ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT ⟨ ⋅ ⟩ concatenates a set of interventions.
A base input setting can be fed into the resulting model to compute the counterfactual output value. Consider the following interchange intervention:
II(ℬ,{[t,t]},{{V 1}})=ℬ{V 1}←𝖦𝖾𝗍𝖵𝖺𝗅𝗌{V 1}(ℬ([t,t]))II ℬ t t subscript 𝑉 1 subscript ℬ←subscript 𝑉 1 subscript 𝖦𝖾𝗍𝖵𝖺𝗅𝗌 subscript 𝑉 1 ℬ t t\textsc{II}(\mathcal{B},{[\textsc{t},\textsc{t}]},{{V_{1}}})=\mathcal{B}% {{V{1}}\leftarrow\mathsf{GetVals}{{V{1}}}(\mathcal{B}({[\textsc{t},% \textsc{t}]}))}II ( caligraphic_B , { [ t , t ] } , { { italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT } } ) = caligraphic_B start_POSTSUBSCRIPT { italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT } ← sansserif_GetVals start_POSTSUBSCRIPT { italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT } end_POSTSUBSCRIPT ( caligraphic_B ( [ t , t ] ) ) end_POSTSUBSCRIPT
We process a base input and a source input, and then we intervene on a target variable, replacing it with the value obtained by processing the source. Our causal model is fully known, and so we know ahead of time that this interchange intervention yields t. For our neural network, the corresponding behavior is not known ahead of time. The interchange intervention corresponding to the above (according to the alignment we are exploring) is as follows
II(𝒩,{[1,1]},{{H 1}})=𝒩{V 1}←𝖦𝖾𝗍𝖵𝖺𝗅𝗌{H 1}(𝒩([1,1]))II 𝒩 1 1 subscript 𝐻 1 𝒩 subscript 𝑉 1←subscript 𝖦𝖾𝗍𝖵𝖺𝗅𝗌 subscript 𝐻 1 𝒩 1 1\textsc{II}(\mathcal{N},{[1,1]},{{H_{1}}})=\mathcal{N}{{V_{1}}% \leftarrow\mathsf{GetVals}{{H{1}}}(\mathcal{N}({[1,1]}))}II ( caligraphic_N , { [ 1 , 1 ] } , { { italic_H start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT } } ) = caligraphic_N { italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT } ← sansserif_GetVals start_POSTSUBSCRIPT { italic_H start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT } end_POSTSUBSCRIPT ( caligraphic_N ( [ 1 , 1 ] ) )
And, indeed, the counterfactual behavior of the model and the network 𝒩 𝒩\mathcal{N}caligraphic_N are unequal:
Under the given alignment, the interchange interventions at the low and high level have different effects. Thus, we have a counterexample to constructive abstraction as given in Definition3.1. Although 𝒩 𝒩\mathcal{N}caligraphic_N has perfect behavioral accuracy, its accuracy under the counterfactuals created by our interventions is not perfect, and thus ℬ ℬ\mathcal{B}caligraphic_B is not a constructive abstraction of 𝒩 𝒩\mathcal{N}caligraphic_N under this alignment.
3.4 Distributed Interventions
The above conclusion is based on the kind of localist causal abstraction explored in the literature to date. As noted in Section1, there are two risks associated with this conclusion: (1) we may have chosen a suboptimal alignment, and (2) we may be wrong to assume that the relevant structure will be encoded in the standard basis we have implicitly assumed throughout.
If we simply rotate the representation [H 1,H 2]subscript 𝐻 1 subscript 𝐻 2[H_{1},H_{2}][ italic_H start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_H start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ] by −20∘superscript 20-20^{\circ}- 20 start_POSTSUPERSCRIPT ∘ end_POSTSUPERSCRIPT to get a new representation [Y 1,Y 2]subscript 𝑌 1 subscript 𝑌 2[Y_{1},Y_{2}][ italic_Y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_Y start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ], then the resulting network has perfect behavioral and counterfactual accuracy when we align V 1 subscript 𝑉 1 V_{1}italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and V 2 subscript 𝑉 2 V_{2}italic_V start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT with Y 1 subscript 𝑌 1 Y_{1}italic_Y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and Y 2 subscript 𝑌 2 Y_{2}italic_Y start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT. What this reveals is that there is an alignment, but not in the basis we chose. Since the choice of basis was arbitrary, our negative conclusion about the causal abstraction relation was spurious.
This rotation localizes the information about the first and second argument into separate dimensions. To understand this, observe that the weight matrix of the linear network rotates a two dimensional vector by 20∘superscript 20 20^{\circ}20 start_POSTSUPERSCRIPT ∘ end_POSTSUPERSCRIPT and the rotation matrix rotates the representation by 340∘superscript 340 340^{\circ}340 start_POSTSUPERSCRIPT ∘ end_POSTSUPERSCRIPT. The two matrices are inverses. Because this network is linear, there is no activation function and so rotating the hidden representation “undoes” the transformation of the input by the weight matrix. Under this non-standard basis, the first hidden dimension is equal to the first input argument and the second hidden dimension is equal to the second input argument.
This reveals an essential aspect of distributed neural representations: there is a many-to-many mapping between neurons and concepts, and thus multiple high-level causal variables might be encoded in structures from overlapping groups of neurons (Rumelhart et al., 1986; McClelland et al., 1986). In particular, Smolensky (1986) proposes that viewing a neural representation under a basis that is not aligned with individual neurons can reveal the interpretable distributed structure of the neural representations.
Figure 1: A generic multi-source distributed interchange intervention. The base input and two source inputs create three total settings of a model. The top left (green) and right (blue) total model settings are determined by two source inputs and the middle total model setting (red) is determined by the base input. Three hidden units from each total setting are rotated with an orthogonal matrix 𝐑:𝐗→𝐘:𝐑→𝐗 𝐘\mathbf{R}:\mathbf{X}\to\mathbf{Y}bold_R : bold_X → bold_Y. Then we intervene on the rotated representation for the base input and fix two dimensions to be the value they take on for each source input, respectively. Then we unrotate the representation with 𝐑−1 superscript 𝐑 1\mathbf{R}^{-1}bold_R start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT and compute a counterfactual total model setting for the base input. In DAS, the orthogonal matrix is found with gradient descent using a high-level causal model to guide the search process.
To make good on this intuition we define a distributed intervention, which first transforms a set of variables to a vector space, then does interchange on orthogonal sub-spaces, before transforming back to the original representation space.
Definition 3.3.
Distributed Interchange Interventions We begin with a causal model ℳ ℳ\mathcal{M}caligraphic_M with input variables 𝐒 𝐒\mathbf{S}bold_S and source input settings {𝐬 j}j=1 k superscript subscript subscript 𝐬 𝑗 𝑗 1 𝑘{\mathbf{s}{j}}{j=1}^{k}{ bold_s start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT. Let 𝐍 𝐍\mathbf{N}bold_N be a subset of variables in ℳ ℳ\mathcal{M}caligraphic_M, the target variables. Let 𝐘 𝐘\mathbf{Y}bold_Y be a vector space with subspaces {𝐘 j}0 k superscript subscript subscript 𝐘 𝑗 0 𝑘{\mathbf{Y}{j}}{0}^{k}{ bold_Y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT that form an orthogonal decomposition, i.e., 𝐘=⨁j=0 k 𝐘 j 𝐘 superscript subscript direct-sum 𝑗 0 𝑘 subscript 𝐘 𝑗\mathbf{Y}=\bigoplus_{j=0}^{k}\mathbf{Y}{j}bold_Y = ⨁ start_POSTSUBSCRIPT italic_j = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT bold_Y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT. Let 𝐑 𝐑\mathbf{R}bold_R be an invertible function 𝐑:𝐍→𝐘 normal-:𝐑 normal-→𝐍 𝐘\mathbf{R}:\mathbf{N}\to\mathbf{Y}bold_R : bold_N → bold_Y. Write 𝖯𝗋𝗈𝗃 𝐘 j subscript 𝖯𝗋𝗈𝗃 subscript 𝐘 𝑗\mathsf{Proj}{\mathbf{Y}{j}}sansserif_Proj start_POSTSUBSCRIPT bold_Y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT for the orthogonal projection operator of a vector in 𝐘 𝐘\mathbf{Y}bold_Y onto subspace 𝐘 j subscript 𝐘 𝑗\mathbf{Y}{j}bold_Y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT.1 1 1 Thus, 𝖯𝗋𝗈𝗃 𝖯𝗋𝗈𝗃\mathsf{Proj}sansserif_Proj generalizes 𝖦𝖾𝗍𝖵𝖺𝗅𝗌 𝖦𝖾𝗍𝖵𝖺𝗅𝗌\mathsf{GetVals}sansserif_GetVals to arbitrary vector spaces. A distributed interchange intervention yields a new model DII(ℳ,𝐑,{𝐬 j}1 k,{𝐘 j}0 k)DII ℳ 𝐑 superscript subscript subscript 𝐬 𝑗 1 𝑘 superscript subscript subscript 𝐘 𝑗 0 𝑘\textsc{DII}(\mathcal{M},\mathbf{R},{\mathbf{s}{j}}{1}^{k},{\mathbf{Y}{j% }}{0}^{k})DII ( caligraphic_M , bold_R , { bold_s start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT , { bold_Y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) which is identical to ℳ ℳ\mathcal{M}caligraphic_M except that the mechanisms F 𝐍 subscript 𝐹 𝐍 F_{\mathbf{N}}italic_F start_POSTSUBSCRIPT bold_N end_POSTSUBSCRIPT (which yield values of 𝐍 𝐍\mathbf{N}bold_N from a total setting) are replaced by:
F 𝐍(𝐯)=𝐑−1(𝖯𝗋𝗈𝗃 𝐘 0(𝐑(F 𝐍(𝐯)))+∑j=1 k 𝖯𝗋𝗈𝗃 𝐘 j(𝐑(F 𝐍(ℳ(𝐬 𝐣))))).subscript superscript 𝐹 𝐍 𝐯 superscript 𝐑 1 subscript 𝖯𝗋𝗈𝗃 subscript 𝐘 0 𝐑 subscript 𝐹 𝐍 𝐯 superscript subscript 𝑗 1 𝑘 subscript 𝖯𝗋𝗈𝗃 subscript 𝐘 𝑗 𝐑 subscript 𝐹 𝐍 ℳ subscript 𝐬 𝐣 F^{}{\mathbf{N}}(\mathbf{v})=\mathbf{R}^{-1}\bigg{(}\mathsf{Proj}{\mathbf{Y% }{0}}\Big{(}\mathbf{R}\big{(}F{\mathbf{N}}(\mathbf{v})\big{)}\Big{)}\ +\sum_{j=1}^{k}\mathsf{Proj}{\mathbf{Y}{j}}\Big{(}\mathbf{R}\big{(}F_{% \mathbf{N}}(\mathcal{M}(\mathbf{s_{j}}))\big{)}\Big{)}\bigg{)}.italic_F start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_N end_POSTSUBSCRIPT ( bold_v ) = bold_R start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( sansserif_Proj start_POSTSUBSCRIPT bold_Y start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_R ( italic_F start_POSTSUBSCRIPT bold_N end_POSTSUBSCRIPT ( bold_v ) ) ) + ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT sansserif_Proj start_POSTSUBSCRIPT bold_Y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_R ( italic_F start_POSTSUBSCRIPT bold_N end_POSTSUBSCRIPT ( caligraphic_M ( bold_s start_POSTSUBSCRIPT bold_j end_POSTSUBSCRIPT ) ) ) ) ) .
Notice that in this definition the base setting is partially preserved through the intervention (in subspace 𝐘 0 subscript 𝐘 0\mathbf{Y}_{0}bold_Y start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT) and hence this is a soft intervention on 𝐍 𝐍\mathbf{N}bold_N that rewrites causal mechanisms while maintaining a causal dependence between parent and child.
Under this new alignment, the high-level interchange intervention II(ℬ,{[t,t]},{{V 1}})=ℬ{V 1}←𝖦𝖾𝗍𝖵𝖺𝗅𝗌{V 1}(ℬ([t,t]))II ℬ t t subscript 𝑉 1 subscript ℬ←subscript 𝑉 1 subscript 𝖦𝖾𝗍𝖵𝖺𝗅𝗌 subscript 𝑉 1 ℬ t t\textsc{II}(\mathcal{B},{[\textsc{t},\textsc{t}]},{{V_{1}}})=\mathcal{B}% {{V{1}}\leftarrow\mathsf{GetVals}{{V{1}}}(\mathcal{B}({[\textsc{t},% \textsc{t}]}))}II ( caligraphic_B , { [ t , t ] } , { { italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT } } ) = caligraphic_B start_POSTSUBSCRIPT { italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT } ← sansserif_GetVals start_POSTSUBSCRIPT { italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT } end_POSTSUBSCRIPT ( caligraphic_B ( [ t , t ] ) ) end_POSTSUBSCRIPT is aligned with the low-level distributed interchange intervention
DII(𝒩,[cos(−20∘)−sin(−20∘)sin(−20∘)cos(−20∘)],{[1,1]},{{Y 1}})DII 𝒩 delimited-[]superscript 20 superscript 20 superscript 20 superscript 20 1 1 subscript 𝑌 1\textsc{DII}(\mathcal{N},\Bigg{[}\begin{array}[]{rr}\cos(-20^{\circ})&-\sin(-2% 0^{\circ})\ \sin(-20^{\circ})&\phantom{-}\cos(-20^{\circ})\end{array}\Bigg{]},{[1,1]},{% {Y_{1}}})DII ( caligraphic_N , [ start_ARRAY start_ROW start_CELL roman_cos ( - 20 start_POSTSUPERSCRIPT ∘ end_POSTSUPERSCRIPT ) end_CELL start_CELL - roman_sin ( - 20 start_POSTSUPERSCRIPT ∘ end_POSTSUPERSCRIPT ) end_CELL end_ROW start_ROW start_CELL roman_sin ( - 20 start_POSTSUPERSCRIPT ∘ end_POSTSUPERSCRIPT ) end_CELL start_CELL roman_cos ( - 20 start_POSTSUPERSCRIPT ∘ end_POSTSUPERSCRIPT ) end_CELL end_ROW end_ARRAY ] , { [ 1 , 1 ] } , { { italic_Y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT } } )
and the counterfactual output behavior of ℬ ℬ\mathcal{B}caligraphic_B and 𝒩 𝒩\mathcal{N}caligraphic_N are equal:
In what follows we will assume that 𝐗 𝐗\mathbf{X}bold_X are already vector spaces (which is true for neural nets) and the functions 𝐑 𝐑\mathbf{R}bold_R are rotation operators. In this case, the subspaces 𝐘 j subscript 𝐘 𝑗\mathbf{Y}{j}bold_Y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT can be identified without loss of generality with those spanned by the first |𝐘 0|subscript 𝐘 0|\mathbf{Y}{0}|| bold_Y start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | basis vectors for 𝐘 0 subscript 𝐘 0\mathbf{Y}{0}bold_Y start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, the next |𝐘 1|subscript 𝐘 1|\mathbf{Y}{1}|| bold_Y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | basis vectors for 𝐘 1 subscript 𝐘 1\mathbf{Y}_{1}bold_Y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, and so on. (The following methods would be well-defined for non-linear transformations, as long as they were invertible and differentiable, but efficient implementation becomes harder.)
3.5 Distributed Alignment Search
The question then arises of how to find good rotations. As we discussed above, previous causal abstraction analyses of neural networks have performed brute-force search through a discrete space of hand-picked alignments. In distributed alignment search(DAS), we find an alignment between one or more high-level variables and disjoint sub-spaces (but not necessarily subsets) of a large neural representation. We define a distributed interchange intervention training objective, use differentiable parameterizations for the space of orthogonal matrices (such as provided by PyTorch), and then optimize the objective with stochastic gradient descent. Crucially, the low-level and high-level models are frozen during learning so we are only changing the alignment.
In the following definition we assume that a neural network specifies an output distribution for a given input, which can then be pushed forward to a distribution on output values of the high-level model via an alignment function τ 𝜏\tau italic_τ. We may similarly interpret even a deterministic high-level model as defining a (e.g., delta) distribution on output values. We make use of these distributions, after interchange intervention, to define a differentiable loss for the rotation matrix which aligns intermediate variables.
Definition 3.4.
Distributed Interchange Intervention Training Objective Begin with a low-level neural network ℒ ℒ\mathcal{L}caligraphic_L, with low-level input settings 𝐈𝐧𝐩𝐮𝐭𝐬 L subscript 𝐈𝐧𝐩𝐮𝐭𝐬 𝐿\mathbf{Inputs}{L}bold_Inputs start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT, a high-level algorithm ℋ ℋ\mathcal{H}caligraphic_H, with high-level output settings 𝐎𝐮𝐭 H subscript 𝐎𝐮𝐭 𝐻\mathbf{Out}{H}bold_Out start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT, and an alignment τ 𝜏\tau italic_τ for their input and output variables. Suppose we want to align intermediate high level variables X j∈𝐕𝐚𝐫𝐬 ℋ subscript 𝑋 𝑗 subscript 𝐕𝐚𝐫𝐬 ℋ X_{j}\in\mathbf{Vars}{\mathcal{H}}italic_X start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∈ bold_Vars start_POSTSUBSCRIPT caligraphic_H end_POSTSUBSCRIPT with rotated subspaces 𝐘 j subscript 𝐘 𝑗\mathbf{Y}{j}bold_Y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT of a neural representation 𝐍⊂𝐕𝐚𝐫𝐬 ℒ 𝐍 subscript 𝐕𝐚𝐫𝐬 ℒ\mathbf{N}\subset\mathbf{Vars}_{\mathcal{L}}bold_N ⊂ bold_Vars start_POSTSUBSCRIPT caligraphic_L end_POSTSUBSCRIPT with learned rotation matrix 𝐑 θ:𝐍→𝐘 normal-:superscript 𝐑 𝜃 normal-→𝐍 𝐘\mathbf{R}^{\theta}:\mathbf{N}\to\mathbf{Y}bold_R start_POSTSUPERSCRIPT italic_θ end_POSTSUPERSCRIPT : bold_N → bold_Y.
In general, we can define a training objective using any differentiable loss function 𝖫𝗈𝗌𝗌 𝖫𝗈𝗌𝗌\mathsf{Loss}sansserif_Loss that quantifies the distance between two total high-level settings.
∑𝐛,𝐬 1,…,𝐬 k∈𝐈𝐧𝐩𝐮𝐭𝐬 L 𝖫𝗈𝗌𝗌(DII(ℒ,𝐑 θ,{𝐬 j}1 k,{𝐘 j}0 k)(𝐛),II(ℋ,{τ(𝐬 j)}1 k,{𝐗 j}1 k)(τ(𝐛)))\sum_{\mathbf{b},\mathbf{s}{1},\dots,\mathbf{s}{k}\in\mathbf{Inputs}{L}}% \mathsf{Loss}\biggl{(}\textsc{DII}(\mathcal{L},\mathbf{R}^{\theta},{\mathbf{s% }{j}}{1}^{k},{\mathbf{Y}{j}}^{k}{0})(\mathbf{b}),\ \textsc{II}(\mathcal{H},{\tau(\mathbf{s}{j})}^{k}{1},{\mathbf{X}{j}}_{1% }^{k})(\tau(\mathbf{b}))\biggl{)}∑ start_POSTSUBSCRIPT bold_b , bold_s start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_s start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∈ bold_Inputs start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_POSTSUBSCRIPT sansserif_Loss ( DII ( caligraphic_L , bold_R start_POSTSUPERSCRIPT italic_θ end_POSTSUPERSCRIPT , { bold_s start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT , { bold_Y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ( bold_b ) , II ( caligraphic_H , { italic_τ ( bold_s start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) } start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , { bold_X start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) ( italic_τ ( bold_b ) ) )
For our experiments, we compute the cross entropy loss 𝖢𝖤(⋅,⋅)𝖢𝖤 normal-⋅normal-⋅\mathsf{CE}(\cdot,\cdot)sansserif_CE ( ⋅ , ⋅ ) between the high-level output distribution ℙ(𝐨𝐮𝐭 H|ℋ(τ(𝐛)))ℙ conditional subscript 𝐨𝐮𝐭 𝐻 ℋ 𝜏 𝐛\mathbb{P}(\mathbf{out}{H}|\mathcal{H}(\tau(\mathbf{b})))blackboard_P ( bold_out start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT | caligraphic_H ( italic_τ ( bold_b ) ) ) and the push-forward under τ 𝜏\tau italic_τ of the low-level output distribution ℙ τ(𝐨𝐮𝐭 H|ℒ(𝐛))superscript ℙ 𝜏 conditional subscript 𝐨𝐮𝐭 𝐻 ℒ 𝐛\mathbb{P}^{\tau}(\mathbf{out}{H}|\mathcal{L}(\mathbf{b}))blackboard_P start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT ( bold_out start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT | caligraphic_L ( bold_b ) ). The overall objective is:
∑𝐛,𝐬 1,…,𝐬 k∈𝐈𝐧𝐩𝐮𝐭𝐬 L 𝖢𝖤(ℙ(𝐨𝐮𝐭 H|II(ℋ,{τ(𝐬 j)}1 k,{𝐗 j}1 k))(τ(𝐛)),ℙ τ(𝐨𝐮𝐭 H|DII(ℒ,𝐑 θ,{𝐬 j}1 k,{𝐘 j}0 k)(𝐛)))\sum_{\mathbf{b},\mathbf{s}{1},\dots,\mathbf{s}{k}\in\mathbf{Inputs}{L}}% \mathsf{CE}\biggl{(}\mathbb{P}(\mathbf{out}{H}|\textsc{II}(\mathcal{H},{\tau% (\mathbf{s}{j})}^{k}{1},{\mathbf{X}{j}}{1}^{k}))(\tau(\mathbf{b})),% \mathbb{P}^{\tau}(\mathbf{out}{H}|\textsc{DII}(\mathcal{L},\mathbf{R}^{\theta% },{\mathbf{s}{j}}{1}^{k},{\mathbf{Y}{j}}^{k}_{0})(\mathbf{b}))\biggl{)}start_ROW start_CELL ∑ start_POSTSUBSCRIPT bold_b , bold_s start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_s start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∈ bold_Inputs start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_POSTSUBSCRIPT sansserif_CE ( blackboard_P ( bold_out start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT | II ( caligraphic_H , { italic_τ ( bold_s start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) } start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , { bold_X start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) ) ( italic_τ ( bold_b ) ) , blackboard_P start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT ( bold_out start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT | DII ( caligraphic_L , bold_R start_POSTSUPERSCRIPT italic_θ end_POSTSUPERSCRIPT , { bold_s start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT , { bold_Y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ( bold_b ) ) ) end_CELL end_ROW
While we still have discrete hyperparameters (𝐍,|𝐘 0|,…,|𝐘 k|)𝐍 subscript 𝐘 0…subscript 𝐘 𝑘(\mathbf{N},|\mathbf{Y}{0}|,\dots,|\mathbf{Y}{k}|)( bold_N , | bold_Y start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | , … , | bold_Y start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT | )—the target population and the dimensionality of the sub-spaces used for each high-level variable—we may use stochastic gradient descent to determine the rotation that minimizes loss, thus yielding the best distributed alignment between ℒ ℒ\mathcal{L}caligraphic_L and ℋ ℋ\mathcal{H}caligraphic_H.
3.6 Approximate Causal Abstraction
Perfect causal abstraction relationships are unlikely to arise for neural networks trained to solve complex empirical tasks. We use a graded notion of accuracy:
Definition 3.5.
Distributed Interchange Intervention Accuracy Given low-level and high-level causal models ℒ ℒ\mathcal{L}caligraphic_L and ℋ ℋ\mathcal{H}caligraphic_H with alignment (Π,τ)normal-Π 𝜏(\Pi,\tau)( roman_Π , italic_τ ), rotation 𝐑:𝐍→𝐘 normal-:𝐑 normal-→𝐍 𝐘\mathbf{R}:\mathbf{N}\to\mathbf{Y}bold_R : bold_N → bold_Y, and orthogonal decomposition {𝐘 j}0 k subscript superscript subscript 𝐘 𝑗 𝑘 0{\mathbf{Y}{j}}^{k}{0}{ bold_Y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. If we let 𝐈𝐧𝐩𝐮𝐭𝐬 L subscript 𝐈𝐧𝐩𝐮𝐭𝐬 𝐿\mathbf{Inputs}{L}bold_Inputs start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT be low-level input settings and {𝐗 j}1 k superscript subscript subscript 𝐗 𝑗 1 𝑘{\mathbf{X}{j}}_{1}^{k}{ bold_X start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT be high-level intermediate variables the interchange intervention accuracy (IIA) is as follows
∑𝐛,𝐬 1,…,𝐬 k∈𝐈𝐧𝐩𝐮𝐭𝐬 L 1|𝐈𝐧𝐩𝐮𝐭𝐬 L|k+1[τ(DII(ℒ,𝐑 θ,{𝐬 j}1 k,{𝐘 j}0 k)(𝐛))=II(ℋ,{τ(𝐬 j)}1 k,{𝐗 j}1 k)(τ(𝐛))]subscript 𝐛 subscript 𝐬 1…subscript 𝐬 𝑘 subscript 𝐈𝐧𝐩𝐮𝐭𝐬 𝐿 1 superscript subscript 𝐈𝐧𝐩𝐮𝐭𝐬 𝐿 𝑘 1 delimited-[]𝜏 DII ℒ superscript 𝐑 𝜃 superscript subscript subscript 𝐬 𝑗 1 𝑘 subscript superscript subscript 𝐘 𝑗 𝑘 0 𝐛 II ℋ subscript superscript 𝜏 subscript 𝐬 𝑗 𝑘 1 superscript subscript subscript 𝐗 𝑗 1 𝑘 𝜏 𝐛\sum_{\mathbf{b},\mathbf{s}{1},\dots,\mathbf{s}{k}\in\mathbf{Inputs}{L}}% \frac{1}{|\mathbf{Inputs}{L}|^{k+1}}\Big{[}\tau\big{(}\textsc{DII}(\mathcal{L% },\mathbf{R}^{\theta},{\mathbf{s}{j}}{1}^{k},{\mathbf{Y}{j}}^{k}{0})(% \mathbf{b})\big{)}=\textsc{II}(\mathcal{H},{\tau(\mathbf{s}{j})}^{k}{1},{% \mathbf{X}{j}}{1}^{k})(\tau(\mathbf{b}))\Big{]}∑ start_POSTSUBSCRIPT bold_b , bold_s start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_s start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∈ bold_Inputs start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_POSTSUBSCRIPT divide start_ARG 1 end_ARG start_ARG | bold_Inputs start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT | start_POSTSUPERSCRIPT italic_k + 1 end_POSTSUPERSCRIPT end_ARG [ italic_τ ( DII ( caligraphic_L , bold_R start_POSTSUPERSCRIPT italic_θ end_POSTSUPERSCRIPT , { bold_s start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT , { bold_Y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ( bold_b ) ) = II ( caligraphic_H , { italic_τ ( bold_s start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) } start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , { bold_X start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) ( italic_τ ( bold_b ) ) ]
IIA is the proportion of aligned interchange interventions that have equivalent high-level and low-level effects. In our example with 𝒩 𝒩\mathcal{N}caligraphic_N and 𝒜 𝒜\mathcal{A}caligraphic_A, IIA is 100% and the high-level model is a perfect abstraction of the low-level model (Def. 3.1). When IIA is α<𝛼 absent\alpha<italic_α <100%, we rely on the graded notion of α 𝛼\alpha italic_α-on-average approximate causal abstraction (Geiger et al., 2023), which coincides with IIA.
3.7 General Experimental Setup
We illustrate the value of DAS by analyzing feed-forward networks trained on a hierarchical equality and pretrained Transformer-based language models (Vaswani et al., 2017) fine-tuned on a natural language inference task. Our evaluation paradigm is as follows:
- 1.Train the neural network 𝒩 𝒩\mathcal{N}caligraphic_N to solve the task. In all experiments, the neural models achieve perfect accuracy on both training and testing data.
- 2.Create interchange intervention training datasets using a high-level causal model. Each example consists of a base input, one or more source inputs, high-level causal variables targetted for intervention, and a counterfactual gold label that will be output by the network if the interchange intervention has the hypothesized effect on model behavior. This gold label is a counterfactual output of the high-level model we will align with the network. (See AppendixA.1 for details)
- 3.Optimize an orthogonal matrix to learn a distributed alignment for each high-level model that maximizes IIA using the training objective in Def.3.4. We experiment with different hidden dimension sizes for our low-level model and different intervention site sizes (dimensionality of low-level subspaces) and locations (the layer where the intervention happens). (See AppendixA.2 for details)
- 4.Evaluate a baseline that brute-force searches through a discrete space of alignments and selects the alignment with the highest IIA. We search the space of alignments by aligning each high-level variable with groups of neurons in disjoint sliding windows. (See AppendixA.3 for details)
- 5.Evaluate the localist alignment “closest” to the learned distributed alignment. The rotation matrix for the localist alignment will be axis-aligned with the standard basis, possibly permuting and reflecting unit axes. (See AppendixA.4 for details)
- 6.Determine whether each distributed representation aligned with high-level variables can be decomposed into multiple representations that encode the identity of the input values to the variable’s causal mechanism. We do this by learning a second rotation matrix that decomposes learned distributed representation, holding the first rotation matrix fixed. (See AppendixA.5 for details)
4 Hierarchical Equality Experiment
We now illustrate the power of DAS for analyzing networks designed to solve a hierarchical equality task. We concentrate on analyzing a trained feed-forward network.
A basic equality task is to determine whether a pair of objects are the same (x=y 𝑥 𝑦 x=y italic_x = italic_y). A hierarchical equality task is to determine whether a pair of pairs of objects have identical relations: (w=x)=(y=z)𝑤 𝑥 𝑦 𝑧(w=x)=(y=z)( italic_w = italic_x ) = ( italic_y = italic_z ). Specifically, the input to the task is two pairs of objects and the output is 𝖳𝗋𝗎𝖾 𝖳𝗋𝗎𝖾\mathsf{True}sansserif_True if both pairs are equal or both pairs are unequal and 𝖥𝖺𝗅𝗌𝖾 𝖥𝖺𝗅𝗌𝖾\mathsf{False}sansserif_False otherwise. For example, (A,A,B,B)𝐴 𝐴 𝐵 𝐵(A,A,B,B)( italic_A , italic_A , italic_B , italic_B ) and (A,B,C,D)𝐴 𝐵 𝐶 𝐷(A,B,C,D)( italic_A , italic_B , italic_C , italic_D ) are both assigned 𝖳𝗋𝗎𝖾 𝖳𝗋𝗎𝖾\mathsf{True}sansserif_True while (A,B,C,C)𝐴 𝐵 𝐶 𝐶(A,B,C,C)( italic_A , italic_B , italic_C , italic_C ) is assigned 𝖥𝖺𝗅𝗌𝖾 𝖥𝖺𝗅𝗌𝖾\mathsf{False}sansserif_False.
4.1 Low-Level Neural Model
We train a three-layer feed-forward network with ReLU activations to perform the hierarchical equality task. Each input object is represented by a randomly initialized vector. Specifically, our model has the following architecture where k 𝑘 k italic_k is the number of layers.
h 1=𝖱𝖾𝖫𝖴([x 1;x 2;x 3;x 4]W 1+b 1)h j−1=𝖱𝖾𝖫𝖴(h jW j+b j)y=𝐬𝐨𝐟𝐭𝐦𝐚𝐱(h kW k+b k)formulae-sequence subscript ℎ 1 𝖱𝖾𝖫𝖴 subscript 𝑥 1 subscript 𝑥 2 subscript 𝑥 3 subscript 𝑥 4 subscript 𝑊 1 subscript 𝑏 1 formulae-sequence subscript ℎ 𝑗 1 𝖱𝖾𝖫𝖴 subscript ℎ 𝑗 subscript 𝑊 𝑗 subscript 𝑏 𝑗 𝑦 𝐬𝐨𝐟𝐭𝐦𝐚𝐱 subscript ℎ 𝑘 subscript 𝑊 𝑘 subscript 𝑏 𝑘 h_{1}=\mathsf{ReLU}([x_{1};x_{2};x_{3};x_{4}]W_{1}+b_{1});;;;;;;;;h_{% j-1}=\mathsf{ReLU}(h_{j}W_{j}+b_{j});;;;;;;;;y=\mathbf{softmax}(h_{k}% W_{k}+b_{k})italic_h start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = sansserif_ReLU ( [ italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ; italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ; italic_x start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ; italic_x start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT ] italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) italic_h start_POSTSUBSCRIPT italic_j - 1 end_POSTSUBSCRIPT = sansserif_ReLU ( italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_W start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT + italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) italic_y = bold_softmax ( italic_h start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + italic_b start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT )
The input vectors are in ℝ n superscript ℝ 𝑛\mathbb{R}^{n}blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT, the biases are in ℝ 4n superscript ℝ 4 𝑛\mathbb{R}^{4n}blackboard_R start_POSTSUPERSCRIPT 4 italic_n end_POSTSUPERSCRIPT, and the weights are in ℝ 4n×4n superscript ℝ 4 𝑛 4 𝑛\mathbb{R}^{4n\times 4n}blackboard_R start_POSTSUPERSCRIPT 4 italic_n × 4 italic_n end_POSTSUPERSCRIPT. We evaluate our model on held-out random vectors unseen during training, as in Geiger et al. 2022a.
Both Equality Relations Left Equality Relation Identity of First Argument Identity Subspace of Left Equality
&
Hidden size Intervention size Layer 1 Layer 2 Layer 3 Layer 1 Layer 2 Layer 3 Layer 1 Layer 2 Layer 3 Layer 1|𝐍|=16 𝐍 16|\mathbf{N}|=16| bold_N | = 16 1 1 1 1 0.88 0.51 0.50 0.85 0.54 0.50 0.51 0.52 0.50 0.51 |𝐍|=16 𝐍 16|\mathbf{N}|=16| bold_N | = 16 2 2 2 2 0.97 0.54 0.50 0.85 0.55 0.50 0.50 0.52 0.51 0.50 |𝐍|=16 𝐍 16|\mathbf{N}|=16| bold_N | = 16 8 8 8 8 1.00 0.57 0.50 0.90 0.56 0.50 0.52 0.53 0.51 0.51 |𝐍|=32 𝐍 32|\mathbf{N}|=32| bold_N | = 32 2 2 2 2 0.93 0.63 0.49 0.92 0.65 0.50 0.52 0.55 0.52 0.50 |𝐍|=32 𝐍 32|\mathbf{N}|=32| bold_N | = 32 4 4 4 4 0.97 0.63 0.49 0.94 0.65 0.50 0.51 0.55 0.52 0.51 |𝐍|=32 𝐍 32|\mathbf{N}|=32| bold_N | = 32 16 16 16 16 0.99 0.67 0.53 0.99 0.65 0.50 0.49 0.55 0.52 0.51 Brute-Force Search 0.60 0.56 0.52 0.64 0.64 0.57 0.50 0.51 0.54 - Localist Alignment 0.73 0.56 0.48 0.60 0.50 0.49 0.46 0.47 0.48 -
Table 1: Hierarchical equality alignment learning results. The table can be read as follows: Layer 1, Layer 2, and Layer 3 indicate which layer of neurons is targeted, |𝐍|𝐍|\mathbf{N}|| bold_N | is the number of neurons in a layer, k 𝑘 k italic_k is the number of neurons aligned with each intermediate variable (red) where our subspace model occupies k 2 𝑘 2\frac{k}{2}divide start_ARG italic_k end_ARG start_ARG 2 end_ARG with rounding up to the closest integer, and the values in each cell are interchange intervention accuracies for the learned alignment on training data. We report the best results from three runs with distinct random seeds for training the rotation matrix (the same frozen low-level model is used for each seed).
4.2 High-Level Models
We use DAS to evaluate whether trained neural networks have achieved the natural solution to the hierarchical equality task where the left and right equality relations are computed and then used to predict the final label (Figure2).
Figure 2: A causal model that computes the hierarchical equality task.
However, evaluating this high-level model alone is insufficient, as there are obviously many other high-level models of this task. To further contextualize our results, we also consider two alternatives: a high-level model where only the equality relation of the first pair is represented and a high-level model where the lone intermediate variable encodes the identity of the first input object (leaving all computation for the final step). These alternative high-level models also solve the task perfectly.
4.3 Discussion
The IIA results achieved by the best alignment for each high-level model can be seen in Table4.1. The best alignments found are with the ‘Both Equality Relations’ model that is widely assumed in the cognitive science literature. For all causal models, DAS learns a more faithful alignment (higher IIA) than a brute-force search through localist alignments. This result is most pronounced for ‘Both Equality Relations’, where DAS learns perfect or near-perfect alignments under a number of settings, whereas the best brute-force alignment achieves only 0.60 and the best localist alignment achieves only 0.73. Finally, the distributed representation of left equality could not be decomposed into a representation of the first argument identity. We see this in the very low performance of the ‘Identity Subspace of Left Equality’ results. This indicates that models are truly learning to encode an abstract equality relation, rather than merely storing the identities of the inputs.
4.4 Analyzing a Randomly Initialized Network
Figure 3: DAS on a random network with a 16 dimension input. An oversized hidden dimension allows DAS to manipulate the model behavior by searching through a large space of random mechanisms.
Both Equality Relations
Hidden size & Intervention size Layer 1|𝐍|=16 𝐍 16|\mathbf{N}|=16| bold_N | = 16 k=8 𝑘 8 k=8 italic_k = 8 0.50 |𝐍|=64 𝐍 64|\mathbf{N}|=64| bold_N | = 64 k=32 𝑘 32 k=32 italic_k = 32 0.50 |𝐍|=256 𝐍 256|\mathbf{N}|=256| bold_N | = 256 k=128 𝑘 128 k=128 italic_k = 128 0.51 |𝐍|=1028 𝐍 1028|\mathbf{N}|=1028| bold_N | = 1028 k=512 𝑘 512 k=512 italic_k = 512 0.55 |𝐍|=4096 𝐍 4096|\mathbf{N}|=4096| bold_N | = 4096 k=2048 𝑘 2048 k=2048 italic_k = 2048 0.64
To calibrate intuitions about our method, we evaluate the ability of DAS to optimize for interchange intervention accuracy on a frozen randomly initialized networks that achieves chance accuracy (50%) on the hierarchical equality task. This investigates the degree to which random causal structures can be used to systematically manipulate the counterfactual behavior of the network. We evaluate networks with different hidden representation sizes while holding the four input vectors fixed at 4 4 4 4 dimensions, under the hypothesis that more hidden neurons create more random structure that DAS can search through. These results are summarized in Table4.4. Observe that, in small networks, there is no ability to increase interchange intervention accuracy. However, as we increase the size of the hidden representation to be orders of magnitude larger than the input dimension of 16, the interchange intervention accuracy increases. This confirms our hypothesis and serves as a check that demonstrates DAS cannot construct entirely new behaviors from random structure.
\subfigure [Two MoNLI examples.] \subfigure[A simple program that solves MoNLI.] MoNLI(𝐩,𝐡)MoNLI 𝐩 𝐡\textsc{MoNLI}(\textbf{p},\textbf{h})MoNLI ( p , h ){codebox}\li 𝑙𝑒𝑥𝑟𝑒𝑙←\procget−lexrel(𝐩,𝐡)←𝑙𝑒𝑥𝑟𝑒𝑙\proc 𝑔 𝑒 𝑡 𝑙 𝑒 𝑥 𝑟 𝑒 𝑙 𝐩 𝐡\textit{lexrel}\leftarrow\proc{get-lexrel}(\textbf{p},\textbf{h})lexrel ← italic_g italic_e italic_t - italic_l italic_e italic_x italic_r italic_e italic_l ( p , h )\li 𝑛𝑒𝑔←\proccontains−not(𝐩,𝐡)←𝑛𝑒𝑔\proc 𝑐 𝑜 𝑛 𝑡 𝑎 𝑖 𝑛 𝑠 𝑛 𝑜 𝑡 𝐩 𝐡\textit{neg}\leftarrow\proc{contains-not}(\textbf{p},\textbf{h})neg ← italic_c italic_o italic_n italic_t italic_a italic_i italic_n italic_s - italic_n italic_o italic_t ( p , h )\li if neg: \li\Return\procreverse(𝑙𝑒𝑥𝑟𝑒𝑙)\proc 𝑟 𝑒 𝑣 𝑒 𝑟 𝑠 𝑒 𝑙𝑒𝑥𝑟𝑒𝑙\proc{reverse}(\textit{lexrel})italic_r italic_e italic_v italic_e italic_r italic_s italic_e ( lexrel )\End\li\Return lexrel
Figure 4: Monotonicity NLI task examples and high-level model.
Figure 3: DAS on a random network with a 16 dimension input. An oversized hidden dimension allows DAS to manipulate the model behavior by searching through a large space of random mechanisms.
Xet Storage Details
- Size:
- 69.9 kB
- Xet hash:
- dbbf06d45179ee7131377bafb3c86dd64c903525ddec2eb7de8e20c395b0b02a
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.