INT8 Fused Attention CUDA Kernel
A high-performance INT8 fused attention kernel for diffusion transformers implemented in CUDA with WMMA tensor core acceleration.
The kernel performs scaled dot-product attention while computing the QK matrix multiply in INT8 and keeping softmax and value accumulation in FP32 for numerical stability.
The implementation approach is inspired by SageAttention.
Builds
| PyTorch | CUDA | Architecture | ABI |
|---|---|---|---|
| 2.9 | 13.0 | sm_75, sm_80, sm_86, sm_89, sm_90 | cxx11, manylinux_2_28 |
| 2.10 | 12.6 | sm_75, sm_80, sm_86, sm_89, sm_90 | cxx11, manylinux_2_28 |
Loading via kernels(recommended):
The kernel can be downloaded and loaded dynamically using the kernels package.
import torch
from kernels import get_kernel
kernel = get_kernel(
"kernels-community/attention-int8",
version=1
)
Python API
from attention_int8 import int8_attention_forward
output = int8_attention_forward(
Q, # [B, H, N, D]
K, # [B, kv_H, N, D]
V, # [B, kv_H, N, D]
timestep_scales, # [T] float32 or None
timestep=0,
causal=False
)
Key properties:
INT8 tensor core QK multiplication
FP32 softmax accumulation
FP16 inputs and outputs
Causal masking support
Grouped Query Attention (GQA) support
Dynamic per-tile quantization of Q and K
Supported HEAD_DIM Values 32, 64, 80, 96, 128, 160, 256
Notes:
INT8 WMMA tensor cores require sm_75 (Turing) or newer
Volta architectures (sm_70 / sm_72) are excluded but a fallback mechanism is supported for unsupported architectures.
- Downloads last month
- 8
