File size: 3,919 Bytes
440e322
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""
FP16 二进制编码/解码 — 模型边界操作(无可训练参数)。

IEEE 754 float16 位布局(K=16 时间步):
  时间步:  0    1    2    3    4    5    6    7    8    9   10   11   12   13   14   15
  位:     sign  E4   E3   E2   E1   E0   M9   M8   M7   M6   M5   M4   M3   M2   M1   M0
  含义:   符号  ←── 指数(bias=15) ──→  ←────────── 尾数(隐含 1.xxx) ──────────→

编码: embedding → IEEE 754 float16 位提取 → 16 帧二值 spike(detach,固定预处理)
解码: 16 帧二值 spike → IEEE 754 位重建 → 连续值(可微分,梯度通过 surrogate grad 传播)
"""

import torch
from torch import Tensor


def fp16_encode(emb: Tensor, K: int = 16) -> Tensor:
    """FP16 二进制编码(模型边界操作,固定预处理)。

    将连续 embedding 转为 IEEE 754 float16 位模式,作为 SNN 的 spike 输入。

    Args:
        emb: (batch, seq_len, D) 连续 embedding
        K: 时间步数(必须为 16,对应 float16 的 16 位)

    Returns:
        spike_seq: (seq_len*K, batch, D) 二值 {0, 1}, detached
    """
    batch, seq_len, D = emb.shape

    # 转为 float16 获取 IEEE 754 位模式
    # clamp 防止 overflow 产生 Inf(float16 最大值 65504)
    emb_fp16 = emb.float().clamp(-65504.0, 65504.0).half()
    bits_int = emb_fp16.view(torch.int16)  # (batch, seq_len, D)

    # 提取 16 位(MSB first: sign, exponent, mantissa)
    shifts = torch.arange(15, -1, -1, device=emb.device)  # [15, 14, ..., 0]
    # bits_int: (batch, seq_len, D) → unsqueeze → (batch, seq_len, 1, D)
    # shifts: (K,) → view → (1, 1, K, 1)
    bits = ((bits_int.unsqueeze(2) >> shifts.view(1, 1, K, 1)) & 1)  # (batch, seq_len, K, D)

    # 转为计算 dtype 并 detach(编码不参与梯度)
    bits = bits.to(emb.dtype).detach()

    # reshape → (seq_len*K, batch, D)
    return bits.reshape(batch, seq_len * K, D).permute(1, 0, 2).contiguous()


def fp16_decode(spikes: Tensor, seq_len: int, K: int = 16) -> Tensor:
    """FP16 精确位解码:从 16 个二值 spike 重建 float16 值。

    fp16_encode 的精确逆操作。全程可微分——梯度通过 IEEE 754 重建公式
    传到每个 spike 输出,再经 surrogate gradient 传入 SNN。

    IEEE 754 float16 重建:
      Normal   (exp > 0):  (-1)^sign * 2^(exp - 15) * (1 + mant_frac)
      Subnormal (exp = 0): (-1)^sign * 2^(-14)      * mant_frac
      其中 mant_frac = Σ mant_bit_i * 2^{-(i+1)}, i=0..9

    Args:
        spikes: (seq_len*K, batch, D) 二值 {0, 1}(输出神经元的 spike)
        seq_len: token 序列长度
        K: 时间步数(= 16)

    Returns:
        decoded: (batch, seq_len, D) 连续值
    """
    batch, D = spikes.shape[1], spikes.shape[2]

    # (seq_len*K, batch, D) → (batch, seq_len, K, D)
    s = spikes.permute(1, 0, 2).reshape(batch, seq_len, K, D)

    # ---- Sign: bit 0 ----
    sign = 1.0 - 2.0 * s[:, :, 0, :]  # +1 or -1

    # ---- Exponent: bits 1-5, 加权求和 → 整数 0~31 ----
    exp_weights = torch.tensor(
        [16.0, 8.0, 4.0, 2.0, 1.0],
        device=spikes.device, dtype=spikes.dtype,
    )
    exp_val = (s[:, :, 1:6, :] * exp_weights.view(1, 1, 5, 1)).sum(dim=2)

    # ---- Mantissa fraction: bits 6-15, 加权求和 → [0, 1) ----
    mant_weights = torch.tensor(
        [2.0 ** (-i) for i in range(1, 11)],
        device=spikes.device, dtype=spikes.dtype,
    )
    mant_frac = (s[:, :, 6:, :] * mant_weights.view(1, 1, 10, 1)).sum(dim=2)

    # ---- IEEE 754 重建 ----
    # Normal:    (-1)^s * 2^(exp-15) * (1 + mant_frac)
    # Subnormal: (-1)^s * 2^(-14)   * mant_frac
    is_normal = (exp_val > 0)

    normal_val = sign * torch.exp2(exp_val - 15.0) * (1.0 + mant_frac)
    subnormal_val = sign * (2.0 ** -14) * mant_frac

    return torch.where(is_normal, normal_val, subnormal_val)