bitnet-1bitllm / vm_backup /code /model_v25.py
hidude562's picture
1bitllm code (checkpoints to follow)
4754707 verified
"""v25: Track VII.A — Gumbel-routed ±1 MoE FFN.
Reuses the same Gumbel-softmax hard-argmax machinery we already know trains
well for attention; applies it to expert routing. Token → router scores (one
per expert) → Gumbel one-hot selection at training → pure argmax at inference.
Each of E experts is a standard v18 BitFFN. Matches v21's total active
per-token compute when `experts = 4, d_ff_per_expert = d_ff/4` (standard MoE
"fixed active FLOPs" setup), at cost of 4× more total parameters. We instead
use matched-total-params (each expert has d_ff = d_model), which means total
params equal v21 but active per-token FLOPs drop 4×.
Routing is pure-integer at inference:
scores = popcount(W_router ⊕ x) # (E,) integer per token per layer
expert = argmax(scores) # integer compare tree
y = experts[expert](x)
All weights ±1. All activations ±1. Only train-time float: Gumbel-softmax's
softmax (same concession v18 already pays for attention).
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from model import sign_ste, sign_ste_clipped, BitLinear, BitFFN, BinaryEmbedding
from model_v18 import IntBinaryAttention
from model_v16 import set_gumbel_tau, _get_tau
def gumbel_route(scores, mask=None):
"""Gumbel hard routing; soft-to-hard STE at train, argmax at eval."""
tau = _get_tau(scores.device)
if scores.requires_grad:
g = -torch.log(-torch.log(torch.rand_like(scores).clamp(min=1e-9)) + 1e-9)
y_soft = F.softmax((scores + g) / tau, dim=-1)
y_hard = torch.zeros_like(y_soft)
y_hard.scatter_(-1, y_soft.argmax(-1, keepdim=True), 1.0)
return y_soft + (y_hard - y_soft).detach()
else:
y = torch.zeros_like(scores)
y.scatter_(-1, scores.argmax(-1, keepdim=True), 1.0)
return y
class MoEFFN(nn.Module):
def __init__(self, d_model, d_ff, E=4):
super().__init__()
self.E = E
self.d_model = d_model
# Router: ±1 weight mapping x → E scores
self.router_w = nn.Parameter(torch.randn(E, d_model) * 0.02)
# Experts: each is a standard BitFFN
self.experts = nn.ModuleList([BitFFN(d_model, d_ff) for _ in range(E)])
def forward(self, x):
# x: (B, T, D) ±1
B, T, D = x.shape
# Route
W_r = sign_ste(self.router_w) # (E, D) ±1
x_bin = sign_ste_clipped(x)
scores = F.linear(x_bin, W_r) # (B, T, E) integer popcount
route = gumbel_route(scores) # (B, T, E) soft-to-hard
# Compute all E experts (simple implementation; real MoE would dispatch).
# (B, T, D) each
outs = torch.stack([exp(x) for exp in self.experts], dim=-2) # (B, T, E, D)
# Mix by route weights: (B, T, E, 1) * (B, T, E, D) -> sum over E
return (route.unsqueeze(-1) * outs).sum(dim=-2)
class BitBlockV25(nn.Module):
def __init__(self, d_model, n_heads, d_ff, E=4):
super().__init__()
self.attn = IntBinaryAttention(d_model, n_heads)
self.ffn = MoEFFN(d_model, d_ff, E=E)
def forward(self, x):
a = self.attn(x)
f = self.ffn(x)
return sign_ste(x + a + f)
class BitLMv25(nn.Module):
def __init__(self, vocab_size=128, d_model=256, n_layers=8, n_heads=8, d_ff=512,
max_seq_len=256, E=4):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layers = n_layers
self.max_seq_len = max_seq_len
self.E = E
self.embed = BinaryEmbedding(vocab_size, d_model)
self.blocks = nn.ModuleList([
BitBlockV25(d_model, n_heads, d_ff, E=E) for _ in range(n_layers)
])
self.out_codebook = nn.Parameter(torch.randn(vocab_size, d_model) * 0.02)
self.logit_scale = nn.Parameter(torch.tensor(1.0 / math.sqrt(d_model)))
self.out_bias = nn.Parameter(torch.zeros(vocab_size))
def forward(self, idx, targets=None):
x = self.embed(idx)
for blk in self.blocks:
x = blk(x)
W_out = sign_ste(self.out_codebook)
scores = torch.matmul(x, W_out.t())
logits = scores * self.logit_scale + self.out_bias
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, self.vocab_size), targets.view(-1))
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens=200, temperature=1.0, top_k=None):
self.eval()
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.max_seq_len:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / max(temperature, 1e-5)
if top_k is not None:
v, _ = torch.topk(logits, top_k)
logits[logits < v[:, [-1]]] = -float('inf')
probs = F.softmax(logits, dim=-1)
nxt = torch.multinomial(probs, num_samples=1)
idx = torch.cat([idx, nxt], dim=1)
return idx
if __name__ == '__main__':
set_gumbel_tau(0.5)
for E in [2, 4, 8]:
m = BitLMv25(E=E)
n = sum(p.numel() for p in m.parameters())
print(f'v25 E={E}: {n:,} params ({n/1e6:.2f}M)')