Binned PyTorch - OpenAI-style MoE

GPU Info

▼ code ▼ output ▶ uv-logs | Cell: nv | 0.25s | Raw GitHub
import subprocess
print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)
Fri Dec 19 23:00:37 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.105.08             Driver Version: 580.105.08     CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA L40S                    On  |   00000000:4D:00.0 Off |                    0 |
| N/A   40C    P0             84W /  350W |       0MiB /  46068MiB |     60%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

OpenAI-style MoE Benchmark (Binned PyTorch)

▼ code ▼ output ▶ uv-logs | Cell: benchmark | 723.84s | Raw GitHub
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "numpy",
#     "torch==2.8.0",
#     "kernels-benchmark-tools",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
# ///
import torch
import sys
from kernels_benchmark_tools import KernelTypeEnum, run_benchmark


def binned_gather(x, indices, bins, expert_capacity, top_k):
    E, H = bins.shape[0], x.shape[1]
    out = torch.zeros((E, expert_capacity, H), device=x.device, dtype=x.dtype)
    for e in range(E):
        start = 0 if e == 0 else bins[e - 1]
        end = bins[e]
        n = min(end - start, expert_capacity)
        for i in range(n):
            flat_pos = indices[start + i]
            tok = flat_pos // top_k
            out[e, i] = x[tok]
    return out


def binned_scatter(x, indices, weights, bins, expert_capacity, top_k):
    E, C, H = x.shape
    N = indices.shape[0] // top_k
    out = torch.zeros((N, top_k, H), dtype=x.dtype, device=x.device)
    for e in range(E):
        start = 0 if e == 0 else bins[e - 1]
        end = bins[e]
        n = end - start
        if n == 0:
            continue
        take = min(n, expert_capacity)
        for i in range(take):
            flat_pos = indices[start + i]  # flattened (token, slot)
            tok = flat_pos // top_k
            slot = flat_pos % top_k
            scale = weights[flat_pos] if weights is not None else 1.0
            out[tok, slot] = x[e, i] * scale
    return out.sum(dim=1)


def sort_tokens_by_expert(router_indices, num_experts):
    flat_indices = router_indices.flatten()
    sorted_values, sorted_indices = torch.sort(flat_indices)
    tokens_per_expert = torch.bincount(sorted_values, minlength=num_experts)
    bins = torch.cumsum(tokens_per_expert, dim=0)
    return sorted_indices, sorted_values, bins, tokens_per_expert


def binned_experts_ref(
    hidden_states,
    router_indices,
    routing_weights,
    gate_up_proj,
    gate_up_proj_bias,
    down_proj,
    down_proj_bias,
    expert_capacity,
):
    B, S, H = hidden_states.shape
    E, K = routing_weights.shape[2], router_indices.shape[1]

    indices, _, bins, _ = sort_tokens_by_expert(router_indices, E)
    x = binned_gather(hidden_states.view(-1, H), indices, bins, expert_capacity, K)

    gate_up = torch.bmm(x, gate_up_proj) + gate_up_proj_bias[..., None, :]
    gate, up = gate_up[..., ::2], gate_up[..., 1::2]

    # clamp to limit
    limit = 7.0
    gate = gate.clamp(min=None, max=limit)
    up = up.clamp(min=-limit, max=limit)

    glu = gate * torch.sigmoid(gate * 1.702)
    x = (up + 1) * glu
    x = torch.bmm(x, down_proj) + down_proj_bias[..., None, :]

    # build routing weights aligned to (token, slot)
    flat_dense = routing_weights.view(-1, E)  # [B*S, E]
    flat_router = router_indices.view(-1, K)  # [B*S, K]
    selected = torch.gather(flat_dense, 1, flat_router).reshape(-1)  # [B*S*K]

    # scatter back
    y = binned_scatter(x, indices, selected, bins, expert_capacity, K)  # [B*S, H]

    return y.view(B, S, H)


def binned_torch_openai_moe(
    hidden_states,
    router_indices,
    routing_weights,
    gate_up_proj,
    gate_up_proj_bias,
    down_proj,
    down_proj_bias,
):
    """
    Binned PyTorch implementation of OpenAI-style MoE.
    Sorts tokens by expert assignment for more efficient batched processing.
    """
    B, S = hidden_states.shape[0], hidden_states.shape[1]
    K = router_indices.shape[1]

    # Set expert_capacity to a reasonable value (max tokens per expert)
    # Use 2x the average to handle imbalance
    expert_capacity = (B * S * K * 2) // routing_weights.shape[2]

    return binned_experts_ref(
        hidden_states,
        router_indices,
        routing_weights,
        gate_up_proj,
        gate_up_proj_bias,
        down_proj,
        down_proj_bias,
        expert_capacity,
    )


run_benchmark(
    kernel_type=KernelTypeEnum.OPENAI_MOE,
    impl_name="binned_torch",
    impl_tags={"family": "pytorch", "backend": "eager"},
    impl_func=binned_torch_openai_moe,
    dtype="float32",
)
Running openai_moe benchmark on cuda with 8 workloads.

======================================================================
PROFILE TRACE: binned_torch | cuda_B1_S512_E2
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us     916.334ms      1818.27%     916.334ms     916.334ms             1  
                                           binned_torch        24.63%     226.221ms       100.00%     918.346ms     918.346ms       0.000us         0.00%      50.398ms      50.398ms             1  
                                             aten::item         1.84%      16.915ms        25.73%     236.247ms      15.396us       0.000us         0.00%      15.727ms       1.025us         15345  
                              aten::_local_scalar_dense         5.92%      54.373ms        23.88%     219.332ms      14.293us      15.726ms        31.20%      15.727ms       1.025us         15345  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us      15.726ms        31.20%      15.726ms       1.025us         15345  
                                              aten::bmm         0.02%     194.226us         0.03%     236.195us      39.366us       8.013ms        15.90%       8.013ms       1.336ms             6  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us       8.013ms        15.90%       8.013ms       1.336ms             6  
                                     aten::floor_divide         5.35%      49.157ms        13.15%     120.743ms      19.652us       7.547ms        14.98%       7.547ms       1.228us          6144  
                                            aten::copy_         3.75%      34.457ms         9.21%      84.535ms      13.732us       6.589ms        13.08%       6.592ms       1.071us          6156  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us       6.585ms        13.07%       6.585ms       1.070us          6153  
                                              aten::mul         3.14%      28.847ms         5.63%      51.742ms      16.794us       4.707ms         9.34%       4.707ms       1.528us          3081  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us       4.479ms         8.89%       4.479ms       1.458us          3072  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       4.026ms         7.99%       4.026ms       1.311us          3072  
                                        aten::remainder         3.09%      28.363ms         4.76%      43.750ms      14.241us       3.702ms         7.35%       3.702ms       1.205us          3072  
                                              aten::add         2.79%      25.584ms         4.81%      44.150ms      14.557us       3.631ms         7.20%       3.631ms       1.197us          3033  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       3.522ms         6.99%       3.522ms       1.147us          3072  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       3.235ms         6.42%       3.235ms       1.068us          3030  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.954ms         3.88%       1.954ms       1.272us          1536  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       1.749ms         3.47%       1.749ms       1.138us          1536  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     287.138us         0.57%     287.138us      47.856us             6  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 918.353ms
Self CUDA time total: 50.396ms



======================================================================
PROFILE TRACE: binned_torch | cuda_B1_S512_E4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us     930.604ms      1724.65%     930.604ms     930.604ms             1  
                                           binned_torch        24.29%     226.115ms       100.00%     930.865ms     930.865ms       0.000us         0.00%      53.966ms      53.966ms             1  
                                             aten::item         1.81%      16.815ms        27.55%     256.425ms      15.142us       0.000us         0.00%      17.838ms       1.053us         16935  
                              aten::_local_scalar_dense         6.14%      57.141ms        25.74%     239.611ms      14.149us      17.835ms        33.05%      17.838ms       1.053us         16935  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us      17.835ms        33.05%      17.835ms       1.053us         16935  
                                              aten::bmm         0.02%     175.424us         0.02%     217.325us      36.221us       7.967ms        14.77%       7.967ms       1.328ms             6  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us       7.967ms        14.77%       7.967ms       1.328ms             6  
                                     aten::floor_divide         5.05%      47.005ms        12.57%     117.000ms      19.043us       7.550ms        13.99%       7.551ms       1.229us          6144  
                                            aten::copy_         3.51%      32.640ms         8.36%      77.831ms      12.643us       6.635ms        12.30%       6.635ms       1.078us          6156  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us       6.632ms        12.29%       6.632ms       1.078us          6152  
                                              aten::add         3.89%      36.256ms         6.95%      64.697ms      14.086us       5.059ms         9.38%       5.059ms       1.102us          4593  
                                              aten::mul         2.92%      27.144ms         5.32%      49.502ms      16.067us       4.707ms         8.72%       4.707ms       1.528us          3081  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us       4.479ms         8.30%       4.479ms       1.458us          3072  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       4.026ms         7.46%       4.026ms       1.310us          3072  
                                        aten::remainder         2.81%      26.197ms         4.49%      41.800ms      13.607us       3.721ms         6.90%       3.721ms       1.211us          3072  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       3.524ms         6.53%       3.524ms       1.147us          3072  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       3.140ms         5.82%       3.140ms       1.036us          3030  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.965ms         3.64%       1.965ms       1.279us          1536  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       1.756ms         3.25%       1.756ms       1.143us          1536  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       1.517ms         2.81%       1.517ms       0.972us          1560  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 930.874ms
Self CUDA time total: 53.959ms



======================================================================
PROFILE TRACE: binned_torch | cuda_B1_S1024_E2
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us        1.706s      1653.15%        1.706s        1.706s             1  
                                           binned_torch        24.03%     409.734ms       100.00%        1.705s        1.705s       0.000us         0.00%     103.183ms     103.183ms             1  
                                             aten::item         1.59%      27.070ms        26.54%     452.490ms      14.829us       0.000us         0.00%      31.572ms       1.035us         30513  
                              aten::_local_scalar_dense         5.90%     100.602ms        24.95%     425.421ms      13.942us      31.568ms        30.60%      31.572ms       1.035us         30513  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us      31.568ms        30.60%      31.568ms       1.035us         30513  
                                              aten::bmm         0.01%     213.024us         0.02%     261.877us      43.646us      15.473ms        15.00%      15.473ms       2.579ms             6  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us      15.473ms        15.00%      15.473ms       2.579ms             6  
                                     aten::floor_divide         5.42%      92.355ms        13.36%     227.861ms      18.543us      15.078ms        14.61%      15.078ms       1.227us         12288  
                                            aten::copy_         3.96%      67.445ms         9.41%     160.444ms      13.044us      13.330ms        12.92%      13.330ms       1.084us         12300  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us      13.326ms        12.92%      13.326ms       1.084us         12294  
                                              aten::mul         3.18%      54.204ms         5.76%      98.288ms      15.974us      11.263ms        10.92%      11.265ms       1.831us          6153  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us       9.919ms         9.61%       9.919ms       1.614us          6144  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       8.044ms         7.80%       8.044ms       1.309us          6144  
                                        aten::remainder         3.09%      52.622ms         4.84%      82.495ms      13.427us       7.409ms         7.18%       7.409ms       1.206us          6144  
                                              aten::add         2.82%      48.063ms         4.95%      84.371ms      14.269us       7.380ms         7.15%       7.380ms       1.248us          5913  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       7.034ms         6.82%       7.034ms       1.145us          6144  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       6.098ms         5.91%       6.098ms       1.032us          5910  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       3.912ms         3.79%       3.912ms       1.273us          3072  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       3.498ms         3.39%       3.498ms       1.139us          3072  
                                            aten::clamp         0.00%      70.381us         0.01%     115.343us      19.224us       1.182ms         1.15%       1.182ms     197.026us             6  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 1.705s
Self CUDA time total: 103.179ms



======================================================================
PROFILE TRACE: binned_torch | cuda_B1_S1024_E4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us        1.835s      1676.06%        1.835s        1.835s             1  
                                           binned_torch        24.11%     442.690ms       100.00%        1.836s        1.836s       0.000us         0.00%     109.503ms     109.503ms             1  
                                             aten::item         1.62%      29.702ms        27.50%     504.982ms      14.972us       0.000us         0.00%      35.015ms       1.038us         33729  
                              aten::_local_scalar_dense         6.21%     114.112ms        25.88%     475.279ms      14.091us      35.012ms        31.97%      35.015ms       1.038us         33729  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us      35.012ms        31.97%      35.012ms       1.038us         33728  
                                              aten::bmm         0.01%     232.655us         0.02%     282.685us      47.114us      15.567ms        14.22%      15.567ms       2.595ms             6  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us      15.567ms        14.22%      15.567ms       2.595ms             6  
                                     aten::floor_divide         5.11%      93.914ms        12.52%     229.926ms      18.711us      15.067ms        13.76%      15.067ms       1.226us         12288  
                                            aten::copy_         3.50%      64.191ms         8.58%     157.627ms      12.815us      13.353ms        12.19%      13.355ms       1.086us         12300  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us      13.350ms        12.19%      13.350ms       1.086us         12294  
                                              aten::mul         2.97%      54.553ms         5.34%      97.962ms      15.921us      10.925ms         9.98%      10.925ms       1.776us          6153  
                                              aten::add         3.96%      72.764ms         6.93%     127.157ms      13.975us      10.457ms         9.55%      10.457ms       1.149us          9099  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us       9.572ms         8.74%       9.572ms       1.558us          6144  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       8.046ms         7.35%       8.046ms       1.310us          6144  
                                        aten::remainder         2.95%      54.099ms         4.66%      85.633ms      13.938us       7.422ms         6.78%       7.422ms       1.208us          6144  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       7.021ms         6.41%       7.021ms       1.143us          6144  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       6.106ms         5.58%       6.106ms       1.033us          5910  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       3.920ms         3.58%       3.920ms       1.276us          3072  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       3.502ms         3.20%       3.502ms       1.140us          3072  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       3.094ms         2.83%       3.094ms       0.971us          3186  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 1.836s
Self CUDA time total: 109.497ms



======================================================================
PROFILE TRACE: binned_torch | cuda_B4_S512_E2
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us        3.483s      1652.23%        3.483s        3.483s             1  
                                           binned_torch        24.18%     842.026ms       100.00%        3.482s        3.482s       0.000us         0.00%     210.838ms     210.838ms             1  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us      63.561ms        30.15%      63.561ms       1.032us         61586  
                                             aten::item         1.74%      60.466ms        26.96%     938.865ms      15.245us       0.000us         0.00%      63.559ms       1.032us         61587  
                              aten::_local_scalar_dense         6.04%     210.488ms        25.22%     878.295ms      14.261us      63.559ms        30.15%      63.559ms       1.032us         61587  
                                     aten::floor_divide         5.38%     187.378ms        13.29%     462.870ms      18.834us      30.531ms        14.48%      30.538ms       1.243us         24576  
                                              aten::bmm         0.01%     232.923us         0.01%     283.154us      47.192us      29.267ms        13.88%      29.267ms       4.878ms             6  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us      29.267ms        13.88%      29.267ms       4.878ms             6  
                                            aten::copy_         3.71%     129.087ms         8.89%     309.556ms      12.590us      26.727ms        12.68%      26.728ms       1.087us         24588  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us      26.725ms        12.68%      26.725ms       1.087us         24582  
                                              aten::mul         3.12%     108.737ms         5.69%     198.327ms      16.128us      25.576ms        12.13%      25.578ms       2.080us         12297  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      22.134ms        10.50%      22.134ms       1.801us         12288  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      16.473ms         7.81%      16.473ms       1.341us         12288  
                                              aten::add         2.81%      97.833ms         4.96%     172.866ms      13.928us      16.092ms         7.63%      16.093ms       1.297us         12411  
                                        aten::remainder         3.07%     106.957ms         4.82%     167.982ms      13.670us      14.887ms         7.06%      14.889ms       1.212us         12288  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      14.058ms         6.67%      14.058ms       1.144us         12288  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      12.970ms         6.15%      12.970ms       1.045us         12408  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       7.857ms         3.73%       7.857ms       1.279us          6144  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       7.030ms         3.33%       7.030ms       1.144us          6144  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       2.605ms         1.24%       2.605ms     434.242us             6  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 3.483s
Self CUDA time total: 210.821ms



======================================================================
PROFILE TRACE: binned_torch | cuda_B4_S512_E4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us        3.725s      1668.35%        3.725s        3.725s             1  
                                           binned_torch        24.05%     896.242ms       100.00%        3.727s        3.727s       0.000us         0.00%     223.307ms     223.307ms             1  
                                             aten::item         1.73%      64.547ms        27.53%        1.026s      15.123us       0.000us         0.00%      69.633ms       1.026us         67845  
                              aten::_local_scalar_dense         6.19%     230.534ms        25.80%     961.495ms      14.172us      69.631ms        31.18%      69.633ms       1.026us         67845  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us      69.632ms        31.18%      69.632ms       1.026us         67841  
                                     aten::floor_divide         5.09%     189.838ms        12.50%     465.764ms      18.952us      30.442ms        13.63%      30.448ms       1.239us         24576  
                                              aten::bmm         0.01%     247.707us         0.01%     294.697us      49.116us      29.554ms        13.24%      29.554ms       4.926ms             6  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us      29.554ms        13.24%      29.554ms       4.926ms             6  
                                            aten::copy_         3.50%     130.326ms         8.36%     311.636ms      12.674us      26.718ms        11.97%      26.719ms       1.087us         24588  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us      26.715ms        11.96%      26.715ms       1.087us         24581  
                                              aten::mul         2.92%     108.800ms         5.34%     198.878ms      16.173us      25.547ms        11.44%      25.547ms       2.077us         12297  
                                              aten::add         3.96%     147.436ms         7.04%     262.447ms      14.081us      22.490ms        10.07%      22.492ms       1.207us         18639  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      22.115ms         9.90%      22.115ms       1.800us         12288  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      16.451ms         7.37%      16.451ms       1.339us         12287  
                                        aten::remainder         2.81%     104.739ms         4.44%     165.425ms      13.462us      14.805ms         6.63%      14.806ms       1.205us         12288  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      13.992ms         6.27%      13.992ms       1.139us         12287  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      13.166ms         5.90%      13.166ms       1.061us         12407  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       7.819ms         3.50%       7.819ms       1.273us          6144  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       6.986ms         3.13%       6.986ms       1.137us          6144  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       6.214ms         2.78%       6.214ms       0.998us          6228  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 3.727s
Self CUDA time total: 223.293ms



======================================================================
PROFILE TRACE: binned_torch | cuda_B4_S1024_E2
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us        6.919s      1639.48%        6.919s        6.919s             1  
                                           binned_torch        24.46%        1.695s       100.00%        6.929s        6.929s       0.000us         0.00%     422.036ms     422.036ms             1  
                                             aten::item         1.67%     115.500ms        26.73%        1.852s      15.089us       0.000us         0.00%     127.102ms       1.035us        122763  
                              aten::_local_scalar_dense         5.94%     411.594ms        25.07%        1.737s      14.148us     127.094ms        30.12%     127.102ms       1.035us        122763  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us     127.096ms        30.12%     127.096ms       1.035us        122762  
                                     aten::floor_divide         5.38%     373.026ms        13.30%     921.425ms      18.746us      61.339ms        14.53%      61.343ms       1.248us         49152  
                                              aten::bmm         0.00%     231.234us         0.00%     280.225us      46.704us      57.287ms        13.57%      57.287ms       9.548ms             6  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us      57.287ms        13.57%      57.287ms       9.548ms             6  
                                            aten::copy_         3.72%     257.654ms         8.91%     617.063ms      12.553us      53.696ms        12.72%      53.697ms       1.092us         49158  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us      53.694ms        12.72%      53.694ms       1.092us         49154  
                                              aten::mul         3.13%     217.096ms         5.68%     393.622ms      16.011us      51.639ms        12.24%      51.644ms       2.101us         24585  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      44.676ms        10.59%      44.676ms       1.818us         24576  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      33.163ms         7.86%      33.163ms       1.349us         24576  
                                              aten::add         2.81%     194.866ms         4.91%     340.544ms      13.937us      32.585ms         7.72%      32.588ms       1.334us         24435  
                                        aten::remainder         3.09%     213.993ms         4.85%     335.801ms      13.664us      29.914ms         7.09%      29.918ms       1.217us         24576  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      28.177ms         6.68%      28.177ms       1.147us         24576  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      25.921ms         6.14%      25.921ms       1.061us         24431  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      15.786ms         3.74%      15.786ms       1.285us         12288  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      14.129ms         3.35%      14.129ms       1.150us         12288  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       5.239ms         1.24%       5.239ms     873.180us             6  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 6.929s
Self CUDA time total: 422.014ms



======================================================================
PROFILE TRACE: binned_torch | cuda_B4_S1024_E4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us        7.526s      1690.98%        7.526s        7.526s             1  
                                           binned_torch        24.06%        1.811s       100.00%        7.528s        7.528s       0.000us         0.00%     445.109ms     445.109ms             1  
                                             aten::item         1.62%     121.583ms        26.84%        2.020s      14.998us       0.000us         0.00%     138.816ms       1.030us        134715  
                              aten::_local_scalar_dense         6.12%     460.388ms        25.22%        1.899s      14.095us     138.805ms        31.19%     138.816ms       1.030us        134715  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us     138.805ms        31.19%     138.805ms       1.030us        134707  
                                     aten::floor_divide         5.25%     395.063ms        12.72%     957.555ms      19.482us      61.331ms        13.78%      61.336ms       1.248us         49152  
                                              aten::bmm         0.00%     238.536us         0.00%     289.618us      48.270us      57.304ms        12.88%      57.304ms       9.551ms             6  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us      57.304ms        12.88%      57.304ms       9.551ms             6  
                                            aten::copy_         3.62%     272.274ms         8.61%     648.516ms      13.192us      53.873ms        12.10%      53.876ms       1.096us         49158  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us      53.870ms        12.10%      53.870ms       1.096us         49149  
                                              aten::mul         3.08%     231.551ms         5.44%     409.269ms      16.647us      51.546ms        11.58%      51.551ms       2.097us         24585  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      44.593ms        10.02%      44.593ms       1.814us         24576  
                                              aten::add         4.08%     306.812ms         7.05%     530.578ms      14.594us      43.966ms         9.88%      43.969ms       1.209us         36357  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      33.107ms         7.44%      33.107ms       1.347us         24573  
                                        aten::remainder         2.97%     223.921ms         4.70%     353.632ms      14.389us      29.770ms         6.69%      29.775ms       1.211us         24577  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      28.225ms         6.34%      28.225ms       1.149us         24573  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      25.583ms         5.75%      25.583ms       1.047us         24431  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      15.722ms         3.53%      15.722ms       1.279us         12288  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      14.047ms         3.16%      14.047ms       1.143us         12288  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      11.757ms         2.64%      11.757ms       0.986us         11922  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.528s
Self CUDA time total: 445.070ms


impl                     wl                  p50(ms)  ok
binned_torch             cuda_B1_S1024_E2     367.98  True
binned_torch             cuda_B1_S1024_E4     396.30  True
binned_torch             cuda_B1_S512_E2      154.35  True
binned_torch             cuda_B1_S512_E4      195.55  True
binned_torch             cuda_B4_S1024_E2    1510.09  True
binned_torch             cuda_B4_S1024_E4    1618.05  True
binned_torch             cuda_B4_S512_E2      733.47  True
binned_torch             cuda_B4_S512_E4      787.61  True
▶ UV Install Logs

Artifacts:

openai_moe.jsonl