"""Maximalist ±1 binary language model. Forward-pass invariants (what the paper calls "true 1-bit"): - Embeddings, Q/K/V/O, FFN weights, attention matrix, layer activations: ±1 via sign-STE. - All matmuls are between ±1 operands (XNOR-popcount equivalents). Intermediate accumulators are integers in [-k, k]. Thresholds are subtracted and sign is re-applied at the output. - Residual stream: majority vote sign(x + F(x)) with stochastic tie-break on {x+F(x) == 0}. - FFN gating: XNOR gate (elementwise multiply of two ±1 tensors). - Normalization: none. ReZero-style identity residual path at init (threshold = 0 keeps pre-activation balanced; F(x) starts near-balanced noise). - Position: integer binary-ALiBi subtractive bias (per-head fixed slopes). - Output head: tied ±1 embedding codebook. Score = popcount similarity. Softmax applied only for training-time cross-entropy (acknowledged float concession at the loss surface). Training-pass concession (§3 of the proposal): each ±1 weight has a latent float that we call the "counter" in signSGD mode; it's standard for STE-trained networks but we bound it. """ import math import torch import torch.nn as nn import torch.nn.functional as F def sign_ste(x): """Sign with pure identity backward. Maps 0 -> +1.""" out = torch.where(x >= 0, torch.ones_like(x), -torch.ones_like(x)) return x + (out - x).detach() def sign_ste_clipped(x): """Sign with hard-tanh backward (grad only for |x|<=1). Only works if x has been pre-normalized to ~unit scale; otherwise gradients die.""" out = torch.where(x >= 0, torch.ones_like(x), -torch.ones_like(x)) x_clip = torch.clamp(x, -1.0, 1.0) return x_clip + (out - x_clip).detach() class BitLinearRaw(nn.Module): """Linear with ±1 weights. Returns raw integer popcount (no sign at output).""" def __init__(self, in_features, out_features, binarize_input=True): super().__init__() self.in_features = in_features self.out_features = out_features self.binarize_input = binarize_input # Latent float weight; forward uses sign(w). Small gaussian init gives balanced ±1. self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.02) def forward(self, x): W = sign_ste(self.weight) if self.binarize_input: x = sign_ste_clipped(x) return F.linear(x, W) class BitLinear(nn.Module): """BitLinearRaw + learned threshold + sign. Returns ±1. The popcount integer output has range [-k, k] with std ~sqrt(k) for balanced inputs. We divide by sqrt(k) (a scalar constant) so the pre-sign values live at ~unit scale. This does NOT introduce a float weight — it is just a fixed normalization so hard-tanh STE actually passes gradients. BiBERT and BitNet both use an equivalent scaling. """ def __init__(self, in_features, out_features, binarize_input=True): super().__init__() self.raw = BitLinearRaw(in_features, out_features, binarize_input=binarize_input) self.threshold = nn.Parameter(torch.zeros(out_features)) self.scale = 1.0 / math.sqrt(in_features) def forward(self, x): s = self.raw(x) * self.scale - self.threshold return sign_ste_clipped(s) class BiAttention(nn.Module): """BiBERT-style bool-threshold causal self-attention, fully ±1. S = Q @ K^T (popcount integer) S -= alibi_slope * |i-j| (integer subtractive bias, per head) S -= tau (learned per-head threshold, BiBERT's entropy-max proxy) mask future -> -inf A = sign_ste(S) (±1) mask future -> -1 (force attention off on future tokens) O = A @ V (popcount integer) return BitLinear(O) -> ±1 """ def __init__(self, d_model, n_heads): super().__init__() assert d_model % n_heads == 0 self.d_model = d_model self.n_heads = n_heads self.head_dim = d_model // n_heads self.q_proj = BitLinear(d_model, d_model, binarize_input=True) self.k_proj = BitLinear(d_model, d_model, binarize_input=True) self.v_proj = BitLinear(d_model, d_model, binarize_input=True) self.o_proj = BitLinear(d_model, d_model, binarize_input=True) self.attn_threshold = nn.Parameter(torch.zeros(n_heads)) # Integer binary-ALiBi slopes (fixed). Head 0 is global, later heads are local. # slopes = [0.25, 0.5, 1, 2, 4, 8, 16, 32, ...] slopes = torch.tensor([2.0 ** (i - 2) for i in range(n_heads)]) self.register_buffer('alibi_slopes', slopes) self.register_buffer('_causal_mask', torch.empty(0), persistent=False) def _get_mask(self, T, device): if self._causal_mask.shape[-1] < T or self._causal_mask.device != device: m = torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1) self._causal_mask = m return self._causal_mask[:T, :T] def forward(self, x): B, T, D = x.shape H, Dh = self.n_heads, self.head_dim Q = self.q_proj(x).view(B, T, H, Dh).transpose(1, 2) # (B, H, T, Dh) K = self.k_proj(x).view(B, T, H, Dh).transpose(1, 2) V = self.v_proj(x).view(B, T, H, Dh).transpose(1, 2) scores = torch.matmul(Q, K.transpose(-2, -1)) # (B, H, T, T) integer popcount # Scale by 1/sqrt(head_dim) so |scores| ~ O(1). This is the standard attention # normalization; it's a fixed scalar constant, not a float weight. scores = scores / math.sqrt(Dh) pos = torch.arange(T, device=x.device).float() dist = (pos.unsqueeze(0) - pos.unsqueeze(1)).abs() # (T, T) alibi_bias = self.alibi_slopes.view(1, H, 1, 1) * dist.view(1, 1, T, T) / math.sqrt(Dh) scores = scores - alibi_bias mask = self._get_mask(T, x.device) scores = scores.masked_fill(mask, -1e9) tau = self.attn_threshold.view(1, H, 1, 1) A = sign_ste_clipped(scores - tau) # Force masked future positions to -1 on the forward (STE handles grad). A = A.masked_fill(mask, -1.0) O = torch.matmul(A, V) # (B, H, T, Dh), integer popcount O = O.transpose(1, 2).contiguous().view(B, T, D) return self.o_proj(O) class BitFFN(nn.Module): """XNOR-gated binary FFN: down(gate(x) * up(x)). Multiplication of ±1 tensors stays ±1.""" def __init__(self, d_model, d_ff): super().__init__() self.gate = BitLinear(d_model, d_ff, binarize_input=True) self.up = BitLinear(d_model, d_ff, binarize_input=True) self.down = BitLinear(d_ff, d_model, binarize_input=True) def forward(self, x): g = self.gate(x) u = self.up(x) return self.down(g * u) class BitBlock(nn.Module): def __init__(self, d_model, n_heads, d_ff): super().__init__() self.attn = BiAttention(d_model, n_heads) self.ffn = BitFFN(d_model, d_ff) def _residual(self, x, fx): """Majority-vote residual. s = x + fx in {-2, 0, 2}. Sign+STE maps 0 to +1; the branch inputs will learn to avoid exact ties. Forward is deterministic (same in train and eval). STE passes gradient identically through the sum.""" return sign_ste(x + fx) def forward(self, x): x = self._residual(x, self.attn(x)) x = self._residual(x, self.ffn(x)) return x class BinaryEmbedding(nn.Module): def __init__(self, vocab_size, d_model): super().__init__() self.vocab_size = vocab_size self.d_model = d_model self.weight = nn.Parameter(torch.randn(vocab_size, d_model) * 0.02) def forward(self, idx): W = sign_ste(self.weight) return F.embedding(idx, W) def get_codebook(self): return sign_ste(self.weight) class BitLM(nn.Module): """Concessions at the loss surface (per graceful-degradation ladder): - learnable output logit scale (1 float scalar) - per-vocab output bias (V floats) - untied ±1 output codebook (independent from input embedding) All hidden computations remain ±1 with integer popcounts. """ def __init__(self, vocab_size=128, d_model=256, n_layers=8, n_heads=8, d_ff=512, max_seq_len=256): super().__init__() self.vocab_size = vocab_size self.d_model = d_model self.n_layers = n_layers self.max_seq_len = max_seq_len self.embed = BinaryEmbedding(vocab_size, d_model) self.blocks = nn.ModuleList([ BitBlock(d_model, n_heads, d_ff) for _ in range(n_layers) ]) # Independent output codebook (±1 like embedding, but not tied). self.out_codebook = nn.Parameter(torch.randn(vocab_size, d_model) * 0.02) self.logit_scale = nn.Parameter(torch.tensor(1.0 / math.sqrt(d_model))) self.out_bias = nn.Parameter(torch.zeros(vocab_size)) def forward(self, idx, targets=None): x = self.embed(idx) for blk in self.blocks: x = blk(x) W_out = sign_ste(self.out_codebook) scores = torch.matmul(x, W_out.t()) # integer popcount in [-D, D] logits = scores * self.logit_scale + self.out_bias loss = None if targets is not None: loss = F.cross_entropy(logits.view(-1, self.vocab_size), targets.view(-1)) return logits, loss @torch.no_grad() def generate(self, idx, max_new_tokens=200, temperature=1.0, top_k=None): self.eval() for _ in range(max_new_tokens): idx_cond = idx[:, -self.max_seq_len:] logits, _ = self(idx_cond) logits = logits[:, -1, :] / max(temperature, 1e-5) if top_k is not None: v, _ = torch.topk(logits, top_k) logits[logits < v[:, [-1]]] = -float('inf') probs = F.softmax(logits, dim=-1) nxt = torch.multinomial(probs, num_samples=1) idx = torch.cat([idx, nxt], dim=1) return idx def param_count(m): return sum(p.numel() for p in m.parameters()) if __name__ == '__main__': model = BitLM() print(f"total params: {param_count(model):,}") x = torch.randint(0, 128, (2, 64)) y = torch.randint(0, 128, (2, 64)) logits, loss = model(x, y) print("logits:", logits.shape, "loss:", loss.item()) loss.backward() print("backward OK")