File size: 5,962 Bytes
4754707
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""v22: Track III.C — multi-bit integer FFN accumulator.

Architecture:
    v18 FFN: x(±1) → gate,up (±1) → XNOR → down (±1) → sign → ±1
    v22 FFN: x(±1) → up_raw (integer popcount) → CLIP[-B,+B] → down (±1 weights × small int) → sign → ±1

The hidden FFN activation is a small signed integer (3 bits for B=7, 4 for B=15)
instead of 1 bit. The down-projection becomes a signed-integer adder tree:
    z_i = Σ_j (W_down[i,j] ∈ {±1}) · (y_j ∈ [−B, +B])
which is still strictly integer arithmetic — conditionally negate y_j, sum.
No float multiply anywhere. Hardware cost: adder width grows from 0 (popcount)
to log₂(d_ff · B). For d_ff=512, B=7: 13-bit INT adder tree of depth 9.

Per ParetoQ / BitNet a4.8: this is the single highest-impact change for closing
the FP32 gap under strict ±1 weights. Expected 0.20-0.35 BPC drop at equal params.

Inference path:
  - All weights still 1-bit ±1
  - Intermediate FFN activation is 3-bit signed int (B=7 fits in 4 bits incl. sign)
  - All other activations still ±1
  - No float on the hot path
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from model import sign_ste, sign_ste_clipped, BitLinear, BinaryEmbedding
from model_v18 import IntBinaryAttention
from model_v16 import set_gumbel_tau


class IntFFN(nn.Module):
    """Gated FFN (SwiGLU analog) with clipped-integer `up` activation.

    Forward:
      g = sign(popcount(W_gate @ x)) in ±1     (unchanged from v18)
      u_int = clip(popcount(W_up @ x) * scale, -B, +B)   # small signed integer
      h = g * u_int                             # ±B range, gated by ±1
      z = sign(popcount-with-integer(W_down @ h)) in ±1
    """
    def __init__(self, d_model, d_ff, B=7):
        super().__init__()
        self.d_model = d_model
        self.d_ff = d_ff
        self.B = B
        # Gate: standard ±1 BitLinear producing sign mask
        self.gate = BitLinear(d_model, d_ff, binarize_input=True)
        # Up: raw popcount, no final sign — we clip instead
        self.up_w = nn.Parameter(torch.randn(d_ff, d_model) * 0.02)
        self.up_shift = nn.Parameter(torch.zeros(d_ff))
        self.up_scale = nn.Parameter(torch.tensor(1.0 / math.sqrt(d_model)))
        # Down: ±1 weights, integer activation input
        self.down_w = nn.Parameter(torch.randn(d_model, d_ff) * 0.02)
        self.down_threshold = nn.Parameter(torch.zeros(d_model))

    def forward(self, x):
        # x is ±1 of shape (..., d_model)
        g = self.gate(x)  # ±1

        W_up = sign_ste(self.up_w)
        x_bin = sign_ste_clipped(x)
        up_raw = F.linear(x_bin, W_up)  # integer popcount
        up_scaled = up_raw * self.up_scale - self.up_shift
        up_clipped = torch.clamp(up_scaled, -self.B, self.B)
        up_int = up_scaled + (up_clipped - up_scaled).detach()  # STE through clip

        # XNOR-style gate: ±1 gate multiplied by signed integer gives range ±B
        h = g * up_int  # signed values in [-B, +B]

        # Down projection: ±1 weights, multi-bit input
        W_down = sign_ste(self.down_w)
        down_raw = F.linear(h, W_down)  # signed integer adder tree
        # Normalize by sqrt(d_ff * B) to keep pre-sign in ~unit scale
        scale = 1.0 / math.sqrt(self.d_ff * max(self.B, 1))
        down_final = down_raw * scale - self.down_threshold
        return sign_ste_clipped(down_final)


class BitBlockV22(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, B=7):
        super().__init__()
        self.attn = IntBinaryAttention(d_model, n_heads)
        self.ffn = IntFFN(d_model, d_ff, B=B)

    def forward(self, x):
        a = self.attn(x)
        f = self.ffn(x)
        return sign_ste(x + a + f)


class BitLMv22(nn.Module):
    def __init__(self, vocab_size=128, d_model=256, n_layers=8, n_heads=8, d_ff=512,
                 max_seq_len=256, B=7):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.n_layers = n_layers
        self.max_seq_len = max_seq_len
        self.B = B
        self.embed = BinaryEmbedding(vocab_size, d_model)
        self.blocks = nn.ModuleList([
            BitBlockV22(d_model, n_heads, d_ff, B=B) for _ in range(n_layers)
        ])
        self.out_codebook = nn.Parameter(torch.randn(vocab_size, d_model) * 0.02)
        self.logit_scale = nn.Parameter(torch.tensor(1.0 / math.sqrt(d_model)))
        self.out_bias = nn.Parameter(torch.zeros(vocab_size))

    def forward(self, idx, targets=None):
        x = self.embed(idx)
        for blk in self.blocks:
            x = blk(x)
        W_out = sign_ste(self.out_codebook)
        scores = torch.matmul(x, W_out.t())
        logits = scores * self.logit_scale + self.out_bias
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, self.vocab_size), targets.view(-1))
        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens=200, temperature=1.0, top_k=None):
        self.eval()
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.max_seq_len:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / max(temperature, 1e-5)
            if top_k is not None:
                v, _ = torch.topk(logits, top_k)
                logits[logits < v[:, [-1]]] = -float('inf')
            probs = F.softmax(logits, dim=-1)
            nxt = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, nxt], dim=1)
        return idx


if __name__ == '__main__':
    set_gumbel_tau(0.5)
    for B in [3, 7, 15]:
        m = BitLMv22(B=B)
        n = sum(p.numel() for p in m.parameters())
        print(f'v22 B={B}: {n:,} params ({n/1e6:.2f}M)')
        x = torch.randint(0, 128, (2, 64))
        y = torch.randint(0, 128, (2, 64))
        logits, loss = m(x, y)
        loss.backward()
        print(f'  loss={loss.item():.3f}, backward OK')