"""FP32 baseline of matched parameter count. Standard transformer with RMSNorm + SwiGLU + RoPE-free. Used only as a reference; this is not binary.""" import math import torch import torch.nn as nn import torch.nn.functional as F class RMSNorm(nn.Module): def __init__(self, d, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(d)) self.eps = eps def forward(self, x): n = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) return x * n * self.weight class MHA(nn.Module): def __init__(self, d, h): super().__init__() self.d = d self.h = h self.dh = d // h self.qkv = nn.Linear(d, 3 * d, bias=False) self.o = nn.Linear(d, d, bias=False) def forward(self, x): B, T, D = x.shape qkv = self.qkv(x).reshape(B, T, 3, self.h, self.dh).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] y = F.scaled_dot_product_attention(q, k, v, is_causal=True) y = y.transpose(1, 2).contiguous().view(B, T, D) return self.o(y) class SwiGLU(nn.Module): def __init__(self, d, d_ff): super().__init__() self.g = nn.Linear(d, d_ff, bias=False) self.u = nn.Linear(d, d_ff, bias=False) self.d = nn.Linear(d_ff, d, bias=False) def forward(self, x): return self.d(F.silu(self.g(x)) * self.u(x)) class Block(nn.Module): def __init__(self, d, h, d_ff): super().__init__() self.n1 = RMSNorm(d) self.a = MHA(d, h) self.n2 = RMSNorm(d) self.f = SwiGLU(d, d_ff) def forward(self, x): x = x + self.a(self.n1(x)) x = x + self.f(self.n2(x)) return x class FP32LM(nn.Module): def __init__(self, vocab_size=128, d_model=256, n_layers=8, n_heads=8, d_ff=512, max_seq_len=256): super().__init__() self.vocab_size = vocab_size self.d_model = d_model self.max_seq_len = max_seq_len self.embed = nn.Embedding(vocab_size, d_model) self.pos = nn.Embedding(max_seq_len, d_model) self.blocks = nn.ModuleList([Block(d_model, n_heads, d_ff) for _ in range(n_layers)]) self.norm_f = RMSNorm(d_model) self.head = nn.Linear(d_model, vocab_size, bias=False) self.head.weight = self.embed.weight # tie def forward(self, idx, targets=None): B, T = idx.shape pos = torch.arange(T, device=idx.device) x = self.embed(idx) + self.pos(pos) for b in self.blocks: x = b(x) x = self.norm_f(x) logits = self.head(x) loss = None if targets is not None: loss = F.cross_entropy(logits.view(-1, self.vocab_size), targets.view(-1)) return logits, loss @torch.no_grad() def generate(self, idx, max_new_tokens=200, temperature=1.0, top_k=None): self.eval() for _ in range(max_new_tokens): idx_cond = idx[:, -self.max_seq_len:] logits, _ = self(idx_cond) logits = logits[:, -1, :] / max(temperature, 1e-5) if top_k is not None: v, _ = torch.topk(logits, top_k) logits[logits < v[:, [-1]]] = -float('inf') probs = F.softmax(logits, dim=-1) nxt = torch.multinomial(probs, num_samples=1) idx = torch.cat([idx, nxt], dim=1) return idx if __name__ == '__main__': m = FP32LM() n = sum(p.numel() for p in m.parameters()) print(f"fp32 params: {n:,} ({n/1e6:.2f}M)")