Flash Attention Implementation

GPU Info

▼ code ▼ output ▶ uv-logs | Cell: nv | 0.21s | Raw GitHub
import subprocess

print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)
Thu Oct 23 17:22:15 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.195.03             Driver Version: 570.195.03     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA L40S                    On  |   00000000:4D:00.0 Off |                    0 |
| N/A   37C    P0             91W /  350W |       0MiB /  46068MiB |     26%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

Flash Attention Benchmark

▼ code ▼ output ▶ uv-logs | Cell: benchmark | 3.60s | Raw GitHub
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "numpy",
#     "torch==2.8.0",
#     "kernels-benchmark-tools",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
# ///
import torch
import sys
import os
import kernels_benchmark_tools as kbt


def torch_flash(q, k, v):
    qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
    with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
        o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
    return o.transpose(1, 2).contiguous()

kbt.add(
    "torch_flash_ma",
    torch_flash,
    tags={"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"},
)

if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = "float32" if device == "cpu" else "bfloat16"

    # Flux-like workloads scaled down for CPU testing
    base = 1024 if device == "cuda" else 512
    flux_sizes = (
        [128, 256, 320, 384, 448, 512] if device == "cuda" else [64, 128, 192, 256]
    )
    heads = 24 if device == "cuda" else 8
    head_dim = 128 if device == "cuda" else 64

    wl = []
    for L in flux_sizes:
        wl.append(
            {
                "name": f"flux_L{L}",
                "batch": 1,
                "seq_len": base + L,
                "heads": heads,
                "head_dim": head_dim,
                "dtype": dtype,
                "device": device,
                "seed": 0,
            }
        )

    kbt.run(
        wl,
        jsonl="attn.jsonl",
        reps=5,
        warmup=2,
        gen=kbt.attn.gen_qkv,
        ref=kbt.attn.ref_math,
        cmp=kbt.attn.cmp_allclose,
        profile_trace=True
    )
    kbt.summarize(["attn.jsonl"])
======================================================================
PROFILE TRACE: torch_flash_ma | flux_L128
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us     799.070us       225.43%     799.070us     799.070us             1  
                                         torch_flash_ma        14.65%     361.148us        99.74%       2.458ms       2.458ms       0.000us         0.00%     362.241us     362.241us             1  
                     aten::scaled_dot_product_attention         1.75%      43.042us         9.34%     230.141us      76.714us       0.000us         0.00%     266.207us      88.736us             3  
              aten::_scaled_dot_product_flash_attention         1.09%      26.961us         7.59%     187.099us      62.366us       0.000us         0.00%     266.207us      88.736us             3  
                         aten::_flash_attention_forward         1.68%      41.361us         5.54%     136.527us      45.509us     266.207us        75.10%     266.207us      88.736us             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us     266.207us        75.10%     266.207us      88.736us             3  
                                       aten::contiguous         0.64%      15.860us        72.86%       1.796ms     149.661us       0.000us         0.00%      96.034us       8.003us            12  
                                            aten::clone         1.71%      42.134us        72.21%       1.780ms     148.339us       0.000us         0.00%      96.034us       8.003us            12  
                                            aten::copy_         3.86%      95.153us        66.84%       1.648ms     137.298us      88.258us        24.90%      96.034us       8.003us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      88.258us        24.90%      88.258us       7.355us            12  
                                Activity Buffer Request        58.01%       1.430ms        58.01%       1.430ms       1.430ms       7.776us         2.19%       7.776us       7.776us             1  
                                        aten::transpose         2.95%      72.712us         3.85%      94.884us       3.954us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.90%      22.172us         0.90%      22.172us       0.924us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         1.13%      27.832us         4.55%     112.245us       7.483us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         4.09%     100.886us         4.09%     100.886us       4.204us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         5.96%     146.998us         5.96%     146.998us       9.800us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.65%      15.960us         0.65%      15.960us       5.320us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.12%       2.850us         0.12%       2.850us       0.475us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.54%      13.411us         0.54%      13.411us       4.470us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize         0.26%       6.530us         0.26%       6.530us       6.530us       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 2.465ms
Self CUDA time total: 354.465us



======================================================================
PROFILE TRACE: torch_flash_ma | flux_L256
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us     680.541us       161.63%     680.541us     680.541us             1  
                                         torch_flash_ma        11.51%     254.710us        99.74%       2.208ms       2.208ms       0.000us         0.00%     430.783us     430.783us             1  
                     aten::scaled_dot_product_attention         1.09%      24.080us         8.33%     184.408us      61.469us       0.000us         0.00%     312.064us     104.021us             3  
              aten::_scaled_dot_product_flash_attention         0.81%      17.821us         7.24%     160.328us      53.443us       0.000us         0.00%     312.064us     104.021us             3  
                         aten::_flash_attention_forward         1.85%      41.011us         5.37%     118.956us      39.652us     312.064us        74.11%     312.064us     104.021us             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us     312.064us        74.11%     312.064us     104.021us             3  
                                       aten::contiguous         0.42%       9.258us        77.80%       1.722ms     143.509us       0.000us         0.00%     118.719us       9.893us            12  
                                            aten::clone         1.32%      29.284us        77.38%       1.713ms     142.737us       0.000us         0.00%     118.719us       9.893us            12  
                                            aten::copy_         3.64%      80.568us        73.02%       1.616ms     134.703us     108.991us        25.89%     118.719us       9.893us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     108.991us        25.89%     108.991us       9.083us            12  
                                Activity Buffer Request        65.56%       1.451ms        65.56%       1.451ms       1.451ms       9.728us         2.31%       9.728us       9.728us             1  
                                        aten::transpose         2.36%      52.224us         3.17%      70.126us       2.922us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.81%      17.902us         0.81%      17.902us       0.746us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.96%      21.191us         3.98%      88.123us       5.875us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         3.58%      79.273us         3.58%      79.273us       3.303us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         4.85%     107.363us         4.85%     107.363us       7.158us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.70%      15.410us         0.70%      15.410us       5.137us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.09%       2.071us         0.09%       2.071us       0.345us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.20%       4.321us         0.20%       4.321us       1.440us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize         0.26%       5.841us         0.26%       5.841us       5.841us       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 2.214ms
Self CUDA time total: 421.055us



======================================================================
PROFILE TRACE: torch_flash_ma | flux_L320
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us     690.203us       159.06%     690.203us     690.203us             1  
                                         torch_flash_ma        11.42%     254.276us        99.18%       2.209ms       2.209ms       0.000us         0.00%     443.100us     443.100us             1  
                     aten::scaled_dot_product_attention         1.09%      24.201us         8.13%     181.079us      60.360us       0.000us         0.00%     330.557us     110.186us             3  
              aten::_scaled_dot_product_flash_attention         0.78%      17.350us         7.04%     156.878us      52.293us       0.000us         0.00%     330.557us     110.186us             3  
                         aten::_flash_attention_forward         1.80%      40.093us         5.30%     118.035us      39.345us     330.557us        76.18%     330.557us     110.186us             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us     330.557us        76.18%     330.557us     110.186us             3  
                                       aten::contiguous         0.42%       9.369us        77.58%       1.728ms     143.991us       0.000us         0.00%     112.543us       9.379us            12  
                                            aten::clone         1.34%      29.740us        77.16%       1.719ms     143.210us       0.000us         0.00%     112.543us       9.379us            12  
                                            aten::copy_         3.81%      84.905us        72.90%       1.624ms     135.305us     103.359us        23.82%     112.543us       9.379us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     103.359us        23.82%     103.359us       8.613us            12  
                                Activity Buffer Request        65.38%       1.456ms        65.38%       1.456ms       1.456ms       9.184us         2.12%       9.184us       9.184us             1  
                                        aten::transpose         2.26%      50.400us         3.02%      67.214us       2.801us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.75%      16.814us         0.75%      16.814us       0.701us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.96%      21.489us         3.82%      85.044us       5.670us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         3.43%      76.464us         3.43%      76.464us       3.186us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         4.82%     107.405us         4.82%     107.405us       7.160us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.66%      14.631us         0.66%      14.631us       4.877us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.08%       1.710us         0.08%       1.710us       0.285us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.18%       3.930us         0.18%       3.930us       1.310us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize         0.82%      18.331us         0.82%      18.331us      18.331us       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 2.227ms
Self CUDA time total: 433.916us



======================================================================
PROFILE TRACE: torch_flash_ma | flux_L384
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us     691.645us       147.68%     691.645us     691.645us             1  
                                         torch_flash_ma        10.40%     252.243us        99.18%       2.405ms       2.405ms       0.000us         0.00%     481.117us     481.117us             1  
                     aten::scaled_dot_product_attention         1.00%      24.352us         7.27%     176.289us      58.763us       0.000us         0.00%     341.277us     113.759us             3  
              aten::_scaled_dot_product_flash_attention         0.73%      17.811us         6.27%     151.937us      50.646us       0.000us         0.00%     341.277us     113.759us             3  
                         aten::_flash_attention_forward         1.38%      33.540us         4.54%     110.186us      36.729us     341.277us        72.87%     341.277us     113.759us             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us     341.277us        72.87%     341.277us     113.759us             3  
                                       aten::contiguous         0.39%       9.522us        79.59%       1.930ms     160.818us       0.000us         0.00%     139.840us      11.653us            12  
                                            aten::clone         1.25%      30.240us        79.20%       1.920ms     160.024us       0.000us         0.00%     139.840us      11.653us            12  
                                            aten::copy_         3.35%      81.274us        75.28%       1.825ms     152.111us     127.072us        27.13%     139.840us      11.653us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     127.072us        27.13%     127.072us      10.589us            12  
                                Activity Buffer Request        59.91%       1.453ms        59.91%       1.453ms       1.453ms      12.768us         2.73%      12.768us      12.768us             1  
                                        aten::transpose         2.18%      52.871us         2.90%      70.271us       2.928us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.72%      17.400us         0.72%      17.400us       0.725us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.83%      20.083us         3.47%      84.148us       5.610us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         3.18%      77.125us         3.18%      77.125us       3.214us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel        13.00%     315.205us        13.00%     315.205us      21.014us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.61%      14.781us         0.61%      14.781us       4.927us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.07%       1.670us         0.07%       1.670us       0.278us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.16%       3.970us         0.16%       3.970us       1.323us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize         0.82%      19.911us         0.82%      19.911us      19.911us       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 2.425ms
Self CUDA time total: 468.349us



======================================================================
PROFILE TRACE: torch_flash_ma | flux_L448
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us     799.966us       130.76%     799.966us     799.966us             1  
                                         torch_flash_ma        12.25%     304.685us        97.28%       2.419ms       2.419ms       0.000us         0.00%     624.638us     624.638us             1  
                     aten::scaled_dot_product_attention         0.97%      24.122us         7.38%     183.559us      61.186us       0.000us         0.00%     485.886us     161.962us             3  
              aten::_scaled_dot_product_flash_attention         0.71%      17.700us         6.41%     159.437us      53.146us       0.000us         0.00%     485.886us     161.962us             3  
                         aten::_flash_attention_forward         1.59%      39.459us         4.74%     117.796us      39.265us     485.886us        79.42%     485.886us     161.962us             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us     485.886us        79.42%     485.886us     161.962us             3  
                                       aten::contiguous         0.39%       9.743us        75.79%       1.885ms     157.075us       0.000us         0.00%     138.752us      11.563us            12  
                                            aten::clone         1.21%      30.098us        75.40%       1.875ms     156.263us       0.000us         0.00%     138.752us      11.563us            12  
                                            aten::copy_         3.39%      84.237us        71.41%       1.776ms     147.998us     125.888us        20.58%     138.752us      11.563us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     125.888us        20.58%     125.888us      10.491us            12  
                                Activity Buffer Request        58.51%       1.455ms        58.51%       1.455ms       1.455ms      12.864us         2.10%      12.864us      12.864us             1  
                                        aten::transpose         2.11%      52.456us         2.81%      69.984us       2.916us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.70%      17.528us         0.70%      17.528us       0.730us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.83%      20.690us         3.57%      88.794us       5.920us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         3.29%      81.917us         3.29%      81.917us       3.413us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel        10.48%     260.751us        10.48%     260.751us      17.383us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.58%      14.540us         0.58%      14.540us       4.847us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.09%       2.170us         0.09%       2.170us       0.362us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.16%       3.911us         0.16%       3.911us       1.304us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize         2.72%      67.754us         2.72%      67.754us      67.754us       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 2.487ms
Self CUDA time total: 611.774us



======================================================================
PROFILE TRACE: torch_flash_ma | flux_L512
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us     754.076us       118.52%     754.076us     754.076us             1  
                                         torch_flash_ma        10.33%     251.863us        96.72%       2.358ms       2.358ms       0.000us         0.00%     647.964us     647.964us             1  
                     aten::scaled_dot_product_attention         1.02%      24.850us         7.50%     182.789us      60.930us       0.000us         0.00%     507.517us     169.172us             3  
              aten::_scaled_dot_product_flash_attention         0.72%      17.614us         6.48%     157.939us      52.646us       0.000us         0.00%     507.517us     169.172us             3  
                         aten::_flash_attention_forward         1.67%      40.594us         4.82%     117.465us      39.155us     507.517us        79.77%     507.517us     169.172us             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us     507.517us        79.77%     507.517us     169.172us             3  
                                       aten::contiguous         0.38%       9.202us        77.00%       1.877ms     156.434us       0.000us         0.00%     140.447us      11.704us            12  
                                            aten::clone         1.22%      29.851us        76.63%       1.868ms     155.667us       0.000us         0.00%     140.447us      11.704us            12  
                                            aten::copy_         3.45%      84.032us        72.63%       1.771ms     147.547us     128.703us        20.23%     140.447us      11.704us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     128.703us        20.23%     128.703us      10.725us            12  
                                Activity Buffer Request        59.63%       1.454ms        59.63%       1.454ms       1.454ms      11.744us         1.85%      11.744us      11.744us             1  
                                        aten::transpose         2.09%      51.002us         2.82%      68.782us       2.866us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.73%      17.780us         0.73%      17.780us       0.741us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.85%      20.819us         3.58%      87.161us       5.811us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         3.27%      79.813us         3.27%      79.813us       3.326us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel        10.50%     256.026us        10.50%     256.026us      17.068us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.59%      14.340us         0.59%      14.340us       4.780us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.08%       1.949us         0.08%       1.949us       0.325us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.18%       4.440us         0.18%       4.440us       1.480us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize         3.28%      80.003us         3.28%      80.003us      80.003us       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 2.438ms
Self CUDA time total: 636.220us


impl                     wl                  p50(ms)  ok
torch_flash_ma           flux_L128              0.18  True
torch_flash_ma           flux_L256              0.21  True
torch_flash_ma           flux_L320              0.22  True
torch_flash_ma           flux_L384              0.22  True
torch_flash_ma           flux_L448              0.27  True
torch_flash_ma           flux_L512              0.28  True

Artifacts:

attn.jsonl