Flash Attention Implementation

GPU Info

▼ code ▼ output ▶ uv-logs | Cell: nv | 0.21s | Raw GitHub
import subprocess

print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)
Wed Oct 29 04:14:27 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.195.03             Driver Version: 570.195.03     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA L40S                    On  |   00000000:4D:00.0 Off |                    0 |
| N/A   36C    P0             80W /  350W |       0MiB /  46068MiB |     11%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

Flash Attention Benchmark

▼ code ▼ output ▶ uv-logs | Cell: benchmark | 3.81s | Raw GitHub
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "numpy",
#     "torch==2.8.0",
#     "kernels-benchmark-tools",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
# ///
import torch
import sys
from kernels_benchmark_tools import KernelTypeEnum, run_benchmark


def torch_flash(q, k, v):
    qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
    with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
        o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
    return o.transpose(1, 2).contiguous()


run_benchmark(
    kernel_type=KernelTypeEnum.ATTENTION,
    impl_name="torch_flash_ma",
    impl_tags={"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"},
    impl_func=torch_flash,
)
Running attention benchmark on cuda with 6 workloads.

======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L128_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       3.590ms       102.17%       3.590ms       3.590ms             1  
                                         torch_flash_ma         6.85%     354.470us        47.44%       2.454ms       2.454ms       0.000us         0.00%       3.554ms       3.554ms             1  
                     aten::scaled_dot_product_attention         0.84%      43.371us         4.38%     226.614us      75.538us       0.000us         0.00%       2.798ms     932.564us             3  
              aten::_scaled_dot_product_flash_attention         0.52%      27.141us         3.54%     183.243us      61.081us       0.000us         0.00%       2.798ms     932.564us             3  
                         aten::_flash_attention_forward         0.84%      43.539us         2.59%     134.122us      44.707us       2.798ms        79.63%       2.798ms     932.564us             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       2.798ms        79.63%       2.798ms     932.564us             3  
                                       aten::contiguous         0.29%      14.889us        34.84%       1.803ms     150.217us       0.000us         0.00%     755.939us      62.995us            12  
                                            aten::clone         0.79%      40.742us        34.56%       1.788ms     148.977us       0.000us         0.00%     755.939us      62.995us            12  
                                            aten::copy_         1.80%      93.020us        31.59%       1.634ms     136.197us     715.586us        20.37%     755.939us      62.995us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     715.586us        20.37%     715.586us      59.632us            12  
                                Activity Buffer Request        27.64%       1.430ms        27.64%       1.430ms       1.430ms      40.353us         1.15%      40.353us      40.353us             1  
                                        aten::transpose         1.35%      70.048us         1.79%      92.780us       3.866us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.44%      22.732us         0.44%      22.732us       0.947us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.49%      25.480us         2.63%     136.134us       9.076us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         2.37%     122.383us         2.37%     122.383us       5.099us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         2.63%     136.154us         2.63%     136.154us       9.077us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.35%      17.861us         0.35%      17.861us       5.954us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.05%       2.732us         0.05%       2.732us       0.455us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.19%      10.040us         0.19%      10.040us       3.347us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        52.56%       2.719ms        52.56%       2.719ms       2.719ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 5.174ms
Self CUDA time total: 3.513ms



======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L256_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         5.13%     269.966us        42.38%       2.232ms       2.232ms       0.000us         0.00%       3.778ms       3.778ms             1  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       3.734ms       100.30%       3.734ms       3.734ms             1  
                     aten::scaled_dot_product_attention         0.51%      26.890us         3.58%     188.304us      62.768us       0.000us         0.00%       2.960ms     986.590us             3  
              aten::_scaled_dot_product_flash_attention         0.35%      18.589us         3.07%     161.414us      53.805us       0.000us         0.00%       2.960ms     986.590us             3  
                         aten::_flash_attention_forward         0.78%      41.299us         2.29%     120.413us      40.138us       2.960ms        79.51%       2.960ms     986.590us             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       2.960ms        79.51%       2.960ms     986.590us             3  
                                       aten::contiguous         0.18%       9.501us        32.77%       1.726ms     143.802us       0.000us         0.00%     818.206us      68.184us            12  
                                            aten::clone         0.54%      28.568us        32.59%       1.716ms     143.010us       0.000us         0.00%     818.206us      68.184us            12  
                                            aten::copy_         1.52%      80.181us        30.79%       1.621ms     135.119us     762.846us        20.49%     818.206us      68.184us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     762.846us        20.49%     762.846us      63.571us            12  
                                Activity Buffer Request        27.52%       1.449ms        27.52%       1.449ms       1.449ms      55.360us         1.49%      55.360us      55.360us             1  
                                        aten::transpose         1.00%      52.915us         1.33%      70.084us       2.920us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.33%      17.169us         0.33%      17.169us       0.715us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.39%      20.652us         1.64%      86.425us       5.762us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.51%      79.433us         1.51%      79.433us       3.310us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         2.18%     114.743us         2.18%     114.743us       7.650us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.29%      15.331us         0.29%      15.331us       5.110us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.04%       1.900us         0.04%       1.900us       0.317us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.10%       5.520us         0.10%       5.520us       1.840us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        57.62%       3.034ms        57.62%       3.034ms       3.034ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 5.265ms
Self CUDA time total: 3.723ms



======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L320_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         5.04%     266.137us        41.64%       2.197ms       2.197ms       0.000us         0.00%       3.820ms       3.820ms             1  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       3.772ms       100.29%       3.772ms       3.772ms             1  
                     aten::scaled_dot_product_attention         0.49%      25.880us         3.59%     189.194us      63.065us       0.000us         0.00%       2.983ms     994.205us             3  
              aten::_scaled_dot_product_flash_attention         0.37%      19.363us         3.10%     163.314us      54.438us       0.000us         0.00%       2.983ms     994.205us             3  
                         aten::_flash_attention_forward         0.81%      42.782us         2.31%     121.862us      40.621us       2.983ms        79.31%       2.983ms     994.205us             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       2.983ms        79.31%       2.983ms     994.205us             3  
                                       aten::contiguous         0.18%       9.290us        32.12%       1.695ms     141.255us       0.000us         0.00%     836.990us      69.749us            12  
                                            aten::clone         0.53%      27.791us        31.95%       1.686ms     140.481us       0.000us         0.00%     836.990us      69.749us            12  
                                            aten::copy_         1.57%      82.879us        30.22%       1.595ms     132.896us     778.238us        20.69%     836.990us      69.749us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     778.238us        20.69%     778.238us      64.853us            12  
                                Activity Buffer Request        26.92%       1.420ms        26.92%       1.420ms       1.420ms      58.752us         1.56%      58.752us      58.752us             1  
                                        aten::transpose         0.98%      51.581us         1.30%      68.820us       2.868us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.33%      17.239us         0.33%      17.239us       0.718us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.35%      18.669us         1.58%      83.581us       5.572us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.49%      78.372us         1.49%      78.372us       3.265us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         2.17%     114.523us         2.17%     114.523us       7.635us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.29%      15.511us         0.29%      15.511us       5.170us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.04%       2.300us         0.04%       2.300us       0.383us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.09%       4.560us         0.09%       4.560us       1.520us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        58.36%       3.079ms        58.36%       3.079ms       3.079ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 5.277ms
Self CUDA time total: 3.761ms



======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L384_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         4.81%     269.664us        43.38%       2.432ms       2.432ms       0.000us         0.00%       3.921ms       3.921ms             1  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       3.875ms       100.29%       3.875ms       3.875ms             1  
                     aten::scaled_dot_product_attention         0.47%      26.530us         3.32%     186.254us      62.085us       0.000us         0.00%       3.079ms       1.026ms             3  
              aten::_scaled_dot_product_flash_attention         0.33%      18.670us         2.85%     159.724us      53.241us       0.000us         0.00%       3.079ms       1.026ms             3  
                         aten::_flash_attention_forward         0.73%      41.012us         2.12%     118.963us      39.654us       3.079ms        79.68%       3.079ms       1.026ms             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       3.079ms        79.68%       3.079ms       1.026ms             3  
                                       aten::contiguous         0.17%       9.411us        34.39%       1.928ms     160.703us       0.000us         0.00%     842.199us      70.183us            12  
                                            aten::clone         0.52%      28.883us        34.22%       1.919ms     159.919us       0.000us         0.00%     842.199us      70.183us            12  
                                            aten::copy_         1.48%      82.822us        32.55%       1.825ms     152.123us     784.952us        20.32%     842.199us      70.183us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     784.952us        20.32%     784.952us      65.413us            12  
                                Activity Buffer Request        25.77%       1.445ms        25.77%       1.445ms       1.445ms      57.247us         1.48%      57.247us      57.247us             1  
                                        aten::transpose         0.94%      52.967us         1.25%      70.184us       2.924us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.31%      17.217us         0.31%      17.217us       0.717us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.34%      19.178us         1.51%      84.829us       5.655us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.41%      78.973us         1.41%      78.973us       3.291us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         5.72%     320.465us         5.72%     320.465us      21.364us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.27%      15.229us         0.27%      15.229us       5.076us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.04%       2.110us         0.04%       2.110us       0.352us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.07%       4.130us         0.07%       4.130us       1.377us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        56.62%       3.175ms        56.62%       3.175ms       3.175ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 5.607ms
Self CUDA time total: 3.864ms



======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L448_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         5.31%     318.398us        40.52%       2.428ms       2.428ms       0.000us         0.00%       4.370ms       4.370ms             1  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       4.321ms       100.25%       4.321ms       4.321ms             1  
                     aten::scaled_dot_product_attention         0.43%      25.890us         3.27%     195.733us      65.244us       0.000us         0.00%       3.503ms       1.168ms             3  
              aten::_scaled_dot_product_flash_attention         0.32%      19.430us         2.83%     169.843us      56.614us       0.000us         0.00%       3.503ms       1.168ms             3  
                         aten::_flash_attention_forward         0.75%      44.733us         2.13%     127.534us      42.511us       3.503ms        81.28%       3.503ms       1.168ms             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       3.503ms        81.28%       3.503ms       1.168ms             3  
                                       aten::contiguous         0.16%       9.533us        31.15%       1.866ms     155.517us       0.000us         0.00%     867.131us      72.261us            12  
                                            aten::clone         0.48%      28.649us        30.99%       1.857ms     154.722us       0.000us         0.00%     867.131us      72.261us            12  
                                            aten::copy_         1.37%      82.103us        29.43%       1.763ms     146.944us     806.940us        18.72%     867.131us      72.261us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     806.940us        18.72%     806.940us      67.245us            12  
                                Activity Buffer Request        23.90%       1.432ms        23.90%       1.432ms       1.432ms      60.191us         1.40%      60.191us      60.191us             1  
                                        aten::transpose         0.87%      52.328us         1.17%      70.130us       2.922us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.30%      17.802us         0.30%      17.802us       0.742us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.33%      20.052us         1.44%      86.062us       5.737us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.32%      79.270us         1.32%      79.270us       3.303us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         4.58%     274.314us         4.58%     274.314us      18.288us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.27%      16.430us         0.27%      16.430us       5.477us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.04%       2.360us         0.04%       2.360us       0.393us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.07%       4.210us         0.07%       4.210us       1.403us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        59.48%       3.564ms        59.48%       3.564ms       3.564ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 5.991ms
Self CUDA time total: 4.310ms



======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L512_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         3.92%     237.516us        38.06%       2.305ms       2.305ms       0.000us         0.00%       4.487ms       4.487ms             1  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       4.438ms       100.25%       4.438ms       4.438ms             1  
                     aten::scaled_dot_product_attention         0.44%      26.369us         3.02%     182.943us      60.981us       0.000us         0.00%       3.605ms       1.202ms             3  
              aten::_scaled_dot_product_flash_attention         0.31%      18.541us         2.59%     156.574us      52.191us       0.000us         0.00%       3.605ms       1.202ms             3  
                         aten::_flash_attention_forward         0.63%      38.112us         1.91%     115.882us      38.627us       3.605ms        81.43%       3.605ms       1.202ms             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       3.605ms        81.43%       3.605ms       1.202ms             3  
                                       aten::contiguous         0.15%       9.281us        30.31%       1.836ms     153.003us       0.000us         0.00%     882.684us      73.557us            12  
                                            aten::clone         0.47%      28.328us        30.16%       1.827ms     152.229us       0.000us         0.00%     882.684us      73.557us            12  
                                            aten::copy_         1.32%      79.871us        28.64%       1.734ms     144.531us     822.268us        18.57%     882.684us      73.557us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     822.268us        18.57%     822.268us      68.522us            12  
                                Activity Buffer Request        23.38%       1.416ms        23.38%       1.416ms       1.416ms      60.416us         1.36%      60.416us      60.416us             1  
                                        aten::transpose         0.89%      53.992us         1.17%      70.941us       2.956us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.28%      16.949us         0.28%      16.949us       0.706us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.33%      19.985us         1.39%      84.474us       5.632us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.27%      76.679us         1.27%      76.679us       3.195us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         4.33%     262.156us         4.33%     262.156us      17.477us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.26%      15.620us         0.26%      15.620us       5.207us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.04%       2.329us         0.04%       2.329us       0.388us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.06%       3.781us         0.06%       3.781us       1.260us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        61.94%       3.751ms        61.94%       3.751ms       3.751ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 6.057ms
Self CUDA time total: 4.427ms


impl                     wl                  p50(ms)  ok
torch_flash_ma           cuda_attn_L128_bfloat16     1.21  True
torch_flash_ma           cuda_attn_L256_bfloat16     1.26  True
torch_flash_ma           cuda_attn_L320_bfloat16     1.29  True
torch_flash_ma           cuda_attn_L384_bfloat16     1.32  True
torch_flash_ma           cuda_attn_L448_bfloat16     1.48  True
torch_flash_ma           cuda_attn_L512_bfloat16     1.51  True

Artifacts:

attention.jsonl