File size: 7,903 Bytes
4754707 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 | """v5: combines sprint Track A+C top-EV bets.
- **A5 Hadamard rotation** before Q/K/V: rotate activations by a fixed ±1 Hadamard
matrix (fast Walsh-Hadamard transform). Outlier-reducing, natively ±1 (Hadamard is
a sign matrix), cost-free at forward since FWHT is O(d log d) with ±1 ops.
- **A1 learnable integer τ** for the bool-threshold attention: τ is a float shadow
that is round-STE'd to the nearest integer in forward. Keeps the "all forward
arithmetic is integer/±1" invariant while letting τ move continuously under grad.
- **C2 5-way parallel residual**: y = sign(x + attn(x) + ffn(x) + pos_bias_A + pos_bias_B)
where pos_bias_A/B are per-layer learned ±1 position-independent channel bias
vectors (sign-STE of small float shadows). 5 = odd ⇒ no sum-to-zero ties.
- **D3 Hamming output head (implicit)**: we already use popcount similarity as the
logit; keep it unchanged.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from model import (
sign_ste, sign_ste_clipped, BitLinearRaw, BitLinear, BitFFN, BinaryEmbedding,
)
def int_ste(x):
"""Round-to-nearest-int with identity backward (straight-through)."""
r = torch.round(x)
return x + (r - x).detach()
def hadamard_transform(x):
"""In-place fast Walsh-Hadamard transform along last dim. Requires len power of 2.
Output is not normalized (so H @ H = d·I). We absorb the 1/sqrt(d) into downstream
scales — it's a fixed scalar constant, like the BitLinear's 1/sqrt(in) normalization.
"""
d = x.shape[-1]
assert (d & (d - 1)) == 0, f"d must be power of 2, got {d}"
# x shape (..., d). Reshape to (..., d) and apply butterfly.
shape = x.shape
x = x.reshape(-1, d).contiguous()
n = d
h = 1
while h < n:
x = x.view(-1, n // (2 * h), 2, h)
a = x[:, :, 0, :]
b = x[:, :, 1, :]
x = torch.stack([a + b, a - b], dim=2).view(-1, n)
h *= 2
return x.view(shape)
class BiAttentionV5(nn.Module):
"""Hadamard-rotated, learnable-integer-τ causal attention, fully ±1."""
def __init__(self, d_model, n_heads):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.q_proj = BitLinear(d_model, d_model, binarize_input=True)
self.k_proj = BitLinear(d_model, d_model, binarize_input=True)
self.v_proj = BitLinear(d_model, d_model, binarize_input=True)
self.o_proj = BitLinear(d_model, d_model, binarize_input=True)
# A1: float shadow for τ, rounded to int in forward.
self.attn_threshold_shadow = nn.Parameter(torch.zeros(n_heads))
slopes = torch.tensor([2.0 ** (i - 2) for i in range(n_heads)])
self.register_buffer('alibi_slopes', slopes)
self.register_buffer('_causal_mask', torch.empty(0), persistent=False)
def _get_mask(self, T, device):
if self._causal_mask.shape[-1] < T or self._causal_mask.device != device:
m = torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1)
self._causal_mask = m
return self._causal_mask[:T, :T]
def forward(self, x):
B, T, D = x.shape
H, Dh = self.n_heads, self.head_dim
# A5: Hadamard rotation BEFORE sign-binarize in projection.
# Apply along last dim of x. Scale by 1/sqrt(d) to keep unit variance,
# otherwise the rotated values scale up and BitLinear's own 1/sqrt(d) can't compensate.
x_rot = hadamard_transform(x) / math.sqrt(D)
Q = self.q_proj(x_rot).view(B, T, H, Dh).transpose(1, 2)
K = self.k_proj(x_rot).view(B, T, H, Dh).transpose(1, 2)
V = self.v_proj(x_rot).view(B, T, H, Dh).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(Dh)
pos = torch.arange(T, device=x.device).float()
dist = (pos.unsqueeze(0) - pos.unsqueeze(1)).abs()
alibi_bias = self.alibi_slopes.view(1, H, 1, 1) * dist.view(1, 1, T, T) / math.sqrt(Dh)
scores = scores - alibi_bias
mask = self._get_mask(T, x.device)
scores = scores.masked_fill(mask, -1e9)
# A1: integer τ (rounded shadow), per head.
tau_int = int_ste(self.attn_threshold_shadow).view(1, H, 1, 1)
A = sign_ste_clipped(scores - tau_int)
A = A.masked_fill(mask, -1.0)
O = torch.matmul(A, V)
O = O.transpose(1, 2).contiguous().view(B, T, D)
return self.o_proj(O)
class BitBlockV5(nn.Module):
"""C2 5-way parallel residual: x + attn(x) + ffn(x) + bias_A + bias_B.
bias_A, bias_B are per-layer learned ±1 vectors (T-independent) — same value
broadcast over the sequence axis. 5 odd terms ⇒ no sum-to-zero, no tie-break bias.
"""
def __init__(self, d_model, n_heads, d_ff):
super().__init__()
self.attn = BiAttentionV5(d_model, n_heads)
self.ffn = BitFFN(d_model, d_ff)
# Two learnable ±1 per-channel biases, sign-STE of float shadows.
self.bias_a = nn.Parameter(torch.randn(d_model) * 0.02)
self.bias_b = nn.Parameter(torch.randn(d_model) * 0.02)
def forward(self, x):
a = self.attn(x)
f = self.ffn(x)
ba = sign_ste(self.bias_a).view(1, 1, -1)
bb = sign_ste(self.bias_b).view(1, 1, -1)
return sign_ste(x + a + f + ba + bb)
class BitLMv5(nn.Module):
def __init__(self, vocab_size=128, d_model=256, n_layers=8, n_heads=8, d_ff=512, max_seq_len=256):
super().__init__()
assert (d_model & (d_model - 1)) == 0, "v5 requires d_model power of 2 for Hadamard"
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layers = n_layers
self.max_seq_len = max_seq_len
self.embed = BinaryEmbedding(vocab_size, d_model)
self.blocks = nn.ModuleList([
BitBlockV5(d_model, n_heads, d_ff) for _ in range(n_layers)
])
self.out_codebook = nn.Parameter(torch.randn(vocab_size, d_model) * 0.02)
self.logit_scale = nn.Parameter(torch.tensor(1.0 / math.sqrt(d_model)))
self.out_bias = nn.Parameter(torch.zeros(vocab_size))
def forward(self, idx, targets=None):
x = self.embed(idx)
for blk in self.blocks:
x = blk(x)
W_out = sign_ste(self.out_codebook)
scores = torch.matmul(x, W_out.t())
logits = scores * self.logit_scale + self.out_bias
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, self.vocab_size), targets.view(-1))
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens=200, temperature=1.0, top_k=None):
self.eval()
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.max_seq_len:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / max(temperature, 1e-5)
if top_k is not None:
v, _ = torch.topk(logits, top_k)
logits[logits < v[:, [-1]]] = -float('inf')
probs = F.softmax(logits, dim=-1)
nxt = torch.multinomial(probs, num_samples=1)
idx = torch.cat([idx, nxt], dim=1)
return idx
if __name__ == '__main__':
m = BitLMv5()
n = sum(p.numel() for p in m.parameters())
print(f"v5 params: {n:,} ({n/1e6:.2f}M)")
x = torch.randint(0, 128, (2, 64))
y = torch.randint(0, 128, (2, 64))
logits, loss = m(x, y)
print("logits:", logits.shape, "loss:", loss.item())
loss.backward()
print("backward OK")
# sanity check hadamard
x = torch.randn(3, 256)
x_h = hadamard_transform(x)
x_hh = hadamard_transform(x_h) # should be d·x
assert torch.allclose(x_hh, 256 * x, atol=1e-4), "Hadamard self-inverse check failed"
print("Hadamard self-inverse ok")
|