|
|
--- |
|
|
tags: |
|
|
- mixture-of-experts |
|
|
- moe |
|
|
- transformer |
|
|
- language-model |
|
|
- pytorch |
|
|
- conditional-computation |
|
|
datasets: |
|
|
- custom |
|
|
pipeline_tag: text-generation |
|
|
license: mit |
|
|
--- |
|
|
|
|
|
# Mixture-of-Experts Language Models |
|
|
|
|
|
A PyTorch implementation exploring conditional computation in Transformers through Mixture-of-Experts (MoE). |
|
|
|
|
|
## Models |
|
|
|
|
|
This repository contains two MoE architectures: |
|
|
|
|
|
### 1. Sparse MoE (Top-K Routing) |
|
|
Routes each token to a fixed number of experts (k=2), increasing model capacity without proportionally increasing compute. |
|
|
|
|
|
### 2. Dynamic MoE (Confidence-Based Routing) |
|
|
Dynamically adjusts the number of experts per token based on routing confidenceβ"easy" tokens use fewer experts, "hard" tokens use more. |
|
|
|
|
|
## Model Details |
|
|
|
|
|
| Parameter | Sparse MoE | Dynamic MoE | |
|
|
|-----------|------------|-------------| |
|
|
| Layers | 4 | 4 | |
|
|
| Hidden Dim | 512 | 512 | |
|
|
| FFN Dim | 2048 | 2048 | |
|
|
| Attention Heads | 8 | 8 | |
|
|
| Experts | 8 | 4 | |
|
|
| Routing | Top-2 | Ο=0.8 threshold | |
|
|
| Context Length | 256 | 256 | |
|
|
| Vocab Size | 10,000 | 10,000 | |
|
|
|
|
|
## Architecture |
|
|
|
|
|
``` |
|
|
Input β Embedding β [Transformer Block Γ N] β RMSNorm β Linear β Output |
|
|
|
|
|
Transformer Block: |
|
|
ββ RMSNorm β Multi-Head Self-Attention β Residual |
|
|
ββ RMSNorm β MoE Layer β Residual |
|
|
|
|
|
MoE Layer: |
|
|
ββ Router (softmax gating) |
|
|
ββ Expert Selection (Top-K or Dynamic) |
|
|
ββ Weighted Expert Outputs |
|
|
``` |
|
|
|
|
|
## Training |
|
|
|
|
|
Both models were trained with: |
|
|
- **Optimizer**: AdamW (Ξ²1=0.9, Ξ²2=0.95) |
|
|
- **Learning Rate**: 3e-4 with cosine decay |
|
|
- **Warmup Steps**: 2,000 |
|
|
- **Weight Decay**: 0.1 |
|
|
|
|
|
### Loss Functions |
|
|
|
|
|
**Sparse MoE:** |
|
|
``` |
|
|
L = L_CE + Ξ± * L_balance |
|
|
``` |
|
|
|
|
|
**Dynamic MoE:** |
|
|
``` |
|
|
L = L_CE + Ξ² * L_balance + Ξ³ * L_entropy |
|
|
``` |
|
|
|
|
|
Where: |
|
|
- `L_CE`: Cross-entropy loss |
|
|
- `L_balance`: Load balancing loss (encourages uniform expert utilization) |
|
|
- `L_entropy`: Entropy regularization (encourages sparse routing) |
|
|
|
|
|
## Usage |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from moe.moelm import MoeLM, DynamicMOELM |
|
|
|
|
|
# Load Sparse MoE |
|
|
sparse_model = MoeLM( |
|
|
vocab_size=10000, |
|
|
num_layers=4, |
|
|
context_length=256, |
|
|
d_model=512, |
|
|
d_ff=2048, |
|
|
num_heads=8, |
|
|
num_experts=8, |
|
|
top_k=2 |
|
|
) |
|
|
sparse_model.load_state_dict(torch.load("sparse_moe_final.pt")) |
|
|
|
|
|
# Load Dynamic MoE |
|
|
dynamic_model = DynamicMOELM( |
|
|
vocab_size=10000, |
|
|
num_layers=4, |
|
|
context_length=256, |
|
|
d_model=512, |
|
|
d_ff=2048, |
|
|
num_heads=8, |
|
|
num_experts=4, |
|
|
confidence_threshold=0.8 |
|
|
) |
|
|
dynamic_model.load_state_dict(torch.load("dynamic_moe_final.pt")) |
|
|
``` |
|
|
|
|
|
## Files |
|
|
|
|
|
| File | Description | |
|
|
|------|-------------| |
|
|
| `sparse_moe_final.pt` | Sparse MoE model weights | |
|
|
| `dynamic_moe_final.pt` | Dynamic MoE model weights | |
|
|
| `sparse_moe_config.json` | Sparse MoE configuration | |
|
|
| `dynamic_moe_config.json` | Dynamic MoE configuration | |
|
|
|
|
|
## Citation |
|
|
|
|
|
```bibtex |
|
|
@misc{moe-lm-2024, |
|
|
title={Mixture-of-Experts Language Model}, |
|
|
author={Chaitanya}, |
|
|
year={2024}, |
|
|
url={https://github.com/chaitanya/transformers-and-MOE} |
|
|
} |
|
|
``` |
|
|
|
|
|
## Reference |
|
|
|
|
|
Based on ["Harder Tasks Need More Experts: Dynamic Routing in MoE Models"](https://arxiv.org/abs/2403.07652) |
|
|
|
|
|
## License |
|
|
|
|
|
MIT |