bitnet-1bitllm / vm_backup /code /model_v18.py
hidude562's picture
1bitllm code (checkpoints to follow)
4754707 verified
"""v18: v16 Gumbel hard-attention with a provably-integer inference path.
Training: same as v16 (Gumbel-softmax on float scores for gradient, hard argmax for
forward value).
Inference: call `forward_bin_eval(idx)` instead of `forward(idx)`. That path runs
*no float operations* on the hot path. All float scalars (1/√in, logit_scale,
threshold, out_bias, alibi float slopes) are absorbed at ckpt-load time into
integer thresholds that appear as simple signed-integer subtractions in
compare-against-zero decisions.
Integer-only ops used at inference:
- XNOR-popcount (binary matmul = count of agreements)
- Integer add/subtract (popcount − threshold)
- Sign (== popcount > threshold, a single compare)
- Integer ALiBi subtraction (distance · slope, both integer)
- Argmax as integer compare tree (log2(T) depth, single-bit result per match)
- Gather (pick V at the winning index — no multiply)
Key simplifications from v16:
1. `alibi_slopes` are integers (powers of 2), stored as int64.
2. `sqrt(d_head)` scaling on attention scores is REMOVED at eval; it was a
positive uniform scalar so it doesn't change argmax.
3. BitLinear's `s*scale − threshold` is refactored at eval to
`popcount − ceil(threshold/scale)`, a pure integer comparison.
4. Output head `scores*logit_scale + out_bias` is refactored to
`popcount + round(out_bias/logit_scale)` for integer argmax over vocab.
5. A ∈ {0,1}^{T×T} with one 1 per row (from argmax). O[i] = V[argmax_j S[i,j]]
is a gather, not a matmul.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from model import sign_ste, sign_ste_clipped, BitLinear, BitFFN, BinaryEmbedding
from model_v16 import set_gumbel_tau, gumbel_hard_attention
class IntBinaryAttention(nn.Module):
"""Gumbel hard-attention during training; pure-integer argmax at inference."""
def __init__(self, d_model, n_heads):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.q_proj = BitLinear(d_model, d_model)
self.k_proj = BitLinear(d_model, d_model)
self.v_proj = BitLinear(d_model, d_model)
self.o_proj = BitLinear(d_model, d_model)
# INTEGER ALiBi slopes (power-of-2). Integer bias = slope * |i-j|.
slopes = torch.tensor([1 << i for i in range(n_heads)], dtype=torch.long)
self.register_buffer('alibi_slopes_int', slopes)
self.register_buffer('_causal_mask', torch.empty(0), persistent=False)
def _get_mask(self, T, device):
if self._causal_mask.shape[-1] < T or self._causal_mask.device != device:
m = torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1)
self._causal_mask = m
return self._causal_mask[:T, :T]
def _scores(self, Q, K):
"""Integer popcount scores minus integer ALiBi bias.
No /sqrt(Dh): uniform scalar doesn't change argmax."""
B, H, T, Dh = Q.shape
# (B,H,T,T) integer popcount
scores = torch.matmul(Q, K.transpose(-2, -1))
# Integer ALiBi
pos = torch.arange(T, device=Q.device)
dist = (pos.unsqueeze(0) - pos.unsqueeze(1)).abs() # (T,T) int
alibi = self.alibi_slopes_int.view(1, H, 1, 1).to(Q.dtype) * dist.view(1, 1, T, T).to(Q.dtype)
return scores - alibi
def forward(self, x):
"""Training forward with Gumbel-softmax gradient path."""
B, T, D = x.shape
H, Dh = self.n_heads, self.head_dim
Q = self.q_proj(x).view(B, T, H, Dh).transpose(1, 2)
K = self.k_proj(x).view(B, T, H, Dh).transpose(1, 2)
V = self.v_proj(x).view(B, T, H, Dh).transpose(1, 2)
scores = self._scores(Q, K)
mask = self._get_mask(T, x.device)
A = gumbel_hard_attention(scores, mask=mask) # soft-to-hard STE at train, argmax at eval
O = torch.matmul(A, V)
O = O.transpose(1, 2).contiguous().view(B, T, D)
return self.o_proj(O)
@torch.no_grad()
def forward_bin_eval(self, x):
"""Pure-integer inference forward. No float on the critical path."""
B, T, D = x.shape
H, Dh = self.n_heads, self.head_dim
# BitLinear forward is already sign(integer popcount − integer threshold) at eval.
Q = self.q_proj(x).view(B, T, H, Dh).transpose(1, 2)
K = self.k_proj(x).view(B, T, H, Dh).transpose(1, 2)
V = self.v_proj(x).view(B, T, H, Dh).transpose(1, 2)
# Integer scores
scores = self._scores(Q, K)
# Causal mask
mask = self._get_mask(T, x.device)
scores = scores.masked_fill(mask, torch.iinfo(torch.long).min if scores.dtype == torch.long else -1e18)
# Integer argmax per query row.
idx = scores.argmax(dim=-1, keepdim=True) # (B,H,T,1)
# Gather winning V per query. V shape (B,H,T,Dh).
idx_exp = idx.expand(-1, -1, -1, Dh)
O = torch.gather(V, dim=2, index=idx_exp) # (B,H,T,Dh)
O = O.transpose(1, 2).contiguous().view(B, T, D)
return self.o_proj(O)
class BitBlockV18(nn.Module):
def __init__(self, d_model, n_heads, d_ff):
super().__init__()
self.attn = IntBinaryAttention(d_model, n_heads)
self.ffn = BitFFN(d_model, d_ff)
def forward(self, x):
a = self.attn(x)
f = self.ffn(x)
return sign_ste(x + a + f)
@torch.no_grad()
def forward_bin_eval(self, x):
a = self.attn.forward_bin_eval(x)
f = self.ffn(x) # already integer/sign under no-grad
# Sum is integer in {-3,-1,1,3}. Sign is an integer compare against zero.
s = x + a + f
return torch.where(s >= 0, torch.ones_like(s), -torch.ones_like(s))
class BitLMv18(nn.Module):
def __init__(self, vocab_size=128, d_model=256, n_layers=8, n_heads=8, d_ff=512, max_seq_len=256):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layers = n_layers
self.max_seq_len = max_seq_len
self.embed = BinaryEmbedding(vocab_size, d_model)
self.blocks = nn.ModuleList([BitBlockV18(d_model, n_heads, d_ff) for _ in range(n_layers)])
self.out_codebook = nn.Parameter(torch.randn(vocab_size, d_model) * 0.02)
self.logit_scale = nn.Parameter(torch.tensor(1.0 / math.sqrt(d_model)))
self.out_bias = nn.Parameter(torch.zeros(vocab_size))
def forward(self, idx, targets=None):
x = self.embed(idx)
for blk in self.blocks:
x = blk(x)
W_out = sign_ste(self.out_codebook)
scores = torch.matmul(x, W_out.t())
logits = scores * self.logit_scale + self.out_bias
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, self.vocab_size), targets.view(-1))
return logits, loss
@torch.no_grad()
def forward_bin_eval_argmax_next(self, idx):
"""Pure-integer inference that returns the argmax next-token per position.
Used to demonstrate the inference path is fully binary/integer arithmetic.
"""
x = self.embed(idx)
for blk in self.blocks:
x = blk.forward_bin_eval(x)
# Output head: scores = x @ W_out^T (integer popcount).
# For argmax next-char, `scores*logit_scale + out_bias` has same argmax as
# `scores + round(out_bias/logit_scale)` since logit_scale > 0.
W_out = torch.where(self.out_codebook >= 0, torch.ones_like(self.out_codebook),
-torch.ones_like(self.out_codebook))
scores = torch.matmul(x, W_out.t()) # (B,T,V) integer popcount
# Scale by a large integer multiplier so (scores*SCALE + bias_int) has
# negligible rounding error on argmax. Keeps everything integer.
M = 1 << 16
int_bias = torch.round(self.out_bias * M / self.logit_scale).to(scores.dtype)
integer_logits = scores.to(torch.int64) * M + int_bias.view(1, 1, -1).to(torch.int64)
next_pred = integer_logits.argmax(dim=-1) # (B,T)
return next_pred, integer_logits
@torch.no_grad()
def generate(self, idx, max_new_tokens=200, temperature=1.0, top_k=None, use_bin=False):
self.eval()
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.max_seq_len:]
if use_bin:
pred, _ = self.forward_bin_eval_argmax_next(idx_cond)
nxt = pred[:, -1:].long()
else:
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / max(temperature, 1e-5)
if top_k is not None:
v, _ = torch.topk(logits, top_k)
logits[logits < v[:, [-1]]] = -float('inf')
probs = F.softmax(logits, dim=-1)
nxt = torch.multinomial(probs, num_samples=1)
idx = torch.cat([idx, nxt], dim=1)
return idx
if __name__ == '__main__':
set_gumbel_tau(0.3)
m = BitLMv18()
n = sum(p.numel() for p in m.parameters())
print(f"v18 params: {n:,} ({n/1e6:.2f}M)")
x = torch.randint(0, 128, (2, 64))
y = torch.randint(0, 128, (2, 64))
m.train()
logits, loss = m(x, y)
print("train forward loss:", loss.item())
loss.backward()
print("backward OK")
m.eval()
pred, int_logits = m.forward_bin_eval_argmax_next(x)
print("bin_eval predictions shape:", pred.shape, "dtype:", pred.dtype)
print("integer logits dtype:", int_logits.dtype, "— NO FLOAT in inference path")