| --- |
| license: apache-2.0 |
| tags: |
| - kernels |
| --- |
| |
| <!--  --> |
|
|
| # triton moe |
|
|
| This repository contains the kernels to run the Mixture of Experts (MoE) model using Triton. |
|
|
|
|
|
|
| ```python |
| # /// script |
| # dependencies = [ |
| # "kernels", |
| # "numpy", |
| # "torch", |
| # ] |
| # /// |
| import torch |
| |
| from kernels import get_kernel |
| |
| # Make reproducible |
| torch.manual_seed(42) |
| torch.cuda.manual_seed(42) |
| |
| # Download optimized kernels from the Hugging Face hub |
| triton_moe = get_kernel("kernels-community/triton-moe") |
| |
| # Random tensor |
| x = torch.randn((10, 10), dtype=torch.float16, device="cuda") |
| |
| # Run the kernel |
| gate_up_out = x.unsqueeze(-1).repeat(1, 1, 2) |
| out = triton_moe.fused_glu.fused_glu_triton(gate_up_out=gate_up_out, alpha=1.0) |
| |
| # Check the output |
| print("Output shape:", out.shape) |
| print("Output sum:", out.sum().item()) |
| # Output shape: torch.Size([10, 10, 1]) |
| # Output sum: 62.875 |
| |
| ``` |
|
|
|
|
| ### Testing |
|
|
| ```bash |
| nix develop -i -L .#test --command python -m pytest -s tests |
| ``` |
|
|
| expected output of the test in [`tests/test_triton_moe.py`](tests/test_triton_moe.py): |
|
|
| ```text |
| warning: Git tree '/home/ubuntu/Projects/triton-moe' is dirty |
| evaluation warning: CUDA versions older than 12.0 will be removed in Nixpkgs 25.05; see the 24.11 release notes for more information |
| triton_moe-torch-ext> Running phase: unpackPhase |
| triton_moe-torch-ext> unpacking source archive /nix/store/5zm9aqzym4h6xx414sy17dynr1hjbwh8-source |
| triton_moe-torch-ext> source root is source |
| triton_moe-torch-ext> Running phase: patchPhase |
| triton_moe-torch-ext> Running phase: updateAutotoolsGnuConfigScriptsPhase |
| triton_moe-torch-ext> Running phase: configurePhase |
| triton_moe-torch-ext> no configure script, doing nothing |
| triton_moe-torch-ext> Running phase: installPhase |
| triton_moe-torch-ext> Running phase: fixupPhase |
| triton_moe-torch-ext> shrinking RPATHs of ELF executables and libraries in /nix/store/yrzl0pngg8xxpf4jpkya8rmvmakgn4cd-triton_moe-torch-ext |
| triton_moe-torch-ext> checking for references to /build/ in /nix/store/yrzl0pngg8xxpf4jpkya8rmvmakgn4cd-triton_moe-torch-ext... |
| triton_moe-torch-ext> patching script interpreter paths in /nix/store/yrzl0pngg8xxpf4jpkya8rmvmakgn4cd-triton_moe-torch-ext |
| ===================================== test session starts ====================================== |
| platform linux -- Python 3.12.10, pytest-8.3.5, pluggy-1.5.0 |
| rootdir: /home/ubuntu/Projects/triton-moe |
| plugins: hypothesis-6.130.12 |
| collected 8 items |
| |
| tests/test_triton_moe.py Average difference: 0.009301766753196716 |
| Max difference: 0.095703125 |
| .gate_up_proj.grad exists: True |
| gate_up_proj_bias.grad exists: True |
| down_proj.grad exists: True |
| down_proj_bias.grad exists: True |
| .gate_up_proj.grad exists: True |
| gate_up_proj_bias.grad exists: True |
| down_proj.grad exists: True |
| down_proj_bias.grad exists: True |
| hidden_states.grad exists: True |
| ✓ Backward test passed - all parameters have gradients |
| 10 elements from gate_up_proj gradients: |
| tensor([ 179.9667, -852.7672, -3274.1992, -4076.2095, -2571.6282, -296.6539, |
| 1800.2004, 503.5397, 48.4640, 191.8257], device='cuda:0') |
| 10 elements from ref_layer.gate_up_proj gradients: |
| tensor([ 179.9663, -852.7676, -3274.1995, -4076.2188, -2571.6238, -296.6619, |
| 1800.1997, 503.5378, 48.4632, 191.8266], device='cuda:0') |
| .Warming up... |
| Benchmarking reference implementation (20 runs)... |
| Completed 5/20 runs |
| Completed 10/20 runs |
| Completed 15/20 runs |
| Completed 20/20 runs |
| Benchmarking custom implementation (20 runs)... |
| Completed 5/20 runs |
| Completed 10/20 runs |
| Completed 15/20 runs |
| Completed 20/20 runs |
| |
| ================================================================================ |
| BACKWARD PASS BENCHMARK RESULTS |
| ================================================================================ |
| Configuration: |
| - Experts: 128 |
| - Hidden size: 1024 |
| - Expert dim: 512 |
| - Batch tokens: 4096 |
| - Top-k: 2 |
| - Runs: 20 |
| |
| Reference Implementation (OpenaiExperts): |
| - Mean: 1855.949 ms |
| - Std: 9.959 ms |
| - Min: 1851.829 ms |
| - Max: 1896.181 ms |
| |
| Custom Implementation (MoE): |
| - Mean: 250.311 ms |
| - Std: 0.591 ms |
| - Min: 249.697 ms |
| - Max: 252.103 ms |
| |
| Speedup: 7.41x |
| ✓ Custom implementation is 7.41x faster |
| ================================================================================ |
| |
| Detailed timings (ms): |
| Reference: [1869.8261399986222, 1852.2405139519833, 1853.5575779969804, 1852.6741919922642, 1853.342688002158, 1853.4869640134275, 1854.8076269798912, 1852.6032069930807, 1853.7065120181069, 1852.3418360273354, 1853.6653390037827, 1853.267052967567, 1853.421829000581, 1851.8838259624317, 1852.5341430213302, 1851.8291829968803, 1852.3553150007501, 1896.1806659935974, 1852.8080619871616, 1852.4555769981816] |
| Custom: [252.10309802787378, 251.01929501397535, 250.63456897623837, 250.02275395672768, 249.69729100121185, 249.89533895859495, 249.75963402539492, 249.9880829709582, 251.09507801244035, 250.78181497519836, 250.63572899671271, 250.75654400279745, 250.1296689733863, 249.93031000485644, 249.83260600129142, 250.06496504647657, 250.1023070071824, 250.03619497874752, 249.96405199635774, 249.77726401994005] |
| . |
| ============================================================ |
| MEMORY USAGE BENCHMARK |
| ============================================================ |
| Reference implementation: 2.977 GB |
| Custom implementation: 2.737 GB |
| Memory ratio: 0.919x |
| ✓ Custom uses 8.1% less memory |
| ============================================================ |
| .Warming up... |
| Benchmarking reference implementation (50 runs)... |
| Completed 10/50 runs |
| Completed 20/50 runs |
| Completed 30/50 runs |
| Completed 40/50 runs |
| Completed 50/50 runs |
| Benchmarking custom implementation (50 runs)... |
| Completed 10/50 runs |
| Completed 20/50 runs |
| Completed 30/50 runs |
| Completed 40/50 runs |
| Completed 50/50 runs |
| |
| ================================================================================ |
| FORWARD PASS BENCHMARK RESULTS |
| ================================================================================ |
| Configuration: |
| - Experts: 128 |
| - Hidden size: 1024 |
| - Expert dim: 512 |
| - Batch tokens: 4096 |
| - Top-k: 2 |
| - Runs: 50 |
| |
| Reference Implementation (OpenaiExperts): |
| - Mean: 45.218 ms |
| - Std: 0.643 ms |
| - Min: 44.657 ms |
| - Max: 49.252 ms |
| |
| Custom Implementation (MoE): |
| - Mean: 45.092 ms |
| - Std: 0.382 ms |
| - Min: 44.630 ms |
| - Max: 45.988 ms |
| |
| Speedup: 1.00x |
| ✓ Custom implementation is 1.00x faster |
| ================================================================================ |
| |
| Detailed timings (ms): |
| Reference: [49.2524920264259, 44.996229000389576, 44.75043900310993, 45.368854014668614, 45.28193100122735, 44.9596070102416, 45.46641802880913, 45.40711600566283, 44.76102895569056, 45.04138103220612, 44.8942250222899, 45.068661973346025, 45.07604299578816, 45.02651101211086, 44.988538953475654, 44.86601299140602, 45.13182397931814, 44.97082799207419, 44.656905985902995, 45.53678201045841, 45.52290996070951, 45.288041001185775, 45.18025700235739, 45.118414040189236, 44.841952971182764, 45.04251101752743, 45.11384398210794, 45.021480007562786, 44.88742502871901, 44.94614701252431, 45.80574203282595, 46.0561320069246, 45.724289026111364, 45.17030599527061, 45.18806800479069, 44.91667600814253, 45.08163296850398, 44.952887983527035, 45.330752967856824, 44.88741495879367, 44.882684014737606, 45.47502798959613, 45.74003902962431, 45.015780022367835, 45.045601029414684, 44.906274997629225, 44.9566770112142, 45.05321098258719, 45.05068197613582, 45.216178987175226] |
| Custom: [45.50682002445683, 44.99372898135334, 45.103324053343385, 45.18590698717162, 44.91502500604838, 45.08770297979936, 45.72224896401167, 45.902586018200964, 45.94191804062575, 45.69388699019328, 45.324411999899894, 45.01690098550171, 44.72995799733326, 44.99283799668774, 45.312902017030865, 45.248389011248946, 44.97709800489247, 44.99160900013521, 44.63031404884532, 44.66017603408545, 45.02782097551972, 44.99858903000131, 44.89475500304252, 44.744468992576, 44.88639399642125, 44.79811096098274, 44.76995003642514, 44.648215000052005, 44.673235970549285, 44.698296987917274, 44.77059002965689, 44.66089600464329, 44.90563599392772, 45.984469004906714, 45.526801026426256, 45.913067006040365, 45.988010009750724, 45.26515997713432, 45.067521976307034, 45.230168965645134, 44.975398981478065, 44.86092395382002, 45.25493999244645, 44.94089604122564, 44.82307197758928, 44.838363013695925, 45.01837998395786, 44.708857021760195, 44.941117987036705, 44.85234396997839] |
| . |
| ============================================================ |
| FORWARD MEMORY USAGE BENCHMARK |
| ============================================================ |
| Reference implementation: 0.822 GB |
| Custom implementation: 1.758 GB |
| Memory ratio: 2.140x |
| ✗ Custom uses 114.0% more memory |
| ============================================================ |
| .Warming up for throughput test... |
| |
| ====================================================================== |
| FORWARD THROUGHPUT BENCHMARK |
| ====================================================================== |
| Configuration: 4096 tokens/batch × 100 runs = 409,600 tokens |
| |
| Reference Implementation: |
| - Total time: 4.510 seconds |
| - Throughput: 90,816 tokens/second |
| |
| Custom Implementation: |
| - Total time: 4.510 seconds |
| - Throughput: 90,827 tokens/second |
| |
| Throughput improvement: 1.00x |
| ✓ Custom processes 0.0% more tokens/second |
| ====================================================================== |
| . |
| |
| ================================= 8 passed in 75.71s (0:01:15) ================================= |
| ``` |