OzTianlu commited on
Commit
a6cf12a
·
verified ·
1 Parent(s): b6c0790

Upload 14 files

Browse files
Files changed (10) hide show
  1. .gitattributes +1 -0
  2. ACC_SPAR.png +0 -0
  3. ARCH.png +3 -0
  4. LOSS_SPAR.png +0 -0
  5. MonoidForCausalLM.py +85 -66
  6. README.md +166 -40
  7. config.json +1 -0
  8. model.safetensors +2 -2
  9. monoid_scan_cuda.py +61 -63
  10. training_args.bin +2 -2
.gitattributes CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  tokenizer.json filter=lfs diff=lfs merge=lfs -text
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  tokenizer.json filter=lfs diff=lfs merge=lfs -text
37
+ ARCH.png filter=lfs diff=lfs merge=lfs -text
ACC_SPAR.png ADDED
ARCH.png ADDED

Git LFS Details

  • SHA256: 4a9331e05338296049fe3e87e223538cd291f89234a9da8865d8febaf38ae2c2
  • Pointer size: 131 Bytes
  • Size of remote file: 668 kB
LOSS_SPAR.png ADDED
MonoidForCausalLM.py CHANGED
@@ -23,11 +23,11 @@ Architecture / 架构概要:
23
  其中 α_t ∈ ℝ^d 是逐维度的向量衰减门。
24
 
25
  This is a monoid because the binary operator:
26
- (log_α, S) ⊕ (log_β, X) = (log_α + log_β, exp(log_β)·S + X)
27
  is associative → enables parallel prefix scan for training,
28
  and O(1) sequential update for inference.
29
  这是一个幺半群,因为二元算子:
30
- (log_α, S) ⊕ (log_β, X) = (log_α + log_β, exp(log_β)·S + X)
31
  满足结合律 → 训练时可用并行前缀扫描,推理时 O(1) 逐步递推。
32
 
33
  Key properties / 关键特性:
@@ -70,33 +70,33 @@ except ImportError:
70
  # Pure-PyTorch fallback (sequential scan) — works on CPU / MPS / any device.
71
  # Slower than the fused CUDA kernel but numerically identical.
72
 
73
- def parallel_scan(log_alpha: Tensor, kv: Tensor) -> Tensor:
74
- """Sequential prefix scan fallback: S_t[i,:] = exp(log_α_t[i])·S_{t-1}[i,:] + kv_t[i,:]."""
75
  B, H, T, d1, d2 = kv.shape
76
  states = torch.zeros(B, H, T, d1, d2, device=kv.device, dtype=kv.dtype)
77
  S = torch.zeros(B, H, d1, d2, device=kv.device, dtype=kv.dtype)
78
  for t in range(T):
79
- decay = torch.exp(log_alpha[:, :, t]) # [B, H, d]
80
  while decay.dim() < S.dim():
81
  decay = decay.unsqueeze(-1)
82
  S = S * decay + kv[:, :, t]
83
  states[:, :, t] = S
84
  return states
85
 
86
- def parallel_scan_with_state(log_alpha: Tensor, kv: Tensor):
87
- """Sequential prefix scan that also returns the final (log_decay, S) state."""
88
  B, H, T, d1, d2 = kv.shape
89
  states = torch.zeros(B, H, T, d1, d2, device=kv.device, dtype=kv.dtype)
90
  S = torch.zeros(B, H, d1, d2, device=kv.device, dtype=kv.dtype)
91
- log_acc = torch.zeros(B, H, d1, device=log_alpha.device, dtype=log_alpha.dtype)
92
  for t in range(T):
93
- decay = torch.exp(log_alpha[:, :, t])
94
  while decay.dim() < S.dim():
95
  decay = decay.unsqueeze(-1)
96
  S = S * decay + kv[:, :, t]
97
  states[:, :, t] = S
98
- log_acc = log_acc + log_alpha[:, :, t]
99
- return states, (log_acc, S)
100
 
101
 
102
 
@@ -169,14 +169,14 @@ class MonoidCache:
169
 
170
  Unlike Transformer KV-Cache that stores all past keys & values (O(T) memory),
171
  each layer here stores exactly ONE state tuple:
172
- (log_decay_acc, S) where S ∈ ℝ^{B, H, d, d}
173
- This is the monoid "sum" of all past (log_α_i, k_i⊗v_i) via ⊕.
174
  Memory is O(1) per layer regardless of sequence length.
175
 
176
  不同于 Transformer 的 KV-Cache (存储所有过去的 key 和 value, O(T) 内存),
177
  这里每层仅存储一个状态元组:
178
- (log_decay_acc, S) 其中 S ∈ ℝ^{B, H, d, d}
179
- 这是所有过去的 (log_α_i, k_i⊗v_i) 通过 ⊕ 累积的幺半群 "和"。
180
  无论序列多长,每层内存 O(1)。
181
  """
182
 
@@ -219,12 +219,12 @@ def monoid_op(
219
  b: tuple[Tensor, Tensor],
220
  ) -> tuple[Tensor, Tensor]:
221
  """
222
- The monoid binary operator ⊕ on (log-space vector decay, state matrix) pairs.
223
- 幺半群二元算子 ⊕,作用于 (对数向量衰减, 状态矩阵) 对。
224
 
225
  Definition / 定义:
226
- (log_α, S) ⊕ (log_β, X) = (log_α + log_β, diag(exp(log_β))·S + X)
227
- where log_α, log_β ∈ ^d are per-dimension log decay vectors.
228
 
229
  Why this is a monoid / 为什么这是幺半群:
230
  • Associativity / 结合律:
@@ -235,12 +235,7 @@ def monoid_op(
235
  推理时可以 O(1) 左折叠 (逐步追加)。
236
 
237
  • Identity / 单位元:
238
- e = (0, 0) → e ⊕ a = a ⊕ e = a ✓
239
-
240
- Why log-space / 为什么用对数空间:
241
- Working in log-space for the decay factor avoids numerical
242
- underflow when α^T → 0 for long sequences.
243
- 衰减因子在���数空间中运算,避免长序列下 α^T → 0 的数值下溢。
244
 
245
  Causal semantics / 因果语义:
246
  S_t = α_t · S_{t-1} + k_t ⊗ v_t
@@ -251,15 +246,14 @@ def monoid_op(
251
  这就是 *显式因果建模* — 模型必须在每个时间步学习如何
252
  平衡保留旧信息与吸收新信息。
253
  """
254
- log_a, kv_a = a
255
- log_b, kv_b = b
256
 
257
- new_log = log_a + log_b # log(α·β) = log_α + log_β
258
- decay_b = torch.exp(log_b) # β = exp(log_β)
259
  while decay_b.dim() < kv_a.dim():
260
- decay_b = decay_b.unsqueeze(-1) # broadcast to [B,H,...,1,1]
261
 
262
- return new_log, kv_a * decay_b + kv_b # β·S + X
263
 
264
 
265
  # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
@@ -327,6 +321,17 @@ class MonoidAttention(nn.Module):
327
  self.v_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
328
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=config.attention_bias)
329
 
 
 
 
 
 
 
 
 
 
 
 
330
  # --- Decay gate (novel component, randomly initialized) ---
331
  # --- 衰减门 (全新组件, 随机初始化) ---
332
  # Projects hidden_size → num_heads * head_dim, yielding a VECTOR per head.
@@ -351,6 +356,7 @@ class MonoidAttention(nn.Module):
351
  # 可能无界增长。
352
  self.q_norm = LlamaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
353
  self.k_norm = LlamaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
 
354
 
355
  # --- Learnable initial state h0 (novel component, zero-initialized) ---
356
  # --- 可学习初始状态 h0 (全新组件, 零初始化) ---
@@ -394,6 +400,10 @@ class MonoidAttention(nn.Module):
394
  k = self.k_proj(hidden_states).view(B, T, H, d).transpose(1, 2)
395
  v = self.v_proj(hidden_states).view(B, T, H, d).transpose(1, 2)
396
 
 
 
 
 
397
  # --- QK-Norm: stabilize q·S readout scale ---
398
  # --- QK 归一化: 稳定 q·S 读出尺度 ---
399
  q = self.q_norm(q) * self.scaling
@@ -413,25 +423,22 @@ class MonoidAttention(nn.Module):
413
 
414
  # --- Compute per-dimension vector decay gate α_t ---
415
  # --- 计算每维度向量衰减门 α_t ---
416
- # Negative Softplus: log_α = -softplus(Wx + b)
417
- # Value range: log_α ∈ (-∞, 0), i.e. α ∈ (0, 1].
418
- # When Wx → -∞: softplus → 0, α → 1 (perfect memory, no forgetting)
419
- # When Wx → +∞: softplusWx, α 0 (complete forgetting)
420
- # This avoids α > 1 explosion (unlike SiLU) while still allowing
421
- # α = 1 for lossless memory (unlike Sigmoid which caps at <1).
422
  # Each dimension of the d-vector decays independently:
423
  # S_t[i,j] = α_t[i] · S_{t-1}[i,j] + k_t[i] · v_t[j]
424
  #
425
- # 负 Softplus: log_α = -softplus(Wx + b)
426
- # 值域: log_α ∈ (-∞, 0), 即 α ∈ (0, 1]
427
- # 当 Wx → -∞: softplus → 0, α → 1 (完美记忆, 不遗忘)
428
- # 当 Wx → +∞: softplusWx, α → 0 (完遗忘)
429
- # 避免了 SiLU 的 α > 1 爆炸, 同时允许 α = 1 无损记忆 (Sigmoid 无法做到)。
430
  # d-向量的每个维度独立衰减:
431
  # S_t[i,j] = α_t[i] · S_{t-1}[i,j] + k_t[i] · v_t[j]
432
  raw = self.decay_proj(hidden_states) # [B,T,H*d]
433
- log_alpha = -torch.nn.functional.softplus(raw) # [B,T,H*d]
434
- log_alpha = log_alpha.view(B, T, H, d).transpose(1, 2) # [B,H,T,d]
435
 
436
  # --- Apply attention_mask: PAD tokens must be invisible to the recurrence ---
437
  # --- 应用注意力掩码: PAD token 必须对递推不可见 ---
@@ -441,10 +448,10 @@ class MonoidAttention(nn.Module):
441
  # 这使得 S_t = 1·S_{t-1} + 0 = S_{t-1}, 即 PAD 对状态是空操作。
442
  if attention_mask is not None:
443
  # attention_mask: [B, T] → [B, 1, T, 1] for broadcasting with [B, H, T, d]
444
- mask = attention_mask[:, None, :, None].to(log_alpha.dtype) # [B,1,T,1]
445
- log_alpha = log_alpha * mask # PAD → log_α=0 α=1
446
- k = k * mask # PAD → k=0
447
- v = v * mask # PAD → v=0 → kv=0
448
 
449
  # ══════════════════════════════════════════════════════════
450
  # Inference path (RNN mode): O(1) per token per layer
@@ -466,20 +473,20 @@ class MonoidAttention(nn.Module):
466
  # Outer product: k_t ⊗ v_t ∈ ℝ^{H×d×d}
467
  # 外积: k_t ⊗ v_t ∈ ℝ^{H×d×d}
468
  kv_t = torch.einsum('bhd, bhe -> bhde', k[:, :, 0], v[:, :, 0])
469
- log_t = log_alpha[:, :, 0] # [B,H,d]
470
 
471
  prev = monoid_cache.get_state(self.layer_idx) if monoid_cache else None
472
  if prev is None:
473
  # First token: initialize from learnable h0
474
  # 第一个 token: 从可学习的 h0 初始化
475
- decay_t = torch.exp(log_t)
476
  while decay_t.dim() < self.h0.dim():
477
  decay_t = decay_t.unsqueeze(-1)
478
- new_state = (log_t, self.h0.expand(B, -1, -1, -1) * decay_t + kv_t)
479
  else:
480
  # Subsequent tokens: fold via monoid_op — O(1)!
481
  # 后续 token: 通过 monoid_op 折叠 — O(1)!
482
- new_state = monoid_op(prev, (log_t, kv_t))
483
 
484
  if monoid_cache is not None:
485
  monoid_cache.update(self.layer_idx, new_state)
@@ -487,10 +494,11 @@ class MonoidAttention(nn.Module):
487
  # Readout: o_t = q_t · S_t
488
  # 读出: o_t = q_t · S_t
489
  o = torch.einsum('bhd, bhde -> bhe', q[:, :, 0], new_state[1])
 
490
  # Reshape [B,H,d] → [B,1,H*d] (heads contiguous, matching scan path)
491
  # 重塑 [B,H,d] → [B,1,H*d] (头连续排列, 与扫描路径一致)
492
  o = o.contiguous().view(B, 1, -1)
493
- return self.o_proj(o), new_state
494
 
495
  # ══════════════════════════════════════════════════════════
496
  # Inference prefill (use_cache=True, T>1): parallel scan + readout
@@ -504,20 +512,20 @@ class MonoidAttention(nn.Module):
504
  # 内存: O(B·H·T·d²) — 与训练路径相同。
505
  if use_cache:
506
  kv = torch.einsum('bhtd, bhte -> bhtde', k, v) # [B,H,T,d,d]
507
- states, (log_acc, S_T) = parallel_scan_with_state(log_alpha, kv)
508
 
509
  # Add h0 contribution: S_t += diag(∏_{i=0}^{t} α_i) · h0
510
  # 叠加 h0 贡献: S_t += diag(∏_{i=0}^{t} α_i) · h0
511
- cum_log_alpha = torch.cumsum(log_alpha, dim=2) # [B,H,T,d]
512
- h0_decay = torch.exp(cum_log_alpha).unsqueeze(-1) # [B,H,T,d,1]
513
  states = states + h0_decay * self.h0.unsqueeze(2) # broadcast h0 [1,H,1,d,d]
514
 
515
  # Final state includes h0 contribution
516
  # 最终状态包含 h0 贡献
517
- total_h0_decay = torch.exp(log_acc).unsqueeze(-1) # [B,H,d,1]
518
  S_final = S_T + total_h0_decay * self.h0.squeeze(0) # [B,H,d,d]
519
  # h0 is [1,H,d,d], squeeze(0) removed for clarity but expand also works
520
- final_state = (log_acc, S_final)
521
 
522
  if monoid_cache is not None:
523
  monoid_cache.update(self.layer_idx, final_state)
@@ -525,8 +533,9 @@ class MonoidAttention(nn.Module):
525
  # Vectorized readout: o_t = q_t · S_t for all t
526
  # 向量化读出: 一次性计算所有 t 的 o_t = q_t · S_t
527
  o = torch.einsum('bhtd, bhtde -> bhte', q, states) # [B,H,T,d]
 
528
  o = o.transpose(1, 2).contiguous().view(B, T, -1)
529
- return self.o_proj(o), final_state
530
 
531
  # ══════════════════════════════════════════════════════════
532
  # Training path: parallel scan + vectorized readout
@@ -548,22 +557,23 @@ class MonoidAttention(nn.Module):
548
 
549
  # Parallel prefix scan: S_t = diag(α_t)·S_{t-1} + kv_t (from S=0)
550
  # 并行前缀扫描: S_t = diag(α_t)·S_{t-1} + kv_t (从 S=0 开始)
551
- # log_alpha is [B,H,T,d] — vector decay per dimension.
552
- # log_alpha 为 [B,H,T,d] — 每维度向量衰减。
553
- states = parallel_scan(log_alpha, kv) # [B,H,T,d,d]
554
 
555
  # Add h0 contribution: S_t += diag(∏_{i=0}^{t} α_i) · h0
556
  # 叠加 h0 贡献: S_t += diag(∏_{i=0}^{t} α_i) · h0
557
- cum_log_alpha = torch.cumsum(log_alpha, dim=2) # [B,H,T,d]
558
- h0_decay = torch.exp(cum_log_alpha).unsqueeze(-1) # [B,H,T,d,1]
559
  states = states + h0_decay * self.h0.unsqueeze(2) # broadcast h0 [1,H,1,d,d]
560
 
561
  # Vectorized readout: o_t = q_t · S_t for all t at once
562
  # 向量化读出: 一次性计算所有 t 的 q_t · S_t
563
  o = torch.einsum('bhtd, bhtde -> bhte', q, states) # [B,H,T,d]
 
564
 
565
  o = o.transpose(1, 2).contiguous().view(B, T, -1)
566
- return self.o_proj(o), None
567
 
568
 
569
  # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
@@ -640,7 +650,16 @@ class MonoidPreTrainedModel(PreTrainedModel):
640
  module.weight.data[module.padding_idx].zero_()
641
 
642
  if isinstance(module, MonoidAttention):
643
- nn.init.constant_(module.decay_proj.bias, 1.0)
 
 
 
 
 
 
 
 
 
644
 
645
  class MonoidModel(MonoidPreTrainedModel):
646
  """
 
23
  其中 α_t ∈ ℝ^d 是逐维度的向量衰减门。
24
 
25
  This is a monoid because the binary operator:
26
+ (α, S) ⊕ (β, X) = (α·β, diag(β)·S + X)
27
  is associative → enables parallel prefix scan for training,
28
  and O(1) sequential update for inference.
29
  这是一个幺半群,因为二元算子:
30
+ (α, S) ⊕ (β, X) = (α·β, diag(β)·S + X)
31
  满足结合律 → 训练时可用并行前缀扫描,推理时 O(1) 逐步递推。
32
 
33
  Key properties / 关键特性:
 
70
  # Pure-PyTorch fallback (sequential scan) — works on CPU / MPS / any device.
71
  # Slower than the fused CUDA kernel but numerically identical.
72
 
73
+ def parallel_scan(alpha: Tensor, kv: Tensor) -> Tensor:
74
+ """Sequential prefix scan fallback: S_t[i,:] = α_t[i]·S_{t-1}[i,:] + kv_t[i,:]."""
75
  B, H, T, d1, d2 = kv.shape
76
  states = torch.zeros(B, H, T, d1, d2, device=kv.device, dtype=kv.dtype)
77
  S = torch.zeros(B, H, d1, d2, device=kv.device, dtype=kv.dtype)
78
  for t in range(T):
79
+ decay = alpha[:, :, t] # [B, H, d]
80
  while decay.dim() < S.dim():
81
  decay = decay.unsqueeze(-1)
82
  S = S * decay + kv[:, :, t]
83
  states[:, :, t] = S
84
  return states
85
 
86
+ def parallel_scan_with_state(alpha: Tensor, kv: Tensor):
87
+ """Sequential prefix scan that also returns the final (decay_acc, S) state."""
88
  B, H, T, d1, d2 = kv.shape
89
  states = torch.zeros(B, H, T, d1, d2, device=kv.device, dtype=kv.dtype)
90
  S = torch.zeros(B, H, d1, d2, device=kv.device, dtype=kv.dtype)
91
+ decay_acc = torch.ones(B, H, d1, device=alpha.device, dtype=alpha.dtype)
92
  for t in range(T):
93
+ decay = alpha[:, :, t]
94
  while decay.dim() < S.dim():
95
  decay = decay.unsqueeze(-1)
96
  S = S * decay + kv[:, :, t]
97
  states[:, :, t] = S
98
+ decay_acc = decay_acc * alpha[:, :, t]
99
+ return states, (decay_acc, S)
100
 
101
 
102
 
 
169
 
170
  Unlike Transformer KV-Cache that stores all past keys & values (O(T) memory),
171
  each layer here stores exactly ONE state tuple:
172
+ (decay_acc, S) where S ∈ ℝ^{B, H, d, d}
173
+ This is the monoid "sum" of all past (α_i, k_i⊗v_i) via ⊕.
174
  Memory is O(1) per layer regardless of sequence length.
175
 
176
  不同于 Transformer 的 KV-Cache (存储所有过去的 key 和 value, O(T) 内存),
177
  这里每层仅存储一个状态元组:
178
+ (decay_acc, S) 其中 S ∈ ℝ^{B, H, d, d}
179
+ 这是所有过去的 (α_i, k_i⊗v_i) 通过 ⊕ 累积的幺半群 "和"。
180
  无论序列多长,每层内存 O(1)。
181
  """
182
 
 
219
  b: tuple[Tensor, Tensor],
220
  ) -> tuple[Tensor, Tensor]:
221
  """
222
+ The monoid binary operator ⊕ on (vector decay, state matrix) pairs.
223
+ 幺半群二元算子 ⊕,作用于 (向量衰减, 状态矩阵) 对。
224
 
225
  Definition / 定义:
226
+ (α, S) ⊕ (β, X) = (α·β, diag(β)·S + X)
227
+ where α, β ∈ (0,1)^d are per-dimension vector decay gates (sigmoid output).
228
 
229
  Why this is a monoid / 为什么这是幺半群:
230
  • Associativity / 结合律:
 
235
  推理时可以 O(1) 左折叠 (逐步追加)。
236
 
237
  • Identity / 单位元:
238
+ e = (1, 0) → e ⊕ a = a ⊕ e = a ✓
 
 
 
 
 
239
 
240
  Causal semantics / 因果语义:
241
  S_t = α_t · S_{t-1} + k_t ⊗ v_t
 
246
  这就是 *显式因果建模* — 模型必须在每个时间步学习如何
247
  平衡保留旧信息与吸收新信息。
248
  """
249
+ decay_a, kv_a = a
250
+ decay_b, kv_b = b
251
 
252
+ new_decay = decay_a * decay_b # α·β (element-wise product)
 
253
  while decay_b.dim() < kv_a.dim():
254
+ decay_b = decay_b.unsqueeze(-1) # broadcast to [B,H,...,1,1]
255
 
256
+ return new_decay, kv_a * decay_b + kv_b # β·S + X
257
 
258
 
259
  # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
 
321
  self.v_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
322
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=config.attention_bias)
323
 
324
+ # --- Output gate (novel component, randomly initialized) ---
325
+ # --- 输出门控 (全新组件, 随机初始化) ---
326
+ # Modulates the multi-head readout before o_proj, similar to GLA/RetNet.
327
+ # gate = SiLU(gate_proj(x)), output = gate ⊙ concat_heads(o)
328
+ # This lets the model suppress or amplify specific head outputs
329
+ # conditioned on the current input, increasing expressiveness.
330
+ # 在 o_proj 之前调制多头读出, 类似 GLA/RetNet。
331
+ # gate = SiLU(gate_proj(x)), output = gate ⊙ concat_heads(o)
332
+ # 使模型能根据当前输入抑制或放大特定头的输出, 增加表达力。
333
+ self.gate_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
334
+
335
  # --- Decay gate (novel component, randomly initialized) ---
336
  # --- 衰减门 (全新组件, 随机初始化) ---
337
  # Projects hidden_size → num_heads * head_dim, yielding a VECTOR per head.
 
356
  # 可能无界增长。
357
  self.q_norm = LlamaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
358
  self.k_norm = LlamaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
359
+ self.o_norm = LlamaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
360
 
361
  # --- Learnable initial state h0 (novel component, zero-initialized) ---
362
  # --- 可学习初始状态 h0 (全新组件, 零初始化) ---
 
400
  k = self.k_proj(hidden_states).view(B, T, H, d).transpose(1, 2)
401
  v = self.v_proj(hidden_states).view(B, T, H, d).transpose(1, 2)
402
 
403
+ # --- Output gate: computed from input, applied before o_proj ---
404
+ # --- 输出门控: 从输入计算, 在 o_proj 之前应用 ---
405
+ gate = torch.nn.functional.silu(self.gate_proj(hidden_states)) # [B,T,H*d]
406
+
407
  # --- QK-Norm: stabilize q·S readout scale ---
408
  # --- QK 归一化: 稳定 q·S 读出尺度 ---
409
  q = self.q_norm(q) * self.scaling
 
423
 
424
  # --- Compute per-dimension vector decay gate α_t ---
425
  # --- 计算每维度向量衰减门 α_t ---
426
+ # Sigmoid: α = σ(Wx + b)
427
+ # Value range: α ∈ (0, 1).
428
+ # When Wx → -∞: σ → 0 (complete forgetting)
429
+ # When Wx → +∞: σ1 (perfect memory, no forgetting)
 
 
430
  # Each dimension of the d-vector decays independently:
431
  # S_t[i,j] = α_t[i] · S_{t-1}[i,j] + k_t[i] · v_t[j]
432
  #
433
+ # Sigmoid: α = σ(Wx + b)
434
+ # 值域: α ∈ (0, 1)
435
+ # 当 Wx → -∞: σ → 0 (完遗忘)
436
+ # 当 Wx → +∞: σ1 (完美记忆, 不遗忘)
 
437
  # d-向量的每个维度独立衰减:
438
  # S_t[i,j] = α_t[i] · S_{t-1}[i,j] + k_t[i] · v_t[j]
439
  raw = self.decay_proj(hidden_states) # [B,T,H*d]
440
+ alpha = torch.sigmoid(raw) # [B,T,H*d]
441
+ alpha = alpha.view(B, T, H, d).transpose(1, 2) # [B,H,T,d]
442
 
443
  # --- Apply attention_mask: PAD tokens must be invisible to the recurrence ---
444
  # --- 应用注意力掩码: PAD token 必须对递推不可见 ---
 
448
  # 这使得 S_t = 1·S_{t-1} + 0 = S_{t-1}, 即 PAD 对状态是空操作。
449
  if attention_mask is not None:
450
  # attention_mask: [B, T] → [B, 1, T, 1] for broadcasting with [B, H, T, d]
451
+ mask = attention_mask[:, None, :, None].to(alpha.dtype) # [B,1,T,1]
452
+ alpha = alpha * mask + (1 - mask) # PAD → α=1 (preserve state)
453
+ k = k * mask # PAD → k=0
454
+ v = v * mask # PAD → v=0 → kv=0
455
 
456
  # ══════════════════════════════════════════════════════════
457
  # Inference path (RNN mode): O(1) per token per layer
 
473
  # Outer product: k_t ⊗ v_t ∈ ℝ^{H×d×d}
474
  # 外积: k_t ⊗ v_t ∈ ℝ^{H×d×d}
475
  kv_t = torch.einsum('bhd, bhe -> bhde', k[:, :, 0], v[:, :, 0])
476
+ alpha_t = alpha[:, :, 0] # [B,H,d]
477
 
478
  prev = monoid_cache.get_state(self.layer_idx) if monoid_cache else None
479
  if prev is None:
480
  # First token: initialize from learnable h0
481
  # 第一个 token: 从可学习的 h0 初始化
482
+ decay_t = alpha_t
483
  while decay_t.dim() < self.h0.dim():
484
  decay_t = decay_t.unsqueeze(-1)
485
+ new_state = (alpha_t, self.h0.expand(B, -1, -1, -1) * decay_t + kv_t)
486
  else:
487
  # Subsequent tokens: fold via monoid_op — O(1)!
488
  # 后续 token: 通过 monoid_op 折叠 — O(1)!
489
+ new_state = monoid_op(prev, (alpha_t, kv_t))
490
 
491
  if monoid_cache is not None:
492
  monoid_cache.update(self.layer_idx, new_state)
 
494
  # Readout: o_t = q_t · S_t
495
  # 读出: o_t = q_t · S_t
496
  o = torch.einsum('bhd, bhde -> bhe', q[:, :, 0], new_state[1])
497
+ o = self.o_norm(o)
498
  # Reshape [B,H,d] → [B,1,H*d] (heads contiguous, matching scan path)
499
  # 重塑 [B,H,d] → [B,1,H*d] (头连续排列, 与扫描路径一致)
500
  o = o.contiguous().view(B, 1, -1)
501
+ return self.o_proj(gate * o), new_state
502
 
503
  # ══════════════════════════════════════════════════════════
504
  # Inference prefill (use_cache=True, T>1): parallel scan + readout
 
512
  # 内存: O(B·H·T·d²) — 与训练路径相同。
513
  if use_cache:
514
  kv = torch.einsum('bhtd, bhte -> bhtde', k, v) # [B,H,T,d,d]
515
+ states, (decay_acc, S_T) = parallel_scan_with_state(alpha, kv)
516
 
517
  # Add h0 contribution: S_t += diag(∏_{i=0}^{t} α_i) · h0
518
  # 叠加 h0 贡献: S_t += diag(∏_{i=0}^{t} α_i) · h0
519
+ cum_alpha = torch.exp(torch.cumsum(torch.log(alpha + 1e-8), dim=2)) # [B,H,T,d]
520
+ h0_decay = cum_alpha.unsqueeze(-1) # [B,H,T,d,1]
521
  states = states + h0_decay * self.h0.unsqueeze(2) # broadcast h0 [1,H,1,d,d]
522
 
523
  # Final state includes h0 contribution
524
  # 最终状态包含 h0 贡献
525
+ total_h0_decay = decay_acc.unsqueeze(-1) # [B,H,d,1]
526
  S_final = S_T + total_h0_decay * self.h0.squeeze(0) # [B,H,d,d]
527
  # h0 is [1,H,d,d], squeeze(0) removed for clarity but expand also works
528
+ final_state = (decay_acc, S_final)
529
 
530
  if monoid_cache is not None:
531
  monoid_cache.update(self.layer_idx, final_state)
 
533
  # Vectorized readout: o_t = q_t · S_t for all t
534
  # 向量化读出: 一次性计算所有 t 的 o_t = q_t · S_t
535
  o = torch.einsum('bhtd, bhtde -> bhte', q, states) # [B,H,T,d]
536
+ o = self.o_norm(o)
537
  o = o.transpose(1, 2).contiguous().view(B, T, -1)
538
+ return self.o_proj(gate * o), final_state
539
 
540
  # ══════════════════════════════════════════════════════════
541
  # Training path: parallel scan + vectorized readout
 
557
 
558
  # Parallel prefix scan: S_t = diag(α_t)·S_{t-1} + kv_t (from S=0)
559
  # 并行前缀扫描: S_t = diag(α_t)·S_{t-1} + kv_t (从 S=0 开始)
560
+ # alpha is [B,H,T,d] — vector decay per dimension.
561
+ # alpha 为 [B,H,T,d] — 每维度向量衰减。
562
+ states = parallel_scan(alpha, kv) # [B,H,T,d,d]
563
 
564
  # Add h0 contribution: S_t += diag(∏_{i=0}^{t} α_i) · h0
565
  # 叠加 h0 贡献: S_t += diag(∏_{i=0}^{t} α_i) · h0
566
+ cum_alpha = torch.exp(torch.cumsum(torch.log(alpha + 1e-8), dim=2)) # [B,H,T,d]
567
+ h0_decay = cum_alpha.unsqueeze(-1) # [B,H,T,d,1]
568
  states = states + h0_decay * self.h0.unsqueeze(2) # broadcast h0 [1,H,1,d,d]
569
 
570
  # Vectorized readout: o_t = q_t · S_t for all t at once
571
  # 向量化读出: 一次性计算所有 t 的 q_t · S_t
572
  o = torch.einsum('bhtd, bhtde -> bhte', q, states) # [B,H,T,d]
573
+ o = self.o_norm(o)
574
 
575
  o = o.transpose(1, 2).contiguous().view(B, T, -1)
576
+ return self.o_proj(gate * o), None
577
 
578
 
579
  # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
 
650
  module.weight.data[module.padding_idx].zero_()
651
 
652
  if isinstance(module, MonoidAttention):
653
+ # decay_proj: bias init so sigmoid(bias) 0.95 → mostly remembering at start
654
+ # decay_proj: 偏置初始化使 sigmoid(bias) ≈ 0.95 → 初始时以记忆为主
655
+ nn.init.constant_(module.decay_proj.bias, 3.0)
656
+ # gate_proj: small init so gate starts near identity (SiLU(0)=0,
657
+ # but normal weights give moderate gate values)
658
+ # gate_proj: 小初始化, 使门控从接近恒等开始
659
+ nn.init.normal_(module.gate_proj.weight, mean=0.0, std=0.01)
660
+ # o_norm: RMSNorm weight defaults to 1.0 (identity), explicit for clarity
661
+ # o_norm: RMSNorm 权重默认为 1.0 (恒等), 显式设置确保正确
662
+ nn.init.ones_(module.o_norm.weight)
663
 
664
  class MonoidModel(MonoidPreTrainedModel):
665
  """
README.md CHANGED
@@ -21,33 +21,153 @@ model-index:
21
 
22
  A 1.3B parameter language model that replaces softmax attention with **causal monoid state compression**, achieving **O(1) time per token** and **O(1) memory** at inference — regardless of sequence length.
23
 
24
- ## Monoid Attention — Internal Structure
25
-
26
- ```
27
- MonoidAttention (per layer, per head)
28
- ┌─────────────────────────────────────────────────────────────────────────┐
29
- │ │
30
- │ x_t ∈ R^{2048} │
31
- │ │ │
32
- │ ├──> q_proj ──> RMSNorm ──> q_t ∈ R^d (query, scaled 1/√d)
33
- │ │ │
34
- │ ├──> k_proj ──> RMSNorm ──> SiLU ──> k_t ∈ R^d (key, non-negative) │
35
- │ │ │
36
- │ ├──> v_proj ──> v_t ∈ R^d (value) │
37
- │ │ │
38
- │ └──> decay_proj ──> -Softplus ──> log α_t ∈ R^d (vector decay gate)
39
- │ │
40
- │ k_t ⊗ v_t │
41
- │ │ ┌─────────────────────────────────┐ │
42
- │ │ │ State Matrix S_t ∈ R^{d x d} │ │
43
- │ v │ "Compressed causal history" │ │
44
- S_t = diag(α_t) · S_{t-1} + k_t ⊗ v_t │ │
45
- │ │ │ α_t ∈ (0,1]^d per dimension │ │
46
- │ └─────────────────────────────────┘ │
47
- │ v │
48
- o_t = q_t · S_t ──> o_proj ──> output │
49
- │ │
50
- └─────────────────────────────────────────────────────────────────────────┘
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  ```
52
 
53
  ## Key Properties
@@ -75,11 +195,11 @@ S_t = diag(α_t) · S_{t-1} + k_t ⊗ v_t — vector decay monoid recurrence
75
  o_t = q_t · S_t — state readout
76
  ```
77
 
78
- This is a monoid because the binary operator `(log_α, S) ⊕ (log_β, X) = (log_α + log_β, exp(log_β)·S + X)` is **associative**, enabling O(T) parallel prefix scan for training and O(1) sequential update for inference.
79
 
80
  ## Vector Decay — Per-Dimension Memory Lifetimes
81
 
82
- Unlike scalar decay (one α per head), Spartacus uses **vector decay**: each dimension of the d-vector has its own independent decay rate α_t[i] ∈ (0, 1]:
83
 
84
  ```
85
  S_t[i,j] = α_t[i] · S_{t-1}[i,j] + k_t[i] · v_t[j]
@@ -89,26 +209,27 @@ This allows different feature dimensions to specialize:
89
  - **Fast-decaying dimensions** (α ≈ 0) — local syntax, punctuation, function words
90
  - **Slow-decaying dimensions** (α ≈ 1) — entity memory, topic tracking, long-range facts
91
 
92
- The decay gate uses **Negative Softplus** activation:
93
 
94
  ```
95
- log α_t = -softplus(W·x_t + b)
96
  ```
97
 
98
  | Property | Value |
99
  |---|---|
100
- | Range | α ∈ (0, 1] — bounded, no explosion |
101
- | Perfect memory | W·x → -∞ ⟹ softplus0 ⟹ α → 1 (lossless retention) |
102
- | Full forgetting | W·x → +∞ ⟹ softplus∞ ⟹ α → 0 (complete reset) |
103
- | Stability | α 1 by construction — no divergence regardless of input magnitude |
 
104
 
105
  ## Attention Mask — Padding-Aware Recurrence
106
 
107
  The monoid recurrence correctly handles `attention_mask` for padded batches (e.g., left-padding during `generate()`). For PAD positions (mask=0):
108
 
109
  ```
110
- log_α = 0 → α = 1 (preserve state unchanged)
111
- k = 0, v = 0 → kv = 0 (no information injected)
112
  ```
113
 
114
  Net effect: `S_t = 1·S_{t-1} + 0 = S_{t-1}` — PAD acts as the **monoid identity element**, completely invisible to the recurrence. This ensures identical outputs whether inputs are padded or not.
@@ -117,9 +238,11 @@ Net effect: `S_t = 1·S_{t-1} + 0 = S_{t-1}` — PAD acts as the **monoid identi
117
 
118
  - **SiLU-activated keys**: `k = SiLU(k_proj(x))` ensures non-negative keys, making the state matrix S positive semi-definite (PSD). This prevents "feature erasure" where one token's contribution cancels another's
119
  - **QK-Norm**: RMSNorm on both q and k before readout, stabilizing the scale of q·S when the state matrix accumulates many outer products
120
- - **Log-space decay**: Working in log-space `log(α)` avoids numerical underflow when α^T 0 for long sequences
 
 
121
  - **Learnable h0**: The initial state S₀ = h0 is a learnable parameter (zero-initialized), acting as a compressed "system prompt"
122
- - **Negative Softplus gate**: Ensures α (0, 1] by construction allows perfect memory (α=1) while preventing state explosion (α>1)
123
 
124
  ## Three Forward Paths
125
 
@@ -141,7 +264,7 @@ Net effect: `S_t = 1·S_{t-1} + 0 = S_{t-1}` — PAD acts as the **monoid identi
141
  | Layers | 16 |
142
  | Attention heads | 32 |
143
  | Head dimension | 64 |
144
- | Decay gate | Vector decay, d=64 per head |
145
  | State matrix per head | 64 × 64 = 4,096 floats |
146
  | Vocabulary | 128,256 (Llama-3.2 tokenizer) |
147
  | Precision | bfloat16 |
@@ -207,6 +330,9 @@ monoid_scan_cuda.py # Triton JIT parallel prefix scan (vector decay) + Py
207
  model.safetensors # Model weights (bfloat16)
208
  config.json # Model configuration
209
  tokenizer.json # Llama-3.2 tokenizer
 
 
 
210
  ```
211
 
212
  ## Citation
 
21
 
22
  A 1.3B parameter language model that replaces softmax attention with **causal monoid state compression**, achieving **O(1) time per token** and **O(1) memory** at inference — regardless of sequence length.
23
 
24
+ ## SFT Training Curves
25
+
26
+ | Loss | Accuracy |
27
+ |:---:|:---:|
28
+ | ![SFT Loss](LOSS_SPAR.png) | ![SFT Accuracy](ACC_SPAR.png) |
29
+
30
+ ## Core Mechanism
31
+
32
+ ![Core Mechanism: The Monoid Recurrence](ARCH.png)
33
+
34
+ ## Architecture Overview
35
+
36
+ ```
37
+ ╔═══════════════════════════════════════════════════════════════════════════╗
38
+ ║ MonoidForCausalLM (1.34B)
39
+ ╠═══════════════════════════════════════════════════════════════════════════╣
40
+ ║ ║
41
+ ║ token_ids ──> [ embed_tokens 128256 × 2048 ] ──> x_0 ║
42
+ ║ ║
43
+ ║ ┌─────────────────────────┐ ║
44
+ MonoidDecoderLayer × 16 ◄── see detail below ║
45
+ ║ └─────────────────────────┘ ║
46
+
47
+ ║ [ RMSNorm ] ║
48
+
49
+ ║ [ lm_head 2048 × 128256 ] ──> logits ║
50
+ ║ (tied with embed_tokens) ║
51
+ ╚═══════════════════════════════════════════════════════════════════════════╝
52
+
53
+
54
+ ╔═══════════════════════════════════════════════════════════════════════════╗
55
+ ║ MonoidDecoderLayer (× 16 layers) ║
56
+ ╠═══════════════════════════════════════════════════════════════════════════╣
57
+ ║ ║
58
+ ║ x ─────────────────────────────────────────┐ (residual) ║
59
+ ║ │ │ ║
60
+ ║ [ input_layernorm RMSNorm ] │ ║
61
+ ║ │ │ ║
62
+ ║ [ MonoidAttention ] ◄── see detail below │ ║
63
+ ║ │ │ ║
64
+ ║ + <────────────────────────────────────────┘ ║
65
+ ║ │ ║
66
+ ║ x ─────────────────────────────────────────┐ (residual) ║
67
+ ║ │ │ ║
68
+ ║ [ post_attention_layernorm RMSNorm ] │ ║
69
+ ║ │ │ ║
70
+ ║ [ LlamaMLP 2048 → 8192 → 2048 ] │ ║
71
+ ║ │ gate_proj ─┐ │ ║
72
+ ║ │ up_proj ───┤─> SiLU(gate) ⊙ up │ ║
73
+ ║ │ └──> down_proj ──> out │ ║
74
+ ║ │ │ ║
75
+ ║ + <────────────────────────────────────────┘ ║
76
+ ║ │ ║
77
+ ║ out ║
78
+ ╚═══════════════════════════════════════════════════════════════════════════╝
79
+
80
+
81
+ ╔═══════════════════════════════════════════════════════════════════════════╗
82
+ ║ MonoidAttention (32 heads, d=64 per head) ║
83
+ ╠═══════════════════════════════════════════════════════════════════════════╣
84
+ ║ ║
85
+ ║ x_t ∈ R^{2048} ║
86
+ ║ │ ║
87
+ ║ ├──> q_proj ──> [B,H,T,d] ──> RMSNorm ──> ×(1/√d) ──────> q_t ║
88
+ ║ │ ║
89
+ ║ ├──> k_proj ──> [B,H,T,d] ──> RMSNorm ──> SiLU ──────────> k_t ≥0 ║
90
+ ║ │ ║
91
+ ║ ├──> v_proj ──> [B,H,T,d] ────────────────────────────────> v_t ║
92
+ ║ │ ║
93
+ ║ ├──> decay_proj ──> Sigmoid ──> α_t ∈ (0,1)^d (vector decay gate) ║
94
+ ║ │ bias init = 3.0 ║
95
+ ║ │ → σ(3) ≈ 0.95 at start ║
96
+ ║ │ ║
97
+ ║ └──> gate_proj ──> SiLU ──────> g_t ∈ R^{H*d} (output gate) ║
98
+ ║ ║
99
+ ║ ┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄ ║
100
+ ║ Monoid Recurrence (training: parallel prefix scan, decode: O(1)) ║
101
+ ║ ┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄���┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄ ║
102
+ ║ ║
103
+ ║ k_t ⊗ v_t ──────────────┐ ║
104
+ ║ [d×d] v ║
105
+ ║ ┌─────────────────────────┐ ║
106
+ ║ S_{t-1} ────> │ S_t = diag(α_t)·S_{t-1}│ ║
107
+ ║ [d×d] │ + k_t ⊗ v_t │──> S_t ║
108
+ ║ └─────────────────────────┘ [d×d] ║
109
+ ║ "compressed causal history" ║
110
+ ║ ║
111
+ ║ h0 (learnable, zero-init) ──> S_0 at sequence start ║
112
+ ║ ║
113
+ ║ ┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄ ║
114
+ ║ Readout + Output Projection ║
115
+ ║ ┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄ ║
116
+ ║ ║
117
+ ║ q_t ──> einsum(q, S_t) ──> o_t ──> RMSNorm ──┐ ║
118
+ ║ (o_norm) │ ║
119
+ ║ v ║
120
+ ║ g_t ──────────────────────────────────> g_t ⊙ o_t ──> o_proj ──> out ║
121
+ ║ ║
122
+ ╚═══════════════════════════════════════════════════════════════════════════╝
123
+
124
+
125
+ ╔═══════════════════════════════════════════════════════════════════════════╗
126
+ ║ MonoidCache — O(1) State (replaces O(T) KV-Cache) ║
127
+ ╠═══════════════════════════════════════════════════════════════════════════╣
128
+ ║ ║
129
+ ║ Transformer KV-Cache: Monoid State Cache: ║
130
+ ║ ┌──────────────────┐ ┌──────────────────┐ ║
131
+ ║ │ K: [B,H,T,d] │ │ S: [B,H,d,d] │ ← fixed size ║
132
+ ║ │ V: [B,H,T,d] │ │ α_acc: [B,H,d] │ ║
133
+ ║ │ grows with T ↑↑↑ │ │ per layer │ ║
134
+ ║ └──────────────────┘ └──────────────────┘ ║
135
+ ║ Memory: O(T·H·d) Memory: O(H·d²) ║
136
+ ║ 1000 tok → 2M floats/layer ANY length → 131K floats/layer ║
137
+ ║ ║
138
+ ║ Decode step: Decode step: ║
139
+ ║ o = softmax(q·K^T)·V S_t = α_t·S_{t-1} + k_t⊗v_t ║
140
+ ║ scan T keys ↑ o_t = q_t · S_t ║
141
+ ║ Time: O(T·d) Time: O(d²) ← constant! ║
142
+ ╚═══════════════════════════════════════════════════════════════════════════╝
143
+
144
+
145
+ ╔═══════════════════════════════════════════════════════════════════════════╗
146
+ ║ Weight Transfer from Llama-3.2-1B-Instruct ║
147
+ ╠══════════════════════════════════════════════════════════════���════════════╣
148
+ ║ ║
149
+ ║ Reused directly (frozen-compatible): ║
150
+ ║ ┌──────────────────────────────────────────────┐ ║
151
+ ║ │ embed_tokens 128256 × 2048 │ ║
152
+ ║ │ lm_head 2048 × 128256 (tied) │ ║
153
+ ║ │ LlamaMLP × 16 gate/up/down_proj │ ║
154
+ ║ │ LlamaRMSNorm × 33 input/post_attn/final │ ║
155
+ ║ │ q_proj × 16 2048 → 2048 │ ║
156
+ ║ │ k_proj × 16 2048 → 2048 (tiled 8→32 heads from GQA) │ ║
157
+ ║ │ v_proj × 16 2048 → 2048 (tiled 8→32 heads from GQA) │ ║
158
+ ║ │ o_proj × 16 2048 → 2048 │ ║
159
+ ║ └──────────────────────────────────────────────┘ ║
160
+ ║ ║
161
+ ║ Novel (randomly initialized): ║
162
+ ║ ┌──────────────────────────────────────────────┐ ║
163
+ ║ │ decay_proj × 16 2048 → 2048 (bias=3.0) │ ║
164
+ ║ │ gate_proj × 16 2048 → 2048 (std=0.01) │ ║
165
+ ║ │ q_norm × 16 RMSNorm(64) │ ║
166
+ ║ │ k_norm × 16 RMSNorm(64) │ ║
167
+ ║ │ o_norm × 16 RMSNorm(64) (weight=1) │ ║
168
+ ║ │ h0 × 16 [1,32,64,64] (zeros) │ ║
169
+ ║ └──────────────────────────────────────────────┘ ║
170
+ ╚═══════════════════════════════════════════════════════════════════════════╝
171
  ```
172
 
173
  ## Key Properties
 
195
  o_t = q_t · S_t — state readout
196
  ```
197
 
198
+ This is a monoid because the binary operator `(α, S) ⊕ (β, X) = (α·β, diag(β)·S + X)` is **associative**, enabling O(T) parallel prefix scan for training and O(1) sequential update for inference.
199
 
200
  ## Vector Decay — Per-Dimension Memory Lifetimes
201
 
202
+ Unlike scalar decay (one α per head), Spartacus uses **vector decay**: each dimension of the d-vector has its own independent decay rate α_t[i] ∈ (0, 1):
203
 
204
  ```
205
  S_t[i,j] = α_t[i] · S_{t-1}[i,j] + k_t[i] · v_t[j]
 
209
  - **Fast-decaying dimensions** (α ≈ 0) — local syntax, punctuation, function words
210
  - **Slow-decaying dimensions** (α ≈ 1) — entity memory, topic tracking, long-range facts
211
 
212
+ The decay gate uses **Sigmoid** activation:
213
 
214
  ```
215
+ α_t = σ(W·x_t + b)
216
  ```
217
 
218
  | Property | Value |
219
  |---|---|
220
+ | Range | α ∈ (0, 1) — bounded, no explosion |
221
+ | Perfect memory | W·x → +∞ ⟹ σ → 1 (lossless retention) |
222
+ | Full forgetting | W·x → -∞ ⟹ σ → 0 (complete reset) |
223
+ | Stability | α < 1 by construction — no divergence regardless of input magnitude |
224
+ | Bias init | b = 3.0 ⟹ σ(3) ≈ 0.95, model starts in "mostly remember" mode |
225
 
226
  ## Attention Mask — Padding-Aware Recurrence
227
 
228
  The monoid recurrence correctly handles `attention_mask` for padded batches (e.g., left-padding during `generate()`). For PAD positions (mask=0):
229
 
230
  ```
231
+ α = α * mask + (1 - mask) → α = 1 (preserve state unchanged)
232
+ k = k * mask, v = v * mask → kv = 0 (no information injected)
233
  ```
234
 
235
  Net effect: `S_t = 1·S_{t-1} + 0 = S_{t-1}` — PAD acts as the **monoid identity element**, completely invisible to the recurrence. This ensures identical outputs whether inputs are padded or not.
 
238
 
239
  - **SiLU-activated keys**: `k = SiLU(k_proj(x))` ensures non-negative keys, making the state matrix S positive semi-definite (PSD). This prevents "feature erasure" where one token's contribution cancels another's
240
  - **QK-Norm**: RMSNorm on both q and k before readout, stabilizing the scale of q·S when the state matrix accumulates many outer products
241
+ - **Output Norm**: RMSNorm on the readout o after `q·S`, further stabilizing scale before gating
242
+ - **Output Gate**: `gate = SiLU(gate_proj(x))`, modulates the multi-head readout before o_proj (similar to GLA/RetNet). Lets the model suppress or amplify specific head outputs conditioned on the current input
243
+ - **Sigmoid decay gate**: Ensures α ∈ (0, 1) by construction — allows near-perfect memory (α→1) while preventing state explosion (α>1). Bias initialized to 3.0 so σ(3)≈0.95, starting in high-retention mode
244
  - **Learnable h0**: The initial state S₀ = h0 is a learnable parameter (zero-initialized), acting as a compressed "system prompt"
245
+ - **Log-space decay in scan**: The parallel prefix scan works in log-space `log(α)` to avoid numerical underflow when computing cumulative products over long sequences
246
 
247
  ## Three Forward Paths
248
 
 
264
  | Layers | 16 |
265
  | Attention heads | 32 |
266
  | Head dimension | 64 |
267
+ | Decay gate | Vector decay (Sigmoid), d=64 per head |
268
  | State matrix per head | 64 × 64 = 4,096 floats |
269
  | Vocabulary | 128,256 (Llama-3.2 tokenizer) |
270
  | Precision | bfloat16 |
 
330
  model.safetensors # Model weights (bfloat16)
331
  config.json # Model configuration
332
  tokenizer.json # Llama-3.2 tokenizer
333
+ ARCH.png # Core mechanism diagram (monoid recurrence + parallel scan)
334
+ ACC_SPAR.png # SFT accuracy curve
335
+ LOSS_SPAR.png # SFT loss curve
336
  ```
337
 
338
  ## Citation
config.json CHANGED
@@ -23,5 +23,6 @@
23
  "pad_token_id": 128009,
24
  "rms_norm_eps": 1e-05,
25
  "transformers_version": "4.57.6",
 
26
  "vocab_size": 128256
27
  }
 
23
  "pad_token_id": 128009,
24
  "rms_norm_eps": 1e-05,
25
  "transformers_version": "4.57.6",
26
+ "use_cache": false,
27
  "vocab_size": 128256
28
  }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d5cd463898c4ce262d12fe56c6227d0c1117680aa13892f9cac6e100a1db9077
3
- size 2811462896
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a423fff0e285f14e1da9e2bc7ace8ea963c408ad9715bb485c365affb0da4cf1
3
+ size 2945686352
monoid_scan_cuda.py CHANGED
@@ -3,18 +3,18 @@ monoid_scan_cuda.py — Triton CUDA JIT Accelerated Parallel Prefix Scan
3
  monoid_scan_cuda.py — Triton CUDA JIT 加速的并行前缀扫描
4
 
5
  This module implements the parallel prefix scan for the vector-decay monoid recurrence:
6
- y_t[i,:] = exp(log_decay_t[i]) · y_{t-1}[i,:] + x_t[i,:]
7
  本模块实现向量衰减幺半群递推的并行前缀扫描:
8
- y_t[i,:] = exp(log_decay_t[i]) · y_{t-1}[i,:] + x_t[i,:]
9
 
10
  This is the computational backbone of Monoid Attention's state compression.
11
  这是幺半群注意力状态压缩的计算骨干。
12
 
13
  Vector decay: each dimension of the D_k×D_v state matrix has its own
14
- per-dimension decay rate α_t ∈ ^{D_k}, enabling different feature
15
- dimensions to have independent memory lifetimes (fast-decaying for
16
- local syntax, slow-decaying for global entity memory).
17
- 向量衰减: D_k×D_v 状态矩阵的每个维度拥有独立的衰减率 α_t ∈ ^{D_k},
18
  使不同特征维度拥有独立的记忆生命周期 (快速衰减用于局部语法, 慢速衰减用于全局实体记忆)。
19
 
20
  Implementation:
@@ -22,13 +22,13 @@ Implementation:
22
  Each program handles one row of the state matrix (D_v elements)
23
  with a scalar decay per row.
24
  Backward: reverse-order adjoint scan for gradient computation.
25
- Per-row reduction for log_decay gradient (no atomic_add needed).
26
  Auto-dispatches: CUDA → Triton kernel, CPU/MPS → PyTorch fallback.
27
 
28
  前向: 沿 T 维顺序扫描, 跨 B*H*D_k 在 GPU 上并行。
29
  每个 program 处理状态矩阵的一行 (D_v 个元素), 每行一个标量衰减。
30
  反向: 逆序伴随变量扫描计算梯度。
31
- 逐行归约计算 log_decay 梯度 (无需 atomic_add)。
32
  自动分派: CUDA → Triton 核函数, CPU/MPS → PyTorch 回退。
33
  """
34
 
@@ -52,28 +52,28 @@ except ImportError:
52
  # 回退: 纯 PyTorch 串行扫描 (CPU / MPS / no Triton)
53
  # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
54
 
55
- def _sequential_scan(log_decays: Tensor, values: Tensor) -> Tensor:
56
  """
57
  Pure PyTorch sequential scan fallback (when no CUDA / Triton available).
58
  纯 PyTorch 串行扫描回退 (无 CUDA / Triton 时使用)。
59
 
60
  Implements the vector-decay monoid recurrence step by step:
61
  acc_0 = 0
62
- acc_t[i,:] = exp(log_decay_t[i]) · acc_{t-1}[i,:] + values_t[i,:]
63
  This is O(T) sequential — correct but slow on GPU.
64
  逐步实现向量衰减幺半群递推:
65
  acc_0 = 0
66
- acc_t[i,:] = exp(log_decay_t[i]) · acc_{t-1}[i,:] + values_t[i,:]
67
  这是 O(T) 串行的 — 结果正确但在 GPU 上较慢。
68
 
69
  Args:
70
- log_decays: [B, H, T, D_k] — log of per-dimension per-step decay gates
71
- 每维度每步衰减门的对数
72
- values: [B, H, T, D_k, D_v] — outer products k_t⊗v_t to accumulate
73
- 待累积的外积 k_t⊗v_t
74
  Returns:
75
- output: [B, H, T, D_k, D_v] — all prefix states S_1, ..., S_T
76
- 所有前缀状态 S_1, ..., S_T
77
  """
78
  B, H, T, D_k, D_v = values.shape
79
  out = torch.empty_like(values)
@@ -83,7 +83,7 @@ def _sequential_scan(log_decays: Tensor, values: Tensor) -> Tensor:
83
  for t in range(T):
84
  # S_t = diag(α_t) · S_{t-1} + kv_t (vector decay monoid recurrence)
85
  # S_t = diag(α_t) · S_{t-1} + kv_t (向量衰减幺半群递推)
86
- decay_t = torch.exp(log_decays[:, :, t]).unsqueeze(-1) # [B,H,D_k,1]
87
  acc = acc * decay_t + values[:, :, t]
88
  out[:, :, t] = acc
89
  return out
@@ -140,10 +140,9 @@ if HAS_TRITON:
140
  o_base = O_ptr + bhdk * s_o_bhdk
141
 
142
  for t in range(T):
143
- # Load scalar log_decay for this row at time t
144
- # 加载此行在时刻 t 的标量 log_decay
145
- ld_val = tl.load(ld_base + t * s_ld_t).to(tl.float32)
146
- decay = tl.exp(ld_val)
147
 
148
  # Load kv_t[row, :] (one row of the outer product)
149
  # 加载 kv_t[行, :] (外积的一行)
@@ -178,12 +177,12 @@ if HAS_TRITON:
178
  反向扫描核函数 — 通过伴随方法计算梯度 (向量衰减)。
179
 
180
  Each program handles one row of the state matrix (one d_k dimension).
181
- The decay for this row is a scalar, so the log_decay gradient is:
182
- ∂L/∂log_α_t[i] = α_t[i] · Σ_j(λ_t[i,j] · y_{t-1}[i,j])
183
  The sum over j (D_v) is computed within this single program — no atomic_add.
184
  每个 program 处理状态矩阵的一行 (一个 d_k 维度)。
185
- 该行的衰减是标量, 因此 log_decay 梯度为:
186
- ∂L/∂log_α_t[i] = α_t[i] · Σ_j(λ_t[i,j] · y_{t-1}[i,j])
187
  对 j (D_v) 的求和在单个 program 内完成 — 无需 atomic_add。
188
  """
189
  bhdk = tl.program_id(0)
@@ -216,19 +215,18 @@ if HAS_TRITON:
216
  lam, mask=dv_mask,
217
  )
218
 
219
- # ∂L/∂log_α_t[i] = α_t[i] · Σ_j(λ_t[i,j] · y_{t-1}[i,j])
220
  # Per-row scalar gradient: sum over D_v within this program.
221
  # 逐行标量梯度: 在此 program 内对 D_v 求和。
222
- ld_val = tl.load(LD_ptr + bhdk * s_ld_bhdk + t * s_ld_t).to(tl.float32)
223
- a_t = tl.exp(ld_val)
224
 
225
  if t > 0:
226
  y_prev = tl.load(
227
  O_ptr + bhdk * s_o_bhdk + (t - 1) * s_o_t + dv_offs * s_o_dv,
228
  mask=dv_mask, other=0.0,
229
  ).to(tl.float32)
230
- grad_ld = tl.sum(lam * y_prev) * a_t
231
- tl.atomic_add(GLD_ptr + bhdk * s_gld_bhdk + t * s_gld_t, grad_ld)
232
 
233
  # Prepare for next step (t-1): adj = a_t · λ_t
234
  # 为下一步 (t-1) 准备: adj = a_t · λ_t
@@ -255,16 +253,16 @@ if HAS_TRITON:
255
  逐行归约消除大部分 atomic_add 开销。
256
  """
257
  @staticmethod
258
- def forward(ctx, log_decays: Tensor, values: Tensor) -> Tensor:
259
  B, H, T, D_k, D_v = values.shape
260
 
261
  # Reshape for row-parallel kernel:
262
- # log_decays: [B, H, T, D_k] → permute to [B, H, D_k, T] → [B*H*D_k, T]
263
- # values: [B, H, T, D_k, D_v] → permute to [B, H, D_k, T, D_v] → [B*H*D_k, T, D_v]
264
  # 为行并行核函数重塑:
265
- # log_decays: [B, H, T, D_k] → 转置为 [B, H, D_k, T] → [B*H*D_k, T]
266
- # values: [B, H, T, D_k, D_v] → 转置为 [B, H, D_k, T, D_v] → [B*H*D_k, T, D_v]
267
- ld_flat = log_decays.permute(0, 1, 3, 2).contiguous().reshape(B * H * D_k, T)
268
  v_flat = values.permute(0, 1, 3, 2, 4).contiguous().reshape(B * H * D_k, T, D_v)
269
  o_flat = torch.empty_like(v_flat)
270
 
@@ -283,8 +281,8 @@ if HAS_TRITON:
283
  BLOCK_DV=BLOCK_DV,
284
  )
285
 
286
- # Save for backward: need log_decays and forward outputs y_t
287
- # 为反向传播保存: 需要 log_decays 和前向输出 y_t
288
  ctx.save_for_backward(ld_flat, o_flat)
289
  ctx.shape_info = (B, H, T, D_k, D_v, BHDK, BLOCK_DV)
290
  # Reshape back: [B*H*D_k, T, D_v] → [B, H, D_k, T, D_v] → [B, H, T, D_k, D_v]
@@ -318,15 +316,15 @@ if HAS_TRITON:
318
  # Reshape gradients back to original layout
319
  # 重塑梯度回原始布局
320
  # gld: [B*H*D_k, T] → [B, H, D_k, T] → [B, H, T, D_k]
321
- grad_log_decays = gld_flat.to(grad_output.dtype).reshape(B, H, D_k, T).permute(0, 1, 3, 2).contiguous()
322
  # gv: [B*H*D_k, T, D_v] → [B, H, D_k, T, D_v] → [B, H, T, D_k, D_v]
323
  grad_values = gv_flat.reshape(B, H, D_k, T, D_v).permute(0, 1, 3, 2, 4).contiguous()
324
- return grad_log_decays, grad_values
325
 
326
- def _triton_parallel_scan(log_decays: Tensor, values: Tensor) -> Tensor:
327
  """Triton-accelerated parallel scan entry point (vector decay).
328
  Triton 加速的并行扫描入口 (向量衰减)。"""
329
- return _ParallelScanFn.apply(log_decays, values)
330
 
331
  else:
332
  _triton_parallel_scan = None
@@ -336,7 +334,7 @@ else:
336
  # Public API / 公共接口
337
  # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
338
 
339
- def parallel_scan(log_decays: Tensor, values: Tensor) -> Tensor:
340
  """
341
  Parallel prefix scan — computes all prefix monoid sums (vector decay).
342
  并行前缀扫描 — 计算所有前缀幺半群和 (向量衰减)。
@@ -357,21 +355,21 @@ def parallel_scan(log_decays: Tensor, values: Tensor) -> Tensor:
357
  CPU/MPS → PyTorch 串行扫描 (正确, 较慢)
358
 
359
  Args:
360
- log_decays: [B, H, T, D_k] — log of per-dimension decay gates α_t
361
- 每维度衰减门 α_t 的对数
362
- values: [B, H, T, D_k, D_v] — outer products k_t⊗v_t
363
- 外积 k_t⊗v_t
364
  Returns:
365
- states: [B, H, T, D_k, D_v] — all prefix states S_1..S_T
366
- 所有前缀状态 S_1..S_T
367
  """
368
  if _triton_parallel_scan is not None and values.is_cuda:
369
- return _triton_parallel_scan(log_decays, values)
370
- return _sequential_scan(log_decays, values)
371
 
372
 
373
  def parallel_scan_with_state(
374
- log_decays: Tensor, values: Tensor,
375
  ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
376
  """
377
  Parallel prefix scan + extract final state for inference handoff (vector decay).
@@ -389,23 +387,23 @@ def parallel_scan_with_state(
389
  这是训练模式 (并行扫描) 和推理模式 (串行 monoid_op) 之间的桥梁。
390
 
391
  Args:
392
- log_decays: [B, H, T, D_k]
393
- values: [B, H, T, D_k, D_v]
394
 
395
  Returns:
396
  output: [B, H, T, D_k, D_v] — all prefix states S_1..S_T
397
  所有前缀状态
398
- final_state: (log_acc, S_T) where
399
- log_acc: [B, H, D_k] — accumulated log-decay vector (for future monoid_op)
400
- 累积对数衰减向量 (供后续 monoid_op 使用)
401
  final_state: [B, H, D_k, D_v] — S_T, the compressed causal summary
402
  S_T, 压缩的因果摘要
403
  """
404
- output = parallel_scan(log_decays, values)
405
- # Sum all log-decays over T to get the total accumulated decay per dimension
406
- # 对所有 log-decay 沿 T 求和得到每个维度的总累衰减
407
- log_acc = log_decays.sum(dim=2) # [B, H, D_k]
408
  # The last timestep's state IS the full causal summary
409
  # 最后一个时间步的状态就是完整的因果摘要
410
  final_state = output[:, :, -1] # [B, H, D_k, D_v]
411
- return output, (log_acc, final_state)
 
3
  monoid_scan_cuda.py — Triton CUDA JIT 加速的并行前缀扫描
4
 
5
  This module implements the parallel prefix scan for the vector-decay monoid recurrence:
6
+ y_t[i,:] = decay_t[i] · y_{t-1}[i,:] + x_t[i,:]
7
  本模块实现向量衰减幺半群递推的并行前缀扫描:
8
+ y_t[i,:] = decay_t[i] · y_{t-1}[i,:] + x_t[i,:]
9
 
10
  This is the computational backbone of Monoid Attention's state compression.
11
  这是幺半群注意力状态压缩的计算骨干。
12
 
13
  Vector decay: each dimension of the D_k×D_v state matrix has its own
14
+ per-dimension decay rate α_t ∈ (0,1)^{D_k} (sigmoid output), enabling
15
+ different feature dimensions to have independent memory lifetimes
16
+ (fast-decaying for local syntax, slow-decaying for global entity memory).
17
+ 向量衰减: D_k×D_v 状态矩阵的每个维度拥有独立的衰减率 α_t ∈ (0,1)^{D_k} (sigmoid 输出),
18
  使不同特征维度拥有独立的记忆生命周期 (快速衰减用于局部语法, 慢速衰减用于全局实体记忆)。
19
 
20
  Implementation:
 
22
  Each program handles one row of the state matrix (D_v elements)
23
  with a scalar decay per row.
24
  Backward: reverse-order adjoint scan for gradient computation.
25
+ Per-row reduction for decay gradient (no atomic_add needed).
26
  Auto-dispatches: CUDA → Triton kernel, CPU/MPS → PyTorch fallback.
27
 
28
  前向: 沿 T 维顺序扫描, 跨 B*H*D_k 在 GPU 上并行。
29
  每个 program 处理状态矩阵的一行 (D_v 个元素), 每行一个标量衰减。
30
  反向: 逆序伴随变量扫描计算梯度。
31
+ 逐行归约计算 decay 梯度 (无需 atomic_add)。
32
  自动分派: CUDA → Triton 核函数, CPU/MPS → PyTorch 回退。
33
  """
34
 
 
52
  # 回退: 纯 PyTorch 串行扫描 (CPU / MPS / no Triton)
53
  # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
54
 
55
+ def _sequential_scan(decays: Tensor, values: Tensor) -> Tensor:
56
  """
57
  Pure PyTorch sequential scan fallback (when no CUDA / Triton available).
58
  纯 PyTorch 串行扫描回退 (无 CUDA / Triton 时使用)。
59
 
60
  Implements the vector-decay monoid recurrence step by step:
61
  acc_0 = 0
62
+ acc_t[i,:] = decay_t[i] · acc_{t-1}[i,:] + values_t[i,:]
63
  This is O(T) sequential — correct but slow on GPU.
64
  逐步实现向量衰减幺半群递推:
65
  acc_0 = 0
66
+ acc_t[i,:] = decay_t[i] · acc_{t-1}[i,:] + values_t[i,:]
67
  这是 O(T) 串行的 — 结果正确但在 GPU 上较慢。
68
 
69
  Args:
70
+ decays: [B, H, T, D_k] — per-dimension per-step decay gates α_t ∈ (0,1)
71
+ 每维度每步衰减门 α_t ∈ (0,1)
72
+ values: [B, H, T, D_k, D_v] — outer products k_t⊗v_t to accumulate
73
+ 待累积的外积 k_t⊗v_t
74
  Returns:
75
+ output: [B, H, T, D_k, D_v] — all prefix states S_1, ..., S_T
76
+ 所有前缀状态 S_1, ..., S_T
77
  """
78
  B, H, T, D_k, D_v = values.shape
79
  out = torch.empty_like(values)
 
83
  for t in range(T):
84
  # S_t = diag(α_t) · S_{t-1} + kv_t (vector decay monoid recurrence)
85
  # S_t = diag(α_t) · S_{t-1} + kv_t (向量衰减幺半群递推)
86
+ decay_t = decays[:, :, t].unsqueeze(-1) # [B,H,D_k,1]
87
  acc = acc * decay_t + values[:, :, t]
88
  out[:, :, t] = acc
89
  return out
 
140
  o_base = O_ptr + bhdk * s_o_bhdk
141
 
142
  for t in range(T):
143
+ # Load scalar decay for this row at time t
144
+ # 加载此行在时刻 t 的标量 decay
145
+ decay = tl.load(ld_base + t * s_ld_t).to(tl.float32)
 
146
 
147
  # Load kv_t[row, :] (one row of the outer product)
148
  # 加载 kv_t[行, :] (外积的一行)
 
177
  反向扫描核函数 — 通过伴随方法计算梯度 (向量衰减)。
178
 
179
  Each program handles one row of the state matrix (one d_k dimension).
180
+ The decay for this row is a scalar, so the decay gradient is:
181
+ ∂L/∂α_t[i] = Σ_j(λ_t[i,j] · y_{t-1}[i,j])
182
  The sum over j (D_v) is computed within this single program — no atomic_add.
183
  每个 program 处理状态矩阵的一行 (一个 d_k 维度)。
184
+ 该行的衰减是标量, 因此 decay 梯度为:
185
+ ∂L/∂α_t[i] = Σ_j(λ_t[i,j] · y_{t-1}[i,j])
186
  对 j (D_v) 的求和在单个 program 内完成 — 无需 atomic_add。
187
  """
188
  bhdk = tl.program_id(0)
 
215
  lam, mask=dv_mask,
216
  )
217
 
218
+ # ∂L/∂α_t[i] = Σ_j(λ_t[i,j] · y_{t-1}[i,j])
219
  # Per-row scalar gradient: sum over D_v within this program.
220
  # 逐行标量梯度: 在此 program 内对 D_v 求和。
221
+ a_t = tl.load(LD_ptr + bhdk * s_ld_bhdk + t * s_ld_t).to(tl.float32)
 
222
 
223
  if t > 0:
224
  y_prev = tl.load(
225
  O_ptr + bhdk * s_o_bhdk + (t - 1) * s_o_t + dv_offs * s_o_dv,
226
  mask=dv_mask, other=0.0,
227
  ).to(tl.float32)
228
+ grad_d = tl.sum(lam * y_prev)
229
+ tl.atomic_add(GLD_ptr + bhdk * s_gld_bhdk + t * s_gld_t, grad_d)
230
 
231
  # Prepare for next step (t-1): adj = a_t · λ_t
232
  # 为下一步 (t-1) 准备: adj = a_t · λ_t
 
253
  逐行归约消除大部分 atomic_add 开销。
254
  """
255
  @staticmethod
256
+ def forward(ctx, decays: Tensor, values: Tensor) -> Tensor:
257
  B, H, T, D_k, D_v = values.shape
258
 
259
  # Reshape for row-parallel kernel:
260
+ # decays: [B, H, T, D_k] → permute to [B, H, D_k, T] → [B*H*D_k, T]
261
+ # values: [B, H, T, D_k, D_v] → permute to [B, H, D_k, T, D_v] → [B*H*D_k, T, D_v]
262
  # 为行并行核函数重塑:
263
+ # decays: [B, H, T, D_k] → 转置为 [B, H, D_k, T] → [B*H*D_k, T]
264
+ # values: [B, H, T, D_k, D_v] → 转置为 [B, H, D_k, T, D_v] → [B*H*D_k, T, D_v]
265
+ ld_flat = decays.permute(0, 1, 3, 2).contiguous().reshape(B * H * D_k, T)
266
  v_flat = values.permute(0, 1, 3, 2, 4).contiguous().reshape(B * H * D_k, T, D_v)
267
  o_flat = torch.empty_like(v_flat)
268
 
 
281
  BLOCK_DV=BLOCK_DV,
282
  )
283
 
284
+ # Save for backward: need decays and forward outputs y_t
285
+ # 为反向传播保存: 需要 decays 和前向输出 y_t
286
  ctx.save_for_backward(ld_flat, o_flat)
287
  ctx.shape_info = (B, H, T, D_k, D_v, BHDK, BLOCK_DV)
288
  # Reshape back: [B*H*D_k, T, D_v] → [B, H, D_k, T, D_v] → [B, H, T, D_k, D_v]
 
316
  # Reshape gradients back to original layout
317
  # 重塑梯度回原始布局
318
  # gld: [B*H*D_k, T] → [B, H, D_k, T] → [B, H, T, D_k]
319
+ grad_decays = gld_flat.to(grad_output.dtype).reshape(B, H, D_k, T).permute(0, 1, 3, 2).contiguous()
320
  # gv: [B*H*D_k, T, D_v] → [B, H, D_k, T, D_v] → [B, H, T, D_k, D_v]
321
  grad_values = gv_flat.reshape(B, H, D_k, T, D_v).permute(0, 1, 3, 2, 4).contiguous()
322
+ return grad_decays, grad_values
323
 
324
+ def _triton_parallel_scan(decays: Tensor, values: Tensor) -> Tensor:
325
  """Triton-accelerated parallel scan entry point (vector decay).
326
  Triton 加速的并行扫描入口 (向量衰减)。"""
327
+ return _ParallelScanFn.apply(decays, values)
328
 
329
  else:
330
  _triton_parallel_scan = None
 
334
  # Public API / 公共接口
335
  # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
336
 
337
+ def parallel_scan(decays: Tensor, values: Tensor) -> Tensor:
338
  """
339
  Parallel prefix scan — computes all prefix monoid sums (vector decay).
340
  并行前缀扫描 — 计算所有前缀幺半群和 (向量衰减)。
 
355
  CPU/MPS → PyTorch 串行扫描 (正确, 较慢)
356
 
357
  Args:
358
+ decays: [B, H, T, D_k] — per-dimension decay gates α_t ∈ (0,1) (sigmoid output)
359
+ 每维度衰减门 α_t ∈ (0,1) (sigmoid 输出)
360
+ values: [B, H, T, D_k, D_v] — outer products k_t⊗v_t
361
+ 外积 k_t⊗v_t
362
  Returns:
363
+ states: [B, H, T, D_k, D_v] — all prefix states S_1..S_T
364
+ 所有前缀状态 S_1..S_T
365
  """
366
  if _triton_parallel_scan is not None and values.is_cuda:
367
+ return _triton_parallel_scan(decays, values)
368
+ return _sequential_scan(decays, values)
369
 
370
 
371
  def parallel_scan_with_state(
372
+ decays: Tensor, values: Tensor,
373
  ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
374
  """
375
  Parallel prefix scan + extract final state for inference handoff (vector decay).
 
387
  这是训练模式 (并行扫描) 和推理模式 (串行 monoid_op) 之间的桥梁。
388
 
389
  Args:
390
+ decays: [B, H, T, D_k] — per-dimension decay gates α_t ∈ (0,1)
391
+ values: [B, H, T, D_k, D_v]
392
 
393
  Returns:
394
  output: [B, H, T, D_k, D_v] — all prefix states S_1..S_T
395
  所有前缀状态
396
+ final_state: (decay_acc, S_T) where
397
+ decay_acc: [B, H, D_k] — accumulated decay product (for future monoid_op)
398
+ 累积衰减乘积 (供后续 monoid_op 使用)
399
  final_state: [B, H, D_k, D_v] — S_T, the compressed causal summary
400
  S_T, 压缩的因果摘要
401
  """
402
+ output = parallel_scan(decays, values)
403
+ # Product of all decays over T use log-sum-exp for numerical stability in bf16
404
+ # 对所有 decay 沿 T 求积 — 使用 log-sum-exp 保证 bf16 数值稳定
405
+ decay_acc = torch.exp(torch.sum(torch.log(decays + 1e-8), dim=2)) # [B, H, D_k]
406
  # The last timestep's state IS the full causal summary
407
  # 最后一个时间步的状态就是完整的因果摘要
408
  final_state = output[:, :, -1] # [B, H, D_k, D_v]
409
+ return output, (decay_acc, final_state)
training_args.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:32938ebb880dc58fa7d6f8e45383c55e1d5d4352618531d62a28069918595445
3
- size 6417
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a74d419f03ffc06a6f989ef4dc1768ad7f4298b971f129f0a2e121514a016053
3
+ size 6353