File size: 5,218 Bytes
4754707
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""v25: Track VII.A — Gumbel-routed ±1 MoE FFN.

Reuses the same Gumbel-softmax hard-argmax machinery we already know trains
well for attention; applies it to expert routing. Token → router scores (one
per expert) → Gumbel one-hot selection at training → pure argmax at inference.

Each of E experts is a standard v18 BitFFN. Matches v21's total active
per-token compute when `experts = 4, d_ff_per_expert = d_ff/4` (standard MoE
"fixed active FLOPs" setup), at cost of 4× more total parameters. We instead
use matched-total-params (each expert has d_ff = d_model), which means total
params equal v21 but active per-token FLOPs drop 4×.

Routing is pure-integer at inference:
  scores = popcount(W_router ⊕ x)   # (E,) integer per token per layer
  expert = argmax(scores)            # integer compare tree
  y = experts[expert](x)

All weights ±1. All activations ±1. Only train-time float: Gumbel-softmax's
softmax (same concession v18 already pays for attention).
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from model import sign_ste, sign_ste_clipped, BitLinear, BitFFN, BinaryEmbedding
from model_v18 import IntBinaryAttention
from model_v16 import set_gumbel_tau, _get_tau


def gumbel_route(scores, mask=None):
    """Gumbel hard routing; soft-to-hard STE at train, argmax at eval."""
    tau = _get_tau(scores.device)
    if scores.requires_grad:
        g = -torch.log(-torch.log(torch.rand_like(scores).clamp(min=1e-9)) + 1e-9)
        y_soft = F.softmax((scores + g) / tau, dim=-1)
        y_hard = torch.zeros_like(y_soft)
        y_hard.scatter_(-1, y_soft.argmax(-1, keepdim=True), 1.0)
        return y_soft + (y_hard - y_soft).detach()
    else:
        y = torch.zeros_like(scores)
        y.scatter_(-1, scores.argmax(-1, keepdim=True), 1.0)
        return y


class MoEFFN(nn.Module):
    def __init__(self, d_model, d_ff, E=4):
        super().__init__()
        self.E = E
        self.d_model = d_model
        # Router: ±1 weight mapping x → E scores
        self.router_w = nn.Parameter(torch.randn(E, d_model) * 0.02)
        # Experts: each is a standard BitFFN
        self.experts = nn.ModuleList([BitFFN(d_model, d_ff) for _ in range(E)])

    def forward(self, x):
        # x: (B, T, D) ±1
        B, T, D = x.shape
        # Route
        W_r = sign_ste(self.router_w)  # (E, D) ±1
        x_bin = sign_ste_clipped(x)
        scores = F.linear(x_bin, W_r)  # (B, T, E) integer popcount
        route = gumbel_route(scores)    # (B, T, E) soft-to-hard

        # Compute all E experts (simple implementation; real MoE would dispatch).
        # (B, T, D) each
        outs = torch.stack([exp(x) for exp in self.experts], dim=-2)  # (B, T, E, D)
        # Mix by route weights: (B, T, E, 1) * (B, T, E, D) -> sum over E
        return (route.unsqueeze(-1) * outs).sum(dim=-2)


class BitBlockV25(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, E=4):
        super().__init__()
        self.attn = IntBinaryAttention(d_model, n_heads)
        self.ffn = MoEFFN(d_model, d_ff, E=E)

    def forward(self, x):
        a = self.attn(x)
        f = self.ffn(x)
        return sign_ste(x + a + f)


class BitLMv25(nn.Module):
    def __init__(self, vocab_size=128, d_model=256, n_layers=8, n_heads=8, d_ff=512,
                 max_seq_len=256, E=4):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.n_layers = n_layers
        self.max_seq_len = max_seq_len
        self.E = E
        self.embed = BinaryEmbedding(vocab_size, d_model)
        self.blocks = nn.ModuleList([
            BitBlockV25(d_model, n_heads, d_ff, E=E) for _ in range(n_layers)
        ])
        self.out_codebook = nn.Parameter(torch.randn(vocab_size, d_model) * 0.02)
        self.logit_scale = nn.Parameter(torch.tensor(1.0 / math.sqrt(d_model)))
        self.out_bias = nn.Parameter(torch.zeros(vocab_size))

    def forward(self, idx, targets=None):
        x = self.embed(idx)
        for blk in self.blocks:
            x = blk(x)
        W_out = sign_ste(self.out_codebook)
        scores = torch.matmul(x, W_out.t())
        logits = scores * self.logit_scale + self.out_bias
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, self.vocab_size), targets.view(-1))
        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens=200, temperature=1.0, top_k=None):
        self.eval()
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.max_seq_len:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / max(temperature, 1e-5)
            if top_k is not None:
                v, _ = torch.topk(logits, top_k)
                logits[logits < v[:, [-1]]] = -float('inf')
            probs = F.softmax(logits, dim=-1)
            nxt = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, nxt], dim=1)
        return idx


if __name__ == '__main__':
    set_gumbel_tau(0.5)
    for E in [2, 4, 8]:
        m = BitLMv25(E=E)
        n = sum(p.numel() for p in m.parameters())
        print(f'v25 E={E}: {n:,} params ({n/1e6:.2f}M)')