Memory Efficient Attention Implementation

Memory Efficient SDPA Benchmark

▼ code ▼ output ▶ uv-logs | Cell: benchmark | 8.14s | Raw GitHub
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "numpy",
#     "torch==2.8.0",
#     "kernels-benchmark-tools",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
# ///
import torch
import sys
from kernels_benchmark_tools import KernelTypeEnum, run_benchmark


def torch_mem_eff(q, k, v):
    qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
    with torch.nn.attention.sdpa_kernel(
        torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION
    ):
        o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
    return o.transpose(1, 2).contiguous()


run_benchmark(
    kernel_type=KernelTypeEnum.ATTENTION,
    impl_name="torch_mem_eff",
    impl_tags={"family": "torch-sdpa", "backend": "EFFICIENT", "compile": "none"},
    impl_func=torch_mem_eff,
)
Running attention benchmark on cuda with 6 workloads.

======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L128_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       5.462ms       101.52%       5.462ms       5.462ms             1  
                                          torch_mem_eff         4.78%     351.785us        36.36%       2.675ms       2.675ms       0.000us         0.00%       5.434ms       5.434ms             1  
                     aten::scaled_dot_product_attention         0.44%      32.361us         3.09%     227.216us      75.739us       0.000us         0.00%       4.760ms       1.587ms             3  
          aten::_scaled_dot_product_efficient_attention         0.32%      23.392us         2.65%     194.855us      64.952us       0.000us         0.00%       4.760ms       1.587ms             3  
                     aten::_efficient_attention_forward         0.47%      34.731us         1.98%     145.602us      48.534us       4.760ms        88.47%       4.760ms       1.587ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       4.760ms        88.47%       4.760ms       1.587ms             3  
                                       aten::contiguous         0.14%      10.161us        27.51%       2.023ms     224.817us       0.000us         0.00%     673.947us      74.883us             9  
                                            aten::clone         0.40%      29.063us        27.37%       2.013ms     223.688us       0.000us         0.00%     673.947us      74.883us             9  
                                            aten::copy_         1.06%      77.620us        25.90%       1.905ms     211.680us     620.444us        11.53%     673.947us      74.883us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     620.444us        11.53%     620.444us      68.938us             9  
                                Activity Buffer Request        23.68%       1.742ms        23.68%       1.742ms       1.742ms      53.503us         0.99%      53.503us      53.503us             1  
                                        aten::transpose         0.99%      72.964us         1.33%      98.194us       4.091us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.34%      25.230us         0.34%      25.230us       1.051us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.25%      18.168us         1.07%      79.009us       8.779us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         1.28%      94.381us         1.28%      94.381us       4.494us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         1.49%     109.573us         1.49%     109.573us       9.131us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.05%       3.660us         0.05%       3.660us       1.220us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.67%      49.491us         0.67%      49.491us      16.497us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        63.64%       4.681ms        63.64%       4.681ms       4.681ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.356ms
Self CUDA time total: 5.380ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L256_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         2.99%     227.637us        31.17%       2.369ms       2.369ms       0.000us         0.00%       5.835ms       5.835ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       5.790ms       100.14%       5.790ms       5.790ms             1  
                     aten::scaled_dot_product_attention         0.23%      17.721us         1.87%     142.143us      47.381us       0.000us         0.00%       5.146ms       1.715ms             3  
          aten::_scaled_dot_product_efficient_attention         0.25%      18.819us         1.64%     124.422us      41.474us       0.000us         0.00%       5.146ms       1.715ms             3  
                     aten::_efficient_attention_forward         0.37%      28.141us         1.08%      82.262us      27.421us       5.146ms        89.01%       5.146ms       1.715ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       5.146ms        89.01%       5.146ms       1.715ms             3  
                                       aten::contiguous         0.09%       6.739us        25.75%       1.957ms     217.483us       0.000us         0.00%     689.503us      76.611us             9  
                                            aten::clone         0.27%      20.691us        25.66%       1.951ms     216.734us       0.000us         0.00%     689.503us      76.611us             9  
                                            aten::copy_         0.83%      62.851us        24.72%       1.879ms     208.808us     635.680us        10.99%     689.503us      76.611us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     635.680us        10.99%     635.680us      70.631us             9  
                                Activity Buffer Request        23.06%       1.753ms        23.06%       1.753ms       1.753ms      53.823us         0.93%      53.823us      53.823us             1  
                                        aten::transpose         0.63%      47.890us         0.86%      65.431us       2.726us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.23%      17.541us         0.23%      17.541us       0.731us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.15%      11.310us         0.67%      50.641us       5.627us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.87%      66.232us         0.87%      66.232us       3.154us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         1.12%      85.492us         1.12%      85.492us       7.124us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.460us         0.03%       2.460us       0.820us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.04%       3.070us         0.04%       3.070us       1.023us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        68.83%       5.232ms        68.83%       5.232ms       5.232ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.601ms
Self CUDA time total: 5.782ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L320_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         2.88%     222.044us        30.17%       2.327ms       2.327ms       0.000us         0.00%       5.986ms       5.986ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       5.939ms       100.13%       5.939ms       5.939ms             1  
                     aten::scaled_dot_product_attention         0.24%      18.710us         1.85%     142.303us      47.434us       0.000us         0.00%       5.284ms       1.761ms             3  
          aten::_scaled_dot_product_efficient_attention         0.25%      19.190us         1.60%     123.593us      41.198us       0.000us         0.00%       5.284ms       1.761ms             3  
                     aten::_efficient_attention_forward         0.36%      27.947us         1.05%      81.281us      27.094us       5.284ms        89.10%       5.284ms       1.761ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       5.284ms        89.10%       5.284ms       1.761ms             3  
                                       aten::contiguous         0.09%       7.300us        24.90%       1.920ms     213.350us       0.000us         0.00%     702.238us      78.026us             9  
                                            aten::clone         0.28%      21.930us        24.80%       1.913ms     212.539us       0.000us         0.00%     702.238us      78.026us             9  
                                            aten::copy_         0.79%      60.872us        23.86%       1.840ms     204.449us     646.526us        10.90%     702.238us      78.026us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     646.526us        10.90%     646.526us      71.836us             9  
                                Activity Buffer Request        22.23%       1.715ms        22.23%       1.715ms       1.715ms      55.712us         0.94%      55.712us      55.712us             1  
                                        aten::transpose         0.63%      48.814us         0.85%      65.893us       2.746us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.22%      17.079us         0.22%      17.079us       0.712us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.15%      11.801us         0.66%      50.882us       5.654us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.85%      65.644us         0.85%      65.644us       3.126us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         1.11%      85.622us         1.11%      85.622us       7.135us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.511us         0.03%       2.511us       0.837us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.04%       3.110us         0.04%       3.110us       1.037us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        69.83%       5.385ms        69.83%       5.385ms       5.385ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.713ms
Self CUDA time total: 5.931ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L384_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         3.05%     248.737us        32.15%       2.620ms       2.620ms       0.000us         0.00%       6.167ms       6.167ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       6.117ms       100.13%       6.117ms       6.117ms             1  
                     aten::scaled_dot_product_attention         0.24%      19.380us         1.81%     147.173us      49.058us       0.000us         0.00%       5.450ms       1.817ms             3  
          aten::_scaled_dot_product_efficient_attention         0.23%      19.059us         1.57%     127.793us      42.598us       0.000us         0.00%       5.450ms       1.817ms             3  
                     aten::_efficient_attention_forward         0.34%      28.111us         1.04%      84.373us      28.124us       5.450ms        89.21%       5.450ms       1.817ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       5.450ms        89.21%       5.450ms       1.817ms             3  
                                       aten::contiguous         0.09%       7.070us        26.79%       2.183ms     242.545us       0.000us         0.00%     717.472us      79.719us             9  
                                            aten::clone         0.26%      21.211us        26.70%       2.176ms     241.760us       0.000us         0.00%     717.472us      79.719us             9  
                                            aten::copy_         0.77%      62.427us        25.76%       2.100ms     233.287us     658.976us        10.79%     717.472us      79.719us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     658.976us        10.79%     658.976us      73.220us             9  
                                Activity Buffer Request        21.68%       1.767ms        21.68%       1.767ms       1.767ms      58.496us         0.96%      58.496us      58.496us             1  
                                        aten::transpose         0.59%      47.765us         0.81%      65.883us       2.745us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.22%      18.118us         0.22%      18.118us       0.755us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.14%      11.420us         0.68%      55.041us       6.116us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.87%      71.281us         0.87%      71.281us       3.394us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         3.59%     292.889us         3.59%     292.889us      24.407us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.781us         0.03%       2.781us       0.927us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.04%       3.020us         0.04%       3.020us       1.007us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        67.85%       5.529ms        67.85%       5.529ms       5.529ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 8.150ms
Self CUDA time total: 6.109ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L448_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         2.74%     222.904us        29.02%       2.363ms       2.363ms       0.000us         0.00%       6.392ms       6.392ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       6.341ms       100.13%       6.341ms       6.341ms             1  
                     aten::scaled_dot_product_attention         0.23%      18.463us         1.76%     143.054us      47.685us       0.000us         0.00%       5.664ms       1.888ms             3  
          aten::_scaled_dot_product_efficient_attention         0.23%      18.699us         1.53%     124.591us      41.530us       0.000us         0.00%       5.664ms       1.888ms             3  
                     aten::_efficient_attention_forward         0.35%      28.650us         1.01%      82.071us      27.357us       5.664ms        89.43%       5.664ms       1.888ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       5.664ms        89.43%       5.664ms       1.888ms             3  
                                       aten::contiguous         0.09%       7.480us        24.00%       1.954ms     217.122us       0.000us         0.00%     727.838us      80.871us             9  
                                            aten::clone         0.26%      21.231us        23.90%       1.947ms     216.290us       0.000us         0.00%     727.838us      80.871us             9  
                                            aten::copy_         0.78%      63.523us        23.01%       1.874ms     208.176us     669.182us        10.57%     727.838us      80.871us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     669.182us        10.57%     669.182us      74.354us             9  
                                Activity Buffer Request        19.19%       1.562ms        19.19%       1.562ms       1.562ms      58.656us         0.93%      58.656us      58.656us             1  
                                        aten::transpose         0.60%      48.754us         0.82%      66.672us       2.778us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.22%      17.918us         0.22%      17.918us       0.747us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.14%      11.269us         0.64%      51.800us       5.756us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.81%      66.291us         0.81%      66.291us       3.157us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         3.31%     269.756us         3.31%     269.756us      22.480us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.590us         0.03%       2.590us       0.863us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.04%       2.940us         0.04%       2.940us       0.980us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        70.98%       5.781ms        70.98%       5.781ms       5.781ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 8.144ms
Self CUDA time total: 6.333ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L512_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         2.91%     254.056us        31.19%       2.722ms       2.722ms       0.000us         0.00%       6.645ms       6.645ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       6.592ms       100.12%       6.592ms       6.592ms             1  
                     aten::scaled_dot_product_attention         0.23%      20.440us         1.69%     147.533us      49.178us       0.000us         0.00%       5.910ms       1.970ms             3  
          aten::_scaled_dot_product_efficient_attention         0.22%      19.250us         1.46%     127.093us      42.364us       0.000us         0.00%       5.910ms       1.970ms             3  
                     aten::_efficient_attention_forward         0.33%      28.899us         0.98%      85.242us      28.414us       5.910ms        89.76%       5.910ms       1.970ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       5.910ms        89.76%       5.910ms       1.970ms             3  
                                       aten::contiguous         0.08%       7.268us        26.04%       2.272ms     252.404us       0.000us         0.00%     734.815us      81.646us             9  
                                            aten::clone         0.28%      24.054us        25.95%       2.264ms     251.596us       0.000us         0.00%     734.815us      81.646us             9  
                                            aten::copy_         0.77%      66.891us        25.04%       2.185ms     242.745us     674.239us        10.24%     734.815us      81.646us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     674.239us        10.24%     674.239us      74.915us             9  
                                Activity Buffer Request        20.22%       1.764ms        20.22%       1.764ms       1.764ms      60.576us         0.92%      60.576us      60.576us             1  
                                        aten::transpose         0.62%      53.860us         0.81%      70.972us       2.957us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.20%      17.112us         0.20%      17.112us       0.713us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.15%      12.910us         0.64%      55.601us       6.178us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.82%      71.503us         0.82%      71.503us       3.405us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         4.30%     375.338us         4.30%     375.338us      31.278us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.571us         0.03%       2.571us       0.857us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.03%       3.000us         0.03%       3.000us       1.000us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        68.81%       6.003ms        68.81%       6.003ms       6.003ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 8.725ms
Self CUDA time total: 6.584ms


impl                     wl                  p50(ms)  ok
torch_mem_eff            cuda_attn_L128_bfloat16     1.83  True
torch_mem_eff            cuda_attn_L256_bfloat16     1.93  True
torch_mem_eff            cuda_attn_L320_bfloat16     1.95  True
torch_mem_eff            cuda_attn_L384_bfloat16     2.04  True
torch_mem_eff            cuda_attn_L448_bfloat16     2.08  True
torch_mem_eff            cuda_attn_L512_bfloat16     2.17  True
▶ UV Install Logs

Artifacts:

attention.jsonl