bitnet-1bitllm / vm_backup /code /model_v16.py
hidude562's picture
1bitllm code (checkpoints to follow)
4754707 verified
"""v16: Gumbel hard-attention. Each query attends to exactly ONE key, selected via
Gumbel-softmax with temperature annealing from soft → hard.
Why this might work where v11 top-k failed: v11's STE through top-k gave gradient
that pushed scores up/down but the discrete selection didn't move easily. Gumbel
softmax gives a proper continuous-to-discrete bridge. At high temperature, attn
is like softmax (multiple positions active). At low temperature, attn is
one-hot (single position). Training anneals high → low.
At eval: pure argmax. Each query attends to exactly one position (attention as
pointer). This is ternary {-1, 0, +1} in the attention matrix: one +1 per row,
rest 0s, with optional sign flip carried via separate bit.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from model import sign_ste, sign_ste_clipped, BitLinear, BitFFN, BinaryEmbedding
# Module-level temperature as a mutable CUDA tensor so torch.compile doesn't
# retrace every step when we anneal tau.
_GUMBEL_TAU = torch.tensor([1.0])
def set_gumbel_tau(tau: float):
"""Mutate the tau tensor in place — keeps the same object identity so
torch.compile doesn't see a new constant."""
global _GUMBEL_TAU
_GUMBEL_TAU.fill_(float(tau))
def _get_tau(device):
"""Return the current tau as a device-resident tensor."""
global _GUMBEL_TAU
if _GUMBEL_TAU.device != device:
_GUMBEL_TAU = _GUMBEL_TAU.to(device)
return _GUMBEL_TAU.clamp(min=0.05)
def gumbel_hard_attention(scores, mask=None):
"""scores: (B, H, T, T). mask: bool (T, T) with True for positions to zero out.
Returns (B, H, T, T) attention matrix with one non-zero entry per row at train
time (straight-through hard), and pure argmax at eval."""
tau = _get_tau(scores.device)
if mask is not None:
scores = scores.masked_fill(mask, -1e9)
if scores.requires_grad:
# Gumbel-softmax sample, then straight-through hardify.
g = -torch.log(-torch.log(torch.rand_like(scores).clamp(min=1e-9)) + 1e-9)
y_soft = F.softmax((scores + g) / tau, dim=-1)
y_hard = torch.zeros_like(y_soft)
y_hard.scatter_(-1, y_soft.argmax(-1, keepdim=True), 1.0)
return y_soft + (y_hard - y_soft).detach()
else:
# Eval: pure argmax
y_hard = torch.zeros_like(scores)
y_hard.scatter_(-1, scores.argmax(-1, keepdim=True), 1.0)
return y_hard
class GumbelHardAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.q_proj = BitLinear(d_model, d_model)
self.k_proj = BitLinear(d_model, d_model)
self.v_proj = BitLinear(d_model, d_model)
self.o_proj = BitLinear(d_model, d_model)
slopes = torch.tensor([2.0 ** (i - 2) for i in range(n_heads)])
self.register_buffer('alibi_slopes', slopes)
self.register_buffer('_causal_mask', torch.empty(0), persistent=False)
def _get_mask(self, T, device):
if self._causal_mask.shape[-1] < T or self._causal_mask.device != device:
m = torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1)
self._causal_mask = m
return self._causal_mask[:T, :T]
def forward(self, x):
B, T, D = x.shape
H, Dh = self.n_heads, self.head_dim
Q = self.q_proj(x).view(B, T, H, Dh).transpose(1, 2)
K = self.k_proj(x).view(B, T, H, Dh).transpose(1, 2)
V = self.v_proj(x).view(B, T, H, Dh).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(Dh)
pos = torch.arange(T, device=x.device).float()
dist = (pos.unsqueeze(0) - pos.unsqueeze(1)).abs()
alibi_bias = self.alibi_slopes.view(1, H, 1, 1) * dist.view(1, 1, T, T) / math.sqrt(Dh)
scores = scores - alibi_bias
mask = self._get_mask(T, x.device)
A = gumbel_hard_attention(scores, mask=mask) # 1-hot per row
# A is float (soft at train, hard at eval). Multiply by sign of V to mimic
# value aggregation; for pure strict ±1 we'd also sign V before, but V is
# already ±1 by construction.
O = torch.matmul(A, V)
O = O.transpose(1, 2).contiguous().view(B, T, D)
return self.o_proj(O)
class BitBlockV16(nn.Module):
def __init__(self, d_model, n_heads, d_ff):
super().__init__()
self.attn = GumbelHardAttention(d_model, n_heads)
self.ffn = BitFFN(d_model, d_ff)
def forward(self, x):
a = self.attn(x)
f = self.ffn(x)
return sign_ste(x + a + f)
class BitLMv16(nn.Module):
def __init__(self, vocab_size=128, d_model=256, n_layers=8, n_heads=8, d_ff=512, max_seq_len=256):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layers = n_layers
self.max_seq_len = max_seq_len
self.embed = BinaryEmbedding(vocab_size, d_model)
self.blocks = nn.ModuleList([BitBlockV16(d_model, n_heads, d_ff) for _ in range(n_layers)])
self.out_codebook = nn.Parameter(torch.randn(vocab_size, d_model) * 0.02)
self.logit_scale = nn.Parameter(torch.tensor(1.0 / math.sqrt(d_model)))
self.out_bias = nn.Parameter(torch.zeros(vocab_size))
def forward(self, idx, targets=None):
x = self.embed(idx)
for blk in self.blocks:
x = blk(x)
W_out = sign_ste(self.out_codebook)
scores = torch.matmul(x, W_out.t())
logits = scores * self.logit_scale + self.out_bias
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, self.vocab_size), targets.view(-1))
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens=200, temperature=1.0, top_k=None):
self.eval()
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.max_seq_len:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / max(temperature, 1e-5)
if top_k is not None:
v, _ = torch.topk(logits, top_k)
logits[logits < v[:, [-1]]] = -float('inf')
probs = F.softmax(logits, dim=-1)
nxt = torch.multinomial(probs, num_samples=1)
idx = torch.cat([idx, nxt], dim=1)
return idx
if __name__ == '__main__':
set_gumbel_tau(1.0)
m = BitLMv16()
n = sum(p.numel() for p in m.parameters())
print(f"v16 params: {n:,} ({n/1e6:.2f}M)")
x = torch.randint(0, 128, (2, 64))
y = torch.randint(0, 128, (2, 64))
logits, loss = m(x, y)
print("logits:", logits.shape, "loss:", loss.item())
loss.backward()
print("backward OK")