File size: 10,345 Bytes
4754707 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 | """Maximalist ±1 binary language model.
Forward-pass invariants (what the paper calls "true 1-bit"):
- Embeddings, Q/K/V/O, FFN weights, attention matrix, layer activations: ±1 via sign-STE.
- All matmuls are between ±1 operands (XNOR-popcount equivalents). Intermediate accumulators
are integers in [-k, k]. Thresholds are subtracted and sign is re-applied at the output.
- Residual stream: majority vote sign(x + F(x)) with stochastic tie-break on {x+F(x) == 0}.
- FFN gating: XNOR gate (elementwise multiply of two ±1 tensors).
- Normalization: none. ReZero-style identity residual path at init (threshold = 0 keeps
pre-activation balanced; F(x) starts near-balanced noise).
- Position: integer binary-ALiBi subtractive bias (per-head fixed slopes).
- Output head: tied ±1 embedding codebook. Score = popcount similarity. Softmax applied only
for training-time cross-entropy (acknowledged float concession at the loss surface).
Training-pass concession (§3 of the proposal): each ±1 weight has a latent float that we
call the "counter" in signSGD mode; it's standard for STE-trained networks but we bound it.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def sign_ste(x):
"""Sign with pure identity backward. Maps 0 -> +1."""
out = torch.where(x >= 0, torch.ones_like(x), -torch.ones_like(x))
return x + (out - x).detach()
def sign_ste_clipped(x):
"""Sign with hard-tanh backward (grad only for |x|<=1). Only works if x has been
pre-normalized to ~unit scale; otherwise gradients die."""
out = torch.where(x >= 0, torch.ones_like(x), -torch.ones_like(x))
x_clip = torch.clamp(x, -1.0, 1.0)
return x_clip + (out - x_clip).detach()
class BitLinearRaw(nn.Module):
"""Linear with ±1 weights. Returns raw integer popcount (no sign at output)."""
def __init__(self, in_features, out_features, binarize_input=True):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.binarize_input = binarize_input
# Latent float weight; forward uses sign(w). Small gaussian init gives balanced ±1.
self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.02)
def forward(self, x):
W = sign_ste(self.weight)
if self.binarize_input:
x = sign_ste_clipped(x)
return F.linear(x, W)
class BitLinear(nn.Module):
"""BitLinearRaw + learned threshold + sign. Returns ±1.
The popcount integer output has range [-k, k] with std ~sqrt(k) for balanced inputs.
We divide by sqrt(k) (a scalar constant) so the pre-sign values live at ~unit scale.
This does NOT introduce a float weight — it is just a fixed normalization so hard-tanh
STE actually passes gradients. BiBERT and BitNet both use an equivalent scaling.
"""
def __init__(self, in_features, out_features, binarize_input=True):
super().__init__()
self.raw = BitLinearRaw(in_features, out_features, binarize_input=binarize_input)
self.threshold = nn.Parameter(torch.zeros(out_features))
self.scale = 1.0 / math.sqrt(in_features)
def forward(self, x):
s = self.raw(x) * self.scale - self.threshold
return sign_ste_clipped(s)
class BiAttention(nn.Module):
"""BiBERT-style bool-threshold causal self-attention, fully ±1.
S = Q @ K^T (popcount integer)
S -= alibi_slope * |i-j| (integer subtractive bias, per head)
S -= tau (learned per-head threshold, BiBERT's entropy-max proxy)
mask future -> -inf
A = sign_ste(S) (±1)
mask future -> -1 (force attention off on future tokens)
O = A @ V (popcount integer)
return BitLinear(O) -> ±1
"""
def __init__(self, d_model, n_heads):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.q_proj = BitLinear(d_model, d_model, binarize_input=True)
self.k_proj = BitLinear(d_model, d_model, binarize_input=True)
self.v_proj = BitLinear(d_model, d_model, binarize_input=True)
self.o_proj = BitLinear(d_model, d_model, binarize_input=True)
self.attn_threshold = nn.Parameter(torch.zeros(n_heads))
# Integer binary-ALiBi slopes (fixed). Head 0 is global, later heads are local.
# slopes = [0.25, 0.5, 1, 2, 4, 8, 16, 32, ...]
slopes = torch.tensor([2.0 ** (i - 2) for i in range(n_heads)])
self.register_buffer('alibi_slopes', slopes)
self.register_buffer('_causal_mask', torch.empty(0), persistent=False)
def _get_mask(self, T, device):
if self._causal_mask.shape[-1] < T or self._causal_mask.device != device:
m = torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1)
self._causal_mask = m
return self._causal_mask[:T, :T]
def forward(self, x):
B, T, D = x.shape
H, Dh = self.n_heads, self.head_dim
Q = self.q_proj(x).view(B, T, H, Dh).transpose(1, 2) # (B, H, T, Dh)
K = self.k_proj(x).view(B, T, H, Dh).transpose(1, 2)
V = self.v_proj(x).view(B, T, H, Dh).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1)) # (B, H, T, T) integer popcount
# Scale by 1/sqrt(head_dim) so |scores| ~ O(1). This is the standard attention
# normalization; it's a fixed scalar constant, not a float weight.
scores = scores / math.sqrt(Dh)
pos = torch.arange(T, device=x.device).float()
dist = (pos.unsqueeze(0) - pos.unsqueeze(1)).abs() # (T, T)
alibi_bias = self.alibi_slopes.view(1, H, 1, 1) * dist.view(1, 1, T, T) / math.sqrt(Dh)
scores = scores - alibi_bias
mask = self._get_mask(T, x.device)
scores = scores.masked_fill(mask, -1e9)
tau = self.attn_threshold.view(1, H, 1, 1)
A = sign_ste_clipped(scores - tau)
# Force masked future positions to -1 on the forward (STE handles grad).
A = A.masked_fill(mask, -1.0)
O = torch.matmul(A, V) # (B, H, T, Dh), integer popcount
O = O.transpose(1, 2).contiguous().view(B, T, D)
return self.o_proj(O)
class BitFFN(nn.Module):
"""XNOR-gated binary FFN: down(gate(x) * up(x)). Multiplication of ±1 tensors stays ±1."""
def __init__(self, d_model, d_ff):
super().__init__()
self.gate = BitLinear(d_model, d_ff, binarize_input=True)
self.up = BitLinear(d_model, d_ff, binarize_input=True)
self.down = BitLinear(d_ff, d_model, binarize_input=True)
def forward(self, x):
g = self.gate(x)
u = self.up(x)
return self.down(g * u)
class BitBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff):
super().__init__()
self.attn = BiAttention(d_model, n_heads)
self.ffn = BitFFN(d_model, d_ff)
def _residual(self, x, fx):
"""Majority-vote residual. s = x + fx in {-2, 0, 2}. Sign+STE maps 0 to +1; the
branch inputs will learn to avoid exact ties. Forward is deterministic (same in
train and eval). STE passes gradient identically through the sum."""
return sign_ste(x + fx)
def forward(self, x):
x = self._residual(x, self.attn(x))
x = self._residual(x, self.ffn(x))
return x
class BinaryEmbedding(nn.Module):
def __init__(self, vocab_size, d_model):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.weight = nn.Parameter(torch.randn(vocab_size, d_model) * 0.02)
def forward(self, idx):
W = sign_ste(self.weight)
return F.embedding(idx, W)
def get_codebook(self):
return sign_ste(self.weight)
class BitLM(nn.Module):
"""Concessions at the loss surface (per graceful-degradation ladder):
- learnable output logit scale (1 float scalar)
- per-vocab output bias (V floats)
- untied ±1 output codebook (independent from input embedding)
All hidden computations remain ±1 with integer popcounts.
"""
def __init__(self, vocab_size=128, d_model=256, n_layers=8, n_heads=8, d_ff=512, max_seq_len=256):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layers = n_layers
self.max_seq_len = max_seq_len
self.embed = BinaryEmbedding(vocab_size, d_model)
self.blocks = nn.ModuleList([
BitBlock(d_model, n_heads, d_ff) for _ in range(n_layers)
])
# Independent output codebook (±1 like embedding, but not tied).
self.out_codebook = nn.Parameter(torch.randn(vocab_size, d_model) * 0.02)
self.logit_scale = nn.Parameter(torch.tensor(1.0 / math.sqrt(d_model)))
self.out_bias = nn.Parameter(torch.zeros(vocab_size))
def forward(self, idx, targets=None):
x = self.embed(idx)
for blk in self.blocks:
x = blk(x)
W_out = sign_ste(self.out_codebook)
scores = torch.matmul(x, W_out.t()) # integer popcount in [-D, D]
logits = scores * self.logit_scale + self.out_bias
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, self.vocab_size), targets.view(-1))
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens=200, temperature=1.0, top_k=None):
self.eval()
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.max_seq_len:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / max(temperature, 1e-5)
if top_k is not None:
v, _ = torch.topk(logits, top_k)
logits[logits < v[:, [-1]]] = -float('inf')
probs = F.softmax(logits, dim=-1)
nxt = torch.multinomial(probs, num_samples=1)
idx = torch.cat([idx, nxt], dim=1)
return idx
def param_count(m):
return sum(p.numel() for p in m.parameters())
if __name__ == '__main__':
model = BitLM()
print(f"total params: {param_count(model):,}")
x = torch.randint(0, 128, (2, 64))
y = torch.randint(0, 128, (2, 64))
logits, loss = model(x, y)
print("logits:", logits.shape, "loss:", loss.item())
loss.backward()
print("backward OK")
|