File size: 877 Bytes
58b76f1 81fff32 d8c3a70 d87c146 58b76f1 81fff32 58b76f1 05bebc1 d8c3a70 58b76f1 d87c146 58b76f1 d8c3a70 d87c146 58b76f1 05bebc1 d87c146 05bebc1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "numpy",
# "torch==2.8.0",
# "kernels",
# "kernels-benchmark-tools",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
# ///
import torch
import sys
from kernels_benchmark_tools import KernelTypeEnum, run_benchmark
from kernels import get_kernel
# Load the sage attention kernel
hf_kernels_sage_attn = get_kernel("kernels-community/sage_attention")
def sage_attention(query, key, value):
"""SageAttention with INT8 Q/K quantization and FP16 P/V"""
return hf_kernels_sage_attn.fwd(query, key, value, is_causal=False)[0]
run_benchmark(
kernel_type=KernelTypeEnum.ATTENTION,
impl_name="sage_int8_fp16",
impl_tags={"family": "sageattention", "backend": "int8_fp16_cuda", "compile": "none"},
impl_func=sage_attention,
) |