File size: 6,161 Bytes
4754707 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | """v11: top-k binary attention (ternary {-1,0,+1} attention matrix).
Issue 1 isolation test. Exactly v3 architecture except the attention matrix A
selects only the top-k positions per query (ternary: {-1,0,+1} where 0 = ignore).
Hypothesis: if the POC plateau at 3.20 BPC is caused by binary attention's
inability to express selective sparsity, then restoring sparse selection via
top-k should close most of the v3→v4 gap (0.48 BPC) while adding only the
minimal concession of ternary attention weights (per-position A ∈ {−1, 0, +1}).
Everything else (weights, Q/K/V/O projections, FFN, residuals, embeddings)
stays strict ±1.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from model import sign_ste, sign_ste_clipped, BitLinear, BitFFN, BinaryEmbedding
class TopKBinaryAttention(nn.Module):
def __init__(self, d_model, n_heads, topk=8):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.topk = topk
self.q_proj = BitLinear(d_model, d_model, binarize_input=True)
self.k_proj = BitLinear(d_model, d_model, binarize_input=True)
self.v_proj = BitLinear(d_model, d_model, binarize_input=True)
self.o_proj = BitLinear(d_model, d_model, binarize_input=True)
slopes = torch.tensor([2.0 ** (i - 2) for i in range(n_heads)])
self.register_buffer('alibi_slopes', slopes)
self.register_buffer('_causal_mask', torch.empty(0), persistent=False)
def _get_mask(self, T, device):
if self._causal_mask.shape[-1] < T or self._causal_mask.device != device:
m = torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1)
self._causal_mask = m
return self._causal_mask[:T, :T]
def forward(self, x):
B, T, D = x.shape
H, Dh = self.n_heads, self.head_dim
Q = self.q_proj(x).view(B, T, H, Dh).transpose(1, 2)
K = self.k_proj(x).view(B, T, H, Dh).transpose(1, 2)
V = self.v_proj(x).view(B, T, H, Dh).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1)) # (B,H,T,T) integer popcount
scores_f = scores / math.sqrt(Dh)
pos = torch.arange(T, device=x.device).float()
dist = (pos.unsqueeze(0) - pos.unsqueeze(1)).abs()
alibi_bias = self.alibi_slopes.view(1, H, 1, 1) * dist.view(1, 1, T, T) / math.sqrt(Dh)
scores_f = scores_f - alibi_bias
mask = self._get_mask(T, x.device)
scores_f = scores_f.masked_fill(mask, -1e9)
# Per-query top-k selection. k is clamped to number of valid (unmasked) keys.
# For query position i, exactly min(i+1, topk) keys are valid.
k = min(self.topk, T)
_, topk_idx = torch.topk(scores_f, k=k, dim=-1) # (B,H,T,k)
# Build mask_on: 1 at top-k positions, 0 elsewhere
mask_on = torch.zeros_like(scores_f, dtype=scores_f.dtype)
mask_on.scatter_(-1, topk_idx, 1.0)
# Ternary attention: sign(scores) * mask_on, giving {-1, 0, +1}.
# STE: forward ternary, backward identity through the float scores.
sign_scores = torch.where(scores_f >= 0, torch.ones_like(scores_f), -torch.ones_like(scores_f))
A_ternary = sign_scores * mask_on # {-1, 0, +1}
# Also zero out attention on causally-masked positions explicitly.
A_ternary = A_ternary.masked_fill(mask, 0.0)
# STE pass-through
A = scores_f + (A_ternary - scores_f).detach()
O = torch.matmul(A, V)
O = O.transpose(1, 2).contiguous().view(B, T, D)
return self.o_proj(O)
class BitBlockV11(nn.Module):
def __init__(self, d_model, n_heads, d_ff, topk=8):
super().__init__()
self.attn = TopKBinaryAttention(d_model, n_heads, topk=topk)
self.ffn = BitFFN(d_model, d_ff)
def forward(self, x):
a = self.attn(x)
f = self.ffn(x)
return sign_ste(x + a + f)
class BitLMv11(nn.Module):
def __init__(self, vocab_size=128, d_model=256, n_layers=8, n_heads=8, d_ff=512, max_seq_len=256, topk=8):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layers = n_layers
self.max_seq_len = max_seq_len
self.embed = BinaryEmbedding(vocab_size, d_model)
self.blocks = nn.ModuleList([
BitBlockV11(d_model, n_heads, d_ff, topk=topk) for _ in range(n_layers)
])
self.out_codebook = nn.Parameter(torch.randn(vocab_size, d_model) * 0.02)
self.logit_scale = nn.Parameter(torch.tensor(1.0 / math.sqrt(d_model)))
self.out_bias = nn.Parameter(torch.zeros(vocab_size))
def forward(self, idx, targets=None):
x = self.embed(idx)
for blk in self.blocks:
x = blk(x)
W_out = sign_ste(self.out_codebook)
scores = torch.matmul(x, W_out.t())
logits = scores * self.logit_scale + self.out_bias
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, self.vocab_size), targets.view(-1))
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens=200, temperature=1.0, top_k=None):
self.eval()
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.max_seq_len:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / max(temperature, 1e-5)
if top_k is not None:
v, _ = torch.topk(logits, top_k)
logits[logits < v[:, [-1]]] = -float('inf')
probs = F.softmax(logits, dim=-1)
nxt = torch.multinomial(probs, num_samples=1)
idx = torch.cat([idx, nxt], dim=1)
return idx
if __name__ == '__main__':
m = BitLMv11(topk=8)
n = sum(p.numel() for p in m.parameters())
print(f"v11 params: {n:,} ({n/1e6:.2f}M)")
x = torch.randint(0, 128, (2, 64))
y = torch.randint(0, 128, (2, 64))
logits, loss = m(x, y)
print("logits:", logits.shape, "loss:", loss.item())
loss.backward()
print("backward OK")
|