""" Parallel Scan 工具函数:SNN 线性递推的高效并行求解 实现三层后端: 1. Fused PLIF kernel(默认,CUDA + Sigmoid surrogate): 单 kernel 完成 PLIF 前向(scan + spike + soft reset)和反向(surrogate gradient) · per-element beta/v_th: _fused_plif_fwd_kernel / _fused_plif_bwd_kernel · row-param beta/v_th: _fused_plif_fwd_rowparam_kernel / _fused_plif_bwd_rowparam_kernel 2. Triton linear_recurrence(CUDA,非 Sigmoid 或无 surrogate): 列级并行 scan,O(K) 工作量,1 次 kernel launch 3. Hillis-Steele parallel scan(CPU 回退):O(K log K) 工作量 线性递推: V[k] = a[k] * V[k-1] + b[k], V[-1] = v_init PLIF 神经元动力学: V_pre[k] = beta[k] * V_post[k-1] + u[k] s[k] = Θ(V_pre[k] - v_th[k]) V_post[k] = V_pre[k] - v_th[k] * s[k] 数学原理见 SNN_SELECTIVE_STATE_SPACE.md。 """ import os import torch # ============================================================ # Triton fused recurrence kernels # ============================================================ # DGX Spark (GB10, sm_121a): Triton 3.5.1 自带 ptxas 不支持 sm_121a, # 需要使用系统 CUDA 13.0 的 ptxas _SYSTEM_PTXAS = '/usr/local/cuda-13.0/bin/ptxas' if os.path.exists(_SYSTEM_PTXAS) and 'TRITON_PTXAS_PATH' not in os.environ: os.environ['TRITON_PTXAS_PATH'] = _SYSTEM_PTXAS _HAS_TRITON = False try: import triton import triton.language as tl _HAS_TRITON = True except ImportError: pass if _HAS_TRITON: @triton.jit def _fwd_recurrence_kernel( A_ptr, B_ptr, INIT_ptr, OUT_ptr, K, num_cols, BLOCK: tl.constexpr, ): """Forward: V[k] = A[k]*V[k-1] + B[k], V[-1] = init. Grid: (ceil(num_cols / BLOCK),) Each program processes BLOCK columns across all K sequential steps. Accumulation in float32; storage in input dtype. """ pid = tl.program_id(0) cols = pid * BLOCK + tl.arange(0, BLOCK) mask = cols < num_cols v = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32) for k in range(K): off = k * num_cols + cols a = tl.load(A_ptr + off, mask=mask, other=0.0).to(tl.float32) b = tl.load(B_ptr + off, mask=mask, other=0.0).to(tl.float32) v = a * v + b tl.store(OUT_ptr + off, v, mask=mask) @triton.jit def _bwd_recurrence_kernel( A_ptr, V_ptr, INIT_ptr, GRAD_OUT_ptr, GRAD_A_ptr, GRAD_B_ptr, GRAD_INIT_ptr, K, num_cols, BLOCK: tl.constexpr, ): """Backward for V[k] = A[k]*V[k-1] + B[k]. Reverse accumulation (k from K-1 down to 0): g = 0 for k = K-1, ..., 0: g += grad_out[k] grad_B[k] = g grad_A[k] = g * V[k-1] (V[-1] = init) g = g * A[k] grad_init = g """ pid = tl.program_id(0) cols = pid * BLOCK + tl.arange(0, BLOCK) mask = cols < num_cols g = tl.zeros([BLOCK], dtype=tl.float32) for k_rev in range(K): k = K - 1 - k_rev off = k * num_cols + cols dV = tl.load(GRAD_OUT_ptr + off, mask=mask, other=0.0).to(tl.float32) g = g + dV tl.store(GRAD_B_ptr + off, g, mask=mask) if k > 0: v_prev = tl.load( V_ptr + (k - 1) * num_cols + cols, mask=mask, other=0.0, ).to(tl.float32) else: v_prev = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32) tl.store(GRAD_A_ptr + off, g * v_prev, mask=mask) a = tl.load(A_ptr + off, mask=mask, other=0.0).to(tl.float32) g = g * a tl.store(GRAD_INIT_ptr + cols, g, mask=mask) class _TritonLinearRecurrence(torch.autograd.Function): """Fused Triton linear recurrence: V[k] = A[k]*V[k-1] + B[k].""" _BLOCK = 128 @staticmethod def forward(ctx, beta, u, v_init): beta_c = beta.contiguous() u_c = u.contiguous() v_init_c = v_init.contiguous() K = beta_c.shape[0] num_cols = beta_c[0].numel() V = torch.empty_like(u_c) BLOCK = _TritonLinearRecurrence._BLOCK grid = ((num_cols + BLOCK - 1) // BLOCK,) _fwd_recurrence_kernel[grid]( beta_c, u_c, v_init_c, V, K, num_cols, BLOCK=BLOCK, ) if ctx.needs_input_grad[0] or ctx.needs_input_grad[1] or ctx.needs_input_grad[2]: ctx.save_for_backward(beta_c, V, v_init_c) ctx.K = K ctx.num_cols = num_cols return V @staticmethod def backward(ctx, grad_V): beta, V, v_init = ctx.saved_tensors grad_V_c = grad_V.contiguous() K = ctx.K num_cols = ctx.num_cols grad_beta = torch.empty_like(beta) grad_u = torch.empty_like(beta) grad_v_init = torch.empty_like(v_init) BLOCK = _TritonLinearRecurrence._BLOCK grid = ((num_cols + BLOCK - 1) // BLOCK,) _bwd_recurrence_kernel[grid]( beta, V, v_init, grad_V_c, grad_beta, grad_u, grad_v_init, K, num_cols, BLOCK=BLOCK, ) return grad_beta, grad_u, grad_v_init # ============================================================ # Fused PLIF forward/backward kernels # ============================================================ @triton.jit def _fused_plif_fwd_kernel( BETA_ptr, U_ptr, VTH_ptr, INIT_ptr, SPIKE_ptr, VPOST_ptr, K, num_cols, BLOCK: tl.constexpr, ): """Fused PLIF forward: single-pass sequential scan with inline spike + soft reset. Exact computation — sequential scan IS the ground truth. Replaces the 3-phase approach (linear scan + spike iteration + correction). Per column (parallel across batch*D): v = v_init for k = 0..K-1: v_pre = beta[k]*v + u[k] spike[k] = Θ(v_pre - v_th[k]) v = v_pre - v_th[k]*spike[k] """ pid = tl.program_id(0) cols = pid * BLOCK + tl.arange(0, BLOCK) mask = cols < num_cols v = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32) for k in range(K): off = k * num_cols + cols beta = tl.load(BETA_ptr + off, mask=mask, other=0.0).to(tl.float32) u = tl.load(U_ptr + off, mask=mask, other=0.0).to(tl.float32) vth = tl.load(VTH_ptr + off, mask=mask, other=0.0).to(tl.float32) v_pre = beta * v + u spike = tl.where(v_pre >= vth, 1.0, 0.0) v = v_pre - vth * spike # soft reset tl.store(SPIKE_ptr + off, spike, mask=mask) tl.store(VPOST_ptr + off, v, mask=mask) @triton.jit def _fused_plif_bwd_kernel( BETA_ptr, VTH_ptr, INIT_ptr, VPOST_ptr, SPIKE_ptr, GRAD_SPIKE_ptr, GRAD_VPOST_ptr, GRAD_BETA_ptr, GRAD_U_ptr, GRAD_VTH_ptr, GRAD_INIT_ptr, K, num_cols, ALPHA, BLOCK: tl.constexpr, ): """Fused PLIF backward: single reverse pass with Sigmoid surrogate gradient. V_pre[k] = V_post[k] + v_th[k]*spike[k] (reconstructed) surrogate_grad(x) = alpha * sigmoid(alpha*x) * (1 - sigmoid(alpha*x)) where x = V_pre[k] - v_th[k] = V_post[k] - v_th[k]*(1 - spike[k]) Reverse accumulation: acc = 0 for k = K-1 downto 0: total_gV = grad_V_post[k] + acc sg = surrogate_grad(V_pre[k] - v_th[k]) grad_v_pre = grad_spike[k]*sg + total_gV grad_beta[k] = grad_v_pre * V_post[k-1] grad_u[k] = grad_v_pre grad_v_th[k] = -grad_spike[k]*sg - total_gV*spike[k] acc = grad_v_pre * beta[k] grad_v_init = acc """ pid = tl.program_id(0) cols = pid * BLOCK + tl.arange(0, BLOCK) mask = cols < num_cols acc = tl.zeros([BLOCK], dtype=tl.float32) for k_rev in range(K): k = K - 1 - k_rev off = k * num_cols + cols beta = tl.load(BETA_ptr + off, mask=mask, other=0.0).to(tl.float32) vth = tl.load(VTH_ptr + off, mask=mask, other=0.0).to(tl.float32) v_post = tl.load(VPOST_ptr + off, mask=mask, other=0.0).to(tl.float32) spike = tl.load(SPIKE_ptr + off, mask=mask, other=0.0).to(tl.float32) g_s = tl.load(GRAD_SPIKE_ptr + off, mask=mask, other=0.0).to(tl.float32) g_V = tl.load(GRAD_VPOST_ptr + off, mask=mask, other=0.0).to(tl.float32) # V_post[k-1] if k > 0: v_prev = tl.load( VPOST_ptr + (k - 1) * num_cols + cols, mask=mask, other=0.0, ).to(tl.float32) else: v_prev = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32) # Sigmoid surrogate gradient x = v_post - vth * (1.0 - spike) # = V_pre - v_th neg_ax = -ALPHA * x neg_ax = tl.where(neg_ax > 88.0, 88.0, neg_ax) # prevent exp overflow sig = 1.0 / (1.0 + tl.exp(neg_ax)) sg = ALPHA * sig * (1.0 - sig) total_gV = g_V + acc grad_v_pre = g_s * sg + total_gV tl.store(GRAD_BETA_ptr + off, grad_v_pre * v_prev, mask=mask) tl.store(GRAD_U_ptr + off, grad_v_pre, mask=mask) tl.store(GRAD_VTH_ptr + off, -g_s * sg - total_gV * spike, mask=mask) acc = grad_v_pre * beta tl.store(GRAD_INIT_ptr + cols, acc, mask=mask) # ============================================================ # Fused PLIF kernels with row-parameter beta/v_th # (constant across K steps — e.g., ParametricLIFNode scalars) # ============================================================ @triton.jit def _fused_plif_fwd_rowparam_kernel( BETA_ROW_ptr, U_ptr, VTH_ROW_ptr, INIT_ptr, SPIKE_ptr, VPOST_ptr, K, num_cols, BLOCK: tl.constexpr, ): """Fused PLIF forward with row-parameter beta and v_th. beta and v_th are (*shape) — constant across K steps, loaded once into registers. Reduces global memory reads from 3 per step (beta, u, v_th) to 1 (u only). """ pid = tl.program_id(0) cols = pid * BLOCK + tl.arange(0, BLOCK) mask = cols < num_cols v = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32) beta = tl.load(BETA_ROW_ptr + cols, mask=mask, other=0.0).to(tl.float32) vth = tl.load(VTH_ROW_ptr + cols, mask=mask, other=0.0).to(tl.float32) for k in range(K): off = k * num_cols + cols u = tl.load(U_ptr + off, mask=mask, other=0.0).to(tl.float32) v_pre = beta * v + u spike = tl.where(v_pre >= vth, 1.0, 0.0) v = v_pre - vth * spike tl.store(SPIKE_ptr + off, spike, mask=mask) tl.store(VPOST_ptr + off, v, mask=mask) @triton.jit def _fused_plif_bwd_rowparam_kernel( BETA_ROW_ptr, VTH_ROW_ptr, INIT_ptr, VPOST_ptr, SPIKE_ptr, GRAD_SPIKE_ptr, GRAD_VPOST_ptr, GRAD_BETA_ROW_ptr, GRAD_U_ptr, GRAD_VTH_ROW_ptr, GRAD_INIT_ptr, K, num_cols, ALPHA, BLOCK: tl.constexpr, ): """Fused PLIF backward with row-parameter beta/v_th. Gradients for beta and v_th are accumulated over K steps (reduction in registers). Returns grad_beta_row (*shape) and grad_v_th_row (*shape) instead of per-step gradients. """ pid = tl.program_id(0) cols = pid * BLOCK + tl.arange(0, BLOCK) mask = cols < num_cols beta = tl.load(BETA_ROW_ptr + cols, mask=mask, other=0.0).to(tl.float32) vth = tl.load(VTH_ROW_ptr + cols, mask=mask, other=0.0).to(tl.float32) acc = tl.zeros([BLOCK], dtype=tl.float32) acc_grad_beta = tl.zeros([BLOCK], dtype=tl.float32) acc_grad_vth = tl.zeros([BLOCK], dtype=tl.float32) for k_rev in range(K): k = K - 1 - k_rev off = k * num_cols + cols v_post = tl.load(VPOST_ptr + off, mask=mask, other=0.0).to(tl.float32) spike = tl.load(SPIKE_ptr + off, mask=mask, other=0.0).to(tl.float32) g_s = tl.load(GRAD_SPIKE_ptr + off, mask=mask, other=0.0).to(tl.float32) g_V = tl.load(GRAD_VPOST_ptr + off, mask=mask, other=0.0).to(tl.float32) if k > 0: v_prev = tl.load( VPOST_ptr + (k - 1) * num_cols + cols, mask=mask, other=0.0, ).to(tl.float32) else: v_prev = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32) # Sigmoid surrogate gradient x = v_post - vth * (1.0 - spike) neg_ax = -ALPHA * x neg_ax = tl.where(neg_ax > 88.0, 88.0, neg_ax) sig = 1.0 / (1.0 + tl.exp(neg_ax)) sg = ALPHA * sig * (1.0 - sig) total_gV = g_V + acc grad_v_pre = g_s * sg + total_gV tl.store(GRAD_U_ptr + off, grad_v_pre, mask=mask) # Accumulate gradients for row parameters (reduction over K in registers) acc_grad_beta += grad_v_pre * v_prev acc_grad_vth += -g_s * sg - total_gV * spike acc = grad_v_pre * beta tl.store(GRAD_INIT_ptr + cols, acc, mask=mask) tl.store(GRAD_BETA_ROW_ptr + cols, acc_grad_beta, mask=mask) tl.store(GRAD_VTH_ROW_ptr + cols, acc_grad_vth, mask=mask) class _TritonPLIFRowParamForward(torch.autograd.Function): """Fused Triton PLIF with row-parameter beta/v_th. For neurons with constant beta/v_th across K steps (ParametricLIFNode). Eliminates expand+contiguous for beta/v_th tensors, reduces memory I/O by ~40%. """ _BLOCK = 128 @staticmethod def forward(ctx, beta_row, u, v_th_row, v_init, alpha): beta_row_c = beta_row.contiguous() u_c = u.contiguous() v_th_row_c = v_th_row.contiguous() v_init_c = v_init.contiguous() K = u_c.shape[0] num_cols = u_c[0].numel() spike = torch.empty_like(u_c) V_post = torch.empty_like(u_c) BLOCK = _TritonPLIFRowParamForward._BLOCK grid = ((num_cols + BLOCK - 1) // BLOCK,) _fused_plif_fwd_rowparam_kernel[grid]( beta_row_c, u_c, v_th_row_c, v_init_c, spike, V_post, K, num_cols, BLOCK=BLOCK, ) if any(ctx.needs_input_grad[:4]): ctx.save_for_backward(beta_row_c, v_th_row_c, v_init_c, V_post, spike) ctx.K = K ctx.num_cols = num_cols ctx.alpha = alpha return spike, V_post @staticmethod def backward(ctx, grad_spike, grad_V_post): beta_row, v_th_row, v_init, V_post, spike = ctx.saved_tensors K = ctx.K num_cols = ctx.num_cols alpha = ctx.alpha if grad_spike is None: grad_spike = torch.zeros_like(spike) if grad_V_post is None: grad_V_post = torch.zeros_like(V_post) grad_spike_c = grad_spike.contiguous() grad_V_post_c = grad_V_post.contiguous() grad_beta_row = torch.empty_like(beta_row) grad_u = torch.empty_like(V_post) grad_v_th_row = torch.empty_like(v_th_row) grad_v_init = torch.empty_like(v_init) BLOCK = _TritonPLIFRowParamForward._BLOCK grid = ((num_cols + BLOCK - 1) // BLOCK,) _fused_plif_bwd_rowparam_kernel[grid]( beta_row, v_th_row, v_init, V_post, spike, grad_spike_c, grad_V_post_c, grad_beta_row, grad_u, grad_v_th_row, grad_v_init, K, num_cols, float(alpha), BLOCK=BLOCK, ) return grad_beta_row, grad_u, grad_v_th_row, grad_v_init, None class _TritonPLIFForward(torch.autograd.Function): """Fused Triton PLIF forward + backward. Single-pass sequential scan replaces the 3-phase approach: Phase 1 (linear scan) + Phase 2 (spike iteration) + Phase 3 (correction) → 1 fused kernel with inline spike detection + soft reset Advantages: - 1 kernel launch (vs 3-4 launches + ~10 element-wise ops) - Exact computation (no iteration convergence issues) - Less memory (no intermediate V_L, delta_S, delta_S_prev) - Higher precision (fp32 accumulation, no bf16 intermediate store/load) """ _BLOCK = 128 @staticmethod def forward(ctx, beta, u, v_th, v_init, alpha): beta_c = beta.contiguous() u_c = u.contiguous() v_th_c = v_th.contiguous() v_init_c = v_init.contiguous() K = beta_c.shape[0] num_cols = beta_c[0].numel() spike = torch.empty_like(u_c) V_post = torch.empty_like(u_c) BLOCK = _TritonPLIFForward._BLOCK grid = ((num_cols + BLOCK - 1) // BLOCK,) _fused_plif_fwd_kernel[grid]( beta_c, u_c, v_th_c, v_init_c, spike, V_post, K, num_cols, BLOCK=BLOCK, ) if any(ctx.needs_input_grad[:4]): ctx.save_for_backward(beta_c, v_th_c, v_init_c, V_post, spike) ctx.K = K ctx.num_cols = num_cols ctx.alpha = alpha return spike, V_post @staticmethod def backward(ctx, grad_spike, grad_V_post): beta, v_th, v_init, V_post, spike = ctx.saved_tensors K = ctx.K num_cols = ctx.num_cols alpha = ctx.alpha if grad_spike is None: grad_spike = torch.zeros_like(spike) if grad_V_post is None: grad_V_post = torch.zeros_like(V_post) grad_spike_c = grad_spike.contiguous() grad_V_post_c = grad_V_post.contiguous() grad_beta = torch.empty_like(beta) grad_u = torch.empty_like(beta) grad_v_th = torch.empty_like(v_th) grad_v_init = torch.empty_like(v_init) BLOCK = _TritonPLIFForward._BLOCK grid = ((num_cols + BLOCK - 1) // BLOCK,) _fused_plif_bwd_kernel[grid]( beta, v_th, v_init, V_post, spike, grad_spike_c, grad_V_post_c, grad_beta, grad_u, grad_v_th, grad_v_init, K, num_cols, float(alpha), BLOCK=BLOCK, ) return grad_beta, grad_u, grad_v_th, grad_v_init, None # ============================================================ # Hillis-Steele parallel prefix scan (CPU fallback) # ============================================================ def hillis_steele_scan(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ Hillis-Steele 并行前缀扫描:计算仿射映射序列的所有前缀复合。 给定仿射映射 f_k(x) = a[k] * x + b[k], k = 0, ..., K-1, 计算前缀复合 F_k = f_k ∘ f_{k-1} ∘ ... ∘ f_0, 使得 V[k] = F_k(v_init) = A[k] * v_init + B[k]。 复合规则: (a2, b2) ∘ (a1, b1) = (a2 * a1, a2 * b1 + b2) 实现使用 torch.cat 重建张量(无原地操作),完全兼容 autograd。 Args: a: (K, *shape) — 乘性系数(如 β) b: (K, *shape) — 加性项(如 α·I) Returns: A: (K, *shape) — 前缀积 A[k] = ∏_{j=0}^{k} a[j] B: (K, *shape) — 前缀和 B[k] 使得 V[k] = A[k] * v_init + B[k] 并行深度: O(log K) 工作量: O(K * log K) """ K = a.shape[0] A = a B = b d = 1 while d < K: A_new_tail = A[d:] * A[:-d] B_new_tail = A[d:] * B[:-d] + B[d:] A = torch.cat([A[:d], A_new_tail], dim=0) B = torch.cat([B[:d], B_new_tail], dim=0) d *= 2 return A, B # ============================================================ # Public API: linear_recurrence # ============================================================ def linear_recurrence(beta: torch.Tensor, u: torch.Tensor, v_init: torch.Tensor) -> torch.Tensor: """ 求解线性递推: V[k] = beta[k] * V[k-1] + u[k], V[-1] = v_init CUDA 后端: Triton fused kernel(1 次 kernel launch,O(K) 工作量) CPU 后端: Hillis-Steele parallel scan(O(K log K) 工作量) Args: beta: (K, *shape) — 衰减系数,值域 (0, 1) u: (K, *shape) — 输入项 v_init: (*shape) — 初始状态 Returns: V: (K, *shape) — 所有 K 步的状态 """ if _HAS_TRITON and beta.is_cuda: return _TritonLinearRecurrence.apply(beta, u, v_init) # CPU fallback A, B = hillis_steele_scan(beta, u) V = A * v_init.unsqueeze(0) + B return V # ============================================================ # PLIF parallel forward (with spike iteration) # ============================================================ def plif_parallel_forward( beta: torch.Tensor, u: torch.Tensor, v_th: torch.Tensor, v_init: torch.Tensor, max_iter: int = 3, surrogate_function=None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ PLIF 神经元的并行前向传播(soft reset,surrogate gradient 兼容)。 求解: V_pre[k] = beta[k] * V_post[k-1] + u[k] s[k] = Θ(V_pre[k] - v_th[k]) V_post[k] = V_pre[k] - v_th[k] * s[k] 方法: Phase 1: 线性轨迹 parallel scan(有梯度) Phase 2: spike 不动点迭代(detach,确定离散 spike pattern) Phase 3: 用 converged spike pattern 重算 V_post(有梯度), surrogate_function(V_pre - v_th) 生成可微 spike 输出 Args: beta: (K, *shape) — 衰减系数 u: (K, *shape) — 输入 α·I v_th: (K, *shape) — 动态阈值 v_init: (*shape) — 初始膜电位 max_iter: spike 不动点迭代次数上限 surrogate_function: SpikingJelly surrogate gradient 函数(如 surrogate.Sigmoid(alpha=4.0)) None 时退化为硬阈值(无梯度) Returns: spike: (K, *shape) — spike 模式(有 surrogate gradient) V_post: (K, *shape) — 发放后膜电位 V_pre: (K, *shape) — 发放前膜电位(fused path 返回 None) """ # Fused Triton path: single-pass sequential scan (exact, no iteration) # Replaces 3-phase approach with 1 kernel launch — ~3x faster forward, ~5x faster backward if (_HAS_TRITON and beta.is_cuda and surrogate_function is not None and hasattr(surrogate_function, 'alpha') and type(surrogate_function).__name__ == 'Sigmoid'): alpha = float(surrogate_function.alpha) spike, V_post = _TritonPLIFForward.apply(beta, u, v_th, v_init, alpha) return spike, V_post, None # Fallback: 3-phase approach (CPU, non-Sigmoid surrogates, or no surrogate) # Phase 1: 线性轨迹 V_L (假设从不发放) V_L = linear_recurrence(beta, u, v_init) # (K, *shape) # Phase 2: Spike 不动点迭代(全部 detach,不建立梯度图) # 目的:确定哪些神经元在哪些步发放(离散决策) with torch.no_grad(): V_L_det = V_L.detach() beta_det = beta.detach() v_th_det = v_th.detach() v_init_det = v_init.detach() if isinstance(v_init, torch.Tensor) else v_init spike_pattern = (V_L_det >= v_th_det).float() for _ in range(max_iter - 1): # 计算 ΔS: ΔS[k] = beta[k] * ΔS[k-1] + v_th[k] * s[k] delta_S = linear_recurrence( beta_det, v_th_det * spike_pattern, torch.zeros_like(v_init_det) if isinstance(v_init_det, torch.Tensor) else torch.zeros_like(V_L_det[0]), ) # ΔS_prev = ΔS[k-1](位移一步) delta_S_prev = torch.zeros_like(delta_S) delta_S_prev[1:] = delta_S[:-1] # V_pre = V_L - beta * ΔS_prev V_pre_det = V_L_det - beta_det * delta_S_prev # 更新 spike spike_new = (V_pre_det >= v_th_det).float() # 收敛检查 if torch.equal(spike_new, spike_pattern): break spike_pattern = spike_new # Phase 3: 用 converged spike pattern 重算 V_post(有完整梯度) # spike_pattern 是 detached 的,作为常数参与计算 # 梯度通过 u, v_th, beta, v_init 流动 u_eff = u - v_th * spike_pattern V_post = linear_recurrence(beta, u_eff, v_init) # (K, *shape) # 重建 V_pre(有梯度,用于 surrogate gradient) V_post_prev = torch.zeros_like(V_post) if isinstance(v_init, torch.Tensor): V_post_prev[0] = v_init V_post_prev[1:] = V_post[:-1] V_pre = beta * V_post_prev + u # 生成可微 spike 输出 if surrogate_function is not None: # forward: Heaviside(V_pre - v_th), backward: surrogate gradient spike = surrogate_function(V_pre - v_th) else: # 退化模式:硬阈值,无梯度 spike = (V_pre >= v_th).float() return spike, V_post, V_pre def plif_rowparam_forward( beta_row: torch.Tensor, u: torch.Tensor, v_th_row: torch.Tensor, v_init: torch.Tensor, surrogate_function=None, ) -> tuple[torch.Tensor, torch.Tensor]: """ 行参数 PLIF 前向:beta 和 v_th 在 K 步中保持恒定。 比 plif_parallel_forward 快 ~40%(省去 expand+contiguous,减少 2/3 显存读取)。 用于 ParametricLIFNode(固定 beta/v_th)或合并多个固定参数神经元。 Args: beta_row: (*shape) — 每列的衰减率(所有 K 步相同) u: (K, *shape) — 每步输入 v_th_row: (*shape) — 每列的阈值(所有 K 步相同) v_init: (*shape) — 初始膜电位 surrogate_function: surrogate gradient 函数 Returns: spike: (K, *shape) — spike 模式 V_post: (K, *shape) — 发放后膜电位 """ if (_HAS_TRITON and u.is_cuda and surrogate_function is not None and hasattr(surrogate_function, 'alpha') and type(surrogate_function).__name__ == 'Sigmoid'): alpha = float(surrogate_function.alpha) spike, V_post = _TritonPLIFRowParamForward.apply( beta_row, u, v_th_row, v_init, alpha, ) return spike, V_post # Fallback: expand to full (K, *shape) and use standard path K = u.shape[0] beta = beta_row.unsqueeze(0).expand(K, *u.shape[1:]).contiguous() v_th = v_th_row.unsqueeze(0).expand(K, *u.shape[1:]).contiguous() spike, V_post, _ = plif_parallel_forward(beta, u, v_th, v_init, surrogate_function=surrogate_function) return spike, V_post def plif_fixed_param_forward( beta, u: torch.Tensor, v_th, v_init: torch.Tensor, max_iter: int = 3, surrogate_function=None, ) -> tuple[torch.Tensor, torch.Tensor]: """ 固定参数 PLIF 神经元的并行前向(如输出神经元、FFN 神经元)。 ParametricLIFNode 方程: V[k] = beta * V[k-1] + (1-beta) * x[k] 其中 beta = 1/(1+exp(w)), 可为 scalar tensor(保持梯度流向 w)。 scalar/0-dim beta 和 v_th 使用 row-param 内核(无需 expand 到 (K, *shape))。 Args: beta: 衰减率 — scalar float、0-dim tensor 或 (K, *shape) tensor u: (K, *shape) — 输入(已乘以 (1-beta)) v_th: 阈值 — scalar float、0-dim tensor 或 (K, *shape) tensor v_init: (*shape) — 初始膜电位 max_iter: spike 迭代次数 surrogate_function: surrogate gradient 函数 Returns: spike: (K, *shape) — spike 模式 V_post: (K, *shape) — 发放后膜电位 """ K = u.shape[0] shape = u.shape[1:] # Row-param fast path: beta 和 v_th 都是 scalar/0-dim → 扩展为 (*shape) 行向量 beta_is_scalar = isinstance(beta, torch.Tensor) and beta.dim() == 0 beta_is_float = not isinstance(beta, torch.Tensor) vth_is_scalar = isinstance(v_th, torch.Tensor) and v_th.dim() == 0 vth_is_float = not isinstance(v_th, torch.Tensor) if (beta_is_scalar or beta_is_float) and (vth_is_scalar or vth_is_float): # Build row vectors (*shape) if beta_is_scalar: beta_row = beta.expand(*shape).contiguous() else: beta_row = torch.full(shape, beta, device=u.device, dtype=u.dtype) if vth_is_scalar: v_th_row = v_th.expand(*shape).contiguous() else: v_th_row = torch.full(shape, v_th, device=u.device, dtype=u.dtype) return plif_rowparam_forward(beta_row, u, v_th_row, v_init, surrogate_function) # Full-tensor path: expand to (K, *shape) if needed if isinstance(beta, torch.Tensor): if beta.dim() == 0: beta = beta.expand(K, *shape).contiguous() else: beta = torch.full_like(u, beta) if isinstance(v_th, torch.Tensor): if v_th.dim() == 0: v_th = v_th.expand(K, *shape).contiguous() else: v_th = torch.full_like(u, v_th) spike, V_post, _ = plif_parallel_forward( beta, u, v_th, v_init, max_iter, surrogate_function, ) return spike, V_post