File size: 6,482 Bytes
4754707 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 | """v40: BitProto — attention augmented with learnable ±1 prototype keys/values.
v39 (separate Hopfield head in parallel with attention) was plateauing around
1.98 BPC: the extra head competes with real attention for residual bandwidth
and steals d_ff budget. v40 integrates prototypes *inside* the existing
attention mechanism:
K_ext = [K_from_x | K_proto] (T + M columns per head)
V_ext = [V_from_x | V_proto]
A = Gumbel-argmax over (T + M) options
O = A @ V_ext
No separate head, no extra residual summand. Prototypes live per-head,
per-layer. They're non-causal (always visible). ALiBi bias is zero for
prototype columns.
Everything remains strictly ±1 on the forward path. Prototypes are latent
floats, sign()'d at forward. Adds only 2·n_proto·d_model params per layer
(16K for n_proto=32, d_model=256).
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from model import sign_ste, BitLinear, BitFFN, BinaryEmbedding
from model_v16 import gumbel_hard_attention
class IntBinaryAttentionWithProto(nn.Module):
"""IntBinaryAttention + M learnable ±1 prototype K/V per head."""
def __init__(self, d_model, n_heads, n_proto=32):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.n_proto = n_proto
self.q_proj = BitLinear(d_model, d_model)
self.k_proj = BitLinear(d_model, d_model)
self.v_proj = BitLinear(d_model, d_model)
self.o_proj = BitLinear(d_model, d_model)
# Integer ALiBi slopes (power-of-2).
slopes = torch.tensor([1 << i for i in range(n_heads)], dtype=torch.long)
self.register_buffer('alibi_slopes_int', slopes)
# Per-head prototypes: (M, H, Dh). Latent float; sign() at forward.
self.key_proto = nn.Parameter(torch.randn(n_proto, n_heads, self.head_dim) * 0.02)
self.val_proto = nn.Parameter(torch.randn(n_proto, n_heads, self.head_dim) * 0.02)
def forward(self, x):
B, T, D = x.shape
H, Dh, M = self.n_heads, self.head_dim, self.n_proto
Q = self.q_proj(x).view(B, T, H, Dh).transpose(1, 2) # (B, H, T, Dh)
K = self.k_proj(x).view(B, T, H, Dh).transpose(1, 2)
V = self.v_proj(x).view(B, T, H, Dh).transpose(1, 2)
# Binarize + broadcast prototypes.
Kp = sign_ste(self.key_proto).permute(1, 0, 2) # (H, M, Dh)
Vp = sign_ste(self.val_proto).permute(1, 0, 2)
Kp = Kp.unsqueeze(0).expand(B, H, M, Dh)
Vp = Vp.unsqueeze(0).expand(B, H, M, Dh)
K_ext = torch.cat([K, Kp], dim=2) # (B, H, T+M, Dh)
V_ext = torch.cat([V, Vp], dim=2)
scores = torch.matmul(Q, K_ext.transpose(-2, -1)) # (B, H, T, T+M)
# ALiBi over T-part only; 0 bias for prototypes.
pos = torch.arange(T, device=x.device)
dist = (pos.unsqueeze(0) - pos.unsqueeze(1)).abs() # (T, T)
alibi_t = self.alibi_slopes_int.view(1, H, 1, 1).to(scores.dtype) \
* dist.view(1, 1, T, T).to(scores.dtype)
alibi_p = torch.zeros(1, H, T, M, dtype=scores.dtype, device=x.device)
alibi = torch.cat([alibi_t, alibi_p], dim=-1)
scores = scores - alibi
# Causal mask over T-part; prototypes always visible.
causal = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1)
visible_p = torch.zeros(T, M, device=x.device, dtype=torch.bool)
mask = torch.cat([causal, visible_p], dim=-1) # (T, T+M)
A = gumbel_hard_attention(scores, mask=mask) # (B, H, T, T+M)
O = torch.matmul(A, V_ext)
O = O.transpose(1, 2).contiguous().view(B, T, D)
return self.o_proj(O)
class BitBlockV40(nn.Module):
def __init__(self, d_model, n_heads, d_ff, n_proto):
super().__init__()
self.attn = IntBinaryAttentionWithProto(d_model, n_heads, n_proto)
self.ffn = BitFFN(d_model, d_ff)
def forward(self, x):
a = self.attn(x)
f = self.ffn(x)
return sign_ste(x + a + f)
class BitLMv40(nn.Module):
def __init__(self, vocab_size=128, d_model=256, n_layers=8, n_heads=8,
d_ff=444, n_proto=32, max_seq_len=256):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layers = n_layers
self.n_heads = n_heads
self.max_seq_len = max_seq_len
self.n_proto = n_proto
self.embed = BinaryEmbedding(vocab_size, d_model)
self.blocks = nn.ModuleList([
BitBlockV40(d_model, n_heads, d_ff, n_proto) for _ in range(n_layers)
])
self.out_codebook = nn.Parameter(torch.randn(vocab_size, d_model) * 0.02)
self.logit_scale = nn.Parameter(torch.tensor(1.0 / math.sqrt(d_model)))
self.out_bias = nn.Parameter(torch.zeros(vocab_size))
def forward(self, idx, targets=None):
x = self.embed(idx)
for blk in self.blocks:
x = blk(x)
W_out = sign_ste(self.out_codebook)
scores = torch.matmul(x, W_out.t())
logits = scores * self.logit_scale + self.out_bias
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, self.vocab_size), targets.view(-1))
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens=200, temperature=1.0, top_k=None):
self.eval()
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.max_seq_len:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / max(temperature, 1e-5)
if top_k is not None:
v, _ = torch.topk(logits, top_k)
logits[logits < v[:, [-1]]] = -float('inf')
probs = F.softmax(logits, dim=-1)
nxt = torch.multinomial(probs, num_samples=1)
idx = torch.cat([idx, nxt], dim=1)
return idx
if __name__ == '__main__':
from model_v16 import set_gumbel_tau
set_gumbel_tau(0.5)
for d_ff in (440, 444, 448):
m = BitLMv40(d_ff=d_ff)
n = sum(p.numel() for p in m.parameters())
print(f'd_ff={d_ff}: {n:,} ({n/1e6:.3f}M)')
m = BitLMv40()
x = torch.randint(0, 128, (2, 64))
y = torch.randint(0, 128, (2, 64))
logits, loss = m(x, y)
loss.backward()
print(f'loss={loss.item():.3f}, backward OK')
|