File size: 7,092 Bytes
3ff42e6 4049aa7 3ff42e6 4049aa7 3ff42e6 4049aa7 3ff42e6 4049aa7 3ff42e6 4049aa7 3ff42e6 4049aa7 3ff42e6 4049aa7 3ff42e6 4049aa7 3ff42e6 4049aa7 3ff42e6 4049aa7 3ff42e6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 | ---
license: mit
tags:
- mixture-of-experts
- gpt2
- research
- expert-specialization
language:
- en
datasets:
- codeparrot/codeparrot-clean
- allenai/ai2_arc
- allenai/c4
base_model:
- openai-community/gpt2
library_name: transformers
pipeline_tag: text-generation
---
# MoE Emergence
Checkpoints from a research project studying expert specialization in Mixture-of-Experts models. I fine-tuned GPT-2 small on three domains -- code, math, and prose -- to see whether experts naturally specialize by domain when given the right routing incentives.
Short answer: they do. MoE beats the dense baseline by 3.6% overall and 14% on math, with zero expert collapse across 10,000 training steps. Two ablation runs confirmed that load balancing loss is essential (without it, one expert captures 73.6% of tokens by step 500) and that top-2 routing provides negligible improvement over top-1.
---
## 1. Results
### Main comparison
| Metric | Dense Baseline | MoE (top-1) | Delta |
|---|---|---|---|
| eval/loss | 2.157 | 2.080 | -3.6% |
| loss_code | 1.554 | 1.521 | -2.1% |
| loss_math | 2.023 | 1.740 | -14.0% |
| loss_prose | 3.485 | 3.541 | +1.6% |
| perplexity | 8.64 | 7.91 | -8.4% |
Math benefits the most from expert routing. Prose is the one domain where dense wins; diverse web text doesn't lend itself to clean expert specialization. The MoE model crossed the dense baseline at step ~3,600 (36% of training).
### Ablations
| Run | What it tests | Result |
|---|---|---|
| No-LB ablation | Remove load balancing loss (`lb_coef=0.0`) | Expert collapse at step 500. Single expert handles 73.6% of tokens. Z-loss alone doesn't prevent it. |
| Top-2 directional | Route to 2 experts instead of 1 | eval/loss=2.077 vs top-1's 2.080, which is a 0.14% difference. Not worth 2x expert compute. |
---
## 2. Files
```
dense-baseline/
βββ final-model.safetensors # 622 MB -- dense GPT-2, 124M params
βββ final-model.json # metadata sidecar (config, metrics)
βββ ckpt-step-4999.pt # 1.4 GB -- full resume checkpoint
βββ metrics.jsonl # per-step training + eval metrics
moe-main/
βββ final-model.safetensors # 1.1 GB -- MoE GPT-2, 257M params (8 experts Γ 4 layers)
βββ final-model.json # metadata sidecar
βββ ckpt-step-9999.pt # 2.9 GB -- full resume checkpoint
βββ metrics.jsonl
no-lb-ablation/
βββ final-model.safetensors # 1.1 GB -- collapsed MoE model at step 500
βββ best-model.safetensors # 1.1 GB -- best eval loss (step 400, pre-collapse)
βββ ckpt-step-500.pt # 2.9 GB -- full resume checkpoint
βββ config.json, run_summary.json
βββ metrics.jsonl
top2-main-10k/
βββ final-model.safetensors # 1.2 GB -- top-2 MoE model at step 9999
βββ best-model.safetensors # 1.2 GB -- best eval loss (step 8000)
βββ ckpt-step-9999.pt # 2.9 GB -- full resume checkpoint
βββ config.json, run_summary.json
βββ metrics.jsonl
```
The `.safetensors` files are the trained model weights. The `.pt` files contain the full training state for resuming runs (optimizer, LR scheduler, RNG states). The `.json` sidecars store architecture config and final eval metrics.
---
## 3. Usage
Clone the [source repo](https://github.com/sumitdotml/moe-emergence) and install dependencies:
```bash
git clone https://github.com/sumitdotml/moe-emergence.git
cd moe-emergence
uv sync
```
Run inference with a trained checkpoint:
```bash
# MoE model
uv run python moe_emergence/gpt2_inference.py \
--checkpoint checkpoints/moe-main/final-model \
--prompt "def fibonacci(n):" \
--sample --temperature 0.8
# Dense baseline
uv run python moe_emergence/gpt2_inference.py \
--checkpoint checkpoints/dense-baseline/final-model \
--prompt "The meaning of life is"
```
The inference script reads the `.json` sidecar to detect mode (dense vs MoE) and architecture config automatically.
To resume training from a checkpoint:
```bash
uv run python -m moe_emergence.train \
--preset moe-main --run-name moe-main \
--device cuda \
--resume checkpoints/moe-main/ckpt-step-9999.pt
```
---
## 4. Architecture
The dense baseline is standard GPT-2 small (124M parameters, 12 transformer layers).
The MoE model takes GPT-2 small and replaces layers 8-11 with MoE layers. Each MoE layer has 8 experts -- deep copies of the original GPT-2 MLP, warm-started from pretrained weights -- and a learned router with top-1 routing. Total: 257M parameters.
Routing uses the Straight-Through Estimator. Forward pass routes to one expert with weight 1.0, backward pass flows gradients through the soft probability from the router.
| Component | Detail |
|---|---|
| Base model | GPT-2 small (124M) |
| MoE layers | 8, 9, 10, 11 |
| Experts per layer | 8 |
| Routing | Top-1, STE |
| Expert init | `deepcopy(original_mlp)` + tiny noise |
| Load balance loss | `0.01 Γ n_experts Γ Ξ£(f_i Γ P_i)` |
| Z-loss | `0.001 Γ mean(logsumexp(logits)Β²)` |
---
## 5. Training
All models trained on ~6.6M tokens across three domains, balanced to equal token counts:
| Domain | Source | Size |
|---|---|---|
| Code | CodeParrot-clean (Python) | 10 MB |
| Math | MathQA (allenai) | 10 MB |
| Prose | C4 English (allenai) | 10 MB |
Training config:
| Parameter | Dense | MoE (top-1) | MoE (top-2) | No-LB |
|---|---|---|---|---|
| Max steps | 5,000 | 10,000 | 10,000 | 2,000 (early-stopped at 500) |
| Batch size | 8 | 8 | 8 | 8 |
| Block size | 512 | 512 | 512 | 512 |
| Learning rate | 5e-5 | 5e-5 | 5e-5 | 5e-5 |
| lb_coef | β | 0.01 | 0.01 | 0.0 |
| noise_std | β | 0.1 | 0.1 | 0.0 |
| Hardware | 1Γ RTX 4090 | 1Γ RTX 4090 | 1Γ RTX 4090 | 1Γ RTX 4090 |
| Wall time | ~30 min | ~85 min | ~48 min | ~5 min |
Total GPU cost for all 4 runs: ~$2.79 (including setup/idle overhead).
---
## 6. W&B
Training curves are on Weights & Biases:
- [Dense baseline](https://wandb.ai/sumit-ml/moe-emergence/runs/fqhfblfv)
- [MoE main run](https://wandb.ai/sumit-ml/moe-emergence/runs/j08s2d1m)
- [No-LB ablation](https://wandb.ai/sumit-ml/moe-emergence/runs/06pljhrv)
- [Top-2 directional](https://wandb.ai/sumit-ml/moe-emergence/runs/6mw6qbac)
---
## 7. Links
- Code: [github.com/sumitdotml/moe-emergence](https://github.com/sumitdotml/moe-emergence)
- Experiment docs: [run-004 (dense)](https://github.com/sumitdotml/moe-emergence/blob/main/docs/experiments/run-004-dense-baseline.md), [run-005 (MoE)](https://github.com/sumitdotml/moe-emergence/blob/main/docs/experiments/run-005-moe-main.md), [run-006 (no-LB)](https://github.com/sumitdotml/moe-emergence/blob/main/docs/experiments/run-006-no-lb-ablation.md), [run-007 (top-2)](https://github.com/sumitdotml/moe-emergence/blob/main/docs/experiments/run-007-top2-directional.md)
---
## License
MIT. See the [source repo](https://github.com/sumitdotml/moe-emergence/blob/main/LICENSE) for details. Third-party dataset licenses are documented in [THIRD-PARTY-NOTICES.md](https://github.com/sumitdotml/moe-emergence/blob/main/THIRD-PARTY-NOTICES.md).
|