Binned PyTorch - OpenAI-style MoE

GPU Info

▼ code ▼ output ▶ uv-logs | Cell: nv | 0.28s | Raw GitHub
import subprocess
print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)
Fri Dec 19 19:41:48 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.105.08             Driver Version: 580.105.08     CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA L40S                    On  |   00000000:4D:00.0 Off |                    0 |
| N/A   33C    P0            126W /  350W |       0MiB /  46068MiB |    100%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

OpenAI-style MoE Benchmark (Binned PyTorch)

▼ code ▼ output ▶ uv-logs | Cell: benchmark | 733.46s | Raw GitHub
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "numpy",
#     "torch==2.8.0",
#     "kernels-benchmark-tools",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
# ///
import torch
import sys
from kernels_benchmark_tools import KernelTypeEnum, run_benchmark


def binned_gather(x, indices, bins, expert_capacity, top_k):
    E, H = bins.shape[0], x.shape[1]
    out = torch.zeros((E, expert_capacity, H), device=x.device, dtype=x.dtype)
    for e in range(E):
        start = 0 if e == 0 else bins[e - 1]
        end = bins[e]
        n = min(end - start, expert_capacity)
        for i in range(n):
            flat_pos = indices[start + i]
            tok = flat_pos // top_k
            out[e, i] = x[tok]
    return out


def binned_scatter(x, indices, weights, bins, expert_capacity, top_k):
    E, C, H = x.shape
    N = indices.shape[0] // top_k
    out = torch.zeros((N, top_k, H), dtype=x.dtype, device=x.device)
    for e in range(E):
        start = 0 if e == 0 else bins[e - 1]
        end = bins[e]
        n = end - start
        if n == 0:
            continue
        take = min(n, expert_capacity)
        for i in range(take):
            flat_pos = indices[start + i]  # flattened (token, slot)
            tok = flat_pos // top_k
            slot = flat_pos % top_k
            scale = weights[flat_pos] if weights is not None else 1.0
            out[tok, slot] = x[e, i] * scale
    return out.sum(dim=1)


def sort_tokens_by_expert(router_indices, num_experts):
    flat_indices = router_indices.flatten()
    sorted_values, sorted_indices = torch.sort(flat_indices)
    tokens_per_expert = torch.bincount(sorted_values, minlength=num_experts)
    bins = torch.cumsum(tokens_per_expert, dim=0)
    return sorted_indices, sorted_values, bins, tokens_per_expert


def binned_experts_ref(
    hidden_states,
    router_indices,
    routing_weights,
    gate_up_proj,
    gate_up_proj_bias,
    down_proj,
    down_proj_bias,
    expert_capacity,
):
    B, S, H = hidden_states.shape
    E, K = routing_weights.shape[2], router_indices.shape[1]

    indices, _, bins, _ = sort_tokens_by_expert(router_indices, E)
    x = binned_gather(hidden_states.view(-1, H), indices, bins, expert_capacity, K)

    gate_up = torch.bmm(x, gate_up_proj) + gate_up_proj_bias[..., None, :]
    gate, up = gate_up[..., ::2], gate_up[..., 1::2]

    # clamp to limit
    limit = 7.0
    gate = gate.clamp(min=None, max=limit)
    up = up.clamp(min=-limit, max=limit)

    glu = gate * torch.sigmoid(gate * 1.702)
    x = (up + 1) * glu
    x = torch.bmm(x, down_proj) + down_proj_bias[..., None, :]

    # build routing weights aligned to (token, slot)
    flat_dense = routing_weights.view(-1, E)  # [B*S, E]
    flat_router = router_indices.view(-1, K)  # [B*S, K]
    selected = torch.gather(flat_dense, 1, flat_router).reshape(-1)  # [B*S*K]

    # scatter back
    y = binned_scatter(x, indices, selected, bins, expert_capacity, K)  # [B*S, H]

    return y.view(B, S, H)


def binned_torch_openai_moe(
    hidden_states,
    router_indices,
    routing_weights,
    gate_up_proj,
    gate_up_proj_bias,
    down_proj,
    down_proj_bias,
):
    """
    Binned PyTorch implementation of OpenAI-style MoE.
    Sorts tokens by expert assignment for more efficient batched processing.
    """
    B, S = hidden_states.shape[0], hidden_states.shape[1]
    K = router_indices.shape[1]

    # Set expert_capacity to a reasonable value (max tokens per expert)
    # Use 2x the average to handle imbalance
    expert_capacity = (B * S * K * 2) // routing_weights.shape[2]

    return binned_experts_ref(
        hidden_states,
        router_indices,
        routing_weights,
        gate_up_proj,
        gate_up_proj_bias,
        down_proj,
        down_proj_bias,
        expert_capacity,
    )


run_benchmark(
    kernel_type=KernelTypeEnum.OPENAI_MOE,
    impl_name="binned_torch",
    impl_tags={"family": "pytorch", "backend": "eager"},
    impl_func=binned_torch_openai_moe,
    dtype="float32",
)
Running openai_moe benchmark on cuda with 8 workloads.

======================================================================
PROFILE TRACE: binned_torch | cuda_B1_S512_E2
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us     935.516ms      1843.92%     935.516ms     935.516ms             1  
                                           binned_torch        24.73%     231.815ms       100.00%     937.553ms     937.553ms       0.000us         0.00%      50.740ms      50.740ms             1  
                                             aten::item         1.92%      17.997ms        26.19%     245.573ms      16.003us       0.000us         0.00%      15.756ms       1.027us         15345  
                              aten::_local_scalar_dense         6.46%      60.533ms        24.27%     227.576ms      14.831us      15.755ms        31.05%      15.756ms       1.027us         15345  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us      15.755ms        31.05%      15.755ms       1.027us         15345  
                                     aten::floor_divide         5.33%      49.954ms        13.00%     121.926ms      19.845us       7.813ms        15.40%       7.813ms       1.272us          6144  
                                              aten::bmm         0.02%     192.684us         0.02%     232.345us      38.724us       7.792ms        15.36%       7.792ms       1.299ms             6  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us       7.792ms        15.36%       7.792ms       1.299ms             6  
                                            aten::copy_         3.73%      34.970ms         9.17%      86.008ms      13.971us       6.589ms        12.99%       6.590ms       1.071us          6156  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us       6.585ms        12.98%       6.585ms       1.070us          6153  
                                              aten::mul         3.28%      30.750ms         5.69%      53.382ms      17.326us       4.708ms         9.28%       4.708ms       1.528us          3081  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us       4.480ms         8.83%       4.480ms       1.458us          3072  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       4.159ms         8.20%       4.159ms       1.354us          3072  
                                        aten::remainder         3.15%      29.490ms         4.77%      44.737ms      14.563us       3.838ms         7.56%       3.838ms       1.249us          3072  
                                              aten::add         2.76%      25.910ms         4.76%      44.643ms      14.719us       3.755ms         7.40%       3.755ms       1.238us          3033  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       3.655ms         7.20%       3.655ms       1.190us          3072  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       3.364ms         6.63%       3.364ms       1.110us          3030  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       2.022ms         3.99%       2.022ms       1.316us          1536  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       1.816ms         3.58%       1.816ms       1.182us          1536  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     284.802us         0.56%     284.802us      47.467us             6  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 937.562ms
Self CUDA time total: 50.735ms



======================================================================
PROFILE TRACE: binned_torch | cuda_B1_S512_E4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us     958.363ms      1758.28%     958.363ms     958.363ms             1  
                                           binned_torch        24.25%     232.525ms       100.00%     958.754ms     958.754ms       0.000us         0.00%      54.510ms      54.510ms             1  
                                             aten::item         1.77%      17.002ms        27.44%     263.071ms      15.534us       0.000us         0.00%      17.862ms       1.055us         16935  
                              aten::_local_scalar_dense         6.54%      62.707ms        25.67%     246.070ms      14.530us      17.860ms        32.77%      17.862ms       1.055us         16935  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us      17.860ms        32.77%      17.860ms       1.055us         16935  
                                              aten::bmm         0.02%     170.065us         0.02%     212.615us      35.436us       7.895ms        14.48%       7.895ms       1.316ms             6  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us       7.895ms        14.48%       7.895ms       1.316ms             6  
                                     aten::floor_divide         4.96%      47.565ms        12.31%     117.977ms      19.202us       7.812ms        14.33%       7.813ms       1.272us          6144  
                                            aten::copy_         3.61%      34.645ms         8.68%      83.187ms      13.513us       6.631ms        12.17%       6.631ms       1.077us          6156  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us       6.628ms        12.16%       6.628ms       1.077us          6152  
                                              aten::add         3.91%      37.531ms         7.22%      69.217ms      15.070us       5.262ms         9.65%       5.262ms       1.146us          4593  
                                              aten::mul         3.03%      29.029ms         5.30%      50.820ms      16.495us       4.703ms         8.63%       4.703ms       1.526us          3081  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us       4.476ms         8.21%       4.476ms       1.457us          3072  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       4.156ms         7.62%       4.156ms       1.353us          3072  
                                        aten::remainder         2.84%      27.273ms         4.45%      42.673ms      13.891us       3.854ms         7.07%       3.854ms       1.255us          3072  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       3.656ms         6.71%       3.656ms       1.190us          3072  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       3.271ms         6.00%       3.271ms       1.080us          3030  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       2.031ms         3.73%       2.031ms       1.323us          1536  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       1.822ms         3.34%       1.822ms       1.187us          1536  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       1.585ms         2.91%       1.585ms       1.016us          1560  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 958.762ms
Self CUDA time total: 54.506ms



======================================================================
PROFILE TRACE: binned_torch | cuda_B1_S1024_E2
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us        1.754s      1688.21%        1.754s        1.754s             1  
                                           binned_torch        24.13%     423.200ms       100.00%        1.754s        1.754s       0.000us         0.00%     103.889ms     103.889ms             1  
                                             aten::item         1.68%      29.485ms        26.54%     465.492ms      15.256us       0.000us         0.00%      31.587ms       1.035us         30513  
                              aten::_local_scalar_dense         6.17%     108.158ms        24.86%     436.007ms      14.289us      31.585ms        30.40%      31.587ms       1.035us         30513  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us      31.585ms        30.40%      31.585ms       1.035us         30513  
                                     aten::floor_divide         5.33%      93.524ms        13.33%     233.711ms      19.019us      15.605ms        15.02%      15.605ms       1.270us         12288  
                                              aten::bmm         0.01%     221.157us         0.02%     267.387us      44.564us      15.098ms        14.53%      15.098ms       2.516ms             6  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us      15.098ms        14.53%      15.098ms       2.516ms             6  
                                            aten::copy_         3.90%      68.459ms         9.45%     165.766ms      13.477us      13.325ms        12.83%      13.325ms       1.083us         12300  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us      13.322ms        12.82%      13.322ms       1.084us         12294  
                                              aten::mul         3.29%      57.635ms         5.89%     103.357ms      16.798us      11.271ms        10.85%      11.273ms       1.832us          6153  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us       9.920ms         9.55%       9.920ms       1.615us          6144  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       8.308ms         8.00%       8.308ms       1.352us          6144  
                                        aten::remainder         3.09%      54.193ms         4.85%      85.026ms      13.839us       7.675ms         7.39%       7.675ms       1.249us          6144  
                                              aten::add         2.79%      48.989ms         4.92%      86.297ms      14.595us       7.638ms         7.35%       7.639ms       1.292us          5913  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       7.297ms         7.02%       7.297ms       1.188us          6144  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       6.357ms         6.12%       6.357ms       1.076us          5910  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       4.044ms         3.89%       4.044ms       1.317us          3072  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       3.632ms         3.50%       3.632ms       1.182us          3072  
                                            aten::clamp         0.00%      73.899us         0.01%     123.411us      20.569us       1.193ms         1.15%       1.193ms     198.833us             6  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 1.754s
Self CUDA time total: 103.882ms



======================================================================
PROFILE TRACE: binned_torch | cuda_B1_S1024_E4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us        1.874s      1695.99%        1.874s        1.874s             1  
                                           binned_torch        24.25%     455.076ms       100.00%        1.876s        1.876s       0.000us         0.00%     110.516ms     110.516ms             1  
                                             aten::item         1.77%      33.154ms        27.43%     514.675ms      15.259us       0.000us         0.00%      34.979ms       1.037us         33729  
                              aten::_local_scalar_dense         6.27%     117.583ms        25.66%     481.520ms      14.276us      34.976ms        31.65%      34.979ms       1.037us         33729  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us      34.976ms        31.65%      34.976ms       1.037us         33728  
                                     aten::floor_divide         4.89%      91.819ms        12.09%     226.952ms      18.469us      15.582ms        14.10%      15.582ms       1.268us         12288  
                                              aten::bmm         0.01%     222.715us         0.01%     267.616us      44.603us      15.462ms        13.99%      15.462ms       2.577ms             6  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us      15.462ms        13.99%      15.462ms       2.577ms             6  
                                            aten::copy_         3.58%      67.106ms         8.62%     161.781ms      13.153us      13.339ms        12.07%      13.341ms       1.085us         12300  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us      13.337ms        12.07%      13.337ms       1.085us         12294  
                                              aten::mul         3.09%      57.893ms         5.35%     100.363ms      16.311us      10.926ms         9.89%      10.927ms       1.776us          6153  
                                              aten::add         4.06%      76.225ms         6.94%     130.290ms      14.319us      10.845ms         9.81%      10.845ms       1.192us          9099  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us       9.572ms         8.66%       9.572ms       1.558us          6144  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       8.302ms         7.51%       8.302ms       1.351us          6144  
                                        aten::remainder         2.99%      56.031ms         4.55%      85.473ms      13.912us       7.682ms         6.95%       7.682ms       1.250us          6144  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       7.280ms         6.59%       7.280ms       1.185us          6144  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       6.358ms         5.75%       6.358ms       1.076us          5910  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       4.050ms         3.67%       4.050ms       1.318us          3072  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       3.631ms         3.29%       3.631ms       1.182us          3072  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       3.228ms         2.92%       3.228ms       1.013us          3186  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 1.876s
Self CUDA time total: 110.507ms



======================================================================
PROFILE TRACE: binned_torch | cuda_B4_S512_E2
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us        3.610s      1697.16%        3.610s        3.610s             1  
                                           binned_torch        23.68%     855.222ms       100.00%        3.611s        3.611s       0.000us         0.00%     212.735ms     212.735ms             1  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us      63.569ms        29.88%      63.569ms       1.032us         61587  
                                             aten::item         1.81%      65.197ms        27.34%     987.119ms      16.028us       0.000us         0.00%      63.568ms       1.032us         61587  
                              aten::_local_scalar_dense         6.48%     233.826ms        25.53%     921.922ms      14.969us      63.567ms        29.88%      63.568ms       1.032us         61587  
                                     aten::floor_divide         5.24%     189.036ms        13.02%     470.235ms      19.134us      31.579ms        14.85%      31.582ms       1.285us         24576  
                                              aten::bmm         0.01%     232.455us         0.01%     281.845us      46.974us      29.001ms        13.63%      29.001ms       4.833ms             6  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us      29.001ms        13.63%      29.001ms       4.833ms             6  
                                            aten::copy_         3.67%     132.477ms         9.25%     334.079ms      13.587us      26.719ms        12.56%      26.722ms       1.087us         24588  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us      26.715ms        12.56%      26.715ms       1.087us         24585  
                                              aten::mul         3.15%     113.903ms         5.68%     205.201ms      16.687us      25.580ms        12.03%      25.582ms       2.080us         12297  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      22.132ms        10.40%      22.132ms       1.801us         12288  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      16.992ms         7.99%      16.992ms       1.383us         12288  
                                              aten::add         2.81%     101.355ms         4.98%     179.658ms      14.476us      16.634ms         7.82%      16.635ms       1.340us         12411  
                                        aten::remainder         3.15%     113.609ms         4.99%     180.020ms      14.650us      15.413ms         7.25%      15.415ms       1.255us         12288  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      14.588ms         6.86%      14.588ms       1.187us         12288  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      13.512ms         6.35%      13.512ms       1.089us         12408  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       8.121ms         3.82%       8.121ms       1.322us          6144  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       7.292ms         3.43%       7.292ms       1.187us          6144  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       2.612ms         1.23%       2.612ms     435.298us             6  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 3.611s
Self CUDA time total: 212.720ms



======================================================================
PROFILE TRACE: binned_torch | cuda_B4_S512_E4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us        3.762s      1666.70%        3.762s        3.762s             1  
                                           binned_torch        23.91%     899.748ms       100.00%        3.764s        3.764s       0.000us         0.00%     225.734ms     225.734ms             1  
                                             aten::item         1.82%      68.620ms        27.46%        1.034s      15.235us       0.000us         0.00%      69.795ms       1.029us         67845  
                              aten::_local_scalar_dense         6.31%     237.441ms        25.64%     964.994ms      14.224us      69.792ms        30.92%      69.795ms       1.029us         67845  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us      69.793ms        30.92%      69.793ms       1.029us         67840  
                                     aten::floor_divide         4.95%     186.290ms        12.17%     458.105ms      18.640us      31.553ms        13.98%      31.560ms       1.284us         24576  
                                              aten::bmm         0.01%     226.315us         0.01%     272.505us      45.418us      29.269ms        12.97%      29.269ms       4.878ms             6  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us      29.269ms        12.97%      29.269ms       4.878ms             6  
                                            aten::copy_         3.56%     134.013ms         8.54%     321.380ms      13.071us      26.742ms        11.85%      26.743ms       1.088us         24588  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us      26.740ms        11.85%      26.740ms       1.088us         24581  
                                              aten::mul         3.06%     115.077ms         5.31%     199.757ms      16.244us      25.618ms        11.35%      25.618ms       2.083us         12297  
                                              aten::add         4.14%     155.825ms         7.08%     266.365ms      14.291us      23.275ms        10.31%      23.276ms       1.249us         18639  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      22.160ms         9.82%      22.160ms       1.803us         12288  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      17.005ms         7.53%      17.005ms       1.384us         12287  
                                        aten::remainder         2.93%     110.282ms         4.49%     168.952ms      13.749us      15.362ms         6.81%      15.364ms       1.250us         12288  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      14.548ms         6.45%      14.548ms       1.184us         12287  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      13.690ms         6.07%      13.690ms       1.103us         12407  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       8.098ms         3.59%       8.098ms       1.318us          6144  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       7.264ms         3.22%       7.264ms       1.182us          6144  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       6.476ms         2.87%       6.476ms       1.040us          6228  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 3.764s
Self CUDA time total: 225.722ms



======================================================================
PROFILE TRACE: binned_torch | cuda_B4_S1024_E2
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us        7.172s      1685.34%        7.172s        7.172s             1  
                                           binned_torch        23.83%        1.712s       100.00%        7.184s        7.184s       0.000us         0.00%     425.602ms     425.602ms             1  
                                             aten::item         1.77%     127.233ms        27.17%        1.952s      15.898us       0.000us         0.00%     127.069ms       1.035us        122763  
                              aten::_local_scalar_dense         6.22%     446.668ms        25.40%        1.825s      14.862us     127.060ms        29.86%     127.069ms       1.035us        122763  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us     127.060ms        29.86%     127.060ms       1.035us        122762  
                                     aten::floor_divide         5.22%     375.373ms        13.07%     938.750ms      19.099us      63.372ms        14.89%      63.374ms       1.289us         49152  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us      57.057ms        13.41%      57.057ms       9.509ms             6  
                                              aten::bmm         0.00%     232.954us         0.00%     280.556us      46.759us      57.057ms        13.41%      57.057ms       9.509ms             6  
                                            aten::copy_         3.67%     263.382ms         9.14%     656.814ms      13.361us      53.605ms        12.60%      53.606ms       1.090us         49158  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us      53.603ms        12.60%      53.603ms       1.091us         49154  
                                              aten::mul         3.19%     229.239ms         5.71%     410.065ms      16.679us      51.561ms        12.12%      51.568ms       2.098us         24585  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      44.597ms        10.48%      44.597ms       1.815us         24576  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      34.170ms         8.03%      34.170ms       1.390us         24576  
                                              aten::add         2.78%     199.917ms         4.97%     356.982ms      14.609us      33.583ms         7.89%      33.584ms       1.374us         24435  
                                        aten::remainder         3.17%     227.943ms         4.97%     356.780ms      14.517us      30.902ms         7.26%      30.903ms       1.257us         24576  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      29.202ms         6.86%      29.202ms       1.188us         24576  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      26.924ms         6.33%      26.924ms       1.102us         24431  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      16.278ms         3.82%      16.278ms       1.325us         12288  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      14.628ms         3.44%      14.628ms       1.190us         12288  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       5.242ms         1.23%       5.242ms     873.601us             6  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.184s
Self CUDA time total: 425.579ms



======================================================================
PROFILE TRACE: binned_torch | cuda_B4_S1024_E4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us        7.590s      1687.04%        7.590s        7.590s             1  
                                           binned_torch        23.93%        1.817s       100.00%        7.592s        7.592s       0.000us         0.00%     449.935ms     449.935ms             1  
                                             aten::item         1.74%     131.929ms        27.26%        2.070s      15.365us       0.000us         0.00%     139.467ms       1.035us        134715  
                              aten::_local_scalar_dense         6.36%     483.083ms        25.53%        1.938s      14.386us     139.456ms        31.00%     139.467ms       1.035us        134715  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us     139.456ms        31.00%     139.456ms       1.035us        134706  
                                     aten::floor_divide         4.94%     375.293ms        12.19%     925.665ms      18.833us      63.455ms        14.10%      63.460ms       1.291us         49152  
                                              aten::bmm         0.00%     234.075us         0.00%     282.947us      47.158us      56.663ms        12.59%      56.663ms       9.444ms             6  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us      56.663ms        12.59%      56.663ms       9.444ms             6  
                                            aten::copy_         3.75%     285.044ms         8.75%     664.131ms      13.510us      53.858ms        11.97%      53.860ms       1.096us         49158  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us      53.855ms        11.97%      53.855ms       1.096us         49149  
                                              aten::mul         3.08%     233.920ms         5.34%     405.684ms      16.501us      51.582ms        11.47%      51.587ms       2.098us         24585  
                                              aten::add         3.87%     294.168ms         6.87%     521.854ms      14.354us      45.530ms        10.12%      45.534ms       1.252us         36357  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      44.640ms         9.92%      44.640ms       1.816us         24576  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      34.166ms         7.59%      34.166ms       1.390us         24573  
                                        aten::remainder         2.91%     220.707ms         4.59%     348.339ms      14.174us      30.841ms         6.86%      30.843ms       1.255us         24576  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      29.291ms         6.51%      29.291ms       1.192us         24573  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      26.632ms         5.92%      26.632ms       1.090us         24431  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      16.258ms         3.61%      16.258ms       1.323us         12288  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      14.582ms         3.24%      14.582ms       1.187us         12288  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      12.272ms         2.73%      12.272ms       1.029us         11922  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.592s
Self CUDA time total: 449.893ms


impl                     wl                  p50(ms)  ok
binned_torch             cuda_B1_S1024_E2     377.89  True
binned_torch             cuda_B1_S1024_E4     408.91  True
binned_torch             cuda_B1_S512_E2      158.27  True
binned_torch             cuda_B1_S512_E4      209.01  True
binned_torch             cuda_B4_S1024_E2    1516.51  True
binned_torch             cuda_B4_S1024_E4    1643.14  True
binned_torch             cuda_B4_S512_E2      769.64  True
binned_torch             cuda_B4_S512_E4      816.95  True
▶ UV Install Logs

Artifacts:

openai_moe.jsonl