Memory Efficient Attention Implementation

Memory Efficient SDPA Benchmark

▼ code ▼ output ▶ uv-logs | Cell: benchmark | 32.68s | Raw GitHub
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "numpy",
#     "torch==2.8.0",
#     "kernels-benchmark-tools",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
# ///
import torch
import sys
from kernels_benchmark_tools import KernelTypeEnum, run_benchmark


def torch_mem_eff(q, k, v):
    qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
    with torch.nn.attention.sdpa_kernel(
        torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION
    ):
        o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
    return o.transpose(1, 2).contiguous()


run_benchmark(
    kernel_type=KernelTypeEnum.ATTENTION,
    impl_name="torch_mem_eff",
    impl_tags={"family": "torch-sdpa", "backend": "EFFICIENT", "compile": "none"},
    impl_func=torch_mem_eff,
)
Running attention benchmark on cuda with 6 workloads.

======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L128_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         4.77%     340.490us        32.91%       2.350ms       2.350ms       0.000us         0.00%       5.530ms       5.530ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       5.523ms       100.81%       5.523ms       5.523ms             1  
                     aten::scaled_dot_product_attention         0.44%      31.421us         2.67%     190.938us      63.646us       0.000us         0.00%       4.861ms       1.620ms             3  
          aten::_scaled_dot_product_efficient_attention         0.35%      24.771us         2.23%     159.517us      53.172us       0.000us         0.00%       4.861ms       1.620ms             3  
                     aten::_efficient_attention_forward         0.51%      36.163us         1.50%     107.413us      35.804us       4.861ms        88.73%       4.861ms       1.620ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       4.861ms        88.73%       4.861ms       1.620ms             3  
                                       aten::contiguous         0.17%      12.232us        24.52%       1.751ms     194.525us       0.000us         0.00%     668.128us      74.236us             9  
                                            aten::clone         0.48%      34.579us        24.35%       1.738ms     193.165us       0.000us         0.00%     668.128us      74.236us             9  
                                            aten::copy_         1.16%      82.494us        22.79%       1.628ms     180.845us     617.312us        11.27%     668.128us      74.236us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     617.312us        11.27%     617.312us      68.590us             9  
                                Activity Buffer Request        20.35%       1.453ms        20.35%       1.453ms       1.453ms      50.816us         0.93%      50.816us      50.816us             1  
                                        aten::transpose         1.00%      71.754us         1.33%      95.065us       3.961us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.33%      23.311us         0.33%      23.311us       0.971us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.27%      19.481us         1.07%      76.301us       8.478us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         1.26%      89.759us         1.26%      89.759us       4.274us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         1.62%     115.656us         1.62%     115.656us       9.638us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.04%       2.980us         0.04%       2.980us       0.993us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.16%      11.490us         0.16%      11.490us       3.830us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        67.09%       4.790ms        67.09%       4.790ms       4.790ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.140ms
Self CUDA time total: 5.479ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L256_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         3.38%     251.986us        27.98%       2.086ms       2.086ms       0.000us         0.00%       6.014ms       6.014ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       5.969ms       100.15%       5.969ms       5.969ms             1  
                     aten::scaled_dot_product_attention         0.27%      19.962us         1.97%     146.646us      48.882us       0.000us         0.00%       5.323ms       1.774ms             3  
          aten::_scaled_dot_product_efficient_attention         0.26%      19.141us         1.70%     126.684us      42.228us       0.000us         0.00%       5.323ms       1.774ms             3  
                     aten::_efficient_attention_forward         0.39%      29.281us         1.12%      83.514us      27.838us       5.323ms        89.32%       5.323ms       1.774ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       5.323ms        89.32%       5.323ms       1.774ms             3  
                                       aten::contiguous         0.10%       7.510us        22.05%       1.644ms     182.655us       0.000us         0.00%     690.909us      76.768us             9  
                                            aten::clone         0.31%      23.251us        21.95%       1.636ms     181.821us       0.000us         0.00%     690.909us      76.768us             9  
                                            aten::copy_         0.91%      68.131us        20.95%       1.562ms     173.540us     636.478us        10.68%     690.909us      76.768us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     636.478us        10.68%     636.478us      70.720us             9  
                                Activity Buffer Request        19.09%       1.423ms        19.09%       1.423ms       1.423ms      54.431us         0.91%      54.431us      54.431us             1  
                                        aten::transpose         0.68%      50.542us         0.90%      67.292us       2.804us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.22%      16.750us         0.22%      16.750us       0.698us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.17%      12.371us         0.69%      51.272us       5.697us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.87%      64.771us         0.87%      64.771us       3.084us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         1.25%      93.466us         1.25%      93.466us       7.789us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.400us         0.03%       2.400us       0.800us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.05%       3.371us         0.05%       3.371us       1.124us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        72.02%       5.368ms        72.02%       5.368ms       5.368ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.454ms
Self CUDA time total: 5.959ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L320_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         3.08%     235.490us        27.25%       2.083ms       2.083ms       0.000us         0.00%       6.182ms       6.182ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       6.132ms       100.15%       6.132ms       6.132ms             1  
                     aten::scaled_dot_product_attention         0.24%      18.220us         1.86%     142.046us      47.349us       0.000us         0.00%       5.466ms       1.822ms             3  
          aten::_scaled_dot_product_efficient_attention         0.24%      18.131us         1.62%     123.826us      41.275us       0.000us         0.00%       5.466ms       1.822ms             3  
                     aten::_efficient_attention_forward         0.37%      27.940us         1.08%      82.291us      27.430us       5.466ms        89.28%       5.466ms       1.822ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       5.466ms        89.28%       5.466ms       1.822ms             3  
                                       aten::contiguous         0.10%       7.272us        21.47%       1.642ms     182.409us       0.000us         0.00%     715.197us      79.466us             9  
                                            aten::clone         0.29%      22.290us        21.38%       1.634ms     181.601us       0.000us         0.00%     715.197us      79.466us             9  
                                            aten::copy_         0.83%      63.251us        20.39%       1.559ms     173.182us     656.318us        10.72%     715.197us      79.466us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     656.318us        10.72%     656.318us      72.924us             9  
                                Activity Buffer Request        18.70%       1.430ms        18.70%       1.430ms       1.430ms      58.879us         0.96%      58.879us      58.879us             1  
                                        aten::transpose         0.93%      71.209us         1.15%      87.625us       3.651us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.21%      16.416us         0.21%      16.416us       0.684us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.15%      11.741us         0.70%      53.481us       5.942us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.89%      67.840us         0.89%      67.840us       3.230us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         1.15%      88.022us         1.15%      88.022us       7.335us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.651us         0.03%       2.651us       0.884us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.04%       3.370us         0.04%       3.370us       1.123us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        72.75%       5.562ms        72.75%       5.562ms       5.562ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.646ms
Self CUDA time total: 6.123ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L384_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         2.84%     224.838us        29.78%       2.354ms       2.354ms       0.000us         0.00%       6.170ms       6.170ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       6.121ms       100.15%       6.121ms       6.121ms             1  
                     aten::scaled_dot_product_attention         0.24%      18.891us         1.82%     143.646us      47.882us       0.000us         0.00%       5.458ms       1.819ms             3  
          aten::_scaled_dot_product_efficient_attention         0.24%      19.093us         1.58%     124.755us      41.585us       0.000us         0.00%       5.458ms       1.819ms             3  
                     aten::_efficient_attention_forward         0.36%      28.140us         1.04%      82.213us      27.404us       5.458ms        89.30%       5.458ms       1.819ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       5.458ms        89.30%       5.458ms       1.819ms             3  
                                       aten::contiguous         0.10%       7.739us        24.57%       1.942ms     215.806us       0.000us         0.00%     711.998us      79.111us             9  
                                            aten::clone         0.31%      24.450us        24.47%       1.935ms     214.946us       0.000us         0.00%     711.998us      79.111us             9  
                                            aten::copy_         0.86%      68.064us        23.51%       1.859ms     206.523us     653.982us        10.70%     711.998us      79.111us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     653.982us        10.70%     653.982us      72.665us             9  
                                Activity Buffer Request        18.84%       1.489ms        18.84%       1.489ms       1.489ms      58.016us         0.95%      58.016us      58.016us             1  
                                        aten::transpose         0.62%      49.288us         0.84%      66.489us       2.770us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.22%      17.201us         0.22%      17.201us       0.717us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.15%      12.041us         0.65%      51.362us       5.707us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.83%      65.351us         0.83%      65.351us       3.112us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         4.09%     323.234us         4.09%     323.234us      26.936us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.670us         0.03%       2.670us       0.890us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.04%       3.430us         0.04%       3.430us       1.143us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        70.22%       5.551ms        70.22%       5.551ms       5.551ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.905ms
Self CUDA time total: 6.112ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L448_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         2.78%     220.799us        28.42%       2.258ms       2.258ms       0.000us         0.00%       6.296ms       6.296ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       6.245ms       100.15%       6.245ms       6.245ms             1  
                     aten::scaled_dot_product_attention         0.24%      19.311us         1.79%     142.116us      47.372us       0.000us         0.00%       5.574ms       1.858ms             3  
          aten::_scaled_dot_product_efficient_attention         0.23%      17.909us         1.55%     122.805us      40.935us       0.000us         0.00%       5.574ms       1.858ms             3  
                     aten::_efficient_attention_forward         0.36%      28.682us         1.03%      82.073us      27.358us       5.574ms        89.39%       5.574ms       1.858ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       5.574ms        89.39%       5.574ms       1.858ms             3  
                                       aten::contiguous         0.09%       7.009us        23.32%       1.852ms     205.811us       0.000us         0.00%     721.599us      80.178us             9  
                                            aten::clone         0.28%      22.450us        23.23%       1.845ms     205.033us       0.000us         0.00%     721.599us      80.178us             9  
                                            aten::copy_         0.87%      68.713us        22.33%       1.774ms     197.096us     661.695us        10.61%     721.599us      80.178us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     661.695us        10.61%     661.695us      73.522us             9  
                                Activity Buffer Request        17.91%       1.422ms        17.91%       1.422ms       1.422ms      59.904us         0.96%      59.904us      59.904us             1  
                                        aten::transpose         0.61%      48.435us         0.82%      65.304us       2.721us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.21%      16.869us         0.21%      16.869us       0.703us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.14%      11.511us         0.62%      48.982us       5.442us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.78%      61.691us         0.78%      61.691us       2.938us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         3.85%     305.580us         3.85%     305.580us      25.465us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.440us         0.03%       2.440us       0.813us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.05%       3.920us         0.05%       3.920us       1.307us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        71.58%       5.685ms        71.58%       5.685ms       5.685ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.943ms
Self CUDA time total: 6.236ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L512_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         3.27%     267.711us        29.30%       2.401ms       2.401ms       0.000us         0.00%       6.459ms       6.459ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       6.406ms       100.13%       6.406ms       6.406ms             1  
                     aten::scaled_dot_product_attention         0.24%      19.643us         1.85%     151.176us      50.392us       0.000us         0.00%       5.726ms       1.909ms             3  
          aten::_scaled_dot_product_efficient_attention         0.26%      20.920us         1.61%     131.533us      43.844us       0.000us         0.00%       5.726ms       1.909ms             3  
                     aten::_efficient_attention_forward         0.37%      30.563us         1.03%      84.603us      28.201us       5.726ms        89.50%       5.726ms       1.909ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       5.726ms        89.50%       5.726ms       1.909ms             3  
                                       aten::contiguous         0.09%       7.670us        23.58%       1.932ms     214.647us       0.000us         0.00%     733.247us      81.472us             9  
                                            aten::clone         0.31%      25.042us        23.48%       1.924ms     213.795us       0.000us         0.00%     733.247us      81.472us             9  
                                            aten::copy_         0.88%      72.162us        22.52%       1.845ms     205.052us     671.711us        10.50%     733.247us      81.472us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     671.711us        10.50%     671.711us      74.635us             9  
                                Activity Buffer Request        17.78%       1.456ms        17.78%       1.456ms       1.456ms      61.536us         0.96%      61.536us      61.536us             1  
                                        aten::transpose         0.71%      58.110us         0.93%      75.842us       3.160us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.22%      17.732us         0.22%      17.732us       0.739us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.15%      12.319us         0.65%      53.641us       5.960us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.81%      66.513us         0.81%      66.513us       3.167us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         4.14%     339.159us         4.14%     339.159us      28.263us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.379us         0.03%       2.379us       0.793us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.05%       4.230us         0.05%       4.230us       1.410us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        70.70%       5.793ms        70.70%       5.793ms       5.793ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 8.193ms
Self CUDA time total: 6.398ms


impl                     wl                  p50(ms)  ok
torch_mem_eff            cuda_attn_L128_bfloat16     1.86  True
torch_mem_eff            cuda_attn_L256_bfloat16     1.97  True
torch_mem_eff            cuda_attn_L320_bfloat16     2.04  True
torch_mem_eff            cuda_attn_L384_bfloat16     2.06  True
torch_mem_eff            cuda_attn_L448_bfloat16     2.03  True
torch_mem_eff            cuda_attn_L512_bfloat16     2.19  True
▶ UV Install Logs

Artifacts:

attention.jsonl