Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/common.cpython-311.pyc +3 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp.cpython-311.pyc +3 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_cpu.cpython-311.pyc +3 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/b2b_gemm.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/binary_folding.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/decompose_mem_bound_mm.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/dedupe_symint_uses.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/efficient_conv_bn_eval.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/freezing_patterns.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/fuse_attention.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/group_batch_fusion.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/joint_graph.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/micro_pipeline_tp.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/misc_patterns.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/mkldnn_fusion.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/numeric_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/pad_mm.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/post_grad.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/pre_grad.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/quantization.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/reinplace.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/replace_random.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/kernel/__init__.py +1 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/bmm.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/conv.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/flex_attention.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/flex_decoding.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/mm.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/mm_common.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/mm_plus_mm.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/mm_scaled.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/unpack_mixed_mm.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/kernel/bmm.py +192 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/kernel/conv.py +679 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/kernel/flex_attention.py +1843 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/kernel/flex_decoding.py +570 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/kernel/mm.py +776 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/kernel/mm_common.py +466 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/kernel/mm_plus_mm.py +248 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/kernel/mm_scaled.py +311 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/kernel/unpack_mixed_mm.py +87 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/runtime/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/runtime/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/runtime/__pycache__/autotune_cache.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/runtime/__pycache__/benchmarking.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/runtime/__pycache__/compile_tasks.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/runtime/__pycache__/coordinate_descent_tuner.cpython-311.pyc +0 -0
.gitattributes
CHANGED
|
@@ -142,3 +142,6 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
|
|
| 142 |
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/simd.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 143 |
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/halide.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 144 |
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/wrapper.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/simd.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 143 |
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/halide.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 144 |
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/wrapper.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 145 |
+
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_cpu.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 146 |
+
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 147 |
+
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/common.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/common.cpython-311.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ee19b8e0b980d0895a7af50aa7c3244d133ce110a196485ab8cec5fa7b9767d4
|
| 3 |
+
size 121452
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp.cpython-311.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e8596ce3d305b9ea76fd93737e3fda25769b1901142db9efff0fde9757b03517
|
| 3 |
+
size 262897
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_cpu.cpython-311.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:de158b207dec0ef6dd7cca5acc1db68fcc605b0046ed6c5ffcf0d9b8f34d3b82
|
| 3 |
+
size 138985
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (198 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/b2b_gemm.cpython-311.pyc
ADDED
|
Binary file (26.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/binary_folding.cpython-311.pyc
ADDED
|
Binary file (13.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/decompose_mem_bound_mm.cpython-311.pyc
ADDED
|
Binary file (8.35 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/dedupe_symint_uses.cpython-311.pyc
ADDED
|
Binary file (5.27 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/efficient_conv_bn_eval.cpython-311.pyc
ADDED
|
Binary file (14.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/freezing_patterns.cpython-311.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/fuse_attention.cpython-311.pyc
ADDED
|
Binary file (41.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/group_batch_fusion.cpython-311.pyc
ADDED
|
Binary file (75.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/joint_graph.cpython-311.pyc
ADDED
|
Binary file (34.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/micro_pipeline_tp.cpython-311.pyc
ADDED
|
Binary file (42 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/misc_patterns.cpython-311.pyc
ADDED
|
Binary file (6.81 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/mkldnn_fusion.cpython-311.pyc
ADDED
|
Binary file (63.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/numeric_utils.cpython-311.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/pad_mm.cpython-311.pyc
ADDED
|
Binary file (37.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/post_grad.cpython-311.pyc
ADDED
|
Binary file (59.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/pre_grad.cpython-311.pyc
ADDED
|
Binary file (40.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/quantization.cpython-311.pyc
ADDED
|
Binary file (84.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/reinplace.cpython-311.pyc
ADDED
|
Binary file (31.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/replace_random.cpython-311.pyc
ADDED
|
Binary file (7.48 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from . import mm, mm_common, mm_plus_mm, unpack_mixed_mm
|
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (321 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/bmm.cpython-311.pyc
ADDED
|
Binary file (9.36 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/conv.cpython-311.pyc
ADDED
|
Binary file (25.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/flex_attention.cpython-311.pyc
ADDED
|
Binary file (67.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/flex_decoding.cpython-311.pyc
ADDED
|
Binary file (21.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/mm.cpython-311.pyc
ADDED
|
Binary file (30.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/mm_common.cpython-311.pyc
ADDED
|
Binary file (20.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/mm_plus_mm.cpython-311.pyc
ADDED
|
Binary file (7.92 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/mm_scaled.cpython-311.pyc
ADDED
|
Binary file (11.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/unpack_mixed_mm.cpython-311.pyc
ADDED
|
Binary file (3.59 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/bmm.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from .. import ir, lowering as L
|
| 7 |
+
from ..select_algorithm import (
|
| 8 |
+
autotune_select_algorithm,
|
| 9 |
+
ExternKernelChoice,
|
| 10 |
+
TritonTemplate,
|
| 11 |
+
)
|
| 12 |
+
from ..utils import (
|
| 13 |
+
ceildiv as cdiv,
|
| 14 |
+
use_aten_gemm_kernels,
|
| 15 |
+
use_cutlass_template,
|
| 16 |
+
use_triton_template,
|
| 17 |
+
)
|
| 18 |
+
from ..virtualized import V
|
| 19 |
+
from .mm import _is_static_problem
|
| 20 |
+
from .mm_common import addmm_epilogue, mm_args, mm_configs, mm_options
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
log = logging.getLogger(__name__)
|
| 24 |
+
aten = torch.ops.aten
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def bmm_grid(b, m, n, meta):
|
| 28 |
+
return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), b, 1)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
bmm_template = TritonTemplate(
|
| 32 |
+
name="bmm",
|
| 33 |
+
grid=bmm_grid,
|
| 34 |
+
source=r"""
|
| 35 |
+
{{def_kernel("A", "B")}}
|
| 36 |
+
M = {{size("A", -2)}}
|
| 37 |
+
N = {{size("B", -1)}}
|
| 38 |
+
K = {{size("A", -1)}}
|
| 39 |
+
|
| 40 |
+
stride_aq = {{stride("A", 0)}}
|
| 41 |
+
stride_am = {{stride("A", 1)}}
|
| 42 |
+
stride_ak = {{stride("A", 2)}}
|
| 43 |
+
|
| 44 |
+
stride_bq = {{stride("B", 0)}}
|
| 45 |
+
stride_bk = {{stride("B", 1)}}
|
| 46 |
+
stride_bn = {{stride("B", 2)}}
|
| 47 |
+
|
| 48 |
+
# based on triton.ops.matmul
|
| 49 |
+
pid = tl.program_id(0)
|
| 50 |
+
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
| 51 |
+
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
| 52 |
+
|
| 53 |
+
# re-order program ID for better L2 performance
|
| 54 |
+
width = GROUP_M * grid_n
|
| 55 |
+
group_id = pid // width
|
| 56 |
+
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
| 57 |
+
pid_m = group_id * GROUP_M + (pid % group_size)
|
| 58 |
+
pid_n = (pid % width) // (group_size)
|
| 59 |
+
|
| 60 |
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 61 |
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 62 |
+
if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1):
|
| 63 |
+
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
| 64 |
+
else:
|
| 65 |
+
ram = rm % M
|
| 66 |
+
if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1):
|
| 67 |
+
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
| 68 |
+
else:
|
| 69 |
+
rbn = rn % N
|
| 70 |
+
|
| 71 |
+
rk = tl.arange(0, BLOCK_K)
|
| 72 |
+
|
| 73 |
+
idx_q = tl.program_id(1) # batch dimension for BMM
|
| 74 |
+
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak + idx_q*stride_aq)
|
| 75 |
+
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn + idx_q*stride_bq)
|
| 76 |
+
|
| 77 |
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
| 78 |
+
for k in range(K, 0, -BLOCK_K):
|
| 79 |
+
if EVEN_K:
|
| 80 |
+
a = tl.load(A)
|
| 81 |
+
b = tl.load(B)
|
| 82 |
+
else:
|
| 83 |
+
a = tl.load(A, mask=rk[None, :] < k, other=0.)
|
| 84 |
+
b = tl.load(B, mask=rk[:, None] < k, other=0.)
|
| 85 |
+
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
|
| 86 |
+
A += BLOCK_K * stride_ak
|
| 87 |
+
B += BLOCK_K * stride_bk
|
| 88 |
+
|
| 89 |
+
# rematerialize rm and rn to save registers
|
| 90 |
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 91 |
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 92 |
+
idx_q = tl.program_id(1) # batch dimension for BMM
|
| 93 |
+
idx_m = rm[:, None]
|
| 94 |
+
idx_n = rn[None, :]
|
| 95 |
+
mask = (idx_m < M) & (idx_n < N)
|
| 96 |
+
|
| 97 |
+
# inductor generates a suffix
|
| 98 |
+
{{store_output(("idx_q", "idx_m", "idx_n"), "acc", "mask")}}
|
| 99 |
+
""",
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
aten_bmm = ExternKernelChoice(torch.bmm, "at::bmm_out")
|
| 103 |
+
aten_baddbmm = ExternKernelChoice(torch.baddbmm, "at::baddbmm_out")
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@L.register_lowering(aten.bmm)
|
| 107 |
+
def tuned_bmm(mat1, mat2, *, layout=None):
|
| 108 |
+
if all(x.get_device().type == "cpu" for x in [mat1, mat2]):
|
| 109 |
+
# decompose to small ops when memory bound
|
| 110 |
+
if mat1.get_size()[1] == 1 or mat2.get_size()[2] == 1:
|
| 111 |
+
mat1 = L.unsqueeze(mat1, -1)
|
| 112 |
+
mat2 = L.unsqueeze(mat2, 1)
|
| 113 |
+
return L.sum_(L.mul(mat1, mat2), axis=2)
|
| 114 |
+
|
| 115 |
+
def is_valid_to_require_contiguous(t):
|
| 116 |
+
if not ir.is_storage_and_layout(t):
|
| 117 |
+
return True
|
| 118 |
+
_, layout = ir.as_storage_and_layout(t, freeze=False)
|
| 119 |
+
return isinstance(layout, ir.FlexibleLayout)
|
| 120 |
+
|
| 121 |
+
def is_preferred_layout_as_bmm_input(sizes, strides):
|
| 122 |
+
# contiguous on one of the last two dims
|
| 123 |
+
return (
|
| 124 |
+
strides[-1] == 1 and (sizes[-2] == 1 or strides[-2] >= sizes[-1])
|
| 125 |
+
) or (strides[-2] == 1 and (sizes[-1] == 1 or strides[-1] >= sizes[-2]))
|
| 126 |
+
|
| 127 |
+
# Make the input of bmm contiguous
|
| 128 |
+
# if it is not contiguous on either of the last two dims,
|
| 129 |
+
# because bmm cpu implementation would do contiguous() if not.
|
| 130 |
+
# This is to avoid additional copies in bmm.
|
| 131 |
+
def may_require_contiguous(t, meta_t):
|
| 132 |
+
sizes = meta_t.meta["val"].size()
|
| 133 |
+
strides = meta_t.meta["val"].stride()
|
| 134 |
+
if not is_preferred_layout_as_bmm_input(sizes, strides):
|
| 135 |
+
t = ir.ExternKernel.require_contiguous(t)
|
| 136 |
+
return t
|
| 137 |
+
|
| 138 |
+
if is_valid_to_require_contiguous(mat1):
|
| 139 |
+
meta_mat1 = V.graph.current_node.args[0]
|
| 140 |
+
mat1 = may_require_contiguous(mat1, meta_mat1)
|
| 141 |
+
if is_valid_to_require_contiguous(mat2):
|
| 142 |
+
meta_mat2 = V.graph.current_node.args[1]
|
| 143 |
+
mat2 = may_require_contiguous(mat2, meta_mat2)
|
| 144 |
+
|
| 145 |
+
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
|
| 146 |
+
|
| 147 |
+
# options to tune from
|
| 148 |
+
choices = [aten_bmm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else []
|
| 149 |
+
if use_triton_template(layout):
|
| 150 |
+
for config in mm_configs(m, n, k):
|
| 151 |
+
bmm_template.maybe_append_choice(
|
| 152 |
+
choices,
|
| 153 |
+
input_nodes=(mat1, mat2),
|
| 154 |
+
layout=layout,
|
| 155 |
+
**mm_options(config, m, n, k, layout),
|
| 156 |
+
)
|
| 157 |
+
static_shape, is_nonzero = _is_static_problem([mat1, mat2], layout)
|
| 158 |
+
if static_shape and is_nonzero and use_cutlass_template(layout, m, n, k):
|
| 159 |
+
from ..codegen.cuda.gemm_template import CUTLASS3xGemmTemplate
|
| 160 |
+
|
| 161 |
+
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2])
|
| 162 |
+
|
| 163 |
+
if len(choices) == 0:
|
| 164 |
+
log.warning("No choices for GEMM, using ATen backend as fallback")
|
| 165 |
+
choices.append(aten_bmm.bind((mat1, mat2), layout))
|
| 166 |
+
|
| 167 |
+
return autotune_select_algorithm("bmm", choices, [mat1, mat2], layout)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# Don't register this since it is slower than decomposing it
|
| 171 |
+
# @L.register_lowering(aten.baddbmm)
|
| 172 |
+
def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
| 173 |
+
m, n, k, layout, mat1, mat2, inp = mm_args(mat1, mat2, inp, layout=layout)
|
| 174 |
+
|
| 175 |
+
# options to tune from
|
| 176 |
+
choices = (
|
| 177 |
+
[aten_baddbmm.bind((inp, mat1, mat2), layout, alpha=alpha, beta=beta)]
|
| 178 |
+
if use_aten_gemm_kernels()
|
| 179 |
+
else []
|
| 180 |
+
)
|
| 181 |
+
if use_triton_template(layout):
|
| 182 |
+
for config in mm_configs(m, n, k):
|
| 183 |
+
bmm_template.maybe_append_choice(
|
| 184 |
+
choices,
|
| 185 |
+
input_nodes=(inp, mat1, mat2),
|
| 186 |
+
layout=layout,
|
| 187 |
+
**mm_options(config, m, n, k, layout),
|
| 188 |
+
prefix_args=1,
|
| 189 |
+
epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
return autotune_select_algorithm("baddbmm", choices, [inp, mat1, mat2], layout)
|
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/conv.py
ADDED
|
@@ -0,0 +1,679 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-decorators
|
| 2 |
+
# mypy: allow-untyped-defs
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import functools
|
| 6 |
+
import logging
|
| 7 |
+
from typing import cast, List, Optional, Sequence, Tuple, TYPE_CHECKING, TypedDict
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from .. import config, ir
|
| 12 |
+
from ..lowering import (
|
| 13 |
+
add_layout_constraint,
|
| 14 |
+
constrain_to_fx_strides,
|
| 15 |
+
lowerings as L,
|
| 16 |
+
register_lowering,
|
| 17 |
+
)
|
| 18 |
+
from ..select_algorithm import (
|
| 19 |
+
autotune_select_algorithm,
|
| 20 |
+
ExternKernelChoice,
|
| 21 |
+
TritonTemplate,
|
| 22 |
+
)
|
| 23 |
+
from ..utils import (
|
| 24 |
+
ceildiv,
|
| 25 |
+
is_ones,
|
| 26 |
+
is_zeros,
|
| 27 |
+
pad_listlike,
|
| 28 |
+
sympy_product,
|
| 29 |
+
use_triton_template,
|
| 30 |
+
)
|
| 31 |
+
from ..virtualized import V
|
| 32 |
+
from .mm_common import filtered_configs
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
if TYPE_CHECKING:
|
| 36 |
+
from ..ir import TensorBox
|
| 37 |
+
|
| 38 |
+
log = logging.getLogger(__name__)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
aten = torch.ops.aten
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def conv2d_grid(n, c, h, w, meta):
|
| 45 |
+
return (
|
| 46 |
+
ceildiv(n * h * w, meta["BLOCK_M"]),
|
| 47 |
+
ceildiv(c, meta["BLOCK_N"]),
|
| 48 |
+
meta["GROUPS"],
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def conv3d_grid(n, c, d, h, w, meta):
|
| 53 |
+
return (
|
| 54 |
+
ceildiv(n * d * h * w, meta["BLOCK_M"]),
|
| 55 |
+
ceildiv(c, meta["BLOCK_N"]),
|
| 56 |
+
meta["GROUPS"],
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# List of dictionaries to store the kernel configs. Configs that evaluate to true
|
| 61 |
+
# will be utilised on the target platform
|
| 62 |
+
kernel_configs = [
|
| 63 |
+
# "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps"
|
| 64 |
+
{"config": (64, 256, 16, 2, 4), "cond": True},
|
| 65 |
+
{"config": (256, 64, 16, 2, 4), "cond": True},
|
| 66 |
+
{"config": (1024, 16, 16, 1, 8), "cond": True},
|
| 67 |
+
{"config": (128, 128, 32, 2, 8), "cond": True},
|
| 68 |
+
{"config": (64, 64, 32, 2, 4), "cond": True},
|
| 69 |
+
{"config": (64, 256, 32, 2, 8), "cond": True},
|
| 70 |
+
{"config": (256, 64, 32, 2, 8), "cond": True},
|
| 71 |
+
]
|
| 72 |
+
|
| 73 |
+
# Create filtered list of configs based on conv
|
| 74 |
+
platform_configs = tuple(
|
| 75 |
+
cast(Tuple[int, int, int, int, int], config["config"])
|
| 76 |
+
for config in kernel_configs
|
| 77 |
+
if config["cond"]
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
# On ROCm convert num_stages to 1 as pipelining provides no benefit
|
| 81 |
+
if torch.version.hip:
|
| 82 |
+
platform_configs = tuple(
|
| 83 |
+
(config[0], config[1], config[2], 1, config[4]) for config in platform_configs
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
conv_configs = functools.partial(
|
| 87 |
+
filtered_configs,
|
| 88 |
+
configs=platform_configs,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
LOOP_BODY_2D = """
|
| 92 |
+
idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H
|
| 93 |
+
idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W
|
| 94 |
+
idx_x_c = tl.arange(0, BLOCK_K) + k
|
| 95 |
+
|
| 96 |
+
x_ptrs = x_base + (
|
| 97 |
+
(idx_x_h * stride_xh)[:, None]
|
| 98 |
+
+ (idx_x_w * stride_xw)[:, None]
|
| 99 |
+
+ (idx_x_c * stride_xc)[None, :]
|
| 100 |
+
)
|
| 101 |
+
mask_x = (
|
| 102 |
+
(idx_n < BATCH)[:, None]
|
| 103 |
+
& (idx_x_h >= 0)[:, None]
|
| 104 |
+
& (idx_x_h < IN_H)[:, None]
|
| 105 |
+
& (idx_x_w >= 0)[:, None]
|
| 106 |
+
& (idx_x_w < IN_W)[:, None]
|
| 107 |
+
& (idx_x_c < GROUP_IN_C)[None, :]
|
| 108 |
+
)
|
| 109 |
+
matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
|
| 110 |
+
|
| 111 |
+
w_ptrs = w_base + (
|
| 112 |
+
(idx_x_c * stride_wc_in)[:, None] + (i * stride_wh) + (j * stride_ww)
|
| 113 |
+
)
|
| 114 |
+
mask_w = (idx_x_c[:, None] < GROUP_IN_C) & (idx_y_c[None, :] < GROUP_OUT_C)
|
| 115 |
+
matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
|
| 116 |
+
acc += tl.dot(matrix_x, matrix_w, allow_tf32=ALLOW_TF32)
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
"""
|
| 120 |
+
This is a relatively simple conv implementation that can likely be
|
| 121 |
+
improved. Many alternate conv versions can be found here:
|
| 122 |
+
https://github.com/pytorch/torchdynamo/pull/971
|
| 123 |
+
"""
|
| 124 |
+
conv2d_template = TritonTemplate(
|
| 125 |
+
name="convolution2d",
|
| 126 |
+
grid=conv2d_grid,
|
| 127 |
+
source=r"""
|
| 128 |
+
{{def_kernel("X", "W")}}
|
| 129 |
+
# Tensor dimensions
|
| 130 |
+
BATCH = {{size("X", 0)}}
|
| 131 |
+
IN_C = {{size("X", 1)}}
|
| 132 |
+
IN_H = {{size("X", 2)}}
|
| 133 |
+
IN_W = {{size("X", 3)}}
|
| 134 |
+
OUT_C = {{size(None, 1)}}
|
| 135 |
+
OUT_H = {{size(None, 2)}}
|
| 136 |
+
OUT_W = {{size(None, 3)}}
|
| 137 |
+
|
| 138 |
+
# Strides:
|
| 139 |
+
stride_xn = {{stride("X", 0)}}
|
| 140 |
+
stride_xc = {{stride("X", 1)}}
|
| 141 |
+
stride_xh = {{stride("X", 2)}}
|
| 142 |
+
stride_xw = {{stride("X", 3)}}
|
| 143 |
+
stride_wc_out = {{stride("W", 0)}}
|
| 144 |
+
stride_wc_in = {{stride("W", 1)}}
|
| 145 |
+
stride_wh = {{stride("W", 2)}}
|
| 146 |
+
stride_ww = {{stride("W", 3)}}
|
| 147 |
+
|
| 148 |
+
nhw = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 149 |
+
idx_y_w = nhw % OUT_W
|
| 150 |
+
nh = nhw // OUT_W
|
| 151 |
+
idx_y_h = nh % OUT_H
|
| 152 |
+
idx_n = nh // OUT_H
|
| 153 |
+
idx_y_c = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 154 |
+
|
| 155 |
+
{% if GROUPS == 1 %}
|
| 156 |
+
group = 0
|
| 157 |
+
GROUP_IN_C = IN_C
|
| 158 |
+
GROUP_OUT_C = OUT_C
|
| 159 |
+
{% else %}
|
| 160 |
+
group = tl.program_id(2)
|
| 161 |
+
GROUP_IN_C = IN_C // GROUPS
|
| 162 |
+
GROUP_OUT_C = OUT_C // GROUPS
|
| 163 |
+
{% endif %}
|
| 164 |
+
|
| 165 |
+
x_base = X + (group * stride_xc * GROUP_IN_C + idx_n * stride_xn)[:, None]
|
| 166 |
+
w_base = (
|
| 167 |
+
W + (group * stride_wc_out * GROUP_OUT_C + idx_y_c * stride_wc_out)[None, :]
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
| 171 |
+
|
| 172 |
+
{% if UNROLL %}
|
| 173 |
+
{% for i in range(KERNEL_H) %}
|
| 174 |
+
{% for j in range(KERNEL_W) %}
|
| 175 |
+
i = {{i}}
|
| 176 |
+
j = {{j}}
|
| 177 |
+
for k in range(0, GROUP_IN_C, BLOCK_K):
|
| 178 |
+
"""
|
| 179 |
+
+ LOOP_BODY_2D
|
| 180 |
+
+ """
|
| 181 |
+
{% endfor %}
|
| 182 |
+
{% endfor %}
|
| 183 |
+
{% else %}
|
| 184 |
+
# Could be simplified, but slightly slower:
|
| 185 |
+
# for i in range(KERNEL_H):
|
| 186 |
+
# for j in range(KERNEL_W):
|
| 187 |
+
# for k in range(0, GROUP_IN_C, BLOCK_K):
|
| 188 |
+
BLOCK_K_COUNT = (GROUP_IN_C + BLOCK_K - 1) // BLOCK_K
|
| 189 |
+
for ijk in range(KERNEL_H * KERNEL_W * BLOCK_K_COUNT):
|
| 190 |
+
k = (ijk % BLOCK_K_COUNT) * BLOCK_K
|
| 191 |
+
ij = ijk // BLOCK_K_COUNT
|
| 192 |
+
i = ij // KERNEL_W
|
| 193 |
+
j = ij % KERNEL_W
|
| 194 |
+
"""
|
| 195 |
+
+ LOOP_BODY_2D
|
| 196 |
+
+ """
|
| 197 |
+
{% endif %}
|
| 198 |
+
|
| 199 |
+
mask = (
|
| 200 |
+
(idx_n < BATCH)[:, None]
|
| 201 |
+
& (idx_y_h < OUT_H)[:, None]
|
| 202 |
+
& (idx_y_w < OUT_W)[:, None]
|
| 203 |
+
& (idx_y_c < GROUP_OUT_C)[None, :]
|
| 204 |
+
)
|
| 205 |
+
idx_n = idx_n[:, None]
|
| 206 |
+
idx_c = idx_y_c[None, :] + group * GROUP_OUT_C
|
| 207 |
+
idx_h = idx_y_h[:, None]
|
| 208 |
+
idx_w = idx_y_w[:, None]
|
| 209 |
+
|
| 210 |
+
# inductor generates a suffix
|
| 211 |
+
{{store_output(("idx_n", "idx_c", "idx_h", "idx_w"), "acc", "mask")}}
|
| 212 |
+
""",
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
LOOP_BODY_3D = """
|
| 216 |
+
idx_x_d = d - PADDING_D + idx_y_d * STRIDE_D
|
| 217 |
+
idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H
|
| 218 |
+
idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W
|
| 219 |
+
idx_x_c = tl.arange(0, BLOCK_K) + k
|
| 220 |
+
|
| 221 |
+
x_ptrs = x_base + (
|
| 222 |
+
(idx_x_d * stride_xd)[:, None]
|
| 223 |
+
+ (idx_x_h * stride_xh)[:, None]
|
| 224 |
+
+ (idx_x_w * stride_xw)[:, None]
|
| 225 |
+
+ (idx_x_c * stride_xc)[None, :]
|
| 226 |
+
)
|
| 227 |
+
mask_x = (
|
| 228 |
+
(idx_n < BATCH)[:, None]
|
| 229 |
+
& (idx_x_d >= 0)[:, None]
|
| 230 |
+
& (idx_x_d < IN_D)[:, None]
|
| 231 |
+
& (idx_x_h >= 0)[:, None]
|
| 232 |
+
& (idx_x_h < IN_H)[:, None]
|
| 233 |
+
& (idx_x_w >= 0)[:, None]
|
| 234 |
+
& (idx_x_w < IN_W)[:, None]
|
| 235 |
+
& (idx_x_c < GROUP_IN_C)[None, :]
|
| 236 |
+
)
|
| 237 |
+
matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
|
| 238 |
+
|
| 239 |
+
w_ptrs = w_base + (
|
| 240 |
+
(idx_x_c * stride_wc_in)[:, None] +
|
| 241 |
+
(d * stride_wd) + (i * stride_wh) + (j * stride_ww)
|
| 242 |
+
)
|
| 243 |
+
mask_w = (idx_x_c[:, None] < GROUP_IN_C) & (idx_y_c[None, :] < GROUP_OUT_C)
|
| 244 |
+
matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
|
| 245 |
+
acc += tl.dot(matrix_x, matrix_w, allow_tf32=ALLOW_TF32)
|
| 246 |
+
"""
|
| 247 |
+
|
| 248 |
+
conv3d_template = TritonTemplate(
|
| 249 |
+
name="convolution3d",
|
| 250 |
+
grid=conv3d_grid,
|
| 251 |
+
source=r"""
|
| 252 |
+
{{def_kernel("X", "W")}}
|
| 253 |
+
# Tensor dimensions
|
| 254 |
+
BATCH = {{size("X", 0)}}
|
| 255 |
+
IN_C = {{size("X", 1)}}
|
| 256 |
+
IN_D = {{size("X", 2)}}
|
| 257 |
+
IN_H = {{size("X", 3)}}
|
| 258 |
+
IN_W = {{size("X", 4)}}
|
| 259 |
+
OUT_C = {{size(None, 1)}}
|
| 260 |
+
OUT_D = {{size(None, 2)}}
|
| 261 |
+
OUT_H = {{size(None, 3)}}
|
| 262 |
+
OUT_W = {{size(None, 4)}}
|
| 263 |
+
|
| 264 |
+
# Strides:
|
| 265 |
+
stride_xn = {{stride("X", 0)}}
|
| 266 |
+
stride_xc = {{stride("X", 1)}}
|
| 267 |
+
stride_xd = {{stride("X", 2)}}
|
| 268 |
+
stride_xh = {{stride("X", 3)}}
|
| 269 |
+
stride_xw = {{stride("X", 4)}}
|
| 270 |
+
stride_wc_out = {{stride("W", 0)}}
|
| 271 |
+
stride_wc_in = {{stride("W", 1)}}
|
| 272 |
+
stride_wd = {{stride("W", 2)}}
|
| 273 |
+
stride_wh = {{stride("W", 3)}}
|
| 274 |
+
stride_ww = {{stride("W", 4)}}
|
| 275 |
+
|
| 276 |
+
ndhw = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 277 |
+
idx_y_w = ndhw % OUT_W
|
| 278 |
+
ndh = ndhw // OUT_W
|
| 279 |
+
idx_y_h = ndh % OUT_H
|
| 280 |
+
nd = ndh // OUT_H
|
| 281 |
+
idx_y_d = nd % OUT_D
|
| 282 |
+
idx_n = nd // OUT_D
|
| 283 |
+
idx_y_c = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 284 |
+
|
| 285 |
+
{% if GROUPS == 1 %}
|
| 286 |
+
group = 0
|
| 287 |
+
GROUP_IN_C = IN_C
|
| 288 |
+
GROUP_OUT_C = OUT_C
|
| 289 |
+
{% else %}
|
| 290 |
+
group = tl.program_id(2)
|
| 291 |
+
GROUP_IN_C = IN_C // GROUPS
|
| 292 |
+
GROUP_OUT_C = OUT_C // GROUPS
|
| 293 |
+
{% endif %}
|
| 294 |
+
|
| 295 |
+
x_base = X + (group * stride_xc * GROUP_IN_C + idx_n * stride_xn)[:, None]
|
| 296 |
+
w_base = (
|
| 297 |
+
W + (group * stride_wc_out * GROUP_OUT_C + idx_y_c * stride_wc_out)[None, :]
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
| 301 |
+
|
| 302 |
+
{% if UNROLL %}
|
| 303 |
+
{% for d in range(KERNEL_D) %}
|
| 304 |
+
{% for i in range(KERNEL_H) %}
|
| 305 |
+
{% for j in range(KERNEL_W) %}
|
| 306 |
+
d = {{d}}
|
| 307 |
+
i = {{i}}
|
| 308 |
+
j = {{j}}
|
| 309 |
+
for k in range(0, GROUP_IN_C, BLOCK_K):
|
| 310 |
+
"""
|
| 311 |
+
+ LOOP_BODY_3D
|
| 312 |
+
+ """
|
| 313 |
+
{% endfor %}
|
| 314 |
+
{% endfor %}
|
| 315 |
+
{% endfor %}
|
| 316 |
+
{% else %}
|
| 317 |
+
# Could be simplified, but slightly slower:
|
| 318 |
+
# for d in range(KERNEL_D):
|
| 319 |
+
# for i in range(KERNEL_H):
|
| 320 |
+
# for j in range(KERNEL_W):
|
| 321 |
+
# for k in range(0, GROUP_IN_C, BLOCK_K):
|
| 322 |
+
BLOCK_K_COUNT = (GROUP_IN_C + BLOCK_K - 1) // BLOCK_K
|
| 323 |
+
for dijk in range(KERNEL_D * KERNEL_H * KERNEL_W * BLOCK_K_COUNT):
|
| 324 |
+
k = (dijk % BLOCK_K_COUNT) * BLOCK_K
|
| 325 |
+
dij = dijk // BLOCK_K_COUNT
|
| 326 |
+
j = dij % KERNEL_W
|
| 327 |
+
di = dij // KERNEL_W
|
| 328 |
+
i = di % KERNEL_H
|
| 329 |
+
d = di // KERNEL_H
|
| 330 |
+
"""
|
| 331 |
+
+ LOOP_BODY_3D
|
| 332 |
+
+ """
|
| 333 |
+
{% endif %}
|
| 334 |
+
|
| 335 |
+
mask = (
|
| 336 |
+
(idx_n < BATCH)[:, None]
|
| 337 |
+
& (idx_y_d < OUT_D)[:, None]
|
| 338 |
+
& (idx_y_h < OUT_H)[:, None]
|
| 339 |
+
& (idx_y_w < OUT_W)[:, None]
|
| 340 |
+
& (idx_y_c < GROUP_OUT_C)[None, :]
|
| 341 |
+
)
|
| 342 |
+
idx_n = idx_n[:, None]
|
| 343 |
+
idx_c = idx_y_c[None, :] + group * GROUP_OUT_C
|
| 344 |
+
idx_d = idx_y_d[:, None]
|
| 345 |
+
idx_h = idx_y_h[:, None]
|
| 346 |
+
idx_w = idx_y_w[:, None]
|
| 347 |
+
|
| 348 |
+
# inductor generates a suffix
|
| 349 |
+
{{store_output(("idx_n", "idx_c", "idx_d", "idx_h", "idx_w"), "acc", "mask")}}
|
| 350 |
+
""",
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
aten_convolution = ExternKernelChoice(
|
| 354 |
+
torch.convolution,
|
| 355 |
+
"at::convolution",
|
| 356 |
+
has_out_variant=False,
|
| 357 |
+
op_overload=aten.convolution.default,
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def conv1x1_via_mm(x, w, *, out):
|
| 362 |
+
w = torch.squeeze(torch.squeeze(w, -1), -1)
|
| 363 |
+
return torch.matmul(
|
| 364 |
+
x.permute(0, 2, 3, 1), w.permute(1, 0), out=out.permute(0, 2, 3, 1)
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
aten_conv1x1_via_mm = ExternKernelChoice(conv1x1_via_mm, None)
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
class ConvLayoutParams(TypedDict):
|
| 372 |
+
stride: tuple[int, ...]
|
| 373 |
+
padding: tuple[int, ...]
|
| 374 |
+
dilation: tuple[int, ...]
|
| 375 |
+
transposed: bool
|
| 376 |
+
output_padding: tuple[int, ...]
|
| 377 |
+
groups: int
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def conv_layout(
|
| 381 |
+
x: TensorBox,
|
| 382 |
+
weight: TensorBox,
|
| 383 |
+
bias: Optional[TensorBox],
|
| 384 |
+
stride: Sequence[int],
|
| 385 |
+
padding: tuple[int, ...],
|
| 386 |
+
dilation: tuple[int, ...],
|
| 387 |
+
transposed: bool,
|
| 388 |
+
output_padding: tuple[int, ...],
|
| 389 |
+
groups: int,
|
| 390 |
+
) -> ir.Layout:
|
| 391 |
+
"""Determine output layout for a convolution"""
|
| 392 |
+
with V.graph.fake_mode:
|
| 393 |
+
output = torch.ops.aten.convolution(
|
| 394 |
+
ir.ir_node_to_tensor(x, guard_shape=True),
|
| 395 |
+
ir.ir_node_to_tensor(weight, guard_shape=True),
|
| 396 |
+
ir.ir_node_to_tensor(bias, guard_shape=True),
|
| 397 |
+
V.graph.sizevars.size_hints(stride), # type: ignore[arg-type]
|
| 398 |
+
V.graph.sizevars.size_hints(padding), # type: ignore[arg-type]
|
| 399 |
+
V.graph.sizevars.size_hints(dilation), # type: ignore[arg-type]
|
| 400 |
+
transposed,
|
| 401 |
+
V.graph.sizevars.size_hints(output_padding), # type: ignore[arg-type]
|
| 402 |
+
groups,
|
| 403 |
+
)
|
| 404 |
+
sizes = ir.convert_shape_to_inductor(output.size())
|
| 405 |
+
stride = ir.convert_shape_to_inductor(output.stride()) # type: ignore[assignment]
|
| 406 |
+
|
| 407 |
+
return ir.FixedLayout(
|
| 408 |
+
x.get_device(),
|
| 409 |
+
x.get_dtype(),
|
| 410 |
+
sizes,
|
| 411 |
+
stride,
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def channels_last_order(rank):
|
| 416 |
+
order = list(reversed(range(rank)))
|
| 417 |
+
order.insert(1, order.pop(-1))
|
| 418 |
+
return order
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def convert_1x1_conv_to_mm(x, weight, bias):
|
| 422 |
+
# special case for 1x1 convolution, which is actually just a matmul
|
| 423 |
+
rank = len(weight.get_size())
|
| 424 |
+
for _ in range(rank - 2):
|
| 425 |
+
weight = L[aten.squeeze](weight, dim=-1)
|
| 426 |
+
weight = L[aten.permute](weight, [1, 0])
|
| 427 |
+
|
| 428 |
+
x = ir.ExternKernel.require_stride_order(x, channels_last_order(rank))
|
| 429 |
+
x_permute = list(range(rank))
|
| 430 |
+
x_permute.append(x_permute.pop(1))
|
| 431 |
+
x = L[aten.permute](x, x_permute)
|
| 432 |
+
*sizes, in_chan = x.get_size()
|
| 433 |
+
x = L[aten.reshape](x, [sympy_product(sizes), in_chan])
|
| 434 |
+
if bias is None:
|
| 435 |
+
result = L[aten.mm](x, weight)
|
| 436 |
+
else:
|
| 437 |
+
result = L[aten.addmm](bias, x, weight)
|
| 438 |
+
result = L[aten.reshape](result, [*sizes, -1])
|
| 439 |
+
result_permute = list(range(rank))
|
| 440 |
+
result_permute.insert(1, result_permute.pop(-1))
|
| 441 |
+
return L[aten.permute](result, result_permute)
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
@register_lowering(aten.convolution)
|
| 445 |
+
def convolution(
|
| 446 |
+
x: TensorBox,
|
| 447 |
+
weight: TensorBox,
|
| 448 |
+
bias: TensorBox,
|
| 449 |
+
stride: List[int],
|
| 450 |
+
padding: List[int],
|
| 451 |
+
dilation: List[int],
|
| 452 |
+
transposed: bool,
|
| 453 |
+
output_padding: List[int],
|
| 454 |
+
groups: int,
|
| 455 |
+
):
|
| 456 |
+
stride = tuple(stride)
|
| 457 |
+
padding = tuple(padding)
|
| 458 |
+
dilation = tuple(dilation)
|
| 459 |
+
output_padding = tuple(output_padding)
|
| 460 |
+
if not isinstance(groups, int):
|
| 461 |
+
groups = V.graph.sizevars.evaluate_static_shape(groups)
|
| 462 |
+
assert isinstance(groups, int)
|
| 463 |
+
|
| 464 |
+
# Need use hint for triton template since the template does not
|
| 465 |
+
# work with a dynamic shape.
|
| 466 |
+
#
|
| 467 |
+
# No need to evaluate_static_shape for dilation and output_padding
|
| 468 |
+
# since the template is only used when dilation is 1 and output_padding
|
| 469 |
+
# is 0.
|
| 470 |
+
stride = tuple(V.graph.sizevars.evaluate_static_shapes(stride))
|
| 471 |
+
padding = tuple(V.graph.sizevars.evaluate_static_shapes(padding))
|
| 472 |
+
|
| 473 |
+
kwargs: ConvLayoutParams = {
|
| 474 |
+
"stride": stride,
|
| 475 |
+
"padding": padding,
|
| 476 |
+
"dilation": dilation,
|
| 477 |
+
"transposed": transposed,
|
| 478 |
+
"output_padding": output_padding,
|
| 479 |
+
"groups": groups,
|
| 480 |
+
}
|
| 481 |
+
|
| 482 |
+
if len(x.get_size()) == len(weight.get_size()) - 1:
|
| 483 |
+
# add batch dimension to simplify rest of function
|
| 484 |
+
return L[aten.squeeze](
|
| 485 |
+
convolution(L[aten.expand](x, [1, *x.get_size()]), weight, bias, **kwargs),
|
| 486 |
+
dim=0,
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
out_chan, in_chan, *kernel_shape = V.graph.sizevars.evaluate_static_shapes(
|
| 490 |
+
weight.get_size()
|
| 491 |
+
)
|
| 492 |
+
ndim = len(kernel_shape)
|
| 493 |
+
stride = pad_listlike(stride, ndim)
|
| 494 |
+
padding = pad_listlike(padding, ndim)
|
| 495 |
+
dilation = pad_listlike(dilation, ndim)
|
| 496 |
+
output_padding = pad_listlike(output_padding, ndim)
|
| 497 |
+
|
| 498 |
+
def channels_last_conv():
|
| 499 |
+
if V.graph.layout_opt and ndim == 2:
|
| 500 |
+
return True
|
| 501 |
+
|
| 502 |
+
layout = conv_layout(x, weight, None, **kwargs)
|
| 503 |
+
req_stride_order = ir.get_stride_order(
|
| 504 |
+
V.graph.sizevars.size_hints(layout.stride)
|
| 505 |
+
)
|
| 506 |
+
return req_stride_order == ir.NHWC_STRIDE_ORDER
|
| 507 |
+
|
| 508 |
+
autotuning_gemm = config.max_autotune or config.max_autotune_gemm
|
| 509 |
+
|
| 510 |
+
if (
|
| 511 |
+
(config.conv_1x1_as_mm or (autotuning_gemm and channels_last_conv()))
|
| 512 |
+
and is_ones(kernel_shape)
|
| 513 |
+
and is_ones(stride)
|
| 514 |
+
and is_zeros(padding)
|
| 515 |
+
and is_ones(dilation)
|
| 516 |
+
and not transposed
|
| 517 |
+
and is_zeros(output_padding)
|
| 518 |
+
and groups == 1
|
| 519 |
+
and V.graph.sizevars.statically_known_gt(sympy_product(x.get_size()), 0)
|
| 520 |
+
):
|
| 521 |
+
return convert_1x1_conv_to_mm(x, weight, bias)
|
| 522 |
+
|
| 523 |
+
if bias is not None and ir.get_device_type(x) != "cpu":
|
| 524 |
+
# peel off the bias, cudnn is slower with it
|
| 525 |
+
result = convolution(x, weight, None, **kwargs)
|
| 526 |
+
return L[aten.add](
|
| 527 |
+
result, L[aten.view](bias, [result.get_size()[1]] + ndim * [1])
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
x.realize()
|
| 531 |
+
weight.realize()
|
| 532 |
+
|
| 533 |
+
# ndim can be 1 for convolution in models such as demucs
|
| 534 |
+
# TODO: check if it's beneficial to convert Conv1d to Conv2d and then
|
| 535 |
+
# apply channels last.
|
| 536 |
+
if V.graph.layout_opt and ndim == 2:
|
| 537 |
+
V.graph.num_channels_last_conv += 1
|
| 538 |
+
x = ir.ExternKernel.require_channels_last(x)
|
| 539 |
+
# TODO maybe we can convert weights to channels last just once before
|
| 540 |
+
# running the model.
|
| 541 |
+
weight = ir.ExternKernel.require_channels_last(weight)
|
| 542 |
+
layout = conv_layout(x, weight, None, **kwargs)
|
| 543 |
+
else:
|
| 544 |
+
layout = conv_layout(x, weight, None, **kwargs)
|
| 545 |
+
req_stride_order = ir.get_stride_order(
|
| 546 |
+
V.graph.sizevars.size_hints(layout.stride)
|
| 547 |
+
)
|
| 548 |
+
x = ir.ExternKernel.require_stride_order(x, req_stride_order)
|
| 549 |
+
weight = ir.ExternKernel.require_stride_order(weight, req_stride_order)
|
| 550 |
+
|
| 551 |
+
ordered_kwargs_for_cpp_kernel = [
|
| 552 |
+
"stride",
|
| 553 |
+
"padding",
|
| 554 |
+
"dilation",
|
| 555 |
+
"transposed",
|
| 556 |
+
"output_padding",
|
| 557 |
+
"groups",
|
| 558 |
+
]
|
| 559 |
+
if bias is None:
|
| 560 |
+
args = [x, weight]
|
| 561 |
+
kwargs["bias"] = None # type: ignore[typeddict-unknown-key]
|
| 562 |
+
ordered_kwargs_for_cpp_kernel.insert(0, "bias")
|
| 563 |
+
else:
|
| 564 |
+
args = [x, weight, bias]
|
| 565 |
+
bias.realize()
|
| 566 |
+
bias.freeze_layout()
|
| 567 |
+
V.graph.sizevars.evaluate_static_shapes(bias.get_size())
|
| 568 |
+
|
| 569 |
+
choices = []
|
| 570 |
+
if torch._inductor.utils._use_conv_autotune_backend("ATEN"):
|
| 571 |
+
choices = [
|
| 572 |
+
aten_convolution.bind(
|
| 573 |
+
args,
|
| 574 |
+
layout,
|
| 575 |
+
ordered_kwargs_for_cpp_kernel,
|
| 576 |
+
**kwargs,
|
| 577 |
+
)
|
| 578 |
+
]
|
| 579 |
+
|
| 580 |
+
if (
|
| 581 |
+
torch._inductor.utils._use_conv_autotune_backend("TRITON")
|
| 582 |
+
and use_triton_template(layout)
|
| 583 |
+
# templates only support these:
|
| 584 |
+
and is_ones(dilation)
|
| 585 |
+
and not transposed
|
| 586 |
+
and is_zeros(output_padding)
|
| 587 |
+
# there are some odd models where this check fails (e.g. shufflenet_v2_x1_0)
|
| 588 |
+
and V.graph.sizevars.statically_known_equals(in_chan, x.get_size()[1]) # type: ignore[arg-type]
|
| 589 |
+
):
|
| 590 |
+
if (
|
| 591 |
+
is_ones(kernel_shape)
|
| 592 |
+
and is_ones(stride)
|
| 593 |
+
and is_zeros(padding)
|
| 594 |
+
and groups == 1
|
| 595 |
+
):
|
| 596 |
+
choices.append(aten_conv1x1_via_mm.bind(args, layout))
|
| 597 |
+
|
| 598 |
+
for cfg in conv_configs(
|
| 599 |
+
sympy_product([x.get_size()[0], *x.get_size()[2:]]),
|
| 600 |
+
out_chan,
|
| 601 |
+
in_chan,
|
| 602 |
+
):
|
| 603 |
+
if ndim == 2:
|
| 604 |
+
conv2d_template.maybe_append_choice(
|
| 605 |
+
choices,
|
| 606 |
+
input_nodes=(x, weight),
|
| 607 |
+
layout=layout,
|
| 608 |
+
KERNEL_H=kernel_shape[0],
|
| 609 |
+
KERNEL_W=kernel_shape[1],
|
| 610 |
+
STRIDE_H=stride[0],
|
| 611 |
+
STRIDE_W=stride[1],
|
| 612 |
+
PADDING_H=padding[0],
|
| 613 |
+
PADDING_W=padding[1],
|
| 614 |
+
GROUPS=groups,
|
| 615 |
+
# TODO(jansel): try unroll for bigger kernels once fixed:
|
| 616 |
+
# https://github.com/openai/triton/issues/1254
|
| 617 |
+
UNROLL=is_ones(kernel_shape),
|
| 618 |
+
ALLOW_TF32=torch.backends.cudnn.allow_tf32,
|
| 619 |
+
num_stages=cfg.num_stages,
|
| 620 |
+
num_warps=cfg.num_warps,
|
| 621 |
+
**cfg.kwargs,
|
| 622 |
+
)
|
| 623 |
+
elif ndim == 3:
|
| 624 |
+
conv3d_template.maybe_append_choice(
|
| 625 |
+
choices,
|
| 626 |
+
input_nodes=(x, weight),
|
| 627 |
+
layout=layout,
|
| 628 |
+
KERNEL_D=kernel_shape[0],
|
| 629 |
+
KERNEL_H=kernel_shape[1],
|
| 630 |
+
KERNEL_W=kernel_shape[2],
|
| 631 |
+
STRIDE_D=stride[0],
|
| 632 |
+
STRIDE_H=stride[1],
|
| 633 |
+
STRIDE_W=stride[2],
|
| 634 |
+
PADDING_D=padding[0],
|
| 635 |
+
PADDING_H=padding[1],
|
| 636 |
+
PADDING_W=padding[2],
|
| 637 |
+
GROUPS=groups,
|
| 638 |
+
# TODO(jansel): try unroll for bigger kernels once fixed:
|
| 639 |
+
# https://github.com/openai/triton/issues/1254
|
| 640 |
+
UNROLL=is_ones(kernel_shape),
|
| 641 |
+
ALLOW_TF32=torch.backends.cudnn.allow_tf32,
|
| 642 |
+
num_stages=cfg.num_stages,
|
| 643 |
+
num_warps=cfg.num_warps,
|
| 644 |
+
**cfg.kwargs,
|
| 645 |
+
)
|
| 646 |
+
|
| 647 |
+
return autotune_select_algorithm("convolution", choices, args, layout)
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
@register_lowering(aten._convolution)
|
| 651 |
+
def _convolution(
|
| 652 |
+
x,
|
| 653 |
+
weight,
|
| 654 |
+
bias,
|
| 655 |
+
stride,
|
| 656 |
+
padding,
|
| 657 |
+
dilation,
|
| 658 |
+
transposed,
|
| 659 |
+
output_padding,
|
| 660 |
+
groups,
|
| 661 |
+
benchmark,
|
| 662 |
+
deterministic,
|
| 663 |
+
cudnn_enabled,
|
| 664 |
+
allow_tf32,
|
| 665 |
+
):
|
| 666 |
+
return convolution(
|
| 667 |
+
x, weight, bias, stride, padding, dilation, transposed, output_padding, groups
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
def constrain_conv_to_fx_strides(fx_node, *args, **kwargs):
|
| 672 |
+
assert fx_node.target == torch.ops.aten.convolution.default
|
| 673 |
+
if V.graph.layout_opt:
|
| 674 |
+
return args, kwargs
|
| 675 |
+
else:
|
| 676 |
+
return constrain_to_fx_strides(fx_node, *args, **kwargs)
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
add_layout_constraint(aten.convolution, constrain_conv_to_fx_strides)
|
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/flex_attention.py
ADDED
|
@@ -0,0 +1,1843 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
""" Triton Implementation of the flex_attention Kernel"""
|
| 3 |
+
|
| 4 |
+
import logging
|
| 5 |
+
import math
|
| 6 |
+
from typing import Any, List, Optional, Sequence, Tuple
|
| 7 |
+
|
| 8 |
+
import sympy
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch._inductor.virtualized import V
|
| 12 |
+
from torch.utils._pytree import tree_map
|
| 13 |
+
|
| 14 |
+
from .. import config
|
| 15 |
+
from ..ir import (
|
| 16 |
+
ComputedBuffer,
|
| 17 |
+
ExternKernel,
|
| 18 |
+
FixedLayout,
|
| 19 |
+
FlexibleLayout,
|
| 20 |
+
get_stride_order,
|
| 21 |
+
InputBuffer,
|
| 22 |
+
IRNode,
|
| 23 |
+
StorageBox,
|
| 24 |
+
stride_order2fill_order,
|
| 25 |
+
Subgraph,
|
| 26 |
+
TensorBox,
|
| 27 |
+
)
|
| 28 |
+
from ..lowering import empty, empty_strided, lowerings, register_lowering
|
| 29 |
+
from ..select_algorithm import autotune_select_algorithm, realize_inputs, TritonTemplate
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
log = logging.getLogger(__name__)
|
| 33 |
+
aten = torch.ops.aten
|
| 34 |
+
Expr = sympy.Expr
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def construct_strides(
|
| 38 |
+
sizes: Sequence[int],
|
| 39 |
+
fill_order: Sequence[int],
|
| 40 |
+
) -> Sequence[int]:
|
| 41 |
+
"""From a list of sizes and a fill order, construct the strides of the permuted tensor."""
|
| 42 |
+
# Initialize strides
|
| 43 |
+
assert len(sizes) == len(
|
| 44 |
+
fill_order
|
| 45 |
+
), "Length of sizes must match the length of the fill order"
|
| 46 |
+
strides = [0] * len(sizes)
|
| 47 |
+
|
| 48 |
+
# Start with stride 1 for the innermost dimension
|
| 49 |
+
current_stride = 1
|
| 50 |
+
|
| 51 |
+
# Iterate through the fill order populating strides
|
| 52 |
+
for dim in fill_order:
|
| 53 |
+
strides[dim] = current_stride
|
| 54 |
+
current_stride *= sizes[dim]
|
| 55 |
+
|
| 56 |
+
return strides
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def flex_attention_grid(batch_size, q_heads, num_queries, d_model, meta):
|
| 60 |
+
"""How is this kernel parallelized?
|
| 61 |
+
We create a grid of (batch_size * num_heads, ceil_div(n_queries, query_block_size), 1)
|
| 62 |
+
Each block is responsible for iterating over blocks of keys and values calculating
|
| 63 |
+
the final attention output.
|
| 64 |
+
"""
|
| 65 |
+
import triton
|
| 66 |
+
|
| 67 |
+
return (triton.cdiv(num_queries, meta["BLOCK_M"]), batch_size * q_heads, 1)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def create_placeholder(
|
| 71 |
+
name: str, dtype: torch.dtype, device: torch.device
|
| 72 |
+
) -> TensorBox:
|
| 73 |
+
"""Creates a placeholder input buffers for producing subgraph_output."""
|
| 74 |
+
input_buffer = InputBuffer(name, FixedLayout(device, dtype, [], []))
|
| 75 |
+
return TensorBox.create(input_buffer)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def maybe_realize(args: List[Optional[IRNode]]):
|
| 79 |
+
"""Accepts a list of optional IRNodes and returns a list of realized IRNodes"""
|
| 80 |
+
return tree_map(lambda x: realize_inputs(x) if x is not None else None, args)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def get_float32_precision():
|
| 84 |
+
if torch.get_float32_matmul_precision() == "highest" or torch.version.hip:
|
| 85 |
+
return "'ieee'"
|
| 86 |
+
else:
|
| 87 |
+
return "'tf32'"
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def build_subgraph_buffer(
|
| 91 |
+
args: List[TensorBox],
|
| 92 |
+
subgraph: Subgraph,
|
| 93 |
+
):
|
| 94 |
+
"""This function's goal is to take in the required args and produce the subgraph buffer
|
| 95 |
+
The subgraph buffer is a ComputedBuffer that will be inlined into the triton template
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
args: The args that are passed into the subgraph. Contains both fixed and lifted inputs.
|
| 99 |
+
subgraph: The Subgraph ir for which to produce the output node
|
| 100 |
+
"""
|
| 101 |
+
cnt = 0
|
| 102 |
+
env = {}
|
| 103 |
+
for node in subgraph.graph_module.graph.nodes:
|
| 104 |
+
# There are two classes of placeholder inpts that we need
|
| 105 |
+
# to handle differently. For the first n_scalar_inps inputs
|
| 106 |
+
# we expect that these placeholders were generated by the make_fx call
|
| 107 |
+
# in the flex Attention HOP. So we need to create a new placeholder
|
| 108 |
+
# TensorBox for each of these inputs. For the rest of the inputs we
|
| 109 |
+
# expect that these are lifted inputs that fill up the '*other_buffers'
|
| 110 |
+
# tuple and already have corresponding TensorBoxes passed in as args.
|
| 111 |
+
if node.op == "placeholder":
|
| 112 |
+
env[node] = args[cnt]
|
| 113 |
+
cnt += 1
|
| 114 |
+
elif node.op == "call_function":
|
| 115 |
+
# For call_function we use the default lowerings and pass in the
|
| 116 |
+
# already created TensorBoxes as args
|
| 117 |
+
|
| 118 |
+
args, kwargs = tree_map(
|
| 119 |
+
lambda x: env[x] if x in env else x, (node.args, node.kwargs)
|
| 120 |
+
)
|
| 121 |
+
env[node] = lowerings[node.target](*args, **kwargs)
|
| 122 |
+
elif node.op == "output":
|
| 123 |
+
|
| 124 |
+
def convert_output_node_to_buffer(output):
|
| 125 |
+
if output is None:
|
| 126 |
+
return None
|
| 127 |
+
output_node = output
|
| 128 |
+
output_buffer = env[output_node]
|
| 129 |
+
assert isinstance(output_buffer, TensorBox), (
|
| 130 |
+
"The output node for flex attention's subgraph must be a TensorBox, but got: ",
|
| 131 |
+
type(output_buffer),
|
| 132 |
+
)
|
| 133 |
+
assert isinstance(output_buffer.data, StorageBox), (
|
| 134 |
+
"The output node for the flex attention subgraph must be a StorageBox, but got: ",
|
| 135 |
+
type(output_buffer),
|
| 136 |
+
)
|
| 137 |
+
subgraph_buffer = ComputedBuffer(
|
| 138 |
+
name=None,
|
| 139 |
+
layout=FlexibleLayout(
|
| 140 |
+
device=output_buffer.data.get_device(),
|
| 141 |
+
dtype=output_buffer.data.get_dtype(),
|
| 142 |
+
size=output_buffer.data.get_size(),
|
| 143 |
+
),
|
| 144 |
+
data=output_buffer.data.data, # type: ignore[arg-type]
|
| 145 |
+
)
|
| 146 |
+
return subgraph_buffer
|
| 147 |
+
|
| 148 |
+
# node.args[0] is either a single element or a list of elements
|
| 149 |
+
# representing all outputs of the function.
|
| 150 |
+
return tree_map(convert_output_node_to_buffer, node.args[0])
|
| 151 |
+
|
| 152 |
+
raise ValueError("FlexAttention was passed a subgraph with no output node!")
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
# Inner Triton functions shared by flex_attention & split-k decoding kernels.
|
| 156 |
+
compute_next_offset_func = r"""
|
| 157 |
+
@triton.jit
|
| 158 |
+
def get_offset_for_next_block(loop_iter, col_indices, total_blocks, SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK):
|
| 159 |
+
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
|
| 160 |
+
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
|
| 161 |
+
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
|
| 162 |
+
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
|
| 163 |
+
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
|
| 164 |
+
|
| 165 |
+
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
|
| 166 |
+
return offset
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
compute_flex_attention = r"""
|
| 170 |
+
{{def_kernel("Q", "K", "V", "LSE", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}}
|
| 171 |
+
# Sub notation for this kernel:
|
| 172 |
+
#
|
| 173 |
+
# Q: Query, K: Key, V: Value
|
| 174 |
+
# M: Number of queries, N: Number of keys/values, D: Model dimension
|
| 175 |
+
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
| 176 |
+
# V_HEAD_DIM: The dimension of the value embeddings
|
| 177 |
+
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
|
| 178 |
+
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
| 179 |
+
#
|
| 180 |
+
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
|
| 181 |
+
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
| 182 |
+
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
| 183 |
+
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 184 |
+
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 185 |
+
#
|
| 186 |
+
# OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
|
| 187 |
+
#
|
| 188 |
+
# (Modifiable) Performance tuning options
|
| 189 |
+
# BLOCK_M: The thread block size across the seqlen dim of Q.
|
| 190 |
+
# BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
|
| 191 |
+
|
| 192 |
+
# The below are kernel options that can be applied for certain score_mods,
|
| 193 |
+
# or involve a numerics vs. perf tradeoff
|
| 194 |
+
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
|
| 195 |
+
# about 20% more numerical error, but slightly faster.
|
| 196 |
+
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
|
| 197 |
+
# is not masked out? If so, we can skip an extra safety check
|
| 198 |
+
|
| 199 |
+
tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
|
| 200 |
+
tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
|
| 201 |
+
|
| 202 |
+
# Define strides of inputs
|
| 203 |
+
stride_qz, stride_qh, stride_qm, stride_qk = {{stride("Q")}}
|
| 204 |
+
stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}}
|
| 205 |
+
stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}}
|
| 206 |
+
|
| 207 |
+
Z = {{size("Q", 0)}}
|
| 208 |
+
HQ = {{size("Q", 1)}}
|
| 209 |
+
Q_LEN = {{size("Q", 2)}}
|
| 210 |
+
KV_LEN = {{size("K", 2)}}
|
| 211 |
+
|
| 212 |
+
MATMUL_PRECISION = Q.dtype.element_ty
|
| 213 |
+
|
| 214 |
+
q_start = tl.program_id(0)
|
| 215 |
+
off_z = tl.program_id(1) // HQ
|
| 216 |
+
off_hq = tl.program_id(1) % HQ
|
| 217 |
+
off_hkv = off_hq // GQA_SHARED_HEADS
|
| 218 |
+
off_g = off_hq % GQA_SHARED_HEADS
|
| 219 |
+
|
| 220 |
+
q_offset = off_z * stride_qz + off_hq * stride_qh
|
| 221 |
+
k_offset = off_z * stride_kz + off_hkv * stride_kh
|
| 222 |
+
v_offset = off_z * stride_vz + off_hkv * stride_vh
|
| 223 |
+
|
| 224 |
+
Q = Q + q_offset
|
| 225 |
+
K = K + k_offset
|
| 226 |
+
V = V + v_offset
|
| 227 |
+
|
| 228 |
+
SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}
|
| 229 |
+
SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}}
|
| 230 |
+
|
| 231 |
+
sparse_idx_z = off_z % SPARSE_Z
|
| 232 |
+
sparse_idx_hq = off_hq % SPARSE_HQ
|
| 233 |
+
|
| 234 |
+
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
|
| 235 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 236 |
+
|
| 237 |
+
SPARSE_Q_BLOCK_CNT: tl.constexpr = tl.cdiv(Q_LEN, SPARSE_Q_BLOCK_SIZE)
|
| 238 |
+
SPARSE_KV_BLOCK_CNT: tl.constexpr = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE)
|
| 239 |
+
|
| 240 |
+
# initialize pointer to m and l
|
| 241 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
| 242 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
| 243 |
+
acc = tl.zeros([BLOCK_M, V_HEAD_DIM], dtype=tl.float32)
|
| 244 |
+
|
| 245 |
+
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 246 |
+
|
| 247 |
+
# KV_IDX and KV_NUM_BLKS are always contiguous.
|
| 248 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
|
| 249 |
+
sparse_kv_num_blks_offset = sparse_hz_offset * SPARSE_Q_BLOCK_CNT + q_start // SPARSE_Q_MULTIPLE
|
| 250 |
+
sparse_kv_idx_offset = sparse_hz_offset * SPARSE_Q_BLOCK_CNT * SPARSE_KV_BLOCK_CNT + (q_start // SPARSE_Q_MULTIPLE) * SPARSE_KV_BLOCK_CNT # noqa: B950
|
| 251 |
+
|
| 252 |
+
Q_block_ptr = tl.make_block_ptr(
|
| 253 |
+
base=Q,
|
| 254 |
+
shape=(Q_LEN, QK_HEAD_DIM),
|
| 255 |
+
strides=(stride_qm, stride_qk),
|
| 256 |
+
offsets=(q_start * BLOCK_M, 0),
|
| 257 |
+
block_shape=(BLOCK_M, QK_HEAD_DIM),
|
| 258 |
+
order=(1, 0)
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
# load q: it stays in SRAM throughout the inner loop.
|
| 262 |
+
if IS_DIVISIBLE:
|
| 263 |
+
q = tl.load(Q_block_ptr)
|
| 264 |
+
else:
|
| 265 |
+
# boundary check is not free, so we only do it when necessary.
|
| 266 |
+
q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option = "zero")
|
| 267 |
+
|
| 268 |
+
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 269 |
+
# We don't know anything "special" about these blocks, so we need to apply
|
| 270 |
+
# both score_mod and mask_mod to it
|
| 271 |
+
kv_indices = KV_IDX + sparse_kv_idx_offset
|
| 272 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 273 |
+
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 274 |
+
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 275 |
+
|
| 276 |
+
K_block_ptr = tl.make_block_ptr(
|
| 277 |
+
base=K,
|
| 278 |
+
shape=(QK_HEAD_DIM, KV_LEN),
|
| 279 |
+
strides=(stride_kk, stride_kn),
|
| 280 |
+
offsets=(0, kv_start),
|
| 281 |
+
block_shape=(QK_HEAD_DIM, BLOCK_N),
|
| 282 |
+
order=(0, 1)
|
| 283 |
+
)
|
| 284 |
+
V_block_ptr = tl.make_block_ptr(
|
| 285 |
+
base=V,
|
| 286 |
+
shape=(KV_LEN, V_HEAD_DIM),
|
| 287 |
+
strides=(stride_vn, stride_vk),
|
| 288 |
+
offsets=(kv_start, 0),
|
| 289 |
+
block_shape=(BLOCK_N, V_HEAD_DIM),
|
| 290 |
+
order=(1, 0)
|
| 291 |
+
)
|
| 292 |
+
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
| 293 |
+
|
| 294 |
+
acc, l_i, m_i = forward_inner(
|
| 295 |
+
{{gen_argdefs()}},
|
| 296 |
+
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
|
| 297 |
+
acc, l_i, m_i,
|
| 298 |
+
off_z, off_hq, offs_m[:, None], offs_n[None, :],
|
| 299 |
+
kv_indices, kv_num_blocks,
|
| 300 |
+
0, block_n_end,
|
| 301 |
+
MATMUL_PRECISION,
|
| 302 |
+
IS_FULL_BLOCKS=False,
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
# ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 306 |
+
# We know these blocks are guaranteed to be "full", so we don't need to
|
| 307 |
+
# apply mask_mod to them - only score_mod
|
| 308 |
+
if HAS_FULL_BLOCKS:
|
| 309 |
+
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
|
| 310 |
+
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
|
| 311 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 312 |
+
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 313 |
+
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 314 |
+
|
| 315 |
+
K_block_ptr = tl.make_block_ptr(
|
| 316 |
+
base=K,
|
| 317 |
+
shape=(QK_HEAD_DIM, KV_LEN),
|
| 318 |
+
strides=(stride_kk, stride_kn),
|
| 319 |
+
offsets=(0, kv_start),
|
| 320 |
+
block_shape=(QK_HEAD_DIM, BLOCK_N),
|
| 321 |
+
order=(0, 1)
|
| 322 |
+
)
|
| 323 |
+
V_block_ptr = tl.make_block_ptr(
|
| 324 |
+
base=V,
|
| 325 |
+
shape=(KV_LEN, V_HEAD_DIM),
|
| 326 |
+
strides=(stride_vn, stride_vk),
|
| 327 |
+
offsets=(kv_start, 0),
|
| 328 |
+
block_shape=(BLOCK_N, V_HEAD_DIM),
|
| 329 |
+
order=(1, 0)
|
| 330 |
+
)
|
| 331 |
+
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
| 332 |
+
|
| 333 |
+
acc, l_i, m_i = forward_inner(
|
| 334 |
+
{{gen_argdefs()}},
|
| 335 |
+
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
|
| 336 |
+
acc, l_i, m_i,
|
| 337 |
+
off_z, off_hq, offs_m[:, None], offs_n[None, :],
|
| 338 |
+
kv_indices, kv_num_blocks,
|
| 339 |
+
0, block_n_end,
|
| 340 |
+
MATMUL_PRECISION,
|
| 341 |
+
IS_FULL_BLOCKS=True,
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
# [Note] Handle fully masked out rows:
|
| 346 |
+
# Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
|
| 347 |
+
# We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
|
| 348 |
+
l_i = tl.where(l_i == 0.0, 1, l_i)
|
| 349 |
+
|
| 350 |
+
acc = acc / l_i[:, None]
|
| 351 |
+
idx_z = tl.program_id(1) // HQ
|
| 352 |
+
idx_hq = tl.program_id(1) % HQ
|
| 353 |
+
idx_m = offs_m[:, None]
|
| 354 |
+
idx_d = tl.arange(0, V_HEAD_DIM)[None, :]
|
| 355 |
+
|
| 356 |
+
mask = idx_m < Q_LEN
|
| 357 |
+
# TODO generalize and add proper mask support
|
| 358 |
+
{{store_output(("idx_z", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}}
|
| 359 |
+
|
| 360 |
+
# TODO dont want to write this if we dont require grad
|
| 361 |
+
if OUTPUT_LOGSUMEXP:
|
| 362 |
+
off_hz = tl.program_id(1)
|
| 363 |
+
l_ptrs = LSE + off_hz * Q_LEN + offs_m
|
| 364 |
+
lse = m_i + tl.math.log2(l_i)
|
| 365 |
+
if IS_DIVISIBLE:
|
| 366 |
+
tl.store(l_ptrs, lse)
|
| 367 |
+
else:
|
| 368 |
+
tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
|
| 369 |
+
"""
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
compute_forward_inner = r"""
|
| 373 |
+
@triton.jit
|
| 374 |
+
def forward_inner(
|
| 375 |
+
{{gen_argdefs()}},
|
| 376 |
+
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
|
| 377 |
+
# accumulated values
|
| 378 |
+
acc, l_i, m_i,
|
| 379 |
+
# Offsets used as inputs to score_mod & mask_mod
|
| 380 |
+
# of size [BLOCK_M, BLOCK_N] or scalar.
|
| 381 |
+
off_z, off_h, offs_m, offs_n,
|
| 382 |
+
# blocksparse data
|
| 383 |
+
kv_indices, kv_num_blocks,
|
| 384 |
+
# start kv and end kv block
|
| 385 |
+
block_n_start, block_n_end,
|
| 386 |
+
MATMUL_PRECISION,
|
| 387 |
+
IS_FULL_BLOCKS,
|
| 388 |
+
):
|
| 389 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 390 |
+
{{gen_defines() | indent_except_first(1)}}
|
| 391 |
+
|
| 392 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 393 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 394 |
+
|
| 395 |
+
if PRESCALE_QK:
|
| 396 |
+
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 397 |
+
|
| 398 |
+
# loop over k, v and update accumulator until block_n_end
|
| 399 |
+
for start_n in range(block_n_start, block_n_end):
|
| 400 |
+
if IS_DIVISIBLE:
|
| 401 |
+
acc, l_i, m_i = forward_block_mn(
|
| 402 |
+
{{gen_argdefs()}},
|
| 403 |
+
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
|
| 404 |
+
# accumulated values
|
| 405 |
+
acc, l_i, m_i,
|
| 406 |
+
# Offsets
|
| 407 |
+
off_z, off_h, offs_m, offs_n,
|
| 408 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 409 |
+
IS_FULL_BLOCKS,
|
| 410 |
+
)
|
| 411 |
+
else:
|
| 412 |
+
# Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
|
| 413 |
+
# it's on par or slightly faster than only applying to the last block in fwd.
|
| 414 |
+
# However, we choose different strategy for bwd, where we only apply mod & mask
|
| 415 |
+
# to the last block because it's faster a lot.
|
| 416 |
+
acc, l_i, m_i = forward_block_mn(
|
| 417 |
+
{{gen_argdefs()}},
|
| 418 |
+
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
|
| 419 |
+
# accumulated values
|
| 420 |
+
acc, l_i, m_i,
|
| 421 |
+
# Offsets
|
| 422 |
+
off_z, off_h, offs_m, offs_n,
|
| 423 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 424 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
# update pointers
|
| 428 |
+
offset = get_offset_for_next_block(
|
| 429 |
+
start_n, kv_indices, kv_num_blocks,
|
| 430 |
+
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
V_block_ptr = tl.advance(V_block_ptr, (offset, 0))
|
| 434 |
+
K_block_ptr = tl.advance(K_block_ptr, (0, offset))
|
| 435 |
+
|
| 436 |
+
offs_n = offs_n + offset
|
| 437 |
+
|
| 438 |
+
return acc, l_i, m_i
|
| 439 |
+
|
| 440 |
+
"""
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
compute_forward_block_mn = r"""
|
| 444 |
+
@triton.jit
|
| 445 |
+
def forward_block_mn(
|
| 446 |
+
{{gen_argdefs()}},
|
| 447 |
+
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
|
| 448 |
+
# accumulated values
|
| 449 |
+
acc, l_i, m_i,
|
| 450 |
+
# Offsets
|
| 451 |
+
off_z, off_h, offs_m, offs_n,
|
| 452 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 453 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
|
| 454 |
+
):
|
| 455 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 456 |
+
{{gen_defines() | indent_except_first(1)}}
|
| 457 |
+
|
| 458 |
+
# -- load k --
|
| 459 |
+
if IS_DIVISIBLE:
|
| 460 |
+
k = tl.load(K_block_ptr)
|
| 461 |
+
else:
|
| 462 |
+
k = tl.load(K_block_ptr, boundary_check=(1,), padding_option = "zero")
|
| 463 |
+
# -- compute qk ---
|
| 464 |
+
qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
|
| 465 |
+
if not PRESCALE_QK:
|
| 466 |
+
qk *= SM_SCALE
|
| 467 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 468 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 469 |
+
# If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
|
| 470 |
+
# which is larger than the actual number of elements. To avoid access memory out of bound,
|
| 471 |
+
# we need to mask out the elements that are out of Q_LEN & KV_LEN.
|
| 472 |
+
m = offs_m % Q_LEN
|
| 473 |
+
n = offs_n % KV_LEN
|
| 474 |
+
else:
|
| 475 |
+
m = offs_m
|
| 476 |
+
n = offs_n
|
| 477 |
+
|
| 478 |
+
{{ modification(
|
| 479 |
+
subgraph_number=0,
|
| 480 |
+
output_name="post_mod_scores",
|
| 481 |
+
score="qk",
|
| 482 |
+
b="off_z",
|
| 483 |
+
h="off_h",
|
| 484 |
+
m="m",
|
| 485 |
+
n="n",
|
| 486 |
+
out="qk"
|
| 487 |
+
) | indent_except_first(1) }}
|
| 488 |
+
|
| 489 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 490 |
+
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
|
| 491 |
+
post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
|
| 492 |
+
|
| 493 |
+
if not IS_FULL_BLOCKS:
|
| 494 |
+
{{ modification(
|
| 495 |
+
subgraph_number=1,
|
| 496 |
+
output_name="mask_mod_output",
|
| 497 |
+
score="qk",
|
| 498 |
+
b="off_z",
|
| 499 |
+
h="off_h",
|
| 500 |
+
m="m",
|
| 501 |
+
n="n",
|
| 502 |
+
) | indent_except_first(2) }}
|
| 503 |
+
|
| 504 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 505 |
+
mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, float("-inf"))
|
| 506 |
+
# apply mask for partially unmasked blocks
|
| 507 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 508 |
+
|
| 509 |
+
# TODO: In the case that score_mod is linear, this can be LICMed
|
| 510 |
+
if not PRESCALE_QK:
|
| 511 |
+
post_mod_scores *= RCP_LN2
|
| 512 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 513 |
+
|
| 514 |
+
# -- compute scaling constant ---
|
| 515 |
+
m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
|
| 516 |
+
if not ROWS_GUARANTEED_SAFE:
|
| 517 |
+
masked_out_rows = (m_ij == float("-inf"))
|
| 518 |
+
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
|
| 519 |
+
else:
|
| 520 |
+
m_ij_masked = m_ij
|
| 521 |
+
|
| 522 |
+
alpha = tl.math.exp2(m_i - m_ij_masked)
|
| 523 |
+
p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
|
| 524 |
+
|
| 525 |
+
# NB: l_i update is pulled up here since it's a bit faster
|
| 526 |
+
# NB: For headdim=256, it's faster to move it back down to after m_i =
|
| 527 |
+
# m_ij
|
| 528 |
+
l_i = l_i * alpha + tl.sum(p, 1)
|
| 529 |
+
# # -- scale and update acc --
|
| 530 |
+
acc = acc * alpha[:, None]
|
| 531 |
+
|
| 532 |
+
if IS_DIVISIBLE:
|
| 533 |
+
v = tl.load(V_block_ptr)
|
| 534 |
+
else:
|
| 535 |
+
v = tl.load(V_block_ptr, boundary_check=(0,), padding_option = "zero")
|
| 536 |
+
acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
|
| 537 |
+
|
| 538 |
+
# -- update m_i
|
| 539 |
+
m_i = m_ij
|
| 540 |
+
|
| 541 |
+
return acc, l_i, m_i
|
| 542 |
+
|
| 543 |
+
"""
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
flex_attention_template = TritonTemplate(
|
| 547 |
+
name="flex_attention",
|
| 548 |
+
grid=flex_attention_grid,
|
| 549 |
+
source=compute_flex_attention
|
| 550 |
+
+ compute_forward_inner
|
| 551 |
+
+ compute_next_offset_func
|
| 552 |
+
+ compute_forward_block_mn,
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
def _use_flex_decoding(query, kernel_options):
|
| 557 |
+
# Decide which kernel to use, return true if use flex decoding kernel.
|
| 558 |
+
return (
|
| 559 |
+
not kernel_options.get("FORCE_USE_FLEX_ATTENTION", False)
|
| 560 |
+
) and V.graph.sizevars.evaluate_expr(sympy.Lt(query.get_size()[-2], 128))
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
_h100_default_config = {
|
| 564 |
+
(torch.float32, 64): (128, 32, 4, 3),
|
| 565 |
+
(torch.float32, 128): (32, 64, 4, 3),
|
| 566 |
+
(torch.float32, 256): (32, 32, 4, 3),
|
| 567 |
+
(torch.bfloat16, 64): (128, 128, 4, 3),
|
| 568 |
+
(torch.bfloat16, 128): (128, 64, 8, 3),
|
| 569 |
+
(torch.bfloat16, 256): (64, 32, 4, 3),
|
| 570 |
+
(torch.float16, 64): (128, 128, 4, 3),
|
| 571 |
+
(torch.float16, 128): (128, 128, 8, 3),
|
| 572 |
+
(torch.float16, 256): (64, 32, 4, 3),
|
| 573 |
+
}
|
| 574 |
+
|
| 575 |
+
_a100_default_config = {
|
| 576 |
+
(torch.float32, 64): (128, 32, 4, 3),
|
| 577 |
+
(torch.float32, 128): (128, 32, 4, 3),
|
| 578 |
+
(torch.float32, 256): (64, 16, 4, 3),
|
| 579 |
+
(torch.bfloat16, 64): (128, 64, 4, 3),
|
| 580 |
+
(torch.bfloat16, 128): (128, 64, 8, 3),
|
| 581 |
+
(torch.bfloat16, 256): (32, 64, 4, 3),
|
| 582 |
+
(torch.float16, 64): (128, 64, 4, 3),
|
| 583 |
+
(torch.float16, 128): (128, 64, 8, 3),
|
| 584 |
+
(torch.float16, 256): (32, 64, 4, 3),
|
| 585 |
+
}
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
def _get_default_config_fwd(query) -> Tuple[int, int, int, int]:
|
| 589 |
+
dtype = query.get_dtype()
|
| 590 |
+
head_dim = query.get_size()[-1]
|
| 591 |
+
default_config = None
|
| 592 |
+
|
| 593 |
+
if head_dim <= 256 and torch.cuda.get_device_capability() >= (9, 0): # H100
|
| 594 |
+
if dtype == torch.float32:
|
| 595 |
+
default_config = (64, 64, 4, 3)
|
| 596 |
+
else:
|
| 597 |
+
default_config = (128, 64, 4, 3)
|
| 598 |
+
default_config = _h100_default_config.get((dtype, head_dim), default_config)
|
| 599 |
+
elif head_dim <= 256 and torch.cuda.get_device_capability() >= (8, 0): # A100
|
| 600 |
+
if dtype == torch.float32:
|
| 601 |
+
default_config = (64, 64, 4, 3)
|
| 602 |
+
else:
|
| 603 |
+
default_config = (128, 64, 4, 3)
|
| 604 |
+
default_config = _a100_default_config.get((dtype, head_dim), default_config)
|
| 605 |
+
else: # modest hardware or extremely large head_dim
|
| 606 |
+
if dtype == torch.float32:
|
| 607 |
+
default_config = (32, 16, 4, 3)
|
| 608 |
+
else:
|
| 609 |
+
default_config = (64, 32, 4, 3)
|
| 610 |
+
|
| 611 |
+
return default_config
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
def _get_default_config_bwd(query) -> Tuple[int, int, int, int]:
|
| 615 |
+
head_dim = query.get_size()[-1]
|
| 616 |
+
dtype = query.get_dtype()
|
| 617 |
+
|
| 618 |
+
if dtype == torch.float32:
|
| 619 |
+
return (16, 16, 4, 1)
|
| 620 |
+
if head_dim <= 256 and torch.cuda.get_device_capability() >= (9, 0): # H100
|
| 621 |
+
if head_dim == 64:
|
| 622 |
+
return (64, 64, 4, 3)
|
| 623 |
+
elif head_dim == 128:
|
| 624 |
+
return (64, 128, 8, 3)
|
| 625 |
+
else:
|
| 626 |
+
return (64, 64, 4, 2)
|
| 627 |
+
elif torch.cuda.get_device_capability() >= (8, 0): # A100
|
| 628 |
+
if head_dim == 64:
|
| 629 |
+
return (32, 128, 4, 3)
|
| 630 |
+
elif head_dim == 128:
|
| 631 |
+
return (64, 128, 8, 3)
|
| 632 |
+
else:
|
| 633 |
+
return (64, 64, 4, 2)
|
| 634 |
+
else: # modest hardware or extremely large head_dim
|
| 635 |
+
return (16, 16, 4, 1)
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
def create_num_blocks_fake_generator(sparse_indices):
|
| 639 |
+
# The idea here is that we need to create a real tensor with real data
|
| 640 |
+
# that's representative for benchmarking.
|
| 641 |
+
# For example, returning all zeros for the `kv_num_blocks` input would mean
|
| 642 |
+
# that we are computing 0 blocks for each row, which would provide bogus
|
| 643 |
+
# autotuning results.
|
| 644 |
+
#
|
| 645 |
+
# In this case, we choose to use min(16, max_block) blocks, because I
|
| 646 |
+
# (Horace) think it'll probably result in pretty representative performance.
|
| 647 |
+
# If it's too short then prefetching won't help. If it's too long then
|
| 648 |
+
# autotuning will take longer for no good reason.
|
| 649 |
+
def create_num_blocks_fake(x) -> torch.Tensor:
|
| 650 |
+
num_blocks_for_autotuning = min(16, sparse_indices.shape[-1])
|
| 651 |
+
return torch.full(
|
| 652 |
+
x.get_size(),
|
| 653 |
+
int(num_blocks_for_autotuning),
|
| 654 |
+
dtype=x.get_dtype(),
|
| 655 |
+
device=x.get_device(),
|
| 656 |
+
)
|
| 657 |
+
|
| 658 |
+
return create_num_blocks_fake
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
def create_indices_fake(x) -> torch.Tensor:
|
| 662 |
+
indices = torch.arange(
|
| 663 |
+
0, int(x.get_size()[-1]), dtype=x.get_dtype(), device=x.get_device()
|
| 664 |
+
)
|
| 665 |
+
indices = indices.expand(x.get_size()).contiguous()
|
| 666 |
+
return indices
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
from torch._inductor.kernel.flex_decoding import create_flex_decoding_kernel
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
# TODO: We probably also need a layout constraint?
|
| 673 |
+
@register_lowering(torch.ops.higher_order.flex_attention, type_promotion_kind=None)
|
| 674 |
+
def flex_attention(
|
| 675 |
+
query,
|
| 676 |
+
key,
|
| 677 |
+
value,
|
| 678 |
+
subgraph,
|
| 679 |
+
block_mask,
|
| 680 |
+
scale,
|
| 681 |
+
kernel_options,
|
| 682 |
+
score_mod_other_buffers,
|
| 683 |
+
mask_mod_other_buffers,
|
| 684 |
+
):
|
| 685 |
+
(
|
| 686 |
+
kv_num_blocks,
|
| 687 |
+
kv_indices,
|
| 688 |
+
full_kv_num_blocks,
|
| 689 |
+
full_kv_indices,
|
| 690 |
+
q_num_blocks,
|
| 691 |
+
q_indices,
|
| 692 |
+
full_q_num_blocks,
|
| 693 |
+
full_q_indices,
|
| 694 |
+
SPARSE_KV_BLOCK_SIZE,
|
| 695 |
+
SPARSE_Q_BLOCK_SIZE,
|
| 696 |
+
mask_graph,
|
| 697 |
+
) = block_mask
|
| 698 |
+
placeholder_inps = [
|
| 699 |
+
create_placeholder(name, dtype, query.get_device())
|
| 700 |
+
for name, dtype in [
|
| 701 |
+
("score", query.get_dtype()),
|
| 702 |
+
("b", torch.int32),
|
| 703 |
+
("h", torch.int32),
|
| 704 |
+
("m", torch.int32),
|
| 705 |
+
("n", torch.int32),
|
| 706 |
+
]
|
| 707 |
+
]
|
| 708 |
+
subgraph_buffer = build_subgraph_buffer(
|
| 709 |
+
placeholder_inps + list(score_mod_other_buffers), subgraph
|
| 710 |
+
)
|
| 711 |
+
mask_graph_placeholder_inps = [
|
| 712 |
+
create_placeholder(name, dtype, query.get_device())
|
| 713 |
+
for name, dtype in [
|
| 714 |
+
("b", torch.int32),
|
| 715 |
+
("h", torch.int32),
|
| 716 |
+
("m", torch.int32),
|
| 717 |
+
("n", torch.int32),
|
| 718 |
+
]
|
| 719 |
+
]
|
| 720 |
+
mask_graph_buffer = build_subgraph_buffer(
|
| 721 |
+
mask_graph_placeholder_inps + list(mask_mod_other_buffers), mask_graph
|
| 722 |
+
)
|
| 723 |
+
kernel_options = dict(kernel_options)
|
| 724 |
+
kernel_options.setdefault("FLOAT32_PRECISION", get_float32_precision())
|
| 725 |
+
if _use_flex_decoding(query, kernel_options):
|
| 726 |
+
return create_flex_decoding_kernel(
|
| 727 |
+
query,
|
| 728 |
+
key,
|
| 729 |
+
value,
|
| 730 |
+
block_mask,
|
| 731 |
+
scale,
|
| 732 |
+
kernel_options,
|
| 733 |
+
subgraph_buffer,
|
| 734 |
+
mask_graph_buffer,
|
| 735 |
+
score_mod_other_buffers,
|
| 736 |
+
mask_mod_other_buffers,
|
| 737 |
+
)
|
| 738 |
+
|
| 739 |
+
(
|
| 740 |
+
query,
|
| 741 |
+
key,
|
| 742 |
+
value,
|
| 743 |
+
kv_num_blocks,
|
| 744 |
+
kv_indices,
|
| 745 |
+
full_kv_num_blocks,
|
| 746 |
+
full_kv_indices,
|
| 747 |
+
q_num_blocks,
|
| 748 |
+
q_indices,
|
| 749 |
+
full_q_num_blocks,
|
| 750 |
+
full_q_indices,
|
| 751 |
+
) = maybe_realize(
|
| 752 |
+
[
|
| 753 |
+
query,
|
| 754 |
+
key,
|
| 755 |
+
value,
|
| 756 |
+
kv_num_blocks,
|
| 757 |
+
kv_indices,
|
| 758 |
+
full_kv_num_blocks,
|
| 759 |
+
full_kv_indices,
|
| 760 |
+
q_num_blocks,
|
| 761 |
+
q_indices,
|
| 762 |
+
full_q_num_blocks,
|
| 763 |
+
full_q_indices,
|
| 764 |
+
]
|
| 765 |
+
)
|
| 766 |
+
|
| 767 |
+
Bq, Hq, seq_len_q, qk_head_dim = query.get_size()
|
| 768 |
+
Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size()
|
| 769 |
+
assert Bq == Bkv, "Batch dimension must match"
|
| 770 |
+
B = Bq
|
| 771 |
+
|
| 772 |
+
if seq_len_q % 128 != 0 or seq_len_kv % 128 != 0:
|
| 773 |
+
kernel_options.setdefault("IS_DIVISIBLE", False)
|
| 774 |
+
else:
|
| 775 |
+
kernel_options.setdefault("IS_DIVISIBLE", True)
|
| 776 |
+
|
| 777 |
+
# Reuse query strides for output layout despite different last dimension.
|
| 778 |
+
# This works because only the last dim differs and we check it is contiguous.
|
| 779 |
+
q_strides = query.get_stride()
|
| 780 |
+
assert q_strides[-1] == 1, "Query must be contiguous in the last dimension"
|
| 781 |
+
|
| 782 |
+
# Construct output layout with strides matching the query.
|
| 783 |
+
out_size = [B, Hq, seq_len_q, v_head_dim]
|
| 784 |
+
stride_order = get_stride_order(query.get_stride())
|
| 785 |
+
fill_order = stride_order2fill_order(stride_order)
|
| 786 |
+
out_strides = construct_strides(out_size, fill_order)
|
| 787 |
+
|
| 788 |
+
layout = FixedLayout(
|
| 789 |
+
query.get_device(),
|
| 790 |
+
query.get_dtype(),
|
| 791 |
+
[B, Hq, seq_len_q, v_head_dim],
|
| 792 |
+
stride=out_strides,
|
| 793 |
+
)
|
| 794 |
+
# see NOTE:[TritonTemplates with multiple outputs]
|
| 795 |
+
logsumexp_shape = [B, Hq, seq_len_q]
|
| 796 |
+
logsumexp = empty_strided(
|
| 797 |
+
logsumexp_shape,
|
| 798 |
+
None,
|
| 799 |
+
dtype=torch.float32, # The logsumexp is always stored in fp32 regardless of the input dtype
|
| 800 |
+
device=query.get_device(),
|
| 801 |
+
)
|
| 802 |
+
kernel_options.setdefault("SM_SCALE", scale)
|
| 803 |
+
|
| 804 |
+
# Determine GQA broadcast factor.
|
| 805 |
+
gqa_shared_heads = Hq // Hkv
|
| 806 |
+
kernel_options.setdefault("GQA_SHARED_HEADS", gqa_shared_heads)
|
| 807 |
+
|
| 808 |
+
# Inside of Triton kernel, only apply partial masking if partial blocks are computed.
|
| 809 |
+
# full_kv_num_blocks is None if partial blocks are not computed
|
| 810 |
+
has_full_blocks = full_kv_num_blocks is not None
|
| 811 |
+
kernel_options.setdefault("HAS_FULL_BLOCKS", has_full_blocks)
|
| 812 |
+
if not has_full_blocks:
|
| 813 |
+
full_kv_num_blocks, full_kv_indices = (
|
| 814 |
+
empty(0, device=query.get_device()) for _ in range(2)
|
| 815 |
+
)
|
| 816 |
+
kernel_options.setdefault("QK_HEAD_DIM", qk_head_dim)
|
| 817 |
+
kernel_options.setdefault("V_HEAD_DIM", v_head_dim)
|
| 818 |
+
|
| 819 |
+
choices: List[Any] = []
|
| 820 |
+
configs: List[Tuple[int, int, int, int]] = []
|
| 821 |
+
configs.append(_get_default_config_fwd(query))
|
| 822 |
+
if config.max_autotune:
|
| 823 |
+
configs += [
|
| 824 |
+
(128, 64, 4, 3),
|
| 825 |
+
(128, 128, 4, 3),
|
| 826 |
+
(128, 128, 8, 2),
|
| 827 |
+
(64, 128, 4, 3),
|
| 828 |
+
(64, 64, 4, 3),
|
| 829 |
+
]
|
| 830 |
+
|
| 831 |
+
# Note, we don't need to pass in the captured buffers explicitly
|
| 832 |
+
# because they're implicitly added by the score_mod function
|
| 833 |
+
# We do need to explicitly pass it in for autotuning though.
|
| 834 |
+
|
| 835 |
+
for BLOCK_M, BLOCK_N, num_warps, num_stages in configs:
|
| 836 |
+
if SPARSE_KV_BLOCK_SIZE % BLOCK_N != 0 or SPARSE_Q_BLOCK_SIZE % BLOCK_M != 0:
|
| 837 |
+
continue
|
| 838 |
+
# Work around https://github.com/pytorch/pytorch/issues/129625
|
| 839 |
+
if num_stages == 2:
|
| 840 |
+
continue
|
| 841 |
+
|
| 842 |
+
# Performance tuning
|
| 843 |
+
kernel_options.setdefault("BLOCK_M", BLOCK_M)
|
| 844 |
+
kernel_options.setdefault("BLOCK_N", BLOCK_N)
|
| 845 |
+
# Blocksparse options
|
| 846 |
+
kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE)
|
| 847 |
+
kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE)
|
| 848 |
+
|
| 849 |
+
flex_attention_template.maybe_append_choice(
|
| 850 |
+
choices=choices,
|
| 851 |
+
input_nodes=[
|
| 852 |
+
query,
|
| 853 |
+
key,
|
| 854 |
+
value,
|
| 855 |
+
logsumexp,
|
| 856 |
+
kv_num_blocks,
|
| 857 |
+
kv_indices,
|
| 858 |
+
full_kv_num_blocks,
|
| 859 |
+
full_kv_indices,
|
| 860 |
+
],
|
| 861 |
+
layout=layout,
|
| 862 |
+
subgraphs=[
|
| 863 |
+
subgraph_buffer,
|
| 864 |
+
mask_graph_buffer,
|
| 865 |
+
],
|
| 866 |
+
mutated_inputs=[
|
| 867 |
+
logsumexp,
|
| 868 |
+
],
|
| 869 |
+
num_stages=num_stages,
|
| 870 |
+
num_warps=num_warps,
|
| 871 |
+
call_sizes=query.get_size(),
|
| 872 |
+
**kernel_options,
|
| 873 |
+
)
|
| 874 |
+
inputs_for_autotuning = (
|
| 875 |
+
[
|
| 876 |
+
query,
|
| 877 |
+
key,
|
| 878 |
+
value,
|
| 879 |
+
logsumexp,
|
| 880 |
+
kv_num_blocks,
|
| 881 |
+
kv_indices,
|
| 882 |
+
full_kv_num_blocks,
|
| 883 |
+
full_kv_indices,
|
| 884 |
+
]
|
| 885 |
+
+ list(score_mod_other_buffers)
|
| 886 |
+
+ list(mask_mod_other_buffers)
|
| 887 |
+
)
|
| 888 |
+
input_gen_fns = {
|
| 889 |
+
4: create_num_blocks_fake_generator(kv_indices),
|
| 890 |
+
5: create_indices_fake,
|
| 891 |
+
6: create_num_blocks_fake_generator(full_kv_indices),
|
| 892 |
+
7: create_indices_fake,
|
| 893 |
+
}
|
| 894 |
+
return (
|
| 895 |
+
autotune_select_algorithm(
|
| 896 |
+
"flex_attention",
|
| 897 |
+
choices,
|
| 898 |
+
inputs_for_autotuning,
|
| 899 |
+
layout,
|
| 900 |
+
input_gen_fns=input_gen_fns,
|
| 901 |
+
),
|
| 902 |
+
logsumexp,
|
| 903 |
+
)
|
| 904 |
+
|
| 905 |
+
|
| 906 |
+
# ---------------------------- Backward HOP Implementation ----------------------------
|
| 907 |
+
|
| 908 |
+
|
| 909 |
+
def flex_attention_backward_grid(
|
| 910 |
+
batch_size, q_heads, num_queries, d_model, kv_heads, num_key_value, meta
|
| 911 |
+
):
|
| 912 |
+
"""How is this kernel parallelized?
|
| 913 |
+
Currently this is only parallelizing over batch* kv_heads, but we can, and want to
|
| 914 |
+
parallelize over ceil_div(q_heads//kv_heads * num_key_value, key_value_block_size).
|
| 915 |
+
To do this will either require atomic updates to some grad values or to have a two pass kernel design.
|
| 916 |
+
"""
|
| 917 |
+
import triton
|
| 918 |
+
|
| 919 |
+
return (
|
| 920 |
+
triton.cdiv(num_queries, meta["BLOCK_M2"]) * (q_heads // kv_heads)
|
| 921 |
+
+ triton.cdiv(num_key_value, meta["BLOCK_N1"]),
|
| 922 |
+
1,
|
| 923 |
+
batch_size * kv_heads,
|
| 924 |
+
)
|
| 925 |
+
|
| 926 |
+
|
| 927 |
+
flex_attention_backward_template = TritonTemplate(
|
| 928 |
+
name="flex_attention_backward",
|
| 929 |
+
grid=flex_attention_backward_grid,
|
| 930 |
+
source=r"""
|
| 931 |
+
{{def_kernel("Q", "K", "V", "LSE", "DELTA", "DO", "DQ", "DV", "KV_NUM_BLKS", "KV_IDX", "Q_NUM_BLKS", "Q_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX", "FULL_Q_NUM_BLKS", "FULL_Q_IDX")}}
|
| 932 |
+
# Sub notation for this kernel:
|
| 933 |
+
#
|
| 934 |
+
# Q: Query, K: Key, V: Value
|
| 935 |
+
# LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
|
| 936 |
+
# DELTA: Precomputed sum(OUT*DO, axis=-1)
|
| 937 |
+
# DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
|
| 938 |
+
# DK: Derivative of Key, is the written to via the store_output call due to some limitations with
|
| 939 |
+
# inductor codegen
|
| 940 |
+
# M: Number of queries, N: Number of keys/values
|
| 941 |
+
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
| 942 |
+
# V_HEAD_DIM: The dimension of the value embeddings
|
| 943 |
+
# z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
|
| 944 |
+
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
| 945 |
+
# (Modifiable) Performance tuning options
|
| 946 |
+
# BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
|
| 947 |
+
# BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
|
| 948 |
+
# BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
|
| 949 |
+
# BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
|
| 950 |
+
#
|
| 951 |
+
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
|
| 952 |
+
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
| 953 |
+
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
| 954 |
+
# Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
|
| 955 |
+
# Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
|
| 956 |
+
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 957 |
+
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 958 |
+
# FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
|
| 959 |
+
# FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
|
| 960 |
+
|
| 961 |
+
# The below are kernel options that can be applied for certain score_mods,
|
| 962 |
+
# or involve a numerics vs. perf tradeoff
|
| 963 |
+
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
|
| 964 |
+
# about 20% more numerical error, but slightly faster.
|
| 965 |
+
|
| 966 |
+
# Define strides of inputs
|
| 967 |
+
stride_qz, stride_qh, stride_qm, stride_qd = {{stride("Q")}}
|
| 968 |
+
stride_kz, stride_kh, stride_kn, stride_kd = {{stride("K")}}
|
| 969 |
+
stride_vz, stride_vh, stride_vn, stride_vd = {{stride("V")}}
|
| 970 |
+
stride_doz, stride_doh, stride_dom, stride_dod = {{stride("DO")}}
|
| 971 |
+
|
| 972 |
+
stride_dqz, stride_dqh, stride_dqm, stride_dqd = {{stride("DQ")}}
|
| 973 |
+
stride_dvz, stride_dvh, stride_dvm, stride_dvd = {{stride("DV")}}
|
| 974 |
+
|
| 975 |
+
Z = {{size("Q", 0)}}
|
| 976 |
+
HQ = {{size("Q", 1)}}
|
| 977 |
+
HKV = {{size("K", 1)}}
|
| 978 |
+
Q_LEN = {{size("Q", 2)}}
|
| 979 |
+
KV_LEN = {{size("K", 2)}}
|
| 980 |
+
|
| 981 |
+
MATMUL_PRECISION = Q.dtype.element_ty
|
| 982 |
+
|
| 983 |
+
pid = tl.program_id(0)
|
| 984 |
+
NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
|
| 985 |
+
NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
|
| 986 |
+
|
| 987 |
+
off_hz = tl.program_id(2)
|
| 988 |
+
off_z = off_hz // HKV # batch idx
|
| 989 |
+
off_hkv = off_hz % HKV # kv head idx
|
| 990 |
+
|
| 991 |
+
SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}
|
| 992 |
+
SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}}
|
| 993 |
+
|
| 994 |
+
sparse_idx_z = off_z % SPARSE_Z
|
| 995 |
+
|
| 996 |
+
k_adj = (stride_kh * off_hkv + stride_kz * off_z).to(tl.int64)
|
| 997 |
+
v_adj = (stride_vh * off_hkv + stride_vz * off_z).to(tl.int64)
|
| 998 |
+
dv_adj = (stride_dvh * off_hkv + stride_dvz * off_z).to(tl.int64)
|
| 999 |
+
|
| 1000 |
+
# offset K, V, DV pointers for batch/kv-head
|
| 1001 |
+
K += k_adj
|
| 1002 |
+
V += v_adj
|
| 1003 |
+
DV += dv_adj
|
| 1004 |
+
|
| 1005 |
+
RCP_LN2 = 1.44269504
|
| 1006 |
+
offs_k = tl.arange(0, QK_HEAD_DIM)
|
| 1007 |
+
offs_v = tl.arange(0, V_HEAD_DIM)
|
| 1008 |
+
|
| 1009 |
+
if pid >= NUM_KV_BLOCKS:
|
| 1010 |
+
off_pid = pid - NUM_KV_BLOCKS
|
| 1011 |
+
# THIS BLOCK DOES DQ
|
| 1012 |
+
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
|
| 1013 |
+
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
|
| 1014 |
+
off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
|
| 1015 |
+
start_m2_block = off_pid % NUM_Q_BLOCKS
|
| 1016 |
+
off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
|
| 1017 |
+
stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}}
|
| 1018 |
+
stride_kv_idx_h = {{stride("KV_IDX", 1)}}
|
| 1019 |
+
stride_kv_idx_m = {{stride("KV_IDX", 2)}}
|
| 1020 |
+
|
| 1021 |
+
sparse_idx_hq2 = off_hq2 % SPARSE_HQ
|
| 1022 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
|
| 1023 |
+
|
| 1024 |
+
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
|
| 1025 |
+
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
|
| 1026 |
+
|
| 1027 |
+
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offseted by query heads.
|
| 1028 |
+
q_adj2 = (stride_qh * off_hq2 + stride_qz * off_z).to(tl.int64)
|
| 1029 |
+
do_adj2 = (stride_doh * off_hq2 + stride_doz * off_z).to(tl.int64)
|
| 1030 |
+
dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_z).to(tl.int64)
|
| 1031 |
+
off_chz2 = ((off_z * HQ + off_hq2) * Q_LEN).to(tl.int64)
|
| 1032 |
+
|
| 1033 |
+
Q2 = Q + q_adj2
|
| 1034 |
+
DO2 = DO + do_adj2
|
| 1035 |
+
# TODO: This does not work if DQ is not the same layout as Q (for example,
|
| 1036 |
+
# if Q is broadcasted)
|
| 1037 |
+
DQ2 = DQ + dq_adj2
|
| 1038 |
+
LSE2 = LSE + off_chz2
|
| 1039 |
+
DELTA2 = DELTA + off_chz2
|
| 1040 |
+
|
| 1041 |
+
dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
|
| 1042 |
+
|
| 1043 |
+
start_m2 = start_m2_block * BLOCK_M2
|
| 1044 |
+
offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
|
| 1045 |
+
|
| 1046 |
+
# load Q and do: they stay in SRAM throughout the inner loop.
|
| 1047 |
+
if IS_DIVISIBLE:
|
| 1048 |
+
q = tl.load(Q2 + offs_m2[:, None] * stride_qm + offs_k[None, :] * stride_qd)
|
| 1049 |
+
do = tl.load(DO2 + offs_m2[:, None] * stride_dom + offs_v[None, :] * stride_dod)
|
| 1050 |
+
else:
|
| 1051 |
+
q = tl.load(Q2 + offs_m2[:, None] * stride_qm + offs_k[None, :] * stride_qd, mask=offs_m2[:, None] < Q_LEN)
|
| 1052 |
+
do = tl.load(DO2 + offs_m2[:, None] * stride_dom + offs_v[None, :] * stride_dod, mask=offs_m2[:, None] < Q_LEN)
|
| 1053 |
+
|
| 1054 |
+
if PRESCALE_QK:
|
| 1055 |
+
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 1056 |
+
|
| 1057 |
+
if IS_DIVISIBLE:
|
| 1058 |
+
Di = tl.load(DELTA2 + offs_m2)
|
| 1059 |
+
lse = tl.load(LSE2 + offs_m2)
|
| 1060 |
+
else:
|
| 1061 |
+
Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
|
| 1062 |
+
lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
|
| 1063 |
+
lse = tl.where(lse == -float("inf"), 0.0, lse)
|
| 1064 |
+
lse = lse[:, None]
|
| 1065 |
+
|
| 1066 |
+
# ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 1067 |
+
# KV_IDX and KV_NUM_BLKS are always contiguous.
|
| 1068 |
+
kv_indices = KV_IDX + sparse_kv_idx_offset
|
| 1069 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 1070 |
+
sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 1071 |
+
|
| 1072 |
+
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
| 1073 |
+
dq = bwd_dq_inner(
|
| 1074 |
+
{{gen_argdefs()}},
|
| 1075 |
+
K, V,
|
| 1076 |
+
dq, q, do, Di, lse,
|
| 1077 |
+
off_z, off_hq2, offs_m2, offs_n2,
|
| 1078 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 1079 |
+
kv_indices, sparse_kv_num_blocks,
|
| 1080 |
+
MATMUL_PRECISION,
|
| 1081 |
+
IS_FULL_BLOCKS=False,
|
| 1082 |
+
)
|
| 1083 |
+
|
| 1084 |
+
if HAS_FULL_BLOCKS:
|
| 1085 |
+
# ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 1086 |
+
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
|
| 1087 |
+
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
|
| 1088 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 1089 |
+
sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 1090 |
+
|
| 1091 |
+
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
| 1092 |
+
dq = bwd_dq_inner(
|
| 1093 |
+
{{gen_argdefs()}},
|
| 1094 |
+
K, V,
|
| 1095 |
+
dq, q, do, Di, lse,
|
| 1096 |
+
off_z, off_hq2, offs_m2, offs_n2,
|
| 1097 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 1098 |
+
kv_indices, sparse_kv_num_blocks,
|
| 1099 |
+
MATMUL_PRECISION,
|
| 1100 |
+
IS_FULL_BLOCKS=True,
|
| 1101 |
+
)
|
| 1102 |
+
|
| 1103 |
+
# Write back dQ.
|
| 1104 |
+
dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
|
| 1105 |
+
dq *= SM_SCALE
|
| 1106 |
+
if IS_DIVISIBLE:
|
| 1107 |
+
tl.store(dq_ptrs, dq)
|
| 1108 |
+
else:
|
| 1109 |
+
tl.store(dq_ptrs, dq, mask=offs_m2[:, None] < Q_LEN)
|
| 1110 |
+
else:
|
| 1111 |
+
# THIS BLOCK DOES DK & DV
|
| 1112 |
+
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
|
| 1113 |
+
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
|
| 1114 |
+
|
| 1115 |
+
pid_mask = pid // SPARSE_KV_MULTIPLE
|
| 1116 |
+
|
| 1117 |
+
stride_q_num_blks_h = {{stride("Q_NUM_BLKS", 1)}}
|
| 1118 |
+
stride_q_idx_h = {{stride("Q_IDX", 1)}}
|
| 1119 |
+
stride_q_idx_n = {{stride("Q_IDX", 2)}}
|
| 1120 |
+
|
| 1121 |
+
dv = tl.zeros([BLOCK_N1, V_HEAD_DIM], dtype=tl.float32)
|
| 1122 |
+
dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM], dtype=tl.float32)
|
| 1123 |
+
|
| 1124 |
+
start_n1 = pid * BLOCK_N1
|
| 1125 |
+
offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
|
| 1126 |
+
|
| 1127 |
+
# load K and V: they stay in SRAM throughout the inner loop.
|
| 1128 |
+
if IS_DIVISIBLE:
|
| 1129 |
+
k = tl.load(K + offs_n1[:, None] * stride_kn + offs_k[None, :] * stride_kd)
|
| 1130 |
+
v = tl.load(V + offs_n1[:, None] * stride_vn + offs_v[None, :] * stride_vd)
|
| 1131 |
+
else:
|
| 1132 |
+
k = tl.load(K + offs_n1[:, None] * stride_kn + offs_k[None, :] * stride_kd, mask=offs_n1[:, None] < KV_LEN)
|
| 1133 |
+
v = tl.load(V + offs_n1[:, None] * stride_vn + offs_v[None, :] * stride_vd, mask=offs_n1[:, None] < KV_LEN)
|
| 1134 |
+
if PRESCALE_QK:
|
| 1135 |
+
k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 1136 |
+
|
| 1137 |
+
for off_g in range(0, GQA_SHARED_HEADS):
|
| 1138 |
+
off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
|
| 1139 |
+
|
| 1140 |
+
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offseted by query heads.
|
| 1141 |
+
q_adj1 = (stride_qh * off_hq1 + stride_qz * off_z).to(tl.int64)
|
| 1142 |
+
do_adj1 = (stride_doh * off_hq1 + stride_doz * off_z).to(tl.int64)
|
| 1143 |
+
dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_z).to(tl.int64)
|
| 1144 |
+
off_chz1 = ((off_z * HQ + off_hq1) * Q_LEN).to(tl.int64)
|
| 1145 |
+
|
| 1146 |
+
Q1 = Q + q_adj1
|
| 1147 |
+
DO1 = DO + do_adj1
|
| 1148 |
+
# TODO: This does not work if DQ is not the same layout as Q (for example,
|
| 1149 |
+
# if Q is broadcasted)
|
| 1150 |
+
LSE1 = LSE + off_chz1
|
| 1151 |
+
DELTA1 = DELTA + off_chz1
|
| 1152 |
+
|
| 1153 |
+
sparse_idx_hq1 = off_hq1 % SPARSE_HQ
|
| 1154 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
|
| 1155 |
+
|
| 1156 |
+
sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
|
| 1157 |
+
sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
|
| 1158 |
+
|
| 1159 |
+
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 1160 |
+
# Q_IDX and Q_NUM_BLKS are always contiguous.
|
| 1161 |
+
q_indices = Q_IDX + sparse_q_idx_offset
|
| 1162 |
+
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
| 1163 |
+
sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
|
| 1164 |
+
|
| 1165 |
+
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
| 1166 |
+
dk, dv = bwd_dkdv_inner(
|
| 1167 |
+
{{gen_argdefs()}},
|
| 1168 |
+
Q1, DO1, DELTA1, LSE1,
|
| 1169 |
+
dk, dv, k, v,
|
| 1170 |
+
off_z, off_hq1, offs_n1, offs_m1,
|
| 1171 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 1172 |
+
q_indices, sparse_q_num_blocks,
|
| 1173 |
+
MATMUL_PRECISION,
|
| 1174 |
+
IS_FULL_BLOCKS=False,
|
| 1175 |
+
)
|
| 1176 |
+
|
| 1177 |
+
|
| 1178 |
+
if HAS_FULL_BLOCKS:
|
| 1179 |
+
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 1180 |
+
# FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
|
| 1181 |
+
q_indices = FULL_Q_IDX + sparse_q_idx_offset
|
| 1182 |
+
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
| 1183 |
+
sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
|
| 1184 |
+
|
| 1185 |
+
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
| 1186 |
+
dk, dv = bwd_dkdv_inner(
|
| 1187 |
+
{{gen_argdefs()}},
|
| 1188 |
+
Q1, DO1, DELTA1, LSE1,
|
| 1189 |
+
dk, dv, k, v,
|
| 1190 |
+
off_z, off_hq1, offs_n1, offs_m1,
|
| 1191 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 1192 |
+
q_indices, sparse_q_num_blocks,
|
| 1193 |
+
MATMUL_PRECISION,
|
| 1194 |
+
IS_FULL_BLOCKS=True,
|
| 1195 |
+
)
|
| 1196 |
+
|
| 1197 |
+
# Write back dV and dK.
|
| 1198 |
+
dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
|
| 1199 |
+
|
| 1200 |
+
index_n = offs_n1[:, None]
|
| 1201 |
+
index_k = offs_k[None, :]
|
| 1202 |
+
|
| 1203 |
+
if IS_DIVISIBLE:
|
| 1204 |
+
tl.store(dv_ptrs, dv)
|
| 1205 |
+
else:
|
| 1206 |
+
tl.store(dv_ptrs, dv, mask=index_n < KV_LEN)
|
| 1207 |
+
|
| 1208 |
+
dk *= SM_SCALE
|
| 1209 |
+
mask = index_n < KV_LEN
|
| 1210 |
+
{{store_output(("off_z", "off_hkv", "index_n", "index_k"), "dk", "mask", indent_width=8)}}
|
| 1211 |
+
|
| 1212 |
+
@triton.jit
|
| 1213 |
+
def bwd_dq_inner(
|
| 1214 |
+
{{gen_argdefs()}},
|
| 1215 |
+
K, V, # pointers
|
| 1216 |
+
dq, q, do, Di, lse,
|
| 1217 |
+
off_z, off_hq, offs_m2, offs_n2,
|
| 1218 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 1219 |
+
kv_indices, sparse_kv_num_blocks,
|
| 1220 |
+
MATMUL_PRECISION,
|
| 1221 |
+
IS_FULL_BLOCKS,
|
| 1222 |
+
):
|
| 1223 |
+
{{gen_defines() | indent_except_first(1) }}
|
| 1224 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
|
| 1225 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 1226 |
+
Q_LEN = {{size("Q", 2)}}
|
| 1227 |
+
KV_LEN = {{size("K", 2)}}
|
| 1228 |
+
|
| 1229 |
+
offs_k = tl.arange(0, QK_HEAD_DIM)
|
| 1230 |
+
offs_v = tl.arange(0, V_HEAD_DIM)
|
| 1231 |
+
|
| 1232 |
+
kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
|
| 1233 |
+
vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
|
| 1234 |
+
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
|
| 1235 |
+
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
|
| 1236 |
+
|
| 1237 |
+
hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
|
| 1238 |
+
if not IS_DIVISIBLE:
|
| 1239 |
+
if hi >= 1:
|
| 1240 |
+
for start_n in range(0, hi - 1):
|
| 1241 |
+
dq = bwd_dq_block_mn(
|
| 1242 |
+
{{gen_argdefs()}},
|
| 1243 |
+
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
| 1244 |
+
off_z, off_hq, offs_m2, offs_n2,
|
| 1245 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 1246 |
+
kv_indices, sparse_kv_num_blocks,
|
| 1247 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 1248 |
+
IS_FULL_BLOCKS,
|
| 1249 |
+
)
|
| 1250 |
+
|
| 1251 |
+
# Increment pointers.
|
| 1252 |
+
offset = get_offset_for_next_block(
|
| 1253 |
+
start_n, kv_indices, sparse_kv_num_blocks,
|
| 1254 |
+
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2
|
| 1255 |
+
)
|
| 1256 |
+
|
| 1257 |
+
kT_ptrs += offset * stride_kn
|
| 1258 |
+
vT_ptrs += offset * stride_vn
|
| 1259 |
+
|
| 1260 |
+
offs_n2 += offset
|
| 1261 |
+
|
| 1262 |
+
dq = bwd_dq_block_mn(
|
| 1263 |
+
{{gen_argdefs()}},
|
| 1264 |
+
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
| 1265 |
+
off_z, off_hq, offs_m2, offs_n2,
|
| 1266 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 1267 |
+
kv_indices, sparse_kv_num_blocks,
|
| 1268 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 1269 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
|
| 1270 |
+
)
|
| 1271 |
+
else:
|
| 1272 |
+
for start_n in range(0, hi):
|
| 1273 |
+
dq = bwd_dq_block_mn(
|
| 1274 |
+
{{gen_argdefs()}},
|
| 1275 |
+
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
| 1276 |
+
off_z, off_hq, offs_m2, offs_n2,
|
| 1277 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 1278 |
+
kv_indices, sparse_kv_num_blocks,
|
| 1279 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 1280 |
+
IS_FULL_BLOCKS,
|
| 1281 |
+
)
|
| 1282 |
+
|
| 1283 |
+
# Increment pointers.
|
| 1284 |
+
offset = get_offset_for_next_block(
|
| 1285 |
+
start_n, kv_indices, sparse_kv_num_blocks,
|
| 1286 |
+
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2
|
| 1287 |
+
)
|
| 1288 |
+
|
| 1289 |
+
kT_ptrs += offset * stride_kn
|
| 1290 |
+
vT_ptrs += offset * stride_vn
|
| 1291 |
+
|
| 1292 |
+
offs_n2 += offset
|
| 1293 |
+
|
| 1294 |
+
return dq
|
| 1295 |
+
|
| 1296 |
+
|
| 1297 |
+
@triton.jit
|
| 1298 |
+
def bwd_dq_block_mn(
|
| 1299 |
+
{{gen_argdefs()}},
|
| 1300 |
+
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
| 1301 |
+
off_z, off_hq, offs_m2, offs_n2,
|
| 1302 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 1303 |
+
kv_indices, sparse_kv_num_blocks,
|
| 1304 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 1305 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
|
| 1306 |
+
):
|
| 1307 |
+
{{gen_defines() | indent_except_first(1)}}
|
| 1308 |
+
|
| 1309 |
+
if IS_DIVISIBLE:
|
| 1310 |
+
kT = tl.load(kT_ptrs)
|
| 1311 |
+
else:
|
| 1312 |
+
kT = tl.load(kT_ptrs, mask=offs_n2[None, :] < KV_LEN)
|
| 1313 |
+
qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
|
| 1314 |
+
if not PRESCALE_QK:
|
| 1315 |
+
qk *= SM_SCALE
|
| 1316 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 1317 |
+
pre_mod_scores = qk
|
| 1318 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 1319 |
+
m = offs_m2[:, None] % Q_LEN
|
| 1320 |
+
n = offs_n2[None, :] % KV_LEN
|
| 1321 |
+
else:
|
| 1322 |
+
m = offs_m2[:, None]
|
| 1323 |
+
n = offs_n2[None, :]
|
| 1324 |
+
{{ modification(
|
| 1325 |
+
subgraph_number=0,
|
| 1326 |
+
output_name="post_mod_scores",
|
| 1327 |
+
score="qk",
|
| 1328 |
+
b="off_z",
|
| 1329 |
+
h="off_hq",
|
| 1330 |
+
m="m",
|
| 1331 |
+
n="n",
|
| 1332 |
+
out="qk"
|
| 1333 |
+
) | indent_except_first(1) }}
|
| 1334 |
+
|
| 1335 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 1336 |
+
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
|
| 1337 |
+
post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
|
| 1338 |
+
|
| 1339 |
+
if not IS_FULL_BLOCKS:
|
| 1340 |
+
{{ modification(
|
| 1341 |
+
subgraph_number=2,
|
| 1342 |
+
output_name="mask_mod_output",
|
| 1343 |
+
score="qk",
|
| 1344 |
+
b="off_z",
|
| 1345 |
+
h="off_hq",
|
| 1346 |
+
m="m",
|
| 1347 |
+
n="n",
|
| 1348 |
+
) | indent_except_first(2) }}
|
| 1349 |
+
|
| 1350 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 1351 |
+
mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, float("-inf"))
|
| 1352 |
+
# apply mask for partial masked block
|
| 1353 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 1354 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 1355 |
+
if not PRESCALE_QK:
|
| 1356 |
+
post_mod_scores *= RCP_LN2
|
| 1357 |
+
p = tl.math.exp2(post_mod_scores - lse)
|
| 1358 |
+
# Compute dP and dS.
|
| 1359 |
+
if IS_DIVISIBLE:
|
| 1360 |
+
vT = tl.load(vT_ptrs)
|
| 1361 |
+
else:
|
| 1362 |
+
vT = tl.load(vT_ptrs, mask=offs_n2[None, :] < KV_LEN)
|
| 1363 |
+
dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
|
| 1364 |
+
ds = p * (dp - Di[:, None])
|
| 1365 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
|
| 1366 |
+
{{ modification(
|
| 1367 |
+
subgraph_number=1,
|
| 1368 |
+
output_name = "grad_scores",
|
| 1369 |
+
score="pre_mod_scores",
|
| 1370 |
+
b="off_z",
|
| 1371 |
+
h="off_hq",
|
| 1372 |
+
m="m",
|
| 1373 |
+
n="n",
|
| 1374 |
+
grad_score_mod="ds"
|
| 1375 |
+
) | indent_except_first(1) }}
|
| 1376 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 1377 |
+
grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
|
| 1378 |
+
|
| 1379 |
+
ds = grad_scores
|
| 1380 |
+
|
| 1381 |
+
if not IS_FULL_BLOCKS:
|
| 1382 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 1383 |
+
mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, float("-inf"))
|
| 1384 |
+
# (grads) apply mask for partially unmasked block
|
| 1385 |
+
ds = tl.where(mask_mod_output, ds, 0.0)
|
| 1386 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 1387 |
+
ds = ds.to(MATMUL_PRECISION)
|
| 1388 |
+
# Compute dQ.
|
| 1389 |
+
dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
|
| 1390 |
+
|
| 1391 |
+
return dq
|
| 1392 |
+
|
| 1393 |
+
|
| 1394 |
+
@triton.jit
|
| 1395 |
+
def bwd_dkdv_inner(
|
| 1396 |
+
{{gen_argdefs()}},
|
| 1397 |
+
Q, DO, DELTA, LSE, # pointers
|
| 1398 |
+
dk, dv, k, v,
|
| 1399 |
+
off_z, off_hq, offs_n1, offs_m1,
|
| 1400 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 1401 |
+
q_indices, sparse_q_num_blocks,
|
| 1402 |
+
MATMUL_PRECISION,
|
| 1403 |
+
IS_FULL_BLOCKS,
|
| 1404 |
+
):
|
| 1405 |
+
{{gen_defines() | indent_except_first(1) }}
|
| 1406 |
+
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
|
| 1407 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 1408 |
+
Q_LEN = {{size("Q", 2)}}
|
| 1409 |
+
KV_LEN = {{size("K", 2)}}
|
| 1410 |
+
|
| 1411 |
+
offs_k = tl.arange(0, QK_HEAD_DIM)
|
| 1412 |
+
offs_v = tl.arange(0, V_HEAD_DIM)
|
| 1413 |
+
|
| 1414 |
+
qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
|
| 1415 |
+
do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
|
| 1416 |
+
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
|
| 1417 |
+
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
|
| 1418 |
+
hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
|
| 1419 |
+
|
| 1420 |
+
if not IS_DIVISIBLE:
|
| 1421 |
+
if hi >= 1:
|
| 1422 |
+
for start_m in range(0, hi - 1):
|
| 1423 |
+
dk, dv = bwd_dkdv_block_mn(
|
| 1424 |
+
{{gen_argdefs()}},
|
| 1425 |
+
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
| 1426 |
+
off_z, off_hq, offs_n1, offs_m1,
|
| 1427 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 1428 |
+
q_indices, sparse_q_num_blocks,
|
| 1429 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 1430 |
+
IS_FULL_BLOCKS,
|
| 1431 |
+
)
|
| 1432 |
+
# Increment pointers.
|
| 1433 |
+
offset = get_offset_for_next_block(
|
| 1434 |
+
start_m, q_indices, sparse_q_num_blocks,
|
| 1435 |
+
SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1
|
| 1436 |
+
)
|
| 1437 |
+
|
| 1438 |
+
qT_ptrs += offset * stride_qm
|
| 1439 |
+
do_ptrs += offset * stride_dom
|
| 1440 |
+
|
| 1441 |
+
offs_m1 += offset
|
| 1442 |
+
|
| 1443 |
+
dk, dv = bwd_dkdv_block_mn(
|
| 1444 |
+
{{gen_argdefs()}},
|
| 1445 |
+
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
| 1446 |
+
off_z, off_hq, offs_n1, offs_m1,
|
| 1447 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 1448 |
+
q_indices, sparse_q_num_blocks,
|
| 1449 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 1450 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
|
| 1451 |
+
)
|
| 1452 |
+
else:
|
| 1453 |
+
for start_m in range(0, hi):
|
| 1454 |
+
dk, dv = bwd_dkdv_block_mn(
|
| 1455 |
+
{{gen_argdefs()}},
|
| 1456 |
+
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
| 1457 |
+
off_z, off_hq, offs_n1, offs_m1,
|
| 1458 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 1459 |
+
q_indices, sparse_q_num_blocks,
|
| 1460 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 1461 |
+
IS_FULL_BLOCKS,
|
| 1462 |
+
)
|
| 1463 |
+
# Increment pointers.
|
| 1464 |
+
offset = get_offset_for_next_block(
|
| 1465 |
+
start_m, q_indices, sparse_q_num_blocks,
|
| 1466 |
+
SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1
|
| 1467 |
+
)
|
| 1468 |
+
|
| 1469 |
+
qT_ptrs += offset * stride_qm
|
| 1470 |
+
do_ptrs += offset * stride_dom
|
| 1471 |
+
|
| 1472 |
+
offs_m1 += offset
|
| 1473 |
+
|
| 1474 |
+
return dk, dv
|
| 1475 |
+
|
| 1476 |
+
|
| 1477 |
+
@triton.jit
|
| 1478 |
+
def bwd_dkdv_block_mn(
|
| 1479 |
+
{{gen_argdefs()}},
|
| 1480 |
+
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
| 1481 |
+
off_z, off_hq, offs_n1, offs_m1,
|
| 1482 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 1483 |
+
q_indices, sparse_q_num_blocks,
|
| 1484 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 1485 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
|
| 1486 |
+
):
|
| 1487 |
+
{{gen_defines() | indent_except_first(1) }}
|
| 1488 |
+
|
| 1489 |
+
# Load LSE before computing qk to reduce pipeline stall.
|
| 1490 |
+
if IS_DIVISIBLE:
|
| 1491 |
+
qT = tl.load(qT_ptrs)
|
| 1492 |
+
lse = tl.load(LSE + offs_m1)
|
| 1493 |
+
else:
|
| 1494 |
+
qT = tl.load(qT_ptrs, mask=offs_m1[None, :] < Q_LEN)
|
| 1495 |
+
lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
|
| 1496 |
+
lse = tl.where(lse == -float("inf"), 0.0, lse)
|
| 1497 |
+
qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
|
| 1498 |
+
if not PRESCALE_QK:
|
| 1499 |
+
qkT *= SM_SCALE
|
| 1500 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 1501 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 1502 |
+
m = offs_m1[None, :] % Q_LEN
|
| 1503 |
+
n = offs_n1[:, None] % KV_LEN
|
| 1504 |
+
else:
|
| 1505 |
+
m = offs_m1[None, :]
|
| 1506 |
+
n = offs_n1[:, None]
|
| 1507 |
+
pre_mod_scores = qkT
|
| 1508 |
+
{{ modification(
|
| 1509 |
+
subgraph_number=0,
|
| 1510 |
+
output_name="post_mod_scores",
|
| 1511 |
+
score="qkT",
|
| 1512 |
+
b="off_z",
|
| 1513 |
+
h="off_hq",
|
| 1514 |
+
m="m",
|
| 1515 |
+
n="n",
|
| 1516 |
+
out="qkT"
|
| 1517 |
+
) | indent_except_first(1) }}
|
| 1518 |
+
|
| 1519 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 1520 |
+
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
|
| 1521 |
+
post_mod_scores = tl.where(offs_n1[:, None] < KV_LEN, post_mod_scores, float("-inf"))
|
| 1522 |
+
|
| 1523 |
+
if not IS_FULL_BLOCKS:
|
| 1524 |
+
{{ modification(
|
| 1525 |
+
subgraph_number=2,
|
| 1526 |
+
output_name="mask_mod_output",
|
| 1527 |
+
score="qkT",
|
| 1528 |
+
b="off_z",
|
| 1529 |
+
h="off_hq",
|
| 1530 |
+
m="m",
|
| 1531 |
+
n="n",
|
| 1532 |
+
) | indent_except_first(2) }}
|
| 1533 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 1534 |
+
mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, float("-inf"))
|
| 1535 |
+
# (grads) apply mask for fully masked block
|
| 1536 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 1537 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 1538 |
+
if not PRESCALE_QK:
|
| 1539 |
+
post_mod_scores *= RCP_LN2
|
| 1540 |
+
pT = tl.math.exp2(post_mod_scores - lse[None, :])
|
| 1541 |
+
if IS_DIVISIBLE:
|
| 1542 |
+
do = tl.load(do_ptrs)
|
| 1543 |
+
else:
|
| 1544 |
+
do = tl.load(do_ptrs, mask=offs_m1[:, None] < Q_LEN)
|
| 1545 |
+
# Compute dV.
|
| 1546 |
+
ppT = pT
|
| 1547 |
+
dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
|
| 1548 |
+
if IS_DIVISIBLE:
|
| 1549 |
+
Di = tl.load(DELTA + offs_m1)
|
| 1550 |
+
else:
|
| 1551 |
+
Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
|
| 1552 |
+
# Compute dP and dS.
|
| 1553 |
+
dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
|
| 1554 |
+
dsT = pT * (dpT - Di[None, :])
|
| 1555 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
|
| 1556 |
+
{{ modification(
|
| 1557 |
+
subgraph_number=1,
|
| 1558 |
+
output_name = "grad_scores",
|
| 1559 |
+
score="pre_mod_scores",
|
| 1560 |
+
b="off_z",
|
| 1561 |
+
h="off_hq",
|
| 1562 |
+
m="m",
|
| 1563 |
+
n="n",
|
| 1564 |
+
grad_score_mod="dsT"
|
| 1565 |
+
) | indent_except_first(1) }}
|
| 1566 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 1567 |
+
grad_scores = tl.where(offs_n1[:, None] < KV_LEN, grad_scores, 0.0)
|
| 1568 |
+
|
| 1569 |
+
dsT = grad_scores
|
| 1570 |
+
if not IS_FULL_BLOCKS:
|
| 1571 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 1572 |
+
mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, float("-inf"))
|
| 1573 |
+
# (grads) apply mask for partially unmasked block
|
| 1574 |
+
dsT = tl.where(mask_mod_output, dsT, 0.0)
|
| 1575 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 1576 |
+
dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
|
| 1577 |
+
|
| 1578 |
+
return dk, dv
|
| 1579 |
+
"""
|
| 1580 |
+
+ compute_next_offset_func,
|
| 1581 |
+
)
|
| 1582 |
+
|
| 1583 |
+
|
| 1584 |
+
# TODO: We probably also need a layout constraint?
|
| 1585 |
+
@register_lowering(
|
| 1586 |
+
torch.ops.higher_order.flex_attention_backward, type_promotion_kind=None
|
| 1587 |
+
)
|
| 1588 |
+
def flex_attention_backward(*args, **kwargs):
|
| 1589 |
+
(
|
| 1590 |
+
query,
|
| 1591 |
+
key,
|
| 1592 |
+
value,
|
| 1593 |
+
out,
|
| 1594 |
+
logsumexp,
|
| 1595 |
+
grad_out,
|
| 1596 |
+
grad_logsumexp,
|
| 1597 |
+
fw_graph,
|
| 1598 |
+
joint_graph,
|
| 1599 |
+
block_mask,
|
| 1600 |
+
scale,
|
| 1601 |
+
kernel_options,
|
| 1602 |
+
score_mod_other_buffers,
|
| 1603 |
+
mask_mod_other_buffers,
|
| 1604 |
+
) = args
|
| 1605 |
+
(
|
| 1606 |
+
kv_num_blocks,
|
| 1607 |
+
kv_indices,
|
| 1608 |
+
full_kv_num_blocks,
|
| 1609 |
+
full_kv_indices,
|
| 1610 |
+
q_num_blocks,
|
| 1611 |
+
q_indices,
|
| 1612 |
+
full_q_num_blocks,
|
| 1613 |
+
full_q_indices,
|
| 1614 |
+
SPARSE_KV_BLOCK_SIZE,
|
| 1615 |
+
SPARSE_Q_BLOCK_SIZE,
|
| 1616 |
+
mask_graph,
|
| 1617 |
+
) = block_mask
|
| 1618 |
+
|
| 1619 |
+
(
|
| 1620 |
+
query,
|
| 1621 |
+
key,
|
| 1622 |
+
value,
|
| 1623 |
+
grad_out,
|
| 1624 |
+
kv_num_blocks,
|
| 1625 |
+
kv_indices,
|
| 1626 |
+
full_kv_num_blocks,
|
| 1627 |
+
full_kv_indices,
|
| 1628 |
+
q_num_blocks,
|
| 1629 |
+
q_indices,
|
| 1630 |
+
full_q_num_blocks,
|
| 1631 |
+
full_q_indices,
|
| 1632 |
+
) = maybe_realize(
|
| 1633 |
+
[
|
| 1634 |
+
query,
|
| 1635 |
+
key,
|
| 1636 |
+
value,
|
| 1637 |
+
grad_out,
|
| 1638 |
+
kv_num_blocks,
|
| 1639 |
+
kv_indices,
|
| 1640 |
+
full_kv_num_blocks,
|
| 1641 |
+
full_kv_indices,
|
| 1642 |
+
q_num_blocks,
|
| 1643 |
+
q_indices,
|
| 1644 |
+
full_q_num_blocks,
|
| 1645 |
+
full_q_indices,
|
| 1646 |
+
]
|
| 1647 |
+
)
|
| 1648 |
+
|
| 1649 |
+
device = query.get_device()
|
| 1650 |
+
dtype = query.get_dtype()
|
| 1651 |
+
Bq, Hq, seq_len_q, qk_head_dim = query.get_size()
|
| 1652 |
+
Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size()
|
| 1653 |
+
assert Bq == Bkv, "Batch dimension must match"
|
| 1654 |
+
B = Bq
|
| 1655 |
+
|
| 1656 |
+
kernel_options = dict(kernel_options)
|
| 1657 |
+
kernel_options.setdefault("FLOAT32_PRECISION", get_float32_precision())
|
| 1658 |
+
if seq_len_q % 128 != 0 or seq_len_kv % 128 != 0:
|
| 1659 |
+
kernel_options.setdefault("IS_DIVISIBLE", False)
|
| 1660 |
+
else:
|
| 1661 |
+
kernel_options.setdefault("IS_DIVISIBLE", True)
|
| 1662 |
+
|
| 1663 |
+
fwd_placeholder_inps = [
|
| 1664 |
+
create_placeholder(name, dtype, device)
|
| 1665 |
+
for name, dtype in [
|
| 1666 |
+
("score", dtype),
|
| 1667 |
+
("b", torch.int32),
|
| 1668 |
+
("h", torch.int32),
|
| 1669 |
+
("m", torch.int32),
|
| 1670 |
+
("n", torch.int32),
|
| 1671 |
+
]
|
| 1672 |
+
]
|
| 1673 |
+
fw_subgraph_buffer = build_subgraph_buffer(
|
| 1674 |
+
fwd_placeholder_inps + list(score_mod_other_buffers), fw_graph
|
| 1675 |
+
)
|
| 1676 |
+
|
| 1677 |
+
joint_placeholder_inps = fwd_placeholder_inps + [
|
| 1678 |
+
create_placeholder("grad_score_mod", dtype, device)
|
| 1679 |
+
]
|
| 1680 |
+
joint_subgraph_buffer, *_ = build_subgraph_buffer(
|
| 1681 |
+
joint_placeholder_inps + list(score_mod_other_buffers), joint_graph
|
| 1682 |
+
)
|
| 1683 |
+
|
| 1684 |
+
mask_graph_placeholder_inps = [
|
| 1685 |
+
create_placeholder(name, dtype, query.get_device())
|
| 1686 |
+
for name, dtype in [
|
| 1687 |
+
("b", torch.int32),
|
| 1688 |
+
("h", torch.int32),
|
| 1689 |
+
("m", torch.int32),
|
| 1690 |
+
("n", torch.int32),
|
| 1691 |
+
]
|
| 1692 |
+
]
|
| 1693 |
+
mask_graph_buffer = build_subgraph_buffer(
|
| 1694 |
+
mask_graph_placeholder_inps + list(mask_mod_other_buffers), mask_graph
|
| 1695 |
+
)
|
| 1696 |
+
|
| 1697 |
+
layout_k = FixedLayout(
|
| 1698 |
+
key.get_device(),
|
| 1699 |
+
key.get_dtype(),
|
| 1700 |
+
key.get_size(),
|
| 1701 |
+
key.get_stride(),
|
| 1702 |
+
)
|
| 1703 |
+
|
| 1704 |
+
# Create delta which will is needed for the bwd's kernel
|
| 1705 |
+
grad_lse_exp2 = lowerings[aten.mul](grad_logsumexp, 1 / math.log(2))
|
| 1706 |
+
mul_delta = lowerings[aten.mul](out, grad_out)
|
| 1707 |
+
delta = lowerings[aten.sum](mul_delta, axis=-1)
|
| 1708 |
+
delta = lowerings[aten.sub](delta, grad_lse_exp2)
|
| 1709 |
+
delta = ExternKernel.require_contiguous(delta)
|
| 1710 |
+
|
| 1711 |
+
grad_lse_exp2, delta = maybe_realize([grad_lse_exp2, delta])
|
| 1712 |
+
|
| 1713 |
+
# see NOTE:[TritonTemplates with multiple outputs]
|
| 1714 |
+
grad_query = empty_strided(
|
| 1715 |
+
query.get_size(), query.get_stride(), dtype=dtype, device=device
|
| 1716 |
+
)
|
| 1717 |
+
grad_value = empty_strided(
|
| 1718 |
+
value.get_size(), value.get_stride(), dtype=dtype, device=device
|
| 1719 |
+
)
|
| 1720 |
+
|
| 1721 |
+
kernel_options.setdefault("SM_SCALE", scale)
|
| 1722 |
+
|
| 1723 |
+
# Determine GQA factor
|
| 1724 |
+
gqa_shared_heads = Hq // Hkv
|
| 1725 |
+
kernel_options.setdefault("GQA_SHARED_HEADS", gqa_shared_heads)
|
| 1726 |
+
|
| 1727 |
+
# Inside of Triton kernel, only apply partial masking if partial blocks are computed.
|
| 1728 |
+
# full_kv_num_blocks is torch.zeros([1, 1, 1]) if partial blocks are not computed.
|
| 1729 |
+
has_full_blocks = full_kv_num_blocks is not None
|
| 1730 |
+
kernel_options.setdefault("HAS_FULL_BLOCKS", has_full_blocks)
|
| 1731 |
+
if not has_full_blocks:
|
| 1732 |
+
full_kv_num_blocks, full_kv_indices, full_q_num_blocks, full_q_indices = (
|
| 1733 |
+
empty(0, device=query.get_device()) for _ in range(4)
|
| 1734 |
+
)
|
| 1735 |
+
kernel_options.setdefault("QK_HEAD_DIM", qk_head_dim)
|
| 1736 |
+
kernel_options.setdefault("V_HEAD_DIM", v_head_dim)
|
| 1737 |
+
|
| 1738 |
+
choices: List[Any] = []
|
| 1739 |
+
configs: List[Tuple[int, int, int, int]] = []
|
| 1740 |
+
configs.append(_get_default_config_bwd(query))
|
| 1741 |
+
if config.max_autotune:
|
| 1742 |
+
configs.extend(
|
| 1743 |
+
[
|
| 1744 |
+
(BLOCK1, BLOCK2, w, s)
|
| 1745 |
+
for BLOCK1 in [32, 64]
|
| 1746 |
+
for BLOCK2 in [32, 64, 128]
|
| 1747 |
+
for w in [4, 8]
|
| 1748 |
+
for s in [1, 3, 4, 5]
|
| 1749 |
+
if BLOCK2 % BLOCK1 == 0
|
| 1750 |
+
]
|
| 1751 |
+
)
|
| 1752 |
+
|
| 1753 |
+
for BLOCK1, BLOCK2, num_warps, num_stages in configs:
|
| 1754 |
+
if (
|
| 1755 |
+
SPARSE_KV_BLOCK_SIZE % BLOCK1 != 0
|
| 1756 |
+
or SPARSE_Q_BLOCK_SIZE % BLOCK1 != 0
|
| 1757 |
+
or SPARSE_KV_BLOCK_SIZE % BLOCK2 != 0
|
| 1758 |
+
or SPARSE_Q_BLOCK_SIZE % BLOCK2 != 0
|
| 1759 |
+
):
|
| 1760 |
+
continue
|
| 1761 |
+
|
| 1762 |
+
# Performance tuning
|
| 1763 |
+
kernel_options.setdefault("BLOCK_M1", BLOCK1)
|
| 1764 |
+
kernel_options.setdefault("BLOCK_N1", BLOCK2)
|
| 1765 |
+
kernel_options.setdefault("BLOCK_M2", BLOCK2)
|
| 1766 |
+
kernel_options.setdefault("BLOCK_N2", BLOCK1)
|
| 1767 |
+
# Blocksparse options
|
| 1768 |
+
kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE)
|
| 1769 |
+
kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE)
|
| 1770 |
+
|
| 1771 |
+
flex_attention_backward_template.maybe_append_choice(
|
| 1772 |
+
choices=choices,
|
| 1773 |
+
input_nodes=[
|
| 1774 |
+
query,
|
| 1775 |
+
key,
|
| 1776 |
+
value,
|
| 1777 |
+
logsumexp,
|
| 1778 |
+
delta,
|
| 1779 |
+
grad_out,
|
| 1780 |
+
grad_query,
|
| 1781 |
+
grad_value,
|
| 1782 |
+
kv_num_blocks,
|
| 1783 |
+
kv_indices,
|
| 1784 |
+
q_num_blocks,
|
| 1785 |
+
q_indices,
|
| 1786 |
+
full_kv_num_blocks,
|
| 1787 |
+
full_kv_indices,
|
| 1788 |
+
full_q_num_blocks,
|
| 1789 |
+
full_q_indices,
|
| 1790 |
+
],
|
| 1791 |
+
layout=layout_k, # We use store_output only for grad_key
|
| 1792 |
+
subgraphs=[fw_subgraph_buffer, joint_subgraph_buffer, mask_graph_buffer],
|
| 1793 |
+
mutated_inputs=[grad_query, grad_value],
|
| 1794 |
+
call_sizes=query.get_size() + key.get_size()[1:3],
|
| 1795 |
+
num_stages=num_stages,
|
| 1796 |
+
num_warps=num_warps,
|
| 1797 |
+
**kernel_options,
|
| 1798 |
+
)
|
| 1799 |
+
inputs_for_autotuning = (
|
| 1800 |
+
[
|
| 1801 |
+
query,
|
| 1802 |
+
key,
|
| 1803 |
+
value,
|
| 1804 |
+
logsumexp,
|
| 1805 |
+
delta,
|
| 1806 |
+
grad_out,
|
| 1807 |
+
grad_query,
|
| 1808 |
+
grad_value,
|
| 1809 |
+
kv_num_blocks,
|
| 1810 |
+
kv_indices,
|
| 1811 |
+
q_num_blocks,
|
| 1812 |
+
q_indices,
|
| 1813 |
+
full_kv_num_blocks,
|
| 1814 |
+
full_kv_indices,
|
| 1815 |
+
full_q_num_blocks,
|
| 1816 |
+
full_q_indices,
|
| 1817 |
+
]
|
| 1818 |
+
+ list(score_mod_other_buffers)
|
| 1819 |
+
+ list(mask_mod_other_buffers)
|
| 1820 |
+
)
|
| 1821 |
+
input_gen_fns = {
|
| 1822 |
+
8: create_num_blocks_fake_generator(kv_indices), # kv_num_blocks
|
| 1823 |
+
9: create_indices_fake,
|
| 1824 |
+
10: create_num_blocks_fake_generator(q_indices), # q_num_blocks
|
| 1825 |
+
11: create_indices_fake,
|
| 1826 |
+
12: create_num_blocks_fake_generator(full_kv_indices), # full_kv_num_blocks
|
| 1827 |
+
13: create_indices_fake,
|
| 1828 |
+
14: create_num_blocks_fake_generator(full_q_indices), # full_q_num_blocks
|
| 1829 |
+
15: create_indices_fake,
|
| 1830 |
+
}
|
| 1831 |
+
|
| 1832 |
+
grad_key = autotune_select_algorithm(
|
| 1833 |
+
"flex_attention_backward",
|
| 1834 |
+
choices,
|
| 1835 |
+
inputs_for_autotuning,
|
| 1836 |
+
layout_k,
|
| 1837 |
+
input_gen_fns=input_gen_fns,
|
| 1838 |
+
)
|
| 1839 |
+
return (
|
| 1840 |
+
grad_query,
|
| 1841 |
+
grad_key,
|
| 1842 |
+
grad_value,
|
| 1843 |
+
)
|
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/flex_decoding.py
ADDED
|
@@ -0,0 +1,570 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
""" Triton Implementation of the flex_attention Kernel for short query length (FlexDecoding)"""
|
| 3 |
+
from typing import Any, List, Tuple
|
| 4 |
+
|
| 5 |
+
import sympy
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch._inductor.virtualized import V
|
| 9 |
+
|
| 10 |
+
from .. import config, ir
|
| 11 |
+
from ..ir import FixedLayout, FlexibleLayout
|
| 12 |
+
from ..lowering import empty, empty_strided, lowerings
|
| 13 |
+
from ..runtime.runtime_utils import is_power_of_2, next_power_of_2
|
| 14 |
+
from ..select_algorithm import autotune_select_algorithm, TritonTemplate
|
| 15 |
+
from .flex_attention import (
|
| 16 |
+
compute_forward_block_mn,
|
| 17 |
+
compute_forward_inner,
|
| 18 |
+
compute_next_offset_func,
|
| 19 |
+
create_indices_fake,
|
| 20 |
+
create_num_blocks_fake_generator,
|
| 21 |
+
maybe_realize,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
aten = torch.ops.aten
|
| 26 |
+
prims = torch.ops.prims
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, meta):
|
| 30 |
+
"""How is this kernel parallelized?
|
| 31 |
+
We create a grid of (batch_size * kv_heads, SPLIT_KV, 1)
|
| 32 |
+
Each block is responsible for iterating over blocks of keys and values calculating
|
| 33 |
+
the local output for their tile of keys and values over all full length of query.
|
| 34 |
+
groups of SPLIT_KV blocks then combine their output to produce the final result.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
return (batch_size * kv_heads, meta["SPLIT_KV"], 1)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
flex_decoding_template = TritonTemplate(
|
| 41 |
+
name="flex_decoding",
|
| 42 |
+
grid=flex_decoding_grid,
|
| 43 |
+
source=r"""
|
| 44 |
+
{{def_kernel("Q", "K", "V", "M", "L", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}}
|
| 45 |
+
# Sub notation for this kernel:
|
| 46 |
+
# Q: Query, K: Key, V: Value
|
| 47 |
+
# reduction buffers: M rowmax across local KV split, L local sumexp across local KV split
|
| 48 |
+
# M: Number of queries, N: Number of keys/values
|
| 49 |
+
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
| 50 |
+
# V_HEAD_DIM: The dimension of the value embeddings
|
| 51 |
+
# BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block
|
| 52 |
+
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits
|
| 53 |
+
# (Modifiable) Config options:
|
| 54 |
+
# SPLIT_KV: number of blocks K & V are split into
|
| 55 |
+
# TILE_KV: length of each local KV split
|
| 56 |
+
# BLOCK_M: block size that Q is padded along seqlen dim.
|
| 57 |
+
# BLOCK_N: block size of K & V along N dimension.
|
| 58 |
+
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
| 59 |
+
#
|
| 60 |
+
# change of base out of the loop
|
| 61 |
+
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
|
| 62 |
+
# is not masked out? If so, we can skip an extra safety check
|
| 63 |
+
# SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query.
|
| 64 |
+
# SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value.
|
| 65 |
+
|
| 66 |
+
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base.
|
| 67 |
+
#
|
| 68 |
+
# SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim.
|
| 69 |
+
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
| 70 |
+
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
| 71 |
+
#
|
| 72 |
+
#
|
| 73 |
+
# Output: ACC output accumulated across local KV split.
|
| 74 |
+
|
| 75 |
+
tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
|
| 76 |
+
|
| 77 |
+
# Define Q Strides
|
| 78 |
+
stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = {{stride("Q")}}
|
| 79 |
+
stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}}
|
| 80 |
+
stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}}
|
| 81 |
+
stride_mz, stride_mt, stride_mh, stride_mm = {{stride("M")}}
|
| 82 |
+
stride_lz, stride_lt, stride_lh, stride_lm = {{stride("L")}}
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
Z = {{size("Q", 0)}}
|
| 86 |
+
HKV = {{size("Q", 1)}}
|
| 87 |
+
G: tl.constexpr = GQA_SHARED_HEADS
|
| 88 |
+
HQ = HKV * G
|
| 89 |
+
Q_LEN = {{size("Q", 3)}}
|
| 90 |
+
KV_LEN = {{size("K", 2)}}
|
| 91 |
+
|
| 92 |
+
MATMUL_PRECISION = Q.dtype.element_ty
|
| 93 |
+
|
| 94 |
+
# Make sure each split is a multiple of BLOCK_N
|
| 95 |
+
TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV)
|
| 96 |
+
TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N
|
| 97 |
+
TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N)
|
| 98 |
+
|
| 99 |
+
off_z = tl.program_id(0) // HKV
|
| 100 |
+
off_hkv = tl.program_id(0) % HKV
|
| 101 |
+
off_t = tl.program_id(1)
|
| 102 |
+
|
| 103 |
+
q_offset = off_z * stride_qz + off_hkv * stride_qh
|
| 104 |
+
k_offset = off_z * stride_kz + off_hkv * stride_kh
|
| 105 |
+
v_offset = off_z * stride_vz + off_hkv * stride_vh
|
| 106 |
+
|
| 107 |
+
SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}
|
| 108 |
+
SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}}
|
| 109 |
+
|
| 110 |
+
sparse_idx_z = off_z % SPARSE_Z
|
| 111 |
+
# TODO: support masks not broadcasted along the head dimension.
|
| 112 |
+
tl.device_assert(SPARSE_HQ == 1)
|
| 113 |
+
sparse_idx_h = 0
|
| 114 |
+
|
| 115 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 116 |
+
SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE)
|
| 117 |
+
|
| 118 |
+
# initialize pointer to m and l
|
| 119 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
| 120 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
| 121 |
+
acc = tl.zeros([BLOCK_M, V_HEAD_DIM], dtype=tl.float32)
|
| 122 |
+
|
| 123 |
+
# initialize offsets
|
| 124 |
+
tl.device_assert(BLOCK_M % G == 0)
|
| 125 |
+
BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G
|
| 126 |
+
off_g = tl.arange(0, G) # [G]
|
| 127 |
+
offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
|
| 128 |
+
offs_hq = offs_g + off_hkv * G
|
| 129 |
+
off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ]
|
| 130 |
+
offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
|
| 131 |
+
offs_d = tl.arange(0, QK_HEAD_DIM)
|
| 132 |
+
offs_vd = tl.arange(0, V_HEAD_DIM)
|
| 133 |
+
|
| 134 |
+
# KV_IDX / FULL_KV_IDX and KV_NUM_BLKS / FULL_KV_NUM_BLKS are always contiguous.
|
| 135 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_h
|
| 136 |
+
|
| 137 |
+
# Calculate KV blocks that belong this CTA.
|
| 138 |
+
block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block
|
| 139 |
+
block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N
|
| 140 |
+
|
| 141 |
+
q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :]
|
| 142 |
+
|
| 143 |
+
if SAFE_M_BOUNDARY:
|
| 144 |
+
q = tl.load(Q + q_offset + q_range)
|
| 145 |
+
else:
|
| 146 |
+
mask = off_m[None, :, None] < Q_LEN
|
| 147 |
+
q = tl.load(Q + q_offset + q_range, mask)
|
| 148 |
+
|
| 149 |
+
q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM])
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 153 |
+
# Apply both score_mod and mask_mod
|
| 154 |
+
|
| 155 |
+
# find first kv block we are loading and the number of blocks we are loading
|
| 156 |
+
kv_indices = KV_IDX + sparse_hz_offset * SPARSE_KV_BLOCK_CNT
|
| 157 |
+
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_hz_offset)
|
| 158 |
+
indices_idx = block_n_start // SPARSE_KV_MULTIPLE
|
| 159 |
+
off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
|
| 160 |
+
off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
|
| 161 |
+
# first kv block we're loading
|
| 162 |
+
|
| 163 |
+
# last valid block according to sparse mask
|
| 164 |
+
block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 165 |
+
|
| 166 |
+
K_block_ptr = tl.make_block_ptr(
|
| 167 |
+
base=K + k_offset,
|
| 168 |
+
shape=(QK_HEAD_DIM, KV_LEN), # (d, N)
|
| 169 |
+
strides=(stride_kk, stride_kn),
|
| 170 |
+
offsets=(0, off_n),
|
| 171 |
+
block_shape=(QK_HEAD_DIM, BLOCK_N),
|
| 172 |
+
order=(0, 1)
|
| 173 |
+
)
|
| 174 |
+
V_block_ptr = tl.make_block_ptr(
|
| 175 |
+
base=V + v_offset,
|
| 176 |
+
shape=(KV_LEN, V_HEAD_DIM),
|
| 177 |
+
strides=(stride_vn, stride_vk),
|
| 178 |
+
offsets=(off_n, 0),
|
| 179 |
+
block_shape=(BLOCK_N, V_HEAD_DIM),
|
| 180 |
+
order=(1, 0)
|
| 181 |
+
)
|
| 182 |
+
offs_n = tl.arange(0, BLOCK_N) + off_n
|
| 183 |
+
|
| 184 |
+
acc, l_i, m_i = forward_inner(
|
| 185 |
+
{{gen_argdefs()}},
|
| 186 |
+
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
|
| 187 |
+
# accumulatd values
|
| 188 |
+
acc, l_i, m_i,
|
| 189 |
+
#offsets
|
| 190 |
+
off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
|
| 191 |
+
#block sparse data
|
| 192 |
+
kv_indices, kv_num_blocks,
|
| 193 |
+
block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
|
| 194 |
+
MATMUL_PRECISION,
|
| 195 |
+
IS_FULL_BLOCKS=False,
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
# ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 200 |
+
# We know these blocks are guaranteed to be "full", so we don't need to
|
| 201 |
+
# apply mask_mod to them - only score_mod
|
| 202 |
+
if HAS_FULL_BLOCKS:
|
| 203 |
+
kv_indices = FULL_KV_IDX + sparse_hz_offset * SPARSE_KV_BLOCK_CNT
|
| 204 |
+
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_hz_offset)
|
| 205 |
+
indices_idx = block_n_start // SPARSE_KV_MULTIPLE
|
| 206 |
+
off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
|
| 207 |
+
off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
|
| 208 |
+
|
| 209 |
+
# last valid block according to sparse mask
|
| 210 |
+
block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 211 |
+
|
| 212 |
+
K_block_ptr = tl.make_block_ptr(
|
| 213 |
+
base=K + k_offset,
|
| 214 |
+
shape=(QK_HEAD_DIM, KV_LEN), # (d, N)
|
| 215 |
+
strides=(stride_kk, stride_kn),
|
| 216 |
+
offsets=(0, off_n),
|
| 217 |
+
block_shape=(QK_HEAD_DIM, BLOCK_N),
|
| 218 |
+
order=(0, 1)
|
| 219 |
+
)
|
| 220 |
+
V_block_ptr = tl.make_block_ptr(
|
| 221 |
+
base=V + v_offset,
|
| 222 |
+
shape=(KV_LEN, V_HEAD_DIM),
|
| 223 |
+
strides=(stride_vn, stride_vk),
|
| 224 |
+
offsets=(off_n, 0),
|
| 225 |
+
block_shape=(BLOCK_N, V_HEAD_DIM),
|
| 226 |
+
order=(1, 0)
|
| 227 |
+
)
|
| 228 |
+
offs_n = tl.arange(0, BLOCK_N) + off_n
|
| 229 |
+
|
| 230 |
+
acc, l_i, m_i = forward_inner(
|
| 231 |
+
{{gen_argdefs()}},
|
| 232 |
+
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
|
| 233 |
+
# accumulatd values
|
| 234 |
+
acc, l_i, m_i,
|
| 235 |
+
#offsets
|
| 236 |
+
off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
|
| 237 |
+
#block sparse data
|
| 238 |
+
kv_indices, kv_num_blocks,
|
| 239 |
+
block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
|
| 240 |
+
MATMUL_PRECISION,
|
| 241 |
+
IS_FULL_BLOCKS=True,
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
m_offset = off_t * stride_mt + off_z * stride_mz
|
| 245 |
+
l_offset = off_t * stride_lt + off_z * stride_lz
|
| 246 |
+
|
| 247 |
+
M_block_ptr = tl.make_block_ptr(
|
| 248 |
+
base=M + m_offset,
|
| 249 |
+
shape=(G, Q_LEN), # (G, M)
|
| 250 |
+
strides=(stride_mh, stride_mm),
|
| 251 |
+
offsets=(off_hkv*G, 0),
|
| 252 |
+
block_shape=(G, BLOCK_M_PER_HQ),
|
| 253 |
+
order=(1, 0)
|
| 254 |
+
)
|
| 255 |
+
L_block_ptr = tl.make_block_ptr(
|
| 256 |
+
base=L + l_offset,
|
| 257 |
+
shape=(G, Q_LEN), # (G, M)
|
| 258 |
+
strides=(stride_lh, stride_lm),
|
| 259 |
+
offsets=(off_hkv*G, 0),
|
| 260 |
+
block_shape=(G, BLOCK_M_PER_HQ),
|
| 261 |
+
order=(1, 0)
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
# Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16)
|
| 265 |
+
m_i = m_i.reshape(G, BLOCK_M_PER_HQ)
|
| 266 |
+
l_i = l_i.reshape(G, BLOCK_M_PER_HQ)
|
| 267 |
+
if SAFE_M_BOUNDARY:
|
| 268 |
+
tl.store(M_block_ptr, m_i)
|
| 269 |
+
tl.store(L_block_ptr, l_i)
|
| 270 |
+
else:
|
| 271 |
+
tl.store(M_block_ptr, m_i, boundary_check=(1,))
|
| 272 |
+
tl.store(L_block_ptr, l_i, boundary_check=(1,))
|
| 273 |
+
|
| 274 |
+
# -- store output
|
| 275 |
+
idx_z = off_z
|
| 276 |
+
idx_t = off_t
|
| 277 |
+
idx_hq = off_hkv*G + off_g[:, None, None]
|
| 278 |
+
idx_m = off_m[None, :, None]
|
| 279 |
+
idx_d = offs_vd[None, None, :]
|
| 280 |
+
|
| 281 |
+
mask = (idx_m < Q_LEN)
|
| 282 |
+
acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM)
|
| 283 |
+
{{store_output(("idx_z", "idx_t", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}}
|
| 284 |
+
"""
|
| 285 |
+
+ compute_forward_inner
|
| 286 |
+
+ compute_next_offset_func
|
| 287 |
+
+ compute_forward_block_mn,
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def get_split_k(B: int, H: int, Mk: int, SM: int = 128) -> int:
|
| 292 |
+
"""Heuristic for the number of splits from xformer"""
|
| 293 |
+
bh = max(B * H, 1) # NOTE: Handle B*h=0 case
|
| 294 |
+
split_k = SM // bh # Each SM should at least get one block.
|
| 295 |
+
split_k = max(split_k, 1)
|
| 296 |
+
|
| 297 |
+
return split_k
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def _get_decoding_default_config(key) -> Tuple[int, int, int]:
|
| 301 |
+
dtype = key.get_dtype()
|
| 302 |
+
head_dim = key.get_size()[-1]
|
| 303 |
+
sm_version = torch.cuda.get_device_capability()
|
| 304 |
+
default_config = (64, 2, 1)
|
| 305 |
+
if sm_version >= (9, 0):
|
| 306 |
+
if head_dim > 128 and dtype == torch.float32:
|
| 307 |
+
return default_config
|
| 308 |
+
return (64, 2, 3)
|
| 309 |
+
return default_config
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def create_flex_decoding_kernel(*args, **kwargs):
|
| 313 |
+
(
|
| 314 |
+
query,
|
| 315 |
+
key,
|
| 316 |
+
value,
|
| 317 |
+
block_mask,
|
| 318 |
+
scale,
|
| 319 |
+
kernel_options,
|
| 320 |
+
score_mod_subgraph,
|
| 321 |
+
mask_mod_subgraph,
|
| 322 |
+
score_mod_other_buffers,
|
| 323 |
+
mask_mod_other_buffers,
|
| 324 |
+
) = args
|
| 325 |
+
(
|
| 326 |
+
kv_num_blocks,
|
| 327 |
+
kv_indices,
|
| 328 |
+
full_kv_num_blocks, # full_kv_num_blocks,
|
| 329 |
+
full_kv_indices, # full_kv_indices,
|
| 330 |
+
_, # q_num_blocks
|
| 331 |
+
_, # q_indices
|
| 332 |
+
_, # full_q_num_blocks,
|
| 333 |
+
_, # full_q_indices,
|
| 334 |
+
SPARSE_KV_BLOCK_SIZE,
|
| 335 |
+
_, # SPARSE_Q_BLOCK_SIZE,
|
| 336 |
+
_,
|
| 337 |
+
) = block_mask
|
| 338 |
+
|
| 339 |
+
Bq, Hq, seq_len_q, qk_head_dim = query.get_size()
|
| 340 |
+
Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size()
|
| 341 |
+
assert Bq == Bkv, "Batch dimension must match"
|
| 342 |
+
B = Bq
|
| 343 |
+
kernel_options = dict(kernel_options)
|
| 344 |
+
|
| 345 |
+
# TODO: Fix flex decoding non-divisible case!
|
| 346 |
+
if seq_len_q % 128 != 0 or seq_len_kv % 128 != 0:
|
| 347 |
+
kernel_options.setdefault("IS_DIVISIBLE", False)
|
| 348 |
+
else:
|
| 349 |
+
kernel_options.setdefault("IS_DIVISIBLE", True)
|
| 350 |
+
|
| 351 |
+
# Calculate GQA head sharing
|
| 352 |
+
gqa_shared_heads = Hq // Hkv
|
| 353 |
+
if not is_power_of_2(gqa_shared_heads):
|
| 354 |
+
raise ValueError(
|
| 355 |
+
"Number of shared query heads sharing the same KV head must be power of 2. "
|
| 356 |
+
)
|
| 357 |
+
kernel_options.setdefault("GQA_SHARED_HEADS", gqa_shared_heads)
|
| 358 |
+
|
| 359 |
+
# Determine if there are "full" blocks where we only need to apply score_mod, and can skip mask_mod
|
| 360 |
+
has_full_blocks = full_kv_num_blocks is not None
|
| 361 |
+
kernel_options.setdefault("HAS_FULL_BLOCKS", has_full_blocks)
|
| 362 |
+
if not has_full_blocks:
|
| 363 |
+
# Create a plackeholder full block list in case it is empty
|
| 364 |
+
full_kv_num_blocks, full_kv_indices = (
|
| 365 |
+
empty(0, device=query.get_device()) for _ in range(2)
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
(
|
| 369 |
+
query,
|
| 370 |
+
key,
|
| 371 |
+
value,
|
| 372 |
+
kv_num_blocks,
|
| 373 |
+
kv_indices,
|
| 374 |
+
full_kv_num_blocks,
|
| 375 |
+
full_kv_indices,
|
| 376 |
+
) = maybe_realize(
|
| 377 |
+
[
|
| 378 |
+
query,
|
| 379 |
+
key,
|
| 380 |
+
value,
|
| 381 |
+
kv_num_blocks,
|
| 382 |
+
kv_indices,
|
| 383 |
+
full_kv_num_blocks,
|
| 384 |
+
full_kv_indices,
|
| 385 |
+
]
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
choices: List[Any] = []
|
| 389 |
+
configs: List[Tuple[int, int, int]] = []
|
| 390 |
+
configs.append(_get_decoding_default_config(key))
|
| 391 |
+
# Note: max_autotune is not supported yet. Causes error in lowering the dynamic shape in reduction ops.
|
| 392 |
+
if config.max_autotune:
|
| 393 |
+
configs += [
|
| 394 |
+
(64, 2, 2),
|
| 395 |
+
(32, 2, 3),
|
| 396 |
+
(128, 2, 3),
|
| 397 |
+
]
|
| 398 |
+
# TODO: fix autotuning.
|
| 399 |
+
|
| 400 |
+
kernel_options.setdefault("SM_SCALE", scale)
|
| 401 |
+
kernel_options.setdefault("SPLIT_KV", get_split_k(B, Hkv, seq_len_kv))
|
| 402 |
+
MAX_SPLIT_KV = kernel_options["SPLIT_KV"]
|
| 403 |
+
|
| 404 |
+
# create config dependent intermediate buffers
|
| 405 |
+
buf_ACC_shape = [B, MAX_SPLIT_KV, Hq, seq_len_q, v_head_dim]
|
| 406 |
+
buf_ML_shape = buf_ACC_shape[:-1]
|
| 407 |
+
buf_M = empty_strided(
|
| 408 |
+
buf_ML_shape,
|
| 409 |
+
None,
|
| 410 |
+
dtype=torch.float32, # The rowmax is always stored in fp32 regardless of the input dtype
|
| 411 |
+
device=query.get_device(),
|
| 412 |
+
)
|
| 413 |
+
buf_L = empty_strided(
|
| 414 |
+
buf_ML_shape,
|
| 415 |
+
None,
|
| 416 |
+
dtype=torch.float32, # The intermediate sumexp is always stored in fp32 regardless of the input dtype
|
| 417 |
+
device=query.get_device(),
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
layout_acc = FixedLayout(
|
| 421 |
+
query.get_device(),
|
| 422 |
+
torch.float32,
|
| 423 |
+
buf_ACC_shape,
|
| 424 |
+
FlexibleLayout.contiguous_strides(buf_ACC_shape),
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
kernel_options.setdefault("QK_HEAD_DIM", qk_head_dim)
|
| 428 |
+
kernel_options.setdefault("V_HEAD_DIM", v_head_dim)
|
| 429 |
+
|
| 430 |
+
kernel_options.setdefault(
|
| 431 |
+
"BLOCK_M",
|
| 432 |
+
(
|
| 433 |
+
# m
|
| 434 |
+
# if V.graph.sizevars.evaluate_expr(sympy.Lt(query.get_size()[-2], 0))
|
| 435 |
+
# else # Always use a BLOCK_M > 16 before Triton fix https://github.com/triton-lang/triton/pull/4061 is in pin
|
| 436 |
+
max(
|
| 437 |
+
next_power_of_2(
|
| 438 |
+
V.graph.sizevars.size_hint(
|
| 439 |
+
seq_len_q, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
|
| 440 |
+
)
|
| 441 |
+
* gqa_shared_heads
|
| 442 |
+
),
|
| 443 |
+
16,
|
| 444 |
+
)
|
| 445 |
+
),
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
query = ir.ExternKernel.realize_input(query)
|
| 449 |
+
stride_b, stride_hq, stride_seq_len_q, stride_qk_head_dim = query.get_stride()
|
| 450 |
+
|
| 451 |
+
# Reshape query for GQA: [B, Hq, Mq, D] -> [B, Hkv, G, Mq, D]
|
| 452 |
+
gqa_query_shape = (B, Hkv, gqa_shared_heads, seq_len_q, qk_head_dim)
|
| 453 |
+
gqa_query_stride = (
|
| 454 |
+
stride_b,
|
| 455 |
+
stride_hq * gqa_shared_heads,
|
| 456 |
+
stride_hq,
|
| 457 |
+
stride_seq_len_q,
|
| 458 |
+
stride_qk_head_dim,
|
| 459 |
+
)
|
| 460 |
+
query = lowerings[aten.as_strided](query, gqa_query_shape, gqa_query_stride)
|
| 461 |
+
|
| 462 |
+
V.graph.sizevars.guard_leq(
|
| 463 |
+
seq_len_q * gqa_shared_heads, sympy.Integer(kernel_options["BLOCK_M"])
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
kernel_options.setdefault(
|
| 467 |
+
"SAFE_M_BOUNDARY",
|
| 468 |
+
((seq_len_q * gqa_shared_heads) % kernel_options["BLOCK_M"]) == 0,
|
| 469 |
+
)
|
| 470 |
+
# TODO: This feels sketchy
|
| 471 |
+
kernel_options.setdefault("SAFE_N_BOUNDARY", True)
|
| 472 |
+
|
| 473 |
+
# Note, we don't need to pass in the captured buffers explicitly
|
| 474 |
+
# because they're implicitly added by the score_mod function
|
| 475 |
+
# We do need to explicitly pass it in for autotuning though.
|
| 476 |
+
for BLOCK_N, num_warps, num_stages in configs:
|
| 477 |
+
if SPARSE_KV_BLOCK_SIZE % BLOCK_N != 0:
|
| 478 |
+
continue
|
| 479 |
+
|
| 480 |
+
# Performance tuning
|
| 481 |
+
kernel_options.setdefault("BLOCK_N", BLOCK_N)
|
| 482 |
+
kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE)
|
| 483 |
+
|
| 484 |
+
# Work around https://github.com/pytorch/pytorch/issues/129625
|
| 485 |
+
if num_stages == 2:
|
| 486 |
+
continue
|
| 487 |
+
flex_decoding_template.maybe_append_choice(
|
| 488 |
+
choices=choices,
|
| 489 |
+
input_nodes=[
|
| 490 |
+
query,
|
| 491 |
+
key,
|
| 492 |
+
value,
|
| 493 |
+
buf_M,
|
| 494 |
+
buf_L,
|
| 495 |
+
kv_num_blocks,
|
| 496 |
+
kv_indices,
|
| 497 |
+
full_kv_num_blocks,
|
| 498 |
+
full_kv_indices,
|
| 499 |
+
],
|
| 500 |
+
layout=layout_acc,
|
| 501 |
+
subgraphs=[
|
| 502 |
+
score_mod_subgraph,
|
| 503 |
+
mask_mod_subgraph,
|
| 504 |
+
],
|
| 505 |
+
mutated_inputs=[buf_M, buf_L],
|
| 506 |
+
num_stages=num_stages,
|
| 507 |
+
num_warps=num_warps,
|
| 508 |
+
call_sizes=query.get_size(),
|
| 509 |
+
**kernel_options,
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
inputs_for_flex_decoding = (
|
| 513 |
+
[
|
| 514 |
+
query,
|
| 515 |
+
key,
|
| 516 |
+
value,
|
| 517 |
+
buf_M,
|
| 518 |
+
buf_L,
|
| 519 |
+
kv_num_blocks,
|
| 520 |
+
kv_indices,
|
| 521 |
+
full_kv_num_blocks,
|
| 522 |
+
full_kv_indices,
|
| 523 |
+
]
|
| 524 |
+
+ list(score_mod_other_buffers)
|
| 525 |
+
+ list(mask_mod_other_buffers)
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
input_gen_fns = {
|
| 529 |
+
5: create_num_blocks_fake_generator(kv_indices),
|
| 530 |
+
6: create_indices_fake,
|
| 531 |
+
7: create_num_blocks_fake_generator(full_kv_indices),
|
| 532 |
+
8: create_indices_fake,
|
| 533 |
+
}
|
| 534 |
+
|
| 535 |
+
buf_ACC = autotune_select_algorithm(
|
| 536 |
+
"flex_decoding",
|
| 537 |
+
choices,
|
| 538 |
+
inputs_for_flex_decoding,
|
| 539 |
+
layout_acc,
|
| 540 |
+
input_gen_fns=input_gen_fns,
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
# Reduction
|
| 544 |
+
|
| 545 |
+
g_M = lowerings[aten.max](buf_M, dim=1, keepdim=True)[0]
|
| 546 |
+
# See [Note] Handle fully masked out rows:
|
| 547 |
+
# g_M Is the global max among split kv blocks.
|
| 548 |
+
masked_rows = lowerings[aten.eq](g_M, -float("inf"))
|
| 549 |
+
adj_M = lowerings[aten.sub](buf_M, g_M)
|
| 550 |
+
adj_M = lowerings[aten.where](masked_rows, 0, adj_M)
|
| 551 |
+
alpha = lowerings[aten.exp2](adj_M)
|
| 552 |
+
|
| 553 |
+
buf_L = lowerings[aten.mul](buf_L, alpha)
|
| 554 |
+
g_L = lowerings[aten.sum](buf_L, axis=1)
|
| 555 |
+
masked_rows_squeezed = lowerings[aten.squeeze](masked_rows, dim=1)
|
| 556 |
+
g_L = lowerings[aten.where](masked_rows_squeezed, 1.0, g_L)
|
| 557 |
+
logsumexp = lowerings[aten.log2](g_L)
|
| 558 |
+
logsumexp = lowerings[aten.add](logsumexp, lowerings[aten.squeeze](g_M, dim=1))
|
| 559 |
+
|
| 560 |
+
alpha_unseq = lowerings[aten.unsqueeze](alpha, 4)
|
| 561 |
+
buf_ACC = lowerings[aten.mul](buf_ACC, alpha_unseq)
|
| 562 |
+
output = lowerings[aten.sum](buf_ACC, axis=1)
|
| 563 |
+
L_unseq = lowerings[aten.unsqueeze](g_L, 3)
|
| 564 |
+
output = lowerings[aten.div](output, L_unseq)
|
| 565 |
+
output = lowerings[prims.convert_element_type](output, query.get_dtype())
|
| 566 |
+
|
| 567 |
+
return (
|
| 568 |
+
output,
|
| 569 |
+
logsumexp,
|
| 570 |
+
)
|
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/mm.py
ADDED
|
@@ -0,0 +1,776 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import functools
|
| 3 |
+
import logging
|
| 4 |
+
from typing import Any, Dict, List, Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch._inductor.autoheuristic.autoheuristic import AutoHeuristicSelectAlgorithm
|
| 8 |
+
from torch._inductor.autoheuristic.autoheuristic_utils import (
|
| 9 |
+
AHContext,
|
| 10 |
+
context_add_strides,
|
| 11 |
+
context_add_using_tf32,
|
| 12 |
+
get_mixedmm_precondition,
|
| 13 |
+
mixed_mm_operations,
|
| 14 |
+
mm_operations,
|
| 15 |
+
)
|
| 16 |
+
from torch._inductor.codegen.cpp_gemm_template import CppPackedGemmTemplate
|
| 17 |
+
from torch._inductor.virtualized import V
|
| 18 |
+
|
| 19 |
+
from .. import config as inductor_config
|
| 20 |
+
from ..codegen.common import BackendFeature
|
| 21 |
+
from ..codegen.cuda.gemm_template import CUTLASS2xGemmTemplate, CUTLASS3xGemmTemplate
|
| 22 |
+
from ..codegen.rocm.ck_universal_gemm_template import CKGemmTemplate
|
| 23 |
+
from ..codegen.wrapper import WrapperCodeGen
|
| 24 |
+
from ..ir import FlexibleLayout, is_triton
|
| 25 |
+
from ..lowering import register_lowering
|
| 26 |
+
from ..select_algorithm import (
|
| 27 |
+
autotune_select_algorithm,
|
| 28 |
+
ExternKernelChoice,
|
| 29 |
+
NoValidChoicesError,
|
| 30 |
+
TritonTemplate,
|
| 31 |
+
)
|
| 32 |
+
from ..utils import (
|
| 33 |
+
get_gpu_shared_memory,
|
| 34 |
+
use_aten_gemm_kernels,
|
| 35 |
+
use_ck_template,
|
| 36 |
+
use_cpp_packed_gemm_template,
|
| 37 |
+
use_cutlass_template,
|
| 38 |
+
use_max_autotune,
|
| 39 |
+
use_triton_template,
|
| 40 |
+
)
|
| 41 |
+
from .mm_common import (
|
| 42 |
+
addmm_epilogue,
|
| 43 |
+
extra_mm_configs,
|
| 44 |
+
int8_mm_configs,
|
| 45 |
+
mixed_mm_configs,
|
| 46 |
+
mm_args,
|
| 47 |
+
mm_configs,
|
| 48 |
+
mm_grid,
|
| 49 |
+
mm_options,
|
| 50 |
+
triton_config,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
log = logging.getLogger(__name__)
|
| 55 |
+
aten = torch.ops.aten
|
| 56 |
+
|
| 57 |
+
mm_template = TritonTemplate(
|
| 58 |
+
name="mm",
|
| 59 |
+
grid=mm_grid,
|
| 60 |
+
source=r"""
|
| 61 |
+
{{def_kernel("A", "B")}}
|
| 62 |
+
M = {{size("A", 0)}}
|
| 63 |
+
N = {{size("B", 1)}}
|
| 64 |
+
K = {{size("A", 1)}}
|
| 65 |
+
if M * N == 0:
|
| 66 |
+
# early exit due to zero-size input(s)
|
| 67 |
+
return
|
| 68 |
+
stride_am = {{stride("A", 0)}}
|
| 69 |
+
stride_ak = {{stride("A", 1)}}
|
| 70 |
+
stride_bk = {{stride("B", 0)}}
|
| 71 |
+
stride_bn = {{stride("B", 1)}}
|
| 72 |
+
|
| 73 |
+
# based on triton.ops.matmul
|
| 74 |
+
pid = tl.program_id(0)
|
| 75 |
+
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
| 76 |
+
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
| 77 |
+
|
| 78 |
+
# re-order program ID for better L2 performance
|
| 79 |
+
width = GROUP_M * grid_n
|
| 80 |
+
group_id = pid // width
|
| 81 |
+
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
| 82 |
+
pid_m = group_id * GROUP_M + (pid % group_size)
|
| 83 |
+
pid_n = (pid % width) // (group_size)
|
| 84 |
+
|
| 85 |
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 86 |
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 87 |
+
if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1):
|
| 88 |
+
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
| 89 |
+
else:
|
| 90 |
+
ram = rm % M
|
| 91 |
+
if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1):
|
| 92 |
+
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
| 93 |
+
else:
|
| 94 |
+
rbn = rn % N
|
| 95 |
+
rk = tl.arange(0, BLOCK_K)
|
| 96 |
+
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
| 97 |
+
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
| 98 |
+
|
| 99 |
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
| 100 |
+
for k in range(K, 0, -BLOCK_K):
|
| 101 |
+
if EVEN_K:
|
| 102 |
+
a = tl.load(A)
|
| 103 |
+
b = tl.load(B)
|
| 104 |
+
else:
|
| 105 |
+
a = tl.load(A, mask=rk[None, :] < k, other=0.)
|
| 106 |
+
b = tl.load(B, mask=rk[:, None] < k, other=0.)
|
| 107 |
+
if B_PROLOGUE_CAST_TYPE is not None:
|
| 108 |
+
b = b.to(B_PROLOGUE_CAST_TYPE)
|
| 109 |
+
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
|
| 110 |
+
A += BLOCK_K * stride_ak
|
| 111 |
+
B += BLOCK_K * stride_bk
|
| 112 |
+
|
| 113 |
+
# rematerialize rm and rn to save registers
|
| 114 |
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 115 |
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 116 |
+
idx_m = rm[:, None]
|
| 117 |
+
idx_n = rn[None, :]
|
| 118 |
+
mask = (idx_m < M) & (idx_n < N)
|
| 119 |
+
|
| 120 |
+
# inductor generates a suffix
|
| 121 |
+
{{store_output(("idx_m", "idx_n"), "acc", "mask")}}
|
| 122 |
+
""",
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
aten_mm = ExternKernelChoice(torch.mm, "at::mm_out")
|
| 126 |
+
|
| 127 |
+
aten_addmm = ExternKernelChoice(
|
| 128 |
+
torch.addmm, "at::addmm_out", op_overload=aten.addmm.default
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
aten__int_mm = ExternKernelChoice(torch._int_mm, "at::_int_mm")
|
| 132 |
+
|
| 133 |
+
aten__sparse_semi_structured_mm = ExternKernelChoice(
|
| 134 |
+
torch._sparse_semi_structured_mm,
|
| 135 |
+
"at::_sparse_semi_structured_mm",
|
| 136 |
+
has_out_variant=False,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def _is_int8_mat(mat):
|
| 141 |
+
return mat.get_dtype() in (torch.int8, torch.uint8)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1):
|
| 145 |
+
"""
|
| 146 |
+
Giving torch.addmm a 1D tensor calls a different (faster) cublasLt
|
| 147 |
+
kernel under the hood. There are a few shapes where this is slower,
|
| 148 |
+
but they are rare.
|
| 149 |
+
"""
|
| 150 |
+
if inp.stride(0) == 0 or inp.size(0) == 1:
|
| 151 |
+
return torch.addmm(inp[0], mat1, mat2, out=out, alpha=alpha, beta=beta)
|
| 152 |
+
return torch.addmm(inp, mat1, mat2, out=out, alpha=alpha, beta=beta)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
aten_bias_addmm = ExternKernelChoice(bias_addmm, None)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
@register_lowering(aten.mm, type_promotion_kind=None)
|
| 159 |
+
def tuned_mm(mat1, mat2, *, layout=None):
|
| 160 |
+
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
|
| 161 |
+
name = "mm"
|
| 162 |
+
|
| 163 |
+
aten_layout = layout
|
| 164 |
+
if not use_max_autotune():
|
| 165 |
+
aten_layout = FlexibleLayout(
|
| 166 |
+
device=layout.device, dtype=layout.dtype, size=layout.size
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
# options to tune from
|
| 170 |
+
choices = (
|
| 171 |
+
[aten_mm.bind((mat1, mat2), aten_layout)] if use_aten_gemm_kernels() else []
|
| 172 |
+
)
|
| 173 |
+
static_shape, is_nonzero = _is_static_problem([mat1, mat2], layout)
|
| 174 |
+
if is_nonzero and use_triton_template(layout):
|
| 175 |
+
for config in mm_configs(m, n, k):
|
| 176 |
+
mm_template.maybe_append_choice(
|
| 177 |
+
choices,
|
| 178 |
+
input_nodes=(mat1, mat2),
|
| 179 |
+
layout=layout,
|
| 180 |
+
**mm_options(config, m, n, k, layout),
|
| 181 |
+
)
|
| 182 |
+
if static_shape and is_nonzero and use_cutlass_template(layout, m, n, k):
|
| 183 |
+
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2])
|
| 184 |
+
|
| 185 |
+
if is_nonzero and use_ck_template(layout, m, n, k):
|
| 186 |
+
CKGemmTemplate.add_ck_gemm_choices(choices, layout, [mat1, mat2])
|
| 187 |
+
|
| 188 |
+
if use_cpp_packed_gemm_template(layout, mat1, mat2):
|
| 189 |
+
CppPackedGemmTemplate.add_choices(
|
| 190 |
+
choices,
|
| 191 |
+
layout,
|
| 192 |
+
[mat1, mat2],
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
input_nodes = [mat1, mat2]
|
| 196 |
+
if (
|
| 197 |
+
is_nonzero
|
| 198 |
+
and use_triton_template(layout)
|
| 199 |
+
and torch._inductor.config.run_autoheuristic(name)
|
| 200 |
+
and is_triton(mat1)
|
| 201 |
+
):
|
| 202 |
+
always_included = []
|
| 203 |
+
if use_aten_gemm_kernels():
|
| 204 |
+
always_included.append("extern_mm")
|
| 205 |
+
num_choices_before_extra_configs = len(choices)
|
| 206 |
+
for config in extra_mm_configs(m, n, k):
|
| 207 |
+
mm_template.maybe_append_choice(
|
| 208 |
+
choices,
|
| 209 |
+
input_nodes=(mat1, mat2),
|
| 210 |
+
layout=layout,
|
| 211 |
+
**mm_options(config, m, n, k, layout),
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
# using AutoHeuristic for ranking
|
| 215 |
+
ah_choices = mm_autoheuristic(
|
| 216 |
+
mat1,
|
| 217 |
+
mat2,
|
| 218 |
+
m,
|
| 219 |
+
n,
|
| 220 |
+
k,
|
| 221 |
+
choices,
|
| 222 |
+
name,
|
| 223 |
+
input_nodes,
|
| 224 |
+
mm_operations(),
|
| 225 |
+
None,
|
| 226 |
+
top_k=10,
|
| 227 |
+
always_included=always_included,
|
| 228 |
+
)
|
| 229 |
+
if not torch._inductor.config.collect_autoheuristic(name):
|
| 230 |
+
# if we are collecting data, we do not want to modify choices
|
| 231 |
+
if ah_choices is not None and len(ah_choices) > 0:
|
| 232 |
+
# the order in which autoheuristic returns choices is not the same as
|
| 233 |
+
# as the order of choices, which affects things like epilogue fusion.
|
| 234 |
+
# once epilogue fusion benchmarks choices in sorted order, I think we can
|
| 235 |
+
# just use the order returned by autoheuristic
|
| 236 |
+
choices = [choice for choice in choices if choice in ah_choices]
|
| 237 |
+
else:
|
| 238 |
+
choices = choices[:num_choices_before_extra_configs]
|
| 239 |
+
|
| 240 |
+
if (
|
| 241 |
+
len(choices) == 0
|
| 242 |
+
and not use_aten_gemm_kernels()
|
| 243 |
+
and inductor_config.autotune_fallback_to_aten
|
| 244 |
+
):
|
| 245 |
+
log.warning("No choices for GEMM, using ATen backend as fallback")
|
| 246 |
+
return aten_mm.bind((mat1, mat2), aten_layout).output_node()
|
| 247 |
+
|
| 248 |
+
try:
|
| 249 |
+
return autotune_select_algorithm(name, choices, [mat1, mat2], layout)
|
| 250 |
+
except NoValidChoicesError:
|
| 251 |
+
if not inductor_config.autotune_fallback_to_aten:
|
| 252 |
+
raise
|
| 253 |
+
log.warning("All choices for GEMM were invalid, using ATen backend as fallback")
|
| 254 |
+
return aten_mm.bind((mat1, mat2), aten_layout).output_node()
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def _is_static_problem(inputs_tensors, layout):
|
| 258 |
+
# checks whether all input tensors and the output layout
|
| 259 |
+
# have a static shape by attempting to convert the dimensions
|
| 260 |
+
# to int
|
| 261 |
+
static_shape = True
|
| 262 |
+
static_size = WrapperCodeGen.statically_known_list_of_ints_or_none(layout.size)
|
| 263 |
+
if static_size is None:
|
| 264 |
+
nonzero = True
|
| 265 |
+
for s in layout.size:
|
| 266 |
+
sz = WrapperCodeGen.statically_known_int_or_none(s)
|
| 267 |
+
if sz is not None and sz == 0:
|
| 268 |
+
nonzero = False
|
| 269 |
+
break
|
| 270 |
+
return False, nonzero
|
| 271 |
+
numel = 1
|
| 272 |
+
for dim in static_size:
|
| 273 |
+
numel *= dim
|
| 274 |
+
nonzero = numel > 0
|
| 275 |
+
return static_shape, nonzero
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
@register_lowering(aten._int_mm, type_promotion_kind=None)
|
| 279 |
+
def tuned_int_mm(mat1, mat2, *, layout=None):
|
| 280 |
+
m, n, k, layout, mat1, mat2 = mm_args(
|
| 281 |
+
mat1, mat2, layout=layout, out_dtype=torch.int32
|
| 282 |
+
)
|
| 283 |
+
static_shape, is_nonzero = _is_static_problem([mat1, mat2], layout)
|
| 284 |
+
use_cutlass = static_shape and is_nonzero and use_cutlass_template(layout, m, n, k)
|
| 285 |
+
|
| 286 |
+
choices = (
|
| 287 |
+
[aten__int_mm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else []
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
# TODO: Re-enable eager mode implementation once cuBLAS is fixed
|
| 291 |
+
if use_cutlass or use_triton_template(layout, enable_int32=True):
|
| 292 |
+
choices = []
|
| 293 |
+
|
| 294 |
+
if use_cutlass:
|
| 295 |
+
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
|
| 296 |
+
choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True
|
| 297 |
+
)
|
| 298 |
+
if is_nonzero and use_triton_template(layout, enable_int32=True):
|
| 299 |
+
for config in int8_mm_configs(m, n, k):
|
| 300 |
+
mm_template.maybe_append_choice(
|
| 301 |
+
choices,
|
| 302 |
+
input_nodes=(mat1, mat2),
|
| 303 |
+
layout=layout,
|
| 304 |
+
**mm_options(config, m, n, k, layout),
|
| 305 |
+
)
|
| 306 |
+
if len(choices) == 0:
|
| 307 |
+
log.warning(
|
| 308 |
+
"No choices for integer GEMM avaialbe using configured backends, using ATen backend as fallback"
|
| 309 |
+
)
|
| 310 |
+
choices = [aten__int_mm.bind((mat1, mat2), layout)]
|
| 311 |
+
|
| 312 |
+
try:
|
| 313 |
+
return autotune_select_algorithm("int_mm", choices, [mat1, mat2], layout)
|
| 314 |
+
except NoValidChoicesError:
|
| 315 |
+
if not inductor_config.autotune_fallback_to_aten:
|
| 316 |
+
raise
|
| 317 |
+
log.warning("All choices for GEMM were invalid, using ATen backend as fallback")
|
| 318 |
+
choices = [aten__int_mm.bind((mat1, mat2), layout)]
|
| 319 |
+
return autotune_select_algorithm("int_mm", choices, [mat1, mat2], layout)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
@register_lowering(aten.addmm, type_promotion_kind=None)
|
| 323 |
+
def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
| 324 |
+
ordered_kwargs_for_cpp_kernel = ("beta", "alpha")
|
| 325 |
+
m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout)
|
| 326 |
+
static_shape, is_nonzero = _is_static_problem([inp, mat1, mat2], layout)
|
| 327 |
+
if (not is_nonzero) or (not use_max_autotune()):
|
| 328 |
+
# Use a FlexibleLayout if we are not autotuning.
|
| 329 |
+
# This allows padding strides for the output.
|
| 330 |
+
from torch._inductor.ir import FixedLayout, FlexibleLayout
|
| 331 |
+
|
| 332 |
+
if isinstance(layout, FixedLayout):
|
| 333 |
+
layout = FlexibleLayout(
|
| 334 |
+
device=layout.device, dtype=layout.dtype, size=layout.size
|
| 335 |
+
)
|
| 336 |
+
choices = (
|
| 337 |
+
[
|
| 338 |
+
aten_addmm.bind(
|
| 339 |
+
(inp, mat1, mat2),
|
| 340 |
+
layout,
|
| 341 |
+
alpha=alpha,
|
| 342 |
+
beta=beta,
|
| 343 |
+
)
|
| 344 |
+
]
|
| 345 |
+
if use_aten_gemm_kernels()
|
| 346 |
+
else []
|
| 347 |
+
)
|
| 348 |
+
return autotune_select_algorithm("addmm", choices, [inp, mat1, mat2], layout)
|
| 349 |
+
|
| 350 |
+
choices = (
|
| 351 |
+
[
|
| 352 |
+
aten_addmm.bind(
|
| 353 |
+
(inp_expanded, mat1, mat2),
|
| 354 |
+
layout,
|
| 355 |
+
alpha=alpha,
|
| 356 |
+
beta=beta,
|
| 357 |
+
)
|
| 358 |
+
]
|
| 359 |
+
if use_aten_gemm_kernels()
|
| 360 |
+
else []
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
if (
|
| 364 |
+
use_aten_gemm_kernels()
|
| 365 |
+
and inp_expanded.get_stride()[0] == 0
|
| 366 |
+
and inp_expanded.get_device().type == "cuda"
|
| 367 |
+
and inductor_config.triton.autotune_cublasLt
|
| 368 |
+
):
|
| 369 |
+
# unexpand inp to make sure fused addmm from cublasLt is used
|
| 370 |
+
choices.insert(
|
| 371 |
+
0,
|
| 372 |
+
aten_bias_addmm.bind(
|
| 373 |
+
(inp_expanded, mat1, mat2), layout, alpha=alpha, beta=beta
|
| 374 |
+
),
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
if is_nonzero and use_triton_template(layout):
|
| 378 |
+
for config in mm_configs(m, n, k):
|
| 379 |
+
mm_template.maybe_append_choice(
|
| 380 |
+
choices,
|
| 381 |
+
input_nodes=(inp_expanded, mat1, mat2),
|
| 382 |
+
layout=layout,
|
| 383 |
+
**mm_options(config, m, n, k, layout),
|
| 384 |
+
prefix_args=1,
|
| 385 |
+
epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
if static_shape and is_nonzero and use_cutlass_template(layout, m, n, k):
|
| 389 |
+
# Filter out a known cause of CUDA illegal memory access errors
|
| 390 |
+
# broadcasting on the last dim of the bias term seems not to be working
|
| 391 |
+
# in the linear GEMM epilogue used by addmm.
|
| 392 |
+
if (
|
| 393 |
+
WrapperCodeGen.statically_known_int_or_none(inp_expanded.layout.stride[-1])
|
| 394 |
+
!= 0
|
| 395 |
+
):
|
| 396 |
+
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
|
| 397 |
+
choices,
|
| 398 |
+
layout,
|
| 399 |
+
[mat1, mat2, inp_expanded],
|
| 400 |
+
alpha=alpha,
|
| 401 |
+
beta=beta,
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
if is_nonzero and use_ck_template(layout, m, n, k):
|
| 405 |
+
CKGemmTemplate.add_ck_gemm_choices(
|
| 406 |
+
choices,
|
| 407 |
+
layout,
|
| 408 |
+
[mat1, mat2, inp_expanded],
|
| 409 |
+
alpha=alpha,
|
| 410 |
+
beta=beta,
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
if use_cpp_packed_gemm_template(layout, mat1, mat2):
|
| 414 |
+
CppPackedGemmTemplate.add_choices(
|
| 415 |
+
choices,
|
| 416 |
+
layout,
|
| 417 |
+
[inp_expanded, mat1, mat2],
|
| 418 |
+
alpha=alpha,
|
| 419 |
+
beta=beta,
|
| 420 |
+
has_bias=True,
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
add_aten_fallback = False
|
| 424 |
+
if len(choices) == 0:
|
| 425 |
+
log.warning("No choices for GEMM, using ATen backend as fallback")
|
| 426 |
+
add_aten_fallback = True
|
| 427 |
+
|
| 428 |
+
if add_aten_fallback:
|
| 429 |
+
choices.append(
|
| 430 |
+
aten_addmm.bind(
|
| 431 |
+
(inp_expanded, mat1, mat2),
|
| 432 |
+
layout,
|
| 433 |
+
ordered_kwargs_for_cpp_kernel,
|
| 434 |
+
alpha=alpha,
|
| 435 |
+
beta=beta,
|
| 436 |
+
)
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
if (
|
| 440 |
+
inp_expanded.get_stride()[0] == 0
|
| 441 |
+
and inp_expanded.get_device().type == "cuda"
|
| 442 |
+
and inductor_config.triton.autotune_cublasLt
|
| 443 |
+
):
|
| 444 |
+
# unexpand inp to make sure fused addmm from cublasLt is used
|
| 445 |
+
choices.insert(
|
| 446 |
+
0,
|
| 447 |
+
aten_bias_addmm.bind(
|
| 448 |
+
(inp_expanded, mat1, mat2), layout, alpha=alpha, beta=beta
|
| 449 |
+
),
|
| 450 |
+
)
|
| 451 |
+
try:
|
| 452 |
+
return autotune_select_algorithm(
|
| 453 |
+
"addmm", choices, [inp_expanded, mat1, mat2], layout
|
| 454 |
+
)
|
| 455 |
+
except NoValidChoicesError:
|
| 456 |
+
if not inductor_config.autotune_fallback_to_aten:
|
| 457 |
+
raise
|
| 458 |
+
log.warning("All choices for GEMM were invalid, using ATen backend as fallback")
|
| 459 |
+
fallback_choice = aten_addmm.bind(
|
| 460 |
+
(inp, mat1, mat2),
|
| 461 |
+
layout,
|
| 462 |
+
ordered_kwargs_for_cpp_kernel,
|
| 463 |
+
alpha=alpha,
|
| 464 |
+
beta=beta,
|
| 465 |
+
)
|
| 466 |
+
return fallback_choice.output_node()
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
@register_lowering(aten._sparse_semi_structured_mm, type_promotion_kind=None)
|
| 470 |
+
def tuned_sparse_semi_structured_mm(
|
| 471 |
+
mat1, mat1_meta, mat2, *, out_dtype=None, layout=None
|
| 472 |
+
):
|
| 473 |
+
from torch._inductor.select_algorithm import realize_inputs
|
| 474 |
+
|
| 475 |
+
mat1, mat1_meta, mat2 = realize_inputs(mat1, mat1_meta, mat2)
|
| 476 |
+
m1, k1 = mat1.get_size()
|
| 477 |
+
m2, _ = mat1_meta.get_size()
|
| 478 |
+
k2, n = mat2.get_size()
|
| 479 |
+
m = V.graph.sizevars.guard_equals(m1, m2)
|
| 480 |
+
k = V.graph.sizevars.guard_equals(2 * k1, k2)
|
| 481 |
+
|
| 482 |
+
if layout is None:
|
| 483 |
+
from torch._inductor.ir import FixedLayout
|
| 484 |
+
|
| 485 |
+
layout = FixedLayout(
|
| 486 |
+
mat2.get_device(),
|
| 487 |
+
out_dtype if out_dtype else mat2.get_dtype(),
|
| 488 |
+
[m, n],
|
| 489 |
+
[n, 1],
|
| 490 |
+
)
|
| 491 |
+
else:
|
| 492 |
+
assert out_dtype is None, "out_dtype is ignored if layout is specified."
|
| 493 |
+
|
| 494 |
+
choices = (
|
| 495 |
+
[
|
| 496 |
+
aten__sparse_semi_structured_mm.bind(
|
| 497 |
+
(mat1, mat1_meta, mat2), layout, out_dtype=out_dtype
|
| 498 |
+
)
|
| 499 |
+
]
|
| 500 |
+
if use_aten_gemm_kernels()
|
| 501 |
+
else []
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
if m * n != 0 and use_cutlass_template(layout, m, n, k):
|
| 505 |
+
CUTLASS2xGemmTemplate.add_cutlass_gemm_choices(
|
| 506 |
+
choices, layout, [mat1, mat2, mat1_meta], fuseable=True, non_fuseable=True
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
return autotune_select_algorithm(
|
| 510 |
+
"sparse_semi_structured_mm", choices, [mat1, mat1_meta, mat2], layout
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
def fallback_mixed_mm(mat1, mat2, *, out):
|
| 515 |
+
return torch.mm(mat1, mat2.to(mat1.dtype), out=out)
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
aten_fallback_mixed_mm = ExternKernelChoice(fallback_mixed_mm, None)
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
@functools.lru_cache(None)
|
| 522 |
+
def _is_sm7x_or_older_gpu(index: Optional[int]) -> bool:
|
| 523 |
+
props = torch.cuda.get_device_properties(index or 0)
|
| 524 |
+
return props.major <= 7
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
def dims_are_int(dims):
|
| 528 |
+
return all(isinstance(dim, int) for dim in dims)
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
def try_heuristic(m, n, k, choices, mat1, mat2, mat2_dtype, layout):
|
| 532 |
+
m, n, k = get_size_hints(mat1, mat2, m, n, k)
|
| 533 |
+
if not dims_are_int([m, n, k]):
|
| 534 |
+
return None
|
| 535 |
+
|
| 536 |
+
if mat1.dtype != torch.float16:
|
| 537 |
+
return None
|
| 538 |
+
|
| 539 |
+
# only use heuristic if we are running on an A100
|
| 540 |
+
# torch.cuda.get_device_capability() >= (8, 0) returns true for A10G
|
| 541 |
+
# which does not have enough shared memory for one of the configs
|
| 542 |
+
if (
|
| 543 |
+
not torch.cuda.get_device_capability() >= (8, 0)
|
| 544 |
+
) or get_gpu_shared_memory() != 166912:
|
| 545 |
+
return None
|
| 546 |
+
|
| 547 |
+
if m == 1 and (n % 16 != 0 or k % 16 != 0):
|
| 548 |
+
return None
|
| 549 |
+
|
| 550 |
+
if m <= 16 and n >= 4096 and k >= 4096:
|
| 551 |
+
return triton_config(
|
| 552 |
+
BLOCK_M=16,
|
| 553 |
+
BLOCK_N=64,
|
| 554 |
+
BLOCK_K=128,
|
| 555 |
+
num_stages=5,
|
| 556 |
+
num_warps=4,
|
| 557 |
+
)
|
| 558 |
+
elif m > 16 and m <= 32 and n >= 4096 and k >= 4096:
|
| 559 |
+
return triton_config(
|
| 560 |
+
BLOCK_M=32,
|
| 561 |
+
BLOCK_N=32,
|
| 562 |
+
BLOCK_K=128,
|
| 563 |
+
num_stages=5,
|
| 564 |
+
num_warps=4,
|
| 565 |
+
)
|
| 566 |
+
elif m > 32 and m <= 64 and n >= 4096 and k >= 4096:
|
| 567 |
+
return triton_config(
|
| 568 |
+
BLOCK_M=64,
|
| 569 |
+
BLOCK_N=32,
|
| 570 |
+
BLOCK_K=128,
|
| 571 |
+
num_stages=5,
|
| 572 |
+
num_warps=4,
|
| 573 |
+
)
|
| 574 |
+
return None
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
def mm_autoheuristic(
|
| 578 |
+
mat1,
|
| 579 |
+
mat2,
|
| 580 |
+
m,
|
| 581 |
+
n,
|
| 582 |
+
k,
|
| 583 |
+
choices,
|
| 584 |
+
name,
|
| 585 |
+
input_nodes,
|
| 586 |
+
ops,
|
| 587 |
+
precondition,
|
| 588 |
+
top_k: Optional[int] = None,
|
| 589 |
+
always_included=None,
|
| 590 |
+
):
|
| 591 |
+
m, n, k = get_size_hints(mat1, mat2, m, n, k)
|
| 592 |
+
if not dims_are_int([m, n, k]):
|
| 593 |
+
return None
|
| 594 |
+
mat1_stride, mat2_stride = get_size_hints_strides(mat1, mat2)
|
| 595 |
+
|
| 596 |
+
def get_context(m, k, n, mat1, mat2, mat1_stride, mat2_stride):
|
| 597 |
+
context = AHContext()
|
| 598 |
+
context.add_feature("m", m)
|
| 599 |
+
context.add_feature("k", k)
|
| 600 |
+
context.add_feature("n", n)
|
| 601 |
+
context.add_feature("mat1_dtype", mat1.layout.dtype, is_categorical=True)
|
| 602 |
+
context.add_feature("mat2_dtype", mat2.layout.dtype, is_categorical=True)
|
| 603 |
+
context_add_strides(context, "mat1", mat1_stride)
|
| 604 |
+
context_add_strides(context, "mat2", mat2_stride)
|
| 605 |
+
context.add_feature(
|
| 606 |
+
"mat1_iscontig", mat1.layout.is_contiguous(), is_categorical=True
|
| 607 |
+
)
|
| 608 |
+
context.add_feature(
|
| 609 |
+
"mat2_iscontig", mat2.layout.is_contiguous(), is_categorical=True
|
| 610 |
+
)
|
| 611 |
+
if name == "mm":
|
| 612 |
+
# for mixed_mm, we only consider fp16
|
| 613 |
+
context_add_using_tf32(context, mat1.layout.dtype)
|
| 614 |
+
return context
|
| 615 |
+
|
| 616 |
+
def fallback():
|
| 617 |
+
return None
|
| 618 |
+
|
| 619 |
+
context = get_context(m, k, n, mat1, mat2, mat1_stride, mat2_stride)
|
| 620 |
+
autoheuristic = AutoHeuristicSelectAlgorithm(
|
| 621 |
+
fallback=fallback,
|
| 622 |
+
choices=choices,
|
| 623 |
+
input_nodes=input_nodes,
|
| 624 |
+
context=context,
|
| 625 |
+
name=name,
|
| 626 |
+
augment_context=ops,
|
| 627 |
+
precondition=precondition,
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
if top_k is not None:
|
| 631 |
+
# TODO: is there a cleaner way to ensure aten.mm is always included?
|
| 632 |
+
return autoheuristic.get_top_k_choices_caller(
|
| 633 |
+
top_k, always_included=always_included
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
return autoheuristic.get_choice_caller()
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
def get_size_hints(mat1, mat2, m, n, k):
|
| 640 |
+
if not isinstance(m, int) or not isinstance(k, int):
|
| 641 |
+
(m, k) = V.graph.sizevars.size_hints(
|
| 642 |
+
mat1.get_size(),
|
| 643 |
+
fallback=torch._inductor.config.unbacked_symint_fallback,
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
if not isinstance(n, int) or not isinstance(k, int):
|
| 647 |
+
(k, n) = V.graph.sizevars.size_hints(
|
| 648 |
+
mat2.get_size(),
|
| 649 |
+
fallback=torch._inductor.config.unbacked_symint_fallback,
|
| 650 |
+
)
|
| 651 |
+
return m, n, k
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
def get_size_hints_strides(mat1, mat2):
|
| 655 |
+
mat1_stride = mat1.layout.stride
|
| 656 |
+
mat2_stride = mat2.layout.stride
|
| 657 |
+
strides = [mat1_stride, mat2_stride]
|
| 658 |
+
strides_hints = []
|
| 659 |
+
for stride in strides:
|
| 660 |
+
if not isinstance(stride, int):
|
| 661 |
+
stride = V.graph.sizevars.size_hints(
|
| 662 |
+
stride,
|
| 663 |
+
fallback=torch._inductor.config.unbacked_symint_fallback,
|
| 664 |
+
)
|
| 665 |
+
strides_hints.append(stride)
|
| 666 |
+
return strides_hints[0], strides_hints[1]
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
def tuned_mixed_mm(mat1, mat2, mat2_dtype):
|
| 670 |
+
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=None)
|
| 671 |
+
static_shape, is_nonzero = _is_static_problem([mat1, mat2], layout)
|
| 672 |
+
|
| 673 |
+
fallback = aten_fallback_mixed_mm.bind((mat1, mat2), layout)
|
| 674 |
+
|
| 675 |
+
choices = [fallback]
|
| 676 |
+
|
| 677 |
+
# can't use triton kernel unless one of these is true or if running on v100 (numerical issues)
|
| 678 |
+
skip_triton = (
|
| 679 |
+
(
|
| 680 |
+
mat1.layout.dtype != torch.float32
|
| 681 |
+
and not (mat2.layout.is_contiguous() or mat2.layout.is_transposed())
|
| 682 |
+
)
|
| 683 |
+
or _is_sm7x_or_older_gpu(layout.device.index)
|
| 684 |
+
or inductor_config.mixed_mm_choice == "aten"
|
| 685 |
+
or not V.graph.has_feature(layout.device, BackendFeature.TRITON_TEMPLATES)
|
| 686 |
+
or (
|
| 687 |
+
mat1.layout.dtype == torch.float32 and torch.backends.cuda.matmul.allow_tf32
|
| 688 |
+
)
|
| 689 |
+
or (mat1.layout.dtype == torch.bfloat16 and mat2.layout.dtype == torch.uint8)
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
if inductor_config.mixed_mm_choice == "triton":
|
| 693 |
+
choices = []
|
| 694 |
+
|
| 695 |
+
if not skip_triton:
|
| 696 |
+
b_prologue_cast_type = f"tl.{mat2_dtype}".replace("torch.", "")
|
| 697 |
+
if static_shape and inductor_config.mixed_mm_choice == "heuristic":
|
| 698 |
+
choices = []
|
| 699 |
+
config = try_heuristic(m, n, k, choices, mat1, mat2, mat2_dtype, layout)
|
| 700 |
+
if config is not None:
|
| 701 |
+
mm_template.maybe_append_choice(
|
| 702 |
+
choices,
|
| 703 |
+
input_nodes=(mat1, mat2),
|
| 704 |
+
layout=layout,
|
| 705 |
+
**mm_options(config, m, n, k, layout, b_prologue_cast_type),
|
| 706 |
+
)
|
| 707 |
+
choices.append(fallback)
|
| 708 |
+
|
| 709 |
+
has_int8_tensor = _is_int8_mat(mat1) or _is_int8_mat(mat2)
|
| 710 |
+
for config in mixed_mm_configs(m, n, k, has_int8_tensor=has_int8_tensor):
|
| 711 |
+
mm_template.maybe_append_choice(
|
| 712 |
+
choices,
|
| 713 |
+
input_nodes=(mat1, mat2),
|
| 714 |
+
layout=layout,
|
| 715 |
+
**mm_options(config, m, n, k, layout, b_prologue_cast_type),
|
| 716 |
+
)
|
| 717 |
+
|
| 718 |
+
if static_shape and is_nonzero and use_cutlass_template(layout, m, n, k):
|
| 719 |
+
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
|
| 720 |
+
choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True
|
| 721 |
+
)
|
| 722 |
+
CUTLASS2xGemmTemplate.add_cutlass_gemm_choices(
|
| 723 |
+
choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True
|
| 724 |
+
)
|
| 725 |
+
|
| 726 |
+
if skip_triton and not choices:
|
| 727 |
+
choices = [fallback]
|
| 728 |
+
|
| 729 |
+
name = "mixed_mm"
|
| 730 |
+
input_nodes = [mat1, mat2]
|
| 731 |
+
if torch._inductor.config.run_autoheuristic(name):
|
| 732 |
+
choice = mm_autoheuristic(
|
| 733 |
+
mat1,
|
| 734 |
+
mat2,
|
| 735 |
+
m,
|
| 736 |
+
n,
|
| 737 |
+
k,
|
| 738 |
+
choices,
|
| 739 |
+
name,
|
| 740 |
+
input_nodes,
|
| 741 |
+
mixed_mm_operations(),
|
| 742 |
+
get_mixedmm_precondition,
|
| 743 |
+
)
|
| 744 |
+
if (
|
| 745 |
+
not skip_triton
|
| 746 |
+
and inductor_config.mixed_mm_choice == "heuristic"
|
| 747 |
+
and choice is not None
|
| 748 |
+
):
|
| 749 |
+
choices.insert(0, choice)
|
| 750 |
+
return autotune_select_algorithm(name, choices, input_nodes, layout)
|
| 751 |
+
|
| 752 |
+
|
| 753 |
+
# This op is a special case of the int_mm op which we use based on the pattern
|
| 754 |
+
# _int_mm -> mul (defined in ../fx_passes/post_grad.py) in order to prevent
|
| 755 |
+
# realization of the int32 _int_mm output by forcing fusion with the mul op.
|
| 756 |
+
# This is only used when config.force_fuse_int_mm_with_mul = True
|
| 757 |
+
def tuned_fused_int_mm_mul(mat1, mat2, mat3, out_dtype, *, layout=None):
|
| 758 |
+
out_dtype = (
|
| 759 |
+
torch.promote_types(mat3.get_dtype(), torch.int32)
|
| 760 |
+
if out_dtype is None
|
| 761 |
+
else out_dtype
|
| 762 |
+
)
|
| 763 |
+
m, n, k, layout, mat1, mat2, mat3 = mm_args(
|
| 764 |
+
mat1, mat2, mat3, layout=layout, out_dtype=out_dtype
|
| 765 |
+
)
|
| 766 |
+
choices: List[Dict[Any, Any]] = []
|
| 767 |
+
for config in int8_mm_configs(m, n, k):
|
| 768 |
+
mm_template.maybe_append_choice(
|
| 769 |
+
choices,
|
| 770 |
+
input_nodes=(mat1, mat2, mat3),
|
| 771 |
+
layout=layout,
|
| 772 |
+
**dict(mm_options(config, m, n, k, layout), ACC_TYPE="tl.int32"),
|
| 773 |
+
suffix_args=1,
|
| 774 |
+
epilogue_fn=V.ops.mul,
|
| 775 |
+
)
|
| 776 |
+
return autotune_select_algorithm("int_mm", choices, [mat1, mat2, mat3], layout)
|
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/mm_common.py
ADDED
|
@@ -0,0 +1,466 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import functools
|
| 3 |
+
import itertools
|
| 4 |
+
import logging
|
| 5 |
+
from typing import cast, List, Tuple
|
| 6 |
+
|
| 7 |
+
import sympy
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch._inductor.select_algorithm import realize_inputs
|
| 11 |
+
from torch._inductor.virtualized import V
|
| 12 |
+
|
| 13 |
+
from .. import config as inductor_config
|
| 14 |
+
from ..runtime.runtime_utils import next_power_of_2
|
| 15 |
+
from ..utils import ceildiv as cdiv
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
log = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def triton_config(num_stages, num_warps, **kwargs):
|
| 22 |
+
from triton import Config
|
| 23 |
+
|
| 24 |
+
return Config(kwargs, num_stages=num_stages, num_warps=num_warps)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def filtered_configs(
|
| 28 |
+
m: int,
|
| 29 |
+
n: int,
|
| 30 |
+
k: int,
|
| 31 |
+
configs: List[Tuple[int, int, int, int, int]],
|
| 32 |
+
has_int8_tensor=False,
|
| 33 |
+
):
|
| 34 |
+
"""Heuristic to shrink configs when they are bigger than the input size"""
|
| 35 |
+
|
| 36 |
+
min_block_size = 16
|
| 37 |
+
# block_k=16 seems to be causing issues
|
| 38 |
+
# see: https://github.com/triton-lang/triton/issues/2156#issuecomment-1695897424
|
| 39 |
+
min_block_size_k = 32 if has_int8_tensor else 16
|
| 40 |
+
m = max(
|
| 41 |
+
next_power_of_2(
|
| 42 |
+
V.graph.sizevars.size_hint(
|
| 43 |
+
m, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
|
| 44 |
+
)
|
| 45 |
+
),
|
| 46 |
+
min_block_size,
|
| 47 |
+
)
|
| 48 |
+
n = max(
|
| 49 |
+
next_power_of_2(
|
| 50 |
+
V.graph.sizevars.size_hint(
|
| 51 |
+
n, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
|
| 52 |
+
)
|
| 53 |
+
),
|
| 54 |
+
min_block_size,
|
| 55 |
+
)
|
| 56 |
+
k = max(
|
| 57 |
+
next_power_of_2(
|
| 58 |
+
V.graph.sizevars.size_hint(
|
| 59 |
+
k, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
|
| 60 |
+
)
|
| 61 |
+
),
|
| 62 |
+
min_block_size_k,
|
| 63 |
+
)
|
| 64 |
+
used = set()
|
| 65 |
+
for block_m, block_n, block_k, num_stages, num_warps in configs:
|
| 66 |
+
# shrink configs for small sizes
|
| 67 |
+
block_m = max(min(block_m, m), min_block_size)
|
| 68 |
+
block_n = max(min(block_n, n), min_block_size)
|
| 69 |
+
block_k = max(min(block_k, k), min_block_size_k)
|
| 70 |
+
# each warp computes 16x16 tile = 256
|
| 71 |
+
num_warps = min(num_warps, block_m * block_n // 256)
|
| 72 |
+
if torch.version.hip:
|
| 73 |
+
for matrix_instr_nonkdim in [0, 16]:
|
| 74 |
+
if matrix_instr_nonkdim != 0 and (
|
| 75 |
+
block_m % matrix_instr_nonkdim != 0
|
| 76 |
+
or block_n % matrix_instr_nonkdim != 0
|
| 77 |
+
):
|
| 78 |
+
# block_m and block_n must be a multiple of matrix_instr_nonkdim
|
| 79 |
+
continue
|
| 80 |
+
if (
|
| 81 |
+
block_m,
|
| 82 |
+
block_n,
|
| 83 |
+
block_k,
|
| 84 |
+
num_stages,
|
| 85 |
+
num_warps,
|
| 86 |
+
matrix_instr_nonkdim,
|
| 87 |
+
) not in used:
|
| 88 |
+
used.add(
|
| 89 |
+
(
|
| 90 |
+
block_m,
|
| 91 |
+
block_n,
|
| 92 |
+
block_k,
|
| 93 |
+
num_stages,
|
| 94 |
+
num_warps,
|
| 95 |
+
matrix_instr_nonkdim,
|
| 96 |
+
)
|
| 97 |
+
)
|
| 98 |
+
yield triton_config(
|
| 99 |
+
BLOCK_M=block_m,
|
| 100 |
+
BLOCK_N=block_n,
|
| 101 |
+
BLOCK_K=block_k,
|
| 102 |
+
num_stages=num_stages,
|
| 103 |
+
num_warps=num_warps,
|
| 104 |
+
matrix_instr_nonkdim=matrix_instr_nonkdim,
|
| 105 |
+
)
|
| 106 |
+
else:
|
| 107 |
+
if (block_m, block_n, block_k, num_stages, num_warps, 0) not in used:
|
| 108 |
+
used.add((block_m, block_n, block_k, num_stages, num_warps, 0))
|
| 109 |
+
yield triton_config(
|
| 110 |
+
BLOCK_M=block_m,
|
| 111 |
+
BLOCK_N=block_n,
|
| 112 |
+
BLOCK_K=block_k,
|
| 113 |
+
num_stages=num_stages,
|
| 114 |
+
num_warps=num_warps,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# List of dictionaries to store the kernel configs. Configs that evaluate to true
|
| 119 |
+
# will be utilised on the target platform. The configs are as follows:
|
| 120 |
+
# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
|
| 121 |
+
mm_kernel_configs = (
|
| 122 |
+
[
|
| 123 |
+
{"config": (32, 32, 16, 1, 2), "cond": True},
|
| 124 |
+
{"config": (32, 32, 128, 2, 4), "cond": torch.version.hip is None},
|
| 125 |
+
{"config": (32, 64, 32, 5, 8), "cond": True},
|
| 126 |
+
{"config": (64, 32, 32, 5, 8), "cond": True},
|
| 127 |
+
{"config": (64, 32, 128, 5, 4), "cond": True},
|
| 128 |
+
{"config": (64, 64, 16, 2, 4), "cond": True},
|
| 129 |
+
{"config": (64, 64, 32, 2, 4), "cond": True},
|
| 130 |
+
{"config": (64, 64, 64, 3, 8), "cond": True},
|
| 131 |
+
{"config": (64, 64, 128, 5, 4), "cond": True},
|
| 132 |
+
{"config": (64, 128, 32, 3, 4), "cond": True},
|
| 133 |
+
{"config": (64, 128, 32, 4, 8), "cond": True},
|
| 134 |
+
{"config": (64, 128, 64, 3, 4), "cond": True},
|
| 135 |
+
{"config": (64, 128, 128, 4, 4), "cond": True},
|
| 136 |
+
{"config": (128, 64, 32, 3, 4), "cond": True},
|
| 137 |
+
{"config": (128, 64, 32, 4, 8), "cond": True},
|
| 138 |
+
{"config": (128, 128, 32, 2, 8), "cond": True},
|
| 139 |
+
{"config": (128, 128, 32, 3, 4), "cond": True},
|
| 140 |
+
{"config": (128, 128, 64, 3, 4), "cond": True},
|
| 141 |
+
{"config": (128, 128, 64, 5, 8), "cond": True},
|
| 142 |
+
]
|
| 143 |
+
if inductor_config.max_autotune_gemm_search_space != "EXHAUSTIVE"
|
| 144 |
+
else [
|
| 145 |
+
{"config": (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps), "cond": True}
|
| 146 |
+
for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product(
|
| 147 |
+
[16, 32, 64, 128, 256], repeat=3
|
| 148 |
+
)
|
| 149 |
+
for num_stages in [1, 2, 3, 4, 5]
|
| 150 |
+
for num_warps in [2, 4, 8]
|
| 151 |
+
]
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# these are only used in tuned_mm when AutoHeuristic is enabled
|
| 155 |
+
# the idea is that when AutoHeuristic collects data to learn a heuristic, more configs are autotuned
|
| 156 |
+
# when the learned heuristic is used, the learned heuristic reduces the number of configs down to 10
|
| 157 |
+
# which saves compilation time (since less configs are autotuned) and potentially increase performance
|
| 158 |
+
# because the learned heuristic might predict a config that is not part mm_configs
|
| 159 |
+
extra_mm_kernel_configs = [
|
| 160 |
+
{"config": (16, 32, 16, 3, 2), "cond": True},
|
| 161 |
+
{"config": (16, 32, 32, 4, 2), "cond": True},
|
| 162 |
+
{"config": (16, 32, 32, 5, 2), "cond": True},
|
| 163 |
+
{"config": (64, 64, 128, 3, 4), "cond": True},
|
| 164 |
+
{"config": (128, 64, 32, 2, 2), "cond": True},
|
| 165 |
+
{"config": (128, 64, 64, 3, 8), "cond": True},
|
| 166 |
+
{"config": (128, 64, 128, 4, 8), "cond": True},
|
| 167 |
+
{"config": (128, 128, 32, 4, 4), "cond": True},
|
| 168 |
+
{"config": (128, 128, 64, 3, 8), "cond": True},
|
| 169 |
+
{"config": (128, 128, 64, 5, 4), "cond": True},
|
| 170 |
+
]
|
| 171 |
+
|
| 172 |
+
int8_mm_kernel_configs = [
|
| 173 |
+
{"config": (64, 64, 32, 2, 4), "cond": True},
|
| 174 |
+
{"config": (64, 128, 32, 3, 4), "cond": True},
|
| 175 |
+
{"config": (128, 64, 32, 3, 4), "cond": True},
|
| 176 |
+
{"config": (64, 128, 32, 4, 8), "cond": True},
|
| 177 |
+
{"config": (128, 64, 32, 4, 8), "cond": True},
|
| 178 |
+
{"config": (64, 32, 32, 5, 8), "cond": True},
|
| 179 |
+
{"config": (32, 64, 32, 5, 8), "cond": True},
|
| 180 |
+
{"config": (128, 128, 32, 2, 8), "cond": True},
|
| 181 |
+
{"config": (64, 64, 64, 3, 8), "cond": True},
|
| 182 |
+
# {"config": (32, 32, 128, 2, 4), "cond": True},
|
| 183 |
+
# {"config": (64, 64, 16, 2, 4), "cond": True},
|
| 184 |
+
# {"config": (32, 32, 16, 1, 2), "cond": True},
|
| 185 |
+
{"config": (128, 256, 128, 3, 8), "cond": torch.version.hip is None},
|
| 186 |
+
{"config": (256, 128, 128, 3, 8), "cond": torch.version.hip is None},
|
| 187 |
+
]
|
| 188 |
+
|
| 189 |
+
# Mixed precision kernel configs for small sizes of m for mm's like (16, 8192) x (8192, 8192).
|
| 190 |
+
mixed_mm_kernel_configs_small_m = [
|
| 191 |
+
{"config": (16, 128, 256, 3, 4), "cond": True},
|
| 192 |
+
{"config": (16, 128, 256, 5, 8), "cond": True},
|
| 193 |
+
]
|
| 194 |
+
|
| 195 |
+
mixed_mm_kernel_configs = (
|
| 196 |
+
mm_kernel_configs + mixed_mm_kernel_configs_small_m
|
| 197 |
+
if inductor_config.max_autotune_gemm_search_space != "EXHAUSTIVE"
|
| 198 |
+
else mm_kernel_configs
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
scaled_mm_kernel_configs = [
|
| 202 |
+
{"config": (128, 256, 32, 3, 8), "cond": True},
|
| 203 |
+
{"config": (256, 128, 32, 3, 8), "cond": True},
|
| 204 |
+
{"config": (256, 64, 32, 4, 4), "cond": True},
|
| 205 |
+
{"config": (64, 256, 32, 4, 4), "cond": True},
|
| 206 |
+
{"config": (128, 128, 32, 4, 4), "cond": True},
|
| 207 |
+
{"config": (128, 64, 32, 4, 4), "cond": True},
|
| 208 |
+
{"config": (64, 128, 32, 4, 4), "cond": True},
|
| 209 |
+
{"config": (128, 32, 32, 4, 4), "cond": True},
|
| 210 |
+
{"config": (64, 32, 32, 5, 2), "cond": True},
|
| 211 |
+
{"config": (256, 128, 128, 3, 8), "cond": True},
|
| 212 |
+
{"config": (256, 64, 128, 4, 4), "cond": True},
|
| 213 |
+
{"config": (64, 256, 128, 4, 4), "cond": True},
|
| 214 |
+
{"config": (128, 128, 128, 4, 4), "cond": True},
|
| 215 |
+
{"config": (128, 64, 64, 4, 4), "cond": True},
|
| 216 |
+
{"config": (64, 128, 64, 4, 4), "cond": True},
|
| 217 |
+
{"config": (128, 32, 64, 4, 4), "cond": True},
|
| 218 |
+
{"config": (64, 32, 64, 5, 2), "cond": True},
|
| 219 |
+
{"config": (16, 32, 32, 2, 2), "cond": True},
|
| 220 |
+
{"config": (16, 64, 32, 2, 2), "cond": True},
|
| 221 |
+
{"config": (16, 128, 32, 2, 4), "cond": True},
|
| 222 |
+
{"config": (16, 256, 32, 2, 4), "cond": True},
|
| 223 |
+
{"config": (16, 32, 64, 2, 2), "cond": True},
|
| 224 |
+
{"config": (16, 64, 64, 2, 2), "cond": True},
|
| 225 |
+
{"config": (16, 128, 64, 2, 4), "cond": True},
|
| 226 |
+
{"config": (16, 256, 64, 2, 4), "cond": True},
|
| 227 |
+
{"config": (32, 32, 32, 2, 2), "cond": True},
|
| 228 |
+
{"config": (32, 64, 32, 2, 2), "cond": True},
|
| 229 |
+
{"config": (32, 128, 32, 2, 4), "cond": True},
|
| 230 |
+
{"config": (32, 256, 32, 2, 4), "cond": True},
|
| 231 |
+
{"config": (32, 32, 64, 2, 2), "cond": True},
|
| 232 |
+
{"config": (32, 64, 64, 2, 2), "cond": True},
|
| 233 |
+
{"config": (32, 128, 64, 2, 4), "cond": True},
|
| 234 |
+
{"config": (32, 256, 64, 2, 4), "cond": True},
|
| 235 |
+
{"config": (16, 32, 32, 3, 2), "cond": True},
|
| 236 |
+
{"config": (16, 64, 32, 3, 2), "cond": True},
|
| 237 |
+
{"config": (16, 128, 32, 3, 4), "cond": True},
|
| 238 |
+
{"config": (16, 256, 32, 3, 4), "cond": True},
|
| 239 |
+
{"config": (16, 32, 64, 3, 2), "cond": True},
|
| 240 |
+
{"config": (16, 64, 64, 3, 2), "cond": True},
|
| 241 |
+
{"config": (16, 128, 64, 3, 4), "cond": True},
|
| 242 |
+
{"config": (16, 256, 64, 3, 4), "cond": True},
|
| 243 |
+
{"config": (32, 32, 32, 3, 2), "cond": True},
|
| 244 |
+
{"config": (32, 64, 32, 3, 2), "cond": True},
|
| 245 |
+
{"config": (32, 128, 32, 3, 4), "cond": True},
|
| 246 |
+
{"config": (32, 256, 32, 3, 4), "cond": True},
|
| 247 |
+
{"config": (32, 32, 64, 3, 2), "cond": True},
|
| 248 |
+
{"config": (32, 64, 64, 3, 2), "cond": True},
|
| 249 |
+
{"config": (32, 128, 64, 3, 4), "cond": True},
|
| 250 |
+
{"config": (32, 256, 64, 3, 4), "cond": True},
|
| 251 |
+
{"config": (16, 32, 32, 4, 2), "cond": True},
|
| 252 |
+
{"config": (16, 64, 32, 4, 2), "cond": True},
|
| 253 |
+
{"config": (16, 128, 32, 4, 4), "cond": True},
|
| 254 |
+
{"config": (16, 256, 32, 4, 4), "cond": True},
|
| 255 |
+
{"config": (16, 32, 64, 4, 2), "cond": True},
|
| 256 |
+
{"config": (16, 64, 64, 4, 2), "cond": True},
|
| 257 |
+
{"config": (16, 128, 64, 4, 4), "cond": True},
|
| 258 |
+
{"config": (16, 256, 64, 4, 4), "cond": True},
|
| 259 |
+
{"config": (32, 32, 32, 4, 2), "cond": True},
|
| 260 |
+
{"config": (32, 64, 32, 4, 2), "cond": True},
|
| 261 |
+
{"config": (32, 128, 32, 4, 4), "cond": True},
|
| 262 |
+
{"config": (32, 256, 32, 4, 4), "cond": True},
|
| 263 |
+
{"config": (32, 32, 64, 4, 2), "cond": True},
|
| 264 |
+
{"config": (32, 64, 64, 4, 2), "cond": True},
|
| 265 |
+
{"config": (32, 128, 64, 4, 4), "cond": True},
|
| 266 |
+
{"config": (32, 256, 64, 4, 4), "cond": True},
|
| 267 |
+
{"config": (16, 32, 32, 5, 2), "cond": True},
|
| 268 |
+
{"config": (16, 64, 32, 5, 2), "cond": True},
|
| 269 |
+
{"config": (16, 128, 32, 5, 4), "cond": True},
|
| 270 |
+
{"config": (16, 256, 32, 5, 4), "cond": True},
|
| 271 |
+
{"config": (16, 32, 64, 5, 2), "cond": True},
|
| 272 |
+
{"config": (16, 64, 64, 5, 2), "cond": True},
|
| 273 |
+
{"config": (16, 128, 64, 5, 4), "cond": True},
|
| 274 |
+
{"config": (16, 256, 64, 5, 4), "cond": True},
|
| 275 |
+
{"config": (32, 32, 32, 5, 2), "cond": True},
|
| 276 |
+
{"config": (32, 64, 32, 5, 2), "cond": True},
|
| 277 |
+
{"config": (32, 128, 32, 5, 4), "cond": True},
|
| 278 |
+
{"config": (32, 256, 32, 5, 4), "cond": True},
|
| 279 |
+
{"config": (32, 32, 64, 5, 2), "cond": True},
|
| 280 |
+
{"config": (32, 64, 64, 5, 2), "cond": True},
|
| 281 |
+
{"config": (32, 128, 64, 5, 4), "cond": True},
|
| 282 |
+
{"config": (32, 256, 64, 5, 4), "cond": True},
|
| 283 |
+
{"config": (16, 32, 32, 6, 2), "cond": True},
|
| 284 |
+
{"config": (16, 64, 32, 6, 2), "cond": True},
|
| 285 |
+
{"config": (16, 128, 32, 6, 4), "cond": True},
|
| 286 |
+
{"config": (16, 256, 32, 6, 4), "cond": True},
|
| 287 |
+
{"config": (16, 32, 64, 6, 2), "cond": True},
|
| 288 |
+
{"config": (16, 64, 64, 6, 2), "cond": True},
|
| 289 |
+
{"config": (16, 128, 64, 6, 4), "cond": True},
|
| 290 |
+
{"config": (16, 256, 64, 6, 4), "cond": True},
|
| 291 |
+
{"config": (32, 32, 32, 6, 2), "cond": True},
|
| 292 |
+
{"config": (32, 64, 32, 6, 2), "cond": True},
|
| 293 |
+
{"config": (32, 128, 32, 6, 4), "cond": True},
|
| 294 |
+
{"config": (32, 256, 32, 6, 4), "cond": True},
|
| 295 |
+
{"config": (32, 32, 64, 6, 2), "cond": True},
|
| 296 |
+
{"config": (32, 64, 64, 6, 2), "cond": True},
|
| 297 |
+
{"config": (32, 128, 64, 6, 4), "cond": True},
|
| 298 |
+
{"config": (32, 256, 64, 6, 4), "cond": True},
|
| 299 |
+
]
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
# Create filtered list of configs based on cond evaluation
|
| 303 |
+
mm_platform_configs = tuple(
|
| 304 |
+
cast(Tuple[int, int, int, int, int], config["config"])
|
| 305 |
+
for config in mm_kernel_configs
|
| 306 |
+
if config["cond"]
|
| 307 |
+
)
|
| 308 |
+
extra_mm_platform_configs = tuple(
|
| 309 |
+
cast(Tuple[int, int, int, int, int], config["config"])
|
| 310 |
+
for config in extra_mm_kernel_configs
|
| 311 |
+
if config["cond"]
|
| 312 |
+
)
|
| 313 |
+
int8_platform_configs = tuple(
|
| 314 |
+
cast(Tuple[int, int, int, int, int], config["config"])
|
| 315 |
+
for config in int8_mm_kernel_configs
|
| 316 |
+
if config["cond"]
|
| 317 |
+
)
|
| 318 |
+
mixed_mm_platform_configs = tuple(
|
| 319 |
+
cast(Tuple[int, int, int, int, int], config["config"])
|
| 320 |
+
for config in mixed_mm_kernel_configs
|
| 321 |
+
if config["cond"]
|
| 322 |
+
)
|
| 323 |
+
scaled_mm_platform_configs = tuple(
|
| 324 |
+
cast(Tuple[int, int, int, int, int], config["config"])
|
| 325 |
+
for config in scaled_mm_kernel_configs
|
| 326 |
+
if config["cond"]
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
# On ROCm convert num_stages to 0 to enable software pipelining
|
| 330 |
+
if torch.version.hip:
|
| 331 |
+
mm_platform_configs = tuple(
|
| 332 |
+
(config[0], config[1], config[2], 0, config[4])
|
| 333 |
+
for config in mm_platform_configs
|
| 334 |
+
)
|
| 335 |
+
extra_mm_platform_configs = tuple(
|
| 336 |
+
(config[0], config[1], config[2], 0, config[4])
|
| 337 |
+
for config in extra_mm_platform_configs
|
| 338 |
+
)
|
| 339 |
+
int8_platform_configs = tuple(
|
| 340 |
+
(config[0], config[1], config[2], 0, config[4])
|
| 341 |
+
for config in mm_platform_configs
|
| 342 |
+
)
|
| 343 |
+
mixed_mm_platform_configs = tuple(
|
| 344 |
+
(config[0], config[1], config[2], 0, config[4])
|
| 345 |
+
for config in mixed_mm_platform_configs
|
| 346 |
+
)
|
| 347 |
+
scaled_mm_platform_configs = tuple(
|
| 348 |
+
(config[0], config[1], config[2], 0, config[4])
|
| 349 |
+
for config in scaled_mm_platform_configs
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
mm_configs = functools.partial(
|
| 353 |
+
filtered_configs,
|
| 354 |
+
configs=mm_platform_configs,
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
extra_mm_configs = functools.partial(
|
| 358 |
+
filtered_configs,
|
| 359 |
+
configs=extra_mm_platform_configs,
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
int8_mm_configs = functools.partial(
|
| 363 |
+
filtered_configs,
|
| 364 |
+
configs=int8_platform_configs,
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
mixed_mm_configs = functools.partial(
|
| 368 |
+
filtered_configs,
|
| 369 |
+
configs=mixed_mm_platform_configs,
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
scaled_mm_configs = functools.partial(
|
| 373 |
+
filtered_configs,
|
| 374 |
+
configs=scaled_mm_platform_configs,
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def mm_grid(m, n, meta):
|
| 379 |
+
"""
|
| 380 |
+
The CUDA grid size for matmul triton templates.
|
| 381 |
+
"""
|
| 382 |
+
return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), 1, 1)
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def acc_type(dtype):
|
| 386 |
+
if dtype in (torch.float16, torch.bfloat16):
|
| 387 |
+
return "tl.float32"
|
| 388 |
+
return f"tl.{dtype}".replace("torch.", "")
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def mm_options(config, sym_m, sym_n, sym_k, layout, b_prologue_cast_type=None):
|
| 392 |
+
"""
|
| 393 |
+
Common options to matmul triton templates.
|
| 394 |
+
"""
|
| 395 |
+
even_k_symbolic = (
|
| 396 |
+
# it isn't worth guarding on this
|
| 397 |
+
sympy.gcd(sym_k, config.kwargs["BLOCK_K"])
|
| 398 |
+
== config.kwargs["BLOCK_K"]
|
| 399 |
+
)
|
| 400 |
+
allow_tf32 = torch.backends.cuda.matmul.allow_tf32 and (
|
| 401 |
+
not inductor_config.force_same_precision
|
| 402 |
+
or ((sym_m % 16) == 0 and (sym_n % 16) == 0 and (sym_k % 8) == 0)
|
| 403 |
+
)
|
| 404 |
+
return dict(
|
| 405 |
+
GROUP_M=8,
|
| 406 |
+
EVEN_K=even_k_symbolic,
|
| 407 |
+
ALLOW_TF32=allow_tf32,
|
| 408 |
+
ACC_TYPE=acc_type(layout.dtype),
|
| 409 |
+
B_PROLOGUE_CAST_TYPE=b_prologue_cast_type,
|
| 410 |
+
num_stages=config.num_stages,
|
| 411 |
+
num_warps=config.num_warps,
|
| 412 |
+
**config.kwargs,
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
def mm_args(
|
| 417 |
+
mat1,
|
| 418 |
+
mat2,
|
| 419 |
+
*others,
|
| 420 |
+
layout=None,
|
| 421 |
+
out_dtype=None,
|
| 422 |
+
use_4x2_dim=False,
|
| 423 |
+
mat2_transposed=False,
|
| 424 |
+
):
|
| 425 |
+
"""
|
| 426 |
+
Common arg processing for mm,bmm,addmm,etc
|
| 427 |
+
"""
|
| 428 |
+
mat1, mat2 = realize_inputs(mat1, mat2)
|
| 429 |
+
*b1, m, k1 = mat1.get_size()
|
| 430 |
+
if mat2_transposed:
|
| 431 |
+
*b2, n, k2 = mat2.get_size()
|
| 432 |
+
else:
|
| 433 |
+
*b2, k2, n = mat2.get_size()
|
| 434 |
+
b = [V.graph.sizevars.guard_equals(a, b) for a, b in zip(b1, b2)]
|
| 435 |
+
if use_4x2_dim:
|
| 436 |
+
k2 = k2 * 2
|
| 437 |
+
k = V.graph.sizevars.guard_equals(k1, k2)
|
| 438 |
+
if layout is None:
|
| 439 |
+
from torch._inductor.ir import FixedLayout
|
| 440 |
+
|
| 441 |
+
if out_dtype is None:
|
| 442 |
+
out_dtype = mat1.get_dtype()
|
| 443 |
+
|
| 444 |
+
layout = FixedLayout(
|
| 445 |
+
mat1.get_device(),
|
| 446 |
+
out_dtype,
|
| 447 |
+
[*b, m, n],
|
| 448 |
+
)
|
| 449 |
+
else:
|
| 450 |
+
assert out_dtype is None, "out_dtype is ignored if layout is specified."
|
| 451 |
+
from ..lowering import expand
|
| 452 |
+
|
| 453 |
+
others = [realize_inputs(expand(x, layout.size)) for x in others]
|
| 454 |
+
|
| 455 |
+
return [m, n, k, layout, mat1, mat2, *others]
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def addmm_epilogue(dtype, alpha, beta):
|
| 459 |
+
def epilogue(acc, bias):
|
| 460 |
+
if alpha != 1:
|
| 461 |
+
acc = V.ops.mul(acc, V.ops.constant(alpha, dtype))
|
| 462 |
+
if beta != 1:
|
| 463 |
+
bias = V.ops.mul(bias, V.ops.constant(beta, dtype))
|
| 464 |
+
return V.ops.add(acc, bias)
|
| 465 |
+
|
| 466 |
+
return epilogue
|
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/mm_plus_mm.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import functools
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from ..lowering import lowerings
|
| 7 |
+
from ..select_algorithm import (
|
| 8 |
+
autotune_select_algorithm,
|
| 9 |
+
ExternKernelChoice,
|
| 10 |
+
TritonTemplate,
|
| 11 |
+
)
|
| 12 |
+
from ..utils import use_aten_gemm_kernels, use_triton_template
|
| 13 |
+
from ..virtualized import V
|
| 14 |
+
from .mm_common import mm_args, mm_grid, mm_options
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
aten = torch.ops.aten
|
| 18 |
+
|
| 19 |
+
aten_mm_plus_mm = ExternKernelChoice(
|
| 20 |
+
torch.ops.inductor._mm_plus_mm, "torch::inductor::_mm_plus_mm"
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
mm_plus_mm_template = TritonTemplate(
|
| 24 |
+
name="mm_plus_mm",
|
| 25 |
+
grid=mm_grid,
|
| 26 |
+
debug=False,
|
| 27 |
+
source=r"""
|
| 28 |
+
{{def_kernel("A", "B", "C", "D")}}
|
| 29 |
+
M = {{size("A", 0)}}
|
| 30 |
+
N = {{size("B", 1)}}
|
| 31 |
+
K1 = {{size("A", 1)}}
|
| 32 |
+
if M * N == 0:
|
| 33 |
+
# early exit due to zero-size input(s)
|
| 34 |
+
return
|
| 35 |
+
# K2 = {{size("C", 1)}}
|
| 36 |
+
stride_am = {{stride("A", 0)}}
|
| 37 |
+
stride_ak = {{stride("A", 1)}}
|
| 38 |
+
stride_bk = {{stride("B", 0)}}
|
| 39 |
+
stride_bn = {{stride("B", 1)}}
|
| 40 |
+
stride_cm = {{stride("C", 0)}}
|
| 41 |
+
stride_ck = {{stride("C", 1)}}
|
| 42 |
+
stride_dk = {{stride("D", 0)}}
|
| 43 |
+
stride_dn = {{stride("D", 1)}}
|
| 44 |
+
|
| 45 |
+
# based on triton.ops.matmul
|
| 46 |
+
pid = tl.program_id(0)
|
| 47 |
+
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
| 48 |
+
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
| 49 |
+
|
| 50 |
+
# re-order program ID for better L2 performance
|
| 51 |
+
width = GROUP_M * grid_n
|
| 52 |
+
group_id = pid // width
|
| 53 |
+
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
| 54 |
+
pid_m = group_id * GROUP_M + (pid % group_size)
|
| 55 |
+
pid_n = (pid % width) // (group_size)
|
| 56 |
+
|
| 57 |
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 58 |
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 59 |
+
|
| 60 |
+
if (((stride_am == 1 and stride_ak == M) or (stride_am == K1 and stride_ak == 1))
|
| 61 |
+
and ((stride_cm == 1 and stride_ck == M) or (stride_cm == K1 and stride_ck == 1))):
|
| 62 |
+
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
| 63 |
+
else:
|
| 64 |
+
ram = rm % M
|
| 65 |
+
|
| 66 |
+
if (((stride_bk == 1 and stride_bn == K1) or (stride_bk == N and stride_bn == 1))
|
| 67 |
+
and ((stride_dk == 1 and stride_dn == K1) or (stride_dk == N and stride_dn == 1))):
|
| 68 |
+
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
| 69 |
+
else:
|
| 70 |
+
rbn = rn % N
|
| 71 |
+
|
| 72 |
+
rk = tl.arange(0, BLOCK_K)
|
| 73 |
+
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
| 74 |
+
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
| 75 |
+
C = C + (ram[:, None] * stride_cm + rk[None, :] * stride_ck)
|
| 76 |
+
D = D + (rk[:, None] * stride_dk + rbn[None, :] * stride_dn)
|
| 77 |
+
|
| 78 |
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
| 79 |
+
for k1 in range(K1, 0, -BLOCK_K):
|
| 80 |
+
# First matmul with A @ B
|
| 81 |
+
if EVEN_K:
|
| 82 |
+
a = tl.load(A)
|
| 83 |
+
b = tl.load(B)
|
| 84 |
+
else:
|
| 85 |
+
a = tl.load(A, mask=rk[None, :] < k1, other=0.)
|
| 86 |
+
b = tl.load(B, mask=rk[:, None] < k1, other=0.)
|
| 87 |
+
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
|
| 88 |
+
A += BLOCK_K * stride_ak
|
| 89 |
+
B += BLOCK_K * stride_bk
|
| 90 |
+
|
| 91 |
+
for k2 in range(K1, 0, -BLOCK_K):
|
| 92 |
+
|
| 93 |
+
# Second matmul with C @ D
|
| 94 |
+
if EVEN_K:
|
| 95 |
+
c = tl.load(C)
|
| 96 |
+
d = tl.load(D)
|
| 97 |
+
else:
|
| 98 |
+
c = tl.load(C, mask=rk[None, :] < k2, other=0.)
|
| 99 |
+
d = tl.load(D, mask=rk[:, None] < k2, other=0.)
|
| 100 |
+
acc += tl.dot(c, d, allow_tf32=ALLOW_TF32)
|
| 101 |
+
C += BLOCK_K * stride_ck
|
| 102 |
+
D += BLOCK_K * stride_dk
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
idx_m = rm[:, None]
|
| 106 |
+
idx_n = rn[None, :]
|
| 107 |
+
mask = (idx_m < M) & (idx_n < N)
|
| 108 |
+
|
| 109 |
+
# inductor generates a suffix
|
| 110 |
+
{{store_output(("idx_m", "idx_n"), "acc", "mask")}}
|
| 111 |
+
""",
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
@functools.lru_cache(None)
|
| 116 |
+
def mm_configs():
|
| 117 |
+
import triton
|
| 118 |
+
|
| 119 |
+
# List of dictionaries to store the kernel configs. Configs that evaluate to true
|
| 120 |
+
# will be utilised on the target platform
|
| 121 |
+
mm_triton_configs = [
|
| 122 |
+
{
|
| 123 |
+
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32},
|
| 124 |
+
"num_stages": 2,
|
| 125 |
+
"num_warps": 4,
|
| 126 |
+
"cond": True,
|
| 127 |
+
},
|
| 128 |
+
{
|
| 129 |
+
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32},
|
| 130 |
+
"num_stages": 3,
|
| 131 |
+
"num_warps": 8,
|
| 132 |
+
"cond": True,
|
| 133 |
+
},
|
| 134 |
+
{
|
| 135 |
+
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32},
|
| 136 |
+
"num_stages": 4,
|
| 137 |
+
"num_warps": 16,
|
| 138 |
+
"cond": True,
|
| 139 |
+
},
|
| 140 |
+
{
|
| 141 |
+
"config": {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32},
|
| 142 |
+
"num_stages": 4,
|
| 143 |
+
"num_warps": 8,
|
| 144 |
+
"cond": True,
|
| 145 |
+
},
|
| 146 |
+
{
|
| 147 |
+
"config": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32},
|
| 148 |
+
"num_stages": 4,
|
| 149 |
+
"num_warps": 8,
|
| 150 |
+
"cond": True,
|
| 151 |
+
},
|
| 152 |
+
{
|
| 153 |
+
"config": {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32},
|
| 154 |
+
"num_stages": 1,
|
| 155 |
+
"num_warps": 8,
|
| 156 |
+
"cond": True,
|
| 157 |
+
},
|
| 158 |
+
{
|
| 159 |
+
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64},
|
| 160 |
+
"num_stages": 1,
|
| 161 |
+
"num_warps": 8,
|
| 162 |
+
"cond": True,
|
| 163 |
+
},
|
| 164 |
+
{
|
| 165 |
+
"config": {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 128},
|
| 166 |
+
"num_stages": 1,
|
| 167 |
+
"num_warps": 8,
|
| 168 |
+
"cond": torch.version.hip is None,
|
| 169 |
+
},
|
| 170 |
+
{
|
| 171 |
+
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 16},
|
| 172 |
+
"num_stages": 2,
|
| 173 |
+
"num_warps": 4,
|
| 174 |
+
"cond": True,
|
| 175 |
+
},
|
| 176 |
+
{
|
| 177 |
+
"config": {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 16},
|
| 178 |
+
"num_stages": 1,
|
| 179 |
+
"num_warps": 2,
|
| 180 |
+
"cond": True,
|
| 181 |
+
},
|
| 182 |
+
]
|
| 183 |
+
|
| 184 |
+
# Filter out configs in which cond evaluates to true
|
| 185 |
+
# On ROCm convert num_stages to 1 as pipelining provides no benefit
|
| 186 |
+
if torch.version.hip:
|
| 187 |
+
filtered_configs = [
|
| 188 |
+
triton.Config(c["config"], num_stages=1, num_warps=c["num_warps"])
|
| 189 |
+
for c in mm_triton_configs
|
| 190 |
+
if c["cond"]
|
| 191 |
+
]
|
| 192 |
+
else:
|
| 193 |
+
filtered_configs = [
|
| 194 |
+
triton.Config(
|
| 195 |
+
c["config"], num_stages=c["num_stages"], num_warps=c["num_warps"]
|
| 196 |
+
)
|
| 197 |
+
for c in mm_triton_configs
|
| 198 |
+
if c["cond"]
|
| 199 |
+
]
|
| 200 |
+
|
| 201 |
+
return filtered_configs
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
|
| 205 |
+
"""
|
| 206 |
+
Computes mm(mat1, mat2) + mm(mat3, mat4)
|
| 207 |
+
"""
|
| 208 |
+
m1, n1, k1, layout1, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
|
| 209 |
+
m2, n2, _, layout2, mat3, mat4 = mm_args(mat3, mat4, layout=layout)
|
| 210 |
+
# Optimization is optional, because we can always just not do the fusion
|
| 211 |
+
if (
|
| 212 |
+
m1 * n1 == 0
|
| 213 |
+
or m2 * n2 == 0
|
| 214 |
+
or not V.graph.sizevars.statically_known_list_equals(
|
| 215 |
+
mat1.get_size(), mat3.get_size()
|
| 216 |
+
)
|
| 217 |
+
or not V.graph.sizevars.statically_known_list_equals(
|
| 218 |
+
mat2.get_size(), mat4.get_size()
|
| 219 |
+
)
|
| 220 |
+
):
|
| 221 |
+
# TODO(jansel): support different K values when this is fixed:
|
| 222 |
+
# https://github.com/openai/triton/issues/967
|
| 223 |
+
return lowerings[aten.add](
|
| 224 |
+
lowerings[aten.mm](mat1, mat2), lowerings[aten.mm](mat3, mat4)
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
assert layout1 == layout2
|
| 228 |
+
# options to tune from
|
| 229 |
+
choices = (
|
| 230 |
+
[aten_mm_plus_mm.bind((mat1, mat2, mat3, mat4), layout1)]
|
| 231 |
+
if use_aten_gemm_kernels()
|
| 232 |
+
else []
|
| 233 |
+
)
|
| 234 |
+
if use_triton_template(layout1):
|
| 235 |
+
for config in mm_configs():
|
| 236 |
+
# see https://github.com/openai/triton/issues/1298
|
| 237 |
+
# BLOCK_K = K causes llvm error
|
| 238 |
+
if V.graph.sizevars.statically_known_lt(config.kwargs["BLOCK_K"], k1):
|
| 239 |
+
mm_plus_mm_template.maybe_append_choice(
|
| 240 |
+
choices,
|
| 241 |
+
input_nodes=(mat1, mat2, mat3, mat4),
|
| 242 |
+
layout=layout1,
|
| 243 |
+
**mm_options(config, m1, n1, k1, layout1),
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
return autotune_select_algorithm(
|
| 247 |
+
"mm_plus_mm", choices, [mat1, mat2, mat3, mat4], layout1
|
| 248 |
+
)
|
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/mm_scaled.py
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import sympy
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from .. import config as inductor_config
|
| 9 |
+
from ..ir import ChoiceCaller, Layout, StorageBox, TensorBox
|
| 10 |
+
from ..lowering import add_layout_constraint, constrain_to_fx_strides, register_lowering
|
| 11 |
+
from ..select_algorithm import (
|
| 12 |
+
autotune_select_algorithm,
|
| 13 |
+
ExternKernelChoice,
|
| 14 |
+
NoValidChoicesError,
|
| 15 |
+
realize_inputs,
|
| 16 |
+
TritonTemplate,
|
| 17 |
+
)
|
| 18 |
+
from ..utils import use_aten_gemm_kernels, use_triton_template
|
| 19 |
+
from .mm import _is_static_problem # TODO(yangsiyu) move to mm_common
|
| 20 |
+
from .mm_common import mm_args, mm_grid, scaled_mm_configs
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
log = logging.getLogger(__name__)
|
| 24 |
+
aten = torch.ops.aten
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
scaled_mm_template = TritonTemplate(
|
| 28 |
+
name="scaled_mm",
|
| 29 |
+
grid=mm_grid,
|
| 30 |
+
source=r"""
|
| 31 |
+
{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}}
|
| 32 |
+
M = {{size("A", 0)}}
|
| 33 |
+
N = {{size("B", 1)}}
|
| 34 |
+
K = {{size("A", 1)}}
|
| 35 |
+
if M * N == 0:
|
| 36 |
+
# early exit due to zero-size input(s)
|
| 37 |
+
return
|
| 38 |
+
stride_am = {{stride("A", 0)}}
|
| 39 |
+
stride_ak = {{stride("A", 1)}}
|
| 40 |
+
stride_bk = {{stride("B", 0)}}
|
| 41 |
+
stride_bn = {{stride("B", 1)}}
|
| 42 |
+
|
| 43 |
+
# based on triton.ops.matmul
|
| 44 |
+
pid = tl.program_id(0)
|
| 45 |
+
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
| 46 |
+
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
| 47 |
+
|
| 48 |
+
# re-order program ID for better L2 performance
|
| 49 |
+
width = GROUP_M * grid_n
|
| 50 |
+
group_id = pid // width
|
| 51 |
+
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
| 52 |
+
pid_m = group_id * GROUP_M + (pid % group_size)
|
| 53 |
+
pid_n = (pid % width) // (group_size)
|
| 54 |
+
|
| 55 |
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 56 |
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 57 |
+
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
| 58 |
+
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
| 59 |
+
rk = tl.arange(0, BLOCK_K)
|
| 60 |
+
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
| 61 |
+
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
| 62 |
+
|
| 63 |
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
| 64 |
+
for k in range(K, 0, -BLOCK_K):
|
| 65 |
+
if EVEN_K:
|
| 66 |
+
a = tl.load(A)
|
| 67 |
+
b = tl.load(B)
|
| 68 |
+
else:
|
| 69 |
+
a = tl.load(A, mask=rk[None, :] < k, other=0.)
|
| 70 |
+
b = tl.load(B, mask=rk[:, None] < k, other=0.)
|
| 71 |
+
if B_PROLOGUE_CAST_TYPE is not None:
|
| 72 |
+
b = b.to(B_PROLOGUE_CAST_TYPE)
|
| 73 |
+
if USE_FAST_ACCUM:
|
| 74 |
+
acc = tl.dot(a, b, acc, out_dtype=ACC_TYPE)
|
| 75 |
+
else:
|
| 76 |
+
acc += tl.dot(a, b, out_dtype=ACC_TYPE)
|
| 77 |
+
A += BLOCK_K * stride_ak
|
| 78 |
+
B += BLOCK_K * stride_bk
|
| 79 |
+
|
| 80 |
+
if SCALING_ROWWISE:
|
| 81 |
+
inv_a_scale_row = tl.load(A_inverse_scale + rm, mask=rm < M)
|
| 82 |
+
inv_b_scale_row = tl.load(B_inverse_scale + rn, mask=rn < N)
|
| 83 |
+
inv_scale_row = inv_a_scale_row[:, None] * inv_b_scale_row[None, :]
|
| 84 |
+
acc *= inv_scale_row
|
| 85 |
+
else:
|
| 86 |
+
# for tensor-wise scaling, the scales are scalars
|
| 87 |
+
inv_a_scale = tl.load(A_inverse_scale)
|
| 88 |
+
inv_b_scale = tl.load(B_inverse_scale)
|
| 89 |
+
inv_scale = inv_a_scale * inv_b_scale
|
| 90 |
+
acc *= inv_scale
|
| 91 |
+
|
| 92 |
+
# rematerialize rm and rn to save registers
|
| 93 |
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 94 |
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 95 |
+
|
| 96 |
+
idx_m = rm[:, None]
|
| 97 |
+
idx_n = rn[None, :]
|
| 98 |
+
mask = (idx_m < M) & (idx_n < N)
|
| 99 |
+
|
| 100 |
+
# inductor generates a suffix
|
| 101 |
+
{{store_output(("idx_m", "idx_n"), "acc", "mask")}}
|
| 102 |
+
""",
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# Inductor does not allow optional tensor input arguments currently (pass None as an
|
| 107 |
+
# input node to template choices), but since for _scaled_mm there is only one such arg
|
| 108 |
+
# (bias), work around by having a second template when bias is provided.
|
| 109 |
+
scaled_mm_bias_template = TritonTemplate(
|
| 110 |
+
name="scaled_mm_bias",
|
| 111 |
+
grid=mm_grid,
|
| 112 |
+
source=r"""
|
| 113 |
+
{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale", "bias_ptr")}}
|
| 114 |
+
M = {{size("A", 0)}}
|
| 115 |
+
N = {{size("B", 1)}}
|
| 116 |
+
K = {{size("A", 1)}}
|
| 117 |
+
if M * N == 0:
|
| 118 |
+
# early exit due to zero-size input(s)
|
| 119 |
+
return
|
| 120 |
+
stride_am = {{stride("A", 0)}}
|
| 121 |
+
stride_ak = {{stride("A", 1)}}
|
| 122 |
+
stride_bk = {{stride("B", 0)}}
|
| 123 |
+
stride_bn = {{stride("B", 1)}}
|
| 124 |
+
|
| 125 |
+
# based on triton.ops.matmul
|
| 126 |
+
pid = tl.program_id(0)
|
| 127 |
+
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
| 128 |
+
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
| 129 |
+
|
| 130 |
+
# re-order program ID for better L2 performance
|
| 131 |
+
width = GROUP_M * grid_n
|
| 132 |
+
group_id = pid // width
|
| 133 |
+
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
| 134 |
+
pid_m = group_id * GROUP_M + (pid % group_size)
|
| 135 |
+
pid_n = (pid % width) // (group_size)
|
| 136 |
+
|
| 137 |
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 138 |
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 139 |
+
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
| 140 |
+
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
| 141 |
+
rk = tl.arange(0, BLOCK_K)
|
| 142 |
+
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
| 143 |
+
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
| 144 |
+
|
| 145 |
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
| 146 |
+
for k in range(K, 0, -BLOCK_K):
|
| 147 |
+
if EVEN_K:
|
| 148 |
+
a = tl.load(A)
|
| 149 |
+
b = tl.load(B)
|
| 150 |
+
else:
|
| 151 |
+
a = tl.load(A, mask=rk[None, :] < k, other=0.)
|
| 152 |
+
b = tl.load(B, mask=rk[:, None] < k, other=0.)
|
| 153 |
+
if B_PROLOGUE_CAST_TYPE is not None:
|
| 154 |
+
b = b.to(B_PROLOGUE_CAST_TYPE)
|
| 155 |
+
if USE_FAST_ACCUM:
|
| 156 |
+
acc = tl.dot(a, b, acc, out_dtype=ACC_TYPE)
|
| 157 |
+
else:
|
| 158 |
+
acc += tl.dot(a, b, out_dtype=ACC_TYPE)
|
| 159 |
+
A += BLOCK_K * stride_ak
|
| 160 |
+
B += BLOCK_K * stride_bk
|
| 161 |
+
|
| 162 |
+
if SCALING_ROWWISE:
|
| 163 |
+
inv_a_scale_row = tl.load(A_inverse_scale + rm, mask=rm < M)
|
| 164 |
+
inv_b_scale_row = tl.load(B_inverse_scale + rn, mask=rn < N)
|
| 165 |
+
inv_scale_row = inv_a_scale_row[:, None] * inv_b_scale_row[None, :]
|
| 166 |
+
acc *= inv_scale_row
|
| 167 |
+
else:
|
| 168 |
+
# for tensor-wise scaling, the scales are scalars
|
| 169 |
+
inv_a_scale = tl.load(A_inverse_scale)
|
| 170 |
+
inv_b_scale = tl.load(B_inverse_scale)
|
| 171 |
+
inv_scale = inv_a_scale * inv_b_scale
|
| 172 |
+
acc *= inv_scale
|
| 173 |
+
|
| 174 |
+
# rematerialize rm and rn to save registers
|
| 175 |
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 176 |
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 177 |
+
|
| 178 |
+
# bias
|
| 179 |
+
bias = tl.load(bias_ptr + rn, mask=rn < N)
|
| 180 |
+
acc += bias
|
| 181 |
+
|
| 182 |
+
idx_m = rm[:, None]
|
| 183 |
+
idx_n = rn[None, :]
|
| 184 |
+
mask = (idx_m < M) & (idx_n < N)
|
| 185 |
+
|
| 186 |
+
# inductor generates a suffix
|
| 187 |
+
{{store_output(("idx_m", "idx_n"), "acc", "mask")}}
|
| 188 |
+
""",
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
aten__fp8_mm = ExternKernelChoice(torch._scaled_mm, "at::_scaled_mm")
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def are_compatible_scales(size_a: List[int], size_b: List[int]) -> bool:
|
| 196 |
+
# Same sized scales are compatable
|
| 197 |
+
if len(size_a) == len(size_b):
|
| 198 |
+
return True
|
| 199 |
+
|
| 200 |
+
# Both need to be scalars or len(1) tensors
|
| 201 |
+
if len(size_a) <= 1 and len(size_b) <= 1:
|
| 202 |
+
return True
|
| 203 |
+
|
| 204 |
+
return False
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def scaled_mm_options( # type: ignore[no-untyped-def]
|
| 208 |
+
config, # triton.Config
|
| 209 |
+
sym_m: sympy.core.numbers.Integer,
|
| 210 |
+
sym_n: sympy.core.numbers.Integer,
|
| 211 |
+
sym_k: sympy.core.numbers.Integer,
|
| 212 |
+
layout: Layout,
|
| 213 |
+
scale_a: StorageBox,
|
| 214 |
+
scale_b: StorageBox,
|
| 215 |
+
use_fast_accum: bool,
|
| 216 |
+
b_prologue_cast_type: Optional[str] = None,
|
| 217 |
+
) -> Dict[str, Any]:
|
| 218 |
+
even_k_symbolic = (
|
| 219 |
+
sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) == config.kwargs["BLOCK_K"]
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
size_a, size_b = scale_a.get_size(), scale_b.get_size()
|
| 223 |
+
assert are_compatible_scales(size_a, size_b), (
|
| 224 |
+
"Expect scale_a and scale_b to be either both scalars (including single-element tensors) "
|
| 225 |
+
f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}."
|
| 226 |
+
)
|
| 227 |
+
return dict(
|
| 228 |
+
GROUP_M=8,
|
| 229 |
+
EVEN_K=even_k_symbolic,
|
| 230 |
+
ACC_TYPE="tl.float32",
|
| 231 |
+
B_PROLOGUE_CAST_TYPE=b_prologue_cast_type,
|
| 232 |
+
USE_FAST_ACCUM=use_fast_accum,
|
| 233 |
+
num_stages=config.num_stages,
|
| 234 |
+
num_warps=config.num_warps,
|
| 235 |
+
# tensor-wise scaling if scalar scales
|
| 236 |
+
SCALING_ROWWISE=len(scale_a.get_size()) == 2,
|
| 237 |
+
**config.kwargs,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
add_layout_constraint(aten._scaled_mm.default, constrain_to_fx_strides)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
@register_lowering(aten._scaled_mm.default, type_promotion_kind=None) # type: ignore[misc]
|
| 245 |
+
def tuned_scaled_mm(
|
| 246 |
+
mat_a: TensorBox,
|
| 247 |
+
mat_b: TensorBox,
|
| 248 |
+
scale_a: TensorBox,
|
| 249 |
+
scale_b: TensorBox,
|
| 250 |
+
bias: Optional[TensorBox] = None,
|
| 251 |
+
scale_result: Optional[TensorBox] = None,
|
| 252 |
+
out_dtype: Optional[torch.dtype] = None,
|
| 253 |
+
use_fast_accum: bool = False,
|
| 254 |
+
layout: Optional[Layout] = None,
|
| 255 |
+
) -> TensorBox:
|
| 256 |
+
m, n, k, layout, mat_a, mat_b = mm_args(
|
| 257 |
+
mat_a, mat_b, layout=layout, out_dtype=out_dtype
|
| 258 |
+
)
|
| 259 |
+
scale_a, scale_b = realize_inputs(scale_a, scale_b)
|
| 260 |
+
|
| 261 |
+
input_nodes: Tuple[Any, ...]
|
| 262 |
+
# workaround for Inductor not supporting optional tensor input arguments
|
| 263 |
+
if bias is None:
|
| 264 |
+
input_nodes = (mat_a, mat_b, scale_a, scale_b)
|
| 265 |
+
triton_template = scaled_mm_template
|
| 266 |
+
else:
|
| 267 |
+
bias = realize_inputs(bias)
|
| 268 |
+
input_nodes = (mat_a, mat_b, scale_a, scale_b, bias)
|
| 269 |
+
triton_template = scaled_mm_bias_template
|
| 270 |
+
|
| 271 |
+
aten_choice = aten__fp8_mm.bind(
|
| 272 |
+
input_nodes, layout, out_dtype=out_dtype, use_fast_accum=use_fast_accum
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
choices: List[ChoiceCaller] = []
|
| 276 |
+
if use_aten_gemm_kernels():
|
| 277 |
+
choices.append(aten_choice)
|
| 278 |
+
|
| 279 |
+
static_shape, is_nonzero = _is_static_problem([mat_a, mat_b], layout)
|
| 280 |
+
if is_nonzero and use_triton_template(layout, enable_float8=True):
|
| 281 |
+
for config in scaled_mm_configs(m, n, k):
|
| 282 |
+
if k == 16 and config.kwargs["BLOCK_M"] >= 64:
|
| 283 |
+
continue # Triton crashes in this case
|
| 284 |
+
kwargs = scaled_mm_options(
|
| 285 |
+
config, m, n, k, layout, scale_a, scale_b, use_fast_accum
|
| 286 |
+
)
|
| 287 |
+
# possibly appends a TritonTemplateCaller to choices
|
| 288 |
+
triton_template.maybe_append_choice(
|
| 289 |
+
choices,
|
| 290 |
+
input_nodes=input_nodes,
|
| 291 |
+
layout=layout,
|
| 292 |
+
**kwargs,
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
if (
|
| 296 |
+
len(choices) == 0
|
| 297 |
+
and not use_aten_gemm_kernels()
|
| 298 |
+
and inductor_config.autotune_fallback_to_aten
|
| 299 |
+
):
|
| 300 |
+
log.warning("No choices for scaled_mm, using ATen backend as fallback")
|
| 301 |
+
return aten_choice.output_node()
|
| 302 |
+
|
| 303 |
+
try:
|
| 304 |
+
return autotune_select_algorithm("scaled_mm", choices, input_nodes, layout)
|
| 305 |
+
except NoValidChoicesError:
|
| 306 |
+
if not inductor_config.autotune_fallback_to_aten:
|
| 307 |
+
raise
|
| 308 |
+
log.warning(
|
| 309 |
+
"All choices for scaled_mm were invalid, using ATen backend as fallback"
|
| 310 |
+
)
|
| 311 |
+
return aten_choice.output_node()
|
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/unpack_mixed_mm.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import logging
|
| 3 |
+
from typing import List, TYPE_CHECKING
|
| 4 |
+
|
| 5 |
+
from ..select_algorithm import autotune_select_algorithm, TritonTemplate
|
| 6 |
+
from .mm_common import mm_args, mm_configs, mm_grid, mm_options
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
if TYPE_CHECKING:
|
| 10 |
+
from ..ir import ChoiceCaller
|
| 11 |
+
|
| 12 |
+
log = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
uint4x2_mixed_mm_template = TritonTemplate(
|
| 15 |
+
name="uint4x2_mixed_mm",
|
| 16 |
+
grid=mm_grid,
|
| 17 |
+
source=r"""
|
| 18 |
+
{{def_kernel("A", "B")}}
|
| 19 |
+
M = {{size("A", 0)}}
|
| 20 |
+
N = {{size("B", 1)}}
|
| 21 |
+
K = {{size("A", 1)}}
|
| 22 |
+
stride_am = {{stride("A", 0)}}
|
| 23 |
+
stride_ak = {{stride("A", 1)}}
|
| 24 |
+
stride_bk = {{stride("B", 0)}}
|
| 25 |
+
stride_bn = {{stride("B", 1)}}
|
| 26 |
+
|
| 27 |
+
# based on triton.ops.matmul
|
| 28 |
+
pid = tl.program_id(0)
|
| 29 |
+
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
| 30 |
+
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
| 31 |
+
|
| 32 |
+
# re-order program ID for better L2 performance
|
| 33 |
+
width = GROUP_M * grid_n
|
| 34 |
+
group_id = pid // width
|
| 35 |
+
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
| 36 |
+
pid_m = group_id * GROUP_M + (pid % group_size)
|
| 37 |
+
pid_n = (pid % width) // (group_size)
|
| 38 |
+
|
| 39 |
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 40 |
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 41 |
+
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
| 42 |
+
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
| 43 |
+
rk = tl.arange(0, BLOCK_K)
|
| 44 |
+
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
| 45 |
+
B = B + (rk[:, None]//2 * stride_bk + rbn[None, :] * stride_bn)
|
| 46 |
+
b_shifts = 4*(rk%2)
|
| 47 |
+
b_subs = 8*(1-(rk%2))
|
| 48 |
+
|
| 49 |
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
| 50 |
+
for k in range(K, 0, -BLOCK_K):
|
| 51 |
+
if EVEN_K:
|
| 52 |
+
a = tl.load(A)
|
| 53 |
+
b = tl.load(B)
|
| 54 |
+
else:
|
| 55 |
+
a = tl.load(A, mask=rk[None, :] < k, other=0.)
|
| 56 |
+
b = tl.load(B, mask=rk[:, None] < k, other=0.)
|
| 57 |
+
b = ((b >> b_shifts[:, None]) & 0xF) - 8
|
| 58 |
+
b = b.to(B_PROLOGUE_CAST_TYPE)
|
| 59 |
+
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
|
| 60 |
+
A += BLOCK_K * stride_ak
|
| 61 |
+
B += BLOCK_K//2 * stride_bk
|
| 62 |
+
|
| 63 |
+
# rematerialize rm and rn to save registers
|
| 64 |
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 65 |
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 66 |
+
idx_m = rm[:, None]
|
| 67 |
+
idx_n = rn[None, :]
|
| 68 |
+
mask = (idx_m < M) & (idx_n < N)
|
| 69 |
+
|
| 70 |
+
# inductor generates a suffix
|
| 71 |
+
{{store_output(("idx_m", "idx_n"), "acc", "mask")}}
|
| 72 |
+
""",
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def tuned_uint4x2_mixed_mm(mat1, mat2, mat2_mm_shape, mat2_dtype):
|
| 77 |
+
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=None, use_4x2_dim=True)
|
| 78 |
+
choices: List[ChoiceCaller] = []
|
| 79 |
+
b_prologue_cast_type = f"tl.{mat2_dtype}".replace("torch.", "")
|
| 80 |
+
for config in mm_configs(m, n, k):
|
| 81 |
+
uint4x2_mixed_mm_template.maybe_append_choice(
|
| 82 |
+
choices,
|
| 83 |
+
input_nodes=(mat1, mat2),
|
| 84 |
+
layout=layout,
|
| 85 |
+
**mm_options(config, m, n, k, layout, b_prologue_cast_type),
|
| 86 |
+
)
|
| 87 |
+
return autotune_select_algorithm("uint4x2_mixed_mm", choices, [mat1, mat2], layout)
|
.venv/lib/python3.11/site-packages/torch/_inductor/runtime/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/torch/_inductor/runtime/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (196 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/runtime/__pycache__/autotune_cache.cpython-311.pyc
ADDED
|
Binary file (11.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/runtime/__pycache__/benchmarking.cpython-311.pyc
ADDED
|
Binary file (11.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/runtime/__pycache__/compile_tasks.cpython-311.pyc
ADDED
|
Binary file (4.34 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/runtime/__pycache__/coordinate_descent_tuner.cpython-311.pyc
ADDED
|
Binary file (11.7 kB). View file
|
|
|