|
|
""" |
|
|
monoid_scan_cuda.py — Triton CUDA JIT Accelerated Parallel Prefix Scan |
|
|
monoid_scan_cuda.py — Triton CUDA JIT 加速的并行前缀扫描 |
|
|
|
|
|
This module implements the parallel prefix scan for the monoid recurrence: |
|
|
y_t = exp(log_decay_t) · y_{t-1} + x_t |
|
|
本模块实现幺半群递推的并行前缀扫描: |
|
|
y_t = exp(log_decay_t) · y_{t-1} + x_t |
|
|
|
|
|
This is the computational backbone of Monoid Attention's state compression. |
|
|
这是幺半群注意力状态压缩的计算骨干。 |
|
|
|
|
|
Why parallel prefix scan matters / 并行前缀扫描为什么重要: |
|
|
The monoid recurrence S_t = α_t·S_{t-1} + kv_t is inherently sequential. |
|
|
However, because (log_α, S) ⊕ (log_β, X) = (log_α+log_β, exp(log_β)·S+X) |
|
|
is ASSOCIATIVE, we can compute all prefix sums S_1..S_T via a parallel |
|
|
reduction tree in O(log T) depth instead of O(T) sequential steps. |
|
|
幺半群递推 S_t = α_t·S_{t-1} + kv_t 本质上是串行的。 |
|
|
但因为 (log_α, S) ⊕ (log_β, X) = (log_α+log_β, exp(log_β)·S+X) |
|
|
满足结合律, 我们可以通过并行归约树在 O(log T) 深度内计算所有前缀和 S_1..S_T, |
|
|
而非 O(T) 的串行步骤。 |
|
|
|
|
|
Training uses O(T) parallel scan (this file). |
|
|
Inference uses O(1) sequential monoid_op (in MonoidForCausalLM.py). |
|
|
训练使用 O(T) 并行扫描 (本文件)。 |
|
|
推理使用 O(1) 串行 monoid_op (在 MonoidForCausalLM.py 中)。 |
|
|
|
|
|
Implementation: |
|
|
Forward: sequential scan along T, parallelized across B*H*D on GPU. |
|
|
Backward: reverse-order adjoint scan for gradient computation. |
|
|
Auto-dispatches: CUDA → Triton kernel, CPU/MPS → PyTorch fallback. |
|
|
|
|
|
前向: 沿 T 维顺序扫描, 跨 B*H*D 在 GPU 上并行。 |
|
|
反向: 逆序伴随变量扫描计算梯度。 |
|
|
自动分派: CUDA → Triton 核函数, CPU/MPS → PyTorch 回退。 |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import torch |
|
|
from torch import Tensor |
|
|
from torch.autograd import Function |
|
|
from typing import Tuple |
|
|
|
|
|
try: |
|
|
import triton |
|
|
import triton.language as tl |
|
|
HAS_TRITON = True |
|
|
except ImportError: |
|
|
HAS_TRITON = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _sequential_scan(log_decays: Tensor, values: Tensor) -> Tensor: |
|
|
""" |
|
|
Pure PyTorch sequential scan fallback (when no CUDA / Triton available). |
|
|
纯 PyTorch 串行扫描回退 (无 CUDA / Triton 时使用)。 |
|
|
|
|
|
Implements the monoid recurrence step by step: |
|
|
acc_0 = 0 |
|
|
acc_t = exp(log_decay_t) · acc_{t-1} + values_t |
|
|
This is O(T) sequential — correct but slow on GPU. |
|
|
逐步实现幺半群递推: |
|
|
acc_0 = 0 |
|
|
acc_t = exp(log_decay_t) · acc_{t-1} + values_t |
|
|
这是 O(T) 串行的 — 结果正确但在 GPU 上较慢。 |
|
|
|
|
|
Args: |
|
|
log_decays: [B, H, T, 1] — log of per-head per-step decay gates |
|
|
每头每步衰减门的对数 |
|
|
values: [B, H, T, D_k, D_v] — outer products k_t⊗v_t to accumulate |
|
|
待累积的外积 k_t⊗v_t |
|
|
Returns: |
|
|
output: [B, H, T, D_k, D_v] — all prefix states S_1, ..., S_T |
|
|
所有前缀状态 S_1, ..., S_T |
|
|
""" |
|
|
B, H, T, D_k, D_v = values.shape |
|
|
out = torch.empty_like(values) |
|
|
|
|
|
|
|
|
acc = torch.zeros(B, H, D_k, D_v, device=values.device, dtype=values.dtype) |
|
|
for t in range(T): |
|
|
|
|
|
|
|
|
decay_t = torch.exp(log_decays[:, :, t]).unsqueeze(-1) |
|
|
acc = acc * decay_t + values[:, :, t] |
|
|
out[:, :, t] = acc |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if HAS_TRITON: |
|
|
|
|
|
@triton.jit |
|
|
def _scan_fwd_kernel( |
|
|
LD_ptr, V_ptr, O_ptr, |
|
|
T, D, |
|
|
s_ld_bh, s_ld_t, |
|
|
s_v_bh, s_v_t, s_v_d, |
|
|
s_o_bh, s_o_t, s_o_d, |
|
|
BLOCK_D: tl.constexpr, |
|
|
): |
|
|
""" |
|
|
Forward scan kernel — computes all prefix states S_1..S_T. |
|
|
前向扫描核函数 — 计算所有前缀状态 S_1..S_T。 |
|
|
|
|
|
Parallelization strategy / 并行化策略: |
|
|
- program_id(0) = bh: one program per (batch, head) pair |
|
|
每个 (batch, head) 对一个 program |
|
|
- program_id(1) = db: one program per D-dimension block |
|
|
每个 D 维 block 一个 program |
|
|
- Sequential loop over T (the causal recurrence is inherently sequential) |
|
|
沿 T 维串行循环 (因果递推本质上是串行的) |
|
|
|
|
|
Each program computes: acc_t = exp(ld_t) * acc_{t-1} + val_t |
|
|
for a BLOCK_D-wide slice of the flattened d_k*d_v state matrix. |
|
|
每个 program 计算展平的 d_k*d_v 状态矩阵的一个 BLOCK_D 宽的切片。 |
|
|
|
|
|
Note: while the T-loop is sequential within each program, |
|
|
B*H*ceil(D/BLOCK_D) programs run in parallel on the GPU. |
|
|
注意: 虽然 T 循环在每个 program 内是串行的, |
|
|
但 B*H*ceil(D/BLOCK_D) 个 program 在 GPU 上并行运行。 |
|
|
""" |
|
|
bh = tl.program_id(0) |
|
|
db = tl.program_id(1) |
|
|
d_offs = db * BLOCK_D + tl.arange(0, BLOCK_D) |
|
|
d_mask = d_offs < D |
|
|
|
|
|
|
|
|
|
|
|
acc = tl.zeros([BLOCK_D], dtype=tl.float32) |
|
|
|
|
|
ld_base = LD_ptr + bh * s_ld_bh |
|
|
v_base = V_ptr + bh * s_v_bh |
|
|
o_base = O_ptr + bh * s_o_bh |
|
|
|
|
|
for t in range(T): |
|
|
|
|
|
|
|
|
ld_val = tl.load(ld_base + t * s_ld_t).to(tl.float32) |
|
|
decay = tl.exp(ld_val) |
|
|
|
|
|
|
|
|
|
|
|
val = tl.load( |
|
|
v_base + t * s_v_t + d_offs * s_v_d, |
|
|
mask=d_mask, other=0.0, |
|
|
).to(tl.float32) |
|
|
|
|
|
|
|
|
|
|
|
acc = acc * decay + val |
|
|
|
|
|
|
|
|
tl.store( |
|
|
o_base + t * s_o_t + d_offs * s_o_d, |
|
|
acc, mask=d_mask, |
|
|
) |
|
|
|
|
|
@triton.jit |
|
|
def _scan_bwd_kernel( |
|
|
LD_ptr, O_ptr, GO_ptr, GV_ptr, GLD_ptr, |
|
|
T, D, |
|
|
s_ld_bh, s_ld_t, |
|
|
s_o_bh, s_o_t, s_o_d, |
|
|
s_go_bh, s_go_t, s_go_d, |
|
|
s_gv_bh, s_gv_t, s_gv_d, |
|
|
s_gld_bh, s_gld_t, |
|
|
BLOCK_D: tl.constexpr, |
|
|
): |
|
|
""" |
|
|
Backward scan kernel — computes gradients via adjoint method. |
|
|
反向扫描核函数 — 通过伴随方法计算梯度。 |
|
|
|
|
|
The forward recurrence is: y_t = a_t * y_{t-1} + x_t |
|
|
前向递推: y_t = a_t * y_{t-1} + x_t |
|
|
|
|
|
The adjoint (reverse-time) recurrence for the Lagrange multiplier λ: |
|
|
λ_t = ∂L/∂y_t + a_{t+1} · λ_{t+1} (backward in time) |
|
|
伴随 (逆时间) 递推的拉格朗日乘子 λ: |
|
|
λ_t = ∂L/∂y_t + a_{t+1} · λ_{t+1} (时间反向) |
|
|
|
|
|
Gradients / 梯度: |
|
|
∂L/∂x_t = λ_t (gradient w.r.t. input values) |
|
|
(对输入值的梯度) |
|
|
∂L/∂log_a_t = a_t · Σ_D(λ_t · y_{t-1}) (gradient w.r.t. log-decay) |
|
|
(对对数衰减的梯度) |
|
|
|
|
|
The gradient of log_decay is critical for training the decay gate: |
|
|
it tells the model how to adjust each head's forgetting rate. |
|
|
log_decay 的梯度对训练衰减门至关重要: |
|
|
它告诉模型如何调整每个头的遗忘速率。 |
|
|
""" |
|
|
bh = tl.program_id(0) |
|
|
db = tl.program_id(1) |
|
|
d_offs = db * BLOCK_D + tl.arange(0, BLOCK_D) |
|
|
d_mask = d_offs < D |
|
|
|
|
|
|
|
|
|
|
|
adj = tl.zeros([BLOCK_D], dtype=tl.float32) |
|
|
|
|
|
for t_rev in range(T): |
|
|
t = T - 1 - t_rev |
|
|
|
|
|
|
|
|
|
|
|
go = tl.load( |
|
|
GO_ptr + bh * s_go_bh + t * s_go_t + d_offs * s_go_d, |
|
|
mask=d_mask, other=0.0, |
|
|
).to(tl.float32) |
|
|
|
|
|
|
|
|
|
|
|
lam = go + adj |
|
|
|
|
|
|
|
|
tl.store( |
|
|
GV_ptr + bh * s_gv_bh + t * s_gv_t + d_offs * s_gv_d, |
|
|
lam, mask=d_mask, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ld_val = tl.load(LD_ptr + bh * s_ld_bh + t * s_ld_t).to(tl.float32) |
|
|
a_t = tl.exp(ld_val) |
|
|
|
|
|
if t > 0: |
|
|
y_prev = tl.load( |
|
|
O_ptr + bh * s_o_bh + (t - 1) * s_o_t + d_offs * s_o_d, |
|
|
mask=d_mask, other=0.0, |
|
|
).to(tl.float32) |
|
|
grad_ld_partial = tl.sum(lam * y_prev) * a_t |
|
|
tl.atomic_add(GLD_ptr + bh * s_gld_bh + t * s_gld_t, grad_ld_partial) |
|
|
|
|
|
|
|
|
|
|
|
adj = a_t * lam |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _ParallelScanFn(Function): |
|
|
""" |
|
|
Custom autograd function for the parallel prefix scan. |
|
|
并行前缀扫描的自定义 autograd 函数。 |
|
|
|
|
|
Forward: launches _scan_fwd_kernel to compute all prefix states. |
|
|
Backward: launches _scan_bwd_kernel to compute gradients via adjoint method. |
|
|
|
|
|
前向: 启动 _scan_fwd_kernel 计算所有前缀状态。 |
|
|
反向: 启动 _scan_bwd_kernel 通过伴随方法计算梯度。 |
|
|
""" |
|
|
@staticmethod |
|
|
def forward(ctx, log_decays: Tensor, values: Tensor) -> Tensor: |
|
|
B, H, T, D_k, D_v = values.shape |
|
|
D = D_k * D_v |
|
|
|
|
|
|
|
|
|
|
|
ld_flat = log_decays.squeeze(-1).contiguous().reshape(B * H, T) |
|
|
v_flat = values.reshape(B * H, T, D).contiguous() |
|
|
o_flat = torch.empty_like(v_flat) |
|
|
|
|
|
BH = B * H |
|
|
BLOCK_D = min(triton.next_power_of_2(D), 1024) |
|
|
|
|
|
|
|
|
grid = (BH, triton.cdiv(D, BLOCK_D)) |
|
|
|
|
|
_scan_fwd_kernel[grid]( |
|
|
ld_flat, v_flat, o_flat, |
|
|
T, D, |
|
|
ld_flat.stride(0), ld_flat.stride(1), |
|
|
v_flat.stride(0), v_flat.stride(1), v_flat.stride(2), |
|
|
o_flat.stride(0), o_flat.stride(1), o_flat.stride(2), |
|
|
BLOCK_D=BLOCK_D, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
ctx.save_for_backward(ld_flat, o_flat) |
|
|
ctx.shape_info = (B, H, T, D_k, D_v, D, BH, BLOCK_D) |
|
|
return o_flat.reshape(B, H, T, D_k, D_v) |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, grad_output: Tensor): |
|
|
ld_flat, o_flat = ctx.saved_tensors |
|
|
B, H, T, D_k, D_v, D, BH, BLOCK_D = ctx.shape_info |
|
|
|
|
|
go_flat = grad_output.reshape(BH, T, D).contiguous() |
|
|
gv_flat = torch.empty_like(go_flat) |
|
|
|
|
|
|
|
|
gld_flat = torch.zeros(BH, T, device=ld_flat.device, dtype=torch.float32) |
|
|
|
|
|
grid = (BH, triton.cdiv(D, BLOCK_D)) |
|
|
|
|
|
_scan_bwd_kernel[grid]( |
|
|
ld_flat, o_flat, go_flat, gv_flat, gld_flat, |
|
|
T, D, |
|
|
ld_flat.stride(0), ld_flat.stride(1), |
|
|
o_flat.stride(0), o_flat.stride(1), o_flat.stride(2), |
|
|
go_flat.stride(0), go_flat.stride(1), go_flat.stride(2), |
|
|
gv_flat.stride(0), gv_flat.stride(1), gv_flat.stride(2), |
|
|
gld_flat.stride(0), gld_flat.stride(1), |
|
|
BLOCK_D=BLOCK_D, |
|
|
) |
|
|
|
|
|
grad_log_decays = gld_flat.to(grad_output.dtype).reshape(B, H, T, 1) |
|
|
grad_values = gv_flat.reshape(B, H, T, D_k, D_v) |
|
|
return grad_log_decays, grad_values |
|
|
|
|
|
def _triton_parallel_scan(log_decays: Tensor, values: Tensor) -> Tensor: |
|
|
"""Triton-accelerated parallel scan entry point. |
|
|
Triton 加速的并行扫描入口。""" |
|
|
return _ParallelScanFn.apply(log_decays, values) |
|
|
|
|
|
else: |
|
|
_triton_parallel_scan = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parallel_scan(log_decays: Tensor, values: Tensor) -> Tensor: |
|
|
""" |
|
|
Parallel prefix scan — computes all prefix monoid sums. |
|
|
并行前缀扫描 — 计算所有前缀幺半群和。 |
|
|
|
|
|
This is the training-time workhorse of Monoid Attention. |
|
|
It computes S_1, S_2, ..., S_T where S_t = α_t·S_{t-1} + kv_t, |
|
|
for ALL timesteps simultaneously. |
|
|
这是幺半群注意力训练时的主力计算。 |
|
|
它同时计算所有时间步的 S_1, S_2, ..., S_T, |
|
|
其中 S_t = α_t·S_{t-1} + kv_t。 |
|
|
|
|
|
Auto-dispatches based on device: |
|
|
CUDA → Triton JIT kernel (fast, with custom backward) |
|
|
CPU/MPS → PyTorch sequential scan (correct, slower) |
|
|
根据设备自动分派: |
|
|
CUDA → Triton JIT 核函数 (快速, 带自定义反向传播) |
|
|
CPU/MPS → PyTorch 串行扫描 (正确, 较慢) |
|
|
|
|
|
Args: |
|
|
log_decays: [B, H, T, 1] — log of decay gates α_t |
|
|
衰减门 α_t 的对数 |
|
|
values: [B, H, T, D_k, D_v] — outer products k_t⊗v_t |
|
|
外积 k_t⊗v_t |
|
|
Returns: |
|
|
states: [B, H, T, D_k, D_v] — all prefix states S_1..S_T |
|
|
所有前缀状态 S_1..S_T |
|
|
""" |
|
|
if _triton_parallel_scan is not None and values.is_cuda: |
|
|
return _triton_parallel_scan(log_decays, values) |
|
|
return _sequential_scan(log_decays, values) |
|
|
|
|
|
|
|
|
def parallel_scan_with_state( |
|
|
log_decays: Tensor, values: Tensor, |
|
|
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: |
|
|
""" |
|
|
Parallel prefix scan + extract final state for inference handoff. |
|
|
并行前缀扫描 + 提取最终状态用于推理切换。 |
|
|
|
|
|
Used during prefill: compute all training-time prefix states, |
|
|
AND extract the final accumulated state S_T so that subsequent |
|
|
tokens can be generated in O(1) RNN mode via monoid_op. |
|
|
在预填充时使用: 计算所有训练时的前缀状态, |
|
|
同时提取最终累积状态 S_T, 以便后续 token 可以 |
|
|
通过 monoid_op 以 O(1) RNN 模式生成。 |
|
|
|
|
|
This is the bridge between training mode (parallel scan) |
|
|
and inference mode (sequential monoid_op). |
|
|
这是训练模式 (并行扫描) 和推理模式 (串行 monoid_op) 之间的桥梁。 |
|
|
|
|
|
Args: |
|
|
log_decays: [B, H, T, 1] |
|
|
values: [B, H, T, D_k, D_v] |
|
|
|
|
|
Returns: |
|
|
output: [B, H, T, D_k, D_v] — all prefix states S_1..S_T |
|
|
所有前缀状态 |
|
|
final_state: (log_acc, S_T) where |
|
|
log_acc: [B, H, 1] — accumulated log-decay (for future monoid_op) |
|
|
累积对数衰减 (供后续 monoid_op 使用) |
|
|
final_state: [B, H, D_k, D_v] — S_T, the compressed causal summary |
|
|
S_T, 压缩的因果摘要 |
|
|
""" |
|
|
output = parallel_scan(log_decays, values) |
|
|
|
|
|
|
|
|
log_acc = log_decays.squeeze(-1).sum(dim=2, keepdim=True) |
|
|
|
|
|
|
|
|
final_state = output[:, :, -1] |
|
|
return output, (log_acc, final_state) |
|
|
|