koichi12 commited on
Commit
44b4c93
·
verified ·
1 Parent(s): 70fbf20

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/common.cpython-311.pyc +3 -0
  3. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp.cpython-311.pyc +3 -0
  4. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_cpu.cpython-311.pyc +3 -0
  5. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/__init__.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/b2b_gemm.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/binary_folding.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/decompose_mem_bound_mm.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/dedupe_symint_uses.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/efficient_conv_bn_eval.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/freezing_patterns.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/fuse_attention.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/group_batch_fusion.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/joint_graph.cpython-311.pyc +0 -0
  15. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/micro_pipeline_tp.cpython-311.pyc +0 -0
  16. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/misc_patterns.cpython-311.pyc +0 -0
  17. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/mkldnn_fusion.cpython-311.pyc +0 -0
  18. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/numeric_utils.cpython-311.pyc +0 -0
  19. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/pad_mm.cpython-311.pyc +0 -0
  20. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/post_grad.cpython-311.pyc +0 -0
  21. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/pre_grad.cpython-311.pyc +0 -0
  22. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/quantization.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/reinplace.cpython-311.pyc +0 -0
  24. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/replace_random.cpython-311.pyc +0 -0
  25. .venv/lib/python3.11/site-packages/torch/_inductor/kernel/__init__.py +1 -0
  26. .venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/__init__.cpython-311.pyc +0 -0
  27. .venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/bmm.cpython-311.pyc +0 -0
  28. .venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/conv.cpython-311.pyc +0 -0
  29. .venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/flex_attention.cpython-311.pyc +0 -0
  30. .venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/flex_decoding.cpython-311.pyc +0 -0
  31. .venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/mm.cpython-311.pyc +0 -0
  32. .venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/mm_common.cpython-311.pyc +0 -0
  33. .venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/mm_plus_mm.cpython-311.pyc +0 -0
  34. .venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/mm_scaled.cpython-311.pyc +0 -0
  35. .venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/unpack_mixed_mm.cpython-311.pyc +0 -0
  36. .venv/lib/python3.11/site-packages/torch/_inductor/kernel/bmm.py +192 -0
  37. .venv/lib/python3.11/site-packages/torch/_inductor/kernel/conv.py +679 -0
  38. .venv/lib/python3.11/site-packages/torch/_inductor/kernel/flex_attention.py +1843 -0
  39. .venv/lib/python3.11/site-packages/torch/_inductor/kernel/flex_decoding.py +570 -0
  40. .venv/lib/python3.11/site-packages/torch/_inductor/kernel/mm.py +776 -0
  41. .venv/lib/python3.11/site-packages/torch/_inductor/kernel/mm_common.py +466 -0
  42. .venv/lib/python3.11/site-packages/torch/_inductor/kernel/mm_plus_mm.py +248 -0
  43. .venv/lib/python3.11/site-packages/torch/_inductor/kernel/mm_scaled.py +311 -0
  44. .venv/lib/python3.11/site-packages/torch/_inductor/kernel/unpack_mixed_mm.py +87 -0
  45. .venv/lib/python3.11/site-packages/torch/_inductor/runtime/__init__.py +0 -0
  46. .venv/lib/python3.11/site-packages/torch/_inductor/runtime/__pycache__/__init__.cpython-311.pyc +0 -0
  47. .venv/lib/python3.11/site-packages/torch/_inductor/runtime/__pycache__/autotune_cache.cpython-311.pyc +0 -0
  48. .venv/lib/python3.11/site-packages/torch/_inductor/runtime/__pycache__/benchmarking.cpython-311.pyc +0 -0
  49. .venv/lib/python3.11/site-packages/torch/_inductor/runtime/__pycache__/compile_tasks.cpython-311.pyc +0 -0
  50. .venv/lib/python3.11/site-packages/torch/_inductor/runtime/__pycache__/coordinate_descent_tuner.cpython-311.pyc +0 -0
.gitattributes CHANGED
@@ -142,3 +142,6 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
142
  .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/simd.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
143
  .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/halide.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
144
  .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/wrapper.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
 
 
 
 
142
  .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/simd.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
143
  .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/halide.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
144
  .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/wrapper.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
145
+ .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_cpu.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
146
+ .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
147
+ .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/common.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/common.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee19b8e0b980d0895a7af50aa7c3244d133ce110a196485ab8cec5fa7b9767d4
3
+ size 121452
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e8596ce3d305b9ea76fd93737e3fda25769b1901142db9efff0fde9757b03517
3
+ size 262897
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_cpu.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de158b207dec0ef6dd7cca5acc1db68fcc605b0046ed6c5ffcf0d9b8f34d3b82
3
+ size 138985
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (198 Bytes). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/b2b_gemm.cpython-311.pyc ADDED
Binary file (26.6 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/binary_folding.cpython-311.pyc ADDED
Binary file (13.2 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/decompose_mem_bound_mm.cpython-311.pyc ADDED
Binary file (8.35 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/dedupe_symint_uses.cpython-311.pyc ADDED
Binary file (5.27 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/efficient_conv_bn_eval.cpython-311.pyc ADDED
Binary file (14.5 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/freezing_patterns.cpython-311.pyc ADDED
Binary file (10.6 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/fuse_attention.cpython-311.pyc ADDED
Binary file (41.9 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/group_batch_fusion.cpython-311.pyc ADDED
Binary file (75.8 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/joint_graph.cpython-311.pyc ADDED
Binary file (34.2 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/micro_pipeline_tp.cpython-311.pyc ADDED
Binary file (42 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/misc_patterns.cpython-311.pyc ADDED
Binary file (6.81 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/mkldnn_fusion.cpython-311.pyc ADDED
Binary file (63.5 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/numeric_utils.cpython-311.pyc ADDED
Binary file (10.6 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/pad_mm.cpython-311.pyc ADDED
Binary file (37.4 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/post_grad.cpython-311.pyc ADDED
Binary file (59.3 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/pre_grad.cpython-311.pyc ADDED
Binary file (40.2 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/quantization.cpython-311.pyc ADDED
Binary file (84.2 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/reinplace.cpython-311.pyc ADDED
Binary file (31.8 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/replace_random.cpython-311.pyc ADDED
Binary file (7.48 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import mm, mm_common, mm_plus_mm, unpack_mixed_mm
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (321 Bytes). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/bmm.cpython-311.pyc ADDED
Binary file (9.36 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/conv.cpython-311.pyc ADDED
Binary file (25.6 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/flex_attention.cpython-311.pyc ADDED
Binary file (67.6 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/flex_decoding.cpython-311.pyc ADDED
Binary file (21.2 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/mm.cpython-311.pyc ADDED
Binary file (30.2 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/mm_common.cpython-311.pyc ADDED
Binary file (20.7 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/mm_plus_mm.cpython-311.pyc ADDED
Binary file (7.92 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/mm_scaled.cpython-311.pyc ADDED
Binary file (11.9 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/unpack_mixed_mm.cpython-311.pyc ADDED
Binary file (3.59 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/bmm.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import logging
3
+
4
+ import torch
5
+
6
+ from .. import ir, lowering as L
7
+ from ..select_algorithm import (
8
+ autotune_select_algorithm,
9
+ ExternKernelChoice,
10
+ TritonTemplate,
11
+ )
12
+ from ..utils import (
13
+ ceildiv as cdiv,
14
+ use_aten_gemm_kernels,
15
+ use_cutlass_template,
16
+ use_triton_template,
17
+ )
18
+ from ..virtualized import V
19
+ from .mm import _is_static_problem
20
+ from .mm_common import addmm_epilogue, mm_args, mm_configs, mm_options
21
+
22
+
23
+ log = logging.getLogger(__name__)
24
+ aten = torch.ops.aten
25
+
26
+
27
+ def bmm_grid(b, m, n, meta):
28
+ return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), b, 1)
29
+
30
+
31
+ bmm_template = TritonTemplate(
32
+ name="bmm",
33
+ grid=bmm_grid,
34
+ source=r"""
35
+ {{def_kernel("A", "B")}}
36
+ M = {{size("A", -2)}}
37
+ N = {{size("B", -1)}}
38
+ K = {{size("A", -1)}}
39
+
40
+ stride_aq = {{stride("A", 0)}}
41
+ stride_am = {{stride("A", 1)}}
42
+ stride_ak = {{stride("A", 2)}}
43
+
44
+ stride_bq = {{stride("B", 0)}}
45
+ stride_bk = {{stride("B", 1)}}
46
+ stride_bn = {{stride("B", 2)}}
47
+
48
+ # based on triton.ops.matmul
49
+ pid = tl.program_id(0)
50
+ grid_m = (M + BLOCK_M - 1) // BLOCK_M
51
+ grid_n = (N + BLOCK_N - 1) // BLOCK_N
52
+
53
+ # re-order program ID for better L2 performance
54
+ width = GROUP_M * grid_n
55
+ group_id = pid // width
56
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
57
+ pid_m = group_id * GROUP_M + (pid % group_size)
58
+ pid_n = (pid % width) // (group_size)
59
+
60
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
61
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
62
+ if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1):
63
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
64
+ else:
65
+ ram = rm % M
66
+ if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1):
67
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
68
+ else:
69
+ rbn = rn % N
70
+
71
+ rk = tl.arange(0, BLOCK_K)
72
+
73
+ idx_q = tl.program_id(1) # batch dimension for BMM
74
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak + idx_q*stride_aq)
75
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn + idx_q*stride_bq)
76
+
77
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
78
+ for k in range(K, 0, -BLOCK_K):
79
+ if EVEN_K:
80
+ a = tl.load(A)
81
+ b = tl.load(B)
82
+ else:
83
+ a = tl.load(A, mask=rk[None, :] < k, other=0.)
84
+ b = tl.load(B, mask=rk[:, None] < k, other=0.)
85
+ acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
86
+ A += BLOCK_K * stride_ak
87
+ B += BLOCK_K * stride_bk
88
+
89
+ # rematerialize rm and rn to save registers
90
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
91
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
92
+ idx_q = tl.program_id(1) # batch dimension for BMM
93
+ idx_m = rm[:, None]
94
+ idx_n = rn[None, :]
95
+ mask = (idx_m < M) & (idx_n < N)
96
+
97
+ # inductor generates a suffix
98
+ {{store_output(("idx_q", "idx_m", "idx_n"), "acc", "mask")}}
99
+ """,
100
+ )
101
+
102
+ aten_bmm = ExternKernelChoice(torch.bmm, "at::bmm_out")
103
+ aten_baddbmm = ExternKernelChoice(torch.baddbmm, "at::baddbmm_out")
104
+
105
+
106
+ @L.register_lowering(aten.bmm)
107
+ def tuned_bmm(mat1, mat2, *, layout=None):
108
+ if all(x.get_device().type == "cpu" for x in [mat1, mat2]):
109
+ # decompose to small ops when memory bound
110
+ if mat1.get_size()[1] == 1 or mat2.get_size()[2] == 1:
111
+ mat1 = L.unsqueeze(mat1, -1)
112
+ mat2 = L.unsqueeze(mat2, 1)
113
+ return L.sum_(L.mul(mat1, mat2), axis=2)
114
+
115
+ def is_valid_to_require_contiguous(t):
116
+ if not ir.is_storage_and_layout(t):
117
+ return True
118
+ _, layout = ir.as_storage_and_layout(t, freeze=False)
119
+ return isinstance(layout, ir.FlexibleLayout)
120
+
121
+ def is_preferred_layout_as_bmm_input(sizes, strides):
122
+ # contiguous on one of the last two dims
123
+ return (
124
+ strides[-1] == 1 and (sizes[-2] == 1 or strides[-2] >= sizes[-1])
125
+ ) or (strides[-2] == 1 and (sizes[-1] == 1 or strides[-1] >= sizes[-2]))
126
+
127
+ # Make the input of bmm contiguous
128
+ # if it is not contiguous on either of the last two dims,
129
+ # because bmm cpu implementation would do contiguous() if not.
130
+ # This is to avoid additional copies in bmm.
131
+ def may_require_contiguous(t, meta_t):
132
+ sizes = meta_t.meta["val"].size()
133
+ strides = meta_t.meta["val"].stride()
134
+ if not is_preferred_layout_as_bmm_input(sizes, strides):
135
+ t = ir.ExternKernel.require_contiguous(t)
136
+ return t
137
+
138
+ if is_valid_to_require_contiguous(mat1):
139
+ meta_mat1 = V.graph.current_node.args[0]
140
+ mat1 = may_require_contiguous(mat1, meta_mat1)
141
+ if is_valid_to_require_contiguous(mat2):
142
+ meta_mat2 = V.graph.current_node.args[1]
143
+ mat2 = may_require_contiguous(mat2, meta_mat2)
144
+
145
+ m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
146
+
147
+ # options to tune from
148
+ choices = [aten_bmm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else []
149
+ if use_triton_template(layout):
150
+ for config in mm_configs(m, n, k):
151
+ bmm_template.maybe_append_choice(
152
+ choices,
153
+ input_nodes=(mat1, mat2),
154
+ layout=layout,
155
+ **mm_options(config, m, n, k, layout),
156
+ )
157
+ static_shape, is_nonzero = _is_static_problem([mat1, mat2], layout)
158
+ if static_shape and is_nonzero and use_cutlass_template(layout, m, n, k):
159
+ from ..codegen.cuda.gemm_template import CUTLASS3xGemmTemplate
160
+
161
+ CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2])
162
+
163
+ if len(choices) == 0:
164
+ log.warning("No choices for GEMM, using ATen backend as fallback")
165
+ choices.append(aten_bmm.bind((mat1, mat2), layout))
166
+
167
+ return autotune_select_algorithm("bmm", choices, [mat1, mat2], layout)
168
+
169
+
170
+ # Don't register this since it is slower than decomposing it
171
+ # @L.register_lowering(aten.baddbmm)
172
+ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
173
+ m, n, k, layout, mat1, mat2, inp = mm_args(mat1, mat2, inp, layout=layout)
174
+
175
+ # options to tune from
176
+ choices = (
177
+ [aten_baddbmm.bind((inp, mat1, mat2), layout, alpha=alpha, beta=beta)]
178
+ if use_aten_gemm_kernels()
179
+ else []
180
+ )
181
+ if use_triton_template(layout):
182
+ for config in mm_configs(m, n, k):
183
+ bmm_template.maybe_append_choice(
184
+ choices,
185
+ input_nodes=(inp, mat1, mat2),
186
+ layout=layout,
187
+ **mm_options(config, m, n, k, layout),
188
+ prefix_args=1,
189
+ epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
190
+ )
191
+
192
+ return autotune_select_algorithm("baddbmm", choices, [inp, mat1, mat2], layout)
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/conv.py ADDED
@@ -0,0 +1,679 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-decorators
2
+ # mypy: allow-untyped-defs
3
+ from __future__ import annotations
4
+
5
+ import functools
6
+ import logging
7
+ from typing import cast, List, Optional, Sequence, Tuple, TYPE_CHECKING, TypedDict
8
+
9
+ import torch
10
+
11
+ from .. import config, ir
12
+ from ..lowering import (
13
+ add_layout_constraint,
14
+ constrain_to_fx_strides,
15
+ lowerings as L,
16
+ register_lowering,
17
+ )
18
+ from ..select_algorithm import (
19
+ autotune_select_algorithm,
20
+ ExternKernelChoice,
21
+ TritonTemplate,
22
+ )
23
+ from ..utils import (
24
+ ceildiv,
25
+ is_ones,
26
+ is_zeros,
27
+ pad_listlike,
28
+ sympy_product,
29
+ use_triton_template,
30
+ )
31
+ from ..virtualized import V
32
+ from .mm_common import filtered_configs
33
+
34
+
35
+ if TYPE_CHECKING:
36
+ from ..ir import TensorBox
37
+
38
+ log = logging.getLogger(__name__)
39
+
40
+
41
+ aten = torch.ops.aten
42
+
43
+
44
+ def conv2d_grid(n, c, h, w, meta):
45
+ return (
46
+ ceildiv(n * h * w, meta["BLOCK_M"]),
47
+ ceildiv(c, meta["BLOCK_N"]),
48
+ meta["GROUPS"],
49
+ )
50
+
51
+
52
+ def conv3d_grid(n, c, d, h, w, meta):
53
+ return (
54
+ ceildiv(n * d * h * w, meta["BLOCK_M"]),
55
+ ceildiv(c, meta["BLOCK_N"]),
56
+ meta["GROUPS"],
57
+ )
58
+
59
+
60
+ # List of dictionaries to store the kernel configs. Configs that evaluate to true
61
+ # will be utilised on the target platform
62
+ kernel_configs = [
63
+ # "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps"
64
+ {"config": (64, 256, 16, 2, 4), "cond": True},
65
+ {"config": (256, 64, 16, 2, 4), "cond": True},
66
+ {"config": (1024, 16, 16, 1, 8), "cond": True},
67
+ {"config": (128, 128, 32, 2, 8), "cond": True},
68
+ {"config": (64, 64, 32, 2, 4), "cond": True},
69
+ {"config": (64, 256, 32, 2, 8), "cond": True},
70
+ {"config": (256, 64, 32, 2, 8), "cond": True},
71
+ ]
72
+
73
+ # Create filtered list of configs based on conv
74
+ platform_configs = tuple(
75
+ cast(Tuple[int, int, int, int, int], config["config"])
76
+ for config in kernel_configs
77
+ if config["cond"]
78
+ )
79
+
80
+ # On ROCm convert num_stages to 1 as pipelining provides no benefit
81
+ if torch.version.hip:
82
+ platform_configs = tuple(
83
+ (config[0], config[1], config[2], 1, config[4]) for config in platform_configs
84
+ )
85
+
86
+ conv_configs = functools.partial(
87
+ filtered_configs,
88
+ configs=platform_configs,
89
+ )
90
+
91
+ LOOP_BODY_2D = """
92
+ idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H
93
+ idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W
94
+ idx_x_c = tl.arange(0, BLOCK_K) + k
95
+
96
+ x_ptrs = x_base + (
97
+ (idx_x_h * stride_xh)[:, None]
98
+ + (idx_x_w * stride_xw)[:, None]
99
+ + (idx_x_c * stride_xc)[None, :]
100
+ )
101
+ mask_x = (
102
+ (idx_n < BATCH)[:, None]
103
+ & (idx_x_h >= 0)[:, None]
104
+ & (idx_x_h < IN_H)[:, None]
105
+ & (idx_x_w >= 0)[:, None]
106
+ & (idx_x_w < IN_W)[:, None]
107
+ & (idx_x_c < GROUP_IN_C)[None, :]
108
+ )
109
+ matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
110
+
111
+ w_ptrs = w_base + (
112
+ (idx_x_c * stride_wc_in)[:, None] + (i * stride_wh) + (j * stride_ww)
113
+ )
114
+ mask_w = (idx_x_c[:, None] < GROUP_IN_C) & (idx_y_c[None, :] < GROUP_OUT_C)
115
+ matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
116
+ acc += tl.dot(matrix_x, matrix_w, allow_tf32=ALLOW_TF32)
117
+ """
118
+
119
+ """
120
+ This is a relatively simple conv implementation that can likely be
121
+ improved. Many alternate conv versions can be found here:
122
+ https://github.com/pytorch/torchdynamo/pull/971
123
+ """
124
+ conv2d_template = TritonTemplate(
125
+ name="convolution2d",
126
+ grid=conv2d_grid,
127
+ source=r"""
128
+ {{def_kernel("X", "W")}}
129
+ # Tensor dimensions
130
+ BATCH = {{size("X", 0)}}
131
+ IN_C = {{size("X", 1)}}
132
+ IN_H = {{size("X", 2)}}
133
+ IN_W = {{size("X", 3)}}
134
+ OUT_C = {{size(None, 1)}}
135
+ OUT_H = {{size(None, 2)}}
136
+ OUT_W = {{size(None, 3)}}
137
+
138
+ # Strides:
139
+ stride_xn = {{stride("X", 0)}}
140
+ stride_xc = {{stride("X", 1)}}
141
+ stride_xh = {{stride("X", 2)}}
142
+ stride_xw = {{stride("X", 3)}}
143
+ stride_wc_out = {{stride("W", 0)}}
144
+ stride_wc_in = {{stride("W", 1)}}
145
+ stride_wh = {{stride("W", 2)}}
146
+ stride_ww = {{stride("W", 3)}}
147
+
148
+ nhw = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
149
+ idx_y_w = nhw % OUT_W
150
+ nh = nhw // OUT_W
151
+ idx_y_h = nh % OUT_H
152
+ idx_n = nh // OUT_H
153
+ idx_y_c = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
154
+
155
+ {% if GROUPS == 1 %}
156
+ group = 0
157
+ GROUP_IN_C = IN_C
158
+ GROUP_OUT_C = OUT_C
159
+ {% else %}
160
+ group = tl.program_id(2)
161
+ GROUP_IN_C = IN_C // GROUPS
162
+ GROUP_OUT_C = OUT_C // GROUPS
163
+ {% endif %}
164
+
165
+ x_base = X + (group * stride_xc * GROUP_IN_C + idx_n * stride_xn)[:, None]
166
+ w_base = (
167
+ W + (group * stride_wc_out * GROUP_OUT_C + idx_y_c * stride_wc_out)[None, :]
168
+ )
169
+
170
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
171
+
172
+ {% if UNROLL %}
173
+ {% for i in range(KERNEL_H) %}
174
+ {% for j in range(KERNEL_W) %}
175
+ i = {{i}}
176
+ j = {{j}}
177
+ for k in range(0, GROUP_IN_C, BLOCK_K):
178
+ """
179
+ + LOOP_BODY_2D
180
+ + """
181
+ {% endfor %}
182
+ {% endfor %}
183
+ {% else %}
184
+ # Could be simplified, but slightly slower:
185
+ # for i in range(KERNEL_H):
186
+ # for j in range(KERNEL_W):
187
+ # for k in range(0, GROUP_IN_C, BLOCK_K):
188
+ BLOCK_K_COUNT = (GROUP_IN_C + BLOCK_K - 1) // BLOCK_K
189
+ for ijk in range(KERNEL_H * KERNEL_W * BLOCK_K_COUNT):
190
+ k = (ijk % BLOCK_K_COUNT) * BLOCK_K
191
+ ij = ijk // BLOCK_K_COUNT
192
+ i = ij // KERNEL_W
193
+ j = ij % KERNEL_W
194
+ """
195
+ + LOOP_BODY_2D
196
+ + """
197
+ {% endif %}
198
+
199
+ mask = (
200
+ (idx_n < BATCH)[:, None]
201
+ & (idx_y_h < OUT_H)[:, None]
202
+ & (idx_y_w < OUT_W)[:, None]
203
+ & (idx_y_c < GROUP_OUT_C)[None, :]
204
+ )
205
+ idx_n = idx_n[:, None]
206
+ idx_c = idx_y_c[None, :] + group * GROUP_OUT_C
207
+ idx_h = idx_y_h[:, None]
208
+ idx_w = idx_y_w[:, None]
209
+
210
+ # inductor generates a suffix
211
+ {{store_output(("idx_n", "idx_c", "idx_h", "idx_w"), "acc", "mask")}}
212
+ """,
213
+ )
214
+
215
+ LOOP_BODY_3D = """
216
+ idx_x_d = d - PADDING_D + idx_y_d * STRIDE_D
217
+ idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H
218
+ idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W
219
+ idx_x_c = tl.arange(0, BLOCK_K) + k
220
+
221
+ x_ptrs = x_base + (
222
+ (idx_x_d * stride_xd)[:, None]
223
+ + (idx_x_h * stride_xh)[:, None]
224
+ + (idx_x_w * stride_xw)[:, None]
225
+ + (idx_x_c * stride_xc)[None, :]
226
+ )
227
+ mask_x = (
228
+ (idx_n < BATCH)[:, None]
229
+ & (idx_x_d >= 0)[:, None]
230
+ & (idx_x_d < IN_D)[:, None]
231
+ & (idx_x_h >= 0)[:, None]
232
+ & (idx_x_h < IN_H)[:, None]
233
+ & (idx_x_w >= 0)[:, None]
234
+ & (idx_x_w < IN_W)[:, None]
235
+ & (idx_x_c < GROUP_IN_C)[None, :]
236
+ )
237
+ matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
238
+
239
+ w_ptrs = w_base + (
240
+ (idx_x_c * stride_wc_in)[:, None] +
241
+ (d * stride_wd) + (i * stride_wh) + (j * stride_ww)
242
+ )
243
+ mask_w = (idx_x_c[:, None] < GROUP_IN_C) & (idx_y_c[None, :] < GROUP_OUT_C)
244
+ matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
245
+ acc += tl.dot(matrix_x, matrix_w, allow_tf32=ALLOW_TF32)
246
+ """
247
+
248
+ conv3d_template = TritonTemplate(
249
+ name="convolution3d",
250
+ grid=conv3d_grid,
251
+ source=r"""
252
+ {{def_kernel("X", "W")}}
253
+ # Tensor dimensions
254
+ BATCH = {{size("X", 0)}}
255
+ IN_C = {{size("X", 1)}}
256
+ IN_D = {{size("X", 2)}}
257
+ IN_H = {{size("X", 3)}}
258
+ IN_W = {{size("X", 4)}}
259
+ OUT_C = {{size(None, 1)}}
260
+ OUT_D = {{size(None, 2)}}
261
+ OUT_H = {{size(None, 3)}}
262
+ OUT_W = {{size(None, 4)}}
263
+
264
+ # Strides:
265
+ stride_xn = {{stride("X", 0)}}
266
+ stride_xc = {{stride("X", 1)}}
267
+ stride_xd = {{stride("X", 2)}}
268
+ stride_xh = {{stride("X", 3)}}
269
+ stride_xw = {{stride("X", 4)}}
270
+ stride_wc_out = {{stride("W", 0)}}
271
+ stride_wc_in = {{stride("W", 1)}}
272
+ stride_wd = {{stride("W", 2)}}
273
+ stride_wh = {{stride("W", 3)}}
274
+ stride_ww = {{stride("W", 4)}}
275
+
276
+ ndhw = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
277
+ idx_y_w = ndhw % OUT_W
278
+ ndh = ndhw // OUT_W
279
+ idx_y_h = ndh % OUT_H
280
+ nd = ndh // OUT_H
281
+ idx_y_d = nd % OUT_D
282
+ idx_n = nd // OUT_D
283
+ idx_y_c = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
284
+
285
+ {% if GROUPS == 1 %}
286
+ group = 0
287
+ GROUP_IN_C = IN_C
288
+ GROUP_OUT_C = OUT_C
289
+ {% else %}
290
+ group = tl.program_id(2)
291
+ GROUP_IN_C = IN_C // GROUPS
292
+ GROUP_OUT_C = OUT_C // GROUPS
293
+ {% endif %}
294
+
295
+ x_base = X + (group * stride_xc * GROUP_IN_C + idx_n * stride_xn)[:, None]
296
+ w_base = (
297
+ W + (group * stride_wc_out * GROUP_OUT_C + idx_y_c * stride_wc_out)[None, :]
298
+ )
299
+
300
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
301
+
302
+ {% if UNROLL %}
303
+ {% for d in range(KERNEL_D) %}
304
+ {% for i in range(KERNEL_H) %}
305
+ {% for j in range(KERNEL_W) %}
306
+ d = {{d}}
307
+ i = {{i}}
308
+ j = {{j}}
309
+ for k in range(0, GROUP_IN_C, BLOCK_K):
310
+ """
311
+ + LOOP_BODY_3D
312
+ + """
313
+ {% endfor %}
314
+ {% endfor %}
315
+ {% endfor %}
316
+ {% else %}
317
+ # Could be simplified, but slightly slower:
318
+ # for d in range(KERNEL_D):
319
+ # for i in range(KERNEL_H):
320
+ # for j in range(KERNEL_W):
321
+ # for k in range(0, GROUP_IN_C, BLOCK_K):
322
+ BLOCK_K_COUNT = (GROUP_IN_C + BLOCK_K - 1) // BLOCK_K
323
+ for dijk in range(KERNEL_D * KERNEL_H * KERNEL_W * BLOCK_K_COUNT):
324
+ k = (dijk % BLOCK_K_COUNT) * BLOCK_K
325
+ dij = dijk // BLOCK_K_COUNT
326
+ j = dij % KERNEL_W
327
+ di = dij // KERNEL_W
328
+ i = di % KERNEL_H
329
+ d = di // KERNEL_H
330
+ """
331
+ + LOOP_BODY_3D
332
+ + """
333
+ {% endif %}
334
+
335
+ mask = (
336
+ (idx_n < BATCH)[:, None]
337
+ & (idx_y_d < OUT_D)[:, None]
338
+ & (idx_y_h < OUT_H)[:, None]
339
+ & (idx_y_w < OUT_W)[:, None]
340
+ & (idx_y_c < GROUP_OUT_C)[None, :]
341
+ )
342
+ idx_n = idx_n[:, None]
343
+ idx_c = idx_y_c[None, :] + group * GROUP_OUT_C
344
+ idx_d = idx_y_d[:, None]
345
+ idx_h = idx_y_h[:, None]
346
+ idx_w = idx_y_w[:, None]
347
+
348
+ # inductor generates a suffix
349
+ {{store_output(("idx_n", "idx_c", "idx_d", "idx_h", "idx_w"), "acc", "mask")}}
350
+ """,
351
+ )
352
+
353
+ aten_convolution = ExternKernelChoice(
354
+ torch.convolution,
355
+ "at::convolution",
356
+ has_out_variant=False,
357
+ op_overload=aten.convolution.default,
358
+ )
359
+
360
+
361
+ def conv1x1_via_mm(x, w, *, out):
362
+ w = torch.squeeze(torch.squeeze(w, -1), -1)
363
+ return torch.matmul(
364
+ x.permute(0, 2, 3, 1), w.permute(1, 0), out=out.permute(0, 2, 3, 1)
365
+ )
366
+
367
+
368
+ aten_conv1x1_via_mm = ExternKernelChoice(conv1x1_via_mm, None)
369
+
370
+
371
+ class ConvLayoutParams(TypedDict):
372
+ stride: tuple[int, ...]
373
+ padding: tuple[int, ...]
374
+ dilation: tuple[int, ...]
375
+ transposed: bool
376
+ output_padding: tuple[int, ...]
377
+ groups: int
378
+
379
+
380
+ def conv_layout(
381
+ x: TensorBox,
382
+ weight: TensorBox,
383
+ bias: Optional[TensorBox],
384
+ stride: Sequence[int],
385
+ padding: tuple[int, ...],
386
+ dilation: tuple[int, ...],
387
+ transposed: bool,
388
+ output_padding: tuple[int, ...],
389
+ groups: int,
390
+ ) -> ir.Layout:
391
+ """Determine output layout for a convolution"""
392
+ with V.graph.fake_mode:
393
+ output = torch.ops.aten.convolution(
394
+ ir.ir_node_to_tensor(x, guard_shape=True),
395
+ ir.ir_node_to_tensor(weight, guard_shape=True),
396
+ ir.ir_node_to_tensor(bias, guard_shape=True),
397
+ V.graph.sizevars.size_hints(stride), # type: ignore[arg-type]
398
+ V.graph.sizevars.size_hints(padding), # type: ignore[arg-type]
399
+ V.graph.sizevars.size_hints(dilation), # type: ignore[arg-type]
400
+ transposed,
401
+ V.graph.sizevars.size_hints(output_padding), # type: ignore[arg-type]
402
+ groups,
403
+ )
404
+ sizes = ir.convert_shape_to_inductor(output.size())
405
+ stride = ir.convert_shape_to_inductor(output.stride()) # type: ignore[assignment]
406
+
407
+ return ir.FixedLayout(
408
+ x.get_device(),
409
+ x.get_dtype(),
410
+ sizes,
411
+ stride,
412
+ )
413
+
414
+
415
+ def channels_last_order(rank):
416
+ order = list(reversed(range(rank)))
417
+ order.insert(1, order.pop(-1))
418
+ return order
419
+
420
+
421
+ def convert_1x1_conv_to_mm(x, weight, bias):
422
+ # special case for 1x1 convolution, which is actually just a matmul
423
+ rank = len(weight.get_size())
424
+ for _ in range(rank - 2):
425
+ weight = L[aten.squeeze](weight, dim=-1)
426
+ weight = L[aten.permute](weight, [1, 0])
427
+
428
+ x = ir.ExternKernel.require_stride_order(x, channels_last_order(rank))
429
+ x_permute = list(range(rank))
430
+ x_permute.append(x_permute.pop(1))
431
+ x = L[aten.permute](x, x_permute)
432
+ *sizes, in_chan = x.get_size()
433
+ x = L[aten.reshape](x, [sympy_product(sizes), in_chan])
434
+ if bias is None:
435
+ result = L[aten.mm](x, weight)
436
+ else:
437
+ result = L[aten.addmm](bias, x, weight)
438
+ result = L[aten.reshape](result, [*sizes, -1])
439
+ result_permute = list(range(rank))
440
+ result_permute.insert(1, result_permute.pop(-1))
441
+ return L[aten.permute](result, result_permute)
442
+
443
+
444
+ @register_lowering(aten.convolution)
445
+ def convolution(
446
+ x: TensorBox,
447
+ weight: TensorBox,
448
+ bias: TensorBox,
449
+ stride: List[int],
450
+ padding: List[int],
451
+ dilation: List[int],
452
+ transposed: bool,
453
+ output_padding: List[int],
454
+ groups: int,
455
+ ):
456
+ stride = tuple(stride)
457
+ padding = tuple(padding)
458
+ dilation = tuple(dilation)
459
+ output_padding = tuple(output_padding)
460
+ if not isinstance(groups, int):
461
+ groups = V.graph.sizevars.evaluate_static_shape(groups)
462
+ assert isinstance(groups, int)
463
+
464
+ # Need use hint for triton template since the template does not
465
+ # work with a dynamic shape.
466
+ #
467
+ # No need to evaluate_static_shape for dilation and output_padding
468
+ # since the template is only used when dilation is 1 and output_padding
469
+ # is 0.
470
+ stride = tuple(V.graph.sizevars.evaluate_static_shapes(stride))
471
+ padding = tuple(V.graph.sizevars.evaluate_static_shapes(padding))
472
+
473
+ kwargs: ConvLayoutParams = {
474
+ "stride": stride,
475
+ "padding": padding,
476
+ "dilation": dilation,
477
+ "transposed": transposed,
478
+ "output_padding": output_padding,
479
+ "groups": groups,
480
+ }
481
+
482
+ if len(x.get_size()) == len(weight.get_size()) - 1:
483
+ # add batch dimension to simplify rest of function
484
+ return L[aten.squeeze](
485
+ convolution(L[aten.expand](x, [1, *x.get_size()]), weight, bias, **kwargs),
486
+ dim=0,
487
+ )
488
+
489
+ out_chan, in_chan, *kernel_shape = V.graph.sizevars.evaluate_static_shapes(
490
+ weight.get_size()
491
+ )
492
+ ndim = len(kernel_shape)
493
+ stride = pad_listlike(stride, ndim)
494
+ padding = pad_listlike(padding, ndim)
495
+ dilation = pad_listlike(dilation, ndim)
496
+ output_padding = pad_listlike(output_padding, ndim)
497
+
498
+ def channels_last_conv():
499
+ if V.graph.layout_opt and ndim == 2:
500
+ return True
501
+
502
+ layout = conv_layout(x, weight, None, **kwargs)
503
+ req_stride_order = ir.get_stride_order(
504
+ V.graph.sizevars.size_hints(layout.stride)
505
+ )
506
+ return req_stride_order == ir.NHWC_STRIDE_ORDER
507
+
508
+ autotuning_gemm = config.max_autotune or config.max_autotune_gemm
509
+
510
+ if (
511
+ (config.conv_1x1_as_mm or (autotuning_gemm and channels_last_conv()))
512
+ and is_ones(kernel_shape)
513
+ and is_ones(stride)
514
+ and is_zeros(padding)
515
+ and is_ones(dilation)
516
+ and not transposed
517
+ and is_zeros(output_padding)
518
+ and groups == 1
519
+ and V.graph.sizevars.statically_known_gt(sympy_product(x.get_size()), 0)
520
+ ):
521
+ return convert_1x1_conv_to_mm(x, weight, bias)
522
+
523
+ if bias is not None and ir.get_device_type(x) != "cpu":
524
+ # peel off the bias, cudnn is slower with it
525
+ result = convolution(x, weight, None, **kwargs)
526
+ return L[aten.add](
527
+ result, L[aten.view](bias, [result.get_size()[1]] + ndim * [1])
528
+ )
529
+
530
+ x.realize()
531
+ weight.realize()
532
+
533
+ # ndim can be 1 for convolution in models such as demucs
534
+ # TODO: check if it's beneficial to convert Conv1d to Conv2d and then
535
+ # apply channels last.
536
+ if V.graph.layout_opt and ndim == 2:
537
+ V.graph.num_channels_last_conv += 1
538
+ x = ir.ExternKernel.require_channels_last(x)
539
+ # TODO maybe we can convert weights to channels last just once before
540
+ # running the model.
541
+ weight = ir.ExternKernel.require_channels_last(weight)
542
+ layout = conv_layout(x, weight, None, **kwargs)
543
+ else:
544
+ layout = conv_layout(x, weight, None, **kwargs)
545
+ req_stride_order = ir.get_stride_order(
546
+ V.graph.sizevars.size_hints(layout.stride)
547
+ )
548
+ x = ir.ExternKernel.require_stride_order(x, req_stride_order)
549
+ weight = ir.ExternKernel.require_stride_order(weight, req_stride_order)
550
+
551
+ ordered_kwargs_for_cpp_kernel = [
552
+ "stride",
553
+ "padding",
554
+ "dilation",
555
+ "transposed",
556
+ "output_padding",
557
+ "groups",
558
+ ]
559
+ if bias is None:
560
+ args = [x, weight]
561
+ kwargs["bias"] = None # type: ignore[typeddict-unknown-key]
562
+ ordered_kwargs_for_cpp_kernel.insert(0, "bias")
563
+ else:
564
+ args = [x, weight, bias]
565
+ bias.realize()
566
+ bias.freeze_layout()
567
+ V.graph.sizevars.evaluate_static_shapes(bias.get_size())
568
+
569
+ choices = []
570
+ if torch._inductor.utils._use_conv_autotune_backend("ATEN"):
571
+ choices = [
572
+ aten_convolution.bind(
573
+ args,
574
+ layout,
575
+ ordered_kwargs_for_cpp_kernel,
576
+ **kwargs,
577
+ )
578
+ ]
579
+
580
+ if (
581
+ torch._inductor.utils._use_conv_autotune_backend("TRITON")
582
+ and use_triton_template(layout)
583
+ # templates only support these:
584
+ and is_ones(dilation)
585
+ and not transposed
586
+ and is_zeros(output_padding)
587
+ # there are some odd models where this check fails (e.g. shufflenet_v2_x1_0)
588
+ and V.graph.sizevars.statically_known_equals(in_chan, x.get_size()[1]) # type: ignore[arg-type]
589
+ ):
590
+ if (
591
+ is_ones(kernel_shape)
592
+ and is_ones(stride)
593
+ and is_zeros(padding)
594
+ and groups == 1
595
+ ):
596
+ choices.append(aten_conv1x1_via_mm.bind(args, layout))
597
+
598
+ for cfg in conv_configs(
599
+ sympy_product([x.get_size()[0], *x.get_size()[2:]]),
600
+ out_chan,
601
+ in_chan,
602
+ ):
603
+ if ndim == 2:
604
+ conv2d_template.maybe_append_choice(
605
+ choices,
606
+ input_nodes=(x, weight),
607
+ layout=layout,
608
+ KERNEL_H=kernel_shape[0],
609
+ KERNEL_W=kernel_shape[1],
610
+ STRIDE_H=stride[0],
611
+ STRIDE_W=stride[1],
612
+ PADDING_H=padding[0],
613
+ PADDING_W=padding[1],
614
+ GROUPS=groups,
615
+ # TODO(jansel): try unroll for bigger kernels once fixed:
616
+ # https://github.com/openai/triton/issues/1254
617
+ UNROLL=is_ones(kernel_shape),
618
+ ALLOW_TF32=torch.backends.cudnn.allow_tf32,
619
+ num_stages=cfg.num_stages,
620
+ num_warps=cfg.num_warps,
621
+ **cfg.kwargs,
622
+ )
623
+ elif ndim == 3:
624
+ conv3d_template.maybe_append_choice(
625
+ choices,
626
+ input_nodes=(x, weight),
627
+ layout=layout,
628
+ KERNEL_D=kernel_shape[0],
629
+ KERNEL_H=kernel_shape[1],
630
+ KERNEL_W=kernel_shape[2],
631
+ STRIDE_D=stride[0],
632
+ STRIDE_H=stride[1],
633
+ STRIDE_W=stride[2],
634
+ PADDING_D=padding[0],
635
+ PADDING_H=padding[1],
636
+ PADDING_W=padding[2],
637
+ GROUPS=groups,
638
+ # TODO(jansel): try unroll for bigger kernels once fixed:
639
+ # https://github.com/openai/triton/issues/1254
640
+ UNROLL=is_ones(kernel_shape),
641
+ ALLOW_TF32=torch.backends.cudnn.allow_tf32,
642
+ num_stages=cfg.num_stages,
643
+ num_warps=cfg.num_warps,
644
+ **cfg.kwargs,
645
+ )
646
+
647
+ return autotune_select_algorithm("convolution", choices, args, layout)
648
+
649
+
650
+ @register_lowering(aten._convolution)
651
+ def _convolution(
652
+ x,
653
+ weight,
654
+ bias,
655
+ stride,
656
+ padding,
657
+ dilation,
658
+ transposed,
659
+ output_padding,
660
+ groups,
661
+ benchmark,
662
+ deterministic,
663
+ cudnn_enabled,
664
+ allow_tf32,
665
+ ):
666
+ return convolution(
667
+ x, weight, bias, stride, padding, dilation, transposed, output_padding, groups
668
+ )
669
+
670
+
671
+ def constrain_conv_to_fx_strides(fx_node, *args, **kwargs):
672
+ assert fx_node.target == torch.ops.aten.convolution.default
673
+ if V.graph.layout_opt:
674
+ return args, kwargs
675
+ else:
676
+ return constrain_to_fx_strides(fx_node, *args, **kwargs)
677
+
678
+
679
+ add_layout_constraint(aten.convolution, constrain_conv_to_fx_strides)
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/flex_attention.py ADDED
@@ -0,0 +1,1843 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ """ Triton Implementation of the flex_attention Kernel"""
3
+
4
+ import logging
5
+ import math
6
+ from typing import Any, List, Optional, Sequence, Tuple
7
+
8
+ import sympy
9
+
10
+ import torch
11
+ from torch._inductor.virtualized import V
12
+ from torch.utils._pytree import tree_map
13
+
14
+ from .. import config
15
+ from ..ir import (
16
+ ComputedBuffer,
17
+ ExternKernel,
18
+ FixedLayout,
19
+ FlexibleLayout,
20
+ get_stride_order,
21
+ InputBuffer,
22
+ IRNode,
23
+ StorageBox,
24
+ stride_order2fill_order,
25
+ Subgraph,
26
+ TensorBox,
27
+ )
28
+ from ..lowering import empty, empty_strided, lowerings, register_lowering
29
+ from ..select_algorithm import autotune_select_algorithm, realize_inputs, TritonTemplate
30
+
31
+
32
+ log = logging.getLogger(__name__)
33
+ aten = torch.ops.aten
34
+ Expr = sympy.Expr
35
+
36
+
37
+ def construct_strides(
38
+ sizes: Sequence[int],
39
+ fill_order: Sequence[int],
40
+ ) -> Sequence[int]:
41
+ """From a list of sizes and a fill order, construct the strides of the permuted tensor."""
42
+ # Initialize strides
43
+ assert len(sizes) == len(
44
+ fill_order
45
+ ), "Length of sizes must match the length of the fill order"
46
+ strides = [0] * len(sizes)
47
+
48
+ # Start with stride 1 for the innermost dimension
49
+ current_stride = 1
50
+
51
+ # Iterate through the fill order populating strides
52
+ for dim in fill_order:
53
+ strides[dim] = current_stride
54
+ current_stride *= sizes[dim]
55
+
56
+ return strides
57
+
58
+
59
+ def flex_attention_grid(batch_size, q_heads, num_queries, d_model, meta):
60
+ """How is this kernel parallelized?
61
+ We create a grid of (batch_size * num_heads, ceil_div(n_queries, query_block_size), 1)
62
+ Each block is responsible for iterating over blocks of keys and values calculating
63
+ the final attention output.
64
+ """
65
+ import triton
66
+
67
+ return (triton.cdiv(num_queries, meta["BLOCK_M"]), batch_size * q_heads, 1)
68
+
69
+
70
+ def create_placeholder(
71
+ name: str, dtype: torch.dtype, device: torch.device
72
+ ) -> TensorBox:
73
+ """Creates a placeholder input buffers for producing subgraph_output."""
74
+ input_buffer = InputBuffer(name, FixedLayout(device, dtype, [], []))
75
+ return TensorBox.create(input_buffer)
76
+
77
+
78
+ def maybe_realize(args: List[Optional[IRNode]]):
79
+ """Accepts a list of optional IRNodes and returns a list of realized IRNodes"""
80
+ return tree_map(lambda x: realize_inputs(x) if x is not None else None, args)
81
+
82
+
83
+ def get_float32_precision():
84
+ if torch.get_float32_matmul_precision() == "highest" or torch.version.hip:
85
+ return "'ieee'"
86
+ else:
87
+ return "'tf32'"
88
+
89
+
90
+ def build_subgraph_buffer(
91
+ args: List[TensorBox],
92
+ subgraph: Subgraph,
93
+ ):
94
+ """This function's goal is to take in the required args and produce the subgraph buffer
95
+ The subgraph buffer is a ComputedBuffer that will be inlined into the triton template
96
+
97
+ Args:
98
+ args: The args that are passed into the subgraph. Contains both fixed and lifted inputs.
99
+ subgraph: The Subgraph ir for which to produce the output node
100
+ """
101
+ cnt = 0
102
+ env = {}
103
+ for node in subgraph.graph_module.graph.nodes:
104
+ # There are two classes of placeholder inpts that we need
105
+ # to handle differently. For the first n_scalar_inps inputs
106
+ # we expect that these placeholders were generated by the make_fx call
107
+ # in the flex Attention HOP. So we need to create a new placeholder
108
+ # TensorBox for each of these inputs. For the rest of the inputs we
109
+ # expect that these are lifted inputs that fill up the '*other_buffers'
110
+ # tuple and already have corresponding TensorBoxes passed in as args.
111
+ if node.op == "placeholder":
112
+ env[node] = args[cnt]
113
+ cnt += 1
114
+ elif node.op == "call_function":
115
+ # For call_function we use the default lowerings and pass in the
116
+ # already created TensorBoxes as args
117
+
118
+ args, kwargs = tree_map(
119
+ lambda x: env[x] if x in env else x, (node.args, node.kwargs)
120
+ )
121
+ env[node] = lowerings[node.target](*args, **kwargs)
122
+ elif node.op == "output":
123
+
124
+ def convert_output_node_to_buffer(output):
125
+ if output is None:
126
+ return None
127
+ output_node = output
128
+ output_buffer = env[output_node]
129
+ assert isinstance(output_buffer, TensorBox), (
130
+ "The output node for flex attention's subgraph must be a TensorBox, but got: ",
131
+ type(output_buffer),
132
+ )
133
+ assert isinstance(output_buffer.data, StorageBox), (
134
+ "The output node for the flex attention subgraph must be a StorageBox, but got: ",
135
+ type(output_buffer),
136
+ )
137
+ subgraph_buffer = ComputedBuffer(
138
+ name=None,
139
+ layout=FlexibleLayout(
140
+ device=output_buffer.data.get_device(),
141
+ dtype=output_buffer.data.get_dtype(),
142
+ size=output_buffer.data.get_size(),
143
+ ),
144
+ data=output_buffer.data.data, # type: ignore[arg-type]
145
+ )
146
+ return subgraph_buffer
147
+
148
+ # node.args[0] is either a single element or a list of elements
149
+ # representing all outputs of the function.
150
+ return tree_map(convert_output_node_to_buffer, node.args[0])
151
+
152
+ raise ValueError("FlexAttention was passed a subgraph with no output node!")
153
+
154
+
155
+ # Inner Triton functions shared by flex_attention & split-k decoding kernels.
156
+ compute_next_offset_func = r"""
157
+ @triton.jit
158
+ def get_offset_for_next_block(loop_iter, col_indices, total_blocks, SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK):
159
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
160
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
161
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
162
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
163
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
164
+
165
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
166
+ return offset
167
+ """
168
+
169
+ compute_flex_attention = r"""
170
+ {{def_kernel("Q", "K", "V", "LSE", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}}
171
+ # Sub notation for this kernel:
172
+ #
173
+ # Q: Query, K: Key, V: Value
174
+ # M: Number of queries, N: Number of keys/values, D: Model dimension
175
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
176
+ # V_HEAD_DIM: The dimension of the value embeddings
177
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
178
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
179
+ #
180
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
181
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
182
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
183
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
184
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
185
+ #
186
+ # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
187
+ #
188
+ # (Modifiable) Performance tuning options
189
+ # BLOCK_M: The thread block size across the seqlen dim of Q.
190
+ # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
191
+
192
+ # The below are kernel options that can be applied for certain score_mods,
193
+ # or involve a numerics vs. perf tradeoff
194
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
195
+ # about 20% more numerical error, but slightly faster.
196
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
197
+ # is not masked out? If so, we can skip an extra safety check
198
+
199
+ tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
200
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
201
+
202
+ # Define strides of inputs
203
+ stride_qz, stride_qh, stride_qm, stride_qk = {{stride("Q")}}
204
+ stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}}
205
+ stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}}
206
+
207
+ Z = {{size("Q", 0)}}
208
+ HQ = {{size("Q", 1)}}
209
+ Q_LEN = {{size("Q", 2)}}
210
+ KV_LEN = {{size("K", 2)}}
211
+
212
+ MATMUL_PRECISION = Q.dtype.element_ty
213
+
214
+ q_start = tl.program_id(0)
215
+ off_z = tl.program_id(1) // HQ
216
+ off_hq = tl.program_id(1) % HQ
217
+ off_hkv = off_hq // GQA_SHARED_HEADS
218
+ off_g = off_hq % GQA_SHARED_HEADS
219
+
220
+ q_offset = off_z * stride_qz + off_hq * stride_qh
221
+ k_offset = off_z * stride_kz + off_hkv * stride_kh
222
+ v_offset = off_z * stride_vz + off_hkv * stride_vh
223
+
224
+ Q = Q + q_offset
225
+ K = K + k_offset
226
+ V = V + v_offset
227
+
228
+ SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}
229
+ SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}}
230
+
231
+ sparse_idx_z = off_z % SPARSE_Z
232
+ sparse_idx_hq = off_hq % SPARSE_HQ
233
+
234
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
235
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
236
+
237
+ SPARSE_Q_BLOCK_CNT: tl.constexpr = tl.cdiv(Q_LEN, SPARSE_Q_BLOCK_SIZE)
238
+ SPARSE_KV_BLOCK_CNT: tl.constexpr = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE)
239
+
240
+ # initialize pointer to m and l
241
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
242
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
243
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM], dtype=tl.float32)
244
+
245
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
246
+
247
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
248
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
249
+ sparse_kv_num_blks_offset = sparse_hz_offset * SPARSE_Q_BLOCK_CNT + q_start // SPARSE_Q_MULTIPLE
250
+ sparse_kv_idx_offset = sparse_hz_offset * SPARSE_Q_BLOCK_CNT * SPARSE_KV_BLOCK_CNT + (q_start // SPARSE_Q_MULTIPLE) * SPARSE_KV_BLOCK_CNT # noqa: B950
251
+
252
+ Q_block_ptr = tl.make_block_ptr(
253
+ base=Q,
254
+ shape=(Q_LEN, QK_HEAD_DIM),
255
+ strides=(stride_qm, stride_qk),
256
+ offsets=(q_start * BLOCK_M, 0),
257
+ block_shape=(BLOCK_M, QK_HEAD_DIM),
258
+ order=(1, 0)
259
+ )
260
+
261
+ # load q: it stays in SRAM throughout the inner loop.
262
+ if IS_DIVISIBLE:
263
+ q = tl.load(Q_block_ptr)
264
+ else:
265
+ # boundary check is not free, so we only do it when necessary.
266
+ q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option = "zero")
267
+
268
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
269
+ # We don't know anything "special" about these blocks, so we need to apply
270
+ # both score_mod and mask_mod to it
271
+ kv_indices = KV_IDX + sparse_kv_idx_offset
272
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
273
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
274
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
275
+
276
+ K_block_ptr = tl.make_block_ptr(
277
+ base=K,
278
+ shape=(QK_HEAD_DIM, KV_LEN),
279
+ strides=(stride_kk, stride_kn),
280
+ offsets=(0, kv_start),
281
+ block_shape=(QK_HEAD_DIM, BLOCK_N),
282
+ order=(0, 1)
283
+ )
284
+ V_block_ptr = tl.make_block_ptr(
285
+ base=V,
286
+ shape=(KV_LEN, V_HEAD_DIM),
287
+ strides=(stride_vn, stride_vk),
288
+ offsets=(kv_start, 0),
289
+ block_shape=(BLOCK_N, V_HEAD_DIM),
290
+ order=(1, 0)
291
+ )
292
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
293
+
294
+ acc, l_i, m_i = forward_inner(
295
+ {{gen_argdefs()}},
296
+ q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
297
+ acc, l_i, m_i,
298
+ off_z, off_hq, offs_m[:, None], offs_n[None, :],
299
+ kv_indices, kv_num_blocks,
300
+ 0, block_n_end,
301
+ MATMUL_PRECISION,
302
+ IS_FULL_BLOCKS=False,
303
+ )
304
+
305
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
306
+ # We know these blocks are guaranteed to be "full", so we don't need to
307
+ # apply mask_mod to them - only score_mod
308
+ if HAS_FULL_BLOCKS:
309
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
310
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
311
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
312
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
313
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
314
+
315
+ K_block_ptr = tl.make_block_ptr(
316
+ base=K,
317
+ shape=(QK_HEAD_DIM, KV_LEN),
318
+ strides=(stride_kk, stride_kn),
319
+ offsets=(0, kv_start),
320
+ block_shape=(QK_HEAD_DIM, BLOCK_N),
321
+ order=(0, 1)
322
+ )
323
+ V_block_ptr = tl.make_block_ptr(
324
+ base=V,
325
+ shape=(KV_LEN, V_HEAD_DIM),
326
+ strides=(stride_vn, stride_vk),
327
+ offsets=(kv_start, 0),
328
+ block_shape=(BLOCK_N, V_HEAD_DIM),
329
+ order=(1, 0)
330
+ )
331
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
332
+
333
+ acc, l_i, m_i = forward_inner(
334
+ {{gen_argdefs()}},
335
+ q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
336
+ acc, l_i, m_i,
337
+ off_z, off_hq, offs_m[:, None], offs_n[None, :],
338
+ kv_indices, kv_num_blocks,
339
+ 0, block_n_end,
340
+ MATMUL_PRECISION,
341
+ IS_FULL_BLOCKS=True,
342
+ )
343
+
344
+
345
+ # [Note] Handle fully masked out rows:
346
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
347
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
348
+ l_i = tl.where(l_i == 0.0, 1, l_i)
349
+
350
+ acc = acc / l_i[:, None]
351
+ idx_z = tl.program_id(1) // HQ
352
+ idx_hq = tl.program_id(1) % HQ
353
+ idx_m = offs_m[:, None]
354
+ idx_d = tl.arange(0, V_HEAD_DIM)[None, :]
355
+
356
+ mask = idx_m < Q_LEN
357
+ # TODO generalize and add proper mask support
358
+ {{store_output(("idx_z", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}}
359
+
360
+ # TODO dont want to write this if we dont require grad
361
+ if OUTPUT_LOGSUMEXP:
362
+ off_hz = tl.program_id(1)
363
+ l_ptrs = LSE + off_hz * Q_LEN + offs_m
364
+ lse = m_i + tl.math.log2(l_i)
365
+ if IS_DIVISIBLE:
366
+ tl.store(l_ptrs, lse)
367
+ else:
368
+ tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
369
+ """
370
+
371
+
372
+ compute_forward_inner = r"""
373
+ @triton.jit
374
+ def forward_inner(
375
+ {{gen_argdefs()}},
376
+ q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
377
+ # accumulated values
378
+ acc, l_i, m_i,
379
+ # Offsets used as inputs to score_mod & mask_mod
380
+ # of size [BLOCK_M, BLOCK_N] or scalar.
381
+ off_z, off_h, offs_m, offs_n,
382
+ # blocksparse data
383
+ kv_indices, kv_num_blocks,
384
+ # start kv and end kv block
385
+ block_n_start, block_n_end,
386
+ MATMUL_PRECISION,
387
+ IS_FULL_BLOCKS,
388
+ ):
389
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
390
+ {{gen_defines() | indent_except_first(1)}}
391
+
392
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
393
+ RCP_LN2: tl.constexpr = 1.44269504
394
+
395
+ if PRESCALE_QK:
396
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
397
+
398
+ # loop over k, v and update accumulator until block_n_end
399
+ for start_n in range(block_n_start, block_n_end):
400
+ if IS_DIVISIBLE:
401
+ acc, l_i, m_i = forward_block_mn(
402
+ {{gen_argdefs()}},
403
+ q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
404
+ # accumulated values
405
+ acc, l_i, m_i,
406
+ # Offsets
407
+ off_z, off_h, offs_m, offs_n,
408
+ MATMUL_PRECISION, RCP_LN2,
409
+ IS_FULL_BLOCKS,
410
+ )
411
+ else:
412
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
413
+ # it's on par or slightly faster than only applying to the last block in fwd.
414
+ # However, we choose different strategy for bwd, where we only apply mod & mask
415
+ # to the last block because it's faster a lot.
416
+ acc, l_i, m_i = forward_block_mn(
417
+ {{gen_argdefs()}},
418
+ q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
419
+ # accumulated values
420
+ acc, l_i, m_i,
421
+ # Offsets
422
+ off_z, off_h, offs_m, offs_n,
423
+ MATMUL_PRECISION, RCP_LN2,
424
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
425
+ )
426
+
427
+ # update pointers
428
+ offset = get_offset_for_next_block(
429
+ start_n, kv_indices, kv_num_blocks,
430
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N
431
+ )
432
+
433
+ V_block_ptr = tl.advance(V_block_ptr, (offset, 0))
434
+ K_block_ptr = tl.advance(K_block_ptr, (0, offset))
435
+
436
+ offs_n = offs_n + offset
437
+
438
+ return acc, l_i, m_i
439
+
440
+ """
441
+
442
+
443
+ compute_forward_block_mn = r"""
444
+ @triton.jit
445
+ def forward_block_mn(
446
+ {{gen_argdefs()}},
447
+ q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
448
+ # accumulated values
449
+ acc, l_i, m_i,
450
+ # Offsets
451
+ off_z, off_h, offs_m, offs_n,
452
+ MATMUL_PRECISION, RCP_LN2,
453
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
454
+ ):
455
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
456
+ {{gen_defines() | indent_except_first(1)}}
457
+
458
+ # -- load k --
459
+ if IS_DIVISIBLE:
460
+ k = tl.load(K_block_ptr)
461
+ else:
462
+ k = tl.load(K_block_ptr, boundary_check=(1,), padding_option = "zero")
463
+ # -- compute qk ---
464
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
465
+ if not PRESCALE_QK:
466
+ qk *= SM_SCALE
467
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
468
+ if CHECK_BLOCK_BOUNDARY:
469
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
470
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
471
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
472
+ m = offs_m % Q_LEN
473
+ n = offs_n % KV_LEN
474
+ else:
475
+ m = offs_m
476
+ n = offs_n
477
+
478
+ {{ modification(
479
+ subgraph_number=0,
480
+ output_name="post_mod_scores",
481
+ score="qk",
482
+ b="off_z",
483
+ h="off_h",
484
+ m="m",
485
+ n="n",
486
+ out="qk"
487
+ ) | indent_except_first(1) }}
488
+
489
+ if CHECK_BLOCK_BOUNDARY:
490
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
491
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
492
+
493
+ if not IS_FULL_BLOCKS:
494
+ {{ modification(
495
+ subgraph_number=1,
496
+ output_name="mask_mod_output",
497
+ score="qk",
498
+ b="off_z",
499
+ h="off_h",
500
+ m="m",
501
+ n="n",
502
+ ) | indent_except_first(2) }}
503
+
504
+ if CHECK_BLOCK_BOUNDARY:
505
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, float("-inf"))
506
+ # apply mask for partially unmasked blocks
507
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
508
+
509
+ # TODO: In the case that score_mod is linear, this can be LICMed
510
+ if not PRESCALE_QK:
511
+ post_mod_scores *= RCP_LN2
512
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
513
+
514
+ # -- compute scaling constant ---
515
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
516
+ if not ROWS_GUARANTEED_SAFE:
517
+ masked_out_rows = (m_ij == float("-inf"))
518
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
519
+ else:
520
+ m_ij_masked = m_ij
521
+
522
+ alpha = tl.math.exp2(m_i - m_ij_masked)
523
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
524
+
525
+ # NB: l_i update is pulled up here since it's a bit faster
526
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
527
+ # m_ij
528
+ l_i = l_i * alpha + tl.sum(p, 1)
529
+ # # -- scale and update acc --
530
+ acc = acc * alpha[:, None]
531
+
532
+ if IS_DIVISIBLE:
533
+ v = tl.load(V_block_ptr)
534
+ else:
535
+ v = tl.load(V_block_ptr, boundary_check=(0,), padding_option = "zero")
536
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
537
+
538
+ # -- update m_i
539
+ m_i = m_ij
540
+
541
+ return acc, l_i, m_i
542
+
543
+ """
544
+
545
+
546
+ flex_attention_template = TritonTemplate(
547
+ name="flex_attention",
548
+ grid=flex_attention_grid,
549
+ source=compute_flex_attention
550
+ + compute_forward_inner
551
+ + compute_next_offset_func
552
+ + compute_forward_block_mn,
553
+ )
554
+
555
+
556
+ def _use_flex_decoding(query, kernel_options):
557
+ # Decide which kernel to use, return true if use flex decoding kernel.
558
+ return (
559
+ not kernel_options.get("FORCE_USE_FLEX_ATTENTION", False)
560
+ ) and V.graph.sizevars.evaluate_expr(sympy.Lt(query.get_size()[-2], 128))
561
+
562
+
563
+ _h100_default_config = {
564
+ (torch.float32, 64): (128, 32, 4, 3),
565
+ (torch.float32, 128): (32, 64, 4, 3),
566
+ (torch.float32, 256): (32, 32, 4, 3),
567
+ (torch.bfloat16, 64): (128, 128, 4, 3),
568
+ (torch.bfloat16, 128): (128, 64, 8, 3),
569
+ (torch.bfloat16, 256): (64, 32, 4, 3),
570
+ (torch.float16, 64): (128, 128, 4, 3),
571
+ (torch.float16, 128): (128, 128, 8, 3),
572
+ (torch.float16, 256): (64, 32, 4, 3),
573
+ }
574
+
575
+ _a100_default_config = {
576
+ (torch.float32, 64): (128, 32, 4, 3),
577
+ (torch.float32, 128): (128, 32, 4, 3),
578
+ (torch.float32, 256): (64, 16, 4, 3),
579
+ (torch.bfloat16, 64): (128, 64, 4, 3),
580
+ (torch.bfloat16, 128): (128, 64, 8, 3),
581
+ (torch.bfloat16, 256): (32, 64, 4, 3),
582
+ (torch.float16, 64): (128, 64, 4, 3),
583
+ (torch.float16, 128): (128, 64, 8, 3),
584
+ (torch.float16, 256): (32, 64, 4, 3),
585
+ }
586
+
587
+
588
+ def _get_default_config_fwd(query) -> Tuple[int, int, int, int]:
589
+ dtype = query.get_dtype()
590
+ head_dim = query.get_size()[-1]
591
+ default_config = None
592
+
593
+ if head_dim <= 256 and torch.cuda.get_device_capability() >= (9, 0): # H100
594
+ if dtype == torch.float32:
595
+ default_config = (64, 64, 4, 3)
596
+ else:
597
+ default_config = (128, 64, 4, 3)
598
+ default_config = _h100_default_config.get((dtype, head_dim), default_config)
599
+ elif head_dim <= 256 and torch.cuda.get_device_capability() >= (8, 0): # A100
600
+ if dtype == torch.float32:
601
+ default_config = (64, 64, 4, 3)
602
+ else:
603
+ default_config = (128, 64, 4, 3)
604
+ default_config = _a100_default_config.get((dtype, head_dim), default_config)
605
+ else: # modest hardware or extremely large head_dim
606
+ if dtype == torch.float32:
607
+ default_config = (32, 16, 4, 3)
608
+ else:
609
+ default_config = (64, 32, 4, 3)
610
+
611
+ return default_config
612
+
613
+
614
+ def _get_default_config_bwd(query) -> Tuple[int, int, int, int]:
615
+ head_dim = query.get_size()[-1]
616
+ dtype = query.get_dtype()
617
+
618
+ if dtype == torch.float32:
619
+ return (16, 16, 4, 1)
620
+ if head_dim <= 256 and torch.cuda.get_device_capability() >= (9, 0): # H100
621
+ if head_dim == 64:
622
+ return (64, 64, 4, 3)
623
+ elif head_dim == 128:
624
+ return (64, 128, 8, 3)
625
+ else:
626
+ return (64, 64, 4, 2)
627
+ elif torch.cuda.get_device_capability() >= (8, 0): # A100
628
+ if head_dim == 64:
629
+ return (32, 128, 4, 3)
630
+ elif head_dim == 128:
631
+ return (64, 128, 8, 3)
632
+ else:
633
+ return (64, 64, 4, 2)
634
+ else: # modest hardware or extremely large head_dim
635
+ return (16, 16, 4, 1)
636
+
637
+
638
+ def create_num_blocks_fake_generator(sparse_indices):
639
+ # The idea here is that we need to create a real tensor with real data
640
+ # that's representative for benchmarking.
641
+ # For example, returning all zeros for the `kv_num_blocks` input would mean
642
+ # that we are computing 0 blocks for each row, which would provide bogus
643
+ # autotuning results.
644
+ #
645
+ # In this case, we choose to use min(16, max_block) blocks, because I
646
+ # (Horace) think it'll probably result in pretty representative performance.
647
+ # If it's too short then prefetching won't help. If it's too long then
648
+ # autotuning will take longer for no good reason.
649
+ def create_num_blocks_fake(x) -> torch.Tensor:
650
+ num_blocks_for_autotuning = min(16, sparse_indices.shape[-1])
651
+ return torch.full(
652
+ x.get_size(),
653
+ int(num_blocks_for_autotuning),
654
+ dtype=x.get_dtype(),
655
+ device=x.get_device(),
656
+ )
657
+
658
+ return create_num_blocks_fake
659
+
660
+
661
+ def create_indices_fake(x) -> torch.Tensor:
662
+ indices = torch.arange(
663
+ 0, int(x.get_size()[-1]), dtype=x.get_dtype(), device=x.get_device()
664
+ )
665
+ indices = indices.expand(x.get_size()).contiguous()
666
+ return indices
667
+
668
+
669
+ from torch._inductor.kernel.flex_decoding import create_flex_decoding_kernel
670
+
671
+
672
+ # TODO: We probably also need a layout constraint?
673
+ @register_lowering(torch.ops.higher_order.flex_attention, type_promotion_kind=None)
674
+ def flex_attention(
675
+ query,
676
+ key,
677
+ value,
678
+ subgraph,
679
+ block_mask,
680
+ scale,
681
+ kernel_options,
682
+ score_mod_other_buffers,
683
+ mask_mod_other_buffers,
684
+ ):
685
+ (
686
+ kv_num_blocks,
687
+ kv_indices,
688
+ full_kv_num_blocks,
689
+ full_kv_indices,
690
+ q_num_blocks,
691
+ q_indices,
692
+ full_q_num_blocks,
693
+ full_q_indices,
694
+ SPARSE_KV_BLOCK_SIZE,
695
+ SPARSE_Q_BLOCK_SIZE,
696
+ mask_graph,
697
+ ) = block_mask
698
+ placeholder_inps = [
699
+ create_placeholder(name, dtype, query.get_device())
700
+ for name, dtype in [
701
+ ("score", query.get_dtype()),
702
+ ("b", torch.int32),
703
+ ("h", torch.int32),
704
+ ("m", torch.int32),
705
+ ("n", torch.int32),
706
+ ]
707
+ ]
708
+ subgraph_buffer = build_subgraph_buffer(
709
+ placeholder_inps + list(score_mod_other_buffers), subgraph
710
+ )
711
+ mask_graph_placeholder_inps = [
712
+ create_placeholder(name, dtype, query.get_device())
713
+ for name, dtype in [
714
+ ("b", torch.int32),
715
+ ("h", torch.int32),
716
+ ("m", torch.int32),
717
+ ("n", torch.int32),
718
+ ]
719
+ ]
720
+ mask_graph_buffer = build_subgraph_buffer(
721
+ mask_graph_placeholder_inps + list(mask_mod_other_buffers), mask_graph
722
+ )
723
+ kernel_options = dict(kernel_options)
724
+ kernel_options.setdefault("FLOAT32_PRECISION", get_float32_precision())
725
+ if _use_flex_decoding(query, kernel_options):
726
+ return create_flex_decoding_kernel(
727
+ query,
728
+ key,
729
+ value,
730
+ block_mask,
731
+ scale,
732
+ kernel_options,
733
+ subgraph_buffer,
734
+ mask_graph_buffer,
735
+ score_mod_other_buffers,
736
+ mask_mod_other_buffers,
737
+ )
738
+
739
+ (
740
+ query,
741
+ key,
742
+ value,
743
+ kv_num_blocks,
744
+ kv_indices,
745
+ full_kv_num_blocks,
746
+ full_kv_indices,
747
+ q_num_blocks,
748
+ q_indices,
749
+ full_q_num_blocks,
750
+ full_q_indices,
751
+ ) = maybe_realize(
752
+ [
753
+ query,
754
+ key,
755
+ value,
756
+ kv_num_blocks,
757
+ kv_indices,
758
+ full_kv_num_blocks,
759
+ full_kv_indices,
760
+ q_num_blocks,
761
+ q_indices,
762
+ full_q_num_blocks,
763
+ full_q_indices,
764
+ ]
765
+ )
766
+
767
+ Bq, Hq, seq_len_q, qk_head_dim = query.get_size()
768
+ Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size()
769
+ assert Bq == Bkv, "Batch dimension must match"
770
+ B = Bq
771
+
772
+ if seq_len_q % 128 != 0 or seq_len_kv % 128 != 0:
773
+ kernel_options.setdefault("IS_DIVISIBLE", False)
774
+ else:
775
+ kernel_options.setdefault("IS_DIVISIBLE", True)
776
+
777
+ # Reuse query strides for output layout despite different last dimension.
778
+ # This works because only the last dim differs and we check it is contiguous.
779
+ q_strides = query.get_stride()
780
+ assert q_strides[-1] == 1, "Query must be contiguous in the last dimension"
781
+
782
+ # Construct output layout with strides matching the query.
783
+ out_size = [B, Hq, seq_len_q, v_head_dim]
784
+ stride_order = get_stride_order(query.get_stride())
785
+ fill_order = stride_order2fill_order(stride_order)
786
+ out_strides = construct_strides(out_size, fill_order)
787
+
788
+ layout = FixedLayout(
789
+ query.get_device(),
790
+ query.get_dtype(),
791
+ [B, Hq, seq_len_q, v_head_dim],
792
+ stride=out_strides,
793
+ )
794
+ # see NOTE:[TritonTemplates with multiple outputs]
795
+ logsumexp_shape = [B, Hq, seq_len_q]
796
+ logsumexp = empty_strided(
797
+ logsumexp_shape,
798
+ None,
799
+ dtype=torch.float32, # The logsumexp is always stored in fp32 regardless of the input dtype
800
+ device=query.get_device(),
801
+ )
802
+ kernel_options.setdefault("SM_SCALE", scale)
803
+
804
+ # Determine GQA broadcast factor.
805
+ gqa_shared_heads = Hq // Hkv
806
+ kernel_options.setdefault("GQA_SHARED_HEADS", gqa_shared_heads)
807
+
808
+ # Inside of Triton kernel, only apply partial masking if partial blocks are computed.
809
+ # full_kv_num_blocks is None if partial blocks are not computed
810
+ has_full_blocks = full_kv_num_blocks is not None
811
+ kernel_options.setdefault("HAS_FULL_BLOCKS", has_full_blocks)
812
+ if not has_full_blocks:
813
+ full_kv_num_blocks, full_kv_indices = (
814
+ empty(0, device=query.get_device()) for _ in range(2)
815
+ )
816
+ kernel_options.setdefault("QK_HEAD_DIM", qk_head_dim)
817
+ kernel_options.setdefault("V_HEAD_DIM", v_head_dim)
818
+
819
+ choices: List[Any] = []
820
+ configs: List[Tuple[int, int, int, int]] = []
821
+ configs.append(_get_default_config_fwd(query))
822
+ if config.max_autotune:
823
+ configs += [
824
+ (128, 64, 4, 3),
825
+ (128, 128, 4, 3),
826
+ (128, 128, 8, 2),
827
+ (64, 128, 4, 3),
828
+ (64, 64, 4, 3),
829
+ ]
830
+
831
+ # Note, we don't need to pass in the captured buffers explicitly
832
+ # because they're implicitly added by the score_mod function
833
+ # We do need to explicitly pass it in for autotuning though.
834
+
835
+ for BLOCK_M, BLOCK_N, num_warps, num_stages in configs:
836
+ if SPARSE_KV_BLOCK_SIZE % BLOCK_N != 0 or SPARSE_Q_BLOCK_SIZE % BLOCK_M != 0:
837
+ continue
838
+ # Work around https://github.com/pytorch/pytorch/issues/129625
839
+ if num_stages == 2:
840
+ continue
841
+
842
+ # Performance tuning
843
+ kernel_options.setdefault("BLOCK_M", BLOCK_M)
844
+ kernel_options.setdefault("BLOCK_N", BLOCK_N)
845
+ # Blocksparse options
846
+ kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE)
847
+ kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE)
848
+
849
+ flex_attention_template.maybe_append_choice(
850
+ choices=choices,
851
+ input_nodes=[
852
+ query,
853
+ key,
854
+ value,
855
+ logsumexp,
856
+ kv_num_blocks,
857
+ kv_indices,
858
+ full_kv_num_blocks,
859
+ full_kv_indices,
860
+ ],
861
+ layout=layout,
862
+ subgraphs=[
863
+ subgraph_buffer,
864
+ mask_graph_buffer,
865
+ ],
866
+ mutated_inputs=[
867
+ logsumexp,
868
+ ],
869
+ num_stages=num_stages,
870
+ num_warps=num_warps,
871
+ call_sizes=query.get_size(),
872
+ **kernel_options,
873
+ )
874
+ inputs_for_autotuning = (
875
+ [
876
+ query,
877
+ key,
878
+ value,
879
+ logsumexp,
880
+ kv_num_blocks,
881
+ kv_indices,
882
+ full_kv_num_blocks,
883
+ full_kv_indices,
884
+ ]
885
+ + list(score_mod_other_buffers)
886
+ + list(mask_mod_other_buffers)
887
+ )
888
+ input_gen_fns = {
889
+ 4: create_num_blocks_fake_generator(kv_indices),
890
+ 5: create_indices_fake,
891
+ 6: create_num_blocks_fake_generator(full_kv_indices),
892
+ 7: create_indices_fake,
893
+ }
894
+ return (
895
+ autotune_select_algorithm(
896
+ "flex_attention",
897
+ choices,
898
+ inputs_for_autotuning,
899
+ layout,
900
+ input_gen_fns=input_gen_fns,
901
+ ),
902
+ logsumexp,
903
+ )
904
+
905
+
906
+ # ---------------------------- Backward HOP Implementation ----------------------------
907
+
908
+
909
+ def flex_attention_backward_grid(
910
+ batch_size, q_heads, num_queries, d_model, kv_heads, num_key_value, meta
911
+ ):
912
+ """How is this kernel parallelized?
913
+ Currently this is only parallelizing over batch* kv_heads, but we can, and want to
914
+ parallelize over ceil_div(q_heads//kv_heads * num_key_value, key_value_block_size).
915
+ To do this will either require atomic updates to some grad values or to have a two pass kernel design.
916
+ """
917
+ import triton
918
+
919
+ return (
920
+ triton.cdiv(num_queries, meta["BLOCK_M2"]) * (q_heads // kv_heads)
921
+ + triton.cdiv(num_key_value, meta["BLOCK_N1"]),
922
+ 1,
923
+ batch_size * kv_heads,
924
+ )
925
+
926
+
927
+ flex_attention_backward_template = TritonTemplate(
928
+ name="flex_attention_backward",
929
+ grid=flex_attention_backward_grid,
930
+ source=r"""
931
+ {{def_kernel("Q", "K", "V", "LSE", "DELTA", "DO", "DQ", "DV", "KV_NUM_BLKS", "KV_IDX", "Q_NUM_BLKS", "Q_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX", "FULL_Q_NUM_BLKS", "FULL_Q_IDX")}}
932
+ # Sub notation for this kernel:
933
+ #
934
+ # Q: Query, K: Key, V: Value
935
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
936
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
937
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
938
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
939
+ # inductor codegen
940
+ # M: Number of queries, N: Number of keys/values
941
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
942
+ # V_HEAD_DIM: The dimension of the value embeddings
943
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
944
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
945
+ # (Modifiable) Performance tuning options
946
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
947
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
948
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
949
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
950
+ #
951
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
952
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
953
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
954
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
955
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
956
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
957
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
958
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
959
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
960
+
961
+ # The below are kernel options that can be applied for certain score_mods,
962
+ # or involve a numerics vs. perf tradeoff
963
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
964
+ # about 20% more numerical error, but slightly faster.
965
+
966
+ # Define strides of inputs
967
+ stride_qz, stride_qh, stride_qm, stride_qd = {{stride("Q")}}
968
+ stride_kz, stride_kh, stride_kn, stride_kd = {{stride("K")}}
969
+ stride_vz, stride_vh, stride_vn, stride_vd = {{stride("V")}}
970
+ stride_doz, stride_doh, stride_dom, stride_dod = {{stride("DO")}}
971
+
972
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = {{stride("DQ")}}
973
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = {{stride("DV")}}
974
+
975
+ Z = {{size("Q", 0)}}
976
+ HQ = {{size("Q", 1)}}
977
+ HKV = {{size("K", 1)}}
978
+ Q_LEN = {{size("Q", 2)}}
979
+ KV_LEN = {{size("K", 2)}}
980
+
981
+ MATMUL_PRECISION = Q.dtype.element_ty
982
+
983
+ pid = tl.program_id(0)
984
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
985
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
986
+
987
+ off_hz = tl.program_id(2)
988
+ off_z = off_hz // HKV # batch idx
989
+ off_hkv = off_hz % HKV # kv head idx
990
+
991
+ SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}
992
+ SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}}
993
+
994
+ sparse_idx_z = off_z % SPARSE_Z
995
+
996
+ k_adj = (stride_kh * off_hkv + stride_kz * off_z).to(tl.int64)
997
+ v_adj = (stride_vh * off_hkv + stride_vz * off_z).to(tl.int64)
998
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_z).to(tl.int64)
999
+
1000
+ # offset K, V, DV pointers for batch/kv-head
1001
+ K += k_adj
1002
+ V += v_adj
1003
+ DV += dv_adj
1004
+
1005
+ RCP_LN2 = 1.44269504
1006
+ offs_k = tl.arange(0, QK_HEAD_DIM)
1007
+ offs_v = tl.arange(0, V_HEAD_DIM)
1008
+
1009
+ if pid >= NUM_KV_BLOCKS:
1010
+ off_pid = pid - NUM_KV_BLOCKS
1011
+ # THIS BLOCK DOES DQ
1012
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
1013
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
1014
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
1015
+ start_m2_block = off_pid % NUM_Q_BLOCKS
1016
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
1017
+ stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}}
1018
+ stride_kv_idx_h = {{stride("KV_IDX", 1)}}
1019
+ stride_kv_idx_m = {{stride("KV_IDX", 2)}}
1020
+
1021
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
1022
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
1023
+
1024
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
1025
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
1026
+
1027
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offseted by query heads.
1028
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_z).to(tl.int64)
1029
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_z).to(tl.int64)
1030
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_z).to(tl.int64)
1031
+ off_chz2 = ((off_z * HQ + off_hq2) * Q_LEN).to(tl.int64)
1032
+
1033
+ Q2 = Q + q_adj2
1034
+ DO2 = DO + do_adj2
1035
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
1036
+ # if Q is broadcasted)
1037
+ DQ2 = DQ + dq_adj2
1038
+ LSE2 = LSE + off_chz2
1039
+ DELTA2 = DELTA + off_chz2
1040
+
1041
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
1042
+
1043
+ start_m2 = start_m2_block * BLOCK_M2
1044
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
1045
+
1046
+ # load Q and do: they stay in SRAM throughout the inner loop.
1047
+ if IS_DIVISIBLE:
1048
+ q = tl.load(Q2 + offs_m2[:, None] * stride_qm + offs_k[None, :] * stride_qd)
1049
+ do = tl.load(DO2 + offs_m2[:, None] * stride_dom + offs_v[None, :] * stride_dod)
1050
+ else:
1051
+ q = tl.load(Q2 + offs_m2[:, None] * stride_qm + offs_k[None, :] * stride_qd, mask=offs_m2[:, None] < Q_LEN)
1052
+ do = tl.load(DO2 + offs_m2[:, None] * stride_dom + offs_v[None, :] * stride_dod, mask=offs_m2[:, None] < Q_LEN)
1053
+
1054
+ if PRESCALE_QK:
1055
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
1056
+
1057
+ if IS_DIVISIBLE:
1058
+ Di = tl.load(DELTA2 + offs_m2)
1059
+ lse = tl.load(LSE2 + offs_m2)
1060
+ else:
1061
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
1062
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
1063
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
1064
+ lse = lse[:, None]
1065
+
1066
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1067
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
1068
+ kv_indices = KV_IDX + sparse_kv_idx_offset
1069
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
1070
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
1071
+
1072
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
1073
+ dq = bwd_dq_inner(
1074
+ {{gen_argdefs()}},
1075
+ K, V,
1076
+ dq, q, do, Di, lse,
1077
+ off_z, off_hq2, offs_m2, offs_n2,
1078
+ stride_kn, stride_kd, stride_vn, stride_vd,
1079
+ kv_indices, sparse_kv_num_blocks,
1080
+ MATMUL_PRECISION,
1081
+ IS_FULL_BLOCKS=False,
1082
+ )
1083
+
1084
+ if HAS_FULL_BLOCKS:
1085
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1086
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
1087
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
1088
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
1089
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
1090
+
1091
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
1092
+ dq = bwd_dq_inner(
1093
+ {{gen_argdefs()}},
1094
+ K, V,
1095
+ dq, q, do, Di, lse,
1096
+ off_z, off_hq2, offs_m2, offs_n2,
1097
+ stride_kn, stride_kd, stride_vn, stride_vd,
1098
+ kv_indices, sparse_kv_num_blocks,
1099
+ MATMUL_PRECISION,
1100
+ IS_FULL_BLOCKS=True,
1101
+ )
1102
+
1103
+ # Write back dQ.
1104
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
1105
+ dq *= SM_SCALE
1106
+ if IS_DIVISIBLE:
1107
+ tl.store(dq_ptrs, dq)
1108
+ else:
1109
+ tl.store(dq_ptrs, dq, mask=offs_m2[:, None] < Q_LEN)
1110
+ else:
1111
+ # THIS BLOCK DOES DK & DV
1112
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
1113
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
1114
+
1115
+ pid_mask = pid // SPARSE_KV_MULTIPLE
1116
+
1117
+ stride_q_num_blks_h = {{stride("Q_NUM_BLKS", 1)}}
1118
+ stride_q_idx_h = {{stride("Q_IDX", 1)}}
1119
+ stride_q_idx_n = {{stride("Q_IDX", 2)}}
1120
+
1121
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM], dtype=tl.float32)
1122
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM], dtype=tl.float32)
1123
+
1124
+ start_n1 = pid * BLOCK_N1
1125
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
1126
+
1127
+ # load K and V: they stay in SRAM throughout the inner loop.
1128
+ if IS_DIVISIBLE:
1129
+ k = tl.load(K + offs_n1[:, None] * stride_kn + offs_k[None, :] * stride_kd)
1130
+ v = tl.load(V + offs_n1[:, None] * stride_vn + offs_v[None, :] * stride_vd)
1131
+ else:
1132
+ k = tl.load(K + offs_n1[:, None] * stride_kn + offs_k[None, :] * stride_kd, mask=offs_n1[:, None] < KV_LEN)
1133
+ v = tl.load(V + offs_n1[:, None] * stride_vn + offs_v[None, :] * stride_vd, mask=offs_n1[:, None] < KV_LEN)
1134
+ if PRESCALE_QK:
1135
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
1136
+
1137
+ for off_g in range(0, GQA_SHARED_HEADS):
1138
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
1139
+
1140
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offseted by query heads.
1141
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_z).to(tl.int64)
1142
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_z).to(tl.int64)
1143
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_z).to(tl.int64)
1144
+ off_chz1 = ((off_z * HQ + off_hq1) * Q_LEN).to(tl.int64)
1145
+
1146
+ Q1 = Q + q_adj1
1147
+ DO1 = DO + do_adj1
1148
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
1149
+ # if Q is broadcasted)
1150
+ LSE1 = LSE + off_chz1
1151
+ DELTA1 = DELTA + off_chz1
1152
+
1153
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
1154
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
1155
+
1156
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
1157
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
1158
+
1159
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1160
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
1161
+ q_indices = Q_IDX + sparse_q_idx_offset
1162
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
1163
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
1164
+
1165
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
1166
+ dk, dv = bwd_dkdv_inner(
1167
+ {{gen_argdefs()}},
1168
+ Q1, DO1, DELTA1, LSE1,
1169
+ dk, dv, k, v,
1170
+ off_z, off_hq1, offs_n1, offs_m1,
1171
+ stride_qm, stride_qd, stride_dom, stride_dod,
1172
+ q_indices, sparse_q_num_blocks,
1173
+ MATMUL_PRECISION,
1174
+ IS_FULL_BLOCKS=False,
1175
+ )
1176
+
1177
+
1178
+ if HAS_FULL_BLOCKS:
1179
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1180
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
1181
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
1182
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
1183
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
1184
+
1185
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
1186
+ dk, dv = bwd_dkdv_inner(
1187
+ {{gen_argdefs()}},
1188
+ Q1, DO1, DELTA1, LSE1,
1189
+ dk, dv, k, v,
1190
+ off_z, off_hq1, offs_n1, offs_m1,
1191
+ stride_qm, stride_qd, stride_dom, stride_dod,
1192
+ q_indices, sparse_q_num_blocks,
1193
+ MATMUL_PRECISION,
1194
+ IS_FULL_BLOCKS=True,
1195
+ )
1196
+
1197
+ # Write back dV and dK.
1198
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
1199
+
1200
+ index_n = offs_n1[:, None]
1201
+ index_k = offs_k[None, :]
1202
+
1203
+ if IS_DIVISIBLE:
1204
+ tl.store(dv_ptrs, dv)
1205
+ else:
1206
+ tl.store(dv_ptrs, dv, mask=index_n < KV_LEN)
1207
+
1208
+ dk *= SM_SCALE
1209
+ mask = index_n < KV_LEN
1210
+ {{store_output(("off_z", "off_hkv", "index_n", "index_k"), "dk", "mask", indent_width=8)}}
1211
+
1212
+ @triton.jit
1213
+ def bwd_dq_inner(
1214
+ {{gen_argdefs()}},
1215
+ K, V, # pointers
1216
+ dq, q, do, Di, lse,
1217
+ off_z, off_hq, offs_m2, offs_n2,
1218
+ stride_kn, stride_kd, stride_vn, stride_vd,
1219
+ kv_indices, sparse_kv_num_blocks,
1220
+ MATMUL_PRECISION,
1221
+ IS_FULL_BLOCKS,
1222
+ ):
1223
+ {{gen_defines() | indent_except_first(1) }}
1224
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
1225
+ RCP_LN2: tl.constexpr = 1.44269504
1226
+ Q_LEN = {{size("Q", 2)}}
1227
+ KV_LEN = {{size("K", 2)}}
1228
+
1229
+ offs_k = tl.arange(0, QK_HEAD_DIM)
1230
+ offs_v = tl.arange(0, V_HEAD_DIM)
1231
+
1232
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
1233
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
1234
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
1235
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
1236
+
1237
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
1238
+ if not IS_DIVISIBLE:
1239
+ if hi >= 1:
1240
+ for start_n in range(0, hi - 1):
1241
+ dq = bwd_dq_block_mn(
1242
+ {{gen_argdefs()}},
1243
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
1244
+ off_z, off_hq, offs_m2, offs_n2,
1245
+ stride_kn, stride_kd, stride_vn, stride_vd,
1246
+ kv_indices, sparse_kv_num_blocks,
1247
+ MATMUL_PRECISION, RCP_LN2,
1248
+ IS_FULL_BLOCKS,
1249
+ )
1250
+
1251
+ # Increment pointers.
1252
+ offset = get_offset_for_next_block(
1253
+ start_n, kv_indices, sparse_kv_num_blocks,
1254
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2
1255
+ )
1256
+
1257
+ kT_ptrs += offset * stride_kn
1258
+ vT_ptrs += offset * stride_vn
1259
+
1260
+ offs_n2 += offset
1261
+
1262
+ dq = bwd_dq_block_mn(
1263
+ {{gen_argdefs()}},
1264
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
1265
+ off_z, off_hq, offs_m2, offs_n2,
1266
+ stride_kn, stride_kd, stride_vn, stride_vd,
1267
+ kv_indices, sparse_kv_num_blocks,
1268
+ MATMUL_PRECISION, RCP_LN2,
1269
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
1270
+ )
1271
+ else:
1272
+ for start_n in range(0, hi):
1273
+ dq = bwd_dq_block_mn(
1274
+ {{gen_argdefs()}},
1275
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
1276
+ off_z, off_hq, offs_m2, offs_n2,
1277
+ stride_kn, stride_kd, stride_vn, stride_vd,
1278
+ kv_indices, sparse_kv_num_blocks,
1279
+ MATMUL_PRECISION, RCP_LN2,
1280
+ IS_FULL_BLOCKS,
1281
+ )
1282
+
1283
+ # Increment pointers.
1284
+ offset = get_offset_for_next_block(
1285
+ start_n, kv_indices, sparse_kv_num_blocks,
1286
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2
1287
+ )
1288
+
1289
+ kT_ptrs += offset * stride_kn
1290
+ vT_ptrs += offset * stride_vn
1291
+
1292
+ offs_n2 += offset
1293
+
1294
+ return dq
1295
+
1296
+
1297
+ @triton.jit
1298
+ def bwd_dq_block_mn(
1299
+ {{gen_argdefs()}},
1300
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
1301
+ off_z, off_hq, offs_m2, offs_n2,
1302
+ stride_kn, stride_kd, stride_vn, stride_vd,
1303
+ kv_indices, sparse_kv_num_blocks,
1304
+ MATMUL_PRECISION, RCP_LN2,
1305
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
1306
+ ):
1307
+ {{gen_defines() | indent_except_first(1)}}
1308
+
1309
+ if IS_DIVISIBLE:
1310
+ kT = tl.load(kT_ptrs)
1311
+ else:
1312
+ kT = tl.load(kT_ptrs, mask=offs_n2[None, :] < KV_LEN)
1313
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
1314
+ if not PRESCALE_QK:
1315
+ qk *= SM_SCALE
1316
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
1317
+ pre_mod_scores = qk
1318
+ if CHECK_BLOCK_BOUNDARY:
1319
+ m = offs_m2[:, None] % Q_LEN
1320
+ n = offs_n2[None, :] % KV_LEN
1321
+ else:
1322
+ m = offs_m2[:, None]
1323
+ n = offs_n2[None, :]
1324
+ {{ modification(
1325
+ subgraph_number=0,
1326
+ output_name="post_mod_scores",
1327
+ score="qk",
1328
+ b="off_z",
1329
+ h="off_hq",
1330
+ m="m",
1331
+ n="n",
1332
+ out="qk"
1333
+ ) | indent_except_first(1) }}
1334
+
1335
+ if CHECK_BLOCK_BOUNDARY:
1336
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
1337
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
1338
+
1339
+ if not IS_FULL_BLOCKS:
1340
+ {{ modification(
1341
+ subgraph_number=2,
1342
+ output_name="mask_mod_output",
1343
+ score="qk",
1344
+ b="off_z",
1345
+ h="off_hq",
1346
+ m="m",
1347
+ n="n",
1348
+ ) | indent_except_first(2) }}
1349
+
1350
+ if CHECK_BLOCK_BOUNDARY:
1351
+ mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, float("-inf"))
1352
+ # apply mask for partial masked block
1353
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
1354
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1355
+ if not PRESCALE_QK:
1356
+ post_mod_scores *= RCP_LN2
1357
+ p = tl.math.exp2(post_mod_scores - lse)
1358
+ # Compute dP and dS.
1359
+ if IS_DIVISIBLE:
1360
+ vT = tl.load(vT_ptrs)
1361
+ else:
1362
+ vT = tl.load(vT_ptrs, mask=offs_n2[None, :] < KV_LEN)
1363
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
1364
+ ds = p * (dp - Di[:, None])
1365
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
1366
+ {{ modification(
1367
+ subgraph_number=1,
1368
+ output_name = "grad_scores",
1369
+ score="pre_mod_scores",
1370
+ b="off_z",
1371
+ h="off_hq",
1372
+ m="m",
1373
+ n="n",
1374
+ grad_score_mod="ds"
1375
+ ) | indent_except_first(1) }}
1376
+ if CHECK_BLOCK_BOUNDARY:
1377
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
1378
+
1379
+ ds = grad_scores
1380
+
1381
+ if not IS_FULL_BLOCKS:
1382
+ if CHECK_BLOCK_BOUNDARY:
1383
+ mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, float("-inf"))
1384
+ # (grads) apply mask for partially unmasked block
1385
+ ds = tl.where(mask_mod_output, ds, 0.0)
1386
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1387
+ ds = ds.to(MATMUL_PRECISION)
1388
+ # Compute dQ.
1389
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
1390
+
1391
+ return dq
1392
+
1393
+
1394
+ @triton.jit
1395
+ def bwd_dkdv_inner(
1396
+ {{gen_argdefs()}},
1397
+ Q, DO, DELTA, LSE, # pointers
1398
+ dk, dv, k, v,
1399
+ off_z, off_hq, offs_n1, offs_m1,
1400
+ stride_qm, stride_qd, stride_dom, stride_dod,
1401
+ q_indices, sparse_q_num_blocks,
1402
+ MATMUL_PRECISION,
1403
+ IS_FULL_BLOCKS,
1404
+ ):
1405
+ {{gen_defines() | indent_except_first(1) }}
1406
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
1407
+ RCP_LN2: tl.constexpr = 1.44269504
1408
+ Q_LEN = {{size("Q", 2)}}
1409
+ KV_LEN = {{size("K", 2)}}
1410
+
1411
+ offs_k = tl.arange(0, QK_HEAD_DIM)
1412
+ offs_v = tl.arange(0, V_HEAD_DIM)
1413
+
1414
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
1415
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
1416
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
1417
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
1418
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
1419
+
1420
+ if not IS_DIVISIBLE:
1421
+ if hi >= 1:
1422
+ for start_m in range(0, hi - 1):
1423
+ dk, dv = bwd_dkdv_block_mn(
1424
+ {{gen_argdefs()}},
1425
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
1426
+ off_z, off_hq, offs_n1, offs_m1,
1427
+ stride_qm, stride_qd, stride_dom, stride_dod,
1428
+ q_indices, sparse_q_num_blocks,
1429
+ MATMUL_PRECISION, RCP_LN2,
1430
+ IS_FULL_BLOCKS,
1431
+ )
1432
+ # Increment pointers.
1433
+ offset = get_offset_for_next_block(
1434
+ start_m, q_indices, sparse_q_num_blocks,
1435
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1
1436
+ )
1437
+
1438
+ qT_ptrs += offset * stride_qm
1439
+ do_ptrs += offset * stride_dom
1440
+
1441
+ offs_m1 += offset
1442
+
1443
+ dk, dv = bwd_dkdv_block_mn(
1444
+ {{gen_argdefs()}},
1445
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
1446
+ off_z, off_hq, offs_n1, offs_m1,
1447
+ stride_qm, stride_qd, stride_dom, stride_dod,
1448
+ q_indices, sparse_q_num_blocks,
1449
+ MATMUL_PRECISION, RCP_LN2,
1450
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
1451
+ )
1452
+ else:
1453
+ for start_m in range(0, hi):
1454
+ dk, dv = bwd_dkdv_block_mn(
1455
+ {{gen_argdefs()}},
1456
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
1457
+ off_z, off_hq, offs_n1, offs_m1,
1458
+ stride_qm, stride_qd, stride_dom, stride_dod,
1459
+ q_indices, sparse_q_num_blocks,
1460
+ MATMUL_PRECISION, RCP_LN2,
1461
+ IS_FULL_BLOCKS,
1462
+ )
1463
+ # Increment pointers.
1464
+ offset = get_offset_for_next_block(
1465
+ start_m, q_indices, sparse_q_num_blocks,
1466
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1
1467
+ )
1468
+
1469
+ qT_ptrs += offset * stride_qm
1470
+ do_ptrs += offset * stride_dom
1471
+
1472
+ offs_m1 += offset
1473
+
1474
+ return dk, dv
1475
+
1476
+
1477
+ @triton.jit
1478
+ def bwd_dkdv_block_mn(
1479
+ {{gen_argdefs()}},
1480
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
1481
+ off_z, off_hq, offs_n1, offs_m1,
1482
+ stride_qm, stride_qd, stride_dom, stride_dod,
1483
+ q_indices, sparse_q_num_blocks,
1484
+ MATMUL_PRECISION, RCP_LN2,
1485
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
1486
+ ):
1487
+ {{gen_defines() | indent_except_first(1) }}
1488
+
1489
+ # Load LSE before computing qk to reduce pipeline stall.
1490
+ if IS_DIVISIBLE:
1491
+ qT = tl.load(qT_ptrs)
1492
+ lse = tl.load(LSE + offs_m1)
1493
+ else:
1494
+ qT = tl.load(qT_ptrs, mask=offs_m1[None, :] < Q_LEN)
1495
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
1496
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
1497
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
1498
+ if not PRESCALE_QK:
1499
+ qkT *= SM_SCALE
1500
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
1501
+ if CHECK_BLOCK_BOUNDARY:
1502
+ m = offs_m1[None, :] % Q_LEN
1503
+ n = offs_n1[:, None] % KV_LEN
1504
+ else:
1505
+ m = offs_m1[None, :]
1506
+ n = offs_n1[:, None]
1507
+ pre_mod_scores = qkT
1508
+ {{ modification(
1509
+ subgraph_number=0,
1510
+ output_name="post_mod_scores",
1511
+ score="qkT",
1512
+ b="off_z",
1513
+ h="off_hq",
1514
+ m="m",
1515
+ n="n",
1516
+ out="qkT"
1517
+ ) | indent_except_first(1) }}
1518
+
1519
+ if CHECK_BLOCK_BOUNDARY:
1520
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
1521
+ post_mod_scores = tl.where(offs_n1[:, None] < KV_LEN, post_mod_scores, float("-inf"))
1522
+
1523
+ if not IS_FULL_BLOCKS:
1524
+ {{ modification(
1525
+ subgraph_number=2,
1526
+ output_name="mask_mod_output",
1527
+ score="qkT",
1528
+ b="off_z",
1529
+ h="off_hq",
1530
+ m="m",
1531
+ n="n",
1532
+ ) | indent_except_first(2) }}
1533
+ if CHECK_BLOCK_BOUNDARY:
1534
+ mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, float("-inf"))
1535
+ # (grads) apply mask for fully masked block
1536
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
1537
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1538
+ if not PRESCALE_QK:
1539
+ post_mod_scores *= RCP_LN2
1540
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
1541
+ if IS_DIVISIBLE:
1542
+ do = tl.load(do_ptrs)
1543
+ else:
1544
+ do = tl.load(do_ptrs, mask=offs_m1[:, None] < Q_LEN)
1545
+ # Compute dV.
1546
+ ppT = pT
1547
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
1548
+ if IS_DIVISIBLE:
1549
+ Di = tl.load(DELTA + offs_m1)
1550
+ else:
1551
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
1552
+ # Compute dP and dS.
1553
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
1554
+ dsT = pT * (dpT - Di[None, :])
1555
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
1556
+ {{ modification(
1557
+ subgraph_number=1,
1558
+ output_name = "grad_scores",
1559
+ score="pre_mod_scores",
1560
+ b="off_z",
1561
+ h="off_hq",
1562
+ m="m",
1563
+ n="n",
1564
+ grad_score_mod="dsT"
1565
+ ) | indent_except_first(1) }}
1566
+ if CHECK_BLOCK_BOUNDARY:
1567
+ grad_scores = tl.where(offs_n1[:, None] < KV_LEN, grad_scores, 0.0)
1568
+
1569
+ dsT = grad_scores
1570
+ if not IS_FULL_BLOCKS:
1571
+ if CHECK_BLOCK_BOUNDARY:
1572
+ mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, float("-inf"))
1573
+ # (grads) apply mask for partially unmasked block
1574
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
1575
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1576
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
1577
+
1578
+ return dk, dv
1579
+ """
1580
+ + compute_next_offset_func,
1581
+ )
1582
+
1583
+
1584
+ # TODO: We probably also need a layout constraint?
1585
+ @register_lowering(
1586
+ torch.ops.higher_order.flex_attention_backward, type_promotion_kind=None
1587
+ )
1588
+ def flex_attention_backward(*args, **kwargs):
1589
+ (
1590
+ query,
1591
+ key,
1592
+ value,
1593
+ out,
1594
+ logsumexp,
1595
+ grad_out,
1596
+ grad_logsumexp,
1597
+ fw_graph,
1598
+ joint_graph,
1599
+ block_mask,
1600
+ scale,
1601
+ kernel_options,
1602
+ score_mod_other_buffers,
1603
+ mask_mod_other_buffers,
1604
+ ) = args
1605
+ (
1606
+ kv_num_blocks,
1607
+ kv_indices,
1608
+ full_kv_num_blocks,
1609
+ full_kv_indices,
1610
+ q_num_blocks,
1611
+ q_indices,
1612
+ full_q_num_blocks,
1613
+ full_q_indices,
1614
+ SPARSE_KV_BLOCK_SIZE,
1615
+ SPARSE_Q_BLOCK_SIZE,
1616
+ mask_graph,
1617
+ ) = block_mask
1618
+
1619
+ (
1620
+ query,
1621
+ key,
1622
+ value,
1623
+ grad_out,
1624
+ kv_num_blocks,
1625
+ kv_indices,
1626
+ full_kv_num_blocks,
1627
+ full_kv_indices,
1628
+ q_num_blocks,
1629
+ q_indices,
1630
+ full_q_num_blocks,
1631
+ full_q_indices,
1632
+ ) = maybe_realize(
1633
+ [
1634
+ query,
1635
+ key,
1636
+ value,
1637
+ grad_out,
1638
+ kv_num_blocks,
1639
+ kv_indices,
1640
+ full_kv_num_blocks,
1641
+ full_kv_indices,
1642
+ q_num_blocks,
1643
+ q_indices,
1644
+ full_q_num_blocks,
1645
+ full_q_indices,
1646
+ ]
1647
+ )
1648
+
1649
+ device = query.get_device()
1650
+ dtype = query.get_dtype()
1651
+ Bq, Hq, seq_len_q, qk_head_dim = query.get_size()
1652
+ Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size()
1653
+ assert Bq == Bkv, "Batch dimension must match"
1654
+ B = Bq
1655
+
1656
+ kernel_options = dict(kernel_options)
1657
+ kernel_options.setdefault("FLOAT32_PRECISION", get_float32_precision())
1658
+ if seq_len_q % 128 != 0 or seq_len_kv % 128 != 0:
1659
+ kernel_options.setdefault("IS_DIVISIBLE", False)
1660
+ else:
1661
+ kernel_options.setdefault("IS_DIVISIBLE", True)
1662
+
1663
+ fwd_placeholder_inps = [
1664
+ create_placeholder(name, dtype, device)
1665
+ for name, dtype in [
1666
+ ("score", dtype),
1667
+ ("b", torch.int32),
1668
+ ("h", torch.int32),
1669
+ ("m", torch.int32),
1670
+ ("n", torch.int32),
1671
+ ]
1672
+ ]
1673
+ fw_subgraph_buffer = build_subgraph_buffer(
1674
+ fwd_placeholder_inps + list(score_mod_other_buffers), fw_graph
1675
+ )
1676
+
1677
+ joint_placeholder_inps = fwd_placeholder_inps + [
1678
+ create_placeholder("grad_score_mod", dtype, device)
1679
+ ]
1680
+ joint_subgraph_buffer, *_ = build_subgraph_buffer(
1681
+ joint_placeholder_inps + list(score_mod_other_buffers), joint_graph
1682
+ )
1683
+
1684
+ mask_graph_placeholder_inps = [
1685
+ create_placeholder(name, dtype, query.get_device())
1686
+ for name, dtype in [
1687
+ ("b", torch.int32),
1688
+ ("h", torch.int32),
1689
+ ("m", torch.int32),
1690
+ ("n", torch.int32),
1691
+ ]
1692
+ ]
1693
+ mask_graph_buffer = build_subgraph_buffer(
1694
+ mask_graph_placeholder_inps + list(mask_mod_other_buffers), mask_graph
1695
+ )
1696
+
1697
+ layout_k = FixedLayout(
1698
+ key.get_device(),
1699
+ key.get_dtype(),
1700
+ key.get_size(),
1701
+ key.get_stride(),
1702
+ )
1703
+
1704
+ # Create delta which will is needed for the bwd's kernel
1705
+ grad_lse_exp2 = lowerings[aten.mul](grad_logsumexp, 1 / math.log(2))
1706
+ mul_delta = lowerings[aten.mul](out, grad_out)
1707
+ delta = lowerings[aten.sum](mul_delta, axis=-1)
1708
+ delta = lowerings[aten.sub](delta, grad_lse_exp2)
1709
+ delta = ExternKernel.require_contiguous(delta)
1710
+
1711
+ grad_lse_exp2, delta = maybe_realize([grad_lse_exp2, delta])
1712
+
1713
+ # see NOTE:[TritonTemplates with multiple outputs]
1714
+ grad_query = empty_strided(
1715
+ query.get_size(), query.get_stride(), dtype=dtype, device=device
1716
+ )
1717
+ grad_value = empty_strided(
1718
+ value.get_size(), value.get_stride(), dtype=dtype, device=device
1719
+ )
1720
+
1721
+ kernel_options.setdefault("SM_SCALE", scale)
1722
+
1723
+ # Determine GQA factor
1724
+ gqa_shared_heads = Hq // Hkv
1725
+ kernel_options.setdefault("GQA_SHARED_HEADS", gqa_shared_heads)
1726
+
1727
+ # Inside of Triton kernel, only apply partial masking if partial blocks are computed.
1728
+ # full_kv_num_blocks is torch.zeros([1, 1, 1]) if partial blocks are not computed.
1729
+ has_full_blocks = full_kv_num_blocks is not None
1730
+ kernel_options.setdefault("HAS_FULL_BLOCKS", has_full_blocks)
1731
+ if not has_full_blocks:
1732
+ full_kv_num_blocks, full_kv_indices, full_q_num_blocks, full_q_indices = (
1733
+ empty(0, device=query.get_device()) for _ in range(4)
1734
+ )
1735
+ kernel_options.setdefault("QK_HEAD_DIM", qk_head_dim)
1736
+ kernel_options.setdefault("V_HEAD_DIM", v_head_dim)
1737
+
1738
+ choices: List[Any] = []
1739
+ configs: List[Tuple[int, int, int, int]] = []
1740
+ configs.append(_get_default_config_bwd(query))
1741
+ if config.max_autotune:
1742
+ configs.extend(
1743
+ [
1744
+ (BLOCK1, BLOCK2, w, s)
1745
+ for BLOCK1 in [32, 64]
1746
+ for BLOCK2 in [32, 64, 128]
1747
+ for w in [4, 8]
1748
+ for s in [1, 3, 4, 5]
1749
+ if BLOCK2 % BLOCK1 == 0
1750
+ ]
1751
+ )
1752
+
1753
+ for BLOCK1, BLOCK2, num_warps, num_stages in configs:
1754
+ if (
1755
+ SPARSE_KV_BLOCK_SIZE % BLOCK1 != 0
1756
+ or SPARSE_Q_BLOCK_SIZE % BLOCK1 != 0
1757
+ or SPARSE_KV_BLOCK_SIZE % BLOCK2 != 0
1758
+ or SPARSE_Q_BLOCK_SIZE % BLOCK2 != 0
1759
+ ):
1760
+ continue
1761
+
1762
+ # Performance tuning
1763
+ kernel_options.setdefault("BLOCK_M1", BLOCK1)
1764
+ kernel_options.setdefault("BLOCK_N1", BLOCK2)
1765
+ kernel_options.setdefault("BLOCK_M2", BLOCK2)
1766
+ kernel_options.setdefault("BLOCK_N2", BLOCK1)
1767
+ # Blocksparse options
1768
+ kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE)
1769
+ kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE)
1770
+
1771
+ flex_attention_backward_template.maybe_append_choice(
1772
+ choices=choices,
1773
+ input_nodes=[
1774
+ query,
1775
+ key,
1776
+ value,
1777
+ logsumexp,
1778
+ delta,
1779
+ grad_out,
1780
+ grad_query,
1781
+ grad_value,
1782
+ kv_num_blocks,
1783
+ kv_indices,
1784
+ q_num_blocks,
1785
+ q_indices,
1786
+ full_kv_num_blocks,
1787
+ full_kv_indices,
1788
+ full_q_num_blocks,
1789
+ full_q_indices,
1790
+ ],
1791
+ layout=layout_k, # We use store_output only for grad_key
1792
+ subgraphs=[fw_subgraph_buffer, joint_subgraph_buffer, mask_graph_buffer],
1793
+ mutated_inputs=[grad_query, grad_value],
1794
+ call_sizes=query.get_size() + key.get_size()[1:3],
1795
+ num_stages=num_stages,
1796
+ num_warps=num_warps,
1797
+ **kernel_options,
1798
+ )
1799
+ inputs_for_autotuning = (
1800
+ [
1801
+ query,
1802
+ key,
1803
+ value,
1804
+ logsumexp,
1805
+ delta,
1806
+ grad_out,
1807
+ grad_query,
1808
+ grad_value,
1809
+ kv_num_blocks,
1810
+ kv_indices,
1811
+ q_num_blocks,
1812
+ q_indices,
1813
+ full_kv_num_blocks,
1814
+ full_kv_indices,
1815
+ full_q_num_blocks,
1816
+ full_q_indices,
1817
+ ]
1818
+ + list(score_mod_other_buffers)
1819
+ + list(mask_mod_other_buffers)
1820
+ )
1821
+ input_gen_fns = {
1822
+ 8: create_num_blocks_fake_generator(kv_indices), # kv_num_blocks
1823
+ 9: create_indices_fake,
1824
+ 10: create_num_blocks_fake_generator(q_indices), # q_num_blocks
1825
+ 11: create_indices_fake,
1826
+ 12: create_num_blocks_fake_generator(full_kv_indices), # full_kv_num_blocks
1827
+ 13: create_indices_fake,
1828
+ 14: create_num_blocks_fake_generator(full_q_indices), # full_q_num_blocks
1829
+ 15: create_indices_fake,
1830
+ }
1831
+
1832
+ grad_key = autotune_select_algorithm(
1833
+ "flex_attention_backward",
1834
+ choices,
1835
+ inputs_for_autotuning,
1836
+ layout_k,
1837
+ input_gen_fns=input_gen_fns,
1838
+ )
1839
+ return (
1840
+ grad_query,
1841
+ grad_key,
1842
+ grad_value,
1843
+ )
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/flex_decoding.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ """ Triton Implementation of the flex_attention Kernel for short query length (FlexDecoding)"""
3
+ from typing import Any, List, Tuple
4
+
5
+ import sympy
6
+
7
+ import torch
8
+ from torch._inductor.virtualized import V
9
+
10
+ from .. import config, ir
11
+ from ..ir import FixedLayout, FlexibleLayout
12
+ from ..lowering import empty, empty_strided, lowerings
13
+ from ..runtime.runtime_utils import is_power_of_2, next_power_of_2
14
+ from ..select_algorithm import autotune_select_algorithm, TritonTemplate
15
+ from .flex_attention import (
16
+ compute_forward_block_mn,
17
+ compute_forward_inner,
18
+ compute_next_offset_func,
19
+ create_indices_fake,
20
+ create_num_blocks_fake_generator,
21
+ maybe_realize,
22
+ )
23
+
24
+
25
+ aten = torch.ops.aten
26
+ prims = torch.ops.prims
27
+
28
+
29
+ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, meta):
30
+ """How is this kernel parallelized?
31
+ We create a grid of (batch_size * kv_heads, SPLIT_KV, 1)
32
+ Each block is responsible for iterating over blocks of keys and values calculating
33
+ the local output for their tile of keys and values over all full length of query.
34
+ groups of SPLIT_KV blocks then combine their output to produce the final result.
35
+ """
36
+
37
+ return (batch_size * kv_heads, meta["SPLIT_KV"], 1)
38
+
39
+
40
+ flex_decoding_template = TritonTemplate(
41
+ name="flex_decoding",
42
+ grid=flex_decoding_grid,
43
+ source=r"""
44
+ {{def_kernel("Q", "K", "V", "M", "L", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}}
45
+ # Sub notation for this kernel:
46
+ # Q: Query, K: Key, V: Value
47
+ # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split
48
+ # M: Number of queries, N: Number of keys/values
49
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
50
+ # V_HEAD_DIM: The dimension of the value embeddings
51
+ # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block
52
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits
53
+ # (Modifiable) Config options:
54
+ # SPLIT_KV: number of blocks K & V are split into
55
+ # TILE_KV: length of each local KV split
56
+ # BLOCK_M: block size that Q is padded along seqlen dim.
57
+ # BLOCK_N: block size of K & V along N dimension.
58
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
59
+ #
60
+ # change of base out of the loop
61
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
62
+ # is not masked out? If so, we can skip an extra safety check
63
+ # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query.
64
+ # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value.
65
+
66
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base.
67
+ #
68
+ # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim.
69
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
70
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
71
+ #
72
+ #
73
+ # Output: ACC output accumulated across local KV split.
74
+
75
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
76
+
77
+ # Define Q Strides
78
+ stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = {{stride("Q")}}
79
+ stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}}
80
+ stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}}
81
+ stride_mz, stride_mt, stride_mh, stride_mm = {{stride("M")}}
82
+ stride_lz, stride_lt, stride_lh, stride_lm = {{stride("L")}}
83
+
84
+
85
+ Z = {{size("Q", 0)}}
86
+ HKV = {{size("Q", 1)}}
87
+ G: tl.constexpr = GQA_SHARED_HEADS
88
+ HQ = HKV * G
89
+ Q_LEN = {{size("Q", 3)}}
90
+ KV_LEN = {{size("K", 2)}}
91
+
92
+ MATMUL_PRECISION = Q.dtype.element_ty
93
+
94
+ # Make sure each split is a multiple of BLOCK_N
95
+ TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV)
96
+ TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N
97
+ TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N)
98
+
99
+ off_z = tl.program_id(0) // HKV
100
+ off_hkv = tl.program_id(0) % HKV
101
+ off_t = tl.program_id(1)
102
+
103
+ q_offset = off_z * stride_qz + off_hkv * stride_qh
104
+ k_offset = off_z * stride_kz + off_hkv * stride_kh
105
+ v_offset = off_z * stride_vz + off_hkv * stride_vh
106
+
107
+ SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}
108
+ SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}}
109
+
110
+ sparse_idx_z = off_z % SPARSE_Z
111
+ # TODO: support masks not broadcasted along the head dimension.
112
+ tl.device_assert(SPARSE_HQ == 1)
113
+ sparse_idx_h = 0
114
+
115
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
116
+ SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE)
117
+
118
+ # initialize pointer to m and l
119
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
120
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
121
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM], dtype=tl.float32)
122
+
123
+ # initialize offsets
124
+ tl.device_assert(BLOCK_M % G == 0)
125
+ BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G
126
+ off_g = tl.arange(0, G) # [G]
127
+ offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
128
+ offs_hq = offs_g + off_hkv * G
129
+ off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ]
130
+ offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
131
+ offs_d = tl.arange(0, QK_HEAD_DIM)
132
+ offs_vd = tl.arange(0, V_HEAD_DIM)
133
+
134
+ # KV_IDX / FULL_KV_IDX and KV_NUM_BLKS / FULL_KV_NUM_BLKS are always contiguous.
135
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_h
136
+
137
+ # Calculate KV blocks that belong this CTA.
138
+ block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block
139
+ block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N
140
+
141
+ q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :]
142
+
143
+ if SAFE_M_BOUNDARY:
144
+ q = tl.load(Q + q_offset + q_range)
145
+ else:
146
+ mask = off_m[None, :, None] < Q_LEN
147
+ q = tl.load(Q + q_offset + q_range, mask)
148
+
149
+ q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM])
150
+
151
+
152
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
153
+ # Apply both score_mod and mask_mod
154
+
155
+ # find first kv block we are loading and the number of blocks we are loading
156
+ kv_indices = KV_IDX + sparse_hz_offset * SPARSE_KV_BLOCK_CNT
157
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_hz_offset)
158
+ indices_idx = block_n_start // SPARSE_KV_MULTIPLE
159
+ off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
160
+ off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
161
+ # first kv block we're loading
162
+
163
+ # last valid block according to sparse mask
164
+ block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
165
+
166
+ K_block_ptr = tl.make_block_ptr(
167
+ base=K + k_offset,
168
+ shape=(QK_HEAD_DIM, KV_LEN), # (d, N)
169
+ strides=(stride_kk, stride_kn),
170
+ offsets=(0, off_n),
171
+ block_shape=(QK_HEAD_DIM, BLOCK_N),
172
+ order=(0, 1)
173
+ )
174
+ V_block_ptr = tl.make_block_ptr(
175
+ base=V + v_offset,
176
+ shape=(KV_LEN, V_HEAD_DIM),
177
+ strides=(stride_vn, stride_vk),
178
+ offsets=(off_n, 0),
179
+ block_shape=(BLOCK_N, V_HEAD_DIM),
180
+ order=(1, 0)
181
+ )
182
+ offs_n = tl.arange(0, BLOCK_N) + off_n
183
+
184
+ acc, l_i, m_i = forward_inner(
185
+ {{gen_argdefs()}},
186
+ q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
187
+ # accumulatd values
188
+ acc, l_i, m_i,
189
+ #offsets
190
+ off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
191
+ #block sparse data
192
+ kv_indices, kv_num_blocks,
193
+ block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
194
+ MATMUL_PRECISION,
195
+ IS_FULL_BLOCKS=False,
196
+ )
197
+
198
+
199
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
200
+ # We know these blocks are guaranteed to be "full", so we don't need to
201
+ # apply mask_mod to them - only score_mod
202
+ if HAS_FULL_BLOCKS:
203
+ kv_indices = FULL_KV_IDX + sparse_hz_offset * SPARSE_KV_BLOCK_CNT
204
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_hz_offset)
205
+ indices_idx = block_n_start // SPARSE_KV_MULTIPLE
206
+ off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
207
+ off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
208
+
209
+ # last valid block according to sparse mask
210
+ block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
211
+
212
+ K_block_ptr = tl.make_block_ptr(
213
+ base=K + k_offset,
214
+ shape=(QK_HEAD_DIM, KV_LEN), # (d, N)
215
+ strides=(stride_kk, stride_kn),
216
+ offsets=(0, off_n),
217
+ block_shape=(QK_HEAD_DIM, BLOCK_N),
218
+ order=(0, 1)
219
+ )
220
+ V_block_ptr = tl.make_block_ptr(
221
+ base=V + v_offset,
222
+ shape=(KV_LEN, V_HEAD_DIM),
223
+ strides=(stride_vn, stride_vk),
224
+ offsets=(off_n, 0),
225
+ block_shape=(BLOCK_N, V_HEAD_DIM),
226
+ order=(1, 0)
227
+ )
228
+ offs_n = tl.arange(0, BLOCK_N) + off_n
229
+
230
+ acc, l_i, m_i = forward_inner(
231
+ {{gen_argdefs()}},
232
+ q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
233
+ # accumulatd values
234
+ acc, l_i, m_i,
235
+ #offsets
236
+ off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
237
+ #block sparse data
238
+ kv_indices, kv_num_blocks,
239
+ block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
240
+ MATMUL_PRECISION,
241
+ IS_FULL_BLOCKS=True,
242
+ )
243
+
244
+ m_offset = off_t * stride_mt + off_z * stride_mz
245
+ l_offset = off_t * stride_lt + off_z * stride_lz
246
+
247
+ M_block_ptr = tl.make_block_ptr(
248
+ base=M + m_offset,
249
+ shape=(G, Q_LEN), # (G, M)
250
+ strides=(stride_mh, stride_mm),
251
+ offsets=(off_hkv*G, 0),
252
+ block_shape=(G, BLOCK_M_PER_HQ),
253
+ order=(1, 0)
254
+ )
255
+ L_block_ptr = tl.make_block_ptr(
256
+ base=L + l_offset,
257
+ shape=(G, Q_LEN), # (G, M)
258
+ strides=(stride_lh, stride_lm),
259
+ offsets=(off_hkv*G, 0),
260
+ block_shape=(G, BLOCK_M_PER_HQ),
261
+ order=(1, 0)
262
+ )
263
+
264
+ # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16)
265
+ m_i = m_i.reshape(G, BLOCK_M_PER_HQ)
266
+ l_i = l_i.reshape(G, BLOCK_M_PER_HQ)
267
+ if SAFE_M_BOUNDARY:
268
+ tl.store(M_block_ptr, m_i)
269
+ tl.store(L_block_ptr, l_i)
270
+ else:
271
+ tl.store(M_block_ptr, m_i, boundary_check=(1,))
272
+ tl.store(L_block_ptr, l_i, boundary_check=(1,))
273
+
274
+ # -- store output
275
+ idx_z = off_z
276
+ idx_t = off_t
277
+ idx_hq = off_hkv*G + off_g[:, None, None]
278
+ idx_m = off_m[None, :, None]
279
+ idx_d = offs_vd[None, None, :]
280
+
281
+ mask = (idx_m < Q_LEN)
282
+ acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM)
283
+ {{store_output(("idx_z", "idx_t", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}}
284
+ """
285
+ + compute_forward_inner
286
+ + compute_next_offset_func
287
+ + compute_forward_block_mn,
288
+ )
289
+
290
+
291
+ def get_split_k(B: int, H: int, Mk: int, SM: int = 128) -> int:
292
+ """Heuristic for the number of splits from xformer"""
293
+ bh = max(B * H, 1) # NOTE: Handle B*h=0 case
294
+ split_k = SM // bh # Each SM should at least get one block.
295
+ split_k = max(split_k, 1)
296
+
297
+ return split_k
298
+
299
+
300
+ def _get_decoding_default_config(key) -> Tuple[int, int, int]:
301
+ dtype = key.get_dtype()
302
+ head_dim = key.get_size()[-1]
303
+ sm_version = torch.cuda.get_device_capability()
304
+ default_config = (64, 2, 1)
305
+ if sm_version >= (9, 0):
306
+ if head_dim > 128 and dtype == torch.float32:
307
+ return default_config
308
+ return (64, 2, 3)
309
+ return default_config
310
+
311
+
312
+ def create_flex_decoding_kernel(*args, **kwargs):
313
+ (
314
+ query,
315
+ key,
316
+ value,
317
+ block_mask,
318
+ scale,
319
+ kernel_options,
320
+ score_mod_subgraph,
321
+ mask_mod_subgraph,
322
+ score_mod_other_buffers,
323
+ mask_mod_other_buffers,
324
+ ) = args
325
+ (
326
+ kv_num_blocks,
327
+ kv_indices,
328
+ full_kv_num_blocks, # full_kv_num_blocks,
329
+ full_kv_indices, # full_kv_indices,
330
+ _, # q_num_blocks
331
+ _, # q_indices
332
+ _, # full_q_num_blocks,
333
+ _, # full_q_indices,
334
+ SPARSE_KV_BLOCK_SIZE,
335
+ _, # SPARSE_Q_BLOCK_SIZE,
336
+ _,
337
+ ) = block_mask
338
+
339
+ Bq, Hq, seq_len_q, qk_head_dim = query.get_size()
340
+ Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size()
341
+ assert Bq == Bkv, "Batch dimension must match"
342
+ B = Bq
343
+ kernel_options = dict(kernel_options)
344
+
345
+ # TODO: Fix flex decoding non-divisible case!
346
+ if seq_len_q % 128 != 0 or seq_len_kv % 128 != 0:
347
+ kernel_options.setdefault("IS_DIVISIBLE", False)
348
+ else:
349
+ kernel_options.setdefault("IS_DIVISIBLE", True)
350
+
351
+ # Calculate GQA head sharing
352
+ gqa_shared_heads = Hq // Hkv
353
+ if not is_power_of_2(gqa_shared_heads):
354
+ raise ValueError(
355
+ "Number of shared query heads sharing the same KV head must be power of 2. "
356
+ )
357
+ kernel_options.setdefault("GQA_SHARED_HEADS", gqa_shared_heads)
358
+
359
+ # Determine if there are "full" blocks where we only need to apply score_mod, and can skip mask_mod
360
+ has_full_blocks = full_kv_num_blocks is not None
361
+ kernel_options.setdefault("HAS_FULL_BLOCKS", has_full_blocks)
362
+ if not has_full_blocks:
363
+ # Create a plackeholder full block list in case it is empty
364
+ full_kv_num_blocks, full_kv_indices = (
365
+ empty(0, device=query.get_device()) for _ in range(2)
366
+ )
367
+
368
+ (
369
+ query,
370
+ key,
371
+ value,
372
+ kv_num_blocks,
373
+ kv_indices,
374
+ full_kv_num_blocks,
375
+ full_kv_indices,
376
+ ) = maybe_realize(
377
+ [
378
+ query,
379
+ key,
380
+ value,
381
+ kv_num_blocks,
382
+ kv_indices,
383
+ full_kv_num_blocks,
384
+ full_kv_indices,
385
+ ]
386
+ )
387
+
388
+ choices: List[Any] = []
389
+ configs: List[Tuple[int, int, int]] = []
390
+ configs.append(_get_decoding_default_config(key))
391
+ # Note: max_autotune is not supported yet. Causes error in lowering the dynamic shape in reduction ops.
392
+ if config.max_autotune:
393
+ configs += [
394
+ (64, 2, 2),
395
+ (32, 2, 3),
396
+ (128, 2, 3),
397
+ ]
398
+ # TODO: fix autotuning.
399
+
400
+ kernel_options.setdefault("SM_SCALE", scale)
401
+ kernel_options.setdefault("SPLIT_KV", get_split_k(B, Hkv, seq_len_kv))
402
+ MAX_SPLIT_KV = kernel_options["SPLIT_KV"]
403
+
404
+ # create config dependent intermediate buffers
405
+ buf_ACC_shape = [B, MAX_SPLIT_KV, Hq, seq_len_q, v_head_dim]
406
+ buf_ML_shape = buf_ACC_shape[:-1]
407
+ buf_M = empty_strided(
408
+ buf_ML_shape,
409
+ None,
410
+ dtype=torch.float32, # The rowmax is always stored in fp32 regardless of the input dtype
411
+ device=query.get_device(),
412
+ )
413
+ buf_L = empty_strided(
414
+ buf_ML_shape,
415
+ None,
416
+ dtype=torch.float32, # The intermediate sumexp is always stored in fp32 regardless of the input dtype
417
+ device=query.get_device(),
418
+ )
419
+
420
+ layout_acc = FixedLayout(
421
+ query.get_device(),
422
+ torch.float32,
423
+ buf_ACC_shape,
424
+ FlexibleLayout.contiguous_strides(buf_ACC_shape),
425
+ )
426
+
427
+ kernel_options.setdefault("QK_HEAD_DIM", qk_head_dim)
428
+ kernel_options.setdefault("V_HEAD_DIM", v_head_dim)
429
+
430
+ kernel_options.setdefault(
431
+ "BLOCK_M",
432
+ (
433
+ # m
434
+ # if V.graph.sizevars.evaluate_expr(sympy.Lt(query.get_size()[-2], 0))
435
+ # else # Always use a BLOCK_M > 16 before Triton fix https://github.com/triton-lang/triton/pull/4061 is in pin
436
+ max(
437
+ next_power_of_2(
438
+ V.graph.sizevars.size_hint(
439
+ seq_len_q, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
440
+ )
441
+ * gqa_shared_heads
442
+ ),
443
+ 16,
444
+ )
445
+ ),
446
+ )
447
+
448
+ query = ir.ExternKernel.realize_input(query)
449
+ stride_b, stride_hq, stride_seq_len_q, stride_qk_head_dim = query.get_stride()
450
+
451
+ # Reshape query for GQA: [B, Hq, Mq, D] -> [B, Hkv, G, Mq, D]
452
+ gqa_query_shape = (B, Hkv, gqa_shared_heads, seq_len_q, qk_head_dim)
453
+ gqa_query_stride = (
454
+ stride_b,
455
+ stride_hq * gqa_shared_heads,
456
+ stride_hq,
457
+ stride_seq_len_q,
458
+ stride_qk_head_dim,
459
+ )
460
+ query = lowerings[aten.as_strided](query, gqa_query_shape, gqa_query_stride)
461
+
462
+ V.graph.sizevars.guard_leq(
463
+ seq_len_q * gqa_shared_heads, sympy.Integer(kernel_options["BLOCK_M"])
464
+ )
465
+
466
+ kernel_options.setdefault(
467
+ "SAFE_M_BOUNDARY",
468
+ ((seq_len_q * gqa_shared_heads) % kernel_options["BLOCK_M"]) == 0,
469
+ )
470
+ # TODO: This feels sketchy
471
+ kernel_options.setdefault("SAFE_N_BOUNDARY", True)
472
+
473
+ # Note, we don't need to pass in the captured buffers explicitly
474
+ # because they're implicitly added by the score_mod function
475
+ # We do need to explicitly pass it in for autotuning though.
476
+ for BLOCK_N, num_warps, num_stages in configs:
477
+ if SPARSE_KV_BLOCK_SIZE % BLOCK_N != 0:
478
+ continue
479
+
480
+ # Performance tuning
481
+ kernel_options.setdefault("BLOCK_N", BLOCK_N)
482
+ kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE)
483
+
484
+ # Work around https://github.com/pytorch/pytorch/issues/129625
485
+ if num_stages == 2:
486
+ continue
487
+ flex_decoding_template.maybe_append_choice(
488
+ choices=choices,
489
+ input_nodes=[
490
+ query,
491
+ key,
492
+ value,
493
+ buf_M,
494
+ buf_L,
495
+ kv_num_blocks,
496
+ kv_indices,
497
+ full_kv_num_blocks,
498
+ full_kv_indices,
499
+ ],
500
+ layout=layout_acc,
501
+ subgraphs=[
502
+ score_mod_subgraph,
503
+ mask_mod_subgraph,
504
+ ],
505
+ mutated_inputs=[buf_M, buf_L],
506
+ num_stages=num_stages,
507
+ num_warps=num_warps,
508
+ call_sizes=query.get_size(),
509
+ **kernel_options,
510
+ )
511
+
512
+ inputs_for_flex_decoding = (
513
+ [
514
+ query,
515
+ key,
516
+ value,
517
+ buf_M,
518
+ buf_L,
519
+ kv_num_blocks,
520
+ kv_indices,
521
+ full_kv_num_blocks,
522
+ full_kv_indices,
523
+ ]
524
+ + list(score_mod_other_buffers)
525
+ + list(mask_mod_other_buffers)
526
+ )
527
+
528
+ input_gen_fns = {
529
+ 5: create_num_blocks_fake_generator(kv_indices),
530
+ 6: create_indices_fake,
531
+ 7: create_num_blocks_fake_generator(full_kv_indices),
532
+ 8: create_indices_fake,
533
+ }
534
+
535
+ buf_ACC = autotune_select_algorithm(
536
+ "flex_decoding",
537
+ choices,
538
+ inputs_for_flex_decoding,
539
+ layout_acc,
540
+ input_gen_fns=input_gen_fns,
541
+ )
542
+
543
+ # Reduction
544
+
545
+ g_M = lowerings[aten.max](buf_M, dim=1, keepdim=True)[0]
546
+ # See [Note] Handle fully masked out rows:
547
+ # g_M Is the global max among split kv blocks.
548
+ masked_rows = lowerings[aten.eq](g_M, -float("inf"))
549
+ adj_M = lowerings[aten.sub](buf_M, g_M)
550
+ adj_M = lowerings[aten.where](masked_rows, 0, adj_M)
551
+ alpha = lowerings[aten.exp2](adj_M)
552
+
553
+ buf_L = lowerings[aten.mul](buf_L, alpha)
554
+ g_L = lowerings[aten.sum](buf_L, axis=1)
555
+ masked_rows_squeezed = lowerings[aten.squeeze](masked_rows, dim=1)
556
+ g_L = lowerings[aten.where](masked_rows_squeezed, 1.0, g_L)
557
+ logsumexp = lowerings[aten.log2](g_L)
558
+ logsumexp = lowerings[aten.add](logsumexp, lowerings[aten.squeeze](g_M, dim=1))
559
+
560
+ alpha_unseq = lowerings[aten.unsqueeze](alpha, 4)
561
+ buf_ACC = lowerings[aten.mul](buf_ACC, alpha_unseq)
562
+ output = lowerings[aten.sum](buf_ACC, axis=1)
563
+ L_unseq = lowerings[aten.unsqueeze](g_L, 3)
564
+ output = lowerings[aten.div](output, L_unseq)
565
+ output = lowerings[prims.convert_element_type](output, query.get_dtype())
566
+
567
+ return (
568
+ output,
569
+ logsumexp,
570
+ )
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/mm.py ADDED
@@ -0,0 +1,776 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import functools
3
+ import logging
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ import torch
7
+ from torch._inductor.autoheuristic.autoheuristic import AutoHeuristicSelectAlgorithm
8
+ from torch._inductor.autoheuristic.autoheuristic_utils import (
9
+ AHContext,
10
+ context_add_strides,
11
+ context_add_using_tf32,
12
+ get_mixedmm_precondition,
13
+ mixed_mm_operations,
14
+ mm_operations,
15
+ )
16
+ from torch._inductor.codegen.cpp_gemm_template import CppPackedGemmTemplate
17
+ from torch._inductor.virtualized import V
18
+
19
+ from .. import config as inductor_config
20
+ from ..codegen.common import BackendFeature
21
+ from ..codegen.cuda.gemm_template import CUTLASS2xGemmTemplate, CUTLASS3xGemmTemplate
22
+ from ..codegen.rocm.ck_universal_gemm_template import CKGemmTemplate
23
+ from ..codegen.wrapper import WrapperCodeGen
24
+ from ..ir import FlexibleLayout, is_triton
25
+ from ..lowering import register_lowering
26
+ from ..select_algorithm import (
27
+ autotune_select_algorithm,
28
+ ExternKernelChoice,
29
+ NoValidChoicesError,
30
+ TritonTemplate,
31
+ )
32
+ from ..utils import (
33
+ get_gpu_shared_memory,
34
+ use_aten_gemm_kernels,
35
+ use_ck_template,
36
+ use_cpp_packed_gemm_template,
37
+ use_cutlass_template,
38
+ use_max_autotune,
39
+ use_triton_template,
40
+ )
41
+ from .mm_common import (
42
+ addmm_epilogue,
43
+ extra_mm_configs,
44
+ int8_mm_configs,
45
+ mixed_mm_configs,
46
+ mm_args,
47
+ mm_configs,
48
+ mm_grid,
49
+ mm_options,
50
+ triton_config,
51
+ )
52
+
53
+
54
+ log = logging.getLogger(__name__)
55
+ aten = torch.ops.aten
56
+
57
+ mm_template = TritonTemplate(
58
+ name="mm",
59
+ grid=mm_grid,
60
+ source=r"""
61
+ {{def_kernel("A", "B")}}
62
+ M = {{size("A", 0)}}
63
+ N = {{size("B", 1)}}
64
+ K = {{size("A", 1)}}
65
+ if M * N == 0:
66
+ # early exit due to zero-size input(s)
67
+ return
68
+ stride_am = {{stride("A", 0)}}
69
+ stride_ak = {{stride("A", 1)}}
70
+ stride_bk = {{stride("B", 0)}}
71
+ stride_bn = {{stride("B", 1)}}
72
+
73
+ # based on triton.ops.matmul
74
+ pid = tl.program_id(0)
75
+ grid_m = (M + BLOCK_M - 1) // BLOCK_M
76
+ grid_n = (N + BLOCK_N - 1) // BLOCK_N
77
+
78
+ # re-order program ID for better L2 performance
79
+ width = GROUP_M * grid_n
80
+ group_id = pid // width
81
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
82
+ pid_m = group_id * GROUP_M + (pid % group_size)
83
+ pid_n = (pid % width) // (group_size)
84
+
85
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
86
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
87
+ if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1):
88
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
89
+ else:
90
+ ram = rm % M
91
+ if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1):
92
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
93
+ else:
94
+ rbn = rn % N
95
+ rk = tl.arange(0, BLOCK_K)
96
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
97
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
98
+
99
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
100
+ for k in range(K, 0, -BLOCK_K):
101
+ if EVEN_K:
102
+ a = tl.load(A)
103
+ b = tl.load(B)
104
+ else:
105
+ a = tl.load(A, mask=rk[None, :] < k, other=0.)
106
+ b = tl.load(B, mask=rk[:, None] < k, other=0.)
107
+ if B_PROLOGUE_CAST_TYPE is not None:
108
+ b = b.to(B_PROLOGUE_CAST_TYPE)
109
+ acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
110
+ A += BLOCK_K * stride_ak
111
+ B += BLOCK_K * stride_bk
112
+
113
+ # rematerialize rm and rn to save registers
114
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
115
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
116
+ idx_m = rm[:, None]
117
+ idx_n = rn[None, :]
118
+ mask = (idx_m < M) & (idx_n < N)
119
+
120
+ # inductor generates a suffix
121
+ {{store_output(("idx_m", "idx_n"), "acc", "mask")}}
122
+ """,
123
+ )
124
+
125
+ aten_mm = ExternKernelChoice(torch.mm, "at::mm_out")
126
+
127
+ aten_addmm = ExternKernelChoice(
128
+ torch.addmm, "at::addmm_out", op_overload=aten.addmm.default
129
+ )
130
+
131
+ aten__int_mm = ExternKernelChoice(torch._int_mm, "at::_int_mm")
132
+
133
+ aten__sparse_semi_structured_mm = ExternKernelChoice(
134
+ torch._sparse_semi_structured_mm,
135
+ "at::_sparse_semi_structured_mm",
136
+ has_out_variant=False,
137
+ )
138
+
139
+
140
+ def _is_int8_mat(mat):
141
+ return mat.get_dtype() in (torch.int8, torch.uint8)
142
+
143
+
144
+ def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1):
145
+ """
146
+ Giving torch.addmm a 1D tensor calls a different (faster) cublasLt
147
+ kernel under the hood. There are a few shapes where this is slower,
148
+ but they are rare.
149
+ """
150
+ if inp.stride(0) == 0 or inp.size(0) == 1:
151
+ return torch.addmm(inp[0], mat1, mat2, out=out, alpha=alpha, beta=beta)
152
+ return torch.addmm(inp, mat1, mat2, out=out, alpha=alpha, beta=beta)
153
+
154
+
155
+ aten_bias_addmm = ExternKernelChoice(bias_addmm, None)
156
+
157
+
158
+ @register_lowering(aten.mm, type_promotion_kind=None)
159
+ def tuned_mm(mat1, mat2, *, layout=None):
160
+ m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
161
+ name = "mm"
162
+
163
+ aten_layout = layout
164
+ if not use_max_autotune():
165
+ aten_layout = FlexibleLayout(
166
+ device=layout.device, dtype=layout.dtype, size=layout.size
167
+ )
168
+
169
+ # options to tune from
170
+ choices = (
171
+ [aten_mm.bind((mat1, mat2), aten_layout)] if use_aten_gemm_kernels() else []
172
+ )
173
+ static_shape, is_nonzero = _is_static_problem([mat1, mat2], layout)
174
+ if is_nonzero and use_triton_template(layout):
175
+ for config in mm_configs(m, n, k):
176
+ mm_template.maybe_append_choice(
177
+ choices,
178
+ input_nodes=(mat1, mat2),
179
+ layout=layout,
180
+ **mm_options(config, m, n, k, layout),
181
+ )
182
+ if static_shape and is_nonzero and use_cutlass_template(layout, m, n, k):
183
+ CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2])
184
+
185
+ if is_nonzero and use_ck_template(layout, m, n, k):
186
+ CKGemmTemplate.add_ck_gemm_choices(choices, layout, [mat1, mat2])
187
+
188
+ if use_cpp_packed_gemm_template(layout, mat1, mat2):
189
+ CppPackedGemmTemplate.add_choices(
190
+ choices,
191
+ layout,
192
+ [mat1, mat2],
193
+ )
194
+
195
+ input_nodes = [mat1, mat2]
196
+ if (
197
+ is_nonzero
198
+ and use_triton_template(layout)
199
+ and torch._inductor.config.run_autoheuristic(name)
200
+ and is_triton(mat1)
201
+ ):
202
+ always_included = []
203
+ if use_aten_gemm_kernels():
204
+ always_included.append("extern_mm")
205
+ num_choices_before_extra_configs = len(choices)
206
+ for config in extra_mm_configs(m, n, k):
207
+ mm_template.maybe_append_choice(
208
+ choices,
209
+ input_nodes=(mat1, mat2),
210
+ layout=layout,
211
+ **mm_options(config, m, n, k, layout),
212
+ )
213
+
214
+ # using AutoHeuristic for ranking
215
+ ah_choices = mm_autoheuristic(
216
+ mat1,
217
+ mat2,
218
+ m,
219
+ n,
220
+ k,
221
+ choices,
222
+ name,
223
+ input_nodes,
224
+ mm_operations(),
225
+ None,
226
+ top_k=10,
227
+ always_included=always_included,
228
+ )
229
+ if not torch._inductor.config.collect_autoheuristic(name):
230
+ # if we are collecting data, we do not want to modify choices
231
+ if ah_choices is not None and len(ah_choices) > 0:
232
+ # the order in which autoheuristic returns choices is not the same as
233
+ # as the order of choices, which affects things like epilogue fusion.
234
+ # once epilogue fusion benchmarks choices in sorted order, I think we can
235
+ # just use the order returned by autoheuristic
236
+ choices = [choice for choice in choices if choice in ah_choices]
237
+ else:
238
+ choices = choices[:num_choices_before_extra_configs]
239
+
240
+ if (
241
+ len(choices) == 0
242
+ and not use_aten_gemm_kernels()
243
+ and inductor_config.autotune_fallback_to_aten
244
+ ):
245
+ log.warning("No choices for GEMM, using ATen backend as fallback")
246
+ return aten_mm.bind((mat1, mat2), aten_layout).output_node()
247
+
248
+ try:
249
+ return autotune_select_algorithm(name, choices, [mat1, mat2], layout)
250
+ except NoValidChoicesError:
251
+ if not inductor_config.autotune_fallback_to_aten:
252
+ raise
253
+ log.warning("All choices for GEMM were invalid, using ATen backend as fallback")
254
+ return aten_mm.bind((mat1, mat2), aten_layout).output_node()
255
+
256
+
257
+ def _is_static_problem(inputs_tensors, layout):
258
+ # checks whether all input tensors and the output layout
259
+ # have a static shape by attempting to convert the dimensions
260
+ # to int
261
+ static_shape = True
262
+ static_size = WrapperCodeGen.statically_known_list_of_ints_or_none(layout.size)
263
+ if static_size is None:
264
+ nonzero = True
265
+ for s in layout.size:
266
+ sz = WrapperCodeGen.statically_known_int_or_none(s)
267
+ if sz is not None and sz == 0:
268
+ nonzero = False
269
+ break
270
+ return False, nonzero
271
+ numel = 1
272
+ for dim in static_size:
273
+ numel *= dim
274
+ nonzero = numel > 0
275
+ return static_shape, nonzero
276
+
277
+
278
+ @register_lowering(aten._int_mm, type_promotion_kind=None)
279
+ def tuned_int_mm(mat1, mat2, *, layout=None):
280
+ m, n, k, layout, mat1, mat2 = mm_args(
281
+ mat1, mat2, layout=layout, out_dtype=torch.int32
282
+ )
283
+ static_shape, is_nonzero = _is_static_problem([mat1, mat2], layout)
284
+ use_cutlass = static_shape and is_nonzero and use_cutlass_template(layout, m, n, k)
285
+
286
+ choices = (
287
+ [aten__int_mm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else []
288
+ )
289
+
290
+ # TODO: Re-enable eager mode implementation once cuBLAS is fixed
291
+ if use_cutlass or use_triton_template(layout, enable_int32=True):
292
+ choices = []
293
+
294
+ if use_cutlass:
295
+ CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
296
+ choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True
297
+ )
298
+ if is_nonzero and use_triton_template(layout, enable_int32=True):
299
+ for config in int8_mm_configs(m, n, k):
300
+ mm_template.maybe_append_choice(
301
+ choices,
302
+ input_nodes=(mat1, mat2),
303
+ layout=layout,
304
+ **mm_options(config, m, n, k, layout),
305
+ )
306
+ if len(choices) == 0:
307
+ log.warning(
308
+ "No choices for integer GEMM avaialbe using configured backends, using ATen backend as fallback"
309
+ )
310
+ choices = [aten__int_mm.bind((mat1, mat2), layout)]
311
+
312
+ try:
313
+ return autotune_select_algorithm("int_mm", choices, [mat1, mat2], layout)
314
+ except NoValidChoicesError:
315
+ if not inductor_config.autotune_fallback_to_aten:
316
+ raise
317
+ log.warning("All choices for GEMM were invalid, using ATen backend as fallback")
318
+ choices = [aten__int_mm.bind((mat1, mat2), layout)]
319
+ return autotune_select_algorithm("int_mm", choices, [mat1, mat2], layout)
320
+
321
+
322
+ @register_lowering(aten.addmm, type_promotion_kind=None)
323
+ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
324
+ ordered_kwargs_for_cpp_kernel = ("beta", "alpha")
325
+ m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout)
326
+ static_shape, is_nonzero = _is_static_problem([inp, mat1, mat2], layout)
327
+ if (not is_nonzero) or (not use_max_autotune()):
328
+ # Use a FlexibleLayout if we are not autotuning.
329
+ # This allows padding strides for the output.
330
+ from torch._inductor.ir import FixedLayout, FlexibleLayout
331
+
332
+ if isinstance(layout, FixedLayout):
333
+ layout = FlexibleLayout(
334
+ device=layout.device, dtype=layout.dtype, size=layout.size
335
+ )
336
+ choices = (
337
+ [
338
+ aten_addmm.bind(
339
+ (inp, mat1, mat2),
340
+ layout,
341
+ alpha=alpha,
342
+ beta=beta,
343
+ )
344
+ ]
345
+ if use_aten_gemm_kernels()
346
+ else []
347
+ )
348
+ return autotune_select_algorithm("addmm", choices, [inp, mat1, mat2], layout)
349
+
350
+ choices = (
351
+ [
352
+ aten_addmm.bind(
353
+ (inp_expanded, mat1, mat2),
354
+ layout,
355
+ alpha=alpha,
356
+ beta=beta,
357
+ )
358
+ ]
359
+ if use_aten_gemm_kernels()
360
+ else []
361
+ )
362
+
363
+ if (
364
+ use_aten_gemm_kernels()
365
+ and inp_expanded.get_stride()[0] == 0
366
+ and inp_expanded.get_device().type == "cuda"
367
+ and inductor_config.triton.autotune_cublasLt
368
+ ):
369
+ # unexpand inp to make sure fused addmm from cublasLt is used
370
+ choices.insert(
371
+ 0,
372
+ aten_bias_addmm.bind(
373
+ (inp_expanded, mat1, mat2), layout, alpha=alpha, beta=beta
374
+ ),
375
+ )
376
+
377
+ if is_nonzero and use_triton_template(layout):
378
+ for config in mm_configs(m, n, k):
379
+ mm_template.maybe_append_choice(
380
+ choices,
381
+ input_nodes=(inp_expanded, mat1, mat2),
382
+ layout=layout,
383
+ **mm_options(config, m, n, k, layout),
384
+ prefix_args=1,
385
+ epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
386
+ )
387
+
388
+ if static_shape and is_nonzero and use_cutlass_template(layout, m, n, k):
389
+ # Filter out a known cause of CUDA illegal memory access errors
390
+ # broadcasting on the last dim of the bias term seems not to be working
391
+ # in the linear GEMM epilogue used by addmm.
392
+ if (
393
+ WrapperCodeGen.statically_known_int_or_none(inp_expanded.layout.stride[-1])
394
+ != 0
395
+ ):
396
+ CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
397
+ choices,
398
+ layout,
399
+ [mat1, mat2, inp_expanded],
400
+ alpha=alpha,
401
+ beta=beta,
402
+ )
403
+
404
+ if is_nonzero and use_ck_template(layout, m, n, k):
405
+ CKGemmTemplate.add_ck_gemm_choices(
406
+ choices,
407
+ layout,
408
+ [mat1, mat2, inp_expanded],
409
+ alpha=alpha,
410
+ beta=beta,
411
+ )
412
+
413
+ if use_cpp_packed_gemm_template(layout, mat1, mat2):
414
+ CppPackedGemmTemplate.add_choices(
415
+ choices,
416
+ layout,
417
+ [inp_expanded, mat1, mat2],
418
+ alpha=alpha,
419
+ beta=beta,
420
+ has_bias=True,
421
+ )
422
+
423
+ add_aten_fallback = False
424
+ if len(choices) == 0:
425
+ log.warning("No choices for GEMM, using ATen backend as fallback")
426
+ add_aten_fallback = True
427
+
428
+ if add_aten_fallback:
429
+ choices.append(
430
+ aten_addmm.bind(
431
+ (inp_expanded, mat1, mat2),
432
+ layout,
433
+ ordered_kwargs_for_cpp_kernel,
434
+ alpha=alpha,
435
+ beta=beta,
436
+ )
437
+ )
438
+
439
+ if (
440
+ inp_expanded.get_stride()[0] == 0
441
+ and inp_expanded.get_device().type == "cuda"
442
+ and inductor_config.triton.autotune_cublasLt
443
+ ):
444
+ # unexpand inp to make sure fused addmm from cublasLt is used
445
+ choices.insert(
446
+ 0,
447
+ aten_bias_addmm.bind(
448
+ (inp_expanded, mat1, mat2), layout, alpha=alpha, beta=beta
449
+ ),
450
+ )
451
+ try:
452
+ return autotune_select_algorithm(
453
+ "addmm", choices, [inp_expanded, mat1, mat2], layout
454
+ )
455
+ except NoValidChoicesError:
456
+ if not inductor_config.autotune_fallback_to_aten:
457
+ raise
458
+ log.warning("All choices for GEMM were invalid, using ATen backend as fallback")
459
+ fallback_choice = aten_addmm.bind(
460
+ (inp, mat1, mat2),
461
+ layout,
462
+ ordered_kwargs_for_cpp_kernel,
463
+ alpha=alpha,
464
+ beta=beta,
465
+ )
466
+ return fallback_choice.output_node()
467
+
468
+
469
+ @register_lowering(aten._sparse_semi_structured_mm, type_promotion_kind=None)
470
+ def tuned_sparse_semi_structured_mm(
471
+ mat1, mat1_meta, mat2, *, out_dtype=None, layout=None
472
+ ):
473
+ from torch._inductor.select_algorithm import realize_inputs
474
+
475
+ mat1, mat1_meta, mat2 = realize_inputs(mat1, mat1_meta, mat2)
476
+ m1, k1 = mat1.get_size()
477
+ m2, _ = mat1_meta.get_size()
478
+ k2, n = mat2.get_size()
479
+ m = V.graph.sizevars.guard_equals(m1, m2)
480
+ k = V.graph.sizevars.guard_equals(2 * k1, k2)
481
+
482
+ if layout is None:
483
+ from torch._inductor.ir import FixedLayout
484
+
485
+ layout = FixedLayout(
486
+ mat2.get_device(),
487
+ out_dtype if out_dtype else mat2.get_dtype(),
488
+ [m, n],
489
+ [n, 1],
490
+ )
491
+ else:
492
+ assert out_dtype is None, "out_dtype is ignored if layout is specified."
493
+
494
+ choices = (
495
+ [
496
+ aten__sparse_semi_structured_mm.bind(
497
+ (mat1, mat1_meta, mat2), layout, out_dtype=out_dtype
498
+ )
499
+ ]
500
+ if use_aten_gemm_kernels()
501
+ else []
502
+ )
503
+
504
+ if m * n != 0 and use_cutlass_template(layout, m, n, k):
505
+ CUTLASS2xGemmTemplate.add_cutlass_gemm_choices(
506
+ choices, layout, [mat1, mat2, mat1_meta], fuseable=True, non_fuseable=True
507
+ )
508
+
509
+ return autotune_select_algorithm(
510
+ "sparse_semi_structured_mm", choices, [mat1, mat1_meta, mat2], layout
511
+ )
512
+
513
+
514
+ def fallback_mixed_mm(mat1, mat2, *, out):
515
+ return torch.mm(mat1, mat2.to(mat1.dtype), out=out)
516
+
517
+
518
+ aten_fallback_mixed_mm = ExternKernelChoice(fallback_mixed_mm, None)
519
+
520
+
521
+ @functools.lru_cache(None)
522
+ def _is_sm7x_or_older_gpu(index: Optional[int]) -> bool:
523
+ props = torch.cuda.get_device_properties(index or 0)
524
+ return props.major <= 7
525
+
526
+
527
+ def dims_are_int(dims):
528
+ return all(isinstance(dim, int) for dim in dims)
529
+
530
+
531
+ def try_heuristic(m, n, k, choices, mat1, mat2, mat2_dtype, layout):
532
+ m, n, k = get_size_hints(mat1, mat2, m, n, k)
533
+ if not dims_are_int([m, n, k]):
534
+ return None
535
+
536
+ if mat1.dtype != torch.float16:
537
+ return None
538
+
539
+ # only use heuristic if we are running on an A100
540
+ # torch.cuda.get_device_capability() >= (8, 0) returns true for A10G
541
+ # which does not have enough shared memory for one of the configs
542
+ if (
543
+ not torch.cuda.get_device_capability() >= (8, 0)
544
+ ) or get_gpu_shared_memory() != 166912:
545
+ return None
546
+
547
+ if m == 1 and (n % 16 != 0 or k % 16 != 0):
548
+ return None
549
+
550
+ if m <= 16 and n >= 4096 and k >= 4096:
551
+ return triton_config(
552
+ BLOCK_M=16,
553
+ BLOCK_N=64,
554
+ BLOCK_K=128,
555
+ num_stages=5,
556
+ num_warps=4,
557
+ )
558
+ elif m > 16 and m <= 32 and n >= 4096 and k >= 4096:
559
+ return triton_config(
560
+ BLOCK_M=32,
561
+ BLOCK_N=32,
562
+ BLOCK_K=128,
563
+ num_stages=5,
564
+ num_warps=4,
565
+ )
566
+ elif m > 32 and m <= 64 and n >= 4096 and k >= 4096:
567
+ return triton_config(
568
+ BLOCK_M=64,
569
+ BLOCK_N=32,
570
+ BLOCK_K=128,
571
+ num_stages=5,
572
+ num_warps=4,
573
+ )
574
+ return None
575
+
576
+
577
+ def mm_autoheuristic(
578
+ mat1,
579
+ mat2,
580
+ m,
581
+ n,
582
+ k,
583
+ choices,
584
+ name,
585
+ input_nodes,
586
+ ops,
587
+ precondition,
588
+ top_k: Optional[int] = None,
589
+ always_included=None,
590
+ ):
591
+ m, n, k = get_size_hints(mat1, mat2, m, n, k)
592
+ if not dims_are_int([m, n, k]):
593
+ return None
594
+ mat1_stride, mat2_stride = get_size_hints_strides(mat1, mat2)
595
+
596
+ def get_context(m, k, n, mat1, mat2, mat1_stride, mat2_stride):
597
+ context = AHContext()
598
+ context.add_feature("m", m)
599
+ context.add_feature("k", k)
600
+ context.add_feature("n", n)
601
+ context.add_feature("mat1_dtype", mat1.layout.dtype, is_categorical=True)
602
+ context.add_feature("mat2_dtype", mat2.layout.dtype, is_categorical=True)
603
+ context_add_strides(context, "mat1", mat1_stride)
604
+ context_add_strides(context, "mat2", mat2_stride)
605
+ context.add_feature(
606
+ "mat1_iscontig", mat1.layout.is_contiguous(), is_categorical=True
607
+ )
608
+ context.add_feature(
609
+ "mat2_iscontig", mat2.layout.is_contiguous(), is_categorical=True
610
+ )
611
+ if name == "mm":
612
+ # for mixed_mm, we only consider fp16
613
+ context_add_using_tf32(context, mat1.layout.dtype)
614
+ return context
615
+
616
+ def fallback():
617
+ return None
618
+
619
+ context = get_context(m, k, n, mat1, mat2, mat1_stride, mat2_stride)
620
+ autoheuristic = AutoHeuristicSelectAlgorithm(
621
+ fallback=fallback,
622
+ choices=choices,
623
+ input_nodes=input_nodes,
624
+ context=context,
625
+ name=name,
626
+ augment_context=ops,
627
+ precondition=precondition,
628
+ )
629
+
630
+ if top_k is not None:
631
+ # TODO: is there a cleaner way to ensure aten.mm is always included?
632
+ return autoheuristic.get_top_k_choices_caller(
633
+ top_k, always_included=always_included
634
+ )
635
+
636
+ return autoheuristic.get_choice_caller()
637
+
638
+
639
+ def get_size_hints(mat1, mat2, m, n, k):
640
+ if not isinstance(m, int) or not isinstance(k, int):
641
+ (m, k) = V.graph.sizevars.size_hints(
642
+ mat1.get_size(),
643
+ fallback=torch._inductor.config.unbacked_symint_fallback,
644
+ )
645
+
646
+ if not isinstance(n, int) or not isinstance(k, int):
647
+ (k, n) = V.graph.sizevars.size_hints(
648
+ mat2.get_size(),
649
+ fallback=torch._inductor.config.unbacked_symint_fallback,
650
+ )
651
+ return m, n, k
652
+
653
+
654
+ def get_size_hints_strides(mat1, mat2):
655
+ mat1_stride = mat1.layout.stride
656
+ mat2_stride = mat2.layout.stride
657
+ strides = [mat1_stride, mat2_stride]
658
+ strides_hints = []
659
+ for stride in strides:
660
+ if not isinstance(stride, int):
661
+ stride = V.graph.sizevars.size_hints(
662
+ stride,
663
+ fallback=torch._inductor.config.unbacked_symint_fallback,
664
+ )
665
+ strides_hints.append(stride)
666
+ return strides_hints[0], strides_hints[1]
667
+
668
+
669
+ def tuned_mixed_mm(mat1, mat2, mat2_dtype):
670
+ m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=None)
671
+ static_shape, is_nonzero = _is_static_problem([mat1, mat2], layout)
672
+
673
+ fallback = aten_fallback_mixed_mm.bind((mat1, mat2), layout)
674
+
675
+ choices = [fallback]
676
+
677
+ # can't use triton kernel unless one of these is true or if running on v100 (numerical issues)
678
+ skip_triton = (
679
+ (
680
+ mat1.layout.dtype != torch.float32
681
+ and not (mat2.layout.is_contiguous() or mat2.layout.is_transposed())
682
+ )
683
+ or _is_sm7x_or_older_gpu(layout.device.index)
684
+ or inductor_config.mixed_mm_choice == "aten"
685
+ or not V.graph.has_feature(layout.device, BackendFeature.TRITON_TEMPLATES)
686
+ or (
687
+ mat1.layout.dtype == torch.float32 and torch.backends.cuda.matmul.allow_tf32
688
+ )
689
+ or (mat1.layout.dtype == torch.bfloat16 and mat2.layout.dtype == torch.uint8)
690
+ )
691
+
692
+ if inductor_config.mixed_mm_choice == "triton":
693
+ choices = []
694
+
695
+ if not skip_triton:
696
+ b_prologue_cast_type = f"tl.{mat2_dtype}".replace("torch.", "")
697
+ if static_shape and inductor_config.mixed_mm_choice == "heuristic":
698
+ choices = []
699
+ config = try_heuristic(m, n, k, choices, mat1, mat2, mat2_dtype, layout)
700
+ if config is not None:
701
+ mm_template.maybe_append_choice(
702
+ choices,
703
+ input_nodes=(mat1, mat2),
704
+ layout=layout,
705
+ **mm_options(config, m, n, k, layout, b_prologue_cast_type),
706
+ )
707
+ choices.append(fallback)
708
+
709
+ has_int8_tensor = _is_int8_mat(mat1) or _is_int8_mat(mat2)
710
+ for config in mixed_mm_configs(m, n, k, has_int8_tensor=has_int8_tensor):
711
+ mm_template.maybe_append_choice(
712
+ choices,
713
+ input_nodes=(mat1, mat2),
714
+ layout=layout,
715
+ **mm_options(config, m, n, k, layout, b_prologue_cast_type),
716
+ )
717
+
718
+ if static_shape and is_nonzero and use_cutlass_template(layout, m, n, k):
719
+ CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
720
+ choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True
721
+ )
722
+ CUTLASS2xGemmTemplate.add_cutlass_gemm_choices(
723
+ choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True
724
+ )
725
+
726
+ if skip_triton and not choices:
727
+ choices = [fallback]
728
+
729
+ name = "mixed_mm"
730
+ input_nodes = [mat1, mat2]
731
+ if torch._inductor.config.run_autoheuristic(name):
732
+ choice = mm_autoheuristic(
733
+ mat1,
734
+ mat2,
735
+ m,
736
+ n,
737
+ k,
738
+ choices,
739
+ name,
740
+ input_nodes,
741
+ mixed_mm_operations(),
742
+ get_mixedmm_precondition,
743
+ )
744
+ if (
745
+ not skip_triton
746
+ and inductor_config.mixed_mm_choice == "heuristic"
747
+ and choice is not None
748
+ ):
749
+ choices.insert(0, choice)
750
+ return autotune_select_algorithm(name, choices, input_nodes, layout)
751
+
752
+
753
+ # This op is a special case of the int_mm op which we use based on the pattern
754
+ # _int_mm -> mul (defined in ../fx_passes/post_grad.py) in order to prevent
755
+ # realization of the int32 _int_mm output by forcing fusion with the mul op.
756
+ # This is only used when config.force_fuse_int_mm_with_mul = True
757
+ def tuned_fused_int_mm_mul(mat1, mat2, mat3, out_dtype, *, layout=None):
758
+ out_dtype = (
759
+ torch.promote_types(mat3.get_dtype(), torch.int32)
760
+ if out_dtype is None
761
+ else out_dtype
762
+ )
763
+ m, n, k, layout, mat1, mat2, mat3 = mm_args(
764
+ mat1, mat2, mat3, layout=layout, out_dtype=out_dtype
765
+ )
766
+ choices: List[Dict[Any, Any]] = []
767
+ for config in int8_mm_configs(m, n, k):
768
+ mm_template.maybe_append_choice(
769
+ choices,
770
+ input_nodes=(mat1, mat2, mat3),
771
+ layout=layout,
772
+ **dict(mm_options(config, m, n, k, layout), ACC_TYPE="tl.int32"),
773
+ suffix_args=1,
774
+ epilogue_fn=V.ops.mul,
775
+ )
776
+ return autotune_select_algorithm("int_mm", choices, [mat1, mat2, mat3], layout)
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/mm_common.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import functools
3
+ import itertools
4
+ import logging
5
+ from typing import cast, List, Tuple
6
+
7
+ import sympy
8
+
9
+ import torch
10
+ from torch._inductor.select_algorithm import realize_inputs
11
+ from torch._inductor.virtualized import V
12
+
13
+ from .. import config as inductor_config
14
+ from ..runtime.runtime_utils import next_power_of_2
15
+ from ..utils import ceildiv as cdiv
16
+
17
+
18
+ log = logging.getLogger(__name__)
19
+
20
+
21
+ def triton_config(num_stages, num_warps, **kwargs):
22
+ from triton import Config
23
+
24
+ return Config(kwargs, num_stages=num_stages, num_warps=num_warps)
25
+
26
+
27
+ def filtered_configs(
28
+ m: int,
29
+ n: int,
30
+ k: int,
31
+ configs: List[Tuple[int, int, int, int, int]],
32
+ has_int8_tensor=False,
33
+ ):
34
+ """Heuristic to shrink configs when they are bigger than the input size"""
35
+
36
+ min_block_size = 16
37
+ # block_k=16 seems to be causing issues
38
+ # see: https://github.com/triton-lang/triton/issues/2156#issuecomment-1695897424
39
+ min_block_size_k = 32 if has_int8_tensor else 16
40
+ m = max(
41
+ next_power_of_2(
42
+ V.graph.sizevars.size_hint(
43
+ m, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
44
+ )
45
+ ),
46
+ min_block_size,
47
+ )
48
+ n = max(
49
+ next_power_of_2(
50
+ V.graph.sizevars.size_hint(
51
+ n, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
52
+ )
53
+ ),
54
+ min_block_size,
55
+ )
56
+ k = max(
57
+ next_power_of_2(
58
+ V.graph.sizevars.size_hint(
59
+ k, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
60
+ )
61
+ ),
62
+ min_block_size_k,
63
+ )
64
+ used = set()
65
+ for block_m, block_n, block_k, num_stages, num_warps in configs:
66
+ # shrink configs for small sizes
67
+ block_m = max(min(block_m, m), min_block_size)
68
+ block_n = max(min(block_n, n), min_block_size)
69
+ block_k = max(min(block_k, k), min_block_size_k)
70
+ # each warp computes 16x16 tile = 256
71
+ num_warps = min(num_warps, block_m * block_n // 256)
72
+ if torch.version.hip:
73
+ for matrix_instr_nonkdim in [0, 16]:
74
+ if matrix_instr_nonkdim != 0 and (
75
+ block_m % matrix_instr_nonkdim != 0
76
+ or block_n % matrix_instr_nonkdim != 0
77
+ ):
78
+ # block_m and block_n must be a multiple of matrix_instr_nonkdim
79
+ continue
80
+ if (
81
+ block_m,
82
+ block_n,
83
+ block_k,
84
+ num_stages,
85
+ num_warps,
86
+ matrix_instr_nonkdim,
87
+ ) not in used:
88
+ used.add(
89
+ (
90
+ block_m,
91
+ block_n,
92
+ block_k,
93
+ num_stages,
94
+ num_warps,
95
+ matrix_instr_nonkdim,
96
+ )
97
+ )
98
+ yield triton_config(
99
+ BLOCK_M=block_m,
100
+ BLOCK_N=block_n,
101
+ BLOCK_K=block_k,
102
+ num_stages=num_stages,
103
+ num_warps=num_warps,
104
+ matrix_instr_nonkdim=matrix_instr_nonkdim,
105
+ )
106
+ else:
107
+ if (block_m, block_n, block_k, num_stages, num_warps, 0) not in used:
108
+ used.add((block_m, block_n, block_k, num_stages, num_warps, 0))
109
+ yield triton_config(
110
+ BLOCK_M=block_m,
111
+ BLOCK_N=block_n,
112
+ BLOCK_K=block_k,
113
+ num_stages=num_stages,
114
+ num_warps=num_warps,
115
+ )
116
+
117
+
118
+ # List of dictionaries to store the kernel configs. Configs that evaluate to true
119
+ # will be utilised on the target platform. The configs are as follows:
120
+ # (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
121
+ mm_kernel_configs = (
122
+ [
123
+ {"config": (32, 32, 16, 1, 2), "cond": True},
124
+ {"config": (32, 32, 128, 2, 4), "cond": torch.version.hip is None},
125
+ {"config": (32, 64, 32, 5, 8), "cond": True},
126
+ {"config": (64, 32, 32, 5, 8), "cond": True},
127
+ {"config": (64, 32, 128, 5, 4), "cond": True},
128
+ {"config": (64, 64, 16, 2, 4), "cond": True},
129
+ {"config": (64, 64, 32, 2, 4), "cond": True},
130
+ {"config": (64, 64, 64, 3, 8), "cond": True},
131
+ {"config": (64, 64, 128, 5, 4), "cond": True},
132
+ {"config": (64, 128, 32, 3, 4), "cond": True},
133
+ {"config": (64, 128, 32, 4, 8), "cond": True},
134
+ {"config": (64, 128, 64, 3, 4), "cond": True},
135
+ {"config": (64, 128, 128, 4, 4), "cond": True},
136
+ {"config": (128, 64, 32, 3, 4), "cond": True},
137
+ {"config": (128, 64, 32, 4, 8), "cond": True},
138
+ {"config": (128, 128, 32, 2, 8), "cond": True},
139
+ {"config": (128, 128, 32, 3, 4), "cond": True},
140
+ {"config": (128, 128, 64, 3, 4), "cond": True},
141
+ {"config": (128, 128, 64, 5, 8), "cond": True},
142
+ ]
143
+ if inductor_config.max_autotune_gemm_search_space != "EXHAUSTIVE"
144
+ else [
145
+ {"config": (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps), "cond": True}
146
+ for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product(
147
+ [16, 32, 64, 128, 256], repeat=3
148
+ )
149
+ for num_stages in [1, 2, 3, 4, 5]
150
+ for num_warps in [2, 4, 8]
151
+ ]
152
+ )
153
+
154
+ # these are only used in tuned_mm when AutoHeuristic is enabled
155
+ # the idea is that when AutoHeuristic collects data to learn a heuristic, more configs are autotuned
156
+ # when the learned heuristic is used, the learned heuristic reduces the number of configs down to 10
157
+ # which saves compilation time (since less configs are autotuned) and potentially increase performance
158
+ # because the learned heuristic might predict a config that is not part mm_configs
159
+ extra_mm_kernel_configs = [
160
+ {"config": (16, 32, 16, 3, 2), "cond": True},
161
+ {"config": (16, 32, 32, 4, 2), "cond": True},
162
+ {"config": (16, 32, 32, 5, 2), "cond": True},
163
+ {"config": (64, 64, 128, 3, 4), "cond": True},
164
+ {"config": (128, 64, 32, 2, 2), "cond": True},
165
+ {"config": (128, 64, 64, 3, 8), "cond": True},
166
+ {"config": (128, 64, 128, 4, 8), "cond": True},
167
+ {"config": (128, 128, 32, 4, 4), "cond": True},
168
+ {"config": (128, 128, 64, 3, 8), "cond": True},
169
+ {"config": (128, 128, 64, 5, 4), "cond": True},
170
+ ]
171
+
172
+ int8_mm_kernel_configs = [
173
+ {"config": (64, 64, 32, 2, 4), "cond": True},
174
+ {"config": (64, 128, 32, 3, 4), "cond": True},
175
+ {"config": (128, 64, 32, 3, 4), "cond": True},
176
+ {"config": (64, 128, 32, 4, 8), "cond": True},
177
+ {"config": (128, 64, 32, 4, 8), "cond": True},
178
+ {"config": (64, 32, 32, 5, 8), "cond": True},
179
+ {"config": (32, 64, 32, 5, 8), "cond": True},
180
+ {"config": (128, 128, 32, 2, 8), "cond": True},
181
+ {"config": (64, 64, 64, 3, 8), "cond": True},
182
+ # {"config": (32, 32, 128, 2, 4), "cond": True},
183
+ # {"config": (64, 64, 16, 2, 4), "cond": True},
184
+ # {"config": (32, 32, 16, 1, 2), "cond": True},
185
+ {"config": (128, 256, 128, 3, 8), "cond": torch.version.hip is None},
186
+ {"config": (256, 128, 128, 3, 8), "cond": torch.version.hip is None},
187
+ ]
188
+
189
+ # Mixed precision kernel configs for small sizes of m for mm's like (16, 8192) x (8192, 8192).
190
+ mixed_mm_kernel_configs_small_m = [
191
+ {"config": (16, 128, 256, 3, 4), "cond": True},
192
+ {"config": (16, 128, 256, 5, 8), "cond": True},
193
+ ]
194
+
195
+ mixed_mm_kernel_configs = (
196
+ mm_kernel_configs + mixed_mm_kernel_configs_small_m
197
+ if inductor_config.max_autotune_gemm_search_space != "EXHAUSTIVE"
198
+ else mm_kernel_configs
199
+ )
200
+
201
+ scaled_mm_kernel_configs = [
202
+ {"config": (128, 256, 32, 3, 8), "cond": True},
203
+ {"config": (256, 128, 32, 3, 8), "cond": True},
204
+ {"config": (256, 64, 32, 4, 4), "cond": True},
205
+ {"config": (64, 256, 32, 4, 4), "cond": True},
206
+ {"config": (128, 128, 32, 4, 4), "cond": True},
207
+ {"config": (128, 64, 32, 4, 4), "cond": True},
208
+ {"config": (64, 128, 32, 4, 4), "cond": True},
209
+ {"config": (128, 32, 32, 4, 4), "cond": True},
210
+ {"config": (64, 32, 32, 5, 2), "cond": True},
211
+ {"config": (256, 128, 128, 3, 8), "cond": True},
212
+ {"config": (256, 64, 128, 4, 4), "cond": True},
213
+ {"config": (64, 256, 128, 4, 4), "cond": True},
214
+ {"config": (128, 128, 128, 4, 4), "cond": True},
215
+ {"config": (128, 64, 64, 4, 4), "cond": True},
216
+ {"config": (64, 128, 64, 4, 4), "cond": True},
217
+ {"config": (128, 32, 64, 4, 4), "cond": True},
218
+ {"config": (64, 32, 64, 5, 2), "cond": True},
219
+ {"config": (16, 32, 32, 2, 2), "cond": True},
220
+ {"config": (16, 64, 32, 2, 2), "cond": True},
221
+ {"config": (16, 128, 32, 2, 4), "cond": True},
222
+ {"config": (16, 256, 32, 2, 4), "cond": True},
223
+ {"config": (16, 32, 64, 2, 2), "cond": True},
224
+ {"config": (16, 64, 64, 2, 2), "cond": True},
225
+ {"config": (16, 128, 64, 2, 4), "cond": True},
226
+ {"config": (16, 256, 64, 2, 4), "cond": True},
227
+ {"config": (32, 32, 32, 2, 2), "cond": True},
228
+ {"config": (32, 64, 32, 2, 2), "cond": True},
229
+ {"config": (32, 128, 32, 2, 4), "cond": True},
230
+ {"config": (32, 256, 32, 2, 4), "cond": True},
231
+ {"config": (32, 32, 64, 2, 2), "cond": True},
232
+ {"config": (32, 64, 64, 2, 2), "cond": True},
233
+ {"config": (32, 128, 64, 2, 4), "cond": True},
234
+ {"config": (32, 256, 64, 2, 4), "cond": True},
235
+ {"config": (16, 32, 32, 3, 2), "cond": True},
236
+ {"config": (16, 64, 32, 3, 2), "cond": True},
237
+ {"config": (16, 128, 32, 3, 4), "cond": True},
238
+ {"config": (16, 256, 32, 3, 4), "cond": True},
239
+ {"config": (16, 32, 64, 3, 2), "cond": True},
240
+ {"config": (16, 64, 64, 3, 2), "cond": True},
241
+ {"config": (16, 128, 64, 3, 4), "cond": True},
242
+ {"config": (16, 256, 64, 3, 4), "cond": True},
243
+ {"config": (32, 32, 32, 3, 2), "cond": True},
244
+ {"config": (32, 64, 32, 3, 2), "cond": True},
245
+ {"config": (32, 128, 32, 3, 4), "cond": True},
246
+ {"config": (32, 256, 32, 3, 4), "cond": True},
247
+ {"config": (32, 32, 64, 3, 2), "cond": True},
248
+ {"config": (32, 64, 64, 3, 2), "cond": True},
249
+ {"config": (32, 128, 64, 3, 4), "cond": True},
250
+ {"config": (32, 256, 64, 3, 4), "cond": True},
251
+ {"config": (16, 32, 32, 4, 2), "cond": True},
252
+ {"config": (16, 64, 32, 4, 2), "cond": True},
253
+ {"config": (16, 128, 32, 4, 4), "cond": True},
254
+ {"config": (16, 256, 32, 4, 4), "cond": True},
255
+ {"config": (16, 32, 64, 4, 2), "cond": True},
256
+ {"config": (16, 64, 64, 4, 2), "cond": True},
257
+ {"config": (16, 128, 64, 4, 4), "cond": True},
258
+ {"config": (16, 256, 64, 4, 4), "cond": True},
259
+ {"config": (32, 32, 32, 4, 2), "cond": True},
260
+ {"config": (32, 64, 32, 4, 2), "cond": True},
261
+ {"config": (32, 128, 32, 4, 4), "cond": True},
262
+ {"config": (32, 256, 32, 4, 4), "cond": True},
263
+ {"config": (32, 32, 64, 4, 2), "cond": True},
264
+ {"config": (32, 64, 64, 4, 2), "cond": True},
265
+ {"config": (32, 128, 64, 4, 4), "cond": True},
266
+ {"config": (32, 256, 64, 4, 4), "cond": True},
267
+ {"config": (16, 32, 32, 5, 2), "cond": True},
268
+ {"config": (16, 64, 32, 5, 2), "cond": True},
269
+ {"config": (16, 128, 32, 5, 4), "cond": True},
270
+ {"config": (16, 256, 32, 5, 4), "cond": True},
271
+ {"config": (16, 32, 64, 5, 2), "cond": True},
272
+ {"config": (16, 64, 64, 5, 2), "cond": True},
273
+ {"config": (16, 128, 64, 5, 4), "cond": True},
274
+ {"config": (16, 256, 64, 5, 4), "cond": True},
275
+ {"config": (32, 32, 32, 5, 2), "cond": True},
276
+ {"config": (32, 64, 32, 5, 2), "cond": True},
277
+ {"config": (32, 128, 32, 5, 4), "cond": True},
278
+ {"config": (32, 256, 32, 5, 4), "cond": True},
279
+ {"config": (32, 32, 64, 5, 2), "cond": True},
280
+ {"config": (32, 64, 64, 5, 2), "cond": True},
281
+ {"config": (32, 128, 64, 5, 4), "cond": True},
282
+ {"config": (32, 256, 64, 5, 4), "cond": True},
283
+ {"config": (16, 32, 32, 6, 2), "cond": True},
284
+ {"config": (16, 64, 32, 6, 2), "cond": True},
285
+ {"config": (16, 128, 32, 6, 4), "cond": True},
286
+ {"config": (16, 256, 32, 6, 4), "cond": True},
287
+ {"config": (16, 32, 64, 6, 2), "cond": True},
288
+ {"config": (16, 64, 64, 6, 2), "cond": True},
289
+ {"config": (16, 128, 64, 6, 4), "cond": True},
290
+ {"config": (16, 256, 64, 6, 4), "cond": True},
291
+ {"config": (32, 32, 32, 6, 2), "cond": True},
292
+ {"config": (32, 64, 32, 6, 2), "cond": True},
293
+ {"config": (32, 128, 32, 6, 4), "cond": True},
294
+ {"config": (32, 256, 32, 6, 4), "cond": True},
295
+ {"config": (32, 32, 64, 6, 2), "cond": True},
296
+ {"config": (32, 64, 64, 6, 2), "cond": True},
297
+ {"config": (32, 128, 64, 6, 4), "cond": True},
298
+ {"config": (32, 256, 64, 6, 4), "cond": True},
299
+ ]
300
+
301
+
302
+ # Create filtered list of configs based on cond evaluation
303
+ mm_platform_configs = tuple(
304
+ cast(Tuple[int, int, int, int, int], config["config"])
305
+ for config in mm_kernel_configs
306
+ if config["cond"]
307
+ )
308
+ extra_mm_platform_configs = tuple(
309
+ cast(Tuple[int, int, int, int, int], config["config"])
310
+ for config in extra_mm_kernel_configs
311
+ if config["cond"]
312
+ )
313
+ int8_platform_configs = tuple(
314
+ cast(Tuple[int, int, int, int, int], config["config"])
315
+ for config in int8_mm_kernel_configs
316
+ if config["cond"]
317
+ )
318
+ mixed_mm_platform_configs = tuple(
319
+ cast(Tuple[int, int, int, int, int], config["config"])
320
+ for config in mixed_mm_kernel_configs
321
+ if config["cond"]
322
+ )
323
+ scaled_mm_platform_configs = tuple(
324
+ cast(Tuple[int, int, int, int, int], config["config"])
325
+ for config in scaled_mm_kernel_configs
326
+ if config["cond"]
327
+ )
328
+
329
+ # On ROCm convert num_stages to 0 to enable software pipelining
330
+ if torch.version.hip:
331
+ mm_platform_configs = tuple(
332
+ (config[0], config[1], config[2], 0, config[4])
333
+ for config in mm_platform_configs
334
+ )
335
+ extra_mm_platform_configs = tuple(
336
+ (config[0], config[1], config[2], 0, config[4])
337
+ for config in extra_mm_platform_configs
338
+ )
339
+ int8_platform_configs = tuple(
340
+ (config[0], config[1], config[2], 0, config[4])
341
+ for config in mm_platform_configs
342
+ )
343
+ mixed_mm_platform_configs = tuple(
344
+ (config[0], config[1], config[2], 0, config[4])
345
+ for config in mixed_mm_platform_configs
346
+ )
347
+ scaled_mm_platform_configs = tuple(
348
+ (config[0], config[1], config[2], 0, config[4])
349
+ for config in scaled_mm_platform_configs
350
+ )
351
+
352
+ mm_configs = functools.partial(
353
+ filtered_configs,
354
+ configs=mm_platform_configs,
355
+ )
356
+
357
+ extra_mm_configs = functools.partial(
358
+ filtered_configs,
359
+ configs=extra_mm_platform_configs,
360
+ )
361
+
362
+ int8_mm_configs = functools.partial(
363
+ filtered_configs,
364
+ configs=int8_platform_configs,
365
+ )
366
+
367
+ mixed_mm_configs = functools.partial(
368
+ filtered_configs,
369
+ configs=mixed_mm_platform_configs,
370
+ )
371
+
372
+ scaled_mm_configs = functools.partial(
373
+ filtered_configs,
374
+ configs=scaled_mm_platform_configs,
375
+ )
376
+
377
+
378
+ def mm_grid(m, n, meta):
379
+ """
380
+ The CUDA grid size for matmul triton templates.
381
+ """
382
+ return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), 1, 1)
383
+
384
+
385
+ def acc_type(dtype):
386
+ if dtype in (torch.float16, torch.bfloat16):
387
+ return "tl.float32"
388
+ return f"tl.{dtype}".replace("torch.", "")
389
+
390
+
391
+ def mm_options(config, sym_m, sym_n, sym_k, layout, b_prologue_cast_type=None):
392
+ """
393
+ Common options to matmul triton templates.
394
+ """
395
+ even_k_symbolic = (
396
+ # it isn't worth guarding on this
397
+ sympy.gcd(sym_k, config.kwargs["BLOCK_K"])
398
+ == config.kwargs["BLOCK_K"]
399
+ )
400
+ allow_tf32 = torch.backends.cuda.matmul.allow_tf32 and (
401
+ not inductor_config.force_same_precision
402
+ or ((sym_m % 16) == 0 and (sym_n % 16) == 0 and (sym_k % 8) == 0)
403
+ )
404
+ return dict(
405
+ GROUP_M=8,
406
+ EVEN_K=even_k_symbolic,
407
+ ALLOW_TF32=allow_tf32,
408
+ ACC_TYPE=acc_type(layout.dtype),
409
+ B_PROLOGUE_CAST_TYPE=b_prologue_cast_type,
410
+ num_stages=config.num_stages,
411
+ num_warps=config.num_warps,
412
+ **config.kwargs,
413
+ )
414
+
415
+
416
+ def mm_args(
417
+ mat1,
418
+ mat2,
419
+ *others,
420
+ layout=None,
421
+ out_dtype=None,
422
+ use_4x2_dim=False,
423
+ mat2_transposed=False,
424
+ ):
425
+ """
426
+ Common arg processing for mm,bmm,addmm,etc
427
+ """
428
+ mat1, mat2 = realize_inputs(mat1, mat2)
429
+ *b1, m, k1 = mat1.get_size()
430
+ if mat2_transposed:
431
+ *b2, n, k2 = mat2.get_size()
432
+ else:
433
+ *b2, k2, n = mat2.get_size()
434
+ b = [V.graph.sizevars.guard_equals(a, b) for a, b in zip(b1, b2)]
435
+ if use_4x2_dim:
436
+ k2 = k2 * 2
437
+ k = V.graph.sizevars.guard_equals(k1, k2)
438
+ if layout is None:
439
+ from torch._inductor.ir import FixedLayout
440
+
441
+ if out_dtype is None:
442
+ out_dtype = mat1.get_dtype()
443
+
444
+ layout = FixedLayout(
445
+ mat1.get_device(),
446
+ out_dtype,
447
+ [*b, m, n],
448
+ )
449
+ else:
450
+ assert out_dtype is None, "out_dtype is ignored if layout is specified."
451
+ from ..lowering import expand
452
+
453
+ others = [realize_inputs(expand(x, layout.size)) for x in others]
454
+
455
+ return [m, n, k, layout, mat1, mat2, *others]
456
+
457
+
458
+ def addmm_epilogue(dtype, alpha, beta):
459
+ def epilogue(acc, bias):
460
+ if alpha != 1:
461
+ acc = V.ops.mul(acc, V.ops.constant(alpha, dtype))
462
+ if beta != 1:
463
+ bias = V.ops.mul(bias, V.ops.constant(beta, dtype))
464
+ return V.ops.add(acc, bias)
465
+
466
+ return epilogue
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/mm_plus_mm.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import functools
3
+
4
+ import torch
5
+
6
+ from ..lowering import lowerings
7
+ from ..select_algorithm import (
8
+ autotune_select_algorithm,
9
+ ExternKernelChoice,
10
+ TritonTemplate,
11
+ )
12
+ from ..utils import use_aten_gemm_kernels, use_triton_template
13
+ from ..virtualized import V
14
+ from .mm_common import mm_args, mm_grid, mm_options
15
+
16
+
17
+ aten = torch.ops.aten
18
+
19
+ aten_mm_plus_mm = ExternKernelChoice(
20
+ torch.ops.inductor._mm_plus_mm, "torch::inductor::_mm_plus_mm"
21
+ )
22
+
23
+ mm_plus_mm_template = TritonTemplate(
24
+ name="mm_plus_mm",
25
+ grid=mm_grid,
26
+ debug=False,
27
+ source=r"""
28
+ {{def_kernel("A", "B", "C", "D")}}
29
+ M = {{size("A", 0)}}
30
+ N = {{size("B", 1)}}
31
+ K1 = {{size("A", 1)}}
32
+ if M * N == 0:
33
+ # early exit due to zero-size input(s)
34
+ return
35
+ # K2 = {{size("C", 1)}}
36
+ stride_am = {{stride("A", 0)}}
37
+ stride_ak = {{stride("A", 1)}}
38
+ stride_bk = {{stride("B", 0)}}
39
+ stride_bn = {{stride("B", 1)}}
40
+ stride_cm = {{stride("C", 0)}}
41
+ stride_ck = {{stride("C", 1)}}
42
+ stride_dk = {{stride("D", 0)}}
43
+ stride_dn = {{stride("D", 1)}}
44
+
45
+ # based on triton.ops.matmul
46
+ pid = tl.program_id(0)
47
+ grid_m = (M + BLOCK_M - 1) // BLOCK_M
48
+ grid_n = (N + BLOCK_N - 1) // BLOCK_N
49
+
50
+ # re-order program ID for better L2 performance
51
+ width = GROUP_M * grid_n
52
+ group_id = pid // width
53
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
54
+ pid_m = group_id * GROUP_M + (pid % group_size)
55
+ pid_n = (pid % width) // (group_size)
56
+
57
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
58
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
59
+
60
+ if (((stride_am == 1 and stride_ak == M) or (stride_am == K1 and stride_ak == 1))
61
+ and ((stride_cm == 1 and stride_ck == M) or (stride_cm == K1 and stride_ck == 1))):
62
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
63
+ else:
64
+ ram = rm % M
65
+
66
+ if (((stride_bk == 1 and stride_bn == K1) or (stride_bk == N and stride_bn == 1))
67
+ and ((stride_dk == 1 and stride_dn == K1) or (stride_dk == N and stride_dn == 1))):
68
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
69
+ else:
70
+ rbn = rn % N
71
+
72
+ rk = tl.arange(0, BLOCK_K)
73
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
74
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
75
+ C = C + (ram[:, None] * stride_cm + rk[None, :] * stride_ck)
76
+ D = D + (rk[:, None] * stride_dk + rbn[None, :] * stride_dn)
77
+
78
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
79
+ for k1 in range(K1, 0, -BLOCK_K):
80
+ # First matmul with A @ B
81
+ if EVEN_K:
82
+ a = tl.load(A)
83
+ b = tl.load(B)
84
+ else:
85
+ a = tl.load(A, mask=rk[None, :] < k1, other=0.)
86
+ b = tl.load(B, mask=rk[:, None] < k1, other=0.)
87
+ acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
88
+ A += BLOCK_K * stride_ak
89
+ B += BLOCK_K * stride_bk
90
+
91
+ for k2 in range(K1, 0, -BLOCK_K):
92
+
93
+ # Second matmul with C @ D
94
+ if EVEN_K:
95
+ c = tl.load(C)
96
+ d = tl.load(D)
97
+ else:
98
+ c = tl.load(C, mask=rk[None, :] < k2, other=0.)
99
+ d = tl.load(D, mask=rk[:, None] < k2, other=0.)
100
+ acc += tl.dot(c, d, allow_tf32=ALLOW_TF32)
101
+ C += BLOCK_K * stride_ck
102
+ D += BLOCK_K * stride_dk
103
+
104
+
105
+ idx_m = rm[:, None]
106
+ idx_n = rn[None, :]
107
+ mask = (idx_m < M) & (idx_n < N)
108
+
109
+ # inductor generates a suffix
110
+ {{store_output(("idx_m", "idx_n"), "acc", "mask")}}
111
+ """,
112
+ )
113
+
114
+
115
+ @functools.lru_cache(None)
116
+ def mm_configs():
117
+ import triton
118
+
119
+ # List of dictionaries to store the kernel configs. Configs that evaluate to true
120
+ # will be utilised on the target platform
121
+ mm_triton_configs = [
122
+ {
123
+ "config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32},
124
+ "num_stages": 2,
125
+ "num_warps": 4,
126
+ "cond": True,
127
+ },
128
+ {
129
+ "config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32},
130
+ "num_stages": 3,
131
+ "num_warps": 8,
132
+ "cond": True,
133
+ },
134
+ {
135
+ "config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32},
136
+ "num_stages": 4,
137
+ "num_warps": 16,
138
+ "cond": True,
139
+ },
140
+ {
141
+ "config": {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32},
142
+ "num_stages": 4,
143
+ "num_warps": 8,
144
+ "cond": True,
145
+ },
146
+ {
147
+ "config": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32},
148
+ "num_stages": 4,
149
+ "num_warps": 8,
150
+ "cond": True,
151
+ },
152
+ {
153
+ "config": {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32},
154
+ "num_stages": 1,
155
+ "num_warps": 8,
156
+ "cond": True,
157
+ },
158
+ {
159
+ "config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64},
160
+ "num_stages": 1,
161
+ "num_warps": 8,
162
+ "cond": True,
163
+ },
164
+ {
165
+ "config": {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 128},
166
+ "num_stages": 1,
167
+ "num_warps": 8,
168
+ "cond": torch.version.hip is None,
169
+ },
170
+ {
171
+ "config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 16},
172
+ "num_stages": 2,
173
+ "num_warps": 4,
174
+ "cond": True,
175
+ },
176
+ {
177
+ "config": {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 16},
178
+ "num_stages": 1,
179
+ "num_warps": 2,
180
+ "cond": True,
181
+ },
182
+ ]
183
+
184
+ # Filter out configs in which cond evaluates to true
185
+ # On ROCm convert num_stages to 1 as pipelining provides no benefit
186
+ if torch.version.hip:
187
+ filtered_configs = [
188
+ triton.Config(c["config"], num_stages=1, num_warps=c["num_warps"])
189
+ for c in mm_triton_configs
190
+ if c["cond"]
191
+ ]
192
+ else:
193
+ filtered_configs = [
194
+ triton.Config(
195
+ c["config"], num_stages=c["num_stages"], num_warps=c["num_warps"]
196
+ )
197
+ for c in mm_triton_configs
198
+ if c["cond"]
199
+ ]
200
+
201
+ return filtered_configs
202
+
203
+
204
+ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
205
+ """
206
+ Computes mm(mat1, mat2) + mm(mat3, mat4)
207
+ """
208
+ m1, n1, k1, layout1, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
209
+ m2, n2, _, layout2, mat3, mat4 = mm_args(mat3, mat4, layout=layout)
210
+ # Optimization is optional, because we can always just not do the fusion
211
+ if (
212
+ m1 * n1 == 0
213
+ or m2 * n2 == 0
214
+ or not V.graph.sizevars.statically_known_list_equals(
215
+ mat1.get_size(), mat3.get_size()
216
+ )
217
+ or not V.graph.sizevars.statically_known_list_equals(
218
+ mat2.get_size(), mat4.get_size()
219
+ )
220
+ ):
221
+ # TODO(jansel): support different K values when this is fixed:
222
+ # https://github.com/openai/triton/issues/967
223
+ return lowerings[aten.add](
224
+ lowerings[aten.mm](mat1, mat2), lowerings[aten.mm](mat3, mat4)
225
+ )
226
+
227
+ assert layout1 == layout2
228
+ # options to tune from
229
+ choices = (
230
+ [aten_mm_plus_mm.bind((mat1, mat2, mat3, mat4), layout1)]
231
+ if use_aten_gemm_kernels()
232
+ else []
233
+ )
234
+ if use_triton_template(layout1):
235
+ for config in mm_configs():
236
+ # see https://github.com/openai/triton/issues/1298
237
+ # BLOCK_K = K causes llvm error
238
+ if V.graph.sizevars.statically_known_lt(config.kwargs["BLOCK_K"], k1):
239
+ mm_plus_mm_template.maybe_append_choice(
240
+ choices,
241
+ input_nodes=(mat1, mat2, mat3, mat4),
242
+ layout=layout1,
243
+ **mm_options(config, m1, n1, k1, layout1),
244
+ )
245
+
246
+ return autotune_select_algorithm(
247
+ "mm_plus_mm", choices, [mat1, mat2, mat3, mat4], layout1
248
+ )
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/mm_scaled.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Any, Dict, List, Optional, Tuple
3
+
4
+ import sympy
5
+
6
+ import torch
7
+
8
+ from .. import config as inductor_config
9
+ from ..ir import ChoiceCaller, Layout, StorageBox, TensorBox
10
+ from ..lowering import add_layout_constraint, constrain_to_fx_strides, register_lowering
11
+ from ..select_algorithm import (
12
+ autotune_select_algorithm,
13
+ ExternKernelChoice,
14
+ NoValidChoicesError,
15
+ realize_inputs,
16
+ TritonTemplate,
17
+ )
18
+ from ..utils import use_aten_gemm_kernels, use_triton_template
19
+ from .mm import _is_static_problem # TODO(yangsiyu) move to mm_common
20
+ from .mm_common import mm_args, mm_grid, scaled_mm_configs
21
+
22
+
23
+ log = logging.getLogger(__name__)
24
+ aten = torch.ops.aten
25
+
26
+
27
+ scaled_mm_template = TritonTemplate(
28
+ name="scaled_mm",
29
+ grid=mm_grid,
30
+ source=r"""
31
+ {{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}}
32
+ M = {{size("A", 0)}}
33
+ N = {{size("B", 1)}}
34
+ K = {{size("A", 1)}}
35
+ if M * N == 0:
36
+ # early exit due to zero-size input(s)
37
+ return
38
+ stride_am = {{stride("A", 0)}}
39
+ stride_ak = {{stride("A", 1)}}
40
+ stride_bk = {{stride("B", 0)}}
41
+ stride_bn = {{stride("B", 1)}}
42
+
43
+ # based on triton.ops.matmul
44
+ pid = tl.program_id(0)
45
+ grid_m = (M + BLOCK_M - 1) // BLOCK_M
46
+ grid_n = (N + BLOCK_N - 1) // BLOCK_N
47
+
48
+ # re-order program ID for better L2 performance
49
+ width = GROUP_M * grid_n
50
+ group_id = pid // width
51
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
52
+ pid_m = group_id * GROUP_M + (pid % group_size)
53
+ pid_n = (pid % width) // (group_size)
54
+
55
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
56
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
57
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
58
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
59
+ rk = tl.arange(0, BLOCK_K)
60
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
61
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
62
+
63
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
64
+ for k in range(K, 0, -BLOCK_K):
65
+ if EVEN_K:
66
+ a = tl.load(A)
67
+ b = tl.load(B)
68
+ else:
69
+ a = tl.load(A, mask=rk[None, :] < k, other=0.)
70
+ b = tl.load(B, mask=rk[:, None] < k, other=0.)
71
+ if B_PROLOGUE_CAST_TYPE is not None:
72
+ b = b.to(B_PROLOGUE_CAST_TYPE)
73
+ if USE_FAST_ACCUM:
74
+ acc = tl.dot(a, b, acc, out_dtype=ACC_TYPE)
75
+ else:
76
+ acc += tl.dot(a, b, out_dtype=ACC_TYPE)
77
+ A += BLOCK_K * stride_ak
78
+ B += BLOCK_K * stride_bk
79
+
80
+ if SCALING_ROWWISE:
81
+ inv_a_scale_row = tl.load(A_inverse_scale + rm, mask=rm < M)
82
+ inv_b_scale_row = tl.load(B_inverse_scale + rn, mask=rn < N)
83
+ inv_scale_row = inv_a_scale_row[:, None] * inv_b_scale_row[None, :]
84
+ acc *= inv_scale_row
85
+ else:
86
+ # for tensor-wise scaling, the scales are scalars
87
+ inv_a_scale = tl.load(A_inverse_scale)
88
+ inv_b_scale = tl.load(B_inverse_scale)
89
+ inv_scale = inv_a_scale * inv_b_scale
90
+ acc *= inv_scale
91
+
92
+ # rematerialize rm and rn to save registers
93
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
94
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
95
+
96
+ idx_m = rm[:, None]
97
+ idx_n = rn[None, :]
98
+ mask = (idx_m < M) & (idx_n < N)
99
+
100
+ # inductor generates a suffix
101
+ {{store_output(("idx_m", "idx_n"), "acc", "mask")}}
102
+ """,
103
+ )
104
+
105
+
106
+ # Inductor does not allow optional tensor input arguments currently (pass None as an
107
+ # input node to template choices), but since for _scaled_mm there is only one such arg
108
+ # (bias), work around by having a second template when bias is provided.
109
+ scaled_mm_bias_template = TritonTemplate(
110
+ name="scaled_mm_bias",
111
+ grid=mm_grid,
112
+ source=r"""
113
+ {{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale", "bias_ptr")}}
114
+ M = {{size("A", 0)}}
115
+ N = {{size("B", 1)}}
116
+ K = {{size("A", 1)}}
117
+ if M * N == 0:
118
+ # early exit due to zero-size input(s)
119
+ return
120
+ stride_am = {{stride("A", 0)}}
121
+ stride_ak = {{stride("A", 1)}}
122
+ stride_bk = {{stride("B", 0)}}
123
+ stride_bn = {{stride("B", 1)}}
124
+
125
+ # based on triton.ops.matmul
126
+ pid = tl.program_id(0)
127
+ grid_m = (M + BLOCK_M - 1) // BLOCK_M
128
+ grid_n = (N + BLOCK_N - 1) // BLOCK_N
129
+
130
+ # re-order program ID for better L2 performance
131
+ width = GROUP_M * grid_n
132
+ group_id = pid // width
133
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
134
+ pid_m = group_id * GROUP_M + (pid % group_size)
135
+ pid_n = (pid % width) // (group_size)
136
+
137
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
138
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
139
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
140
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
141
+ rk = tl.arange(0, BLOCK_K)
142
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
143
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
144
+
145
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
146
+ for k in range(K, 0, -BLOCK_K):
147
+ if EVEN_K:
148
+ a = tl.load(A)
149
+ b = tl.load(B)
150
+ else:
151
+ a = tl.load(A, mask=rk[None, :] < k, other=0.)
152
+ b = tl.load(B, mask=rk[:, None] < k, other=0.)
153
+ if B_PROLOGUE_CAST_TYPE is not None:
154
+ b = b.to(B_PROLOGUE_CAST_TYPE)
155
+ if USE_FAST_ACCUM:
156
+ acc = tl.dot(a, b, acc, out_dtype=ACC_TYPE)
157
+ else:
158
+ acc += tl.dot(a, b, out_dtype=ACC_TYPE)
159
+ A += BLOCK_K * stride_ak
160
+ B += BLOCK_K * stride_bk
161
+
162
+ if SCALING_ROWWISE:
163
+ inv_a_scale_row = tl.load(A_inverse_scale + rm, mask=rm < M)
164
+ inv_b_scale_row = tl.load(B_inverse_scale + rn, mask=rn < N)
165
+ inv_scale_row = inv_a_scale_row[:, None] * inv_b_scale_row[None, :]
166
+ acc *= inv_scale_row
167
+ else:
168
+ # for tensor-wise scaling, the scales are scalars
169
+ inv_a_scale = tl.load(A_inverse_scale)
170
+ inv_b_scale = tl.load(B_inverse_scale)
171
+ inv_scale = inv_a_scale * inv_b_scale
172
+ acc *= inv_scale
173
+
174
+ # rematerialize rm and rn to save registers
175
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
176
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
177
+
178
+ # bias
179
+ bias = tl.load(bias_ptr + rn, mask=rn < N)
180
+ acc += bias
181
+
182
+ idx_m = rm[:, None]
183
+ idx_n = rn[None, :]
184
+ mask = (idx_m < M) & (idx_n < N)
185
+
186
+ # inductor generates a suffix
187
+ {{store_output(("idx_m", "idx_n"), "acc", "mask")}}
188
+ """,
189
+ )
190
+
191
+
192
+ aten__fp8_mm = ExternKernelChoice(torch._scaled_mm, "at::_scaled_mm")
193
+
194
+
195
+ def are_compatible_scales(size_a: List[int], size_b: List[int]) -> bool:
196
+ # Same sized scales are compatable
197
+ if len(size_a) == len(size_b):
198
+ return True
199
+
200
+ # Both need to be scalars or len(1) tensors
201
+ if len(size_a) <= 1 and len(size_b) <= 1:
202
+ return True
203
+
204
+ return False
205
+
206
+
207
+ def scaled_mm_options( # type: ignore[no-untyped-def]
208
+ config, # triton.Config
209
+ sym_m: sympy.core.numbers.Integer,
210
+ sym_n: sympy.core.numbers.Integer,
211
+ sym_k: sympy.core.numbers.Integer,
212
+ layout: Layout,
213
+ scale_a: StorageBox,
214
+ scale_b: StorageBox,
215
+ use_fast_accum: bool,
216
+ b_prologue_cast_type: Optional[str] = None,
217
+ ) -> Dict[str, Any]:
218
+ even_k_symbolic = (
219
+ sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) == config.kwargs["BLOCK_K"]
220
+ )
221
+
222
+ size_a, size_b = scale_a.get_size(), scale_b.get_size()
223
+ assert are_compatible_scales(size_a, size_b), (
224
+ "Expect scale_a and scale_b to be either both scalars (including single-element tensors) "
225
+ f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}."
226
+ )
227
+ return dict(
228
+ GROUP_M=8,
229
+ EVEN_K=even_k_symbolic,
230
+ ACC_TYPE="tl.float32",
231
+ B_PROLOGUE_CAST_TYPE=b_prologue_cast_type,
232
+ USE_FAST_ACCUM=use_fast_accum,
233
+ num_stages=config.num_stages,
234
+ num_warps=config.num_warps,
235
+ # tensor-wise scaling if scalar scales
236
+ SCALING_ROWWISE=len(scale_a.get_size()) == 2,
237
+ **config.kwargs,
238
+ )
239
+
240
+
241
+ add_layout_constraint(aten._scaled_mm.default, constrain_to_fx_strides)
242
+
243
+
244
+ @register_lowering(aten._scaled_mm.default, type_promotion_kind=None) # type: ignore[misc]
245
+ def tuned_scaled_mm(
246
+ mat_a: TensorBox,
247
+ mat_b: TensorBox,
248
+ scale_a: TensorBox,
249
+ scale_b: TensorBox,
250
+ bias: Optional[TensorBox] = None,
251
+ scale_result: Optional[TensorBox] = None,
252
+ out_dtype: Optional[torch.dtype] = None,
253
+ use_fast_accum: bool = False,
254
+ layout: Optional[Layout] = None,
255
+ ) -> TensorBox:
256
+ m, n, k, layout, mat_a, mat_b = mm_args(
257
+ mat_a, mat_b, layout=layout, out_dtype=out_dtype
258
+ )
259
+ scale_a, scale_b = realize_inputs(scale_a, scale_b)
260
+
261
+ input_nodes: Tuple[Any, ...]
262
+ # workaround for Inductor not supporting optional tensor input arguments
263
+ if bias is None:
264
+ input_nodes = (mat_a, mat_b, scale_a, scale_b)
265
+ triton_template = scaled_mm_template
266
+ else:
267
+ bias = realize_inputs(bias)
268
+ input_nodes = (mat_a, mat_b, scale_a, scale_b, bias)
269
+ triton_template = scaled_mm_bias_template
270
+
271
+ aten_choice = aten__fp8_mm.bind(
272
+ input_nodes, layout, out_dtype=out_dtype, use_fast_accum=use_fast_accum
273
+ )
274
+
275
+ choices: List[ChoiceCaller] = []
276
+ if use_aten_gemm_kernels():
277
+ choices.append(aten_choice)
278
+
279
+ static_shape, is_nonzero = _is_static_problem([mat_a, mat_b], layout)
280
+ if is_nonzero and use_triton_template(layout, enable_float8=True):
281
+ for config in scaled_mm_configs(m, n, k):
282
+ if k == 16 and config.kwargs["BLOCK_M"] >= 64:
283
+ continue # Triton crashes in this case
284
+ kwargs = scaled_mm_options(
285
+ config, m, n, k, layout, scale_a, scale_b, use_fast_accum
286
+ )
287
+ # possibly appends a TritonTemplateCaller to choices
288
+ triton_template.maybe_append_choice(
289
+ choices,
290
+ input_nodes=input_nodes,
291
+ layout=layout,
292
+ **kwargs,
293
+ )
294
+
295
+ if (
296
+ len(choices) == 0
297
+ and not use_aten_gemm_kernels()
298
+ and inductor_config.autotune_fallback_to_aten
299
+ ):
300
+ log.warning("No choices for scaled_mm, using ATen backend as fallback")
301
+ return aten_choice.output_node()
302
+
303
+ try:
304
+ return autotune_select_algorithm("scaled_mm", choices, input_nodes, layout)
305
+ except NoValidChoicesError:
306
+ if not inductor_config.autotune_fallback_to_aten:
307
+ raise
308
+ log.warning(
309
+ "All choices for scaled_mm were invalid, using ATen backend as fallback"
310
+ )
311
+ return aten_choice.output_node()
.venv/lib/python3.11/site-packages/torch/_inductor/kernel/unpack_mixed_mm.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import logging
3
+ from typing import List, TYPE_CHECKING
4
+
5
+ from ..select_algorithm import autotune_select_algorithm, TritonTemplate
6
+ from .mm_common import mm_args, mm_configs, mm_grid, mm_options
7
+
8
+
9
+ if TYPE_CHECKING:
10
+ from ..ir import ChoiceCaller
11
+
12
+ log = logging.getLogger(__name__)
13
+
14
+ uint4x2_mixed_mm_template = TritonTemplate(
15
+ name="uint4x2_mixed_mm",
16
+ grid=mm_grid,
17
+ source=r"""
18
+ {{def_kernel("A", "B")}}
19
+ M = {{size("A", 0)}}
20
+ N = {{size("B", 1)}}
21
+ K = {{size("A", 1)}}
22
+ stride_am = {{stride("A", 0)}}
23
+ stride_ak = {{stride("A", 1)}}
24
+ stride_bk = {{stride("B", 0)}}
25
+ stride_bn = {{stride("B", 1)}}
26
+
27
+ # based on triton.ops.matmul
28
+ pid = tl.program_id(0)
29
+ grid_m = (M + BLOCK_M - 1) // BLOCK_M
30
+ grid_n = (N + BLOCK_N - 1) // BLOCK_N
31
+
32
+ # re-order program ID for better L2 performance
33
+ width = GROUP_M * grid_n
34
+ group_id = pid // width
35
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
36
+ pid_m = group_id * GROUP_M + (pid % group_size)
37
+ pid_n = (pid % width) // (group_size)
38
+
39
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
40
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
41
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
42
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
43
+ rk = tl.arange(0, BLOCK_K)
44
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
45
+ B = B + (rk[:, None]//2 * stride_bk + rbn[None, :] * stride_bn)
46
+ b_shifts = 4*(rk%2)
47
+ b_subs = 8*(1-(rk%2))
48
+
49
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
50
+ for k in range(K, 0, -BLOCK_K):
51
+ if EVEN_K:
52
+ a = tl.load(A)
53
+ b = tl.load(B)
54
+ else:
55
+ a = tl.load(A, mask=rk[None, :] < k, other=0.)
56
+ b = tl.load(B, mask=rk[:, None] < k, other=0.)
57
+ b = ((b >> b_shifts[:, None]) & 0xF) - 8
58
+ b = b.to(B_PROLOGUE_CAST_TYPE)
59
+ acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
60
+ A += BLOCK_K * stride_ak
61
+ B += BLOCK_K//2 * stride_bk
62
+
63
+ # rematerialize rm and rn to save registers
64
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
65
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
66
+ idx_m = rm[:, None]
67
+ idx_n = rn[None, :]
68
+ mask = (idx_m < M) & (idx_n < N)
69
+
70
+ # inductor generates a suffix
71
+ {{store_output(("idx_m", "idx_n"), "acc", "mask")}}
72
+ """,
73
+ )
74
+
75
+
76
+ def tuned_uint4x2_mixed_mm(mat1, mat2, mat2_mm_shape, mat2_dtype):
77
+ m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=None, use_4x2_dim=True)
78
+ choices: List[ChoiceCaller] = []
79
+ b_prologue_cast_type = f"tl.{mat2_dtype}".replace("torch.", "")
80
+ for config in mm_configs(m, n, k):
81
+ uint4x2_mixed_mm_template.maybe_append_choice(
82
+ choices,
83
+ input_nodes=(mat1, mat2),
84
+ layout=layout,
85
+ **mm_options(config, m, n, k, layout, b_prologue_cast_type),
86
+ )
87
+ return autotune_select_algorithm("uint4x2_mixed_mm", choices, [mat1, mat2], layout)
.venv/lib/python3.11/site-packages/torch/_inductor/runtime/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/torch/_inductor/runtime/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (196 Bytes). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/runtime/__pycache__/autotune_cache.cpython-311.pyc ADDED
Binary file (11.5 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/runtime/__pycache__/benchmarking.cpython-311.pyc ADDED
Binary file (11.2 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/runtime/__pycache__/compile_tasks.cpython-311.pyc ADDED
Binary file (4.34 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/runtime/__pycache__/coordinate_descent_tuner.cpython-311.pyc ADDED
Binary file (11.7 kB). View file