moe-emergence / README.md
sumit
updated model card with ablation results and all 4 runs
4049aa7
---
license: mit
tags:
- mixture-of-experts
- gpt2
- research
- expert-specialization
language:
- en
datasets:
- codeparrot/codeparrot-clean
- allenai/ai2_arc
- allenai/c4
base_model:
- openai-community/gpt2
library_name: transformers
pipeline_tag: text-generation
---
# MoE Emergence
Checkpoints from a research project studying expert specialization in Mixture-of-Experts models. I fine-tuned GPT-2 small on three domains -- code, math, and prose -- to see whether experts naturally specialize by domain when given the right routing incentives.
Short answer: they do. MoE beats the dense baseline by 3.6% overall and 14% on math, with zero expert collapse across 10,000 training steps. Two ablation runs confirmed that load balancing loss is essential (without it, one expert captures 73.6% of tokens by step 500) and that top-2 routing provides negligible improvement over top-1.
---
## 1. Results
### Main comparison
| Metric | Dense Baseline | MoE (top-1) | Delta |
|---|---|---|---|
| eval/loss | 2.157 | 2.080 | -3.6% |
| loss_code | 1.554 | 1.521 | -2.1% |
| loss_math | 2.023 | 1.740 | -14.0% |
| loss_prose | 3.485 | 3.541 | +1.6% |
| perplexity | 8.64 | 7.91 | -8.4% |
Math benefits the most from expert routing. Prose is the one domain where dense wins; diverse web text doesn't lend itself to clean expert specialization. The MoE model crossed the dense baseline at step ~3,600 (36% of training).
### Ablations
| Run | What it tests | Result |
|---|---|---|
| No-LB ablation | Remove load balancing loss (`lb_coef=0.0`) | Expert collapse at step 500. Single expert handles 73.6% of tokens. Z-loss alone doesn't prevent it. |
| Top-2 directional | Route to 2 experts instead of 1 | eval/loss=2.077 vs top-1's 2.080, which is a 0.14% difference. Not worth 2x expert compute. |
---
## 2. Files
```
dense-baseline/
β”œβ”€β”€ final-model.safetensors # 622 MB -- dense GPT-2, 124M params
β”œβ”€β”€ final-model.json # metadata sidecar (config, metrics)
β”œβ”€β”€ ckpt-step-4999.pt # 1.4 GB -- full resume checkpoint
└── metrics.jsonl # per-step training + eval metrics
moe-main/
β”œβ”€β”€ final-model.safetensors # 1.1 GB -- MoE GPT-2, 257M params (8 experts Γ— 4 layers)
β”œβ”€β”€ final-model.json # metadata sidecar
β”œβ”€β”€ ckpt-step-9999.pt # 2.9 GB -- full resume checkpoint
└── metrics.jsonl
no-lb-ablation/
β”œβ”€β”€ final-model.safetensors # 1.1 GB -- collapsed MoE model at step 500
β”œβ”€β”€ best-model.safetensors # 1.1 GB -- best eval loss (step 400, pre-collapse)
β”œβ”€β”€ ckpt-step-500.pt # 2.9 GB -- full resume checkpoint
β”œβ”€β”€ config.json, run_summary.json
└── metrics.jsonl
top2-main-10k/
β”œβ”€β”€ final-model.safetensors # 1.2 GB -- top-2 MoE model at step 9999
β”œβ”€β”€ best-model.safetensors # 1.2 GB -- best eval loss (step 8000)
β”œβ”€β”€ ckpt-step-9999.pt # 2.9 GB -- full resume checkpoint
β”œβ”€β”€ config.json, run_summary.json
└── metrics.jsonl
```
The `.safetensors` files are the trained model weights. The `.pt` files contain the full training state for resuming runs (optimizer, LR scheduler, RNG states). The `.json` sidecars store architecture config and final eval metrics.
---
## 3. Usage
Clone the [source repo](https://github.com/sumitdotml/moe-emergence) and install dependencies:
```bash
git clone https://github.com/sumitdotml/moe-emergence.git
cd moe-emergence
uv sync
```
Run inference with a trained checkpoint:
```bash
# MoE model
uv run python moe_emergence/gpt2_inference.py \
--checkpoint checkpoints/moe-main/final-model \
--prompt "def fibonacci(n):" \
--sample --temperature 0.8
# Dense baseline
uv run python moe_emergence/gpt2_inference.py \
--checkpoint checkpoints/dense-baseline/final-model \
--prompt "The meaning of life is"
```
The inference script reads the `.json` sidecar to detect mode (dense vs MoE) and architecture config automatically.
To resume training from a checkpoint:
```bash
uv run python -m moe_emergence.train \
--preset moe-main --run-name moe-main \
--device cuda \
--resume checkpoints/moe-main/ckpt-step-9999.pt
```
---
## 4. Architecture
The dense baseline is standard GPT-2 small (124M parameters, 12 transformer layers).
The MoE model takes GPT-2 small and replaces layers 8-11 with MoE layers. Each MoE layer has 8 experts -- deep copies of the original GPT-2 MLP, warm-started from pretrained weights -- and a learned router with top-1 routing. Total: 257M parameters.
Routing uses the Straight-Through Estimator. Forward pass routes to one expert with weight 1.0, backward pass flows gradients through the soft probability from the router.
| Component | Detail |
|---|---|
| Base model | GPT-2 small (124M) |
| MoE layers | 8, 9, 10, 11 |
| Experts per layer | 8 |
| Routing | Top-1, STE |
| Expert init | `deepcopy(original_mlp)` + tiny noise |
| Load balance loss | `0.01 Γ— n_experts Γ— Ξ£(f_i Γ— P_i)` |
| Z-loss | `0.001 Γ— mean(logsumexp(logits)Β²)` |
---
## 5. Training
All models trained on ~6.6M tokens across three domains, balanced to equal token counts:
| Domain | Source | Size |
|---|---|---|
| Code | CodeParrot-clean (Python) | 10 MB |
| Math | MathQA (allenai) | 10 MB |
| Prose | C4 English (allenai) | 10 MB |
Training config:
| Parameter | Dense | MoE (top-1) | MoE (top-2) | No-LB |
|---|---|---|---|---|
| Max steps | 5,000 | 10,000 | 10,000 | 2,000 (early-stopped at 500) |
| Batch size | 8 | 8 | 8 | 8 |
| Block size | 512 | 512 | 512 | 512 |
| Learning rate | 5e-5 | 5e-5 | 5e-5 | 5e-5 |
| lb_coef | β€” | 0.01 | 0.01 | 0.0 |
| noise_std | β€” | 0.1 | 0.1 | 0.0 |
| Hardware | 1Γ— RTX 4090 | 1Γ— RTX 4090 | 1Γ— RTX 4090 | 1Γ— RTX 4090 |
| Wall time | ~30 min | ~85 min | ~48 min | ~5 min |
Total GPU cost for all 4 runs: ~$2.79 (including setup/idle overhead).
---
## 6. W&B
Training curves are on Weights & Biases:
- [Dense baseline](https://wandb.ai/sumit-ml/moe-emergence/runs/fqhfblfv)
- [MoE main run](https://wandb.ai/sumit-ml/moe-emergence/runs/j08s2d1m)
- [No-LB ablation](https://wandb.ai/sumit-ml/moe-emergence/runs/06pljhrv)
- [Top-2 directional](https://wandb.ai/sumit-ml/moe-emergence/runs/6mw6qbac)
---
## 7. Links
- Code: [github.com/sumitdotml/moe-emergence](https://github.com/sumitdotml/moe-emergence)
- Experiment docs: [run-004 (dense)](https://github.com/sumitdotml/moe-emergence/blob/main/docs/experiments/run-004-dense-baseline.md), [run-005 (MoE)](https://github.com/sumitdotml/moe-emergence/blob/main/docs/experiments/run-005-moe-main.md), [run-006 (no-LB)](https://github.com/sumitdotml/moe-emergence/blob/main/docs/experiments/run-006-no-lb-ablation.md), [run-007 (top-2)](https://github.com/sumitdotml/moe-emergence/blob/main/docs/experiments/run-007-top2-directional.md)
---
## License
MIT. See the [source repo](https://github.com/sumitdotml/moe-emergence/blob/main/LICENSE) for details. Third-party dataset licenses are documented in [THIRD-PARTY-NOTICES.md](https://github.com/sumitdotml/moe-emergence/blob/main/THIRD-PARTY-NOTICES.md).