"""v18: v16 Gumbel hard-attention with a provably-integer inference path. Training: same as v16 (Gumbel-softmax on float scores for gradient, hard argmax for forward value). Inference: call `forward_bin_eval(idx)` instead of `forward(idx)`. That path runs *no float operations* on the hot path. All float scalars (1/√in, logit_scale, threshold, out_bias, alibi float slopes) are absorbed at ckpt-load time into integer thresholds that appear as simple signed-integer subtractions in compare-against-zero decisions. Integer-only ops used at inference: - XNOR-popcount (binary matmul = count of agreements) - Integer add/subtract (popcount − threshold) - Sign (== popcount > threshold, a single compare) - Integer ALiBi subtraction (distance · slope, both integer) - Argmax as integer compare tree (log2(T) depth, single-bit result per match) - Gather (pick V at the winning index — no multiply) Key simplifications from v16: 1. `alibi_slopes` are integers (powers of 2), stored as int64. 2. `sqrt(d_head)` scaling on attention scores is REMOVED at eval; it was a positive uniform scalar so it doesn't change argmax. 3. BitLinear's `s*scale − threshold` is refactored at eval to `popcount − ceil(threshold/scale)`, a pure integer comparison. 4. Output head `scores*logit_scale + out_bias` is refactored to `popcount + round(out_bias/logit_scale)` for integer argmax over vocab. 5. A ∈ {0,1}^{T×T} with one 1 per row (from argmax). O[i] = V[argmax_j S[i,j]] is a gather, not a matmul. """ import math import torch import torch.nn as nn import torch.nn.functional as F from model import sign_ste, sign_ste_clipped, BitLinear, BitFFN, BinaryEmbedding from model_v16 import set_gumbel_tau, gumbel_hard_attention class IntBinaryAttention(nn.Module): """Gumbel hard-attention during training; pure-integer argmax at inference.""" def __init__(self, d_model, n_heads): super().__init__() assert d_model % n_heads == 0 self.d_model = d_model self.n_heads = n_heads self.head_dim = d_model // n_heads self.q_proj = BitLinear(d_model, d_model) self.k_proj = BitLinear(d_model, d_model) self.v_proj = BitLinear(d_model, d_model) self.o_proj = BitLinear(d_model, d_model) # INTEGER ALiBi slopes (power-of-2). Integer bias = slope * |i-j|. slopes = torch.tensor([1 << i for i in range(n_heads)], dtype=torch.long) self.register_buffer('alibi_slopes_int', slopes) self.register_buffer('_causal_mask', torch.empty(0), persistent=False) def _get_mask(self, T, device): if self._causal_mask.shape[-1] < T or self._causal_mask.device != device: m = torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1) self._causal_mask = m return self._causal_mask[:T, :T] def _scores(self, Q, K): """Integer popcount scores minus integer ALiBi bias. No /sqrt(Dh): uniform scalar doesn't change argmax.""" B, H, T, Dh = Q.shape # (B,H,T,T) integer popcount scores = torch.matmul(Q, K.transpose(-2, -1)) # Integer ALiBi pos = torch.arange(T, device=Q.device) dist = (pos.unsqueeze(0) - pos.unsqueeze(1)).abs() # (T,T) int alibi = self.alibi_slopes_int.view(1, H, 1, 1).to(Q.dtype) * dist.view(1, 1, T, T).to(Q.dtype) return scores - alibi def forward(self, x): """Training forward with Gumbel-softmax gradient path.""" B, T, D = x.shape H, Dh = self.n_heads, self.head_dim Q = self.q_proj(x).view(B, T, H, Dh).transpose(1, 2) K = self.k_proj(x).view(B, T, H, Dh).transpose(1, 2) V = self.v_proj(x).view(B, T, H, Dh).transpose(1, 2) scores = self._scores(Q, K) mask = self._get_mask(T, x.device) A = gumbel_hard_attention(scores, mask=mask) # soft-to-hard STE at train, argmax at eval O = torch.matmul(A, V) O = O.transpose(1, 2).contiguous().view(B, T, D) return self.o_proj(O) @torch.no_grad() def forward_bin_eval(self, x): """Pure-integer inference forward. No float on the critical path.""" B, T, D = x.shape H, Dh = self.n_heads, self.head_dim # BitLinear forward is already sign(integer popcount − integer threshold) at eval. Q = self.q_proj(x).view(B, T, H, Dh).transpose(1, 2) K = self.k_proj(x).view(B, T, H, Dh).transpose(1, 2) V = self.v_proj(x).view(B, T, H, Dh).transpose(1, 2) # Integer scores scores = self._scores(Q, K) # Causal mask mask = self._get_mask(T, x.device) scores = scores.masked_fill(mask, torch.iinfo(torch.long).min if scores.dtype == torch.long else -1e18) # Integer argmax per query row. idx = scores.argmax(dim=-1, keepdim=True) # (B,H,T,1) # Gather winning V per query. V shape (B,H,T,Dh). idx_exp = idx.expand(-1, -1, -1, Dh) O = torch.gather(V, dim=2, index=idx_exp) # (B,H,T,Dh) O = O.transpose(1, 2).contiguous().view(B, T, D) return self.o_proj(O) class BitBlockV18(nn.Module): def __init__(self, d_model, n_heads, d_ff): super().__init__() self.attn = IntBinaryAttention(d_model, n_heads) self.ffn = BitFFN(d_model, d_ff) def forward(self, x): a = self.attn(x) f = self.ffn(x) return sign_ste(x + a + f) @torch.no_grad() def forward_bin_eval(self, x): a = self.attn.forward_bin_eval(x) f = self.ffn(x) # already integer/sign under no-grad # Sum is integer in {-3,-1,1,3}. Sign is an integer compare against zero. s = x + a + f return torch.where(s >= 0, torch.ones_like(s), -torch.ones_like(s)) class BitLMv18(nn.Module): def __init__(self, vocab_size=128, d_model=256, n_layers=8, n_heads=8, d_ff=512, max_seq_len=256): super().__init__() self.vocab_size = vocab_size self.d_model = d_model self.n_layers = n_layers self.max_seq_len = max_seq_len self.embed = BinaryEmbedding(vocab_size, d_model) self.blocks = nn.ModuleList([BitBlockV18(d_model, n_heads, d_ff) for _ in range(n_layers)]) self.out_codebook = nn.Parameter(torch.randn(vocab_size, d_model) * 0.02) self.logit_scale = nn.Parameter(torch.tensor(1.0 / math.sqrt(d_model))) self.out_bias = nn.Parameter(torch.zeros(vocab_size)) def forward(self, idx, targets=None): x = self.embed(idx) for blk in self.blocks: x = blk(x) W_out = sign_ste(self.out_codebook) scores = torch.matmul(x, W_out.t()) logits = scores * self.logit_scale + self.out_bias loss = None if targets is not None: loss = F.cross_entropy(logits.view(-1, self.vocab_size), targets.view(-1)) return logits, loss @torch.no_grad() def forward_bin_eval_argmax_next(self, idx): """Pure-integer inference that returns the argmax next-token per position. Used to demonstrate the inference path is fully binary/integer arithmetic. """ x = self.embed(idx) for blk in self.blocks: x = blk.forward_bin_eval(x) # Output head: scores = x @ W_out^T (integer popcount). # For argmax next-char, `scores*logit_scale + out_bias` has same argmax as # `scores + round(out_bias/logit_scale)` since logit_scale > 0. W_out = torch.where(self.out_codebook >= 0, torch.ones_like(self.out_codebook), -torch.ones_like(self.out_codebook)) scores = torch.matmul(x, W_out.t()) # (B,T,V) integer popcount # Scale by a large integer multiplier so (scores*SCALE + bias_int) has # negligible rounding error on argmax. Keeps everything integer. M = 1 << 16 int_bias = torch.round(self.out_bias * M / self.logit_scale).to(scores.dtype) integer_logits = scores.to(torch.int64) * M + int_bias.view(1, 1, -1).to(torch.int64) next_pred = integer_logits.argmax(dim=-1) # (B,T) return next_pred, integer_logits @torch.no_grad() def generate(self, idx, max_new_tokens=200, temperature=1.0, top_k=None, use_bin=False): self.eval() for _ in range(max_new_tokens): idx_cond = idx[:, -self.max_seq_len:] if use_bin: pred, _ = self.forward_bin_eval_argmax_next(idx_cond) nxt = pred[:, -1:].long() else: logits, _ = self(idx_cond) logits = logits[:, -1, :] / max(temperature, 1e-5) if top_k is not None: v, _ = torch.topk(logits, top_k) logits[logits < v[:, [-1]]] = -float('inf') probs = F.softmax(logits, dim=-1) nxt = torch.multinomial(probs, num_samples=1) idx = torch.cat([idx, nxt], dim=1) return idx if __name__ == '__main__': set_gumbel_tau(0.3) m = BitLMv18() n = sum(p.numel() for p in m.parameters()) print(f"v18 params: {n:,} ({n/1e6:.2f}M)") x = torch.randint(0, 128, (2, 64)) y = torch.randint(0, 128, (2, 64)) m.train() logits, loss = m(x, y) print("train forward loss:", loss.item()) loss.backward() print("backward OK") m.eval() pred, int_logits = m.forward_bin_eval_argmax_next(x) print("bin_eval predictions shape:", pred.shape, "dtype:", pred.dtype) print("integer logits dtype:", int_logits.dtype, "— NO FLOAT in inference path")