Brain2nd's picture
Initial release: NeuronSpark-0.9B-Chat instruction-tuned SNN language model
440e322 verified
"""
FP16 二进制编码/解码 — 模型边界操作(无可训练参数)。
IEEE 754 float16 位布局(K=16 时间步):
时间步: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
位: sign E4 E3 E2 E1 E0 M9 M8 M7 M6 M5 M4 M3 M2 M1 M0
含义: 符号 ←── 指数(bias=15) ──→ ←────────── 尾数(隐含 1.xxx) ──────────→
编码: embedding → IEEE 754 float16 位提取 → 16 帧二值 spike(detach,固定预处理)
解码: 16 帧二值 spike → IEEE 754 位重建 → 连续值(可微分,梯度通过 surrogate grad 传播)
"""
import torch
from torch import Tensor
def fp16_encode(emb: Tensor, K: int = 16) -> Tensor:
"""FP16 二进制编码(模型边界操作,固定预处理)。
将连续 embedding 转为 IEEE 754 float16 位模式,作为 SNN 的 spike 输入。
Args:
emb: (batch, seq_len, D) 连续 embedding
K: 时间步数(必须为 16,对应 float16 的 16 位)
Returns:
spike_seq: (seq_len*K, batch, D) 二值 {0, 1}, detached
"""
batch, seq_len, D = emb.shape
# 转为 float16 获取 IEEE 754 位模式
# clamp 防止 overflow 产生 Inf(float16 最大值 65504)
emb_fp16 = emb.float().clamp(-65504.0, 65504.0).half()
bits_int = emb_fp16.view(torch.int16) # (batch, seq_len, D)
# 提取 16 位(MSB first: sign, exponent, mantissa)
shifts = torch.arange(15, -1, -1, device=emb.device) # [15, 14, ..., 0]
# bits_int: (batch, seq_len, D) → unsqueeze → (batch, seq_len, 1, D)
# shifts: (K,) → view → (1, 1, K, 1)
bits = ((bits_int.unsqueeze(2) >> shifts.view(1, 1, K, 1)) & 1) # (batch, seq_len, K, D)
# 转为计算 dtype 并 detach(编码不参与梯度)
bits = bits.to(emb.dtype).detach()
# reshape → (seq_len*K, batch, D)
return bits.reshape(batch, seq_len * K, D).permute(1, 0, 2).contiguous()
def fp16_decode(spikes: Tensor, seq_len: int, K: int = 16) -> Tensor:
"""FP16 精确位解码:从 16 个二值 spike 重建 float16 值。
fp16_encode 的精确逆操作。全程可微分——梯度通过 IEEE 754 重建公式
传到每个 spike 输出,再经 surrogate gradient 传入 SNN。
IEEE 754 float16 重建:
Normal (exp > 0): (-1)^sign * 2^(exp - 15) * (1 + mant_frac)
Subnormal (exp = 0): (-1)^sign * 2^(-14) * mant_frac
其中 mant_frac = Σ mant_bit_i * 2^{-(i+1)}, i=0..9
Args:
spikes: (seq_len*K, batch, D) 二值 {0, 1}(输出神经元的 spike)
seq_len: token 序列长度
K: 时间步数(= 16)
Returns:
decoded: (batch, seq_len, D) 连续值
"""
batch, D = spikes.shape[1], spikes.shape[2]
# (seq_len*K, batch, D) → (batch, seq_len, K, D)
s = spikes.permute(1, 0, 2).reshape(batch, seq_len, K, D)
# ---- Sign: bit 0 ----
sign = 1.0 - 2.0 * s[:, :, 0, :] # +1 or -1
# ---- Exponent: bits 1-5, 加权求和 → 整数 0~31 ----
exp_weights = torch.tensor(
[16.0, 8.0, 4.0, 2.0, 1.0],
device=spikes.device, dtype=spikes.dtype,
)
exp_val = (s[:, :, 1:6, :] * exp_weights.view(1, 1, 5, 1)).sum(dim=2)
# ---- Mantissa fraction: bits 6-15, 加权求和 → [0, 1) ----
mant_weights = torch.tensor(
[2.0 ** (-i) for i in range(1, 11)],
device=spikes.device, dtype=spikes.dtype,
)
mant_frac = (s[:, :, 6:, :] * mant_weights.view(1, 1, 10, 1)).sum(dim=2)
# ---- IEEE 754 重建 ----
# Normal: (-1)^s * 2^(exp-15) * (1 + mant_frac)
# Subnormal: (-1)^s * 2^(-14) * mant_frac
is_normal = (exp_val > 0)
normal_val = sign * torch.exp2(exp_val - 15.0) * (1.0 + mant_frac)
subnormal_val = sign * (2.0 ** -14) * mant_frac
return torch.where(is_normal, normal_val, subnormal_val)