Kernels:
Trusted publisher
Uploaded using `kernel-builder`.
Browse files- build/torch-cuda/__init__.py +2 -6
- build/torch-cuda/_ops.py +2 -2
- build/torch-cuda/bench_utils.py +196 -0
- build/torch-cuda/block_info.py +35 -4
- build/torch-cuda/block_sparse_utils.py +19 -61
- build/torch-cuda/block_sparsity.py +33 -10
- build/torch-cuda/cache_utils.py +9 -15
- build/torch-cuda/cute_dsl_utils.py +0 -38
- build/torch-cuda/fa_logging.py +97 -0
- build/torch-cuda/flash_bwd.py +25 -6
- build/torch-cuda/flash_bwd_postprocess.py +31 -29
- build/torch-cuda/flash_bwd_preprocess.py +146 -157
- build/torch-cuda/flash_bwd_sm100.py +9 -6
- build/torch-cuda/flash_bwd_sm120.py +55 -0
- build/torch-cuda/flash_bwd_sm90.py +451 -151
- build/torch-cuda/flash_fwd.py +132 -1363
- build/torch-cuda/flash_fwd_combine.py +62 -56
- build/torch-cuda/flash_fwd_sm100.py +406 -257
- build/torch-cuda/flash_fwd_sm120.py +59 -0
- build/torch-cuda/flash_fwd_sm90.py +1534 -0
- build/torch-cuda/interface.py +734 -505
- build/torch-cuda/mask.py +168 -110
- build/torch-cuda/named_barrier.py +15 -0
- build/torch-cuda/pack_gqa.py +110 -12
- build/torch-cuda/paged_kv.py +35 -15
- build/torch-cuda/pipeline.py +198 -236
- build/torch-cuda/quack/copy_utils.py +186 -8
- build/torch-cuda/quack/cute_dsl_utils.py +20 -26
- build/torch-cuda/quack/layout_utils.py +34 -0
- build/torch-cuda/quack/utils.py +324 -0
- build/torch-cuda/seqlen_info.py +188 -39
- build/torch-cuda/sm90_config_search.py +402 -0
- build/torch-cuda/softmax.py +1 -1
- build/torch-cuda/tile_scheduler.py +419 -59
- build/torch-cuda/utils.py +104 -2
build/torch-cuda/__init__.py
CHANGED
|
@@ -1,19 +1,15 @@
|
|
| 1 |
"""Flash Attention CUTE (CUDA Template Engine) implementation."""
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
# Update when syncing again.
|
| 6 |
-
__version__ = "4.0.0.beta4"
|
| 7 |
|
| 8 |
import cutlass.cute as cute
|
| 9 |
|
|
|
|
| 10 |
from .interface import (
|
| 11 |
flash_attn_func,
|
| 12 |
flash_attn_varlen_func,
|
| 13 |
)
|
| 14 |
|
| 15 |
-
from .cute_dsl_utils import cute_compile_patched
|
| 16 |
-
|
| 17 |
# Patch cute.compile to optionally dump SASS
|
| 18 |
cute.compile = cute_compile_patched
|
| 19 |
|
|
|
|
| 1 |
"""Flash Attention CUTE (CUDA Template Engine) implementation."""
|
| 2 |
|
| 3 |
+
__version__ = "4.0.0.beta8"
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
import cutlass.cute as cute
|
| 6 |
|
| 7 |
+
from .cute_dsl_utils import cute_compile_patched
|
| 8 |
from .interface import (
|
| 9 |
flash_attn_func,
|
| 10 |
flash_attn_varlen_func,
|
| 11 |
)
|
| 12 |
|
|
|
|
|
|
|
| 13 |
# Patch cute.compile to optionally dump SASS
|
| 14 |
cute.compile = cute_compile_patched
|
| 15 |
|
build/torch-cuda/_ops.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
import torch
|
| 2 |
-
ops = torch.ops.
|
| 3 |
|
| 4 |
def add_op_namespace_prefix(op_name: str):
|
| 5 |
"""
|
| 6 |
Prefix op by namespace.
|
| 7 |
"""
|
| 8 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
ops = torch.ops._flash_attn4_c9a1374
|
| 3 |
|
| 4 |
def add_op_namespace_prefix(op_name: str):
|
| 5 |
"""
|
| 6 |
Prefix op by namespace.
|
| 7 |
"""
|
| 8 |
+
return f"_flash_attn4_c9a1374::{op_name}"
|
build/torch-cuda/bench_utils.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared benchmark utilities: attention_ref, cuDNN helpers, flops calculation."""
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
try:
|
| 7 |
+
import cudnn
|
| 8 |
+
except ImportError:
|
| 9 |
+
cudnn = None
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# ── FLOPS calculation ────────────────────────────────────────────────────────
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def flops(
|
| 16 |
+
batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, window_size=(None, None)
|
| 17 |
+
):
|
| 18 |
+
if causal:
|
| 19 |
+
avg_seqlen = (max(0, seqlen_k - seqlen_q) + seqlen_k) / 2
|
| 20 |
+
else:
|
| 21 |
+
if window_size == (None, None):
|
| 22 |
+
avg_seqlen = seqlen_k
|
| 23 |
+
else:
|
| 24 |
+
row_idx = torch.arange(seqlen_q, device="cuda")
|
| 25 |
+
col_left = (
|
| 26 |
+
torch.maximum(row_idx + seqlen_k - seqlen_q - window_size[0], torch.tensor(0))
|
| 27 |
+
if window_size[0] is not None
|
| 28 |
+
else torch.zeros_like(row_idx)
|
| 29 |
+
)
|
| 30 |
+
col_right = (
|
| 31 |
+
torch.minimum(
|
| 32 |
+
row_idx + seqlen_k - seqlen_q + window_size[1], torch.tensor(seqlen_k - 1)
|
| 33 |
+
)
|
| 34 |
+
if window_size[1] is not None
|
| 35 |
+
else torch.full_like(row_idx, seqlen_k - 1)
|
| 36 |
+
)
|
| 37 |
+
avg_seqlen = (col_right - col_left + 1).float().mean().item()
|
| 38 |
+
return batch * nheads * 2 * seqlen_q * avg_seqlen * (headdim + headdim_v)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# ── Reference attention ─────────────────────────────────────────────────────
|
| 42 |
+
|
| 43 |
+
_attention_ref_mask_cache = {}
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def attention_ref(q, k, v, causal=False):
|
| 47 |
+
"""Standard attention reference implementation.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
q, k, v: (batch, seqlen, nheads, headdim) tensors.
|
| 51 |
+
causal: whether to apply causal mask.
|
| 52 |
+
"""
|
| 53 |
+
softmax_scale = 1.0 / math.sqrt(q.shape[-1])
|
| 54 |
+
scores = torch.einsum("bthd,bshd->bhts", q * softmax_scale, k)
|
| 55 |
+
if causal:
|
| 56 |
+
if scores.shape[-2] not in _attention_ref_mask_cache:
|
| 57 |
+
mask = torch.tril(
|
| 58 |
+
torch.ones(scores.shape[-2:], device=scores.device, dtype=torch.bool), diagonal=0
|
| 59 |
+
)
|
| 60 |
+
_attention_ref_mask_cache[scores.shape[-2]] = mask
|
| 61 |
+
else:
|
| 62 |
+
mask = _attention_ref_mask_cache[scores.shape[-2]]
|
| 63 |
+
scores = scores.masked_fill(mask, float("-inf"))
|
| 64 |
+
attn = torch.softmax(scores, dim=-1)
|
| 65 |
+
return torch.einsum("bhts,bshd->bthd", attn, v)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# ── cuDNN graph helpers ─────────────────────────────────────────────────────
|
| 69 |
+
|
| 70 |
+
_TORCH_TO_CUDNN_DTYPE = {
|
| 71 |
+
torch.float16: "HALF",
|
| 72 |
+
torch.bfloat16: "BFLOAT16",
|
| 73 |
+
torch.float32: "FLOAT",
|
| 74 |
+
torch.int32: "INT32",
|
| 75 |
+
torch.int64: "INT64",
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _build_cudnn_graph(io_dtype, tensors, build_fn):
|
| 80 |
+
"""Build a cuDNN graph. Returns (graph, variant_pack, workspace)."""
|
| 81 |
+
assert cudnn is not None, "cuDNN is not available"
|
| 82 |
+
cudnn_dtype = getattr(cudnn.data_type, _TORCH_TO_CUDNN_DTYPE[io_dtype])
|
| 83 |
+
graph = cudnn.pygraph(
|
| 84 |
+
io_data_type=cudnn_dtype,
|
| 85 |
+
intermediate_data_type=cudnn.data_type.FLOAT,
|
| 86 |
+
compute_data_type=cudnn.data_type.FLOAT,
|
| 87 |
+
)
|
| 88 |
+
graph_tensors = {name: graph.tensor_like(t.detach()) for name, t in tensors.items()}
|
| 89 |
+
variant_pack = build_fn(graph, graph_tensors)
|
| 90 |
+
graph.validate()
|
| 91 |
+
graph.build_operation_graph()
|
| 92 |
+
graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
|
| 93 |
+
graph.check_support()
|
| 94 |
+
graph.build_plans()
|
| 95 |
+
workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8)
|
| 96 |
+
return graph, variant_pack, workspace
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def cudnn_fwd_setup(q, k, v, causal=False, window_size_left=None):
|
| 100 |
+
"""Build a cuDNN forward SDPA graph.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
q, k, v: (batch, nheads, seqlen, headdim) tensors (cuDNN layout).
|
| 104 |
+
causal: whether to apply causal mask.
|
| 105 |
+
window_size_left: sliding window size (None for no window).
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
(fwd_fn, o_gpu, stats_gpu) where fwd_fn is a zero-arg callable.
|
| 109 |
+
"""
|
| 110 |
+
b, nheads, seqlen_q, headdim = q.shape
|
| 111 |
+
headdim_v = v.shape[-1]
|
| 112 |
+
o_gpu = torch.empty(b, nheads, seqlen_q, headdim_v, dtype=q.dtype, device=q.device)
|
| 113 |
+
stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=q.device)
|
| 114 |
+
|
| 115 |
+
def build(graph, gt):
|
| 116 |
+
o, stats = graph.sdpa(
|
| 117 |
+
name="sdpa",
|
| 118 |
+
q=gt["q"],
|
| 119 |
+
k=gt["k"],
|
| 120 |
+
v=gt["v"],
|
| 121 |
+
is_inference=False,
|
| 122 |
+
attn_scale=1.0 / math.sqrt(headdim),
|
| 123 |
+
use_causal_mask=causal or window_size_left is not None,
|
| 124 |
+
sliding_window_length=window_size_left
|
| 125 |
+
if window_size_left is not None and not causal
|
| 126 |
+
else None,
|
| 127 |
+
)
|
| 128 |
+
o.set_output(True).set_dim(o_gpu.shape).set_stride(o_gpu.stride())
|
| 129 |
+
stats.set_output(True).set_data_type(cudnn.data_type.FLOAT)
|
| 130 |
+
return {gt["q"]: q, gt["k"]: k, gt["v"]: v, o: o_gpu, stats: stats_gpu}
|
| 131 |
+
|
| 132 |
+
graph, variant_pack, workspace = _build_cudnn_graph(q.dtype, {"q": q, "k": k, "v": v}, build)
|
| 133 |
+
|
| 134 |
+
def fwd_fn():
|
| 135 |
+
graph.execute(variant_pack, workspace)
|
| 136 |
+
return o_gpu
|
| 137 |
+
|
| 138 |
+
return fwd_fn, o_gpu, stats_gpu
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def cudnn_bwd_setup(q, k, v, o, g, lse, causal=False, window_size_left=None):
|
| 142 |
+
"""Build a cuDNN backward SDPA graph.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
q, k, v, o, g, lse: (batch, nheads, seqlen, dim) tensors (cuDNN layout).
|
| 146 |
+
causal: whether to apply causal mask.
|
| 147 |
+
window_size_left: sliding window size (None for no window).
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
bwd_fn: zero-arg callable that returns (dq, dk, dv).
|
| 151 |
+
"""
|
| 152 |
+
headdim = q.shape[-1]
|
| 153 |
+
dq_gpu, dk_gpu, dv_gpu = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
|
| 154 |
+
|
| 155 |
+
def build(graph, gt):
|
| 156 |
+
dq, dk, dv = graph.sdpa_backward(
|
| 157 |
+
name="sdpa_backward",
|
| 158 |
+
q=gt["q"],
|
| 159 |
+
k=gt["k"],
|
| 160 |
+
v=gt["v"],
|
| 161 |
+
o=gt["o"],
|
| 162 |
+
dO=gt["g"],
|
| 163 |
+
stats=gt["lse"],
|
| 164 |
+
attn_scale=1.0 / math.sqrt(headdim),
|
| 165 |
+
use_causal_mask=causal or window_size_left is not None,
|
| 166 |
+
sliding_window_length=window_size_left
|
| 167 |
+
if window_size_left is not None and not causal
|
| 168 |
+
else None,
|
| 169 |
+
use_deterministic_algorithm=False,
|
| 170 |
+
)
|
| 171 |
+
dq.set_output(True).set_dim(dq_gpu.shape).set_stride(dq_gpu.stride())
|
| 172 |
+
dk.set_output(True).set_dim(dk_gpu.shape).set_stride(dk_gpu.stride())
|
| 173 |
+
dv.set_output(True).set_dim(dv_gpu.shape).set_stride(dv_gpu.stride())
|
| 174 |
+
return {
|
| 175 |
+
gt["q"]: q,
|
| 176 |
+
gt["k"]: k,
|
| 177 |
+
gt["v"]: v,
|
| 178 |
+
gt["o"]: o,
|
| 179 |
+
gt["g"]: g,
|
| 180 |
+
gt["lse"]: lse,
|
| 181 |
+
dq: dq_gpu,
|
| 182 |
+
dk: dk_gpu,
|
| 183 |
+
dv: dv_gpu,
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
graph, variant_pack, workspace = _build_cudnn_graph(
|
| 187 |
+
q.dtype,
|
| 188 |
+
{"q": q, "k": k, "v": v, "o": o, "g": g, "lse": lse},
|
| 189 |
+
build,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
def bwd_fn():
|
| 193 |
+
graph.execute(variant_pack, workspace)
|
| 194 |
+
return dq_gpu, dk_gpu, dv_gpu
|
| 195 |
+
|
| 196 |
+
return bwd_fn
|
build/torch-cuda/block_info.py
CHANGED
|
@@ -6,7 +6,7 @@ import cutlass
|
|
| 6 |
import cutlass.cute as cute
|
| 7 |
from cutlass import Int32, const_expr
|
| 8 |
|
| 9 |
-
from .seqlen_info import SeqlenInfoQK
|
| 10 |
|
| 11 |
|
| 12 |
@dataclass(frozen=True)
|
|
@@ -25,8 +25,8 @@ class BlockInfo:
|
|
| 25 |
self,
|
| 26 |
seqlen_info: SeqlenInfoQK,
|
| 27 |
m_block: Int32,
|
| 28 |
-
split_idx:
|
| 29 |
-
num_splits:
|
| 30 |
) -> Tuple[Int32, Int32]:
|
| 31 |
n_block_max = cute.ceil_div(seqlen_info.seqlen_k, self.tile_n)
|
| 32 |
if const_expr(self.is_causal or (self.is_local and self.window_size_right is not None)):
|
|
@@ -46,7 +46,7 @@ class BlockInfo:
|
|
| 46 |
n_block_min = cutlass.max(n_idx_left // self.tile_n, 0)
|
| 47 |
if cutlass.const_expr(self.is_split_kv):
|
| 48 |
num_n_blocks_per_split = (
|
| 49 |
-
|
| 50 |
if n_block_max <= n_block_min
|
| 51 |
else (n_block_max - n_block_min + num_splits - 1) // num_splits
|
| 52 |
)
|
|
@@ -70,6 +70,37 @@ class BlockInfo:
|
|
| 70 |
m_block_max = min(m_block_max, cute.ceil_div(m_idx_left, self.tile_m))
|
| 71 |
return m_block_min, m_block_max
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
@cute.jit
|
| 74 |
def get_n_block_min_causal_local_mask(
|
| 75 |
self,
|
|
|
|
| 6 |
import cutlass.cute as cute
|
| 7 |
from cutlass import Int32, const_expr
|
| 8 |
|
| 9 |
+
from .seqlen_info import SeqlenInfoQK, SeqlenInfoQKNewK
|
| 10 |
|
| 11 |
|
| 12 |
@dataclass(frozen=True)
|
|
|
|
| 25 |
self,
|
| 26 |
seqlen_info: SeqlenInfoQK,
|
| 27 |
m_block: Int32,
|
| 28 |
+
split_idx: Int32 = 0,
|
| 29 |
+
num_splits: Int32 = 1,
|
| 30 |
) -> Tuple[Int32, Int32]:
|
| 31 |
n_block_max = cute.ceil_div(seqlen_info.seqlen_k, self.tile_n)
|
| 32 |
if const_expr(self.is_causal or (self.is_local and self.window_size_right is not None)):
|
|
|
|
| 46 |
n_block_min = cutlass.max(n_idx_left // self.tile_n, 0)
|
| 47 |
if cutlass.const_expr(self.is_split_kv):
|
| 48 |
num_n_blocks_per_split = (
|
| 49 |
+
Int32(0)
|
| 50 |
if n_block_max <= n_block_min
|
| 51 |
else (n_block_max - n_block_min + num_splits - 1) // num_splits
|
| 52 |
)
|
|
|
|
| 70 |
m_block_max = min(m_block_max, cute.ceil_div(m_idx_left, self.tile_m))
|
| 71 |
return m_block_min, m_block_max
|
| 72 |
|
| 73 |
+
@cute.jit
|
| 74 |
+
def get_n_block_k_new_min_max(
|
| 75 |
+
self,
|
| 76 |
+
seqlen_info: SeqlenInfoQKNewK,
|
| 77 |
+
m_block: Int32,
|
| 78 |
+
split_idx: Int32 = 0,
|
| 79 |
+
num_splits: Int32 = 1,
|
| 80 |
+
) -> Tuple[Int32, Int32]:
|
| 81 |
+
"""Get the block range for new K tokens (append KV).
|
| 82 |
+
|
| 83 |
+
First computes the full n_block range via get_n_block_min_max, then maps
|
| 84 |
+
those blocks into the new-K index space by subtracting seqlen_k_og.
|
| 85 |
+
"""
|
| 86 |
+
n_block_min, n_block_max = self.get_n_block_min_max(
|
| 87 |
+
seqlen_info,
|
| 88 |
+
m_block,
|
| 89 |
+
split_idx,
|
| 90 |
+
num_splits,
|
| 91 |
+
)
|
| 92 |
+
idx_k_new_min = cutlass.max(n_block_min * self.tile_n - seqlen_info.seqlen_k_og, 0)
|
| 93 |
+
idx_k_new_max = cutlass.min(
|
| 94 |
+
n_block_max * self.tile_n - seqlen_info.seqlen_k_og, seqlen_info.seqlen_k_new
|
| 95 |
+
)
|
| 96 |
+
n_block_new_min = idx_k_new_min // self.tile_n
|
| 97 |
+
n_block_new_max = (
|
| 98 |
+
cute.ceil_div(idx_k_new_max, self.tile_n)
|
| 99 |
+
if idx_k_new_max > idx_k_new_min
|
| 100 |
+
else n_block_new_min
|
| 101 |
+
)
|
| 102 |
+
return n_block_new_min, n_block_new_max
|
| 103 |
+
|
| 104 |
@cute.jit
|
| 105 |
def get_n_block_min_causal_local_mask(
|
| 106 |
self,
|
build/torch-cuda/block_sparse_utils.py
CHANGED
|
@@ -72,24 +72,22 @@ from .named_barrier import NamedBarrierBwd
|
|
| 72 |
def load_block_list(
|
| 73 |
block_indices: cute.Tensor,
|
| 74 |
block_count,
|
| 75 |
-
load_q_with_first: cutlass.Constexpr,
|
| 76 |
first_block_preloaded: cutlass.Constexpr,
|
| 77 |
kv_producer_state,
|
| 78 |
-
load_Q,
|
| 79 |
load_K,
|
| 80 |
load_V,
|
| 81 |
pipeline_k,
|
| 82 |
pipeline_v,
|
| 83 |
-
use_tma_q: cutlass.Constexpr,
|
| 84 |
-
tma_q_bytes: cutlass.Constexpr,
|
| 85 |
intra_wg_overlap: cutlass.Constexpr,
|
| 86 |
):
|
| 87 |
-
"""Iterate over the sparse blocks and load K, V
|
| 88 |
-
|
| 89 |
means we need to pipeline the last V load from the partial block case,
|
| 90 |
with the loads for the full blocks. Set first_block_preloaded when the
|
| 91 |
caller has already issued the first K load for the list.
|
| 92 |
|
|
|
|
|
|
|
| 93 |
Note:
|
| 94 |
we iterate along the block_n indices in reverse.
|
| 95 |
|
|
@@ -99,21 +97,7 @@ def load_block_list(
|
|
| 99 |
"""
|
| 100 |
if block_count > 0:
|
| 101 |
if const_expr(not intra_wg_overlap):
|
| 102 |
-
|
| 103 |
-
# Parameters are already Constexpr, so no need to wrap in const_expr()
|
| 104 |
-
n_block_first = block_indices[block_count - 1]
|
| 105 |
-
extra_tx = tma_q_bytes if const_expr(load_q_with_first) and const_expr(use_tma_q) else 0
|
| 106 |
-
pipeline_k.producer_acquire(kv_producer_state, extra_tx_count=extra_tx)
|
| 107 |
-
|
| 108 |
-
if const_expr(load_q_with_first and use_tma_q):
|
| 109 |
-
load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state))
|
| 110 |
-
|
| 111 |
-
load_K(src_idx=n_block_first, producer_state=kv_producer_state)
|
| 112 |
-
pipeline_v.producer_acquire(kv_producer_state)
|
| 113 |
-
load_V(src_idx=n_block_first, producer_state=kv_producer_state)
|
| 114 |
-
kv_producer_state.advance()
|
| 115 |
-
|
| 116 |
-
for offset in cutlass.range(1, block_count):
|
| 117 |
n_block = block_indices[block_count - 1 - offset]
|
| 118 |
pipeline_k.producer_acquire(kv_producer_state)
|
| 119 |
load_K(src_idx=n_block, producer_state=kv_producer_state)
|
|
@@ -123,14 +107,7 @@ def load_block_list(
|
|
| 123 |
else:
|
| 124 |
n_block_first = block_indices[block_count - 1]
|
| 125 |
if const_expr(not first_block_preloaded):
|
| 126 |
-
|
| 127 |
-
tma_q_bytes if const_expr(load_q_with_first) and const_expr(use_tma_q) else 0
|
| 128 |
-
)
|
| 129 |
-
pipeline_k.producer_acquire(kv_producer_state, extra_tx_count=extra_tx)
|
| 130 |
-
|
| 131 |
-
if const_expr(load_q_with_first and use_tma_q):
|
| 132 |
-
load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state))
|
| 133 |
-
|
| 134 |
load_K(src_idx=n_block_first, producer_state=kv_producer_state)
|
| 135 |
|
| 136 |
for idx in cutlass.range(block_count - 1, unroll=1):
|
|
@@ -186,19 +163,18 @@ def produce_block_sparse_loads(
|
|
| 186 |
head_idx,
|
| 187 |
m_block,
|
| 188 |
kv_producer_state,
|
| 189 |
-
load_Q,
|
| 190 |
load_K,
|
| 191 |
load_V,
|
| 192 |
pipeline_k,
|
| 193 |
pipeline_v,
|
| 194 |
-
use_tma_q: cutlass.Constexpr,
|
| 195 |
-
tma_q_bytes: cutlass.Constexpr,
|
| 196 |
intra_wg_overlap: cutlass.Constexpr,
|
| 197 |
qhead_per_kvhead: cutlass.Constexpr[int] = 1,
|
| 198 |
q_subtile_factor: cutlass.Constexpr[int] = 1,
|
| 199 |
):
|
| 200 |
"""Iterate over the mask and full block lists for a single tile.
|
| 201 |
|
|
|
|
|
|
|
| 202 |
The masked (partial) list may leave the last V load pending when intra-warp-group
|
| 203 |
overlap is enabled. The first full block must consume that pending V while
|
| 204 |
issuing its own K load on the next pipeline stage.
|
|
@@ -230,20 +206,16 @@ def produce_block_sparse_loads(
|
|
| 230 |
full_empty = curr_full_block_cnt == 0
|
| 231 |
|
| 232 |
if mask_empty:
|
| 233 |
-
# No masked blocks: the full list owns the initial
|
| 234 |
kv_producer_state = load_block_list(
|
| 235 |
curr_full_block_idx,
|
| 236 |
curr_full_block_cnt,
|
| 237 |
-
load_q_with_first=True,
|
| 238 |
first_block_preloaded=False,
|
| 239 |
kv_producer_state=kv_producer_state,
|
| 240 |
-
load_Q=load_Q,
|
| 241 |
load_K=load_K,
|
| 242 |
load_V=load_V,
|
| 243 |
pipeline_k=pipeline_k,
|
| 244 |
pipeline_v=pipeline_v,
|
| 245 |
-
use_tma_q=use_tma_q,
|
| 246 |
-
tma_q_bytes=tma_q_bytes,
|
| 247 |
intra_wg_overlap=intra_wg_overlap,
|
| 248 |
)
|
| 249 |
|
|
@@ -256,21 +228,16 @@ def produce_block_sparse_loads(
|
|
| 256 |
kv_producer_state,
|
| 257 |
)
|
| 258 |
else:
|
| 259 |
-
# Masked blocks present
|
| 260 |
-
# start immediately. When overlap is disabled this fully drains the list.
|
| 261 |
kv_producer_state = load_block_list(
|
| 262 |
curr_mask_block_idx,
|
| 263 |
curr_mask_block_cnt,
|
| 264 |
-
load_q_with_first=True,
|
| 265 |
first_block_preloaded=False,
|
| 266 |
kv_producer_state=kv_producer_state,
|
| 267 |
-
load_Q=load_Q,
|
| 268 |
load_K=load_K,
|
| 269 |
load_V=load_V,
|
| 270 |
pipeline_k=pipeline_k,
|
| 271 |
pipeline_v=pipeline_v,
|
| 272 |
-
use_tma_q=use_tma_q,
|
| 273 |
-
tma_q_bytes=tma_q_bytes,
|
| 274 |
intra_wg_overlap=intra_wg_overlap,
|
| 275 |
)
|
| 276 |
|
|
@@ -299,16 +266,12 @@ def produce_block_sparse_loads(
|
|
| 299 |
kv_producer_state = load_block_list(
|
| 300 |
curr_full_block_idx,
|
| 301 |
curr_full_block_cnt,
|
| 302 |
-
load_q_with_first=False,
|
| 303 |
first_block_preloaded=True,
|
| 304 |
kv_producer_state=kv_producer_state,
|
| 305 |
-
load_Q=load_Q,
|
| 306 |
load_K=load_K,
|
| 307 |
load_V=load_V,
|
| 308 |
pipeline_k=pipeline_k,
|
| 309 |
pipeline_v=pipeline_v,
|
| 310 |
-
use_tma_q=use_tma_q,
|
| 311 |
-
tma_q_bytes=tma_q_bytes,
|
| 312 |
intra_wg_overlap=intra_wg_overlap,
|
| 313 |
)
|
| 314 |
|
|
@@ -320,21 +283,16 @@ def produce_block_sparse_loads(
|
|
| 320 |
kv_producer_state,
|
| 321 |
)
|
| 322 |
else:
|
| 323 |
-
# Non-overlap path with both lists: run the full list normally
|
| 324 |
-
# reload because the masked list already issued it).
|
| 325 |
kv_producer_state = load_block_list(
|
| 326 |
curr_full_block_idx,
|
| 327 |
curr_full_block_cnt,
|
| 328 |
-
load_q_with_first=False,
|
| 329 |
first_block_preloaded=False,
|
| 330 |
kv_producer_state=kv_producer_state,
|
| 331 |
-
load_Q=load_Q,
|
| 332 |
load_K=load_K,
|
| 333 |
load_V=load_V,
|
| 334 |
pipeline_k=pipeline_k,
|
| 335 |
pipeline_v=pipeline_v,
|
| 336 |
-
use_tma_q=use_tma_q,
|
| 337 |
-
tma_q_bytes=tma_q_bytes,
|
| 338 |
intra_wg_overlap=intra_wg_overlap,
|
| 339 |
)
|
| 340 |
|
|
@@ -1390,18 +1348,18 @@ def _store_one_dQaccum_sm90(
|
|
| 1390 |
m_block,
|
| 1391 |
sdQaccum: cute.Tensor,
|
| 1392 |
gdQaccum: cute.Tensor,
|
| 1393 |
-
|
| 1394 |
num_threads_per_warp_group: cutlass.Constexpr,
|
| 1395 |
tma_copy_bytes_dQ,
|
| 1396 |
):
|
| 1397 |
"""Store dQaccum for a single m_block."""
|
| 1398 |
-
for warp_group_idx in cutlass.range_constexpr(
|
| 1399 |
-
cute.arch.cp_async_bulk_wait_group(
|
| 1400 |
cute.arch.barrier_arrive(
|
| 1401 |
barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx,
|
| 1402 |
number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE,
|
| 1403 |
)
|
| 1404 |
-
for warp_group_idx in cutlass.range_constexpr(
|
| 1405 |
cute.arch.barrier(
|
| 1406 |
barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx,
|
| 1407 |
number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE,
|
|
@@ -1409,7 +1367,7 @@ def _store_one_dQaccum_sm90(
|
|
| 1409 |
with cute.arch.elect_one():
|
| 1410 |
copy_utils.cpasync_reduce_bulk_add_f32(
|
| 1411 |
sdQaccum[None, warp_group_idx].iterator,
|
| 1412 |
-
gdQaccum[None, warp_group_idx, m_block].iterator,
|
| 1413 |
tma_copy_bytes_dQ,
|
| 1414 |
)
|
| 1415 |
cute.arch.cp_async_bulk_commit_group()
|
|
@@ -1425,7 +1383,7 @@ def dQaccum_store_block_sparse_bwd_sm90(
|
|
| 1425 |
gdQaccum: cute.Tensor,
|
| 1426 |
subtile_factor: cutlass.Constexpr,
|
| 1427 |
m_block_max: int,
|
| 1428 |
-
|
| 1429 |
num_threads_per_warp_group: cutlass.Constexpr,
|
| 1430 |
tma_copy_bytes_dQ,
|
| 1431 |
):
|
|
@@ -1454,7 +1412,7 @@ def dQaccum_store_block_sparse_bwd_sm90(
|
|
| 1454 |
m_block,
|
| 1455 |
sdQaccum,
|
| 1456 |
gdQaccum,
|
| 1457 |
-
|
| 1458 |
num_threads_per_warp_group,
|
| 1459 |
tma_copy_bytes_dQ,
|
| 1460 |
)
|
|
@@ -1470,7 +1428,7 @@ def dQaccum_store_block_sparse_bwd_sm90(
|
|
| 1470 |
m_block,
|
| 1471 |
sdQaccum,
|
| 1472 |
gdQaccum,
|
| 1473 |
-
|
| 1474 |
num_threads_per_warp_group,
|
| 1475 |
tma_copy_bytes_dQ,
|
| 1476 |
)
|
|
|
|
| 72 |
def load_block_list(
|
| 73 |
block_indices: cute.Tensor,
|
| 74 |
block_count,
|
|
|
|
| 75 |
first_block_preloaded: cutlass.Constexpr,
|
| 76 |
kv_producer_state,
|
|
|
|
| 77 |
load_K,
|
| 78 |
load_V,
|
| 79 |
pipeline_k,
|
| 80 |
pipeline_v,
|
|
|
|
|
|
|
| 81 |
intra_wg_overlap: cutlass.Constexpr,
|
| 82 |
):
|
| 83 |
+
"""Iterate over the sparse blocks and load K, V into the pipeline.
|
| 84 |
+
For the intra_wg_overlap case, we overlap the loads of K and V. And this
|
| 85 |
means we need to pipeline the last V load from the partial block case,
|
| 86 |
with the loads for the full blocks. Set first_block_preloaded when the
|
| 87 |
caller has already issued the first K load for the list.
|
| 88 |
|
| 89 |
+
Q is loaded separately on its own mbarrier before this function is called.
|
| 90 |
+
|
| 91 |
Note:
|
| 92 |
we iterate along the block_n indices in reverse.
|
| 93 |
|
|
|
|
| 97 |
"""
|
| 98 |
if block_count > 0:
|
| 99 |
if const_expr(not intra_wg_overlap):
|
| 100 |
+
for offset in cutlass.range(block_count):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
n_block = block_indices[block_count - 1 - offset]
|
| 102 |
pipeline_k.producer_acquire(kv_producer_state)
|
| 103 |
load_K(src_idx=n_block, producer_state=kv_producer_state)
|
|
|
|
| 107 |
else:
|
| 108 |
n_block_first = block_indices[block_count - 1]
|
| 109 |
if const_expr(not first_block_preloaded):
|
| 110 |
+
pipeline_k.producer_acquire(kv_producer_state)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
load_K(src_idx=n_block_first, producer_state=kv_producer_state)
|
| 112 |
|
| 113 |
for idx in cutlass.range(block_count - 1, unroll=1):
|
|
|
|
| 163 |
head_idx,
|
| 164 |
m_block,
|
| 165 |
kv_producer_state,
|
|
|
|
| 166 |
load_K,
|
| 167 |
load_V,
|
| 168 |
pipeline_k,
|
| 169 |
pipeline_v,
|
|
|
|
|
|
|
| 170 |
intra_wg_overlap: cutlass.Constexpr,
|
| 171 |
qhead_per_kvhead: cutlass.Constexpr[int] = 1,
|
| 172 |
q_subtile_factor: cutlass.Constexpr[int] = 1,
|
| 173 |
):
|
| 174 |
"""Iterate over the mask and full block lists for a single tile.
|
| 175 |
|
| 176 |
+
Q is loaded separately on its own mbarrier before this function is called.
|
| 177 |
+
|
| 178 |
The masked (partial) list may leave the last V load pending when intra-warp-group
|
| 179 |
overlap is enabled. The first full block must consume that pending V while
|
| 180 |
issuing its own K load on the next pipeline stage.
|
|
|
|
| 206 |
full_empty = curr_full_block_cnt == 0
|
| 207 |
|
| 208 |
if mask_empty:
|
| 209 |
+
# No masked blocks: the full list owns the initial K load.
|
| 210 |
kv_producer_state = load_block_list(
|
| 211 |
curr_full_block_idx,
|
| 212 |
curr_full_block_cnt,
|
|
|
|
| 213 |
first_block_preloaded=False,
|
| 214 |
kv_producer_state=kv_producer_state,
|
|
|
|
| 215 |
load_K=load_K,
|
| 216 |
load_V=load_V,
|
| 217 |
pipeline_k=pipeline_k,
|
| 218 |
pipeline_v=pipeline_v,
|
|
|
|
|
|
|
| 219 |
intra_wg_overlap=intra_wg_overlap,
|
| 220 |
)
|
| 221 |
|
|
|
|
| 228 |
kv_producer_state,
|
| 229 |
)
|
| 230 |
else:
|
| 231 |
+
# Masked blocks present. When overlap is disabled this fully drains the list.
|
|
|
|
| 232 |
kv_producer_state = load_block_list(
|
| 233 |
curr_mask_block_idx,
|
| 234 |
curr_mask_block_cnt,
|
|
|
|
| 235 |
first_block_preloaded=False,
|
| 236 |
kv_producer_state=kv_producer_state,
|
|
|
|
| 237 |
load_K=load_K,
|
| 238 |
load_V=load_V,
|
| 239 |
pipeline_k=pipeline_k,
|
| 240 |
pipeline_v=pipeline_v,
|
|
|
|
|
|
|
| 241 |
intra_wg_overlap=intra_wg_overlap,
|
| 242 |
)
|
| 243 |
|
|
|
|
| 266 |
kv_producer_state = load_block_list(
|
| 267 |
curr_full_block_idx,
|
| 268 |
curr_full_block_cnt,
|
|
|
|
| 269 |
first_block_preloaded=True,
|
| 270 |
kv_producer_state=kv_producer_state,
|
|
|
|
| 271 |
load_K=load_K,
|
| 272 |
load_V=load_V,
|
| 273 |
pipeline_k=pipeline_k,
|
| 274 |
pipeline_v=pipeline_v,
|
|
|
|
|
|
|
| 275 |
intra_wg_overlap=intra_wg_overlap,
|
| 276 |
)
|
| 277 |
|
|
|
|
| 283 |
kv_producer_state,
|
| 284 |
)
|
| 285 |
else:
|
| 286 |
+
# Non-overlap path with both lists: run the full list normally.
|
|
|
|
| 287 |
kv_producer_state = load_block_list(
|
| 288 |
curr_full_block_idx,
|
| 289 |
curr_full_block_cnt,
|
|
|
|
| 290 |
first_block_preloaded=False,
|
| 291 |
kv_producer_state=kv_producer_state,
|
|
|
|
| 292 |
load_K=load_K,
|
| 293 |
load_V=load_V,
|
| 294 |
pipeline_k=pipeline_k,
|
| 295 |
pipeline_v=pipeline_v,
|
|
|
|
|
|
|
| 296 |
intra_wg_overlap=intra_wg_overlap,
|
| 297 |
)
|
| 298 |
|
|
|
|
| 1348 |
m_block,
|
| 1349 |
sdQaccum: cute.Tensor,
|
| 1350 |
gdQaccum: cute.Tensor,
|
| 1351 |
+
num_dQ_warp_groups: cutlass.Constexpr,
|
| 1352 |
num_threads_per_warp_group: cutlass.Constexpr,
|
| 1353 |
tma_copy_bytes_dQ,
|
| 1354 |
):
|
| 1355 |
"""Store dQaccum for a single m_block."""
|
| 1356 |
+
for warp_group_idx in cutlass.range_constexpr(num_dQ_warp_groups):
|
| 1357 |
+
cute.arch.cp_async_bulk_wait_group(num_dQ_warp_groups - 1 - warp_group_idx, read=True)
|
| 1358 |
cute.arch.barrier_arrive(
|
| 1359 |
barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx,
|
| 1360 |
number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE,
|
| 1361 |
)
|
| 1362 |
+
for warp_group_idx in cutlass.range_constexpr(num_dQ_warp_groups):
|
| 1363 |
cute.arch.barrier(
|
| 1364 |
barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx,
|
| 1365 |
number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE,
|
|
|
|
| 1367 |
with cute.arch.elect_one():
|
| 1368 |
copy_utils.cpasync_reduce_bulk_add_f32(
|
| 1369 |
sdQaccum[None, warp_group_idx].iterator,
|
| 1370 |
+
gdQaccum[(None, warp_group_idx), m_block].iterator,
|
| 1371 |
tma_copy_bytes_dQ,
|
| 1372 |
)
|
| 1373 |
cute.arch.cp_async_bulk_commit_group()
|
|
|
|
| 1383 |
gdQaccum: cute.Tensor,
|
| 1384 |
subtile_factor: cutlass.Constexpr,
|
| 1385 |
m_block_max: int,
|
| 1386 |
+
num_dQ_warp_groups: cutlass.Constexpr,
|
| 1387 |
num_threads_per_warp_group: cutlass.Constexpr,
|
| 1388 |
tma_copy_bytes_dQ,
|
| 1389 |
):
|
|
|
|
| 1412 |
m_block,
|
| 1413 |
sdQaccum,
|
| 1414 |
gdQaccum,
|
| 1415 |
+
num_dQ_warp_groups,
|
| 1416 |
num_threads_per_warp_group,
|
| 1417 |
tma_copy_bytes_dQ,
|
| 1418 |
)
|
|
|
|
| 1428 |
m_block,
|
| 1429 |
sdQaccum,
|
| 1430 |
gdQaccum,
|
| 1431 |
+
num_dQ_warp_groups,
|
| 1432 |
num_threads_per_warp_group,
|
| 1433 |
tma_copy_bytes_dQ,
|
| 1434 |
)
|
build/torch-cuda/block_sparsity.py
CHANGED
|
@@ -34,6 +34,23 @@ class BlockSparseTensorsTorch(NamedTuple):
|
|
| 34 |
block_size: tuple[int, int] | None = None
|
| 35 |
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
def _expand_sparsity_tensor(
|
| 38 |
tensor: torch.Tensor,
|
| 39 |
expected_shape: Tuple[int, ...],
|
|
@@ -81,6 +98,12 @@ def _check_and_expand_block(
|
|
| 81 |
expanded_cnt = _expand_sparsity_tensor(
|
| 82 |
cnt, expected_count_shape, f"{name}_block_cnt", context, hint
|
| 83 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
expanded_idx = _expand_sparsity_tensor(
|
| 85 |
idx, expected_index_shape, f"{name}_block_idx", context, hint
|
| 86 |
)
|
|
@@ -140,17 +163,14 @@ def infer_block_sparse_expected_shapes(
|
|
| 140 |
num_m_blocks = tensors.mask_block_idx.shape[2]
|
| 141 |
|
| 142 |
if sparse_block_size_q is None:
|
| 143 |
-
|
| 144 |
-
if
|
| 145 |
-
max_block_size = seqlen_q
|
| 146 |
-
else:
|
| 147 |
-
max_block_size = (seqlen_q - 1) // (num_m_blocks - 1)
|
| 148 |
-
if max_block_size != min_block_size and base_m_block != 1:
|
| 149 |
raise ValueError(
|
| 150 |
f"Block sparse tensors{context} require explicit sparse_block_size[0] "
|
| 151 |
f"to disambiguate block size for seqlen_q={seqlen_q} and num_m_blocks={num_m_blocks}."
|
| 152 |
)
|
| 153 |
-
sparse_block_size_q
|
|
|
|
| 154 |
|
| 155 |
if sparse_block_size_q % base_m_block != 0:
|
| 156 |
raise ValueError(
|
|
@@ -186,9 +206,11 @@ def infer_block_sparse_expected_shapes(
|
|
| 186 |
raise ValueError(f"Block sparse tensors{context} {dim_name} dim must be {tgt} or 1.")
|
| 187 |
if mask_block_cnt.shape[2] != mask_block_idx.shape[2]:
|
| 188 |
raise ValueError(f"Block sparse tensors{context} must share the same m-block dimension.")
|
| 189 |
-
|
|
|
|
|
|
|
| 190 |
raise ValueError(
|
| 191 |
-
f"Block sparse tensors{context} n-block dimension must be {expected_n_blocks}."
|
| 192 |
)
|
| 193 |
if expected_m_blocks != num_m_blocks:
|
| 194 |
raise ValueError(
|
|
@@ -314,7 +336,7 @@ def normalize_block_sparse_config(
|
|
| 314 |
) -> tuple[BlockSparseTensorsTorch, Tuple[Tuple[bool, ...], ...] | None, int]:
|
| 315 |
m_block_size, n_block_size = block_size
|
| 316 |
if tensors.block_size is None:
|
| 317 |
-
sparse_block_size_q, sparse_block_size_kv =
|
| 318 |
else:
|
| 319 |
sparse_block_size_q, sparse_block_size_kv = tensors.block_size
|
| 320 |
if sparse_block_size_kv != n_block_size:
|
|
@@ -401,6 +423,7 @@ def to_cute_block_sparse_tensors(
|
|
| 401 |
"""Convert torch block sparsity tensors to CuTe tensors, optionally for tvm ffi"""
|
| 402 |
if not is_block_sparsity_enabled(tensors):
|
| 403 |
return None
|
|
|
|
| 404 |
(
|
| 405 |
mask_block_cnt,
|
| 406 |
mask_block_idx,
|
|
|
|
| 34 |
block_size: tuple[int, int] | None = None
|
| 35 |
|
| 36 |
|
| 37 |
+
def get_sparse_q_block_size(
|
| 38 |
+
tensors: BlockSparseTensorsTorch | None,
|
| 39 |
+
seqlen_q: int,
|
| 40 |
+
) -> int | None:
|
| 41 |
+
"""Return the Q sparse block size, or None when sparsity is unset or ambiguous."""
|
| 42 |
+
if tensors is None:
|
| 43 |
+
return None
|
| 44 |
+
if tensors.block_size is not None:
|
| 45 |
+
return tensors.block_size[0]
|
| 46 |
+
num_m_blocks = tensors.mask_block_idx.shape[2]
|
| 47 |
+
min_block_size = ceildiv(seqlen_q, num_m_blocks)
|
| 48 |
+
max_block_size = seqlen_q if num_m_blocks == 1 else (seqlen_q - 1) // (num_m_blocks - 1)
|
| 49 |
+
if min_block_size != max_block_size:
|
| 50 |
+
return None
|
| 51 |
+
return min_block_size
|
| 52 |
+
|
| 53 |
+
|
| 54 |
def _expand_sparsity_tensor(
|
| 55 |
tensor: torch.Tensor,
|
| 56 |
expected_shape: Tuple[int, ...],
|
|
|
|
| 98 |
expanded_cnt = _expand_sparsity_tensor(
|
| 99 |
cnt, expected_count_shape, f"{name}_block_cnt", context, hint
|
| 100 |
)
|
| 101 |
+
# [Note] Allow Compact block sparse indices
|
| 102 |
+
# Allow the last dimension (n_blocks) of idx to be <= expected, since
|
| 103 |
+
# FA4 only accesses indices 0..cnt-1 per query tile. This enables compact
|
| 104 |
+
# index tensors that avoid O(N^2) memory at long sequence lengths.
|
| 105 |
+
if idx.ndim == 4 and idx.shape[3] <= expected_index_shape[3]:
|
| 106 |
+
expected_index_shape = (*expected_index_shape[:3], idx.shape[3])
|
| 107 |
expanded_idx = _expand_sparsity_tensor(
|
| 108 |
idx, expected_index_shape, f"{name}_block_idx", context, hint
|
| 109 |
)
|
|
|
|
| 163 |
num_m_blocks = tensors.mask_block_idx.shape[2]
|
| 164 |
|
| 165 |
if sparse_block_size_q is None:
|
| 166 |
+
sparse_block_size_q = get_sparse_q_block_size(tensors, seqlen_q)
|
| 167 |
+
if sparse_block_size_q is None and base_m_block != 1:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
raise ValueError(
|
| 169 |
f"Block sparse tensors{context} require explicit sparse_block_size[0] "
|
| 170 |
f"to disambiguate block size for seqlen_q={seqlen_q} and num_m_blocks={num_m_blocks}."
|
| 171 |
)
|
| 172 |
+
if sparse_block_size_q is None:
|
| 173 |
+
sparse_block_size_q = ceildiv(seqlen_q, num_m_blocks)
|
| 174 |
|
| 175 |
if sparse_block_size_q % base_m_block != 0:
|
| 176 |
raise ValueError(
|
|
|
|
| 206 |
raise ValueError(f"Block sparse tensors{context} {dim_name} dim must be {tgt} or 1.")
|
| 207 |
if mask_block_cnt.shape[2] != mask_block_idx.shape[2]:
|
| 208 |
raise ValueError(f"Block sparse tensors{context} must share the same m-block dimension.")
|
| 209 |
+
# [Note] Allow Compact block sparse indices: FA4 only accesses indices 0..cnt-1
|
| 210 |
+
# per query tile, so idx.shape[3] can be <= expected_n_blocks.
|
| 211 |
+
if mask_block_idx.shape[3] > expected_n_blocks:
|
| 212 |
raise ValueError(
|
| 213 |
+
f"Block sparse tensors{context} n-block dimension must be <= {expected_n_blocks}."
|
| 214 |
)
|
| 215 |
if expected_m_blocks != num_m_blocks:
|
| 216 |
raise ValueError(
|
|
|
|
| 336 |
) -> tuple[BlockSparseTensorsTorch, Tuple[Tuple[bool, ...], ...] | None, int]:
|
| 337 |
m_block_size, n_block_size = block_size
|
| 338 |
if tensors.block_size is None:
|
| 339 |
+
sparse_block_size_q, sparse_block_size_kv = None, n_block_size
|
| 340 |
else:
|
| 341 |
sparse_block_size_q, sparse_block_size_kv = tensors.block_size
|
| 342 |
if sparse_block_size_kv != n_block_size:
|
|
|
|
| 423 |
"""Convert torch block sparsity tensors to CuTe tensors, optionally for tvm ffi"""
|
| 424 |
if not is_block_sparsity_enabled(tensors):
|
| 425 |
return None
|
| 426 |
+
|
| 427 |
(
|
| 428 |
mask_block_cnt,
|
| 429 |
mask_block_idx,
|
build/torch-cuda/cache_utils.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
# Manage Ahead-of-Time (AOT) compiled kernels
|
| 2 |
import fcntl
|
| 3 |
import hashlib
|
| 4 |
-
import logging
|
| 5 |
import os
|
| 6 |
import pickle
|
| 7 |
import sys
|
|
@@ -18,6 +17,7 @@ import cutlass
|
|
| 18 |
import cutlass.cute as cute
|
| 19 |
import tvm_ffi
|
| 20 |
from cutlass.cutlass_dsl import JitCompiledFunction
|
|
|
|
| 21 |
|
| 22 |
# Pre-load cute DSL runtime libraries with RTLD_GLOBAL so that their symbols
|
| 23 |
# (e.g. _cudaLibraryLoadData) are visible to .so modules loaded later via dlopen.
|
|
@@ -30,12 +30,6 @@ for _lib_path in cute.runtime.find_runtime_libraries(enable_tvm_ffi=False):
|
|
| 30 |
CompileKeyType: TypeAlias = tuple[Hashable, ...]
|
| 31 |
CallableFunction: TypeAlias = JitCompiledFunction | tvm_ffi.Function
|
| 32 |
|
| 33 |
-
logger = logging.getLogger(__name__)
|
| 34 |
-
_handler = logging.StreamHandler()
|
| 35 |
-
_handler.setFormatter(logging.Formatter("%(asctime)s.%(msecs)03d %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S"))
|
| 36 |
-
logger.addHandler(_handler)
|
| 37 |
-
logger.setLevel(logging.DEBUG)
|
| 38 |
-
|
| 39 |
|
| 40 |
# Enable cache via `FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1`
|
| 41 |
CUTE_DSL_CACHE_ENABLED: bool = os.getenv("FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED", "0") == "1"
|
|
@@ -222,13 +216,13 @@ class JITPersistentCache(JITCache):
|
|
| 222 |
label=sha256_hex,
|
| 223 |
):
|
| 224 |
if obj_path.exists():
|
| 225 |
-
|
| 226 |
m = cute.runtime.load_module(str(obj_path), enable_tvm_ffi=True)
|
| 227 |
fn = getattr(m, self.EXPORT_FUNCTION_PREFIX)
|
| 228 |
JITCache.__setitem__(self, key, fn)
|
| 229 |
return True
|
| 230 |
else:
|
| 231 |
-
|
| 232 |
return False
|
| 233 |
|
| 234 |
def _try_export_to_storage(self, key: CompileKeyType, fn: JitCompiledFunction) -> None:
|
|
@@ -243,14 +237,14 @@ class JITPersistentCache(JITCache):
|
|
| 243 |
obj_path = self.cache_path / f"{sha256_hex}.o"
|
| 244 |
if obj_path.exists():
|
| 245 |
# Another process already exported.
|
| 246 |
-
|
| 247 |
return
|
| 248 |
-
|
| 249 |
fn.export_to_c(
|
| 250 |
object_file_path=str(obj_path),
|
| 251 |
function_name=self.EXPORT_FUNCTION_PREFIX,
|
| 252 |
)
|
| 253 |
-
|
| 254 |
|
| 255 |
def _key_to_hash(self, key: CompileKeyType) -> str:
|
| 256 |
return hashlib.sha256(pickle.dumps(key)).hexdigest()
|
|
@@ -262,7 +256,7 @@ class JITPersistentCache(JITCache):
|
|
| 262 |
"""
|
| 263 |
Not only clear the in-memory cache. Also purge persistent compilation cache.
|
| 264 |
"""
|
| 265 |
-
|
| 266 |
super().clear()
|
| 267 |
for child in self.cache_path.iterdir():
|
| 268 |
child.unlink()
|
|
@@ -281,8 +275,8 @@ def get_jit_cache(name: str | None = None) -> JITCache:
|
|
| 281 |
path = get_cache_path() / _compute_source_fingerprint()
|
| 282 |
if name:
|
| 283 |
path = path / name
|
| 284 |
-
|
| 285 |
return JITPersistentCache(path)
|
| 286 |
else:
|
| 287 |
-
|
| 288 |
return JITCache()
|
|
|
|
| 1 |
# Manage Ahead-of-Time (AOT) compiled kernels
|
| 2 |
import fcntl
|
| 3 |
import hashlib
|
|
|
|
| 4 |
import os
|
| 5 |
import pickle
|
| 6 |
import sys
|
|
|
|
| 17 |
import cutlass.cute as cute
|
| 18 |
import tvm_ffi
|
| 19 |
from cutlass.cutlass_dsl import JitCompiledFunction
|
| 20 |
+
from .fa_logging import fa_log
|
| 21 |
|
| 22 |
# Pre-load cute DSL runtime libraries with RTLD_GLOBAL so that their symbols
|
| 23 |
# (e.g. _cudaLibraryLoadData) are visible to .so modules loaded later via dlopen.
|
|
|
|
| 30 |
CompileKeyType: TypeAlias = tuple[Hashable, ...]
|
| 31 |
CallableFunction: TypeAlias = JitCompiledFunction | tvm_ffi.Function
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
# Enable cache via `FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1`
|
| 35 |
CUTE_DSL_CACHE_ENABLED: bool = os.getenv("FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED", "0") == "1"
|
|
|
|
| 216 |
label=sha256_hex,
|
| 217 |
):
|
| 218 |
if obj_path.exists():
|
| 219 |
+
fa_log(1, f"Loading compiled function from disk: {obj_path}")
|
| 220 |
m = cute.runtime.load_module(str(obj_path), enable_tvm_ffi=True)
|
| 221 |
fn = getattr(m, self.EXPORT_FUNCTION_PREFIX)
|
| 222 |
JITCache.__setitem__(self, key, fn)
|
| 223 |
return True
|
| 224 |
else:
|
| 225 |
+
fa_log(1, f"Cache miss on disk for key hash {sha256_hex}")
|
| 226 |
return False
|
| 227 |
|
| 228 |
def _try_export_to_storage(self, key: CompileKeyType, fn: JitCompiledFunction) -> None:
|
|
|
|
| 237 |
obj_path = self.cache_path / f"{sha256_hex}.o"
|
| 238 |
if obj_path.exists():
|
| 239 |
# Another process already exported.
|
| 240 |
+
fa_log(1, f"Skipping export, already on disk: {obj_path}")
|
| 241 |
return
|
| 242 |
+
fa_log(1, f"Exporting compiled function to disk: {obj_path}")
|
| 243 |
fn.export_to_c(
|
| 244 |
object_file_path=str(obj_path),
|
| 245 |
function_name=self.EXPORT_FUNCTION_PREFIX,
|
| 246 |
)
|
| 247 |
+
fa_log(1, f"Successfully exported compiled function to disk: {obj_path}")
|
| 248 |
|
| 249 |
def _key_to_hash(self, key: CompileKeyType) -> str:
|
| 250 |
return hashlib.sha256(pickle.dumps(key)).hexdigest()
|
|
|
|
| 256 |
"""
|
| 257 |
Not only clear the in-memory cache. Also purge persistent compilation cache.
|
| 258 |
"""
|
| 259 |
+
fa_log(1, f"Clearing persistent cache at {self.cache_path}")
|
| 260 |
super().clear()
|
| 261 |
for child in self.cache_path.iterdir():
|
| 262 |
child.unlink()
|
|
|
|
| 275 |
path = get_cache_path() / _compute_source_fingerprint()
|
| 276 |
if name:
|
| 277 |
path = path / name
|
| 278 |
+
fa_log(1, f"Creating persistent JIT cache at {path}")
|
| 279 |
return JITPersistentCache(path)
|
| 280 |
else:
|
| 281 |
+
fa_log(1, "Persistent cache disabled, using in-memory JIT cache")
|
| 282 |
return JITCache()
|
build/torch-cuda/cute_dsl_utils.py
CHANGED
|
@@ -4,7 +4,6 @@ import os
|
|
| 4 |
import pathlib
|
| 5 |
from typing import Tuple
|
| 6 |
from functools import partial, lru_cache
|
| 7 |
-
from dataclasses import dataclass, fields
|
| 8 |
|
| 9 |
import torch
|
| 10 |
|
|
@@ -15,7 +14,6 @@ except ImportError:
|
|
| 15 |
|
| 16 |
import cutlass
|
| 17 |
import cutlass.cute as cute
|
| 18 |
-
from cutlass.base_dsl.typing import JitArgument
|
| 19 |
from cutlass.cutlass_dsl import NumericMeta
|
| 20 |
from cutlass.cute.runtime import from_dlpack
|
| 21 |
|
|
@@ -43,42 +41,6 @@ def get_device_capacity(device: torch.device = None) -> Tuple[int, int]:
|
|
| 43 |
return torch.cuda.get_device_capability(device)
|
| 44 |
|
| 45 |
|
| 46 |
-
@dataclass
|
| 47 |
-
class ArgumentsBase(JitArgument):
|
| 48 |
-
def __c_pointers__(self):
|
| 49 |
-
all_fields = [getattr(self, field.name) for field in fields(self)]
|
| 50 |
-
non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
|
| 51 |
-
c_ptrs = []
|
| 52 |
-
for obj in non_constexpr_fields:
|
| 53 |
-
if hasattr(obj, "__c_pointers__"):
|
| 54 |
-
c_ptrs.extend(obj.__c_pointers__())
|
| 55 |
-
return c_ptrs
|
| 56 |
-
|
| 57 |
-
def __get_mlir_types__(self):
|
| 58 |
-
all_fields = [getattr(self, field.name) for field in fields(self)]
|
| 59 |
-
non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
|
| 60 |
-
types, self._values_pos = [], []
|
| 61 |
-
for obj in non_constexpr_fields:
|
| 62 |
-
if hasattr(obj, "__get_mlir_types__"):
|
| 63 |
-
obj_types = obj.__get_mlir_types__()
|
| 64 |
-
types.extend(obj_types)
|
| 65 |
-
self._values_pos.append(len(obj_types))
|
| 66 |
-
else:
|
| 67 |
-
self._values_pos.append(0)
|
| 68 |
-
return types
|
| 69 |
-
|
| 70 |
-
def __new_from_mlir_values__(self, values):
|
| 71 |
-
all_fields = {field.name: getattr(self, field.name) for field in fields(self)}
|
| 72 |
-
constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)}
|
| 73 |
-
non_constexpr_fields = {
|
| 74 |
-
n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)
|
| 75 |
-
}
|
| 76 |
-
for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
|
| 77 |
-
non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
|
| 78 |
-
values = values[n_items:]
|
| 79 |
-
return self.__class__(**non_constexpr_fields, **constexpr_fields)
|
| 80 |
-
|
| 81 |
-
|
| 82 |
def load_cubin_module_data_patched(cubin_data, filepath):
|
| 83 |
pathlib.Path(filepath).write_bytes(cubin_data)
|
| 84 |
return load_cubin_module_data_og(cubin_data)
|
|
|
|
| 4 |
import pathlib
|
| 5 |
from typing import Tuple
|
| 6 |
from functools import partial, lru_cache
|
|
|
|
| 7 |
|
| 8 |
import torch
|
| 9 |
|
|
|
|
| 14 |
|
| 15 |
import cutlass
|
| 16 |
import cutlass.cute as cute
|
|
|
|
| 17 |
from cutlass.cutlass_dsl import NumericMeta
|
| 18 |
from cutlass.cute.runtime import from_dlpack
|
| 19 |
|
|
|
|
| 41 |
return torch.cuda.get_device_capability(device)
|
| 42 |
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
def load_cubin_module_data_patched(cubin_data, filepath):
|
| 45 |
pathlib.Path(filepath).write_bytes(cubin_data)
|
| 46 |
return load_cubin_module_data_og(cubin_data)
|
build/torch-cuda/fa_logging.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Tri Dao.
|
| 2 |
+
|
| 3 |
+
"""Unified FlashAttention logging controlled by a single ``FA_LOG_LEVEL`` env var.
|
| 4 |
+
|
| 5 |
+
Host-side messages go through Python ``logging`` (logger name ``flash_attn``).
|
| 6 |
+
A default ``StreamHandler`` is attached automatically when ``FA_LOG_LEVEL >= 1``
|
| 7 |
+
so that standalone scripts get output without extra setup; applications that
|
| 8 |
+
configure their own logging can remove or replace it via the standard API.
|
| 9 |
+
|
| 10 |
+
FA_LOG_LEVEL mapping::
|
| 11 |
+
|
| 12 |
+
0 off nothing logged
|
| 13 |
+
1 host host-side summaries only (no kernel printf)
|
| 14 |
+
2 kernel host + curated kernel traces
|
| 15 |
+
3 max host + all kernel traces (noisy, perf hit)
|
| 16 |
+
|
| 17 |
+
Set via environment variable::
|
| 18 |
+
|
| 19 |
+
FA_LOG_LEVEL=1 python train.py
|
| 20 |
+
|
| 21 |
+
Device-side ``cute.printf`` calls are compile-time eliminated via
|
| 22 |
+
``cutlass.const_expr`` when the log level is below the callsite threshold,
|
| 23 |
+
so there is zero performance cost when device logging is off.
|
| 24 |
+
Changing the log level after kernel compilation requires a recompile
|
| 25 |
+
(the level participates in the forward compile key).
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
import logging
|
| 29 |
+
import os
|
| 30 |
+
import sys
|
| 31 |
+
|
| 32 |
+
import cutlass.cute as cute
|
| 33 |
+
from cutlass import const_expr
|
| 34 |
+
|
| 35 |
+
_LOG_LEVEL_NAMES = {"off": 0, "host": 1, "kernel": 2, "max": 3}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _parse_log_level(raw: str) -> int:
|
| 39 |
+
if raw in _LOG_LEVEL_NAMES:
|
| 40 |
+
return _LOG_LEVEL_NAMES[raw]
|
| 41 |
+
try:
|
| 42 |
+
level = int(raw)
|
| 43 |
+
except ValueError:
|
| 44 |
+
return 0
|
| 45 |
+
return max(0, min(level, 3))
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
_fa_log_level: int = _parse_log_level(os.environ.get("FA_LOG_LEVEL", "0"))
|
| 49 |
+
|
| 50 |
+
_logger = logging.getLogger("flash_attn")
|
| 51 |
+
_logger.addHandler(logging.NullHandler())
|
| 52 |
+
_default_handler: logging.Handler | None = None
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _configure_default_handler() -> None:
|
| 56 |
+
global _default_handler
|
| 57 |
+
if _fa_log_level >= 1:
|
| 58 |
+
if _default_handler is None:
|
| 59 |
+
_default_handler = logging.StreamHandler(sys.stdout)
|
| 60 |
+
_default_handler.setFormatter(logging.Formatter("[FA] %(message)s"))
|
| 61 |
+
_logger.addHandler(_default_handler)
|
| 62 |
+
_logger.setLevel(logging.DEBUG)
|
| 63 |
+
else:
|
| 64 |
+
if _default_handler is not None:
|
| 65 |
+
_logger.removeHandler(_default_handler)
|
| 66 |
+
_default_handler = None
|
| 67 |
+
_logger.setLevel(logging.WARNING)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
_configure_default_handler()
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def get_fa_log_level() -> int:
|
| 74 |
+
return _fa_log_level
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def set_fa_log_level(level: int | str) -> None:
|
| 78 |
+
"""Set the FA log level programmatically.
|
| 79 |
+
|
| 80 |
+
Host logging takes effect immediately. Device logging changes only
|
| 81 |
+
affect kernels compiled after this call (new compile-key selection).
|
| 82 |
+
"""
|
| 83 |
+
global _fa_log_level
|
| 84 |
+
if isinstance(level, str):
|
| 85 |
+
level = _parse_log_level(level)
|
| 86 |
+
_fa_log_level = max(0, min(int(level), 3))
|
| 87 |
+
_configure_default_handler()
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def fa_log(level: int, msg: str):
|
| 91 |
+
if _fa_log_level >= level:
|
| 92 |
+
_logger.info(msg)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def fa_printf(level: int, fmt, *args):
|
| 96 |
+
if const_expr(_fa_log_level >= level):
|
| 97 |
+
cute.printf(fmt, *args)
|
build/torch-cuda/flash_bwd.py
CHANGED
|
@@ -22,6 +22,7 @@ from .mask import AttentionMask
|
|
| 22 |
from .seqlen_info import SeqlenInfoQK
|
| 23 |
from .quack.cute_dsl_utils import ParamsBase
|
| 24 |
from .tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
class FlashAttentionBackwardSm80:
|
|
@@ -372,7 +373,6 @@ class FlashAttentionBackwardSm80:
|
|
| 372 |
mdK: cute.Tensor,
|
| 373 |
mdV: cute.Tensor,
|
| 374 |
softmax_scale: cutlass.Float32,
|
| 375 |
-
stream: cuda.CUstream,
|
| 376 |
mCuSeqlensQ: Optional[cute.Tensor] = None,
|
| 377 |
mCuSeqlensK: Optional[cute.Tensor] = None,
|
| 378 |
mSeqUsedQ: Optional[cute.Tensor] = None,
|
|
@@ -381,8 +381,16 @@ class FlashAttentionBackwardSm80:
|
|
| 381 |
window_size_left: Int32 | int | None = None,
|
| 382 |
window_size_right: Int32 | int | None = None,
|
| 383 |
mdQ_semaphore: Optional[cute.Tensor] = None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
):
|
| 385 |
-
assert mdQ_semaphore is None
|
|
|
|
|
|
|
| 386 |
# Get the data type and check if it is fp16 or bf16
|
| 387 |
self._check_type(*(t.element_type if t is not None else None
|
| 388 |
for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)))
|
|
@@ -512,7 +520,17 @@ class FlashAttentionBackwardSm80:
|
|
| 512 |
n_block, head_idx, batch_idx, _ = work_tile.tile_idx
|
| 513 |
|
| 514 |
if work_tile.is_valid_tile:
|
| 515 |
-
seqlen = SeqlenInfoQK.create(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 516 |
|
| 517 |
m_block_max = cute.ceil_div(seqlen.seqlen_q, self.m_block_size)
|
| 518 |
m_block_min = 0
|
|
@@ -538,7 +556,7 @@ class FlashAttentionBackwardSm80:
|
|
| 538 |
mdPsum_cur = mdPsum[batch_idx, head_idx, None]
|
| 539 |
mdQaccum_cur = mdQaccum[batch_idx, head_idx, None]
|
| 540 |
else:
|
| 541 |
-
padded_offset_q = seqlen.
|
| 542 |
mQ_cur = cute.domain_offset((seqlen.offset_q, 0), mQ[None, head_idx, None])
|
| 543 |
mLSE_cur = cute.domain_offset((padded_offset_q,), mLSE[head_idx, None])
|
| 544 |
mdO_cur = cute.domain_offset((seqlen.offset_q, 0), mdO[None, head_idx, None])
|
|
@@ -794,9 +812,10 @@ class FlashAttentionBackwardSm80:
|
|
| 794 |
# Mainloop
|
| 795 |
# ///////////////////////////////////////////////////////////////////////////////
|
| 796 |
# Start processing of the first n-block.
|
| 797 |
-
mask = AttentionMask(self.m_block_size, self.n_block_size, seqlen
|
| 798 |
mask_fn = partial(
|
| 799 |
mask.apply_mask, n_block=n_block, thr_mma=thr_mma_sdp,
|
|
|
|
| 800 |
mask_seqlen=True, mask_causal=self.is_causal
|
| 801 |
)
|
| 802 |
smem_pipe_read_q = cutlass.Int32(0)
|
|
@@ -968,7 +987,7 @@ class FlashAttentionBackwardSm80:
|
|
| 968 |
|
| 969 |
# MMA dK
|
| 970 |
if cutlass.const_expr(self.Mma_dKV_is_RS):
|
| 971 |
-
|
| 972 |
else:
|
| 973 |
tdKrdS = mma_params.tdKrdS
|
| 974 |
sm80_utils.gemm(
|
|
|
|
| 22 |
from .seqlen_info import SeqlenInfoQK
|
| 23 |
from .quack.cute_dsl_utils import ParamsBase
|
| 24 |
from .tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments
|
| 25 |
+
from .block_sparsity import BlockSparseTensors
|
| 26 |
|
| 27 |
|
| 28 |
class FlashAttentionBackwardSm80:
|
|
|
|
| 373 |
mdK: cute.Tensor,
|
| 374 |
mdV: cute.Tensor,
|
| 375 |
softmax_scale: cutlass.Float32,
|
|
|
|
| 376 |
mCuSeqlensQ: Optional[cute.Tensor] = None,
|
| 377 |
mCuSeqlensK: Optional[cute.Tensor] = None,
|
| 378 |
mSeqUsedQ: Optional[cute.Tensor] = None,
|
|
|
|
| 381 |
window_size_left: Int32 | int | None = None,
|
| 382 |
window_size_right: Int32 | int | None = None,
|
| 383 |
mdQ_semaphore: Optional[cute.Tensor] = None,
|
| 384 |
+
mdK_semaphore: Optional[cute.Tensor] = None,
|
| 385 |
+
mdV_semaphore: Optional[cute.Tensor] = None,
|
| 386 |
+
aux_tensors: Optional[list] = None,
|
| 387 |
+
blocksparse_tensors: Optional[BlockSparseTensors] = None,
|
| 388 |
+
# Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).
|
| 389 |
+
stream: cuda.CUstream = None,
|
| 390 |
):
|
| 391 |
+
assert mdQ_semaphore is None and mdK_semaphore is None and mdV_semaphore is None, (
|
| 392 |
+
"determinism not supported yet for Sm80"
|
| 393 |
+
)
|
| 394 |
# Get the data type and check if it is fp16 or bf16
|
| 395 |
self._check_type(*(t.element_type if t is not None else None
|
| 396 |
for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)))
|
|
|
|
| 520 |
n_block, head_idx, batch_idx, _ = work_tile.tile_idx
|
| 521 |
|
| 522 |
if work_tile.is_valid_tile:
|
| 523 |
+
seqlen = SeqlenInfoQK.create(
|
| 524 |
+
batch_idx,
|
| 525 |
+
mQ.shape[1],
|
| 526 |
+
mK.shape[1],
|
| 527 |
+
mCuSeqlensQ=mCuSeqlensQ,
|
| 528 |
+
mCuSeqlensK=mCuSeqlensK,
|
| 529 |
+
mSeqUsedQ=mSeqUsedQ,
|
| 530 |
+
mSeqUsedK=mSeqUsedK,
|
| 531 |
+
tile_m=self.m_block_size,
|
| 532 |
+
tile_n=self.n_block_size,
|
| 533 |
+
)
|
| 534 |
|
| 535 |
m_block_max = cute.ceil_div(seqlen.seqlen_q, self.m_block_size)
|
| 536 |
m_block_min = 0
|
|
|
|
| 556 |
mdPsum_cur = mdPsum[batch_idx, head_idx, None]
|
| 557 |
mdQaccum_cur = mdQaccum[batch_idx, head_idx, None]
|
| 558 |
else:
|
| 559 |
+
padded_offset_q = seqlen.padded_offset_q
|
| 560 |
mQ_cur = cute.domain_offset((seqlen.offset_q, 0), mQ[None, head_idx, None])
|
| 561 |
mLSE_cur = cute.domain_offset((padded_offset_q,), mLSE[head_idx, None])
|
| 562 |
mdO_cur = cute.domain_offset((seqlen.offset_q, 0), mdO[None, head_idx, None])
|
|
|
|
| 812 |
# Mainloop
|
| 813 |
# ///////////////////////////////////////////////////////////////////////////////
|
| 814 |
# Start processing of the first n-block.
|
| 815 |
+
mask = AttentionMask(self.m_block_size, self.n_block_size, seqlen)
|
| 816 |
mask_fn = partial(
|
| 817 |
mask.apply_mask, n_block=n_block, thr_mma=thr_mma_sdp,
|
| 818 |
+
batch_idx=batch_idx, head_idx=head_idx,
|
| 819 |
mask_seqlen=True, mask_causal=self.is_causal
|
| 820 |
)
|
| 821 |
smem_pipe_read_q = cutlass.Int32(0)
|
|
|
|
| 987 |
|
| 988 |
# MMA dK
|
| 989 |
if cutlass.const_expr(self.Mma_dKV_is_RS):
|
| 990 |
+
tdKrdS = layout_utils.reshape_acc_to_frgA(rdS)
|
| 991 |
else:
|
| 992 |
tdKrdS = mma_params.tdKrdS
|
| 993 |
sm80_utils.gemm(
|
build/torch-cuda/flash_bwd_postprocess.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
# A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_bwd_postprocess_kernel.h
|
| 3 |
# from Cutlass C++ to Cute-DSL.
|
| 4 |
import math
|
| 5 |
-
from typing import Callable, Optional, Type
|
| 6 |
|
| 7 |
import cuda.bindings.driver as cuda
|
| 8 |
|
|
@@ -36,7 +36,7 @@ class FlashAttentionBackwardPostprocess:
|
|
| 36 |
self,
|
| 37 |
dtype: Type[cutlass.Numeric],
|
| 38 |
head_dim: int,
|
| 39 |
-
arch:
|
| 40 |
tile_m: int = 128,
|
| 41 |
num_threads: int = 256,
|
| 42 |
AtomLayoutMdQ: int = 1,
|
|
@@ -52,8 +52,8 @@ class FlashAttentionBackwardPostprocess:
|
|
| 52 |
"""
|
| 53 |
self.dtype = dtype
|
| 54 |
self.tile_m = tile_m
|
| 55 |
-
assert arch // 10 in [8, 9, 10, 11], (
|
| 56 |
-
"Only Ampere (8.x), Hopper (9.x), and Blackwell (10.x, 11.x) are supported"
|
| 57 |
)
|
| 58 |
self.arch = arch
|
| 59 |
# padding head_dim to a multiple of 32 as k_block_size
|
|
@@ -63,7 +63,7 @@ class FlashAttentionBackwardPostprocess:
|
|
| 63 |
self.num_threads = num_threads
|
| 64 |
self.AtomLayoutMdQ = AtomLayoutMdQ
|
| 65 |
self.dQ_swapAB = dQ_swapAB
|
| 66 |
-
self.use_2cta_instrs = use_2cta_instrs and arch ==
|
| 67 |
self.cluster_size = cluster_size
|
| 68 |
|
| 69 |
@staticmethod
|
|
@@ -89,7 +89,7 @@ class FlashAttentionBackwardPostprocess:
|
|
| 89 |
return True
|
| 90 |
|
| 91 |
def _get_tiled_mma(self):
|
| 92 |
-
if const_expr(self.arch
|
| 93 |
num_mma_warps = self.num_threads // 32
|
| 94 |
atom_layout_dQ = (
|
| 95 |
(self.AtomLayoutMdQ, num_mma_warps // self.AtomLayoutMdQ, 1)
|
|
@@ -101,9 +101,9 @@ class FlashAttentionBackwardPostprocess:
|
|
| 101 |
atom_layout_dQ,
|
| 102 |
permutation_mnk=(atom_layout_dQ[0] * 16, atom_layout_dQ[1] * 16, 16),
|
| 103 |
)
|
| 104 |
-
elif const_expr(self.arch ==
|
| 105 |
-
|
| 106 |
-
atom_layout_dQ = (self.AtomLayoutMdQ,
|
| 107 |
tiler_mn_dQ = (self.tile_m // atom_layout_dQ[0], self.tile_hdim // atom_layout_dQ[1])
|
| 108 |
tiled_mma = sm90_utils_basic.make_trivial_tiled_mma(
|
| 109 |
self.dtype,
|
|
@@ -125,7 +125,7 @@ class FlashAttentionBackwardPostprocess:
|
|
| 125 |
cta_group,
|
| 126 |
(self.tile_m, self.tile_hdim),
|
| 127 |
)
|
| 128 |
-
if const_expr(self.arch in [
|
| 129 |
assert self.num_threads == tiled_mma.size
|
| 130 |
return tiled_mma
|
| 131 |
|
|
@@ -148,22 +148,22 @@ class FlashAttentionBackwardPostprocess:
|
|
| 148 |
cute.make_layout(self.num_threads),
|
| 149 |
cute.make_layout(async_copy_elems_accum),
|
| 150 |
)
|
| 151 |
-
num_s2r_copy_elems = 1 if const_expr(self.arch
|
| 152 |
-
if const_expr(self.arch
|
| 153 |
self.s2r_tiled_copy_dQaccum = copy_utils.tiled_copy_1d(
|
| 154 |
Float32, self.num_threads, num_s2r_copy_elems
|
| 155 |
)
|
| 156 |
self.sdQaccum_layout = cute.make_layout(self.tile_m * self.tile_hdim)
|
| 157 |
-
elif const_expr(self.arch ==
|
| 158 |
num_threads_per_warp_group = 128
|
| 159 |
-
|
| 160 |
self.s2r_tiled_copy_dQaccum = cute.make_tiled_copy_tv(
|
| 161 |
cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128),
|
| 162 |
-
cute.make_layout((num_threads_per_warp_group,
|
| 163 |
cute.make_layout(128 // Float32.width), # val_layout
|
| 164 |
)
|
| 165 |
self.sdQaccum_layout = cute.make_layout(
|
| 166 |
-
(self.tile_m * self.tile_hdim //
|
| 167 |
)
|
| 168 |
else:
|
| 169 |
self.dQ_reduce_ncol = 32
|
|
@@ -188,14 +188,18 @@ class FlashAttentionBackwardPostprocess:
|
|
| 188 |
# then setting kBlockKSmem to 32 will cause "Static shape_div failure".
|
| 189 |
# We want to treat it as 64 x 48, so kBlockKSmem should be 16.
|
| 190 |
mma_shape_n = self.tiled_mma.get_tile_size(1)
|
| 191 |
-
if const_expr(self.arch
|
| 192 |
sdQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, mma_shape_n)
|
| 193 |
self.sdQ_layout = cute.tile_to_shape(
|
| 194 |
sdQ_layout_atom, (self.tile_m, self.tile_hdim), (0, 1)
|
| 195 |
)
|
| 196 |
-
elif const_expr(self.arch ==
|
|
|
|
| 197 |
self.sdQ_layout = sm90_utils.make_smem_layout(
|
| 198 |
-
self.dtype,
|
|
|
|
|
|
|
|
|
|
| 199 |
)
|
| 200 |
else:
|
| 201 |
# TODO: this is hard-coded for hdim 128
|
|
@@ -211,7 +215,8 @@ class FlashAttentionBackwardPostprocess:
|
|
| 211 |
scale: cutlass.Float32,
|
| 212 |
mCuSeqlensQ: Optional[cute.Tensor],
|
| 213 |
mSeqUsedQ: Optional[cute.Tensor],
|
| 214 |
-
stream:
|
|
|
|
| 215 |
):
|
| 216 |
# Get the data type and check if it is fp16 or bf16
|
| 217 |
if const_expr(mdQ.element_type not in [cutlass.Float16, cutlass.BFloat16]):
|
|
@@ -305,7 +310,7 @@ class FlashAttentionBackwardPostprocess:
|
|
| 305 |
smem = cutlass.utils.SmemAllocator()
|
| 306 |
sdQaccum = smem.allocate_tensor(cutlass.Float32, sdQaccum_layout, byte_alignment=1024)
|
| 307 |
sdQaccum_flat = cute.make_tensor(sdQaccum.iterator, cute.make_layout(cute.size(sdQaccum)))
|
| 308 |
-
if const_expr(self.arch in [
|
| 309 |
sdQ = cute.make_tensor(cute.recast_ptr(sdQaccum.iterator, dtype=self.dtype), sdQ_layout)
|
| 310 |
else:
|
| 311 |
# extra stage dimension
|
|
@@ -343,10 +348,7 @@ class FlashAttentionBackwardPostprocess:
|
|
| 343 |
mdQaccum_cur = mdQaccum[batch_idx, head_idx, None]
|
| 344 |
head_dim = mdQ.shape[3]
|
| 345 |
else:
|
| 346 |
-
|
| 347 |
-
padded_offset_q = seqlen.padded_offset_q
|
| 348 |
-
else:
|
| 349 |
-
padded_offset_q = seqlen.offset_q + batch_idx * self.tile_m
|
| 350 |
mdQ_cur = cute.domain_offset((seqlen.offset_q, 0), mdQ[None, head_idx, None])
|
| 351 |
mdQaccum_cur = cute.domain_offset(
|
| 352 |
(padded_offset_q * self.tile_hdim,), mdQaccum[head_idx, None]
|
|
@@ -371,7 +373,7 @@ class FlashAttentionBackwardPostprocess:
|
|
| 371 |
seqlen_q = seqlen.seqlen_q
|
| 372 |
seqlen_q_rounded = cute.round_up(seqlen_q, self.tile_m)
|
| 373 |
|
| 374 |
-
if const_expr(self.arch ==
|
| 375 |
# 2-CTA: remap dQaccum layout into TMEM view before writing sdQ
|
| 376 |
num_reduce_threads = self.num_threads
|
| 377 |
thr_mma_dsk = tiled_mma.get_slice(tidx)
|
|
@@ -502,7 +504,7 @@ class FlashAttentionBackwardPostprocess:
|
|
| 502 |
tile_shape = (self.tile_m, self.tile_hdim)
|
| 503 |
acc = None
|
| 504 |
tiled_copy_t2r = None
|
| 505 |
-
if const_expr(self.arch in [
|
| 506 |
acc_shape = tiled_mma.partition_shape_C(
|
| 507 |
tile_shape if const_expr(not dQ_swapAB) else tile_shape[::-1]
|
| 508 |
)
|
|
@@ -531,7 +533,7 @@ class FlashAttentionBackwardPostprocess:
|
|
| 531 |
|
| 532 |
# Step 3: Copy dQ from register to smem
|
| 533 |
cute.arch.barrier() # make sure all threads have finished loading dQaccum
|
| 534 |
-
if const_expr(self.arch in [
|
| 535 |
copy_atom_r2s_dQ = utils.get_smem_store_atom(
|
| 536 |
self.arch, self.dtype, transpose=self.dQ_swapAB
|
| 537 |
)
|
|
@@ -553,7 +555,7 @@ class FlashAttentionBackwardPostprocess:
|
|
| 553 |
)
|
| 554 |
thr_copy_r2s_dQ = tiled_copy_r2s_dQ.get_slice(tidx)
|
| 555 |
cdQ = cute.make_identity_tensor((self.tile_m, self.tile_hdim))
|
| 556 |
-
if const_expr(self.arch in [
|
| 557 |
taccdQrdQ = thr_copy_r2s_dQ.retile(rdQ)
|
| 558 |
else:
|
| 559 |
taccdQcdQ_shape = thr_copy_r2s_dQ.partition_S(cdQ).shape
|
|
|
|
| 2 |
# A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_bwd_postprocess_kernel.h
|
| 3 |
# from Cutlass C++ to Cute-DSL.
|
| 4 |
import math
|
| 5 |
+
from typing import Callable, Optional, Type
|
| 6 |
|
| 7 |
import cuda.bindings.driver as cuda
|
| 8 |
|
|
|
|
| 36 |
self,
|
| 37 |
dtype: Type[cutlass.Numeric],
|
| 38 |
head_dim: int,
|
| 39 |
+
arch: int,
|
| 40 |
tile_m: int = 128,
|
| 41 |
num_threads: int = 256,
|
| 42 |
AtomLayoutMdQ: int = 1,
|
|
|
|
| 52 |
"""
|
| 53 |
self.dtype = dtype
|
| 54 |
self.tile_m = tile_m
|
| 55 |
+
assert arch // 10 in [8, 9, 10, 11, 12], (
|
| 56 |
+
"Only Ampere (8.x), Hopper (9.x), and Blackwell (10.x, 11.x, 12.x) are supported"
|
| 57 |
)
|
| 58 |
self.arch = arch
|
| 59 |
# padding head_dim to a multiple of 32 as k_block_size
|
|
|
|
| 63 |
self.num_threads = num_threads
|
| 64 |
self.AtomLayoutMdQ = AtomLayoutMdQ
|
| 65 |
self.dQ_swapAB = dQ_swapAB
|
| 66 |
+
self.use_2cta_instrs = use_2cta_instrs and arch // 10 == 10 and head_dim != 64
|
| 67 |
self.cluster_size = cluster_size
|
| 68 |
|
| 69 |
@staticmethod
|
|
|
|
| 89 |
return True
|
| 90 |
|
| 91 |
def _get_tiled_mma(self):
|
| 92 |
+
if const_expr(self.arch // 10 in [8, 12]):
|
| 93 |
num_mma_warps = self.num_threads // 32
|
| 94 |
atom_layout_dQ = (
|
| 95 |
(self.AtomLayoutMdQ, num_mma_warps // self.AtomLayoutMdQ, 1)
|
|
|
|
| 101 |
atom_layout_dQ,
|
| 102 |
permutation_mnk=(atom_layout_dQ[0] * 16, atom_layout_dQ[1] * 16, 16),
|
| 103 |
)
|
| 104 |
+
elif const_expr(self.arch // 10 == 9):
|
| 105 |
+
num_wg_mma = self.num_threads // 128
|
| 106 |
+
atom_layout_dQ = (self.AtomLayoutMdQ, num_wg_mma // self.AtomLayoutMdQ)
|
| 107 |
tiler_mn_dQ = (self.tile_m // atom_layout_dQ[0], self.tile_hdim // atom_layout_dQ[1])
|
| 108 |
tiled_mma = sm90_utils_basic.make_trivial_tiled_mma(
|
| 109 |
self.dtype,
|
|
|
|
| 125 |
cta_group,
|
| 126 |
(self.tile_m, self.tile_hdim),
|
| 127 |
)
|
| 128 |
+
if const_expr(self.arch // 10 in [8, 9, 12]):
|
| 129 |
assert self.num_threads == tiled_mma.size
|
| 130 |
return tiled_mma
|
| 131 |
|
|
|
|
| 148 |
cute.make_layout(self.num_threads),
|
| 149 |
cute.make_layout(async_copy_elems_accum),
|
| 150 |
)
|
| 151 |
+
num_s2r_copy_elems = 1 if const_expr(self.arch // 10 in [8, 12]) else 4
|
| 152 |
+
if const_expr(self.arch // 10 in [8, 12]):
|
| 153 |
self.s2r_tiled_copy_dQaccum = copy_utils.tiled_copy_1d(
|
| 154 |
Float32, self.num_threads, num_s2r_copy_elems
|
| 155 |
)
|
| 156 |
self.sdQaccum_layout = cute.make_layout(self.tile_m * self.tile_hdim)
|
| 157 |
+
elif const_expr(self.arch // 10 == 9):
|
| 158 |
num_threads_per_warp_group = 128
|
| 159 |
+
num_wg_mma = self.num_threads // 128
|
| 160 |
self.s2r_tiled_copy_dQaccum = cute.make_tiled_copy_tv(
|
| 161 |
cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128),
|
| 162 |
+
cute.make_layout((num_threads_per_warp_group, num_wg_mma)), # thr_layout
|
| 163 |
cute.make_layout(128 // Float32.width), # val_layout
|
| 164 |
)
|
| 165 |
self.sdQaccum_layout = cute.make_layout(
|
| 166 |
+
(self.tile_m * self.tile_hdim // num_wg_mma, num_wg_mma)
|
| 167 |
)
|
| 168 |
else:
|
| 169 |
self.dQ_reduce_ncol = 32
|
|
|
|
| 188 |
# then setting kBlockKSmem to 32 will cause "Static shape_div failure".
|
| 189 |
# We want to treat it as 64 x 48, so kBlockKSmem should be 16.
|
| 190 |
mma_shape_n = self.tiled_mma.get_tile_size(1)
|
| 191 |
+
if const_expr(self.arch // 10 in [8, 12]):
|
| 192 |
sdQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, mma_shape_n)
|
| 193 |
self.sdQ_layout = cute.tile_to_shape(
|
| 194 |
sdQ_layout_atom, (self.tile_m, self.tile_hdim), (0, 1)
|
| 195 |
)
|
| 196 |
+
elif const_expr(self.arch // 10 == 9):
|
| 197 |
+
wg_d_dQ = num_wg_mma // self.AtomLayoutMdQ
|
| 198 |
self.sdQ_layout = sm90_utils.make_smem_layout(
|
| 199 |
+
self.dtype,
|
| 200 |
+
LayoutEnum.ROW_MAJOR,
|
| 201 |
+
(self.tile_m, self.tile_hdim),
|
| 202 |
+
major_mode_size=self.tile_hdim // wg_d_dQ,
|
| 203 |
)
|
| 204 |
else:
|
| 205 |
# TODO: this is hard-coded for hdim 128
|
|
|
|
| 215 |
scale: cutlass.Float32,
|
| 216 |
mCuSeqlensQ: Optional[cute.Tensor],
|
| 217 |
mSeqUsedQ: Optional[cute.Tensor],
|
| 218 |
+
# Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).
|
| 219 |
+
stream: cuda.CUstream = None,
|
| 220 |
):
|
| 221 |
# Get the data type and check if it is fp16 or bf16
|
| 222 |
if const_expr(mdQ.element_type not in [cutlass.Float16, cutlass.BFloat16]):
|
|
|
|
| 310 |
smem = cutlass.utils.SmemAllocator()
|
| 311 |
sdQaccum = smem.allocate_tensor(cutlass.Float32, sdQaccum_layout, byte_alignment=1024)
|
| 312 |
sdQaccum_flat = cute.make_tensor(sdQaccum.iterator, cute.make_layout(cute.size(sdQaccum)))
|
| 313 |
+
if const_expr(self.arch // 10 in [8, 9, 12]):
|
| 314 |
sdQ = cute.make_tensor(cute.recast_ptr(sdQaccum.iterator, dtype=self.dtype), sdQ_layout)
|
| 315 |
else:
|
| 316 |
# extra stage dimension
|
|
|
|
| 348 |
mdQaccum_cur = mdQaccum[batch_idx, head_idx, None]
|
| 349 |
head_dim = mdQ.shape[3]
|
| 350 |
else:
|
| 351 |
+
padded_offset_q = seqlen.padded_offset_q
|
|
|
|
|
|
|
|
|
|
| 352 |
mdQ_cur = cute.domain_offset((seqlen.offset_q, 0), mdQ[None, head_idx, None])
|
| 353 |
mdQaccum_cur = cute.domain_offset(
|
| 354 |
(padded_offset_q * self.tile_hdim,), mdQaccum[head_idx, None]
|
|
|
|
| 373 |
seqlen_q = seqlen.seqlen_q
|
| 374 |
seqlen_q_rounded = cute.round_up(seqlen_q, self.tile_m)
|
| 375 |
|
| 376 |
+
if const_expr(self.arch // 10 == 10 and self.use_2cta_instrs):
|
| 377 |
# 2-CTA: remap dQaccum layout into TMEM view before writing sdQ
|
| 378 |
num_reduce_threads = self.num_threads
|
| 379 |
thr_mma_dsk = tiled_mma.get_slice(tidx)
|
|
|
|
| 504 |
tile_shape = (self.tile_m, self.tile_hdim)
|
| 505 |
acc = None
|
| 506 |
tiled_copy_t2r = None
|
| 507 |
+
if const_expr(self.arch // 10 in [8, 9, 12]):
|
| 508 |
acc_shape = tiled_mma.partition_shape_C(
|
| 509 |
tile_shape if const_expr(not dQ_swapAB) else tile_shape[::-1]
|
| 510 |
)
|
|
|
|
| 533 |
|
| 534 |
# Step 3: Copy dQ from register to smem
|
| 535 |
cute.arch.barrier() # make sure all threads have finished loading dQaccum
|
| 536 |
+
if const_expr(self.arch // 10 in [8, 9, 12]):
|
| 537 |
copy_atom_r2s_dQ = utils.get_smem_store_atom(
|
| 538 |
self.arch, self.dtype, transpose=self.dQ_swapAB
|
| 539 |
)
|
|
|
|
| 555 |
)
|
| 556 |
thr_copy_r2s_dQ = tiled_copy_r2s_dQ.get_slice(tidx)
|
| 557 |
cdQ = cute.make_identity_tensor((self.tile_m, self.tile_hdim))
|
| 558 |
+
if const_expr(self.arch // 10 in [8, 9, 12]):
|
| 559 |
taccdQrdQ = thr_copy_r2s_dQ.retile(rdQ)
|
| 560 |
else:
|
| 561 |
taccdQcdQ_shape = thr_copy_r2s_dQ.partition_S(cdQ).shape
|
build/torch-cuda/flash_bwd_preprocess.py
CHANGED
|
@@ -1,21 +1,32 @@
|
|
| 1 |
# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 2 |
# A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_bwd_preprocess_kernel.h
|
| 3 |
# from Cutlass C++ to Cute-DSL.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
import math
|
| 5 |
import operator
|
| 6 |
-
from
|
|
|
|
| 7 |
|
| 8 |
import cuda.bindings.driver as cuda
|
| 9 |
|
| 10 |
import cutlass
|
| 11 |
import cutlass.cute as cute
|
| 12 |
-
from cutlass import Float32
|
|
|
|
| 13 |
|
| 14 |
-
from .quack import copy_utils
|
| 15 |
|
| 16 |
from . import utils
|
| 17 |
-
from .
|
| 18 |
-
from .seqlen_info import SeqlenInfoQK
|
| 19 |
from .quack.cute_dsl_utils import ParamsBase
|
| 20 |
from .tile_scheduler import (
|
| 21 |
SingleTileScheduler,
|
|
@@ -30,9 +41,8 @@ class FlashAttentionBackwardPreprocess:
|
|
| 30 |
dtype: Type[cutlass.Numeric],
|
| 31 |
head_dim: int,
|
| 32 |
head_dim_v: int,
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
num_threads: int = 128,
|
| 36 |
):
|
| 37 |
"""
|
| 38 |
All contiguous dimensions must be at least 16 bytes aligned which indicates the head dimension
|
|
@@ -40,14 +50,14 @@ class FlashAttentionBackwardPreprocess:
|
|
| 40 |
|
| 41 |
:param head_dim: head dimension
|
| 42 |
:type head_dim: int
|
| 43 |
-
:param
|
| 44 |
-
:type
|
| 45 |
:param num_threads: number of threads
|
| 46 |
:type num_threads: int
|
| 47 |
"""
|
|
|
|
| 48 |
self.dtype = dtype
|
| 49 |
-
self.
|
| 50 |
-
self.arch = arch
|
| 51 |
# padding head_dim to a multiple of 32 as k_block_size
|
| 52 |
hdim_multiple_of = 32
|
| 53 |
self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of)
|
|
@@ -56,15 +66,15 @@ class FlashAttentionBackwardPreprocess:
|
|
| 56 |
self.num_threads = num_threads
|
| 57 |
|
| 58 |
@staticmethod
|
| 59 |
-
def can_implement(dtype, head_dim,
|
| 60 |
"""Check if the kernel can be implemented with the given parameters.
|
| 61 |
|
| 62 |
:param dtype: data type
|
| 63 |
:type dtype: cutlass.Numeric
|
| 64 |
:param head_dim: head dimension
|
| 65 |
:type head_dim: int
|
| 66 |
-
:param
|
| 67 |
-
:type
|
| 68 |
:param num_threads: number of threads
|
| 69 |
:type num_threads: int
|
| 70 |
|
|
@@ -77,7 +87,7 @@ class FlashAttentionBackwardPreprocess:
|
|
| 77 |
return False
|
| 78 |
if num_threads % 32 != 0:
|
| 79 |
return False
|
| 80 |
-
if num_threads <
|
| 81 |
return False
|
| 82 |
return True
|
| 83 |
|
|
@@ -105,7 +115,7 @@ class FlashAttentionBackwardPreprocess:
|
|
| 105 |
universal_copy_bits = 128
|
| 106 |
num_copy_elems_dQaccum = universal_copy_bits // Float32.width
|
| 107 |
assert (
|
| 108 |
-
self.
|
| 109 |
) % self.num_threads == 0
|
| 110 |
self.gmem_tiled_copy_dQaccum = copy_utils.tiled_copy_1d(
|
| 111 |
Float32, self.num_threads, num_copy_elems_dQaccum
|
|
@@ -114,38 +124,53 @@ class FlashAttentionBackwardPreprocess:
|
|
| 114 |
@cute.jit
|
| 115 |
def __call__(
|
| 116 |
self,
|
| 117 |
-
mO: cute.Tensor,
|
| 118 |
-
mdO: cute.Tensor,
|
| 119 |
-
|
| 120 |
-
mLSE: Optional[cute.Tensor],
|
| 121 |
-
mLSElog2: Optional[cute.Tensor],
|
|
|
|
| 122 |
mdQaccum: Optional[cute.Tensor],
|
| 123 |
-
mCuSeqlensQ: Optional[cute.Tensor],
|
| 124 |
-
mSeqUsedQ: Optional[cute.Tensor],
|
| 125 |
-
|
|
|
|
|
|
|
| 126 |
):
|
| 127 |
# Get the data type and check if it is fp16 or bf16
|
| 128 |
-
if
|
| 129 |
raise TypeError("All tensors must have the same data type")
|
| 130 |
-
if
|
| 131 |
raise TypeError("Only Float16 or BFloat16 is supported")
|
| 132 |
-
if
|
| 133 |
-
raise TypeError("
|
| 134 |
-
if
|
| 135 |
-
if
|
| 136 |
raise TypeError("dQaccum tensor must be Float32")
|
| 137 |
-
if
|
| 138 |
assert mLSElog2 is not None, "If mLSE is provided, mLSElog2 must also be provided"
|
| 139 |
-
if
|
| 140 |
raise TypeError("LSE tensor must be Float32")
|
| 141 |
-
if
|
| 142 |
raise TypeError("LSElog2 tensor must be Float32")
|
| 143 |
-
|
| 144 |
-
|
|
|
|
| 145 |
|
| 146 |
self._setup_attributes()
|
| 147 |
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
TileScheduler = SingleTileVarlenScheduler
|
| 150 |
num_head = mO.shape[1]
|
| 151 |
num_batch = mCuSeqlensQ.shape[0] - 1
|
|
@@ -155,7 +180,7 @@ class FlashAttentionBackwardPreprocess:
|
|
| 155 |
num_batch = mO.shape[0]
|
| 156 |
|
| 157 |
tile_sched_args = TileSchedulerArguments(
|
| 158 |
-
num_block=cute.ceil_div(mO.shape[1], self.
|
| 159 |
num_head=num_head,
|
| 160 |
num_batch=num_batch,
|
| 161 |
num_splits=1,
|
|
@@ -163,7 +188,7 @@ class FlashAttentionBackwardPreprocess:
|
|
| 163 |
headdim=0,
|
| 164 |
headdim_v=mO.shape[2],
|
| 165 |
total_q=mO.shape[0],
|
| 166 |
-
tile_shape_mn=(self.
|
| 167 |
mCuSeqlensQ=mCuSeqlensQ,
|
| 168 |
mSeqUsedQ=mSeqUsedQ,
|
| 169 |
)
|
|
@@ -174,12 +199,13 @@ class FlashAttentionBackwardPreprocess:
|
|
| 174 |
self.kernel(
|
| 175 |
mO,
|
| 176 |
mdO,
|
| 177 |
-
|
| 178 |
mLSE,
|
| 179 |
mLSElog2,
|
| 180 |
mdQaccum,
|
| 181 |
mCuSeqlensQ,
|
| 182 |
mSeqUsedQ,
|
|
|
|
| 183 |
self.gmem_tiled_copy_O,
|
| 184 |
self.gmem_tiled_copy_dQaccum,
|
| 185 |
tile_sched_params,
|
|
@@ -188,6 +214,7 @@ class FlashAttentionBackwardPreprocess:
|
|
| 188 |
grid=grid_dim,
|
| 189 |
block=[self.num_threads, 1, 1],
|
| 190 |
stream=stream,
|
|
|
|
| 191 |
)
|
| 192 |
|
| 193 |
@cute.kernel
|
|
@@ -195,12 +222,13 @@ class FlashAttentionBackwardPreprocess:
|
|
| 195 |
self,
|
| 196 |
mO: cute.Tensor,
|
| 197 |
mdO: cute.Tensor,
|
| 198 |
-
|
| 199 |
mLSE: Optional[cute.Tensor],
|
| 200 |
mLSElog2: Optional[cute.Tensor],
|
| 201 |
mdQaccum: Optional[cute.Tensor],
|
| 202 |
mCuSeqlensQ: Optional[cute.Tensor],
|
| 203 |
mSeqUsedQ: Optional[cute.Tensor],
|
|
|
|
| 204 |
gmem_tiled_copy_O: cute.TiledCopy,
|
| 205 |
gmem_tiled_copy_dQaccum: cute.TiledCopy,
|
| 206 |
tile_sched_params: ParamsBase,
|
|
@@ -217,145 +245,106 @@ class FlashAttentionBackwardPreprocess:
|
|
| 217 |
# ///////////////////////////////////////////////////////////////////////////////
|
| 218 |
# Get the appropriate tiles for this thread block.
|
| 219 |
# ///////////////////////////////////////////////////////////////////////////////
|
| 220 |
-
seqlen =
|
| 221 |
-
batch_idx,
|
| 222 |
-
mO.shape[1],
|
| 223 |
-
0,
|
| 224 |
-
mCuSeqlensQ=mCuSeqlensQ,
|
| 225 |
-
mCuSeqlensK=None,
|
| 226 |
-
mSeqUsedQ=mSeqUsedQ,
|
| 227 |
-
mSeqUsedK=None,
|
| 228 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
mdPsum_cur = mdPsum[batch_idx, head_idx, None]
|
| 234 |
-
headdim_v = mO.shape[3]
|
| 235 |
-
else:
|
| 236 |
-
mO_cur = cute.domain_offset((seqlen.offset_q, 0), mO[None, head_idx, None])
|
| 237 |
-
mdO_cur = cute.domain_offset((seqlen.offset_q, 0), mdO[None, head_idx, None])
|
| 238 |
-
|
| 239 |
-
padded_offset_q = seqlen.offset_q + batch_idx * self.m_block_size
|
| 240 |
-
if cutlass.const_expr(self.arch >= 90):
|
| 241 |
-
padded_offset_q = padded_offset_q // self.m_block_size * self.m_block_size
|
| 242 |
-
mdPsum_cur = cute.domain_offset((padded_offset_q,), mdPsum[head_idx, None])
|
| 243 |
-
headdim_v = mO.shape[2]
|
| 244 |
-
|
| 245 |
-
blkOdO_shape = (self.m_block_size, self.head_dim_v_padded)
|
| 246 |
-
# (m_block_size, head_dim_v)
|
| 247 |
-
gO = cute.local_tile(mO_cur, blkOdO_shape, (m_block, 0))
|
| 248 |
-
gdO = cute.local_tile(mdO_cur, blkOdO_shape, (m_block, 0))
|
| 249 |
-
|
| 250 |
gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx)
|
| 251 |
# (CPY_Atom, CPY_M, CPY_K)
|
| 252 |
tOgO = gmem_thr_copy_O.partition_S(gO)
|
| 253 |
tOgdO = gmem_thr_copy_O.partition_S(gdO)
|
| 254 |
-
|
| 255 |
-
# ///////////////////////////////////////////////////////////////////////////////
|
| 256 |
-
# Predicate: Mark indices that need to copy when problem_shape isn't a multiple
|
| 257 |
-
# of tile_shape
|
| 258 |
-
# ///////////////////////////////////////////////////////////////////////////////
|
| 259 |
-
# Construct identity layout for KV
|
| 260 |
-
cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded))
|
| 261 |
tOcO = gmem_thr_copy_O.partition_S(cO)
|
| 262 |
t0OcO = gmem_thr_copy_O.get_slice(0).partition_S(cO)
|
| 263 |
-
tOpO =
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_block,))
|
| 276 |
-
lse = Float32.inf
|
| 277 |
-
if tidx < seqlen_q - m_block * self.m_block_size:
|
| 278 |
-
lse = gLSE[tidx]
|
| 279 |
-
|
| 280 |
-
tOrO = cute.make_fragment_like(tOgO)
|
| 281 |
-
tOrdO = cute.make_fragment_like(tOgdO)
|
| 282 |
-
assert cute.size(tOgO, mode=[0]) == cute.size(tOgdO, mode=[0])
|
| 283 |
-
assert cute.size(tOgO, mode=[1]) == cute.size(tOgdO, mode=[1])
|
| 284 |
-
assert cute.size(tOgO, mode=[2]) == cute.size(tOgdO, mode=[2])
|
| 285 |
for m in cutlass.range(cute.size(tOrO.shape[1]), unroll_full=True):
|
| 286 |
-
# Instead of using tOcO, we using t0OcO and subtract the offset from the limit
|
| 287 |
-
#
|
| 288 |
-
if t0OcO[0, m, 0][0] <
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
else None,
|
| 296 |
-
)
|
| 297 |
-
cute.copy(
|
| 298 |
-
gmem_thr_copy_O,
|
| 299 |
-
tOgdO[None, m, None],
|
| 300 |
-
tOrdO[None, m, None],
|
| 301 |
-
pred=tOpdO[None, m, None]
|
| 302 |
-
if cutlass.const_expr(self.check_hdim_v_oob)
|
| 303 |
-
else None,
|
| 304 |
-
)
|
| 305 |
# Sum across the "k" dimension
|
| 306 |
-
|
| 307 |
cute.ReductionOp.ADD, init_val=0.0, reduction_profile=(0, None, 1)
|
| 308 |
)
|
| 309 |
threads_per_row = gmem_tiled_copy_O.layout_src_tv_tiled[0].shape[0]
|
| 310 |
assert cute.arch.WARP_SIZE % threads_per_row == 0
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
#
|
| 316 |
-
|
| 317 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
if tOcO[0, 0, 0][1] == 0:
|
| 319 |
-
for m in cutlass.range(cute.size(
|
| 320 |
row = tOcO[0, m, 0][0]
|
| 321 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 322 |
|
| 323 |
# Clear dQaccum
|
| 324 |
-
if
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
(padded_offset_q * self.head_dim_padded,), mdQaccum[head_idx, None]
|
| 330 |
-
)
|
| 331 |
-
|
| 332 |
-
# HACK: Compiler doesn't seem to recognize that padding
|
| 333 |
-
# by padded_offset_q * self.head_dim_padded keeps alignment
|
| 334 |
-
# since statically divisible by 4
|
| 335 |
-
|
| 336 |
-
mdQaccum_cur_ptr = cute.make_ptr(
|
| 337 |
-
dtype=mdQaccum_cur.element_type,
|
| 338 |
-
value=mdQaccum_cur.iterator.toint(),
|
| 339 |
-
mem_space=mdQaccum_cur.iterator.memspace,
|
| 340 |
-
assumed_align=mdQaccum.iterator.alignment,
|
| 341 |
-
)
|
| 342 |
-
mdQaccum_cur = cute.make_tensor(mdQaccum_cur_ptr, mdQaccum_cur.layout)
|
| 343 |
-
|
| 344 |
-
blkdQaccum_shape = (self.m_block_size * self.head_dim_padded,)
|
| 345 |
gdQaccum = cute.local_tile(mdQaccum_cur, blkdQaccum_shape, (m_block,))
|
| 346 |
gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_slice(tidx)
|
| 347 |
tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum)
|
| 348 |
-
zero = cute.
|
| 349 |
zero.fill(0.0)
|
| 350 |
cute.copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum)
|
| 351 |
|
| 352 |
-
if
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
gLSElog2 = cute.local_tile(mLSElog2_cur, (self.m_block_size,), (m_block,))
|
| 359 |
LOG2_E = math.log2(math.e)
|
| 360 |
-
if tidx < seqlen_q_rounded - m_block * self.
|
| 361 |
gLSElog2[tidx] = lse * LOG2_E if lse != -Float32.inf else 0.0
|
|
|
|
| 1 |
# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 2 |
# A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_bwd_preprocess_kernel.h
|
| 3 |
# from Cutlass C++ to Cute-DSL.
|
| 4 |
+
#
|
| 5 |
+
# Computes D_i = (dO_i * O_i).sum(dim=-1), optionally adjusted for LSE gradient:
|
| 6 |
+
# D'_i = D_i - dLSE_i
|
| 7 |
+
# This works because in the backward pass:
|
| 8 |
+
# dS_ij = P_ij * (dP_ij - D_i) [standard]
|
| 9 |
+
# When LSE is differentiable, d(loss)/d(S_ij) gets an extra term dLSE_i * P_ij
|
| 10 |
+
# (since d(LSE_i)/d(S_ij) = P_ij), giving:
|
| 11 |
+
# dS_ij = P_ij * (dP_ij - D_i) + dLSE_i * P_ij
|
| 12 |
+
# = P_ij * (dP_ij - (D_i - dLSE_i))
|
| 13 |
+
# So the main backward kernel is unchanged; we just replace D with D' = D - dLSE here.
|
| 14 |
import math
|
| 15 |
import operator
|
| 16 |
+
from functools import partial
|
| 17 |
+
from typing import Callable, Type, Optional
|
| 18 |
|
| 19 |
import cuda.bindings.driver as cuda
|
| 20 |
|
| 21 |
import cutlass
|
| 22 |
import cutlass.cute as cute
|
| 23 |
+
from cutlass import Float32, const_expr
|
| 24 |
+
from cutlass.cutlass_dsl import Arch, BaseDSL
|
| 25 |
|
| 26 |
+
from .quack import copy_utils, layout_utils
|
| 27 |
|
| 28 |
from . import utils
|
| 29 |
+
from .seqlen_info import SeqlenInfo
|
|
|
|
| 30 |
from .quack.cute_dsl_utils import ParamsBase
|
| 31 |
from .tile_scheduler import (
|
| 32 |
SingleTileScheduler,
|
|
|
|
| 41 |
dtype: Type[cutlass.Numeric],
|
| 42 |
head_dim: int,
|
| 43 |
head_dim_v: int,
|
| 44 |
+
tile_m: int = 128,
|
| 45 |
+
num_threads: int = 256,
|
|
|
|
| 46 |
):
|
| 47 |
"""
|
| 48 |
All contiguous dimensions must be at least 16 bytes aligned which indicates the head dimension
|
|
|
|
| 50 |
|
| 51 |
:param head_dim: head dimension
|
| 52 |
:type head_dim: int
|
| 53 |
+
:param tile_m: m block size
|
| 54 |
+
:type tile_m: int
|
| 55 |
:param num_threads: number of threads
|
| 56 |
:type num_threads: int
|
| 57 |
"""
|
| 58 |
+
self.use_pdl = BaseDSL._get_dsl().get_arch_enum() >= Arch.sm_90a
|
| 59 |
self.dtype = dtype
|
| 60 |
+
self.tile_m = tile_m
|
|
|
|
| 61 |
# padding head_dim to a multiple of 32 as k_block_size
|
| 62 |
hdim_multiple_of = 32
|
| 63 |
self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of)
|
|
|
|
| 66 |
self.num_threads = num_threads
|
| 67 |
|
| 68 |
@staticmethod
|
| 69 |
+
def can_implement(dtype, head_dim, tile_m, num_threads) -> bool:
|
| 70 |
"""Check if the kernel can be implemented with the given parameters.
|
| 71 |
|
| 72 |
:param dtype: data type
|
| 73 |
:type dtype: cutlass.Numeric
|
| 74 |
:param head_dim: head dimension
|
| 75 |
:type head_dim: int
|
| 76 |
+
:param tile_m: m block size
|
| 77 |
+
:type tile_m: int
|
| 78 |
:param num_threads: number of threads
|
| 79 |
:type num_threads: int
|
| 80 |
|
|
|
|
| 87 |
return False
|
| 88 |
if num_threads % 32 != 0:
|
| 89 |
return False
|
| 90 |
+
if num_threads < tile_m: # For multiplying lse with log2
|
| 91 |
return False
|
| 92 |
return True
|
| 93 |
|
|
|
|
| 115 |
universal_copy_bits = 128
|
| 116 |
num_copy_elems_dQaccum = universal_copy_bits // Float32.width
|
| 117 |
assert (
|
| 118 |
+
self.tile_m * self.head_dim_padded // num_copy_elems_dQaccum
|
| 119 |
) % self.num_threads == 0
|
| 120 |
self.gmem_tiled_copy_dQaccum = copy_utils.tiled_copy_1d(
|
| 121 |
Float32, self.num_threads, num_copy_elems_dQaccum
|
|
|
|
| 124 |
@cute.jit
|
| 125 |
def __call__(
|
| 126 |
self,
|
| 127 |
+
mO: cute.Tensor, # (batch, seqlen, nheads, head_dim_v) or (total_q, nheads, head_dim_v)
|
| 128 |
+
mdO: cute.Tensor, # same shape as mO
|
| 129 |
+
mPdPsum: cute.Tensor, # (batch, nheads, seqlen_padded) or (nheads, total_q_padded)
|
| 130 |
+
mLSE: Optional[cute.Tensor], # (batch, nheads, seqlen) or (nheads, total_q)
|
| 131 |
+
mLSElog2: Optional[cute.Tensor], # same shape as mPdPsum
|
| 132 |
+
# (batch, nheads, seqlen_padded * head_dim_v) or (nheads, total_q_padded * head_dim_v)
|
| 133 |
mdQaccum: Optional[cute.Tensor],
|
| 134 |
+
mCuSeqlensQ: Optional[cute.Tensor], # (batch + 1,)
|
| 135 |
+
mSeqUsedQ: Optional[cute.Tensor], # (batch,)
|
| 136 |
+
mdLSE: Optional[cute.Tensor], # (batch, nheads, seqlen) or (nheads, total_q)
|
| 137 |
+
# Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).
|
| 138 |
+
stream: cuda.CUstream = None,
|
| 139 |
):
|
| 140 |
# Get the data type and check if it is fp16 or bf16
|
| 141 |
+
if const_expr(not (mO.element_type == mdO.element_type)):
|
| 142 |
raise TypeError("All tensors must have the same data type")
|
| 143 |
+
if const_expr(mO.element_type not in [cutlass.Float16, cutlass.BFloat16]):
|
| 144 |
raise TypeError("Only Float16 or BFloat16 is supported")
|
| 145 |
+
if const_expr(mPdPsum.element_type not in [Float32]):
|
| 146 |
+
raise TypeError("PdPsum tensor must be Float32")
|
| 147 |
+
if const_expr(mdQaccum is not None):
|
| 148 |
+
if const_expr(mdQaccum.element_type not in [Float32]):
|
| 149 |
raise TypeError("dQaccum tensor must be Float32")
|
| 150 |
+
if const_expr(mLSE is not None):
|
| 151 |
assert mLSElog2 is not None, "If mLSE is provided, mLSElog2 must also be provided"
|
| 152 |
+
if const_expr(mLSE.element_type not in [Float32]):
|
| 153 |
raise TypeError("LSE tensor must be Float32")
|
| 154 |
+
if const_expr(mLSElog2.element_type not in [Float32]):
|
| 155 |
raise TypeError("LSElog2 tensor must be Float32")
|
| 156 |
+
if const_expr(mdLSE is not None):
|
| 157 |
+
if const_expr(mdLSE.element_type not in [Float32]):
|
| 158 |
+
raise TypeError("dLSE tensor must be Float32")
|
| 159 |
|
| 160 |
self._setup_attributes()
|
| 161 |
|
| 162 |
+
# (batch, nheads, seqlen) -> (seqlen, nheads, batch) or (total_q, nheads) -> (nheads, total_q)
|
| 163 |
+
transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0]
|
| 164 |
+
mPdPsum = layout_utils.select(mPdPsum, transpose)
|
| 165 |
+
if const_expr(mLSE is not None):
|
| 166 |
+
mLSE = layout_utils.select(mLSE, transpose)
|
| 167 |
+
mLSElog2 = layout_utils.select(mLSElog2, transpose)
|
| 168 |
+
if const_expr(mdLSE is not None):
|
| 169 |
+
mdLSE = layout_utils.select(mdLSE, transpose)
|
| 170 |
+
if const_expr(mdQaccum is not None):
|
| 171 |
+
mdQaccum = layout_utils.select(mdQaccum, transpose)
|
| 172 |
+
|
| 173 |
+
if const_expr(mCuSeqlensQ is not None):
|
| 174 |
TileScheduler = SingleTileVarlenScheduler
|
| 175 |
num_head = mO.shape[1]
|
| 176 |
num_batch = mCuSeqlensQ.shape[0] - 1
|
|
|
|
| 180 |
num_batch = mO.shape[0]
|
| 181 |
|
| 182 |
tile_sched_args = TileSchedulerArguments(
|
| 183 |
+
num_block=cute.ceil_div(mO.shape[1], self.tile_m),
|
| 184 |
num_head=num_head,
|
| 185 |
num_batch=num_batch,
|
| 186 |
num_splits=1,
|
|
|
|
| 188 |
headdim=0,
|
| 189 |
headdim_v=mO.shape[2],
|
| 190 |
total_q=mO.shape[0],
|
| 191 |
+
tile_shape_mn=(self.tile_m, 1),
|
| 192 |
mCuSeqlensQ=mCuSeqlensQ,
|
| 193 |
mSeqUsedQ=mSeqUsedQ,
|
| 194 |
)
|
|
|
|
| 199 |
self.kernel(
|
| 200 |
mO,
|
| 201 |
mdO,
|
| 202 |
+
mPdPsum,
|
| 203 |
mLSE,
|
| 204 |
mLSElog2,
|
| 205 |
mdQaccum,
|
| 206 |
mCuSeqlensQ,
|
| 207 |
mSeqUsedQ,
|
| 208 |
+
mdLSE,
|
| 209 |
self.gmem_tiled_copy_O,
|
| 210 |
self.gmem_tiled_copy_dQaccum,
|
| 211 |
tile_sched_params,
|
|
|
|
| 214 |
grid=grid_dim,
|
| 215 |
block=[self.num_threads, 1, 1],
|
| 216 |
stream=stream,
|
| 217 |
+
use_pdl=self.use_pdl,
|
| 218 |
)
|
| 219 |
|
| 220 |
@cute.kernel
|
|
|
|
| 222 |
self,
|
| 223 |
mO: cute.Tensor,
|
| 224 |
mdO: cute.Tensor,
|
| 225 |
+
mPdPsum: cute.Tensor,
|
| 226 |
mLSE: Optional[cute.Tensor],
|
| 227 |
mLSElog2: Optional[cute.Tensor],
|
| 228 |
mdQaccum: Optional[cute.Tensor],
|
| 229 |
mCuSeqlensQ: Optional[cute.Tensor],
|
| 230 |
mSeqUsedQ: Optional[cute.Tensor],
|
| 231 |
+
mdLSE: Optional[cute.Tensor],
|
| 232 |
gmem_tiled_copy_O: cute.TiledCopy,
|
| 233 |
gmem_tiled_copy_dQaccum: cute.TiledCopy,
|
| 234 |
tile_sched_params: ParamsBase,
|
|
|
|
| 245 |
# ///////////////////////////////////////////////////////////////////////////////
|
| 246 |
# Get the appropriate tiles for this thread block.
|
| 247 |
# ///////////////////////////////////////////////////////////////////////////////
|
| 248 |
+
seqlen = SeqlenInfo.create(
|
| 249 |
+
batch_idx, mO.shape[1], mCuSeqlensQ, mSeqUsedQ, tile=self.tile_m
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
)
|
| 251 |
+
mO_cur = seqlen.offset_batch(mO, batch_idx, dim=0)[None, head_idx, None]
|
| 252 |
+
mdO_cur = seqlen.offset_batch(mdO, batch_idx, dim=0)[None, head_idx, None]
|
| 253 |
+
mPdPsum_cur = seqlen.offset_batch(mPdPsum, batch_idx, dim=2, padded=True)[
|
| 254 |
+
None, head_idx
|
| 255 |
+
]
|
| 256 |
+
headdim_v = mO_cur.shape[cute.rank(mO_cur) - 1]
|
| 257 |
+
seqlen_q = seqlen.seqlen
|
| 258 |
+
seqlen_q_rounded = cute.round_up(seqlen_q, self.tile_m)
|
| 259 |
+
seqlen_limit = seqlen_q - m_block * self.tile_m
|
| 260 |
+
|
| 261 |
+
lse = None
|
| 262 |
+
if const_expr(mLSE is not None):
|
| 263 |
+
mLSE_cur = seqlen.offset_batch(mLSE, batch_idx, dim=2)[None, head_idx]
|
| 264 |
+
gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (m_block,))
|
| 265 |
+
lse = Float32.inf
|
| 266 |
+
if tidx < seqlen_limit:
|
| 267 |
+
lse = gLSE[tidx]
|
| 268 |
|
| 269 |
+
blk_shape = (self.tile_m, self.head_dim_v_padded)
|
| 270 |
+
gO = cute.local_tile(mO_cur, blk_shape, (m_block, 0))
|
| 271 |
+
gdO = cute.local_tile(mdO_cur, blk_shape, (m_block, 0))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx)
|
| 273 |
# (CPY_Atom, CPY_M, CPY_K)
|
| 274 |
tOgO = gmem_thr_copy_O.partition_S(gO)
|
| 275 |
tOgdO = gmem_thr_copy_O.partition_S(gdO)
|
| 276 |
+
cO = cute.make_identity_tensor(blk_shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
tOcO = gmem_thr_copy_O.partition_S(cO)
|
| 278 |
t0OcO = gmem_thr_copy_O.get_slice(0).partition_S(cO)
|
| 279 |
+
tOpO = None
|
| 280 |
+
if const_expr(self.check_hdim_v_oob):
|
| 281 |
+
tOpO = copy_utils.predicate_k(tOcO, limit=headdim_v)
|
| 282 |
+
# Each copy will use the same predicate
|
| 283 |
+
copy = partial(copy_utils.copy, pred=tOpO)
|
| 284 |
+
|
| 285 |
+
tOrO = cute.make_rmem_tensor_like(tOgO)
|
| 286 |
+
tOrdO = cute.make_rmem_tensor_like(tOgdO)
|
| 287 |
+
if const_expr(self.check_hdim_v_oob):
|
| 288 |
+
tOrO.fill(0.0)
|
| 289 |
+
tOrdO.fill(0.0)
|
| 290 |
+
assert tOgO.shape == tOgdO.shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
for m in cutlass.range(cute.size(tOrO.shape[1]), unroll_full=True):
|
| 292 |
+
# Instead of using tOcO, we using t0OcO and subtract the offset from the limit.
|
| 293 |
+
# This is bc the entries of t0OcO are known at compile time.
|
| 294 |
+
if t0OcO[0, m, 0][0] < seqlen_limit - tOcO[0][0]:
|
| 295 |
+
copy(tOgO[None, m, None], tOrO[None, m, None])
|
| 296 |
+
copy(tOgdO[None, m, None], tOrdO[None, m, None])
|
| 297 |
+
# O and dO loads are done; signal that the next kernel can start.
|
| 298 |
+
# Correctness is ensured by griddepcontrol_wait() in bwd_sm90 before it reads our outputs.
|
| 299 |
+
if const_expr(self.use_pdl):
|
| 300 |
+
cute.arch.griddepcontrol_launch_dependents()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
# Sum across the "k" dimension
|
| 302 |
+
pdpsum = (tOrO.load().to(Float32) * tOrdO.load().to(Float32)).reduce(
|
| 303 |
cute.ReductionOp.ADD, init_val=0.0, reduction_profile=(0, None, 1)
|
| 304 |
)
|
| 305 |
threads_per_row = gmem_tiled_copy_O.layout_src_tv_tiled[0].shape[0]
|
| 306 |
assert cute.arch.WARP_SIZE % threads_per_row == 0
|
| 307 |
+
pdpsum = utils.warp_reduce(pdpsum, operator.add, width=threads_per_row)
|
| 308 |
+
PdP_sum = cute.make_rmem_tensor(cute.size(tOrO, mode=[1]), Float32)
|
| 309 |
+
PdP_sum.store(pdpsum)
|
| 310 |
+
|
| 311 |
+
# If dLSE is provided, compute D' = D - dLSE (see module docstring for derivation).
|
| 312 |
+
gdLSE = None
|
| 313 |
+
if const_expr(mdLSE is not None):
|
| 314 |
+
mdLSE_cur = seqlen.offset_batch(mdLSE, batch_idx, dim=2)[None, head_idx]
|
| 315 |
+
gdLSE = cute.local_tile(mdLSE_cur, (self.tile_m,), (m_block,))
|
| 316 |
+
|
| 317 |
+
# Write PdPsum from rmem -> gmem
|
| 318 |
+
gPdPsum = cute.local_tile(mPdPsum_cur, (self.tile_m,), (m_block,))
|
| 319 |
+
# Only the thread corresponding to column 0 writes out the PdPsum to gmem
|
| 320 |
if tOcO[0, 0, 0][1] == 0:
|
| 321 |
+
for m in cutlass.range(cute.size(PdP_sum), unroll_full=True):
|
| 322 |
row = tOcO[0, m, 0][0]
|
| 323 |
+
PdPsum_val = 0.0
|
| 324 |
+
if row < seqlen_limit:
|
| 325 |
+
PdPsum_val = PdP_sum[m]
|
| 326 |
+
if const_expr(mdLSE is not None):
|
| 327 |
+
PdPsum_val -= gdLSE[row]
|
| 328 |
+
gPdPsum[row] = PdPsum_val
|
| 329 |
|
| 330 |
# Clear dQaccum
|
| 331 |
+
if const_expr(mdQaccum is not None):
|
| 332 |
+
mdQaccum_cur = seqlen.offset_batch(
|
| 333 |
+
mdQaccum, batch_idx, dim=2, padded=True, multiple=self.head_dim_padded
|
| 334 |
+
)[None, head_idx]
|
| 335 |
+
blkdQaccum_shape = (self.tile_m * self.head_dim_padded,)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
gdQaccum = cute.local_tile(mdQaccum_cur, blkdQaccum_shape, (m_block,))
|
| 337 |
gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_slice(tidx)
|
| 338 |
tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum)
|
| 339 |
+
zero = cute.make_rmem_tensor_like(tdQgdQaccum)
|
| 340 |
zero.fill(0.0)
|
| 341 |
cute.copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum)
|
| 342 |
|
| 343 |
+
if const_expr(mLSE is not None):
|
| 344 |
+
mLSElog2_cur = seqlen.offset_batch(mLSElog2, batch_idx, dim=2, padded=True)[
|
| 345 |
+
None, head_idx
|
| 346 |
+
]
|
| 347 |
+
gLSElog2 = cute.local_tile(mLSElog2_cur, (self.tile_m,), (m_block,))
|
|
|
|
|
|
|
| 348 |
LOG2_E = math.log2(math.e)
|
| 349 |
+
if tidx < seqlen_q_rounded - m_block * self.tile_m:
|
| 350 |
gLSElog2[tidx] = lse * LOG2_E if lse != -Float32.inf else 0.0
|
build/torch-cuda/flash_bwd_sm100.py
CHANGED
|
@@ -84,7 +84,6 @@ class FlashAttentionBackwardSm100:
|
|
| 84 |
self.use_2cta_instrs = bool(
|
| 85 |
use_2cta_instrs
|
| 86 |
and cluster_size == 2
|
| 87 |
-
and not is_local
|
| 88 |
and score_mod is None
|
| 89 |
and score_mod_bwd is None
|
| 90 |
and mask_mod is None
|
|
@@ -453,7 +452,6 @@ class FlashAttentionBackwardSm100:
|
|
| 453 |
mdK: cute.Tensor,
|
| 454 |
mdV: cute.Tensor,
|
| 455 |
softmax_scale: Float32,
|
| 456 |
-
stream: cuda.CUstream,
|
| 457 |
mCuSeqlensQ: Optional[cute.Tensor] = None,
|
| 458 |
mCuSeqlensK: Optional[cute.Tensor] = None,
|
| 459 |
mSeqUsedQ: Optional[cute.Tensor] = None,
|
|
@@ -467,6 +465,8 @@ class FlashAttentionBackwardSm100:
|
|
| 467 |
aux_tensors: Optional[list] = None,
|
| 468 |
# Block-sparse tensors (Q direction - for iterating m_blocks per n_block):
|
| 469 |
blocksparse_tensors: Optional[BlockSparseTensors] = None,
|
|
|
|
|
|
|
| 470 |
):
|
| 471 |
self.q_dtype = mQ.element_type
|
| 472 |
self.k_dtype = mK.element_type
|
|
@@ -927,10 +927,6 @@ class FlashAttentionBackwardSm100:
|
|
| 927 |
"2-CTA mode does not support block sparsity. "
|
| 928 |
"Please create kernel with use_2cta_instrs=False for block sparse attention."
|
| 929 |
)
|
| 930 |
-
assert window_size_left is None and window_size_right is None, (
|
| 931 |
-
"2-CTA mode does not support window attention. "
|
| 932 |
-
"Please create kernel with use_2cta_instrs=False for window attention."
|
| 933 |
-
)
|
| 934 |
# 2-CTA: 231424 and 1-CTA: 232448
|
| 935 |
# print("SMEM: ", self.shared_storage.size_in_bytes())
|
| 936 |
if const_expr(self.use_block_sparsity or aux_tensors is not None):
|
|
@@ -3143,6 +3139,8 @@ class FlashAttentionBackwardSm100:
|
|
| 3143 |
with cute.arch.elect_one():
|
| 3144 |
pipeline_S_P.consumer_release(consumer_state_S_P_dP)
|
| 3145 |
# pipeline_S_P.sync_object_empty.arrive(0, pipeline_S_P.consumer_mask)
|
|
|
|
|
|
|
| 3146 |
pipeline_LSE.consumer_release(consumer_state_LSE)
|
| 3147 |
consumer_state_LSE.advance()
|
| 3148 |
# ---------------------------------------------
|
|
@@ -3253,6 +3251,8 @@ class FlashAttentionBackwardSm100:
|
|
| 3253 |
|
| 3254 |
cute.arch.fence_view_async_shared()
|
| 3255 |
self.compute_sync_barrier.arrive_and_wait()
|
|
|
|
|
|
|
| 3256 |
pipeline_dPsum.consumer_release(consumer_state_dPsum)
|
| 3257 |
consumer_state_dPsum.advance()
|
| 3258 |
# when 2cta hdim 128, pipeline_dS also signals S tmem load completion so is deferred
|
|
@@ -3650,6 +3650,9 @@ class FlashAttentionBackwardSm100:
|
|
| 3650 |
tile_scheduler.advance_to_next_work()
|
| 3651 |
work_tile = tile_scheduler.get_current_work()
|
| 3652 |
|
|
|
|
|
|
|
|
|
|
| 3653 |
@cute.jit
|
| 3654 |
def epilogue_dKV(
|
| 3655 |
self,
|
|
|
|
| 84 |
self.use_2cta_instrs = bool(
|
| 85 |
use_2cta_instrs
|
| 86 |
and cluster_size == 2
|
|
|
|
| 87 |
and score_mod is None
|
| 88 |
and score_mod_bwd is None
|
| 89 |
and mask_mod is None
|
|
|
|
| 452 |
mdK: cute.Tensor,
|
| 453 |
mdV: cute.Tensor,
|
| 454 |
softmax_scale: Float32,
|
|
|
|
| 455 |
mCuSeqlensQ: Optional[cute.Tensor] = None,
|
| 456 |
mCuSeqlensK: Optional[cute.Tensor] = None,
|
| 457 |
mSeqUsedQ: Optional[cute.Tensor] = None,
|
|
|
|
| 465 |
aux_tensors: Optional[list] = None,
|
| 466 |
# Block-sparse tensors (Q direction - for iterating m_blocks per n_block):
|
| 467 |
blocksparse_tensors: Optional[BlockSparseTensors] = None,
|
| 468 |
+
# Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).
|
| 469 |
+
stream: cuda.CUstream = None,
|
| 470 |
):
|
| 471 |
self.q_dtype = mQ.element_type
|
| 472 |
self.k_dtype = mK.element_type
|
|
|
|
| 927 |
"2-CTA mode does not support block sparsity. "
|
| 928 |
"Please create kernel with use_2cta_instrs=False for block sparse attention."
|
| 929 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 930 |
# 2-CTA: 231424 and 1-CTA: 232448
|
| 931 |
# print("SMEM: ", self.shared_storage.size_in_bytes())
|
| 932 |
if const_expr(self.use_block_sparsity or aux_tensors is not None):
|
|
|
|
| 3139 |
with cute.arch.elect_one():
|
| 3140 |
pipeline_S_P.consumer_release(consumer_state_S_P_dP)
|
| 3141 |
# pipeline_S_P.sync_object_empty.arrive(0, pipeline_S_P.consumer_mask)
|
| 3142 |
+
# Normally we'd need syncwarp here since only 1 thread will signal in
|
| 3143 |
+
# consumer_release, but we already have the self.compute_sync_barrier before this
|
| 3144 |
pipeline_LSE.consumer_release(consumer_state_LSE)
|
| 3145 |
consumer_state_LSE.advance()
|
| 3146 |
# ---------------------------------------------
|
|
|
|
| 3251 |
|
| 3252 |
cute.arch.fence_view_async_shared()
|
| 3253 |
self.compute_sync_barrier.arrive_and_wait()
|
| 3254 |
+
# Normally we'd need syncwarp here since only 1 thread will signal in
|
| 3255 |
+
# consumer_release, but we already have the self.compute_sync_barrier before this
|
| 3256 |
pipeline_dPsum.consumer_release(consumer_state_dPsum)
|
| 3257 |
consumer_state_dPsum.advance()
|
| 3258 |
# when 2cta hdim 128, pipeline_dS also signals S tmem load completion so is deferred
|
|
|
|
| 3650 |
tile_scheduler.advance_to_next_work()
|
| 3651 |
work_tile = tile_scheduler.get_current_work()
|
| 3652 |
|
| 3653 |
+
if const_expr(not self.deterministic):
|
| 3654 |
+
cute.arch.cp_async_bulk_wait_group(0, read=True)
|
| 3655 |
+
|
| 3656 |
@cute.jit
|
| 3657 |
def epilogue_dKV(
|
| 3658 |
self,
|
build/torch-cuda/flash_bwd_sm120.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 2 |
+
# SM120 (Blackwell GeForce / DGX Spark) backward pass.
|
| 3 |
+
#
|
| 4 |
+
# SM120 uses the same SM80-era MMA instructions (mma.sync.aligned.m16n8k16) but has
|
| 5 |
+
# a smaller shared memory capacity (99 KB vs 163 KB on SM80). This module subclasses
|
| 6 |
+
# FlashAttentionBackwardSm80 and overrides the SMEM capacity check accordingly.
|
| 7 |
+
|
| 8 |
+
import cutlass
|
| 9 |
+
import cutlass.utils as utils_basic
|
| 10 |
+
|
| 11 |
+
from .flash_bwd import FlashAttentionBackwardSm80
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class FlashAttentionBackwardSm120(FlashAttentionBackwardSm80):
|
| 15 |
+
@staticmethod
|
| 16 |
+
def can_implement(
|
| 17 |
+
dtype,
|
| 18 |
+
head_dim,
|
| 19 |
+
head_dim_v,
|
| 20 |
+
m_block_size,
|
| 21 |
+
n_block_size,
|
| 22 |
+
num_stages_Q,
|
| 23 |
+
num_stages_dO,
|
| 24 |
+
num_threads,
|
| 25 |
+
is_causal,
|
| 26 |
+
V_in_regs=False,
|
| 27 |
+
) -> bool:
|
| 28 |
+
"""Check if the kernel can be implemented on SM120.
|
| 29 |
+
|
| 30 |
+
Same logic as SM80 but uses SM120's shared memory capacity (99 KB).
|
| 31 |
+
"""
|
| 32 |
+
if dtype not in [cutlass.Float16, cutlass.BFloat16]:
|
| 33 |
+
return False
|
| 34 |
+
if head_dim % 8 != 0:
|
| 35 |
+
return False
|
| 36 |
+
if head_dim_v % 8 != 0:
|
| 37 |
+
return False
|
| 38 |
+
if n_block_size % 16 != 0:
|
| 39 |
+
return False
|
| 40 |
+
if num_threads % 32 != 0:
|
| 41 |
+
return False
|
| 42 |
+
# Shared memory usage: Q tile + dO tile + K tile + V tile
|
| 43 |
+
smem_usage_Q = m_block_size * head_dim * num_stages_Q * 2
|
| 44 |
+
smem_usage_dO = m_block_size * head_dim_v * num_stages_dO * 2
|
| 45 |
+
smem_usage_K = n_block_size * head_dim * 2
|
| 46 |
+
smem_usage_V = n_block_size * head_dim_v * 2
|
| 47 |
+
smem_usage_QV = (
|
| 48 |
+
(smem_usage_Q + smem_usage_V) if not V_in_regs else max(smem_usage_Q, smem_usage_V)
|
| 49 |
+
)
|
| 50 |
+
smem_usage = smem_usage_QV + smem_usage_dO + smem_usage_K
|
| 51 |
+
# SM120 has 99 KB shared memory (vs 163 KB on SM80)
|
| 52 |
+
smem_capacity = utils_basic.get_smem_capacity_in_bytes("sm_120")
|
| 53 |
+
if smem_usage > smem_capacity:
|
| 54 |
+
return False
|
| 55 |
+
return True
|
build/torch-cuda/flash_bwd_sm90.py
CHANGED
|
@@ -24,7 +24,13 @@ from .seqlen_info import SeqlenInfoQK
|
|
| 24 |
from .block_info import BlockInfo
|
| 25 |
from . import pipeline
|
| 26 |
from .quack.cute_dsl_utils import ParamsBase
|
| 27 |
-
from .tile_scheduler import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
from .named_barrier import NamedBarrierBwd
|
| 29 |
from .softmax import apply_score_mod_inner, apply_score_mod_bwd_inner
|
| 30 |
from .block_sparsity import BlockSparseTensors
|
|
@@ -46,6 +52,8 @@ class FlashAttentionBackwardSm90:
|
|
| 46 |
head_dim_v: Optional[int] = None,
|
| 47 |
qhead_per_kvhead: int = 1,
|
| 48 |
is_causal: bool = False,
|
|
|
|
|
|
|
| 49 |
tile_m: int = 64,
|
| 50 |
tile_n: int = 128,
|
| 51 |
Q_stage: int = 2,
|
|
@@ -64,6 +72,7 @@ class FlashAttentionBackwardSm90:
|
|
| 64 |
mask_mod: cutlass.Constexpr | None = None,
|
| 65 |
has_aux_tensors: cutlass.Constexpr = False,
|
| 66 |
subtile_factor: cutlass.Constexpr[int] = 1,
|
|
|
|
| 67 |
):
|
| 68 |
self.dtype = dtype
|
| 69 |
# padding head_dim to a multiple of 16 as k_block_size
|
|
@@ -77,7 +86,8 @@ class FlashAttentionBackwardSm90:
|
|
| 77 |
self.check_hdim_v_oob = head_dim_v != self.tile_hdimv
|
| 78 |
self.qhead_per_kvhead = qhead_per_kvhead
|
| 79 |
self.is_causal = is_causal
|
| 80 |
-
self.is_local =
|
|
|
|
| 81 |
self.tile_m = tile_m
|
| 82 |
self.tile_n = tile_n
|
| 83 |
self.num_threads = num_threads
|
|
@@ -92,23 +102,23 @@ class FlashAttentionBackwardSm90:
|
|
| 92 |
self.AtomLayoutMSdP = AtomLayoutMSdP
|
| 93 |
self.AtomLayoutNdKV = AtomLayoutNdKV
|
| 94 |
self.AtomLayoutMdQ = AtomLayoutMdQ
|
| 95 |
-
self.
|
| 96 |
self.mma_dkv_is_rs = (
|
| 97 |
AtomLayoutMSdP == 1
|
| 98 |
-
and AtomLayoutNdKV == self.
|
| 99 |
and SdP_swapAB
|
| 100 |
and not dKV_swapAB
|
| 101 |
)
|
| 102 |
self.V_in_regs = V_in_regs
|
|
|
|
| 103 |
if qhead_per_kvhead > 1:
|
| 104 |
assert self.same_hdim_kv, "GQA backward requires head_dim == head_dim_v"
|
| 105 |
-
assert self.
|
| 106 |
# These are tuned for speed
|
| 107 |
# Do we keep the LSE and dPsum in each thread, or split them across 8 threads that share
|
| 108 |
# them and then shuffle to get the value whenever we need? This can reduce register
|
| 109 |
# pressure when SdP_swapAB, where each thread needs to keep statistics for (kBlockM / 4)
|
| 110 |
# rows. If !SdP_swapAB, each thread only needs to keep statistics for 2 rows.
|
| 111 |
-
# TODO: impl these for hdim 64
|
| 112 |
self.shuffle_LSE = self.SdP_swapAB and self.tile_hdim <= 64
|
| 113 |
self.shuffle_dPsum = self.SdP_swapAB and self.tile_hdim <= 64
|
| 114 |
|
|
@@ -124,6 +134,12 @@ class FlashAttentionBackwardSm90:
|
|
| 124 |
else:
|
| 125 |
self.vec_size: cutlass.Constexpr = 4
|
| 126 |
self.qk_acc_dtype = Float32
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
@staticmethod
|
| 129 |
def can_implement(
|
|
@@ -182,32 +198,58 @@ class FlashAttentionBackwardSm90:
|
|
| 182 |
assert mQ_type == self.dtype
|
| 183 |
|
| 184 |
def _setup_attributes(self):
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
|
|
|
|
|
|
| 193 |
]
|
| 194 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
self.sdQaccum_layout = cute.make_layout(
|
| 196 |
-
(self.tile_m * self.tile_hdim // self.
|
| 197 |
)
|
| 198 |
# dQaccum R->S
|
| 199 |
self.r2s_tiled_copy_dQaccum = cute.make_tiled_copy_tv(
|
| 200 |
cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128),
|
| 201 |
# thr_layout
|
| 202 |
-
cute.make_layout((self.num_threads_per_warp_group, self.
|
| 203 |
cute.make_layout(128 // Float32.width), # val_layout
|
| 204 |
)
|
| 205 |
# dKVaccum for GQA epilogue - reuses sV+sK memory recast as f32
|
| 206 |
# TODO: assert that sVaccum and sKaccum don't overflow smem
|
| 207 |
|
| 208 |
def _get_tiled_mma(self):
|
|
|
|
| 209 |
# S = Q @ K.T, dP = dO @ V.T
|
| 210 |
-
atom_layout_SdP = (self.AtomLayoutMSdP, self.
|
| 211 |
tiler_mn_SdP = (self.tile_m // atom_layout_SdP[0], self.tile_n // atom_layout_SdP[1])
|
| 212 |
tiled_mma_SdP = sm90_utils_basic.make_trivial_tiled_mma(
|
| 213 |
self.dtype,
|
|
@@ -215,12 +257,11 @@ class FlashAttentionBackwardSm90:
|
|
| 215 |
warpgroup.OperandMajorMode.K,
|
| 216 |
warpgroup.OperandMajorMode.K,
|
| 217 |
Float32,
|
| 218 |
-
atom_layout_mnk=(atom_layout_SdP
|
| 219 |
-
|
| 220 |
-
tiler_mn=tiler_mn_SdP if not self.SdP_swapAB else tiler_mn_SdP[::-1],
|
| 221 |
)
|
| 222 |
# dV = P.T @ dO, dK = dS.T @ Q
|
| 223 |
-
atom_layout_dKV = (self.AtomLayoutNdKV, self.
|
| 224 |
tiler_mn_dK = (self.tile_n // atom_layout_dKV[0], self.tile_hdim // atom_layout_dKV[1])
|
| 225 |
tiler_mn_dV = (self.tile_n // atom_layout_dKV[0], self.tile_hdimv // atom_layout_dKV[1])
|
| 226 |
tiled_mma_dK, tiled_mma_dV = [
|
|
@@ -232,9 +273,8 @@ class FlashAttentionBackwardSm90:
|
|
| 232 |
else warpgroup.OperandMajorMode.K,
|
| 233 |
warpgroup.OperandMajorMode.MN,
|
| 234 |
Float32,
|
| 235 |
-
atom_layout_mnk=(atom_layout_dKV
|
| 236 |
-
|
| 237 |
-
tiler_mn=tiler_mn_d if not self.dKV_swapAB else tiler_mn_d[::-1],
|
| 238 |
a_source=warpgroup.OperandSource.RMEM
|
| 239 |
if self.mma_dkv_is_rs
|
| 240 |
else warpgroup.OperandSource.SMEM,
|
|
@@ -242,7 +282,8 @@ class FlashAttentionBackwardSm90:
|
|
| 242 |
for tiler_mn_d in (tiler_mn_dK, tiler_mn_dV)
|
| 243 |
]
|
| 244 |
# dQ = dS @ K
|
| 245 |
-
|
|
|
|
| 246 |
tiler_mn_dQ = (self.tile_m // atom_layout_dQ[0], self.tile_hdim // atom_layout_dQ[1])
|
| 247 |
tiled_mma_dQ = sm90_utils_basic.make_trivial_tiled_mma(
|
| 248 |
self.dtype,
|
|
@@ -250,8 +291,8 @@ class FlashAttentionBackwardSm90:
|
|
| 250 |
warpgroup.OperandMajorMode.K if not self.dQ_swapAB else warpgroup.OperandMajorMode.MN,
|
| 251 |
warpgroup.OperandMajorMode.MN if not self.dQ_swapAB else warpgroup.OperandMajorMode.K,
|
| 252 |
Float32,
|
| 253 |
-
atom_layout_mnk=(atom_layout_dQ
|
| 254 |
-
tiler_mn=tiler_mn_dQ if not self.dQ_swapAB else tiler_mn_dQ[
|
| 255 |
)
|
| 256 |
return tiled_mma_SdP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ
|
| 257 |
|
|
@@ -305,7 +346,6 @@ class FlashAttentionBackwardSm90:
|
|
| 305 |
mdK: cute.Tensor,
|
| 306 |
mdV: cute.Tensor,
|
| 307 |
softmax_scale: Float32,
|
| 308 |
-
stream: cuda.CUstream,
|
| 309 |
mCuSeqlensQ: Optional[cute.Tensor] = None,
|
| 310 |
mCuSeqlensK: Optional[cute.Tensor] = None,
|
| 311 |
mSeqUsedQ: Optional[cute.Tensor] = None,
|
|
@@ -318,10 +358,13 @@ class FlashAttentionBackwardSm90:
|
|
| 318 |
mdV_semaphore: Optional[cute.Tensor] = None,
|
| 319 |
aux_tensors: Optional[list] = None,
|
| 320 |
blocksparse_tensors: Optional[BlockSparseTensors] = None,
|
|
|
|
|
|
|
| 321 |
):
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
|
|
|
| 325 |
|
| 326 |
self._check_type(
|
| 327 |
*(
|
|
@@ -330,23 +373,36 @@ class FlashAttentionBackwardSm90:
|
|
| 330 |
)
|
| 331 |
)
|
| 332 |
|
|
|
|
|
|
|
| 333 |
mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV = [
|
| 334 |
assume_tensor_aligned(t) for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV)
|
| 335 |
]
|
| 336 |
|
| 337 |
-
|
| 338 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
if const_expr(self.qhead_per_kvhead == 1):
|
| 340 |
-
mdK, mdV = [
|
| 341 |
else:
|
| 342 |
-
|
|
|
|
| 343 |
mdK, mdV = [layout_utils.select(t, accum_transpose) for t in (mdK, mdV)]
|
| 344 |
-
|
|
|
|
| 345 |
mLSE, mdPsum, mdQaccum = [
|
| 346 |
layout_utils.select(t, LSE_dPsum_dQaccum_transpose) for t in (mLSE, mdPsum, mdQaccum)
|
| 347 |
]
|
| 348 |
|
| 349 |
tiled_mma_SdP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ = self._get_tiled_mma()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
|
| 351 |
self.num_mma_threads = tiled_mma_SdP.size
|
| 352 |
assert self.num_mma_threads + 128 == self.num_threads
|
|
@@ -354,10 +410,25 @@ class FlashAttentionBackwardSm90:
|
|
| 354 |
self.num_threads_per_warp_group = 128
|
| 355 |
self.num_producer_threads = 32
|
| 356 |
|
| 357 |
-
self.
|
| 358 |
-
self.
|
| 359 |
-
|
| 360 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
|
| 362 |
self._setup_attributes()
|
| 363 |
SharedStorage = self._get_shared_storage_cls()
|
|
@@ -374,7 +445,7 @@ class FlashAttentionBackwardSm90:
|
|
| 374 |
self.tma_copy_bytes["LSE"] = self.tile_m * Float32.width // 8
|
| 375 |
self.tma_copy_bytes["dPsum"] = self.tile_m * Float32.width // 8
|
| 376 |
self.tma_copy_bytes["dQ"] = (
|
| 377 |
-
self.tile_m * self.tile_hdim * Float32.width // 8 // self.
|
| 378 |
)
|
| 379 |
self.tma_copy_bytes["dKacc"] = self.tile_n * self.tile_hdim * Float32.width // 8
|
| 380 |
self.tma_copy_bytes["dVacc"] = self.tile_n * self.tile_hdimv * Float32.width // 8
|
|
@@ -404,38 +475,59 @@ class FlashAttentionBackwardSm90:
|
|
| 404 |
(self.tile_m, self.tile_hdimv),
|
| 405 |
)
|
| 406 |
if const_expr(self.qhead_per_kvhead == 1):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
tma_atom_dK, tma_tensor_dK = cpasync.make_tiled_tma_atom(
|
| 408 |
cpasync.CopyBulkTensorTileS2GOp(),
|
| 409 |
-
|
| 410 |
cute.select(self.sK_layout, mode=[0, 1]),
|
| 411 |
(self.tile_n, self.tile_hdim),
|
| 412 |
)
|
| 413 |
tma_atom_dV, tma_tensor_dV = cpasync.make_tiled_tma_atom(
|
| 414 |
cpasync.CopyBulkTensorTileS2GOp(),
|
| 415 |
-
|
| 416 |
cute.select(self.sV_layout, mode=[0, 1]),
|
| 417 |
(self.tile_n, self.tile_hdimv),
|
| 418 |
)
|
| 419 |
else:
|
| 420 |
tma_atom_dK = tma_atom_dV = tma_tensor_dK = tma_tensor_dV = None
|
| 421 |
|
| 422 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 423 |
tile_sched_args = TileSchedulerArguments(
|
| 424 |
cute.ceil_div(cute.size(mK.shape[0]), self.tile_n),
|
| 425 |
cute.size(mQ.shape[2]),
|
| 426 |
-
cute.size(
|
|
|
|
|
|
|
| 427 |
1, # num_splits
|
| 428 |
-
cute.size(
|
| 429 |
-
mQ.shape[1],
|
| 430 |
-
mV.shape[1],
|
| 431 |
-
total_q=cute.size(
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
|
|
|
|
|
|
| 435 |
qhead_per_kvhead_packgqa=1,
|
| 436 |
element_size=self.dtype.width // 8,
|
| 437 |
is_persistent=False,
|
| 438 |
-
lpt=
|
|
|
|
| 439 |
)
|
| 440 |
|
| 441 |
tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
|
|
@@ -461,6 +553,11 @@ class FlashAttentionBackwardSm90:
|
|
| 461 |
|
| 462 |
self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None)
|
| 463 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 464 |
self.kernel(
|
| 465 |
tma_tensor_Q,
|
| 466 |
tma_tensor_K,
|
|
@@ -477,6 +574,10 @@ class FlashAttentionBackwardSm90:
|
|
| 477 |
mLSE,
|
| 478 |
mdPsum,
|
| 479 |
mdQaccum,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 480 |
self.sQ_layout,
|
| 481 |
self.sK_layout,
|
| 482 |
self.sV_layout,
|
|
@@ -497,11 +598,15 @@ class FlashAttentionBackwardSm90:
|
|
| 497 |
fastdiv_mods,
|
| 498 |
blocksparse_tensors,
|
| 499 |
qhead_per_kvhead_divmod,
|
|
|
|
|
|
|
|
|
|
| 500 |
).launch(
|
| 501 |
grid=grid_dim,
|
| 502 |
block=[self.num_threads, 1, 1],
|
| 503 |
stream=stream,
|
| 504 |
min_blocks_per_mp=1,
|
|
|
|
| 505 |
)
|
| 506 |
|
| 507 |
@cute.kernel
|
|
@@ -522,6 +627,10 @@ class FlashAttentionBackwardSm90:
|
|
| 522 |
mLSE: cute.Tensor,
|
| 523 |
mdPsum: cute.Tensor,
|
| 524 |
mdQaccum: cute.Tensor,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 525 |
sQ_layout: cute.ComposedLayout,
|
| 526 |
sK_layout: cute.ComposedLayout,
|
| 527 |
sV_layout: cute.ComposedLayout,
|
|
@@ -542,15 +651,17 @@ class FlashAttentionBackwardSm90:
|
|
| 542 |
fastdiv_mods=(None, None),
|
| 543 |
blocksparse_tensors: Optional[BlockSparseTensors] = None,
|
| 544 |
qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None,
|
|
|
|
|
|
|
|
|
|
| 545 |
):
|
| 546 |
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
| 547 |
|
| 548 |
# prefetch TMA descriptors
|
| 549 |
if warp_idx == 0:
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
cpasync.prefetch_descriptor(tma_atom_dO)
|
| 554 |
|
| 555 |
smem = cutlass.utils.SmemAllocator()
|
| 556 |
storage = smem.allocate(SharedStorage)
|
|
@@ -604,25 +715,27 @@ class FlashAttentionBackwardSm90:
|
|
| 604 |
self.is_causal,
|
| 605 |
self.is_local,
|
| 606 |
False, # is_split_kv
|
| 607 |
-
|
| 608 |
-
|
| 609 |
qhead_per_kvhead_packgqa=1,
|
| 610 |
)
|
| 611 |
SeqlenInfoCls = partial(
|
| 612 |
SeqlenInfoQK.create,
|
| 613 |
seqlen_q_static=mQ.shape[0],
|
| 614 |
seqlen_k_static=mK.shape[0],
|
| 615 |
-
mCuSeqlensQ=
|
| 616 |
-
mCuSeqlensK=
|
| 617 |
-
mSeqUsedQ=
|
| 618 |
-
mSeqUsedK=
|
|
|
|
|
|
|
| 619 |
)
|
| 620 |
AttentionMaskCls = partial(
|
| 621 |
AttentionMask,
|
| 622 |
self.tile_m,
|
| 623 |
self.tile_n,
|
| 624 |
-
window_size_left=
|
| 625 |
-
window_size_right=
|
| 626 |
swap_AB=self.SdP_swapAB,
|
| 627 |
)
|
| 628 |
TileSchedulerCls = partial(TileScheduler.create, tile_sched_params)
|
|
@@ -663,12 +776,12 @@ class FlashAttentionBackwardSm90:
|
|
| 663 |
TileSchedulerCls,
|
| 664 |
SeqlenInfoCls,
|
| 665 |
blocksparse_tensors,
|
|
|
|
| 666 |
)
|
| 667 |
else:
|
| 668 |
-
cute.arch.setmaxregister_increase(self.num_mma_regs)
|
| 669 |
tidx, _, _ = cute.arch.thread_idx()
|
| 670 |
tidx = tidx - 128
|
| 671 |
-
|
| 672 |
tiled_mma_SdP,
|
| 673 |
tiled_mma_dK,
|
| 674 |
tiled_mma_dV,
|
|
@@ -702,6 +815,19 @@ class FlashAttentionBackwardSm90:
|
|
| 702 |
blocksparse_tensors,
|
| 703 |
qhead_per_kvhead_divmod,
|
| 704 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 705 |
|
| 706 |
@cute.jit
|
| 707 |
def load(
|
|
@@ -749,18 +875,22 @@ class FlashAttentionBackwardSm90:
|
|
| 749 |
if const_expr(self.qhead_per_kvhead == 1)
|
| 750 |
else head_idx // qhead_per_kvhead_divmod
|
| 751 |
)
|
| 752 |
-
mK_cur = mK[None, None, head_idx_kv
|
|
|
|
| 753 |
gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (n_block, 0))
|
| 754 |
-
mV_cur = mV[None, None, head_idx_kv, batch_idx]
|
| 755 |
gV = cute.local_tile(mV_cur, (self.tile_n, self.tile_hdimv), (n_block, 0))
|
| 756 |
|
| 757 |
-
mQ_cur = mQ[None, None, head_idx
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 758 |
gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (None, 0))
|
| 759 |
-
mdO_cur = mdO[None, None, head_idx, batch_idx]
|
| 760 |
gdO = cute.local_tile(mdO_cur, (self.tile_m, self.tile_hdimv), (None, 0))
|
| 761 |
-
mLSE_cur = mLSE[None, head_idx, batch_idx]
|
| 762 |
gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (None,))
|
| 763 |
-
mdPsum_cur = mdPsum[None, head_idx, batch_idx]
|
| 764 |
gdPsum = cute.local_tile(mdPsum_cur, (self.tile_m,), (None,))
|
| 765 |
|
| 766 |
load_K, _, _ = copy_utils.tma_get_copy_fn(
|
|
@@ -786,7 +916,10 @@ class FlashAttentionBackwardSm90:
|
|
| 786 |
|
| 787 |
if const_expr(not self.use_block_sparsity):
|
| 788 |
total_m_block_cnt = m_block_max - m_block_min
|
| 789 |
-
process_tile =
|
|
|
|
|
|
|
|
|
|
| 790 |
else:
|
| 791 |
total_m_block_cnt = get_total_q_block_count_bwd(
|
| 792 |
blocksparse_tensors,
|
|
@@ -806,6 +939,8 @@ class FlashAttentionBackwardSm90:
|
|
| 806 |
)
|
| 807 |
load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q))
|
| 808 |
load_Q(first_m_block, producer_state=producer_state_Q)
|
|
|
|
|
|
|
| 809 |
load_LSE(first_m_block, producer_state=producer_state_Q)
|
| 810 |
producer_state_dO_cur = (
|
| 811 |
producer_state_dO
|
|
@@ -984,16 +1119,20 @@ class FlashAttentionBackwardSm90:
|
|
| 984 |
fastdiv_mods=(None, None),
|
| 985 |
blocksparse_tensors: Optional[BlockSparseTensors] = None,
|
| 986 |
qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None,
|
|
|
|
| 987 |
):
|
| 988 |
warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group)
|
| 989 |
warp_group_thread_layout = cute.make_layout(
|
| 990 |
-
self.
|
| 991 |
)
|
| 992 |
thr_mma_SdP = tiled_mma_SdP.get_slice(tidx)
|
| 993 |
wg_mma_SdP = tiled_mma_SdP.get_slice(warp_group_thread_layout(warp_group_idx))
|
| 994 |
wg_mma_dK = tiled_mma_dK.get_slice(warp_group_thread_layout(warp_group_idx))
|
| 995 |
wg_mma_dV = tiled_mma_dV.get_slice(warp_group_thread_layout(warp_group_idx))
|
| 996 |
-
wg_mma_dQ =
|
|
|
|
|
|
|
|
|
|
| 997 |
# S = Q @ K.T
|
| 998 |
shape_mnk_S = (self.tile_m, self.tile_n, self.tile_hdim)
|
| 999 |
_, tSrQ, tSrK = sm90_utils.partition_fragment_ABC(
|
|
@@ -1039,23 +1178,43 @@ class FlashAttentionBackwardSm90:
|
|
| 1039 |
# dQ = dS @ K
|
| 1040 |
sKt = layout_utils.transpose_view(sK)
|
| 1041 |
shape_mnk_dQ = (self.tile_m, self.tile_hdim, self.tile_n)
|
| 1042 |
-
|
| 1043 |
-
|
| 1044 |
-
|
| 1045 |
-
|
| 1046 |
-
|
| 1047 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1048 |
|
| 1049 |
-
# Smem copy atom tiling
|
| 1050 |
copy_P_r2s = None
|
|
|
|
| 1051 |
if const_expr(sP is not None):
|
| 1052 |
sP_cpy = sP if const_expr(not self.SdP_swapAB) else sPt
|
| 1053 |
copy_P_r2s, _, _ = copy_utils.get_smem_store_C(
|
| 1054 |
-
tiled_mma_SdP,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1055 |
)
|
| 1056 |
sdS_cpy = sdS if const_expr(not self.SdP_swapAB) else sdSt
|
| 1057 |
copy_dS_r2s, _, _ = copy_utils.get_smem_store_C(
|
| 1058 |
-
tiled_mma_SdP,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1059 |
)
|
| 1060 |
|
| 1061 |
tLSEsLSE = layout_utils.mma_partition_C_vec(
|
|
@@ -1064,9 +1223,21 @@ class FlashAttentionBackwardSm90:
|
|
| 1064 |
tLSEsdPsum = layout_utils.mma_partition_C_vec(
|
| 1065 |
sdPsum, thr_mma_SdP, expand_shape=self.tile_n, is_colvec=not self.SdP_swapAB
|
| 1066 |
)
|
| 1067 |
-
|
| 1068 |
-
|
| 1069 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1070 |
|
| 1071 |
PdS_barrier = cutlass.pipeline.NamedBarrier(
|
| 1072 |
barrier_id=int(NamedBarrierBwd.PdS), num_threads=self.num_mma_threads
|
|
@@ -1105,6 +1276,7 @@ class FlashAttentionBackwardSm90:
|
|
| 1105 |
PdS_barrier=PdS_barrier,
|
| 1106 |
# acc_dV=acc_dV,
|
| 1107 |
# acc_dK=acc_dK,
|
|
|
|
| 1108 |
)
|
| 1109 |
|
| 1110 |
consumer_state_Q = cutlass.pipeline.make_pipeline_state(
|
|
@@ -1136,7 +1308,10 @@ class FlashAttentionBackwardSm90:
|
|
| 1136 |
m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block)
|
| 1137 |
|
| 1138 |
if const_expr(not self.use_block_sparsity):
|
| 1139 |
-
process_tile =
|
|
|
|
|
|
|
|
|
|
| 1140 |
else:
|
| 1141 |
total_m_block_cnt = get_total_q_block_count_bwd(
|
| 1142 |
blocksparse_tensors,
|
|
@@ -1218,8 +1393,8 @@ class FlashAttentionBackwardSm90:
|
|
| 1218 |
qhead_per_kvhead_divmod,
|
| 1219 |
)
|
| 1220 |
else:
|
| 1221 |
-
#
|
| 1222 |
-
if const_expr(self.use_block_sparsity):
|
| 1223 |
acc_dK.fill(0.0)
|
| 1224 |
acc_dV.fill(0.0)
|
| 1225 |
self.epilogue_dKV(
|
|
@@ -1248,6 +1423,22 @@ class FlashAttentionBackwardSm90:
|
|
| 1248 |
if warp_idx == 4:
|
| 1249 |
cute.arch.cp_async_bulk_wait_group(0, read=True)
|
| 1250 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1251 |
@cute.jit
|
| 1252 |
def mma_one_m_block(
|
| 1253 |
self,
|
|
@@ -1266,16 +1457,17 @@ class FlashAttentionBackwardSm90:
|
|
| 1266 |
pipeline_dO: cutlass.pipeline.PipelineAsync,
|
| 1267 |
tLSEsLSE: cute.Tensor,
|
| 1268 |
tLSEsdPsum: cute.Tensor,
|
| 1269 |
-
tdQsdQaccum: cute.Tensor,
|
| 1270 |
softmax_scale_log2: Float32,
|
| 1271 |
PdS_barrier: cutlass.pipeline.NamedBarrier,
|
|
|
|
| 1272 |
mask_fn: Optional[Callable] = None,
|
| 1273 |
score_mod_fn: Optional[Callable] = None,
|
| 1274 |
score_mod_bwd_fn: Optional[Callable] = None,
|
| 1275 |
dKV_accumulate: Boolean = True,
|
| 1276 |
):
|
| 1277 |
consumer_state_dO_cur = (
|
| 1278 |
-
|
| 1279 |
)
|
| 1280 |
smem_idx_Q = consumer_state_Q.index
|
| 1281 |
smem_idx_dO = consumer_state_dO_cur.index if const_expr(self.dO_stage > 1) else 0
|
|
@@ -1283,6 +1475,7 @@ class FlashAttentionBackwardSm90:
|
|
| 1283 |
# (1) [GEMM 1] S = Q @ K^T
|
| 1284 |
pipeline_Q.consumer_wait(consumer_state_Q, pipeline_Q.consumer_try_wait(consumer_state_Q))
|
| 1285 |
acc_S = mma_qk_fn(A_idx=smem_idx_Q, wg_wait=-1)
|
|
|
|
| 1286 |
tLSErLSE = copy_utils.load_s2r(tLSEsLSE[None, smem_idx_Q])
|
| 1287 |
# (2) [GEMM 2] dP = dO @ V.T
|
| 1288 |
pipeline_dO.consumer_wait(
|
|
@@ -1301,10 +1494,12 @@ class FlashAttentionBackwardSm90:
|
|
| 1301 |
if cutlass.const_expr(mask_fn is not None):
|
| 1302 |
mask_fn(acc_S, m_block=m_block)
|
| 1303 |
acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S, transpose=self.SdP_swapAB)
|
|
|
|
| 1304 |
for r in cutlass.range_constexpr(cute.size(acc_S_mn, mode=[0])):
|
|
|
|
| 1305 |
for c in cutlass.range(cute.size(acc_S_mn, mode=[1]), unroll_full=True):
|
| 1306 |
acc_S_mn[r, c] = cute.math.exp2(
|
| 1307 |
-
acc_S_mn[r, c] * softmax_scale_log2 -
|
| 1308 |
)
|
| 1309 |
tLSErdPsum = copy_utils.load_s2r(tLSEsdPsum[None, smem_idx_dO])
|
| 1310 |
|
|
@@ -1321,8 +1516,9 @@ class FlashAttentionBackwardSm90:
|
|
| 1321 |
warpgroup.wait_group(0)
|
| 1322 |
acc_dP_mn = layout_utils.reshape_acc_to_mn(acc_dP, transpose=self.SdP_swapAB)
|
| 1323 |
for r in cutlass.range_constexpr(cute.size(acc_dP_mn, mode=[0])):
|
|
|
|
| 1324 |
for c in cutlass.range(cute.size(acc_dP_mn, mode=[1]), unroll_full=True):
|
| 1325 |
-
acc_dP_mn[r, c] = acc_S_mn[r, c] * (acc_dP_mn[r, c] -
|
| 1326 |
|
| 1327 |
if const_expr(self.score_mod_bwd is not None):
|
| 1328 |
score_mod_bwd_fn(acc_dP, acc_S_pre, m_block=m_block)
|
|
@@ -1354,36 +1550,50 @@ class FlashAttentionBackwardSm90:
|
|
| 1354 |
# smem fence to make sure sdS is written before it's read by WGMMA
|
| 1355 |
cute.arch.fence_view_async_shared()
|
| 1356 |
PdS_barrier.arrive_and_wait()
|
| 1357 |
-
# (6) [GEMM 4] dQ = dS @ K
|
| 1358 |
-
acc_dQ = mma_dsk_fn(A_idx=smem_idx_PdS, wg_wait=1)
|
| 1359 |
-
# if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dV)
|
| 1360 |
-
pipeline_dO.consumer_release(consumer_state_dO_cur) # release dO as dV mma is done
|
| 1361 |
|
| 1362 |
-
|
| 1363 |
-
|
| 1364 |
-
|
| 1365 |
-
|
| 1366 |
-
)
|
| 1367 |
-
else:
|
| 1368 |
-
mma_dsq_fn(tCrA=tdKrdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1)
|
| 1369 |
-
# if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dQ)
|
| 1370 |
|
| 1371 |
-
|
| 1372 |
-
|
| 1373 |
-
|
| 1374 |
-
|
| 1375 |
-
|
| 1376 |
-
|
| 1377 |
-
|
| 1378 |
-
cute.arch.barrier_arrive(
|
| 1379 |
-
barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx,
|
| 1380 |
-
number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE,
|
| 1381 |
-
)
|
| 1382 |
|
| 1383 |
-
|
| 1384 |
-
|
| 1385 |
-
|
| 1386 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1387 |
|
| 1388 |
consumer_state_Q.advance()
|
| 1389 |
consumer_state_dO.advance()
|
|
@@ -1415,8 +1625,12 @@ class FlashAttentionBackwardSm90:
|
|
| 1415 |
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
| 1416 |
|
| 1417 |
if const_expr(self.qhead_per_kvhead == 1):
|
| 1418 |
-
|
| 1419 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1420 |
gdK = cute.local_tile(mdK_cur, (self.tile_n, self.tile_hdim), (n_block, 0))
|
| 1421 |
gdV = cute.local_tile(mdV_cur, (self.tile_n, self.tile_hdimv), (n_block, 0))
|
| 1422 |
store_dK, _, _ = copy_utils.tma_get_copy_fn(
|
|
@@ -1428,10 +1642,20 @@ class FlashAttentionBackwardSm90:
|
|
| 1428 |
sdV = sV if const_expr(not self.dKV_swapAB) else layout_utils.transpose_view(sV)
|
| 1429 |
sdK = sK if const_expr(not self.dKV_swapAB) else layout_utils.transpose_view(sK)
|
| 1430 |
copy_dV_r2s, _, _ = copy_utils.get_smem_store_C(
|
| 1431 |
-
tiled_mma_dV,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1432 |
)
|
| 1433 |
copy_dK_r2s, _, _ = copy_utils.get_smem_store_C(
|
| 1434 |
-
tiled_mma_dK,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1435 |
)
|
| 1436 |
cute.arch.cp_async_bulk_wait_group(1, read=True)
|
| 1437 |
epi_barrier.arrive_and_wait()
|
|
@@ -1450,15 +1674,19 @@ class FlashAttentionBackwardSm90:
|
|
| 1450 |
store_dK()
|
| 1451 |
cute.arch.cp_async_bulk_commit_group()
|
| 1452 |
else:
|
| 1453 |
-
sdKaccum_shape0 = self.tile_n * self.tile_hdim // self.
|
| 1454 |
-
sdVaccum_shape0 = self.tile_n * self.tile_hdimv // self.
|
| 1455 |
-
sdKaccum_layout = cute.make_layout((sdKaccum_shape0, self.
|
| 1456 |
-
sdVaccum_layout = cute.make_layout((sdVaccum_shape0, self.
|
| 1457 |
head_idx_kv = head_idx // qhead_per_kvhead_divmod
|
| 1458 |
-
mdKaccum_cur =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1459 |
gdKaccum_ = cute.local_tile(mdKaccum_cur, (self.tile_n * self.tile_hdim,), (n_block,))
|
| 1460 |
gdKaccum = cute.flat_divide(gdKaccum_, (sdKaccum_shape0,))
|
| 1461 |
-
mdVaccum_cur = mdV[None, head_idx_kv, batch_idx]
|
| 1462 |
gdVaccum_ = cute.local_tile(mdVaccum_cur, (self.tile_n * self.tile_hdimv,), (n_block,))
|
| 1463 |
gdVaccum = cute.flat_divide(gdVaccum_, (sdVaccum_shape0,))
|
| 1464 |
# These two overlap each other
|
|
@@ -1467,7 +1695,7 @@ class FlashAttentionBackwardSm90:
|
|
| 1467 |
sdVaccum = cute.make_tensor(sVaccum_ptr, sdVaccum_layout)
|
| 1468 |
tiled_copy_dKVaccum_r2s = cute.make_tiled_copy_tv(
|
| 1469 |
cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128),
|
| 1470 |
-
cute.make_layout((self.num_threads_per_warp_group, self.
|
| 1471 |
cute.make_layout(128 // Float32.width),
|
| 1472 |
)
|
| 1473 |
thr_copy_dKVaccum_r2s = tiled_copy_dKVaccum_r2s.get_slice(tidx)
|
|
@@ -1482,11 +1710,11 @@ class FlashAttentionBackwardSm90:
|
|
| 1482 |
epi_barrier.arrive_and_wait()
|
| 1483 |
if warp_idx == 4:
|
| 1484 |
with cute.arch.elect_one():
|
| 1485 |
-
for wg_idx in cutlass.range_constexpr(self.
|
| 1486 |
copy_utils.cpasync_reduce_bulk_add_f32(
|
| 1487 |
sdKaccum[None, wg_idx].iterator,
|
| 1488 |
gdKaccum[None, wg_idx].iterator,
|
| 1489 |
-
self.tma_copy_bytes["dKacc"] // self.
|
| 1490 |
)
|
| 1491 |
cute.arch.cp_async_bulk_commit_group()
|
| 1492 |
|
|
@@ -1498,11 +1726,11 @@ class FlashAttentionBackwardSm90:
|
|
| 1498 |
epi_barrier.arrive_and_wait()
|
| 1499 |
if warp_idx == 4:
|
| 1500 |
with cute.arch.elect_one():
|
| 1501 |
-
for wg_idx in cutlass.range_constexpr(self.
|
| 1502 |
copy_utils.cpasync_reduce_bulk_add_f32(
|
| 1503 |
sdVaccum[None, wg_idx].iterator,
|
| 1504 |
gdVaccum[None, wg_idx].iterator,
|
| 1505 |
-
self.tma_copy_bytes["dVacc"] // self.
|
| 1506 |
)
|
| 1507 |
cute.arch.cp_async_bulk_commit_group()
|
| 1508 |
|
|
@@ -1515,21 +1743,45 @@ class FlashAttentionBackwardSm90:
|
|
| 1515 |
TileSchedulerCls: cutlass.Constexpr[Callable],
|
| 1516 |
SeqlenInfoCls: cutlass.Constexpr[Callable],
|
| 1517 |
blocksparse_tensors: Optional[BlockSparseTensors] = None,
|
|
|
|
| 1518 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1519 |
tile_scheduler = TileSchedulerCls()
|
| 1520 |
work_tile = tile_scheduler.initial_work_tile_info()
|
| 1521 |
while work_tile.is_valid_tile:
|
| 1522 |
n_block, head_idx, batch_idx, _ = work_tile.tile_idx
|
| 1523 |
seqlen = SeqlenInfoCls(batch_idx)
|
| 1524 |
-
|
| 1525 |
-
|
| 1526 |
-
|
| 1527 |
-
|
| 1528 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1529 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1530 |
m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block)
|
| 1531 |
if const_expr(not self.use_block_sparsity):
|
| 1532 |
-
process_tile =
|
|
|
|
|
|
|
|
|
|
| 1533 |
loop_count = m_block_max - m_block_min
|
| 1534 |
else:
|
| 1535 |
total_block_cnt = get_total_q_block_count_bwd(
|
|
@@ -1548,17 +1800,36 @@ class FlashAttentionBackwardSm90:
|
|
| 1548 |
m_block = m_block_min + iter_idx
|
| 1549 |
m_block_safe = m_block
|
| 1550 |
|
| 1551 |
-
|
| 1552 |
-
|
| 1553 |
-
|
| 1554 |
-
|
|
|
|
|
|
|
|
|
|
| 1555 |
cute.arch.barrier_arrive(
|
| 1556 |
barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx,
|
| 1557 |
number_of_threads=self.num_threads_per_warp_group
|
| 1558 |
+ cute.arch.WARP_SIZE,
|
| 1559 |
)
|
| 1560 |
|
| 1561 |
-
for
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1562 |
cute.arch.barrier(
|
| 1563 |
barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx,
|
| 1564 |
number_of_threads=self.num_threads_per_warp_group
|
|
@@ -1567,11 +1838,24 @@ class FlashAttentionBackwardSm90:
|
|
| 1567 |
with cute.arch.elect_one():
|
| 1568 |
copy_utils.cpasync_reduce_bulk_add_f32(
|
| 1569 |
sdQaccum[None, warp_group_idx].iterator,
|
| 1570 |
-
gdQaccum[None, warp_group_idx, m_block_safe].iterator,
|
| 1571 |
self.tma_copy_bytes["dQ"],
|
| 1572 |
)
|
| 1573 |
cute.arch.cp_async_bulk_commit_group()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1574 |
else:
|
|
|
|
|
|
|
|
|
|
| 1575 |
dQaccum_store_block_sparse_bwd_sm90(
|
| 1576 |
blocksparse_tensors,
|
| 1577 |
batch_idx,
|
|
@@ -1581,11 +1865,27 @@ class FlashAttentionBackwardSm90:
|
|
| 1581 |
gdQaccum,
|
| 1582 |
subtile_factor=self.subtile_factor,
|
| 1583 |
m_block_max=m_block_max,
|
| 1584 |
-
|
| 1585 |
num_threads_per_warp_group=self.num_threads_per_warp_group,
|
| 1586 |
tma_copy_bytes_dQ=self.tma_copy_bytes["dQ"],
|
| 1587 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1588 |
tile_scheduler.advance_to_next_work()
|
| 1589 |
work_tile = tile_scheduler.get_current_work()
|
| 1590 |
|
| 1591 |
-
|
|
|
|
|
|
| 24 |
from .block_info import BlockInfo
|
| 25 |
from . import pipeline
|
| 26 |
from .quack.cute_dsl_utils import ParamsBase
|
| 27 |
+
from .tile_scheduler import (
|
| 28 |
+
TileSchedulerArguments,
|
| 29 |
+
SingleTileScheduler,
|
| 30 |
+
SingleTileLPTBwdScheduler,
|
| 31 |
+
SingleTileVarlenScheduler,
|
| 32 |
+
)
|
| 33 |
+
from . import barrier
|
| 34 |
from .named_barrier import NamedBarrierBwd
|
| 35 |
from .softmax import apply_score_mod_inner, apply_score_mod_bwd_inner
|
| 36 |
from .block_sparsity import BlockSparseTensors
|
|
|
|
| 52 |
head_dim_v: Optional[int] = None,
|
| 53 |
qhead_per_kvhead: int = 1,
|
| 54 |
is_causal: bool = False,
|
| 55 |
+
is_local: bool = False,
|
| 56 |
+
deterministic: bool = False,
|
| 57 |
tile_m: int = 64,
|
| 58 |
tile_n: int = 128,
|
| 59 |
Q_stage: int = 2,
|
|
|
|
| 72 |
mask_mod: cutlass.Constexpr | None = None,
|
| 73 |
has_aux_tensors: cutlass.Constexpr = False,
|
| 74 |
subtile_factor: cutlass.Constexpr[int] = 1,
|
| 75 |
+
dQ_single_wg: bool = False,
|
| 76 |
):
|
| 77 |
self.dtype = dtype
|
| 78 |
# padding head_dim to a multiple of 16 as k_block_size
|
|
|
|
| 86 |
self.check_hdim_v_oob = head_dim_v != self.tile_hdimv
|
| 87 |
self.qhead_per_kvhead = qhead_per_kvhead
|
| 88 |
self.is_causal = is_causal
|
| 89 |
+
self.is_local = is_local
|
| 90 |
+
self.deterministic = deterministic
|
| 91 |
self.tile_m = tile_m
|
| 92 |
self.tile_n = tile_n
|
| 93 |
self.num_threads = num_threads
|
|
|
|
| 102 |
self.AtomLayoutMSdP = AtomLayoutMSdP
|
| 103 |
self.AtomLayoutNdKV = AtomLayoutNdKV
|
| 104 |
self.AtomLayoutMdQ = AtomLayoutMdQ
|
| 105 |
+
self.num_wg_mma = (self.num_threads // 128) - 1
|
| 106 |
self.mma_dkv_is_rs = (
|
| 107 |
AtomLayoutMSdP == 1
|
| 108 |
+
and AtomLayoutNdKV == self.num_wg_mma
|
| 109 |
and SdP_swapAB
|
| 110 |
and not dKV_swapAB
|
| 111 |
)
|
| 112 |
self.V_in_regs = V_in_regs
|
| 113 |
+
# May be overridden in __call__ for varlen inputs.
|
| 114 |
if qhead_per_kvhead > 1:
|
| 115 |
assert self.same_hdim_kv, "GQA backward requires head_dim == head_dim_v"
|
| 116 |
+
assert self.num_wg_mma == 2, "GQA backward assumes 2 warp groups"
|
| 117 |
# These are tuned for speed
|
| 118 |
# Do we keep the LSE and dPsum in each thread, or split them across 8 threads that share
|
| 119 |
# them and then shuffle to get the value whenever we need? This can reduce register
|
| 120 |
# pressure when SdP_swapAB, where each thread needs to keep statistics for (kBlockM / 4)
|
| 121 |
# rows. If !SdP_swapAB, each thread only needs to keep statistics for 2 rows.
|
|
|
|
| 122 |
self.shuffle_LSE = self.SdP_swapAB and self.tile_hdim <= 64
|
| 123 |
self.shuffle_dPsum = self.SdP_swapAB and self.tile_hdim <= 64
|
| 124 |
|
|
|
|
| 134 |
else:
|
| 135 |
self.vec_size: cutlass.Constexpr = 4
|
| 136 |
self.qk_acc_dtype = Float32
|
| 137 |
+
# dQ_single_wg: WG0 computes the full dQ GEMM, WG1 skips it.
|
| 138 |
+
# Only valid for 2 MMA warp groups.
|
| 139 |
+
# Credit: Ben Spector
|
| 140 |
+
if dQ_single_wg:
|
| 141 |
+
assert self.num_wg_mma == 2, "dQ_single_wg only supports 2 warp groups"
|
| 142 |
+
self.num_wg_dQ = 1 if dQ_single_wg else self.num_wg_mma
|
| 143 |
|
| 144 |
@staticmethod
|
| 145 |
def can_implement(
|
|
|
|
| 198 |
assert mQ_type == self.dtype
|
| 199 |
|
| 200 |
def _setup_attributes(self):
|
| 201 |
+
# We need to accommodate both Q and Q^T (and dO and dO^T) in shared memory.
|
| 202 |
+
# Q & dO are used in the SdP Mma and Q^T and dO^T are used in the dKV Mma.
|
| 203 |
+
# The M dimension (tile_m) doesn't matter for the layout, only the K dimension
|
| 204 |
+
wg_d_dKV = self.num_wg_mma // self.AtomLayoutNdKV
|
| 205 |
+
self.sQ_layout, self.sdO_layout = [
|
| 206 |
+
# Need to set major_mode_size (mms) to accommodate Q and Q.T
|
| 207 |
+
sm90_utils.make_smem_layout(self.dtype, LayoutEnum.ROW_MAJOR, shape, stage, mms)
|
| 208 |
+
for shape, stage, mms in [
|
| 209 |
+
((self.tile_m, self.tile_hdim), self.Q_stage, self.tile_hdim // wg_d_dKV),
|
| 210 |
+
((self.tile_m, self.tile_hdimv), self.dO_stage, self.tile_hdim // wg_d_dKV),
|
| 211 |
]
|
| 212 |
]
|
| 213 |
+
wg_d_dQ = self.num_wg_dQ // self.AtomLayoutMdQ
|
| 214 |
+
# Accomodate both K and K.T
|
| 215 |
+
self.sK_layout = sm90_utils.make_smem_layout(
|
| 216 |
+
self.dtype,
|
| 217 |
+
LayoutEnum.ROW_MAJOR,
|
| 218 |
+
(self.tile_n, self.tile_hdim),
|
| 219 |
+
stage=None,
|
| 220 |
+
major_mode_size=self.tile_hdim // wg_d_dQ,
|
| 221 |
+
)
|
| 222 |
+
# There's only V, no V.T, so layout is normal
|
| 223 |
+
self.sV_layout = sm90_utils.make_smem_layout(
|
| 224 |
+
self.dtype, LayoutEnum.ROW_MAJOR, (self.tile_n, self.tile_hdimv), None
|
| 225 |
+
)
|
| 226 |
+
# Accomodate both S and S.T
|
| 227 |
+
wg_n_SdP = self.num_wg_mma // self.AtomLayoutMSdP
|
| 228 |
+
wg_n_dKV = self.AtomLayoutNdKV
|
| 229 |
+
self.sPdS_layout = sm90_utils.make_smem_layout(
|
| 230 |
+
self.dtype,
|
| 231 |
+
LayoutEnum.ROW_MAJOR,
|
| 232 |
+
(self.tile_m, self.tile_n),
|
| 233 |
+
stage=self.PdS_stage,
|
| 234 |
+
major_mode_size=math.gcd(self.tile_n // wg_n_SdP, self.tile_n // wg_n_dKV),
|
| 235 |
+
)
|
| 236 |
self.sdQaccum_layout = cute.make_layout(
|
| 237 |
+
(self.tile_m * self.tile_hdim // self.num_wg_dQ, self.num_wg_dQ)
|
| 238 |
)
|
| 239 |
# dQaccum R->S
|
| 240 |
self.r2s_tiled_copy_dQaccum = cute.make_tiled_copy_tv(
|
| 241 |
cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128),
|
| 242 |
# thr_layout
|
| 243 |
+
cute.make_layout((self.num_threads_per_warp_group, self.num_wg_dQ)),
|
| 244 |
cute.make_layout(128 // Float32.width), # val_layout
|
| 245 |
)
|
| 246 |
# dKVaccum for GQA epilogue - reuses sV+sK memory recast as f32
|
| 247 |
# TODO: assert that sVaccum and sKaccum don't overflow smem
|
| 248 |
|
| 249 |
def _get_tiled_mma(self):
|
| 250 |
+
maybe_swap_mn = lambda shape, swap: (shape[1], shape[0], *shape[2:]) if swap else shape
|
| 251 |
# S = Q @ K.T, dP = dO @ V.T
|
| 252 |
+
atom_layout_SdP = (self.AtomLayoutMSdP, self.num_wg_mma // self.AtomLayoutMSdP, 1)
|
| 253 |
tiler_mn_SdP = (self.tile_m // atom_layout_SdP[0], self.tile_n // atom_layout_SdP[1])
|
| 254 |
tiled_mma_SdP = sm90_utils_basic.make_trivial_tiled_mma(
|
| 255 |
self.dtype,
|
|
|
|
| 257 |
warpgroup.OperandMajorMode.K,
|
| 258 |
warpgroup.OperandMajorMode.K,
|
| 259 |
Float32,
|
| 260 |
+
atom_layout_mnk=maybe_swap_mn(atom_layout_SdP, self.SdP_swapAB),
|
| 261 |
+
tiler_mn=(64, tiler_mn_SdP[1] if not self.SdP_swapAB else tiler_mn_SdP[0]),
|
|
|
|
| 262 |
)
|
| 263 |
# dV = P.T @ dO, dK = dS.T @ Q
|
| 264 |
+
atom_layout_dKV = (self.AtomLayoutNdKV, self.num_wg_mma // self.AtomLayoutNdKV, 1)
|
| 265 |
tiler_mn_dK = (self.tile_n // atom_layout_dKV[0], self.tile_hdim // atom_layout_dKV[1])
|
| 266 |
tiler_mn_dV = (self.tile_n // atom_layout_dKV[0], self.tile_hdimv // atom_layout_dKV[1])
|
| 267 |
tiled_mma_dK, tiled_mma_dV = [
|
|
|
|
| 273 |
else warpgroup.OperandMajorMode.K,
|
| 274 |
warpgroup.OperandMajorMode.MN,
|
| 275 |
Float32,
|
| 276 |
+
atom_layout_mnk=maybe_swap_mn(atom_layout_dKV, self.dKV_swapAB),
|
| 277 |
+
tiler_mn=(64, tiler_mn_d[1] if not self.dKV_swapAB else tiler_mn_d[0]),
|
|
|
|
| 278 |
a_source=warpgroup.OperandSource.RMEM
|
| 279 |
if self.mma_dkv_is_rs
|
| 280 |
else warpgroup.OperandSource.SMEM,
|
|
|
|
| 282 |
for tiler_mn_d in (tiler_mn_dK, tiler_mn_dV)
|
| 283 |
]
|
| 284 |
# dQ = dS @ K
|
| 285 |
+
assert self.num_wg_dQ % self.AtomLayoutMdQ == 0
|
| 286 |
+
atom_layout_dQ = (self.AtomLayoutMdQ, self.num_wg_dQ // self.AtomLayoutMdQ, 1)
|
| 287 |
tiler_mn_dQ = (self.tile_m // atom_layout_dQ[0], self.tile_hdim // atom_layout_dQ[1])
|
| 288 |
tiled_mma_dQ = sm90_utils_basic.make_trivial_tiled_mma(
|
| 289 |
self.dtype,
|
|
|
|
| 291 |
warpgroup.OperandMajorMode.K if not self.dQ_swapAB else warpgroup.OperandMajorMode.MN,
|
| 292 |
warpgroup.OperandMajorMode.MN if not self.dQ_swapAB else warpgroup.OperandMajorMode.K,
|
| 293 |
Float32,
|
| 294 |
+
atom_layout_mnk=maybe_swap_mn(atom_layout_dQ, self.dQ_swapAB),
|
| 295 |
+
tiler_mn=(64, tiler_mn_dQ[1] if not self.dQ_swapAB else tiler_mn_dQ[0]),
|
| 296 |
)
|
| 297 |
return tiled_mma_SdP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ
|
| 298 |
|
|
|
|
| 346 |
mdK: cute.Tensor,
|
| 347 |
mdV: cute.Tensor,
|
| 348 |
softmax_scale: Float32,
|
|
|
|
| 349 |
mCuSeqlensQ: Optional[cute.Tensor] = None,
|
| 350 |
mCuSeqlensK: Optional[cute.Tensor] = None,
|
| 351 |
mSeqUsedQ: Optional[cute.Tensor] = None,
|
|
|
|
| 358 |
mdV_semaphore: Optional[cute.Tensor] = None,
|
| 359 |
aux_tensors: Optional[list] = None,
|
| 360 |
blocksparse_tensors: Optional[BlockSparseTensors] = None,
|
| 361 |
+
# Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).
|
| 362 |
+
stream: cuda.CUstream = None,
|
| 363 |
):
|
| 364 |
+
# For GQA (qhead_per_kvhead > 1), multiple Q heads accumulate into the same dK/dV,
|
| 365 |
+
# so we need the float32 accum path + postprocess.
|
| 366 |
+
# For varlen_k with qhead_per_kvhead == 1, we use ragged TMA tensors.
|
| 367 |
+
self.varlen_k = mCuSeqlensK is not None or mSeqUsedK is not None
|
| 368 |
|
| 369 |
self._check_type(
|
| 370 |
*(
|
|
|
|
| 373 |
)
|
| 374 |
)
|
| 375 |
|
| 376 |
+
self.is_varlen_q = mCuSeqlensQ is not None or mSeqUsedQ is not None
|
| 377 |
+
|
| 378 |
mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV = [
|
| 379 |
assume_tensor_aligned(t) for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV)
|
| 380 |
]
|
| 381 |
|
| 382 |
+
# Non-varlen inputs are (b, s, n, h), varlen inputs are (s, n, h).
|
| 383 |
+
# We convert both to a seqlen-major view with head-dim second.
|
| 384 |
+
# Each tensor may have different rank when Q is padded (seqused_q) but K/V are unpadded (cu_seqlens_k).
|
| 385 |
+
def _qkv_transpose(t):
|
| 386 |
+
return layout_utils.select(t, [1, 3, 2, 0] if cute.rank(t.shape) == 4 else [0, 2, 1])
|
| 387 |
+
|
| 388 |
+
mQ, mK, mV, mdO = [_qkv_transpose(t) for t in (mQ, mK, mV, mdO)]
|
| 389 |
if const_expr(self.qhead_per_kvhead == 1):
|
| 390 |
+
mdK, mdV = [_qkv_transpose(t) for t in (mdK, mdV)]
|
| 391 |
else:
|
| 392 |
+
# Accum tensors are (b, n, s*h) for non-varlen and (n, s*h) for varlen.
|
| 393 |
+
accum_transpose = [2, 1, 0] if cute.rank(mdK.shape) == 3 else [1, 0]
|
| 394 |
mdK, mdV = [layout_utils.select(t, accum_transpose) for t in (mdK, mdV)]
|
| 395 |
+
# Non-varlen stats are (b, n, s), varlen stats are (n, s).
|
| 396 |
+
LSE_dPsum_dQaccum_transpose = [2, 1, 0] if cute.rank(mLSE.shape) == 3 else [1, 0]
|
| 397 |
mLSE, mdPsum, mdQaccum = [
|
| 398 |
layout_utils.select(t, LSE_dPsum_dQaccum_transpose) for t in (mLSE, mdPsum, mdQaccum)
|
| 399 |
]
|
| 400 |
|
| 401 |
tiled_mma_SdP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ = self._get_tiled_mma()
|
| 402 |
+
# (batch, num_head, num_m_blocks, cluster_size) -> (num_m_blocks, cluster_size, num_head, batch)
|
| 403 |
+
if const_expr(self.deterministic):
|
| 404 |
+
assert mdQ_semaphore is not None
|
| 405 |
+
mdQ_semaphore = layout_utils.select(mdQ_semaphore, mode=[2, 3, 1, 0])
|
| 406 |
|
| 407 |
self.num_mma_threads = tiled_mma_SdP.size
|
| 408 |
assert self.num_mma_threads + 128 == self.num_threads
|
|
|
|
| 410 |
self.num_threads_per_warp_group = 128
|
| 411 |
self.num_producer_threads = 32
|
| 412 |
|
| 413 |
+
REG_LIMIT = 504 if self.num_wg_mma == 2 else 512
|
| 414 |
+
if const_expr(self.num_wg_mma == 2):
|
| 415 |
+
if const_expr(self.num_wg_dQ == 1):
|
| 416 |
+
self.num_mma_regs_wg0 = 256
|
| 417 |
+
self.num_mma_regs_wg1 = 224
|
| 418 |
+
else:
|
| 419 |
+
self.num_mma_regs_wg0 = 240
|
| 420 |
+
self.num_mma_regs_wg1 = 240
|
| 421 |
+
self.num_mma_regs = self.num_mma_regs_wg0 # for backward compat
|
| 422 |
+
self.num_producer_regs = 24
|
| 423 |
+
assert (
|
| 424 |
+
self.num_mma_regs_wg0 + self.num_mma_regs_wg1 + self.num_producer_regs <= REG_LIMIT
|
| 425 |
+
)
|
| 426 |
+
else: # 3 warp groups
|
| 427 |
+
self.num_mma_regs_wg0 = 160
|
| 428 |
+
self.num_mma_regs_wg1 = 160
|
| 429 |
+
self.num_mma_regs = 160
|
| 430 |
+
self.num_producer_regs = 32
|
| 431 |
+
assert self.num_mma_regs_wg0 * self.num_wg_mma + self.num_producer_regs <= REG_LIMIT
|
| 432 |
|
| 433 |
self._setup_attributes()
|
| 434 |
SharedStorage = self._get_shared_storage_cls()
|
|
|
|
| 445 |
self.tma_copy_bytes["LSE"] = self.tile_m * Float32.width // 8
|
| 446 |
self.tma_copy_bytes["dPsum"] = self.tile_m * Float32.width // 8
|
| 447 |
self.tma_copy_bytes["dQ"] = (
|
| 448 |
+
self.tile_m * self.tile_hdim * Float32.width // 8 // self.num_wg_dQ
|
| 449 |
)
|
| 450 |
self.tma_copy_bytes["dKacc"] = self.tile_n * self.tile_hdim * Float32.width // 8
|
| 451 |
self.tma_copy_bytes["dVacc"] = self.tile_n * self.tile_hdimv * Float32.width // 8
|
|
|
|
| 475 |
(self.tile_m, self.tile_hdimv),
|
| 476 |
)
|
| 477 |
if const_expr(self.qhead_per_kvhead == 1):
|
| 478 |
+
mdK_tma = (
|
| 479 |
+
copy_utils.create_ragged_tensor_for_tma(mdK, ragged_dim=0, ptr_shift=True)
|
| 480 |
+
if self.varlen_k
|
| 481 |
+
else mdK
|
| 482 |
+
)
|
| 483 |
+
mdV_tma = (
|
| 484 |
+
copy_utils.create_ragged_tensor_for_tma(mdV, ragged_dim=0, ptr_shift=True)
|
| 485 |
+
if self.varlen_k
|
| 486 |
+
else mdV
|
| 487 |
+
)
|
| 488 |
tma_atom_dK, tma_tensor_dK = cpasync.make_tiled_tma_atom(
|
| 489 |
cpasync.CopyBulkTensorTileS2GOp(),
|
| 490 |
+
mdK_tma,
|
| 491 |
cute.select(self.sK_layout, mode=[0, 1]),
|
| 492 |
(self.tile_n, self.tile_hdim),
|
| 493 |
)
|
| 494 |
tma_atom_dV, tma_tensor_dV = cpasync.make_tiled_tma_atom(
|
| 495 |
cpasync.CopyBulkTensorTileS2GOp(),
|
| 496 |
+
mdV_tma,
|
| 497 |
cute.select(self.sV_layout, mode=[0, 1]),
|
| 498 |
(self.tile_n, self.tile_hdimv),
|
| 499 |
)
|
| 500 |
else:
|
| 501 |
tma_atom_dK = tma_atom_dV = tma_tensor_dK = tma_tensor_dV = None
|
| 502 |
|
| 503 |
+
if const_expr(mCuSeqlensK is not None or mSeqUsedK is not None):
|
| 504 |
+
TileScheduler = SingleTileVarlenScheduler
|
| 505 |
+
elif const_expr(self.deterministic):
|
| 506 |
+
TileScheduler = SingleTileLPTBwdScheduler
|
| 507 |
+
else:
|
| 508 |
+
TileScheduler = SingleTileScheduler
|
| 509 |
+
self.spt = (self.is_causal or self.is_local) and self.deterministic
|
| 510 |
tile_sched_args = TileSchedulerArguments(
|
| 511 |
cute.ceil_div(cute.size(mK.shape[0]), self.tile_n),
|
| 512 |
cute.size(mQ.shape[2]),
|
| 513 |
+
cute.size(mK.shape[3])
|
| 514 |
+
if const_expr(mCuSeqlensK is None)
|
| 515 |
+
else cute.size(mCuSeqlensK.shape[0] - 1), # num_batch
|
| 516 |
1, # num_splits
|
| 517 |
+
cute.size(mQ.shape[0]), # pass seqlen_q or total_q for seqlen_k
|
| 518 |
+
mQ.shape[1], # headdim
|
| 519 |
+
mV.shape[1], # headdim_v
|
| 520 |
+
total_q=cute.size(mK.shape[0])
|
| 521 |
+
if const_expr(mCuSeqlensK is not None)
|
| 522 |
+
else cute.size(mK.shape[0]) * cute.size(mK.shape[3]),
|
| 523 |
+
tile_shape_mn=(self.tile_n, self.tile_m), # Swapping the role of Q & K
|
| 524 |
+
mCuSeqlensQ=mCuSeqlensK,
|
| 525 |
+
mSeqUsedQ=mSeqUsedK,
|
| 526 |
qhead_per_kvhead_packgqa=1,
|
| 527 |
element_size=self.dtype.width // 8,
|
| 528 |
is_persistent=False,
|
| 529 |
+
lpt=self.spt,
|
| 530 |
+
head_swizzle=self.deterministic,
|
| 531 |
)
|
| 532 |
|
| 533 |
tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
|
|
|
|
| 553 |
|
| 554 |
self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None)
|
| 555 |
|
| 556 |
+
if const_expr(window_size_left is not None):
|
| 557 |
+
window_size_left = Int32(window_size_left)
|
| 558 |
+
if const_expr(window_size_right is not None):
|
| 559 |
+
window_size_right = Int32(window_size_right)
|
| 560 |
+
|
| 561 |
self.kernel(
|
| 562 |
tma_tensor_Q,
|
| 563 |
tma_tensor_K,
|
|
|
|
| 574 |
mLSE,
|
| 575 |
mdPsum,
|
| 576 |
mdQaccum,
|
| 577 |
+
mCuSeqlensQ,
|
| 578 |
+
mCuSeqlensK,
|
| 579 |
+
mSeqUsedQ,
|
| 580 |
+
mSeqUsedK,
|
| 581 |
self.sQ_layout,
|
| 582 |
self.sK_layout,
|
| 583 |
self.sV_layout,
|
|
|
|
| 598 |
fastdiv_mods,
|
| 599 |
blocksparse_tensors,
|
| 600 |
qhead_per_kvhead_divmod,
|
| 601 |
+
mdQ_semaphore,
|
| 602 |
+
window_size_left,
|
| 603 |
+
window_size_right,
|
| 604 |
).launch(
|
| 605 |
grid=grid_dim,
|
| 606 |
block=[self.num_threads, 1, 1],
|
| 607 |
stream=stream,
|
| 608 |
min_blocks_per_mp=1,
|
| 609 |
+
use_pdl=True,
|
| 610 |
)
|
| 611 |
|
| 612 |
@cute.kernel
|
|
|
|
| 627 |
mLSE: cute.Tensor,
|
| 628 |
mdPsum: cute.Tensor,
|
| 629 |
mdQaccum: cute.Tensor,
|
| 630 |
+
mCuSeqlensQ: Optional[cute.Tensor],
|
| 631 |
+
mCuSeqlensK: Optional[cute.Tensor],
|
| 632 |
+
mSeqUsedQ: Optional[cute.Tensor],
|
| 633 |
+
mSeqUsedK: Optional[cute.Tensor],
|
| 634 |
sQ_layout: cute.ComposedLayout,
|
| 635 |
sK_layout: cute.ComposedLayout,
|
| 636 |
sV_layout: cute.ComposedLayout,
|
|
|
|
| 651 |
fastdiv_mods=(None, None),
|
| 652 |
blocksparse_tensors: Optional[BlockSparseTensors] = None,
|
| 653 |
qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None,
|
| 654 |
+
mdQ_semaphore: Optional[cute.Tensor] = None,
|
| 655 |
+
window_size_left: Optional[Int32] = None,
|
| 656 |
+
window_size_right: Optional[Int32] = None,
|
| 657 |
):
|
| 658 |
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
| 659 |
|
| 660 |
# prefetch TMA descriptors
|
| 661 |
if warp_idx == 0:
|
| 662 |
+
for atom in [tma_atom_Q, tma_atom_K, tma_atom_V, tma_atom_dO, tma_atom_dK, tma_atom_dV]:
|
| 663 |
+
if const_expr(atom is not None):
|
| 664 |
+
cpasync.prefetch_descriptor(atom)
|
|
|
|
| 665 |
|
| 666 |
smem = cutlass.utils.SmemAllocator()
|
| 667 |
storage = smem.allocate(SharedStorage)
|
|
|
|
| 715 |
self.is_causal,
|
| 716 |
self.is_local,
|
| 717 |
False, # is_split_kv
|
| 718 |
+
window_size_left,
|
| 719 |
+
window_size_right,
|
| 720 |
qhead_per_kvhead_packgqa=1,
|
| 721 |
)
|
| 722 |
SeqlenInfoCls = partial(
|
| 723 |
SeqlenInfoQK.create,
|
| 724 |
seqlen_q_static=mQ.shape[0],
|
| 725 |
seqlen_k_static=mK.shape[0],
|
| 726 |
+
mCuSeqlensQ=mCuSeqlensQ,
|
| 727 |
+
mCuSeqlensK=mCuSeqlensK,
|
| 728 |
+
mSeqUsedQ=mSeqUsedQ,
|
| 729 |
+
mSeqUsedK=mSeqUsedK,
|
| 730 |
+
tile_m=self.tile_m,
|
| 731 |
+
tile_n=self.tile_n,
|
| 732 |
)
|
| 733 |
AttentionMaskCls = partial(
|
| 734 |
AttentionMask,
|
| 735 |
self.tile_m,
|
| 736 |
self.tile_n,
|
| 737 |
+
window_size_left=window_size_left,
|
| 738 |
+
window_size_right=window_size_right,
|
| 739 |
swap_AB=self.SdP_swapAB,
|
| 740 |
)
|
| 741 |
TileSchedulerCls = partial(TileScheduler.create, tile_sched_params)
|
|
|
|
| 776 |
TileSchedulerCls,
|
| 777 |
SeqlenInfoCls,
|
| 778 |
blocksparse_tensors,
|
| 779 |
+
mdQ_semaphore,
|
| 780 |
)
|
| 781 |
else:
|
|
|
|
| 782 |
tidx, _, _ = cute.arch.thread_idx()
|
| 783 |
tidx = tidx - 128
|
| 784 |
+
mma_args = (
|
| 785 |
tiled_mma_SdP,
|
| 786 |
tiled_mma_dK,
|
| 787 |
tiled_mma_dV,
|
|
|
|
| 815 |
blocksparse_tensors,
|
| 816 |
qhead_per_kvhead_divmod,
|
| 817 |
)
|
| 818 |
+
if const_expr(self.num_wg_dQ == self.num_wg_mma):
|
| 819 |
+
# Both WGs compute dQ
|
| 820 |
+
cute.arch.setmaxregister_increase(self.num_mma_regs_wg0)
|
| 821 |
+
self.mma(*mma_args, is_dQ_wg=True)
|
| 822 |
+
else:
|
| 823 |
+
# WG0 computes dQ, WG1 skips it
|
| 824 |
+
warp_idx_in_mma = cute.arch.make_warp_uniform(cute.arch.warp_idx()) - 4
|
| 825 |
+
if warp_idx_in_mma < 4:
|
| 826 |
+
cute.arch.setmaxregister_increase(self.num_mma_regs_wg0)
|
| 827 |
+
self.mma(*mma_args, is_dQ_wg=True)
|
| 828 |
+
else:
|
| 829 |
+
cute.arch.setmaxregister_increase(self.num_mma_regs_wg1)
|
| 830 |
+
self.mma(*mma_args, is_dQ_wg=False)
|
| 831 |
|
| 832 |
@cute.jit
|
| 833 |
def load(
|
|
|
|
| 875 |
if const_expr(self.qhead_per_kvhead == 1)
|
| 876 |
else head_idx // qhead_per_kvhead_divmod
|
| 877 |
)
|
| 878 |
+
mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[None, None, head_idx_kv]
|
| 879 |
+
mV_cur = seqlen.offset_batch_K(mV, batch_idx, dim=3)[None, None, head_idx_kv]
|
| 880 |
gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (n_block, 0))
|
|
|
|
| 881 |
gV = cute.local_tile(mV_cur, (self.tile_n, self.tile_hdimv), (n_block, 0))
|
| 882 |
|
| 883 |
+
mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx]
|
| 884 |
+
mLSE_cur = seqlen.offset_batch_Q(mLSE, batch_idx, dim=2, padded=True)[
|
| 885 |
+
None, head_idx
|
| 886 |
+
]
|
| 887 |
+
mdO_cur = seqlen.offset_batch_Q(mdO, batch_idx, dim=3)[None, None, head_idx]
|
| 888 |
+
mdPsum_cur = seqlen.offset_batch_Q(mdPsum, batch_idx, dim=2, padded=True)[
|
| 889 |
+
None, head_idx
|
| 890 |
+
]
|
| 891 |
gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (None, 0))
|
|
|
|
| 892 |
gdO = cute.local_tile(mdO_cur, (self.tile_m, self.tile_hdimv), (None, 0))
|
|
|
|
| 893 |
gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (None,))
|
|
|
|
| 894 |
gdPsum = cute.local_tile(mdPsum_cur, (self.tile_m,), (None,))
|
| 895 |
|
| 896 |
load_K, _, _ = copy_utils.tma_get_copy_fn(
|
|
|
|
| 916 |
|
| 917 |
if const_expr(not self.use_block_sparsity):
|
| 918 |
total_m_block_cnt = m_block_max - m_block_min
|
| 919 |
+
process_tile = (
|
| 920 |
+
const_expr(not self.is_local and not self.is_varlen_q)
|
| 921 |
+
or m_block_min < m_block_max
|
| 922 |
+
)
|
| 923 |
else:
|
| 924 |
total_m_block_cnt = get_total_q_block_count_bwd(
|
| 925 |
blocksparse_tensors,
|
|
|
|
| 939 |
)
|
| 940 |
load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q))
|
| 941 |
load_Q(first_m_block, producer_state=producer_state_Q)
|
| 942 |
+
# Wait for bwd preprocess to finish writing LSE and dPsum
|
| 943 |
+
cute.arch.griddepcontrol_wait()
|
| 944 |
load_LSE(first_m_block, producer_state=producer_state_Q)
|
| 945 |
producer_state_dO_cur = (
|
| 946 |
producer_state_dO
|
|
|
|
| 1119 |
fastdiv_mods=(None, None),
|
| 1120 |
blocksparse_tensors: Optional[BlockSparseTensors] = None,
|
| 1121 |
qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None,
|
| 1122 |
+
is_dQ_wg: cutlass.Constexpr[bool] = True,
|
| 1123 |
):
|
| 1124 |
warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group)
|
| 1125 |
warp_group_thread_layout = cute.make_layout(
|
| 1126 |
+
self.num_wg_mma, stride=self.num_threads_per_warp_group
|
| 1127 |
)
|
| 1128 |
thr_mma_SdP = tiled_mma_SdP.get_slice(tidx)
|
| 1129 |
wg_mma_SdP = tiled_mma_SdP.get_slice(warp_group_thread_layout(warp_group_idx))
|
| 1130 |
wg_mma_dK = tiled_mma_dK.get_slice(warp_group_thread_layout(warp_group_idx))
|
| 1131 |
wg_mma_dV = tiled_mma_dV.get_slice(warp_group_thread_layout(warp_group_idx))
|
| 1132 |
+
wg_mma_dQ = None
|
| 1133 |
+
if const_expr(is_dQ_wg):
|
| 1134 |
+
wg_idx_dQ = warp_group_idx if const_expr(self.num_wg_dQ > 1) else 0
|
| 1135 |
+
wg_mma_dQ = tiled_mma_dQ.get_slice(warp_group_thread_layout(wg_idx_dQ))
|
| 1136 |
# S = Q @ K.T
|
| 1137 |
shape_mnk_S = (self.tile_m, self.tile_n, self.tile_hdim)
|
| 1138 |
_, tSrQ, tSrK = sm90_utils.partition_fragment_ABC(
|
|
|
|
| 1178 |
# dQ = dS @ K
|
| 1179 |
sKt = layout_utils.transpose_view(sK)
|
| 1180 |
shape_mnk_dQ = (self.tile_m, self.tile_hdim, self.tile_n)
|
| 1181 |
+
mma_dsk_fn = None
|
| 1182 |
+
if const_expr(is_dQ_wg):
|
| 1183 |
+
_, tdQrdS, tdQrKt = sm90_utils.partition_fragment_ABC(
|
| 1184 |
+
wg_mma_dQ, shape_mnk_dQ, sdS, sKt, swap_AB=self.dQ_swapAB
|
| 1185 |
+
)
|
| 1186 |
+
mma_dsk_fn = partial(
|
| 1187 |
+
gemm_zero_init,
|
| 1188 |
+
tiled_mma_dQ,
|
| 1189 |
+
shape_mnk_dQ[:2],
|
| 1190 |
+
tdQrdS,
|
| 1191 |
+
tdQrKt,
|
| 1192 |
+
swap_AB=self.dQ_swapAB,
|
| 1193 |
+
)
|
| 1194 |
|
| 1195 |
+
# Smem copy atom tiling for P/dS R2S
|
| 1196 |
copy_P_r2s = None
|
| 1197 |
+
mms_PdS = self.tile_n // (self.num_wg_mma // self.AtomLayoutMSdP)
|
| 1198 |
if const_expr(sP is not None):
|
| 1199 |
sP_cpy = sP if const_expr(not self.SdP_swapAB) else sPt
|
| 1200 |
copy_P_r2s, _, _ = copy_utils.get_smem_store_C(
|
| 1201 |
+
tiled_mma_SdP,
|
| 1202 |
+
sP_cpy,
|
| 1203 |
+
tidx,
|
| 1204 |
+
self.arch,
|
| 1205 |
+
transpose=self.SdP_swapAB,
|
| 1206 |
+
position_independent=True,
|
| 1207 |
+
major_mode_size=mms_PdS,
|
| 1208 |
)
|
| 1209 |
sdS_cpy = sdS if const_expr(not self.SdP_swapAB) else sdSt
|
| 1210 |
copy_dS_r2s, _, _ = copy_utils.get_smem_store_C(
|
| 1211 |
+
tiled_mma_SdP,
|
| 1212 |
+
sdS_cpy,
|
| 1213 |
+
tidx,
|
| 1214 |
+
self.arch,
|
| 1215 |
+
transpose=self.SdP_swapAB,
|
| 1216 |
+
position_independent=True,
|
| 1217 |
+
major_mode_size=mms_PdS,
|
| 1218 |
)
|
| 1219 |
|
| 1220 |
tLSEsLSE = layout_utils.mma_partition_C_vec(
|
|
|
|
| 1223 |
tLSEsdPsum = layout_utils.mma_partition_C_vec(
|
| 1224 |
sdPsum, thr_mma_SdP, expand_shape=self.tile_n, is_colvec=not self.SdP_swapAB
|
| 1225 |
)
|
| 1226 |
+
# When shuffle=True, rows are distributed across 8 quads (4 threads each) within a warp.
|
| 1227 |
+
# Each thread loads only ceil(num_rows/8) values;
|
| 1228 |
+
shfl_copy = copy_utils.tiled_copy_1d(sLSE.element_type, num_threads=8, num_copy_elems=2)
|
| 1229 |
+
if const_expr(self.shuffle_LSE):
|
| 1230 |
+
tLSEsLSE = shfl_copy.get_slice(cute.arch.lane_idx() // 4).partition_S(tLSEsLSE)
|
| 1231 |
+
# ((2, 1), 1, 2) -> (((2, 1), 1), 2)
|
| 1232 |
+
tLSEsLSE = cute.group_modes(tLSEsLSE, 0, 2)
|
| 1233 |
+
if const_expr(self.shuffle_dPsum):
|
| 1234 |
+
tLSEsdPsum = shfl_copy.get_slice(cute.arch.lane_idx() // 4).partition_S(tLSEsdPsum)
|
| 1235 |
+
tLSEsdPsum = cute.group_modes(tLSEsdPsum, 0, 2)
|
| 1236 |
+
|
| 1237 |
+
tdQsdQaccum = None
|
| 1238 |
+
if const_expr(is_dQ_wg):
|
| 1239 |
+
smem_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_slice(tidx)
|
| 1240 |
+
tdQsdQaccum = smem_thr_copy_dQaccum.partition_D(sdQaccum)
|
| 1241 |
|
| 1242 |
PdS_barrier = cutlass.pipeline.NamedBarrier(
|
| 1243 |
barrier_id=int(NamedBarrierBwd.PdS), num_threads=self.num_mma_threads
|
|
|
|
| 1276 |
PdS_barrier=PdS_barrier,
|
| 1277 |
# acc_dV=acc_dV,
|
| 1278 |
# acc_dK=acc_dK,
|
| 1279 |
+
is_dQ_wg=is_dQ_wg,
|
| 1280 |
)
|
| 1281 |
|
| 1282 |
consumer_state_Q = cutlass.pipeline.make_pipeline_state(
|
|
|
|
| 1308 |
m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block)
|
| 1309 |
|
| 1310 |
if const_expr(not self.use_block_sparsity):
|
| 1311 |
+
process_tile = (
|
| 1312 |
+
const_expr(not self.is_local and not self.is_varlen_q)
|
| 1313 |
+
or m_block_min < m_block_max
|
| 1314 |
+
)
|
| 1315 |
else:
|
| 1316 |
total_m_block_cnt = get_total_q_block_count_bwd(
|
| 1317 |
blocksparse_tensors,
|
|
|
|
| 1393 |
qhead_per_kvhead_divmod,
|
| 1394 |
)
|
| 1395 |
else:
|
| 1396 |
+
# KV tile with zero Q blocks produces no dK/dV; write zeros.
|
| 1397 |
+
if const_expr(self.use_block_sparsity or self.is_local or self.is_varlen_q):
|
| 1398 |
acc_dK.fill(0.0)
|
| 1399 |
acc_dV.fill(0.0)
|
| 1400 |
self.epilogue_dKV(
|
|
|
|
| 1423 |
if warp_idx == 4:
|
| 1424 |
cute.arch.cp_async_bulk_wait_group(0, read=True)
|
| 1425 |
|
| 1426 |
+
@staticmethod
|
| 1427 |
+
@cute.jit
|
| 1428 |
+
def _get_stat(tSrS: cute.Tensor, row: Int32, lane: Int32, shuffle: bool) -> Float32:
|
| 1429 |
+
"""Retrieve the statistic for a given accumulator row.
|
| 1430 |
+
|
| 1431 |
+
When shuffle=False, direct register indexing.
|
| 1432 |
+
When shuffle=True, warp shuffle from the thread group that holds the value.
|
| 1433 |
+
"""
|
| 1434 |
+
if const_expr(not shuffle):
|
| 1435 |
+
return tSrS[row]
|
| 1436 |
+
# tSrS: (((2, 1), 1), 1)), distributed across 8 threads in the warp
|
| 1437 |
+
vecsize = cute.size(tSrS, mode=[0, 0]) # 2
|
| 1438 |
+
idx0, off, idx1 = cute.idx2crd(row, (vecsize, 8, cute.shape(tSrS, mode=[0, 1])))
|
| 1439 |
+
# register index: 0, 1, 0, 1, ..., 2, 3, 2, 3, ...
|
| 1440 |
+
return utils.shuffle_sync(tSrS[idx0 + idx1 * vecsize], offset=off * 4 + (lane % 4))
|
| 1441 |
+
|
| 1442 |
@cute.jit
|
| 1443 |
def mma_one_m_block(
|
| 1444 |
self,
|
|
|
|
| 1457 |
pipeline_dO: cutlass.pipeline.PipelineAsync,
|
| 1458 |
tLSEsLSE: cute.Tensor,
|
| 1459 |
tLSEsdPsum: cute.Tensor,
|
| 1460 |
+
tdQsdQaccum: Optional[cute.Tensor],
|
| 1461 |
softmax_scale_log2: Float32,
|
| 1462 |
PdS_barrier: cutlass.pipeline.NamedBarrier,
|
| 1463 |
+
is_dQ_wg: cutlass.Constexpr[bool] = True,
|
| 1464 |
mask_fn: Optional[Callable] = None,
|
| 1465 |
score_mod_fn: Optional[Callable] = None,
|
| 1466 |
score_mod_bwd_fn: Optional[Callable] = None,
|
| 1467 |
dKV_accumulate: Boolean = True,
|
| 1468 |
):
|
| 1469 |
consumer_state_dO_cur = (
|
| 1470 |
+
consumer_state_Q if const_expr(self.Q_stage == self.dO_stage) else consumer_state_dO
|
| 1471 |
)
|
| 1472 |
smem_idx_Q = consumer_state_Q.index
|
| 1473 |
smem_idx_dO = consumer_state_dO_cur.index if const_expr(self.dO_stage > 1) else 0
|
|
|
|
| 1475 |
# (1) [GEMM 1] S = Q @ K^T
|
| 1476 |
pipeline_Q.consumer_wait(consumer_state_Q, pipeline_Q.consumer_try_wait(consumer_state_Q))
|
| 1477 |
acc_S = mma_qk_fn(A_idx=smem_idx_Q, wg_wait=-1)
|
| 1478 |
+
# If shuffle_LSE, OOB reads are OK since sLSE is already padded
|
| 1479 |
tLSErLSE = copy_utils.load_s2r(tLSEsLSE[None, smem_idx_Q])
|
| 1480 |
# (2) [GEMM 2] dP = dO @ V.T
|
| 1481 |
pipeline_dO.consumer_wait(
|
|
|
|
| 1494 |
if cutlass.const_expr(mask_fn is not None):
|
| 1495 |
mask_fn(acc_S, m_block=m_block)
|
| 1496 |
acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S, transpose=self.SdP_swapAB)
|
| 1497 |
+
lane_idx = cute.arch.lane_idx()
|
| 1498 |
for r in cutlass.range_constexpr(cute.size(acc_S_mn, mode=[0])):
|
| 1499 |
+
lse_val = self._get_stat(tLSErLSE, r, lane_idx, shuffle=self.shuffle_LSE)
|
| 1500 |
for c in cutlass.range(cute.size(acc_S_mn, mode=[1]), unroll_full=True):
|
| 1501 |
acc_S_mn[r, c] = cute.math.exp2(
|
| 1502 |
+
acc_S_mn[r, c] * softmax_scale_log2 - lse_val, fastmath=True
|
| 1503 |
)
|
| 1504 |
tLSErdPsum = copy_utils.load_s2r(tLSEsdPsum[None, smem_idx_dO])
|
| 1505 |
|
|
|
|
| 1516 |
warpgroup.wait_group(0)
|
| 1517 |
acc_dP_mn = layout_utils.reshape_acc_to_mn(acc_dP, transpose=self.SdP_swapAB)
|
| 1518 |
for r in cutlass.range_constexpr(cute.size(acc_dP_mn, mode=[0])):
|
| 1519 |
+
dpsum_val = self._get_stat(tLSErdPsum, r, lane_idx, shuffle=self.shuffle_dPsum)
|
| 1520 |
for c in cutlass.range(cute.size(acc_dP_mn, mode=[1]), unroll_full=True):
|
| 1521 |
+
acc_dP_mn[r, c] = acc_S_mn[r, c] * (acc_dP_mn[r, c] - dpsum_val)
|
| 1522 |
|
| 1523 |
if const_expr(self.score_mod_bwd is not None):
|
| 1524 |
score_mod_bwd_fn(acc_dP, acc_S_pre, m_block=m_block)
|
|
|
|
| 1550 |
# smem fence to make sure sdS is written before it's read by WGMMA
|
| 1551 |
cute.arch.fence_view_async_shared()
|
| 1552 |
PdS_barrier.arrive_and_wait()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1553 |
|
| 1554 |
+
if const_expr(is_dQ_wg):
|
| 1555 |
+
# (6) [GEMM 4] dQ = dS @ K
|
| 1556 |
+
acc_dQ = mma_dsk_fn(A_idx=smem_idx_PdS, wg_wait=1)
|
| 1557 |
+
pipeline_dO.consumer_release(consumer_state_dO_cur) # release dO as dV mma is done
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1558 |
|
| 1559 |
+
# (7) [GEMM 5] dK += dS.T @ Q
|
| 1560 |
+
if const_expr(not self.mma_dkv_is_rs):
|
| 1561 |
+
mma_dsq_fn(
|
| 1562 |
+
A_idx=smem_idx_PdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1
|
| 1563 |
+
)
|
| 1564 |
+
else:
|
| 1565 |
+
mma_dsq_fn(tCrA=tdKrdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1566 |
|
| 1567 |
+
# dQ R2S: wait for dQaccum_store to free the smem buffer, then write dQ to smem
|
| 1568 |
+
# When dQ_single_wg, only WG0 enters here so warp_group_idx == 0
|
| 1569 |
+
cute.arch.barrier(
|
| 1570 |
+
barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx,
|
| 1571 |
+
number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE,
|
| 1572 |
+
)
|
| 1573 |
+
tdQrdQaccum_flat = cute.make_tensor(
|
| 1574 |
+
acc_dQ.iterator, cute.make_layout(tdQsdQaccum.shape)
|
| 1575 |
+
)
|
| 1576 |
+
cute.autovec_copy(tdQrdQaccum_flat, tdQsdQaccum)
|
| 1577 |
+
cute.arch.fence_view_async_shared()
|
| 1578 |
+
cute.arch.barrier_arrive(
|
| 1579 |
+
barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx,
|
| 1580 |
+
number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE,
|
| 1581 |
+
)
|
| 1582 |
+
|
| 1583 |
+
warpgroup.wait_group(0)
|
| 1584 |
+
pipeline_Q.consumer_release(consumer_state_Q)
|
| 1585 |
+
else:
|
| 1586 |
+
# dQ_single_wg: WG1 skips dQ, only does dV wait + dK
|
| 1587 |
+
# (7) [GEMM 5] dK += dS.T @ Q
|
| 1588 |
+
if const_expr(not self.mma_dkv_is_rs):
|
| 1589 |
+
mma_dsq_fn(
|
| 1590 |
+
A_idx=smem_idx_PdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1
|
| 1591 |
+
)
|
| 1592 |
+
else:
|
| 1593 |
+
mma_dsq_fn(tCrA=tdKrdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1)
|
| 1594 |
+
pipeline_dO.consumer_release(consumer_state_dO_cur)
|
| 1595 |
+
warpgroup.wait_group(0)
|
| 1596 |
+
pipeline_Q.consumer_release(consumer_state_Q)
|
| 1597 |
|
| 1598 |
consumer_state_Q.advance()
|
| 1599 |
consumer_state_dO.advance()
|
|
|
|
| 1625 |
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
| 1626 |
|
| 1627 |
if const_expr(self.qhead_per_kvhead == 1):
|
| 1628 |
+
mdK_cur = seqlen.offset_batch_K(mdK, batch_idx, dim=3, ragged=self.varlen_k)[
|
| 1629 |
+
None, None, head_idx
|
| 1630 |
+
]
|
| 1631 |
+
mdV_cur = seqlen.offset_batch_K(mdV, batch_idx, dim=3, ragged=self.varlen_k)[
|
| 1632 |
+
None, None, head_idx
|
| 1633 |
+
]
|
| 1634 |
gdK = cute.local_tile(mdK_cur, (self.tile_n, self.tile_hdim), (n_block, 0))
|
| 1635 |
gdV = cute.local_tile(mdV_cur, (self.tile_n, self.tile_hdimv), (n_block, 0))
|
| 1636 |
store_dK, _, _ = copy_utils.tma_get_copy_fn(
|
|
|
|
| 1642 |
sdV = sV if const_expr(not self.dKV_swapAB) else layout_utils.transpose_view(sV)
|
| 1643 |
sdK = sK if const_expr(not self.dKV_swapAB) else layout_utils.transpose_view(sK)
|
| 1644 |
copy_dV_r2s, _, _ = copy_utils.get_smem_store_C(
|
| 1645 |
+
tiled_mma_dV,
|
| 1646 |
+
sdV,
|
| 1647 |
+
tidx,
|
| 1648 |
+
self.arch,
|
| 1649 |
+
transpose=self.dKV_swapAB,
|
| 1650 |
+
position_independent=True,
|
| 1651 |
)
|
| 1652 |
copy_dK_r2s, _, _ = copy_utils.get_smem_store_C(
|
| 1653 |
+
tiled_mma_dK,
|
| 1654 |
+
sdK,
|
| 1655 |
+
tidx,
|
| 1656 |
+
self.arch,
|
| 1657 |
+
transpose=self.dKV_swapAB,
|
| 1658 |
+
position_independent=True,
|
| 1659 |
)
|
| 1660 |
cute.arch.cp_async_bulk_wait_group(1, read=True)
|
| 1661 |
epi_barrier.arrive_and_wait()
|
|
|
|
| 1674 |
store_dK()
|
| 1675 |
cute.arch.cp_async_bulk_commit_group()
|
| 1676 |
else:
|
| 1677 |
+
sdKaccum_shape0 = self.tile_n * self.tile_hdim // self.num_wg_mma
|
| 1678 |
+
sdVaccum_shape0 = self.tile_n * self.tile_hdimv // self.num_wg_mma
|
| 1679 |
+
sdKaccum_layout = cute.make_layout((sdKaccum_shape0, self.num_wg_mma))
|
| 1680 |
+
sdVaccum_layout = cute.make_layout((sdVaccum_shape0, self.num_wg_mma))
|
| 1681 |
head_idx_kv = head_idx // qhead_per_kvhead_divmod
|
| 1682 |
+
mdKaccum_cur = seqlen.offset_batch_K(
|
| 1683 |
+
mdK, batch_idx, dim=2, padded=True, multiple=self.tile_hdim
|
| 1684 |
+
)[None, head_idx_kv]
|
| 1685 |
+
mdVaccum_cur = seqlen.offset_batch_K(
|
| 1686 |
+
mdV, batch_idx, dim=2, padded=True, multiple=self.tile_hdimv
|
| 1687 |
+
)[None, head_idx_kv]
|
| 1688 |
gdKaccum_ = cute.local_tile(mdKaccum_cur, (self.tile_n * self.tile_hdim,), (n_block,))
|
| 1689 |
gdKaccum = cute.flat_divide(gdKaccum_, (sdKaccum_shape0,))
|
|
|
|
| 1690 |
gdVaccum_ = cute.local_tile(mdVaccum_cur, (self.tile_n * self.tile_hdimv,), (n_block,))
|
| 1691 |
gdVaccum = cute.flat_divide(gdVaccum_, (sdVaccum_shape0,))
|
| 1692 |
# These two overlap each other
|
|
|
|
| 1695 |
sdVaccum = cute.make_tensor(sVaccum_ptr, sdVaccum_layout)
|
| 1696 |
tiled_copy_dKVaccum_r2s = cute.make_tiled_copy_tv(
|
| 1697 |
cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128),
|
| 1698 |
+
cute.make_layout((self.num_threads_per_warp_group, self.num_wg_mma)),
|
| 1699 |
cute.make_layout(128 // Float32.width),
|
| 1700 |
)
|
| 1701 |
thr_copy_dKVaccum_r2s = tiled_copy_dKVaccum_r2s.get_slice(tidx)
|
|
|
|
| 1710 |
epi_barrier.arrive_and_wait()
|
| 1711 |
if warp_idx == 4:
|
| 1712 |
with cute.arch.elect_one():
|
| 1713 |
+
for wg_idx in cutlass.range_constexpr(self.num_wg_mma):
|
| 1714 |
copy_utils.cpasync_reduce_bulk_add_f32(
|
| 1715 |
sdKaccum[None, wg_idx].iterator,
|
| 1716 |
gdKaccum[None, wg_idx].iterator,
|
| 1717 |
+
self.tma_copy_bytes["dKacc"] // self.num_wg_mma,
|
| 1718 |
)
|
| 1719 |
cute.arch.cp_async_bulk_commit_group()
|
| 1720 |
|
|
|
|
| 1726 |
epi_barrier.arrive_and_wait()
|
| 1727 |
if warp_idx == 4:
|
| 1728 |
with cute.arch.elect_one():
|
| 1729 |
+
for wg_idx in cutlass.range_constexpr(self.num_wg_mma):
|
| 1730 |
copy_utils.cpasync_reduce_bulk_add_f32(
|
| 1731 |
sdVaccum[None, wg_idx].iterator,
|
| 1732 |
gdVaccum[None, wg_idx].iterator,
|
| 1733 |
+
self.tma_copy_bytes["dVacc"] // self.num_wg_mma,
|
| 1734 |
)
|
| 1735 |
cute.arch.cp_async_bulk_commit_group()
|
| 1736 |
|
|
|
|
| 1743 |
TileSchedulerCls: cutlass.Constexpr[Callable],
|
| 1744 |
SeqlenInfoCls: cutlass.Constexpr[Callable],
|
| 1745 |
blocksparse_tensors: Optional[BlockSparseTensors] = None,
|
| 1746 |
+
mdQ_semaphore: Optional[cute.Tensor] = None,
|
| 1747 |
):
|
| 1748 |
+
tidx, _, _ = cute.arch.thread_idx()
|
| 1749 |
+
# warp-local thread index (dQaccum_store runs on warp 1, global tidx 32-63)
|
| 1750 |
+
warp_local_tidx = tidx % cute.arch.WARP_SIZE
|
| 1751 |
+
read_flag = const_expr(not self.deterministic)
|
| 1752 |
+
|
| 1753 |
tile_scheduler = TileSchedulerCls()
|
| 1754 |
work_tile = tile_scheduler.initial_work_tile_info()
|
| 1755 |
while work_tile.is_valid_tile:
|
| 1756 |
n_block, head_idx, batch_idx, _ = work_tile.tile_idx
|
| 1757 |
seqlen = SeqlenInfoCls(batch_idx)
|
| 1758 |
+
if const_expr(not seqlen.has_cu_seqlens_q):
|
| 1759 |
+
mdQaccum_cur = mdQaccum[None, head_idx, batch_idx]
|
| 1760 |
+
else:
|
| 1761 |
+
mdQaccum_cur = cute.domain_offset(
|
| 1762 |
+
(seqlen.padded_offset_q * self.tile_hdim,), mdQaccum[None, head_idx]
|
| 1763 |
+
)
|
| 1764 |
+
# ((M * K / num_wg_dQ, num_wg_dQ), num_m_blocks)
|
| 1765 |
+
gdQaccum = cute.local_tile(
|
| 1766 |
+
mdQaccum_cur,
|
| 1767 |
+
(
|
| 1768 |
+
cute.make_layout(
|
| 1769 |
+
(self.tile_m * self.tile_hdim // self.num_wg_dQ, self.num_wg_dQ)
|
| 1770 |
+
),
|
| 1771 |
+
),
|
| 1772 |
+
(None,),
|
| 1773 |
)
|
| 1774 |
+
|
| 1775 |
+
if const_expr(mdQ_semaphore is not None):
|
| 1776 |
+
# mdQ_semaphore is (num_m_blocks, cluster_size, num_head, batch) after transpose
|
| 1777 |
+
mdQ_semaphore_cur = mdQ_semaphore[None, None, head_idx, batch_idx]
|
| 1778 |
+
|
| 1779 |
m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block)
|
| 1780 |
if const_expr(not self.use_block_sparsity):
|
| 1781 |
+
process_tile = (
|
| 1782 |
+
const_expr(not self.is_local and not self.is_varlen_q)
|
| 1783 |
+
or m_block_min < m_block_max
|
| 1784 |
+
)
|
| 1785 |
loop_count = m_block_max - m_block_min
|
| 1786 |
else:
|
| 1787 |
total_block_cnt = get_total_q_block_count_bwd(
|
|
|
|
| 1800 |
m_block = m_block_min + iter_idx
|
| 1801 |
m_block_safe = m_block
|
| 1802 |
|
| 1803 |
+
num_dQ_chunks = self.num_wg_dQ
|
| 1804 |
+
for warp_group_idx in cutlass.range_constexpr(num_dQ_chunks):
|
| 1805 |
+
if const_expr(not self.deterministic):
|
| 1806 |
+
# If deterministic, we already waited at the end of the prev iter
|
| 1807 |
+
cute.arch.cp_async_bulk_wait_group(
|
| 1808 |
+
num_dQ_chunks - 1 - warp_group_idx, read=read_flag
|
| 1809 |
+
)
|
| 1810 |
cute.arch.barrier_arrive(
|
| 1811 |
barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx,
|
| 1812 |
number_of_threads=self.num_threads_per_warp_group
|
| 1813 |
+ cute.arch.WARP_SIZE,
|
| 1814 |
)
|
| 1815 |
|
| 1816 |
+
# Semaphore acquire: wait for prior n_blocks to finish writing this m_block
|
| 1817 |
+
if const_expr(self.deterministic):
|
| 1818 |
+
if const_expr(self.spt):
|
| 1819 |
+
_, n_block_max_for_m_block = block_info.get_n_block_min_max(
|
| 1820 |
+
seqlen, m_block_safe
|
| 1821 |
+
)
|
| 1822 |
+
lock_value = n_block_max_for_m_block - 1 - n_block
|
| 1823 |
+
else:
|
| 1824 |
+
lock_value = n_block
|
| 1825 |
+
barrier.wait_eq(
|
| 1826 |
+
mdQ_semaphore_cur[(m_block_safe, None)].iterator,
|
| 1827 |
+
warp_local_tidx,
|
| 1828 |
+
0, # flag_offset
|
| 1829 |
+
lock_value,
|
| 1830 |
+
)
|
| 1831 |
+
|
| 1832 |
+
for warp_group_idx in cutlass.range_constexpr(num_dQ_chunks):
|
| 1833 |
cute.arch.barrier(
|
| 1834 |
barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx,
|
| 1835 |
number_of_threads=self.num_threads_per_warp_group
|
|
|
|
| 1838 |
with cute.arch.elect_one():
|
| 1839 |
copy_utils.cpasync_reduce_bulk_add_f32(
|
| 1840 |
sdQaccum[None, warp_group_idx].iterator,
|
| 1841 |
+
gdQaccum[(None, warp_group_idx), m_block_safe].iterator,
|
| 1842 |
self.tma_copy_bytes["dQ"],
|
| 1843 |
)
|
| 1844 |
cute.arch.cp_async_bulk_commit_group()
|
| 1845 |
+
|
| 1846 |
+
# Semaphore release: signal that this n_block is done with this m_block
|
| 1847 |
+
if const_expr(self.deterministic):
|
| 1848 |
+
cute.arch.cp_async_bulk_wait_group(0, read=read_flag)
|
| 1849 |
+
barrier.arrive_inc(
|
| 1850 |
+
mdQ_semaphore_cur[(m_block_safe, None)].iterator,
|
| 1851 |
+
warp_local_tidx,
|
| 1852 |
+
0, # flag_offset
|
| 1853 |
+
1,
|
| 1854 |
+
)
|
| 1855 |
else:
|
| 1856 |
+
assert not self.deterministic, (
|
| 1857 |
+
"Deterministic not implemented for block-sparse backward"
|
| 1858 |
+
)
|
| 1859 |
dQaccum_store_block_sparse_bwd_sm90(
|
| 1860 |
blocksparse_tensors,
|
| 1861 |
batch_idx,
|
|
|
|
| 1865 |
gdQaccum,
|
| 1866 |
subtile_factor=self.subtile_factor,
|
| 1867 |
m_block_max=m_block_max,
|
| 1868 |
+
num_dQ_warp_groups=self.num_wg_dQ,
|
| 1869 |
num_threads_per_warp_group=self.num_threads_per_warp_group,
|
| 1870 |
tma_copy_bytes_dQ=self.tma_copy_bytes["dQ"],
|
| 1871 |
)
|
| 1872 |
+
|
| 1873 |
+
# For local masking + deterministic (non-spt): signal remaining m_blocks
|
| 1874 |
+
# that this n_block won't visit, so they don't deadlock waiting.
|
| 1875 |
+
if const_expr(
|
| 1876 |
+
self.deterministic and not self.spt and block_info.window_size_left is not None
|
| 1877 |
+
):
|
| 1878 |
+
m_block_global_max = cute.ceil_div(seqlen.seqlen_q, self.tile_m)
|
| 1879 |
+
for m_block in cutlass.range(m_block_max, m_block_global_max, unroll=1):
|
| 1880 |
+
barrier.arrive_inc(
|
| 1881 |
+
mdQ_semaphore_cur[(m_block, None)].iterator,
|
| 1882 |
+
warp_local_tidx,
|
| 1883 |
+
0, # flag_offset
|
| 1884 |
+
1,
|
| 1885 |
+
)
|
| 1886 |
+
|
| 1887 |
tile_scheduler.advance_to_next_work()
|
| 1888 |
work_tile = tile_scheduler.get_current_work()
|
| 1889 |
|
| 1890 |
+
if const_expr(not self.deterministic):
|
| 1891 |
+
cute.arch.cp_async_bulk_wait_group(0, read=True)
|
build/torch-cuda/flash_fwd.py
CHANGED
|
@@ -15,42 +15,28 @@ import cuda.bindings.driver as cuda
|
|
| 15 |
import cutlass
|
| 16 |
import cutlass.cute as cute
|
| 17 |
from cutlass import Constexpr, Float32, Int32, const_expr, Boolean
|
| 18 |
-
from cutlass.cute.nvgpu import cpasync, warp
|
| 19 |
import cutlass.utils as utils_basic
|
| 20 |
-
from cutlass.
|
| 21 |
-
|
| 22 |
|
| 23 |
from .quack import copy_utils
|
| 24 |
from .quack import layout_utils
|
| 25 |
-
from .quack import sm90_utils
|
| 26 |
|
| 27 |
from . import ampere_helpers as sm80_utils
|
| 28 |
from .cute_dsl_utils import assume_tensor_aligned
|
| 29 |
from . import utils
|
| 30 |
from .mask import AttentionMask
|
| 31 |
-
from .softmax import Softmax
|
| 32 |
from .seqlen_info import SeqlenInfoQK
|
| 33 |
from .block_info import BlockInfo
|
| 34 |
-
from .block_sparsity import BlockSparseTensors
|
| 35 |
-
from .block_sparse_utils import (
|
| 36 |
-
produce_block_sparse_loads,
|
| 37 |
-
consume_block_sparse_loads,
|
| 38 |
-
)
|
| 39 |
-
from . import pipeline
|
| 40 |
from .pack_gqa import PackGQA
|
| 41 |
from .named_barrier import NamedBarrierFwd
|
| 42 |
-
from .
|
| 43 |
-
from .tile_scheduler import
|
| 44 |
-
TileSchedulerArguments,
|
| 45 |
-
SingleTileScheduler,
|
| 46 |
-
SingleTileLPTScheduler,
|
| 47 |
-
SingleTileVarlenScheduler,
|
| 48 |
-
)
|
| 49 |
-
from cutlass.cute import FastDivmodDivisor
|
| 50 |
|
| 51 |
|
| 52 |
class FlashAttentionForwardBase:
|
| 53 |
-
arch: int = 80
|
| 54 |
|
| 55 |
def __init__(
|
| 56 |
self,
|
|
@@ -116,6 +102,12 @@ class FlashAttentionForwardBase:
|
|
| 116 |
self.vec_size: cutlass.Constexpr = getattr(
|
| 117 |
score_mod, "__vec_size__", 1 if cutlass.const_expr(has_aux_tensors) else 2
|
| 118 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
@staticmethod
|
| 121 |
def can_implement(
|
|
@@ -318,7 +310,8 @@ class FlashAttentionForwardBase:
|
|
| 318 |
mO: cute.Tensor,
|
| 319 |
mLSE: Optional[cute.Tensor],
|
| 320 |
softmax_scale: Float32,
|
| 321 |
-
stream:
|
|
|
|
| 322 |
):
|
| 323 |
"""Configures and launches the flash attention kernel.
|
| 324 |
|
|
@@ -351,7 +344,7 @@ class FlashAttentionForwardBase:
|
|
| 351 |
cute.arch.barrier(
|
| 352 |
barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads
|
| 353 |
)
|
| 354 |
-
smem_copy_atom_O = utils.get_smem_store_atom(self.arch, self.dtype)
|
| 355 |
smem_thr_copy_O = cute.make_tiled_copy_C(smem_copy_atom_O, tiled_mma).get_slice(tidx)
|
| 356 |
taccOrO = smem_thr_copy_O.retile(rO)
|
| 357 |
taccOsO = smem_thr_copy_O.partition_D(sO)
|
|
@@ -366,11 +359,7 @@ class FlashAttentionForwardBase:
|
|
| 366 |
|
| 367 |
# Write LSE from rmem -> gmem
|
| 368 |
if const_expr(mLSE is not None):
|
| 369 |
-
|
| 370 |
-
mLSE_cur = mLSE[None, head_idx, batch_idx]
|
| 371 |
-
else:
|
| 372 |
-
offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q)
|
| 373 |
-
mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx])
|
| 374 |
if const_expr(not self.pack_gqa):
|
| 375 |
gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (m_block,))
|
| 376 |
gLSE_expanded_layout = cute.append(
|
|
@@ -384,7 +373,7 @@ class FlashAttentionForwardBase:
|
|
| 384 |
t0accOcO = layout_utils.reshape_acc_to_mn(thr_mma.get_slice(0).partition_C(cO))
|
| 385 |
# Only the thread corresponding to column 0 writes out the lse to gmem
|
| 386 |
if taccOcO[0][1] == 0:
|
| 387 |
-
for m in cutlass.
|
| 388 |
if (
|
| 389 |
t0accOcO[m, 0][0]
|
| 390 |
< seqlen.seqlen_q - m_block * self.tile_m - taccOcO[0][0]
|
|
@@ -393,11 +382,8 @@ class FlashAttentionForwardBase:
|
|
| 393 |
else:
|
| 394 |
pack_gqa.store_LSE(mLSE_cur, lse, tiled_mma, tidx, m_block, seqlen.seqlen_q)
|
| 395 |
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
else:
|
| 399 |
-
offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q)
|
| 400 |
-
mO_cur = cute.domain_offset((offset, 0), mO[None, None, head_idx])
|
| 401 |
# thr_mma = tiled_mma.get_slice(tidx)
|
| 402 |
# taccOgO = thr_mma.partition_C(gO)
|
| 403 |
# cute.autovec_copy(rO, taccOgO)
|
|
@@ -634,12 +620,19 @@ class FlashAttentionForwardSm80(FlashAttentionForwardBase):
|
|
| 634 |
mV: cute.Tensor,
|
| 635 |
mO: cute.Tensor,
|
| 636 |
mLSE: Optional[cute.Tensor],
|
| 637 |
-
|
| 638 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 639 |
window_size_left: Optional[Int32] = None,
|
| 640 |
window_size_right: Optional[Int32] = None,
|
| 641 |
learnable_sink: Optional[cute.Tensor] = None,
|
|
|
|
| 642 |
aux_tensors=None,
|
|
|
|
|
|
|
| 643 |
):
|
| 644 |
"""Configures and launches the flash attention kernel.
|
| 645 |
|
|
@@ -648,7 +641,7 @@ class FlashAttentionForwardSm80(FlashAttentionForwardBase):
|
|
| 648 |
"""
|
| 649 |
assert learnable_sink is None, "Learnable sink is not supported in this kernel"
|
| 650 |
self._check_type(
|
| 651 |
-
*(t.element_type if t is not None else None for t in (mQ, mK, mV, mO, mLSE))
|
| 652 |
)
|
| 653 |
tiled_mma_qk, tiled_mma_pv = self._get_tiled_mma()
|
| 654 |
self.num_mma_threads = tiled_mma_pv.size
|
|
@@ -656,41 +649,54 @@ class FlashAttentionForwardSm80(FlashAttentionForwardBase):
|
|
| 656 |
self.num_Q_load_threads = self.num_threads
|
| 657 |
self.num_epilogue_threads = self.num_threads
|
| 658 |
# self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None
|
| 659 |
-
self.use_tma_O = self.arch >=
|
| 660 |
self._setup_attributes()
|
| 661 |
SharedStorage = self._get_shared_storage_cls()
|
| 662 |
mQ, mK, mV, mO = [assume_tensor_aligned(t) for t in (mQ, mK, mV, mO)]
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
|
|
|
|
|
|
|
|
|
| 666 |
]
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
if const_expr(
|
| 676 |
-
|
| 677 |
-
softmax_scale = None
|
| 678 |
else:
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
seqlen_k
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 694 |
|
| 695 |
self.kernel(
|
| 696 |
mQ,
|
|
@@ -698,6 +704,10 @@ class FlashAttentionForwardSm80(FlashAttentionForwardBase):
|
|
| 698 |
mV,
|
| 699 |
mO,
|
| 700 |
mLSE,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 701 |
softmax_scale_log2,
|
| 702 |
softmax_scale,
|
| 703 |
window_size_left,
|
|
@@ -714,6 +724,8 @@ class FlashAttentionForwardSm80(FlashAttentionForwardBase):
|
|
| 714 |
tiled_mma_qk,
|
| 715 |
tiled_mma_pv,
|
| 716 |
SharedStorage,
|
|
|
|
|
|
|
| 717 |
aux_tensors,
|
| 718 |
fastdiv_mods,
|
| 719 |
).launch(
|
|
@@ -731,6 +743,10 @@ class FlashAttentionForwardSm80(FlashAttentionForwardBase):
|
|
| 731 |
mV: cute.Tensor,
|
| 732 |
mO: cute.Tensor,
|
| 733 |
mLSE: Optional[cute.Tensor],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 734 |
softmax_scale_log2: Float32,
|
| 735 |
softmax_scale: Optional[Float32],
|
| 736 |
window_size_left: Optional[Int32],
|
|
@@ -747,12 +763,17 @@ class FlashAttentionForwardSm80(FlashAttentionForwardBase):
|
|
| 747 |
tiled_mma_qk: cute.TiledMma,
|
| 748 |
tiled_mma_pv: cute.TiledMma,
|
| 749 |
SharedStorage: cutlass.Constexpr,
|
|
|
|
|
|
|
| 750 |
aux_tensors=None,
|
| 751 |
fastdiv_mods=None,
|
| 752 |
):
|
| 753 |
# Thread index, block index
|
| 754 |
tidx, _, _ = cute.arch.thread_idx()
|
| 755 |
-
|
|
|
|
|
|
|
|
|
|
| 756 |
|
| 757 |
block_info = BlockInfo(
|
| 758 |
self.tile_m,
|
|
@@ -764,13 +785,21 @@ class FlashAttentionForwardSm80(FlashAttentionForwardBase):
|
|
| 764 |
window_size_right,
|
| 765 |
qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
|
| 766 |
)
|
| 767 |
-
seqlen = SeqlenInfoQK.create(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 768 |
n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block)
|
| 769 |
-
#
|
| 770 |
-
#
|
| 771 |
-
#
|
| 772 |
-
#
|
| 773 |
-
n_block = n_block_max - 1
|
| 774 |
|
| 775 |
# ///////////////////////////////////////////////////////////////////////////////
|
| 776 |
# Get the appropriate tiles for this thread block.
|
|
@@ -778,10 +807,20 @@ class FlashAttentionForwardSm80(FlashAttentionForwardBase):
|
|
| 778 |
blkQ_shape = (self.tile_m, self.tile_hdim)
|
| 779 |
blkK_shape = (self.tile_n, self.tile_hdim)
|
| 780 |
blkV_shape = (self.tile_n, self.tile_hdimv)
|
| 781 |
-
gQ = cute.local_tile(mQ[None, None, num_head, batch_size], blkQ_shape, (m_block, 0))
|
| 782 |
num_head_kv = num_head // self.qhead_per_kvhead
|
| 783 |
-
|
| 784 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 785 |
|
| 786 |
# ///////////////////////////////////////////////////////////////////////////////
|
| 787 |
# Get shared memory buffer
|
|
@@ -953,18 +992,20 @@ class FlashAttentionForwardSm80(FlashAttentionForwardBase):
|
|
| 953 |
mask = AttentionMask(
|
| 954 |
self.tile_m,
|
| 955 |
self.tile_n,
|
| 956 |
-
seqlen
|
| 957 |
-
seqlen.seqlen_k,
|
| 958 |
window_size_left,
|
| 959 |
window_size_right,
|
| 960 |
self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
|
| 961 |
)
|
| 962 |
mask_fn = partial(
|
| 963 |
mask.apply_mask,
|
|
|
|
|
|
|
| 964 |
m_block=m_block,
|
| 965 |
thr_mma=thr_mma_qk,
|
| 966 |
mask_causal=self.is_causal,
|
| 967 |
mask_local=self.is_local,
|
|
|
|
| 968 |
fastdiv_mods=fastdiv_mods if const_expr(self.mask_mod is not None) else None,
|
| 969 |
)
|
| 970 |
|
|
@@ -976,8 +1017,8 @@ class FlashAttentionForwardSm80(FlashAttentionForwardBase):
|
|
| 976 |
smem_pipe_read,
|
| 977 |
smem_pipe_write,
|
| 978 |
is_first_n_block=True,
|
| 979 |
-
|
| 980 |
-
mask_fn=partial(mask_fn, mask_seqlen=True),
|
| 981 |
)
|
| 982 |
smem_pipe_read = self.advance_pipeline(smem_pipe_read)
|
| 983 |
smem_pipe_write = self.advance_pipeline(smem_pipe_write)
|
|
@@ -992,15 +1033,17 @@ class FlashAttentionForwardSm80(FlashAttentionForwardBase):
|
|
| 992 |
n_block,
|
| 993 |
smem_pipe_read,
|
| 994 |
smem_pipe_write,
|
| 995 |
-
|
| 996 |
-
mask_fn=partial(mask_fn, mask_seqlen=
|
| 997 |
)
|
| 998 |
smem_pipe_read = self.advance_pipeline(smem_pipe_read)
|
| 999 |
smem_pipe_write = self.advance_pipeline(smem_pipe_write)
|
| 1000 |
# The remaining iterations have no masking
|
| 1001 |
for n_tile in cutlass.range(n_block, unroll=1):
|
| 1002 |
compute_one_n_block(
|
| 1003 |
-
n_block - n_tile - 1, smem_pipe_read, smem_pipe_write,
|
|
|
|
|
|
|
| 1004 |
)
|
| 1005 |
smem_pipe_read = self.advance_pipeline(smem_pipe_read)
|
| 1006 |
smem_pipe_write = self.advance_pipeline(smem_pipe_write)
|
|
@@ -1144,1283 +1187,9 @@ class FlashAttentionForwardSm80(FlashAttentionForwardBase):
|
|
| 1144 |
# load_K_next()
|
| 1145 |
|
| 1146 |
|
| 1147 |
-
|
| 1148 |
-
|
| 1149 |
-
|
| 1150 |
-
|
| 1151 |
-
|
| 1152 |
-
|
| 1153 |
-
intra_wg_overlap: bool = True,
|
| 1154 |
-
mma_pv_is_rs: bool = True,
|
| 1155 |
-
**kwargs,
|
| 1156 |
-
):
|
| 1157 |
-
super().__init__(*args, **kwargs)
|
| 1158 |
-
self.intra_wg_overlap = intra_wg_overlap
|
| 1159 |
-
self.mma_pv_is_rs = mma_pv_is_rs
|
| 1160 |
-
self.buffer_align_bytes = 1024
|
| 1161 |
-
|
| 1162 |
-
def _get_smem_layout_atom(self):
|
| 1163 |
-
sQ_layout_atom = warpgroup.make_smem_layout_atom(
|
| 1164 |
-
sm90_utils_basic.get_smem_layout_atom(LayoutEnum.ROW_MAJOR, self.dtype, self.tile_hdim),
|
| 1165 |
-
self.dtype,
|
| 1166 |
-
)
|
| 1167 |
-
sK_layout_atom = sQ_layout_atom
|
| 1168 |
-
sV_layout_atom = warpgroup.make_smem_layout_atom(
|
| 1169 |
-
sm90_utils_basic.get_smem_layout_atom(
|
| 1170 |
-
LayoutEnum.ROW_MAJOR, self.dtype, self.tile_hdimv
|
| 1171 |
-
),
|
| 1172 |
-
self.dtype,
|
| 1173 |
-
)
|
| 1174 |
-
sO_layout_atom = sV_layout_atom
|
| 1175 |
-
if not self.mma_pv_is_rs:
|
| 1176 |
-
sP_layout_atom = warpgroup.make_smem_layout_atom(
|
| 1177 |
-
sm90_utils_basic.get_smem_layout_atom(
|
| 1178 |
-
LayoutEnum.ROW_MAJOR, self.dtype, self.tile_n
|
| 1179 |
-
),
|
| 1180 |
-
self.dtype,
|
| 1181 |
-
)
|
| 1182 |
-
else:
|
| 1183 |
-
sP_layout_atom = None
|
| 1184 |
-
return sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom
|
| 1185 |
-
|
| 1186 |
-
def _get_tiled_mma(self):
|
| 1187 |
-
tiled_mma_qk = sm90_utils_basic.make_trivial_tiled_mma(
|
| 1188 |
-
self.dtype,
|
| 1189 |
-
self.dtype,
|
| 1190 |
-
warpgroup.OperandMajorMode.K,
|
| 1191 |
-
warpgroup.OperandMajorMode.K,
|
| 1192 |
-
Float32,
|
| 1193 |
-
atom_layout_mnk=(self.tile_m // 64, 1, 1), # Might need (1, 2, 1) for hdim 512
|
| 1194 |
-
tiler_mn=(64, self.tile_n),
|
| 1195 |
-
)
|
| 1196 |
-
tiled_mma_pv = sm90_utils_basic.make_trivial_tiled_mma(
|
| 1197 |
-
self.dtype,
|
| 1198 |
-
self.dtype,
|
| 1199 |
-
warpgroup.OperandMajorMode.K,
|
| 1200 |
-
warpgroup.OperandMajorMode.MN,
|
| 1201 |
-
Float32,
|
| 1202 |
-
atom_layout_mnk=(self.tile_m // 64, 1, 1), # Might need (1, 2, 1) for hdim 512
|
| 1203 |
-
tiler_mn=(64, self.tile_hdimv),
|
| 1204 |
-
a_source=warpgroup.OperandSource.RMEM
|
| 1205 |
-
if self.mma_pv_is_rs
|
| 1206 |
-
else warpgroup.OperandSource.SMEM,
|
| 1207 |
-
)
|
| 1208 |
-
return tiled_mma_qk, tiled_mma_pv
|
| 1209 |
-
|
| 1210 |
-
def _get_shared_storage_cls(self):
|
| 1211 |
-
sQ_struct, sK_struct, sV_struct = [
|
| 1212 |
-
cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], self.buffer_align_bytes]
|
| 1213 |
-
for layout in (self.sQ_layout, self.sK_layout, self.sV_layout)
|
| 1214 |
-
|
| 1215 |
-
]
|
| 1216 |
-
cosize_sQV = max(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout))
|
| 1217 |
-
sQV_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sQV], 1024]
|
| 1218 |
-
cosize_sP = cute.cosize(self.sP_layout) if const_expr(self.sP_layout is not None) else 0
|
| 1219 |
-
sP_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sP], 1024]
|
| 1220 |
-
# 1 for Q, 1 for O, self.num_stages*2 for K, self.num_stages*2 for V,
|
| 1221 |
-
mbar_ptr_QO_struct = cute.struct.MemRange[cutlass.Int64, 2]
|
| 1222 |
-
mbar_ptr_K_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2]
|
| 1223 |
-
mbar_ptr_V_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2]
|
| 1224 |
-
|
| 1225 |
-
@cute.struct
|
| 1226 |
-
class SharedStorageQKV:
|
| 1227 |
-
mbar_ptr: mbar_ptr_QO_struct
|
| 1228 |
-
mbar_ptr_K: mbar_ptr_K_struct
|
| 1229 |
-
mbar_ptr_V: mbar_ptr_V_struct
|
| 1230 |
-
sV: sV_struct
|
| 1231 |
-
sQ: sQ_struct
|
| 1232 |
-
sK: sK_struct
|
| 1233 |
-
sP: sP_struct
|
| 1234 |
-
|
| 1235 |
-
@cute.struct
|
| 1236 |
-
class SharedStorageSharedQV:
|
| 1237 |
-
mbar_ptr: mbar_ptr_QO_struct
|
| 1238 |
-
mbar_ptr_K: mbar_ptr_K_struct
|
| 1239 |
-
mbar_ptr_V: mbar_ptr_V_struct
|
| 1240 |
-
sQ: sQV_struct
|
| 1241 |
-
sK: sK_struct
|
| 1242 |
-
sP: sP_struct
|
| 1243 |
-
|
| 1244 |
-
return SharedStorageQKV if const_expr(not self.Q_in_regs) else SharedStorageSharedQV
|
| 1245 |
-
|
| 1246 |
-
@cute.jit
|
| 1247 |
-
def __call__(
|
| 1248 |
-
self,
|
| 1249 |
-
mQ: cute.Tensor, # (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
|
| 1250 |
-
mK: cute.Tensor, # (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table
|
| 1251 |
-
mV: cute.Tensor, # (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table
|
| 1252 |
-
mO: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
|
| 1253 |
-
mLSE: Optional[cute.Tensor],
|
| 1254 |
-
softmax_scale: Float32,
|
| 1255 |
-
stream: cuda.CUstream,
|
| 1256 |
-
mCuSeqlensQ: Optional[cute.Tensor] = None,
|
| 1257 |
-
mCuSeqlensK: Optional[cute.Tensor] = None,
|
| 1258 |
-
mSeqUsedQ: Optional[cute.Tensor] = None,
|
| 1259 |
-
mSeqUsedK: Optional[cute.Tensor] = None,
|
| 1260 |
-
mPageTable: Optional[cute.Tensor] = None, # (b_k, max_num_pages_per_seq)
|
| 1261 |
-
window_size_left: Int32 | int | None = None,
|
| 1262 |
-
window_size_right: Int32 | int | None = None,
|
| 1263 |
-
learnable_sink: Optional[cute.Tensor] = None,
|
| 1264 |
-
blocksparse_tensors: Optional[BlockSparseTensors] = None,
|
| 1265 |
-
aux_tensors: Optional[list] = None,
|
| 1266 |
-
):
|
| 1267 |
-
"""Configures and launches the flash attention kernel.
|
| 1268 |
-
|
| 1269 |
-
mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout:
|
| 1270 |
-
(batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1)
|
| 1271 |
-
"""
|
| 1272 |
-
|
| 1273 |
-
self._check_type(
|
| 1274 |
-
*(
|
| 1275 |
-
t.element_type if t is not None else None
|
| 1276 |
-
for t in (mQ, mK, mV, mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)
|
| 1277 |
-
)
|
| 1278 |
-
)
|
| 1279 |
-
|
| 1280 |
-
mQ, mK, mV, mO = [assume_tensor_aligned(t) for t in (mQ, mK, mV, mO)]
|
| 1281 |
-
QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1]
|
| 1282 |
-
mQ, mO = [layout_utils.select(t, QO_layout_transpose) for t in (mQ, mO)]
|
| 1283 |
-
KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1]
|
| 1284 |
-
mK, mV = [layout_utils.select(t, KV_layout_transpose) for t in (mK, mV)]
|
| 1285 |
-
LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0]
|
| 1286 |
-
mLSE = layout_utils.select(mLSE, LSE_layout_transpose) if const_expr(mLSE is not None) else None
|
| 1287 |
-
|
| 1288 |
-
tiled_mma_qk, tiled_mma_pv = self._get_tiled_mma()
|
| 1289 |
-
self.num_mma_threads = tiled_mma_qk.size
|
| 1290 |
-
self.num_threads_per_warp_group = 128
|
| 1291 |
-
self.num_mma_warp_groups = self.num_mma_threads // self.num_threads_per_warp_group
|
| 1292 |
-
self.num_threads = self.num_threads_per_warp_group * (self.num_mma_warp_groups + 1)
|
| 1293 |
-
self.num_producer_threads = 32
|
| 1294 |
-
self.num_Q_load_threads = self.num_mma_threads # If not TMA_Q, MMA threads load Q
|
| 1295 |
-
self.num_epilogue_threads = self.num_mma_threads
|
| 1296 |
-
self.num_mma_regs = (
|
| 1297 |
-
256
|
| 1298 |
-
if self.num_mma_warp_groups == 1
|
| 1299 |
-
else (240 if self.num_mma_warp_groups == 2 else 160)
|
| 1300 |
-
)
|
| 1301 |
-
self.num_producer_regs = (
|
| 1302 |
-
56 if self.num_mma_warp_groups == 1 else (24 if self.num_mma_warp_groups == 2 else 32)
|
| 1303 |
-
)
|
| 1304 |
-
# self.num_mma_regs = 232
|
| 1305 |
-
# self.num_producer_regs = 40
|
| 1306 |
-
self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None)
|
| 1307 |
-
|
| 1308 |
-
self.use_scheduler_barrier = (
|
| 1309 |
-
(self.num_mma_warp_groups >= 2 and self.tile_hdim <= 128)
|
| 1310 |
-
if const_expr(self.intra_wg_overlap)
|
| 1311 |
-
else (self.num_mma_warp_groups == 2)
|
| 1312 |
-
)
|
| 1313 |
-
self.use_tma_Q = self.arch >= 90 and not (
|
| 1314 |
-
self.pack_gqa and self.tile_m % self.qhead_per_kvhead != 0
|
| 1315 |
-
)
|
| 1316 |
-
self.use_tma_O = (
|
| 1317 |
-
self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None and not self.pack_gqa
|
| 1318 |
-
)
|
| 1319 |
-
# TODO: rescale_O_before_gemm
|
| 1320 |
-
self._setup_attributes()
|
| 1321 |
-
# TODO: we prob don't need most of what's in _setup_attributes
|
| 1322 |
-
self.sQ_layout, self.sK_layout, self.sV_layout, self.sO_layout = [
|
| 1323 |
-
sm90_utils.make_smem_layout(mX.element_type, LayoutEnum.ROW_MAJOR, shape, stage)
|
| 1324 |
-
for mX, shape, stage in [
|
| 1325 |
-
(mQ, (self.tile_m, self.tile_hdim), None),
|
| 1326 |
-
(mK, (self.tile_n, self.tile_hdim), self.num_stages),
|
| 1327 |
-
(mV, (self.tile_n, self.tile_hdimv), self.num_stages),
|
| 1328 |
-
(mO, (self.tile_m, self.tile_hdimv), None),
|
| 1329 |
-
]
|
| 1330 |
-
]
|
| 1331 |
-
self.sP_layout = None
|
| 1332 |
-
if const_expr(not self.mma_pv_is_rs):
|
| 1333 |
-
self.sP_layout = sm90_utils.make_smem_layout(
|
| 1334 |
-
mV.element_type, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_n)
|
| 1335 |
-
)
|
| 1336 |
-
|
| 1337 |
-
SharedStorage = self._get_shared_storage_cls()
|
| 1338 |
-
|
| 1339 |
-
if const_expr(self.pack_gqa):
|
| 1340 |
-
shape_Q_packed = (
|
| 1341 |
-
(self.qhead_per_kvhead, mQ.shape[0]),
|
| 1342 |
-
mQ.shape[1],
|
| 1343 |
-
mK.shape[2],
|
| 1344 |
-
*mQ.shape[3:],
|
| 1345 |
-
)
|
| 1346 |
-
stride_Q_packed = (
|
| 1347 |
-
(mQ.stride[2], mQ.stride[0]),
|
| 1348 |
-
mQ.stride[1],
|
| 1349 |
-
mQ.stride[2] * self.qhead_per_kvhead,
|
| 1350 |
-
*mQ.stride[3:],
|
| 1351 |
-
)
|
| 1352 |
-
mQ = cute.make_tensor(
|
| 1353 |
-
mQ.iterator, cute.make_layout(shape_Q_packed, stride=stride_Q_packed)
|
| 1354 |
-
)
|
| 1355 |
-
shape_O_packed = (
|
| 1356 |
-
(self.qhead_per_kvhead, mO.shape[0]),
|
| 1357 |
-
mK.shape[1],
|
| 1358 |
-
mK.shape[2],
|
| 1359 |
-
*mO.shape[3:],
|
| 1360 |
-
)
|
| 1361 |
-
stride_O_packed = (
|
| 1362 |
-
(mO.stride[2], mO.stride[0]),
|
| 1363 |
-
mO.stride[1],
|
| 1364 |
-
mO.stride[2] * self.qhead_per_kvhead,
|
| 1365 |
-
*mO.stride[3:],
|
| 1366 |
-
)
|
| 1367 |
-
mO = cute.make_tensor(
|
| 1368 |
-
mO.iterator, cute.make_layout(shape_O_packed, stride=stride_O_packed)
|
| 1369 |
-
)
|
| 1370 |
-
if const_expr(mLSE is not None):
|
| 1371 |
-
shape_LSE_packed = (
|
| 1372 |
-
(self.qhead_per_kvhead, mLSE.shape[0]),
|
| 1373 |
-
mK.shape[2],
|
| 1374 |
-
*mLSE.shape[2:],
|
| 1375 |
-
)
|
| 1376 |
-
stride_LSE_packed = (
|
| 1377 |
-
(mLSE.stride[1], mLSE.stride[0]),
|
| 1378 |
-
mLSE.stride[1] * self.qhead_per_kvhead,
|
| 1379 |
-
*mLSE.stride[2:],
|
| 1380 |
-
)
|
| 1381 |
-
mLSE = cute.make_tensor(
|
| 1382 |
-
mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed)
|
| 1383 |
-
)
|
| 1384 |
-
|
| 1385 |
-
# TMA
|
| 1386 |
-
gmem_tiled_copy_Q = cpasync.CopyBulkTensorTileG2SOp()
|
| 1387 |
-
gmem_tiled_copy_KV = cpasync.CopyBulkTensorTileG2SOp() # Might multicast
|
| 1388 |
-
gmem_tiled_copy_O = cpasync.CopyBulkTensorTileS2GOp()
|
| 1389 |
-
self.tma_copy_bytes = {
|
| 1390 |
-
name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1]))
|
| 1391 |
-
for name, mX, layout in [
|
| 1392 |
-
("Q", mQ, self.sQ_layout),
|
| 1393 |
-
("K", mK, self.sK_layout),
|
| 1394 |
-
("V", mV, self.sV_layout),
|
| 1395 |
-
]
|
| 1396 |
-
}
|
| 1397 |
-
tma_atom_Q, tma_tensor_Q = None, None
|
| 1398 |
-
if const_expr(self.use_tma_Q):
|
| 1399 |
-
tma_atom_Q, tma_tensor_Q = cpasync.make_tiled_tma_atom(
|
| 1400 |
-
gmem_tiled_copy_Q,
|
| 1401 |
-
mQ,
|
| 1402 |
-
self.sQ_layout,
|
| 1403 |
-
(self.tile_m, self.tile_hdim), # No mcast
|
| 1404 |
-
)
|
| 1405 |
-
tma_atom_K, tma_tensor_K = cpasync.make_tiled_tma_atom(
|
| 1406 |
-
gmem_tiled_copy_KV,
|
| 1407 |
-
mK,
|
| 1408 |
-
cute.select(self.sK_layout, mode=[0, 1]),
|
| 1409 |
-
(self.tile_n, self.tile_hdim),
|
| 1410 |
-
1, # No mcast for now
|
| 1411 |
-
)
|
| 1412 |
-
tma_atom_V, tma_tensor_V = cpasync.make_tiled_tma_atom(
|
| 1413 |
-
gmem_tiled_copy_KV,
|
| 1414 |
-
mV,
|
| 1415 |
-
cute.select(self.sV_layout, mode=[0, 1]),
|
| 1416 |
-
(self.tile_n, self.tile_hdimv),
|
| 1417 |
-
1, # No mcast for now
|
| 1418 |
-
)
|
| 1419 |
-
tma_atom_O, tma_tensor_O = None, None
|
| 1420 |
-
if const_expr(self.use_tma_O):
|
| 1421 |
-
tma_atom_O, tma_tensor_O = cpasync.make_tiled_tma_atom(
|
| 1422 |
-
gmem_tiled_copy_O,
|
| 1423 |
-
mO,
|
| 1424 |
-
self.sO_layout,
|
| 1425 |
-
(self.tile_m, self.tile_hdimv), # No mcast
|
| 1426 |
-
)
|
| 1427 |
-
if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None):
|
| 1428 |
-
TileScheduler = SingleTileVarlenScheduler
|
| 1429 |
-
else:
|
| 1430 |
-
TileScheduler = (
|
| 1431 |
-
SingleTileScheduler
|
| 1432 |
-
if const_expr(not self.is_causal or self.is_local)
|
| 1433 |
-
else SingleTileLPTScheduler
|
| 1434 |
-
)
|
| 1435 |
-
tile_sched_args = TileSchedulerArguments(
|
| 1436 |
-
cute.ceil_div(cute.size(mQ.shape[0]), self.tile_m),
|
| 1437 |
-
cute.size(mQ.shape[2]),
|
| 1438 |
-
cute.size(mQ.shape[3])
|
| 1439 |
-
if const_expr(mCuSeqlensQ is None)
|
| 1440 |
-
else cute.size(mCuSeqlensQ.shape[0] - 1),
|
| 1441 |
-
1, # num_splits
|
| 1442 |
-
cute.size(mK.shape[0]),
|
| 1443 |
-
mQ.shape[1],
|
| 1444 |
-
mV.shape[1],
|
| 1445 |
-
total_q=cute.size(mQ.shape[0])
|
| 1446 |
-
if const_expr(mCuSeqlensQ is not None)
|
| 1447 |
-
else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]),
|
| 1448 |
-
tile_shape_mn=(self.tile_m, self.tile_n),
|
| 1449 |
-
mCuSeqlensQ=mCuSeqlensQ,
|
| 1450 |
-
mSeqUsedQ=mSeqUsedQ,
|
| 1451 |
-
qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
|
| 1452 |
-
element_size=self.dtype.width // 8,
|
| 1453 |
-
is_persistent=False,
|
| 1454 |
-
lpt=self.is_causal or self.is_local,
|
| 1455 |
-
)
|
| 1456 |
-
tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
|
| 1457 |
-
grid_dim = TileScheduler.get_grid_shape(tile_sched_params)
|
| 1458 |
-
LOG2_E = math.log2(math.e)
|
| 1459 |
-
if const_expr(self.score_mod is None):
|
| 1460 |
-
softmax_scale_log2 = softmax_scale * LOG2_E
|
| 1461 |
-
softmax_scale = None
|
| 1462 |
-
else:
|
| 1463 |
-
# NB: If a user passes in a score mod, we want to apply the score-mod in the sm_scaled qk
|
| 1464 |
-
# But in the original base 10. We hijack softmax_scale_log2 to just be the change of base
|
| 1465 |
-
# and correctly apply the softmax_scale prior to score_mod in the softmax step
|
| 1466 |
-
softmax_scale_log2 = LOG2_E
|
| 1467 |
-
softmax_scale = softmax_scale
|
| 1468 |
-
if const_expr(window_size_left is not None):
|
| 1469 |
-
window_size_left = Int32(window_size_left)
|
| 1470 |
-
if const_expr(window_size_right is not None):
|
| 1471 |
-
window_size_right = Int32(window_size_right)
|
| 1472 |
-
|
| 1473 |
-
fastdiv_mods = None
|
| 1474 |
-
if const_expr(aux_tensors is not None):
|
| 1475 |
-
seqlen_q = cute.size(mQ.shape[0]) // (
|
| 1476 |
-
self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1
|
| 1477 |
-
)
|
| 1478 |
-
seqlen_k = (
|
| 1479 |
-
cute.size(mK.shape[0])
|
| 1480 |
-
if const_expr(mPageTable is None)
|
| 1481 |
-
else mK.shape[0] * mPageTable.shape[1]
|
| 1482 |
-
)
|
| 1483 |
-
seqlen_q_divmod = FastDivmodDivisor(seqlen_q)
|
| 1484 |
-
seqlen_k_divmod = FastDivmodDivisor(seqlen_k)
|
| 1485 |
-
fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod)
|
| 1486 |
-
|
| 1487 |
-
self.kernel(
|
| 1488 |
-
tma_tensor_Q if const_expr(self.use_tma_Q) else mQ,
|
| 1489 |
-
tma_tensor_K,
|
| 1490 |
-
tma_tensor_V,
|
| 1491 |
-
tma_tensor_O if const_expr(self.use_tma_O) else mO,
|
| 1492 |
-
mLSE,
|
| 1493 |
-
mCuSeqlensQ,
|
| 1494 |
-
mCuSeqlensK,
|
| 1495 |
-
mSeqUsedQ,
|
| 1496 |
-
mSeqUsedK,
|
| 1497 |
-
tma_atom_Q,
|
| 1498 |
-
tma_atom_K,
|
| 1499 |
-
tma_atom_V,
|
| 1500 |
-
tma_atom_O,
|
| 1501 |
-
softmax_scale_log2,
|
| 1502 |
-
softmax_scale,
|
| 1503 |
-
window_size_left,
|
| 1504 |
-
window_size_right,
|
| 1505 |
-
learnable_sink,
|
| 1506 |
-
blocksparse_tensors,
|
| 1507 |
-
self.sQ_layout,
|
| 1508 |
-
self.sK_layout,
|
| 1509 |
-
self.sV_layout,
|
| 1510 |
-
self.sO_layout,
|
| 1511 |
-
self.sP_layout,
|
| 1512 |
-
self.gmem_tiled_copy_Q,
|
| 1513 |
-
self.gmem_tiled_copy_K,
|
| 1514 |
-
self.gmem_tiled_copy_V,
|
| 1515 |
-
self.gmem_tiled_copy_O,
|
| 1516 |
-
tiled_mma_qk,
|
| 1517 |
-
tiled_mma_pv,
|
| 1518 |
-
tile_sched_params,
|
| 1519 |
-
TileScheduler,
|
| 1520 |
-
SharedStorage,
|
| 1521 |
-
aux_tensors,
|
| 1522 |
-
fastdiv_mods,
|
| 1523 |
-
).launch(
|
| 1524 |
-
grid=grid_dim,
|
| 1525 |
-
block=[self.num_threads, 1, 1],
|
| 1526 |
-
stream=stream,
|
| 1527 |
-
min_blocks_per_mp=1,
|
| 1528 |
-
)
|
| 1529 |
-
|
| 1530 |
-
@cute.kernel
|
| 1531 |
-
def kernel(
|
| 1532 |
-
self,
|
| 1533 |
-
mQ: cute.Tensor,
|
| 1534 |
-
mK: cute.Tensor,
|
| 1535 |
-
mV: cute.Tensor,
|
| 1536 |
-
mO: cute.Tensor,
|
| 1537 |
-
mLSE: Optional[cute.Tensor],
|
| 1538 |
-
mCuSeqlensQ: Optional[cute.Tensor],
|
| 1539 |
-
mCuSeqlensK: Optional[cute.Tensor],
|
| 1540 |
-
mSeqUsedQ: Optional[cute.Tensor],
|
| 1541 |
-
mSeqUsedK: Optional[cute.Tensor],
|
| 1542 |
-
tma_atom_Q: Optional[cute.CopyAtom],
|
| 1543 |
-
tma_atom_K: Optional[cute.CopyAtom],
|
| 1544 |
-
tma_atom_V: Optional[cute.CopyAtom],
|
| 1545 |
-
tma_atom_O: Optional[cute.CopyAtom],
|
| 1546 |
-
softmax_scale_log2: Float32,
|
| 1547 |
-
softmax_scale: Optional[Float32],
|
| 1548 |
-
window_size_left: Optional[Int32],
|
| 1549 |
-
window_size_right: Optional[Int32],
|
| 1550 |
-
learnable_sink: Optional[cute.Tensor],
|
| 1551 |
-
blocksparse_tensors: Optional[BlockSparseTensors],
|
| 1552 |
-
sQ_layout: cute.ComposedLayout,
|
| 1553 |
-
sK_layout: cute.ComposedLayout,
|
| 1554 |
-
sV_layout: cute.ComposedLayout,
|
| 1555 |
-
sO_layout: cute.ComposedLayout,
|
| 1556 |
-
sP_layout: cute.ComposedLayout | None,
|
| 1557 |
-
gmem_tiled_copy_Q: cute.TiledCopy,
|
| 1558 |
-
gmem_tiled_copy_K: cute.TiledCopy,
|
| 1559 |
-
gmem_tiled_copy_V: cute.TiledCopy,
|
| 1560 |
-
gmem_tiled_copy_O: cute.TiledCopy,
|
| 1561 |
-
tiled_mma_qk: cute.TiledMma,
|
| 1562 |
-
tiled_mma_pv: cute.TiledMma,
|
| 1563 |
-
tile_sched_params: ParamsBase,
|
| 1564 |
-
TileScheduler: cutlass.Constexpr[Callable],
|
| 1565 |
-
SharedStorage: cutlass.Constexpr[Callable],
|
| 1566 |
-
aux_tensors=Optional[list[cute.Tensor]],
|
| 1567 |
-
fastdiv_mods=None,
|
| 1568 |
-
):
|
| 1569 |
-
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
| 1570 |
-
# Prefetch tma descriptor
|
| 1571 |
-
if warp_idx == 0:
|
| 1572 |
-
for tma_atom in (tma_atom_Q, tma_atom_K, tma_atom_V, tma_atom_O):
|
| 1573 |
-
if const_expr(tma_atom is not None):
|
| 1574 |
-
cpasync.prefetch_descriptor(tma_atom)
|
| 1575 |
-
|
| 1576 |
-
smem = cutlass.utils.SmemAllocator()
|
| 1577 |
-
storage = smem.allocate(SharedStorage)
|
| 1578 |
-
|
| 1579 |
-
# Mbarrier init
|
| 1580 |
-
mbar_ptr_Q = storage.mbar_ptr.data_ptr()
|
| 1581 |
-
if warp_idx == 1:
|
| 1582 |
-
# if tidx < 2:
|
| 1583 |
-
# # barrierO num threads should be self.num_mma_threads
|
| 1584 |
-
# cute.arch.mbarrier_init(mbar_ptr_Q + tidx, 1 if tidx == 0 else self.num_mma_threads)
|
| 1585 |
-
if const_expr(not self.use_tma_Q):
|
| 1586 |
-
cute.arch.mbarrier_init(mbar_ptr_Q, self.num_Q_load_threads)
|
| 1587 |
-
# cute.arch.mbarrier_init(mbar_ptr_Q + 1, self.num_mma_threads)
|
| 1588 |
-
# We rely on pipeline_k and pipeline_v to initialize the mbarrier fence and sync
|
| 1589 |
-
pipeline_kv_producer_group = cutlass.pipeline.CooperativeGroup(
|
| 1590 |
-
cutlass.pipeline.Agent.Thread
|
| 1591 |
-
)
|
| 1592 |
-
pipeline_kv_consumer_group = cutlass.pipeline.CooperativeGroup(
|
| 1593 |
-
cutlass.pipeline.Agent.Thread, self.num_mma_threads // cute.arch.WARP_SIZE
|
| 1594 |
-
)
|
| 1595 |
-
pipeline_k = pipeline.PipelineTmaAsync.create(
|
| 1596 |
-
barrier_storage=storage.mbar_ptr_K.data_ptr(),
|
| 1597 |
-
num_stages=self.num_stages,
|
| 1598 |
-
producer_group=pipeline_kv_producer_group,
|
| 1599 |
-
consumer_group=pipeline_kv_consumer_group,
|
| 1600 |
-
tx_count=self.tma_copy_bytes["K"],
|
| 1601 |
-
defer_sync=True,
|
| 1602 |
-
)
|
| 1603 |
-
pipeline_v = pipeline.PipelineTmaAsync.create(
|
| 1604 |
-
barrier_storage=storage.mbar_ptr_V.data_ptr(),
|
| 1605 |
-
num_stages=self.num_stages,
|
| 1606 |
-
producer_group=pipeline_kv_producer_group,
|
| 1607 |
-
consumer_group=pipeline_kv_consumer_group,
|
| 1608 |
-
tx_count=self.tma_copy_bytes["V"],
|
| 1609 |
-
defer_sync=False
|
| 1610 |
-
)
|
| 1611 |
-
|
| 1612 |
-
# ///////////////////////////////////////////////////////////////////////////////
|
| 1613 |
-
# Get shared memory buffer
|
| 1614 |
-
# ///////////////////////////////////////////////////////////////////////////////
|
| 1615 |
-
sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner)
|
| 1616 |
-
sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner)
|
| 1617 |
-
if const_expr(not self.Q_in_regs):
|
| 1618 |
-
sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner)
|
| 1619 |
-
else:
|
| 1620 |
-
sV = storage.sQ.get_tensor(
|
| 1621 |
-
sV_layout.outer, swizzle=sV_layout.inner, dtype=mV.element_type
|
| 1622 |
-
)
|
| 1623 |
-
# Transpose view of V to tensor with layout (head_dim_v, tile_n) for tiled mma
|
| 1624 |
-
sVt = layout_utils.transpose_view(sV)
|
| 1625 |
-
sP = None
|
| 1626 |
-
if const_expr(sP_layout is not None):
|
| 1627 |
-
sP = storage.sP.get_tensor(sP_layout.outer, swizzle=sP_layout.inner)
|
| 1628 |
-
# reuse sQ's data iterator
|
| 1629 |
-
sO = storage.sQ.get_tensor(sO_layout.outer, swizzle=sO_layout.inner, dtype=self.dtype)
|
| 1630 |
-
|
| 1631 |
-
block_info = BlockInfo(
|
| 1632 |
-
self.tile_m,
|
| 1633 |
-
self.tile_n,
|
| 1634 |
-
self.is_causal,
|
| 1635 |
-
self.is_local,
|
| 1636 |
-
False, # is_split_kv
|
| 1637 |
-
window_size_left,
|
| 1638 |
-
window_size_right,
|
| 1639 |
-
qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
|
| 1640 |
-
)
|
| 1641 |
-
SeqlenInfoCls = partial(
|
| 1642 |
-
SeqlenInfoQK.create,
|
| 1643 |
-
seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1],
|
| 1644 |
-
seqlen_k_static=mK.shape[0],
|
| 1645 |
-
mCuSeqlensQ=mCuSeqlensQ,
|
| 1646 |
-
mCuSeqlensK=mCuSeqlensK,
|
| 1647 |
-
mSeqUsedQ=mSeqUsedQ,
|
| 1648 |
-
mSeqUsedK=mSeqUsedK,
|
| 1649 |
-
)
|
| 1650 |
-
AttentionMaskCls = partial(
|
| 1651 |
-
AttentionMask,
|
| 1652 |
-
self.tile_m,
|
| 1653 |
-
self.tile_n,
|
| 1654 |
-
window_size_left=window_size_left,
|
| 1655 |
-
window_size_right=window_size_right,
|
| 1656 |
-
qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
|
| 1657 |
-
)
|
| 1658 |
-
TileSchedulerCls = partial(TileScheduler.create, tile_sched_params)
|
| 1659 |
-
|
| 1660 |
-
if warp_idx < 4: # Producer
|
| 1661 |
-
cute.arch.setmaxregister_decrease(self.num_producer_regs)
|
| 1662 |
-
self.load(
|
| 1663 |
-
mQ,
|
| 1664 |
-
mK,
|
| 1665 |
-
mV,
|
| 1666 |
-
sQ,
|
| 1667 |
-
sK,
|
| 1668 |
-
sV,
|
| 1669 |
-
tma_atom_Q,
|
| 1670 |
-
tma_atom_K,
|
| 1671 |
-
tma_atom_V,
|
| 1672 |
-
pipeline_k,
|
| 1673 |
-
pipeline_v,
|
| 1674 |
-
mbar_ptr_Q,
|
| 1675 |
-
blocksparse_tensors,
|
| 1676 |
-
block_info,
|
| 1677 |
-
SeqlenInfoCls,
|
| 1678 |
-
TileSchedulerCls,
|
| 1679 |
-
)
|
| 1680 |
-
|
| 1681 |
-
else: # Consumer
|
| 1682 |
-
cute.arch.setmaxregister_increase(self.num_mma_regs)
|
| 1683 |
-
# ///////////////////////////////////////////////////////////////////////////////
|
| 1684 |
-
# Tile MMA compute thread partitions and allocate accumulators
|
| 1685 |
-
# ///////////////////////////////////////////////////////////////////////////////
|
| 1686 |
-
tidx, _, _ = cute.arch.thread_idx()
|
| 1687 |
-
tidx = tidx - 128
|
| 1688 |
-
self.mma(
|
| 1689 |
-
tiled_mma_qk,
|
| 1690 |
-
tiled_mma_pv,
|
| 1691 |
-
mQ,
|
| 1692 |
-
mO,
|
| 1693 |
-
mLSE,
|
| 1694 |
-
sQ,
|
| 1695 |
-
sK,
|
| 1696 |
-
sVt,
|
| 1697 |
-
sP,
|
| 1698 |
-
sO,
|
| 1699 |
-
learnable_sink,
|
| 1700 |
-
pipeline_k,
|
| 1701 |
-
pipeline_v,
|
| 1702 |
-
mbar_ptr_Q,
|
| 1703 |
-
gmem_tiled_copy_Q,
|
| 1704 |
-
gmem_tiled_copy_O,
|
| 1705 |
-
tma_atom_O,
|
| 1706 |
-
tidx,
|
| 1707 |
-
softmax_scale_log2,
|
| 1708 |
-
softmax_scale,
|
| 1709 |
-
block_info,
|
| 1710 |
-
SeqlenInfoCls,
|
| 1711 |
-
AttentionMaskCls,
|
| 1712 |
-
TileSchedulerCls,
|
| 1713 |
-
blocksparse_tensors,
|
| 1714 |
-
aux_tensors,
|
| 1715 |
-
fastdiv_mods,
|
| 1716 |
-
)
|
| 1717 |
-
|
| 1718 |
-
@cute.jit
|
| 1719 |
-
def load(
|
| 1720 |
-
self,
|
| 1721 |
-
mQ: cute.Tensor,
|
| 1722 |
-
mK: cute.Tensor,
|
| 1723 |
-
mV: cute.Tensor,
|
| 1724 |
-
sQ: cute.Tensor,
|
| 1725 |
-
sK: cute.Tensor,
|
| 1726 |
-
sV: cute.Tensor,
|
| 1727 |
-
tma_atom_Q: cute.CopyAtom,
|
| 1728 |
-
tma_atom_K: cute.CopyAtom,
|
| 1729 |
-
tma_atom_V: cute.CopyAtom,
|
| 1730 |
-
pipeline_k: cutlass.pipeline.PipelineAsync,
|
| 1731 |
-
pipeline_v: cutlass.pipeline.PipelineAsync,
|
| 1732 |
-
mbar_ptr_Q: cutlass.Pointer,
|
| 1733 |
-
blocksparse_tensors: Optional[BlockSparseTensors],
|
| 1734 |
-
block_info: BlockInfo,
|
| 1735 |
-
SeqlenInfoCls: Callable,
|
| 1736 |
-
TileSchedulerCls: Callable,
|
| 1737 |
-
):
|
| 1738 |
-
warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4
|
| 1739 |
-
if warp_idx_in_wg == 0:
|
| 1740 |
-
q_producer_phase = Int32(1)
|
| 1741 |
-
kv_producer_state = pipeline.make_pipeline_state(
|
| 1742 |
-
cutlass.pipeline.PipelineUserType.Producer, self.num_stages
|
| 1743 |
-
)
|
| 1744 |
-
tile_scheduler = TileSchedulerCls()
|
| 1745 |
-
work_tile = tile_scheduler.initial_work_tile_info()
|
| 1746 |
-
while work_tile.is_valid_tile:
|
| 1747 |
-
# if work_tile.is_valid_tile:
|
| 1748 |
-
m_block, head_idx, batch_idx, _ = work_tile.tile_idx
|
| 1749 |
-
seqlen = SeqlenInfoCls(batch_idx)
|
| 1750 |
-
mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx]
|
| 1751 |
-
head_idx_kv = (
|
| 1752 |
-
head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx
|
| 1753 |
-
)
|
| 1754 |
-
mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[None, None, head_idx_kv]
|
| 1755 |
-
mV_cur = seqlen.offset_batch_K(mV, batch_idx, dim=3)[None, None, head_idx_kv]
|
| 1756 |
-
gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (None, 0))
|
| 1757 |
-
gV = cute.local_tile(mV_cur, (self.tile_n, self.tile_hdimv), (None, 0))
|
| 1758 |
-
load_Q = None
|
| 1759 |
-
if const_expr(self.use_tma_Q):
|
| 1760 |
-
gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0))
|
| 1761 |
-
load_Q, _, _ = copy_utils.tma_get_copy_fn(
|
| 1762 |
-
tma_atom_Q, 0, cute.make_layout(1), gQ, sQ, single_stage=True
|
| 1763 |
-
)
|
| 1764 |
-
# TODO: mcast
|
| 1765 |
-
# TODO check warp_idx if we have 128 producer threads
|
| 1766 |
-
load_K, _, _ = copy_utils.tma_get_copy_fn(
|
| 1767 |
-
tma_atom_K, 0, cute.make_layout(1), gK, sK
|
| 1768 |
-
)
|
| 1769 |
-
load_K = copy_utils.tma_producer_copy_fn(load_K, pipeline_k)
|
| 1770 |
-
load_V, _, _ = copy_utils.tma_get_copy_fn(
|
| 1771 |
-
tma_atom_V, 0, cute.make_layout(1), gV, sV
|
| 1772 |
-
)
|
| 1773 |
-
load_V = copy_utils.tma_producer_copy_fn(load_V, pipeline_v)
|
| 1774 |
-
|
| 1775 |
-
if const_expr(not self.use_block_sparsity):
|
| 1776 |
-
n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block)
|
| 1777 |
-
# if cute.arch.thread_idx()[0] == 0:
|
| 1778 |
-
# cute.printf("m_block = %d, n_block_min: %d, n_block_max: %d", m_block, n_block_min, n_block_max)
|
| 1779 |
-
# First iteration: load both Q & K with the same mbarrier
|
| 1780 |
-
n_block = n_block_max - 1
|
| 1781 |
-
pipeline_k.producer_acquire(
|
| 1782 |
-
kv_producer_state,
|
| 1783 |
-
extra_tx_count=self.tma_copy_bytes["Q"]
|
| 1784 |
-
if const_expr(self.use_tma_Q)
|
| 1785 |
-
else 0,
|
| 1786 |
-
)
|
| 1787 |
-
if const_expr(self.use_tma_Q):
|
| 1788 |
-
load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state))
|
| 1789 |
-
load_K(src_idx=n_block, producer_state=kv_producer_state)
|
| 1790 |
-
|
| 1791 |
-
if const_expr(not self.intra_wg_overlap):
|
| 1792 |
-
pipeline_v.producer_acquire(kv_producer_state)
|
| 1793 |
-
load_V(src_idx=n_block, producer_state=kv_producer_state)
|
| 1794 |
-
kv_producer_state.advance()
|
| 1795 |
-
for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1):
|
| 1796 |
-
n_block = n_block_max - 1 - i - 1
|
| 1797 |
-
pipeline_k.producer_acquire(kv_producer_state)
|
| 1798 |
-
load_K(src_idx=n_block, producer_state=kv_producer_state)
|
| 1799 |
-
pipeline_v.producer_acquire(kv_producer_state)
|
| 1800 |
-
load_V(src_idx=n_block, producer_state=kv_producer_state)
|
| 1801 |
-
kv_producer_state.advance()
|
| 1802 |
-
else:
|
| 1803 |
-
for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1):
|
| 1804 |
-
n_block_prev = n_block_max - i - 1
|
| 1805 |
-
n_block = n_block_prev - 1
|
| 1806 |
-
kv_producer_state_prev = kv_producer_state.clone()
|
| 1807 |
-
kv_producer_state.advance()
|
| 1808 |
-
pipeline_k.producer_acquire(kv_producer_state)
|
| 1809 |
-
load_K(src_idx=n_block, producer_state=kv_producer_state)
|
| 1810 |
-
pipeline_v.producer_acquire(kv_producer_state_prev)
|
| 1811 |
-
load_V(src_idx=n_block_prev, producer_state=kv_producer_state_prev)
|
| 1812 |
-
n_block = n_block_min
|
| 1813 |
-
pipeline_v.producer_acquire(kv_producer_state)
|
| 1814 |
-
load_V(src_idx=n_block, producer_state=kv_producer_state)
|
| 1815 |
-
kv_producer_state.advance()
|
| 1816 |
-
else:
|
| 1817 |
-
kv_producer_state = produce_block_sparse_loads(
|
| 1818 |
-
blocksparse_tensors,
|
| 1819 |
-
batch_idx,
|
| 1820 |
-
head_idx,
|
| 1821 |
-
m_block,
|
| 1822 |
-
kv_producer_state,
|
| 1823 |
-
load_Q,
|
| 1824 |
-
load_K,
|
| 1825 |
-
load_V,
|
| 1826 |
-
pipeline_k,
|
| 1827 |
-
pipeline_v,
|
| 1828 |
-
self.use_tma_Q,
|
| 1829 |
-
self.tma_copy_bytes["Q"],
|
| 1830 |
-
self.intra_wg_overlap,
|
| 1831 |
-
self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
|
| 1832 |
-
self.q_subtile_factor if self.q_subtile_factor is not None else 1,
|
| 1833 |
-
)
|
| 1834 |
-
|
| 1835 |
-
tile_scheduler.prefetch_next_work()
|
| 1836 |
-
tile_scheduler.advance_to_next_work()
|
| 1837 |
-
work_tile = tile_scheduler.get_current_work()
|
| 1838 |
-
# End of persistent scheduler loop
|
| 1839 |
-
|
| 1840 |
-
@cute.jit
|
| 1841 |
-
def mma(
|
| 1842 |
-
self,
|
| 1843 |
-
tiled_mma_qk: cute.TiledMma,
|
| 1844 |
-
tiled_mma_pv: cute.TiledMma,
|
| 1845 |
-
# softmax: Softmax,
|
| 1846 |
-
# acc_O: cute.Tensor,
|
| 1847 |
-
mQ: cute.Tensor,
|
| 1848 |
-
mO: cute.Tensor,
|
| 1849 |
-
mLSE: Optional[cute.Tensor],
|
| 1850 |
-
sQ: cute.Tensor,
|
| 1851 |
-
sK: cute.Tensor,
|
| 1852 |
-
sVt: cute.Tensor,
|
| 1853 |
-
sP: Optional[cute.Tensor],
|
| 1854 |
-
sO: cute.Tensor,
|
| 1855 |
-
learnable_sink: Optional[cute.Tensor],
|
| 1856 |
-
pipeline_k: cutlass.pipeline.PipelineAsync,
|
| 1857 |
-
pipeline_v: cutlass.pipeline.PipelineAsync,
|
| 1858 |
-
mbar_ptr_Q: cutlass.Pointer,
|
| 1859 |
-
gmem_tiled_copy_Q: cute.TiledCopy,
|
| 1860 |
-
gmem_tiled_copy_O: cute.TiledCopy,
|
| 1861 |
-
tma_atom_O: Optional[cute.CopyAtom],
|
| 1862 |
-
tidx: Int32,
|
| 1863 |
-
softmax_scale_log2: Float32,
|
| 1864 |
-
softmax_scale: Optional[Float32],
|
| 1865 |
-
block_info: BlockInfo,
|
| 1866 |
-
SeqlenInfoCls: Callable,
|
| 1867 |
-
AttentionMaskCls: Callable,
|
| 1868 |
-
TileSchedulerCls: Callable,
|
| 1869 |
-
blocksparse_tensors: Optional[BlockSparseTensors],
|
| 1870 |
-
aux_tensors: Optional[list],
|
| 1871 |
-
fastdiv_mods=None,
|
| 1872 |
-
):
|
| 1873 |
-
warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group)
|
| 1874 |
-
warp_group_thread_layout = cute.make_layout(
|
| 1875 |
-
self.num_mma_warp_groups, stride=self.num_threads_per_warp_group
|
| 1876 |
-
)
|
| 1877 |
-
thr_mma_qk = tiled_mma_qk.get_slice(tidx)
|
| 1878 |
-
wg_mma_qk = tiled_mma_qk.get_slice(warp_group_thread_layout(warp_group_idx))
|
| 1879 |
-
wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx))
|
| 1880 |
-
_, tSrQ, tSrK = sm90_utils.partition_fragment_ABC(
|
| 1881 |
-
wg_mma_qk, (self.tile_m, self.tile_n, self.tile_hdim), sQ, sK
|
| 1882 |
-
)
|
| 1883 |
-
mma_qk_fn = partial(
|
| 1884 |
-
sm90_utils.gemm_zero_init, tiled_mma_qk, (self.tile_m, self.tile_n), tSrQ, tSrK
|
| 1885 |
-
)
|
| 1886 |
-
acc_O, tOrP, tOrVt = sm90_utils.partition_fragment_ABC(
|
| 1887 |
-
wg_mma_pv, (self.tile_m, self.tile_hdimv, self.tile_n), sP, sVt
|
| 1888 |
-
)
|
| 1889 |
-
mma_pv_fn = partial(sm90_utils.gemm_w_idx, tiled_mma_pv, acc_O, tOrP, tOrVt)
|
| 1890 |
-
|
| 1891 |
-
# ///////////////////////////////////////////////////////////////////////////////
|
| 1892 |
-
# Smem copy atom tiling
|
| 1893 |
-
# ///////////////////////////////////////////////////////////////////////////////
|
| 1894 |
-
smem_copy_atom_P = utils.get_smem_store_atom(self.arch, self.dtype)
|
| 1895 |
-
smem_thr_copy_P = cute.make_tiled_copy_C(smem_copy_atom_P, tiled_mma_qk).get_slice(tidx)
|
| 1896 |
-
tPsP = smem_thr_copy_P.partition_D(sP) if const_expr(sP is not None) else None
|
| 1897 |
-
smem_copy_params = SimpleNamespace(smem_thr_copy_P=smem_thr_copy_P, tPsP=tPsP)
|
| 1898 |
-
|
| 1899 |
-
self.mma_init()
|
| 1900 |
-
|
| 1901 |
-
mma_one_n_block_all = partial(
|
| 1902 |
-
self.mma_one_n_block_intrawg_overlap
|
| 1903 |
-
if const_expr(self.intra_wg_overlap)
|
| 1904 |
-
else self.mma_one_n_block,
|
| 1905 |
-
mma_qk_fn=mma_qk_fn,
|
| 1906 |
-
pipeline_k=pipeline_k,
|
| 1907 |
-
pipeline_v=pipeline_v,
|
| 1908 |
-
acc_O=acc_O,
|
| 1909 |
-
tOrP=tOrP,
|
| 1910 |
-
smem_copy_params=smem_copy_params,
|
| 1911 |
-
check_inf=True,
|
| 1912 |
-
)
|
| 1913 |
-
|
| 1914 |
-
q_consumer_phase = Int32(0)
|
| 1915 |
-
kv_consumer_state = pipeline.make_pipeline_state(
|
| 1916 |
-
cutlass.pipeline.PipelineUserType.Consumer, self.num_stages
|
| 1917 |
-
)
|
| 1918 |
-
|
| 1919 |
-
tile_scheduler = TileSchedulerCls()
|
| 1920 |
-
work_tile = tile_scheduler.initial_work_tile_info()
|
| 1921 |
-
softmax = Softmax.create(
|
| 1922 |
-
softmax_scale_log2,
|
| 1923 |
-
num_rows=acc_O.shape[0][0] * acc_O.shape[1],
|
| 1924 |
-
softmax_scale=softmax_scale,
|
| 1925 |
-
)
|
| 1926 |
-
|
| 1927 |
-
process_first_half_block = partial(
|
| 1928 |
-
self.first_half_block_overlap,
|
| 1929 |
-
mma_qk_fn=mma_qk_fn,
|
| 1930 |
-
pipeline_k=pipeline_k,
|
| 1931 |
-
tOrP=tOrP,
|
| 1932 |
-
smem_copy_params=smem_copy_params,
|
| 1933 |
-
softmax=softmax,
|
| 1934 |
-
)
|
| 1935 |
-
process_last_half_block = partial(
|
| 1936 |
-
self.last_half_block_overlap,
|
| 1937 |
-
pipeline_v=pipeline_v,
|
| 1938 |
-
mma_pv_fn=mma_pv_fn,
|
| 1939 |
-
)
|
| 1940 |
-
while work_tile.is_valid_tile:
|
| 1941 |
-
# if work_tile.is_valid_tile:
|
| 1942 |
-
|
| 1943 |
-
# shape: (atom_v_m * rest_m)
|
| 1944 |
-
m_block, head_idx, batch_idx, _ = work_tile.tile_idx
|
| 1945 |
-
seqlen = SeqlenInfoCls(batch_idx)
|
| 1946 |
-
|
| 1947 |
-
# Recompute fastdiv_mods if necessary for varlen with aux_tensors
|
| 1948 |
-
recompute_fastdiv_mods_q = cutlass.const_expr(
|
| 1949 |
-
aux_tensors is not None and (seqlen.has_cu_seqlens_q or seqlen.has_seqused_q)
|
| 1950 |
-
)
|
| 1951 |
-
recompute_fastdiv_mods_k = cutlass.const_expr(
|
| 1952 |
-
aux_tensors is not None and (seqlen.has_cu_seqlens_k or seqlen.has_seqused_k)
|
| 1953 |
-
)
|
| 1954 |
-
if cutlass.const_expr(fastdiv_mods is not None):
|
| 1955 |
-
seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods
|
| 1956 |
-
fastdiv_mods = (
|
| 1957 |
-
seqlen_q_divmod
|
| 1958 |
-
if not recompute_fastdiv_mods_q
|
| 1959 |
-
else FastDivmodDivisor(seqlen.seqlen_q),
|
| 1960 |
-
seqlen_k_divmod
|
| 1961 |
-
if not recompute_fastdiv_mods_k
|
| 1962 |
-
else FastDivmodDivisor(seqlen.seqlen_k),
|
| 1963 |
-
)
|
| 1964 |
-
|
| 1965 |
-
mask = AttentionMaskCls(seqlen)
|
| 1966 |
-
mask_fn = partial(
|
| 1967 |
-
mask.apply_mask,
|
| 1968 |
-
batch_idx=batch_idx,
|
| 1969 |
-
head_idx=head_idx,
|
| 1970 |
-
m_block=m_block,
|
| 1971 |
-
thr_mma=thr_mma_qk,
|
| 1972 |
-
mask_causal=self.is_causal,
|
| 1973 |
-
mask_local=self.is_local,
|
| 1974 |
-
aux_tensors=aux_tensors,
|
| 1975 |
-
fastdiv_mods=fastdiv_mods,
|
| 1976 |
-
)
|
| 1977 |
-
score_mod_fn = None
|
| 1978 |
-
if const_expr(self.score_mod is not None):
|
| 1979 |
-
score_mod_fn = partial(
|
| 1980 |
-
self.apply_score_mod,
|
| 1981 |
-
thr_mma_qk,
|
| 1982 |
-
batch_idx,
|
| 1983 |
-
head_idx,
|
| 1984 |
-
m_block,
|
| 1985 |
-
softmax_scale=softmax_scale,
|
| 1986 |
-
aux_tensors=aux_tensors,
|
| 1987 |
-
fastdiv_mods=fastdiv_mods,
|
| 1988 |
-
)
|
| 1989 |
-
mma_one_n_block = partial(
|
| 1990 |
-
mma_one_n_block_all,
|
| 1991 |
-
seqlen=seqlen,
|
| 1992 |
-
softmax=softmax,
|
| 1993 |
-
score_mod_fn=score_mod_fn,
|
| 1994 |
-
)
|
| 1995 |
-
# Load Q if not TMA_Q
|
| 1996 |
-
if const_expr(not self.use_tma_Q):
|
| 1997 |
-
pack_gqa = PackGQA(
|
| 1998 |
-
self.tile_m, self.tile_hdim, self.check_hdim_oob, self.qhead_per_kvhead
|
| 1999 |
-
)
|
| 2000 |
-
mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx]
|
| 2001 |
-
# gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(tidx)
|
| 2002 |
-
# gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0))
|
| 2003 |
-
# self.load_Q(gmem_thr_copy_Q, gQ, sQ, m_block, seqlen=seqlen.seqlen_q,
|
| 2004 |
-
# headdim=mQ.shape[1])
|
| 2005 |
-
pack_gqa.load_Q(mQ_cur, sQ, gmem_tiled_copy_Q, tidx, m_block, seqlen.seqlen_q)
|
| 2006 |
-
cute.arch.cp_async_mbarrier_arrive_noinc(mbar_ptr_Q)
|
| 2007 |
-
|
| 2008 |
-
n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block)
|
| 2009 |
-
if const_expr(not self.use_tma_Q):
|
| 2010 |
-
cute.arch.mbarrier_wait(mbar_ptr_Q, phase=q_consumer_phase)
|
| 2011 |
-
q_consumer_phase ^= 1
|
| 2012 |
-
# For performance reason, we separate out two kinds of iterations:
|
| 2013 |
-
# those that need masking on S, and those that don't.
|
| 2014 |
-
# We need masking on S for the very last block when K and V has length not multiple of tile_n.
|
| 2015 |
-
# We also need masking on S if it's causal, for the last several blocks.
|
| 2016 |
-
# softmax.reset() # Don't need reset as we explicitly call softmax w is_first=True
|
| 2017 |
-
O_should_accumulate = False
|
| 2018 |
-
|
| 2019 |
-
# ==========================================
|
| 2020 |
-
# MAINLOOP
|
| 2021 |
-
# ==========================================
|
| 2022 |
-
if const_expr(not self.use_block_sparsity):
|
| 2023 |
-
# ==========================================
|
| 2024 |
-
# No block-sparsity (original path)
|
| 2025 |
-
# ==========================================
|
| 2026 |
-
# First iteration with seqlen masking
|
| 2027 |
-
if const_expr(self.intra_wg_overlap):
|
| 2028 |
-
kv_consumer_state = process_first_half_block(
|
| 2029 |
-
n_block=n_block_max - 1,
|
| 2030 |
-
seqlen=seqlen,
|
| 2031 |
-
kv_consumer_state=kv_consumer_state,
|
| 2032 |
-
mask_fn=partial(mask_fn, mask_mod=self.mask_mod),
|
| 2033 |
-
score_mod_fn=score_mod_fn,
|
| 2034 |
-
is_first_block=True,
|
| 2035 |
-
)
|
| 2036 |
-
# Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter
|
| 2037 |
-
# acc_O.fill(0.0)
|
| 2038 |
-
else:
|
| 2039 |
-
self.warp_scheduler_barrier_sync()
|
| 2040 |
-
kv_consumer_state = mma_one_n_block(
|
| 2041 |
-
kv_consumer_state,
|
| 2042 |
-
n_block=n_block_max - 1,
|
| 2043 |
-
seqlen=seqlen,
|
| 2044 |
-
mma_pv_fn=partial(mma_pv_fn, zero_init=True),
|
| 2045 |
-
is_first_n_block=True,
|
| 2046 |
-
mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=True),
|
| 2047 |
-
)
|
| 2048 |
-
O_should_accumulate = True
|
| 2049 |
-
# if cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block_max = {}, n_block_min = {}", m_block, n_block_max, n_block_min)
|
| 2050 |
-
n_block_max -= 1
|
| 2051 |
-
# Next couple of iterations with causal masking
|
| 2052 |
-
if const_expr(self.is_causal or self.is_local):
|
| 2053 |
-
n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask(
|
| 2054 |
-
seqlen, m_block, n_block_min
|
| 2055 |
-
)
|
| 2056 |
-
# if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_causal_local_mask = {}", n_block_min_causal_local_mask)
|
| 2057 |
-
for n_tile in cutlass.range(
|
| 2058 |
-
n_block_max - n_block_min_causal_local_mask, unroll=1
|
| 2059 |
-
):
|
| 2060 |
-
kv_consumer_state = mma_one_n_block(
|
| 2061 |
-
kv_consumer_state,
|
| 2062 |
-
n_block=n_block_max - 1 - n_tile,
|
| 2063 |
-
seqlen=seqlen,
|
| 2064 |
-
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
|
| 2065 |
-
mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False),
|
| 2066 |
-
)
|
| 2067 |
-
O_should_accumulate = True
|
| 2068 |
-
n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask)
|
| 2069 |
-
# The remaining iterations have no masking
|
| 2070 |
-
n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask(
|
| 2071 |
-
seqlen, m_block, n_block_min
|
| 2072 |
-
)
|
| 2073 |
-
# if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_before_local_mask = {}, n_block_min = {}", n_block_min_before_local_mask, n_block_min)
|
| 2074 |
-
for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1):
|
| 2075 |
-
kv_consumer_state = mma_one_n_block(
|
| 2076 |
-
kv_consumer_state,
|
| 2077 |
-
n_block=n_block_max - 1 - n_tile,
|
| 2078 |
-
seqlen=seqlen,
|
| 2079 |
-
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
|
| 2080 |
-
mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False),
|
| 2081 |
-
)
|
| 2082 |
-
O_should_accumulate = True
|
| 2083 |
-
# Separate iterations with local masking on the left
|
| 2084 |
-
if const_expr(self.is_local and block_info.window_size_left is not None):
|
| 2085 |
-
n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask)
|
| 2086 |
-
for n_tile in cutlass.range(n_block_max - n_block_min, unroll=1):
|
| 2087 |
-
kv_consumer_state = mma_one_n_block(
|
| 2088 |
-
kv_consumer_state,
|
| 2089 |
-
n_block=n_block_max - 1 - n_tile,
|
| 2090 |
-
seqlen=seqlen,
|
| 2091 |
-
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
|
| 2092 |
-
mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False),
|
| 2093 |
-
)
|
| 2094 |
-
O_should_accumulate = True
|
| 2095 |
-
# Last "half" iteration
|
| 2096 |
-
if const_expr(self.intra_wg_overlap):
|
| 2097 |
-
kv_consumer_state = process_last_half_block(
|
| 2098 |
-
kv_consumer_state=kv_consumer_state,
|
| 2099 |
-
zero_init=not O_should_accumulate,
|
| 2100 |
-
)
|
| 2101 |
-
O_should_accumulate = True
|
| 2102 |
-
else:
|
| 2103 |
-
self.warp_scheduler_barrier_arrive()
|
| 2104 |
-
|
| 2105 |
-
else:
|
| 2106 |
-
# ==========================================
|
| 2107 |
-
# Block sparsity
|
| 2108 |
-
# ==========================================
|
| 2109 |
-
kv_consumer_state, O_should_accumulate, processed_any = consume_block_sparse_loads(
|
| 2110 |
-
blocksparse_tensors,
|
| 2111 |
-
batch_idx,
|
| 2112 |
-
head_idx,
|
| 2113 |
-
m_block,
|
| 2114 |
-
seqlen,
|
| 2115 |
-
kv_consumer_state,
|
| 2116 |
-
mma_pv_fn,
|
| 2117 |
-
mma_one_n_block,
|
| 2118 |
-
process_first_half_block,
|
| 2119 |
-
process_last_half_block,
|
| 2120 |
-
mask_fn,
|
| 2121 |
-
score_mod_fn,
|
| 2122 |
-
O_should_accumulate,
|
| 2123 |
-
self.mask_mod,
|
| 2124 |
-
fastdiv_mods,
|
| 2125 |
-
self.intra_wg_overlap,
|
| 2126 |
-
self.warp_scheduler_barrier_sync,
|
| 2127 |
-
self.warp_scheduler_barrier_arrive,
|
| 2128 |
-
self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
|
| 2129 |
-
self.q_subtile_factor if self.q_subtile_factor is not None else 1,
|
| 2130 |
-
)
|
| 2131 |
-
|
| 2132 |
-
# Handle empty case (when no blocks to process)
|
| 2133 |
-
if not processed_any:
|
| 2134 |
-
softmax.reset()
|
| 2135 |
-
acc_O.fill(0.0)
|
| 2136 |
-
|
| 2137 |
-
sink_val = None
|
| 2138 |
-
if const_expr(learnable_sink is not None):
|
| 2139 |
-
if const_expr(not self.pack_gqa):
|
| 2140 |
-
sink_val = Float32(learnable_sink[head_idx])
|
| 2141 |
-
else: # Each thread might have a different sink value due to different q_head
|
| 2142 |
-
sink_val = cute.make_fragment_like(softmax.row_max, Float32)
|
| 2143 |
-
cS = cute.make_identity_tensor((self.tile_m, self.tile_n))
|
| 2144 |
-
tScS_mn = layout_utils.reshape_acc_to_mn(thr_mma_qk.partition_C(cS))
|
| 2145 |
-
for r in cutlass.range(cute.size(sink_val), unroll_full=True):
|
| 2146 |
-
row = m_block * self.tile_m + tScS_mn[r][0]
|
| 2147 |
-
q_head_idx = row % self.qhead_per_kvhead + head_idx * self.qhead_per_kvhead
|
| 2148 |
-
sink_val[r] = Float32(learnable_sink[q_head_idx])
|
| 2149 |
-
|
| 2150 |
-
# normalize acc_O by row_sum and calculate the lse
|
| 2151 |
-
row_scale = softmax.finalize(sink_val=sink_val)
|
| 2152 |
-
softmax.rescale_O(acc_O, row_scale)
|
| 2153 |
-
|
| 2154 |
-
# ///////////////////////////////////////////////////////////////////////////////
|
| 2155 |
-
# Epilogue
|
| 2156 |
-
# ///////////////////////////////////////////////////////////////////////////////
|
| 2157 |
-
self.epilogue(
|
| 2158 |
-
acc_O,
|
| 2159 |
-
softmax.row_sum,
|
| 2160 |
-
mO,
|
| 2161 |
-
mLSE,
|
| 2162 |
-
sO,
|
| 2163 |
-
seqlen,
|
| 2164 |
-
gmem_tiled_copy_O,
|
| 2165 |
-
tma_atom_O,
|
| 2166 |
-
tiled_mma_pv,
|
| 2167 |
-
tidx,
|
| 2168 |
-
m_block,
|
| 2169 |
-
head_idx,
|
| 2170 |
-
batch_idx,
|
| 2171 |
-
)
|
| 2172 |
-
|
| 2173 |
-
tile_scheduler.advance_to_next_work()
|
| 2174 |
-
work_tile = tile_scheduler.get_current_work()
|
| 2175 |
-
|
| 2176 |
-
|
| 2177 |
-
@cute.jit
|
| 2178 |
-
def first_half_block_overlap(
|
| 2179 |
-
self,
|
| 2180 |
-
n_block: Int32,
|
| 2181 |
-
mma_qk_fn: Callable,
|
| 2182 |
-
kv_consumer_state,
|
| 2183 |
-
pipeline_k,
|
| 2184 |
-
tOrP: cute.Tensor,
|
| 2185 |
-
smem_copy_params: SimpleNamespace,
|
| 2186 |
-
softmax: Softmax,
|
| 2187 |
-
seqlen: SeqlenInfoQK,
|
| 2188 |
-
mask_fn: Callable = None,
|
| 2189 |
-
score_mod_fn: Optional[Callable] = None,
|
| 2190 |
-
is_first_block: bool = False,
|
| 2191 |
-
):
|
| 2192 |
-
"""Processes the first half block when using intra-warpgroup-overlap"""
|
| 2193 |
-
|
| 2194 |
-
pipeline_k.consumer_wait(kv_consumer_state, pipeline_k.consumer_try_wait(kv_consumer_state))
|
| 2195 |
-
acc_S = mma_qk_fn(B_idx=kv_consumer_state.index, wg_wait=0)
|
| 2196 |
-
pipeline_k.consumer_release(kv_consumer_state)
|
| 2197 |
-
|
| 2198 |
-
# Apply score modification if present
|
| 2199 |
-
if const_expr(score_mod_fn is not None):
|
| 2200 |
-
score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen)
|
| 2201 |
-
|
| 2202 |
-
# Apply mask; mask_seqlen always True for first block
|
| 2203 |
-
# Caveat: if full block further right than mask block, seqlen masking is redundant;
|
| 2204 |
-
# however, masking is being applied anyway, so essentially no perf hit
|
| 2205 |
-
mask_fn(acc_S, n_block=n_block, mask_seqlen=True)
|
| 2206 |
-
|
| 2207 |
-
softmax.online_softmax(acc_S, is_first=is_first_block)
|
| 2208 |
-
|
| 2209 |
-
tOrP_acc = layout_utils.reshape_acc_to_frgA(acc_S)
|
| 2210 |
-
tOrP_cur = (
|
| 2211 |
-
tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype)
|
| 2212 |
-
)
|
| 2213 |
-
tOrP_cur.store(tOrP_acc.load().to(self.dtype))
|
| 2214 |
-
|
| 2215 |
-
# if pv gemm not rs
|
| 2216 |
-
if const_expr(not self.mma_pv_is_rs):
|
| 2217 |
-
tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur)
|
| 2218 |
-
cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP)
|
| 2219 |
-
# Fence and barrier to make smem store visible to WGMMA
|
| 2220 |
-
cute.arch.fence_view_async_shared()
|
| 2221 |
-
cute.arch.sync_warp()
|
| 2222 |
-
|
| 2223 |
-
return kv_consumer_state
|
| 2224 |
-
|
| 2225 |
-
@cute.jit
|
| 2226 |
-
def last_half_block_overlap(
|
| 2227 |
-
self,
|
| 2228 |
-
kv_consumer_state,
|
| 2229 |
-
pipeline_v,
|
| 2230 |
-
mma_pv_fn: Callable,
|
| 2231 |
-
zero_init: bool,
|
| 2232 |
-
):
|
| 2233 |
-
"""Processes the final PV GEMM when using intra-warpgroup-overlap"""
|
| 2234 |
-
|
| 2235 |
-
pipeline_v.consumer_wait(kv_consumer_state, pipeline_v.consumer_try_wait(kv_consumer_state))
|
| 2236 |
-
mma_pv_fn(B_idx=kv_consumer_state.index, zero_init=zero_init, wg_wait=0)
|
| 2237 |
-
pipeline_v.consumer_release(kv_consumer_state)
|
| 2238 |
-
kv_consumer_state.advance()
|
| 2239 |
-
return kv_consumer_state
|
| 2240 |
-
|
| 2241 |
-
@cute.jit
|
| 2242 |
-
def mma_one_n_block(
|
| 2243 |
-
self,
|
| 2244 |
-
smem_pipe_read: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple,
|
| 2245 |
-
n_block: Int32,
|
| 2246 |
-
mma_qk_fn: Callable,
|
| 2247 |
-
mma_pv_fn: Callable,
|
| 2248 |
-
pipeline_k: cutlass.pipeline.PipelineAsync,
|
| 2249 |
-
pipeline_v: cutlass.pipeline.PipelineAsync,
|
| 2250 |
-
acc_O: cute.Tensor,
|
| 2251 |
-
tOrP: cute.Tensor,
|
| 2252 |
-
smem_copy_params: SimpleNamespace,
|
| 2253 |
-
softmax: Softmax,
|
| 2254 |
-
seqlen: SeqlenInfoQK,
|
| 2255 |
-
score_mod_fn: Optional[Callable] = None,
|
| 2256 |
-
mask_fn: Optional[Callable] = None,
|
| 2257 |
-
is_first_n_block: cutlass.Constexpr = False,
|
| 2258 |
-
check_inf: cutlass.Constexpr = True,
|
| 2259 |
-
):
|
| 2260 |
-
pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read))
|
| 2261 |
-
# S = Q @ K.T
|
| 2262 |
-
acc_S = mma_qk_fn(B_idx=smem_pipe_read.index, wg_wait=-1)
|
| 2263 |
-
self.warp_scheduler_barrier_arrive()
|
| 2264 |
-
warpgroup.wait_group(0)
|
| 2265 |
-
pipeline_k.consumer_release(smem_pipe_read)
|
| 2266 |
-
|
| 2267 |
-
# handle score mods and masking
|
| 2268 |
-
if const_expr(score_mod_fn is not None):
|
| 2269 |
-
score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen)
|
| 2270 |
-
if const_expr(mask_fn is not None):
|
| 2271 |
-
mask_fn(acc_S=acc_S, n_block=n_block)
|
| 2272 |
-
|
| 2273 |
-
row_scale = softmax.online_softmax(acc_S, is_first=is_first_n_block, check_inf=check_inf)
|
| 2274 |
-
# if cute.arch.thread_idx()[0] == 0: cute.print_tensor(layout_utils.reshape_acc_to_mn(acc_S))
|
| 2275 |
-
tOrP_acc = layout_utils.reshape_acc_to_frgA(acc_S)
|
| 2276 |
-
tOrP_cur = (
|
| 2277 |
-
tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype)
|
| 2278 |
-
)
|
| 2279 |
-
# tOrP.store(tOrP_acc.load().to(self.dtype))
|
| 2280 |
-
# the "to(self.dtype)" conversion fails to vectorize for block sizes other
|
| 2281 |
-
# than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of
|
| 2282 |
-
# 2 elements. So we just call ptx directly.
|
| 2283 |
-
utils.cvt_f16(tOrP_acc, tOrP_cur)
|
| 2284 |
-
if const_expr(not self.mma_pv_is_rs):
|
| 2285 |
-
tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur)
|
| 2286 |
-
cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP)
|
| 2287 |
-
softmax.rescale_O(acc_O, row_scale)
|
| 2288 |
-
if const_expr(not self.mma_pv_is_rs):
|
| 2289 |
-
# Fence and barrier to make sure smem store is visible to WGMMA
|
| 2290 |
-
cute.arch.fence_view_async_shared()
|
| 2291 |
-
cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV
|
| 2292 |
-
pipeline_v.consumer_wait(smem_pipe_read, pipeline_v.consumer_try_wait(smem_pipe_read))
|
| 2293 |
-
self.warp_scheduler_barrier_sync()
|
| 2294 |
-
# O += P @ V
|
| 2295 |
-
mma_pv_fn(B_idx=smem_pipe_read.index, wg_wait=0)
|
| 2296 |
-
pipeline_v.consumer_release(smem_pipe_read)
|
| 2297 |
-
smem_pipe_read.advance()
|
| 2298 |
-
return smem_pipe_read
|
| 2299 |
-
|
| 2300 |
-
@cute.jit
|
| 2301 |
-
def mma_one_n_block_intrawg_overlap(
|
| 2302 |
-
self,
|
| 2303 |
-
smem_pipe_read: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple,
|
| 2304 |
-
n_block: Int32,
|
| 2305 |
-
mma_qk_fn: Callable,
|
| 2306 |
-
mma_pv_fn: Callable,
|
| 2307 |
-
pipeline_k: cutlass.pipeline.PipelineAsync,
|
| 2308 |
-
pipeline_v: cutlass.pipeline.PipelineAsync,
|
| 2309 |
-
acc_O: cute.Tensor,
|
| 2310 |
-
tOrP: cute.Tensor,
|
| 2311 |
-
smem_copy_params: SimpleNamespace,
|
| 2312 |
-
softmax: Softmax,
|
| 2313 |
-
seqlen: SeqlenInfoQK,
|
| 2314 |
-
score_mod_fn: Optional[Callable] = None,
|
| 2315 |
-
mask_fn: Optional[Callable] = None,
|
| 2316 |
-
check_inf: cutlass.Constexpr = True,
|
| 2317 |
-
):
|
| 2318 |
-
smem_pipe_read_v = smem_pipe_read.clone()
|
| 2319 |
-
smem_pipe_read.advance()
|
| 2320 |
-
pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read))
|
| 2321 |
-
self.warp_scheduler_barrier_sync()
|
| 2322 |
-
# S = Q @ K.T
|
| 2323 |
-
acc_S = mma_qk_fn(B_idx=smem_pipe_read.index, wg_wait=-1)
|
| 2324 |
-
pipeline_v.consumer_wait(smem_pipe_read_v, pipeline_v.consumer_try_wait(smem_pipe_read_v))
|
| 2325 |
-
# O += P @ V
|
| 2326 |
-
mma_pv_fn(B_idx=smem_pipe_read_v.index, wg_wait=-1)
|
| 2327 |
-
self.warp_scheduler_barrier_arrive()
|
| 2328 |
-
warpgroup.wait_group(1)
|
| 2329 |
-
pipeline_k.consumer_release(smem_pipe_read)
|
| 2330 |
-
|
| 2331 |
-
# handle score mods and masking
|
| 2332 |
-
if const_expr(score_mod_fn is not None):
|
| 2333 |
-
score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen)
|
| 2334 |
-
if const_expr(mask_fn is not None):
|
| 2335 |
-
mask_fn(acc_S=acc_S, n_block=n_block)
|
| 2336 |
-
# if cute.arch.thread_idx()[0] == 128: cute.print_tensor(layout_utils.reshape_acc_to_mn(acc_S))
|
| 2337 |
-
|
| 2338 |
-
row_scale = softmax.online_softmax(acc_S, check_inf=check_inf)
|
| 2339 |
-
warpgroup.wait_group(0)
|
| 2340 |
-
pipeline_v.consumer_release(smem_pipe_read_v)
|
| 2341 |
-
tOrP_acc = layout_utils.reshape_acc_to_frgA(acc_S)
|
| 2342 |
-
tOrP_cur = (
|
| 2343 |
-
tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype)
|
| 2344 |
-
)
|
| 2345 |
-
# tOrP_cur.store(tOrP_acc.load().to(self.dtype))
|
| 2346 |
-
# the "to(self.dtype)" conversion fails to vectorize for block sizes other
|
| 2347 |
-
# than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of
|
| 2348 |
-
# 2 elements. So we just call ptx directly.
|
| 2349 |
-
utils.cvt_f16(tOrP_acc, tOrP_cur)
|
| 2350 |
-
if const_expr(not self.mma_pv_is_rs):
|
| 2351 |
-
tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur)
|
| 2352 |
-
cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP)
|
| 2353 |
-
softmax.rescale_O(acc_O, row_scale)
|
| 2354 |
-
if const_expr(not self.mma_pv_is_rs):
|
| 2355 |
-
# Fence and barrier to make sure smem store is visible to WGMMA
|
| 2356 |
-
cute.arch.fence_view_async_shared()
|
| 2357 |
-
cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV
|
| 2358 |
-
return smem_pipe_read
|
| 2359 |
-
|
| 2360 |
-
@cute.jit
|
| 2361 |
-
def mma_init(self):
|
| 2362 |
-
warp_group_idx = utils.canonical_warp_group_idx(sync=False)
|
| 2363 |
-
if const_expr(self.use_scheduler_barrier):
|
| 2364 |
-
if warp_group_idx == 1:
|
| 2365 |
-
cute.arch.barrier_arrive(
|
| 2366 |
-
barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1),
|
| 2367 |
-
number_of_threads=2 * self.num_threads_per_warp_group,
|
| 2368 |
-
)
|
| 2369 |
-
|
| 2370 |
-
@cute.jit
|
| 2371 |
-
def apply_score_mod(
|
| 2372 |
-
self,
|
| 2373 |
-
thr_mma_qk,
|
| 2374 |
-
batch_idx,
|
| 2375 |
-
head_idx,
|
| 2376 |
-
m_block,
|
| 2377 |
-
acc_S,
|
| 2378 |
-
n_block,
|
| 2379 |
-
softmax_scale,
|
| 2380 |
-
seqlen,
|
| 2381 |
-
aux_tensors: Optional[list] = None,
|
| 2382 |
-
fastdiv_mods=None,
|
| 2383 |
-
):
|
| 2384 |
-
# Prepare index tensor
|
| 2385 |
-
cS = cute.make_identity_tensor((self.tile_m, self.tile_n))
|
| 2386 |
-
cS = cute.domain_offset((m_block * self.tile_m, n_block * self.tile_n), cS)
|
| 2387 |
-
tScS = thr_mma_qk.partition_C(cS)
|
| 2388 |
-
|
| 2389 |
-
apply_score_mod_inner(
|
| 2390 |
-
acc_S,
|
| 2391 |
-
tScS,
|
| 2392 |
-
self.score_mod,
|
| 2393 |
-
batch_idx,
|
| 2394 |
-
head_idx,
|
| 2395 |
-
softmax_scale,
|
| 2396 |
-
self.vec_size,
|
| 2397 |
-
self.qk_acc_dtype,
|
| 2398 |
-
aux_tensors,
|
| 2399 |
-
fastdiv_mods,
|
| 2400 |
-
seqlen_info=seqlen,
|
| 2401 |
-
constant_q_idx=None,
|
| 2402 |
-
qhead_per_kvhead=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
|
| 2403 |
-
)
|
| 2404 |
-
|
| 2405 |
-
def warp_scheduler_barrier_sync(self):
|
| 2406 |
-
if const_expr(self.use_scheduler_barrier):
|
| 2407 |
-
cute.arch.barrier(
|
| 2408 |
-
barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1)
|
| 2409 |
-
- 1
|
| 2410 |
-
+ utils.canonical_warp_group_idx(sync=False),
|
| 2411 |
-
number_of_threads=2 * self.num_threads_per_warp_group,
|
| 2412 |
-
)
|
| 2413 |
-
|
| 2414 |
-
def warp_scheduler_barrier_arrive(self):
|
| 2415 |
-
if const_expr(self.use_scheduler_barrier):
|
| 2416 |
-
assert self.num_mma_warp_groups in [2, 3]
|
| 2417 |
-
cur_wg = utils.canonical_warp_group_idx(sync=False) - 1
|
| 2418 |
-
if const_expr(self.num_mma_warp_groups == 2):
|
| 2419 |
-
next_wg = 1 - cur_wg
|
| 2420 |
-
else:
|
| 2421 |
-
t = cur_wg + 1
|
| 2422 |
-
next_wg = t % self.num_mma_warp_groups
|
| 2423 |
-
cute.arch.barrier_arrive(
|
| 2424 |
-
barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + next_wg,
|
| 2425 |
-
number_of_threads=2 * self.num_threads_per_warp_group,
|
| 2426 |
-
)
|
|
|
|
| 15 |
import cutlass
|
| 16 |
import cutlass.cute as cute
|
| 17 |
from cutlass import Constexpr, Float32, Int32, const_expr, Boolean
|
| 18 |
+
from cutlass.cute.nvgpu import cpasync, warp
|
| 19 |
import cutlass.utils as utils_basic
|
| 20 |
+
from cutlass.base_dsl.arch import Arch
|
| 21 |
+
from cutlass.cutlass_dsl import BaseDSL
|
| 22 |
|
| 23 |
from .quack import copy_utils
|
| 24 |
from .quack import layout_utils
|
|
|
|
| 25 |
|
| 26 |
from . import ampere_helpers as sm80_utils
|
| 27 |
from .cute_dsl_utils import assume_tensor_aligned
|
| 28 |
from . import utils
|
| 29 |
from .mask import AttentionMask
|
| 30 |
+
from .softmax import Softmax
|
| 31 |
from .seqlen_info import SeqlenInfoQK
|
| 32 |
from .block_info import BlockInfo
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
from .pack_gqa import PackGQA
|
| 34 |
from .named_barrier import NamedBarrierFwd
|
| 35 |
+
from .block_sparsity import BlockSparseTensors
|
| 36 |
+
from .tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
class FlashAttentionForwardBase:
|
|
|
|
| 40 |
|
| 41 |
def __init__(
|
| 42 |
self,
|
|
|
|
| 102 |
self.vec_size: cutlass.Constexpr = getattr(
|
| 103 |
score_mod, "__vec_size__", 1 if cutlass.const_expr(has_aux_tensors) else 2
|
| 104 |
)
|
| 105 |
+
if self.vec_size > 2:
|
| 106 |
+
raise ValueError(
|
| 107 |
+
f"score_mod vec_size {self.vec_size} not supported on Sm80/90/120 "
|
| 108 |
+
"due to accumulator thread ownership pattern."
|
| 109 |
+
)
|
| 110 |
+
self.arch = BaseDSL._get_dsl().get_arch_enum()
|
| 111 |
|
| 112 |
@staticmethod
|
| 113 |
def can_implement(
|
|
|
|
| 310 |
mO: cute.Tensor,
|
| 311 |
mLSE: Optional[cute.Tensor],
|
| 312 |
softmax_scale: Float32,
|
| 313 |
+
# Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).
|
| 314 |
+
stream: cuda.CUstream = None,
|
| 315 |
):
|
| 316 |
"""Configures and launches the flash attention kernel.
|
| 317 |
|
|
|
|
| 344 |
cute.arch.barrier(
|
| 345 |
barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads
|
| 346 |
)
|
| 347 |
+
smem_copy_atom_O = utils.get_smem_store_atom(self.arch.major * 10 + self.arch.minor, self.dtype)
|
| 348 |
smem_thr_copy_O = cute.make_tiled_copy_C(smem_copy_atom_O, tiled_mma).get_slice(tidx)
|
| 349 |
taccOrO = smem_thr_copy_O.retile(rO)
|
| 350 |
taccOsO = smem_thr_copy_O.partition_D(sO)
|
|
|
|
| 359 |
|
| 360 |
# Write LSE from rmem -> gmem
|
| 361 |
if const_expr(mLSE is not None):
|
| 362 |
+
mLSE_cur = seqlen.offset_batch_Q(mLSE, batch_idx, dim=2)[None, head_idx]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
if const_expr(not self.pack_gqa):
|
| 364 |
gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (m_block,))
|
| 365 |
gLSE_expanded_layout = cute.append(
|
|
|
|
| 373 |
t0accOcO = layout_utils.reshape_acc_to_mn(thr_mma.get_slice(0).partition_C(cO))
|
| 374 |
# Only the thread corresponding to column 0 writes out the lse to gmem
|
| 375 |
if taccOcO[0][1] == 0:
|
| 376 |
+
for m in cutlass.range(cute.size(taccOgLSE.shape[1]), unroll_full=True):
|
| 377 |
if (
|
| 378 |
t0accOcO[m, 0][0]
|
| 379 |
< seqlen.seqlen_q - m_block * self.tile_m - taccOcO[0][0]
|
|
|
|
| 382 |
else:
|
| 383 |
pack_gqa.store_LSE(mLSE_cur, lse, tiled_mma, tidx, m_block, seqlen.seqlen_q)
|
| 384 |
|
| 385 |
+
ragged = self.use_tma_O and (seqlen.has_cu_seqlens_q or seqlen.has_seqused_q)
|
| 386 |
+
mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3, ragged=ragged)[None, None, head_idx]
|
|
|
|
|
|
|
|
|
|
| 387 |
# thr_mma = tiled_mma.get_slice(tidx)
|
| 388 |
# taccOgO = thr_mma.partition_C(gO)
|
| 389 |
# cute.autovec_copy(rO, taccOgO)
|
|
|
|
| 620 |
mV: cute.Tensor,
|
| 621 |
mO: cute.Tensor,
|
| 622 |
mLSE: Optional[cute.Tensor],
|
| 623 |
+
softmax_scale: Float32,
|
| 624 |
+
mCuSeqlensQ: Optional[cute.Tensor] = None,
|
| 625 |
+
mCuSeqlensK: Optional[cute.Tensor] = None,
|
| 626 |
+
mSeqUsedQ: Optional[cute.Tensor] = None,
|
| 627 |
+
mSeqUsedK: Optional[cute.Tensor] = None,
|
| 628 |
+
mPageTable: Optional[cute.Tensor] = None,
|
| 629 |
window_size_left: Optional[Int32] = None,
|
| 630 |
window_size_right: Optional[Int32] = None,
|
| 631 |
learnable_sink: Optional[cute.Tensor] = None,
|
| 632 |
+
blocksparse_tensors: Optional[BlockSparseTensors] = None,
|
| 633 |
aux_tensors=None,
|
| 634 |
+
# Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).
|
| 635 |
+
stream: cuda.CUstream = None,
|
| 636 |
):
|
| 637 |
"""Configures and launches the flash attention kernel.
|
| 638 |
|
|
|
|
| 641 |
"""
|
| 642 |
assert learnable_sink is None, "Learnable sink is not supported in this kernel"
|
| 643 |
self._check_type(
|
| 644 |
+
*(t.element_type if t is not None else None for t in (mQ, mK, mV, mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK))
|
| 645 |
)
|
| 646 |
tiled_mma_qk, tiled_mma_pv = self._get_tiled_mma()
|
| 647 |
self.num_mma_threads = tiled_mma_pv.size
|
|
|
|
| 649 |
self.num_Q_load_threads = self.num_threads
|
| 650 |
self.num_epilogue_threads = self.num_threads
|
| 651 |
# self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None
|
| 652 |
+
self.use_tma_O = self.arch >= Arch.sm_90
|
| 653 |
self._setup_attributes()
|
| 654 |
SharedStorage = self._get_shared_storage_cls()
|
| 655 |
mQ, mK, mV, mO = [assume_tensor_aligned(t) for t in (mQ, mK, mV, mO)]
|
| 656 |
+
# Layout permutation: 4D non-varlen vs 3D varlen
|
| 657 |
+
QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1]
|
| 658 |
+
KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1]
|
| 659 |
+
mQ, mO = [
|
| 660 |
+
cute.make_tensor(t.iterator, cute.select(t.layout, mode=QO_layout_transpose))
|
| 661 |
+
for t in (mQ, mO)
|
| 662 |
]
|
| 663 |
+
mK, mV = [
|
| 664 |
+
cute.make_tensor(t.iterator, cute.select(t.layout, mode=KV_layout_transpose))
|
| 665 |
+
for t in (mK, mV)
|
| 666 |
+
]
|
| 667 |
+
if const_expr(mLSE is not None):
|
| 668 |
+
LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0]
|
| 669 |
+
mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose))
|
| 670 |
+
# TileScheduler for varlen, simple grid for non-varlen
|
| 671 |
+
if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None):
|
| 672 |
+
TileScheduler = SingleTileVarlenScheduler
|
|
|
|
| 673 |
else:
|
| 674 |
+
TileScheduler = SingleTileScheduler
|
| 675 |
+
num_batch = (
|
| 676 |
+
mCuSeqlensQ.shape[0] - 1
|
| 677 |
+
if const_expr(mCuSeqlensQ is not None)
|
| 678 |
+
else mQ.shape[3]
|
| 679 |
+
)
|
| 680 |
+
tile_sched_args = TileSchedulerArguments(
|
| 681 |
+
num_block=cute.ceil_div(mQ.shape[0], self.tile_m),
|
| 682 |
+
num_head=cute.size(mQ.shape[2]),
|
| 683 |
+
num_batch=num_batch,
|
| 684 |
+
num_splits=1,
|
| 685 |
+
seqlen_k=0,
|
| 686 |
+
headdim=mQ.shape[1],
|
| 687 |
+
headdim_v=mV.shape[1],
|
| 688 |
+
total_q=cute.size(mQ.shape[0])
|
| 689 |
+
if const_expr(mCuSeqlensQ is not None)
|
| 690 |
+
else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]),
|
| 691 |
+
tile_shape_mn=(self.tile_m, self.tile_n),
|
| 692 |
+
qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
|
| 693 |
+
mCuSeqlensQ=mCuSeqlensQ,
|
| 694 |
+
mSeqUsedQ=mSeqUsedQ,
|
| 695 |
+
)
|
| 696 |
+
tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
|
| 697 |
+
grid_dim = TileScheduler.get_grid_shape(tile_sched_params)
|
| 698 |
+
softmax_scale_log2, softmax_scale = utils.compute_softmax_scale_log2(softmax_scale, self.score_mod)
|
| 699 |
+
fastdiv_mods = utils.compute_fastdiv_mods(mQ, mK, self.qhead_per_kvhead, self.pack_gqa, aux_tensors)
|
| 700 |
|
| 701 |
self.kernel(
|
| 702 |
mQ,
|
|
|
|
| 704 |
mV,
|
| 705 |
mO,
|
| 706 |
mLSE,
|
| 707 |
+
mCuSeqlensQ,
|
| 708 |
+
mCuSeqlensK,
|
| 709 |
+
mSeqUsedQ,
|
| 710 |
+
mSeqUsedK,
|
| 711 |
softmax_scale_log2,
|
| 712 |
softmax_scale,
|
| 713 |
window_size_left,
|
|
|
|
| 724 |
tiled_mma_qk,
|
| 725 |
tiled_mma_pv,
|
| 726 |
SharedStorage,
|
| 727 |
+
tile_sched_params,
|
| 728 |
+
TileScheduler,
|
| 729 |
aux_tensors,
|
| 730 |
fastdiv_mods,
|
| 731 |
).launch(
|
|
|
|
| 743 |
mV: cute.Tensor,
|
| 744 |
mO: cute.Tensor,
|
| 745 |
mLSE: Optional[cute.Tensor],
|
| 746 |
+
mCuSeqlensQ: Optional[cute.Tensor],
|
| 747 |
+
mCuSeqlensK: Optional[cute.Tensor],
|
| 748 |
+
mSeqUsedQ: Optional[cute.Tensor],
|
| 749 |
+
mSeqUsedK: Optional[cute.Tensor],
|
| 750 |
softmax_scale_log2: Float32,
|
| 751 |
softmax_scale: Optional[Float32],
|
| 752 |
window_size_left: Optional[Int32],
|
|
|
|
| 763 |
tiled_mma_qk: cute.TiledMma,
|
| 764 |
tiled_mma_pv: cute.TiledMma,
|
| 765 |
SharedStorage: cutlass.Constexpr,
|
| 766 |
+
tile_sched_params,
|
| 767 |
+
TileScheduler: cutlass.Constexpr[Callable],
|
| 768 |
aux_tensors=None,
|
| 769 |
fastdiv_mods=None,
|
| 770 |
):
|
| 771 |
# Thread index, block index
|
| 772 |
tidx, _, _ = cute.arch.thread_idx()
|
| 773 |
+
|
| 774 |
+
tile_scheduler = TileScheduler.create(tile_sched_params)
|
| 775 |
+
work_tile = tile_scheduler.initial_work_tile_info()
|
| 776 |
+
m_block, num_head, batch_size, _ = work_tile.tile_idx
|
| 777 |
|
| 778 |
block_info = BlockInfo(
|
| 779 |
self.tile_m,
|
|
|
|
| 785 |
window_size_right,
|
| 786 |
qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
|
| 787 |
)
|
| 788 |
+
seqlen = SeqlenInfoQK.create(
|
| 789 |
+
batch_idx=batch_size,
|
| 790 |
+
seqlen_q_static=mQ.shape[0],
|
| 791 |
+
seqlen_k_static=mK.shape[0],
|
| 792 |
+
mCuSeqlensQ=mCuSeqlensQ,
|
| 793 |
+
mCuSeqlensK=mCuSeqlensK,
|
| 794 |
+
mSeqUsedQ=mSeqUsedQ,
|
| 795 |
+
mSeqUsedK=mSeqUsedK,
|
| 796 |
+
)
|
| 797 |
n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block)
|
| 798 |
+
# For varlen, wasted grid tiles (where batch_idx >= num_batch) will have
|
| 799 |
+
# seqlen_q=seqlen_k=0 and n_block_max=0. Clamp to 0 so we don't use a
|
| 800 |
+
# negative block index for K/V loads; the load/store predicates already
|
| 801 |
+
# guard all memory accesses when seqlen is 0.
|
| 802 |
+
n_block = cutlass.max(n_block_max - 1, 0)
|
| 803 |
|
| 804 |
# ///////////////////////////////////////////////////////////////////////////////
|
| 805 |
# Get the appropriate tiles for this thread block.
|
|
|
|
| 807 |
blkQ_shape = (self.tile_m, self.tile_hdim)
|
| 808 |
blkK_shape = (self.tile_n, self.tile_hdim)
|
| 809 |
blkV_shape = (self.tile_n, self.tile_hdimv)
|
|
|
|
| 810 |
num_head_kv = num_head // self.qhead_per_kvhead
|
| 811 |
+
if const_expr(not seqlen.has_cu_seqlens_q):
|
| 812 |
+
mQ_cur = mQ[None, None, num_head, batch_size]
|
| 813 |
+
else:
|
| 814 |
+
mQ_cur = cute.domain_offset((seqlen.offset_q, 0), mQ[None, None, num_head])
|
| 815 |
+
if const_expr(not seqlen.has_cu_seqlens_k):
|
| 816 |
+
mK_cur = mK[None, None, num_head_kv, batch_size]
|
| 817 |
+
mV_cur = mV[None, None, num_head_kv, batch_size]
|
| 818 |
+
else:
|
| 819 |
+
mK_cur = cute.domain_offset((seqlen.offset_k, 0), mK[None, None, num_head_kv])
|
| 820 |
+
mV_cur = cute.domain_offset((seqlen.offset_k, 0), mV[None, None, num_head_kv])
|
| 821 |
+
gQ = cute.local_tile(mQ_cur, blkQ_shape, (m_block, 0))
|
| 822 |
+
gK = cute.local_tile(mK_cur, blkK_shape, (None, 0))
|
| 823 |
+
gV = cute.local_tile(mV_cur, blkV_shape, (None, 0))
|
| 824 |
|
| 825 |
# ///////////////////////////////////////////////////////////////////////////////
|
| 826 |
# Get shared memory buffer
|
|
|
|
| 992 |
mask = AttentionMask(
|
| 993 |
self.tile_m,
|
| 994 |
self.tile_n,
|
| 995 |
+
seqlen,
|
|
|
|
| 996 |
window_size_left,
|
| 997 |
window_size_right,
|
| 998 |
self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
|
| 999 |
)
|
| 1000 |
mask_fn = partial(
|
| 1001 |
mask.apply_mask,
|
| 1002 |
+
batch_idx=batch_size,
|
| 1003 |
+
head_idx=num_head,
|
| 1004 |
m_block=m_block,
|
| 1005 |
thr_mma=thr_mma_qk,
|
| 1006 |
mask_causal=self.is_causal,
|
| 1007 |
mask_local=self.is_local,
|
| 1008 |
+
aux_tensors=aux_tensors,
|
| 1009 |
fastdiv_mods=fastdiv_mods if const_expr(self.mask_mod is not None) else None,
|
| 1010 |
)
|
| 1011 |
|
|
|
|
| 1017 |
smem_pipe_read,
|
| 1018 |
smem_pipe_write,
|
| 1019 |
is_first_n_block=True,
|
| 1020 |
+
seqlen=seqlen,
|
| 1021 |
+
mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=True),
|
| 1022 |
)
|
| 1023 |
smem_pipe_read = self.advance_pipeline(smem_pipe_read)
|
| 1024 |
smem_pipe_write = self.advance_pipeline(smem_pipe_write)
|
|
|
|
| 1033 |
n_block,
|
| 1034 |
smem_pipe_read,
|
| 1035 |
smem_pipe_write,
|
| 1036 |
+
seqlen=seqlen,
|
| 1037 |
+
mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=True),
|
| 1038 |
)
|
| 1039 |
smem_pipe_read = self.advance_pipeline(smem_pipe_read)
|
| 1040 |
smem_pipe_write = self.advance_pipeline(smem_pipe_write)
|
| 1041 |
# The remaining iterations have no masking
|
| 1042 |
for n_tile in cutlass.range(n_block, unroll=1):
|
| 1043 |
compute_one_n_block(
|
| 1044 |
+
n_block - n_tile - 1, smem_pipe_read, smem_pipe_write,
|
| 1045 |
+
seqlen=seqlen, is_first_n_block=False,
|
| 1046 |
+
mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False)
|
| 1047 |
)
|
| 1048 |
smem_pipe_read = self.advance_pipeline(smem_pipe_read)
|
| 1049 |
smem_pipe_write = self.advance_pipeline(smem_pipe_write)
|
|
|
|
| 1187 |
# load_K_next()
|
| 1188 |
|
| 1189 |
|
| 1190 |
+
# SM90 forward pass moved to flash_fwd_sm90.py; re-export for backward compatibility
|
| 1191 |
+
def __getattr__(name):
|
| 1192 |
+
if name == "FlashAttentionForwardSm90":
|
| 1193 |
+
from .flash_fwd_sm90 import FlashAttentionForwardSm90
|
| 1194 |
+
return FlashAttentionForwardSm90
|
| 1195 |
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch-cuda/flash_fwd_combine.py
CHANGED
|
@@ -10,7 +10,7 @@ import cuda.bindings.driver as cuda
|
|
| 10 |
import cutlass
|
| 11 |
import cutlass.cute as cute
|
| 12 |
from cutlass.cute.nvgpu import cpasync
|
| 13 |
-
from cutlass import Float32, Int32, const_expr
|
| 14 |
|
| 15 |
from . import utils
|
| 16 |
from .cute_dsl_utils import assume_tensor_aligned
|
|
@@ -24,7 +24,7 @@ class FlashAttentionForwardCombine:
|
|
| 24 |
dtype: Type[cutlass.Numeric],
|
| 25 |
dtype_partial: Type[cutlass.Numeric],
|
| 26 |
head_dim: int,
|
| 27 |
-
|
| 28 |
k_block_size: int = 64,
|
| 29 |
log_max_splits: int = 4,
|
| 30 |
num_threads: int = 256,
|
|
@@ -36,7 +36,7 @@ class FlashAttentionForwardCombine:
|
|
| 36 |
:param dtype: output data type
|
| 37 |
:param dtype_partial: partial accumulation data type
|
| 38 |
:param head_dim: head dimension
|
| 39 |
-
:param
|
| 40 |
:param k_block_size: k block size
|
| 41 |
:param log_max_splits: log2 of maximum splits
|
| 42 |
:param num_threads: number of threads
|
|
@@ -46,7 +46,7 @@ class FlashAttentionForwardCombine:
|
|
| 46 |
self.dtype = dtype
|
| 47 |
self.dtype_partial = dtype_partial
|
| 48 |
self.head_dim = head_dim
|
| 49 |
-
self.
|
| 50 |
self.k_block_size = k_block_size
|
| 51 |
self.max_splits = 1 << log_max_splits
|
| 52 |
self.num_threads = num_threads
|
|
@@ -58,7 +58,7 @@ class FlashAttentionForwardCombine:
|
|
| 58 |
dtype,
|
| 59 |
dtype_partial,
|
| 60 |
head_dim,
|
| 61 |
-
|
| 62 |
k_block_size,
|
| 63 |
log_max_splits,
|
| 64 |
num_threads,
|
|
@@ -72,12 +72,12 @@ class FlashAttentionForwardCombine:
|
|
| 72 |
return False
|
| 73 |
if num_threads % 32 != 0:
|
| 74 |
return False
|
| 75 |
-
if
|
| 76 |
return False
|
| 77 |
max_splits = 1 << log_max_splits
|
| 78 |
if max_splits > 256:
|
| 79 |
return False
|
| 80 |
-
if (
|
| 81 |
return False
|
| 82 |
return True
|
| 83 |
|
|
@@ -124,15 +124,11 @@ class FlashAttentionForwardCombine:
|
|
| 124 |
lse_copy_bits = Float32.width # 1 element per copy, width is in bits
|
| 125 |
m_block_smem = (
|
| 126 |
128
|
| 127 |
-
if self.
|
| 128 |
else (
|
| 129 |
64
|
| 130 |
-
if self.
|
| 131 |
-
else (
|
| 132 |
-
32
|
| 133 |
-
if self.m_block_size % 32 == 0
|
| 134 |
-
else (16 if self.m_block_size % 16 == 0 else 8)
|
| 135 |
-
)
|
| 136 |
)
|
| 137 |
)
|
| 138 |
gmem_threads_per_row_lse = m_block_smem
|
|
@@ -183,12 +179,12 @@ class FlashAttentionForwardCombine:
|
|
| 183 |
smem_lse_swizzle, 0, cute.make_ordered_layout((8, m_block_smem), order=(1, 0))
|
| 184 |
)
|
| 185 |
self.smem_layout_lse = cute.tile_to_shape(
|
| 186 |
-
smem_layout_atom_lse, (self.max_splits, self.
|
| 187 |
)
|
| 188 |
|
| 189 |
# O partial shared memory layout (simple layout for pipeline stages)
|
| 190 |
self.smem_layout_o = cute.make_ordered_layout(
|
| 191 |
-
(self.
|
| 192 |
)
|
| 193 |
|
| 194 |
@cute.jit
|
|
@@ -201,7 +197,9 @@ class FlashAttentionForwardCombine:
|
|
| 201 |
cu_seqlens: Optional[cute.Tensor] = None,
|
| 202 |
seqused: Optional[cute.Tensor] = None,
|
| 203 |
num_splits_dynamic_ptr: Optional[cute.Tensor] = None,
|
|
|
|
| 204 |
semaphore_to_reset: Optional[cute.Tensor] = None,
|
|
|
|
| 205 |
stream: cuda.CUstream = None,
|
| 206 |
):
|
| 207 |
# Type checking
|
|
@@ -269,7 +267,7 @@ class FlashAttentionForwardCombine:
|
|
| 269 |
sLSE: cute.struct.Align[
|
| 270 |
cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128
|
| 271 |
]
|
| 272 |
-
sMaxValidSplit: cute.struct.Align[cute.struct.MemRange[Int32, self.
|
| 273 |
sO: cute.struct.Align[
|
| 274 |
cute.struct.MemRange[self.dtype_partial, cute.cosize(self.smem_layout_o)], 128
|
| 275 |
]
|
|
@@ -290,7 +288,7 @@ class FlashAttentionForwardCombine:
|
|
| 290 |
head_divmod = FastDivmodDivisor(num_head)
|
| 291 |
|
| 292 |
grid_dim = (
|
| 293 |
-
cute.ceil_div(seqlen * num_head, self.
|
| 294 |
cute.ceil_div(self.head_dim, self.k_block_size),
|
| 295 |
batch_size,
|
| 296 |
)
|
|
@@ -303,6 +301,7 @@ class FlashAttentionForwardCombine:
|
|
| 303 |
cu_seqlens,
|
| 304 |
seqused,
|
| 305 |
num_splits_dynamic_ptr,
|
|
|
|
| 306 |
semaphore_to_reset,
|
| 307 |
SharedStorage,
|
| 308 |
self.smem_layout_lse,
|
|
@@ -331,6 +330,7 @@ class FlashAttentionForwardCombine:
|
|
| 331 |
cu_seqlens: Optional[cute.Tensor],
|
| 332 |
seqused: Optional[cute.Tensor],
|
| 333 |
num_splits_dynamic_ptr: Optional[cute.Tensor],
|
|
|
|
| 334 |
semaphore_to_reset: Optional[cute.Tensor],
|
| 335 |
SharedStorage: cutlass.Constexpr,
|
| 336 |
smem_layout_lse: cute.Layout | cute.ComposedLayout,
|
|
@@ -345,7 +345,14 @@ class FlashAttentionForwardCombine:
|
|
| 345 |
):
|
| 346 |
# Thread and block indices
|
| 347 |
tidx, _, _ = cute.arch.thread_idx()
|
| 348 |
-
m_block, k_block,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
|
| 350 |
# ///////////////////////////////////////////////////////////////////////////////
|
| 351 |
# Get shared memory buffer
|
|
@@ -353,22 +360,23 @@ class FlashAttentionForwardCombine:
|
|
| 353 |
smem = cutlass.utils.SmemAllocator()
|
| 354 |
storage = smem.allocate(SharedStorage)
|
| 355 |
sLSE = storage.sLSE.get_tensor(smem_layout_lse)
|
| 356 |
-
sMaxValidSplit = storage.sMaxValidSplit.get_tensor((self.
|
| 357 |
sO = storage.sO.get_tensor(smem_layout_o)
|
| 358 |
|
| 359 |
-
# Handle semaphore reset
|
| 360 |
if const_expr(semaphore_to_reset is not None):
|
| 361 |
if (
|
| 362 |
tidx == 0
|
| 363 |
and m_block == cute.arch.grid_dim()[0] - 1
|
| 364 |
and k_block == cute.arch.grid_dim()[1] - 1
|
| 365 |
-
and
|
| 366 |
):
|
|
|
|
| 367 |
semaphore_to_reset[0] = 0
|
| 368 |
|
| 369 |
-
# Get number of splits
|
| 370 |
num_splits = (
|
| 371 |
-
num_splits_dynamic_ptr[
|
| 372 |
if const_expr(num_splits_dynamic_ptr is not None)
|
| 373 |
else mLSE_partial.shape[1]
|
| 374 |
)
|
|
@@ -378,6 +386,7 @@ class FlashAttentionForwardCombine:
|
|
| 378 |
seqlen_static=mO_partial.shape[0],
|
| 379 |
cu_seqlens=cu_seqlens,
|
| 380 |
seqused=seqused,
|
|
|
|
| 381 |
)
|
| 382 |
seqlen, offset = seqlen_info.seqlen, seqlen_info.offset
|
| 383 |
|
|
@@ -387,29 +396,27 @@ class FlashAttentionForwardCombine:
|
|
| 387 |
|
| 388 |
# Early exit for single split if dynamic
|
| 389 |
if (const_expr(num_splits_dynamic_ptr is None) or num_splits > 1) and (
|
| 390 |
-
const_expr(not varlen) or m_block * self.
|
| 391 |
):
|
|
|
|
|
|
|
|
|
|
| 392 |
# ===============================
|
| 393 |
# Step 1: Load LSE_partial from gmem to shared memory
|
| 394 |
# ===============================
|
| 395 |
|
| 396 |
-
|
| 397 |
-
mLSE_partial_cur = mLSE_partial[None, None, None, batch_idx]
|
| 398 |
-
else:
|
| 399 |
-
mLSE_partial_cur = cute.domain_offset((offset, 0, 0), mLSE_partial)
|
| 400 |
mLSE_partial_copy = cute.tiled_divide(mLSE_partial_cur, (1,))
|
| 401 |
-
|
| 402 |
gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_slice(tidx)
|
| 403 |
tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE)
|
| 404 |
-
|
| 405 |
# Create identity tensor for coordinate tracking
|
| 406 |
-
cLSE = cute.make_identity_tensor((self.max_splits, self.
|
| 407 |
tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE)
|
| 408 |
|
| 409 |
# Load LSE partial values
|
| 410 |
for m in cutlass.range(cute.size(tLSEcLSE, mode=[2]), unroll_full=True):
|
| 411 |
mi = tLSEcLSE[0, 0, m][1] # Get m coordinate
|
| 412 |
-
idx = m_block * self.
|
| 413 |
if idx < max_idx:
|
| 414 |
# Calculate actual sequence position and head using FastDivmodDivisor
|
| 415 |
if const_expr(not varlen):
|
|
@@ -436,22 +443,19 @@ class FlashAttentionForwardCombine:
|
|
| 436 |
# ===============================
|
| 437 |
|
| 438 |
gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_slice(tidx)
|
| 439 |
-
cO = cute.make_identity_tensor((self.
|
| 440 |
tOcO = gmem_thr_copy_O_partial.partition_D(cO)
|
| 441 |
tOsO_partial = gmem_thr_copy_O_partial.partition_D(sO)
|
| 442 |
-
|
| 443 |
-
mO_partial_cur = mO_partial[None, None, None, None, batch_idx]
|
| 444 |
-
else:
|
| 445 |
-
mO_partial_cur = cute.domain_offset((offset, 0, 0, 0), mO_partial)
|
| 446 |
|
| 447 |
# Precompute these values to avoid recomputing them in the loop
|
| 448 |
num_rows = const_expr(cute.size(tOcO, mode=[1]))
|
| 449 |
-
tOmidx = cute.
|
| 450 |
-
tOhidx = cute.
|
| 451 |
-
tOrOptr = cute.
|
| 452 |
for m in cutlass.range(num_rows, unroll_full=True):
|
| 453 |
mi = tOcO[0, m, 0][0] # m coordinate
|
| 454 |
-
idx = m_block * self.
|
| 455 |
if const_expr(not varlen):
|
| 456 |
tOhidx[m], tOmidx[m] = divmod(idx, seqlen_divmod)
|
| 457 |
else:
|
|
@@ -463,11 +467,12 @@ class FlashAttentionForwardCombine:
|
|
| 463 |
if idx >= max_idx:
|
| 464 |
tOhidx[m] = -1
|
| 465 |
|
| 466 |
-
tOpO =
|
| 467 |
if const_expr(not self.is_even_k):
|
|
|
|
| 468 |
for k in cutlass.range(cute.size(tOpO), unroll_full=True):
|
| 469 |
tOpO[k] = tOcO[0, 0, k][1] < mO_partial.shape[1] - k_block * self.k_block_size
|
| 470 |
-
|
| 471 |
|
| 472 |
load_O_partial = partial(
|
| 473 |
self.load_O_partial,
|
|
@@ -501,17 +506,17 @@ class FlashAttentionForwardCombine:
|
|
| 501 |
|
| 502 |
s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_slice(tidx)
|
| 503 |
ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE)
|
| 504 |
-
ts2rrLSE = cute.
|
| 505 |
cute.copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE)
|
| 506 |
|
| 507 |
# ===============================
|
| 508 |
# Step 4: Compute final LSE along split dimension
|
| 509 |
# ===============================
|
| 510 |
|
| 511 |
-
lse_sum = cute.
|
| 512 |
ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE)
|
| 513 |
# We compute the max valid split for each row to short-circuit the computation later
|
| 514 |
-
max_valid_split = cute.
|
| 515 |
assert cute.size(ts2rrLSE, mode=[0]) == 1
|
| 516 |
# Compute max, scales, and final LSE for each row
|
| 517 |
for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True):
|
|
@@ -561,7 +566,7 @@ class FlashAttentionForwardCombine:
|
|
| 561 |
for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True):
|
| 562 |
if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes
|
| 563 |
mi = ts2rcLSE[0, 0, m][1]
|
| 564 |
-
if mi < self.
|
| 565 |
sMaxValidSplit[mi] = max_valid_split[m]
|
| 566 |
|
| 567 |
# ===============================
|
|
@@ -577,7 +582,7 @@ class FlashAttentionForwardCombine:
|
|
| 577 |
for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True):
|
| 578 |
if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes
|
| 579 |
mi = ts2rcLSE[0, 0, m][1]
|
| 580 |
-
idx = m_block * self.
|
| 581 |
if idx < max_idx:
|
| 582 |
if const_expr(not varlen):
|
| 583 |
head_idx, m_idx = divmod(idx, seqlen_divmod)
|
|
@@ -594,11 +599,11 @@ class FlashAttentionForwardCombine:
|
|
| 594 |
|
| 595 |
# Get max valid split for this thread
|
| 596 |
thr_max_valid_split = sMaxValidSplit[tOcO[0, 0, 0][0]]
|
| 597 |
-
for m in cutlass.range(1, cute.size(tOcO, mode=[1])):
|
| 598 |
thr_max_valid_split = max(thr_max_valid_split, sMaxValidSplit[tOcO[0, m, 0][0]])
|
| 599 |
|
| 600 |
-
tOrO_partial = cute.
|
| 601 |
-
tOrO = cute.
|
| 602 |
tOrO.fill(0.0)
|
| 603 |
|
| 604 |
stage_load = self.stages - 1
|
|
@@ -607,7 +612,7 @@ class FlashAttentionForwardCombine:
|
|
| 607 |
# Main accumulation loop
|
| 608 |
for s in cutlass.range(thr_max_valid_split + 1, unroll=4):
|
| 609 |
# Get scales for this split
|
| 610 |
-
scale = cute.
|
| 611 |
for m in cutlass.range(num_rows, unroll_full=True):
|
| 612 |
scale[m] = sLSE[s, tOcO[0, m, 0][0]] # Get scale from smem
|
| 613 |
|
|
@@ -637,8 +642,9 @@ class FlashAttentionForwardCombine:
|
|
| 637 |
# Step 7: Write final O to gmem
|
| 638 |
# ===============================
|
| 639 |
|
| 640 |
-
rO = cute.
|
| 641 |
rO.store(tOrO.load().to(self.dtype))
|
|
|
|
| 642 |
if const_expr(cu_seqlens is None):
|
| 643 |
mO_cur = mO[None, None, None, batch_idx]
|
| 644 |
else:
|
|
@@ -665,7 +671,7 @@ class FlashAttentionForwardCombine:
|
|
| 665 |
tOrOptr: cute.Tensor,
|
| 666 |
tOsO_partial: cute.Tensor,
|
| 667 |
tOhidx: cute.Tensor,
|
| 668 |
-
tOpO: cute.Tensor,
|
| 669 |
tOcO: cute.Tensor,
|
| 670 |
mO_cur_partial_layout: cute.Layout,
|
| 671 |
split: Int32,
|
|
@@ -684,7 +690,7 @@ class FlashAttentionForwardCombine:
|
|
| 684 |
mO_partial_cur_copy = cute.tiled_divide(mO_partial_cur, (elems_per_load,))
|
| 685 |
for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True):
|
| 686 |
k_idx = tOcO[0, 0, k][1] // elems_per_load
|
| 687 |
-
if const_expr(
|
| 688 |
cute.copy(
|
| 689 |
gmem_tiled_copy_O_partial,
|
| 690 |
mO_partial_cur_copy[None, k_idx, split],
|
|
|
|
| 10 |
import cutlass
|
| 11 |
import cutlass.cute as cute
|
| 12 |
from cutlass.cute.nvgpu import cpasync
|
| 13 |
+
from cutlass import Float32, Int32, Boolean, const_expr
|
| 14 |
|
| 15 |
from . import utils
|
| 16 |
from .cute_dsl_utils import assume_tensor_aligned
|
|
|
|
| 24 |
dtype: Type[cutlass.Numeric],
|
| 25 |
dtype_partial: Type[cutlass.Numeric],
|
| 26 |
head_dim: int,
|
| 27 |
+
tile_m: int = 8,
|
| 28 |
k_block_size: int = 64,
|
| 29 |
log_max_splits: int = 4,
|
| 30 |
num_threads: int = 256,
|
|
|
|
| 36 |
:param dtype: output data type
|
| 37 |
:param dtype_partial: partial accumulation data type
|
| 38 |
:param head_dim: head dimension
|
| 39 |
+
:param tile_m: m block size
|
| 40 |
:param k_block_size: k block size
|
| 41 |
:param log_max_splits: log2 of maximum splits
|
| 42 |
:param num_threads: number of threads
|
|
|
|
| 46 |
self.dtype = dtype
|
| 47 |
self.dtype_partial = dtype_partial
|
| 48 |
self.head_dim = head_dim
|
| 49 |
+
self.tile_m = tile_m
|
| 50 |
self.k_block_size = k_block_size
|
| 51 |
self.max_splits = 1 << log_max_splits
|
| 52 |
self.num_threads = num_threads
|
|
|
|
| 58 |
dtype,
|
| 59 |
dtype_partial,
|
| 60 |
head_dim,
|
| 61 |
+
tile_m,
|
| 62 |
k_block_size,
|
| 63 |
log_max_splits,
|
| 64 |
num_threads,
|
|
|
|
| 72 |
return False
|
| 73 |
if num_threads % 32 != 0:
|
| 74 |
return False
|
| 75 |
+
if tile_m % 8 != 0:
|
| 76 |
return False
|
| 77 |
max_splits = 1 << log_max_splits
|
| 78 |
if max_splits > 256:
|
| 79 |
return False
|
| 80 |
+
if (tile_m * max_splits) % num_threads != 0:
|
| 81 |
return False
|
| 82 |
return True
|
| 83 |
|
|
|
|
| 124 |
lse_copy_bits = Float32.width # 1 element per copy, width is in bits
|
| 125 |
m_block_smem = (
|
| 126 |
128
|
| 127 |
+
if self.tile_m % 128 == 0
|
| 128 |
else (
|
| 129 |
64
|
| 130 |
+
if self.tile_m % 64 == 0
|
| 131 |
+
else (32 if self.tile_m % 32 == 0 else (16 if self.tile_m % 16 == 0 else 8))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
)
|
| 133 |
)
|
| 134 |
gmem_threads_per_row_lse = m_block_smem
|
|
|
|
| 179 |
smem_lse_swizzle, 0, cute.make_ordered_layout((8, m_block_smem), order=(1, 0))
|
| 180 |
)
|
| 181 |
self.smem_layout_lse = cute.tile_to_shape(
|
| 182 |
+
smem_layout_atom_lse, (self.max_splits, self.tile_m), (0, 1)
|
| 183 |
)
|
| 184 |
|
| 185 |
# O partial shared memory layout (simple layout for pipeline stages)
|
| 186 |
self.smem_layout_o = cute.make_ordered_layout(
|
| 187 |
+
(self.tile_m, self.k_block_size, self.stages), order=(1, 0, 2)
|
| 188 |
)
|
| 189 |
|
| 190 |
@cute.jit
|
|
|
|
| 197 |
cu_seqlens: Optional[cute.Tensor] = None,
|
| 198 |
seqused: Optional[cute.Tensor] = None,
|
| 199 |
num_splits_dynamic_ptr: Optional[cute.Tensor] = None,
|
| 200 |
+
varlen_batch_idx: Optional[cute.Tensor] = None,
|
| 201 |
semaphore_to_reset: Optional[cute.Tensor] = None,
|
| 202 |
+
# Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).
|
| 203 |
stream: cuda.CUstream = None,
|
| 204 |
):
|
| 205 |
# Type checking
|
|
|
|
| 267 |
sLSE: cute.struct.Align[
|
| 268 |
cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128
|
| 269 |
]
|
| 270 |
+
sMaxValidSplit: cute.struct.Align[cute.struct.MemRange[Int32, self.tile_m], 128]
|
| 271 |
sO: cute.struct.Align[
|
| 272 |
cute.struct.MemRange[self.dtype_partial, cute.cosize(self.smem_layout_o)], 128
|
| 273 |
]
|
|
|
|
| 288 |
head_divmod = FastDivmodDivisor(num_head)
|
| 289 |
|
| 290 |
grid_dim = (
|
| 291 |
+
cute.ceil_div(seqlen * num_head, self.tile_m),
|
| 292 |
cute.ceil_div(self.head_dim, self.k_block_size),
|
| 293 |
batch_size,
|
| 294 |
)
|
|
|
|
| 301 |
cu_seqlens,
|
| 302 |
seqused,
|
| 303 |
num_splits_dynamic_ptr,
|
| 304 |
+
varlen_batch_idx,
|
| 305 |
semaphore_to_reset,
|
| 306 |
SharedStorage,
|
| 307 |
self.smem_layout_lse,
|
|
|
|
| 330 |
cu_seqlens: Optional[cute.Tensor],
|
| 331 |
seqused: Optional[cute.Tensor],
|
| 332 |
num_splits_dynamic_ptr: Optional[cute.Tensor],
|
| 333 |
+
varlen_batch_idx: Optional[cute.Tensor],
|
| 334 |
semaphore_to_reset: Optional[cute.Tensor],
|
| 335 |
SharedStorage: cutlass.Constexpr,
|
| 336 |
smem_layout_lse: cute.Layout | cute.ComposedLayout,
|
|
|
|
| 345 |
):
|
| 346 |
# Thread and block indices
|
| 347 |
tidx, _, _ = cute.arch.thread_idx()
|
| 348 |
+
m_block, k_block, maybe_virtual_batch = cute.arch.block_idx()
|
| 349 |
+
|
| 350 |
+
# Map virtual batch index to real batch index (for persistent tile schedulers)
|
| 351 |
+
batch_idx = (
|
| 352 |
+
varlen_batch_idx[maybe_virtual_batch]
|
| 353 |
+
if const_expr(varlen_batch_idx is not None)
|
| 354 |
+
else maybe_virtual_batch
|
| 355 |
+
)
|
| 356 |
|
| 357 |
# ///////////////////////////////////////////////////////////////////////////////
|
| 358 |
# Get shared memory buffer
|
|
|
|
| 360 |
smem = cutlass.utils.SmemAllocator()
|
| 361 |
storage = smem.allocate(SharedStorage)
|
| 362 |
sLSE = storage.sLSE.get_tensor(smem_layout_lse)
|
| 363 |
+
sMaxValidSplit = storage.sMaxValidSplit.get_tensor((self.tile_m,))
|
| 364 |
sO = storage.sO.get_tensor(smem_layout_o)
|
| 365 |
|
| 366 |
+
# Handle semaphore reset — wait for dependent grids first
|
| 367 |
if const_expr(semaphore_to_reset is not None):
|
| 368 |
if (
|
| 369 |
tidx == 0
|
| 370 |
and m_block == cute.arch.grid_dim()[0] - 1
|
| 371 |
and k_block == cute.arch.grid_dim()[1] - 1
|
| 372 |
+
and maybe_virtual_batch == cute.arch.grid_dim()[2] - 1
|
| 373 |
):
|
| 374 |
+
cute.arch.griddepcontrol_wait()
|
| 375 |
semaphore_to_reset[0] = 0
|
| 376 |
|
| 377 |
+
# Get number of splits (use maybe_virtual_batch for per-batch-slot splits)
|
| 378 |
num_splits = (
|
| 379 |
+
num_splits_dynamic_ptr[maybe_virtual_batch]
|
| 380 |
if const_expr(num_splits_dynamic_ptr is not None)
|
| 381 |
else mLSE_partial.shape[1]
|
| 382 |
)
|
|
|
|
| 386 |
seqlen_static=mO_partial.shape[0],
|
| 387 |
cu_seqlens=cu_seqlens,
|
| 388 |
seqused=seqused,
|
| 389 |
+
# Don't need to pass in tile size since we won't use offset_padded
|
| 390 |
)
|
| 391 |
seqlen, offset = seqlen_info.seqlen, seqlen_info.offset
|
| 392 |
|
|
|
|
| 396 |
|
| 397 |
# Early exit for single split if dynamic
|
| 398 |
if (const_expr(num_splits_dynamic_ptr is None) or num_splits > 1) and (
|
| 399 |
+
const_expr(not varlen) or m_block * self.tile_m < max_idx
|
| 400 |
):
|
| 401 |
+
# Wait for dependent grids (e.g., the main attention kernel that produces O_partial/LSE_partial)
|
| 402 |
+
cute.arch.griddepcontrol_wait()
|
| 403 |
+
|
| 404 |
# ===============================
|
| 405 |
# Step 1: Load LSE_partial from gmem to shared memory
|
| 406 |
# ===============================
|
| 407 |
|
| 408 |
+
mLSE_partial_cur = seqlen_info.offset_batch(mLSE_partial, batch_idx, dim=3)
|
|
|
|
|
|
|
|
|
|
| 409 |
mLSE_partial_copy = cute.tiled_divide(mLSE_partial_cur, (1,))
|
|
|
|
| 410 |
gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_slice(tidx)
|
| 411 |
tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE)
|
|
|
|
| 412 |
# Create identity tensor for coordinate tracking
|
| 413 |
+
cLSE = cute.make_identity_tensor((self.max_splits, self.tile_m))
|
| 414 |
tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE)
|
| 415 |
|
| 416 |
# Load LSE partial values
|
| 417 |
for m in cutlass.range(cute.size(tLSEcLSE, mode=[2]), unroll_full=True):
|
| 418 |
mi = tLSEcLSE[0, 0, m][1] # Get m coordinate
|
| 419 |
+
idx = m_block * self.tile_m + mi
|
| 420 |
if idx < max_idx:
|
| 421 |
# Calculate actual sequence position and head using FastDivmodDivisor
|
| 422 |
if const_expr(not varlen):
|
|
|
|
| 443 |
# ===============================
|
| 444 |
|
| 445 |
gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_slice(tidx)
|
| 446 |
+
cO = cute.make_identity_tensor((self.tile_m, self.k_block_size))
|
| 447 |
tOcO = gmem_thr_copy_O_partial.partition_D(cO)
|
| 448 |
tOsO_partial = gmem_thr_copy_O_partial.partition_D(sO)
|
| 449 |
+
mO_partial_cur = seqlen_info.offset_batch(mO_partial, batch_idx, dim=4)
|
|
|
|
|
|
|
|
|
|
| 450 |
|
| 451 |
# Precompute these values to avoid recomputing them in the loop
|
| 452 |
num_rows = const_expr(cute.size(tOcO, mode=[1]))
|
| 453 |
+
tOmidx = cute.make_rmem_tensor(num_rows, cutlass.Int32)
|
| 454 |
+
tOhidx = cute.make_rmem_tensor(num_rows, cutlass.Int32)
|
| 455 |
+
tOrOptr = cute.make_rmem_tensor(num_rows, cutlass.Int64)
|
| 456 |
for m in cutlass.range(num_rows, unroll_full=True):
|
| 457 |
mi = tOcO[0, m, 0][0] # m coordinate
|
| 458 |
+
idx = m_block * self.tile_m + mi
|
| 459 |
if const_expr(not varlen):
|
| 460 |
tOhidx[m], tOmidx[m] = divmod(idx, seqlen_divmod)
|
| 461 |
else:
|
|
|
|
| 467 |
if idx >= max_idx:
|
| 468 |
tOhidx[m] = -1
|
| 469 |
|
| 470 |
+
tOpO = None
|
| 471 |
if const_expr(not self.is_even_k):
|
| 472 |
+
tOpO = cute.make_rmem_tensor(cute.size(tOcO, mode=[2]), Boolean)
|
| 473 |
for k in cutlass.range(cute.size(tOpO), unroll_full=True):
|
| 474 |
tOpO[k] = tOcO[0, 0, k][1] < mO_partial.shape[1] - k_block * self.k_block_size
|
| 475 |
+
# if cute.arch.thread_idx()[0] == 0 and k_block == 1: cute.print_tensor(tOpO)
|
| 476 |
|
| 477 |
load_O_partial = partial(
|
| 478 |
self.load_O_partial,
|
|
|
|
| 506 |
|
| 507 |
s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_slice(tidx)
|
| 508 |
ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE)
|
| 509 |
+
ts2rrLSE = cute.make_rmem_tensor_like(ts2rsLSE)
|
| 510 |
cute.copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE)
|
| 511 |
|
| 512 |
# ===============================
|
| 513 |
# Step 4: Compute final LSE along split dimension
|
| 514 |
# ===============================
|
| 515 |
|
| 516 |
+
lse_sum = cute.make_rmem_tensor(cute.size(ts2rrLSE, mode=[2]), Float32)
|
| 517 |
ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE)
|
| 518 |
# We compute the max valid split for each row to short-circuit the computation later
|
| 519 |
+
max_valid_split = cute.make_rmem_tensor(cute.size(ts2rrLSE, mode=[2]), Int32)
|
| 520 |
assert cute.size(ts2rrLSE, mode=[0]) == 1
|
| 521 |
# Compute max, scales, and final LSE for each row
|
| 522 |
for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True):
|
|
|
|
| 566 |
for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True):
|
| 567 |
if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes
|
| 568 |
mi = ts2rcLSE[0, 0, m][1]
|
| 569 |
+
if mi < self.tile_m:
|
| 570 |
sMaxValidSplit[mi] = max_valid_split[m]
|
| 571 |
|
| 572 |
# ===============================
|
|
|
|
| 582 |
for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True):
|
| 583 |
if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes
|
| 584 |
mi = ts2rcLSE[0, 0, m][1]
|
| 585 |
+
idx = m_block * self.tile_m + mi
|
| 586 |
if idx < max_idx:
|
| 587 |
if const_expr(not varlen):
|
| 588 |
head_idx, m_idx = divmod(idx, seqlen_divmod)
|
|
|
|
| 599 |
|
| 600 |
# Get max valid split for this thread
|
| 601 |
thr_max_valid_split = sMaxValidSplit[tOcO[0, 0, 0][0]]
|
| 602 |
+
for m in cutlass.range(1, cute.size(tOcO, mode=[1]), unroll_full=True):
|
| 603 |
thr_max_valid_split = max(thr_max_valid_split, sMaxValidSplit[tOcO[0, m, 0][0]])
|
| 604 |
|
| 605 |
+
tOrO_partial = cute.make_rmem_tensor_like(tOsO_partial[None, None, None, 0])
|
| 606 |
+
tOrO = cute.make_rmem_tensor_like(tOrO_partial, Float32)
|
| 607 |
tOrO.fill(0.0)
|
| 608 |
|
| 609 |
stage_load = self.stages - 1
|
|
|
|
| 612 |
# Main accumulation loop
|
| 613 |
for s in cutlass.range(thr_max_valid_split + 1, unroll=4):
|
| 614 |
# Get scales for this split
|
| 615 |
+
scale = cute.make_rmem_tensor(num_rows, Float32)
|
| 616 |
for m in cutlass.range(num_rows, unroll_full=True):
|
| 617 |
scale[m] = sLSE[s, tOcO[0, m, 0][0]] # Get scale from smem
|
| 618 |
|
|
|
|
| 642 |
# Step 7: Write final O to gmem
|
| 643 |
# ===============================
|
| 644 |
|
| 645 |
+
rO = cute.make_rmem_tensor_like(tOrO, self.dtype)
|
| 646 |
rO.store(tOrO.load().to(self.dtype))
|
| 647 |
+
mO_cur = seqlen_info.offset_batch(mO, batch_idx, dim=3)
|
| 648 |
if const_expr(cu_seqlens is None):
|
| 649 |
mO_cur = mO[None, None, None, batch_idx]
|
| 650 |
else:
|
|
|
|
| 671 |
tOrOptr: cute.Tensor,
|
| 672 |
tOsO_partial: cute.Tensor,
|
| 673 |
tOhidx: cute.Tensor,
|
| 674 |
+
tOpO: Optional[cute.Tensor],
|
| 675 |
tOcO: cute.Tensor,
|
| 676 |
mO_cur_partial_layout: cute.Layout,
|
| 677 |
split: Int32,
|
|
|
|
| 690 |
mO_partial_cur_copy = cute.tiled_divide(mO_partial_cur, (elems_per_load,))
|
| 691 |
for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True):
|
| 692 |
k_idx = tOcO[0, 0, k][1] // elems_per_load
|
| 693 |
+
if const_expr(tOpO is None) or tOpO[k]:
|
| 694 |
cute.copy(
|
| 695 |
gmem_tiled_copy_O_partial,
|
| 696 |
mO_partial_cur_copy[None, k_idx, split],
|
build/torch-cuda/flash_fwd_sm100.py
CHANGED
|
@@ -13,9 +13,8 @@
|
|
| 13 |
# https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha
|
| 14 |
# https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/fmha.py
|
| 15 |
|
| 16 |
-
import enum
|
| 17 |
import math
|
| 18 |
-
from typing import
|
| 19 |
from functools import partial
|
| 20 |
|
| 21 |
import cuda.bindings.driver as cuda
|
|
@@ -28,6 +27,7 @@ import cutlass.cute.nvgpu.tcgen05 as tcgen05
|
|
| 28 |
import cutlass.utils.blackwell_helpers as sm100_utils_basic
|
| 29 |
from cutlass import pipeline
|
| 30 |
from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait
|
|
|
|
| 31 |
from cutlass.base_dsl.arch import Arch
|
| 32 |
from cutlass.cutlass_dsl import BaseDSL
|
| 33 |
|
|
@@ -35,7 +35,9 @@ from .quack import copy_utils, layout_utils
|
|
| 35 |
|
| 36 |
from .paged_kv import PagedKVManager
|
| 37 |
from .cute_dsl_utils import assume_tensor_aligned
|
|
|
|
| 38 |
from . import pipeline as pipeline_custom
|
|
|
|
| 39 |
from .mask import AttentionMask
|
| 40 |
from .softmax import SoftmaxSm100, apply_score_mod_inner
|
| 41 |
from .seqlen_info import SeqlenInfoQK
|
|
@@ -47,33 +49,45 @@ from .block_sparse_utils import (
|
|
| 47 |
softmax_block_sparse_sm100,
|
| 48 |
handle_block_sparse_empty_tile_correction_sm100,
|
| 49 |
)
|
| 50 |
-
from .pack_gqa import PackGQA
|
| 51 |
from . import mma_sm100_desc as sm100_desc
|
| 52 |
from . import blackwell_helpers as sm100_utils
|
|
|
|
| 53 |
from cutlass.cute import FastDivmodDivisor
|
| 54 |
from .quack.cute_dsl_utils import ParamsBase
|
| 55 |
from .tile_scheduler import (
|
|
|
|
|
|
|
| 56 |
TileSchedulerArguments,
|
|
|
|
| 57 |
SingleTileScheduler,
|
| 58 |
StaticPersistentTileScheduler,
|
| 59 |
SingleTileLPTScheduler,
|
| 60 |
SingleTileVarlenScheduler,
|
| 61 |
)
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
|
| 79 |
class FlashAttentionForwardSm100:
|
|
@@ -99,6 +113,7 @@ class FlashAttentionForwardSm100:
|
|
| 99 |
paged_kv_non_tma: bool = False,
|
| 100 |
is_varlen_q: bool = False,
|
| 101 |
use_2cta_instrs: bool = False,
|
|
|
|
| 102 |
):
|
| 103 |
self.use_tma_KV = not paged_kv_non_tma
|
| 104 |
# self.dtype = dtype
|
|
@@ -145,10 +160,6 @@ class FlashAttentionForwardSm100:
|
|
| 145 |
self.is_split_kv = is_split_kv
|
| 146 |
self.pack_gqa = pack_gqa
|
| 147 |
self.q_subtile_factor = q_subtile_factor
|
| 148 |
-
if pack_gqa:
|
| 149 |
-
assert m_block_size % self.qhead_per_kvhead == 0, (
|
| 150 |
-
"For PackGQA, m_block_size must be divisible by qhead_per_kvhead"
|
| 151 |
-
)
|
| 152 |
assert not (self.is_split_kv and self.head_dim_v_padded >= 192), (
|
| 153 |
"SplitKV is not supported for hdim >= 192"
|
| 154 |
)
|
|
@@ -160,8 +171,10 @@ class FlashAttentionForwardSm100:
|
|
| 160 |
# Does S1 need to wait for S0 to finish
|
| 161 |
# self.s0_s1_barrier = self.head_dim_padded in [64, 96] and (not self.is_causal and not self.is_local)
|
| 162 |
is_sm103 = self.arch >= Arch.sm_103 and self.arch <= Arch.sm_103f
|
| 163 |
-
|
| 164 |
-
|
|
|
|
|
|
|
| 165 |
self.s0_s1_barrier = False
|
| 166 |
self.overlap_sO_sQ = (
|
| 167 |
(self.head_dim_padded == 192 and self.head_dim_v_padded >= 64) or
|
|
@@ -174,6 +187,32 @@ class FlashAttentionForwardSm100:
|
|
| 174 |
"Paged KV does not support irregular head dim"
|
| 175 |
)
|
| 176 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
self.softmax0_warp_ids = (0, 1, 2, 3)
|
| 178 |
self.softmax1_warp_ids = (4, 5, 6, 7)
|
| 179 |
self.correction_warp_ids = (8, 9, 10, 11)
|
|
@@ -195,8 +234,10 @@ class FlashAttentionForwardSm100:
|
|
| 195 |
)
|
| 196 |
)
|
| 197 |
|
|
|
|
|
|
|
| 198 |
if self.q_stage == 1:
|
| 199 |
-
if not self.use_tma_KV:
|
| 200 |
self.empty_warp_ids = self.empty_warp_ids + self.load_warp_ids
|
| 201 |
self.load_warp_ids = self.softmax1_warp_ids
|
| 202 |
else:
|
|
@@ -212,6 +253,8 @@ class FlashAttentionForwardSm100:
|
|
| 212 |
elif self.is_varlen_q: # fallback
|
| 213 |
self.epilogue_warp_ids = (13, 14)
|
| 214 |
|
|
|
|
|
|
|
| 215 |
self.tmem_s_offset = [0, self.n_block_size] # e.g., 0, 128
|
| 216 |
self.tmem_o_offset = [
|
| 217 |
self.tmem_s_offset[-1] + self.n_block_size + i * self.head_dim_v_padded
|
|
@@ -227,31 +270,26 @@ class FlashAttentionForwardSm100:
|
|
| 227 |
# vec buffer for row_max & row_sum
|
| 228 |
self.tmem_vec_offset = self.tmem_s_offset
|
| 229 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
if self.head_dim_padded < 96:
|
| 231 |
self.num_regs_softmax = 200 if not paged_kv_non_tma else 184
|
| 232 |
self.num_regs_correction = 64
|
| 233 |
self.num_regs_other = 48 if not paged_kv_non_tma else 80
|
| 234 |
else:
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
self.
|
|
|
|
|
|
|
|
|
|
| 238 |
else:
|
| 239 |
-
|
| 240 |
-
self.
|
| 241 |
-
|
| 242 |
-
# self.num_regs_correction = 96
|
| 243 |
-
# self.num_regs_correction = 64 if self.is_causal or self.is_local else 80
|
| 244 |
-
if not self.enable_ex2_emu:
|
| 245 |
-
self.num_regs_correction = 80 if not paged_kv_non_tma else 64
|
| 246 |
-
else:
|
| 247 |
-
# self.num_regs_correction = 64
|
| 248 |
-
self.num_regs_correction = 80 if not paged_kv_non_tma else 64
|
| 249 |
-
# self.num_regs_other = 32
|
| 250 |
-
# self.num_regs_other = 64
|
| 251 |
-
# self.num_regs_other = 80
|
| 252 |
-
self.num_regs_other = 48 if not paged_kv_non_tma else 80
|
| 253 |
-
# self.num_regs_other = 96 if self.is_causal or self.is_local else 80
|
| 254 |
-
# self.num_regs_other = 64 if self.is_causal or self.is_local else 80
|
| 255 |
|
| 256 |
self.buffer_align_bytes = 1024
|
| 257 |
|
|
@@ -289,7 +327,7 @@ class FlashAttentionForwardSm100:
|
|
| 289 |
self.head_dim_padded == 192 and self.head_dim_v_padded == 128 and self.kv_stage == 3
|
| 290 |
)
|
| 291 |
self.uneven_kv_smem_offset = (
|
| 292 |
-
self.
|
| 293 |
if self.uneven_kv_smem
|
| 294 |
else 0
|
| 295 |
)
|
|
@@ -304,7 +342,6 @@ class FlashAttentionForwardSm100:
|
|
| 304 |
mO: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
|
| 305 |
mLSE: Optional[cute.Tensor],
|
| 306 |
softmax_scale: Float32,
|
| 307 |
-
stream: cuda.CUstream,
|
| 308 |
mCuSeqlensQ: Optional[cute.Tensor] = None,
|
| 309 |
mCuSeqlensK: Optional[cute.Tensor] = None,
|
| 310 |
mSeqUsedQ: Optional[cute.Tensor] = None,
|
|
@@ -315,6 +352,8 @@ class FlashAttentionForwardSm100:
|
|
| 315 |
learnable_sink: Optional[cute.Tensor] = None,
|
| 316 |
blocksparse_tensors: Optional[BlockSparseTensors] = None,
|
| 317 |
aux_tensors: Optional[list] = None,
|
|
|
|
|
|
|
| 318 |
):
|
| 319 |
"""Execute the Fused Multi-Head Attention operation on the provided tensors.
|
| 320 |
|
|
@@ -367,22 +406,21 @@ class FlashAttentionForwardSm100:
|
|
| 367 |
if const_expr(self.q_dtype != self.v_dtype):
|
| 368 |
raise TypeError(f"Type mismatch: {self.q_dtype} != {self.v_dtype}")
|
| 369 |
self._setup_attributes()
|
| 370 |
-
self.use_tma_O =
|
| 371 |
-
|
| 372 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
self.ex2_emu_freq = 0
|
| 374 |
-
|
| 375 |
-
self.ex2_emu_start_frg = 1
|
| 376 |
if const_expr(self.enable_ex2_emu):
|
| 377 |
-
self.ex2_emu_freq = 16
|
| 378 |
-
if const_expr(self.head_dim_padded == 128 and self.use_2cta_instrs):
|
| 379 |
-
self.ex2_emu_freq = 12
|
| 380 |
if const_expr(
|
| 381 |
self.pack_gqa and self.head_dim_padded > 64 and not self.is_causal and not self.is_local
|
| 382 |
):
|
| 383 |
-
self.ex2_emu_freq = 32 if mCuSeqlensQ is not None or mSeqUsedQ is not None else 10
|
| 384 |
-
if const_expr(self.head_dim_padded > 64 and self.is_causal):
|
| 385 |
-
self.ex2_emu_freq = 10
|
| 386 |
|
| 387 |
cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE
|
| 388 |
q_major_mode = tcgen05.OperandMajorMode.K
|
|
@@ -462,50 +500,11 @@ class FlashAttentionForwardSm100:
|
|
| 462 |
)
|
| 463 |
|
| 464 |
if const_expr(self.pack_gqa):
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
mK.shape[2],
|
| 469 |
-
*mQ.shape[3:],
|
| 470 |
-
)
|
| 471 |
-
stride_Q_packed = (
|
| 472 |
-
(mQ.stride[2], mQ.stride[0]),
|
| 473 |
-
mQ.stride[1],
|
| 474 |
-
mQ.stride[2] * self.qhead_per_kvhead,
|
| 475 |
-
*mQ.stride[3:],
|
| 476 |
-
)
|
| 477 |
-
mQ = cute.make_tensor(
|
| 478 |
-
mQ.iterator, cute.make_layout(shape_Q_packed, stride=stride_Q_packed)
|
| 479 |
-
)
|
| 480 |
-
shape_O_packed = (
|
| 481 |
-
(self.qhead_per_kvhead, mO.shape[0]),
|
| 482 |
-
mO.shape[1],
|
| 483 |
-
mK.shape[2],
|
| 484 |
-
*mO.shape[3:],
|
| 485 |
-
)
|
| 486 |
-
stride_O_packed = (
|
| 487 |
-
(mO.stride[2], mO.stride[0]),
|
| 488 |
-
mO.stride[1],
|
| 489 |
-
mO.stride[2] * self.qhead_per_kvhead,
|
| 490 |
-
*mO.stride[3:],
|
| 491 |
-
)
|
| 492 |
-
mO = cute.make_tensor(
|
| 493 |
-
mO.iterator, cute.make_layout(shape_O_packed, stride=stride_O_packed)
|
| 494 |
-
)
|
| 495 |
if const_expr(mLSE is not None):
|
| 496 |
-
|
| 497 |
-
(self.qhead_per_kvhead, mLSE.shape[0]),
|
| 498 |
-
mK.shape[2],
|
| 499 |
-
*mLSE.shape[2:],
|
| 500 |
-
)
|
| 501 |
-
stride_LSE_packed = (
|
| 502 |
-
(mLSE.stride[1], mLSE.stride[0]),
|
| 503 |
-
mLSE.stride[1] * self.qhead_per_kvhead,
|
| 504 |
-
*mLSE.stride[2:],
|
| 505 |
-
)
|
| 506 |
-
mLSE = cute.make_tensor(
|
| 507 |
-
mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed)
|
| 508 |
-
)
|
| 509 |
|
| 510 |
self.tma_copy_bytes = {
|
| 511 |
name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1, 2]))
|
|
@@ -522,14 +521,24 @@ class FlashAttentionForwardSm100:
|
|
| 522 |
tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group)
|
| 523 |
tma_store_op = cpasync.CopyBulkTensorTileS2GOp()
|
| 524 |
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 533 |
|
| 534 |
tma_atom_K = None
|
| 535 |
tma_atom_V = None
|
|
@@ -578,19 +587,10 @@ class FlashAttentionForwardSm100:
|
|
| 578 |
vO_layout = cute.make_layout((1, async_copy_elems))
|
| 579 |
gmem_tiled_copy_O = cute.make_tiled_copy_tv(atom_universal_copy, tO_layout, vO_layout)
|
| 580 |
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
else:
|
| 584 |
-
if const_expr(self.is_causal or self.is_local):
|
| 585 |
-
TileScheduler = SingleTileLPTScheduler
|
| 586 |
-
else:
|
| 587 |
-
TileScheduler = (
|
| 588 |
-
SingleTileScheduler
|
| 589 |
-
if const_expr(not self.is_persistent)
|
| 590 |
-
else StaticPersistentTileScheduler
|
| 591 |
-
)
|
| 592 |
tile_sched_args = TileSchedulerArguments(
|
| 593 |
-
cute.ceil_div(cute.size(mQ.shape[0]),
|
| 594 |
cute.size(mQ.shape[2]),
|
| 595 |
cute.size(mQ.shape[3])
|
| 596 |
if const_expr(mCuSeqlensQ is None)
|
|
@@ -613,8 +613,11 @@ class FlashAttentionForwardSm100:
|
|
| 613 |
lpt=self.is_causal or self.is_local,
|
| 614 |
is_split_kv=self.is_split_kv,
|
| 615 |
cluster_shape_mn=self.cluster_shape_mn,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 616 |
)
|
| 617 |
-
tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
|
| 618 |
self.tile_scheduler_cls = TileScheduler
|
| 619 |
grid_dim = TileScheduler.get_grid_shape(tile_sched_params)
|
| 620 |
|
|
@@ -624,6 +627,9 @@ class FlashAttentionForwardSm100:
|
|
| 624 |
cutlass.max(cute.cosize(sQ_layout), cute.cosize(sO_layout) * self.o_dtype.width // self.q_dtype.width)
|
| 625 |
)
|
| 626 |
|
|
|
|
|
|
|
|
|
|
| 627 |
@cute.struct
|
| 628 |
class SharedStorage:
|
| 629 |
# m_barriers for pipelines
|
|
@@ -643,6 +649,13 @@ class FlashAttentionForwardSm100:
|
|
| 643 |
# Smem tensors
|
| 644 |
# store row max and row sum
|
| 645 |
sScale: cute.struct.MemRange[Float32, self.q_stage * self.m_block_size * 2]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 646 |
sO: cute.struct.Align[
|
| 647 |
cute.struct.MemRange[self.o_dtype, sO_size], self.buffer_align_bytes
|
| 648 |
]
|
|
@@ -657,35 +670,10 @@ class FlashAttentionForwardSm100:
|
|
| 657 |
|
| 658 |
self.shared_storage = SharedStorage
|
| 659 |
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
else:
|
| 665 |
-
# NB: If a users passes in a score mod, we want to apply the score-mod in the sm_scaled qk
|
| 666 |
-
# But in the original base 10. We hijack softmax_scale_log2 to just be the change of base
|
| 667 |
-
# and correctly apply the softmax_scale prior to score_mod in the softmax step
|
| 668 |
-
softmax_scale_log2 = LOG2_E
|
| 669 |
-
softmax_scale = softmax_scale
|
| 670 |
-
|
| 671 |
-
if const_expr(window_size_left is not None):
|
| 672 |
-
window_size_left = Int32(window_size_left)
|
| 673 |
-
if const_expr(window_size_right is not None):
|
| 674 |
-
window_size_right = Int32(window_size_right)
|
| 675 |
-
|
| 676 |
-
fastdiv_mods = None
|
| 677 |
-
if cutlass.const_expr(aux_tensors is not None):
|
| 678 |
-
seqlen_q = cute.size(mQ.shape[0]) // (
|
| 679 |
-
self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1
|
| 680 |
-
)
|
| 681 |
-
seqlen_k = (
|
| 682 |
-
cute.size(mK.shape[0])
|
| 683 |
-
if const_expr(mPageTable is None)
|
| 684 |
-
else mK.shape[0] * mPageTable.shape[1]
|
| 685 |
-
)
|
| 686 |
-
seqlen_q_divmod = FastDivmodDivisor(seqlen_q)
|
| 687 |
-
seqlen_k_divmod = FastDivmodDivisor(seqlen_k)
|
| 688 |
-
fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod)
|
| 689 |
|
| 690 |
head_divmod = None
|
| 691 |
if cutlass.const_expr(self.pack_gqa):
|
|
@@ -722,6 +710,7 @@ class FlashAttentionForwardSm100:
|
|
| 722 |
tP_layout,
|
| 723 |
sV_layout,
|
| 724 |
sO_layout,
|
|
|
|
| 725 |
gmem_tiled_copy_O,
|
| 726 |
tiled_mma_qk,
|
| 727 |
tiled_mma_pv,
|
|
@@ -752,7 +741,7 @@ class FlashAttentionForwardSm100:
|
|
| 752 |
mSeqUsedQ: Optional[cute.Tensor],
|
| 753 |
mSeqUsedK: Optional[cute.Tensor],
|
| 754 |
mPageTable: Optional[cute.Tensor],
|
| 755 |
-
tma_atom_Q: cute.CopyAtom,
|
| 756 |
tma_atom_K: Optional[cute.CopyAtom],
|
| 757 |
tma_atom_V: Optional[cute.CopyAtom],
|
| 758 |
tma_atom_O: Optional[cute.CopyAtom],
|
|
@@ -767,6 +756,7 @@ class FlashAttentionForwardSm100:
|
|
| 767 |
tP_layout: cute.ComposedLayout,
|
| 768 |
sV_layout: cute.ComposedLayout,
|
| 769 |
sO_layout: cute.ComposedLayout,
|
|
|
|
| 770 |
gmem_tiled_copy_O: Optional[cute.TiledCopy],
|
| 771 |
tiled_mma_qk: cute.TiledMma,
|
| 772 |
tiled_mma_pv: cute.TiledMma,
|
|
@@ -814,7 +804,7 @@ class FlashAttentionForwardSm100:
|
|
| 814 |
storage = smem.allocate(self.shared_storage)
|
| 815 |
|
| 816 |
tmem_alloc_barrier = pipeline.NamedBarrier(
|
| 817 |
-
barrier_id=int(
|
| 818 |
num_threads=cute.arch.WARP_SIZE * len(
|
| 819 |
(self.mma_warp_id,
|
| 820 |
*self.softmax0_warp_ids,
|
|
@@ -833,8 +823,8 @@ class FlashAttentionForwardSm100:
|
|
| 833 |
|
| 834 |
ThreadCooperativeGroup = partial(pipeline.CooperativeGroup, pipeline.Agent.Thread)
|
| 835 |
mma_warp = ThreadCooperativeGroup(len([self.mma_warp_id]))
|
| 836 |
-
load_warps = ThreadCooperativeGroup(len(self.load_warp_ids))
|
| 837 |
tma_warp = ThreadCooperativeGroup(1)
|
|
|
|
| 838 |
softmax_warps = ThreadCooperativeGroup(len(self.softmax0_warp_ids))
|
| 839 |
softmax_threads = ThreadCooperativeGroup(cute.arch.WARP_SIZE * len(self.softmax0_warp_ids))
|
| 840 |
# softmax_threads = ThreadCooperativeGroup(cute.arch.WARP_SIZE)
|
|
@@ -857,15 +847,25 @@ class FlashAttentionForwardSm100:
|
|
| 857 |
softmax_correction_threads_cluster = ThreadCooperativeGroup(
|
| 858 |
cute.arch.WARP_SIZE * len(self.softmax0_warp_ids + self.correction_warp_ids) * self.cta_group_size
|
| 859 |
)
|
| 860 |
-
|
| 861 |
-
|
| 862 |
-
|
| 863 |
-
|
| 864 |
-
|
| 865 |
-
|
| 866 |
-
|
| 867 |
-
|
| 868 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 869 |
if const_expr(self.use_tma_KV):
|
| 870 |
pipeline_kv = pipeline_custom.PipelineTmaUmma.create(
|
| 871 |
barrier_storage=storage.mbar_load_KV.data_ptr(),
|
|
@@ -877,13 +877,10 @@ class FlashAttentionForwardSm100:
|
|
| 877 |
defer_sync=True,
|
| 878 |
)
|
| 879 |
else:
|
| 880 |
-
cpasync_producer_group = pipeline.CooperativeGroup(
|
| 881 |
-
pipeline.Agent.Thread, len(self.load_warp_ids) * cute.arch.WARP_SIZE
|
| 882 |
-
)
|
| 883 |
pipeline_kv = pipeline.PipelineAsyncUmma.create(
|
| 884 |
barrier_storage=storage.mbar_load_KV.data_ptr(),
|
| 885 |
num_stages=self.kv_stage,
|
| 886 |
-
producer_group=
|
| 887 |
consumer_group=mma_warp,
|
| 888 |
cta_layout_vmnk=cta_layout_vmnk,
|
| 889 |
defer_sync=True,
|
|
@@ -938,7 +935,7 @@ class FlashAttentionForwardSm100:
|
|
| 938 |
)
|
| 939 |
# Should put the NamedBarrier inside the pipeline class so we'll just have pipeline_sm_stats
|
| 940 |
sm_stats_barrier = pipeline_custom.NamedBarrier(
|
| 941 |
-
barrier_id=int(
|
| 942 |
)
|
| 943 |
pipeline_o_epi = None
|
| 944 |
if const_expr(not self.use_correction_warps_for_epi):
|
|
@@ -1019,17 +1016,69 @@ class FlashAttentionForwardSm100:
|
|
| 1019 |
window_size_right=window_size_right,
|
| 1020 |
qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
|
| 1021 |
)
|
| 1022 |
-
TileSchedulerCls = partial(self.tile_scheduler_cls.create, tile_sched_params)
|
| 1023 |
-
|
| 1024 |
# Cluster wait before tensor memory alloc
|
| 1025 |
pipeline_init_wait(cluster_shape_mn=cta_layout_vmnk)
|
| 1026 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1027 |
# ///////////////////////////////////////////////////////////////////////////////
|
| 1028 |
-
# EMPTY
|
| 1029 |
# ///////////////////////////////////////////////////////////////////////////////
|
| 1030 |
-
|
| 1031 |
-
if warp_idx == self.
|
| 1032 |
cute.arch.setmaxregister_decrease(self.num_regs_other)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1033 |
|
| 1034 |
# ///////////////////////////////////////////////////////////////////////////////
|
| 1035 |
# LOAD
|
|
@@ -1049,13 +1098,14 @@ class FlashAttentionForwardSm100:
|
|
| 1049 |
tma_atom_Q,
|
| 1050 |
tma_atom_K,
|
| 1051 |
tma_atom_V,
|
|
|
|
| 1052 |
pipeline_q,
|
| 1053 |
pipeline_kv,
|
| 1054 |
block_info,
|
| 1055 |
num_splits,
|
| 1056 |
SeqlenInfoCls,
|
| 1057 |
-
TileSchedulerCls,
|
| 1058 |
blocksparse_tensors,
|
|
|
|
| 1059 |
)
|
| 1060 |
|
| 1061 |
# ///////////////////////////////////////////////////////////////////////////////
|
|
@@ -1085,8 +1135,8 @@ class FlashAttentionForwardSm100:
|
|
| 1085 |
block_info,
|
| 1086 |
num_splits,
|
| 1087 |
SeqlenInfoCls,
|
| 1088 |
-
TileSchedulerCls,
|
| 1089 |
blocksparse_tensors,
|
|
|
|
| 1090 |
)
|
| 1091 |
# Dealloc the tensor memory buffer
|
| 1092 |
tmem.relinquish_alloc_permit()
|
|
@@ -1108,8 +1158,8 @@ class FlashAttentionForwardSm100:
|
|
| 1108 |
block_info,
|
| 1109 |
num_splits,
|
| 1110 |
SeqlenInfoCls,
|
| 1111 |
-
TileSchedulerCls,
|
| 1112 |
mma_tile_coord_v,
|
|
|
|
| 1113 |
)
|
| 1114 |
|
| 1115 |
# ///////////////////////////////////////////////////////////////////////////////
|
|
@@ -1141,11 +1191,11 @@ class FlashAttentionForwardSm100:
|
|
| 1141 |
num_splits=num_splits,
|
| 1142 |
SeqlenInfoCls=SeqlenInfoCls,
|
| 1143 |
AttentionMaskCls=AttentionMaskCls,
|
| 1144 |
-
TileSchedulerCls=TileSchedulerCls,
|
| 1145 |
aux_tensors=aux_tensors,
|
| 1146 |
fastdiv_mods=fastdiv_mods,
|
| 1147 |
head_divmod=head_divmod,
|
| 1148 |
blocksparse_tensors=blocksparse_tensors,
|
|
|
|
| 1149 |
)
|
| 1150 |
|
| 1151 |
if const_expr(not self.s0_s1_barrier):
|
|
@@ -1189,8 +1239,8 @@ class FlashAttentionForwardSm100:
|
|
| 1189 |
block_info,
|
| 1190 |
num_splits,
|
| 1191 |
SeqlenInfoCls,
|
| 1192 |
-
TileSchedulerCls,
|
| 1193 |
blocksparse_tensors,
|
|
|
|
| 1194 |
)
|
| 1195 |
tmem_alloc_barrier.arrive()
|
| 1196 |
|
|
@@ -1208,35 +1258,38 @@ class FlashAttentionForwardSm100:
|
|
| 1208 |
sK: cute.Tensor,
|
| 1209 |
sV: cute.Tensor,
|
| 1210 |
mPageTable: Optional[cute.Tensor],
|
| 1211 |
-
tma_atom_Q: cute.CopyAtom,
|
| 1212 |
tma_atom_K: Optional[cute.CopyAtom],
|
| 1213 |
tma_atom_V: Optional[cute.CopyAtom],
|
|
|
|
| 1214 |
pipeline_q: pipeline.PipelineAsync,
|
| 1215 |
pipeline_kv: pipeline.PipelineAsync,
|
| 1216 |
block_info: BlockInfo,
|
| 1217 |
num_splits: Int32,
|
| 1218 |
SeqlenInfoCls: Callable,
|
| 1219 |
-
TileSchedulerCls: Callable,
|
| 1220 |
blocksparse_tensors: Optional[BlockSparseTensors],
|
|
|
|
| 1221 |
):
|
| 1222 |
num_load_threads = len(self.load_warp_ids) * cute.arch.WARP_SIZE
|
| 1223 |
tidx = cute.arch.thread_idx()[0] % num_load_threads
|
| 1224 |
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1225 |
q_producer_phase = Int32(1)
|
| 1226 |
kv_producer_state = pipeline.make_pipeline_state(
|
| 1227 |
pipeline.PipelineUserType.Producer, self.kv_stage
|
| 1228 |
)
|
| 1229 |
-
tile_scheduler = TileSchedulerCls()
|
| 1230 |
work_tile = tile_scheduler.initial_work_tile_info()
|
| 1231 |
while work_tile.is_valid_tile:
|
| 1232 |
m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx
|
| 1233 |
seqlen = SeqlenInfoCls(batch_idx)
|
| 1234 |
mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx]
|
| 1235 |
-
tiler_gQ = ((self.mma_tiler_qk[0] * self.q_stage), self.head_dim_padded)
|
| 1236 |
-
gQ = cute.local_tile(mQ_cur, tiler_gQ, (m_block, 0)) # (128 * 2, 128)
|
| 1237 |
-
gQ = layout_utils.select(
|
| 1238 |
-
cute.flat_divide(gQ, (self.mma_tiler_qk[0],)), mode=[0, 2, 1]
|
| 1239 |
-
) # (128, 128, 2)
|
| 1240 |
|
| 1241 |
head_idx_kv = (
|
| 1242 |
head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx
|
|
@@ -1258,12 +1311,32 @@ class FlashAttentionForwardSm100:
|
|
| 1258 |
gV = cute.local_tile(
|
| 1259 |
mV_cur, cute.select(self.mma_tiler_pv, mode=[1, 2]), (0, None, None)
|
| 1260 |
)
|
| 1261 |
-
tSgQ = thr_mma_qk.partition_A(gQ)
|
| 1262 |
tSgK = thr_mma_qk.partition_B(gK)
|
| 1263 |
tOgV = thr_mma_pv.partition_B(gV)
|
| 1264 |
-
|
| 1265 |
-
|
| 1266 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1267 |
|
| 1268 |
if const_expr(self.use_tma_KV):
|
| 1269 |
tKsK, tKgK = cpasync.tma_partition(
|
|
@@ -1302,7 +1375,6 @@ class FlashAttentionForwardSm100:
|
|
| 1302 |
tKsK, tKgK = None, None
|
| 1303 |
tVsV, tVgV = None, None
|
| 1304 |
|
| 1305 |
-
load_Q = partial(self.load_Q, load_Q_fn, pipeline_q=pipeline_q, phase=q_producer_phase)
|
| 1306 |
load_K = partial(
|
| 1307 |
self.load_KV,
|
| 1308 |
tma_atom_K,
|
|
@@ -1337,24 +1409,19 @@ class FlashAttentionForwardSm100:
|
|
| 1337 |
)
|
| 1338 |
if const_expr(not self.use_tma_KV):
|
| 1339 |
paged_kv_manager.load_page_table(n_block_first)
|
| 1340 |
-
|
|
|
|
| 1341 |
# load_K(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx, extra_tx_count=self.tma_copy_bytes["Q"]) # K0
|
| 1342 |
-
if
|
| 1343 |
-
|
| 1344 |
-
|
| 1345 |
-
|
| 1346 |
-
|
| 1347 |
-
|
| 1348 |
-
load_Q_fn(src_idx=0, dst_idx=0, tma_bar_ptr=tma_bar_ptr)
|
| 1349 |
-
kv_producer_state.advance()
|
| 1350 |
-
if const_expr(self.q_stage == 2) and (const_expr(len(self.load_warp_ids) == 1) or warp_idx == self.load_warp_ids[0]):
|
| 1351 |
-
# load_Q(block=1, stage=1) # Q1
|
| 1352 |
-
pipeline_q.producer_acquire_w_index_phase(1, q_producer_phase)
|
| 1353 |
-
tma_bar_ptr = pipeline_q.sync_object_full.get_barrier(1)
|
| 1354 |
-
load_Q_fn(src_idx=1, dst_idx=1, tma_bar_ptr=tma_bar_ptr)
|
| 1355 |
q_producer_phase ^= 1
|
| 1356 |
-
|
| 1357 |
-
|
|
|
|
| 1358 |
for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1):
|
| 1359 |
n_block = n_block_max - 2 - i
|
| 1360 |
page_idx = (
|
|
@@ -1365,10 +1432,11 @@ class FlashAttentionForwardSm100:
|
|
| 1365 |
if const_expr(not self.use_tma_KV):
|
| 1366 |
paged_kv_manager.load_page_table(n_block)
|
| 1367 |
# if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("n_block = {}, page_idx = {}", n_block, page_idx)
|
| 1368 |
-
|
| 1369 |
-
|
| 1370 |
-
|
| 1371 |
-
|
|
|
|
| 1372 |
|
| 1373 |
else:
|
| 1374 |
kv_producer_state, q_producer_phase = produce_block_sparse_loads_sm100(
|
|
@@ -1387,14 +1455,14 @@ class FlashAttentionForwardSm100:
|
|
| 1387 |
self.q_subtile_factor if self.q_subtile_factor is not None else 1,
|
| 1388 |
)
|
| 1389 |
|
| 1390 |
-
|
| 1391 |
-
tile_scheduler.advance_to_next_work()
|
| 1392 |
-
work_tile = tile_scheduler.get_current_work()
|
| 1393 |
# End of persistent scheduler loop
|
| 1394 |
|
| 1395 |
-
|
| 1396 |
-
|
| 1397 |
-
|
|
|
|
| 1398 |
pipeline_q.producer_acquire_w_index_phase(self.q_stage - 1, q_producer_phase)
|
| 1399 |
|
| 1400 |
@cute.jit
|
|
@@ -1417,8 +1485,8 @@ class FlashAttentionForwardSm100:
|
|
| 1417 |
block_info: BlockInfo,
|
| 1418 |
num_splits: Int32,
|
| 1419 |
SeqlenInfoCls: Callable,
|
| 1420 |
-
TileSchedulerCls: Callable,
|
| 1421 |
blocksparse_tensors: Optional[BlockSparseTensors],
|
|
|
|
| 1422 |
):
|
| 1423 |
tSrQ = tiled_mma_qk.make_fragment_A(sQ)
|
| 1424 |
tSrK = tiled_mma_qk.make_fragment_B(sK)
|
|
@@ -1507,7 +1575,6 @@ class FlashAttentionForwardSm100:
|
|
| 1507 |
)
|
| 1508 |
P_full_O_rescaled_phase = Int32(0)
|
| 1509 |
|
| 1510 |
-
tile_scheduler = TileSchedulerCls()
|
| 1511 |
work_tile = tile_scheduler.initial_work_tile_info()
|
| 1512 |
while work_tile.is_valid_tile:
|
| 1513 |
m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx
|
|
@@ -1678,8 +1745,7 @@ class FlashAttentionForwardSm100:
|
|
| 1678 |
# End of GEMM_PV1(i_end) (P1 * Vi_end -> O1)
|
| 1679 |
|
| 1680 |
# Advance to next tile
|
| 1681 |
-
tile_scheduler.advance_to_next_work()
|
| 1682 |
-
work_tile = tile_scheduler.get_current_work()
|
| 1683 |
# End of persistent scheduler loop
|
| 1684 |
|
| 1685 |
# We don't need pipeline_s_p_o.producer_tail() since there's no dangling mbarrier at the end
|
|
@@ -1708,11 +1774,11 @@ class FlashAttentionForwardSm100:
|
|
| 1708 |
num_splits: Int32,
|
| 1709 |
SeqlenInfoCls: Callable,
|
| 1710 |
AttentionMaskCls: Callable,
|
| 1711 |
-
TileSchedulerCls: Callable,
|
| 1712 |
aux_tensors: Optional[list] = None,
|
| 1713 |
fastdiv_mods=(None, None),
|
| 1714 |
head_divmod=None,
|
| 1715 |
blocksparse_tensors: Optional[BlockSparseTensors] = None,
|
|
|
|
| 1716 |
):
|
| 1717 |
"""Compute softmax on attention scores from QK matrix multiplication.
|
| 1718 |
|
|
@@ -1772,7 +1838,6 @@ class FlashAttentionForwardSm100:
|
|
| 1772 |
|
| 1773 |
warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4
|
| 1774 |
|
| 1775 |
-
tile_scheduler = TileSchedulerCls()
|
| 1776 |
work_tile = tile_scheduler.initial_work_tile_info()
|
| 1777 |
while work_tile.is_valid_tile:
|
| 1778 |
m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx
|
|
@@ -2015,8 +2080,7 @@ class FlashAttentionForwardSm100:
|
|
| 2015 |
# gLSE[tidx] = lse
|
| 2016 |
|
| 2017 |
# Advance to next tile
|
| 2018 |
-
tile_scheduler.advance_to_next_work()
|
| 2019 |
-
work_tile = tile_scheduler.get_current_work()
|
| 2020 |
# End of persistent scheduler loop
|
| 2021 |
|
| 2022 |
# This is equivalent to pipeline_sm_stats.producer_tail
|
|
@@ -2186,8 +2250,8 @@ class FlashAttentionForwardSm100:
|
|
| 2186 |
block_info: BlockInfo,
|
| 2187 |
num_splits: Int32,
|
| 2188 |
SeqlenInfoCls: Callable,
|
| 2189 |
-
TileSchedulerCls: Callable,
|
| 2190 |
blocksparse_tensors: Optional[BlockSparseTensors] = None,
|
|
|
|
| 2191 |
):
|
| 2192 |
tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.correction_warp_ids))
|
| 2193 |
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4
|
|
@@ -2217,7 +2281,6 @@ class FlashAttentionForwardSm100:
|
|
| 2217 |
o_corr_consumer_phase = Int32(0)
|
| 2218 |
corr_epi_producer_phase = Int32(1)
|
| 2219 |
|
| 2220 |
-
tile_scheduler = TileSchedulerCls()
|
| 2221 |
work_tile = tile_scheduler.initial_work_tile_info()
|
| 2222 |
while work_tile.is_valid_tile:
|
| 2223 |
m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx
|
|
@@ -2228,12 +2291,14 @@ class FlashAttentionForwardSm100:
|
|
| 2228 |
mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx]
|
| 2229 |
else:
|
| 2230 |
mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx]
|
| 2231 |
-
|
| 2232 |
-
|
| 2233 |
-
|
| 2234 |
-
cute.
|
| 2235 |
-
|
| 2236 |
-
|
|
|
|
|
|
|
| 2237 |
|
| 2238 |
# Default LSE to -inf for invalid split_idx tiles
|
| 2239 |
stats = [(0.0, -Float32.inf if const_expr(mLSE is not None or learnable_sink is not None) else None, True)] * self.q_stage
|
|
@@ -2334,6 +2399,7 @@ class FlashAttentionForwardSm100:
|
|
| 2334 |
pipeline_o_acc.consumer_wait_w_index_phase(stage, o_corr_consumer_phase)
|
| 2335 |
if const_expr(not self.use_correction_warps_for_epi):
|
| 2336 |
pipeline_o_epi.producer_acquire_w_index_phase(stage, corr_epi_producer_phase)
|
|
|
|
| 2337 |
self.correction_epilogue(
|
| 2338 |
thr_mma_pv,
|
| 2339 |
tOtO[None, None, None, stage],
|
|
@@ -2344,7 +2410,7 @@ class FlashAttentionForwardSm100:
|
|
| 2344 |
scale,
|
| 2345 |
sO[None, None, stage],
|
| 2346 |
mO_cur,
|
| 2347 |
-
|
| 2348 |
gmem_tiled_copy_O,
|
| 2349 |
)
|
| 2350 |
# Signal for the next work tile that O buffers in tmem are already read, so
|
|
@@ -2414,7 +2480,6 @@ class FlashAttentionForwardSm100:
|
|
| 2414 |
mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx])
|
| 2415 |
for stage in cutlass.range_constexpr(self.q_stage):
|
| 2416 |
m_tile_idx = (m_block * self.q_stage + stage) * self.cta_group_size + mma_tile_coord_v
|
| 2417 |
-
gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_tile_idx,))
|
| 2418 |
row_sum, row_max, acc_O_mn_row_is_zero_or_nan = stats[stage]
|
| 2419 |
# if tidx == 0 and stage <= 1:
|
| 2420 |
# cute.printf("row_sum = {}, row_max = {}, acc_O_mn_row_is_zero_or_nan = {}\n", row_sum, row_max, acc_O_mn_row_is_zero_or_nan)
|
|
@@ -2429,13 +2494,24 @@ class FlashAttentionForwardSm100:
|
|
| 2429 |
if const_expr(not self.pack_gqa)
|
| 2430 |
else seqlen.seqlen_q * self.qhead_per_kvhead
|
| 2431 |
)
|
| 2432 |
-
if
|
| 2433 |
-
|
| 2434 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2435 |
|
| 2436 |
# Advance to next tile
|
| 2437 |
-
tile_scheduler.advance_to_next_work()
|
| 2438 |
-
work_tile = tile_scheduler.get_current_work()
|
| 2439 |
# End of persistent scheduler loop
|
| 2440 |
|
| 2441 |
# This is equivalent to pipeline_o_epi.consumer_tail() for the correction warps
|
|
@@ -2574,7 +2650,7 @@ class FlashAttentionForwardSm100:
|
|
| 2574 |
if const_expr(self.use_correction_warps_for_epi):
|
| 2575 |
assert(not self.use_tma_O)
|
| 2576 |
assert(gmem_tiled_copy_O is not None)
|
| 2577 |
-
cute.arch.barrier(barrier_id=int(
|
| 2578 |
number_of_threads=len(self.epilogue_warp_ids) * cute.arch.WARP_SIZE)
|
| 2579 |
mma_tile_coord_v = thr_mma.thr_idx
|
| 2580 |
m_tile_idx = (m_block * self.q_stage + stage) * self.cta_group_size + mma_tile_coord_v
|
|
@@ -2586,7 +2662,7 @@ class FlashAttentionForwardSm100:
|
|
| 2586 |
def _store_O_to_gmem(
|
| 2587 |
self,
|
| 2588 |
sO_stage: cute.Tensor,
|
| 2589 |
-
gO: cute.Tensor,
|
| 2590 |
mO_cur: cute.Tensor,
|
| 2591 |
gmem_tiled_copy_O: cute.TiledCopy,
|
| 2592 |
tidx: Int32,
|
|
@@ -2597,7 +2673,6 @@ class FlashAttentionForwardSm100:
|
|
| 2597 |
gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx)
|
| 2598 |
tOsO = gmem_thr_copy_O.partition_S(sO_stage)
|
| 2599 |
cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded))
|
| 2600 |
-
tOgO = gmem_thr_copy_O.partition_D(gO)
|
| 2601 |
tOcO = gmem_thr_copy_O.partition_S(cO)
|
| 2602 |
t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO)
|
| 2603 |
tOpO = copy_utils.predicate_k(tOcO, limit=mO_cur.shape[1])
|
|
@@ -2613,6 +2688,8 @@ class FlashAttentionForwardSm100:
|
|
| 2613 |
cute.autovec_copy(tOsO, tOrO)
|
| 2614 |
# copy acc O from rmem to gmem
|
| 2615 |
if const_expr(not self.pack_gqa):
|
|
|
|
|
|
|
| 2616 |
for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])):
|
| 2617 |
if (
|
| 2618 |
t0OcO[0, rest_m, 0][0] < seqlen_q - m_tile_idx * self.m_block_size - tOcO[0][0]
|
|
@@ -2641,11 +2718,10 @@ class FlashAttentionForwardSm100:
|
|
| 2641 |
block_info: BlockInfo,
|
| 2642 |
num_splits: int,
|
| 2643 |
SeqlenInfoCls: Callable,
|
| 2644 |
-
TileSchedulerCls: Callable,
|
| 2645 |
mma_tile_coord_v: Int32 = 0,
|
|
|
|
| 2646 |
):
|
| 2647 |
epi_consumer_phase = Int32(0)
|
| 2648 |
-
tile_scheduler = TileSchedulerCls()
|
| 2649 |
work_tile = tile_scheduler.initial_work_tile_info()
|
| 2650 |
while work_tile.is_valid_tile:
|
| 2651 |
m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx
|
|
@@ -2657,12 +2733,14 @@ class FlashAttentionForwardSm100:
|
|
| 2657 |
mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx]
|
| 2658 |
else:
|
| 2659 |
mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx]
|
| 2660 |
-
|
| 2661 |
-
|
| 2662 |
-
|
| 2663 |
-
cute.
|
| 2664 |
-
|
| 2665 |
-
|
|
|
|
|
|
|
| 2666 |
|
| 2667 |
if const_expr(self.use_tma_O):
|
| 2668 |
store_O, _, _ = copy_utils.tma_get_copy_fn(
|
|
@@ -2689,8 +2767,9 @@ class FlashAttentionForwardSm100:
|
|
| 2689 |
pipeline_o_epi.consumer_wait_w_index_phase(stage, epi_consumer_phase)
|
| 2690 |
# 2. copy O0 / O1 to gmem
|
| 2691 |
m_tile_idx = (m_block * self.q_stage + stage) * self.cta_group_size + mma_tile_coord_v
|
|
|
|
| 2692 |
self._store_O_to_gmem(
|
| 2693 |
-
sO[None, None, stage],
|
| 2694 |
tidx, seqlen.seqlen_q, m_tile_idx,
|
| 2695 |
)
|
| 2696 |
pipeline_o_epi.consumer_release_w_index(stage)
|
|
@@ -2698,8 +2777,39 @@ class FlashAttentionForwardSm100:
|
|
| 2698 |
epi_consumer_phase ^= 1
|
| 2699 |
|
| 2700 |
# Advance to next tile
|
| 2701 |
-
tile_scheduler.advance_to_next_work()
|
| 2702 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2703 |
|
| 2704 |
def load_Q(
|
| 2705 |
self,
|
|
@@ -2712,6 +2822,39 @@ class FlashAttentionForwardSm100:
|
|
| 2712 |
pipeline_q.producer_acquire_w_index_phase(stage, phase)
|
| 2713 |
load_Q_fn(src_idx=block, dst_idx=stage, tma_bar_ptr=pipeline_q.sync_object_full.get_barrier(stage))
|
| 2714 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2715 |
@cute.jit
|
| 2716 |
def load_KV(
|
| 2717 |
self,
|
|
@@ -2754,7 +2897,10 @@ class FlashAttentionForwardSm100:
|
|
| 2754 |
else:
|
| 2755 |
assert paged_kv_manager is not None
|
| 2756 |
assert extra_tx_count is None
|
| 2757 |
-
|
|
|
|
|
|
|
|
|
|
| 2758 |
cute.arch.cp_async_commit_group()
|
| 2759 |
pipeline_kv.sync_object_full.arrive_cp_async_mbarrier(stage)
|
| 2760 |
|
|
@@ -2765,6 +2911,9 @@ class FlashAttentionForwardSm100:
|
|
| 2765 |
# (smem_large + smem_small) // 2. So for stage == 1, move right by offset if
|
| 2766 |
# phase == 0, or left by offset if phase == 1.
|
| 2767 |
offset = 0 if stage != 1 else self.uneven_kv_smem_offset * (1 - 2 * phase)
|
|
|
|
|
|
|
|
|
|
| 2768 |
return cute.make_tensor(sX.iterator + offset, sX.layout)
|
| 2769 |
else:
|
| 2770 |
return sX
|
|
@@ -2774,12 +2923,12 @@ class FlashAttentionForwardSm100:
|
|
| 2774 |
# warp_group_idx = utils.canonical_warp_group_idx(sync=False)
|
| 2775 |
# if warp_group_idx == 0:
|
| 2776 |
# cute.arch.barrier_arrive(
|
| 2777 |
-
# barrier_id=int(
|
| 2778 |
# )
|
| 2779 |
|
| 2780 |
# def warp_scheduler_barrier_sync(self):
|
| 2781 |
# cute.arch.barrier(
|
| 2782 |
-
# barrier_id=int(
|
| 2783 |
# number_of_threads=2 * 128
|
| 2784 |
# )
|
| 2785 |
|
|
@@ -2787,7 +2936,7 @@ class FlashAttentionForwardSm100:
|
|
| 2787 |
# cur_wg = utils.canonical_warp_group_idx(sync=False)
|
| 2788 |
# next_wg = 1 - cur_wg
|
| 2789 |
# cute.arch.barrier_arrive(
|
| 2790 |
-
# barrier_id=int(
|
| 2791 |
# )
|
| 2792 |
|
| 2793 |
@cute.jit
|
|
|
|
| 13 |
# https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha
|
| 14 |
# https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/fmha.py
|
| 15 |
|
|
|
|
| 16 |
import math
|
| 17 |
+
from typing import Tuple, Callable, Optional, Literal
|
| 18 |
from functools import partial
|
| 19 |
|
| 20 |
import cuda.bindings.driver as cuda
|
|
|
|
| 27 |
import cutlass.utils.blackwell_helpers as sm100_utils_basic
|
| 28 |
from cutlass import pipeline
|
| 29 |
from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait
|
| 30 |
+
from cutlass.utils import ClcDynamicPersistentTileScheduler
|
| 31 |
from cutlass.base_dsl.arch import Arch
|
| 32 |
from cutlass.cutlass_dsl import BaseDSL
|
| 33 |
|
|
|
|
| 35 |
|
| 36 |
from .paged_kv import PagedKVManager
|
| 37 |
from .cute_dsl_utils import assume_tensor_aligned
|
| 38 |
+
from . import utils
|
| 39 |
from . import pipeline as pipeline_custom
|
| 40 |
+
import cutlass.pipeline as cutlass_pipeline
|
| 41 |
from .mask import AttentionMask
|
| 42 |
from .softmax import SoftmaxSm100, apply_score_mod_inner
|
| 43 |
from .seqlen_info import SeqlenInfoQK
|
|
|
|
| 49 |
softmax_block_sparse_sm100,
|
| 50 |
handle_block_sparse_empty_tile_correction_sm100,
|
| 51 |
)
|
| 52 |
+
from .pack_gqa import PackGQA, pack_gqa_layout
|
| 53 |
from . import mma_sm100_desc as sm100_desc
|
| 54 |
from . import blackwell_helpers as sm100_utils
|
| 55 |
+
from .named_barrier import NamedBarrierFwdSm100
|
| 56 |
from cutlass.cute import FastDivmodDivisor
|
| 57 |
from .quack.cute_dsl_utils import ParamsBase
|
| 58 |
from .tile_scheduler import (
|
| 59 |
+
ClcState,
|
| 60 |
+
SchedulingMode,
|
| 61 |
TileSchedulerArguments,
|
| 62 |
+
TileSchedulerProtocol,
|
| 63 |
SingleTileScheduler,
|
| 64 |
StaticPersistentTileScheduler,
|
| 65 |
SingleTileLPTScheduler,
|
| 66 |
SingleTileVarlenScheduler,
|
| 67 |
)
|
| 68 |
+
from .fa_logging import fa_log, fa_printf
|
| 69 |
+
from .utils import smid
|
| 70 |
+
|
| 71 |
+
# === TUNING KNOBS (agent-editable) ===
|
| 72 |
+
# Keys: (use_2cta_instrs: bool, is_causal: bool, head_dim_padded: int, is_sm103: bool)
|
| 73 |
+
# Values:
|
| 74 |
+
# ex2_emu_freq: int — how often to use emulated exp2 (0=all hardware exp2, higher=more emulation).
|
| 75 |
+
# SM103 has fast native exp2, so set freq=0 there.
|
| 76 |
+
# ex2_emu_start_frg: int — fragment index to start emulation from
|
| 77 |
+
# num_regs_softmax: int — register count for softmax warps (multiple of 8)
|
| 78 |
+
# num_regs_correction: int — register count for correction warps (multiple of 8)
|
| 79 |
+
# num_regs_other is derived: 512 - num_regs_softmax * 2 - num_regs_correction
|
| 80 |
+
_TUNING_CONFIG = {
|
| 81 |
+
(True, False, 128, False): {'ex2_emu_freq': 10, 'ex2_emu_start_frg': 1, 'num_regs_softmax': 176, 'num_regs_correction': 88},
|
| 82 |
+
(False, True, 128, False): {'ex2_emu_freq': 16, 'ex2_emu_start_frg': 1, 'num_regs_softmax': 192, 'num_regs_correction': 72},
|
| 83 |
+
(True, False, 192, False): {"ex2_emu_freq": 16, "ex2_emu_start_frg": 0, "num_regs_softmax": 184, "num_regs_correction": 80},
|
| 84 |
+
(False, True, 192, False): {"ex2_emu_freq": 32, "ex2_emu_start_frg": 1, "num_regs_softmax": 192, "num_regs_correction": 72},
|
| 85 |
+
(True, False, 128, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 80},
|
| 86 |
+
(False, True, 128, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 64},
|
| 87 |
+
(True, False, 192, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 64},
|
| 88 |
+
(False, True, 192, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 72},
|
| 89 |
+
}
|
| 90 |
+
# === END TUNING KNOBS ===
|
| 91 |
|
| 92 |
|
| 93 |
class FlashAttentionForwardSm100:
|
|
|
|
| 113 |
paged_kv_non_tma: bool = False,
|
| 114 |
is_varlen_q: bool = False,
|
| 115 |
use_2cta_instrs: bool = False,
|
| 116 |
+
use_clc_scheduler: bool = False,
|
| 117 |
):
|
| 118 |
self.use_tma_KV = not paged_kv_non_tma
|
| 119 |
# self.dtype = dtype
|
|
|
|
| 160 |
self.is_split_kv = is_split_kv
|
| 161 |
self.pack_gqa = pack_gqa
|
| 162 |
self.q_subtile_factor = q_subtile_factor
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
assert not (self.is_split_kv and self.head_dim_v_padded >= 192), (
|
| 164 |
"SplitKV is not supported for hdim >= 192"
|
| 165 |
)
|
|
|
|
| 171 |
# Does S1 need to wait for S0 to finish
|
| 172 |
# self.s0_s1_barrier = self.head_dim_padded in [64, 96] and (not self.is_causal and not self.is_local)
|
| 173 |
is_sm103 = self.arch >= Arch.sm_103 and self.arch <= Arch.sm_103f
|
| 174 |
+
self.is_sm103 = is_sm103
|
| 175 |
+
# enable_ex2_emu is derived: True if tuning config has freq > 0, else fallback to default logic
|
| 176 |
+
_default_enable_ex2_emu = (self.head_dim_padded <= 128 or (self.head_dim_padded == 192 and self.use_2cta_instrs and not self.is_causal and not self.is_local)) and not is_sm103
|
| 177 |
+
self.enable_ex2_emu = _default_enable_ex2_emu
|
| 178 |
self.s0_s1_barrier = False
|
| 179 |
self.overlap_sO_sQ = (
|
| 180 |
(self.head_dim_padded == 192 and self.head_dim_v_padded >= 64) or
|
|
|
|
| 187 |
"Paged KV does not support irregular head dim"
|
| 188 |
)
|
| 189 |
|
| 190 |
+
self.use_clc_scheduler = (
|
| 191 |
+
use_clc_scheduler
|
| 192 |
+
and self.use_tma_KV
|
| 193 |
+
and not self.overlap_sO_sQ
|
| 194 |
+
)
|
| 195 |
+
self.sched_stages = 1
|
| 196 |
+
if self.use_clc_scheduler:
|
| 197 |
+
assert self.cluster_shape_mn[1] == 1, f"CLC requires cluster N == 1: {self.cluster_shape_mn}"
|
| 198 |
+
assert self.cluster_shape_mn[0] in (1, 2), f"bad CLC cluster M: {self.cluster_shape_mn}"
|
| 199 |
+
assert self.cluster_shape_mn[0] == self.cta_group_size, (
|
| 200 |
+
f"CLC cluster M != cta_group_size: {self.cluster_shape_mn}, {self.cta_group_size}"
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
self.scheduling_mode = SchedulingMode.CLC if self.use_clc_scheduler else SchedulingMode.STATIC
|
| 204 |
+
|
| 205 |
+
if is_varlen_q:
|
| 206 |
+
self.TileScheduler = SingleTileVarlenScheduler
|
| 207 |
+
elif self.is_causal or self.is_local or self.use_clc_scheduler:
|
| 208 |
+
self.TileScheduler = SingleTileLPTScheduler
|
| 209 |
+
elif self.is_persistent:
|
| 210 |
+
self.TileScheduler = StaticPersistentTileScheduler
|
| 211 |
+
else:
|
| 212 |
+
self.TileScheduler = SingleTileScheduler
|
| 213 |
+
|
| 214 |
+
fa_log(1, f"TileScheduler={self.TileScheduler.__name__}, scheduling_mode={self.scheduling_mode.name}, USE_2CTA={self.use_2cta_instrs}")
|
| 215 |
+
|
| 216 |
self.softmax0_warp_ids = (0, 1, 2, 3)
|
| 217 |
self.softmax1_warp_ids = (4, 5, 6, 7)
|
| 218 |
self.correction_warp_ids = (8, 9, 10, 11)
|
|
|
|
| 234 |
)
|
| 235 |
)
|
| 236 |
|
| 237 |
+
self.use_tma_Q = not (self.pack_gqa and self.m_block_size % self.qhead_per_kvhead != 0)
|
| 238 |
+
|
| 239 |
if self.q_stage == 1:
|
| 240 |
+
if not self.use_tma_KV or not self.use_tma_Q:
|
| 241 |
self.empty_warp_ids = self.empty_warp_ids + self.load_warp_ids
|
| 242 |
self.load_warp_ids = self.softmax1_warp_ids
|
| 243 |
else:
|
|
|
|
| 253 |
elif self.is_varlen_q: # fallback
|
| 254 |
self.epilogue_warp_ids = (13, 14)
|
| 255 |
|
| 256 |
+
self.clc_scheduler_warp_id = self.empty_warp_ids[0] if self.use_clc_scheduler else None
|
| 257 |
+
|
| 258 |
self.tmem_s_offset = [0, self.n_block_size] # e.g., 0, 128
|
| 259 |
self.tmem_o_offset = [
|
| 260 |
self.tmem_s_offset[-1] + self.n_block_size + i * self.head_dim_v_padded
|
|
|
|
| 270 |
# vec buffer for row_max & row_sum
|
| 271 |
self.tmem_vec_offset = self.tmem_s_offset
|
| 272 |
|
| 273 |
+
# Look up tuning config for register counts and ex2_emu params
|
| 274 |
+
_tune_key = (self.use_2cta_instrs, self.is_causal, self.head_dim_padded, self.is_sm103)
|
| 275 |
+
self._tune = _TUNING_CONFIG.get(_tune_key, {})
|
| 276 |
+
if "ex2_emu_freq" in self._tune:
|
| 277 |
+
self.enable_ex2_emu = self._tune["ex2_emu_freq"] > 0
|
| 278 |
if self.head_dim_padded < 96:
|
| 279 |
self.num_regs_softmax = 200 if not paged_kv_non_tma else 184
|
| 280 |
self.num_regs_correction = 64
|
| 281 |
self.num_regs_other = 48 if not paged_kv_non_tma else 80
|
| 282 |
else:
|
| 283 |
+
if not paged_kv_non_tma and "num_regs_softmax" in self._tune:
|
| 284 |
+
self.num_regs_softmax = self._tune["num_regs_softmax"]
|
| 285 |
+
self.num_regs_correction = self._tune["num_regs_correction"]
|
| 286 |
+
elif not paged_kv_non_tma:
|
| 287 |
+
self.num_regs_softmax = 192
|
| 288 |
+
self.num_regs_correction = 80
|
| 289 |
else:
|
| 290 |
+
self.num_regs_softmax = 184
|
| 291 |
+
self.num_regs_correction = 64
|
| 292 |
+
self.num_regs_other = 512 - self.num_regs_softmax * 2 - self.num_regs_correction
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
|
| 294 |
self.buffer_align_bytes = 1024
|
| 295 |
|
|
|
|
| 327 |
self.head_dim_padded == 192 and self.head_dim_v_padded == 128 and self.kv_stage == 3
|
| 328 |
)
|
| 329 |
self.uneven_kv_smem_offset = (
|
| 330 |
+
self.n_block_size * (self.head_dim_padded - self.head_dim_v_padded) // 2
|
| 331 |
if self.uneven_kv_smem
|
| 332 |
else 0
|
| 333 |
)
|
|
|
|
| 342 |
mO: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
|
| 343 |
mLSE: Optional[cute.Tensor],
|
| 344 |
softmax_scale: Float32,
|
|
|
|
| 345 |
mCuSeqlensQ: Optional[cute.Tensor] = None,
|
| 346 |
mCuSeqlensK: Optional[cute.Tensor] = None,
|
| 347 |
mSeqUsedQ: Optional[cute.Tensor] = None,
|
|
|
|
| 352 |
learnable_sink: Optional[cute.Tensor] = None,
|
| 353 |
blocksparse_tensors: Optional[BlockSparseTensors] = None,
|
| 354 |
aux_tensors: Optional[list] = None,
|
| 355 |
+
# Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).
|
| 356 |
+
stream: cuda.CUstream = None,
|
| 357 |
):
|
| 358 |
"""Execute the Fused Multi-Head Attention operation on the provided tensors.
|
| 359 |
|
|
|
|
| 406 |
if const_expr(self.q_dtype != self.v_dtype):
|
| 407 |
raise TypeError(f"Type mismatch: {self.q_dtype} != {self.v_dtype}")
|
| 408 |
self._setup_attributes()
|
| 409 |
+
self.use_tma_O = (
|
| 410 |
+
self.arch >= Arch.sm_90
|
| 411 |
+
and mCuSeqlensQ is None
|
| 412 |
+
and mSeqUsedQ is None
|
| 413 |
+
and not (self.pack_gqa and self.m_block_size % self.qhead_per_kvhead != 0)
|
| 414 |
+
and not (self.pack_gqa and self.is_split_kv)
|
| 415 |
+
)
|
| 416 |
self.ex2_emu_freq = 0
|
| 417 |
+
self.ex2_emu_start_frg = self._tune.get("ex2_emu_start_frg", 1)
|
|
|
|
| 418 |
if const_expr(self.enable_ex2_emu):
|
| 419 |
+
self.ex2_emu_freq = self._tune.get("ex2_emu_freq", 16)
|
|
|
|
|
|
|
| 420 |
if const_expr(
|
| 421 |
self.pack_gqa and self.head_dim_padded > 64 and not self.is_causal and not self.is_local
|
| 422 |
):
|
| 423 |
+
self.ex2_emu_freq = 32 if mCuSeqlensQ is not None or mSeqUsedQ is not None else self._tune.get("ex2_emu_freq", 10)
|
|
|
|
|
|
|
| 424 |
|
| 425 |
cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE
|
| 426 |
q_major_mode = tcgen05.OperandMajorMode.K
|
|
|
|
| 500 |
)
|
| 501 |
|
| 502 |
if const_expr(self.pack_gqa):
|
| 503 |
+
nheads_kv = mK.shape[2]
|
| 504 |
+
mQ = pack_gqa_layout(mQ, self.qhead_per_kvhead, nheads_kv, head_idx=2)
|
| 505 |
+
mO = pack_gqa_layout(mO, self.qhead_per_kvhead, nheads_kv, head_idx=2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 506 |
if const_expr(mLSE is not None):
|
| 507 |
+
mLSE = pack_gqa_layout(mLSE, self.qhead_per_kvhead, nheads_kv, head_idx=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 508 |
|
| 509 |
self.tma_copy_bytes = {
|
| 510 |
name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1, 2]))
|
|
|
|
| 521 |
tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group)
|
| 522 |
tma_store_op = cpasync.CopyBulkTensorTileS2GOp()
|
| 523 |
|
| 524 |
+
if const_expr(self.use_tma_Q):
|
| 525 |
+
tma_atom_Q, mQ = cute.nvgpu.make_tiled_tma_atom_A(
|
| 526 |
+
tma_load_op,
|
| 527 |
+
mQ,
|
| 528 |
+
cute.select(sQ_layout, mode=[0, 1, 2]),
|
| 529 |
+
self.mma_tiler_qk,
|
| 530 |
+
tiled_mma_qk,
|
| 531 |
+
cta_layout_vmnk.shape,
|
| 532 |
+
)
|
| 533 |
+
gmem_tiled_copy_Q = None
|
| 534 |
+
else:
|
| 535 |
+
tma_atom_Q = None
|
| 536 |
+
async_copy_elems = 128 // self.q_dtype.width
|
| 537 |
+
num_load_threads = cute.arch.WARP_SIZE * len(self.load_warp_ids)
|
| 538 |
+
threads_per_row = math.gcd(self.head_dim_padded // async_copy_elems, num_load_threads)
|
| 539 |
+
gmem_tiled_copy_Q = copy_utils.tiled_copy_2d(
|
| 540 |
+
self.q_dtype, threads_per_row, num_load_threads, async_copy_elems, is_async=True
|
| 541 |
+
)
|
| 542 |
|
| 543 |
tma_atom_K = None
|
| 544 |
tma_atom_V = None
|
|
|
|
| 587 |
vO_layout = cute.make_layout((1, async_copy_elems))
|
| 588 |
gmem_tiled_copy_O = cute.make_tiled_copy_tv(atom_universal_copy, tO_layout, vO_layout)
|
| 589 |
|
| 590 |
+
TileScheduler = self.TileScheduler
|
| 591 |
+
_num_block_divisor = self.cta_tiler[0] * (self.cta_group_size if not self.is_persistent and self.cta_group_size > 1 else 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 592 |
tile_sched_args = TileSchedulerArguments(
|
| 593 |
+
cute.ceil_div(cute.size(mQ.shape[0]), _num_block_divisor),
|
| 594 |
cute.size(mQ.shape[2]),
|
| 595 |
cute.size(mQ.shape[3])
|
| 596 |
if const_expr(mCuSeqlensQ is None)
|
|
|
|
| 613 |
lpt=self.is_causal or self.is_local,
|
| 614 |
is_split_kv=self.is_split_kv,
|
| 615 |
cluster_shape_mn=self.cluster_shape_mn,
|
| 616 |
+
use_cluster_idx=not self.is_persistent and self.cta_group_size > 1,
|
| 617 |
+
)
|
| 618 |
+
tile_sched_params = TileScheduler.to_underlying_arguments(
|
| 619 |
+
tile_sched_args, scheduling_mode=self.scheduling_mode
|
| 620 |
)
|
|
|
|
| 621 |
self.tile_scheduler_cls = TileScheduler
|
| 622 |
grid_dim = TileScheduler.get_grid_shape(tile_sched_params)
|
| 623 |
|
|
|
|
| 627 |
cutlass.max(cute.cosize(sQ_layout), cute.cosize(sO_layout) * self.o_dtype.width // self.q_dtype.width)
|
| 628 |
)
|
| 629 |
|
| 630 |
+
clc_response_size = self.sched_stages * 4 if self.use_clc_scheduler else 0
|
| 631 |
+
clc_mbar_size = self.sched_stages * 2 if self.use_clc_scheduler else 0
|
| 632 |
+
|
| 633 |
@cute.struct
|
| 634 |
class SharedStorage:
|
| 635 |
# m_barriers for pipelines
|
|
|
|
| 649 |
# Smem tensors
|
| 650 |
# store row max and row sum
|
| 651 |
sScale: cute.struct.MemRange[Float32, self.q_stage * self.m_block_size * 2]
|
| 652 |
+
# CLC buffers placed here to utilize padding before sO's 1024-byte alignment.
|
| 653 |
+
# This avoids adding bytes at the end when we're at the smem limit.
|
| 654 |
+
# PipelineClcFetchAsync expects 2 * sched_stages mbarriers (full + empty).
|
| 655 |
+
clc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, clc_mbar_size]
|
| 656 |
+
# CLC response storage (16 bytes per stage, stored as 4 Int32s).
|
| 657 |
+
clc_response: cute.struct.MemRange[Int32, clc_response_size]
|
| 658 |
+
# Large TMA buffers with 1024-byte alignment
|
| 659 |
sO: cute.struct.Align[
|
| 660 |
cute.struct.MemRange[self.o_dtype, sO_size], self.buffer_align_bytes
|
| 661 |
]
|
|
|
|
| 670 |
|
| 671 |
self.shared_storage = SharedStorage
|
| 672 |
|
| 673 |
+
softmax_scale_log2, softmax_scale = utils.compute_softmax_scale_log2(softmax_scale, self.score_mod)
|
| 674 |
+
window_size_left = Int32(window_size_left) if window_size_left is not None else None
|
| 675 |
+
window_size_right = Int32(window_size_right) if window_size_right is not None else None
|
| 676 |
+
fastdiv_mods = utils.compute_fastdiv_mods(mQ, mK, self.qhead_per_kvhead, self.pack_gqa, aux_tensors, mPageTable)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 677 |
|
| 678 |
head_divmod = None
|
| 679 |
if cutlass.const_expr(self.pack_gqa):
|
|
|
|
| 710 |
tP_layout,
|
| 711 |
sV_layout,
|
| 712 |
sO_layout,
|
| 713 |
+
gmem_tiled_copy_Q,
|
| 714 |
gmem_tiled_copy_O,
|
| 715 |
tiled_mma_qk,
|
| 716 |
tiled_mma_pv,
|
|
|
|
| 741 |
mSeqUsedQ: Optional[cute.Tensor],
|
| 742 |
mSeqUsedK: Optional[cute.Tensor],
|
| 743 |
mPageTable: Optional[cute.Tensor],
|
| 744 |
+
tma_atom_Q: Optional[cute.CopyAtom],
|
| 745 |
tma_atom_K: Optional[cute.CopyAtom],
|
| 746 |
tma_atom_V: Optional[cute.CopyAtom],
|
| 747 |
tma_atom_O: Optional[cute.CopyAtom],
|
|
|
|
| 756 |
tP_layout: cute.ComposedLayout,
|
| 757 |
sV_layout: cute.ComposedLayout,
|
| 758 |
sO_layout: cute.ComposedLayout,
|
| 759 |
+
gmem_tiled_copy_Q: Optional[cute.TiledCopy],
|
| 760 |
gmem_tiled_copy_O: Optional[cute.TiledCopy],
|
| 761 |
tiled_mma_qk: cute.TiledMma,
|
| 762 |
tiled_mma_pv: cute.TiledMma,
|
|
|
|
| 804 |
storage = smem.allocate(self.shared_storage)
|
| 805 |
|
| 806 |
tmem_alloc_barrier = pipeline.NamedBarrier(
|
| 807 |
+
barrier_id=int(NamedBarrierFwdSm100.TmemPtr),
|
| 808 |
num_threads=cute.arch.WARP_SIZE * len(
|
| 809 |
(self.mma_warp_id,
|
| 810 |
*self.softmax0_warp_ids,
|
|
|
|
| 823 |
|
| 824 |
ThreadCooperativeGroup = partial(pipeline.CooperativeGroup, pipeline.Agent.Thread)
|
| 825 |
mma_warp = ThreadCooperativeGroup(len([self.mma_warp_id]))
|
|
|
|
| 826 |
tma_warp = ThreadCooperativeGroup(1)
|
| 827 |
+
load_threads = ThreadCooperativeGroup(len(self.load_warp_ids) * cute.arch.WARP_SIZE)
|
| 828 |
softmax_warps = ThreadCooperativeGroup(len(self.softmax0_warp_ids))
|
| 829 |
softmax_threads = ThreadCooperativeGroup(cute.arch.WARP_SIZE * len(self.softmax0_warp_ids))
|
| 830 |
# softmax_threads = ThreadCooperativeGroup(cute.arch.WARP_SIZE)
|
|
|
|
| 847 |
softmax_correction_threads_cluster = ThreadCooperativeGroup(
|
| 848 |
cute.arch.WARP_SIZE * len(self.softmax0_warp_ids + self.correction_warp_ids) * self.cta_group_size
|
| 849 |
)
|
| 850 |
+
if const_expr(self.use_tma_Q):
|
| 851 |
+
pipeline_q = pipeline_custom.PipelineTmaUmma.create(
|
| 852 |
+
barrier_storage=storage.mbar_load_Q.data_ptr(),
|
| 853 |
+
num_stages=self.q_stage,
|
| 854 |
+
producer_group=tma_warp,
|
| 855 |
+
consumer_group=mma_warp,
|
| 856 |
+
tx_count=self.tma_copy_bytes["Q"],
|
| 857 |
+
cta_layout_vmnk=cta_layout_vmnk,
|
| 858 |
+
defer_sync=True,
|
| 859 |
+
)
|
| 860 |
+
else:
|
| 861 |
+
pipeline_q = pipeline_custom.PipelineAsyncUmma.create(
|
| 862 |
+
barrier_storage=storage.mbar_load_Q.data_ptr(),
|
| 863 |
+
num_stages=self.q_stage,
|
| 864 |
+
producer_group=load_threads,
|
| 865 |
+
consumer_group=mma_warp,
|
| 866 |
+
cta_layout_vmnk=cta_layout_vmnk,
|
| 867 |
+
defer_sync=True,
|
| 868 |
+
)
|
| 869 |
if const_expr(self.use_tma_KV):
|
| 870 |
pipeline_kv = pipeline_custom.PipelineTmaUmma.create(
|
| 871 |
barrier_storage=storage.mbar_load_KV.data_ptr(),
|
|
|
|
| 877 |
defer_sync=True,
|
| 878 |
)
|
| 879 |
else:
|
|
|
|
|
|
|
|
|
|
| 880 |
pipeline_kv = pipeline.PipelineAsyncUmma.create(
|
| 881 |
barrier_storage=storage.mbar_load_KV.data_ptr(),
|
| 882 |
num_stages=self.kv_stage,
|
| 883 |
+
producer_group=load_threads,
|
| 884 |
consumer_group=mma_warp,
|
| 885 |
cta_layout_vmnk=cta_layout_vmnk,
|
| 886 |
defer_sync=True,
|
|
|
|
| 935 |
)
|
| 936 |
# Should put the NamedBarrier inside the pipeline class so we'll just have pipeline_sm_stats
|
| 937 |
sm_stats_barrier = pipeline_custom.NamedBarrier(
|
| 938 |
+
barrier_id=int(NamedBarrierFwdSm100.SoftmaxStatsW0), num_threads=cute.arch.WARP_SIZE * 2
|
| 939 |
)
|
| 940 |
pipeline_o_epi = None
|
| 941 |
if const_expr(not self.use_correction_warps_for_epi):
|
|
|
|
| 1016 |
window_size_right=window_size_right,
|
| 1017 |
qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
|
| 1018 |
)
|
|
|
|
|
|
|
| 1019 |
# Cluster wait before tensor memory alloc
|
| 1020 |
pipeline_init_wait(cluster_shape_mn=cta_layout_vmnk)
|
| 1021 |
|
| 1022 |
+
if const_expr(self.use_clc_scheduler):
|
| 1023 |
+
clc_response_ptr = storage.clc_response.data_ptr()
|
| 1024 |
+
clc_mbar_ptr = storage.clc_mbar_ptr.data_ptr()
|
| 1025 |
+
|
| 1026 |
+
clc_pipeline_producer_group = cutlass_pipeline.CooperativeGroup(
|
| 1027 |
+
cutlass_pipeline.Agent.Thread
|
| 1028 |
+
)
|
| 1029 |
+
num_clc_consumer_warps_per_cta = self.threads_per_cta // cute.arch.WARP_SIZE
|
| 1030 |
+
# NB on CTA0 warp15 == scheduler on CTA1 == empty but still both consume
|
| 1031 |
+
num_clc_consumer_warps = num_clc_consumer_warps_per_cta * self.cta_group_size
|
| 1032 |
+
clc_pipeline_consumer_group = cutlass_pipeline.CooperativeGroup(
|
| 1033 |
+
cutlass_pipeline.Agent.Thread, cute.arch.WARP_SIZE * num_clc_consumer_warps
|
| 1034 |
+
)
|
| 1035 |
+
|
| 1036 |
+
block_idx = cute.arch.block_idx()
|
| 1037 |
+
clc = ClcState.create(
|
| 1038 |
+
hw_scheduler=ClcDynamicPersistentTileScheduler.create(
|
| 1039 |
+
self.tile_scheduler_cls.clc_problem_shape(tile_sched_params),
|
| 1040 |
+
block_idx,
|
| 1041 |
+
cute.arch.grid_dim(),
|
| 1042 |
+
clc_response_ptr,
|
| 1043 |
+
),
|
| 1044 |
+
pipeline=cutlass_pipeline.PipelineClcFetchAsync.create(
|
| 1045 |
+
barrier_storage=clc_mbar_ptr,
|
| 1046 |
+
num_stages=self.sched_stages,
|
| 1047 |
+
producer_group=clc_pipeline_producer_group,
|
| 1048 |
+
consumer_group=clc_pipeline_consumer_group,
|
| 1049 |
+
tx_count=16,
|
| 1050 |
+
cta_layout_vmnk=cta_layout_vmnk,
|
| 1051 |
+
),
|
| 1052 |
+
consumer_state=cutlass_pipeline.make_pipeline_state(
|
| 1053 |
+
cutlass_pipeline.PipelineUserType.Consumer, self.sched_stages
|
| 1054 |
+
),
|
| 1055 |
+
producer_state=cutlass_pipeline.make_pipeline_state(
|
| 1056 |
+
cutlass_pipeline.PipelineUserType.Producer, self.sched_stages
|
| 1057 |
+
),
|
| 1058 |
+
)
|
| 1059 |
+
tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params, clc=clc)
|
| 1060 |
+
else:
|
| 1061 |
+
tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params)
|
| 1062 |
+
assert isinstance(tile_scheduler, TileSchedulerProtocol), f"tile_scheduler is not a TileSchedulerProtocol: {type(tile_scheduler)}"
|
| 1063 |
+
|
| 1064 |
# ///////////////////////////////////////////////////////////////////////////////
|
| 1065 |
+
# EMPTY / CLC SCHEDULER WARP
|
| 1066 |
# ///////////////////////////////////////////////////////////////////////////////
|
| 1067 |
+
if const_expr(self.use_clc_scheduler):
|
| 1068 |
+
if warp_idx == self.clc_scheduler_warp_id:
|
| 1069 |
cute.arch.setmaxregister_decrease(self.num_regs_other)
|
| 1070 |
+
if is_leader_cta:
|
| 1071 |
+
self.clc_scheduler_warp(tile_scheduler)
|
| 1072 |
+
else:
|
| 1073 |
+
self.empty_warp(tile_scheduler)
|
| 1074 |
+
for i in cutlass.range_constexpr(len(self.empty_warp_ids)):
|
| 1075 |
+
if warp_idx == self.empty_warp_ids[i] and warp_idx != self.clc_scheduler_warp_id:
|
| 1076 |
+
cute.arch.setmaxregister_decrease(self.num_regs_other)
|
| 1077 |
+
self.empty_warp(tile_scheduler)
|
| 1078 |
+
else:
|
| 1079 |
+
for i in cutlass.range_constexpr(len(self.empty_warp_ids)):
|
| 1080 |
+
if warp_idx == self.empty_warp_ids[i]:
|
| 1081 |
+
cute.arch.setmaxregister_decrease(self.num_regs_other)
|
| 1082 |
|
| 1083 |
# ///////////////////////////////////////////////////////////////////////////////
|
| 1084 |
# LOAD
|
|
|
|
| 1098 |
tma_atom_Q,
|
| 1099 |
tma_atom_K,
|
| 1100 |
tma_atom_V,
|
| 1101 |
+
gmem_tiled_copy_Q,
|
| 1102 |
pipeline_q,
|
| 1103 |
pipeline_kv,
|
| 1104 |
block_info,
|
| 1105 |
num_splits,
|
| 1106 |
SeqlenInfoCls,
|
|
|
|
| 1107 |
blocksparse_tensors,
|
| 1108 |
+
tile_scheduler=tile_scheduler,
|
| 1109 |
)
|
| 1110 |
|
| 1111 |
# ///////////////////////////////////////////////////////////////////////////////
|
|
|
|
| 1135 |
block_info,
|
| 1136 |
num_splits,
|
| 1137 |
SeqlenInfoCls,
|
|
|
|
| 1138 |
blocksparse_tensors,
|
| 1139 |
+
tile_scheduler=tile_scheduler,
|
| 1140 |
)
|
| 1141 |
# Dealloc the tensor memory buffer
|
| 1142 |
tmem.relinquish_alloc_permit()
|
|
|
|
| 1158 |
block_info,
|
| 1159 |
num_splits,
|
| 1160 |
SeqlenInfoCls,
|
|
|
|
| 1161 |
mma_tile_coord_v,
|
| 1162 |
+
tile_scheduler=tile_scheduler,
|
| 1163 |
)
|
| 1164 |
|
| 1165 |
# ///////////////////////////////////////////////////////////////////////////////
|
|
|
|
| 1191 |
num_splits=num_splits,
|
| 1192 |
SeqlenInfoCls=SeqlenInfoCls,
|
| 1193 |
AttentionMaskCls=AttentionMaskCls,
|
|
|
|
| 1194 |
aux_tensors=aux_tensors,
|
| 1195 |
fastdiv_mods=fastdiv_mods,
|
| 1196 |
head_divmod=head_divmod,
|
| 1197 |
blocksparse_tensors=blocksparse_tensors,
|
| 1198 |
+
tile_scheduler=tile_scheduler,
|
| 1199 |
)
|
| 1200 |
|
| 1201 |
if const_expr(not self.s0_s1_barrier):
|
|
|
|
| 1239 |
block_info,
|
| 1240 |
num_splits,
|
| 1241 |
SeqlenInfoCls,
|
|
|
|
| 1242 |
blocksparse_tensors,
|
| 1243 |
+
tile_scheduler=tile_scheduler,
|
| 1244 |
)
|
| 1245 |
tmem_alloc_barrier.arrive()
|
| 1246 |
|
|
|
|
| 1258 |
sK: cute.Tensor,
|
| 1259 |
sV: cute.Tensor,
|
| 1260 |
mPageTable: Optional[cute.Tensor],
|
| 1261 |
+
tma_atom_Q: Optional[cute.CopyAtom],
|
| 1262 |
tma_atom_K: Optional[cute.CopyAtom],
|
| 1263 |
tma_atom_V: Optional[cute.CopyAtom],
|
| 1264 |
+
gmem_tiled_copy_Q: Optional[cute.TiledCopy],
|
| 1265 |
pipeline_q: pipeline.PipelineAsync,
|
| 1266 |
pipeline_kv: pipeline.PipelineAsync,
|
| 1267 |
block_info: BlockInfo,
|
| 1268 |
num_splits: Int32,
|
| 1269 |
SeqlenInfoCls: Callable,
|
|
|
|
| 1270 |
blocksparse_tensors: Optional[BlockSparseTensors],
|
| 1271 |
+
tile_scheduler: TileSchedulerProtocol,
|
| 1272 |
):
|
| 1273 |
num_load_threads = len(self.load_warp_ids) * cute.arch.WARP_SIZE
|
| 1274 |
tidx = cute.arch.thread_idx()[0] % num_load_threads
|
| 1275 |
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
| 1276 |
+
issue_kv_for_this_warp = (
|
| 1277 |
+
const_expr(not self.use_tma_KV or len(self.load_warp_ids) == 1) or
|
| 1278 |
+
warp_idx == self.load_warp_ids[0]
|
| 1279 |
+
)
|
| 1280 |
+
issue_q_for_this_warp = (
|
| 1281 |
+
const_expr(not self.use_tma_Q or len(self.load_warp_ids) == 1) or
|
| 1282 |
+
warp_idx == self.load_warp_ids[0]
|
| 1283 |
+
)
|
| 1284 |
q_producer_phase = Int32(1)
|
| 1285 |
kv_producer_state = pipeline.make_pipeline_state(
|
| 1286 |
pipeline.PipelineUserType.Producer, self.kv_stage
|
| 1287 |
)
|
|
|
|
| 1288 |
work_tile = tile_scheduler.initial_work_tile_info()
|
| 1289 |
while work_tile.is_valid_tile:
|
| 1290 |
m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx
|
| 1291 |
seqlen = SeqlenInfoCls(batch_idx)
|
| 1292 |
mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1293 |
|
| 1294 |
head_idx_kv = (
|
| 1295 |
head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx
|
|
|
|
| 1311 |
gV = cute.local_tile(
|
| 1312 |
mV_cur, cute.select(self.mma_tiler_pv, mode=[1, 2]), (0, None, None)
|
| 1313 |
)
|
|
|
|
| 1314 |
tSgK = thr_mma_qk.partition_B(gK)
|
| 1315 |
tOgV = thr_mma_pv.partition_B(gV)
|
| 1316 |
+
if const_expr(self.use_tma_Q):
|
| 1317 |
+
tiler_gQ = ((self.mma_tiler_qk[0] * self.q_stage), self.head_dim_padded)
|
| 1318 |
+
gQ = cute.local_tile(mQ_cur, tiler_gQ, (m_block, 0)) # (128 * 2, 128)
|
| 1319 |
+
gQ = layout_utils.select(
|
| 1320 |
+
cute.flat_divide(gQ, (self.mma_tiler_qk[0],)), mode=[0, 2, 1]
|
| 1321 |
+
) # (128, 128, 2)
|
| 1322 |
+
tSgQ = thr_mma_qk.partition_A(gQ)
|
| 1323 |
+
load_Q_fn, _, _ = copy_utils.tma_get_copy_fn(
|
| 1324 |
+
tma_atom_Q, 0, cute.make_layout(1), tSgQ, sQ
|
| 1325 |
+
)
|
| 1326 |
+
load_Q = partial(self.load_Q, load_Q_fn, pipeline_q=pipeline_q, phase=q_producer_phase)
|
| 1327 |
+
else:
|
| 1328 |
+
assert gmem_tiled_copy_Q is not None
|
| 1329 |
+
load_Q = partial(
|
| 1330 |
+
self.load_Q_non_tma,
|
| 1331 |
+
mQ_cur,
|
| 1332 |
+
sQ,
|
| 1333 |
+
gmem_tiled_copy_Q,
|
| 1334 |
+
pipeline_q,
|
| 1335 |
+
tidx,
|
| 1336 |
+
seqlen.seqlen_q,
|
| 1337 |
+
m_block,
|
| 1338 |
+
phase=q_producer_phase,
|
| 1339 |
+
)
|
| 1340 |
|
| 1341 |
if const_expr(self.use_tma_KV):
|
| 1342 |
tKsK, tKgK = cpasync.tma_partition(
|
|
|
|
| 1375 |
tKsK, tKgK = None, None
|
| 1376 |
tVsV, tVgV = None, None
|
| 1377 |
|
|
|
|
| 1378 |
load_K = partial(
|
| 1379 |
self.load_KV,
|
| 1380 |
tma_atom_K,
|
|
|
|
| 1409 |
)
|
| 1410 |
if const_expr(not self.use_tma_KV):
|
| 1411 |
paged_kv_manager.load_page_table(n_block_first)
|
| 1412 |
+
if issue_kv_for_this_warp:
|
| 1413 |
+
load_K(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # K0
|
| 1414 |
# load_K(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx, extra_tx_count=self.tma_copy_bytes["Q"]) # K0
|
| 1415 |
+
if issue_q_for_this_warp:
|
| 1416 |
+
load_Q(block=0, stage=0)
|
| 1417 |
+
if issue_kv_for_this_warp:
|
| 1418 |
+
kv_producer_state.advance()
|
| 1419 |
+
if const_expr(self.q_stage == 2) and issue_q_for_this_warp:
|
| 1420 |
+
load_Q(block=1, stage=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1421 |
q_producer_phase ^= 1
|
| 1422 |
+
if issue_kv_for_this_warp:
|
| 1423 |
+
load_V(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # V0
|
| 1424 |
+
kv_producer_state.advance()
|
| 1425 |
for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1):
|
| 1426 |
n_block = n_block_max - 2 - i
|
| 1427 |
page_idx = (
|
|
|
|
| 1432 |
if const_expr(not self.use_tma_KV):
|
| 1433 |
paged_kv_manager.load_page_table(n_block)
|
| 1434 |
# if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("n_block = {}, page_idx = {}", n_block, page_idx)
|
| 1435 |
+
if issue_kv_for_this_warp:
|
| 1436 |
+
load_K(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Ki
|
| 1437 |
+
kv_producer_state.advance()
|
| 1438 |
+
load_V(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Vi
|
| 1439 |
+
kv_producer_state.advance()
|
| 1440 |
|
| 1441 |
else:
|
| 1442 |
kv_producer_state, q_producer_phase = produce_block_sparse_loads_sm100(
|
|
|
|
| 1455 |
self.q_subtile_factor if self.q_subtile_factor is not None else 1,
|
| 1456 |
)
|
| 1457 |
|
| 1458 |
+
|
| 1459 |
+
work_tile = tile_scheduler.advance_to_next_work()
|
|
|
|
| 1460 |
# End of persistent scheduler loop
|
| 1461 |
|
| 1462 |
+
if issue_kv_for_this_warp:
|
| 1463 |
+
pipeline_kv.producer_tail(kv_producer_state)
|
| 1464 |
+
# This is equivalent to pipeline_q.producer_tail for the TMA-Q producer warp.
|
| 1465 |
+
if issue_q_for_this_warp:
|
| 1466 |
pipeline_q.producer_acquire_w_index_phase(self.q_stage - 1, q_producer_phase)
|
| 1467 |
|
| 1468 |
@cute.jit
|
|
|
|
| 1485 |
block_info: BlockInfo,
|
| 1486 |
num_splits: Int32,
|
| 1487 |
SeqlenInfoCls: Callable,
|
|
|
|
| 1488 |
blocksparse_tensors: Optional[BlockSparseTensors],
|
| 1489 |
+
tile_scheduler=None,
|
| 1490 |
):
|
| 1491 |
tSrQ = tiled_mma_qk.make_fragment_A(sQ)
|
| 1492 |
tSrK = tiled_mma_qk.make_fragment_B(sK)
|
|
|
|
| 1575 |
)
|
| 1576 |
P_full_O_rescaled_phase = Int32(0)
|
| 1577 |
|
|
|
|
| 1578 |
work_tile = tile_scheduler.initial_work_tile_info()
|
| 1579 |
while work_tile.is_valid_tile:
|
| 1580 |
m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx
|
|
|
|
| 1745 |
# End of GEMM_PV1(i_end) (P1 * Vi_end -> O1)
|
| 1746 |
|
| 1747 |
# Advance to next tile
|
| 1748 |
+
work_tile = tile_scheduler.advance_to_next_work()
|
|
|
|
| 1749 |
# End of persistent scheduler loop
|
| 1750 |
|
| 1751 |
# We don't need pipeline_s_p_o.producer_tail() since there's no dangling mbarrier at the end
|
|
|
|
| 1774 |
num_splits: Int32,
|
| 1775 |
SeqlenInfoCls: Callable,
|
| 1776 |
AttentionMaskCls: Callable,
|
|
|
|
| 1777 |
aux_tensors: Optional[list] = None,
|
| 1778 |
fastdiv_mods=(None, None),
|
| 1779 |
head_divmod=None,
|
| 1780 |
blocksparse_tensors: Optional[BlockSparseTensors] = None,
|
| 1781 |
+
tile_scheduler=None,
|
| 1782 |
):
|
| 1783 |
"""Compute softmax on attention scores from QK matrix multiplication.
|
| 1784 |
|
|
|
|
| 1838 |
|
| 1839 |
warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4
|
| 1840 |
|
|
|
|
| 1841 |
work_tile = tile_scheduler.initial_work_tile_info()
|
| 1842 |
while work_tile.is_valid_tile:
|
| 1843 |
m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx
|
|
|
|
| 2080 |
# gLSE[tidx] = lse
|
| 2081 |
|
| 2082 |
# Advance to next tile
|
| 2083 |
+
work_tile = tile_scheduler.advance_to_next_work()
|
|
|
|
| 2084 |
# End of persistent scheduler loop
|
| 2085 |
|
| 2086 |
# This is equivalent to pipeline_sm_stats.producer_tail
|
|
|
|
| 2250 |
block_info: BlockInfo,
|
| 2251 |
num_splits: Int32,
|
| 2252 |
SeqlenInfoCls: Callable,
|
|
|
|
| 2253 |
blocksparse_tensors: Optional[BlockSparseTensors] = None,
|
| 2254 |
+
tile_scheduler=None,
|
| 2255 |
):
|
| 2256 |
tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.correction_warp_ids))
|
| 2257 |
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4
|
|
|
|
| 2281 |
o_corr_consumer_phase = Int32(0)
|
| 2282 |
corr_epi_producer_phase = Int32(1)
|
| 2283 |
|
|
|
|
| 2284 |
work_tile = tile_scheduler.initial_work_tile_info()
|
| 2285 |
while work_tile.is_valid_tile:
|
| 2286 |
m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx
|
|
|
|
| 2291 |
mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx]
|
| 2292 |
else:
|
| 2293 |
mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx]
|
| 2294 |
+
gO = None
|
| 2295 |
+
if const_expr(self.use_tma_O or not self.pack_gqa):
|
| 2296 |
+
tiler_gO = ((self.mma_tiler_pv[0] * self.q_stage), self.head_dim_v_padded)
|
| 2297 |
+
gO = cute.local_tile(mO_cur, tiler_gO, (m_block, 0)) # (128 * 2, 128)
|
| 2298 |
+
gO = layout_utils.select(
|
| 2299 |
+
cute.flat_divide(gO, (self.mma_tiler_pv[0],)), mode=[0, 2, 1]
|
| 2300 |
+
) # (128, 128, 2)
|
| 2301 |
+
gO = cute.flat_divide(gO, (self.mma_tiler_pv[0] // self.cta_group_size,))[None, mma_tile_coord_v, None, None]
|
| 2302 |
|
| 2303 |
# Default LSE to -inf for invalid split_idx tiles
|
| 2304 |
stats = [(0.0, -Float32.inf if const_expr(mLSE is not None or learnable_sink is not None) else None, True)] * self.q_stage
|
|
|
|
| 2399 |
pipeline_o_acc.consumer_wait_w_index_phase(stage, o_corr_consumer_phase)
|
| 2400 |
if const_expr(not self.use_correction_warps_for_epi):
|
| 2401 |
pipeline_o_epi.producer_acquire_w_index_phase(stage, corr_epi_producer_phase)
|
| 2402 |
+
gO_stage = gO[None, None, stage] if const_expr(gO is not None) else None
|
| 2403 |
self.correction_epilogue(
|
| 2404 |
thr_mma_pv,
|
| 2405 |
tOtO[None, None, None, stage],
|
|
|
|
| 2410 |
scale,
|
| 2411 |
sO[None, None, stage],
|
| 2412 |
mO_cur,
|
| 2413 |
+
gO_stage,
|
| 2414 |
gmem_tiled_copy_O,
|
| 2415 |
)
|
| 2416 |
# Signal for the next work tile that O buffers in tmem are already read, so
|
|
|
|
| 2480 |
mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx])
|
| 2481 |
for stage in cutlass.range_constexpr(self.q_stage):
|
| 2482 |
m_tile_idx = (m_block * self.q_stage + stage) * self.cta_group_size + mma_tile_coord_v
|
|
|
|
| 2483 |
row_sum, row_max, acc_O_mn_row_is_zero_or_nan = stats[stage]
|
| 2484 |
# if tidx == 0 and stage <= 1:
|
| 2485 |
# cute.printf("row_sum = {}, row_max = {}, acc_O_mn_row_is_zero_or_nan = {}\n", row_sum, row_max, acc_O_mn_row_is_zero_or_nan)
|
|
|
|
| 2494 |
if const_expr(not self.pack_gqa)
|
| 2495 |
else seqlen.seqlen_q * self.qhead_per_kvhead
|
| 2496 |
)
|
| 2497 |
+
if const_expr(not self.pack_gqa or self.m_block_size % self.qhead_per_kvhead == 0):
|
| 2498 |
+
gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_tile_idx,))
|
| 2499 |
+
if tidx < seqlen_q - m_tile_idx * self.m_block_size:
|
| 2500 |
+
# This actually just works with PackGQA too
|
| 2501 |
+
gLSE[tidx] = lse
|
| 2502 |
+
else:
|
| 2503 |
+
idx = m_tile_idx * self.m_block_size + tidx
|
| 2504 |
+
if idx < seqlen_q:
|
| 2505 |
+
m_idx = idx // self.qhead_per_kvhead
|
| 2506 |
+
h_idx = idx - m_idx * self.qhead_per_kvhead
|
| 2507 |
+
lse_ptr_i64 = utils.elem_pointer(mLSE_cur, ((h_idx, m_idx),)).toint()
|
| 2508 |
+
lse_gmem_ptr = cute.make_ptr(
|
| 2509 |
+
mLSE_cur.element_type, lse_ptr_i64, cute.AddressSpace.gmem, assumed_align=4
|
| 2510 |
+
)
|
| 2511 |
+
cute.make_tensor(lse_gmem_ptr, (1,))[0] = lse
|
| 2512 |
|
| 2513 |
# Advance to next tile
|
| 2514 |
+
work_tile = tile_scheduler.advance_to_next_work()
|
|
|
|
| 2515 |
# End of persistent scheduler loop
|
| 2516 |
|
| 2517 |
# This is equivalent to pipeline_o_epi.consumer_tail() for the correction warps
|
|
|
|
| 2650 |
if const_expr(self.use_correction_warps_for_epi):
|
| 2651 |
assert(not self.use_tma_O)
|
| 2652 |
assert(gmem_tiled_copy_O is not None)
|
| 2653 |
+
cute.arch.barrier(barrier_id=int(NamedBarrierFwdSm100.Epilogue),
|
| 2654 |
number_of_threads=len(self.epilogue_warp_ids) * cute.arch.WARP_SIZE)
|
| 2655 |
mma_tile_coord_v = thr_mma.thr_idx
|
| 2656 |
m_tile_idx = (m_block * self.q_stage + stage) * self.cta_group_size + mma_tile_coord_v
|
|
|
|
| 2662 |
def _store_O_to_gmem(
|
| 2663 |
self,
|
| 2664 |
sO_stage: cute.Tensor,
|
| 2665 |
+
gO: Optional[cute.Tensor],
|
| 2666 |
mO_cur: cute.Tensor,
|
| 2667 |
gmem_tiled_copy_O: cute.TiledCopy,
|
| 2668 |
tidx: Int32,
|
|
|
|
| 2673 |
gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx)
|
| 2674 |
tOsO = gmem_thr_copy_O.partition_S(sO_stage)
|
| 2675 |
cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded))
|
|
|
|
| 2676 |
tOcO = gmem_thr_copy_O.partition_S(cO)
|
| 2677 |
t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO)
|
| 2678 |
tOpO = copy_utils.predicate_k(tOcO, limit=mO_cur.shape[1])
|
|
|
|
| 2688 |
cute.autovec_copy(tOsO, tOrO)
|
| 2689 |
# copy acc O from rmem to gmem
|
| 2690 |
if const_expr(not self.pack_gqa):
|
| 2691 |
+
assert gO is not None
|
| 2692 |
+
tOgO = gmem_thr_copy_O.partition_D(gO)
|
| 2693 |
for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])):
|
| 2694 |
if (
|
| 2695 |
t0OcO[0, rest_m, 0][0] < seqlen_q - m_tile_idx * self.m_block_size - tOcO[0][0]
|
|
|
|
| 2718 |
block_info: BlockInfo,
|
| 2719 |
num_splits: int,
|
| 2720 |
SeqlenInfoCls: Callable,
|
|
|
|
| 2721 |
mma_tile_coord_v: Int32 = 0,
|
| 2722 |
+
tile_scheduler=None,
|
| 2723 |
):
|
| 2724 |
epi_consumer_phase = Int32(0)
|
|
|
|
| 2725 |
work_tile = tile_scheduler.initial_work_tile_info()
|
| 2726 |
while work_tile.is_valid_tile:
|
| 2727 |
m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx
|
|
|
|
| 2733 |
mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx]
|
| 2734 |
else:
|
| 2735 |
mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx]
|
| 2736 |
+
gO = None
|
| 2737 |
+
if const_expr(self.use_tma_O or not self.pack_gqa):
|
| 2738 |
+
tiler_gO = ((self.mma_tiler_pv[0] * self.q_stage), self.head_dim_v_padded)
|
| 2739 |
+
gO = cute.local_tile(mO_cur, tiler_gO, (m_block, 0)) # (128 * 2, 128)
|
| 2740 |
+
gO = layout_utils.select(
|
| 2741 |
+
cute.flat_divide(gO, (self.mma_tiler_pv[0],)), mode=[0, 2, 1]
|
| 2742 |
+
) # (128, 128, 2)
|
| 2743 |
+
gO = cute.flat_divide(gO, (self.mma_tiler_pv[0] // self.cta_group_size,))[None, mma_tile_coord_v, None, None]
|
| 2744 |
|
| 2745 |
if const_expr(self.use_tma_O):
|
| 2746 |
store_O, _, _ = copy_utils.tma_get_copy_fn(
|
|
|
|
| 2767 |
pipeline_o_epi.consumer_wait_w_index_phase(stage, epi_consumer_phase)
|
| 2768 |
# 2. copy O0 / O1 to gmem
|
| 2769 |
m_tile_idx = (m_block * self.q_stage + stage) * self.cta_group_size + mma_tile_coord_v
|
| 2770 |
+
gO_stage = gO[None, None, stage] if const_expr(gO is not None) else None
|
| 2771 |
self._store_O_to_gmem(
|
| 2772 |
+
sO[None, None, stage], gO_stage, mO_cur, gmem_tiled_copy_O,
|
| 2773 |
tidx, seqlen.seqlen_q, m_tile_idx,
|
| 2774 |
)
|
| 2775 |
pipeline_o_epi.consumer_release_w_index(stage)
|
|
|
|
| 2777 |
epi_consumer_phase ^= 1
|
| 2778 |
|
| 2779 |
# Advance to next tile
|
| 2780 |
+
work_tile = tile_scheduler.advance_to_next_work()
|
| 2781 |
+
|
| 2782 |
+
@cute.jit
|
| 2783 |
+
def clc_scheduler_warp(
|
| 2784 |
+
self,
|
| 2785 |
+
tile_scheduler: TileSchedulerProtocol,
|
| 2786 |
+
):
|
| 2787 |
+
work_tile = tile_scheduler.initial_work_tile_info()
|
| 2788 |
+
while work_tile.is_valid_tile:
|
| 2789 |
+
tile_scheduler.prefetch_next_work()
|
| 2790 |
+
work_tile = tile_scheduler.advance_to_next_work()
|
| 2791 |
+
if cute.arch.thread_idx()[0] == self.clc_scheduler_warp_id * cute.arch.WARP_SIZE:
|
| 2792 |
+
fa_printf(
|
| 2793 |
+
3,
|
| 2794 |
+
"[CLC] query sm={} cta={} (m_blk={},h={},b={},s={}) valid={}\n",
|
| 2795 |
+
smid(),
|
| 2796 |
+
cute.arch.block_idx()[0],
|
| 2797 |
+
work_tile.tile_idx[0],
|
| 2798 |
+
work_tile.tile_idx[1],
|
| 2799 |
+
work_tile.tile_idx[2],
|
| 2800 |
+
work_tile.tile_idx[3],
|
| 2801 |
+
work_tile.is_valid_tile,
|
| 2802 |
+
)
|
| 2803 |
+
tile_scheduler.producer_tail()
|
| 2804 |
+
|
| 2805 |
+
@cute.jit
|
| 2806 |
+
def empty_warp(
|
| 2807 |
+
self,
|
| 2808 |
+
tile_scheduler: TileSchedulerProtocol,
|
| 2809 |
+
):
|
| 2810 |
+
work_tile = tile_scheduler.initial_work_tile_info()
|
| 2811 |
+
while work_tile.is_valid_tile:
|
| 2812 |
+
work_tile = tile_scheduler.advance_to_next_work()
|
| 2813 |
|
| 2814 |
def load_Q(
|
| 2815 |
self,
|
|
|
|
| 2822 |
pipeline_q.producer_acquire_w_index_phase(stage, phase)
|
| 2823 |
load_Q_fn(src_idx=block, dst_idx=stage, tma_bar_ptr=pipeline_q.sync_object_full.get_barrier(stage))
|
| 2824 |
|
| 2825 |
+
def load_Q_non_tma(
|
| 2826 |
+
self,
|
| 2827 |
+
mQ: cute.Tensor,
|
| 2828 |
+
sQ: cute.Tensor,
|
| 2829 |
+
gmem_tiled_copy_Q: cute.TiledCopy,
|
| 2830 |
+
pipeline_q: pipeline.PipelineAsync,
|
| 2831 |
+
tidx: Int32,
|
| 2832 |
+
seqlen_q: Int32,
|
| 2833 |
+
m_block: Int32,
|
| 2834 |
+
block: Int32,
|
| 2835 |
+
stage: int,
|
| 2836 |
+
phase: Int32,
|
| 2837 |
+
):
|
| 2838 |
+
assert self.cta_group_size == 1, "cta_group_size must be 1 for non-tma Q load"
|
| 2839 |
+
pipeline_q.producer_acquire_w_index_phase(stage, phase)
|
| 2840 |
+
pack_gqa = PackGQA(
|
| 2841 |
+
self.m_block_size,
|
| 2842 |
+
self.head_dim_padded,
|
| 2843 |
+
self.check_hdim_oob,
|
| 2844 |
+
self.qhead_per_kvhead,
|
| 2845 |
+
)
|
| 2846 |
+
sQ_stage = sQ[None, None, None, stage]
|
| 2847 |
+
sQ_pi = cute.make_tensor(
|
| 2848 |
+
sQ_stage.iterator,
|
| 2849 |
+
cute.make_layout(
|
| 2850 |
+
(sQ_stage.shape[0][0], (sQ_stage.shape[0][1], sQ_stage.shape[2])),
|
| 2851 |
+
stride=(sQ_stage.stride[0][0], (sQ_stage.stride[0][1], sQ_stage.stride[2])),
|
| 2852 |
+
),
|
| 2853 |
+
)
|
| 2854 |
+
pack_gqa.load_Q(mQ, sQ_pi, gmem_tiled_copy_Q, tidx, m_block * self.q_stage + block, seqlen_q)
|
| 2855 |
+
cute.arch.cp_async_commit_group()
|
| 2856 |
+
pipeline_q.sync_object_full.arrive_cp_async_mbarrier(stage)
|
| 2857 |
+
|
| 2858 |
@cute.jit
|
| 2859 |
def load_KV(
|
| 2860 |
self,
|
|
|
|
| 2897 |
else:
|
| 2898 |
assert paged_kv_manager is not None
|
| 2899 |
assert extra_tx_count is None
|
| 2900 |
+
sX_cur = sX[None, None, None, stage]
|
| 2901 |
+
if const_expr(self.uneven_kv_smem):
|
| 2902 |
+
sX_cur = self.offset_kv_smem(sX_cur, stage, phase ^ 1)
|
| 2903 |
+
paged_kv_manager.load_KV(block, sX_cur, K_or_V)
|
| 2904 |
cute.arch.cp_async_commit_group()
|
| 2905 |
pipeline_kv.sync_object_full.arrive_cp_async_mbarrier(stage)
|
| 2906 |
|
|
|
|
| 2911 |
# (smem_large + smem_small) // 2. So for stage == 1, move right by offset if
|
| 2912 |
# phase == 0, or left by offset if phase == 1.
|
| 2913 |
offset = 0 if stage != 1 else self.uneven_kv_smem_offset * (1 - 2 * phase)
|
| 2914 |
+
# Hint that the offset is 128-bit aligned so that
|
| 2915 |
+
# ptr + offset preserves the alignment needed by cp.async.
|
| 2916 |
+
offset = cute.assume(offset, divby=128 // self.k_dtype.width)
|
| 2917 |
return cute.make_tensor(sX.iterator + offset, sX.layout)
|
| 2918 |
else:
|
| 2919 |
return sX
|
|
|
|
| 2923 |
# warp_group_idx = utils.canonical_warp_group_idx(sync=False)
|
| 2924 |
# if warp_group_idx == 0:
|
| 2925 |
# cute.arch.barrier_arrive(
|
| 2926 |
+
# barrier_id=int(NamedBarrierFwdSm100.WarpSchedulerWG1), number_of_threads=2 * 128,
|
| 2927 |
# )
|
| 2928 |
|
| 2929 |
# def warp_scheduler_barrier_sync(self):
|
| 2930 |
# cute.arch.barrier(
|
| 2931 |
+
# barrier_id=int(NamedBarrierFwdSm100.WarpSchedulerWG1) + utils.canonical_warp_group_idx(sync=False),
|
| 2932 |
# number_of_threads=2 * 128
|
| 2933 |
# )
|
| 2934 |
|
|
|
|
| 2936 |
# cur_wg = utils.canonical_warp_group_idx(sync=False)
|
| 2937 |
# next_wg = 1 - cur_wg
|
| 2938 |
# cute.arch.barrier_arrive(
|
| 2939 |
+
# barrier_id=int(NamedBarrierFwdSm100.WarpSchedulerWG1) + next_wg, number_of_threads=2 * 128,
|
| 2940 |
# )
|
| 2941 |
|
| 2942 |
@cute.jit
|
build/torch-cuda/flash_fwd_sm120.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 2 |
+
# SM120 (Blackwell GeForce / DGX Spark) forward pass.
|
| 3 |
+
#
|
| 4 |
+
# SM120 uses the same SM80-era MMA instructions (mma.sync.aligned.m16n8k16) but has
|
| 5 |
+
# a smaller shared memory capacity (99 KB vs 163 KB on SM80). This module subclasses
|
| 6 |
+
# FlashAttentionForwardSm80 and overrides the SMEM capacity check accordingly.
|
| 7 |
+
|
| 8 |
+
import cutlass
|
| 9 |
+
import cutlass.utils as utils_basic
|
| 10 |
+
|
| 11 |
+
from .flash_fwd import FlashAttentionForwardSm80
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class FlashAttentionForwardSm120(FlashAttentionForwardSm80):
|
| 15 |
+
# Keep arch = 80 to use CpAsync code paths (no TMA for output).
|
| 16 |
+
# The compilation target is determined by the GPU at compile time, not this field.
|
| 17 |
+
arch = 80
|
| 18 |
+
|
| 19 |
+
@staticmethod
|
| 20 |
+
def can_implement(
|
| 21 |
+
dtype,
|
| 22 |
+
head_dim,
|
| 23 |
+
head_dim_v,
|
| 24 |
+
tile_m,
|
| 25 |
+
tile_n,
|
| 26 |
+
num_stages,
|
| 27 |
+
num_threads,
|
| 28 |
+
is_causal,
|
| 29 |
+
Q_in_regs=False,
|
| 30 |
+
) -> bool:
|
| 31 |
+
"""Check if the kernel can be implemented on SM120.
|
| 32 |
+
|
| 33 |
+
Same logic as SM80 but uses SM120's shared memory capacity (99 KB).
|
| 34 |
+
"""
|
| 35 |
+
if dtype not in [cutlass.Float16, cutlass.BFloat16]:
|
| 36 |
+
return False
|
| 37 |
+
if head_dim % 8 != 0:
|
| 38 |
+
return False
|
| 39 |
+
if head_dim_v % 8 != 0:
|
| 40 |
+
return False
|
| 41 |
+
if tile_n % 16 != 0:
|
| 42 |
+
return False
|
| 43 |
+
if num_threads % 32 != 0:
|
| 44 |
+
return False
|
| 45 |
+
# Shared memory usage: Q tile + (K tile + V tile)
|
| 46 |
+
smem_usage_Q = tile_m * head_dim * 2
|
| 47 |
+
smem_usage_K = tile_n * head_dim * num_stages * 2
|
| 48 |
+
smem_usage_V = tile_n * head_dim_v * num_stages * 2
|
| 49 |
+
smem_usage_QV = (
|
| 50 |
+
(smem_usage_Q + smem_usage_V) if not Q_in_regs else max(smem_usage_Q, smem_usage_V)
|
| 51 |
+
)
|
| 52 |
+
smem_usage = smem_usage_QV + smem_usage_K
|
| 53 |
+
# SM120 has 99 KB shared memory (vs 163 KB on SM80)
|
| 54 |
+
smem_capacity = utils_basic.get_smem_capacity_in_bytes("sm_120")
|
| 55 |
+
if smem_usage > smem_capacity:
|
| 56 |
+
return False
|
| 57 |
+
if (tile_m * 2) % num_threads != 0:
|
| 58 |
+
return False
|
| 59 |
+
return True
|
build/torch-cuda/flash_fwd_sm90.py
ADDED
|
@@ -0,0 +1,1534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 2 |
+
# SM90 (Hopper) forward pass for flash attention, extracted from flash_fwd.py.
|
| 3 |
+
|
| 4 |
+
from types import SimpleNamespace
|
| 5 |
+
from typing import Callable, Literal, Optional
|
| 6 |
+
from functools import partial
|
| 7 |
+
|
| 8 |
+
import cuda.bindings.driver as cuda
|
| 9 |
+
|
| 10 |
+
import cutlass
|
| 11 |
+
import cutlass.cute as cute
|
| 12 |
+
from cutlass import Float32, Int32, const_expr
|
| 13 |
+
from cutlass.cute.nvgpu import cpasync, warpgroup
|
| 14 |
+
from cutlass.utils import LayoutEnum
|
| 15 |
+
import cutlass.utils.hopper_helpers as sm90_utils_basic
|
| 16 |
+
from cutlass import pipeline
|
| 17 |
+
from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait
|
| 18 |
+
from cutlass.base_dsl.arch import Arch
|
| 19 |
+
|
| 20 |
+
from .quack import copy_utils
|
| 21 |
+
from .quack import layout_utils
|
| 22 |
+
from .quack import sm90_utils
|
| 23 |
+
|
| 24 |
+
from .cute_dsl_utils import assume_tensor_aligned
|
| 25 |
+
from . import utils
|
| 26 |
+
from .mask import AttentionMask
|
| 27 |
+
from .softmax import Softmax, apply_score_mod_inner
|
| 28 |
+
from .seqlen_info import SeqlenInfoQK
|
| 29 |
+
from .block_info import BlockInfo
|
| 30 |
+
from .block_sparsity import BlockSparseTensors
|
| 31 |
+
from .block_sparse_utils import (
|
| 32 |
+
produce_block_sparse_loads,
|
| 33 |
+
consume_block_sparse_loads,
|
| 34 |
+
)
|
| 35 |
+
from . import pipeline as pipeline_custom
|
| 36 |
+
from .pack_gqa import PackGQA, pack_gqa_layout, make_packgqa_tiled_tma_atom
|
| 37 |
+
from .paged_kv import PagedKVManager
|
| 38 |
+
from .named_barrier import NamedBarrierFwd
|
| 39 |
+
from .quack.cute_dsl_utils import ParamsBase
|
| 40 |
+
from .tile_scheduler import (
|
| 41 |
+
TileSchedulerArguments,
|
| 42 |
+
SingleTileScheduler,
|
| 43 |
+
SingleTileLPTScheduler,
|
| 44 |
+
SingleTileVarlenScheduler,
|
| 45 |
+
)
|
| 46 |
+
from cutlass.cute import FastDivmodDivisor
|
| 47 |
+
|
| 48 |
+
from .flash_fwd import FlashAttentionForwardBase
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class FlashAttentionForwardSm90(FlashAttentionForwardBase):
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
*args,
|
| 55 |
+
intra_wg_overlap: bool = True,
|
| 56 |
+
mma_pv_is_rs: bool = True,
|
| 57 |
+
paged_kv_non_tma: bool = False,
|
| 58 |
+
**kwargs,
|
| 59 |
+
):
|
| 60 |
+
super().__init__(*args, **kwargs)
|
| 61 |
+
self.intra_wg_overlap = intra_wg_overlap
|
| 62 |
+
self.mma_pv_is_rs = mma_pv_is_rs
|
| 63 |
+
self.buffer_align_bytes = 1024
|
| 64 |
+
self.use_tma_KV = not paged_kv_non_tma
|
| 65 |
+
assert self.use_tma_KV or not (self.check_hdim_oob or self.check_hdim_v_oob), (
|
| 66 |
+
"Paged KV does not support irregular head dim"
|
| 67 |
+
)
|
| 68 |
+
self.cluster_shape_mn = (1, 1)
|
| 69 |
+
assert self.arch >= Arch.sm_90 and self.arch <= Arch.sm_90a, "Only SM 9.x is supported"
|
| 70 |
+
|
| 71 |
+
def _get_smem_layout_atom(self):
|
| 72 |
+
sQ_layout_atom = warpgroup.make_smem_layout_atom(
|
| 73 |
+
sm90_utils_basic.get_smem_layout_atom(LayoutEnum.ROW_MAJOR, self.dtype, self.tile_hdim),
|
| 74 |
+
self.dtype,
|
| 75 |
+
)
|
| 76 |
+
sK_layout_atom = sQ_layout_atom
|
| 77 |
+
sV_layout_atom = warpgroup.make_smem_layout_atom(
|
| 78 |
+
sm90_utils_basic.get_smem_layout_atom(
|
| 79 |
+
LayoutEnum.ROW_MAJOR, self.dtype, self.tile_hdimv
|
| 80 |
+
),
|
| 81 |
+
self.dtype,
|
| 82 |
+
)
|
| 83 |
+
sO_layout_atom = sV_layout_atom
|
| 84 |
+
if not self.mma_pv_is_rs:
|
| 85 |
+
sP_layout_atom = warpgroup.make_smem_layout_atom(
|
| 86 |
+
sm90_utils_basic.get_smem_layout_atom(
|
| 87 |
+
LayoutEnum.ROW_MAJOR, self.dtype, self.tile_n
|
| 88 |
+
),
|
| 89 |
+
self.dtype,
|
| 90 |
+
)
|
| 91 |
+
else:
|
| 92 |
+
sP_layout_atom = None
|
| 93 |
+
return sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom
|
| 94 |
+
|
| 95 |
+
def _get_tiled_mma(self):
|
| 96 |
+
tiled_mma_qk = sm90_utils_basic.make_trivial_tiled_mma(
|
| 97 |
+
self.dtype,
|
| 98 |
+
self.dtype,
|
| 99 |
+
warpgroup.OperandMajorMode.K,
|
| 100 |
+
warpgroup.OperandMajorMode.K,
|
| 101 |
+
Float32,
|
| 102 |
+
atom_layout_mnk=(self.tile_m // 64, 1, 1),
|
| 103 |
+
tiler_mn=(64, self.tile_n),
|
| 104 |
+
)
|
| 105 |
+
tiled_mma_pv = sm90_utils_basic.make_trivial_tiled_mma(
|
| 106 |
+
self.dtype,
|
| 107 |
+
self.dtype,
|
| 108 |
+
warpgroup.OperandMajorMode.K,
|
| 109 |
+
warpgroup.OperandMajorMode.MN,
|
| 110 |
+
Float32,
|
| 111 |
+
atom_layout_mnk=(self.tile_m // 64, 1, 1), # Might need (1, 2, 1) for hdim 512
|
| 112 |
+
tiler_mn=(64, self.tile_hdimv),
|
| 113 |
+
a_source=warpgroup.OperandSource.RMEM
|
| 114 |
+
if self.mma_pv_is_rs
|
| 115 |
+
else warpgroup.OperandSource.SMEM,
|
| 116 |
+
)
|
| 117 |
+
return tiled_mma_qk, tiled_mma_pv
|
| 118 |
+
|
| 119 |
+
def _get_shared_storage_cls(self):
|
| 120 |
+
sQ_struct, sK_struct, sV_struct = [
|
| 121 |
+
cute.struct.Align[
|
| 122 |
+
cute.struct.MemRange[self.dtype, cute.cosize(layout)], self.buffer_align_bytes
|
| 123 |
+
]
|
| 124 |
+
for layout in (self.sQ_layout, self.sK_layout, self.sV_layout)
|
| 125 |
+
]
|
| 126 |
+
cosize_sQV = max(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout))
|
| 127 |
+
sQV_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sQV], 1024]
|
| 128 |
+
cosize_sP = cute.cosize(self.sP_layout) if const_expr(self.sP_layout is not None) else 0
|
| 129 |
+
sP_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sP], 1024]
|
| 130 |
+
# 1 stage * 2 for Q pipeline (full + empty), self.num_stages*2 for K, self.num_stages*2 for V,
|
| 131 |
+
mbar_ptr_Q_struct = cute.struct.MemRange[cutlass.Int64, 1 * 2]
|
| 132 |
+
mbar_ptr_K_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2]
|
| 133 |
+
mbar_ptr_V_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2]
|
| 134 |
+
|
| 135 |
+
@cute.struct
|
| 136 |
+
class SharedStorageQKV:
|
| 137 |
+
mbar_ptr_Q: mbar_ptr_Q_struct
|
| 138 |
+
mbar_ptr_K: mbar_ptr_K_struct
|
| 139 |
+
mbar_ptr_V: mbar_ptr_V_struct
|
| 140 |
+
sV: sV_struct
|
| 141 |
+
sQ: sQ_struct
|
| 142 |
+
sK: sK_struct
|
| 143 |
+
sP: sP_struct
|
| 144 |
+
|
| 145 |
+
@cute.struct
|
| 146 |
+
class SharedStorageSharedQV:
|
| 147 |
+
mbar_ptr_Q: mbar_ptr_Q_struct
|
| 148 |
+
mbar_ptr_K: mbar_ptr_K_struct
|
| 149 |
+
mbar_ptr_V: mbar_ptr_V_struct
|
| 150 |
+
sQ: sQV_struct
|
| 151 |
+
sK: sK_struct
|
| 152 |
+
sP: sP_struct
|
| 153 |
+
|
| 154 |
+
return SharedStorageQKV if const_expr(not self.Q_in_regs) else SharedStorageSharedQV
|
| 155 |
+
|
| 156 |
+
@cute.jit
|
| 157 |
+
def __call__(
|
| 158 |
+
self,
|
| 159 |
+
mQ: cute.Tensor, # (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
|
| 160 |
+
mK: cute.Tensor, # (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table
|
| 161 |
+
mV: cute.Tensor, # (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table
|
| 162 |
+
mO: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
|
| 163 |
+
mLSE: Optional[cute.Tensor],
|
| 164 |
+
softmax_scale: Float32,
|
| 165 |
+
mCuSeqlensQ: Optional[cute.Tensor] = None,
|
| 166 |
+
mCuSeqlensK: Optional[cute.Tensor] = None,
|
| 167 |
+
mSeqUsedQ: Optional[cute.Tensor] = None,
|
| 168 |
+
mSeqUsedK: Optional[cute.Tensor] = None,
|
| 169 |
+
mPageTable: Optional[cute.Tensor] = None, # (b_k, max_num_pages_per_seq)
|
| 170 |
+
window_size_left: Int32 | int | None = None,
|
| 171 |
+
window_size_right: Int32 | int | None = None,
|
| 172 |
+
learnable_sink: Optional[cute.Tensor] = None,
|
| 173 |
+
blocksparse_tensors: Optional[BlockSparseTensors] = None,
|
| 174 |
+
aux_tensors: Optional[list] = None,
|
| 175 |
+
# Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).
|
| 176 |
+
stream: cuda.CUstream = None,
|
| 177 |
+
):
|
| 178 |
+
"""Configures and launches the flash attention kernel.
|
| 179 |
+
|
| 180 |
+
mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout:
|
| 181 |
+
(batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1)
|
| 182 |
+
"""
|
| 183 |
+
|
| 184 |
+
self._check_type(
|
| 185 |
+
*(
|
| 186 |
+
t.element_type if t is not None else None
|
| 187 |
+
for t in (mQ, mK, mV, mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)
|
| 188 |
+
)
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
self.varlen_q = mCuSeqlensQ is not None or mSeqUsedQ is not None
|
| 192 |
+
|
| 193 |
+
mQ, mK, mV, mO = [assume_tensor_aligned(t) for t in (mQ, mK, mV, mO)]
|
| 194 |
+
QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1]
|
| 195 |
+
mQ, mO = [layout_utils.select(t, QO_layout_transpose) for t in (mQ, mO)]
|
| 196 |
+
KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1]
|
| 197 |
+
mK, mV = [layout_utils.select(t, KV_layout_transpose) for t in (mK, mV)]
|
| 198 |
+
LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0]
|
| 199 |
+
mLSE = (
|
| 200 |
+
layout_utils.select(mLSE, LSE_layout_transpose)
|
| 201 |
+
if const_expr(mLSE is not None)
|
| 202 |
+
else None
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
tiled_mma_qk, tiled_mma_pv = self._get_tiled_mma()
|
| 206 |
+
self.num_mma_threads = tiled_mma_qk.size
|
| 207 |
+
self.num_threads_per_warp_group = 128
|
| 208 |
+
self.num_wg_mma = self.num_mma_threads // self.num_threads_per_warp_group
|
| 209 |
+
assert self.num_wg_mma in [1, 2, 3]
|
| 210 |
+
self.num_threads = self.num_threads_per_warp_group * (self.num_wg_mma + 1)
|
| 211 |
+
self.num_producer_threads = 32
|
| 212 |
+
self.num_Q_load_threads = self.num_threads_per_warp_group # If not TMA_Q
|
| 213 |
+
self.num_epilogue_threads = self.num_mma_threads
|
| 214 |
+
self.num_mma_regs, self.num_producer_regs = {1: (256, 56), 2: (240, 24), 3: (160, 32)}[
|
| 215 |
+
self.num_wg_mma
|
| 216 |
+
]
|
| 217 |
+
self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None)
|
| 218 |
+
|
| 219 |
+
self.use_scheduler_barrier = (
|
| 220 |
+
(self.num_wg_mma >= 2 and self.tile_hdim <= 128)
|
| 221 |
+
if const_expr(self.intra_wg_overlap)
|
| 222 |
+
else (self.num_wg_mma == 2)
|
| 223 |
+
)
|
| 224 |
+
self.use_tma_Q = self.arch >= Arch.sm_90 and not (
|
| 225 |
+
self.pack_gqa and self.tile_m % self.qhead_per_kvhead != 0
|
| 226 |
+
)
|
| 227 |
+
self.use_tma_O = self.use_tma_Q
|
| 228 |
+
# Producer needs more registers when doing cp.async Q or KV loads
|
| 229 |
+
if const_expr(self.num_wg_mma == 2 and (not self.use_tma_Q or not self.use_tma_KV)):
|
| 230 |
+
self.num_mma_regs, self.num_producer_regs = 224, 40
|
| 231 |
+
self.rescale_O_before_gemm = self.tile_hdimv > 128 and self.intra_wg_overlap
|
| 232 |
+
self._setup_attributes()
|
| 233 |
+
# TODO: we prob don't need most of what's in _setup_attributes
|
| 234 |
+
self.sQ_layout, self.sK_layout, self.sV_layout, self.sO_layout = [
|
| 235 |
+
sm90_utils.make_smem_layout(mX.element_type, LayoutEnum.ROW_MAJOR, shape, stage)
|
| 236 |
+
for mX, shape, stage in [
|
| 237 |
+
(mQ, (self.tile_m, self.tile_hdim), None),
|
| 238 |
+
(mK, (self.tile_n, self.tile_hdim), self.num_stages),
|
| 239 |
+
(mV, (self.tile_n, self.tile_hdimv), self.num_stages),
|
| 240 |
+
(mO, (self.tile_m, self.tile_hdimv), None),
|
| 241 |
+
]
|
| 242 |
+
]
|
| 243 |
+
self.sP_layout = None
|
| 244 |
+
if const_expr(not self.mma_pv_is_rs):
|
| 245 |
+
self.sP_layout = sm90_utils.make_smem_layout(
|
| 246 |
+
mV.element_type, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_n)
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
SharedStorage = self._get_shared_storage_cls()
|
| 250 |
+
|
| 251 |
+
mQ_og, mO_og = mQ, mO
|
| 252 |
+
if const_expr(self.pack_gqa):
|
| 253 |
+
nheads_kv = mK.shape[2]
|
| 254 |
+
mQ = pack_gqa_layout(mQ, self.qhead_per_kvhead, nheads_kv, head_idx=2)
|
| 255 |
+
mO = pack_gqa_layout(mO, self.qhead_per_kvhead, nheads_kv, head_idx=2)
|
| 256 |
+
if const_expr(mLSE is not None):
|
| 257 |
+
mLSE = pack_gqa_layout(mLSE, self.qhead_per_kvhead, nheads_kv, head_idx=1)
|
| 258 |
+
|
| 259 |
+
# TMA
|
| 260 |
+
gmem_tiled_copy_Q = cpasync.CopyBulkTensorTileG2SOp()
|
| 261 |
+
gmem_tiled_copy_KV = cpasync.CopyBulkTensorTileG2SOp() # Might multicast
|
| 262 |
+
gmem_tiled_copy_O = cpasync.CopyBulkTensorTileS2GOp()
|
| 263 |
+
self.tma_copy_bytes = {
|
| 264 |
+
name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1]))
|
| 265 |
+
for name, mX, layout in [
|
| 266 |
+
("Q", mQ, self.sQ_layout),
|
| 267 |
+
("K", mK, self.sK_layout),
|
| 268 |
+
("V", mV, self.sV_layout),
|
| 269 |
+
]
|
| 270 |
+
}
|
| 271 |
+
make_tiled_tma_atom_fn = (
|
| 272 |
+
partial(make_packgqa_tiled_tma_atom, qhead_per_kvhead=self.qhead_per_kvhead, head_idx=2)
|
| 273 |
+
if const_expr(self.pack_gqa)
|
| 274 |
+
else cpasync.make_tiled_tma_atom
|
| 275 |
+
)
|
| 276 |
+
tma_atom_Q, tma_tensor_Q = None, None
|
| 277 |
+
if const_expr(self.use_tma_Q):
|
| 278 |
+
tma_atom_Q, tma_tensor_Q = make_tiled_tma_atom_fn(
|
| 279 |
+
gmem_tiled_copy_Q,
|
| 280 |
+
mQ_og if const_expr(self.pack_gqa) else mQ,
|
| 281 |
+
self.sQ_layout,
|
| 282 |
+
(self.tile_m, self.tile_hdim), # No mcast
|
| 283 |
+
)
|
| 284 |
+
tma_atom_K, tma_tensor_K = None, None
|
| 285 |
+
tma_atom_V, tma_tensor_V = None, None
|
| 286 |
+
if const_expr(self.use_tma_KV):
|
| 287 |
+
tma_atom_K, tma_tensor_K = cpasync.make_tiled_tma_atom(
|
| 288 |
+
gmem_tiled_copy_KV,
|
| 289 |
+
mK,
|
| 290 |
+
cute.select(self.sK_layout, mode=[0, 1]),
|
| 291 |
+
(self.tile_n, self.tile_hdim),
|
| 292 |
+
1, # No mcast for now
|
| 293 |
+
)
|
| 294 |
+
tma_atom_V, tma_tensor_V = cpasync.make_tiled_tma_atom(
|
| 295 |
+
gmem_tiled_copy_KV,
|
| 296 |
+
mV,
|
| 297 |
+
cute.select(self.sV_layout, mode=[0, 1]),
|
| 298 |
+
(self.tile_n, self.tile_hdimv),
|
| 299 |
+
1, # No mcast for now
|
| 300 |
+
)
|
| 301 |
+
tma_atom_O, tma_tensor_O = None, None
|
| 302 |
+
if const_expr(self.use_tma_O):
|
| 303 |
+
mO_tma = mO_og if const_expr(self.pack_gqa) else mO
|
| 304 |
+
if const_expr(self.varlen_q):
|
| 305 |
+
mO_tma = copy_utils.create_ragged_tensor_for_tma(
|
| 306 |
+
mO_tma, ragged_dim=0, ptr_shift=True
|
| 307 |
+
)
|
| 308 |
+
tma_atom_O, tma_tensor_O = make_tiled_tma_atom_fn(
|
| 309 |
+
gmem_tiled_copy_O,
|
| 310 |
+
mO_tma,
|
| 311 |
+
self.sO_layout,
|
| 312 |
+
(self.tile_m, self.tile_hdimv), # No mcast
|
| 313 |
+
)
|
| 314 |
+
if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None):
|
| 315 |
+
TileScheduler = SingleTileVarlenScheduler
|
| 316 |
+
else:
|
| 317 |
+
TileScheduler = (
|
| 318 |
+
SingleTileScheduler
|
| 319 |
+
if const_expr(not self.is_causal or self.is_local)
|
| 320 |
+
else SingleTileLPTScheduler
|
| 321 |
+
)
|
| 322 |
+
tile_sched_args = TileSchedulerArguments(
|
| 323 |
+
cute.ceil_div(cute.size(mQ.shape[0]), self.tile_m),
|
| 324 |
+
cute.size(mQ.shape[2]),
|
| 325 |
+
cute.size(mQ.shape[3])
|
| 326 |
+
if const_expr(mCuSeqlensQ is None)
|
| 327 |
+
else cute.size(mCuSeqlensQ.shape[0] - 1),
|
| 328 |
+
1, # num_splits
|
| 329 |
+
cute.size(mK.shape[0])
|
| 330 |
+
if const_expr(mPageTable is None)
|
| 331 |
+
else mK.shape[0] * mPageTable.shape[1],
|
| 332 |
+
mQ.shape[1],
|
| 333 |
+
mV.shape[1],
|
| 334 |
+
total_q=cute.size(mQ.shape[0])
|
| 335 |
+
if const_expr(mCuSeqlensQ is not None)
|
| 336 |
+
else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]),
|
| 337 |
+
tile_shape_mn=(self.tile_m, self.tile_n),
|
| 338 |
+
mCuSeqlensQ=mCuSeqlensQ,
|
| 339 |
+
mSeqUsedQ=mSeqUsedQ,
|
| 340 |
+
qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
|
| 341 |
+
element_size=self.dtype.width // 8,
|
| 342 |
+
is_persistent=False,
|
| 343 |
+
lpt=self.is_causal or self.is_local,
|
| 344 |
+
)
|
| 345 |
+
tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
|
| 346 |
+
grid_dim = TileScheduler.get_grid_shape(tile_sched_params)
|
| 347 |
+
softmax_scale_log2, softmax_scale = utils.compute_softmax_scale_log2(
|
| 348 |
+
softmax_scale, self.score_mod
|
| 349 |
+
)
|
| 350 |
+
window_size_left = Int32(window_size_left) if window_size_left is not None else None
|
| 351 |
+
window_size_right = Int32(window_size_right) if window_size_right is not None else None
|
| 352 |
+
fastdiv_mods = utils.compute_fastdiv_mods(
|
| 353 |
+
mQ, mK, self.qhead_per_kvhead, self.pack_gqa, aux_tensors, mPageTable
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
self.kernel(
|
| 357 |
+
tma_tensor_Q if const_expr(self.use_tma_Q) else mQ,
|
| 358 |
+
tma_tensor_K if const_expr(self.use_tma_KV) else mK,
|
| 359 |
+
tma_tensor_V if const_expr(self.use_tma_KV) else mV,
|
| 360 |
+
tma_tensor_O if const_expr(self.use_tma_O) else mO,
|
| 361 |
+
mLSE,
|
| 362 |
+
mCuSeqlensQ,
|
| 363 |
+
mCuSeqlensK,
|
| 364 |
+
mSeqUsedQ,
|
| 365 |
+
mSeqUsedK,
|
| 366 |
+
mPageTable,
|
| 367 |
+
tma_atom_Q,
|
| 368 |
+
tma_atom_K,
|
| 369 |
+
tma_atom_V,
|
| 370 |
+
tma_atom_O,
|
| 371 |
+
softmax_scale_log2,
|
| 372 |
+
softmax_scale,
|
| 373 |
+
window_size_left,
|
| 374 |
+
window_size_right,
|
| 375 |
+
learnable_sink,
|
| 376 |
+
blocksparse_tensors,
|
| 377 |
+
self.sQ_layout,
|
| 378 |
+
self.sK_layout,
|
| 379 |
+
self.sV_layout,
|
| 380 |
+
self.sO_layout,
|
| 381 |
+
self.sP_layout,
|
| 382 |
+
self.gmem_tiled_copy_Q,
|
| 383 |
+
self.gmem_tiled_copy_K,
|
| 384 |
+
self.gmem_tiled_copy_V,
|
| 385 |
+
self.gmem_tiled_copy_O,
|
| 386 |
+
tiled_mma_qk,
|
| 387 |
+
tiled_mma_pv,
|
| 388 |
+
tile_sched_params,
|
| 389 |
+
TileScheduler,
|
| 390 |
+
SharedStorage,
|
| 391 |
+
aux_tensors,
|
| 392 |
+
fastdiv_mods,
|
| 393 |
+
).launch(
|
| 394 |
+
grid=grid_dim,
|
| 395 |
+
block=[self.num_threads, 1, 1],
|
| 396 |
+
stream=stream,
|
| 397 |
+
min_blocks_per_mp=1,
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
@cute.kernel
|
| 401 |
+
def kernel(
|
| 402 |
+
self,
|
| 403 |
+
mQ: cute.Tensor,
|
| 404 |
+
mK: cute.Tensor,
|
| 405 |
+
mV: cute.Tensor,
|
| 406 |
+
mO: cute.Tensor,
|
| 407 |
+
mLSE: Optional[cute.Tensor],
|
| 408 |
+
mCuSeqlensQ: Optional[cute.Tensor],
|
| 409 |
+
mCuSeqlensK: Optional[cute.Tensor],
|
| 410 |
+
mSeqUsedQ: Optional[cute.Tensor],
|
| 411 |
+
mSeqUsedK: Optional[cute.Tensor],
|
| 412 |
+
mPageTable: Optional[cute.Tensor],
|
| 413 |
+
tma_atom_Q: Optional[cute.CopyAtom],
|
| 414 |
+
tma_atom_K: Optional[cute.CopyAtom],
|
| 415 |
+
tma_atom_V: Optional[cute.CopyAtom],
|
| 416 |
+
tma_atom_O: Optional[cute.CopyAtom],
|
| 417 |
+
softmax_scale_log2: Float32,
|
| 418 |
+
softmax_scale: Optional[Float32],
|
| 419 |
+
window_size_left: Optional[Int32],
|
| 420 |
+
window_size_right: Optional[Int32],
|
| 421 |
+
learnable_sink: Optional[cute.Tensor],
|
| 422 |
+
blocksparse_tensors: Optional[BlockSparseTensors],
|
| 423 |
+
sQ_layout: cute.ComposedLayout,
|
| 424 |
+
sK_layout: cute.ComposedLayout,
|
| 425 |
+
sV_layout: cute.ComposedLayout,
|
| 426 |
+
sO_layout: cute.ComposedLayout,
|
| 427 |
+
sP_layout: cute.ComposedLayout | None,
|
| 428 |
+
gmem_tiled_copy_Q: cute.TiledCopy,
|
| 429 |
+
gmem_tiled_copy_K: cute.TiledCopy,
|
| 430 |
+
gmem_tiled_copy_V: cute.TiledCopy,
|
| 431 |
+
gmem_tiled_copy_O: cute.TiledCopy,
|
| 432 |
+
tiled_mma_qk: cute.TiledMma,
|
| 433 |
+
tiled_mma_pv: cute.TiledMma,
|
| 434 |
+
tile_sched_params: ParamsBase,
|
| 435 |
+
TileScheduler: cutlass.Constexpr[Callable],
|
| 436 |
+
SharedStorage: cutlass.Constexpr[Callable],
|
| 437 |
+
aux_tensors=Optional[list[cute.Tensor]],
|
| 438 |
+
fastdiv_mods=None,
|
| 439 |
+
):
|
| 440 |
+
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
| 441 |
+
# Prefetch tma descriptor
|
| 442 |
+
if warp_idx == 0:
|
| 443 |
+
for tma_atom in (tma_atom_Q, tma_atom_K, tma_atom_V, tma_atom_O):
|
| 444 |
+
if const_expr(tma_atom is not None):
|
| 445 |
+
cpasync.prefetch_descriptor(tma_atom)
|
| 446 |
+
|
| 447 |
+
smem = cutlass.utils.SmemAllocator()
|
| 448 |
+
storage = smem.allocate(SharedStorage)
|
| 449 |
+
|
| 450 |
+
# Mbarrier / pipeline init
|
| 451 |
+
mbar_ptr_Q = storage.mbar_ptr_Q.data_ptr()
|
| 452 |
+
|
| 453 |
+
ThreadCooperativeGroup = partial(pipeline.CooperativeGroup, pipeline.Agent.Thread)
|
| 454 |
+
tma_warp = ThreadCooperativeGroup(1)
|
| 455 |
+
load_threads = ThreadCooperativeGroup(self.num_threads_per_warp_group)
|
| 456 |
+
mma_warps = ThreadCooperativeGroup(self.num_mma_threads // cute.arch.WARP_SIZE)
|
| 457 |
+
if const_expr(self.use_tma_Q):
|
| 458 |
+
pipeline_q = pipeline_custom.PipelineTmaAsync.create(
|
| 459 |
+
barrier_storage=mbar_ptr_Q,
|
| 460 |
+
num_stages=1,
|
| 461 |
+
producer_group=tma_warp,
|
| 462 |
+
consumer_group=mma_warps,
|
| 463 |
+
tx_count=self.tma_copy_bytes["Q"],
|
| 464 |
+
defer_sync=True,
|
| 465 |
+
)
|
| 466 |
+
else:
|
| 467 |
+
pipeline_q = pipeline_custom.PipelineCpAsync.create(
|
| 468 |
+
barrier_storage=mbar_ptr_Q,
|
| 469 |
+
num_stages=1,
|
| 470 |
+
producer_group=load_threads,
|
| 471 |
+
consumer_group=mma_warps,
|
| 472 |
+
defer_sync=True,
|
| 473 |
+
elect_one_release=True,
|
| 474 |
+
syncwarp_before_release=False,
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
if const_expr(self.use_tma_KV):
|
| 478 |
+
pipeline_k = pipeline_custom.PipelineTmaAsync.create(
|
| 479 |
+
barrier_storage=storage.mbar_ptr_K.data_ptr(),
|
| 480 |
+
num_stages=self.num_stages,
|
| 481 |
+
producer_group=tma_warp,
|
| 482 |
+
consumer_group=mma_warps,
|
| 483 |
+
tx_count=self.tma_copy_bytes["K"],
|
| 484 |
+
defer_sync=True,
|
| 485 |
+
)
|
| 486 |
+
pipeline_v = pipeline_custom.PipelineTmaAsync.create(
|
| 487 |
+
barrier_storage=storage.mbar_ptr_V.data_ptr(),
|
| 488 |
+
num_stages=self.num_stages,
|
| 489 |
+
producer_group=tma_warp,
|
| 490 |
+
consumer_group=mma_warps,
|
| 491 |
+
tx_count=self.tma_copy_bytes["V"],
|
| 492 |
+
defer_sync=True,
|
| 493 |
+
)
|
| 494 |
+
else:
|
| 495 |
+
pipeline_k = pipeline_custom.PipelineCpAsync.create(
|
| 496 |
+
barrier_storage=storage.mbar_ptr_K.data_ptr(),
|
| 497 |
+
num_stages=self.num_stages,
|
| 498 |
+
producer_group=load_threads,
|
| 499 |
+
consumer_group=mma_warps,
|
| 500 |
+
defer_sync=True,
|
| 501 |
+
elect_one_release=True,
|
| 502 |
+
syncwarp_before_release=False,
|
| 503 |
+
)
|
| 504 |
+
pipeline_v = pipeline_custom.PipelineCpAsync.create(
|
| 505 |
+
barrier_storage=storage.mbar_ptr_V.data_ptr(),
|
| 506 |
+
num_stages=self.num_stages,
|
| 507 |
+
producer_group=load_threads,
|
| 508 |
+
consumer_group=mma_warps,
|
| 509 |
+
defer_sync=True,
|
| 510 |
+
elect_one_release=True,
|
| 511 |
+
syncwarp_before_release=False,
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
# Cluster arrive after barrier init
|
| 515 |
+
pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mn, is_relaxed=True)
|
| 516 |
+
|
| 517 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 518 |
+
# Get shared memory buffer
|
| 519 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 520 |
+
sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner)
|
| 521 |
+
sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner)
|
| 522 |
+
if const_expr(not self.Q_in_regs):
|
| 523 |
+
sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner)
|
| 524 |
+
else:
|
| 525 |
+
sV = storage.sQ.get_tensor(
|
| 526 |
+
sV_layout.outer, swizzle=sV_layout.inner, dtype=mV.element_type
|
| 527 |
+
)
|
| 528 |
+
# Transpose view of V to tensor with layout (head_dim_v, tile_n) for tiled mma
|
| 529 |
+
sVt = layout_utils.transpose_view(sV)
|
| 530 |
+
sP = None
|
| 531 |
+
if const_expr(sP_layout is not None):
|
| 532 |
+
sP = storage.sP.get_tensor(sP_layout.outer, swizzle=sP_layout.inner)
|
| 533 |
+
# reuse sQ's data iterator
|
| 534 |
+
sO = storage.sQ.get_tensor(sO_layout.outer, swizzle=sO_layout.inner, dtype=self.dtype)
|
| 535 |
+
|
| 536 |
+
block_info = BlockInfo(
|
| 537 |
+
self.tile_m,
|
| 538 |
+
self.tile_n,
|
| 539 |
+
self.is_causal,
|
| 540 |
+
self.is_local,
|
| 541 |
+
False, # is_split_kv
|
| 542 |
+
window_size_left,
|
| 543 |
+
window_size_right,
|
| 544 |
+
qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
|
| 545 |
+
)
|
| 546 |
+
SeqlenInfoCls = partial(
|
| 547 |
+
SeqlenInfoQK.create,
|
| 548 |
+
seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1],
|
| 549 |
+
seqlen_k_static=mK.shape[0]
|
| 550 |
+
if const_expr(mPageTable is None)
|
| 551 |
+
else mK.shape[0] * mPageTable.shape[1],
|
| 552 |
+
mCuSeqlensQ=mCuSeqlensQ,
|
| 553 |
+
mCuSeqlensK=mCuSeqlensK,
|
| 554 |
+
mSeqUsedQ=mSeqUsedQ,
|
| 555 |
+
mSeqUsedK=mSeqUsedK,
|
| 556 |
+
# Don't need to pass in tile_mn because we won't access offset_padded
|
| 557 |
+
)
|
| 558 |
+
AttentionMaskCls = partial(
|
| 559 |
+
AttentionMask,
|
| 560 |
+
self.tile_m,
|
| 561 |
+
self.tile_n,
|
| 562 |
+
window_size_left=window_size_left,
|
| 563 |
+
window_size_right=window_size_right,
|
| 564 |
+
qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
|
| 565 |
+
)
|
| 566 |
+
TileSchedulerCls = partial(TileScheduler.create, tile_sched_params)
|
| 567 |
+
|
| 568 |
+
# Cluster wait before starting
|
| 569 |
+
pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mn)
|
| 570 |
+
|
| 571 |
+
if warp_idx < 4: # Producer
|
| 572 |
+
cute.arch.setmaxregister_decrease(self.num_producer_regs)
|
| 573 |
+
self.load(
|
| 574 |
+
mQ,
|
| 575 |
+
mK,
|
| 576 |
+
mV,
|
| 577 |
+
sQ,
|
| 578 |
+
sK,
|
| 579 |
+
sV,
|
| 580 |
+
tma_atom_Q,
|
| 581 |
+
tma_atom_K,
|
| 582 |
+
tma_atom_V,
|
| 583 |
+
pipeline_k,
|
| 584 |
+
pipeline_v,
|
| 585 |
+
pipeline_q,
|
| 586 |
+
gmem_tiled_copy_Q,
|
| 587 |
+
mPageTable,
|
| 588 |
+
blocksparse_tensors,
|
| 589 |
+
block_info,
|
| 590 |
+
SeqlenInfoCls,
|
| 591 |
+
TileSchedulerCls,
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
else: # Consumer
|
| 595 |
+
cute.arch.setmaxregister_increase(self.num_mma_regs)
|
| 596 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 597 |
+
# Tile MMA compute thread partitions and allocate accumulators
|
| 598 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 599 |
+
tidx, _, _ = cute.arch.thread_idx()
|
| 600 |
+
tidx = tidx - 128
|
| 601 |
+
self.mma(
|
| 602 |
+
tiled_mma_qk,
|
| 603 |
+
tiled_mma_pv,
|
| 604 |
+
mO,
|
| 605 |
+
mLSE,
|
| 606 |
+
sQ,
|
| 607 |
+
sK,
|
| 608 |
+
sVt,
|
| 609 |
+
sP,
|
| 610 |
+
sO,
|
| 611 |
+
learnable_sink,
|
| 612 |
+
pipeline_k,
|
| 613 |
+
pipeline_v,
|
| 614 |
+
pipeline_q,
|
| 615 |
+
gmem_tiled_copy_O,
|
| 616 |
+
tma_atom_O,
|
| 617 |
+
tidx,
|
| 618 |
+
softmax_scale_log2,
|
| 619 |
+
softmax_scale,
|
| 620 |
+
block_info,
|
| 621 |
+
SeqlenInfoCls,
|
| 622 |
+
AttentionMaskCls,
|
| 623 |
+
TileSchedulerCls,
|
| 624 |
+
blocksparse_tensors,
|
| 625 |
+
aux_tensors,
|
| 626 |
+
fastdiv_mods,
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
@cute.jit
|
| 630 |
+
def load(
|
| 631 |
+
self,
|
| 632 |
+
mQ: cute.Tensor,
|
| 633 |
+
mK: cute.Tensor,
|
| 634 |
+
mV: cute.Tensor,
|
| 635 |
+
sQ: cute.Tensor,
|
| 636 |
+
sK: cute.Tensor,
|
| 637 |
+
sV: cute.Tensor,
|
| 638 |
+
tma_atom_Q: Optional[cute.CopyAtom],
|
| 639 |
+
tma_atom_K: Optional[cute.CopyAtom],
|
| 640 |
+
tma_atom_V: Optional[cute.CopyAtom],
|
| 641 |
+
pipeline_k: pipeline.PipelineAsync,
|
| 642 |
+
pipeline_v: pipeline.PipelineAsync,
|
| 643 |
+
pipeline_q: pipeline.PipelineAsync,
|
| 644 |
+
gmem_tiled_copy_Q: cute.TiledCopy,
|
| 645 |
+
mPageTable: Optional[cute.Tensor],
|
| 646 |
+
blocksparse_tensors: Optional[BlockSparseTensors],
|
| 647 |
+
block_info: BlockInfo,
|
| 648 |
+
SeqlenInfoCls: Callable,
|
| 649 |
+
TileSchedulerCls: Callable,
|
| 650 |
+
):
|
| 651 |
+
warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4
|
| 652 |
+
tidx, _, _ = cute.arch.thread_idx()
|
| 653 |
+
|
| 654 |
+
# TMA: only warp 0 loads. cp_async: all warps load.
|
| 655 |
+
# When not use_tma_Q, all 128 producer threads participate in Q loading.
|
| 656 |
+
is_load_warp = warp_idx_in_wg == 0 or const_expr(not self.use_tma_KV or not self.use_tma_Q)
|
| 657 |
+
# KV loading restricted to warp 0 for TMA, all warps for non-TMA KV
|
| 658 |
+
is_kv_load_warp = warp_idx_in_wg == 0 or const_expr(not self.use_tma_KV)
|
| 659 |
+
|
| 660 |
+
if is_load_warp:
|
| 661 |
+
q_producer_phase = Int32(1)
|
| 662 |
+
kv_producer_state = pipeline.make_pipeline_state(
|
| 663 |
+
pipeline.PipelineUserType.Producer, self.num_stages
|
| 664 |
+
)
|
| 665 |
+
tile_scheduler = TileSchedulerCls()
|
| 666 |
+
work_tile = tile_scheduler.initial_work_tile_info()
|
| 667 |
+
while work_tile.is_valid_tile:
|
| 668 |
+
# if work_tile.is_valid_tile:
|
| 669 |
+
m_block, head_idx, batch_idx, _ = work_tile.tile_idx
|
| 670 |
+
seqlen = SeqlenInfoCls(batch_idx)
|
| 671 |
+
mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx]
|
| 672 |
+
head_idx_kv = (
|
| 673 |
+
head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx
|
| 674 |
+
)
|
| 675 |
+
|
| 676 |
+
load_Q = None
|
| 677 |
+
if const_expr(self.use_tma_Q):
|
| 678 |
+
gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0))
|
| 679 |
+
load_Q, _, _ = copy_utils.tma_get_copy_fn(
|
| 680 |
+
tma_atom_Q, 0, cute.make_layout(1), gQ, sQ, single_stage=True
|
| 681 |
+
)
|
| 682 |
+
|
| 683 |
+
paged_kv_manager = None
|
| 684 |
+
tma_load_K_fn = None
|
| 685 |
+
tma_load_V_fn = None
|
| 686 |
+
if const_expr(self.use_tma_KV):
|
| 687 |
+
# === TMA path (non-paged and paged with page_size == n_block_size) ===
|
| 688 |
+
if const_expr(mPageTable is not None):
|
| 689 |
+
# Paged TMA: keep page dimension indexable
|
| 690 |
+
mK_cur = mK[None, None, head_idx_kv, None]
|
| 691 |
+
mV_cur = mV[None, None, head_idx_kv, None]
|
| 692 |
+
gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (0, 0, None))
|
| 693 |
+
gV = cute.local_tile(mV_cur, (self.tile_n, self.tile_hdimv), (0, 0, None))
|
| 694 |
+
else:
|
| 695 |
+
# Non-paged TMA
|
| 696 |
+
mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[
|
| 697 |
+
None, None, head_idx_kv
|
| 698 |
+
]
|
| 699 |
+
mV_cur = seqlen.offset_batch_K(mV, batch_idx, dim=3)[
|
| 700 |
+
None, None, head_idx_kv
|
| 701 |
+
]
|
| 702 |
+
gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (None, 0))
|
| 703 |
+
gV = cute.local_tile(mV_cur, (self.tile_n, self.tile_hdimv), (None, 0))
|
| 704 |
+
# TODO: mcast
|
| 705 |
+
tma_load_K_fn, _, _ = copy_utils.tma_get_copy_fn(
|
| 706 |
+
tma_atom_K, 0, cute.make_layout(1), gK, sK
|
| 707 |
+
)
|
| 708 |
+
tma_load_K_fn = copy_utils.tma_producer_copy_fn(tma_load_K_fn, pipeline_k)
|
| 709 |
+
tma_load_V_fn, _, _ = copy_utils.tma_get_copy_fn(
|
| 710 |
+
tma_atom_V, 0, cute.make_layout(1), gV, sV
|
| 711 |
+
)
|
| 712 |
+
tma_load_V_fn = copy_utils.tma_producer_copy_fn(tma_load_V_fn, pipeline_v)
|
| 713 |
+
else:
|
| 714 |
+
# === cp_async path (paged KV with page_size != n_block_size) ===
|
| 715 |
+
paged_kv_manager = PagedKVManager.create(
|
| 716 |
+
mPageTable,
|
| 717 |
+
mK,
|
| 718 |
+
mV,
|
| 719 |
+
FastDivmodDivisor(mK.shape[0]),
|
| 720 |
+
batch_idx,
|
| 721 |
+
head_idx_kv,
|
| 722 |
+
tidx,
|
| 723 |
+
seqlen.seqlen_k,
|
| 724 |
+
0, # leftpad_k
|
| 725 |
+
self.tile_n,
|
| 726 |
+
self.tile_hdim,
|
| 727 |
+
self.tile_hdimv,
|
| 728 |
+
self.num_threads_per_warp_group,
|
| 729 |
+
mK.element_type,
|
| 730 |
+
arch=self.arch.major * 10 + self.arch.minor,
|
| 731 |
+
)
|
| 732 |
+
|
| 733 |
+
load_K = partial(
|
| 734 |
+
self.load_KV,
|
| 735 |
+
tma_load_K_fn,
|
| 736 |
+
paged_kv_manager,
|
| 737 |
+
sK,
|
| 738 |
+
pipeline_kv=pipeline_k,
|
| 739 |
+
K_or_V="K",
|
| 740 |
+
)
|
| 741 |
+
load_V = partial(
|
| 742 |
+
self.load_KV,
|
| 743 |
+
tma_load_V_fn,
|
| 744 |
+
paged_kv_manager,
|
| 745 |
+
sV,
|
| 746 |
+
pipeline_kv=pipeline_v,
|
| 747 |
+
K_or_V="V",
|
| 748 |
+
)
|
| 749 |
+
|
| 750 |
+
pack_gqa = None
|
| 751 |
+
if const_expr(not self.use_tma_Q):
|
| 752 |
+
pack_gqa = PackGQA(
|
| 753 |
+
self.tile_m, self.tile_hdim, self.check_hdim_oob, self.qhead_per_kvhead
|
| 754 |
+
)
|
| 755 |
+
|
| 756 |
+
if const_expr(not self.use_block_sparsity):
|
| 757 |
+
n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block)
|
| 758 |
+
# if cute.arch.thread_idx()[0] == 0:
|
| 759 |
+
# cute.printf("m_block = %d, n_block_min: %d, n_block_max: %d", m_block, n_block_min, n_block_max)
|
| 760 |
+
# Clamp n_block to 0 when n_block_max == 0 (can happen with causal
|
| 761 |
+
# + pack_gqa when seqlen_k < tile_n). TMA handles n_block=-1
|
| 762 |
+
# gracefully (fills zeros), but cp.async would crash on
|
| 763 |
+
# out-of-bounds page table access.
|
| 764 |
+
n_block = (
|
| 765 |
+
n_block_max - 1
|
| 766 |
+
if const_expr(self.use_tma_KV)
|
| 767 |
+
else cutlass.max(n_block_max - 1, 0)
|
| 768 |
+
)
|
| 769 |
+
page_idx = (
|
| 770 |
+
mPageTable[batch_idx, n_block]
|
| 771 |
+
if const_expr(mPageTable is not None and self.use_tma_KV)
|
| 772 |
+
else None
|
| 773 |
+
)
|
| 774 |
+
|
| 775 |
+
# First iteration: load K on pipeline_k, Q on pipeline_q
|
| 776 |
+
if is_kv_load_warp:
|
| 777 |
+
pipeline_k.producer_acquire(kv_producer_state)
|
| 778 |
+
if const_expr(not self.use_tma_KV):
|
| 779 |
+
paged_kv_manager.load_page_table(n_block)
|
| 780 |
+
load_K(block=n_block, producer_state=kv_producer_state, page_idx=page_idx)
|
| 781 |
+
if const_expr(self.use_tma_Q):
|
| 782 |
+
if warp_idx_in_wg == 0:
|
| 783 |
+
pipeline_q.producer_acquire_w_index_phase(0, q_producer_phase)
|
| 784 |
+
load_Q(tma_bar_ptr=pipeline_q.sync_object_full.get_barrier(0))
|
| 785 |
+
q_producer_phase ^= 1
|
| 786 |
+
else:
|
| 787 |
+
pipeline_q.producer_acquire_w_index_phase(0, q_producer_phase)
|
| 788 |
+
pack_gqa.load_Q(
|
| 789 |
+
mQ_cur, sQ, gmem_tiled_copy_Q, tidx, m_block, seqlen.seqlen_q
|
| 790 |
+
)
|
| 791 |
+
cute.arch.cp_async_commit_group()
|
| 792 |
+
pipeline_q.producer_commit_w_index(0)
|
| 793 |
+
q_producer_phase ^= 1
|
| 794 |
+
|
| 795 |
+
if is_kv_load_warp:
|
| 796 |
+
if const_expr(not self.intra_wg_overlap or not self.use_tma_KV):
|
| 797 |
+
pipeline_v.producer_acquire(kv_producer_state)
|
| 798 |
+
load_V(
|
| 799 |
+
block=n_block, producer_state=kv_producer_state, page_idx=page_idx
|
| 800 |
+
)
|
| 801 |
+
kv_producer_state.advance()
|
| 802 |
+
for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1):
|
| 803 |
+
n_block = n_block_max - 1 - i - 1
|
| 804 |
+
page_idx = (
|
| 805 |
+
mPageTable[batch_idx, n_block]
|
| 806 |
+
if const_expr(mPageTable is not None and self.use_tma_KV)
|
| 807 |
+
else None
|
| 808 |
+
)
|
| 809 |
+
if const_expr(not self.use_tma_KV):
|
| 810 |
+
paged_kv_manager.load_page_table(n_block)
|
| 811 |
+
pipeline_k.producer_acquire(kv_producer_state)
|
| 812 |
+
load_K(
|
| 813 |
+
block=n_block,
|
| 814 |
+
producer_state=kv_producer_state,
|
| 815 |
+
page_idx=page_idx,
|
| 816 |
+
)
|
| 817 |
+
pipeline_v.producer_acquire(kv_producer_state)
|
| 818 |
+
load_V(
|
| 819 |
+
block=n_block,
|
| 820 |
+
producer_state=kv_producer_state,
|
| 821 |
+
page_idx=page_idx,
|
| 822 |
+
)
|
| 823 |
+
kv_producer_state.advance()
|
| 824 |
+
else:
|
| 825 |
+
for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1):
|
| 826 |
+
n_block_prev = n_block_max - i - 1
|
| 827 |
+
n_block = n_block_prev - 1
|
| 828 |
+
page_idx = (
|
| 829 |
+
mPageTable[batch_idx, n_block]
|
| 830 |
+
if const_expr(mPageTable is not None)
|
| 831 |
+
else None
|
| 832 |
+
)
|
| 833 |
+
page_idx_prev = (
|
| 834 |
+
mPageTable[batch_idx, n_block_prev]
|
| 835 |
+
if const_expr(mPageTable is not None)
|
| 836 |
+
else None
|
| 837 |
+
)
|
| 838 |
+
kv_producer_state_prev = kv_producer_state.clone()
|
| 839 |
+
kv_producer_state.advance()
|
| 840 |
+
pipeline_k.producer_acquire(kv_producer_state)
|
| 841 |
+
load_K(
|
| 842 |
+
block=n_block,
|
| 843 |
+
producer_state=kv_producer_state,
|
| 844 |
+
page_idx=page_idx,
|
| 845 |
+
)
|
| 846 |
+
pipeline_v.producer_acquire(kv_producer_state_prev)
|
| 847 |
+
load_V(
|
| 848 |
+
block=n_block_prev,
|
| 849 |
+
producer_state=kv_producer_state_prev,
|
| 850 |
+
page_idx=page_idx_prev,
|
| 851 |
+
)
|
| 852 |
+
n_block = n_block_min
|
| 853 |
+
page_idx = (
|
| 854 |
+
mPageTable[batch_idx, n_block]
|
| 855 |
+
if const_expr(mPageTable is not None)
|
| 856 |
+
else None
|
| 857 |
+
)
|
| 858 |
+
pipeline_v.producer_acquire(kv_producer_state)
|
| 859 |
+
load_V(
|
| 860 |
+
block=n_block, producer_state=kv_producer_state, page_idx=page_idx
|
| 861 |
+
)
|
| 862 |
+
kv_producer_state.advance()
|
| 863 |
+
else:
|
| 864 |
+
# Block sparsity: use TMA closures directly (not paged)
|
| 865 |
+
# Load Q on pipeline_q, separate from K/V pipeline
|
| 866 |
+
if const_expr(self.use_tma_Q):
|
| 867 |
+
if warp_idx_in_wg == 0:
|
| 868 |
+
pipeline_q.producer_acquire_w_index_phase(0, q_producer_phase)
|
| 869 |
+
load_Q(tma_bar_ptr=pipeline_q.sync_object_full.get_barrier(0))
|
| 870 |
+
q_producer_phase ^= 1
|
| 871 |
+
else:
|
| 872 |
+
pipeline_q.producer_acquire_w_index_phase(0, q_producer_phase)
|
| 873 |
+
pack_gqa.load_Q(
|
| 874 |
+
mQ_cur, sQ, gmem_tiled_copy_Q, tidx, m_block, seqlen.seqlen_q
|
| 875 |
+
)
|
| 876 |
+
cute.arch.cp_async_commit_group()
|
| 877 |
+
pipeline_q.producer_commit_w_index(0)
|
| 878 |
+
q_producer_phase ^= 1
|
| 879 |
+
if is_kv_load_warp:
|
| 880 |
+
kv_producer_state = produce_block_sparse_loads(
|
| 881 |
+
blocksparse_tensors,
|
| 882 |
+
batch_idx,
|
| 883 |
+
head_idx,
|
| 884 |
+
m_block,
|
| 885 |
+
kv_producer_state,
|
| 886 |
+
tma_load_K_fn,
|
| 887 |
+
tma_load_V_fn,
|
| 888 |
+
pipeline_k,
|
| 889 |
+
pipeline_v,
|
| 890 |
+
self.intra_wg_overlap,
|
| 891 |
+
self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
|
| 892 |
+
self.q_subtile_factor if self.q_subtile_factor is not None else 1,
|
| 893 |
+
)
|
| 894 |
+
|
| 895 |
+
tile_scheduler.prefetch_next_work()
|
| 896 |
+
tile_scheduler.advance_to_next_work()
|
| 897 |
+
work_tile = tile_scheduler.get_current_work()
|
| 898 |
+
# End of persistent scheduler loop
|
| 899 |
+
|
| 900 |
+
# Producer tail is only useful for cluster to avoid early exit of blocks.
|
| 901 |
+
# We only need producer_tail on V since that's the last that's loaded, we don't
|
| 902 |
+
# need it for Q (no cluster) and K.
|
| 903 |
+
if is_kv_load_warp:
|
| 904 |
+
pipeline_v.producer_tail(kv_producer_state)
|
| 905 |
+
|
| 906 |
+
@cute.jit
|
| 907 |
+
def load_KV(
|
| 908 |
+
self,
|
| 909 |
+
tma_load_fn: Optional[Callable],
|
| 910 |
+
paged_kv_manager: Optional[PagedKVManager],
|
| 911 |
+
sX: cute.Tensor,
|
| 912 |
+
block: Int32,
|
| 913 |
+
pipeline_kv: pipeline.PipelineAsync,
|
| 914 |
+
producer_state: pipeline.PipelineState,
|
| 915 |
+
K_or_V: Literal["K", "V"],
|
| 916 |
+
page_idx: Optional[Int32] = None,
|
| 917 |
+
):
|
| 918 |
+
if const_expr(self.use_tma_KV):
|
| 919 |
+
src_idx = block if const_expr(page_idx is None) else page_idx
|
| 920 |
+
tma_load_fn(src_idx=src_idx, producer_state=producer_state)
|
| 921 |
+
else:
|
| 922 |
+
paged_kv_manager.load_KV(block, sX[None, None, producer_state.index], K_or_V)
|
| 923 |
+
cute.arch.cp_async_commit_group()
|
| 924 |
+
pipeline_kv.producer_commit(producer_state)
|
| 925 |
+
|
| 926 |
+
@cute.jit
|
| 927 |
+
def mma(
|
| 928 |
+
self,
|
| 929 |
+
tiled_mma_qk: cute.TiledMma,
|
| 930 |
+
tiled_mma_pv: cute.TiledMma,
|
| 931 |
+
mO: cute.Tensor,
|
| 932 |
+
mLSE: Optional[cute.Tensor],
|
| 933 |
+
sQ: cute.Tensor,
|
| 934 |
+
sK: cute.Tensor,
|
| 935 |
+
sVt: cute.Tensor,
|
| 936 |
+
sP: Optional[cute.Tensor],
|
| 937 |
+
sO: cute.Tensor,
|
| 938 |
+
learnable_sink: Optional[cute.Tensor],
|
| 939 |
+
pipeline_k: pipeline.PipelineAsync,
|
| 940 |
+
pipeline_v: pipeline.PipelineAsync,
|
| 941 |
+
pipeline_q: pipeline.PipelineAsync,
|
| 942 |
+
gmem_tiled_copy_O: cute.TiledCopy,
|
| 943 |
+
tma_atom_O: Optional[cute.CopyAtom],
|
| 944 |
+
tidx: Int32,
|
| 945 |
+
softmax_scale_log2: Float32,
|
| 946 |
+
softmax_scale: Optional[Float32],
|
| 947 |
+
block_info: BlockInfo,
|
| 948 |
+
SeqlenInfoCls: Callable,
|
| 949 |
+
AttentionMaskCls: Callable,
|
| 950 |
+
TileSchedulerCls: Callable,
|
| 951 |
+
blocksparse_tensors: Optional[BlockSparseTensors],
|
| 952 |
+
aux_tensors: Optional[list],
|
| 953 |
+
fastdiv_mods=None,
|
| 954 |
+
):
|
| 955 |
+
warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group)
|
| 956 |
+
warp_group_thread_layout = cute.make_layout(
|
| 957 |
+
self.num_wg_mma, stride=self.num_threads_per_warp_group
|
| 958 |
+
)
|
| 959 |
+
thr_mma_qk = tiled_mma_qk.get_slice(tidx)
|
| 960 |
+
wg_mma_qk = tiled_mma_qk.get_slice(warp_group_thread_layout(warp_group_idx))
|
| 961 |
+
wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx))
|
| 962 |
+
_, tSrQ, tSrK = sm90_utils.partition_fragment_ABC(
|
| 963 |
+
wg_mma_qk, (self.tile_m, self.tile_n, self.tile_hdim), sQ, sK
|
| 964 |
+
)
|
| 965 |
+
mma_qk_fn = partial(
|
| 966 |
+
sm90_utils.gemm_zero_init, tiled_mma_qk, (self.tile_m, self.tile_n), tSrQ, tSrK
|
| 967 |
+
)
|
| 968 |
+
acc_O, tOrP, tOrVt = sm90_utils.partition_fragment_ABC(
|
| 969 |
+
wg_mma_pv, (self.tile_m, self.tile_hdimv, self.tile_n), sP, sVt
|
| 970 |
+
)
|
| 971 |
+
mma_pv_fn = partial(sm90_utils.gemm_w_idx, tiled_mma_pv, acc_O, tOrP, tOrVt)
|
| 972 |
+
|
| 973 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 974 |
+
# Smem copy atom tiling
|
| 975 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 976 |
+
smem_copy_atom_P = utils.get_smem_store_atom(
|
| 977 |
+
self.arch.major * 10 + self.arch.minor, self.dtype
|
| 978 |
+
)
|
| 979 |
+
smem_thr_copy_P = cute.make_tiled_copy_C(smem_copy_atom_P, tiled_mma_qk).get_slice(tidx)
|
| 980 |
+
tPsP = smem_thr_copy_P.partition_D(sP) if const_expr(sP is not None) else None
|
| 981 |
+
smem_copy_params = SimpleNamespace(smem_thr_copy_P=smem_thr_copy_P, tPsP=tPsP)
|
| 982 |
+
|
| 983 |
+
self.mma_init()
|
| 984 |
+
|
| 985 |
+
q_consumer_phase = Int32(0)
|
| 986 |
+
kv_consumer_state = pipeline.make_pipeline_state(
|
| 987 |
+
pipeline.PipelineUserType.Consumer, self.num_stages
|
| 988 |
+
)
|
| 989 |
+
|
| 990 |
+
tile_scheduler = TileSchedulerCls()
|
| 991 |
+
work_tile = tile_scheduler.initial_work_tile_info()
|
| 992 |
+
softmax = Softmax.create(
|
| 993 |
+
softmax_scale_log2,
|
| 994 |
+
num_rows=acc_O.shape[0][0] * acc_O.shape[1],
|
| 995 |
+
softmax_scale=softmax_scale,
|
| 996 |
+
)
|
| 997 |
+
|
| 998 |
+
# For RescaleOBeforeGemm: persistent scores_scale across iterations
|
| 999 |
+
scores_scale = None
|
| 1000 |
+
if const_expr(self.rescale_O_before_gemm):
|
| 1001 |
+
scores_scale = cute.make_rmem_tensor_like(softmax.row_max, Float32)
|
| 1002 |
+
|
| 1003 |
+
mma_one_n_block_all = partial(
|
| 1004 |
+
self.mma_one_n_block_intrawg_overlap
|
| 1005 |
+
if const_expr(self.intra_wg_overlap)
|
| 1006 |
+
else self.mma_one_n_block,
|
| 1007 |
+
mma_qk_fn=mma_qk_fn,
|
| 1008 |
+
pipeline_k=pipeline_k,
|
| 1009 |
+
pipeline_v=pipeline_v,
|
| 1010 |
+
acc_O=acc_O,
|
| 1011 |
+
tOrP=tOrP,
|
| 1012 |
+
smem_copy_params=smem_copy_params,
|
| 1013 |
+
check_inf=True,
|
| 1014 |
+
scores_scale=scores_scale,
|
| 1015 |
+
)
|
| 1016 |
+
|
| 1017 |
+
process_first_half_block = partial(
|
| 1018 |
+
self.first_half_block_overlap,
|
| 1019 |
+
mma_qk_fn=mma_qk_fn,
|
| 1020 |
+
pipeline_k=pipeline_k,
|
| 1021 |
+
tOrP=tOrP,
|
| 1022 |
+
smem_copy_params=smem_copy_params,
|
| 1023 |
+
scores_scale=scores_scale,
|
| 1024 |
+
softmax=softmax,
|
| 1025 |
+
acc_O=acc_O,
|
| 1026 |
+
)
|
| 1027 |
+
process_last_half_block = partial(
|
| 1028 |
+
self.last_half_block_overlap,
|
| 1029 |
+
pipeline_v=pipeline_v,
|
| 1030 |
+
mma_pv_fn=mma_pv_fn,
|
| 1031 |
+
scores_scale=scores_scale,
|
| 1032 |
+
softmax=softmax,
|
| 1033 |
+
acc_O=acc_O,
|
| 1034 |
+
)
|
| 1035 |
+
while work_tile.is_valid_tile:
|
| 1036 |
+
# if work_tile.is_valid_tile:
|
| 1037 |
+
|
| 1038 |
+
# shape: (atom_v_m * rest_m)
|
| 1039 |
+
m_block, head_idx, batch_idx, _ = work_tile.tile_idx
|
| 1040 |
+
seqlen = SeqlenInfoCls(batch_idx)
|
| 1041 |
+
|
| 1042 |
+
# Recompute fastdiv_mods if necessary for varlen with aux_tensors
|
| 1043 |
+
recompute_fastdiv_mods_q = cutlass.const_expr(
|
| 1044 |
+
aux_tensors is not None and (seqlen.has_cu_seqlens_q or seqlen.has_seqused_q)
|
| 1045 |
+
)
|
| 1046 |
+
recompute_fastdiv_mods_k = cutlass.const_expr(
|
| 1047 |
+
aux_tensors is not None and (seqlen.has_cu_seqlens_k or seqlen.has_seqused_k)
|
| 1048 |
+
)
|
| 1049 |
+
if cutlass.const_expr(fastdiv_mods is not None):
|
| 1050 |
+
seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods
|
| 1051 |
+
fastdiv_mods = (
|
| 1052 |
+
seqlen_q_divmod
|
| 1053 |
+
if not recompute_fastdiv_mods_q
|
| 1054 |
+
else FastDivmodDivisor(seqlen.seqlen_q),
|
| 1055 |
+
seqlen_k_divmod
|
| 1056 |
+
if not recompute_fastdiv_mods_k
|
| 1057 |
+
else FastDivmodDivisor(seqlen.seqlen_k),
|
| 1058 |
+
)
|
| 1059 |
+
|
| 1060 |
+
mask = AttentionMaskCls(seqlen)
|
| 1061 |
+
mask_fn = partial(
|
| 1062 |
+
mask.apply_mask,
|
| 1063 |
+
batch_idx=batch_idx,
|
| 1064 |
+
head_idx=head_idx,
|
| 1065 |
+
m_block=m_block,
|
| 1066 |
+
thr_mma=thr_mma_qk,
|
| 1067 |
+
mask_causal=self.is_causal,
|
| 1068 |
+
mask_local=self.is_local,
|
| 1069 |
+
aux_tensors=aux_tensors,
|
| 1070 |
+
fastdiv_mods=fastdiv_mods,
|
| 1071 |
+
)
|
| 1072 |
+
score_mod_fn = None
|
| 1073 |
+
if const_expr(self.score_mod is not None):
|
| 1074 |
+
score_mod_fn = partial(
|
| 1075 |
+
self.apply_score_mod,
|
| 1076 |
+
thr_mma_qk,
|
| 1077 |
+
batch_idx,
|
| 1078 |
+
head_idx,
|
| 1079 |
+
m_block,
|
| 1080 |
+
softmax_scale=softmax_scale,
|
| 1081 |
+
aux_tensors=aux_tensors,
|
| 1082 |
+
fastdiv_mods=fastdiv_mods,
|
| 1083 |
+
)
|
| 1084 |
+
mma_one_n_block = partial(
|
| 1085 |
+
mma_one_n_block_all, seqlen=seqlen, softmax=softmax, score_mod_fn=score_mod_fn
|
| 1086 |
+
)
|
| 1087 |
+
n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block)
|
| 1088 |
+
pipeline_q.consumer_wait_w_index_phase(0, q_consumer_phase)
|
| 1089 |
+
# For performance reason, we separate out two kinds of iterations:
|
| 1090 |
+
# those that need masking on S, and those that don't.
|
| 1091 |
+
# We need masking on S for the very last block when K and V has length not multiple of tile_n.
|
| 1092 |
+
# We also need masking on S if it's causal, for the last several blocks.
|
| 1093 |
+
# softmax.reset() # Don't need reset as we explicitly call softmax w is_first=True
|
| 1094 |
+
O_should_accumulate = False
|
| 1095 |
+
|
| 1096 |
+
# ==========================================
|
| 1097 |
+
# MAINLOOP
|
| 1098 |
+
# ==========================================
|
| 1099 |
+
if const_expr(not self.use_block_sparsity):
|
| 1100 |
+
# ==========================================
|
| 1101 |
+
# No block-sparsity (original path)
|
| 1102 |
+
# ==========================================
|
| 1103 |
+
# First iteration with seqlen masking
|
| 1104 |
+
if const_expr(self.intra_wg_overlap):
|
| 1105 |
+
kv_consumer_state = process_first_half_block(
|
| 1106 |
+
n_block=n_block_max - 1,
|
| 1107 |
+
seqlen=seqlen,
|
| 1108 |
+
kv_consumer_state=kv_consumer_state,
|
| 1109 |
+
mask_fn=partial(mask_fn, mask_mod=self.mask_mod),
|
| 1110 |
+
score_mod_fn=score_mod_fn,
|
| 1111 |
+
is_first_block=True,
|
| 1112 |
+
)
|
| 1113 |
+
else:
|
| 1114 |
+
self.warp_scheduler_barrier_sync()
|
| 1115 |
+
kv_consumer_state = mma_one_n_block(
|
| 1116 |
+
kv_consumer_state,
|
| 1117 |
+
n_block=n_block_max - 1,
|
| 1118 |
+
seqlen=seqlen,
|
| 1119 |
+
mma_pv_fn=partial(mma_pv_fn, zero_init=True),
|
| 1120 |
+
is_first_n_block=True,
|
| 1121 |
+
mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=True),
|
| 1122 |
+
)
|
| 1123 |
+
O_should_accumulate = True
|
| 1124 |
+
# if cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block_max = {}, n_block_min = {}", m_block, n_block_max, n_block_min)
|
| 1125 |
+
n_block_max -= 1
|
| 1126 |
+
# Next couple of iterations with causal masking
|
| 1127 |
+
if const_expr(self.is_causal or self.is_local):
|
| 1128 |
+
n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask(
|
| 1129 |
+
seqlen, m_block, n_block_min
|
| 1130 |
+
)
|
| 1131 |
+
# if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_causal_local_mask = {}", n_block_min_causal_local_mask)
|
| 1132 |
+
for n_tile in cutlass.range(
|
| 1133 |
+
n_block_max - n_block_min_causal_local_mask, unroll=1
|
| 1134 |
+
):
|
| 1135 |
+
kv_consumer_state = mma_one_n_block(
|
| 1136 |
+
kv_consumer_state,
|
| 1137 |
+
n_block=n_block_max - 1 - n_tile,
|
| 1138 |
+
seqlen=seqlen,
|
| 1139 |
+
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
|
| 1140 |
+
mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False),
|
| 1141 |
+
)
|
| 1142 |
+
O_should_accumulate = True
|
| 1143 |
+
n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask)
|
| 1144 |
+
# The remaining iterations have no masking
|
| 1145 |
+
n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask(
|
| 1146 |
+
seqlen, m_block, n_block_min
|
| 1147 |
+
)
|
| 1148 |
+
# if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_before_local_mask = {}, n_block_min = {}", n_block_min_before_local_mask, n_block_min)
|
| 1149 |
+
for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1):
|
| 1150 |
+
kv_consumer_state = mma_one_n_block(
|
| 1151 |
+
kv_consumer_state,
|
| 1152 |
+
n_block=n_block_max - 1 - n_tile,
|
| 1153 |
+
seqlen=seqlen,
|
| 1154 |
+
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
|
| 1155 |
+
mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False),
|
| 1156 |
+
)
|
| 1157 |
+
O_should_accumulate = True
|
| 1158 |
+
# Separate iterations with local masking on the left
|
| 1159 |
+
if const_expr(self.is_local and block_info.window_size_left is not None):
|
| 1160 |
+
n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask)
|
| 1161 |
+
for n_tile in cutlass.range(n_block_max - n_block_min, unroll=1):
|
| 1162 |
+
kv_consumer_state = mma_one_n_block(
|
| 1163 |
+
kv_consumer_state,
|
| 1164 |
+
n_block=n_block_max - 1 - n_tile,
|
| 1165 |
+
seqlen=seqlen,
|
| 1166 |
+
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
|
| 1167 |
+
mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False),
|
| 1168 |
+
)
|
| 1169 |
+
O_should_accumulate = True
|
| 1170 |
+
# Release Q pipeline so the producer can load the next tile's Q
|
| 1171 |
+
pipeline_q.consumer_release_w_index(0)
|
| 1172 |
+
# Last "half" iteration
|
| 1173 |
+
if const_expr(self.intra_wg_overlap):
|
| 1174 |
+
kv_consumer_state = process_last_half_block(
|
| 1175 |
+
kv_consumer_state=kv_consumer_state,
|
| 1176 |
+
zero_init=not O_should_accumulate,
|
| 1177 |
+
)
|
| 1178 |
+
O_should_accumulate = True
|
| 1179 |
+
else:
|
| 1180 |
+
self.warp_scheduler_barrier_arrive()
|
| 1181 |
+
|
| 1182 |
+
else:
|
| 1183 |
+
# ==========================================
|
| 1184 |
+
# Block sparsity
|
| 1185 |
+
# ==========================================
|
| 1186 |
+
kv_consumer_state, O_should_accumulate, processed_any = consume_block_sparse_loads(
|
| 1187 |
+
blocksparse_tensors,
|
| 1188 |
+
batch_idx,
|
| 1189 |
+
head_idx,
|
| 1190 |
+
m_block,
|
| 1191 |
+
seqlen,
|
| 1192 |
+
kv_consumer_state,
|
| 1193 |
+
mma_pv_fn,
|
| 1194 |
+
mma_one_n_block,
|
| 1195 |
+
process_first_half_block,
|
| 1196 |
+
process_last_half_block,
|
| 1197 |
+
mask_fn,
|
| 1198 |
+
score_mod_fn,
|
| 1199 |
+
O_should_accumulate,
|
| 1200 |
+
self.mask_mod,
|
| 1201 |
+
fastdiv_mods,
|
| 1202 |
+
self.intra_wg_overlap,
|
| 1203 |
+
self.warp_scheduler_barrier_sync,
|
| 1204 |
+
self.warp_scheduler_barrier_arrive,
|
| 1205 |
+
self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
|
| 1206 |
+
self.q_subtile_factor if self.q_subtile_factor is not None else 1,
|
| 1207 |
+
)
|
| 1208 |
+
|
| 1209 |
+
# Release Q pipeline so the producer can load the next tile's Q
|
| 1210 |
+
pipeline_q.consumer_release_w_index(0)
|
| 1211 |
+
|
| 1212 |
+
# Handle empty case (when no blocks to process)
|
| 1213 |
+
if not processed_any:
|
| 1214 |
+
softmax.reset()
|
| 1215 |
+
acc_O.fill(0.0)
|
| 1216 |
+
|
| 1217 |
+
q_consumer_phase ^= 1
|
| 1218 |
+
|
| 1219 |
+
sink_val = None
|
| 1220 |
+
if const_expr(learnable_sink is not None):
|
| 1221 |
+
if const_expr(not self.pack_gqa):
|
| 1222 |
+
sink_val = Float32(learnable_sink[head_idx])
|
| 1223 |
+
else: # Each thread might have a different sink value due to different q_head
|
| 1224 |
+
sink_val = cute.make_rmem_tensor_like(softmax.row_max, Float32)
|
| 1225 |
+
cS = cute.make_identity_tensor((self.tile_m, self.tile_n))
|
| 1226 |
+
tScS_mn = layout_utils.reshape_acc_to_mn(thr_mma_qk.partition_C(cS))
|
| 1227 |
+
for r in cutlass.range(cute.size(sink_val), unroll_full=True):
|
| 1228 |
+
row = m_block * self.tile_m + tScS_mn[r][0]
|
| 1229 |
+
q_head_idx = row % self.qhead_per_kvhead + head_idx * self.qhead_per_kvhead
|
| 1230 |
+
sink_val[r] = Float32(learnable_sink[q_head_idx])
|
| 1231 |
+
|
| 1232 |
+
# normalize acc_O by row_sum and calculate the lse
|
| 1233 |
+
row_scale = softmax.finalize(sink_val=sink_val)
|
| 1234 |
+
softmax.rescale_O(acc_O, row_scale)
|
| 1235 |
+
|
| 1236 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 1237 |
+
# Epilogue
|
| 1238 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 1239 |
+
self.epilogue(
|
| 1240 |
+
acc_O,
|
| 1241 |
+
softmax.row_sum,
|
| 1242 |
+
mO,
|
| 1243 |
+
mLSE,
|
| 1244 |
+
sO,
|
| 1245 |
+
seqlen,
|
| 1246 |
+
gmem_tiled_copy_O,
|
| 1247 |
+
tma_atom_O,
|
| 1248 |
+
tiled_mma_pv,
|
| 1249 |
+
tidx,
|
| 1250 |
+
m_block,
|
| 1251 |
+
head_idx,
|
| 1252 |
+
batch_idx,
|
| 1253 |
+
)
|
| 1254 |
+
|
| 1255 |
+
tile_scheduler.advance_to_next_work()
|
| 1256 |
+
work_tile = tile_scheduler.get_current_work()
|
| 1257 |
+
|
| 1258 |
+
@cute.jit
|
| 1259 |
+
def first_half_block_overlap(
|
| 1260 |
+
self,
|
| 1261 |
+
n_block: Int32,
|
| 1262 |
+
mma_qk_fn: Callable,
|
| 1263 |
+
kv_consumer_state,
|
| 1264 |
+
pipeline_k,
|
| 1265 |
+
tOrP: cute.Tensor,
|
| 1266 |
+
smem_copy_params: SimpleNamespace,
|
| 1267 |
+
softmax: Softmax,
|
| 1268 |
+
seqlen: SeqlenInfoQK,
|
| 1269 |
+
scores_scale: Optional[cute.Tensor] = None,
|
| 1270 |
+
acc_O: Optional[cute.Tensor] = None,
|
| 1271 |
+
mask_fn: Callable = None,
|
| 1272 |
+
score_mod_fn: Optional[Callable] = None,
|
| 1273 |
+
is_first_block: bool = False,
|
| 1274 |
+
):
|
| 1275 |
+
"""Processes the first half block when using intra-warpgroup-overlap"""
|
| 1276 |
+
|
| 1277 |
+
pipeline_k.consumer_wait(kv_consumer_state, pipeline_k.consumer_try_wait(kv_consumer_state))
|
| 1278 |
+
acc_S = mma_qk_fn(B_idx=kv_consumer_state.index, wg_wait=0)
|
| 1279 |
+
pipeline_k.consumer_release(kv_consumer_state)
|
| 1280 |
+
|
| 1281 |
+
# Apply score modification if present
|
| 1282 |
+
if const_expr(score_mod_fn is not None):
|
| 1283 |
+
score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen)
|
| 1284 |
+
|
| 1285 |
+
# Apply mask; mask_seqlen always True for first block
|
| 1286 |
+
# Caveat: if full block further right than mask block, seqlen masking is redundant;
|
| 1287 |
+
# however, masking is being applied anyway, so essentially no perf hit
|
| 1288 |
+
mask_fn(acc_S, n_block=n_block, mask_seqlen=True)
|
| 1289 |
+
|
| 1290 |
+
row_scale = softmax.online_softmax(acc_S, is_first=is_first_block)
|
| 1291 |
+
|
| 1292 |
+
tOrP_acc = layout_utils.reshape_acc_to_frgA(acc_S)
|
| 1293 |
+
tOrP_cur = (
|
| 1294 |
+
tOrP
|
| 1295 |
+
if const_expr(self.mma_pv_is_rs)
|
| 1296 |
+
else cute.make_rmem_tensor_like(tOrP_acc, self.dtype)
|
| 1297 |
+
)
|
| 1298 |
+
tOrP_cur.store(tOrP_acc.load().to(self.dtype))
|
| 1299 |
+
|
| 1300 |
+
if const_expr(not self.mma_pv_is_rs):
|
| 1301 |
+
tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur)
|
| 1302 |
+
cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP)
|
| 1303 |
+
# Fence and barrier to make smem store visible to WGMMA
|
| 1304 |
+
cute.arch.fence_view_async_shared()
|
| 1305 |
+
cute.arch.sync_warp()
|
| 1306 |
+
|
| 1307 |
+
# For RescaleOBeforeGemm: initialize acc_O
|
| 1308 |
+
if const_expr(self.rescale_O_before_gemm):
|
| 1309 |
+
acc_O.fill(0.0)
|
| 1310 |
+
scores_scale.store(row_scale.load())
|
| 1311 |
+
|
| 1312 |
+
return kv_consumer_state
|
| 1313 |
+
|
| 1314 |
+
@cute.jit
|
| 1315 |
+
def last_half_block_overlap(
|
| 1316 |
+
self,
|
| 1317 |
+
kv_consumer_state,
|
| 1318 |
+
pipeline_v,
|
| 1319 |
+
mma_pv_fn: Callable,
|
| 1320 |
+
zero_init: bool,
|
| 1321 |
+
scores_scale: Optional[cute.Tensor] = None,
|
| 1322 |
+
softmax: Optional[Softmax] = None,
|
| 1323 |
+
acc_O: Optional[cute.Tensor] = None,
|
| 1324 |
+
):
|
| 1325 |
+
"""Processes the final PV GEMM when using intra-warpgroup-overlap"""
|
| 1326 |
+
|
| 1327 |
+
# For RescaleOBeforeGemm: rescale O before the final PV GEMM
|
| 1328 |
+
if const_expr(self.rescale_O_before_gemm):
|
| 1329 |
+
softmax.rescale_O(acc_O, scores_scale)
|
| 1330 |
+
|
| 1331 |
+
pipeline_v.consumer_wait(kv_consumer_state, pipeline_v.consumer_try_wait(kv_consumer_state))
|
| 1332 |
+
mma_pv_fn(B_idx=kv_consumer_state.index, zero_init=zero_init, wg_wait=0)
|
| 1333 |
+
pipeline_v.consumer_release(kv_consumer_state)
|
| 1334 |
+
kv_consumer_state.advance()
|
| 1335 |
+
return kv_consumer_state
|
| 1336 |
+
|
| 1337 |
+
@cute.jit
|
| 1338 |
+
def mma_one_n_block(
|
| 1339 |
+
self,
|
| 1340 |
+
smem_pipe_read: pipeline.PipelineState | pipeline_custom.PipelineStateSimple,
|
| 1341 |
+
n_block: Int32,
|
| 1342 |
+
mma_qk_fn: Callable,
|
| 1343 |
+
mma_pv_fn: Callable,
|
| 1344 |
+
pipeline_k: pipeline.PipelineAsync,
|
| 1345 |
+
pipeline_v: pipeline.PipelineAsync,
|
| 1346 |
+
acc_O: cute.Tensor,
|
| 1347 |
+
tOrP: cute.Tensor,
|
| 1348 |
+
smem_copy_params: SimpleNamespace,
|
| 1349 |
+
softmax: Softmax,
|
| 1350 |
+
seqlen: SeqlenInfoQK,
|
| 1351 |
+
scores_scale: Optional[cute.Tensor] = None, # not used
|
| 1352 |
+
score_mod_fn: Optional[Callable] = None,
|
| 1353 |
+
mask_fn: Optional[Callable] = None,
|
| 1354 |
+
is_first_n_block: cutlass.Constexpr = False,
|
| 1355 |
+
check_inf: cutlass.Constexpr = True,
|
| 1356 |
+
):
|
| 1357 |
+
pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read))
|
| 1358 |
+
# S = Q @ K.T
|
| 1359 |
+
acc_S = mma_qk_fn(B_idx=smem_pipe_read.index, wg_wait=-1)
|
| 1360 |
+
self.warp_scheduler_barrier_arrive()
|
| 1361 |
+
warpgroup.wait_group(0)
|
| 1362 |
+
pipeline_k.consumer_release(smem_pipe_read)
|
| 1363 |
+
|
| 1364 |
+
# handle score mods and masking
|
| 1365 |
+
if const_expr(score_mod_fn is not None):
|
| 1366 |
+
score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen)
|
| 1367 |
+
if const_expr(mask_fn is not None):
|
| 1368 |
+
mask_fn(acc_S=acc_S, n_block=n_block)
|
| 1369 |
+
|
| 1370 |
+
row_scale = softmax.online_softmax(acc_S, is_first=is_first_n_block, check_inf=check_inf)
|
| 1371 |
+
# if cute.arch.thread_idx()[0] == 0: cute.print_tensor(layout_utils.reshape_acc_to_mn(acc_S))
|
| 1372 |
+
tOrP_acc = layout_utils.reshape_acc_to_frgA(acc_S)
|
| 1373 |
+
tOrP_cur = (
|
| 1374 |
+
tOrP
|
| 1375 |
+
if const_expr(self.mma_pv_is_rs)
|
| 1376 |
+
else cute.make_rmem_tensor_like(tOrP_acc, self.dtype)
|
| 1377 |
+
)
|
| 1378 |
+
# tOrP.store(tOrP_acc.load().to(self.dtype))
|
| 1379 |
+
# the "to(self.dtype)" conversion fails to vectorize for block sizes other
|
| 1380 |
+
# than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of
|
| 1381 |
+
# 2 elements. So we just call ptx directly.
|
| 1382 |
+
utils.cvt_f16(tOrP_acc, tOrP_cur)
|
| 1383 |
+
if const_expr(not self.mma_pv_is_rs):
|
| 1384 |
+
tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur)
|
| 1385 |
+
cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP)
|
| 1386 |
+
softmax.rescale_O(acc_O, row_scale)
|
| 1387 |
+
if const_expr(not self.mma_pv_is_rs):
|
| 1388 |
+
# Fence and barrier to make sure smem store is visible to WGMMA
|
| 1389 |
+
cute.arch.fence_view_async_shared()
|
| 1390 |
+
cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV
|
| 1391 |
+
pipeline_v.consumer_wait(smem_pipe_read, pipeline_v.consumer_try_wait(smem_pipe_read))
|
| 1392 |
+
self.warp_scheduler_barrier_sync()
|
| 1393 |
+
# O += P @ V
|
| 1394 |
+
mma_pv_fn(B_idx=smem_pipe_read.index, wg_wait=0)
|
| 1395 |
+
pipeline_v.consumer_release(smem_pipe_read)
|
| 1396 |
+
smem_pipe_read.advance()
|
| 1397 |
+
return smem_pipe_read
|
| 1398 |
+
|
| 1399 |
+
@cute.jit
|
| 1400 |
+
def mma_one_n_block_intrawg_overlap(
|
| 1401 |
+
self,
|
| 1402 |
+
smem_pipe_read: pipeline.PipelineState | pipeline_custom.PipelineStateSimple,
|
| 1403 |
+
n_block: Int32,
|
| 1404 |
+
mma_qk_fn: Callable,
|
| 1405 |
+
mma_pv_fn: Callable,
|
| 1406 |
+
pipeline_k: pipeline.PipelineAsync,
|
| 1407 |
+
pipeline_v: pipeline.PipelineAsync,
|
| 1408 |
+
acc_O: cute.Tensor,
|
| 1409 |
+
tOrP: cute.Tensor,
|
| 1410 |
+
smem_copy_params: SimpleNamespace,
|
| 1411 |
+
softmax: Softmax,
|
| 1412 |
+
seqlen: SeqlenInfoQK,
|
| 1413 |
+
scores_scale: Optional[cute.Tensor] = None,
|
| 1414 |
+
score_mod_fn: Optional[Callable] = None,
|
| 1415 |
+
mask_fn: Optional[Callable] = None,
|
| 1416 |
+
check_inf: cutlass.Constexpr = True,
|
| 1417 |
+
):
|
| 1418 |
+
smem_pipe_read_v = smem_pipe_read.clone()
|
| 1419 |
+
smem_pipe_read.advance()
|
| 1420 |
+
pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read))
|
| 1421 |
+
self.warp_scheduler_barrier_sync()
|
| 1422 |
+
# S = Q @ K.T
|
| 1423 |
+
acc_S = mma_qk_fn(B_idx=smem_pipe_read.index, wg_wait=-1)
|
| 1424 |
+
# RescaleOBeforeGemm: rescale O while QK GEMM is in flight, before PV GEMM
|
| 1425 |
+
if const_expr(self.rescale_O_before_gemm):
|
| 1426 |
+
softmax.rescale_O(acc_O, scores_scale)
|
| 1427 |
+
pipeline_v.consumer_wait(smem_pipe_read_v, pipeline_v.consumer_try_wait(smem_pipe_read_v))
|
| 1428 |
+
# O += P @ V
|
| 1429 |
+
mma_pv_fn(B_idx=smem_pipe_read_v.index, wg_wait=-1)
|
| 1430 |
+
self.warp_scheduler_barrier_arrive()
|
| 1431 |
+
warpgroup.wait_group(1)
|
| 1432 |
+
pipeline_k.consumer_release(smem_pipe_read)
|
| 1433 |
+
|
| 1434 |
+
# handle score mods and masking
|
| 1435 |
+
if const_expr(score_mod_fn is not None):
|
| 1436 |
+
score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen)
|
| 1437 |
+
if const_expr(mask_fn is not None):
|
| 1438 |
+
mask_fn(acc_S=acc_S, n_block=n_block)
|
| 1439 |
+
# if cute.arch.thread_idx()[0] == 128: cute.print_tensor(layout_utils.reshape_acc_to_mn(acc_S))
|
| 1440 |
+
|
| 1441 |
+
row_scale = softmax.online_softmax(acc_S, check_inf=check_inf)
|
| 1442 |
+
warpgroup.wait_group(0)
|
| 1443 |
+
pipeline_v.consumer_release(smem_pipe_read_v)
|
| 1444 |
+
tOrP_acc = layout_utils.reshape_acc_to_frgA(acc_S)
|
| 1445 |
+
tOrP_cur = (
|
| 1446 |
+
tOrP
|
| 1447 |
+
if const_expr(self.mma_pv_is_rs)
|
| 1448 |
+
else cute.make_rmem_tensor_like(tOrP_acc, self.dtype)
|
| 1449 |
+
)
|
| 1450 |
+
# tOrP_cur.store(tOrP_acc.load().to(self.dtype))
|
| 1451 |
+
# the "to(self.dtype)" conversion fails to vectorize for block sizes other
|
| 1452 |
+
# than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of
|
| 1453 |
+
# 2 elements. So we just call ptx directly.
|
| 1454 |
+
utils.cvt_f16(tOrP_acc, tOrP_cur)
|
| 1455 |
+
if const_expr(not self.mma_pv_is_rs):
|
| 1456 |
+
tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur)
|
| 1457 |
+
cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP)
|
| 1458 |
+
if const_expr(not self.rescale_O_before_gemm):
|
| 1459 |
+
softmax.rescale_O(acc_O, row_scale)
|
| 1460 |
+
if const_expr(self.rescale_O_before_gemm):
|
| 1461 |
+
scores_scale.store(row_scale.load())
|
| 1462 |
+
if const_expr(not self.mma_pv_is_rs):
|
| 1463 |
+
# Fence and barrier to make sure smem store is visible to WGMMA
|
| 1464 |
+
cute.arch.fence_view_async_shared()
|
| 1465 |
+
cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV
|
| 1466 |
+
return smem_pipe_read
|
| 1467 |
+
|
| 1468 |
+
@cute.jit
|
| 1469 |
+
def mma_init(self):
|
| 1470 |
+
warp_group_idx = utils.canonical_warp_group_idx(sync=False)
|
| 1471 |
+
if const_expr(self.use_scheduler_barrier):
|
| 1472 |
+
if warp_group_idx == 1:
|
| 1473 |
+
cute.arch.barrier_arrive(
|
| 1474 |
+
barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1),
|
| 1475 |
+
number_of_threads=2 * self.num_threads_per_warp_group,
|
| 1476 |
+
)
|
| 1477 |
+
|
| 1478 |
+
@cute.jit
|
| 1479 |
+
def apply_score_mod(
|
| 1480 |
+
self,
|
| 1481 |
+
thr_mma_qk,
|
| 1482 |
+
batch_idx,
|
| 1483 |
+
head_idx,
|
| 1484 |
+
m_block,
|
| 1485 |
+
acc_S,
|
| 1486 |
+
n_block,
|
| 1487 |
+
softmax_scale,
|
| 1488 |
+
seqlen,
|
| 1489 |
+
aux_tensors: Optional[list] = None,
|
| 1490 |
+
fastdiv_mods=None,
|
| 1491 |
+
):
|
| 1492 |
+
# Prepare index tensor
|
| 1493 |
+
cS = cute.make_identity_tensor((self.tile_m, self.tile_n))
|
| 1494 |
+
cS = cute.domain_offset((m_block * self.tile_m, n_block * self.tile_n), cS)
|
| 1495 |
+
tScS = thr_mma_qk.partition_C(cS)
|
| 1496 |
+
|
| 1497 |
+
apply_score_mod_inner(
|
| 1498 |
+
acc_S,
|
| 1499 |
+
tScS,
|
| 1500 |
+
self.score_mod,
|
| 1501 |
+
batch_idx,
|
| 1502 |
+
head_idx,
|
| 1503 |
+
softmax_scale,
|
| 1504 |
+
self.vec_size,
|
| 1505 |
+
self.qk_acc_dtype,
|
| 1506 |
+
aux_tensors,
|
| 1507 |
+
fastdiv_mods,
|
| 1508 |
+
seqlen_info=seqlen,
|
| 1509 |
+
constant_q_idx=None,
|
| 1510 |
+
qhead_per_kvhead=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
|
| 1511 |
+
)
|
| 1512 |
+
|
| 1513 |
+
def warp_scheduler_barrier_sync(self):
|
| 1514 |
+
if const_expr(self.use_scheduler_barrier):
|
| 1515 |
+
cute.arch.barrier(
|
| 1516 |
+
barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1)
|
| 1517 |
+
- 1
|
| 1518 |
+
+ utils.canonical_warp_group_idx(sync=False),
|
| 1519 |
+
number_of_threads=2 * self.num_threads_per_warp_group,
|
| 1520 |
+
)
|
| 1521 |
+
|
| 1522 |
+
def warp_scheduler_barrier_arrive(self):
|
| 1523 |
+
if const_expr(self.use_scheduler_barrier):
|
| 1524 |
+
assert self.num_wg_mma in [2, 3]
|
| 1525 |
+
cur_wg = utils.canonical_warp_group_idx(sync=False) - 1
|
| 1526 |
+
if const_expr(self.num_wg_mma == 2):
|
| 1527 |
+
next_wg = 1 - cur_wg
|
| 1528 |
+
else:
|
| 1529 |
+
t = cur_wg + 1
|
| 1530 |
+
next_wg = t % self.num_wg_mma
|
| 1531 |
+
cute.arch.barrier_arrive(
|
| 1532 |
+
barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + next_wg,
|
| 1533 |
+
number_of_threads=2 * self.num_threads_per_warp_group,
|
| 1534 |
+
)
|
build/torch-cuda/interface.py
CHANGED
|
@@ -21,6 +21,7 @@
|
|
| 21 |
|
| 22 |
import os
|
| 23 |
import math
|
|
|
|
| 24 |
from functools import lru_cache
|
| 25 |
from typing import Optional, Tuple, Callable
|
| 26 |
|
|
@@ -31,6 +32,8 @@ import cuda.bindings.driver as cuda
|
|
| 31 |
|
| 32 |
import cutlass
|
| 33 |
import cutlass.cute as cute
|
|
|
|
|
|
|
| 34 |
from .cache_utils import get_jit_cache
|
| 35 |
from .testing import is_fake_mode
|
| 36 |
|
|
@@ -43,30 +46,201 @@ if os.environ.get("CUTE_DSL_PTXAS_PATH", None) is not None:
|
|
| 43 |
|
| 44 |
|
| 45 |
from . import utils
|
|
|
|
| 46 |
from .cute_dsl_utils import (
|
| 47 |
to_cute_tensor, to_cute_aux_tensor, get_aux_tensor_metadata, get_broadcast_dims,
|
| 48 |
)
|
| 49 |
-
from .flash_fwd import
|
|
|
|
| 50 |
from .flash_fwd_sm100 import FlashAttentionForwardSm100
|
|
|
|
| 51 |
from .flash_bwd_preprocess import FlashAttentionBackwardPreprocess
|
| 52 |
from .flash_bwd import FlashAttentionBackwardSm80
|
| 53 |
from .flash_bwd_sm90 import FlashAttentionBackwardSm90
|
| 54 |
from .flash_bwd_sm100 import FlashAttentionBackwardSm100
|
|
|
|
| 55 |
from .flash_bwd_postprocess import FlashAttentionBackwardPostprocess
|
| 56 |
from .flash_fwd_combine import FlashAttentionForwardCombine
|
| 57 |
|
| 58 |
from .block_sparsity import (
|
| 59 |
BlockSparseTensorsTorch,
|
|
|
|
| 60 |
to_cute_block_sparse_tensors,
|
| 61 |
normalize_block_sparse_config,
|
| 62 |
normalize_block_sparse_config_bwd,
|
| 63 |
)
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
@lru_cache(maxsize=None)
|
| 66 |
def _get_device_arch():
|
| 67 |
-
"""Cached device arch check.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
major, minor = torch.cuda.get_device_capability()
|
| 69 |
-
return major * 10 + minor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
def maybe_contiguous(x):
|
| 72 |
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
|
|
@@ -76,7 +250,8 @@ def _validate_tensor(t, name, expected_shape, expected_dtype, expected_device):
|
|
| 76 |
assert t.shape == expected_shape, f"{name} shape {t.shape} != expected {expected_shape}"
|
| 77 |
assert t.dtype == expected_dtype, f"{name} dtype {t.dtype} != expected {expected_dtype}"
|
| 78 |
assert t.device == expected_device, f"{name} device {t.device} != expected {expected_device}"
|
| 79 |
-
|
|
|
|
| 80 |
|
| 81 |
|
| 82 |
torch2cute_dtype_map = {
|
|
@@ -96,6 +271,29 @@ def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, max_splits):
|
|
| 96 |
return min(num_SMs // total_mblocks, max_splits, num_n_blocks)
|
| 97 |
|
| 98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
def _flash_attn_fwd(
|
| 100 |
q: torch.Tensor,
|
| 101 |
k: torch.Tensor,
|
|
@@ -113,11 +311,9 @@ def _flash_attn_fwd(
|
|
| 113 |
window_size_left: Optional[int] = None,
|
| 114 |
window_size_right: Optional[int] = None,
|
| 115 |
learnable_sink: Optional[torch.Tensor] = None,
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
m_block_size: int = 128,
|
| 120 |
-
n_block_size: int = 128,
|
| 121 |
num_threads: int = 384,
|
| 122 |
num_splits: int = 1,
|
| 123 |
pack_gqa: Optional[bool] = None,
|
|
@@ -138,7 +334,7 @@ def _flash_attn_fwd(
|
|
| 138 |
mask_mod: A callable that takes token position information and selectively masks
|
| 139 |
block_sparse_tensors: A tuple of tensors used for block sparsity.
|
| 140 |
return_lse: Whether to return the log softmax of the attention scores. If set to True will always calculate
|
| 141 |
-
|
| 142 |
out: Optional pre-allocated output tensor. If None, will be allocated internally.
|
| 143 |
lse: Optional pre-allocated log-sum-exp tensor. If None, will be allocated when needed.
|
| 144 |
aux_tensors: Some score_mods will want to read from global aux_tensors. This is how we thread them through to the inner kernel.
|
|
@@ -203,25 +399,27 @@ def _flash_attn_fwd(
|
|
| 203 |
assert learnable_sink.shape == (num_head,)
|
| 204 |
assert learnable_sink.dtype == torch.bfloat16, "learnable_sink must be bfloat16"
|
| 205 |
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
|
|
|
|
|
|
|
|
|
| 220 |
assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv"
|
| 221 |
-
assert head_dim <= 256, "head_dim must be less than or equal to 256"
|
| 222 |
alignment = 16 // q.element_size()
|
| 223 |
-
|
| 224 |
-
|
| 225 |
if softmax_scale is None:
|
| 226 |
softmax_scale = 1.0 / math.sqrt(head_dim)
|
| 227 |
if softcap == 0.0:
|
|
@@ -253,43 +451,47 @@ def _flash_attn_fwd(
|
|
| 253 |
_validate_tensor(lse, "lse", lse_shape, torch.float32, device)
|
| 254 |
|
| 255 |
dtype = torch2cute_dtype_map[q.dtype]
|
| 256 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
|
| 258 |
-
|
|
|
|
| 259 |
|
| 260 |
-
|
| 261 |
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
|
|
|
| 273 |
else:
|
| 274 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
else:
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
|
|
|
|
|
|
|
|
|
| 279 |
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
if arch // 10 in [10, 11]:
|
| 285 |
-
if (
|
| 286 |
-
pack_gqa
|
| 287 |
-
and (128 % qhead_per_kvhead != 0)
|
| 288 |
-
):
|
| 289 |
-
pack_gqa = False
|
| 290 |
-
# TODO: fix GQA + SplitKV + non-varlen
|
| 291 |
-
if pack_gqa and num_splits != 1 and cu_seqlens_q is None:
|
| 292 |
-
pack_gqa = False
|
| 293 |
|
| 294 |
if max_seqlen_q is None:
|
| 295 |
max_seqlen_q = seqlen_q if cu_seqlens_q is None else total_q
|
|
@@ -297,28 +499,50 @@ def _flash_attn_fwd(
|
|
| 297 |
max_seqlen_k = seqlen_k
|
| 298 |
seqlen_q_packgqa = max_seqlen_q * qhead_per_kvhead
|
| 299 |
if arch // 10 == 10:
|
| 300 |
-
q_stage = 2 if seqlen_q_packgqa >
|
| 301 |
else:
|
| 302 |
q_stage = 1
|
| 303 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
if num_splits < 1:
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
num_n_blocks,
|
| 314 |
-
|
| 315 |
-
|
| 316 |
|
| 317 |
is_split_kv = num_splits > 1
|
| 318 |
if is_split_kv:
|
| 319 |
out_partial = torch.empty(num_splits, *q_batch_seqlen_shape, num_head, head_dim_v, dtype=torch.float32, device=device)
|
| 320 |
lse_partial = torch.empty(num_splits, *lse_shape, dtype=torch.float32, device=device)
|
| 321 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 322 |
# hash score and mask mods for compile cache
|
| 323 |
score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else False
|
| 324 |
mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else False
|
|
@@ -370,14 +594,14 @@ def _flash_attn_fwd(
|
|
| 370 |
num_head=num_head,
|
| 371 |
seqlen_q=seqlen_q,
|
| 372 |
seqlen_k=seqlen_k,
|
| 373 |
-
block_size=(
|
| 374 |
q_stage=q_stage,
|
| 375 |
)
|
| 376 |
-
if aux_tensors is not None:
|
| 377 |
aux_tensor_metadata = get_aux_tensor_metadata(aux_tensors)
|
| 378 |
else:
|
| 379 |
aux_tensor_metadata = None
|
| 380 |
-
|
| 381 |
compile_key = (
|
| 382 |
dtype,
|
| 383 |
head_dim,
|
|
@@ -398,15 +622,20 @@ def _flash_attn_fwd(
|
|
| 398 |
window_size_left is not None,
|
| 399 |
window_size_right is not None,
|
| 400 |
learnable_sink is not None,
|
| 401 |
-
|
| 402 |
-
|
| 403 |
q_stage,
|
| 404 |
num_threads,
|
| 405 |
is_split_kv,
|
| 406 |
pack_gqa,
|
| 407 |
arch,
|
| 408 |
-
page_size not in [None,
|
|
|
|
| 409 |
q_subtile_factor,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 410 |
)
|
| 411 |
if compile_key not in _flash_attn_fwd.compile_cache:
|
| 412 |
(
|
|
@@ -445,10 +674,28 @@ def _flash_attn_fwd(
|
|
| 445 |
if aux_tensors is not None:
|
| 446 |
cute_aux_tensors = [to_cute_aux_tensor(buf) for buf in aux_tensors]
|
| 447 |
|
| 448 |
-
if arch // 10 ==
|
| 449 |
-
assert page_table is None, "paged KV not supported on SM
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 450 |
assert not is_split_kv, "SplitKV not supported on SM 9.0"
|
| 451 |
-
# fa_fwd = FlashAttentionForwardSm80(
|
| 452 |
fa_fwd = FlashAttentionForwardSm90(
|
| 453 |
dtype,
|
| 454 |
head_dim,
|
|
@@ -457,33 +704,21 @@ def _flash_attn_fwd(
|
|
| 457 |
is_causal=causal,
|
| 458 |
is_local=local,
|
| 459 |
pack_gqa=pack_gqa,
|
| 460 |
-
tile_m=
|
| 461 |
-
tile_n=
|
| 462 |
# num_stages=1,
|
| 463 |
num_stages=2,
|
| 464 |
num_threads=num_threads,
|
| 465 |
Q_in_regs=False,
|
| 466 |
-
intra_wg_overlap=
|
| 467 |
-
mma_pv_is_rs=
|
| 468 |
mask_mod=mask_mod,
|
| 469 |
score_mod=score_mod,
|
| 470 |
has_aux_tensors=aux_tensors is not None,
|
| 471 |
q_subtile_factor=q_subtile_factor,
|
|
|
|
| 472 |
)
|
| 473 |
elif arch // 10 in [10, 11]:
|
| 474 |
-
head_dim_padded = int(math.ceil(head_dim / 16) * 16)
|
| 475 |
-
head_dim_v_padded = int(math.ceil(head_dim / 16) * 16)
|
| 476 |
-
use_2cta_instrs = (
|
| 477 |
-
not causal
|
| 478 |
-
and not local
|
| 479 |
-
and not is_split_kv
|
| 480 |
-
and cu_seqlens_q is None
|
| 481 |
-
and seqused_q is None
|
| 482 |
-
and not use_block_sparsity
|
| 483 |
-
and page_size in [None, 128]
|
| 484 |
-
and head_dim_padded == 128
|
| 485 |
-
and head_dim_v_padded == 128
|
| 486 |
-
)
|
| 487 |
fa_fwd = FlashAttentionForwardSm100(
|
| 488 |
head_dim,
|
| 489 |
head_dim_v,
|
|
@@ -492,8 +727,8 @@ def _flash_attn_fwd(
|
|
| 492 |
is_local=local,
|
| 493 |
is_split_kv=is_split_kv,
|
| 494 |
pack_gqa=pack_gqa,
|
| 495 |
-
m_block_size=
|
| 496 |
-
n_block_size=
|
| 497 |
q_stage=q_stage,
|
| 498 |
is_persistent=not causal
|
| 499 |
and not local
|
|
@@ -503,14 +738,37 @@ def _flash_attn_fwd(
|
|
| 503 |
score_mod=score_mod,
|
| 504 |
mask_mod=mask_mod,
|
| 505 |
has_aux_tensors=aux_tensors is not None,
|
| 506 |
-
paged_kv_non_tma=page_size not in [None,
|
| 507 |
is_varlen_q=cu_seqlens_q is not None or seqused_q is not None,
|
| 508 |
q_subtile_factor=q_subtile_factor,
|
| 509 |
use_2cta_instrs=use_2cta_instrs,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 510 |
)
|
| 511 |
else:
|
| 512 |
raise ValueError(
|
| 513 |
-
f"Unsupported compute capability: {arch}. Supported: 9.x, 10.x, 11.x"
|
| 514 |
)
|
| 515 |
# TODO: check @can_implement
|
| 516 |
_flash_attn_fwd.compile_cache[compile_key] = cute.compile(
|
|
@@ -521,7 +779,6 @@ def _flash_attn_fwd(
|
|
| 521 |
o_tensor,
|
| 522 |
lse_tensor,
|
| 523 |
softmax_scale,
|
| 524 |
-
current_stream,
|
| 525 |
cu_seqlens_q_tensor,
|
| 526 |
cu_seqlens_k_tensor,
|
| 527 |
seqused_q_tensor,
|
|
@@ -532,6 +789,7 @@ def _flash_attn_fwd(
|
|
| 532 |
learnable_sink_tensor,
|
| 533 |
sparse_tensors,
|
| 534 |
cute_aux_tensors,
|
|
|
|
| 535 |
options="--enable-tvm-ffi",
|
| 536 |
)
|
| 537 |
|
|
@@ -547,7 +805,6 @@ def _flash_attn_fwd(
|
|
| 547 |
out.detach() if not is_split_kv else out_partial,
|
| 548 |
lse_partial if is_split_kv else lse,
|
| 549 |
softmax_scale,
|
| 550 |
-
current_stream,
|
| 551 |
cu_seqlens_q,
|
| 552 |
cu_seqlens_k,
|
| 553 |
seqused_q,
|
|
@@ -574,6 +831,140 @@ def _flash_attn_fwd(
|
|
| 574 |
_flash_attn_fwd.compile_cache = get_jit_cache("fwd")
|
| 575 |
|
| 576 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 577 |
def _flash_attn_bwd(
|
| 578 |
q: torch.Tensor,
|
| 579 |
k: torch.Tensor,
|
|
@@ -614,47 +1005,74 @@ def _flash_attn_bwd(
|
|
| 614 |
mask_mod: Optional[Callable] = None,
|
| 615 |
aux_tensors: Optional[list[torch.Tensor]] = None,
|
| 616 |
block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None,
|
|
|
|
| 617 |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 618 |
arch = _get_device_arch()
|
| 619 |
-
assert arch // 10 in [9, 10, 11], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x"
|
|
|
|
|
|
|
|
|
|
| 620 |
|
| 621 |
num_head, head_dim = q.shape[-2:]
|
|
|
|
| 622 |
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
window_size_left = None
|
| 627 |
-
window_size_right = None
|
| 628 |
-
local = window_size_left is not None or window_size_right is not None
|
| 629 |
-
if local:
|
| 630 |
-
if window_size_left is None and window_size_right == 0:
|
| 631 |
-
causal, local = True, False
|
| 632 |
-
window_size_right = None
|
| 633 |
-
else:
|
| 634 |
-
causal, local = False, True
|
| 635 |
|
| 636 |
-
if arch // 10 ==
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
|
| 642 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 643 |
dKV_swapAB = False
|
| 644 |
-
dQ_swapAB =
|
| 645 |
-
AtomLayoutMSdP =
|
| 646 |
-
AtomLayoutNdKV =
|
| 647 |
-
AtomLayoutMdQ =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 648 |
cluster_size = 1
|
| 649 |
use_2cta_instrs = False
|
| 650 |
-
assert window_size_left is None and window_size_right is None, "local not supported yet on 9.x"
|
| 651 |
is_varlen = (
|
| 652 |
cu_seqlens_q is not None
|
| 653 |
or cu_seqlens_k is not None
|
| 654 |
or seqused_q is not None
|
| 655 |
or seqused_k is not None
|
| 656 |
)
|
| 657 |
-
assert not is_varlen, "varlen backward is not yet supported on sm90"
|
| 658 |
else:
|
| 659 |
m_block_size = 128
|
| 660 |
n_block_size = 128
|
|
@@ -662,15 +1080,17 @@ def _flash_attn_bwd(
|
|
| 662 |
dKV_swapAB = False
|
| 663 |
AtomLayoutMdQ = 1
|
| 664 |
AtomLayoutNdKV = 1
|
|
|
|
| 665 |
disable_2cta = (
|
| 666 |
-
|
| 667 |
or score_mod is not None
|
| 668 |
or score_mod_bwd is not None
|
| 669 |
or mask_mod is not None
|
|
|
|
| 670 |
)
|
| 671 |
cluster_size = 2 if head_dim >= 128 and not disable_2cta else 1
|
| 672 |
use_2cta_instrs = cluster_size==2
|
| 673 |
-
|
| 674 |
q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = [
|
| 675 |
maybe_contiguous(t)
|
| 676 |
for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
|
|
@@ -692,19 +1112,9 @@ def _flash_attn_bwd(
|
|
| 692 |
seqlen_k = max_seqlen_k if max_seqlen_k is not None else total_k
|
| 693 |
|
| 694 |
num_head_kv = k.shape[-2]
|
| 695 |
-
head_dim_v = v.shape[-1]
|
| 696 |
|
| 697 |
use_block_sparsity = block_sparse_tensors is not None
|
| 698 |
-
|
| 699 |
-
# SM90 block-sparse backward: tile_m=64 is the GCD between a m_block_size that fits,
|
| 700 |
-
# the base block_m of 128 from forward, and block-sparse size for subtiling.
|
| 701 |
-
if arch // 10 == 9 and use_block_sparsity:
|
| 702 |
-
m_block_size = 64
|
| 703 |
-
# dQ_swapAB tuning: use False when m_block_size=64 (same as causal case)
|
| 704 |
-
dQ_swapAB = False
|
| 705 |
-
|
| 706 |
-
# NB: this could be derived from the block_sparse_tensors but for now we hardcode it to 2
|
| 707 |
-
subtile_factor = 2
|
| 708 |
seqlen_q_rounded = (seqlen_q + m_block_size - 1) // m_block_size * m_block_size
|
| 709 |
seqlen_k_rounded = (seqlen_k + n_block_size - 1) // n_block_size * n_block_size
|
| 710 |
num_n_blocks = seqlen_k_rounded // n_block_size
|
|
@@ -744,14 +1154,16 @@ def _flash_attn_bwd(
|
|
| 744 |
if t is not None:
|
| 745 |
assert t.dtype == torch.int32, "cu_seqlens_q, cu_seqlens_k must be int32"
|
| 746 |
assert lse.dtype == torch.float32, "lse must be float32"
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
|
|
|
|
|
|
|
|
|
|
| 750 |
assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv"
|
| 751 |
-
assert head_dim <= 256, "head_dim must be less than or equal to 256"
|
| 752 |
alignment = 16 // q.element_size()
|
| 753 |
-
|
| 754 |
-
|
| 755 |
if softmax_scale is None:
|
| 756 |
softmax_scale = 1.0 / math.sqrt(head_dim)
|
| 757 |
qhead_per_kvhead = num_head // num_head_kv
|
|
@@ -759,9 +1171,6 @@ def _flash_attn_bwd(
|
|
| 759 |
pack_gqa = qhead_per_kvhead > 1
|
| 760 |
# pack_gqa backward not yet supported in bwd
|
| 761 |
pack_gqa = False
|
| 762 |
-
if arch // 10 not in [10, 11]:
|
| 763 |
-
assert deterministic is False, "bwd deterministic only supported for sm100/sm110 for now"
|
| 764 |
-
|
| 765 |
if score_mod is not None:
|
| 766 |
assert score_mod_bwd is not None, "score_mod_bwd is required when score_mod is provided"
|
| 767 |
assert softcap == 0.0, "softcap and score_mod are mutually exclusive (different log2 scaling)"
|
|
@@ -813,6 +1222,9 @@ def _flash_attn_bwd(
|
|
| 813 |
dpsum = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device)
|
| 814 |
lse_log2 = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device)
|
| 815 |
|
|
|
|
|
|
|
|
|
|
| 816 |
dKV_postprocess = qhead_per_kvhead > 1
|
| 817 |
if dKV_postprocess:
|
| 818 |
head_dim_v_rounded = (head_dim_v + 32 - 1) // 32 * 32
|
|
@@ -850,83 +1262,30 @@ def _flash_attn_bwd(
|
|
| 850 |
)
|
| 851 |
|
| 852 |
dtype = torch2cute_dtype_map[q.dtype]
|
| 853 |
-
current_stream =
|
| 854 |
|
| 855 |
if deterministic:
|
| 856 |
-
dQ_semaphore = torch.zeros(batch_size, num_head, seqlen_q_rounded // m_block_size, cluster_size, dtype=torch.int32, device=
|
| 857 |
else:
|
| 858 |
dQ_semaphore = None
|
| 859 |
|
| 860 |
if deterministic and qhead_per_kvhead > 1:
|
| 861 |
-
dK_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device=
|
| 862 |
-
dV_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device=
|
| 863 |
else:
|
| 864 |
dK_semaphore = None
|
| 865 |
dV_semaphore = None
|
| 866 |
|
| 867 |
-
# Preprocess kernel: compute (o * dout).sum(dim=-1), lse * log2_e, and zero out dq_accum.
|
| 868 |
-
|
| 869 |
-
|
| 870 |
-
|
| 871 |
-
head_dim,
|
| 872 |
-
head_dim_v,
|
| 873 |
-
m_block_size,
|
| 874 |
-
num_threads,
|
| 875 |
-
cu_seqlens_q is None,
|
| 876 |
-
seqused_q is None,
|
| 877 |
-
get_broadcast_dims(out),
|
| 878 |
-
get_broadcast_dims(dout),
|
| 879 |
)
|
| 880 |
-
|
| 881 |
-
|
| 882 |
-
|
| 883 |
-
|
| 884 |
-
]
|
| 885 |
-
lse_tensor = to_cute_tensor(lse, assumed_align=4)
|
| 886 |
-
cu_seqlens_q_tensor, seqused_q_tensor = [
|
| 887 |
-
to_cute_tensor(t, assumed_align=4) if t is not None else None
|
| 888 |
-
for t in (cu_seqlens_q, seqused_q)
|
| 889 |
-
]
|
| 890 |
-
fa_bwd_pre = FlashAttentionBackwardPreprocess(
|
| 891 |
-
dtype,
|
| 892 |
-
head_dim,
|
| 893 |
-
head_dim_v,
|
| 894 |
-
arch,
|
| 895 |
-
m_block_size,
|
| 896 |
-
num_threads=num_threads,
|
| 897 |
-
)
|
| 898 |
-
# TODO: check @can_implement
|
| 899 |
-
_flash_attn_bwd.compile_cache_pre[compile_key_pre] = cute.compile(
|
| 900 |
-
fa_bwd_pre,
|
| 901 |
-
o_tensor,
|
| 902 |
-
do_tensor,
|
| 903 |
-
dpsum_tensor,
|
| 904 |
-
lse_tensor,
|
| 905 |
-
lse_log2_tensor,
|
| 906 |
-
dq_accum_tensor,
|
| 907 |
-
cu_seqlens_q_tensor,
|
| 908 |
-
seqused_q_tensor,
|
| 909 |
-
current_stream,
|
| 910 |
-
options="--enable-tvm-ffi",
|
| 911 |
-
)
|
| 912 |
-
if not is_fake_mode():
|
| 913 |
-
_flash_attn_bwd.compile_cache_pre[compile_key_pre](
|
| 914 |
-
out,
|
| 915 |
-
dout,
|
| 916 |
-
dpsum,
|
| 917 |
-
lse,
|
| 918 |
-
lse_log2,
|
| 919 |
-
dq_accum,
|
| 920 |
-
cu_seqlens_q,
|
| 921 |
-
seqused_q,
|
| 922 |
-
current_stream,
|
| 923 |
-
)
|
| 924 |
-
|
| 925 |
-
# NB num_threads application for 3 kernels
|
| 926 |
-
# There are pre, main, post processing kernels, currenlty num_threads is only actually
|
| 927 |
-
# used for the pre proc, and then we hard code to 384 for the main and post proc, and we do
|
| 928 |
-
# before cache key gen
|
| 929 |
-
num_threads = 384
|
| 930 |
|
| 931 |
# Backward kernel: compute dk, dv, dq_accum.
|
| 932 |
score_mod_hash = utils.hash_callable(score_mod) if score_mod else False
|
|
@@ -953,7 +1312,7 @@ def _flash_attn_bwd(
|
|
| 953 |
subtile_factor=subtile_factor,
|
| 954 |
)
|
| 955 |
|
| 956 |
-
if arch // 10
|
| 957 |
compile_key = (
|
| 958 |
arch,
|
| 959 |
dtype,
|
|
@@ -961,6 +1320,8 @@ def _flash_attn_bwd(
|
|
| 961 |
head_dim_v,
|
| 962 |
qhead_per_kvhead,
|
| 963 |
causal,
|
|
|
|
|
|
|
| 964 |
softcap != 0.0,
|
| 965 |
m_block_size,
|
| 966 |
n_block_size,
|
|
@@ -975,6 +1336,8 @@ def _flash_attn_bwd(
|
|
| 975 |
AtomLayoutNdKV,
|
| 976 |
AtomLayoutMdQ,
|
| 977 |
V_in_regs,
|
|
|
|
|
|
|
| 978 |
cu_seqlens_q is None,
|
| 979 |
cu_seqlens_k is None,
|
| 980 |
seqused_q is None,
|
|
@@ -1043,51 +1406,56 @@ def _flash_attn_bwd(
|
|
| 1043 |
if t is not None else None
|
| 1044 |
for t in (dQ_semaphore, dK_semaphore, dV_semaphore)
|
| 1045 |
]
|
| 1046 |
-
|
| 1047 |
-
|
| 1048 |
-
|
| 1049 |
-
head_dim_v,
|
| 1050 |
-
qhead_per_kvhead,
|
| 1051 |
-
m_block_size,
|
| 1052 |
-
n_block_size,
|
| 1053 |
-
num_stages_Q,
|
| 1054 |
-
num_stages_dO,
|
| 1055 |
-
num_threads,
|
| 1056 |
-
pack_gqa,
|
| 1057 |
-
causal,
|
| 1058 |
-
SdP_swapAB,
|
| 1059 |
-
dKV_swapAB,
|
| 1060 |
-
dQ_swapAB,
|
| 1061 |
-
AtomLayoutMSdP,
|
| 1062 |
-
AtomLayoutNdKV,
|
| 1063 |
-
AtomLayoutMdQ,
|
| 1064 |
-
V_in_regs=V_in_regs,
|
| 1065 |
-
)
|
| 1066 |
-
if arch // 10 == 9:
|
| 1067 |
-
fa_bwd_obj = FlashAttentionBackwardSm90(
|
| 1068 |
dtype,
|
| 1069 |
head_dim,
|
| 1070 |
head_dim_v,
|
| 1071 |
qhead_per_kvhead,
|
| 1072 |
-
causal,
|
| 1073 |
m_block_size,
|
| 1074 |
n_block_size,
|
| 1075 |
num_stages_Q,
|
| 1076 |
num_stages_dO,
|
| 1077 |
-
|
|
|
|
|
|
|
| 1078 |
SdP_swapAB,
|
| 1079 |
dKV_swapAB,
|
| 1080 |
dQ_swapAB,
|
| 1081 |
AtomLayoutMSdP,
|
| 1082 |
AtomLayoutNdKV,
|
| 1083 |
AtomLayoutMdQ,
|
| 1084 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1085 |
V_in_regs=V_in_regs,
|
| 1086 |
score_mod=score_mod,
|
| 1087 |
score_mod_bwd=score_mod_bwd,
|
| 1088 |
mask_mod=mask_mod,
|
| 1089 |
has_aux_tensors=aux_tensors is not None,
|
| 1090 |
subtile_factor=subtile_factor,
|
|
|
|
| 1091 |
)
|
| 1092 |
else:
|
| 1093 |
fa_bwd_obj = FlashAttentionBackwardSm100(
|
|
@@ -1126,7 +1494,6 @@ def _flash_attn_bwd(
|
|
| 1126 |
dk_tensor if not dKV_postprocess else dk_accum_tensor,
|
| 1127 |
dv_tensor if not dKV_postprocess else dv_accum_tensor,
|
| 1128 |
softmax_scale,
|
| 1129 |
-
current_stream,
|
| 1130 |
cu_seqlens_q_tensor,
|
| 1131 |
cu_seqlens_k_tensor,
|
| 1132 |
seqused_q_tensor,
|
|
@@ -1139,6 +1506,7 @@ def _flash_attn_bwd(
|
|
| 1139 |
dV_semaphore_tensor,
|
| 1140 |
cute_aux_tensors,
|
| 1141 |
sparse_tensors_compile,
|
|
|
|
| 1142 |
options="--enable-tvm-ffi",
|
| 1143 |
)
|
| 1144 |
if not is_fake_mode():
|
|
@@ -1153,7 +1521,6 @@ def _flash_attn_bwd(
|
|
| 1153 |
dk if not dKV_postprocess else dk_accum,
|
| 1154 |
dv if not dKV_postprocess else dv_accum,
|
| 1155 |
softmax_scale,
|
| 1156 |
-
current_stream,
|
| 1157 |
cu_seqlens_q,
|
| 1158 |
cu_seqlens_k,
|
| 1159 |
seqused_q,
|
|
@@ -1168,157 +1535,45 @@ def _flash_attn_bwd(
|
|
| 1168 |
normalized_block_sparse_tensors[:4] if normalized_block_sparse_tensors is not None else None,
|
| 1169 |
)
|
| 1170 |
|
| 1171 |
-
|
| 1172 |
-
|
| 1173 |
-
|
| 1174 |
-
|
| 1175 |
-
|
| 1176 |
-
|
| 1177 |
-
|
| 1178 |
-
|
| 1179 |
-
|
| 1180 |
-
|
| 1181 |
-
|
| 1182 |
-
|
| 1183 |
-
|
| 1184 |
-
|
| 1185 |
-
|
| 1186 |
-
get_broadcast_dims(dq),
|
| 1187 |
)
|
| 1188 |
-
if compile_key_post not in _flash_attn_bwd.compile_cache_post:
|
| 1189 |
-
dq_accum_tensor = to_cute_tensor(dq_accum)
|
| 1190 |
-
dq_tensor = to_cute_tensor(dq)
|
| 1191 |
-
cu_seqlens_q_tensor, seqused_q_tensor = [
|
| 1192 |
-
to_cute_tensor(t, assumed_align=4) if t is not None else None
|
| 1193 |
-
for t in (cu_seqlens_q, seqused_q)
|
| 1194 |
-
]
|
| 1195 |
-
fa_bwd_post = FlashAttentionBackwardPostprocess(
|
| 1196 |
-
dtype, head_dim, arch, m_block_size, num_threads, AtomLayoutMdQ, dQ_swapAB,
|
| 1197 |
-
use_2cta_instrs=use_2cta_instrs,
|
| 1198 |
-
)
|
| 1199 |
-
# TODO: check @can_implement
|
| 1200 |
-
_flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile(
|
| 1201 |
-
fa_bwd_post,
|
| 1202 |
-
dq_accum_tensor,
|
| 1203 |
-
dq_tensor,
|
| 1204 |
-
softmax_scale,
|
| 1205 |
-
cu_seqlens_q_tensor,
|
| 1206 |
-
seqused_q_tensor,
|
| 1207 |
-
current_stream,
|
| 1208 |
-
options="--enable-tvm-ffi",
|
| 1209 |
-
)
|
| 1210 |
-
|
| 1211 |
-
if not is_fake_mode():
|
| 1212 |
-
_flash_attn_bwd.compile_cache_post[compile_key_post](
|
| 1213 |
-
dq_accum,
|
| 1214 |
-
dq,
|
| 1215 |
-
softmax_scale,
|
| 1216 |
-
cu_seqlens_q,
|
| 1217 |
-
seqused_q,
|
| 1218 |
-
current_stream,
|
| 1219 |
-
)
|
| 1220 |
|
| 1221 |
if dKV_postprocess:
|
| 1222 |
-
# Postprocess
|
| 1223 |
-
|
| 1224 |
-
|
| 1225 |
-
|
| 1226 |
-
head_dim,
|
| 1227 |
-
|
| 1228 |
-
|
| 1229 |
-
AtomLayoutNdKV,
|
| 1230 |
-
dKV_swapAB,
|
| 1231 |
-
cu_seqlens_k is None,
|
| 1232 |
-
seqused_k is None,
|
| 1233 |
-
False, # even for 2cta, is split along hdim, so always False
|
| 1234 |
-
cluster_size, # cluster is for tile_n
|
| 1235 |
-
get_broadcast_dims(dk_accum),
|
| 1236 |
-
get_broadcast_dims(dk),
|
| 1237 |
)
|
| 1238 |
-
|
| 1239 |
-
|
| 1240 |
-
|
| 1241 |
-
|
| 1242 |
-
|
| 1243 |
-
|
| 1244 |
-
|
| 1245 |
-
fa_bwd_post = FlashAttentionBackwardPostprocess(
|
| 1246 |
-
dtype, head_dim, arch, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB,
|
| 1247 |
-
cluster_size=cluster_size,
|
| 1248 |
-
)
|
| 1249 |
-
# TODO: check @can_implement
|
| 1250 |
-
_flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile(
|
| 1251 |
-
fa_bwd_post,
|
| 1252 |
-
dk_accum_tensor,
|
| 1253 |
-
dk_tensor,
|
| 1254 |
-
softmax_scale,
|
| 1255 |
-
cu_seqlens_k_tensor,
|
| 1256 |
-
seqused_k_tensor,
|
| 1257 |
-
current_stream,
|
| 1258 |
-
options="--enable-tvm-ffi",
|
| 1259 |
-
)
|
| 1260 |
-
if not is_fake_mode():
|
| 1261 |
-
_flash_attn_bwd.compile_cache_post[compile_key_post](
|
| 1262 |
-
dk_accum,
|
| 1263 |
-
dk,
|
| 1264 |
-
softmax_scale,
|
| 1265 |
-
cu_seqlens_k,
|
| 1266 |
-
seqused_k,
|
| 1267 |
-
current_stream,
|
| 1268 |
-
)
|
| 1269 |
-
compile_key_post = (
|
| 1270 |
-
arch,
|
| 1271 |
-
dtype,
|
| 1272 |
-
head_dim_v,
|
| 1273 |
-
n_block_size,
|
| 1274 |
-
num_threads,
|
| 1275 |
-
AtomLayoutNdKV,
|
| 1276 |
-
dKV_swapAB,
|
| 1277 |
-
cu_seqlens_k is None,
|
| 1278 |
-
seqused_k is None,
|
| 1279 |
-
False,
|
| 1280 |
-
cluster_size,
|
| 1281 |
-
get_broadcast_dims(dv_accum),
|
| 1282 |
-
get_broadcast_dims(dv),
|
| 1283 |
)
|
| 1284 |
-
if compile_key_post not in _flash_attn_bwd.compile_cache_post:
|
| 1285 |
-
dv_accum_tensor = to_cute_tensor(dv_accum)
|
| 1286 |
-
dv_tensor = to_cute_tensor(dv)
|
| 1287 |
-
cu_seqlens_k_tensor, seqused_k_tensor = [
|
| 1288 |
-
to_cute_tensor(t, assumed_align=4) if t is not None else None
|
| 1289 |
-
for t in (cu_seqlens_k, seqused_k)
|
| 1290 |
-
]
|
| 1291 |
-
fa_bwd_post = FlashAttentionBackwardPostprocess(
|
| 1292 |
-
dtype, head_dim_v, arch, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB,
|
| 1293 |
-
cluster_size=cluster_size,
|
| 1294 |
-
)
|
| 1295 |
-
# TODO: check @can_implement
|
| 1296 |
-
_flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile(
|
| 1297 |
-
fa_bwd_post,
|
| 1298 |
-
dv_accum_tensor,
|
| 1299 |
-
dv_tensor,
|
| 1300 |
-
cutlass.Float32(1.0),
|
| 1301 |
-
cu_seqlens_k_tensor,
|
| 1302 |
-
seqused_k_tensor,
|
| 1303 |
-
current_stream,
|
| 1304 |
-
options="--enable-tvm-ffi",
|
| 1305 |
-
)
|
| 1306 |
-
if not is_fake_mode():
|
| 1307 |
-
_flash_attn_bwd.compile_cache_post[compile_key_post](
|
| 1308 |
-
dv_accum,
|
| 1309 |
-
dv,
|
| 1310 |
-
1.0,
|
| 1311 |
-
cu_seqlens_k,
|
| 1312 |
-
seqused_k,
|
| 1313 |
-
current_stream,
|
| 1314 |
-
)
|
| 1315 |
|
| 1316 |
return dq, dk, dv
|
| 1317 |
|
| 1318 |
|
| 1319 |
-
_flash_attn_bwd.compile_cache_pre = get_jit_cache("bwd_pre")
|
| 1320 |
_flash_attn_bwd.compile_cache = get_jit_cache("bwd")
|
| 1321 |
-
_flash_attn_bwd.compile_cache_post = get_jit_cache("bwd_post")
|
| 1322 |
|
| 1323 |
|
| 1324 |
class FlashAttnFunc(torch.autograd.Function):
|
|
@@ -1376,14 +1631,17 @@ class FlashAttnFunc(torch.autograd.Function):
|
|
| 1376 |
ctx.window_size = window_size
|
| 1377 |
ctx.softcap = softcap
|
| 1378 |
ctx.deterministic = deterministic
|
| 1379 |
-
|
| 1380 |
-
|
| 1381 |
-
ctx.mark_non_differentiable(lse)
|
| 1382 |
return out, lse
|
| 1383 |
|
| 1384 |
@staticmethod
|
| 1385 |
-
def backward(ctx, dout,
|
| 1386 |
q, k, v, out, lse = ctx.saved_tensors
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1387 |
dq, dk, dv = _flash_attn_bwd(
|
| 1388 |
q,
|
| 1389 |
k,
|
|
@@ -1397,6 +1655,7 @@ class FlashAttnFunc(torch.autograd.Function):
|
|
| 1397 |
window_size_left=ctx.window_size[0],
|
| 1398 |
window_size_right=ctx.window_size[1],
|
| 1399 |
deterministic=ctx.deterministic,
|
|
|
|
| 1400 |
)
|
| 1401 |
return dq, dk, dv, *((None,) * 20) # Extra Nones is fine
|
| 1402 |
|
|
@@ -1458,15 +1717,18 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
|
|
| 1458 |
ctx.deterministic = deterministic
|
| 1459 |
ctx.max_seqlen_q = max_seqlen_q
|
| 1460 |
ctx.max_seqlen_k = max_seqlen_k
|
| 1461 |
-
|
| 1462 |
-
|
| 1463 |
-
ctx.mark_non_differentiable(lse)
|
| 1464 |
return out, lse
|
| 1465 |
|
| 1466 |
@staticmethod
|
| 1467 |
-
def backward(ctx, dout,
|
| 1468 |
q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
|
| 1469 |
assert ctx.softcap == 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1470 |
dq, dk, dv = _flash_attn_bwd(
|
| 1471 |
q,
|
| 1472 |
k,
|
|
@@ -1486,6 +1748,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
|
|
| 1486 |
max_seqlen_q=ctx.max_seqlen_q,
|
| 1487 |
max_seqlen_k=ctx.max_seqlen_k,
|
| 1488 |
deterministic=ctx.deterministic,
|
|
|
|
| 1489 |
)
|
| 1490 |
|
| 1491 |
return dq, dk, dv, *((None,) * 20)
|
|
@@ -1581,6 +1844,63 @@ def flash_attn_varlen_func(
|
|
| 1581 |
)
|
| 1582 |
|
| 1583 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1584 |
def _flash_attn_fwd_combine(
|
| 1585 |
out_partial: torch.Tensor,
|
| 1586 |
lse_partial: torch.Tensor,
|
|
@@ -1589,6 +1909,7 @@ def _flash_attn_fwd_combine(
|
|
| 1589 |
cu_seqlens: Optional[torch.Tensor] = None,
|
| 1590 |
seqused: Optional[torch.Tensor] = None,
|
| 1591 |
num_splits_dynamic_ptr: Optional[torch.Tensor] = None,
|
|
|
|
| 1592 |
semaphore_to_reset: Optional[torch.Tensor] = None,
|
| 1593 |
) -> None:
|
| 1594 |
"""Forward combine kernel for split attention computation.
|
|
@@ -1612,27 +1933,13 @@ def _flash_attn_fwd_combine(
|
|
| 1612 |
Returns:
|
| 1613 |
None
|
| 1614 |
"""
|
| 1615 |
-
# Input validation
|
| 1616 |
-
assert out_partial.dim() in [4, 5], "out_partial must have 4 or 5 dimensions"
|
| 1617 |
-
assert lse_partial.dim() in [3, 4], "lse_partial must have 3 or 4 dimensions"
|
| 1618 |
assert out_partial.dtype in [torch.float16, torch.bfloat16, torch.float32], (
|
| 1619 |
"out_partial must be fp16, bf16, or fp32"
|
| 1620 |
)
|
| 1621 |
-
|
| 1622 |
-
|
| 1623 |
-
assert out_partial.stride(-1) == 1, "out_partial must be contiguous in the last dimension"
|
| 1624 |
-
assert lse_partial.stride(-2) == 1, "lse_partial must be contiguous in the seqlen dimension"
|
| 1625 |
-
assert lse_partial.shape == out_partial.shape[:-1]
|
| 1626 |
-
|
| 1627 |
# Determine if this is variable length based on dimensions
|
| 1628 |
is_varlen = out_partial.dim() == 4
|
| 1629 |
-
|
| 1630 |
-
# Validate output tensor shapes and types
|
| 1631 |
-
assert out.shape == out_partial.shape[1:], "out shape mismatch"
|
| 1632 |
-
if lse is not None:
|
| 1633 |
-
assert lse.shape == lse_partial.shape[1:], "lse shape mismatch"
|
| 1634 |
-
assert lse.dtype == torch.float32, "lse must be fp32"
|
| 1635 |
-
|
| 1636 |
# Validate optional tensors
|
| 1637 |
for t, name in [
|
| 1638 |
(cu_seqlens, "cu_seqlens"),
|
|
@@ -1640,10 +1947,9 @@ def _flash_attn_fwd_combine(
|
|
| 1640 |
(num_splits_dynamic_ptr, "num_splits_dynamic_ptr"),
|
| 1641 |
]:
|
| 1642 |
if t is not None:
|
| 1643 |
-
|
| 1644 |
-
|
| 1645 |
assert t.is_contiguous(), f"{name} must be contiguous"
|
| 1646 |
-
|
| 1647 |
head_dim = out_partial.shape[-1]
|
| 1648 |
num_splits = out_partial.shape[0]
|
| 1649 |
assert num_splits <= 256
|
|
@@ -1652,101 +1958,37 @@ def _flash_attn_fwd_combine(
|
|
| 1652 |
k_block_size = 64 if head_dim <= 64 else 128
|
| 1653 |
# We want kBlockM to be as small as possible to maximize parallelism.
|
| 1654 |
# E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats).
|
| 1655 |
-
|
| 1656 |
log_max_splits = max(math.ceil(math.log2(num_splits)), 4)
|
| 1657 |
-
if
|
| 1658 |
# If kBlockM == 8 then the minimum number of splits is 32.
|
| 1659 |
# TODO: we can deal w this by using 128 threads instead
|
| 1660 |
log_max_splits = max(log_max_splits, 5)
|
| 1661 |
|
| 1662 |
-
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
| 1663 |
-
|
| 1664 |
# Create combine kernel configuration
|
| 1665 |
dtype = torch2cute_dtype_map[out.dtype]
|
| 1666 |
dtype_partial = torch2cute_dtype_map[out_partial.dtype]
|
| 1667 |
-
|
| 1668 |
compile_key = (
|
| 1669 |
dtype,
|
| 1670 |
dtype_partial,
|
| 1671 |
head_dim,
|
| 1672 |
-
|
| 1673 |
k_block_size,
|
| 1674 |
log_max_splits,
|
| 1675 |
cu_seqlens is not None,
|
| 1676 |
seqused is not None,
|
| 1677 |
lse is not None,
|
|
|
|
| 1678 |
)
|
| 1679 |
-
|
| 1680 |
if compile_key not in _flash_attn_fwd_combine.compile_cache:
|
| 1681 |
-
|
| 1682 |
-
|
| 1683 |
-
)
|
| 1684 |
-
lse_partial_tensor = to_cute_tensor(
|
| 1685 |
-
lse_partial, assumed_align=4, leading_dim=lse_partial.ndim - 2
|
| 1686 |
-
)
|
| 1687 |
-
out_tensor = to_cute_tensor(out, leading_dim=3 if not is_varlen else 2)
|
| 1688 |
-
lse_tensor = (
|
| 1689 |
-
to_cute_tensor(lse, assumed_align=4, leading_dim=lse.ndim - 2)
|
| 1690 |
-
if lse is not None
|
| 1691 |
-
else None
|
| 1692 |
-
)
|
| 1693 |
-
|
| 1694 |
-
optional_tensors = [
|
| 1695 |
-
to_cute_tensor(t, assumed_align=4, leading_dim=0)
|
| 1696 |
-
if t is not None
|
| 1697 |
-
else None
|
| 1698 |
-
for t in (cu_seqlens, seqused, num_splits_dynamic_ptr, semaphore_to_reset)
|
| 1699 |
-
]
|
| 1700 |
-
cu_seqlens_tensor, seqused_tensor, num_splits_dynamic_tensor, semaphore_tensor = (
|
| 1701 |
-
optional_tensors
|
| 1702 |
-
)
|
| 1703 |
-
fa_combine = FlashAttentionForwardCombine(
|
| 1704 |
-
dtype=dtype,
|
| 1705 |
-
dtype_partial=dtype_partial,
|
| 1706 |
-
head_dim=head_dim,
|
| 1707 |
-
m_block_size=m_block_size,
|
| 1708 |
-
k_block_size=k_block_size,
|
| 1709 |
-
log_max_splits=log_max_splits,
|
| 1710 |
-
)
|
| 1711 |
-
|
| 1712 |
-
# Check if implementation is supported
|
| 1713 |
-
if not fa_combine.can_implement(
|
| 1714 |
-
dtype,
|
| 1715 |
-
dtype_partial,
|
| 1716 |
-
head_dim,
|
| 1717 |
-
m_block_size,
|
| 1718 |
-
k_block_size,
|
| 1719 |
-
log_max_splits,
|
| 1720 |
-
num_threads=256,
|
| 1721 |
-
):
|
| 1722 |
-
raise RuntimeError(
|
| 1723 |
-
"FlashAttention combine kernel cannot be implemented with given parameters"
|
| 1724 |
-
)
|
| 1725 |
-
|
| 1726 |
-
_flash_attn_fwd_combine.compile_cache[compile_key] = cute.compile(
|
| 1727 |
-
fa_combine,
|
| 1728 |
-
out_partial_tensor,
|
| 1729 |
-
lse_partial_tensor,
|
| 1730 |
-
out_tensor,
|
| 1731 |
-
lse_tensor,
|
| 1732 |
-
cu_seqlens_tensor,
|
| 1733 |
-
seqused_tensor,
|
| 1734 |
-
num_splits_dynamic_tensor,
|
| 1735 |
-
semaphore_tensor,
|
| 1736 |
-
current_stream,
|
| 1737 |
-
options="--enable-tvm-ffi",
|
| 1738 |
)
|
| 1739 |
if not is_fake_mode():
|
| 1740 |
_flash_attn_fwd_combine.compile_cache[compile_key](
|
| 1741 |
-
out_partial,
|
| 1742 |
-
|
| 1743 |
-
out,
|
| 1744 |
-
lse,
|
| 1745 |
-
cu_seqlens,
|
| 1746 |
-
seqused,
|
| 1747 |
-
num_splits_dynamic_ptr,
|
| 1748 |
semaphore_to_reset,
|
| 1749 |
-
current_stream,
|
| 1750 |
)
|
| 1751 |
|
| 1752 |
|
|
@@ -1760,6 +2002,7 @@ def flash_attn_combine(
|
|
| 1760 |
out_dtype: Optional[torch.dtype] = None,
|
| 1761 |
cu_seqlens: Optional[torch.Tensor] = None,
|
| 1762 |
seqused: Optional[torch.Tensor] = None,
|
|
|
|
| 1763 |
return_lse: bool = True,
|
| 1764 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 1765 |
"""Flash Attention combine function for split attention computation.
|
|
@@ -1779,6 +2022,9 @@ def flash_attn_combine(
|
|
| 1779 |
out_dtype: Optional output dtype. If None, will use fp16/bf16 based on input.
|
| 1780 |
cu_seqlens: Cumulative sequence lengths for variable length sequences
|
| 1781 |
seqused: Used sequence lengths for each batch
|
|
|
|
|
|
|
|
|
|
| 1782 |
return_lse: Whether to return the combined LSE tensor. Default is True.
|
| 1783 |
|
| 1784 |
Returns:
|
|
@@ -1795,32 +2041,19 @@ def flash_attn_combine(
|
|
| 1795 |
"""
|
| 1796 |
# Input validation
|
| 1797 |
assert out_partial.dim() in [4, 5], "out_partial must have 4 or 5 dimensions"
|
| 1798 |
-
assert lse_partial.dim() in [3, 4], "lse_partial must have 3 or 4 dimensions"
|
| 1799 |
-
assert out_partial.dtype == torch.float32, "out_partial must be fp32 (from accumulation)"
|
| 1800 |
-
assert lse_partial.dtype == torch.float32, "lse_partial must be fp32"
|
| 1801 |
-
|
| 1802 |
# Determine if this is variable length based on dimensions
|
| 1803 |
is_varlen = out_partial.dim() == 4
|
| 1804 |
-
|
| 1805 |
if is_varlen:
|
| 1806 |
# Variable length: (num_splits, total_q, num_heads, head_size)
|
| 1807 |
num_splits, total_q, num_heads, head_size = out_partial.shape
|
| 1808 |
-
assert lse_partial.shape == (num_splits, total_q, num_heads), (
|
| 1809 |
-
"lse_partial shape mismatch for varlen"
|
| 1810 |
-
)
|
| 1811 |
batch_size = 1 # Treat as single batch for varlen
|
| 1812 |
seqlen = total_q
|
| 1813 |
else:
|
| 1814 |
# Regular batched: (num_splits, batch_size, seqlen, num_heads, head_size)
|
| 1815 |
num_splits, batch_size, seqlen, num_heads, head_size = out_partial.shape
|
| 1816 |
-
assert lse_partial.shape == (num_splits, batch_size, seqlen, num_heads), (
|
| 1817 |
-
"lse_partial shape mismatch"
|
| 1818 |
-
)
|
| 1819 |
-
|
| 1820 |
# Determine output dtype
|
| 1821 |
if out_dtype is None:
|
| 1822 |
out_dtype = out_partial.dtype
|
| 1823 |
-
|
| 1824 |
# Create output if not provided
|
| 1825 |
device = out_partial.device
|
| 1826 |
if out is None:
|
|
@@ -1830,20 +2063,15 @@ def flash_attn_combine(
|
|
| 1830 |
out = torch.empty(
|
| 1831 |
batch_size, seqlen, num_heads, head_size, dtype=out_dtype, device=device
|
| 1832 |
)
|
| 1833 |
-
|
| 1834 |
# Create lse output only if requested
|
| 1835 |
if return_lse:
|
| 1836 |
if is_varlen:
|
| 1837 |
-
lse = torch.empty(num_heads, total_q, dtype=torch.float32, device=device)
|
| 1838 |
-
0, 1
|
| 1839 |
-
)
|
| 1840 |
else:
|
| 1841 |
-
lse = torch.empty(
|
| 1842 |
-
|
| 1843 |
-
).transpose(1, 2)
|
| 1844 |
else:
|
| 1845 |
lse = None
|
| 1846 |
-
|
| 1847 |
_flash_attn_fwd_combine(
|
| 1848 |
out_partial,
|
| 1849 |
lse_partial,
|
|
@@ -1851,5 +2079,6 @@ def flash_attn_combine(
|
|
| 1851 |
lse,
|
| 1852 |
cu_seqlens,
|
| 1853 |
seqused,
|
|
|
|
| 1854 |
)
|
| 1855 |
return out, lse
|
|
|
|
| 21 |
|
| 22 |
import os
|
| 23 |
import math
|
| 24 |
+
from dataclasses import dataclass
|
| 25 |
from functools import lru_cache
|
| 26 |
from typing import Optional, Tuple, Callable
|
| 27 |
|
|
|
|
| 32 |
|
| 33 |
import cutlass
|
| 34 |
import cutlass.cute as cute
|
| 35 |
+
from cutlass import Int32, Float32
|
| 36 |
+
from .quack.compile_utils import make_fake_tensor as fake_tensor
|
| 37 |
from .cache_utils import get_jit_cache
|
| 38 |
from .testing import is_fake_mode
|
| 39 |
|
|
|
|
| 46 |
|
| 47 |
|
| 48 |
from . import utils
|
| 49 |
+
from . import fa_logging
|
| 50 |
from .cute_dsl_utils import (
|
| 51 |
to_cute_tensor, to_cute_aux_tensor, get_aux_tensor_metadata, get_broadcast_dims,
|
| 52 |
)
|
| 53 |
+
from .flash_fwd import FlashAttentionForwardSm80
|
| 54 |
+
from .flash_fwd_sm90 import FlashAttentionForwardSm90
|
| 55 |
from .flash_fwd_sm100 import FlashAttentionForwardSm100
|
| 56 |
+
from .flash_fwd_sm120 import FlashAttentionForwardSm120
|
| 57 |
from .flash_bwd_preprocess import FlashAttentionBackwardPreprocess
|
| 58 |
from .flash_bwd import FlashAttentionBackwardSm80
|
| 59 |
from .flash_bwd_sm90 import FlashAttentionBackwardSm90
|
| 60 |
from .flash_bwd_sm100 import FlashAttentionBackwardSm100
|
| 61 |
+
from .flash_bwd_sm120 import FlashAttentionBackwardSm120
|
| 62 |
from .flash_bwd_postprocess import FlashAttentionBackwardPostprocess
|
| 63 |
from .flash_fwd_combine import FlashAttentionForwardCombine
|
| 64 |
|
| 65 |
from .block_sparsity import (
|
| 66 |
BlockSparseTensorsTorch,
|
| 67 |
+
get_sparse_q_block_size,
|
| 68 |
to_cute_block_sparse_tensors,
|
| 69 |
normalize_block_sparse_config,
|
| 70 |
normalize_block_sparse_config_bwd,
|
| 71 |
)
|
| 72 |
|
| 73 |
+
def _parse_arch_str(arch_str):
|
| 74 |
+
"""Parse arch string (e.g. 'sm_80', 'sm_90a', '80', '100') to int (e.g. 80, 90, 100)."""
|
| 75 |
+
import re
|
| 76 |
+
match = re.match(r"^(?:sm_?|SM_?)?(\d+)(\d)([af]?)$", arch_str)
|
| 77 |
+
if not match:
|
| 78 |
+
raise ValueError(f"Invalid arch format: {arch_str}")
|
| 79 |
+
major, minor, _ = match.groups()
|
| 80 |
+
return int(major) * 10 + int(minor)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
@lru_cache(maxsize=None)
|
| 84 |
def _get_device_arch():
|
| 85 |
+
"""Cached device arch check.
|
| 86 |
+
|
| 87 |
+
Override with FLASH_ATTENTION_ARCH (e.g. 'sm_80' or '80') to select which
|
| 88 |
+
kernel path to use (SM80/SM90/SM100/SM120) independently of the compilation
|
| 89 |
+
target (CUTE_DSL_ARCH).
|
| 90 |
+
|
| 91 |
+
For CPU-only compilation (no GPU), set both:
|
| 92 |
+
FLASH_ATTENTION_ARCH=sm_80 (kernel selection)
|
| 93 |
+
CUTE_DSL_ARCH=sm_80 (compilation target)
|
| 94 |
+
"""
|
| 95 |
+
arch_override = os.environ.get("FLASH_ATTENTION_ARCH", None)
|
| 96 |
+
if arch_override is not None:
|
| 97 |
+
return _parse_arch_str(arch_override)
|
| 98 |
major, minor = torch.cuda.get_device_capability()
|
| 99 |
+
return major * 10 + int(minor)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def _validate_head_dims(head_dim: int, head_dim_v: int, compute_capability: int, alignment: int) -> None:
|
| 103 |
+
"""Validate head dimension constraints based on compute capability."""
|
| 104 |
+
is_deepseek_shape = head_dim == 192 and head_dim_v == 128
|
| 105 |
+
is_standard_range = 8 <= head_dim <= 128 and 8 <= head_dim_v <= 128
|
| 106 |
+
|
| 107 |
+
is_sm90_range = 8 <= head_dim <= 256 and 8 <= head_dim_v <= 256
|
| 108 |
+
if compute_capability == 9:
|
| 109 |
+
assert is_sm90_range and head_dim % alignment == 0 and head_dim_v % alignment == 0, (
|
| 110 |
+
f"(head_dim, head_dim_v)=({head_dim}, {head_dim_v}) is not supported on SM90. "
|
| 111 |
+
f"head_dim and head_dim_v must be between 8 and 256 and divisible by {alignment}."
|
| 112 |
+
)
|
| 113 |
+
elif compute_capability in [10, 11]:
|
| 114 |
+
assert (is_standard_range or is_deepseek_shape) and head_dim % alignment == 0 and head_dim_v % alignment == 0, (
|
| 115 |
+
f"(head_dim, head_dim_v)=({head_dim}, {head_dim_v}) is not supported on SM100/SM110. "
|
| 116 |
+
f"head_dim and head_dim_v must be between 8 and 128 and divisible by {alignment}, or (192, 128) for DeepSeek."
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
@dataclass(frozen=True)
|
| 121 |
+
class FwdConfig:
|
| 122 |
+
m_block_size: int
|
| 123 |
+
n_block_size: int
|
| 124 |
+
mma_pv_is_rs: bool
|
| 125 |
+
intra_wg_overlap: bool
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def _tile_size_fwd_sm90(head_dim, head_dim_v, is_causal, is_local, sparse_block_size_q=None):
|
| 129 |
+
"""Return FwdConfig for SM90 forward.
|
| 130 |
+
|
| 131 |
+
Tile sizes and flags based on tile_size_fwd_sm90 in hopper/tile_size.h, adjusted
|
| 132 |
+
for the Python kernel's different register/smem tradeoffs (benchmarked on H100 SXM).
|
| 133 |
+
|
| 134 |
+
When sparse_block_size_q is set, tile_m must divide it. For head_dim <= 96 the
|
| 135 |
+
optimal tile_m=192 is used when compatible, otherwise we fall back to 128.
|
| 136 |
+
"""
|
| 137 |
+
if head_dim <= 64:
|
| 138 |
+
# C++: 192×192 non-causal, 192×128 causal/local.
|
| 139 |
+
# Python: 192×128 RS+OL is consistently best across seqlens.
|
| 140 |
+
if sparse_block_size_q is not None and sparse_block_size_q % 192 != 0:
|
| 141 |
+
return FwdConfig(128, 128, True, True)
|
| 142 |
+
return FwdConfig(192, 128, True, True)
|
| 143 |
+
elif head_dim <= 96:
|
| 144 |
+
# C++: 192×144 noRS+OL for all cases.
|
| 145 |
+
# Python: RS is catastrophic with 192× tiles (~300 vs ~600 TFLOPS).
|
| 146 |
+
# noRS+OL is always required. Causal: 192×128 slightly better short seqlen.
|
| 147 |
+
if sparse_block_size_q is not None and sparse_block_size_q % 192 != 0:
|
| 148 |
+
return FwdConfig(128, 128, False, True)
|
| 149 |
+
if is_causal or is_local:
|
| 150 |
+
return FwdConfig(192, 128, False, True)
|
| 151 |
+
else:
|
| 152 |
+
return FwdConfig(192, 144, False, True)
|
| 153 |
+
elif head_dim <= 128:
|
| 154 |
+
return FwdConfig(128, 128, True, True)
|
| 155 |
+
elif head_dim <= 192:
|
| 156 |
+
tile_n = 96 if is_local else (128 if head_dim_v <= 128 else 112)
|
| 157 |
+
return FwdConfig(128, tile_n, True, True)
|
| 158 |
+
else: # hdim 256
|
| 159 |
+
tile_n = 64 if is_local else 80
|
| 160 |
+
return FwdConfig(128, tile_n, True, True)
|
| 161 |
+
|
| 162 |
+
@dataclass(frozen=True)
|
| 163 |
+
class BwdConfig:
|
| 164 |
+
m_block_size: int
|
| 165 |
+
n_block_size: int
|
| 166 |
+
num_stages_Q: int
|
| 167 |
+
num_stages_dO: int
|
| 168 |
+
num_stages_PdS: int
|
| 169 |
+
SdP_swapAB: bool
|
| 170 |
+
dKV_swapAB: bool
|
| 171 |
+
dQ_swapAB: bool
|
| 172 |
+
AtomLayoutMSdP: int
|
| 173 |
+
AtomLayoutNdKV: int
|
| 174 |
+
AtomLayoutMdQ: int
|
| 175 |
+
num_wg: int = 2 # MMA warp groups (total threads = (num_wg + 1) * 128)
|
| 176 |
+
dQ_single_wg: bool = False
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def _tile_size_bwd_sm90(head_dim, head_dim_v, causal, local, sparse_block_size_q=None):
|
| 180 |
+
"""Return BwdConfig for SM90.
|
| 181 |
+
|
| 182 |
+
Configs based on C++ FA3 hopper/flash_bwd_launch_template.h,
|
| 183 |
+
benchmarked on H100 SXM.
|
| 184 |
+
"""
|
| 185 |
+
if head_dim <= 64:
|
| 186 |
+
# C++ FA3: 128, 128, 64, ..., 2, 2, true, false, false, 2, 1, 2, 2
|
| 187 |
+
return BwdConfig(
|
| 188 |
+
m_block_size=128, n_block_size=128,
|
| 189 |
+
num_stages_Q=2, num_stages_dO=2, num_stages_PdS=2,
|
| 190 |
+
SdP_swapAB=True, dKV_swapAB=False, dQ_swapAB=False,
|
| 191 |
+
AtomLayoutMSdP=1, AtomLayoutNdKV=2, AtomLayoutMdQ=2,
|
| 192 |
+
)
|
| 193 |
+
elif head_dim <= 96:
|
| 194 |
+
# C++ FA3: 64, 128, 96, dQ_swapAB=False
|
| 195 |
+
return BwdConfig(
|
| 196 |
+
m_block_size=64, n_block_size=128,
|
| 197 |
+
num_stages_Q=2, num_stages_dO=2, num_stages_PdS=2,
|
| 198 |
+
SdP_swapAB=True, dKV_swapAB=False, dQ_swapAB=False,
|
| 199 |
+
AtomLayoutMSdP=1, AtomLayoutNdKV=2, AtomLayoutMdQ=1,
|
| 200 |
+
dQ_single_wg=True,
|
| 201 |
+
)
|
| 202 |
+
elif head_dim <= 128:
|
| 203 |
+
# C++ FA3: causal/local: 64, 128; non-causal: 80, 128 with dQ_swapAB
|
| 204 |
+
is_causal_or_local = causal or local
|
| 205 |
+
m_block_size = 64 if is_causal_or_local else 80
|
| 206 |
+
if sparse_block_size_q is not None and sparse_block_size_q % m_block_size != 0:
|
| 207 |
+
m_block_size = 64
|
| 208 |
+
return BwdConfig(
|
| 209 |
+
m_block_size=m_block_size,
|
| 210 |
+
n_block_size=128,
|
| 211 |
+
num_stages_Q=2, num_stages_dO=2, num_stages_PdS=2,
|
| 212 |
+
SdP_swapAB=True, dKV_swapAB=False,
|
| 213 |
+
dQ_swapAB=m_block_size % 64 != 0,
|
| 214 |
+
AtomLayoutMSdP=1, AtomLayoutNdKV=2, AtomLayoutMdQ=1,
|
| 215 |
+
)
|
| 216 |
+
elif head_dim <= 192:
|
| 217 |
+
hdimv128 = head_dim_v <= 128
|
| 218 |
+
if hdimv128:
|
| 219 |
+
return BwdConfig(
|
| 220 |
+
m_block_size=64, n_block_size=96,
|
| 221 |
+
num_stages_Q=2, num_stages_dO=2, num_stages_PdS=1,
|
| 222 |
+
SdP_swapAB=False, dKV_swapAB=True, dQ_swapAB=False,
|
| 223 |
+
AtomLayoutMSdP=1, AtomLayoutNdKV=2, AtomLayoutMdQ=1,
|
| 224 |
+
num_wg=2,
|
| 225 |
+
)
|
| 226 |
+
else:
|
| 227 |
+
return BwdConfig(
|
| 228 |
+
m_block_size=64, n_block_size=96,
|
| 229 |
+
num_stages_Q=2, num_stages_dO=1, num_stages_PdS=1,
|
| 230 |
+
SdP_swapAB=False, dKV_swapAB=True, dQ_swapAB=False,
|
| 231 |
+
AtomLayoutMSdP=1, AtomLayoutNdKV=2, AtomLayoutMdQ=1,
|
| 232 |
+
num_wg=2,
|
| 233 |
+
)
|
| 234 |
+
else:
|
| 235 |
+
# hdim 256
|
| 236 |
+
return BwdConfig(
|
| 237 |
+
m_block_size=64, n_block_size=64,
|
| 238 |
+
num_stages_Q=1, num_stages_dO=1, num_stages_PdS=1,
|
| 239 |
+
SdP_swapAB=False, dKV_swapAB=False, dQ_swapAB=False,
|
| 240 |
+
AtomLayoutMSdP=1, AtomLayoutNdKV=1, AtomLayoutMdQ=1,
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
|
| 245 |
def maybe_contiguous(x):
|
| 246 |
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
|
|
|
|
| 250 |
assert t.shape == expected_shape, f"{name} shape {t.shape} != expected {expected_shape}"
|
| 251 |
assert t.dtype == expected_dtype, f"{name} dtype {t.dtype} != expected {expected_dtype}"
|
| 252 |
assert t.device == expected_device, f"{name} device {t.device} != expected {expected_device}"
|
| 253 |
+
if not is_fake_mode():
|
| 254 |
+
assert t.is_cuda, f"{name} must be on CUDA"
|
| 255 |
|
| 256 |
|
| 257 |
torch2cute_dtype_map = {
|
|
|
|
| 271 |
return min(num_SMs // total_mblocks, max_splits, num_n_blocks)
|
| 272 |
|
| 273 |
|
| 274 |
+
def _resolve_causal_local_window(causal, window_size_left, window_size_right, mask_mod=None):
|
| 275 |
+
"""Resolve causal/local/window settings into canonical form.
|
| 276 |
+
|
| 277 |
+
Returns (causal, local, window_size_left, window_size_right).
|
| 278 |
+
"""
|
| 279 |
+
if mask_mod is not None:
|
| 280 |
+
return False, False, window_size_left, window_size_right
|
| 281 |
+
if causal:
|
| 282 |
+
window_size_right = 0
|
| 283 |
+
if window_size_left is not None and window_size_right is not None and window_size_left + window_size_right < 0:
|
| 284 |
+
window_size_left = None
|
| 285 |
+
window_size_right = None
|
| 286 |
+
if window_size_left is not None or window_size_right is not None:
|
| 287 |
+
if window_size_left is None and window_size_right == 0:
|
| 288 |
+
causal, local = True, False
|
| 289 |
+
window_size_right = None
|
| 290 |
+
else:
|
| 291 |
+
causal, local = False, True
|
| 292 |
+
else:
|
| 293 |
+
local = False
|
| 294 |
+
return causal, local, window_size_left, window_size_right
|
| 295 |
+
|
| 296 |
+
|
| 297 |
def _flash_attn_fwd(
|
| 298 |
q: torch.Tensor,
|
| 299 |
k: torch.Tensor,
|
|
|
|
| 311 |
window_size_left: Optional[int] = None,
|
| 312 |
window_size_right: Optional[int] = None,
|
| 313 |
learnable_sink: Optional[torch.Tensor] = None,
|
| 314 |
+
tile_mn: Optional[Tuple[int, int]] = None,
|
| 315 |
+
mma_pv_is_rs: Optional[bool] = None,
|
| 316 |
+
intra_wg_overlap: Optional[bool] = None,
|
|
|
|
|
|
|
| 317 |
num_threads: int = 384,
|
| 318 |
num_splits: int = 1,
|
| 319 |
pack_gqa: Optional[bool] = None,
|
|
|
|
| 334 |
mask_mod: A callable that takes token position information and selectively masks
|
| 335 |
block_sparse_tensors: A tuple of tensors used for block sparsity.
|
| 336 |
return_lse: Whether to return the log softmax of the attention scores. If set to True will always calculate
|
| 337 |
+
The returned LSE supports taking gradient.
|
| 338 |
out: Optional pre-allocated output tensor. If None, will be allocated internally.
|
| 339 |
lse: Optional pre-allocated log-sum-exp tensor. If None, will be allocated when needed.
|
| 340 |
aux_tensors: Some score_mods will want to read from global aux_tensors. This is how we thread them through to the inner kernel.
|
|
|
|
| 399 |
assert learnable_sink.shape == (num_head,)
|
| 400 |
assert learnable_sink.dtype == torch.bfloat16, "learnable_sink must be bfloat16"
|
| 401 |
|
| 402 |
+
if not is_fake_mode():
|
| 403 |
+
assert all(
|
| 404 |
+
t is None or t.is_cuda
|
| 405 |
+
for t in (
|
| 406 |
+
q,
|
| 407 |
+
k,
|
| 408 |
+
v,
|
| 409 |
+
cu_seqlens_q,
|
| 410 |
+
cu_seqlens_k,
|
| 411 |
+
seqused_q,
|
| 412 |
+
seqused_k,
|
| 413 |
+
page_table,
|
| 414 |
+
learnable_sink,
|
| 415 |
+
)
|
| 416 |
+
), "inputs must be on CUDA device"
|
| 417 |
+
arch = _get_device_arch() if _arch is None else _arch
|
| 418 |
+
assert arch // 10 in [8, 9, 10, 11, 12], "Unsupported compute capability. Supported: 8.x, 9.x, 10.x, 11.x, 12.x"
|
| 419 |
assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv"
|
|
|
|
| 420 |
alignment = 16 // q.element_size()
|
| 421 |
+
if arch // 10 not in [8, 12]:
|
| 422 |
+
_validate_head_dims(head_dim, head_dim_v, arch // 10, alignment)
|
| 423 |
if softmax_scale is None:
|
| 424 |
softmax_scale = 1.0 / math.sqrt(head_dim)
|
| 425 |
if softcap == 0.0:
|
|
|
|
| 451 |
_validate_tensor(lse, "lse", lse_shape, torch.float32, device)
|
| 452 |
|
| 453 |
dtype = torch2cute_dtype_map[q.dtype]
|
| 454 |
+
use_block_sparsity = block_sparse_tensors is not None
|
| 455 |
+
|
| 456 |
+
causal, local, window_size_left, window_size_right = _resolve_causal_local_window(
|
| 457 |
+
causal, window_size_left, window_size_right, mask_mod
|
| 458 |
+
)
|
| 459 |
|
| 460 |
+
requested_use_clc_scheduler = utils._get_use_clc_scheduler_default()
|
| 461 |
+
requested_disable_2cta = utils._get_disable_2cta_default()
|
| 462 |
|
| 463 |
+
current_stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True)
|
| 464 |
|
| 465 |
+
# SM80/SM120: uses SM80 MMA, 128 threads (4 warps)
|
| 466 |
+
if arch // 10 in [8, 12]:
|
| 467 |
+
num_threads = 128
|
| 468 |
+
|
| 469 |
+
fwd_cfg = FwdConfig(128, 128, True, True) # default
|
| 470 |
+
if tile_mn is None:
|
| 471 |
+
if arch // 10 == 12:
|
| 472 |
+
# SM120 tile sizes tuned for 99 KB SMEM capacity:
|
| 473 |
+
# D<=64: 128x128 → 48 KB (good occupancy)
|
| 474 |
+
# D>64: 128x64 → 64 KB (128x128 would use 96 KB, hurting occupancy)
|
| 475 |
+
if head_dim <= 64:
|
| 476 |
+
fwd_cfg = FwdConfig(128, 128, True, True)
|
| 477 |
else:
|
| 478 |
+
fwd_cfg = FwdConfig(128, 64, True, True)
|
| 479 |
+
elif arch // 10 == 8:
|
| 480 |
+
fwd_cfg = FwdConfig(128, 64, True, True) # SM80, should tune
|
| 481 |
+
elif arch // 10 == 9:
|
| 482 |
+
sparse_q = get_sparse_q_block_size(block_sparse_tensors, seqlen_q)
|
| 483 |
+
fwd_cfg = _tile_size_fwd_sm90(head_dim, head_dim_v, causal, local, sparse_block_size_q=sparse_q)
|
| 484 |
else:
|
| 485 |
+
fwd_cfg = FwdConfig(tile_mn[0], tile_mn[1], fwd_cfg.mma_pv_is_rs, fwd_cfg.intra_wg_overlap)
|
| 486 |
+
tile_m, tile_n = fwd_cfg.m_block_size, fwd_cfg.n_block_size
|
| 487 |
+
if mma_pv_is_rs is None:
|
| 488 |
+
mma_pv_is_rs = fwd_cfg.mma_pv_is_rs
|
| 489 |
+
if intra_wg_overlap is None:
|
| 490 |
+
intra_wg_overlap = fwd_cfg.intra_wg_overlap
|
| 491 |
|
| 492 |
+
# TODO: fix GQA + SplitKV + non-varlen
|
| 493 |
+
if pack_gqa and num_splits != 1 and cu_seqlens_q is None:
|
| 494 |
+
pack_gqa = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 495 |
|
| 496 |
if max_seqlen_q is None:
|
| 497 |
max_seqlen_q = seqlen_q if cu_seqlens_q is None else total_q
|
|
|
|
| 499 |
max_seqlen_k = seqlen_k
|
| 500 |
seqlen_q_packgqa = max_seqlen_q * qhead_per_kvhead
|
| 501 |
if arch // 10 == 10:
|
| 502 |
+
q_stage = 2 if seqlen_q_packgqa > tile_m else 1
|
| 503 |
else:
|
| 504 |
q_stage = 1
|
| 505 |
|
| 506 |
+
m_block_size_effective = q_stage * tile_m
|
| 507 |
+
seqlen_k_loaded = max_seqlen_k if not local else max(0, min(max_seqlen_k, (window_size_right or max_seqlen_k) + (window_size_left or max_seqlen_k) + 1 + tile_m))
|
| 508 |
+
num_m_blocks = (seqlen_q_packgqa + m_block_size_effective - 1) // m_block_size_effective
|
| 509 |
+
total_mblocks = batch_size * num_head_kv * num_m_blocks
|
| 510 |
+
num_n_blocks = (seqlen_k_loaded + tile_n - 1) // tile_n
|
| 511 |
+
num_SMs = 132 if is_fake_mode() else torch.cuda.get_device_properties(device).multi_processor_count
|
| 512 |
if num_splits < 1:
|
| 513 |
+
num_splits = num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, 128)
|
| 514 |
+
|
| 515 |
+
# SplitKV uses float32 partial output, which doubles the O buffer size
|
| 516 |
+
# in shared memory, causing OOM for diff-headdim (192, 128)
|
| 517 |
+
if arch // 10 in [10, 11] and head_dim != head_dim_v and num_splits > 1:
|
| 518 |
+
if num_n_blocks >= 64:
|
| 519 |
+
tile_n = 64
|
| 520 |
+
num_n_blocks = (seqlen_k_loaded + tile_n - 1) // tile_n
|
| 521 |
+
num_splits = num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, 128)
|
| 522 |
+
else:
|
| 523 |
+
num_splits = 1
|
| 524 |
|
| 525 |
is_split_kv = num_splits > 1
|
| 526 |
if is_split_kv:
|
| 527 |
out_partial = torch.empty(num_splits, *q_batch_seqlen_shape, num_head, head_dim_v, dtype=torch.float32, device=device)
|
| 528 |
lse_partial = torch.empty(num_splits, *lse_shape, dtype=torch.float32, device=device)
|
| 529 |
|
| 530 |
+
use_2cta_instrs = (
|
| 531 |
+
arch // 10 in [10, 11]
|
| 532 |
+
and not requested_disable_2cta
|
| 533 |
+
and not causal
|
| 534 |
+
and not local
|
| 535 |
+
and not is_split_kv
|
| 536 |
+
and cu_seqlens_q is None
|
| 537 |
+
and seqused_q is None
|
| 538 |
+
and not use_block_sparsity
|
| 539 |
+
and page_size in [None, 128]
|
| 540 |
+
and int(math.ceil(head_dim / 16) * 16) in [128, 192]
|
| 541 |
+
and int(math.ceil(head_dim_v / 16) * 16) == 128
|
| 542 |
+
and seqlen_q_packgqa > 2 * tile_m
|
| 543 |
+
and (tile_m % qhead_per_kvhead == 0 or not pack_gqa)
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
# hash score and mask mods for compile cache
|
| 547 |
score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else False
|
| 548 |
mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else False
|
|
|
|
| 594 |
num_head=num_head,
|
| 595 |
seqlen_q=seqlen_q,
|
| 596 |
seqlen_k=seqlen_k,
|
| 597 |
+
block_size=(tile_m, tile_n),
|
| 598 |
q_stage=q_stage,
|
| 599 |
)
|
| 600 |
+
if aux_tensors is not None:
|
| 601 |
aux_tensor_metadata = get_aux_tensor_metadata(aux_tensors)
|
| 602 |
else:
|
| 603 |
aux_tensor_metadata = None
|
| 604 |
+
|
| 605 |
compile_key = (
|
| 606 |
dtype,
|
| 607 |
head_dim,
|
|
|
|
| 622 |
window_size_left is not None,
|
| 623 |
window_size_right is not None,
|
| 624 |
learnable_sink is not None,
|
| 625 |
+
tile_m,
|
| 626 |
+
tile_n,
|
| 627 |
q_stage,
|
| 628 |
num_threads,
|
| 629 |
is_split_kv,
|
| 630 |
pack_gqa,
|
| 631 |
arch,
|
| 632 |
+
page_size not in [None, tile_n], # paged KV non-TMA
|
| 633 |
+
use_2cta_instrs,
|
| 634 |
q_subtile_factor,
|
| 635 |
+
mma_pv_is_rs,
|
| 636 |
+
intra_wg_overlap,
|
| 637 |
+
requested_use_clc_scheduler,
|
| 638 |
+
fa_logging.get_fa_log_level(),
|
| 639 |
)
|
| 640 |
if compile_key not in _flash_attn_fwd.compile_cache:
|
| 641 |
(
|
|
|
|
| 674 |
if aux_tensors is not None:
|
| 675 |
cute_aux_tensors = [to_cute_aux_tensor(buf) for buf in aux_tensors]
|
| 676 |
|
| 677 |
+
if arch // 10 == 8:
|
| 678 |
+
assert page_table is None, "paged KV not supported on SM 8.0"
|
| 679 |
+
assert not is_split_kv, "SplitKV not supported on SM 8.0"
|
| 680 |
+
fa_fwd = FlashAttentionForwardSm80(
|
| 681 |
+
dtype,
|
| 682 |
+
head_dim,
|
| 683 |
+
head_dim_v,
|
| 684 |
+
qhead_per_kvhead,
|
| 685 |
+
is_causal=causal,
|
| 686 |
+
is_local=local,
|
| 687 |
+
pack_gqa=pack_gqa,
|
| 688 |
+
tile_m=tile_m,
|
| 689 |
+
tile_n=tile_n,
|
| 690 |
+
num_stages=1,
|
| 691 |
+
num_threads=num_threads,
|
| 692 |
+
Q_in_regs=False,
|
| 693 |
+
score_mod=score_mod,
|
| 694 |
+
mask_mod=mask_mod,
|
| 695 |
+
has_aux_tensors=aux_tensors is not None,
|
| 696 |
+
)
|
| 697 |
+
elif arch // 10 == 9:
|
| 698 |
assert not is_split_kv, "SplitKV not supported on SM 9.0"
|
|
|
|
| 699 |
fa_fwd = FlashAttentionForwardSm90(
|
| 700 |
dtype,
|
| 701 |
head_dim,
|
|
|
|
| 704 |
is_causal=causal,
|
| 705 |
is_local=local,
|
| 706 |
pack_gqa=pack_gqa,
|
| 707 |
+
tile_m=tile_m,
|
| 708 |
+
tile_n=tile_n,
|
| 709 |
# num_stages=1,
|
| 710 |
num_stages=2,
|
| 711 |
num_threads=num_threads,
|
| 712 |
Q_in_regs=False,
|
| 713 |
+
intra_wg_overlap=intra_wg_overlap,
|
| 714 |
+
mma_pv_is_rs=mma_pv_is_rs,
|
| 715 |
mask_mod=mask_mod,
|
| 716 |
score_mod=score_mod,
|
| 717 |
has_aux_tensors=aux_tensors is not None,
|
| 718 |
q_subtile_factor=q_subtile_factor,
|
| 719 |
+
paged_kv_non_tma=page_size not in [None, tile_n],
|
| 720 |
)
|
| 721 |
elif arch // 10 in [10, 11]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 722 |
fa_fwd = FlashAttentionForwardSm100(
|
| 723 |
head_dim,
|
| 724 |
head_dim_v,
|
|
|
|
| 727 |
is_local=local,
|
| 728 |
is_split_kv=is_split_kv,
|
| 729 |
pack_gqa=pack_gqa,
|
| 730 |
+
m_block_size=tile_m,
|
| 731 |
+
n_block_size=tile_n,
|
| 732 |
q_stage=q_stage,
|
| 733 |
is_persistent=not causal
|
| 734 |
and not local
|
|
|
|
| 738 |
score_mod=score_mod,
|
| 739 |
mask_mod=mask_mod,
|
| 740 |
has_aux_tensors=aux_tensors is not None,
|
| 741 |
+
paged_kv_non_tma=page_size not in [None, tile_n],
|
| 742 |
is_varlen_q=cu_seqlens_q is not None or seqused_q is not None,
|
| 743 |
q_subtile_factor=q_subtile_factor,
|
| 744 |
use_2cta_instrs=use_2cta_instrs,
|
| 745 |
+
use_clc_scheduler=requested_use_clc_scheduler,
|
| 746 |
+
)
|
| 747 |
+
elif arch // 10 == 12:
|
| 748 |
+
# SM120 (Blackwell GeForce / DGX Spark): uses SM80 MMA with SM120 SMEM capacity
|
| 749 |
+
assert not use_block_sparsity, "Block sparsity not supported on SM 12.0"
|
| 750 |
+
assert page_table is None, "Paged KV not supported on SM 12.0 in this PR"
|
| 751 |
+
assert not is_split_kv, "SplitKV not supported on SM 12.0 in this PR"
|
| 752 |
+
fa_fwd = FlashAttentionForwardSm120(
|
| 753 |
+
dtype,
|
| 754 |
+
head_dim,
|
| 755 |
+
head_dim_v,
|
| 756 |
+
qhead_per_kvhead,
|
| 757 |
+
is_causal=causal,
|
| 758 |
+
is_local=local,
|
| 759 |
+
pack_gqa=pack_gqa,
|
| 760 |
+
tile_m=tile_m,
|
| 761 |
+
tile_n=tile_n,
|
| 762 |
+
num_stages=1,
|
| 763 |
+
num_threads=num_threads,
|
| 764 |
+
Q_in_regs=False,
|
| 765 |
+
score_mod=score_mod,
|
| 766 |
+
mask_mod=mask_mod,
|
| 767 |
+
has_aux_tensors=aux_tensors is not None,
|
| 768 |
)
|
| 769 |
else:
|
| 770 |
raise ValueError(
|
| 771 |
+
f"Unsupported compute capability: {arch}. Supported: 8.x, 9.x, 10.x, 11.x, 12.x"
|
| 772 |
)
|
| 773 |
# TODO: check @can_implement
|
| 774 |
_flash_attn_fwd.compile_cache[compile_key] = cute.compile(
|
|
|
|
| 779 |
o_tensor,
|
| 780 |
lse_tensor,
|
| 781 |
softmax_scale,
|
|
|
|
| 782 |
cu_seqlens_q_tensor,
|
| 783 |
cu_seqlens_k_tensor,
|
| 784 |
seqused_q_tensor,
|
|
|
|
| 789 |
learnable_sink_tensor,
|
| 790 |
sparse_tensors,
|
| 791 |
cute_aux_tensors,
|
| 792 |
+
current_stream,
|
| 793 |
options="--enable-tvm-ffi",
|
| 794 |
)
|
| 795 |
|
|
|
|
| 805 |
out.detach() if not is_split_kv else out_partial,
|
| 806 |
lse_partial if is_split_kv else lse,
|
| 807 |
softmax_scale,
|
|
|
|
| 808 |
cu_seqlens_q,
|
| 809 |
cu_seqlens_k,
|
| 810 |
seqused_q,
|
|
|
|
| 831 |
_flash_attn_fwd.compile_cache = get_jit_cache("fwd")
|
| 832 |
|
| 833 |
|
| 834 |
+
def make_fake_bwd_tensors(dtype, has_gqa, varlen_q, varlen_k):
|
| 835 |
+
sym = cute.sym_int
|
| 836 |
+
# divisibility in elements: assumed_align_bytes = divisibility * dtype.width // 8
|
| 837 |
+
# For 16-byte align: fp16/bf16 → divisibility=8, float32 → divisibility=4
|
| 838 |
+
div = 128 // dtype.width # 8 for fp16/bf16
|
| 839 |
+
# Shared sym_ints for dimensions that must match across tensors
|
| 840 |
+
b, seqlen_q, seqlen_k, h_q, d, d_v = sym(), sym(), sym(), sym(), sym(), sym()
|
| 841 |
+
h_kv = h_q if not has_gqa else sym()
|
| 842 |
+
seqlen_q_rounded, seqlen_k_rounded = sym(), sym()
|
| 843 |
+
seqlen_q_d_rounded, seqlen_k_d_rounded, seqlen_k_dv_rounded = sym(), sym(), sym()
|
| 844 |
+
total_q, total_k, total_q_rounded, total_k_rounded = sym(), sym(), sym(), sym()
|
| 845 |
+
total_q_d_rounded, total_k_d_rounded, total_k_dv_rounded = sym(), sym(), sym()
|
| 846 |
+
b_seqlenq = (b, seqlen_q) if not varlen_q else (total_q,)
|
| 847 |
+
b_seqlenk = (b, seqlen_k) if not varlen_k else (total_k,)
|
| 848 |
+
mQ = fake_tensor(dtype, (*b_seqlenq, h_q, d), divisibility=div)
|
| 849 |
+
mO = fake_tensor(dtype, (*b_seqlenq, h_q, d_v), divisibility=div)
|
| 850 |
+
mdO = fake_tensor(dtype, (*b_seqlenq, h_q, d_v), divisibility=div)
|
| 851 |
+
mK = fake_tensor(dtype, (*b_seqlenk, h_kv, d), divisibility=div)
|
| 852 |
+
mV = fake_tensor(dtype, (*b_seqlenk, h_kv, d_v), divisibility=div)
|
| 853 |
+
mdQ = fake_tensor(dtype, (*b_seqlenq, h_q, d), divisibility=div)
|
| 854 |
+
mdK = fake_tensor(dtype, (*b_seqlenk, h_kv, d), divisibility=div)
|
| 855 |
+
mdV = fake_tensor(dtype, (*b_seqlenk, h_kv, d_v), divisibility=div)
|
| 856 |
+
if not varlen_q:
|
| 857 |
+
mLSE = fake_tensor(Float32, (b, h_q, seqlen_q), divisibility=1)
|
| 858 |
+
mLSElog2 = fake_tensor(Float32, (b, h_q, seqlen_q_rounded), divisibility=4)
|
| 859 |
+
mPdPsum = fake_tensor(Float32, (b, h_q, seqlen_q_rounded), divisibility=4)
|
| 860 |
+
dQaccum = fake_tensor(Float32, (b, h_q, seqlen_q_d_rounded), divisibility=4)
|
| 861 |
+
else:
|
| 862 |
+
mLSE = fake_tensor(Float32, (h_q, total_q), divisibility=1)
|
| 863 |
+
mLSElog2 = fake_tensor(Float32, (h_q, total_q_rounded), divisibility=4)
|
| 864 |
+
mPdPsum = fake_tensor(Float32, (h_q, total_q_rounded), divisibility=4)
|
| 865 |
+
dQaccum = fake_tensor(Float32, (h_q, total_q_d_rounded), divisibility=4)
|
| 866 |
+
if not has_gqa:
|
| 867 |
+
mdKaccum, mdVaccum = None, None
|
| 868 |
+
else:
|
| 869 |
+
if not varlen_k:
|
| 870 |
+
mdKaccum = fake_tensor(Float32, (b, h_kv, seqlen_k_rounded), divisibility=4)
|
| 871 |
+
mdVaccum = fake_tensor(Float32, (b, h_kv, seqlen_k_dv_rounded), divisibility=4)
|
| 872 |
+
else:
|
| 873 |
+
mdKaccum = fake_tensor(Float32, (h_kv, total_k_rounded), divisibility=4)
|
| 874 |
+
mdVaccum = fake_tensor(Float32, (h_kv, total_k_dv_rounded), divisibility=4)
|
| 875 |
+
return mQ, mK, mV, mO, mdO, mdQ, mdK, mdV, mLSE, mLSElog2, mPdPsum, dQaccum, mdKaccum, mdVaccum
|
| 876 |
+
|
| 877 |
+
|
| 878 |
+
def _compile_bwd_preprocess(
|
| 879 |
+
dtype, head_dim, head_dim_v, m_block_size, has_cuseqlens_q, has_seqused_q, has_dlse,
|
| 880 |
+
):
|
| 881 |
+
"""Compile bwd preprocess kernel using cute fake tensors (no real GPU tensors needed)."""
|
| 882 |
+
mQ, mK, mV, mO, mdO, mdQ, mdK, mdV, mLSE, mLSElog2, mPdPsum, mdQaccum, mdKaccum, mdVaccum = make_fake_bwd_tensors(
|
| 883 |
+
dtype, has_gqa=True, varlen_q=has_cuseqlens_q, varlen_k=False
|
| 884 |
+
)
|
| 885 |
+
batch = mQ.shape[0] if not has_cuseqlens_q else cute.sym_int()
|
| 886 |
+
batchp1 = cute.sym_int()
|
| 887 |
+
mCuSeqlensQ = fake_tensor(Int32, (batchp1,), divisibility=1) if has_cuseqlens_q else None
|
| 888 |
+
mSequsedQ = fake_tensor(Int32, (batch,), divisibility=1) if has_seqused_q else None
|
| 889 |
+
mdLSE = fake_tensor(Float32, mLSE.shape, divisibility=1) if has_dlse else None
|
| 890 |
+
fa_bwd_pre = FlashAttentionBackwardPreprocess(dtype, head_dim, head_dim_v, m_block_size)
|
| 891 |
+
return cute.compile(
|
| 892 |
+
fa_bwd_pre, mO, mdO, mPdPsum, mLSE, mLSElog2, mdQaccum, mCuSeqlensQ, mSequsedQ, mdLSE,
|
| 893 |
+
cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
|
| 894 |
+
options="--enable-tvm-ffi",
|
| 895 |
+
)
|
| 896 |
+
|
| 897 |
+
|
| 898 |
+
def _bwd_preprocess(
|
| 899 |
+
out, dout, dpsum, lse, lse_log2, dq_accum,
|
| 900 |
+
cu_seqlens_q, seqused_q, dlse,
|
| 901 |
+
dtype, head_dim, head_dim_v, m_block_size,
|
| 902 |
+
):
|
| 903 |
+
"""Backward preprocess: compute (o * dout).sum(dim=-1) - dLSE, lse * log2_e, and zero out dq_accum."""
|
| 904 |
+
is_varlen = cu_seqlens_q is not None
|
| 905 |
+
compile_key = (
|
| 906 |
+
dtype, head_dim, head_dim_v, m_block_size, is_varlen, seqused_q is not None, dlse is not None,
|
| 907 |
+
)
|
| 908 |
+
if compile_key not in _bwd_preprocess.compile_cache:
|
| 909 |
+
_bwd_preprocess.compile_cache[compile_key] = _compile_bwd_preprocess(*compile_key)
|
| 910 |
+
if not is_fake_mode():
|
| 911 |
+
_bwd_preprocess.compile_cache[compile_key](
|
| 912 |
+
out, dout, dpsum, lse, lse_log2, dq_accum, cu_seqlens_q, seqused_q, dlse
|
| 913 |
+
)
|
| 914 |
+
|
| 915 |
+
|
| 916 |
+
_bwd_preprocess.compile_cache = get_jit_cache("bwd_pre")
|
| 917 |
+
|
| 918 |
+
|
| 919 |
+
def _compile_bwd_postprocess(
|
| 920 |
+
dtype, hdim, block_size, num_threads, atom_layout, swap_ab,
|
| 921 |
+
has_cuseqlens_q, has_seqused_q,
|
| 922 |
+
use_2cta_instrs, cluster_size, arch,
|
| 923 |
+
):
|
| 924 |
+
"""Compile bwd postprocess kernel using cute fake tensors."""
|
| 925 |
+
mQ, mK, mV, mO, mdO, mdQ, mdK, mdV, mLSE, mLSElog2, mPdPsum, mdQaccum, mdKaccum, mdVaccum = make_fake_bwd_tensors(
|
| 926 |
+
dtype, has_gqa=True, varlen_q=has_cuseqlens_q, varlen_k=False
|
| 927 |
+
)
|
| 928 |
+
batch = mQ.shape[0] if not has_cuseqlens_q else cute.sym_int()
|
| 929 |
+
batchp1 = cute.sym_int()
|
| 930 |
+
mCuSeqlensQ = fake_tensor(Int32, (batchp1,), divisibility=1) if has_cuseqlens_q else None
|
| 931 |
+
mSeqUsedQ = fake_tensor(Int32, (batch,), divisibility=1) if has_seqused_q else None
|
| 932 |
+
fa_bwd_post = FlashAttentionBackwardPostprocess(
|
| 933 |
+
dtype, hdim, arch, block_size, num_threads, atom_layout, swap_ab,
|
| 934 |
+
use_2cta_instrs=use_2cta_instrs,
|
| 935 |
+
cluster_size=cluster_size,
|
| 936 |
+
)
|
| 937 |
+
return cute.compile(
|
| 938 |
+
fa_bwd_post, mdQaccum, mdQ, Float32(0.0), mCuSeqlensQ, mSeqUsedQ,
|
| 939 |
+
cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
|
| 940 |
+
options="--enable-tvm-ffi",
|
| 941 |
+
)
|
| 942 |
+
|
| 943 |
+
|
| 944 |
+
def _bwd_postprocess_convert(
|
| 945 |
+
accum, output, scale,
|
| 946 |
+
cu_seqlens, seqused,
|
| 947 |
+
arch, dtype, hdim, block_size, num_threads,
|
| 948 |
+
atom_layout, swap_ab,
|
| 949 |
+
use_2cta_instrs=False, cluster_size=1,
|
| 950 |
+
):
|
| 951 |
+
"""Backward postprocess: convert float32 accumulator to bf16/fp16 output."""
|
| 952 |
+
compile_key = (
|
| 953 |
+
dtype, hdim, block_size, num_threads, atom_layout, swap_ab,
|
| 954 |
+
cu_seqlens is not None, seqused is not None,
|
| 955 |
+
use_2cta_instrs, cluster_size, arch,
|
| 956 |
+
)
|
| 957 |
+
if compile_key not in _bwd_postprocess_convert.compile_cache:
|
| 958 |
+
_bwd_postprocess_convert.compile_cache[compile_key] = _compile_bwd_postprocess(*compile_key)
|
| 959 |
+
if not is_fake_mode():
|
| 960 |
+
_bwd_postprocess_convert.compile_cache[compile_key](
|
| 961 |
+
accum, output, scale, cu_seqlens, seqused,
|
| 962 |
+
)
|
| 963 |
+
|
| 964 |
+
|
| 965 |
+
_bwd_postprocess_convert.compile_cache = get_jit_cache("bwd_post")
|
| 966 |
+
|
| 967 |
+
|
| 968 |
def _flash_attn_bwd(
|
| 969 |
q: torch.Tensor,
|
| 970 |
k: torch.Tensor,
|
|
|
|
| 1005 |
mask_mod: Optional[Callable] = None,
|
| 1006 |
aux_tensors: Optional[list[torch.Tensor]] = None,
|
| 1007 |
block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None,
|
| 1008 |
+
dlse: Optional[torch.Tensor] = None,
|
| 1009 |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 1010 |
arch = _get_device_arch()
|
| 1011 |
+
assert arch // 10 in [9, 10, 11, 12], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x, 12.x"
|
| 1012 |
+
sparse_q = None
|
| 1013 |
+
if block_sparse_tensors is not None and arch // 10 == 9:
|
| 1014 |
+
sparse_q = block_sparse_tensors.block_size[0] if block_sparse_tensors.block_size is not None else 128
|
| 1015 |
|
| 1016 |
num_head, head_dim = q.shape[-2:]
|
| 1017 |
+
head_dim_v = v.shape[-1]
|
| 1018 |
|
| 1019 |
+
causal, local, window_size_left, window_size_right = _resolve_causal_local_window(
|
| 1020 |
+
causal, window_size_left, window_size_right
|
| 1021 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1022 |
|
| 1023 |
+
if arch // 10 == 12:
|
| 1024 |
+
# SM120: uses SM80 MMA with 99 KB SMEM, 128 threads (4 warps).
|
| 1025 |
+
m_block_size = 64
|
| 1026 |
+
n_block_size = 64
|
| 1027 |
+
if head_dim <= 64:
|
| 1028 |
+
num_stages_Q = 2
|
| 1029 |
+
num_stages_dO = 2
|
| 1030 |
+
else:
|
| 1031 |
+
num_stages_Q = 1
|
| 1032 |
+
num_stages_dO = 1
|
| 1033 |
+
SdP_swapAB = False
|
| 1034 |
dKV_swapAB = False
|
| 1035 |
+
dQ_swapAB = False
|
| 1036 |
+
AtomLayoutMSdP = 4
|
| 1037 |
+
AtomLayoutNdKV = 4
|
| 1038 |
+
AtomLayoutMdQ = 4
|
| 1039 |
+
V_in_regs = False
|
| 1040 |
+
cluster_size = 1
|
| 1041 |
+
use_2cta_instrs = False
|
| 1042 |
+
num_threads = 128
|
| 1043 |
+
assert not (block_sparse_tensors is not None), "Block sparsity backward not supported on SM 12.0"
|
| 1044 |
+
assert score_mod is None and score_mod_bwd is None, "score_mod backward not supported on SM 12.0"
|
| 1045 |
+
assert mask_mod is None, "mask_mod backward not supported on SM 12.0"
|
| 1046 |
+
assert deterministic is False, "deterministic backward not supported on SM 12.0"
|
| 1047 |
+
elif arch // 10 == 9:
|
| 1048 |
+
cfg = _tile_size_bwd_sm90(
|
| 1049 |
+
head_dim,
|
| 1050 |
+
head_dim_v,
|
| 1051 |
+
causal,
|
| 1052 |
+
local,
|
| 1053 |
+
sparse_block_size_q=sparse_q,
|
| 1054 |
+
)
|
| 1055 |
+
m_block_size = cfg.m_block_size
|
| 1056 |
+
n_block_size = cfg.n_block_size
|
| 1057 |
+
num_stages_Q = cfg.num_stages_Q
|
| 1058 |
+
num_stages_dO = cfg.num_stages_dO
|
| 1059 |
+
num_stages_PdS = cfg.num_stages_PdS
|
| 1060 |
+
SdP_swapAB = cfg.SdP_swapAB
|
| 1061 |
+
dKV_swapAB = cfg.dKV_swapAB
|
| 1062 |
+
dQ_swapAB = cfg.dQ_swapAB
|
| 1063 |
+
AtomLayoutMSdP = cfg.AtomLayoutMSdP
|
| 1064 |
+
AtomLayoutNdKV = cfg.AtomLayoutNdKV
|
| 1065 |
+
AtomLayoutMdQ = cfg.AtomLayoutMdQ
|
| 1066 |
+
num_threads = (cfg.num_wg + 1) * 128
|
| 1067 |
+
dQ_single_wg = cfg.dQ_single_wg
|
| 1068 |
cluster_size = 1
|
| 1069 |
use_2cta_instrs = False
|
|
|
|
| 1070 |
is_varlen = (
|
| 1071 |
cu_seqlens_q is not None
|
| 1072 |
or cu_seqlens_k is not None
|
| 1073 |
or seqused_q is not None
|
| 1074 |
or seqused_k is not None
|
| 1075 |
)
|
|
|
|
| 1076 |
else:
|
| 1077 |
m_block_size = 128
|
| 1078 |
n_block_size = 128
|
|
|
|
| 1080 |
dKV_swapAB = False
|
| 1081 |
AtomLayoutMdQ = 1
|
| 1082 |
AtomLayoutNdKV = 1
|
| 1083 |
+
requested_disable_2cta = utils._get_disable_2cta_default()
|
| 1084 |
disable_2cta = (
|
| 1085 |
+
requested_disable_2cta
|
| 1086 |
or score_mod is not None
|
| 1087 |
or score_mod_bwd is not None
|
| 1088 |
or mask_mod is not None
|
| 1089 |
+
or block_sparse_tensors is not None
|
| 1090 |
)
|
| 1091 |
cluster_size = 2 if head_dim >= 128 and not disable_2cta else 1
|
| 1092 |
use_2cta_instrs = cluster_size==2
|
| 1093 |
+
|
| 1094 |
q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = [
|
| 1095 |
maybe_contiguous(t)
|
| 1096 |
for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
|
|
|
|
| 1112 |
seqlen_k = max_seqlen_k if max_seqlen_k is not None else total_k
|
| 1113 |
|
| 1114 |
num_head_kv = k.shape[-2]
|
|
|
|
| 1115 |
|
| 1116 |
use_block_sparsity = block_sparse_tensors is not None
|
| 1117 |
+
subtile_factor = sparse_q // m_block_size if sparse_q is not None else 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1118 |
seqlen_q_rounded = (seqlen_q + m_block_size - 1) // m_block_size * m_block_size
|
| 1119 |
seqlen_k_rounded = (seqlen_k + n_block_size - 1) // n_block_size * n_block_size
|
| 1120 |
num_n_blocks = seqlen_k_rounded // n_block_size
|
|
|
|
| 1154 |
if t is not None:
|
| 1155 |
assert t.dtype == torch.int32, "cu_seqlens_q, cu_seqlens_k must be int32"
|
| 1156 |
assert lse.dtype == torch.float32, "lse must be float32"
|
| 1157 |
+
if dlse is not None:
|
| 1158 |
+
dlse = maybe_contiguous(dlse)
|
| 1159 |
+
if not is_fake_mode():
|
| 1160 |
+
assert all(
|
| 1161 |
+
t is None or t.is_cuda for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k)
|
| 1162 |
+
), "inputs must be on CUDA device"
|
| 1163 |
assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv"
|
|
|
|
| 1164 |
alignment = 16 // q.element_size()
|
| 1165 |
+
if arch // 10 != 12:
|
| 1166 |
+
_validate_head_dims(head_dim, head_dim_v, arch // 10, alignment)
|
| 1167 |
if softmax_scale is None:
|
| 1168 |
softmax_scale = 1.0 / math.sqrt(head_dim)
|
| 1169 |
qhead_per_kvhead = num_head // num_head_kv
|
|
|
|
| 1171 |
pack_gqa = qhead_per_kvhead > 1
|
| 1172 |
# pack_gqa backward not yet supported in bwd
|
| 1173 |
pack_gqa = False
|
|
|
|
|
|
|
|
|
|
| 1174 |
if score_mod is not None:
|
| 1175 |
assert score_mod_bwd is not None, "score_mod_bwd is required when score_mod is provided"
|
| 1176 |
assert softcap == 0.0, "softcap and score_mod are mutually exclusive (different log2 scaling)"
|
|
|
|
| 1222 |
dpsum = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device)
|
| 1223 |
lse_log2 = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device)
|
| 1224 |
|
| 1225 |
+
# GQA (qhead_per_kvhead > 1) needs dK/dV accum+postprocess since multiple Q heads
|
| 1226 |
+
# accumulate into the same dK/dV. SM90 varlen_k with qhead_per_kvhead==1 now uses
|
| 1227 |
+
# ragged TMA tensors for direct store, so no longer needs accum+postprocess.
|
| 1228 |
dKV_postprocess = qhead_per_kvhead > 1
|
| 1229 |
if dKV_postprocess:
|
| 1230 |
head_dim_v_rounded = (head_dim_v + 32 - 1) // 32 * 32
|
|
|
|
| 1262 |
)
|
| 1263 |
|
| 1264 |
dtype = torch2cute_dtype_map[q.dtype]
|
| 1265 |
+
current_stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True)
|
| 1266 |
|
| 1267 |
if deterministic:
|
| 1268 |
+
dQ_semaphore = torch.zeros(batch_size, num_head, seqlen_q_rounded // m_block_size, cluster_size, dtype=torch.int32, device=device)
|
| 1269 |
else:
|
| 1270 |
dQ_semaphore = None
|
| 1271 |
|
| 1272 |
if deterministic and qhead_per_kvhead > 1:
|
| 1273 |
+
dK_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device=device)
|
| 1274 |
+
dV_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device=device)
|
| 1275 |
else:
|
| 1276 |
dK_semaphore = None
|
| 1277 |
dV_semaphore = None
|
| 1278 |
|
| 1279 |
+
# Preprocess kernel: compute (o * dout).sum(dim=-1) - dLSE, lse * log2_e, and zero out dq_accum.
|
| 1280 |
+
_bwd_preprocess(
|
| 1281 |
+
out, dout, dpsum, lse, lse_log2, dq_accum,
|
| 1282 |
+
cu_seqlens_q, seqused_q, dlse,
|
| 1283 |
+
dtype, head_dim, head_dim_v, m_block_size,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1284 |
)
|
| 1285 |
+
# num_threads: SM90 derives from BwdConfig.num_wg, SM120 is set to 128 above,
|
| 1286 |
+
# SM100/SM110 uses default from function signature (384).
|
| 1287 |
+
if arch // 10 not in [9, 12]:
|
| 1288 |
+
num_threads = 384
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1289 |
|
| 1290 |
# Backward kernel: compute dk, dv, dq_accum.
|
| 1291 |
score_mod_hash = utils.hash_callable(score_mod) if score_mod else False
|
|
|
|
| 1312 |
subtile_factor=subtile_factor,
|
| 1313 |
)
|
| 1314 |
|
| 1315 |
+
if arch // 10 in [8, 9, 12]:
|
| 1316 |
compile_key = (
|
| 1317 |
arch,
|
| 1318 |
dtype,
|
|
|
|
| 1320 |
head_dim_v,
|
| 1321 |
qhead_per_kvhead,
|
| 1322 |
causal,
|
| 1323 |
+
window_size_left is not None,
|
| 1324 |
+
window_size_right is not None,
|
| 1325 |
softcap != 0.0,
|
| 1326 |
m_block_size,
|
| 1327 |
n_block_size,
|
|
|
|
| 1336 |
AtomLayoutNdKV,
|
| 1337 |
AtomLayoutMdQ,
|
| 1338 |
V_in_regs,
|
| 1339 |
+
dQ_single_wg,
|
| 1340 |
+
deterministic,
|
| 1341 |
cu_seqlens_q is None,
|
| 1342 |
cu_seqlens_k is None,
|
| 1343 |
seqused_q is None,
|
|
|
|
| 1406 |
if t is not None else None
|
| 1407 |
for t in (dQ_semaphore, dK_semaphore, dV_semaphore)
|
| 1408 |
]
|
| 1409 |
+
if arch // 10 in [8, 12]:
|
| 1410 |
+
flash_bwd_obj_cls = FlashAttentionBackwardSm120 if arch // 10 == 12 else FlashAttentionBackwardSm80
|
| 1411 |
+
fa_bwd_obj = flash_bwd_obj_cls(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1412 |
dtype,
|
| 1413 |
head_dim,
|
| 1414 |
head_dim_v,
|
| 1415 |
qhead_per_kvhead,
|
|
|
|
| 1416 |
m_block_size,
|
| 1417 |
n_block_size,
|
| 1418 |
num_stages_Q,
|
| 1419 |
num_stages_dO,
|
| 1420 |
+
num_threads,
|
| 1421 |
+
pack_gqa,
|
| 1422 |
+
causal,
|
| 1423 |
SdP_swapAB,
|
| 1424 |
dKV_swapAB,
|
| 1425 |
dQ_swapAB,
|
| 1426 |
AtomLayoutMSdP,
|
| 1427 |
AtomLayoutNdKV,
|
| 1428 |
AtomLayoutMdQ,
|
| 1429 |
+
V_in_regs=V_in_regs,
|
| 1430 |
+
)
|
| 1431 |
+
elif arch // 10 == 9:
|
| 1432 |
+
fa_bwd_obj = FlashAttentionBackwardSm90(
|
| 1433 |
+
dtype,
|
| 1434 |
+
head_dim,
|
| 1435 |
+
head_dim_v,
|
| 1436 |
+
qhead_per_kvhead,
|
| 1437 |
+
causal,
|
| 1438 |
+
is_local=local,
|
| 1439 |
+
deterministic=deterministic,
|
| 1440 |
+
tile_m=m_block_size,
|
| 1441 |
+
tile_n=n_block_size,
|
| 1442 |
+
Q_stage=num_stages_Q,
|
| 1443 |
+
dO_stage=num_stages_dO,
|
| 1444 |
+
PdS_stage=num_stages_PdS,
|
| 1445 |
+
SdP_swapAB=SdP_swapAB,
|
| 1446 |
+
dKV_swapAB=dKV_swapAB,
|
| 1447 |
+
dQ_swapAB=dQ_swapAB,
|
| 1448 |
+
AtomLayoutMSdP=AtomLayoutMSdP,
|
| 1449 |
+
AtomLayoutNdKV=AtomLayoutNdKV,
|
| 1450 |
+
AtomLayoutMdQ=AtomLayoutMdQ,
|
| 1451 |
+
num_threads=num_threads,
|
| 1452 |
V_in_regs=V_in_regs,
|
| 1453 |
score_mod=score_mod,
|
| 1454 |
score_mod_bwd=score_mod_bwd,
|
| 1455 |
mask_mod=mask_mod,
|
| 1456 |
has_aux_tensors=aux_tensors is not None,
|
| 1457 |
subtile_factor=subtile_factor,
|
| 1458 |
+
dQ_single_wg=dQ_single_wg,
|
| 1459 |
)
|
| 1460 |
else:
|
| 1461 |
fa_bwd_obj = FlashAttentionBackwardSm100(
|
|
|
|
| 1494 |
dk_tensor if not dKV_postprocess else dk_accum_tensor,
|
| 1495 |
dv_tensor if not dKV_postprocess else dv_accum_tensor,
|
| 1496 |
softmax_scale,
|
|
|
|
| 1497 |
cu_seqlens_q_tensor,
|
| 1498 |
cu_seqlens_k_tensor,
|
| 1499 |
seqused_q_tensor,
|
|
|
|
| 1506 |
dV_semaphore_tensor,
|
| 1507 |
cute_aux_tensors,
|
| 1508 |
sparse_tensors_compile,
|
| 1509 |
+
current_stream,
|
| 1510 |
options="--enable-tvm-ffi",
|
| 1511 |
)
|
| 1512 |
if not is_fake_mode():
|
|
|
|
| 1521 |
dk if not dKV_postprocess else dk_accum,
|
| 1522 |
dv if not dKV_postprocess else dv_accum,
|
| 1523 |
softmax_scale,
|
|
|
|
| 1524 |
cu_seqlens_q,
|
| 1525 |
cu_seqlens_k,
|
| 1526 |
seqused_q,
|
|
|
|
| 1535 |
normalized_block_sparse_tensors[:4] if normalized_block_sparse_tensors is not None else None,
|
| 1536 |
)
|
| 1537 |
|
| 1538 |
+
if arch // 10 == 9:
|
| 1539 |
+
# dQ postprocess: match main kernel's MMA WG count, unless dQ_single_wg
|
| 1540 |
+
num_threads_post_dQ = 128 if dQ_single_wg else cfg.num_wg * 128
|
| 1541 |
+
num_threads_post_dKV = cfg.num_wg * 128
|
| 1542 |
+
else:
|
| 1543 |
+
num_threads_post_dQ = 128
|
| 1544 |
+
num_threads_post_dKV = 128
|
| 1545 |
+
|
| 1546 |
+
# Postprocess: convert dq_accum from float32 to dq in bf16/fp16
|
| 1547 |
+
_bwd_postprocess_convert(
|
| 1548 |
+
dq_accum, dq, softmax_scale,
|
| 1549 |
+
cu_seqlens_q, seqused_q,
|
| 1550 |
+
arch, dtype, head_dim, m_block_size, num_threads_post_dQ,
|
| 1551 |
+
AtomLayoutMdQ, dQ_swapAB,
|
| 1552 |
+
use_2cta_instrs=use_2cta_instrs, cluster_size=1,
|
|
|
|
| 1553 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1554 |
|
| 1555 |
if dKV_postprocess:
|
| 1556 |
+
# Postprocess: convert dk_accum from float32 to dk in bf16/fp16
|
| 1557 |
+
_bwd_postprocess_convert(
|
| 1558 |
+
dk_accum, dk, softmax_scale,
|
| 1559 |
+
cu_seqlens_k, seqused_k,
|
| 1560 |
+
arch, dtype, head_dim, n_block_size, num_threads_post_dKV,
|
| 1561 |
+
AtomLayoutNdKV, dKV_swapAB,
|
| 1562 |
+
cluster_size=cluster_size,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1563 |
)
|
| 1564 |
+
# Postprocess: convert dv_accum from float32 to dv in bf16/fp16
|
| 1565 |
+
_bwd_postprocess_convert(
|
| 1566 |
+
dv_accum, dv, 1.0,
|
| 1567 |
+
cu_seqlens_k, seqused_k,
|
| 1568 |
+
arch, dtype, head_dim_v, n_block_size, num_threads_post_dKV,
|
| 1569 |
+
AtomLayoutNdKV, dKV_swapAB,
|
| 1570 |
+
cluster_size=cluster_size,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1571 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1572 |
|
| 1573 |
return dq, dk, dv
|
| 1574 |
|
| 1575 |
|
|
|
|
| 1576 |
_flash_attn_bwd.compile_cache = get_jit_cache("bwd")
|
|
|
|
| 1577 |
|
| 1578 |
|
| 1579 |
class FlashAttnFunc(torch.autograd.Function):
|
|
|
|
| 1631 |
ctx.window_size = window_size
|
| 1632 |
ctx.softcap = softcap
|
| 1633 |
ctx.deterministic = deterministic
|
| 1634 |
+
ctx.return_lse = return_lse
|
| 1635 |
+
ctx.set_materialize_grads(False)
|
|
|
|
| 1636 |
return out, lse
|
| 1637 |
|
| 1638 |
@staticmethod
|
| 1639 |
+
def backward(ctx, dout, dlse):
|
| 1640 |
q, k, v, out, lse = ctx.saved_tensors
|
| 1641 |
+
if not ctx.return_lse:
|
| 1642 |
+
dlse = None
|
| 1643 |
+
if dout is None:
|
| 1644 |
+
dout = torch.zeros_like(out)
|
| 1645 |
dq, dk, dv = _flash_attn_bwd(
|
| 1646 |
q,
|
| 1647 |
k,
|
|
|
|
| 1655 |
window_size_left=ctx.window_size[0],
|
| 1656 |
window_size_right=ctx.window_size[1],
|
| 1657 |
deterministic=ctx.deterministic,
|
| 1658 |
+
dlse=dlse,
|
| 1659 |
)
|
| 1660 |
return dq, dk, dv, *((None,) * 20) # Extra Nones is fine
|
| 1661 |
|
|
|
|
| 1717 |
ctx.deterministic = deterministic
|
| 1718 |
ctx.max_seqlen_q = max_seqlen_q
|
| 1719 |
ctx.max_seqlen_k = max_seqlen_k
|
| 1720 |
+
ctx.return_lse = return_lse
|
| 1721 |
+
ctx.set_materialize_grads(False)
|
|
|
|
| 1722 |
return out, lse
|
| 1723 |
|
| 1724 |
@staticmethod
|
| 1725 |
+
def backward(ctx, dout, dlse):
|
| 1726 |
q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
|
| 1727 |
assert ctx.softcap == 0.0
|
| 1728 |
+
if not ctx.return_lse:
|
| 1729 |
+
dlse = None
|
| 1730 |
+
if dout is None:
|
| 1731 |
+
dout = torch.zeros_like(out)
|
| 1732 |
dq, dk, dv = _flash_attn_bwd(
|
| 1733 |
q,
|
| 1734 |
k,
|
|
|
|
| 1748 |
max_seqlen_q=ctx.max_seqlen_q,
|
| 1749 |
max_seqlen_k=ctx.max_seqlen_k,
|
| 1750 |
deterministic=ctx.deterministic,
|
| 1751 |
+
dlse=dlse,
|
| 1752 |
)
|
| 1753 |
|
| 1754 |
return dq, dk, dv, *((None,) * 20)
|
|
|
|
| 1844 |
)
|
| 1845 |
|
| 1846 |
|
| 1847 |
+
def _compile_fwd_combine(
|
| 1848 |
+
dtype, dtype_partial, head_dim, tile_m, k_block_size, log_max_splits,
|
| 1849 |
+
has_cu_seqlens, has_seqused, has_lse, has_varlen_batch_idx,
|
| 1850 |
+
):
|
| 1851 |
+
"""Compile fwd combine kernel using cute fake tensors (no real GPU tensors needed)."""
|
| 1852 |
+
sym = cute.sym_int
|
| 1853 |
+
div = 128 // dtype_partial.width # 16-byte alignment in elements
|
| 1854 |
+
|
| 1855 |
+
fa_combine = FlashAttentionForwardCombine(
|
| 1856 |
+
dtype=dtype,
|
| 1857 |
+
dtype_partial=dtype_partial,
|
| 1858 |
+
head_dim=head_dim,
|
| 1859 |
+
tile_m=tile_m,
|
| 1860 |
+
k_block_size=k_block_size,
|
| 1861 |
+
log_max_splits=log_max_splits,
|
| 1862 |
+
)
|
| 1863 |
+
if not fa_combine.can_implement(
|
| 1864 |
+
dtype, dtype_partial, head_dim, tile_m, k_block_size, log_max_splits,
|
| 1865 |
+
num_threads=256,
|
| 1866 |
+
):
|
| 1867 |
+
raise RuntimeError(
|
| 1868 |
+
"FlashAttention combine kernel cannot be implemented with given parameters"
|
| 1869 |
+
)
|
| 1870 |
+
|
| 1871 |
+
if has_cu_seqlens:
|
| 1872 |
+
# Varlen: (num_splits, total_q, nheads, headdim)
|
| 1873 |
+
num_splits, total_q, nheads = sym(), sym(), sym()
|
| 1874 |
+
mO_partial = fake_tensor(dtype_partial, (num_splits, total_q, nheads, head_dim), divisibility=div)
|
| 1875 |
+
mLSE_partial = fake_tensor(Float32, (num_splits, total_q, nheads), divisibility=1, leading_dim=1)
|
| 1876 |
+
mO = fake_tensor(dtype, (total_q, nheads, head_dim), divisibility=div)
|
| 1877 |
+
mLSE = fake_tensor(Float32, (total_q, nheads), divisibility=1, leading_dim=0) if has_lse else None
|
| 1878 |
+
else:
|
| 1879 |
+
# Batched: (num_splits, batch, seqlen, nheads, headdim)
|
| 1880 |
+
num_splits, batch, seqlen, nheads = sym(), sym(), sym(), sym()
|
| 1881 |
+
mO_partial = fake_tensor(dtype_partial, (num_splits, batch, seqlen, nheads, head_dim), divisibility=div)
|
| 1882 |
+
mLSE_partial = fake_tensor(Float32, (num_splits, batch, seqlen, nheads), divisibility=1, leading_dim=2)
|
| 1883 |
+
mO = fake_tensor(dtype, (batch, seqlen, nheads, head_dim), divisibility=div)
|
| 1884 |
+
mLSE = fake_tensor(Float32, (batch, seqlen, nheads), divisibility=1, leading_dim=1) if has_lse else None
|
| 1885 |
+
batch = mO_partial.shape[1]
|
| 1886 |
+
|
| 1887 |
+
batch_for_1d = batch if not has_cu_seqlens else sym()
|
| 1888 |
+
batchp1 = sym()
|
| 1889 |
+
mCuSeqlens = fake_tensor(Int32, (batchp1,), divisibility=1) if has_cu_seqlens else None
|
| 1890 |
+
mSeqused = fake_tensor(Int32, (batch_for_1d,), divisibility=1) if has_seqused else None
|
| 1891 |
+
mNumSplitsDynamic = None # Not parametrized in compile_key
|
| 1892 |
+
mVarlenBatchIdx = fake_tensor(Int32, (batch_for_1d,), divisibility=1) if has_varlen_batch_idx else None
|
| 1893 |
+
mSemaphore = None # Not parametrized in compile_key
|
| 1894 |
+
|
| 1895 |
+
return cute.compile(
|
| 1896 |
+
fa_combine,
|
| 1897 |
+
mO_partial, mLSE_partial, mO, mLSE,
|
| 1898 |
+
mCuSeqlens, mSeqused, mNumSplitsDynamic, mVarlenBatchIdx, mSemaphore,
|
| 1899 |
+
cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
|
| 1900 |
+
options="--enable-tvm-ffi",
|
| 1901 |
+
)
|
| 1902 |
+
|
| 1903 |
+
|
| 1904 |
def _flash_attn_fwd_combine(
|
| 1905 |
out_partial: torch.Tensor,
|
| 1906 |
lse_partial: torch.Tensor,
|
|
|
|
| 1909 |
cu_seqlens: Optional[torch.Tensor] = None,
|
| 1910 |
seqused: Optional[torch.Tensor] = None,
|
| 1911 |
num_splits_dynamic_ptr: Optional[torch.Tensor] = None,
|
| 1912 |
+
varlen_batch_idx: Optional[torch.Tensor] = None,
|
| 1913 |
semaphore_to_reset: Optional[torch.Tensor] = None,
|
| 1914 |
) -> None:
|
| 1915 |
"""Forward combine kernel for split attention computation.
|
|
|
|
| 1933 |
Returns:
|
| 1934 |
None
|
| 1935 |
"""
|
|
|
|
|
|
|
|
|
|
| 1936 |
assert out_partial.dtype in [torch.float16, torch.bfloat16, torch.float32], (
|
| 1937 |
"out_partial must be fp16, bf16, or fp32"
|
| 1938 |
)
|
| 1939 |
+
if not is_fake_mode():
|
| 1940 |
+
assert out_partial.is_cuda and lse_partial.is_cuda, "tensors must be on CUDA device"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1941 |
# Determine if this is variable length based on dimensions
|
| 1942 |
is_varlen = out_partial.dim() == 4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1943 |
# Validate optional tensors
|
| 1944 |
for t, name in [
|
| 1945 |
(cu_seqlens, "cu_seqlens"),
|
|
|
|
| 1947 |
(num_splits_dynamic_ptr, "num_splits_dynamic_ptr"),
|
| 1948 |
]:
|
| 1949 |
if t is not None:
|
| 1950 |
+
if not is_fake_mode():
|
| 1951 |
+
assert t.is_cuda, f"{name} must be on CUDA device"
|
| 1952 |
assert t.is_contiguous(), f"{name} must be contiguous"
|
|
|
|
| 1953 |
head_dim = out_partial.shape[-1]
|
| 1954 |
num_splits = out_partial.shape[0]
|
| 1955 |
assert num_splits <= 256
|
|
|
|
| 1958 |
k_block_size = 64 if head_dim <= 64 else 128
|
| 1959 |
# We want kBlockM to be as small as possible to maximize parallelism.
|
| 1960 |
# E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats).
|
| 1961 |
+
tile_m = 8 if k_block_size % 128 == 0 else (16 if k_block_size % 64 == 0 else 32)
|
| 1962 |
log_max_splits = max(math.ceil(math.log2(num_splits)), 4)
|
| 1963 |
+
if tile_m == 8:
|
| 1964 |
# If kBlockM == 8 then the minimum number of splits is 32.
|
| 1965 |
# TODO: we can deal w this by using 128 threads instead
|
| 1966 |
log_max_splits = max(log_max_splits, 5)
|
| 1967 |
|
|
|
|
|
|
|
| 1968 |
# Create combine kernel configuration
|
| 1969 |
dtype = torch2cute_dtype_map[out.dtype]
|
| 1970 |
dtype_partial = torch2cute_dtype_map[out_partial.dtype]
|
|
|
|
| 1971 |
compile_key = (
|
| 1972 |
dtype,
|
| 1973 |
dtype_partial,
|
| 1974 |
head_dim,
|
| 1975 |
+
tile_m,
|
| 1976 |
k_block_size,
|
| 1977 |
log_max_splits,
|
| 1978 |
cu_seqlens is not None,
|
| 1979 |
seqused is not None,
|
| 1980 |
lse is not None,
|
| 1981 |
+
varlen_batch_idx is not None,
|
| 1982 |
)
|
|
|
|
| 1983 |
if compile_key not in _flash_attn_fwd_combine.compile_cache:
|
| 1984 |
+
_flash_attn_fwd_combine.compile_cache[compile_key] = _compile_fwd_combine(
|
| 1985 |
+
*compile_key
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1986 |
)
|
| 1987 |
if not is_fake_mode():
|
| 1988 |
_flash_attn_fwd_combine.compile_cache[compile_key](
|
| 1989 |
+
out_partial, lse_partial, out, lse,
|
| 1990 |
+
cu_seqlens, seqused, num_splits_dynamic_ptr, varlen_batch_idx,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1991 |
semaphore_to_reset,
|
|
|
|
| 1992 |
)
|
| 1993 |
|
| 1994 |
|
|
|
|
| 2002 |
out_dtype: Optional[torch.dtype] = None,
|
| 2003 |
cu_seqlens: Optional[torch.Tensor] = None,
|
| 2004 |
seqused: Optional[torch.Tensor] = None,
|
| 2005 |
+
varlen_batch_idx: Optional[torch.Tensor] = None,
|
| 2006 |
return_lse: bool = True,
|
| 2007 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 2008 |
"""Flash Attention combine function for split attention computation.
|
|
|
|
| 2022 |
out_dtype: Optional output dtype. If None, will use fp16/bf16 based on input.
|
| 2023 |
cu_seqlens: Cumulative sequence lengths for variable length sequences
|
| 2024 |
seqused: Used sequence lengths for each batch
|
| 2025 |
+
varlen_batch_idx: Optional mapping from virtual batch index to real batch index
|
| 2026 |
+
(int32 tensor of shape (batch_size,)). Used by persistent tile schedulers
|
| 2027 |
+
that reorder batch processing for load balancing.
|
| 2028 |
return_lse: Whether to return the combined LSE tensor. Default is True.
|
| 2029 |
|
| 2030 |
Returns:
|
|
|
|
| 2041 |
"""
|
| 2042 |
# Input validation
|
| 2043 |
assert out_partial.dim() in [4, 5], "out_partial must have 4 or 5 dimensions"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2044 |
# Determine if this is variable length based on dimensions
|
| 2045 |
is_varlen = out_partial.dim() == 4
|
|
|
|
| 2046 |
if is_varlen:
|
| 2047 |
# Variable length: (num_splits, total_q, num_heads, head_size)
|
| 2048 |
num_splits, total_q, num_heads, head_size = out_partial.shape
|
|
|
|
|
|
|
|
|
|
| 2049 |
batch_size = 1 # Treat as single batch for varlen
|
| 2050 |
seqlen = total_q
|
| 2051 |
else:
|
| 2052 |
# Regular batched: (num_splits, batch_size, seqlen, num_heads, head_size)
|
| 2053 |
num_splits, batch_size, seqlen, num_heads, head_size = out_partial.shape
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2054 |
# Determine output dtype
|
| 2055 |
if out_dtype is None:
|
| 2056 |
out_dtype = out_partial.dtype
|
|
|
|
| 2057 |
# Create output if not provided
|
| 2058 |
device = out_partial.device
|
| 2059 |
if out is None:
|
|
|
|
| 2063 |
out = torch.empty(
|
| 2064 |
batch_size, seqlen, num_heads, head_size, dtype=out_dtype, device=device
|
| 2065 |
)
|
|
|
|
| 2066 |
# Create lse output only if requested
|
| 2067 |
if return_lse:
|
| 2068 |
if is_varlen:
|
| 2069 |
+
lse = torch.empty(num_heads, total_q, dtype=torch.float32, device=device)
|
|
|
|
|
|
|
| 2070 |
else:
|
| 2071 |
+
lse = torch.empty(batch_size, num_heads, seqlen, dtype=torch.float32, device=device)
|
| 2072 |
+
lse = lse.transpose(-1, -2)
|
|
|
|
| 2073 |
else:
|
| 2074 |
lse = None
|
|
|
|
| 2075 |
_flash_attn_fwd_combine(
|
| 2076 |
out_partial,
|
| 2077 |
lse_partial,
|
|
|
|
| 2079 |
lse,
|
| 2080 |
cu_seqlens,
|
| 2081 |
seqused,
|
| 2082 |
+
varlen_batch_idx=varlen_batch_idx,
|
| 2083 |
)
|
| 2084 |
return out, lse
|
build/torch-cuda/mask.py
CHANGED
|
@@ -1,109 +1,102 @@
|
|
| 1 |
# Copyright (c) 2025, Tri Dao.
|
| 2 |
|
| 3 |
-
from typing import Optional, Callable
|
| 4 |
from dataclasses import dataclass
|
| 5 |
|
| 6 |
import cutlass
|
| 7 |
import cutlass.cute as cute
|
| 8 |
-
from cutlass import Float32, Int32, const_expr
|
| 9 |
|
| 10 |
from .quack import layout_utils
|
| 11 |
-
from . import utils
|
| 12 |
from .seqlen_info import SeqlenInfoQK
|
| 13 |
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
@cute.jit
|
| 16 |
-
def
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
col_limit_transformed = col_limit
|
| 25 |
-
ncol = const_expr(cute.size(X.shape[cute.rank(X) - 1]) if not rank1 else cute.size(X.shape))
|
| 26 |
-
# Ideally we'd move by 32 instead of 24, but mask >> i isn't correct for i == 31
|
| 27 |
-
for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)):
|
| 28 |
-
# Don't need to clamp to 32 since the shr.u32 instruction does that already
|
| 29 |
-
col_limit_right_s = max(col_limit_transformed - s * 24, 0)
|
| 30 |
-
# 0 -> 0b00...00, 1 -> 0b00...01, ..., 31 -> 0b01...11, 32 -> 0b11...11
|
| 31 |
-
mask = (1 << col_limit_right_s) - 1
|
| 32 |
-
# This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction
|
| 33 |
-
for i in cutlass.range_constexpr(min(24, ncol - s * 24)):
|
| 34 |
-
in_bound = cutlass.Boolean(mask & (1 << i))
|
| 35 |
-
c = s * 24 + i
|
| 36 |
-
if const_expr(rank1):
|
| 37 |
-
X[c] = X[c] if in_bound else -Float32.inf
|
| 38 |
-
# This is the equivalent of:
|
| 39 |
-
# X[s * 24 + i] = X[s * 24 + i] if col_limit_right_s <= i else -Float32.inf
|
| 40 |
-
else:
|
| 41 |
-
for r in cutlass.range_constexpr(cute.size(X.shape[0])):
|
| 42 |
-
X[r, c] = X[r, c] if in_bound else -Float32.inf
|
| 43 |
|
| 44 |
|
| 45 |
@cute.jit
|
| 46 |
-
def
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
row_limit_top % (num_rep * num_wg), num_rep
|
| 55 |
-
)
|
| 56 |
-
ncol = cute.size(X.shape)
|
| 57 |
-
# Ideally we'd move by 32 instead of 24, but mask >> i isn't correct for i == 31
|
| 58 |
-
for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)):
|
| 59 |
-
row_limit_top_s = max(row_limit_top_transformed - s * 24, 0)
|
| 60 |
-
# 0 -> 0b00...00, 1 -> 0b00...01, ..., 31 -> 0b01...11, 32 -> 0b11...11
|
| 61 |
-
mask = (1 << row_limit_top_s) - 1
|
| 62 |
-
# This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction
|
| 63 |
-
for i in cutlass.range_constexpr(min(24, ncol - s * 24)):
|
| 64 |
-
out_bound = cutlass.Boolean(mask & (1 << i))
|
| 65 |
-
c = s * 24 + i
|
| 66 |
-
X[c] = -Float32.inf if out_bound else X[c]
|
| 67 |
-
# tidx = cute.arch.thread_idx()[0] % 256
|
| 68 |
-
# if tidx == 128:
|
| 69 |
-
# cute.printf("tidx = {}, s = {}, i = {}, row_limit_top = {}, row_limit_top_s = {}, mask = {}, out_bound = {}", tidx, s, i, row_limit_top, row_limit_top_s, mask, out_bound)
|
| 70 |
|
| 71 |
|
| 72 |
@cute.jit
|
| 73 |
-
def
|
| 74 |
X: cute.Tensor,
|
| 75 |
-
|
| 76 |
-
|
| 77 |
) -> None:
|
| 78 |
-
"""
|
| 79 |
-
Dual-bound masking using two bitmasks for SM100, following mask_r2p.
|
| 80 |
-
Masks elements where: NOT (col_limit_left <= col < col_limit_right)
|
| 81 |
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
mask_range = mask_range = mask_right & ~ mask_left -> bits (right-1)..left are 1
|
| 86 |
"""
|
| 87 |
-
ncol = const_expr(cute.size(X.shape))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
-
for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)):
|
| 90 |
-
right_s = max(col_limit_right - s * 24, 0)
|
| 91 |
-
left_s = max(col_limit_left - s * 24, 0)
|
| 92 |
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
|
|
|
| 101 |
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
|
| 109 |
@dataclass(frozen=True)
|
|
@@ -161,8 +154,7 @@ class AttentionMask:
|
|
| 161 |
seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset
|
| 162 |
if const_expr(not mask_causal and not mask_local and mask_mod is None):
|
| 163 |
if const_expr(mask_seqlen):
|
| 164 |
-
|
| 165 |
-
r2p = const_expr(False and not self.swap_AB)
|
| 166 |
if const_expr(not r2p):
|
| 167 |
# traverse column index.
|
| 168 |
for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True):
|
|
@@ -170,7 +162,8 @@ class AttentionMask:
|
|
| 170 |
for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):
|
| 171 |
acc_S_mn[r, c] = -Float32.inf if oob else acc_S_mn[r, c]
|
| 172 |
else:
|
| 173 |
-
|
|
|
|
| 174 |
|
| 175 |
elif const_expr(
|
| 176 |
not mask_causal and not mask_local and mask_mod is not None
|
|
@@ -272,7 +265,12 @@ class AttentionMask:
|
|
| 272 |
else acc_S_mn[r, c]
|
| 273 |
)
|
| 274 |
else:
|
| 275 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
else: # Local
|
| 277 |
local_row_offset_right = (
|
| 278 |
causal_row_offset + self.window_size_right
|
|
@@ -284,6 +282,7 @@ class AttentionMask:
|
|
| 284 |
if const_expr(self.window_size_left is not None)
|
| 285 |
else None
|
| 286 |
)
|
|
|
|
| 287 |
for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):
|
| 288 |
if const_expr(self.qhead_per_kvhead_packgqa == 1):
|
| 289 |
row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m
|
|
@@ -302,13 +301,22 @@ class AttentionMask:
|
|
| 302 |
if const_expr(self.window_size_left is not None)
|
| 303 |
else 0
|
| 304 |
)
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
else: # swap_AB
|
| 313 |
assert self.qhead_per_kvhead_packgqa == 1
|
| 314 |
thr_row_offset = tScS_mn[0][ROW]
|
|
@@ -338,11 +346,18 @@ class AttentionMask:
|
|
| 338 |
# column, by setting row limit to be self.tile_m.
|
| 339 |
row_limit_top = (
|
| 340 |
self.tile_m
|
| 341 |
-
if col0 >= seqlenk_col_limit
|
| 342 |
-
else
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
)
|
| 344 |
-
# TODO: do we need col_limit_sink?
|
| 345 |
-
row_limit_bot = col0 - causal_row_offset + self.window_size_left
|
| 346 |
for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):
|
| 347 |
row_idx = t0ScS_mn[r, 0][ROW]
|
| 348 |
acc_S_mn[r, c] = (
|
|
@@ -392,7 +407,11 @@ class AttentionMask:
|
|
| 392 |
# For some reason the 2 lines above generate really bad SASS
|
| 393 |
acc_S[i] = -Float32.inf if tScS_t2r[i][1] >= seqlenk_col_limit else acc_S[i]
|
| 394 |
else:
|
| 395 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 396 |
|
| 397 |
elif const_expr(not mask_causal and not mask_local and mask_mod is not None):
|
| 398 |
# Block sparse case w/ mask_mod
|
|
@@ -445,12 +464,12 @@ class AttentionMask:
|
|
| 445 |
acc_S[i] = -Float32.inf if mask_row >= self.seqlen_q else acc_S[i]
|
| 446 |
|
| 447 |
else: # Causal or local
|
| 448 |
-
causal_row_offset =
|
| 449 |
row_idx = tScS_t2r[0][0] + m_block * self.tile_m
|
| 450 |
if const_expr(self.qhead_per_kvhead_packgqa != 1):
|
| 451 |
row_idx = row_idx // self.qhead_per_kvhead_packgqa
|
| 452 |
if const_expr(mask_causal):
|
| 453 |
-
col_limit_right = row_idx + causal_row_offset
|
| 454 |
if const_expr(mask_seqlen):
|
| 455 |
col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit)
|
| 456 |
# if cute.arch.thread_idx()[0] % 32 == 0:
|
|
@@ -460,15 +479,19 @@ class AttentionMask:
|
|
| 460 |
for i in cutlass.range(ncol, unroll_full=True):
|
| 461 |
acc_S[i] = -Float32.inf if tScS_t2r[i][1] >= col_limit_right else acc_S[i]
|
| 462 |
else:
|
| 463 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 464 |
else:
|
| 465 |
local_row_offset_right = (
|
| 466 |
-
causal_row_offset + self.window_size_right
|
| 467 |
if const_expr(self.window_size_right is not None)
|
| 468 |
else None
|
| 469 |
)
|
| 470 |
local_row_offset_left = (
|
| 471 |
-
causal_row_offset -
|
| 472 |
if const_expr(self.window_size_left is not None)
|
| 473 |
else None
|
| 474 |
)
|
|
@@ -493,8 +516,15 @@ class AttentionMask:
|
|
| 493 |
else acc_S[i]
|
| 494 |
)
|
| 495 |
else:
|
| 496 |
-
#
|
| 497 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 498 |
|
| 499 |
@cute.jit
|
| 500 |
def apply_mask_sm100_transposed(
|
|
@@ -634,7 +664,13 @@ class AttentionMask:
|
|
| 634 |
)
|
| 635 |
else:
|
| 636 |
num_rep = cute.size(tScS_t2r, mode=[0]) # 16 or 32
|
| 637 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 638 |
else:
|
| 639 |
if const_expr(self.window_size_right is not None):
|
| 640 |
row_limit_top = causal_offset - self.window_size_right
|
|
@@ -645,9 +681,31 @@ class AttentionMask:
|
|
| 645 |
if const_expr(mask_seqlen):
|
| 646 |
if seqlenk_col_limit <= 0:
|
| 647 |
row_limit_top = self.tile_m
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
local_mask
|
| 653 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# Copyright (c) 2025, Tri Dao.
|
| 2 |
|
| 3 |
+
from typing import Optional, Callable, TypeAlias
|
| 4 |
from dataclasses import dataclass
|
| 5 |
|
| 6 |
import cutlass
|
| 7 |
import cutlass.cute as cute
|
| 8 |
+
from cutlass import Float32, Int32, Uint32, const_expr
|
| 9 |
|
| 10 |
from .quack import layout_utils
|
| 11 |
+
from . import utils as utils
|
| 12 |
from .seqlen_info import SeqlenInfoQK
|
| 13 |
|
| 14 |
+
MaskGenFn: TypeAlias = Callable[[int], Uint32]
|
| 15 |
+
MASK_R2P_CHUNK_SIZE: int = 32
|
| 16 |
+
|
| 17 |
|
| 18 |
@cute.jit
|
| 19 |
+
def r2p_bitmask_below(limit: Int32, s: int) -> Uint32:
|
| 20 |
+
"""32-bit R2P bitmask keeping positions < limit (exclusive upper bound).
|
| 21 |
+
|
| 22 |
+
Positions 0..limit-1 in chunk `s` get bit=1 (keep), the rest bit=0 (mask).
|
| 23 |
+
Uses inline PTX to avoid shift-by-type-width UB.
|
| 24 |
+
"""
|
| 25 |
+
m = max((s + 1) * MASK_R2P_CHUNK_SIZE - limit, 0)
|
| 26 |
+
return utils.shr_u32(Uint32(0xFFFFFFFF), Uint32(m))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
@cute.jit
|
| 30 |
+
def r2p_bitmask_above(limit: Int32, s: int) -> Uint32:
|
| 31 |
+
"""32-bit R2P bitmask keeping positions >= limit (inclusive lower bound).
|
| 32 |
+
|
| 33 |
+
Positions limit..31 in chunk `s` get bit=1 (keep), the rest bit=0 (mask).
|
| 34 |
+
Uses inline PTX to avoid shift-by-type-width UB.
|
| 35 |
+
"""
|
| 36 |
+
n = max(limit - s * MASK_R2P_CHUNK_SIZE, 0)
|
| 37 |
+
return utils.shl_u32(Uint32(0xFFFFFFFF), Uint32(n))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
|
| 40 |
@cute.jit
|
| 41 |
+
def mask_r2p_lambda(
|
| 42 |
X: cute.Tensor,
|
| 43 |
+
mask_gen_fn: cutlass.Constexpr[MaskGenFn],
|
| 44 |
+
rank1: bool = False,
|
| 45 |
) -> None:
|
| 46 |
+
"""Apply R2P masking with a custom bitmask generator.
|
|
|
|
|
|
|
| 47 |
|
| 48 |
+
mask_gen_fn(chunk_idx: constexpr int) -> Uint32:
|
| 49 |
+
Returns a 32-bit bitmask for the chunk. Bit i set means column
|
| 50 |
+
chunk_idx * chunk_size + i is KEPT; bit i clear means masked to -inf.
|
|
|
|
| 51 |
"""
|
| 52 |
+
ncol = const_expr(cute.size(X.shape[cute.rank(X) - 1]) if not rank1 else cute.size(X.shape))
|
| 53 |
+
# 32-column chunks. The mask_gen_fn returns a Uint32 bitmask (1=keep).
|
| 54 |
+
CHUNK_SIZE = MASK_R2P_CHUNK_SIZE
|
| 55 |
+
for s in cutlass.range_constexpr(cute.ceil_div(ncol, CHUNK_SIZE)):
|
| 56 |
+
mask = mask_gen_fn(s)
|
| 57 |
+
# This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction
|
| 58 |
+
for i in cutlass.range_constexpr(min(CHUNK_SIZE, ncol - s * CHUNK_SIZE)):
|
| 59 |
+
in_bound = cutlass.Boolean(mask & (Uint32(1) << i))
|
| 60 |
+
c = s * CHUNK_SIZE + i
|
| 61 |
+
if const_expr(rank1):
|
| 62 |
+
X[c] = X[c] if in_bound else -Float32.inf
|
| 63 |
+
else:
|
| 64 |
+
for r in cutlass.range_constexpr(cute.size(X.shape[0])):
|
| 65 |
+
X[r, c] = X[r, c] if in_bound else -Float32.inf
|
| 66 |
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
+
@cute.jit
|
| 69 |
+
def sm90_col_to_r2p_idx(col_limit: Int32) -> Int32:
|
| 70 |
+
"""Transform SM90 MMA column coordinate to R2P element index.
|
| 71 |
|
| 72 |
+
SM90 MMA accumulator column indices are non-contiguous: 0, 1, 8, 9, 16, 17, ...
|
| 73 |
+
Element indices are contiguous: 0, 1, 2, 3, 4, 5, ...
|
| 74 |
+
This converts a column-space threshold to element-space for r2p_bitmask_below/above.
|
| 75 |
+
"""
|
| 76 |
+
return col_limit // 8 * 2 + min(col_limit % 8, 2)
|
| 77 |
|
| 78 |
+
|
| 79 |
+
@cute.jit
|
| 80 |
+
def row_to_r2p_idx(x: Int32, num_rep: int, num_wg: int) -> Int32:
|
| 81 |
+
"""Convert a row coordinate to an R2P element index in the warp-group interleaved layout.
|
| 82 |
+
|
| 83 |
+
In the SM100 backward pass, 2 warp groups share TMEM. The TMEM load atom
|
| 84 |
+
distributes rows in an interleaved pattern: elements 0..num_rep-1 map to
|
| 85 |
+
rows 0..num_rep-1 (warp group 0), elements num_rep..2*num_rep-1 map to
|
| 86 |
+
rows num_rep*num_wg..num_rep*num_wg+num_rep-1 (warp group 1), and so on.
|
| 87 |
+
Row-coordinate thresholds (causal limits, window bounds, uih_len) must be
|
| 88 |
+
converted to element indices before use with r2p_bitmask_above/below.
|
| 89 |
+
|
| 90 |
+
Rows not owned by this thread (in the gap between warp groups) are clamped
|
| 91 |
+
to the boundary element index, which is safe because R2P thresholds are
|
| 92 |
+
monotonic.
|
| 93 |
+
|
| 94 |
+
Example with num_rep=16, num_wg=2:
|
| 95 |
+
row 0 -> elem 0, row 15 -> elem 15,
|
| 96 |
+
row 16 -> elem 16 (clamped), row 31 -> elem 16 (clamped),
|
| 97 |
+
row 32 -> elem 16, row 33 -> elem 17, row 47 -> elem 31.
|
| 98 |
+
"""
|
| 99 |
+
return x // (num_rep * num_wg) * num_rep + min(x % (num_rep * num_wg), num_rep)
|
| 100 |
|
| 101 |
|
| 102 |
@dataclass(frozen=True)
|
|
|
|
| 154 |
seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset
|
| 155 |
if const_expr(not mask_causal and not mask_local and mask_mod is None):
|
| 156 |
if const_expr(mask_seqlen):
|
| 157 |
+
r2p = const_expr(not self.swap_AB)
|
|
|
|
| 158 |
if const_expr(not r2p):
|
| 159 |
# traverse column index.
|
| 160 |
for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True):
|
|
|
|
| 162 |
for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):
|
| 163 |
acc_S_mn[r, c] = -Float32.inf if oob else acc_S_mn[r, c]
|
| 164 |
else:
|
| 165 |
+
seqlenk_col_limit_r2p = sm90_col_to_r2p_idx(seqlenk_col_limit)
|
| 166 |
+
mask_r2p_lambda(acc_S_mn, lambda s: r2p_bitmask_below(seqlenk_col_limit_r2p, s))
|
| 167 |
|
| 168 |
elif const_expr(
|
| 169 |
not mask_causal and not mask_local and mask_mod is not None
|
|
|
|
| 265 |
else acc_S_mn[r, c]
|
| 266 |
)
|
| 267 |
else:
|
| 268 |
+
col_limit_r2p = sm90_col_to_r2p_idx(col_limit_right)
|
| 269 |
+
mask_r2p_lambda(
|
| 270 |
+
acc_S_mn[r, None],
|
| 271 |
+
lambda s: r2p_bitmask_below(col_limit_r2p, s),
|
| 272 |
+
rank1=True,
|
| 273 |
+
)
|
| 274 |
else: # Local
|
| 275 |
local_row_offset_right = (
|
| 276 |
causal_row_offset + self.window_size_right
|
|
|
|
| 282 |
if const_expr(self.window_size_left is not None)
|
| 283 |
else None
|
| 284 |
)
|
| 285 |
+
r2p_local = const_expr(not self.swap_AB)
|
| 286 |
for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):
|
| 287 |
if const_expr(self.qhead_per_kvhead_packgqa == 1):
|
| 288 |
row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m
|
|
|
|
| 301 |
if const_expr(self.window_size_left is not None)
|
| 302 |
else 0
|
| 303 |
)
|
| 304 |
+
if const_expr(not r2p_local):
|
| 305 |
+
# traverse column index.
|
| 306 |
+
for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True):
|
| 307 |
+
col_idx = t0ScS_mn[0, c][1]
|
| 308 |
+
if col_idx >= col_limit_right or col_idx < col_limit_left:
|
| 309 |
+
acc_S_mn[r, c] = -Float32.inf
|
| 310 |
+
else:
|
| 311 |
+
col_limit_right_r2p = sm90_col_to_r2p_idx(col_limit_right)
|
| 312 |
+
col_limit_left_r2p = sm90_col_to_r2p_idx(col_limit_left)
|
| 313 |
+
|
| 314 |
+
def mask_gen_fn(s: int) -> Uint32:
|
| 315 |
+
return r2p_bitmask_below(
|
| 316 |
+
col_limit_right_r2p, s
|
| 317 |
+
) & r2p_bitmask_above(col_limit_left_r2p, s)
|
| 318 |
+
|
| 319 |
+
mask_r2p_lambda(acc_S_mn[r, None], mask_gen_fn, rank1=True)
|
| 320 |
else: # swap_AB
|
| 321 |
assert self.qhead_per_kvhead_packgqa == 1
|
| 322 |
thr_row_offset = tScS_mn[0][ROW]
|
|
|
|
| 346 |
# column, by setting row limit to be self.tile_m.
|
| 347 |
row_limit_top = (
|
| 348 |
self.tile_m
|
| 349 |
+
if col0 >= seqlenk_col_limit and mask_seqlen
|
| 350 |
+
else (
|
| 351 |
+
col0 - causal_row_offset - self.window_size_right
|
| 352 |
+
if const_expr(self.window_size_right is not None)
|
| 353 |
+
else 0
|
| 354 |
+
)
|
| 355 |
+
)
|
| 356 |
+
row_limit_bot = (
|
| 357 |
+
col0 - causal_row_offset + self.window_size_left
|
| 358 |
+
if const_expr(self.window_size_left is not None)
|
| 359 |
+
else self.tile_m
|
| 360 |
)
|
|
|
|
|
|
|
| 361 |
for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):
|
| 362 |
row_idx = t0ScS_mn[r, 0][ROW]
|
| 363 |
acc_S_mn[r, c] = (
|
|
|
|
| 407 |
# For some reason the 2 lines above generate really bad SASS
|
| 408 |
acc_S[i] = -Float32.inf if tScS_t2r[i][1] >= seqlenk_col_limit else acc_S[i]
|
| 409 |
else:
|
| 410 |
+
mask_r2p_lambda(
|
| 411 |
+
acc_S,
|
| 412 |
+
lambda s: r2p_bitmask_below(seqlenk_col_limit, s),
|
| 413 |
+
rank1=True,
|
| 414 |
+
)
|
| 415 |
|
| 416 |
elif const_expr(not mask_causal and not mask_local and mask_mod is not None):
|
| 417 |
# Block sparse case w/ mask_mod
|
|
|
|
| 464 |
acc_S[i] = -Float32.inf if mask_row >= self.seqlen_q else acc_S[i]
|
| 465 |
|
| 466 |
else: # Causal or local
|
| 467 |
+
causal_row_offset = self.seqlen_k - n_block * self.tile_n - self.seqlen_q
|
| 468 |
row_idx = tScS_t2r[0][0] + m_block * self.tile_m
|
| 469 |
if const_expr(self.qhead_per_kvhead_packgqa != 1):
|
| 470 |
row_idx = row_idx // self.qhead_per_kvhead_packgqa
|
| 471 |
if const_expr(mask_causal):
|
| 472 |
+
col_limit_right = row_idx + causal_row_offset + 1
|
| 473 |
if const_expr(mask_seqlen):
|
| 474 |
col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit)
|
| 475 |
# if cute.arch.thread_idx()[0] % 32 == 0:
|
|
|
|
| 479 |
for i in cutlass.range(ncol, unroll_full=True):
|
| 480 |
acc_S[i] = -Float32.inf if tScS_t2r[i][1] >= col_limit_right else acc_S[i]
|
| 481 |
else:
|
| 482 |
+
mask_r2p_lambda(
|
| 483 |
+
acc_S,
|
| 484 |
+
lambda s: r2p_bitmask_below(col_limit_right, s),
|
| 485 |
+
rank1=True,
|
| 486 |
+
)
|
| 487 |
else:
|
| 488 |
local_row_offset_right = (
|
| 489 |
+
causal_row_offset + 1 + self.window_size_right
|
| 490 |
if const_expr(self.window_size_right is not None)
|
| 491 |
else None
|
| 492 |
)
|
| 493 |
local_row_offset_left = (
|
| 494 |
+
causal_row_offset - self.window_size_left
|
| 495 |
if const_expr(self.window_size_left is not None)
|
| 496 |
else None
|
| 497 |
)
|
|
|
|
| 516 |
else acc_S[i]
|
| 517 |
)
|
| 518 |
else:
|
| 519 |
+
# Dual-bound R2P masking for SM100.
|
| 520 |
+
# Masks elements where: NOT (col_limit_left <= col < col_limit_right)
|
| 521 |
+
|
| 522 |
+
def mask_gen_fn(s: int) -> Uint32:
|
| 523 |
+
return r2p_bitmask_below(col_limit_right, s) & r2p_bitmask_above(
|
| 524 |
+
col_limit_left, s
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
mask_r2p_lambda(acc_S, mask_gen_fn, rank1=True)
|
| 528 |
|
| 529 |
@cute.jit
|
| 530 |
def apply_mask_sm100_transposed(
|
|
|
|
| 664 |
)
|
| 665 |
else:
|
| 666 |
num_rep = cute.size(tScS_t2r, mode=[0]) # 16 or 32
|
| 667 |
+
num_wg = 2
|
| 668 |
+
row_limit = row_to_r2p_idx(row_limit_top, num_rep, num_wg)
|
| 669 |
+
mask_r2p_lambda(
|
| 670 |
+
acc_S,
|
| 671 |
+
lambda s: r2p_bitmask_above(row_limit, s),
|
| 672 |
+
rank1=True,
|
| 673 |
+
)
|
| 674 |
else:
|
| 675 |
if const_expr(self.window_size_right is not None):
|
| 676 |
row_limit_top = causal_offset - self.window_size_right
|
|
|
|
| 681 |
if const_expr(mask_seqlen):
|
| 682 |
if seqlenk_col_limit <= 0:
|
| 683 |
row_limit_top = self.tile_m
|
| 684 |
+
r2p = True
|
| 685 |
+
if const_expr(not r2p):
|
| 686 |
+
for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True):
|
| 687 |
+
row_idx = t0ScS_t2r[i][ROW]
|
| 688 |
+
local_mask = row_idx < row_limit_top
|
| 689 |
+
if const_expr(self.window_size_left is not None):
|
| 690 |
+
local_mask |= row_idx > row_limit_bot
|
| 691 |
+
acc_S[i] = -cutlass.Float32.inf if local_mask else acc_S[i]
|
| 692 |
+
else:
|
| 693 |
+
|
| 694 |
+
def mask_gen_fn(s: int) -> Uint32:
|
| 695 |
+
num_rep = cute.size(tScS_t2r, mode=[0])
|
| 696 |
+
num_wg = 2
|
| 697 |
+
|
| 698 |
+
row_limit = row_to_r2p_idx(row_limit_top, num_rep, num_wg)
|
| 699 |
+
mask = r2p_bitmask_above(row_limit, s)
|
| 700 |
+
|
| 701 |
+
if const_expr(self.window_size_left is not None):
|
| 702 |
+
row_limit_bottom = row_to_r2p_idx(row_limit_bot + 1, num_rep, num_wg)
|
| 703 |
+
mask = mask & r2p_bitmask_below(row_limit_bottom, s)
|
| 704 |
+
|
| 705 |
+
return mask
|
| 706 |
+
|
| 707 |
+
mask_r2p_lambda(
|
| 708 |
+
acc_S,
|
| 709 |
+
mask_gen_fn,
|
| 710 |
+
rank1=True,
|
| 711 |
+
)
|
build/torch-cuda/named_barrier.py
CHANGED
|
@@ -12,6 +12,19 @@ class NamedBarrierFwd(enum.IntEnum):
|
|
| 12 |
PEmpty = enum.auto()
|
| 13 |
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
class NamedBarrierBwd(enum.IntEnum):
|
| 16 |
Epilogue = enum.auto()
|
| 17 |
WarpSchedulerWG1 = enum.auto()
|
|
@@ -20,8 +33,10 @@ class NamedBarrierBwd(enum.IntEnum):
|
|
| 20 |
PdS = enum.auto()
|
| 21 |
dQFullWG0 = enum.auto()
|
| 22 |
dQFullWG1 = enum.auto()
|
|
|
|
| 23 |
dQEmptyWG0 = enum.auto()
|
| 24 |
dQEmptyWG1 = enum.auto()
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
class NamedBarrierBwdSm100(enum.IntEnum):
|
|
|
|
| 12 |
PEmpty = enum.auto()
|
| 13 |
|
| 14 |
|
| 15 |
+
class NamedBarrierFwdSm100(enum.IntEnum):
|
| 16 |
+
Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads()
|
| 17 |
+
TmemPtr = enum.auto()
|
| 18 |
+
SoftmaxStatsW0 = enum.auto()
|
| 19 |
+
SoftmaxStatsW1 = enum.auto()
|
| 20 |
+
SoftmaxStatsW2 = enum.auto()
|
| 21 |
+
SoftmaxStatsW3 = enum.auto()
|
| 22 |
+
SoftmaxStatsW4 = enum.auto()
|
| 23 |
+
SoftmaxStatsW5 = enum.auto()
|
| 24 |
+
SoftmaxStatsW6 = enum.auto()
|
| 25 |
+
SoftmaxStatsW7 = enum.auto()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
class NamedBarrierBwd(enum.IntEnum):
|
| 29 |
Epilogue = enum.auto()
|
| 30 |
WarpSchedulerWG1 = enum.auto()
|
|
|
|
| 33 |
PdS = enum.auto()
|
| 34 |
dQFullWG0 = enum.auto()
|
| 35 |
dQFullWG1 = enum.auto()
|
| 36 |
+
dQFullWG2 = enum.auto()
|
| 37 |
dQEmptyWG0 = enum.auto()
|
| 38 |
dQEmptyWG1 = enum.auto()
|
| 39 |
+
dQEmptyWG2 = enum.auto()
|
| 40 |
|
| 41 |
|
| 42 |
class NamedBarrierBwdSm100(enum.IntEnum):
|
build/torch-cuda/pack_gqa.py
CHANGED
|
@@ -1,25 +1,123 @@
|
|
| 1 |
# Copyright (c) 2025, Tri Dao.
|
| 2 |
|
|
|
|
|
|
|
| 3 |
|
| 4 |
import cutlass
|
| 5 |
import cutlass.cute as cute
|
|
|
|
|
|
|
| 6 |
|
| 7 |
from .quack import layout_utils
|
| 8 |
-
from . import utils
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
class PackGQA:
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
check_hdim_oob: cutlass.Constexpr[bool],
|
| 17 |
-
qhead_per_kvhead: cutlass.Constexpr[bool],
|
| 18 |
-
):
|
| 19 |
-
self.m_block_size = m_block_size
|
| 20 |
-
self.head_dim_padded = head_dim_padded
|
| 21 |
-
self.check_hdim_oob = check_hdim_oob
|
| 22 |
-
self.qhead_per_kvhead = qhead_per_kvhead
|
| 23 |
|
| 24 |
@cute.jit
|
| 25 |
def compute_ptr(
|
|
|
|
| 1 |
# Copyright (c) 2025, Tri Dao.
|
| 2 |
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Union, Tuple
|
| 5 |
|
| 6 |
import cutlass
|
| 7 |
import cutlass.cute as cute
|
| 8 |
+
from cutlass.cute.nvgpu import cpasync
|
| 9 |
+
|
| 10 |
|
| 11 |
from .quack import layout_utils
|
| 12 |
+
from . import utils as utils
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def pack_gqa_layout(T, qhead_per_kvhead, nheads_kv, head_idx):
|
| 16 |
+
"""Reshape a tensor to fold qhead_per_kvhead into the seqlen dimension (mode 0).
|
| 17 |
+
|
| 18 |
+
The head dimension is at mode ``head_idx``. Modes before it (1..head_idx-1)
|
| 19 |
+
are kept as-is (e.g. headdim for Q/O tensors), and modes after it are kept
|
| 20 |
+
as-is (e.g. batch).
|
| 21 |
+
|
| 22 |
+
For Q/O tensors (head_idx=2):
|
| 23 |
+
(seqlen_q, headdim, nheads, batch, ...) -> ((qhead_per_kvhead, seqlen_q), headdim, nheads_kv, batch, ...)
|
| 24 |
+
For LSE tensors (head_idx=1):
|
| 25 |
+
(seqlen_q, nheads, batch, ...) -> ((qhead_per_kvhead, seqlen_q), nheads_kv, batch, ...)
|
| 26 |
+
"""
|
| 27 |
+
head_stride = T.stride[head_idx]
|
| 28 |
+
shape_packed = (
|
| 29 |
+
(qhead_per_kvhead, T.shape[0]),
|
| 30 |
+
*[T.shape[i] for i in range(1, head_idx)],
|
| 31 |
+
nheads_kv,
|
| 32 |
+
*[T.shape[i] for i in range(head_idx + 1, len(T.shape))],
|
| 33 |
+
)
|
| 34 |
+
stride_packed = (
|
| 35 |
+
(head_stride, T.stride[0]),
|
| 36 |
+
*[T.stride[i] for i in range(1, head_idx)],
|
| 37 |
+
head_stride * qhead_per_kvhead,
|
| 38 |
+
*[T.stride[i] for i in range(head_idx + 1, len(T.shape))],
|
| 39 |
+
)
|
| 40 |
+
return cute.make_tensor(T.iterator, cute.make_layout(shape_packed, stride=stride_packed))
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def make_packgqa_tiled_tma_atom(
|
| 44 |
+
op: cute.atom.CopyOp,
|
| 45 |
+
gmem_tensor: cute.Tensor,
|
| 46 |
+
smem_layout: Union[cute.Layout, cute.ComposedLayout],
|
| 47 |
+
cta_tiler: Tuple[int, int],
|
| 48 |
+
qhead_per_kvhead: int,
|
| 49 |
+
head_idx: int,
|
| 50 |
+
):
|
| 51 |
+
# This packing and unpacking of the layout is so that we keep the same TMA dimension as usual.
|
| 52 |
+
# e.g. for (seqlen, d, nheads, b) layout, we still have 4D TMA after packing to
|
| 53 |
+
# ((nheads, seqlen), d, b).
|
| 54 |
+
# If we instead pack directly to ((qhead_per_kvhead, seqlen), d, nheads_kv, b) we'd have 5D TMA.
|
| 55 |
+
# Pack headdim and seqlen dim into 1: (seqlen, d, nheads, b) -> ((nheads, seqlen), d, b)
|
| 56 |
+
gmem_tensor = layout_utils.select(
|
| 57 |
+
gmem_tensor, [head_idx, *range(head_idx), *range(head_idx + 1, cute.rank(gmem_tensor))]
|
| 58 |
+
)
|
| 59 |
+
gmem_tensor = cute.group_modes(gmem_tensor, 0, 2)
|
| 60 |
+
assert cta_tiler[0] % qhead_per_kvhead == 0, (
|
| 61 |
+
"CTA tile size in the seqlen dimension must be divisible by qhead_per_kvhead"
|
| 62 |
+
)
|
| 63 |
+
tma_atom, tma_tensor = cpasync.make_tiled_tma_atom(
|
| 64 |
+
op,
|
| 65 |
+
gmem_tensor,
|
| 66 |
+
smem_layout,
|
| 67 |
+
((qhead_per_kvhead, cta_tiler[0] // qhead_per_kvhead), cta_tiler[1]), # No mcast
|
| 68 |
+
)
|
| 69 |
+
# Unpack from ((nheads, seqlen), d, b) -> ((qhead_per_kvhead, seqlen), d, nheads_kv, b)
|
| 70 |
+
T = tma_tensor
|
| 71 |
+
shape_packed = (
|
| 72 |
+
(qhead_per_kvhead, T.shape[0][1]),
|
| 73 |
+
*[T.shape[i] for i in range(1, head_idx)],
|
| 74 |
+
T.shape[0][0] // qhead_per_kvhead,
|
| 75 |
+
*[T.shape[i] for i in range(head_idx, len(T.shape))],
|
| 76 |
+
)
|
| 77 |
+
stride_packed = (
|
| 78 |
+
*[T.stride[i] for i in range(head_idx)],
|
| 79 |
+
T.stride[0][0] * qhead_per_kvhead,
|
| 80 |
+
*[T.stride[i] for i in range(head_idx, len(T.shape))],
|
| 81 |
+
)
|
| 82 |
+
tma_tensor = cute.make_tensor(T.iterator, cute.make_layout(shape_packed, stride=stride_packed))
|
| 83 |
+
return tma_atom, tma_tensor
|
| 84 |
|
| 85 |
|
| 86 |
+
def unpack_gqa_layout(T, qhead_per_kvhead, head_idx):
|
| 87 |
+
"""Reverse of pack_gqa_layout: unfold qhead_per_kvhead from the seqlen dimension (mode 0).
|
| 88 |
+
|
| 89 |
+
The head dimension is at mode ``head_idx``. Modes before it (1..head_idx-1)
|
| 90 |
+
are kept as-is (e.g. headdim for Q/O tensors), and modes after it are kept
|
| 91 |
+
as-is (e.g. batch).
|
| 92 |
+
|
| 93 |
+
For Q/O tensors (head_idx=2):
|
| 94 |
+
((qhead_per_kvhead, seqlen_q), headdim, nheads_kv, batch, ...) -> (seqlen_q, headdim, nheads, batch, ...)
|
| 95 |
+
For LSE tensors (head_idx=1):
|
| 96 |
+
((qhead_per_kvhead, seqlen_q), nheads_kv, batch, ...) -> (seqlen_q, nheads, batch, ...)
|
| 97 |
+
"""
|
| 98 |
+
seqlen_stride = T.stride[0][1]
|
| 99 |
+
head_stride = T.stride[0][0]
|
| 100 |
+
shape_unpacked = (
|
| 101 |
+
T.shape[0][1],
|
| 102 |
+
*[T.shape[i] for i in range(1, head_idx)],
|
| 103 |
+
T.shape[head_idx] * qhead_per_kvhead,
|
| 104 |
+
*[T.shape[i] for i in range(head_idx + 1, len(T.shape))],
|
| 105 |
+
)
|
| 106 |
+
stride_unpacked = (
|
| 107 |
+
seqlen_stride,
|
| 108 |
+
*[T.stride[i] for i in range(1, head_idx)],
|
| 109 |
+
head_stride,
|
| 110 |
+
*[T.stride[i] for i in range(head_idx + 1, len(T.shape))],
|
| 111 |
+
)
|
| 112 |
+
return cute.make_tensor(T.iterator, cute.make_layout(shape_unpacked, stride=stride_unpacked))
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
@dataclass
|
| 116 |
class PackGQA:
|
| 117 |
+
m_block_size: cutlass.Constexpr[int]
|
| 118 |
+
head_dim_padded: cutlass.Constexpr[int]
|
| 119 |
+
check_hdim_oob: cutlass.Constexpr[bool]
|
| 120 |
+
qhead_per_kvhead: cutlass.Constexpr[bool]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
@cute.jit
|
| 123 |
def compute_ptr(
|
build/torch-cuda/paged_kv.py
CHANGED
|
@@ -28,6 +28,9 @@ class PagedKVManager(ParamsBase):
|
|
| 28 |
head_dim_padded: cutlass.Constexpr[Int32]
|
| 29 |
head_dim_v_padded: cutlass.Constexpr[Int32]
|
| 30 |
|
|
|
|
|
|
|
|
|
|
| 31 |
gmem_threads_per_row: cutlass.Constexpr[Int32]
|
| 32 |
page_entry_per_thread: Int32
|
| 33 |
async_copy_elems: Int32
|
|
@@ -55,7 +58,11 @@ class PagedKVManager(ParamsBase):
|
|
| 55 |
head_dim_v_padded: cutlass.Constexpr[Int32],
|
| 56 |
num_threads: cutlass.Constexpr[Int32],
|
| 57 |
dtype: Type[cutlass.Numeric],
|
|
|
|
| 58 |
):
|
|
|
|
|
|
|
|
|
|
| 59 |
universal_copy_bits = 128
|
| 60 |
async_copy_elems = universal_copy_bits // dtype.width
|
| 61 |
dtype_bytes = dtype.width // 8
|
|
@@ -97,7 +104,8 @@ class PagedKVManager(ParamsBase):
|
|
| 97 |
else:
|
| 98 |
cV = cute.make_identity_tensor((n_block_size, head_dim_v_padded))
|
| 99 |
tVcV = gmem_thr_copy_KV.partition_S(cV)
|
| 100 |
-
|
|
|
|
| 101 |
|
| 102 |
return PagedKVManager(
|
| 103 |
mPageTable,
|
|
@@ -111,6 +119,8 @@ class PagedKVManager(ParamsBase):
|
|
| 111 |
num_threads,
|
| 112 |
head_dim_padded,
|
| 113 |
head_dim_v_padded,
|
|
|
|
|
|
|
| 114 |
gmem_threads_per_row,
|
| 115 |
page_entry_per_thread,
|
| 116 |
async_copy_elems,
|
|
@@ -146,13 +156,17 @@ class PagedKVManager(ParamsBase):
|
|
| 146 |
@cute.jit
|
| 147 |
def compute_X_ptr(self, K_or_V: str):
|
| 148 |
tPrXPtr = cute.make_rmem_tensor((self.page_entry_per_thread,), cutlass.Int64)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
for i in cutlass.range(self.page_entry_per_thread, unroll=1):
|
| 150 |
page = self.tPrPage[i]
|
| 151 |
page_offset = self.tPrPageOffset[i]
|
| 152 |
-
if const_expr(
|
| 153 |
-
tPrXPtr[i] = utils.elem_pointer(
|
| 154 |
else:
|
| 155 |
-
tPrXPtr[i] = utils.elem_pointer(
|
| 156 |
return tPrXPtr
|
| 157 |
|
| 158 |
@cute.jit
|
|
@@ -161,18 +175,24 @@ class PagedKVManager(ParamsBase):
|
|
| 161 |
|
| 162 |
tPrXPtr = self.compute_X_ptr(K_or_V)
|
| 163 |
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
cute.
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
|
| 177 |
head_dim = self.head_dim_v_padded if const_expr(K_or_V == "V") else self.head_dim_padded
|
| 178 |
cX = cute.make_identity_tensor((self.n_block_size, head_dim))
|
|
|
|
| 28 |
head_dim_padded: cutlass.Constexpr[Int32]
|
| 29 |
head_dim_v_padded: cutlass.Constexpr[Int32]
|
| 30 |
|
| 31 |
+
arch: cutlass.Constexpr[Int32]
|
| 32 |
+
v_gmem_transposed: cutlass.Constexpr[bool]
|
| 33 |
+
|
| 34 |
gmem_threads_per_row: cutlass.Constexpr[Int32]
|
| 35 |
page_entry_per_thread: Int32
|
| 36 |
async_copy_elems: Int32
|
|
|
|
| 58 |
head_dim_v_padded: cutlass.Constexpr[Int32],
|
| 59 |
num_threads: cutlass.Constexpr[Int32],
|
| 60 |
dtype: Type[cutlass.Numeric],
|
| 61 |
+
arch: cutlass.Constexpr[int] = 100,
|
| 62 |
):
|
| 63 |
+
# SM100 transposes V in gmem to (dv, page_size, num_pages);
|
| 64 |
+
# SM90 keeps V as (page_size, dv, num_pages), same layout as K.
|
| 65 |
+
v_gmem_transposed = arch != 90
|
| 66 |
universal_copy_bits = 128
|
| 67 |
async_copy_elems = universal_copy_bits // dtype.width
|
| 68 |
dtype_bytes = dtype.width // 8
|
|
|
|
| 104 |
else:
|
| 105 |
cV = cute.make_identity_tensor((n_block_size, head_dim_v_padded))
|
| 106 |
tVcV = gmem_thr_copy_KV.partition_S(cV)
|
| 107 |
+
# When V is transposed in gmem, dv is shape[0]; otherwise dv is shape[1] (same as K)
|
| 108 |
+
tVpV = utils.predicate_k(tVcV, limit=mV_paged.shape[0 if v_gmem_transposed else 1])
|
| 109 |
|
| 110 |
return PagedKVManager(
|
| 111 |
mPageTable,
|
|
|
|
| 119 |
num_threads,
|
| 120 |
head_dim_padded,
|
| 121 |
head_dim_v_padded,
|
| 122 |
+
arch,
|
| 123 |
+
v_gmem_transposed,
|
| 124 |
gmem_threads_per_row,
|
| 125 |
page_entry_per_thread,
|
| 126 |
async_copy_elems,
|
|
|
|
| 156 |
@cute.jit
|
| 157 |
def compute_X_ptr(self, K_or_V: str):
|
| 158 |
tPrXPtr = cute.make_rmem_tensor((self.page_entry_per_thread,), cutlass.Int64)
|
| 159 |
+
mX = self.mK_paged if const_expr(K_or_V == "K") else self.mV_paged
|
| 160 |
+
# K is always (page_size, d, num_pages). V matches K when not transposed,
|
| 161 |
+
# but is (dv, page_size, num_pages) when transposed (SM100).
|
| 162 |
+
transposed = const_expr(K_or_V == "V" and self.v_gmem_transposed)
|
| 163 |
for i in cutlass.range(self.page_entry_per_thread, unroll=1):
|
| 164 |
page = self.tPrPage[i]
|
| 165 |
page_offset = self.tPrPageOffset[i]
|
| 166 |
+
if const_expr(transposed):
|
| 167 |
+
tPrXPtr[i] = utils.elem_pointer(mX, (0, page_offset, page)).toint()
|
| 168 |
else:
|
| 169 |
+
tPrXPtr[i] = utils.elem_pointer(mX, (page_offset, 0, page)).toint()
|
| 170 |
return tPrXPtr
|
| 171 |
|
| 172 |
@cute.jit
|
|
|
|
| 175 |
|
| 176 |
tPrXPtr = self.compute_X_ptr(K_or_V)
|
| 177 |
|
| 178 |
+
if const_expr(self.arch == 90):
|
| 179 |
+
# SM90: sX is already stage-sliced by caller (sK[None, None, stage]).
|
| 180 |
+
# Flatten hierarchical modes to get (n_block_size, head_dim).
|
| 181 |
+
sX_pi = cute.group_modes(sX, 0, 1)
|
| 182 |
+
# SM90 does NOT transpose V here (it's transposed via utils.transpose_view before MMA)
|
| 183 |
+
else:
|
| 184 |
+
# SM100: Finesse sX layout to be (M, N).
|
| 185 |
+
sX_pi = cute.make_tensor(
|
| 186 |
+
sX.iterator,
|
| 187 |
+
cute.make_layout(
|
| 188 |
+
(sX.shape[0][0], (sX.shape[0][1], sX.shape[2])),
|
| 189 |
+
stride=(sX.stride[0][0], (sX.stride[0][1], sX.stride[2])),
|
| 190 |
+
),
|
| 191 |
+
)
|
| 192 |
|
| 193 |
+
if const_expr(K_or_V == "V"):
|
| 194 |
+
# Transpose smem V to match transposed gmem layout
|
| 195 |
+
sX_pi = cute.make_tensor(sX_pi.iterator, cute.select(sX_pi.layout, mode=[1, 0]))
|
| 196 |
|
| 197 |
head_dim = self.head_dim_v_padded if const_expr(K_or_V == "V") else self.head_dim_padded
|
| 198 |
cX = cute.make_identity_tensor((self.n_block_size, head_dim))
|
build/torch-cuda/pipeline.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
# Copyright (c) 2025, Tri Dao.
|
| 2 |
|
| 3 |
-
# import math
|
| 4 |
from typing import Optional
|
| 5 |
from dataclasses import dataclass
|
| 6 |
|
|
@@ -11,12 +10,31 @@ from cutlass.pipeline import PipelineState
|
|
| 11 |
from cutlass.pipeline import PipelineUserType
|
| 12 |
from cutlass.pipeline import NamedBarrier as NamedBarrierOg
|
| 13 |
from cutlass.pipeline import PipelineAsync as PipelineAsyncOg
|
|
|
|
| 14 |
from cutlass.pipeline import PipelineTmaAsync as PipelineTmaAsyncOg
|
| 15 |
from cutlass.pipeline import PipelineTmaUmma as PipelineTmaUmmaOg
|
| 16 |
from cutlass.pipeline import PipelineUmmaAsync as PipelineUmmaAsyncOg
|
| 17 |
from cutlass.pipeline import PipelineAsyncUmma as PipelineAsyncUmmaOg
|
| 18 |
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
class PipelineStateSimple:
|
| 21 |
"""
|
| 22 |
Pipeline state contains an index and phase bit corresponding to the current position in the circular buffer.
|
|
@@ -25,9 +43,6 @@ class PipelineStateSimple:
|
|
| 25 |
"""
|
| 26 |
|
| 27 |
def __init__(self, stages: int, phase_index: Int32):
|
| 28 |
-
# assert stages < 2**16
|
| 29 |
-
# self._log_stages = int(math.log2(stages))
|
| 30 |
-
# assert 1 << self._log_stages == stages, "Number of stages must be a power of 2."
|
| 31 |
self._stages = stages
|
| 32 |
self._phase_index = phase_index
|
| 33 |
|
|
@@ -36,13 +51,10 @@ class PipelineStateSimple:
|
|
| 36 |
|
| 37 |
@property
|
| 38 |
def stages(self) -> int:
|
| 39 |
-
# return 1 << self._log_stages
|
| 40 |
return self._stages
|
| 41 |
|
| 42 |
@property
|
| 43 |
def index(self) -> Int32:
|
| 44 |
-
# return self._phase_index & 0xFFFF
|
| 45 |
-
# return self._phase_index & ((1 << self._log_stages) - 1)
|
| 46 |
if const_expr(self._stages == 1):
|
| 47 |
return Int32(0)
|
| 48 |
else:
|
|
@@ -50,11 +62,8 @@ class PipelineStateSimple:
|
|
| 50 |
|
| 51 |
@property
|
| 52 |
def phase(self) -> Int32:
|
| 53 |
-
# return self._phase_index >> 16
|
| 54 |
# PTX docs say that the phase parity needs to be 0 or 1, so by right we need to
|
| 55 |
# take modulo 2. But in practice just passing the phase in without modulo works fine.
|
| 56 |
-
# return (self._phase_index >> self._log_stages) % 2
|
| 57 |
-
# return self._phase_index >> self._log_stages
|
| 58 |
if const_expr(self._stages == 1):
|
| 59 |
return self._phase_index
|
| 60 |
else:
|
|
@@ -66,21 +75,6 @@ class PipelineStateSimple:
|
|
| 66 |
else:
|
| 67 |
self._phase_index += 1
|
| 68 |
|
| 69 |
-
# def then_body(phase_index):
|
| 70 |
-
# # XOR the phase bit and set the index to 0
|
| 71 |
-
# return (phase_index & 0xFFFF0000) ^ (1 << 16)
|
| 72 |
-
|
| 73 |
-
# def else_body(phase_index):
|
| 74 |
-
# return phase_index
|
| 75 |
-
|
| 76 |
-
# self._phase_index = if_generate(
|
| 77 |
-
# (self._phase_index & 0xFFFF) == self.stages,
|
| 78 |
-
# then_body,
|
| 79 |
-
# else_body,
|
| 80 |
-
# [self._phase_index],
|
| 81 |
-
# [Int32],
|
| 82 |
-
# )
|
| 83 |
-
|
| 84 |
def __extract_mlir_values__(self):
|
| 85 |
phase_index = self._phase_index
|
| 86 |
return [phase_index.ir_value()]
|
|
@@ -94,7 +88,6 @@ def make_pipeline_state(type: PipelineUserType, stages: int):
|
|
| 94 |
Creates a pipeline state. Producers are assumed to start with an empty buffer and have a flipped phase bit of 1.
|
| 95 |
"""
|
| 96 |
if type is PipelineUserType.Producer:
|
| 97 |
-
# return PipelineStateSimple(stages, Int32(1 << 16))
|
| 98 |
return PipelineStateSimple(stages, Int32(stages))
|
| 99 |
elif type is PipelineUserType.Consumer:
|
| 100 |
return PipelineStateSimple(stages, Int32(0))
|
|
@@ -102,14 +95,73 @@ def make_pipeline_state(type: PipelineUserType, stages: int):
|
|
| 102 |
assert False, "Error: invalid PipelineUserType specified for make_pipeline_state."
|
| 103 |
|
| 104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
@dataclass(frozen=True)
|
| 106 |
class NamedBarrier(NamedBarrierOg):
|
| 107 |
-
|
| 108 |
-
def create(*args, **kwargs):
|
| 109 |
-
obj = NamedBarrierOg.create(*args, **kwargs)
|
| 110 |
-
# Can't assign to __class__ directly since the dataclass is frozen
|
| 111 |
-
object.__setattr__(obj, "__class__", NamedBarrier)
|
| 112 |
-
return obj
|
| 113 |
|
| 114 |
@dsl_user_op
|
| 115 |
def arrive_w_index(self, index: Int32, *, loc=None, ip=None) -> None:
|
|
@@ -134,72 +186,121 @@ class NamedBarrier(NamedBarrierOg):
|
|
| 134 |
)
|
| 135 |
|
| 136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
@dataclass(frozen=True)
|
| 138 |
-
class PipelineAsync(PipelineAsyncOg):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
@staticmethod
|
| 140 |
-
def create(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
obj = PipelineAsyncOg.create(*args, **kwargs)
|
| 142 |
-
# Can't assign to __class__ directly since the dataclass is frozen
|
| 143 |
-
# obj.__class__ = PipelineAsync
|
| 144 |
object.__setattr__(obj, "__class__", PipelineAsync)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
return obj
|
| 146 |
|
| 147 |
@dsl_user_op
|
| 148 |
-
def
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
if_generate(
|
| 158 |
-
try_acquire_token is None or try_acquire_token == 0,
|
| 159 |
-
lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip),
|
| 160 |
-
loc=loc,
|
| 161 |
-
ip=ip,
|
| 162 |
)
|
| 163 |
|
| 164 |
@dsl_user_op
|
| 165 |
-
def
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
*,
|
| 175 |
-
loc=None,
|
| 176 |
-
ip=None,
|
| 177 |
-
):
|
| 178 |
-
if_generate(
|
| 179 |
-
try_wait_token is None or try_wait_token == 0,
|
| 180 |
-
lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip),
|
| 181 |
-
loc=loc,
|
| 182 |
-
ip=ip,
|
| 183 |
)
|
| 184 |
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
|
|
|
|
|
|
| 188 |
|
| 189 |
|
| 190 |
@dataclass(frozen=True)
|
| 191 |
-
class
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
"""
|
| 195 |
|
| 196 |
@staticmethod
|
| 197 |
-
def create(
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
return obj
|
| 202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
@dsl_user_op
|
| 204 |
def producer_acquire(
|
| 205 |
self,
|
|
@@ -226,19 +327,15 @@ class PipelineTmaAsync(PipelineTmaAsyncOg):
|
|
| 226 |
self.sync_object_full.arrive_and_expect_tx(state.index, tx_count, loc=loc, ip=ip)
|
| 227 |
|
| 228 |
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
"""
|
| 234 |
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
# obj.__class__ = PipelineTmaUmma
|
| 240 |
-
object.__setattr__(obj, "__class__", PipelineTmaUmma)
|
| 241 |
-
return obj
|
| 242 |
|
| 243 |
@dsl_user_op
|
| 244 |
def producer_acquire(
|
|
@@ -279,162 +376,27 @@ class PipelineTmaUmma(PipelineTmaUmmaOg):
|
|
| 279 |
ip=ip,
|
| 280 |
)
|
| 281 |
|
| 282 |
-
@dsl_user_op
|
| 283 |
-
def producer_acquire_w_index_phase(
|
| 284 |
-
self,
|
| 285 |
-
index: Int32,
|
| 286 |
-
phase: Int32,
|
| 287 |
-
try_acquire_token: Optional[Boolean] = None,
|
| 288 |
-
*,
|
| 289 |
-
loc=None,
|
| 290 |
-
ip=None,
|
| 291 |
-
):
|
| 292 |
-
"""
|
| 293 |
-
TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks.
|
| 294 |
-
"""
|
| 295 |
-
if_generate(
|
| 296 |
-
try_acquire_token is None or try_acquire_token == 0,
|
| 297 |
-
lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip),
|
| 298 |
-
loc=loc,
|
| 299 |
-
ip=ip,
|
| 300 |
-
)
|
| 301 |
-
if_generate(
|
| 302 |
-
self.is_leader_cta,
|
| 303 |
-
lambda: self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip),
|
| 304 |
-
loc=loc,
|
| 305 |
-
ip=ip,
|
| 306 |
-
)
|
| 307 |
|
| 308 |
-
|
| 309 |
-
def consumer_wait_w_index_phase(
|
| 310 |
-
self,
|
| 311 |
-
index: Int32,
|
| 312 |
-
phase: Int32,
|
| 313 |
-
try_wait_token: Optional[Boolean] = None,
|
| 314 |
-
*,
|
| 315 |
-
loc=None,
|
| 316 |
-
ip=None,
|
| 317 |
-
):
|
| 318 |
-
if_generate(
|
| 319 |
-
try_wait_token is None or try_wait_token == 0,
|
| 320 |
-
lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip),
|
| 321 |
-
loc=loc,
|
| 322 |
-
ip=ip,
|
| 323 |
-
)
|
| 324 |
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
"""
|
| 328 |
-
UMMA consumer release buffer empty, cta_group needs to be provided.
|
| 329 |
-
"""
|
| 330 |
-
self.sync_object_empty.arrive(index, self.consumer_mask, self.cta_group, loc=loc, ip=ip)
|
| 331 |
|
| 332 |
|
| 333 |
@dataclass(frozen=True)
|
| 334 |
-
class PipelineUmmaAsync(PipelineUmmaAsyncOg):
|
| 335 |
-
|
| 336 |
-
def create(*args, **kwargs):
|
| 337 |
-
obj = PipelineUmmaAsyncOg.create(*args, **kwargs)
|
| 338 |
-
# Can't assign to __class__ directly since the dataclass is frozen
|
| 339 |
-
object.__setattr__(obj, "__class__", PipelineUmmaAsync)
|
| 340 |
-
return obj
|
| 341 |
|
| 342 |
-
@dsl_user_op
|
| 343 |
-
def producer_acquire_w_index_phase(
|
| 344 |
-
self,
|
| 345 |
-
index: Int32,
|
| 346 |
-
phase: Int32,
|
| 347 |
-
try_acquire_token: Optional[Boolean] = None,
|
| 348 |
-
*,
|
| 349 |
-
loc=None,
|
| 350 |
-
ip=None,
|
| 351 |
-
):
|
| 352 |
-
if_generate(
|
| 353 |
-
try_acquire_token is None or try_acquire_token == 0,
|
| 354 |
-
lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip),
|
| 355 |
-
loc=loc,
|
| 356 |
-
ip=ip,
|
| 357 |
-
)
|
| 358 |
|
| 359 |
-
|
| 360 |
-
def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None):
|
| 361 |
-
"""
|
| 362 |
-
UMMA producer commit buffer full, cta_group needs to be provided.
|
| 363 |
-
"""
|
| 364 |
-
self.sync_object_full.arrive(index, self.producer_mask, self.cta_group, loc=loc, ip=ip)
|
| 365 |
|
| 366 |
-
@dsl_user_op
|
| 367 |
-
def consumer_wait_w_index_phase(
|
| 368 |
-
self,
|
| 369 |
-
index: Int32,
|
| 370 |
-
phase: Int32,
|
| 371 |
-
try_wait_token: Optional[Boolean] = None,
|
| 372 |
-
*,
|
| 373 |
-
loc=None,
|
| 374 |
-
ip=None,
|
| 375 |
-
):
|
| 376 |
-
if_generate(
|
| 377 |
-
try_wait_token is None or try_wait_token == 0,
|
| 378 |
-
lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip),
|
| 379 |
-
loc=loc,
|
| 380 |
-
ip=ip,
|
| 381 |
-
)
|
| 382 |
|
| 383 |
-
|
| 384 |
-
def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None):
|
| 385 |
-
self.sync_object_empty.arrive(index, self.consumer_mask, loc=loc, ip=ip)
|
| 386 |
|
| 387 |
|
| 388 |
@dataclass(frozen=True)
|
| 389 |
-
class PipelineAsyncUmma(PipelineAsyncUmmaOg):
|
| 390 |
-
|
| 391 |
-
def create(*args, **kwargs):
|
| 392 |
-
obj = PipelineAsyncUmmaOg.create(*args, **kwargs)
|
| 393 |
-
# Can't assign to __class__ directly since the dataclass is frozen
|
| 394 |
-
object.__setattr__(obj, "__class__", PipelineAsyncUmma)
|
| 395 |
-
return obj
|
| 396 |
|
| 397 |
-
@dsl_user_op
|
| 398 |
-
def producer_acquire_w_index_phase(
|
| 399 |
-
self,
|
| 400 |
-
index: Int32,
|
| 401 |
-
phase: Int32,
|
| 402 |
-
try_acquire_token: Optional[Boolean] = None,
|
| 403 |
-
*,
|
| 404 |
-
loc=None,
|
| 405 |
-
ip=None,
|
| 406 |
-
):
|
| 407 |
-
if_generate(
|
| 408 |
-
try_acquire_token is None or try_acquire_token == 0,
|
| 409 |
-
lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip),
|
| 410 |
-
loc=loc,
|
| 411 |
-
ip=ip,
|
| 412 |
-
)
|
| 413 |
-
|
| 414 |
-
@dsl_user_op
|
| 415 |
-
def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None):
|
| 416 |
-
self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip)
|
| 417 |
-
|
| 418 |
-
@dsl_user_op
|
| 419 |
-
def consumer_wait_w_index_phase(
|
| 420 |
-
self,
|
| 421 |
-
index: Int32,
|
| 422 |
-
phase: Int32,
|
| 423 |
-
try_wait_token: Optional[Boolean] = None,
|
| 424 |
-
*,
|
| 425 |
-
loc=None,
|
| 426 |
-
ip=None,
|
| 427 |
-
):
|
| 428 |
-
if_generate(
|
| 429 |
-
try_wait_token is None or try_wait_token == 0,
|
| 430 |
-
lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip),
|
| 431 |
-
loc=loc,
|
| 432 |
-
ip=ip,
|
| 433 |
-
)
|
| 434 |
|
| 435 |
-
|
| 436 |
-
def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None):
|
| 437 |
-
"""
|
| 438 |
-
UMMA consumer release buffer empty, cta_group needs to be provided.
|
| 439 |
-
"""
|
| 440 |
-
self.sync_object_empty.arrive(index, self.consumer_mask, self.cta_group, loc=loc, ip=ip)
|
|
|
|
| 1 |
# Copyright (c) 2025, Tri Dao.
|
| 2 |
|
|
|
|
| 3 |
from typing import Optional
|
| 4 |
from dataclasses import dataclass
|
| 5 |
|
|
|
|
| 10 |
from cutlass.pipeline import PipelineUserType
|
| 11 |
from cutlass.pipeline import NamedBarrier as NamedBarrierOg
|
| 12 |
from cutlass.pipeline import PipelineAsync as PipelineAsyncOg
|
| 13 |
+
from cutlass.pipeline import PipelineCpAsync as PipelineCpAsyncOg
|
| 14 |
from cutlass.pipeline import PipelineTmaAsync as PipelineTmaAsyncOg
|
| 15 |
from cutlass.pipeline import PipelineTmaUmma as PipelineTmaUmmaOg
|
| 16 |
from cutlass.pipeline import PipelineUmmaAsync as PipelineUmmaAsyncOg
|
| 17 |
from cutlass.pipeline import PipelineAsyncUmma as PipelineAsyncUmmaOg
|
| 18 |
|
| 19 |
|
| 20 |
+
def _override_create(parent_cls, child_cls):
|
| 21 |
+
"""Create a static factory that constructs parent_cls then re-classes to child_cls."""
|
| 22 |
+
|
| 23 |
+
@staticmethod
|
| 24 |
+
def create(*args, **kwargs):
|
| 25 |
+
obj = parent_cls.create(*args, **kwargs)
|
| 26 |
+
# Can't assign to __class__ directly since the dataclass is frozen
|
| 27 |
+
object.__setattr__(obj, "__class__", child_cls)
|
| 28 |
+
return obj
|
| 29 |
+
|
| 30 |
+
return create
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _make_state(index: Int32, phase: Int32) -> PipelineState:
|
| 34 |
+
"""Construct a PipelineState from index and phase (count/stages unused by callers)."""
|
| 35 |
+
return PipelineState(stages=0, count=Int32(0), index=index, phase=phase)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
class PipelineStateSimple:
|
| 39 |
"""
|
| 40 |
Pipeline state contains an index and phase bit corresponding to the current position in the circular buffer.
|
|
|
|
| 43 |
"""
|
| 44 |
|
| 45 |
def __init__(self, stages: int, phase_index: Int32):
|
|
|
|
|
|
|
|
|
|
| 46 |
self._stages = stages
|
| 47 |
self._phase_index = phase_index
|
| 48 |
|
|
|
|
| 51 |
|
| 52 |
@property
|
| 53 |
def stages(self) -> int:
|
|
|
|
| 54 |
return self._stages
|
| 55 |
|
| 56 |
@property
|
| 57 |
def index(self) -> Int32:
|
|
|
|
|
|
|
| 58 |
if const_expr(self._stages == 1):
|
| 59 |
return Int32(0)
|
| 60 |
else:
|
|
|
|
| 62 |
|
| 63 |
@property
|
| 64 |
def phase(self) -> Int32:
|
|
|
|
| 65 |
# PTX docs say that the phase parity needs to be 0 or 1, so by right we need to
|
| 66 |
# take modulo 2. But in practice just passing the phase in without modulo works fine.
|
|
|
|
|
|
|
| 67 |
if const_expr(self._stages == 1):
|
| 68 |
return self._phase_index
|
| 69 |
else:
|
|
|
|
| 75 |
else:
|
| 76 |
self._phase_index += 1
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
def __extract_mlir_values__(self):
|
| 79 |
phase_index = self._phase_index
|
| 80 |
return [phase_index.ir_value()]
|
|
|
|
| 88 |
Creates a pipeline state. Producers are assumed to start with an empty buffer and have a flipped phase bit of 1.
|
| 89 |
"""
|
| 90 |
if type is PipelineUserType.Producer:
|
|
|
|
| 91 |
return PipelineStateSimple(stages, Int32(stages))
|
| 92 |
elif type is PipelineUserType.Consumer:
|
| 93 |
return PipelineStateSimple(stages, Int32(0))
|
|
|
|
| 95 |
assert False, "Error: invalid PipelineUserType specified for make_pipeline_state."
|
| 96 |
|
| 97 |
|
| 98 |
+
# ── Shared helpers ───────────────────────────────────────────────────────────
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def _call_with_elect_one(parent_method, self, state, elect_one, syncwarp, loc, ip):
|
| 102 |
+
"""Optionally wrap a parent pipeline method call in sync_warp + elect_one."""
|
| 103 |
+
if const_expr(elect_one):
|
| 104 |
+
if const_expr(syncwarp):
|
| 105 |
+
cute.arch.sync_warp()
|
| 106 |
+
with cute.arch.elect_one():
|
| 107 |
+
parent_method(self, state, loc=loc, ip=ip)
|
| 108 |
+
else:
|
| 109 |
+
parent_method(self, state, loc=loc, ip=ip)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# ── Mixin: _w_index / _w_index_phase variants that delegate to parent ───────
|
| 113 |
+
# Each parent class has PipelineState-based methods (producer_acquire, producer_commit,
|
| 114 |
+
# consumer_wait, consumer_release). The _w_index_phase variants just construct a
|
| 115 |
+
# PipelineState from (index, phase) and delegate.
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class _PipelineIndexPhaseMixin:
|
| 119 |
+
"""Mixin providing _w_index_phase / _w_index methods that delegate to PipelineState-based parents."""
|
| 120 |
+
|
| 121 |
+
@dsl_user_op
|
| 122 |
+
def producer_acquire_w_index_phase(
|
| 123 |
+
self,
|
| 124 |
+
index: Int32,
|
| 125 |
+
phase: Int32,
|
| 126 |
+
try_acquire_token: Optional[Boolean] = None,
|
| 127 |
+
*,
|
| 128 |
+
loc=None,
|
| 129 |
+
ip=None,
|
| 130 |
+
):
|
| 131 |
+
state = _make_state(index, phase)
|
| 132 |
+
# Call the parent's producer_acquire (which takes PipelineState)
|
| 133 |
+
self.producer_acquire(state, try_acquire_token, loc=loc, ip=ip)
|
| 134 |
+
|
| 135 |
+
@dsl_user_op
|
| 136 |
+
def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None):
|
| 137 |
+
state = _make_state(index, Int32(0))
|
| 138 |
+
self.producer_commit(state, loc=loc, ip=ip)
|
| 139 |
+
|
| 140 |
+
@dsl_user_op
|
| 141 |
+
def consumer_wait_w_index_phase(
|
| 142 |
+
self,
|
| 143 |
+
index: Int32,
|
| 144 |
+
phase: Int32,
|
| 145 |
+
try_wait_token: Optional[Boolean] = None,
|
| 146 |
+
*,
|
| 147 |
+
loc=None,
|
| 148 |
+
ip=None,
|
| 149 |
+
):
|
| 150 |
+
state = _make_state(index, phase)
|
| 151 |
+
self.consumer_wait(state, try_wait_token, loc=loc, ip=ip)
|
| 152 |
+
|
| 153 |
+
@dsl_user_op
|
| 154 |
+
def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None):
|
| 155 |
+
state = _make_state(index, Int32(0))
|
| 156 |
+
self.consumer_release(state, loc=loc, ip=ip)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
# ── NamedBarrier ─────────────────────────────────────────────────────────────
|
| 160 |
+
|
| 161 |
+
|
| 162 |
@dataclass(frozen=True)
|
| 163 |
class NamedBarrier(NamedBarrierOg):
|
| 164 |
+
create = _override_create(NamedBarrierOg, None) # patched below
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
@dsl_user_op
|
| 167 |
def arrive_w_index(self, index: Int32, *, loc=None, ip=None) -> None:
|
|
|
|
| 186 |
)
|
| 187 |
|
| 188 |
|
| 189 |
+
NamedBarrier.create = _override_create(NamedBarrierOg, NamedBarrier)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
# ── PipelineAsync ────────────────────────────────────────────────────────────
|
| 193 |
+
|
| 194 |
+
|
| 195 |
@dataclass(frozen=True)
|
| 196 |
+
class PipelineAsync(_PipelineIndexPhaseMixin, PipelineAsyncOg):
|
| 197 |
+
"""
|
| 198 |
+
PipelineAsync with optional elect_one for producer_commit and consumer_release.
|
| 199 |
+
|
| 200 |
+
When elect_one_*=True (set at create time), only one elected thread per warp
|
| 201 |
+
signals the barrier arrive. This is useful when the mask count is set to 1 per warp.
|
| 202 |
+
|
| 203 |
+
Args (to create):
|
| 204 |
+
elect_one_commit: If True, only elected thread signals producer_commit.
|
| 205 |
+
syncwarp_before_commit: If True (default), issue syncwarp before elect_one.
|
| 206 |
+
elect_one_release: If True, only elected thread signals consumer_release.
|
| 207 |
+
syncwarp_before_release: If True (default), issue syncwarp before elect_one.
|
| 208 |
+
Set syncwarp to False when threads are already converged (e.g. after wgmma wait_group).
|
| 209 |
+
"""
|
| 210 |
+
|
| 211 |
+
_elect_one_commit: bool = False
|
| 212 |
+
_syncwarp_before_commit: bool = True
|
| 213 |
+
_elect_one_release: bool = False
|
| 214 |
+
_syncwarp_before_release: bool = True
|
| 215 |
+
|
| 216 |
@staticmethod
|
| 217 |
+
def create(
|
| 218 |
+
*args,
|
| 219 |
+
elect_one_commit: bool = False,
|
| 220 |
+
syncwarp_before_commit: bool = True,
|
| 221 |
+
elect_one_release: bool = False,
|
| 222 |
+
syncwarp_before_release: bool = True,
|
| 223 |
+
**kwargs,
|
| 224 |
+
):
|
| 225 |
obj = PipelineAsyncOg.create(*args, **kwargs)
|
|
|
|
|
|
|
| 226 |
object.__setattr__(obj, "__class__", PipelineAsync)
|
| 227 |
+
object.__setattr__(obj, "_elect_one_commit", elect_one_commit)
|
| 228 |
+
object.__setattr__(obj, "_syncwarp_before_commit", syncwarp_before_commit)
|
| 229 |
+
object.__setattr__(obj, "_elect_one_release", elect_one_release)
|
| 230 |
+
object.__setattr__(obj, "_syncwarp_before_release", syncwarp_before_release)
|
| 231 |
return obj
|
| 232 |
|
| 233 |
@dsl_user_op
|
| 234 |
+
def producer_commit(self, state: PipelineState, *, loc=None, ip=None):
|
| 235 |
+
_call_with_elect_one(
|
| 236 |
+
PipelineAsyncOg.producer_commit,
|
| 237 |
+
self,
|
| 238 |
+
state,
|
| 239 |
+
self._elect_one_commit,
|
| 240 |
+
self._syncwarp_before_commit,
|
| 241 |
+
loc,
|
| 242 |
+
ip,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
)
|
| 244 |
|
| 245 |
@dsl_user_op
|
| 246 |
+
def consumer_release(self, state: PipelineState, *, loc=None, ip=None):
|
| 247 |
+
_call_with_elect_one(
|
| 248 |
+
PipelineAsyncOg.consumer_release,
|
| 249 |
+
self,
|
| 250 |
+
state,
|
| 251 |
+
self._elect_one_release,
|
| 252 |
+
self._syncwarp_before_release,
|
| 253 |
+
loc,
|
| 254 |
+
ip,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
)
|
| 256 |
|
| 257 |
+
# _w_index variants inherited from _PipelineIndexPhaseMixin, which delegate
|
| 258 |
+
# to producer_commit / consumer_release above.
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
# ── PipelineCpAsync ──────────────────────────────────────────────────────────
|
| 262 |
|
| 263 |
|
| 264 |
@dataclass(frozen=True)
|
| 265 |
+
class PipelineCpAsync(_PipelineIndexPhaseMixin, PipelineCpAsyncOg):
|
| 266 |
+
_elect_one_release: bool = False
|
| 267 |
+
_syncwarp_before_release: bool = True
|
|
|
|
| 268 |
|
| 269 |
@staticmethod
|
| 270 |
+
def create(
|
| 271 |
+
*args,
|
| 272 |
+
elect_one_release: bool = False,
|
| 273 |
+
syncwarp_before_release: bool = True,
|
| 274 |
+
**kwargs,
|
| 275 |
+
):
|
| 276 |
+
obj = PipelineCpAsyncOg.create(*args, **kwargs)
|
| 277 |
+
object.__setattr__(obj, "__class__", PipelineCpAsync)
|
| 278 |
+
object.__setattr__(obj, "_elect_one_release", elect_one_release)
|
| 279 |
+
object.__setattr__(obj, "_syncwarp_before_release", syncwarp_before_release)
|
| 280 |
return obj
|
| 281 |
|
| 282 |
+
@dsl_user_op
|
| 283 |
+
def consumer_release(self, state: PipelineState, *, loc=None, ip=None):
|
| 284 |
+
_call_with_elect_one(
|
| 285 |
+
PipelineCpAsyncOg.consumer_release,
|
| 286 |
+
self,
|
| 287 |
+
state,
|
| 288 |
+
self._elect_one_release,
|
| 289 |
+
self._syncwarp_before_release,
|
| 290 |
+
loc,
|
| 291 |
+
ip,
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
# _w_index variants inherited from _PipelineIndexPhaseMixin.
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
# ── PipelineTmaAsync ────────────────────────────────────────────────────────
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
@dataclass(frozen=True)
|
| 301 |
+
class PipelineTmaAsync(_PipelineIndexPhaseMixin, PipelineTmaAsyncOg):
|
| 302 |
+
"""Override producer_acquire to take in extra_tx_count parameter."""
|
| 303 |
+
|
| 304 |
@dsl_user_op
|
| 305 |
def producer_acquire(
|
| 306 |
self,
|
|
|
|
| 327 |
self.sync_object_full.arrive_and_expect_tx(state.index, tx_count, loc=loc, ip=ip)
|
| 328 |
|
| 329 |
|
| 330 |
+
PipelineTmaAsync.create = _override_create(PipelineTmaAsyncOg, PipelineTmaAsync)
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
# ── PipelineTmaUmma ─────────────────────────────────────────────────────────
|
|
|
|
| 334 |
|
| 335 |
+
|
| 336 |
+
@dataclass(frozen=True)
|
| 337 |
+
class PipelineTmaUmma(_PipelineIndexPhaseMixin, PipelineTmaUmmaOg):
|
| 338 |
+
"""Override producer_acquire to take in extra_tx_count parameter."""
|
|
|
|
|
|
|
|
|
|
| 339 |
|
| 340 |
@dsl_user_op
|
| 341 |
def producer_acquire(
|
|
|
|
| 376 |
ip=ip,
|
| 377 |
)
|
| 378 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
|
| 380 |
+
PipelineTmaUmma.create = _override_create(PipelineTmaUmmaOg, PipelineTmaUmma)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
|
| 382 |
+
|
| 383 |
+
# ── PipelineUmmaAsync ───────────────────────────────────────────────────────
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
|
| 385 |
|
| 386 |
@dataclass(frozen=True)
|
| 387 |
+
class PipelineUmmaAsync(_PipelineIndexPhaseMixin, PipelineUmmaAsyncOg):
|
| 388 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
|
| 391 |
+
PipelineUmmaAsync.create = _override_create(PipelineUmmaAsyncOg, PipelineUmmaAsync)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
|
| 394 |
+
# ── PipelineAsyncUmma ───────────────────────────────────────────────────────
|
|
|
|
|
|
|
| 395 |
|
| 396 |
|
| 397 |
@dataclass(frozen=True)
|
| 398 |
+
class PipelineAsyncUmma(_PipelineIndexPhaseMixin, PipelineAsyncUmmaOg):
|
| 399 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 401 |
|
| 402 |
+
PipelineAsyncUmma.create = _override_create(PipelineAsyncUmmaOg, PipelineAsyncUmma)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch-cuda/quack/copy_utils.py
CHANGED
|
@@ -15,6 +15,9 @@ from cutlass._mlir.dialects import llvm
|
|
| 15 |
from cutlass._mlir import ir
|
| 16 |
from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir
|
| 17 |
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
Sm100MmaPeerBitMask = 0xFEFFFFFF
|
| 20 |
|
|
@@ -41,6 +44,30 @@ def cvt_copy(
|
|
| 41 |
cute.copy(tiled_copy, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
|
| 42 |
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
@dsl_user_op
|
| 45 |
def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
|
| 46 |
dst = cute.make_rmem_tensor_like(src, src.element_type, loc=loc, ip=ip)
|
|
@@ -796,17 +823,17 @@ def gather_m_get_copy_fn(
|
|
| 796 |
limit_m: Int32,
|
| 797 |
limit_k: Int32,
|
| 798 |
) -> Callable:
|
| 799 |
-
|
| 800 |
-
tAsA =
|
| 801 |
# k-major
|
| 802 |
assert tAsA.shape[2] == 1
|
| 803 |
tAsA = cute.group_modes(cute.slice_(tAsA, (None, None, 0, None)), 0, 2)
|
| 804 |
|
| 805 |
-
is_even_m_smem =
|
| 806 |
if const_expr(not is_even_m_smem):
|
| 807 |
-
limit_m = min(limit_m,
|
| 808 |
elems_per_load = cute.size(tAsA.shape[0][0])
|
| 809 |
-
cA = cute.make_identity_tensor(
|
| 810 |
tAcA = thr_copy_A.partition_S(cA)
|
| 811 |
t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
|
| 812 |
# Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
|
|
@@ -828,13 +855,13 @@ def gather_m_get_copy_fn(
|
|
| 828 |
else:
|
| 829 |
m_idx[m] = 0 # It's ok to load row 0 in the case of OOB
|
| 830 |
|
| 831 |
-
mA_k = cute.logical_divide(mA, (None,
|
| 832 |
|
| 833 |
def copy_fn(src_idx, dst_idx, pred: bool = False):
|
| 834 |
tApA_k = None
|
| 835 |
if const_expr(pred):
|
| 836 |
tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean)
|
| 837 |
-
limit_k_cur = limit_k - src_idx *
|
| 838 |
for k in cutlass.range(cols_per_thread, unroll_full=True):
|
| 839 |
tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
|
| 840 |
mA_cur = mA_k[None, (None, src_idx)]
|
|
@@ -997,11 +1024,162 @@ def gather_m_get_tma_copy_fn(
|
|
| 997 |
tma_gather4_load_fn = partial(tma_gather4_load, tma_desc_ptr, num_cta=cta_group)
|
| 998 |
|
| 999 |
def copy_fn(src_idx, dst_idx, tma_bar_ptr: cute.Pointer):
|
|
|
|
| 1000 |
col_idx = tile_K * src_idx
|
| 1001 |
for m in cutlass.range(cute.size(tSR_rAIdx, mode=[1]), unroll_full=True):
|
| 1002 |
row_indices = [tSR_rAIdx[v, m] for v in range(4)]
|
| 1003 |
-
smem_ptr =
|
| 1004 |
with cute.arch.elect_one():
|
| 1005 |
tma_gather4_load_fn(smem_ptr, tma_bar_ptr, col_idx, row_indices)
|
| 1006 |
|
| 1007 |
return copy_fn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
from cutlass._mlir import ir
|
| 16 |
from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir
|
| 17 |
|
| 18 |
+
from . import layout_utils
|
| 19 |
+
from .utils import make_vector
|
| 20 |
+
|
| 21 |
|
| 22 |
Sm100MmaPeerBitMask = 0xFEFFFFFF
|
| 23 |
|
|
|
|
| 44 |
cute.copy(tiled_copy, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
|
| 45 |
|
| 46 |
|
| 47 |
+
@dsl_user_op
|
| 48 |
+
def sr_cvt_copy(
|
| 49 |
+
tiled_copy: cute.TiledCopy,
|
| 50 |
+
src: cute.Tensor,
|
| 51 |
+
dst: cute.Tensor,
|
| 52 |
+
seed: Int32,
|
| 53 |
+
tidx: Int32,
|
| 54 |
+
*,
|
| 55 |
+
loc=None,
|
| 56 |
+
ip=None,
|
| 57 |
+
) -> None:
|
| 58 |
+
"""Like cvt_copy but uses stochastic rounding for FP32 -> BF16 conversion."""
|
| 59 |
+
assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem
|
| 60 |
+
from .rounding import convert_f32_to_bf16_sr
|
| 61 |
+
from cutlass.cute.tensor import TensorSSA
|
| 62 |
+
|
| 63 |
+
src_cvt = cute.make_rmem_tensor_like(src, dst.element_type)
|
| 64 |
+
src_vec = src.load()
|
| 65 |
+
raw_vec = convert_f32_to_bf16_sr(src_vec, seed, tidx, loc=loc, ip=ip)
|
| 66 |
+
src_cvt.store(TensorSSA(raw_vec, src_vec.shape, dst.element_type))
|
| 67 |
+
src = src_cvt
|
| 68 |
+
cute.copy(tiled_copy, src, dst, loc=loc, ip=ip)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
@dsl_user_op
|
| 72 |
def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
|
| 73 |
dst = cute.make_rmem_tensor_like(src, src.element_type, loc=loc, ip=ip)
|
|
|
|
| 823 |
limit_m: Int32,
|
| 824 |
limit_k: Int32,
|
| 825 |
) -> Callable:
|
| 826 |
+
tile_M, tile_K = cute.size(sA, mode=[0]), cute.size(sA, mode=[1])
|
| 827 |
+
tAsA = partition_D_position_independent(thr_copy_A, sA)
|
| 828 |
# k-major
|
| 829 |
assert tAsA.shape[2] == 1
|
| 830 |
tAsA = cute.group_modes(cute.slice_(tAsA, (None, None, 0, None)), 0, 2)
|
| 831 |
|
| 832 |
+
is_even_m_smem = tile_M % thr_copy_A.tiler_mn[0].shape == 0
|
| 833 |
if const_expr(not is_even_m_smem):
|
| 834 |
+
limit_m = min(limit_m, tile_M)
|
| 835 |
elems_per_load = cute.size(tAsA.shape[0][0])
|
| 836 |
+
cA = cute.make_identity_tensor((tile_M, tile_K))
|
| 837 |
tAcA = thr_copy_A.partition_S(cA)
|
| 838 |
t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
|
| 839 |
# Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
|
|
|
|
| 855 |
else:
|
| 856 |
m_idx[m] = 0 # It's ok to load row 0 in the case of OOB
|
| 857 |
|
| 858 |
+
mA_k = cute.logical_divide(mA, (None, tile_K))
|
| 859 |
|
| 860 |
def copy_fn(src_idx, dst_idx, pred: bool = False):
|
| 861 |
tApA_k = None
|
| 862 |
if const_expr(pred):
|
| 863 |
tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean)
|
| 864 |
+
limit_k_cur = limit_k - src_idx * tile_K
|
| 865 |
for k in cutlass.range(cols_per_thread, unroll_full=True):
|
| 866 |
tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
|
| 867 |
mA_cur = mA_k[None, (None, src_idx)]
|
|
|
|
| 1024 |
tma_gather4_load_fn = partial(tma_gather4_load, tma_desc_ptr, num_cta=cta_group)
|
| 1025 |
|
| 1026 |
def copy_fn(src_idx, dst_idx, tma_bar_ptr: cute.Pointer):
|
| 1027 |
+
tSR_sA_cur = tSR_sA[None, None, None, dst_idx]
|
| 1028 |
col_idx = tile_K * src_idx
|
| 1029 |
for m in cutlass.range(cute.size(tSR_rAIdx, mode=[1]), unroll_full=True):
|
| 1030 |
row_indices = [tSR_rAIdx[v, m] for v in range(4)]
|
| 1031 |
+
smem_ptr = tSR_sA_cur[None, m, None].iterator
|
| 1032 |
with cute.arch.elect_one():
|
| 1033 |
tma_gather4_load_fn(smem_ptr, tma_bar_ptr, col_idx, row_indices)
|
| 1034 |
|
| 1035 |
return copy_fn
|
| 1036 |
+
|
| 1037 |
+
|
| 1038 |
+
@cute.jit
|
| 1039 |
+
def gather_k_get_tma_copy_fn(
|
| 1040 |
+
tma_atom: cute.CopyAtom,
|
| 1041 |
+
sA: cute.Tensor, # ((4, tile_K/4), (tile_M,), STAGE) — K-grouped load layout
|
| 1042 |
+
sAIdx: cute.Tensor, # (tile_K, a_prefetch_stage) — K indices in smem
|
| 1043 |
+
col_idx: Int32, # M offset in global tensor (contiguous dim for M-major)
|
| 1044 |
+
warp_idx: Int32,
|
| 1045 |
+
num_warps: int,
|
| 1046 |
+
num_cta: int = 1,
|
| 1047 |
+
) -> Tuple[Callable, Callable]:
|
| 1048 |
+
"""Build a copy function for TMA gather4 in K dimension (M-major A).
|
| 1049 |
+
|
| 1050 |
+
Each gather4 instruction loads 4 K-columns × tile_M contiguous M-elements.
|
| 1051 |
+
col_idx is the absolute M position in the global tensor.
|
| 1052 |
+
K indices come from sAIdx (prefetched to smem by the scheduler warp).
|
| 1053 |
+
|
| 1054 |
+
Returns copy_fn(src_idx, dst_idx, tma_bar_ptr) which:
|
| 1055 |
+
Issues gather4 calls with those K indices as row_indices
|
| 1056 |
+
"""
|
| 1057 |
+
tile_K = cute.size(sAIdx, mode=[0])
|
| 1058 |
+
assert tile_K % 4 == 0
|
| 1059 |
+
cta_group = num_cta
|
| 1060 |
+
|
| 1061 |
+
# Tiled copy for loading K indices from smem to registers (4 per vector, across warps)
|
| 1062 |
+
copy_AIdx_s2r = cute.make_tiled_copy_tv(
|
| 1063 |
+
cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Int32, num_bits_per_copy=128),
|
| 1064 |
+
cute.make_layout(num_warps), # thr_layout
|
| 1065 |
+
cute.make_layout(4), # val_layout — 4 K indices per gather4
|
| 1066 |
+
)
|
| 1067 |
+
warp_idx = cute.arch.make_warp_uniform(warp_idx)
|
| 1068 |
+
warp_copy_AIdx_s2r = copy_AIdx_s2r.get_slice(warp_idx)
|
| 1069 |
+
tSR_sAIdx = warp_copy_AIdx_s2r.partition_S(sAIdx) # (((4,1),4,4))
|
| 1070 |
+
# ((4,1),4,(64,2),(1,4)):((64,0),1024,(1,4096),(0,8192))
|
| 1071 |
+
tSR_sA = warp_copy_AIdx_s2r.partition_S(layout_utils.transpose_view(sA))
|
| 1072 |
+
tma_desc_ptr = get_tma_desc_addr(tma_atom)
|
| 1073 |
+
tma_gather4_load_fn = partial(tma_gather4_load, tma_desc_ptr, num_cta=cta_group)
|
| 1074 |
+
|
| 1075 |
+
def prefetch_from_smem_fn(
|
| 1076 |
+
a_prefetch_pipeline,
|
| 1077 |
+
src_idx,
|
| 1078 |
+
dst_idx,
|
| 1079 |
+
a_prefetch_consumer_state,
|
| 1080 |
+
) -> cute.Tensor:
|
| 1081 |
+
a_prefetch_pipeline.consumer_wait(a_prefetch_consumer_state)
|
| 1082 |
+
tSR_rAIdx = load_s2r(tSR_sAIdx[None, None, dst_idx])
|
| 1083 |
+
cute.arch.sync_warp()
|
| 1084 |
+
with cute.arch.elect_one():
|
| 1085 |
+
a_prefetch_pipeline.consumer_release(a_prefetch_consumer_state)
|
| 1086 |
+
return tSR_rAIdx
|
| 1087 |
+
|
| 1088 |
+
def copy_fn(src_idx, dst_idx, tSR_rAIdx, tma_bar_ptr: cute.Pointer):
|
| 1089 |
+
# Issue gather4: col_idx = M position, row_indices = 4 K positions
|
| 1090 |
+
tSR_sA_cur = tSR_sA[None, None, None, dst_idx]
|
| 1091 |
+
gather_dim = cute.size(tSR_sA_cur, mode=[2, 0]) # Typically 64
|
| 1092 |
+
for k in cutlass.range(cute.size(tSR_rAIdx, mode=[1]), unroll_full=True):
|
| 1093 |
+
row_indices = [tSR_rAIdx[v, k] for v in range(4)]
|
| 1094 |
+
for m in cutlass.range(cute.size(tSR_sA_cur, mode=[2, 1]), unroll_full=True):
|
| 1095 |
+
smem_ptr = tSR_sA_cur[None, k, (None, m)].iterator
|
| 1096 |
+
with cute.arch.elect_one():
|
| 1097 |
+
tma_gather4_load_fn(
|
| 1098 |
+
smem_ptr, tma_bar_ptr, col_idx + m * gather_dim, row_indices
|
| 1099 |
+
)
|
| 1100 |
+
|
| 1101 |
+
return copy_fn, prefetch_from_smem_fn
|
| 1102 |
+
|
| 1103 |
+
|
| 1104 |
+
# ---------------------------------------------------------------------------
|
| 1105 |
+
# Store helpers
|
| 1106 |
+
# ---------------------------------------------------------------------------
|
| 1107 |
+
|
| 1108 |
+
|
| 1109 |
+
@dsl_user_op
|
| 1110 |
+
@cute.jit
|
| 1111 |
+
def store(
|
| 1112 |
+
ptr: cute.Pointer,
|
| 1113 |
+
val,
|
| 1114 |
+
pred: Optional[Boolean] = None,
|
| 1115 |
+
cop: cutlass.Constexpr = None,
|
| 1116 |
+
*,
|
| 1117 |
+
loc=None,
|
| 1118 |
+
ip=None,
|
| 1119 |
+
):
|
| 1120 |
+
"""Store a scalar value via cute.arch.store.
|
| 1121 |
+
|
| 1122 |
+
ptr: cute.Pointer (any address space).
|
| 1123 |
+
val: DSL Numeric value.
|
| 1124 |
+
pred: None → unconditional. DSL Boolean → skipped when pred == 0.
|
| 1125 |
+
cop: Cache operator — "wb" (default), "cg", "cs" (streaming), "wt".
|
| 1126 |
+
"""
|
| 1127 |
+
if const_expr(pred is None):
|
| 1128 |
+
cute.arch.store(ptr.llvm_ptr, type(val)(val), cop=cop, loc=loc, ip=ip)
|
| 1129 |
+
else:
|
| 1130 |
+
if pred:
|
| 1131 |
+
cute.arch.store(ptr.llvm_ptr, type(val)(val), cop=cop, loc=loc, ip=ip)
|
| 1132 |
+
|
| 1133 |
+
|
| 1134 |
+
@dsl_user_op
|
| 1135 |
+
@cute.jit
|
| 1136 |
+
def store_v2(
|
| 1137 |
+
ptr: cute.Pointer,
|
| 1138 |
+
v0,
|
| 1139 |
+
v1,
|
| 1140 |
+
pred: Optional[Boolean] = None,
|
| 1141 |
+
cop: cutlass.Constexpr = None,
|
| 1142 |
+
*,
|
| 1143 |
+
loc=None,
|
| 1144 |
+
ip=None,
|
| 1145 |
+
):
|
| 1146 |
+
"""Vectorized store of 2 elements via cute.arch.store.
|
| 1147 |
+
|
| 1148 |
+
Packs v0, v1 into an MLIR <2 x T> vector.
|
| 1149 |
+
ptr: cute.Pointer (any address space, must be aligned for vector width).
|
| 1150 |
+
cop: Cache operator — "wb" (default), "cg", "cs" (streaming), "wt".
|
| 1151 |
+
"""
|
| 1152 |
+
vec = make_vector(type(v0), v0, v1, loc=loc, ip=ip)
|
| 1153 |
+
if const_expr(pred is None):
|
| 1154 |
+
cute.arch.store(ptr.llvm_ptr, vec, cop=cop, loc=loc, ip=ip)
|
| 1155 |
+
else:
|
| 1156 |
+
if pred:
|
| 1157 |
+
cute.arch.store(ptr.llvm_ptr, vec, cop=cop, loc=loc, ip=ip)
|
| 1158 |
+
|
| 1159 |
+
|
| 1160 |
+
@dsl_user_op
|
| 1161 |
+
@cute.jit
|
| 1162 |
+
def store_v4(
|
| 1163 |
+
ptr: cute.Pointer,
|
| 1164 |
+
v0,
|
| 1165 |
+
v1,
|
| 1166 |
+
v2,
|
| 1167 |
+
v3,
|
| 1168 |
+
pred: Optional[Boolean] = None,
|
| 1169 |
+
cop: cutlass.Constexpr = None,
|
| 1170 |
+
*,
|
| 1171 |
+
loc=None,
|
| 1172 |
+
ip=None,
|
| 1173 |
+
):
|
| 1174 |
+
"""Vectorized store of 4 elements via cute.arch.store.
|
| 1175 |
+
|
| 1176 |
+
Packs v0–v3 into an MLIR <4 x T> vector.
|
| 1177 |
+
ptr: cute.Pointer (any address space, must be aligned for vector width).
|
| 1178 |
+
cop: Cache operator — "wb" (default), "cg", "cs" (streaming), "wt".
|
| 1179 |
+
"""
|
| 1180 |
+
vec = make_vector(type(v0), v0, v1, v2, v3, loc=loc, ip=ip)
|
| 1181 |
+
if const_expr(pred is None):
|
| 1182 |
+
cute.arch.store(ptr.llvm_ptr, vec, cop=cop, loc=loc, ip=ip)
|
| 1183 |
+
else:
|
| 1184 |
+
if pred:
|
| 1185 |
+
cute.arch.store(ptr.llvm_ptr, vec, cop=cop, loc=loc, ip=ip)
|
build/torch-cuda/quack/cute_dsl_utils.py
CHANGED
|
@@ -4,6 +4,9 @@ from typing import Tuple, get_origin
|
|
| 4 |
from functools import lru_cache
|
| 5 |
from dataclasses import dataclass, fields
|
| 6 |
|
|
|
|
|
|
|
|
|
|
| 7 |
import torch
|
| 8 |
|
| 9 |
try:
|
|
@@ -14,7 +17,6 @@ except ImportError:
|
|
| 14 |
import cutlass
|
| 15 |
import cutlass.cute as cute
|
| 16 |
from cutlass import Int32, Int64, Float16, BFloat16, Float32
|
| 17 |
-
from cutlass.base_dsl.typing import JitArgument
|
| 18 |
from cutlass.base_dsl.tvm_ffi_builder import spec
|
| 19 |
from cutlass.cutlass_dsl import NumericMeta
|
| 20 |
|
|
@@ -65,8 +67,25 @@ def get_max_active_clusters(cluster_size):
|
|
| 65 |
return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size)
|
| 66 |
|
| 67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
@lru_cache
|
| 69 |
def get_device_capacity(device: torch.device = None) -> Tuple[int, int]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
return torch.cuda.get_device_capability(device)
|
| 71 |
|
| 72 |
|
|
@@ -138,28 +157,3 @@ class ParamsBase:
|
|
| 138 |
return values
|
| 139 |
|
| 140 |
__new_from_mlir_values__ = _new_from_mlir_values
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
@dataclass
|
| 144 |
-
class ArgumentsBase(JitArgument):
|
| 145 |
-
def __c_pointers__(self):
|
| 146 |
-
_, non_constexpr_fields = _partition_fields(self)
|
| 147 |
-
c_ptrs = []
|
| 148 |
-
for obj in non_constexpr_fields.values():
|
| 149 |
-
if hasattr(obj, "__c_pointers__"):
|
| 150 |
-
c_ptrs.extend(obj.__c_pointers__())
|
| 151 |
-
return c_ptrs
|
| 152 |
-
|
| 153 |
-
def __get_mlir_types__(self):
|
| 154 |
-
_, non_constexpr_fields = _partition_fields(self)
|
| 155 |
-
types, self._values_pos = [], []
|
| 156 |
-
for obj in non_constexpr_fields.values():
|
| 157 |
-
if hasattr(obj, "__get_mlir_types__"):
|
| 158 |
-
obj_types = obj.__get_mlir_types__()
|
| 159 |
-
types.extend(obj_types)
|
| 160 |
-
self._values_pos.append(len(obj_types))
|
| 161 |
-
else:
|
| 162 |
-
self._values_pos.append(0)
|
| 163 |
-
return types
|
| 164 |
-
|
| 165 |
-
__new_from_mlir_values__ = _new_from_mlir_values
|
|
|
|
| 4 |
from functools import lru_cache
|
| 5 |
from dataclasses import dataclass, fields
|
| 6 |
|
| 7 |
+
import os
|
| 8 |
+
import re
|
| 9 |
+
|
| 10 |
import torch
|
| 11 |
|
| 12 |
try:
|
|
|
|
| 17 |
import cutlass
|
| 18 |
import cutlass.cute as cute
|
| 19 |
from cutlass import Int32, Int64, Float16, BFloat16, Float32
|
|
|
|
| 20 |
from cutlass.base_dsl.tvm_ffi_builder import spec
|
| 21 |
from cutlass.cutlass_dsl import NumericMeta
|
| 22 |
|
|
|
|
| 67 |
return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size)
|
| 68 |
|
| 69 |
|
| 70 |
+
def _parse_arch_str(arch_str: str) -> Tuple[int, int]:
|
| 71 |
+
"""Parse arch string (e.g. 'sm_90', 'sm90', '90', 'sm_100a') to (major, minor) tuple."""
|
| 72 |
+
match = re.match(r"^(?:sm_?)?(\d+)(\d)([af]?)$", arch_str.strip(), re.IGNORECASE)
|
| 73 |
+
if not match:
|
| 74 |
+
raise ValueError(f"Invalid QUACK_ARCH format: {arch_str!r} (expected e.g. '90', 'sm_90')")
|
| 75 |
+
major, minor, _ = match.groups()
|
| 76 |
+
return int(major), int(minor)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
@lru_cache
|
| 80 |
def get_device_capacity(device: torch.device = None) -> Tuple[int, int]:
|
| 81 |
+
"""Return (major, minor) device capability.
|
| 82 |
+
|
| 83 |
+
Override with QUACK_ARCH (e.g. 'sm_90' or '90') for CPU-only compilation
|
| 84 |
+
without a GPU present.
|
| 85 |
+
"""
|
| 86 |
+
arch_override = os.environ.get("QUACK_ARCH")
|
| 87 |
+
if arch_override is not None:
|
| 88 |
+
return _parse_arch_str(arch_override)
|
| 89 |
return torch.cuda.get_device_capability(device)
|
| 90 |
|
| 91 |
|
|
|
|
| 157 |
return values
|
| 158 |
|
| 159 |
__new_from_mlir_values__ = _new_from_mlir_values
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch-cuda/quack/layout_utils.py
CHANGED
|
@@ -295,3 +295,37 @@ def mma_partition_A_vec(
|
|
| 295 |
sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride))
|
| 296 |
tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_A(sVec_mma))
|
| 297 |
return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride))
|
| 296 |
tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_A(sVec_mma))
|
| 297 |
return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None]
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def copy_partition_S_vec(
|
| 301 |
+
sVec: cute.Tensor, thr_copy: cute.core.ThrCopy, expand_shape: int, is_colvec: bool
|
| 302 |
+
) -> cute.Tensor:
|
| 303 |
+
assert cute.rank(sVec) == 2
|
| 304 |
+
assert sVec.stride[0] == 1
|
| 305 |
+
stage = sVec.shape[1]
|
| 306 |
+
shape = (
|
| 307 |
+
(sVec.shape[0], expand_shape, stage)
|
| 308 |
+
if const_expr(is_colvec)
|
| 309 |
+
else (expand_shape, sVec.shape[0], stage)
|
| 310 |
+
)
|
| 311 |
+
stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1])
|
| 312 |
+
sVec_thr = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride))
|
| 313 |
+
tC_sVec = reshape_acc_to_mn(thr_copy.partition_S(sVec_thr))
|
| 314 |
+
return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None]
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def copy_partition_D_vec(
|
| 318 |
+
sVec: cute.Tensor, thr_copy: cute.core.ThrCopy, expand_shape: int, is_colvec: bool
|
| 319 |
+
) -> cute.Tensor:
|
| 320 |
+
assert cute.rank(sVec) == 2
|
| 321 |
+
assert sVec.stride[0] == 1
|
| 322 |
+
stage = sVec.shape[1]
|
| 323 |
+
shape = (
|
| 324 |
+
(sVec.shape[0], expand_shape, stage)
|
| 325 |
+
if const_expr(is_colvec)
|
| 326 |
+
else (expand_shape, sVec.shape[0], stage)
|
| 327 |
+
)
|
| 328 |
+
stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1])
|
| 329 |
+
sVec_thr = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride))
|
| 330 |
+
tC_sVec = reshape_acc_to_mn(thr_copy.partition_D(sVec_thr))
|
| 331 |
+
return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None]
|
build/torch-cuda/quack/utils.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from typing import Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import cutlass
|
| 7 |
+
import cutlass.cute as cute
|
| 8 |
+
|
| 9 |
+
from cutlass import Float32, Int32, const_expr
|
| 10 |
+
from cutlass._mlir.dialects import arith as _arith
|
| 11 |
+
from cutlass._mlir.dialects import llvm, nvvm, vector
|
| 12 |
+
from cutlass.cutlass_dsl import T, dsl_user_op
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dsl_user_op
|
| 16 |
+
def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer:
|
| 17 |
+
return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@cute.jit
|
| 21 |
+
def load_scalar_or_pointer(x, dtype=Float32):
|
| 22 |
+
if const_expr(isinstance(x, cute.Pointer)):
|
| 23 |
+
return dtype(cute.make_tensor(x, cute.make_layout(1))[0])
|
| 24 |
+
else:
|
| 25 |
+
return x
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dsl_user_op
|
| 29 |
+
def set_block_rank(
|
| 30 |
+
smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: Int32, *, loc=None, ip=None
|
| 31 |
+
) -> Int32:
|
| 32 |
+
"""Map the given smem pointer to the address at another CTA rank in the cluster."""
|
| 33 |
+
smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
|
| 34 |
+
return Int32(
|
| 35 |
+
llvm.inline_asm(
|
| 36 |
+
T.i32(),
|
| 37 |
+
[smem_ptr_i32, peer_cta_rank_in_cluster.ir_value()],
|
| 38 |
+
"mapa.shared::cluster.u32 $0, $1, $2;",
|
| 39 |
+
"=r,r,r",
|
| 40 |
+
has_side_effects=False,
|
| 41 |
+
is_align_stack=False,
|
| 42 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 43 |
+
)
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@dsl_user_op
|
| 48 |
+
def store_shared_remote(
|
| 49 |
+
val: float | Float32 | Int32 | cutlass.Int64,
|
| 50 |
+
smem_ptr: cute.Pointer,
|
| 51 |
+
mbar_ptr: cute.Pointer,
|
| 52 |
+
peer_cta_rank_in_cluster: cute.typing.Int,
|
| 53 |
+
*,
|
| 54 |
+
loc=None,
|
| 55 |
+
ip=None,
|
| 56 |
+
) -> None:
|
| 57 |
+
remote_smem_ptr_i32 = set_block_rank(
|
| 58 |
+
smem_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
|
| 59 |
+
).ir_value()
|
| 60 |
+
remote_mbar_ptr_i32 = set_block_rank(
|
| 61 |
+
mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
|
| 62 |
+
).ir_value()
|
| 63 |
+
if const_expr(isinstance(val, float)):
|
| 64 |
+
val = Float32(val)
|
| 65 |
+
assert isinstance(val, (Float32, Int32, cutlass.Int64)), "val must be Float32, Int32, or Int64"
|
| 66 |
+
suffix = {Float32: "f32", Int32: "s32", cutlass.Int64: "s64"}[type(val)]
|
| 67 |
+
constraint = {Float32: "f", Int32: "r", cutlass.Int64: "l"}[type(val)]
|
| 68 |
+
llvm.inline_asm(
|
| 69 |
+
None,
|
| 70 |
+
[remote_smem_ptr_i32, val.ir_value(loc=loc, ip=ip), remote_mbar_ptr_i32],
|
| 71 |
+
f"st.async.shared::cluster.mbarrier::complete_tx::bytes.{suffix} [$0], $1, [$2];",
|
| 72 |
+
f"r,{constraint},r",
|
| 73 |
+
has_side_effects=True,
|
| 74 |
+
is_align_stack=False,
|
| 75 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@dsl_user_op
|
| 80 |
+
def store_shared_remote_x4(
|
| 81 |
+
val0: Float32 | Int32,
|
| 82 |
+
val1: Float32 | Int32,
|
| 83 |
+
val2: Float32 | Int32,
|
| 84 |
+
val3: Float32 | Int32,
|
| 85 |
+
smem_ptr: cute.Pointer,
|
| 86 |
+
mbar_ptr: cute.Pointer,
|
| 87 |
+
peer_cta_rank_in_cluster: cute.typing.Int,
|
| 88 |
+
*,
|
| 89 |
+
loc=None,
|
| 90 |
+
ip=None,
|
| 91 |
+
) -> None:
|
| 92 |
+
remote_smem_ptr_i32 = set_block_rank(
|
| 93 |
+
smem_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
|
| 94 |
+
).ir_value()
|
| 95 |
+
remote_mbar_ptr_i32 = set_block_rank(
|
| 96 |
+
mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
|
| 97 |
+
).ir_value()
|
| 98 |
+
assert isinstance(val0, (Float32, Int32)), "val must be Float32, or Int32"
|
| 99 |
+
dtype = Float32 if isinstance(val0, Float32) else Int32
|
| 100 |
+
suffix = {Float32: "f32", Int32: "s32"}[dtype]
|
| 101 |
+
constraint = {Float32: "f", Int32: "r"}[dtype]
|
| 102 |
+
llvm.inline_asm(
|
| 103 |
+
None,
|
| 104 |
+
[
|
| 105 |
+
remote_smem_ptr_i32,
|
| 106 |
+
remote_mbar_ptr_i32,
|
| 107 |
+
dtype(val0).ir_value(loc=loc, ip=ip),
|
| 108 |
+
dtype(val1).ir_value(loc=loc, ip=ip),
|
| 109 |
+
dtype(val2).ir_value(loc=loc, ip=ip),
|
| 110 |
+
dtype(val3).ir_value(loc=loc, ip=ip),
|
| 111 |
+
],
|
| 112 |
+
"{\n\t"
|
| 113 |
+
f".reg .v4 .{suffix} abcd;\n\t"
|
| 114 |
+
f"mov.{suffix} abcd.x, $2;\n\t"
|
| 115 |
+
f"mov.{suffix} abcd.y, $3;\n\t"
|
| 116 |
+
f"mov.{suffix} abcd.z, $4;\n\t"
|
| 117 |
+
f"mov.{suffix} abcd.w, $5;\n\t"
|
| 118 |
+
f"st.async.shared::cluster.mbarrier::complete_tx::bytes.v4.{suffix} [$0], abcd, [$1];\n\t"
|
| 119 |
+
"}\n",
|
| 120 |
+
f"r,r,{constraint},{constraint},{constraint},{constraint}",
|
| 121 |
+
has_side_effects=True,
|
| 122 |
+
is_align_stack=False,
|
| 123 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
@dsl_user_op
|
| 128 |
+
def fmin(a: Union[float, Float32], b: Union[float, Float32], *, loc=None, ip=None) -> Float32:
|
| 129 |
+
if cutlass.const_expr(cutlass.CUDA_VERSION.major) == 12:
|
| 130 |
+
return Float32(
|
| 131 |
+
nvvm.fmin(
|
| 132 |
+
T.f32(),
|
| 133 |
+
Float32(a).ir_value(loc=loc, ip=ip),
|
| 134 |
+
Float32(b).ir_value(loc=loc, ip=ip),
|
| 135 |
+
loc=loc,
|
| 136 |
+
ip=ip,
|
| 137 |
+
)
|
| 138 |
+
)
|
| 139 |
+
return Float32(
|
| 140 |
+
nvvm.fmin(
|
| 141 |
+
Float32(a).ir_value(loc=loc, ip=ip),
|
| 142 |
+
Float32(b).ir_value(loc=loc, ip=ip),
|
| 143 |
+
loc=loc,
|
| 144 |
+
ip=ip,
|
| 145 |
+
)
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
@dsl_user_op
|
| 150 |
+
def sqrt(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
| 151 |
+
return Float32(
|
| 152 |
+
llvm.inline_asm(
|
| 153 |
+
T.f32(),
|
| 154 |
+
[Float32(a).ir_value(loc=loc, ip=ip)],
|
| 155 |
+
"sqrt.approx.f32 $0, $1;",
|
| 156 |
+
"=f,f",
|
| 157 |
+
has_side_effects=False,
|
| 158 |
+
is_align_stack=False,
|
| 159 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 160 |
+
)
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
@dsl_user_op
|
| 165 |
+
def ceil(a: float | Float32, *, loc=None, ip=None) -> Int32:
|
| 166 |
+
return Int32(
|
| 167 |
+
llvm.inline_asm(
|
| 168 |
+
T.i32(),
|
| 169 |
+
[Float32(a).ir_value(loc=loc, ip=ip)],
|
| 170 |
+
"cvt.rpi.ftz.s32.f32 $0, $1;",
|
| 171 |
+
"=r,f",
|
| 172 |
+
has_side_effects=False,
|
| 173 |
+
is_align_stack=False,
|
| 174 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 175 |
+
)
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
@cute.jit
|
| 180 |
+
def fill_oob(tXsX: cute.Tensor, tXpX: Optional[cute.Tensor], fill_value: cute.Numeric) -> None:
|
| 181 |
+
"""Fill out-of-bounds values in shared memory tensor.
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
tXsX: Shared memory tensor to fill
|
| 185 |
+
tXpX: Predicate tensor indicating valid elements
|
| 186 |
+
fill_value: Value to fill OOB locations with
|
| 187 |
+
"""
|
| 188 |
+
tXrX_fill = cute.make_rmem_tensor_like(tXsX[(None, 0), None, 0])
|
| 189 |
+
tXrX_fill.fill(fill_value)
|
| 190 |
+
for rest_v in cutlass.range_constexpr(tXsX.shape[0][1]):
|
| 191 |
+
for rest_k in cutlass.range_constexpr(tXsX.shape[2]):
|
| 192 |
+
if const_expr(tXpX is not None):
|
| 193 |
+
if not tXpX[rest_v, 0, rest_k]:
|
| 194 |
+
cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k])
|
| 195 |
+
else:
|
| 196 |
+
cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k])
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
# ---------------------------------------------------------------------------
|
| 200 |
+
# General-purpose DSL store / vector helpers
|
| 201 |
+
# ---------------------------------------------------------------------------
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
@dsl_user_op
|
| 205 |
+
def make_vector(elem_type, *values, loc=None, ip=None):
|
| 206 |
+
"""Build an MLIR vector <N x elem_type> from N scalar DSL values.
|
| 207 |
+
|
| 208 |
+
Example: make_vector(cutlass.Uint32, v0, v1) -> <2 x i32> MLIR vector
|
| 209 |
+
"""
|
| 210 |
+
from cutlass._mlir import ir
|
| 211 |
+
|
| 212 |
+
n = len(values)
|
| 213 |
+
mlir_ty = elem_type.mlir_type
|
| 214 |
+
vec_ty = ir.VectorType.get([n], mlir_ty)
|
| 215 |
+
vec = llvm.mlir_undef(vec_ty, loc=loc, ip=ip)
|
| 216 |
+
for i, v in enumerate(values):
|
| 217 |
+
vec = vector.insertelement(
|
| 218 |
+
elem_type(v).ir_value(loc=loc, ip=ip),
|
| 219 |
+
vec,
|
| 220 |
+
position=_arith.constant(T.i32(), i, loc=loc, ip=ip),
|
| 221 |
+
loc=loc,
|
| 222 |
+
ip=ip,
|
| 223 |
+
)
|
| 224 |
+
return vec
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
@dsl_user_op
|
| 228 |
+
def f32x2_to_i64(a: Float32, b: Float32, *, loc=None, ip=None) -> cutlass.Int64:
|
| 229 |
+
vec_f32x2 = vector.from_elements(
|
| 230 |
+
T.vector(2, T.f32()), (a.ir_value(), b.ir_value()), loc=loc, ip=ip
|
| 231 |
+
)
|
| 232 |
+
vec_i64x1 = vector.bitcast(T.vector(1, T.i64()), vec_f32x2)
|
| 233 |
+
res = cutlass.Int64(
|
| 234 |
+
vector.extract(vec_i64x1, dynamic_position=[], static_position=[0], loc=loc, ip=ip)
|
| 235 |
+
)
|
| 236 |
+
return res
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
@dsl_user_op
|
| 240 |
+
def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
|
| 241 |
+
vec_i64x1 = vector.from_elements(T.vector(1, T.i64()), (c.ir_value(),), loc=loc, ip=ip)
|
| 242 |
+
vec_f32x2 = vector.bitcast(T.vector(2, T.f32()), vec_i64x1)
|
| 243 |
+
res0 = Float32(
|
| 244 |
+
vector.extract(vec_f32x2, dynamic_position=[], static_position=[0], loc=loc, ip=ip)
|
| 245 |
+
)
|
| 246 |
+
res1 = Float32(
|
| 247 |
+
vector.extract(vec_f32x2, dynamic_position=[], static_position=[1], loc=loc, ip=ip)
|
| 248 |
+
)
|
| 249 |
+
return res0, res1
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
@cute.jit
|
| 253 |
+
def warp_prefix_sum(val: Int32, lane: Optional[Int32] = None) -> Int32:
|
| 254 |
+
if const_expr(lane is None):
|
| 255 |
+
lane = cute.arch.lane_idx()
|
| 256 |
+
for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))):
|
| 257 |
+
offset = 1 << i
|
| 258 |
+
# Very important that we set mask_and_clamp to 0
|
| 259 |
+
partial_sum = cute.arch.shuffle_sync_up(val, offset=offset, mask_and_clamp=0)
|
| 260 |
+
if lane >= offset:
|
| 261 |
+
val += partial_sum
|
| 262 |
+
return val
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
@dsl_user_op
|
| 266 |
+
def atomic_inc_i32(a: int | Int32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> Int32:
|
| 267 |
+
from cutlass import CUDA_VERSION
|
| 268 |
+
|
| 269 |
+
# * NVVM call based on nvvm version
|
| 270 |
+
if CUDA_VERSION.major == 12 and CUDA_VERSION.minor == 9:
|
| 271 |
+
# Old API: requires explicit result type as first positional argument
|
| 272 |
+
return nvvm.atomicrmw(
|
| 273 |
+
res=T.i32(), op=nvvm.AtomicOpKind.INC, ptr=gmem_ptr.llvm_ptr, a=Int32(a).ir_value()
|
| 274 |
+
)
|
| 275 |
+
else:
|
| 276 |
+
# New API: infers result type automatically
|
| 277 |
+
return nvvm.atomicrmw(
|
| 278 |
+
op=nvvm.AtomicOpKind.INC, ptr=gmem_ptr.llvm_ptr, a=Int32(a).ir_value()
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
@dsl_user_op
|
| 283 |
+
def atomic_add_i32(a: int | Int32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> Int32:
|
| 284 |
+
from cutlass import CUDA_VERSION
|
| 285 |
+
|
| 286 |
+
# * NVVM call based on nvvm version
|
| 287 |
+
if CUDA_VERSION.major == 12 and CUDA_VERSION.minor == 9:
|
| 288 |
+
# Old API: requires explicit result type as first positional argument
|
| 289 |
+
return nvvm.atomicrmw(
|
| 290 |
+
res=T.i32(), op=nvvm.AtomicOpKind.ADD, ptr=gmem_ptr.llvm_ptr, a=Int32(a).ir_value()
|
| 291 |
+
)
|
| 292 |
+
else:
|
| 293 |
+
# New API: infers result type automatically
|
| 294 |
+
return nvvm.atomicrmw(
|
| 295 |
+
op=nvvm.AtomicOpKind.ADD, ptr=gmem_ptr.llvm_ptr, a=Int32(a).ir_value()
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
@dsl_user_op
|
| 300 |
+
def issue_clc_query_nomulticast(
|
| 301 |
+
mbar_ptr: cute.Pointer,
|
| 302 |
+
clc_response_ptr: cute.Pointer,
|
| 303 |
+
loc=None,
|
| 304 |
+
ip=None,
|
| 305 |
+
) -> None:
|
| 306 |
+
"""
|
| 307 |
+
The clusterlaunchcontrol.try_cancel instruction requests atomically cancelling the launch
|
| 308 |
+
of a cluster that has not started running yet. It asynchronously writes an opaque response
|
| 309 |
+
to shared memory indicating whether the operation succeeded or failed. On success, the
|
| 310 |
+
opaque response contains the ctaid of the first CTA of the canceled cluster.
|
| 311 |
+
|
| 312 |
+
:param mbar_ptr: A pointer to the mbarrier address in SMEM
|
| 313 |
+
:type mbar_ptr: Pointer
|
| 314 |
+
:param clc_response_ptr: A pointer to the cluster launch control response address in SMEM
|
| 315 |
+
:type clc_response_ptr: Pointer
|
| 316 |
+
"""
|
| 317 |
+
mbar_llvm_ptr = mbar_ptr.llvm_ptr
|
| 318 |
+
clc_response_llvm_ptr = clc_response_ptr.llvm_ptr
|
| 319 |
+
nvvm.clusterlaunchcontrol_try_cancel(
|
| 320 |
+
clc_response_llvm_ptr,
|
| 321 |
+
mbar_llvm_ptr,
|
| 322 |
+
loc=loc,
|
| 323 |
+
ip=ip,
|
| 324 |
+
)
|
build/torch-cuda/seqlen_info.py
CHANGED
|
@@ -5,6 +5,8 @@ import cutlass
|
|
| 5 |
import cutlass.cute as cute
|
| 6 |
from cutlass import Int32, const_expr
|
| 7 |
|
|
|
|
|
|
|
| 8 |
"""
|
| 9 |
This consolidates all the info related to sequence length. This is so that we can do all
|
| 10 |
the gmem reads once at the beginning of each tile, rather than having to repeat these reads
|
|
@@ -14,34 +16,61 @@ to compute various things like n_block_min, n_block_max, etc.
|
|
| 14 |
|
| 15 |
@dataclass(frozen=True)
|
| 16 |
class SeqlenInfo:
|
| 17 |
-
offset:
|
| 18 |
-
|
|
|
|
|
|
|
| 19 |
|
| 20 |
@staticmethod
|
| 21 |
def create(
|
| 22 |
-
batch_idx:
|
| 23 |
-
seqlen_static:
|
| 24 |
cu_seqlens: Optional[cute.Tensor] = None,
|
| 25 |
seqused: Optional[cute.Tensor] = None,
|
|
|
|
| 26 |
):
|
| 27 |
offset = 0 if const_expr(cu_seqlens is None) else cu_seqlens[batch_idx]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
if const_expr(seqused is not None):
|
| 29 |
seqlen = seqused[batch_idx]
|
| 30 |
elif const_expr(cu_seqlens is not None):
|
| 31 |
seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx]
|
| 32 |
else:
|
| 33 |
seqlen = seqlen_static
|
| 34 |
-
return SeqlenInfo(offset, seqlen)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
|
| 37 |
@dataclass(frozen=True)
|
| 38 |
class SeqlenInfoQK:
|
| 39 |
-
offset_q:
|
| 40 |
-
offset_k:
|
| 41 |
-
padded_offset_q:
|
| 42 |
-
padded_offset_k:
|
| 43 |
-
seqlen_q:
|
| 44 |
-
seqlen_k:
|
| 45 |
has_cu_seqlens_q: cutlass.Constexpr[bool]
|
| 46 |
has_cu_seqlens_k: cutlass.Constexpr[bool]
|
| 47 |
has_seqused_q: cutlass.Constexpr[bool]
|
|
@@ -49,27 +78,27 @@ class SeqlenInfoQK:
|
|
| 49 |
|
| 50 |
@staticmethod
|
| 51 |
def create(
|
| 52 |
-
batch_idx:
|
| 53 |
-
seqlen_q_static:
|
| 54 |
-
seqlen_k_static:
|
| 55 |
mCuSeqlensQ: Optional[cute.Tensor] = None,
|
| 56 |
mCuSeqlensK: Optional[cute.Tensor] = None,
|
| 57 |
mSeqUsedQ: Optional[cute.Tensor] = None,
|
| 58 |
mSeqUsedK: Optional[cute.Tensor] = None,
|
| 59 |
-
tile_m: cutlass.Constexpr[
|
| 60 |
-
tile_n: cutlass.Constexpr[
|
| 61 |
):
|
| 62 |
offset_q = 0 if const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx]
|
| 63 |
offset_k = 0 if const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx]
|
| 64 |
padded_offset_q = (
|
| 65 |
0
|
| 66 |
if const_expr(mCuSeqlensQ is None)
|
| 67 |
-
else (offset_q + batch_idx * tile_m) // tile_m * tile_m
|
| 68 |
)
|
| 69 |
padded_offset_k = (
|
| 70 |
0
|
| 71 |
if const_expr(mCuSeqlensK is None)
|
| 72 |
-
else (offset_k + batch_idx * tile_n) // tile_n * tile_n
|
| 73 |
)
|
| 74 |
if const_expr(mSeqUsedQ is not None):
|
| 75 |
seqlen_q = mSeqUsedQ[batch_idx]
|
|
@@ -87,10 +116,6 @@ class SeqlenInfoQK:
|
|
| 87 |
if const_expr(mCuSeqlensK is None)
|
| 88 |
else mCuSeqlensK[batch_idx + 1] - offset_k
|
| 89 |
)
|
| 90 |
-
has_cu_seqlens_q: int = mCuSeqlensQ is not None
|
| 91 |
-
has_cu_seqlens_k: int = mCuSeqlensK is not None
|
| 92 |
-
has_seqused_q: int = mSeqUsedQ is not None
|
| 93 |
-
has_seqused_k: int = mSeqUsedK is not None
|
| 94 |
return SeqlenInfoQK(
|
| 95 |
offset_q,
|
| 96 |
offset_k,
|
|
@@ -98,10 +123,10 @@ class SeqlenInfoQK:
|
|
| 98 |
padded_offset_k,
|
| 99 |
seqlen_q,
|
| 100 |
seqlen_k,
|
| 101 |
-
has_cu_seqlens_q,
|
| 102 |
-
has_cu_seqlens_k,
|
| 103 |
-
has_seqused_q,
|
| 104 |
-
has_seqused_k,
|
| 105 |
)
|
| 106 |
|
| 107 |
def offset_batch_Q(
|
|
@@ -110,16 +135,38 @@ class SeqlenInfoQK:
|
|
| 110 |
batch_idx: Int32,
|
| 111 |
dim: int,
|
| 112 |
padded: cutlass.Constexpr[bool] = False,
|
|
|
|
| 113 |
) -> cute.Tensor:
|
| 114 |
"""Seqlen must be the first dimension of mQ"""
|
| 115 |
-
if const_expr(not
|
| 116 |
-
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
else:
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
def offset_batch_K(
|
| 125 |
self,
|
|
@@ -127,12 +174,114 @@ class SeqlenInfoQK:
|
|
| 127 |
batch_idx: Int32,
|
| 128 |
dim: int,
|
| 129 |
padded: cutlass.Constexpr[bool] = False,
|
|
|
|
|
|
|
| 130 |
) -> cute.Tensor:
|
| 131 |
"""Seqlen must be the first dimension of mK"""
|
| 132 |
-
if const_expr(not
|
| 133 |
-
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
else:
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
import cutlass.cute as cute
|
| 6 |
from cutlass import Int32, const_expr
|
| 7 |
|
| 8 |
+
from .quack import copy_utils
|
| 9 |
+
|
| 10 |
"""
|
| 11 |
This consolidates all the info related to sequence length. This is so that we can do all
|
| 12 |
the gmem reads once at the beginning of each tile, rather than having to repeat these reads
|
|
|
|
| 16 |
|
| 17 |
@dataclass(frozen=True)
|
| 18 |
class SeqlenInfo:
|
| 19 |
+
offset: Int32
|
| 20 |
+
offset_padded: Int32
|
| 21 |
+
seqlen: Int32
|
| 22 |
+
has_cu_seqlens: cutlass.Constexpr[bool] = False
|
| 23 |
|
| 24 |
@staticmethod
|
| 25 |
def create(
|
| 26 |
+
batch_idx: Int32,
|
| 27 |
+
seqlen_static: Int32,
|
| 28 |
cu_seqlens: Optional[cute.Tensor] = None,
|
| 29 |
seqused: Optional[cute.Tensor] = None,
|
| 30 |
+
tile: cutlass.Constexpr[int] = 128,
|
| 31 |
):
|
| 32 |
offset = 0 if const_expr(cu_seqlens is None) else cu_seqlens[batch_idx]
|
| 33 |
+
offset_padded = (
|
| 34 |
+
0
|
| 35 |
+
if const_expr(cu_seqlens is None)
|
| 36 |
+
# Add divby so that the compiler knows the alignment when moving by offset_padded
|
| 37 |
+
else cute.assume((offset + batch_idx * tile) // tile * tile, divby=tile)
|
| 38 |
+
)
|
| 39 |
if const_expr(seqused is not None):
|
| 40 |
seqlen = seqused[batch_idx]
|
| 41 |
elif const_expr(cu_seqlens is not None):
|
| 42 |
seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx]
|
| 43 |
else:
|
| 44 |
seqlen = seqlen_static
|
| 45 |
+
return SeqlenInfo(offset, offset_padded, seqlen, has_cu_seqlens=cu_seqlens is not None)
|
| 46 |
+
|
| 47 |
+
def offset_batch(
|
| 48 |
+
self,
|
| 49 |
+
mT: cute.Tensor,
|
| 50 |
+
batch_idx: Int32,
|
| 51 |
+
dim: int,
|
| 52 |
+
padded: cutlass.Constexpr[bool] = False,
|
| 53 |
+
multiple: int = 1,
|
| 54 |
+
) -> cute.Tensor:
|
| 55 |
+
"""Offset a tensor by batch index. batch dim is at position `dim`, seqlen is at dim=0."""
|
| 56 |
+
if const_expr(not self.has_cu_seqlens):
|
| 57 |
+
idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mT) - 1 - dim)
|
| 58 |
+
return mT[idx]
|
| 59 |
+
else:
|
| 60 |
+
off = multiple * (self.offset if const_expr(not padded) else self.offset_padded)
|
| 61 |
+
offset = off if const_expr(cute.rank(mT.shape[0]) == 1) else (0, off)
|
| 62 |
+
idx = (offset,) + (None,) * (cute.rank(mT) - 1)
|
| 63 |
+
return cute.domain_offset(idx, mT)
|
| 64 |
|
| 65 |
|
| 66 |
@dataclass(frozen=True)
|
| 67 |
class SeqlenInfoQK:
|
| 68 |
+
offset_q: Int32
|
| 69 |
+
offset_k: Int32
|
| 70 |
+
padded_offset_q: Int32
|
| 71 |
+
padded_offset_k: Int32
|
| 72 |
+
seqlen_q: Int32
|
| 73 |
+
seqlen_k: Int32
|
| 74 |
has_cu_seqlens_q: cutlass.Constexpr[bool]
|
| 75 |
has_cu_seqlens_k: cutlass.Constexpr[bool]
|
| 76 |
has_seqused_q: cutlass.Constexpr[bool]
|
|
|
|
| 78 |
|
| 79 |
@staticmethod
|
| 80 |
def create(
|
| 81 |
+
batch_idx: Int32,
|
| 82 |
+
seqlen_q_static: Int32,
|
| 83 |
+
seqlen_k_static: Int32,
|
| 84 |
mCuSeqlensQ: Optional[cute.Tensor] = None,
|
| 85 |
mCuSeqlensK: Optional[cute.Tensor] = None,
|
| 86 |
mSeqUsedQ: Optional[cute.Tensor] = None,
|
| 87 |
mSeqUsedK: Optional[cute.Tensor] = None,
|
| 88 |
+
tile_m: cutlass.Constexpr[Int32] = 128,
|
| 89 |
+
tile_n: cutlass.Constexpr[Int32] = 128,
|
| 90 |
):
|
| 91 |
offset_q = 0 if const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx]
|
| 92 |
offset_k = 0 if const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx]
|
| 93 |
padded_offset_q = (
|
| 94 |
0
|
| 95 |
if const_expr(mCuSeqlensQ is None)
|
| 96 |
+
else cute.assume((offset_q + batch_idx * tile_m) // tile_m * tile_m, divby=tile_m)
|
| 97 |
)
|
| 98 |
padded_offset_k = (
|
| 99 |
0
|
| 100 |
if const_expr(mCuSeqlensK is None)
|
| 101 |
+
else cute.assume((offset_k + batch_idx * tile_n) // tile_n * tile_n, divby=tile_n)
|
| 102 |
)
|
| 103 |
if const_expr(mSeqUsedQ is not None):
|
| 104 |
seqlen_q = mSeqUsedQ[batch_idx]
|
|
|
|
| 116 |
if const_expr(mCuSeqlensK is None)
|
| 117 |
else mCuSeqlensK[batch_idx + 1] - offset_k
|
| 118 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
return SeqlenInfoQK(
|
| 120 |
offset_q,
|
| 121 |
offset_k,
|
|
|
|
| 123 |
padded_offset_k,
|
| 124 |
seqlen_q,
|
| 125 |
seqlen_k,
|
| 126 |
+
has_cu_seqlens_q=mCuSeqlensQ is not None,
|
| 127 |
+
has_cu_seqlens_k=mCuSeqlensK is not None,
|
| 128 |
+
has_seqused_q=mSeqUsedQ is not None,
|
| 129 |
+
has_seqused_k=mSeqUsedK is not None,
|
| 130 |
)
|
| 131 |
|
| 132 |
def offset_batch_Q(
|
|
|
|
| 135 |
batch_idx: Int32,
|
| 136 |
dim: int,
|
| 137 |
padded: cutlass.Constexpr[bool] = False,
|
| 138 |
+
ragged: cutlass.Constexpr[bool] = False,
|
| 139 |
) -> cute.Tensor:
|
| 140 |
"""Seqlen must be the first dimension of mQ"""
|
| 141 |
+
if const_expr(not ragged):
|
| 142 |
+
if const_expr(not self.has_cu_seqlens_q):
|
| 143 |
+
idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mQ) - 1 - dim)
|
| 144 |
+
return mQ[idx]
|
| 145 |
+
else:
|
| 146 |
+
offset_q = self.offset_q if const_expr(not padded) else self.padded_offset_q
|
| 147 |
+
offset_q = offset_q if const_expr(cute.rank(mQ.shape[0]) == 1) else (None, offset_q)
|
| 148 |
+
idx = (offset_q,) + (None,) * (cute.rank(mQ) - 1)
|
| 149 |
+
return cute.domain_offset(idx, mQ)
|
| 150 |
else:
|
| 151 |
+
if const_expr(not self.has_cu_seqlens_q):
|
| 152 |
+
offset_q = 0
|
| 153 |
+
idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mQ) - 1 - dim)
|
| 154 |
+
mQ = mQ[idx]
|
| 155 |
+
else:
|
| 156 |
+
offset_q = self.offset_q if const_expr(not padded) else self.padded_offset_q
|
| 157 |
+
if const_expr(cute.rank(mQ.shape[0]) == 1):
|
| 158 |
+
return copy_utils.offset_ragged_tensor(
|
| 159 |
+
mQ, offset_q, self.seqlen_q, ragged_dim=0, ptr_shift=True
|
| 160 |
+
)
|
| 161 |
+
else: # PackGQA
|
| 162 |
+
assert cute.rank(mQ.shape[0]) == 2
|
| 163 |
+
# Unpack before calling offset_ragged_tensor, then pack
|
| 164 |
+
idx = ((None, None),) + (None,) * (cute.rank(mQ) - 1)
|
| 165 |
+
mQ = mQ[idx]
|
| 166 |
+
mQ = copy_utils.offset_ragged_tensor(
|
| 167 |
+
mQ, offset_q, self.seqlen_q, ragged_dim=1, ptr_shift=True
|
| 168 |
+
)
|
| 169 |
+
return cute.group_modes(mQ, 0, 2)
|
| 170 |
|
| 171 |
def offset_batch_K(
|
| 172 |
self,
|
|
|
|
| 174 |
batch_idx: Int32,
|
| 175 |
dim: int,
|
| 176 |
padded: cutlass.Constexpr[bool] = False,
|
| 177 |
+
ragged: cutlass.Constexpr[bool] = False,
|
| 178 |
+
multiple: int = 1,
|
| 179 |
) -> cute.Tensor:
|
| 180 |
"""Seqlen must be the first dimension of mK"""
|
| 181 |
+
if const_expr(not ragged):
|
| 182 |
+
if const_expr(not self.has_cu_seqlens_k):
|
| 183 |
+
idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mK) - 1 - dim)
|
| 184 |
+
return mK[idx]
|
| 185 |
+
else:
|
| 186 |
+
offset_k = self.offset_k if const_expr(not padded) else self.padded_offset_k
|
| 187 |
+
offset_k *= multiple
|
| 188 |
+
idx = (offset_k,) + (None,) * (cute.rank(mK) - 1)
|
| 189 |
+
return cute.domain_offset(idx, mK)
|
| 190 |
else:
|
| 191 |
+
if const_expr(not self.has_cu_seqlens_k):
|
| 192 |
+
offset_k = 0
|
| 193 |
+
idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mK) - 1 - dim)
|
| 194 |
+
mK = mK[idx]
|
| 195 |
+
else:
|
| 196 |
+
offset_k = self.offset_k if const_expr(not padded) else self.padded_offset_k
|
| 197 |
+
offset_k *= multiple
|
| 198 |
+
return copy_utils.offset_ragged_tensor(
|
| 199 |
+
mK, offset_k, self.seqlen_k, ragged_dim=0, ptr_shift=True
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
@dataclass(frozen=True)
|
| 204 |
+
class SeqlenInfoQKNewK:
|
| 205 |
+
"""Sequence length info for append-KV with left-padding and new K support.
|
| 206 |
+
|
| 207 |
+
Extends SeqlenInfoQK with:
|
| 208 |
+
- leftpad_k: left padding for K (tokens to skip at the start of the KV cache)
|
| 209 |
+
- offset_k_new: offset into the new K tensor
|
| 210 |
+
- seqlen_k_og: original K length (before appending new K), excluding leftpad
|
| 211 |
+
- seqlen_k_new: length of new K to append
|
| 212 |
+
- seqlen_k: total K length (seqlen_k_og + seqlen_k_new)
|
| 213 |
+
- seqlen_rotary: position for rotary embedding computation
|
| 214 |
+
"""
|
| 215 |
+
|
| 216 |
+
leftpad_k: Int32
|
| 217 |
+
offset_q: Int32
|
| 218 |
+
offset_k: Int32
|
| 219 |
+
offset_k_new: Int32
|
| 220 |
+
seqlen_q: Int32
|
| 221 |
+
seqlen_k_og: Int32
|
| 222 |
+
seqlen_k_new: Int32
|
| 223 |
+
seqlen_k: Int32
|
| 224 |
+
seqlen_rotary: Int32
|
| 225 |
+
|
| 226 |
+
@staticmethod
|
| 227 |
+
def create(
|
| 228 |
+
batch_idx: Int32,
|
| 229 |
+
seqlen_q_static: Int32,
|
| 230 |
+
seqlen_k_static: Int32,
|
| 231 |
+
shape_K_new_0: Int32,
|
| 232 |
+
mCuSeqlensQ: Optional[cute.Tensor] = None,
|
| 233 |
+
mCuSeqlensK: Optional[cute.Tensor] = None,
|
| 234 |
+
mCuSeqlensKNew: Optional[cute.Tensor] = None,
|
| 235 |
+
mSeqUsedQ: Optional[cute.Tensor] = None,
|
| 236 |
+
mSeqUsedK: Optional[cute.Tensor] = None,
|
| 237 |
+
mLeftpadK: Optional[cute.Tensor] = None,
|
| 238 |
+
mSeqlensRotary: Optional[cute.Tensor] = None,
|
| 239 |
+
):
|
| 240 |
+
leftpad_k = 0 if const_expr(mLeftpadK is None) else mLeftpadK[batch_idx]
|
| 241 |
+
offset_q = 0 if const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx]
|
| 242 |
+
if const_expr(mCuSeqlensK is not None):
|
| 243 |
+
offset_k = mCuSeqlensK[batch_idx] + leftpad_k
|
| 244 |
+
else:
|
| 245 |
+
offset_k = leftpad_k if const_expr(mCuSeqlensQ is not None) else 0
|
| 246 |
+
offset_k_new = 0 if const_expr(mCuSeqlensKNew is None) else mCuSeqlensKNew[batch_idx]
|
| 247 |
+
# seqlen_q
|
| 248 |
+
if const_expr(mSeqUsedQ is not None):
|
| 249 |
+
seqlen_q = mSeqUsedQ[batch_idx]
|
| 250 |
+
elif const_expr(mCuSeqlensQ is not None):
|
| 251 |
+
seqlen_q = mCuSeqlensQ[batch_idx + 1] - mCuSeqlensQ[batch_idx]
|
| 252 |
+
else:
|
| 253 |
+
seqlen_q = seqlen_q_static
|
| 254 |
+
# seqlen_k_og: original K length (excluding leftpad)
|
| 255 |
+
if const_expr(mSeqUsedK is not None):
|
| 256 |
+
seqlen_k_og = mSeqUsedK[batch_idx] - leftpad_k
|
| 257 |
+
elif const_expr(mCuSeqlensK is not None):
|
| 258 |
+
seqlen_k_og = mCuSeqlensK[batch_idx + 1] - mCuSeqlensK[batch_idx] - leftpad_k
|
| 259 |
+
else:
|
| 260 |
+
seqlen_k_og = (
|
| 261 |
+
seqlen_k_static - leftpad_k
|
| 262 |
+
if const_expr(mCuSeqlensQ is not None)
|
| 263 |
+
else seqlen_k_static
|
| 264 |
+
)
|
| 265 |
+
# seqlen_k_new
|
| 266 |
+
if const_expr(mCuSeqlensKNew is None):
|
| 267 |
+
seqlen_k_new = 0 if const_expr(mCuSeqlensQ is None) else shape_K_new_0
|
| 268 |
+
else:
|
| 269 |
+
seqlen_k_new = mCuSeqlensKNew[batch_idx + 1] - mCuSeqlensKNew[batch_idx]
|
| 270 |
+
seqlen_k = seqlen_k_og if const_expr(mCuSeqlensQ is None) else seqlen_k_og + seqlen_k_new
|
| 271 |
+
|
| 272 |
+
# seqlen_rotary: defaults to seqlen_k_og + leftpad_k unless explicitly provided
|
| 273 |
+
if const_expr(mSeqlensRotary is not None):
|
| 274 |
+
seqlen_rotary = mSeqlensRotary[batch_idx]
|
| 275 |
+
else:
|
| 276 |
+
seqlen_rotary = seqlen_k_og + leftpad_k
|
| 277 |
+
return SeqlenInfoQKNewK(
|
| 278 |
+
leftpad_k,
|
| 279 |
+
offset_q,
|
| 280 |
+
offset_k,
|
| 281 |
+
offset_k_new,
|
| 282 |
+
seqlen_q,
|
| 283 |
+
seqlen_k_og,
|
| 284 |
+
seqlen_k_new,
|
| 285 |
+
seqlen_k,
|
| 286 |
+
seqlen_rotary,
|
| 287 |
+
)
|
build/torch-cuda/sm90_config_search.py
ADDED
|
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Search feasible SM90 fwd/bwd attention configs for given (head_dim, head_dim_v).
|
| 2 |
+
|
| 3 |
+
Enumerates tile sizes, swap modes, atom layouts, and staging options.
|
| 4 |
+
Checks GMMA divisibility, register budget, and shared memory budget.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
python flash_attn/cute/sm90_config_search.py --headdim 128
|
| 8 |
+
python flash_attn/cute/sm90_config_search.py --mode fwd --headdim 192-128
|
| 9 |
+
python flash_attn/cute/sm90_config_search.py --mode bwd --headdim 192 --tile-n 64,96
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import math
|
| 13 |
+
|
| 14 |
+
# H100 hardware limits
|
| 15 |
+
SMEM_LIMIT = 224 * 1024 # 228 KB minus ~3 KB for LSE, dPsum, mbarriers
|
| 16 |
+
REG_LIMITS = {2: 216, 3: 128} # per-WG budget: 2WG=240-24, 3WG=160-32
|
| 17 |
+
THREADS_PER_WG = 128
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _divisors(n):
|
| 21 |
+
return [d for d in range(1, n + 1) if n % d == 0]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _acc_regs(M, N, num_wg):
|
| 25 |
+
"""Accumulator registers per thread per WG."""
|
| 26 |
+
return M * N // (num_wg * THREADS_PER_WG)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _check_mma(M, N, num_wg, atom_layout_m, swap_AB):
|
| 30 |
+
"""Check MMA feasibility. Returns regs per WG, or None if infeasible.
|
| 31 |
+
|
| 32 |
+
GMMA atom M=64. Swap exchanges (M, N) and atom layout.
|
| 33 |
+
Requires: M divisible by (atom_layout_m * 64), N by (atom_layout_n * 8).
|
| 34 |
+
"""
|
| 35 |
+
if swap_AB:
|
| 36 |
+
M, N = N, M
|
| 37 |
+
atom_layout_m = num_wg // atom_layout_m
|
| 38 |
+
atom_layout_n = num_wg // atom_layout_m
|
| 39 |
+
if M % (atom_layout_m * 64) != 0 or N % (atom_layout_n * 8) != 0:
|
| 40 |
+
return None
|
| 41 |
+
return _acc_regs(M, N, num_wg)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _mma_traffic(M_eff, N_eff, K_red, num_wg, wg_n, is_rs=False):
|
| 45 |
+
"""Total SMEM read traffic for one MMA (all WGs combined).
|
| 46 |
+
|
| 47 |
+
num_instr = (M_eff / 64) * wg_n instructions total.
|
| 48 |
+
Each reads A(64, K_red) and B(N_eff/wg_n, K_red) from smem (bf16).
|
| 49 |
+
"""
|
| 50 |
+
num_instr = (M_eff // 64) * wg_n
|
| 51 |
+
A_per = 64 * K_red * 2 if not is_rs else 0
|
| 52 |
+
B_per = (N_eff // wg_n) * K_red * 2
|
| 53 |
+
return num_instr * (A_per + B_per)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# ============================================================================
|
| 57 |
+
# Backward
|
| 58 |
+
# ============================================================================
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _check_bwd_config(
|
| 62 |
+
hdim,
|
| 63 |
+
hdimv,
|
| 64 |
+
tile_m,
|
| 65 |
+
tile_n,
|
| 66 |
+
num_wg,
|
| 67 |
+
SdP_swapAB,
|
| 68 |
+
dKV_swapAB,
|
| 69 |
+
dQ_swapAB,
|
| 70 |
+
AtomLayoutMSdP,
|
| 71 |
+
AtomLayoutNdKV,
|
| 72 |
+
AtomLayoutMdQ,
|
| 73 |
+
):
|
| 74 |
+
reg_limit = REG_LIMITS[num_wg]
|
| 75 |
+
|
| 76 |
+
# MMA feasibility
|
| 77 |
+
regs_SdP = _check_mma(tile_m, tile_n, num_wg, AtomLayoutMSdP, SdP_swapAB)
|
| 78 |
+
regs_dK = _check_mma(tile_n, hdim, num_wg, AtomLayoutNdKV, dKV_swapAB)
|
| 79 |
+
regs_dV = _check_mma(tile_n, hdimv, num_wg, AtomLayoutNdKV, dKV_swapAB)
|
| 80 |
+
regs_dQ = _check_mma(tile_m, hdim, num_wg, AtomLayoutMdQ, dQ_swapAB)
|
| 81 |
+
if any(r is None for r in (regs_SdP, regs_dK, regs_dV, regs_dQ)):
|
| 82 |
+
return None
|
| 83 |
+
|
| 84 |
+
# Peak regs: max(S+dP, dQ) + dK + dV
|
| 85 |
+
total_regs = max(2 * regs_SdP, regs_dQ) + regs_dK + regs_dV
|
| 86 |
+
if total_regs > reg_limit:
|
| 87 |
+
return None
|
| 88 |
+
|
| 89 |
+
# SMEM
|
| 90 |
+
mma_dkv_is_rs = (
|
| 91 |
+
AtomLayoutMSdP == 1 and AtomLayoutNdKV == num_wg and SdP_swapAB and not dKV_swapAB
|
| 92 |
+
)
|
| 93 |
+
Q_stage, PdS_stage = 2, 1
|
| 94 |
+
|
| 95 |
+
for dO_stage in (2, 1):
|
| 96 |
+
sQ = tile_m * hdim * 2 * Q_stage
|
| 97 |
+
sK = tile_n * hdim * 2
|
| 98 |
+
sV = tile_n * hdimv * 2
|
| 99 |
+
sdO = tile_m * hdimv * 2 * dO_stage
|
| 100 |
+
sPdS = tile_m * tile_n * 2 * PdS_stage
|
| 101 |
+
sP = sPdS if not mma_dkv_is_rs else 0
|
| 102 |
+
sdQaccum = tile_m * hdim * 4
|
| 103 |
+
smem = sQ + sK + sV + sdO + sP + sPdS + sdQaccum
|
| 104 |
+
if smem <= SMEM_LIMIT:
|
| 105 |
+
break
|
| 106 |
+
else:
|
| 107 |
+
return None
|
| 108 |
+
|
| 109 |
+
# SMEM traffic
|
| 110 |
+
def _swap(a, b, s):
|
| 111 |
+
return (b, a) if s else (a, b)
|
| 112 |
+
|
| 113 |
+
def _wg_n(al_m, s):
|
| 114 |
+
return al_m if s else num_wg // al_m
|
| 115 |
+
|
| 116 |
+
M_s, N_s = _swap(tile_m, tile_n, SdP_swapAB)
|
| 117 |
+
wn_SdP = _wg_n(AtomLayoutMSdP, SdP_swapAB)
|
| 118 |
+
traffic_S = _mma_traffic(M_s, N_s, hdim, num_wg, wn_SdP)
|
| 119 |
+
traffic_dP = _mma_traffic(M_s, N_s, hdimv, num_wg, wn_SdP)
|
| 120 |
+
|
| 121 |
+
wn_dKV = _wg_n(AtomLayoutNdKV, dKV_swapAB)
|
| 122 |
+
M_dv, N_dv = _swap(tile_n, hdimv, dKV_swapAB)
|
| 123 |
+
traffic_dV = _mma_traffic(M_dv, N_dv, tile_m, num_wg, wn_dKV, is_rs=mma_dkv_is_rs)
|
| 124 |
+
M_dk, N_dk = _swap(tile_n, hdim, dKV_swapAB)
|
| 125 |
+
traffic_dK = _mma_traffic(M_dk, N_dk, tile_m, num_wg, wn_dKV, is_rs=mma_dkv_is_rs)
|
| 126 |
+
|
| 127 |
+
M_dq, N_dq = _swap(tile_m, hdim, dQ_swapAB)
|
| 128 |
+
wn_dQ = _wg_n(AtomLayoutMdQ, dQ_swapAB)
|
| 129 |
+
traffic_dQ = _mma_traffic(M_dq, N_dq, tile_n, num_wg, wn_dQ)
|
| 130 |
+
|
| 131 |
+
traffic_P_store = tile_m * tile_n * 2 if not mma_dkv_is_rs else 0
|
| 132 |
+
traffic_dS_store = tile_m * tile_n * 2
|
| 133 |
+
traffic_dQ_smem = tile_m * hdim * 4 * 2 # store + TMA load
|
| 134 |
+
|
| 135 |
+
smem_traffic = (
|
| 136 |
+
traffic_S
|
| 137 |
+
+ traffic_dP
|
| 138 |
+
+ traffic_dV
|
| 139 |
+
+ traffic_dK
|
| 140 |
+
+ traffic_dQ
|
| 141 |
+
+ traffic_P_store
|
| 142 |
+
+ traffic_dS_store
|
| 143 |
+
+ traffic_dQ_smem
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
return dict(
|
| 147 |
+
tile_m=tile_m,
|
| 148 |
+
tile_n=tile_n,
|
| 149 |
+
num_wg=num_wg,
|
| 150 |
+
Q_stage=Q_stage,
|
| 151 |
+
dO_stage=dO_stage,
|
| 152 |
+
PdS_stage=PdS_stage,
|
| 153 |
+
SdP_swapAB=SdP_swapAB,
|
| 154 |
+
dKV_swapAB=dKV_swapAB,
|
| 155 |
+
dQ_swapAB=dQ_swapAB,
|
| 156 |
+
AtomLayoutMSdP=AtomLayoutMSdP,
|
| 157 |
+
AtomLayoutNdKV=AtomLayoutNdKV,
|
| 158 |
+
AtomLayoutMdQ=AtomLayoutMdQ,
|
| 159 |
+
mma_dkv_is_rs=mma_dkv_is_rs,
|
| 160 |
+
regs_SdP=regs_SdP,
|
| 161 |
+
regs_dK=regs_dK,
|
| 162 |
+
regs_dV=regs_dV,
|
| 163 |
+
regs_dQ=regs_dQ,
|
| 164 |
+
total_regs=total_regs,
|
| 165 |
+
reg_limit=reg_limit,
|
| 166 |
+
smem_bytes=smem,
|
| 167 |
+
smem_kb=smem / 1024,
|
| 168 |
+
smem_traffic=smem_traffic,
|
| 169 |
+
smem_traffic_kb=smem_traffic / 1024,
|
| 170 |
+
smem_traffic_per_block=smem_traffic / (tile_m * tile_n),
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def find_feasible_bwd_configs(
|
| 175 |
+
head_dim,
|
| 176 |
+
head_dim_v=None,
|
| 177 |
+
tile_m_choices=(64, 80, 96, 112, 128),
|
| 178 |
+
tile_n_choices=(64, 80, 96, 112, 128),
|
| 179 |
+
):
|
| 180 |
+
if head_dim_v is None:
|
| 181 |
+
head_dim_v = head_dim
|
| 182 |
+
hdim = int(math.ceil(head_dim / 32) * 32)
|
| 183 |
+
hdimv = int(math.ceil(head_dim_v / 32) * 32)
|
| 184 |
+
|
| 185 |
+
results = []
|
| 186 |
+
for num_wg in (2, 3):
|
| 187 |
+
divs = _divisors(num_wg)
|
| 188 |
+
for tile_m in tile_m_choices:
|
| 189 |
+
for tile_n in tile_n_choices:
|
| 190 |
+
for SdP_swap in (False, True):
|
| 191 |
+
if (tile_n if SdP_swap else tile_m) % 64 != 0:
|
| 192 |
+
continue
|
| 193 |
+
for dKV_swap in (False, True):
|
| 194 |
+
if not dKV_swap and tile_n % 64 != 0:
|
| 195 |
+
continue
|
| 196 |
+
if dKV_swap and (hdim % 64 != 0 or hdimv % 64 != 0):
|
| 197 |
+
continue
|
| 198 |
+
for dQ_swap in (False, True):
|
| 199 |
+
if (hdim if dQ_swap else tile_m) % 64 != 0:
|
| 200 |
+
continue
|
| 201 |
+
for a1 in divs:
|
| 202 |
+
for a2 in divs:
|
| 203 |
+
for a3 in divs:
|
| 204 |
+
cfg = _check_bwd_config(
|
| 205 |
+
hdim,
|
| 206 |
+
hdimv,
|
| 207 |
+
tile_m,
|
| 208 |
+
tile_n,
|
| 209 |
+
num_wg,
|
| 210 |
+
SdP_swap,
|
| 211 |
+
dKV_swap,
|
| 212 |
+
dQ_swap,
|
| 213 |
+
a1,
|
| 214 |
+
a2,
|
| 215 |
+
a3,
|
| 216 |
+
)
|
| 217 |
+
if cfg is not None:
|
| 218 |
+
results.append(cfg)
|
| 219 |
+
|
| 220 |
+
results.sort(key=lambda c: (-c["tile_n"], -c["tile_m"], c["smem_traffic_per_block"]))
|
| 221 |
+
return results
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def print_bwd_configs(configs, max_results=20):
|
| 225 |
+
if not configs:
|
| 226 |
+
print("No feasible configs found!")
|
| 227 |
+
return
|
| 228 |
+
n = min(len(configs), max_results)
|
| 229 |
+
print(f"Found {len(configs)} feasible configs (showing top {n}):\n")
|
| 230 |
+
hdr = (
|
| 231 |
+
f"{'wg':>2} {'tm':>3} {'tn':>3} "
|
| 232 |
+
f"{'SdP':>3} {'dKV':>3} {'dQ':>3} "
|
| 233 |
+
f"{'aSdP':>4} {'adKV':>4} {'adQ':>4} "
|
| 234 |
+
f"{'Qs':>2} {'dOs':>3} "
|
| 235 |
+
f"{'rS':>3} {'rdK':>3} {'rdV':>3} {'rdQ':>3} {'tot':>4}/{'':<3} "
|
| 236 |
+
f"{'smem':>5} {'traffic':>7} {'tr/blk':>6}"
|
| 237 |
+
)
|
| 238 |
+
print(hdr)
|
| 239 |
+
print("-" * len(hdr))
|
| 240 |
+
B = lambda b: "T" if b else "F"
|
| 241 |
+
for c in configs[:max_results]:
|
| 242 |
+
print(
|
| 243 |
+
f"{c['num_wg']:>2} {c['tile_m']:>3} {c['tile_n']:>3} "
|
| 244 |
+
f"{B(c['SdP_swapAB']):>3} {B(c['dKV_swapAB']):>3} {B(c['dQ_swapAB']):>3} "
|
| 245 |
+
f"{c['AtomLayoutMSdP']:>4} {c['AtomLayoutNdKV']:>4} {c['AtomLayoutMdQ']:>4} "
|
| 246 |
+
f"{c['Q_stage']:>2} {c['dO_stage']:>3} "
|
| 247 |
+
f"{c['regs_SdP']:>3} {c['regs_dK']:>3} {c['regs_dV']:>3} {c['regs_dQ']:>3} "
|
| 248 |
+
f"{c['total_regs']:>4}/{c['reg_limit']:<3} "
|
| 249 |
+
f"{c['smem_kb']:>4.0f}K "
|
| 250 |
+
f"{c['smem_traffic_kb']:>6.0f}K "
|
| 251 |
+
f"{c['smem_traffic_per_block']:>6.1f}"
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
# ============================================================================
|
| 256 |
+
# Forward
|
| 257 |
+
# ============================================================================
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def _check_fwd_config(hdim, hdimv, tile_n, num_wg, pv_is_rs, overlap_wg):
|
| 261 |
+
reg_limit = REG_LIMITS[num_wg]
|
| 262 |
+
tile_m = num_wg * 64
|
| 263 |
+
|
| 264 |
+
if tile_n % 8 != 0:
|
| 265 |
+
return None
|
| 266 |
+
|
| 267 |
+
regs_S = _acc_regs(tile_m, tile_n, num_wg)
|
| 268 |
+
regs_O = _acc_regs(tile_m, hdimv, num_wg)
|
| 269 |
+
regs_P = regs_S // 2 # bf16 = half of f32
|
| 270 |
+
|
| 271 |
+
if overlap_wg:
|
| 272 |
+
total_regs = regs_S + regs_P + regs_O
|
| 273 |
+
else:
|
| 274 |
+
total_regs = regs_S + regs_O
|
| 275 |
+
|
| 276 |
+
if total_regs > reg_limit:
|
| 277 |
+
return None
|
| 278 |
+
|
| 279 |
+
# SMEM: 1 stage Q, 2 stages K/V, O overlaps Q, sP if not RS
|
| 280 |
+
sQ = tile_m * hdim * 2
|
| 281 |
+
sK = tile_n * hdim * 2 * 2
|
| 282 |
+
sV = tile_n * hdimv * 2 * 2
|
| 283 |
+
sO = tile_m * hdimv * 2
|
| 284 |
+
sP = tile_m * tile_n * 2 if not pv_is_rs else 0
|
| 285 |
+
smem = max(sQ, sO) + sK + sV + sP
|
| 286 |
+
if smem > SMEM_LIMIT:
|
| 287 |
+
return None
|
| 288 |
+
|
| 289 |
+
# SMEM traffic: num_instr = num_wg (all WGs in M, wg_n=1)
|
| 290 |
+
traffic_S = num_wg * (64 * hdim * 2 + tile_n * hdim * 2)
|
| 291 |
+
A_pv = 64 * tile_n * 2 if not pv_is_rs else 0
|
| 292 |
+
traffic_O = num_wg * (A_pv + hdimv * tile_n * 2)
|
| 293 |
+
traffic_P_store = tile_m * tile_n * 2 if not pv_is_rs else 0
|
| 294 |
+
smem_traffic = traffic_S + traffic_O + traffic_P_store
|
| 295 |
+
|
| 296 |
+
return dict(
|
| 297 |
+
tile_m=tile_m,
|
| 298 |
+
tile_n=tile_n,
|
| 299 |
+
num_wg=num_wg,
|
| 300 |
+
pv_is_rs=pv_is_rs,
|
| 301 |
+
overlap_wg=overlap_wg,
|
| 302 |
+
regs_S=regs_S,
|
| 303 |
+
regs_O=regs_O,
|
| 304 |
+
regs_P=regs_P,
|
| 305 |
+
total_regs=total_regs,
|
| 306 |
+
reg_limit=reg_limit,
|
| 307 |
+
smem_bytes=smem,
|
| 308 |
+
smem_kb=smem / 1024,
|
| 309 |
+
smem_traffic=smem_traffic,
|
| 310 |
+
smem_traffic_kb=smem_traffic / 1024,
|
| 311 |
+
smem_traffic_per_block=smem_traffic / (tile_m * tile_n),
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def find_feasible_fwd_configs(
|
| 316 |
+
head_dim, head_dim_v=None, tile_n_choices=(64, 80, 96, 112, 128, 144, 160, 176, 192)
|
| 317 |
+
):
|
| 318 |
+
if head_dim_v is None:
|
| 319 |
+
head_dim_v = head_dim
|
| 320 |
+
hdim = int(math.ceil(head_dim / 32) * 32)
|
| 321 |
+
hdimv = int(math.ceil(head_dim_v / 32) * 32)
|
| 322 |
+
|
| 323 |
+
results = []
|
| 324 |
+
for num_wg in (2, 3):
|
| 325 |
+
for tile_n in tile_n_choices:
|
| 326 |
+
for pv_is_rs in (True, False):
|
| 327 |
+
for overlap_wg in (True, False):
|
| 328 |
+
cfg = _check_fwd_config(hdim, hdimv, tile_n, num_wg, pv_is_rs, overlap_wg)
|
| 329 |
+
if cfg is not None:
|
| 330 |
+
results.append(cfg)
|
| 331 |
+
|
| 332 |
+
results.sort(key=lambda c: (-c["tile_n"], c["smem_traffic_per_block"]))
|
| 333 |
+
return results
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def print_fwd_configs(configs, max_results=20):
|
| 337 |
+
if not configs:
|
| 338 |
+
print("No feasible configs found!")
|
| 339 |
+
return
|
| 340 |
+
n = min(len(configs), max_results)
|
| 341 |
+
print(f"Found {len(configs)} feasible configs (showing top {n}):\n")
|
| 342 |
+
hdr = (
|
| 343 |
+
f"{'wg':>2} {'tm':>3} {'tn':>3} "
|
| 344 |
+
f"{'RS':>2} {'olap':>4} "
|
| 345 |
+
f"{'rS':>3} {'rP':>3} {'rO':>3} {'tot':>4}/{'':<3} "
|
| 346 |
+
f"{'smem':>5} {'traffic':>7} {'tr/blk':>6}"
|
| 347 |
+
)
|
| 348 |
+
print(hdr)
|
| 349 |
+
print("-" * len(hdr))
|
| 350 |
+
B = lambda b: "T" if b else "F"
|
| 351 |
+
for c in configs[:max_results]:
|
| 352 |
+
print(
|
| 353 |
+
f"{c['num_wg']:>2} {c['tile_m']:>3} {c['tile_n']:>3} "
|
| 354 |
+
f"{B(c['pv_is_rs']):>2} {B(c['overlap_wg']):>4} "
|
| 355 |
+
f"{c['regs_S']:>3} {c['regs_P']:>3} {c['regs_O']:>3} "
|
| 356 |
+
f"{c['total_regs']:>4}/{c['reg_limit']:<3} "
|
| 357 |
+
f"{c['smem_kb']:>4.0f}K "
|
| 358 |
+
f"{c['smem_traffic_kb']:>6.0f}K "
|
| 359 |
+
f"{c['smem_traffic_per_block']:>6.1f}"
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
# ============================================================================
|
| 364 |
+
# CLI
|
| 365 |
+
# ============================================================================
|
| 366 |
+
|
| 367 |
+
if __name__ == "__main__":
|
| 368 |
+
import argparse
|
| 369 |
+
|
| 370 |
+
parser = argparse.ArgumentParser(description="Search feasible SM90 MMA configs")
|
| 371 |
+
parser.add_argument("--mode", choices=["fwd", "bwd", "both"], default="both")
|
| 372 |
+
parser.add_argument(
|
| 373 |
+
"--headdim", type=str, default="128", help="Head dim, or hdim-hdimv (e.g. 192-128)"
|
| 374 |
+
)
|
| 375 |
+
parser.add_argument("--tile-m", type=str, default="64,80,96,112,128", help="Bwd tile_m choices")
|
| 376 |
+
parser.add_argument(
|
| 377 |
+
"--tile-n",
|
| 378 |
+
type=str,
|
| 379 |
+
default=None,
|
| 380 |
+
help="tile_n choices (default: fwd up to 192, bwd up to 128)",
|
| 381 |
+
)
|
| 382 |
+
parser.add_argument("-n", "--num-results", type=int, default=30)
|
| 383 |
+
args = parser.parse_args()
|
| 384 |
+
|
| 385 |
+
parts = args.headdim.split("-")
|
| 386 |
+
hdim = int(parts[0])
|
| 387 |
+
hdimv = int(parts[1]) if len(parts) > 1 else hdim
|
| 388 |
+
|
| 389 |
+
TN_FWD = "64,80,96,112,128,144,160,176,192"
|
| 390 |
+
TN_BWD = "64,80,96,112,128"
|
| 391 |
+
|
| 392 |
+
if args.mode in ("fwd", "both"):
|
| 393 |
+
tn = tuple(int(x) for x in (args.tile_n or TN_FWD).split(","))
|
| 394 |
+
print(f"=== FWD configs: hdim={hdim}, hdimv={hdimv} ===\n")
|
| 395 |
+
print_fwd_configs(find_feasible_fwd_configs(hdim, hdimv, tn), args.num_results)
|
| 396 |
+
print()
|
| 397 |
+
|
| 398 |
+
if args.mode in ("bwd", "both"):
|
| 399 |
+
tm = tuple(int(x) for x in args.tile_m.split(","))
|
| 400 |
+
tn = tuple(int(x) for x in (args.tile_n or TN_BWD).split(","))
|
| 401 |
+
print(f"=== BWD configs: hdim={hdim}, hdimv={hdimv} ===\n")
|
| 402 |
+
print_bwd_configs(find_feasible_bwd_configs(hdim, hdimv, tm, tn), args.num_results)
|
build/torch-cuda/softmax.py
CHANGED
|
@@ -10,7 +10,7 @@ import cutlass.cute as cute
|
|
| 10 |
from cutlass import Float32
|
| 11 |
|
| 12 |
from .quack import layout_utils
|
| 13 |
-
from . import utils
|
| 14 |
from .quack.cute_dsl_utils import ParamsBase
|
| 15 |
from .seqlen_info import SeqlenInfoQK
|
| 16 |
|
|
|
|
| 10 |
from cutlass import Float32
|
| 11 |
|
| 12 |
from .quack import layout_utils
|
| 13 |
+
from . import utils as utils
|
| 14 |
from .quack.cute_dsl_utils import ParamsBase
|
| 15 |
from .seqlen_info import SeqlenInfoQK
|
| 16 |
|
build/torch-cuda/tile_scheduler.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
# Copyright (c) 2025, Tri Dao.
|
| 2 |
|
| 3 |
-
from
|
|
|
|
| 4 |
from dataclasses import dataclass
|
| 5 |
|
| 6 |
try:
|
|
@@ -9,17 +10,80 @@ except ImportError: # Python < 3.12
|
|
| 9 |
from typing_extensions import override
|
| 10 |
|
| 11 |
import cutlass
|
|
|
|
| 12 |
from cutlass._mlir import ir
|
| 13 |
import cutlass.cute as cute
|
| 14 |
from cutlass import Int32, const_expr
|
| 15 |
from cutlass.cute import FastDivmodDivisor
|
|
|
|
| 16 |
|
| 17 |
from .quack.cute_dsl_utils import ParamsBase
|
| 18 |
|
| 19 |
-
from . import utils
|
| 20 |
from .fast_math import clz
|
| 21 |
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
class WorkTileInfo(cutlass.utils.WorkTileInfo):
|
| 24 |
"""Altered WorkTileInfo which includes four axes: (block, head, batch, split)"""
|
| 25 |
|
|
@@ -31,6 +95,47 @@ class WorkTileInfo(cutlass.utils.WorkTileInfo):
|
|
| 31 |
return WorkTileInfo(new_tile_idx, new_is_valid_tile)
|
| 32 |
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
@dataclass
|
| 35 |
class TileSchedulerArguments(ParamsBase):
|
| 36 |
num_block: Int32
|
|
@@ -51,6 +156,7 @@ class TileSchedulerArguments(ParamsBase):
|
|
| 51 |
lpt: cutlass.Constexpr[bool] = False
|
| 52 |
is_split_kv: cutlass.Constexpr[bool] = False
|
| 53 |
head_swizzle: cutlass.Constexpr[bool] = False
|
|
|
|
| 54 |
|
| 55 |
|
| 56 |
class SingleTileScheduler:
|
|
@@ -63,6 +169,7 @@ class SingleTileScheduler:
|
|
| 63 |
num_splits_divmod: FastDivmodDivisor
|
| 64 |
is_split_kv: cutlass.Constexpr[bool] = False
|
| 65 |
cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1)
|
|
|
|
| 66 |
|
| 67 |
@staticmethod
|
| 68 |
def create(
|
|
@@ -76,6 +183,7 @@ class SingleTileScheduler:
|
|
| 76 |
FastDivmodDivisor(args.num_splits),
|
| 77 |
args.is_split_kv,
|
| 78 |
args.cluster_shape_mn,
|
|
|
|
| 79 |
)
|
| 80 |
|
| 81 |
def __init__(self, params: Params, blk_coord: cute.Coord, *, loc=None, ip=None):
|
|
@@ -86,18 +194,26 @@ class SingleTileScheduler:
|
|
| 86 |
self._ip = ip
|
| 87 |
|
| 88 |
@staticmethod
|
| 89 |
-
def to_underlying_arguments(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
return SingleTileScheduler.Params.create(args, loc=loc, ip=ip)
|
| 91 |
|
| 92 |
@staticmethod
|
| 93 |
-
def create(
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
blk_coord = cute.arch.block_idx()
|
| 101 |
return SingleTileScheduler(params, blk_coord, loc=loc, ip=ip)
|
| 102 |
|
| 103 |
# called by host
|
|
@@ -110,8 +226,13 @@ class SingleTileScheduler:
|
|
| 110 |
) -> Tuple[Int32, Int32, Int32]:
|
| 111 |
# TODO: this hard-codes the fact that we only use cluster = (1, 1) or (2, 1)
|
| 112 |
assert params.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
return (
|
| 114 |
-
|
| 115 |
params.num_head * params.num_splits,
|
| 116 |
params.num_batch,
|
| 117 |
)
|
|
@@ -135,6 +256,10 @@ class SingleTileScheduler:
|
|
| 135 |
|
| 136 |
def advance_to_next_work(self, *, loc=None, ip=None):
|
| 137 |
self._is_first_block = False
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
def __extract_mlir_values__(self):
|
| 140 |
values, self._values_pos = [], []
|
|
@@ -180,18 +305,28 @@ class StaticPersistentTileScheduler:
|
|
| 180 |
self._ip = ip
|
| 181 |
|
| 182 |
@staticmethod
|
| 183 |
-
def to_underlying_arguments(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
return StaticPersistentTileScheduler.Params.create(args, loc=loc, ip=ip)
|
| 185 |
|
| 186 |
@staticmethod
|
| 187 |
-
def create(
|
|
|
|
|
|
|
| 188 |
if const_expr(cute.size(params.cluster_shape_m) == 1):
|
| 189 |
tile_idx = cute.arch.block_idx()[0]
|
| 190 |
else:
|
| 191 |
tile_idx = cute.arch.cluster_idx()[0]
|
| 192 |
return StaticPersistentTileScheduler(params, tile_idx, loc=loc, ip=ip)
|
| 193 |
|
| 194 |
-
# called by host
|
| 195 |
@staticmethod
|
| 196 |
def get_grid_shape(
|
| 197 |
params: Params,
|
|
@@ -201,18 +336,14 @@ class StaticPersistentTileScheduler:
|
|
| 201 |
) -> Tuple[Int32, Int32, Int32]:
|
| 202 |
hardware_info = cutlass.utils.HardwareInfo()
|
| 203 |
sm_count = hardware_info.get_device_multiprocessor_count()
|
| 204 |
-
# Grid must be a multiple of cluster_shape_m for CUDA cluster launch.
|
| 205 |
max_ctas = (sm_count // params.cluster_shape_m) * params.cluster_shape_m
|
| 206 |
grid_x = cutlass.min(max_ctas, params.total_blocks_cluster * params.cluster_shape_m)
|
| 207 |
return (grid_x, Int32(1), Int32(1))
|
| 208 |
|
| 209 |
-
# @cute.jit
|
| 210 |
def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
|
| 211 |
hn_idx, block_idx = divmod(self._tile_idx, self.params.num_block_cluster_divmod)
|
| 212 |
batch_idx, head_idx = divmod(hn_idx, self.params.num_head_divmod)
|
| 213 |
is_valid = self._tile_idx < self.params.total_blocks_cluster
|
| 214 |
-
# if cute.arch.thread_idx()[0] == 0:
|
| 215 |
-
# cute.printf("TileScheduler: tile_idx=%d, hn_idx=%d, block_idx=%d, batch_idx=%d, head_idx=%d, is_valid=%d", self._tile_idx, hn_idx, block_idx, batch_idx, head_idx, is_valid)
|
| 216 |
return WorkTileInfo(
|
| 217 |
(Int32(block_idx), Int32(head_idx), Int32(batch_idx), Int32(0)), is_valid
|
| 218 |
)
|
|
@@ -228,6 +359,10 @@ class StaticPersistentTileScheduler:
|
|
| 228 |
self._tile_idx += cute.arch.grid_dim()[0]
|
| 229 |
else:
|
| 230 |
self._tile_idx += cute.arch.cluster_dim()[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
|
| 232 |
def __extract_mlir_values__(self):
|
| 233 |
values, self._values_pos = [], []
|
|
@@ -254,32 +389,41 @@ class SingleTileLPTScheduler:
|
|
| 254 |
total_blocks: Int32
|
| 255 |
num_splits: Int32
|
| 256 |
num_block: Int32
|
|
|
|
|
|
|
| 257 |
l2_minor: Int32
|
| 258 |
-
num_block_divmod: FastDivmodDivisor
|
| 259 |
num_head_divmod: FastDivmodDivisor
|
| 260 |
l2_minor_divmod: FastDivmodDivisor
|
| 261 |
l2_major_divmod: FastDivmodDivisor
|
| 262 |
l2_minor_residual_divmod: FastDivmodDivisor
|
| 263 |
num_hb_quotient: Int32
|
|
|
|
| 264 |
is_split_kv: cutlass.Constexpr[bool] = False
|
|
|
|
|
|
|
|
|
|
| 265 |
|
| 266 |
@staticmethod
|
| 267 |
@cute.jit
|
| 268 |
def create(
|
| 269 |
-
args: TileSchedulerArguments,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
) -> "SingleTileLPTScheduler.Params":
|
| 271 |
-
|
|
|
|
|
|
|
| 272 |
size_one_kv_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size
|
| 273 |
size_one_head = size_one_kv_head
|
| 274 |
size_l2 = 50 * 1024 * 1024 # 40 MB for K & V
|
| 275 |
# Swizzle is the size of each "section". Round swizzle to a power of 2
|
| 276 |
# Need to be careful about the case where only one head will fit
|
| 277 |
# swizzle is how many heads can fit in L2
|
| 278 |
-
#
|
| 279 |
-
# Seems faster if swizzle if a power of 2
|
| 280 |
log2_floor = lambda n: 31 - clz(n)
|
| 281 |
swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head))
|
| 282 |
-
# swizzle = 1 if size_l2 < size_one_head else (size_l2 // size_one_head)
|
| 283 |
# If we're in the last section (called residual), we don't want to divide by
|
| 284 |
# swizzle. Instead we want to divide by the remainder.
|
| 285 |
num_hb_quotient = (args.num_head * args.num_batch) // swizzle
|
|
@@ -287,37 +431,84 @@ class SingleTileLPTScheduler:
|
|
| 287 |
return SingleTileLPTScheduler.Params(
|
| 288 |
total_blocks=args.num_block * args.num_head * args.num_batch,
|
| 289 |
num_block=args.num_block,
|
|
|
|
|
|
|
| 290 |
l2_minor=Int32(swizzle),
|
| 291 |
-
num_block_divmod=FastDivmodDivisor(args.num_block),
|
| 292 |
num_head_divmod=FastDivmodDivisor(args.num_head),
|
| 293 |
l2_minor_divmod=FastDivmodDivisor(swizzle),
|
| 294 |
l2_major_divmod=FastDivmodDivisor(swizzle * args.num_block),
|
| 295 |
-
l2_minor_residual_divmod=FastDivmodDivisor(
|
| 296 |
-
max(num_hb_remainder, 1)
|
| 297 |
-
), # don't divide by 0
|
| 298 |
num_hb_quotient=Int32(num_hb_quotient),
|
| 299 |
num_splits=args.num_splits,
|
|
|
|
| 300 |
is_split_kv=args.is_split_kv,
|
|
|
|
|
|
|
|
|
|
| 301 |
)
|
| 302 |
|
| 303 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
self.params = params
|
| 305 |
self._tile_idx = tile_idx
|
| 306 |
self._split_idx = split_idx
|
|
|
|
| 307 |
self._loc = loc
|
| 308 |
self._ip = ip
|
| 309 |
|
| 310 |
@staticmethod
|
| 311 |
-
def to_underlying_arguments(
|
| 312 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
|
| 314 |
@staticmethod
|
| 315 |
@cute.jit
|
| 316 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
tile_idx, split_idx, _ = cute.arch.block_idx()
|
| 318 |
return SingleTileLPTScheduler(params, tile_idx, split_idx, loc=loc, ip=ip)
|
| 319 |
|
| 320 |
-
# called by host
|
| 321 |
@staticmethod
|
| 322 |
def get_grid_shape(
|
| 323 |
params: Params,
|
|
@@ -325,10 +516,40 @@ class SingleTileLPTScheduler:
|
|
| 325 |
loc=None,
|
| 326 |
ip=None,
|
| 327 |
) -> Tuple[Int32, Int32, Int32]:
|
|
|
|
|
|
|
| 328 |
return (params.total_blocks, params.num_splits, Int32(1))
|
| 329 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
@cute.jit
|
| 331 |
def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
params = self.params
|
| 333 |
# Implement LPT scheduling coordinate calculation
|
| 334 |
bidhb, l2_mod = divmod(self._tile_idx, params.l2_major_divmod)
|
|
@@ -342,25 +563,45 @@ class SingleTileLPTScheduler:
|
|
| 342 |
bidhb_actual = bidhb * params.l2_minor + bidhb_residual
|
| 343 |
batch_idx, head_idx = divmod(bidhb_actual, params.num_head_divmod)
|
| 344 |
# Longest-processing-time-first
|
| 345 |
-
|
|
|
|
| 346 |
is_valid = self._tile_idx < params.total_blocks
|
| 347 |
return WorkTileInfo(
|
| 348 |
(Int32(block), Int32(head_idx), Int32(batch_idx), Int32(self._split_idx)), is_valid
|
| 349 |
)
|
| 350 |
|
|
|
|
| 351 |
def initial_work_tile_info(self, *, loc=None, ip=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 352 |
return self.get_current_work(loc=loc, ip=ip)
|
| 353 |
|
| 354 |
def prefetch_next_work(self, *, loc=None, ip=None):
|
| 355 |
-
|
|
|
|
| 356 |
|
| 357 |
def advance_to_next_work(self, *, loc=None, ip=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
# Single tile scheduler - set to invalid tile_idx to indicate no more work
|
| 359 |
self._tile_idx = self.params.total_blocks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 360 |
|
| 361 |
def __extract_mlir_values__(self):
|
| 362 |
values, self._values_pos = [], []
|
| 363 |
-
|
|
|
|
|
|
|
|
|
|
| 364 |
obj_values = cutlass.extract_mlir_values(obj)
|
| 365 |
values += obj_values
|
| 366 |
self._values_pos.append(len(obj_values))
|
|
@@ -368,10 +609,13 @@ class SingleTileLPTScheduler:
|
|
| 368 |
|
| 369 |
def __new_from_mlir_values__(self, values):
|
| 370 |
obj_list = []
|
| 371 |
-
|
|
|
|
|
|
|
|
|
|
| 372 |
obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
|
| 373 |
values = values[n_items:]
|
| 374 |
-
return self.__class__(*
|
| 375 |
|
| 376 |
|
| 377 |
class SingleTileLPTBwdScheduler:
|
|
@@ -395,8 +639,8 @@ class SingleTileLPTBwdScheduler:
|
|
| 395 |
) -> "SingleTileLPTBwdScheduler.Params":
|
| 396 |
size_l2 = 50 * 1024 * 1024
|
| 397 |
size_one_qdo_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size
|
| 398 |
-
|
| 399 |
-
size_one_dqaccum_head = 0
|
| 400 |
size_one_head = size_one_qdo_head + size_one_dqaccum_head
|
| 401 |
log2_floor = lambda n: 31 - clz(n)
|
| 402 |
swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head))
|
|
@@ -430,7 +674,16 @@ class SingleTileLPTBwdScheduler:
|
|
| 430 |
self._ip = ip
|
| 431 |
|
| 432 |
@staticmethod
|
| 433 |
-
def to_underlying_arguments(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 434 |
return SingleTileLPTBwdScheduler.Params.create(args, loc=loc, ip=ip)
|
| 435 |
|
| 436 |
@staticmethod
|
|
@@ -481,6 +734,7 @@ class SingleTileLPTBwdScheduler:
|
|
| 481 |
def advance_to_next_work(self, *, loc=None, ip=None):
|
| 482 |
# Single tile scheduler - set to invalid tile_idx to indicate no more work
|
| 483 |
self._tile_idx = self.params.total_blocks
|
|
|
|
| 484 |
|
| 485 |
def __extract_mlir_values__(self):
|
| 486 |
values, self._values_pos = [], []
|
|
@@ -514,20 +768,38 @@ class SingleTileVarlenScheduler:
|
|
| 514 |
is_split_kv: cutlass.Constexpr[bool] = False
|
| 515 |
head_swizzle: cutlass.Constexpr[bool] = False
|
| 516 |
cluster_shape_m: cutlass.Constexpr[int] = 1
|
|
|
|
| 517 |
|
| 518 |
@staticmethod
|
| 519 |
@cute.jit
|
| 520 |
def create(
|
| 521 |
-
args: TileSchedulerArguments,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 522 |
) -> "SingleTileVarlenScheduler.Params":
|
|
|
|
|
|
|
|
|
|
| 523 |
size_l2 = 50 * 1024 * 1024 # 50 MB for K & V
|
| 524 |
-
|
|
|
|
| 525 |
(args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1]
|
| 526 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 527 |
assert args.mCuSeqlensQ is not None or args.mSeqUsedQ is not None, (
|
| 528 |
"At least one of mCuSeqlensQ or mSeqUsedQ must be provided"
|
| 529 |
)
|
| 530 |
assert args.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 531 |
return SingleTileVarlenScheduler.Params(
|
| 532 |
num_head=args.num_head,
|
| 533 |
num_batch=args.num_batch,
|
|
@@ -542,22 +814,65 @@ class SingleTileVarlenScheduler:
|
|
| 542 |
is_split_kv=args.is_split_kv,
|
| 543 |
head_swizzle=args.head_swizzle,
|
| 544 |
cluster_shape_m=args.cluster_shape_mn[0],
|
|
|
|
| 545 |
)
|
| 546 |
|
| 547 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 548 |
self.params = params
|
| 549 |
self._tile_idx = tile_idx
|
| 550 |
self._split_idx = split_idx
|
| 551 |
self._is_first_block = True
|
|
|
|
| 552 |
self._loc = loc
|
| 553 |
self._ip = ip
|
| 554 |
|
| 555 |
@staticmethod
|
| 556 |
-
def to_underlying_arguments(
|
| 557 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 558 |
|
| 559 |
@staticmethod
|
| 560 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 561 |
tile_idx, split_idx, _ = cute.arch.block_idx()
|
| 562 |
return SingleTileVarlenScheduler(params, tile_idx, split_idx, loc=loc, ip=ip)
|
| 563 |
|
|
@@ -573,7 +888,7 @@ class SingleTileVarlenScheduler:
|
|
| 573 |
params.total_q
|
| 574 |
+ params.num_batch * (params.cluster_shape_m * params.tile_shape_mn[0] - 1)
|
| 575 |
) // params.tile_shape_mn[0]
|
| 576 |
-
#
|
| 577 |
total_blocks_max = total_blocks_max // params.cluster_shape_m * params.cluster_shape_m
|
| 578 |
return (total_blocks_max * params.num_head, params.num_splits, Int32(1))
|
| 579 |
|
|
@@ -601,7 +916,8 @@ class SingleTileVarlenScheduler:
|
|
| 601 |
)
|
| 602 |
|
| 603 |
@cute.jit
|
| 604 |
-
def
|
|
|
|
| 605 |
params = self.params
|
| 606 |
lane_idx = cute.arch.lane_idx()
|
| 607 |
num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=0)
|
|
@@ -654,6 +970,7 @@ class SingleTileVarlenScheduler:
|
|
| 654 |
num_n_blocks = (
|
| 655 |
num_m_blocks
|
| 656 |
* params.tile_shape_mn[0]
|
|
|
|
| 657 |
// params.qhead_per_kvhead_packgqa
|
| 658 |
// params.tile_shape_mn[1]
|
| 659 |
)
|
|
@@ -698,19 +1015,62 @@ class SingleTileVarlenScheduler:
|
|
| 698 |
split_idx = self._split_idx if const_expr(params.is_split_kv) else Int32(0)
|
| 699 |
return WorkTileInfo((Int32(block), Int32(head_idx), Int32(batch_idx), split_idx), is_valid)
|
| 700 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 701 |
def initial_work_tile_info(self, *, loc=None, ip=None):
|
| 702 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 703 |
|
| 704 |
def prefetch_next_work(self, *, loc=None, ip=None):
|
| 705 |
-
|
|
|
|
| 706 |
|
| 707 |
def advance_to_next_work(self, *, loc=None, ip=None):
|
| 708 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 709 |
self._is_first_block = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 710 |
|
| 711 |
def __extract_mlir_values__(self):
|
| 712 |
values, self._values_pos = [], []
|
| 713 |
-
|
|
|
|
|
|
|
|
|
|
| 714 |
obj_values = cutlass.extract_mlir_values(obj)
|
| 715 |
values += obj_values
|
| 716 |
self._values_pos.append(len(obj_values))
|
|
@@ -718,10 +1078,10 @@ class SingleTileVarlenScheduler:
|
|
| 718 |
|
| 719 |
def __new_from_mlir_values__(self, values):
|
| 720 |
obj_list = []
|
| 721 |
-
|
| 722 |
-
|
| 723 |
-
self.
|
| 724 |
-
):
|
| 725 |
obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
|
| 726 |
values = values[n_items:]
|
| 727 |
-
return
|
|
|
|
| 1 |
# Copyright (c) 2025, Tri Dao.
|
| 2 |
|
| 3 |
+
from enum import IntEnum, auto
|
| 4 |
+
from typing import Optional, Tuple, Protocol, runtime_checkable
|
| 5 |
from dataclasses import dataclass
|
| 6 |
|
| 7 |
try:
|
|
|
|
| 10 |
from typing_extensions import override
|
| 11 |
|
| 12 |
import cutlass
|
| 13 |
+
from cutlass.pipeline import PipelineClcFetchAsync, PipelineState
|
| 14 |
from cutlass._mlir import ir
|
| 15 |
import cutlass.cute as cute
|
| 16 |
from cutlass import Int32, const_expr
|
| 17 |
from cutlass.cute import FastDivmodDivisor
|
| 18 |
+
from cutlass.utils import ClcDynamicPersistentTileScheduler, ClcDynamicPersistentTileSchedulerParams
|
| 19 |
|
| 20 |
from .quack.cute_dsl_utils import ParamsBase
|
| 21 |
|
| 22 |
+
from . import utils as utils
|
| 23 |
from .fast_math import clz
|
| 24 |
|
| 25 |
|
| 26 |
+
class SchedulingMode(IntEnum):
|
| 27 |
+
NONE = auto()
|
| 28 |
+
STATIC = auto()
|
| 29 |
+
DYNAMIC = auto()
|
| 30 |
+
CLC = auto()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@dataclass
|
| 34 |
+
class ClcState(ParamsBase):
|
| 35 |
+
"""Owns the runtime state shared by CLC-capable tile schedulers.
|
| 36 |
+
|
| 37 |
+
`FlashAttentionForwardSm100` constructs this state because it owns the CLC
|
| 38 |
+
response buffer, mbarrier storage, and launch geometry needed to initialize
|
| 39 |
+
the hardware scheduler and async pipeline. Individual tile schedulers then
|
| 40 |
+
consume this state and map the returned hardware work tiles into their own
|
| 41 |
+
logical `WorkTileInfo` coordinates.
|
| 42 |
+
|
| 43 |
+
To add CLC support to a scheduler:
|
| 44 |
+
- implement `clc_problem_shape(params)` so the kernel can create the hardware scheduler
|
| 45 |
+
- accept `clc: ClcState | None` in `create(...)` / `__init__`
|
| 46 |
+
- map `clc.initial_work_tile_info()` and `clc.get_current_work()` into scheduler coordinates
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
_hw_scheduler: ClcDynamicPersistentTileScheduler
|
| 50 |
+
_pipeline: PipelineClcFetchAsync
|
| 51 |
+
_consumer_state: PipelineState
|
| 52 |
+
_producer_state: PipelineState
|
| 53 |
+
|
| 54 |
+
@staticmethod
|
| 55 |
+
def create(
|
| 56 |
+
*,
|
| 57 |
+
hw_scheduler: ClcDynamicPersistentTileScheduler,
|
| 58 |
+
pipeline: PipelineClcFetchAsync,
|
| 59 |
+
consumer_state: PipelineState,
|
| 60 |
+
producer_state: PipelineState,
|
| 61 |
+
) -> "ClcState":
|
| 62 |
+
return ClcState(hw_scheduler, pipeline, consumer_state, producer_state)
|
| 63 |
+
|
| 64 |
+
def initial_work_tile_info(self):
|
| 65 |
+
return self._hw_scheduler.initial_work_tile_info()
|
| 66 |
+
|
| 67 |
+
def get_current_work(self):
|
| 68 |
+
return self._hw_scheduler.get_current_work()
|
| 69 |
+
|
| 70 |
+
def prefetch_next_work(self, *, loc=None, ip=None):
|
| 71 |
+
self._pipeline.producer_acquire(self._producer_state, loc=loc, ip=ip)
|
| 72 |
+
mbarrier_addr = self._pipeline.producer_get_barrier(self._producer_state, loc=loc, ip=ip)
|
| 73 |
+
self._hw_scheduler.advance_to_next_work(mbarrier_addr, loc=loc, ip=ip)
|
| 74 |
+
self._producer_state.advance(loc=loc, ip=ip)
|
| 75 |
+
|
| 76 |
+
def consumer_wait(self, *, loc=None, ip=None):
|
| 77 |
+
self._pipeline.consumer_wait(self._consumer_state, loc=loc, ip=ip)
|
| 78 |
+
|
| 79 |
+
def consumer_release(self, *, loc=None, ip=None):
|
| 80 |
+
self._pipeline.consumer_release(self._consumer_state, loc=loc, ip=ip)
|
| 81 |
+
self._consumer_state.advance(loc=loc, ip=ip)
|
| 82 |
+
|
| 83 |
+
def producer_tail(self, *, loc=None, ip=None):
|
| 84 |
+
self._pipeline.producer_tail(self._producer_state, loc=loc, ip=ip)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
class WorkTileInfo(cutlass.utils.WorkTileInfo):
|
| 88 |
"""Altered WorkTileInfo which includes four axes: (block, head, batch, split)"""
|
| 89 |
|
|
|
|
| 95 |
return WorkTileInfo(new_tile_idx, new_is_valid_tile)
|
| 96 |
|
| 97 |
|
| 98 |
+
@runtime_checkable
|
| 99 |
+
class TileSchedulerProtocol(Protocol):
|
| 100 |
+
"""Protocol defining the interface all tile schedulers must implement.
|
| 101 |
+
|
| 102 |
+
Schedulers are responsible for:
|
| 103 |
+
1. Coordinate mapping: linear tile index -> (m_block, head, batch, split)
|
| 104 |
+
2. Work distribution: how to get the next tile (static grid-stride vs CLC dynamic)
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
def get_current_work(self) -> WorkTileInfo:
|
| 108 |
+
"""Get the current work tile coordinates."""
|
| 109 |
+
...
|
| 110 |
+
|
| 111 |
+
def initial_work_tile_info(self) -> WorkTileInfo:
|
| 112 |
+
"""Get the initial work tile for this CTA."""
|
| 113 |
+
...
|
| 114 |
+
|
| 115 |
+
def advance_to_next_work(self, *, loc=None, ip=None):
|
| 116 |
+
"""Consumer-side advance: move to next tile and return it.
|
| 117 |
+
|
| 118 |
+
For static schedulers: grid-stride increment + get_current_work.
|
| 119 |
+
For CLC schedulers: consumer wait + get_current_work + consumer release + state advance.
|
| 120 |
+
"""
|
| 121 |
+
...
|
| 122 |
+
|
| 123 |
+
def prefetch_next_work(self, *, loc=None, ip=None) -> None:
|
| 124 |
+
"""Producer-side prefetch of next work tile (no-op for static schedulers).
|
| 125 |
+
|
| 126 |
+
For CLC schedulers: producer acquire + issue CLC query + producer state advance.
|
| 127 |
+
Only called by the scheduler warp.
|
| 128 |
+
"""
|
| 129 |
+
...
|
| 130 |
+
|
| 131 |
+
def producer_tail(self, *, loc=None, ip=None) -> None:
|
| 132 |
+
"""Producer-side cleanup after the last tile.
|
| 133 |
+
|
| 134 |
+
No-op for static schedulers. For CLC schedulers: pipeline producer_tail.
|
| 135 |
+
"""
|
| 136 |
+
...
|
| 137 |
+
|
| 138 |
+
|
| 139 |
@dataclass
|
| 140 |
class TileSchedulerArguments(ParamsBase):
|
| 141 |
num_block: Int32
|
|
|
|
| 156 |
lpt: cutlass.Constexpr[bool] = False
|
| 157 |
is_split_kv: cutlass.Constexpr[bool] = False
|
| 158 |
head_swizzle: cutlass.Constexpr[bool] = False
|
| 159 |
+
use_cluster_idx: cutlass.Constexpr[bool] = False
|
| 160 |
|
| 161 |
|
| 162 |
class SingleTileScheduler:
|
|
|
|
| 169 |
num_splits_divmod: FastDivmodDivisor
|
| 170 |
is_split_kv: cutlass.Constexpr[bool] = False
|
| 171 |
cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1)
|
| 172 |
+
use_cluster_idx: cutlass.Constexpr[bool] = False
|
| 173 |
|
| 174 |
@staticmethod
|
| 175 |
def create(
|
|
|
|
| 183 |
FastDivmodDivisor(args.num_splits),
|
| 184 |
args.is_split_kv,
|
| 185 |
args.cluster_shape_mn,
|
| 186 |
+
args.use_cluster_idx,
|
| 187 |
)
|
| 188 |
|
| 189 |
def __init__(self, params: Params, blk_coord: cute.Coord, *, loc=None, ip=None):
|
|
|
|
| 194 |
self._ip = ip
|
| 195 |
|
| 196 |
@staticmethod
|
| 197 |
+
def to_underlying_arguments(
|
| 198 |
+
args: TileSchedulerArguments,
|
| 199 |
+
*,
|
| 200 |
+
scheduling_mode: SchedulingMode = SchedulingMode.STATIC,
|
| 201 |
+
loc=None,
|
| 202 |
+
ip=None,
|
| 203 |
+
) -> Params:
|
| 204 |
+
assert scheduling_mode == SchedulingMode.STATIC, (
|
| 205 |
+
f"SingleTileScheduler only supports STATIC, got {scheduling_mode!r}"
|
| 206 |
+
)
|
| 207 |
return SingleTileScheduler.Params.create(args, loc=loc, ip=ip)
|
| 208 |
|
| 209 |
@staticmethod
|
| 210 |
+
def create(
|
| 211 |
+
params: Params, clc: ClcState | None = None, *, loc=None, ip=None
|
| 212 |
+
) -> "SingleTileScheduler":
|
| 213 |
+
if const_expr(cute.size(params.cluster_shape_mn) == 1 or not params.use_cluster_idx):
|
| 214 |
+
blk_coord = cute.arch.block_idx()
|
| 215 |
+
else:
|
| 216 |
+
blk_coord = cute.arch.cluster_idx()
|
|
|
|
| 217 |
return SingleTileScheduler(params, blk_coord, loc=loc, ip=ip)
|
| 218 |
|
| 219 |
# called by host
|
|
|
|
| 226 |
) -> Tuple[Int32, Int32, Int32]:
|
| 227 |
# TODO: this hard-codes the fact that we only use cluster = (1, 1) or (2, 1)
|
| 228 |
assert params.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported"
|
| 229 |
+
if const_expr(params.use_cluster_idx):
|
| 230 |
+
# Grid must have num_block * cluster_m physical blocks so that there are num_block clusters
|
| 231 |
+
grid_x = params.num_block * params.cluster_shape_mn[0]
|
| 232 |
+
else:
|
| 233 |
+
grid_x = cute.round_up(params.num_block, params.cluster_shape_mn[0])
|
| 234 |
return (
|
| 235 |
+
grid_x,
|
| 236 |
params.num_head * params.num_splits,
|
| 237 |
params.num_batch,
|
| 238 |
)
|
|
|
|
| 256 |
|
| 257 |
def advance_to_next_work(self, *, loc=None, ip=None):
|
| 258 |
self._is_first_block = False
|
| 259 |
+
return self.get_current_work()
|
| 260 |
+
|
| 261 |
+
def producer_tail(self, *, loc=None, ip=None):
|
| 262 |
+
pass
|
| 263 |
|
| 264 |
def __extract_mlir_values__(self):
|
| 265 |
values, self._values_pos = [], []
|
|
|
|
| 305 |
self._ip = ip
|
| 306 |
|
| 307 |
@staticmethod
|
| 308 |
+
def to_underlying_arguments(
|
| 309 |
+
args: TileSchedulerArguments,
|
| 310 |
+
*,
|
| 311 |
+
scheduling_mode: SchedulingMode = SchedulingMode.STATIC,
|
| 312 |
+
loc=None,
|
| 313 |
+
ip=None,
|
| 314 |
+
) -> Params:
|
| 315 |
+
assert scheduling_mode == SchedulingMode.STATIC, (
|
| 316 |
+
f"StaticPersistentTileScheduler only supports STATIC, got {scheduling_mode!r}"
|
| 317 |
+
)
|
| 318 |
return StaticPersistentTileScheduler.Params.create(args, loc=loc, ip=ip)
|
| 319 |
|
| 320 |
@staticmethod
|
| 321 |
+
def create(
|
| 322 |
+
params: Params, clc: ClcState | None = None, *, loc=None, ip=None
|
| 323 |
+
) -> "StaticPersistentTileScheduler":
|
| 324 |
if const_expr(cute.size(params.cluster_shape_m) == 1):
|
| 325 |
tile_idx = cute.arch.block_idx()[0]
|
| 326 |
else:
|
| 327 |
tile_idx = cute.arch.cluster_idx()[0]
|
| 328 |
return StaticPersistentTileScheduler(params, tile_idx, loc=loc, ip=ip)
|
| 329 |
|
|
|
|
| 330 |
@staticmethod
|
| 331 |
def get_grid_shape(
|
| 332 |
params: Params,
|
|
|
|
| 336 |
) -> Tuple[Int32, Int32, Int32]:
|
| 337 |
hardware_info = cutlass.utils.HardwareInfo()
|
| 338 |
sm_count = hardware_info.get_device_multiprocessor_count()
|
|
|
|
| 339 |
max_ctas = (sm_count // params.cluster_shape_m) * params.cluster_shape_m
|
| 340 |
grid_x = cutlass.min(max_ctas, params.total_blocks_cluster * params.cluster_shape_m)
|
| 341 |
return (grid_x, Int32(1), Int32(1))
|
| 342 |
|
|
|
|
| 343 |
def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
|
| 344 |
hn_idx, block_idx = divmod(self._tile_idx, self.params.num_block_cluster_divmod)
|
| 345 |
batch_idx, head_idx = divmod(hn_idx, self.params.num_head_divmod)
|
| 346 |
is_valid = self._tile_idx < self.params.total_blocks_cluster
|
|
|
|
|
|
|
| 347 |
return WorkTileInfo(
|
| 348 |
(Int32(block_idx), Int32(head_idx), Int32(batch_idx), Int32(0)), is_valid
|
| 349 |
)
|
|
|
|
| 359 |
self._tile_idx += cute.arch.grid_dim()[0]
|
| 360 |
else:
|
| 361 |
self._tile_idx += cute.arch.cluster_dim()[0]
|
| 362 |
+
return self.get_current_work()
|
| 363 |
+
|
| 364 |
+
def producer_tail(self, *, loc=None, ip=None):
|
| 365 |
+
pass
|
| 366 |
|
| 367 |
def __extract_mlir_values__(self):
|
| 368 |
values, self._values_pos = [], []
|
|
|
|
| 389 |
total_blocks: Int32
|
| 390 |
num_splits: Int32
|
| 391 |
num_block: Int32
|
| 392 |
+
num_head: Int32
|
| 393 |
+
num_batch: Int32
|
| 394 |
l2_minor: Int32
|
|
|
|
| 395 |
num_head_divmod: FastDivmodDivisor
|
| 396 |
l2_minor_divmod: FastDivmodDivisor
|
| 397 |
l2_major_divmod: FastDivmodDivisor
|
| 398 |
l2_minor_residual_divmod: FastDivmodDivisor
|
| 399 |
num_hb_quotient: Int32
|
| 400 |
+
num_splits_divmod: FastDivmodDivisor
|
| 401 |
is_split_kv: cutlass.Constexpr[bool] = False
|
| 402 |
+
cluster_shape_m: cutlass.Constexpr[int] = 1
|
| 403 |
+
scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC
|
| 404 |
+
lpt: cutlass.Constexpr[bool] = True
|
| 405 |
|
| 406 |
@staticmethod
|
| 407 |
@cute.jit
|
| 408 |
def create(
|
| 409 |
+
args: TileSchedulerArguments,
|
| 410 |
+
*,
|
| 411 |
+
scheduling_mode: SchedulingMode = SchedulingMode.STATIC,
|
| 412 |
+
loc=None,
|
| 413 |
+
ip=None,
|
| 414 |
) -> "SingleTileLPTScheduler.Params":
|
| 415 |
+
assert scheduling_mode in (SchedulingMode.STATIC, SchedulingMode.CLC), (
|
| 416 |
+
f"Only STATIC and CLC are supported, got {scheduling_mode!r}"
|
| 417 |
+
)
|
| 418 |
size_one_kv_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size
|
| 419 |
size_one_head = size_one_kv_head
|
| 420 |
size_l2 = 50 * 1024 * 1024 # 40 MB for K & V
|
| 421 |
# Swizzle is the size of each "section". Round swizzle to a power of 2
|
| 422 |
# Need to be careful about the case where only one head will fit
|
| 423 |
# swizzle is how many heads can fit in L2
|
| 424 |
+
# Seems faster if swizzle is a power of 2
|
|
|
|
| 425 |
log2_floor = lambda n: 31 - clz(n)
|
| 426 |
swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head))
|
|
|
|
| 427 |
# If we're in the last section (called residual), we don't want to divide by
|
| 428 |
# swizzle. Instead we want to divide by the remainder.
|
| 429 |
num_hb_quotient = (args.num_head * args.num_batch) // swizzle
|
|
|
|
| 431 |
return SingleTileLPTScheduler.Params(
|
| 432 |
total_blocks=args.num_block * args.num_head * args.num_batch,
|
| 433 |
num_block=args.num_block,
|
| 434 |
+
num_head=args.num_head,
|
| 435 |
+
num_batch=args.num_batch,
|
| 436 |
l2_minor=Int32(swizzle),
|
|
|
|
| 437 |
num_head_divmod=FastDivmodDivisor(args.num_head),
|
| 438 |
l2_minor_divmod=FastDivmodDivisor(swizzle),
|
| 439 |
l2_major_divmod=FastDivmodDivisor(swizzle * args.num_block),
|
| 440 |
+
l2_minor_residual_divmod=FastDivmodDivisor(max(num_hb_remainder, 1)),
|
|
|
|
|
|
|
| 441 |
num_hb_quotient=Int32(num_hb_quotient),
|
| 442 |
num_splits=args.num_splits,
|
| 443 |
+
num_splits_divmod=FastDivmodDivisor(args.num_splits),
|
| 444 |
is_split_kv=args.is_split_kv,
|
| 445 |
+
cluster_shape_m=args.cluster_shape_mn[0],
|
| 446 |
+
scheduling_mode=scheduling_mode,
|
| 447 |
+
lpt=args.lpt,
|
| 448 |
)
|
| 449 |
|
| 450 |
+
def __init__(
|
| 451 |
+
self,
|
| 452 |
+
params: Params,
|
| 453 |
+
tile_idx: Int32,
|
| 454 |
+
split_idx: Int32,
|
| 455 |
+
clc: ClcState | None = None,
|
| 456 |
+
*,
|
| 457 |
+
loc=None,
|
| 458 |
+
ip=None,
|
| 459 |
+
):
|
| 460 |
self.params = params
|
| 461 |
self._tile_idx = tile_idx
|
| 462 |
self._split_idx = split_idx
|
| 463 |
+
self.clc = clc
|
| 464 |
self._loc = loc
|
| 465 |
self._ip = ip
|
| 466 |
|
| 467 |
@staticmethod
|
| 468 |
+
def to_underlying_arguments(
|
| 469 |
+
args: TileSchedulerArguments,
|
| 470 |
+
*,
|
| 471 |
+
scheduling_mode: SchedulingMode = SchedulingMode.STATIC,
|
| 472 |
+
loc=None,
|
| 473 |
+
ip=None,
|
| 474 |
+
) -> Params:
|
| 475 |
+
return SingleTileLPTScheduler.Params.create(
|
| 476 |
+
args, scheduling_mode=scheduling_mode, loc=loc, ip=ip
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
@staticmethod
|
| 480 |
+
def _clc_grid_shape(params: Params):
|
| 481 |
+
num_batch_splits = (
|
| 482 |
+
params.num_batch * params.num_splits
|
| 483 |
+
if const_expr(params.is_split_kv)
|
| 484 |
+
else params.num_batch
|
| 485 |
+
)
|
| 486 |
+
return (
|
| 487 |
+
cute.round_up(params.num_block, params.cluster_shape_m),
|
| 488 |
+
params.num_head,
|
| 489 |
+
num_batch_splits,
|
| 490 |
+
)
|
| 491 |
|
| 492 |
@staticmethod
|
| 493 |
@cute.jit
|
| 494 |
+
def clc_problem_shape(params: Params):
|
| 495 |
+
return ClcDynamicPersistentTileSchedulerParams(
|
| 496 |
+
problem_shape_ntile_mnl=SingleTileLPTScheduler._clc_grid_shape(params),
|
| 497 |
+
cluster_shape_mnk=(params.cluster_shape_m, 1, 1),
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
@staticmethod
|
| 501 |
+
@cute.jit
|
| 502 |
+
def create(
|
| 503 |
+
params: Params, clc: ClcState | None = None, *, loc=None, ip=None
|
| 504 |
+
) -> "SingleTileLPTScheduler":
|
| 505 |
+
if const_expr(params.scheduling_mode == SchedulingMode.CLC):
|
| 506 |
+
return SingleTileLPTScheduler(
|
| 507 |
+
params, cute.arch.block_idx()[0], Int32(0), clc, loc=loc, ip=ip
|
| 508 |
+
)
|
| 509 |
tile_idx, split_idx, _ = cute.arch.block_idx()
|
| 510 |
return SingleTileLPTScheduler(params, tile_idx, split_idx, loc=loc, ip=ip)
|
| 511 |
|
|
|
|
| 512 |
@staticmethod
|
| 513 |
def get_grid_shape(
|
| 514 |
params: Params,
|
|
|
|
| 516 |
loc=None,
|
| 517 |
ip=None,
|
| 518 |
) -> Tuple[Int32, Int32, Int32]:
|
| 519 |
+
if const_expr(params.scheduling_mode == SchedulingMode.CLC):
|
| 520 |
+
return SingleTileLPTScheduler._clc_grid_shape(params)
|
| 521 |
return (params.total_blocks, params.num_splits, Int32(1))
|
| 522 |
|
| 523 |
+
@cute.jit
|
| 524 |
+
def clc_work_to_coords(self, work) -> WorkTileInfo:
|
| 525 |
+
"""Convert CLC response (block, head, batch_split) to WorkTileInfo.
|
| 526 |
+
|
| 527 |
+
CLC returns raw grid coordinates — no L2 swizzle (hardware decides order).
|
| 528 |
+
We only apply cluster division, optional LPT block reversal, and split_kv unpacking.
|
| 529 |
+
"""
|
| 530 |
+
block_idx = work.tile_idx[0]
|
| 531 |
+
if const_expr(self.params.cluster_shape_m > 1):
|
| 532 |
+
block_idx = block_idx // self.params.cluster_shape_m
|
| 533 |
+
if const_expr(self.params.lpt):
|
| 534 |
+
# Longest-processing-time-first: reverse block order
|
| 535 |
+
block_idx = self.params.num_block - 1 - block_idx
|
| 536 |
+
split_idx = Int32(0)
|
| 537 |
+
if const_expr(self.params.is_split_kv):
|
| 538 |
+
batch_idx, split_idx = divmod(work.tile_idx[2], self.params.num_splits_divmod)
|
| 539 |
+
else:
|
| 540 |
+
batch_idx = work.tile_idx[2]
|
| 541 |
+
return WorkTileInfo(
|
| 542 |
+
(Int32(block_idx), Int32(work.tile_idx[1]), Int32(batch_idx), Int32(split_idx)),
|
| 543 |
+
work.is_valid_tile,
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
@cute.jit
|
| 547 |
def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
|
| 548 |
+
if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
|
| 549 |
+
work = self.clc.get_current_work()
|
| 550 |
+
self._tile_idx = work.tile_idx[0]
|
| 551 |
+
return self.clc_work_to_coords(work)
|
| 552 |
+
# Static path: L2-swizzled coordinate mapping
|
| 553 |
params = self.params
|
| 554 |
# Implement LPT scheduling coordinate calculation
|
| 555 |
bidhb, l2_mod = divmod(self._tile_idx, params.l2_major_divmod)
|
|
|
|
| 563 |
bidhb_actual = bidhb * params.l2_minor + bidhb_residual
|
| 564 |
batch_idx, head_idx = divmod(bidhb_actual, params.num_head_divmod)
|
| 565 |
# Longest-processing-time-first
|
| 566 |
+
if const_expr(params.lpt):
|
| 567 |
+
block = params.num_block - 1 - block
|
| 568 |
is_valid = self._tile_idx < params.total_blocks
|
| 569 |
return WorkTileInfo(
|
| 570 |
(Int32(block), Int32(head_idx), Int32(batch_idx), Int32(self._split_idx)), is_valid
|
| 571 |
)
|
| 572 |
|
| 573 |
+
@cute.jit
|
| 574 |
def initial_work_tile_info(self, *, loc=None, ip=None):
|
| 575 |
+
if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
|
| 576 |
+
work = self.clc.initial_work_tile_info()
|
| 577 |
+
self._tile_idx = work.tile_idx[0]
|
| 578 |
+
return self.clc_work_to_coords(work)
|
| 579 |
return self.get_current_work(loc=loc, ip=ip)
|
| 580 |
|
| 581 |
def prefetch_next_work(self, *, loc=None, ip=None):
|
| 582 |
+
if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
|
| 583 |
+
self.clc.prefetch_next_work(loc=loc, ip=ip)
|
| 584 |
|
| 585 |
def advance_to_next_work(self, *, loc=None, ip=None):
|
| 586 |
+
if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
|
| 587 |
+
self.clc.consumer_wait(loc=loc, ip=ip)
|
| 588 |
+
work = self.get_current_work()
|
| 589 |
+
self.clc.consumer_release(loc=loc, ip=ip)
|
| 590 |
+
return work
|
| 591 |
# Single tile scheduler - set to invalid tile_idx to indicate no more work
|
| 592 |
self._tile_idx = self.params.total_blocks
|
| 593 |
+
return self.get_current_work()
|
| 594 |
+
|
| 595 |
+
def producer_tail(self, *, loc=None, ip=None):
|
| 596 |
+
if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
|
| 597 |
+
self.clc.producer_tail(loc=loc, ip=ip)
|
| 598 |
|
| 599 |
def __extract_mlir_values__(self):
|
| 600 |
values, self._values_pos = [], []
|
| 601 |
+
objs = [self.params, self._tile_idx, self._split_idx]
|
| 602 |
+
if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
|
| 603 |
+
objs += [self.clc]
|
| 604 |
+
for obj in objs:
|
| 605 |
obj_values = cutlass.extract_mlir_values(obj)
|
| 606 |
values += obj_values
|
| 607 |
self._values_pos.append(len(obj_values))
|
|
|
|
| 609 |
|
| 610 |
def __new_from_mlir_values__(self, values):
|
| 611 |
obj_list = []
|
| 612 |
+
objs = [self.params, self._tile_idx, self._split_idx]
|
| 613 |
+
if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
|
| 614 |
+
objs += [self.clc]
|
| 615 |
+
for obj, n_items in zip(objs, self._values_pos):
|
| 616 |
obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
|
| 617 |
values = values[n_items:]
|
| 618 |
+
return self.__class__(*obj_list, loc=self._loc)
|
| 619 |
|
| 620 |
|
| 621 |
class SingleTileLPTBwdScheduler:
|
|
|
|
| 639 |
) -> "SingleTileLPTBwdScheduler.Params":
|
| 640 |
size_l2 = 50 * 1024 * 1024
|
| 641 |
size_one_qdo_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size
|
| 642 |
+
size_one_dqaccum_head = args.seqlen_k * (args.headdim) * 4
|
| 643 |
+
# size_one_dqaccum_head = 0
|
| 644 |
size_one_head = size_one_qdo_head + size_one_dqaccum_head
|
| 645 |
log2_floor = lambda n: 31 - clz(n)
|
| 646 |
swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head))
|
|
|
|
| 674 |
self._ip = ip
|
| 675 |
|
| 676 |
@staticmethod
|
| 677 |
+
def to_underlying_arguments(
|
| 678 |
+
args: TileSchedulerArguments,
|
| 679 |
+
*,
|
| 680 |
+
scheduling_mode: SchedulingMode = SchedulingMode.STATIC,
|
| 681 |
+
loc=None,
|
| 682 |
+
ip=None,
|
| 683 |
+
) -> Params:
|
| 684 |
+
assert scheduling_mode == SchedulingMode.STATIC, (
|
| 685 |
+
f"SingleTileLPTBwdScheduler only supports STATIC, got {scheduling_mode!r}"
|
| 686 |
+
)
|
| 687 |
return SingleTileLPTBwdScheduler.Params.create(args, loc=loc, ip=ip)
|
| 688 |
|
| 689 |
@staticmethod
|
|
|
|
| 734 |
def advance_to_next_work(self, *, loc=None, ip=None):
|
| 735 |
# Single tile scheduler - set to invalid tile_idx to indicate no more work
|
| 736 |
self._tile_idx = self.params.total_blocks
|
| 737 |
+
return self.get_current_work()
|
| 738 |
|
| 739 |
def __extract_mlir_values__(self):
|
| 740 |
values, self._values_pos = [], []
|
|
|
|
| 768 |
is_split_kv: cutlass.Constexpr[bool] = False
|
| 769 |
head_swizzle: cutlass.Constexpr[bool] = False
|
| 770 |
cluster_shape_m: cutlass.Constexpr[int] = 1
|
| 771 |
+
scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC
|
| 772 |
|
| 773 |
@staticmethod
|
| 774 |
@cute.jit
|
| 775 |
def create(
|
| 776 |
+
args: TileSchedulerArguments,
|
| 777 |
+
*,
|
| 778 |
+
scheduling_mode: SchedulingMode = SchedulingMode.STATIC,
|
| 779 |
+
loc=None,
|
| 780 |
+
ip=None,
|
| 781 |
) -> "SingleTileVarlenScheduler.Params":
|
| 782 |
+
assert scheduling_mode in (SchedulingMode.STATIC, SchedulingMode.CLC), (
|
| 783 |
+
f"Only STATIC and CLC are supported, got {scheduling_mode!r}"
|
| 784 |
+
)
|
| 785 |
size_l2 = 50 * 1024 * 1024 # 50 MB for K & V
|
| 786 |
+
# if backward, this is qdo block size
|
| 787 |
+
kv_block_size = (
|
| 788 |
(args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1]
|
| 789 |
)
|
| 790 |
+
# if backward, add dqaccum block size to calculate swizzle
|
| 791 |
+
if args.head_swizzle:
|
| 792 |
+
kv_block_size += args.headdim * 4 * args.tile_shape_mn[1]
|
| 793 |
+
max_kvblock_in_l2 = size_l2 // kv_block_size
|
| 794 |
assert args.mCuSeqlensQ is not None or args.mSeqUsedQ is not None, (
|
| 795 |
"At least one of mCuSeqlensQ or mSeqUsedQ must be provided"
|
| 796 |
)
|
| 797 |
assert args.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported"
|
| 798 |
+
# TODO: Support varlen CLC with cluster_shape_m > 1 by refactoring the
|
| 799 |
+
# flattened-tile decode so cluster unpacking semantics are explicit.
|
| 800 |
+
assert scheduling_mode != SchedulingMode.CLC or args.cluster_shape_mn[0] == 1, (
|
| 801 |
+
"Varlen CLC currently requires cluster_shape_mn[0] == 1"
|
| 802 |
+
)
|
| 803 |
return SingleTileVarlenScheduler.Params(
|
| 804 |
num_head=args.num_head,
|
| 805 |
num_batch=args.num_batch,
|
|
|
|
| 814 |
is_split_kv=args.is_split_kv,
|
| 815 |
head_swizzle=args.head_swizzle,
|
| 816 |
cluster_shape_m=args.cluster_shape_mn[0],
|
| 817 |
+
scheduling_mode=scheduling_mode,
|
| 818 |
)
|
| 819 |
|
| 820 |
+
def __init__(
|
| 821 |
+
self,
|
| 822 |
+
params: Params,
|
| 823 |
+
tile_idx: Int32,
|
| 824 |
+
split_idx: Int32,
|
| 825 |
+
clc: ClcState | None = None,
|
| 826 |
+
*,
|
| 827 |
+
loc=None,
|
| 828 |
+
ip=None,
|
| 829 |
+
):
|
| 830 |
self.params = params
|
| 831 |
self._tile_idx = tile_idx
|
| 832 |
self._split_idx = split_idx
|
| 833 |
self._is_first_block = True
|
| 834 |
+
self.clc = clc
|
| 835 |
self._loc = loc
|
| 836 |
self._ip = ip
|
| 837 |
|
| 838 |
@staticmethod
|
| 839 |
+
def to_underlying_arguments(
|
| 840 |
+
args: TileSchedulerArguments,
|
| 841 |
+
*,
|
| 842 |
+
scheduling_mode: SchedulingMode = SchedulingMode.STATIC,
|
| 843 |
+
loc=None,
|
| 844 |
+
ip=None,
|
| 845 |
+
) -> Params:
|
| 846 |
+
return SingleTileVarlenScheduler.Params.create(
|
| 847 |
+
args, scheduling_mode=scheduling_mode, loc=loc, ip=ip
|
| 848 |
+
)
|
| 849 |
|
| 850 |
@staticmethod
|
| 851 |
+
@cute.jit
|
| 852 |
+
def clc_problem_shape(params: Params):
|
| 853 |
+
return ClcDynamicPersistentTileSchedulerParams(
|
| 854 |
+
problem_shape_ntile_mnl=SingleTileVarlenScheduler.get_grid_shape(params),
|
| 855 |
+
cluster_shape_mnk=(1, 1, 1),
|
| 856 |
+
)
|
| 857 |
+
|
| 858 |
+
@staticmethod
|
| 859 |
+
@cute.jit
|
| 860 |
+
def create(
|
| 861 |
+
params: Params, clc: ClcState | None = None, *, loc=None, ip=None
|
| 862 |
+
) -> "SingleTileVarlenScheduler":
|
| 863 |
+
if const_expr(params.scheduling_mode == SchedulingMode.CLC):
|
| 864 |
+
block_idx = cute.arch.block_idx()
|
| 865 |
+
split_idx = Int32(0)
|
| 866 |
+
if const_expr(params.is_split_kv):
|
| 867 |
+
split_idx = block_idx[1]
|
| 868 |
+
return SingleTileVarlenScheduler(
|
| 869 |
+
params,
|
| 870 |
+
block_idx[0],
|
| 871 |
+
split_idx,
|
| 872 |
+
clc,
|
| 873 |
+
loc=loc,
|
| 874 |
+
ip=ip,
|
| 875 |
+
)
|
| 876 |
tile_idx, split_idx, _ = cute.arch.block_idx()
|
| 877 |
return SingleTileVarlenScheduler(params, tile_idx, split_idx, loc=loc, ip=ip)
|
| 878 |
|
|
|
|
| 888 |
params.total_q
|
| 889 |
+ params.num_batch * (params.cluster_shape_m * params.tile_shape_mn[0] - 1)
|
| 890 |
) // params.tile_shape_mn[0]
|
| 891 |
+
# Round down to nearest multiple of cluster since odd excess is always padding.
|
| 892 |
total_blocks_max = total_blocks_max // params.cluster_shape_m * params.cluster_shape_m
|
| 893 |
return (total_blocks_max * params.num_head, params.num_splits, Int32(1))
|
| 894 |
|
|
|
|
| 916 |
)
|
| 917 |
|
| 918 |
@cute.jit
|
| 919 |
+
def _varlen_coord_map(self) -> WorkTileInfo:
|
| 920 |
+
"""Map self._tile_idx to (block, head, batch) via warp-level prefix sums."""
|
| 921 |
params = self.params
|
| 922 |
lane_idx = cute.arch.lane_idx()
|
| 923 |
num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=0)
|
|
|
|
| 970 |
num_n_blocks = (
|
| 971 |
num_m_blocks
|
| 972 |
* params.tile_shape_mn[0]
|
| 973 |
+
* params.cluster_shape_m
|
| 974 |
// params.qhead_per_kvhead_packgqa
|
| 975 |
// params.tile_shape_mn[1]
|
| 976 |
)
|
|
|
|
| 1015 |
split_idx = self._split_idx if const_expr(params.is_split_kv) else Int32(0)
|
| 1016 |
return WorkTileInfo((Int32(block), Int32(head_idx), Int32(batch_idx), split_idx), is_valid)
|
| 1017 |
|
| 1018 |
+
@cute.jit
|
| 1019 |
+
def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
|
| 1020 |
+
if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
|
| 1021 |
+
clc_work = self.clc.get_current_work()
|
| 1022 |
+
# Default to grid_dim (one past last valid flat index) so _varlen_coord_map
|
| 1023 |
+
# returns is_valid=False when CLC is exhausted. CLC tile_idx is garbage when
|
| 1024 |
+
# invalid, so we can't trust it. Local-then-assign avoids CuTe DSL structural
|
| 1025 |
+
# mismatch on self inside the runtime if.
|
| 1026 |
+
new_tile_idx = cute.arch.grid_dim()[0]
|
| 1027 |
+
new_split_idx = Int32(0)
|
| 1028 |
+
if clc_work.is_valid_tile:
|
| 1029 |
+
new_tile_idx = clc_work.tile_idx[0]
|
| 1030 |
+
if const_expr(self.params.is_split_kv):
|
| 1031 |
+
new_split_idx = clc_work.tile_idx[1]
|
| 1032 |
+
self._tile_idx = new_tile_idx
|
| 1033 |
+
self._split_idx = new_split_idx
|
| 1034 |
+
return self._varlen_coord_map()
|
| 1035 |
+
|
| 1036 |
+
@cute.jit
|
| 1037 |
def initial_work_tile_info(self, *, loc=None, ip=None):
|
| 1038 |
+
if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
|
| 1039 |
+
clc_work = self.clc.initial_work_tile_info()
|
| 1040 |
+
# See get_current_work for why grid_dim and local-then-assign.
|
| 1041 |
+
new_tile_idx = cute.arch.grid_dim()[0]
|
| 1042 |
+
new_split_idx = Int32(0)
|
| 1043 |
+
if clc_work.is_valid_tile:
|
| 1044 |
+
new_tile_idx = clc_work.tile_idx[0]
|
| 1045 |
+
if const_expr(self.params.is_split_kv):
|
| 1046 |
+
new_split_idx = clc_work.tile_idx[1]
|
| 1047 |
+
self._tile_idx = new_tile_idx
|
| 1048 |
+
self._split_idx = new_split_idx
|
| 1049 |
+
return self._varlen_coord_map()
|
| 1050 |
|
| 1051 |
def prefetch_next_work(self, *, loc=None, ip=None):
|
| 1052 |
+
if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
|
| 1053 |
+
self.clc.prefetch_next_work(loc=loc, ip=ip)
|
| 1054 |
|
| 1055 |
def advance_to_next_work(self, *, loc=None, ip=None):
|
| 1056 |
+
if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
|
| 1057 |
+
self.clc.consumer_wait(loc=loc, ip=ip)
|
| 1058 |
+
work = self.get_current_work()
|
| 1059 |
+
self.clc.consumer_release(loc=loc, ip=ip)
|
| 1060 |
+
return work
|
| 1061 |
self._is_first_block = False
|
| 1062 |
+
return self.get_current_work()
|
| 1063 |
+
|
| 1064 |
+
def producer_tail(self, *, loc=None, ip=None):
|
| 1065 |
+
if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
|
| 1066 |
+
self.clc.producer_tail(loc=loc, ip=ip)
|
| 1067 |
|
| 1068 |
def __extract_mlir_values__(self):
|
| 1069 |
values, self._values_pos = [], []
|
| 1070 |
+
objs = [self.params, self._tile_idx, self._split_idx]
|
| 1071 |
+
if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
|
| 1072 |
+
objs += [self.clc]
|
| 1073 |
+
for obj in objs:
|
| 1074 |
obj_values = cutlass.extract_mlir_values(obj)
|
| 1075 |
values += obj_values
|
| 1076 |
self._values_pos.append(len(obj_values))
|
|
|
|
| 1078 |
|
| 1079 |
def __new_from_mlir_values__(self, values):
|
| 1080 |
obj_list = []
|
| 1081 |
+
objs = [self.params, self._tile_idx, self._split_idx]
|
| 1082 |
+
if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
|
| 1083 |
+
objs += [self.clc]
|
| 1084 |
+
for obj, n_items in zip(objs, self._values_pos):
|
| 1085 |
obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
|
| 1086 |
values = values[n_items:]
|
| 1087 |
+
return self.__class__(*obj_list, loc=self._loc)
|
build/torch-cuda/utils.py
CHANGED
|
@@ -3,12 +3,14 @@
|
|
| 3 |
import math
|
| 4 |
import hashlib
|
| 5 |
import inspect
|
|
|
|
| 6 |
from typing import Type, Callable, Optional, Tuple, overload
|
| 7 |
|
| 8 |
import cutlass
|
| 9 |
import cutlass.cute as cute
|
| 10 |
|
| 11 |
-
from cutlass import Float32, const_expr
|
|
|
|
| 12 |
from cutlass.cutlass_dsl import T, dsl_user_op
|
| 13 |
from cutlass._mlir.dialects import nvvm, llvm
|
| 14 |
from cutlass.cute.runtime import from_dlpack
|
|
@@ -54,6 +56,17 @@ POLY_EX2 = {
|
|
| 54 |
),
|
| 55 |
}
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
def _compute_base_hash(func: Callable) -> str:
|
| 59 |
"""Compute hash from source code or bytecode and closure values."""
|
|
@@ -123,6 +136,40 @@ def create_softcap_scoremod(softcap_val):
|
|
| 123 |
return scoremod_premask_fn
|
| 124 |
|
| 125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Tensor:
|
| 127 |
return (
|
| 128 |
from_dlpack(x, assumed_align=alignment)
|
|
@@ -215,6 +262,21 @@ def warp_reduce(
|
|
| 215 |
return val
|
| 216 |
|
| 217 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
@dsl_user_op
|
| 219 |
def fmax(
|
| 220 |
a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None
|
|
@@ -429,8 +491,48 @@ def shuffle_sync(
|
|
| 429 |
return val[0]
|
| 430 |
|
| 431 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 432 |
@dsl_user_op
|
| 433 |
def shr_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 434 |
return cutlass.Uint32(
|
| 435 |
llvm.inline_asm(
|
| 436 |
T.i32(),
|
|
@@ -438,7 +540,7 @@ def shr_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) ->
|
|
| 438 |
cutlass.Uint32(val).ir_value(loc=loc, ip=ip),
|
| 439 |
cutlass.Uint32(shift).ir_value(loc=loc, ip=ip),
|
| 440 |
],
|
| 441 |
-
"shr.
|
| 442 |
"=r,r,r",
|
| 443 |
has_side_effects=False,
|
| 444 |
is_align_stack=False,
|
|
|
|
| 3 |
import math
|
| 4 |
import hashlib
|
| 5 |
import inspect
|
| 6 |
+
import os
|
| 7 |
from typing import Type, Callable, Optional, Tuple, overload
|
| 8 |
|
| 9 |
import cutlass
|
| 10 |
import cutlass.cute as cute
|
| 11 |
|
| 12 |
+
from cutlass import Float32, Int32, const_expr
|
| 13 |
+
from cutlass.cute import FastDivmodDivisor
|
| 14 |
from cutlass.cutlass_dsl import T, dsl_user_op
|
| 15 |
from cutlass._mlir.dialects import nvvm, llvm
|
| 16 |
from cutlass.cute.runtime import from_dlpack
|
|
|
|
| 56 |
),
|
| 57 |
}
|
| 58 |
|
| 59 |
+
_fa_clc_enabled: bool = os.environ.get("FA_CLC", "0") == "1"
|
| 60 |
+
_fa_disable_2cta_enabled: bool = os.environ.get("FA_DISABLE_2CTA", "0") == "1"
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _get_use_clc_scheduler_default() -> bool:
|
| 64 |
+
return _fa_clc_enabled
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _get_disable_2cta_default() -> bool:
|
| 68 |
+
return _fa_disable_2cta_enabled
|
| 69 |
+
|
| 70 |
|
| 71 |
def _compute_base_hash(func: Callable) -> str:
|
| 72 |
"""Compute hash from source code or bytecode and closure values."""
|
|
|
|
| 136 |
return scoremod_premask_fn
|
| 137 |
|
| 138 |
|
| 139 |
+
LOG2_E = math.log2(math.e)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def compute_softmax_scale_log2(softmax_scale, score_mod):
|
| 143 |
+
"""Compute softmax_scale_log2 and adjusted softmax_scale based on whether score_mod is used.
|
| 144 |
+
|
| 145 |
+
When score_mod is None, fold the log2(e) factor into softmax_scale_log2 and set softmax_scale
|
| 146 |
+
to None. When score_mod is present, keep softmax_scale separate so it can be applied before
|
| 147 |
+
the score_mod, and set softmax_scale_log2 to just the change-of-base constant.
|
| 148 |
+
|
| 149 |
+
Returns (softmax_scale_log2, softmax_scale).
|
| 150 |
+
"""
|
| 151 |
+
if const_expr(score_mod is None):
|
| 152 |
+
return softmax_scale * LOG2_E, None
|
| 153 |
+
else:
|
| 154 |
+
return LOG2_E, softmax_scale
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def compute_fastdiv_mods(mQ, mK, qhead_per_kvhead, pack_gqa, aux_tensors, mPageTable=None):
|
| 158 |
+
"""Compute FastDivmodDivisor pairs for aux_tensors index computation.
|
| 159 |
+
|
| 160 |
+
Returns a (seqlen_q_divmod, seqlen_k_divmod) tuple, or None if aux_tensors is None.
|
| 161 |
+
"""
|
| 162 |
+
if const_expr(aux_tensors is None):
|
| 163 |
+
return None
|
| 164 |
+
seqlen_q = cute.size(mQ.shape[0]) // (qhead_per_kvhead if const_expr(pack_gqa) else 1)
|
| 165 |
+
seqlen_k = (
|
| 166 |
+
cute.size(mK.shape[0])
|
| 167 |
+
if const_expr(mPageTable is None)
|
| 168 |
+
else mK.shape[0] * mPageTable.shape[1]
|
| 169 |
+
)
|
| 170 |
+
return (FastDivmodDivisor(seqlen_q), FastDivmodDivisor(seqlen_k))
|
| 171 |
+
|
| 172 |
+
|
| 173 |
def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Tensor:
|
| 174 |
return (
|
| 175 |
from_dlpack(x, assumed_align=alignment)
|
|
|
|
| 262 |
return val
|
| 263 |
|
| 264 |
|
| 265 |
+
@dsl_user_op
|
| 266 |
+
def smid(*, loc=None, ip=None) -> Int32:
|
| 267 |
+
return Int32(
|
| 268 |
+
llvm.inline_asm(
|
| 269 |
+
T.i32(),
|
| 270 |
+
[],
|
| 271 |
+
"mov.u32 $0, %smid;",
|
| 272 |
+
"=r",
|
| 273 |
+
has_side_effects=False,
|
| 274 |
+
is_align_stack=False,
|
| 275 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 276 |
+
)
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
@dsl_user_op
|
| 281 |
def fmax(
|
| 282 |
a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None
|
|
|
|
| 491 |
return val[0]
|
| 492 |
|
| 493 |
|
| 494 |
+
@dsl_user_op
|
| 495 |
+
def shl_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32:
|
| 496 |
+
"""
|
| 497 |
+
Left-shift val by shift bits using PTX shl.b32 (sign-agnostic).
|
| 498 |
+
|
| 499 |
+
Named ``shl_u32`` (not ``shl_b32``) because python type annotations
|
| 500 |
+
distinguish signed/unsigned.
|
| 501 |
+
|
| 502 |
+
PTX semantics (§9.7.8.8): "Shift amounts greater than the register width N
|
| 503 |
+
are clamped to N." So ``shl.b32 d, a, 32`` is well-defined and yields 0.
|
| 504 |
+
|
| 505 |
+
This differs from C/C++ and LLVM IR, where shifting by >= the type width is
|
| 506 |
+
undefined behavior. CuTeDSL compiles through MLIR -> LLVM IR, so a plain
|
| 507 |
+
Python-level ``Uint32(x) << Uint32(n)`` inherits LLVM's UB: the optimizer
|
| 508 |
+
may treat the result as poison and eliminate dependent code. Inline PTX
|
| 509 |
+
bypasses the LLVM IR shift entirely — the instruction is emitted verbatim
|
| 510 |
+
into PTX where clamping makes it safe for all shift amounts.
|
| 511 |
+
"""
|
| 512 |
+
return cutlass.Uint32(
|
| 513 |
+
llvm.inline_asm(
|
| 514 |
+
T.i32(),
|
| 515 |
+
[
|
| 516 |
+
cutlass.Uint32(val).ir_value(loc=loc, ip=ip),
|
| 517 |
+
cutlass.Uint32(shift).ir_value(loc=loc, ip=ip),
|
| 518 |
+
],
|
| 519 |
+
"shl.b32 $0, $1, $2;",
|
| 520 |
+
"=r,r,r",
|
| 521 |
+
has_side_effects=False,
|
| 522 |
+
is_align_stack=False,
|
| 523 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 524 |
+
)
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
|
| 528 |
@dsl_user_op
|
| 529 |
def shr_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32:
|
| 530 |
+
"""
|
| 531 |
+
Unsigned right-shift val by shift bits using PTX shr.u32 (zero-fills).
|
| 532 |
+
|
| 533 |
+
See ``shl_u32`` docstring for why inline PTX is used instead of plain
|
| 534 |
+
CuTeDSL shift operators (LLVM shift-by-type-width UB).
|
| 535 |
+
"""
|
| 536 |
return cutlass.Uint32(
|
| 537 |
llvm.inline_asm(
|
| 538 |
T.i32(),
|
|
|
|
| 540 |
cutlass.Uint32(val).ir_value(loc=loc, ip=ip),
|
| 541 |
cutlass.Uint32(shift).ir_value(loc=loc, ip=ip),
|
| 542 |
],
|
| 543 |
+
"shr.u32 $0, $1, $2;",
|
| 544 |
"=r,r,r",
|
| 545 |
has_side_effects=False,
|
| 546 |
is_align_stack=False,
|