Memory Efficient Attention Implementation

Memory Efficient SDPA Benchmark

▼ code ▼ output ▶ uv-logs | Cell: benchmark | 4.15s | Raw GitHub
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "numpy",
#     "torch==2.8.0",
#     "kernels-benchmark-tools",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
# ///
import torch
import sys
from kernels_benchmark_tools import KernelTypeEnum, run_benchmark


def torch_mem_eff(q, k, v):
    qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
    with torch.nn.attention.sdpa_kernel(
        torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION
    ):
        o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
    return o.transpose(1, 2).contiguous()


run_benchmark(
    kernel_type=KernelTypeEnum.ATTENTION,
    impl_name="torch_mem_eff",
    impl_tags={"family": "torch-sdpa", "backend": "EFFICIENT", "compile": "none"},
    impl_func=torch_mem_eff,
)
Running attention benchmark on cuda with 6 workloads.

======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L128_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         4.11%     302.695us        35.19%       2.592ms       2.592ms       0.000us         0.00%       5.476ms       5.476ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       5.440ms       100.33%       5.440ms       5.440ms             1  
                     aten::scaled_dot_product_attention         0.40%      29.210us         2.30%     169.213us      56.404us       0.000us         0.00%       4.805ms       1.602ms             3  
          aten::_scaled_dot_product_efficient_attention         0.29%      21.719us         1.90%     140.003us      46.668us       0.000us         0.00%       4.805ms       1.602ms             3  
                     aten::_efficient_attention_forward         0.48%      35.571us         1.32%      97.242us      32.414us       4.805ms        88.62%       4.805ms       1.602ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       4.805ms        88.62%       4.805ms       1.602ms             3  
                                       aten::contiguous         0.13%       9.829us        27.98%       2.062ms     229.090us       0.000us         0.00%     670.404us      74.489us             9  
                                            aten::clone         0.35%      25.869us        27.85%       2.052ms     227.998us       0.000us         0.00%     670.404us      74.489us             9  
                                            aten::copy_         0.98%      72.210us        26.54%       1.956ms     217.285us     616.836us        11.38%     670.404us      74.489us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     616.836us        11.38%     616.836us      68.537us             9  
                                Activity Buffer Request        24.39%       1.797ms        24.39%       1.797ms       1.797ms      53.568us         0.99%      53.568us      53.568us             1  
                                        aten::transpose         0.81%      59.530us         1.08%      79.784us       3.324us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.27%      20.254us         0.27%      20.254us       0.844us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.20%      14.892us         0.96%      70.554us       7.839us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         1.12%      82.341us         1.12%      82.341us       3.921us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         1.48%     109.241us         1.48%     109.241us       9.103us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.04%       3.240us         0.04%       3.240us       1.080us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.12%       9.162us         0.12%       9.162us       3.054us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        64.81%       4.776ms        64.81%       4.776ms       4.776ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.368ms
Self CUDA time total: 5.422ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L256_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         3.18%     243.704us        30.16%       2.312ms       2.312ms       0.000us         0.00%       5.946ms       5.946ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       5.900ms       100.14%       5.900ms       5.900ms             1  
                     aten::scaled_dot_product_attention         0.23%      17.410us         1.83%     139.893us      46.631us       0.000us         0.00%       5.256ms       1.752ms             3  
          aten::_scaled_dot_product_efficient_attention         0.24%      18.330us         1.60%     122.483us      40.828us       0.000us         0.00%       5.256ms       1.752ms             3  
                     aten::_efficient_attention_forward         0.36%      27.350us         1.07%      81.803us      27.268us       5.256ms        89.21%       5.256ms       1.752ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       5.256ms        89.21%       5.256ms       1.752ms             3  
                                       aten::contiguous         0.10%       7.470us        24.63%       1.888ms     209.765us       0.000us         0.00%     690.500us      76.722us             9  
                                            aten::clone         0.27%      20.522us        24.53%       1.880ms     208.935us       0.000us         0.00%     690.500us      76.722us             9  
                                            aten::copy_         0.86%      65.740us        23.60%       1.809ms     200.963us     635.844us        10.79%     690.500us      76.722us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     635.844us        10.79%     635.844us      70.649us             9  
                                Activity Buffer Request        21.87%       1.676ms        21.87%       1.676ms       1.676ms      54.656us         0.93%      54.656us      54.656us             1  
                                        aten::transpose         0.62%      47.210us         0.82%      62.900us       2.621us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.20%      15.690us         0.20%      15.690us       0.654us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.16%      11.901us         0.67%      51.221us       5.691us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.85%      65.201us         0.85%      65.201us       3.105us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         1.16%      89.161us         1.16%      89.161us       7.430us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.381us         0.03%       2.381us       0.794us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.05%       3.881us         0.05%       3.881us       1.294us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        69.84%       5.353ms        69.84%       5.353ms       5.353ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.665ms
Self CUDA time total: 5.891ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L320_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         3.05%     239.816us        30.60%       2.409ms       2.409ms       0.000us         0.00%       6.068ms       6.068ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       6.021ms       100.14%       6.021ms       6.021ms             1  
                     aten::scaled_dot_product_attention         0.23%      17.959us         1.79%     140.600us      46.867us       0.000us         0.00%       5.365ms       1.788ms             3  
          aten::_scaled_dot_product_efficient_attention         0.23%      18.141us         1.56%     122.641us      40.880us       0.000us         0.00%       5.365ms       1.788ms             3  
                     aten::_efficient_attention_forward         0.36%      28.699us         1.04%      81.531us      27.177us       5.365ms        89.24%       5.365ms       1.788ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       5.365ms        89.24%       5.365ms       1.788ms             3  
                                       aten::contiguous         0.10%       7.861us        25.24%       1.987ms     220.773us       0.000us         0.00%     702.468us      78.052us             9  
                                            aten::clone         0.26%      20.540us        25.14%       1.979ms     219.899us       0.000us         0.00%     702.468us      78.052us             9  
                                            aten::copy_         0.92%      72.171us        24.24%       1.908ms     212.002us     646.884us        10.76%     702.468us      78.052us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     646.884us        10.76%     646.884us      71.876us             9  
                                Activity Buffer Request        22.46%       1.768ms        22.46%       1.768ms       1.768ms      55.584us         0.92%      55.584us      55.584us             1  
                                        aten::transpose         0.60%      47.471us         0.81%      64.120us       2.672us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.21%      16.649us         0.21%      16.649us       0.694us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.15%      11.960us         0.64%      50.531us       5.615us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.81%      63.971us         0.81%      63.971us       3.046us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         1.13%      89.282us         1.13%      89.282us       7.440us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.660us         0.03%       2.660us       0.887us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.04%       3.150us         0.04%       3.150us       1.050us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        69.40%       5.462ms        69.40%       5.462ms       5.462ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.871ms
Self CUDA time total: 6.012ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L384_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         2.93%     240.625us        31.13%       2.555ms       2.555ms       0.000us         0.00%       6.259ms       6.259ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       6.208ms       100.13%       6.208ms       6.208ms             1  
                     aten::scaled_dot_product_attention         0.21%      17.361us         1.73%     142.203us      47.401us       0.000us         0.00%       5.537ms       1.846ms             3  
          aten::_scaled_dot_product_efficient_attention         0.22%      18.441us         1.52%     124.842us      41.614us       0.000us         0.00%       5.537ms       1.846ms             3  
                     aten::_efficient_attention_forward         0.36%      29.601us         1.03%      84.471us      28.157us       5.537ms        89.30%       5.537ms       1.846ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       5.537ms        89.30%       5.537ms       1.846ms             3  
                                       aten::contiguous         0.09%       7.769us        25.95%       2.130ms     236.658us       0.000us         0.00%     721.984us      80.220us             9  
                                            aten::clone         0.26%      21.609us        25.85%       2.122ms     235.795us       0.000us         0.00%     721.984us      80.220us             9  
                                            aten::copy_         0.80%      65.822us        24.94%       2.047ms     227.475us     663.552us        10.70%     721.984us      80.220us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     663.552us        10.70%     663.552us      73.728us             9  
                                Activity Buffer Request        21.30%       1.749ms        21.30%       1.749ms       1.749ms      58.432us         0.94%      58.432us      58.432us             1  
                                        aten::transpose         0.59%      48.680us         0.78%      64.131us       2.672us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.19%      15.451us         0.19%      15.451us       0.644us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.15%      12.591us         0.65%      53.271us       5.919us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.81%      66.120us         0.81%      66.120us       3.149us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         3.12%     256.044us         3.12%     256.044us      21.337us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.670us         0.03%       2.670us       0.890us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.04%       3.480us         0.04%       3.480us       1.160us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        68.87%       5.653ms        68.87%       5.653ms       5.653ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 8.208ms
Self CUDA time total: 6.200ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L448_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         2.93%     245.582us        31.52%       2.645ms       2.645ms       0.000us         0.00%       6.354ms       6.354ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       6.303ms       100.13%       6.303ms       6.303ms             1  
                     aten::scaled_dot_product_attention         0.20%      17.170us         1.68%     140.693us      46.898us       0.000us         0.00%       5.628ms       1.876ms             3  
          aten::_scaled_dot_product_efficient_attention         0.21%      17.520us         1.47%     123.523us      41.174us       0.000us         0.00%       5.628ms       1.876ms             3  
                     aten::_efficient_attention_forward         0.35%      29.440us         1.00%      84.263us      28.088us       5.628ms        89.41%       5.628ms       1.876ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       5.628ms        89.41%       5.628ms       1.876ms             3  
                                       aten::contiguous         0.09%       7.259us        26.43%       2.218ms     246.393us       0.000us         0.00%     726.309us      80.701us             9  
                                            aten::clone         0.25%      21.219us        26.34%       2.210ms     245.587us       0.000us         0.00%     726.309us      80.701us             9  
                                            aten::copy_         0.78%      65.083us        25.46%       2.136ms     237.368us     666.948us        10.59%     726.309us      80.701us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     666.948us        10.59%     666.948us      74.105us             9  
                                Activity Buffer Request        21.84%       1.833ms        21.84%       1.833ms       1.833ms      59.361us         0.94%      59.361us      59.361us             1  
                                        aten::transpose         0.56%      46.780us         0.75%      62.730us       2.614us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.19%      15.950us         0.19%      15.950us       0.665us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.14%      11.512us         0.63%      52.753us       5.861us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.79%      66.642us         0.79%      66.642us       3.173us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         3.12%     261.945us         3.12%     261.945us      21.829us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.500us         0.03%       2.500us       0.833us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.04%       3.581us         0.04%       3.581us       1.194us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        68.48%       5.745ms        68.48%       5.745ms       5.745ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 8.390ms
Self CUDA time total: 6.295ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L512_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         2.68%     234.298us        28.81%       2.516ms       2.516ms       0.000us         0.00%       6.820ms       6.820ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       6.768ms       100.12%       6.768ms       6.768ms             1  
                     aten::scaled_dot_product_attention         0.20%      17.618us         1.61%     140.900us      46.967us       0.000us         0.00%       6.087ms       2.029ms             3  
          aten::_scaled_dot_product_efficient_attention         0.21%      18.311us         1.41%     123.282us      41.094us       0.000us         0.00%       6.087ms       2.029ms             3  
                     aten::_efficient_attention_forward         0.33%      29.191us         0.95%      82.621us      27.540us       6.087ms        90.04%       6.087ms       2.029ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       6.087ms        90.04%       6.087ms       2.029ms             3  
                                       aten::contiguous         0.09%       7.641us        24.06%       2.101ms     233.417us       0.000us         0.00%     733.380us      81.487us             9  
                                            aten::clone         0.23%      20.279us        23.97%       2.093ms     232.568us       0.000us         0.00%     733.380us      81.487us             9  
                                            aten::copy_         0.74%      64.431us        23.10%       2.017ms     224.097us     672.964us         9.96%     733.380us      81.487us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     672.964us         9.96%     672.964us      74.774us             9  
                                Activity Buffer Request        19.61%       1.713ms        19.61%       1.713ms       1.713ms      60.416us         0.89%      60.416us      60.416us             1  
                                        aten::transpose         0.53%      46.410us         0.71%      62.109us       2.588us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.18%      15.699us         0.18%      15.699us       0.654us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.15%      12.751us         0.64%      55.961us       6.218us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.79%      69.050us         0.79%      69.050us       3.288us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         2.99%     261.415us         2.99%     261.415us      21.785us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.920us         0.03%       2.920us       0.973us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.03%       2.980us         0.03%       2.980us       0.993us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        71.19%       6.216ms        71.19%       6.216ms       6.216ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 8.732ms
Self CUDA time total: 6.759ms


impl                     wl                  p50(ms)  ok
torch_mem_eff            cuda_attn_L128_bfloat16     1.84  True
torch_mem_eff            cuda_attn_L256_bfloat16     1.95  True
torch_mem_eff            cuda_attn_L320_bfloat16     1.97  True
torch_mem_eff            cuda_attn_L384_bfloat16     2.08  True
torch_mem_eff            cuda_attn_L448_bfloat16     2.04  True
torch_mem_eff            cuda_attn_L512_bfloat16     2.25  True

Artifacts:

attention.jsonl