| """Maximalist ±1 binary language model. |
| |
| Forward-pass invariants (what the paper calls "true 1-bit"): |
| - Embeddings, Q/K/V/O, FFN weights, attention matrix, layer activations: ±1 via sign-STE. |
| - All matmuls are between ±1 operands (XNOR-popcount equivalents). Intermediate accumulators |
| are integers in [-k, k]. Thresholds are subtracted and sign is re-applied at the output. |
| - Residual stream: majority vote sign(x + F(x)) with stochastic tie-break on {x+F(x) == 0}. |
| - FFN gating: XNOR gate (elementwise multiply of two ±1 tensors). |
| - Normalization: none. ReZero-style identity residual path at init (threshold = 0 keeps |
| pre-activation balanced; F(x) starts near-balanced noise). |
| - Position: integer binary-ALiBi subtractive bias (per-head fixed slopes). |
| - Output head: tied ±1 embedding codebook. Score = popcount similarity. Softmax applied only |
| for training-time cross-entropy (acknowledged float concession at the loss surface). |
| |
| Training-pass concession (§3 of the proposal): each ±1 weight has a latent float that we |
| call the "counter" in signSGD mode; it's standard for STE-trained networks but we bound it. |
| """ |
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| def sign_ste(x): |
| """Sign with pure identity backward. Maps 0 -> +1.""" |
| out = torch.where(x >= 0, torch.ones_like(x), -torch.ones_like(x)) |
| return x + (out - x).detach() |
|
|
|
|
| def sign_ste_clipped(x): |
| """Sign with hard-tanh backward (grad only for |x|<=1). Only works if x has been |
| pre-normalized to ~unit scale; otherwise gradients die.""" |
| out = torch.where(x >= 0, torch.ones_like(x), -torch.ones_like(x)) |
| x_clip = torch.clamp(x, -1.0, 1.0) |
| return x_clip + (out - x_clip).detach() |
|
|
|
|
| class BitLinearRaw(nn.Module): |
| """Linear with ±1 weights. Returns raw integer popcount (no sign at output).""" |
| def __init__(self, in_features, out_features, binarize_input=True): |
| super().__init__() |
| self.in_features = in_features |
| self.out_features = out_features |
| self.binarize_input = binarize_input |
| |
| self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.02) |
|
|
| def forward(self, x): |
| W = sign_ste(self.weight) |
| if self.binarize_input: |
| x = sign_ste_clipped(x) |
| return F.linear(x, W) |
|
|
|
|
| class BitLinear(nn.Module): |
| """BitLinearRaw + learned threshold + sign. Returns ±1. |
| |
| The popcount integer output has range [-k, k] with std ~sqrt(k) for balanced inputs. |
| We divide by sqrt(k) (a scalar constant) so the pre-sign values live at ~unit scale. |
| This does NOT introduce a float weight — it is just a fixed normalization so hard-tanh |
| STE actually passes gradients. BiBERT and BitNet both use an equivalent scaling. |
| """ |
| def __init__(self, in_features, out_features, binarize_input=True): |
| super().__init__() |
| self.raw = BitLinearRaw(in_features, out_features, binarize_input=binarize_input) |
| self.threshold = nn.Parameter(torch.zeros(out_features)) |
| self.scale = 1.0 / math.sqrt(in_features) |
|
|
| def forward(self, x): |
| s = self.raw(x) * self.scale - self.threshold |
| return sign_ste_clipped(s) |
|
|
|
|
| class BiAttention(nn.Module): |
| """BiBERT-style bool-threshold causal self-attention, fully ±1. |
| |
| S = Q @ K^T (popcount integer) |
| S -= alibi_slope * |i-j| (integer subtractive bias, per head) |
| S -= tau (learned per-head threshold, BiBERT's entropy-max proxy) |
| mask future -> -inf |
| A = sign_ste(S) (±1) |
| mask future -> -1 (force attention off on future tokens) |
| O = A @ V (popcount integer) |
| return BitLinear(O) -> ±1 |
| """ |
| def __init__(self, d_model, n_heads): |
| super().__init__() |
| assert d_model % n_heads == 0 |
| self.d_model = d_model |
| self.n_heads = n_heads |
| self.head_dim = d_model // n_heads |
|
|
| self.q_proj = BitLinear(d_model, d_model, binarize_input=True) |
| self.k_proj = BitLinear(d_model, d_model, binarize_input=True) |
| self.v_proj = BitLinear(d_model, d_model, binarize_input=True) |
| self.o_proj = BitLinear(d_model, d_model, binarize_input=True) |
|
|
| self.attn_threshold = nn.Parameter(torch.zeros(n_heads)) |
|
|
| |
| |
| slopes = torch.tensor([2.0 ** (i - 2) for i in range(n_heads)]) |
| self.register_buffer('alibi_slopes', slopes) |
| self.register_buffer('_causal_mask', torch.empty(0), persistent=False) |
|
|
| def _get_mask(self, T, device): |
| if self._causal_mask.shape[-1] < T or self._causal_mask.device != device: |
| m = torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1) |
| self._causal_mask = m |
| return self._causal_mask[:T, :T] |
|
|
| def forward(self, x): |
| B, T, D = x.shape |
| H, Dh = self.n_heads, self.head_dim |
| Q = self.q_proj(x).view(B, T, H, Dh).transpose(1, 2) |
| K = self.k_proj(x).view(B, T, H, Dh).transpose(1, 2) |
| V = self.v_proj(x).view(B, T, H, Dh).transpose(1, 2) |
|
|
| scores = torch.matmul(Q, K.transpose(-2, -1)) |
| |
| |
| scores = scores / math.sqrt(Dh) |
|
|
| pos = torch.arange(T, device=x.device).float() |
| dist = (pos.unsqueeze(0) - pos.unsqueeze(1)).abs() |
| alibi_bias = self.alibi_slopes.view(1, H, 1, 1) * dist.view(1, 1, T, T) / math.sqrt(Dh) |
| scores = scores - alibi_bias |
|
|
| mask = self._get_mask(T, x.device) |
| scores = scores.masked_fill(mask, -1e9) |
|
|
| tau = self.attn_threshold.view(1, H, 1, 1) |
| A = sign_ste_clipped(scores - tau) |
| |
| A = A.masked_fill(mask, -1.0) |
|
|
| O = torch.matmul(A, V) |
| O = O.transpose(1, 2).contiguous().view(B, T, D) |
| return self.o_proj(O) |
|
|
|
|
| class BitFFN(nn.Module): |
| """XNOR-gated binary FFN: down(gate(x) * up(x)). Multiplication of ±1 tensors stays ±1.""" |
| def __init__(self, d_model, d_ff): |
| super().__init__() |
| self.gate = BitLinear(d_model, d_ff, binarize_input=True) |
| self.up = BitLinear(d_model, d_ff, binarize_input=True) |
| self.down = BitLinear(d_ff, d_model, binarize_input=True) |
|
|
| def forward(self, x): |
| g = self.gate(x) |
| u = self.up(x) |
| return self.down(g * u) |
|
|
|
|
| class BitBlock(nn.Module): |
| def __init__(self, d_model, n_heads, d_ff): |
| super().__init__() |
| self.attn = BiAttention(d_model, n_heads) |
| self.ffn = BitFFN(d_model, d_ff) |
|
|
| def _residual(self, x, fx): |
| """Majority-vote residual. s = x + fx in {-2, 0, 2}. Sign+STE maps 0 to +1; the |
| branch inputs will learn to avoid exact ties. Forward is deterministic (same in |
| train and eval). STE passes gradient identically through the sum.""" |
| return sign_ste(x + fx) |
|
|
| def forward(self, x): |
| x = self._residual(x, self.attn(x)) |
| x = self._residual(x, self.ffn(x)) |
| return x |
|
|
|
|
| class BinaryEmbedding(nn.Module): |
| def __init__(self, vocab_size, d_model): |
| super().__init__() |
| self.vocab_size = vocab_size |
| self.d_model = d_model |
| self.weight = nn.Parameter(torch.randn(vocab_size, d_model) * 0.02) |
|
|
| def forward(self, idx): |
| W = sign_ste(self.weight) |
| return F.embedding(idx, W) |
|
|
| def get_codebook(self): |
| return sign_ste(self.weight) |
|
|
|
|
| class BitLM(nn.Module): |
| """Concessions at the loss surface (per graceful-degradation ladder): |
| - learnable output logit scale (1 float scalar) |
| - per-vocab output bias (V floats) |
| - untied ±1 output codebook (independent from input embedding) |
| All hidden computations remain ±1 with integer popcounts. |
| """ |
| def __init__(self, vocab_size=128, d_model=256, n_layers=8, n_heads=8, d_ff=512, max_seq_len=256): |
| super().__init__() |
| self.vocab_size = vocab_size |
| self.d_model = d_model |
| self.n_layers = n_layers |
| self.max_seq_len = max_seq_len |
| self.embed = BinaryEmbedding(vocab_size, d_model) |
| self.blocks = nn.ModuleList([ |
| BitBlock(d_model, n_heads, d_ff) for _ in range(n_layers) |
| ]) |
| |
| self.out_codebook = nn.Parameter(torch.randn(vocab_size, d_model) * 0.02) |
| self.logit_scale = nn.Parameter(torch.tensor(1.0 / math.sqrt(d_model))) |
| self.out_bias = nn.Parameter(torch.zeros(vocab_size)) |
|
|
| def forward(self, idx, targets=None): |
| x = self.embed(idx) |
| for blk in self.blocks: |
| x = blk(x) |
| W_out = sign_ste(self.out_codebook) |
| scores = torch.matmul(x, W_out.t()) |
| logits = scores * self.logit_scale + self.out_bias |
| loss = None |
| if targets is not None: |
| loss = F.cross_entropy(logits.view(-1, self.vocab_size), targets.view(-1)) |
| return logits, loss |
|
|
| @torch.no_grad() |
| def generate(self, idx, max_new_tokens=200, temperature=1.0, top_k=None): |
| self.eval() |
| for _ in range(max_new_tokens): |
| idx_cond = idx[:, -self.max_seq_len:] |
| logits, _ = self(idx_cond) |
| logits = logits[:, -1, :] / max(temperature, 1e-5) |
| if top_k is not None: |
| v, _ = torch.topk(logits, top_k) |
| logits[logits < v[:, [-1]]] = -float('inf') |
| probs = F.softmax(logits, dim=-1) |
| nxt = torch.multinomial(probs, num_samples=1) |
| idx = torch.cat([idx, nxt], dim=1) |
| return idx |
|
|
|
|
| def param_count(m): |
| return sum(p.numel() for p in m.parameters()) |
|
|
|
|
| if __name__ == '__main__': |
| model = BitLM() |
| print(f"total params: {param_count(model):,}") |
| x = torch.randint(0, 128, (2, 64)) |
| y = torch.randint(0, 128, (2, 64)) |
| logits, loss = model(x, y) |
| print("logits:", logits.shape, "loss:", loss.item()) |
| loss.backward() |
| print("backward OK") |
|
|