Flash Attention Implementation

GPU Info

▼ code ▼ output ▶ uv-logs | Cell: nv | 0.26s | Raw GitHub
import subprocess

print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)
Tue Oct 28 14:08:39 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.195.03             Driver Version: 570.195.03     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA L40S                    On  |   00000000:4D:00.0 Off |                    0 |
| N/A   32C    P0            153W /  350W |       0MiB /  46068MiB |     26%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

Flash Attention Benchmark

▼ code ▼ output ▶ uv-logs | Cell: benchmark | 3.83s | Raw GitHub
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "numpy",
#     "torch==2.8.0",
#     "kernels-benchmark-tools",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
# ///
import torch
import sys
from kernels_benchmark_tools import KernelTypeEnum, run_benchmark


def torch_flash(q, k, v):
    qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
    with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
        o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
    return o.transpose(1, 2).contiguous()


run_benchmark(
    kernel_type=KernelTypeEnum.ATTENTION,
    impl_name="torch_flash_ma",
    impl_tags={"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"},
    impl_func=torch_flash,
)
Running attention benchmark on cuda with 6 workloads.

======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L128_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       3.585ms       101.47%       3.585ms       3.585ms             1  
                                         torch_flash_ma         6.34%     327.656us        45.53%       2.352ms       2.352ms       0.000us         0.00%       3.573ms       3.573ms             1  
                     aten::scaled_dot_product_attention         0.82%      42.312us         4.12%     213.057us      71.019us       0.000us         0.00%       2.820ms     940.062us             3  
              aten::_scaled_dot_product_flash_attention         0.51%      26.321us         3.31%     170.745us      56.915us       0.000us         0.00%       2.820ms     940.062us             3  
                         aten::_flash_attention_forward         0.73%      37.527us         2.40%     124.015us      41.338us       2.820ms        79.83%       2.820ms     940.062us             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       2.820ms        79.83%       2.820ms     940.062us             3  
                                       aten::contiguous         0.27%      14.121us        33.79%       1.745ms     145.446us       0.000us         0.00%     752.928us      62.744us            12  
                                            aten::clone         0.72%      37.329us        33.52%       1.731ms     144.269us       0.000us         0.00%     752.928us      62.744us            12  
                                            aten::copy_         1.68%      87.013us        31.25%       1.614ms     134.513us     712.672us        20.17%     752.928us      62.744us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     712.672us        20.17%     712.672us      59.389us            12  
                                Activity Buffer Request        27.64%       1.428ms        27.64%       1.428ms       1.428ms      40.256us         1.14%      40.256us      40.256us             1  
                                        aten::transpose         1.24%      64.087us         1.67%      86.009us       3.584us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.42%      21.922us         0.42%      21.922us       0.913us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.48%      24.711us         1.99%     102.775us       6.852us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.74%      89.843us         1.74%      89.843us       3.743us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         2.38%     122.771us         2.38%     122.771us       8.185us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.34%      17.310us         0.34%      17.310us       5.770us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.04%       2.229us         0.04%       2.229us       0.372us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.17%       8.900us         0.17%       8.900us       2.967us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        54.47%       2.814ms        54.47%       2.814ms       2.814ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 5.165ms
Self CUDA time total: 3.533ms



======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L256_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         4.84%     255.079us        41.49%       2.188ms       2.188ms       0.000us         0.00%       3.787ms       3.787ms             1  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       3.743ms       100.29%       3.743ms       3.743ms             1  
                     aten::scaled_dot_product_attention         0.47%      24.640us         3.42%     180.356us      60.119us       0.000us         0.00%       2.967ms     989.106us             3  
              aten::_scaled_dot_product_flash_attention         0.36%      19.241us         2.95%     155.716us      51.905us       0.000us         0.00%       2.967ms     989.106us             3  
                         aten::_flash_attention_forward         0.73%      38.683us         2.19%     115.525us      38.508us       2.967ms        79.51%       2.967ms     989.106us             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       2.967ms        79.51%       2.967ms     989.106us             3  
                                       aten::contiguous         0.17%       8.802us        32.41%       1.709ms     142.425us       0.000us         0.00%     819.868us      68.322us            12  
                                            aten::clone         0.52%      27.349us        32.24%       1.700ms     141.692us       0.000us         0.00%     819.868us      68.322us            12  
                                            aten::copy_         1.56%      82.061us        30.60%       1.614ms     134.473us     764.892us        20.49%     819.868us      68.322us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     764.892us        20.49%     764.892us      63.741us            12  
                                Activity Buffer Request        27.50%       1.450ms        27.50%       1.450ms       1.450ms      54.976us         1.47%      54.976us      54.976us             1  
                                        aten::transpose         0.91%      47.959us         1.22%      64.512us       2.688us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.31%      16.553us         0.31%      16.553us       0.690us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.39%      20.732us         1.52%      80.304us       5.354us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.38%      72.972us         1.38%      72.972us       3.040us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         1.96%     103.146us         1.96%     103.146us       6.876us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.28%      14.880us         0.28%      14.880us       4.960us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.03%       1.800us         0.03%       1.800us       0.300us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.07%       3.830us         0.07%       3.830us       1.277us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        58.51%       3.085ms        58.51%       3.085ms       3.085ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 5.273ms
Self CUDA time total: 3.732ms



======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L320_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         4.77%     251.162us        41.45%       2.184ms       2.184ms       0.000us         0.00%       3.786ms       3.786ms             1  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       3.738ms       100.28%       3.738ms       3.738ms             1  
                     aten::scaled_dot_product_attention         0.46%      24.280us         3.42%     180.086us      60.029us       0.000us         0.00%       2.949ms     982.872us             3  
              aten::_scaled_dot_product_flash_attention         0.34%      18.160us         2.96%     155.806us      51.935us       0.000us         0.00%       2.949ms     982.872us             3  
                         aten::_flash_attention_forward         0.73%      38.599us         2.20%     115.865us      38.622us       2.949ms        79.09%       2.949ms     982.872us             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       2.949ms        79.09%       2.949ms     982.872us             3  
                                       aten::contiguous         0.17%       8.991us        32.44%       1.710ms     142.465us       0.000us         0.00%     837.719us      69.810us            12  
                                            aten::clone         0.53%      27.728us        32.27%       1.701ms     141.715us       0.000us         0.00%     837.719us      69.810us            12  
                                            aten::copy_         1.52%      79.873us        30.57%       1.611ms     134.242us     779.480us        20.91%     837.719us      69.810us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     779.480us        20.91%     779.480us      64.957us            12  
                                Activity Buffer Request        27.50%       1.449ms        27.50%       1.449ms       1.449ms      58.239us         1.56%      58.239us      58.239us             1  
                                        aten::transpose         0.92%      48.219us         1.24%      65.252us       2.719us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.32%      17.033us         0.32%      17.033us       0.710us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.37%      19.303us         1.55%      81.795us       5.453us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.44%      76.031us         1.44%      76.031us       3.168us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         1.98%     104.564us         1.98%     104.564us       6.971us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.28%      14.492us         0.28%      14.492us       4.831us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.04%       1.860us         0.04%       1.860us       0.310us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.10%       5.030us         0.10%       5.030us       1.677us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        58.55%       3.085ms        58.55%       3.085ms       3.085ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 5.269ms
Self CUDA time total: 3.728ms



======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L384_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         5.01%     280.573us        44.17%       2.475ms       2.475ms       0.000us         0.00%       3.878ms       3.878ms             1  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       3.831ms       100.27%       3.831ms       3.831ms             1  
                     aten::scaled_dot_product_attention         0.48%      26.630us         3.39%     189.956us      63.319us       0.000us         0.00%       3.032ms       1.011ms             3  
              aten::_scaled_dot_product_flash_attention         0.34%      19.101us         2.91%     163.326us      54.442us       0.000us         0.00%       3.032ms       1.011ms             3  
                         aten::_flash_attention_forward         0.70%      39.063us         2.15%     120.325us      40.108us       3.032ms        79.37%       3.032ms       1.011ms             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       3.032ms        79.37%       3.032ms       1.011ms             3  
                                       aten::contiguous         0.17%       9.271us        34.98%       1.960ms     163.354us       0.000us         0.00%     845.820us      70.485us            12  
                                            aten::clone         0.52%      28.974us        34.82%       1.951ms     162.581us       0.000us         0.00%     845.820us      70.485us            12  
                                            aten::copy_         1.48%      83.180us        33.17%       1.859ms     154.908us     788.284us        20.63%     845.820us      70.485us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     788.284us        20.63%     788.284us      65.690us            12  
                                Activity Buffer Request        26.18%       1.467ms        26.18%       1.467ms       1.467ms      57.536us         1.51%      57.536us      57.536us             1  
                                        aten::transpose         0.89%      50.110us         1.21%      67.952us       2.831us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.32%      17.842us         0.32%      17.842us       0.743us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.36%      19.969us         1.53%      85.492us       5.699us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.37%      76.982us         1.37%      76.982us       3.208us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         5.95%     333.480us         5.95%     333.480us      22.232us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.30%      17.041us         0.30%      17.041us       5.680us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.03%       1.700us         0.03%       1.700us       0.283us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.07%       4.040us         0.07%       4.040us       1.347us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        55.83%       3.129ms        55.83%       3.129ms       3.129ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 5.603ms
Self CUDA time total: 3.820ms



======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L448_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         5.07%     303.893us        39.93%       2.395ms       2.395ms       0.000us         0.00%       4.370ms       4.370ms             1  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       4.320ms       100.25%       4.320ms       4.320ms             1  
                     aten::scaled_dot_product_attention         0.41%      24.650us         3.07%     184.006us      61.335us       0.000us         0.00%       3.503ms       1.168ms             3  
              aten::_scaled_dot_product_flash_attention         0.32%      19.311us         2.66%     159.356us      53.119us       0.000us         0.00%       3.503ms       1.168ms             3  
                         aten::_flash_attention_forward         0.68%      40.911us         1.97%     118.205us      39.402us       3.503ms        81.28%       3.503ms       1.168ms             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       3.503ms        81.28%       3.503ms       1.168ms             3  
                                       aten::contiguous         0.15%       8.977us        31.04%       1.862ms     155.201us       0.000us         0.00%     867.581us      72.298us            12  
                                            aten::clone         0.47%      28.114us        30.89%       1.853ms     154.453us       0.000us         0.00%     867.581us      72.298us            12  
                                            aten::copy_         1.36%      81.500us        29.40%       1.764ms     146.991us     806.749us        18.72%     867.581us      72.298us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     806.749us        18.72%     806.749us      67.229us            12  
                                Activity Buffer Request        23.82%       1.429ms        23.82%       1.429ms       1.429ms      60.832us         1.41%      60.832us      60.832us             1  
                                        aten::transpose         0.82%      49.363us         1.11%      66.863us       2.786us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.29%      17.500us         0.29%      17.500us       0.729us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.33%      20.081us         1.37%      82.424us       5.495us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.26%      75.593us         1.26%      75.593us       3.150us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         4.60%     275.759us         4.60%     275.759us      18.384us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.25%      15.251us         0.25%      15.251us       5.084us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.03%       1.740us         0.03%       1.740us       0.290us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.06%       3.680us         0.06%       3.680us       1.227us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        60.07%       3.604ms        60.07%       3.604ms       3.604ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 5.999ms
Self CUDA time total: 4.309ms



======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L512_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         3.83%     232.270us        37.82%       2.296ms       2.296ms       0.000us         0.00%       4.474ms       4.474ms             1  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       4.423ms       100.25%       4.423ms       4.423ms             1  
                     aten::scaled_dot_product_attention         0.41%      24.850us         2.85%     172.746us      57.582us       0.000us         0.00%       3.595ms       1.198ms             3  
              aten::_scaled_dot_product_flash_attention         0.30%      18.250us         2.44%     147.896us      49.299us       0.000us         0.00%       3.595ms       1.198ms             3  
                         aten::_flash_attention_forward         0.54%      32.692us         1.77%     107.224us      35.741us       3.595ms        81.48%       3.595ms       1.198ms             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       3.595ms        81.48%       3.595ms       1.198ms             3  
                                       aten::contiguous         0.14%       8.610us        30.41%       1.846ms     153.859us       0.000us         0.00%     878.139us      73.178us            12  
                                            aten::clone         0.45%      27.368us        30.27%       1.838ms     153.142us       0.000us         0.00%     878.139us      73.178us            12  
                                            aten::copy_         1.35%      81.917us        28.83%       1.750ms     145.831us     817.083us        18.52%     878.139us      73.178us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     817.083us        18.52%     817.083us      68.090us            12  
                                Activity Buffer Request        23.72%       1.440ms        23.72%       1.440ms       1.440ms      61.056us         1.38%      61.056us      61.056us             1  
                                        aten::transpose         0.82%      50.064us         1.10%      66.792us       2.783us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.28%      16.728us         0.28%      16.728us       0.697us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.32%      19.431us         1.31%      79.591us       5.306us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.21%      73.220us         1.21%      73.220us       3.051us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         4.12%     249.950us         4.12%     249.950us      16.663us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.24%      14.270us         0.24%      14.270us       4.757us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.03%       1.680us         0.03%       1.680us       0.280us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.07%       4.380us         0.07%       4.380us       1.460us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        62.18%       3.775ms        62.18%       3.775ms       3.775ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 6.071ms
Self CUDA time total: 4.413ms


impl                     wl                  p50(ms)  ok
torch_flash_ma           cuda_attn_L128_bfloat16     1.22  True
torch_flash_ma           cuda_attn_L256_bfloat16     1.27  True
torch_flash_ma           cuda_attn_L320_bfloat16     1.28  True
torch_flash_ma           cuda_attn_L384_bfloat16     1.31  True
torch_flash_ma           cuda_attn_L448_bfloat16     1.47  True
torch_flash_ma           cuda_attn_L512_bfloat16     1.50  True

Artifacts:

attention.jsonl