bitnet-1bitllm / vm_backup /code /model_v22.py
hidude562's picture
1bitllm code (checkpoints to follow)
4754707 verified
"""v22: Track III.C — multi-bit integer FFN accumulator.
Architecture:
v18 FFN: x(±1) → gate,up (±1) → XNOR → down (±1) → sign → ±1
v22 FFN: x(±1) → up_raw (integer popcount) → CLIP[-B,+B] → down (±1 weights × small int) → sign → ±1
The hidden FFN activation is a small signed integer (3 bits for B=7, 4 for B=15)
instead of 1 bit. The down-projection becomes a signed-integer adder tree:
z_i = Σ_j (W_down[i,j] ∈ {±1}) · (y_j ∈ [−B, +B])
which is still strictly integer arithmetic — conditionally negate y_j, sum.
No float multiply anywhere. Hardware cost: adder width grows from 0 (popcount)
to log₂(d_ff · B). For d_ff=512, B=7: 13-bit INT adder tree of depth 9.
Per ParetoQ / BitNet a4.8: this is the single highest-impact change for closing
the FP32 gap under strict ±1 weights. Expected 0.20-0.35 BPC drop at equal params.
Inference path:
- All weights still 1-bit ±1
- Intermediate FFN activation is 3-bit signed int (B=7 fits in 4 bits incl. sign)
- All other activations still ±1
- No float on the hot path
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from model import sign_ste, sign_ste_clipped, BitLinear, BinaryEmbedding
from model_v18 import IntBinaryAttention
from model_v16 import set_gumbel_tau
class IntFFN(nn.Module):
"""Gated FFN (SwiGLU analog) with clipped-integer `up` activation.
Forward:
g = sign(popcount(W_gate @ x)) in ±1 (unchanged from v18)
u_int = clip(popcount(W_up @ x) * scale, -B, +B) # small signed integer
h = g * u_int # ±B range, gated by ±1
z = sign(popcount-with-integer(W_down @ h)) in ±1
"""
def __init__(self, d_model, d_ff, B=7):
super().__init__()
self.d_model = d_model
self.d_ff = d_ff
self.B = B
# Gate: standard ±1 BitLinear producing sign mask
self.gate = BitLinear(d_model, d_ff, binarize_input=True)
# Up: raw popcount, no final sign — we clip instead
self.up_w = nn.Parameter(torch.randn(d_ff, d_model) * 0.02)
self.up_shift = nn.Parameter(torch.zeros(d_ff))
self.up_scale = nn.Parameter(torch.tensor(1.0 / math.sqrt(d_model)))
# Down: ±1 weights, integer activation input
self.down_w = nn.Parameter(torch.randn(d_model, d_ff) * 0.02)
self.down_threshold = nn.Parameter(torch.zeros(d_model))
def forward(self, x):
# x is ±1 of shape (..., d_model)
g = self.gate(x) # ±1
W_up = sign_ste(self.up_w)
x_bin = sign_ste_clipped(x)
up_raw = F.linear(x_bin, W_up) # integer popcount
up_scaled = up_raw * self.up_scale - self.up_shift
up_clipped = torch.clamp(up_scaled, -self.B, self.B)
up_int = up_scaled + (up_clipped - up_scaled).detach() # STE through clip
# XNOR-style gate: ±1 gate multiplied by signed integer gives range ±B
h = g * up_int # signed values in [-B, +B]
# Down projection: ±1 weights, multi-bit input
W_down = sign_ste(self.down_w)
down_raw = F.linear(h, W_down) # signed integer adder tree
# Normalize by sqrt(d_ff * B) to keep pre-sign in ~unit scale
scale = 1.0 / math.sqrt(self.d_ff * max(self.B, 1))
down_final = down_raw * scale - self.down_threshold
return sign_ste_clipped(down_final)
class BitBlockV22(nn.Module):
def __init__(self, d_model, n_heads, d_ff, B=7):
super().__init__()
self.attn = IntBinaryAttention(d_model, n_heads)
self.ffn = IntFFN(d_model, d_ff, B=B)
def forward(self, x):
a = self.attn(x)
f = self.ffn(x)
return sign_ste(x + a + f)
class BitLMv22(nn.Module):
def __init__(self, vocab_size=128, d_model=256, n_layers=8, n_heads=8, d_ff=512,
max_seq_len=256, B=7):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layers = n_layers
self.max_seq_len = max_seq_len
self.B = B
self.embed = BinaryEmbedding(vocab_size, d_model)
self.blocks = nn.ModuleList([
BitBlockV22(d_model, n_heads, d_ff, B=B) for _ in range(n_layers)
])
self.out_codebook = nn.Parameter(torch.randn(vocab_size, d_model) * 0.02)
self.logit_scale = nn.Parameter(torch.tensor(1.0 / math.sqrt(d_model)))
self.out_bias = nn.Parameter(torch.zeros(vocab_size))
def forward(self, idx, targets=None):
x = self.embed(idx)
for blk in self.blocks:
x = blk(x)
W_out = sign_ste(self.out_codebook)
scores = torch.matmul(x, W_out.t())
logits = scores * self.logit_scale + self.out_bias
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, self.vocab_size), targets.view(-1))
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens=200, temperature=1.0, top_k=None):
self.eval()
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.max_seq_len:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / max(temperature, 1e-5)
if top_k is not None:
v, _ = torch.topk(logits, top_k)
logits[logits < v[:, [-1]]] = -float('inf')
probs = F.softmax(logits, dim=-1)
nxt = torch.multinomial(probs, num_samples=1)
idx = torch.cat([idx, nxt], dim=1)
return idx
if __name__ == '__main__':
set_gumbel_tau(0.5)
for B in [3, 7, 15]:
m = BitLMv22(B=B)
n = sum(p.numel() for p in m.parameters())
print(f'v22 B={B}: {n:,} params ({n/1e6:.2f}M)')
x = torch.randint(0, 128, (2, 64))
y = torch.randint(0, 128, (2, 64))
logits, loss = m(x, y)
loss.backward()
print(f' loss={loss.item():.3f}, backward OK')