File size: 18,167 Bytes
ae7984f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 | """
monoid_scan_cuda.py — Triton CUDA JIT Accelerated Parallel Prefix Scan
monoid_scan_cuda.py — Triton CUDA JIT 加速的并行前缀扫描
This module implements the parallel prefix scan for the monoid recurrence:
y_t = exp(log_decay_t) · y_{t-1} + x_t
本模块实现幺半群递推的并行前缀扫描:
y_t = exp(log_decay_t) · y_{t-1} + x_t
This is the computational backbone of Monoid Attention's state compression.
这是幺半群注意力状态压缩的计算骨干。
Why parallel prefix scan matters / 并行前缀扫描为什么重要:
The monoid recurrence S_t = α_t·S_{t-1} + kv_t is inherently sequential.
However, because (log_α, S) ⊕ (log_β, X) = (log_α+log_β, exp(log_β)·S+X)
is ASSOCIATIVE, we can compute all prefix sums S_1..S_T via a parallel
reduction tree in O(log T) depth instead of O(T) sequential steps.
幺半群递推 S_t = α_t·S_{t-1} + kv_t 本质上是串行的。
但因为 (log_α, S) ⊕ (log_β, X) = (log_α+log_β, exp(log_β)·S+X)
满足结合律, 我们可以通过并行归约树在 O(log T) 深度内计算所有前缀和 S_1..S_T,
而非 O(T) 的串行步骤。
Training uses O(T) parallel scan (this file).
Inference uses O(1) sequential monoid_op (in MonoidForCausalLM.py).
训练使用 O(T) 并行扫描 (本文件)。
推理使用 O(1) 串行 monoid_op (在 MonoidForCausalLM.py 中)。
Implementation:
Forward: sequential scan along T, parallelized across B*H*D on GPU.
Backward: reverse-order adjoint scan for gradient computation.
Auto-dispatches: CUDA → Triton kernel, CPU/MPS → PyTorch fallback.
前向: 沿 T 维顺序扫描, 跨 B*H*D 在 GPU 上并行。
反向: 逆序伴随变量扫描计算梯度。
自动分派: CUDA → Triton 核函数, CPU/MPS → PyTorch 回退。
"""
from __future__ import annotations
import torch
from torch import Tensor
from torch.autograd import Function
from typing import Tuple
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# Fallback: pure PyTorch sequential scan
# 回退: 纯 PyTorch 串行扫描 (CPU / MPS / no Triton)
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
def _sequential_scan(log_decays: Tensor, values: Tensor) -> Tensor:
"""
Pure PyTorch sequential scan fallback (when no CUDA / Triton available).
纯 PyTorch 串行扫描回退 (无 CUDA / Triton 时使用)。
Implements the monoid recurrence step by step:
acc_0 = 0
acc_t = exp(log_decay_t) · acc_{t-1} + values_t
This is O(T) sequential — correct but slow on GPU.
逐步实现幺半群递推:
acc_0 = 0
acc_t = exp(log_decay_t) · acc_{t-1} + values_t
这是 O(T) 串行的 — 结果正确但在 GPU 上较慢。
Args:
log_decays: [B, H, T, 1] — log of per-head per-step decay gates
每头每步衰减门的对数
values: [B, H, T, D_k, D_v] — outer products k_t⊗v_t to accumulate
待累积的外积 k_t⊗v_t
Returns:
output: [B, H, T, D_k, D_v] — all prefix states S_1, ..., S_T
所有前缀状态 S_1, ..., S_T
"""
B, H, T, D_k, D_v = values.shape
out = torch.empty_like(values)
# acc represents S_t — the compressed causal state at time t
# acc 代表 S_t — 时刻 t 的压缩因果状态
acc = torch.zeros(B, H, D_k, D_v, device=values.device, dtype=values.dtype)
for t in range(T):
# S_t = α_t · S_{t-1} + kv_t (the core monoid recurrence)
# S_t = α_t · S_{t-1} + kv_t (核心幺半群递推)
decay_t = torch.exp(log_decays[:, :, t]).unsqueeze(-1) # [B,H,1,1]
acc = acc * decay_t + values[:, :, t]
out[:, :, t] = acc
return out
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# Triton Kernels — GPU-accelerated scan
# Triton 核函数 — GPU 加速扫描
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
if HAS_TRITON:
@triton.jit
def _scan_fwd_kernel(
LD_ptr, V_ptr, O_ptr,
T, D,
s_ld_bh, s_ld_t,
s_v_bh, s_v_t, s_v_d,
s_o_bh, s_o_t, s_o_d,
BLOCK_D: tl.constexpr,
):
"""
Forward scan kernel — computes all prefix states S_1..S_T.
前向扫描核函数 — 计算所有前缀状态 S_1..S_T。
Parallelization strategy / 并行化策略:
- program_id(0) = bh: one program per (batch, head) pair
每个 (batch, head) 对一个 program
- program_id(1) = db: one program per D-dimension block
每个 D 维 block 一个 program
- Sequential loop over T (the causal recurrence is inherently sequential)
沿 T 维串行循环 (因果递推本质上是串行的)
Each program computes: acc_t = exp(ld_t) * acc_{t-1} + val_t
for a BLOCK_D-wide slice of the flattened d_k*d_v state matrix.
每个 program 计算展平的 d_k*d_v 状态矩阵的一个 BLOCK_D 宽的切片。
Note: while the T-loop is sequential within each program,
B*H*ceil(D/BLOCK_D) programs run in parallel on the GPU.
注意: 虽然 T 循环在每个 program 内是串行的,
但 B*H*ceil(D/BLOCK_D) 个 program 在 GPU 上并行运行。
"""
bh = tl.program_id(0)
db = tl.program_id(1)
d_offs = db * BLOCK_D + tl.arange(0, BLOCK_D)
d_mask = d_offs < D
# acc = S_0 = 0 (identity element of the monoid)
# acc = S_0 = 0 (幺半群的单位元)
acc = tl.zeros([BLOCK_D], dtype=tl.float32)
ld_base = LD_ptr + bh * s_ld_bh
v_base = V_ptr + bh * s_v_bh
o_base = O_ptr + bh * s_o_bh
for t in range(T):
# Load log_decay and compute decay = exp(log_α_t)
# 加载 log_decay 并计算 decay = exp(log_α_t)
ld_val = tl.load(ld_base + t * s_ld_t).to(tl.float32)
decay = tl.exp(ld_val)
# Load kv_t (a slice of the outer product k_t⊗v_t)
# 加载 kv_t (外积 k_t⊗v_t 的一个切片)
val = tl.load(
v_base + t * s_v_t + d_offs * s_v_d,
mask=d_mask, other=0.0,
).to(tl.float32)
# Core recurrence: S_t = α_t · S_{t-1} + kv_t
# 核心递推: S_t = α_t · S_{t-1} + kv_t
acc = acc * decay + val
# Store S_t
tl.store(
o_base + t * s_o_t + d_offs * s_o_d,
acc, mask=d_mask,
)
@triton.jit
def _scan_bwd_kernel(
LD_ptr, O_ptr, GO_ptr, GV_ptr, GLD_ptr,
T, D,
s_ld_bh, s_ld_t,
s_o_bh, s_o_t, s_o_d,
s_go_bh, s_go_t, s_go_d,
s_gv_bh, s_gv_t, s_gv_d,
s_gld_bh, s_gld_t,
BLOCK_D: tl.constexpr,
):
"""
Backward scan kernel — computes gradients via adjoint method.
反向扫描核函数 — 通过伴随方法计算梯度。
The forward recurrence is: y_t = a_t * y_{t-1} + x_t
前向递推: y_t = a_t * y_{t-1} + x_t
The adjoint (reverse-time) recurrence for the Lagrange multiplier λ:
λ_t = ∂L/∂y_t + a_{t+1} · λ_{t+1} (backward in time)
伴随 (逆时间) 递推的拉格朗日乘子 λ:
λ_t = ∂L/∂y_t + a_{t+1} · λ_{t+1} (时间反向)
Gradients / 梯度:
∂L/∂x_t = λ_t (gradient w.r.t. input values)
(对输入值的梯度)
∂L/∂log_a_t = a_t · Σ_D(λ_t · y_{t-1}) (gradient w.r.t. log-decay)
(对对数衰减的梯度)
The gradient of log_decay is critical for training the decay gate:
it tells the model how to adjust each head's forgetting rate.
log_decay 的梯度对训练衰减门至关重要:
它告诉模型如何调整每个头的遗忘速率。
"""
bh = tl.program_id(0)
db = tl.program_id(1)
d_offs = db * BLOCK_D + tl.arange(0, BLOCK_D)
d_mask = d_offs < D
# adj holds a_{t+1} · λ_{t+1}, initialized to 0 at the sequence end
# adj 保存 a_{t+1} · λ_{t+1}, 在序列末尾初始化为 0
adj = tl.zeros([BLOCK_D], dtype=tl.float32)
for t_rev in range(T):
t = T - 1 - t_rev # reverse time / 逆序时间
# Load ∂L/∂y_t (upstream gradient)
# 加载 ∂L/∂y_t (上游梯度)
go = tl.load(
GO_ptr + bh * s_go_bh + t * s_go_t + d_offs * s_go_d,
mask=d_mask, other=0.0,
).to(tl.float32)
# Adjoint: λ_t = ∂L/∂y_t + a_{t+1} · λ_{t+1}
# 伴随: λ_t = ∂L/∂y_t + a_{t+1} · λ_{t+1}
lam = go + adj
# ∂L/∂x_t = λ_t (gradient of values / 值的梯度)
tl.store(
GV_ptr + bh * s_gv_bh + t * s_gv_t + d_offs * s_gv_d,
lam, mask=d_mask,
)
# ∂L/∂log_a_t = a_t · Σ_D(λ_t · y_{t-1})
# This gradient flows back to the decay gate (decay_proj),
# teaching the model how to control causal information retention.
# 此梯度回传到衰减门 (decay_proj),
# 教模型如何控制因果信息的保留。
ld_val = tl.load(LD_ptr + bh * s_ld_bh + t * s_ld_t).to(tl.float32)
a_t = tl.exp(ld_val)
if t > 0:
y_prev = tl.load(
O_ptr + bh * s_o_bh + (t - 1) * s_o_t + d_offs * s_o_d,
mask=d_mask, other=0.0,
).to(tl.float32)
grad_ld_partial = tl.sum(lam * y_prev) * a_t
tl.atomic_add(GLD_ptr + bh * s_gld_bh + t * s_gld_t, grad_ld_partial)
# Prepare for next step (t-1): adj = a_t · λ_t
# 为下一步 (t-1) 准备: adj = a_t · λ_t
adj = a_t * lam
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# Autograd Function — bridges Triton kernels with PyTorch autograd
# 自动微分函数 — 将 Triton 核函数与 PyTorch 自动微分桥接
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
class _ParallelScanFn(Function):
"""
Custom autograd function for the parallel prefix scan.
并行前缀扫描的自定义 autograd 函数。
Forward: launches _scan_fwd_kernel to compute all prefix states.
Backward: launches _scan_bwd_kernel to compute gradients via adjoint method.
前向: 启动 _scan_fwd_kernel 计算所有前缀状态。
反向: 启动 _scan_bwd_kernel 通过伴随方法计算梯度。
"""
@staticmethod
def forward(ctx, log_decays: Tensor, values: Tensor) -> Tensor:
B, H, T, D_k, D_v = values.shape
D = D_k * D_v # flattened state dimension / 展平的状态维度
# Flatten: [B,H,T,1] → [BH, T], [B,H,T,Dk,Dv] → [BH, T, D]
# 展平: [B,H,T,1] → [BH, T], [B,H,T,Dk,Dv] → [BH, T, D]
ld_flat = log_decays.squeeze(-1).contiguous().reshape(B * H, T)
v_flat = values.reshape(B * H, T, D).contiguous()
o_flat = torch.empty_like(v_flat)
BH = B * H
BLOCK_D = min(triton.next_power_of_2(D), 1024)
# Grid: (BH, ceil(D/BLOCK_D)) — one program per (batch*head, D-block)
# 网格: (BH, ceil(D/BLOCK_D)) — 每个 (batch*head, D-block) 一个 program
grid = (BH, triton.cdiv(D, BLOCK_D))
_scan_fwd_kernel[grid](
ld_flat, v_flat, o_flat,
T, D,
ld_flat.stride(0), ld_flat.stride(1),
v_flat.stride(0), v_flat.stride(1), v_flat.stride(2),
o_flat.stride(0), o_flat.stride(1), o_flat.stride(2),
BLOCK_D=BLOCK_D,
)
# Save for backward: need log_decays and forward outputs y_t
# 为反向传播保存: 需要 log_decays 和前向输出 y_t
ctx.save_for_backward(ld_flat, o_flat)
ctx.shape_info = (B, H, T, D_k, D_v, D, BH, BLOCK_D)
return o_flat.reshape(B, H, T, D_k, D_v)
@staticmethod
def backward(ctx, grad_output: Tensor):
ld_flat, o_flat = ctx.saved_tensors
B, H, T, D_k, D_v, D, BH, BLOCK_D = ctx.shape_info
go_flat = grad_output.reshape(BH, T, D).contiguous()
gv_flat = torch.empty_like(go_flat)
# Use f32 for atomic_add precision in gradient accumulation
# 使用 f32 保证 atomic_add 梯度累积的精度
gld_flat = torch.zeros(BH, T, device=ld_flat.device, dtype=torch.float32)
grid = (BH, triton.cdiv(D, BLOCK_D))
_scan_bwd_kernel[grid](
ld_flat, o_flat, go_flat, gv_flat, gld_flat,
T, D,
ld_flat.stride(0), ld_flat.stride(1),
o_flat.stride(0), o_flat.stride(1), o_flat.stride(2),
go_flat.stride(0), go_flat.stride(1), go_flat.stride(2),
gv_flat.stride(0), gv_flat.stride(1), gv_flat.stride(2),
gld_flat.stride(0), gld_flat.stride(1),
BLOCK_D=BLOCK_D,
)
grad_log_decays = gld_flat.to(grad_output.dtype).reshape(B, H, T, 1)
grad_values = gv_flat.reshape(B, H, T, D_k, D_v)
return grad_log_decays, grad_values
def _triton_parallel_scan(log_decays: Tensor, values: Tensor) -> Tensor:
"""Triton-accelerated parallel scan entry point.
Triton 加速的并行扫描入口。"""
return _ParallelScanFn.apply(log_decays, values)
else:
_triton_parallel_scan = None
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# Public API / 公共接口
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
def parallel_scan(log_decays: Tensor, values: Tensor) -> Tensor:
"""
Parallel prefix scan — computes all prefix monoid sums.
并行前缀扫描 — 计算所有前缀幺半群和。
This is the training-time workhorse of Monoid Attention.
It computes S_1, S_2, ..., S_T where S_t = α_t·S_{t-1} + kv_t,
for ALL timesteps simultaneously.
这是幺半群注意力训练时的主力计算。
它同时计算所有时间步的 S_1, S_2, ..., S_T,
其中 S_t = α_t·S_{t-1} + kv_t。
Auto-dispatches based on device:
CUDA → Triton JIT kernel (fast, with custom backward)
CPU/MPS → PyTorch sequential scan (correct, slower)
根据设备自动分派:
CUDA → Triton JIT 核函数 (快速, 带自定义反向传播)
CPU/MPS → PyTorch 串行扫描 (正确, 较慢)
Args:
log_decays: [B, H, T, 1] — log of decay gates α_t
衰减门 α_t 的对数
values: [B, H, T, D_k, D_v] — outer products k_t⊗v_t
外积 k_t⊗v_t
Returns:
states: [B, H, T, D_k, D_v] — all prefix states S_1..S_T
所有前缀状态 S_1..S_T
"""
if _triton_parallel_scan is not None and values.is_cuda:
return _triton_parallel_scan(log_decays, values)
return _sequential_scan(log_decays, values)
def parallel_scan_with_state(
log_decays: Tensor, values: Tensor,
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
"""
Parallel prefix scan + extract final state for inference handoff.
并行前缀扫描 + 提取最终状态用于推理切换。
Used during prefill: compute all training-time prefix states,
AND extract the final accumulated state S_T so that subsequent
tokens can be generated in O(1) RNN mode via monoid_op.
在预填充时使用: 计算所有训练时的前缀状态,
同时提取最终累积状态 S_T, 以便后续 token 可以
通过 monoid_op 以 O(1) RNN 模式生成。
This is the bridge between training mode (parallel scan)
and inference mode (sequential monoid_op).
这是训练模式 (并行扫描) 和推理模式 (串行 monoid_op) 之间的桥梁。
Args:
log_decays: [B, H, T, 1]
values: [B, H, T, D_k, D_v]
Returns:
output: [B, H, T, D_k, D_v] — all prefix states S_1..S_T
所有前缀状态
final_state: (log_acc, S_T) where
log_acc: [B, H, 1] — accumulated log-decay (for future monoid_op)
累积对数衰减 (供后续 monoid_op 使用)
final_state: [B, H, D_k, D_v] — S_T, the compressed causal summary
S_T, 压缩的因果摘要
"""
output = parallel_scan(log_decays, values)
# Sum all log-decays to get the total accumulated decay
# 对所有 log-decay 求和得到总累积衰减
log_acc = log_decays.squeeze(-1).sum(dim=2, keepdim=True) # [B, H, 1]
# The last timestep's state IS the full causal summary
# 最后一个时间步的状态就是完整的因果摘要
final_state = output[:, :, -1] # [B, H, D_k, D_v]
return output, (log_acc, final_state)
|