Compiled SwiGLU Activation

GPU Info

▼ code ▼ output ▶ uv-logs | Cell: nv | 0.23s | Raw GitHub
import subprocess
print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)
Thu Oct 23 17:21:49 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.195.03             Driver Version: 570.195.03     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA L40S                    On  |   00000000:4D:00.0 Off |                    0 |
| N/A   37C    P0             80W /  350W |       0MiB /  46068MiB |     13%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

SwiGLU Benchmark (torch.compile)

▼ code ▼ output ▶ uv-logs | Cell: benchmark | 14.79s | Raw GitHub
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "numpy",
#     "torch==2.8.0",
#     "kernels-benchmark-tools",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
# ///
import torch
import sys
import kernels_benchmark_tools as kbt


def torch_swiglu_base(input_tensor):
    """Base PyTorch SwiGLU implementation"""
    d = input_tensor.shape[-1] // 2
    x1 = input_tensor[..., :d]
    x2 = input_tensor[..., d:]
    return torch.nn.functional.silu(x1) * x2


# Compile the function
compiled_swiglu = torch.compile(torch_swiglu_base, mode="max-autotune", fullgraph=True, dynamic=False)


# Register the implementation
kbt.add(
    "compiled_swiglu_max_autotune",
    compiled_swiglu,
    tags={"family": "torch", "backend": "compiled", "compile": "max-autotune"},
)

if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = "float32" if device == "cpu" else "bfloat16"

    # Generate workloads - using a subset for faster testing
    if device == "cuda":
        wl = list(kbt.activation.llama_workloads(dtype=dtype))[:3]
    else:
        wl = list(kbt.activation.cpu_workloads(dtype=dtype))[:3]

    print(f"Running SwiGLU benchmarks on {device} with {dtype}")
    print(f"Testing {len(wl)} workloads")

    # Run benchmark
    kbt.run(
        wl,
        jsonl="activation.jsonl",
        reps=5,
        warmup=2,
        gen=kbt.activation.gen_inputs,
        ref=kbt.activation.ref_swiglu,
        cmp=kbt.activation.cmp_allclose,
        profile_trace=True
    )

    kbt.summarize(["activation.jsonl"])
Running SwiGLU benchmarks on cuda with bfloat16
Testing 3 workloads

======================================================================
PROFILE TRACE: compiled_swiglu_max_autotune | llama_T512_D4096
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                           compiled_swiglu_max_autotune         0.00%       0.000us         0.00%       0.000us       0.000us       1.851ms      5297.74%       1.851ms     925.622us             2  
                           compiled_swiglu_max_autotune         0.10%     159.779us        99.99%     166.375ms     166.375ms       0.000us         0.00%      38.816us      38.816us             1  
                             Torch-Compiled Region: 0/1         1.45%       2.415ms        99.86%     166.157ms      55.386ms      11.007us        31.50%      38.816us      12.939us             3  
                                   aten::_foreach_copy_         0.02%      39.542us         0.05%      87.165us      29.055us      21.600us        61.81%      21.600us       7.200us             3  
void at::native::(anonymous namespace)::multi_tensor...         0.00%       0.000us         0.00%       0.000us       0.000us      21.600us        61.81%      21.600us       7.200us             3  
                    CUDAGraphNode.record (dynamo_timed)         0.00%       0.000us         0.00%       0.000us       0.000us      20.673us        59.16%      20.673us      20.673us             1  
                            triton_poi_fused_mul_silu_0         0.00%       0.000us         0.00%       0.000us       0.000us      11.007us        31.50%      11.007us       3.669us             3  
                                Activity Buffer Request         0.86%       1.424ms         0.86%       1.424ms       1.424ms       3.872us        11.08%       3.872us       3.872us             1  
                    CUDAGraphNode.record (dynamo_timed)        96.87%     161.185ms        97.39%     162.045ms     162.045ms       0.000us         0.00%       2.337us       2.337us             1  
                                            aten::fill_         0.02%      34.251us         0.05%      74.934us      37.467us       2.337us         6.69%       2.337us       1.168us             2  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       2.337us         6.69%       2.337us       1.168us             2  
                               TorchDynamo Cache Lookup         0.03%      57.633us         0.03%      57.633us      19.211us       0.000us         0.00%       0.000us       0.000us             3  
                                      Pregraph bytecode         0.01%      12.280us         0.01%      12.280us       4.093us       0.000us         0.00%       0.000us       0.000us             3  
                 AOTDispatcher Runtime Wrapper Prologue         0.01%      21.352us         0.01%      21.352us       7.117us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize         0.07%     111.205us         0.07%     111.205us      18.534us       0.000us         0.00%       0.000us       0.000us             6  
                                  cudaStreamIsCapturing         0.01%      10.600us         0.01%      10.600us       0.815us       0.000us         0.00%       0.000us       0.000us            13  
                               cudaEventRecordWithFlags         0.00%       4.751us         0.00%       4.751us       1.584us       0.000us         0.00%       0.000us       0.000us             3  
                                    cudaStreamWaitEvent         0.00%       4.550us         0.00%       4.550us       1.517us       0.000us         0.00%       0.000us       0.000us             3  
                                    aten::empty_strided         0.01%      14.680us         0.01%      14.680us       4.893us       0.000us         0.00%       0.000us       0.000us             3  
                                       cudaLaunchKernel         0.05%      88.306us         0.05%      88.306us      17.661us       0.000us         0.00%       0.000us       0.000us             5  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 166.389ms
Self CUDA time total: 34.944us



======================================================================
PROFILE TRACE: compiled_swiglu_max_autotune | llama_T512_D8192
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                           compiled_swiglu_max_autotune         0.00%       0.000us         0.00%       0.000us       0.000us       1.882ms      2857.54%       1.882ms     940.918us             2  
                           compiled_swiglu_max_autotune         0.08%     131.855us        99.99%     174.569ms     174.569ms       0.000us         0.00%      72.799us      72.799us             1  
                             Torch-Compiled Region: 0/3         1.26%       2.204ms        99.89%     174.392ms      58.131ms      18.240us        27.70%      72.799us      24.266us             3  
                                   aten::_foreach_copy_         0.02%      39.114us         0.05%      88.345us      29.448us      45.247us        68.71%      45.247us      15.082us             3  
void at::native::(anonymous namespace)::multi_tensor...         0.00%       0.000us         0.00%       0.000us       0.000us      45.247us        68.71%      45.247us      15.082us             3  
                    CUDAGraphNode.record (dynamo_timed)         0.00%       0.000us         0.00%       0.000us       0.000us      19.904us        30.22%      19.904us      19.904us             1  
                            triton_poi_fused_mul_silu_0         0.00%       0.000us         0.00%       0.000us       0.000us      18.240us        27.70%      18.240us       6.080us             3  
                                Activity Buffer Request         0.83%       1.441ms         0.83%       1.441ms       1.441ms       6.944us        10.54%       6.944us       6.944us             1  
                    CUDAGraphNode.record (dynamo_timed)        96.65%     168.746ms        97.67%     170.521ms     170.521ms       0.000us         0.00%       2.368us       2.368us             1  
                                            aten::fill_         0.02%      36.482us         0.04%      78.354us      39.177us       2.368us         3.60%       2.368us       1.184us             2  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       2.368us         3.60%       2.368us       1.184us             2  
                               TorchDynamo Cache Lookup         0.03%      45.013us         0.03%      45.013us      15.004us       0.000us         0.00%       0.000us       0.000us             3  
                                      Pregraph bytecode         0.01%       9.190us         0.01%       9.190us       3.063us       0.000us         0.00%       0.000us       0.000us             3  
                 AOTDispatcher Runtime Wrapper Prologue         0.01%      17.071us         0.01%      17.071us       5.690us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize         0.04%      76.533us         0.04%      76.533us      12.755us       0.000us         0.00%       0.000us       0.000us             6  
                                  cudaStreamIsCapturing         0.01%       9.681us         0.01%       9.681us       0.745us       0.000us         0.00%       0.000us       0.000us            13  
                               cudaEventRecordWithFlags         0.00%       3.672us         0.00%       3.672us       1.224us       0.000us         0.00%       0.000us       0.000us             3  
                                    cudaStreamWaitEvent         0.00%       3.040us         0.00%       3.040us       1.013us       0.000us         0.00%       0.000us       0.000us             3  
                                    aten::empty_strided         0.01%      12.061us         0.01%      12.061us       4.020us       0.000us         0.00%       0.000us       0.000us             3  
                                       cudaLaunchKernel         0.05%      91.103us         0.05%      91.103us      18.221us       0.000us         0.00%       0.000us       0.000us             5  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 174.590ms
Self CUDA time total: 65.855us



======================================================================
PROFILE TRACE: compiled_swiglu_max_autotune | llama_T512_D11008
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                           compiled_swiglu_max_autotune         0.00%       0.000us         0.00%       0.000us       0.000us       1.863ms      1771.89%       1.863ms     931.590us             2  
                           compiled_swiglu_max_autotune         0.07%     121.234us        99.99%     174.986ms     174.986ms       0.000us         0.00%     113.760us     113.760us             1  
                             Torch-Compiled Region: 0/5         1.21%       2.117ms        99.90%     174.826ms      58.275ms      24.864us        23.65%     113.760us      37.920us             3  
                                   aten::_foreach_copy_         0.02%      36.152us         0.05%      83.124us      27.708us      78.144us        74.32%      78.144us      26.048us             3  
void at::native::(anonymous namespace)::multi_tensor...         0.00%       0.000us         0.00%       0.000us       0.000us      78.144us        74.32%      78.144us      26.048us             3  
                            triton_poi_fused_mul_silu_0         0.00%       0.000us         0.00%       0.000us       0.000us      24.864us        23.65%      24.864us       8.288us             3  
                    CUDAGraphNode.record (dynamo_timed)         0.00%       0.000us         0.00%       0.000us       0.000us      19.776us        18.81%      19.776us      19.776us             1  
                                Activity Buffer Request         0.77%       1.349ms         0.77%       1.349ms       1.349ms       8.608us         8.19%       8.608us       8.608us             1  
                    CUDAGraphNode.record (dynamo_timed)        96.23%     168.408ms        97.80%     171.145ms     171.145ms       0.000us         0.00%       2.144us       2.144us             1  
                                            aten::fill_         0.02%      32.121us         0.04%      72.933us      36.467us       2.144us         2.04%       2.144us       1.072us             2  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       2.144us         2.04%       2.144us       1.072us             2  
                               TorchDynamo Cache Lookup         0.02%      38.274us         0.02%      38.274us      12.758us       0.000us         0.00%       0.000us       0.000us             3  
                                      Pregraph bytecode         0.01%       9.421us         0.01%       9.421us       3.140us       0.000us         0.00%       0.000us       0.000us             3  
                 AOTDispatcher Runtime Wrapper Prologue         0.01%      14.201us         0.01%      14.201us       4.734us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize         0.04%      73.664us         0.04%      73.664us      12.277us       0.000us         0.00%       0.000us       0.000us             6  
                                  cudaStreamIsCapturing         0.01%       9.722us         0.01%       9.722us       0.748us       0.000us         0.00%       0.000us       0.000us            13  
                               cudaEventRecordWithFlags         0.00%       3.409us         0.00%       3.409us       1.136us       0.000us         0.00%       0.000us       0.000us             3  
                                    cudaStreamWaitEvent         0.00%       2.910us         0.00%       2.910us       0.970us       0.000us         0.00%       0.000us       0.000us             3  
                                    aten::empty_strided         0.01%      11.600us         0.01%      11.600us       3.867us       0.000us         0.00%       0.000us       0.000us             3  
                                       cudaLaunchKernel         0.05%      87.784us         0.05%      87.784us      17.557us       0.000us         0.00%       0.000us       0.000us             5  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 175.003ms
Self CUDA time total: 105.152us


impl                     wl                  p50(ms)  ok
compiled_swiglu_max_autotune llama_T512_D11008      0.11  True
compiled_swiglu_max_autotune llama_T512_D4096       0.10  True
compiled_swiglu_max_autotune llama_T512_D8192       0.11  True
▶ UV Install Logs

Artifacts:

activation.jsonl