File size: 6,161 Bytes
4754707
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""v11: top-k binary attention (ternary {-1,0,+1} attention matrix).

Issue 1 isolation test. Exactly v3 architecture except the attention matrix A
selects only the top-k positions per query (ternary: {-1,0,+1} where 0 = ignore).

Hypothesis: if the POC plateau at 3.20 BPC is caused by binary attention's
inability to express selective sparsity, then restoring sparse selection via
top-k should close most of the v3→v4 gap (0.48 BPC) while adding only the
minimal concession of ternary attention weights (per-position A ∈ {−1, 0, +1}).
Everything else (weights, Q/K/V/O projections, FFN, residuals, embeddings)
stays strict ±1.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from model import sign_ste, sign_ste_clipped, BitLinear, BitFFN, BinaryEmbedding


class TopKBinaryAttention(nn.Module):
    def __init__(self, d_model, n_heads, topk=8):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.topk = topk

        self.q_proj = BitLinear(d_model, d_model, binarize_input=True)
        self.k_proj = BitLinear(d_model, d_model, binarize_input=True)
        self.v_proj = BitLinear(d_model, d_model, binarize_input=True)
        self.o_proj = BitLinear(d_model, d_model, binarize_input=True)

        slopes = torch.tensor([2.0 ** (i - 2) for i in range(n_heads)])
        self.register_buffer('alibi_slopes', slopes)
        self.register_buffer('_causal_mask', torch.empty(0), persistent=False)

    def _get_mask(self, T, device):
        if self._causal_mask.shape[-1] < T or self._causal_mask.device != device:
            m = torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1)
            self._causal_mask = m
        return self._causal_mask[:T, :T]

    def forward(self, x):
        B, T, D = x.shape
        H, Dh = self.n_heads, self.head_dim
        Q = self.q_proj(x).view(B, T, H, Dh).transpose(1, 2)
        K = self.k_proj(x).view(B, T, H, Dh).transpose(1, 2)
        V = self.v_proj(x).view(B, T, H, Dh).transpose(1, 2)

        scores = torch.matmul(Q, K.transpose(-2, -1))  # (B,H,T,T) integer popcount
        scores_f = scores / math.sqrt(Dh)

        pos = torch.arange(T, device=x.device).float()
        dist = (pos.unsqueeze(0) - pos.unsqueeze(1)).abs()
        alibi_bias = self.alibi_slopes.view(1, H, 1, 1) * dist.view(1, 1, T, T) / math.sqrt(Dh)
        scores_f = scores_f - alibi_bias

        mask = self._get_mask(T, x.device)
        scores_f = scores_f.masked_fill(mask, -1e9)

        # Per-query top-k selection. k is clamped to number of valid (unmasked) keys.
        # For query position i, exactly min(i+1, topk) keys are valid.
        k = min(self.topk, T)
        _, topk_idx = torch.topk(scores_f, k=k, dim=-1)  # (B,H,T,k)

        # Build mask_on: 1 at top-k positions, 0 elsewhere
        mask_on = torch.zeros_like(scores_f, dtype=scores_f.dtype)
        mask_on.scatter_(-1, topk_idx, 1.0)

        # Ternary attention: sign(scores) * mask_on, giving {-1, 0, +1}.
        # STE: forward ternary, backward identity through the float scores.
        sign_scores = torch.where(scores_f >= 0, torch.ones_like(scores_f), -torch.ones_like(scores_f))
        A_ternary = sign_scores * mask_on  # {-1, 0, +1}
        # Also zero out attention on causally-masked positions explicitly.
        A_ternary = A_ternary.masked_fill(mask, 0.0)
        # STE pass-through
        A = scores_f + (A_ternary - scores_f).detach()

        O = torch.matmul(A, V)
        O = O.transpose(1, 2).contiguous().view(B, T, D)
        return self.o_proj(O)


class BitBlockV11(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, topk=8):
        super().__init__()
        self.attn = TopKBinaryAttention(d_model, n_heads, topk=topk)
        self.ffn = BitFFN(d_model, d_ff)

    def forward(self, x):
        a = self.attn(x)
        f = self.ffn(x)
        return sign_ste(x + a + f)


class BitLMv11(nn.Module):
    def __init__(self, vocab_size=128, d_model=256, n_layers=8, n_heads=8, d_ff=512, max_seq_len=256, topk=8):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.n_layers = n_layers
        self.max_seq_len = max_seq_len
        self.embed = BinaryEmbedding(vocab_size, d_model)
        self.blocks = nn.ModuleList([
            BitBlockV11(d_model, n_heads, d_ff, topk=topk) for _ in range(n_layers)
        ])
        self.out_codebook = nn.Parameter(torch.randn(vocab_size, d_model) * 0.02)
        self.logit_scale = nn.Parameter(torch.tensor(1.0 / math.sqrt(d_model)))
        self.out_bias = nn.Parameter(torch.zeros(vocab_size))

    def forward(self, idx, targets=None):
        x = self.embed(idx)
        for blk in self.blocks:
            x = blk(x)
        W_out = sign_ste(self.out_codebook)
        scores = torch.matmul(x, W_out.t())
        logits = scores * self.logit_scale + self.out_bias
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, self.vocab_size), targets.view(-1))
        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens=200, temperature=1.0, top_k=None):
        self.eval()
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.max_seq_len:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / max(temperature, 1e-5)
            if top_k is not None:
                v, _ = torch.topk(logits, top_k)
                logits[logits < v[:, [-1]]] = -float('inf')
            probs = F.softmax(logits, dim=-1)
            nxt = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, nxt], dim=1)
        return idx


if __name__ == '__main__':
    m = BitLMv11(topk=8)
    n = sum(p.numel() for p in m.parameters())
    print(f"v11 params: {n:,} ({n/1e6:.2f}M)")
    x = torch.randint(0, 128, (2, 64))
    y = torch.randint(0, 128, (2, 64))
    logits, loss = m(x, y)
    print("logits:", logits.shape, "loss:", loss.item())
    loss.backward()
    print("backward OK")