--- license: mit tags: - mixture-of-experts - gpt2 - research - expert-specialization language: - en datasets: - codeparrot/codeparrot-clean - allenai/ai2_arc - allenai/c4 base_model: - openai-community/gpt2 library_name: transformers pipeline_tag: text-generation --- # MoE Emergence Checkpoints from a research project studying expert specialization in Mixture-of-Experts models. I fine-tuned GPT-2 small on three domains -- code, math, and prose -- to see whether experts naturally specialize by domain when given the right routing incentives. Short answer: they do. MoE beats the dense baseline by 3.6% overall and 14% on math, with zero expert collapse across 10,000 training steps. Two ablation runs confirmed that load balancing loss is essential (without it, one expert captures 73.6% of tokens by step 500) and that top-2 routing provides negligible improvement over top-1. --- ## 1. Results ### Main comparison | Metric | Dense Baseline | MoE (top-1) | Delta | |---|---|---|---| | eval/loss | 2.157 | 2.080 | -3.6% | | loss_code | 1.554 | 1.521 | -2.1% | | loss_math | 2.023 | 1.740 | -14.0% | | loss_prose | 3.485 | 3.541 | +1.6% | | perplexity | 8.64 | 7.91 | -8.4% | Math benefits the most from expert routing. Prose is the one domain where dense wins; diverse web text doesn't lend itself to clean expert specialization. The MoE model crossed the dense baseline at step ~3,600 (36% of training). ### Ablations | Run | What it tests | Result | |---|---|---| | No-LB ablation | Remove load balancing loss (`lb_coef=0.0`) | Expert collapse at step 500. Single expert handles 73.6% of tokens. Z-loss alone doesn't prevent it. | | Top-2 directional | Route to 2 experts instead of 1 | eval/loss=2.077 vs top-1's 2.080, which is a 0.14% difference. Not worth 2x expert compute. | --- ## 2. Files ``` dense-baseline/ ├── final-model.safetensors # 622 MB -- dense GPT-2, 124M params ├── final-model.json # metadata sidecar (config, metrics) ├── ckpt-step-4999.pt # 1.4 GB -- full resume checkpoint └── metrics.jsonl # per-step training + eval metrics moe-main/ ├── final-model.safetensors # 1.1 GB -- MoE GPT-2, 257M params (8 experts × 4 layers) ├── final-model.json # metadata sidecar ├── ckpt-step-9999.pt # 2.9 GB -- full resume checkpoint └── metrics.jsonl no-lb-ablation/ ├── final-model.safetensors # 1.1 GB -- collapsed MoE model at step 500 ├── best-model.safetensors # 1.1 GB -- best eval loss (step 400, pre-collapse) ├── ckpt-step-500.pt # 2.9 GB -- full resume checkpoint ├── config.json, run_summary.json └── metrics.jsonl top2-main-10k/ ├── final-model.safetensors # 1.2 GB -- top-2 MoE model at step 9999 ├── best-model.safetensors # 1.2 GB -- best eval loss (step 8000) ├── ckpt-step-9999.pt # 2.9 GB -- full resume checkpoint ├── config.json, run_summary.json └── metrics.jsonl ``` The `.safetensors` files are the trained model weights. The `.pt` files contain the full training state for resuming runs (optimizer, LR scheduler, RNG states). The `.json` sidecars store architecture config and final eval metrics. --- ## 3. Usage Clone the [source repo](https://github.com/sumitdotml/moe-emergence) and install dependencies: ```bash git clone https://github.com/sumitdotml/moe-emergence.git cd moe-emergence uv sync ``` Run inference with a trained checkpoint: ```bash # MoE model uv run python moe_emergence/gpt2_inference.py \ --checkpoint checkpoints/moe-main/final-model \ --prompt "def fibonacci(n):" \ --sample --temperature 0.8 # Dense baseline uv run python moe_emergence/gpt2_inference.py \ --checkpoint checkpoints/dense-baseline/final-model \ --prompt "The meaning of life is" ``` The inference script reads the `.json` sidecar to detect mode (dense vs MoE) and architecture config automatically. To resume training from a checkpoint: ```bash uv run python -m moe_emergence.train \ --preset moe-main --run-name moe-main \ --device cuda \ --resume checkpoints/moe-main/ckpt-step-9999.pt ``` --- ## 4. Architecture The dense baseline is standard GPT-2 small (124M parameters, 12 transformer layers). The MoE model takes GPT-2 small and replaces layers 8-11 with MoE layers. Each MoE layer has 8 experts -- deep copies of the original GPT-2 MLP, warm-started from pretrained weights -- and a learned router with top-1 routing. Total: 257M parameters. Routing uses the Straight-Through Estimator. Forward pass routes to one expert with weight 1.0, backward pass flows gradients through the soft probability from the router. | Component | Detail | |---|---| | Base model | GPT-2 small (124M) | | MoE layers | 8, 9, 10, 11 | | Experts per layer | 8 | | Routing | Top-1, STE | | Expert init | `deepcopy(original_mlp)` + tiny noise | | Load balance loss | `0.01 × n_experts × Σ(f_i × P_i)` | | Z-loss | `0.001 × mean(logsumexp(logits)²)` | --- ## 5. Training All models trained on ~6.6M tokens across three domains, balanced to equal token counts: | Domain | Source | Size | |---|---|---| | Code | CodeParrot-clean (Python) | 10 MB | | Math | MathQA (allenai) | 10 MB | | Prose | C4 English (allenai) | 10 MB | Training config: | Parameter | Dense | MoE (top-1) | MoE (top-2) | No-LB | |---|---|---|---|---| | Max steps | 5,000 | 10,000 | 10,000 | 2,000 (early-stopped at 500) | | Batch size | 8 | 8 | 8 | 8 | | Block size | 512 | 512 | 512 | 512 | | Learning rate | 5e-5 | 5e-5 | 5e-5 | 5e-5 | | lb_coef | — | 0.01 | 0.01 | 0.0 | | noise_std | — | 0.1 | 0.1 | 0.0 | | Hardware | 1× RTX 4090 | 1× RTX 4090 | 1× RTX 4090 | 1× RTX 4090 | | Wall time | ~30 min | ~85 min | ~48 min | ~5 min | Total GPU cost for all 4 runs: ~$2.79 (including setup/idle overhead). --- ## 6. W&B Training curves are on Weights & Biases: - [Dense baseline](https://wandb.ai/sumit-ml/moe-emergence/runs/fqhfblfv) - [MoE main run](https://wandb.ai/sumit-ml/moe-emergence/runs/j08s2d1m) - [No-LB ablation](https://wandb.ai/sumit-ml/moe-emergence/runs/06pljhrv) - [Top-2 directional](https://wandb.ai/sumit-ml/moe-emergence/runs/6mw6qbac) --- ## 7. Links - Code: [github.com/sumitdotml/moe-emergence](https://github.com/sumitdotml/moe-emergence) - Experiment docs: [run-004 (dense)](https://github.com/sumitdotml/moe-emergence/blob/main/docs/experiments/run-004-dense-baseline.md), [run-005 (MoE)](https://github.com/sumitdotml/moe-emergence/blob/main/docs/experiments/run-005-moe-main.md), [run-006 (no-LB)](https://github.com/sumitdotml/moe-emergence/blob/main/docs/experiments/run-006-no-lb-ablation.md), [run-007 (top-2)](https://github.com/sumitdotml/moe-emergence/blob/main/docs/experiments/run-007-top2-directional.md) --- ## License MIT. See the [source repo](https://github.com/sumitdotml/moe-emergence/blob/main/LICENSE) for details. Third-party dataset licenses are documented in [THIRD-PARTY-NOTICES.md](https://github.com/sumitdotml/moe-emergence/blob/main/THIRD-PARTY-NOTICES.md).