| """ |
| LateralInhibition: 侧抑制归一化(Divisive Normalization) |
| |
| 神经科学基础: |
| Carandini & Heeger (2012) "Normalization as a canonical neural computation" |
| 侧抑制是大脑中最基本的计算原语之一:兴奋性神经元的活动通过抑制性中间神经元池 |
| 反馈调节,实现增益控制(gain control)。 |
| |
| SNN 机制: |
| 1. 兴奋性群体活动: activity_i = h_i² |
| 2. 抑制性中间神经元池: pool = mean(activity) = mean(h²) |
| 3. 分裂抑制 (shunting inhibition): h_norm = h / sqrt(pool + ε) |
| 4. 增益调制 (gain modulation): output = gain · h_norm |
| |
| 替换 RMSNorm:数学操作等价,但在 SNN 框架中有明确的神经科学对应—— |
| RMSNorm 是 divisive normalization 的特例。 |
| |
| Triton fused kernel: |
| - 前向: {mean(h²), rsqrt, element-wise mul} → 1 kernel launch |
| - 反向: {recompute norm, grad_gain, grad_h} → 1 kernel launch |
| - 每行 (D dim) 一个 block,行间并行 |
| """ |
|
|
| import os |
|
|
| import torch |
| import torch.nn as nn |
| from spikingjelly.activation_based import base |
|
|
|
|
| |
| |
| |
|
|
| _SYSTEM_PTXAS = '/usr/local/cuda-13.0/bin/ptxas' |
| if os.path.exists(_SYSTEM_PTXAS) and 'TRITON_PTXAS_PATH' not in os.environ: |
| os.environ['TRITON_PTXAS_PATH'] = _SYSTEM_PTXAS |
|
|
| _HAS_TRITON = False |
| try: |
| import triton |
| import triton.language as tl |
| _HAS_TRITON = True |
| except ImportError: |
| pass |
|
|
|
|
| if _HAS_TRITON: |
|
|
| @triton.jit |
| def _li_fwd_kernel( |
| X_ptr, GAIN_ptr, OUT_ptr, |
| stride_row, |
| D: tl.constexpr, |
| eps: tl.constexpr, |
| BLOCK_D: tl.constexpr, |
| ): |
| """Forward: out = x * rsqrt(mean(x²) + eps) * gain |
| |
| Grid: (num_rows,). Each program processes one row of D elements. |
| Computation in float32; storage in input dtype. |
| """ |
| row = tl.program_id(0) |
| cols = tl.arange(0, BLOCK_D) |
| mask = cols < D |
| off = row * stride_row + cols |
|
|
| |
| x = tl.load(X_ptr + off, mask=mask, other=0.0).to(tl.float32) |
| gain = tl.load(GAIN_ptr + cols, mask=mask, other=0.0).to(tl.float32) |
|
|
| |
| variance = tl.sum(x * x, axis=0) / D |
| rrms = 1.0 / tl.sqrt(variance + eps) |
|
|
| |
| out = x * rrms * gain |
|
|
| tl.store(OUT_ptr + off, out, mask=mask) |
|
|
| @triton.jit |
| def _li_bwd_kernel( |
| DOUT_ptr, X_ptr, GAIN_ptr, |
| DX_ptr, DGAIN_ptr, |
| stride_row, |
| D: tl.constexpr, |
| eps: tl.constexpr, |
| BLOCK_D: tl.constexpr, |
| ): |
| """Backward: grad_x, grad_gain (per-row, reduced externally). |
| |
| Grid: (num_rows,). |
| d_x = rrms * (d_out * gain - x_hat * mean(d_out * gain * x_hat)) |
| d_gain_row = d_out * x_hat (sum across rows done outside kernel) |
| """ |
| row = tl.program_id(0) |
| cols = tl.arange(0, BLOCK_D) |
| mask = cols < D |
| off = row * stride_row + cols |
|
|
| dout = tl.load(DOUT_ptr + off, mask=mask, other=0.0).to(tl.float32) |
| x = tl.load(X_ptr + off, mask=mask, other=0.0).to(tl.float32) |
| gain = tl.load(GAIN_ptr + cols, mask=mask, other=0.0).to(tl.float32) |
|
|
| |
| variance = tl.sum(x * x, axis=0) / D |
| rrms = 1.0 / tl.sqrt(variance + eps) |
| x_hat = x * rrms |
|
|
| |
| dgain = dout * x_hat |
| tl.store(DGAIN_ptr + off, dgain, mask=mask) |
|
|
| |
| dout_gain = dout * gain |
| dot = tl.sum(dout_gain * x_hat, axis=0) / D |
| dx = (dout_gain - x_hat * dot) * rrms |
|
|
| tl.store(DX_ptr + off, dx, mask=mask) |
|
|
|
|
| class _LateralInhibitionTriton(torch.autograd.Function): |
| """Triton-accelerated lateral inhibition (divisive normalization).""" |
|
|
| @staticmethod |
| def forward(ctx, x, gain, eps): |
| orig_shape = x.shape |
| D = x.shape[-1] |
| x_2d = x.reshape(-1, D).contiguous() |
| N = x_2d.shape[0] |
|
|
| out = torch.empty_like(x_2d) |
|
|
| BLOCK_D = triton.next_power_of_2(D) |
| _li_fwd_kernel[(N,)]( |
| x_2d, gain, out, |
| x_2d.stride(0), |
| D=D, eps=eps, BLOCK_D=BLOCK_D, |
| ) |
|
|
| ctx.save_for_backward(x_2d, gain) |
| ctx.eps = eps |
| ctx.orig_shape = orig_shape |
| ctx.N = N |
| ctx.D = D |
|
|
| return out.reshape(orig_shape) |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| x_2d, gain = ctx.saved_tensors |
| D = ctx.D |
| N = ctx.N |
|
|
| grad_2d = grad_output.reshape(N, D).contiguous() |
|
|
| dx = torch.empty_like(x_2d) |
| dgain_rows = torch.empty_like(x_2d) |
|
|
| BLOCK_D = triton.next_power_of_2(D) |
| _li_bwd_kernel[(N,)]( |
| grad_2d, x_2d, gain, |
| dx, dgain_rows, |
| x_2d.stride(0), |
| D=D, eps=ctx.eps, BLOCK_D=BLOCK_D, |
| ) |
|
|
| |
| dgain = dgain_rows.sum(dim=0) |
|
|
| return dx.reshape(ctx.orig_shape), dgain, None |
|
|
|
|
| |
| |
| |
|
|
| class LateralInhibition(base.MemoryModule): |
| """ |
| 侧抑制归一化层(Divisive Normalization)。 |
| |
| 通过抑制性中间神经元池实现增益控制。 |
| |
| 数学: |
| pool = mean(h², dim=-1) # 抑制性池:群体活动水平 |
| h_norm = h / sqrt(pool + ε) # 分裂抑制 (shunting inhibition) |
| output = gain · h_norm # 增益调制 (gain modulation) |
| |
| 等价于 RMSNorm,但在 SNN 框架中对应 divisive normalization |
| (Carandini & Heeger, 2012),是神经科学中最基本的计算原语之一。 |
| |
| CUDA: Triton fused kernel(前向+反向各 1 次 launch) |
| CPU: PyTorch fallback |
| |
| Args: |
| dim: 特征维度(D) |
| eps: 数值稳定性 |
| """ |
|
|
| def __init__(self, dim: int, eps: float = 1e-6): |
| super().__init__() |
| self.gain = nn.Parameter(torch.ones(dim)) |
| self.eps = eps |
| self.dim = dim |
|
|
| def forward(self, h: torch.Tensor) -> torch.Tensor: |
| if _HAS_TRITON and h.is_cuda: |
| return _LateralInhibitionTriton.apply(h, self.gain, self.eps) |
| |
| variance = h.pow(2).mean(-1, keepdim=True) |
| h_norm = h * torch.rsqrt(variance + self.eps) |
| return self.gain * h_norm |
|
|
| def extra_repr(self): |
| return f'dim={self.dim}, eps={self.eps}' |
|
|