Flash Attention Implementation

GPU Info

▼ code ▼ output ▶ uv-logs | Cell: nv | 0.25s | Raw GitHub
import subprocess

print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)
Fri Dec 19 19:41:23 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.105.08             Driver Version: 580.105.08     CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA L40S                    On  |   00000000:4D:00.0 Off |                    0 |
| N/A   31C    P0            107W /  350W |       0MiB /  46068MiB |    100%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

Flash Attention Benchmark

▼ code ▼ output ▶ uv-logs | Cell: benchmark | 4.12s | Raw GitHub
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "numpy",
#     "torch==2.8.0",
#     "kernels-benchmark-tools",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
# ///
import torch
import sys
from kernels_benchmark_tools import KernelTypeEnum, run_benchmark


def torch_flash(q, k, v):
    qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
    with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
        o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
    return o.transpose(1, 2).contiguous()


run_benchmark(
    kernel_type=KernelTypeEnum.ATTENTION,
    impl_name="torch_flash_ma",
    impl_tags={"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"},
    impl_func=torch_flash,
)
Running attention benchmark on cuda with 6 workloads.

======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L128_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       3.560ms       101.41%       3.560ms       3.560ms             1  
                                         torch_flash_ma         6.12%     330.406us        49.12%       2.651ms       2.651ms       0.000us         0.00%       3.550ms       3.550ms             1  
                     aten::scaled_dot_product_attention         0.76%      41.091us         4.12%     222.225us      74.075us       0.000us         0.00%       2.785ms     928.191us             3  
              aten::_scaled_dot_product_flash_attention         0.57%      30.902us         3.36%     181.134us      60.378us       0.000us         0.00%       2.785ms     928.191us             3  
                         aten::_flash_attention_forward         0.74%      39.881us         2.41%     130.323us      43.441us       2.785ms        79.34%       2.785ms     928.191us             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       2.785ms        79.34%       2.785ms     928.191us             3  
                                       aten::contiguous         0.24%      12.809us        37.68%       2.033ms     169.455us       0.000us         0.00%     765.791us      63.816us            12  
                                            aten::clone         0.64%      34.521us        37.44%       2.021ms     168.387us       0.000us         0.00%     765.791us      63.816us            12  
                                            aten::copy_         1.67%      90.094us        35.26%       1.903ms     158.570us     725.311us        20.66%     765.791us      63.816us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     725.311us        20.66%     725.311us      60.443us            12  
                                Activity Buffer Request        31.66%       1.709ms        31.66%       1.709ms       1.709ms      40.480us         1.15%      40.480us      40.480us             1  
                                        aten::transpose         1.17%      63.269us         1.58%      85.140us       3.548us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.41%      21.871us         0.41%      21.871us       0.911us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.47%      25.421us         1.97%     106.322us       7.088us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.76%      94.971us         1.76%      94.971us       3.957us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         2.37%     128.144us         2.37%     128.144us       8.543us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.32%      17.100us         0.32%      17.100us       5.700us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.04%       2.290us         0.04%       2.290us       0.382us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.18%       9.631us         0.18%       9.631us       3.210us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        50.88%       2.746ms        50.88%       2.746ms       2.746ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 5.397ms
Self CUDA time total: 3.510ms



======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L256_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         4.59%     254.063us        44.59%       2.468ms       2.468ms       0.000us         0.00%       3.765ms       3.765ms             1  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       3.721ms       100.30%       3.721ms       3.721ms             1  
                     aten::scaled_dot_product_attention         0.43%      23.691us         3.30%     182.385us      60.795us       0.000us         0.00%       2.950ms     983.280us             3  
              aten::_scaled_dot_product_flash_attention         0.32%      17.969us         2.87%     158.694us      52.898us       0.000us         0.00%       2.950ms     983.280us             3  
                         aten::_flash_attention_forward         0.74%      40.930us         2.17%     120.223us      40.074us       2.950ms        79.52%       2.950ms     983.280us             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       2.950ms        79.52%       2.950ms     983.280us             3  
                                       aten::contiguous         0.16%       8.922us        35.94%       1.989ms     165.775us       0.000us         0.00%     815.354us      67.946us            12  
                                            aten::clone         0.46%      25.650us        35.78%       1.980ms     165.031us       0.000us         0.00%     815.354us      67.946us            12  
                                            aten::copy_         1.41%      78.081us        34.18%       1.891ms     157.619us     759.770us        20.48%     815.354us      67.946us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     759.770us        20.48%     759.770us      63.314us            12  
                                Activity Buffer Request        31.33%       1.734ms        31.33%       1.734ms       1.734ms      55.584us         1.50%      55.584us      55.584us             1  
                                        aten::transpose         0.84%      46.272us         1.13%      62.592us       2.608us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.29%      16.320us         0.29%      16.320us       0.680us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.39%      21.392us         1.49%      82.711us       5.514us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.42%      78.721us         1.42%      78.721us       3.280us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         1.84%     101.714us         1.84%     101.714us       6.781us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.25%      13.930us         0.25%      13.930us       4.643us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.03%       1.700us         0.03%       1.700us       0.283us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.08%       4.360us         0.08%       4.360us       1.453us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        55.41%       3.067ms        55.41%       3.067ms       3.067ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 5.534ms
Self CUDA time total: 3.710ms



======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L320_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         4.62%     254.756us        44.14%       2.433ms       2.433ms       0.000us         0.00%       3.774ms       3.774ms             1  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       3.727ms       100.29%       3.727ms       3.727ms             1  
                     aten::scaled_dot_product_attention         0.43%      23.830us         3.33%     183.454us      61.151us       0.000us         0.00%       2.942ms     980.796us             3  
              aten::_scaled_dot_product_flash_attention         0.32%      17.891us         2.90%     159.624us      53.208us       0.000us         0.00%       2.942ms     980.796us             3  
                         aten::_flash_attention_forward         0.73%      40.074us         2.20%     121.152us      40.384us       2.942ms        79.17%       2.942ms     980.796us             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       2.942ms        79.17%       2.942ms     980.796us             3  
                                       aten::contiguous         0.16%       8.718us        35.43%       1.953ms     162.745us       0.000us         0.00%     831.581us      69.298us            12  
                                            aten::clone         0.47%      25.749us        35.27%       1.944ms     162.019us       0.000us         0.00%     831.581us      69.298us            12  
                                            aten::copy_         1.40%      77.041us        33.64%       1.855ms     154.552us     774.142us        20.83%     831.581us      69.298us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     774.142us        20.83%     774.142us      64.512us            12  
                                Activity Buffer Request        30.83%       1.700ms        30.83%       1.700ms       1.700ms      57.439us         1.55%      57.439us      57.439us             1  
                                        aten::transpose         0.84%      46.360us         1.13%      62.482us       2.603us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.29%      16.122us         0.29%      16.122us       0.672us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.36%      19.611us         1.53%      84.374us       5.625us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.44%      79.561us         1.44%      79.561us       3.315us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         1.87%     102.913us         1.87%     102.913us       6.861us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.28%      15.330us         0.28%      15.330us       5.110us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.03%       1.680us         0.03%       1.680us       0.280us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.07%       3.840us         0.07%       3.840us       1.280us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        55.86%       3.080ms        55.86%       3.080ms       3.080ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 5.513ms
Self CUDA time total: 3.717ms



======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L384_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         4.28%     249.055us        45.91%       2.672ms       2.672ms       0.000us         0.00%       3.870ms       3.870ms             1  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       3.822ms       100.28%       3.822ms       3.822ms             1  
                     aten::scaled_dot_product_attention         0.44%      25.342us         3.23%     187.955us      62.652us       0.000us         0.00%       3.022ms       1.007ms             3  
              aten::_scaled_dot_product_flash_attention         0.30%      17.701us         2.79%     162.613us      54.204us       0.000us         0.00%       3.022ms       1.007ms             3  
                         aten::_flash_attention_forward         0.71%      41.280us         2.11%     122.541us      40.847us       3.022ms        79.29%       3.022ms       1.007ms             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       3.022ms        79.29%       3.022ms       1.007ms             3  
                                       aten::contiguous         0.16%       9.081us        37.65%       2.191ms     182.597us       0.000us         0.00%     847.483us      70.624us            12  
                                            aten::clone         0.47%      27.546us        37.50%       2.182ms     181.840us       0.000us         0.00%     847.483us      70.624us            12  
                                            aten::copy_         1.40%      81.736us        35.91%       2.090ms     174.156us     789.211us        20.71%     847.483us      70.624us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     789.211us        20.71%     789.211us      65.768us            12  
                                Activity Buffer Request        29.46%       1.714ms        29.46%       1.714ms       1.714ms      58.272us         1.53%      58.272us      58.272us             1  
                                        aten::transpose         0.83%      48.521us         1.13%      65.981us       2.749us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.30%      17.460us         0.30%      17.460us       0.727us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.35%      20.461us         1.45%      84.343us       5.623us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.38%      80.070us         1.38%      80.070us       3.336us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         5.47%     318.217us         5.47%     318.217us      21.214us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.25%      14.521us         0.25%      14.521us       4.840us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.03%       1.689us         0.03%       1.689us       0.282us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.08%       4.671us         0.08%       4.671us       1.557us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        54.09%       3.147ms        54.09%       3.147ms       3.147ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 5.819ms
Self CUDA time total: 3.811ms



======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L448_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         4.79%     300.628us        43.01%       2.699ms       2.699ms       0.000us         0.00%       4.340ms       4.340ms             1  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       4.290ms       100.25%       4.290ms       4.290ms             1  
                     aten::scaled_dot_product_attention         0.40%      25.381us         2.96%     185.704us      61.901us       0.000us         0.00%       3.474ms       1.158ms             3  
              aten::_scaled_dot_product_flash_attention         0.28%      17.780us         2.55%     160.323us      53.441us       0.000us         0.00%       3.474ms       1.158ms             3  
                         aten::_flash_attention_forward         0.64%      40.370us         1.93%     121.223us      40.408us       3.474ms        81.17%       3.474ms       1.158ms             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       3.474ms        81.17%       3.474ms       1.158ms             3  
                                       aten::contiguous         0.14%       9.022us        34.56%       2.169ms     180.719us       0.000us         0.00%     866.336us      72.195us            12  
                                            aten::clone         0.44%      27.858us        34.41%       2.160ms     179.967us       0.000us         0.00%     866.336us      72.195us            12  
                                            aten::copy_         1.24%      77.719us        32.91%       2.066ms     172.130us     806.048us        18.83%     866.336us      72.195us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     806.048us        18.83%     806.048us      67.171us            12  
                                Activity Buffer Request        27.70%       1.738ms        27.70%       1.738ms       1.738ms      60.288us         1.41%      60.288us      60.288us             1  
                                        aten::transpose         0.77%      48.240us         1.05%      65.650us       2.735us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.28%      17.410us         0.28%      17.410us       0.725us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.34%      21.363us         1.38%      86.453us       5.764us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.28%      80.561us         1.28%      80.561us       3.357us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         4.36%     273.888us         4.36%     273.888us      18.259us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.24%      14.900us         0.24%      14.900us       4.967us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.03%       1.700us         0.03%       1.700us       0.283us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.07%       4.100us         0.07%       4.100us       1.367us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        56.99%       3.576ms        56.99%       3.576ms       3.576ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 6.275ms
Self CUDA time total: 4.280ms



======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L512_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         4.01%     253.526us        41.16%       2.602ms       2.602ms       0.000us         0.00%       4.429ms       4.429ms             1  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       4.378ms       100.24%       4.378ms       4.378ms             1  
                     aten::scaled_dot_product_attention         0.38%      23.889us         2.89%     182.483us      60.828us       0.000us         0.00%       3.556ms       1.185ms             3  
              aten::_scaled_dot_product_flash_attention         0.27%      17.360us         2.51%     158.594us      52.865us       0.000us         0.00%       3.556ms       1.185ms             3  
                         aten::_flash_attention_forward         0.66%      42.013us         1.90%     120.422us      40.141us       3.556ms        81.42%       3.556ms       1.185ms             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       3.556ms        81.42%       3.556ms       1.185ms             3  
                                       aten::contiguous         0.14%       8.630us        33.58%       2.122ms     176.875us       0.000us         0.00%     872.667us      72.722us            12  
                                            aten::clone         0.41%      26.047us        33.44%       2.114ms     176.156us       0.000us         0.00%     872.667us      72.722us            12  
                                            aten::copy_         1.25%      79.082us        32.00%       2.023ms     168.597us     811.483us        18.58%     872.667us      72.722us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     811.483us        18.58%     811.483us      67.624us            12  
                                Activity Buffer Request        26.87%       1.699ms        26.87%       1.699ms       1.699ms      61.184us         1.40%      61.184us      61.184us             1  
                                        aten::transpose         0.75%      47.653us         1.02%      64.533us       2.689us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.27%      16.880us         0.27%      16.880us       0.703us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.33%      20.879us         1.34%      84.642us       5.643us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.25%      79.031us         1.25%      79.031us       3.293us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         4.24%     268.168us         4.24%     268.168us      17.878us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.23%      14.592us         0.23%      14.592us       4.864us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.03%       1.679us         0.03%       1.679us       0.280us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.06%       3.920us         0.06%       3.920us       1.307us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        58.84%       3.719ms        58.84%       3.719ms       3.719ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 6.322ms
Self CUDA time total: 4.367ms


impl                     wl                  p50(ms)  ok
torch_flash_ma           cuda_attn_L128_bfloat16     1.21  True
torch_flash_ma           cuda_attn_L256_bfloat16     1.25  True
torch_flash_ma           cuda_attn_L320_bfloat16     1.28  True
torch_flash_ma           cuda_attn_L384_bfloat16     1.31  True
torch_flash_ma           cuda_attn_L448_bfloat16     1.45  True
torch_flash_ma           cuda_attn_L512_bfloat16     1.49  True

Artifacts:

attention.jsonl