File size: 4,801 Bytes
281d8ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74a2acb
 
 
 
 
 
 
 
 
 
281d8ba
0daa7ef
 
 
 
 
 
 
 
 
 
281d8ba
 
 
 
 
 
 
 
 
 
 
 
 
a9b8fe6
281d8ba
 
 
 
a9b8fe6
 
 
 
 
 
 
7a0d4b3
281d8ba
 
a9b8fe6
 
281d8ba
 
 
 
 
 
 
 
 
 
 
 
a9b8fe6
 
 
 
 
 
 
281d8ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9b8fe6
 
 
0daa7ef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
---
license: mit
tags:
  - kernel
---

``` 
                                                            
oooo    ooo  .oooo.   ooo. .oo.  .oo.    .ooooo.   .ooooo.  
 `88.  .8'  `P  )88b  `888P"Y88bP"Y88b  d88' `88b d88' `88b 
  `88..8'    .oP"888   888   888   888  888   888 888ooo888 
   `888'    d8(  888   888   888   888  888   888 888    .o 
    .8'     `Y888""8o o888o o888o o888o `Y8bod8P' `Y8bod8P' 
.o..P'                                                      
`Y8P'                                                       

                Yet Another Mixture of Experts 
```

`yamoe` is a no nonsense, straightforward implementation of Mixture of Experts (MoE) kernels, designed to be super easy to use and be very computationally efficient.

### Design goals
- simplicity: easy to read and understand the code
- efficiency: optimized for high throughput and low latency
- low memory usage: optimized to handle large batch sizes
- reproducibility: easy to reproduce results, no special new `sm` requirements
- availability: easy to install and use via the [kernels](https://github.com/huggingface/kernels) library

### Kernel Hub

You can find the kernel on [Kernel Hub](https://huggingface.co/drbh/yamoe) and install it via the [kernels](https://github.com/huggingface/kernels) library.

```python
from kernels import get_kernel
yamoe = get_kernel("drbh/yamoe", revision="v0.2.0")
```

### Performance

`yamoe` scales well as batch sizes increase in comparision to the naive method of repeating the data and computation for every item in the batch as shown in the reference in [torch-ext/yamoe/reference.py](torch-ext/yamoe/reference.py). This bench can be reproduced by running `uv run perf_plot.py` or a smaller bench and correctness comparision can be run with `uv run compare_example.py`


TLDR: smaller is better on the first two rows of charts

<img width="3583" height="2358" alt="moe_performance_comparison" src="https://github.com/user-attachments/assets/72938f64-ec05-4eaa-82c4-507a43891543" />



### How to use

```python
# /// script
# requires-python = "==3.10"
# dependencies = ["torch==2.7.0", "triton", "numpy", "kernels"]
# [tool.uv.sources]
# kernels = { git = "https://github.com/huggingface/kernels.git" }
# ///

import time
import torch
from kernels import get_local_kernel
from kernels import get_kernel
from pathlib import Path
from torch.nn import functional as F

# Set seeds and deterministic flags for reproducibility
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.cuda.manual_seed_all(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

yamoe = get_kernel("drbh/yamoe", revision="v0.2.0")

# Configuration
batch_size, seq_len, hidden_dim = 16, 256, 2880
num_experts, top_k = 8, 2

# Create routing weights
logits = torch.randn(batch_size, seq_len, num_experts)
probs = F.softmax(logits, dim=-1)
weights, indices = torch.topk(probs, top_k, dim=-1)

batch_seq = batch_size * seq_len
routing_weights = torch.zeros(batch_seq, num_experts, dtype=weights.dtype)
flat_indices, flat_weights = indices.reshape(-1, top_k), weights.reshape(-1, top_k)
batch_indices = torch.arange(batch_seq).unsqueeze(1).expand(-1, top_k)
routing_weights[batch_indices, flat_indices] = flat_weights

# Create model tensors
hidden_states = torch.randn(batch_size, seq_len, hidden_dim).cuda()
gate_up_proj = torch.randn(num_experts, hidden_dim, 2 * hidden_dim).cuda()
gate_up_proj_bias = torch.zeros(num_experts, 2 * hidden_dim).cuda()
down_proj = torch.randn(num_experts, hidden_dim, hidden_dim).cuda()
down_proj_bias = torch.zeros(num_experts, hidden_dim).cuda()
routing_weights = routing_weights.cuda()
router_indices = flat_indices.cuda()

# Warmup
for _ in range(5):
    _ = yamoe.experts(
        hidden_states.view(-1, hidden_dim),
        router_indices,
        routing_weights.view(-1, num_experts),
        gate_up_proj,
        gate_up_proj_bias,
        down_proj,
        down_proj_bias,
        seq_len,
        num_experts,
        top_k,
    )

# Benchmark
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
start = time.perf_counter()

with torch.no_grad():
    output = yamoe.experts(
        hidden_states.view(-1, hidden_dim),
        router_indices,
        routing_weights.view(-1, num_experts),
        gate_up_proj,
        gate_up_proj_bias,
        down_proj,
        down_proj_bias,
        seq_len,
        num_experts,
        top_k,
    )

torch.cuda.synchronize()
elapsed_ms = (time.perf_counter() - start) * 1e3
peak_mem_mb = torch.cuda.max_memory_allocated() / (1024 * 1024)

print(f"Output: sum={output.sum().item():.1f}, min={output.min().item():.1f}, max={output.max().item():.1f}")
print(f"First 3: {output.view(-1)[:3].tolist()}")
print(f"Time: {elapsed_ms:.1f}ms, Memory: {peak_mem_mb:.0f}MB")
```