Memory Efficient Attention Implementation

Memory Efficient SDPA Benchmark

▼ code ▼ output ▶ uv-logs | Cell: benchmark | 3.92s | Raw GitHub
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "numpy",
#     "torch==2.8.0",
#     "kernels-benchmark-tools",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
# ///
import torch
import sys
from kernels_benchmark_tools import KernelTypeEnum, run_benchmark


def torch_mem_eff(q, k, v):
    qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
    with torch.nn.attention.sdpa_kernel(
        torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION
    ):
        o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
    return o.transpose(1, 2).contiguous()


run_benchmark(
    kernel_type=KernelTypeEnum.ATTENTION,
    impl_name="torch_mem_eff",
    impl_tags={"family": "torch-sdpa", "backend": "EFFICIENT", "compile": "none"},
    impl_func=torch_mem_eff,
)
Running attention benchmark on cuda with 6 workloads.

======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L128_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         4.88%     347.876us        33.28%       2.372ms       2.372ms       0.000us         0.00%       5.473ms       5.473ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       5.465ms       100.77%       5.465ms       5.465ms             1  
                     aten::scaled_dot_product_attention         0.44%      31.501us         2.47%     176.074us      58.691us       0.000us         0.00%       4.806ms       1.602ms             3  
          aten::_scaled_dot_product_efficient_attention         0.33%      23.351us         2.03%     144.573us      48.191us       0.000us         0.00%       4.806ms       1.602ms             3  
                     aten::_efficient_attention_forward         0.48%      33.995us         1.40%      99.622us      33.207us       4.806ms        88.63%       4.806ms       1.602ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       4.806ms        88.63%       4.806ms       1.602ms             3  
                                       aten::contiguous         0.20%      13.962us        24.98%       1.780ms     197.762us       0.000us         0.00%     667.264us      74.140us             9  
                                            aten::clone         0.48%      34.432us        24.78%       1.766ms     196.211us       0.000us         0.00%     667.264us      74.140us             9  
                                            aten::copy_         1.03%      73.682us        23.27%       1.658ms     184.268us     616.768us        11.37%     667.264us      74.140us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     616.768us        11.37%     616.768us      68.530us             9  
                                Activity Buffer Request        21.06%       1.501ms        21.06%       1.501ms       1.501ms      50.496us         0.93%      50.496us      50.496us             1  
                                        aten::transpose         0.94%      67.099us         1.26%      89.541us       3.731us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.31%      22.442us         0.31%      22.442us       0.935us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.26%      18.431us         1.03%      73.051us       8.117us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         1.15%      82.238us         1.15%      82.238us       3.916us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         1.53%     109.170us         1.53%     109.170us       9.098us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.04%       3.169us         0.04%       3.169us       1.056us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.13%       9.530us         0.13%       9.530us       3.177us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        66.72%       4.754ms        66.72%       4.754ms       4.754ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.126ms
Self CUDA time total: 5.423ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L256_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         3.49%     251.026us        29.53%       2.123ms       2.123ms       0.000us         0.00%       5.671ms       5.671ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       5.625ms       100.14%       5.625ms       5.625ms             1  
                     aten::scaled_dot_product_attention         0.28%      19.941us         1.97%     141.843us      47.281us       0.000us         0.00%       4.980ms       1.660ms             3  
          aten::_scaled_dot_product_efficient_attention         0.25%      17.669us         1.70%     121.902us      40.634us       0.000us         0.00%       4.980ms       1.660ms             3  
                     aten::_efficient_attention_forward         0.38%      27.651us         1.14%      82.182us      27.394us       4.980ms        88.66%       4.980ms       1.660ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       4.980ms        88.66%       4.980ms       1.660ms             3  
                                       aten::contiguous         0.10%       7.480us        23.48%       1.688ms     187.567us       0.000us         0.00%     691.071us      76.786us             9  
                                            aten::clone         0.30%      21.261us        23.38%       1.681ms     186.736us       0.000us         0.00%     691.071us      76.786us             9  
                                            aten::copy_         0.85%      60.983us        22.39%       1.610ms     178.842us     637.247us        11.34%     691.071us      76.786us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     637.247us        11.34%     637.247us      70.805us             9  
                                Activity Buffer Request        20.63%       1.483ms        20.63%       1.483ms       1.483ms      53.824us         0.96%      53.824us      53.824us             1  
                                        aten::transpose         0.67%      48.164us         0.89%      64.122us       2.672us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.22%      15.958us         0.22%      15.958us       0.665us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.16%      11.580us         0.69%      49.790us       5.532us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.89%      63.701us         0.89%      63.701us       3.033us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         1.22%      87.751us         1.22%      87.751us       7.313us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.04%       3.090us         0.04%       3.090us       1.030us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.05%       3.339us         0.05%       3.339us       1.113us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        70.47%       5.066ms        70.47%       5.066ms       5.066ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.190ms
Self CUDA time total: 5.617ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L320_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         3.37%     266.115us        31.17%       2.458ms       2.458ms       0.000us         0.00%       6.082ms       6.082ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       6.032ms       100.14%       6.032ms       6.032ms             1  
                     aten::scaled_dot_product_attention         0.25%      19.720us         1.92%     151.403us      50.468us       0.000us         0.00%       5.369ms       1.790ms             3  
          aten::_scaled_dot_product_efficient_attention         0.24%      18.800us         1.67%     131.683us      43.894us       0.000us         0.00%       5.369ms       1.790ms             3  
                     aten::_efficient_attention_forward         0.36%      28.452us         1.04%      81.963us      27.321us       5.369ms        89.14%       5.369ms       1.790ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       5.369ms        89.14%       5.369ms       1.790ms             3  
                                       aten::contiguous         0.10%       7.851us        25.32%       1.997ms     221.887us       0.000us         0.00%     712.865us      79.207us             9  
                                            aten::clone         0.51%      40.412us        25.22%       1.989ms     221.015us       0.000us         0.00%     712.865us      79.207us             9  
                                            aten::copy_         0.83%      65.138us        24.07%       1.898ms     210.924us     654.369us        10.86%     712.865us      79.207us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     654.369us        10.86%     654.369us      72.708us             9  
                                Activity Buffer Request        22.37%       1.764ms        22.37%       1.764ms       1.764ms      58.496us         0.97%      58.496us      58.496us             1  
                                        aten::transpose         0.63%      49.872us         0.95%      74.812us       3.117us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.32%      24.940us         0.32%      24.940us       1.039us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.15%      11.509us         0.64%      50.401us       5.600us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.82%      64.330us         0.82%      64.330us       3.063us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         1.16%      91.554us         1.16%      91.554us       7.629us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.671us         0.03%       2.671us       0.890us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.04%       3.101us         0.04%       3.101us       1.034us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        68.83%       5.428ms        68.83%       5.428ms       5.428ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.886ms
Self CUDA time total: 6.024ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L384_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         4.19%     329.379us        30.22%       2.377ms       2.377ms       0.000us         0.00%       6.195ms       6.195ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       6.146ms       100.15%       6.146ms       6.146ms             1  
                     aten::scaled_dot_product_attention         0.26%      20.400us         1.80%     141.523us      47.174us       0.000us         0.00%       5.484ms       1.828ms             3  
          aten::_scaled_dot_product_efficient_attention         0.23%      17.780us         1.54%     121.123us      40.374us       0.000us         0.00%       5.484ms       1.828ms             3  
                     aten::_efficient_attention_forward         0.36%      28.239us         1.03%      81.303us      27.101us       5.484ms        89.36%       5.484ms       1.828ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       5.484ms        89.36%       5.484ms       1.828ms             3  
                                       aten::contiguous         0.10%       8.071us        23.69%       1.863ms     207.042us       0.000us         0.00%     711.166us      79.018us             9  
                                            aten::clone         0.27%      21.510us        23.59%       1.855ms     206.145us       0.000us         0.00%     711.166us      79.018us             9  
                                            aten::copy_         0.81%      63.940us        22.65%       1.781ms     197.883us     652.767us        10.64%     711.166us      79.018us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     652.767us        10.64%     652.767us      72.530us             9  
                                Activity Buffer Request        18.20%       1.431ms        18.20%       1.431ms       1.431ms      58.399us         0.95%      58.399us      58.399us             1  
                                        aten::transpose         0.61%      48.309us         0.82%      64.340us       2.681us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.20%      16.031us         0.20%      16.031us       0.668us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.14%      11.029us         0.67%      52.851us       5.872us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.84%      66.365us         0.84%      66.365us       3.160us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         3.91%     307.476us         3.91%     307.476us      25.623us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.550us         0.03%       2.550us       0.850us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.05%       4.011us         0.05%       4.011us       1.337us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        69.78%       5.488ms        69.78%       5.488ms       5.488ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.864ms
Self CUDA time total: 6.137ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L448_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         3.07%     246.275us        28.09%       2.251ms       2.251ms       0.000us         0.00%       6.379ms       6.379ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       6.328ms       100.14%       6.328ms       6.328ms             1  
                     aten::scaled_dot_product_attention         0.24%      19.011us         1.78%     142.253us      47.418us       0.000us         0.00%       5.653ms       1.884ms             3  
          aten::_scaled_dot_product_efficient_attention         0.24%      19.261us         1.54%     123.242us      41.081us       0.000us         0.00%       5.653ms       1.884ms             3  
                     aten::_efficient_attention_forward         0.35%      28.069us         1.02%      81.511us      27.170us       5.653ms        89.46%       5.653ms       1.884ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       5.653ms        89.46%       5.653ms       1.884ms             3  
                                       aten::contiguous         0.10%       7.649us        22.70%       1.819ms     202.115us       0.000us         0.00%     725.600us      80.622us             9  
                                            aten::clone         0.27%      22.011us        22.61%       1.811ms     201.265us       0.000us         0.00%     725.600us      80.622us             9  
                                            aten::copy_         0.79%      63.041us        21.68%       1.737ms     193.055us     666.112us        10.54%     725.600us      80.622us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     666.112us        10.54%     666.112us      74.012us             9  
                                Activity Buffer Request        18.14%       1.453ms        18.14%       1.453ms       1.453ms      59.488us         0.94%      59.488us      59.488us             1  
                                        aten::transpose         0.62%      49.849us         0.82%      66.103us       2.754us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.20%      16.254us         0.20%      16.254us       0.677us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.15%      11.889us         0.65%      51.880us       5.764us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.80%      64.291us         0.80%      64.291us       3.061us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         3.04%     243.917us         3.04%     243.917us      20.326us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.04%       3.200us         0.04%       3.200us       1.067us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.04%       3.130us         0.04%       3.130us       1.043us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        71.91%       5.762ms        71.91%       5.762ms       5.762ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 8.013ms
Self CUDA time total: 6.319ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L512_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         2.99%     249.826us        26.96%       2.254ms       2.254ms       0.000us         0.00%       6.738ms       6.738ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       6.686ms       100.13%       6.686ms       6.686ms             1  
                     aten::scaled_dot_product_attention         0.22%      18.532us         1.72%     143.464us      47.821us       0.000us         0.00%       6.005ms       2.002ms             3  
          aten::_scaled_dot_product_efficient_attention         0.24%      19.750us         1.49%     124.932us      41.644us       0.000us         0.00%       6.005ms       2.002ms             3  
                     aten::_efficient_attention_forward         0.34%      28.159us         0.97%      81.312us      27.104us       6.005ms        89.92%       6.005ms       2.002ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       6.005ms        89.92%       6.005ms       2.002ms             3  
                                       aten::contiguous         0.11%       8.892us        21.70%       1.814ms     201.591us       0.000us         0.00%     733.564us      81.507us             9  
                                            aten::clone         0.28%      23.489us        21.59%       1.805ms     200.603us       0.000us         0.00%     733.564us      81.507us             9  
                                            aten::copy_         0.78%      65.381us        20.67%       1.729ms     192.090us     672.957us        10.08%     733.564us      81.507us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     672.957us        10.08%     672.957us      74.773us             9  
                                Activity Buffer Request        17.24%       1.442ms        17.24%       1.442ms       1.442ms      60.607us         0.91%      60.607us      60.607us             1  
                                        aten::transpose         0.64%      53.558us         0.84%      70.590us       2.941us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.20%      17.032us         0.20%      17.032us       0.710us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.15%      12.490us         0.64%      53.131us       5.903us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.79%      65.813us         0.79%      65.813us       3.134us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         2.91%     243.356us         2.91%     243.356us      20.280us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.04%       3.000us         0.04%       3.000us       1.000us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.04%       3.289us         0.04%       3.289us       1.096us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        73.04%       6.108ms        73.04%       6.108ms       6.108ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 8.362ms
Self CUDA time total: 6.678ms


impl                     wl                  p50(ms)  ok
torch_mem_eff            cuda_attn_L128_bfloat16     1.84  True
torch_mem_eff            cuda_attn_L256_bfloat16     1.91  True
torch_mem_eff            cuda_attn_L320_bfloat16     1.96  True
torch_mem_eff            cuda_attn_L384_bfloat16     2.04  True
torch_mem_eff            cuda_attn_L448_bfloat16     2.10  True
torch_mem_eff            cuda_attn_L512_bfloat16     2.18  True

Artifacts:

attention.jsonl