+

Torch Compile Variants

+

This file benchmarks Flash Attention with different torch.compile modes.

+

Flash Attention with torch.compile(mode="default")

+
+
+ +▼ code +▼ output + ▶ uv-logs + | +Cell: benchmark_default | 46.71s + | + +Raw +
+
+
+
# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+#     "numpy",
+#     "torch",
+#     "kernels-benchmark-tools",
+# ]
+#
+# [tool.uv.sources]
+# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
+# ///
+import torch
+import sys
+import os
+import kernels_benchmark_tools as kbt
+
+
+def torch_flash_base(q, k, v):
+    qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
+    with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
+        o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
+    return o.transpose(1, 2).contiguous()
+
+
+# Compile with default mode
+compiled_flash_default = torch.compile(torch_flash_base, mode="default", fullgraph=True, dynamic=False)
+
+kbt.add(
+    "torch_flash_compiled_default",
+    compiled_flash_default,
+    tags={"family": "torch-sdpa", "backend": "FLASH", "compile": "default"},
+)
+
+if __name__ == "__main__":
+    device = "cuda" if torch.cuda.is_available() else "cpu"
+    dtype = "float32" if device == "cpu" else "bfloat16"
+
+    # Flux-like workloads
+    base = 1024 if device == "cuda" else 512
+    flux_sizes = (
+        [128, 256, 320, 384, 448, 512] if device == "cuda" else [64, 128, 192, 256]
+    )
+    heads = 24 if device == "cuda" else 8
+    head_dim = 128 if device == "cuda" else 64
+
+    wl = []
+    for L in flux_sizes:
+        wl.append(
+            {
+                "name": f"flux_L{L}",
+                "batch": 1,
+                "seq_len": base + L,
+                "heads": heads,
+                "head_dim": head_dim,
+                "dtype": dtype,
+                "device": device,
+                "seed": 0,
+            }
+        )
+
+    kbt.run(
+        wl,
+        jsonl="attn_default.jsonl",
+        reps=5,
+        warmup=2,
+        gen=kbt.attn.gen_qkv,
+        ref=kbt.attn.ref_math,
+        cmp=kbt.attn.cmp_allclose,
+    )
+    kbt.summarize(["attn_default.jsonl"])
+
+ +
+
+
+
+
impl wl p50(ms) ok +torch_flash_compiled_default flux_L128 0.52 True +torch_flash_compiled_default flux_L256 0.56 True +torch_flash_compiled_default flux_L320 0.68 True +torch_flash_compiled_default flux_L384 0.72 True +torch_flash_compiled_default flux_L448 0.75 True +torch_flash_compiled_default flux_L512 0.77 True +
+
+
▶ UV Install Logs
+ +
+
+

Artifacts:

+attn_default.jsonl +
+
+
+ +

Flash Attention with torch.compile(mode="max-autotune")

+
+
+ +▼ code +▼ output + ▶ uv-logs + | +Cell: benchmark_max_autotune | 53.95s + | + +Raw +
+
+
+
# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+#     "numpy",
+#     "torch",
+#     "kernels-benchmark-tools",
+# ]
+#
+# [tool.uv.sources]
+# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
+# ///
+import torch
+import sys
+import os
+import kernels_benchmark_tools as kbt
+
+
+def torch_flash_base(q, k, v):
+    qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
+    with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
+        o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
+    return o.transpose(1, 2).contiguous()
+
+
+# Compile with max-autotune mode
+compiled_flash_max_autotune = torch.compile(torch_flash_base, mode="max-autotune", fullgraph=True, dynamic=False)
+
+kbt.add(
+    "torch_flash_compiled_max_autotune",
+    compiled_flash_max_autotune,
+    tags={"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"},
+)
+
+if __name__ == "__main__":
+    device = "cuda" if torch.cuda.is_available() else "cpu"
+    dtype = "float32" if device == "cpu" else "bfloat16"
+
+    # Flux-like workloads
+    base = 1024 if device == "cuda" else 512
+    flux_sizes = (
+        [128, 256, 320, 384, 448, 512] if device == "cuda" else [64, 128, 192, 256]
+    )
+    heads = 24 if device == "cuda" else 8
+    head_dim = 128 if device == "cuda" else 64
+
+    wl = []
+    for L in flux_sizes:
+        wl.append(
+            {
+                "name": f"flux_L{L}",
+                "batch": 1,
+                "seq_len": base + L,
+                "heads": heads,
+                "head_dim": head_dim,
+                "dtype": dtype,
+                "device": device,
+                "seed": 0,
+            }
+        )
+
+    kbt.run(
+        wl,
+        jsonl="attn_max_autotune.jsonl",
+        reps=5,
+        warmup=2,
+        gen=kbt.attn.gen_qkv,
+        ref=kbt.attn.ref_math,
+        cmp=kbt.attn.cmp_allclose,
+    )
+    kbt.summarize(["attn_max_autotune.jsonl"])
+
+ +
+
+
+
+
impl wl p50(ms) ok +torch_flash_compiled_max_autotune flux_L128 0.64 True +torch_flash_compiled_max_autotune flux_L256 0.68 True +torch_flash_compiled_max_autotune flux_L320 0.81 True +torch_flash_compiled_max_autotune flux_L384 0.85 True +torch_flash_compiled_max_autotune flux_L448 0.90 True +torch_flash_compiled_max_autotune flux_L512 0.92 True +
+
+
▶ UV Install Logs
+ +
+
+

Artifacts:

+attn_max_autotune.jsonl +
+
+
+