Binned PyTorch - OpenAI-style MoE

GPU Info

▼ code ▼ output ▶ uv-logs | Cell: nv | 0.25s | Raw GitHub
import subprocess
print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)
Fri Dec 19 18:56:28 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.105.08             Driver Version: 580.105.08     CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA L40S                    On  |   00000000:4D:00.0 Off |                    0 |
| N/A   34C    P0             80W /  350W |       0MiB /  46068MiB |     41%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

OpenAI-style MoE Benchmark (Binned PyTorch)

▼ code ▼ output ▶ uv-logs | Cell: benchmark | 730.34s | Raw GitHub
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "numpy",
#     "torch==2.8.0",
#     "kernels-benchmark-tools",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
# ///
import torch
import sys
from kernels_benchmark_tools import KernelTypeEnum, run_benchmark


def binned_gather(x, indices, bins, expert_capacity, top_k):
    E, H = bins.shape[0], x.shape[1]
    out = torch.zeros((E, expert_capacity, H), device=x.device, dtype=x.dtype)
    for e in range(E):
        start = 0 if e == 0 else bins[e - 1]
        end = bins[e]
        n = min(end - start, expert_capacity)
        for i in range(n):
            flat_pos = indices[start + i]
            tok = flat_pos // top_k
            out[e, i] = x[tok]
    return out


def binned_scatter(x, indices, weights, bins, expert_capacity, top_k):
    E, C, H = x.shape
    N = indices.shape[0] // top_k
    out = torch.zeros((N, top_k, H), dtype=x.dtype, device=x.device)
    for e in range(E):
        start = 0 if e == 0 else bins[e - 1]
        end = bins[e]
        n = end - start
        if n == 0:
            continue
        take = min(n, expert_capacity)
        for i in range(take):
            flat_pos = indices[start + i]  # flattened (token, slot)
            tok = flat_pos // top_k
            slot = flat_pos % top_k
            scale = weights[flat_pos] if weights is not None else 1.0
            out[tok, slot] = x[e, i] * scale
    return out.sum(dim=1)


def sort_tokens_by_expert(router_indices, num_experts):
    flat_indices = router_indices.flatten()
    sorted_values, sorted_indices = torch.sort(flat_indices)
    tokens_per_expert = torch.bincount(sorted_values, minlength=num_experts)
    bins = torch.cumsum(tokens_per_expert, dim=0)
    return sorted_indices, sorted_values, bins, tokens_per_expert


def binned_experts_ref(
    hidden_states,
    router_indices,
    routing_weights,
    gate_up_proj,
    gate_up_proj_bias,
    down_proj,
    down_proj_bias,
    expert_capacity,
):
    B, S, H = hidden_states.shape
    E, K = routing_weights.shape[2], router_indices.shape[1]

    indices, _, bins, _ = sort_tokens_by_expert(router_indices, E)
    x = binned_gather(hidden_states.view(-1, H), indices, bins, expert_capacity, K)

    gate_up = torch.bmm(x, gate_up_proj) + gate_up_proj_bias[..., None, :]
    gate, up = gate_up[..., ::2], gate_up[..., 1::2]

    # clamp to limit
    limit = 7.0
    gate = gate.clamp(min=None, max=limit)
    up = up.clamp(min=-limit, max=limit)

    glu = gate * torch.sigmoid(gate * 1.702)
    x = (up + 1) * glu
    x = torch.bmm(x, down_proj) + down_proj_bias[..., None, :]

    # build routing weights aligned to (token, slot)
    flat_dense = routing_weights.view(-1, E)  # [B*S, E]
    flat_router = router_indices.view(-1, K)  # [B*S, K]
    selected = torch.gather(flat_dense, 1, flat_router).reshape(-1)  # [B*S*K]

    # scatter back
    y = binned_scatter(x, indices, selected, bins, expert_capacity, K)  # [B*S, H]

    return y.view(B, S, H)


def binned_torch_openai_moe(
    hidden_states,
    router_indices,
    routing_weights,
    gate_up_proj,
    gate_up_proj_bias,
    down_proj,
    down_proj_bias,
):
    """
    Binned PyTorch implementation of OpenAI-style MoE.
    Sorts tokens by expert assignment for more efficient batched processing.
    """
    B, S = hidden_states.shape[0], hidden_states.shape[1]
    K = router_indices.shape[1]

    # Set expert_capacity to a reasonable value (max tokens per expert)
    # Use 2x the average to handle imbalance
    expert_capacity = (B * S * K * 2) // routing_weights.shape[2]

    return binned_experts_ref(
        hidden_states,
        router_indices,
        routing_weights,
        gate_up_proj,
        gate_up_proj_bias,
        down_proj,
        down_proj_bias,
        expert_capacity,
    )


run_benchmark(
    kernel_type=KernelTypeEnum.OPENAI_MOE,
    impl_name="binned_torch",
    impl_tags={"family": "pytorch", "backend": "eager"},
    impl_func=binned_torch_openai_moe,
    dtype="float32",
)
Running openai_moe benchmark on cuda with 8 workloads.

======================================================================
PROFILE TRACE: binned_torch | cuda_B1_S512_E2
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us     919.007ms      1814.55%     919.007ms     919.007ms             1  
                                           binned_torch        24.74%     227.809ms       100.00%     920.989ms     920.989ms       0.000us         0.00%      50.650ms      50.650ms             1  
                                             aten::item         1.86%      17.169ms        26.20%     241.261ms      15.722us       0.000us         0.00%      15.873ms       1.034us         15345  
                              aten::_local_scalar_dense         5.94%      54.669ms        24.33%     224.092ms      14.604us      15.872ms        31.34%      15.873ms       1.034us         15345  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us      15.872ms        31.34%      15.872ms       1.034us         15345  
                                     aten::floor_divide         5.47%      50.387ms        13.12%     120.822ms      19.665us       7.812ms        15.43%       7.812ms       1.272us          6144  
                                              aten::bmm         0.02%     191.383us         0.03%     231.124us      38.521us       7.592ms        14.99%       7.592ms       1.265ms             6  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us       7.592ms        14.99%       7.592ms       1.265ms             6  
                                            aten::copy_         3.61%      33.260ms         9.01%      82.984ms      13.480us       6.583ms        13.00%       6.585ms       1.070us          6156  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us       6.579ms        12.99%       6.579ms       1.069us          6153  
                                              aten::mul         3.25%      29.933ms         5.69%      52.377ms      17.000us       4.706ms         9.29%       4.706ms       1.527us          3081  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us       4.478ms         8.84%       4.478ms       1.458us          3072  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       4.159ms         8.21%       4.159ms       1.354us          3072  
                                        aten::remainder         3.14%      28.956ms         4.78%      44.045ms      14.337us       3.839ms         7.58%       3.839ms       1.250us          3072  
                                              aten::add         2.87%      26.444ms         4.82%      44.437ms      14.651us       3.761ms         7.43%       3.761ms       1.240us          3033  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       3.655ms         7.22%       3.655ms       1.190us          3072  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       3.365ms         6.64%       3.365ms       1.110us          3030  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       2.023ms         3.99%       2.023ms       1.317us          1536  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       1.816ms         3.58%       1.816ms       1.182us          1536  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     287.650us         0.57%     287.650us      47.942us             6  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 920.998ms
Self CUDA time total: 50.647ms



======================================================================
PROFILE TRACE: binned_torch | cuda_B1_S512_E4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us     934.694ms      1714.22%     934.694ms     934.694ms             1  
                                           binned_torch        24.25%     226.767ms       100.00%     935.247ms     935.247ms       0.000us         0.00%      54.534ms      54.534ms             1  
                                             aten::item         1.76%      16.424ms        27.79%     259.914ms      15.348us       0.000us         0.00%      17.987ms       1.062us         16935  
                              aten::_local_scalar_dense         6.05%      56.595ms        26.03%     243.490ms      14.378us      17.985ms        32.98%      17.987ms       1.062us         16935  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us      17.985ms        32.98%      17.985ms       1.062us         16935  
                                     aten::floor_divide         5.13%      47.972ms        12.39%     115.852ms      18.856us       7.812ms        14.33%       7.813ms       1.272us          6144  
                                              aten::bmm         0.02%     166.771us         0.02%     207.402us      34.567us       7.794ms        14.29%       7.794ms       1.299ms             6  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us       7.794ms        14.29%       7.794ms       1.299ms             6  
                                            aten::copy_         3.47%      32.488ms         8.51%      79.554ms      12.923us       6.633ms        12.17%       6.635ms       1.078us          6156  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us       6.630ms        12.16%       6.630ms       1.078us          6153  
                                              aten::add         4.14%      38.686ms         7.06%      65.992ms      14.368us       5.259ms         9.64%       5.259ms       1.145us          4593  
                                              aten::mul         3.02%      28.215ms         5.35%      50.047ms      16.244us       4.701ms         8.62%       4.701ms       1.526us          3081  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us       4.474ms         8.21%       4.474ms       1.457us          3072  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       4.157ms         7.62%       4.157ms       1.353us          3072  
                                        aten::remainder         2.81%      26.265ms         4.43%      41.468ms      13.499us       3.852ms         7.06%       3.852ms       1.254us          3072  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       3.655ms         6.70%       3.655ms       1.190us          3072  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       3.270ms         6.00%       3.270ms       1.079us          3030  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       2.030ms         3.72%       2.030ms       1.322us          1536  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       1.822ms         3.34%       1.822ms       1.186us          1536  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       1.584ms         2.91%       1.584ms       1.015us          1560  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 935.255ms
Self CUDA time total: 54.526ms



======================================================================
PROFILE TRACE: binned_torch | cuda_B1_S1024_E2
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us        1.775s      1705.66%        1.775s        1.775s             1  
                                           binned_torch        24.39%     432.670ms       100.00%        1.774s        1.774s       0.000us         0.00%     104.087ms     104.087ms             1  
                                             aten::item         1.67%      29.627ms        26.26%     465.825ms      15.266us       0.000us         0.00%      31.856ms       1.044us         30513  
                              aten::_local_scalar_dense         5.88%     104.231ms        24.59%     436.198ms      14.295us      31.854ms        30.61%      31.856ms       1.044us         30513  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us      31.854ms        30.61%      31.854ms       1.044us         30513  
                                     aten::floor_divide         5.49%      97.404ms        13.46%     238.769ms      19.431us      15.611ms        15.00%      15.612ms       1.270us         12288  
                                              aten::bmm         0.01%     215.332us         0.01%     258.864us      43.144us      15.009ms        14.42%      15.009ms       2.502ms             6  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us      15.009ms        14.42%      15.009ms       2.502ms             6  
                                            aten::copy_         3.73%      66.187ms         9.04%     160.371ms      13.038us      13.330ms        12.81%      13.331ms       1.084us         12300  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us      13.326ms        12.80%      13.326ms       1.084us         12294  
                                              aten::mul         3.16%      56.128ms         5.72%     101.496ms      16.495us      11.275ms        10.83%      11.277ms       1.833us          6153  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us       9.921ms         9.53%       9.921ms       1.615us          6144  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       8.311ms         7.99%       8.311ms       1.353us          6144  
                                        aten::remainder         3.23%      57.334ms         5.09%      90.371ms      14.709us       7.676ms         7.38%       7.678ms       1.250us          6144  
                                              aten::add         2.88%      51.067ms         5.02%      88.987ms      15.049us       7.641ms         7.34%       7.642ms       1.292us          5913  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       7.300ms         7.01%       7.300ms       1.188us          6144  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       6.359ms         6.11%       6.359ms       1.076us          5910  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       4.045ms         3.89%       4.045ms       1.317us          3072  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       3.632ms         3.49%       3.632ms       1.182us          3072  
                                            aten::clamp         0.00%      74.963us         0.01%     122.824us      20.471us       1.191ms         1.14%       1.191ms     198.444us             6  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 1.774s
Self CUDA time total: 104.078ms



======================================================================
PROFILE TRACE: binned_torch | cuda_B1_S1024_E4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us        1.943s      1756.79%        1.943s        1.943s             1  
                                           binned_torch        24.29%     471.728ms       100.00%        1.942s        1.942s       0.000us         0.00%     110.592ms     110.592ms             1  
                                             aten::item         1.62%      31.476ms        26.94%     523.166ms      15.511us       0.000us         0.00%      35.330ms       1.047us         33729  
                              aten::_local_scalar_dense         6.11%     118.659ms        25.32%     491.691ms      14.578us      35.327ms        31.95%      35.330ms       1.047us         33729  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us      35.327ms        31.95%      35.327ms       1.047us         33728  
                                     aten::floor_divide         5.19%     100.816ms        12.43%     241.273ms      19.635us      15.609ms        14.12%      15.611ms       1.270us         12288  
                                              aten::bmm         0.01%     222.165us         0.01%     267.105us      44.517us      15.085ms        13.64%      15.085ms       2.514ms             6  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us      15.085ms        13.64%      15.085ms       2.514ms             6  
                                            aten::copy_         3.60%      69.833ms         8.76%     170.090ms      13.828us      13.355ms        12.08%      13.357ms       1.086us         12300  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us      13.353ms        12.07%      13.353ms       1.086us         12294  
                                              aten::mul         2.94%      57.042ms         5.32%     103.331ms      16.794us      10.942ms         9.89%      10.942ms       1.778us          6153  
                                              aten::add         3.88%      75.326ms         6.94%     134.721ms      14.806us      10.866ms         9.83%      10.866ms       1.194us          9099  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us       9.591ms         8.67%       9.591ms       1.561us          6144  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       8.314ms         7.52%       8.314ms       1.353us          6144  
                                        aten::remainder         2.77%      53.827ms         4.45%      86.321ms      14.050us       7.697ms         6.96%       7.697ms       1.253us          6144  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       7.295ms         6.60%       7.295ms       1.187us          6144  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       6.370ms         5.76%       6.370ms       1.078us          5910  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       4.058ms         3.67%       4.058ms       1.321us          3072  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       3.639ms         3.29%       3.639ms       1.185us          3072  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       3.234ms         2.92%       3.234ms       1.015us          3186  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 1.942s
Self CUDA time total: 110.585ms



======================================================================
PROFILE TRACE: binned_torch | cuda_B4_S512_E2
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us        3.554s      1668.92%        3.554s        3.554s             1  
                                           binned_torch        24.03%     852.954ms       100.00%        3.549s        3.549s       0.000us         0.00%     212.979ms     212.979ms             1  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us      63.933ms        30.02%      63.933ms       1.038us         61586  
                                             aten::item         1.68%      59.518ms        26.66%     946.248ms      15.364us       0.000us         0.00%      63.933ms       1.038us         61587  
                              aten::_local_scalar_dense         6.15%     218.157ms        24.98%     886.634ms      14.396us      63.932ms        30.02%      63.933ms       1.038us         61587  
                                     aten::floor_divide         5.36%     190.145ms        13.28%     471.339ms      19.179us      31.621ms        14.85%      31.623ms       1.287us         24576  
                                              aten::bmm         0.01%     230.233us         0.01%     275.904us      45.984us      28.855ms        13.55%      28.855ms       4.809ms             6  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us      28.855ms        13.55%      28.855ms       4.809ms             6  
                                            aten::copy_         3.84%     136.428ms         9.38%     333.073ms      13.546us      26.747ms        12.56%      26.749ms       1.088us         24588  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us      26.744ms        12.56%      26.744ms       1.088us         24582  
                                              aten::mul         3.20%     113.415ms         5.79%     205.629ms      16.722us      25.614ms        12.03%      25.614ms       2.083us         12297  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      22.161ms        10.41%      22.161ms       1.803us         12288  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      17.018ms         7.99%      17.018ms       1.385us         12288  
                                              aten::add         2.93%     103.833ms         5.19%     184.217ms      14.843us      16.665ms         7.83%      16.666ms       1.343us         12411  
                                        aten::remainder         3.13%     110.979ms         5.01%     177.878ms      14.476us      15.442ms         7.25%      15.444ms       1.257us         12288  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      14.607ms         6.86%      14.607ms       1.189us         12288  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      13.543ms         6.36%      13.543ms       1.091us         12408  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       8.136ms         3.82%       8.136ms       1.324us          6144  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       7.305ms         3.43%       7.305ms       1.189us          6144  
                                            aten::clamp         0.00%      80.604us         0.00%     131.123us      21.854us       2.608ms         1.22%       2.608ms     434.678us             6  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 3.549s
Self CUDA time total: 212.971ms



======================================================================
PROFILE TRACE: binned_torch | cuda_B4_S512_E4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us        3.834s      1701.16%        3.834s        3.834s             1  
                                           binned_torch        23.91%     917.039ms       100.00%        3.836s        3.836s       0.000us         0.00%     225.394ms     225.394ms             1  
                                             aten::item         1.70%      65.086ms        27.21%        1.044s      15.386us       0.000us         0.00%      70.210ms       1.035us         67845  
                              aten::_local_scalar_dense         6.32%     242.356ms        25.52%     978.758ms      14.426us      70.207ms        31.15%      70.210ms       1.035us         67845  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us      70.207ms        31.15%      70.207ms       1.035us         67840  
                                     aten::floor_divide         5.09%     195.347ms        12.48%     478.676ms      19.477us      31.474ms        13.97%      31.481ms       1.281us         24576  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us      28.832ms        12.79%      28.832ms       4.805ms             6  
                                              aten::bmm         0.01%     227.473us         0.01%     274.364us      45.727us      28.832ms        12.79%      28.832ms       4.805ms             6  
                                            aten::copy_         3.61%     138.479ms         8.82%     338.314ms      13.759us      26.687ms        11.84%      26.689ms       1.085us         24588  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us      26.685ms        11.84%      26.685ms       1.086us         24581  
                                              aten::mul         2.97%     113.735ms         5.38%     206.436ms      16.787us      25.537ms        11.33%      25.539ms       2.077us         12297  
                                              aten::add         4.18%     160.247ms         7.41%     284.235ms      15.249us      23.217ms        10.30%      23.217ms       1.246us         18639  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      22.084ms         9.80%      22.084ms       1.797us         12288  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      16.963ms         7.53%      16.963ms       1.381us         12287  
                                        aten::remainder         2.89%     110.779ms         4.66%     178.579ms      14.533us      15.327ms         6.80%      15.329ms       1.247us         12288  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      14.512ms         6.44%      14.512ms       1.181us         12287  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      13.655ms         6.06%      13.655ms       1.101us         12407  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       8.083ms         3.59%       8.083ms       1.316us          6144  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       7.244ms         3.21%       7.244ms       1.179us          6144  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       6.461ms         2.87%       6.461ms       1.037us          6228  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 3.836s
Self CUDA time total: 225.376ms



======================================================================
PROFILE TRACE: binned_torch | cuda_B4_S1024_E2
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us        7.307s      1714.16%        7.307s        7.307s             1  
                                           binned_torch        24.10%        1.762s       100.00%        7.313s        7.313s       0.000us         0.00%     426.284ms     426.284ms             1  
                                             aten::item         1.74%     126.959ms        26.39%        1.930s      15.721us       0.000us         0.00%     128.245ms       1.045us        122763  
                              aten::_local_scalar_dense         6.22%     454.984ms        24.65%        1.803s      14.685us     128.239ms        30.08%     128.245ms       1.045us        122763  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us     128.241ms        30.08%     128.241ms       1.045us        122762  
                                     aten::floor_divide         5.53%     404.463ms        13.23%     967.808ms      19.690us      63.393ms        14.87%      63.393ms       1.290us         49152  
                                              aten::bmm         0.00%     234.623us         0.00%     278.223us      46.371us      56.525ms        13.26%      56.525ms       9.421ms             6  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us      56.525ms        13.26%      56.525ms       9.421ms             6  
                                            aten::copy_         4.05%     295.852ms         9.44%     690.402ms      14.045us      53.639ms        12.58%      53.640ms       1.091us         49158  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us      53.636ms        12.58%      53.636ms       1.091us         49154  
                                              aten::mul         3.24%     237.068ms         5.73%     419.319ms      17.056us      51.499ms        12.08%      51.504ms       2.095us         24585  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      44.577ms        10.46%      44.577ms       1.814us         24576  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      34.181ms         8.02%      34.181ms       1.391us         24576  
                                              aten::add         2.92%     213.232ms         5.07%     370.760ms      15.173us      33.603ms         7.88%      33.606ms       1.375us         24435  
                                        aten::remainder         3.14%     229.281ms         5.03%     367.714ms      14.962us      30.916ms         7.25%      30.921ms       1.258us         24576  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      29.214ms         6.85%      29.214ms       1.189us         24576  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      26.954ms         6.32%      26.954ms       1.103us         24431  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      16.285ms         3.82%      16.285ms       1.325us         12288  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      14.630ms         3.43%      14.630ms       1.191us         12288  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       5.208ms         1.22%       5.208ms     868.029us             6  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.313s
Self CUDA time total: 426.263ms



======================================================================
PROFILE TRACE: binned_torch | cuda_B4_S1024_E4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us        7.520s      1665.26%        7.520s        7.520s             1  
                                           binned_torch        23.83%        1.792s       100.00%        7.522s        7.522s       0.000us         0.00%     451.603ms     451.603ms             1  
                                             aten::item         1.82%     136.877ms        27.31%        2.054s      15.246us       0.000us         0.00%     140.837ms       1.045us        134715  
                              aten::_local_scalar_dense         6.26%     471.062ms        25.49%        1.917s      14.230us     140.825ms        31.19%     140.837ms       1.045us        134715  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us     140.826ms        31.19%     140.826ms       1.045us        134706  
                                     aten::floor_divide         5.15%     387.087ms        12.45%     936.766ms      19.059us      63.494ms        14.06%      63.499ms       1.292us         49152  
                                              aten::bmm         0.00%     222.563us         0.00%     265.513us      44.252us      56.696ms        12.56%      56.696ms       9.449ms             6  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us      56.696ms        12.56%      56.696ms       9.449ms             6  
                                            aten::copy_         3.71%     279.306ms         8.85%     665.315ms      13.534us      53.897ms        11.94%      53.900ms       1.096us         49158  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us      53.894ms        11.94%      53.894ms       1.097us         49149  
                                              aten::mul         3.04%     228.311ms         5.39%     405.691ms      16.502us      51.688ms        11.45%      51.695ms       2.103us         24585  
                                              aten::add         4.00%     300.523ms         6.98%     525.049ms      14.441us      45.565ms        10.09%      45.568ms       1.253us         36357  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      44.621ms         9.88%      44.621ms       1.816us         24576  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      34.193ms         7.57%      34.193ms       1.391us         24573  
                                        aten::remainder         2.86%     215.282ms         4.58%     344.226ms      14.007us      30.855ms         6.83%      30.857ms       1.256us         24576  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      29.302ms         6.49%      29.302ms       1.192us         24573  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      26.656ms         5.90%      26.656ms       1.091us         24431  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      16.266ms         3.60%      16.266ms       1.324us         12288  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      14.588ms         3.23%      14.588ms       1.187us         12288  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      12.278ms         2.72%      12.278ms       1.030us         11922  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.522s
Self CUDA time total: 451.562ms


impl                     wl                  p50(ms)  ok
binned_torch             cuda_B1_S1024_E2     383.31  True
binned_torch             cuda_B1_S1024_E4     421.42  True
binned_torch             cuda_B1_S512_E2      157.73  True
binned_torch             cuda_B1_S512_E4      204.82  True
binned_torch             cuda_B4_S1024_E2    1513.71  True
binned_torch             cuda_B4_S1024_E4    1658.74  True
binned_torch             cuda_B4_S512_E2      773.70  True
binned_torch             cuda_B4_S512_E4      840.01  True

Artifacts:

openai_moe.jsonl