| """v30: Doubled Binary — each weight stored as TWO independent ±1 bits (W_A, W_B). |
| |
| Effective weight W = W_A + W_B has values in {−2, 0, +2} — strict ternary on a |
| binary substrate. This closes the ternary-vs-binary gap ParetoQ identified |
| (~0.2-0.3 BPC on LLaMA) while keeping every operation as XNOR + popcount + add. |
| |
| At inference the output of a DoubleBitLinear layer is: |
| y_i = popcount(W_A[i] XNOR x) + popcount(W_B[i] XNOR x) − in_features |
| which is one extra XNOR-popcount per output row vs standard v18. Memory doubles. |
| |
| Attention, FFN, embeddings, residuals, and output head all use DoubleBitLinear |
| (and a doubled embedding codebook). Activations remain strictly ±1. |
| """ |
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from model import sign_ste, sign_ste_clipped |
| from model_v18 import IntBinaryAttention |
| from model_v16 import set_gumbel_tau |
|
|
|
|
| def double_bin_linear_forward(x, W_A_bits, W_B_bits, threshold, in_features, scale): |
| """Both weight halves are ±1; output is the sum of two popcount dot products.""" |
| W_A = sign_ste(W_A_bits) |
| W_B = sign_ste(W_B_bits) |
| x_bin = sign_ste_clipped(x) |
| |
| y = F.linear(x_bin, W_A) + F.linear(x_bin, W_B) |
| return sign_ste_clipped(y * scale - threshold) |
|
|
|
|
| class DoubleBitLinear(nn.Module): |
| def __init__(self, in_features, out_features): |
| super().__init__() |
| self.in_features = in_features |
| self.out_features = out_features |
| |
| self.weight_A = nn.Parameter(torch.randn(out_features, in_features) * 0.02) |
| self.weight_B = nn.Parameter(torch.randn(out_features, in_features) * 0.02) |
| self.threshold = nn.Parameter(torch.zeros(out_features)) |
| |
| |
| self.scale = 1.0 / (2.0 * math.sqrt(in_features)) |
|
|
| def forward(self, x): |
| return double_bin_linear_forward( |
| x, self.weight_A, self.weight_B, self.threshold, self.in_features, self.scale) |
|
|
|
|
| class DoubleBiAttention(nn.Module): |
| """v18's IntBinaryAttention but with DoubleBitLinear projections.""" |
| def __init__(self, d_model, n_heads): |
| super().__init__() |
| assert d_model % n_heads == 0 |
| self.d_model = d_model |
| self.n_heads = n_heads |
| self.head_dim = d_model // n_heads |
| self.q_proj = DoubleBitLinear(d_model, d_model) |
| self.k_proj = DoubleBitLinear(d_model, d_model) |
| self.v_proj = DoubleBitLinear(d_model, d_model) |
| self.o_proj = DoubleBitLinear(d_model, d_model) |
| slopes = torch.tensor([1 << i for i in range(n_heads)], dtype=torch.long) |
| self.register_buffer('alibi_slopes_int', slopes) |
| self.register_buffer('_causal_mask', torch.empty(0), persistent=False) |
|
|
| def _get_mask(self, T, device): |
| if self._causal_mask.shape[-1] < T or self._causal_mask.device != device: |
| m = torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1) |
| self._causal_mask = m |
| return self._causal_mask[:T, :T] |
|
|
| def _gumbel_hard(self, scores): |
| from model_v16 import _get_tau |
| tau = _get_tau(scores.device) |
| if scores.requires_grad: |
| g = -torch.log(-torch.log(torch.rand_like(scores).clamp(min=1e-9)) + 1e-9) |
| y_soft = F.softmax((scores + g) / tau, dim=-1) |
| y_hard = torch.zeros_like(y_soft) |
| y_hard.scatter_(-1, y_soft.argmax(-1, keepdim=True), 1.0) |
| return y_soft + (y_hard - y_soft).detach() |
| else: |
| y = torch.zeros_like(scores) |
| y.scatter_(-1, scores.argmax(-1, keepdim=True), 1.0) |
| return y |
|
|
| def forward(self, x): |
| B, T, D = x.shape |
| H, Dh = self.n_heads, self.head_dim |
| Q = self.q_proj(x).view(B, T, H, Dh).transpose(1, 2) |
| K = self.k_proj(x).view(B, T, H, Dh).transpose(1, 2) |
| V = self.v_proj(x).view(B, T, H, Dh).transpose(1, 2) |
| scores = torch.matmul(Q, K.transpose(-2, -1)) |
| pos = torch.arange(T, device=Q.device) |
| dist = (pos.unsqueeze(0) - pos.unsqueeze(1)).abs().to(Q.dtype) |
| alibi = self.alibi_slopes_int.view(1, H, 1, 1).to(Q.dtype) * dist.view(1, 1, T, T) |
| scores = scores - alibi |
| mask = self._get_mask(T, x.device) |
| scores = scores.masked_fill(mask, -1e9) |
| A = self._gumbel_hard(scores) |
| O = torch.matmul(A, V) |
| O = O.transpose(1, 2).contiguous().view(B, T, D) |
| return self.o_proj(O) |
|
|
|
|
| class DoubleBitFFN(nn.Module): |
| def __init__(self, d_model, d_ff): |
| super().__init__() |
| self.gate = DoubleBitLinear(d_model, d_ff) |
| self.up = DoubleBitLinear(d_model, d_ff) |
| self.down = DoubleBitLinear(d_ff, d_model) |
|
|
| def forward(self, x): |
| return self.down(self.gate(x) * self.up(x)) |
|
|
|
|
| class BitBlockV30(nn.Module): |
| def __init__(self, d_model, n_heads, d_ff): |
| super().__init__() |
| self.attn = DoubleBiAttention(d_model, n_heads) |
| self.ffn = DoubleBitFFN(d_model, d_ff) |
|
|
| def forward(self, x): |
| a = self.attn(x) |
| f = self.ffn(x) |
| return sign_ste(x + a + f) |
|
|
|
|
| class DoubleBinaryEmbedding(nn.Module): |
| """Embedding with two ±1 codebooks summed; effective ternary.""" |
| def __init__(self, vocab_size, d_model): |
| super().__init__() |
| self.vocab_size = vocab_size |
| self.d_model = d_model |
| self.weight_A = nn.Parameter(torch.randn(vocab_size, d_model) * 0.02) |
| self.weight_B = nn.Parameter(torch.randn(vocab_size, d_model) * 0.02) |
|
|
| def forward(self, idx): |
| W_A = sign_ste(self.weight_A) |
| W_B = sign_ste(self.weight_B) |
| |
| |
| W = sign_ste(W_A + W_B) |
| return F.embedding(idx, W) |
|
|
| def get_codebook(self): |
| return sign_ste(sign_ste(self.weight_A) + sign_ste(self.weight_B)) |
|
|
|
|
| class BitLMv30(nn.Module): |
| def __init__(self, vocab_size=128, d_model=256, n_layers=8, n_heads=8, d_ff=512, |
| max_seq_len=256): |
| super().__init__() |
| self.vocab_size = vocab_size |
| self.d_model = d_model |
| self.n_layers = n_layers |
| self.max_seq_len = max_seq_len |
| self.embed = DoubleBinaryEmbedding(vocab_size, d_model) |
| self.blocks = nn.ModuleList([ |
| BitBlockV30(d_model, n_heads, d_ff) for _ in range(n_layers) |
| ]) |
| |
| self.out_codebook_A = nn.Parameter(torch.randn(vocab_size, d_model) * 0.02) |
| self.out_codebook_B = nn.Parameter(torch.randn(vocab_size, d_model) * 0.02) |
| self.logit_scale = nn.Parameter(torch.tensor(1.0 / (2.0 * math.sqrt(d_model)))) |
| self.out_bias = nn.Parameter(torch.zeros(vocab_size)) |
|
|
| def forward(self, idx, targets=None): |
| x = self.embed(idx) |
| for blk in self.blocks: |
| x = blk(x) |
| W_A = sign_ste(self.out_codebook_A) |
| W_B = sign_ste(self.out_codebook_B) |
| |
| scores = torch.matmul(x, W_A.t()) + torch.matmul(x, W_B.t()) |
| logits = scores * self.logit_scale + self.out_bias |
| loss = None |
| if targets is not None: |
| loss = F.cross_entropy(logits.view(-1, self.vocab_size), targets.view(-1)) |
| return logits, loss |
|
|
| @torch.no_grad() |
| def generate(self, idx, max_new_tokens=200, temperature=1.0, top_k=None): |
| self.eval() |
| for _ in range(max_new_tokens): |
| idx_cond = idx[:, -self.max_seq_len:] |
| logits, _ = self(idx_cond) |
| logits = logits[:, -1, :] / max(temperature, 1e-5) |
| if top_k is not None: |
| v, _ = torch.topk(logits, top_k) |
| logits[logits < v[:, [-1]]] = -float('inf') |
| probs = F.softmax(logits, dim=-1) |
| nxt = torch.multinomial(probs, num_samples=1) |
| idx = torch.cat([idx, nxt], dim=1) |
| return idx |
|
|
|
|
| if __name__ == '__main__': |
| set_gumbel_tau(0.5) |
| for cfg_name, d, L, d_ff in [('5M', 256, 8, 512), ('50M', 768, 10, 1280)]: |
| m = BitLMv30(vocab_size=128, d_model=d, n_layers=L, n_heads=max(8, d//64), d_ff=d_ff) |
| n = sum(p.numel() for p in m.parameters()) |
| print(f'v30 {cfg_name}: {n:,} params ({n/1e6:.2f}M)') |
| x = torch.randint(0, 128, (2, 64)) |
| y = torch.randint(0, 128, (2, 64)) |
| logits, loss = m(x, y) |
| loss.backward() |
| print(f' loss={loss.item():.3f}, backward OK') |
|
|