bitnet-1bitllm / vm_backup /code /model_v5.py
hidude562's picture
1bitllm code (checkpoints to follow)
4754707 verified
"""v5: combines sprint Track A+C top-EV bets.
- **A5 Hadamard rotation** before Q/K/V: rotate activations by a fixed ±1 Hadamard
matrix (fast Walsh-Hadamard transform). Outlier-reducing, natively ±1 (Hadamard is
a sign matrix), cost-free at forward since FWHT is O(d log d) with ±1 ops.
- **A1 learnable integer τ** for the bool-threshold attention: τ is a float shadow
that is round-STE'd to the nearest integer in forward. Keeps the "all forward
arithmetic is integer/±1" invariant while letting τ move continuously under grad.
- **C2 5-way parallel residual**: y = sign(x + attn(x) + ffn(x) + pos_bias_A + pos_bias_B)
where pos_bias_A/B are per-layer learned ±1 position-independent channel bias
vectors (sign-STE of small float shadows). 5 = odd ⇒ no sum-to-zero ties.
- **D3 Hamming output head (implicit)**: we already use popcount similarity as the
logit; keep it unchanged.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from model import (
sign_ste, sign_ste_clipped, BitLinearRaw, BitLinear, BitFFN, BinaryEmbedding,
)
def int_ste(x):
"""Round-to-nearest-int with identity backward (straight-through)."""
r = torch.round(x)
return x + (r - x).detach()
def hadamard_transform(x):
"""In-place fast Walsh-Hadamard transform along last dim. Requires len power of 2.
Output is not normalized (so H @ H = d·I). We absorb the 1/sqrt(d) into downstream
scales — it's a fixed scalar constant, like the BitLinear's 1/sqrt(in) normalization.
"""
d = x.shape[-1]
assert (d & (d - 1)) == 0, f"d must be power of 2, got {d}"
# x shape (..., d). Reshape to (..., d) and apply butterfly.
shape = x.shape
x = x.reshape(-1, d).contiguous()
n = d
h = 1
while h < n:
x = x.view(-1, n // (2 * h), 2, h)
a = x[:, :, 0, :]
b = x[:, :, 1, :]
x = torch.stack([a + b, a - b], dim=2).view(-1, n)
h *= 2
return x.view(shape)
class BiAttentionV5(nn.Module):
"""Hadamard-rotated, learnable-integer-τ causal attention, fully ±1."""
def __init__(self, d_model, n_heads):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.q_proj = BitLinear(d_model, d_model, binarize_input=True)
self.k_proj = BitLinear(d_model, d_model, binarize_input=True)
self.v_proj = BitLinear(d_model, d_model, binarize_input=True)
self.o_proj = BitLinear(d_model, d_model, binarize_input=True)
# A1: float shadow for τ, rounded to int in forward.
self.attn_threshold_shadow = nn.Parameter(torch.zeros(n_heads))
slopes = torch.tensor([2.0 ** (i - 2) for i in range(n_heads)])
self.register_buffer('alibi_slopes', slopes)
self.register_buffer('_causal_mask', torch.empty(0), persistent=False)
def _get_mask(self, T, device):
if self._causal_mask.shape[-1] < T or self._causal_mask.device != device:
m = torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1)
self._causal_mask = m
return self._causal_mask[:T, :T]
def forward(self, x):
B, T, D = x.shape
H, Dh = self.n_heads, self.head_dim
# A5: Hadamard rotation BEFORE sign-binarize in projection.
# Apply along last dim of x. Scale by 1/sqrt(d) to keep unit variance,
# otherwise the rotated values scale up and BitLinear's own 1/sqrt(d) can't compensate.
x_rot = hadamard_transform(x) / math.sqrt(D)
Q = self.q_proj(x_rot).view(B, T, H, Dh).transpose(1, 2)
K = self.k_proj(x_rot).view(B, T, H, Dh).transpose(1, 2)
V = self.v_proj(x_rot).view(B, T, H, Dh).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(Dh)
pos = torch.arange(T, device=x.device).float()
dist = (pos.unsqueeze(0) - pos.unsqueeze(1)).abs()
alibi_bias = self.alibi_slopes.view(1, H, 1, 1) * dist.view(1, 1, T, T) / math.sqrt(Dh)
scores = scores - alibi_bias
mask = self._get_mask(T, x.device)
scores = scores.masked_fill(mask, -1e9)
# A1: integer τ (rounded shadow), per head.
tau_int = int_ste(self.attn_threshold_shadow).view(1, H, 1, 1)
A = sign_ste_clipped(scores - tau_int)
A = A.masked_fill(mask, -1.0)
O = torch.matmul(A, V)
O = O.transpose(1, 2).contiguous().view(B, T, D)
return self.o_proj(O)
class BitBlockV5(nn.Module):
"""C2 5-way parallel residual: x + attn(x) + ffn(x) + bias_A + bias_B.
bias_A, bias_B are per-layer learned ±1 vectors (T-independent) — same value
broadcast over the sequence axis. 5 odd terms ⇒ no sum-to-zero, no tie-break bias.
"""
def __init__(self, d_model, n_heads, d_ff):
super().__init__()
self.attn = BiAttentionV5(d_model, n_heads)
self.ffn = BitFFN(d_model, d_ff)
# Two learnable ±1 per-channel biases, sign-STE of float shadows.
self.bias_a = nn.Parameter(torch.randn(d_model) * 0.02)
self.bias_b = nn.Parameter(torch.randn(d_model) * 0.02)
def forward(self, x):
a = self.attn(x)
f = self.ffn(x)
ba = sign_ste(self.bias_a).view(1, 1, -1)
bb = sign_ste(self.bias_b).view(1, 1, -1)
return sign_ste(x + a + f + ba + bb)
class BitLMv5(nn.Module):
def __init__(self, vocab_size=128, d_model=256, n_layers=8, n_heads=8, d_ff=512, max_seq_len=256):
super().__init__()
assert (d_model & (d_model - 1)) == 0, "v5 requires d_model power of 2 for Hadamard"
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layers = n_layers
self.max_seq_len = max_seq_len
self.embed = BinaryEmbedding(vocab_size, d_model)
self.blocks = nn.ModuleList([
BitBlockV5(d_model, n_heads, d_ff) for _ in range(n_layers)
])
self.out_codebook = nn.Parameter(torch.randn(vocab_size, d_model) * 0.02)
self.logit_scale = nn.Parameter(torch.tensor(1.0 / math.sqrt(d_model)))
self.out_bias = nn.Parameter(torch.zeros(vocab_size))
def forward(self, idx, targets=None):
x = self.embed(idx)
for blk in self.blocks:
x = blk(x)
W_out = sign_ste(self.out_codebook)
scores = torch.matmul(x, W_out.t())
logits = scores * self.logit_scale + self.out_bias
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, self.vocab_size), targets.view(-1))
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens=200, temperature=1.0, top_k=None):
self.eval()
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.max_seq_len:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / max(temperature, 1e-5)
if top_k is not None:
v, _ = torch.topk(logits, top_k)
logits[logits < v[:, [-1]]] = -float('inf')
probs = F.softmax(logits, dim=-1)
nxt = torch.multinomial(probs, num_samples=1)
idx = torch.cat([idx, nxt], dim=1)
return idx
if __name__ == '__main__':
m = BitLMv5()
n = sum(p.numel() for p in m.parameters())
print(f"v5 params: {n:,} ({n/1e6:.2f}M)")
x = torch.randint(0, 128, (2, 64))
y = torch.randint(0, 128, (2, 64))
logits, loss = m(x, y)
print("logits:", logits.shape, "loss:", loss.item())
loss.backward()
print("backward OK")
# sanity check hadamard
x = torch.randn(3, 256)
x_h = hadamard_transform(x)
x_hh = hadamard_transform(x_h) # should be d·x
assert torch.allclose(x_hh, 256 * x, atol=1e-4), "Hadamard self-inverse check failed"
print("Hadamard self-inverse ok")