# /// script
# requires-python = ">=3.10"
# dependencies = [
# "numpy",
# "torch==2.8.0",
# "kernels-benchmark-tools",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
# ///
import torch
import sys
import os
import kernels_benchmark_tools as kbt
def torch_flash_base(q, k, v):
qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
return o.transpose(1, 2).contiguous()
# Compile with default mode
compiled_flash_default = torch.compile(torch_flash_base, mode="default", fullgraph=True, dynamic=False)
kbt.add(
"torch_flash_compiled_default",
compiled_flash_default,
tags={"family": "torch-sdpa", "backend": "FLASH", "compile": "default"},
)
if __name__ == "__main__":
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = "float32" if device == "cpu" else "bfloat16"
# Flux-like workloads
base = 1024 if device == "cuda" else 512
flux_sizes = (
[128, 256, 320, 384, 448, 512] if device == "cuda" else [64, 128, 192, 256]
)
heads = 24 if device == "cuda" else 8
head_dim = 128 if device == "cuda" else 64
wl = []
for L in flux_sizes:
wl.append(
{
"name": f"flux_L{L}",
"batch": 1,
"seq_len": base + L,
"heads": heads,
"head_dim": head_dim,
"dtype": dtype,
"device": device,
"seed": 0,
}
)
kbt.run(
wl,
jsonl="attn_default.jsonl",
reps=5,
warmup=2,
gen=kbt.attn.gen_qkv,
ref=kbt.attn.ref_math,
cmp=kbt.attn.cmp_allclose,
profile_trace=True
)
kbt.summarize(["attn_default.jsonl"])
======================================================================
PROFILE TRACE: torch_flash_compiled_default | flux_L128
======================================================================
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
torch_flash_compiled_default 0.00% 0.000us 0.00% 0.000us 0.000us 967.332us 298.12% 967.332us 967.332us 1
torch_flash_compiled_default 5.37% 154.798us 99.77% 2.878ms 2.878ms 0.000us 0.00% 324.481us 324.481us 1
Torch-Compiled Region: 0/1 20.96% 604.478us 92.49% 2.668ms 889.236us 0.000us 0.00% 324.481us 108.160us 3
aten::_scaled_dot_product_flash_attention 1.54% 44.432us 8.35% 240.853us 80.284us 0.000us 0.00% 276.257us 92.086us 3
aten::_flash_attention_forward 1.64% 47.371us 5.29% 152.657us 50.886us 276.257us 85.14% 276.257us 92.086us 3
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne... 0.00% 0.000us 0.00% 0.000us 0.000us 276.257us 85.14% 276.257us 92.086us 3
triton_poi_fused__scaled_dot_product_flash_attention... 3.50% 100.807us 6.04% 174.309us 19.368us 36.704us 11.31% 36.704us 4.078us 9
triton_poi_fused__scaled_dot_product_flash_attention... 0.00% 0.000us 0.00% 0.000us 0.000us 36.704us 11.31% 36.704us 4.078us 9
triton_poi_fused_clone_1 1.27% 36.672us 2.17% 62.583us 20.861us 11.520us 3.55% 11.520us 3.840us 3
triton_poi_fused_clone_1 0.00% 0.000us 0.00% 0.000us 0.000us 11.520us 3.55% 11.520us 3.840us 3
TorchDynamo Cache Lookup 1.91% 55.093us 1.91% 55.093us 18.364us 0.000us 0.00% 0.000us 0.000us 3
Pregraph bytecode 0.36% 10.400us 0.36% 10.400us 3.467us 0.000us 0.00% 0.000us 0.000us 3
AOTDispatcher Runtime Wrapper Prologue 0.70% 20.280us 0.70% 20.280us 6.760us 0.000us 0.00% 0.000us 0.000us 3
Activity Buffer Request 53.91% 1.555ms 53.91% 1.555ms 1.555ms 0.000us 0.00% 0.000us 0.000us 1
cuLaunchKernel 3.45% 99.413us 3.45% 99.413us 8.284us 0.000us 0.00% 0.000us 0.000us 12
aten::transpose 1.19% 34.395us 1.52% 43.764us 3.647us 0.000us 0.00% 0.000us 0.000us 12
aten::as_strided 0.32% 9.369us 0.32% 9.369us 0.781us 0.000us 0.00% 0.000us 0.000us 12
aten::empty_like 0.44% 12.621us 1.20% 34.732us 11.577us 0.000us 0.00% 0.000us 0.000us 3
aten::empty_strided 0.77% 22.111us 0.77% 22.111us 7.370us 0.000us 0.00% 0.000us 0.000us 3
aten::empty 1.24% 35.841us 1.24% 35.841us 2.987us 0.000us 0.00% 0.000us 0.000us 12
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 2.884ms
Self CUDA time total: 324.481us
======================================================================
PROFILE TRACE: torch_flash_compiled_default | flux_L256
======================================================================
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
torch_flash_compiled_default 0.00% 0.000us 0.00% 0.000us 0.000us 834.378us 233.60% 834.378us 834.378us 1
torch_flash_compiled_default 4.04% 97.294us 99.68% 2.400ms 2.400ms 0.000us 0.00% 357.190us 357.190us 1
Torch-Compiled Region: 0/3 19.97% 480.803us 94.43% 2.274ms 757.987us 0.000us 0.00% 357.190us 119.063us 3
aten::_scaled_dot_product_flash_attention 1.08% 25.983us 7.33% 176.640us 58.880us 0.000us 0.00% 300.165us 100.055us 3
aten::_flash_attention_forward 1.50% 36.164us 5.01% 120.717us 40.239us 300.165us 84.04% 300.165us 100.055us 3
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne... 0.00% 0.000us 0.00% 0.000us 0.000us 300.165us 84.04% 300.165us 100.055us 3
triton_poi_fused__scaled_dot_product_flash_attention... 3.30% 79.496us 6.27% 150.937us 16.771us 40.161us 11.24% 40.161us 4.462us 9
triton_poi_fused__scaled_dot_product_flash_attention... 0.00% 0.000us 0.00% 0.000us 0.000us 40.161us 11.24% 40.161us 4.462us 9
triton_poi_fused_clone_1 2.33% 56.123us 3.38% 81.404us 27.135us 16.864us 4.72% 16.864us 5.621us 3
triton_poi_fused_clone_1 0.00% 0.000us 0.00% 0.000us 0.000us 16.864us 4.72% 16.864us 5.621us 3
TorchDynamo Cache Lookup 1.21% 29.133us 1.21% 29.133us 9.711us 0.000us 0.00% 0.000us 0.000us 3
Pregraph bytecode 0.32% 7.730us 0.32% 7.730us 2.577us 0.000us 0.00% 0.000us 0.000us 3
AOTDispatcher Runtime Wrapper Prologue 0.49% 11.750us 0.49% 11.750us 3.917us 0.000us 0.00% 0.000us 0.000us 3
Activity Buffer Request 56.67% 1.365ms 56.67% 1.365ms 1.365ms 0.000us 0.00% 0.000us 0.000us 1
cuLaunchKernel 4.02% 96.722us 4.02% 96.722us 8.060us 0.000us 0.00% 0.000us 0.000us 12
aten::transpose 0.90% 21.580us 1.24% 29.940us 2.495us 0.000us 0.00% 0.000us 0.000us 12
aten::as_strided 0.35% 8.360us 0.35% 8.360us 0.697us 0.000us 0.00% 0.000us 0.000us 12
aten::empty_like 0.27% 6.480us 1.00% 23.971us 7.990us 0.000us 0.00% 0.000us 0.000us 3
aten::empty_strided 0.73% 17.491us 0.73% 17.491us 5.830us 0.000us 0.00% 0.000us 0.000us 3
aten::empty 1.24% 29.800us 1.24% 29.800us 2.483us 0.000us 0.00% 0.000us 0.000us 12
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 2.408ms
Self CUDA time total: 357.190us
======================================================================
PROFILE TRACE: torch_flash_compiled_default | flux_L320
======================================================================
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
torch_flash_compiled_default 0.00% 0.000us 0.00% 0.000us 0.000us 876.295us 230.02% 876.295us 876.295us 1
torch_flash_compiled_default 3.99% 99.235us 99.67% 2.477ms 2.477ms 0.000us 0.00% 380.963us 380.963us 1
Torch-Compiled Region: 0/5 19.71% 489.623us 94.50% 2.348ms 782.708us 0.000us 0.00% 380.963us 126.988us 3
aten::_scaled_dot_product_flash_attention 1.15% 28.583us 7.58% 188.458us 62.819us 0.000us 0.00% 323.107us 107.702us 3
aten::_flash_attention_forward 1.61% 40.110us 5.06% 125.615us 41.872us 323.107us 84.81% 323.107us 107.702us 3
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne... 0.00% 0.000us 0.00% 0.000us 0.000us 323.107us 84.81% 323.107us 107.702us 3
triton_poi_fused__scaled_dot_product_flash_attention... 3.47% 86.344us 6.19% 153.807us 17.090us 44.448us 11.67% 44.448us 4.939us 9
triton_poi_fused__scaled_dot_product_flash_attention... 0.00% 0.000us 0.00% 0.000us 0.000us 44.448us 11.67% 44.448us 4.939us 9
triton_poi_fused_clone_1 1.44% 35.902us 2.40% 59.634us 19.878us 13.408us 3.52% 13.408us 4.469us 3
triton_poi_fused_clone_1 0.00% 0.000us 0.00% 0.000us 0.000us 13.408us 3.52% 13.408us 4.469us 3
TorchDynamo Cache Lookup 1.18% 29.223us 1.18% 29.223us 9.741us 0.000us 0.00% 0.000us 0.000us 3
Pregraph bytecode 0.30% 7.450us 0.30% 7.450us 2.483us 0.000us 0.00% 0.000us 0.000us 3
AOTDispatcher Runtime Wrapper Prologue 0.46% 11.502us 0.46% 11.502us 3.834us 0.000us 0.00% 0.000us 0.000us 3
Activity Buffer Request 57.86% 1.438ms 57.86% 1.438ms 1.438ms 0.000us 0.00% 0.000us 0.000us 1
cuLaunchKernel 3.67% 91.195us 3.67% 91.195us 7.600us 0.000us 0.00% 0.000us 0.000us 12
aten::transpose 0.95% 23.681us 1.38% 34.260us 2.855us 0.000us 0.00% 0.000us 0.000us 12
aten::as_strided 0.43% 10.579us 0.43% 10.579us 0.882us 0.000us 0.00% 0.000us 0.000us 12
aten::empty_like 0.27% 6.811us 0.93% 23.051us 7.684us 0.000us 0.00% 0.000us 0.000us 3
aten::empty_strided 0.65% 16.240us 0.65% 16.240us 5.413us 0.000us 0.00% 0.000us 0.000us 3
aten::empty 1.30% 32.232us 1.30% 32.232us 2.686us 0.000us 0.00% 0.000us 0.000us 12
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 2.485ms
Self CUDA time total: 380.963us
======================================================================
PROFILE TRACE: torch_flash_compiled_default | flux_L384
======================================================================
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
torch_flash_compiled_default 0.00% 0.000us 0.00% 0.000us 0.000us 900.385us 224.95% 900.385us 900.385us 1
torch_flash_compiled_default 3.56% 101.756us 99.74% 2.848ms 2.848ms 0.000us 0.00% 400.258us 400.258us 1
Torch-Compiled Region: 0/7 18.27% 521.655us 95.19% 2.718ms 906.103us 0.000us 0.00% 400.258us 133.419us 3
aten::_scaled_dot_product_flash_attention 0.99% 28.253us 6.33% 180.729us 60.243us 0.000us 0.00% 336.352us 112.117us 3
aten::_flash_attention_forward 1.29% 36.890us 4.19% 119.565us 39.855us 336.352us 84.03% 336.352us 112.117us 3
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne... 0.00% 0.000us 0.00% 0.000us 0.000us 336.352us 84.03% 336.352us 112.117us 3
triton_poi_fused__scaled_dot_product_flash_attention... 3.07% 87.777us 16.12% 460.302us 51.145us 49.985us 12.49% 49.985us 5.554us 9
triton_poi_fused__scaled_dot_product_flash_attention... 0.00% 0.000us 0.00% 0.000us 0.000us 49.985us 12.49% 49.985us 5.554us 9
triton_poi_fused_clone_1 1.24% 35.330us 2.05% 58.492us 19.497us 13.921us 3.48% 13.921us 4.640us 3
triton_poi_fused_clone_1 0.00% 0.000us 0.00% 0.000us 0.000us 13.921us 3.48% 13.921us 4.640us 3
TorchDynamo Cache Lookup 0.99% 28.213us 0.99% 28.213us 9.404us 0.000us 0.00% 0.000us 0.000us 3
Pregraph bytecode 0.25% 7.170us 0.25% 7.170us 2.390us 0.000us 0.00% 0.000us 0.000us 3
AOTDispatcher Runtime Wrapper Prologue 0.43% 12.361us 0.43% 12.361us 4.120us 0.000us 0.00% 0.000us 0.000us 3
Activity Buffer Request 51.74% 1.478ms 51.74% 1.478ms 1.478ms 0.000us 0.00% 0.000us 0.000us 1
cuLaunchKernel 13.86% 395.687us 13.86% 395.687us 32.974us 0.000us 0.00% 0.000us 0.000us 12
aten::transpose 0.83% 23.691us 1.15% 32.911us 2.743us 0.000us 0.00% 0.000us 0.000us 12
aten::as_strided 0.32% 9.220us 0.32% 9.220us 0.768us 0.000us 0.00% 0.000us 0.000us 12
aten::empty_like 0.23% 6.600us 0.78% 22.311us 7.437us 0.000us 0.00% 0.000us 0.000us 3
aten::empty_strided 0.55% 15.711us 0.55% 15.711us 5.237us 0.000us 0.00% 0.000us 0.000us 3
aten::empty 1.03% 29.502us 1.03% 29.502us 2.459us 0.000us 0.00% 0.000us 0.000us 12
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 2.856ms
Self CUDA time total: 400.258us
impl wl p50(ms) ok
torch_flash_compiled_default flux_L128 0.20 True
torch_flash_compiled_default flux_L256 0.23 True
torch_flash_compiled_default flux_L320 0.24 True
torch_flash_compiled_default flux_L384 0.24 True
torch_flash_compiled_default flux_L448 FAIL False
Error: recompile_limit reached with one_graph=True. Excessive recompilations can degrade performance due to the compilation overhead of each recompilation. To monitor recompilations, enable TORCH_LOGS=recompiles. If recompilations are expected, consider increasing torch._dynamo.config.cache_size_limit to an appropriate value.
torch_flash_compiled_default flux_L512 FAIL False
Error: recompile_limit reached with one_graph=True. Excessive recompilations can degrade performance due to the compilation overhead of each recompilation. To monitor recompilations, enable TORCH_LOGS=recompiles. If recompilations are expected, consider increasing torch._dynamo.config.cache_size_limit to an appropriate value.
▶ UV Install Logs
W1023 17:21:27.942000 6833 torch/_dynamo/convert_frame.py:1016] [0/8] torch._dynamo hit config.recompile_limit (8)
W1023 17:21:27.942000 6833 torch/_dynamo/convert_frame.py:1016] [0/8] function: 'torch_flash_base' (/__w/kernels-benchmarks/kernels-benchmarks/benches/flash_attn/impls/.uvnote/cells/benchmark_default.py:18)
W1023 17:21:27.942000 6833 torch/_dynamo/convert_frame.py:1016] [0/8] last reason: 0/7: GLOBAL_STATE changed: num_threads
W1023 17:21:27.942000 6833 torch/_dynamo/convert_frame.py:1016] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W1023 17:21:27.942000 6833 torch/_dynamo/convert_frame.py:1016] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html.
W1023 17:21:27.948000 6833 torch/_dynamo/convert_frame.py:1016] [0/9] torch._dynamo hit config.recompile_limit (8)
W1023 17:21:27.948000 6833 torch/_dynamo/convert_frame.py:1016] [0/9] function: 'torch_flash_base' (/__w/kernels-benchmarks/kernels-benchmarks/benches/flash_attn/impls/.uvnote/cells/benchmark_default.py:18)
W1023 17:21:27.948000 6833 torch/_dynamo/convert_frame.py:1016] [0/9] last reason: 0/7: GLOBAL_STATE changed: num_threads
W1023 17:21:27.948000 6833 torch/_dynamo/convert_frame.py:1016] [0/9] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W1023 17:21:27.948000 6833 torch/_dynamo/convert_frame.py:1016] [0/9] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html.