bitnet-1bitllm / vm_backup /code /model_v30.py
hidude562's picture
1bitllm code (checkpoints to follow)
4754707 verified
"""v30: Doubled Binary — each weight stored as TWO independent ±1 bits (W_A, W_B).
Effective weight W = W_A + W_B has values in {−2, 0, +2} — strict ternary on a
binary substrate. This closes the ternary-vs-binary gap ParetoQ identified
(~0.2-0.3 BPC on LLaMA) while keeping every operation as XNOR + popcount + add.
At inference the output of a DoubleBitLinear layer is:
y_i = popcount(W_A[i] XNOR x) + popcount(W_B[i] XNOR x) − in_features
which is one extra XNOR-popcount per output row vs standard v18. Memory doubles.
Attention, FFN, embeddings, residuals, and output head all use DoubleBitLinear
(and a doubled embedding codebook). Activations remain strictly ±1.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from model import sign_ste, sign_ste_clipped
from model_v18 import IntBinaryAttention # reuse attention shell
from model_v16 import set_gumbel_tau
def double_bin_linear_forward(x, W_A_bits, W_B_bits, threshold, in_features, scale):
"""Both weight halves are ±1; output is the sum of two popcount dot products."""
W_A = sign_ste(W_A_bits)
W_B = sign_ste(W_B_bits)
x_bin = sign_ste_clipped(x)
# Two matmuls; sum them; scale; threshold; sign.
y = F.linear(x_bin, W_A) + F.linear(x_bin, W_B) # effective ternary weight sum
return sign_ste_clipped(y * scale - threshold)
class DoubleBitLinear(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.in_features = in_features
self.out_features = out_features
# Two independent ±1 weight matrices
self.weight_A = nn.Parameter(torch.randn(out_features, in_features) * 0.02)
self.weight_B = nn.Parameter(torch.randn(out_features, in_features) * 0.02)
self.threshold = nn.Parameter(torch.zeros(out_features))
# Scale: since the effective sum is in [-2*in, +2*in] instead of [-in, +in],
# we scale by 1/(2*sqrt(in)) to keep pre-sign at unit scale.
self.scale = 1.0 / (2.0 * math.sqrt(in_features))
def forward(self, x):
return double_bin_linear_forward(
x, self.weight_A, self.weight_B, self.threshold, self.in_features, self.scale)
class DoubleBiAttention(nn.Module):
"""v18's IntBinaryAttention but with DoubleBitLinear projections."""
def __init__(self, d_model, n_heads):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.q_proj = DoubleBitLinear(d_model, d_model)
self.k_proj = DoubleBitLinear(d_model, d_model)
self.v_proj = DoubleBitLinear(d_model, d_model)
self.o_proj = DoubleBitLinear(d_model, d_model)
slopes = torch.tensor([1 << i for i in range(n_heads)], dtype=torch.long)
self.register_buffer('alibi_slopes_int', slopes)
self.register_buffer('_causal_mask', torch.empty(0), persistent=False)
def _get_mask(self, T, device):
if self._causal_mask.shape[-1] < T or self._causal_mask.device != device:
m = torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1)
self._causal_mask = m
return self._causal_mask[:T, :T]
def _gumbel_hard(self, scores):
from model_v16 import _get_tau
tau = _get_tau(scores.device)
if scores.requires_grad:
g = -torch.log(-torch.log(torch.rand_like(scores).clamp(min=1e-9)) + 1e-9)
y_soft = F.softmax((scores + g) / tau, dim=-1)
y_hard = torch.zeros_like(y_soft)
y_hard.scatter_(-1, y_soft.argmax(-1, keepdim=True), 1.0)
return y_soft + (y_hard - y_soft).detach()
else:
y = torch.zeros_like(scores)
y.scatter_(-1, scores.argmax(-1, keepdim=True), 1.0)
return y
def forward(self, x):
B, T, D = x.shape
H, Dh = self.n_heads, self.head_dim
Q = self.q_proj(x).view(B, T, H, Dh).transpose(1, 2)
K = self.k_proj(x).view(B, T, H, Dh).transpose(1, 2)
V = self.v_proj(x).view(B, T, H, Dh).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1))
pos = torch.arange(T, device=Q.device)
dist = (pos.unsqueeze(0) - pos.unsqueeze(1)).abs().to(Q.dtype)
alibi = self.alibi_slopes_int.view(1, H, 1, 1).to(Q.dtype) * dist.view(1, 1, T, T)
scores = scores - alibi
mask = self._get_mask(T, x.device)
scores = scores.masked_fill(mask, -1e9)
A = self._gumbel_hard(scores)
O = torch.matmul(A, V)
O = O.transpose(1, 2).contiguous().view(B, T, D)
return self.o_proj(O)
class DoubleBitFFN(nn.Module):
def __init__(self, d_model, d_ff):
super().__init__()
self.gate = DoubleBitLinear(d_model, d_ff)
self.up = DoubleBitLinear(d_model, d_ff)
self.down = DoubleBitLinear(d_ff, d_model)
def forward(self, x):
return self.down(self.gate(x) * self.up(x))
class BitBlockV30(nn.Module):
def __init__(self, d_model, n_heads, d_ff):
super().__init__()
self.attn = DoubleBiAttention(d_model, n_heads)
self.ffn = DoubleBitFFN(d_model, d_ff)
def forward(self, x):
a = self.attn(x)
f = self.ffn(x)
return sign_ste(x + a + f)
class DoubleBinaryEmbedding(nn.Module):
"""Embedding with two ±1 codebooks summed; effective ternary."""
def __init__(self, vocab_size, d_model):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.weight_A = nn.Parameter(torch.randn(vocab_size, d_model) * 0.02)
self.weight_B = nn.Parameter(torch.randn(vocab_size, d_model) * 0.02)
def forward(self, idx):
W_A = sign_ste(self.weight_A)
W_B = sign_ste(self.weight_B)
# Sum-and-sign to keep embedding strictly ±1 at block input
# (alternatively we could go ternary here too — but we keep input ±1 for clarity).
W = sign_ste(W_A + W_B)
return F.embedding(idx, W)
def get_codebook(self):
return sign_ste(sign_ste(self.weight_A) + sign_ste(self.weight_B))
class BitLMv30(nn.Module):
def __init__(self, vocab_size=128, d_model=256, n_layers=8, n_heads=8, d_ff=512,
max_seq_len=256):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layers = n_layers
self.max_seq_len = max_seq_len
self.embed = DoubleBinaryEmbedding(vocab_size, d_model)
self.blocks = nn.ModuleList([
BitBlockV30(d_model, n_heads, d_ff) for _ in range(n_layers)
])
# Doubled output codebook for ternary-effective output head
self.out_codebook_A = nn.Parameter(torch.randn(vocab_size, d_model) * 0.02)
self.out_codebook_B = nn.Parameter(torch.randn(vocab_size, d_model) * 0.02)
self.logit_scale = nn.Parameter(torch.tensor(1.0 / (2.0 * math.sqrt(d_model))))
self.out_bias = nn.Parameter(torch.zeros(vocab_size))
def forward(self, idx, targets=None):
x = self.embed(idx)
for blk in self.blocks:
x = blk(x)
W_A = sign_ste(self.out_codebook_A)
W_B = sign_ste(self.out_codebook_B)
# Sum two popcount similarities for ternary effective logits
scores = torch.matmul(x, W_A.t()) + torch.matmul(x, W_B.t())
logits = scores * self.logit_scale + self.out_bias
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, self.vocab_size), targets.view(-1))
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens=200, temperature=1.0, top_k=None):
self.eval()
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.max_seq_len:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / max(temperature, 1e-5)
if top_k is not None:
v, _ = torch.topk(logits, top_k)
logits[logits < v[:, [-1]]] = -float('inf')
probs = F.softmax(logits, dim=-1)
nxt = torch.multinomial(probs, num_samples=1)
idx = torch.cat([idx, nxt], dim=1)
return idx
if __name__ == '__main__':
set_gumbel_tau(0.5)
for cfg_name, d, L, d_ff in [('5M', 256, 8, 512), ('50M', 768, 10, 1280)]:
m = BitLMv30(vocab_size=128, d_model=d, n_layers=L, n_heads=max(8, d//64), d_ff=d_ff)
n = sum(p.numel() for p in m.parameters())
print(f'v30 {cfg_name}: {n:,} params ({n/1e6:.2f}M)')
x = torch.randint(0, 128, (2, 64))
y = torch.randint(0, 128, (2, 64))
logits, loss = m(x, y)
loss.backward()
print(f' loss={loss.item():.3f}, backward OK')