Update modeling_steerling.py

#4
by AyaGL - opened
Files changed (1) hide show
  1. modeling_steerling.py +1574 -0
modeling_steerling.py ADDED
@@ -0,0 +1,1574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ # Auto-generated by scripts/build_hf_files_v3.py — do not edit manually.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import logging
8
+ import os
9
+ from functools import partial
10
+ from typing import TYPE_CHECKING, Any
11
+ from dataclasses import dataclass
12
+ from torch import Tensor
13
+ import math
14
+ import warnings
15
+
16
+
17
+ # ======================================================================
18
+ # steerling/models/layers/primitives.py
19
+ # ======================================================================
20
+
21
+ class RMSNorm(nn.Module):
22
+ """
23
+ Root Mean Square Layer Normalization.
24
+ """
25
+
26
+ def __init__(self, config, size: int | None=None):
27
+ super().__init__()
28
+ self.eps = getattr(config, 'norm_eps', 1e-05)
29
+ norm_size = size if size is not None else config.n_embd
30
+ self.weight = nn.Parameter(torch.ones(norm_size))
31
+
32
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
33
+ og = x.dtype
34
+ x = x.float()
35
+ var = x.pow(2).mean(-1, keepdim=True)
36
+ x = x * torch.rsqrt(var + self.eps)
37
+ return (self.weight * x).to(og)
38
+
39
+ class BufferCache:
40
+ """Simple cache for storing tensors (used by RotaryEmbedding)."""
41
+
42
+ def __init__(self):
43
+ self._cache: dict[str, torch.Tensor] = {}
44
+
45
+ def get(self, key: str) -> torch.Tensor | None:
46
+ return self._cache.get(key)
47
+
48
+ def __setitem__(self, key: str, value: torch.Tensor):
49
+ self._cache[key] = value
50
+
51
+ def __getitem__(self, key: str) -> torch.Tensor:
52
+ return self._cache[key]
53
+
54
+ class RotaryEmbedding(nn.Module):
55
+ """
56
+ Rotary Position Embeddings (RoPE).
57
+
58
+ Applies rotary embeddings to queries and keys for position information.
59
+
60
+ Args:
61
+ dim: Dimension of the rotary embeddings (typically head_dim)
62
+ max_seq_len: Maximum sequence length to cache
63
+ base: Base for inverse frequency computation (theta)
64
+ rope_full_precision: Whether to compute RoPE in full precision
65
+ """
66
+
67
+ def __init__(self, dim: int, max_seq_len: int=2048, base: float=10000.0, rope_full_precision: bool=True):
68
+ super().__init__()
69
+ self.dim = dim
70
+ self.max_seq_len = max_seq_len
71
+ self.rope_theta = base
72
+ self.rope_full_precision = rope_full_precision
73
+ self.__cache = BufferCache()
74
+ self.get_rotary_embedding(max_seq_len, torch.device('cpu'))
75
+
76
+ def get_rotary_embedding(self, seq_len: int, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
77
+ """Get or compute rotary embeddings for given sequence length."""
78
+ pos_sin = self.__cache.get('rope_pos_sin')
79
+ pos_cos = self.__cache.get('rope_pos_cos')
80
+ if pos_sin is not None and pos_cos is not None and (pos_sin.shape[-2] >= seq_len) and (pos_cos.shape[-2] >= seq_len):
81
+ if pos_sin.device != device:
82
+ pos_sin = pos_sin.to(device)
83
+ self.__cache['rope_pos_sin'] = pos_sin
84
+ if pos_cos.device != device:
85
+ pos_cos = pos_cos.to(device)
86
+ self.__cache['rope_pos_cos'] = pos_cos
87
+ return (pos_sin[:, :, :seq_len, :], pos_cos[:, :, :seq_len, :])
88
+ with torch.autocast(device.type, enabled=False):
89
+ inv_freq = 1.0 / self.rope_theta ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float) / self.dim)
90
+ seq = torch.arange(seq_len, device=device, dtype=torch.float)
91
+ freqs = torch.outer(seq, inv_freq)
92
+ positions = torch.cat((freqs, freqs), dim=-1)
93
+ pos_sin = positions.sin()[None, None, :, :]
94
+ pos_cos = positions.cos()[None, None, :, :]
95
+ self.__cache['rope_pos_sin'] = pos_sin
96
+ self.__cache['rope_pos_cos'] = pos_cos
97
+ return (pos_sin, pos_cos)
98
+
99
+ def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
100
+ """Rotate half the hidden dims of the input."""
101
+ B, nh, T, hs = x.size()
102
+ x = x.view(B, nh, T, 2, hs // 2)
103
+ x1, x2 = x.unbind(dim=-2)
104
+ return torch.cat((-x2, x1), dim=-1)
105
+
106
+ def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
107
+ """Apply rotary position embeddings to input tensor."""
108
+ return (t * pos_cos + self.rotate_half(t) * pos_sin).to(t.dtype)
109
+
110
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
111
+ """Apply rotary embeddings to queries and keys."""
112
+ if self.rope_full_precision:
113
+ q_, k_ = (q.float(), k.float())
114
+ else:
115
+ q_, k_ = (q, k)
116
+ with torch.autocast(q.device.type, enabled=False):
117
+ query_len, key_len = (q_.shape[-2], k_.shape[-2])
118
+ pos_sin, pos_cos = self.get_rotary_embedding(key_len, q_.device)
119
+ pos_sin = pos_sin.type_as(q_)
120
+ pos_cos = pos_cos.type_as(q_)
121
+ q_ = self.apply_rotary_pos_emb(pos_sin[:, :, key_len - query_len:key_len, :], pos_cos[:, :, key_len - query_len:key_len, :], q_)
122
+ k_ = self.apply_rotary_pos_emb(pos_sin, pos_cos, k_)
123
+ return (q_.type_as(q), k_.type_as(k))
124
+
125
+ class MLP(nn.Module):
126
+ """
127
+ Multi-Layer Perceptron with SwiGLU or standard activation.
128
+
129
+ Args:
130
+ config: Model config with n_embd, mlp_ratio, use_bias, mlp_type, activation
131
+ """
132
+
133
+ def __init__(self, config):
134
+ super().__init__()
135
+ if hasattr(config, 'intermediate_size') and config.intermediate_size is not None:
136
+ intermediate_size = config.intermediate_size
137
+ else:
138
+ intermediate_size = getattr(config, 'mlp_ratio', 4) * config.n_embd
139
+ use_bias = config.use_bias
140
+ mlp_type = config.mlp_type
141
+ if mlp_type == 'swiglu':
142
+ self.c_fc = nn.Linear(config.n_embd, 2 * intermediate_size, bias=use_bias)
143
+ self.c_proj = nn.Linear(intermediate_size, config.n_embd, bias=use_bias)
144
+ self.activation = None
145
+ else:
146
+ self.c_fc = nn.Linear(config.n_embd, intermediate_size, bias=use_bias)
147
+ self.c_proj = nn.Linear(intermediate_size, config.n_embd, bias=use_bias)
148
+ act_map = {'gelu': nn.GELU(approximate='tanh'), 'relu': nn.ReLU(), 'silu': nn.SiLU()}
149
+ self.activation = act_map[config.activation]
150
+ self.c_proj.SCALE_INIT = 1
151
+ self.config = config
152
+
153
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
154
+ mlp_type = getattr(self.config, 'mlp_type', 'swiglu')
155
+ if mlp_type == 'swiglu':
156
+ gate_up = self.c_fc(x)
157
+ up, gate = gate_up.chunk(2, dim=-1)
158
+ intermediate = F.silu(gate) * up
159
+ else:
160
+ intermediate = self.c_fc(x)
161
+ intermediate = self.activation(intermediate)
162
+ return self.c_proj(intermediate)
163
+
164
+ # ======================================================================
165
+ # steerling/models/layers/causal_diffusion_layers.py
166
+ # ======================================================================
167
+
168
+ logger = logging.getLogger(__name__)
169
+ try:
170
+ from torch.nn.attention.flex_attention import BlockMask, _dense_to_ordered, flex_attention
171
+ _FLEX_ATTN_AVAILABLE = True
172
+ except ImportError:
173
+ _FLEX_ATTN_AVAILABLE = False
174
+ BlockMask: Any = None
175
+ flex_attention: Any = None
176
+ _dense_to_ordered: Any = None
177
+ if os.environ.get('STEERLING_USE_FLEX_ATTN', '0') != '1':
178
+ _FLEX_ATTN_AVAILABLE = False
179
+ if TYPE_CHECKING:
180
+ from torch.nn.attention.flex_attention import BlockMask as BlockMaskType
181
+ from steerling.configs.causal_diffusion import CausalDiffusionConfig
182
+ if torch.cuda.is_available() and _FLEX_ATTN_AVAILABLE:
183
+ compiled_flex_attention = torch.compile(flex_attention, fullgraph=True)
184
+ else:
185
+ compiled_flex_attention = flex_attention
186
+
187
+ def block_causal_mask_mod(b: Any, h: Any, q_idx: torch.Tensor, kv_idx: torch.Tensor, *, block_size: int) -> torch.Tensor:
188
+ """Block-causal mask: causal across blocks, bidirectional within blocks."""
189
+ return q_idx // block_size >= kv_idx // block_size
190
+
191
+ def fast_create_block_causal_mask(attn_block_size: int, seq_length: int, mask_block_size: int, device: torch.device) -> BlockMaskType:
192
+ """
193
+ Fast block-causal mask creation for flex_attention.
194
+
195
+ Analytically computes the sparse block structure instead of evaluating
196
+ the mask function at every position.
197
+ """
198
+ if not _FLEX_ATTN_AVAILABLE or _dense_to_ordered is None or BlockMask is None:
199
+ raise RuntimeError('flex_attention not available')
200
+ num_mask_blocks = -(-seq_length // mask_block_size)
201
+ attn_blocks_per_mask_block, rem = divmod(mask_block_size, attn_block_size)
202
+ if rem != 0:
203
+ raise ValueError(f'mask_block_size ({mask_block_size}) must be divisible by attn_block_size ({attn_block_size})')
204
+ num_attn_blocks = num_mask_blocks * attn_blocks_per_mask_block
205
+ lowres_attn_mask = torch.tril(torch.ones(num_attn_blocks, num_attn_blocks, dtype=torch.bool, device=device))
206
+ block_attn_count = lowres_attn_mask.reshape(num_mask_blocks, attn_blocks_per_mask_block, num_mask_blocks, attn_blocks_per_mask_block).permute(0, 2, 1, 3).sum(dim=[-2, -1])
207
+ max_count = attn_blocks_per_mask_block * attn_blocks_per_mask_block
208
+ full_block_mask = block_attn_count == max_count
209
+ if seq_length % mask_block_size > 0:
210
+ full_block_mask[-1, :] = False
211
+ normal_block_mask = (block_attn_count > 0) & ~full_block_mask
212
+ kv_num_blocks, kv_indices = _dense_to_ordered(normal_block_mask)
213
+ full_kv_num_blocks, full_kv_indices = _dense_to_ordered(full_block_mask)
214
+ q_num_blocks, q_indices = _dense_to_ordered(normal_block_mask.transpose(-2, -1))
215
+ full_q_num_blocks, full_q_indices = _dense_to_ordered(full_block_mask.transpose(-2, -1))
216
+ return BlockMask(seq_lengths=(seq_length, seq_length), kv_num_blocks=kv_num_blocks[None, None, ...], kv_indices=kv_indices[None, None, ...], full_kv_num_blocks=full_kv_num_blocks[None, None, ...], full_kv_indices=full_kv_indices[None, None, ...], q_num_blocks=q_num_blocks[None, None, ...], q_indices=q_indices[None, None, ...], full_q_num_blocks=full_q_num_blocks[None, None, ...], full_q_indices=full_q_indices[None, None, ...], mask_mod=partial(block_causal_mask_mod, block_size=attn_block_size), BLOCK_SIZE=(mask_block_size, mask_block_size))
217
+
218
+ def sdpa_with_block_causal_mask(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, diff_block_size: int, mask_cache: dict[str, torch.Tensor], enable_gqa: bool=False) -> torch.Tensor:
219
+ """Fallback using SDPA with dense mask when flex_attention unavailable."""
220
+ B, H, T, D = q.shape
221
+ device = q.device
222
+ dtype = q.dtype
223
+ cache_key = f'sdpa_{T}_{device}_{dtype}'
224
+ if cache_key not in mask_cache:
225
+ q_idx = torch.arange(T, device=device).unsqueeze(1)
226
+ kv_idx = torch.arange(T, device=device).unsqueeze(0)
227
+ bool_mask = q_idx // diff_block_size >= kv_idx // diff_block_size
228
+ attn_mask = torch.zeros(T, T, device=device, dtype=dtype)
229
+ attn_mask.masked_fill_(~bool_mask, float('-inf'))
230
+ mask_cache[cache_key] = attn_mask
231
+ return F.scaled_dot_product_attention(q, k, v, attn_mask=mask_cache[cache_key], dropout_p=0.0, is_causal=False, enable_gqa=enable_gqa)
232
+
233
+ class BlockCausalAttention(nn.Module):
234
+ """Block-causal self-attention with FlexAttention and optional GQA."""
235
+ FLEX_MASK_BLOCK_SIZE = 128
236
+
237
+ def __init__(self, config: CausalDiffusionConfig) -> None:
238
+ super().__init__()
239
+ if not hasattr(config, 'diff_block_size'):
240
+ raise ValueError("BlockCausalAttention requires 'diff_block_size' in config.")
241
+ assert config.n_embd % config.n_head == 0
242
+ self.config = config
243
+ self.n_head = config.n_head
244
+ self.n_embd = config.n_embd
245
+ self.head_dim = config.n_embd // config.n_head
246
+ n_kv = getattr(config, 'n_kv_heads', None)
247
+ self.n_kv_heads = self.n_head if n_kv is None else int(n_kv)
248
+ if self.n_kv_heads <= 0:
249
+ raise ValueError(f'n_kv_heads must be >= 1 (got {self.n_kv_heads})')
250
+ if self.n_head % self.n_kv_heads != 0:
251
+ raise ValueError(f'n_head ({self.n_head}) must be divisible by n_kv_heads ({self.n_kv_heads})')
252
+ self.kv_repeat = self.n_head // self.n_kv_heads
253
+ use_bias = getattr(config, 'use_bias', False)
254
+ kv_out = self.n_kv_heads * self.head_dim
255
+ attn_out = self.n_embd + 2 * kv_out
256
+ self.c_attn = nn.Linear(config.n_embd, attn_out, bias=use_bias)
257
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=use_bias)
258
+ self.c_proj.SCALE_INIT = 1
259
+ if getattr(config, 'use_qk_norm', False):
260
+ if getattr(config, 'use_rms_norm', True):
261
+ self.q_norm: nn.Module | None = RMSNorm(config, size=self.head_dim)
262
+ self.k_norm: nn.Module | None = RMSNorm(config, size=self.head_dim)
263
+ else:
264
+ self.q_norm = nn.LayerNorm(self.head_dim)
265
+ self.k_norm = nn.LayerNorm(self.head_dim)
266
+ else:
267
+ self.q_norm = None
268
+ self.k_norm = None
269
+ if getattr(config, 'use_rope', True):
270
+ self.rope: RotaryEmbedding | None = RotaryEmbedding(dim=self.head_dim, max_seq_len=config.block_size, base=getattr(config, 'rope_base', 500000.0), rope_full_precision=getattr(config, 'rope_full_precision', True))
271
+ else:
272
+ self.rope = None
273
+ self._mask_cache: dict = {}
274
+ self._sdpa_mask_cache: dict[str, torch.Tensor] = {}
275
+ self._logged_attention_mode = False
276
+
277
+ def _get_block_mask(self, T: int, device: torch.device):
278
+ cache_key = f'flex_{T}_{device}'
279
+ if cache_key not in self._mask_cache:
280
+ diff_block_size = self.config.diff_block_size
281
+ mask_block_size = self.FLEX_MASK_BLOCK_SIZE
282
+ if mask_block_size % diff_block_size != 0:
283
+ mask_block_size = diff_block_size * (mask_block_size // diff_block_size)
284
+ if mask_block_size == 0:
285
+ mask_block_size = diff_block_size
286
+ self._mask_cache[cache_key] = fast_create_block_causal_mask(attn_block_size=diff_block_size, seq_length=T, mask_block_size=mask_block_size, device=device)
287
+ return self._mask_cache[cache_key]
288
+
289
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
290
+ B, T, C = x.size()
291
+ device = x.device
292
+ use_flex = _FLEX_ATTN_AVAILABLE and x.is_cuda and (flex_attention is not None)
293
+ if not self._logged_attention_mode:
294
+ self._logged_attention_mode = True
295
+ mode = 'flex_attention' if use_flex else 'SDPA fallback'
296
+ logger.debug(f'[CausalDiffusion] Using {mode} with GQA (n_head={self.n_head}, n_kv_heads={self.n_kv_heads})')
297
+ qkv = self.c_attn(x)
298
+ clip_qkv = getattr(self.config, 'clip_qkv', None)
299
+ if clip_qkv is not None:
300
+ qkv = qkv.clamp(min=-clip_qkv, max=clip_qkv)
301
+ kv_dim = self.n_kv_heads * self.head_dim
302
+ q, k, v = qkv.split([self.n_embd, kv_dim, kv_dim], dim=2)
303
+ q = q.reshape(B, T, self.n_head, self.head_dim).transpose(1, 2)
304
+ k = k.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
305
+ v = v.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
306
+ if self.q_norm is not None and self.k_norm is not None:
307
+ q = self.q_norm(q)
308
+ k = self.k_norm(k)
309
+ if self.rope is not None:
310
+ q, k = self.rope(q, k)
311
+ if use_flex:
312
+ block_mask = self._get_block_mask(T, device)
313
+ assert flex_attention is not None and compiled_flex_attention is not None
314
+ if q.is_cuda:
315
+ y = compiled_flex_attention(q, k, v, block_mask=block_mask, enable_gqa=True)
316
+ else:
317
+ y = flex_attention(q, k, v, block_mask=block_mask, enable_gqa=True)
318
+ else:
319
+ y = sdpa_with_block_causal_mask(q, k, v, diff_block_size=self.config.diff_block_size, mask_cache=self._sdpa_mask_cache, enable_gqa=True)
320
+ y = y.transpose(1, 2).reshape(B, T, C)
321
+ y = self.c_proj(y)
322
+ return y
323
+
324
+ class CausalDiffusionBlock(nn.Module):
325
+ """Transformer block for CausalDiffusionLM (block-causal attention + MLP)."""
326
+
327
+ def __init__(self, config: CausalDiffusionConfig) -> None:
328
+ super().__init__()
329
+ use_rms_norm = getattr(config, 'use_rms_norm', True)
330
+ if use_rms_norm:
331
+ self.ln_1: nn.Module = RMSNorm(config)
332
+ self.ln_2: nn.Module = RMSNorm(config)
333
+ else:
334
+ self.ln_1 = nn.LayerNorm(config.n_embd)
335
+ self.ln_2 = nn.LayerNorm(config.n_embd)
336
+ self.norm_order = getattr(config, 'norm_order', 'post')
337
+ self.attn = BlockCausalAttention(config)
338
+ self.mlp = MLP(config)
339
+
340
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
341
+ if self.norm_order == 'pre':
342
+ x = x + self.attn(self.ln_1(x))
343
+ x = x + self.mlp(self.ln_2(x))
344
+ else:
345
+ x = x + self.ln_1(self.attn(x))
346
+ x = x + self.ln_2(self.mlp(x))
347
+ return x
348
+
349
+ # ======================================================================
350
+ # steerling/models/causal_diffusion.py
351
+ # ======================================================================
352
+
353
+ class CausalDiffusionLM(nn.Module):
354
+ """
355
+ CausalDiffusionLM transformer backbone with block-causal attention.
356
+
357
+ Pure compute graph — no training code, no loss logic.
358
+
359
+ Args:
360
+ config: CausalDiffusionConfig with model hyperparameters
361
+ vocab_size: Vocabulary size (including special tokens)
362
+ """
363
+
364
+ def __init__(self, config: CausalDiffusionConfig, vocab_size: int) -> None:
365
+ super().__init__()
366
+ self.config = config
367
+ self.vocab_size = vocab_size
368
+ self.tok_emb = nn.Embedding(vocab_size, config.n_embd)
369
+ self.blocks = nn.ModuleList([CausalDiffusionBlock(config) for _ in range(config.n_layers)])
370
+ if config.use_rms_norm:
371
+ self.ln_f: nn.Module = RMSNorm(config)
372
+ else:
373
+ self.ln_f = nn.LayerNorm(config.n_embd)
374
+ self.lm_head = nn.Linear(config.n_embd, vocab_size, bias=False)
375
+ if config.weight_sharing:
376
+ self.tok_emb.weight = self.lm_head.weight
377
+
378
+ def forward(self, input_ids: torch.Tensor, *, input_embeds: torch.Tensor | None=None, return_hidden: bool=False) -> torch.Tensor:
379
+ """
380
+ Forward pass.
381
+
382
+ Args:
383
+ input_ids: Token indices [B, T] (may contain mask tokens)
384
+ input_embeds: Pre-computed embeddings [B, T, D]. If provided, input_ids is ignored.
385
+ return_hidden: If True, return hidden states before lm_head.
386
+
387
+ Returns:
388
+ logits [B, T, vocab_size] or hidden_states [B, T, n_embd]
389
+ """
390
+ if input_embeds is not None:
391
+ x = input_embeds
392
+ elif input_ids is not None:
393
+ x = self.tok_emb(input_ids)
394
+ else:
395
+ raise ValueError('Either input_ids or input_embeds must be provided')
396
+ for block in self.blocks:
397
+ x = block(x)
398
+ x = self.ln_f(x)
399
+ if return_hidden:
400
+ return x
401
+ return self.lm_head(x)
402
+
403
+ def get_num_params(self, non_embedding: bool=True) -> int:
404
+ """Return number of parameters."""
405
+ n_params = sum((p.numel() for p in self.parameters()))
406
+ if non_embedding:
407
+ n_params -= self.tok_emb.weight.numel()
408
+ return n_params
409
+
410
+ def _restore_weight_tying(self) -> None:
411
+ """Re-establish weight tying after to_empty() or device transfer."""
412
+ if self.config.weight_sharing:
413
+ self.tok_emb.weight = self.lm_head.weight
414
+
415
+ def _init_weights(self, module: nn.Module) -> None:
416
+ """Initialize model weights (used for fresh models, not loaded checkpoints)."""
417
+ if isinstance(module, nn.Linear):
418
+ std = 0.02
419
+ if hasattr(module, 'SCALE_INIT'):
420
+ std *= (2 * self.config.n_layers) ** (-0.5)
421
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
422
+ if module.bias is not None:
423
+ torch.nn.init.zeros_(module.bias)
424
+ elif isinstance(module, nn.Embedding):
425
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
426
+ elif isinstance(module, RMSNorm):
427
+ torch.nn.init.ones_(module.weight)
428
+
429
+ # ======================================================================
430
+ # steerling/models/interpretable/outputs.py
431
+ # ======================================================================
432
+
433
+ @dataclass
434
+ class InterpretableOutput:
435
+ """
436
+ Full output from InterpretableCausalDiffusionLM; it contains all decomposition components for attribution and analysis.
437
+ """
438
+ hidden: Tensor
439
+ known_features: Tensor
440
+ known_logits: Tensor | None
441
+ known_gt_features: Tensor | None
442
+ known_predicted: Tensor
443
+ known_weights: Tensor | None
444
+ known_topk_indices: Tensor | None
445
+ known_topk_logits: Tensor | None
446
+ unk: Tensor
447
+ unk_hat: Tensor | None
448
+ unk_for_lm: Tensor
449
+ unknown_logits: Tensor | None
450
+ unknown_weights: Tensor | None
451
+ unknown_topk_indices: Tensor | None
452
+ unknown_topk_logits: Tensor | None
453
+ composed: Tensor
454
+ epsilon: Tensor | None
455
+ epsilon_true: Tensor | None
456
+
457
+ # ======================================================================
458
+ # steerling/models/interpretable/concept_head.py
459
+ # ======================================================================
460
+
461
+ logger = logging.getLogger(__name__)
462
+ LARGE_CONCEPT_THRESHOLD = 50000
463
+
464
+ @dataclass
465
+ class ConceptHeadOutput:
466
+ """Output from ConceptHead forward pass.
467
+
468
+ Attributes:
469
+ features: Final concept features after teacher forcing/intervention (B, T, D)
470
+ gt_features: Ground truth pooled features. None for unknown heads. (B, T, D) or None
471
+ logits: Full concept logits (B, T, C). Only set if return_logits=True. Usually None.
472
+ predicted: Predicted features before teacher forcing mixing (B, T, D)
473
+ weights: Full concept weights (B, T, C). Only set if return_logits=True. Usually None.
474
+ topk_indices: Top-k concept indices (B, T, k). Set when using streaming top-k.
475
+ topk_logits: Logits for top-k concepts (B, T, k). Set when using streaming top-k.
476
+ hidden: Hidden states passed to this head (B, T, D). Stored for attribution.
477
+ """
478
+ features: Tensor
479
+ gt_features: Tensor | None
480
+ logits: Tensor | None
481
+ predicted: Tensor
482
+ weights: Tensor | None = None
483
+ topk_indices: Tensor | None = None
484
+ topk_logits: Tensor | None = None
485
+ hidden: Tensor | None = None
486
+
487
+ class ConceptHead(nn.Module):
488
+ """
489
+ Concept decomposition head supporting both known and unknown concepts.
490
+ Memory-efficient implementation that avoids (B, T, C) allocations by default.
491
+
492
+ Modes:
493
+ - Known (is_unknown=False): Supports GT, teacher forcing, top-k, interventions
494
+ - Unknown (is_unknown=True): No GT, no teacher forcing
495
+
496
+ Architectures:
497
+ - use_attention=False: Linear predictor (n_embd -> n_concepts)
498
+ - use_attention=True: Query projection + sigmoid attention over embeddings
499
+
500
+ Factorization (for large unknown heads):
501
+ - factorize=False: Dense embeddings (C, D) and predictor (D, C)
502
+ - factorize=True: Factorized embeddings (C, r) @ (r, D) where r << D
503
+ Reduces memory by ~10-20x for large C
504
+
505
+ Memory Safety:
506
+ - Unknown heads with n_concepts > 50k cannot use dense operations
507
+ - Interventions are only supported for known heads
508
+ - return_logits=True is forbidden for large unknown heads
509
+ - All tensor indexing uses F.embedding for DTensor safety
510
+
511
+ Args:
512
+ n_concepts: Number of concepts (C)
513
+ concept_dim: Dimension of concept embeddings (should equal n_embd)
514
+ n_embd: Model hidden dimension
515
+ is_unknown: If True, skip GT pooling and teacher forcing
516
+ use_attention: If True, use attention; else use linear predictor
517
+ topk: Top-k sparsity for concept weights. None = no sparsity.
518
+ block_size: Block size for memory-efficient operations
519
+ pad_multiple: Pad n_concepts to a multiple of this for efficiency
520
+ store_unknown_weights: If True and use_attention & is_unknown, store logits/weights
521
+ apply_topk_to_unknown: If True, also apply top-k to unknown concepts
522
+ topk_on_logits: If True, apply top-k on logits (then sigmoid). If False, on weights.
523
+ teacher_force_alpha: If None, hard TF. If in [0,1], soft mixing.
524
+ factorize: If True, use low-rank factorized embeddings
525
+ factorize_rank: Rank for factorization (r). Lower = less memory, less expressivity.
526
+ """
527
+
528
+ class ConceptPooling(nn.Module):
529
+ """Memory-efficient sum pooling using scatter-add."""
530
+
531
+ def __init__(self, concept_dim: int):
532
+ super().__init__()
533
+ self.concept_dim = concept_dim
534
+
535
+ def forward(self, concept_ids: Tensor, concept_mask: Tensor, concept_embeddings: nn.Embedding) -> Tensor:
536
+ """
537
+ Pool concept embeddings based on ground truth IDs.
538
+ Uses scatter-add to avoid (B, T, K, D) allocation when K is sparse.
539
+
540
+ Args:
541
+ concept_ids: (B, T, K) concept indices, -1 for invalid
542
+ concept_mask: (B, T, K) boolean mask for valid concepts
543
+ concept_embeddings: Embedding layer to look up
544
+
545
+ Returns:
546
+ Pooled features (B, T, D)
547
+ """
548
+ B, T, K = concept_ids.shape
549
+ D = concept_embeddings.embedding_dim
550
+ device = concept_ids.device
551
+ valid_mask = concept_mask & (concept_ids != -1)
552
+ pooled = torch.zeros(B, T, D, device=device, dtype=concept_embeddings.weight.dtype)
553
+ if not valid_mask.any():
554
+ return pooled
555
+ b_idx, t_idx, k_idx = torch.where(valid_mask)
556
+ c_ids = concept_ids[b_idx, t_idx, k_idx].long()
557
+ emb = concept_embeddings(c_ids)
558
+ flat_idx = b_idx * T + t_idx
559
+ flat_idx = flat_idx.unsqueeze(-1).expand(-1, D)
560
+ pooled_flat = pooled.view(B * T, D)
561
+ pooled_flat.scatter_add_(0, flat_idx, emb)
562
+ return pooled.view(B, T, D)
563
+
564
+ def __init__(self, n_concepts: int, concept_dim: int, n_embd: int, is_unknown: bool=False, use_attention: bool=False, topk: int | None=16, topk_features: int | None=None, block_size: int=8192, *, pad_multiple: int=16, store_unknown_weights: bool=False, apply_topk_to_unknown: bool=False, topk_on_logits: bool=False, factorize: bool=False, factorize_rank: int=256):
565
+ super().__init__()
566
+ self.n_concepts = n_concepts
567
+ self.concept_dim = concept_dim
568
+ self.n_embd = n_embd
569
+ self.is_unknown = is_unknown
570
+ self.use_attention = use_attention
571
+ self.topk = topk
572
+ self.topk_features = topk_features if topk_features is not None else topk
573
+ self.block_size = block_size
574
+ self.pad_multiple = pad_multiple
575
+ self.store_unknown_weights = store_unknown_weights
576
+ self.apply_topk_to_unknown = apply_topk_to_unknown
577
+ self.topk_on_logits = topk_on_logits
578
+ self.factorize = factorize
579
+ self.factorize_rank = factorize_rank
580
+ self._is_large = n_concepts > LARGE_CONCEPT_THRESHOLD
581
+ self.n_concepts_padded = (n_concepts + pad_multiple - 1) // pad_multiple * pad_multiple
582
+ if factorize:
583
+ self.embedding_coef = nn.Embedding(self.n_concepts_padded, factorize_rank)
584
+ self.embedding_basis = nn.Linear(factorize_rank, concept_dim, bias=False)
585
+ self.concept_embedding = None
586
+ if not use_attention:
587
+ self.predictor_down = nn.Linear(n_embd, factorize_rank, bias=False)
588
+ self.predictor_up = nn.Linear(factorize_rank, self.n_concepts_padded, bias=False)
589
+ self.concept_predictor = None
590
+ else:
591
+ self.concept_query_projection = nn.Linear(n_embd, concept_dim, bias=False)
592
+ self.predictor_down = None
593
+ self.predictor_up = None
594
+ self.concept_predictor = None
595
+ dense_params = n_concepts * concept_dim * 2
596
+ factorized_params = n_concepts * factorize_rank + factorize_rank * concept_dim + (n_embd * factorize_rank + factorize_rank * n_concepts if not use_attention else 0)
597
+ logger.info(f'[ConceptHead] Factorized mode: {n_concepts} concepts, rank={factorize_rank}')
598
+ logger.info(f'[ConceptHead] Memory: {dense_params * 2 / 1000000000.0:.2f} GB (dense) -> {factorized_params * 2 / 1000000000.0:.2f} GB (factorized) = {(1 - factorized_params / dense_params) * 100:.1f}% reduction')
599
+ else:
600
+ self.concept_embedding = nn.Embedding(self.n_concepts_padded, concept_dim)
601
+ self.embedding_coef = None
602
+ self.embedding_basis = None
603
+ if use_attention:
604
+ self.concept_query_projection = nn.Linear(n_embd, concept_dim, bias=False)
605
+ self.concept_predictor = None
606
+ else:
607
+ self.concept_predictor = nn.Linear(n_embd, self.n_concepts_padded, bias=False)
608
+ self.predictor_down = None
609
+ self.predictor_up = None
610
+ self.concept_pooling = self.ConceptPooling(concept_dim)
611
+ if self.topk_features != self.topk:
612
+ logger.info(f"[ConceptHead] {('Unknown' if is_unknown else 'Known')} head: topk={self.topk} (loss), topk_features={self.topk_features} (features)")
613
+ if is_unknown and apply_topk_to_unknown:
614
+ logger.info(f'[ConceptHead] Unknown head: apply_topk_to_unknown=True, topk={self.topk}')
615
+ self._init_weights()
616
+
617
+ def _init_weights(self):
618
+ """Initialize weights with small values."""
619
+ if self.factorize:
620
+ nn.init.normal_(self.embedding_coef.weight, mean=0.0, std=0.02)
621
+ nn.init.normal_(self.embedding_basis.weight, mean=0.0, std=0.02)
622
+ if self.predictor_down is not None:
623
+ nn.init.normal_(self.predictor_down.weight, mean=0.0, std=0.02)
624
+ if self.predictor_up is not None:
625
+ nn.init.normal_(self.predictor_up.weight, mean=0.0, std=0.02)
626
+ else:
627
+ if self.concept_embedding is not None:
628
+ nn.init.normal_(self.concept_embedding.weight, mean=0.0, std=0.02)
629
+ if self.concept_predictor is not None:
630
+ nn.init.normal_(self.concept_predictor.weight, mean=0.0, std=0.02)
631
+ if hasattr(self, 'concept_query_projection') and self.concept_query_projection is not None:
632
+ nn.init.normal_(self.concept_query_projection.weight, mean=0.0, std=0.02)
633
+
634
+ def _check_dense_allowed(self, operation: str) -> None:
635
+ """Raise error if dense operations are requested for large unknown heads."""
636
+ if self.is_unknown and self._is_large:
637
+ raise ValueError(f'{operation} requested for unknown head with {self.n_concepts} concepts. This would allocate multi-GB tensors. Use streaming mode instead. (Threshold: {LARGE_CONCEPT_THRESHOLD})')
638
+
639
+ @staticmethod
640
+ def _safe_index(weight: Tensor, indices: Tensor) -> Tensor:
641
+ """
642
+ DTensor-safe indexing using F.embedding.
643
+
644
+ Replaces weight[indices] which crashes under FSDP2/DTensor.
645
+
646
+ Args:
647
+ weight: (N, D) weight matrix
648
+ indices: (...) indices to select
649
+
650
+ Returns:
651
+ (..., D) selected embeddings
652
+ """
653
+ original_shape = indices.shape
654
+ flat_indices = indices.reshape(-1)
655
+ flat_result = F.embedding(flat_indices, weight)
656
+ return flat_result.reshape(*original_shape, -1)
657
+
658
+ def _get_embedding_weight(self) -> Tensor:
659
+ """
660
+ Get full embedding matrix.
661
+
662
+ For dense: returns concept_embedding.weight
663
+ For factorized: computes coef @ basis (materializes full matrix)
664
+
665
+ Returns:
666
+ (C, D) embedding matrix
667
+ """
668
+ if self.concept_embedding is not None:
669
+ return self.concept_embedding.weight
670
+ else:
671
+ return self.embedding_basis(self.embedding_coef.weight)
672
+
673
+ def _get_embedding(self, indices: Tensor) -> Tensor:
674
+ """
675
+ Get embeddings for specific indices (DTensor-safe).
676
+
677
+ For dense: uses F.embedding
678
+ For factorized: looks up coef, then applies basis
679
+
680
+ Args:
681
+ indices: (...) concept indices
682
+
683
+ Returns:
684
+ (..., D) embeddings
685
+ """
686
+ if self.concept_embedding is not None:
687
+ return self.concept_embedding(indices)
688
+ else:
689
+ coef = self.embedding_coef(indices)
690
+ return self.embedding_basis(coef)
691
+
692
+ def _get_predictor_weight(self) -> Tensor | None:
693
+ """
694
+ Get full predictor weight matrix (for linear path only).
695
+
696
+ Returns:
697
+ (C, D) predictor weight, or None if using attention
698
+ """
699
+ if self.concept_predictor is not None:
700
+ return self.concept_predictor.weight
701
+ elif self.predictor_down is not None and self.predictor_up is not None:
702
+ return self.predictor_up.weight @ self.predictor_down.weight
703
+ else:
704
+ return None
705
+
706
+ @staticmethod
707
+ def _merge_topk(topv: Tensor, topi: Tensor, v_blk: Tensor, i_blk: Tensor, k: int) -> tuple[Tensor, Tensor]:
708
+ """Efficient merge of two top-k sets. Memory: O(BT × 2k)."""
709
+ cand_v = torch.cat([topv, v_blk], dim=1)
710
+ cand_i = torch.cat([topi, i_blk], dim=1)
711
+ new_v, sel = torch.topk(cand_v, k, dim=1)
712
+ new_i = torch.gather(cand_i, 1, sel)
713
+ return (new_v, new_i)
714
+
715
+ @staticmethod
716
+ def linear_block_features(hidden: Tensor, predictor_weight: Tensor, embeddings: Tensor, block_size: int=4096) -> Tensor:
717
+ """
718
+ Memory-efficient linear prediction without materializing (B, T, C).
719
+
720
+ Args:
721
+ hidden: (B, T, D)
722
+ predictor_weight: (C, D)
723
+ embeddings: (C, D)
724
+ block_size: Concepts per block
725
+
726
+ Returns:
727
+ Features (B, T, D)
728
+ """
729
+ B, T, D = hidden.shape
730
+ C = predictor_weight.size(0)
731
+ output = torch.zeros(B, T, D, dtype=hidden.dtype, device=hidden.device)
732
+ flat_h = hidden.reshape(-1, D)
733
+ W_t = predictor_weight.t().contiguous()
734
+ for start in range(0, C, block_size):
735
+ end = min(start + block_size, C)
736
+ logits_block = (flat_h @ W_t[:, start:end]).to(torch.float32)
737
+ logits_block = logits_block.clamp(-15, 15)
738
+ weights_block = torch.sigmoid(logits_block)
739
+ E_block = embeddings[start:end].to(weights_block.dtype)
740
+ output.add_((weights_block @ E_block).reshape(B, T, D))
741
+ return output.to(hidden.dtype)
742
+
743
+ @staticmethod
744
+ def attention_block_features(query: Tensor, embeddings: Tensor, block_size: int=4096) -> Tensor:
745
+ """Memory-efficient attention features without materializing (B, T, C)."""
746
+ B, T, D = query.shape
747
+ C = embeddings.shape[0]
748
+ scale = 1.0 / math.sqrt(D)
749
+ flat_q = query.reshape(-1, D)
750
+ emb_T = embeddings.t().contiguous()
751
+ output = torch.zeros(B * T, D, dtype=query.dtype, device=query.device)
752
+ for start in range(0, C, block_size):
753
+ end = min(start + block_size, C)
754
+ scores = (flat_q @ emb_T[:, start:end]).to(torch.float32) * scale
755
+ scores = scores.clamp(-15, 15)
756
+ weights = torch.sigmoid(scores)
757
+ output.add_(weights @ embeddings[start:end].to(weights.dtype))
758
+ return output.reshape(B, T, D).to(query.dtype)
759
+
760
+ @staticmethod
761
+ def linear_features_topk_streaming(hidden: Tensor, predictor_weight: Tensor, embeddings: Tensor, k: int, block_size: int=4096, topk_on_logits: bool=False) -> tuple[Tensor, Tensor, Tensor]:
762
+ """
763
+ Memory-efficient linear prediction with streaming top-k.
764
+
765
+ Uses merge-k-with-k to keep memory O(BT × k), not O(BT × block_size).
766
+
767
+ Args:
768
+ hidden: (B, T, D)
769
+ predictor_weight: (C, D)
770
+ embeddings: (C, D)
771
+ k: Number of top concepts
772
+ block_size: Concepts per block
773
+ topk_on_logits: If True, select top-k by logits; else by sigmoid
774
+
775
+ Returns:
776
+ features: (B, T, D) weighted concept features
777
+ topk_indices: (B, T, k) indices of top-k concepts
778
+ topk_logits: (B, T, k) logits for top-k concepts
779
+ """
780
+ B, T, D = hidden.shape
781
+ C = predictor_weight.size(0)
782
+ BT = B * T
783
+ device = hidden.device
784
+ k = min(k, C)
785
+ flat_h = hidden.reshape(BT, D)
786
+ W_t = predictor_weight.t().contiguous()
787
+ topv = torch.full((BT, k), float('-inf'), device=device, dtype=hidden.dtype)
788
+ topi = torch.zeros((BT, k), device=device, dtype=torch.long)
789
+ for start in range(0, C, block_size):
790
+ end = min(start + block_size, C)
791
+ logits_blk = (flat_h @ W_t[:, start:end]).to(torch.float32).clamp_(-15, 15)
792
+ vals_blk = logits_blk if topk_on_logits else torch.sigmoid(logits_blk)
793
+ blk_k = min(k, end - start)
794
+ v_blk, idx_blk = torch.topk(vals_blk, blk_k, dim=1)
795
+ i_blk = idx_blk + start
796
+ if blk_k < k:
797
+ pad_v = torch.full((BT, k - blk_k), float('-inf'), device=device, dtype=torch.float32)
798
+ pad_i = torch.zeros((BT, k - blk_k), device=device, dtype=torch.long)
799
+ v_blk = torch.cat([v_blk, pad_v], dim=1)
800
+ i_blk = torch.cat([i_blk, pad_i], dim=1)
801
+ topv, topi = ConceptHead._merge_topk(topv, topi, v_blk, i_blk, k)
802
+ W_sel = ConceptHead._safe_index(predictor_weight, topi)
803
+ logits_sel = torch.einsum('bd,bkd->bk', flat_h.to(torch.float32), W_sel.to(torch.float32))
804
+ logits_sel = logits_sel.clamp(-15, 15)
805
+ del W_sel
806
+ weights_sel = torch.sigmoid(logits_sel)
807
+ E_sel = ConceptHead._safe_index(embeddings, topi)
808
+ features = torch.einsum('bk,bkd->bd', weights_sel.to(E_sel.dtype), E_sel)
809
+ return (features.reshape(B, T, D).to(hidden.dtype), topi.reshape(B, T, k), logits_sel.reshape(B, T, k))
810
+
811
+ @staticmethod
812
+ def attention_features_topk_streaming(query: Tensor, embeddings: Tensor, k: int, block_size: int=4096, topk_on_logits: bool=False) -> tuple[Tensor, Tensor, Tensor]:
813
+ """Memory-efficient attention with streaming top-k."""
814
+ B, T, D = query.shape
815
+ C = embeddings.shape[0]
816
+ BT = B * T
817
+ device = query.device
818
+ scale = 1.0 / math.sqrt(D)
819
+ k = min(k, C)
820
+ flat_q = query.reshape(BT, D)
821
+ emb_T = embeddings.t().contiguous()
822
+ topv = torch.full((BT, k), float('-inf'), device=device, dtype=query.dtype)
823
+ topi = torch.zeros((BT, k), device=device, dtype=torch.long)
824
+ for start in range(0, C, block_size):
825
+ end = min(start + block_size, C)
826
+ logits_blk = (flat_q @ emb_T[:, start:end]).to(torch.float32) * scale
827
+ logits_blk = logits_blk.clamp(-15, 15)
828
+ vals_blk = logits_blk if topk_on_logits else torch.sigmoid(logits_blk)
829
+ blk_k = min(k, end - start)
830
+ v_blk, idx_blk = torch.topk(vals_blk, blk_k, dim=1)
831
+ i_blk = idx_blk + start
832
+ if blk_k < k:
833
+ pad_v = torch.full((BT, k - blk_k), float('-inf'), device=device, dtype=torch.float32)
834
+ pad_i = torch.zeros((BT, k - blk_k), device=device, dtype=torch.long)
835
+ v_blk = torch.cat([v_blk, pad_v], dim=1)
836
+ i_blk = torch.cat([i_blk, pad_i], dim=1)
837
+ topv, topi = ConceptHead._merge_topk(topv, topi, v_blk, i_blk, k)
838
+ E_sel = ConceptHead._safe_index(embeddings, topi)
839
+ logits_sel = torch.einsum('bd,bkd->bk', flat_q.to(torch.float32), E_sel.to(torch.float32)) * scale
840
+ logits_sel = logits_sel.clamp(-15, 15)
841
+ weights_sel = torch.sigmoid(logits_sel)
842
+ features = torch.einsum('bk,bkd->bd', weights_sel.to(E_sel.dtype), E_sel)
843
+ return (features.reshape(B, T, D).to(query.dtype), topi.reshape(B, T, k), logits_sel.reshape(B, T, k))
844
+
845
+ def attention_block_features_factorized(self, query: Tensor, block_size: int=4096) -> Tensor:
846
+ """
847
+ Memory-efficient factorized attention over ALL concepts.
848
+
849
+ Uses factorized scoring and feature computation:
850
+ - Scoring: (query @ basis.T) @ coef.T instead of query @ E.T
851
+ - Features: (weights @ coef) @ basis instead of weights @ E
852
+
853
+ FLOPs: O(BT * r * (D + C)) instead of O(BT * D * C)
854
+
855
+ Args:
856
+ query: (B, T, D) query vectors from concept_query_projection
857
+ block_size: Concepts per block for chunked processing
858
+
859
+ Returns:
860
+ (B, T, D) weighted concept features
861
+ """
862
+ assert self.factorize, 'Only valid for factorized head'
863
+ B, T, D = query.shape
864
+ BT = B * T
865
+ C = self.n_concepts
866
+ _ = self.factorize_rank
867
+ device = query.device
868
+ scale = 1.0 / math.sqrt(D)
869
+ flat_q = query.reshape(BT, D)
870
+ coef = self.embedding_coef.weight[:C]
871
+ basis_weight = self.embedding_basis.weight
872
+ q_compressed = flat_q @ basis_weight
873
+ output = torch.zeros(BT, D, dtype=query.dtype, device=device)
874
+ _ = (C + block_size - 1) // block_size
875
+ for _block_idx, start in enumerate(range(0, C, block_size)):
876
+ end = min(start + block_size, C)
877
+ coef_chunk = coef[start:end]
878
+ scores_chunk = (q_compressed @ coef_chunk.T).float() * scale
879
+ scores_chunk = scores_chunk.clamp(-15, 15)
880
+ weights_chunk = torch.sigmoid(scores_chunk)
881
+ weighted_coef = weights_chunk @ coef_chunk.float()
882
+ features_chunk = weighted_coef @ basis_weight.T.to(weighted_coef.dtype)
883
+ output.add_(features_chunk)
884
+ return output.reshape(B, T, D).to(query.dtype)
885
+
886
+ def attention_features_topk_factorized(self, query: Tensor, k: int, block_size: int=4096) -> tuple[Tensor, Tensor, Tensor]:
887
+ """
888
+ Memory-efficient factorized attention with streaming top-k.
889
+
890
+ Pass 1: Find top-k concepts using factorized scoring
891
+ Pass 2: Compute features using only top-k embeddings
892
+
893
+ Args:
894
+ query: (B, T, D) query vectors
895
+ k: Number of top concepts per token
896
+ block_size: Concepts per block
897
+
898
+ Returns:
899
+ features: (B, T, D) weighted concept features
900
+ topk_indices: (B, T, k) top-k concept indices
901
+ topk_logits: (B, T, k) logits for top-k concepts
902
+ """
903
+ assert self.factorize, 'Only valid for factorized head'
904
+ B, T, D = query.shape
905
+ BT = B * T
906
+ C = self.n_concepts
907
+ _ = self.factorize_rank
908
+ device = query.device
909
+ scale = 1.0 / math.sqrt(D)
910
+ k = min(k, C)
911
+ flat_q = query.reshape(BT, D)
912
+ coef = self.embedding_coef.weight[:C]
913
+ basis_weight = self.embedding_basis.weight
914
+ q_compressed = flat_q @ basis_weight
915
+ topv = torch.full((BT, k), float('-inf'), device=device, dtype=query.dtype)
916
+ topi = torch.zeros((BT, k), device=device, dtype=torch.long)
917
+ for start in range(0, C, block_size):
918
+ end = min(start + block_size, C)
919
+ coef_chunk = coef[start:end]
920
+ scores_chunk = q_compressed.float() @ coef_chunk.T.float() * scale
921
+ scores_chunk = scores_chunk.clamp(-15, 15)
922
+ blk_k = min(k, end - start)
923
+ v_chunk, idx_chunk = torch.topk(scores_chunk, blk_k, dim=1)
924
+ i_chunk = idx_chunk + start
925
+ if blk_k < k:
926
+ pad_v = torch.full((BT, k - blk_k), float('-inf'), device=device, dtype=torch.float32)
927
+ pad_i = torch.zeros((BT, k - blk_k), device=device, dtype=torch.long)
928
+ v_chunk = torch.cat([v_chunk, pad_v], dim=1)
929
+ i_chunk = torch.cat([i_chunk, pad_i], dim=1)
930
+ topv, topi = self._merge_topk(topv, topi, v_chunk, i_chunk, k)
931
+ coef_sel = self.embedding_coef(topi)
932
+ logits_sel = torch.einsum('br,bkr->bk', q_compressed.float(), coef_sel.float()) * scale
933
+ logits_sel = logits_sel.clamp(-15, 15)
934
+ weights_sel = torch.sigmoid(logits_sel)
935
+ weighted_coef = torch.einsum('bk,bkr->br', weights_sel.to(coef_sel.dtype), coef_sel)
936
+ features = weighted_coef @ basis_weight.T.to(weighted_coef.dtype)
937
+ return (features.reshape(B, T, D).to(query.dtype), topi.reshape(B, T, k), logits_sel.reshape(B, T, k))
938
+
939
+ def linear_block_features_factorized(self, hidden: Tensor, block_size: int=4096) -> Tensor:
940
+ """
941
+ Memory-efficient factorized linear prediction over ALL concepts.
942
+
943
+ Uses factorized predictor: logits = hidden @ down @ up.T
944
+ Uses factorized embeddings: features = weights @ coef @ basis
945
+
946
+ Args:
947
+ hidden: (B, T, D) hidden states
948
+ block_size: Concepts per block
949
+
950
+ Returns:
951
+ (B, T, D) weighted concept features
952
+ """
953
+ assert self.factorize, 'Only valid for factorized head'
954
+ assert self.predictor_down is not None, 'Linear path requires predictor'
955
+ B, T, D = hidden.shape
956
+ BT = B * T
957
+ C = self.n_concepts
958
+ _ = self.factorize_rank
959
+ device = hidden.device
960
+ flat_h = hidden.reshape(BT, D)
961
+ coef = self.embedding_coef.weight[:C]
962
+ basis_weight = self.embedding_basis.weight
963
+ down_weight = self.predictor_down.weight
964
+ up_weight = self.predictor_up.weight[:C]
965
+ h_compressed = flat_h @ down_weight.T
966
+ output = torch.zeros(BT, D, dtype=hidden.dtype, device=device)
967
+ for start in range(0, C, block_size):
968
+ end = min(start + block_size, C)
969
+ up_chunk = up_weight[start:end]
970
+ coef_chunk = coef[start:end]
971
+ logits_chunk = h_compressed.float() @ up_chunk.T.float()
972
+ logits_chunk = logits_chunk.clamp(-15, 15)
973
+ weights_chunk = torch.sigmoid(logits_chunk)
974
+ weighted_coef = weights_chunk @ coef_chunk.float()
975
+ features_chunk = weighted_coef @ basis_weight.T.to(weighted_coef.dtype)
976
+ output.add_(features_chunk)
977
+ return output.reshape(B, T, D).to(hidden.dtype)
978
+
979
+ def linear_features_topk_factorized(self, hidden: Tensor, k: int, block_size: int=4096) -> tuple[Tensor, Tensor, Tensor]:
980
+ """
981
+ Memory-efficient factorized linear with streaming top-k.
982
+
983
+ Args:
984
+ hidden: (B, T, D) hidden states
985
+ k: Number of top concepts per token
986
+ block_size: Concepts per block
987
+
988
+ Returns:
989
+ features: (B, T, D) weighted concept features
990
+ topk_indices: (B, T, k) top-k concept indices
991
+ topk_logits: (B, T, k) logits for top-k concepts
992
+ """
993
+ assert self.factorize, 'Only valid for factorized head'
994
+ assert self.predictor_down is not None, 'Linear path requires predictor'
995
+ B, T, D = hidden.shape
996
+ BT = B * T
997
+ C = self.n_concepts
998
+ _ = self.factorize_rank
999
+ device = hidden.device
1000
+ k = min(k, C)
1001
+ flat_h = hidden.reshape(BT, D)
1002
+ down_weight = self.predictor_down.weight
1003
+ up_weight = self.predictor_up.weight[:C]
1004
+ basis_weight = self.embedding_basis.weight
1005
+ h_compressed = flat_h @ down_weight.T
1006
+ topv = torch.full((BT, k), float('-inf'), device=device, dtype=hidden.dtype)
1007
+ topi = torch.zeros((BT, k), device=device, dtype=torch.long)
1008
+ for start in range(0, C, block_size):
1009
+ end = min(start + block_size, C)
1010
+ up_chunk = up_weight[start:end]
1011
+ logits_chunk = h_compressed.float() @ up_chunk.T.float()
1012
+ logits_chunk = logits_chunk.clamp(-15, 15)
1013
+ blk_k = min(k, end - start)
1014
+ v_chunk, idx_chunk = torch.topk(logits_chunk, blk_k, dim=1)
1015
+ i_chunk = idx_chunk + start
1016
+ if blk_k < k:
1017
+ pad_v = torch.full((BT, k - blk_k), float('-inf'), device=device, dtype=torch.float32)
1018
+ pad_i = torch.zeros((BT, k - blk_k), device=device, dtype=torch.long)
1019
+ v_chunk = torch.cat([v_chunk, pad_v], dim=1)
1020
+ i_chunk = torch.cat([i_chunk, pad_i], dim=1)
1021
+ topv, topi = self._merge_topk(topv, topi, v_chunk, i_chunk, k)
1022
+ coef_sel = self.embedding_coef(topi)
1023
+ up_sel = self._safe_index(self.predictor_up.weight[:C], topi)
1024
+ logits_sel = torch.einsum('br,bkr->bk', h_compressed.float(), up_sel.float())
1025
+ logits_sel = logits_sel.clamp(-15, 15)
1026
+ weights_sel = torch.sigmoid(logits_sel)
1027
+ weighted_coef = torch.einsum('bk,bkr->br', weights_sel.to(coef_sel.dtype), coef_sel)
1028
+ features = weighted_coef @ basis_weight.T.to(weighted_coef.dtype)
1029
+ return (features.reshape(B, T, D).to(hidden.dtype), topi.reshape(B, T, k), logits_sel.reshape(B, T, k))
1030
+
1031
+ def compute_logits_for_indices(self, hidden: Tensor, indices: Tensor) -> Tensor:
1032
+ """
1033
+ Compute logits for specific concept indices only (sparse).
1034
+
1035
+ Supports both dense and factorized heads.
1036
+
1037
+ IMPORTANT: This function materializes (M, K, D) where M is the number of
1038
+ tokens in hidden. Only call this with small M (e.g., masked tokens only).
1039
+
1040
+ Args:
1041
+ hidden: (M, D) or (B, T, D) hidden states
1042
+ indices: (M, K) or (B, T, K) concept indices
1043
+
1044
+ Returns:
1045
+ logits: Same shape as indices
1046
+ """
1047
+ if hidden.dim() == 2:
1048
+ M, D = hidden.shape
1049
+ K = indices.size(-1)
1050
+ flat_h = hidden
1051
+ flat_idx = indices
1052
+ output_shape = indices.shape
1053
+ else:
1054
+ B, T, D = hidden.shape
1055
+ K = indices.size(-1)
1056
+ M = B * T
1057
+ flat_h = hidden.reshape(M, D)
1058
+ flat_idx = indices.reshape(M, K)
1059
+ output_shape = indices.shape
1060
+ estimated_bytes = M * K * D * 2
1061
+ if estimated_bytes > 1000000000.0:
1062
+ warnings.warn(f'compute_logits_for_indices will allocate ~{estimated_bytes / 1000000000.0:.1f} GB. Consider reducing M={M} (use masked tokens only) or K={K}.')
1063
+ n_valid = self.n_concepts
1064
+ indices_safe = flat_idx.clamp(0, n_valid - 1)
1065
+ if self.use_attention:
1066
+ query = self.concept_query_projection(flat_h.unsqueeze(0)).squeeze(0)
1067
+ scale = 1.0 / math.sqrt(self.concept_dim)
1068
+ E_sel = self._get_embedding(indices_safe)
1069
+ logits = torch.einsum('md,mkd->mk', query.float(), E_sel.float()) * scale
1070
+ else:
1071
+ if self.factorize:
1072
+ W = self._get_predictor_weight()[:n_valid]
1073
+ W_sel = self._safe_index(W, indices_safe)
1074
+ else:
1075
+ W = self.concept_predictor.weight[:n_valid]
1076
+ W_sel = self._safe_index(W, indices_safe)
1077
+ logits = torch.einsum('md,mkd->mk', flat_h.float(), W_sel.float())
1078
+ return logits.clamp(-15, 15).reshape(output_shape)
1079
+
1080
+ def get_concept_weights(self, hidden: Tensor, concept_ids: Tensor) -> Tensor:
1081
+ """
1082
+ Get sigmoid weights for specific concepts (for attribution).
1083
+
1084
+ Args:
1085
+ hidden: (B, T, D) or (M, D) hidden states
1086
+ concept_ids: (B, T, K) or (M, K) or (K,) concept indices
1087
+
1088
+ Returns:
1089
+ weights: Same shape as concept_ids, values in [0, 1]
1090
+ """
1091
+ if concept_ids.dim() == 1:
1092
+ if hidden.dim() == 2:
1093
+ M = hidden.size(0)
1094
+ concept_ids = concept_ids.unsqueeze(0).expand(M, -1)
1095
+ else:
1096
+ B, T, _ = hidden.shape
1097
+ concept_ids = concept_ids.unsqueeze(0).unsqueeze(0).expand(B, T, -1)
1098
+ logits = self.compute_logits_for_indices(hidden, concept_ids)
1099
+ return torch.sigmoid(logits)
1100
+
1101
+ @staticmethod
1102
+ def blocked_logits(query: Tensor, embeddings: Tensor, block_size: int=8192, out_device: torch.device | None=None, out_dtype: torch.dtype=torch.float32) -> Tensor:
1103
+ """
1104
+ Compute concept logits in column blocks for memory efficiency.
1105
+
1106
+ logits = query @ embeddings.T / sqrt(D)
1107
+ """
1108
+ B, T, D = query.shape
1109
+ C = embeddings.size(0)
1110
+ scale = 1.0 / math.sqrt(D)
1111
+ dev = query.device if out_device is None else out_device
1112
+ logits = torch.empty(B, T, C, device=dev, dtype=out_dtype)
1113
+ q = query.reshape(-1, D).to(torch.float32)
1114
+ Et = embeddings.t().contiguous().to(torch.float32)
1115
+ for s in range(0, C, block_size):
1116
+ e = min(s + block_size, C)
1117
+ scores = q @ Et[:, s:e] * scale
1118
+ scores = scores.clamp(-15, 15)
1119
+ logits[:, :, s:e] = scores.reshape(B, T, e - s).to(out_dtype)
1120
+ return logits
1121
+
1122
+ @staticmethod
1123
+ def blocked_mix(weights: Tensor, embeddings: Tensor, block_size: int=8192) -> Tensor:
1124
+ """
1125
+ Compute weighted sum of embeddings in column blocks.
1126
+
1127
+ output = weights @ embeddings
1128
+ """
1129
+ B, T, C = weights.shape
1130
+ D = embeddings.size(1)
1131
+ out = torch.zeros(B, T, D, device=weights.device, dtype=weights.dtype)
1132
+ for s in range(0, C, block_size):
1133
+ e = min(s + block_size, C)
1134
+ w_blk = weights[:, :, s:e].to(torch.float32)
1135
+ V_blk = embeddings[s:e].to(w_blk.dtype)
1136
+ out.add_(w_blk @ V_blk)
1137
+ return out.to(weights.dtype)
1138
+
1139
+ @staticmethod
1140
+ def sigmoid_block_attention(query: Tensor, embeddings: Tensor, block_size: int=8192, return_logits: bool=False) -> Tensor | tuple[Tensor, Tensor]:
1141
+ """Memory-efficient sigmoid attention using block processing."""
1142
+ B, T, D = query.shape
1143
+ C = embeddings.shape[0]
1144
+ scale = 1.0 / math.sqrt(D)
1145
+ flat_q = query.reshape(-1, D)
1146
+ emb_T = embeddings.t().contiguous()
1147
+ output = torch.zeros(B * T, D, dtype=query.dtype, device=query.device)
1148
+ logits: Tensor | None = None
1149
+ if return_logits:
1150
+ logits = torch.empty(B, T, C, dtype=torch.float32, device=query.device)
1151
+ for start in range(0, C, block_size):
1152
+ end = min(start + block_size, C)
1153
+ scores = (flat_q @ emb_T[:, start:end]).to(torch.float32) * scale
1154
+ scores = scores.clamp(-15, 15)
1155
+ if logits is not None:
1156
+ logits[:, :, start:end] = scores.reshape(B, T, end - start)
1157
+ weights = torch.sigmoid(scores)
1158
+ output.add_(weights @ embeddings[start:end].to(weights.dtype))
1159
+ output = output.reshape(B, T, D).to(query.dtype)
1160
+ if return_logits:
1161
+ assert logits is not None
1162
+ return (output, logits)
1163
+ return output
1164
+
1165
+ def _apply_sparse_interventions(self, features: Tensor, hidden: Tensor, intervene_ids: Tensor, intervene_vals: Tensor) -> Tensor:
1166
+ """
1167
+ Apply sparse interventions matching original dense behavior.
1168
+
1169
+ Original dense behavior:
1170
+ weights = sigmoid(logits) # (B, T, C)
1171
+ weights[..., c] = new_val # Override
1172
+ features = weights @ embeddings
1173
+
1174
+ Sparse equivalent:
1175
+ features += (new_val - current_weight) * embedding[c]
1176
+ """
1177
+ B, T, D = features.shape
1178
+ valid = intervene_ids != -1
1179
+ if not valid.any():
1180
+ return features
1181
+ ids_safe = intervene_ids.clamp(0, self.n_concepts - 1)
1182
+ current_logits = self.compute_logits_for_indices(hidden, ids_safe)
1183
+ current_weights = torch.sigmoid(current_logits)
1184
+ emb = self._get_embedding(ids_safe)
1185
+ delta = (intervene_vals - current_weights) * valid.float()
1186
+ correction = (delta.unsqueeze(-1) * emb).sum(dim=2)
1187
+ return features + correction
1188
+
1189
+ def _apply_dense_interventions(self, concept_weight: Tensor, intervene_ids: Tensor, intervene_vals: Tensor) -> Tensor:
1190
+ """Apply interventions by overriding concept weights (dense path)."""
1191
+ n_valid = min(self.n_concepts, concept_weight.size(-1))
1192
+ valid_edit = intervene_ids != -1
1193
+ ids = intervene_ids.clamp(0, n_valid - 1).long()
1194
+ vals = intervene_vals.to(concept_weight.dtype)
1195
+ updates = torch.zeros_like(concept_weight)
1196
+ updates.scatter_add_(2, ids, torch.where(valid_edit, vals, torch.zeros_like(vals)))
1197
+ set_mask = torch.zeros_like(concept_weight, dtype=torch.bool)
1198
+ set_mask.scatter_(2, ids, valid_edit)
1199
+ return torch.where(set_mask, updates, concept_weight)
1200
+
1201
+ def topk_with_cutoff(self, tensor: Tensor, dim: int=-1) -> Tensor:
1202
+ """
1203
+ Apply top-k sparsity, zeroing out all but top-k values.
1204
+
1205
+ Args:
1206
+ tensor: Input tensor, typically (B, T, C)
1207
+ dim: Dimension to apply top-k (default: last)
1208
+
1209
+ Returns:
1210
+ Sparse tensor with only top-k values preserved
1211
+ """
1212
+ assert dim == -1 or dim == tensor.dim() - 1
1213
+ if self.topk is None:
1214
+ return tensor
1215
+ padded = tensor.size(dim)
1216
+ n_valid = min(self.n_concepts, padded)
1217
+ if n_valid <= 0:
1218
+ return torch.zeros_like(tensor)
1219
+ x = tensor.narrow(dim, 0, n_valid)
1220
+ kk = min(self.topk, n_valid)
1221
+ topv, topi = torch.topk(x, kk, dim=dim)
1222
+ out = torch.zeros_like(x)
1223
+ out.scatter_(dim, topi, topv)
1224
+ if n_valid < padded:
1225
+ pad_shape = list(out.shape)
1226
+ pad_shape[dim] = padded - n_valid
1227
+ pad_zeros = out.new_zeros(pad_shape)
1228
+ out = torch.cat([out, pad_zeros], dim=dim)
1229
+ return out
1230
+
1231
+ def _compute_weights(self, concept_logits: Tensor, E: Tensor) -> Tensor:
1232
+ """Compute concept weights from logits, with optional top-k sparsity."""
1233
+ apply_topk = self.topk is not None and (not self.is_unknown or self.apply_topk_to_unknown)
1234
+ if apply_topk and self.topk_on_logits:
1235
+ logits_for_weights = self.topk_with_cutoff(concept_logits)
1236
+ weights = torch.sigmoid(logits_for_weights).to(E.dtype)
1237
+ return weights
1238
+ weights = torch.sigmoid(concept_logits).to(E.dtype)
1239
+ if apply_topk and (not self.topk_on_logits):
1240
+ weights = self.topk_with_cutoff(weights)
1241
+ return weights
1242
+
1243
+ @torch.compiler.disable
1244
+ def forward(self, hidden: Tensor, intervene_ids: Tensor | None=None, intervene_vals: Tensor | None=None, return_logits: bool=False, store_hidden: bool=False) -> ConceptHeadOutput:
1245
+ """
1246
+ Forward pass for concept decomposition (inference only, no teacher forcing).
1247
+
1248
+ Args:
1249
+ hidden: Transformer hidden states (B, T, n_embd)
1250
+ intervene_ids: Concept IDs to intervene on (B, T, K_int), -1 = skip
1251
+ intervene_vals: Intervention strength values (B, T, K_int)
1252
+ return_logits: If True, compute full (B, T, C) logits. Forbidden for large heads.
1253
+ store_hidden: If True, store hidden in output for later attribution.
1254
+
1255
+ Returns:
1256
+ ConceptHeadOutput with features, predicted, topk_indices, topk_logits
1257
+ """
1258
+ B, T, _ = hidden.shape
1259
+ has_interventions = intervene_ids is not None and intervene_vals is not None
1260
+ if return_logits:
1261
+ self._check_dense_allowed('return_logits=True')
1262
+ n_valid = self.n_concepts
1263
+ concept_logits: Tensor | None = None
1264
+ concept_weight: Tensor | None = None
1265
+ predicted: Tensor
1266
+ topk_indices: Tensor | None = None
1267
+ topk_logits: Tensor | None = None
1268
+ apply_topk = self.topk is not None and (not self.is_unknown or self.apply_topk_to_unknown)
1269
+ k_features = self.topk_features if self.topk_features is not None else self.topk
1270
+ use_dense_intervention = has_interventions and (not self._is_large)
1271
+ if use_dense_intervention:
1272
+ E = self._get_embedding_weight()[:n_valid]
1273
+ if self.use_attention:
1274
+ query = self.concept_query_projection(hidden)
1275
+ concept_logits = self.blocked_logits(query, E, block_size=self.block_size)
1276
+ else:
1277
+ if self.factorize:
1278
+ W = self._get_predictor_weight()[:n_valid]
1279
+ raw_logits = hidden @ W.T
1280
+ else:
1281
+ raw_logits = self.concept_predictor(hidden)[..., :n_valid]
1282
+ concept_logits = raw_logits.float().clamp(-15, 15)
1283
+ concept_weight = self._compute_weights(concept_logits, E)
1284
+ assert intervene_ids is not None and intervene_vals is not None
1285
+ concept_weight = self._apply_dense_interventions(concept_weight, intervene_ids, intervene_vals)
1286
+ predicted = self.blocked_mix(concept_weight, E, block_size=self.block_size)
1287
+ elif self.factorize:
1288
+ if self.use_attention:
1289
+ query = self.concept_query_projection(hidden)
1290
+ if apply_topk:
1291
+ predicted, topk_indices, topk_logits = self.attention_features_topk_factorized(query, k=k_features, block_size=self.block_size)
1292
+ else:
1293
+ predicted = self.attention_block_features_factorized(query, block_size=self.block_size)
1294
+ elif apply_topk:
1295
+ predicted, topk_indices, topk_logits = self.linear_features_topk_factorized(hidden, k=k_features, block_size=self.block_size)
1296
+ else:
1297
+ predicted = self.linear_block_features_factorized(hidden, block_size=self.block_size)
1298
+ elif apply_topk:
1299
+ E = self._get_embedding_weight()[:n_valid]
1300
+ if self.use_attention:
1301
+ query = self.concept_query_projection(hidden)
1302
+ predicted, topk_indices, topk_logits = self.attention_features_topk_streaming(query, E, k=k_features, block_size=self.block_size, topk_on_logits=self.topk_on_logits)
1303
+ else:
1304
+ W = self.concept_predictor.weight[:n_valid]
1305
+ predicted, topk_indices, topk_logits = self.linear_features_topk_streaming(hidden, W, E, k=k_features, block_size=self.block_size, topk_on_logits=self.topk_on_logits)
1306
+ else:
1307
+ E = self._get_embedding_weight()[:n_valid]
1308
+ if self.use_attention:
1309
+ query = self.concept_query_projection(hidden)
1310
+ predicted = self.attention_block_features(query, E, block_size=self.block_size)
1311
+ else:
1312
+ W = self.concept_predictor.weight[:n_valid]
1313
+ predicted = self.linear_block_features(hidden, W, E, block_size=self.block_size)
1314
+ if topk_indices is not None and self.topk is not None and (self.topk_features is not None) and (self.topk_features > self.topk):
1315
+ _, rerank_idx = torch.topk(topk_logits, self.topk, dim=-1)
1316
+ topk_indices = torch.gather(topk_indices, -1, rerank_idx)
1317
+ topk_logits = torch.gather(topk_logits, -1, rerank_idx)
1318
+ if return_logits and (not use_dense_intervention):
1319
+ E = self._get_embedding_weight()[:n_valid]
1320
+ if self.use_attention:
1321
+ query = self.concept_query_projection(hidden)
1322
+ concept_logits = self.blocked_logits(query, E, block_size=self.block_size)
1323
+ else:
1324
+ if self.factorize:
1325
+ W = self._get_predictor_weight()[:n_valid]
1326
+ raw_logits = hidden @ W.T
1327
+ else:
1328
+ raw_logits = self.concept_predictor(hidden)[..., :n_valid]
1329
+ concept_logits = raw_logits.float().clamp(-15, 15)
1330
+ concept_weight = self._compute_weights(concept_logits, E)
1331
+ if not hasattr(self, '_logged_forward_path'):
1332
+ self._logged_forward_path = True
1333
+ path = 'dense_intervention' if use_dense_intervention else 'factorized_topk' if self.factorize and apply_topk else 'factorized_all' if self.factorize else 'streaming_topk' if apply_topk else 'dense_all'
1334
+ logger.info(f"[ConceptHead] {('Unknown' if self.is_unknown else 'Known')} head: path={path}, topk={self.topk}, topk_features={self.topk_features}, n_concepts={self.n_concepts}, factorize={self.factorize}, apply_topk={apply_topk}")
1335
+ if topk_indices is not None and self.topk is not None and (self.topk_features is not None) and (self.topk_features > self.topk):
1336
+ if not hasattr(self, '_logged_topk_slice'):
1337
+ self._logged_topk_slice = True
1338
+ logger.info(f"[ConceptHead] {('Unknown' if self.is_unknown else 'Known')} head: Sliced topk: {self.topk_features} features -> {self.topk} for loss")
1339
+ if has_interventions and (not use_dense_intervention):
1340
+ assert intervene_ids is not None and intervene_vals is not None
1341
+ predicted = self._apply_sparse_interventions(predicted, hidden, intervene_ids, intervene_vals)
1342
+ return ConceptHeadOutput(features=predicted, gt_features=None, logits=concept_logits, predicted=predicted, weights=concept_weight, topk_indices=topk_indices, topk_logits=topk_logits, hidden=hidden.detach() if store_hidden else None)
1343
+
1344
+ # ======================================================================
1345
+ # steerling/models/interpretable/interpretable_causal_diffusion.py
1346
+ # ======================================================================
1347
+
1348
+ logger = logging.getLogger(__name__)
1349
+
1350
+ class InterpretableCausalDiffusionLM(nn.Module):
1351
+ """
1352
+ Interpretable CausalDiffusionLM with concept decomposition heads.
1353
+
1354
+ Wraps a CausalDiffusionLM and adds:
1355
+ - Known concept head: predicts known concepts from hidden states
1356
+ - Unknown concept head: captures residual features (optional)
1357
+ - Steering via concept interventions
1358
+
1359
+ Args:
1360
+ config: CausalDiffusionConfig (model architecture)
1361
+ concept_config: ConceptConfig (concept decomposition)
1362
+ vocab_size: Vocabulary size
1363
+ """
1364
+
1365
+ def __init__(self, config: CausalDiffusionConfig, concept_config: ConceptConfig, vocab_size: int):
1366
+ super().__init__()
1367
+ self.config = config
1368
+ self.concept_config = concept_config
1369
+ self.vocab_size = vocab_size
1370
+ self.transformer = CausalDiffusionLM(config, vocab_size)
1371
+ self.known_head = ConceptHead(n_concepts=concept_config.n_concepts, concept_dim=concept_config.concept_dim, n_embd=config.n_embd, is_unknown=False, use_attention=concept_config.use_attention_known, topk=concept_config.topk_known, topk_features=concept_config.topk_known_features, block_size=concept_config.block_size, pad_multiple=concept_config.pad_multiple, store_unknown_weights=False, apply_topk_to_unknown=False, topk_on_logits=concept_config.topk_on_logits)
1372
+ if concept_config.use_unknown:
1373
+ if concept_config.n_unknown_concepts is None:
1374
+ raise ValueError('n_unknown_concepts must be set when use_unknown=True')
1375
+ self.unknown_head: ConceptHead | None = ConceptHead(n_concepts=concept_config.n_unknown_concepts, concept_dim=concept_config.concept_dim, n_embd=config.n_embd, is_unknown=True, use_attention=concept_config.use_attention_unknown, topk=concept_config.unknown_topk, block_size=concept_config.block_size, pad_multiple=concept_config.pad_multiple, store_unknown_weights=False, apply_topk_to_unknown=concept_config.apply_topk_to_unknown, topk_on_logits=concept_config.topk_on_logits, factorize=concept_config.factorize_unknown, factorize_rank=concept_config.factorize_rank)
1376
+ else:
1377
+ self.unknown_head = None
1378
+
1379
+ def forward(self, input_ids: Tensor, *, input_embeds: Tensor | None=None, intervene_known_ids: Tensor | None=None, intervene_known_vals: Tensor | None=None, intervene_unknown_ids: Tensor | None=None, intervene_unknown_vals: Tensor | None=None, minimal_output: bool=False, position_injection: Tensor | None=None, steering_inject_layer: int | None=None, steering_inject_alpha: float=1.0, unknown_topk: int=64) -> tuple[Tensor, InterpretableOutput]:
1380
+ """
1381
+ Forward pass with concept decomposition.
1382
+
1383
+ Args:
1384
+ input_ids: Token IDs (B, T). May contain mask tokens.
1385
+ input_embeds: Pre-computed embeddings (B, T, D). Overrides input_ids.
1386
+ intervene_known_ids: Known concept IDs to intervene (B, T, K_int)
1387
+ intervene_known_vals: Intervention values for known (B, T, K_int)
1388
+ intervene_unknown_ids: Unknown concept IDs to intervene (B, T, K_int)
1389
+ intervene_unknown_vals: Intervention values for unknown (B, T, K_int)
1390
+ minimal_output: If True, skip some expensive computations
1391
+ position_injection: Per-position steering injection (B, T, D)
1392
+ steering_inject_layer: Inject at layers >= this
1393
+ steering_inject_alpha: Injection strength
1394
+ unknown_topk: Top-k for unknown head attribution
1395
+
1396
+ Returns:
1397
+ logits: LM logits (B, T, V)
1398
+ outputs: InterpretableOutput with all decomposition components
1399
+ """
1400
+ need_dense_logits = not minimal_output
1401
+ if position_injection is not None and steering_inject_layer is not None:
1402
+ hidden = self._forward_with_injection(input_ids, input_embeds, position_injection, steering_inject_layer, steering_inject_alpha)
1403
+ else:
1404
+ hidden = self.transformer(input_ids, input_embeds=input_embeds, return_hidden=True)
1405
+ known_out: ConceptHeadOutput = self.known_head(hidden, intervene_ids=intervene_known_ids, intervene_vals=intervene_known_vals, return_logits=need_dense_logits)
1406
+ known_features = known_out.features.to(hidden.dtype)
1407
+ unk = hidden - known_features.detach()
1408
+ unk_for_lm: Tensor = unk
1409
+ unknown_out: ConceptHeadOutput | None = None
1410
+ unk_hat: Tensor | None = None
1411
+ if self.unknown_head is not None:
1412
+ unknown_out = self.unknown_head(hidden.detach(), intervene_ids=intervene_unknown_ids, intervene_vals=intervene_unknown_vals, return_logits=not minimal_output and (not self.unknown_head._is_large))
1413
+ assert unknown_out is not None
1414
+ unk_hat = unknown_out.features.to(hidden.dtype)
1415
+ unk_for_lm = unk_hat.detach()
1416
+ epsilon_true = None
1417
+ if self.unknown_head is not None and unk_hat is not None:
1418
+ epsilon_true = hidden.detach() - (known_out.predicted + unk_hat)
1419
+ epsilon = None
1420
+ if self.concept_config.use_epsilon_correction and intervene_known_ids is None:
1421
+ epsilon = hidden - (unk_for_lm + known_features)
1422
+ unk_for_lm = unk_for_lm + epsilon
1423
+ composed = unk_for_lm + known_features
1424
+ logits = self.transformer.lm_head(composed)
1425
+ _unk_topk_indices = unknown_out.topk_indices if unknown_out else None
1426
+ _unk_topk_logits = unknown_out.topk_logits if unknown_out else None
1427
+ if not minimal_output and self.unknown_head is not None and (unknown_out is not None) and (_unk_topk_indices is None) and (unknown_topk > 0):
1428
+ with torch.no_grad():
1429
+ _unk_topk_indices, _unk_topk_logits = self._compute_unknown_topk(hidden, unknown_topk)
1430
+ outputs = InterpretableOutput(hidden=hidden, known_features=known_features, known_logits=known_out.logits, known_gt_features=known_out.gt_features, known_predicted=known_out.predicted, known_weights=known_out.weights, known_topk_indices=known_out.topk_indices, known_topk_logits=known_out.topk_logits, unk=unk, unk_hat=unk_hat, unk_for_lm=unk_for_lm, unknown_logits=unknown_out.logits if unknown_out else None, unknown_weights=unknown_out.weights if unknown_out else None, unknown_topk_indices=_unk_topk_indices, unknown_topk_logits=_unk_topk_logits, composed=composed, epsilon=epsilon, epsilon_true=epsilon_true)
1431
+ return (logits, outputs)
1432
+
1433
+ def _compute_unknown_topk(self, hidden: Tensor, unknown_topk: int) -> tuple[Tensor | None, Tensor | None]:
1434
+ """Compute unknown head top-k indices for attribution."""
1435
+ assert self.unknown_head is not None
1436
+ if self.unknown_head.factorize:
1437
+ if self.unknown_head.use_attention:
1438
+ _query = self.unknown_head.concept_query_projection(hidden.detach())
1439
+ _, indices, logits = self.unknown_head.attention_features_topk_factorized(_query, k=unknown_topk, block_size=self.unknown_head.block_size)
1440
+ else:
1441
+ _, indices, logits = self.unknown_head.linear_features_topk_factorized(hidden.detach(), k=unknown_topk, block_size=self.unknown_head.block_size)
1442
+ else:
1443
+ _E = self.unknown_head._get_embedding_weight()[:self.unknown_head.n_concepts]
1444
+ if self.unknown_head.use_attention:
1445
+ _query = self.unknown_head.concept_query_projection(hidden.detach())
1446
+ _, indices, logits = self.unknown_head.attention_features_topk_streaming(_query, _E, k=unknown_topk, block_size=self.unknown_head.block_size)
1447
+ else:
1448
+ _W = self.unknown_head.concept_predictor.weight[:self.unknown_head.n_concepts]
1449
+ _, indices, logits = self.unknown_head.linear_features_topk_streaming(hidden.detach(), _W, _E, k=unknown_topk, block_size=self.unknown_head.block_size)
1450
+ return (indices, logits)
1451
+
1452
+ def _forward_with_injection(self, input_ids: Tensor, input_embeds: Tensor | None, position_injection: Tensor, inject_layer: int, inject_alpha: float) -> Tensor:
1453
+ """Forward through transformer with steering injection at specified layers."""
1454
+ x = input_embeds if input_embeds is not None else self.transformer.tok_emb(input_ids)
1455
+ for i, block in enumerate(self.transformer.blocks):
1456
+ x = block(x)
1457
+ if i + 1 >= inject_layer:
1458
+ x = x + inject_alpha * position_injection
1459
+ x = self.transformer.ln_f(x)
1460
+ return x
1461
+
1462
+ @torch.no_grad()
1463
+ def intervene(self, input_ids: Tensor, known: dict[int, float] | None=None, unknown: dict[int, float] | None=None, positions: Tensor | None=None) -> tuple[Tensor, InterpretableOutput]:
1464
+ """
1465
+ Run inference with concept interventions.
1466
+
1467
+ Args:
1468
+ input_ids: Input token IDs (B, T)
1469
+ known: Dict mapping known concept IDs to intervention strengths
1470
+ unknown: Dict mapping unknown concept IDs to intervention strengths
1471
+ positions: Bool mask of positions to intervene (B, T). Default: all.
1472
+
1473
+ Returns:
1474
+ logits: LM logits (B, T, V)
1475
+ outputs: InterpretableOutput
1476
+ """
1477
+ B, T = input_ids.shape
1478
+ device = input_ids.device
1479
+ if positions is None:
1480
+ positions = torch.ones(B, T, dtype=torch.bool, device=device)
1481
+ int_known_ids, int_known_vals = (None, None)
1482
+ if known is not None and len(known) > 0:
1483
+ int_known_ids, int_known_vals = self._build_intervention_tensors(known, B, T, positions, device)
1484
+ int_unknown_ids, int_unknown_vals = (None, None)
1485
+ if unknown is not None and len(unknown) > 0:
1486
+ int_unknown_ids, int_unknown_vals = self._build_intervention_tensors(unknown, B, T, positions, device)
1487
+ return self(input_ids, intervene_known_ids=int_known_ids, intervene_known_vals=int_known_vals, intervene_unknown_ids=int_unknown_ids, intervene_unknown_vals=int_unknown_vals, minimal_output=False)
1488
+
1489
+ @staticmethod
1490
+ def _build_intervention_tensors(interventions: dict[int, float], B: int, T: int, positions: Tensor, device: torch.device) -> tuple[Tensor, Tensor]:
1491
+ """Build intervention tensors for concept steering."""
1492
+ K = len(interventions)
1493
+ concept_ids = list(interventions.keys())
1494
+ directions = list(interventions.values())
1495
+ ids = torch.full((B, T, K), -1, dtype=torch.long, device=device)
1496
+ vals = torch.zeros((B, T, K), dtype=torch.float32, device=device)
1497
+ concept_tensor = torch.tensor(concept_ids, device=device)
1498
+ direction_tensor = torch.tensor(directions, dtype=torch.float32, device=device)
1499
+ n_active = int(positions.sum().item())
1500
+ ids[positions] = concept_tensor.unsqueeze(0).expand(n_active, -1)
1501
+ vals[positions] = direction_tensor.unsqueeze(0).expand(n_active, -1)
1502
+ return (ids, vals)
1503
+
1504
+ def get_num_params(self, non_embedding: bool=True) -> int:
1505
+ n_params = sum((p.numel() for p in self.parameters()))
1506
+ if non_embedding and hasattr(self.transformer, 'tok_emb'):
1507
+ n_params -= self.transformer.tok_emb.weight.numel()
1508
+ return n_params
1509
+ from transformers import PreTrainedModel
1510
+ from .configuration_steerling import SteerlingConfig
1511
+
1512
+
1513
+ # CausalDiffusionLM is the backbone — alias to HF-friendly name
1514
+ SteerlingBackbone = CausalDiffusionLM
1515
+
1516
+
1517
+ class SteerlingForCausalLM(PreTrainedModel):
1518
+ config_class = SteerlingConfig
1519
+ supports_gradient_checkpointing = False
1520
+ _tied_weights_keys = ["transformer.lm_head.weight"]
1521
+
1522
+ def __init__(self, config: SteerlingConfig):
1523
+ super().__init__(config)
1524
+ # SteerlingConfig has all fields from both arch and concept configs
1525
+ self.concept_config = config
1526
+ self.transformer = SteerlingBackbone(config, config.vocab_size)
1527
+ self.known_head = ConceptHead(
1528
+ n_concepts=config.n_concepts,
1529
+ concept_dim=config.concept_dim,
1530
+ n_embd=config.n_embd,
1531
+ is_unknown=False,
1532
+ use_attention=config.use_attention_known,
1533
+ topk=config.topk_known,
1534
+ topk_features=config.topk_known_features,
1535
+ block_size=config.concept_block_size,
1536
+ pad_multiple=config.pad_multiple,
1537
+ store_unknown_weights=False,
1538
+ apply_topk_to_unknown=False,
1539
+ topk_on_logits=config.topk_on_logits,
1540
+ factorize=False,
1541
+ )
1542
+ if config.use_unknown:
1543
+ self.unknown_head = ConceptHead(
1544
+ n_concepts=config.n_unknown_concepts,
1545
+ concept_dim=config.concept_dim,
1546
+ n_embd=config.n_embd,
1547
+ is_unknown=True,
1548
+ use_attention=config.use_attention_unknown,
1549
+ topk=config.unknown_topk,
1550
+ block_size=config.concept_block_size,
1551
+ pad_multiple=config.pad_multiple,
1552
+ store_unknown_weights=config.store_unknown_weights,
1553
+ apply_topk_to_unknown=config.apply_topk_to_unknown,
1554
+ topk_on_logits=config.topk_on_logits,
1555
+ factorize=config.factorize_unknown,
1556
+ factorize_rank=config.factorize_rank,
1557
+ )
1558
+ else:
1559
+ self.unknown_head = None
1560
+ self.post_init()
1561
+
1562
+ def _init_weights(self, module):
1563
+ pass
1564
+
1565
+ def _tie_weights(self):
1566
+ if self.config.weight_sharing:
1567
+ self.transformer.lm_head.weight = self.transformer.tok_emb.weight
1568
+
1569
+ def forward(self, input_ids=None, **kwargs):
1570
+ if self.config.interpretable:
1571
+ return InterpretableCausalDiffusionLM.forward(self, input_ids, **kwargs)
1572
+ else:
1573
+ kwargs.pop('minimal_output', None)
1574
+ return CausalDiffusionLM.forward(self, input_ids, **kwargs)