hidude562's picture
1bitllm code (checkpoints to follow)
4754707 verified
"""Maximalist ±1 binary language model.
Forward-pass invariants (what the paper calls "true 1-bit"):
- Embeddings, Q/K/V/O, FFN weights, attention matrix, layer activations: ±1 via sign-STE.
- All matmuls are between ±1 operands (XNOR-popcount equivalents). Intermediate accumulators
are integers in [-k, k]. Thresholds are subtracted and sign is re-applied at the output.
- Residual stream: majority vote sign(x + F(x)) with stochastic tie-break on {x+F(x) == 0}.
- FFN gating: XNOR gate (elementwise multiply of two ±1 tensors).
- Normalization: none. ReZero-style identity residual path at init (threshold = 0 keeps
pre-activation balanced; F(x) starts near-balanced noise).
- Position: integer binary-ALiBi subtractive bias (per-head fixed slopes).
- Output head: tied ±1 embedding codebook. Score = popcount similarity. Softmax applied only
for training-time cross-entropy (acknowledged float concession at the loss surface).
Training-pass concession (§3 of the proposal): each ±1 weight has a latent float that we
call the "counter" in signSGD mode; it's standard for STE-trained networks but we bound it.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def sign_ste(x):
"""Sign with pure identity backward. Maps 0 -> +1."""
out = torch.where(x >= 0, torch.ones_like(x), -torch.ones_like(x))
return x + (out - x).detach()
def sign_ste_clipped(x):
"""Sign with hard-tanh backward (grad only for |x|<=1). Only works if x has been
pre-normalized to ~unit scale; otherwise gradients die."""
out = torch.where(x >= 0, torch.ones_like(x), -torch.ones_like(x))
x_clip = torch.clamp(x, -1.0, 1.0)
return x_clip + (out - x_clip).detach()
class BitLinearRaw(nn.Module):
"""Linear with ±1 weights. Returns raw integer popcount (no sign at output)."""
def __init__(self, in_features, out_features, binarize_input=True):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.binarize_input = binarize_input
# Latent float weight; forward uses sign(w). Small gaussian init gives balanced ±1.
self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.02)
def forward(self, x):
W = sign_ste(self.weight)
if self.binarize_input:
x = sign_ste_clipped(x)
return F.linear(x, W)
class BitLinear(nn.Module):
"""BitLinearRaw + learned threshold + sign. Returns ±1.
The popcount integer output has range [-k, k] with std ~sqrt(k) for balanced inputs.
We divide by sqrt(k) (a scalar constant) so the pre-sign values live at ~unit scale.
This does NOT introduce a float weight — it is just a fixed normalization so hard-tanh
STE actually passes gradients. BiBERT and BitNet both use an equivalent scaling.
"""
def __init__(self, in_features, out_features, binarize_input=True):
super().__init__()
self.raw = BitLinearRaw(in_features, out_features, binarize_input=binarize_input)
self.threshold = nn.Parameter(torch.zeros(out_features))
self.scale = 1.0 / math.sqrt(in_features)
def forward(self, x):
s = self.raw(x) * self.scale - self.threshold
return sign_ste_clipped(s)
class BiAttention(nn.Module):
"""BiBERT-style bool-threshold causal self-attention, fully ±1.
S = Q @ K^T (popcount integer)
S -= alibi_slope * |i-j| (integer subtractive bias, per head)
S -= tau (learned per-head threshold, BiBERT's entropy-max proxy)
mask future -> -inf
A = sign_ste(S) (±1)
mask future -> -1 (force attention off on future tokens)
O = A @ V (popcount integer)
return BitLinear(O) -> ±1
"""
def __init__(self, d_model, n_heads):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.q_proj = BitLinear(d_model, d_model, binarize_input=True)
self.k_proj = BitLinear(d_model, d_model, binarize_input=True)
self.v_proj = BitLinear(d_model, d_model, binarize_input=True)
self.o_proj = BitLinear(d_model, d_model, binarize_input=True)
self.attn_threshold = nn.Parameter(torch.zeros(n_heads))
# Integer binary-ALiBi slopes (fixed). Head 0 is global, later heads are local.
# slopes = [0.25, 0.5, 1, 2, 4, 8, 16, 32, ...]
slopes = torch.tensor([2.0 ** (i - 2) for i in range(n_heads)])
self.register_buffer('alibi_slopes', slopes)
self.register_buffer('_causal_mask', torch.empty(0), persistent=False)
def _get_mask(self, T, device):
if self._causal_mask.shape[-1] < T or self._causal_mask.device != device:
m = torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1)
self._causal_mask = m
return self._causal_mask[:T, :T]
def forward(self, x):
B, T, D = x.shape
H, Dh = self.n_heads, self.head_dim
Q = self.q_proj(x).view(B, T, H, Dh).transpose(1, 2) # (B, H, T, Dh)
K = self.k_proj(x).view(B, T, H, Dh).transpose(1, 2)
V = self.v_proj(x).view(B, T, H, Dh).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1)) # (B, H, T, T) integer popcount
# Scale by 1/sqrt(head_dim) so |scores| ~ O(1). This is the standard attention
# normalization; it's a fixed scalar constant, not a float weight.
scores = scores / math.sqrt(Dh)
pos = torch.arange(T, device=x.device).float()
dist = (pos.unsqueeze(0) - pos.unsqueeze(1)).abs() # (T, T)
alibi_bias = self.alibi_slopes.view(1, H, 1, 1) * dist.view(1, 1, T, T) / math.sqrt(Dh)
scores = scores - alibi_bias
mask = self._get_mask(T, x.device)
scores = scores.masked_fill(mask, -1e9)
tau = self.attn_threshold.view(1, H, 1, 1)
A = sign_ste_clipped(scores - tau)
# Force masked future positions to -1 on the forward (STE handles grad).
A = A.masked_fill(mask, -1.0)
O = torch.matmul(A, V) # (B, H, T, Dh), integer popcount
O = O.transpose(1, 2).contiguous().view(B, T, D)
return self.o_proj(O)
class BitFFN(nn.Module):
"""XNOR-gated binary FFN: down(gate(x) * up(x)). Multiplication of ±1 tensors stays ±1."""
def __init__(self, d_model, d_ff):
super().__init__()
self.gate = BitLinear(d_model, d_ff, binarize_input=True)
self.up = BitLinear(d_model, d_ff, binarize_input=True)
self.down = BitLinear(d_ff, d_model, binarize_input=True)
def forward(self, x):
g = self.gate(x)
u = self.up(x)
return self.down(g * u)
class BitBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff):
super().__init__()
self.attn = BiAttention(d_model, n_heads)
self.ffn = BitFFN(d_model, d_ff)
def _residual(self, x, fx):
"""Majority-vote residual. s = x + fx in {-2, 0, 2}. Sign+STE maps 0 to +1; the
branch inputs will learn to avoid exact ties. Forward is deterministic (same in
train and eval). STE passes gradient identically through the sum."""
return sign_ste(x + fx)
def forward(self, x):
x = self._residual(x, self.attn(x))
x = self._residual(x, self.ffn(x))
return x
class BinaryEmbedding(nn.Module):
def __init__(self, vocab_size, d_model):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.weight = nn.Parameter(torch.randn(vocab_size, d_model) * 0.02)
def forward(self, idx):
W = sign_ste(self.weight)
return F.embedding(idx, W)
def get_codebook(self):
return sign_ste(self.weight)
class BitLM(nn.Module):
"""Concessions at the loss surface (per graceful-degradation ladder):
- learnable output logit scale (1 float scalar)
- per-vocab output bias (V floats)
- untied ±1 output codebook (independent from input embedding)
All hidden computations remain ±1 with integer popcounts.
"""
def __init__(self, vocab_size=128, d_model=256, n_layers=8, n_heads=8, d_ff=512, max_seq_len=256):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layers = n_layers
self.max_seq_len = max_seq_len
self.embed = BinaryEmbedding(vocab_size, d_model)
self.blocks = nn.ModuleList([
BitBlock(d_model, n_heads, d_ff) for _ in range(n_layers)
])
# Independent output codebook (±1 like embedding, but not tied).
self.out_codebook = nn.Parameter(torch.randn(vocab_size, d_model) * 0.02)
self.logit_scale = nn.Parameter(torch.tensor(1.0 / math.sqrt(d_model)))
self.out_bias = nn.Parameter(torch.zeros(vocab_size))
def forward(self, idx, targets=None):
x = self.embed(idx)
for blk in self.blocks:
x = blk(x)
W_out = sign_ste(self.out_codebook)
scores = torch.matmul(x, W_out.t()) # integer popcount in [-D, D]
logits = scores * self.logit_scale + self.out_bias
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, self.vocab_size), targets.view(-1))
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens=200, temperature=1.0, top_k=None):
self.eval()
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.max_seq_len:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / max(temperature, 1e-5)
if top_k is not None:
v, _ = torch.topk(logits, top_k)
logits[logits < v[:, [-1]]] = -float('inf')
probs = F.softmax(logits, dim=-1)
nxt = torch.multinomial(probs, num_samples=1)
idx = torch.cat([idx, nxt], dim=1)
return idx
def param_count(m):
return sum(p.numel() for p in m.parameters())
if __name__ == '__main__':
model = BitLM()
print(f"total params: {param_count(model):,}")
x = torch.randint(0, 128, (2, 64))
y = torch.randint(0, 128, (2, 64))
logits, loss = model(x, y)
print("logits:", logits.shape, "loss:", loss.item())
loss.backward()
print("backward OK")