yamoe / README.md
drbh
fix: link to kernel hub
74a2acb
---
license: mit
tags:
- kernel
---
```
oooo ooo .oooo. ooo. .oo. .oo. .ooooo. .ooooo.
`88. .8' `P )88b `888P"Y88bP"Y88b d88' `88b d88' `88b
`88..8' .oP"888 888 888 888 888 888 888ooo888
`888' d8( 888 888 888 888 888 888 888 .o
.8' `Y888""8o o888o o888o o888o `Y8bod8P' `Y8bod8P'
.o..P'
`Y8P'
Yet Another Mixture of Experts
```
`yamoe` is a no nonsense, straightforward implementation of Mixture of Experts (MoE) kernels, designed to be super easy to use and be very computationally efficient.
### Design goals
- simplicity: easy to read and understand the code
- efficiency: optimized for high throughput and low latency
- low memory usage: optimized to handle large batch sizes
- reproducibility: easy to reproduce results, no special new `sm` requirements
- availability: easy to install and use via the [kernels](https://github.com/huggingface/kernels) library
### Kernel Hub
You can find the kernel on [Kernel Hub](https://huggingface.co/drbh/yamoe) and install it via the [kernels](https://github.com/huggingface/kernels) library.
```python
from kernels import get_kernel
yamoe = get_kernel("drbh/yamoe", revision="v0.2.0")
```
### Performance
`yamoe` scales well as batch sizes increase in comparision to the naive method of repeating the data and computation for every item in the batch as shown in the reference in [torch-ext/yamoe/reference.py](torch-ext/yamoe/reference.py). This bench can be reproduced by running `uv run perf_plot.py` or a smaller bench and correctness comparision can be run with `uv run compare_example.py`
TLDR: smaller is better on the first two rows of charts
<img width="3583" height="2358" alt="moe_performance_comparison" src="https://github.com/user-attachments/assets/72938f64-ec05-4eaa-82c4-507a43891543" />
### How to use
```python
# /// script
# requires-python = "==3.10"
# dependencies = ["torch==2.7.0", "triton", "numpy", "kernels"]
# [tool.uv.sources]
# kernels = { git = "https://github.com/huggingface/kernels.git" }
# ///
import time
import torch
from kernels import get_local_kernel
from kernels import get_kernel
from pathlib import Path
from torch.nn import functional as F
# Set seeds and deterministic flags for reproducibility
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.cuda.manual_seed_all(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
yamoe = get_kernel("drbh/yamoe", revision="v0.2.0")
# Configuration
batch_size, seq_len, hidden_dim = 16, 256, 2880
num_experts, top_k = 8, 2
# Create routing weights
logits = torch.randn(batch_size, seq_len, num_experts)
probs = F.softmax(logits, dim=-1)
weights, indices = torch.topk(probs, top_k, dim=-1)
batch_seq = batch_size * seq_len
routing_weights = torch.zeros(batch_seq, num_experts, dtype=weights.dtype)
flat_indices, flat_weights = indices.reshape(-1, top_k), weights.reshape(-1, top_k)
batch_indices = torch.arange(batch_seq).unsqueeze(1).expand(-1, top_k)
routing_weights[batch_indices, flat_indices] = flat_weights
# Create model tensors
hidden_states = torch.randn(batch_size, seq_len, hidden_dim).cuda()
gate_up_proj = torch.randn(num_experts, hidden_dim, 2 * hidden_dim).cuda()
gate_up_proj_bias = torch.zeros(num_experts, 2 * hidden_dim).cuda()
down_proj = torch.randn(num_experts, hidden_dim, hidden_dim).cuda()
down_proj_bias = torch.zeros(num_experts, hidden_dim).cuda()
routing_weights = routing_weights.cuda()
router_indices = flat_indices.cuda()
# Warmup
for _ in range(5):
_ = yamoe.experts(
hidden_states.view(-1, hidden_dim),
router_indices,
routing_weights.view(-1, num_experts),
gate_up_proj,
gate_up_proj_bias,
down_proj,
down_proj_bias,
seq_len,
num_experts,
top_k,
)
# Benchmark
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
start = time.perf_counter()
with torch.no_grad():
output = yamoe.experts(
hidden_states.view(-1, hidden_dim),
router_indices,
routing_weights.view(-1, num_experts),
gate_up_proj,
gate_up_proj_bias,
down_proj,
down_proj_bias,
seq_len,
num_experts,
top_k,
)
torch.cuda.synchronize()
elapsed_ms = (time.perf_counter() - start) * 1e3
peak_mem_mb = torch.cuda.max_memory_allocated() / (1024 * 1024)
print(f"Output: sum={output.sum().item():.1f}, min={output.min().item():.1f}, max={output.max().item():.1f}")
print(f"First 3: {output.view(-1)[:3].tolist()}")
print(f"Time: {elapsed_ms:.1f}ms, Memory: {peak_mem_mb:.0f}MB")
```