NeuronSpark-0.9B-Chat / atomic_ops /parallel_scan.py
Brain2nd's picture
Initial release: NeuronSpark-0.9B-Chat instruction-tuned SNN language model
440e322 verified
"""
Parallel Scan 工具函数:SNN 线性递推的高效并行求解
实现三层后端:
1. Fused PLIF kernel(默认,CUDA + Sigmoid surrogate):
单 kernel 完成 PLIF 前向(scan + spike + soft reset)和反向(surrogate gradient)
· per-element beta/v_th: _fused_plif_fwd_kernel / _fused_plif_bwd_kernel
· row-param beta/v_th: _fused_plif_fwd_rowparam_kernel / _fused_plif_bwd_rowparam_kernel
2. Triton linear_recurrence(CUDA,非 Sigmoid 或无 surrogate):
列级并行 scan,O(K) 工作量,1 次 kernel launch
3. Hillis-Steele parallel scan(CPU 回退):O(K log K) 工作量
线性递推:
V[k] = a[k] * V[k-1] + b[k], V[-1] = v_init
PLIF 神经元动力学:
V_pre[k] = beta[k] * V_post[k-1] + u[k]
s[k] = Θ(V_pre[k] - v_th[k])
V_post[k] = V_pre[k] - v_th[k] * s[k]
数学原理见 SNN_SELECTIVE_STATE_SPACE.md。
"""
import os
import torch
# ============================================================
# Triton fused recurrence kernels
# ============================================================
# DGX Spark (GB10, sm_121a): Triton 3.5.1 自带 ptxas 不支持 sm_121a,
# 需要使用系统 CUDA 13.0 的 ptxas
_SYSTEM_PTXAS = '/usr/local/cuda-13.0/bin/ptxas'
if os.path.exists(_SYSTEM_PTXAS) and 'TRITON_PTXAS_PATH' not in os.environ:
os.environ['TRITON_PTXAS_PATH'] = _SYSTEM_PTXAS
_HAS_TRITON = False
try:
import triton
import triton.language as tl
_HAS_TRITON = True
except ImportError:
pass
if _HAS_TRITON:
@triton.jit
def _fwd_recurrence_kernel(
A_ptr, B_ptr, INIT_ptr, OUT_ptr,
K, num_cols,
BLOCK: tl.constexpr,
):
"""Forward: V[k] = A[k]*V[k-1] + B[k], V[-1] = init.
Grid: (ceil(num_cols / BLOCK),)
Each program processes BLOCK columns across all K sequential steps.
Accumulation in float32; storage in input dtype.
"""
pid = tl.program_id(0)
cols = pid * BLOCK + tl.arange(0, BLOCK)
mask = cols < num_cols
v = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32)
for k in range(K):
off = k * num_cols + cols
a = tl.load(A_ptr + off, mask=mask, other=0.0).to(tl.float32)
b = tl.load(B_ptr + off, mask=mask, other=0.0).to(tl.float32)
v = a * v + b
tl.store(OUT_ptr + off, v, mask=mask)
@triton.jit
def _bwd_recurrence_kernel(
A_ptr, V_ptr, INIT_ptr, GRAD_OUT_ptr,
GRAD_A_ptr, GRAD_B_ptr, GRAD_INIT_ptr,
K, num_cols,
BLOCK: tl.constexpr,
):
"""Backward for V[k] = A[k]*V[k-1] + B[k].
Reverse accumulation (k from K-1 down to 0):
g = 0
for k = K-1, ..., 0:
g += grad_out[k]
grad_B[k] = g
grad_A[k] = g * V[k-1] (V[-1] = init)
g = g * A[k]
grad_init = g
"""
pid = tl.program_id(0)
cols = pid * BLOCK + tl.arange(0, BLOCK)
mask = cols < num_cols
g = tl.zeros([BLOCK], dtype=tl.float32)
for k_rev in range(K):
k = K - 1 - k_rev
off = k * num_cols + cols
dV = tl.load(GRAD_OUT_ptr + off, mask=mask, other=0.0).to(tl.float32)
g = g + dV
tl.store(GRAD_B_ptr + off, g, mask=mask)
if k > 0:
v_prev = tl.load(
V_ptr + (k - 1) * num_cols + cols,
mask=mask, other=0.0,
).to(tl.float32)
else:
v_prev = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32)
tl.store(GRAD_A_ptr + off, g * v_prev, mask=mask)
a = tl.load(A_ptr + off, mask=mask, other=0.0).to(tl.float32)
g = g * a
tl.store(GRAD_INIT_ptr + cols, g, mask=mask)
class _TritonLinearRecurrence(torch.autograd.Function):
"""Fused Triton linear recurrence: V[k] = A[k]*V[k-1] + B[k]."""
_BLOCK = 128
@staticmethod
def forward(ctx, beta, u, v_init):
beta_c = beta.contiguous()
u_c = u.contiguous()
v_init_c = v_init.contiguous()
K = beta_c.shape[0]
num_cols = beta_c[0].numel()
V = torch.empty_like(u_c)
BLOCK = _TritonLinearRecurrence._BLOCK
grid = ((num_cols + BLOCK - 1) // BLOCK,)
_fwd_recurrence_kernel[grid](
beta_c, u_c, v_init_c, V,
K, num_cols,
BLOCK=BLOCK,
)
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1] or ctx.needs_input_grad[2]:
ctx.save_for_backward(beta_c, V, v_init_c)
ctx.K = K
ctx.num_cols = num_cols
return V
@staticmethod
def backward(ctx, grad_V):
beta, V, v_init = ctx.saved_tensors
grad_V_c = grad_V.contiguous()
K = ctx.K
num_cols = ctx.num_cols
grad_beta = torch.empty_like(beta)
grad_u = torch.empty_like(beta)
grad_v_init = torch.empty_like(v_init)
BLOCK = _TritonLinearRecurrence._BLOCK
grid = ((num_cols + BLOCK - 1) // BLOCK,)
_bwd_recurrence_kernel[grid](
beta, V, v_init, grad_V_c,
grad_beta, grad_u, grad_v_init,
K, num_cols,
BLOCK=BLOCK,
)
return grad_beta, grad_u, grad_v_init
# ============================================================
# Fused PLIF forward/backward kernels
# ============================================================
@triton.jit
def _fused_plif_fwd_kernel(
BETA_ptr, U_ptr, VTH_ptr, INIT_ptr,
SPIKE_ptr, VPOST_ptr,
K, num_cols,
BLOCK: tl.constexpr,
):
"""Fused PLIF forward: single-pass sequential scan with inline spike + soft reset.
Exact computation — sequential scan IS the ground truth.
Replaces the 3-phase approach (linear scan + spike iteration + correction).
Per column (parallel across batch*D):
v = v_init
for k = 0..K-1:
v_pre = beta[k]*v + u[k]
spike[k] = Θ(v_pre - v_th[k])
v = v_pre - v_th[k]*spike[k]
"""
pid = tl.program_id(0)
cols = pid * BLOCK + tl.arange(0, BLOCK)
mask = cols < num_cols
v = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32)
for k in range(K):
off = k * num_cols + cols
beta = tl.load(BETA_ptr + off, mask=mask, other=0.0).to(tl.float32)
u = tl.load(U_ptr + off, mask=mask, other=0.0).to(tl.float32)
vth = tl.load(VTH_ptr + off, mask=mask, other=0.0).to(tl.float32)
v_pre = beta * v + u
spike = tl.where(v_pre >= vth, 1.0, 0.0)
v = v_pre - vth * spike # soft reset
tl.store(SPIKE_ptr + off, spike, mask=mask)
tl.store(VPOST_ptr + off, v, mask=mask)
@triton.jit
def _fused_plif_bwd_kernel(
BETA_ptr, VTH_ptr, INIT_ptr, VPOST_ptr, SPIKE_ptr,
GRAD_SPIKE_ptr, GRAD_VPOST_ptr,
GRAD_BETA_ptr, GRAD_U_ptr, GRAD_VTH_ptr, GRAD_INIT_ptr,
K, num_cols, ALPHA,
BLOCK: tl.constexpr,
):
"""Fused PLIF backward: single reverse pass with Sigmoid surrogate gradient.
V_pre[k] = V_post[k] + v_th[k]*spike[k] (reconstructed)
surrogate_grad(x) = alpha * sigmoid(alpha*x) * (1 - sigmoid(alpha*x))
where x = V_pre[k] - v_th[k] = V_post[k] - v_th[k]*(1 - spike[k])
Reverse accumulation:
acc = 0
for k = K-1 downto 0:
total_gV = grad_V_post[k] + acc
sg = surrogate_grad(V_pre[k] - v_th[k])
grad_v_pre = grad_spike[k]*sg + total_gV
grad_beta[k] = grad_v_pre * V_post[k-1]
grad_u[k] = grad_v_pre
grad_v_th[k] = -grad_spike[k]*sg - total_gV*spike[k]
acc = grad_v_pre * beta[k]
grad_v_init = acc
"""
pid = tl.program_id(0)
cols = pid * BLOCK + tl.arange(0, BLOCK)
mask = cols < num_cols
acc = tl.zeros([BLOCK], dtype=tl.float32)
for k_rev in range(K):
k = K - 1 - k_rev
off = k * num_cols + cols
beta = tl.load(BETA_ptr + off, mask=mask, other=0.0).to(tl.float32)
vth = tl.load(VTH_ptr + off, mask=mask, other=0.0).to(tl.float32)
v_post = tl.load(VPOST_ptr + off, mask=mask, other=0.0).to(tl.float32)
spike = tl.load(SPIKE_ptr + off, mask=mask, other=0.0).to(tl.float32)
g_s = tl.load(GRAD_SPIKE_ptr + off, mask=mask, other=0.0).to(tl.float32)
g_V = tl.load(GRAD_VPOST_ptr + off, mask=mask, other=0.0).to(tl.float32)
# V_post[k-1]
if k > 0:
v_prev = tl.load(
VPOST_ptr + (k - 1) * num_cols + cols,
mask=mask, other=0.0,
).to(tl.float32)
else:
v_prev = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Sigmoid surrogate gradient
x = v_post - vth * (1.0 - spike) # = V_pre - v_th
neg_ax = -ALPHA * x
neg_ax = tl.where(neg_ax > 88.0, 88.0, neg_ax) # prevent exp overflow
sig = 1.0 / (1.0 + tl.exp(neg_ax))
sg = ALPHA * sig * (1.0 - sig)
total_gV = g_V + acc
grad_v_pre = g_s * sg + total_gV
tl.store(GRAD_BETA_ptr + off, grad_v_pre * v_prev, mask=mask)
tl.store(GRAD_U_ptr + off, grad_v_pre, mask=mask)
tl.store(GRAD_VTH_ptr + off, -g_s * sg - total_gV * spike, mask=mask)
acc = grad_v_pre * beta
tl.store(GRAD_INIT_ptr + cols, acc, mask=mask)
# ============================================================
# Fused PLIF kernels with row-parameter beta/v_th
# (constant across K steps — e.g., ParametricLIFNode scalars)
# ============================================================
@triton.jit
def _fused_plif_fwd_rowparam_kernel(
BETA_ROW_ptr, U_ptr, VTH_ROW_ptr, INIT_ptr,
SPIKE_ptr, VPOST_ptr,
K, num_cols,
BLOCK: tl.constexpr,
):
"""Fused PLIF forward with row-parameter beta and v_th.
beta and v_th are (*shape) — constant across K steps, loaded once into registers.
Reduces global memory reads from 3 per step (beta, u, v_th) to 1 (u only).
"""
pid = tl.program_id(0)
cols = pid * BLOCK + tl.arange(0, BLOCK)
mask = cols < num_cols
v = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32)
beta = tl.load(BETA_ROW_ptr + cols, mask=mask, other=0.0).to(tl.float32)
vth = tl.load(VTH_ROW_ptr + cols, mask=mask, other=0.0).to(tl.float32)
for k in range(K):
off = k * num_cols + cols
u = tl.load(U_ptr + off, mask=mask, other=0.0).to(tl.float32)
v_pre = beta * v + u
spike = tl.where(v_pre >= vth, 1.0, 0.0)
v = v_pre - vth * spike
tl.store(SPIKE_ptr + off, spike, mask=mask)
tl.store(VPOST_ptr + off, v, mask=mask)
@triton.jit
def _fused_plif_bwd_rowparam_kernel(
BETA_ROW_ptr, VTH_ROW_ptr, INIT_ptr, VPOST_ptr, SPIKE_ptr,
GRAD_SPIKE_ptr, GRAD_VPOST_ptr,
GRAD_BETA_ROW_ptr, GRAD_U_ptr, GRAD_VTH_ROW_ptr, GRAD_INIT_ptr,
K, num_cols, ALPHA,
BLOCK: tl.constexpr,
):
"""Fused PLIF backward with row-parameter beta/v_th.
Gradients for beta and v_th are accumulated over K steps (reduction in registers).
Returns grad_beta_row (*shape) and grad_v_th_row (*shape) instead of per-step gradients.
"""
pid = tl.program_id(0)
cols = pid * BLOCK + tl.arange(0, BLOCK)
mask = cols < num_cols
beta = tl.load(BETA_ROW_ptr + cols, mask=mask, other=0.0).to(tl.float32)
vth = tl.load(VTH_ROW_ptr + cols, mask=mask, other=0.0).to(tl.float32)
acc = tl.zeros([BLOCK], dtype=tl.float32)
acc_grad_beta = tl.zeros([BLOCK], dtype=tl.float32)
acc_grad_vth = tl.zeros([BLOCK], dtype=tl.float32)
for k_rev in range(K):
k = K - 1 - k_rev
off = k * num_cols + cols
v_post = tl.load(VPOST_ptr + off, mask=mask, other=0.0).to(tl.float32)
spike = tl.load(SPIKE_ptr + off, mask=mask, other=0.0).to(tl.float32)
g_s = tl.load(GRAD_SPIKE_ptr + off, mask=mask, other=0.0).to(tl.float32)
g_V = tl.load(GRAD_VPOST_ptr + off, mask=mask, other=0.0).to(tl.float32)
if k > 0:
v_prev = tl.load(
VPOST_ptr + (k - 1) * num_cols + cols,
mask=mask, other=0.0,
).to(tl.float32)
else:
v_prev = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Sigmoid surrogate gradient
x = v_post - vth * (1.0 - spike)
neg_ax = -ALPHA * x
neg_ax = tl.where(neg_ax > 88.0, 88.0, neg_ax)
sig = 1.0 / (1.0 + tl.exp(neg_ax))
sg = ALPHA * sig * (1.0 - sig)
total_gV = g_V + acc
grad_v_pre = g_s * sg + total_gV
tl.store(GRAD_U_ptr + off, grad_v_pre, mask=mask)
# Accumulate gradients for row parameters (reduction over K in registers)
acc_grad_beta += grad_v_pre * v_prev
acc_grad_vth += -g_s * sg - total_gV * spike
acc = grad_v_pre * beta
tl.store(GRAD_INIT_ptr + cols, acc, mask=mask)
tl.store(GRAD_BETA_ROW_ptr + cols, acc_grad_beta, mask=mask)
tl.store(GRAD_VTH_ROW_ptr + cols, acc_grad_vth, mask=mask)
class _TritonPLIFRowParamForward(torch.autograd.Function):
"""Fused Triton PLIF with row-parameter beta/v_th.
For neurons with constant beta/v_th across K steps (ParametricLIFNode).
Eliminates expand+contiguous for beta/v_th tensors, reduces memory I/O by ~40%.
"""
_BLOCK = 128
@staticmethod
def forward(ctx, beta_row, u, v_th_row, v_init, alpha):
beta_row_c = beta_row.contiguous()
u_c = u.contiguous()
v_th_row_c = v_th_row.contiguous()
v_init_c = v_init.contiguous()
K = u_c.shape[0]
num_cols = u_c[0].numel()
spike = torch.empty_like(u_c)
V_post = torch.empty_like(u_c)
BLOCK = _TritonPLIFRowParamForward._BLOCK
grid = ((num_cols + BLOCK - 1) // BLOCK,)
_fused_plif_fwd_rowparam_kernel[grid](
beta_row_c, u_c, v_th_row_c, v_init_c,
spike, V_post,
K, num_cols,
BLOCK=BLOCK,
)
if any(ctx.needs_input_grad[:4]):
ctx.save_for_backward(beta_row_c, v_th_row_c, v_init_c, V_post, spike)
ctx.K = K
ctx.num_cols = num_cols
ctx.alpha = alpha
return spike, V_post
@staticmethod
def backward(ctx, grad_spike, grad_V_post):
beta_row, v_th_row, v_init, V_post, spike = ctx.saved_tensors
K = ctx.K
num_cols = ctx.num_cols
alpha = ctx.alpha
if grad_spike is None:
grad_spike = torch.zeros_like(spike)
if grad_V_post is None:
grad_V_post = torch.zeros_like(V_post)
grad_spike_c = grad_spike.contiguous()
grad_V_post_c = grad_V_post.contiguous()
grad_beta_row = torch.empty_like(beta_row)
grad_u = torch.empty_like(V_post)
grad_v_th_row = torch.empty_like(v_th_row)
grad_v_init = torch.empty_like(v_init)
BLOCK = _TritonPLIFRowParamForward._BLOCK
grid = ((num_cols + BLOCK - 1) // BLOCK,)
_fused_plif_bwd_rowparam_kernel[grid](
beta_row, v_th_row, v_init, V_post, spike,
grad_spike_c, grad_V_post_c,
grad_beta_row, grad_u, grad_v_th_row, grad_v_init,
K, num_cols, float(alpha),
BLOCK=BLOCK,
)
return grad_beta_row, grad_u, grad_v_th_row, grad_v_init, None
class _TritonPLIFForward(torch.autograd.Function):
"""Fused Triton PLIF forward + backward.
Single-pass sequential scan replaces the 3-phase approach:
Phase 1 (linear scan) + Phase 2 (spike iteration) + Phase 3 (correction)
→ 1 fused kernel with inline spike detection + soft reset
Advantages:
- 1 kernel launch (vs 3-4 launches + ~10 element-wise ops)
- Exact computation (no iteration convergence issues)
- Less memory (no intermediate V_L, delta_S, delta_S_prev)
- Higher precision (fp32 accumulation, no bf16 intermediate store/load)
"""
_BLOCK = 128
@staticmethod
def forward(ctx, beta, u, v_th, v_init, alpha):
beta_c = beta.contiguous()
u_c = u.contiguous()
v_th_c = v_th.contiguous()
v_init_c = v_init.contiguous()
K = beta_c.shape[0]
num_cols = beta_c[0].numel()
spike = torch.empty_like(u_c)
V_post = torch.empty_like(u_c)
BLOCK = _TritonPLIFForward._BLOCK
grid = ((num_cols + BLOCK - 1) // BLOCK,)
_fused_plif_fwd_kernel[grid](
beta_c, u_c, v_th_c, v_init_c,
spike, V_post,
K, num_cols,
BLOCK=BLOCK,
)
if any(ctx.needs_input_grad[:4]):
ctx.save_for_backward(beta_c, v_th_c, v_init_c, V_post, spike)
ctx.K = K
ctx.num_cols = num_cols
ctx.alpha = alpha
return spike, V_post
@staticmethod
def backward(ctx, grad_spike, grad_V_post):
beta, v_th, v_init, V_post, spike = ctx.saved_tensors
K = ctx.K
num_cols = ctx.num_cols
alpha = ctx.alpha
if grad_spike is None:
grad_spike = torch.zeros_like(spike)
if grad_V_post is None:
grad_V_post = torch.zeros_like(V_post)
grad_spike_c = grad_spike.contiguous()
grad_V_post_c = grad_V_post.contiguous()
grad_beta = torch.empty_like(beta)
grad_u = torch.empty_like(beta)
grad_v_th = torch.empty_like(v_th)
grad_v_init = torch.empty_like(v_init)
BLOCK = _TritonPLIFForward._BLOCK
grid = ((num_cols + BLOCK - 1) // BLOCK,)
_fused_plif_bwd_kernel[grid](
beta, v_th, v_init, V_post, spike,
grad_spike_c, grad_V_post_c,
grad_beta, grad_u, grad_v_th, grad_v_init,
K, num_cols, float(alpha),
BLOCK=BLOCK,
)
return grad_beta, grad_u, grad_v_th, grad_v_init, None
# ============================================================
# Hillis-Steele parallel prefix scan (CPU fallback)
# ============================================================
def hillis_steele_scan(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Hillis-Steele 并行前缀扫描:计算仿射映射序列的所有前缀复合。
给定仿射映射 f_k(x) = a[k] * x + b[k], k = 0, ..., K-1,
计算前缀复合 F_k = f_k ∘ f_{k-1} ∘ ... ∘ f_0,
使得 V[k] = F_k(v_init) = A[k] * v_init + B[k]。
复合规则: (a2, b2) ∘ (a1, b1) = (a2 * a1, a2 * b1 + b2)
实现使用 torch.cat 重建张量(无原地操作),完全兼容 autograd。
Args:
a: (K, *shape) — 乘性系数(如 β)
b: (K, *shape) — 加性项(如 α·I)
Returns:
A: (K, *shape) — 前缀积 A[k] = ∏_{j=0}^{k} a[j]
B: (K, *shape) — 前缀和 B[k] 使得 V[k] = A[k] * v_init + B[k]
并行深度: O(log K)
工作量: O(K * log K)
"""
K = a.shape[0]
A = a
B = b
d = 1
while d < K:
A_new_tail = A[d:] * A[:-d]
B_new_tail = A[d:] * B[:-d] + B[d:]
A = torch.cat([A[:d], A_new_tail], dim=0)
B = torch.cat([B[:d], B_new_tail], dim=0)
d *= 2
return A, B
# ============================================================
# Public API: linear_recurrence
# ============================================================
def linear_recurrence(beta: torch.Tensor, u: torch.Tensor, v_init: torch.Tensor) -> torch.Tensor:
"""
求解线性递推: V[k] = beta[k] * V[k-1] + u[k], V[-1] = v_init
CUDA 后端: Triton fused kernel(1 次 kernel launch,O(K) 工作量)
CPU 后端: Hillis-Steele parallel scan(O(K log K) 工作量)
Args:
beta: (K, *shape) — 衰减系数,值域 (0, 1)
u: (K, *shape) — 输入项
v_init: (*shape) — 初始状态
Returns:
V: (K, *shape) — 所有 K 步的状态
"""
if _HAS_TRITON and beta.is_cuda:
return _TritonLinearRecurrence.apply(beta, u, v_init)
# CPU fallback
A, B = hillis_steele_scan(beta, u)
V = A * v_init.unsqueeze(0) + B
return V
# ============================================================
# PLIF parallel forward (with spike iteration)
# ============================================================
def plif_parallel_forward(
beta: torch.Tensor,
u: torch.Tensor,
v_th: torch.Tensor,
v_init: torch.Tensor,
max_iter: int = 3,
surrogate_function=None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
PLIF 神经元的并行前向传播(soft reset,surrogate gradient 兼容)。
求解:
V_pre[k] = beta[k] * V_post[k-1] + u[k]
s[k] = Θ(V_pre[k] - v_th[k])
V_post[k] = V_pre[k] - v_th[k] * s[k]
方法:
Phase 1: 线性轨迹 parallel scan(有梯度)
Phase 2: spike 不动点迭代(detach,确定离散 spike pattern)
Phase 3: 用 converged spike pattern 重算 V_post(有梯度),
surrogate_function(V_pre - v_th) 生成可微 spike 输出
Args:
beta: (K, *shape) — 衰减系数
u: (K, *shape) — 输入 α·I
v_th: (K, *shape) — 动态阈值
v_init: (*shape) — 初始膜电位
max_iter: spike 不动点迭代次数上限
surrogate_function: SpikingJelly surrogate gradient 函数(如 surrogate.Sigmoid(alpha=4.0))
None 时退化为硬阈值(无梯度)
Returns:
spike: (K, *shape) — spike 模式(有 surrogate gradient)
V_post: (K, *shape) — 发放后膜电位
V_pre: (K, *shape) — 发放前膜电位(fused path 返回 None)
"""
# Fused Triton path: single-pass sequential scan (exact, no iteration)
# Replaces 3-phase approach with 1 kernel launch — ~3x faster forward, ~5x faster backward
if (_HAS_TRITON and beta.is_cuda and surrogate_function is not None
and hasattr(surrogate_function, 'alpha')
and type(surrogate_function).__name__ == 'Sigmoid'):
alpha = float(surrogate_function.alpha)
spike, V_post = _TritonPLIFForward.apply(beta, u, v_th, v_init, alpha)
return spike, V_post, None
# Fallback: 3-phase approach (CPU, non-Sigmoid surrogates, or no surrogate)
# Phase 1: 线性轨迹 V_L (假设从不发放)
V_L = linear_recurrence(beta, u, v_init) # (K, *shape)
# Phase 2: Spike 不动点迭代(全部 detach,不建立梯度图)
# 目的:确定哪些神经元在哪些步发放(离散决策)
with torch.no_grad():
V_L_det = V_L.detach()
beta_det = beta.detach()
v_th_det = v_th.detach()
v_init_det = v_init.detach() if isinstance(v_init, torch.Tensor) else v_init
spike_pattern = (V_L_det >= v_th_det).float()
for _ in range(max_iter - 1):
# 计算 ΔS: ΔS[k] = beta[k] * ΔS[k-1] + v_th[k] * s[k]
delta_S = linear_recurrence(
beta_det, v_th_det * spike_pattern,
torch.zeros_like(v_init_det) if isinstance(v_init_det, torch.Tensor)
else torch.zeros_like(V_L_det[0]),
)
# ΔS_prev = ΔS[k-1](位移一步)
delta_S_prev = torch.zeros_like(delta_S)
delta_S_prev[1:] = delta_S[:-1]
# V_pre = V_L - beta * ΔS_prev
V_pre_det = V_L_det - beta_det * delta_S_prev
# 更新 spike
spike_new = (V_pre_det >= v_th_det).float()
# 收敛检查
if torch.equal(spike_new, spike_pattern):
break
spike_pattern = spike_new
# Phase 3: 用 converged spike pattern 重算 V_post(有完整梯度)
# spike_pattern 是 detached 的,作为常数参与计算
# 梯度通过 u, v_th, beta, v_init 流动
u_eff = u - v_th * spike_pattern
V_post = linear_recurrence(beta, u_eff, v_init) # (K, *shape)
# 重建 V_pre(有梯度,用于 surrogate gradient)
V_post_prev = torch.zeros_like(V_post)
if isinstance(v_init, torch.Tensor):
V_post_prev[0] = v_init
V_post_prev[1:] = V_post[:-1]
V_pre = beta * V_post_prev + u
# 生成可微 spike 输出
if surrogate_function is not None:
# forward: Heaviside(V_pre - v_th), backward: surrogate gradient
spike = surrogate_function(V_pre - v_th)
else:
# 退化模式:硬阈值,无梯度
spike = (V_pre >= v_th).float()
return spike, V_post, V_pre
def plif_rowparam_forward(
beta_row: torch.Tensor,
u: torch.Tensor,
v_th_row: torch.Tensor,
v_init: torch.Tensor,
surrogate_function=None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
行参数 PLIF 前向:beta 和 v_th 在 K 步中保持恒定。
比 plif_parallel_forward 快 ~40%(省去 expand+contiguous,减少 2/3 显存读取)。
用于 ParametricLIFNode(固定 beta/v_th)或合并多个固定参数神经元。
Args:
beta_row: (*shape) — 每列的衰减率(所有 K 步相同)
u: (K, *shape) — 每步输入
v_th_row: (*shape) — 每列的阈值(所有 K 步相同)
v_init: (*shape) — 初始膜电位
surrogate_function: surrogate gradient 函数
Returns:
spike: (K, *shape) — spike 模式
V_post: (K, *shape) — 发放后膜电位
"""
if (_HAS_TRITON and u.is_cuda and surrogate_function is not None
and hasattr(surrogate_function, 'alpha')
and type(surrogate_function).__name__ == 'Sigmoid'):
alpha = float(surrogate_function.alpha)
spike, V_post = _TritonPLIFRowParamForward.apply(
beta_row, u, v_th_row, v_init, alpha,
)
return spike, V_post
# Fallback: expand to full (K, *shape) and use standard path
K = u.shape[0]
beta = beta_row.unsqueeze(0).expand(K, *u.shape[1:]).contiguous()
v_th = v_th_row.unsqueeze(0).expand(K, *u.shape[1:]).contiguous()
spike, V_post, _ = plif_parallel_forward(beta, u, v_th, v_init, surrogate_function=surrogate_function)
return spike, V_post
def plif_fixed_param_forward(
beta,
u: torch.Tensor,
v_th,
v_init: torch.Tensor,
max_iter: int = 3,
surrogate_function=None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
固定参数 PLIF 神经元的并行前向(如输出神经元、FFN 神经元)。
ParametricLIFNode 方程: V[k] = beta * V[k-1] + (1-beta) * x[k]
其中 beta = 1/(1+exp(w)), 可为 scalar tensor(保持梯度流向 w)。
scalar/0-dim beta 和 v_th 使用 row-param 内核(无需 expand 到 (K, *shape))。
Args:
beta: 衰减率 — scalar float、0-dim tensor 或 (K, *shape) tensor
u: (K, *shape) — 输入(已乘以 (1-beta))
v_th: 阈值 — scalar float、0-dim tensor 或 (K, *shape) tensor
v_init: (*shape) — 初始膜电位
max_iter: spike 迭代次数
surrogate_function: surrogate gradient 函数
Returns:
spike: (K, *shape) — spike 模式
V_post: (K, *shape) — 发放后膜电位
"""
K = u.shape[0]
shape = u.shape[1:]
# Row-param fast path: beta 和 v_th 都是 scalar/0-dim → 扩展为 (*shape) 行向量
beta_is_scalar = isinstance(beta, torch.Tensor) and beta.dim() == 0
beta_is_float = not isinstance(beta, torch.Tensor)
vth_is_scalar = isinstance(v_th, torch.Tensor) and v_th.dim() == 0
vth_is_float = not isinstance(v_th, torch.Tensor)
if (beta_is_scalar or beta_is_float) and (vth_is_scalar or vth_is_float):
# Build row vectors (*shape)
if beta_is_scalar:
beta_row = beta.expand(*shape).contiguous()
else:
beta_row = torch.full(shape, beta, device=u.device, dtype=u.dtype)
if vth_is_scalar:
v_th_row = v_th.expand(*shape).contiguous()
else:
v_th_row = torch.full(shape, v_th, device=u.device, dtype=u.dtype)
return plif_rowparam_forward(beta_row, u, v_th_row, v_init, surrogate_function)
# Full-tensor path: expand to (K, *shape) if needed
if isinstance(beta, torch.Tensor):
if beta.dim() == 0:
beta = beta.expand(K, *shape).contiguous()
else:
beta = torch.full_like(u, beta)
if isinstance(v_th, torch.Tensor):
if v_th.dim() == 0:
v_th = v_th.expand(K, *shape).contiguous()
else:
v_th = torch.full_like(u, v_th)
spike, V_post, _ = plif_parallel_forward(
beta, u, v_th, v_init, max_iter, surrogate_function,
)
return spike, V_post