+

Flash Attention Implementation

+

GPU Info

+
+
+ +▼ code +▼ output + ▶ uv-logs + | +Cell: nv | 0.68s + | + +Raw +
+
+
+
import subprocess
+
+print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)
+
+ +
+
+
+
+
Thu Oct 2 15:03:41 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.172.08 Driver Version: 570.172.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA A10G On | 00000000:00:1B.0 Off | 0 | +| 0% 31C P0 87W / 300W | 0MiB / 23028MiB | 0% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA A10G On | 00000000:00:1C.0 Off | 0 | +| 0% 25C P8 23W / 300W | 0MiB / 23028MiB | 0% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA A10G On | 00000000:00:1D.0 Off | 0 | +| 0% 26C P8 23W / 300W | 0MiB / 23028MiB | 0% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA A10G On | 00000000:00:1E.0 Off | 0 | +| 0% 25C P8 24W / 300W | 0MiB / 23028MiB | 0% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +
+
+
+ +

Flash Attention Benchmark

+
+
+ +▼ code +▼ output + ▶ uv-logs + | +Cell: benchmark | 35.67s + | + +Raw +
+
+
+
# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+#     "numpy",
+#     "torch",
+#     "kernels-benchmark-tools",
+# ]
+#
+# [tool.uv.sources]
+# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
+# ///
+import torch
+import sys
+import os
+import kernels_benchmark_tools as kbt
+
+
+def torch_flash(q, k, v):
+    qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
+    with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
+        o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
+    return o.transpose(1, 2).contiguous()
+
+kbt.add(
+    "torch_flash_ma",
+    torch_flash,
+    tags={"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"},
+)
+
+if __name__ == "__main__":
+    device = "cuda" if torch.cuda.is_available() else "cpu"
+    dtype = "float32" if device == "cpu" else "bfloat16"
+
+    # Flux-like workloads scaled down for CPU testing
+    base = 1024 if device == "cuda" else 512
+    flux_sizes = (
+        [128, 256, 320, 384, 448, 512] if device == "cuda" else [64, 128, 192, 256]
+    )
+    heads = 24 if device == "cuda" else 8
+    head_dim = 128 if device == "cuda" else 64
+
+    wl = []
+    for L in flux_sizes:
+        wl.append(
+            {
+                "name": f"flux_L{L}",
+                "batch": 1,
+                "seq_len": base + L,
+                "heads": heads,
+                "head_dim": head_dim,
+                "dtype": dtype,
+                "device": device,
+                "seed": 0,
+            }
+        )
+
+    kbt.run(
+        wl,
+        jsonl="attn.jsonl",
+        reps=5,
+        warmup=2,
+        gen=kbt.attn.gen_qkv,
+        ref=kbt.attn.ref_math,
+        cmp=kbt.attn.cmp_allclose,
+    )
+    kbt.summarize(["attn.jsonl"])
+
+ +
+
+
+
+
impl wl p50(ms) ok +torch_flash_ma flux_L128 0.48 True +torch_flash_ma flux_L256 0.53 True +torch_flash_ma flux_L320 0.65 True +torch_flash_ma flux_L384 0.68 True +torch_flash_ma flux_L448 0.71 True +torch_flash_ma flux_L512 0.74 True +
+
+
▶ UV Install Logs
+ +
+
+

Artifacts:

+attn.jsonl +
+
+
+