DigitalDaimyo commited on
Commit
9aaf11e
·
verified ·
1 Parent(s): 8f95e75

Upload analysis.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. analysis.py +1493 -0
analysis.py ADDED
@@ -0,0 +1,1493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """
3
+ Addressed State Attention (ASA) - Analysis Harness
4
+
5
+ Research implementation with mechanistic intervention capabilities.
6
+ For efficient training without interventions, use asm_training.py instead.
7
+
8
+ Features:
9
+ - Slot-mask causal interventions (slot_mask, slot_mask_where, slot_mask_scope)
10
+ - Refinement decomposition (orthogonal/parallel gating)
11
+ - Per-head geometry logging
12
+ - Configurable information storage (info_level, info_cfg)
13
+
14
+ Checkpoint Compatibility:
15
+ All parameter/buffer names match asm_training.py for weight sharing.
16
+ Do NOT rename: slot_keys, Wk_write, Wv_write, Wq_read, out_proj,
17
+ _alibi_slopes, _alibi_strength_param, _content_read_gamma_raw,
18
+ slot_in/slot_q/slot_k/slot_v/slot_out, _slotspace_gate_raw,
19
+ rope/rope_slotspace buffers.
20
+
21
+ Repository: https://github.com/DigitalDaimyo/AddressedStateAttention
22
+ Paper: https://github.com/DigitalDaimyo/AddressedStateAttention/tree/main/paper_drafts
23
+ """
24
+
25
+ import math
26
+ from dataclasses import dataclass
27
+ from typing import Optional, Dict, Tuple, List
28
+
29
+ import torch
30
+ import torch.nn as nn
31
+ import torch.nn.functional as F
32
+
33
+
34
+ __all__ = [
35
+ 'AddressedStateAttention',
36
+ 'ASMBlock',
37
+ 'ASMLanguageModel',
38
+ 'ASMTrainConfig',
39
+ 'build_model_from_cfg',
40
+ ]
41
+
42
+
43
+ # ------------------------------------------------------------------ helpers ---
44
+
45
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
46
+ x1 = x[..., ::2]
47
+ x2 = x[..., 1::2]
48
+ return torch.stack((-x2, x1), dim=-1).flatten(-2)
49
+
50
+
51
+ class RotaryEmbedding(nn.Module):
52
+ def __init__(self, dim: int, base: float = 10000.0):
53
+ super().__init__()
54
+ assert dim % 2 == 0, "RoPE requires even dim"
55
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
56
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
57
+ self._cos_cached = None
58
+ self._sin_cached = None
59
+ self._t_cached = None
60
+ self._device_cached = None
61
+
62
+ def get_cos_sin(self, T: int, device, dtype):
63
+ if (
64
+ self._t_cached == T
65
+ and self._cos_cached is not None
66
+ and self._device_cached == device
67
+ ):
68
+ return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype)
69
+ t = torch.arange(T, device=device, dtype=self.inv_freq.dtype)
70
+ freqs = torch.einsum("t,f->tf", t, self.inv_freq)
71
+ emb = torch.cat([freqs, freqs], dim=-1)
72
+ cos = emb.cos()[None, None, :, :]
73
+ sin = emb.sin()[None, None, :, :]
74
+ self._t_cached = T
75
+ self._device_cached = device
76
+ self._cos_cached = cos
77
+ self._sin_cached = sin
78
+ return cos.to(dtype=dtype), sin.to(dtype=dtype)
79
+
80
+
81
+ def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
82
+ return (x * cos) + (_rotate_half(x) * sin)
83
+
84
+
85
+ def alibi_slopes(num_heads: int, device=None, dtype=torch.float32) -> torch.Tensor:
86
+ def get_slopes(n):
87
+ def power_of_2_slopes(n):
88
+ start = 2.0 ** (-(2.0 ** -(math.log2(n) - 3)))
89
+ ratio = start
90
+ return [start * (ratio ** i) for i in range(n)]
91
+ if math.log2(n).is_integer():
92
+ return power_of_2_slopes(n)
93
+ closest = 2 ** math.floor(math.log2(n))
94
+ return power_of_2_slopes(closest) + get_slopes(2 * closest)[0::2][: n - closest]
95
+ return torch.tensor(get_slopes(num_heads), device=device, dtype=dtype)
96
+
97
+
98
+ def _inv_softplus(y: torch.Tensor) -> torch.Tensor:
99
+ return torch.log(torch.expm1(y))
100
+
101
+
102
+ def phi(x: torch.Tensor) -> torch.Tensor:
103
+ """Performer-style feature map (elu + 1)."""
104
+ return F.elu(x) + 1.0
105
+
106
+
107
+ # --------------------------------------------------------- main module ---
108
+
109
+ class AddressedStateAttention(nn.Module):
110
+ """
111
+ Addressed State Attention (ASA) — unified research harness.
112
+
113
+ Core mechanism
114
+ --------------
115
+ * Prefix-softmax WRITE into K learned slots (streaming, O(T))
116
+ * READ routing from tokens → slots (softmax / top-k / external)
117
+ * Content-conditioned READ term (gamma-weighted)
118
+ * RoPE on write keys (geometry)
119
+ * ALiBi bias on write logits (prefix-friendly)
120
+
121
+ Slot-space refinement
122
+ ---------------------
123
+ * Causal linear attention in a low-dim slot-address coordinate space
124
+ * Produces per-token signed weights over slots
125
+ * Decoded through the same streaming slot-state basis
126
+ * Gated by learnable ``slotspace_gate`` (softplus)
127
+
128
+ Causal intervention (slot mask)
129
+ -------------------------------
130
+ * ``slot_mask`` [K] float/bool, 1=keep 0=mask
131
+ * ``slot_mask_where`` "read" | "content_read_only" | "slotspace_only"
132
+ * ``slot_mask_scope`` "all" | "last_pos_only"
133
+
134
+ Refine-delta intervention (instance attrs, NO-OP by default)
135
+ ----------------------------------------------------------------
136
+ * ``_intv_mode`` "off" | "delta_par" | "delta_orth" | "orth_gate" | …
137
+ * Decomposes refine delta into parallel / orthogonal vs base output
138
+ * See User Guide for configuration details.
139
+
140
+ Refine-geometry logging (NO output change)
141
+ ------------------------------------------------
142
+ * ``_log_refine_geom = True`` enables per-head geometry vectors in info dict.
143
+
144
+ Info storage
145
+ ------------
146
+ * ``info_level`` "basic" | "logits" | "full"
147
+ * ``info_cfg`` dict controlling which tensors to store, downsampling, CPU offload.
148
+ """
149
+
150
+ # ---------------------------------------------------------------- init ---
151
+
152
+ def __init__(
153
+ self,
154
+ embed_dim: int,
155
+ num_heads: int = 8,
156
+ num_slots: int = 8,
157
+ dropout: float = 0.1,
158
+ # temperatures / numerics
159
+ read_temperature: float = 1.0,
160
+ write_temperature: float = 1.0,
161
+ state_fp32: bool = True,
162
+ slot_dropout: float = 0.0,
163
+ normalize_k: bool = False,
164
+ # positions (write geometry)
165
+ use_rope_keys: bool = True,
166
+ rope_base: float = 10000.0,
167
+ # write bias (ALiBi)
168
+ use_alibi_write: bool = True,
169
+ alibi_strength_init: float = 0.1,
170
+ learn_alibi_strength: bool = True,
171
+ min_strength: float = 0.0,
172
+ # content-conditioned read term
173
+ use_content_read: bool = True,
174
+ content_read_init: float = -4.0,
175
+ content_read_max_gamma: float = 3.0,
176
+ # slot-space refinement
177
+ use_slotspace_refine: bool = True,
178
+ slotspace_dim: int = 32,
179
+ slotspace_gate_init: float = -4.0,
180
+ slotspace_dropout: float = 0.05,
181
+ slotspace_signed_weights: bool = True,
182
+ # RoPE in slot-space matcher
183
+ use_rope_slotspace: bool = True,
184
+ rope_base_slotspace: float = 100000.0,
185
+ # perf knobs
186
+ write_chunk_size: int = 128,
187
+ slotspace_chunk_size: int = 128,
188
+ ):
189
+ super().__init__()
190
+ assert embed_dim % num_heads == 0
191
+ self.embed_dim = embed_dim
192
+ self.num_heads = num_heads
193
+ self.num_slots = num_slots
194
+ self.head_dim = embed_dim // num_heads
195
+
196
+ self.dropout = nn.Dropout(dropout)
197
+
198
+ self.read_temperature = float(read_temperature)
199
+ self.write_temperature = float(write_temperature)
200
+ self.state_fp32 = bool(state_fp32)
201
+ self.slot_dropout = float(slot_dropout)
202
+ self.normalize_k = bool(normalize_k)
203
+ self.routing_override = None
204
+
205
+ self.use_rope_keys = bool(use_rope_keys)
206
+ self.use_alibi_write = bool(use_alibi_write)
207
+ self.learn_alibi_strength = bool(learn_alibi_strength)
208
+ self.min_strength = float(min_strength)
209
+
210
+ self.use_content_read = bool(use_content_read)
211
+ self.content_read_max_gamma = float(content_read_max_gamma)
212
+
213
+ self.use_slotspace_refine = bool(use_slotspace_refine)
214
+ self.slotspace_dim = int(slotspace_dim)
215
+ self.slotspace_dropout = nn.Dropout(float(slotspace_dropout))
216
+ self.slotspace_signed_weights = bool(slotspace_signed_weights)
217
+
218
+ self.write_chunk_size = int(write_chunk_size)
219
+ self.slotspace_chunk_size = int(slotspace_chunk_size)
220
+
221
+ # Learned slot keys: [H, K, d]
222
+ self.slot_keys = nn.Parameter(
223
+ torch.randn(num_heads, num_slots, self.head_dim) / math.sqrt(self.head_dim)
224
+ )
225
+
226
+ # Projections
227
+ self.Wk_write = nn.Linear(embed_dim, embed_dim, bias=False)
228
+ self.Wv_write = nn.Linear(embed_dim, embed_dim, bias=False)
229
+ self.Wq_read = nn.Linear(embed_dim, embed_dim, bias=False)
230
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
231
+
232
+ # RoPE (write geometry)
233
+ self.rope = RotaryEmbedding(self.head_dim, base=rope_base) if self.use_rope_keys else None
234
+
235
+ # ALiBi
236
+ if self.use_alibi_write:
237
+ self.register_buffer("_alibi_slopes", alibi_slopes(num_heads), persistent=False)
238
+ else:
239
+ self.register_buffer("_alibi_slopes", torch.zeros(num_heads), persistent=False)
240
+
241
+ if self.use_alibi_write and self.learn_alibi_strength:
242
+ init = torch.tensor(float(alibi_strength_init) - self.min_strength).clamp_min(1e-8)
243
+ self._alibi_strength_param = nn.Parameter(_inv_softplus(init))
244
+ else:
245
+ self._alibi_strength_param = None
246
+ self.alibi_strength = float(alibi_strength_init)
247
+
248
+ # Content read gamma
249
+ if self.use_content_read:
250
+ self._content_read_gamma_raw = nn.Parameter(torch.tensor(float(content_read_init)))
251
+ else:
252
+ self._content_read_gamma_raw = None
253
+
254
+ # Slot-space refinement
255
+ self.use_rope_slotspace = bool(use_rope_slotspace) and bool(self.use_slotspace_refine)
256
+ if self.use_slotspace_refine:
257
+ self.slot_in = nn.Linear(num_slots, self.slotspace_dim, bias=False)
258
+ self.slot_q = nn.Linear(self.slotspace_dim, self.slotspace_dim, bias=False)
259
+ self.slot_k = nn.Linear(self.slotspace_dim, self.slotspace_dim, bias=False)
260
+ self.slot_v = nn.Linear(self.slotspace_dim, self.slotspace_dim, bias=False)
261
+ self.slot_out = nn.Linear(self.slotspace_dim, num_slots, bias=False)
262
+ self._slotspace_gate_raw = nn.Parameter(torch.tensor(float(slotspace_gate_init)))
263
+ if self.use_rope_slotspace:
264
+ assert (self.slotspace_dim % 2) == 0, "use_rope_slotspace requires even slotspace_dim"
265
+ self.rope_slotspace = RotaryEmbedding(self.slotspace_dim, base=float(rope_base_slotspace))
266
+ else:
267
+ self.rope_slotspace = None
268
+ else:
269
+ self.slot_in = None
270
+ self.slot_q = self.slot_k = self.slot_v = None
271
+ self.slot_out = None
272
+ self._slotspace_gate_raw = None
273
+ self.rope_slotspace = None
274
+
275
+ # ----- intervention defaults (NO-OP) -----
276
+ self._intv_mode: str = "off"
277
+ self._intv_beta: float = 1.0
278
+ self._intv_score_kind: str = "orth_frac"
279
+ self._intv_tau_kind: str = "pctl"
280
+ self._intv_tau: float = 0.15
281
+ self._intv_tau_pctl: float = 75.0
282
+ self._intv_mask_mode: str = "soft"
283
+ self._intv_soft_temp: float = 0.05
284
+ self._intv_par_beta: float = 1.0
285
+ self._intv_head_mask: Optional[torch.Tensor] = None
286
+ self._intv_score_clip_pctl: float = 99.0
287
+
288
+ # ----- refine-geometry logging (no compute change) -----
289
+ self._log_refine_geom: bool = False
290
+
291
+ # -------------------------------------------------------- scalar params ---
292
+
293
+ def _alibi_strength(self, dtype, device) -> torch.Tensor:
294
+ if not (self.use_alibi_write and self.learn_alibi_strength):
295
+ return torch.tensor(self.alibi_strength, dtype=dtype, device=device)
296
+ return (F.softplus(self._alibi_strength_param) + self.min_strength).to(dtype=dtype, device=device)
297
+
298
+ def _content_read_gamma(self, dtype, device) -> torch.Tensor:
299
+ if not self.use_content_read:
300
+ return torch.tensor(0.0, dtype=dtype, device=device)
301
+ g = F.softplus(self._content_read_gamma_raw)
302
+ if self.content_read_max_gamma is not None and self.content_read_max_gamma > 0:
303
+ g = g.clamp(max=self.content_read_max_gamma)
304
+ return g.to(dtype=dtype, device=device)
305
+
306
+ def _slotspace_gate(self, dtype, device) -> torch.Tensor:
307
+ if not self.use_slotspace_refine:
308
+ return torch.tensor(0.0, dtype=dtype, device=device)
309
+ return F.softplus(self._slotspace_gate_raw).to(dtype=dtype, device=device)
310
+
311
+ # --------------------------------------------------------- numerics ---
312
+
313
+ @staticmethod
314
+ def _safe_exp_sub_max(s: torch.Tensor, m: torch.Tensor) -> torch.Tensor:
315
+ diff = s - m
316
+ diff = diff.masked_fill(~torch.isfinite(m), float("-inf"))
317
+ return torch.exp(diff)
318
+
319
+ # ------------------------------------------------------ slot mask ---
320
+
321
+ def _resolve_slot_mask(
322
+ self,
323
+ slot_mask: Optional[torch.Tensor],
324
+ *,
325
+ B: int, H: int, L: int, K: int,
326
+ device, dtype, scope: str,
327
+ ) -> Optional[torch.Tensor]:
328
+ """Expand [K] mask → [B,H,L,K]. Falls back to self.slot_mask attr."""
329
+ if slot_mask is None:
330
+ slot_mask = getattr(self, "slot_mask", None)
331
+ if slot_mask is None:
332
+ return None
333
+ sm = slot_mask.to(device=device, dtype=dtype)
334
+ if sm.ndim != 1 or sm.numel() != K:
335
+ raise ValueError(f"slot_mask must be shape [K]={K}, got {tuple(sm.shape)}")
336
+ sm = sm.view(1, 1, 1, K)
337
+ if scope == "all":
338
+ return sm.expand(B, H, L, K)
339
+ if scope == "last_pos_only":
340
+ out = torch.ones((B, H, L, K), device=device, dtype=dtype)
341
+ out[:, :, -1:, :] = sm.expand(B, H, 1, K)
342
+ return out
343
+ raise ValueError(f"Unknown slot_mask_scope={scope!r}")
344
+
345
+ @staticmethod
346
+ def _apply_hard_mask_and_renorm(w: torch.Tensor, keep: torch.Tensor) -> torch.Tensor:
347
+ w = w * keep.to(w.dtype)
348
+ return w / w.sum(dim=-1, keepdim=True).clamp_min(1e-8)
349
+
350
+ # --------------------------------------------------- info helpers ---
351
+
352
+ @staticmethod
353
+ def default_info_cfg() -> Dict:
354
+ """Return default info_cfg dict. Copy and modify before passing to forward()."""
355
+ return dict(
356
+ store_read_weights=True,
357
+ store_read_logits=True,
358
+ store_write_logits=True,
359
+ store_slot_state_norm=True,
360
+ store_out1=False,
361
+ store_delta=False,
362
+ store_slot_w=False,
363
+ detach_to_cpu=False,
364
+ time_stride=1,
365
+ batch_stride=1,
366
+ )
367
+
368
+ @staticmethod
369
+ def _store_tensor(
370
+ t: Optional[torch.Tensor], *, cfg: Dict, kind: str,
371
+ ) -> Optional[torch.Tensor]:
372
+ """Downsample + detach (+ optional CPU offload)."""
373
+ if t is None:
374
+ return None
375
+ bstride = int(cfg.get("batch_stride", 1))
376
+ tstride = int(cfg.get("time_stride", 1))
377
+ to_cpu = bool(cfg.get("detach_to_cpu", False))
378
+ x = t
379
+ if x.dim() >= 1 and bstride > 1:
380
+ x = x[::bstride]
381
+ if x.dim() == 4 and tstride > 1:
382
+ if kind == "bhtk":
383
+ x = x[:, :, ::tstride, :]
384
+ elif kind == "bhkt":
385
+ x = x[:, :, :, ::tstride]
386
+ x = x.detach()
387
+ if to_cpu:
388
+ x = x.to("cpu", non_blocking=True)
389
+ return x
390
+
391
+ # ------------------------------------------------ read-weight routing ---
392
+
393
+ def _compute_read_weights(
394
+ self,
395
+ *,
396
+ read_logits: torch.Tensor,
397
+ read_logits_key: torch.Tensor,
398
+ read_logits_content: Optional[torch.Tensor],
399
+ routing_mode: str,
400
+ routing_topk: int,
401
+ read_weights_override: Optional[torch.Tensor],
402
+ routing_noise: Optional[str],
403
+ routing_noise_scale: float,
404
+ rtemp: float,
405
+ sm: Optional[torch.Tensor],
406
+ slot_mask_where: str,
407
+ B: int, H: int, L: int, K: int,
408
+ T_total: int,
409
+ t0: int, t1: int,
410
+ q_read_c: torch.Tensor,
411
+ slot_keys: torch.Tensor,
412
+ slot_state_t: torch.Tensor,
413
+ valid: Optional[torch.Tensor],
414
+ state_dtype,
415
+ ) -> torch.Tensor:
416
+ """Compute read weights for one write-chunk. Handles noise, overrides, masks."""
417
+ # routing noise
418
+ if routing_noise is not None:
419
+ if routing_noise == "gumbel":
420
+ u = torch.rand_like(read_logits)
421
+ g = -torch.log(-torch.log(u.clamp_min(1e-8)).clamp_min(1e-8))
422
+ read_logits = read_logits + routing_noise_scale * g
423
+ elif routing_noise == "gaussian":
424
+ read_logits = read_logits + routing_noise_scale * torch.randn_like(read_logits)
425
+ else:
426
+ raise ValueError(f"Unknown routing_noise={routing_noise}")
427
+
428
+ # routing override (external callable or tensor)
429
+ if self.routing_override is not None:
430
+ if callable(self.routing_override):
431
+ ctx = dict(
432
+ t0=t0, t1=t1, B=B, H=H, T=T_total, K=K, d=self.head_dim,
433
+ rtemp=rtemp, state_dtype=state_dtype,
434
+ q_read_c=q_read_c, slot_keys=slot_keys,
435
+ slot_state_t=slot_state_t, valid=valid,
436
+ )
437
+ read_w = self.routing_override(
438
+ t0, t1, read_logits, read_logits_key, read_logits_content, ctx,
439
+ )
440
+ else:
441
+ read_w = self.routing_override[:, :, t0:t1, :].to(read_logits.dtype)
442
+ read_w = torch.nan_to_num(read_w, nan=0.0, posinf=0.0, neginf=0.0)
443
+ read_w = read_w.clamp_min(0.0)
444
+ read_w = read_w / read_w.sum(dim=-1, keepdim=True).clamp_min(1e-8)
445
+
446
+ else:
447
+ if routing_mode == "softmax":
448
+ read_w = torch.softmax(read_logits / rtemp, dim=-1)
449
+ elif routing_mode == "top1":
450
+ top = read_logits.argmax(dim=-1)
451
+ read_w = F.one_hot(top, num_classes=K).to(read_logits.dtype)
452
+ elif routing_mode == "topk":
453
+ kk = max(1, min(K, int(routing_topk)))
454
+ vals, idx = torch.topk(read_logits, k=kk, dim=-1)
455
+ masked = torch.full_like(read_logits, float("-inf"))
456
+ masked.scatter_(-1, idx, vals)
457
+ read_w = torch.softmax(masked / rtemp, dim=-1)
458
+ elif routing_mode == "external":
459
+ if read_weights_override is None:
460
+ raise ValueError("routing_mode='external' requires read_weights_override")
461
+ if read_weights_override.shape[-2] == T_total:
462
+ read_w = read_weights_override[:, :, t0:t1, :]
463
+ else:
464
+ read_w = read_weights_override
465
+ read_w = read_w / read_w.sum(dim=-1, keepdim=True).clamp_min(1e-8)
466
+ else:
467
+ raise ValueError(f"Unknown routing_mode={routing_mode}")
468
+
469
+ # slot mask at read stage
470
+ if slot_mask_where == "read" and sm is not None:
471
+ read_w = self._apply_hard_mask_and_renorm(read_w, (sm > 0.0))
472
+
473
+ return read_w
474
+
475
+ # ------------------------------------------- refine-delta intervention ---
476
+
477
+ def _apply_refine_intervention(
478
+ self,
479
+ out1: torch.Tensor,
480
+ delta: torch.Tensor,
481
+ slot_w: Optional[torch.Tensor],
482
+ ):
483
+ """Decompose refine delta into par/orth vs base output, optionally gate."""
484
+ eps = 1e-8
485
+ B, H, L, d = out1.shape
486
+
487
+ # head mask
488
+ hm = getattr(self, "_intv_head_mask", None)
489
+ if hm is not None:
490
+ hm = hm.to(device=out1.device).view(1, H, 1, 1).to(dtype=out1.dtype)
491
+
492
+ out1_norm2 = (out1 * out1).sum(dim=-1, keepdim=True).clamp_min(eps)
493
+ alpha = (delta * out1).sum(dim=-1, keepdim=True) / out1_norm2
494
+ delta_par = alpha * out1
495
+ delta_orth = delta - delta_par
496
+
497
+ logs = None
498
+
499
+ # geometry logging (no output change)
500
+ if getattr(self, "_log_refine_geom", False):
501
+ out1n = out1.norm(dim=-1).clamp_min(eps)
502
+ dn = delta.norm(dim=-1).clamp_min(eps)
503
+ dparn = delta_par.norm(dim=-1)
504
+ dorthn = delta_orth.norm(dim=-1)
505
+ a = alpha.squeeze(-1)
506
+ logs = dict(
507
+ geom_alpha_mean=a.mean(dim=(0, 2)),
508
+ geom_alpha_abs=a.abs().mean(dim=(0, 2)),
509
+ geom_sign_pos=(a > 0).float().mean(dim=(0, 2)),
510
+ geom_orth_frac=(dorthn / dn).mean(dim=(0, 2)),
511
+ geom_d_ratio=(dn / out1n).mean(dim=(0, 2)),
512
+ geom_dpar_ratio=(dparn / dn).mean(dim=(0, 2)),
513
+ )
514
+
515
+ mode = getattr(self, "_intv_mode", "off")
516
+ if mode is None or mode == "off":
517
+ return delta, logs
518
+
519
+ # --- intervention modes ---
520
+ if mode == "delta_par":
521
+ delta_mod = delta_par
522
+ logs = logs or {}
523
+ logs["alpha"] = alpha.squeeze(-1)
524
+
525
+ elif mode == "delta_orth":
526
+ delta_mod = delta_orth
527
+ logs = logs or {}
528
+ logs["alpha"] = alpha.squeeze(-1)
529
+
530
+ elif mode == "delta_par_plus_orth":
531
+ delta_mod = delta_par + delta_orth
532
+ logs = logs or {}
533
+ logs["alpha"] = alpha.squeeze(-1)
534
+
535
+ elif mode == "orth_gate":
536
+ beta = float(getattr(self, "_intv_beta", 1.0))
537
+ sk = getattr(self, "_intv_score_kind", "orth_frac")
538
+ out1n = out1.norm(dim=-1).clamp_min(eps)
539
+ dorthn = delta_orth.norm(dim=-1)
540
+ dn = delta.norm(dim=-1).clamp_min(eps)
541
+
542
+ if sk == "orth_ratio":
543
+ score = dorthn / out1n
544
+ elif sk == "orth_frac":
545
+ score = dorthn / dn
546
+ elif sk == "alpha_abs":
547
+ score = alpha.abs().squeeze(-1)
548
+ elif sk == "slot_peaked":
549
+ if slot_w is None:
550
+ raise ValueError("score_kind='slot_peaked' requires slot_w")
551
+ p = torch.softmax(slot_w.float(), dim=-1).clamp_min(1e-8)
552
+ Hrw = -(p * p.log()).sum(dim=-1)
553
+ K = p.shape[-1]
554
+ score = (1.0 - Hrw / max(1e-8, math.log(K))).to(dtype=out1.dtype)
555
+ else:
556
+ raise ValueError(f"Unknown _intv_score_kind={sk}")
557
+
558
+ # score clipping
559
+ clip_p = getattr(self, "_intv_score_clip_pctl", None)
560
+ if clip_p is not None:
561
+ clip_p = float(clip_p)
562
+ if 0.0 < clip_p < 100.0:
563
+ smax = torch.quantile(score.detach().flatten(), clip_p / 100.0).to(score.dtype)
564
+ score = torch.clamp(score, max=smax)
565
+
566
+ # tau
567
+ tk = getattr(self, "_intv_tau_kind", "pctl")
568
+ if tk == "abs":
569
+ tau = torch.tensor(float(getattr(self, "_intv_tau", 0.15)),
570
+ device=score.device, dtype=score.dtype)
571
+ elif tk == "pctl":
572
+ tau = torch.quantile(
573
+ score.detach().flatten(),
574
+ float(getattr(self, "_intv_tau_pctl", 75.0)) / 100.0,
575
+ ).to(score.dtype)
576
+ else:
577
+ raise ValueError(f"Unknown _intv_tau_kind={tk}")
578
+
579
+ # mask
580
+ mm = getattr(self, "_intv_mask_mode", "soft")
581
+ if mm == "hard":
582
+ mask = (score > tau).to(out1.dtype)
583
+ elif mm == "soft":
584
+ temp = max(1e-6, float(getattr(self, "_intv_soft_temp", 0.05)))
585
+ mask = torch.sigmoid((score - tau) / temp).to(out1.dtype)
586
+ else:
587
+ raise ValueError(f"Unknown _intv_mask_mode={mm}")
588
+
589
+ par_beta = float(getattr(self, "_intv_par_beta", 1.0))
590
+ delta_mod = par_beta * delta_par + beta * mask.unsqueeze(-1) * delta_orth
591
+
592
+ logs = logs or {}
593
+ logs.update(dict(
594
+ score=score, tau=tau, mask=mask,
595
+ alpha=alpha.squeeze(-1),
596
+ out1_norm=out1n,
597
+ dpar_norm=delta_par.norm(dim=-1),
598
+ dorth_norm=dorthn,
599
+ ))
600
+ else:
601
+ raise ValueError(f"Unknown _intv_mode={mode}")
602
+
603
+ # head targeting
604
+ if hm is not None:
605
+ delta_mod = hm * delta_mod + (1.0 - hm) * delta
606
+ logs = logs or {}
607
+ logs["head_mask"] = hm.squeeze(0).squeeze(-1).squeeze(-1).detach()
608
+
609
+ return delta_mod, logs
610
+
611
+ # ============================================================ forward ===
612
+
613
+ def forward(
614
+ self,
615
+ x: torch.Tensor,
616
+ attention_mask: Optional[torch.Tensor] = None,
617
+ return_info: bool = False,
618
+
619
+ # routing
620
+ routing_mode: str = "softmax",
621
+ routing_topk: int = 2,
622
+ read_weights_override: Optional[torch.Tensor] = None,
623
+ routing_noise: Optional[str] = None,
624
+ routing_noise_scale: float = 1.0,
625
+
626
+ # slot mask (causal intervention)
627
+ slot_mask: Optional[torch.Tensor] = None,
628
+ slot_mask_where: str = "read",
629
+ slot_mask_scope: str = "all",
630
+
631
+ # info controls
632
+ info_level: str = "full",
633
+ info_cfg: Optional[Dict] = None,
634
+ ) -> Tuple[torch.Tensor, Optional[Dict[str, torch.Tensor]]]:
635
+ """
636
+ Parameters
637
+ ----------
638
+ x : [B, T, C]
639
+ attention_mask : [B, T] optional padding mask (1=valid, 0=pad)
640
+ return_info : if True, return diagnostics dict as second element
641
+ routing_mode : "softmax" | "top1" | "topk" | "external"
642
+ routing_topk : k for topk mode
643
+ read_weights_override : [B,H,T,K] or [B,H,L,K] for external routing
644
+ routing_noise : None | "gumbel" | "gaussian"
645
+ routing_noise_scale : scale for routing noise
646
+ slot_mask : [K] where 1=keep, 0=mask
647
+ slot_mask_where : "read" | "content_read_only" | "slotspace_only"
648
+ slot_mask_scope : "all" | "last_pos_only"
649
+ info_level : "basic" | "logits" | "full"
650
+ info_cfg : dict (see default_info_cfg())
651
+
652
+ Returns
653
+ -------
654
+ (output, info) where info is None if return_info=False.
655
+ """
656
+
657
+ B, T, C = x.shape
658
+ H, K, d = self.num_heads, self.num_slots, self.head_dim
659
+
660
+ # ---- resolve info config ----
661
+ if info_cfg is None:
662
+ info_cfg = self.default_info_cfg()
663
+ store_read_weights = bool(info_cfg.get("store_read_weights", True))
664
+ store_read_logits = bool(info_cfg.get("store_read_logits", True)) and info_level in ("logits", "full")
665
+ store_write_logits = bool(info_cfg.get("store_write_logits", True)) and info_level == "full"
666
+ store_slot_norm = bool(info_cfg.get("store_slot_state_norm", True)) and info_level == "full"
667
+ store_out1 = bool(info_cfg.get("store_out1", False)) and return_info
668
+ store_delta = bool(info_cfg.get("store_delta", False)) and return_info
669
+ store_slot_w = bool(info_cfg.get("store_slot_w", False)) and return_info
670
+
671
+ # ---- projections ----
672
+ k_write = self.Wk_write(x).view(B, T, H, d).transpose(1, 2)
673
+ v_write = self.Wv_write(x).view(B, T, H, d).transpose(1, 2)
674
+ q_read = self.Wq_read(x).view(B, T, H, d).transpose(1, 2)
675
+
676
+ if self.normalize_k:
677
+ k_write = F.normalize(k_write, dim=-1, eps=1e-8)
678
+
679
+ if self.use_rope_keys:
680
+ cos, sin = self.rope.get_cos_sin(T, device=x.device, dtype=k_write.dtype)
681
+ k_write = apply_rope(k_write, cos, sin)
682
+
683
+ # slot dropout
684
+ slot_keys = self.slot_keys
685
+ if self.training and self.slot_dropout > 0.0:
686
+ drop = (torch.rand((H, K), device=x.device) < self.slot_dropout)
687
+ slot_keys = slot_keys * (~drop).to(slot_keys.dtype).unsqueeze(-1)
688
+
689
+ # ---- WRITE logits ----
690
+ write_logits_raw = torch.einsum("hkd,bhtd->bhkt", slot_keys, k_write) / math.sqrt(d)
691
+ state_dtype = torch.float32 if (self.state_fp32 and x.dtype != torch.float32) else x.dtype
692
+ write_logits = write_logits_raw.to(state_dtype) / max(1e-6, self.write_temperature)
693
+
694
+ # ALiBi
695
+ alibi_bias_applied = None
696
+ if self.use_alibi_write:
697
+ strength = self._alibi_strength(dtype=state_dtype, device=x.device)
698
+ slopes = self._alibi_slopes.to(device=x.device, dtype=state_dtype) * strength
699
+ pos_i = torch.arange(T, device=x.device, dtype=state_dtype)
700
+ alibi_bias = slopes.view(1, H, 1, 1) * pos_i.view(1, 1, 1, T)
701
+ write_logits = write_logits + alibi_bias
702
+ alibi_bias_applied = alibi_bias
703
+
704
+ # padding mask
705
+ if attention_mask is not None:
706
+ valid = attention_mask.to(dtype=torch.bool)
707
+ write_logits = write_logits.masked_fill(~valid.view(B, 1, 1, T), float("-inf"))
708
+ else:
709
+ valid = None
710
+
711
+ # ================================================================
712
+ # STREAMING WRITE + READ
713
+ # ================================================================
714
+ content_read_gamma = self._content_read_gamma(dtype=q_read.dtype, device=x.device)
715
+ rtemp = max(1e-6, self.read_temperature)
716
+
717
+ out_h = torch.empty((B, H, T, d), device=x.device, dtype=state_dtype)
718
+
719
+ out1_full = torch.empty((B, H, T, d), device=x.device, dtype=state_dtype) if store_out1 else None
720
+ delta_full = torch.empty((B, H, T, d), device=x.device, dtype=state_dtype) if store_delta else None
721
+ slot_w_full = torch.empty((B, H, T, K), device=x.device, dtype=state_dtype) if store_slot_w else None
722
+
723
+ need_rw = bool(self.use_slotspace_refine) or (return_info and store_read_weights)
724
+ read_weights = torch.empty((B, H, T, K), device=x.device, dtype=q_read.dtype) if need_rw else None
725
+
726
+ slot_state_norm_t = (
727
+ torch.empty((B, H, T, K), device=x.device, dtype=torch.float32)
728
+ if (return_info and store_slot_norm) else None
729
+ )
730
+
731
+ if return_info and store_read_logits:
732
+ read_logits_full = torch.empty((B, H, T, K), device=x.device, dtype=state_dtype)
733
+ read_logits_key_full = torch.empty((B, H, T, K), device=x.device, dtype=state_dtype)
734
+ read_logits_content_full = (
735
+ torch.empty((B, H, T, K), device=x.device, dtype=state_dtype) if self.use_content_read else None
736
+ )
737
+ else:
738
+ read_logits_full = read_logits_key_full = read_logits_content_full = None
739
+
740
+ # streaming state
741
+ denom_state = torch.zeros((B, H, K), device=x.device, dtype=state_dtype)
742
+ numer_state = torch.zeros((B, H, K, d), device=x.device, dtype=state_dtype)
743
+ m_state = torch.full((B, H, K), float("-inf"), device=x.device, dtype=state_dtype)
744
+
745
+ WRITE_CHUNK = self.write_chunk_size
746
+
747
+ for t0 in range(0, T, WRITE_CHUNK):
748
+ t1 = min(T, t0 + WRITE_CHUNK)
749
+ L = t1 - t0
750
+
751
+ wlog_c = write_logits[:, :, :, t0:t1]
752
+ m_c, _ = torch.cummax(wlog_c, dim=-1)
753
+ m_new = torch.maximum(m_state.unsqueeze(-1), m_c)
754
+
755
+ scale = torch.exp(m_state.unsqueeze(-1) - m_new)
756
+ denom_c = denom_state.unsqueeze(-1) * scale
757
+ numer_c = numer_state.unsqueeze(-2) * scale.unsqueeze(-1)
758
+
759
+ w_new = self._safe_exp_sub_max(wlog_c, m_new)
760
+ denom_c = denom_c + torch.cumsum(w_new, dim=-1)
761
+
762
+ v_c = v_write[:, :, t0:t1, :].to(state_dtype)
763
+ add = torch.cumsum(w_new.unsqueeze(-1) * v_c.unsqueeze(2), dim=-2)
764
+ numer_c = numer_c + add
765
+
766
+ slot_state_c = numer_c / denom_c.clamp_min(1e-8).unsqueeze(-1)
767
+ slot_state_t = slot_state_c.permute(0, 1, 3, 2, 4).contiguous()
768
+
769
+ # READ logits
770
+ q_read_c = q_read[:, :, t0:t1, :]
771
+ read_logits_key = torch.einsum("bhld,hkd->bhlk", q_read_c, slot_keys) / math.sqrt(d)
772
+
773
+ read_logits_content = None
774
+ if self.use_content_read:
775
+ read_logits_content = torch.einsum(
776
+ "bhld,bhlkd->bhlk", q_read_c, slot_state_t.to(q_read_c.dtype),
777
+ ) / math.sqrt(d)
778
+
779
+ # slot mask for this chunk
780
+ sm = self._resolve_slot_mask(
781
+ slot_mask, B=B, H=H, L=L, K=K,
782
+ device=x.device, dtype=read_logits_key.dtype, scope=slot_mask_scope,
783
+ )
784
+
785
+ # apply mask to logits according to slot_mask_where
786
+ if slot_mask_where == "read":
787
+ if sm is not None:
788
+ read_logits_key = read_logits_key.masked_fill(sm <= 0.0, float("-inf"))
789
+ if self.use_content_read and read_logits_content is not None:
790
+ read_logits_content = read_logits_content.masked_fill(sm <= 0.0, float("-inf"))
791
+ elif slot_mask_where == "content_read_only":
792
+ if sm is not None and self.use_content_read and read_logits_content is not None:
793
+ read_logits_content = read_logits_content.masked_fill(sm <= 0.0, 0.0)
794
+ elif slot_mask_where == "slotspace_only":
795
+ pass # applied later on slot_w
796
+ else:
797
+ raise ValueError(f"Unknown slot_mask_where={slot_mask_where!r}")
798
+
799
+ # combine
800
+ rl = read_logits_key
801
+ if self.use_content_read and read_logits_content is not None:
802
+ rl = rl + content_read_gamma.to(rl.dtype) * read_logits_content
803
+
804
+ if return_info and store_read_logits:
805
+ read_logits_full[:, :, t0:t1, :] = rl.to(state_dtype)
806
+ read_logits_key_full[:, :, t0:t1, :] = read_logits_key.to(state_dtype)
807
+ if self.use_content_read and read_logits_content_full is not None:
808
+ read_logits_content_full[:, :, t0:t1, :] = read_logits_content.to(state_dtype)
809
+
810
+ # read weights
811
+ read_w_c = self._compute_read_weights(
812
+ read_logits=rl, read_logits_key=read_logits_key,
813
+ read_logits_content=read_logits_content,
814
+ routing_mode=routing_mode, routing_topk=routing_topk,
815
+ read_weights_override=read_weights_override,
816
+ routing_noise=routing_noise, routing_noise_scale=routing_noise_scale,
817
+ rtemp=rtemp, sm=sm, slot_mask_where=slot_mask_where,
818
+ B=B, H=H, L=L, K=K, T_total=T, t0=t0, t1=t1,
819
+ q_read_c=q_read_c, slot_keys=slot_keys,
820
+ slot_state_t=slot_state_t, valid=valid,
821
+ state_dtype=state_dtype,
822
+ )
823
+
824
+ if read_weights is not None:
825
+ read_weights[:, :, t0:t1, :] = read_w_c
826
+
827
+ # base output
828
+ out_h[:, :, t0:t1, :] = torch.einsum(
829
+ "bhlk,bhlkd->bhld", read_w_c.to(state_dtype), slot_state_t.to(state_dtype),
830
+ )
831
+
832
+ if out1_full is not None:
833
+ out1_full[:, :, t0:t1, :] = out_h[:, :, t0:t1, :]
834
+
835
+ if slot_state_norm_t is not None:
836
+ slot_state_norm_t[:, :, t0:t1, :] = slot_state_t.to(torch.float32).norm(dim=-1)
837
+
838
+ m_state = m_new[:, :, :, -1]
839
+ denom_state = denom_c[:, :, :, -1]
840
+ numer_state = numer_c[:, :, :, -1, :]
841
+
842
+ # ================================================================
843
+ # SLOT-SPACE REFINEMENT
844
+ # ================================================================
845
+ slotspace_delta_norm_mean = None
846
+ intv_logs_acc: Optional[Dict] = None
847
+ intv_logs_count = 0
848
+
849
+ if self.use_slotspace_refine:
850
+ slotspace_dtype = state_dtype
851
+ M = self.slotspace_dim
852
+ assert read_weights is not None
853
+
854
+ u = self.slot_in(read_weights.to(slotspace_dtype))
855
+ q_s = self.slot_q(u)
856
+ k_s = self.slot_k(u)
857
+ v_s = self.slot_v(u)
858
+
859
+ if self.use_rope_slotspace:
860
+ cos_s, sin_s = self.rope_slotspace.get_cos_sin(T, device=x.device, dtype=q_s.dtype)
861
+ q_s = apply_rope(q_s, cos_s, sin_s)
862
+ k_s = apply_rope(k_s, cos_s, sin_s)
863
+
864
+ qf = phi(q_s)
865
+ kf = phi(k_s)
866
+
867
+ if valid is not None:
868
+ vmask = valid.view(B, 1, T, 1).to(slotspace_dtype)
869
+ qf = qf * vmask
870
+ kf = kf * vmask
871
+ v_s = v_s * vmask
872
+
873
+ u2 = torch.empty((B, H, T, M), device=x.device, dtype=slotspace_dtype)
874
+ S_state = torch.zeros((B, H, M, M), device=x.device, dtype=slotspace_dtype)
875
+ Z_state = torch.zeros((B, H, M), device=x.device, dtype=slotspace_dtype)
876
+
877
+ SS_CHUNK = self.slotspace_chunk_size
878
+ for t0 in range(0, T, SS_CHUNK):
879
+ t1 = min(T, t0 + SS_CHUNK)
880
+ qf_c = qf[:, :, t0:t1, :]
881
+ kf_c = kf[:, :, t0:t1, :]
882
+ v_c = v_s[:, :, t0:t1, :]
883
+
884
+ kv = torch.einsum("bhlm,bhln->bhlmn", kf_c, v_c)
885
+ S_c = torch.cumsum(kv, dim=2) + S_state.unsqueeze(2)
886
+ Z_c = (torch.cumsum(kf_c, dim=2) + Z_state.unsqueeze(2)).clamp_min(1e-8)
887
+
888
+ num = torch.einsum("bhlm,bhlmn->bhln", qf_c, S_c)
889
+ den = torch.einsum("bhlm,bhlm->bhl", qf_c, Z_c).unsqueeze(-1).clamp_min(1e-8)
890
+ u2[:, :, t0:t1, :] = num / den
891
+
892
+ S_state = S_c[:, :, -1, :, :]
893
+ Z_state = Z_c[:, :, -1, :]
894
+
895
+ u2 = self.slotspace_dropout(u2)
896
+ slot_w = self.slot_out(u2)
897
+
898
+ if slot_w_full is not None:
899
+ slot_w_full[:] = slot_w.to(state_dtype)
900
+
901
+ if self.slotspace_signed_weights:
902
+ slot_w_eff = torch.tanh(slot_w)
903
+ else:
904
+ slot_w_eff = torch.softmax(slot_w, dim=-1)
905
+
906
+ # slotspace-only mask
907
+ if slot_mask_where == "slotspace_only":
908
+ sm_full = self._resolve_slot_mask(
909
+ slot_mask, B=B, H=H, L=T, K=K,
910
+ device=x.device, dtype=slot_w_eff.dtype, scope=slot_mask_scope,
911
+ )
912
+ if sm_full is not None:
913
+ slot_w_eff = slot_w_eff * (sm_full > 0.0).to(slot_w_eff.dtype)
914
+ if not self.slotspace_signed_weights:
915
+ slot_w_eff = slot_w_eff / slot_w_eff.sum(dim=-1, keepdim=True).clamp_min(1e-8)
916
+
917
+ gate = self._slotspace_gate(dtype=state_dtype, device=x.device).to(state_dtype)
918
+
919
+ # second streaming pass: decode delta through slot states
920
+ denom_state2 = torch.zeros((B, H, K), device=x.device, dtype=state_dtype)
921
+ numer_state2 = torch.zeros((B, H, K, d), device=x.device, dtype=state_dtype)
922
+ m_state2 = torch.full((B, H, K), float("-inf"), device=x.device, dtype=state_dtype)
923
+
924
+ delta_norm_sum = torch.zeros((), device=x.device, dtype=torch.float32)
925
+ delta_norm_count = 0
926
+
927
+ for t0 in range(0, T, WRITE_CHUNK):
928
+ t1 = min(T, t0 + WRITE_CHUNK)
929
+ Lc = t1 - t0
930
+
931
+ wlog_c = write_logits[:, :, :, t0:t1]
932
+ m_c, _ = torch.cummax(wlog_c, dim=-1)
933
+ m_new = torch.maximum(m_state2.unsqueeze(-1), m_c)
934
+
935
+ scale = torch.exp(m_state2.unsqueeze(-1) - m_new)
936
+ denom_c = denom_state2.unsqueeze(-1) * scale
937
+ numer_c = numer_state2.unsqueeze(-2) * scale.unsqueeze(-1)
938
+
939
+ w_new = self._safe_exp_sub_max(wlog_c, m_new)
940
+ denom_c = denom_c + torch.cumsum(w_new, dim=-1)
941
+
942
+ v_c = v_write[:, :, t0:t1, :].to(state_dtype)
943
+ add = torch.cumsum(w_new.unsqueeze(-1) * v_c.unsqueeze(2), dim=-2)
944
+ numer_c = numer_c + add
945
+
946
+ slot_state_c = numer_c / denom_c.clamp_min(1e-8).unsqueeze(-1)
947
+ slot_state_t2 = slot_state_c.permute(0, 1, 3, 2, 4).contiguous()
948
+
949
+ slot_w_c = slot_w_eff[:, :, t0:t1, :].to(state_dtype)
950
+ delta_c = torch.einsum("bhlk,bhlkd->bhld", slot_w_c, slot_state_t2.to(state_dtype))
951
+
952
+ delta = gate * delta_c
953
+
954
+ if delta_full is not None:
955
+ delta_full[:, :, t0:t1, :] = delta
956
+
957
+ # intervention
958
+ slot_w_for_score = slot_w[:, :, t0:t1, :] if store_slot_w else None
959
+ delta_mod, logs = self._apply_refine_intervention(
960
+ out1=out_h[:, :, t0:t1, :], delta=delta, slot_w=slot_w_for_score,
961
+ )
962
+
963
+ out_h[:, :, t0:t1, :] = out_h[:, :, t0:t1, :] + delta_mod
964
+
965
+ # accumulate logs
966
+ if logs is not None and return_info:
967
+ if intv_logs_acc is None:
968
+ intv_logs_acc = {}
969
+ for klog, v in logs.items():
970
+ if torch.is_tensor(v):
971
+ vv = v.detach().to(torch.float32)
972
+ intv_logs_acc[klog] = vv if vv.ndim == 1 else vv.mean()
973
+ intv_logs_count = 1
974
+ else:
975
+ for klog, v in logs.items():
976
+ if torch.is_tensor(v) and klog in intv_logs_acc:
977
+ vv = v.detach().to(torch.float32)
978
+ intv_logs_acc[klog] = intv_logs_acc[klog] + (vv if vv.ndim == 1 else vv.mean())
979
+ intv_logs_count += 1
980
+
981
+ delta_norm_sum = delta_norm_sum + delta.detach().to(torch.float32).norm(dim=-1).sum()
982
+ delta_norm_count += B * H * Lc
983
+
984
+ m_state2 = m_new[:, :, :, -1]
985
+ denom_state2 = denom_c[:, :, :, -1]
986
+ numer_state2 = numer_c[:, :, :, -1, :]
987
+
988
+ slotspace_delta_norm_mean = (delta_norm_sum / max(1, delta_norm_count)).detach().cpu()
989
+
990
+ # ================================================================
991
+ # OUTPUT
992
+ # ================================================================
993
+ out = out_h.transpose(1, 2).contiguous().view(B, T, C)
994
+ out = self.out_proj(out)
995
+ out = self.dropout(out)
996
+
997
+ # ---- info dict ----
998
+ info = None
999
+ if return_info:
1000
+ info = {
1001
+ "content_read_gamma": content_read_gamma.detach().to(torch.float32).cpu(),
1002
+ "routing_mode": routing_mode,
1003
+ "slot_mask_where": slot_mask_where,
1004
+ "slot_mask_scope": slot_mask_scope,
1005
+ "intv_mode": getattr(self, "_intv_mode", "off"),
1006
+ }
1007
+
1008
+ if alibi_bias_applied is not None and info_level == "full":
1009
+ info["alibi_bias_applied"] = self._store_tensor(alibi_bias_applied.to(torch.float32), cfg=info_cfg, kind="other")
1010
+
1011
+ if self.use_alibi_write and self.learn_alibi_strength:
1012
+ info["alibi_strength"] = self._alibi_strength(dtype=torch.float32, device=x.device).detach().cpu()
1013
+
1014
+ if self.use_slotspace_refine:
1015
+ info["slotspace_gate"] = self._slotspace_gate(dtype=torch.float32, device=x.device).detach().cpu()
1016
+ info["use_rope_slotspace"] = torch.tensor(bool(self.use_rope_slotspace))
1017
+ if slotspace_delta_norm_mean is not None:
1018
+ info["slotspace_delta_norm"] = slotspace_delta_norm_mean
1019
+
1020
+ # read weights
1021
+ if store_read_weights and read_weights is not None:
1022
+ info["read_weights"] = self._store_tensor(read_weights, cfg=info_cfg, kind="bhtk")
1023
+ else:
1024
+ info["read_weights"] = None
1025
+
1026
+ # slot state norm
1027
+ if store_slot_norm and slot_state_norm_t is not None:
1028
+ s = slot_state_norm_t.permute(0, 1, 3, 2).contiguous()
1029
+ info["slot_state_norm"] = self._store_tensor(s, cfg=info_cfg, kind="bhkt")
1030
+ else:
1031
+ info["slot_state_norm"] = None
1032
+
1033
+ # read logits
1034
+ if store_read_logits and read_logits_full is not None:
1035
+ info["read_logits"] = self._store_tensor(read_logits_full.to(torch.float32), cfg=info_cfg, kind="bhtk")
1036
+ info["read_logits_key"] = self._store_tensor(read_logits_key_full.to(torch.float32), cfg=info_cfg, kind="bhtk")
1037
+ info["read_logits_content"] = (
1038
+ self._store_tensor(read_logits_content_full.to(torch.float32), cfg=info_cfg, kind="bhtk")
1039
+ if read_logits_content_full is not None else None
1040
+ )
1041
+ else:
1042
+ info["read_logits"] = info["read_logits_key"] = info["read_logits_content"] = None
1043
+
1044
+ # write logits
1045
+ if store_write_logits and info_level == "full":
1046
+ info["write_logits_raw"] = self._store_tensor(write_logits_raw, cfg=info_cfg, kind="bhkt")
1047
+ info["write_logits"] = self._store_tensor(write_logits.to(torch.float32), cfg=info_cfg, kind="bhkt")
1048
+ else:
1049
+ info["write_logits_raw"] = info["write_logits"] = None
1050
+
1051
+ # out1 / delta / slot_w
1052
+ info["out1"] = self._store_tensor(out1_full.to(torch.float32), cfg=info_cfg, kind="other") if out1_full is not None else None
1053
+ info["delta"] = self._store_tensor(delta_full.to(torch.float32), cfg=info_cfg, kind="other") if delta_full is not None else None
1054
+ info["slot_w"] = self._store_tensor(slot_w_full.to(torch.float32), cfg=info_cfg, kind="bhtk") if slot_w_full is not None else None
1055
+
1056
+ # averaged intervention / geometry logs
1057
+ if intv_logs_acc is not None and intv_logs_count > 0:
1058
+ for klog, v in intv_logs_acc.items():
1059
+ info[klog] = (v / float(intv_logs_count)).detach().cpu()
1060
+
1061
+ # backward-compatible scalar aliases
1062
+ for alias_from, alias_to in [
1063
+ ("score", "intv_score_mean"), ("mask", "intv_mask_mean"),
1064
+ ("tau", "intv_tau"), ("alpha", "intv_alpha_mean"),
1065
+ ("out1_norm", "intv_out1_norm_mean"),
1066
+ ("dpar_norm", "intv_dpar_norm_mean"),
1067
+ ("dorth_norm", "intv_dorth_norm_mean"),
1068
+ ]:
1069
+ if alias_from in intv_logs_acc:
1070
+ val = info.get(alias_from)
1071
+ if torch.is_tensor(val) and val.ndim != 1:
1072
+ info[alias_to] = val
1073
+
1074
+ return out, info
1075
+
1076
+
1077
+ # Addressed State Models (ASM): Config + Block + LM
1078
+ #
1079
+ # Unified companion for the consolidated AddressedStateAttention harness.
1080
+ # Block.forward() and LM.forward() pass through the full ASA forward() surface:
1081
+ # routing controls, slot mask, info_level, info_cfg.
1082
+ #
1083
+ # ============================================================================
1084
+ # Config
1085
+ # ============================================================================
1086
+ @dataclass
1087
+ class ASMTrainConfig:
1088
+ # Data
1089
+ dataset_name: str = "wikitext"
1090
+ dataset_config: str = "wikitext-103-raw-v1"
1091
+ tokenizer_name: str = "gpt2"
1092
+
1093
+ max_seq_len: int = 256
1094
+ stride_frac_val: float = 0.50
1095
+ seed: int = 1337
1096
+
1097
+ micro_batch_size: int = 2
1098
+ grad_accum_steps: int = 8
1099
+ train_samples_target: int = 100_000_000
1100
+ val_samples_target: int = 25_000
1101
+
1102
+ # Training
1103
+ batch_size: int = 64
1104
+ learning_rate: float = 3e-4
1105
+ weight_decay: float = 0.01
1106
+ betas: Tuple[float, float] = (0.9, 0.95)
1107
+ grad_clip: float = 1.0
1108
+ warmup_steps: int = 1_000
1109
+ total_steps: int = 75_000
1110
+ eval_interval: int = 1_000
1111
+ log_interval: int = 100
1112
+
1113
+ # Model
1114
+ vocab_size: int = 50257
1115
+ embed_dim: int = 384
1116
+ num_layers: int = 23
1117
+ num_heads: int = 8
1118
+ num_slots: int = 32
1119
+ mlp_ratio: float = 4.0
1120
+ dropout: float = 0.1
1121
+ tie_weights: bool = True
1122
+
1123
+ # ASA / numerics
1124
+ read_temperature: float = 1.0
1125
+ write_temperature: float = 1.0
1126
+ slot_dropout: float = 0.05
1127
+ state_fp32: bool = True
1128
+ normalize_k: bool = False
1129
+
1130
+ # Positions
1131
+ use_abs_pos: bool = False
1132
+ use_rope_keys: bool = True
1133
+ rope_base: float = 10000.0
1134
+ use_alibi_write: bool = True
1135
+ alibi_strength_init: float = 0.1
1136
+ learn_alibi_strength: bool = True
1137
+ min_strength: float = 0.0
1138
+
1139
+ # Content-conditioned read (gamma)
1140
+ use_content_read: bool = True
1141
+ content_read_init: float = -4.0
1142
+ content_read_max_gamma: float = 3.0
1143
+
1144
+ # Slot-space refinement
1145
+ use_slotspace_refine: bool = True
1146
+ slotspace_dim: int = 64
1147
+ slotspace_gate_init: float = -4.0
1148
+ slotspace_dropout: float = 0.05
1149
+ slotspace_signed_weights: bool = True
1150
+
1151
+ # RoPE inside slot-space matcher
1152
+ use_rope_slotspace: bool = True
1153
+ rope_base_slotspace: float = 100000.0
1154
+
1155
+ # Perf knobs
1156
+ write_chunk_size: int = 128
1157
+ slotspace_chunk_size: int = 128
1158
+ enable_compiled: bool = False
1159
+
1160
+ # Analytics
1161
+ eval_max_batches: int = 150
1162
+ analytics_last_k: int = 32
1163
+
1164
+ # IO / caches
1165
+ output_dir: str = "./drive/MyDrive/asm_outputs"
1166
+ tag: str = "asm_wikitext"
1167
+ cache_dir: str = "./drive/MyDrive/asm_caches"
1168
+ val_windows_cache: str = "./drive/MyDrive/asm_val_cache_windows_1024.pkl"
1169
+
1170
+
1171
+ # ============================================================================
1172
+ # Block
1173
+ # ============================================================================
1174
+ class ASMBlock(nn.Module):
1175
+ def __init__(
1176
+ self,
1177
+ embed_dim: int,
1178
+ num_heads: int,
1179
+ num_slots: int,
1180
+ mlp_ratio: float = 4.0,
1181
+ dropout: float = 0.1,
1182
+ # temperatures / numerics
1183
+ read_temperature: float = 1.0,
1184
+ write_temperature: float = 1.0,
1185
+ state_fp32: bool = True,
1186
+ slot_dropout: float = 0.0,
1187
+ normalize_k: bool = False,
1188
+ # positions
1189
+ use_rope_keys: bool = True,
1190
+ rope_base: float = 10000.0,
1191
+ use_alibi_write: bool = True,
1192
+ # ALiBi
1193
+ alibi_strength_init: float = 0.1,
1194
+ learn_alibi_strength: bool = True,
1195
+ min_strength: float = 0.0,
1196
+ # content-conditioned read (gamma)
1197
+ use_content_read: bool = True,
1198
+ content_read_init: float = -4.0,
1199
+ content_read_max_gamma: float = 3.0,
1200
+ # slot-space refinement
1201
+ use_slotspace_refine: bool = True,
1202
+ slotspace_dim: int = 32,
1203
+ slotspace_gate_init: float = -10.0,
1204
+ slotspace_dropout: float = 0.0,
1205
+ slotspace_signed_weights: bool = True,
1206
+ # RoPE inside slot-space matcher
1207
+ use_rope_slotspace: bool = True,
1208
+ rope_base_slotspace: float = 100000.0,
1209
+ # chunk sizes
1210
+ write_chunk_size: int = 128,
1211
+ slotspace_chunk_size: int = 128,
1212
+ ):
1213
+ super().__init__()
1214
+ self.norm1 = nn.LayerNorm(embed_dim)
1215
+
1216
+ self.asa = AddressedStateAttention(
1217
+ embed_dim=embed_dim,
1218
+ num_heads=num_heads,
1219
+ num_slots=num_slots,
1220
+ dropout=dropout,
1221
+ read_temperature=read_temperature,
1222
+ write_temperature=write_temperature,
1223
+ state_fp32=state_fp32,
1224
+ slot_dropout=slot_dropout,
1225
+ normalize_k=normalize_k,
1226
+ use_rope_keys=use_rope_keys,
1227
+ rope_base=rope_base,
1228
+ use_alibi_write=use_alibi_write,
1229
+ alibi_strength_init=alibi_strength_init,
1230
+ learn_alibi_strength=learn_alibi_strength,
1231
+ min_strength=min_strength,
1232
+ use_content_read=use_content_read,
1233
+ content_read_init=content_read_init,
1234
+ content_read_max_gamma=content_read_max_gamma,
1235
+ use_slotspace_refine=use_slotspace_refine,
1236
+ slotspace_dim=slotspace_dim,
1237
+ slotspace_gate_init=slotspace_gate_init,
1238
+ slotspace_dropout=slotspace_dropout,
1239
+ slotspace_signed_weights=slotspace_signed_weights,
1240
+ use_rope_slotspace=use_rope_slotspace,
1241
+ rope_base_slotspace=rope_base_slotspace,
1242
+ write_chunk_size=write_chunk_size,
1243
+ slotspace_chunk_size=slotspace_chunk_size,
1244
+ )
1245
+
1246
+ self.norm2 = nn.LayerNorm(embed_dim)
1247
+ hidden = int(embed_dim * mlp_ratio)
1248
+ self.mlp = nn.Sequential(
1249
+ nn.Linear(embed_dim, hidden, bias=False),
1250
+ nn.GELU(),
1251
+ nn.Dropout(dropout),
1252
+ nn.Linear(hidden, embed_dim, bias=False),
1253
+ nn.Dropout(dropout),
1254
+ )
1255
+
1256
+ def forward(
1257
+ self,
1258
+ x: torch.Tensor,
1259
+ attention_mask: Optional[torch.Tensor] = None,
1260
+ return_info: bool = False,
1261
+ # routing
1262
+ routing_mode: str = "softmax",
1263
+ routing_topk: int = 2,
1264
+ read_weights_override: Optional[torch.Tensor] = None,
1265
+ routing_noise: Optional[str] = None,
1266
+ routing_noise_scale: float = 1.0,
1267
+ # slot mask
1268
+ slot_mask: Optional[torch.Tensor] = None,
1269
+ slot_mask_where: str = "read",
1270
+ slot_mask_scope: str = "all",
1271
+ # info controls
1272
+ info_level: str = "full",
1273
+ info_cfg: Optional[Dict] = None,
1274
+ ):
1275
+ a, info = self.asa(
1276
+ self.norm1(x),
1277
+ attention_mask=attention_mask,
1278
+ return_info=return_info,
1279
+ routing_mode=routing_mode,
1280
+ routing_topk=routing_topk,
1281
+ read_weights_override=read_weights_override,
1282
+ routing_noise=routing_noise,
1283
+ routing_noise_scale=routing_noise_scale,
1284
+ slot_mask=slot_mask,
1285
+ slot_mask_where=slot_mask_where,
1286
+ slot_mask_scope=slot_mask_scope,
1287
+ info_level=info_level,
1288
+ info_cfg=info_cfg,
1289
+ )
1290
+ x = x + a
1291
+ x = x + self.mlp(self.norm2(x))
1292
+ return x, info
1293
+
1294
+
1295
+ # ============================================================================
1296
+ # LM
1297
+ # ============================================================================
1298
+ class ASMLanguageModel(nn.Module):
1299
+ def __init__(
1300
+ self,
1301
+ vocab_size: int,
1302
+ embed_dim: int = 384,
1303
+ num_layers: int = 6,
1304
+ num_heads: int = 8,
1305
+ num_slots: int = 8,
1306
+ max_seq_len: int = 1024,
1307
+ mlp_ratio: float = 4.0,
1308
+ dropout: float = 0.1,
1309
+ # temperatures / numerics
1310
+ read_temperature: float = 1.0,
1311
+ write_temperature: float = 1.0,
1312
+ state_fp32: bool = True,
1313
+ slot_dropout: float = 0.05,
1314
+ normalize_k: bool = False,
1315
+ tie_weights: bool = True,
1316
+ # LM-level abs pos
1317
+ use_abs_pos: bool = False,
1318
+ # positions
1319
+ use_rope_keys: bool = True,
1320
+ rope_base: float = 10000.0,
1321
+ use_alibi_write: bool = True,
1322
+ # ALiBi
1323
+ alibi_strength_init: float = 0.1,
1324
+ learn_alibi_strength: bool = True,
1325
+ min_strength: float = 0.0,
1326
+ # content-conditioned read (gamma)
1327
+ use_content_read: bool = True,
1328
+ content_read_init: float = -4.0,
1329
+ content_read_max_gamma: float = 3.0,
1330
+ # slot-space refinement
1331
+ use_slotspace_refine: bool = True,
1332
+ slotspace_dim: int = 32,
1333
+ slotspace_gate_init: float = -10.0,
1334
+ slotspace_dropout: float = 0.0,
1335
+ slotspace_signed_weights: bool = True,
1336
+ # RoPE inside slot-space matcher
1337
+ use_rope_slotspace: bool = True,
1338
+ rope_base_slotspace: float = 100000.0,
1339
+ # chunk sizes
1340
+ write_chunk_size: int = 128,
1341
+ slotspace_chunk_size: int = 128,
1342
+ ):
1343
+ super().__init__()
1344
+ self.vocab_size = vocab_size
1345
+ self.embed_dim = embed_dim
1346
+ self.max_seq_len = max_seq_len
1347
+ self.use_abs_pos = bool(use_abs_pos)
1348
+
1349
+ self.tok = nn.Embedding(vocab_size, embed_dim)
1350
+ self.pos = nn.Embedding(max_seq_len, embed_dim) if self.use_abs_pos else None
1351
+ self.drop = nn.Dropout(dropout)
1352
+
1353
+ self.blocks = nn.ModuleList([
1354
+ ASMBlock(
1355
+ embed_dim=embed_dim,
1356
+ num_heads=num_heads,
1357
+ num_slots=num_slots,
1358
+ mlp_ratio=mlp_ratio,
1359
+ dropout=dropout,
1360
+ read_temperature=read_temperature,
1361
+ write_temperature=write_temperature,
1362
+ state_fp32=state_fp32,
1363
+ slot_dropout=slot_dropout,
1364
+ normalize_k=normalize_k,
1365
+ use_rope_keys=use_rope_keys,
1366
+ rope_base=rope_base,
1367
+ use_alibi_write=use_alibi_write,
1368
+ alibi_strength_init=alibi_strength_init,
1369
+ learn_alibi_strength=learn_alibi_strength,
1370
+ min_strength=min_strength,
1371
+ use_content_read=use_content_read,
1372
+ content_read_init=content_read_init,
1373
+ content_read_max_gamma=content_read_max_gamma,
1374
+ use_slotspace_refine=use_slotspace_refine,
1375
+ slotspace_dim=slotspace_dim,
1376
+ slotspace_gate_init=slotspace_gate_init,
1377
+ slotspace_dropout=slotspace_dropout,
1378
+ slotspace_signed_weights=slotspace_signed_weights,
1379
+ use_rope_slotspace=use_rope_slotspace,
1380
+ rope_base_slotspace=rope_base_slotspace,
1381
+ write_chunk_size=write_chunk_size,
1382
+ slotspace_chunk_size=slotspace_chunk_size,
1383
+ )
1384
+ for _ in range(num_layers)
1385
+ ])
1386
+
1387
+ self.norm = nn.LayerNorm(embed_dim)
1388
+ self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
1389
+ if tie_weights:
1390
+ self.lm_head.weight = self.tok.weight
1391
+
1392
+ self.apply(self._init)
1393
+
1394
+ def _init(self, m):
1395
+ if isinstance(m, nn.Linear):
1396
+ nn.init.normal_(m.weight, std=0.02)
1397
+ elif isinstance(m, nn.Embedding):
1398
+ nn.init.normal_(m.weight, std=0.02)
1399
+ elif isinstance(m, nn.LayerNorm):
1400
+ nn.init.ones_(m.weight)
1401
+ nn.init.zeros_(m.bias)
1402
+
1403
+ def forward(
1404
+ self,
1405
+ input_ids: torch.Tensor,
1406
+ attention_mask: Optional[torch.Tensor] = None,
1407
+ return_info: bool = False,
1408
+ # routing
1409
+ routing_mode: str = "softmax",
1410
+ routing_topk: int = 2,
1411
+ read_weights_override: Optional[torch.Tensor] = None,
1412
+ routing_noise: Optional[str] = None,
1413
+ routing_noise_scale: float = 1.0,
1414
+ # slot mask
1415
+ slot_mask: Optional[torch.Tensor] = None,
1416
+ slot_mask_where: str = "read",
1417
+ slot_mask_scope: str = "all",
1418
+ # info controls
1419
+ info_level: str = "full",
1420
+ info_cfg: Optional[Dict] = None,
1421
+ ):
1422
+ B, T = input_ids.shape
1423
+
1424
+ x = self.tok(input_ids)
1425
+ if self.use_abs_pos:
1426
+ pos = torch.arange(T, device=input_ids.device).unsqueeze(0).expand(B, -1)
1427
+ x = x + self.pos(pos)
1428
+ x = self.drop(x)
1429
+
1430
+ infos: List[Optional[Dict[str, torch.Tensor]]] = []
1431
+ for blk in self.blocks:
1432
+ x, info = blk(
1433
+ x,
1434
+ attention_mask=attention_mask,
1435
+ return_info=return_info,
1436
+ routing_mode=routing_mode,
1437
+ routing_topk=routing_topk,
1438
+ read_weights_override=read_weights_override,
1439
+ routing_noise=routing_noise,
1440
+ routing_noise_scale=routing_noise_scale,
1441
+ slot_mask=slot_mask,
1442
+ slot_mask_where=slot_mask_where,
1443
+ slot_mask_scope=slot_mask_scope,
1444
+ info_level=info_level,
1445
+ info_cfg=info_cfg,
1446
+ )
1447
+ if return_info:
1448
+ infos.append(info)
1449
+
1450
+ x = self.norm(x)
1451
+ logits = self.lm_head(x)
1452
+ return (logits, infos) if return_info else logits
1453
+
1454
+
1455
+ # ============================================================================
1456
+ # Convenience: build model from config
1457
+ # ============================================================================
1458
+ def build_model_from_cfg(cfg: ASMTrainConfig) -> ASMLanguageModel:
1459
+ return ASMLanguageModel(
1460
+ vocab_size=cfg.vocab_size,
1461
+ embed_dim=cfg.embed_dim,
1462
+ num_layers=cfg.num_layers,
1463
+ num_heads=cfg.num_heads,
1464
+ num_slots=cfg.num_slots,
1465
+ max_seq_len=cfg.max_seq_len,
1466
+ mlp_ratio=cfg.mlp_ratio,
1467
+ dropout=cfg.dropout,
1468
+ read_temperature=cfg.read_temperature,
1469
+ write_temperature=cfg.write_temperature,
1470
+ state_fp32=cfg.state_fp32,
1471
+ slot_dropout=cfg.slot_dropout,
1472
+ normalize_k=cfg.normalize_k,
1473
+ tie_weights=cfg.tie_weights,
1474
+ use_abs_pos=cfg.use_abs_pos,
1475
+ use_rope_keys=cfg.use_rope_keys,
1476
+ rope_base=cfg.rope_base,
1477
+ use_alibi_write=cfg.use_alibi_write,
1478
+ alibi_strength_init=cfg.alibi_strength_init,
1479
+ learn_alibi_strength=cfg.learn_alibi_strength,
1480
+ min_strength=cfg.min_strength,
1481
+ use_content_read=cfg.use_content_read,
1482
+ content_read_init=cfg.content_read_init,
1483
+ content_read_max_gamma=cfg.content_read_max_gamma,
1484
+ use_slotspace_refine=cfg.use_slotspace_refine,
1485
+ slotspace_dim=cfg.slotspace_dim,
1486
+ slotspace_gate_init=cfg.slotspace_gate_init,
1487
+ slotspace_dropout=cfg.slotspace_dropout,
1488
+ slotspace_signed_weights=cfg.slotspace_signed_weights,
1489
+ use_rope_slotspace=cfg.use_rope_slotspace,
1490
+ rope_base_slotspace=cfg.rope_base_slotspace,
1491
+ write_chunk_size=cfg.write_chunk_size,
1492
+ slotspace_chunk_size=cfg.slotspace_chunk_size,
1493
+ )