GptOssExperts - OpenAI-style MoE

GPU Info

▼ code ▼ output ▶ uv-logs | Cell: nv | 0.28s | Raw GitHub 🤗 HF
import subprocess
print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)
Fri Dec 19 19:41:48 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.105.08             Driver Version: 580.105.08     CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA L40S                    On  |   00000000:4D:00.0 Off |                    0 |
| N/A   33C    P0            126W /  350W |       0MiB /  46068MiB |    100%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

OpenAI-style MoE Benchmark (GptOssExperts Reference)

▼ code ▼ output ▶ uv-logs | Cell: benchmark | 21.43s | Raw GitHub 🤗 HF
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "numpy",
#     "torch==2.8.0",
#     "kernels-benchmark-tools",
#     "kernels",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
# kernels = { git = "https://github.com/huggingface/kernels.git" }
# ///
import torch
import sys
from kernels_benchmark_tools import KernelTypeEnum, run_benchmark
from kernels import get_kernel

# Load yamoe to get GptOssExperts reference
yamoe = get_kernel("drbh/yamoe", revision="v0.2.0")
GptOssExperts = yamoe.vendored.gpt_oss_mlp.GptOssExperts


def gpt_oss_openai_moe(
    hidden_states,
    router_indices,
    routing_weights,
    gate_up_proj,
    gate_up_proj_bias,
    down_proj,
    down_proj_bias,
):
    """
    GptOssExperts reference implementation of OpenAI-style MoE.
    This is the reference model implementation from the original GPT OSS codebase.
    """
    B, S, H = hidden_states.shape
    E = routing_weights.shape[2]

    # Create a config object for GptOssExperts
    config = type("Config", (), {})()
    config.hidden_size = H
    config.intermediate_size = gate_up_proj.shape[2] // 2  # expert_dim / 2 = H
    config.num_local_experts = E

    # Initialize model
    model = GptOssExperts(config)

    # Set weights from benchmark inputs
    model.gate_up_proj.data = gate_up_proj
    model.gate_up_proj_bias.data = gate_up_proj_bias
    model.down_proj.data = down_proj
    model.down_proj_bias.data = down_proj_bias

    model = model.to(hidden_states.device)
    model.eval()

    # Force GptOssExperts to use CPU path for correctness (matches naive_moe_ref behavior)
    # The GPU path processes all experts which can lead to numerical differences
    # CPU path explicitly uses router_indices like the reference implementation
    model.train()  # Force CPU path

    # Flatten routing_weights to [batch_seq, num_experts]
    routing_weights_flat = routing_weights.view(-1, E)

    # Run forward pass
    with torch.no_grad():
        output = model(hidden_states, router_indices, routing_weights_flat)

    model.eval()  # Reset to eval mode

    return output


run_benchmark(
    kernel_type=KernelTypeEnum.OPENAI_MOE,
    impl_name="gpt_oss_experts",
    impl_tags={"family": "reference", "backend": "pytorch"},
    impl_func=gpt_oss_openai_moe,
    dtype="float32",
)
Running openai_moe benchmark on cuda with 8 workloads.

======================================================================
PROFILE TRACE: gpt_oss_experts | cuda_B1_S512_E2
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        gpt_oss_experts         0.00%       0.000us         0.00%       0.000us       0.000us      10.220ms       197.88%      10.220ms      10.220ms             1  
                                        gpt_oss_experts        16.01%       2.006ms        99.94%      12.523ms      12.523ms       0.000us         0.00%       5.168ms       5.168ms             1  
                                           aten::matmul         0.20%      24.744us         3.78%     473.582us      39.465us       0.000us         0.00%       4.543ms     378.565us            12  
                                               aten::mm         2.31%     289.874us         3.58%     448.838us      37.403us       4.543ms        87.96%       4.543ms     378.565us            12  
                                 ampere_sgemm_128x64_nn         0.00%       0.000us         0.00%       0.000us       0.000us       3.093ms        59.88%       3.093ms     343.626us             9  
void cutlass::Kernel2<cutlass_80_simt_sgemm_128x64_8...         0.00%       0.000us         0.00%       0.000us       0.000us       1.444ms        27.95%       1.444ms     481.227us             3  
                                              aten::mul         1.34%     167.604us         2.25%     281.908us      11.746us     108.865us         2.11%     108.865us       4.536us            24  
                                              aten::add         1.61%     201.238us         3.79%     474.483us      26.360us     102.656us         1.99%     102.656us       5.703us            18  
                                            aten::index         1.69%     212.259us         2.75%     345.042us      28.753us      88.512us         1.71%      88.512us       7.376us            12  
                                       aten::index_add_         0.46%      58.122us         0.75%      94.202us      15.700us      80.160us         1.55%      80.160us      13.360us             6  
void at::native::indexFuncLargeIndex<float, long, un...         0.00%       0.000us         0.00%       0.000us       0.000us      80.160us         1.55%      80.160us      13.360us             6  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      80.000us         1.55%      80.000us       6.667us            12  
                                          aten::nonzero         2.08%     261.099us         6.37%     797.848us      88.650us      65.246us         1.26%      76.095us       8.455us             9  
                                            aten::clamp         0.95%     119.641us         1.55%     194.514us      16.209us      63.010us         1.22%      63.010us       5.251us            12  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      63.010us         1.22%      63.010us       5.251us            12  
                                            aten::where         0.06%       7.130us         5.02%     629.533us     104.922us       0.000us         0.00%      61.472us      10.245us             6  
                                    aten::nonzero_numpy         0.09%      11.550us         4.97%     622.403us     103.734us       0.000us         0.00%      61.472us      10.245us             6  
void at::native::vectorized_gather_kernel<16, long>(...         0.00%       0.000us         0.00%       0.000us       0.000us      60.800us         1.18%      60.800us      10.133us             6  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      56.608us         1.10%      56.608us       4.717us            12  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us      50.776us         0.98%      50.776us       1.128us            45  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 12.530ms
Self CUDA time total: 5.165ms



======================================================================
PROFILE TRACE: gpt_oss_experts | cuda_B1_S512_E4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        gpt_oss_experts         0.00%       0.000us         0.00%       0.000us       0.000us      14.281ms       232.73%      14.281ms      14.281ms             1  
                                        gpt_oss_experts        16.85%       2.763ms        99.97%      16.396ms      16.396ms       0.000us         0.00%       6.139ms       6.139ms             1  
                                           aten::matmul         0.27%      44.470us         4.93%     808.156us      33.673us       0.000us         0.00%       5.322ms     221.756us            24  
                                               aten::mm         2.81%     461.070us         4.66%     763.686us      31.820us       5.322ms        86.73%       5.322ms     221.756us            24  
                                 ampere_sgemm_128x64_nn         0.00%       0.000us         0.00%       0.000us       0.000us       5.267ms        85.83%       5.267ms     219.440us            24  
                                          aten::nonzero         2.44%     399.465us         7.84%       1.285ms      85.683us     115.131us         1.88%     137.882us       9.192us            15  
                                              aten::mul         1.86%     305.625us         3.19%     523.892us      10.914us     131.841us         2.15%     131.841us       2.747us            48  
                                              aten::add         2.10%     345.215us         3.57%     585.271us      16.258us     127.810us         2.08%     127.810us       3.550us            36  
                                            aten::where         0.07%      10.792us         7.40%       1.214ms     101.132us       0.000us         0.00%     123.674us      10.306us            12  
                                    aten::nonzero_numpy         0.13%      21.688us         7.33%       1.203ms     100.233us       0.000us         0.00%     123.674us      10.306us            12  
                                            aten::index         2.22%     363.289us         3.85%     631.035us      26.293us     111.423us         1.82%     111.423us       4.643us            24  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     101.762us         1.66%     101.762us       4.240us            24  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us      91.773us         1.50%      91.773us       1.055us            87  
                                            aten::clamp         1.29%     211.324us         2.19%     359.818us      14.992us      88.222us         1.44%      88.222us       3.676us            24  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      88.222us         1.44%      88.222us       3.676us            24  
                                             aten::item         0.47%      77.138us        37.50%       6.150ms      85.417us       0.000us         0.00%      75.678us       1.051us            72  
                              aten::_local_scalar_dense         1.90%     311.363us        37.03%       6.073ms      84.345us      75.678us         1.23%      75.678us       1.051us            72  
                                       aten::index_add_         0.59%      96.073us         0.99%     162.304us      13.525us      70.526us         1.15%      70.526us       5.877us            12  
void at::native::indexFuncLargeIndex<float, long, un...         0.00%       0.000us         0.00%       0.000us       0.000us      70.526us         1.15%      70.526us       5.877us            12  
void at::native::vectorized_gather_kernel<16, long>(...         0.00%       0.000us         0.00%       0.000us       0.000us      66.368us         1.08%      66.368us       5.531us            12  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 16.401ms
Self CUDA time total: 6.136ms



======================================================================
PROFILE TRACE: gpt_oss_experts | cuda_B1_S1024_E2
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        gpt_oss_experts         0.00%       0.000us         0.00%       0.000us       0.000us      12.623ms       150.27%      12.623ms      12.623ms             1  
                                        gpt_oss_experts        13.47%       1.791ms        99.96%      13.283ms      13.283ms       0.000us         0.00%       8.405ms       8.405ms             1  
                                           aten::matmul         0.18%      23.339us         3.36%     446.659us      37.222us       0.000us         0.00%       7.382ms     615.173us            12  
                                               aten::mm         1.99%     264.803us         3.19%     423.320us      35.277us       7.382ms        87.88%       7.382ms     615.173us            12  
void cutlass::Kernel2<cutlass_80_simt_sgemm_256x128_...         0.00%       0.000us         0.00%       0.000us       0.000us       4.494ms        53.50%       4.494ms     748.960us             6  
                                 ampere_sgemm_128x64_nn         0.00%       0.000us         0.00%       0.000us       0.000us       1.479ms        17.61%       1.479ms     493.131us             3  
void cutlass::Kernel2<cutlass_80_simt_sgemm_128x64_8...         0.00%       0.000us         0.00%       0.000us       0.000us       1.402ms        16.69%       1.402ms     467.413us             3  
                                              aten::mul         1.17%     155.791us         2.03%     269.215us      11.217us     193.439us         2.30%     193.439us       8.060us            24  
                                              aten::add         1.34%     178.665us         2.34%     311.318us      17.295us     184.286us         2.19%     184.286us      10.238us            18  
                                       aten::index_add_         0.37%      48.760us         0.64%      85.661us      14.277us     167.358us         1.99%     167.358us      27.893us             6  
void at::native::indexFuncLargeIndex<float, long, un...         0.00%       0.000us         0.00%       0.000us       0.000us     167.358us         1.99%     167.358us      27.893us             6  
                                            aten::index         1.43%     189.705us         2.42%     321.187us      26.766us     146.945us         1.75%     146.945us      12.245us            12  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     145.824us         1.74%     145.824us      12.152us            12  
void at::native::vectorized_gather_kernel<16, long>(...         0.00%       0.000us         0.00%       0.000us       0.000us     116.832us         1.39%     116.832us      19.472us             6  
                                            aten::clamp         0.82%     108.995us         1.40%     185.495us      15.458us     109.284us         1.30%     109.284us       9.107us            12  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     109.284us         1.30%     109.284us       9.107us            12  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     103.135us         1.23%     103.135us       8.595us            12  
                                          aten::nonzero         1.83%     243.374us         5.76%     765.236us      85.026us      70.402us         0.84%      81.794us       9.088us             9  
                                            aten::where         0.04%       5.651us         4.63%     615.153us     102.525us       0.000us         0.00%      66.851us      11.142us             6  
                                    aten::nonzero_numpy         0.08%      11.009us         4.59%     609.502us     101.584us       0.000us         0.00%      66.851us      11.142us             6  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 13.289ms
Self CUDA time total: 8.400ms



======================================================================
PROFILE TRACE: gpt_oss_experts | cuda_B1_S1024_E4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        gpt_oss_experts         0.00%       0.000us         0.00%       0.000us       0.000us      18.138ms       172.84%      18.138ms      18.138ms             1  
                                        gpt_oss_experts        12.76%       2.622ms        99.97%      20.540ms      20.540ms       0.000us         0.00%      10.500ms      10.500ms             1  
                                           aten::matmul         0.22%      44.749us         4.11%     844.232us      35.176us       0.000us         0.00%       9.224ms     384.346us            24  
                                               aten::mm         2.32%     476.088us         3.89%     799.483us      33.312us       9.224ms        87.90%       9.224ms     384.346us            24  
                                 ampere_sgemm_128x64_nn         0.00%       0.000us         0.00%       0.000us       0.000us       6.287ms        59.90%       6.287ms     349.259us            18  
void cutlass::Kernel2<cutlass_80_simt_sgemm_128x64_8...         0.00%       0.000us         0.00%       0.000us       0.000us       2.925ms        27.87%       2.925ms     487.438us             6  
                                              aten::mul         1.51%     311.093us         2.62%     538.833us      11.226us     229.793us         2.19%     229.793us       4.787us            48  
                                              aten::add         1.68%     344.530us         2.88%     592.257us      16.452us     211.009us         2.01%     211.009us       5.861us            36  
                                            aten::index         1.75%     359.041us         3.02%     619.685us      25.820us     205.054us         1.95%     205.054us       8.544us            24  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     164.639us         1.57%     164.639us       6.860us            24  
                                       aten::index_add_         0.48%      97.780us         0.85%     174.953us      14.579us     157.631us         1.50%     157.631us      13.136us            12  
void at::native::indexFuncLargeIndex<float, long, un...         0.00%       0.000us         0.00%       0.000us       0.000us     157.631us         1.50%     157.631us      13.136us            12  
                                          aten::nonzero         1.89%     388.553us         6.17%       1.268ms      84.506us     122.654us         1.17%     146.847us       9.790us            15  
void at::native::vectorized_gather_kernel<16, long>(...         0.00%       0.000us         0.00%       0.000us       0.000us     145.663us         1.39%     145.663us      12.139us            12  
                                            aten::where         0.05%      10.471us         5.79%       1.190ms      99.134us       0.000us         0.00%     132.128us      11.011us            12  
                                    aten::nonzero_numpy         0.10%      21.340us         5.74%       1.179ms      98.262us       0.000us         0.00%     132.128us      11.011us            12  
                                            aten::clamp         1.02%     209.010us         1.74%     358.311us      14.930us     131.327us         1.25%     131.327us       5.472us            24  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     131.327us         1.25%     131.327us       5.472us            24  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     117.601us         1.12%     117.601us       4.900us            24  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us     108.253us         1.03%     108.253us       1.244us            87  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 20.546ms
Self CUDA time total: 10.495ms



======================================================================
PROFILE TRACE: gpt_oss_experts | cuda_B4_S512_E2
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        gpt_oss_experts         0.00%       0.000us         0.00%       0.000us       0.000us      20.935ms       121.00%      20.935ms      20.935ms             1  
                                        gpt_oss_experts         7.61%       1.780ms        99.98%      23.376ms      23.376ms       0.000us         0.00%      17.312ms      17.312ms             1  
                                           aten::matmul         0.10%      23.122us         1.96%     458.772us      38.231us       0.000us         0.00%      14.468ms       1.206ms            12  
                                               aten::mm         1.15%     269.268us         1.86%     435.650us      36.304us      14.468ms        83.62%      14.468ms       1.206ms            12  
void cutlass::Kernel2<cutlass_80_simt_sgemm_256x128_...         0.00%       0.000us         0.00%       0.000us       0.000us       8.827ms        51.02%       8.827ms       1.471ms             6  
                                 ampere_sgemm_128x64_nn         0.00%       0.000us         0.00%       0.000us       0.000us       5.632ms        32.55%       5.632ms     938.689us             6  
                                              aten::add         0.79%     184.599us         1.36%     318.590us      17.699us     771.593us         4.46%     771.593us      42.866us            18  
                                              aten::mul         0.68%     158.205us         1.17%     272.787us      11.366us     648.706us         3.75%     648.706us      27.029us            24  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     492.134us         2.84%     492.134us      41.011us            12  
                                       aten::index_add_         0.22%      51.621us         0.39%      91.292us      15.215us     449.187us         2.60%     449.187us      74.864us             6  
void at::native::indexFuncLargeIndex<float, long, un...         0.00%       0.000us         0.00%       0.000us       0.000us     449.187us         2.60%     449.187us      74.864us             6  
                                            aten::clamp         0.47%     109.062us         0.80%     186.384us      15.532us     328.069us         1.90%     328.069us      27.339us            12  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     328.069us         1.90%     328.069us      27.339us            12  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     298.432us         1.72%     298.432us      49.739us             6  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     279.459us         1.62%     279.459us      46.576us             6  
                                            aten::index         0.79%     185.644us         1.37%     320.365us      26.697us     259.362us         1.50%     259.362us      21.614us            12  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     252.002us         1.46%     252.002us      21.000us            12  
void at::native::vectorized_gather_kernel<16, long>(...         0.00%       0.000us         0.00%       0.000us       0.000us     226.817us         1.31%     226.817us      37.803us             6  
                                          aten::sigmoid         0.16%      37.651us         0.31%      72.093us      12.016us     177.249us         1.02%     177.249us      29.542us             6  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     177.249us         1.02%     177.249us      29.542us             6  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 23.381ms
Self CUDA time total: 17.302ms



======================================================================
PROFILE TRACE: gpt_oss_experts | cuda_B4_S512_E4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        gpt_oss_experts         0.00%       0.000us         0.00%       0.000us       0.000us      24.710ms       141.76%      24.710ms      24.710ms             1  
                                        gpt_oss_experts        10.14%       2.749ms        99.98%      27.106ms      27.106ms       0.000us         0.00%      17.441ms      17.441ms             1  
                                           aten::matmul         0.17%      45.968us         3.40%     922.464us      38.436us       0.000us         0.00%      15.230ms     634.586us            24  
                                               aten::mm         2.05%     556.479us         3.23%     876.496us      36.521us      15.230ms        87.37%      15.230ms     634.586us            24  
void cutlass::Kernel2<cutlass_80_simt_sgemm_256x128_...         0.00%       0.000us         0.00%       0.000us       0.000us       9.172ms        52.62%       9.172ms     764.334us            12  
                                 ampere_sgemm_128x64_nn         0.00%       0.000us         0.00%       0.000us       0.000us       3.147ms        18.05%       3.147ms     524.452us             6  
void cutlass::Kernel2<cutlass_80_simt_sgemm_128x64_8...         0.00%       0.000us         0.00%       0.000us       0.000us       2.898ms        16.62%       2.898ms     482.943us             6  
                                              aten::add         1.29%     350.116us         2.26%     613.465us      17.041us     420.321us         2.41%     420.321us      11.676us            36  
                                              aten::mul         1.13%     307.419us         1.97%     533.015us      11.104us     413.571us         2.37%     413.571us       8.616us            48  
                                       aten::index_add_         0.36%      98.853us         0.63%     169.455us      14.121us     380.323us         2.18%     380.323us      31.694us            12  
void at::native::indexFuncLargeIndex<float, long, un...         0.00%       0.000us         0.00%       0.000us       0.000us     380.323us         2.18%     380.323us      31.694us            12  
                                            aten::index         1.34%     364.187us         2.36%     638.760us      26.615us     342.626us         1.97%     342.626us      14.276us            24  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     337.185us         1.93%     337.185us      14.049us            24  
void at::native::vectorized_gather_kernel<16, long>(...         0.00%       0.000us         0.00%       0.000us       0.000us     278.754us         1.60%     278.754us      23.230us            12  
                                            aten::clamp         0.81%     219.710us         1.37%     372.721us      15.530us     226.367us         1.30%     226.367us       9.432us            24  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     226.367us         1.30%     226.367us       9.432us            24  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     219.298us         1.26%     219.298us       9.137us            24  
                                          aten::nonzero         1.48%     402.204us         4.91%       1.331ms      88.732us     129.571us         0.74%     155.747us      10.383us            15  
                                            aten::where         0.04%      10.572us         4.67%       1.267ms     105.600us       0.000us         0.00%     139.970us      11.664us            12  
                                    aten::nonzero_numpy         0.08%      21.969us         4.64%       1.257ms     104.719us       0.000us         0.00%     139.970us      11.664us            12  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 27.112ms
Self CUDA time total: 17.431ms



======================================================================
PROFILE TRACE: gpt_oss_experts | cuda_B4_S1024_E2
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        gpt_oss_experts         0.00%       0.000us         0.00%       0.000us       0.000us      40.438ms       109.96%      40.438ms      40.438ms             1  
                                        gpt_oss_experts         4.40%       1.882ms        99.82%      42.728ms      42.728ms       0.000us         0.00%      36.808ms      36.808ms             1  
                                           aten::matmul         0.05%      22.249us         1.02%     438.421us      36.535us       0.000us         0.00%      26.813ms       2.234ms            12  
                                               aten::mm         0.66%     281.965us         0.97%     416.172us      34.681us      26.813ms        72.91%      26.813ms       2.234ms            12  
void cutlass::Kernel2<cutlass_80_simt_sgemm_256x128_...         0.00%       0.000us         0.00%       0.000us       0.000us      26.809ms        72.90%      26.809ms       2.234ms            12  
                                              aten::mul         0.40%     169.436us         0.68%     291.368us      12.140us       2.973ms         8.09%       2.973ms     123.894us            24  
                                              aten::add         0.45%     194.095us         1.09%     466.694us      25.927us       2.399ms         6.52%       2.399ms     133.270us            18  
                                            aten::clamp         0.28%     118.373us         0.48%     205.484us      17.124us       2.385ms         6.49%       2.385ms     198.780us            12  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us       2.385ms         6.49%       2.385ms     198.780us            12  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.983ms         5.39%       1.983ms     165.284us            12  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us       1.623ms         4.41%       1.623ms     135.241us            12  
                                       aten::index_add_         0.12%      50.121us         0.21%      88.453us      14.742us     929.513us         2.53%     929.513us     154.919us             6  
void at::native::indexFuncLargeIndex<float, long, un...         0.00%       0.000us         0.00%       0.000us       0.000us     929.513us         2.53%     929.513us     154.919us             6  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     775.973us         2.11%     775.973us     129.329us             6  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     743.622us         2.02%     743.622us     123.937us             6  
                                            aten::index         0.44%     190.163us         0.78%     332.417us      27.701us     705.798us         1.92%     705.798us      58.816us            12  
void at::native::vectorized_gather_kernel<16, long>(...         0.00%       0.000us         0.00%       0.000us       0.000us     672.133us         1.83%     672.133us     112.022us             6  
                                          aten::sigmoid         0.10%      42.342us         0.17%      71.992us      11.999us     317.635us         0.86%     317.635us      52.939us             6  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     317.635us         0.86%     317.635us      52.939us             6  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     246.434us         0.67%     246.434us      41.072us             6  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 42.805ms
Self CUDA time total: 36.776ms



======================================================================
PROFILE TRACE: gpt_oss_experts | cuda_B4_S1024_E4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        gpt_oss_experts         0.00%       0.000us         0.00%       0.000us       0.000us      40.917ms       118.34%      40.917ms      40.917ms             1  
                                        gpt_oss_experts         6.54%       2.832ms        99.99%      43.320ms      43.320ms       0.000us         0.00%      34.594ms      34.594ms             1  
                                           aten::matmul         0.11%      46.003us         2.16%     933.683us      38.903us       0.000us         0.00%      28.640ms       1.193ms            24  
                                               aten::mm         1.27%     551.595us         2.05%     887.680us      36.987us      28.640ms        82.83%      28.640ms       1.193ms            24  
void cutlass::Kernel2<cutlass_80_simt_sgemm_256x128_...         0.00%       0.000us         0.00%       0.000us       0.000us      20.238ms        58.53%      20.238ms       1.349ms            15  
                                 ampere_sgemm_128x64_nn         0.00%       0.000us         0.00%       0.000us       0.000us       8.385ms        24.25%       8.385ms     931.701us             9  
                                              aten::add         0.85%     367.713us         1.47%     637.625us      17.712us       1.485ms         4.30%       1.485ms      41.254us            36  
                                              aten::mul         0.73%     317.651us         1.28%     554.606us      11.554us       1.368ms         3.96%       1.368ms      28.495us            48  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     932.164us         2.70%     932.164us      38.840us            24  
                                       aten::index_add_         0.23%      99.030us         0.39%     170.492us      14.208us     912.225us         2.64%     912.225us      76.019us            12  
void at::native::indexFuncLargeIndex<float, long, un...         0.00%       0.000us         0.00%       0.000us       0.000us     912.225us         2.64%     912.225us      76.019us            12  
                                            aten::clamp         0.52%     223.402us         0.90%     389.994us      16.250us     772.775us         2.24%     772.775us      32.199us            24  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     772.775us         2.24%     772.775us      32.199us            24  
                                            aten::index         0.84%     365.911us         1.48%     641.837us      26.743us     652.128us         1.89%     652.128us      27.172us            24  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     646.273us         1.87%     646.273us      53.856us            12  
void at::native::vectorized_gather_kernel<16, long>(...         0.00%       0.000us         0.00%       0.000us       0.000us     582.113us         1.68%     582.113us      48.509us            12  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     552.993us         1.60%     552.993us      46.083us            12  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     519.810us         1.50%     519.810us      21.659us            24  
                                          aten::sigmoid         0.18%      79.593us         0.31%     135.883us      11.324us     361.471us         1.05%     361.471us      30.123us            12  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     361.471us         1.05%     361.471us      30.123us            12  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 43.326ms
Self CUDA time total: 34.575ms


impl                     wl                  p50(ms)  ok
gpt_oss_experts          cuda_B1_S1024_E2       3.85  True
gpt_oss_experts          cuda_B1_S1024_E4       5.31  True
gpt_oss_experts          cuda_B1_S512_E2        2.63  True
gpt_oss_experts          cuda_B1_S512_E4        3.93  True
gpt_oss_experts          cuda_B4_S1024_E2      13.24  True
gpt_oss_experts          cuda_B4_S1024_E4      13.36  True
gpt_oss_experts          cuda_B4_S512_E2        6.72  True
gpt_oss_experts          cuda_B4_S512_E4        7.52  True
▶ UV Install Logs
Fetching 6 files: 0%| | 0/6 [00:00<?, ?it/s] Fetching 6 files: 50%|█████ | 3/6 [00:00<00:00, 6.07it/s] Fetching 6 files: 100%|██████████| 6/6 [00:00<00:00, 12.14it/s]

Artifacts:

openai_moe.jsonl