+

SageAttention Implementation

+

SageAttention Benchmark (INT8 Quantized)

+
+
+ +▼ code +▼ output + ▶ uv-logs + | +Cell: benchmark | 40.11s + | + +Raw +
+
+
+
# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+#     "numpy",
+#     "torch",
+#     "kernels",
+#     "kernels-benchmark-tools",
+#     "sageattention",
+# ]
+#
+# [tool.uv.sources]
+# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
+# ///
+import torch
+import sys
+import os
+import kernels_benchmark_tools as kbt
+# from sageattention import sageattn_qk_int8_pv_fp16_cuda
+
+
+# def sage_attention(q, k, v):
+#     """SageAttention with INT8 Q/K quantization and FP16 P/V"""
+#     return sageattn_qk_int8_pv_fp16_cuda(q, k, v, tensor_layout="NHD")
+
+from kernels import get_kernel
+
+hf_kernels_sage_attn = get_kernel("kernels-community/sage_attention")
+
+
+def sage_attention(query, key, value):
+    """HuggingFace Kernels Flash Attention"""
+    return hf_kernels_sage_attn.fwd(query, key, value, is_causal=False)[0]
+
+kbt.add(
+    "sage_int8_fp16",
+    sage_attention,
+    tags={"family": "sageattention", "backend": "int8_fp16_cuda", "compile": "none"},
+)
+
+if __name__ == "__main__":
+    device = "cuda" if torch.cuda.is_available() else "cpu"
+
+    if device == "cpu":
+        print("SageAttention requires CUDA - skipping benchmark")
+        sys.exit(0)
+
+    dtype = "bfloat16"
+
+    # Flux-like workloads
+    base = 1024
+    flux_sizes = [128, 256, 320, 384, 448, 512]
+    heads = 24
+    head_dim = 128
+
+    wl = []
+    for L in flux_sizes:
+        wl.append(
+            {
+                "name": f"flux_L{L}",
+                "batch": 1,
+                "seq_len": base + L,
+                "heads": heads,
+                "head_dim": head_dim,
+                "dtype": dtype,
+                "device": device,
+                "seed": 0,
+            }
+        )
+
+    kbt.run(
+        wl,
+        jsonl="attn.jsonl",
+        reps=5,
+        warmup=2,
+        gen=kbt.attn.gen_qkv,
+        ref=kbt.attn.ref_math,
+        cmp=kbt.attn.cmp_allclose,
+    )
+    kbt.summarize(["attn.jsonl"])
+
+ +
+
+
+
+
impl wl p50(ms) ok +sage_int8_fp16 flux_L128 FAIL False + Error: module 'sage_attention_a39c012a73160148' has no attribute 'fwd' +sage_int8_fp16 flux_L256 FAIL False + Error: module 'sage_attention_a39c012a73160148' has no attribute 'fwd' +sage_int8_fp16 flux_L320 FAIL False + Error: module 'sage_attention_a39c012a73160148' has no attribute 'fwd' +sage_int8_fp16 flux_L384 FAIL False + Error: module 'sage_attention_a39c012a73160148' has no attribute 'fwd' +sage_int8_fp16 flux_L448 FAIL False + Error: module 'sage_attention_a39c012a73160148' has no attribute 'fwd' +sage_int8_fp16 flux_L512 FAIL False + Error: module 'sage_attention_a39c012a73160148' has no attribute 'fwd' +
+
+
▶ UV Install Logs
+ +
+
Fetching 11 files: 0%| | 0/11 [00:00<?, ?it/s] +Fetching 11 files: 9%|▉ | 1/11 [00:00<00:01, 5.70it/s] +Fetching 11 files: 18%|█▊ | 2/11 [00:00<00:01, 6.67it/s] +Fetching 11 files: 27%|██▋ | 3/11 [00:00<00:01, 6.46it/s] +Fetching 11 files: 64%|██████▎ | 7/11 [00:00<00:00, 11.66it/s] +Fetching 11 files: 100%|██████████| 11/11 [00:00<00:00, 15.59it/s]
+
+

Artifacts:

+attn.jsonl +
+
+
+