+

HF Kernels - Flash Attention 3

+

HuggingFace Kernels Flash Attention 3 Benchmark

+
+
+ +▼ code +▼ output + ▶ uv-logs + | +Cell: benchmark | 38.68s + | + +Raw +
+
+
+
# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+#     "numpy",
+#     "torch",
+#     "kernels-benchmark-tools",
+#     "kernels",
+# ]
+#
+# [tool.uv.sources]
+# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
+# ///
+import torch
+import sys
+import os
+import kernels_benchmark_tools as kbt
+from kernels import get_kernel
+
+hf_kernels_flash_attn3 = get_kernel("kernels-community/flash-attn3")
+
+
+def hf_flash_attention3(query, key, value):
+    return hf_kernels_flash_attn3.flash_attn_func(query, key, value, causal=False)[0]
+
+
+kbt.add(
+    "hf_kernels_flash_attn3",
+    hf_flash_attention3,
+    tags={"family": "hf-kernels", "backend": "flash-attn3", "compile": "none"},
+)
+
+if __name__ == "__main__":
+    device = "cuda" if torch.cuda.is_available() else "cpu"
+
+    if device == "cpu":
+        print("HF Kernels Flash Attention 3 requires CUDA - skipping benchmark")
+        sys.exit(0)
+
+    dtype = "bfloat16"
+
+    # Flux-like workloads
+    base = 1024
+    flux_sizes = [128, 256, 320, 384, 448, 512]
+    heads = 24
+    head_dim = 128
+
+    wl = []
+    for L in flux_sizes:
+        wl.append(
+            {
+                "name": f"flux_L{L}",
+                "batch": 1,
+                "seq_len": base + L,
+                "heads": heads,
+                "head_dim": head_dim,
+                "dtype": dtype,
+                "device": device,
+                "seed": 0,
+            }
+        )
+
+    kbt.run(
+        wl,
+        jsonl="attn.jsonl",
+        reps=5,
+        warmup=2,
+        gen=kbt.attn.gen_qkv,
+        ref=kbt.attn.ref_math,
+        cmp=kbt.attn.cmp_allclose,
+    )
+    kbt.summarize(["attn.jsonl"])
+
+ +
+
+
+
+
impl wl p50(ms) ok +hf_kernels_flash_attn3 flux_L128 0.36 True +hf_kernels_flash_attn3 flux_L256 0.39 True +hf_kernels_flash_attn3 flux_L320 0.52 True +hf_kernels_flash_attn3 flux_L384 0.53 True +hf_kernels_flash_attn3 flux_L448 0.57 True +hf_kernels_flash_attn3 flux_L512 0.57 True +
+
+
▶ UV Install Logs
+ +
+
Fetching 4 files: 0%| | 0/4 [00:00<?, ?it/s] +Fetching 4 files: 25%|██▌ | 1/4 [00:00<00:00, 5.61it/s] +Fetching 4 files: 50%|█████ | 2/4 [00:01<00:01, 1.13it/s] +Fetching 4 files: 100%|██████████| 4/4 [00:01<00:00, 2.56it/s]
+
+

Artifacts:

+attn.jsonl +
+
+
+