Spaces:
Sleeping
Sleeping
| import argparse | |
| import json | |
| import math | |
| import os | |
| import time | |
| from dataclasses import dataclass | |
| import tiktoken | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader, Dataset | |
| class Config: | |
| # Model | |
| vocab_size: int = 50257 # using gpt tokenizer thats whyy | |
| d_model: int = 384 | |
| n_heads: int = 6 | |
| head_dim: int = 64 | |
| n_layers: int = 6 | |
| seq_len: int = 128 | |
| # MoE | |
| n_real_experts: int = 8 | |
| shared_expert_hidden: int = 768 | |
| expert_hidden: int = 384 | |
| top_k: int = 4 | |
| rho: float = 0.5 # target data-sparsity | |
| # Loss weights | |
| balance_loss_weight: float = 2e-2 | |
| z_loss_weight: float = 1e-3 | |
| # Training | |
| batch_size: int = 32 | |
| max_steps: int = 3000 | |
| lr: float = 3e-4 | |
| weight_decay: float = 0.1 | |
| warmup_steps: int = 100 | |
| grad_clip: float = 1.0 | |
| # Eval / logging | |
| eval_interval: int = 200 | |
| log_interval: int = 50 | |
| save_interval: int = 1000 | |
| gen_tokens: int = 150 # how many tokens to generate mid training | |
| def n_null_copies(self) -> int: | |
| """M = N * (1-rho)/rho (Eq. 4 in paper)""" | |
| return int(self.n_real_experts * (1 - self.rho) / self.rho) | |
| def n_total_slots(self) -> int: | |
| return self.n_real_experts + self.n_null_copies | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| self.eps = eps | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt() | |
| return (x.float() * norm).type_as(x) * self.weight | |
| # Rotary Position Embeddings | |
| def precompute_rope(dim:int,max_len:int,theta:float = 10000.0) -> tuple[torch.Tensor,torch.Tensor]: | |
| freqs = 1.0/ (theta**(torch.arange(0,dim,2).float()/dim)) | |
| t = torch.arange(max_len).float() | |
| angles = torch.outer(t,freqs) # (max_len, dim//2) | |
| return angles.cos(),angles.sin() | |
| def apply_rope(x:torch.Tensor,cos:torch.Tensor,sin: torch.Tensor) -> torch.Tensor: | |
| """x: (B, n_heads, T, head_dim)""" | |
| D = x.size(-1) | |
| assert D%2 == 0 | |
| T = x.size(2) | |
| cos = cos[:T].unsqueeze(0).unsqueeze(0) # (1, 1, T, head_dim//2) | |
| sin = sin[:T].unsqueeze(0).unsqueeze(0) | |
| x1 = x[...,::2] | |
| x2 = x[...,1::2] | |
| # stack + flatten is just a compact way of interleaving even/odd back into the original embedding layout. | |
| return torch.stack([x1*cos - x2*sin , x1*sin + x2*cos],dim=-1).flatten(-2) | |
| # Multi-Head Attention | |
| class MultiHeadAttention(nn.Module): | |
| def __init__(self,cfg:Config): | |
| super().__init__() | |
| self.n_heads = cfg.n_heads | |
| self.head_dim = cfg.head_dim | |
| self.wq = nn.Linear(cfg.d_model, cfg.n_heads * cfg.head_dim, bias=False) | |
| self.wk = nn.Linear(cfg.d_model, cfg.n_heads * cfg.head_dim, bias=False) | |
| self.wv = nn.Linear(cfg.d_model, cfg.n_heads * cfg.head_dim, bias=False) | |
| self.wo = nn.Linear(cfg.n_heads * cfg.head_dim, cfg.d_model, bias=False) | |
| mask = torch.tril(torch.ones(cfg.seq_len,cfg.seq_len)).view(1,1,cfg.seq_len,cfg.seq_len) | |
| self.register_buffer("mask",mask) | |
| def forward(self,x:torch.Tensor,rope_cos:torch.Tensor,rope_sin:torch.Tensor) -> torch.Tensor: | |
| B, T, _ = x.shape | |
| q = self.wq(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) | |
| k = self.wk(x).view(B, T, self.n_heads,self.head_dim).transpose(1, 2) | |
| v = self.wv(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) | |
| q = apply_rope(q,rope_cos,rope_sin) | |
| k = apply_rope(k,rope_cos,rope_sin) | |
| attn = (q @ k.transpose(-2,-1)) * (self.head_dim**-0.5) | |
| attn = attn.masked_fill(self.mask[:,:,:T,:T]==0,float('-inf')) | |
| attn = F.softmax(attn,dim=-1) | |
| out = (attn @ v).transpose(1, 2).contiguous().view(B, T, -1) | |
| return self.wo(out) | |
| # Expert FFN | |
| class ExpertFFN(nn.Module): | |
| def __init__(self, d_model: int, hidden: int): | |
| super().__init__() | |
| self.w1 = nn.Linear(d_model, hidden, bias=False) | |
| self.gelu = nn.GELU(approximate='tanh') | |
| self.w2 = nn.Linear(hidden, d_model, bias=False) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.w1(x) | |
| x = self.gelu(x) | |
| x = self.w2(x) | |
| return x | |
| class MoELayer(nn.Module): | |
| """ | |
| Token-choice MoE with null experts (Section 4.1). | |
| Router produces N+1 logits (N real + 1 null). The single null logit is | |
| duplicated M times to form N+M total slots. Top-K is applied over all | |
| N+M slots. Selected slots pointing to null produce output=0 (no compute). | |
| Gate weights are renormalised over selected *real* experts only. | |
| """ | |
| def __init__(self,cfg:Config): | |
| super().__init__() | |
| self.cfg = cfg | |
| N = cfg.n_real_experts | |
| M = cfg.n_null_copies | |
| # shared expert | |
| self.shared_expert = ExpertFFN(cfg.d_model, cfg.shared_expert_hidden) | |
| # Routed experts | |
| self.experts = nn.ModuleList([ExpertFFN(cfg.d_model, cfg.expert_hidden) for _ in range(N)]) | |
| # Router: produces N+1 logits (N real + 1 null) | |
| self.router = nn.Linear(cfg.d_model, N + 1, bias=False) | |
| self.N = N | |
| self.M = M | |
| self.top_k = cfg.top_k | |
| # Telemetry accumulators (filled during forward, read externally) | |
| self.last_expert_counts: torch.Tensor | None = None | |
| self.last_null_ratio: float = 0.0 | |
| self.last_gate_weights: torch.Tensor | None = None | |
| self.last_zero_compute_ratio: float = 0.0 | |
| self.last_balance_loss: torch.Tensor | None = None | |
| self.last_z_loss: torch.Tensor | None = None | |
| def forward(self,x:torch.Tensor)-> torch.Tensor: | |
| B, T, D = x.shape | |
| x_flat = x.view(-1, D) # (B*T, D) | |
| num_tokens = x_flat.size(0) | |
| # --- Router logits --- | |
| logits_raw = self.router(x_flat) # (tokens, N+1) | |
| real_logits = logits_raw[:, :self.N] # (tokens, N) | |
| null_logit = logits_raw[:, self.N:] # (tokens, 1) | |
| # Duplicate null logit M times → total N+M slots | |
| expanded_logits = torch.cat([real_logits, null_logit.expand(-1, self.M)], dim=-1) # (tokens, N+M) | |
| # --- Z-loss: log-sum-exp penalty --- | |
| lse = torch.logsumexp(expanded_logits, dim=-1) # (tokens,) | |
| z_loss = (lse ** 2).mean() | |
| # --- Top-K selection --- | |
| topk_vals, topk_idxs = torch.topk(expanded_logits, self.top_k, dim=-1) # (tokens, top_k) | |
| topk_gates = F.softmax(topk_vals, dim=-1) # softmax over selected slots | |
| # build mask for which selection are null and whiich ar real | |
| is_real = topk_idxs < self.N # (tokens, top_k) | |
| real_gate_sum = (topk_gates*is_real.float()).sum(dim=-1,keepdim=True).clamp(min=1e-9) | |
| # If a token has no real experts selected, all gates become 0 (zero-compute token) | |
| has_real = is_real.any(dim=-1, keepdim=True).float() | |
| renorm_gates = topk_gates * is_real.float() / real_gate_sum * has_real # (tokens, top_k) | |
| """ | |
| so first we create a nask for null experts then we multiply the mask with topk gates which nulls out | |
| the gates of the null expert then when we sum across experts those null gates doesnt contribute anything | |
| and then we create another mask for nullifying any rows which has no rela experts then we take product of that mask | |
| with the sum | |
| """ | |
| # --- Load-balancing loss (Eq. 6) --- | |
| # f_i = fraction of tokens routed to slot i | |
| NM = self.N + self.M | |
| """ | |
| topk format is not convenient for counting usage per slot. | |
| Thats where slot_mask comes in. | |
| For every token | |
| for every top-k choice | |
| keep a one-hot vector over slots. | |
| slot_mask[token][k][slot] = 0 | |
| slot_mask is a one-hot encoding of routing decisions. | |
| It answers the question: | |
| “For each token and each of its top-k choices, which slot was selected?” | |
| Assume we are in an MoE router with | |
| `N = 3` real experts, `M = 1` null expert, so | |
| NM = N + M = 4 # total slots | |
| Assume: | |
| num_tokens = 4 | |
| top_k = 2 | |
| So for each token, the router picks 2 slots out of 4. | |
| --- | |
| First, imagine the router already ran and produced `topk_idxs`, which tells us **which slots were selected for each token**. | |
| Letss say: | |
| topk_idxs = | |
| tensor([ | |
| [0, 2], | |
| [1, 3], | |
| [0, 1], | |
| [2, 3] | |
| ]) | |
| Shape is (num_tokens=4, top_k=2) | |
| This means: | |
| * token 0 routed to slots 0 and 2 | |
| * token 1 → slots 1 and 3 | |
| * token 2 → slots 0 and 1 | |
| * token 3 → slots 2 and 3 | |
| --- | |
| Now this line runs: | |
| slot_mask = torch.zeros(num_tokens, self.top_k, NM) | |
| So shape is: | |
| (4, 2, 4) | |
| Initially it’s all zeros: | |
| slot_mask[token, k, slot] = 0 | |
| --- | |
| Then comes the critical operation: | |
| slot_mask.scatter_(2, topk_idxs.unsqueeze(-1), 1.0) | |
| Let’s decode that. | |
| `topk_idxs.unsqueeze(-1)` changes shape from `(4, 2)` to `(4, 2, 1)` so it can index into dimension 2 (the slot dimension). | |
| `scatter_(2, index, value)` means: | |
| > along dimension 2, put `value` at the positions specified by `index` | |
| So after scatter, `slot_mask` becomes a **one-hot encoding of slot selection**. | |
| For token 0, which had `[0, 2]`, its slice looks like: | |
| [ | |
| [1, 0, 0, 0], # k=0 chose slot 0 | |
| [0, 0, 1, 0] # k=1 chose slot 2 | |
| ] | |
| For token 1 `[1, 3]`: | |
| [ | |
| [0, 1, 0, 0], | |
| [0, 0, 0, 1] | |
| ] | |
| For token 2 `[0, 1]`: | |
| [ | |
| [1, 0, 0, 0], | |
| [0, 1, 0, 0] | |
| ] | |
| For token 3 `[2, 3]`: | |
| [ | |
| [0, 0, 1, 0], | |
| [0, 0, 0, 1] | |
| ] | |
| So `slot_mask` now explicitly marks **every routing decision**. | |
| --- | |
| Now comes this line: | |
| f = slot_mask.sum(dim=1).sum(dim=0) / (num_tokens * self.top_k) | |
| Let’s do it step by step mentally. | |
| First: | |
| slot_mask.sum(dim=1) | |
| This collapses the `top_k` dimension, giving shape `(4, 4)`. | |
| For each token, it counts how many times each slot was selected: | |
| Token 0: | |
| [1, 0, 1, 0] | |
| Token 1: | |
| [0, 1, 0, 1] | |
| Token 2: | |
| [1, 1, 0, 0] | |
| Token 3: | |
| [0, 0, 1, 1] | |
| Now: | |
| .sum(dim=0) | |
| This sums across tokens, giving total usage per slot: | |
| Slot 0 → used 2 times | |
| Slot 1 → used 2 times | |
| Slot 2 → used 2 times | |
| Slot 3 → used 2 times | |
| So: | |
| tensor([2, 2, 2, 2]) | |
| Now divide by total routing decisions: | |
| num_tokens * top_k = 4 * 2 = 8 | |
| So: | |
| f = [2/8, 2/8, 2/8, 2/8] | |
| = [0.25, 0.25, 0.25, 0.25] | |
| This `f` is exactly what the comment says: | |
| > **fᵢ = fraction of tokens routed to slot i** | |
| It’s the empirical routing frequency. | |
| --- | |
| Now we compute the *expected* routing distribution from logits: | |
| probs_all = F.softmax(expanded_logits, dim=-1) | |
| Assume `expanded_logits` has shape `(4, 4)` and after softmax we get: | |
| probs_all = | |
| [ | |
| [0.4, 0.3, 0.2, 0.1], | |
| [0.1, 0.4, 0.3, 0.2], | |
| [0.25,0.25,0.25,0.25], | |
| [0.3, 0.2, 0.3, 0.2] | |
| ] | |
| Each row sums to 1. This is the router’s *probabilistic preference* before top-k selection. | |
| Now: | |
| P = probs_all.mean(dim=0) | |
| So average over tokens: | |
| Slot 0 → (0.4 + 0.1 + 0.25 + 0.3) / 4 = 0.2625 | |
| Slot 1 → 0.2875 | |
| Slot 2 → 0.2625 | |
| Slot 3 → 0.1875 | |
| So: | |
| P = [0.2625, 0.2875, 0.2625, 0.1875] | |
| This matches the comment: | |
| > **Pᵢ = average routing probability for slot i** | |
| This is what the router *wants* to do on average. | |
| --- | |
| Finally comes the load-balancing loss: | |
| balance_loss = NM * (f * P).sum() | |
| Compute elementwise product: | |
| f * P = | |
| [ | |
| 0.25 * 0.2625, | |
| 0.25 * 0.2875, | |
| 0.25 * 0.2625, | |
| 0.25 * 0.1875 | |
| ] | |
| Sum them: | |
| ≈ 0.25 | |
| Multiply by `NM = 4`: | |
| balance_loss ≈ 1.0 | |
| --- | |
| ### What this loss is actually doing | |
| If routing is **perfectly balanced**, `f` will match `P`, and this expression is minimized. | |
| If some slots get overused while others are ignored, `f` and `P` become misaligned, increasing the loss. | |
| So this loss softly pushes the router toward: | |
| * using all slots | |
| * respecting its own probability distribution | |
| * avoiding expert collapse | |
| This loss is small when probability mass and actual usage are spread evenly, and it becomes large when both probability and usage concentrate on the same few slots. | |
| Thatss why it prevents expert collapse. | |
| """ | |
| # one-hot encode selected slots | |
| slot_mask = torch.zeros(num_tokens, self.top_k, NM, device=x.device) | |
| slot_mask.scatter_(2,topk_idxs.unsqueeze(-1),1.0) | |
| f = slot_mask.sum(dim=1).sum(dim=0)/ (num_tokens * self.top_k) # (NM,) | |
| # P_i = average routing probability for slot i | |
| probs_all = F.softmax(expanded_logits,dim=-1) # (tokens, NM) | |
| P = probs_all.mean(dim=0) # (NM,) | |
| balance_loss = NM * (f * P).sum() | |
| # --- Dispatch to real experts --- | |
| # Accumulate expert outputs | |
| combined_output = torch.zeros_like(x_flat) # (tokens, D) | |
| # Per-expert token counts for telemetry | |
| expert_counts = torch.zeros(self.N,device=x.device) | |
| gate_weight_sums = torch.zeros(self.N, device=x.device) | |
| gate_weight_counts = torch.zeros(self.N, device=x.device) | |
| for k_idx in range(self.top_k): | |
| slot_ids = topk_idxs[:, k_idx] #(tokens,) | |
| gates = renorm_gates[:, k_idx] | |
| for e in range(self.N): | |
| mask = (slot_ids == e) | |
| if not mask.any(): | |
| continue | |
| token_subset = x_flat[mask] # (n_e, D) | |
| gate_subset = gates[mask].unsqueeze(-1) # (n_e, 1) | |
| expert_out = self.experts[e](token_subset) | |
| combined_output[mask] += gate_subset * expert_out | |
| expert_counts[e] += mask.sum().float() | |
| gate_weight_sums[e] += gates[mask].sum() | |
| gate_weight_counts[e] += mask.sum().float() | |
| # Shared expert (always active) | |
| shared_out = self.shared_expert(x_flat) | |
| output = shared_out + combined_output | |
| # --- Telemetry --- | |
| with torch.no_grad(): | |
| null_selections = (~is_real).sum().float() | |
| total_selections = torch.tensor(num_tokens * self.top_k, dtype=torch.float32, device=x.device) | |
| self.last_null_ratio = (null_selections / total_selections).item() | |
| self.last_expert_counts = expert_counts.detach() | |
| avg_gates = torch.where(gate_weight_counts > 0, | |
| gate_weight_sums / gate_weight_counts, | |
| torch.zeros_like(gate_weight_sums)) | |
| self.last_gate_weights = avg_gates.detach() | |
| # Zero-compute tokens: all top-K went to null | |
| zero_compute = (~is_real).all(dim=-1).sum().float() | |
| self.last_zero_compute_ratio = (zero_compute / num_tokens).item() | |
| self.last_balance_loss = balance_loss | |
| self.last_z_loss = z_loss | |
| return output.view(B, T, D) | |
| # --------------------------------------------------------------------------- | |
| # Transformer Block | |
| # --------------------------------------------------------------------------- | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, cfg: Config): | |
| super().__init__() | |
| self.attn_norm = RMSNorm(cfg.d_model) | |
| self.attn = MultiHeadAttention(cfg) | |
| self.moe_norm = RMSNorm(cfg.d_model) | |
| self.moe = MoELayer(cfg) | |
| def forward(self, x: torch.Tensor, rope_cos: torch.Tensor, rope_sin: torch.Tensor) -> torch.Tensor: | |
| x = x + self.attn(self.attn_norm(x), rope_cos, rope_sin) | |
| x = x + self.moe(self.moe_norm(x)) | |
| return x | |
| # --------------------------------------------------------------------------- | |
| # Full Model | |
| # --------------------------------------------------------------------------- | |
| class MoENullModel(nn.Module): | |
| def __init__(self, cfg: Config): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model) | |
| self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)]) | |
| self.final_norm = RMSNorm(cfg.d_model) | |
| self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) | |
| # Weight tying | |
| self.lm_head.weight = self.tok_emb.weight | |
| # Precompute RoPE | |
| rope_cos, rope_sin = precompute_rope(cfg.head_dim, cfg.seq_len) | |
| self.register_buffer("rope_cos", rope_cos) | |
| self.register_buffer("rope_sin", rope_sin) | |
| self._init_weights() | |
| def _init_weights(self): | |
| for module in self.modules(): | |
| if isinstance(module, nn.Linear): | |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.Embedding): | |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| def forward(self, idx: torch.Tensor) -> torch.Tensor: | |
| """idx: (B, T) → logits: (B, T, vocab_size)""" | |
| x = self.tok_emb(idx) | |
| for block in self.blocks: | |
| x = block(x, self.rope_cos, self.rope_sin) | |
| x = self.final_norm(x) | |
| return self.lm_head(x) | |
| def get_aux_losses(self) -> tuple[torch.Tensor, torch.Tensor]: | |
| balance_loss = torch.tensor(0.0, device=next(self.parameters()).device) | |
| z_loss = torch.tensor(0.0, device=next(self.parameters()).device) | |
| for block in self.blocks: | |
| if block.moe.last_balance_loss is not None: | |
| balance_loss = balance_loss + block.moe.last_balance_loss | |
| if block.moe.last_z_loss is not None: | |
| z_loss = z_loss + block.moe.last_z_loss | |
| balance_loss = balance_loss / self.cfg.n_layers | |
| z_loss = z_loss / self.cfg.n_layers | |
| return balance_loss, z_loss | |
| def get_telemetry(self) -> dict: | |
| """Gather telemetry from all MoE layers.""" | |
| expert_counts = [] | |
| null_ratios = [] | |
| gate_weights = [] | |
| zero_compute_ratios = [] | |
| for block in self.blocks: | |
| moe = block.moe | |
| if moe.last_expert_counts is not None: | |
| expert_counts.append(moe.last_expert_counts.cpu().tolist()) | |
| null_ratios.append(moe.last_null_ratio) | |
| if moe.last_gate_weights is not None: | |
| gate_weights.append(moe.last_gate_weights.cpu().tolist()) | |
| zero_compute_ratios.append(moe.last_zero_compute_ratio) | |
| return { | |
| "per_layer_expert_counts": expert_counts, | |
| "avg_expert_counts": [sum(c) / len(c) for c in zip(*expert_counts)] if expert_counts else [], | |
| "null_ratio": sum(null_ratios) / len(null_ratios) if null_ratios else 0, | |
| "avg_gate_weights": [sum(g) / len(g) for g in zip(*gate_weights)] if gate_weights else [], | |
| "zero_compute_ratio": sum(zero_compute_ratios) / len(zero_compute_ratios) if zero_compute_ratios else 0, | |
| } | |
| def generate(self, idx: torch.Tensor, max_new_tokens: int, temperature: float = 0.8) -> torch.Tensor: | |
| for _ in range(max_new_tokens): | |
| idx_crop = idx[:, -self.cfg.seq_len:] | |
| logits = self(idx_crop) | |
| logits = logits[:, -1, :] / temperature | |
| probs = F.softmax(logits, dim=-1) | |
| idx_next = torch.multinomial(probs, num_samples=1) | |
| idx = torch.cat([idx, idx_next], dim=1) | |
| return idx | |
| # --------------------------------------------------------------------------- | |
| # Dataset | |
| # --------------------------------------------------------------------------- | |
| class TextDataset(Dataset): | |
| def __init__(self, tokens: list[int], seq_len: int): | |
| self.tokens = torch.tensor(tokens, dtype=torch.long) | |
| self.seq_len = seq_len | |
| def __len__(self) -> int: | |
| return max(0, len(self.tokens) - self.seq_len - 1) | |
| def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: | |
| chunk = self.tokens[idx : idx + self.seq_len + 1] | |
| return chunk[:-1], chunk[1:] | |
| # --------------------------------------------------------------------------- | |
| # Learning rate schedule | |
| # --------------------------------------------------------------------------- | |
| def get_lr(step: int, cfg: Config) -> float: | |
| if step < cfg.warmup_steps: | |
| return cfg.lr * step / cfg.warmup_steps | |
| decay_ratio = (step - cfg.warmup_steps) / max(1, cfg.max_steps - cfg.warmup_steps) | |
| coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) | |
| return cfg.lr * max(coeff, 0.1) | |
| # --------------------------------------------------------------------------- | |
| # Perplexity computation | |
| # --------------------------------------------------------------------------- | |
| def compute_perplexity(model: MoENullModel, tokens: list[int], cfg: Config, device: torch.device) -> float: | |
| model.eval() | |
| n_eval = min(len(tokens) - cfg.seq_len - 1, cfg.batch_size * 20) | |
| if n_eval <= 0: | |
| model.train() | |
| return float("inf") | |
| total_loss = 0.0 | |
| count = 0 | |
| start = len(tokens) - n_eval - cfg.seq_len - 1 | |
| for i in range(0, n_eval, cfg.seq_len): | |
| offset = start + i | |
| if offset + cfg.seq_len + 1 > len(tokens): | |
| break | |
| chunk = torch.tensor(tokens[offset : offset + cfg.seq_len + 1], dtype=torch.long, device=device) | |
| x, y = chunk[:-1].unsqueeze(0), chunk[1:].unsqueeze(0) | |
| logits = model(x) | |
| loss = F.cross_entropy(logits.view(-1, cfg.vocab_size), y.view(-1)) | |
| total_loss += loss.item() | |
| count += 1 | |
| model.train() | |
| if count == 0: | |
| return float("inf") | |
| return math.exp(total_loss / count) | |
| # --------------------------------------------------------------------------- | |
| # ASCII bar chart for console | |
| # --------------------------------------------------------------------------- | |
| def ascii_bar(values: list[float], width: int = 30, labels: list[str] | None = None) -> str: | |
| if not values: | |
| return "" | |
| max_val = max(values) if max(values) > 0 else 1.0 | |
| lines = [] | |
| for i, v in enumerate(values): | |
| bar_len = int(v / max_val * width) | |
| label = labels[i] if labels else f"E{i}" | |
| lines.append(f" {label:>4s} |{'█' * bar_len}{' ' * (width - bar_len)}| {v:.0f}") | |
| return "\n".join(lines) | |
| # --------------------------------------------------------------------------- | |
| # Main training loop | |
| # --------------------------------------------------------------------------- | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Train MoE with Null Experts") | |
| parser.add_argument("--dataset", type=str, default="input.txt", | |
| help="Path to text file or 'tiny' for tiny-shakespeare") | |
| parser.add_argument("--steps", type=int, default=None) | |
| parser.add_argument("--batch", type=int, default=None) | |
| parser.add_argument("--lr", type=float, default=None) | |
| parser.add_argument("--eval_interval", type=int, default=None) | |
| parser.add_argument("--log_interval", type=int, default=None) | |
| parser.add_argument("--output_dir", type=str, default=".") | |
| args = parser.parse_args() | |
| cfg = Config() | |
| if args.steps is not None: | |
| cfg.max_steps = args.steps | |
| if args.batch is not None: | |
| cfg.batch_size = args.batch | |
| if args.lr is not None: | |
| cfg.lr = args.lr | |
| if args.eval_interval is not None: | |
| cfg.eval_interval = args.eval_interval | |
| if args.log_interval is not None: | |
| cfg.log_interval = args.log_interval | |
| # --- Device --- | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | |
| device = torch.device("mps") | |
| else: | |
| device = torch.device("cpu") | |
| print(f"Using device: {device}") | |
| # --- Tokenizer --- | |
| enc = tiktoken.get_encoding("gpt2") | |
| # --- Data --- | |
| data_path = args.dataset | |
| if data_path == "tiny": | |
| # Download tiny-shakespeare | |
| import urllib.request | |
| url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" | |
| data_path = os.path.join(args.output_dir, "input.txt") | |
| if not os.path.exists(data_path): | |
| print("Downloading tiny-shakespeare...") | |
| urllib.request.urlretrieve(url, data_path) | |
| with open(data_path, "r", encoding="utf-8") as f: | |
| text = f.read() | |
| tokens = enc.encode(text) | |
| print(f"Dataset: {len(tokens):,} tokens") | |
| dataset = TextDataset(tokens, cfg.seq_len) | |
| dataloader = DataLoader(dataset, batch_size=cfg.batch_size, shuffle=True, | |
| drop_last=True, num_workers=0) | |
| # --- Model --- | |
| model = MoENullModel(cfg).to(device) | |
| n_params = sum(p.numel() for p in model.parameters()) | |
| # Subtract tied weights (counted once in tok_emb, once in lm_head) | |
| n_params_unique = n_params - model.tok_emb.weight.numel() | |
| print(f"Model parameters: {n_params_unique:,} (unique, with weight tying)") | |
| print(f" Null copies M = {cfg.n_null_copies}, total routing slots = {cfg.n_total_slots}") | |
| assert n_params_unique < 42_000_000, f"Model too large: {n_params_unique:,} params" | |
| # --- Optimizer --- | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, | |
| betas=(0.9, 0.95), weight_decay=cfg.weight_decay) | |
| # --- Telemetry log --- | |
| telemetry_path = os.path.join(args.output_dir, "telemetry.json") | |
| telemetry_log: list[dict] = [] | |
| # --- Training --- | |
| model.train() | |
| data_iter = iter(dataloader) | |
| start_time = time.time() | |
| for step in range(1, cfg.max_steps + 1): | |
| # Get batch | |
| try: | |
| x_batch, y_batch = next(data_iter) | |
| except StopIteration: | |
| data_iter = iter(dataloader) | |
| x_batch, y_batch = next(data_iter) | |
| x_batch, y_batch = x_batch.to(device), y_batch.to(device) | |
| # LR schedule | |
| lr = get_lr(step, cfg) | |
| for pg in optimizer.param_groups: | |
| pg["lr"] = lr | |
| # Forward | |
| logits = model(x_batch) | |
| lm_loss = F.cross_entropy(logits.view(-1, cfg.vocab_size), y_batch.view(-1)) | |
| # Auxiliary losses | |
| balance_loss, z_loss = model.get_aux_losses() | |
| total_loss = lm_loss + cfg.balance_loss_weight * balance_loss + cfg.z_loss_weight * z_loss | |
| # Backward | |
| optimizer.zero_grad() | |
| total_loss.backward() | |
| nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip) | |
| optimizer.step() | |
| # --- Telemetry --- | |
| telem = model.get_telemetry() | |
| step_data = { | |
| "step": step, | |
| "total_loss": total_loss.item(), | |
| "lm_loss": lm_loss.item(), | |
| "balance_loss": balance_loss.item(), | |
| "z_loss": z_loss.item(), | |
| "lr": lr, | |
| "expert_counts": telem["avg_expert_counts"], | |
| "null_ratio": telem["null_ratio"], | |
| "gate_weights": telem["avg_gate_weights"], | |
| "zero_compute_ratio": telem["zero_compute_ratio"], | |
| } | |
| # Periodic perplexity + generation | |
| if step % cfg.eval_interval == 0 or step == 1: | |
| ppl = compute_perplexity(model, tokens, cfg, device) | |
| step_data["perplexity"] = ppl | |
| # Generate sample | |
| prompt_tokens = tokens[:10] | |
| prompt = torch.tensor([prompt_tokens], dtype=torch.long, device=device) | |
| gen_ids = model.generate(prompt, max_new_tokens=cfg.gen_tokens) | |
| gen_text = enc.decode(gen_ids[0].cpu().tolist()) | |
| step_data["generated_text"] = gen_text | |
| telemetry_log.append(step_data) | |
| # --- Console output --- | |
| if step % cfg.log_interval == 0 or step == 1: | |
| elapsed = time.time() - start_time | |
| tokens_per_sec = step * cfg.batch_size * cfg.seq_len / elapsed | |
| print(f"\n{'='*60}") | |
| print(f"Step {step}/{cfg.max_steps} | LR: {lr:.2e} | {tokens_per_sec:,.0f} tok/s") | |
| print(f" Loss: {total_loss.item():.4f} (LM: {lm_loss.item():.4f}, " | |
| f"Bal: {balance_loss.item():.4f}, Z: {z_loss.item():.4f})") | |
| if "perplexity" in step_data: | |
| print(f" Perplexity: {step_data['perplexity']:.2f}") | |
| print(f" Null routing: {telem['null_ratio']*100:.1f}% | " | |
| f"Zero-compute tokens: {telem['zero_compute_ratio']*100:.1f}%") | |
| if telem["avg_expert_counts"]: | |
| print(f" Expert utilization:") | |
| print(ascii_bar(telem["avg_expert_counts"], | |
| labels=[f"E{i}" for i in range(cfg.n_real_experts)])) | |
| if "generated_text" in step_data: | |
| print(f" Generated: {step_data['generated_text'][:200]}") | |
| # --- Save checkpoint --- | |
| if step % cfg.save_interval == 0: | |
| ckpt_path = os.path.join(args.output_dir, f"checkpoint_{step}.pt") | |
| torch.save({ | |
| "step": step, | |
| "model_state_dict": model.state_dict(), | |
| "optimizer_state_dict": optimizer.state_dict(), | |
| "config": cfg, | |
| }, ckpt_path) | |
| print(f" Saved checkpoint: {ckpt_path}") | |
| # --- Save telemetry periodically --- | |
| if step % cfg.log_interval == 0: | |
| with open(telemetry_path, "w") as f: | |
| json.dump(telemetry_log, f) | |
| # --- Final save --- | |
| with open(telemetry_path, "w") as f: | |
| json.dump(telemetry_log, f) | |
| print(f"\nTelemetry saved to {telemetry_path}") | |
| final_ckpt = os.path.join(args.output_dir, "checkpoint_final.pt") | |
| torch.save({ | |
| "step": cfg.max_steps, | |
| "model_state_dict": model.state_dict(), | |
| "optimizer_state_dict": optimizer.state_dict(), | |
| "config": cfg, | |
| }, final_ckpt) | |
| print(f"Final checkpoint saved to {final_ckpt}") | |
| if __name__ == "__main__": | |
| main() | |