File size: 6,856 Bytes
4754707 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | """v16: Gumbel hard-attention. Each query attends to exactly ONE key, selected via
Gumbel-softmax with temperature annealing from soft → hard.
Why this might work where v11 top-k failed: v11's STE through top-k gave gradient
that pushed scores up/down but the discrete selection didn't move easily. Gumbel
softmax gives a proper continuous-to-discrete bridge. At high temperature, attn
is like softmax (multiple positions active). At low temperature, attn is
one-hot (single position). Training anneals high → low.
At eval: pure argmax. Each query attends to exactly one position (attention as
pointer). This is ternary {-1, 0, +1} in the attention matrix: one +1 per row,
rest 0s, with optional sign flip carried via separate bit.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from model import sign_ste, sign_ste_clipped, BitLinear, BitFFN, BinaryEmbedding
# Module-level temperature as a mutable CUDA tensor so torch.compile doesn't
# retrace every step when we anneal tau.
_GUMBEL_TAU = torch.tensor([1.0])
def set_gumbel_tau(tau: float):
"""Mutate the tau tensor in place — keeps the same object identity so
torch.compile doesn't see a new constant."""
global _GUMBEL_TAU
_GUMBEL_TAU.fill_(float(tau))
def _get_tau(device):
"""Return the current tau as a device-resident tensor."""
global _GUMBEL_TAU
if _GUMBEL_TAU.device != device:
_GUMBEL_TAU = _GUMBEL_TAU.to(device)
return _GUMBEL_TAU.clamp(min=0.05)
def gumbel_hard_attention(scores, mask=None):
"""scores: (B, H, T, T). mask: bool (T, T) with True for positions to zero out.
Returns (B, H, T, T) attention matrix with one non-zero entry per row at train
time (straight-through hard), and pure argmax at eval."""
tau = _get_tau(scores.device)
if mask is not None:
scores = scores.masked_fill(mask, -1e9)
if scores.requires_grad:
# Gumbel-softmax sample, then straight-through hardify.
g = -torch.log(-torch.log(torch.rand_like(scores).clamp(min=1e-9)) + 1e-9)
y_soft = F.softmax((scores + g) / tau, dim=-1)
y_hard = torch.zeros_like(y_soft)
y_hard.scatter_(-1, y_soft.argmax(-1, keepdim=True), 1.0)
return y_soft + (y_hard - y_soft).detach()
else:
# Eval: pure argmax
y_hard = torch.zeros_like(scores)
y_hard.scatter_(-1, scores.argmax(-1, keepdim=True), 1.0)
return y_hard
class GumbelHardAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.q_proj = BitLinear(d_model, d_model)
self.k_proj = BitLinear(d_model, d_model)
self.v_proj = BitLinear(d_model, d_model)
self.o_proj = BitLinear(d_model, d_model)
slopes = torch.tensor([2.0 ** (i - 2) for i in range(n_heads)])
self.register_buffer('alibi_slopes', slopes)
self.register_buffer('_causal_mask', torch.empty(0), persistent=False)
def _get_mask(self, T, device):
if self._causal_mask.shape[-1] < T or self._causal_mask.device != device:
m = torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1)
self._causal_mask = m
return self._causal_mask[:T, :T]
def forward(self, x):
B, T, D = x.shape
H, Dh = self.n_heads, self.head_dim
Q = self.q_proj(x).view(B, T, H, Dh).transpose(1, 2)
K = self.k_proj(x).view(B, T, H, Dh).transpose(1, 2)
V = self.v_proj(x).view(B, T, H, Dh).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(Dh)
pos = torch.arange(T, device=x.device).float()
dist = (pos.unsqueeze(0) - pos.unsqueeze(1)).abs()
alibi_bias = self.alibi_slopes.view(1, H, 1, 1) * dist.view(1, 1, T, T) / math.sqrt(Dh)
scores = scores - alibi_bias
mask = self._get_mask(T, x.device)
A = gumbel_hard_attention(scores, mask=mask) # 1-hot per row
# A is float (soft at train, hard at eval). Multiply by sign of V to mimic
# value aggregation; for pure strict ±1 we'd also sign V before, but V is
# already ±1 by construction.
O = torch.matmul(A, V)
O = O.transpose(1, 2).contiguous().view(B, T, D)
return self.o_proj(O)
class BitBlockV16(nn.Module):
def __init__(self, d_model, n_heads, d_ff):
super().__init__()
self.attn = GumbelHardAttention(d_model, n_heads)
self.ffn = BitFFN(d_model, d_ff)
def forward(self, x):
a = self.attn(x)
f = self.ffn(x)
return sign_ste(x + a + f)
class BitLMv16(nn.Module):
def __init__(self, vocab_size=128, d_model=256, n_layers=8, n_heads=8, d_ff=512, max_seq_len=256):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layers = n_layers
self.max_seq_len = max_seq_len
self.embed = BinaryEmbedding(vocab_size, d_model)
self.blocks = nn.ModuleList([BitBlockV16(d_model, n_heads, d_ff) for _ in range(n_layers)])
self.out_codebook = nn.Parameter(torch.randn(vocab_size, d_model) * 0.02)
self.logit_scale = nn.Parameter(torch.tensor(1.0 / math.sqrt(d_model)))
self.out_bias = nn.Parameter(torch.zeros(vocab_size))
def forward(self, idx, targets=None):
x = self.embed(idx)
for blk in self.blocks:
x = blk(x)
W_out = sign_ste(self.out_codebook)
scores = torch.matmul(x, W_out.t())
logits = scores * self.logit_scale + self.out_bias
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, self.vocab_size), targets.view(-1))
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens=200, temperature=1.0, top_k=None):
self.eval()
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.max_seq_len:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / max(temperature, 1e-5)
if top_k is not None:
v, _ = torch.topk(logits, top_k)
logits[logits < v[:, [-1]]] = -float('inf')
probs = F.softmax(logits, dim=-1)
nxt = torch.multinomial(probs, num_samples=1)
idx = torch.cat([idx, nxt], dim=1)
return idx
if __name__ == '__main__':
set_gumbel_tau(1.0)
m = BitLMv16()
n = sum(p.numel() for p in m.parameters())
print(f"v16 params: {n:,} ({n/1e6:.2f}M)")
x = torch.randint(0, 128, (2, 64))
y = torch.randint(0, 128, (2, 64))
logits, loss = m(x, y)
print("logits:", logits.shape, "loss:", loss.item())
loss.backward()
print("backward OK")
|