Shakespeare_MoE / train.py
haemant's picture
Upload folder using huggingface_hub
aad4104 verified
import argparse
import json
import math
import os
import time
from dataclasses import dataclass
import tiktoken
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
@dataclass
class Config:
# Model
vocab_size: int = 50257 # using gpt tokenizer thats whyy
d_model: int = 384
n_heads: int = 6
head_dim: int = 64
n_layers: int = 6
seq_len: int = 128
# MoE
n_real_experts: int = 8
shared_expert_hidden: int = 768
expert_hidden: int = 384
top_k: int = 4
rho: float = 0.5 # target data-sparsity
# Loss weights
balance_loss_weight: float = 2e-2
z_loss_weight: float = 1e-3
# Training
batch_size: int = 32
max_steps: int = 3000
lr: float = 3e-4
weight_decay: float = 0.1
warmup_steps: int = 100
grad_clip: float = 1.0
# Eval / logging
eval_interval: int = 200
log_interval: int = 50
save_interval: int = 1000
gen_tokens: int = 150 # how many tokens to generate mid training
@property
def n_null_copies(self) -> int:
"""M = N * (1-rho)/rho (Eq. 4 in paper)"""
return int(self.n_real_experts * (1 - self.rho) / self.rho)
@property
def n_total_slots(self) -> int:
return self.n_real_experts + self.n_null_copies
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
return (x.float() * norm).type_as(x) * self.weight
# Rotary Position Embeddings
def precompute_rope(dim:int,max_len:int,theta:float = 10000.0) -> tuple[torch.Tensor,torch.Tensor]:
freqs = 1.0/ (theta**(torch.arange(0,dim,2).float()/dim))
t = torch.arange(max_len).float()
angles = torch.outer(t,freqs) # (max_len, dim//2)
return angles.cos(),angles.sin()
def apply_rope(x:torch.Tensor,cos:torch.Tensor,sin: torch.Tensor) -> torch.Tensor:
"""x: (B, n_heads, T, head_dim)"""
D = x.size(-1)
assert D%2 == 0
T = x.size(2)
cos = cos[:T].unsqueeze(0).unsqueeze(0) # (1, 1, T, head_dim//2)
sin = sin[:T].unsqueeze(0).unsqueeze(0)
x1 = x[...,::2]
x2 = x[...,1::2]
# stack + flatten is just a compact way of interleaving even/odd back into the original embedding layout.
return torch.stack([x1*cos - x2*sin , x1*sin + x2*cos],dim=-1).flatten(-2)
# Multi-Head Attention
class MultiHeadAttention(nn.Module):
def __init__(self,cfg:Config):
super().__init__()
self.n_heads = cfg.n_heads
self.head_dim = cfg.head_dim
self.wq = nn.Linear(cfg.d_model, cfg.n_heads * cfg.head_dim, bias=False)
self.wk = nn.Linear(cfg.d_model, cfg.n_heads * cfg.head_dim, bias=False)
self.wv = nn.Linear(cfg.d_model, cfg.n_heads * cfg.head_dim, bias=False)
self.wo = nn.Linear(cfg.n_heads * cfg.head_dim, cfg.d_model, bias=False)
mask = torch.tril(torch.ones(cfg.seq_len,cfg.seq_len)).view(1,1,cfg.seq_len,cfg.seq_len)
self.register_buffer("mask",mask)
def forward(self,x:torch.Tensor,rope_cos:torch.Tensor,rope_sin:torch.Tensor) -> torch.Tensor:
B, T, _ = x.shape
q = self.wq(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
k = self.wk(x).view(B, T, self.n_heads,self.head_dim).transpose(1, 2)
v = self.wv(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
q = apply_rope(q,rope_cos,rope_sin)
k = apply_rope(k,rope_cos,rope_sin)
attn = (q @ k.transpose(-2,-1)) * (self.head_dim**-0.5)
attn = attn.masked_fill(self.mask[:,:,:T,:T]==0,float('-inf'))
attn = F.softmax(attn,dim=-1)
out = (attn @ v).transpose(1, 2).contiguous().view(B, T, -1)
return self.wo(out)
# Expert FFN
class ExpertFFN(nn.Module):
def __init__(self, d_model: int, hidden: int):
super().__init__()
self.w1 = nn.Linear(d_model, hidden, bias=False)
self.gelu = nn.GELU(approximate='tanh')
self.w2 = nn.Linear(hidden, d_model, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.w1(x)
x = self.gelu(x)
x = self.w2(x)
return x
class MoELayer(nn.Module):
"""
Token-choice MoE with null experts (Section 4.1).
Router produces N+1 logits (N real + 1 null). The single null logit is
duplicated M times to form N+M total slots. Top-K is applied over all
N+M slots. Selected slots pointing to null produce output=0 (no compute).
Gate weights are renormalised over selected *real* experts only.
"""
def __init__(self,cfg:Config):
super().__init__()
self.cfg = cfg
N = cfg.n_real_experts
M = cfg.n_null_copies
# shared expert
self.shared_expert = ExpertFFN(cfg.d_model, cfg.shared_expert_hidden)
# Routed experts
self.experts = nn.ModuleList([ExpertFFN(cfg.d_model, cfg.expert_hidden) for _ in range(N)])
# Router: produces N+1 logits (N real + 1 null)
self.router = nn.Linear(cfg.d_model, N + 1, bias=False)
self.N = N
self.M = M
self.top_k = cfg.top_k
# Telemetry accumulators (filled during forward, read externally)
self.last_expert_counts: torch.Tensor | None = None
self.last_null_ratio: float = 0.0
self.last_gate_weights: torch.Tensor | None = None
self.last_zero_compute_ratio: float = 0.0
self.last_balance_loss: torch.Tensor | None = None
self.last_z_loss: torch.Tensor | None = None
def forward(self,x:torch.Tensor)-> torch.Tensor:
B, T, D = x.shape
x_flat = x.view(-1, D) # (B*T, D)
num_tokens = x_flat.size(0)
# --- Router logits ---
logits_raw = self.router(x_flat) # (tokens, N+1)
real_logits = logits_raw[:, :self.N] # (tokens, N)
null_logit = logits_raw[:, self.N:] # (tokens, 1)
# Duplicate null logit M times → total N+M slots
expanded_logits = torch.cat([real_logits, null_logit.expand(-1, self.M)], dim=-1) # (tokens, N+M)
# --- Z-loss: log-sum-exp penalty ---
lse = torch.logsumexp(expanded_logits, dim=-1) # (tokens,)
z_loss = (lse ** 2).mean()
# --- Top-K selection ---
topk_vals, topk_idxs = torch.topk(expanded_logits, self.top_k, dim=-1) # (tokens, top_k)
topk_gates = F.softmax(topk_vals, dim=-1) # softmax over selected slots
# build mask for which selection are null and whiich ar real
is_real = topk_idxs < self.N # (tokens, top_k)
real_gate_sum = (topk_gates*is_real.float()).sum(dim=-1,keepdim=True).clamp(min=1e-9)
# If a token has no real experts selected, all gates become 0 (zero-compute token)
has_real = is_real.any(dim=-1, keepdim=True).float()
renorm_gates = topk_gates * is_real.float() / real_gate_sum * has_real # (tokens, top_k)
"""
so first we create a nask for null experts then we multiply the mask with topk gates which nulls out
the gates of the null expert then when we sum across experts those null gates doesnt contribute anything
and then we create another mask for nullifying any rows which has no rela experts then we take product of that mask
with the sum
"""
# --- Load-balancing loss (Eq. 6) ---
# f_i = fraction of tokens routed to slot i
NM = self.N + self.M
"""
topk format is not convenient for counting usage per slot.
Thats where slot_mask comes in.
For every token
for every top-k choice
keep a one-hot vector over slots.
slot_mask[token][k][slot] = 0
slot_mask is a one-hot encoding of routing decisions.
It answers the question:
“For each token and each of its top-k choices, which slot was selected?”
Assume we are in an MoE router with
`N = 3` real experts, `M = 1` null expert, so
NM = N + M = 4 # total slots
Assume:
num_tokens = 4
top_k = 2
So for each token, the router picks 2 slots out of 4.
---
First, imagine the router already ran and produced `topk_idxs`, which tells us **which slots were selected for each token**.
Letss say:
topk_idxs =
tensor([
[0, 2],
[1, 3],
[0, 1],
[2, 3]
])
Shape is (num_tokens=4, top_k=2)
This means:
* token 0 routed to slots 0 and 2
* token 1 → slots 1 and 3
* token 2 → slots 0 and 1
* token 3 → slots 2 and 3
---
Now this line runs:
slot_mask = torch.zeros(num_tokens, self.top_k, NM)
So shape is:
(4, 2, 4)
Initially it’s all zeros:
slot_mask[token, k, slot] = 0
---
Then comes the critical operation:
slot_mask.scatter_(2, topk_idxs.unsqueeze(-1), 1.0)
Let’s decode that.
`topk_idxs.unsqueeze(-1)` changes shape from `(4, 2)` to `(4, 2, 1)` so it can index into dimension 2 (the slot dimension).
`scatter_(2, index, value)` means:
> along dimension 2, put `value` at the positions specified by `index`
So after scatter, `slot_mask` becomes a **one-hot encoding of slot selection**.
For token 0, which had `[0, 2]`, its slice looks like:
[
[1, 0, 0, 0], # k=0 chose slot 0
[0, 0, 1, 0] # k=1 chose slot 2
]
For token 1 `[1, 3]`:
[
[0, 1, 0, 0],
[0, 0, 0, 1]
]
For token 2 `[0, 1]`:
[
[1, 0, 0, 0],
[0, 1, 0, 0]
]
For token 3 `[2, 3]`:
[
[0, 0, 1, 0],
[0, 0, 0, 1]
]
So `slot_mask` now explicitly marks **every routing decision**.
---
Now comes this line:
f = slot_mask.sum(dim=1).sum(dim=0) / (num_tokens * self.top_k)
Let’s do it step by step mentally.
First:
slot_mask.sum(dim=1)
This collapses the `top_k` dimension, giving shape `(4, 4)`.
For each token, it counts how many times each slot was selected:
Token 0:
[1, 0, 1, 0]
Token 1:
[0, 1, 0, 1]
Token 2:
[1, 1, 0, 0]
Token 3:
[0, 0, 1, 1]
Now:
.sum(dim=0)
This sums across tokens, giving total usage per slot:
Slot 0 → used 2 times
Slot 1 → used 2 times
Slot 2 → used 2 times
Slot 3 → used 2 times
So:
tensor([2, 2, 2, 2])
Now divide by total routing decisions:
num_tokens * top_k = 4 * 2 = 8
So:
f = [2/8, 2/8, 2/8, 2/8]
= [0.25, 0.25, 0.25, 0.25]
This `f` is exactly what the comment says:
> **fᵢ = fraction of tokens routed to slot i**
It’s the empirical routing frequency.
---
Now we compute the *expected* routing distribution from logits:
probs_all = F.softmax(expanded_logits, dim=-1)
Assume `expanded_logits` has shape `(4, 4)` and after softmax we get:
probs_all =
[
[0.4, 0.3, 0.2, 0.1],
[0.1, 0.4, 0.3, 0.2],
[0.25,0.25,0.25,0.25],
[0.3, 0.2, 0.3, 0.2]
]
Each row sums to 1. This is the router’s *probabilistic preference* before top-k selection.
Now:
P = probs_all.mean(dim=0)
So average over tokens:
Slot 0 → (0.4 + 0.1 + 0.25 + 0.3) / 4 = 0.2625
Slot 1 → 0.2875
Slot 2 → 0.2625
Slot 3 → 0.1875
So:
P = [0.2625, 0.2875, 0.2625, 0.1875]
This matches the comment:
> **Pᵢ = average routing probability for slot i**
This is what the router *wants* to do on average.
---
Finally comes the load-balancing loss:
balance_loss = NM * (f * P).sum()
Compute elementwise product:
f * P =
[
0.25 * 0.2625,
0.25 * 0.2875,
0.25 * 0.2625,
0.25 * 0.1875
]
Sum them:
≈ 0.25
Multiply by `NM = 4`:
balance_loss ≈ 1.0
---
### What this loss is actually doing
If routing is **perfectly balanced**, `f` will match `P`, and this expression is minimized.
If some slots get overused while others are ignored, `f` and `P` become misaligned, increasing the loss.
So this loss softly pushes the router toward:
* using all slots
* respecting its own probability distribution
* avoiding expert collapse
This loss is small when probability mass and actual usage are spread evenly, and it becomes large when both probability and usage concentrate on the same few slots.
Thatss why it prevents expert collapse.
"""
# one-hot encode selected slots
slot_mask = torch.zeros(num_tokens, self.top_k, NM, device=x.device)
slot_mask.scatter_(2,topk_idxs.unsqueeze(-1),1.0)
f = slot_mask.sum(dim=1).sum(dim=0)/ (num_tokens * self.top_k) # (NM,)
# P_i = average routing probability for slot i
probs_all = F.softmax(expanded_logits,dim=-1) # (tokens, NM)
P = probs_all.mean(dim=0) # (NM,)
balance_loss = NM * (f * P).sum()
# --- Dispatch to real experts ---
# Accumulate expert outputs
combined_output = torch.zeros_like(x_flat) # (tokens, D)
# Per-expert token counts for telemetry
expert_counts = torch.zeros(self.N,device=x.device)
gate_weight_sums = torch.zeros(self.N, device=x.device)
gate_weight_counts = torch.zeros(self.N, device=x.device)
for k_idx in range(self.top_k):
slot_ids = topk_idxs[:, k_idx] #(tokens,)
gates = renorm_gates[:, k_idx]
for e in range(self.N):
mask = (slot_ids == e)
if not mask.any():
continue
token_subset = x_flat[mask] # (n_e, D)
gate_subset = gates[mask].unsqueeze(-1) # (n_e, 1)
expert_out = self.experts[e](token_subset)
combined_output[mask] += gate_subset * expert_out
expert_counts[e] += mask.sum().float()
gate_weight_sums[e] += gates[mask].sum()
gate_weight_counts[e] += mask.sum().float()
# Shared expert (always active)
shared_out = self.shared_expert(x_flat)
output = shared_out + combined_output
# --- Telemetry ---
with torch.no_grad():
null_selections = (~is_real).sum().float()
total_selections = torch.tensor(num_tokens * self.top_k, dtype=torch.float32, device=x.device)
self.last_null_ratio = (null_selections / total_selections).item()
self.last_expert_counts = expert_counts.detach()
avg_gates = torch.where(gate_weight_counts > 0,
gate_weight_sums / gate_weight_counts,
torch.zeros_like(gate_weight_sums))
self.last_gate_weights = avg_gates.detach()
# Zero-compute tokens: all top-K went to null
zero_compute = (~is_real).all(dim=-1).sum().float()
self.last_zero_compute_ratio = (zero_compute / num_tokens).item()
self.last_balance_loss = balance_loss
self.last_z_loss = z_loss
return output.view(B, T, D)
# ---------------------------------------------------------------------------
# Transformer Block
# ---------------------------------------------------------------------------
class TransformerBlock(nn.Module):
def __init__(self, cfg: Config):
super().__init__()
self.attn_norm = RMSNorm(cfg.d_model)
self.attn = MultiHeadAttention(cfg)
self.moe_norm = RMSNorm(cfg.d_model)
self.moe = MoELayer(cfg)
def forward(self, x: torch.Tensor, rope_cos: torch.Tensor, rope_sin: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.attn_norm(x), rope_cos, rope_sin)
x = x + self.moe(self.moe_norm(x))
return x
# ---------------------------------------------------------------------------
# Full Model
# ---------------------------------------------------------------------------
class MoENullModel(nn.Module):
def __init__(self, cfg: Config):
super().__init__()
self.cfg = cfg
self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
self.final_norm = RMSNorm(cfg.d_model)
self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
# Weight tying
self.lm_head.weight = self.tok_emb.weight
# Precompute RoPE
rope_cos, rope_sin = precompute_rope(cfg.head_dim, cfg.seq_len)
self.register_buffer("rope_cos", rope_cos)
self.register_buffer("rope_sin", rope_sin)
self._init_weights()
def _init_weights(self):
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx: torch.Tensor) -> torch.Tensor:
"""idx: (B, T) → logits: (B, T, vocab_size)"""
x = self.tok_emb(idx)
for block in self.blocks:
x = block(x, self.rope_cos, self.rope_sin)
x = self.final_norm(x)
return self.lm_head(x)
def get_aux_losses(self) -> tuple[torch.Tensor, torch.Tensor]:
balance_loss = torch.tensor(0.0, device=next(self.parameters()).device)
z_loss = torch.tensor(0.0, device=next(self.parameters()).device)
for block in self.blocks:
if block.moe.last_balance_loss is not None:
balance_loss = balance_loss + block.moe.last_balance_loss
if block.moe.last_z_loss is not None:
z_loss = z_loss + block.moe.last_z_loss
balance_loss = balance_loss / self.cfg.n_layers
z_loss = z_loss / self.cfg.n_layers
return balance_loss, z_loss
def get_telemetry(self) -> dict:
"""Gather telemetry from all MoE layers."""
expert_counts = []
null_ratios = []
gate_weights = []
zero_compute_ratios = []
for block in self.blocks:
moe = block.moe
if moe.last_expert_counts is not None:
expert_counts.append(moe.last_expert_counts.cpu().tolist())
null_ratios.append(moe.last_null_ratio)
if moe.last_gate_weights is not None:
gate_weights.append(moe.last_gate_weights.cpu().tolist())
zero_compute_ratios.append(moe.last_zero_compute_ratio)
return {
"per_layer_expert_counts": expert_counts,
"avg_expert_counts": [sum(c) / len(c) for c in zip(*expert_counts)] if expert_counts else [],
"null_ratio": sum(null_ratios) / len(null_ratios) if null_ratios else 0,
"avg_gate_weights": [sum(g) / len(g) for g in zip(*gate_weights)] if gate_weights else [],
"zero_compute_ratio": sum(zero_compute_ratios) / len(zero_compute_ratios) if zero_compute_ratios else 0,
}
@torch.no_grad()
def generate(self, idx: torch.Tensor, max_new_tokens: int, temperature: float = 0.8) -> torch.Tensor:
for _ in range(max_new_tokens):
idx_crop = idx[:, -self.cfg.seq_len:]
logits = self(idx_crop)
logits = logits[:, -1, :] / temperature
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat([idx, idx_next], dim=1)
return idx
# ---------------------------------------------------------------------------
# Dataset
# ---------------------------------------------------------------------------
class TextDataset(Dataset):
def __init__(self, tokens: list[int], seq_len: int):
self.tokens = torch.tensor(tokens, dtype=torch.long)
self.seq_len = seq_len
def __len__(self) -> int:
return max(0, len(self.tokens) - self.seq_len - 1)
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
chunk = self.tokens[idx : idx + self.seq_len + 1]
return chunk[:-1], chunk[1:]
# ---------------------------------------------------------------------------
# Learning rate schedule
# ---------------------------------------------------------------------------
def get_lr(step: int, cfg: Config) -> float:
if step < cfg.warmup_steps:
return cfg.lr * step / cfg.warmup_steps
decay_ratio = (step - cfg.warmup_steps) / max(1, cfg.max_steps - cfg.warmup_steps)
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
return cfg.lr * max(coeff, 0.1)
# ---------------------------------------------------------------------------
# Perplexity computation
# ---------------------------------------------------------------------------
@torch.no_grad()
def compute_perplexity(model: MoENullModel, tokens: list[int], cfg: Config, device: torch.device) -> float:
model.eval()
n_eval = min(len(tokens) - cfg.seq_len - 1, cfg.batch_size * 20)
if n_eval <= 0:
model.train()
return float("inf")
total_loss = 0.0
count = 0
start = len(tokens) - n_eval - cfg.seq_len - 1
for i in range(0, n_eval, cfg.seq_len):
offset = start + i
if offset + cfg.seq_len + 1 > len(tokens):
break
chunk = torch.tensor(tokens[offset : offset + cfg.seq_len + 1], dtype=torch.long, device=device)
x, y = chunk[:-1].unsqueeze(0), chunk[1:].unsqueeze(0)
logits = model(x)
loss = F.cross_entropy(logits.view(-1, cfg.vocab_size), y.view(-1))
total_loss += loss.item()
count += 1
model.train()
if count == 0:
return float("inf")
return math.exp(total_loss / count)
# ---------------------------------------------------------------------------
# ASCII bar chart for console
# ---------------------------------------------------------------------------
def ascii_bar(values: list[float], width: int = 30, labels: list[str] | None = None) -> str:
if not values:
return ""
max_val = max(values) if max(values) > 0 else 1.0
lines = []
for i, v in enumerate(values):
bar_len = int(v / max_val * width)
label = labels[i] if labels else f"E{i}"
lines.append(f" {label:>4s} |{'█' * bar_len}{' ' * (width - bar_len)}| {v:.0f}")
return "\n".join(lines)
# ---------------------------------------------------------------------------
# Main training loop
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="Train MoE with Null Experts")
parser.add_argument("--dataset", type=str, default="input.txt",
help="Path to text file or 'tiny' for tiny-shakespeare")
parser.add_argument("--steps", type=int, default=None)
parser.add_argument("--batch", type=int, default=None)
parser.add_argument("--lr", type=float, default=None)
parser.add_argument("--eval_interval", type=int, default=None)
parser.add_argument("--log_interval", type=int, default=None)
parser.add_argument("--output_dir", type=str, default=".")
args = parser.parse_args()
cfg = Config()
if args.steps is not None:
cfg.max_steps = args.steps
if args.batch is not None:
cfg.batch_size = args.batch
if args.lr is not None:
cfg.lr = args.lr
if args.eval_interval is not None:
cfg.eval_interval = args.eval_interval
if args.log_interval is not None:
cfg.log_interval = args.log_interval
# --- Device ---
if torch.cuda.is_available():
device = torch.device("cuda")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
print(f"Using device: {device}")
# --- Tokenizer ---
enc = tiktoken.get_encoding("gpt2")
# --- Data ---
data_path = args.dataset
if data_path == "tiny":
# Download tiny-shakespeare
import urllib.request
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
data_path = os.path.join(args.output_dir, "input.txt")
if not os.path.exists(data_path):
print("Downloading tiny-shakespeare...")
urllib.request.urlretrieve(url, data_path)
with open(data_path, "r", encoding="utf-8") as f:
text = f.read()
tokens = enc.encode(text)
print(f"Dataset: {len(tokens):,} tokens")
dataset = TextDataset(tokens, cfg.seq_len)
dataloader = DataLoader(dataset, batch_size=cfg.batch_size, shuffle=True,
drop_last=True, num_workers=0)
# --- Model ---
model = MoENullModel(cfg).to(device)
n_params = sum(p.numel() for p in model.parameters())
# Subtract tied weights (counted once in tok_emb, once in lm_head)
n_params_unique = n_params - model.tok_emb.weight.numel()
print(f"Model parameters: {n_params_unique:,} (unique, with weight tying)")
print(f" Null copies M = {cfg.n_null_copies}, total routing slots = {cfg.n_total_slots}")
assert n_params_unique < 42_000_000, f"Model too large: {n_params_unique:,} params"
# --- Optimizer ---
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr,
betas=(0.9, 0.95), weight_decay=cfg.weight_decay)
# --- Telemetry log ---
telemetry_path = os.path.join(args.output_dir, "telemetry.json")
telemetry_log: list[dict] = []
# --- Training ---
model.train()
data_iter = iter(dataloader)
start_time = time.time()
for step in range(1, cfg.max_steps + 1):
# Get batch
try:
x_batch, y_batch = next(data_iter)
except StopIteration:
data_iter = iter(dataloader)
x_batch, y_batch = next(data_iter)
x_batch, y_batch = x_batch.to(device), y_batch.to(device)
# LR schedule
lr = get_lr(step, cfg)
for pg in optimizer.param_groups:
pg["lr"] = lr
# Forward
logits = model(x_batch)
lm_loss = F.cross_entropy(logits.view(-1, cfg.vocab_size), y_batch.view(-1))
# Auxiliary losses
balance_loss, z_loss = model.get_aux_losses()
total_loss = lm_loss + cfg.balance_loss_weight * balance_loss + cfg.z_loss_weight * z_loss
# Backward
optimizer.zero_grad()
total_loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
optimizer.step()
# --- Telemetry ---
telem = model.get_telemetry()
step_data = {
"step": step,
"total_loss": total_loss.item(),
"lm_loss": lm_loss.item(),
"balance_loss": balance_loss.item(),
"z_loss": z_loss.item(),
"lr": lr,
"expert_counts": telem["avg_expert_counts"],
"null_ratio": telem["null_ratio"],
"gate_weights": telem["avg_gate_weights"],
"zero_compute_ratio": telem["zero_compute_ratio"],
}
# Periodic perplexity + generation
if step % cfg.eval_interval == 0 or step == 1:
ppl = compute_perplexity(model, tokens, cfg, device)
step_data["perplexity"] = ppl
# Generate sample
prompt_tokens = tokens[:10]
prompt = torch.tensor([prompt_tokens], dtype=torch.long, device=device)
gen_ids = model.generate(prompt, max_new_tokens=cfg.gen_tokens)
gen_text = enc.decode(gen_ids[0].cpu().tolist())
step_data["generated_text"] = gen_text
telemetry_log.append(step_data)
# --- Console output ---
if step % cfg.log_interval == 0 or step == 1:
elapsed = time.time() - start_time
tokens_per_sec = step * cfg.batch_size * cfg.seq_len / elapsed
print(f"\n{'='*60}")
print(f"Step {step}/{cfg.max_steps} | LR: {lr:.2e} | {tokens_per_sec:,.0f} tok/s")
print(f" Loss: {total_loss.item():.4f} (LM: {lm_loss.item():.4f}, "
f"Bal: {balance_loss.item():.4f}, Z: {z_loss.item():.4f})")
if "perplexity" in step_data:
print(f" Perplexity: {step_data['perplexity']:.2f}")
print(f" Null routing: {telem['null_ratio']*100:.1f}% | "
f"Zero-compute tokens: {telem['zero_compute_ratio']*100:.1f}%")
if telem["avg_expert_counts"]:
print(f" Expert utilization:")
print(ascii_bar(telem["avg_expert_counts"],
labels=[f"E{i}" for i in range(cfg.n_real_experts)]))
if "generated_text" in step_data:
print(f" Generated: {step_data['generated_text'][:200]}")
# --- Save checkpoint ---
if step % cfg.save_interval == 0:
ckpt_path = os.path.join(args.output_dir, f"checkpoint_{step}.pt")
torch.save({
"step": step,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"config": cfg,
}, ckpt_path)
print(f" Saved checkpoint: {ckpt_path}")
# --- Save telemetry periodically ---
if step % cfg.log_interval == 0:
with open(telemetry_path, "w") as f:
json.dump(telemetry_log, f)
# --- Final save ---
with open(telemetry_path, "w") as f:
json.dump(telemetry_log, f)
print(f"\nTelemetry saved to {telemetry_path}")
final_ckpt = os.path.join(args.output_dir, "checkpoint_final.pt")
torch.save({
"step": cfg.max_steps,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"config": cfg,
}, final_ckpt)
print(f"Final checkpoint saved to {final_ckpt}")
if __name__ == "__main__":
main()