--- license: apache-2.0 tags: - kernels --- # triton moe This repository contains the kernels to run the Mixture of Experts (MoE) model using Triton. ```python # /// script # dependencies = [ # "kernels", # "numpy", # "torch", # ] # /// import torch from kernels import get_kernel # Make reproducible torch.manual_seed(42) torch.cuda.manual_seed(42) # Download optimized kernels from the Hugging Face hub triton_moe = get_kernel("kernels-community/triton-moe") # Random tensor x = torch.randn((10, 10), dtype=torch.float16, device="cuda") # Run the kernel gate_up_out = x.unsqueeze(-1).repeat(1, 1, 2) out = triton_moe.fused_glu.fused_glu_triton(gate_up_out=gate_up_out, alpha=1.0) # Check the output print("Output shape:", out.shape) print("Output sum:", out.sum().item()) # Output shape: torch.Size([10, 10, 1]) # Output sum: 62.875 ``` ### Testing ```bash nix develop -i -L .#test --command python -m pytest -s tests ``` expected output of the test in [`tests/test_triton_moe.py`](tests/test_triton_moe.py): ```text warning: Git tree '/home/ubuntu/Projects/triton-moe' is dirty evaluation warning: CUDA versions older than 12.0 will be removed in Nixpkgs 25.05; see the 24.11 release notes for more information triton_moe-torch-ext> Running phase: unpackPhase triton_moe-torch-ext> unpacking source archive /nix/store/5zm9aqzym4h6xx414sy17dynr1hjbwh8-source triton_moe-torch-ext> source root is source triton_moe-torch-ext> Running phase: patchPhase triton_moe-torch-ext> Running phase: updateAutotoolsGnuConfigScriptsPhase triton_moe-torch-ext> Running phase: configurePhase triton_moe-torch-ext> no configure script, doing nothing triton_moe-torch-ext> Running phase: installPhase triton_moe-torch-ext> Running phase: fixupPhase triton_moe-torch-ext> shrinking RPATHs of ELF executables and libraries in /nix/store/yrzl0pngg8xxpf4jpkya8rmvmakgn4cd-triton_moe-torch-ext triton_moe-torch-ext> checking for references to /build/ in /nix/store/yrzl0pngg8xxpf4jpkya8rmvmakgn4cd-triton_moe-torch-ext... triton_moe-torch-ext> patching script interpreter paths in /nix/store/yrzl0pngg8xxpf4jpkya8rmvmakgn4cd-triton_moe-torch-ext ===================================== test session starts ====================================== platform linux -- Python 3.12.10, pytest-8.3.5, pluggy-1.5.0 rootdir: /home/ubuntu/Projects/triton-moe plugins: hypothesis-6.130.12 collected 8 items tests/test_triton_moe.py Average difference: 0.009301766753196716 Max difference: 0.095703125 .gate_up_proj.grad exists: True gate_up_proj_bias.grad exists: True down_proj.grad exists: True down_proj_bias.grad exists: True .gate_up_proj.grad exists: True gate_up_proj_bias.grad exists: True down_proj.grad exists: True down_proj_bias.grad exists: True hidden_states.grad exists: True ✓ Backward test passed - all parameters have gradients 10 elements from gate_up_proj gradients: tensor([ 179.9667, -852.7672, -3274.1992, -4076.2095, -2571.6282, -296.6539, 1800.2004, 503.5397, 48.4640, 191.8257], device='cuda:0') 10 elements from ref_layer.gate_up_proj gradients: tensor([ 179.9663, -852.7676, -3274.1995, -4076.2188, -2571.6238, -296.6619, 1800.1997, 503.5378, 48.4632, 191.8266], device='cuda:0') .Warming up... Benchmarking reference implementation (20 runs)... Completed 5/20 runs Completed 10/20 runs Completed 15/20 runs Completed 20/20 runs Benchmarking custom implementation (20 runs)... Completed 5/20 runs Completed 10/20 runs Completed 15/20 runs Completed 20/20 runs ================================================================================ BACKWARD PASS BENCHMARK RESULTS ================================================================================ Configuration: - Experts: 128 - Hidden size: 1024 - Expert dim: 512 - Batch tokens: 4096 - Top-k: 2 - Runs: 20 Reference Implementation (OpenaiExperts): - Mean: 1855.949 ms - Std: 9.959 ms - Min: 1851.829 ms - Max: 1896.181 ms Custom Implementation (MoE): - Mean: 250.311 ms - Std: 0.591 ms - Min: 249.697 ms - Max: 252.103 ms Speedup: 7.41x ✓ Custom implementation is 7.41x faster ================================================================================ Detailed timings (ms): Reference: [1869.8261399986222, 1852.2405139519833, 1853.5575779969804, 1852.6741919922642, 1853.342688002158, 1853.4869640134275, 1854.8076269798912, 1852.6032069930807, 1853.7065120181069, 1852.3418360273354, 1853.6653390037827, 1853.267052967567, 1853.421829000581, 1851.8838259624317, 1852.5341430213302, 1851.8291829968803, 1852.3553150007501, 1896.1806659935974, 1852.8080619871616, 1852.4555769981816] Custom: [252.10309802787378, 251.01929501397535, 250.63456897623837, 250.02275395672768, 249.69729100121185, 249.89533895859495, 249.75963402539492, 249.9880829709582, 251.09507801244035, 250.78181497519836, 250.63572899671271, 250.75654400279745, 250.1296689733863, 249.93031000485644, 249.83260600129142, 250.06496504647657, 250.1023070071824, 250.03619497874752, 249.96405199635774, 249.77726401994005] . ============================================================ MEMORY USAGE BENCHMARK ============================================================ Reference implementation: 2.977 GB Custom implementation: 2.737 GB Memory ratio: 0.919x ✓ Custom uses 8.1% less memory ============================================================ .Warming up... Benchmarking reference implementation (50 runs)... Completed 10/50 runs Completed 20/50 runs Completed 30/50 runs Completed 40/50 runs Completed 50/50 runs Benchmarking custom implementation (50 runs)... Completed 10/50 runs Completed 20/50 runs Completed 30/50 runs Completed 40/50 runs Completed 50/50 runs ================================================================================ FORWARD PASS BENCHMARK RESULTS ================================================================================ Configuration: - Experts: 128 - Hidden size: 1024 - Expert dim: 512 - Batch tokens: 4096 - Top-k: 2 - Runs: 50 Reference Implementation (OpenaiExperts): - Mean: 45.218 ms - Std: 0.643 ms - Min: 44.657 ms - Max: 49.252 ms Custom Implementation (MoE): - Mean: 45.092 ms - Std: 0.382 ms - Min: 44.630 ms - Max: 45.988 ms Speedup: 1.00x ✓ Custom implementation is 1.00x faster ================================================================================ Detailed timings (ms): Reference: [49.2524920264259, 44.996229000389576, 44.75043900310993, 45.368854014668614, 45.28193100122735, 44.9596070102416, 45.46641802880913, 45.40711600566283, 44.76102895569056, 45.04138103220612, 44.8942250222899, 45.068661973346025, 45.07604299578816, 45.02651101211086, 44.988538953475654, 44.86601299140602, 45.13182397931814, 44.97082799207419, 44.656905985902995, 45.53678201045841, 45.52290996070951, 45.288041001185775, 45.18025700235739, 45.118414040189236, 44.841952971182764, 45.04251101752743, 45.11384398210794, 45.021480007562786, 44.88742502871901, 44.94614701252431, 45.80574203282595, 46.0561320069246, 45.724289026111364, 45.17030599527061, 45.18806800479069, 44.91667600814253, 45.08163296850398, 44.952887983527035, 45.330752967856824, 44.88741495879367, 44.882684014737606, 45.47502798959613, 45.74003902962431, 45.015780022367835, 45.045601029414684, 44.906274997629225, 44.9566770112142, 45.05321098258719, 45.05068197613582, 45.216178987175226] Custom: [45.50682002445683, 44.99372898135334, 45.103324053343385, 45.18590698717162, 44.91502500604838, 45.08770297979936, 45.72224896401167, 45.902586018200964, 45.94191804062575, 45.69388699019328, 45.324411999899894, 45.01690098550171, 44.72995799733326, 44.99283799668774, 45.312902017030865, 45.248389011248946, 44.97709800489247, 44.99160900013521, 44.63031404884532, 44.66017603408545, 45.02782097551972, 44.99858903000131, 44.89475500304252, 44.744468992576, 44.88639399642125, 44.79811096098274, 44.76995003642514, 44.648215000052005, 44.673235970549285, 44.698296987917274, 44.77059002965689, 44.66089600464329, 44.90563599392772, 45.984469004906714, 45.526801026426256, 45.913067006040365, 45.988010009750724, 45.26515997713432, 45.067521976307034, 45.230168965645134, 44.975398981478065, 44.86092395382002, 45.25493999244645, 44.94089604122564, 44.82307197758928, 44.838363013695925, 45.01837998395786, 44.708857021760195, 44.941117987036705, 44.85234396997839] . ============================================================ FORWARD MEMORY USAGE BENCHMARK ============================================================ Reference implementation: 0.822 GB Custom implementation: 1.758 GB Memory ratio: 2.140x ✗ Custom uses 114.0% more memory ============================================================ .Warming up for throughput test... ====================================================================== FORWARD THROUGHPUT BENCHMARK ====================================================================== Configuration: 4096 tokens/batch × 100 runs = 409,600 tokens Reference Implementation: - Total time: 4.510 seconds - Throughput: 90,816 tokens/second Custom Implementation: - Total time: 4.510 seconds - Throughput: 90,827 tokens/second Throughput improvement: 1.00x ✓ Custom processes 0.0% more tokens/second ====================================================================== . ================================= 8 passed in 75.71s (0:01:15) ================================= ```