DigitalDaimyo commited on
Commit
8f95e75
·
verified ·
1 Parent(s): 319be6d

Upload training.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. training.py +902 -0
training.py ADDED
@@ -0,0 +1,902 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """
3
+ Addressed State Attention (ASA) - Training Harness
4
+
5
+ Efficient implementation optimized for language model training.
6
+ For mechanistic analysis and interventions, use asm_analysis.py instead.
7
+
8
+ Repository: https://github.com/DigitalDaimyo/AddressedStateAttention
9
+ Paper: https://github.com/DigitalDaimyo/AddressedStateAttention/paper_drafts
10
+ """
11
+
12
+ import math
13
+ from dataclasses import dataclass
14
+ from typing import Optional, Dict, Tuple
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+
20
+ __all__ = [
21
+ 'AddressedStateAttention',
22
+ 'ASMBlock',
23
+ 'ASMLanguageModel',
24
+ 'ASMTrainConfig',
25
+ 'build_model_from_cfg',
26
+ ]
27
+
28
+ # -------------------------
29
+ # RoPE helper (rotate-half)
30
+ # -------------------------
31
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
32
+ x1 = x[..., ::2]
33
+ x2 = x[..., 1::2]
34
+ return torch.stack((-x2, x1), dim=-1).flatten(-2)
35
+
36
+ class RotaryEmbedding(nn.Module):
37
+ def __init__(self, dim: int, base: float = 10000.0):
38
+ super().__init__()
39
+ assert dim % 2 == 0, "RoPE requires even dim"
40
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
41
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
42
+ self._cos_cached = None
43
+ self._sin_cached = None
44
+ self._t_cached = None
45
+ self._device_cached = None
46
+
47
+ def get_cos_sin(self, T: int, device, dtype):
48
+ if self._t_cached == T and self._cos_cached is not None and self._device_cached == device:
49
+ return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype)
50
+
51
+ t = torch.arange(T, device=device, dtype=self.inv_freq.dtype)
52
+ freqs = torch.einsum("t,f->tf", t, self.inv_freq) # [T, d/2]
53
+ emb = torch.cat([freqs, freqs], dim=-1) # [T, d]
54
+ cos = emb.cos()[None, None, :, :] # [1,1,T,d]
55
+ sin = emb.sin()[None, None, :, :] # [1,1,T,d]
56
+
57
+ self._t_cached = T
58
+ self._device_cached = device
59
+ self._cos_cached = cos
60
+ self._sin_cached = sin
61
+ return cos.to(dtype=dtype), sin.to(dtype=dtype)
62
+
63
+ def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
64
+ return (x * cos) + (_rotate_half(x) * sin)
65
+
66
+ # -------------------------
67
+ # ALiBi slopes helper
68
+ # -------------------------
69
+ def alibi_slopes(num_heads: int, device=None, dtype=torch.float32) -> torch.Tensor:
70
+ def get_slopes(n):
71
+ def power_of_2_slopes(n):
72
+ start = 2.0 ** (-(2.0 ** -(math.log2(n) - 3)))
73
+ ratio = start
74
+ return [start * (ratio ** i) for i in range(n)]
75
+ if math.log2(n).is_integer():
76
+ return power_of_2_slopes(n)
77
+ closest = 2 ** math.floor(math.log2(n))
78
+ return power_of_2_slopes(closest) + get_slopes(2 * closest)[0::2][: n - closest]
79
+ return torch.tensor(get_slopes(num_heads), device=device, dtype=dtype)
80
+
81
+ def _inv_softplus(y: torch.Tensor) -> torch.Tensor:
82
+ return torch.log(torch.expm1(y))
83
+
84
+ class AddressedStateAttention(nn.Module):
85
+ """
86
+ ASA with integral slotspace refine fused into the compiled chunk kernel.
87
+ Fixes included:
88
+ (1) pad slotspace RoPE cos/sin to CH (identity on padded positions)
89
+ (2) build valid_mask_c even when attention_mask is None (padding-only)
90
+ (3) pad write logits with -inf (so padded positions contribute zero to scan)
91
+ """
92
+
93
+ def __init__(
94
+ self,
95
+ embed_dim: int,
96
+ num_heads: int = 12,
97
+ num_slots: int = 16,
98
+ dropout: float = 0.1,
99
+
100
+ # temps / numerics
101
+ read_temperature: float = 1.0,
102
+ write_temperature: float = 1.0,
103
+ state_fp32: bool = True,
104
+ slot_dropout: float = 0.0,
105
+ normalize_k: bool = False,
106
+
107
+ # write geometry
108
+ use_rope_keys: bool = True,
109
+ rope_base: float = 10000.0,
110
+
111
+ # write bias
112
+ use_alibi_write: bool = True,
113
+ alibi_strength_init: float = 0.1,
114
+ learn_alibi_strength: bool = True,
115
+ min_strength: float = 0.0,
116
+
117
+ # content read gamma
118
+ use_content_read: bool = True,
119
+ content_read_init: float = -4.0,
120
+ content_read_max_gamma: float = 3.0,
121
+
122
+ # slotspace refine (INTEGRAL)
123
+ use_slotspace_refine: bool = True, # compat only
124
+ slotspace_dim: int = 8,
125
+ slotspace_gate_init: float = -4.0,
126
+ slotspace_dropout: float = 0.05,
127
+ slotspace_signed_weights: bool = True,
128
+
129
+ # slotspace RoPE (Q/K only)
130
+ use_rope_slotspace: bool = True,
131
+ rope_base_slotspace: float = 100000.0,
132
+
133
+ # perf
134
+ write_chunk_size: int = 1024,
135
+ enable_compiled: bool = True,
136
+ ):
137
+ super().__init__()
138
+ assert embed_dim % num_heads == 0
139
+ assert (slotspace_dim % 2) == 0, "slotspace_dim must be even if RoPE enabled"
140
+
141
+ self.embed_dim = embed_dim
142
+ self.num_heads = num_heads
143
+ self.num_slots = num_slots
144
+ self.head_dim = embed_dim // num_heads
145
+
146
+ self.dropout = nn.Dropout(dropout)
147
+
148
+ self.read_temperature = float(read_temperature)
149
+ self.write_temperature = float(write_temperature)
150
+ self.state_fp32 = bool(state_fp32)
151
+ self.slot_dropout = float(slot_dropout)
152
+ self.normalize_k = bool(normalize_k)
153
+
154
+ self.use_rope_keys = bool(use_rope_keys)
155
+ self.use_alibi_write = bool(use_alibi_write)
156
+ self.learn_alibi_strength = bool(learn_alibi_strength)
157
+ self.min_strength = float(min_strength)
158
+
159
+ self.use_content_read = bool(use_content_read)
160
+ self.content_read_max_gamma = float(content_read_max_gamma)
161
+
162
+ self.slotspace_dim = int(slotspace_dim)
163
+ self.slotspace_dropout = nn.Dropout(float(slotspace_dropout))
164
+ self.slotspace_signed_weights = bool(slotspace_signed_weights)
165
+
166
+ self.use_rope_slotspace = bool(use_rope_slotspace)
167
+ self.write_chunk_size = int(write_chunk_size)
168
+
169
+ H, K, d = self.num_heads, self.num_slots, self.head_dim
170
+ M = self.slotspace_dim
171
+
172
+ self.slot_keys = nn.Parameter(torch.randn(H, K, d) / math.sqrt(d))
173
+
174
+ self.Wk_write = nn.Linear(embed_dim, embed_dim, bias=False)
175
+ self.Wv_write = nn.Linear(embed_dim, embed_dim, bias=False)
176
+ self.Wq_read = nn.Linear(embed_dim, embed_dim, bias=False)
177
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
178
+
179
+ self.rope = RotaryEmbedding(d, base=rope_base) if self.use_rope_keys else None
180
+
181
+ if self.use_alibi_write:
182
+ self.register_buffer("_alibi_slopes", alibi_slopes(H), persistent=False)
183
+ else:
184
+ self.register_buffer("_alibi_slopes", torch.zeros(H), persistent=False)
185
+
186
+ if self.use_alibi_write and self.learn_alibi_strength:
187
+ init = torch.tensor(float(alibi_strength_init) - self.min_strength).clamp_min(1e-8)
188
+ self._alibi_strength_param = nn.Parameter(_inv_softplus(init))
189
+ else:
190
+ self._alibi_strength_param = None
191
+ self.alibi_strength = float(alibi_strength_init)
192
+
193
+ if self.use_content_read:
194
+ self._content_read_gamma_raw = nn.Parameter(torch.tensor(float(content_read_init)))
195
+ else:
196
+ self._content_read_gamma_raw = None
197
+
198
+ self.slot_in = nn.Linear(K, M, bias=False)
199
+ self.slot_q = nn.Linear(M, M, bias=False)
200
+ self.slot_k = nn.Linear(M, M, bias=False)
201
+ self.slot_v = nn.Linear(M, M, bias=False)
202
+ self.slot_out = nn.Linear(M, K, bias=False)
203
+
204
+ self._slotspace_gate_raw = nn.Parameter(torch.tensor(float(slotspace_gate_init)))
205
+
206
+ self.rope_slotspace = RotaryEmbedding(M, base=float(rope_base_slotspace)) if self.use_rope_slotspace else None
207
+
208
+ self._compiled = None
209
+ if enable_compiled:
210
+ self.enable_compiled_kernel()
211
+
212
+ def enable_compiled_kernel(self):
213
+ if self._compiled is None:
214
+ self._compiled = torch.compile(self._asa_chunk_fused, dynamic=False, fullgraph=False)
215
+
216
+ def _alibi_strength(self, dtype, device) -> torch.Tensor:
217
+ if not (self.use_alibi_write and self.learn_alibi_strength):
218
+ return torch.tensor(getattr(self, "alibi_strength", 0.0), dtype=dtype, device=device)
219
+ return (F.softplus(self._alibi_strength_param) + self.min_strength).to(dtype=dtype, device=device)
220
+
221
+ def _content_read_gamma(self, dtype, device) -> torch.Tensor:
222
+ if not self.use_content_read:
223
+ return torch.tensor(0.0, dtype=dtype, device=device)
224
+ g = F.softplus(self._content_read_gamma_raw)
225
+ if self.content_read_max_gamma is not None and self.content_read_max_gamma > 0:
226
+ g = g.clamp(max=self.content_read_max_gamma)
227
+ return g.to(dtype=dtype, device=device)
228
+
229
+ def _slotspace_gate(self, dtype, device) -> torch.Tensor:
230
+ return F.softplus(self._slotspace_gate_raw).to(dtype=dtype, device=device)
231
+
232
+ @staticmethod
233
+ def _safe_exp_sub_max(s: torch.Tensor, m: torch.Tensor) -> torch.Tensor:
234
+ diff = s - m
235
+ diff = diff.masked_fill(~torch.isfinite(m), float("-inf"))
236
+ return torch.exp(diff)
237
+
238
+ @staticmethod
239
+ def _phi(x: torch.Tensor) -> torch.Tensor:
240
+ return F.elu(x) + 1.0
241
+
242
+ @staticmethod
243
+ def _pad_time_slice(x: torch.Tensor, t0: int, L: int, CH: int, dim: int):
244
+ sl = x.narrow(dim, t0, L)
245
+ if L == CH:
246
+ return sl, None
247
+ pad_shape = list(sl.shape)
248
+ pad_shape[dim] = CH - L
249
+ pad = torch.zeros(pad_shape, device=sl.device, dtype=sl.dtype)
250
+ xpad = torch.cat([sl, pad], dim=dim)
251
+ mask = torch.zeros((CH,), device=sl.device, dtype=torch.bool)
252
+ mask[:L] = True
253
+ return xpad, mask
254
+
255
+ def _asa_chunk_fused(
256
+ self,
257
+ wlog_c: torch.Tensor, # [B,H,K,CH]
258
+ v_c: torch.Tensor, # [B,H,CH,d]
259
+ q_c: torch.Tensor, # [B,H,CH,d]
260
+ slot_keys_dk: torch.Tensor, # [1,H,d,K]
261
+ pos_cos_s: Optional[torch.Tensor], # [1,1,CH,M] or None
262
+ pos_sin_s: Optional[torch.Tensor], # [1,1,CH,M] or None
263
+ content_gamma: torch.Tensor,
264
+ rtemp_t: torch.Tensor,
265
+ gate_t: torch.Tensor,
266
+ m_state: torch.Tensor, # [B,H,K]
267
+ denom_state: torch.Tensor, # [B,H,K]
268
+ numer_state: torch.Tensor, # [B,H,K,d]
269
+ S_state: torch.Tensor, # [B,H,M,M]
270
+ Z_state: torch.Tensor, # [B,H,M]
271
+ valid_mask_c: Optional[torch.Tensor], # [B,1,CH,1] or None
272
+ do_dropout: bool,
273
+ dropout_p: float,
274
+ signed_slot_w: bool,
275
+ ):
276
+ B, H, K, CH = wlog_c.shape
277
+ d = numer_state.shape[-1]
278
+ M = S_state.shape[-1]
279
+ inv_sqrt_d = 1.0 / math.sqrt(d)
280
+
281
+ # ----- WRITE prefix-softmax scan -----
282
+ m_c, _ = torch.cummax(wlog_c, dim=-1) # [B,H,K,CH]
283
+ m_new = torch.maximum(m_state.unsqueeze(-1), m_c) # [B,H,K,CH]
284
+ scale = torch.exp(m_state.unsqueeze(-1) - m_new) # [B,H,K,CH]
285
+
286
+ denom_c = denom_state.unsqueeze(-1) * scale # [B,H,K,CH]
287
+ numer_c = numer_state.unsqueeze(-2) * scale.unsqueeze(-1) # [B,H,K,CH,d]
288
+
289
+ w_new = self._safe_exp_sub_max(wlog_c, m_new) # [B,H,K,CH]
290
+ denom_c = denom_c + torch.cumsum(w_new, dim=-1) # [B,H,K,CH]
291
+ numer_c = numer_c + torch.cumsum(w_new.unsqueeze(-1) * v_c.unsqueeze(2), dim=-2) # [B,H,K,CH,d]
292
+
293
+ # ----- Routing logits -----
294
+ read_logits_key = torch.matmul(q_c, slot_keys_dk) * inv_sqrt_d # [B,H,CH,K]
295
+
296
+ if self.use_content_read:
297
+ numer_for_dot = numer_c.to(q_c.dtype).permute(0, 1, 3, 2, 4) # [B,H,CH,K,d]
298
+ denom_for_div = denom_c.to(q_c.dtype).permute(0, 1, 3, 2) # [B,H,CH,K]
299
+ read_logits_content = (q_c.unsqueeze(-2) * numer_for_dot).sum(dim=-1) * inv_sqrt_d
300
+ read_logits_content = read_logits_content / denom_for_div.clamp_min(1e-8)
301
+ read_logits = read_logits_key + content_gamma.to(read_logits_key.dtype) * read_logits_content
302
+ else:
303
+ read_logits = read_logits_key
304
+
305
+ read_w = torch.softmax(read_logits / rtemp_t, dim=-1) # [B,H,CH,K]
306
+
307
+ # ----- EXACT base output -----
308
+ inv_denom = (1.0 / denom_c.clamp_min(1e-8)).to(numer_c.dtype) # [B,H,K,CH]
309
+ w_scaled = read_w.to(numer_c.dtype).permute(0, 1, 3, 2) * inv_denom # [B,H,K,CH]
310
+ out_base = (w_scaled.unsqueeze(-1) * numer_c).sum(dim=2) # [B,H,CH,d]
311
+
312
+ # ----- Slotspace refine -----
313
+ u = self.slot_in(read_w.to(out_base.dtype)) # [B,H,CH,M]
314
+ q_s = self.slot_q(u)
315
+ k_s = self.slot_k(u)
316
+ v_s = self.slot_v(u)
317
+
318
+ if self.use_rope_slotspace and (pos_cos_s is not None) and (pos_sin_s is not None):
319
+ q_s = apply_rope(q_s, pos_cos_s, pos_sin_s)
320
+ k_s = apply_rope(k_s, pos_cos_s, pos_sin_s)
321
+
322
+ if valid_mask_c is not None:
323
+ q_s = q_s * valid_mask_c
324
+ k_s = k_s * valid_mask_c
325
+ v_s = v_s * valid_mask_c
326
+
327
+ qf = self._phi(q_s)
328
+ kf = self._phi(k_s)
329
+
330
+ kv = kf.unsqueeze(-1) * v_s.unsqueeze(-2) # [B,H,CH,M,M]
331
+ S_c = torch.cumsum(kv, dim=2) + S_state.unsqueeze(2) # [B,H,CH,M,M]
332
+ Z_c = torch.cumsum(kf, dim=2) + Z_state.unsqueeze(2) # [B,H,CH,M]
333
+ Z_c = Z_c.clamp_min(1e-8)
334
+
335
+ num = torch.matmul(qf.unsqueeze(-2), S_c).squeeze(-2) # [B,H,CH,M]
336
+ den = (qf * Z_c).sum(dim=-1, keepdim=True).clamp_min(1e-8) # [B,H,CH,1]
337
+ u2 = num / den # [B,H,CH,M]
338
+
339
+ S_state_new = S_c[:, :, -1, :, :]
340
+ Z_state_new = Z_c[:, :, -1, :]
341
+
342
+ if do_dropout and dropout_p > 0.0:
343
+ keep = (torch.rand_like(u2) > dropout_p).to(u2.dtype) / (1.0 - dropout_p)
344
+ u2 = u2 * keep
345
+
346
+ slot_w = self.slot_out(u2) # [B,H,CH,K]
347
+ if signed_slot_w:
348
+ slot_w = torch.tanh(slot_w)
349
+ else:
350
+ slot_w = torch.softmax(slot_w, dim=-1)
351
+
352
+ slot_w_scaled = slot_w.to(numer_c.dtype).permute(0, 1, 3, 2) * inv_denom
353
+ delta = (slot_w_scaled.unsqueeze(-1) * numer_c).sum(dim=2) # [B,H,CH,d]
354
+
355
+ out = out_base + gate_t.to(out_base.dtype) * delta
356
+
357
+ m_state_new = m_new[:, :, :, -1]
358
+ denom_state_new = denom_c[:, :, :, -1]
359
+ numer_state_new = numer_c[:, :, :, -1, :]
360
+
361
+ return out, read_w, m_state_new, denom_state_new, numer_state_new, S_state_new, Z_state_new
362
+
363
+ def forward(
364
+ self,
365
+ x: torch.Tensor,
366
+ attention_mask: Optional[torch.Tensor] = None,
367
+ return_info: bool = False,
368
+ return_light_stats: bool = False,
369
+ ) -> Tuple[torch.Tensor, Optional[Dict[str, torch.Tensor]]]:
370
+
371
+ B, T, C = x.shape
372
+ H, K, d = self.num_heads, self.num_slots, self.head_dim
373
+ M = self.slotspace_dim
374
+
375
+ k_write = self.Wk_write(x).reshape(B, T, H, d).transpose(1, 2) # [B,H,T,d]
376
+ v_write = self.Wv_write(x).reshape(B, T, H, d).transpose(1, 2) # [B,H,T,d]
377
+ q_read = self.Wq_read(x).reshape(B, T, H, d).transpose(1, 2) # [B,H,T,d]
378
+
379
+ if self.normalize_k:
380
+ k_write = F.normalize(k_write, dim=-1, eps=1e-8)
381
+
382
+ if self.use_rope_keys:
383
+ cos, sin = self.rope.get_cos_sin(T, device=x.device, dtype=k_write.dtype)
384
+ k_write = apply_rope(k_write, cos, sin)
385
+
386
+ slot_keys = self.slot_keys
387
+ if self.training and self.slot_dropout > 0.0:
388
+ drop = (torch.rand((H, K), device=x.device) < self.slot_dropout)
389
+ slot_keys = slot_keys * (~drop).to(slot_keys.dtype).unsqueeze(-1)
390
+
391
+ slot_keys_dk = slot_keys.transpose(-1, -2).unsqueeze(0).to(q_read.dtype) # [1,H,d,K]
392
+
393
+ write_logits_raw = torch.matmul(k_write.to(q_read.dtype), slot_keys_dk).permute(0, 1, 3, 2) / math.sqrt(d)
394
+
395
+ state_dtype = torch.float32 if (self.state_fp32 and x.dtype != torch.float32) else x.dtype
396
+ write_logits = write_logits_raw.to(state_dtype)
397
+
398
+ wtemp = max(1e-6, self.write_temperature)
399
+ write_logits = write_logits / wtemp
400
+
401
+ if self.use_alibi_write:
402
+ strength = self._alibi_strength(dtype=state_dtype, device=x.device)
403
+ slopes = self._alibi_slopes.to(device=x.device, dtype=state_dtype) * strength
404
+ pos = torch.arange(T, device=x.device, dtype=state_dtype)
405
+ write_logits = write_logits + slopes.view(1, H, 1, 1) * pos.view(1, 1, 1, T)
406
+
407
+ valid = None
408
+ if attention_mask is not None:
409
+ valid = attention_mask.to(dtype=torch.bool)
410
+ write_logits = write_logits.masked_fill(~valid.view(B, 1, 1, T), float("-inf"))
411
+
412
+ content_gamma = self._content_read_gamma(dtype=q_read.dtype, device=x.device)
413
+ rtemp_t = torch.tensor(max(1e-6, self.read_temperature), device=x.device, dtype=q_read.dtype)
414
+ gate_t = self._slotspace_gate(dtype=state_dtype, device=x.device)
415
+
416
+ denom_state = torch.zeros((B, H, K), device=x.device, dtype=state_dtype)
417
+ numer_state = torch.zeros((B, H, K, d), device=x.device, dtype=state_dtype)
418
+ m_state = torch.full((B, H, K), float("-inf"), device=x.device, dtype=state_dtype)
419
+
420
+ S_state = torch.zeros((B, H, M, M), device=x.device, dtype=state_dtype)
421
+ Z_state = torch.zeros((B, H, M), device=x.device, dtype=state_dtype)
422
+
423
+ out_h = torch.empty((B, H, T, d), device=x.device, dtype=state_dtype)
424
+
425
+ if self.use_rope_slotspace:
426
+ cos_s_full, sin_s_full = self.rope_slotspace.get_cos_sin(T, device=x.device, dtype=state_dtype)
427
+ else:
428
+ cos_s_full = sin_s_full = None
429
+
430
+ CH = self.write_chunk_size
431
+ kernel = self._compiled if self._compiled is not None else self._asa_chunk_fused
432
+
433
+ do_dropout = bool(self.training and self.slotspace_dropout.p > 0.0)
434
+ dropout_p = float(self.slotspace_dropout.p)
435
+ signed_slot_w = bool(self.slotspace_signed_weights)
436
+
437
+ for t0 in range(0, T, CH):
438
+ t1 = min(T, t0 + CH)
439
+ L = t1 - t0
440
+
441
+ wlog_c, mask = self._pad_time_slice(write_logits, t0, L, CH, dim=3) # [B,H,K,CH]
442
+ v_c, _ = self._pad_time_slice(v_write.to(state_dtype), t0, L, CH, dim=2) # [B,H,CH,d]
443
+ q_c, _ = self._pad_time_slice(q_read, t0, L, CH, dim=2) # [B,H,CH,d]
444
+
445
+ # (3) ensure padded write logits contribute zero mass
446
+ if mask is not None:
447
+ wlog_c = wlog_c.clone()
448
+ wlog_c[:, :, :, L:] = float("-inf")
449
+
450
+ # (2) build valid_mask_c even when attention_mask is None (padding-only)
451
+ valid_mask_c = None
452
+ if (valid is not None) or (mask is not None):
453
+ if valid is None:
454
+ vm_pad = mask.view(1, CH).expand(B, CH) # [B,CH]
455
+ else:
456
+ if mask is None:
457
+ vm_pad = valid[:, t0:t1]
458
+ else:
459
+ vm = valid[:, t0:t1]
460
+ vm_pad = torch.zeros((B, CH), device=x.device, dtype=torch.bool)
461
+ vm_pad[:, :L] = vm
462
+ valid_mask_c = vm_pad.view(B, 1, CH, 1).to(state_dtype)
463
+
464
+ # (1) slotspace RoPE slice PADDED TO CH (identity on padded positions)
465
+ if self.use_rope_slotspace:
466
+ cos_slice = cos_s_full[:, :, t0:t1, :] # [1,1,L,M]
467
+ sin_slice = sin_s_full[:, :, t0:t1, :] # [1,1,L,M]
468
+ if L == CH:
469
+ cos_s, sin_s = cos_slice, sin_slice
470
+ else:
471
+ cos_s = torch.ones((1, 1, CH, M), device=x.device, dtype=state_dtype)
472
+ sin_s = torch.zeros((1, 1, CH, M), device=x.device, dtype=state_dtype)
473
+ cos_s[:, :, :L, :] = cos_slice
474
+ sin_s[:, :, :L, :] = sin_slice
475
+ else:
476
+ cos_s = sin_s = None
477
+
478
+ out_c, read_w_c, m_state, denom_state, numer_state, S_state, Z_state = kernel(
479
+ wlog_c, v_c, q_c, slot_keys_dk,
480
+ cos_s, sin_s,
481
+ content_gamma, rtemp_t, gate_t,
482
+ m_state, denom_state, numer_state,
483
+ S_state, Z_state,
484
+ valid_mask_c,
485
+ do_dropout, dropout_p,
486
+ signed_slot_w,
487
+ )
488
+
489
+ if mask is not None:
490
+ out_c = out_c * mask.view(1, 1, CH, 1).to(out_c.dtype)
491
+
492
+ out_h[:, :, t0:t1, :] = out_c[:, :, :L, :]
493
+
494
+ out = out_h.transpose(1, 2).reshape(B, T, C)
495
+ out = self.out_proj(out)
496
+ out = self.dropout(out)
497
+
498
+ info = None
499
+ if return_info or return_light_stats:
500
+ info = {
501
+ "content_read_gamma": content_gamma.detach().to(torch.float32).cpu(),
502
+ "slotspace_gate": gate_t.detach().to(torch.float32).cpu(),
503
+ }
504
+ return out, info
505
+
506
+
507
+ # ============================================================================
508
+ # Addressed State Models (ASM): Config + Block + LM
509
+ # - Naming aligned with paper: slots, read/write, slot-space refinement
510
+ # - No compatibility layer (fresh public tooling)
511
+ # ============================================================================
512
+
513
+
514
+ # ============================================================================
515
+ # Config
516
+ # ============================================================================
517
+ @dataclass
518
+ class ASMTrainConfig:
519
+ # Data
520
+ dataset_name: str = "wikitext"
521
+ dataset_config: str = "wikitext-103-raw-v1"
522
+ tokenizer_name: str = "gpt2"
523
+
524
+ max_seq_len: int = 256
525
+ stride_frac_val: float = 0.50
526
+ seed: int = 1337
527
+ micro_batch_size: int = 2
528
+ grad_accum_steps: int = 8
529
+ # Sample budgets
530
+ train_samples_target: int = 100_000_000
531
+ val_samples_target: int = 25_000
532
+
533
+ # Training
534
+ batch_size: int = 64
535
+ learning_rate: float = 3e-4
536
+ weight_decay: float = 0.01
537
+ betas: Tuple[float, float] = (0.9, 0.95)
538
+ grad_clip: float = 1.0
539
+ warmup_steps: int = 1_000
540
+ total_steps: int = 75_000
541
+ eval_interval: int = 1_000
542
+ log_interval: int = 100
543
+
544
+ # Model
545
+ vocab_size: int = 50257
546
+ embed_dim: int = 384
547
+ num_layers: int = 23
548
+ num_heads: int = 8
549
+ num_slots: int = 32
550
+ mlp_ratio: float = 4.0
551
+ dropout: float = 0.1
552
+ tie_weights: bool = True
553
+
554
+ # Addressed State Attention (ASA) / numerics
555
+ read_temperature: float = 1.0
556
+ write_temperature: float = 1.0
557
+ slot_dropout: float = 0.05
558
+ state_fp32: bool = True
559
+ normalize_k: bool = False
560
+
561
+ # Positions
562
+ use_abs_pos: bool = False
563
+ use_rope_keys: bool = True
564
+ rope_base: float = 10000.0
565
+ use_alibi_write: bool = True
566
+ alibi_strength_init: float = 0.1
567
+ learn_alibi_strength: bool = True
568
+ min_strength: float = 0.0
569
+
570
+ # Content-conditioned read term (gamma)
571
+ use_content_read: bool = True
572
+ content_read_init: float = -4.0
573
+ content_read_max_gamma: float = 3.0
574
+
575
+ # Optional slot-space refinement (formerly "k-space")
576
+ use_slotspace_refine: bool = True
577
+ slotspace_dim: int = 64
578
+ slotspace_gate_init: float = -4.0
579
+ slotspace_dropout: float = 0.05
580
+ slotspace_signed_weights: bool = True
581
+
582
+ # RoPE inside slot-space matcher (Q/K only)
583
+ use_rope_slotspace: bool = True
584
+ rope_base_slotspace: float = 100000.0
585
+
586
+ # Perf knobs (behavior-identical)
587
+ write_chunk_size: int = 128
588
+ enable_compiled: bool = True
589
+
590
+ # Analytics
591
+ eval_max_batches: int = 150
592
+ analytics_last_k: int = 4
593
+
594
+ # IO / caches
595
+ output_dir: str = "./drive/MyDrive/asm_outputs"
596
+ tag: str = "asm_wikitext"
597
+ cache_dir: str = "./drive/MyDrive/asm_caches/fineweb/1B"
598
+ val_windows_cache: str = "./drive/MyDrive/asm_val_cache_windows_1024.pkl"
599
+
600
+
601
+ # ============================================================================
602
+ # Block
603
+ # ============================================================================
604
+ class ASMBlock(nn.Module):
605
+ def __init__(
606
+ self,
607
+ embed_dim: int,
608
+ num_heads: int,
609
+ num_slots: int,
610
+ mlp_ratio: float = 4.0,
611
+ dropout: float = 0.1,
612
+
613
+ # temperatures / numerics
614
+ read_temperature: float = 1.0,
615
+ write_temperature: float = 1.0,
616
+ state_fp32: bool = True,
617
+ slot_dropout: float = 0.0,
618
+ normalize_k: bool = False,
619
+
620
+ # positions
621
+ use_rope_keys: bool = True,
622
+ rope_base: float = 10000.0,
623
+ use_alibi_write: bool = True,
624
+
625
+ # ALiBi params
626
+ alibi_strength_init: float = 0.1,
627
+ learn_alibi_strength: bool = True,
628
+ min_strength: float = 0.0,
629
+
630
+ # content-conditioned read (gamma)
631
+ use_content_read: bool = True,
632
+ content_read_init: float = -4.0,
633
+ content_read_max_gamma: float = 3.0,
634
+
635
+ # optional slot-space refinement
636
+ use_slotspace_refine: bool = True,
637
+ slotspace_dim: int = 32,
638
+ slotspace_gate_init: float = -10.0,
639
+ slotspace_dropout: float = 0.0,
640
+ slotspace_signed_weights: bool = True,
641
+
642
+ # RoPE inside slot-space matcher
643
+ use_rope_slotspace: bool = True,
644
+ rope_base_slotspace: float = 100000.0,
645
+
646
+ # chunk sizes
647
+ write_chunk_size: int = 128,
648
+ enable_compiled: bool = False,
649
+ ):
650
+ super().__init__()
651
+ self.norm1 = nn.LayerNorm(embed_dim)
652
+
653
+ self.asa = AddressedStateAttention(
654
+ embed_dim=embed_dim,
655
+ num_heads=num_heads,
656
+ num_slots=num_slots,
657
+ dropout=dropout,
658
+
659
+ read_temperature=read_temperature,
660
+ write_temperature=write_temperature,
661
+ state_fp32=state_fp32,
662
+ slot_dropout=slot_dropout,
663
+ normalize_k=normalize_k,
664
+
665
+ use_rope_keys=use_rope_keys,
666
+ rope_base=rope_base,
667
+ use_alibi_write=use_alibi_write,
668
+ alibi_strength_init=alibi_strength_init,
669
+ learn_alibi_strength=learn_alibi_strength,
670
+ min_strength=min_strength,
671
+
672
+ use_content_read=use_content_read,
673
+ content_read_init=content_read_init,
674
+ content_read_max_gamma=content_read_max_gamma,
675
+
676
+ use_slotspace_refine=use_slotspace_refine,
677
+ slotspace_dim=slotspace_dim,
678
+ slotspace_gate_init=slotspace_gate_init,
679
+ slotspace_dropout=slotspace_dropout,
680
+ slotspace_signed_weights=slotspace_signed_weights,
681
+
682
+ use_rope_slotspace=use_rope_slotspace,
683
+ rope_base_slotspace=rope_base_slotspace,
684
+
685
+ write_chunk_size=write_chunk_size,
686
+ enable_compiled=enable_compiled,
687
+
688
+ )
689
+
690
+ self.norm2 = nn.LayerNorm(embed_dim)
691
+ hidden = int(embed_dim * mlp_ratio)
692
+ self.mlp = nn.Sequential(
693
+ nn.Linear(embed_dim, hidden, bias=False),
694
+ nn.GELU(),
695
+ nn.Dropout(dropout),
696
+ nn.Linear(hidden, embed_dim, bias=False),
697
+ nn.Dropout(dropout),
698
+ )
699
+
700
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, return_info: bool = False, return_light_stats: Optional[bool] = None):
701
+ a, info = self.asa(self.norm1(x), attention_mask=attention_mask, return_info=return_info, return_light_stats=return_light_stats)
702
+ x = x + a
703
+ x = x + self.mlp(self.norm2(x))
704
+ return x, info
705
+
706
+
707
+ # ============================================================================
708
+ # LM
709
+ # ============================================================================
710
+ class ASMLanguageModel(nn.Module):
711
+ def __init__(
712
+ self,
713
+ vocab_size: int,
714
+ embed_dim: int = 384,
715
+ num_layers: int = 6,
716
+ num_heads: int = 8,
717
+ num_slots: int = 8,
718
+ max_seq_len: int = 1024,
719
+ mlp_ratio: float = 4.0,
720
+ dropout: float = 0.1,
721
+
722
+ # temperatures / numerics
723
+ read_temperature: float = 1.0,
724
+ write_temperature: float = 1.0,
725
+ state_fp32: bool = True,
726
+ slot_dropout: float = 0.05,
727
+ normalize_k: bool = False,
728
+
729
+ tie_weights: bool = True,
730
+
731
+ # LM-level abs pos
732
+ use_abs_pos: bool = False,
733
+
734
+ # positions
735
+ use_rope_keys: bool = True,
736
+ rope_base: float = 10000.0,
737
+ use_alibi_write: bool = True,
738
+
739
+ # ALiBi
740
+ alibi_strength_init: float = 0.1,
741
+ learn_alibi_strength: bool = True,
742
+ min_strength: float = 0.0,
743
+
744
+ # content-conditioned read (gamma)
745
+ use_content_read: bool = True,
746
+ content_read_init: float = -4.0,
747
+ content_read_max_gamma: float = 3.0,
748
+
749
+ # optional slot-space refinement
750
+ use_slotspace_refine: bool = True,
751
+ slotspace_dim: int = 32,
752
+ slotspace_gate_init: float = -10.0,
753
+ slotspace_dropout: float = 0.0,
754
+ slotspace_signed_weights: bool = True,
755
+
756
+ # RoPE inside slot-space matcher
757
+ use_rope_slotspace: bool = True,
758
+ rope_base_slotspace: float = 100000.0,
759
+
760
+ # chunk sizes
761
+ write_chunk_size: int = 128,
762
+ enable_compiled: bool = False,
763
+ ):
764
+ super().__init__()
765
+ self.vocab_size = vocab_size
766
+ self.embed_dim = embed_dim
767
+ self.max_seq_len = max_seq_len
768
+ self.use_abs_pos = bool(use_abs_pos)
769
+
770
+ self.tok = nn.Embedding(vocab_size, embed_dim)
771
+ self.pos = nn.Embedding(max_seq_len, embed_dim) if self.use_abs_pos else None
772
+ self.drop = nn.Dropout(dropout)
773
+
774
+ self.blocks = nn.ModuleList([
775
+ ASMBlock(
776
+ embed_dim=embed_dim,
777
+ num_heads=num_heads,
778
+ num_slots=num_slots,
779
+ mlp_ratio=mlp_ratio,
780
+ dropout=dropout,
781
+
782
+ read_temperature=read_temperature,
783
+ write_temperature=write_temperature,
784
+ state_fp32=state_fp32,
785
+ slot_dropout=slot_dropout,
786
+ normalize_k=normalize_k,
787
+
788
+ use_rope_keys=use_rope_keys,
789
+ rope_base=rope_base,
790
+ use_alibi_write=use_alibi_write,
791
+
792
+ alibi_strength_init=alibi_strength_init,
793
+ learn_alibi_strength=learn_alibi_strength,
794
+ min_strength=min_strength,
795
+
796
+ use_content_read=use_content_read,
797
+ content_read_init=content_read_init,
798
+ content_read_max_gamma=content_read_max_gamma,
799
+
800
+ use_slotspace_refine=use_slotspace_refine,
801
+ slotspace_dim=slotspace_dim, slotspace_gate_init=slotspace_gate_init,
802
+ slotspace_dropout=slotspace_dropout,
803
+ slotspace_signed_weights=slotspace_signed_weights,
804
+ use_rope_slotspace=use_rope_slotspace,
805
+ rope_base_slotspace=rope_base_slotspace,
806
+
807
+ write_chunk_size=write_chunk_size,
808
+ enable_compiled=enable_compiled,
809
+ )
810
+ for _ in range(num_layers)
811
+ ])
812
+
813
+ self.norm = nn.LayerNorm(embed_dim)
814
+ self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
815
+ if tie_weights:
816
+ self.lm_head.weight = self.tok.weight
817
+
818
+ self.apply(self._init)
819
+
820
+ def _init(self, m):
821
+ if isinstance(m, nn.Linear):
822
+ nn.init.normal_(m.weight, std=0.02)
823
+ elif isinstance(m, nn.Embedding):
824
+ nn.init.normal_(m.weight, std=0.02)
825
+ elif isinstance(m, nn.LayerNorm):
826
+ nn.init.ones_(m.weight)
827
+ nn.init.zeros_(m.bias)
828
+
829
+ def forward(
830
+ self,
831
+ input_ids: torch.Tensor,
832
+ attention_mask: Optional[torch.Tensor] = None,
833
+ return_info: bool = False,
834
+ return_light_stats: Optional[bool] = None,
835
+ ):
836
+ B, T = input_ids.shape
837
+ assert T <= self.max_seq_len, f"T={T} exceeds max_seq_len={self.max_seq_len}"
838
+
839
+ x = self.tok(input_ids)
840
+ if self.use_abs_pos:
841
+ pos = torch.arange(T, device=input_ids.device).unsqueeze(0).expand(B, -1)
842
+ x = x + self.pos(pos)
843
+
844
+ x = self.drop(x)
845
+
846
+ infos = []
847
+ for blk in self.blocks:
848
+ x, info = blk(x, attention_mask=attention_mask, return_info=return_info, return_light_stats=return_light_stats)
849
+ if return_info:
850
+ infos.append(info)
851
+
852
+ x = self.norm(x)
853
+ logits = self.lm_head(x)
854
+ return (logits, infos) if return_info else logits
855
+
856
+
857
+ # ============================================================================
858
+ # Convenience: build model from config
859
+ # ============================================================================
860
+ def build_model_from_cfg(cfg: ASMTrainConfig) -> ASMLanguageModel:
861
+ return ASMLanguageModel(
862
+ vocab_size=cfg.vocab_size,
863
+ embed_dim=cfg.embed_dim,
864
+ num_layers=cfg.num_layers,
865
+ num_heads=cfg.num_heads,
866
+ num_slots=cfg.num_slots,
867
+ max_seq_len=cfg.max_seq_len,
868
+ mlp_ratio=cfg.mlp_ratio,
869
+ dropout=cfg.dropout,
870
+
871
+ read_temperature=cfg.read_temperature,
872
+ write_temperature=cfg.write_temperature,
873
+ state_fp32=cfg.state_fp32,
874
+ slot_dropout=cfg.slot_dropout,
875
+ normalize_k=cfg.normalize_k,
876
+
877
+ tie_weights=cfg.tie_weights,
878
+
879
+ use_abs_pos=cfg.use_abs_pos,
880
+ use_rope_keys=cfg.use_rope_keys,
881
+ rope_base=cfg.rope_base,
882
+ use_alibi_write=cfg.use_alibi_write,
883
+
884
+ alibi_strength_init=cfg.alibi_strength_init,
885
+ learn_alibi_strength=cfg.learn_alibi_strength,
886
+ min_strength=cfg.min_strength,
887
+
888
+ use_content_read=cfg.use_content_read,
889
+ content_read_init=cfg.content_read_init,
890
+ content_read_max_gamma=cfg.content_read_max_gamma,
891
+
892
+ use_slotspace_refine=cfg.use_slotspace_refine,
893
+ slotspace_dim=cfg.slotspace_dim,
894
+ slotspace_gate_init=cfg.slotspace_gate_init,
895
+ slotspace_dropout=cfg.slotspace_dropout,
896
+ slotspace_signed_weights=cfg.slotspace_signed_weights,
897
+ use_rope_slotspace=cfg.use_rope_slotspace,
898
+ rope_base_slotspace=cfg.rope_base_slotspace,
899
+
900
+ write_chunk_size=cfg.write_chunk_size,
901
+ enable_compiled=cfg.enable_compiled,
902
+ )