Upload 14 files
Browse files- .gitattributes +1 -0
- ACC_SPAR.png +0 -0
- ARCH.png +3 -0
- LOSS_SPAR.png +0 -0
- MonoidForCausalLM.py +85 -66
- README.md +166 -40
- config.json +1 -0
- model.safetensors +2 -2
- monoid_scan_cuda.py +61 -63
- training_args.bin +2 -2
.gitattributes
CHANGED
|
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
ARCH.png filter=lfs diff=lfs merge=lfs -text
|
ACC_SPAR.png
ADDED
|
ARCH.png
ADDED
|
Git LFS Details
|
LOSS_SPAR.png
ADDED
|
MonoidForCausalLM.py
CHANGED
|
@@ -23,11 +23,11 @@ Architecture / 架构概要:
|
|
| 23 |
其中 α_t ∈ ℝ^d 是逐维度的向量衰减门。
|
| 24 |
|
| 25 |
This is a monoid because the binary operator:
|
| 26 |
-
(
|
| 27 |
is associative → enables parallel prefix scan for training,
|
| 28 |
and O(1) sequential update for inference.
|
| 29 |
这是一个幺半群,因为二元算子:
|
| 30 |
-
(
|
| 31 |
满足结合律 → 训练时可用并行前缀扫描,推理时 O(1) 逐步递推。
|
| 32 |
|
| 33 |
Key properties / 关键特性:
|
|
@@ -70,33 +70,33 @@ except ImportError:
|
|
| 70 |
# Pure-PyTorch fallback (sequential scan) — works on CPU / MPS / any device.
|
| 71 |
# Slower than the fused CUDA kernel but numerically identical.
|
| 72 |
|
| 73 |
-
def parallel_scan(
|
| 74 |
-
"""Sequential prefix scan fallback: S_t[i,:] =
|
| 75 |
B, H, T, d1, d2 = kv.shape
|
| 76 |
states = torch.zeros(B, H, T, d1, d2, device=kv.device, dtype=kv.dtype)
|
| 77 |
S = torch.zeros(B, H, d1, d2, device=kv.device, dtype=kv.dtype)
|
| 78 |
for t in range(T):
|
| 79 |
-
decay =
|
| 80 |
while decay.dim() < S.dim():
|
| 81 |
decay = decay.unsqueeze(-1)
|
| 82 |
S = S * decay + kv[:, :, t]
|
| 83 |
states[:, :, t] = S
|
| 84 |
return states
|
| 85 |
|
| 86 |
-
def parallel_scan_with_state(
|
| 87 |
-
"""Sequential prefix scan that also returns the final (
|
| 88 |
B, H, T, d1, d2 = kv.shape
|
| 89 |
states = torch.zeros(B, H, T, d1, d2, device=kv.device, dtype=kv.dtype)
|
| 90 |
S = torch.zeros(B, H, d1, d2, device=kv.device, dtype=kv.dtype)
|
| 91 |
-
|
| 92 |
for t in range(T):
|
| 93 |
-
decay =
|
| 94 |
while decay.dim() < S.dim():
|
| 95 |
decay = decay.unsqueeze(-1)
|
| 96 |
S = S * decay + kv[:, :, t]
|
| 97 |
states[:, :, t] = S
|
| 98 |
-
|
| 99 |
-
return states, (
|
| 100 |
|
| 101 |
|
| 102 |
|
|
@@ -169,14 +169,14 @@ class MonoidCache:
|
|
| 169 |
|
| 170 |
Unlike Transformer KV-Cache that stores all past keys & values (O(T) memory),
|
| 171 |
each layer here stores exactly ONE state tuple:
|
| 172 |
-
(
|
| 173 |
-
This is the monoid "sum" of all past (
|
| 174 |
Memory is O(1) per layer regardless of sequence length.
|
| 175 |
|
| 176 |
不同于 Transformer 的 KV-Cache (存储所有过去的 key 和 value, O(T) 内存),
|
| 177 |
这里每层仅存储一个状态元组:
|
| 178 |
-
(
|
| 179 |
-
这是所有过去的 (
|
| 180 |
无论序列多长,每层内存 O(1)。
|
| 181 |
"""
|
| 182 |
|
|
@@ -219,12 +219,12 @@ def monoid_op(
|
|
| 219 |
b: tuple[Tensor, Tensor],
|
| 220 |
) -> tuple[Tensor, Tensor]:
|
| 221 |
"""
|
| 222 |
-
The monoid binary operator ⊕ on (
|
| 223 |
-
幺半群二元算子 ⊕,作用于 (
|
| 224 |
|
| 225 |
Definition / 定义:
|
| 226 |
-
(
|
| 227 |
-
where
|
| 228 |
|
| 229 |
Why this is a monoid / 为什么这是幺半群:
|
| 230 |
• Associativity / 结合律:
|
|
@@ -235,12 +235,7 @@ def monoid_op(
|
|
| 235 |
推理时可以 O(1) 左折叠 (逐步追加)。
|
| 236 |
|
| 237 |
• Identity / 单位元:
|
| 238 |
-
e = (
|
| 239 |
-
|
| 240 |
-
Why log-space / 为什么用对数空间:
|
| 241 |
-
Working in log-space for the decay factor avoids numerical
|
| 242 |
-
underflow when α^T → 0 for long sequences.
|
| 243 |
-
衰减因子在���数空间中运算,避免长序列下 α^T → 0 的数值下溢。
|
| 244 |
|
| 245 |
Causal semantics / 因果语义:
|
| 246 |
S_t = α_t · S_{t-1} + k_t ⊗ v_t
|
|
@@ -251,15 +246,14 @@ def monoid_op(
|
|
| 251 |
这就是 *显式因果建模* — 模型必须在每个时间步学习如何
|
| 252 |
平衡保留旧信息与吸收新信息。
|
| 253 |
"""
|
| 254 |
-
|
| 255 |
-
|
| 256 |
|
| 257 |
-
|
| 258 |
-
decay_b = torch.exp(log_b) # β = exp(log_β)
|
| 259 |
while decay_b.dim() < kv_a.dim():
|
| 260 |
-
decay_b = decay_b.unsqueeze(-1)
|
| 261 |
|
| 262 |
-
return
|
| 263 |
|
| 264 |
|
| 265 |
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
|
@@ -327,6 +321,17 @@ class MonoidAttention(nn.Module):
|
|
| 327 |
self.v_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
|
| 328 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=config.attention_bias)
|
| 329 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
# --- Decay gate (novel component, randomly initialized) ---
|
| 331 |
# --- 衰减门 (全新组件, 随机初始化) ---
|
| 332 |
# Projects hidden_size → num_heads * head_dim, yielding a VECTOR per head.
|
|
@@ -351,6 +356,7 @@ class MonoidAttention(nn.Module):
|
|
| 351 |
# 可能无界增长。
|
| 352 |
self.q_norm = LlamaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
| 353 |
self.k_norm = LlamaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
|
|
|
| 354 |
|
| 355 |
# --- Learnable initial state h0 (novel component, zero-initialized) ---
|
| 356 |
# --- 可学习初始状态 h0 (全新组件, 零初始化) ---
|
|
@@ -394,6 +400,10 @@ class MonoidAttention(nn.Module):
|
|
| 394 |
k = self.k_proj(hidden_states).view(B, T, H, d).transpose(1, 2)
|
| 395 |
v = self.v_proj(hidden_states).view(B, T, H, d).transpose(1, 2)
|
| 396 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
# --- QK-Norm: stabilize q·S readout scale ---
|
| 398 |
# --- QK 归一化: 稳定 q·S 读出尺度 ---
|
| 399 |
q = self.q_norm(q) * self.scaling
|
|
@@ -413,25 +423,22 @@ class MonoidAttention(nn.Module):
|
|
| 413 |
|
| 414 |
# --- Compute per-dimension vector decay gate α_t ---
|
| 415 |
# --- 计算每维度向量衰减门 α_t ---
|
| 416 |
-
#
|
| 417 |
-
# Value range:
|
| 418 |
-
# When Wx → -∞:
|
| 419 |
-
# When Wx → +∞:
|
| 420 |
-
# This avoids α > 1 explosion (unlike SiLU) while still allowing
|
| 421 |
-
# α = 1 for lossless memory (unlike Sigmoid which caps at <1).
|
| 422 |
# Each dimension of the d-vector decays independently:
|
| 423 |
# S_t[i,j] = α_t[i] · S_{t-1}[i,j] + k_t[i] · v_t[j]
|
| 424 |
#
|
| 425 |
-
#
|
| 426 |
-
# 值域:
|
| 427 |
-
# 当 Wx → -∞:
|
| 428 |
-
# 当 Wx → +∞:
|
| 429 |
-
# 避免了 SiLU 的 α > 1 爆炸, 同时允许 α = 1 无损记忆 (Sigmoid 无法做到)。
|
| 430 |
# d-向量的每个维度独立衰减:
|
| 431 |
# S_t[i,j] = α_t[i] · S_{t-1}[i,j] + k_t[i] · v_t[j]
|
| 432 |
raw = self.decay_proj(hidden_states) # [B,T,H*d]
|
| 433 |
-
|
| 434 |
-
|
| 435 |
|
| 436 |
# --- Apply attention_mask: PAD tokens must be invisible to the recurrence ---
|
| 437 |
# --- 应用注意力掩码: PAD token 必须对递推不可见 ---
|
|
@@ -441,10 +448,10 @@ class MonoidAttention(nn.Module):
|
|
| 441 |
# 这使得 S_t = 1·S_{t-1} + 0 = S_{t-1}, 即 PAD 对状态是空操作。
|
| 442 |
if attention_mask is not None:
|
| 443 |
# attention_mask: [B, T] → [B, 1, T, 1] for broadcasting with [B, H, T, d]
|
| 444 |
-
mask = attention_mask[:, None, :, None].to(
|
| 445 |
-
|
| 446 |
-
k = k * mask
|
| 447 |
-
v = v * mask
|
| 448 |
|
| 449 |
# ══════════════════════════════════════════════════════════
|
| 450 |
# Inference path (RNN mode): O(1) per token per layer
|
|
@@ -466,20 +473,20 @@ class MonoidAttention(nn.Module):
|
|
| 466 |
# Outer product: k_t ⊗ v_t ∈ ℝ^{H×d×d}
|
| 467 |
# 外积: k_t ⊗ v_t ∈ ℝ^{H×d×d}
|
| 468 |
kv_t = torch.einsum('bhd, bhe -> bhde', k[:, :, 0], v[:, :, 0])
|
| 469 |
-
|
| 470 |
|
| 471 |
prev = monoid_cache.get_state(self.layer_idx) if monoid_cache else None
|
| 472 |
if prev is None:
|
| 473 |
# First token: initialize from learnable h0
|
| 474 |
# 第一个 token: 从可学习的 h0 初始化
|
| 475 |
-
decay_t =
|
| 476 |
while decay_t.dim() < self.h0.dim():
|
| 477 |
decay_t = decay_t.unsqueeze(-1)
|
| 478 |
-
new_state = (
|
| 479 |
else:
|
| 480 |
# Subsequent tokens: fold via monoid_op — O(1)!
|
| 481 |
# 后续 token: 通过 monoid_op 折叠 — O(1)!
|
| 482 |
-
new_state = monoid_op(prev, (
|
| 483 |
|
| 484 |
if monoid_cache is not None:
|
| 485 |
monoid_cache.update(self.layer_idx, new_state)
|
|
@@ -487,10 +494,11 @@ class MonoidAttention(nn.Module):
|
|
| 487 |
# Readout: o_t = q_t · S_t
|
| 488 |
# 读出: o_t = q_t · S_t
|
| 489 |
o = torch.einsum('bhd, bhde -> bhe', q[:, :, 0], new_state[1])
|
|
|
|
| 490 |
# Reshape [B,H,d] → [B,1,H*d] (heads contiguous, matching scan path)
|
| 491 |
# 重塑 [B,H,d] → [B,1,H*d] (头连续排列, 与扫描路径一致)
|
| 492 |
o = o.contiguous().view(B, 1, -1)
|
| 493 |
-
return self.o_proj(o), new_state
|
| 494 |
|
| 495 |
# ══════════════════════════════════════════════════════════
|
| 496 |
# Inference prefill (use_cache=True, T>1): parallel scan + readout
|
|
@@ -504,20 +512,20 @@ class MonoidAttention(nn.Module):
|
|
| 504 |
# 内存: O(B·H·T·d²) — 与训练路径相同。
|
| 505 |
if use_cache:
|
| 506 |
kv = torch.einsum('bhtd, bhte -> bhtde', k, v) # [B,H,T,d,d]
|
| 507 |
-
states, (
|
| 508 |
|
| 509 |
# Add h0 contribution: S_t += diag(∏_{i=0}^{t} α_i) · h0
|
| 510 |
# 叠加 h0 贡献: S_t += diag(∏_{i=0}^{t} α_i) · h0
|
| 511 |
-
|
| 512 |
-
h0_decay =
|
| 513 |
states = states + h0_decay * self.h0.unsqueeze(2) # broadcast h0 [1,H,1,d,d]
|
| 514 |
|
| 515 |
# Final state includes h0 contribution
|
| 516 |
# 最终状态包含 h0 贡献
|
| 517 |
-
total_h0_decay =
|
| 518 |
S_final = S_T + total_h0_decay * self.h0.squeeze(0) # [B,H,d,d]
|
| 519 |
# h0 is [1,H,d,d], squeeze(0) removed for clarity but expand also works
|
| 520 |
-
final_state = (
|
| 521 |
|
| 522 |
if monoid_cache is not None:
|
| 523 |
monoid_cache.update(self.layer_idx, final_state)
|
|
@@ -525,8 +533,9 @@ class MonoidAttention(nn.Module):
|
|
| 525 |
# Vectorized readout: o_t = q_t · S_t for all t
|
| 526 |
# 向量化读出: 一次性计算所有 t 的 o_t = q_t · S_t
|
| 527 |
o = torch.einsum('bhtd, bhtde -> bhte', q, states) # [B,H,T,d]
|
|
|
|
| 528 |
o = o.transpose(1, 2).contiguous().view(B, T, -1)
|
| 529 |
-
return self.o_proj(o), final_state
|
| 530 |
|
| 531 |
# ══════════════════════════════════════════════════════════
|
| 532 |
# Training path: parallel scan + vectorized readout
|
|
@@ -548,22 +557,23 @@ class MonoidAttention(nn.Module):
|
|
| 548 |
|
| 549 |
# Parallel prefix scan: S_t = diag(α_t)·S_{t-1} + kv_t (from S=0)
|
| 550 |
# 并行前缀扫描: S_t = diag(α_t)·S_{t-1} + kv_t (从 S=0 开始)
|
| 551 |
-
#
|
| 552 |
-
#
|
| 553 |
-
states = parallel_scan(
|
| 554 |
|
| 555 |
# Add h0 contribution: S_t += diag(∏_{i=0}^{t} α_i) · h0
|
| 556 |
# 叠加 h0 贡献: S_t += diag(∏_{i=0}^{t} α_i) · h0
|
| 557 |
-
|
| 558 |
-
h0_decay =
|
| 559 |
states = states + h0_decay * self.h0.unsqueeze(2) # broadcast h0 [1,H,1,d,d]
|
| 560 |
|
| 561 |
# Vectorized readout: o_t = q_t · S_t for all t at once
|
| 562 |
# 向量化读出: 一次性计算所有 t 的 q_t · S_t
|
| 563 |
o = torch.einsum('bhtd, bhtde -> bhte', q, states) # [B,H,T,d]
|
|
|
|
| 564 |
|
| 565 |
o = o.transpose(1, 2).contiguous().view(B, T, -1)
|
| 566 |
-
return self.o_proj(o), None
|
| 567 |
|
| 568 |
|
| 569 |
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
|
@@ -640,7 +650,16 @@ class MonoidPreTrainedModel(PreTrainedModel):
|
|
| 640 |
module.weight.data[module.padding_idx].zero_()
|
| 641 |
|
| 642 |
if isinstance(module, MonoidAttention):
|
| 643 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 644 |
|
| 645 |
class MonoidModel(MonoidPreTrainedModel):
|
| 646 |
"""
|
|
|
|
| 23 |
其中 α_t ∈ ℝ^d 是逐维度的向量衰减门。
|
| 24 |
|
| 25 |
This is a monoid because the binary operator:
|
| 26 |
+
(α, S) ⊕ (β, X) = (α·β, diag(β)·S + X)
|
| 27 |
is associative → enables parallel prefix scan for training,
|
| 28 |
and O(1) sequential update for inference.
|
| 29 |
这是一个幺半群,因为二元算子:
|
| 30 |
+
(α, S) ⊕ (β, X) = (α·β, diag(β)·S + X)
|
| 31 |
满足结合律 → 训练时可用并行前缀扫描,推理时 O(1) 逐步递推。
|
| 32 |
|
| 33 |
Key properties / 关键特性:
|
|
|
|
| 70 |
# Pure-PyTorch fallback (sequential scan) — works on CPU / MPS / any device.
|
| 71 |
# Slower than the fused CUDA kernel but numerically identical.
|
| 72 |
|
| 73 |
+
def parallel_scan(alpha: Tensor, kv: Tensor) -> Tensor:
|
| 74 |
+
"""Sequential prefix scan fallback: S_t[i,:] = α_t[i]·S_{t-1}[i,:] + kv_t[i,:]."""
|
| 75 |
B, H, T, d1, d2 = kv.shape
|
| 76 |
states = torch.zeros(B, H, T, d1, d2, device=kv.device, dtype=kv.dtype)
|
| 77 |
S = torch.zeros(B, H, d1, d2, device=kv.device, dtype=kv.dtype)
|
| 78 |
for t in range(T):
|
| 79 |
+
decay = alpha[:, :, t] # [B, H, d]
|
| 80 |
while decay.dim() < S.dim():
|
| 81 |
decay = decay.unsqueeze(-1)
|
| 82 |
S = S * decay + kv[:, :, t]
|
| 83 |
states[:, :, t] = S
|
| 84 |
return states
|
| 85 |
|
| 86 |
+
def parallel_scan_with_state(alpha: Tensor, kv: Tensor):
|
| 87 |
+
"""Sequential prefix scan that also returns the final (decay_acc, S) state."""
|
| 88 |
B, H, T, d1, d2 = kv.shape
|
| 89 |
states = torch.zeros(B, H, T, d1, d2, device=kv.device, dtype=kv.dtype)
|
| 90 |
S = torch.zeros(B, H, d1, d2, device=kv.device, dtype=kv.dtype)
|
| 91 |
+
decay_acc = torch.ones(B, H, d1, device=alpha.device, dtype=alpha.dtype)
|
| 92 |
for t in range(T):
|
| 93 |
+
decay = alpha[:, :, t]
|
| 94 |
while decay.dim() < S.dim():
|
| 95 |
decay = decay.unsqueeze(-1)
|
| 96 |
S = S * decay + kv[:, :, t]
|
| 97 |
states[:, :, t] = S
|
| 98 |
+
decay_acc = decay_acc * alpha[:, :, t]
|
| 99 |
+
return states, (decay_acc, S)
|
| 100 |
|
| 101 |
|
| 102 |
|
|
|
|
| 169 |
|
| 170 |
Unlike Transformer KV-Cache that stores all past keys & values (O(T) memory),
|
| 171 |
each layer here stores exactly ONE state tuple:
|
| 172 |
+
(decay_acc, S) where S ∈ ℝ^{B, H, d, d}
|
| 173 |
+
This is the monoid "sum" of all past (α_i, k_i⊗v_i) via ⊕.
|
| 174 |
Memory is O(1) per layer regardless of sequence length.
|
| 175 |
|
| 176 |
不同于 Transformer 的 KV-Cache (存储所有过去的 key 和 value, O(T) 内存),
|
| 177 |
这里每层仅存储一个状态元组:
|
| 178 |
+
(decay_acc, S) 其中 S ∈ ℝ^{B, H, d, d}
|
| 179 |
+
这是所有过去的 (α_i, k_i⊗v_i) 通过 ⊕ 累积的幺半群 "和"。
|
| 180 |
无论序列多长,每层内存 O(1)。
|
| 181 |
"""
|
| 182 |
|
|
|
|
| 219 |
b: tuple[Tensor, Tensor],
|
| 220 |
) -> tuple[Tensor, Tensor]:
|
| 221 |
"""
|
| 222 |
+
The monoid binary operator ⊕ on (vector decay, state matrix) pairs.
|
| 223 |
+
幺半群二元算子 ⊕,作用于 (向量衰减, 状态矩阵) 对。
|
| 224 |
|
| 225 |
Definition / 定义:
|
| 226 |
+
(α, S) ⊕ (β, X) = (α·β, diag(β)·S + X)
|
| 227 |
+
where α, β ∈ (0,1)^d are per-dimension vector decay gates (sigmoid output).
|
| 228 |
|
| 229 |
Why this is a monoid / 为什么这是幺半群:
|
| 230 |
• Associativity / 结合律:
|
|
|
|
| 235 |
推理时可以 O(1) 左折叠 (逐步追加)。
|
| 236 |
|
| 237 |
• Identity / 单位元:
|
| 238 |
+
e = (1, 0) → e ⊕ a = a ⊕ e = a ✓
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
|
| 240 |
Causal semantics / 因果语义:
|
| 241 |
S_t = α_t · S_{t-1} + k_t ⊗ v_t
|
|
|
|
| 246 |
这就是 *显式因果建模* — 模型必须在每个时间步学习如何
|
| 247 |
平衡保留旧信息与吸收新信息。
|
| 248 |
"""
|
| 249 |
+
decay_a, kv_a = a
|
| 250 |
+
decay_b, kv_b = b
|
| 251 |
|
| 252 |
+
new_decay = decay_a * decay_b # α·β (element-wise product)
|
|
|
|
| 253 |
while decay_b.dim() < kv_a.dim():
|
| 254 |
+
decay_b = decay_b.unsqueeze(-1) # broadcast to [B,H,...,1,1]
|
| 255 |
|
| 256 |
+
return new_decay, kv_a * decay_b + kv_b # β·S + X
|
| 257 |
|
| 258 |
|
| 259 |
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
|
|
|
| 321 |
self.v_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
|
| 322 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=config.attention_bias)
|
| 323 |
|
| 324 |
+
# --- Output gate (novel component, randomly initialized) ---
|
| 325 |
+
# --- 输出门控 (全新组件, 随机初始化) ---
|
| 326 |
+
# Modulates the multi-head readout before o_proj, similar to GLA/RetNet.
|
| 327 |
+
# gate = SiLU(gate_proj(x)), output = gate ⊙ concat_heads(o)
|
| 328 |
+
# This lets the model suppress or amplify specific head outputs
|
| 329 |
+
# conditioned on the current input, increasing expressiveness.
|
| 330 |
+
# 在 o_proj 之前调制多头读出, 类似 GLA/RetNet。
|
| 331 |
+
# gate = SiLU(gate_proj(x)), output = gate ⊙ concat_heads(o)
|
| 332 |
+
# 使模型能根据当前输入抑制或放大特定头的输出, 增加表达力。
|
| 333 |
+
self.gate_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 334 |
+
|
| 335 |
# --- Decay gate (novel component, randomly initialized) ---
|
| 336 |
# --- 衰减门 (全新组件, 随机初始化) ---
|
| 337 |
# Projects hidden_size → num_heads * head_dim, yielding a VECTOR per head.
|
|
|
|
| 356 |
# 可能无界增长。
|
| 357 |
self.q_norm = LlamaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
| 358 |
self.k_norm = LlamaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
| 359 |
+
self.o_norm = LlamaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
| 360 |
|
| 361 |
# --- Learnable initial state h0 (novel component, zero-initialized) ---
|
| 362 |
# --- 可学习初始状态 h0 (全新组件, 零初始化) ---
|
|
|
|
| 400 |
k = self.k_proj(hidden_states).view(B, T, H, d).transpose(1, 2)
|
| 401 |
v = self.v_proj(hidden_states).view(B, T, H, d).transpose(1, 2)
|
| 402 |
|
| 403 |
+
# --- Output gate: computed from input, applied before o_proj ---
|
| 404 |
+
# --- 输出门控: 从输入计算, 在 o_proj 之前应用 ---
|
| 405 |
+
gate = torch.nn.functional.silu(self.gate_proj(hidden_states)) # [B,T,H*d]
|
| 406 |
+
|
| 407 |
# --- QK-Norm: stabilize q·S readout scale ---
|
| 408 |
# --- QK 归一化: 稳定 q·S 读出尺度 ---
|
| 409 |
q = self.q_norm(q) * self.scaling
|
|
|
|
| 423 |
|
| 424 |
# --- Compute per-dimension vector decay gate α_t ---
|
| 425 |
# --- 计算每维度向量衰减门 α_t ---
|
| 426 |
+
# Sigmoid: α = σ(Wx + b)
|
| 427 |
+
# Value range: α ∈ (0, 1).
|
| 428 |
+
# When Wx → -∞: σ → 0 (complete forgetting)
|
| 429 |
+
# When Wx → +∞: σ → 1 (perfect memory, no forgetting)
|
|
|
|
|
|
|
| 430 |
# Each dimension of the d-vector decays independently:
|
| 431 |
# S_t[i,j] = α_t[i] · S_{t-1}[i,j] + k_t[i] · v_t[j]
|
| 432 |
#
|
| 433 |
+
# Sigmoid: α = σ(Wx + b)
|
| 434 |
+
# 值域: α ∈ (0, 1)。
|
| 435 |
+
# 当 Wx → -∞: σ → 0 (完全遗忘)
|
| 436 |
+
# 当 Wx → +∞: σ → 1 (完美记忆, 不遗忘)
|
|
|
|
| 437 |
# d-向量的每个维度独立衰减:
|
| 438 |
# S_t[i,j] = α_t[i] · S_{t-1}[i,j] + k_t[i] · v_t[j]
|
| 439 |
raw = self.decay_proj(hidden_states) # [B,T,H*d]
|
| 440 |
+
alpha = torch.sigmoid(raw) # [B,T,H*d]
|
| 441 |
+
alpha = alpha.view(B, T, H, d).transpose(1, 2) # [B,H,T,d]
|
| 442 |
|
| 443 |
# --- Apply attention_mask: PAD tokens must be invisible to the recurrence ---
|
| 444 |
# --- 应用注意力掩码: PAD token 必须对递推不可见 ---
|
|
|
|
| 448 |
# 这使得 S_t = 1·S_{t-1} + 0 = S_{t-1}, 即 PAD 对状态是空操作。
|
| 449 |
if attention_mask is not None:
|
| 450 |
# attention_mask: [B, T] → [B, 1, T, 1] for broadcasting with [B, H, T, d]
|
| 451 |
+
mask = attention_mask[:, None, :, None].to(alpha.dtype) # [B,1,T,1]
|
| 452 |
+
alpha = alpha * mask + (1 - mask) # PAD → α=1 (preserve state)
|
| 453 |
+
k = k * mask # PAD → k=0
|
| 454 |
+
v = v * mask # PAD → v=0 → kv=0
|
| 455 |
|
| 456 |
# ══════════════════════════════════════════════════════════
|
| 457 |
# Inference path (RNN mode): O(1) per token per layer
|
|
|
|
| 473 |
# Outer product: k_t ⊗ v_t ∈ ℝ^{H×d×d}
|
| 474 |
# 外积: k_t ⊗ v_t ∈ ℝ^{H×d×d}
|
| 475 |
kv_t = torch.einsum('bhd, bhe -> bhde', k[:, :, 0], v[:, :, 0])
|
| 476 |
+
alpha_t = alpha[:, :, 0] # [B,H,d]
|
| 477 |
|
| 478 |
prev = monoid_cache.get_state(self.layer_idx) if monoid_cache else None
|
| 479 |
if prev is None:
|
| 480 |
# First token: initialize from learnable h0
|
| 481 |
# 第一个 token: 从可学习的 h0 初始化
|
| 482 |
+
decay_t = alpha_t
|
| 483 |
while decay_t.dim() < self.h0.dim():
|
| 484 |
decay_t = decay_t.unsqueeze(-1)
|
| 485 |
+
new_state = (alpha_t, self.h0.expand(B, -1, -1, -1) * decay_t + kv_t)
|
| 486 |
else:
|
| 487 |
# Subsequent tokens: fold via monoid_op — O(1)!
|
| 488 |
# 后续 token: 通过 monoid_op 折叠 — O(1)!
|
| 489 |
+
new_state = monoid_op(prev, (alpha_t, kv_t))
|
| 490 |
|
| 491 |
if monoid_cache is not None:
|
| 492 |
monoid_cache.update(self.layer_idx, new_state)
|
|
|
|
| 494 |
# Readout: o_t = q_t · S_t
|
| 495 |
# 读出: o_t = q_t · S_t
|
| 496 |
o = torch.einsum('bhd, bhde -> bhe', q[:, :, 0], new_state[1])
|
| 497 |
+
o = self.o_norm(o)
|
| 498 |
# Reshape [B,H,d] → [B,1,H*d] (heads contiguous, matching scan path)
|
| 499 |
# 重塑 [B,H,d] → [B,1,H*d] (头连续排列, 与扫描路径一致)
|
| 500 |
o = o.contiguous().view(B, 1, -1)
|
| 501 |
+
return self.o_proj(gate * o), new_state
|
| 502 |
|
| 503 |
# ══════════════════════════════════════════════════════════
|
| 504 |
# Inference prefill (use_cache=True, T>1): parallel scan + readout
|
|
|
|
| 512 |
# 内存: O(B·H·T·d²) — 与训练路径相同。
|
| 513 |
if use_cache:
|
| 514 |
kv = torch.einsum('bhtd, bhte -> bhtde', k, v) # [B,H,T,d,d]
|
| 515 |
+
states, (decay_acc, S_T) = parallel_scan_with_state(alpha, kv)
|
| 516 |
|
| 517 |
# Add h0 contribution: S_t += diag(∏_{i=0}^{t} α_i) · h0
|
| 518 |
# 叠加 h0 贡献: S_t += diag(∏_{i=0}^{t} α_i) · h0
|
| 519 |
+
cum_alpha = torch.exp(torch.cumsum(torch.log(alpha + 1e-8), dim=2)) # [B,H,T,d]
|
| 520 |
+
h0_decay = cum_alpha.unsqueeze(-1) # [B,H,T,d,1]
|
| 521 |
states = states + h0_decay * self.h0.unsqueeze(2) # broadcast h0 [1,H,1,d,d]
|
| 522 |
|
| 523 |
# Final state includes h0 contribution
|
| 524 |
# 最终状态包含 h0 贡献
|
| 525 |
+
total_h0_decay = decay_acc.unsqueeze(-1) # [B,H,d,1]
|
| 526 |
S_final = S_T + total_h0_decay * self.h0.squeeze(0) # [B,H,d,d]
|
| 527 |
# h0 is [1,H,d,d], squeeze(0) removed for clarity but expand also works
|
| 528 |
+
final_state = (decay_acc, S_final)
|
| 529 |
|
| 530 |
if monoid_cache is not None:
|
| 531 |
monoid_cache.update(self.layer_idx, final_state)
|
|
|
|
| 533 |
# Vectorized readout: o_t = q_t · S_t for all t
|
| 534 |
# 向量化读出: 一次性计算所有 t 的 o_t = q_t · S_t
|
| 535 |
o = torch.einsum('bhtd, bhtde -> bhte', q, states) # [B,H,T,d]
|
| 536 |
+
o = self.o_norm(o)
|
| 537 |
o = o.transpose(1, 2).contiguous().view(B, T, -1)
|
| 538 |
+
return self.o_proj(gate * o), final_state
|
| 539 |
|
| 540 |
# ══════════════════════════════════════════════════════════
|
| 541 |
# Training path: parallel scan + vectorized readout
|
|
|
|
| 557 |
|
| 558 |
# Parallel prefix scan: S_t = diag(α_t)·S_{t-1} + kv_t (from S=0)
|
| 559 |
# 并行前缀扫描: S_t = diag(α_t)·S_{t-1} + kv_t (从 S=0 开始)
|
| 560 |
+
# alpha is [B,H,T,d] — vector decay per dimension.
|
| 561 |
+
# alpha 为 [B,H,T,d] — 每维度向量衰减。
|
| 562 |
+
states = parallel_scan(alpha, kv) # [B,H,T,d,d]
|
| 563 |
|
| 564 |
# Add h0 contribution: S_t += diag(∏_{i=0}^{t} α_i) · h0
|
| 565 |
# 叠加 h0 贡献: S_t += diag(∏_{i=0}^{t} α_i) · h0
|
| 566 |
+
cum_alpha = torch.exp(torch.cumsum(torch.log(alpha + 1e-8), dim=2)) # [B,H,T,d]
|
| 567 |
+
h0_decay = cum_alpha.unsqueeze(-1) # [B,H,T,d,1]
|
| 568 |
states = states + h0_decay * self.h0.unsqueeze(2) # broadcast h0 [1,H,1,d,d]
|
| 569 |
|
| 570 |
# Vectorized readout: o_t = q_t · S_t for all t at once
|
| 571 |
# 向量化读出: 一次性计算所有 t 的 q_t · S_t
|
| 572 |
o = torch.einsum('bhtd, bhtde -> bhte', q, states) # [B,H,T,d]
|
| 573 |
+
o = self.o_norm(o)
|
| 574 |
|
| 575 |
o = o.transpose(1, 2).contiguous().view(B, T, -1)
|
| 576 |
+
return self.o_proj(gate * o), None
|
| 577 |
|
| 578 |
|
| 579 |
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
|
|
|
| 650 |
module.weight.data[module.padding_idx].zero_()
|
| 651 |
|
| 652 |
if isinstance(module, MonoidAttention):
|
| 653 |
+
# decay_proj: bias init so sigmoid(bias) ≈ 0.95 → mostly remembering at start
|
| 654 |
+
# decay_proj: 偏置初始化使 sigmoid(bias) ≈ 0.95 → 初始时以记忆为主
|
| 655 |
+
nn.init.constant_(module.decay_proj.bias, 3.0)
|
| 656 |
+
# gate_proj: small init so gate starts near identity (SiLU(0)=0,
|
| 657 |
+
# but normal weights give moderate gate values)
|
| 658 |
+
# gate_proj: 小初始化, 使门控从接近恒等开始
|
| 659 |
+
nn.init.normal_(module.gate_proj.weight, mean=0.0, std=0.01)
|
| 660 |
+
# o_norm: RMSNorm weight defaults to 1.0 (identity), explicit for clarity
|
| 661 |
+
# o_norm: RMSNorm 权重默认为 1.0 (恒等), 显式设置确保正确
|
| 662 |
+
nn.init.ones_(module.o_norm.weight)
|
| 663 |
|
| 664 |
class MonoidModel(MonoidPreTrainedModel):
|
| 665 |
"""
|
README.md
CHANGED
|
@@ -21,33 +21,153 @@ model-index:
|
|
| 21 |
|
| 22 |
A 1.3B parameter language model that replaces softmax attention with **causal monoid state compression**, achieving **O(1) time per token** and **O(1) memory** at inference — regardless of sequence length.
|
| 23 |
|
| 24 |
-
##
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
│
|
| 45 |
-
|
| 46 |
-
│
|
| 47 |
-
|
| 48 |
-
│
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
```
|
| 52 |
|
| 53 |
## Key Properties
|
|
@@ -75,11 +195,11 @@ S_t = diag(α_t) · S_{t-1} + k_t ⊗ v_t — vector decay monoid recurrence
|
|
| 75 |
o_t = q_t · S_t — state readout
|
| 76 |
```
|
| 77 |
|
| 78 |
-
This is a monoid because the binary operator `(
|
| 79 |
|
| 80 |
## Vector Decay — Per-Dimension Memory Lifetimes
|
| 81 |
|
| 82 |
-
Unlike scalar decay (one α per head), Spartacus uses **vector decay**: each dimension of the d-vector has its own independent decay rate α_t[i] ∈ (0, 1
|
| 83 |
|
| 84 |
```
|
| 85 |
S_t[i,j] = α_t[i] · S_{t-1}[i,j] + k_t[i] · v_t[j]
|
|
@@ -89,26 +209,27 @@ This allows different feature dimensions to specialize:
|
|
| 89 |
- **Fast-decaying dimensions** (α ≈ 0) — local syntax, punctuation, function words
|
| 90 |
- **Slow-decaying dimensions** (α ≈ 1) — entity memory, topic tracking, long-range facts
|
| 91 |
|
| 92 |
-
The decay gate uses **
|
| 93 |
|
| 94 |
```
|
| 95 |
-
|
| 96 |
```
|
| 97 |
|
| 98 |
| Property | Value |
|
| 99 |
|---|---|
|
| 100 |
-
| Range | α ∈ (0, 1
|
| 101 |
-
| Perfect memory | W·x →
|
| 102 |
-
| Full forgetting | W·x →
|
| 103 |
-
| Stability | α
|
|
|
|
| 104 |
|
| 105 |
## Attention Mask — Padding-Aware Recurrence
|
| 106 |
|
| 107 |
The monoid recurrence correctly handles `attention_mask` for padded batches (e.g., left-padding during `generate()`). For PAD positions (mask=0):
|
| 108 |
|
| 109 |
```
|
| 110 |
-
|
| 111 |
-
k =
|
| 112 |
```
|
| 113 |
|
| 114 |
Net effect: `S_t = 1·S_{t-1} + 0 = S_{t-1}` — PAD acts as the **monoid identity element**, completely invisible to the recurrence. This ensures identical outputs whether inputs are padded or not.
|
|
@@ -117,9 +238,11 @@ Net effect: `S_t = 1·S_{t-1} + 0 = S_{t-1}` — PAD acts as the **monoid identi
|
|
| 117 |
|
| 118 |
- **SiLU-activated keys**: `k = SiLU(k_proj(x))` ensures non-negative keys, making the state matrix S positive semi-definite (PSD). This prevents "feature erasure" where one token's contribution cancels another's
|
| 119 |
- **QK-Norm**: RMSNorm on both q and k before readout, stabilizing the scale of q·S when the state matrix accumulates many outer products
|
| 120 |
-
- **
|
|
|
|
|
|
|
| 121 |
- **Learnable h0**: The initial state S₀ = h0 is a learnable parameter (zero-initialized), acting as a compressed "system prompt"
|
| 122 |
-
- **
|
| 123 |
|
| 124 |
## Three Forward Paths
|
| 125 |
|
|
@@ -141,7 +264,7 @@ Net effect: `S_t = 1·S_{t-1} + 0 = S_{t-1}` — PAD acts as the **monoid identi
|
|
| 141 |
| Layers | 16 |
|
| 142 |
| Attention heads | 32 |
|
| 143 |
| Head dimension | 64 |
|
| 144 |
-
| Decay gate | Vector decay, d=64 per head |
|
| 145 |
| State matrix per head | 64 × 64 = 4,096 floats |
|
| 146 |
| Vocabulary | 128,256 (Llama-3.2 tokenizer) |
|
| 147 |
| Precision | bfloat16 |
|
|
@@ -207,6 +330,9 @@ monoid_scan_cuda.py # Triton JIT parallel prefix scan (vector decay) + Py
|
|
| 207 |
model.safetensors # Model weights (bfloat16)
|
| 208 |
config.json # Model configuration
|
| 209 |
tokenizer.json # Llama-3.2 tokenizer
|
|
|
|
|
|
|
|
|
|
| 210 |
```
|
| 211 |
|
| 212 |
## Citation
|
|
|
|
| 21 |
|
| 22 |
A 1.3B parameter language model that replaces softmax attention with **causal monoid state compression**, achieving **O(1) time per token** and **O(1) memory** at inference — regardless of sequence length.
|
| 23 |
|
| 24 |
+
## SFT Training Curves
|
| 25 |
+
|
| 26 |
+
| Loss | Accuracy |
|
| 27 |
+
|:---:|:---:|
|
| 28 |
+
|  |  |
|
| 29 |
+
|
| 30 |
+
## Core Mechanism
|
| 31 |
+
|
| 32 |
+

|
| 33 |
+
|
| 34 |
+
## Architecture Overview
|
| 35 |
+
|
| 36 |
+
```
|
| 37 |
+
╔═══════════════════════════════════════════════════════════════════════════╗
|
| 38 |
+
║ MonoidForCausalLM (1.34B) ║
|
| 39 |
+
╠═══════════════════════════════════════════════════════════════════════════╣
|
| 40 |
+
║ ║
|
| 41 |
+
║ token_ids ──> [ embed_tokens 128256 × 2048 ] ──> x_0 ║
|
| 42 |
+
║ ║
|
| 43 |
+
║ ┌─────────────────────────┐ ║
|
| 44 |
+
║ │ MonoidDecoderLayer × 16 │ ◄── see detail below ║
|
| 45 |
+
║ └─────────────────────────┘ ║
|
| 46 |
+
║ │ ║
|
| 47 |
+
║ [ RMSNorm ] ║
|
| 48 |
+
║ │ ║
|
| 49 |
+
║ [ lm_head 2048 × 128256 ] ──> logits ║
|
| 50 |
+
║ (tied with embed_tokens) ║
|
| 51 |
+
╚═══════════════════════════════════════════════════════════════════════════╝
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
╔═══════════════════════════════════════════════════════════════════════════╗
|
| 55 |
+
║ MonoidDecoderLayer (× 16 layers) ║
|
| 56 |
+
╠═══════════════════════════════════════════════════════════════════════════╣
|
| 57 |
+
║ ║
|
| 58 |
+
║ x ─────────────────────────────────────────┐ (residual) ║
|
| 59 |
+
║ │ │ ║
|
| 60 |
+
║ [ input_layernorm RMSNorm ] │ ║
|
| 61 |
+
║ │ │ ║
|
| 62 |
+
║ [ MonoidAttention ] ◄── see detail below │ ║
|
| 63 |
+
║ │ │ ║
|
| 64 |
+
║ + <────────────────────────────────────────┘ ║
|
| 65 |
+
║ │ ║
|
| 66 |
+
║ x ─────────────────────────────────────────┐ (residual) ║
|
| 67 |
+
║ │ │ ║
|
| 68 |
+
║ [ post_attention_layernorm RMSNorm ] │ ║
|
| 69 |
+
║ │ │ ║
|
| 70 |
+
║ [ LlamaMLP 2048 → 8192 → 2048 ] │ ║
|
| 71 |
+
║ │ gate_proj ─┐ │ ║
|
| 72 |
+
║ │ up_proj ───┤─> SiLU(gate) ⊙ up │ ║
|
| 73 |
+
║ │ └──> down_proj ──> out │ ║
|
| 74 |
+
║ │ │ ║
|
| 75 |
+
║ + <────────────────────────────────────────┘ ║
|
| 76 |
+
║ │ ║
|
| 77 |
+
║ out ║
|
| 78 |
+
╚═══════════════════════════════════════════════════════════════════════════╝
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
╔═══════════════════════════════════════════════════════════════════════════╗
|
| 82 |
+
║ MonoidAttention (32 heads, d=64 per head) ║
|
| 83 |
+
╠═══════════════════════════════════════════════════════════════════════════╣
|
| 84 |
+
║ ║
|
| 85 |
+
║ x_t ∈ R^{2048} ║
|
| 86 |
+
║ │ ║
|
| 87 |
+
║ ├──> q_proj ──> [B,H,T,d] ──> RMSNorm ──> ×(1/√d) ──────> q_t ║
|
| 88 |
+
║ │ ║
|
| 89 |
+
║ ├──> k_proj ──> [B,H,T,d] ──> RMSNorm ──> SiLU ──────────> k_t ≥0 ║
|
| 90 |
+
║ │ ║
|
| 91 |
+
║ ├──> v_proj ──> [B,H,T,d] ────────────────────────────────> v_t ║
|
| 92 |
+
║ │ ║
|
| 93 |
+
║ ├──> decay_proj ──> Sigmoid ──> α_t ∈ (0,1)^d (vector decay gate) ║
|
| 94 |
+
║ │ bias init = 3.0 ║
|
| 95 |
+
║ │ → σ(3) ≈ 0.95 at start ║
|
| 96 |
+
║ │ ║
|
| 97 |
+
║ └──> gate_proj ──> SiLU ──────> g_t ∈ R^{H*d} (output gate) ║
|
| 98 |
+
║ ║
|
| 99 |
+
║ ┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄ ║
|
| 100 |
+
║ Monoid Recurrence (training: parallel prefix scan, decode: O(1)) ║
|
| 101 |
+
║ ┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄���┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄ ║
|
| 102 |
+
║ ║
|
| 103 |
+
║ k_t ⊗ v_t ──────────────┐ ║
|
| 104 |
+
║ [d×d] v ║
|
| 105 |
+
║ ┌─────────────────────────┐ ║
|
| 106 |
+
║ S_{t-1} ────> │ S_t = diag(α_t)·S_{t-1}│ ║
|
| 107 |
+
║ [d×d] │ + k_t ⊗ v_t │──> S_t ║
|
| 108 |
+
║ └─────────────────────────┘ [d×d] ║
|
| 109 |
+
║ "compressed causal history" ║
|
| 110 |
+
║ ║
|
| 111 |
+
║ h0 (learnable, zero-init) ──> S_0 at sequence start ║
|
| 112 |
+
║ ║
|
| 113 |
+
║ ┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄ ║
|
| 114 |
+
║ Readout + Output Projection ║
|
| 115 |
+
║ ┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄ ║
|
| 116 |
+
║ ║
|
| 117 |
+
║ q_t ──> einsum(q, S_t) ──> o_t ──> RMSNorm ──┐ ║
|
| 118 |
+
║ (o_norm) │ ║
|
| 119 |
+
║ v ║
|
| 120 |
+
║ g_t ──────────────────────────────────> g_t ⊙ o_t ──> o_proj ──> out ║
|
| 121 |
+
║ ║
|
| 122 |
+
╚═══════════════════════════════════════════════════════════════════════════╝
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
╔═══════════════════════════════════════════════════════════════════════════╗
|
| 126 |
+
║ MonoidCache — O(1) State (replaces O(T) KV-Cache) ║
|
| 127 |
+
╠═══════════════════════════════════════════════════════════════════════════╣
|
| 128 |
+
║ ║
|
| 129 |
+
║ Transformer KV-Cache: Monoid State Cache: ║
|
| 130 |
+
║ ┌──────────────────┐ ┌──────────────────┐ ║
|
| 131 |
+
║ │ K: [B,H,T,d] │ │ S: [B,H,d,d] │ ← fixed size ║
|
| 132 |
+
║ │ V: [B,H,T,d] │ │ α_acc: [B,H,d] │ ║
|
| 133 |
+
║ │ grows with T ↑↑↑ │ │ per layer │ ║
|
| 134 |
+
║ └──────────────────┘ └──────────────────┘ ║
|
| 135 |
+
║ Memory: O(T·H·d) Memory: O(H·d²) ║
|
| 136 |
+
║ 1000 tok → 2M floats/layer ANY length → 131K floats/layer ║
|
| 137 |
+
║ ║
|
| 138 |
+
║ Decode step: Decode step: ║
|
| 139 |
+
║ o = softmax(q·K^T)·V S_t = α_t·S_{t-1} + k_t⊗v_t ║
|
| 140 |
+
║ scan T keys ↑ o_t = q_t · S_t ║
|
| 141 |
+
║ Time: O(T·d) Time: O(d²) ← constant! ║
|
| 142 |
+
╚═══════════════════════════════════════════════════════════════════════════╝
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
╔═══════════════════════════════════════════════════════════════════════════╗
|
| 146 |
+
║ Weight Transfer from Llama-3.2-1B-Instruct ║
|
| 147 |
+
╠══════════════════════════════════════════════════════════════���════════════╣
|
| 148 |
+
║ ║
|
| 149 |
+
║ Reused directly (frozen-compatible): ║
|
| 150 |
+
║ ┌──────────────────────────────────────────────┐ ║
|
| 151 |
+
║ │ embed_tokens 128256 × 2048 │ ║
|
| 152 |
+
║ │ lm_head 2048 × 128256 (tied) │ ║
|
| 153 |
+
║ │ LlamaMLP × 16 gate/up/down_proj │ ║
|
| 154 |
+
║ │ LlamaRMSNorm × 33 input/post_attn/final │ ║
|
| 155 |
+
║ │ q_proj × 16 2048 → 2048 │ ║
|
| 156 |
+
║ │ k_proj × 16 2048 → 2048 (tiled 8→32 heads from GQA) │ ║
|
| 157 |
+
║ │ v_proj × 16 2048 → 2048 (tiled 8→32 heads from GQA) │ ║
|
| 158 |
+
║ │ o_proj × 16 2048 → 2048 │ ║
|
| 159 |
+
║ └──────────────────────────────────────────────┘ ║
|
| 160 |
+
║ ║
|
| 161 |
+
║ Novel (randomly initialized): ║
|
| 162 |
+
║ ┌──────────────────────────────────────────────┐ ║
|
| 163 |
+
║ │ decay_proj × 16 2048 → 2048 (bias=3.0) │ ║
|
| 164 |
+
║ │ gate_proj × 16 2048 → 2048 (std=0.01) │ ║
|
| 165 |
+
║ │ q_norm × 16 RMSNorm(64) │ ║
|
| 166 |
+
║ │ k_norm × 16 RMSNorm(64) │ ║
|
| 167 |
+
║ │ o_norm × 16 RMSNorm(64) (weight=1) │ ║
|
| 168 |
+
║ │ h0 × 16 [1,32,64,64] (zeros) │ ║
|
| 169 |
+
║ └──────────────────────────────────────────────┘ ║
|
| 170 |
+
╚═══════════════════════════════════════════════════════════════════════════╝
|
| 171 |
```
|
| 172 |
|
| 173 |
## Key Properties
|
|
|
|
| 195 |
o_t = q_t · S_t — state readout
|
| 196 |
```
|
| 197 |
|
| 198 |
+
This is a monoid because the binary operator `(α, S) ⊕ (β, X) = (α·β, diag(β)·S + X)` is **associative**, enabling O(T) parallel prefix scan for training and O(1) sequential update for inference.
|
| 199 |
|
| 200 |
## Vector Decay — Per-Dimension Memory Lifetimes
|
| 201 |
|
| 202 |
+
Unlike scalar decay (one α per head), Spartacus uses **vector decay**: each dimension of the d-vector has its own independent decay rate α_t[i] ∈ (0, 1):
|
| 203 |
|
| 204 |
```
|
| 205 |
S_t[i,j] = α_t[i] · S_{t-1}[i,j] + k_t[i] · v_t[j]
|
|
|
|
| 209 |
- **Fast-decaying dimensions** (α ≈ 0) — local syntax, punctuation, function words
|
| 210 |
- **Slow-decaying dimensions** (α ≈ 1) — entity memory, topic tracking, long-range facts
|
| 211 |
|
| 212 |
+
The decay gate uses **Sigmoid** activation:
|
| 213 |
|
| 214 |
```
|
| 215 |
+
α_t = σ(W·x_t + b)
|
| 216 |
```
|
| 217 |
|
| 218 |
| Property | Value |
|
| 219 |
|---|---|
|
| 220 |
+
| Range | α ∈ (0, 1) — bounded, no explosion |
|
| 221 |
+
| Perfect memory | W·x → +∞ ⟹ σ → 1 (lossless retention) |
|
| 222 |
+
| Full forgetting | W·x → -∞ ⟹ σ → 0 (complete reset) |
|
| 223 |
+
| Stability | α < 1 by construction — no divergence regardless of input magnitude |
|
| 224 |
+
| Bias init | b = 3.0 ⟹ σ(3) ≈ 0.95, model starts in "mostly remember" mode |
|
| 225 |
|
| 226 |
## Attention Mask — Padding-Aware Recurrence
|
| 227 |
|
| 228 |
The monoid recurrence correctly handles `attention_mask` for padded batches (e.g., left-padding during `generate()`). For PAD positions (mask=0):
|
| 229 |
|
| 230 |
```
|
| 231 |
+
α = α * mask + (1 - mask) → α = 1 (preserve state unchanged)
|
| 232 |
+
k = k * mask, v = v * mask → kv = 0 (no information injected)
|
| 233 |
```
|
| 234 |
|
| 235 |
Net effect: `S_t = 1·S_{t-1} + 0 = S_{t-1}` — PAD acts as the **monoid identity element**, completely invisible to the recurrence. This ensures identical outputs whether inputs are padded or not.
|
|
|
|
| 238 |
|
| 239 |
- **SiLU-activated keys**: `k = SiLU(k_proj(x))` ensures non-negative keys, making the state matrix S positive semi-definite (PSD). This prevents "feature erasure" where one token's contribution cancels another's
|
| 240 |
- **QK-Norm**: RMSNorm on both q and k before readout, stabilizing the scale of q·S when the state matrix accumulates many outer products
|
| 241 |
+
- **Output Norm**: RMSNorm on the readout o after `q·S`, further stabilizing scale before gating
|
| 242 |
+
- **Output Gate**: `gate = SiLU(gate_proj(x))`, modulates the multi-head readout before o_proj (similar to GLA/RetNet). Lets the model suppress or amplify specific head outputs conditioned on the current input
|
| 243 |
+
- **Sigmoid decay gate**: Ensures α ∈ (0, 1) by construction — allows near-perfect memory (α→1) while preventing state explosion (α>1). Bias initialized to 3.0 so σ(3)≈0.95, starting in high-retention mode
|
| 244 |
- **Learnable h0**: The initial state S₀ = h0 is a learnable parameter (zero-initialized), acting as a compressed "system prompt"
|
| 245 |
+
- **Log-space decay in scan**: The parallel prefix scan works in log-space `log(α)` to avoid numerical underflow when computing cumulative products over long sequences
|
| 246 |
|
| 247 |
## Three Forward Paths
|
| 248 |
|
|
|
|
| 264 |
| Layers | 16 |
|
| 265 |
| Attention heads | 32 |
|
| 266 |
| Head dimension | 64 |
|
| 267 |
+
| Decay gate | Vector decay (Sigmoid), d=64 per head |
|
| 268 |
| State matrix per head | 64 × 64 = 4,096 floats |
|
| 269 |
| Vocabulary | 128,256 (Llama-3.2 tokenizer) |
|
| 270 |
| Precision | bfloat16 |
|
|
|
|
| 330 |
model.safetensors # Model weights (bfloat16)
|
| 331 |
config.json # Model configuration
|
| 332 |
tokenizer.json # Llama-3.2 tokenizer
|
| 333 |
+
ARCH.png # Core mechanism diagram (monoid recurrence + parallel scan)
|
| 334 |
+
ACC_SPAR.png # SFT accuracy curve
|
| 335 |
+
LOSS_SPAR.png # SFT loss curve
|
| 336 |
```
|
| 337 |
|
| 338 |
## Citation
|
config.json
CHANGED
|
@@ -23,5 +23,6 @@
|
|
| 23 |
"pad_token_id": 128009,
|
| 24 |
"rms_norm_eps": 1e-05,
|
| 25 |
"transformers_version": "4.57.6",
|
|
|
|
| 26 |
"vocab_size": 128256
|
| 27 |
}
|
|
|
|
| 23 |
"pad_token_id": 128009,
|
| 24 |
"rms_norm_eps": 1e-05,
|
| 25 |
"transformers_version": "4.57.6",
|
| 26 |
+
"use_cache": false,
|
| 27 |
"vocab_size": 128256
|
| 28 |
}
|
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a423fff0e285f14e1da9e2bc7ace8ea963c408ad9715bb485c365affb0da4cf1
|
| 3 |
+
size 2945686352
|
monoid_scan_cuda.py
CHANGED
|
@@ -3,18 +3,18 @@ monoid_scan_cuda.py — Triton CUDA JIT Accelerated Parallel Prefix Scan
|
|
| 3 |
monoid_scan_cuda.py — Triton CUDA JIT 加速的并行前缀扫描
|
| 4 |
|
| 5 |
This module implements the parallel prefix scan for the vector-decay monoid recurrence:
|
| 6 |
-
y_t[i,:] =
|
| 7 |
本模块实现向量衰减幺半群递推的并行前缀扫描:
|
| 8 |
-
y_t[i,:] =
|
| 9 |
|
| 10 |
This is the computational backbone of Monoid Attention's state compression.
|
| 11 |
这是幺半群注意力状态压缩的计算骨干。
|
| 12 |
|
| 13 |
Vector decay: each dimension of the D_k×D_v state matrix has its own
|
| 14 |
-
per-dimension decay rate α_t ∈
|
| 15 |
-
dimensions to have independent memory lifetimes
|
| 16 |
-
local syntax, slow-decaying for global entity memory).
|
| 17 |
-
向量衰减: D_k×D_v 状态矩阵的每个维度拥有独立的衰减率 α_t ∈
|
| 18 |
使不同特征维度拥有独立的记忆生命周期 (快速衰减用于局部语法, 慢速衰减用于全局实体记忆)。
|
| 19 |
|
| 20 |
Implementation:
|
|
@@ -22,13 +22,13 @@ Implementation:
|
|
| 22 |
Each program handles one row of the state matrix (D_v elements)
|
| 23 |
with a scalar decay per row.
|
| 24 |
Backward: reverse-order adjoint scan for gradient computation.
|
| 25 |
-
Per-row reduction for
|
| 26 |
Auto-dispatches: CUDA → Triton kernel, CPU/MPS → PyTorch fallback.
|
| 27 |
|
| 28 |
前向: 沿 T 维顺序扫描, 跨 B*H*D_k 在 GPU 上并行。
|
| 29 |
每个 program 处理状态矩阵的一行 (D_v 个元素), 每行一个标量衰减。
|
| 30 |
反向: 逆序伴随变量扫描计算梯度。
|
| 31 |
-
逐行归约计算
|
| 32 |
自动分派: CUDA → Triton 核函数, CPU/MPS → PyTorch 回退。
|
| 33 |
"""
|
| 34 |
|
|
@@ -52,28 +52,28 @@ except ImportError:
|
|
| 52 |
# 回退: 纯 PyTorch 串行扫描 (CPU / MPS / no Triton)
|
| 53 |
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
| 54 |
|
| 55 |
-
def _sequential_scan(
|
| 56 |
"""
|
| 57 |
Pure PyTorch sequential scan fallback (when no CUDA / Triton available).
|
| 58 |
纯 PyTorch 串行扫描回退 (无 CUDA / Triton 时使用)。
|
| 59 |
|
| 60 |
Implements the vector-decay monoid recurrence step by step:
|
| 61 |
acc_0 = 0
|
| 62 |
-
acc_t[i,:] =
|
| 63 |
This is O(T) sequential — correct but slow on GPU.
|
| 64 |
逐步实现向量衰减幺半群递推:
|
| 65 |
acc_0 = 0
|
| 66 |
-
acc_t[i,:] =
|
| 67 |
这是 O(T) 串行的 — 结果正确但在 GPU 上较慢。
|
| 68 |
|
| 69 |
Args:
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
values:
|
| 73 |
-
|
| 74 |
Returns:
|
| 75 |
-
output:
|
| 76 |
-
|
| 77 |
"""
|
| 78 |
B, H, T, D_k, D_v = values.shape
|
| 79 |
out = torch.empty_like(values)
|
|
@@ -83,7 +83,7 @@ def _sequential_scan(log_decays: Tensor, values: Tensor) -> Tensor:
|
|
| 83 |
for t in range(T):
|
| 84 |
# S_t = diag(α_t) · S_{t-1} + kv_t (vector decay monoid recurrence)
|
| 85 |
# S_t = diag(α_t) · S_{t-1} + kv_t (向量衰减幺半群递推)
|
| 86 |
-
decay_t =
|
| 87 |
acc = acc * decay_t + values[:, :, t]
|
| 88 |
out[:, :, t] = acc
|
| 89 |
return out
|
|
@@ -140,10 +140,9 @@ if HAS_TRITON:
|
|
| 140 |
o_base = O_ptr + bhdk * s_o_bhdk
|
| 141 |
|
| 142 |
for t in range(T):
|
| 143 |
-
# Load scalar
|
| 144 |
-
# 加载此行在时刻 t 的标量
|
| 145 |
-
|
| 146 |
-
decay = tl.exp(ld_val)
|
| 147 |
|
| 148 |
# Load kv_t[row, :] (one row of the outer product)
|
| 149 |
# 加载 kv_t[行, :] (外积的一行)
|
|
@@ -178,12 +177,12 @@ if HAS_TRITON:
|
|
| 178 |
反向扫描核函数 — 通过伴随方法计算梯度 (向量衰减)。
|
| 179 |
|
| 180 |
Each program handles one row of the state matrix (one d_k dimension).
|
| 181 |
-
The decay for this row is a scalar, so the
|
| 182 |
-
∂L/∂
|
| 183 |
The sum over j (D_v) is computed within this single program — no atomic_add.
|
| 184 |
每个 program 处理状态矩阵的一行 (一个 d_k 维度)。
|
| 185 |
-
该行的衰减是标量, 因此
|
| 186 |
-
∂L/∂
|
| 187 |
对 j (D_v) 的求和在单个 program 内完成 — 无需 atomic_add。
|
| 188 |
"""
|
| 189 |
bhdk = tl.program_id(0)
|
|
@@ -216,19 +215,18 @@ if HAS_TRITON:
|
|
| 216 |
lam, mask=dv_mask,
|
| 217 |
)
|
| 218 |
|
| 219 |
-
# ∂L/∂
|
| 220 |
# Per-row scalar gradient: sum over D_v within this program.
|
| 221 |
# 逐行标量梯度: 在此 program 内对 D_v 求和。
|
| 222 |
-
|
| 223 |
-
a_t = tl.exp(ld_val)
|
| 224 |
|
| 225 |
if t > 0:
|
| 226 |
y_prev = tl.load(
|
| 227 |
O_ptr + bhdk * s_o_bhdk + (t - 1) * s_o_t + dv_offs * s_o_dv,
|
| 228 |
mask=dv_mask, other=0.0,
|
| 229 |
).to(tl.float32)
|
| 230 |
-
|
| 231 |
-
tl.atomic_add(GLD_ptr + bhdk * s_gld_bhdk + t * s_gld_t,
|
| 232 |
|
| 233 |
# Prepare for next step (t-1): adj = a_t · λ_t
|
| 234 |
# 为下一步 (t-1) 准备: adj = a_t · λ_t
|
|
@@ -255,16 +253,16 @@ if HAS_TRITON:
|
|
| 255 |
逐行归约消除大部分 atomic_add 开销。
|
| 256 |
"""
|
| 257 |
@staticmethod
|
| 258 |
-
def forward(ctx,
|
| 259 |
B, H, T, D_k, D_v = values.shape
|
| 260 |
|
| 261 |
# Reshape for row-parallel kernel:
|
| 262 |
-
#
|
| 263 |
-
# values:
|
| 264 |
# 为行并行核函数重塑:
|
| 265 |
-
#
|
| 266 |
-
# values:
|
| 267 |
-
ld_flat =
|
| 268 |
v_flat = values.permute(0, 1, 3, 2, 4).contiguous().reshape(B * H * D_k, T, D_v)
|
| 269 |
o_flat = torch.empty_like(v_flat)
|
| 270 |
|
|
@@ -283,8 +281,8 @@ if HAS_TRITON:
|
|
| 283 |
BLOCK_DV=BLOCK_DV,
|
| 284 |
)
|
| 285 |
|
| 286 |
-
# Save for backward: need
|
| 287 |
-
# 为反向传播保存: 需要
|
| 288 |
ctx.save_for_backward(ld_flat, o_flat)
|
| 289 |
ctx.shape_info = (B, H, T, D_k, D_v, BHDK, BLOCK_DV)
|
| 290 |
# Reshape back: [B*H*D_k, T, D_v] → [B, H, D_k, T, D_v] → [B, H, T, D_k, D_v]
|
|
@@ -318,15 +316,15 @@ if HAS_TRITON:
|
|
| 318 |
# Reshape gradients back to original layout
|
| 319 |
# 重塑梯度回原始布局
|
| 320 |
# gld: [B*H*D_k, T] → [B, H, D_k, T] → [B, H, T, D_k]
|
| 321 |
-
|
| 322 |
# gv: [B*H*D_k, T, D_v] → [B, H, D_k, T, D_v] → [B, H, T, D_k, D_v]
|
| 323 |
grad_values = gv_flat.reshape(B, H, D_k, T, D_v).permute(0, 1, 3, 2, 4).contiguous()
|
| 324 |
-
return
|
| 325 |
|
| 326 |
-
def _triton_parallel_scan(
|
| 327 |
"""Triton-accelerated parallel scan entry point (vector decay).
|
| 328 |
Triton 加速的并行扫描入口 (向量衰减)。"""
|
| 329 |
-
return _ParallelScanFn.apply(
|
| 330 |
|
| 331 |
else:
|
| 332 |
_triton_parallel_scan = None
|
|
@@ -336,7 +334,7 @@ else:
|
|
| 336 |
# Public API / 公共接口
|
| 337 |
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
| 338 |
|
| 339 |
-
def parallel_scan(
|
| 340 |
"""
|
| 341 |
Parallel prefix scan — computes all prefix monoid sums (vector decay).
|
| 342 |
并行前缀扫描 — 计算所有前缀幺半群和 (向量衰减)。
|
|
@@ -357,21 +355,21 @@ def parallel_scan(log_decays: Tensor, values: Tensor) -> Tensor:
|
|
| 357 |
CPU/MPS → PyTorch 串行扫描 (正确, 较慢)
|
| 358 |
|
| 359 |
Args:
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
values:
|
| 363 |
-
|
| 364 |
Returns:
|
| 365 |
-
states:
|
| 366 |
-
|
| 367 |
"""
|
| 368 |
if _triton_parallel_scan is not None and values.is_cuda:
|
| 369 |
-
return _triton_parallel_scan(
|
| 370 |
-
return _sequential_scan(
|
| 371 |
|
| 372 |
|
| 373 |
def parallel_scan_with_state(
|
| 374 |
-
|
| 375 |
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
|
| 376 |
"""
|
| 377 |
Parallel prefix scan + extract final state for inference handoff (vector decay).
|
|
@@ -389,23 +387,23 @@ def parallel_scan_with_state(
|
|
| 389 |
这是训练模式 (并行扫描) 和推理模式 (串行 monoid_op) 之间的桥梁。
|
| 390 |
|
| 391 |
Args:
|
| 392 |
-
|
| 393 |
-
values:
|
| 394 |
|
| 395 |
Returns:
|
| 396 |
output: [B, H, T, D_k, D_v] — all prefix states S_1..S_T
|
| 397 |
所有前缀状态
|
| 398 |
-
final_state: (
|
| 399 |
-
|
| 400 |
-
累积
|
| 401 |
final_state: [B, H, D_k, D_v] — S_T, the compressed causal summary
|
| 402 |
S_T, 压缩的因果摘要
|
| 403 |
"""
|
| 404 |
-
output = parallel_scan(
|
| 405 |
-
#
|
| 406 |
-
# 对所有
|
| 407 |
-
|
| 408 |
# The last timestep's state IS the full causal summary
|
| 409 |
# 最后一个时间步的状态就是完整的因果摘要
|
| 410 |
final_state = output[:, :, -1] # [B, H, D_k, D_v]
|
| 411 |
-
return output, (
|
|
|
|
| 3 |
monoid_scan_cuda.py — Triton CUDA JIT 加速的并行前缀扫描
|
| 4 |
|
| 5 |
This module implements the parallel prefix scan for the vector-decay monoid recurrence:
|
| 6 |
+
y_t[i,:] = decay_t[i] · y_{t-1}[i,:] + x_t[i,:]
|
| 7 |
本模块实现向量衰减幺半群递推的并行前缀扫描:
|
| 8 |
+
y_t[i,:] = decay_t[i] · y_{t-1}[i,:] + x_t[i,:]
|
| 9 |
|
| 10 |
This is the computational backbone of Monoid Attention's state compression.
|
| 11 |
这是幺半群注意力状态压缩的计算骨干。
|
| 12 |
|
| 13 |
Vector decay: each dimension of the D_k×D_v state matrix has its own
|
| 14 |
+
per-dimension decay rate α_t ∈ (0,1)^{D_k} (sigmoid output), enabling
|
| 15 |
+
different feature dimensions to have independent memory lifetimes
|
| 16 |
+
(fast-decaying for local syntax, slow-decaying for global entity memory).
|
| 17 |
+
向量衰减: D_k×D_v 状态矩阵的每个维度拥有独立的衰减率 α_t ∈ (0,1)^{D_k} (sigmoid 输出),
|
| 18 |
使不同特征维度拥有独立的记忆生命周期 (快速衰减用于局部语法, 慢速衰减用于全局实体记忆)。
|
| 19 |
|
| 20 |
Implementation:
|
|
|
|
| 22 |
Each program handles one row of the state matrix (D_v elements)
|
| 23 |
with a scalar decay per row.
|
| 24 |
Backward: reverse-order adjoint scan for gradient computation.
|
| 25 |
+
Per-row reduction for decay gradient (no atomic_add needed).
|
| 26 |
Auto-dispatches: CUDA → Triton kernel, CPU/MPS → PyTorch fallback.
|
| 27 |
|
| 28 |
前向: 沿 T 维顺序扫描, 跨 B*H*D_k 在 GPU 上并行。
|
| 29 |
每个 program 处理状态矩阵的一行 (D_v 个元素), 每行一个标量衰减。
|
| 30 |
反向: 逆序伴随变量扫描计算梯度。
|
| 31 |
+
逐行归约计算 decay 梯度 (无需 atomic_add)。
|
| 32 |
自动分派: CUDA → Triton 核函数, CPU/MPS → PyTorch 回退。
|
| 33 |
"""
|
| 34 |
|
|
|
|
| 52 |
# 回退: 纯 PyTorch 串行扫描 (CPU / MPS / no Triton)
|
| 53 |
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
| 54 |
|
| 55 |
+
def _sequential_scan(decays: Tensor, values: Tensor) -> Tensor:
|
| 56 |
"""
|
| 57 |
Pure PyTorch sequential scan fallback (when no CUDA / Triton available).
|
| 58 |
纯 PyTorch 串行扫描回退 (无 CUDA / Triton 时使用)。
|
| 59 |
|
| 60 |
Implements the vector-decay monoid recurrence step by step:
|
| 61 |
acc_0 = 0
|
| 62 |
+
acc_t[i,:] = decay_t[i] · acc_{t-1}[i,:] + values_t[i,:]
|
| 63 |
This is O(T) sequential — correct but slow on GPU.
|
| 64 |
逐步实现向量衰减幺半群递推:
|
| 65 |
acc_0 = 0
|
| 66 |
+
acc_t[i,:] = decay_t[i] · acc_{t-1}[i,:] + values_t[i,:]
|
| 67 |
这是 O(T) 串行的 — 结果正确但在 GPU 上较慢。
|
| 68 |
|
| 69 |
Args:
|
| 70 |
+
decays: [B, H, T, D_k] — per-dimension per-step decay gates α_t ∈ (0,1)
|
| 71 |
+
每维度每步衰减门 α_t ∈ (0,1)
|
| 72 |
+
values: [B, H, T, D_k, D_v] — outer products k_t⊗v_t to accumulate
|
| 73 |
+
待累积的外积 k_t⊗v_t
|
| 74 |
Returns:
|
| 75 |
+
output: [B, H, T, D_k, D_v] — all prefix states S_1, ..., S_T
|
| 76 |
+
所有前缀状态 S_1, ..., S_T
|
| 77 |
"""
|
| 78 |
B, H, T, D_k, D_v = values.shape
|
| 79 |
out = torch.empty_like(values)
|
|
|
|
| 83 |
for t in range(T):
|
| 84 |
# S_t = diag(α_t) · S_{t-1} + kv_t (vector decay monoid recurrence)
|
| 85 |
# S_t = diag(α_t) · S_{t-1} + kv_t (向量衰减幺半群递推)
|
| 86 |
+
decay_t = decays[:, :, t].unsqueeze(-1) # [B,H,D_k,1]
|
| 87 |
acc = acc * decay_t + values[:, :, t]
|
| 88 |
out[:, :, t] = acc
|
| 89 |
return out
|
|
|
|
| 140 |
o_base = O_ptr + bhdk * s_o_bhdk
|
| 141 |
|
| 142 |
for t in range(T):
|
| 143 |
+
# Load scalar decay for this row at time t
|
| 144 |
+
# 加载此行在时刻 t 的标量 decay
|
| 145 |
+
decay = tl.load(ld_base + t * s_ld_t).to(tl.float32)
|
|
|
|
| 146 |
|
| 147 |
# Load kv_t[row, :] (one row of the outer product)
|
| 148 |
# 加载 kv_t[行, :] (外积的一行)
|
|
|
|
| 177 |
反向扫描核函数 — 通过伴随方法计算梯度 (向量衰减)。
|
| 178 |
|
| 179 |
Each program handles one row of the state matrix (one d_k dimension).
|
| 180 |
+
The decay for this row is a scalar, so the decay gradient is:
|
| 181 |
+
∂L/∂α_t[i] = Σ_j(λ_t[i,j] · y_{t-1}[i,j])
|
| 182 |
The sum over j (D_v) is computed within this single program — no atomic_add.
|
| 183 |
每个 program 处理状态矩阵的一行 (一个 d_k 维度)。
|
| 184 |
+
该行的衰减是标量, 因此 decay 梯度为:
|
| 185 |
+
∂L/∂α_t[i] = Σ_j(λ_t[i,j] · y_{t-1}[i,j])
|
| 186 |
对 j (D_v) 的求和在单个 program 内完成 — 无需 atomic_add。
|
| 187 |
"""
|
| 188 |
bhdk = tl.program_id(0)
|
|
|
|
| 215 |
lam, mask=dv_mask,
|
| 216 |
)
|
| 217 |
|
| 218 |
+
# ∂L/∂α_t[i] = Σ_j(λ_t[i,j] · y_{t-1}[i,j])
|
| 219 |
# Per-row scalar gradient: sum over D_v within this program.
|
| 220 |
# 逐行标量梯度: 在此 program 内对 D_v 求和。
|
| 221 |
+
a_t = tl.load(LD_ptr + bhdk * s_ld_bhdk + t * s_ld_t).to(tl.float32)
|
|
|
|
| 222 |
|
| 223 |
if t > 0:
|
| 224 |
y_prev = tl.load(
|
| 225 |
O_ptr + bhdk * s_o_bhdk + (t - 1) * s_o_t + dv_offs * s_o_dv,
|
| 226 |
mask=dv_mask, other=0.0,
|
| 227 |
).to(tl.float32)
|
| 228 |
+
grad_d = tl.sum(lam * y_prev)
|
| 229 |
+
tl.atomic_add(GLD_ptr + bhdk * s_gld_bhdk + t * s_gld_t, grad_d)
|
| 230 |
|
| 231 |
# Prepare for next step (t-1): adj = a_t · λ_t
|
| 232 |
# 为下一步 (t-1) 准备: adj = a_t · λ_t
|
|
|
|
| 253 |
逐行归约消除大部分 atomic_add 开销。
|
| 254 |
"""
|
| 255 |
@staticmethod
|
| 256 |
+
def forward(ctx, decays: Tensor, values: Tensor) -> Tensor:
|
| 257 |
B, H, T, D_k, D_v = values.shape
|
| 258 |
|
| 259 |
# Reshape for row-parallel kernel:
|
| 260 |
+
# decays: [B, H, T, D_k] → permute to [B, H, D_k, T] → [B*H*D_k, T]
|
| 261 |
+
# values: [B, H, T, D_k, D_v] → permute to [B, H, D_k, T, D_v] → [B*H*D_k, T, D_v]
|
| 262 |
# 为行并行核函数重塑:
|
| 263 |
+
# decays: [B, H, T, D_k] → 转置为 [B, H, D_k, T] → [B*H*D_k, T]
|
| 264 |
+
# values: [B, H, T, D_k, D_v] → 转置为 [B, H, D_k, T, D_v] → [B*H*D_k, T, D_v]
|
| 265 |
+
ld_flat = decays.permute(0, 1, 3, 2).contiguous().reshape(B * H * D_k, T)
|
| 266 |
v_flat = values.permute(0, 1, 3, 2, 4).contiguous().reshape(B * H * D_k, T, D_v)
|
| 267 |
o_flat = torch.empty_like(v_flat)
|
| 268 |
|
|
|
|
| 281 |
BLOCK_DV=BLOCK_DV,
|
| 282 |
)
|
| 283 |
|
| 284 |
+
# Save for backward: need decays and forward outputs y_t
|
| 285 |
+
# 为反向传播保存: 需要 decays 和前向输出 y_t
|
| 286 |
ctx.save_for_backward(ld_flat, o_flat)
|
| 287 |
ctx.shape_info = (B, H, T, D_k, D_v, BHDK, BLOCK_DV)
|
| 288 |
# Reshape back: [B*H*D_k, T, D_v] → [B, H, D_k, T, D_v] → [B, H, T, D_k, D_v]
|
|
|
|
| 316 |
# Reshape gradients back to original layout
|
| 317 |
# 重塑梯度回原始布局
|
| 318 |
# gld: [B*H*D_k, T] → [B, H, D_k, T] → [B, H, T, D_k]
|
| 319 |
+
grad_decays = gld_flat.to(grad_output.dtype).reshape(B, H, D_k, T).permute(0, 1, 3, 2).contiguous()
|
| 320 |
# gv: [B*H*D_k, T, D_v] → [B, H, D_k, T, D_v] → [B, H, T, D_k, D_v]
|
| 321 |
grad_values = gv_flat.reshape(B, H, D_k, T, D_v).permute(0, 1, 3, 2, 4).contiguous()
|
| 322 |
+
return grad_decays, grad_values
|
| 323 |
|
| 324 |
+
def _triton_parallel_scan(decays: Tensor, values: Tensor) -> Tensor:
|
| 325 |
"""Triton-accelerated parallel scan entry point (vector decay).
|
| 326 |
Triton 加速的并行扫描入口 (向量衰减)。"""
|
| 327 |
+
return _ParallelScanFn.apply(decays, values)
|
| 328 |
|
| 329 |
else:
|
| 330 |
_triton_parallel_scan = None
|
|
|
|
| 334 |
# Public API / 公共接口
|
| 335 |
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
| 336 |
|
| 337 |
+
def parallel_scan(decays: Tensor, values: Tensor) -> Tensor:
|
| 338 |
"""
|
| 339 |
Parallel prefix scan — computes all prefix monoid sums (vector decay).
|
| 340 |
并行前缀扫描 — 计算所有前缀幺半群和 (向量衰减)。
|
|
|
|
| 355 |
CPU/MPS → PyTorch 串行扫描 (正确, 较慢)
|
| 356 |
|
| 357 |
Args:
|
| 358 |
+
decays: [B, H, T, D_k] — per-dimension decay gates α_t ∈ (0,1) (sigmoid output)
|
| 359 |
+
每维度衰减门 α_t ∈ (0,1) (sigmoid 输出)
|
| 360 |
+
values: [B, H, T, D_k, D_v] — outer products k_t⊗v_t
|
| 361 |
+
外积 k_t⊗v_t
|
| 362 |
Returns:
|
| 363 |
+
states: [B, H, T, D_k, D_v] — all prefix states S_1..S_T
|
| 364 |
+
所有前缀状态 S_1..S_T
|
| 365 |
"""
|
| 366 |
if _triton_parallel_scan is not None and values.is_cuda:
|
| 367 |
+
return _triton_parallel_scan(decays, values)
|
| 368 |
+
return _sequential_scan(decays, values)
|
| 369 |
|
| 370 |
|
| 371 |
def parallel_scan_with_state(
|
| 372 |
+
decays: Tensor, values: Tensor,
|
| 373 |
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
|
| 374 |
"""
|
| 375 |
Parallel prefix scan + extract final state for inference handoff (vector decay).
|
|
|
|
| 387 |
这是训练模式 (并行扫描) 和推理模式 (串行 monoid_op) 之间的桥梁。
|
| 388 |
|
| 389 |
Args:
|
| 390 |
+
decays: [B, H, T, D_k] — per-dimension decay gates α_t ∈ (0,1)
|
| 391 |
+
values: [B, H, T, D_k, D_v]
|
| 392 |
|
| 393 |
Returns:
|
| 394 |
output: [B, H, T, D_k, D_v] — all prefix states S_1..S_T
|
| 395 |
所有前缀状态
|
| 396 |
+
final_state: (decay_acc, S_T) where
|
| 397 |
+
decay_acc: [B, H, D_k] — accumulated decay product (for future monoid_op)
|
| 398 |
+
累积衰减乘积 (供后续 monoid_op 使用)
|
| 399 |
final_state: [B, H, D_k, D_v] — S_T, the compressed causal summary
|
| 400 |
S_T, 压缩的因果摘要
|
| 401 |
"""
|
| 402 |
+
output = parallel_scan(decays, values)
|
| 403 |
+
# Product of all decays over T — use log-sum-exp for numerical stability in bf16
|
| 404 |
+
# 对所有 decay 沿 T 求积 — 使用 log-sum-exp 保证 bf16 数值稳定
|
| 405 |
+
decay_acc = torch.exp(torch.sum(torch.log(decays + 1e-8), dim=2)) # [B, H, D_k]
|
| 406 |
# The last timestep's state IS the full causal summary
|
| 407 |
# 最后一个时间步的状态就是完整的因果摘要
|
| 408 |
final_state = output[:, :, -1] # [B, H, D_k, D_v]
|
| 409 |
+
return output, (decay_acc, final_state)
|
training_args.bin
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a74d419f03ffc06a6f989ef4dc1768ad7f4298b971f129f0a2e121514a016053
|
| 3 |
+
size 6353
|