| """ |
| Parallel Scan 工具函数:SNN 线性递推的高效并行求解 |
| |
| 实现三层后端: |
| 1. Fused PLIF kernel(默认,CUDA + Sigmoid surrogate): |
| 单 kernel 完成 PLIF 前向(scan + spike + soft reset)和反向(surrogate gradient) |
| · per-element beta/v_th: _fused_plif_fwd_kernel / _fused_plif_bwd_kernel |
| · row-param beta/v_th: _fused_plif_fwd_rowparam_kernel / _fused_plif_bwd_rowparam_kernel |
| 2. Triton linear_recurrence(CUDA,非 Sigmoid 或无 surrogate): |
| 列级并行 scan,O(K) 工作量,1 次 kernel launch |
| 3. Hillis-Steele parallel scan(CPU 回退):O(K log K) 工作量 |
| |
| 线性递推: |
| V[k] = a[k] * V[k-1] + b[k], V[-1] = v_init |
| |
| PLIF 神经元动力学: |
| V_pre[k] = beta[k] * V_post[k-1] + u[k] |
| s[k] = Θ(V_pre[k] - v_th[k]) |
| V_post[k] = V_pre[k] - v_th[k] * s[k] |
| |
| 数学原理见 SNN_SELECTIVE_STATE_SPACE.md。 |
| """ |
|
|
| import os |
| import torch |
|
|
|
|
| |
| |
| |
|
|
| |
| |
| _SYSTEM_PTXAS = '/usr/local/cuda-13.0/bin/ptxas' |
| if os.path.exists(_SYSTEM_PTXAS) and 'TRITON_PTXAS_PATH' not in os.environ: |
| os.environ['TRITON_PTXAS_PATH'] = _SYSTEM_PTXAS |
|
|
| _HAS_TRITON = False |
| try: |
| import triton |
| import triton.language as tl |
| _HAS_TRITON = True |
| except ImportError: |
| pass |
|
|
| if _HAS_TRITON: |
|
|
| @triton.jit |
| def _fwd_recurrence_kernel( |
| A_ptr, B_ptr, INIT_ptr, OUT_ptr, |
| K, num_cols, |
| BLOCK: tl.constexpr, |
| ): |
| """Forward: V[k] = A[k]*V[k-1] + B[k], V[-1] = init. |
| |
| Grid: (ceil(num_cols / BLOCK),) |
| Each program processes BLOCK columns across all K sequential steps. |
| Accumulation in float32; storage in input dtype. |
| """ |
| pid = tl.program_id(0) |
| cols = pid * BLOCK + tl.arange(0, BLOCK) |
| mask = cols < num_cols |
|
|
| v = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32) |
|
|
| for k in range(K): |
| off = k * num_cols + cols |
| a = tl.load(A_ptr + off, mask=mask, other=0.0).to(tl.float32) |
| b = tl.load(B_ptr + off, mask=mask, other=0.0).to(tl.float32) |
| v = a * v + b |
| tl.store(OUT_ptr + off, v, mask=mask) |
|
|
| @triton.jit |
| def _bwd_recurrence_kernel( |
| A_ptr, V_ptr, INIT_ptr, GRAD_OUT_ptr, |
| GRAD_A_ptr, GRAD_B_ptr, GRAD_INIT_ptr, |
| K, num_cols, |
| BLOCK: tl.constexpr, |
| ): |
| """Backward for V[k] = A[k]*V[k-1] + B[k]. |
| |
| Reverse accumulation (k from K-1 down to 0): |
| g = 0 |
| for k = K-1, ..., 0: |
| g += grad_out[k] |
| grad_B[k] = g |
| grad_A[k] = g * V[k-1] (V[-1] = init) |
| g = g * A[k] |
| grad_init = g |
| """ |
| pid = tl.program_id(0) |
| cols = pid * BLOCK + tl.arange(0, BLOCK) |
| mask = cols < num_cols |
|
|
| g = tl.zeros([BLOCK], dtype=tl.float32) |
|
|
| for k_rev in range(K): |
| k = K - 1 - k_rev |
| off = k * num_cols + cols |
|
|
| dV = tl.load(GRAD_OUT_ptr + off, mask=mask, other=0.0).to(tl.float32) |
| g = g + dV |
|
|
| tl.store(GRAD_B_ptr + off, g, mask=mask) |
|
|
| if k > 0: |
| v_prev = tl.load( |
| V_ptr + (k - 1) * num_cols + cols, |
| mask=mask, other=0.0, |
| ).to(tl.float32) |
| else: |
| v_prev = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32) |
| tl.store(GRAD_A_ptr + off, g * v_prev, mask=mask) |
|
|
| a = tl.load(A_ptr + off, mask=mask, other=0.0).to(tl.float32) |
| g = g * a |
|
|
| tl.store(GRAD_INIT_ptr + cols, g, mask=mask) |
|
|
| class _TritonLinearRecurrence(torch.autograd.Function): |
| """Fused Triton linear recurrence: V[k] = A[k]*V[k-1] + B[k].""" |
|
|
| _BLOCK = 128 |
|
|
| @staticmethod |
| def forward(ctx, beta, u, v_init): |
| beta_c = beta.contiguous() |
| u_c = u.contiguous() |
| v_init_c = v_init.contiguous() |
|
|
| K = beta_c.shape[0] |
| num_cols = beta_c[0].numel() |
| V = torch.empty_like(u_c) |
|
|
| BLOCK = _TritonLinearRecurrence._BLOCK |
| grid = ((num_cols + BLOCK - 1) // BLOCK,) |
|
|
| _fwd_recurrence_kernel[grid]( |
| beta_c, u_c, v_init_c, V, |
| K, num_cols, |
| BLOCK=BLOCK, |
| ) |
|
|
| if ctx.needs_input_grad[0] or ctx.needs_input_grad[1] or ctx.needs_input_grad[2]: |
| ctx.save_for_backward(beta_c, V, v_init_c) |
| ctx.K = K |
| ctx.num_cols = num_cols |
|
|
| return V |
|
|
| @staticmethod |
| def backward(ctx, grad_V): |
| beta, V, v_init = ctx.saved_tensors |
| grad_V_c = grad_V.contiguous() |
|
|
| K = ctx.K |
| num_cols = ctx.num_cols |
|
|
| grad_beta = torch.empty_like(beta) |
| grad_u = torch.empty_like(beta) |
| grad_v_init = torch.empty_like(v_init) |
|
|
| BLOCK = _TritonLinearRecurrence._BLOCK |
| grid = ((num_cols + BLOCK - 1) // BLOCK,) |
|
|
| _bwd_recurrence_kernel[grid]( |
| beta, V, v_init, grad_V_c, |
| grad_beta, grad_u, grad_v_init, |
| K, num_cols, |
| BLOCK=BLOCK, |
| ) |
|
|
| return grad_beta, grad_u, grad_v_init |
|
|
| |
| |
| |
|
|
| @triton.jit |
| def _fused_plif_fwd_kernel( |
| BETA_ptr, U_ptr, VTH_ptr, INIT_ptr, |
| SPIKE_ptr, VPOST_ptr, |
| K, num_cols, |
| BLOCK: tl.constexpr, |
| ): |
| """Fused PLIF forward: single-pass sequential scan with inline spike + soft reset. |
| |
| Exact computation — sequential scan IS the ground truth. |
| Replaces the 3-phase approach (linear scan + spike iteration + correction). |
| |
| Per column (parallel across batch*D): |
| v = v_init |
| for k = 0..K-1: |
| v_pre = beta[k]*v + u[k] |
| spike[k] = Θ(v_pre - v_th[k]) |
| v = v_pre - v_th[k]*spike[k] |
| """ |
| pid = tl.program_id(0) |
| cols = pid * BLOCK + tl.arange(0, BLOCK) |
| mask = cols < num_cols |
|
|
| v = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32) |
|
|
| for k in range(K): |
| off = k * num_cols + cols |
| beta = tl.load(BETA_ptr + off, mask=mask, other=0.0).to(tl.float32) |
| u = tl.load(U_ptr + off, mask=mask, other=0.0).to(tl.float32) |
| vth = tl.load(VTH_ptr + off, mask=mask, other=0.0).to(tl.float32) |
|
|
| v_pre = beta * v + u |
| spike = tl.where(v_pre >= vth, 1.0, 0.0) |
| v = v_pre - vth * spike |
|
|
| tl.store(SPIKE_ptr + off, spike, mask=mask) |
| tl.store(VPOST_ptr + off, v, mask=mask) |
|
|
| @triton.jit |
| def _fused_plif_bwd_kernel( |
| BETA_ptr, VTH_ptr, INIT_ptr, VPOST_ptr, SPIKE_ptr, |
| GRAD_SPIKE_ptr, GRAD_VPOST_ptr, |
| GRAD_BETA_ptr, GRAD_U_ptr, GRAD_VTH_ptr, GRAD_INIT_ptr, |
| K, num_cols, ALPHA, |
| BLOCK: tl.constexpr, |
| ): |
| """Fused PLIF backward: single reverse pass with Sigmoid surrogate gradient. |
| |
| V_pre[k] = V_post[k] + v_th[k]*spike[k] (reconstructed) |
| surrogate_grad(x) = alpha * sigmoid(alpha*x) * (1 - sigmoid(alpha*x)) |
| where x = V_pre[k] - v_th[k] = V_post[k] - v_th[k]*(1 - spike[k]) |
| |
| Reverse accumulation: |
| acc = 0 |
| for k = K-1 downto 0: |
| total_gV = grad_V_post[k] + acc |
| sg = surrogate_grad(V_pre[k] - v_th[k]) |
| grad_v_pre = grad_spike[k]*sg + total_gV |
| grad_beta[k] = grad_v_pre * V_post[k-1] |
| grad_u[k] = grad_v_pre |
| grad_v_th[k] = -grad_spike[k]*sg - total_gV*spike[k] |
| acc = grad_v_pre * beta[k] |
| grad_v_init = acc |
| """ |
| pid = tl.program_id(0) |
| cols = pid * BLOCK + tl.arange(0, BLOCK) |
| mask = cols < num_cols |
|
|
| acc = tl.zeros([BLOCK], dtype=tl.float32) |
|
|
| for k_rev in range(K): |
| k = K - 1 - k_rev |
| off = k * num_cols + cols |
|
|
| beta = tl.load(BETA_ptr + off, mask=mask, other=0.0).to(tl.float32) |
| vth = tl.load(VTH_ptr + off, mask=mask, other=0.0).to(tl.float32) |
| v_post = tl.load(VPOST_ptr + off, mask=mask, other=0.0).to(tl.float32) |
| spike = tl.load(SPIKE_ptr + off, mask=mask, other=0.0).to(tl.float32) |
|
|
| g_s = tl.load(GRAD_SPIKE_ptr + off, mask=mask, other=0.0).to(tl.float32) |
| g_V = tl.load(GRAD_VPOST_ptr + off, mask=mask, other=0.0).to(tl.float32) |
|
|
| |
| if k > 0: |
| v_prev = tl.load( |
| VPOST_ptr + (k - 1) * num_cols + cols, |
| mask=mask, other=0.0, |
| ).to(tl.float32) |
| else: |
| v_prev = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32) |
|
|
| |
| x = v_post - vth * (1.0 - spike) |
| neg_ax = -ALPHA * x |
| neg_ax = tl.where(neg_ax > 88.0, 88.0, neg_ax) |
| sig = 1.0 / (1.0 + tl.exp(neg_ax)) |
| sg = ALPHA * sig * (1.0 - sig) |
|
|
| total_gV = g_V + acc |
| grad_v_pre = g_s * sg + total_gV |
|
|
| tl.store(GRAD_BETA_ptr + off, grad_v_pre * v_prev, mask=mask) |
| tl.store(GRAD_U_ptr + off, grad_v_pre, mask=mask) |
| tl.store(GRAD_VTH_ptr + off, -g_s * sg - total_gV * spike, mask=mask) |
|
|
| acc = grad_v_pre * beta |
|
|
| tl.store(GRAD_INIT_ptr + cols, acc, mask=mask) |
|
|
| |
| |
| |
| |
|
|
| @triton.jit |
| def _fused_plif_fwd_rowparam_kernel( |
| BETA_ROW_ptr, U_ptr, VTH_ROW_ptr, INIT_ptr, |
| SPIKE_ptr, VPOST_ptr, |
| K, num_cols, |
| BLOCK: tl.constexpr, |
| ): |
| """Fused PLIF forward with row-parameter beta and v_th. |
| |
| beta and v_th are (*shape) — constant across K steps, loaded once into registers. |
| Reduces global memory reads from 3 per step (beta, u, v_th) to 1 (u only). |
| """ |
| pid = tl.program_id(0) |
| cols = pid * BLOCK + tl.arange(0, BLOCK) |
| mask = cols < num_cols |
|
|
| v = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32) |
| beta = tl.load(BETA_ROW_ptr + cols, mask=mask, other=0.0).to(tl.float32) |
| vth = tl.load(VTH_ROW_ptr + cols, mask=mask, other=0.0).to(tl.float32) |
|
|
| for k in range(K): |
| off = k * num_cols + cols |
| u = tl.load(U_ptr + off, mask=mask, other=0.0).to(tl.float32) |
|
|
| v_pre = beta * v + u |
| spike = tl.where(v_pre >= vth, 1.0, 0.0) |
| v = v_pre - vth * spike |
|
|
| tl.store(SPIKE_ptr + off, spike, mask=mask) |
| tl.store(VPOST_ptr + off, v, mask=mask) |
|
|
| @triton.jit |
| def _fused_plif_bwd_rowparam_kernel( |
| BETA_ROW_ptr, VTH_ROW_ptr, INIT_ptr, VPOST_ptr, SPIKE_ptr, |
| GRAD_SPIKE_ptr, GRAD_VPOST_ptr, |
| GRAD_BETA_ROW_ptr, GRAD_U_ptr, GRAD_VTH_ROW_ptr, GRAD_INIT_ptr, |
| K, num_cols, ALPHA, |
| BLOCK: tl.constexpr, |
| ): |
| """Fused PLIF backward with row-parameter beta/v_th. |
| |
| Gradients for beta and v_th are accumulated over K steps (reduction in registers). |
| Returns grad_beta_row (*shape) and grad_v_th_row (*shape) instead of per-step gradients. |
| """ |
| pid = tl.program_id(0) |
| cols = pid * BLOCK + tl.arange(0, BLOCK) |
| mask = cols < num_cols |
|
|
| beta = tl.load(BETA_ROW_ptr + cols, mask=mask, other=0.0).to(tl.float32) |
| vth = tl.load(VTH_ROW_ptr + cols, mask=mask, other=0.0).to(tl.float32) |
|
|
| acc = tl.zeros([BLOCK], dtype=tl.float32) |
| acc_grad_beta = tl.zeros([BLOCK], dtype=tl.float32) |
| acc_grad_vth = tl.zeros([BLOCK], dtype=tl.float32) |
|
|
| for k_rev in range(K): |
| k = K - 1 - k_rev |
| off = k * num_cols + cols |
|
|
| v_post = tl.load(VPOST_ptr + off, mask=mask, other=0.0).to(tl.float32) |
| spike = tl.load(SPIKE_ptr + off, mask=mask, other=0.0).to(tl.float32) |
|
|
| g_s = tl.load(GRAD_SPIKE_ptr + off, mask=mask, other=0.0).to(tl.float32) |
| g_V = tl.load(GRAD_VPOST_ptr + off, mask=mask, other=0.0).to(tl.float32) |
|
|
| if k > 0: |
| v_prev = tl.load( |
| VPOST_ptr + (k - 1) * num_cols + cols, |
| mask=mask, other=0.0, |
| ).to(tl.float32) |
| else: |
| v_prev = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32) |
|
|
| |
| x = v_post - vth * (1.0 - spike) |
| neg_ax = -ALPHA * x |
| neg_ax = tl.where(neg_ax > 88.0, 88.0, neg_ax) |
| sig = 1.0 / (1.0 + tl.exp(neg_ax)) |
| sg = ALPHA * sig * (1.0 - sig) |
|
|
| total_gV = g_V + acc |
| grad_v_pre = g_s * sg + total_gV |
|
|
| tl.store(GRAD_U_ptr + off, grad_v_pre, mask=mask) |
|
|
| |
| acc_grad_beta += grad_v_pre * v_prev |
| acc_grad_vth += -g_s * sg - total_gV * spike |
|
|
| acc = grad_v_pre * beta |
|
|
| tl.store(GRAD_INIT_ptr + cols, acc, mask=mask) |
| tl.store(GRAD_BETA_ROW_ptr + cols, acc_grad_beta, mask=mask) |
| tl.store(GRAD_VTH_ROW_ptr + cols, acc_grad_vth, mask=mask) |
|
|
| class _TritonPLIFRowParamForward(torch.autograd.Function): |
| """Fused Triton PLIF with row-parameter beta/v_th. |
| |
| For neurons with constant beta/v_th across K steps (ParametricLIFNode). |
| Eliminates expand+contiguous for beta/v_th tensors, reduces memory I/O by ~40%. |
| """ |
|
|
| _BLOCK = 128 |
|
|
| @staticmethod |
| def forward(ctx, beta_row, u, v_th_row, v_init, alpha): |
| beta_row_c = beta_row.contiguous() |
| u_c = u.contiguous() |
| v_th_row_c = v_th_row.contiguous() |
| v_init_c = v_init.contiguous() |
|
|
| K = u_c.shape[0] |
| num_cols = u_c[0].numel() |
|
|
| spike = torch.empty_like(u_c) |
| V_post = torch.empty_like(u_c) |
|
|
| BLOCK = _TritonPLIFRowParamForward._BLOCK |
| grid = ((num_cols + BLOCK - 1) // BLOCK,) |
|
|
| _fused_plif_fwd_rowparam_kernel[grid]( |
| beta_row_c, u_c, v_th_row_c, v_init_c, |
| spike, V_post, |
| K, num_cols, |
| BLOCK=BLOCK, |
| ) |
|
|
| if any(ctx.needs_input_grad[:4]): |
| ctx.save_for_backward(beta_row_c, v_th_row_c, v_init_c, V_post, spike) |
| ctx.K = K |
| ctx.num_cols = num_cols |
| ctx.alpha = alpha |
|
|
| return spike, V_post |
|
|
| @staticmethod |
| def backward(ctx, grad_spike, grad_V_post): |
| beta_row, v_th_row, v_init, V_post, spike = ctx.saved_tensors |
| K = ctx.K |
| num_cols = ctx.num_cols |
| alpha = ctx.alpha |
|
|
| if grad_spike is None: |
| grad_spike = torch.zeros_like(spike) |
| if grad_V_post is None: |
| grad_V_post = torch.zeros_like(V_post) |
|
|
| grad_spike_c = grad_spike.contiguous() |
| grad_V_post_c = grad_V_post.contiguous() |
|
|
| grad_beta_row = torch.empty_like(beta_row) |
| grad_u = torch.empty_like(V_post) |
| grad_v_th_row = torch.empty_like(v_th_row) |
| grad_v_init = torch.empty_like(v_init) |
|
|
| BLOCK = _TritonPLIFRowParamForward._BLOCK |
| grid = ((num_cols + BLOCK - 1) // BLOCK,) |
|
|
| _fused_plif_bwd_rowparam_kernel[grid]( |
| beta_row, v_th_row, v_init, V_post, spike, |
| grad_spike_c, grad_V_post_c, |
| grad_beta_row, grad_u, grad_v_th_row, grad_v_init, |
| K, num_cols, float(alpha), |
| BLOCK=BLOCK, |
| ) |
|
|
| return grad_beta_row, grad_u, grad_v_th_row, grad_v_init, None |
|
|
| class _TritonPLIFForward(torch.autograd.Function): |
| """Fused Triton PLIF forward + backward. |
| |
| Single-pass sequential scan replaces the 3-phase approach: |
| Phase 1 (linear scan) + Phase 2 (spike iteration) + Phase 3 (correction) |
| → 1 fused kernel with inline spike detection + soft reset |
| |
| Advantages: |
| - 1 kernel launch (vs 3-4 launches + ~10 element-wise ops) |
| - Exact computation (no iteration convergence issues) |
| - Less memory (no intermediate V_L, delta_S, delta_S_prev) |
| - Higher precision (fp32 accumulation, no bf16 intermediate store/load) |
| """ |
|
|
| _BLOCK = 128 |
|
|
| @staticmethod |
| def forward(ctx, beta, u, v_th, v_init, alpha): |
| beta_c = beta.contiguous() |
| u_c = u.contiguous() |
| v_th_c = v_th.contiguous() |
| v_init_c = v_init.contiguous() |
|
|
| K = beta_c.shape[0] |
| num_cols = beta_c[0].numel() |
|
|
| spike = torch.empty_like(u_c) |
| V_post = torch.empty_like(u_c) |
|
|
| BLOCK = _TritonPLIFForward._BLOCK |
| grid = ((num_cols + BLOCK - 1) // BLOCK,) |
|
|
| _fused_plif_fwd_kernel[grid]( |
| beta_c, u_c, v_th_c, v_init_c, |
| spike, V_post, |
| K, num_cols, |
| BLOCK=BLOCK, |
| ) |
|
|
| if any(ctx.needs_input_grad[:4]): |
| ctx.save_for_backward(beta_c, v_th_c, v_init_c, V_post, spike) |
| ctx.K = K |
| ctx.num_cols = num_cols |
| ctx.alpha = alpha |
|
|
| return spike, V_post |
|
|
| @staticmethod |
| def backward(ctx, grad_spike, grad_V_post): |
| beta, v_th, v_init, V_post, spike = ctx.saved_tensors |
| K = ctx.K |
| num_cols = ctx.num_cols |
| alpha = ctx.alpha |
|
|
| if grad_spike is None: |
| grad_spike = torch.zeros_like(spike) |
| if grad_V_post is None: |
| grad_V_post = torch.zeros_like(V_post) |
|
|
| grad_spike_c = grad_spike.contiguous() |
| grad_V_post_c = grad_V_post.contiguous() |
|
|
| grad_beta = torch.empty_like(beta) |
| grad_u = torch.empty_like(beta) |
| grad_v_th = torch.empty_like(v_th) |
| grad_v_init = torch.empty_like(v_init) |
|
|
| BLOCK = _TritonPLIFForward._BLOCK |
| grid = ((num_cols + BLOCK - 1) // BLOCK,) |
|
|
| _fused_plif_bwd_kernel[grid]( |
| beta, v_th, v_init, V_post, spike, |
| grad_spike_c, grad_V_post_c, |
| grad_beta, grad_u, grad_v_th, grad_v_init, |
| K, num_cols, float(alpha), |
| BLOCK=BLOCK, |
| ) |
|
|
| return grad_beta, grad_u, grad_v_th, grad_v_init, None |
|
|
|
|
| |
| |
| |
|
|
| def hillis_steele_scan(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Hillis-Steele 并行前缀扫描:计算仿射映射序列的所有前缀复合。 |
| |
| 给定仿射映射 f_k(x) = a[k] * x + b[k], k = 0, ..., K-1, |
| 计算前缀复合 F_k = f_k ∘ f_{k-1} ∘ ... ∘ f_0, |
| 使得 V[k] = F_k(v_init) = A[k] * v_init + B[k]。 |
| |
| 复合规则: (a2, b2) ∘ (a1, b1) = (a2 * a1, a2 * b1 + b2) |
| |
| 实现使用 torch.cat 重建张量(无原地操作),完全兼容 autograd。 |
| |
| Args: |
| a: (K, *shape) — 乘性系数(如 β) |
| b: (K, *shape) — 加性项(如 α·I) |
| |
| Returns: |
| A: (K, *shape) — 前缀积 A[k] = ∏_{j=0}^{k} a[j] |
| B: (K, *shape) — 前缀和 B[k] 使得 V[k] = A[k] * v_init + B[k] |
| |
| 并行深度: O(log K) |
| 工作量: O(K * log K) |
| """ |
| K = a.shape[0] |
| A = a |
| B = b |
|
|
| d = 1 |
| while d < K: |
| A_new_tail = A[d:] * A[:-d] |
| B_new_tail = A[d:] * B[:-d] + B[d:] |
|
|
| A = torch.cat([A[:d], A_new_tail], dim=0) |
| B = torch.cat([B[:d], B_new_tail], dim=0) |
|
|
| d *= 2 |
|
|
| return A, B |
|
|
|
|
| |
| |
| |
|
|
| def linear_recurrence(beta: torch.Tensor, u: torch.Tensor, v_init: torch.Tensor) -> torch.Tensor: |
| """ |
| 求解线性递推: V[k] = beta[k] * V[k-1] + u[k], V[-1] = v_init |
| |
| CUDA 后端: Triton fused kernel(1 次 kernel launch,O(K) 工作量) |
| CPU 后端: Hillis-Steele parallel scan(O(K log K) 工作量) |
| |
| Args: |
| beta: (K, *shape) — 衰减系数,值域 (0, 1) |
| u: (K, *shape) — 输入项 |
| v_init: (*shape) — 初始状态 |
| |
| Returns: |
| V: (K, *shape) — 所有 K 步的状态 |
| """ |
| if _HAS_TRITON and beta.is_cuda: |
| return _TritonLinearRecurrence.apply(beta, u, v_init) |
| |
| A, B = hillis_steele_scan(beta, u) |
| V = A * v_init.unsqueeze(0) + B |
| return V |
|
|
|
|
| |
| |
| |
|
|
| def plif_parallel_forward( |
| beta: torch.Tensor, |
| u: torch.Tensor, |
| v_th: torch.Tensor, |
| v_init: torch.Tensor, |
| max_iter: int = 3, |
| surrogate_function=None, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| PLIF 神经元的并行前向传播(soft reset,surrogate gradient 兼容)。 |
| |
| 求解: |
| V_pre[k] = beta[k] * V_post[k-1] + u[k] |
| s[k] = Θ(V_pre[k] - v_th[k]) |
| V_post[k] = V_pre[k] - v_th[k] * s[k] |
| |
| 方法: |
| Phase 1: 线性轨迹 parallel scan(有梯度) |
| Phase 2: spike 不动点迭代(detach,确定离散 spike pattern) |
| Phase 3: 用 converged spike pattern 重算 V_post(有梯度), |
| surrogate_function(V_pre - v_th) 生成可微 spike 输出 |
| |
| Args: |
| beta: (K, *shape) — 衰减系数 |
| u: (K, *shape) — 输入 α·I |
| v_th: (K, *shape) — 动态阈值 |
| v_init: (*shape) — 初始膜电位 |
| max_iter: spike 不动点迭代次数上限 |
| surrogate_function: SpikingJelly surrogate gradient 函数(如 surrogate.Sigmoid(alpha=4.0)) |
| None 时退化为硬阈值(无梯度) |
| |
| Returns: |
| spike: (K, *shape) — spike 模式(有 surrogate gradient) |
| V_post: (K, *shape) — 发放后膜电位 |
| V_pre: (K, *shape) — 发放前膜电位(fused path 返回 None) |
| """ |
| |
| |
| if (_HAS_TRITON and beta.is_cuda and surrogate_function is not None |
| and hasattr(surrogate_function, 'alpha') |
| and type(surrogate_function).__name__ == 'Sigmoid'): |
| alpha = float(surrogate_function.alpha) |
| spike, V_post = _TritonPLIFForward.apply(beta, u, v_th, v_init, alpha) |
| return spike, V_post, None |
|
|
| |
| |
| V_L = linear_recurrence(beta, u, v_init) |
|
|
| |
| |
| with torch.no_grad(): |
| V_L_det = V_L.detach() |
| beta_det = beta.detach() |
| v_th_det = v_th.detach() |
| v_init_det = v_init.detach() if isinstance(v_init, torch.Tensor) else v_init |
|
|
| spike_pattern = (V_L_det >= v_th_det).float() |
|
|
| for _ in range(max_iter - 1): |
| |
| delta_S = linear_recurrence( |
| beta_det, v_th_det * spike_pattern, |
| torch.zeros_like(v_init_det) if isinstance(v_init_det, torch.Tensor) |
| else torch.zeros_like(V_L_det[0]), |
| ) |
|
|
| |
| delta_S_prev = torch.zeros_like(delta_S) |
| delta_S_prev[1:] = delta_S[:-1] |
|
|
| |
| V_pre_det = V_L_det - beta_det * delta_S_prev |
|
|
| |
| spike_new = (V_pre_det >= v_th_det).float() |
|
|
| |
| if torch.equal(spike_new, spike_pattern): |
| break |
| spike_pattern = spike_new |
|
|
| |
| |
| |
| u_eff = u - v_th * spike_pattern |
| V_post = linear_recurrence(beta, u_eff, v_init) |
|
|
| |
| V_post_prev = torch.zeros_like(V_post) |
| if isinstance(v_init, torch.Tensor): |
| V_post_prev[0] = v_init |
| V_post_prev[1:] = V_post[:-1] |
| V_pre = beta * V_post_prev + u |
|
|
| |
| if surrogate_function is not None: |
| |
| spike = surrogate_function(V_pre - v_th) |
| else: |
| |
| spike = (V_pre >= v_th).float() |
|
|
| return spike, V_post, V_pre |
|
|
|
|
| def plif_rowparam_forward( |
| beta_row: torch.Tensor, |
| u: torch.Tensor, |
| v_th_row: torch.Tensor, |
| v_init: torch.Tensor, |
| surrogate_function=None, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| 行参数 PLIF 前向:beta 和 v_th 在 K 步中保持恒定。 |
| |
| 比 plif_parallel_forward 快 ~40%(省去 expand+contiguous,减少 2/3 显存读取)。 |
| 用于 ParametricLIFNode(固定 beta/v_th)或合并多个固定参数神经元。 |
| |
| Args: |
| beta_row: (*shape) — 每列的衰减率(所有 K 步相同) |
| u: (K, *shape) — 每步输入 |
| v_th_row: (*shape) — 每列的阈值(所有 K 步相同) |
| v_init: (*shape) — 初始膜电位 |
| surrogate_function: surrogate gradient 函数 |
| |
| Returns: |
| spike: (K, *shape) — spike 模式 |
| V_post: (K, *shape) — 发放后膜电位 |
| """ |
| if (_HAS_TRITON and u.is_cuda and surrogate_function is not None |
| and hasattr(surrogate_function, 'alpha') |
| and type(surrogate_function).__name__ == 'Sigmoid'): |
| alpha = float(surrogate_function.alpha) |
| spike, V_post = _TritonPLIFRowParamForward.apply( |
| beta_row, u, v_th_row, v_init, alpha, |
| ) |
| return spike, V_post |
|
|
| |
| K = u.shape[0] |
| beta = beta_row.unsqueeze(0).expand(K, *u.shape[1:]).contiguous() |
| v_th = v_th_row.unsqueeze(0).expand(K, *u.shape[1:]).contiguous() |
| spike, V_post, _ = plif_parallel_forward(beta, u, v_th, v_init, surrogate_function=surrogate_function) |
| return spike, V_post |
|
|
|
|
| def plif_fixed_param_forward( |
| beta, |
| u: torch.Tensor, |
| v_th, |
| v_init: torch.Tensor, |
| max_iter: int = 3, |
| surrogate_function=None, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| 固定参数 PLIF 神经元的并行前向(如输出神经元、FFN 神经元)。 |
| |
| ParametricLIFNode 方程: V[k] = beta * V[k-1] + (1-beta) * x[k] |
| 其中 beta = 1/(1+exp(w)), 可为 scalar tensor(保持梯度流向 w)。 |
| |
| scalar/0-dim beta 和 v_th 使用 row-param 内核(无需 expand 到 (K, *shape))。 |
| |
| Args: |
| beta: 衰减率 — scalar float、0-dim tensor 或 (K, *shape) tensor |
| u: (K, *shape) — 输入(已乘以 (1-beta)) |
| v_th: 阈值 — scalar float、0-dim tensor 或 (K, *shape) tensor |
| v_init: (*shape) — 初始膜电位 |
| max_iter: spike 迭代次数 |
| surrogate_function: surrogate gradient 函数 |
| |
| Returns: |
| spike: (K, *shape) — spike 模式 |
| V_post: (K, *shape) — 发放后膜电位 |
| """ |
| K = u.shape[0] |
| shape = u.shape[1:] |
|
|
| |
| beta_is_scalar = isinstance(beta, torch.Tensor) and beta.dim() == 0 |
| beta_is_float = not isinstance(beta, torch.Tensor) |
| vth_is_scalar = isinstance(v_th, torch.Tensor) and v_th.dim() == 0 |
| vth_is_float = not isinstance(v_th, torch.Tensor) |
|
|
| if (beta_is_scalar or beta_is_float) and (vth_is_scalar or vth_is_float): |
| |
| if beta_is_scalar: |
| beta_row = beta.expand(*shape).contiguous() |
| else: |
| beta_row = torch.full(shape, beta, device=u.device, dtype=u.dtype) |
| if vth_is_scalar: |
| v_th_row = v_th.expand(*shape).contiguous() |
| else: |
| v_th_row = torch.full(shape, v_th, device=u.device, dtype=u.dtype) |
| return plif_rowparam_forward(beta_row, u, v_th_row, v_init, surrogate_function) |
|
|
| |
| if isinstance(beta, torch.Tensor): |
| if beta.dim() == 0: |
| beta = beta.expand(K, *shape).contiguous() |
| else: |
| beta = torch.full_like(u, beta) |
|
|
| if isinstance(v_th, torch.Tensor): |
| if v_th.dim() == 0: |
| v_th = v_th.expand(K, *shape).contiguous() |
| else: |
| v_th = torch.full_like(u, v_th) |
|
|
| spike, V_post, _ = plif_parallel_forward( |
| beta, u, v_th, v_init, max_iter, surrogate_function, |
| ) |
| return spike, V_post |
|
|