File size: 6,856 Bytes
4754707
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
"""v16: Gumbel hard-attention. Each query attends to exactly ONE key, selected via
Gumbel-softmax with temperature annealing from soft → hard.

Why this might work where v11 top-k failed: v11's STE through top-k gave gradient
that pushed scores up/down but the discrete selection didn't move easily. Gumbel
softmax gives a proper continuous-to-discrete bridge. At high temperature, attn
is like softmax (multiple positions active). At low temperature, attn is
one-hot (single position). Training anneals high → low.

At eval: pure argmax. Each query attends to exactly one position (attention as
pointer). This is ternary {-1, 0, +1} in the attention matrix: one +1 per row,
rest 0s, with optional sign flip carried via separate bit.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from model import sign_ste, sign_ste_clipped, BitLinear, BitFFN, BinaryEmbedding


# Module-level temperature as a mutable CUDA tensor so torch.compile doesn't
# retrace every step when we anneal tau.
_GUMBEL_TAU = torch.tensor([1.0])


def set_gumbel_tau(tau: float):
    """Mutate the tau tensor in place — keeps the same object identity so
    torch.compile doesn't see a new constant."""
    global _GUMBEL_TAU
    _GUMBEL_TAU.fill_(float(tau))


def _get_tau(device):
    """Return the current tau as a device-resident tensor."""
    global _GUMBEL_TAU
    if _GUMBEL_TAU.device != device:
        _GUMBEL_TAU = _GUMBEL_TAU.to(device)
    return _GUMBEL_TAU.clamp(min=0.05)


def gumbel_hard_attention(scores, mask=None):
    """scores: (B, H, T, T). mask: bool (T, T) with True for positions to zero out.
    Returns (B, H, T, T) attention matrix with one non-zero entry per row at train
    time (straight-through hard), and pure argmax at eval."""
    tau = _get_tau(scores.device)
    if mask is not None:
        scores = scores.masked_fill(mask, -1e9)
    if scores.requires_grad:
        # Gumbel-softmax sample, then straight-through hardify.
        g = -torch.log(-torch.log(torch.rand_like(scores).clamp(min=1e-9)) + 1e-9)
        y_soft = F.softmax((scores + g) / tau, dim=-1)
        y_hard = torch.zeros_like(y_soft)
        y_hard.scatter_(-1, y_soft.argmax(-1, keepdim=True), 1.0)
        return y_soft + (y_hard - y_soft).detach()
    else:
        # Eval: pure argmax
        y_hard = torch.zeros_like(scores)
        y_hard.scatter_(-1, scores.argmax(-1, keepdim=True), 1.0)
        return y_hard


class GumbelHardAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.q_proj = BitLinear(d_model, d_model)
        self.k_proj = BitLinear(d_model, d_model)
        self.v_proj = BitLinear(d_model, d_model)
        self.o_proj = BitLinear(d_model, d_model)
        slopes = torch.tensor([2.0 ** (i - 2) for i in range(n_heads)])
        self.register_buffer('alibi_slopes', slopes)
        self.register_buffer('_causal_mask', torch.empty(0), persistent=False)

    def _get_mask(self, T, device):
        if self._causal_mask.shape[-1] < T or self._causal_mask.device != device:
            m = torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1)
            self._causal_mask = m
        return self._causal_mask[:T, :T]

    def forward(self, x):
        B, T, D = x.shape
        H, Dh = self.n_heads, self.head_dim
        Q = self.q_proj(x).view(B, T, H, Dh).transpose(1, 2)
        K = self.k_proj(x).view(B, T, H, Dh).transpose(1, 2)
        V = self.v_proj(x).view(B, T, H, Dh).transpose(1, 2)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(Dh)

        pos = torch.arange(T, device=x.device).float()
        dist = (pos.unsqueeze(0) - pos.unsqueeze(1)).abs()
        alibi_bias = self.alibi_slopes.view(1, H, 1, 1) * dist.view(1, 1, T, T) / math.sqrt(Dh)
        scores = scores - alibi_bias

        mask = self._get_mask(T, x.device)
        A = gumbel_hard_attention(scores, mask=mask)  # 1-hot per row
        # A is float (soft at train, hard at eval). Multiply by sign of V to mimic
        # value aggregation; for pure strict ±1 we'd also sign V before, but V is
        # already ±1 by construction.
        O = torch.matmul(A, V)
        O = O.transpose(1, 2).contiguous().view(B, T, D)
        return self.o_proj(O)


class BitBlockV16(nn.Module):
    def __init__(self, d_model, n_heads, d_ff):
        super().__init__()
        self.attn = GumbelHardAttention(d_model, n_heads)
        self.ffn = BitFFN(d_model, d_ff)

    def forward(self, x):
        a = self.attn(x)
        f = self.ffn(x)
        return sign_ste(x + a + f)


class BitLMv16(nn.Module):
    def __init__(self, vocab_size=128, d_model=256, n_layers=8, n_heads=8, d_ff=512, max_seq_len=256):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.n_layers = n_layers
        self.max_seq_len = max_seq_len
        self.embed = BinaryEmbedding(vocab_size, d_model)
        self.blocks = nn.ModuleList([BitBlockV16(d_model, n_heads, d_ff) for _ in range(n_layers)])
        self.out_codebook = nn.Parameter(torch.randn(vocab_size, d_model) * 0.02)
        self.logit_scale = nn.Parameter(torch.tensor(1.0 / math.sqrt(d_model)))
        self.out_bias = nn.Parameter(torch.zeros(vocab_size))

    def forward(self, idx, targets=None):
        x = self.embed(idx)
        for blk in self.blocks:
            x = blk(x)
        W_out = sign_ste(self.out_codebook)
        scores = torch.matmul(x, W_out.t())
        logits = scores * self.logit_scale + self.out_bias
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, self.vocab_size), targets.view(-1))
        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens=200, temperature=1.0, top_k=None):
        self.eval()
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.max_seq_len:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / max(temperature, 1e-5)
            if top_k is not None:
                v, _ = torch.topk(logits, top_k)
                logits[logits < v[:, [-1]]] = -float('inf')
            probs = F.softmax(logits, dim=-1)
            nxt = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, nxt], dim=1)
        return idx


if __name__ == '__main__':
    set_gumbel_tau(1.0)
    m = BitLMv16()
    n = sum(p.numel() for p in m.parameters())
    print(f"v16 params: {n:,} ({n/1e6:.2f}M)")
    x = torch.randint(0, 128, (2, 64))
    y = torch.randint(0, 128, (2, 64))
    logits, loss = m(x, y)
    print("logits:", logits.shape, "loss:", loss.item())
    loss.backward()
    print("backward OK")