File size: 4,008 Bytes
4754707
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""v4 variant: softmax attention (concession #2 from proposal's degradation ladder).

Keeps ±1 for everything except the attention score pass, where softmax-over-floats is
restored. Hypothesis test: is the bool-threshold attention the binding constraint? If v4
dramatically outperforms v3, the answer is "yes, attention binarization was the problem."
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from model import (
    sign_ste, sign_ste_clipped, BitLinearRaw, BitLinear,
    BitFFN, BinaryEmbedding,
)


class SoftAttention(nn.Module):
    """Same ±1 Q/K/V projections, but softmax (float) attention scores."""
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.q_proj = BitLinear(d_model, d_model, binarize_input=True)
        self.k_proj = BitLinear(d_model, d_model, binarize_input=True)
        self.v_proj = BitLinear(d_model, d_model, binarize_input=True)
        self.o_proj = BitLinear(d_model, d_model, binarize_input=True)

    def forward(self, x):
        B, T, D = x.shape
        H, Dh = self.n_heads, self.head_dim
        Q = self.q_proj(x).view(B, T, H, Dh).transpose(1, 2)
        K = self.k_proj(x).view(B, T, H, Dh).transpose(1, 2)
        V = self.v_proj(x).view(B, T, H, Dh).transpose(1, 2)
        y = F.scaled_dot_product_attention(Q, K, V, is_causal=True)
        y = y.transpose(1, 2).contiguous().view(B, T, D)
        return self.o_proj(y)


class BitBlockV4(nn.Module):
    def __init__(self, d_model, n_heads, d_ff):
        super().__init__()
        self.attn = SoftAttention(d_model, n_heads)
        self.ffn = BitFFN(d_model, d_ff)

    def forward(self, x):
        a = self.attn(x)
        f = self.ffn(x)
        return sign_ste(x + a + f)


class BitLMv4(nn.Module):
    def __init__(self, vocab_size=128, d_model=256, n_layers=8, n_heads=8, d_ff=512, max_seq_len=256):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.n_layers = n_layers
        self.max_seq_len = max_seq_len
        self.embed = BinaryEmbedding(vocab_size, d_model)
        self.blocks = nn.ModuleList([
            BitBlockV4(d_model, n_heads, d_ff) for _ in range(n_layers)
        ])
        self.out_codebook = nn.Parameter(torch.randn(vocab_size, d_model) * 0.02)
        self.logit_scale = nn.Parameter(torch.tensor(1.0 / math.sqrt(d_model)))
        self.out_bias = nn.Parameter(torch.zeros(vocab_size))

    def forward(self, idx, targets=None):
        x = self.embed(idx)
        for blk in self.blocks:
            x = blk(x)
        W_out = sign_ste(self.out_codebook)
        scores = torch.matmul(x, W_out.t())
        logits = scores * self.logit_scale + self.out_bias
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, self.vocab_size), targets.view(-1))
        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens=200, temperature=1.0, top_k=None):
        self.eval()
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.max_seq_len:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / max(temperature, 1e-5)
            if top_k is not None:
                v, _ = torch.topk(logits, top_k)
                logits[logits < v[:, [-1]]] = -float('inf')
            probs = F.softmax(logits, dim=-1)
            nxt = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, nxt], dim=1)
        return idx


if __name__ == '__main__':
    m = BitLMv4()
    n = sum(p.numel() for p in m.parameters())
    print(f"v4 params: {n:,} ({n/1e6:.2f}M)")
    x = torch.randint(0, 128, (2, 64))
    y = torch.randint(0, 128, (2, 64))
    logits, loss = m(x, y)
    print("logits:", logits.shape, "loss:", loss.item())
    loss.backward()
    print("backward OK")