File size: 19,439 Bytes
b0097d1
 
 
 
b6c0790
 
 
 
b0097d1
 
 
 
b6c0790
 
 
 
 
 
b0097d1
 
b6c0790
 
 
b0097d1
b6c0790
b0097d1
 
b6c0790
 
b0097d1
b6c0790
b0097d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6c0790
b0097d1
b6c0790
b0097d1
b6c0790
b0097d1
b6c0790
b0097d1
 
 
b6c0790
 
b0097d1
 
 
 
 
 
 
 
 
 
 
 
b6c0790
 
 
b0097d1
 
 
 
 
 
b6c0790
 
b0097d1
 
 
 
 
 
 
b6c0790
 
 
 
 
b0097d1
 
b6c0790
 
b0097d1
 
b6c0790
 
 
 
b0097d1
 
 
b6c0790
 
 
 
 
b0097d1
b6c0790
 
b0097d1
b6c0790
 
 
 
b0097d1
b6c0790
 
 
b0097d1
b6c0790
 
 
b0097d1
 
b6c0790
 
b0097d1
 
 
b6c0790
 
b0097d1
b6c0790
 
b0097d1
 
b6c0790
 
b0097d1
 
b6c0790
b0097d1
b6c0790
 
b0097d1
 
 
 
 
b6c0790
 
 
 
 
 
 
b0097d1
 
b6c0790
 
 
 
 
 
 
 
 
 
 
b0097d1
b6c0790
 
 
 
b0097d1
 
 
b6c0790
b0097d1
 
 
 
b6c0790
 
b0097d1
b6c0790
 
b0097d1
 
 
 
 
 
b6c0790
 
b0097d1
b6c0790
 
b0097d1
 
b6c0790
 
 
 
b0097d1
 
 
 
b6c0790
 
b0097d1
b6c0790
 
b0097d1
 
 
 
 
 
 
 
 
 
 
 
b6c0790
 
b0097d1
 
b6c0790
b0097d1
b6c0790
b0097d1
 
b6c0790
b0097d1
b6c0790
b0097d1
 
 
 
 
b6c0790
 
 
 
 
 
 
 
b0097d1
 
b6c0790
 
 
 
 
b0097d1
 
 
b6c0790
b0097d1
 
 
b6c0790
b0097d1
 
 
 
 
b6c0790
 
 
b0097d1
 
 
 
b6c0790
b0097d1
b6c0790
 
b0097d1
b6c0790
 
 
b0097d1
b6c0790
b0097d1
 
 
b6c0790
b0097d1
 
 
 
 
b6c0790
b0097d1
 
b6c0790
 
 
 
 
 
b0097d1
 
 
b6c0790
 
b0097d1
 
 
 
 
 
 
 
 
 
 
 
b6c0790
 
b0097d1
 
b6c0790
 
b0097d1
 
 
b6c0790
b0097d1
 
 
 
 
 
 
 
 
b6c0790
 
b0097d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6c0790
 
b0097d1
 
 
 
 
 
 
 
 
 
 
 
 
b6c0790
b0097d1
 
 
 
 
 
b6c0790
 
 
 
b0097d1
 
b6c0790
 
 
b0097d1
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
"""
monoid_scan_cuda.py — Triton CUDA JIT Accelerated Parallel Prefix Scan
monoid_scan_cuda.py — Triton CUDA JIT 加速的并行前缀扫描

This module implements the parallel prefix scan for the vector-decay monoid recurrence:
  y_t[i,:] = exp(log_decay_t[i]) · y_{t-1}[i,:] + x_t[i,:]
本模块实现向量衰减幺半群递推的并行前缀扫描:
  y_t[i,:] = exp(log_decay_t[i]) · y_{t-1}[i,:] + x_t[i,:]

This is the computational backbone of Monoid Attention's state compression.
这是幺半群注意力状态压缩的计算骨干。

Vector decay: each dimension of the D_k×D_v state matrix has its own
per-dimension decay rate α_t ∈ ℝ^{D_k}, enabling different feature
dimensions to have independent memory lifetimes (fast-decaying for
local syntax, slow-decaying for global entity memory).
向量衰减: D_k×D_v 状态矩阵的每个维度拥有独立的衰减率 α_t ∈ ℝ^{D_k},
使不同特征维度拥有独立的记忆生命周期 (快速衰减用于局部语法, 慢速衰减用于全局实体记忆)。

Implementation:
  Forward: sequential scan along T, parallelized across B*H*D_k on GPU.
           Each program handles one row of the state matrix (D_v elements)
           with a scalar decay per row.
  Backward: reverse-order adjoint scan for gradient computation.
            Per-row reduction for log_decay gradient (no atomic_add needed).
  Auto-dispatches: CUDA → Triton kernel, CPU/MPS → PyTorch fallback.

  前向: 沿 T 维顺序扫描, 跨 B*H*D_k 在 GPU 上并行。
        每个 program 处理状态矩阵的一行 (D_v 个元素), 每行一个标量衰减。
  反向: 逆序伴随变量扫描计算梯度。
        逐行归约计算 log_decay 梯度 (无需 atomic_add)。
  自动分派: CUDA → Triton 核函数, CPU/MPS → PyTorch 回退。
"""

from __future__ import annotations

import torch
from torch import Tensor
from torch.autograd import Function
from typing import Tuple

try:
    import triton
    import triton.language as tl
    HAS_TRITON = True
except ImportError:
    HAS_TRITON = False


# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# Fallback: pure PyTorch sequential scan
# 回退: 纯 PyTorch 串行扫描 (CPU / MPS / no Triton)
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

def _sequential_scan(log_decays: Tensor, values: Tensor) -> Tensor:
    """
    Pure PyTorch sequential scan fallback (when no CUDA / Triton available).
    纯 PyTorch 串行扫描回退 (无 CUDA / Triton 时使用)。

    Implements the vector-decay monoid recurrence step by step:
      acc_0 = 0
      acc_t[i,:] = exp(log_decay_t[i]) · acc_{t-1}[i,:] + values_t[i,:]
    This is O(T) sequential — correct but slow on GPU.
    逐步实现向量衰减幺半群递推:
      acc_0 = 0
      acc_t[i,:] = exp(log_decay_t[i]) · acc_{t-1}[i,:] + values_t[i,:]
    这是 O(T) 串行的 — 结果正确但在 GPU 上较慢。

    Args:
        log_decays: [B, H, T, D_k]     — log of per-dimension per-step decay gates
                                           每维度每步衰减门的对数
        values:     [B, H, T, D_k, D_v] — outer products k_t⊗v_t to accumulate
                                           待累积的外积 k_t⊗v_t
    Returns:
        output:     [B, H, T, D_k, D_v] — all prefix states S_1, ..., S_T
                                           所有前缀状态 S_1, ..., S_T
    """
    B, H, T, D_k, D_v = values.shape
    out = torch.empty_like(values)
    # acc represents S_t — the compressed causal state at time t
    # acc 代表 S_t — 时刻 t 的压缩因果状态
    acc = torch.zeros(B, H, D_k, D_v, device=values.device, dtype=values.dtype)
    for t in range(T):
        # S_t = diag(α_t) · S_{t-1} + kv_t  (vector decay monoid recurrence)
        # S_t = diag(α_t) · S_{t-1} + kv_t  (向量衰减幺半群递推)
        decay_t = torch.exp(log_decays[:, :, t]).unsqueeze(-1)  # [B,H,D_k,1]
        acc = acc * decay_t + values[:, :, t]
        out[:, :, t] = acc
    return out


# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# Triton Kernels — GPU-accelerated scan (vector decay)
# Triton 核函数 — GPU 加速扫描 (向量衰减)
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

if HAS_TRITON:

    @triton.jit
    def _scan_fwd_kernel(
        LD_ptr, V_ptr, O_ptr,
        T, D_v,
        s_ld_bhdk, s_ld_t,
        s_v_bhdk, s_v_t, s_v_dv,
        s_o_bhdk, s_o_t, s_o_dv,
        BLOCK_DV: tl.constexpr,
    ):
        """
        Forward scan kernel — computes all prefix states S_1..S_T (vector decay).
        前向扫描核函数 — 计算所有前缀状态 S_1..S_T (向量衰减)。

        Parallelization strategy / 并行化策略:
          - program_id(0) = bhdk: one program per (batch, head, d_k row) triple
            每个 (batch, head, d_k 行) 三元组一个 program
          - program_id(1) = dvb: one program per D_v-dimension block (typically 1 block)
            每个 D_v 维 block 一个 program (通常只有 1 个 block)
          - Sequential loop over T (the causal recurrence is inherently sequential)
            沿 T 维串行循环 (因果递推本质上是串行的)

        Each program handles one row of the D_k×D_v state matrix, where the
        decay is a single scalar per row. This eliminates the need for
        row-index computation in the inner loop.
        每个 program 处理 D_k×D_v 状态矩阵的一行, 该行的衰减是一个标量。
        这消除了内循环中行索引计算的需要。

        Grid: (B*H*D_k, ceil(D_v/BLOCK_DV))
        网格: (B*H*D_k, ceil(D_v/BLOCK_DV))
        """
        bhdk = tl.program_id(0)
        dvb = tl.program_id(1)
        dv_offs = dvb * BLOCK_DV + tl.arange(0, BLOCK_DV)
        dv_mask = dv_offs < D_v

        # acc = S_0[row,:] = 0 (identity element of the monoid)
        # acc = S_0[行,:] = 0 (幺半群的单位元)
        acc = tl.zeros([BLOCK_DV], dtype=tl.float32)

        ld_base = LD_ptr + bhdk * s_ld_bhdk
        v_base = V_ptr + bhdk * s_v_bhdk
        o_base = O_ptr + bhdk * s_o_bhdk

        for t in range(T):
            # Load scalar log_decay for this row at time t
            # 加载此行在时刻 t 的标量 log_decay
            ld_val = tl.load(ld_base + t * s_ld_t).to(tl.float32)
            decay = tl.exp(ld_val)

            # Load kv_t[row, :] (one row of the outer product)
            # 加载 kv_t[行, :] (外积的一行)
            val = tl.load(
                v_base + t * s_v_t + dv_offs * s_v_dv,
                mask=dv_mask, other=0.0,
            ).to(tl.float32)

            # Core recurrence: S_t[i,:] = α_t[i] · S_{t-1}[i,:] + kv_t[i,:]
            # 核心递推: S_t[i,:] = α_t[i] · S_{t-1}[i,:] + kv_t[i,:]
            acc = acc * decay + val

            # Store S_t[row, :]
            tl.store(
                o_base + t * s_o_t + dv_offs * s_o_dv,
                acc, mask=dv_mask,
            )

    @triton.jit
    def _scan_bwd_kernel(
        LD_ptr, O_ptr, GO_ptr, GV_ptr, GLD_ptr,
        T, D_v,
        s_ld_bhdk, s_ld_t,
        s_o_bhdk, s_o_t, s_o_dv,
        s_go_bhdk, s_go_t, s_go_dv,
        s_gv_bhdk, s_gv_t, s_gv_dv,
        s_gld_bhdk, s_gld_t,
        BLOCK_DV: tl.constexpr,
    ):
        """
        Backward scan kernel — computes gradients via adjoint method (vector decay).
        反向扫描核函数 — 通过伴随方法计算梯度 (向量衰减)。

        Each program handles one row of the state matrix (one d_k dimension).
        The decay for this row is a scalar, so the log_decay gradient is:
          ∂L/∂log_α_t[i] = α_t[i] · Σ_j(λ_t[i,j] · y_{t-1}[i,j])
        The sum over j (D_v) is computed within this single program — no atomic_add.
        每个 program 处理状态矩阵的一行 (一个 d_k 维度)。
        该行的衰减是标量, 因此 log_decay 梯度为:
          ∂L/∂log_α_t[i] = α_t[i] · Σ_j(λ_t[i,j] · y_{t-1}[i,j])
        对 j (D_v) 的求和在单个 program 内完成 — 无需 atomic_add。
        """
        bhdk = tl.program_id(0)
        dvb = tl.program_id(1)
        dv_offs = dvb * BLOCK_DV + tl.arange(0, BLOCK_DV)
        dv_mask = dv_offs < D_v

        # adj holds a_{t+1} · λ_{t+1}, initialized to 0 at the sequence end
        # adj 保存 a_{t+1} · λ_{t+1}, 在序列末尾初始化为 0
        adj = tl.zeros([BLOCK_DV], dtype=tl.float32)

        for t_rev in range(T):
            t = T - 1 - t_rev     # reverse time / 逆序时间

            # Load ∂L/∂y_t[row, :] (upstream gradient)
            # 加载 ∂L/∂y_t[行, :] (上游梯度)
            go = tl.load(
                GO_ptr + bhdk * s_go_bhdk + t * s_go_t + dv_offs * s_go_dv,
                mask=dv_mask, other=0.0,
            ).to(tl.float32)

            # Adjoint: λ_t = ∂L/∂y_t + a_{t+1} · λ_{t+1}
            # 伴随: λ_t = ∂L/∂y_t + a_{t+1} · λ_{t+1}
            lam = go + adj

            # ∂L/∂x_t[row,:] = λ_t (gradient of values)
            # ∂L/∂x_t[行,:] = λ_t (值的梯度)
            tl.store(
                GV_ptr + bhdk * s_gv_bhdk + t * s_gv_t + dv_offs * s_gv_dv,
                lam, mask=dv_mask,
            )

            # ∂L/∂log_α_t[i] = α_t[i] · Σ_j(λ_t[i,j] · y_{t-1}[i,j])
            # Per-row scalar gradient: sum over D_v within this program.
            # 逐行标量梯度: 在此 program 内对 D_v 求和。
            ld_val = tl.load(LD_ptr + bhdk * s_ld_bhdk + t * s_ld_t).to(tl.float32)
            a_t = tl.exp(ld_val)

            if t > 0:
                y_prev = tl.load(
                    O_ptr + bhdk * s_o_bhdk + (t - 1) * s_o_t + dv_offs * s_o_dv,
                    mask=dv_mask, other=0.0,
                ).to(tl.float32)
                grad_ld = tl.sum(lam * y_prev) * a_t
                tl.atomic_add(GLD_ptr + bhdk * s_gld_bhdk + t * s_gld_t, grad_ld)

            # Prepare for next step (t-1): adj = a_t · λ_t
            # 为下一步 (t-1) 准备: adj = a_t · λ_t
            adj = a_t * lam

    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    # Autograd Function — bridges Triton kernels with PyTorch autograd
    # 自动微分函数 — 将 Triton 核函数与 PyTorch 自动微分桥接
    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

    class _ParallelScanFn(Function):
        """
        Custom autograd function for the parallel prefix scan (vector decay).
        并行前缀扫描的自定义 autograd 函数 (向量衰减)。

        Forward: launches _scan_fwd_kernel to compute all prefix states.
                 Grid: (B*H*D_k, ceil(D_v/BLOCK_DV)), one program per state row.
        Backward: launches _scan_bwd_kernel to compute gradients via adjoint method.
                  Per-row reduction eliminates most atomic_add overhead.

        前向: 启动 _scan_fwd_kernel 计算所有前缀状态。
              网格: (B*H*D_k, ceil(D_v/BLOCK_DV)), 每行状态一个 program。
        反向: 启动 _scan_bwd_kernel 通过伴随方法计算梯度。
              逐行归约消除大部分 atomic_add 开销。
        """
        @staticmethod
        def forward(ctx, log_decays: Tensor, values: Tensor) -> Tensor:
            B, H, T, D_k, D_v = values.shape

            # Reshape for row-parallel kernel:
            #   log_decays: [B, H, T, D_k] → permute to [B, H, D_k, T] → [B*H*D_k, T]
            #   values:     [B, H, T, D_k, D_v] → permute to [B, H, D_k, T, D_v] → [B*H*D_k, T, D_v]
            # 为行并行核函数重塑:
            #   log_decays: [B, H, T, D_k] → 转置为 [B, H, D_k, T] → [B*H*D_k, T]
            #   values:     [B, H, T, D_k, D_v] → 转置为 [B, H, D_k, T, D_v] → [B*H*D_k, T, D_v]
            ld_flat = log_decays.permute(0, 1, 3, 2).contiguous().reshape(B * H * D_k, T)
            v_flat = values.permute(0, 1, 3, 2, 4).contiguous().reshape(B * H * D_k, T, D_v)
            o_flat = torch.empty_like(v_flat)

            BHDK = B * H * D_k
            BLOCK_DV = min(triton.next_power_of_2(D_v), 1024)
            # Grid: (B*H*D_k, ceil(D_v/BLOCK_DV)) — one program per (batch, head, row, dv-block)
            # 网格: (B*H*D_k, ceil(D_v/BLOCK_DV))
            grid = (BHDK, triton.cdiv(D_v, BLOCK_DV))

            _scan_fwd_kernel[grid](
                ld_flat, v_flat, o_flat,
                T, D_v,
                ld_flat.stride(0), ld_flat.stride(1),
                v_flat.stride(0), v_flat.stride(1), v_flat.stride(2),
                o_flat.stride(0), o_flat.stride(1), o_flat.stride(2),
                BLOCK_DV=BLOCK_DV,
            )

            # Save for backward: need log_decays and forward outputs y_t
            # 为反向传播保存: 需要 log_decays 和前向输出 y_t
            ctx.save_for_backward(ld_flat, o_flat)
            ctx.shape_info = (B, H, T, D_k, D_v, BHDK, BLOCK_DV)
            # Reshape back: [B*H*D_k, T, D_v] → [B, H, D_k, T, D_v] → [B, H, T, D_k, D_v]
            return o_flat.reshape(B, H, D_k, T, D_v).permute(0, 1, 3, 2, 4).contiguous()

        @staticmethod
        def backward(ctx, grad_output: Tensor):
            ld_flat, o_flat = ctx.saved_tensors
            B, H, T, D_k, D_v, BHDK, BLOCK_DV = ctx.shape_info

            # Permute grad_output to match row-parallel layout: [B,H,T,D_k,D_v] → [B*H*D_k, T, D_v]
            go_flat = grad_output.permute(0, 1, 3, 2, 4).contiguous().reshape(BHDK, T, D_v)
            gv_flat = torch.empty_like(go_flat)
            # Use f32 for gradient accumulation precision
            # 使用 f32 保证梯度累积的精度
            gld_flat = torch.zeros(BHDK, T, device=ld_flat.device, dtype=torch.float32)

            grid = (BHDK, triton.cdiv(D_v, BLOCK_DV))

            _scan_bwd_kernel[grid](
                ld_flat, o_flat, go_flat, gv_flat, gld_flat,
                T, D_v,
                ld_flat.stride(0), ld_flat.stride(1),
                o_flat.stride(0), o_flat.stride(1), o_flat.stride(2),
                go_flat.stride(0), go_flat.stride(1), go_flat.stride(2),
                gv_flat.stride(0), gv_flat.stride(1), gv_flat.stride(2),
                gld_flat.stride(0), gld_flat.stride(1),
                BLOCK_DV=BLOCK_DV,
            )

            # Reshape gradients back to original layout
            # 重塑梯度回原始布局
            # gld: [B*H*D_k, T] → [B, H, D_k, T] → [B, H, T, D_k]
            grad_log_decays = gld_flat.to(grad_output.dtype).reshape(B, H, D_k, T).permute(0, 1, 3, 2).contiguous()
            # gv: [B*H*D_k, T, D_v] → [B, H, D_k, T, D_v] → [B, H, T, D_k, D_v]
            grad_values = gv_flat.reshape(B, H, D_k, T, D_v).permute(0, 1, 3, 2, 4).contiguous()
            return grad_log_decays, grad_values

    def _triton_parallel_scan(log_decays: Tensor, values: Tensor) -> Tensor:
        """Triton-accelerated parallel scan entry point (vector decay).
        Triton 加速的并行扫描入口 (向量衰减)。"""
        return _ParallelScanFn.apply(log_decays, values)

else:
    _triton_parallel_scan = None


# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# Public API / 公共接口
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

def parallel_scan(log_decays: Tensor, values: Tensor) -> Tensor:
    """
    Parallel prefix scan — computes all prefix monoid sums (vector decay).
    并行前缀扫描 — 计算所有前缀幺半群和 (向量衰减)。

    This is the training-time workhorse of Monoid Attention.
    It computes S_1, S_2, ..., S_T where
      S_t[i,:] = α_t[i]·S_{t-1}[i,:] + kv_t[i,:]
    for ALL timesteps simultaneously.
    这是幺半群注意力训练时的主力计算。
    它同时计算所有时间步的 S_1, S_2, ..., S_T,
    其中 S_t[i,:] = α_t[i]·S_{t-1}[i,:] + kv_t[i,:]。

    Auto-dispatches based on device:
      CUDA → Triton JIT kernel (fast, with custom backward)
      CPU/MPS → PyTorch sequential scan (correct, slower)
    根据设备自动分派:
      CUDA → Triton JIT 核函数 (快速, 带自定义反向传播)
      CPU/MPS → PyTorch 串行扫描 (正确, 较慢)

    Args:
        log_decays: [B, H, T, D_k]      — log of per-dimension decay gates α_t
                                            每维度衰减门 α_t 的对数
        values:     [B, H, T, D_k, D_v] — outer products k_t⊗v_t
                                           外积 k_t⊗v_t
    Returns:
        states:     [B, H, T, D_k, D_v] — all prefix states S_1..S_T
                                           所有前缀状态 S_1..S_T
    """
    if _triton_parallel_scan is not None and values.is_cuda:
        return _triton_parallel_scan(log_decays, values)
    return _sequential_scan(log_decays, values)


def parallel_scan_with_state(
    log_decays: Tensor, values: Tensor,
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
    """
    Parallel prefix scan + extract final state for inference handoff (vector decay).
    并行前缀扫描 + 提取最终状态用于推理切换 (向量衰减)。

    Used during prefill: compute all training-time prefix states,
    AND extract the final accumulated state S_T so that subsequent
    tokens can be generated in O(1) RNN mode via monoid_op.
    在预填充时使用: 计算所有训练时的前缀状态,
    同时提取最终累积状态 S_T, 以便后续 token 可以
    通过 monoid_op 以 O(1) RNN 模式生成。

    This is the bridge between training mode (parallel scan)
    and inference mode (sequential monoid_op).
    这是训练模式 (并行扫描) 和推理模式 (串行 monoid_op) 之间的桥梁。

    Args:
        log_decays: [B, H, T, D_k]
        values:     [B, H, T, D_k, D_v]

    Returns:
        output:      [B, H, T, D_k, D_v]  — all prefix states S_1..S_T
                                              所有前缀状态
        final_state: (log_acc, S_T) where
            log_acc:     [B, H, D_k]         — accumulated log-decay vector (for future monoid_op)
                                                累积对数衰减向量 (供后续 monoid_op 使用)
            final_state: [B, H, D_k, D_v]    — S_T, the compressed causal summary
                                                S_T, 压缩的因果摘要
    """
    output = parallel_scan(log_decays, values)
    # Sum all log-decays over T to get the total accumulated decay per dimension
    # 对所有 log-decay 沿 T 求和得到每个维度的总累积衰减
    log_acc = log_decays.sum(dim=2)  # [B, H, D_k]
    # The last timestep's state IS the full causal summary
    # 最后一个时间步的状态就是完整的因果摘要
    final_state = output[:, :, -1]  # [B, H, D_k, D_v]
    return output, (log_acc, final_state)