"""v30: Doubled Binary — each weight stored as TWO independent ±1 bits (W_A, W_B). Effective weight W = W_A + W_B has values in {−2, 0, +2} — strict ternary on a binary substrate. This closes the ternary-vs-binary gap ParetoQ identified (~0.2-0.3 BPC on LLaMA) while keeping every operation as XNOR + popcount + add. At inference the output of a DoubleBitLinear layer is: y_i = popcount(W_A[i] XNOR x) + popcount(W_B[i] XNOR x) − in_features which is one extra XNOR-popcount per output row vs standard v18. Memory doubles. Attention, FFN, embeddings, residuals, and output head all use DoubleBitLinear (and a doubled embedding codebook). Activations remain strictly ±1. """ import math import torch import torch.nn as nn import torch.nn.functional as F from model import sign_ste, sign_ste_clipped from model_v18 import IntBinaryAttention # reuse attention shell from model_v16 import set_gumbel_tau def double_bin_linear_forward(x, W_A_bits, W_B_bits, threshold, in_features, scale): """Both weight halves are ±1; output is the sum of two popcount dot products.""" W_A = sign_ste(W_A_bits) W_B = sign_ste(W_B_bits) x_bin = sign_ste_clipped(x) # Two matmuls; sum them; scale; threshold; sign. y = F.linear(x_bin, W_A) + F.linear(x_bin, W_B) # effective ternary weight sum return sign_ste_clipped(y * scale - threshold) class DoubleBitLinear(nn.Module): def __init__(self, in_features, out_features): super().__init__() self.in_features = in_features self.out_features = out_features # Two independent ±1 weight matrices self.weight_A = nn.Parameter(torch.randn(out_features, in_features) * 0.02) self.weight_B = nn.Parameter(torch.randn(out_features, in_features) * 0.02) self.threshold = nn.Parameter(torch.zeros(out_features)) # Scale: since the effective sum is in [-2*in, +2*in] instead of [-in, +in], # we scale by 1/(2*sqrt(in)) to keep pre-sign at unit scale. self.scale = 1.0 / (2.0 * math.sqrt(in_features)) def forward(self, x): return double_bin_linear_forward( x, self.weight_A, self.weight_B, self.threshold, self.in_features, self.scale) class DoubleBiAttention(nn.Module): """v18's IntBinaryAttention but with DoubleBitLinear projections.""" def __init__(self, d_model, n_heads): super().__init__() assert d_model % n_heads == 0 self.d_model = d_model self.n_heads = n_heads self.head_dim = d_model // n_heads self.q_proj = DoubleBitLinear(d_model, d_model) self.k_proj = DoubleBitLinear(d_model, d_model) self.v_proj = DoubleBitLinear(d_model, d_model) self.o_proj = DoubleBitLinear(d_model, d_model) slopes = torch.tensor([1 << i for i in range(n_heads)], dtype=torch.long) self.register_buffer('alibi_slopes_int', slopes) self.register_buffer('_causal_mask', torch.empty(0), persistent=False) def _get_mask(self, T, device): if self._causal_mask.shape[-1] < T or self._causal_mask.device != device: m = torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1) self._causal_mask = m return self._causal_mask[:T, :T] def _gumbel_hard(self, scores): from model_v16 import _get_tau tau = _get_tau(scores.device) if scores.requires_grad: g = -torch.log(-torch.log(torch.rand_like(scores).clamp(min=1e-9)) + 1e-9) y_soft = F.softmax((scores + g) / tau, dim=-1) y_hard = torch.zeros_like(y_soft) y_hard.scatter_(-1, y_soft.argmax(-1, keepdim=True), 1.0) return y_soft + (y_hard - y_soft).detach() else: y = torch.zeros_like(scores) y.scatter_(-1, scores.argmax(-1, keepdim=True), 1.0) return y def forward(self, x): B, T, D = x.shape H, Dh = self.n_heads, self.head_dim Q = self.q_proj(x).view(B, T, H, Dh).transpose(1, 2) K = self.k_proj(x).view(B, T, H, Dh).transpose(1, 2) V = self.v_proj(x).view(B, T, H, Dh).transpose(1, 2) scores = torch.matmul(Q, K.transpose(-2, -1)) pos = torch.arange(T, device=Q.device) dist = (pos.unsqueeze(0) - pos.unsqueeze(1)).abs().to(Q.dtype) alibi = self.alibi_slopes_int.view(1, H, 1, 1).to(Q.dtype) * dist.view(1, 1, T, T) scores = scores - alibi mask = self._get_mask(T, x.device) scores = scores.masked_fill(mask, -1e9) A = self._gumbel_hard(scores) O = torch.matmul(A, V) O = O.transpose(1, 2).contiguous().view(B, T, D) return self.o_proj(O) class DoubleBitFFN(nn.Module): def __init__(self, d_model, d_ff): super().__init__() self.gate = DoubleBitLinear(d_model, d_ff) self.up = DoubleBitLinear(d_model, d_ff) self.down = DoubleBitLinear(d_ff, d_model) def forward(self, x): return self.down(self.gate(x) * self.up(x)) class BitBlockV30(nn.Module): def __init__(self, d_model, n_heads, d_ff): super().__init__() self.attn = DoubleBiAttention(d_model, n_heads) self.ffn = DoubleBitFFN(d_model, d_ff) def forward(self, x): a = self.attn(x) f = self.ffn(x) return sign_ste(x + a + f) class DoubleBinaryEmbedding(nn.Module): """Embedding with two ±1 codebooks summed; effective ternary.""" def __init__(self, vocab_size, d_model): super().__init__() self.vocab_size = vocab_size self.d_model = d_model self.weight_A = nn.Parameter(torch.randn(vocab_size, d_model) * 0.02) self.weight_B = nn.Parameter(torch.randn(vocab_size, d_model) * 0.02) def forward(self, idx): W_A = sign_ste(self.weight_A) W_B = sign_ste(self.weight_B) # Sum-and-sign to keep embedding strictly ±1 at block input # (alternatively we could go ternary here too — but we keep input ±1 for clarity). W = sign_ste(W_A + W_B) return F.embedding(idx, W) def get_codebook(self): return sign_ste(sign_ste(self.weight_A) + sign_ste(self.weight_B)) class BitLMv30(nn.Module): def __init__(self, vocab_size=128, d_model=256, n_layers=8, n_heads=8, d_ff=512, max_seq_len=256): super().__init__() self.vocab_size = vocab_size self.d_model = d_model self.n_layers = n_layers self.max_seq_len = max_seq_len self.embed = DoubleBinaryEmbedding(vocab_size, d_model) self.blocks = nn.ModuleList([ BitBlockV30(d_model, n_heads, d_ff) for _ in range(n_layers) ]) # Doubled output codebook for ternary-effective output head self.out_codebook_A = nn.Parameter(torch.randn(vocab_size, d_model) * 0.02) self.out_codebook_B = nn.Parameter(torch.randn(vocab_size, d_model) * 0.02) self.logit_scale = nn.Parameter(torch.tensor(1.0 / (2.0 * math.sqrt(d_model)))) self.out_bias = nn.Parameter(torch.zeros(vocab_size)) def forward(self, idx, targets=None): x = self.embed(idx) for blk in self.blocks: x = blk(x) W_A = sign_ste(self.out_codebook_A) W_B = sign_ste(self.out_codebook_B) # Sum two popcount similarities for ternary effective logits scores = torch.matmul(x, W_A.t()) + torch.matmul(x, W_B.t()) logits = scores * self.logit_scale + self.out_bias loss = None if targets is not None: loss = F.cross_entropy(logits.view(-1, self.vocab_size), targets.view(-1)) return logits, loss @torch.no_grad() def generate(self, idx, max_new_tokens=200, temperature=1.0, top_k=None): self.eval() for _ in range(max_new_tokens): idx_cond = idx[:, -self.max_seq_len:] logits, _ = self(idx_cond) logits = logits[:, -1, :] / max(temperature, 1e-5) if top_k is not None: v, _ = torch.topk(logits, top_k) logits[logits < v[:, [-1]]] = -float('inf') probs = F.softmax(logits, dim=-1) nxt = torch.multinomial(probs, num_samples=1) idx = torch.cat([idx, nxt], dim=1) return idx if __name__ == '__main__': set_gumbel_tau(0.5) for cfg_name, d, L, d_ff in [('5M', 256, 8, 512), ('50M', 768, 10, 1280)]: m = BitLMv30(vocab_size=128, d_model=d, n_layers=L, n_heads=max(8, d//64), d_ff=d_ff) n = sum(p.numel() for p in m.parameters()) print(f'v30 {cfg_name}: {n:,} params ({n/1e6:.2f}M)') x = torch.randint(0, 128, (2, 64)) y = torch.randint(0, 128, (2, 64)) logits, loss = m(x, y) loss.backward() print(f' loss={loss.item():.3f}, backward OK')